2020年6月2日火曜日

不均衡データのダウンサンプリングとキャリブレーション

ダウンサンプリングしたデータで 2 値分類を行う際に混入するバイアスの除去を検討する。
基本的には この論文ダウンサンプリングによる予測確率のバイアス を参考にしている。

データの読み込み


MNIST データを用いて特定のラベル(今回は 5)を特定する 2 値分類の問題を考える。
Kaggle のサイト から train.csv をダウンロードして取り込む。
library(tidyverse)

# 保存したファイルを読み込む
df.mnist <- readr::read_csv(
  file = "path/to/dir/train.csv",
  col_types = cols(
    .default = col_integer()
  )
) %>%

  dplyr::mutate(
    # ラベル(正解)が "5" かそれ以外かで分類しカテゴリ値に変換
    # 後の yardstick::mn_log_loss の仕様により正解(T)をカテゴリの最初の水準に指定する必要がある事に注意
    y = (label == 5) %>% factor(levels = c(T, F))
  ) %>%
  dplyr::select(-label)

ダウンサンプリング


recipes::step_downsample を用いて正解ラベル(y) が False であるレコードのダウンサンプリングを行う。
under_ratio = 1 の指定により正解ラベル(y)が「TRUE : FALSE = 1 : 1」となるようにサンプリングが実施される。
※詳細は ドキュメント を参照。
# 乱数 seed の固定
# rsample::initial_split 用
set.seed(1025)

lst.splitted <-

  # データを train/test の 2 つに分類
  rsample::initial_split(df.mnist, prop = 3/4, strata = "y") %>%

  {
    split <- (.)

    # ダウンサンプリング用の recipe を定義
    recipe <- rsample::training(split) %>%
      recipes::recipe(y ~ ., .) %>%
      recipes::step_downsample(y, under_ratio = 1, seed = 1025)

    list(
      train = rsample::training(split),
      train_sampled = recipes::prep(recipe, training = rsample::training(split)) %>% recipes::juice(),
      test  = rsample::testing(split)
    )
  }

学習の実施


モデルとして RandomForest を用いて学習を行う。
元データとサンプリング済データで学習時間に約 10 倍近い差が出ておりサンプリングによるメリットが大きく現れている。
  • 元データ: 約 220 秒 / 31,500 件
  • サンプリングデータ: 約 20 秒 / 5,646 件
model <- parsnip::rand_forest(
  mode = "classification",
  trees = 500,
  mtry = 300,
  min_n = 2
) %>%
  parsnip::set_engine(
    engine = "ranger",
    max.depth = 20,
    num.threads = 8,
    seed = 1025
  )

# 全データを用いた学習: 約 220 秒
system.time(
  fit.original <- model %>%
    parsnip::fit(y ~ ., lst.splitted$train)
)

# サンプリングデータを用いた学習:  約 20 秒
system.time(
  fit.downsampled <- model %>%
    parsnip::fit(y ~ ., lst.splitted$train_sampled)
)

キャリブレーション


キャリブレーションはオリジナルの予測確率を $p_{s}$ として下記で与えられる。 \begin{eqnarray} p = \frac{ p_{s} }{ p_{s} + \frac{ 1 - p_{s} }{ \beta } } \nonumber \end{eqnarray} ここで $\beta = p(s = 1 | −)$ により $\beta$ は負例におけるサンプリング割合であり、下記により $\beta \fallingdotseq 0.0984$ となる。
  • 訓練データ: 31,500 件
  • サンプリングデータ: 5,646 件
    • y = TRUE: 2,823 件
    • y = FALSE: 2,823 件
  • 非サンプリングデータ: 31,500 - 5,646 = 25,854 件
  • 2,823 / (2,823 + 25,854) ≒ 0.0984
# 訓練データ全体: 31,500
nrow(lst.splitted$train)

# サンプリング対象: 5,646
nrow(lst.splitted$train_sampled)

# サンプリング対象の内訳
# - T: 2,823
# - F: 2,823
lst.splitted$train_sampled %>% dplyr::count(y)

# beta: 0.09844126
beta <- 2823 / (2823 + 25854)

予測データ


この後の可視化および評価に備えて予測確率を正解ラベルと共にまとめておく。
正例と判定するしきい値はデフォルトの 0.5 を用いている。
# 各予測確率の一覧
# 正例のしきい値: 0.5
df.predicted <- tibble(
  # 正解ラベル
  actual = lst.splitted$test$y,

  # 元データによる学習と予測
  original.proba = predict(fit.original, lst.splitted$test, type = "prob")$.pred_TRUE,
  original.pred  = (original.proba > 0.5) %>% factor(levels = c(T, F)),

  # サンプリングデータによる学習と予測
  downsampled.proba = predict(fit.downsampled, lst.splitted$test, type = "prob")$.pred_TRUE,
  downsampled.pred  = (downsampled.proba > 0.5) %>% factor(levels = c(T, F)),

  # キャリブレーション適用済みの予測
  calibrated.proba = downsampled.proba / (downsampled.proba + (1 - downsampled.proba) / beta),
  calibrated.pred  = (calibrated.proba > 0.5) %>% factor(levels = c(T, F))
)

df.predicted
actual original.proba original.pred downsampled.proba downsampled.pred calibrated.proba calibrated.pred
FALSE 0.0006013 FALSE 0.0091151 FALSE 0.0009047 FALSE
FALSE 0.0024995 FALSE 0.0100703 FALSE 0.0010004 FALSE
FALSE 0.1053711 FALSE 0.2660000 FALSE 0.0344460 FALSE
FALSE 0.0001306 FALSE 0.0199857 FALSE 0.0020035 FALSE
FALSE 0.0003486 FALSE 0.0040777 FALSE 0.0004029 FALSE


予測確率の分布


各予測確率の分布を可視化してみる。
df.predicted %>%

  # wide-form => long-form
  dplyr::select(
    original.proba,
    downsampled.proba,
    calibrated.proba
  ) %>%
  tidyr::pivot_longer(cols = dplyr::everything(), names_to = "type", names_pattern = "(.+)\\.proba", , values_to = "prob") %>%

  # 可視化時の並び順(facet_grid)を指定
  dplyr::mutate(type = forcats::fct_relevel(type, "original", "downsampled", "calibrated")) %>%

  # 可視化
  ggplot(aes(prob)) +
    geom_histogram(aes(y = ..density..), position = "identity", binwidth = 0.010, boundary = 0, colour = "white", alpha = 1/2) +
    geom_density(aes(fill = type), colour = "white", alpha = 1/3) +
    geom_vline(
      data = function(df) { dplyr::group_by(df, type) %>% dplyr::summarise(avg = mean(prob)) },
      aes(xintercept = avg, colour = type),
      linetype = 2,
      size = 1,
      alpha = 1/2
    ) +
    guides(fill = F, colour = F) +
    labs(
      x = "Probability",
      y = NULL
    ) +
    facet_grid(type ~ ., scales = "free_y")

上から順に [ サンプリングなし(original) / サンプリングのみ(downsampled) / キャリブレーション適用(calibrated) ] の順で並んでいる。
サンプリングのみの予測確率(downsampled)は正例側(右側)に寄っており、これがダウンサンプリングにおけるバイアスの影響だと考えられる。



Calibration Curve


0.1 刻みでビン化した予測確率の区間毎に予測確率(横軸)と正例(縦軸)の平均を算出して可視化を行う。
df.predicted %>%

  # ビン化
  dplyr::mutate(
    bins.downsampled = cut(downsampled.proba, breaks = seq(0, 1, 0.1)),
    bins.calibrated  = cut(calibrated.proba,  breaks = seq(0, 1, 0.1))
  ) %>%

  # 集約処理
  {
    data <- (.)

    # サンプリングのみ(downsampled)
    df.downsampled <- data %>%
      dplyr::group_by(bins = bins.downsampled) %>%
      dplyr::summarise(
        n = n(),
        avg_proba = mean(downsampled.proba),
        ratio = mean(actual == "TRUE")
      ) %>%
      dplyr::mutate(type = "downsampled")

    # calibration 適用(calibrated)
    df.calibrated <- data %>%
      dplyr::group_by(bins = bins.calibrated) %>%
      dplyr::summarise(
        n = n(),
        avg_proba = mean(calibrated.proba),
        ratio = mean(actual == "TRUE")
      ) %>%
      dplyr::mutate(type = "calibrated")

    dplyr::bind_rows(
      df.downsampled,
      df.calibrated
    )
  } %>%

  # 可視化
  ggplot(aes(avg_proba, ratio)) +
    geom_point(aes(size = n, colour = type), show.legend = F) +
    geom_line(aes(colour = type)) +
    geom_abline(slope = 1, intercept = 0, linetype = 2, alpha = 1/2) +
    scale_size_area() +
    labs(
      x = "Predict Probability",
      y = "Positive Ratio",
      colour = NULL
    )

左下から右上へ対角線上に伸びる黒い点線に沿うほど信頼性が高いと考えられるが、今回の calibration による実現はない。
見た感じでは 2 つの曲線がグラフの中心 (x, y) = (0.5, 0.5) に対して点対称になっており、 downsampled 曲線が予測確率の高い範囲(1.0 付近)で黒点線に沿っているのに対し calibrated 曲線が予測確率の低い範囲(0.0 付近)で黒点線に沿うような状況になっている事から今回の calibration によって正例と負例のどちらを重視するのかが入れ替わっているものと考えられる (ただの感想なので正しい解釈が欲しい…)。
ダウンサンプリングによって少数である正例側へと生じたバイアスを取り除くという意味では妥当と考えて良いのかもしれない。

下記の図では丸の大きさによって該当するレコード数を表現しており、左下の負例の辺りに大きなサイズの丸が存在している。これは正例の比率が小さい(約10%)事から発生していると考えられるが、件数の多い範囲(0.0 付近)における予測確率の精度が向上するという事はデータ全体における精度の向上として現れるのではないかと考えられる。実際この後に calibration を実施した予測確率において Log Loss の低下(=精度向上)が見られる事を確認する。




Log Loss


評価指標としてまず Log Loss の値を確認する。
# LogLoss: サンプリングなし
yardstick::mn_log_loss(df.predicted, actual, original.proba) %>%
  dplyr::select(metric = .metric, original = .estimate) %>%

  # LogLoss: サンプリングのみ
  dplyr::left_join(
    yardstick::mn_log_loss(df.predicted, actual, downsampled.proba) %>%
      dplyr::select(metric = .metric, downsampled = .estimate),
    by = "metric"
  ) %>%

  # LogLoss: キャリブレーション適用
  dplyr::left_join(
    yardstick::mn_log_loss(df.predicted, actual, calibrated.proba) %>%
      dplyr::select(metric = .metric, calibrated = .estimate),
    by = "metric"
  )

各指標(行)ごとに最も良いスコアと最も悪いスコアをそれぞれ赤と青で色付けしている。

サンプリングのみ(downsampled)の場合と比較して calibration によって明らかに Log Loss の改善(0.12=>0.06)が見られている。
予測確率の分布で見たように calibration によって分布が 0 と 1 の付近に寄るようになった事と、多数側のサンプルである負例での精度が向上している事に依るものと思われる。
しかしながらサンプリングなし(original)で最も良い結果が出ている事から、今回のデータで Log Loss が評価指標になる場合は不均衡への対応を行わないという選択肢は十分にある(他の対応方法を否定するものではない)。

Log Loss
metric original downsampled calibrated
mn_log_loss 0.0464 0.1223 0.0651


混同行列


  • 縦方向の T/F: Actual
  • 横方向の T/F: Predict
  • しきい値: 0.5

# サンプリングなし
table(df.predicted[, c("actual", "original.pred")])
TRUE FALSE
TRUE 848 124
FALSE 12 9,516

# サンプリングのみ
table(df.predicted[, c("actual", "downsampled.pred")])
サンプリングなしと比較して予測が正例(左列)側に寄っており、バイアスの影響が見られる。
TRUE FALSE
TRUE 944 28
FALSE 241 9,287

# キャリブレーション適用
table(df.predicted[, c("actual", "calibrated.pred")])
サンプリングのみの場合と比較して予測が全体的に負例(右列)側に寄っており、ある程度バイアスの解消が出来ているものと思われる。
一方で、False Negative(右上) に該当するサンプル数が大幅に増加してしまっている。
TRUE FALSE
TRUE 700 272
FALSE 5 9,523

Accuracy/Precision/Recall/F


しきい値は 0.5
# 使用する指標を指定
metrics <- yardstick::metric_set(
  yardstick::accuracy,
  yardstick::precision,
  yardstick::recall,
  yardstick::f_meas
)

# サンプリングなし(original)
metrics(df.predicted, truth = actual, estimate = original.pred) %>%
  dplyr::select(metric = .metric, original = .estimate) %>%

  # サンプリングのみ(downsampled)
  dplyr::left_join(
    metrics(df.predicted, truth = actual, estimate = downsampled.pred) %>%
      dplyr::select(metric = .metric, downsampled = .estimate),
    by = "metric"
  ) %>%

  # キャリブレーション適用(calibrated)
  dplyr::left_join(
    metrics(df.predicted, truth = actual, estimate = calibrated.pred) %>%
      dplyr::select(metric = .metric, calibrated = .estimate),
    by = "metric"
  )
各指標(行)ごとに最も良いスコアと最も悪いスコアをそれぞれ赤と青で色付けしている。

総合的には F 値(f_meas)の最も高いサンプリングなし(original)の場合が最も良い予測であると考えられる。
一方で calibration を適用するとサンプリングのみ(downsampled)の場合よりも成績が悪化してしまっている。
metric original downsampled calibrated
accuracy 0.987 0.974 0.974
precision 0.986 0.797 0.993
recall 0.872 0.971 0.720
f_meas 0.926 0.875 0.835


ここで判定に用いているしきい値を変更する事を考えてみる。
予測確率の分布においてその平均値が 0.1〜0.2 付近である事を考えると、上記で用いたしきい値の 0.5 はもう少し小さい値を取る方が良いと思われる。


しきい値の変更


# しきい値を 0.0〜1.0 の範囲で変更して各指標を算出
df.evals <- purrr::map_dfr(seq(0, 1, 0.01), function(threshold) {

  # 予測値の算出
  df.predicted <- tibble(
    actual = lst.splitted$test$y,

    # サンプリングなし
    original.proba = predict(fit.original, lst.splitted$test, type = "prob")$.pred_TRUE,
    original.pred  = (original.proba > threshold) %>% factor(levels = c(T, F)),

    # サンプリングのみ
    downsampled.proba = predict(fit.downsampled, lst.splitted$test, type = "prob")$.pred_TRUE,
    downsampled.pred  = (downsampled.proba > threshold) %>% factor(levels = c(T, F)),

    # キャリブレーション適用
    calibrated.proba = downsampled.proba / (downsampled.proba + (1 - downsampled.proba) / beta),
    calibrated.pred  = (calibrated.proba > threshold) %>% factor(levels = c(T, F))
  )

  # 評価指標
  metrics <- yardstick::metric_set(
    yardstick::accuracy,
    yardstick::precision,
    yardstick::recall,
    yardstick::f_meas
  )

  # サンプリングなし
  metrics(df.predicted, truth = actual, estimate = original.pred) %>%
    dplyr::select(metric = .metric, original = .estimate) %>%

    # サンプリングのみ
    dplyr::left_join(
      metrics(df.predicted, truth = actual, estimate = downsampled.pred) %>%
        dplyr::select(metric = .metric, downsampled = .estimate),
      by = "metric"
    ) %>%

    # キャリブレーション適用
    dplyr::left_join(
      metrics(df.predicted, truth = actual, estimate = calibrated.pred) %>%
        dplyr::select(metric = .metric, calibrated = .estimate),
      by = "metric"
    ) %>%

    # しきい値の追加
    dplyr::mutate(threshold = threshold)
})


df.evals %>%

  # wide-form => long-form
  tidyr::pivot_longer(cols = c(original, downsampled, calibrated), names_to = "type", values_to = "score") %>%

  # 可視化時の並び順を指定
  dplyr::mutate(
    metric = forcats::fct_relevel(metric, "accuracy", "precision", "recall", "f_meas"),
    type = forcats::fct_relevel(type, "original", "downsampled", "calibrated")
  ) %>%

  # 可視化
  ggplot(aes(threshold, score)) +
    geom_line(aes(colour = type)) +
    geom_vline(xintercept = c(0.15), linetype = 2, alpha = 1/2) +
    scale_x_continuous(breaks = seq(0, 1, 0.1)) +
    labs(
      x = "Threshold",
      y = NULL,
      colour = NULL
    ) +
    facet_grid(metric ~ .)

下の図に示すようにしきい値の値によって各指標は大きく変動する。
サンプリングのみの場合とそれ以外では F 値(f_meas)において逆の傾向が出ている。
しきい値として 0.15 を採用するとキャリブレーションを適用した予測値の F 値(f_meas)が極大となる。


まとめ


  • ダウンサンプリングにより学習時間が大幅(今回は約 1/10)に削減される
  • ダウンサンプリングにより正例寄りにバイアスが発生する
  • キャリブレーションによりダウンサンプリングによるバイアスは(ある程度)解消される
  • ダウンサンプリングにより Log Loss は悪化する
  • キャリブレーションによりダウンサンプリングによる Log Loss の悪化は改善可能
  • 少なくとも今回のデータではしきい値の調整は必須
  • しきい値の調整によりキャリブレーション適用後の各指標をサンプリングなしと同程度まで改善可能

しきい値: 0.15
metric original downsampled calibrated
accuracy 0.977 0.801 0.984
precision 0.813 0.318 0.898
recall 0.976 0.999 0.939
f_meas 0.887 0.482 0.918
mn_log_loss 0.046 0.122 0.065

0 件のコメント:

コメントを投稿