Writing

essay · living document

my jax journey

Notes on learning to think in JAX — from autodiff and tracing, through the hardware substrate underneath, to the modern LLM stack and writing Pallas kernels. A synthesis I keep updated for myself, published in case it's useful to anyone walking the same path.

updated
may 2026
length
35 chapters across 7 parts
covers
JAX paradigm · TPU/GPU hardware · Flax NNX · sharding · LLMs · Pallas
format
living document — I revise as the ecosystem moves
sources & attribution Most of the technical content is synthesized from primary sources I'm grateful to. The framing of the roofline, sharding math, and TPU/GPU chapters leans heavily on Google's open "How to Scale Your Model" book — verbatim equations are reproduced with attribution. The ecosystem chapters draw from the JAX docs, Flax, Pallas, Grain, and Orbax documentation. The LLM-stack chapters lean on landmark papers (FlashAttention, RoPE/YaRN, Switch/Mixtral, μP, speculative decoding) cited inline at point of use. My contribution is the synthesis, the ordering, the worked examples, and any errors.
contents · 35 chapters
    loading the guide