Matryoshka Representation Learning

Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham Kakade, Prateek Jain, Ali Farhadi

The official repo with ResNet implementation can be found at

The paper “Matryoshka Representation Learning” (MRL) introduces a groundbreaking approach to representation learning that significantly enhances the efficiency and flexibility of embedding vectors. By encoding information at multiple granularities within a single vector, MRL allows for substantial compression—up to 14 times smaller—facilitating scalable and resource-efficient deployment across various machine learning applications. This novel methodology stands out by enabling adaptive usage of embeddings based on computational constraints, thus optimizing performance for both large-scale server environments and resource-limited edge devices. Empirical evidence from extensive experiments highlights MRL’s robustness, particularly in out-of-distribution scenarios and long-tail few-shot classifications. This flexibility and efficiency are exemplified by OpenAI’s adoption of MRL in their text-embedding-3-small model, which maintains performance with significantly reduced vector sizes. The paper’s thorough validation across multiple domains and datasets underscores its potential to redefine representation learning standards. Researchers should explore further refinement of MRL, including dynamic tuning of nesting dimensions and real-world deployment trials, to fully leverage its capabilities. With its innovative approach and practical implications, this paper is a must-read for those seeking to advance the field of efficient and adaptive machine learning.

Mind Map

graph LR

root["Matryoshka Representation Learning"]
root --> branch1["Research Question/Objective"]
branch1 -.-> leaf1["Improve representation learning"]
branch1 -.-> leaf2["Reduce embedding size"]
branch1 -.-> leaf3["Enhance computational efficiency"]

root --> branch2["Methodology"]
branch2 -.-> leaf4["Matryoshka Representation Learning (MRL)"]
leaf4 -.-> subleaf1["Multigranular embeddings"]
leaf4 -.-> subleaf2["Hierarchical structure"]
branch2 -.-> leaf5["Gradient-based training"]
branch2 -.-> leaf6["Approximate Nearest Neighbor Search (ANNS)"]

root --> branch3["Key Findings/Contributions"]
branch3 -.-> leaf7["Efficient classification and retrieval"]
branch3 -.-> leaf8["Scalable to large datasets"]
branch3 -.-> leaf9["Improved accuracy in out-of-distribution scenarios"]

root --> branch4["Data and Analysis"]
branch4 -.-> leaf10["ImageNet-1K"]
branch4 -.-> leaf11["JFT-300M"]
branch4 -.-> leaf12["ALIGN"]

root --> branch5["Results and Discussion"]
branch5 -.-> leaf13["Efficiency in embedding size reduction"]
leaf13 -.-> subleaf3["Up to 14x smaller embeddings"]
branch5 -.-> leaf14["Enhanced flexibility for adaptive deployment"]
leaf14 -.-> subleaf4["Dynamic computational budget adaptation"]
branch5 -.-> leaf15["Improved robustness and generalization"]
leaf15 -.-> subleaf5["Better performance in long-tail tasks"]

root --> branch6["Implications"]
branch6 -.-> leaf16["Practical for varied environments"]
branch6 -.-> leaf17["Optimal performance with limited resources"]

root --> branch7["Limitations"]
branch7 -.-> leaf18["Training overhead"]
branch7 -.-> leaf19["Complex implementation"]

root --> branch8["Future Research Directions"]
branch8 -.-> leaf20["Optimize weightings dynamically"]
branch8 -.-> leaf21["Explore new loss functions"]
branch8 -.-> leaf22["Dataset-aware retrieval customization"]
branch8 -.-> leaf23["Integrate with learned search data structures"]

Highlights explained

1. Introduction of Matryoshka Representation Learning (MRL)

What it means:
Matryoshka Representation Learning (MRL) is a novel method proposed in the paper that encodes information at multiple granularities within a single embedding vector. This structure allows parts of the vector to represent different levels of detail, similar to nesting dolls (Matryoshkas).

Why it’s significant:
MRL allows for flexible and efficient deployment of machine learning models, enabling the use of different portions of an embedding vector based on the available computational resources and the specific requirements of a task. This flexibility can lead to significant reductions in computational costs and storage requirements without sacrificing model accuracy.

Relation to existing work:
MRL builds upon previous concepts in representation learning and efficient neural networks, enhancing them by providing a practical method for dynamically adjusting the granularity of embeddings. Its use in applications like OpenAI’s text-embedding-3-small demonstrates its real-world impact in reducing the size of embedding vectors while maintaining performance.

2. Efficiency in Embedding Size Reduction

What it means:
The MRL method allows for a significant reduction in the size of embedding vectors—up to 14 times smaller—by retaining essential information while compressing the representation.

Why it’s significant:
Reducing embedding sizes is crucial for scalable machine learning systems, especially those dealing with large datasets or requiring real-time processing. Smaller embeddings lower both storage and processing requirements, making it feasible to deploy high-performance models on resource-constrained devices.

Relation to existing work:
This efficiency addresses a long-standing challenge in machine learning of balancing model complexity with resource constraints. MRL’s approach is distinct in that it compresses embeddings adaptively, contrasting with traditional fixed-size embeddings.

3. Enhanced Flexibility for Adaptive Deployment

What it means:
MRL allows models to adapt dynamically to different computational budgets and tasks by utilizing different parts of the embedding vector as needed. This adaptability makes it possible to scale the computational effort required for inference based on available resources.

Why it’s significant:
This flexibility is valuable for deploying machine learning models in varied environments where computational power and memory resources may vary, such as edge devices, mobile applications, and large-scale server farms. It ensures that the model can perform optimally regardless of resource availability.

Relation to existing work:
Adaptive deployment has been a key area of research in efficient neural networks. MRL’s contribution lies in its practical implementation of nesting multiple levels of information detail within a single embedding vector, enhancing the versatility compared to previous approaches.


The PoC code provides a MRL in BERT/transformer implementation for language model with ANN for educational purposes. Packages like HNSW with engineering optimizated ANN should be used for larger scale. The PoC code shows high validation loss due to very few training examples. Please refer to the official repo for other optimizations and its image classification examples.

conda install torch scikit-learn numpy
import torch
import torch.nn as nn
import torch.optim as optim
from import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import numpy as np
from sklearn.neighbors import NearestNeighbors

print("Importing necessary libraries...")

class MRLTransformer(nn.Module):
    def __init__(self, bert_model, nesting_dims, num_classes):
        super(MRLTransformer, self).__init__()
        self.bert = bert_model
        self.nesting_dims = nesting_dims
        self.num_classes = num_classes
        # Create nested classifiers for each dimension in nesting_dims
        self.classifiers = nn.ModuleList([
            nn.Linear(dim, num_classes) for dim in nesting_dims
        print(f"Initialized MRLTransformer with nesting dimensions: {nesting_dims}")
    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        # Apply nested classifiers to get predictions at different granularities
        nested_logits = []
        for i, classifier in enumerate(self.classifiers):
            nested_output = pooled_output[:, :self.nesting_dims[i]]
        return nested_logits, pooled_output

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        print(f"Created dataset with {len(texts)} examples")
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        # Tokenize and encode the text
        encoding = self.tokenizer.encode_plus(
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)

def train_mrl_transformer(model, train_loader, val_loader, num_epochs, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=2e-5)
    print(f"Model moved to device: {device}")
    for epoch in range(num_epochs):
        train_loss = 0.0
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("Training phase:")
        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            nested_logits, _ = model(input_ids, attention_mask)
            # Calculate loss for all nested dimensions
            loss = sum(criterion(logits, labels) for logits in nested_logits)
            train_loss += loss.item()
            if batch_idx % 5 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")
        train_loss /= len(train_loader)
        print("\nValidation phase:")
        val_loss = 0.0
        correct = [0] * len(model.nesting_dims)
        total = 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                nested_logits, _ = model(input_ids, attention_mask)
                loss = sum(criterion(logits, labels) for logits in nested_logits)
                val_loss += loss.item()
                for i, logits in enumerate(nested_logits):
                    _, predicted = torch.max(logits, 1)
                    correct[i] += (predicted == labels).sum().item()
                total += labels.size(0)
                if batch_idx % 5 == 0:
                    print(f"  Batch {batch_idx+1}/{len(val_loader)}, Loss: {loss.item():.4f}")
        val_loss /= len(val_loader)
        accuracies = [100 * c / total for c in correct]
        print(f'\nEpoch {epoch+1} Summary:')
        print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        for i, acc in enumerate(accuracies):
            print(f'Accuracy at dim {model.nesting_dims[i]}: {acc:.2f}%')

class ANNRetrieval:
    def __init__(self, embeddings, nesting_dims):
        self.embeddings = embeddings
        self.nesting_dims = nesting_dims
        self.ann_indices = {}
        print("Initializing ANN indices for each nesting dimension...")
        for dim in nesting_dims:
            self.ann_indices[dim] = NearestNeighbors(n_neighbors=10, algorithm='auto')
            self.ann_indices[dim].fit(embeddings[:, :dim])
        print("ANN indices initialized")
    def retrieve(self, query, dim):
        distances, indices = self.ann_indices[dim].kneighbors([query[:dim]])
        return distances[0], indices[0]

# Main execution
if __name__ == "__main__":
    print("Setting up device...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    print("\nInitializing BERT model and tokenizer...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    print("BERT model and tokenizer initialized")

    nesting_dims = [8, 16, 32, 64, 128, 256, 512, 768]
    num_classes = 4

    print("\nCreating MRLTransformer model...")
    model = MRLTransformer(bert_model, nesting_dims, num_classes)
    print("MRLTransformer model created")

    print("\nPreparing dataset...")
    texts = [
        "This movie was absolutely fantastic, I loved every minute of it!",
        "The food at this restaurant was terrible, I wouldn't recommend it to anyone.",
        "I'm feeling quite neutral about this product, it's neither good nor bad.",
        "The customer service was outstanding, they went above and beyond to help me.",
        "This book was a disappointment, it didn't live up to the hype at all.",
        "The concert was amazing, the band's energy was infectious!",
        "I found this article to be quite informative and well-written.",
        "The hotel room was a bit small, but overall it was a pleasant stay.",
        "This software is incredibly buggy, it's causing me a lot of frustration.",
        "The scenery on this hike was breathtaking, I'd highly recommend it.",
    labels = [0, 3, 1, 0, 3, 0, 1, 1, 3, 0]  # 0: Very Positive, 1: Somewhat Positive, 2: Somewhat Negative, 3: Very Negative

    print("\nCreating datasets and dataloaders...")
    train_dataset = TextDataset(texts, labels, tokenizer, max_length=128)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(train_dataset, batch_size=4)
    print("Datasets and dataloaders created")

    print("\nStarting training...")
    num_epochs = 5
    train_mrl_transformer(model, train_loader, val_loader, num_epochs, device)
    print("Training completed")

    print("\nGenerating embeddings for ANN retrieval...")
    embeddings = []
    with torch.no_grad():
        for text in texts:
            encoding = tokenizer.encode_plus(text, add_special_tokens=True, max_length=128, 
                                             padding='max_length', truncation=True, return_tensors='pt')
            _, pooled_output = model(encoding['input_ids'].to(device), encoding['attention_mask'].to(device))
    embeddings = np.array(embeddings)
    print(f"Generated embeddings shape: {embeddings.shape}")

    print("\nInitializing ANN retrieval...")
    ann_retrieval = ANNRetrieval(embeddings, nesting_dims)

    print("\nPerforming inference and retrieval on test examples...")
    test_texts = [
        "This product exceeded my expectations in every way.",
        "The service was okay, but there's definitely room for improvement.",

    with torch.no_grad():
        for test_text in test_texts:
            print(f"\nTest text: {test_text}")
            encoding = tokenizer.encode_plus(
            nested_logits, query_embedding = model(encoding['input_ids'].to(device), encoding['attention_mask'].to(device))

            print("Classification results:")
            for i, logits in enumerate(nested_logits):
                probs = torch.softmax(logits, dim=1)
                print(f"Dimension {nesting_dims[i]}: {probs[0]}")
            print("Class probabilities: Very Positive, Somewhat Positive, Somewhat Negative, Very Negative")

            print("\nRetrieval results:")
            for dim in nesting_dims:
                distances, indices = ann_retrieval.retrieve(query_embedding.cpu().numpy()[0], dim)
                print(f"\nTop 3 results for dimension {dim}:")
                for d, i in zip(distances[:3], indices[:3]):
                    print(f"  Distance: {d:.4f}, Text: {texts[i][:50]}...")

    print("\nScript execution completed")

Leave a Reply

Your email address will not be published. Required fields are marked *