How to think in JAX

  • JAX arrays are immutable, meaning that once created their contents cannot be changed.
  • jax.numpy is a high-level wrapper that provides a familiar interface.
  • jax.lax is a lower-level API that is stricter and often more powerful.
  • All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.

  • Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.
  • Variables that you don’t want to be traced can be marked as static.
  • Control flow statements in the function cannot depend on traced values.
  • Understanding which values and operations will be static and which will be traced.
  • Just as values can be either static or traced, operations can be static or traced.
  • Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
  • Use numpy for operations that you want to be static.
  • Use jax.numpy for operations that you want to be traced.

Debugging runtime values

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.breakpoint()
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯

GPU performance tips

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

Pytrees

  • Any object whose type is not in the pytree container registry is considered a leaf pytree.
  • Any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.