Vision Model Fine-Tuning

Fine-tune vision-language models like Gemma 4, Qwen3.5 on Apple Silicon. Real LoRA on both vision encoder and language model.

VLM fine-tuning in 25 lines

A complete Qwen3.5 Vision fine-tuning pipeline.

from mlx_tune import FastVisionModel, VLMSFTTrainer, VLMSFTConfig, UnslothVisionDataCollator
from datasets import load_dataset

# Load a vision-language model
model, processor = FastVisionModel.from_pretrained(
    "mlx-community/Qwen3.5-0.8B-bf16",
    max_seq_length=1024,
)

# Add LoRA to both vision encoder and language model
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers=True,
    finetune_language_layers=True,
    r=16, lora_alpha=16,
)

# Load your dataset
dataset = load_dataset("your-dataset", split="train[:100]")

# Train
FastVisionModel.for_training(model)
collator = UnslothVisionDataCollator(model, processor)
trainer = VLMSFTTrainer(
    model=model, tokenizer=processor,
    data_collator=collator, train_dataset=dataset,
    args=VLMSFTConfig(output_dir="./vlm_output", max_steps=30),
)
trainer.train()

# Inference
FastVisionModel.for_inference(model)
response = model.generate("Describe this image.", image="photo.jpg")
print(response)

FastVisionModel

mlx_tune.vlm

FastVisionModel.from_pretrained()

FastVisionModel.from_pretrained(model_name, max_seq_length=None, load_in_4bit=False, ...) → Tuple[VLMModelWrapper, Processor]

Load a vision-language model from HuggingFace. Returns a model wrapper and a processor (not a tokenizer).

ParameterTypeDescription
model_namestrHuggingFace model ID (e.g., "mlx-community/Qwen3.5-0.8B-bf16") or local path
max_seq_lengthint, optionalMaximum sequence length for training/inference
load_in_4bitboolLoad model with 4-bit quantization
Note

Unlike FastLanguageModel, this returns a processor as the second value (not a tokenizer). The processor handles both text tokenization and image preprocessing.

FastVisionModel.get_peft_model()

FastVisionModel.get_peft_model(model, finetune_vision_layers=True, finetune_language_layers=True, finetune_audio_layers=False, finetune_attention_modules=True, finetune_mlp_modules=True, r=16, lora_alpha=16, ...) → VLMModelWrapper

Add LoRA adapters to vision, audio, and/or language components of the model.

ParameterTypeDefaultDescription
finetune_vision_layersboolTrueApply LoRA to vision encoder layers
finetune_language_layersboolTrueApply LoRA to language model layers
finetune_audio_layersboolFalseApply LoRA to audio tower layers (Gemma 4 E2B/E4B Conformer)
finetune_attention_modulesboolTrueApply LoRA to attention modules (q, k, v, o projections)
finetune_mlp_modulesboolTrueApply LoRA to MLP/feed-forward modules
rint16LoRA rank (higher = more parameters, better quality)
lora_alphaint16LoRA scaling factor. Recommended: equal to r

FastVisionModel.for_training()

FastVisionModel.for_training(model)

Enable training mode. Required before starting any training loop.

FastVisionModel.for_inference()

FastVisionModel.for_inference(model)

Enable inference mode. Activates KV caching and disables dropout. Always call before generating.

VLMModelWrapper

mlx_tune.vlm

Wrapper returned by FastVisionModel.from_pretrained(). Provides Unsloth-compatible methods for generation, saving, and loading.

model.generate()

model.generate(prompt, image=None, audio=None, max_tokens=256, temperature=0.0, min_p=0.0) → str

Generate a response from a text prompt with an optional image or audio. Works for vision+text, audio+text, and text-only inputs.

ParameterTypeDescription
promptstrText prompt for the model
imagestr, optionalPath to an image file, or None for text-only
audiostr, optionalPath to a .wav audio file for STT/ASR (Gemma 4 E2B/E4B only)
max_tokensintMaximum number of tokens to generate
temperaturefloatSampling temperature (0.0 = greedy)
min_pfloatMinimum probability threshold for sampling (recommended: 0.1)

model.save_pretrained()

model.save_pretrained(output_dir)

Save LoRA adapters to disk. Writes adapters.safetensors, adapter_config.json (mlx-lm compatible format), and config.json.

model.load_adapter()

model.load_adapter(adapter_path)

Load previously saved LoRA adapters into the model.

model.save_pretrained_merged()

model.save_pretrained_merged(output_dir, processor)

Fuse LoRA weights into the base model and save the full merged model.

Training

mlx_tune.vlm

VLMSFTTrainer

VLMSFTTrainer(model, tokenizer, data_collator, train_dataset, args=None)

Native MLX training loop for vision-language models. Handles forward pass, loss computation, and gradient updates internally.

ParameterTypeDescription
modelVLMModelWrapperModel with LoRA adapters configured
tokenizerProcessorProcessor from FastVisionModel.from_pretrained()
data_collatorUnslothVisionDataCollatorData collator for image/text batching
train_datasetDatasetHuggingFace dataset with image and text fields
argsVLMSFTConfigTraining configuration
trainer.train() — Start training. Returns training statistics.
Important

Batch size is forced to 1. Images produce variable vision token counts per sample (e.g., Qwen models generate different num_patches per image), so batching is not possible.

VLMSFTConfig

VLMSFTConfig(per_device_train_batch_size=1, max_steps=30, learning_rate=2e-4, output_dir="./vlm_outputs", train_on_completions=True, gradient_accumulation_steps=4, ...)

Training configuration for VLM fine-tuning. Compatible with TRL’s SFTConfig parameters.

ParameterDefaultDescription
per_device_train_batch_size1Batch size (forced to 1 for VLM)
max_steps30Total training steps
learning_rate2e-4Peak learning rate
output_dir"./vlm_outputs"Directory for checkpoints and logs
train_on_completionsTrueCompute loss only on assistant response tokens
gradient_accumulation_steps4Number of steps to accumulate gradients before updating

UnslothVisionDataCollator

UnslothVisionDataCollator(model, processor)

Data collator for vision tasks. Handles image preprocessing, vision token insertion via the processor’s chat template, and batch preparation for VLM training.

ParameterTypeDescription
modelVLMModelWrapperThe vision-language model
processorProcessorProcessor from FastVisionModel.from_pretrained()

VLMGRPOTrainer

VLMGRPOTrainer(model, train_dataset, processor, reward_fn, args=None)

Group Relative Policy Optimization (GRPO) trainer for vision-language models. Generates multiple completions per prompt, scores them with a reward function, and updates the policy using relative advantages within each group.

ParameterTypeDescription
modelVLMModelWrapperModel with LoRA adapters from FastVisionModel.get_peft_model()
train_datasetlist[dict]List of dicts with prompt (str), image (PIL Image or file path), and answer (str) keys
processorProcessorProcessor from FastVisionModel.from_pretrained()
reward_fnCallableFunction (response_text, ground_truth) → float that scores each completion
argsVLMGRPOConfigTraining configuration
trainer.train() — Start GRPO training. Returns training statistics.
How GRPO works

For each sample, GRPO generates num_generations completions, scores them with your reward_fn, computes advantages relative to the group mean, and updates the model to favor higher-reward responses. A KL penalty (controlled by beta) prevents the policy from diverging too far from the reference.

VLMGRPOConfig

VLMGRPOConfig(beta=0.04, num_generations=2, temperature=0.7, max_completion_length=128, output_dir="./vlm_grpo_outputs", learning_rate=1e-6, max_steps=-1, logging_steps=1, save_steps=100)

Configuration for VLM GRPO training.

ParameterDefaultDescription
beta0.04KL penalty coefficient. Higher values keep the model closer to the reference policy
num_generations2Number of completions generated per prompt for advantage estimation
temperature0.7Sampling temperature for generation (higher = more diverse completions)
max_completion_length128Maximum tokens per generated completion
output_dir"./vlm_grpo_outputs"Directory for checkpoints and logs
learning_rate1e-6Learning rate (typically lower than SFT)
max_steps-1Maximum training steps. -1 trains for one full epoch
logging_steps1Log training metrics every N steps
save_steps100Save checkpoint every N steps

GRPO usage example

from mlx_tune import FastVisionModel
from mlx_tune.vlm import VLMGRPOTrainer, VLMGRPOConfig

# Load and configure model
model, processor = FastVisionModel.from_pretrained(
    "mlx-community/Qwen3.5-0.8B-bf16",
)
model = FastVisionModel.get_peft_model(model, r=16, lora_alpha=16)

# Define a reward function
def reward_fn(response, answer):
    return 1.0 if answer.lower() in response.lower() else 0.0

# Prepare dataset: list of dicts with prompt, image, answer
vision_data = [
    {"prompt": "What is in this image?", "image": "photo.jpg", "answer": "a cat"},
    # ...
]

# Train with GRPO
FastVisionModel.for_training(model)
trainer = VLMGRPOTrainer(
    model=model,
    train_dataset=vision_data,
    processor=processor,
    reward_fn=reward_fn,
    args=VLMGRPOConfig(num_generations=2, max_steps=10),
)
result = trainer.train()

Save, load, and merge

After training, you can save adapters for later use, reload them into a fresh model, or merge LoRA weights into the base model.

Save adapters

# Save LoRA adapters only (small files)
model.save_pretrained("./vlm_adapters")

Load adapters

# Load adapters into a fresh model
model, processor = FastVisionModel.from_pretrained(
    "mlx-community/Qwen3.5-0.8B-bf16",
    max_seq_length=1024,
)
model.load_adapter("./vlm_adapters")

Merge LoRA into base model

# Fuse LoRA weights and save full model
model.save_pretrained_merged("./vlm_merged", processor)

Working examples

Complete scripts you can run directly.

10 — Qwen3.5 Vision Fine-Tuning

Fine-tune Qwen3.5 on image-text pairs with LoRA on both vision and language layers.

VLMVisionQwen

View source →

11 — Qwen3.5 Text-Only VLM Fine-Tuning

Fine-tune a VLM on text-only data without images. Useful for improving the language component.

VLMText-OnlyQwen

View source →

26 — Vision GRPO Training

End-to-end Vision GRPO training extending GRPO to VLMs. Generates completions conditioned on image+text prompts, scored by reward functions.

VLMGRPORL

View source →

47 — Gemma 4 Audio ASR

Fine-tune Gemma 4 E4B for speech-to-text via the built-in Conformer audio tower. Language-layer LoRA.

VLMAudioGemma 4

View source →

48 — Gemma 4 Audio Understanding

Audio QA with optional audio tower LoRA for domain-specific acoustic adaptation.

VLMAudioAudio LoRA

View source →

Best practices

Gemma 4 Audio Fine-Tuning

Gemma 4 E2B/E4B have a 12-layer Conformer audio tower for STT/ASR. Use {"type": "audio", "audio": "path.wav"} in your dataset messages. Audio goes through the VLM pipeline (FastVisionModel), not FastSTTModel. Set finetune_audio_layers=True to apply LoRA to the Conformer attention layers for domain-specific acoustic adaptation. The 26B and 31B variants do not have audio towers.

Batch size must be 1

VLM training is forced to batch_size=1. Each image produces a different number of vision tokens (variable num_patches), so samples cannot be stacked into batches. Use gradient_accumulation_steps to simulate larger effective batch sizes.

Qwen3.5 think tags

Qwen3.5 models generate <think>...</think> reasoning tags in their output. mlx-tune automatically strips these from generate() results so you get clean responses.

Text-only VLM training

You can fine-tune a VLM on text-only data (no images needed). This is useful for improving the language model component while preserving vision capabilities.

Response-only training

Enabled by default via train_on_completions=True in VLMSFTConfig. Loss is computed only on assistant response tokens, not on the prompt or system message. This significantly improves training quality.

Training vs inference mode

Always call FastVisionModel.for_training(model) before training and FastVisionModel.for_inference(model) before generation. Forgetting this will cause incorrect behavior or errors.

Memory requirements

16 GB+ unified RAM recommended for 3B VLMs. Larger models (7B+) need 32 GB+ or more. Use 4-bit quantized models from mlx-community/ to reduce memory usage.