{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting started with MDPax\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joefarrington/mdpax/blob/main/examples/getting_started.ipynb)\n", "\n", "This notebook demonstrates MDPax's key features through increasingly complex examples. We'll start with a simple forest management problem and work our way up to larger, more realistic problems.\n", "\n", "If you're running the notebook in Google Colab, you should verify that you're using a GPU instance. Click Runtime > Change runtime type and ensure \"GPU\" is selected as the Hardware accelerator. You can confirm GPU availability by running `!nvidia-smi` in a code cell. \n", "\n", "## Prerequisites\n", "\n", "If you're new to Markov Decision Processes (MDPs), you may find these introductory resources useful:\n", "- [Reinforcement Learning: An Introduction - Chapter 3 | Sutton & Barto](http://incompleteideas.net/book/RLbook2020.pdf)\n", "- 📺 [Markov Decision Processes 1 - Value Iteration | Stanford CS221](https://www.youtube.com/watch?v=9g32v7bK3Co)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation and imports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're running the notebook in Google Colab you may need to restart the runtime after installing the dependencies so that the updated packages can be loaded. You will receive a \"Warning\" message in the output of the cell below if this is the case." ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "try:\n", " # Other dependencies will be installed on Colab already\n", " import mdptoolbox\n", " import mdpax\n", "except ImportError:\n", " if 'google.colab' in sys.modules:\n", " # Automatically install mdpax if running in Colab, environment is temporary\n", " !pip install \"mdpax[examples-colab] @ git+https://github.com/joefarrington/mdpax.git\"\n", " else:\n", " print(\"Dependencies not installed. Please follow the installation instructions in the README: https://github.com/joefarrington/mdpax\")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "import jax\n", "\n", "# In general we recommend using double precision and it is particularly helpful when\n", "# performing comparisons with pymdptoolbox which uses NumPy and therefore defaults to\n", "# double precision.\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple example: forest management\n", "\n", "Let's start with a simple example problem introduced in pymdptoolbox, an alternative library for solving MDPs in Python, so that we can compare our results. \n", "\n", "This problem involves deciding whether we should cut down a forest, or wait to let it mature.\n", "\n", "The state is the current age of the forest and our actions are 0 (wait) and 1 (cut). We receive a reward of 1 if we cut down the forest before it is mature, a reward of $r_1$ if we wait in the oldest state, and a reward of $r_2$ if we cut the forst in the oldest state. There is a risk of fire occurring, with probability $p$ of a fire at each timestep. If we choose to cut down the forest, or if there is a fire, the forest returns to age 0 (newly planted).\n", "\n", "The two key base classes in mdpax are `Problem` and `Solver` - the `Problem` class is used to define the MDP (in this case the forest problem) and the `Solver` class is used to define algorithms for fitting policies (in this case, value iteration)." ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import mdptoolbox\n", "import numpy as np\n", "\n", "from mdpax.problems.forest import Forest\n", "from mdpax.solvers.value_iteration import ValueIteration" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:40:55.398\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with forest problem\u001b[0m\n", "\u001b[32m2025-01-05 21:41:01.698\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m123\u001b[0m - \u001b[1mCheckpointing not enabled\u001b[0m\n", "\u001b[32m2025-01-05 21:41:01.938\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m497\u001b[0m - \u001b[1mIteration 1 span: 4.0000\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.061\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m497\u001b[0m - \u001b[1mIteration 2 span: 2.4300\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.074\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m497\u001b[0m - \u001b[1mIteration 3 span: 0.8100\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.079\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m497\u001b[0m - \u001b[1mIteration 4 span: 0.0000\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.081\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m502\u001b[0m - \u001b[1mConvergence threshold reached at iteration 4\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.083\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m521\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.232\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m523\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:41:02.233\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m525\u001b[0m - \u001b[32m\u001b[1mValue iteration completed\u001b[0m\n" ] } ], "source": [ "# Create and solve the basic forest problem with MDPax\n", "problem = Forest(S=3, r1=4.0, r2=2.0, p=0.1) # Small forest with 3 states\n", "solver = ValueIteration(problem, gamma=0.9, epsilon=0.01)\n", "solution = solver.solve()" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MDPax Solution:\n", "--------------\n", "Values:\n", "[ 5.052 8.292 12.292]\n", "Policy:\n", "[0 0 0]\n", "Iterations to converge: 4\n" ] } ], "source": [ "# solution is a dataclass from which we can extract the values, policy, and iteration count\n", "mdpax_values = solution.values\n", "mdpax_policy = solution.policy\n", "mdpax_iteration = solution.info.iteration\n", "\n", "# Print the solution from MDPax\n", "print(\"MDPax Solution:\")\n", "print(\"--------------\")\n", "print(f\"Values:\\n{np.round(mdpax_values.flatten(), 4)}\") # State values\n", "print(f\"Policy:\\n{mdpax_policy.flatten()}\") # Optimal actions\n", "print(f\"Iterations to converge: {mdpax_iteration}\") # From solver info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MDPax relies on a functional description of the MDP, using the `Problem` class. This includes defining the probability of a random event given a state and an action, and a deterministic transition function that gives the reward and the next state given a state, action and random event. \n", "\n", "Many other libraries, such as pymdptoolbox require the user to provide the transition matrix $\\mathbf{P}$ and a reward matrix $\\mathbf{R}$ for the MDP. \n", "\n", "Transition matrix $\\mathbf{P}$ has dimensions (`n_actions`, `n_states`, `n_states`) and element $\\mathbf{P}_{a,s,s'}$ is the probability of transitioning to state $s'$ when taking action $a$ in state $s$. \n", "\n", "Reward matrix $\\mathbf{R}$ has dimensions (`n_states`, `n_actions`) and element $\\mathbf{R}_{s,a}$ gives the expected reward when taking action $a$ in state $s$.\n", "\n", "The MDPax `Problem` class has a built-in method for constructing these matrices based on the functions describing the problem, which we can use to construct $\\mathbf{P}$ and $\\mathbf{R}$ and solve the forest management problem using pymdptoolbox to check our solution." ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "# Get transition and reward matrices for comparison\n", "P, R = problem.build_transition_and_reward_matrices()\n", "# Convert to numpy arrays for pymdptoolbox\n", "P = np.array(P)\n", "R = np.array(R)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "# Solve with pymdptoolbox\n", "vi = mdptoolbox.mdp.ValueIteration(P, R, discount=0.9, epsilon=0.01)\n", "vi.run()" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "pymdptoolbox Solution:\n", "--------------------\n", "Values:\n", "[ 5.052 8.292 12.292]\n", "Policy:\n", "(0, 0, 0)\n", "Iterations to converge: 4\n" ] } ], "source": [ "# Extract the solution from pymdptoolbox class\n", "toolbox_values = vi.V\n", "toolbox_policy = vi.policy\n", "toolbox_iteration = vi.iter\n", "\n", "\n", "# Print the solution from pymdptoolbox\n", "print(\"\\npymdptoolbox Solution:\")\n", "print(\"--------------------\")\n", "print(f\"Values:\\n{np.round(toolbox_values, 4)}\")\n", "print(f\"Policy:\\n{toolbox_policy}\")\n", "print(f\"Iterations to converge: {toolbox_iteration}\")" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Solutions match?\n", "Values close?: True\n", "Policies match?: True\n", "Number of iterations match?: True\n" ] } ], "source": [ "# Verify solutions match\n", "print(\"\\nSolutions match?\")\n", "print(f\"Values close?: {np.allclose(mdpax_values.flatten(), toolbox_values, rtol=1e-2)}\")\n", "print(f\"Policies match?: {np.array_equal(mdpax_policy.flatten(), np.array(toolbox_policy))}\")\n", "print(f\"Number of iterations match?: {mdpax_iteration == toolbox_iteration}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this very small problem, pymdptoolbox will be faster than MDPax due to data trasfer costs moving data to and from GPU and the upfront costs for [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html) in JAX." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A larger problem: perishable inventory management with substitution\n", "\n", "Perishable inventory management problems are known to be computationally challenging to solve exactly (e.g. with value iteration) because the state must represent the age-profile of the stock (how many units of each age are held) and therefore the size of the state space grows exponentially with the maximum useful life of the product. \n", "\n", "In this example, we consider a perishable inventory management problem introduced by [Hendrix et al. (2019)](https://doi.org/10.1002/cmm4.1027). \n", "\n", "The decision maker must place a replenishment order each day for two perishable products, product A and product B. Orders are placed in the morning and arrive immediately before the start of the next period. Demand for each product each day is random. Some customers who want product B may accept product A, and substitutions are made once demand for product A has been met as far as possbile. \n", "\n", "The goal is to maximise average daily profits (sales revenue less an ordering cost per unit), and therefore we use relative value iteration, (with no discounting of future rewards), as the `Solver`. \n", "\n", "The smallest example considered by Hendrix et al. has 11,025 states and 105 actions. Since $\\mathbf{P}$ has dimensions (`n_actions`, `n_states`, `n_states`) the transition matrix for the problem would have $(105 \\times 11,025 \\times 11,025) = 13\\text{Bn}$ elements. Just storing this matrix as 64-bit floats would require over 100GB of RAM!\n", "\n", "So, to start with, so that we can compare our results with `pymdptoolbox`, we'll look at a smaller version of the problem with 625 states and 25 actions.\n", "\n", " A note on sparsity: The comments on the size of the transition matrices in this introductory notebook do not take potential sparsity into account. pymdptoolbox has support for sparse arrays and this would reduce the memory requirements required to represent the transition matrices. Support for sparse arrays in JAX is currently [experimental](https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html). We may investigate the potential benefits of sparsity as part of a future release." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic case" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of states: 625\n", "Number of actions: 25\n" ] } ], "source": [ "from mdpax.problems.perishable_inventory.hendrix_two_product import (\n", " HendrixTwoProductPerishable,\n", ")\n", "from mdpax.solvers.relative_value_iteration import RelativeValueIteration\n", "\n", "problem = HendrixTwoProductPerishable(max_useful_life = 2, # Products can be used for 2 periods after arrival, then are discarded\n", " demand_poisson_mean_a=2, # Mean daily demand for product A\n", " demand_poisson_mean_b=2, # Mean daily demand for product B\n", " max_order_quantity_a=4, # Maximum order quantity for product A\n", " max_order_quantity_b=4) # Maximum order quantity for product B\n", "\n", "print(f\"Number of states: {problem.n_states}\")\n", "print(f\"Number of actions: {problem.n_actions}\") " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As above, we'l first solve it using MDPax." ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:41:02.977\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with hendrix_two_product problem\u001b[0m\n", "\u001b[32m2025-01-05 21:41:03.688\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m123\u001b[0m - \u001b[1mCheckpointing not enabled\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.398\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 1: span: 2.84548, gain: 7.8527\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.415\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 2: span: 0.60450, gain: 1.0389\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.428\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 3: span: 0.10172, gain: 1.5746\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.440\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 4: span: 0.07247, gain: 1.5865\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.454\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 5: span: 0.03349, gain: 1.5201\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.467\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 6: span: 0.01414, gain: 1.5402\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.481\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 7: span: 0.00648, gain: 1.5447\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.503\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 8: span: 0.00331, gain: 1.5442\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.515\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 9: span: 0.00229, gain: 1.5417\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.530\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 10: span: 0.00091, gain: 1.5439\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.542\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 11: span: 0.00045, gain: 1.5435\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.553\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 12: span: 0.00028, gain: 1.5432\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.565\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 13: span: 0.00011, gain: 1.5434\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.579\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 14: span: 0.00009, gain: 1.5435\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.581\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m174\u001b[0m - \u001b[1mConvergence threshold reached at iteration 14\u001b[0m\n", "\u001b[32m2025-01-05 21:41:04.584\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:41:05.282\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:41:05.282\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m197\u001b[0m - \u001b[32m\u001b[1mRelative value iteration completed\u001b[0m\n" ] } ], "source": [ "solver = RelativeValueIteration(problem, epsilon=1e-4)\n", "solution = solver.solve()\n", "mdpax_values = solution.values\n", "mdpax_policy = solution.policy\n", "mdpax_average_daily_profit = solution.info.gain" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The gain converges to the mean reward per timestep, so we can see that the mean profit per day is $1.54.\n", "\n", "Next, because this is a small problem, we can build $\\mathbf{P}$ and $\\mathbf{R}$ and solve the problem using pymdptoolbox. Also, for this problem, the value function is initialized using the one-step ahead revenue. The `RelativeValueIterationSolver` automatically calculated this from the `Problem`, but we will need to provide it manually to `mdptoolbox`." ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "# Compute P, R and initial values from Problem\n", "P, R = problem.build_transition_and_reward_matrices()\n", "initial_values = jax.vmap(problem.initial_value)(problem.state_space)\n", "P = np.array(P)\n", "R = np.array(R)\n", "initial_values = np.array(initial_values)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "# Run relative value iteration with mpdtoolbox\n", "rvi = mdptoolbox.mdp.RelativeValueIteration(P, R, epsilon=1e-4)\n", "rvi.V = initial_values\n", "rvi.run()\n", "\n", "# Extract the solution from pymdptoolbox class\n", "toolbox_values = rvi.V\n", "toolbox_policy = rvi.policy\n", "toolbox_average_daily_profit = rvi.average_reward" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "pymdptoolbox gives us the index of the best action. This was fine in the Forest example because the actions are not numeric and were were only identified by an index. In this example, where an action is a order quantitity for each of product A and product B, we need to look up the actual action to compare to the policy from MDPax. \n", "\n", "We can do this be indexing into the `action_space` attribute of our `Problem`. See our [tutorial](https://mdpax.readthedocs.io/en/latest/notebooks/create_custom_problem.html) on implementing your own problem for more information on state, action and event spaces." ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "toolbox_policy = problem.action_space.take(jnp.array(toolbox_policy),axis=0)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Solutions match?\n", "Mean daily profit matches?: True\n", "Values close?: True\n", "Policies match?: True\n", "Number of iterations match?: True\n" ] } ], "source": [ "# Verify solutions match\n", "print(\"\\nSolutions match?\")\n", "print(f\"Mean daily profit matches?: {np.allclose(mdpax_average_daily_profit,toolbox_average_daily_profit, rtol=1e-2)}\")\n", "print(f\"Values close?: {np.allclose(mdpax_values.flatten(), toolbox_values, rtol=1e-2)}\")\n", "print(f\"Policies match?: {np.array_equal(mdpax_policy, toolbox_policy)}\")\n", "print(f\"Number of iterations match?: {mdpax_iteration == toolbox_iteration}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Larger cases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll now consider two larger versions of the problem which were included in Hendrix et al. (2019)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Case 1" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of states: 11025\n", "Number of actions: 105\n" ] } ], "source": [ "problem = HendrixTwoProductPerishable(max_useful_life = 2, # Products can be used for 2 periods after arrival, then are discarded\n", " demand_poisson_mean_a=7, # Mean daily demand for product A\n", " demand_poisson_mean_b=3, # Mean daily demand for product B\n", " max_order_quantity_a=14, # Maximum order quantity for product A\n", " max_order_quantity_b=6) # Maximum order quantity for product B\n", "\n", "print(f\"Number of states: {problem.n_states}\")\n", "print(f\"Number of actions: {problem.n_actions}\") " ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:41:15.889\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with hendrix_two_product problem\u001b[0m\n", "\u001b[32m2025-01-05 21:41:16.582\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m123\u001b[0m - \u001b[1mCheckpointing not enabled\u001b[0m\n", "\u001b[32m2025-01-05 21:41:17.612\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 1: span: 6.51383, gain: 19.9606\u001b[0m\n", "\u001b[32m2025-01-05 21:41:17.682\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 2: span: 1.24455, gain: 3.4498\u001b[0m\n", "\u001b[32m2025-01-05 21:41:17.751\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 3: span: 0.30959, gain: 4.3591\u001b[0m\n", "\u001b[32m2025-01-05 21:41:17.813\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 4: span: 0.17163, gain: 4.6679\u001b[0m\n", "\u001b[32m2025-01-05 21:41:17.875\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 5: span: 0.03459, gain: 4.5060\u001b[0m\n", "\u001b[32m2025-01-05 21:41:17.936\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 6: span: 0.01877, gain: 4.5114\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.001\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 7: span: 0.00573, gain: 4.5271\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.064\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 8: span: 0.00128, gain: 4.5217\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.127\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 9: span: 0.00083, gain: 4.5220\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.189\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 10: span: 0.00027, gain: 4.5227\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.252\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 11: span: 0.00010, gain: 4.5226\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.313\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 12: span: 0.00006, gain: 4.5225\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.314\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m174\u001b[0m - \u001b[1mConvergence threshold reached at iteration 12\u001b[0m\n", "\u001b[32m2025-01-05 21:41:18.316\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:41:19.149\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:41:19.149\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m197\u001b[0m - \u001b[32m\u001b[1mRelative value iteration completed\u001b[0m\n" ] } ], "source": [ "solver = RelativeValueIteration(problem, epsilon=1e-4)\n", "solution = solver.solve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This takes less than 20s [1], including setting up the problem, on a Google Colab GPU instance compared to the 206s reported by Hendrix et al. for their implementation using MATLAB on the CPU. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Case 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hendrix et al. reported that the largest problem they were able to solve within a week had 1.2Mn states, and took 80 hours. Using MDPax on a Google Colab GPU instance it should take less than 3 minutes [2].\n", "\n", "Storing the transition matrix for this problem as 64-bit floats would require over 1PB (or 1Mn GB) of RAM!" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of states: 1157625\n", "Number of actions: 105\n" ] } ], "source": [ "problem = HendrixTwoProductPerishable(max_useful_life = 3, # Products can be used for 2 periods after arrival, then are discarded\n", " demand_poisson_mean_a=7, # Mean daily demand for product A\n", " demand_poisson_mean_b=3, # Mean daily demand for product B\n", " max_order_quantity_a=20, # Maximum order quantity for product A\n", " max_order_quantity_b=4) # Maximum order quantity for product B\n", "\n", "print(f\"Number of states: {problem.n_states}\")\n", "print(f\"Number of actions: {problem.n_actions}\")" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:41:25.660\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with hendrix_two_product problem\u001b[0m\n", "\u001b[32m2025-01-05 21:41:28.899\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m123\u001b[0m - \u001b[1mCheckpointing not enabled\u001b[0m\n", "\u001b[32m2025-01-05 21:41:41.182\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 1: span: 6.54445, gain: 19.9912\u001b[0m\n", "\u001b[32m2025-01-05 21:41:50.995\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 2: span: 6.27677, gain: 9.7235\u001b[0m\n", "\u001b[32m2025-01-05 21:42:00.809\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 3: span: 1.48057, gain: 3.4468\u001b[0m\n", "\u001b[32m2025-01-05 21:42:10.537\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 4: span: 0.57104, gain: 4.3561\u001b[0m\n", "\u001b[32m2025-01-05 21:42:20.258\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 5: span: 0.09908, gain: 4.8482\u001b[0m\n", "\u001b[32m2025-01-05 21:42:29.982\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 6: span: 0.09773, gain: 4.9156\u001b[0m\n", "\u001b[32m2025-01-05 21:42:39.682\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 7: span: 0.02806, gain: 4.8388\u001b[0m\n", "\u001b[32m2025-01-05 21:42:49.461\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 8: span: 0.01050, gain: 4.8225\u001b[0m\n", "\u001b[32m2025-01-05 21:42:59.228\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 9: span: 0.00353, gain: 4.8284\u001b[0m\n", "\u001b[32m2025-01-05 21:43:08.963\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 10: span: 0.00176, gain: 4.8307\u001b[0m\n", "\u001b[32m2025-01-05 21:43:18.692\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 11: span: 0.00066, gain: 4.8297\u001b[0m\n", "\u001b[32m2025-01-05 21:43:28.446\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 12: span: 0.00020, gain: 4.8293\u001b[0m\n", "\u001b[32m2025-01-05 21:43:38.192\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 13: span: 0.00010, gain: 4.8293\u001b[0m\n", "\u001b[32m2025-01-05 21:43:38.193\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m174\u001b[0m - \u001b[1mConvergence threshold reached at iteration 13\u001b[0m\n", "\u001b[32m2025-01-05 21:43:38.195\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:43:48.803\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:43:48.803\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m197\u001b[0m - \u001b[32m\u001b[1mRelative value iteration completed\u001b[0m\n" ] } ], "source": [ "solver = RelativeValueIteration(problem, epsilon=1e-4)\n", "solution = solver.solve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Checkpointing\n", "\n", "Some large problems can take a long time, so MDPax supports checkpointing so that you can restart a run from a checkpoint if there is a problem.\n", "\n", "Checkpointing is not enabled by default, because it is not very useful for smaller problems. You can activate it by setting `checkpoint_frequency` > 1 when instantiating a solver. The solver will then store a checkpoint every `checkpoint_frequency` iterations, and once it meets the convergence threshold.\n", "\n", "Let's start by running a problem to convergence to get a reference policy. " ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:43:50.218\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with hendrix_two_product problem\u001b[0m\n", "\u001b[32m2025-01-05 21:43:50.850\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m123\u001b[0m - \u001b[1mCheckpointing not enabled\u001b[0m\n", "\u001b[32m2025-01-05 21:43:51.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 1: span: 6.52970, gain: 19.9641\u001b[0m\n", "\u001b[32m2025-01-05 21:43:51.996\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 2: span: 1.23312, gain: 3.4442\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.098\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 3: span: 0.29569, gain: 4.3579\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.191\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 4: span: 0.18104, gain: 4.6498\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.284\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 5: span: 0.04119, gain: 4.4774\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.374\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 6: span: 0.01627, gain: 4.4977\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.466\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 7: span: 0.00502, gain: 4.5060\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.557\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 8: span: 0.00201, gain: 4.5021\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.648\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 9: span: 0.00060, gain: 4.5032\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.738\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 10: span: 0.00040, gain: 4.5033\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.830\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 11: span: 0.00016, gain: 4.5031\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.919\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 12: span: 0.00007, gain: 4.5032\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.921\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m174\u001b[0m - \u001b[1mConvergence threshold reached at iteration 12\u001b[0m\n", "\u001b[32m2025-01-05 21:43:52.923\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:43:53.714\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:43:53.715\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m197\u001b[0m - \u001b[32m\u001b[1mRelative value iteration completed\u001b[0m\n" ] } ], "source": [ "problem = HendrixTwoProductPerishable(max_useful_life = 2, \n", " demand_poisson_mean_a=5, \n", " demand_poisson_mean_b=5, \n", " max_order_quantity_a=10, \n", " max_order_quantity_b=10)\n", "solver_a = RelativeValueIteration(problem, epsilon=1e-4)\n", "result_from_full_run = solver_a.solve()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's imagine the run got interrupted. To mimic that, we set `max_iterations` to less than the number of iterations required for convergence and activate checkpointing. We'll save a checkpoint every iteration, only keep the most recent checkpoint, and save them in directory `checkpoints/getting_started/incomplete_run`." ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:43:53.725\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with hendrix_two_product problem\u001b[0m\n", "\u001b[32m2025-01-05 21:43:54.310\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m143\u001b[0m - \u001b[1mFull checkpointing enabled with problem and solver reconstruction\u001b[0m\n", "\u001b[32m2025-01-05 21:43:54.310\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m152\u001b[0m - \u001b[1mSaving checkpoints every 1 iteration(s) to /home/joefarrington/other_learning/mdpax/examples/checkpoints/getting_started/incomplete_run\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.201\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 1: span: 6.52970, gain: 19.9641\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.341\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 2: span: 1.23312, gain: 3.4442\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.437\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 3: span: 0.29569, gain: 4.3579\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.528\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 4: span: 0.18104, gain: 4.6498\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.620\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 5: span: 0.04119, gain: 4.4774\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.623\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m186\u001b[0m - \u001b[1mMaximum iterations reached\u001b[0m\n", "\u001b[32m2025-01-05 21:43:55.624\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:43:56.419\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:43:56.420\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m197\u001b[0m - \u001b[32m\u001b[1mRelative value iteration completed\u001b[0m\n" ] } ], "source": [ "solver_b = RelativeValueIteration(problem, epsilon=1e-4, checkpoint_frequency=1, checkpoint_dir=\"checkpoints/getting_started/incomplete_run\")\n", "result_from_incomplete_run = solver_b.solve(max_iterations=5)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:43:58.150\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.core.solver\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m159\u001b[0m - \u001b[1mSolver initialized with hendrix_two_product problem\u001b[0m\n", "\u001b[32m2025-01-05 21:43:58.781\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m143\u001b[0m - \u001b[1mFull checkpointing enabled with problem and solver reconstruction\u001b[0m\n", "\u001b[32m2025-01-05 21:43:58.781\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.utils.checkpointing\u001b[0m:\u001b[36m_setup_checkpointing\u001b[0m:\u001b[36m152\u001b[0m - \u001b[1mSaving checkpoints every 1 iteration(s) to /home/joefarrington/other_learning/mdpax/examples/checkpoints/getting_started/continued_run\u001b[0m\n" ] } ], "source": [ "solver_c = RelativeValueIteration.restore(checkpoint_dir=\"checkpoints/getting_started/incomplete_run\", new_checkpoint_dir=\"checkpoints/getting_started/continued_run\")" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Values restored correctly: True\n", "Iteration restored correctly: True\n" ] } ], "source": [ "print(f\"Values restored correctly: {np.all(solver_c.values == result_from_incomplete_run.values)}\")\n", "print(f\"Iteration restored correctly: {solver_c.iteration == result_from_incomplete_run.info.iteration}\")" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-01-05 21:43:59.775\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 6: span: 0.01627, gain: 4.4977\u001b[0m\n", "\u001b[32m2025-01-05 21:43:59.887\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 7: span: 0.00502, gain: 4.5060\u001b[0m\n", "\u001b[32m2025-01-05 21:43:59.981\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 8: span: 0.00201, gain: 4.5021\u001b[0m\n", "\u001b[32m2025-01-05 21:44:00.076\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 9: span: 0.00060, gain: 4.5032\u001b[0m\n", "\u001b[32m2025-01-05 21:44:00.165\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 10: span: 0.00040, gain: 4.5033\u001b[0m\n", "\u001b[32m2025-01-05 21:44:00.262\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 11: span: 0.00016, gain: 4.5031\u001b[0m\n", "\u001b[32m2025-01-05 21:44:00.353\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m169\u001b[0m - \u001b[1mIteration 12: span: 0.00007, gain: 4.5032\u001b[0m\n", "\u001b[32m2025-01-05 21:44:00.354\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m174\u001b[0m - \u001b[1mConvergence threshold reached at iteration 12\u001b[0m\n", "\u001b[32m2025-01-05 21:44:00.355\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting policy\u001b[0m\n", "\u001b[32m2025-01-05 21:44:01.152\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mPolicy extracted\u001b[0m\n", "\u001b[32m2025-01-05 21:44:01.153\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mmdpax.solvers.relative_value_iteration\u001b[0m:\u001b[36msolve\u001b[0m:\u001b[36m197\u001b[0m - \u001b[32m\u001b[1mRelative value iteration completed\u001b[0m\n" ] } ], "source": [ "result_from_continued_run = solver_c.solve()" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Policy from restored run same as full run: True\n" ] } ], "source": [ "print(f\"Policy from restored run same as full run: {np.all(result_from_continued_run.policy == result_from_full_run.policy)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Next Steps\n", "\n", "- Try the next [tutorial](https://mdpax.readthedocs.io/en/latest/notebooks/create_custom_problem.html) to learn how to implement your own problems using MDPax [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/joefarrington/mdpax/blob/main/examples/create_custom_problem.ipynb)\n", "- Read the [MDPax documentation](https://mdpax.readthedocs.io/en/latest/index.html) for detailed API reference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Footnotes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* [1] Based on 50 iterations run on 2025-10-25 using a Tesla T4 GPU instance and Colab Runtime version 2025.10. Maximum runtime 19.2s, mean runtime 18.2s, standard deviation 0.4s. \n", "\n", "* [2] Based on 50 iterations run on 2025-10-25 using a Tesla T4 GPU instance and Colab Runtime version 2025.10. Maximum runtime 138.5s, mean runtime 135.3s, standard deviation 0.7s. " ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }