Agent Q: Advanced Reasoning and Learning for Autonomous AI Agents

Pranav Putta, Edmund Mills, Naman Garg, Sumeet Motwani, Chelsea Finn, Divyansh Garg, Rafael
Rafailov

The paper “Agent Q: Advanced Reasoning and Learning for Autonomous AI Agents” presents groundbreaking advancements in the development of autonomous AI agents for web navigation and complex decision-making tasks through the strategic integration of Monte Carlo Tree Search (MCTS) and Direct Preference Optimization (DPO). The use of MCTS, a heuristic search algorithm used in AlphaGo, enables the agent to explore and evaluate multiple potential actions efficiently, significantly enhancing its decision-making capabilities in dynamic and partially observable environments like web pages. This novel application of MCTS, combined with DPO’s feedback-driven fine-tuning, allows the agent to iteratively refine its strategy, resulting in superior performance in both simulated (WebShop) and real-world (OpenTable) environments, outperforming traditional methods and even human benchmarks.

The introduction of self-critique mechanisms further empowers the agent to autonomously learn from its past actions, facilitating continuous improvement and adaptation to new tasks. This significant contribution underscores the potential of integrating advanced search techniques with reinforcement learning to tackle real-world problems, setting a new standard for autonomous agents in web-based applications. The comprehensive methodology and rigorous baseline comparisons provided in the paper not only validate the effectiveness of the proposed framework but also offer a robust foundation for future research. Moving forward, exploring dynamic critic models, diverse web environments, and real-time adaptation could further enhance the applicability and performance of such autonomous agents.

Mind Map

graph LR

    root["Agent Q: Advanced Reasoning and Learning for Autonomous AI Agents"]
    
    root --> branch1["Research Question/Objective"]
    root --> branch2["Methodology"]
    root --> branch3["Key Findings/Contributions"]
    root --> branch4["Data and Analysis"]
    root --> branch5["Results and Discussion"]
    root --> branch6["Implications"]
    root --> branch7["Limitations"]
    root --> branch8["Future Research Directions"]
    
    branch1 -.-> leaf1["Enhance AI agents' decision-making in web navigation"]
    branch1 -.-> leaf2["Integration of MCTS and DPO"]
    
    branch2 -.-> leaf3["Framework Development"]
    branch2 -.-> leaf4["Self-Critique Mechanisms"]
    branch2 -.-> leaf5["Iterative Fine-Tuning"]
    
    branch3 -.-> leaf6["Superior performance in WebShop"]
    branch3 -.-> leaf7["Real-world validation with OpenTable"]
    branch3 -.-> leaf8["Algorithmic Innovation"]
    
    branch4 -.-> leaf9["Input/Output format"]
    branch4 -.-> leaf10["Use of GPT-4-V for feedback"]
    
    branch5 -.-> leaf11["Success rates in WebShop and OpenTable"]
    branch5 -.-> leaf12["Comparison with baselines"]
    
    branch6 -.-> leaf13["Potential for broad deployment in web industry"]
    
    branch7 -.-> leaf14["High complexity"]
    branch7 -.-> leaf15["Reliance on frozen critic model"]
    branch7 -.-> leaf16["Scalability concerns"]
    
    branch8 -.-> leaf17["Dynamic Critic Models"]
    branch8 -.-> leaf18["Diverse Web Environments"]
    branch8 -.-> leaf19["Safety and Risk Management"]
    branch8 -.-> leaf20["Integration with Other RL Methods"]
    branch8 -.-> leaf21["Real-Time Adaptation"]

Highlights explained

1. Integration of Monte Carlo Tree Search (MCTS) for Agent Planning

Explanation:

Monte Carlo Tree Search (MCTS) is a heuristic search algorithm used for decision processes, particularly well-known for its applications in game playing, such as the AI for AlphaGo. In this paper, MCTS is used to enhance the decision-making process of autonomous AI agents navigating complex web environments.

Significance:

The use of MCTS allows the agent to systematically explore multiple potential actions and their consequences before committing to a decision. This method helps in efficiently handling the vast decision space inherent in web navigation tasks, thus enabling better performance and more informed decision-making.

Context and Impact:

MCTS’s integration into AI planning for web navigation tasks represents a novel application of this technique, traditionally used in discrete, fully observable game environments. It showcases the versatility of MCTS and its potential for improving the reliability and effectiveness of autonomous agents in dynamic, partially observable environments like web pages.

2. Combination of MCTS with Direct Preference Optimization (DPO)

Explanation:

Direct Preference Optimization (DPO) is an off-policy reinforcement learning variant that focuses on fine-tuning agents based on feedback. In this paper, DPO is combined with MCTS to iteratively refine the agent’s decision-making model using logged feedback on past decisions.

Significance:

This combination allows the agent to benefit from both the broad exploration capabilities of MCTS and the focused, feedback-driven improvements of DPO. It leads to enhanced learning efficiency and more robust performance by integrating exploration and exploitation in the training process.

Context and Impact:

By leveraging DPO alongside MCTS, the paper introduces a novel hybrid approach that improves upon traditional reinforcement learning methods. This integration could influence future work in reinforcement learning by demonstrating the effective synergy between search-based exploration and preference-based optimization.

3. Self-Critique Mechanisms for Enhanced Learning

Explanation:

The paper introduces self-critique mechanisms where the agent evaluates its own decisions and the proposed actions during the search process, using an AI-based critic model. This feedback is then used to guide future actions and refine the decision-making process.

Significance:

Self-critique allows the agent to autonomously identify and learn from its mistakes, leading to continuous performance improvements. It enhances the agent’s ability to adapt to new tasks and environments by iteratively refining its strategy based on past performance.

Context and Impact:

Self-critique mechanisms represent an important shift towards autonomous, self-improving AI systems. This method not only improves immediate task performance but also aids in generalization across different scenarios, potentially reducing the need for extensive human supervision and retraining.

Code

The following code includes the training of AgentQ using MCTS for a simulated WebShop example. The DOP optimization uses random generator to simulate the actual policy and reference. Please adjust the number of simulations, depth and number of episodes.

Python
pip install graphviz openai numpy
Python
import os
import random
import numpy as np
from typing import List, Tuple, Dict
import openai
from collections import deque
from openai import OpenAI

from graphviz import Digraph
import time 
import colorsys

# Set up OpenAI API key
openai.api_key = os.environ["OPENAI_API_KEY"]

print("Agent Q: Advanced Reasoning and Learning for Autonomous AI Agents")
print("================================================================")

class WebShop:
    def __init__(self):
        self.items = [
            {"id": 1, "name": "Laptop", "price": 999, "category": "Electronics"},
            {"id": 2, "name": "Smartphone", "price": 599, "category": "Electronics"},
            {"id": 3, "name": "Headphones", "price": 199, "category": "Electronics"},
            {"id": 4, "name": "T-shirt", "price": 29, "category": "Clothing"},
            {"id": 5, "name": "Jeans", "price": 59, "category": "Clothing"},
        ]
        self.current_page = "home"
        self.cart = []
        self.search_query = ""
        self.current_item = None  # Initialize current_item

    def get_observation(self) -> str:
        if self.current_page == "home":
            return """
            <html>
                <body>
                    <h1>WebShop Home</h1>
                    <input type="text" id="search_bar" placeholder="Search for items">
                    <button id="search_button">Search</button>
                    <button id="view_cart">View Cart</button>
                </body>
            </html>
            """
        elif self.current_page == "search_results":
            items_html = "".join([f'<div class="item" id="item_{item["id"]}">{item["name"]} - ${item["price"]}</div>' for item in self.items if self.search_query.lower() in item["name"].lower()])
            return f"""
            <html>
                <body>
                    <h1>Search Results for "{self.search_query}"</h1>
                    {items_html}
                    <button id="back_to_home">Back to Home</button>
                </body>
            </html>
            """
        elif self.current_page == "item_details":
            if self.current_item is not None:
                item = self.items[self.current_item]
                return f"""
                <html>
                    <body>
                        <h1>{item["name"]}</h1>
                        <p>Price: ${item["price"]}</p>
                        <p>Category: {item["category"]}</p>
                        <button id="add_to_cart">Add to Cart</button>
                        <button id="back_to_results">Back to Results</button>
                    </body>
                </html>
                """
            else:
                return "Error: No item selected"
        elif self.current_page == "cart":
            cart_items = "".join([f'<div class="cart_item">{item["name"]} - ${item["price"]}</div>' for item in self.cart])
            return f"""
            <html>
                <body>
                    <h1>Shopping Cart</h1>
                    {cart_items}
                    <button id="checkout">Checkout</button>
                    <button id="back_to_home">Back to Home</button>
                </body>
            </html>
            """

    def take_action(self, action: str) -> Tuple[str, float, bool]:
        print(f"Taking action: {action}")
        if action.startswith("search"):
            self.search_query = action.split(" ", 1)[1]
            self.current_page = "search_results"
            return self.get_observation(), 0, False
        elif action.startswith("view_item"):
            item_id = int(action.split()[-1])
            self.current_item = item_id - 1
            self.current_page = "item_details"
            return self.get_observation(), 0, False
        elif action == "add_to_cart":
            if self.current_item is not None:
                self.cart.append(self.items[self.current_item])
                return f"Added {self.items[self.current_item]['name']} to cart", 0.1, False
            else:
                return "Error: No item selected", -0.1, False
        elif action == "view_cart":
            self.current_page = "cart"
            return self.get_observation(), 0, False
        elif action == "checkout":
            if len(self.cart) > 0:
                return "Checkout successful", 1, True
            else:
                return "Cart is empty", -0.1, False
        elif action == "back_to_home":
            self.current_page = "home"
            self.current_item = None
            return self.get_observation(), 0, False
        elif action == "back_to_results":
            self.current_page = "search_results"
            self.current_item = None
            return self.get_observation(), 0, False
        else:
            return "Invalid action", -0.1, False

class DPO:
    def __init__(self, beta=1.0):
        self.beta = beta
        self.reference_model = None  # This should be initialized with the initial policy

    def optimize(self, preference_pairs):
        # Implement Direct Preference Optimization
        print("Optimizing policy using Direct Preference Optimization (DPO)")
        losses = []
        for h, a_w, a_l in preference_pairs:
            # In a full implementation, these would be computed using the actual policy and reference model
            pi_w, pi_l = np.random.random(), np.random.random()
            ref_w, ref_l = np.random.random(), np.random.random()
            
            loss = -np.log(1 / (1 + np.exp(self.beta * (np.log(pi_l / ref_l) - np.log(pi_w / ref_w)))))
            losses.append(loss)
        
        avg_loss = np.mean(losses)
        print(f"Average DPO loss: {avg_loss}")
        return avg_loss

class ReplayBuffer:
    def __init__(self, max_size=10000):
        self.buffer = deque(maxlen=max_size)

    def add(self, state, action, reward, next_state, done):
        # Add experience to the replay buffer
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        # Sample a batch of experiences from the replay buffer
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))

class AgentQ:
    def __init__(self, env: WebShop, mcts_simulations: int = 10, mcts_max_depth: int = 5, max_steps_per_episode: int = 100):
        self.env = env
        self.mcts = MCTS(env, num_simulations=mcts_simulations, max_depth=mcts_max_depth)
        self.dpo = DPO()
        self.replay_buffer = ReplayBuffer()
        self.value_threshold = 0.1
        self.episode_trees = []  # Store MCTS trees for each step
        self.max_steps_per_episode = max_steps_per_episode

    def get_action(self, observation: str) -> str:
        print("Agent Q: Selecting action using MCTS")
        action, root = self.mcts.search(observation)
        self.episode_trees.append(root)  # Store the root of the MCTS tree
        return action

    def generate_preference_pairs(self, node):
        print("Generating preference pairs for DPO")
        pairs = []
        for i, child1 in enumerate(node.children):
            for j, child2 in enumerate(node.children[i+1:]):
                if abs(child1.value() - child2.value()) > self.value_threshold:
                    if child1.value() > child2.value():
                        pairs.append((node.observation, child1.action, child2.action))
                    else:
                        pairs.append((node.observation, child2.action, child1.action))
        print(f"Generated {len(pairs)} preference pairs")
        return pairs

    def train(self, num_episodes: int):
        print(f"Training Agent Q for {num_episodes} episodes")
        start_time = time.time()
        total_steps = 0
        total_rewards = []

        for episode in range(num_episodes):
            episode_start_time = time.time()
            print(f"\nEpisode {episode + 1}/{num_episodes}")
            observation = self.env.get_observation()
            done = False
            total_reward = 0
            step = 0
            self.episode_trees = []  # Reset episode trees for new episode

            while not done and step < self.max_steps_per_episode:
                action = self.get_action(observation)
                next_observation, reward, done = self.env.take_action(action)
                self.replay_buffer.add(observation, action, reward, next_observation, done)
                total_reward += reward
                
                preference_pairs = self.generate_preference_pairs(self.episode_trees[-1])
                self.dpo.optimize(preference_pairs)
                
                observation = next_observation
                step += 1
                total_steps += 1

                # Print progress every 10 steps
                if step % 10 == 0:
                    print(f"  Step {step}: Current reward = {total_reward:.2f}")

            if step >= self.max_steps_per_episode:
                print(f"  Episode terminated after reaching max steps ({self.max_steps_per_episode})")

            episode_duration = time.time() - episode_start_time
            total_rewards.append(total_reward)
            avg_reward = sum(total_rewards) / len(total_rewards)
            
            # Visualize MCTS trees for this episode
            self.visualize_episode_mcts_trees(episode + 1, max_depth=3, max_children=5)

            print(f"Episode {episode + 1} completed:")
            print(f"  Total reward: {total_reward:.2f}")
            print(f"  Steps taken: {step}")
            print(f"  Episode duration: {episode_duration:.2f} seconds")
            print(f"  Average reward so far: {avg_reward:.2f}")
            
            self.visualize_episode(episode + 1)  # Visualize the entire episode

        total_duration = time.time() - start_time
        print("\nTraining completed:")
        print(f"Total episodes: {num_episodes}")
        print(f"Total steps: {total_steps}")
        print(f"Average steps per episode: {total_steps / num_episodes:.2f}")
        print(f"Average reward: {sum(total_rewards) / len(total_rewards):.2f}")
        print(f"Total duration: {total_duration:.2f} seconds")
        print(f"Average duration per episode: {total_duration / num_episodes:.2f} seconds")

    def visualize_full_mcts_tree(self, episode_number, step_number, max_depth=None, max_children=None):
        dot = Digraph(comment=f'Episode {episode_number}, Step {step_number} Full MCTS Tree')
        dot.attr(rankdir='TB', size='30,30')

        def add_node(node, parent_id=None, depth=0):
            if max_depth is not None and depth > max_depth:
                return

            node_id = str(id(node))
            label = f"{node.action if node.action else 'root'}\n"
            label += f"Visits: {node.visits}\n"
            label += f"Value: {node.value():.2f}"

            color = self.get_color_for_value(node.value())
            shape = self.get_node_shape(node.action) if node.action else 'doubleoctagon'

            dot.node(node_id, label, style='filled', color=color, shape=shape)

            if parent_id:
                dot.edge(parent_id, node_id)

            children = node.children
            if max_children is not None and len(children) > max_children:
                # If we have too many children, select a subset
                children = sorted(children, key=lambda c: c.visits, reverse=True)[:max_children]
                children.append(MCTSNode("...", parent=node, action="(more)"))

            for child in children:
                add_node(child, node_id, depth + 1)

        root = self.episode_trees[step_number]
        add_node(root)

        # Add a legend
        with dot.subgraph(name='cluster_legend') as c:
            c.attr(label='Legend')
            c.node('legend_high', 'High Value', style='filled', color=self.get_color_for_value(1.0))
            c.node('legend_med', 'Medium Value', style='filled', color=self.get_color_for_value(0.5))
            c.node('legend_low', 'Low Value', style='filled', color=self.get_color_for_value(0.0))
            c.node('legend_search', 'Search Action', shape='diamond')
            c.node('legend_view', 'View Action', shape='ellipse')
            c.node('legend_cart', 'Cart Action', shape='box')

        filename = f'episode_{episode_number}_step_{step_number}_full_mcts_tree'
        dot.render(filename, view=True, format='png', cleanup=True)
        print(f"Full MCTS tree visualization saved as {filename}.png")

    def visualize_episode_mcts_trees(self, episode_number, max_depth=None, max_children=None):
        for step, tree in enumerate(self.episode_trees):
            self.visualize_full_mcts_tree(episode_number, step, max_depth, max_children)
    
    def visualize_episode(self, episode_number):
        dot = Digraph(comment=f'Episode {episode_number} MCTS Trees')
        dot.attr(rankdir='LR', size='30,30')

        # Add a legend
        with dot.subgraph(name='cluster_legend') as c:
            c.attr(label='Legend')
            c.node('legend_high', 'High Value', style='filled', color=self.get_color_for_value(1.0))
            c.node('legend_med', 'Medium Value', style='filled', color=self.get_color_for_value(0.5))
            c.node('legend_low', 'Low Value', style='filled', color=self.get_color_for_value(0.0))
            c.node('legend_search', 'Search Action', shape='diamond')
            c.node('legend_view', 'View Action', shape='ellipse')
            c.node('legend_cart', 'Cart Action', shape='box')

        for step, root in enumerate(self.episode_trees):
            with dot.subgraph(name=f'cluster_{step}') as c:
                c.attr(label=f'Step {step + 1}')
                self.add_nodes_edges(root, c, f's{step}_')

        dot.render(f'episode_{episode_number}_mcts_trees', view=True, format='png', cleanup=True)

    def add_nodes_edges(self, node, graph, prefix, parent_id=None, depth=0):
        if depth > self.mcts.max_depth:
            return

        node_id = f"{prefix}{id(node)}"
        label = f"{node.action if node.action else 'root'}\n"
        label += f"Visits: {node.visits}\n"
        label += f"Value: {node.value():.2f}"
        
        color = self.get_color_for_value(node.value())
        shape = self.get_node_shape(node.action) if node.action else 'doubleoctagon'
        
        graph.node(node_id, label, style='filled', color=color, shape=shape)
        
        if parent_id:
            graph.edge(parent_id, node_id)
        
        for child in node.children:
            self.add_nodes_edges(child, graph, prefix, node_id, depth + 1)

    def get_color_for_value(self, value):
        # Use HSV color space for a smooth transition from red to yellow to green
        hue = value * 0.3  # This will give a range from red (0) to green (0.3)
        saturation = 0.7  # Reduce saturation for less intense colors
        value = 0.9  # Keep brightness high but not maximum
        r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
        return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"

    def get_node_shape(self, action):
        if action.startswith('search'):
            return 'diamond'
        elif action.startswith('view'):
            return 'ellipse'
        elif action in ['add_to_cart', 'checkout']:
            return 'box'
        else:
            return 'oval'
    
    def self_critique(self, observation: str, action: str) -> float:
        print("Performing self-critique")
        prompt = f"""
        Given the current observation: "{observation}"
        And the proposed action: "{action}"
        
        Rate the quality of this action on a scale from 0 to 1, where 1 is the best possible action and 0 is the worst.
        Provide a brief explanation for your rating.
        
        Rating:
        Explanation:
        """
        
        client = OpenAI()
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are an AI assistant evaluating the quality of actions in a web shopping environment."},
                {"role": "user", "content": prompt}
            ]
        )
        
        critique = response.choices[0].message.content
        rating = float(critique.split("Rating:")[1].split("\n")[0].strip())
        explanation = critique.split("Explanation:")[1].strip()
        print(f"Self-critique rating: {rating}")
        print(f"Explanation: {explanation}")
        return rating

class MCTSNode:
    def __init__(self, observation: str, parent: 'MCTSNode' = None, action: str = None):
        self.observation = observation
        self.parent = parent
        self.action = action
        self.children: List['MCTSNode'] = []
        self.visits = 0
        self.total_value = 0
        self.untried_actions = self.get_possible_actions()

    def get_possible_actions(self):
        # Define possible actions for the node
        return ["search laptop", "search smartphone", "view_item 1", "view_item 2", "add_to_cart", "view_cart", "checkout", "back_to_home"]

    def fully_expanded(self):
        # Check if all possible actions have been tried
        return len(self.untried_actions) == 0

    def get_untried_action(self):
        # Get an untried action
        if not self.untried_actions:
            return None
        return self.untried_actions.pop()

    def add_child(self, child_node):
        # Add a child node
        self.children.append(child_node)

    def update(self, value):
        # Update the node's statistics
        self.visits += 1
        self.total_value += value

    def value(self):
        # Calculate the node's value
        return self.total_value / self.visits if self.visits > 0 else 0

    def ucb_score(self, c=1.41):
        # Calculate the UCB1 score for the node
        if self.visits == 0:
            return float('inf')
        return self.value() + c * np.sqrt(np.log(self.parent.visits) / self.visits)

class MCTS:
    def __init__(self, env: WebShop, num_simulations: int = 10, max_depth: int = 5):
        self.env = env
        self.num_simulations = num_simulations
        self.max_depth = max_depth
        self.root = None

    def search(self, observation: str) -> Tuple[str, MCTSNode]:
        print(f"Starting MCTS search with {self.num_simulations} simulations and max depth {self.max_depth}")
        self.root = MCTSNode(observation)
        
        for i in range(self.num_simulations):
            print(f"Simulation {i + 1}/{self.num_simulations}")
            node = self.select(self.root)
            value = self.simulate(node)
            self.backpropagate(node, value)
        
        best_child = self.best_child(self.root)
        if best_child:
            print(f"MCTS search completed. Best action: {best_child.action}")
            return best_child.action, self.root
        else:
            print("MCTS search failed to find a best action. Returning a random action.")
            return random.choice(self.root.get_possible_actions()), self.root

    def select(self, node, depth=0):
        print(f"MCTS: Selection phase (depth {depth})")
        while node.children and depth < self.max_depth:
            if not node.fully_expanded():
                return self.expand(node, depth)
            else:
                node = self.ucb_select(node)
            depth += 1
        if depth < self.max_depth and not node.fully_expanded():
            return self.expand(node, depth)
        return node

    def expand(self, node, depth):
        print(f"MCTS: Expansion phase (depth {depth})")
        action = node.get_untried_action()
        if action is None:
            # If no untried actions, select a random action
            action = random.choice(node.get_possible_actions())
        next_observation, reward, done = self.env.take_action(action)
        child = MCTSNode(next_observation, parent=node, action=action)
        node.add_child(child)
        return child

    def simulate(self, node):
        print(f"MCTS: Simulation phase (depth {self.get_node_depth(node)})")
        return self.ai_process_supervision(node.observation, node.action)

    def backpropagate(self, node, value):
        print(f"MCTS: Backpropagation phase (starting from depth {self.get_node_depth(node)})")
        while node is not None:
            node.update(value)
            node = node.parent

    def get_node_depth(self, node):
        depth = 0
        while node.parent is not None:
            depth += 1
            node = node.parent
        return depth

    def best_child(self, node):
        if not node.children:
            return None
        return max(node.children, key=lambda c: c.visits)

    def ucb_select(self, node):
        return max(node.children, key=lambda c: c.ucb_score())

    def ai_process_supervision(self, observation: str, action: str) -> float:
        #print("Performing AI process supervision using GPT-4")
        prompt = f"""
        You are an AI assistant evaluating the quality of actions in a web shopping environment.
        
        Current webpage observation:
        {observation}
        
        Proposed action:
        {action}
        
        Rate the quality of this action on a scale from 0 to 1, where 1 is the best possible action and 0 is the worst.
        Consider factors such as relevance to the current page, progression towards a shopping goal, and adherence to typical web navigation patterns.
        
        Provide your rating and a brief explanation in the following format:
        
        Rating: [Your rating between 0 and 1]
        Explanation: [Your brief explanation]
        """
        
        try:
            client = OpenAI()
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": "You are an AI assistant evaluating web navigation actions."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=150
            )
            
            critique = response.choices[0].message.content
            rating_line = [line for line in critique.split('\n') if line.startswith("Rating:")][0]
            rating = float(rating_line.split(":")[1].strip())
            
            print(f"AI Process Supervision - Action: {action}, Rating: {rating}")
            print(f"Explanation: {critique.split('Explanation:')[1].strip()}")
            
            return rating
        except Exception as e:
            print(f"Error in AI process supervision: {e}")
            return random.random()  # Fallback to random rating in case of API error
    
    def get_color_for_value(self, value):
        # This function returns a color based on the node's value
        # Green for high values, red for low values
        r = int(255 * (1 - value))
        g = int(255 * value)
        b = 0
        return f"#{r:02x}{g:02x}{b:02x}"
    
    def visualize_tree(self):
        dot = Digraph(comment='MCTS Tree')
        dot.attr(rankdir='TB', size='8,8')
        
        def add_nodes_edges(node, parent_id=None, depth=0):
            if depth > self.max_depth:
                return

            node_id = str(id(node))
            label = f"{node.action if node.action else 'root'}\n"
            label += f"Visits: {node.visits}\n"
            label += f"Value: {node.value():.2f}"
            
            color = self.get_color_for_value(node.value())
            
            dot.node(node_id, label, style='filled', color=color)
            
            if parent_id:
                dot.edge(parent_id, node_id)
            
            for child in node.children:
                add_nodes_edges(child, node_id, depth + 1)
        
        add_nodes_edges(self.root)
        dot.render('mcts_tree', view=True, format='png', cleanup=True)

class MCTSNode:
    def __init__(self, observation: str, parent: 'MCTSNode' = None, action: str = None):
        self.observation = observation
        self.parent = parent
        self.action = action
        self.children: List['MCTSNode'] = []
        self.visits = 0
        self.total_value = 0
        self.untried_actions = self.get_possible_actions()

    def get_possible_actions(self):
        # Define possible actions for the node
        return ["search laptop", "search smartphone", "view_item 1", "view_item 2", "add_to_cart", "view_cart", "checkout", "back_to_home"]

    def fully_expanded(self):
        # Check if all possible actions have been tried
        return len(self.untried_actions) == 0

    def get_untried_action(self):
        # Get an untried action
        if not self.untried_actions:
            return None
        return self.untried_actions.pop()

    def add_child(self, child_node):
        # Add a child node
        self.children.append(child_node)

    def update(self, value):
        # Update the node's statistics
        self.visits += 1
        self.total_value += value

    def value(self):
        # Calculate the node's value
        return self.total_value / self.visits if self.visits > 0 else 0

    def ucb_score(self, c=1.41):
        # Calculate the UCB1 score for the node
        if self.visits == 0:
            return float('inf')
        return self.value() + c * np.sqrt(np.log(self.parent.visits) / self.visits)

def main():
    print("Initializing WebShop environment and Agent Q")
    env = WebShop()
    agent = AgentQ(env)

    print("Training Agent Q")
    agent.train(num_episodes=5)  # You can adjust the number of episodes as needed

    print("\nTesting trained Agent Q")
    observation = env.get_observation()
    done = False
    total_reward = 0
    step = 0

    while not done:
        print(f"\nStep {step + 1}")
        print("Current Observation:")
        print(observation)
        
        action = agent.get_action(observation)
        print(f"\nChosen action: {action}")
        
        critique_score = agent.self_critique(observation, action)
        print(f"Self-critique score: {critique_score}")
        
        observation, reward, done = env.take_action(action)
        total_reward += reward
        print(f"Reward: {reward}")
        print(f"Total reward so far: {total_reward}")
        
        step += 1

    print(f"\nTest completed. Final total reward: {total_reward}")

if __name__ == "__main__":
    main()

Leave a Reply

Your email address will not be published. Required fields are marked *