Ashutosh Joshi, Sheikh Muhammad Sarwar, Samarth Varshney, Sreyashi Nag, Shrivats Agrawal, Juhi Naik
The paper introduces REAPER (Reasoning-based PlannER), a groundbreaking lightweight language model-based planner designed to enhance the efficiency of Retrieval-Augmented Generation (RAG) systems, specifically tailored for complex dialog scenarios such as the Amazon Rufus shopping chatbot. REAPER approach addresses the slowness of Chain-of-Thought (CoT) tasks effectively. This research stands out by addressing the significant latency issues that plague traditional RAG systems. REAPER achieves an impressive reduction in retrieval plan latency to 207 milliseconds—compared to 2 seconds per step in current models—without sacrificing accuracy, thereby making it highly suitable for real-time applications where response speed is critical. The paper reports REAPER’s high accuracy rates, with 95% in tool selection and 92% in argument generation, illustrating that smaller LLMs can perform effectively in these tasks. Moreover, REAPER’s scalability with minimal training data—requiring only 286 examples to add new retrieval sources—demonstrates its practical applicability and efficiency. Its innovative data generation techniques, such as Tool Evolve (TEvo), Tool-Task Generator (TTG), and Diverse Query Sampler (DQS), ensure robustness and diversity, combating issues of model bias. These advancements not only enhance conversational AI but also set a new benchmark in the field, making the paper invaluable for researchers and practitioners aiming to optimize dialog systems. Future research could explore the application of REAPER across various domains and further its integration with other AI systems to dynamically adapt to evolving user needs.
Mind Map
graph LR root["REAPER: Reasoning based Retrieval Planning for Complex RAG Systems"] root --> branch1["Research Question/Objective"] root --> branch2["Methodology"] root --> branch3["Key Findings/Contributions"] root --> branch4["Data and Analysis"] root --> branch5["Results and Discussion"] root --> branch6["Implications"] root --> branch7["Limitations"] root --> branch8["Future Research Directions"] branch1 -.-> leaf1["Optimize retrieval in RAG systems"] branch1 -.-> leaf2["Reduce latency in conversational shopping assistants"] branch2 -.-> leaf3["Use smaller LLM for planning"] branch2 -.-> leaf4["Single-step retrieval plan generation"] branch2 -.-> leaf5["Data generation: TEvo, TTG, DQS"] branch3 -.-> leaf6["Introduction of REAPER"] branch3 -.-> leaf7["Latency reduction"] branch3 -.-> leaf8["High accuracy in tool selection"] branch3 -.->leaf9["Scalability with minimal training data"] branch4 -.-> leaf10["Training on 6K in-domain queries"] branch4 -.-> leaf11["Evaluation against traditional models"] branch4 -.-> leaf12["Comparison with models like Claude3-Sonnet"] branch5 -.-> leaf13["Latency vs. performance metrics"] branch5 -.-> leaf14["Tool selection accuracy: 95%"] branch5 -.-> leaf15["Tool argument accuracy: 92%"] branch5 -.-> leaf16["Scalability results with new cases"] branch6 -.-> leaf17["Impact on real-time applications"] branch6 -.-> leaf18["Improvement in user experience"] branch6 -.-> leaf19["Potential for wider AI integration"] branch7 -.-> leaf20["Technical jargon complexity"] branch7 -.-> leaf21["Evaluation dataset bias"] branch7 -.-> leaf22["In-context learning limitations"] branch8 -.-> leaf23["Real-world traffic testing"] branch8 -.-> leaf24["Integration with reinforcement learning"] branch8 -.-> leaf25["Scalability with more tools"] branch8 -.-> leaf26["Applications in other domains"]
Highlights explained
1. Introduction of REAPER: A Lightweight LLM-Based Planner
Explanation:
REAPER stands for Reasoning-based PlannER and is a lightweight language model-based planner designed to optimize retrieval in Retrieval-Augmented Generation (RAG) systems. It uses a smaller Large Language Model (LLM) for the planning phase to generate retrieval plans efficiently.
Significance:
The use of a smaller LLM for planning significantly reduces the latency typically incurred by traditional RAG systems without compromising the accuracy of the responses. This is crucial for applications like the Amazon Rufus shopping chatbot, where response speed and accuracy directly impact user experience.
Relation to Existing Work:
REAPER builds on the Retrieval-Augmented Generation frameworks, enhancing them by introducing a more efficient planning mechanism. It addresses the latency issues identified in previous models by Lewis et al. (2020) and others.
2. Significant Reduction in Latency
Explanation:
REAPER achieves a latency of 207 milliseconds for the entire retrieval plan, compared to 2 seconds per step for more powerful models like Claude3-Sonnet.
Significance:
The drastic reduction in latency makes REAPER highly suitable for real-time applications like conversational shopping assistants, where quick interaction is critical. Lower latency enhances user satisfaction by providing quicker responses.
Relation to Existing Work:
Existing RAG systems often suffer from high latency due to multiple steps of reasoning and retrieval. REAPER’s approach directly addresses this issue, making it a practical solution for deployment in live systems.
3. High Accuracy in Tool Selection and Argument Generation
Explanation:
REAPER achieves 95% accuracy in tool selection and 92% accuracy in generating correct tool arguments. It demonstrates that a smaller LLM can maintain high performance levels in these critical tasks.
Significance:
High accuracy ensures that the chatbot can provide relevant and precise information to users, improving the overall effectiveness of the shopping assistant. This balance of efficiency and performance is essential for maintaining user trust and satisfaction.
Relation to Existing Work:
While traditional RAG systems can achieve high accuracy, they often do so at the expense of latency. REAPER’s ability to maintain high accuracy with reduced latency is a significant improvement.
Code
pip install openai scikit-learn numpy
import os
import random
import re
import numpy as np
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity
# Initialize the OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Simulated knowledge graph
knowledge_graph = {
"products": {
"laptop123": {"name": "TechPro Laptop", "price": 999, "category": "Electronics", "battery_life": "8 hours"},
"phone456": {"name": "SmartPhone X", "price": 699, "category": "Electronics", "battery_life": "12 hours"},
},
"customer_orders": {
"789": {"customer": "John Doe", "product": "laptop123", "status": "Shipped"},
},
"reviews": {
"laptop123": [{"rating": 4.5, "text": "Great performance for the price!"}],
"phone456": [{"rating": 4.0, "text": "Good phone, but battery life could be better."}],
}
}
# Simulated retrieval tools
class KnowledgeGraphRetriever:
def __init__(self, knowledge_graph):
self.knowledge_graph = knowledge_graph
def _get_product_id(self, product_name):
for product_id, product_info in self.knowledge_graph["products"].items():
if product_info["name"].lower() == product_name.lower():
return product_id
return None
def prod_qna(self, product_id_or_name, query):
product_id = self._get_product_id(product_id_or_name) or product_id_or_name
product = self.knowledge_graph["products"].get(product_id)
if product:
query = query.lower()
if query == "price":
return f"The price of {product['name']} is ${product['price']}."
elif query == "battery life":
return f"The battery life of {product['name']} is approximately {product.get('battery_life', '12 hours')}."
elif query == "availability":
return f"{product['name']} is currently {'in stock' if random.choice([True, False]) else 'out of stock'}."
elif query == "return policy":
return f"The return policy for {product['name']} is 30 days."
elif query == "warranty":
return f"The warranty for {product['name']} is {product.get('warranty', '1 year')}."
else:
return f"Product: {product['name']}, Price: ${product['price']}, Category: {product['category']}"
return "Product not found"
def order_status(self, order_id):
order = self.knowledge_graph["customer_orders"].get(order_id)
if order:
return f"Order {order_id} for {order['product']} is {order['status']}"
return "Order not found"
def review_summary(self, product_id_or_name):
product_id = self._get_product_id(product_id_or_name) or product_id_or_name
reviews = self.knowledge_graph["reviews"].get(product_id, [])
if reviews:
avg_rating = sum(review['rating'] for review in reviews) / len(reviews)
return f"Average rating for {product_id_or_name}: {avg_rating:.1f}, Sample review: {reviews[0]['text']}"
return "No reviews found"
# Alias methods to handle variations
get_prod_qna = prod_qna_v2 = prod_qna
get_review_summary = review_summary_v2 = review_summary
order_status_v2 = order_status
retriever = KnowledgeGraphRetriever(knowledge_graph)
# Tool Evolve (TEvo)
# This function creates variations of tool descriptions to enhance model robustness
def tevo(tools):
evolved_tools = []
for tool in tools:
name_variations = [tool['name'], f"{tool['name']}_v2", f"get_{tool['name']}"]
desc_variations = [
tool['description'],
f"This tool {tool['description'].lower()}",
f"Use this to {tool['description'].lower()}"
]
evolved_tools.append({
'name': random.choice(name_variations),
'description': random.choice(desc_variations)
})
return evolved_tools
# Tool-Task Generator (TTG)
# This function generates diverse, retrieval-related tasks to improve the model's reasoning capabilities
def ttg(primary_task):
secondary_tasks = [
f"Explain why the tool in Step 1 is the most appropriate for answering the query.",
"What additional information might be needed to provide a more comprehensive answer?",
"If the primary tool fails, what would be an alternative approach to answer the query?",
"How could this plan be expanded to provide more detailed information to the customer?",
"Identify any assumptions made in this plan and explain how they might affect the response.",
]
return random.choice(secondary_tasks)
# Diverse Query Sampler (DQS)
# This function selects a diverse set of queries based on their semantic similarity
def get_embedding(text, model="text-embedding-3-small"):
response = client.embeddings.create(input=[text], model=model)
return response.data[0].embedding
def dqs(initial_queries, additional_queries, n_samples=5):
initial_embeddings = [get_embedding(q) for q in initial_queries]
additional_embeddings = [get_embedding(q) for q in additional_queries]
similarities = cosine_similarity(initial_embeddings, additional_embeddings)
diverse_indices = np.argsort(similarities.max(axis=0))[:n_samples]
return [additional_queries[i] for i in diverse_indices]
# REAPER prompt
# This function generates the prompt for the REAPER model, incorporating evolved tool descriptions
def get_reaper_prompt(tools):
evolved_tools = tevo(tools)
prompt = """You are an AI assistant for a retail store. Generate a single, coherent step-by-step plan to retrieve information from the knowledge graph to answer customer questions. Use these tools:
"""
for tool in evolved_tools:
prompt += f"{tool['name']}: {tool['description']}\n"
prompt += """
Generate a plan using these tools. Each step should use only one tool and should be in the following format:
Step X: tool_name("arg1", "arg2")
If no tool is needed, state that the question can be answered without retrieval. Do not generate multiple plans.
Customer query: {query}
Context: {context}
Plan:
"""
return prompt
# REAPER function
# This is the main function that generates the retrieval plan
def reaper(query, context="", model="gpt-4o"):
tools = [
{'name': 'prod_qna', 'description': 'Fetches specific information for a product.'},
{'name': 'order_status', 'description': 'Retrieves the status of an order.'},
{'name': 'review_summary', 'description': 'Provides a summary of reviews for a product.'}
]
prompt = get_reaper_prompt(tools).format(query=query, context=context)
# Apply TTG to create a secondary task
secondary_task = ttg(query)
prompt += f"\n\nAdditional task: {secondary_task}"
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
)
return response.choices[0].message.content, secondary_task
# Execute plan function
# This function executes the plan generated by REAPER
def execute_plan(plan):
results = []
steps = [step.strip() for step in plan.split('\n') if step.strip().startswith('Step')]
for step in steps:
match = re.search(r'Step \d+:\s*(\w+)\s*\((.*?)\)', step)
if match:
tool_name, args = match.groups()
args = [arg.strip().strip('"') for arg in args.split(',')]
if tool_name in ['prod_qna', 'prod_qna_v2', 'get_prod_qna']:
if len(args) == 2:
results.append(retriever.prod_qna(args[0], args[1]))
else:
results.append(f"Error: Invalid number of arguments for {tool_name}")
elif tool_name in ['review_summary', 'get_review_summary', 'review_summary_v2']:
if len(args) == 1:
results.append(retriever.review_summary(args[0]))
else:
results.append(f"Error: Invalid number of arguments for {tool_name}")
elif tool_name in ['order_status', 'order_status_v2']:
if len(args) == 1:
results.append(retriever.order_status(args[0]))
else:
results.append(f"Error: Invalid number of arguments for {tool_name}")
else:
results.append(f"Unknown tool: {tool_name}")
return "\n".join(results) if results else "Could not execute any steps in the plan."
# Main execution
if __name__ == "__main__":
# Initial set of queries
initial_queries = [
"What's the price of the TechPro Laptop?",
"What's the status of my order number 789?",
]
# Additional set of queries
additional_queries = [
"What's the battery life of the SmartPhone X?",
"Can you summarize the reviews for the SmartPhone X?",
"Is the TechPro Laptop available, and what's its return policy?",
"What accessories are available for the SmartPhone X and how much do they cost?",
"Compare the battery life of the SmartPhone X and the TechPro Laptop",
"What's the warranty and price of the SmartPhone X?",
"Are there any negative reviews about the battery life of the SmartPhone X?",
"What's the status of my order 123 and when will it be delivered?",
"Tell me about the camera quality and storage options for the SmartPhone X",
"What's the most popular feature of the TechPro Laptop according to customer reviews?",
]
print("Initial Queries:")
for query in initial_queries:
print(f"- {query}")
print("\nAdditional Queries:")
for query in additional_queries:
print(f"- {query}")
# Use DQS to select diverse queries
diverse_queries = dqs(initial_queries, additional_queries)
print("\nDiverse Queries selected by DQS:")
for query in diverse_queries:
print(f"- {query}")
print("\nExecuting REAPER for Initial and Diverse Queries:")
all_queries = initial_queries + diverse_queries
for query in all_queries:
print(f"\nQuery: {query}")
plan, secondary_task = reaper(query)
# Extracting main plan
main_plan = plan.split("Plan:")[-1].strip() if "Plan:" in plan else plan
print("REAPER Plan:")
print(main_plan)
print("\nAdditional Task:")
print(secondary_task)
print("\nExecution Result:")
result = execute_plan(main_plan)
print(result)
print("=" * 70)