JAX is the new kid on the block when it comes to machine learning frameworks – although the Tensorflow competitor has technically been around since late 2018, only recently has JAX been starting to gain traction within the broader machine learning research community.

So what is JAX, exactly? According to the official JAX documentation:

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Just like the label says, at it's simplest JAX is accelerator-backed numpy with some convience functions for common machine learning ops.

import jax
import jax.numpy as np

def gpu_backed_hidden_layer(x):
    return jax.nn.relu(np.dot(W, x) + b)

You get numpy's well-thought out API that's been honed since 2006, with the performance characteristics of the modern ML workhorses like Tensorflow and PyTorch.

JAX also includes a sizable chunk of the scipy project, exposed through jax.scipy

from jax.scipy.linalg import svd

singular_vectors, singular_values = svd(x)

Although an accelerator-backed version of numpy + scipy is already quite useful, JAX has a few other tricks up its sleeves. First let's look at JAX's extensive support for automatic differentiation.

Autograd

Autograd is a library for efficient computation of gradients over numpy and native python code. Autograd also happens to be the spiritual (and largely literal) predecessor to JAX. Although the original autograd repository is no longer actively developed, much of the core team that worked on autograd has moved on to work full time on the JAX project.

Just like autograd, JAX allows taking derivatives with respect to a python function's output by simply calling grad:

from jax import grad

def hidden_layer(x):
    return jax.nn.relu(np.dot(W, x) + b)

grad_hidden_layer = grad(hidden_layer)

You can also differentiate through native python control structures -- no wrestling with tf.cond required.

def absolute_value(x)
    if x >= 0:
        return x
    else:
        return -x

grad_absolute_value = grad(absolute_value)

JAX also includes support for taking higher-order derivates -- the grad function can be chained arbitrarily.

from jax.nn import tanh

# grads all the way down
print(grad(grad(grad(tanh)))(1.0))

By default, grad gives you reverse-mode gradients -- the most common mode for computing gradients that relies on caching activations to make the backwards pass efficient. Reverse-mode differention is typically the most efficient method for computing parameter updates. However, especially when implementing optimization methods that rely on higher-order derivates, it's not always the best choice. JAX has first-class support for both reverse-mode automatic differention and forward-mode automation differentiation through jacfwd and jacrev:

from jax import jacfwd, jacrev

hessian_fn = jacfwd(jacrev(fn))

Aside from grad, jacfwd and jacrev, JAX provides utilities for computing linear approximations to functions, defining custom gradient operations, and more as part of it's automatic differentiation support.

XLA

XLA (Accelerated Linear Algebra) is a domain specific compiler for linear algebra code -- and it's the backbone that allows JAX to translate python and numpy expressions into accelerator-backed operations.

In addition to allowing JAX to translate python + numpy code to operations that can be run on an accelerator (like we saw in our first example), XLA support also allows JAX to fuse together several operations into a single kernel. It looks for clusters of nodes in a computational graph that can be re-written to reduce computation or intermediate storage of variables. Tensorflow's documentation on XLA uses the following example to explain the kind of instance where a problem would benefit from XLA compilation.

def unoptimized_fn(x, y, z):
  return np.sum(x + y * z)

Run without XLA, this would be run as 3 separate kernels -- a multiplication, and addition, and an additive reduction. Run with XLA, this becomes a single kernel responsible for all three, saving time and memory by not requiring storage of intermediate variables.

Turning on support for this operation-rewriting is as straightforward as decorating a function with @jax.jit:

@jit
def xla_optimized_fn(x, y, z):
  return np.sum(x + y * z)

Like all other JAX functions, jax.jit is fully composable:

xla_optimized_grad = jit(grad(xla_optimized_fn))

Vectorization and Parallelism

Although Autograd and XLA form the core of the JAX library, there are two more JAX functions that stand out from the crowd. You can use jax.vmap and jax.pmap for vectorization and SPMD-based (single program multiple data) parallelism.

To illustrate the benefit of vmap, we'll return to the example of our simple dense layer that operates on a single example represented by the vector x.

# convention to distinguish between 
# jax.numpy and numpy
import numpy as onp

def hidden_layer(x):
    return jax.nn.relu(np.dot(W, x + b)
   
print(hidden_layer(np.random.randn(128)).shape)
# (128,)

We've written our hidden layer to take a single vector input, but in practice we almost always batch our inputs to leverage vectorized computation. With JAX, you can take any function that accepts a single input and allow it to accept a batch of inputs using jax.vmap:

batch_hidden_layer = vmap(hidden_layer)
print(batch_hidden_layer(onp.random.randn(32, 128)).shape)
# (32, 128)

The beauty in this is that it means you more or less ignore the batch dimension in your model function, and have one less tensor dimension to keep in your head at all times when you're composing your model.

If you have several inputs that should all be vectorized, or you want to vectorize along an axis other than axis 0, you can specify this with the in_axes argument.

batch_hidden_layer = vmap(hidden_layer, in_axes=(0,))

JAX's utility for SPMD paralellism follows a very similar API. If you have a 4-gpu machine and a batch of 4 examples, you can use pmap to run one example per device.

# first dimension must align with number of XLA-enabled devices
spmd_hidden_layer = pmap(hidden_layer)

And like always, you can compose functions to your hearts content.

# hypothetical setup for high-throughput inference
outputs = pmap(vmap(hidden_layer))(onp.random.randn(4, 32, 128))
print(outputs.shape)
# (4, 32, 128)

Why JAX?

The tools we use have disproportionate impact on the research arcs that we explore. Subconsciouly or conciously, we constrain the ideas we have to the space of ideas we understand how to implement efficiently. As Roman Ring points out in his blog post that explores the next generation of ML tools, AlexNet was primarily a software engineering achievement that enabled decades of good machine learning ideas to be properly tested.

JAX is an important step forward not because it has a cleaner API than existing machine learning frameworks, or because it's better than Tensorflow and PyTorch at doing the things they were designed for, but because it allows us to more easily experiment with a broader space of ideas that was previously possible.

If you dive in and start using JAX for your own projects, you might be frustrated by how little JAX seems to do on the surface. Writing training loops feels very manual. Managing parameters requires custom code. You even have to generate your own random PRNG keys every time you want a new random value. But in a way that's also JAX's greatest strength.

It doesn't hide the details behind a curtain that you can't peer behind. The internals are extensively documented and it's clear that JAX cares about enabling other developers to contribute. JAX makes very few assumptions about how you intend to use it, and in doing so gives you the flexibility to do things that range from unpleasant to impossible in other frameworks.

Whenever you wrap up a lower API into a higher level abstraction, you make assumptions about the space of possible uses the end user might have. And when you have a very targeted application in mind, this makes for beautifully concise APIs that allow you to get the result you want with minimal configuration. Especially with the recent emphasis on Keras and higher level APIs in TF2.0, writing Tensorflow feels a bit like using a 3D printer -- it's push-button simple, and as long as you want a plastic object that fits within the print surface it's going to work like a charm.

Working with JAX is like being given access to a full blown machine shop. Yes, just about everything around you could lop of a finger or otherwise cause grave bodily harm if you're not careful. But the freedom to implement and explore wild, "out-there" ideas that might just work makes using JAX worth it. So put on your metaphorical safety goggles and start building something weird with JAX.

JAX Ecosystem

Although the JAX ecosystem is still quite fragmented, there do exist some frameworks built on top of JAX that offer some light abstraction on top of the core APIs.  Of particular note are:

  • Flax: a functional framework designed for flexibility
  • Trax:  Google Brain's spiritual successor to Tensor2Tensor with both TF and JAX backend support.  
  • Stax: a neural network library that's part of built JAX's experimental module
  • RLax - pronounced "relax", a JAX package for reinforcement learning
  • Haiku – a JAX variant of the Sonnet neural network library

Preregister your "-ax" libraries before it's too late!

Other JAX Blog Posts and Resources

All things considered, JAX is still in it's infancy compared to the incumbents Tensorflow and PyTorch – but there are resources out there to learn JAX if you look hard enough.  Here's a short list of the resources that I've found useful.