JAX Notes
How to think in JAX
- JAX arrays are immutable, meaning that once created their contents cannot be changed.
jax.numpy
is a high-level wrapper that provides a familiar interface.jax.lax
is a lower-level API that is stricter and often more powerful.- All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.
- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.
- Variables that you don’t want to be traced can be marked as static.
- Control flow statements in the function cannot depend on traced values.
- Understanding which values and operations will be static and which will be traced.
- Just as values can be either static or traced, operations can be static or traced.
- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.
- Use
numpy
for operations that you want to be static. - Use
jax.numpy
for operations that you want to be traced.
Debugging runtime values
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
GPU performance tips
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_latency_hiding_scheduler=true '
)
Pytrees
- Any object whose type is not in the pytree container registry is considered a leaf pytree.
- Any object whose type is in the pytree container registry, and which contains pytrees, is considered a pytree.