loading the guide
senior ai engineer · bangladesh → remote
Writing
my jax journey
Notes on learning to think in JAX — from autodiff and tracing, through the hardware substrate underneath, to the modern LLM stack and writing Pallas kernels. A synthesis I keep updated for myself, published in case it's useful to anyone walking the same path.
sources & attribution
Most of the technical content is synthesized from primary sources I'm grateful to. The framing of the roofline, sharding math, and TPU/GPU chapters leans heavily on Google's open
"How to Scale Your Model" book — verbatim equations are reproduced with attribution. The ecosystem chapters draw from the
JAX docs,
Flax,
Pallas,
Grain, and
Orbax documentation. The LLM-stack chapters lean on landmark papers
(FlashAttention, RoPE/YaRN, Switch/Mixtral, μP, speculative decoding) cited inline at point of use.
My contribution is the synthesis, the ordering, the worked examples, and any errors.