In-context Continual Learning Assisted by an External Continual Learner (No more RAG!)

Saleh Momeni, Sahisnu Mazumder, Zixuan Ke, Bing Liu

The paper, “In-context Continual Learning Assisted by an External Continual Learner,” introduces an innovative methodological advancement in the realm of natural language processing (NLP), addressing persistent challenges in class-incremental learning (CIL) such as catastrophic forgetting (CF) and inter-task class separation (ICS). This study is particularly noteworthy for its novel integration of an external continual learner (ECL) with in-context learning (ICL) using large language models (LLMs), which pre-selects candidate classes through Gaussian distribution modeling. A key technical innovation is the use of Mahalanobis distance for class similarity measurements, which offers several advantages over traditional metrics like cosine similarity: it accounts for correlations between embedding dimensions through a shared covariance matrix, provides better handling of concept drift by considering the shape and orientation of class distributions, and enables more robust zero-shot generalization by capturing the statistical structure of class variations. This approach not only maintains model performance by efficiently managing LLM token limits but also circumvents CF entirely by eschewing parameter updates. The extensive empirical evidence presented demonstrates that InCA significantly outperforms current state-of-the-art methods, emphasizing its scalability and robustness across multiple benchmark datasets. The methodology encourages fresh perspectives on addressing scalability issues inherent to LLM context limitations. The use of Mahalanobis distance in conjunction with Gaussian class modeling provides a statistically sound foundation for handling evolving class distributions and adapting to changes in feature importance, making it particularly suitable for dynamic NLP environments. This paper stands out for its practical implications in effectively implementing continual learning in dynamic NLP environments and suggests further exploration into extending InCA’s applicability to diverse NLP tasks, optimizing class representations, and integrating minimal tuning strategies. This work represents a substantial contribution to the field, promising to reshape strategies for continual learning by leveraging innovative class management techniques.

Mind map

Highlights explained

Introduction of InCA for Class-Incremental Learning (CIL)

  • Explanation: InCA, short for In-context Continual Learning Assisted by an External Continual Learner, integrates an external continual learner (ECL) with in-context learning (ICL) methods using large language models (LLMs) to tackle class-incremental learning challenges.
  • Significance: The method effectively mitigates catastrophic forgetting (CF) and inter-task class separation (ICS) in NLP tasks without needing parameter updates, addressing key problems in traditional CIL frameworks.
  • Relation to Existing Work: InCA departs from conventional CL methods that typically involve parameter tuning, instead maintaining model parameters unchanged, which inherently avoids CF and improves scalability.

Efficient Class Pre-selection with Gaussian Modeling

  • Explanation: InCA employs Gaussian distributions over embeddings of class-related tags to pre-select probable classes for each test instance, reducing the in-context learning prompt length.
  • Significance: This approach allows for efficient handling of LLM token limits by focusing only on the most relevant classes, thus maintaining model performance while ensuring scalability.
  • Impact: By addressing prompt length issues, this method significantly outperforms existing baselines in scalability, making class-incremental learning more feasible in NLP tasks.

Avoidance of Catastrophic Forgetting and Scalability Issues

  • Explanation: InCA’s architecture avoids updating LLM parameters, thereby circumventing catastrophic forgetting entirely. The ECL accumulates class statistics rather than storing past data.
  • Significance: By not retaining previous task data or updating LLM weights, InCA resolves scalability issues and maintains high-class recall rates without the risk of performance deterioration over time.
  • Impact: The avoidance of CF and ICS enhances model reliability and performance consistency across various datasets and class complexities.

Integration with Long-Context LLMs

  • Explanation: InCA integrates with LLMs capable of handling longer contexts but demonstrates superior performance by intelligently selecting class data instead of relying on extended input lengths.
  • Significance: This highlights the efficacy of the ECL’s focused class selection over simply expanding LLM context windows, which can degrade performance due to irrelevant information.
  • Impact: In practice, this integration optimizes LLM utilization, leading to more accurate and efficient predictions, especially in environments with extensive class sets.

Code

The implementation centers on InCA’s key innovation: using Gaussian class modeling with Mahalanobis distance to efficiently manage LLM context, where each class is represented by its mean embedding vector and a shared covariance matrix. The ECL (GaussianClassModel) handles incremental class updates without storing raw examples, while semantic similarity between queries and class summaries provides additional confidence scoring. The system demonstrates concept drift handling (legacy to modern technology examples) and zero-shot generalization (connectivity issues with minimal examples), effectively balancing statistical and semantic understanding for continual learning.

Bash
pip install openai
Python
import os
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
import numpy as np
from openai import OpenAI
import json

# Initialize OpenAI client with API key from environment variable
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

@dataclass
class TagResponse:
    """JSON response format for tag generation"""
    tags: List[str]

@dataclass
class SummaryResponse:
    """JSON response format for class summary generation"""
    summary: str

@dataclass
class ClassificationResponse:
    """JSON response format for classification"""
    class_name: str
    confidence: float
    is_out_of_domain: bool = False

class GaussianClassModel:
    """Implements the External Continual Learner (ECL) using Gaussian distributions"""
    
    def __init__(self, embedding_dim: int = 1536):
        self.means = {}  # Class means
        self.shared_covariance = np.eye(embedding_dim) * 0.1  # Initialize with small variance
        self.class_count = 0
        self.min_covar_eigenval = 1e-6  # Minimum eigenvalue for numerical stability
        
    def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for a text string using OpenAI's embedding model"""
        response = client.embeddings.create(
            input=text,
            model="text-embedding-3-small"
        )
        return np.array(response.data[0].embedding)
    
    def _get_embeddings(self, texts: List[str]) -> np.ndarray:
        """Get embeddings for multiple texts"""
        if not texts:
            return np.array([])
        response = client.embeddings.create(
            input=texts,
            model="text-embedding-3-small"
        )
        return np.array([data.embedding for data in response.data])
        
    def _stabilize_covariance(self, cov_matrix: np.ndarray) -> np.ndarray:
        """Ensure covariance matrix is numerically stable"""
        # Add small constant to diagonal for numerical stability
        cov_matrix += np.eye(cov_matrix.shape[0]) * self.min_covar_eigenval
        
        # Ensure symmetry
        cov_matrix = (cov_matrix + cov_matrix.T) / 2
        
        # Ensure positive definiteness through eigenvalue decomposition
        eigvals, eigvecs = np.linalg.eigh(cov_matrix)
        eigvals = np.maximum(eigvals, self.min_covar_eigenval)
        cov_matrix = eigvecs @ np.diag(eigvals) @ eigvecs.T
        
        return cov_matrix
    
    def update_class_statistics(self, class_name: str, tags: List[str]):
        """Update Gaussian statistics for a class"""
        embeddings = self._get_embeddings(tags)
        if len(embeddings) == 0:
            return
            
        # Update mean for the class
        if class_name not in self.means:
            self.means[class_name] = np.mean(embeddings, axis=0)
            self.class_count += 1
        else:
            # Incremental mean update
            old_mean = self.means[class_name]
            n = len(embeddings)
            self.means[class_name] = (old_mean + np.mean(embeddings, axis=0)) / 2
        
        # Update shared covariance matrix
        diff = embeddings - self.means[class_name]
        class_cov = (diff.T @ diff) / max(len(tags), 1)
        
        # Update shared covariance with stability check
        if self.class_count > 1:
            self.shared_covariance = ((self.class_count - 1) * self.shared_covariance + class_cov) / self.class_count
        else:
            self.shared_covariance = class_cov
            
        self.shared_covariance = self._stabilize_covariance(self.shared_covariance)
    
    def get_top_k_classes(self, query_tags: List[str], k: int) -> tuple[List[str], List[float]]:
        """Get top k most similar classes using Mahalanobis distance"""
        if not query_tags:
            return [], []
            
        query_embeddings = self._get_embeddings(query_tags)
        if len(query_embeddings) == 0:
            return [], []
            
        query_mean = np.mean(query_embeddings, axis=0)
        
        # Calculate Mahalanobis distance to each class
        distances = {}
        inv_cov = np.linalg.inv(self.shared_covariance)
        
        for class_name, class_mean in self.means.items():
            diff = query_mean - class_mean
            dist = np.sqrt(max(0, diff.T @ inv_cov @ diff))  # Ensure non-negative
            distances[class_name] = dist.item()
        
        # Convert distances to similarity scores (inverse of distance)
        similarities = {cls: 1.0 / (dist + 1e-6) for cls, dist in distances.items()}
        
        # Normalize similarities to [0, 1]
        max_sim = max(similarities.values()) + 1e-6
        similarities = {cls: sim/max_sim for cls, sim in similarities.items()}
        
        # Sort by similarity (highest first)
        sorted_classes = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        classes, scores = zip(*sorted_classes[:k])
        
        return list(classes), list(scores)

class InCA:
    """Implementation of In-context Continual Learning Assisted by an ECL"""
    
    def __init__(self, embedding_dim: int = 1536):
        self.ecl = GaussianClassModel(embedding_dim)
        self.class_summaries = {}
        self.confidence_threshold = 0.3
        
    def _calculate_semantic_similarity(self, query: str, class_summary: str) -> float:
        """Calculate semantic similarity between query and class summary"""
        if not class_summary:
            return 0.0
            
        query_emb = self.ecl._get_embedding(query)
        summary_emb = self.ecl._get_embedding(class_summary)
        
        similarity = np.dot(query_emb, summary_emb) / (
            np.linalg.norm(query_emb) * np.linalg.norm(summary_emb)
        )
        
        return max(0, (similarity + 1) / 2)  # Normalize to [0, 1]
        
    def _generate_tags(self, query: str, examples: List[Dict[str, List[str]]]) -> List[str]:
        """Generate semantic tags using GPT-4o-mini with JSON mode"""
        prompt = {
            "role": "system",
            "content": """You are a tag generator that creates semantic tags for text input. 
                         Generate 5-10 relevant tags that capture the key concepts and intent.
                         Respond in JSON format with a 'tags' array."""
        }
        
        query_request = {
            "query": query,
            "examples": examples
        }
        
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                prompt,
                {"role": "user", "content": json.dumps(query_request)}
            ],
            response_format={"type": "json_object"},
            temperature=0.3
        )
        
        try:
            result = json.loads(response.choices[0].message.content)
            return result.get("tags", [])
        except:
            return []
            
    def _generate_class_summary(self, class_name: str, examples: List[str]) -> str:
        """Generate class summary using GPT-4o-mini with JSON mode"""
        prompt = {
            "role": "system",
            "content": """You are a class summarizer that creates concise descriptions.
                         Create a clear, specific summary of the class based on examples.
                         Respond in JSON format with a 'summary' field."""
        }
        
        summary_request = {
            "class_name": class_name,
            "examples": examples
        }
        
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                prompt,
                {"role": "user", "content": json.dumps(summary_request)}
            ],
            response_format={"type": "json_object"},
            temperature=0.3
        )
        
        try:
            result = json.loads(response.choices[0].message.content)
            return result.get("summary", "")
        except:
            return ""
            
    def learn_class(self, class_name: str, examples: List[str], tag_examples: List[Dict[str, List[str]]]):
        """Learn a new class incrementally"""
        summary = self._generate_class_summary(class_name, examples)
        self.class_summaries[class_name] = summary
        
        all_tags = []
        for example in examples:
            tags = self._generate_tags(example, tag_examples)
            all_tags.extend(tags)
        
        self.ecl.update_class_statistics(class_name, all_tags)
        
    def predict(self, query: str, k: int = 3) -> Dict[str, Any]:
        """Predict class for new query with improved confidence and out-of-domain detection"""
        tags = self._generate_tags(query, [])
        top_classes, similarities = self.ecl.get_top_k_classes(tags, k)
        
        if not top_classes:
            return {
                "class_name": "unknown",
                "confidence": 0.0,
                "is_out_of_domain": True,
                "candidate_classes": [],
                "statistical_confidence": 0.0,
                "semantic_confidence": 0.0
            }
        
        # Calculate statistical confidence
        statistical_confidence = similarities[0] if similarities else 0.0
        
        # Calculate semantic confidence
        max_class = top_classes[0]
        class_summary = self.class_summaries.get(max_class, "")
        semantic_confidence = self._calculate_semantic_similarity(query, class_summary)
        
        # Combined confidence score
        final_confidence = 0.6 * statistical_confidence + 0.4 * semantic_confidence
        
        # Out-of-domain detection
        is_out_of_domain = (semantic_confidence < self.confidence_threshold and 
                           statistical_confidence < self.confidence_threshold)
        
        return {
            "class_name": max_class,
            "confidence": final_confidence,
            "is_out_of_domain": is_out_of_domain,
            "candidate_classes": top_classes,
            "semantic_confidence": semantic_confidence,
            "statistical_confidence": statistical_confidence
        }

def main():
    """Example usage demonstrating concept drift and zero-shot extension"""
    # Initialize InCA
    inca = InCA()
    
    # Example tag templates
    tag_examples = [
        {
            "query": "Getting errors while saving to floppy disk",
            "tags": ["storage", "error", "data", "hardware", "save", "disk_issue"]
        },
        {
            "query": "Cloud backup not syncing",
            "tags": ["storage", "cloud", "sync", "backup", "data", "connectivity"]
        }
    ]
    
    # Phase 1: Legacy Technology
    print("\n=== Phase 1: Training with Legacy Technology ===")
    legacy_classes = {
        "storage_issue": [
            "Floppy disk not reading",
            "CD-ROM drive making noise",
            "Can't format my diskette",
            "Zip drive not recognized"
        ],
        "display_problem": [
            "Monitor showing blue screen",
            "CRT display flickering",
            "Screen resolution too low",
            "Monitor colors look wrong"
        ]
    }
    
    for class_name, examples in legacy_classes.items():
        print(f"\nLearning legacy class: {class_name}")
        inca.learn_class(class_name, examples, tag_examples)
    
    # Test cases for each phase
    for phase, queries in [
        ("Legacy Technology", [
            "My floppy drive isn't working",
            "CRT monitor showing ghost images",
            "Tape backup failed"
        ]),
        ("Modern Technology (Concept Drift)", [
            "Cloud storage not syncing with my device",
            "Can't access my Google Drive files",
            "4K display has dead pixels",
            "OLED screen showing burn-in",
            "NVMe SSD not showing up in BIOS"
        ])
    ]:
        print(f"\n=== Testing with {phase} ===")
        for query in queries:
            result = inca.predict(query)
            print(f"\nQuery: {query}")
            print(f"Predicted class: {result['class_name']}")
            print(f"Confidence: {result['confidence']:.2f}")
            print(f"Out of domain: {result['is_out_of_domain']}")
            print(f"Statistical confidence: {result['statistical_confidence']:.2f}")
            print(f"Semantic confidence: {result['semantic_confidence']:.2f}")
    
    # Phase 3: Zero-shot Extension
    print("\n=== Phase 3: Zero-shot Class Extension ===")
    new_class = {
        "connectivity_issue": [
            "WiFi keeps disconnecting",
            "Network connection unstable"
        ]
    }
    
    for class_name, examples in new_class.items():
        print(f"\nLearning new class with minimal examples: {class_name}")
        inca.learn_class(class_name, examples, tag_examples)
    
    # Test zero-shot queries
    zero_shot_queries = [
        "Bluetooth not pairing with my device",
        "5G signal drops in my area",
        "VPN connection keeps timing out",
        "Ethernet port not detecting cable",
        "DNS server not responding"
    ]
    
    print("\nTesting zero-shot generalization:")
    for query in zero_shot_queries:
        result = inca.predict(query)
        print(f"\nQuery: {query}")
        print(f"Predicted class: {result['class_name']}")
        print(f"Confidence: {result['confidence']:.2f}")
        print(f"Out of domain: {result['is_out_of_domain']}")
        print(f"Statistical confidence: {result['statistical_confidence']:.2f}")
        print(f"Semantic confidence: {result['semantic_confidence']:.2f}")

if __name__ == "__main__":
    main()

Leave a Reply

Your email address will not be published. Required fields are marked *