JAX,由Google开发、並由Nvidia做出部分貢獻[4][5][6]Python机器学习框架,用於变换数值函数。JAX结合了修改版的Autograd自動微分系統[7],以及來自OpenXLA專案的編譯器XLA英语Accelerated Linear Algebra[8],可加速數值線性運算。其設計目標是在介面與程式設計風格上盡可能與NumPy保持相容,使使用者能夠以熟悉的方式撰寫高效能運算程式。此外,JAX亦可與TensorFlowPyTorch等機器學習框架整合使用。[9][10]

JAX
開發者Google, Nvidia[1]
首次发布2019年10月31日,​6年前​(2019-10-31[2]
当前版本
  • 0.8.1(2025年11月16日;穩定版本)[3]
編輯維基數據鏈接
源代码库github.com/jax-ml/jax
编程语言Python, C++
操作系统Linux, macOS, Windows
平台Python, NumPy
类型机器学习
许可协议Apache 2.0
网站docs.jax.dev/en/latest/

主要功能

编辑

JAX的主要功能是[4]

grad

编辑

下面的代码演示grad函数的自动微分。

# 导入库
from jax import grad
import jax.numpy as jnp

# 定义logistic函数
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)

# 求值logistic函数在x = 1处的梯度 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

最终的输出为:

0.19661194

jit

编辑

下面的代码演示jit函数的优化。

# 导入库
from jax import jit
import jax.numpy as jnp

# 定义cube函数
def cube(x):
    return x * x * x

# 生成数据
x = jnp.ones((10000, 10000))

# 创建cube函数的jit版本
jit_cube = jit(cube)

# 应用cube函数和jit_cube函数于相同数据来比较其速度
cube(x)
jit_cube(x)

可见jit_cube的运行时间显著的短于cube

vmap

编辑

下面的代码展示vmap函数的通过SIMD的向量化。

# 导入库
from functools import partial
from jax import vmap
import jax.numpy as jnp

# 定义函数
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

使用JAX的库

编辑

一些Python库使用JAX作为后端,这包括:

参见

编辑

引用

编辑
  1. ^ jax/AUTHORS at main · jax-ml/jax. GitHub. [December 21, 2024]. 
  2. ^ jax-v0.1.49. 
  3. ^ https://github.com/jax-ml/jax/releases/tag/jax-v0.8.1.
  4. ^ 4.0 4.1 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao, JAX: Autograd and XLA, Astrophysics Source Code Library (Google), 2022-06-18 [2022-06-18], Bibcode:2021ascl.soft11002B, (原始内容存档于2022-06-18) 
  5. ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris. Compiling machine learning programs via high-level tracing (PDF). MLsys. 2018-02-02: 1–3. (原始内容存档 (PDF)于2022-06-21). 
  6. ^ Using JAX to accelerate our research. www.deepmind.com. [2022-06-18]. (原始内容存档于2022-06-18) (英语). 
  7. ^ autograd. [2023-09-23]. (原始内容存档于2022-07-18). 
  8. ^ XLA. [2023-09-23]. (原始内容存档于2022-09-01). 
  9. ^ Lynley, Matthew. Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta. Business Insider. [2022-06-21]. (原始内容存档于2022-06-21) (美国英语). 
  10. ^ Why is Google's JAX so popular?. Analytics India Magazine. 2022-04-25 [2022-06-18]. (原始内容存档于2022-06-18) (美国英语). 
  11. ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29 [2022-07-29], (原始内容存档于2022-09-03) 
  12. ^ Kidger, Patrick, Equinox, 2022-07-29 [2022-07-29], (原始内容存档于2023-09-19) 
  13. ^ Kidger, Patrick, Diffrax, 2023-08-05 [2023-08-08], (原始内容存档于2023-08-10) 
  14. ^ Optax, DeepMind, 2022-07-28 [2022-07-29], (原始内容存档于2023-06-07) 
  15. ^ Lineax, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10) 
  16. ^ RLax, DeepMind, 2022-07-29 [2022-07-29], (原始内容存档于2023-04-26) 
  17. ^ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08 [2023-08-08], (原始内容存档于2022-11-23) 
  18. ^ jaxtyping, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10) 
  19. ^ NumPyro - Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. [2022-08-31]. (原始内容存档于2022-08-31). 
  20. ^ Brax - Massively parallel rigidbody physics simulation on accelerator hardware. [2022-08-31]. (原始内容存档于2022-08-31). 

外部链接

编辑


📚 Artikel Terkait di Wikipedia

SSE2

編譯器,它大量增加SSE2於Windows應用程式開發。 自從GCC 3推出,它能夠自動生成SSE/SSE2純量碼。而SSE/SSE2的自動向量化(英语:Automatic vectorization)也新增在GCC 4。 Sun Studio Compiler Suite在使用此-xvector=simd參數時也能夠產生SSE2指令碼。

并行计算

超字级并行(Superword level parallelism)是一种基于循环展开和基本块向量化的自动向量化(英语:Automatic vectorization)技术。它与循环向量化算法的不同之处在于,它可以利用内联代码(英语:inline code)的并行性(英语:Parallelism