2026
My JAX Journey
A long-form personal field guide to JAX — the paradigm, the hardware substrate underneath (roofline, TPUs, GPUs, collectives), the modern ecosystem (Flax NNX, Grain, Orbax), sharded computation, the LLM stack (FlashAttention, KV caches, MoE, μP), and writing Pallas kernels. Synthesized from Google's scaling-book, the JAX docs, and a year of my own notes; kept up to date as the ecosystem moves.