API Reference

All public APIs exported from mlx_tune. Import everything from the top-level package.

from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
# ... or any other export listed below

Core Model

mlx_tune.model

FastLanguageModel

Main entry point for loading and configuring language models. Mirrors Unsloth’s FastLanguageModel API.

FastLanguageModel.from_pretrained()

FastLanguageModel.from_pretrained(model_name, max_seq_length=None, load_in_4bit=False, load_in_8bit=False, use_gradient_checkpointing="unsloth", ...) → Tuple[MLXModelWrapper, Tokenizer]

Load a pretrained language model from HuggingFace.

ParameterTypeDescription
model_namestrHuggingFace model ID (e.g., "mlx-community/Llama-3.2-1B-Instruct-4bit") or local path
max_seq_lengthint, optionalMaximum sequence length for training/inference
load_in_4bitboolLoad model with 4-bit quantization (QLoRA)
load_in_8bitboolLoad model with 8-bit quantization

FastLanguageModel.get_peft_model()

FastLanguageModel.get_peft_model(model, r=16, target_modules=None, lora_alpha=16, lora_dropout=0.0, bias="none", use_rslora=False, random_state=3407, ...) → MLXModelWrapper

Add LoRA adapters to the model for parameter-efficient fine-tuning.

ParameterTypeDescription
modelMLXModelWrapperModel from from_pretrained()
rintLoRA rank (higher = more parameters, better quality)
target_moduleslist[str], optionalModules to apply LoRA to. Default: ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_alphaintLoRA scaling factor. Recommended: equal to r
lora_dropoutfloatDropout for LoRA layers

FastLanguageModel.for_inference()

FastLanguageModel.for_inference(model, use_cache=True) → MLXModelWrapper

Enable inference mode: activates KV caching, disables dropout. Always call before generating.

MLXModelWrapper

Internal wrapper providing Unsloth-compatible methods on MLX models. Returned by FastLanguageModel.from_pretrained().

Key Methods

model.save_pretrained(output_dir) — Save LoRA adapters only
model.save_pretrained_merged(output_dir, tokenizer) — Save full merged model
model.save_pretrained_gguf(output_dir, tokenizer, dequantize=False) — Export to GGUF format
model.generate(prompt, max_tokens=256, temperature=0.7, min_p=None, ...) — Generate text
model.stream_generate(prompt, ...) — Stream text token by token

SFT Training

mlx_tune.sft_trainer

SFTTrainer

SFTTrainer(model, train_dataset, tokenizer=None, eval_dataset=None, args=None, ...)

Supervised fine-tuning trainer. API-compatible with TRL’s SFTTrainer.

ParameterTypeDescription
modelMLXModelWrapperModel with LoRA adapters configured
train_datasetDatasetHuggingFace dataset or list of dicts
tokenizerTokenizerTokenizer from from_pretrained()
argsSFTConfigTraining configuration
trainer.train() — Start training. Returns training statistics.
trainer.save_model(output_dir) — Save the trained model.

SFTConfig

SFTConfig(output_dir="outputs", per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-4, lr_scheduler_type="linear", warmup_steps=5, num_train_epochs=1, max_steps=-1, logging_steps=1, save_steps=500, max_seq_length=2048, optim="adam", weight_decay=0.01, seed=3407, report_to="none", ...)

Training configuration. Compatible with TRL’s SFTConfig parameters.

ParameterDefaultDescription
output_dir"outputs"Directory for checkpoints and logs
per_device_train_batch_size2Batch size per device
gradient_accumulation_steps4Number of gradient accumulation steps
learning_rate2e-4Peak learning rate
max_steps-1Total training steps (-1 = use epochs)
max_seq_length2048Maximum sequence length
optim"adam"Optimizer (use "adam" for MLX)
warmup_steps5Linear warmup steps
lr_scheduler_type"linear"LR scheduler: linear, cosine, constant

RL Trainers

mlx_tune.rl_trainers

DPOTrainer

DPOTrainer(model, train_dataset, args=None, tokenizer=None, ...)

Direct Preference Optimization trainer. Uses proper DPO loss with log-probability computation over chosen/rejected pairs.

DPOConfig(output_dir, beta=0.1, max_steps=-1, learning_rate=5e-5, ...)

ORPOTrainer

ORPOTrainer(model, train_dataset, args=None, ...)

Odds Ratio Preference Optimization. Combines SFT loss with odds-ratio preference alignment.

ORPOConfig(output_dir, lambda_orpo=1.0, ...)

GRPOTrainer

GRPOTrainer(model, train_dataset, args=None, reward_fn=None, ...)

Group Relative Policy Optimization (DeepSeek R1 style). Generates multiple completions per prompt and optimizes based on relative rewards.

GRPOConfig(output_dir, num_generations=4, kl_coeff=0.1, ...)

KTOTrainer & SimPOTrainer

KTOTrainer(model, train_dataset, args=None, ...) — Kahneman-Tversky Optimization
SimPOTrainer(model, train_dataset, args=None, ...) — Simple Preference Optimization

Utilities

prepare_preference_dataset(dataset, ...) → dataset
create_reward_function(reward_type="simple") → Callable

Vision Models

mlx_tune.vlm

FastVisionModel

Vision-Language Model API. Mirrors Unsloth’s FastVisionModel.

FastVisionModel.from_pretrained()

FastVisionModel.from_pretrained(model_name, max_seq_length=None, load_in_4bit=False, ...) → Tuple[VLMModelWrapper, Processor]

Load a vision-language model. Returns model wrapper and processor (not tokenizer).

FastVisionModel.get_peft_model()

FastVisionModel.get_peft_model(model, finetune_vision_layers=True, finetune_language_layers=True, finetune_attention_modules=True, finetune_mlp_modules=True, r=16, lora_alpha=16, ...) → VLMModelWrapper

Add LoRA adapters to vision and/or language components.

ParameterTypeDescription
finetune_vision_layersboolApply LoRA to vision encoder
finetune_language_layersboolApply LoRA to language model
finetune_attention_modulesboolApply LoRA to attention modules
finetune_mlp_modulesboolApply LoRA to MLP modules
FastVisionModel.for_training(model) — Enable training mode
FastVisionModel.for_inference(model) — Enable inference mode

VLMSFTTrainer

VLMSFTTrainer(model, train_dataset, tokenizer, data_collator=None, args=None, ...)

Vision-Language model trainer. Batch size is forced to 1 (images produce variable token counts).

VLMSFTConfig

VLMSFTConfig(per_device_train_batch_size=1, max_steps=30, learning_rate=2e-4, ...) — VLM-specific training config

UnslothVisionDataCollator

UnslothVisionDataCollator(model, processor) — Data collator for vision tasks

Handles image preprocessing, vision token insertion, and batch preparation for VLM training.

Chat Templates

mlx_tune.chat_templates

get_chat_template()

get_chat_template(tokenizer, chat_template="auto", ...) → Tokenizer

Apply a chat template to the tokenizer. Supports 15 model templates.

ParameterDescription
tokenizerTokenizer to update
chat_templateTemplate name or "auto" for auto-detection. Options: llama-3, llama-2, gemma, qwen-2.5, qwen-3, phi-3, phi-4, mistral-7b, deepseek, command-r, neural-chat, solar, tulu-2, zephyr, alpaca

train_on_responses_only()

train_on_responses_only(trainer, instruction_part, response_part, ...) → Trainer

Modify trainer to compute loss only on assistant response tokens (not prompts). Significantly improves training quality.

Dataset Utilities

to_sharegpt(dataset, conversation_extension="auto", ...) → dataset — Merge multi-turn conversations
detect_dataset_format(sample) → str — Returns "alpaca", "sharegpt", or "chatml"
standardize_sharegpt(dataset) → dataset
standardize_sharegpt_enhanced(dataset, role_mapping={}, ...) → dataset
convert_to_mlx_format(dataset, ...) → dataset
apply_column_mapping(dataset, column_mapping) → dataset — Rename columns
infer_column_mapping(dataset) → Dict — Auto-detect column mapping
HFDatasetConfig(dataset_name, split, ...) — Structured dataset configuration
load_dataset_with_config(config) → dataset

Template Helpers

list_chat_templates() → List[str] — List all available template names
get_template_info(template_name) → Dict — Get template details
get_template_for_model(model_name) → str — Auto-detect template from model name

Constants

CHAT_TEMPLATES — Dict of all template definitions (Jinja2)
TEMPLATE_ALIASES — Dict of alias mappings (e.g., "llama3""llama-3")
DEFAULT_SYSTEM_MESSAGES — Dict of default system prompts per model

Loss Functions

mlx_tune.losses

Low-level loss functions for custom training loops.

sft_loss(logits, targets, ...) → loss
dpo_loss(chosen_logits, rejected_logits, ...) → loss
orpo_loss(chosen_logits, rejected_logits, ...) → loss
kto_loss(generated_logits, ...) → loss
simpo_loss(chosen_logits, rejected_logits, ...) → loss
grpo_loss(completions, rewards, ...) → loss
grpo_batch_loss(batch_completions, batch_rewards, ...) → loss
compute_log_probs(logits, tokens, ...) → log_probs
compute_log_probs_with_lengths(logits, tokens, lengths, ...) → log_probs
compute_reference_logprobs(model, tokens, ...) → logprobs

Utilities

mlx_tune.trainer
prepare_dataset(dataset_name=None, dataset_path=None, split="train", ...) → dataset
format_chat_template(messages, tokenizer, add_generation_prompt=False) → str
create_training_data(dataset, ...) → formatted_data
save_model_hf_format(model, tokenizer, output_path)
export_to_gguf(model, tokenizer, output_path, ...) — See GGUF limitations
get_training_config(model_name, ...) → Dict
load_vlm_dataset(dataset, ...) → dataset