JEPA on Apple Silicon

JEPA (Joint-Embedding Predictive Architecture) is the self-supervised vision paradigm Yann LeCun has championed (in his Path Towards Autonomous Machine Intelligence position paper) as an alternative to generative and contrastive learning. mlx-tune brings the whole JEPA family to your Mac: train LeJEPA from scratch, or load pretrained I-JEPA (images) and V-JEPA 2 (video) and fine-tune them with LoRA — all natively on MLX, no CUDA.

First trainable JEPA on MLX

To our knowledge this is the first solid, tested, train-and-fine-tune JEPA implementation on Apple Silicon. Both pretrained ports (I-JEPA, V-JEPA 2) are numerically identical to the official HuggingFace PyTorch models (cosine similarity 1.000000), verified in the test suite.

Three ways to use JEPA

mlx-tune exposes JEPA through an Unsloth-flavoured API. Pick the track that matches your data and goal.

1. LeJEPA — pretrain from scratch (images)

The newest, leanest JEPA (Balestriero & LeCun, 2025). A single Vision Transformer trained with a multi-view prediction loss plus SIGReg (a distribution regulariser). No EMA teacher, no predictor, no stop-gradient — which is exactly why it ports cleanly to MLX. Use it to learn representations on your own image set without labels.

2. I-JEPA — pretrained image encoder (Meta)

Load Meta's pretrained I-JEPA (ViT-Huge, self-supervised on ImageNet) and use it for feature extraction, linear probing, or LoRA / full fine-tuning on your downstream classification task.

3. V-JEPA 2 — pretrained video world model (Meta)

Load Meta's V-JEPA 2 (trained on >1M hours of video): fine-tune for video classification, extract clip features, run the masked-latent predictor (anticipate a clip's future, score surprise), or use Meta's fine-tuned SSv2 action classifiers with zero training — all ported faithfully to MLX.

What is JEPA (and is it an LLM)?

JEPA is a vision/video architecture, not a language model. Every JEPA here is a Vision Transformer that operates on image patches or video tubelets. The "JEPA" idea is how it's trained: take two views of an input, encode both, and train the model so one view's representation predicts the other's — in latent space, never reconstructing raw pixels. This captures predictable, abstract structure (semantics, dynamics) and discards unpredictable detail.

The enemy of every JEPA is representation collapse (the encoder outputs a constant). Different variants fight it differently:

VariantDomainAnti-collapse mechanism
I-JEPA (2023)ImagesEMA target encoder (teacher) + predictor; masked target blocks
V-JEPA 2 (2025)VideoSame recipe, spatiotemporal masks; world-model framing
LeJEPA (2025)AnySIGReg — provably regularise embeddings to an isotropic Gaussian. No teacher/predictor/stop-grad
Why does LoRA apply here?

LoRA is not LLM-specific — it works on any model built from large linear layers, and a transformer is mostly linear layers (q/k/v/out projections + MLP), whether its inputs are words or image patches. So you can LoRA-fine-tune a 630M-parameter pretrained JEPA encoder on a Mac, training <1% of the weights.

Supported models

LeJEPA presets (random init — train from scratch)

PresetDimDepthHeadsUse
vit-debug6422tests / quick checks
vit-tiny192123small datasets, fast
vit-small384126balanced
vit-base7681212larger datasets

Pretrained checkpoints (downloaded & converted on first use)

ModelHF repoParamsLoader
I-JEPA ViT-H/14facebook/ijepa_vith14_1k630MFastJEPAModel
I-JEPA ViT-g/16facebook/ijepa_vitg16_22k1.0BFastJEPAModel
V-JEPA 2 ViT-Lfacebook/vjepa2-vitl-fpc64-256326MFastVideoJEPAModel
V-JEPA 2 ViT-H/gfacebook/vjepa2-vit{h,g}-fpc64-*0.6–1BFastVideoJEPAModel
No PyTorch needed at load time

Weights are downloaded as safetensors and converted directly to MLX. torch is only needed if you want to run the HuggingFace parity tests.

LeJEPA — self-supervised pretraining

Train a ViT on your own images, no labels required. The loss has a single hyperparameter lam (the SIGReg weight, ~0.05).

from mlx_tune import FastJEPAModel, JEPATrainer, JEPAConfig, linear_probe

# 1. Build a randomly-initialised ViT (LeJEPA trains from scratch).
model, _ = FastJEPAModel.from_pretrained("vit-tiny", img_size=128)

# 2. Self-supervised pretraining — `images` is just a list of HWC arrays / PIL images.
trainer = JEPATrainer(
    model,
    args=JEPAConfig(num_epochs=5, batch_size=64, lam=0.05),
    train_dataset=images,
)
trainer.train()

# 3. Use the learned encoder: extract features, or evaluate with a linear probe.
feats = model.encode(images)                       # (N, dim) frozen features
acc = linear_probe(model, tr_x, tr_y, te_x, te_y)  # quick representation-quality check

# 4. Save / reload the encoder.
model.save_pretrained("my_lejepa")
model, _ = FastJEPAModel.from_pretrained("my_lejepa")
How SIGReg prevents collapse

LeJEPA proves the ideal embedding distribution is an isotropic Gaussian. SIGReg projects embeddings onto many random 1-D directions and, on each, tests whether the projection looks like a standard normal (Epps–Pulley characteristic-function test). The total loss is (1−lam)·prediction + lam·SIGReg. The training loss correlates with downstream accuracy — useful for model selection without labels.

FastJEPAModel

mlx_tune.jepa

FastJEPAModel.from_pretrained()

FastJEPAModel.from_pretrained(model_name, img_size=None, patch_size=16) → Tuple[JEPAModelWrapper, None]

One entry point, three behaviours, dispatched on model_name:

If model_name is…Then
a preset ("vit-tiny"…)random-init ViT for LeJEPA from-scratch training
a saved directoryreload a previously saved JEPA encoder
a HF repo ("facebook/ijepa_*")download + convert pretrained I-JEPA to MLX

model.encode()

model.encode(images, batch_size=64) → mx.array

Run the frozen encoder over a list of images and return pooled features (N, dim) for downstream tasks (clustering, retrieval, probing).

JEPATrainer & JEPAConfig

mlx_tune.jepa

Key JEPAConfig fields

FieldDefaultDescription
lam0.05SIGReg weight — the single LeJEPA hyperparameter
n_global / n_local2 / 6Number of global / local multi-crop views per image
num_slices1024Random projection directions for SIGReg
img_size128Pixel resolution all views are rendered at
learning_rate5e-4AdamW LR (linear warmup → cosine decay)

I-JEPA — pretrained image encoder

Load Meta's pretrained I-JEPA and put its features to work. Inputs default to 224×224 (the resolution it was trained at); other sizes work too via position-embedding interpolation — see Scale-up.

from mlx_tune import FastJEPAModel, linear_probe, knn_probe, attentive_probe

# Downloads ~2.5 GB on first run, then converts HF -> MLX (no torch needed).
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")

# Three frozen-feature probes — none of them backprop through the 630M encoder:
linear    = linear_probe(model, tr_x, tr_y, te_x, te_y)     # logistic regression on mean-pooled features
knn       = knn_probe(model, tr_x, tr_y, te_x, te_y, k=20)  # label-light, no training
attentive = attentive_probe(model, tr_x, tr_y, te_x, te_y)  # attention-pooling head — the paper's eval

# Or just grab features for retrieval / clustering.
feats   = model.encode(my_images)          # (N, 1280)  mean-pooled
tokens  = model.encode_tokens(my_images)   # (N, T, 1280)  per-patch tokens
Use the attentive probe for the real numbers

I-JEPA and V-JEPA 2 encoders don't mean-pool cleanly — a plain linear_probe on pooled features under-reads their quality. attentive_probe trains a small attention-pooling head over the token features (the canonical evaluation in both papers, and V-JEPA 2's own classification head). It's the most faithful measure of what these encoders know. knn_probe is a cheap, training-free, label-light sanity check. (Note: this is about readout strength; the encoder features themselves are bit-for-bit identical to the reference model regardless of probe.)

Domain-adaptive continued pretraining (warm-start LeJEPA)

Because LeJEPA and the pretrained encoders are all ViTs, you can continue self-supervised pretraining from a pretrained checkpoint — adapt Meta's ImageNet-trained I-JEPA to your own domain with the LeJEPA objective, no labels. Just pass the loaded model straight to JEPATrainer:

from mlx_tune import FastJEPAModel, JEPATrainer, JEPAConfig

# Start from pretrained I-JEPA instead of random init...
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")

# ...and keep training it on your unlabeled domain images with LeJEPA / SIGReg.
JEPATrainer(model, args=JEPAConfig(img_size=224, num_epochs=3, learning_rate=1e-4),
            train_dataset=my_domain_images).train()

model.save_pretrained("ijepa_my_domain")   # domain-adapted encoder
Lower the LR when warm-starting

You're refining strong features, not learning from scratch — use a smaller learning rate (e.g. 1e-4) than a from-scratch LeJEPA run so you don't wash out the pretrained representation.

V-JEPA 2 — pretrained video world model

Load Meta's V-JEPA 2 video encoder. Clips are (T, 256, 256, 3) arrays; T must be a multiple of the tubelet size (2). Frames are resized to 256×256 automatically.

from mlx_tune import (
    FastVideoJEPAModel, video_linear_probe, video_knn_probe, video_attentive_probe,
)

model, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")

# Frozen clip features, or any of the three video probes.
clip_feats = model.encode(my_clips)                            # (N, 1024)
linear     = video_linear_probe(model, tr_x, tr_y, te_x, te_y)
knn        = video_knn_probe(model, tr_x, tr_y, te_x, te_y, k=20)
attentive  = video_attentive_probe(model, tr_x, tr_y, te_x, te_y)  # paper's eval
Faithful 3D RoPE port

V-JEPA 2 uses a 3D tubelet Conv3d patch embed plus rotary position embeddings split across the frame / height / width axes. The MLX port replicates the reference rotation exactly, so encoder outputs match the HuggingFace model bit-for-bit.

The predictor: anticipate the rest of a clip

The masked-latent predictor — the "world model" half of V-JEPA 2 — is ported too (HF parity cosine 1.000000) and loads by default. Give it the latents of the first frames and it predicts the latents of the rest; latent_energy scores how surprising the actual future was (anomaly / discontinuity detection in representation space):

from mlx_tune import FastVideoJEPAModel, latent_energy

model, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")

# Context = first 4 frames; predict the latents of everything after.
predicted, actual = model.predict_latents(clip, context_frames=4)

energy = latent_energy(predicted, actual)                  # scalar surprise
e_map  = latent_energy(predicted, actual, per_token=True)  # localise it
# A clip with a hard cut scores higher energy than a coherent one.

Meta's fine-tuned action classifiers (zero training)

Meta's VJEPA2ForVideoClassification checkpoints (attentive pooler + head, fine-tuned on Something-Something-v2) load straight into MLX — from_pretrained auto-detects them. Logits match HuggingFace (cosine 1.000000, same top-1):

clf, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")

results = clf.predict([clip], top_k=5)      # 174 SSv2 action classes
for r in results[0]:
    print(f"{r['prob']:.3f}  {r['label']}")

clf.save_pretrained("ssv2_mlx")             # MLX-native save/reload
Scope: action-conditioned planning still out

mlx-tune now ports the V-JEPA 2 encoder and predictor — features, fine-tuning, masked latent prediction, and anticipation/surprise scoring. What remains out of scope is V-JEPA 2-AC: the separately post-trained action-conditioned predictor and the robot planning loop (CEM / MPC over action sequences). If you want a trainable world model with planning on your Mac today, see LeWM below.

Downstream classification (frozen / LoRA / full)

Both pretrained tracks share the same downstream API. Attach a classification head and pick how much of the encoder to train:

finetune=What trainsUse when
"frozen"just the linear headfast, tiny data, sanity check
"lora" (default)LoRA adapters in every block + head (<1% of params)adapt cheaply without overfitting
"full"everythinglots of data, max accuracy

Image classification (I-JEPA or LeJEPA encoder)

from mlx_tune import (
    FastJEPAModel, JEPAClassifierTrainer, JEPAClassifierConfig,
)

model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")

# Attach a head; LoRA-adapt the encoder. (frozen base is frozen FIRST, then wrapped)
clf = FastJEPAModel.for_image_classification(
    model, num_classes=10, finetune="lora", r=8,
)

trainer = JEPAClassifierTrainer(
    clf,
    JEPAClassifierConfig(img_size=224, batch_size=6, num_epochs=5),
    train_images, train_labels,
    eval_images=val_images, eval_labels=val_labels,
)
trainer.train()
print(f"accuracy: {trainer.evaluate():.3f}")

Video classification (V-JEPA 2 encoder)

from mlx_tune import (
    FastVideoJEPAModel, VideoClassifierTrainer, VideoClassifierConfig,
)

model, _ = FastVideoJEPAModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
clf = FastVideoJEPAModel.for_video_classification(model, num_classes=2, finetune="lora", r=8)

trainer = VideoClassifierTrainer(
    clf, VideoClassifierConfig(batch_size=2, num_epochs=5),
    train_clips, train_labels, eval_videos=val_clips, eval_labels=val_labels,
)
trainer.train()
LoRA on a deep ViT needs warmup

The classifier configs default to learning_rate=3e-4 with warmup_ratio=0.15. Warmup matters: LoRA adapters perturbing all 24–32 transformer layers from step 0 can diverge without it. For frozen / full you can raise the LR.

Save, reload, and predict

Trained classifiers save the encoder, LoRA adapters, and head together. Reloading reconstructs the architecture (re-applying the LoRA structure) and produces identical predictions.

# after trainer.train() ...
clf.save_pretrained("my_classifier")

# Reload in a fresh process and run inference.
clf = FastJEPAModel.load_classifier("my_classifier")        # video: FastVideoJEPAModel.load_classifier
preds = clf.predict(new_images)                             # class ids   (N,)
probs = clf.predict(new_images, return_probs=True)          # softmax probs (N, num_classes)
Lifecycle verified

The full train → save → load → predict path is covered by the test suite for frozen, LoRA, and full modes — reloaded predictions match the trained model exactly.

LLM-JEPA — the JEPA objective for LLM fine-tuning

Every track above is vision. LLM-JEPA (arXiv 2509.14252, Huang, LeCun & Balestriero) brings the same Joint-Embedding Predictive idea to language-model fine-tuning. It is a training objective, not a model: it augments standard next-token prediction (NTP) with a JEPA term that aligns the embeddings of two “views” of the same item — e.g. a natural-language description (text) and its regex / SQL / code (code). It is implemented here for the first time on MLX.

Objective

L = LNTP + λ · d( Pred(Enc(text)), Enc(code) )

  • Enc(x) — the last non-pad token’s final-layer hidden state for view x.
  • Pred — an optional predictor formed by appending num_predictors learnable [PRED] slots (default 0 ⇒ identity).
  • d — cosine dissimilarity (default), or l2 / mse / infonce.

The fine-tuned artifact is a normal LoRA-fine-tuned LLM — save / merge / generate work exactly as with SFT. The JEPA term only shapes the representation during training (the paper reports accuracy gains and resistance to overfitting). The dataset is a list of dicts with the two views under text_field / code_field (default "text" / "code"), falling back to prompt / completion.

LLMJEPATrainer / LLMJEPAConfig

Config parameterDefaultDescription
jepa_lambda0.1Weight on the JEPA term
jepa_distance"cosine"cosine / l2 / mse / infonce
num_predictors0k learnable [PRED] slots (0 ⇒ predictor is identity)
jepa_ratio-1.0JEPA-loss dropout: probability of keeping the term per step (-1 = always on)
ntp_on"all"NTP over the 3-way concat [combined, text, code], or "combined" only
response_onlyFalseMask the text prefix of the combined sequence during NTP
from mlx_tune import FastLanguageModel, LLMJEPATrainer, LLMJEPAConfig

model, tokenizer = FastLanguageModel.from_pretrained("mlx-community/Qwen3.5-0.8B-MLX-4bit")
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)

# Two "views" of the same item: a description and its regex
data = [
    {"text": "match one or more digits", "code": r"\d+"},
    {"text": "match a hex color code",   "code": r"#[0-9a-fA-F]{6}"},
    # ...
]

trainer = LLMJEPATrainer(
    model, data, tokenizer=tokenizer,
    args=LLMJEPAConfig(jepa_lambda=0.1, jepa_distance="cosine", max_steps=40),
)
trainer.train()

# The result is a normal LoRA model — generate / merge as usual:
print(model.generate(prompt="match one or more digits", max_tokens=16))
# model.save_pretrained_merged("./llm_jepa_merged")
Tip

For real benchmarks use paired-view datasets like NL-RX (regex), Spider (SQL) or GSM8K (math) — the datasets the paper evaluates. Set jepa_lambda=0 to train a pure-NTP baseline for an A/B comparison.

Dense & regression heads (counting, depth, segmentation)

I-JEPA's own paper-headline tasks are depth prediction and object counting — i.e. "more than classify". Three non-classification heads attach to any JEPA encoder and reuse the same frozen / LoRA / full fine-tuning machinery:

BuilderHeadLoss / metric
for_image_regression(model, out_dim)JEPAForImageRegression — scalar/vector (counting)MSE / MAE, RMSE
for_dense_prediction(model, C, task="regression")JEPAForDensePrediction — per-pixel value map (depth)MSE / MAE
for_dense_prediction(model, C, task="segmentation")JEPAForDensePrediction — per-pixel class logitsper-pixel CE / pixel-acc
from mlx_tune import (FastJEPAModel, JEPAClassifierConfig,
                      JEPARegressionTrainer, JEPADenseTrainer)

model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k")

# object counting (scalar regression)
reg = FastJEPAModel.for_image_regression(model, out_dim=1, finetune="lora")
JEPARegressionTrainer(reg, JEPAClassifierConfig(img_size=224),
                      images, counts).train()
print(reg.evaluate())                 # {"mae": ..., "rmse": ...}

# depth map (dense regression) — targets are per-pixel maps at img_size
dense = FastJEPAModel.for_dense_prediction(model, out_channels=1, task="regression")
# segmentation — out_channels = num classes, integer mask targets
seg = FastJEPAModel.for_dense_prediction(model, out_channels=21, task="segmentation")

Dense heads predict at the patch grid and bilinearly upsample to img_size. Reload with FastJEPAModel.load_regressor(dir) / load_dense(dir).

Native resolution & training over a real corpus

Non-224 inputs. Pretrained I-JEPA position embeddings were tied to 224²; now you can load at any (patch-divisible) size — the learned grid is bicubically interpolated at load (and on-the-fly for off-grid inputs). The native-resolution path is bit-identical.

# run a pretrained encoder at 384x384
model, _ = FastJEPAModel.from_pretrained("facebook/ijepa_vith14_1k", img_size=384)

Streaming + resumable checkpoints. ImageFolderDataset decodes images lazily from a folder (recursively) on the prefetch thread, so the corpus never has to fit in memory. Set save_steps + resume to checkpoint the encoder + optimizer state and continue an interrupted run.

from mlx_tune import FastJEPAModel, JEPAConfig, JEPATrainer, ImageFolderDataset

model, _ = FastJEPAModel.from_pretrained("vit-base", img_size=224)
ds = ImageFolderDataset("/path/to/images")     # lazy, recursive
cfg = JEPAConfig(img_size=224, batch_size=64, num_epochs=10,
                 save_steps=500, resume=True)   # resumes from /checkpoint
JEPATrainer(model, cfg, ds).train()

LeWM — a latent world model you can train (and plan with)

LeWM (LeWorldModel, arXiv 2603.19312, Maes/Le Lidec/Scieur/LeCun/Balestriero) trains a latent world model end-to-end from pixels with only two loss terms and one tunable hyperparameter:

L = next-embedding prediction + λ · SIGReg(embeddings)

No stop-gradient, no EMA target — SIGReg (the same regularizer as LeJEPA) is what stops the encoder collapsing. The trained model is a small latent world model you can plan with via CEM / MPC over latent rollouts.

This is the trainable, on-device world model — the counterpart to V-JEPA 2-AC's load-and-plan (which is a design writeup for a future release, see jepa.md §9).

from mlx_tune import FastWorldModel, LeWMConfig, LeWMTrainer, plan_cem, PointMassEnv

env = PointMassEnv(size=48)                       # toy 2-D control env
data = env.collect(n_episodes=80, ep_len=10)      # {"frames", "actions"} trajectories

model = FastWorldModel.from_pretrained("lewm-tiny", img_size=48, action_dim=2)
LeWMTrainer(model, LeWMConfig(img_size=48, action_dim=2, sigreg_lambda=0.05,
                              max_steps=250), data).train()

# plan: pick actions whose predicted latent matches a goal image's latent (MPC)
goal_z = model.encode([env.render([0.8, 0.2])])[0]
z = model.encode([env.render()])[0]
action = plan_cem(model, z, goal_z, horizon=3, action_dim=2)
Honest scope

The planner is exact (cost→0 on known dynamics) and training reduces the prediction loss; convincing closed-loop control needs the paper's training budget / real control datasets (pass them in the {"frames","actions"} format). This is a demo of the pipeline, not a SOTA-control claim.

Runnable examples

FileWhat it shows
examples/58_lejepa_pretraining.pyLeJEPA from-scratch pretraining on a CIFAR-10 subset (synthetic fallback) → linear probe
examples/59_ijepa_feature_extraction.pyPretrained I-JEPA: linear probe + LoRA fine-tune + save/load/predict
examples/60_vjepa2_video.pyPretrained V-JEPA 2: video linear probe + LoRA fine-tune + save/load/predict
examples/61_llm_jepa_finetuning.pyLLM-JEPA: fine-tune an LLM with the JEPA objective (NL→regex views, Qwen3.5-0.8B)
examples/62_lewm_world_model.pyLeWM: train a latent world model from pixels + CEM/MPC planning (toy point-mass)
examples/63_jepa_dense_regression.pyDense & regression heads: counting, depth maps, segmentation
examples/64_vjepa2_predictor_classifier.pyV-JEPA 2 predictor (anticipation + surprise energy) + Meta's pretrained SSv2 classifier
python examples/58_lejepa_pretraining.py
python examples/59_ijepa_feature_extraction.py
python examples/60_vjepa2_video.py
python examples/61_llm_jepa_finetuning.py
python examples/62_lewm_world_model.py
python examples/63_jepa_dense_regression.py
python examples/64_vjepa2_predictor_classifier.py

Tips & gotchas

I-JEPA defaults to 224×224 (other sizes supported)

The pretrained position embeddings were trained on a 224×224 grid, so 224 is the default and the loader resizes to it for you. To run at another resolution, pass img_size= at load — the grid is bicubically interpolated and the native-224 path stays bit-identical. See Native resolution & scale-up.

V-JEPA 2 clips: 256×256, frames multiple of 2

The 3D RoPE grid is tied to crop_size=256. Frames are resized to 256×256; the number of frames T must be a multiple of the tubelet size (2). The predictor is ported too (loads by default) — use predict_latents for masked latent prediction; only the action-conditioned V-JEPA 2-AC variant is out of scope.

Memory footprint (M4 Pro, 48 GB)

I-JEPA LoRA fine-tune (batch 6, 224×224) peaks ~13 GB; V-JEPA 2 LoRA (8-frame clip) peaks ~7.6 GB. Frozen-feature probing is far lighter since there's no backprop through the encoder.

Small-scale results are demos

Linear-probe / fine-tune accuracy on small datasets on a Mac is a capability demonstration, not a SOTA claim. For serious results, prototype here and scale the same code on a larger machine.

Performance

Attention uses MLX's fused scaled_dot_product_attention; trainers configure the wired-memory limit and overlap image/video preprocessing with GPU compute on a background thread — so the first trial is fast.