Vision Model Fine-Tuning
Fine-tune vision-language models like Gemma 4, Qwen3.5 on Apple Silicon. Real LoRA on both vision encoder and language model.
VLM fine-tuning in 25 lines
A complete Qwen3.5 Vision fine-tuning pipeline.
from mlx_tune import FastVisionModel, VLMSFTTrainer, VLMSFTConfig, UnslothVisionDataCollator
from datasets import load_dataset
# Load a vision-language model
model, processor = FastVisionModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16",
max_seq_length=1024,
)
# Add LoRA to both vision encoder and language model
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers=True,
finetune_language_layers=True,
r=16, lora_alpha=16,
)
# Load your dataset
dataset = load_dataset("your-dataset", split="train[:100]")
# Train
FastVisionModel.for_training(model)
collator = UnslothVisionDataCollator(model, processor)
trainer = VLMSFTTrainer(
model=model, tokenizer=processor,
data_collator=collator, train_dataset=dataset,
args=VLMSFTConfig(output_dir="./vlm_output", max_steps=30),
)
trainer.train()
# Inference
FastVisionModel.for_inference(model)
response = model.generate("Describe this image.", image="photo.jpg")
print(response)
FastVisionModel
mlx_tune.vlmFastVisionModel.from_pretrained()
Load a vision-language model from HuggingFace. Returns a model wrapper and a processor (not a tokenizer).
| Parameter | Type | Description |
|---|---|---|
model_name | str | HuggingFace model ID (e.g., "mlx-community/Qwen3.5-0.8B-bf16") or local path |
max_seq_length | int, optional | Maximum sequence length for training/inference |
load_in_4bit | bool | Load model with 4-bit quantization |
Unlike FastLanguageModel, this returns a processor as the second value (not a tokenizer). The processor handles both text tokenization and image preprocessing.
FastVisionModel.get_peft_model()
Add LoRA adapters to vision, audio, and/or language components of the model.
| Parameter | Type | Default | Description |
|---|---|---|---|
finetune_vision_layers | bool | True | Apply LoRA to vision encoder layers |
finetune_language_layers | bool | True | Apply LoRA to language model layers |
finetune_audio_layers | bool | False | Apply LoRA to audio tower layers (Gemma 4 E2B/E4B Conformer) |
finetune_attention_modules | bool | True | Apply LoRA to attention modules (q, k, v, o projections) |
finetune_mlp_modules | bool | True | Apply LoRA to MLP/feed-forward modules |
r | int | 16 | LoRA rank (higher = more parameters, better quality) |
lora_alpha | int | 16 | LoRA scaling factor. Recommended: equal to r |
FastVisionModel.for_training()
Enable training mode. Required before starting any training loop.
FastVisionModel.for_inference()
Enable inference mode. Activates KV caching and disables dropout. Always call before generating.
VLMModelWrapper
mlx_tune.vlmWrapper returned by FastVisionModel.from_pretrained(). Provides Unsloth-compatible methods for generation, saving, and loading.
model.generate()
Generate a response from a text prompt with an optional image or audio. Works for vision+text, audio+text, and text-only inputs.
| Parameter | Type | Description |
|---|---|---|
prompt | str | Text prompt for the model |
image | str, optional | Path to an image file, or None for text-only |
audio | str, optional | Path to a .wav audio file for STT/ASR (Gemma 4 E2B/E4B only) |
max_tokens | int | Maximum number of tokens to generate |
temperature | float | Sampling temperature (0.0 = greedy) |
min_p | float | Minimum probability threshold for sampling (recommended: 0.1) |
model.save_pretrained()
Save LoRA adapters to disk. Writes adapters.safetensors, adapter_config.json (mlx-lm compatible format), and config.json.
model.load_adapter()
Load previously saved LoRA adapters into the model.
model.save_pretrained_merged()
Fuse LoRA weights into the base model and save the full merged model.
Training
mlx_tune.vlmVLMSFTTrainer
Native MLX training loop for vision-language models. Handles forward pass, loss computation, and gradient updates internally.
| Parameter | Type | Description |
|---|---|---|
model | VLMModelWrapper | Model with LoRA adapters configured |
tokenizer | Processor | Processor from FastVisionModel.from_pretrained() |
data_collator | UnslothVisionDataCollator | Data collator for image/text batching |
train_dataset | Dataset | HuggingFace dataset with image and text fields |
args | VLMSFTConfig | Training configuration |
Batch size is forced to 1. Images produce variable vision token counts per sample (e.g., Qwen models generate different num_patches per image), so batching is not possible.
VLMSFTConfig
Training configuration for VLM fine-tuning. Compatible with TRL’s SFTConfig parameters.
| Parameter | Default | Description |
|---|---|---|
per_device_train_batch_size | 1 | Batch size (forced to 1 for VLM) |
max_steps | 30 | Total training steps |
learning_rate | 2e-4 | Peak learning rate |
output_dir | "./vlm_outputs" | Directory for checkpoints and logs |
train_on_completions | True | Compute loss only on assistant response tokens |
gradient_accumulation_steps | 4 | Number of steps to accumulate gradients before updating |
UnslothVisionDataCollator
Data collator for vision tasks. Handles image preprocessing, vision token insertion via the processor’s chat template, and batch preparation for VLM training.
| Parameter | Type | Description |
|---|---|---|
model | VLMModelWrapper | The vision-language model |
processor | Processor | Processor from FastVisionModel.from_pretrained() |
VLMGRPOTrainer
Group Relative Policy Optimization (GRPO) trainer for vision-language models. Generates multiple completions per prompt, scores them with a reward function, and updates the policy using relative advantages within each group.
| Parameter | Type | Description |
|---|---|---|
model | VLMModelWrapper | Model with LoRA adapters from FastVisionModel.get_peft_model() |
train_dataset | list[dict] | List of dicts with prompt (str), image (PIL Image or file path), and answer (str) keys |
processor | Processor | Processor from FastVisionModel.from_pretrained() |
reward_fn | Callable | Function (response_text, ground_truth) → float that scores each completion |
args | VLMGRPOConfig | Training configuration |
For each sample, GRPO generates num_generations completions, scores them with your reward_fn, computes advantages relative to the group mean, and updates the model to favor higher-reward responses. A KL penalty (controlled by beta) prevents the policy from diverging too far from the reference.
VLMGRPOConfig
Configuration for VLM GRPO training.
| Parameter | Default | Description |
|---|---|---|
beta | 0.04 | KL penalty coefficient. Higher values keep the model closer to the reference policy |
num_generations | 2 | Number of completions generated per prompt for advantage estimation |
temperature | 0.7 | Sampling temperature for generation (higher = more diverse completions) |
max_completion_length | 128 | Maximum tokens per generated completion |
output_dir | "./vlm_grpo_outputs" | Directory for checkpoints and logs |
learning_rate | 1e-6 | Learning rate (typically lower than SFT) |
max_steps | -1 | Maximum training steps. -1 trains for one full epoch |
logging_steps | 1 | Log training metrics every N steps |
save_steps | 100 | Save checkpoint every N steps |
GRPO usage example
from mlx_tune import FastVisionModel
from mlx_tune.vlm import VLMGRPOTrainer, VLMGRPOConfig
# Load and configure model
model, processor = FastVisionModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16",
)
model = FastVisionModel.get_peft_model(model, r=16, lora_alpha=16)
# Define a reward function
def reward_fn(response, answer):
return 1.0 if answer.lower() in response.lower() else 0.0
# Prepare dataset: list of dicts with prompt, image, answer
vision_data = [
{"prompt": "What is in this image?", "image": "photo.jpg", "answer": "a cat"},
# ...
]
# Train with GRPO
FastVisionModel.for_training(model)
trainer = VLMGRPOTrainer(
model=model,
train_dataset=vision_data,
processor=processor,
reward_fn=reward_fn,
args=VLMGRPOConfig(num_generations=2, max_steps=10),
)
result = trainer.train()
Save, load, and merge
After training, you can save adapters for later use, reload them into a fresh model, or merge LoRA weights into the base model.
Save adapters
# Save LoRA adapters only (small files)
model.save_pretrained("./vlm_adapters")
Load adapters
# Load adapters into a fresh model
model, processor = FastVisionModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16",
max_seq_length=1024,
)
model.load_adapter("./vlm_adapters")
Merge LoRA into base model
# Fuse LoRA weights and save full model
model.save_pretrained_merged("./vlm_merged", processor)
Working examples
Complete scripts you can run directly.
10 — Qwen3.5 Vision Fine-Tuning
Fine-tune Qwen3.5 on image-text pairs with LoRA on both vision and language layers.
11 — Qwen3.5 Text-Only VLM Fine-Tuning
Fine-tune a VLM on text-only data without images. Useful for improving the language component.
26 — Vision GRPO Training
End-to-end Vision GRPO training extending GRPO to VLMs. Generates completions conditioned on image+text prompts, scored by reward functions.
47 — Gemma 4 Audio ASR
Fine-tune Gemma 4 E4B for speech-to-text via the built-in Conformer audio tower. Language-layer LoRA.
48 — Gemma 4 Audio Understanding
Audio QA with optional audio tower LoRA for domain-specific acoustic adaptation.
Best practices
Gemma 4 E2B/E4B have a 12-layer Conformer audio tower for STT/ASR. Use {"type": "audio", "audio": "path.wav"} in your dataset messages. Audio goes through the VLM pipeline (FastVisionModel), not FastSTTModel. Set finetune_audio_layers=True to apply LoRA to the Conformer attention layers for domain-specific acoustic adaptation. The 26B and 31B variants do not have audio towers.
VLM training is forced to batch_size=1. Each image produces a different number of vision tokens (variable num_patches), so samples cannot be stacked into batches. Use gradient_accumulation_steps to simulate larger effective batch sizes.
Qwen3.5 models generate <think>...</think> reasoning tags in their output. mlx-tune automatically strips these from generate() results so you get clean responses.
You can fine-tune a VLM on text-only data (no images needed). This is useful for improving the language model component while preserving vision capabilities.
Enabled by default via train_on_completions=True in VLMSFTConfig. Loss is computed only on assistant response tokens, not on the prompt or system message. This significantly improves training quality.
Always call FastVisionModel.for_training(model) before training and FastVisionModel.for_inference(model) before generation. Forgetting this will cause incorrect behavior or errors.
16 GB+ unified RAM recommended for 3B VLMs. Larger models (7B+) need 32 GB+ or more. Use 4-bit quantized models from mlx-community/ to reduce memory usage.