Performance
Training speed and memory you can expect, plus the knobs you can use when defaults don’t fit your workload.
What to expect on M4 Pro 48 GB
These numbers come from our own runs on an M4 Pro 48 GB Mac. They’re rough indicators, not benchmarks — treat them as “starting expectations.” Your numbers will move with model size, quantization, context length, batch size, dataset shape, and macOS version. All rows below use LoRA rank 8–16.
| Workload | Model | Context | Batch | Speed | Peak GB | Notes |
|---|---|---|---|---|---|---|
| SFT | Qwen3–0.6B–bf16 | 1024 | 1 | ~2.2 it/s | 2.6 | 16 LoRA layers (q, v), no checkpointing |
| SFT | Qwen3–0.6B–bf16 | 4096 | 1 | ~0.28 it/s | 18 | full LoRA, no checkpointing |
| DPO | Qwen3.5–0.8B–MLX–4bit | 512 | 1 | ~540 ms/sample | ~1 | shares the prompt forward between chosen & rejected |
| DPO | Qwen3.5–0.8B–MLX–4bit | 512 | 2 | ~810 ms/sample | ~1.2 | batched (prompt-share applies only at batch size 1) |
| ORPO | Qwen3.5–0.8B–MLX–4bit | 512 | 1 | ~430 ms/sample | ~1 | shares the prompt forward between chosen & rejected |
| ORPO | Qwen3–0.6B–bf16 | 4096 | 1 | ~0.09 it/s | 9.2 | set use_gradient_checkpointing="unsloth" |
| ORPO | Qwen3–0.6B–bf16 | 2048 | 1 | ~0.41 it/s | 16.8 | no checkpointing needed |
| GRPO | Qwen3–0.6B–bf16 | 4096 | 1 | ~0.11 it/s | 5.3 | 4 generations × 256 tokens per step |
| GRPO | Qwen3.5–0.8B–4bit | ~128–512 | 1 | ~4.8 s/iter | 15.9 | 4 generations × 128 tokens per step |
| Embedding | MiniLM–L6–v2–bf16 (22M) | 128 | 32 | ~25 ms/step | <0.5 | InfoNCE contrastive |
it/s is full step time — forward, backward, optimizer, and per-step overhead. ms/sample divides wall time by the total samples processed across the run, which is the fair number when comparing batch sizes. Peak GB is the high-water mark over the whole run via mx.get_peak_memory() — not the steady-state working set.
Scaling down to smaller Macs
As a rule of thumb, halving the available memory means halving the context length you can comfortably train at, for the same model class. If you’re on 24 GB, the M4 Pro 48 GB rows above should run at roughly half the maximum context; on 16 GB, around a quarter. Below 16 GB, stick to 4-bit models in the 0.5–1B range at context length 1024 or shorter.
A note on DPO/ORPO at batch size 1
At per_device_train_batch_size=1, the trainer forwards the prompt once and reuses the result for both the chosen and rejected branches. The saving scales with how much of the sequence is prompt vs. response — the rows above are at roughly equal prompt and response length. With very short prompts the speedup is modest; with very long prompts (multi-turn or reasoning data), it gets larger.
Performance knobs
Most things are already on by default. These are the levers worth knowing about when defaults don’t fit.
Gradient checkpointing
Trades roughly 2× backward time for about half the activation memory. Off by default because most short-context runs don’t need it. Turn it on when you hit OOM at long context or with a large model:
model = FastLanguageModel.get_peft_model(
model, r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
use_gradient_checkpointing="unsloth", # opt in for ctx >= 4096 or 7B+
)
This flag is now respected by every trainer (not just SFT — it used to silently no-op for DPO, ORPO, KTO, SimPO, GRPO, CPT, and the audio/embedding/VLM trainers).
SFT validation cost
The default is light — 5 validation batches per evaluation pass, evaluated every max(save_steps, 200) steps. Set val_batches=0 to skip evaluation entirely on benchmarking or short runs:
SFTConfig(
val_batches=5, # default; set 0 to skip eval
steps_per_eval=200, # default = max(save_steps, 200)
)
DPO reference model
DPO now uses a frozen reference policy by default — the chosen/rejected log-probabilities of the base model are computed once at the start and reused for the whole run. This matches the standard DPO formulation. If you want the older behaviour (using the policy itself with stop_gradient as a stand-in reference), opt out:
DPOConfig(
precompute_ref_logprobs=True, # default; False for legacy stop-grad reference
)
Environment variables
Two escape hatches for debugging or A/B testing:
| Variable | Effect |
|---|---|
MLX_TUNE_DISABLE_COMPILE=1 | Disable @mx.compile globally and run trainers eagerly. Useful for narrowing down a compile-related issue or comparing eager vs. compiled wall time. |
MLX_TUNE_BUCKET_SIZE=N | Override the default 64-token padding bucket used by collators. Set to 1 to effectively disable bucketing. |
What’s automatic
You don’t need to think about any of this — it’s already happening underneath. Mentioned here so you know what you’re getting:
- Memory-friendly log-probs. Preference and policy losses avoid materializing a full
(batch, length, vocab)log-softmax tensor — matters most on large-vocabulary models like Qwen3 (V≈152k) and Llama 3+ (V≈128k). - KV-cached GRPO generation. Each rollout reuses the prompt cache instead of re-forwarding the prefix every step.
- Shared prompt cache across GRPO rollouts. The prompt forward runs once per prompt, then forks for each of the N completions in the group.
@mx.compilestep wrappers with eager fallback — if a trace fails for any reason, training continues without it.- Length bucketing in collators — consecutive batches hit the same compile-cache slot instead of triggering a recompile per shape.
- Gradient checkpointing propagation —
use_gradient_checkpointingnow actually works for every trainer (see above). - Prompt-prefix sharing for DPO/ORPO/SimPO at batch size 1 — the prompt is forwarded once and reused for both branches.
- Wired memory limit set at the start of every training run, matching mlx-lm’s training-time configuration.
- RNN-T loss refactor for Parakeet TDT — ~1.5× faster backward pass, bit-identical with the prior implementation.