Audio Fine-Tuning

Fine-tune TTS and STT models on Apple Silicon. mlx-tune is the first library to offer native LoRA fine-tuning for audio models on MLX — train text-to-speech and speech-to-text models directly on your Mac.

Installation

Audio fine-tuning requires the optional [audio] extra, which installs mlx-audio, soundfile, and librosa.

# Install with audio support
uv pip install 'mlx-tune[audio]'

# System dependency: FFmpeg (required by mlx-audio for some codecs)
brew install ffmpeg
Requirements

Python packages (mlx-audio, soundfile, librosa) are installed automatically with the [audio] extra. FFmpeg is a system dependency needed by some audio codecs — install via brew install ffmpeg. Apple Silicon Mac with 16 GB+ unified RAM recommended.

datasets version

mlx-tune pins datasets<4.0.0. Version 4.0+ dropped soundfile support and requires torchcodec (which has FFmpeg version conflicts on macOS). If you see "please install torchcodec" errors, downgrade: uv pip install 'datasets<4.0.0'.

Supported Models

mlx-tune supports multiple TTS and STT model architectures. Models are auto-detected from their name — no manual profile configuration needed.

Text-to-Speech

ModelParamsCodecSample RateArchitecture
Orpheus-3B3BSNAC24 kHzLlama (decoder-only)
OuteTTS-1B1BDAC24 kHzLlama (decoder-only)
Spark-TTS0.5BBiCodec16 kHzQwen2 (decoder-only)
Sesame/CSM-1B1BMimi24 kHzBackbone + Decoder
Qwen3-TTS1.7BBuilt-in (16-codebook Split RVQ)24 kHzTalker (decoder-only)

Speech-to-Text

ModelParamsPreprocessorSample RateArchitecture
Whisper (all sizes)39M–1.5BLog-mel spectrogram16 kHzEncoder-Decoder
Distil-Whisper756MLog-mel spectrogram16 kHzEncoder-Decoder (Whisper)
Moonshine27M–330MRaw conv frontend16 kHzEncoder-Decoder
Qwen3-ASR1.7BAudio encoder16 kHzAudio-LLM (Qwen3)
NVIDIA Canary1BParakeet mel16 kHzEncoder-Decoder (Conformer)
Voxtral3BMel spectrogram16 kHzAudio-LLM (Llama)
Voxtral Realtime4BCausal mel + delay tokens16 kHzStreaming (causal encoder + AdaRMSNorm decoder)
Parakeet TDT NEW0.6B–1.1BParakeet log-mel (128)16 kHzFastConformer + TDT transducer, streaming (CTC/RNN-T/TDT losses)
Gemma 4 Audio (VLM-based STT)

Gemma 4 E2B/E4B models have a built-in 12-layer Conformer audio tower for STT/ASR. Unlike the standalone STT models above, Gemma 4 audio is a modality encoder on a VLM — it uses FastVisionModel, UnslothVisionDataCollator, and VLMSFTTrainer instead of the STT-specific classes.

See the VLM page for details, or examples 47 (ASR) and 48 (Audio Understanding).

from mlx_tune import FastVisionModel

model, processor = FastVisionModel.from_pretrained("mlx-community/gemma-4-e4b-it-4bit")
model = FastVisionModel.get_peft_model(model,
    finetune_vision_layers=False, finetune_language_layers=True,
    finetune_audio_layers=False,  # True for acoustic domain adaptation
    r=16, lora_alpha=16)

# Dataset: {"type": "audio", "audio": "path.wav"} in messages
# Inference: model.generate(audio="path.wav", prompt="Transcribe.")

Text-to-Speech

mlx_tune.tts

TTS models in mlx-tune are decoder-only language models that predict discrete audio tokens. The API is the same regardless of model — FastTTSModel auto-detects the architecture and codec. Each model uses a different audio codec (SNAC, DAC, BiCodec, or Mimi), but this is handled transparently.

Orpheus-3B

Llama-based model using the SNAC codec (3 codebooks, 24 kHz). The original and largest supported TTS model.

from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio

model, tokenizer = FastTTSModel.from_pretrained(
    "mlx-community/orpheus-3b-0.1-ft-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)

dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))

collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
    model=model, tokenizer=tokenizer,
    data_collator=collator, train_dataset=dataset,
    args=TTSSFTConfig(output_dir="./tts_output", max_steps=60),
)
trainer.train()

OuteTTS-1B

Llama-based model using the DAC codec (2 codebooks, 24 kHz). Uses text-based audio tokens (<|c1_X|>, <|c2_X|>) instead of numeric offsets. Smaller and faster than Orpheus.

from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio

model, tokenizer = FastTTSModel.from_pretrained(
    "mlx-community/Llama-OuteTTS-1.0-1B-8bit",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)

dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))

collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
    model=model, tokenizer=tokenizer,
    data_collator=collator, train_dataset=dataset,
    args=TTSSFTConfig(output_dir="./outetts_output", max_steps=60),
)
trainer.train()

Spark-TTS (0.5B)

Qwen2-based model using BiCodec (global + semantic tokens, 16 kHz). Ultra-light at only 0.5B parameters — ideal for Apple Silicon with limited RAM.

from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio

model, tokenizer = FastTTSModel.from_pretrained(
    "mlx-community/Spark-TTS-0.5B-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)

# Spark uses 16kHz (not 24kHz like Orpheus/OuteTTS)
dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
    model=model, tokenizer=tokenizer,
    data_collator=collator, train_dataset=dataset,
    args=TTSSFTConfig(
        output_dir="./spark_output", max_steps=60,
        sample_rate=16000,
    ),
)
trainer.train()

Qwen3-TTS (1.7B)

Alibaba's multilingual TTS model (ZH, EN, JA, KO, +more) with voice design. Uses a 28-layer talker transformer that predicts discrete audio codes, with a built-in 16-codebook speech tokenizer at 12.5Hz. Unique dual-embedding architecture where text and codec tokens are added together.

from mlx_tune import FastTTSModel, TTSSFTTrainer, TTSSFTConfig, TTSDataCollator
from datasets import load_dataset, Audio

model, tokenizer = FastTTSModel.from_pretrained(
    "mlx-community/Qwen3-TTS-12Hz-1.7B-VoiceDesign-bf16",
)
model = FastTTSModel.get_peft_model(model, r=16, lora_alpha=16)

dataset = load_dataset("MrDragonFox/Elise", split="train[:100]")
dataset = dataset.cast_column("audio", Audio(sampling_rate=24000))

collator = TTSDataCollator(model, tokenizer)
trainer = TTSSFTTrainer(
    model=model, tokenizer=tokenizer,
    data_collator=collator, train_dataset=dataset,
    args=TTSSFTConfig(output_dir="./qwen3_tts_output", max_steps=60),
)
trainer.train()

FastTTSModel API

Main entry point for loading and configuring TTS models. Auto-detects model architecture, codec, and token format from the model name.

FastTTSModel.from_pretrained()

FastTTSModel.from_pretrained(model_name, max_seq_length=2048, codec_model=None, load_in_4bit=False, ...) → Tuple[TTSModelWrapper, Tokenizer]

Load a pretrained TTS model. Auto-detects the model profile and loads the appropriate codec.

ParameterTypeDescription
model_namestrHuggingFace model ID or local path. Supported: Orpheus, OuteTTS, Spark-TTS, Sesame/CSM, Qwen3-TTS
max_seq_lengthintMaximum sequence length for training/inference
codec_modelstr, optionalOverride codec model ID. Auto-detected from profile if not specified
load_in_4bitboolLoad model with 4-bit quantization

FastTTSModel.get_peft_model()

FastTTSModel.get_peft_model(model, r=16, target_modules=None, lora_alpha=16, ...) → TTSModelWrapper

Add LoRA adapters to the TTS model for parameter-efficient fine-tuning.

ParameterTypeDescription
modelTTSModelWrapperModel from from_pretrained()
rintLoRA rank (higher = more parameters)
target_moduleslist[str], optionalModules to apply LoRA to. Default: ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_alphaintLoRA scaling factor. Recommended: equal to r
FastTTSModel.for_training(model) — Enable training mode
FastTTSModel.for_inference(model) — Enable inference mode

FastTTSModel.convert()

FastTTSModel.convert(hf_model, output_dir="mlx_model", quantize=False, q_bits=4, dtype=None, upload_repo=None)

Convert a HuggingFace TTS model to MLX format.

TTSModelWrapper Methods

model.encode_audio(audio_array, sr=None) → List[int] — Encode audio waveform to codec token IDs
model.decode_audio(token_ids) → np.ndarray — Decode codec token IDs back to audio waveform
model.generate(text, speaker=None, max_tokens=1250, ...) → np.ndarray — Generate audio from text
model.save_pretrained(output_dir) — Save LoRA adapters
model.load_adapter(adapter_path) — Load saved adapters
model.save_pretrained_merged(output_dir, tokenizer=None, push_to_hub=False, repo_id=None) — Save full merged model
model.push_to_hub(repo_id) — Push adapters to HuggingFace Hub

TTSSFTTrainer

TTSSFTTrainer(model, tokenizer, data_collator, train_dataset, args=None)

TTS fine-tuning trainer. Batch size is forced to 1 (audio sequences have variable lengths).

ParameterTypeDescription
modelTTSModelWrapperModel with LoRA adapters configured
tokenizerTokenizerTokenizer from from_pretrained()
data_collatorTTSDataCollatorCollator for encoding audio
train_datasetDatasetDataset with audio and text columns
argsTTSSFTConfigTraining configuration

TTSSFTConfig

TTSSFTConfig(per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=60, learning_rate=2e-4, output_dir="./tts_outputs", sample_rate=24000, train_on_completions=True, ...)

TTS-specific training configuration. Extends SFTConfig with audio parameters.

ParameterDefaultDescription
per_device_train_batch_size1Forced to 1 for audio training
gradient_accumulation_steps4Number of gradient accumulation steps
max_steps60Total training steps
learning_rate2e-4Peak learning rate
sample_rate24000Audio sample rate in Hz
train_on_completionsTrueTrain only on audio tokens (not text prompt)

TTSDataCollator

TTSDataCollator(model, tokenizer, text_column="text", audio_column="audio")

Data collator for TTS training. Encodes audio via the appropriate codec (SNAC, DAC, BiCodec, or Mimi), tokenizes text prompts, builds concatenated sequences, and masks prompt tokens so loss is computed only on audio outputs. Handles both numeric and text-based audio token formats automatically.

Speech-to-Text

mlx_tune.stt

STT models in mlx-tune are encoder-decoder architectures. The encoder processes audio input (mel spectrograms or raw waveforms), and the decoder generates text tokens autoregressively. LoRA adapters are applied to both encoder and decoder attention blocks. The API auto-detects the model type.

Whisper / Distil-Whisper

Whisper processes 80-bin log-mel spectrograms padded to 30 seconds. All Whisper sizes (tiny through large-v3) and Distil-Whisper variants are supported. LoRA targets: query, key, value, out in self-attention and cross-attention.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset, Audio

model, processor = FastSTTModel.from_pretrained(
    "mlx-community/whisper-tiny-asr-fp16",
)
model = FastSTTModel.get_peft_model(
    model, r=16, finetune_encoder=True, finetune_decoder=True,
)

dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
                       split="train[:100]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(output_dir="./stt_output", max_steps=100),
)
trainer.train()
Distil-Whisper

Distil-Whisper variants auto-detect as the whisper profile and work with zero configuration changes. Just swap the model name:

FastSTTModel.from_pretrained("mlx-community/distil-whisper-large-v3")

Moonshine

Moonshine from Useful Sensors is an efficient STT model designed for edge devices. It processes raw audio waveforms directly through a convolutional frontend (no mel spectrogram needed). Very fast inference. LoRA targets: q_proj, k_proj, v_proj, o_proj.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator
from datasets import load_dataset, Audio

model, processor = FastSTTModel.from_pretrained(
    "UsefulSensors/moonshine-tiny",
)
model = FastSTTModel.get_peft_model(
    model, r=16, lora_alpha=16,
    # Moonshine uses different attention names than Whisper
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    finetune_encoder=True, finetune_decoder=True,
)

dataset = load_dataset("mozilla-foundation/common_voice_16_0", "en",
                       split="train[:100]", trust_remote_code=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(output_dir="./moonshine_output", max_steps=100),
)
trainer.train()

Qwen3-ASR

Alibaba's Qwen3-ASR is an audio-LLM that injects audio features into a Qwen3 language model. Supports 30+ languages. The audio encoder has 24 layers and the Qwen3 decoder has 28 layers. LoRA targets both encoder and decoder.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator

model, processor = FastSTTModel.from_pretrained(
    "mlx-community/Qwen3-ASR-1.7B-8bit",
)
model = FastSTTModel.get_peft_model(
    model, r=16, lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    finetune_encoder=True, finetune_decoder=True,
)

collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(output_dir="./qwen3_asr_output", max_steps=100),
)
trainer.train()

NVIDIA Canary

NVIDIA's Canary model uses a FastConformer encoder and Transformer decoder with cross-attention. Supports 25+ languages and speech translation. Encoder uses linear_q/k/v/linear_out (not q_proj), decoder uses q_proj/k_proj/v_proj/out_proj.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator

model, processor = FastSTTModel.from_pretrained(
    "eelcor/canary-1b-v2-mlx",
)
model = FastSTTModel.get_peft_model(
    model, r=16, lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
    finetune_encoder=True, finetune_decoder=True,
)

collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(output_dir="./canary_output", max_steps=100),
)
trainer.train()

Voxtral

Mistral's Voxtral combines an audio encoder (32 layers) with a Llama language model decoder (30 layers). It processes mel spectrograms and merges audio embeddings into the LLM's token sequence. LoRA targets: encoder q_proj/k_proj/v_proj/out_proj, decoder q_proj/k_proj/v_proj/o_proj.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator

model, processor = FastSTTModel.from_pretrained(
    "mlx-community/Voxtral-Mini-3B-2507-bf16",
)
model = FastSTTModel.get_peft_model(
    model, r=16, lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    finetune_encoder=True, finetune_decoder=True,
)

collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(output_dir="./voxtral_output", max_steps=100),
)
trainer.train()

Voxtral Realtime NEW v0.4.21

Mistral's streaming ASR model — a fundamentally different architecture from regular Voxtral. Causal audio encoder (32 layers, sliding window 750) feeds an adapter MLP into a 26-layer LLM decoder with AdaRMSNorm time conditioning and additive embedding fusion (the input at each position is adapter_out[pos] + decoder.embed_token(token), no audio placeholder tokens).

This is the third architecture type in mlx-tune's STT subsystem (alongside encoder_decoder and audio_llm). LoRA targets use Mistral-internal naming: wq/wk/wv/wo + feed_forward_w1/w2/w3NOT q_proj/gate_proj.

Requires mistral-common

mlx-audio's TekkenTokenizer is decode-only. Fine-tuning needs the encode side from mistral-common, which is included in the [audio] extras since v0.4.21. Install with uv pip install 'mlx-tune[audio]' --upgrade.

New language adaptation

The Tekken tokenizer is byte-level BPE (131K vocab) — it encodes any UTF-8 text with lossless roundtrip. The audio encoder processes mel spectrograms language-agnostically. So fine-tuning on a new language requires zero tokenizer changes — just swap the dataset. Verified E2E with Turkish (FLEURS), loss 10.1→8.8 in 10 steps. Works with Common Voice, FLEURS, VoxPopuli, or any HF dataset with an audio + text column.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator

# All quantizations work via the same code path (LoRA on QuantizedLinear too)
model, processor = FastSTTModel.from_pretrained(
    "mlx-community/Voxtral-Mini-4B-Realtime-2602-4bit",  # ~2.5 GB
    # or: "...-2602-fp16"  (~8 GB, highest fidelity)
    # or: "...-Realtime-6bit"
)
model = FastSTTModel.get_peft_model(
    model,
    r=8, lora_alpha=16,
    target_modules=["wq", "wk", "wv", "wo"],  # Mistral-internal naming
    finetune_encoder=False,  # causal encoder stays frozen
    finetune_decoder=True,
)

collator = STTDataCollator(model, processor, language="en", task="transcribe")
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(
        per_device_train_batch_size=1,  # required for Voxtral RT
        max_steps=20,
        learning_rate=2e-5,  # low LR critical — AdaRMSNorm is sensitive
        output_dir="./voxtral_realtime_output",
    ),
)
trainer.train()

NVIDIA Parakeet TDT NEW v0.4.22

NVIDIA's state-of-the-art FastConformer + Token-and-Duration Transducer ASR. The fourth architecture type in mlx-tune's STT subsystem alongside encoder_decoder, audio_llm, and voxtral_realtime.

Parakeet is unusual: the decoder is a 2-layer LSTM prediction network (not LoRA-friendly), the joint network outputs [B, T, U+1, V+1+D] where D=5 duration bins, and training needs transducer loss. mlx-tune solves this by:

  • Mounting a new CTC head on the Conformer encoder, warm-started from the pretrained joint network's output projection.
  • Three loss functions in pure MLX: ctc_loss (fastest), rnnt_loss (Graves 2012 forward algorithm), tdt_loss (extends RNN-T with duration prediction).
  • LoRA on the Conformer encoder's self-attention (linear_q/k/v/linear_out/linear_pos) — note the Conformer-internal naming, not q_proj.
  • Automatic vocabulary extension via model.extend_vocabulary() for any Unicode language. Two strategies: "char" (per-character auto extension, zero config) or "bpe" (retrain SentencePiece on your corpus, merge via aggregate tokenizer).
  • BatchNorm freeze in the Conformer conv modules to prevent running-stat drift during fine-tuning.
  • Streaming preserved: chunked streaming inference works for both the native TDT path (model.generate(audio, stream=True)) on the original training languages, and the new CTC head (model.stream_transcribe_ctc()) for fine-tuned and new-language models.
Language coverage

Parakeet v3's SentencePiece vocabulary covers Latin, Cyrillic, and Greek scripts well (~7,900 pieces across 25 European languages). For these scripts you can fine-tune directly with zero tokenizer changes. For non-Latin scripts (Arabic, Bengali, Devanagari, CJK, Korean, Thai, Hebrew...), call model.extend_vocabulary(texts, strategy="char") or "bpe" and mlx-tune will automatically resize the CTC head, joint output projection, and decoder embedding in lockstep.

from mlx_tune import FastSTTModel, STTSFTTrainer, STTSFTConfig, STTDataCollator

# Load the 0.6B FastConformer TDT model
model, processor = FastSTTModel.from_pretrained(
    "mlx-community/parakeet-tdt-0.6b-v3",
    # For non-training-set languages (Welsh, Bengali, Arabic, etc.), pass
    # warm_start_ctc_head=False to start from a fresh random head — the
    # warm-start has an English-biased prior that hurts new-language convergence.
)

# For non-Latin languages: auto-extend the vocabulary BEFORE applying LoRA.
# This scans `texts` for characters the pretrained SP tokenizer can't encode,
# adds them as new tokens, and resizes the CTC head + joint output + decoder
# embedding. Works for any Unicode language out of the box.
#
# added = model.extend_vocabulary(
#     all_transcripts, strategy="char", min_count=1
# )
# # Or BPE retraining for better compression:
# added = model.extend_vocabulary(
#     all_transcripts, strategy="bpe", bpe_vocab_size=500
# )

# LoRA on the Conformer self-attention (default: q/k/v/out + linear_pos)
model = FastSTTModel.get_peft_model(model, r=16, lora_alpha=32)

collator = STTDataCollator(model, processor)
trainer = STTSFTTrainer(
    model=model, processor=processor,
    data_collator=collator, train_dataset=dataset,
    args=STTSFTConfig(
        per_device_train_batch_size=1,
        max_steps=150,
        learning_rate=3e-4,
        loss_type="ctc",  # or "rnnt" / "tdt" / "hybrid"
        output_dir="./parakeet_output",
    ),
)
trainer.train()

# Transcribe via the new CTC head
text = model.transcribe_ctc("audio.wav")

# The original TDT decoder still works for languages Parakeet was trained on
text = model.transcribe_tdt("audio.wav")

# Streaming CTC via chunked inference
for partial in model.stream_transcribe_ctc("long_audio.wav"):
    print(partial)
Loss typeForward passNotes
ctcencoder → new CTC headFastest, simplest. Joint network not touched.
rnntencoder → joint(enc, dec) slice [:V+1]Standard RNN-T Graves 2012. Uses existing joint+LSTM.
tdtencoder → joint(enc, dec) fullMatches NVIDIA's native TDT recipe with duration regularization.
hybridCTC + TDT weighted sumMatches NVIDIA's Stage 2 recipe (ctc_weight=0.3, tdt_weight=1.0).

FastSTTModel API

Main entry point for loading and configuring STT models. Auto-detects model architecture and preprocessor from the model name.

FastSTTModel.from_pretrained()

FastSTTModel.from_pretrained(model_name, max_seq_length=448, ...) → Tuple[STTModelWrapper, STTProcessor]

Load a pretrained STT model. Returns model wrapper and processor.

ParameterTypeDescription
model_namestrHuggingFace model ID or local path. Supported: Whisper, Distil-Whisper, Moonshine, Qwen3-ASR, Canary, Voxtral, Voxtral Realtime, Parakeet TDT
max_seq_lengthintMaximum decoder sequence length

FastSTTModel.get_peft_model()

FastSTTModel.get_peft_model(model, r=16, finetune_encoder=True, finetune_decoder=True, lora_alpha=16, ...) → STTModelWrapper

Add LoRA adapters to encoder and/or decoder attention layers.

ParameterTypeDescription
modelSTTModelWrapperModel from from_pretrained()
rintLoRA rank
finetune_encoderboolApply LoRA to encoder attention blocks
finetune_decoderboolApply LoRA to decoder attention blocks
lora_alphaintLoRA scaling factor. Recommended: equal to r
FastSTTModel.for_training(model) — Enable training mode
FastSTTModel.for_inference(model) — Enable inference mode

FastSTTModel.convert()

FastSTTModel.convert(hf_model, output_dir="mlx_model", quantize=False, q_bits=4, dtype=None, upload_repo=None)

Convert a HuggingFace Whisper model to MLX format. Uses mlx_audio.convert with model_domain="stt".

STTModelWrapper Methods

model.transcribe(audio, language=None, ...) → str — Transcribe audio to text
model.save_pretrained(output_dir) — Save LoRA adapters
model.load_adapter(adapter_path) — Load saved adapters
model.save_pretrained_merged(output_dir, processor=None, push_to_hub=False, repo_id=None) — Save full merged model
model.push_to_hub(repo_id) — Push adapters to HuggingFace Hub

STTSFTTrainer

STTSFTTrainer(model, processor, data_collator, train_dataset, args=None)

STT fine-tuning trainer. Uses sequence-to-sequence training with encoder-decoder architecture. Batch size is forced to 1.

ParameterTypeDescription
modelSTTModelWrapperModel with LoRA adapters configured
processorSTTProcessorProcessor from from_pretrained()
data_collatorSTTDataCollatorCollator for computing mel spectrograms
train_datasetDatasetDataset with audio and transcript columns
argsSTTSFTConfigTraining configuration

STTSFTConfig

STTSFTConfig(per_device_train_batch_size=1, gradient_accumulation_steps=4, max_steps=100, learning_rate=1e-5, output_dir="./stt_outputs", sample_rate=16000, language="en", ...)

STT-specific training configuration. Extends SFTConfig with audio parameters.

ParameterDefaultDescription
per_device_train_batch_size1Forced to 1 for audio training
gradient_accumulation_steps4Number of gradient accumulation steps
max_steps100Total training steps
learning_rate1e-5Peak learning rate (lower than TTS)
sample_rate16000Audio sample rate in Hz (Whisper uses 16 kHz)
language"en"Target language for transcription

STTDataCollator

STTDataCollator(model, processor, language="en", task="transcribe", audio_column="audio", text_column="text")

Data collator for STT training. For Whisper models, computes 80-bin log-mel spectrograms padded to 30 seconds. For Moonshine, passes raw audio through the conv frontend. Tokenizes transcription targets with language-specific SOT tokens.

STTProcessor

STTProcessor(model, tokenizer=None, hf_processor=None)

Processor wrapper for STT models. Provides mel spectrogram computation and text tokenization.

processor.compute_mel(audio) → mx.array — Compute mel spectrogram from audio
processor.tokenize(text) → List[int] — Tokenize transcription text

Post-Training

Save, merge, convert, and share your fine-tuned audio models. These operations work the same for both TTS and STT.

Save adapters

# Save LoRA adapters only (small, shareable)
model.save_pretrained("./adapters")

Merge LoRA into base model

# Fuse LoRA weights into the base model and save
model.save_pretrained_merged("./merged_model")

Convert HuggingFace model to MLX

# Convert and optionally quantize a HF model
FastTTSModel.convert(
    "canopylabs/orpheus-3b",
    quantize=True,
    q_bits=4,
)

Push to HuggingFace Hub

model.push_to_hub(
    "username/my-model",
    token="hf_...",
    private=True,
)

Examples

12 — Orpheus TTS Fine-Tuning

Fine-tune Orpheus-3B for text-to-speech with SNAC audio codec. Includes data loading, training, and inference.

TTSOrpheusSNAC

View on GitHub →

13 — Whisper STT Fine-Tuning

Fine-tune Whisper for speech-to-text transcription. Encoder-decoder LoRA with mel spectrogram inputs.

STTWhisperTranscription

View on GitHub →

14 — OuteTTS Fine-Tuning

Fine-tune OuteTTS-1B (Llama-based, DAC codec) for voice cloning. Text-based audio tokens, same API as Orpheus.

TTSOuteTTSDAC

View on GitHub →

15 — Spark-TTS Fine-Tuning

Fine-tune Spark-TTS (0.5B, Qwen2-based) with BiCodec. Ultra-light TTS ideal for limited RAM.

TTSSparkBiCodec

View on GitHub →

16 — Moonshine STT Fine-Tuning

Fine-tune Moonshine for speech-to-text. Raw conv frontend (no mel spectrogram), fast edge-device STT.

STTMoonshineConv Frontend

View on GitHub →

17 — Qwen3-ASR Fine-Tuning

Fine-tune Alibaba's Qwen3-ASR (audio-LLM). Audio features injected into Qwen3 LLM for multilingual ASR.

STTQwen3-ASRAudio-LLM

View on GitHub →

18 — Canary STT Fine-Tuning

Fine-tune NVIDIA Canary (FastConformer + Transformer decoder). Multilingual ASR with 25+ languages.

STTCanaryNVIDIA

View on GitHub →

19 — Voxtral STT Fine-Tuning

Fine-tune Mistral's Voxtral (audio encoder + Llama decoder). 3B audio-LLM for speech recognition.

STTVoxtralAudio-LLM

View on GitHub →

49 — Voxtral Realtime Streaming STT NEW

Fine-tune Mistral's streaming Voxtral Realtime on LibriSpeech. Causal encoder + AdaRMSNorm decoder + additive embedding fusion. Works on all quantization variants (4bit/6bit/fp16).

STTVoxtral RTStreaming4-bit

View on GitHub →

Tips

Batch Size

Audio training forces batch_size=1. Audio sequences have variable lengths, so batching is not supported. Use gradient_accumulation_steps to simulate larger effective batch sizes.

Auto-Detection

mlx-tune auto-detects the model type from the model name. You don't need to specify profiles, codecs, or token formats manually — just pass the model ID to from_pretrained().

Sample Rates

Different models use different sample rates. Always cast your dataset audio column to match:

  • 24 kHz: Orpheus, OuteTTS, Sesame/CSM, Qwen3-TTS
  • 16 kHz: Spark-TTS, Whisper, Distil-Whisper, Moonshine, Qwen3-ASR, Canary, Voxtral

dataset.cast_column("audio", Audio(sampling_rate=24000))

Recommended Models

Use MLX-format models from mlx-community on HuggingFace:

  • Orpheus: mlx-community/orpheus-3b-0.1-ft-bf16
  • OuteTTS: mlx-community/Llama-OuteTTS-1.0-1B-8bit
  • Spark: mlx-community/Spark-TTS-0.5B-bf16
  • Whisper: mlx-community/whisper-*-asr-fp16
  • Distil-Whisper: mlx-community/distil-whisper-large-v3
  • Moonshine: UsefulSensors/moonshine-tiny
  • Qwen3-ASR: mlx-community/Qwen3-ASR-1.7B-8bit
  • Canary: eelcor/canary-1b-v2-mlx
  • Voxtral: mlx-community/Voxtral-Mini-3B-2507-bf16
Memory

16 GB+ unified RAM recommended. Approximate requirements:

  • Spark-TTS (0.5B): ~4 GB
  • OuteTTS (1B, 8-bit): ~4 GB
  • Whisper-tiny / Moonshine-tiny: ~2 GB
  • Orpheus-3B (4-bit): ~8 GB
  • Distil-Whisper-large-v3: ~6 GB
  • Qwen3-ASR (1.7B, 8-bit): ~4 GB
  • Canary (1B): ~4 GB
  • Voxtral (3B, bf16): ~12 GB
Response-Only Training

TTS training uses response-only training by default (train_on_completions=True). Loss is computed only on audio tokens, not on the text prompt — this matches how the model is used at inference time.