MDPax
MDPax is designed for researchers and practitioners who want to solve large Markov Decision Process (MDP) problems but don’t want to become experts in graphics processing unit (GPU) programming. By using JAX, we can take advantage of the massive parallel processing power of GPUs while describing new problems using a simple Python interface.
You can run MDPax on your local GPU, or try it for free using Google Colab, which provides access to GPUs in the cloud with no setup required.
Key capabilities
Solve MDPs with millions of states using value iteration or policy iteration
Automatic support for one or more identical GPUs
Flexible interface for defining your own MDP problem or solver algorithm
Asynchronous checkpointing using Orbax
Ready-to-use examples including perishable inventory problems from recent literature
Overview
MDPax is a Python package for solving large-scale MDPs, leveraging JAX’s support for vectorization, parallelization, and just-in-time (JIT) compilation on GPUs.
The package is adapted from the research code developed for Farrington et al (2025) (a preprint was released in 2023). We demonstrated that this approach is particularly well-suited for perishable inventory management problems where the state space grows exponentially with the number of products and the maximum useful life of the products. By implementing the problems in JAX and using consumer-grade GPUs (or freely available GPUs on services such as Google Colab) it is possible to compute the exact solution for realistically sized perishable inventory problems where this was recently reported to be infeasible or impractical.
Traditional value iteration implementations face two main challenges with large state spaces:
Memory requirements - the full transition matrix grows with the square of the state space size
Computational complexity - nested loops over states, actions, and possible next states become prohibitively expensive
MDPax addresses these challenges by:
Using a functional approach where users specify a deterministic transition function in terms of state, action, and random event, rather than providing the full transition matrix
Leveraging JAX’s transformations to optimize computation:
vmapto vectorize operations across states and actionspmapto parallelize across multiple GPU devices where availablejitto compile operations once and reuse them efficiently across many value iteration steps
While MDPax can run on CPU or GPU hardware, it is specifically designed for large problems (millions of states) on GPU. For small to medium-sized problems, especially when running on CPU, existing packages like pymdptoolbox may be more efficient due to JAX’s JIT compilation overhead and GPU memory transfer costs. These overheads become negligible for larger problems where the benefits of parallelization and vectorization dominate.
Installation
MDPax can be installed from PyPI using pip:
pip install mdpax
The main dependencies are:
jax
chex
numpyro
orbax
loguru
hydra-core
jaxtyping
numpy
See pyproject.toml for the complete list of dependencies and version requirements.
GPU (recommended)
MDPax is designed for GPU-accelerated computation and works best on Linux systems with NVIDIA GPUs.
For GPU support, ensure your NVIDIA drivers and CUDA toolkit are compatible with JAX. See the JAX installation guide for details.
CPU only
MDPax will automatically fall back to CPU on Linux if no GPU is detected, though performance will be significantly slower for large problems. If CUDA libraries are installed but no GPU hardware is available, you may need to force CPU execution by setting:
export JAX_PLATFORMS=cpu
Windows/macOS: JAX does not currently support GPUs on Windows and only has experimental support for Apple GPUs on macOS. MDPax therefore uses CPU-only versions of JAX on these platforms, giving reduced performance.
Examples
If you want to run the example notebooks, install the additional dependencies with:
pip install "mdpax[examples]"
Google Colab
You can try MDPax without any local installation using Google Colab, which provides free GPU access. See our Getting Started notebook for an interactive introduction.
To verify you’re using a GPU in Colab, click Runtime > Change runtime type and ensure “GPU” is selected as the Hardware accelerator. You can confirm GPU availability by running !nvidia-smi in a code cell.
Quick Start
The following example shows how to solve a simple forest management problem (adapted from pymdptoolbox’s example):
from mdpax.problems import Forest
from mdpax.solvers import ValueIteration
# Create forest management problem
problem = Forest()
# Create solver with discount factor gamma = 0.9,
# and convergence tolerance epsilon = 0.01
solver = ValueIteration(problem, gamma=0.9, epsilon=0.01)
# Solve the problem (automatically uses GPU if available)
solution = solver.solve(max_iterations=500)
# Access the optimal policy and value function
print(solution.policy) # array([[0], [0], [0]]) - "wait" for all states
print(solution.values) # value for each state under optimal policy
This example demonstrates the core workflow:
Create a problem instance
Initialize a solver
Solve to get the optimal policy and value function
Citation
If you use this software in your research, please cite our paper published in the Journal of Open Source Software (JOSS). Citation details are provided in the CITATION.cff file. You can also use the “Cite this repository” option in the About section of this GitHub repository to export the details in APA or BibTeX format.
If relevant to your work, please also consider citing the paper describing the original research from which MDPax was adapted:
@article{farrington_going_2025,
title = {Going faster to see further: graphics processing unit-accelerated value iteration and simulation for perishable inventory control using {JAX}},
url = {https://doi.org/10.1007/s10479-025-06551-6},
doi = {10.1007/s10479-025-06551-6},
journal = {Annals of Operations Research},
author = {Farrington, Joseph and Wong, Wai Keong and Li, Kezhi and Utley, Martin},
month = mar,
year = {2025},
}
License
MDPax is released under the MIT License. See the LICENSE file for details.
The forest management example problem is adapted from pymdptoolbox (BSD 3-Clause License, Copyright (c) 2011-2013 Steven A. W. Cordwell and Copyright (c) 2009 INRA). Our implementation is original, using the mdpax.core.problems.Problem class.
Acknowledgments
This library is based on research code developed during Joseph Farrington’s PhD at University College London under the supervision of Ken Li, Martin Utley, and Wai Keong Wong.
The PhD was generously supported by:
UKRI training grant EP/S021612/1, the CDT in AI-enabled Healthcare Systems
The Clinical and Research Informatics Unit at the NIHR University College London Hospitals Biomedical Research Centre