2019年6月11日火曜日

時系列のクロスバリデーション

tidymodels に含まれる rsample パッケージの rolling_origin という関数が便利そうだったので試してみた内容を記載してみる。表題の通り時系列データに対するクロスバリデーションを行う。

この関数を知る切っ掛けになったのがこちら。ありがとうございます。

https://blog.hoxo-m.com/entry/2019/06/08/220307

rolling_origin 関数


時系列データのクロスバリデーションに関する考え方はこちらのサイト Rでのナウなデータ分割のやり方: rsampleパッケージによる交差検証 や書籍 前処理大全 に任せるとして、rolling_origin 関数の skip 引数の意味が分かりづらかったので少し試してみる事にする。

まず、今回使用するデータを準備。
"S4248SM144NCEN" というカラム名が何かは知らないけどデータの名称からして恐らく飲酒量とかその辺りだと思われるのでとりあえず "amount" という項目名に変更しておく。

library(tidyverse)
library(tidymodels)

df.drinks <- rsample::drinks %>%
  dplyr::rename(amount = S4248SM144NCEN)

# date amount
1 1992-01-01 3,459
2 1992-02-01 3,458
3 1992-03-01 4,002

このデータは 1992/01/01 から 2017/09/01 まで続く月次データであり、全部で 309 件のレコードが存在する。
このデータを下記の要件にて分割する事を考えてみる。
  • 学習期間: 24ヶ月
  • 検証期間: 12ヶ月
  • スライド: 1ヶ月

上記を実現したのがこちら
df.rolling <- rsample::rolling_origin(
  df.drinks,
  initial = 24,
  assess = 12,
  skip = 0,
  cumulative = F
)

df.rolling
# Rolling origin forecast resampling 
# A tibble: 274 x 2
   splits          id      
                
 1 <split [24/12]> Slice001
 2 <split [24/12]> Slice002
 3 <split [24/12]> Slice003
 4 <split [24/12]> Slice004
 5 <split [24/12]> Slice005
 6 <split [24/12]> Slice006
 7 <split [24/12]> Slice007
 8 <split [24/12]> Slice008
 9 <split [24/12]> Slice009
10 <split [24/12]> Slice010
# … with 264 more rows

学習データの 1 番目を取り出してみた。確かに 1992/1/1 を含む 24 ヶ月間のデータとなっている。
df.rolling$splits[[1]] %>%
  rsample::analysis()

# A tibble: 24 x 2
   date       amount
         
 1 1992-01-01   3459
 2 1992-02-01   3458
 3 1992-03-01   4002
 4 1992-04-01   4564
 5 1992-05-01   4221
 6 1992-06-01   4529
 7 1992-07-01   4466
 8 1992-08-01   4137
 9 1992-09-01   4126
10 1992-10-01   4259
# … with 14 more rows

検証用データの 1 番目を取り出してみる。こちらも想定通り 1994 年の 12 ヶ月分のデータとなっている。
df.rolling$splits[[1]] %>%
  rsample::assessment()

# A tibble: 12 x 2
   date       amount
         
 1 1994-01-01   3075
 2 1994-02-01   3377
 3 1994-03-01   4443
 4 1994-04-01   4261
 5 1994-05-01   4460
 6 1994-06-01   4985
 7 1994-07-01   4324
 8 1994-08-01   4719
 9 1994-09-01   4374
10 1994-10-01   4248
11 1994-11-01   4784
12 1994-12-01   4971

2 番目の学習用/検証用のデータも取り出してみる。想定通り 1 ヶ月スライドしたデータとなっている。
df.rolling$splits[[2]] %>%
  rsample::analysis()
# A tibble: 24 x 2
   date       amount
         
 1 1992-02-01   3458
 2 1992-03-01   4002
 3 1992-04-01   4564
 4 1992-05-01   4221
 5 1992-06-01   4529
 6 1992-07-01   4466
 7 1992-08-01   4137
 8 1992-09-01   4126
 9 1992-10-01   4259
10 1992-11-01   4240
# … with 14 more rows

df.rolling$splits[[2]] %>%
  rsample::assessment()
# A tibble: 12 x 2
   date       amount
         
 1 1994-02-01   3377
 2 1994-03-01   4443
 3 1994-04-01   4261
 4 1994-05-01   4460
 5 1994-06-01   4985
 6 1994-07-01   4324
 7 1994-08-01   4719
 8 1994-09-01   4374
 9 1994-10-01   4248
10 1994-11-01   4784
11 1994-12-01   4971
12 1995-01-01   3370

ここで最初に疑問に思ったのが rolling_origin 関数の skip 引数の仕様である。例えば下記のようなデータ仕様を考えてみる。
  • 学習期間: 24ヶ月
  • 検証期間: 12ヶ月
  • スライド: 12ヶ月
つまり 1 年おきにデータを学習/予測(ex. 年の始めに 2 年分のデータを使って直近 1 年間の予測を行う)するようなケースとなる。 この時に直感的には下記のようなコードを想定していた。
df.rolling2 <- rsample::rolling_origin(
  df.drinks,
  initial = 24,
  assess = 12,
  skip = 12,
  cumulative = F
)

12 ヶ月データをスキップするので skip 引数に 12 を指定しているが、これだと下記のように意図した通りにはならない。 2 番目の学習データが 1993/1/1 ではなく 1993/2/1 からのスタートになってしまっており、実際には 13 ヶ月のスキップとなってしまっている。
df.rolling2$splits[[2]] %>%
  rsample::analysis()
# A tibble: 24 x 2
   date       amount
         
 1 1993-02-01   3261
 2 1993-03-01   4160
 3 1993-04-01   4377
 4 1993-05-01   4307
 5 1993-06-01   4696
 6 1993-07-01   4458
 7 1993-08-01   4457
 8 1993-09-01   4364
 9 1993-10-01   4236
10 1993-11-01   4500
# … with 14 more rows

これはそもそも rolling_origin 関数における skip 引数のデフォルト値が 0 であり、この指定によって 1 ヶ月おきのスキップを意図している事による。 つまりそこから更にずらす幅を追加したい時に skip 引数を使いなさいという事なのだろうなと。

これはドキュメントをちゃんと読むとそれっぽい事が記載されていた。ちゃんと読みなさいよと
When skip = 0, the resampling data sets will increment by one position.
あとから考えてみると skip = 0 の指定でスライドしないのであれば無限ループになってしまう orz

なので、上記の仕様であれば下記のコードが正解となる。
df.rolling3 <- rsample::rolling_origin(
  df.drinks,
  initial = 24,
  assess = 12,
  skip = 11, # 11 = 12 - 1
  cumulative = F
)

df.rolling3$splits[[2]] %>%
  rsample::analysis()
# A tibble: 24 x 2
   date       amount
         
 1 1993-01-01   3031
 2 1993-02-01   3261
 3 1993-03-01   4160
 4 1993-04-01   4377
 5 1993-05-01   4307
 6 1993-06-01   4696
 7 1993-07-01   4458
 8 1993-08-01   4457
 9 1993-09-01   4364
10 1993-10-01   4236
# … with 14 more rows

最初から複数月のスライドを意図したデータを作ろうとしていたのが失敗で、スライド幅 1 のケースから順番に試していれば混乱しなかったんだろうなと。難しい

クロスバリデーションによるモデル比較

せっかく rolling_origin の使い方がだいたい分かったので、これを用いて時系列モデルの比較を行ってみる事にする。

まず比較対象となるモデルの一覧を list で作成する。実際には当該リストに含まれるのは学習済みモデルを返す関数である事に注意。
lst.models <- list(
  # ローカルレベルモデル + 確定的季節成分
  d1_ss = function(df.data) {
    library(KFAS)
    model <- SSModel(
      amount ~
        SSMtrend(degree = 1, Q = NA) +
        SSMseasonal(period = 12, Q = 0),
      data = df.data,
      H = NA
    )
    fitSSM(
      model,
      inits = c(1, 1),
      method = "BFGS"
    )
  },

  # ローカルレベルモデル + 確率的季節成分
  d1_sv = function(df.data) {
    library(KFAS)
    model <- SSModel(
      amount ~
        SSMtrend(degree = 1, Q = NA) +
        SSMseasonal(period = 12, Q = NA),
      data = df.data,
      H = NA
    )
    fitSSM(
      model,
      inits = c(1, 1, 1),
      method = "BFGS"
    )
  },

  # ローカル線形トレンドモデル + 確定的季節成分
  d2_ss = function(df.data) {
    library(KFAS)
    model <- SSModel(
      amount ~
        SSMtrend(degree = 2, Q = list(matrix(NA), matrix(NA))) +
        SSMseasonal(period = 12, Q = 0),
      data = df.data,
      H = NA
    )
    fitSSM(
      model,
      inits = c(1, 1, 1),
      method = "BFGS"
    )
  },

  # ローカル線形トレンドモデル + 確率的季節成分
  d2_sv = function(df.data) {
    library(KFAS)
    model <- SSModel(
      amount ~
        SSMtrend(degree = 2, Q = list(matrix(NA), matrix(NA))) +
        SSMseasonal(period = 12, Q = NA),
      data = df.data,
      H = NA
    )
    fitSSM(
      model,
      inits = c(1, 1, 1, 1),
      method = "BFGS"
    )
  }
)

リストの要素を指定してデータを渡せば KFAS による学習済みモデルが返却される。
lst.models$d1_ss(df.drinks)

### ↓以下結果 ###

$optim.out
$optim.out$par
[1]  9.355875 12.207159

$optim.out$value
[1] 2290.374

$optim.out$counts
function gradient 
      29       10 

$optim.out$convergence
[1] 0

$optim.out$message
NULL


$model
Call:
SSModel(formula = amount ~ SSMtrend(degree = 1, Q = NA) + SSMseasonal(period = 12, 
    Q = 0), data = df.data, H = NA)

State space model object of class SSModel

Dimensions:
[1] Number of time points: 309
[1] Number of time series: 1
[1] Number of disturbances: 2
[1] Number of states: 12
Names of the states:
 [1]  level        sea_dummy1   sea_dummy2   sea_dummy3   sea_dummy4   sea_dummy5   sea_dummy6   sea_dummy7   sea_dummy8 
[10]  sea_dummy9   sea_dummy10  sea_dummy11
Distributions of the time series:
[1]  gaussian

Object is a valid object of class SSModel.

こんな感じに予測も可能。
predict(lst.models$d1_ss(df.drinks)$model)

Time Series:
Start = 1 
End = 309 
Frequency = 1 
             fit
  [1,]  2719.484
  [2,]  3162.725
  [3,]  4213.908
  [4,]  4128.718
  [5,]  4758.920 ...

上記を用いて下記のようにモデル毎に評価結果を返す関数を作って実験を行う。
評価指標は CV データ毎に算出した RMSE の平均を用いる事にする。
# h: 予測する範囲 ex. h=1 で 1 期先予測
# skip: スキップ数
rmse_per_model_cv <- function(h, skip = 0) {
  # 時系列 CV の作成
  df.rolling <- rsample::rolling_origin(df.drinks, initial = 24, assess = h,  cumulative = F, skip = skip)

  # 各モデルの適用
  purrr::map(lst.models, function(model) {

    # クロスバリデーション
    purrr::map_dbl(df.rolling$splits, function(split) {
      df.train <- rsample::analysis(split)
      df.test  <- rsample::assessment(split)

      # モデルの学習
      fit <- model(df.train)

      # 学習済みモデルによる予測と RMSE の算出
      # columns: date, amount, predicted
      df.test %>%
        dplyr::mutate(
          predicted = predict(fit$model, n.ahead = h)[, "fit"] %>% as.numeric()
        ) %>%
        yardstick::rmse(amount, predicted) %>%
        .$.estimate
    })
  })
}

モデル毎に 24 ヶ月の学習データを用いて 12 期先までの予測を行う試行を 23 回行い、23 回分の RMSE の平均値をモデルの評価値とする。
# 予測結果の一覧を取得
rmses <- rmse_per_model_cv(h = 12, skip = 11)

# モデル毎に RMSE の平均値を算出
purrr::map_dbl(rmses, ~ mean(.))

   d1_ss    d1_sv    d2_ss    d2_sv 
561.9146 472.5309 442.3523 449.9628

今回のデータだとモデル d2_ss(ローカル線形トレンド+確定的季節成分) の予測精度が最も高いという結果になった。

0 件のコメント:

コメントを投稿