JAX: Composable Transformations for Python+NumPy Programs

JAX: Composable Transformations for Python+NumPy Programs

Summary

JAX is a powerful Python library designed for high-performance numerical computing and large-scale machine learning. It offers composable function transformations like automatic differentiation, JIT compilation to accelerators (GPU/TPU), and auto-vectorization. This powerful combination allows developers to write flexible and efficient numerical programs.

Repository Info

Updated on March 26, 2026
View on GitHub

Tags

Click on any tag to explore related repositories

Introduction

JAX is a powerful Python library for accelerator-oriented array computation and program transformation, specifically designed for high-performance numerical computing and large-scale machine learning. It allows for composable transformations of Python and NumPy programs, enabling features like automatic differentiation, Just-In-Time (JIT) compilation to GPUs/TPUs, and automatic vectorization. JAX leverages XLA (Accelerated Linear Algebra) to compile and scale your NumPy programs efficiently across various hardware accelerators.

Installation

Getting started with JAX is straightforward. Here are the basic installation commands for common platforms:

  • CPU:
    pip install -U jax
    
  • NVIDIA GPU:
    pip install -U "jax[cuda13]"
    
  • Google TPU:
    pip install -U "jax[tpu]"
    

For more detailed instructions, including alternative CUDA versions, compiling from source, or using Docker, please refer to the official JAX installation documentation.

Examples

JAX's core strength lies in its composable function transformations. Here are some examples:

Automatic Differentiation with jax.grad

JAX can automatically differentiate native Python and NumPy functions, even through complex control flow, recursion, and closures.

import jax
import jax.numpy as jnp

def tanh(x):
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743

# Differentiate to any order
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# prints 0.62162673

# Differentiation with Python control flow
def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0

Just-In-Time Compilation with jax.jit

Use jax.jit to compile your functions end-to-end with XLA, significantly boosting performance on accelerators.

import jax
import jax.numpy as jnp

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
# %timeit -n10 -r3 fast_f(x)
# %timeit -n10 -r3 slow_f(x)

Auto-vectorization with jax.vmap

jax.vmap maps a function along array axes, pushing the loop down to primitive operations for better performance, eliminating the need to manually manage batch dimensions.

import jax
import jax.numpy as jnp

def l1_distance(x, y):
  assert x.ndim == y.ndim == 1  # only works on 1D inputs
  return jnp.sum(jnp.abs(x - y))

def pairwise_distances(dist1D, xs):
  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)

xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape  # (100, 100)

Why Use JAX?

JAX offers several compelling advantages for developers working with numerical computation and machine learning:

  • Composable Transformations: Its core design allows you to combine powerful function transformations like grad, jit, and vmap in arbitrary ways, leading to highly flexible and efficient code.
  • Automatic Differentiation: Easily compute gradients of complex Python and NumPy functions, supporting higher-order derivatives and control flow. This is crucial for optimizing machine learning models.
  • High Performance: Leverage JIT compilation to XLA, enabling your code to run at native speed on CPUs, GPUs, and TPUs without significant code changes.
  • Simplified Batching: jax.vmap automates the process of vectorizing functions, simplifying code for batch processing and improving performance.
  • Scalability: JAX provides robust tools for scaling computations across thousands of devices, supporting automatic parallelization, explicit sharding, and manual per-device programming.

Links