multistral (0.2.4)
Installation
pip install --index-url multistralAbout this package
Multistral: A multimodal small language model combining text, vision, and audio capabilities
Multistral 🚀
Multistral is a flexible multimodal small language model that seamlessly combines text, vision, and audio processing capabilities. Built on top of proven architectures including Ministral-3 (text), Pixtral (vision), and Voxtral (audio), Multistral provides, at this time, only a 4B text variant (5B multimodal).
🌟 Features
- 🔤 Text Processing: Advanced language understanding and generation based on Ministral-3
- 👁️ Vision Understanding: Image analysis and visual reasoning with Pixtral integration
- 🎵 Audio Processing: Speech and audio understanding through Voxtral encoder
- 🧠 Dual Architecture: Choose between dense (all parameters active)
- 📦 Flexible Presets: Pre-configured model sizes (4B)
- 💬 Harmony Chat Format: Built-in support for chat datasets with selective token masking
- 🔧 Training Utilities: SafeSFT trainer and data collators for efficient fine-tuning
- ⚡ Optimized Performance: Support for bfloat16, CUDA, and memory-efficient inference
🚀 Quick Start
Installation
# Basic installation
pip install multistral
# With all optional dependencies
pip install multistral[all]
# Development installation
git clone https://github.com/your-username/multistral.git
cd multistral
pip install -e ".[dev]"
Initializing a Model
from multistral import (
MultistralConfig,
MultistralForConditionalGeneration,
MultistralTokenizer,
)
from multistral.configs import get_preset
# Get a preset configuration (dense_4B)
preset = get_preset("dense_4B")
# Create configuration from preset
config = MultistralConfig(
text_config=preset["text"],
vision_config=preset["vision"],
audio_config=preset["audio"],
multimodal_config=preset["multimodal"],
)
# Initialize model with random weights
model = MultistralForConditionalGeneration(config)
tokenizer = MultistralTokenizer.from_pretrained("models/multistral-tokenizer")
print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
Text Generation
# Simple text generation
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Multimodal Usage (Text + Vision)
from PIL import Image
# Load image
image = Image.open("image.jpg")
text_prompt = "What is in this image?"
# Process with tokenizer and model
inputs = tokenizer(text_prompt, return_tensors="pt")
# Vision processing is integrated into the model forward pass
outputs = model.generate(**inputs, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
📖 Documentation
Model Architecture
Multistral combines three modalities with a unified transformer architecture:
Text Component (Ministral-3)
- Available Presets:
dense_4B: 56 layers, 2560 hidden size (~4.0B parameters)
- Common Features:
- Vocabulary size: 128,000 tokens
- Context length: 131,072 tokens (with YaRN scaling)
- Sliding window: 32,768 tokens
- 32 attention heads (dense models) or 20 attention heads (MoE models)
- RoPE positional encoding with YaRN scaling
Vision Component (Pixtral)
- Unified across all presets
- Architecture:
- Hidden size: 1024 dimensions
- 24 transformer layers
- 16 attention heads
- Image resolution: 1024×1024
- Patch size: 16×16
- ~0.4B parameters
Audio Component (Voxtral)
- Unified across all presets
- Architecture:
- Hidden size: 1024 dimensions
- 24 transformer layers
- 16 attention heads
- Input: 128 mel-frequency bins
- Max sequence length: 1500 frames
- ~0.4B parameters
Integration
- Multimodal projectors bridge vision/audio features to text space
- Efficient feature extraction and alignment
- Support for any combination of modalities
Configuration
You can either use pre-configured presets or create custom configurations:
Using Presets (Recommended):
from multistral import MultistralConfig, MultistralMoEConfig
from multistral.configs import get_preset
# Dense models
dense_4b = get_preset("dense_4B")
# Create config from preset
config = MultistralConfig(
text_config=dense_4b["text"],
vision_config=dense_4b["vision"],
audio_config=dense_4b["audio"],
multimodal_config=dense_4b["multimodal"],
)
Custom Configuration:
from multistral import MultistralConfig
config = MultistralConfig(
text_config={
"vocab_size": 128000,
"hidden_size": 2560,
"num_hidden_layers": 56,
"num_attention_heads": 20,
"intermediate_size": 6400,
},
vision_config={
"hidden_size": 1024,
"num_hidden_layers": 24,
"image_size": 1024,
"patch_size": 16,
},
audio_config={
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_mel_bins": 128,
},
)
Special Tokens and Chat Format
Multistral includes specialized tokens for multimodal tasks and structured chat interactions:
from multistral import SpecialTokens
# Multimodal tokens
print(SpecialTokens.IMAGE.value) # <|image|>
print(SpecialTokens.AUDIO.value) # <|audio|>
print(SpecialTokens.VIDEO.value) # <|video|>
# Conversation structure
print(SpecialTokens.BEGIN.value) # <|begin|>
print(SpecialTokens.END.value) # <|end|>
Harmony Chat Format
Multistral uses the Harmony format for structured training on conversation datasets. This format allows selective token masking to control which tokens contribute to the loss:
from multistral import DataCollatorHarmony, DataCollatorHarmonyConfig, parse_for_training
# Configure Harmony data collator
collator_config = DataCollatorHarmonyConfig(
ignore_index=-100, # Don't compute loss on these tokens
mode="prompt_response", # or "full_conversation"
)
collator = DataCollatorHarmony(collator_config)
# Use with trainer
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=collator,
# ... other trainer arguments
)
For more details, see Harmony Documentation.
🔧 Advanced Usage
Memory Optimization
import torch
from multistral import MultistralForConditionalGeneration
from multistral.configs import get_preset
# Load with memory optimization
preset = get_preset("dense_4B")
config = MultistralConfig(**preset)
model = MultistralForConditionalGeneration(config)
# Use bfloat16 and memory-efficient loading
model = model.to(dtype=torch.bfloat16)
# Enable gradient checkpointing for training
model.gradient_checkpointing_enable()
Parameter Freezing
from multistral import freeze_embeddings, freeze_vision_audio, freeze_bottom_layers
# Freeze vision and audio encoders, train only text
freeze_vision_audio(model)
# Freeze bottom layers for efficient fine-tuning
freeze_bottom_layers(model, num_layers=12)
# Check trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,}")
Training with SafeSFTTrainer
from multistral import SafeSFTTrainer, SafeSFTConfig
from transformers import TrainingArguments
config = SafeSFTConfig(
dataset_text_field="text",
remove_unused_columns=False,
)
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=1,
learning_rate=1e-5,
save_strategy="steps",
save_steps=100,
bf16=True,
)
trainer = SafeSFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
sft_config=config,
)
trainer.train()
📦 Package Structure
Core Modules
multistral.MultistralConfig: Configuration classes- **
multistral.MultistralForConditionalGeneration**: Model classes multistral.MultistralTokenizer: Custom tokenizer with special multimodal tokensmultistral.SpecialTokens: Enumeration of all special tokens
Training & Data
multistral.SafeSFTTrainer: Efficient supervised fine-tuning trainermultistral.DataCollatorHarmony: Data collator for Harmony chat formatmultistral.parse_for_training: Parser for conversation datasets
Utilities
multistral.freeze_embeddings: Freeze embedding layersmultistral.freeze_vision_audio: Freeze vision/audio encodersmultistral.freeze_bottom_layers: Freeze bottom N transformer layersmultistral.load_streaming_datasets: Load and stream large datasetsmultistral.StreamingCollator: Data collator for streaming datasets
🎯 Model Specifications
Dense Models (All parameters always active)
| Preset | Layers | Hidden | Heads | Text Params | Memory (bf16) |
|---|---|---|---|---|---|
| dense_4B | 56 | 2560 | 20 | ~4.0B | ~8.0GB |
Shared Components
| Component | Dimension | Layers | Heads | Parameters |
|---|---|---|---|---|
| Vision (Pixtral) | 1024 | 24 | 16 | ~0.4B |
| Audio (Voxtral) | 1024 | 24 | 16 | ~0.4B |
Note: Vision and audio encoders are identical across all presets and do not count toward the "size" label.
Development Setup
# Clone repository
git clone https://github.com/your-username/multistral.git
cd multistral
# Install in development mode
pip install -e ".[dev]"
# Run tests
pytest tests/
# Run type checking
mypy multistral/
# Format code
black multistral/ tests/
📄 License
This project is licensed under the ISC License - see the LICENSE file for details.
🙏 Acknowledgments
- Mistral AI for the Ministral-3 text architecture
- Mistral AI for the Pixtral vision model
- Mistral AI for the Voxtral audio encoder
- Hugging Face for the Transformers library
- PyTorch team for the deep learning framework