ml-lightning-basics
PyTorch Lightningの包括的なガイド - LightningModule、Trainer、分散学習、PyTorch 2.0のtorch.compile統合、Lightning Fabric、本番運用のベストプラクティスについて解説します。機械学習モデルの開発から本番環境でのデプロイまで、PyTorch Lightningを活用した実装方法を学べます。LightningModuleを使ったモデル構築、Trainerによる効率的な学習管理、複数GPUやTPUを活用した分散学習の設定、最新のPyTorch 2.0機能の統合方法を含みます。さらにLightning Fabricを用いた低レベルの制御や、本番環境で安定稼働させるための実践的なベストプラクティスについても紹介します。
description の原文を見る
Comprehensive guide for PyTorch Lightning - LightningModule, Trainer, distributed training, PyTorch 2.0 torch.compile integration, Lightning Fabric, and production best practices
SKILL.md 本文
ML研究向けPyTorch Lightning
概要
PyTorch Lightningは、PyTorchコードを整理し、研究とエンジニアリングを分離する業界標準フレームワークです。ボイラープレートコードを排除しながら、PyTorchの完全な柔軟性を維持し、研究者がトレーニングインフラストラクチャではなくモデルロジックに集中できるようにします。
主な利点:
- ボイラープレートコードの90%を削減
- 自動分散トレーニング(DDP、FSDP、DeepSpeed)
- ハードウェア非依存(CPU、GPU、TPU、MPS)
- 組み込みベストプラクティス(チェックポイント、ロギング、プロファイリング)
- PyTorch 2.0との完全な互換性とtorch.compile対応
- 初日から本番対応コード
リソース:
- 公式ドキュメント: https://lightning.ai/docs/pytorch/stable/
- スタイルガイド: https://lightning.ai/docs/pytorch/stable/starter/style_guide.html
- パフォーマンスガイド: https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html
コア概念
1. LightningModule
LightningModuleは、モデル+トレーニングロジックを自己完結型クラスにカプセル化します。
完全な例:
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageClassifier(L.LightningModule):
def __init__(self, backbone="resnet18", num_classes=10, lr=1e-3):
super().__init__()
# 重要: チェックポイント用にすべてのハイパーパラメータを保存
self.save_hyperparameters()
# モデルアーキテクチャを定義
if backbone == "resnet18":
from torchvision.models import resnet18
self.model = resnet18(num_classes=num_classes)
else:
raise ValueError(f"Unknown backbone: {backbone}")
# メトリクス(効率性のためTorchMetricsを使用)
from torchmetrics import Accuracy
self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
def forward(self, x):
"""フォワードパス - 推論のみ."""
return self.model(x)
def training_step(self, batch, batch_idx):
"""1バッチのトレーニングロジック."""
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
# メトリクスを更新してログ
acc = self.train_acc(y_hat, y)
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
return loss # ロスを返す必須
def validation_step(self, batch, batch_idx):
"""検証ロジック."""
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = self.val_acc(y_hat, y)
self.log("val/loss", loss, prog_bar=True, sync_dist=True)
self.log("val/acc", acc, prog_bar=True, sync_dist=True)
def test_step(self, batch, batch_idx):
"""テストロジック(オプション)."""
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log("test/loss", loss)
self.log("test/acc", acc)
def configure_optimizers(self):
"""オプティマイザと学習率スケジューラを設定."""
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.lr,
weight_decay=0.01,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.trainer.max_epochs,
eta_min=1e-6,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
},
}
主要メソッド:
| メソッド | 目的 | 必須 |
|---|---|---|
__init__ | モデルアーキテクチャ、ハイパーパラメータ | はい |
forward | 推論ロジック(トレーニングコードなし) | はい |
training_step | 1バッチのトレーニングロジック | はい |
validation_step | 検証ロジック | 推奨 |
test_step | テストロジック | オプション |
configure_optimizers | オプティマイザとスケジューラの設定 | はい |
2. LightningDataModule
再利用可能で再現可能な方法で、すべてのデータロードロジックを整理します。
class ImageDataModule(L.LightningDataModule):
def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
super().__init__()
self.save_hyperparameters()
def prepare_data(self):
"""データをダウンロード(1回実行、単一GPUで実行)."""
# データセットをダウンロード
# ここではインスタンス変数を設定しないこと(self.x = yなど)
from torchvision.datasets import CIFAR10
CIFAR10(self.hparams.data_dir, train=True, download=True)
CIFAR10(self.hparams.data_dir, train=False, download=True)
def setup(self, stage=None):
"""データを読み込み(分散時に各GPUで実行)."""
from torchvision.datasets import CIFAR10
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if stage == "fit" or stage is None:
full_dataset = CIFAR10(
self.hparams.data_dir,
train=True,
transform=transform
)
# トレーニング/検証に分割
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
self.train_dataset, self.val_dataset = torch.utils.data.random_split(
full_dataset, [train_size, val_size]
)
if stage == "test" or stage is None:
self.test_dataset = CIFAR10(
self.hparams.data_dir,
train=False,
transform=transform
)
def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=True,
persistent_workers=True, # エポック間でワーカーを保持
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self.val_dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=True,
persistent_workers=True,
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
self.test_dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
)
3. Trainer
Trainerはトレーニングループ、ハードウェア管理、ロギングを自動化します。
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
# Trainerを作成
trainer = Trainer(
# ハードウェア
accelerator="auto", # 自動検出: GPU、CPU、TPU、MPS
devices="auto", # 利用可能なすべてのデバイスを使用
precision="16-mixed", # 混合精度(高速、メモリ効率的)
# トレーニング期間
max_epochs=100,
min_epochs=10,
# 検証
check_val_every_n_epoch=1,
val_check_interval=1.0, # トレーニングエポックの割合
# ロギング
log_every_n_steps=50,
enable_progress_bar=True,
# コールバック
callbacks=[
ModelCheckpoint(
monitor="val/loss",
mode="min",
save_top_k=3,
filename="epoch_{epoch:02d}-loss_{val/loss:.4f}",
),
EarlyStopping(
monitor="val/loss",
patience=10,
mode="min",
),
LearningRateMonitor(logging_interval="step"),
],
# ロガー
logger=WandbLogger(project="my-project", name="experiment-1"),
# デバッグ
fast_dev_run=False, # 1バッチテスト用にTrueに設定
overfit_batches=0, # デバッグ用にN個のバッチでオーバーフィット
# 再現性
deterministic=True,
benchmark=False, # 入力サイズが一定の場合、Trueに設定
)
# トレーニング
model = ImageClassifier()
datamodule = ImageDataModule()
trainer.fit(model, datamodule=datamodule)
# テスト
trainer.test(model, datamodule=datamodule, ckpt_path="best")
PyTorch 2.0統合
PyTorch 2.0のtorch.compileは、グラフコンパイルを通じて大幅なスピードアップ(平均40%以上)を提供します。
LightningでtorchコンパイルAPIを使用
方法1: モデル全体をコンパイル
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = torch.compile(
YourModel(),
mode="default" # または "reduce-overhead"、"max-autotune"
)
方法2: Trainerで設定(推奨)
model = MyModel()
# 自動的にコンパイル
trainer = Trainer(max_epochs=10)
compiled_model = torch.compile(model, mode="default")
trainer.fit(compiled_model, datamodule=dm)
torch.compileモード
| モード | 最適化レベル | コンパイル時間 | ユースケース |
|---|---|---|---|
default | バランス型 | 高速 | 開発、一般的な用途 |
reduce-overhead | カーネル起動のオーバーヘッド最小化 | 中程度 | 小バッチサイズ、CPUボトルネック |
max-autotune | 最大パフォーマンス | 低速 | 本番環境、長時間トレーニング |
パフォーマンス例:
import torch
# 標準モデル
model = MyModel()
# コンパイルにより平均40%高速化
compiled_model = torch.compile(model, mode="max-autotune")
コンパイルベストプラクティス
やるべき事:
- 開発時には
mode="default"を使用 - 本番環境では
mode="max-autotune"を使用 - ボトルネック特定のためにプロファイル実行
- モデルアーキテクチャを静的に保つ(動的シェイプなし)
やるべきでない事:
- 高度に動的なモデル(可変長RNN)でコンパイル
- CPU上でのスピードアップを期待(torch.compileはGPU中心)
- コンパイルされたモジュールとされていないモジュールを混在
グラフブレーク(パフォーマンスの問題)
グラフブレークは、PyTorchがセクションをコンパイルできない場合に発生します:
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
# 回避: Pythonの制御フローはグラフを破壊
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}") # コンパイルを破壊
loss = F.cross_entropy(y_hat, y)
return loss
グラフブレークをチェック:
import torch._dynamo as dynamo
# リセットしてロギングを有効化
dynamo.reset()
dynamo.config.verbose = True
model = torch.compile(model, mode="default")
# グラフブレークの位置を示す警告が表示されます
分散トレーニング
Lightningは分散トレーニングを簡単にします - 1つの引数を変更するだけです。
DDP(Distributed Data Parallel)
標準マルチGPUトレーニング:
# 単一GPU
trainer = Trainer(accelerator="gpu", devices=1)
# マルチGPU(自動DDP)
trainer = Trainer(accelerator="gpu", devices=4, strategy="ddp")
# すべてのGPU
trainer = Trainer(accelerator="gpu", devices="auto", strategy="ddp")
DDP spawn(Windows互換性):
trainer = Trainer(accelerator="gpu", devices=4, strategy="ddp_spawn")
FSDP(Fully Sharded Data Parallel)
単一GPUメモリに収まらないモデル用:
from lightning.pytorch.strategies import FSDPStrategy
trainer = Trainer(
accelerator="gpu",
devices=8,
strategy=FSDPStrategy(
sharding_strategy="FULL_SHARD", # パラメータ、勾配、オプティマイザをシャード
auto_wrap_policy={nn.Linear}, # 線形層を自動ラップ
),
)
DeepSpeed
極度に大規模なモデル用(数十億パラメータ):
from lightning.pytorch.strategies import DeepSpeedStrategy
trainer = Trainer(
accelerator="gpu",
devices=8,
strategy=DeepSpeedStrategy(
stage=3, # ZeRO Stage 3(最もメモリ効率的)
offload_optimizer=True, # オプティマイザをCPUにオフロード
offload_parameters=True, # パラメータをCPUにオフロード
),
precision="16-mixed",
)
DeepSpeed設定ファイル:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
},
"fp16": {
"enabled": true
}
}
Lightning Fabric
FabricはLightningの軽量抽象化 - Trainerより高い制御性、生のPyTorchよりボイラープレートが少ない。
Fabricを使用する場合
- Lightning特性を備えたカスタムトレーニングループ
- PyTorchから段階的に移行
- きめ細かい制御を必要とする研究
例:
import lightning as L
from lightning.fabric import Fabric
# Fabricを初期化
fabric = L.Fabric(
accelerator="cuda",
devices=2,
precision="16-mixed"
)
fabric.launch()
# モデルとオプティマイザをセットアップ
model = YourModel()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
# データをセットアップ
train_loader = fabric.setup_dataloaders(train_loader)
# カスタムトレーニングループ
for epoch in range(epochs):
for batch in train_loader:
optimizer.zero_grad()
loss = model(batch)
fabric.backward(loss)
optimizer.step()
# 自動ロギング
fabric.log("train_loss", loss)
高度なパターン
複数オプティマイザ
def configure_optimizers(self):
# 異なる部分に対して異なる学習率
opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.001)
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001)
return [opt_g, opt_d], []
手動オプティマイザ
class GANModel(L.LightningModule):
def __init__(self):
super().__init__()
self.automatic_optimization = False # 自動を無効化
def training_step(self, batch, batch_idx):
opt_g, opt_d = self.optimizers()
# ジェネレータを訓練
loss_g = self.generator_loss(batch)
opt_g.zero_grad()
self.manual_backward(loss_g)
opt_g.step()
# ディスクリミネータを訓練
loss_d = self.discriminator_loss(batch)
opt_d.zero_grad()
self.manual_backward(loss_d)
opt_d.step()
self.log_dict({"loss_g": loss_g, "loss_d": loss_d})
勾配累積
# 有効バッチサイズ = batch_size * accumulate_grad_batches
trainer = Trainer(
accumulate_grad_batches=4, # 更新前に4バッチを累積
)
勾配クリッピング
trainer = Trainer(
gradient_clip_val=1.0, # 勾配を最大ノルム1.0にクリップ
gradient_clip_algorithm="norm", # または "value"
)
コールバック
組み込みコールバック
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
RichProgressBar,
ModelSummary,
TQDMProgressBar,
)
callbacks = [
# 最良モデルを保存
ModelCheckpoint(
monitor="val/loss",
mode="min",
save_top_k=3,
filename="best-{epoch:02d}-{val_loss:.2f}",
),
# 早期停止
EarlyStopping(
monitor="val/loss",
patience=10,
mode="min",
verbose=True,
),
# 学習率をログ
LearningRateMonitor(logging_interval="step"),
# Rich進捗バー
RichProgressBar(),
# モデルサマリー
ModelSummary(max_depth=2),
]
カスタムコールバック
from lightning.pytorch.callbacks import Callback
class PrintCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("トレーニング開始!")
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
train_loss = trainer.callback_metrics.get("train/loss")
val_loss = trainer.callback_metrics.get("val/loss")
print(f"エポック {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
def on_validation_end(self, trainer, pl_module):
# カスタムアーティファクトを保存
pass
ロギング
複数ロガー
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
wandb_logger = WandbLogger(project="my-project", name="run-1")
tb_logger = TensorBoardLogger("logs/", name="my_model")
trainer = Trainer(logger=[wandb_logger, tb_logger])
高度なロギング
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# スカラーをログ
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
# 複数メトリクスをログ
self.log_dict({
"train/loss": loss,
"train/acc": acc,
"train/f1": f1,
}, on_epoch=True)
# ヒストグラムをログ(TensorBoard/W&B用)
if batch_idx % 100 == 0:
self.logger.experiment.add_histogram(
"gradients/layer1",
self.model.layer1.weight.grad,
self.global_step
)
return loss
ベストプラクティス
✅ やるべき事
__init__で常にself.save_hyperparameters()を呼び出す - 再現性のため- DataModuleを使用 - すべてのデータロジックをカプセル化
self.log()でログ - 自動集約と同期のため- コールバックを使用 - チェックポイント、早期停止、監視
- 混合精度を有効化 -
precision="16-mixed"でスピードアップ - DataLoaderで
pin_memory=Trueとpersistent_workers=Trueを使用 - 再現性のため
deterministic=Trueを設定 - クイックサニティチェック用に
fast_dev_run=Trueを使用 - TorchMetricsを使用 - 効率的なメトリクス計算
- PyTorch 2.0以降ではモデルを
torch.compileでコンパイル
❌ やるべきでない事
training_stepからロスを返すのを忘れないprepare_data()で状態を設定しない - 代わりにsetup()を使用.to(device)でテンソルを手動で移動しない - Lightningが処理- ロギング用に
print()を使用しない -self.log()を使用 - ハイパーパラメータをハードコードしない -
self.hparamsを使用 - 分散メトリクスで
sync_dist=Trueを無視しない - LightningのトレーニングループとPyTorchの生のループを混在させない
一般的な問題と解決方法
問題1: NaNロス
# 解決策1: 勾配クリッピング
trainer = Trainer(gradient_clip_val=1.0)
# 解決策2: 学習率を下げる
optimizer = torch.optim.Adam(params, lr=1e-4) # 1e-3の代わり
# 解決策3: 完全精度
trainer = Trainer(precision=32) # 16-mixedの代わり
問題2: メモリ不足
# 解決策1: バッチサイズを減らす
datamodule = MyDataModule(batch_size=16) # 32の代わり
# 解決策2: 勾配累積
trainer = Trainer(accumulate_grad_batches=4)
# 解決策3: 混合精度
trainer = Trainer(precision="16-mixed")
問題3: トレーニングが遅い
# 解決策1: モデルをコンパイル(PyTorch 2.0以降)
model = torch.compile(model, mode="max-autotune")
# 解決策2: ボトルネックをプロファイル
trainer = Trainer(profiler="simple") # または "advanced"
# 解決策3: num_workersを増やす
datamodule = MyDataModule(num_workers=8) # CPUコア数に合わせる
必須リソース
公式ドキュメント
- Lightningドキュメント: https://lightning.ai/docs/pytorch/stable/
- APIリファレンス: https://lightning.ai/docs/pytorch/stable/api_references.html
- スタイルガイド: https://lightning.ai/docs/pytorch/stable/starter/style_guide.html
チュートリアル
- 15分でLightning: https://lightning.ai/docs/pytorch/stable/starter/introduction.html
- PyTorchからLightningへ: https://lightning.ai/docs/pytorch/stable/starter/converting.html
- 分散トレーニング: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
PyTorch 2.0
- torch.compileの概要: https://pytorch.org/get-started/pytorch-2-0/
- torch.compileチュートリアル: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html
- パフォーマンスティップス: https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html
コミュニティ
- Lightning Bolts: https://lightning-bolts.readthedocs.io/ (モデル実装)
- GitHubディスカッション: https://github.com/Lightning-AI/pytorch-lightning/discussions
クイックリファレンス
最小限の動作例:
import lightning as L
import torch
class MinimalModel(L.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
loss = torch.nn.functional.mse_loss(self(x), y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
# トレーニング
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader)
まとめ
PyTorch Lightningは以下を提供します:
- シンプルさ: ボイラープレートを削減、研究に集中
- スケーラビリティ: ノートパソコンからスーパーコンピュータまで、1行で対応
- 速度: PyTorch 2.0統合、自動最適化
- 柔軟性: 必要なときはPyTorchの完全な制御が可能
- 本番対応: 初日からデプロイメント準備完了
PyTorch 2.0のtorch.compileと組み合わせることで、Lightningは最小限のコードで最大のパフォーマンスを実現します。
ライセンス: MIT(寛容ライセンスのため全文を引用しています) · 原本リポジトリ
詳細情報
- 作者
- nishide-dev
- ライセンス
- MIT
- 最終更新
- 2026/4/6
Source: https://github.com/nishide-dev/claude-code-ml-research / ライセンス: MIT