LLM Fine-Tuning

Fine-tune language models on Apple Silicon with LoRA. SFT, DPO, GRPO, KTO, SimPO — all natively on MLX.

Full SFT pipeline in 20 lines

Load a model, add LoRA, train, and save — all with the same API you already know from Unsloth.

from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
from datasets import load_dataset

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="mlx-community/Llama-3.2-1B-Instruct-4bit",
    max_seq_length=2048,
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
    model, r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_alpha=16,
)
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:100]")
trainer = SFTTrainer(
    model=model, train_dataset=dataset, tokenizer=tokenizer,
    args=SFTConfig(output_dir="outputs", max_steps=50, learning_rate=2e-4),
)
trainer.train()
model.save_pretrained("lora_model")

FastLanguageModel

mlx_tune.model

Main entry point for loading and configuring language models. Mirrors Unsloth’s FastLanguageModel API.

FastLanguageModel.from_pretrained()

FastLanguageModel.from_pretrained(model_name, max_seq_length=2048, load_in_4bit=False, load_in_8bit=False, use_gradient_checkpointing="unsloth", ...) → Tuple[MLXModelWrapper, Tokenizer]

Load a pretrained language model from HuggingFace or a local path.

ParameterTypeDescription
model_namestrHuggingFace model ID (e.g., "mlx-community/Llama-3.2-1B-Instruct-4bit") or local path
max_seq_lengthintMaximum sequence length for training and inference
load_in_4bitboolLoad model with 4-bit quantization (QLoRA)
load_in_8bitboolLoad model with 8-bit quantization
use_gradient_checkpointingstr | boolGradient checkpointing mode (accepted for Unsloth compat)

FastLanguageModel.get_peft_model()

FastLanguageModel.get_peft_model(model, r=16, target_modules=None, lora_alpha=16, lora_dropout=0.0, bias="none", use_rslora=False, random_state=3407, ...) → MLXModelWrapper

Add LoRA adapters to the model for parameter-efficient fine-tuning.

ParameterTypeDescription
modelMLXModelWrapperModel returned by from_pretrained()
rintLoRA rank (higher = more parameters, better quality)
target_moduleslist[str]Modules to apply LoRA to. Default: ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_alphaintLoRA scaling factor. Recommended: equal to r
lora_dropoutfloatDropout for LoRA layers
biasstrBias mode: "none", "all", or "lora_only"
use_rsloraboolUse rank-stabilized LoRA scaling

FastLanguageModel.for_training()

FastLanguageModel.for_training(model) → MLXModelWrapper

Enable training mode: disables KV caching, enables dropout.

FastLanguageModel.for_inference()

FastLanguageModel.for_inference(model, use_cache=True) → MLXModelWrapper

Enable inference mode: activates KV caching, disables dropout. Always call before generating.

FastLanguageModel.convert()

FastLanguageModel.convert(hf_model, output_dir="mlx_model", quantize=False, q_bits=4, dtype=None, upload_repo=None)

Convert a HuggingFace model to MLX format. Optionally quantize and upload.

ParameterTypeDescription
hf_modelstrHuggingFace model ID (e.g., "meta-llama/Llama-3-8B")
output_dirstrLocal directory to save the converted model
quantizeboolWhether to quantize during conversion
q_bitsintQuantization bits (4 or 8)
dtypestr, optionalTarget dtype (e.g., "float16")
upload_repostr, optionalHuggingFace repo ID to upload the converted model

MLXModelWrapper

mlx_tune.model

Wrapper providing Unsloth-compatible methods on MLX models. Returned by FastLanguageModel.from_pretrained().

model.generate()

model.generate(prompt, max_tokens=256, temperature=0.7, top_p=0.9, min_p=0.0, ...) → str

Generate text from a prompt. Call FastLanguageModel.for_inference(model) first.

ParameterTypeDescription
promptstrInput text or formatted chat prompt
max_tokensintMaximum number of tokens to generate
temperaturefloatSampling temperature (0.0 = greedy)
top_pfloatNucleus sampling probability
min_pfloatMinimum probability filter (recommended: 0.1)

model.save_pretrained()

model.save_pretrained(output_dir)

Save LoRA adapters only. Writes adapters.safetensors and adapter_config.json.

model.save_pretrained_merged()

model.save_pretrained_merged(output_dir, tokenizer, save_method="merged_16bit")

Fuse LoRA weights into the base model and save the full merged model.

model.save_pretrained_gguf()

model.save_pretrained_gguf(output_dir, tokenizer, quantization_method="q4_k_m")

Export to GGUF format for use with Ollama, llama.cpp, and other inference engines.

GGUF Limitation

GGUF export works with non-quantized models only. Quantized (4-bit) model export is an mlx-lm limitation, not an mlx-tune bug. Load the model with load_in_4bit=False before exporting.

model.push_to_hub()

model.push_to_hub(repo_id)

Push the model or adapters to HuggingFace Hub.

model.stream_generate()

model.stream_generate(prompt, ...) — Stream text token by token

SFT Training

mlx_tune.sft_trainer

SFTTrainer

SFTTrainer(model, train_dataset, tokenizer=None, eval_dataset=None, args=None, ...)

Supervised fine-tuning trainer. API-compatible with TRL’s SFTTrainer. Automatically detects dataset format (Alpaca, ShareGPT, ChatML) and converts to the correct training format.

ParameterTypeDescription
modelMLXModelWrapperModel with LoRA adapters configured
train_datasetDatasetHuggingFace dataset or list of dicts
tokenizerTokenizerTokenizer from from_pretrained()
eval_datasetDataset, optionalEvaluation dataset
argsSFTConfigTraining configuration
trainer.train() — Start training. Returns training statistics.
trainer.save_model(output_dir) — Save the trained model.

SFTConfig

SFTConfig(output_dir="outputs", per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-4, lr_scheduler_type="linear", warmup_steps=5, num_train_epochs=1, max_steps=-1, logging_steps=1, save_steps=500, max_seq_length=2048, optim="adam", weight_decay=0.01, seed=3407, report_to="none", ...)

Training configuration. Compatible with TRL’s SFTConfig parameters.

ParameterDefaultDescription
output_dir"outputs"Directory for checkpoints and logs
per_device_train_batch_size2Batch size per device
gradient_accumulation_steps4Number of gradient accumulation steps
learning_rate2e-4Peak learning rate
max_steps-1Total training steps (-1 = use epochs)
max_seq_length2048Maximum sequence length
optim"adam"Optimizer (use "adam" for MLX)
warmup_steps5Linear warmup steps
lr_scheduler_type"linear"LR scheduler: linear, cosine, constant
logging_steps1Log metrics every N steps
save_steps500Save checkpoint every N steps
weight_decay0.01Weight decay for regularization
seed3407Random seed for reproducibility

RL Trainers

mlx_tune.rl_trainers

All RL trainers use proper loss implementations with full log-probability computation — not wrappers around SFT.

DPOTrainer

DPOTrainer(model, train_dataset, args=None, tokenizer=None, ...)

Direct Preference Optimization. Uses proper DPO loss with log-probability computation over chosen/rejected pairs.

DPOConfig

DPOConfig(output_dir, beta=0.1, max_steps=-1, learning_rate=5e-7, ...)
ParameterDefaultDescription
beta0.1KL penalty coefficient (higher = more conservative)
learning_rate5e-7Lower than SFT to avoid reward hacking

ORPOTrainer

ORPOTrainer(model, train_dataset, args=None, ...)

Odds Ratio Preference Optimization. Combines SFT loss with odds-ratio preference alignment.

ORPOConfig(output_dir, beta=0.1, ...)

GRPOTrainer

GRPOTrainer(model, train_dataset, args=None, reward_fn=None, tokenizer=None, ...)

Group Relative Policy Optimization (DeepSeek R1 style). Generates multiple completions per prompt and optimizes based on relative rewards.

GRPOConfig

GRPOConfig(output_dir, num_generations=4, beta=0.04, ...)
ParameterDefaultDescription
num_generations4Number of completions generated per prompt
beta0.04KL divergence coefficient

KTOTrainer

KTOTrainer(model, train_dataset, args=None, tokenizer=None, ...)

Kahneman-Tversky Optimization. Works with binary feedback (desirable/undesirable) instead of paired preferences. Supports both TRL format (prompt+completion+label) and legacy format (text+label).

KTOConfig

KTOConfig(output_dir, beta=0.1, desirable_weight=1.0, undesirable_weight=1.0, learning_rate=5e-7, ...)
ParameterDefaultDescription
beta0.1Temperature coefficient for KL penalty
desirable_weight1.0Weight for desirable (positive) examples
undesirable_weight1.0Weight for undesirable (negative) examples
learning_rate5e-7Learning rate

SimPOTrainer

SimPOTrainer(model, train_dataset, args=None, tokenizer=None, ...)

Simple Preference Optimization. No reference model needed — uses length-normalized log probabilities as implicit rewards. More memory efficient than DPO.

SimPOConfig

SimPOConfig(output_dir, beta=2.0, gamma=0.5, learning_rate=5e-7, ...)
ParameterDefaultDescription
beta2.0Temperature coefficient (typically higher than DPO)
gamma0.5Target reward margin between chosen and rejected
learning_rate5e-7Learning rate

Continual Pretraining

mlx_tune.cpt_trainer

Continual pretraining (CPT) lets you adapt a pretrained model to a new domain or language using raw text. Unlike SFT, loss is computed on all tokens (not just responses). Supports both LoRA-based CPT and full-weight training.

When to use CPT:
  • Language adaptation — Teach a model a new language with raw text corpora
  • Domain knowledge — Inject domain-specific knowledge (medical, legal, scientific)
  • Code capabilities — Extend a model’s programming language coverage

CPTTrainer

CPTTrainer(model, train_dataset, tokenizer=None, args=None, ...)

Continual pretraining trainer. Trains on raw text with loss on all tokens. Optionally trains embed_tokens and lm_head layers with a separate (decoupled) learning rate.

ParameterTypeDescription
modelMLXModelWrapperModel with LoRA adapters (or base model for full-weight CPT)
train_datasetDatasetRaw text dataset (each sample has a "text" field)
tokenizerTokenizerTokenizer from from_pretrained()
argsCPTConfigTraining configuration
trainer.train() — Start continual pretraining. Returns training statistics.

CPTConfig

CPTConfig(output_dir="./cpt_outputs", learning_rate=5e-5, include_embeddings=True, embedding_learning_rate=None, ...)

Configuration for continual pretraining. Extends SFTConfig with CPT-specific options.

ParameterDefaultDescription
output_dir"outputs"Directory for checkpoints and logs
learning_rate5e-5Learning rate for LoRA / main parameters
include_embeddingsTrueAuto-add embed_tokens + lm_head to target modules and unfreeze them
embedding_learning_ratelr/5Decoupled learning rate for embed_tokens/lm_head (defaults to learning_rate / 5)
per_device_train_batch_size2Batch size per device
max_steps-1Total training steps (-1 = use epochs)
max_seq_length2048Maximum sequence length for chunking raw text

LoRA CPT Example

from mlx_tune import FastLanguageModel, CPTTrainer, CPTConfig

model, tokenizer = FastLanguageModel.from_pretrained(
    "mlx-community/SmolLM2-360M-Instruct",  # Use base model for CPT
    max_seq_length=2048,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)

# Raw text dataset — loss on ALL tokens (no chat template)
dataset = [{"text": "Your domain-specific text here..."}, ...]

trainer = CPTTrainer(
    model=model, train_dataset=dataset, tokenizer=tokenizer,
    args=CPTConfig(
        learning_rate=5e-5,
        embedding_learning_rate=5e-6,         # 10x smaller for embeddings
        include_embeddings=True,              # Auto-adds embed_tokens + lm_head
        max_steps=100,
    ),
)
trainer.train()

Examples

MoE Fine-Tuning

Mixture of Experts

mlx-tune automatically detects MoE architectures and applies per-expert LoRA via LoRASwitchLinear. No special API — use the same FastLanguageModel and SFTTrainer as dense models.

How it works: When you specify target_modules=["gate_proj", ...], mlx-tune inspects the model’s actual layer structure and resolves paths dynamically:
  • Expert layers: mlp.switch_mlp.gate_projLoRASwitchLinear (per-expert LoRA)
  • Shared experts: mlp.shared_expert.gate_projLoRALinear
  • Dense layers: mlp.gate_projLoRALinear (mixed architectures)
  • Router: mlp.gate — automatically excluded (not fine-tuned)

Supported MoE Models

ModelTotal / Active ParamsMLX Model ID
Arcee Trinity-Nano (AFMoE)6B / 1Bmlx-community/Trinity-Nano-Preview-4bit — 128 experts + 1 shared, sigmoid routing, gated attention
Gemma 4 26B-A4B26B / ~4Bmlx-community/gemma-4-26b-a4b-it-4bit (VLM path)
Qwen3.5-35B-A3B35B / 3Bmlx-community/Qwen3.5-35B-A3B-4bit
Qwen3-30B-A3B30B / 3Bmlx-community/Qwen3-30B-A3B-Instruct-2507-4bit
Phi-3.5-MoE42B / 6.6Bmlx-community/Phi-3.5-MoE-instruct-4bit
Mixtral-8x7B46B / 12Bmlx-community/Mixtral-8x7B-Instruct-v0.1-hf-4bit-mlx
...and all other MoE architectures supported by mlx-lm (39+ total)

Embedding Models

mlx_tune.embeddings

Fine-tune sentence embedding models for semantic search using contrastive learning (InfoNCE loss). Supports BERT, ModernBERT, Qwen3-Embedding, Microsoft Harrier, and other sentence-transformers compatible architectures.

How it works: Architecture is auto-detected from config.json model_type. LoRA targets are resolved per architecture:
  • BERT / XLM-RoBERTa: attention.self.{query, key, value}
  • ModernBERT: attn.Wqkv (fused QKV)
  • Qwen3 / Gemma3 / Harrier: self_attn.{q_proj, k_proj, v_proj, o_proj}

Supported Models

ModelArchitecturePoolingHuggingFace ID
all-MiniLM-L6-v2BERTMeanmlx-community/all-MiniLM-L6-v2-bf16
Qwen3-Embedding-0.6BQwen3Last Tokenmlx-community/Qwen3-Embedding-0.6B-4bit-DWQ
Harrier 0.6BQwen3Last Tokenmicrosoft/harrier-oss-v1-0.6b
Harrier 270MGemma3Last Tokenmicrosoft/harrier-oss-v1-270m
...and other models supported by mlx-embeddings (BERT, XLM-RoBERTa, ModernBERT, etc.)

FastEmbeddingModel

FastEmbeddingModel.from_pretrained(model_name, max_seq_length=512, pooling_strategy="mean", ...) → (EmbeddingModelWrapper, Tokenizer)

Load a sentence embedding model. Architecture and LoRA targets are auto-detected.

ParameterTypeDescription
model_namestrHuggingFace model ID or local path
max_seq_lengthintMaximum token length (default: 512)
pooling_strategystr"mean", "cls", or "last_token" (use "last_token" for decoder-based models)
load_in_4bitboolLoad in 4-bit quantization
FastEmbeddingModel.get_peft_model(model, r=16, lora_alpha=16, ...) → EmbeddingModelWrapper

Apply LoRA adapters. Targets are auto-detected from model architecture.

ParameterTypeDescription
rintLoRA rank (default: 16)
target_moduleslistOverride auto-detected targets (optional)
lora_alphaintLoRA scaling factor (default: 16)
lora_dropoutfloatDropout rate (default: 0.0)
from mlx_tune import FastEmbeddingModel

# BERT-based (mean pooling)
model, tokenizer = FastEmbeddingModel.from_pretrained(
    "mlx-community/all-MiniLM-L6-v2-bf16",
)
# Decoder-based (last-token pooling)
model, tokenizer = FastEmbeddingModel.from_pretrained(
    "microsoft/harrier-oss-v1-0.6b",
    pooling_strategy="last_token",
)
model = FastEmbeddingModel.get_peft_model(model, r=8, lora_alpha=16)

EmbeddingSFTTrainer

EmbeddingSFTTrainer(model, tokenizer, data_collator, train_dataset, args)

Train embedding models with contrastive learning. Supports InfoNCE (in-batch negatives), cosine embedding, and triplet loss.

Config ParameterDefaultDescription
loss_type"infonce""infonce", "cosine", or "triplet"
temperature0.05InfoNCE temperature (lower = sharper)
per_device_train_batch_size32Batch size (larger = more in-batch negatives)
learning_rate2e-5Optimizer learning rate
max_stepsTotal training steps
max_seq_length512Maximum token length
normalize_embeddingsTrueL2-normalize embeddings before loss
from mlx_tune import EmbeddingSFTTrainer, EmbeddingSFTConfig, EmbeddingDataCollator

train_data = [
    {"anchor": "What is LoRA?", "positive": "Low-rank adaptation for efficient fine-tuning."},
    # ...
]

trainer = EmbeddingSFTTrainer(
    model=model, tokenizer=tokenizer,
    data_collator=EmbeddingDataCollator(model, tokenizer),
    train_dataset=train_data,
    args=EmbeddingSFTConfig(
        loss_type="infonce", temperature=0.05,
        per_device_train_batch_size=10, max_steps=30,
    ),
)
trainer.train()

# Encode and compare
embeddings = model.encode(["query text", "document text"])
similarity = (embeddings[0] * embeddings[1]).sum().item()

Examples

  • Example 27 — BERT/MiniLM embedding fine-tuning
  • Example 28 — Qwen3-Embedding fine-tuning (4-bit)
  • Example 31 — Microsoft Harrier 0.6B (cross-lingual search)
  • Example 32 — Microsoft Harrier 270M (code/doc search)

Chat Templates

mlx_tune.chat_templates

get_chat_template()

get_chat_template(tokenizer, chat_template="auto", ...) → Tokenizer

Apply a chat template to the tokenizer. Supports 15 model families with "auto" detection from model name.

TemplateAliases
llama-3llama3, llama-3.1, llama-3.2
gemmagemma-2, gemma2
qwen-2.5qwen25, qwen2.5
qwen-3qwen3
phi-3phi3
phi-4phi4
mistral-7bmistral
deepseekdeepseek-v2
command-rcohere
llama-2llama2
neural-chat
solar
tulu-2
zephyr
alpaca

train_on_responses_only()

train_on_responses_only(trainer, instruction_part, response_part, ...) → Trainer

Modify trainer to compute loss only on assistant response tokens (prompt tokens are masked). Significantly improves training quality.

from mlx_tune import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
    response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
)

Dataset Utilities

to_sharegpt(dataset, conversation_extension="auto", ...) → dataset — Merge multi-turn conversations
detect_dataset_format(sample) → str — Returns "alpaca", "sharegpt", or "chatml"
apply_column_mapping(dataset, column_mapping) → dataset — Rename columns with a mapping dict
infer_column_mapping(dataset) → Dict — Auto-detect column mapping from dataset
HFDatasetConfig(dataset_name, split, ...) — Structured HuggingFace dataset configuration
load_dataset_with_config(config) → dataset — Load dataset from an HFDatasetConfig

Save, merge, and export

After training, you have several options for saving and deploying your model.

Save LoRA adapters

model.save_pretrained("./lora_model")

Merge LoRA into base model

model.save_pretrained_merged("./merged", tokenizer)

Export to GGUF

model.save_pretrained_gguf("./gguf", tokenizer, quantization_method="q4_k_m")

Convert HuggingFace model to MLX

FastLanguageModel.convert("meta-llama/Llama-3-8B", quantize=True)

Push to HuggingFace Hub

model.push_to_hub("username/my-model")
GGUF Limitation

GGUF export works with non-quantized models only. If you loaded with load_in_4bit=True, you must reload the base model without quantization before exporting to GGUF. This is an mlx-lm limitation.

Example scripts

Ready-to-run examples covering the full LLM fine-tuning workflow.

Example Description
01 – 03 Basics: model loading, LoRA configuration, inference
04 – 07 SFT training: dataset preparation, training loop, saving
08 Full Unsloth-compatible SFT pipeline (recommended starting point)
09 RL training overview: all 5 methods in one script
21 DPO — E2E preference tuning with Qwen3.5
22 GRPO — Reasoning training with custom reward functions (DeepSeek R1 style)
23 ORPO — Combined SFT + preference alignment (no reference model)
24 KTO — Binary feedback training (no paired preferences needed)
25 SimPO — Simple preference optimization (no reference model)
29 MoE — Qwen3.5-35B-A3B fine-tuning (35B total, 3B active)
30 MoE — Phi-3.5-MoE fine-tuning (42B total, Microsoft)
54 Trinity-Nano (AFMoE) — instruction SFT on the 6B/1B-active MoE with 128 experts + 1 shared
55 Trinity-Nano (AFMoE) — GRPO reasoning with <reasoning>/<answer> reward shaping
56 Trinity-Nano (AFMoE) — continual pretraining with decoupled embedding LR
41 LFM2 — Liquid Foundation Model SFT fine-tuning (hybrid gated-conv + GQA)
42 LFM2 — Thinking/reasoning fine-tuning
43 CPT — Language adaptation with continual pretraining
44 CPT — Domain knowledge injection (medical/legal/scientific)
45 CPT — Code capabilities extension
46 CPT + LFM2 — Continual pretraining on Liquid Foundation Model

Browse all examples on GitHub · See the Examples page for code snippets.