JEPA on Apple Silicon
JEPA (Joint-Embedding Predictive Architecture) is the self-supervised vision paradigm Yann LeCun has championed (in his Path Towards Autonomous Machine Intelligence position paper) as an alternative to generative and contrastive learning. mlx-tune brings the whole JEPA family to your Mac: train LeJEPA from scratch, or load pretrained I-JEPA (images) and V-JEPA 2 (video) and fine-tune them with LoRA — all natively on MLX, no CUDA.
To our knowledge this is the first solid, tested, train-and-fine-tune JEPA implementation on Apple Silicon. Both pretrained ports (I-JEPA, V-JEPA 2) are numerically identical to the official HuggingFace PyTorch models (cosine similarity 1.000000), verified in the test suite.
Three ways to use JEPA
mlx-tune exposes JEPA through an Unsloth-flavoured API. Pick the track that matches your data and goal.
The newest, leanest JEPA (Balestriero & LeCun, 2025). A single Vision Transformer trained with a multi-view prediction loss plus SIGReg (a distribution regulariser). No EMA teacher, no predictor, no stop-gradient — which is exactly why it ports cleanly to MLX. Use it to learn representations on your own image set without labels.
Load Meta's pretrained I-JEPA (ViT-Huge, self-supervised on ImageNet) and use it for feature extraction, linear probing, or LoRA / full fine-tuning on your downstream classification task.
Load Meta's V-JEPA 2 (trained on >1M hours of video): fine-tune for video classification, extract clip features, run the masked-latent predictor (anticipate a clip's future, score surprise), or use Meta's fine-tuned SSv2 action classifiers with zero training — all ported faithfully to MLX.
What is JEPA (and is it an LLM)?
JEPA is a vision/video architecture, not a language model. Every JEPA here is a Vision Transformer that operates on image patches or video tubelets. The "JEPA" idea is how it's trained: take two views of an input, encode both, and train the model so one view's representation predicts the other's — in latent space, never reconstructing raw pixels. This captures predictable, abstract structure (semantics, dynamics) and discards unpredictable detail.
The enemy of every JEPA is representation collapse (the encoder outputs a constant). Different variants fight it differently:
| Variant | Domain | Anti-collapse mechanism |
|---|---|---|
| I-JEPA (2023) | Images | EMA target encoder (teacher) + predictor; masked target blocks |
| V-JEPA 2 (2025) | Video | Same recipe, spatiotemporal masks; world-model framing |
| LeJEPA (2025) | Any | SIGReg — provably regularise embeddings to an isotropic Gaussian. No teacher/predictor/stop-grad |
LoRA is not LLM-specific — it works on any model built from large linear layers, and a transformer is mostly linear layers (q/k/v/out projections + MLP), whether its inputs are words or image patches. So you can LoRA-fine-tune a 630M-parameter pretrained JEPA encoder on a Mac, training <1% of the weights.
Supported models
LeJEPA presets (random init — train from scratch)
| Preset | Dim | Depth | Heads | Use |
|---|---|---|---|---|
vit-debug | 64 | 2 | 2 | tests / quick checks |
vit-tiny | 192 | 12 | 3 | small datasets, fast |
vit-small | 384 | 12 | 6 | balanced |
vit-base | 768 | 12 | 12 | larger datasets |
Pretrained checkpoints (downloaded & converted on first use)
| Model | HF repo | Params | Loader |
|---|---|---|---|
| I-JEPA ViT-H/14 | facebook/ijepa_vith14_1k | 630M | FastJEPAModel |
| I-JEPA ViT-g/16 | facebook/ijepa_vitg16_22k | 1.0B | FastJEPAModel |
| V-JEPA 2 ViT-L | facebook/vjepa2-vitl-fpc64-256 | 326M | FastVideoJEPAModel |
| V-JEPA 2 ViT-H/g | facebook/vjepa2-vit{h,g}-fpc64-* | 0.6–1B | FastVideoJEPAModel |
Weights are downloaded as safetensors and converted directly to MLX. torch is only needed if you want to run the HuggingFace parity tests.
LeJEPA — self-supervised pretraining
Train a ViT on your own images, no labels required. The loss has a single hyperparameter lam (the SIGReg weight, ~0.05).
from mlx_tune import FastJEPAModel, JEPATrainer, JEPAConfig, linear_probe
# 1. Build a randomly-initialised ViT (LeJEPA trains from scratch).
model, _ = FastJEPAModel.from_pretrained("vit-tiny", img_size=128)
# 2. Self-supervised pretraining — `images` is just a list of HWC arrays / PIL images.
trainer = JEPATrainer(
model,
args=JEPAConfig(num_epochs=5, batch_size=64, lam=0.05),
train_dataset=images,
)
trainer.train()
# 3. Use the learned encoder: extract features, or evaluate with a linear probe.
feats = model.encode(images) # (N, dim) frozen features
acc = linear_probe(model, tr_x, tr_y, te_x, te_y) # quick representation-quality check
# 4. Save / reload the encoder.
model.save_pretrained("my_lejepa")
model, _ = FastJEPAModel.from_pretrained("my_lejepa")
LeJEPA proves the ideal embedding distribution is an isotropic Gaussian. SIGReg projects embeddings onto many random 1-D directions and, on each, tests whether the projection looks like a standard normal (Epps–Pulley characteristic-function test). The total loss is (1−lam)·prediction + lam·SIGReg. The training loss correlates with downstream accuracy — useful for model selection without labels.
FastJEPAModel
mlx_tune.jepaFastJEPAModel.from_pretrained()
One entry point, three behaviours, dispatched on model_name:
If model_name is… | Then |
|---|---|
a preset ("vit-tiny"…) | random-init ViT for LeJEPA from-scratch training |
| a saved directory | reload a previously saved JEPA encoder |
a HF repo ("facebook/ijepa_*") | download + convert pretrained I-JEPA to MLX |
model.encode()
Run the frozen encoder over a list of images and return pooled features (N, dim) for downstream tasks (clustering, retrieval, probing).
JEPATrainer & JEPAConfig
mlx_tune.jepaKey JEPAConfig fields
| Field | Default | Description |
|---|---|---|
lam | 0.05 | SIGReg weight — the single LeJEPA hyperparameter |
n_global / n_local | 2 / 6 | Number of global / local multi-crop views per image |
num_slices | 1024 | Random projection directions for SIGReg |
img_size | 128 | Pixel resolution all views are rendered at |
learning_rate | 5e-4 | AdamW LR (linear warmup → cosine decay) |
I-JEPA — pretrained image encoder
Load Meta's pretrained I-JEPA and put its features to work. Inputs default to 224×224 (the resolution it was trained at); other sizes work too via position-embedding interpolation — see Scale-up.
from mlx_tune import FastJEPAModel, linear_probe, knn_probe, attentive_probe
# Downloads ~2.5 GB on first run, then converts HF -> MLX (no torch needed).
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")
# Three frozen-feature probes — none of them backprop through the 630M encoder:
linear = linear_probe(model, tr_x, tr_y, te_x, te_y) # logistic regression on mean-pooled features
knn = knn_probe(model, tr_x, tr_y, te_x, te_y, k=20) # label-light, no training
attentive = attentive_probe(model, tr_x, tr_y, te_x, te_y) # attention-pooling head — the paper's eval
# Or just grab features for retrieval / clustering.
feats = model.encode(my_images) # (N, 1280) mean-pooled
tokens = model.encode_tokens(my_images) # (N, T, 1280) per-patch tokens
I-JEPA and V-JEPA 2 encoders don't mean-pool cleanly — a plain linear_probe on pooled features under-reads their quality. attentive_probe trains a small attention-pooling head over the token features (the canonical evaluation in both papers, and V-JEPA 2's own classification head). It's the most faithful measure of what these encoders know. knn_probe is a cheap, training-free, label-light sanity check. (Note: this is about readout strength; the encoder features themselves are bit-for-bit identical to the reference model regardless of probe.)
Domain-adaptive continued pretraining (warm-start LeJEPA)
Because LeJEPA and the pretrained encoders are all ViTs, you can continue self-supervised pretraining from a pretrained checkpoint — adapt Meta's ImageNet-trained I-JEPA to your own domain with the LeJEPA objective, no labels. Just pass the loaded model straight to JEPATrainer:
from mlx_tune import FastJEPAModel, JEPATrainer, JEPAConfig
# Start from pretrained I-JEPA instead of random init...
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")
# ...and keep training it on your unlabeled domain images with LeJEPA / SIGReg.
JEPATrainer(model, args=JEPAConfig(img_size=224, num_epochs=3, learning_rate=1e-4),
train_dataset=my_domain_images).train()
model.save_pretrained("ijepa_my_domain") # domain-adapted encoder
You're refining strong features, not learning from scratch — use a smaller learning rate (e.g. 1e-4) than a from-scratch LeJEPA run so you don't wash out the pretrained representation.
V-JEPA 2 — pretrained video world model
Load Meta's V-JEPA 2 video encoder. Clips are (T, 256, 256, 3) arrays; T must be a multiple of the tubelet size (2). Frames are resized to 256×256 automatically.
from mlx_tune import (
FastVideoJEPAModel, video_linear_probe, video_knn_probe, video_attentive_probe,
)
model, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
# Frozen clip features, or any of the three video probes.
clip_feats = model.encode(my_clips) # (N, 1024)
linear = video_linear_probe(model, tr_x, tr_y, te_x, te_y)
knn = video_knn_probe(model, tr_x, tr_y, te_x, te_y, k=20)
attentive = video_attentive_probe(model, tr_x, tr_y, te_x, te_y) # paper's eval
V-JEPA 2 uses a 3D tubelet Conv3d patch embed plus rotary position embeddings split across the frame / height / width axes. The MLX port replicates the reference rotation exactly, so encoder outputs match the HuggingFace model bit-for-bit.
The predictor: anticipate the rest of a clip
The masked-latent predictor — the "world model" half of V-JEPA 2 — is ported too (HF parity cosine 1.000000) and loads by default. Give it the latents of the first frames and it predicts the latents of the rest; latent_energy scores how surprising the actual future was (anomaly / discontinuity detection in representation space):
from mlx_tune import FastVideoJEPAModel, latent_energy
model, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
# Context = first 4 frames; predict the latents of everything after.
predicted, actual = model.predict_latents(clip, context_frames=4)
energy = latent_energy(predicted, actual) # scalar surprise
e_map = latent_energy(predicted, actual, per_token=True) # localise it
# A clip with a hard cut scores higher energy than a coherent one.
Meta's fine-tuned action classifiers (zero training)
Meta's VJEPA2ForVideoClassification checkpoints (attentive pooler + head, fine-tuned on Something-Something-v2) load straight into MLX — from_pretrained auto-detects them. Logits match HuggingFace (cosine 1.000000, same top-1):
clf, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
results = clf.predict([clip], top_k=5) # 174 SSv2 action classes
for r in results[0]:
print(f"{r['prob']:.3f} {r['label']}")
clf.save_pretrained("ssv2_mlx") # MLX-native save/reload
mlx-tune now ports the V-JEPA 2 encoder and predictor — features, fine-tuning, masked latent prediction, and anticipation/surprise scoring. What remains out of scope is V-JEPA 2-AC: the separately post-trained action-conditioned predictor and the robot planning loop (CEM / MPC over action sequences). If you want a trainable world model with planning on your Mac today, see LeWM below.
Downstream classification (frozen / LoRA / full)
Both pretrained tracks share the same downstream API. Attach a classification head and pick how much of the encoder to train:
finetune= | What trains | Use when |
|---|---|---|
"frozen" | just the linear head | fast, tiny data, sanity check |
"lora" (default) | LoRA adapters in every block + head (<1% of params) | adapt cheaply without overfitting |
"full" | everything | lots of data, max accuracy |
Image classification (I-JEPA or LeJEPA encoder)
from mlx_tune import (
FastJEPAModel, JEPAClassifierTrainer, JEPAClassifierConfig,
)
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")
# Attach a head; LoRA-adapt the encoder. (frozen base is frozen FIRST, then wrapped)
clf = FastJEPAModel.for_image_classification(
model, num_classes=10, finetune="lora", r=8,
)
trainer = JEPAClassifierTrainer(
clf,
JEPAClassifierConfig(img_size=224, batch_size=6, num_epochs=5),
train_images, train_labels,
eval_images=val_images, eval_labels=val_labels,
)
trainer.train()
print(f"accuracy: {trainer.evaluate():.3f}")
Video classification (V-JEPA 2 encoder)
from mlx_tune import (
FastVideoJEPAModel, VideoClassifierTrainer, VideoClassifierConfig,
)
model, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
clf = FastVideoJEPAModel.for_video_classification(model, num_classes=2, finetune="lora", r=8)
trainer = VideoClassifierTrainer(
clf, VideoClassifierConfig(batch_size=2, num_epochs=5),
train_clips, train_labels, eval_videos=val_clips, eval_labels=val_labels,
)
trainer.train()
The classifier configs default to learning_rate=3e-4 with warmup_ratio=0.15. Warmup matters: LoRA adapters perturbing all 24–32 transformer layers from step 0 can diverge without it. For frozen / full you can raise the LR.
Save, reload, and predict
Trained classifiers save the encoder, LoRA adapters, and head together. Reloading reconstructs the architecture (re-applying the LoRA structure) and produces identical predictions.
# after trainer.train() ...
clf.save_pretrained("my_classifier")
# Reload in a fresh process and run inference.
clf = FastJEPAModel.load_classifier("my_classifier") # video: FastVideoJEPAModel.load_classifier
preds = clf.predict(new_images) # class ids (N,)
probs = clf.predict(new_images, return_probs=True) # softmax probs (N, num_classes)
The full train → save → load → predict path is covered by the test suite for frozen, LoRA, and full modes — reloaded predictions match the trained model exactly.
LLM-JEPA — the JEPA objective for LLM fine-tuning
Every track above is vision. LLM-JEPA (arXiv 2509.14252, Huang, LeCun & Balestriero) brings the same Joint-Embedding Predictive idea to language-model fine-tuning. It is a training objective, not a model: it augments standard next-token prediction (NTP) with a JEPA term that aligns the embeddings of two “views” of the same item — e.g. a natural-language description (text) and its regex / SQL / code (code). It is implemented here for the first time on MLX.
L = LNTP + λ · d( Pred(Enc(text)), Enc(code) )
- Enc(x) — the last non-pad token’s final-layer hidden state for view
x. - Pred — an optional predictor formed by appending
num_predictorslearnable[PRED]slots (default0⇒ identity). - d — cosine dissimilarity (default), or
l2/mse/infonce.
The fine-tuned artifact is a normal LoRA-fine-tuned LLM — save / merge / generate work exactly as with SFT. The JEPA term only shapes the representation during training (the paper reports accuracy gains and resistance to overfitting). The dataset is a list of dicts with the two views under text_field / code_field (default "text" / "code"), falling back to prompt / completion.
LLMJEPATrainer / LLMJEPAConfig
| Config parameter | Default | Description |
|---|---|---|
jepa_lambda | 0.1 | Weight on the JEPA term |
jepa_distance | "cosine" | cosine / l2 / mse / infonce |
num_predictors | 0 | k learnable [PRED] slots (0 ⇒ predictor is identity) |
jepa_ratio | -1.0 | JEPA-loss dropout: probability of keeping the term per step (-1 = always on) |
ntp_on | "all" | NTP over the 3-way concat [combined, text, code], or "combined" only |
response_only | False | Mask the text prefix of the combined sequence during NTP |
from mlx_tune import FastLanguageModel, LLMJEPATrainer, LLMJEPAConfig
model, tokenizer = FastLanguageModel.from_pretrained("mlx-community/Qwen3.5-0.8B-MLX-4bit")
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
# Two "views" of the same item: a description and its regex
data = [
{"text": "match one or more digits", "code": r"\d+"},
{"text": "match a hex color code", "code": r"#[0-9a-fA-F]{6}"},
# ...
]
trainer = LLMJEPATrainer(
model, data, tokenizer=tokenizer,
args=LLMJEPAConfig(jepa_lambda=0.1, jepa_distance="cosine", max_steps=40),
)
trainer.train()
# The result is a normal LoRA model — generate / merge as usual:
print(model.generate(prompt="match one or more digits", max_tokens=16))
# model.save_pretrained_merged("./llm_jepa_merged")
For real benchmarks use paired-view datasets like NL-RX (regex), Spider (SQL) or GSM8K (math) — the datasets the paper evaluates. Set jepa_lambda=0 to train a pure-NTP baseline for an A/B comparison.
Dense & regression heads (counting, depth, segmentation)
I-JEPA's own paper-headline tasks are depth prediction and object counting — i.e. "more than classify". Three non-classification heads attach to any JEPA encoder and reuse the same frozen / LoRA / full fine-tuning machinery:
| Builder | Head | Loss / metric |
|---|---|---|
for_image_regression(model, out_dim) | JEPAForImageRegression — scalar/vector (counting) | MSE / MAE, RMSE |
for_dense_prediction(model, C, task="regression") | JEPAForDensePrediction — per-pixel value map (depth) | MSE / MAE |
for_dense_prediction(model, C, task="segmentation") | JEPAForDensePrediction — per-pixel class logits | per-pixel CE / pixel-acc |
from mlx_tune import (FastJEPAModel, JEPAClassifierConfig,
JEPARegressionTrainer, JEPADenseTrainer)
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")
# object counting (scalar regression)
reg = FastJEPAModel.for_image_regression(model, out_dim=1, finetune="lora")
JEPARegressionTrainer(reg, JEPAClassifierConfig(img_size=224),
images, counts).train()
print(reg.evaluate()) # {"mae": ..., "rmse": ...}
# depth map (dense regression) — targets are per-pixel maps at img_size
dense = FastJEPAModel.for_dense_prediction(model, out_channels=1, task="regression")
# segmentation — out_channels = num classes, integer mask targets
seg = FastJEPAModel.for_dense_prediction(model, out_channels=21, task="segmentation")
Dense heads predict at the patch grid and bilinearly upsample to img_size.
Reload with FastJEPAModel.load_regressor(dir) / load_dense(dir).
Native resolution & training over a real corpus
Non-224 inputs. Pretrained I-JEPA position embeddings were tied to 224²; now you can load at any (patch-divisible) size — the learned grid is bicubically interpolated at load (and on-the-fly for off-grid inputs). The native-resolution path is bit-identical.
# run a pretrained encoder at 384x384
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k", img_size=384)
Streaming + resumable checkpoints. ImageFolderDataset decodes
images lazily from a folder (recursively) on the prefetch thread, so the corpus never has to
fit in memory. Set save_steps + resume to checkpoint the encoder +
optimizer state and continue an interrupted run.
from mlx_tune import FastJEPAModel, JEPAConfig, JEPATrainer, ImageFolderDataset
model, _ = FastJEPAModel.from_pretrained("vit-base", img_size=224)
ds = ImageFolderDataset("/path/to/images") # lazy, recursive
cfg = JEPAConfig(img_size=224, batch_size=64, num_epochs=10,
save_steps=500, resume=True) # resumes from /checkpoint
JEPATrainer(model, cfg, ds).train()
LeWM — a latent world model you can train (and plan with)
LeWM (LeWorldModel, arXiv 2603.19312, Maes/Le Lidec/Scieur/LeCun/Balestriero) trains a latent world model end-to-end from pixels with only two loss terms and one tunable hyperparameter:
L = next-embedding prediction + λ · SIGReg(embeddings)
No stop-gradient, no EMA target — SIGReg (the same regularizer as LeJEPA) is what stops the encoder collapsing. The trained model is a small latent world model you can plan with via CEM / MPC over latent rollouts.
This is the trainable, on-device world model — the counterpart to V-JEPA 2-AC's
load-and-plan (which is a design writeup for a future
release, see jepa.md §9).
from mlx_tune import FastWorldModel, LeWMConfig, LeWMTrainer, plan_cem, PointMassEnv
env = PointMassEnv(size=48) # toy 2-D control env
data = env.collect(n_episodes=80, ep_len=10) # {"frames", "actions"} trajectories
model = FastWorldModel.from_pretrained("lewm-tiny", img_size=48, action_dim=2)
LeWMTrainer(model, LeWMConfig(img_size=48, action_dim=2, sigreg_lambda=0.05,
max_steps=250), data).train()
# plan: pick actions whose predicted latent matches a goal image's latent (MPC)
goal_z = model.encode([env.render([0.8, 0.2])])[0]
z = model.encode([env.render()])[0]
action = plan_cem(model, z, goal_z, horizon=3, action_dim=2)
The planner is exact (cost→0 on known dynamics) and training reduces the prediction
loss; convincing closed-loop control needs the paper's training budget / real control
datasets (pass them in the {"frames","actions"} format). This is a demo of the
pipeline, not a SOTA-control claim.
Runnable examples
| File | What it shows |
|---|---|
examples/58_lejepa_pretraining.py | LeJEPA from-scratch pretraining on a CIFAR-10 subset (synthetic fallback) → linear probe |
examples/59_ijepa_feature_extraction.py | Pretrained I-JEPA: linear probe + LoRA fine-tune + save/load/predict |
examples/60_vjepa2_video.py | Pretrained V-JEPA 2: video linear probe + LoRA fine-tune + save/load/predict |
examples/61_llm_jepa_finetuning.py | LLM-JEPA: fine-tune an LLM with the JEPA objective (NL→regex views, Qwen3.5-0.8B) |
examples/62_lewm_world_model.py | LeWM: train a latent world model from pixels + CEM/MPC planning (toy point-mass) |
examples/63_jepa_dense_regression.py | Dense & regression heads: counting, depth maps, segmentation |
examples/64_vjepa2_predictor_classifier.py | V-JEPA 2 predictor (anticipation + surprise energy) + Meta's pretrained SSv2 classifier |
python examples/58_lejepa_pretraining.py
python examples/59_ijepa_feature_extraction.py
python examples/60_vjepa2_video.py
python examples/61_llm_jepa_finetuning.py
python examples/62_lewm_world_model.py
python examples/63_jepa_dense_regression.py
python examples/64_vjepa2_predictor_classifier.py
Tips & gotchas
The pretrained position embeddings were trained on a 224×224 grid, so 224 is the default and the loader resizes to it for you. To run at another resolution, pass img_size= at load — the grid is bicubically interpolated and the native-224 path stays bit-identical. See Native resolution & scale-up.
The 3D RoPE grid is tied to crop_size=256. Frames are resized to 256×256; the number of frames T must be a multiple of the tubelet size (2). The predictor is ported too (loads by default) — use predict_latents for masked latent prediction; only the action-conditioned V-JEPA 2-AC variant is out of scope.
I-JEPA LoRA fine-tune (batch 6, 224×224) peaks ~13 GB; V-JEPA 2 LoRA (8-frame clip) peaks ~7.6 GB. Frozen-feature probing is far lighter since there's no backprop through the encoder.
Linear-probe / fine-tune accuracy on small datasets on a Mac is a capability demonstration, not a SOTA claim. For serious results, prototype here and scale the same code on a larger machine.
Attention uses MLX's fused scaled_dot_product_attention; trainers configure the wired-memory limit and overlap image/video preprocessing with GPU compute on a background thread — so the first trial is fast.