config-knowledge-distillation
教師・生徒モデル間の不一致を活用した拡散ベースのデータ拡張により、共変量シフト下での生徒モデルのロバスト性を向上させます。この手法は、虚偽の特徴量を標的とすることで、モデルの堅牢性を改善します。
description の原文を見る
Improve student model robustness under covariate shift by using diffusion-based augmentation that targets spurious features via teacher-student disagreement.
SKILL.md 本文
ConfiG: Confidence-Guided Data Augmentation for Knowledge Distillation
コアコンセプト
学習データに含まれるテスト時に存在しない疑似特徴がある場合、知識蒸留はこれらのバイアスを学生モデルに保存する可能性があります。ConfiGは教師と学生の不一致を最大化する拡張画像を生成することでこの問題に対処し、学生が学習した疑似相関を正確に対象とします。この拡散ベースのアプローチにより、学生はデータセットのバイアスを克服しながら知識転移を維持できます。
アーキテクチャの概要
- 問題: 知識蒸留はバイアスのあるデータセットからの疑似相関を転移し、未知のグループへの汎化を低下させます
- 解決策: 信頼度ガイド付き拡散は教師-学生の不一致を活用して対抗的な例を生成します
- メカニズム: 潜在変数を最適化して教師の信頼度を最大化しながら学生の信頼度を最小化し、疑似特徴を対象とします
- 理論的根拠: 命題1は信頼度ガイド付き拡張が分布的汎化ギャップを削減することを証明しています
- 相乗効果: モデル中心のバイアス軽減(TABなど)と連携し、データとモデルのアプローチが補完的であることを示します
実装
ステップ1: 共変量シフトと汎化分解を理解する
import torch
import numpy as np
from typing import Tuple, List
class CovariateShiftAnalyzer:
"""Analyze how spurious features affect generalization"""
def decompose_error(self, teacher_model, student_model,
train_data, test_data,
group_labels_test) -> Dict:
"""
Decompose generalization error into two components:
1. Teacher quality: how well teacher generalizes
2. Distributional gap: how much student differs from teacher distribution
Key insight: ConfiG targets reducing the gap, not improving teacher quality.
"""
# Evaluate on training distribution (in-distribution)
train_acc = self.evaluate_accuracy(student_model, train_data)
# Evaluate on test distribution
test_acc = self.evaluate_accuracy(student_model, test_data)
# Evaluate per test group
group_accs = {}
for group_id in np.unique(group_labels_test):
group_mask = group_labels_test == group_id
group_test = test_data[group_mask]
group_accs[group_id] = self.evaluate_accuracy(student_model, group_test)
# Decompose error
overall_gap = train_acc - test_acc
group_gaps = {g: train_acc - acc for g, acc in group_accs.items()}
print("=== Generalization Analysis ===")
print(f"Overall train accuracy: {train_acc:.1%}")
print(f"Overall test accuracy: {test_acc:.1%}")
print(f"Overall gap: {overall_gap:.1%}")
print("\nPer-group performance:")
for group_id, acc in group_accs.items():
print(f" Group {group_id}: {acc:.1%} (gap: {group_gaps[group_id]:.1%})")
return {
'overall_gap': overall_gap,
'group_gaps': group_gaps,
'group_accs': group_accs,
}
def identify_spurious_correlations(self, train_data,
train_labels,
group_labels) -> List[Tuple]:
"""
Identify features that are predictive in training but spurious.
Example: Blond hair → Female in CelebA, but this breaks in real data.
"""
spurious_correlations = []
# Analyze correlations between features and groups
unique_groups = np.unique(group_labels)
for group_id in unique_groups:
group_mask = group_labels == group_id
# Extract feature statistics for this group
group_data = train_data[group_mask]
other_data = train_data[~group_mask]
# Compute feature divergence between groups
# Features with high divergence are likely spurious
feature_divergence = self._compute_feature_divergence(
group_data, other_data
)
# Identify top divergent features
top_divergent = sorted(
enumerate(feature_divergence),
key=lambda x: x[1],
reverse=True
)[:5]
for feature_idx, divergence in top_divergent:
spurious_correlations.append({
'group': group_id,
'feature': feature_idx,
'divergence': divergence,
})
return spurious_correlations
def _compute_feature_divergence(self, group_data, other_data) -> np.ndarray:
"""Measure KL divergence of feature distributions"""
# Simplified: compute histogram divergence per feature
divergences = []
for feature_idx in range(group_data.shape[1]):
# Histogram of feature in group vs out
hist_group = np.histogram(group_data[:, feature_idx], bins=20)[0]
hist_other = np.histogram(other_data[:, feature_idx], bins=20)[0]
# Normalize
hist_group = hist_group / (np.sum(hist_group) + 1e-8)
hist_other = hist_other / (np.sum(hist_other) + 1e-8)
# KL divergence
kl = np.sum(hist_group * np.log((hist_group + 1e-8) / (hist_other + 1e-8)))
divergences.append(kl)
return np.array(divergences)
ステップ2: 信頼度ガイド付き拡張を実装する
import torch.nn.functional as F
class ConfidenceGuidedDiffusion:
"""Generate augmented images targeting spurious features"""
def __init__(self, diffusion_model, teacher_model, student_model):
self.diffusion = diffusion_model # Pretrained diffusion (e.g., Stable Diffusion)
self.teacher = teacher_model
self.student = student_model
def generate_augmented_sample(self, image: torch.Tensor,
label: int,
gamma: float = 2.0,
num_iterations: int = 100) -> torch.Tensor:
"""
Generate augmented image by optimizing latent vector z.
Objective: maximize loss(z) = t(z)^γ + (1-f(z))^γ
where:
t(z) = teacher confidence on augmented sample
f(z) = student confidence on augmented sample
γ = 2.0 (empirically optimal)
This targets spurious features the student learned.
"""
# Initialize random latent in diffusion space
z = torch.randn(1, 4, image.shape[1]//8, image.shape[2]//8)
z.requires_grad = True
# Optimizer for latent variables
optimizer = torch.optim.Adam([z], lr=0.01)
for iteration in range(num_iterations):
# Decode latent to image
augmented_image = self.diffusion.decode(z)
# Get predictions
with torch.no_grad():
teacher_logits = self.teacher(augmented_image)
student_logits = self.student(augmented_image)
teacher_probs = F.softmax(teacher_logits, dim=-1)
student_probs = F.softmax(student_logits, dim=-1)
# Extract confidence for true label
teacher_conf = teacher_probs[0, label] # High = good, preserve label
student_conf = student_probs[0, label] # Low = good, challenge student
# Confidence-guided loss
loss = (teacher_conf ** gamma) + ((1.0 - student_conf) ** gamma)
# Backward step
optimizer.zero_grad()
(-loss).backward() # Maximize by minimizing negative
optimizer.step()
if (iteration + 1) % 20 == 0:
print(f"Iter {iteration + 1}: "
f"teacher_conf={teacher_conf:.3f}, "
f"student_conf={student_conf:.3f}, "
f"loss={loss:.3f}")
# Decode final latent
with torch.no_grad():
augmented = self.diffusion.decode(z)
return augmented
def generate_augmented_dataset(self, train_images: torch.Tensor,
train_labels: torch.Tensor,
augmentation_ratio: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate augmented samples for a subset of training data.
Focus on misclassified or uncertain samples.
"""
# Identify samples where student is uncertain
with torch.no_grad():
student_logits = self.student(train_images)
student_probs = F.softmax(student_logits, dim=-1)
student_confidence = torch.max(student_probs, dim=-1)[0]
# Select samples with low confidence (more need augmentation)
num_to_augment = int(len(train_images) * augmentation_ratio)
uncertain_indices = torch.argsort(student_confidence)[:num_to_augment]
# Generate augmentations
augmented_images = []
augmented_labels = []
for idx in uncertain_indices:
image = train_images[idx].unsqueeze(0)
label = train_labels[idx].item()
print(f"Generating augmentation {len(augmented_images) + 1}/{num_to_augment}")
augmented = self.generate_augmented_sample(image, label)
augmented_images.append(augmented)
augmented_labels.append(label)
# Concatenate original and augmented
all_images = torch.cat([train_images] + augmented_images, dim=0)
all_labels = torch.cat([
train_labels,
torch.tensor(augmented_labels, device=train_labels.device)
], dim=0)
return all_images, all_labels
ステップ3: ConfiGを使った知識蒸留を実装する
class KDWithConfiG:
"""Knowledge distillation enhanced with confidence-guided augmentation"""
def __init__(self, teacher_model, student_model,
diffusion_model, temperature: float = 4.0):
self.teacher = teacher_model
self.student = student_model
self.diffusion = diffusion_model
self.temperature = temperature
self.confidence_aug = ConfidenceGuidedDiffusion(
diffusion_model, teacher_model, student_model
)
self.optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
def distillation_loss(self, student_logits: torch.Tensor,
teacher_logits: torch.Tensor) -> torch.Tensor:
"""KL divergence between student and teacher distributions"""
student_probs = F.softmax(student_logits / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
kl_loss = F.kl_div(
torch.log(student_probs + 1e-8),
teacher_probs.detach(),
reduction='batchmean'
)
return kl_loss * (self.temperature ** 2)
def training_step(self, batch_images: torch.Tensor,
batch_labels: torch.Tensor) -> float:
"""Single training step on original + augmented data"""
# Forward pass
student_logits = self.student(batch_images)
with torch.no_grad():
teacher_logits = self.teacher(batch_images)
# Compute distillation loss
loss = self.distillation_loss(student_logits, teacher_logits)
# Backward
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def train_with_augmentation(self, train_images: torch.Tensor,
train_labels: torch.Tensor,
num_epochs: int = 10) -> Dict:
"""Train student with periodic augmentation"""
history = {'loss': []}
for epoch in range(num_epochs):
print(f"\n=== Epoch {epoch + 1}/{num_epochs} ===")
# Original training step
loss = self.training_step(train_images, train_labels)
history['loss'].append(loss)
print(f"Original data loss: {loss:.4f}")
# Every N epochs: generate augmentations and finetune
if (epoch + 1) % 3 == 0:
print("Generating confidence-guided augmentations...")
aug_images, aug_labels = self.confidence_aug.generate_augmented_dataset(
train_images, train_labels, augmentation_ratio=0.3
)
print(f"Augmented dataset size: {len(aug_images)}")
# Fine-tune on augmented data
for aug_epoch in range(2):
aug_loss = self.training_step(aug_images, aug_labels)
print(f" Augmented epoch {aug_epoch + 1}: loss={aug_loss:.4f}")
return history
ステップ4: モデル中心のバイアス軽減との統合
class HybridBiasMitigation:
"""Combine data-centric (ConfiG) and model-centric (TAB) approaches"""
def __init__(self, teacher_model, student_model, diffusion_model):
self.kd_config = KDWithConfiG(
teacher_model, student_model, diffusion_model
)
# Model-centric: Train-aware batch normalization (TAB)
self.model_centric = TrainAwareBatchNorm(student_model)
def train_hybrid(self, train_images: torch.Tensor,
train_labels: torch.Tensor,
group_labels: torch.Tensor) -> Dict:
"""
Data-centric: ConfiG augmentation targeting spurious features
Model-centric: TAB encouraging invariant features
"""
print("Starting hybrid bias mitigation training...")
print("Data-centric: Confidence-guided augmentation")
print("Model-centric: Train-aware batch normalization")
# Phase 1: Data-centric augmentation
kd_history = self.kd_config.train_with_augmentation(
train_images, train_labels, num_epochs=10
)
# Phase 2: Model-centric refinement with TAB
tab_history = self.model_centric.train_with_tab(
train_images, train_labels, group_labels, num_epochs=5
)
return {
'kd_history': kd_history,
'tab_history': tab_history,
}
実践的なガイダンス
-
疑似相関を特定する: まず学習データを分析して、ラベルと相関しているが、テストデータでは存在する可能性が低い特徴(例: 背景、照明、特定のオブジェクト)を見つけます。
-
教師の品質が重要: ConfiGはロバストな教師モデルを想定しています。教師がバイアスを持っている場合、拡張は役に立ちません。十分に学習された教師または合成データを使用してください。
-
ガンマパラメータ: gamma=2.0が経験的に最適です。より高いγは高い不一致のサンプルに学習を集中させ、より低いγはより広く分散させます。
-
拡張比率: 学習データの30〜50%を拡張し、不確実なサンプルに焦点を当てます。過度な拡張は分布内パフォーマンスを低下させる可能性があります。
-
ハイブリッドアプローチが最適: データ中心(ConfiG)拡張とモデル中心のアプローチ(TAB、グループ正規化)を組み合わせます。これらは補完的であり、共に使用するとより良い結果を達成します。
-
計算コスト: 拡散ベースの拡張は高コスト(画像あたり100イテレーション以上)です。拡張データセットをバッチ前処理し、オンラインで実行しないようにしてください。
参考資料
- 論文: ConfiG (2506.02294)
- 主要な革新: 疑似特徴を対象とした信頼度ガイド付き拡散
- アーキテクチャ: 教師-学生の不一致→拡張目的
- データセット: CelebA、SpuCo Birds、Spurious ImageNet
- 結果: 共変量シフト下での従来の拡張方法と比較して優れたパフォーマンス
ライセンス: MIT(寛容ライセンスのため全文を引用しています) · 原本リポジトリ
詳細情報
- 作者
- ADu2021
- リポジトリ
- ADu2021/skillXiv
- ライセンス
- MIT
- 最終更新
- 2026/3/26
Source: https://github.com/ADu2021/skillXiv / ライセンス: MIT
関連スキル
hugging-face-trackio
Trackioを使用してMLトレーニング実験を追跡・可視化できます。トレーニング中のメトリクスログ記録(Python API)、トレーニング診断のアラート発火、ログされたメトリクスの取得・分析(CLI)が必要な場合に活用してください。リアルタイムダッシュボード表示、Webhookを使用したアラート、HF Space同期、自動化向けのJSON出力に対応しています。
btc-bottom-model
ビットコインのサイクルタイミングモデルで、加重スコアリングシステムを搭載しています。日次パルス(4指標、32ポイント)とウィークリー構造(9指標、68ポイント)の2カテゴリーにわたる13の指標を追跡し、0~100のマーケットヒートスコアを算出します。ETFフロー、ファンディングレート、ロング/ショート比率、恐怖・貪欲指数、LTH-MVRV、NUPL、SOPR(LTH+STH)、LTH供給率、移動平均倍率(365日MA、200週MA)、週次RSI、出来高トレンドに対応します。市場サイクル全体を通じて買いと売りの両方の推奨を提供します。ビットコインの底値拾い、BTCサイクルポジション、買い時・売り時、オンチェーン指標、MVRV、NUPL、SOPR、LTH動向、ETFの流出入、ファンディングレート、恐怖指数、ビットコインが過熱状態か、マイナーコスト、暗号資産市場のセンチメント、BTCのポジションサイジング、「今ビットコインを買うべきか」「BTCが天井をつけているか」「オンチェーン指標は何を示しているか」といった質問の際にこのスキルを活用します。
protein_solubility_optimization
タンパク質の溶解性最適化 - タンパク質の溶解性を最適化します。タンパク質の特性を計算し、溶解性と親水性を予測し、有効な変異を提案します。タンパク質配列の特性計算、タンパク質機能の予測、親水性計算、ゼロショット配列予測を含むタンパク質エンジニアリング業務に使用できます。3つのSCPサーバーから4つのツールを統合しています。
research-lookup
Parallel Chat APIまたはPerplexity sonar-pro-searchを使用して、最新の研究情報を検索できます。学術論文の検索にも対応しています。クエリは自動的に最適なバックエンドにルーティングされるため、論文の検索、研究データの収集、科学情報の検証に活用できます。
tree-formatting
ggtree(R)またはiTOL(ウェブ)を使用して、系統樹の可視化とフォーマットを行います。系統樹を図として描画する際、ツリーレイアウトの選択、分類学に基づく枝やラベルの色付け、クレードの折りたたみ、サポート値の表示、またはツリーへのオーバーレイ追加が必要な場合に使用してください。系統推定(protein-phylogenyスキルを使用)やドメイン注釈(今後の独立したスキル)には使用しないでください。
querying-indonesian-gov-data
インドネシア政府の50以上のAPIとデータソースに接続できます。BPJPH(ハラール認証)、BOM(食品安全)、OJK(金融適正性)、BPS(統計)、BMKG(気象・地震)、インドネシア中央銀行(為替レート)、IDX(株式)、CKAN公開データポータル、pasal.id(第三者法MCP)に対応しています。インドネシア政府データを活用したアプリ開発、.go.idウェブサイトのスクレイピング、ハラール認証の確認、企業の法的適正性の検証、金融機関ステータスの照会、またはインドネシアMCPサーバーへの接続時に使用できます。CSRF処理、CKAN API使用方法、IP制限回避など、すぐに実行可能なPythonパターンを含んでいます。