mint-cot-visual-reasoning
Interleave Tokensを活用して、数学的推論の各ステップで関連する画像領域を動的に選択し、細粒度のビジュアルトークンを数学的推論に統合できます。
description の原文を見る
Integrates fine-grained visual tokens into mathematical reasoning via Interleave Tokens that dynamically select relevant image regions for each reasoning step.
SKILL.md 本文
MINT-CoT: 数学推論における インターリーブされたビジュアルトークン
コアコンセプト
図を含む数学推論には、テキストの推論ステップとビジュアル領域の間の正確なアラインメントが必要です。既存のアプローチは粗いバウンディングボックスを使用しており、幾何学と図の解釈に不可欠な細粒度のビジュアル理解を制限しています。MINT-CoTは、デコーダ状態とビジュアルトークン間の類似度スコアを計算することで、推論中に非矩形画像領域の動的選択を可能にするInterleave Tokensを導入します。トークンレベルのアラインメントを含む54Kのデータセットと3段階の段階的なトレーニングにより、数学ベンチマークにおいて大幅な改善が実現します。
アーキテクチャ概要
- Interleave Tokens: デコーダの隠れ状態とビジュアルトークン埋め込み間の類似度を計算して関連するビジュアル領域を選択する特殊トークン
- 細粒度選択: 非矩形領域の選択を可能にし、任意の形状で図の要素をキャプチャします
- MINT-CoTデータセット: 推論ステップと画像領域間のトークンレベルのアラインメントを含む54Kのアノテーション済み問題
- 3段階トレーニング: テキストのみのCoT → インターリーブされたCoT教師あり → インターリーブされたCoT強化学習
- 自動アノテーション: 効率的なデータセット構築のための4ステップパイプライン(グリッディング、OCR、キーワード抽出、アラインメント)
- GRPO統合: 強化学習フェーズで推論品質をエンドツーエンドで最適化
実装
以下のコードはInterleave Tokenメカニズムとトレーニングパイプラインを示しています:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
class InterleaveToken(nn.Module):
"""
Special token that selects relevant visual regions during reasoning.
"""
def __init__(self, hidden_dim: int, num_visual_tokens: int):
super().__init__()
self.hidden_dim = hidden_dim
self.num_visual_tokens = num_visual_tokens
# Projections for similarity computation
self.query_proj = nn.Linear(hidden_dim, hidden_dim)
self.visual_key_proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, decoder_state: torch.Tensor,
visual_tokens: torch.Tensor,
threshold: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Select visual regions by computing similarity with decoder state.
decoder_state: (hidden_dim,) hidden state before this reasoning step
visual_tokens: (num_visual_tokens, hidden_dim) image patches encoded
threshold: similarity threshold for region selection
Returns: (selected_regions, selection_mask)
"""
# Project for similarity computation
query = self.query_proj(decoder_state) # (hidden_dim,)
keys = self.visual_key_proj(visual_tokens) # (num_visual_tokens, hidden_dim)
# Compute cosine similarity
query_norm = F.normalize(query, p=2, dim=-1)
keys_norm = F.normalize(keys, p=2, dim=-1)
similarity = torch.matmul(keys_norm, query_norm) # (num_visual_tokens,)
# Soft selection: apply softmax for differentiable selection
selection_weights = F.softmax(similarity * 10.0, dim=0) # Temperature=10
# Hard threshold: which regions to include
selection_mask = (similarity > threshold).float()
# Weighted combination of selected visual tokens
selected_regions = torch.matmul(
selection_weights.unsqueeze(0), visual_tokens
) # (1, hidden_dim)
return selected_regions, selection_mask
class MINTCoTModel(nn.Module):
"""
Multimodal model with interleaved visual tokens for math reasoning.
"""
def __init__(self, hidden_dim: int = 4096, vocab_size: int = 32000,
num_visual_tokens: int = 256):
super().__init__()
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
# Language model components
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=4*hidden_dim,
batch_first=True),
num_layers=12
)
self.output_proj = nn.Linear(hidden_dim, vocab_size)
# Vision components
self.vision_encoder = nn.Identity() # In practice, use CLIP or similar
self.interleave_token = InterleaveToken(hidden_dim, num_visual_tokens)
def encode_image_to_tokens(self, image: torch.Tensor) -> torch.Tensor:
"""
Encode image into visual tokens (e.g., patch embeddings).
image: (B, 3, H, W)
Returns: (B, num_visual_tokens, hidden_dim)
"""
# Vision encoder produces patch embeddings
# Simplified: assuming fixed 16x16 = 256 patches
visual_tokens = self.vision_encoder(image) # (B, num_patches, hidden_dim)
return visual_tokens
def forward_with_interleave(self, input_ids: torch.Tensor,
image: torch.Tensor,
visual_token_mask: torch.Tensor) -> torch.Tensor:
"""
Forward pass with interleaved visual token selection.
input_ids: (batch, seq_len) token sequence
image: (batch, 3, H, W) input image
visual_token_mask: (batch, seq_len, num_visual_tokens) which tokens trigger visual selection
Returns: (batch, seq_len, vocab_size) logits
"""
batch_size, seq_len = input_ids.shape
# Encode image
visual_tokens = self.encode_image_to_tokens(image) # (batch, num_visual_tokens, hidden_dim)
# Embed text tokens
text_embeds = self.embedding(input_ids) # (batch, seq_len, hidden_dim)
# Interleave visual tokens where needed
interleaved_embeds = text_embeds.clone()
for pos in range(seq_len):
if visual_token_mask[:, pos].any():
# This position triggers visual selection
decoder_state = text_embeds[:, pos] # (batch, hidden_dim)
# For each batch item, select visual regions
for b in range(batch_size):
if visual_token_mask[b, pos]:
selected_regions, _ = self.interleave_token(
decoder_state[b], visual_tokens[b]
)
# Blend selected visual information with text embedding
interleaved_embeds[b, pos] = 0.7 * text_embeds[b, pos] + 0.3 * selected_regions.squeeze(0)
# Transformer forward
output = self.transformer(interleaved_embeds, memory=None)
logits = self.output_proj(output)
return logits
class MINTCoTDataset:
"""
Dataset construction pipeline for MINT-CoT.
"""
def __init__(self):
self.grid_size = 16 # 16x16 grid
def grid_image(self, image: torch.Tensor) -> List[Tuple[int, int, int, int]]:
"""
Divide image into grid cells.
Returns: list of (x1, y1, x2, y2) coordinates
"""
_, h, w = image.shape
cell_h = h // self.grid_size
cell_w = w // self.grid_size
grid_cells = []
for i in range(self.grid_size):
for j in range(self.grid_size):
x1, y1 = j * cell_w, i * cell_h
x2, y2 = x1 + cell_w, y1 + cell_h
grid_cells.append((x1, y1, x2, y2))
return grid_cells
def extract_keywords_from_step(self, reasoning_step: str) -> List[str]:
"""
Extract keywords from a reasoning step using GPT-4o.
In practice, use language model API.
"""
# Simplified: would call GPT-4o in real implementation
keywords = reasoning_step.split()[:3] # Placeholder
return keywords
def annotate_visual_regions(self, image: torch.Tensor,
reasoning_step: str) -> List[int]:
"""
Map reasoning step keywords to image grid cells.
Returns: list of grid cell indices relevant to this step.
"""
keywords = self.extract_keywords_from_step(reasoning_step)
grid_cells = self.grid_image(image)
# In practice, use OCR to locate keyword positions in image
# For now, return placeholder annotations
relevant_cells = [0, 1, 16, 17] # Example cells
return relevant_cells
def create_dataset_sample(self, image: torch.Tensor,
problem: str,
reasoning_chain: List[str],
answer: str) -> dict:
"""
Create single dataset sample with token-level visual annotations.
"""
# Tokenize problem + reasoning chain
full_text = problem + " " + " ".join(reasoning_chain) + " " + answer
tokens = full_text.split() # Simplified tokenization
# Annotate which tokens trigger visual selection
visual_token_mask = [0] * len(tokens)
for step_idx, step in enumerate(reasoning_chain):
step_keywords = self.extract_keywords_from_step(step)
# Find token positions for this step
step_start = sum(len(r.split()) for r in reasoning_chain[:step_idx])
step_end = step_start + len(step.split())
# Mark these tokens as visual
for pos in range(step_start, min(step_end, len(visual_token_mask))):
visual_token_mask[pos] = 1
return {
'image': image,
'tokens': tokens,
'visual_mask': visual_token_mask,
'answer': answer
}
class MINTCoTTrainer:
"""
Three-stage training for MINT-CoT.
"""
def __init__(self, model: MINTCoTModel):
self.model = model
self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
def stage1_text_cot_sft(self, dataset: List[dict], epochs: int = 3):
"""Stage 1: Supervised fine-tuning on text-only chain-of-thought."""
for epoch in range(epochs):
for sample in dataset:
# Train on text tokens only
logits = self.model(torch.tensor(sample['tokens']),
sample['image'],
torch.zeros_like(sample['visual_mask']))
loss = F.cross_entropy(logits.view(-1, self.model.vocab_size),
torch.tensor(sample['tokens']))
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
def stage2_interleaved_cot_sft(self, dataset: List[dict], epochs: int = 3):
"""Stage 2: Supervised fine-tuning with interleaved visual tokens."""
for epoch in range(epochs):
for sample in dataset:
# Include visual selection mask
visual_mask = torch.tensor(sample['visual_mask']).unsqueeze(0)
logits = self.model(torch.tensor(sample['tokens']).unsqueeze(0),
sample['image'].unsqueeze(0),
visual_mask)
loss = F.cross_entropy(logits.view(-1, self.model.vocab_size),
torch.tensor(sample['tokens']))
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
def stage3_interleaved_grpo(self, dataset: List[dict], num_iterations: int = 100):
"""Stage 3: Group Relative Policy Optimization for reasoning quality."""
for _ in range(num_iterations):
# GRPO: generate multiple reasoning chains, rank by correctness
# Optimize for better reasoning while using visual tokens
pass # Implementation details omitted for brevity
実装ガイダンス
グリッドサイズの選択: 16×16グリッド(256セル)は図の理解のための適切な粒度を提供します。細かい詳細を含む高解像度画像の場合は、32×32に増やしてください。
類似度閾値: Interleave Token選択で50%の類似度閾値として0.5に設定します。低い閾値(0.3)は領域選択を増やし、高い閾値(0.7)はより選別的な選択にします。
キーワード抽出: 推論ステップから正確なキーワード抽出を行うにはGPT-4oを使用します。別の方法として、特定の数学領域用にドメイン固有のキーワードリストを使用できます。
アノテーション品質: データセット品質が重要です。トレーニング前に、OCR精度と推論ステップおよび画像領域間の適切なアラインメントを確認してください。
トレーニングスケジュール: 3段階の進行を厳密に従ってください。各段階は前の段階に基づいており、段階をスキップするとパフォーマンスが低下します。
ビジュアライゼーション: トレーニング中に選択された画像領域をビジュアライズして、Interleave Token選択が推論ステップとアラインメントしていることを確認します。
リファレンス
MINT-CoTは数学推論に大幅な改善を実現します:
- MathVista(数学サブセット): +32.59%の改善
- GeoQA: +26.92%の改善
- 幾何学関連のタスク: 最先端のモデルを上回ります
細粒度のビジュアルアラインメントを含む54Kのアノテーション済みデータセットにより、モデルは図の情報を効果的に活用できます。このアプローチは、幾何学、図ベースの推論、およびビジュアル情報が不可欠な科学的問題解決に特に価値があります。
ライセンス: MIT(寛容ライセンスのため全文を引用しています) · 原本リポジトリ
詳細情報
- 作者
- ADu2021
- リポジトリ
- ADu2021/skillXiv
- ライセンス
- MIT
- 最終更新
- 2026/3/26
Source: https://github.com/ADu2021/skillXiv / ライセンス: MIT
関連スキル
agent-browser
AI エージェント向けのブラウザ自動化 CLI です。ウェブサイトとの対話が必要な場合に使用します。ページ遷移、フォーム入力、ボタンクリック、スクリーンショット取得、データ抽出、ウェブアプリのテスト、ブラウザ操作の自動化など、あらゆるブラウザタスクに対応できます。「ウェブサイトを開く」「フォームに記入する」「ボタンをクリックする」「スクリーンショットを取得する」「ページからデータを抽出する」「このウェブアプリをテストする」「サイトにログインする」「ブラウザ操作を自動化する」といった要求や、プログラマティックなウェブ操作が必要なタスクで起動します。
anyskill
AnySkill — あなたのプライベート・スキルクラウド。GitHubを基盤としたリポジトリからエージェントスキルを管理、同期、動的にロードできます。自然言語でクラウドスキルを検索し、オンデマンドでプロンプトを自動ロード、カスタムスキルのアップロードと共有、スキルバンドルの一括インストールが可能です。OpenClaw、Antigravity、Claude Code、Cursorに対応しています。
engram
AIエージェント向けの永続的なメモリシステムです。バグ修正、意思決定、発見、設定変更の後はmem_saveを使用してください。ユーザーが「覚えている」「記憶している」と言及した場合、または以前のセッションと重複する作業を開始する際はmem_searchを使用します。セッション終了前にmem_session_summaryを使用して、コンテキストを保持してください。
skyvern
AI駆動のブラウザ自動化により、任意のウェブサイトを自動化できます。フォーム入力、データ抽出、ファイルダウンロード、ログイン、複数ステップのワークフロー実行など、ユーザーがウェブサイトと連携する必要があるときに使用します。Skyvernは、LLMとコンピュータビジョンを活用して、未知のサイトも自動操作可能です。Python SDK、TypeScript SDK、REST API、MCPサーバー、またはCLIを通じて統合できます。
pinchbench
PinchBenchベンチマークを実行して、OpenClawエージェントの実世界タスクにおけるパフォーマンスを評価できます。モデルの機能テスト、モデル間の比較、ベンチマーク結果のリーダーボード提出、またはOpenClawのセットアップがカレンダー、メール、リサーチ、コーディング、複数ステップのワークフローにどの程度対応しているかを確認する際に使用します。
openui
OpenUIとOpenUI Langを使用してジェネレーティブUIアプリを構築できます。これらはLLM生成インターフェースのためのトークン効率的なオープン標準です。OpenUI、@openuidev、ジェネレーティブUI、LLMからのストリーミングUI、AI向けコンポーネントライブラリ、またはjson-render/A2UIの置き換えについて述べる際に使用します。スキャフォルディング、defineComponent、システムプロンプト、Renderer、およびOpenUI Lang出力のデバッグに対応しています。