This repository provides the official implementation of CLD, a lightweight language-detection module for multilingual ASR. This codebase contains our pip-installable Python package (cld/) including our training/benchmark scripts implemented in JAX and optimized via ADMM for high performance in low-resource settings. Simply, the package attaches a small language detection head (Convex NN / small NN / linear SVM) to ASR encoder representations, and use it to select the language token (Whisper) or adapter (MMS) before decoding.
- High Accuracy: Excels in binary and multiclass language detection (Table 2).
- Low-Resource Robustness: Effective with limited data (Figures 1 & 2).
- Efficient: 13x training speedup from traditional NNs due to ADMM optimization and JAX.
This repo supports two common setups:
- Package-only install (inference usage):
pip install -e .- Full training/benchmark environment (recommended if you run the scripts in this repo):
pip install -e ".[train]"If you prefer installing from the pinned dependency list instead:
pip install -r requirements.txtimport numpy as np
from cld import ASRModel, CVXNNLangDetectHead, NNLangDetectHead, SVMLangDetectHead
# 1) Load the base ASR model
languages = ["en", "hi", "id", "ms", "zh"]
asr = ASRModel.from_pretrained("openai/whisper-small", config={"languages": languages})
# 2) Load a language detection head artifact (choose ONE)
# head = CVXNNLangDetectHead.load("path/to/whisper-small_trained_cvx_mlp.pkl", asr)
# head = NNLangDetectHead.load("path/to/openai_whisper-small_nn_head.pkl", asr)
# head = SVMLangDetectHead.load("path/to/openai_whisper-small_linear_svm.pkl", asr)
# 3) Attach head and run inference
asr.set_lang_detect_head(head)
audio_16k_mono: np.ndarray = ... # shape (T,), sampling rate 16kHz
pred_langs, pred_texts = asr.predict(audio_16k_mono)
print(pred_langs[0], pred_texts[0])All training/evaluation scripts expect a Hugging Face DatasetDict saved to disk (loaded via datasets.load_from_disk(...)) with splits like train, valid, test. Use our data_ingestion.py script to prepare your data.
python data_ingestion.py \
--config configs/en_hi_config.json \
--out data/en_hi \
--common-voice-dir /absolute/path/to/CommonVoice \
--augment- Required:
--configJSON (see example below),--outsave directory. - Optional:
--augmentenables audiomentations;--musan-dirfor background noise;--common-voice-dirfor local Common Voice. - Output: a saved
DatasetDictatdata/en_hiwith columns:audio,text,lang,accent.
Minimal config example (see more in configs/):
{
"name": "English-Hindi example",
"languages": {
"en": {
"accents": [
{ "code": "us", "column_name": "United States English", "dataset": "common_voice" }
]
},
"hi": {
"accents": [
{ "code": "hi", "column_name": "", "dataset": "common_voice" }
]
}
},
"params": {
"samples_per_class": 1000,
"split": { "train": 0.8, "val": 0.1, "test": 0.1 }
}
}Notes:
- Common Voice selection uses
column_nameagainstaccentsinvalidated.tsv. Useoverride_codeto point to alternative folders (seeconfigs/final_config.json). - Lahaja examples match by
native_language(e.g.,"Telugu","Konkani").
All heads are trained on pooled encoder embeddings extracted by ASRModel.load_data(...) from a dataset on disk.
python train_cvxnn.py \
--model_name openai/whisper-small \
--dataset_path data/multiclass \
--languages en,hi,id,ms,zh \
--output_dir models/lang_heads \
--neuron 64 \
--beta 0.001 \
--rho 0.1 \
--admm_iters 6This produces a pickled artifact like:
models/lang_heads/openai/whisper-small/openai_whisper-small_trained_cvx_mlp.pkl
python train_nn.py \
--dataset_path data/multiclass \
--model_name openai/whisper-small \
--languages en,hi,id,ms,zh \
--output_dir models/lang_heads \
--num_train_epochs 10 \
--learning_rate 1e-3 \
--per_device_train_batch_size 256This produces a pickled artifact like:
models/lang_heads/openai/whisper-small/openai_whisper-small_nn_head.pkl
python train_linear_svm.py \
--model_name openai/whisper-small \
--data_dir data/multiclass \
--languages en,hi,id,ms,zh \
--output_dir models/lang_heads \
--C 1.0 \
--max_iter 5000This produces a pickled artifact like:
models/lang_heads/openai/whisper-small/openai_whisper-small_linear_svm.pkl
Use train_whisper.py to fine-tune a Whisper checkpoint on a preprocessed dataset directory:
python train_whisper.py \
--data_dir data/multiclass \
--model_id openai/whisper-small \
--output_dir models/whisper-small-finetuned \
--num_train_epochs 3 \
--learning_rate 1e-5 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--eval_strategy steps \
--eval_steps 1000 \
--save_steps 1000Optional logging:
python train_whisper.py ... \
--wandb_project CLD \
--run_name whisper-small-finetune-final_dryUse benchmark_cld.py to evaluate language detection and transcription quality (WER/CER) on the test split.
python benchmark_cld.py \
--dataset_path data/multiclass \
--model_name openai/whisper-small \
--cld_type cvx \
--cld_path models/lang_heads/openai/whisper-small/openai_whisper-small_trained_cvx_mlp.pkl \
--languages en,hi,id,ms,zh \
--batch_size 32 \
--no_wandbpython benchmark_cld.py \
--dataset_path data/multiclass \
--model_name openai/whisper-small \
--cld_type nn \
--cld_path models/lang_heads/openai/whisper-small/openai_whisper-small_nn_head.pkl \
--languages en,hi,id,ms,zh \
--batch_size 32 \
--no_wandbpython benchmark_cld.py \
--dataset_path data/multiclass \
--model_name openai/whisper-small \
--cld_type linear_svm \
--cld_path models/lang_heads/openai/whisper-small/openai_whisper-small_linear_svm.pkl \
--languages en,hi,id,ms,zh \
--batch_size 32 \
--no_wandbpython benchmark_cld.py \
--dataset_path data/multiclass \
--model_name openai/whisper-small \
--cld_type vanilla \
--languages en,hi,id,ms,zh \
--batch_size 32 \
--no_wandbPaper results (Table 5):
To reproduce the evaluation numbers for a given head, run benchmark_cld.py as shown in the Evaluation section.

