# JAX: Composable Transformations for Python+NumPy Programs

This repository profile is provided by osrepos.com, an open source repository discovery platform.

Source: osrepos.com
Repository profile: https://osrepos.com/repo/jax-ml-jax
Generated for open source discovery and AI-assisted research.

JAX is a powerful Python library designed for high-performance numerical computing and large-scale machine learning. It offers composable function transformations like automatic differentiation, JIT compilation to accelerators (GPU/TPU), and auto-vectorization. This powerful combination allows developers to write flexible and efficient numerical programs.

GitHub: https://github.com/jax-ml/jax
OSRepos URL: https://osrepos.com/repo/jax-ml-jax

## Summary

JAX is a powerful Python library designed for high-performance numerical computing and large-scale machine learning. It offers composable function transformations like automatic differentiation, JIT compilation to accelerators (GPU/TPU), and auto-vectorization. This powerful combination allows developers to write flexible and efficient numerical programs.

## Topics

- jax
- Python
- Machine Learning
- Deep Learning
- Numerical Computing
- AI
- GPU
- TPU

## Repository Information

Last analyzed by OSRepos: Thu Mar 26 2026 16:09:03 GMT+0000 (Western European Standard Time)
Detail views: 10
GitHub clicks: 5

## Safety Notice

OSRepos shares public repositories for knowledge and discovery only. Review source code, dependencies, licenses, and security implications before running or installing anything.

## Content

## Introduction

JAX is a powerful Python library for accelerator-oriented array computation and program transformation, specifically designed for high-performance numerical computing and large-scale machine learning. It allows for composable transformations of Python and NumPy programs, enabling features like automatic differentiation, Just-In-Time (JIT) compilation to GPUs/TPUs, and automatic vectorization. JAX leverages XLA (Accelerated Linear Algebra) to compile and scale your NumPy programs efficiently across various hardware accelerators.

## Installation

Getting started with JAX is straightforward. Here are the basic installation commands for common platforms:

*   **CPU:**
    bash
    pip install -U jax
    
*   **NVIDIA GPU:**
    bash
    pip install -U "jax[cuda13]"
    
*   **Google TPU:**
    bash
    pip install -U "jax[tpu]"
    

For more detailed instructions, including alternative CUDA versions, compiling from source, or using Docker, please refer to the [official JAX installation documentation](https://docs.jax.dev/en/latest/installation.html){target="_blank"}.

## Examples

JAX's core strength lies in its composable function transformations. Here are some examples:

### Automatic Differentiation with `jax.grad`

JAX can automatically differentiate native Python and NumPy functions, even through complex control flow, recursion, and closures.

python
import jax
import jax.numpy as jnp

def tanh(x):
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743

# Differentiate to any order
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# prints 0.62162673

# Differentiation with Python control flow
def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0


### Just-In-Time Compilation with `jax.jit`

Use `jax.jit` to compile your functions end-to-end with XLA, significantly boosting performance on accelerators.

python
import jax
import jax.numpy as jnp

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
# %timeit -n10 -r3 fast_f(x)
# %timeit -n10 -r3 slow_f(x)


### Auto-vectorization with `jax.vmap`

`jax.vmap` maps a function along array axes, pushing the loop down to primitive operations for better performance, eliminating the need to manually manage batch dimensions.

python
import jax
import jax.numpy as jnp

def l1_distance(x, y):
  assert x.ndim == y.ndim == 1  # only works on 1D inputs
  return jnp.sum(jnp.abs(x - y))

def pairwise_distances(dist1D, xs):
  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)

xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape  # (100, 100)


## Why Use JAX?

JAX offers several compelling advantages for developers working with numerical computation and machine learning:

*   **Composable Transformations:** Its core design allows you to combine powerful function transformations like `grad`, `jit`, and `vmap` in arbitrary ways, leading to highly flexible and efficient code.
*   **Automatic Differentiation:** Easily compute gradients of complex Python and NumPy functions, supporting higher-order derivatives and control flow. This is crucial for optimizing machine learning models.
*   **High Performance:** Leverage JIT compilation to XLA, enabling your code to run at native speed on CPUs, GPUs, and TPUs without significant code changes.
*   **Simplified Batching:** `jax.vmap` automates the process of vectorizing functions, simplifying code for batch processing and improving performance.
*   **Scalability:** JAX provides robust tools for scaling computations across thousands of devices, supporting automatic parallelization, explicit sharding, and manual per-device programming.

## Links

*   **GitHub Repository:** [jax-ml/jax](https://github.com/jax-ml/jax){target="_blank"}
*   **Official Documentation:** [JAX Documentation](https://docs.jax.dev/en/latest/){target="_blank"}
*   **Installation Guide:** [JAX Installation](https://docs.jax.dev/en/latest/installation.html){target="_blank"}