Examples
Standalone Python scripts covering model loading, fine-tuning, RL training, and vision models. Run any example directly with python examples/<filename>.
Basics
01 — Simple Loading
Load a model from HuggingFace with FastLanguageModel.from_pretrained().
from mlx_tune import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/Llama-3.2-1B-Instruct-4bit",
max_seq_length=2048,
load_in_4bit=True,
)
print("Model loaded successfully!")
02 — LoRA Configuration
Add LoRA adapters with get_peft_model() for parameter-efficient training.
from mlx_tune import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/Llama-3.2-1B-Instruct-4bit",
max_seq_length=2048, load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=16,
lora_dropout=0,
)
print(f"LoRA configured: rank={16}")
03 — Inference
Generate text with a loaded model using for_inference().
from mlx_tune import FastLanguageModel
from mlx_lm import generate
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/Llama-3.2-1B-Instruct-4bit",
max_seq_length=2048, load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
prompt = "What is machine learning?"
response = generate(model.model, tokenizer,
prompt=prompt, max_tokens=100)
print(response)
SFT Training
04 — Simple Fine-tuning
LoRA setup and basic training configuration walkthrough.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/Llama-3.2-1B-Instruct-4bit",
max_seq_length=2048, load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
05 — Complete Workflow
Full pipeline: load, configure, train, and save a fine-tuned model.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
from datasets import load_dataset
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/Llama-3.2-1B-Instruct-4bit",
max_seq_length=2048, load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:100]")
trainer = SFTTrainer(
model=model, train_dataset=dataset, tokenizer=tokenizer,
args=SFTConfig(output_dir="outputs", max_steps=50, learning_rate=2e-4),
)
trainer.train()
model.save_pretrained("lora_model")
06 — Real Training Test
Actual training run with SFTTrainer on a small dataset with logging.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
# Loads model, configures LoRA, trains on real data
# with logging_steps=1 to monitor training loss
# See full source for dataset preparation details
07 — Unsloth Comparison
Side-by-side comparison of Unsloth vs mlx-tune API usage.
# Shows Unsloth code (commented) alongside mlx-tune equivalent
# Demonstrates that the API is 100% compatible
# See workflow page for detailed translation guide
08 — Exact Unsloth Pipeline
Complete Unsloth-compatible SFT workflow with chat templates, response masking, and export.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
from mlx_tune import get_chat_template, train_on_responses_only
from datasets import load_dataset
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/Llama-3.2-1B-Instruct-4bit",
max_seq_length=2048, load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
tokenizer = get_chat_template(tokenizer, chat_template="llama-3")
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:100]")
trainer = SFTTrainer(
model=model, train_dataset=dataset, tokenizer=tokenizer,
args=SFTConfig(output_dir="outputs", max_steps=50, learning_rate=2e-4),
)
trainer = train_on_responses_only(trainer,
instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
trainer.train()
model.save_pretrained("lora_model")
model.save_pretrained_merged("merged", tokenizer)
41 — LFM2 SFT Fine-Tuning
Fine-tune Liquid AI’s LFM2 with SFT. Hybrid gated-conv + GQA architecture with ChatML format. Uses LFM2-specific target modules: in_proj, out_proj, w1, w2, w3.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/LFM2-350M-4bit",
max_seq_length=2048, load_in_4bit=True,
)
# LFM2-specific targets: out_proj/in_proj (attention), w1/w2/w3 (gated conv)
model = FastLanguageModel.get_peft_model(
model, r=16,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj",
"in_proj", "w1", "w2", "w3"],
lora_alpha=16,
)
trainer = SFTTrainer(
model=model, train_dataset=dataset, tokenizer=tokenizer,
args=SFTConfig(output_dir="outputs_lfm2", max_steps=50,
learning_rate=2e-4),
)
trainer.train()
42 — LFM2.5-Thinking Reasoning
Fine-tune LFM2.5-1.2B-Thinking for chain-of-thought reasoning. Uses <think> tags for internal reasoning (like Qwen3.5).
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/LFM2.5-1.2B-Thinking-4bit",
max_seq_length=4096, load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=16,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj",
"in_proj", "w1", "w2", "w3"],
)
trainer = SFTTrainer(
model=model, train_dataset=thinking_dataset, tokenizer=tokenizer,
args=SFTConfig(output_dir="outputs_lfm2_think", max_steps=50),
)
trainer.train()
RL Methods
09 — RL Training Methods (Overview)
Quick demo of all 5 RL trainers with preference datasets.
from mlx_tune import (
DPOTrainer, DPOConfig,
ORPOTrainer, ORPOConfig,
GRPOTrainer, GRPOConfig,
KTOTrainer, KTOConfig,
SimPOTrainer, SimPOConfig,
)
21 — DPO Preference Tuning
End-to-end DPO training with chosen/rejected preference pairs. Trains a real model with proper log-probability loss.
from mlx_tune import FastLanguageModel, DPOTrainer, DPOConfig
model, tokenizer = FastLanguageModel.from_pretrained("mlx-community/Qwen3.5-0.8B-MLX-4bit")
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"])
config = DPOConfig(beta=0.1, learning_rate=5e-7, max_steps=30)
trainer = DPOTrainer(model=model, train_dataset=preference_data,
tokenizer=tokenizer, args=config)
trainer.train()
22 — GRPO Reasoning Training
DeepSeek R1-style reasoning training. Generates multiple completions, scores with custom reward functions, and updates via policy gradient.
from mlx_tune import FastLanguageModel, GRPOTrainer, GRPOConfig
def combined_reward(response, ground_truth):
return 0.7 * correctness_reward(response, ground_truth) + \
0.3 * format_reward(response, ground_truth)
config = GRPOConfig(num_generations=2, max_completion_length=128,
learning_rate=1e-6, max_steps=10)
trainer = GRPOTrainer(model=model, train_dataset=reasoning_data,
tokenizer=tokenizer, reward_fn=combined_reward, args=config)
trainer.train()
23 — ORPO Preference Tuning
Combined SFT + preference alignment in one training step. No reference model needed — more memory efficient than DPO.
from mlx_tune import FastLanguageModel, ORPOTrainer, ORPOConfig
config = ORPOConfig(beta=0.1, learning_rate=8e-6, max_steps=30)
trainer = ORPOTrainer(model=model, train_dataset=preference_data,
tokenizer=tokenizer, args=config)
trainer.train()
24 — KTO Binary Feedback
Kahneman-Tversky Optimization with binary feedback. No paired preferences needed — just label responses as good or bad.
from mlx_tune import FastLanguageModel, KTOTrainer, KTOConfig
# Binary feedback: each sample independently labeled
kto_data = [
{"prompt": "What is ML?", "completion": "Machine learning is...", "label": True},
{"prompt": "What is ML?", "completion": "idk", "label": False},
]
config = KTOConfig(beta=0.1, learning_rate=5e-7, max_steps=30)
trainer = KTOTrainer(model=model, train_dataset=kto_data,
tokenizer=tokenizer, args=config)
trainer.train()
25 — SimPO Simple Preference
Simple Preference Optimization. No reference model needed — uses length-normalized log probs as implicit rewards.
from mlx_tune import FastLanguageModel, SimPOTrainer, SimPOConfig
config = SimPOConfig(beta=2.0, gamma=0.5, learning_rate=5e-7, max_steps=30)
trainer = SimPOTrainer(model=model, train_dataset=preference_data,
tokenizer=tokenizer, args=config)
trainer.train()
Vision Models
10 — Qwen3.5 Vision Fine-tuning
Fine-tune a vision-language model on image+text data (LaTeX OCR dataset).
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
from mlx_tune.vlm import VLMSFTConfig
from datasets import load_dataset
model, processor = FastVisionModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16", load_in_4bit=False,
)
model = FastVisionModel.get_peft_model(model,
finetune_vision_layers=True, finetune_language_layers=True,
r=16, lora_alpha=16,
)
dataset = load_dataset("unsloth/LaTeX_OCR", split="train")
FastVisionModel.for_training(model)
trainer = VLMSFTTrainer(
model=model, tokenizer=processor,
data_collator=UnslothVisionDataCollator(model, processor),
train_dataset=converted_dataset,
args=VLMSFTConfig(max_steps=30, learning_rate=2e-4),
)
trainer.train()
11 — Qwen3.5 Text Fine-tuning
Fine-tune Qwen3.5 on text-only data without requiring any images.
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
from mlx_tune.vlm import VLMSFTConfig
# Qwen3.5 can be fine-tuned on text-only data
# No images needed — the model handles pure text conversations
model, processor = FastVisionModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16", load_in_4bit=False,
)
# ... same training setup, just without image data
26 — Vision GRPO Training
End-to-end Vision GRPO training extending GRPO to VLMs. Generates completions conditioned on image+text prompts, scored by reward functions.
from mlx_tune import FastVisionModel
from mlx_tune.vlm import VLMGRPOTrainer, VLMGRPOConfig
model, processor = FastVisionModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16",
)
model = FastVisionModel.get_peft_model(model, r=16, lora_alpha=16)
trainer = VLMGRPOTrainer(
model=model, train_dataset=vision_data,
processor=processor, reward_fn=reward_fn,
args=VLMGRPOConfig(num_generations=2, max_steps=10),
)
trainer.train()
38 — Gemma 4 Vision Fine-Tuning
Fine-tune Google Gemma 4 E4B on LaTeX OCR. All Gemma 4 models are VLMs — use FastVisionModel.
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
from mlx_tune.vlm import VLMSFTConfig
model, processor = FastVisionModel.from_pretrained(
"mlx-community/gemma-4-e4b-it-4bit",
)
model = FastVisionModel.get_peft_model(
model, finetune_vision_layers=True,
finetune_language_layers=True, r=16, lora_alpha=16,
)
trainer = VLMSFTTrainer(
model=model, tokenizer=processor,
data_collator=UnslothVisionDataCollator(model, processor),
train_dataset=dataset,
args=VLMSFTConfig(max_steps=30, learning_rate=2e-4),
)
trainer.train()
39 — Gemma 4 Text-to-SQL
Text-only fine-tuning through the VLM path. Uses Google’s official Gemma 4 example dataset.
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
model, processor = FastVisionModel.from_pretrained(
"mlx-community/gemma-4-e4b-it-4bit",
)
model = FastVisionModel.get_peft_model(
model, finetune_vision_layers=False,
finetune_language_layers=True, r=16, lora_alpha=16,
)
# Text-only fine-tuning (no images needed)
trainer = VLMSFTTrainer(
model=model, tokenizer=processor,
data_collator=UnslothVisionDataCollator(model, processor),
train_dataset=sql_dataset,
args=VLMSFTConfig(max_steps=30, learning_rate=2e-4),
)
trainer.train()
40 — Gemma 4 MoE Fine-Tuning
Fine-tune Gemma 4 26B-A4B MoE (128 experts, top-8 routing + 1 shared). LoRA auto-targets expert layers.
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
model, processor = FastVisionModel.from_pretrained(
"mlx-community/gemma-4-26b-a4b-it-4bit",
)
model = FastVisionModel.get_peft_model(
model, finetune_vision_layers=True,
finetune_language_layers=True, r=8, lora_alpha=8,
)
trainer = VLMSFTTrainer(
model=model, tokenizer=processor,
data_collator=UnslothVisionDataCollator(model, processor),
train_dataset=dataset,
args=VLMSFTConfig(max_steps=20, learning_rate=1e-4),
)
trainer.train()
47 — Gemma 4 Audio ASR
Fine-tune Gemma 4 E4B for speech-to-text via the built-in 12-layer Conformer audio tower. Language-layer LoRA only.
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
from mlx_tune.vlm import VLMSFTConfig
model, processor = FastVisionModel.from_pretrained(
"mlx-community/gemma-4-e4b-it-4bit",
)
model = FastVisionModel.get_peft_model(model,
finetune_vision_layers=False, finetune_language_layers=True,
finetune_audio_layers=False, r=16, lora_alpha=16,
)
dataset = [{"messages": [
{"role": "user", "content": [
{"type": "audio", "audio": "audio.wav"},
{"type": "text", "text": "Transcribe this audio."},
]},
{"role": "assistant", "content": [
{"type": "text", "text": "The quick brown fox jumps over the lazy dog."},
]},
]}]
trainer = VLMSFTTrainer(model=model, tokenizer=processor,
data_collator=UnslothVisionDataCollator(model, processor),
train_dataset=dataset,
args=VLMSFTConfig(max_steps=30, learning_rate=2e-4),
)
trainer.train()
48 — Gemma 4 Audio Understanding
Audio QA fine-tuning with audio tower LoRA for domain-specific acoustic adaptation.
from mlx_tune import FastVisionModel, UnslothVisionDataCollator, VLMSFTTrainer
from mlx_tune.vlm import VLMSFTConfig
model, processor = FastVisionModel.from_pretrained(
"mlx-community/gemma-4-e4b-it-4bit",
)
model = FastVisionModel.get_peft_model(model,
finetune_vision_layers=False, finetune_language_layers=True,
finetune_audio_layers=True, # LoRA on Conformer attention!
r=8, lora_alpha=16,
)
dataset = [{"messages": [
{"role": "user", "content": [
{"type": "audio", "audio": "audio.wav"},
{"type": "text", "text": "What language is being spoken?"},
]},
{"role": "assistant", "content": [
{"type": "text", "text": "The speaker is speaking English."},
]},
]}]
trainer = VLMSFTTrainer(model=model, tokenizer=processor,
data_collator=UnslothVisionDataCollator(model, processor),
train_dataset=dataset,
args=VLMSFTConfig(max_steps=30, learning_rate=1e-4),
)
trainer.train()
Audio Models
12 — Orpheus TTS Fine-Tuning
Fine-tune Orpheus-3B for text-to-speech with SNAC audio codec and LoRA.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/orpheus-3b-0.1-ft-bf16",
codec_model="mlx-community/snac_24khz",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./tts_output", max_steps=60),
)
trainer.train()
FastTTSModel.for_inference(model)
audio = model.generate("Hello, how are you today?")
import soundfile as sf
sf.write("output.wav", audio, 24000)
13 — Whisper STT Fine-Tuning
Fine-tune Whisper for speech-to-text with encoder-decoder LoRA on Apple Silicon.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset, Audio
model, processor = FastSTTModel.from_pretrained(
"mlx-community/whisper-tiny-asr-fp16"
)
model = FastSTTModel.get_peft_model(
model, r=16, finetune_encoder=True, finetune_decoder=True,
)
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en",
split="train[:100]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor)
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./stt_output", max_steps=100),
)
trainer.train()
FastSTTModel.for_inference(model)
result = model.transcribe(audio_array)
print(result)
14 — OuteTTS Fine-Tuning
Fine-tune OuteTTS-1B (Llama + DAC codec) for voice cloning with text-based audio tokens.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/Llama-OuteTTS-1.0-1B-8bit",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("MrDragonFox/Elise", split="train[:10]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./outetts_output", max_steps=60),
)
trainer.train()
15 — Spark-TTS Fine-Tuning
Fine-tune Spark-TTS (0.5B, Qwen2 + BiCodec) — ultra-light TTS for Apple Silicon.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/Spark-TTS-0.5B-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("MrDragonFox/Elise", split="train[:10]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./spark_output", max_steps=60, sample_rate=16000),
)
trainer.train()
20 — Qwen3-TTS Fine-Tuning
Fine-tune Alibaba's Qwen3-TTS (1.7B) for multilingual text-to-speech. Dual-embedding architecture with 16-codebook built-in speech tokenizer.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/Qwen3-TTS-12Hz-1.7B-VoiceDesign-bf16",
)
model = FastTTSModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
)
dataset = load_dataset("MrDragonFox/Elise", split="train[:10]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))
collator = TTSDataCollator(model=model, tokenizer=tokenizer,
text_column="text", audio_column="audio")
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./qwen3_tts_output", max_steps=60),
)
trainer.train()
16 — Moonshine STT Fine-Tuning
Fine-tune Moonshine for speech-to-text with raw conv frontend (no mel spectrogram).
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset, Audio
model, processor = FastSTTModel.from_pretrained(
"UsefulSensors/moonshine-tiny",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
split="train[:10]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor, language="en", task="transcribe",
audio_column="audio", text_column="sentence")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./moonshine_output", max_steps=100),
)
trainer.train()
17 — Qwen3-ASR STT Fine-Tuning
Fine-tune Alibaba's Qwen3-ASR (audio-LLM) for multilingual speech recognition. Audio features are injected into a Qwen3 language model.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
model, processor = FastSTTModel.from_pretrained(
"mlx-community/Qwen3-ASR-1.7B-8bit",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
finetune_encoder=True, finetune_decoder=True,
)
dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
split="train[:10]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor, language="en", task="transcribe",
audio_column="audio", text_column="sentence")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./qwen3_asr_output", max_steps=100),
)
trainer.train()
18 — NVIDIA Canary STT Fine-Tuning
Fine-tune NVIDIA Canary for multilingual ASR. FastConformer encoder with Transformer decoder, supports 25+ languages and translation.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
model, processor = FastSTTModel.from_pretrained(
"eelcor/canary-1b-v2-mlx",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
finetune_encoder=True, finetune_decoder=True,
)
dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
split="train[:10]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor, language="en", task="transcribe",
audio_column="audio", text_column="sentence")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./canary_output", max_steps=100),
)
trainer.train()
19 — Voxtral STT Fine-Tuning
Fine-tune Mistral's Voxtral Mini (3B audio-LLM). Combines an audio encoder with a Llama decoder for speech recognition.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
model, processor = FastSTTModel.from_pretrained(
"mlx-community/Voxtral-Mini-3B-2507-bf16",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
finetune_encoder=True, finetune_decoder=True,
)
dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
split="train[:10]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor, language="en", task="transcribe",
audio_column="audio", text_column="sentence")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./voxtral_output", max_steps=100),
)
trainer.train()
49 — Voxtral Realtime Streaming STT
Fine-tune Mistral's streaming Voxtral Realtime. Causal encoder + AdaRMSNorm decoder, decoder LoRA only.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset
# All quantizations work via the same code path
model, processor = FastSTTModel.from_pretrained(
"mlx-community/Voxtral-Mini-4B-Realtime-2602-4bit", # ~2.5 GB
)
# Mistral-internal LoRA target naming (NOT q_proj!)
model = FastSTTModel.get_peft_model(
model,
r=8, lora_alpha=16,
target_modules=["wq", "wk", "wv", "wo"],
finetune_encoder=False, # causal encoder stays frozen
finetune_decoder=True,
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy",
"clean", split="validation").select(range(20))
collator = STTDataCollator(model, processor, language="en", task="transcribe",
audio_column="audio", text_column="text")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=ds,
args=STTSFTConfig(
per_device_train_batch_size=1, # required for Voxtral RT
max_steps=20,
learning_rate=2e-5, # low LR — AdaRMSNorm is sensitive
output_dir="./voxtral_realtime_output",
),
)
trainer.train()
# Save / reload / merge all work the same way
model.save_pretrained("./out/adapter")
model.save_pretrained_merged("./out/merged")
50 — Parakeet TDT English (LibriSpeech)
Fine-tune NVIDIA Parakeet TDT on English. Warm-started CTC head, LoRA encoder, baseline pipeline demo.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
# Warm-started CTC head from the pretrained joint network
model, processor = FastSTTModel.from_pretrained(
"mlx-community/parakeet-tdt-0.6b-v3",
)
model = FastSTTModel.get_peft_model(model, r=16, lora_alpha=32)
trainer = STTSFTTrainer(
model=model, tokenizer=processor,
data_collator=STTDataCollator(model=model, processor=processor),
train_dataset=train_ds,
args=STTSFTConfig(max_steps=150, learning_rate=3e-4, loss_type="ctc"),
)
trainer.train()
# Both inference paths still work
text_ctc = model.transcribe_ctc(audio_array) # new CTC head
text_tdt = model.transcribe_tdt(audio_array) # original TDT path
51 — Parakeet Welsh (new Latin-script language)
Fine-tune Parakeet on Welsh (FLEURS cy_gb). Latin-script chars already in the vocab — no extension needed.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset
model, processor = FastSTTModel.from_pretrained(
"mlx-community/parakeet-tdt-0.6b-v3",
warm_start_ctc_head=False, # avoid English-biased prior
)
model = FastSTTModel.get_peft_model(model, r=16, lora_alpha=32)
# Welsh from FLEURS — same code path for any Latin / Cyrillic / Greek language
ds = load_dataset("google/fleurs", "cy_gb", split="train[:80]",
trust_remote_code=True)
trainer = STTSFTTrainer(
model=model, tokenizer=processor,
data_collator=STTDataCollator(model=model, processor=processor),
train_dataset=train_ds,
args=STTSFTConfig(max_steps=100, learning_rate=3e-4, loss_type="ctc"),
)
trainer.train()
52 — Parakeet Bengali (char vocab extension)
Bengali characters all encode to UNK. Auto-extends the vocab with one call, then trains. Same path works for Hindi / Thai / CJK / Korean / Hebrew.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset
model, processor = FastSTTModel.from_pretrained(
"mlx-community/parakeet-tdt-0.6b-v3",
warm_start_ctc_head=False,
)
# Auto-detects ~60 missing Bengali chars and resizes the CTC head,
# joint output projection, and decoder embedding in lockstep.
ds = load_dataset("google/fleurs", "bn_in", split="train",
trust_remote_code=True)
texts = [ds[i]["transcription"] for i in range(len(ds))]
added = model.extend_vocabulary(texts, strategy="char")
print(f"Added {len(added)} new tokens")
model = FastSTTModel.get_peft_model(model, r=16, lora_alpha=32)
trainer = STTSFTTrainer(
model=model, tokenizer=processor,
data_collator=STTDataCollator(model=model, processor=processor),
train_dataset=train_ds,
args=STTSFTConfig(max_steps=150, learning_rate=3e-4, loss_type="ctc"),
)
trainer.train()
53 — Parakeet Arabic (BPE vocab extension)
Retrain SentencePiece BPE on the target corpus and combine it with the pretrained tokenizer. Better token efficiency than char-level.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset
model, processor = FastSTTModel.from_pretrained(
"mlx-community/parakeet-tdt-0.6b-v3",
warm_start_ctc_head=False,
)
# Trains a fresh 500-piece BPE on the Arabic corpus and installs an
# aggregate tokenizer over the pretrained + new SentencePiece models.
ds = load_dataset("google/fleurs", "ar_eg", split="train",
trust_remote_code=True)
texts = [ds[i]["transcription"] for i in range(len(ds))]
model.extend_vocabulary(texts, strategy="bpe", bpe_vocab_size=500)
model = FastSTTModel.get_peft_model(model, r=16, lora_alpha=32)
trainer = STTSFTTrainer(
model=model, tokenizer=processor,
data_collator=STTDataCollator(model=model, processor=processor),
train_dataset=train_ds,
args=STTSFTConfig(max_steps=150, learning_rate=3e-4, loss_type="ctc"),
)
trainer.train()
# Aux BPE model is persisted alongside the adapters
model.save_pretrained("./parakeet_arabic_finetuned/adapters")
Embedding Models
27 — BERT Embedding Fine-Tuning
Fine-tune all-MiniLM-L6-v2 (BERT) for semantic similarity using InfoNCE contrastive loss with in-batch negatives. Includes training, encoding, and similarity testing.
from mlx_tune import FastEmbeddingModel, EmbeddingSFTTrainer, EmbeddingSFTConfig, EmbeddingDataCollator
model, tokenizer = FastEmbeddingModel.from_pretrained(
"mlx-community/all-MiniLM-L6-v2-bf16",
pooling_strategy="mean",
)
model = FastEmbeddingModel.get_peft_model(model, r=16, lora_alpha=16,
target_modules=["query", "key", "value"])
train_data = [
{"anchor": "A man is eating food.", "positive": "A man is having a meal."},
# ... more pairs
]
trainer = EmbeddingSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=EmbeddingDataCollator(model, tokenizer),
train_dataset=train_data,
args=EmbeddingSFTConfig(loss_type="infonce", temperature=0.05, max_steps=50),
)
trainer.train()
# Encode and compare
embeddings = model.encode(["Hello world", "Hi there"])
similarity = (embeddings[0] * embeddings[1]).sum().item()
28 — Qwen3-Embedding Fine-Tuning
Fine-tune Qwen3-Embedding-0.6B (4-bit quantized) for domain-specific search. Decoder-based architecture with last-token pooling. Includes adapter save/load verification.
from mlx_tune import FastEmbeddingModel, EmbeddingSFTTrainer, EmbeddingSFTConfig, EmbeddingDataCollator
model, tokenizer = FastEmbeddingModel.from_pretrained(
"mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ",
pooling_strategy="last_token", # Decoder-based model
)
model = FastEmbeddingModel.get_peft_model(model, r=8, lora_alpha=16)
# Auto-detected targets: q_proj, k_proj, v_proj, o_proj
trainer = EmbeddingSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=EmbeddingDataCollator(model, tokenizer),
train_dataset=train_data,
args=EmbeddingSFTConfig(loss_type="infonce", temperature=0.05, max_steps=30),
)
trainer.train()
# Save and reload adapter
model.save_pretrained("./adapter")
model2.load_adapter("./adapter") # On fresh model with LoRA applied
31 — Harrier 0.6B Embedding Fine-Tuning
Fine-tune Microsoft Harrier 0.6B for cross-lingual semantic search. Multilingual embedding model supporting 94 languages with last-token pooling. Includes adapter save/load verification.
from mlx_tune import FastEmbeddingModel, EmbeddingSFTTrainer, EmbeddingSFTConfig, EmbeddingDataCollator
model, tokenizer = FastEmbeddingModel.from_pretrained(
"microsoft/harrier-oss-v1-0.6b",
pooling_strategy="last_token", # Harrier uses last-token pooling
)
model = FastEmbeddingModel.get_peft_model(model, r=8, lora_alpha=16)
# Auto-detected targets: q_proj, k_proj, v_proj, o_proj
trainer = EmbeddingSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=EmbeddingDataCollator(model, tokenizer),
train_dataset=train_data, # Cross-lingual query-passage pairs
args=EmbeddingSFTConfig(loss_type="infonce", temperature=0.05, max_steps=30),
)
trainer.train()
# Save and reload adapter
model.save_pretrained("./adapter")
model2.load_adapter("./adapter") # On fresh model with LoRA applied
32 — Harrier 270M Embedding Fine-Tuning
Fine-tune Microsoft Harrier 270M for code/documentation search. Lightweight model (~540MB) ideal for fast iteration on Apple Silicon. Includes adapter save/load verification.
from mlx_tune import FastEmbeddingModel, EmbeddingSFTTrainer, EmbeddingSFTConfig, EmbeddingDataCollator
model, tokenizer = FastEmbeddingModel.from_pretrained(
"microsoft/harrier-oss-v1-270m",
pooling_strategy="last_token", # Harrier uses last-token pooling
max_seq_length=256,
)
model = FastEmbeddingModel.get_peft_model(model, r=8, lora_alpha=16)
# Auto-detected targets: q_proj, k_proj, v_proj, o_proj
trainer = EmbeddingSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=EmbeddingDataCollator(model, tokenizer, max_seq_length=256),
train_dataset=train_data, # Code/documentation Q&A pairs
args=EmbeddingSFTConfig(loss_type="infonce", temperature=0.05, max_steps=30, max_seq_length=256),
)
trainer.train()
# Save and reload adapter
model.save_pretrained("./adapter")
model2.load_adapter("./adapter") # On fresh model with LoRA applied
OCR Models
33 — OCR Document Fine-Tuning
Fine-tune DeepSeek-OCR on LaTeX document images. Learns to convert document scans into structured LaTeX markup with high fidelity.
from mlx_tune import FastOCRModel, OCRSFTTrainer, OCRSFTConfig, OCRDataCollator
from datasets import load_dataset
model, processor = FastOCRModel.from_pretrained(
"mlx-community/DeepSeek-OCR-0.5B-4bit",
)
model = FastOCRModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("unsloth/LaTeX_OCR", split="train[:100]")
collator = OCRDataCollator(model, processor)
trainer = OCRSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=OCRSFTConfig(output_dir="./ocr_output", max_steps=60),
)
trainer.train()
34 — VLM-to-OCR Fine-Tuning
Adapt Qwen3.5 vision-language model for OCR tasks. Leverages VLM’s visual understanding to extract text from LaTeX document images.
from mlx_tune import FastOCRModel, OCRSFTTrainer, OCRSFTConfig, OCRDataCollator
from datasets import load_dataset
model, processor = FastOCRModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16",
)
model = FastOCRModel.get_peft_model(
model, r=16, lora_alpha=16,
finetune_vision_layers=True, finetune_language_layers=True,
)
dataset = load_dataset("unsloth/LaTeX_OCR", split="train[:100]")
collator = OCRDataCollator(model, processor)
trainer = OCRSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=OCRSFTConfig(output_dir="./vlm_ocr_output", max_steps=60),
)
trainer.train()
35 — Handwriting OCR
Fine-tune DeepSeek-OCR-2 on handwritten text recognition using the Handwriting-OCR dataset. Handles varied writing styles and quality.
from mlx_tune import FastOCRModel, OCRSFTTrainer, OCRSFTConfig, OCRDataCollator
from datasets import load_dataset
model, processor = FastOCRModel.from_pretrained(
"mlx-community/DeepSeek-OCR-2-1B-4bit",
)
model = FastOCRModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("Teklia/Handwriting-OCR", split="train[:100]")
collator = OCRDataCollator(model, processor, image_column="image", text_column="text")
trainer = OCRSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=OCRSFTConfig(output_dir="./handwriting_output", max_steps=60),
)
trainer.train()
36 — OCR GRPO Training
Train Qwen3.5 for OCR using GRPO with character error rate (CER) reward. Generates multiple OCR attempts and optimizes via policy gradient with edit-distance scoring.
from mlx_tune import FastOCRModel
from mlx_tune.ocr import OCRGRPOTrainer, OCRGRPOConfig
def cer_reward(response, ground_truth):
"""Character Error Rate reward: 1.0 - CER"""
import editdistance
dist = editdistance.eval(response, ground_truth)
return max(0.0, 1.0 - dist / max(len(ground_truth), 1))
model, processor = FastOCRModel.from_pretrained(
"mlx-community/Qwen3.5-0.8B-bf16",
)
model = FastOCRModel.get_peft_model(model, r=16, lora_alpha=16)
trainer = OCRGRPOTrainer(
model=model, processor=processor,
train_dataset=ocr_data, reward_fn=cer_reward,
args=OCRGRPOConfig(num_generations=2, max_completion_length=256,
learning_rate=1e-6, max_steps=20),
)
trainer.train()
37 — Multilingual OCR
Fine-tune GLM-OCR on CORD-v2 receipt dataset for multilingual document understanding. Extracts structured fields from scanned receipts across languages.
from mlx_tune import FastOCRModel, OCRSFTTrainer, OCRSFTConfig, OCRDataCollator
from datasets import load_dataset
model, processor = FastOCRModel.from_pretrained(
"mlx-community/GLM-OCR-2B-4bit",
)
model = FastOCRModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("naver-clova-ix/cord-v2", split="train[:100]")
collator = OCRDataCollator(model, processor, image_column="image", text_column="ground_truth")
trainer = OCRSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=OCRSFTConfig(output_dir="./multilingual_ocr_output", max_steps=60),
)
trainer.train()
MoE Fine-Tuning
Fine-tune Mixture of Experts models with per-expert LoRA. Same API as dense models — MoE detection is automatic.
29 — Qwen3.5 MoE Fine-Tuning
Fine-tune Qwen3.5-35B-A3B (35B total, 3B active) with automatic MoE detection. Expert layers get LoRASwitchLinear, shared experts and dense layers get standard LoRALinear.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
model, tok = FastLanguageModel.from_pretrained(
"mlx-community/Qwen3.5-35B-A3B-4bit", load_in_4bit=True)
# Same target_modules — MoE paths resolved automatically
model = FastLanguageModel.get_peft_model(model, r=8,
target_modules=["q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj"])
# Prints: "MoE architecture detected — LoRA will target expert layers"
SFTTrainer(model=model, train_dataset=dataset, tokenizer=tok,
args=SFTConfig(output_dir="outputs_moe", max_steps=10)).train()
30 — Phi-3.5 MoE Fine-Tuning
Fine-tune Microsoft’s Phi-3.5-MoE-instruct (42B total, 16 experts, top-2 routing). Demonstrates architecture-agnostic MoE support — PhiMoE uses block_sparse_moe paths.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
# Microsoft Phi-3.5-MoE — different MoE structure, same API
model, tok = FastLanguageModel.from_pretrained(
"mlx-community/Phi-3.5-MoE-instruct-4bit", load_in_4bit=True)
model = FastLanguageModel.get_peft_model(model, r=8,
target_modules=["q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj"])
SFTTrainer(model=model, train_dataset=dataset, tokenizer=tok,
args=SFTConfig(output_dir="outputs_phi_moe", max_steps=10)).train()
54 — Trinity-Nano SFT (AFMoE)
Instruction SFT on Arcee’s Trinity-Nano-Preview (6B total / 1B active, 128 experts + 1 shared, sigmoid routing, gated attention) on mlabonne/FineTome-100k. Per-expert LoRA auto-applies via LoRASwitchLinear.
from mlx_tune import FastLanguageModel, SFTTrainer, SFTConfig
from datasets import load_dataset
model, tok = FastLanguageModel.from_pretrained(
"mlx-community/Trinity-Nano-Preview-4bit", load_in_4bit=True)
model = FastLanguageModel.get_peft_model(model, r=8, lora_alpha=16,
target_modules=["q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj"])
ds = load_dataset("mlabonne/FineTome-100k", split="train[:200]")
# ShareGPT (from/value) -> ChatML via apply_chat_template
SFTTrainer(model=model, tokenizer=tok, train_dataset=ds,
args=SFTConfig(max_steps=20, learning_rate=2e-4)).train()
model.save_pretrained("outputs_trinity_sft/saved_adapters")
model.save_pretrained_merged("outputs_trinity_sft/merged_model", tokenizer=tok)
55 — Trinity-Nano GRPO on GSM8K
Two-phase reasoning recipe: Phase-1 SFT warmup on GSM8K rationales reshaped into <reasoning>/<answer>, then Phase-2 GRPO with a layered reward (1.0 strict → 0.25 any-tag). Fixes the “zero reward variance” no-op trap on preview models.
from mlx_tune import (FastLanguageModel, SFTTrainer, SFTConfig,
GRPOTrainer, GRPOConfig)
# Phase 1 — SFT warmup teaches the XML format
SFTTrainer(model=model, tokenizer=tok, train_dataset=cot_formatted_gsm8k,
args=SFTConfig(max_steps=20, learning_rate=2e-4)).train()
# Phase 2 — GRPO with layered reward (partial credit creates variance)
def soft_reward(response, gt):
# 1.0 strict XML+correct / 0.7 format+correct / 0.5 correct-only
# 0.25 any-tag / 0.0 otherwise
...
GRPOTrainer(model=model, tokenizer=tok, train_dataset=grpo_prompts,
reward_fn=soft_reward,
args=GRPOConfig(num_generations=4, temperature=1.0,
learning_rate=5e-6, max_steps=8)).train()
56 — Trinity-Nano CPT on WikiText
Continual pretraining on Salesforce/wikitext (wikitext-2-raw-v1) with decoupled embedding LR. Handles Trinity’s quantized lm_head via the v0.4.25 _is_quantized fix that sees through LoRALinear wrappers.
from mlx_tune import FastLanguageModel, CPTTrainer, CPTConfig
from datasets import load_dataset
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
# filter non-empty, then load into CPTTrainer
CPTTrainer(model=model, tokenizer=tok, train_dataset=ds,
args=CPTConfig(
learning_rate=5e-5,
embedding_learning_rate=5e-6, # 10x smaller protects embeddings
include_embeddings=True,
max_steps=20,
)).train()
Continual Pretraining
43 — CPT Language Adaptation
Teach a model a new language with continual pretraining on raw text. Loss on all tokens, optional embed_tokens/lm_head training with decoupled learning rate.
from mlx_tune import FastLanguageModel, CPTTrainer, CPTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/SmolLM2-135M-Instruct",
max_seq_length=2048,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
# Raw text in target language
dataset = [{"text": "Raw text in the target language..."}, ...]
trainer = CPTTrainer(
model=model, train_dataset=dataset, tokenizer=tokenizer,
args=CPTConfig(
output_dir="cpt_language", learning_rate=2e-5,
include_embeddings=True,
embedding_learning_rate=1e-5, max_steps=100,
),
)
trainer.train()
44 — CPT Domain Knowledge
Inject domain-specific knowledge (medical, legal, scientific) into a model via continual pretraining on domain corpora.
from mlx_tune import FastLanguageModel, CPTTrainer, CPTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/SmolLM2-360M-Instruct",
max_seq_length=2048,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
# Domain-specific text corpus
dataset = load_dataset("medical_papers", split="train")
trainer = CPTTrainer(
model=model, train_dataset=dataset, tokenizer=tokenizer,
args=CPTConfig(
output_dir="cpt_domain", learning_rate=2e-5,
max_steps=200,
),
)
trainer.train()
45 — CPT Code Capabilities
Extend a model’s programming language coverage by pretraining on code corpora. Improves code generation and understanding.
from mlx_tune import FastLanguageModel, CPTTrainer, CPTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/SmolLM2-135M-Instruct",
max_seq_length=4096,
)
model = FastLanguageModel.get_peft_model(model, r=16, lora_alpha=16)
# Code corpus
dataset = load_dataset("code_corpus", split="train")
trainer = CPTTrainer(
model=model, train_dataset=dataset, tokenizer=tokenizer,
args=CPTConfig(
output_dir="cpt_code", learning_rate=2e-5,
max_seq_length=4096, max_steps=300,
),
)
trainer.train()
46 — LFM2 + Continual Pretraining
Combine LFM2 with continual pretraining. Adapts Liquid AI’s hybrid architecture to a new domain with raw text training.
from mlx_tune import FastLanguageModel, CPTTrainer, CPTConfig
model, tokenizer = FastLanguageModel.from_pretrained(
"mlx-community/LFM2-350M-4bit",
max_seq_length=2048, load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model, r=16,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj",
"in_proj", "w1", "w2", "w3"],
)
trainer = CPTTrainer(
model=model, train_dataset=domain_text, tokenizer=tokenizer,
args=CPTConfig(
output_dir="cpt_lfm2", learning_rate=2e-5,
include_embeddings=True,
embedding_learning_rate=1e-5, max_steps=100,
),
)
trainer.train()