jhu-clsp/mmBERT-base

fill masktransformersaaiaaktransformerspytorchmodernbertfill-maskaaiaakmit
350.4K

mmBERT: A Modern Multilingual Encoder

License: MIT Paper Model Collection GitHub

TL;DR: A state-of-the-art multilingual encoder trained on 3T+ tokens across 1800+ languages, introducing novel techniques for learning low-resource languages during the decay phase.

mmBERT is a modern multilingual encoder that significantly outperforms previous generation models like XLM-R on classification, embedding, and retrieval tasks. Built on the ModernBERT architecture with novel multilingual training innovations, mmBERT demonstrates that low-resource languages can be effectively learned during the decay phase of training. It is also significantly faster than any previous multilingual encoder.

Table of Contents

Quick Start

Installation

pip install torch>=1.9.0
pip install transformers>=4.21.0

Usage

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base")
model = AutoModel.from_pretrained("jhu-clsp/mmBERT-base")

inputs = tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs)

Model Description

mmBERT represents the first significant advancement over XLM-R for massively multilingual encoder models. Key features include:

  1. Massive Language Coverage - Trained on over 1800 languages with progressive inclusion strategy
  2. Modern Architecture - Built on ModernBERT foundation with Flash Attention 2 and unpadding techniques
  3. Novel Training Recipe - Introduces inverse mask scheduling and temperature sampling
  4. Open Training Data - Complete 3T+ token dataset publicly available
  5. Decay Phase Innovation - Demonstrates effective learning of low-resource languages in final training phase

The model uses bidirectional attention with masked language modeling objectives, optimized specifically for multilingual understanding and cross-lingual transfer.

Novel Training Innovations

Progressive Language Addition: Start with 60 high-resource languages, expand to 110 mid-resource languages, then include all 1833 languages in decay phase.

Inverse Mask Schedule: Reduce mask ratio from 30% → 15% → 5% across training phases for progressively refined learning.

Inverse Temperature Sampling: Adjust multilingual sampling from high-resource bias (τ=0.7) to uniform sampling (τ=0.3).

Model Merging: Combine English-focused, high-resource, and all-language decay variants using TIES merging.

Model Family

ModelTotal ParamsNon-embed ParamsLanguagesDownload
mmBERT-small140M42M1800+Download
mmBERT-base307M110M1800+Download

Training Data

mmBERT training data is publicly available across different phases:

PhaseDatasetTokensDescription
Pre-training P1mmbert-pretrain-p12.3T60 languages, foundational training
Pre-training P2mmbert-pretrain-p2-Extension data for pre-training phase
Pre-training P3mmbert-pretrain-p3-Final pre-training data
Mid-trainingmmbert-midtraining600B110 languages, context extension to 8K
Decay Phasemmbert-decay100B1833 languages, premium quality

Data Sources: Filtered DCLM (English), FineWeb2 (multilingual), FineWeb2-HQ (20 high-resource languages), Wikipedia (MegaWika), code repositories (StarCoder, ProLong), academic papers (ArXiv, PeS2o), and community discussions (StackExchange).

Model Architecture

ParametermmBERT-smallmmBERT-base
Layers2222
Hidden Size384768
Intermediate Size11521152
Attention Heads612
Total Parameters140M307M
Non-embedding Parameters42M110M
Max Sequence Length81928192
Vocabulary Size256,000256,000
TokenizerGemma 2Gemma 2

Usage Examples

Masked Language Modeling

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base")
model = AutoModelForMaskedLM.from_pretrained("jhu-clsp/mmBERT-base")

def predict_masked_token(text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    
    mask_indices = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)
    predictions = outputs.logits[mask_indices]
    top_tokens = torch.topk(predictions, 5, dim=-1)
    
    return [tokenizer.decode(token) for token in top_tokens.indices[0]]

# Works across languages
texts = [
    "The capital of France is <mask>.",
    "La capital de España es <mask>.",
    "Die Hauptstadt von Deutschland ist <mask>."
]

for text in texts:
    predictions = predict_masked_token(text)
    print(f"Text: {text}")
    print(f"Predictions: {predictions}")

Cross-lingual Embeddings

from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity

tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base")
model = AutoModel.from_pretrained("jhu-clsp/mmBERT-base")

def get_embeddings(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
    
    return embeddings.numpy()

multilingual_texts = [
    "Artificial intelligence is transforming technology",
    "La inteligencia artificial está transformando la tecnología",
    "L'intelligence artificielle transforme la technologie", 
    "人工智能正在改变技术"
]

embeddings = get_embeddings(multilingual_texts)
similarities = cosine_similarity(embeddings)
print("Cross-lingual similarity matrix:")
print(similarities)

Fine-tuning Examples

Dense Retrieval with Sentence Transformers

Click to expand dense retrieval fine-tuning example
import argparse
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", type=float, default=8e-5)
    parser.add_argument("--model_name", type=str, default="jhu-clsp/mmBERT-base")
    args = parser.parse_args()
    
    lr = args.lr
    model_name = args.model_name
    model_shortname = model_name.split("/")[-1]

    model = SentenceTransformer(model_name)

    dataset = load_dataset(
        "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
        "triplet-hard",
        split="train",
    )
    dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset_dict["train"].select(range(1_250_000))
    eval_dataset = dataset_dict["test"]

    loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16)
    run_name = f"{model_shortname}-DPR-{lr}"
    
    training_args = SentenceTransformerTrainingArguments(
        output_dir=f"output/{model_shortname}/{run_name}",
        num_train_epochs=1,
        per_device_train_batch_size=512,
        per_device_eval_batch_size=512,
        warmup_ratio=0.05,
        fp16=False,
        bf16=True,
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        learning_rate=lr,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=2,
        logging_steps=500,
        run_name=run_name,
    )

    dev_evaluator = TripletEvaluator(
        anchors=eval_dataset["query"],
        positives=eval_dataset["positive"],
        negatives=eval_dataset["negative"],
        name="msmarco-co-condenser-dev",
    )
    dev_evaluator(model)

    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=dev_evaluator,
    )
    trainer.train()

    model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
    model.push_to_hub(run_name, private=False)

if __name__ == "__main__":
    main()

Cross-lingual Classification

Click to expand multilingual classification fine-tuning example
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification, 
    TrainingArguments,
    Trainer
)
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {
        'accuracy': accuracy_score(labels, predictions),
        'f1': f1_score(labels, predictions, average='weighted')
    }

def main():
    model_name = "jhu-clsp/mmBERT-base"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=3
    )
    
    dataset = load_dataset("xnli", "all_languages")
    
    def tokenize_function(examples):
        texts = [f"{p} {tokenizer.sep_token} {h}" 
                for p, h in zip(examples["premise"], examples["hypothesis"])]
        
        return tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=512
        )
    
    train_dataset = dataset["train"].map(tokenize_function, batched=True)
    eval_dataset = dataset["validation"].map(tokenize_function, batched=True)
    
    training_args = TrainingArguments(
        output_dir="./mmbert-xnli",
        learning_rate=3e-5,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=3,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

Multilingual Reranking

Click to expand multilingual reranking fine-tuning example
import logging
from datasets import load_dataset
from sentence_transformers.cross_encoder import (
    CrossEncoder,
    CrossEncoderModelCardData,
    CrossEncoderTrainer,
    CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.util import mine_hard_negatives
from sentence_transformers import SentenceTransformer
import torch

def main():
    model_name = "jhu-clsp/mmBERT-base"
    train_batch_size = 32
    num_epochs = 2
    num_hard_negatives = 7

    model = CrossEncoder(
        model_name,
        model_card_data=CrossEncoderModelCardData(
            language="multilingual",
            license="mit",
        ),
    )
    
    full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(50_000))
    dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=42)
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]

    embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device="cpu")
    hard_train_dataset = mine_hard_negatives(
        train_dataset,
        embedding_model,
        num_negatives=num_hard_negatives,
        margin=0,
        range_min=0,
        range_max=100,
        sampling_strategy="top",
        batch_size=2048,
        output_format="labeled-pair",
        use_faiss=True,
    )

    loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))

    nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
        dataset_names=["msmarco", "nfcorpus", "nq"],
        batch_size=train_batch_size,
    )

    args = CrossEncoderTrainingArguments(
        output_dir="./mmbert-reranker",
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=train_batch_size,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,
        bf16=True,
        dataloader_num_workers=4,
        load_best_model_at_end=True,
        metric_for_best_model="eval_msmarco_ndcg@10",
        eval_strategy="steps",
        eval_steps=1000,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=2,
        logging_steps=200,
        seed=42,
    )

    trainer = CrossEncoderTrainer(
        model=model,
        args=args,
        train_dataset=hard_train_dataset,
        loss=loss,
        evaluator=nano_beir_evaluator,
    )
    trainer.train()

    model.save_pretrained("./mmbert-reranker/final")

if __name__ == "__main__":
    main()

Training Data

mmBERT was trained on a carefully curated 3T+ token multilingual dataset:

PhaseDatasetDescription
Pre-training P12.3T tokens60 languages, diverse data mixture
Pre-training P2-Extension data for pre-training
Pre-training P3-Final pre-training data
Mid-training600B tokens110 languages, context extension
Decay Phase100B tokens1833 languages, premium quality

Primary Sources:

  • Filtered DCLM: High-quality English content
  • FineWeb2: Broad multilingual web coverage (1800+ languages)
  • FineWeb2-HQ: Filtered subset of 20 high-resource languages
  • Code: StarCoder and ProLong repositories
  • Academic: ArXiv papers and PeS2o scientific content
  • Reference: Wikipedia (MegaWika) and textbooks
  • Community: StackExchange discussions

Citation

If you use mmBERT in your research, please cite our work:

@misc{marone2025mmbertmodernmultilingualencoder,
      title={mmBERT: A Modern Multilingual Encoder with Annealed Language Learning}, 
      author={Marc Marone and Orion Weller and William Fleshman and Eugene Yang and Dawn Lawrie and Benjamin Van Durme},
      year={2025},
      eprint={2509.06888},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2509.06888}, 
}

"""

DEPLOY IN 60 SECONDS

Run mmBERT-base on Runcrate

Deploy on H100, A100, or RTX GPUs. Pay only for what you use. No setup required.