{ "cells": [ { "cell_type": "markdown", "id": "9749f69b", "metadata": {}, "source": [ "# Cartpole-v1" ] }, { "cell_type": "code", "execution_count": 1, "id": "50995613", "metadata": { "ExecuteTime": { "end_time": "2025-03-04T15:41:33.943786Z", "start_time": "2025-03-04T15:41:33.830063Z" } }, "outputs": [], "source": [ "import gymnasium as gym" ] }, { "cell_type": "code", "execution_count": 2, "id": "78ed5a1b", "metadata": { "ExecuteTime": { "end_time": "2025-03-04T15:41:33.972694Z", "start_time": "2025-03-04T15:41:33.944621Z" } }, "outputs": [], "source": [ "from gymcts.gymcts_agent import GymctsAgent\n", "from gymcts.gymcts_deepcopy_wrapper import DeepCopyMCTSGymEnvWrapper" ] }, { "cell_type": "code", "execution_count": 3, "id": "c90dd1f6", "metadata": { "ExecuteTime": { "end_time": "2025-03-04T15:41:49.986143Z", "start_time": "2025-03-04T15:41:33.973323Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "reward: 1.0, step: 10\n", "reward: 1.0, step: 20\n", "reward: 1.0, step: 30\n", "reward: 1.0, step: 40\n", "reward: 1.0, step: 50\n", "reward: 1.0, step: 60\n", "reward: 1.0, step: 70\n", "reward: 1.0, step: 80\n", "reward: 1.0, step: 90\n", "reward: 1.0, step: 100\n", "reward: 1.0, step: 110\n", "reward: 1.0, step: 120\n", "reward: 1.0, step: 130\n", "reward: 1.0, step: 140\n", "reward: 1.0, step: 150\n", "reward: 1.0, step: 160\n", "reward: 1.0, step: 170\n", "reward: 1.0, step: 180\n", "reward: 1.0, step: 190\n", "reward: 1.0, step: 200\n", "reward: 1.0, step: 210\n", "reward: 1.0, step: 220\n", "reward: 1.0, step: 230\n", "reward: 1.0, step: 240\n", "reward: 1.0, step: 250\n", "reward: 1.0, step: 260\n", "reward: 1.0, step: 270\n", "reward: 1.0, step: 280\n", "reward: 1.0, step: 290\n", "reward: 1.0, step: 300\n", "reward: 1.0, step: 310\n", "reward: 1.0, step: 320\n", "reward: 1.0, step: 330\n", "reward: 1.0, step: 340\n", "reward: 1.0, step: 350\n", "reward: 1.0, step: 360\n", "reward: 1.0, step: 370\n", "reward: 1.0, step: 380\n", "reward: 1.0, step: 390\n", "reward: 1.0, step: 400\n", "reward: 1.0, step: 410\n", "reward: 1.0, step: 420\n", "reward: 1.0, step: 430\n", "reward: 1.0, step: 440\n", "reward: 1.0, step: 450\n", "reward: 1.0, step: 460\n", "reward: 1.0, step: 470\n", "reward: 1.0, step: 480\n", "reward: 1.0, step: 490\n", "reward: 1.0, step: 500\n", "CartPole-v1 successfully balanced!\n" ] } ], "source": [ "if __name__ == '__main__':\n", "\n", " # 0. create the environment\n", " env = gym.make(\"CartPole-v1\")\n", " env.reset(seed=42)\n", "\n", "\n", " # 1. wrap the environment with the naive wrapper or a custom gymcts wrapper\n", " env = DeepCopyMCTSGymEnvWrapper(env)\n", "\n", " # 2. create the agent\n", " agent = GymctsAgent(\n", " env=env,\n", " number_of_simulations_per_step=50,\n", " clear_mcts_tree_after_step=True,\n", " )\n", "\n", " # 3. solve the environment\n", " terminal = False\n", " step = 0\n", " while not terminal:\n", " action, _ = agent.perform_mcts_step()\n", " obs, rew, term, trun, info = env.step(action)\n", " terminal = term or trun\n", "\n", " step += 1\n", "\n", " # log to console every 10 steps\n", " if step % 10 == 0:\n", " print(f\"reward: {rew}, step: {step}\")\n", "\n", " if step >= 475:\n", " print(\"CartPole-v1 successfully balanced!\")\n", " else:\n", " print(\"CartPole-v1 failed to balance.\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "255bf1a9", "metadata": { "ExecuteTime": { "end_time": "2025-03-04T15:41:49.988317Z", "start_time": "2025-03-04T15:41:49.986938Z" }, "lines_to_next_cell": 2 }, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }