{"name":"JAX: Composable Transformations for Python+NumPy Programs","description":"JAX is a powerful Python library designed for high-performance numerical computing and large-scale machine learning. It offers composable function transformations like automatic differentiation, JIT compilation to accelerators (GPU/TPU), and auto-vectorization. This powerful combination allows developers to write flexible and efficient numerical programs.","github":"https://github.com/jax-ml/jax","url":"https://osrepos.com/repo/jax-ml-jax","source":"osrepos.com","sourceDescription":"This repository profile is provided by osrepos.com, an open source repository discovery platform.","repositoryProfile":"https://osrepos.com/repo/jax-ml-jax","generatedFor":"open source discovery and AI-assisted research","markdown":"https://osrepos.com/repo/jax-ml-jax.md","json":"https://osrepos.com/repo/jax-ml-jax.json","topics":["jax","Python","Machine Learning","Deep Learning","Numerical Computing","AI","GPU","TPU"],"keywords":["jax","Python","Machine Learning","Deep Learning","Numerical Computing","AI","GPU","TPU"],"stars":null,"summary":"JAX is a powerful Python library designed for high-performance numerical computing and large-scale machine learning. It offers composable function transformations like automatic differentiation, JIT compilation to accelerators (GPU/TPU), and auto-vectorization. This powerful combination allows developers to write flexible and efficient numerical programs.","content":"## Introduction\n\nJAX is a powerful Python library for accelerator-oriented array computation and program transformation, specifically designed for high-performance numerical computing and large-scale machine learning. It allows for composable transformations of Python and NumPy programs, enabling features like automatic differentiation, Just-In-Time (JIT) compilation to GPUs/TPUs, and automatic vectorization. JAX leverages XLA (Accelerated Linear Algebra) to compile and scale your NumPy programs efficiently across various hardware accelerators.\n\n## Installation\n\nGetting started with JAX is straightforward. Here are the basic installation commands for common platforms:\n\n*   **CPU:**\n    bash\n    pip install -U jax\n    \n*   **NVIDIA GPU:**\n    bash\n    pip install -U \"jax[cuda13]\"\n    \n*   **Google TPU:**\n    bash\n    pip install -U \"jax[tpu]\"\n    \n\nFor more detailed instructions, including alternative CUDA versions, compiling from source, or using Docker, please refer to the [official JAX installation documentation](https://docs.jax.dev/en/latest/installation.html){target=\"_blank\"}.\n\n## Examples\n\nJAX's core strength lies in its composable function transformations. Here are some examples:\n\n### Automatic Differentiation with `jax.grad`\n\nJAX can automatically differentiate native Python and NumPy functions, even through complex control flow, recursion, and closures.\n\npython\nimport jax\nimport jax.numpy as jnp\n\ndef tanh(x):\n  y = jnp.exp(-2.0 * x)\n  return (1.0 - y) / (1.0 + y)\n\ngrad_tanh = jax.grad(tanh)\nprint(grad_tanh(1.0))\n# prints 0.4199743\n\n# Differentiate to any order\nprint(jax.grad(jax.grad(jax.grad(tanh)))(1.0))\n# prints 0.62162673\n\n# Differentiation with Python control flow\ndef abs_val(x):\n  if x > 0:\n    return x\n  else:\n    return -x\n\nabs_val_grad = jax.grad(abs_val)\nprint(abs_val_grad(1.0))   # prints 1.0\nprint(abs_val_grad(-1.0))  # prints -1.0\n\n\n### Just-In-Time Compilation with `jax.jit`\n\nUse `jax.jit` to compile your functions end-to-end with XLA, significantly boosting performance on accelerators.\n\npython\nimport jax\nimport jax.numpy as jnp\n\ndef slow_f(x):\n  # Element-wise ops see a large benefit from fusion\n  return x * x + x * 2.0\n\nx = jnp.ones((5000, 5000))\nfast_f = jax.jit(slow_f)\n# %timeit -n10 -r3 fast_f(x)\n# %timeit -n10 -r3 slow_f(x)\n\n\n### Auto-vectorization with `jax.vmap`\n\n`jax.vmap` maps a function along array axes, pushing the loop down to primitive operations for better performance, eliminating the need to manually manage batch dimensions.\n\npython\nimport jax\nimport jax.numpy as jnp\n\ndef l1_distance(x, y):\n  assert x.ndim == y.ndim == 1  # only works on 1D inputs\n  return jnp.sum(jnp.abs(x - y))\n\ndef pairwise_distances(dist1D, xs):\n  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)\n\nxs = jax.random.normal(jax.random.key(0), (100, 3))\ndists = pairwise_distances(l1_distance, xs)\ndists.shape  # (100, 100)\n\n\n## Why Use JAX?\n\nJAX offers several compelling advantages for developers working with numerical computation and machine learning:\n\n*   **Composable Transformations:** Its core design allows you to combine powerful function transformations like `grad`, `jit`, and `vmap` in arbitrary ways, leading to highly flexible and efficient code.\n*   **Automatic Differentiation:** Easily compute gradients of complex Python and NumPy functions, supporting higher-order derivatives and control flow. This is crucial for optimizing machine learning models.\n*   **High Performance:** Leverage JIT compilation to XLA, enabling your code to run at native speed on CPUs, GPUs, and TPUs without significant code changes.\n*   **Simplified Batching:** `jax.vmap` automates the process of vectorizing functions, simplifying code for batch processing and improving performance.\n*   **Scalability:** JAX provides robust tools for scaling computations across thousands of devices, supporting automatic parallelization, explicit sharding, and manual per-device programming.\n\n## Links\n\n*   **GitHub Repository:** [jax-ml/jax](https://github.com/jax-ml/jax){target=\"_blank\"}\n*   **Official Documentation:** [JAX Documentation](https://docs.jax.dev/en/latest/){target=\"_blank\"}\n*   **Installation Guide:** [JAX Installation](https://docs.jax.dev/en/latest/installation.html){target=\"_blank\"}","metrics":{"detailViews":10,"githubClicks":5},"dates":{"published":null,"modified":"2026-03-26T16:09:03.000Z"}}