Multistral is a flexible multimodal small language model that seamlessly combines text, vision, and audio processing capabilities. Built on top of proven architectures including Ministral-3 (text), Pixtral (vision), and Voxtral (audio), Multistral provides, at this time, only a 4B text variant (5B multimodal). https://git.flety.net/damien/aizia
  • Python 98%
  • Jinja 2%
Find a file
damfle 4f898cbec6
All checks were successful
CI / Lint (push) Successful in 27s
CI / Test (push) Successful in 4m18s
CI / Build (push) Successful in 18s
CI / Create Tag (push) Successful in 3s
doc: README.md
2026-03-09 09:13:00 +01:00
.forgejo/workflows min: bump to 0.0.7 2026-03-06 17:45:50 +01:00
examples doc: README.md 2026-03-09 09:13:00 +01:00
multistral min: format 2026-03-08 23:37:39 +01:00
tests min: format 2026-03-08 23:37:39 +01:00
.gitignore min: bump to 0.0.7 2026-03-06 17:45:50 +01:00
LICENSE mod: keep only model code here 2026-03-06 17:45:50 +01:00
MANIFEST.in min: bump to 0.0.7 2026-03-06 17:45:50 +01:00
pyproject.toml doc: README.md 2026-03-09 09:13:00 +01:00
README.md doc: README.md 2026-03-09 09:13:00 +01:00

Multistral 🚀

Multistral is a flexible multimodal small language model that seamlessly combines text, vision, and audio processing capabilities. Built on top of proven architectures including Ministral-3 (text), Pixtral (vision), and Voxtral (audio), Multistral provides two variants: Multistral (HF-based) and Multistral2 (fully-owned implementation), both with 4B text parameters (~5B multimodal).

🌟 Features

  • 🔤 Text Processing: Advanced language understanding and generation based on Ministral-3
  • 👁️ Vision Understanding: Image analysis and visual reasoning with Pixtral integration
  • 🎵 Audio Processing: Speech and audio understanding through Voxtral encoder
  • 🧠 Dual Architecture: Choose between Multistral (HF-based) and Multistral2 (fully-owned) implementations
  • 📦 Flexible Presets: Pre-configured model sizes (4B) with two implementation variants
  • 💬 Harmony Chat Format: Built-in support for chat datasets with selective token masking
  • 🔧 Training Utilities: SafeSFT trainer and data collators for efficient fine-tuning
  • Optimized Performance: Support for bfloat16, CUDA, and memory-efficient inference

🚀 Quick Start

Installation

# Basic installation
pip install multistral

# With all optional dependencies
pip install multistral[all]

# Development installation
git clone https://git.flety.net/damien/multistral.git
cd multistral
pip install -e ".[dev]"

Initializing a Model

from multistral import (
    MultistralConfig,
    MultistralForConditionalGeneration,
    MultistralTokenizer,
)
from multistral.configs import get_preset

# Get a preset configuration (dense_4B)
preset = get_preset("dense_4B")

# Create configuration from preset
config = MultistralConfig(
    text_config=preset["text"],
    vision_config=preset["vision"],
    audio_config=preset["audio"],
    multimodal_config=preset["multimodal"],
)

# Initialize model with random weights
model = MultistralForConditionalGeneration(config)
tokenizer = MultistralTokenizer()

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

Text Generation

# Simple text generation
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

Multimodal Usage (Text + Vision)

from PIL import Image

# Load image
image = Image.open("image.jpg")
text_prompt = "What is in this image?"

# Process with tokenizer and model
inputs = tokenizer(text_prompt, return_tensors="pt")
# Vision processing is integrated into the model forward pass
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

📖 Documentation

Model Architecture

Multistral combines three modalities with a unified transformer architecture:

Text Component (Ministral-3)

  • Available Presets:
    • dense_4B: 56 layers, 2560 hidden size (~4.0B parameters)
    • multistral2_dense_4B: 56 layers, 2560 hidden size with GQA and sliding window
  • Common Features:
    • Vocabulary size: 128,000 tokens
    • Context length: 131,072 tokens (with YaRN scaling)
    • Sliding window: 32,768 tokens
    • 20 attention heads with 4 key/value heads (GQA)
    • RoPE positional encoding with YaRN scaling

Vision Component (Pixtral)

  • Unified across all presets
  • Architecture:
    • Hidden size: 1024 dimensions
    • 24 transformer layers
    • 16 attention heads
    • Image resolution: 1024×1024
    • Patch size: 16×16
    • ~0.4B parameters

Audio Component (Voxtral)

  • Unified across all presets
  • Architecture:
    • Hidden size: 1024 dimensions
    • 24 transformer layers
    • 16 attention heads
    • Input: 128 mel-frequency bins
    • Max sequence length: 1500 frames
    • ~0.4B parameters

Integration

  • Multimodal projectors bridge vision/audio features to text space
  • Efficient feature extraction and alignment
  • Support for any combination of modalities
  • Two model variants: Multistral (HF-based) and Multistral2 (fully-owned implementation)

Configuration

You can either use pre-configured presets or create custom configurations:

Using Presets (Recommended):

from multistral import MultistralConfig, Multistral2Config
from multistral.configs import get_preset

# Dense models
dense_4b = get_preset("dense_4B")

# Create config from preset
config = MultistralConfig(
    text_config=dense_4b["text"],
    vision_config=dense_4b["vision"],
    audio_config=dense_4b["audio"],
    multimodal_config=dense_4b["multimodal"],
)

# For Multistral2 variant
multistral2_preset = get_preset("multistral2_dense_4B")
config2 = Multistral2Config(
    text_config=multistral2_preset["text"],
    vision_config=multistral2_preset["vision"],
    audio_config=multistral2_preset["audio"],
    multimodal_config=multistral2_preset["multimodal"],
)

Custom Configuration:

from multistral import MultistralConfig

config = MultistralConfig(
    text_config={
        "vocab_size": 128000,
        "hidden_size": 2560,
        "num_hidden_layers": 56,
        "num_attention_heads": 20,
        "intermediate_size": 6400,
    },
    vision_config={
        "hidden_size": 1024,
        "num_hidden_layers": 24,
        "image_size": 1024,
        "patch_size": 16,
    },
    audio_config={
        "hidden_size": 1024,
        "num_hidden_layers": 24,
        "num_mel_bins": 128,
    },
)

Special Tokens and Chat Format

Multistral includes specialized tokens for multimodal tasks and structured chat interactions:

from multistral import SpecialTokens

# Multimodal tokens
print(SpecialTokens.IMAGE.value)      # <|image|>
print(SpecialTokens.AUDIO.value)      # <|audio|>
print(SpecialTokens.VIDEO.value)      # <|video|>

# Conversation structure
print(SpecialTokens.BEGIN.value)      # <|begin|>
print(SpecialTokens.END.value)        # <|end|>

Harmony Chat Format

Multistral uses the Harmony format for structured training on conversation datasets. This format allows selective token masking to control which tokens contribute to the loss:

from multistral import DataCollatorHarmony, DataCollatorHarmonyConfig, parse_for_training

# Configure Harmony data collator
collator_config = DataCollatorHarmonyConfig(
    ignore_index=-100,          # Don't compute loss on these tokens
    mode="prompt_response",      # or "full_conversation"
)

collator = DataCollatorHarmony(collator_config)

# Use with trainer
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=collator,
    # ... other trainer arguments
)

For more details, see Harmony Documentation.

🔧 Advanced Usage

Memory Optimization

import torch
from multistral import MultistralForConditionalGeneration
from multistral.configs import get_preset

# Load with memory optimization
preset = get_preset("dense_4B")
config = MultistralConfig(**preset)

model = MultistralForConditionalGeneration(config)

# Use bfloat16 and memory-efficient loading
model = model.to(dtype=torch.bfloat16)

# Enable gradient checkpointing for training
model.gradient_checkpointing_enable()

Parameter Freezing

from multistral import freeze_embeddings, freeze_vision_audio, freeze_bottom_layers

# Freeze vision and audio encoders, train only text
freeze_vision_audio(model)

# Freeze bottom layers for efficient fine-tuning
freeze_bottom_layers(model, num_layers=12)

# Check trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,}")

Training with SafeSFTTrainer

from multistral import SafeSFTTrainer, SafeSFTConfig
from transformers import TrainingArguments

config = SafeSFTConfig(
    dataset_text_field="text",
    remove_unused_columns=False,
)

training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    learning_rate=1e-5,
    save_strategy="steps",
    save_steps=100,
    bf16=True,
)

trainer = SafeSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    sft_config=config,
)

trainer.train()

📦 Package Structure

Core Modules

  • multistral.MultistralConfig: Configuration class for the dense multimodal model (HF-based)
  • multistral.MultistralForConditionalGeneration: Main dense model for conditional generation (HF-based)
  • multistral.Multistral2Config: Configuration class for the fully-owned implementation
  • multistral.Multistral2ForConditionalGeneration: Fully-owned model implementation
  • multistral.MultistralTokenizer: Custom tokenizer with special multimodal tokens
  • multistral.SpecialTokens: Enumeration of all special tokens

Training & Data

  • multistral.SafeSFTTrainer: Efficient supervised fine-tuning trainer
  • multistral.DataCollatorHarmony: Data collator for Harmony chat format
  • multistral.parse_for_training: Parser for conversation datasets

Utilities

  • multistral.freeze_embeddings: Freeze embedding layers
  • multistral.freeze_vision_audio: Freeze vision/audio encoders
  • multistral.freeze_bottom_layers: Freeze bottom N transformer layers
  • multistral.load_streaming_datasets: Load and stream large datasets
  • multistral.StreamingCollator: Data collator for streaming datasets

🎯 Model Specifications

Dense Models (All parameters always active)

Preset Layers Hidden Heads Text Params Memory (bf16)
dense_4B 56 2560 20 ~4.0B ~8.0GB
multistral2_dense_4B 56 2560 20 (4 KV) ~4.0B ~8.0GB

Shared Components

Component Dimension Layers Heads Parameters
Vision (Pixtral) 1024 24 16 ~0.4B
Audio (Voxtral) 1024 24 16 ~0.4B

Note: Vision and audio encoders are identical across all presets and do not count toward the "size" label.

Development Setup

# Clone repository
git clone https://github.com/your-username/multistral.git
cd multistral

# Install in development mode
pip install -e ".[dev]"

# Run tests
pytest tests/

# Run type checking
mypy multistral/

# Format code
black multistral/ tests/

📄 License

This project is licensed under the ISC License - see the LICENSE file for details.

🙏 Acknowledgments

  • Mistral AI for the Ministral-3 text architecture
  • Mistral AI for the Pixtral vision model
  • Mistral AI for the Voxtral audio encoder
  • Hugging Face for the Transformers library
  • PyTorch team for the deep learning framework
  • Damien Fleuriot for the Multistral implementation and architecture