汎用データ・分析⭐ リポ 34品質スコア 55/100
JAX
機械学習と数学解析でJAXを利用するための必須ツールです。コア概念、変換機能、ML特有の機能、制御フロー、並列処理といった要素をカバーしており、これらを組み合わせることで、高性能な計算パイプラインを構築できます。
description の原文を見る
Essential tools for using JAX in machine learning and mathematical analysis, covering core concepts, transformations, ML specifics, control flow, and parallelism.
SKILL.md 本文
注意: このスキルのライセンスは ライセンス未確認 です。本サイトでは本文プレビューのみを表示しています。利用前に GitHub の原本でライセンス条件をご確認ください。
JAX スキル
JAX は Autograd と XLA を統合したもので、高性能な機械学習研究向けです。
目次
- 概念と理論
- イミュータビリティ
- 4つの変換
- Pytree
- コード例
jit、grad、vmap、randomの使用法- 制御フロー(
scan、cond、fori_loop) - 並列処理(
sharding)
一般的なワークフロー
1. 新しいモデルの開発
- パラメータを Pytree(辞書/データクラス)として定義します。
- フォワードパス関数(純粋関数)を定義します。
- 損失関数を定義します。
jax.value_and_gradを使用して勾配を取得します。jax.jitを使用して更新ステップを高速化します。- スニペットについては examples.md をご覧ください。
2. 形状/NaN のデバッグ
- JIT を無効化します:
jax.config.update("jax_disable_jit", True)で標準的な Python ツールを使用してデバッグできます。 - JI
...
詳細情報
- 作者
- diegosouzapw
- ライセンス
- 不明
- 最終更新
- 2026/3/2
Source: https://github.com/diegosouzapw/awesome-omni-skill / ライセンス: 未指定