Getting started with MDPax

Open In Colab

This notebook demonstrates MDPax’s key features through increasingly complex examples. We’ll start with a simple forest management problem and work our way up to larger, more realistic problems.

If you’re running the notebook in Google Colab, you should verify that you’re using a GPU instance. Click Runtime > Change runtime type and ensure “GPU” is selected as the Hardware accelerator. You can confirm GPU availability by running !nvidia-smi in a code cell.

Prerequisites

If you’re new to Markov Decision Processes (MDPs), you may find these introductory resources useful:

Installation and imports

If you’re running the notebook in Google Colab you may need to restart the runtime after installing the dependencies so that the updated packages can be loaded. You will receive a “Warning” message in the output of the cell below if this is the case.

[51]:
import sys

try:
    # Other dependencies will be installed on Colab already
    import mdptoolbox
    import mdpax
except ImportError:
    if 'google.colab' in sys.modules:
        # Automatically install mdpax if running in Colab, environment is temporary
        !pip install "mdpax[examples-colab] @ git+https://github.com/joefarrington/mdpax.git"
    else:
        print("Dependencies not installed. Please follow the installation instructions in the README: https://github.com/joefarrington/mdpax")
[52]:
import jax

# In general we recommend using double precision and it is particularly helpful when
# performing comparisons with pymdptoolbox which uses NumPy and therefore defaults to
# double precision.
jax.config.update("jax_enable_x64", True)

Simple example: forest management

Let’s start with a simple example problem introduced in pymdptoolbox, an alternative library for solving MDPs in Python, so that we can compare our results.

This problem involves deciding whether we should cut down a forest, or wait to let it mature.

The state is the current age of the forest and our actions are 0 (wait) and 1 (cut). We receive a reward of 1 if we cut down the forest before it is mature, a reward of \(r_1\) if we wait in the oldest state, and a reward of \(r_2\) if we cut the forst in the oldest state. There is a risk of fire occurring, with probability \(p\) of a fire at each timestep. If we choose to cut down the forest, or if there is a fire, the forest returns to age 0 (newly planted).

The two key base classes in mdpax are Problem and Solver - the Problem class is used to define the MDP (in this case the forest problem) and the Solver class is used to define algorithms for fitting policies (in this case, value iteration).

[53]:
import jax.numpy as jnp
import mdptoolbox
import numpy as np

from mdpax.problems.forest import Forest
from mdpax.solvers.value_iteration import ValueIteration
[54]:
# Create and solve the basic forest problem with MDPax
problem = Forest(S=3, r1=4.0, r2=2.0, p=0.1)  # Small forest with 3 states
solver = ValueIteration(problem, gamma=0.9, epsilon=0.01)
solution = solver.solve()
2025-01-05 21:40:55.398 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with forest problem
2025-01-05 21:41:01.698 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:123 - Checkpointing not enabled
2025-01-05 21:41:01.938 | INFO     | mdpax.solvers.value_iteration:solve:497 - Iteration 1 span: 4.0000
2025-01-05 21:41:02.061 | INFO     | mdpax.solvers.value_iteration:solve:497 - Iteration 2 span: 2.4300
2025-01-05 21:41:02.074 | INFO     | mdpax.solvers.value_iteration:solve:497 - Iteration 3 span: 0.8100
2025-01-05 21:41:02.079 | INFO     | mdpax.solvers.value_iteration:solve:497 - Iteration 4 span: 0.0000
2025-01-05 21:41:02.081 | INFO     | mdpax.solvers.value_iteration:solve:502 - Convergence threshold reached at iteration 4
2025-01-05 21:41:02.083 | INFO     | mdpax.solvers.value_iteration:solve:521 - Extracting policy
2025-01-05 21:41:02.232 | INFO     | mdpax.solvers.value_iteration:solve:523 - Policy extracted
2025-01-05 21:41:02.233 | SUCCESS  | mdpax.solvers.value_iteration:solve:525 - Value iteration completed
[55]:
# solution is a dataclass from which we can extract the values, policy, and iteration count
mdpax_values = solution.values
mdpax_policy = solution.policy
mdpax_iteration = solution.info.iteration

# Print the solution from MDPax
print("MDPax Solution:")
print("--------------")
print(f"Values:\n{np.round(mdpax_values.flatten(), 4)}")  # State values
print(f"Policy:\n{mdpax_policy.flatten()}")  # Optimal actions
print(f"Iterations to converge: {mdpax_iteration}")  # From solver info
MDPax Solution:
--------------
Values:
[ 5.052  8.292 12.292]
Policy:
[0 0 0]
Iterations to converge: 4

MDPax relies on a functional description of the MDP, using the Problem class. This includes defining the probability of a random event given a state and an action, and a deterministic transition function that gives the reward and the next state given a state, action and random event.

Many other libraries, such as pymdptoolbox require the user to provide the transition matrix \(\mathbf{P}\) and a reward matrix \(\mathbf{R}\) for the MDP.

Transition matrix \(\mathbf{P}\) has dimensions (n_actions, n_states, n_states) and element \(\mathbf{P}_{a,s,s'}\) is the probability of transitioning to state \(s'\) when taking action \(a\) in state \(s\).

Reward matrix \(\mathbf{R}\) has dimensions (n_states, n_actions) and element \(\mathbf{R}_{s,a}\) gives the expected reward when taking action \(a\) in state \(s\).

The MDPax Problem class has a built-in method for constructing these matrices based on the functions describing the problem, which we can use to construct \(\mathbf{P}\) and \(\mathbf{R}\) and solve the forest management problem using pymdptoolbox to check our solution.

[56]:
# Get transition and reward matrices for comparison
P, R = problem.build_transition_and_reward_matrices()
# Convert to numpy arrays for pymdptoolbox
P = np.array(P)
R = np.array(R)
[57]:
# Solve with pymdptoolbox
vi = mdptoolbox.mdp.ValueIteration(P, R, discount=0.9, epsilon=0.01)
vi.run()
[58]:
# Extract the solution from pymdptoolbox class
toolbox_values = vi.V
toolbox_policy = vi.policy
toolbox_iteration = vi.iter


# Print the solution from pymdptoolbox
print("\npymdptoolbox Solution:")
print("--------------------")
print(f"Values:\n{np.round(toolbox_values, 4)}")
print(f"Policy:\n{toolbox_policy}")
print(f"Iterations to converge: {toolbox_iteration}")

pymdptoolbox Solution:
--------------------
Values:
[ 5.052  8.292 12.292]
Policy:
(0, 0, 0)
Iterations to converge: 4
[59]:
# Verify solutions match
print("\nSolutions match?")
print(f"Values close?: {np.allclose(mdpax_values.flatten(), toolbox_values, rtol=1e-2)}")
print(f"Policies match?: {np.array_equal(mdpax_policy.flatten(), np.array(toolbox_policy))}")
print(f"Number of iterations match?: {mdpax_iteration == toolbox_iteration}")

Solutions match?
Values close?: True
Policies match?: True
Number of iterations match?: True

For this very small problem, pymdptoolbox will be faster than MDPax due to data trasfer costs moving data to and from GPU and the upfront costs for JIT compilation in JAX.

A larger problem: perishable inventory management with substitution

Perishable inventory management problems are known to be computationally challenging to solve exactly (e.g. with value iteration) because the state must represent the age-profile of the stock (how many units of each age are held) and therefore the size of the state space grows exponentially with the maximum useful life of the product.

In this example, we consider a perishable inventory management problem introduced by Hendrix et al. (2019).

The decision maker must place a replenishment order each day for two perishable products, product A and product B. Orders are placed in the morning and arrive immediately before the start of the next period. Demand for each product each day is random. Some customers who want product B may accept product A, and substitutions are made once demand for product A has been met as far as possbile.

The goal is to maximise average daily profits (sales revenue less an ordering cost per unit), and therefore we use relative value iteration, (with no discounting of future rewards), as the Solver.

The smallest example considered by Hendrix et al. has 11,025 states and 105 actions. Since \(\mathbf{P}\) has dimensions (n_actions, n_states, n_states) the transition matrix for the problem would have \((105 \times 11,025 \times 11,025) = 13\text{Bn}\) elements. Just storing this matrix as 64-bit floats would require over 100GB of RAM!

So, to start with, so that we can compare our results with pymdptoolbox, we’ll look at a smaller version of the problem with 625 states and 25 actions.

A note on sparsity: The comments on the size of the transition matrices in this introductory notebook do not take potential sparsity into account. pymdptoolbox has support for sparse arrays and this would reduce the memory requirements required to represent the transition matrices. Support for sparse arrays in JAX is currently experimental. We may investigate the potential benefits of sparsity as part of a future release.

Basic case

[60]:
from mdpax.problems.perishable_inventory.hendrix_two_product import (
    HendrixTwoProductPerishable,
)
from mdpax.solvers.relative_value_iteration import RelativeValueIteration

problem = HendrixTwoProductPerishable(max_useful_life = 2, # Products can be used for 2 periods after arrival, then are discarded
                                                  demand_poisson_mean_a=2, # Mean daily demand for product A
                                                  demand_poisson_mean_b=2, # Mean daily demand for product B
                                                  max_order_quantity_a=4, # Maximum order quantity for product A
                                                  max_order_quantity_b=4) # Maximum order quantity for product B

print(f"Number of states: {problem.n_states}")
print(f"Number of actions: {problem.n_actions}")
Number of states: 625
Number of actions: 25

As above, we’l first solve it using MDPax.

[61]:
solver = RelativeValueIteration(problem, epsilon=1e-4)
solution = solver.solve()
mdpax_values = solution.values
mdpax_policy = solution.policy
mdpax_average_daily_profit = solution.info.gain
2025-01-05 21:41:02.977 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with hendrix_two_product problem
2025-01-05 21:41:03.688 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:123 - Checkpointing not enabled
2025-01-05 21:41:04.398 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 1: span: 2.84548, gain: 7.8527
2025-01-05 21:41:04.415 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 2: span: 0.60450, gain: 1.0389
2025-01-05 21:41:04.428 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 3: span: 0.10172, gain: 1.5746
2025-01-05 21:41:04.440 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 4: span: 0.07247, gain: 1.5865
2025-01-05 21:41:04.454 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 5: span: 0.03349, gain: 1.5201
2025-01-05 21:41:04.467 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 6: span: 0.01414, gain: 1.5402
2025-01-05 21:41:04.481 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 7: span: 0.00648, gain: 1.5447
2025-01-05 21:41:04.503 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 8: span: 0.00331, gain: 1.5442
2025-01-05 21:41:04.515 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 9: span: 0.00229, gain: 1.5417
2025-01-05 21:41:04.530 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 10: span: 0.00091, gain: 1.5439
2025-01-05 21:41:04.542 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 11: span: 0.00045, gain: 1.5435
2025-01-05 21:41:04.553 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 12: span: 0.00028, gain: 1.5432
2025-01-05 21:41:04.565 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 13: span: 0.00011, gain: 1.5434
2025-01-05 21:41:04.579 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 14: span: 0.00009, gain: 1.5435
2025-01-05 21:41:04.581 | INFO     | mdpax.solvers.relative_value_iteration:solve:174 - Convergence threshold reached at iteration 14
2025-01-05 21:41:04.584 | INFO     | mdpax.solvers.relative_value_iteration:solve:193 - Extracting policy
2025-01-05 21:41:05.282 | INFO     | mdpax.solvers.relative_value_iteration:solve:195 - Policy extracted
2025-01-05 21:41:05.282 | SUCCESS  | mdpax.solvers.relative_value_iteration:solve:197 - Relative value iteration completed

The gain converges to the mean reward per timestep, so we can see that the mean profit per day is $1.54.

Next, because this is a small problem, we can build \(\mathbf{P}\) and \(\mathbf{R}\) and solve the problem using pymdptoolbox. Also, for this problem, the value function is initialized using the one-step ahead revenue. The RelativeValueIterationSolver automatically calculated this from the Problem, but we will need to provide it manually to mdptoolbox.

[62]:
# Compute P, R and initial values from Problem
P, R = problem.build_transition_and_reward_matrices()
initial_values = jax.vmap(problem.initial_value)(problem.state_space)
P = np.array(P)
R = np.array(R)
initial_values = np.array(initial_values)
[63]:
# Run relative value iteration with mpdtoolbox
rvi = mdptoolbox.mdp.RelativeValueIteration(P, R, epsilon=1e-4)
rvi.V = initial_values
rvi.run()

# Extract the solution from pymdptoolbox class
toolbox_values = rvi.V
toolbox_policy = rvi.policy
toolbox_average_daily_profit = rvi.average_reward

pymdptoolbox gives us the index of the best action. This was fine in the Forest example because the actions are not numeric and were were only identified by an index. In this example, where an action is a order quantitity for each of product A and product B, we need to look up the actual action to compare to the policy from MDPax.

We can do this be indexing into the action_space attribute of our Problem. See our tutorial on implementing your own problem for more information on state, action and event spaces.

[64]:
toolbox_policy = problem.action_space.take(jnp.array(toolbox_policy),axis=0)
[65]:
# Verify solutions match
print("\nSolutions match?")
print(f"Mean daily profit matches?: {np.allclose(mdpax_average_daily_profit,toolbox_average_daily_profit, rtol=1e-2)}")
print(f"Values close?: {np.allclose(mdpax_values.flatten(), toolbox_values, rtol=1e-2)}")
print(f"Policies match?: {np.array_equal(mdpax_policy, toolbox_policy)}")
print(f"Number of iterations match?: {mdpax_iteration == toolbox_iteration}")

Solutions match?
Mean daily profit matches?: True
Values close?: True
Policies match?: True
Number of iterations match?: True

Larger cases

We’ll now consider two larger versions of the problem which were included in Hendrix et al. (2019).

Case 1

[66]:
problem = HendrixTwoProductPerishable(max_useful_life = 2, # Products can be used for 2 periods after arrival, then are discarded
                                      demand_poisson_mean_a=7, # Mean daily demand for product A
                                      demand_poisson_mean_b=3, # Mean daily demand for product B
                                      max_order_quantity_a=14, # Maximum order quantity for product A
                                      max_order_quantity_b=6) # Maximum order quantity for product B

print(f"Number of states: {problem.n_states}")
print(f"Number of actions: {problem.n_actions}")
Number of states: 11025
Number of actions: 105
[67]:
solver = RelativeValueIteration(problem, epsilon=1e-4)
solution = solver.solve()
2025-01-05 21:41:15.889 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with hendrix_two_product problem
2025-01-05 21:41:16.582 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:123 - Checkpointing not enabled
2025-01-05 21:41:17.612 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 1: span: 6.51383, gain: 19.9606
2025-01-05 21:41:17.682 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 2: span: 1.24455, gain: 3.4498
2025-01-05 21:41:17.751 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 3: span: 0.30959, gain: 4.3591
2025-01-05 21:41:17.813 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 4: span: 0.17163, gain: 4.6679
2025-01-05 21:41:17.875 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 5: span: 0.03459, gain: 4.5060
2025-01-05 21:41:17.936 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 6: span: 0.01877, gain: 4.5114
2025-01-05 21:41:18.001 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 7: span: 0.00573, gain: 4.5271
2025-01-05 21:41:18.064 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 8: span: 0.00128, gain: 4.5217
2025-01-05 21:41:18.127 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 9: span: 0.00083, gain: 4.5220
2025-01-05 21:41:18.189 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 10: span: 0.00027, gain: 4.5227
2025-01-05 21:41:18.252 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 11: span: 0.00010, gain: 4.5226
2025-01-05 21:41:18.313 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 12: span: 0.00006, gain: 4.5225
2025-01-05 21:41:18.314 | INFO     | mdpax.solvers.relative_value_iteration:solve:174 - Convergence threshold reached at iteration 12
2025-01-05 21:41:18.316 | INFO     | mdpax.solvers.relative_value_iteration:solve:193 - Extracting policy
2025-01-05 21:41:19.149 | INFO     | mdpax.solvers.relative_value_iteration:solve:195 - Policy extracted
2025-01-05 21:41:19.149 | SUCCESS  | mdpax.solvers.relative_value_iteration:solve:197 - Relative value iteration completed

This takes less than 20s [1], including setting up the problem, on a Google Colab GPU instance compared to the 206s reported by Hendrix et al. for their implementation using MATLAB on the CPU.

Case 2

Hendrix et al. reported that the largest problem they were able to solve within a week had 1.2Mn states, and took 80 hours. Using MDPax on a Google Colab GPU instance it should take less than 3 minutes [2].

Storing the transition matrix for this problem as 64-bit floats would require over 1PB (or 1Mn GB) of RAM!

[68]:
problem = HendrixTwoProductPerishable(max_useful_life = 3, # Products can be used for 2 periods after arrival, then are discarded
                                      demand_poisson_mean_a=7, # Mean daily demand for product A
                                      demand_poisson_mean_b=3, # Mean daily demand for product B
                                      max_order_quantity_a=20, # Maximum order quantity for product A
                                      max_order_quantity_b=4) # Maximum order quantity for product B

print(f"Number of states: {problem.n_states}")
print(f"Number of actions: {problem.n_actions}")
Number of states: 1157625
Number of actions: 105
[69]:
solver = RelativeValueIteration(problem, epsilon=1e-4)
solution = solver.solve()
2025-01-05 21:41:25.660 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with hendrix_two_product problem
2025-01-05 21:41:28.899 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:123 - Checkpointing not enabled
2025-01-05 21:41:41.182 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 1: span: 6.54445, gain: 19.9912
2025-01-05 21:41:50.995 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 2: span: 6.27677, gain: 9.7235
2025-01-05 21:42:00.809 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 3: span: 1.48057, gain: 3.4468
2025-01-05 21:42:10.537 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 4: span: 0.57104, gain: 4.3561
2025-01-05 21:42:20.258 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 5: span: 0.09908, gain: 4.8482
2025-01-05 21:42:29.982 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 6: span: 0.09773, gain: 4.9156
2025-01-05 21:42:39.682 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 7: span: 0.02806, gain: 4.8388
2025-01-05 21:42:49.461 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 8: span: 0.01050, gain: 4.8225
2025-01-05 21:42:59.228 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 9: span: 0.00353, gain: 4.8284
2025-01-05 21:43:08.963 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 10: span: 0.00176, gain: 4.8307
2025-01-05 21:43:18.692 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 11: span: 0.00066, gain: 4.8297
2025-01-05 21:43:28.446 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 12: span: 0.00020, gain: 4.8293
2025-01-05 21:43:38.192 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 13: span: 0.00010, gain: 4.8293
2025-01-05 21:43:38.193 | INFO     | mdpax.solvers.relative_value_iteration:solve:174 - Convergence threshold reached at iteration 13
2025-01-05 21:43:38.195 | INFO     | mdpax.solvers.relative_value_iteration:solve:193 - Extracting policy
2025-01-05 21:43:48.803 | INFO     | mdpax.solvers.relative_value_iteration:solve:195 - Policy extracted
2025-01-05 21:43:48.803 | SUCCESS  | mdpax.solvers.relative_value_iteration:solve:197 - Relative value iteration completed

Checkpointing

Some large problems can take a long time, so MDPax supports checkpointing so that you can restart a run from a checkpoint if there is a problem.

Checkpointing is not enabled by default, because it is not very useful for smaller problems. You can activate it by setting checkpoint_frequency > 1 when instantiating a solver. The solver will then store a checkpoint every checkpoint_frequency iterations, and once it meets the convergence threshold.

Let’s start by running a problem to convergence to get a reference policy.

[70]:
problem = HendrixTwoProductPerishable(max_useful_life = 2,
                                      demand_poisson_mean_a=5,
                                      demand_poisson_mean_b=5,
                                      max_order_quantity_a=10,
                                      max_order_quantity_b=10)
solver_a = RelativeValueIteration(problem, epsilon=1e-4)
result_from_full_run = solver_a.solve()
2025-01-05 21:43:50.218 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with hendrix_two_product problem
2025-01-05 21:43:50.850 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:123 - Checkpointing not enabled
2025-01-05 21:43:51.852 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 1: span: 6.52970, gain: 19.9641
2025-01-05 21:43:51.996 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 2: span: 1.23312, gain: 3.4442
2025-01-05 21:43:52.098 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 3: span: 0.29569, gain: 4.3579
2025-01-05 21:43:52.191 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 4: span: 0.18104, gain: 4.6498
2025-01-05 21:43:52.284 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 5: span: 0.04119, gain: 4.4774
2025-01-05 21:43:52.374 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 6: span: 0.01627, gain: 4.4977
2025-01-05 21:43:52.466 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 7: span: 0.00502, gain: 4.5060
2025-01-05 21:43:52.557 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 8: span: 0.00201, gain: 4.5021
2025-01-05 21:43:52.648 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 9: span: 0.00060, gain: 4.5032
2025-01-05 21:43:52.738 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 10: span: 0.00040, gain: 4.5033
2025-01-05 21:43:52.830 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 11: span: 0.00016, gain: 4.5031
2025-01-05 21:43:52.919 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 12: span: 0.00007, gain: 4.5032
2025-01-05 21:43:52.921 | INFO     | mdpax.solvers.relative_value_iteration:solve:174 - Convergence threshold reached at iteration 12
2025-01-05 21:43:52.923 | INFO     | mdpax.solvers.relative_value_iteration:solve:193 - Extracting policy
2025-01-05 21:43:53.714 | INFO     | mdpax.solvers.relative_value_iteration:solve:195 - Policy extracted
2025-01-05 21:43:53.715 | SUCCESS  | mdpax.solvers.relative_value_iteration:solve:197 - Relative value iteration completed

Now, let’s imagine the run got interrupted. To mimic that, we set max_iterations to less than the number of iterations required for convergence and activate checkpointing. We’ll save a checkpoint every iteration, only keep the most recent checkpoint, and save them in directory checkpoints/getting_started/incomplete_run.

[71]:
solver_b = RelativeValueIteration(problem, epsilon=1e-4, checkpoint_frequency=1, checkpoint_dir="checkpoints/getting_started/incomplete_run")
result_from_incomplete_run = solver_b.solve(max_iterations=5)
2025-01-05 21:43:53.725 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with hendrix_two_product problem
2025-01-05 21:43:54.310 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:143 - Full checkpointing enabled with problem and solver reconstruction
2025-01-05 21:43:54.310 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:152 - Saving checkpoints every 1 iteration(s) to /home/joefarrington/other_learning/mdpax/examples/checkpoints/getting_started/incomplete_run
2025-01-05 21:43:55.201 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 1: span: 6.52970, gain: 19.9641
2025-01-05 21:43:55.341 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 2: span: 1.23312, gain: 3.4442
2025-01-05 21:43:55.437 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 3: span: 0.29569, gain: 4.3579
2025-01-05 21:43:55.528 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 4: span: 0.18104, gain: 4.6498
2025-01-05 21:43:55.620 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 5: span: 0.04119, gain: 4.4774
2025-01-05 21:43:55.623 | INFO     | mdpax.solvers.relative_value_iteration:solve:186 - Maximum iterations reached
2025-01-05 21:43:55.624 | INFO     | mdpax.solvers.relative_value_iteration:solve:193 - Extracting policy
2025-01-05 21:43:56.419 | INFO     | mdpax.solvers.relative_value_iteration:solve:195 - Policy extracted
2025-01-05 21:43:56.420 | SUCCESS  | mdpax.solvers.relative_value_iteration:solve:197 - Relative value iteration completed
[72]:
solver_c = RelativeValueIteration.restore(checkpoint_dir="checkpoints/getting_started/incomplete_run", new_checkpoint_dir="checkpoints/getting_started/continued_run")
2025-01-05 21:43:58.150 | INFO     | mdpax.core.solver:__init__:159 - Solver initialized with hendrix_two_product problem
2025-01-05 21:43:58.781 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:143 - Full checkpointing enabled with problem and solver reconstruction
2025-01-05 21:43:58.781 | INFO     | mdpax.utils.checkpointing:_setup_checkpointing:152 - Saving checkpoints every 1 iteration(s) to /home/joefarrington/other_learning/mdpax/examples/checkpoints/getting_started/continued_run
[73]:
print(f"Values restored correctly: {np.all(solver_c.values == result_from_incomplete_run.values)}")
print(f"Iteration restored correctly: {solver_c.iteration == result_from_incomplete_run.info.iteration}")
Values restored correctly: True
Iteration restored correctly: True
[74]:
result_from_continued_run = solver_c.solve()
2025-01-05 21:43:59.775 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 6: span: 0.01627, gain: 4.4977
2025-01-05 21:43:59.887 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 7: span: 0.00502, gain: 4.5060
2025-01-05 21:43:59.981 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 8: span: 0.00201, gain: 4.5021
2025-01-05 21:44:00.076 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 9: span: 0.00060, gain: 4.5032
2025-01-05 21:44:00.165 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 10: span: 0.00040, gain: 4.5033
2025-01-05 21:44:00.262 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 11: span: 0.00016, gain: 4.5031
2025-01-05 21:44:00.353 | INFO     | mdpax.solvers.relative_value_iteration:solve:169 - Iteration 12: span: 0.00007, gain: 4.5032
2025-01-05 21:44:00.354 | INFO     | mdpax.solvers.relative_value_iteration:solve:174 - Convergence threshold reached at iteration 12
2025-01-05 21:44:00.355 | INFO     | mdpax.solvers.relative_value_iteration:solve:193 - Extracting policy
2025-01-05 21:44:01.152 | INFO     | mdpax.solvers.relative_value_iteration:solve:195 - Policy extracted
2025-01-05 21:44:01.153 | SUCCESS  | mdpax.solvers.relative_value_iteration:solve:197 - Relative value iteration completed
[75]:
print(f"Policy from restored run same as full run: {np.all(result_from_continued_run.policy == result_from_full_run.policy)}")
Policy from restored run same as full run: True

Next Steps

  • Try the next tutorial to learn how to implement your own problems using MDPax Open In Colab

  • Read the MDPax documentation for detailed API reference

Footnotes

  • [1] Based on 50 iterations run on 2025-10-25 using a Tesla T4 GPU instance and Colab Runtime version 2025.10. Maximum runtime 19.2s, mean runtime 18.2s, standard deviation 0.4s.

  • [2] Based on 50 iterations run on 2025-10-25 using a Tesla T4 GPU instance and Colab Runtime version 2025.10. Maximum runtime 138.5s, mean runtime 135.3s, standard deviation 0.7s.