vocabtrim-speculative-decoding
高頻度トークンに絞られたドラフタ語彙によって、推測デコーディングを高速化します。未使用の語彙エントリを削除することで、再学習なしにメモリ制約環境で16%の高速化を実現できます。
description の原文を見る
Accelerate speculative decoding by pruning drafter vocabulary to high-frequency tokens. Achieves 16% speedup in memory-bound settings by eliminating unused vocabulary entries without retraining.
SKILL.md 本文
VocabTrim: メモリ効率的な推測デコーディング用語彙削減
推測デコーディングは、小規模なドラフタモデルを使用して推論ステップごとに複数のトークンを提案し、ターゲット検証モデルがそれらを受け入れるか拒否するかの判定を行います。ドラフタが高速な場合、このアプローチは推論を2~3倍高速化できます。しかし、ドラフタの言語モデリングヘッド(すべての語彙トークンにわたるロジットを出力する最終層)がメモリボトルネックになります。128Kの語彙トークンを持つLlama-3の場合、毎ステップで128K個すべてのトークンにわたるロジット計算は、ドラフタが頻繁に出現するトークンのごく一部からのみサンプリングしているにもかかわらず、メモリと計算を浪費します。
VocabTrimは、ドラフタが推論中に実際にサンプリングする高頻度トークンのみを含むようにドラフタの語彙を再構築することでこの問題を解決します。ドラフタは「予測しやすい」トークン(一般的な単語、句読点)に偏り、レアトークンをめったにサンプリングしないという観点から出発しています。語彙を最頻出の25~50Kトークンに削減することで、受け入れ率への影響がほぼ無視できる範囲で、LMヘッド計算の60~75%を排除できます。
コアコンセプト
VocabTrimは単純な原則に基づいています:ドラフタの完全な語彙を、それが頻繁に生成するトークンのみを含む小さな語彙に置き換える。重要な洞察は以下の通りです:
- ドラフタは本来、高頻度で予測可能なトークンに焦点を当てます。なぜなら高速かつ正確であろうとするためです
- ターゲットモデルは検証時に完全な語彙にアクセスできるため、ドラフタのカバレッジが完全である必要はありません
- トークンのリマッピングにより、生成時にドラフタの出力を完全な語彙インデックスに変換できます
- 本手法は訓練不要です:キャリブレーションデータからトークン周波数を測定し、LMヘッドを再構築するだけです
結果として得られるのは、ステップごとのロジット数が削減される(メモリ転送量が減少)より小規模なドラフタで、拒否率の増加はわずか1~3%に留まりながら、行列乗算も高速化されます。
アーキテクチャ概要
VocabTrimはドラフタの言語モデリングヘッドのみを変更します:
- トークン周波数測定:キャリブレーションデータでドラフタを実行し、生成時にサンプリングされるトークンをカウント
- 語彙選択:キャリブレーションセット全体でサンプリングされる最頻出K個のトークンを抽出
- LMヘッド再構築:選択されたトークンインデックスのみの重みを含む新しい小規模LMヘッド行列を作成
- トークンリマッピング:推論時に、ドラフタの予測値(小規模語彙内のインデックス)を検証器用の完全な語彙インデックスに変換
- 検証器は変更なし:ターゲットモデルは完全な語彙インデックスを受け取り、変更なしで動作
実装
ステップ1:キャリブレーションデータでトークン周波数を測定
代表的なデータでドラフタからサンプリングし、出力分布で最も頻繁に出現するトークンをカウントします。
def measure_drafter_token_frequencies(drafter, tokenizer, calibration_data,
num_samples=100000, sample_length=256):
"""
Run the drafter on calibration data and measure which tokens
are sampled most frequently. This identifies which tokens matter for inference speed.
"""
token_counts = {}
total_tokens_sampled = 0
for batch_idx, batch in enumerate(calibration_data):
input_ids = tokenizer.encode(batch, return_tensors='pt').to(device)
# Generate from drafter, collecting all sampled tokens
with torch.no_grad():
for step in range(sample_length):
outputs = drafter(input_ids)
logits = outputs.logits[:, -1, :] # Last position logits
# Sample from the distribution (how the drafter would generate in practice)
probs = torch.softmax(logits, dim=-1)
sampled_tokens = torch.multinomial(probs, num_samples=1)
# Count token occurrences
for token_id in sampled_tokens.flatten().tolist():
token_counts[token_id] = token_counts.get(token_id, 0) + 1
total_tokens_sampled += 1
input_ids = torch.cat([input_ids, sampled_tokens], dim=1)
if (batch_idx + 1) * sample_length >= num_samples:
break
# Sort by frequency
sorted_tokens = sorted(token_counts.items(),
key=lambda x: x[1], reverse=True)
return dict(sorted_tokens)
ステップ2:語彙を選択してリマッピングを作成
最頻出K個のトークンを選択し、小規模語彙から完全な語彙へのインデックスマッピングを作成します。
def create_vocab_mapping(token_frequencies, target_vocab_size=32000):
"""
Select the most frequent tokens and create a mapping.
This tells us which full vocabulary indices to keep.
"""
# Select top-K tokens by frequency
top_tokens = sorted(token_frequencies.items(),
key=lambda x: x[1], reverse=True)[:target_vocab_size]
# Create two mappings:
# 1. full_to_trimmed: maps full vocabulary indices to trimmed indices
# 2. trimmed_to_full: maps trimmed indices back to full vocabulary
full_to_trimmed = {}
trimmed_to_full = {}
for trimmed_idx, (full_token_id, count) in enumerate(top_tokens):
full_to_trimmed[full_token_id] = trimmed_idx
trimmed_to_full[trimmed_idx] = full_token_id
# Calculate coverage: what percentage of real samples are covered
total_count = sum(token_frequencies.values())
selected_count = sum(count for _, count in top_tokens)
coverage = selected_count / total_count
return {
'full_to_trimmed': full_to_trimmed,
'trimmed_to_full': trimmed_to_full,
'vocab_size': target_vocab_size,
'coverage': coverage
}
ステップ3:ドラフタのLMヘッドを再構築
必要な重みの行のみを抽出することで、新しい小規模な言語モデリングヘッドを作成します。
def reconstruct_lm_head(drafter, vocab_mapping):
"""
Extract the drafter's LM head weights for selected tokens only.
This creates a new head with fewer output dimensions.
"""
original_head = drafter.lm_head
original_weight = original_head.weight # Shape: [vocab_size, hidden_dim]
# Select only rows corresponding to kept tokens
selected_indices = torch.tensor(
[vocab_mapping['trimmed_to_full'][i]
for i in range(vocab_mapping['vocab_size'])],
device=original_weight.device
)
# Extract selected rows
trimmed_weight = original_weight[selected_indices, :]
# Create new head with smaller output vocabulary
trimmed_head = torch.nn.Linear(
original_head.in_features,
vocab_mapping['vocab_size'],
bias=(original_head.bias is not None)
)
# Copy weights
with torch.no_grad():
trimmed_head.weight.copy_(trimmed_weight)
if original_head.bias is not None:
original_bias = original_head.bias
selected_bias = original_bias[selected_indices]
trimmed_head.bias.copy_(selected_bias)
return trimmed_head
ステップ4:推測デコーディングループ内でトークンリマッピングを実装
生成時に、ドラフタインデックス(削減された語彙から)を完全な語彙インデックスに変換します。
def speculative_decode_with_trimmed_vocab(target_model, drafter,
input_ids, vocab_mapping,
num_steps=256, draft_length=4):
"""
Run speculative decoding with a trimmed-vocabulary drafter.
The drafter outputs indices in the small vocabulary,
which we remap to the full vocabulary before verification.
"""
full_to_trimmed = vocab_mapping['full_to_trimmed']
trimmed_to_full = vocab_mapping['trimmed_to_full']
current_ids = input_ids
for step in range(num_steps):
# Drafter generates draft tokens
draft_tokens = []
drafter_ids = current_ids
for draft_step in range(draft_length):
with torch.no_grad():
# Forward pass through drafter with TRIMMED vocabulary
drafter_outputs = drafter(drafter_ids)
drafter_logits = drafter_outputs.logits[:, -1, :]
# Sample from trimmed vocabulary
trimmed_probs = torch.softmax(drafter_logits, dim=-1)
trimmed_sample = torch.multinomial(trimmed_probs, num_samples=1)
# REMAP: convert trimmed index to full vocabulary index
full_sample = torch.tensor(
[[trimmed_to_full[int(trimmed_sample[0, 0])]]],
device=trimmed_sample.device
)
draft_tokens.append(full_sample)
drafter_ids = torch.cat([drafter_ids, full_sample], dim=1)
# Concatenate draft tokens for verification
draft_ids = torch.cat(draft_tokens, dim=1)
candidate_ids = torch.cat([current_ids, draft_ids], dim=1)
# Target model verifies candidate tokens (with FULL vocabulary)
with torch.no_grad():
target_outputs = target_model(candidate_ids)
target_logits = target_outputs.logits
# Verification: accept tokens while probabilities remain high
for accept_idx in range(draft_length):
position = current_ids.shape[1] + accept_idx
candidate_token = draft_ids[0, accept_idx].item()
# Get target model's probability for this token
position_logits = target_logits[0, position - 1, :]
target_probs = torch.softmax(position_logits, dim=-1)
target_prob = target_probs[candidate_token].item()
# Accept or reject
if torch.rand(1).item() < target_prob:
# Accept: token is added permanently
current_ids = torch.cat([current_ids, draft_ids[:, accept_idx:accept_idx+1]], dim=1)
else:
# Reject: resample from target model
resampled = torch.multinomial(target_probs, num_samples=1)
current_ids = torch.cat([current_ids, resampled], dim=1)
break
return current_ids
ステップ5:高速化とカバレッジのトレードオフを評価
実際の推論高速化と、ドラフタが語彙外のトークンを見逃す頻度を測定します。
def evaluate_vocabtrim_performance(target_model, drafter, vocab_mapping,
eval_prompts, vocab_mapping_baseline):
"""
Compare trimmed drafter (VocabTrim) against baseline speculative decoding.
Key metrics: speedup, acceptance rate, coverage.
"""
results = {
'trimmed': [],
'baseline': []
}
for prompt in eval_prompts:
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# Baseline: full vocabulary drafter
start_time = time.time()
output_baseline = speculative_decode_baseline(target_model, drafter,
input_ids, max_length=256)
time_baseline = time.time() - start_time
# Trimmed: VocabTrim drafter
start_time = time.time()
output_trimmed = speculative_decode_with_trimmed_vocab(
target_model, drafter, input_ids, vocab_mapping, num_steps=256
)
time_trimmed = time.time() - start_time
# Compute acceptance rate
# (tokens where drafter was right and target accepted them)
acceptance_baseline = count_accepted_tokens(output_baseline)
acceptance_trimmed = count_accepted_tokens(output_trimmed)
results['baseline'].append({
'time': time_baseline,
'acceptance': acceptance_baseline
})
results['trimmed'].append({
'time': time_trimmed,
'acceptance': acceptance_trimmed
})
speedup = np.mean([r['time'] for r in results['baseline']]) / \
np.mean([r['time'] for r in results['trimmed']])
acceptance_degradation = (np.mean([r['acceptance'] for r in results['baseline']]) -
np.mean([r['acceptance'] for r in results['trimmed']])) / \
np.mean([r['acceptance'] for r in results['baseline']])
return {
'speedup': speedup,
'acceptance_degradation': acceptance_degradation,
'vocab_coverage': vocab_mapping['coverage']
}
実装ガイダンス
| ハイパーパラメータ | 推奨値 | 説明 |
|---|---|---|
| ターゲット語彙サイズ | 25K~50K | 高いほど高速化されますが、効果に限界があります |
| カバレッジ閾値 | 85~95% | サンプリングされるほとんどのトークンがカバーされていることを確認 |
| キャリブレーションデータサイズ | 10K~50K件 | 代表的なデータで周波数を測定 |
| 受け入れ率許容度 | 97~99% | 16%の高速化のために1~3%の低下を許容可能 |
| リマッピングオーバーヘッド | <1% | 推論高速化と比べて無視できるレベル |
VocabTrimを使用すべき場合:
- ドラフト-検証アーキテクチャで推測デコーディングを使用している
- ドラフタが大規模な語彙を持っている(Llama-3で128Kトークン)
- 推論がメモリバウンド(コンピュートバウンドではない)
- 再訓練なしで10~20%の推論高速化が必要
VocabTrimを使用すべきでない場合:
- ドラフタが既に小規模な語彙を持っている(50K未満)
- コンピュートバウンド(語彙削減は役に立ちません)
- 受け入れ率の低下がゼロである必要がある
- 非常にレアなトークンのカバレッジが必要なワークロード
よくある落とし穴:
- 不適切なキャリブレーションデータ選択:ドメイン内データを使用してください(推論と同じドメイン)。ウェブテキストで周波数を測定してもコード生成では機能しません。
- 語彙が小さすぎる:10Kトークンのみを選択すると、受け入れ率は急落します。通常は30~50Kがベストバランスです。
- レアトークンの忘却:コード生成などのタスクではレアトークンが必要です。受け入れ率が低下した場合は段階的に語彙サイズを増やしてください。
- リマッピングオーバーヘッド:マッピングを反復処理するのではなく、高速リマッピング用にハッシュテーブルまたはテンソルルックアップを使用してください。
参考文献
VocabTrim: Vocabulary Pruning for Efficient Speculative Decoding in LLMs https://arxiv.org/abs/2506.22694
ライセンス: MIT(寛容ライセンスのため全文を引用しています) · 原本リポジトリ
詳細情報
- 作者
- ADu2021
- リポジトリ
- ADu2021/skillXiv
- ライセンス
- MIT
- 最終更新
- 2026/3/26
Source: https://github.com/ADu2021/skillXiv / ライセンス: MIT
関連スキル
agent-browser
AI エージェント向けのブラウザ自動化 CLI です。ウェブサイトとの対話が必要な場合に使用します。ページ遷移、フォーム入力、ボタンクリック、スクリーンショット取得、データ抽出、ウェブアプリのテスト、ブラウザ操作の自動化など、あらゆるブラウザタスクに対応できます。「ウェブサイトを開く」「フォームに記入する」「ボタンをクリックする」「スクリーンショットを取得する」「ページからデータを抽出する」「このウェブアプリをテストする」「サイトにログインする」「ブラウザ操作を自動化する」といった要求や、プログラマティックなウェブ操作が必要なタスクで起動します。
anyskill
AnySkill — あなたのプライベート・スキルクラウド。GitHubを基盤としたリポジトリからエージェントスキルを管理、同期、動的にロードできます。自然言語でクラウドスキルを検索し、オンデマンドでプロンプトを自動ロード、カスタムスキルのアップロードと共有、スキルバンドルの一括インストールが可能です。OpenClaw、Antigravity、Claude Code、Cursorに対応しています。
engram
AIエージェント向けの永続的なメモリシステムです。バグ修正、意思決定、発見、設定変更の後はmem_saveを使用してください。ユーザーが「覚えている」「記憶している」と言及した場合、または以前のセッションと重複する作業を開始する際はmem_searchを使用します。セッション終了前にmem_session_summaryを使用して、コンテキストを保持してください。
skyvern
AI駆動のブラウザ自動化により、任意のウェブサイトを自動化できます。フォーム入力、データ抽出、ファイルダウンロード、ログイン、複数ステップのワークフロー実行など、ユーザーがウェブサイトと連携する必要があるときに使用します。Skyvernは、LLMとコンピュータビジョンを活用して、未知のサイトも自動操作可能です。Python SDK、TypeScript SDK、REST API、MCPサーバー、またはCLIを通じて統合できます。
pinchbench
PinchBenchベンチマークを実行して、OpenClawエージェントの実世界タスクにおけるパフォーマンスを評価できます。モデルの機能テスト、モデル間の比較、ベンチマーク結果のリーダーボード提出、またはOpenClawのセットアップがカレンダー、メール、リサーチ、コーディング、複数ステップのワークフローにどの程度対応しているかを確認する際に使用します。
openui
OpenUIとOpenUI Langを使用してジェネレーティブUIアプリを構築できます。これらはLLM生成インターフェースのためのトークン効率的なオープン標準です。OpenUI、@openuidev、ジェネレーティブUI、LLMからのストリーミングUI、AI向けコンポーネントライブラリ、またはjson-render/A2UIの置き換えについて述べる際に使用します。スキャフォルディング、defineComponent、システムプロンプト、Renderer、およびOpenUI Lang出力のデバッグに対応しています。