AutoGluon学習が12時間かけても振るわなかった話
1. 結果サマリー
コンペ提出スコア (RMSE)
20.46
内部検証ベストスコア (-RMSE)
-8.41
スコアの乖離
12.05
分析: 内部での検証スコア(-8.41)は非常に優秀ですが、提出スコア(20.46)との間に大きな乖離が見られます。これは、作成されたモデルが手元の学習データに過剰に適合し(過学習)、未知のテストデータに対する性能が低い状態であることを強く示唆しています。
2. 根本原因の分析
原因①: GPUメモリ不足による有力モデルの脱落
学習ログを調査した結果、AutoGluonが学習を試みたモデルのうち、DeepAR と TiDE という2つの強力なディープラーニングモデルがGPUメモリ不足(CUDA out of memory)で学習に失敗し、自動的にスキップされていました。 これにより、最終的なアンサンブルモデルを構成するモデルの多様性が損なわれ、性能が頭打ちになったと考えられます。
証拠: 学習ログの抜粋
# DeepARのエラー
Warning: Exception caused DeepAR to fail during training... Skipping this model.
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.96 GiB. GPU 0 has a total capacity of 8.00 GiB...
# TiDEのエラー
Warning: Exception caused TiDE to fail during training... Skipping this model.
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 13.70 GiB. GPU 0 has a total capacity of 8.00 GiB...
原因②: 検証戦略の脆弱性
ログの冒頭に以下の警告が出ています。
Time series in train_data are too short for chosen num_val_windows=2. Reducing num_val_windows to 1.
これは、モデルの性能を評価するための検証用データセットが1セットしか作れなかったことを意味します。単一の検証セットでの評価は過学習を見抜きにくくし、今回のスコア乖離の一因となった可能性があります。
3. スコア改善のためのアクションプラン
上記の原因を解消し、スコアを向上させるための具体的な改善策を提案します。
-
【最優先】GPUメモリ不足の解消
脱落したモデルを学習に参加させることがスコア改善の最大の鍵です。
fit()
時にhyperparameters
引数を使い、各ディープラーニングモデルのバッチサイズを明示的に小さく設定します。hyperparameters = { # メモリ不足で失敗したモデル "DeepAR": {"batch_size": 32}, "TiDE": {"batch_size": 32}, # 他のGPUモデルも念のため設定 "PatchTST": {"batch_size": 32}, "Chronos": {"batch_size": 16}, } predictor.fit( train_data_ag, hyperparameters=hyperparameters, # 他の引数はそのまま ... )
-
【推奨】全データでの再学習(リフィット)
モデルの汎化性能を高めるために、
refit_full=True
オプションを追加します。これはルール違反にはなりません。これにより、利用可能な過去データを最大限活用できます。predictor.fit( train_data_ag, refit_full=True, # この行を追加 # 他の引数はそのまま ... )
4. 内部評価リーダーボード
学習時に生成された各モデルの性能一覧です。score_val
が内部検証スコアを示します。
model | score_val | pred_time_val | fit_time_marginal | fit_order |
---|---|---|---|---|
WeightedEnsemble | -8.4101 | 330.44 | 0.56 | 10 |
ChronosFineTuned[bolt_small] | -9.6343 | 3.88 | 581.89 | 8 |
DirectTabular | -10.7523 | 11.63 | 55.03 | 3 |
PatchTST | -11.3727 | 3.44 | 18448.01 | 9 |
NPTS | -11.6348 | 298.13 | 0.41 | 4 |
ChronosZeroShot[bolt_base] | -12.8146 | 8.67 | 13.20 | 7 |
SeasonalNaive | -14.0179 | 2.43 | 0.48 | 1 |
AutoETS | -16.6551 | 2.74 | 0.51 | 6 |
DynamicOptimizedTheta | -17.4808 | 4.69 | 0.45 | 5 |
RecursiveTabular | -56.2560 | 476.26 | 60.25 | 2 |
コメント