Audio Fine-Tuning
Fine-tune TTS and STT models on Apple Silicon. mlx-tune is the first library to offer native LoRA fine-tuning for audio models on MLX — train text-to-speech and speech-to-text models directly on your Mac.
Installation
Audio fine-tuning requires the optional [audio] extra, which installs mlx-audio, soundfile, and librosa.
# Install with audio support
uv pip install 'mlx-tune[audio]'
# System dependency: FFmpeg (required by mlx-audio for some codecs)
brew install ffmpeg
Python packages (mlx-audio, soundfile, librosa) are installed automatically with the [audio] extra. FFmpeg is a system dependency needed by some audio codecs — install via brew install ffmpeg. Apple Silicon Mac with 16 GB+ unified RAM recommended.
mlx-tune pins datasets<4.0.0. Version 4.0+ dropped soundfile support and requires torchcodec (which has FFmpeg version conflicts on macOS). If you see "please install torchcodec" errors, downgrade: uv pip install 'datasets<4.0.0'.
Supported Models
mlx-tune supports multiple TTS and STT model architectures. Models are auto-detected from their name — no manual profile configuration needed.
Text-to-Speech
| Model | Params | Codec | Sample Rate | Architecture |
|---|---|---|---|---|
| Orpheus-3B | 3B | SNAC | 24 kHz | Llama (decoder-only) |
| OuteTTS-1B | 1B | DAC | 24 kHz | Llama (decoder-only) |
| Spark-TTS | 0.5B | BiCodec | 16 kHz | Qwen2 (decoder-only) |
| Sesame/CSM-1B | 1B | Mimi | 24 kHz | Backbone + Decoder |
| Qwen3-TTS | 1.7B | Built-in (16-codebook Split RVQ) | 24 kHz | Talker (decoder-only) |
Speech-to-Text
| Model | Params | Preprocessor | Sample Rate | Architecture |
|---|---|---|---|---|
| Whisper (all sizes) | 39M–1.5B | Log-mel spectrogram | 16 kHz | Encoder-Decoder |
| Distil-Whisper | 756M | Log-mel spectrogram | 16 kHz | Encoder-Decoder (Whisper) |
| Moonshine | 27M–330M | Raw conv frontend | 16 kHz | Encoder-Decoder |
| Qwen3-ASR | 1.7B | Audio encoder | 16 kHz | Audio-LLM (Qwen3) |
| NVIDIA Canary | 1B | Parakeet mel | 16 kHz | Encoder-Decoder (Conformer) |
| Voxtral | 3B | Mel spectrogram | 16 kHz | Audio-LLM (Llama) |
| Voxtral Realtime | 4B | Causal mel + delay tokens | 16 kHz | Streaming (causal encoder + AdaRMSNorm decoder) |
| Parakeet TDT NEW | 0.6B–1.1B | Parakeet log-mel (128) | 16 kHz | FastConformer + TDT transducer, streaming (CTC/RNN-T/TDT losses) |
Gemma 4 E2B/E4B models have a built-in 12-layer Conformer audio tower for STT/ASR. Unlike the standalone STT models above, Gemma 4 audio is a modality encoder on a VLM — it uses FastVisionModel, UnslothVisionDataCollator, and VLMSFTTrainer instead of the STT-specific classes.
See the VLM page for details, or examples 47 (ASR) and 48 (Audio Understanding).
from mlx_tune import FastVisionModel
model, processor = FastVisionModel.from_pretrained("mlx-community/gemma-4-e4b-it-4bit")
model = FastVisionModel.get_peft_model(model,
finetune_vision_layers=False, finetune_language_layers=True,
finetune_audio_layers=False, # True for acoustic domain adaptation
r=16, lora_alpha=16)
# Dataset: {"type": "audio", "audio": "path.wav"} in messages
# Inference: model.generate(audio="path.wav", prompt="Transcribe.")
Text-to-Speech
mlx_tune.ttsTTS models in mlx-tune are decoder-only language models that predict discrete audio tokens. The API is the same regardless of model — FastTTSModel auto-detects the architecture and codec. Each model uses a different audio codec (SNAC, DAC, BiCodec, or Mimi), but this is handled transparently.
Orpheus-3B
Llama-based model using the SNAC codec (3 codebooks, 24 kHz). The original and largest supported TTS model.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/orpheus-3b-0.1-ft-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./tts_output", max_steps=60),
)
trainer.train()
OuteTTS-1B
Llama-based model using the DAC codec (2 codebooks, 24 kHz). Uses text-based audio tokens (<|c1_X|>, <|c2_X|>) instead of numeric offsets. Smaller and faster than Orpheus.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/Llama-OuteTTS-1.0-1B-8bit",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./outetts_output", max_steps=60),
)
trainer.train()
Spark-TTS (0.5B)
Qwen2-based model using BiCodec (global + semantic tokens, 16 kHz). Ultra-light at only 0.5B parameters — ideal for Apple Silicon with limited RAM.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/Spark-TTS-0.5B-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
# Spark uses 16kHz (not 24kHz like Orpheus/OuteTTS)
dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(
output_dir="./spark_output", max_steps=60,
sample_rate=16000,
),
)
trainer.train()
Qwen3-TTS (1.7B)
Alibaba's multilingual TTS model (ZH, EN, JA, KO, +more) with voice design. Uses a 28-layer talker transformer that predicts discrete audio codes, with a built-in 16-codebook speech tokenizer at 12.5Hz. Unique dual-embedding architecture where text and codec tokens are added together.
from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio
model, tokenizer = FastTTSModel.from_pretrained(
"mlx-community/Qwen3-TTS-12Hz-1.7B-VoiceDesign-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)
dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))
collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
model=model, tokenizer=tokenizer,
data_collator=collator, train_dataset=dataset,
args=TTSSFTConfig(output_dir="./qwen3_tts_output", max_steps=60),
)
trainer.train()
FastTTSModel API
Main entry point for loading and configuring TTS models. Auto-detects model architecture, codec, and token format from the model name.
FastTTSModel.from_pretrained()
Load a pretrained TTS model. Auto-detects the model profile and loads the appropriate codec.
| Parameter | Type | Description |
|---|---|---|
model_name | str | HuggingFace model ID or local path. Supported: Orpheus, OuteTTS, Spark-TTS, Sesame/CSM, Qwen3-TTS |
max_seq_length | int | Maximum sequence length for training/inference |
codec_model | str, optional | Override codec model ID. Auto-detected from profile if not specified |
load_in_4bit | bool | Load model with 4-bit quantization |
FastTTSModel.get_peft_model()
Add LoRA adapters to the TTS model for parameter-efficient fine-tuning.
| Parameter | Type | Description |
|---|---|---|
model | TTSModelWrapper | Model from from_pretrained() |
r | int | LoRA rank (higher = more parameters) |
target_modules | list[str], optional | Modules to apply LoRA to. Default: ["q_proj", "k_proj", "v_proj", "o_proj"] |
lora_alpha | int | LoRA scaling factor. Recommended: equal to r |
FastTTSModel.convert()
Convert a HuggingFace TTS model to MLX format.
TTSModelWrapper Methods
TTSSFTTrainer
TTS fine-tuning trainer. Batch size is forced to 1 (audio sequences have variable lengths).
| Parameter | Type | Description |
|---|---|---|
model | TTSModelWrapper | Model with LoRA adapters configured |
tokenizer | Tokenizer | Tokenizer from from_pretrained() |
data_collator | TTSDataCollator | Collator for encoding audio |
train_dataset | Dataset | Dataset with audio and text columns |
args | TTSSFTConfig | Training configuration |
TTSSFTConfig
TTS-specific training configuration. Extends SFTConfig with audio parameters.
| Parameter | Default | Description |
|---|---|---|
per_device_train_batch_size | 1 | Forced to 1 for audio training |
gradient_accumulation_steps | 4 | Number of gradient accumulation steps |
max_steps | 60 | Total training steps |
learning_rate | 2e-4 | Peak learning rate |
sample_rate | 24000 | Audio sample rate in Hz |
train_on_completions | True | Train only on audio tokens (not text prompt) |
TTSDataCollator
Data collator for TTS training. Encodes audio via the appropriate codec (SNAC, DAC, BiCodec, or Mimi), tokenizes text prompts, builds concatenated sequences, and masks prompt tokens so loss is computed only on audio outputs. Handles both numeric and text-based audio token formats automatically.
Speech-to-Text
mlx_tune.sttSTT models in mlx-tune are encoder-decoder architectures. The encoder processes audio input (mel spectrograms or raw waveforms), and the decoder generates text tokens autoregressively. LoRA adapters are applied to both encoder and decoder attention blocks. The API auto-detects the model type.
Whisper / Distil-Whisper
Whisper processes 80-bin log-mel spectrograms padded to 30 seconds. All Whisper sizes (tiny through large-v3) and Distil-Whisper variants are supported. LoRA targets: query, key, value, out in self-attention and cross-attention.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset, Audio
model, processor = FastSTTModel.from_pretrained(
"mlx-community/whisper-tiny-asr-fp16",
)
model = FastSTTModel.get_peft_model(
model, r=16, finetune_encoder=True, finetune_decoder=True,
)
dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
split="train[:100]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./stt_output", max_steps=100),
)
trainer.train()
Distil-Whisper variants auto-detect as the whisper profile and work with zero configuration changes. Just swap the model name:
FastSTTModel.from_pretrained("mlx-community/distil-whisper-large-v3")
Moonshine
Moonshine from Useful Sensors is an efficient STT model designed for edge devices. It processes raw audio waveforms directly through a convolutional frontend (no mel spectrogram needed). Very fast inference. LoRA targets: q_proj, k_proj, v_proj, o_proj.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset, Audio
model, processor = FastSTTModel.from_pretrained(
"UsefulSensors/moonshine-tiny",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
# Moonshine uses different attention names than Whisper
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
finetune_encoder=True, finetune_decoder=True,
)
dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
split="train[:100]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./moonshine_output", max_steps=100),
)
trainer.train()
Qwen3-ASR
Alibaba's Qwen3-ASR is an audio-LLM that injects audio features into a Qwen3 language model. Supports 30+ languages. The audio encoder has 24 layers and the Qwen3 decoder has 28 layers. LoRA targets both encoder and decoder.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
model, processor = FastSTTModel.from_pretrained(
"mlx-community/Qwen3-ASR-1.7B-8bit",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
finetune_encoder=True, finetune_decoder=True,
)
collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./qwen3_asr_output", max_steps=100),
)
trainer.train()
NVIDIA Canary
NVIDIA's Canary model uses a FastConformer encoder and Transformer decoder with cross-attention. Supports 25+ languages and speech translation. Encoder uses linear_q/k/v/linear_out (not q_proj), decoder uses q_proj/k_proj/v_proj/out_proj.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
model, processor = FastSTTModel.from_pretrained(
"eelcor/canary-1b-v2-mlx",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
finetune_encoder=True, finetune_decoder=True,
)
collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./canary_output", max_steps=100),
)
trainer.train()
Voxtral
Mistral's Voxtral combines an audio encoder (32 layers) with a Llama language model decoder (30 layers). It processes mel spectrograms and merges audio embeddings into the LLM's token sequence. LoRA targets: encoder q_proj/k_proj/v_proj/out_proj, decoder q_proj/k_proj/v_proj/o_proj.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
model, processor = FastSTTModel.from_pretrained(
"mlx-community/Voxtral-Mini-3B-2507-bf16",
)
model = FastSTTModel.get_peft_model(
model, r=16, lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
finetune_encoder=True, finetune_decoder=True,
)
collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(output_dir="./voxtral_output", max_steps=100),
)
trainer.train()
Voxtral Realtime NEW v0.4.21
Mistral's streaming ASR model — a fundamentally different architecture from regular Voxtral. Causal audio encoder (32 layers, sliding window 750) feeds an adapter MLP into a 26-layer LLM decoder with AdaRMSNorm time conditioning and additive embedding fusion (the input at each position is adapter_out[pos] + decoder.embed_token(token), no audio placeholder tokens).
This is the third architecture type in mlx-tune's STT subsystem (alongside encoder_decoder and audio_llm). LoRA targets use Mistral-internal naming: wq/wk/wv/wo + feed_forward_w1/w2/w3 — NOT q_proj/gate_proj.
mistral-commonmlx-audio's TekkenTokenizer is decode-only. Fine-tuning needs the encode side from mistral-common, which is included in the [audio] extras since v0.4.21. Install with uv pip install 'mlx-tune[audio]' --upgrade.
The Tekken tokenizer is byte-level BPE (131K vocab) — it encodes any UTF-8 text with lossless roundtrip. The audio encoder processes mel spectrograms language-agnostically. So fine-tuning on a new language requires zero tokenizer changes — just swap the dataset. Verified E2E with Turkish (FLEURS), loss 10.1→8.8 in 10 steps. Works with Common Voice, FLEURS, VoxPopuli, or any HF dataset with an audio + text column.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
# All quantizations work via the same code path (LoRA on QuantizedLinear too)
model, processor = FastSTTModel.from_pretrained(
"mlx-community/Voxtral-Mini-4B-Realtime-2602-4bit", # ~2.5 GB
# or: "...-2602-fp16" (~8 GB, highest fidelity)
# or: "...-Realtime-6bit"
)
model = FastSTTModel.get_peft_model(
model,
r=8, lora_alpha=16,
target_modules=["wq", "wk", "wv", "wo"], # Mistral-internal naming
finetune_encoder=False, # causal encoder stays frozen
finetune_decoder=True,
)
collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(
per_device_train_batch_size=1, # required for Voxtral RT
max_steps=20,
learning_rate=2e-5, # low LR critical — AdaRMSNorm is sensitive
output_dir="./voxtral_realtime_output",
),
)
trainer.train()
NVIDIA Parakeet TDT NEW v0.4.22
NVIDIA's state-of-the-art FastConformer + Token-and-Duration Transducer ASR. The fourth architecture type in mlx-tune's STT subsystem alongside encoder_decoder, audio_llm, and voxtral_realtime.
Parakeet is unusual: the decoder is a 2-layer LSTM prediction network (not LoRA-friendly), the joint network outputs [B, T, U+1, V+1+D] where D=5 duration bins, and training needs transducer loss. mlx-tune solves this by:
- Mounting a new CTC head on the Conformer encoder, warm-started from the pretrained joint network's output projection.
- Three loss functions in pure MLX:
ctc_loss(fastest),rnnt_loss(Graves 2012 forward algorithm),tdt_loss(extends RNN-T with duration prediction). - LoRA on the Conformer encoder's self-attention (
linear_q/k/v/linear_out/linear_pos) — note the Conformer-internal naming, notq_proj. - Automatic vocabulary extension via
model.extend_vocabulary()for any Unicode language. Two strategies:"char"(per-character auto extension, zero config) or"bpe"(retrain SentencePiece on your corpus, merge via aggregate tokenizer). - BatchNorm freeze in the Conformer conv modules to prevent running-stat drift during fine-tuning.
- Streaming preserved: chunked streaming inference works for both the native TDT path (
model.generate(audio, stream=True)) on the original training languages, and the new CTC head (model.stream_transcribe_ctc()) for fine-tuned and new-language models.
Parakeet v3's SentencePiece vocabulary covers Latin, Cyrillic, and Greek scripts well (~7,900 pieces across 25 European languages). For these scripts you can fine-tune directly with zero tokenizer changes. For non-Latin scripts (Arabic, Bengali, Devanagari, CJK, Korean, Thai, Hebrew...), call model.extend_vocabulary(texts, strategy="char") or "bpe" and mlx-tune will automatically resize the CTC head, joint output projection, and decoder embedding in lockstep.
from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
# Load the 0.6B FastConformer TDT model
model, processor = FastSTTModel.from_pretrained(
"mlx-community/parakeet-tdt-0.6b-v3",
# For non-training-set languages (Welsh, Bengali, Arabic, etc.), pass
# warm_start_ctc_head=False to start from a fresh random head — the
# warm-start has an English-biased prior that hurts new-language convergence.
)
# For non-Latin languages: auto-extend the vocabulary BEFORE applying LoRA.
# This scans `texts` for characters the pretrained SP tokenizer can't encode,
# adds them as new tokens, and resizes the CTC head + joint output + decoder
# embedding. Works for any Unicode language out of the box.
#
# added = model.extend_vocabulary(
# all_transcripts, strategy="char", min_count=1
# )
# # Or BPE retraining for better compression:
# added = model.extend_vocabulary(
# all_transcripts, strategy="bpe", bpe_vocab_size=500
# )
# LoRA on the Conformer self-attention (default: q/k/v/out + linear_pos)
model = FastSTTModel.get_peft_model(model, r=16, lora_alpha=32)
collator = STTDataCollator(model, processor)
trainer = STTSFTTrainer(
model=model, processor=processor,
data_collator=collator, train_dataset=dataset,
args=STTSFTConfig(
per_device_train_batch_size=1,
max_steps=150,
learning_rate=3e-4,
loss_type="ctc", # or "rnnt" / "tdt" / "hybrid"
output_dir="./parakeet_output",
),
)
trainer.train()
# Transcribe via the new CTC head
text = model.transcribe_ctc("audio.wav")
# The original TDT decoder still works for languages Parakeet was trained on
text = model.transcribe_tdt("audio.wav")
# Streaming CTC via chunked inference
for partial in model.stream_transcribe_ctc("long_audio.wav"):
print(partial)
| Loss type | Forward pass | Notes |
|---|---|---|
ctc | encoder → new CTC head | Fastest, simplest. Joint network not touched. |
rnnt | encoder → joint(enc, dec) slice [:V+1] | Standard RNN-T Graves 2012. Uses existing joint+LSTM. |
tdt | encoder → joint(enc, dec) full | Matches NVIDIA's native TDT recipe with duration regularization. |
hybrid | CTC + TDT weighted sum | Matches NVIDIA's Stage 2 recipe (ctc_weight=0.3, tdt_weight=1.0). |
FastSTTModel API
Main entry point for loading and configuring STT models. Auto-detects model architecture and preprocessor from the model name.
FastSTTModel.from_pretrained()
Load a pretrained STT model. Returns model wrapper and processor.
| Parameter | Type | Description |
|---|---|---|
model_name | str | HuggingFace model ID or local path. Supported: Whisper, Distil-Whisper, Moonshine, Qwen3-ASR, Canary, Voxtral, Voxtral Realtime, Parakeet TDT |
max_seq_length | int | Maximum decoder sequence length |
FastSTTModel.get_peft_model()
Add LoRA adapters to encoder and/or decoder attention layers.
| Parameter | Type | Description |
|---|---|---|
model | STTModelWrapper | Model from from_pretrained() |
r | int | LoRA rank |
finetune_encoder | bool | Apply LoRA to encoder attention blocks |
finetune_decoder | bool | Apply LoRA to decoder attention blocks |
lora_alpha | int | LoRA scaling factor. Recommended: equal to r |
FastSTTModel.convert()
Convert a HuggingFace Whisper model to MLX format. Uses mlx_audio.convert with model_domain="stt".
STTModelWrapper Methods
STTSFTTrainer
STT fine-tuning trainer. Uses sequence-to-sequence training with encoder-decoder architecture. Batch size is forced to 1.
| Parameter | Type | Description |
|---|---|---|
model | STTModelWrapper | Model with LoRA adapters configured |
processor | STTProcessor | Processor from from_pretrained() |
data_collator | STTDataCollator | Collator for computing mel spectrograms |
train_dataset | Dataset | Dataset with audio and transcript columns |
args | STTSFTConfig | Training configuration |
STTSFTConfig
STT-specific training configuration. Extends SFTConfig with audio parameters.
| Parameter | Default | Description |
|---|---|---|
per_device_train_batch_size | 1 | Forced to 1 for audio training |
gradient_accumulation_steps | 4 | Number of gradient accumulation steps |
max_steps | 100 | Total training steps |
learning_rate | 1e-5 | Peak learning rate (lower than TTS) |
sample_rate | 16000 | Audio sample rate in Hz (Whisper uses 16 kHz) |
language | "en" | Target language for transcription |
STTDataCollator
Data collator for STT training. For Whisper models, computes 80-bin log-mel spectrograms padded to 30 seconds. For Moonshine, passes raw audio through the conv frontend. Tokenizes transcription targets with language-specific SOT tokens.
STTProcessor
Processor wrapper for STT models. Provides mel spectrogram computation and text tokenization.
Post-Training
Save, merge, convert, and share your fine-tuned audio models. These operations work the same for both TTS and STT.
Save adapters
# Save LoRA adapters only (small, shareable)
model.save_pretrained("./adapters")
Merge LoRA into base model
# Fuse LoRA weights into the base model and save
model.save_pretrained_merged("./merged_model")
Convert HuggingFace model to MLX
# Convert and optionally quantize a HF model
FastTTSModel.convert(
"canopylabs/orpheus-3b",
quantize=True,
q_bits=4,
)
Push to HuggingFace Hub
model.push_to_hub(
"username/my-model",
token="hf_...",
private=True,
)
Examples
12 — Orpheus TTS Fine-Tuning
Fine-tune Orpheus-3B for text-to-speech with SNAC audio codec. Includes data loading, training, and inference.
13 — Whisper STT Fine-Tuning
Fine-tune Whisper for speech-to-text transcription. Encoder-decoder LoRA with mel spectrogram inputs.
14 — OuteTTS Fine-Tuning
Fine-tune OuteTTS-1B (Llama-based, DAC codec) for voice cloning. Text-based audio tokens, same API as Orpheus.
15 — Spark-TTS Fine-Tuning
Fine-tune Spark-TTS (0.5B, Qwen2-based) with BiCodec. Ultra-light TTS ideal for limited RAM.
16 — Moonshine STT Fine-Tuning
Fine-tune Moonshine for speech-to-text. Raw conv frontend (no mel spectrogram), fast edge-device STT.
17 — Qwen3-ASR Fine-Tuning
Fine-tune Alibaba's Qwen3-ASR (audio-LLM). Audio features injected into Qwen3 LLM for multilingual ASR.
18 — Canary STT Fine-Tuning
Fine-tune NVIDIA Canary (FastConformer + Transformer decoder). Multilingual ASR with 25+ languages.
19 — Voxtral STT Fine-Tuning
Fine-tune Mistral's Voxtral (audio encoder + Llama decoder). 3B audio-LLM for speech recognition.
49 — Voxtral Realtime Streaming STT NEW
Fine-tune Mistral's streaming Voxtral Realtime on LibriSpeech. Causal encoder + AdaRMSNorm decoder + additive embedding fusion. Works on all quantization variants (4bit/6bit/fp16).
Tips
Audio training forces batch_size=1. Audio sequences have variable lengths, so batching is not supported. Use gradient_accumulation_steps to simulate larger effective batch sizes.
mlx-tune auto-detects the model type from the model name. You don't need to specify profiles, codecs, or token formats manually — just pass the model ID to from_pretrained().
Different models use different sample rates. Always cast your dataset audio column to match:
- 24 kHz: Orpheus, OuteTTS, Sesame/CSM, Qwen3-TTS
- 16 kHz: Spark-TTS, Whisper, Distil-Whisper, Moonshine, Qwen3-ASR, Canary, Voxtral
dataset.cast_column("audio", Audio(sampling_rate=24000))
Use MLX-format models from mlx-community on HuggingFace:
- Orpheus:
mlx-community/orpheus-3b-0.1-ft-bf16 - OuteTTS:
mlx-community/Llama-OuteTTS-1.0-1B-8bit - Spark:
mlx-community/Spark-TTS-0.5B-bf16 - Whisper:
mlx-community/whisper-*-asr-fp16 - Distil-Whisper:
mlx-community/distil-whisper-large-v3 - Moonshine:
UsefulSensors/moonshine-tiny - Qwen3-ASR:
mlx-community/Qwen3-ASR-1.7B-8bit - Canary:
eelcor/canary-1b-v2-mlx - Voxtral:
mlx-community/Voxtral-Mini-3B-2507-bf16
16 GB+ unified RAM recommended. Approximate requirements:
- Spark-TTS (0.5B): ~4 GB
- OuteTTS (1B, 8-bit): ~4 GB
- Whisper-tiny / Moonshine-tiny: ~2 GB
- Orpheus-3B (4-bit): ~8 GB
- Distil-Whisper-large-v3: ~6 GB
- Qwen3-ASR (1.7B, 8-bit): ~4 GB
- Canary (1B): ~4 GB
- Voxtral (3B, bf16): ~12 GB
TTS training uses response-only training by default (train_on_completions=True). Loss is computed only on audio tokens, not on the text prompt — this matches how the model is used at inference time.