JAX: Composable Transformations for Python+NumPy Programs

Summary
JAX is a powerful Python library designed for high-performance numerical computing and large-scale machine learning. It offers composable function transformations like automatic differentiation, JIT compilation to accelerators (GPU/TPU), and auto-vectorization. This powerful combination allows developers to write flexible and efficient numerical programs.
Repository Info
Tags
Click on any tag to explore related repositories
Introduction
JAX is a powerful Python library for accelerator-oriented array computation and program transformation, specifically designed for high-performance numerical computing and large-scale machine learning. It allows for composable transformations of Python and NumPy programs, enabling features like automatic differentiation, Just-In-Time (JIT) compilation to GPUs/TPUs, and automatic vectorization. JAX leverages XLA (Accelerated Linear Algebra) to compile and scale your NumPy programs efficiently across various hardware accelerators.
Installation
Getting started with JAX is straightforward. Here are the basic installation commands for common platforms:
- CPU:
pip install -U jax - NVIDIA GPU:
pip install -U "jax[cuda13]" - Google TPU:
pip install -U "jax[tpu]"
For more detailed instructions, including alternative CUDA versions, compiling from source, or using Docker, please refer to the official JAX installation documentation.
Examples
JAX's core strength lies in its composable function transformations. Here are some examples:
Automatic Differentiation with jax.grad
JAX can automatically differentiate native Python and NumPy functions, even through complex control flow, recursion, and closures.
import jax
import jax.numpy as jnp
def tanh(x):
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743
# Differentiate to any order
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# prints 0.62162673
# Differentiation with Python control flow
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0
Just-In-Time Compilation with jax.jit
Use jax.jit to compile your functions end-to-end with XLA, significantly boosting performance on accelerators.
import jax
import jax.numpy as jnp
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
# %timeit -n10 -r3 fast_f(x)
# %timeit -n10 -r3 slow_f(x)
Auto-vectorization with jax.vmap
jax.vmap maps a function along array axes, pushing the loop down to primitive operations for better performance, eliminating the need to manually manage batch dimensions.
import jax
import jax.numpy as jnp
def l1_distance(x, y):
assert x.ndim == y.ndim == 1 # only works on 1D inputs
return jnp.sum(jnp.abs(x - y))
def pairwise_distances(dist1D, xs):
return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)
xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape # (100, 100)
Why Use JAX?
JAX offers several compelling advantages for developers working with numerical computation and machine learning:
- Composable Transformations: Its core design allows you to combine powerful function transformations like
grad,jit, andvmapin arbitrary ways, leading to highly flexible and efficient code. - Automatic Differentiation: Easily compute gradients of complex Python and NumPy functions, supporting higher-order derivatives and control flow. This is crucial for optimizing machine learning models.
- High Performance: Leverage JIT compilation to XLA, enabling your code to run at native speed on CPUs, GPUs, and TPUs without significant code changes.
- Simplified Batching:
jax.vmapautomates the process of vectorizing functions, simplifying code for batch processing and improving performance. - Scalability: JAX provides robust tools for scaling computations across thousands of devices, supporting automatic parallelization, explicit sharding, and manual per-device programming.
Links
- GitHub Repository: jax-ml/jax
- Official Documentation: JAX Documentation
- Installation Guide: JAX Installation