Q*: Improving Multi-step Reasoning for LLMs with Deliberative Planning (Idea behind OpenAI o1)

Chaojie Wang, Yanchen Deng, Zhiyi Lyu, Liang Zeng, Jujie He, Shuicheng Yan, Bo An

The paper introduces Q*, a groundbreaking framework aimed at enhancing the multi-step reasoning capabilities of Large Language Models (LLMs) through deliberative planning. Conceptualized as a Markov Decision Process (MDP), Q leverages heuristic search methods like the A* algorithm and plug-and-play Q-value models to guide decision-making without the need for task-specific fine-tuning. This novel approach addresses the limitations of LLMs’ auto-regressive generation, improving accuracy and consistency in complex tasks such as math problem-solving and code generation. Extensive empirical validation on datasets like GSM8K, MATH, and MBPP demonstrates that Q* significantly outperforms existing baselines, setting a new benchmark in the field. The Q* framework has a similar approach to OpenAI’s advanced o1 model and it underscores its practical relevance and effectiveness. Researchers and graduate students will find this paper invaluable due to its robust theoretical grounding, innovative methodology, and strong empirical results. Future research could explore refining heuristic functions, enhancing scalability, and broadening task validation to further optimize and generalize the Q* framework. The detailed and well-structured presentation of complex methodologies makes this paper an essential read for those looking to advance the capabilities of LLMs in multi-step reasoning tasks. Code in this blog is generated by combining OpenAI o1 and Claude Sonnet 3.5.

Mind Map

High resolution PDF and interactive HTML can be found at the Github repo.

Highlights explained

1. Introduction of Q* Framework for LLMs

a. Explanation

The Q* framework is a novel deliberation and planning framework designed to enhance Large Language Models’ (LLMs) multi-step reasoning capabilities. It guides the decision-making process by leveraging a heuristic search approach without the need for task-specific fine-tuning.

b. Significance

The Q* framework addresses the limitations of LLMs’ auto-regressive generation process by introducing a more structured and logical approach to reasoning. This enhances the accuracy and consistency of LLMs in complex tasks, such as math problems and code generation, without incurring significant computational overhead.

c. Relation to Existing Work

Q* builds on foundational AI concepts like Markov Decision Processes (MDP) and heuristic search algorithms (e.g., A*) while avoiding the need for laborious prompt engineering. Its generalizability across tasks makes it a versatile tool in AI research and applications.

2. Formalization of Multi-step Reasoning as an MDP

a. Explanation

In Q*, multi-step reasoning is conceptualized as a Markov Decision Process (MDP), where the state is represented by the concatenation of the input prompt and the reasoning steps generated so far, the action is the next reasoning step, and the reward evaluates the correctness of the solution.

b. Significance

This formalization allows the reasoning process to be systematically guided and optimized, leveraging the well-established MDP framework. By using this structured approach, the Q* framework can ensure that each reasoning step moves closer to the correct solution.

c. Relation to Existing Work

MDPs are well-studied in reinforcement learning and decision-making fields. Q* innovatively applies this concept to LLMs, specifically enhancing their reasoning processes without requiring extensive modifications to the models.

3. Heuristic Function with Plug-and-Play Q-value Models

a. Explanation

Q* utilizes plug-and-play Q-value models as heuristic functions to estimate the expected utility of state-action pairs. These Q-value models can be trained using various methods, such as offline reinforcement learning, rollouts, or completions with stronger LLMs like GPT-4.

b. Significance

By providing a robust heuristic function, Q* can effectively guide LLMs in selecting the most promising next reasoning step. This approach reduces the need for extensive fine-tuning and allows for more efficient and accurate decision-making in complex tasks.

c. Relation to Existing Work

Heuristic search methods and Q-value estimation techniques are well-known in planning and reinforcement learning. Q* integrates these concepts into LLMs’ reasoning processes, creating a more efficient and effective framework for multi-step reasoning.

Code

The PoC code uses a small GPT-4o-mini to illustrate the process of Q*.

Bash
pip install pydantic openai
Python
"""
Q* Algorithm Implementation

Overview:
This implementation demonstrates the Q* algorithm, a novel approach to enhance multi-step reasoning
in Large Language Models (LLMs). The core idea is to formulate the reasoning process as a Markov
Decision Process (MDP) and use a heuristic search algorithm (inspired by A*) to guide the LLM
through the problem-solving steps.

Key components:
1. MathProblemSolver: Simulates an LLM solving math problems step-by-step.
2. QValueEstimator: Estimates the value of each potential next step.
3. QStar: Implements the main Q* algorithm, using beam search and Q-value guidance.

The algorithm works by generating multiple possible solution paths, evaluating them using
estimated Q-values, and selecting the most promising path at each step. This approach allows
for more deliberate and accurate problem-solving compared to standard auto-regressive LLM generation.

Thought flow:
1. Initialize the problem state
2. Generate possible next steps
3. Estimate Q-values for each step
4. Select the best steps using beam search
5. Repeat steps 2-4 until a solution is found or max steps are reached
6. Verify and potentially correct the final solution

This implementation serves as a proof of concept and can be extended to various reasoning tasks
beyond math problem solving.
"""

import os
import json
from typing import List, Optional
from pydantic import BaseModel
from openai import OpenAI
import re

# Initialize OpenAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

class MathStep(BaseModel):
    step: str
    explanation: str

class MathSolution(BaseModel):
    steps: List[MathStep]
    final_answer: str
    numerical_answer: Optional[float] = None

class QValue(BaseModel):
    state: str
    action: str
    value: float

class MathProblemSolver:
    def solve(self, question: str) -> MathSolution:
        """
        Simulates an LLM solving a math problem step-by-step.
        
        Why here: This method encapsulates the LLM's problem-solving capability,
        allowing us to use it as a black box within the Q* algorithm. It's crucial
        for generating initial solutions and potential next steps.
        """
        print(f"MathProblemSolver: Solving question - {question}")
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert math problem solver. Provide a step-by-step solution focusing on direct, logical steps leading to the final answer in the requested units."},
                {"role": "user", "content": f"Solve this math problem step by step and provide a clear numerical answer in the requested units: {question}"}
            ],
            functions=[{
                "name": "provide_solution",
                "description": "Provide a step-by-step solution to a math problem",
                "parameters": MathSolution.model_json_schema()
            }],
            function_call={"name": "provide_solution"}
        )
        solution = MathSolution.model_validate_json(response.choices[0].message.function_call.arguments)
        print(f"MathProblemSolver: Solution found with {len(solution.steps)} steps")
        return solution

class QValueEstimator:
    def estimate(self, state: str, action: str) -> QValue:
        """
        Estimates the Q-value for a given state-action pair.
        
        Why here: Q-value estimation is central to the Q* algorithm. It provides
        a heuristic for evaluating the potential of each step, guiding the search
        towards more promising solutions. This method allows us to leverage the
        LLM's knowledge for this crucial evaluation.
        """
        print(f"QValueEstimator: Estimating Q-value for action - {action[:30]}...")
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert in evaluating problem-solving steps. Focus on steps that directly lead to the solution in the requested units."},
                {"role": "user", "content": f"Estimate the likelihood (0-1) that this step will lead to the correct solution in the requested units.\nCurrent state: {state}\nProposed step: {action}"}
            ],
            functions=[{
                "name": "estimate_q_value",
                "description": "Estimate the Q-value for a state-action pair",
                "parameters": QValue.model_json_schema()
            }],
            function_call={"name": "estimate_q_value"}
        )
        q_value = QValue.model_validate_json(response.choices[0].message.function_call.arguments)
        print(f"QValueEstimator: Q-value estimated as {q_value.value}")
        return q_value

class QStar:
    def __init__(self, solver: MathProblemSolver, q_estimator: QValueEstimator):
        self.solver = solver
        self.q_estimator = q_estimator

    def solve(self, question: str, max_steps: int = 7, beam_width: int = 3) -> MathSolution:
        """
        Implements the Q* algorithm to solve a math problem.
        
        Why here: This is the core of the Q* algorithm. It orchestrates the entire
        problem-solving process, combining beam search with Q-value guidance to
        explore and evaluate multiple solution paths simultaneously.
        """
        print(f"QStar: Starting to solve - {question}")
        initial_state = question
        beam = [(initial_state, [], 0)]  # Each element is (state, path, score)
        seen_steps = set()
        best_solution = None

        for step in range(max_steps):
            print(f"QStar: Step {step + 1}/{max_steps}")
            candidates = []
            for state, path, score in beam:
                # Generate and evaluate possible next steps
                next_steps = self._generate_next_steps(state, seen_steps)
                for next_step in next_steps:
                    if next_step.step not in seen_steps:
                        new_state = f"{state}\n{next_step.step}"
                        q_value = self.q_estimator.estimate(state, next_step.step)
                        new_score = score + q_value.value
                        candidates.append((new_state, path + [next_step], new_score))
                        seen_steps.add(next_step.step)
            
            if not candidates:
                print("QStar: No more candidates, breaking")
                break

            # Select top candidates (beam search)
            beam = sorted(candidates, key=lambda x: x[2], reverse=True)[:beam_width]
            print(f"QStar: Top candidate score: {beam[0][2]}")

            # Check if a complete solution is found
            current_solution = self._format_solution(beam[0][1])
            if current_solution.numerical_answer is not None:
                best_solution = current_solution
                if self._is_complete_solution(question, beam[0][0], current_solution):
                    print("QStar: Complete solution found")
                    return best_solution

        # If no complete solution found, use the best partial solution
        if best_solution is None:
            print("QStar: No complete solution found, using best partial solution")
            best_solution = self._format_solution(beam[0][1])
        
        # Verify and potentially correct the solution
        print("QStar: Verifying and potentially correcting the solution")
        verified_solution = self._verify_and_correct_solution(question, best_solution)
        return verified_solution

    def _generate_next_steps(self, state: str, seen_steps: set) -> List[MathStep]:
        """
        Generates possible next steps for the problem-solving process.
        
        Why here: This method is crucial for exploring the solution space. By leveraging
        the LLM to generate diverse and relevant steps, we can consider multiple
        problem-solving approaches, enhancing the algorithm's ability to find optimal solutions.
        """
        print("QStar: Generating next steps")
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert math problem solver. Generate possible next steps, focusing on reaching a numerical answer in the requested units. Avoid repeating steps or adding unnecessary ones."},
                {"role": "user", "content": f"Generate 3 possible next steps for this math problem, focusing on reaching a numerical answer in the requested units:\n{state}\n\nAvoid these steps: {seen_steps}"}
            ],
            functions=[{
                "name": "generate_steps",
                "description": "Generate possible next steps for a math problem",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "steps": {
                            "type": "array",
                            "items": MathStep.model_json_schema()
                        }
                    },
                    "required": ["steps"]
                }
            }],
            function_call={"name": "generate_steps"}
        )
        
        function_call = response.choices[0].message.function_call
        if function_call and function_call.arguments:
            try:
                args = json.loads(function_call.arguments)
                steps = [MathStep.model_validate(step) for step in args.get("steps", [])]
            except json.JSONDecodeError:
                print("QStar: Error decoding JSON in _generate_next_steps")
                steps = []
        else:
            print("QStar: No function call in _generate_next_steps response")
            steps = []

        print(f"QStar: Generated {len(steps)} next steps")
        for i, step in enumerate(steps, 1):
            print(f"  Step {i}: {step.step}")
            print(f"    Explanation: {step.explanation}")
        return steps

    def _is_complete_solution(self, question: str, state: str, solution: MathSolution) -> bool:
        """
        Checks if the current solution is complete and correct.
        
        Why here: This method is essential for determining when to stop the search process.
        It ensures that we not only have a numerical answer but also that the solution
        directly addresses the original question, preventing unnecessary additional steps.
        """
        print("QStar: Checking if solution is complete")
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert in evaluating math solutions. Ensure the solution directly answers the question in the requested units."},
                {"role": "user", "content": f"Does this represent a complete solution to the math problem with a clear numerical answer in the requested units? Answer with 'Yes' or 'No' and explain briefly.\n\nQuestion: {question}\n\nSolution:\n{state}\nFinal Answer: {solution.final_answer}\nNumerical Answer: {solution.numerical_answer}"}
            ]
        )
        answer = response.choices[0].message.content.strip().lower()
        is_complete = "yes" in answer and solution.numerical_answer is not None
        print(f"QStar: Solution completeness: {is_complete}")
        return is_complete

    def _format_solution(self, steps: List[MathStep]) -> MathSolution:
        """
        Formats the solution steps into a MathSolution object.
        
        Why here: This method standardizes the solution format, ensuring consistency
        in how solutions are represented and making it easier to process and evaluate them.
        """
        final_step = steps[-1]
        numerical_answer = self._extract_numerical_answer(" ".join(step.step for step in steps))
        return MathSolution(steps=steps, final_answer=final_step.step, numerical_answer=numerical_answer)

    def _extract_numerical_answer(self, text: str) -> Optional[float]:
        """
        Extracts the numerical answer from the solution text.
        
        Why here: Accurate extraction of the numerical answer is crucial for evaluating
        the correctness of solutions. This method uses regex patterns to identify
        numerical answers in various formats, improving the robustness of the solution evaluation.
        """
        # First, look for patterns like "x = 7" or "The answer is 42 km/h" at the end of the text
        match = re.search(r'(?:x\s*=\s*|answer\s*is\s*|equals\s*|result\s*is\s*)(\d+(?:\.\d+)?)\s*(?:km/h|m/s|mph)?\s*$', text, re.IGNORECASE)
        if match:
            return float(match.group(1))
        
        # If not found at the end, search throughout the text
        match = re.search(r'(?:x\s*=\s*|answer\s*is\s*|equals\s*|result\s*is\s*)(\d+(?:\.\d+)?)\s*(?:km/h|m/s|mph)?', text, re.IGNORECASE)
        if match:
            return float(match.group(1))
        
        # If still not found, look for the last number in the text
        numbers = re.findall(r'\d+(?:\.\d+)?', text)
        if numbers:
            return float(numbers[-1])
        
        return None

    def _verify_and_correct_solution(self, question: str, solution: MathSolution) -> MathSolution:
        """
        Verifies the solution and corrects it if necessary.
        
        Why here: This final verification step is crucial for ensuring the accuracy
        of the solution. It leverages the LLM's capability to check the solution's
        correctness and make necessary corrections, adding an extra layer of reliability
        to the Q* algorithm's output.
        """
        print("QStar: Verifying and correcting solution")
        verification_prompt = f"""
        Question: {question}
        
        Proposed solution steps:
        {' '.join(step.step for step in solution.steps)}
        
        Proposed final answer: {solution.final_answer}
        Extracted numerical answer: {solution.numerical_answer}
        
        Please verify if the solution and numerical answer are correct and in the requested units. If not, provide the correct numerical answer and explain the correction. If possible, perform a simple calculation to check the answer.
        """
        
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an expert math problem solver and verifier. Ensure the answer is correct and in the requested units."},
                {"role": "user", "content": verification_prompt}
            ]
        )
        
        verification_result = response.choices[0].message.content.strip()
        
        if "correct" in verification_result.lower():
            print("QStar: Solution verified as correct")
            return solution
        else:
            print("QStar: Solution needs correction")
            corrected_answer = self._extract_numerical_answer(verification_result)
            if corrected_answer is not None:
                solution.numerical_answer = corrected_answer
                solution.final_answer = f"The correct answer is {corrected_answer}"
                solution.steps.append(MathStep(step=f"Correction: {corrected_answer}", explanation=verification_result))
            return solution

def main():
    """
    Main function to demonstrate the Q* algorithm on sample math problems.
    
    Why here: This function serves as an entry point and demonstration of how
    the Q* algorithm can be applied to various math problems. It showcases the
    versatility of the approach across different problem types.
    """
    solver = MathProblemSolver()
    q_estimator = QValueEstimator()
    q_star = QStar(solver, q_estimator)

    problems = [
        "If a train travels 120 km in 2 hours, what is its average speed in km/h?",
        "A rectangle has a length of 8 cm and a width of 5 cm. What is its area?",
        "If x + 3 = 10, what is the value of x?",
        "A store offers a 20% discount on a $50 item. What is the final price?",
        "If 3x - 7 = 14, solve for x."
    ]

    for i, problem in enumerate(problems, 1):
        print(f"\nSolving Problem {i}: {problem}")
        solution = q_star.solve(problem)
        print(f"\nProblem {i}: {problem}")
        print("Solution:")
        for step in solution.steps:
            print(f"- {step.step}")
            print(f"  Explanation: {step.explanation}")
        print(f"Final Answer: {solution.final_answer}")
        if solution.numerical_answer is not None:
            print(f"Numerical Answer: {solution.numerical_answer}")
        print("-" * 50)

if __name__ == "__main__":
    main()

"""
Conclusion and Key Takeaways:

1. MDP Formulation: The Q* algorithm formulates multi-step reasoning as a Markov
   Decision Process, allowing for a structured approach to problem-solving.

2. Heuristic Search: By using Q-values as a heuristic, the algorithm efficiently
   navigates the solution space, prioritizing promising paths.

3. LLM Integration: The implementation leverages LLMs for multiple aspects of
   problem-solving, including step generation, Q-value estimation, and solution
   verification.

4. Beam Search: The use of beam search allows the algorithm to explore multiple
   solution paths simultaneously, increasing the chances of finding optimal solutions.

5. Verification Step: The final verification and correction step adds an extra
   layer of reliability to the solutions produced by the algorithm.

6. Flexibility: This implementation demonstrates how the Q* approach can be applied
   to various types of math problems, showcasing its potential for broader applications
   in multi-step reasoning tasks.

7. Interpretability: By breaking down the problem-solving process into discrete steps
   and providing explanations, the Q* algorithm offers a more interpretable approach
   compared to end-to-end black-box models.

Future Directions:
- Extend the implementation to other domains beyond math problem solving.
- Explore more sophisticated Q-value estimation techniques, possibly incorporating
  reinforcement learning approaches.
- Investigate ways to optimize the algorithm for better efficiency, especially for
  more complex, multi-step reasoning tasks.
- Develop methods to incorporate domain-specific knowledge or constraints into the
  Q* framework for specialized applications.

This implementation serves as a proof of concept for the Q* algorithm, demonstrating
its potential to enhance the multi-step reasoning capabilities of large language models.
"""

3 thoughts on “Q*: Improving Multi-step Reasoning for LLMs with Deliberative Planning (Idea behind OpenAI o1)

    1. Very possible. Q* only showed some simple search strategy, surely more complex and smart branching and pruning methods should be used. I think there will be some fancier evaluation functions, like policy and action evaluations from AlphaGo’s MCTS.

Leave a Reply

Your email address will not be published. Required fields are marked *