Agent Skills by ALSEL
Anthropic Claudeデータ・分析⭐ リポ 0品質スコア 65/100

ml-lightning-basics

PyTorch Lightningの包括的なガイド - LightningModule、Trainer、分散学習、PyTorch 2.0のtorch.compile統合、Lightning Fabric、本番運用のベストプラクティスについて解説します。機械学習モデルの開発から本番環境でのデプロイまで、PyTorch Lightningを活用した実装方法を学べます。LightningModuleを使ったモデル構築、Trainerによる効率的な学習管理、複数GPUやTPUを活用した分散学習の設定、最新のPyTorch 2.0機能の統合方法を含みます。さらにLightning Fabricを用いた低レベルの制御や、本番環境で安定稼働させるための実践的なベストプラクティスについても紹介します。

description の原文を見る

Comprehensive guide for PyTorch Lightning - LightningModule, Trainer, distributed training, PyTorch 2.0 torch.compile integration, Lightning Fabric, and production best practices

SKILL.md 本文

ML研究向けPyTorch Lightning

概要

PyTorch Lightningは、PyTorchコードを整理し、研究とエンジニアリングを分離する業界標準フレームワークです。ボイラープレートコードを排除しながら、PyTorchの完全な柔軟性を維持し、研究者がトレーニングインフラストラクチャではなくモデルロジックに集中できるようにします。

主な利点:

  • ボイラープレートコードの90%を削減
  • 自動分散トレーニング(DDP、FSDP、DeepSpeed)
  • ハードウェア非依存(CPU、GPU、TPU、MPS)
  • 組み込みベストプラクティス(チェックポイント、ロギング、プロファイリング)
  • PyTorch 2.0との完全な互換性とtorch.compile対応
  • 初日から本番対応コード

リソース:


コア概念

1. LightningModule

LightningModuleは、モデル+トレーニングロジックを自己完結型クラスにカプセル化します。

完全な例:

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageClassifier(L.LightningModule):
    def __init__(self, backbone="resnet18", num_classes=10, lr=1e-3):
        super().__init__()

        # 重要: チェックポイント用にすべてのハイパーパラメータを保存
        self.save_hyperparameters()

        # モデルアーキテクチャを定義
        if backbone == "resnet18":
            from torchvision.models import resnet18
            self.model = resnet18(num_classes=num_classes)
        else:
            raise ValueError(f"Unknown backbone: {backbone}")

        # メトリクス(効率性のためTorchMetricsを使用)
        from torchmetrics import Accuracy
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        """フォワードパス - 推論のみ."""
        return self.model(x)

    def training_step(self, batch, batch_idx):
        """1バッチのトレーニングロジック."""
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        # メトリクスを更新してログ
        acc = self.train_acc(y_hat, y)
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return loss  # ロスを返す必須

    def validation_step(self, batch, batch_idx):
        """検証ロジック."""
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        acc = self.val_acc(y_hat, y)
        self.log("val/loss", loss, prog_bar=True, sync_dist=True)
        self.log("val/acc", acc, prog_bar=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        """テストロジック(オプション)."""
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        self.log("test/loss", loss)
        self.log("test/acc", acc)

    def configure_optimizers(self):
        """オプティマイザと学習率スケジューラを設定."""
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=0.01,
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

主要メソッド:

メソッド目的必須
__init__モデルアーキテクチャ、ハイパーパラメータはい
forward推論ロジック(トレーニングコードなし)はい
training_step1バッチのトレーニングロジックはい
validation_step検証ロジック推奨
test_stepテストロジックオプション
configure_optimizersオプティマイザとスケジューラの設定はい

2. LightningDataModule

再利用可能で再現可能な方法で、すべてのデータロードロジックを整理します。

class ImageDataModule(L.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
        super().__init__()
        self.save_hyperparameters()

    def prepare_data(self):
        """データをダウンロード(1回実行、単一GPUで実行)."""
        # データセットをダウンロード
        # ここではインスタンス変数を設定しないこと(self.x = yなど)
        from torchvision.datasets import CIFAR10
        CIFAR10(self.hparams.data_dir, train=True, download=True)
        CIFAR10(self.hparams.data_dir, train=False, download=True)

    def setup(self, stage=None):
        """データを読み込み(分散時に各GPUで実行)."""
        from torchvision.datasets import CIFAR10
        from torchvision import transforms

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        if stage == "fit" or stage is None:
            full_dataset = CIFAR10(
                self.hparams.data_dir,
                train=True,
                transform=transform
            )
            # トレーニング/検証に分割
            train_size = int(0.9 * len(full_dataset))
            val_size = len(full_dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                full_dataset, [train_size, val_size]
            )

        if stage == "test" or stage is None:
            self.test_dataset = CIFAR10(
                self.hparams.data_dir,
                train=False,
                transform=transform
            )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            persistent_workers=True,  # エポック間でワーカーを保持
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            persistent_workers=True,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
        )

3. Trainer

Trainerはトレーニングループ、ハードウェア管理、ロギングを自動化します。

from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger

# Trainerを作成
trainer = Trainer(
    # ハードウェア
    accelerator="auto",  # 自動検出: GPU、CPU、TPU、MPS
    devices="auto",  # 利用可能なすべてのデバイスを使用
    precision="16-mixed",  # 混合精度(高速、メモリ効率的)

    # トレーニング期間
    max_epochs=100,
    min_epochs=10,

    # 検証
    check_val_every_n_epoch=1,
    val_check_interval=1.0,  # トレーニングエポックの割合

    # ロギング
    log_every_n_steps=50,
    enable_progress_bar=True,

    # コールバック
    callbacks=[
        ModelCheckpoint(
            monitor="val/loss",
            mode="min",
            save_top_k=3,
            filename="epoch_{epoch:02d}-loss_{val/loss:.4f}",
        ),
        EarlyStopping(
            monitor="val/loss",
            patience=10,
            mode="min",
        ),
        LearningRateMonitor(logging_interval="step"),
    ],

    # ロガー
    logger=WandbLogger(project="my-project", name="experiment-1"),

    # デバッグ
    fast_dev_run=False,  # 1バッチテスト用にTrueに設定
    overfit_batches=0,  # デバッグ用にN個のバッチでオーバーフィット

    # 再現性
    deterministic=True,
    benchmark=False,  # 入力サイズが一定の場合、Trueに設定
)

# トレーニング
model = ImageClassifier()
datamodule = ImageDataModule()
trainer.fit(model, datamodule=datamodule)

# テスト
trainer.test(model, datamodule=datamodule, ckpt_path="best")

PyTorch 2.0統合

PyTorch 2.0のtorch.compileは、グラフコンパイルを通じて大幅なスピードアップ(平均40%以上)を提供します。

LightningでtorchコンパイルAPIを使用

方法1: モデル全体をコンパイル

class MyModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.compile(
            YourModel(),
            mode="default"  # または "reduce-overhead"、"max-autotune"
        )

方法2: Trainerで設定(推奨)

model = MyModel()

# 自動的にコンパイル
trainer = Trainer(max_epochs=10)
compiled_model = torch.compile(model, mode="default")
trainer.fit(compiled_model, datamodule=dm)

torch.compileモード

モード最適化レベルコンパイル時間ユースケース
defaultバランス型高速開発、一般的な用途
reduce-overheadカーネル起動のオーバーヘッド最小化中程度小バッチサイズ、CPUボトルネック
max-autotune最大パフォーマンス低速本番環境、長時間トレーニング

パフォーマンス例:

import torch

# 標準モデル
model = MyModel()
# コンパイルにより平均40%高速化
compiled_model = torch.compile(model, mode="max-autotune")

コンパイルベストプラクティス

やるべき事:

  • 開発時にはmode="default"を使用
  • 本番環境ではmode="max-autotune"を使用
  • ボトルネック特定のためにプロファイル実行
  • モデルアーキテクチャを静的に保つ(動的シェイプなし)

やるべきでない事:

  • 高度に動的なモデル(可変長RNN)でコンパイル
  • CPU上でのスピードアップを期待(torch.compileはGPU中心)
  • コンパイルされたモジュールとされていないモジュールを混在

グラフブレーク(パフォーマンスの問題)

グラフブレークは、PyTorchがセクションをコンパイルできない場合に発生します:

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)

    # 回避: Pythonの制御フローはグラフを破壊
    if batch_idx % 100 == 0:
        print(f"Batch {batch_idx}")  # コンパイルを破壊

    loss = F.cross_entropy(y_hat, y)
    return loss

グラフブレークをチェック:

import torch._dynamo as dynamo

# リセットしてロギングを有効化
dynamo.reset()
dynamo.config.verbose = True

model = torch.compile(model, mode="default")
# グラフブレークの位置を示す警告が表示されます

分散トレーニング

Lightningは分散トレーニングを簡単にします - 1つの引数を変更するだけです。

DDP(Distributed Data Parallel)

標準マルチGPUトレーニング:

# 単一GPU
trainer = Trainer(accelerator="gpu", devices=1)

# マルチGPU(自動DDP)
trainer = Trainer(accelerator="gpu", devices=4, strategy="ddp")

# すべてのGPU
trainer = Trainer(accelerator="gpu", devices="auto", strategy="ddp")

DDP spawn(Windows互換性):

trainer = Trainer(accelerator="gpu", devices=4, strategy="ddp_spawn")

FSDP(Fully Sharded Data Parallel)

単一GPUメモリに収まらないモデル用:

from lightning.pytorch.strategies import FSDPStrategy

trainer = Trainer(
    accelerator="gpu",
    devices=8,
    strategy=FSDPStrategy(
        sharding_strategy="FULL_SHARD",  # パラメータ、勾配、オプティマイザをシャード
        auto_wrap_policy={nn.Linear},  # 線形層を自動ラップ
    ),
)

DeepSpeed

極度に大規模なモデル用(数十億パラメータ):

from lightning.pytorch.strategies import DeepSpeedStrategy

trainer = Trainer(
    accelerator="gpu",
    devices=8,
    strategy=DeepSpeedStrategy(
        stage=3,  # ZeRO Stage 3(最もメモリ効率的)
        offload_optimizer=True,  # オプティマイザをCPUにオフロード
        offload_parameters=True,  # パラメータをCPUにオフロード
    ),
    precision="16-mixed",
)

DeepSpeed設定ファイル:

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    }
  },
  "fp16": {
    "enabled": true
  }
}

Lightning Fabric

FabricはLightningの軽量抽象化 - Trainerより高い制御性、生のPyTorchよりボイラープレートが少ない。

Fabricを使用する場合

  • Lightning特性を備えたカスタムトレーニングループ
  • PyTorchから段階的に移行
  • きめ細かい制御を必要とする研究

例:

import lightning as L
from lightning.fabric import Fabric

# Fabricを初期化
fabric = L.Fabric(
    accelerator="cuda",
    devices=2,
    precision="16-mixed"
)
fabric.launch()

# モデルとオプティマイザをセットアップ
model = YourModel()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

# データをセットアップ
train_loader = fabric.setup_dataloaders(train_loader)

# カスタムトレーニングループ
for epoch in range(epochs):
    for batch in train_loader:
        optimizer.zero_grad()
        loss = model(batch)
        fabric.backward(loss)
        optimizer.step()

        # 自動ロギング
        fabric.log("train_loss", loss)

高度なパターン

複数オプティマイザ

def configure_optimizers(self):
    # 異なる部分に対して異なる学習率
    opt_g = torch.optim.Adam(self.generator.parameters(), lr=0.001)
    opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001)
    return [opt_g, opt_d], []

手動オプティマイザ

class GANModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False  # 自動を無効化

    def training_step(self, batch, batch_idx):
        opt_g, opt_d = self.optimizers()

        # ジェネレータを訓練
        loss_g = self.generator_loss(batch)
        opt_g.zero_grad()
        self.manual_backward(loss_g)
        opt_g.step()

        # ディスクリミネータを訓練
        loss_d = self.discriminator_loss(batch)
        opt_d.zero_grad()
        self.manual_backward(loss_d)
        opt_d.step()

        self.log_dict({"loss_g": loss_g, "loss_d": loss_d})

勾配累積

# 有効バッチサイズ = batch_size * accumulate_grad_batches
trainer = Trainer(
    accumulate_grad_batches=4,  # 更新前に4バッチを累積
)

勾配クリッピング

trainer = Trainer(
    gradient_clip_val=1.0,  # 勾配を最大ノルム1.0にクリップ
    gradient_clip_algorithm="norm",  # または "value"
)

コールバック

組み込みコールバック

from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
    RichProgressBar,
    ModelSummary,
    TQDMProgressBar,
)

callbacks = [
    # 最良モデルを保存
    ModelCheckpoint(
        monitor="val/loss",
        mode="min",
        save_top_k=3,
        filename="best-{epoch:02d}-{val_loss:.2f}",
    ),

    # 早期停止
    EarlyStopping(
        monitor="val/loss",
        patience=10,
        mode="min",
        verbose=True,
    ),

    # 学習率をログ
    LearningRateMonitor(logging_interval="step"),

    # Rich進捗バー
    RichProgressBar(),

    # モデルサマリー
    ModelSummary(max_depth=2),
]

カスタムコールバック

from lightning.pytorch.callbacks import Callback

class PrintCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("トレーニング開始!")

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        train_loss = trainer.callback_metrics.get("train/loss")
        val_loss = trainer.callback_metrics.get("val/loss")
        print(f"エポック {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

    def on_validation_end(self, trainer, pl_module):
        # カスタムアーティファクトを保存
        pass

ロギング

複数ロガー

from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

wandb_logger = WandbLogger(project="my-project", name="run-1")
tb_logger = TensorBoardLogger("logs/", name="my_model")

trainer = Trainer(logger=[wandb_logger, tb_logger])

高度なロギング

def training_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)

    # スカラーをログ
    self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)

    # 複数メトリクスをログ
    self.log_dict({
        "train/loss": loss,
        "train/acc": acc,
        "train/f1": f1,
    }, on_epoch=True)

    # ヒストグラムをログ(TensorBoard/W&B用)
    if batch_idx % 100 == 0:
        self.logger.experiment.add_histogram(
            "gradients/layer1",
            self.model.layer1.weight.grad,
            self.global_step
        )

    return loss

ベストプラクティス

✅ やるべき事

  1. __init__で常にself.save_hyperparameters()を呼び出す - 再現性のため
  2. DataModuleを使用 - すべてのデータロジックをカプセル化
  3. self.log()でログ - 自動集約と同期のため
  4. コールバックを使用 - チェックポイント、早期停止、監視
  5. 混合精度を有効化 - precision="16-mixed"でスピードアップ
  6. DataLoaderでpin_memory=Truepersistent_workers=Trueを使用
  7. 再現性のためdeterministic=Trueを設定
  8. クイックサニティチェック用にfast_dev_run=Trueを使用
  9. TorchMetricsを使用 - 効率的なメトリクス計算
  10. PyTorch 2.0以降ではモデルをtorch.compileでコンパイル

❌ やるべきでない事

  1. training_stepからロスを返すのを忘れない
  2. prepare_data()で状態を設定しない - 代わりにsetup()を使用
  3. .to(device)でテンソルを手動で移動しない - Lightningが処理
  4. ロギング用にprint()を使用しない - self.log()を使用
  5. ハイパーパラメータをハードコードしない - self.hparamsを使用
  6. 分散メトリクスでsync_dist=Trueを無視しない
  7. LightningのトレーニングループとPyTorchの生のループを混在させない

一般的な問題と解決方法

問題1: NaNロス

# 解決策1: 勾配クリッピング
trainer = Trainer(gradient_clip_val=1.0)

# 解決策2: 学習率を下げる
optimizer = torch.optim.Adam(params, lr=1e-4)  # 1e-3の代わり

# 解決策3: 完全精度
trainer = Trainer(precision=32)  # 16-mixedの代わり

問題2: メモリ不足

# 解決策1: バッチサイズを減らす
datamodule = MyDataModule(batch_size=16)  # 32の代わり

# 解決策2: 勾配累積
trainer = Trainer(accumulate_grad_batches=4)

# 解決策3: 混合精度
trainer = Trainer(precision="16-mixed")

問題3: トレーニングが遅い

# 解決策1: モデルをコンパイル(PyTorch 2.0以降)
model = torch.compile(model, mode="max-autotune")

# 解決策2: ボトルネックをプロファイル
trainer = Trainer(profiler="simple")  # または "advanced"

# 解決策3: num_workersを増やす
datamodule = MyDataModule(num_workers=8)  # CPUコア数に合わせる

必須リソース

公式ドキュメント

チュートリアル

PyTorch 2.0

コミュニティ


クイックリファレンス

最小限の動作例:

import lightning as L
import torch

class MinimalModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = torch.nn.functional.mse_loss(self(x), y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

# トレーニング
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader)

まとめ

PyTorch Lightningは以下を提供します:

  • シンプルさ: ボイラープレートを削減、研究に集中
  • スケーラビリティ: ノートパソコンからスーパーコンピュータまで、1行で対応
  • 速度: PyTorch 2.0統合、自動最適化
  • 柔軟性: 必要なときはPyTorchの完全な制御が可能
  • 本番対応: 初日からデプロイメント準備完了

PyTorch 2.0のtorch.compileと組み合わせることで、Lightningは最小限のコードで最大のパフォーマンスを実現します。

ライセンス: MIT(寛容ライセンスのため全文を引用しています) · 原本リポジトリ

詳細情報

作者
nishide-dev
リポジトリ
nishide-dev/claude-code-ml-research
ライセンス
MIT
最終更新
2026/4/6

Source: https://github.com/nishide-dev/claude-code-ml-research / ライセンス: MIT

本サイトは GitHub 上で公開されているオープンソースの SKILL.md ファイルをクロール・インデックス化したものです。 各スキルの著作権は原作者に帰属します。掲載に問題がある場合は info@alsel.co.jp または /takedown フォームよりご連絡ください。
原作者: nishide-dev · nishide-dev/claude-code-ml-research · ライセンス: MIT