API Reference

Core

Problem

class mdpax.core.problem.Problem[source]

Abstract base class for MDP problems.

This class defines the interface for Markov Decision Process (MDP) problems. To implement a custom problem, subclass this class and implement the following:

Required Implementations:
  • name (property): Unique identifier for this problem type

  • state_to_index: Convert state vectors to indices

  • _construct_state_space: Define the full state space

  • _construct_action_space: Define the full action space

  • _construct_random_event_space: Define the space of random events

  • random_event_probability: Define transition probabilities

  • transition: Define state transitions and rewards

Optional Implementations:
  • initial_value: Custom initialization of value function (default: 0.0)

  • _setup_before_space_construction: Custom setup before space construction

  • _setup_after_space_construction: Custom setup after space construction

state_space

Array of shape [n_states, state_dim] containing all possible states

action_space

Array of shape [n_actions, action_dim] containing all possible actions

random_event_space

Array of shape [n_events, event_dim] containing all possible random events

n_states

Number of states in the problem

n_actions

Number of actions in the problem

n_random_events

Number of random events in the problem

name

A unique identifier for this problem type

Shape Requirements:
  • Single state: [state_dim]

  • Single action: [action_dim]

  • Single random event: [event_dim]

  • State space: [n_states, state_dim]

  • Action space: [n_actions, action_dim]

  • Random event space: [n_events, event_dim]

Note

All array operations should be implemented using JAX for compatibility with JIT compilation and vmap/pmap.

See also

For an interactive tutorial on how to implement a custom problem, see https://mdpax.readthedocs.io/en/latest/create_custom_problem.html

__init__()[source]

Initialize problem with all spaces and lookups constructed immediately.

abstract _construct_action_space() Float[Array, 'n_actions action_dim'][source]

Build an array of all possible actions.

Returns:

Array of shape [n_actions, action_dim] containing all possible actions

abstract _construct_random_event_space() Float[Array, 'n_events event_dim'][source]

Build an array of all possible random events.

Returns:

Array of shape [n_events, event_dim] containing all possible random events

abstract _construct_state_space() Float[Array, 'n_states state_dim'][source]

Build array of all possible states.

Returns:

Array of shape [n_states, state_dim] containing all possible states

_setup_after_space_construction() None[source]

Setup operations run after constructing spaces.

_setup_before_space_construction() None[source]

Setup operations needed before constructing spaces.

property action_space: Float[Array, 'n_actions action_dim']

Array of shape [n_actions, action_dim] containing all possible actions

build_transition_and_reward_matrices(normalization_tolerance: float = 0.0001) tuple[Float[Array, 'n_actions n_states n_states'], Float[Array, 'n_states n_actions']][source]

Build transition and reward matrices for the MDP.

This method constructs the full transition probability and reward matrices for comparison with other solvers (e.g., mdptoolbox) on small problems. Not recommended for large state/action spaces.

The transition probability matrix P has shape [n_actions, n_states, n_states] where: - P[a,s,s’] is the probability of transitioning from state s to s’ under action a

The reward matrix R has shape [n_states, n_actions] where: - R[s,a] is the expected immediate reward for taking action a in state s

Parameters:

normalization_tolerance – If probabilities sum to within this tolerance of 1, adjust the largest probability to make them sum exactly to 1. Set to 0 to disable this behavior.

Returns:

Tuple containing the transition probability matrix [n_actions, n_states, n_states] and the expected reward matrix [n_states, n_actions]

Note

This method is primarily for testing and comparison purposes. It explicitly constructs the full transition matrices which is impractical for large state spaces and will result in a memory error. The main solver implementations use the transition() method directly instead.

initial_policy(state: Float[Array, 'state_dim']) Float[Array, 'action_dim'][source]

Get initial policy for a state.

By default, raises NotImplementedError to indicate no custom policy is defined. Can be overridden to provide a custom initial policy.

Parameters:

state – Current state vector [state_dim]

Returns:

Initial action vector [action_dim] for the state

Raises:

NotImplementedError – If no custom initial policy is defined

initial_value(state: Float[Array, 'state_dim']) float[source]

Return initial value estimate for a given state.

By default returns 0.0 for all states. Override this method to provide problem-specific initial value estimates.

Parameters:

state – State vector [state_dim]

Returns:

Initial value estimate for the given state

property n_actions: int

Number of actions in the problem.

property n_random_events: int

Number of random events in the problem.

property n_states: int

Number of states in the problem.

abstract property name: str

A unique identifier for this problem type

abstract random_event_probability(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) float[source]

Calculate probability of random event given state-action pair.

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Probability of the random event occurring

Note

  • Probabilities must sum to 1 over all possible random events for each state-action pair

  • This method should be implemented to work efficiently with JAX vectorization over batches of states/actions and be compatible with JIT compilation

property random_event_space: Float[Array, 'n_events event_dim']

Array of shape [n_events, event_dim] containing all possible random events

property state_space: Float[Array, 'n_states state_dim']

Array of shape [n_states, state_dim] containing all possible states

abstract state_to_index(state: Float[Array, 'state_dim']) int[source]

Convert state vector to index.

Parameters:

state – Vector representation of a state [state_dim]

Returns:

Index of the state in state_space

Note

This mapping must be consistent with the ordering in state_space

abstract transition(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) tuple[Float[Array, 'state_dim'], float][source]

Compute next state and reward for a transition.

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Tuple containing the next state vector [state_dim] and the immediate reward

Note

This method should be implemented to work efficiently with JAX vectorization over batches of states/actions and be compatible with JIT compilation

class mdpax.core.problem.ProblemConfig(_target_: str = '???')[source]

Base configuration for all MDP problems.

This serves as the base configuration class that all specific problem configurations should inherit from. It enforces that all problems must specify their target class.

_target_

Full path to the problem class for Hydra instantiation

Type:

str

Solver

class mdpax.core.solver.Solver(problem: Problem | None = None, config: SolverConfig | None = None, **kwargs)[source]

Abstract base class for MDP solvers.

Provides common functionality for solving MDPs using parallel processing across devices with batched state updates. Subclasses must implement the core solution algorithm while inheriting the parallel processing and batching infrastructure.

Required Implementations:
  • _setup_convergence_testing: Setup convergence testing functions and thresholds.

  • _iteration_step: Perform one iteration of the solution algorithm.

Optional Implementations:
  • _setup_additional_components:Hook for additional setup in derived classes.

Shape Requirements:
  • Values: [n_states]

  • Policy: [n_states, action_dim]

  • Batched states: [n_devices, n_batches, batch_size, state_dim]

  • Batched results: [n_devices, n_batches, batch_size]

problem

Problem instance being solved

gamma

Discount factor

epsilon

Convergence threshold

max_batch_size

Maximum batch size for parallel processing

values

Current value function [n_states] or None

policy

Current policy [n_states, action_dim] or None

iteration

Current iteration count

batch_processor

Utility for handling batched computations

verbose

Current verbosity level

solver_state

Current solver state containing values, policy, and info

n_devices

Number of available JAX devices for parallel processing

batch_size

Actual batch size being used (may be less than max_batch_size)

n_pad

Number of padding elements added to make batches fit devices

Note

  • All array operations use JAX for efficient parallel processing

  • States are automatically batched and padded for device distribution

  • Subclasses should use jax.jit and jax.pmap for performance

__init__(problem: Problem | None = None, config: SolverConfig | None = None, **kwargs)[source]
abstract _iteration_step() tuple[Float[Array, 'n_states'], float][source]

Perform one iteration of the solution algorithm.

Returns:

Tuple of (new values, convergence measure) where new values has shape [n_states]

_setup_additional_components() None[source]

Hook for additional setup in derived classes.

abstract _setup_convergence_testing() None[source]

Set up convergence testing functions and thresholds.

property batch_size: int

Actual batch size being used.

property n_devices: int

Number of available devices.

property n_pad: int

Number of padding elements added.

set_verbosity(level: int | str) None[source]

Set the verbosity level for solver output.

Parameters:

level – Verbosity level, either as integer (0-4) or string (‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’, ‘TRACE’)

Integer levels map to:
  • 0: Minimal output (only errors)

  • 1: Show warnings and errors

  • 2: Show main progress (default)

  • 3: Show detailed progress

  • 4: Show everything

solve(max_iterations: int = 2000) SolverState[source]

Run solver to convergence or max iterations.

Parameters:

max_iterations – Maximum number of iterations to run

Returns:

SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including iteration count

property solver_state: SolverState

Get solver state for checkpointing.

class mdpax.core.solver.SolverConfig(_target_: str = '???', problem: ProblemConfig = '???', gamma: float = '???', epsilon: float = '???', max_batch_size: int = '???', jax_double_precision: bool = '???', verbose: int = '???')[source]

Base configuration for all MDP solvers.

This serves as the base configuration class that all specific solver configurations should inherit from. It defines common parameters used across different solvers.

class mdpax.core.solver.SolverState(values: Float[Array, 'n_states'] | None, policy: Float[Array, 'n_states action_dim'] | None, info: SolverInfo)[source]

Base runtime state for all solvers.

Contains the core state that must be maintained by all solvers. Specific solvers can extend the info field with solver-specific metadata.

values

Current value function [n_states]

Type:

jaxtyping.Float[Array, ‘n_states’] | None

policy

Current policy (if computed) [n_states, action_dim]

Type:

jaxtyping.Float[Array, ‘n_states action_dim’] | None

info

Solver metadata

Type:

mdpax.core.solver.SolverInfo

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values
class mdpax.core.solver.SolverInfo(iteration: int)[source]

Base solver information.

Contains common metadata needed by all solvers. Specific solvers can extend this with additional fields.

iteration

Current iteration count

Type:

int

items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
values() an object providing a view on D's values

Solvers

class mdpax.solvers.value_iteration.ValueIteration(problem: Problem | None = None, config: ValueIterationConfig | None = None, **kwargs)[source]

Bases: Solver, CheckpointMixin

Value iteration solver for MDPs.

This solver implements synchronous value iteration with parallel state updates across devices. States are automatically batched and padded for efficient parallel processing.

Convergence testing uses the span of differences in values by default (convergence_test=’span’). If the value function is needed for further analysis, use convergence_test=’max_diff’ to test the maximum absolute difference between successive iterations.

The default settings match the behaviour of pymdptoolbox’s ValueIteration class.

Supports checkpointing for long-running problems using the CheckpointMixin.

Parameters:
  • problem – Problem instance or None if using config

  • config – Configuration object. If provided, other kwargs are ignored.

  • **kwargs – Parameters matching ValueIterationConfig. See Config class for detailed parameter descriptions.

Config

alias of ValueIterationConfig

property batch_size: int

Actual batch size being used.

property has_full_config: bool

Check if solver and problem have complete configs for reconstruction.

Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects.

Returns:

True if both solver and problem have complete configs, False otherwise.

property is_checkpointing_enabled: bool

Check if checkpointing is enabled.

Returns:

True if checkpointing is properly configured and enabled, False otherwise.

load_checkpoint(checkpoint_dir: str | Path, step: int | None = None) None

Load solver state from checkpoint.

Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir).

Parameters:
  • checkpoint_dir – Directory containing checkpoint to load from

  • step – Specific step to load. If None, loads the latest checkpoint.

Returns:

None

property n_devices: int

Number of available devices.

property n_pad: int

Number of padding elements added.

classmethod restore(checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None) Solver

Load solver from checkpoint.

This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters.

Parameters:
  • checkpoint_dir – Directory containing checkpoints.

  • step – Specific step to load. If None, loads the latest checkpoint.

  • new_checkpoint_dir – Optional new directory for future checkpoints. Useful when restoring to a different location.

  • checkpoint_frequency – Optional new checkpoint frequency.

  • max_checkpoints – Optional new maximum number of checkpoints.

  • enable_async_checkpointing – Optional new async checkpointing setting.

Returns:

Reconstructed solver instance with restored state.

save(step: int) None

Save current solver state to checkpoint.

Parameters:

step – Current iteration/step number to associate with the checkpoint.

Returns:

None

set_verbosity(level: int | str) None

Set the verbosity level for solver output.

Parameters:

level – Verbosity level, either as integer (0-4) or string (‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’, ‘TRACE’)

Integer levels map to:
  • 0: Minimal output (only errors)

  • 1: Show warnings and errors

  • 2: Show main progress (default)

  • 3: Show detailed progress

  • 4: Show everything

solve(max_iterations: int = 2000) SolverState[source]

Run solver to convergence or max iterations.

Parameters:

max_iterations – Maximum number of iterations to run

Returns:

SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including iteration count

property solver_state: SolverState

Get solver state for checkpointing.

class mdpax.solvers.relative_value_iteration.RelativeValueIteration(problem: Problem | None = None, config: RelativeValueIterationConfig | None = None, **kwargs)[source]

Bases: ValueIteration

Relative value iteration solver for average reward MDPs.

This solver extends standard value iteration to handle average reward MDPs by:
  • Using gamma=1.0 (no discounting)

  • Tracking and subtracting a gain term to handle unbounded values

Convergence testing is based on the span of value differences.

Supports checkpointing for long-running problems using the CheckpointMixin.

Parameters:
  • problem – Problem instance or None if using config

  • config – Configuration object. If provided, other kwargs are ignored.

  • **kwargs – Parameters matching RelativeValueIterationConfig. See Config class for detailed parameter descriptions.

Config

alias of RelativeValueIterationConfig

property batch_size: int

Actual batch size being used.

property has_full_config: bool

Check if solver and problem have complete configs for reconstruction.

Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects.

Returns:

True if both solver and problem have complete configs, False otherwise.

property is_checkpointing_enabled: bool

Check if checkpointing is enabled.

Returns:

True if checkpointing is properly configured and enabled, False otherwise.

load_checkpoint(checkpoint_dir: str | Path, step: int | None = None) None

Load solver state from checkpoint.

Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir).

Parameters:
  • checkpoint_dir – Directory containing checkpoint to load from

  • step – Specific step to load. If None, loads the latest checkpoint.

Returns:

None

property n_devices: int

Number of available devices.

property n_pad: int

Number of padding elements added.

classmethod restore(checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None) Solver

Load solver from checkpoint.

This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters.

Parameters:
  • checkpoint_dir – Directory containing checkpoints.

  • step – Specific step to load. If None, loads the latest checkpoint.

  • new_checkpoint_dir – Optional new directory for future checkpoints. Useful when restoring to a different location.

  • checkpoint_frequency – Optional new checkpoint frequency.

  • max_checkpoints – Optional new maximum number of checkpoints.

  • enable_async_checkpointing – Optional new async checkpointing setting.

Returns:

Reconstructed solver instance with restored state.

save(step: int) None

Save current solver state to checkpoint.

Parameters:

step – Current iteration/step number to associate with the checkpoint.

Returns:

None

set_verbosity(level: int | str) None

Set the verbosity level for solver output.

Parameters:

level – Verbosity level, either as integer (0-4) or string (‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’, ‘TRACE’)

Integer levels map to:
  • 0: Minimal output (only errors)

  • 1: Show warnings and errors

  • 2: Show main progress (default)

  • 3: Show detailed progress

  • 4: Show everything

solve(max_iterations: int = 2000) RelativeValueIterationState[source]

Run solver to convergence or max iterations.

Parameters:

max_iterations – Maximum number of iterations to run

Returns:

SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including iteration count and gain

property solver_state: RelativeValueIterationState

Get solver state for checkpointing.

class mdpax.solvers.periodic_value_iteration.PeriodicValueIteration(problem: Problem | None = None, config: PeriodicValueIterationConfig | None = None, **kwargs)[source]

Bases: ValueIteration

Periodic value iteration solver for MDPs

This is particularly useful for problems with periodic structure in the state space, where it may require fewer iterations to reach convergence than standard value iteration.

Convergence testing is based on the span of value differences over a period. For undiscounted problems, this is simply the span of differences between current values and values from one period ago. For discounted problems, we sum the consecutive differences over the period, scaling each by the appropriate discount factor. The solver stores a history of values over the period to perform this comparison.

Supports checkpointing for long-running problems using the CheckpointMixin.

Parameters:
  • problem – Problem instance or None if using config

  • config – Configuration object. If provided, other kwargs are ignored.

  • **kwargs – Parameters matching PeriodicValueIterationConfig. See Config class for detailed parameter descriptions.

Config

alias of PeriodicValueIterationConfig

property batch_size: int

Actual batch size being used.

property has_full_config: bool

Check if solver and problem have complete configs for reconstruction.

Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects.

Returns:

True if both solver and problem have complete configs, False otherwise.

property is_checkpointing_enabled: bool

Check if checkpointing is enabled.

Returns:

True if checkpointing is properly configured and enabled, False otherwise.

load_checkpoint(checkpoint_dir: str | Path, step: int | None = None) None

Load solver state from checkpoint.

Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir).

Parameters:
  • checkpoint_dir – Directory containing checkpoint to load from

  • step – Specific step to load. If None, loads the latest checkpoint.

Returns:

None

property n_devices: int

Number of available devices.

property n_pad: int

Number of padding elements added.

classmethod restore(checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None) Solver

Load solver from checkpoint.

This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters.

Parameters:
  • checkpoint_dir – Directory containing checkpoints.

  • step – Specific step to load. If None, loads the latest checkpoint.

  • new_checkpoint_dir – Optional new directory for future checkpoints. Useful when restoring to a different location.

  • checkpoint_frequency – Optional new checkpoint frequency.

  • max_checkpoints – Optional new maximum number of checkpoints.

  • enable_async_checkpointing – Optional new async checkpointing setting.

Returns:

Reconstructed solver instance with restored state.

save(step: int) None

Save current solver state to checkpoint.

Parameters:

step – Current iteration/step number to associate with the checkpoint.

Returns:

None

set_verbosity(level: int | str) None

Set the verbosity level for solver output.

Parameters:

level – Verbosity level, either as integer (0-4) or string (‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’, ‘TRACE’)

Integer levels map to:
  • 0: Minimal output (only errors)

  • 1: Show warnings and errors

  • 2: Show main progress (default)

  • 3: Show detailed progress

  • 4: Show everything

solve(max_iterations: int = 2000) PeriodicValueIterationState[source]

Run solver to convergence or max iterations.

Parameters:

max_iterations – Maximum number of iterations to run

Returns:

SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including iteration count and value history.

property solver_state: PeriodicValueIterationState

Get solver state for checkpointing.

class mdpax.solvers.semi_async_value_iteration.SemiAsyncValueIteration(problem: Problem | None = None, config: SemiAsyncValueIterationConfig | None = None, **kwargs)[source]

Bases: ValueIteration

Semi-asynchronous value iteration solver with flexible batch ordering.

This solver extends standard value iteration by:
  • Processing states in batches with updated values immediately available to subsequent batches on the same device

  • Supporting a fixed or random (shuffle) state update order

Supports checkpointing for long-running problems using the CheckpointMixin.

Parameters:
  • problem – Problem instance or None if using config

  • config – Configuration object. If provided, other kwargs are ignored.

  • **kwargs – Parameters matching SemiAsyncValueIterationConfig. See Config class for detailed parameter descriptions.

Config

alias of SemiAsyncValueIterationConfig

property batch_size: int

Actual batch size being used.

property has_full_config: bool

Check if solver and problem have complete configs for reconstruction.

Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects.

Returns:

True if both solver and problem have complete configs, False otherwise.

property is_checkpointing_enabled: bool

Check if checkpointing is enabled.

Returns:

True if checkpointing is properly configured and enabled, False otherwise.

load_checkpoint(checkpoint_dir: str | Path, step: int | None = None) None

Load solver state from checkpoint.

Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir).

Parameters:
  • checkpoint_dir – Directory containing checkpoint to load from

  • step – Specific step to load. If None, loads the latest checkpoint.

Returns:

None

property n_devices: int

Number of available devices.

property n_pad: int

Number of padding elements added.

classmethod restore(checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None) Solver

Load solver from checkpoint.

This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters.

Parameters:
  • checkpoint_dir – Directory containing checkpoints.

  • step – Specific step to load. If None, loads the latest checkpoint.

  • new_checkpoint_dir – Optional new directory for future checkpoints. Useful when restoring to a different location.

  • checkpoint_frequency – Optional new checkpoint frequency.

  • max_checkpoints – Optional new maximum number of checkpoints.

  • enable_async_checkpointing – Optional new async checkpointing setting.

Returns:

Reconstructed solver instance with restored state.

save(step: int) None

Save current solver state to checkpoint.

Parameters:

step – Current iteration/step number to associate with the checkpoint.

Returns:

None

set_verbosity(level: int | str) None

Set the verbosity level for solver output.

Parameters:

level – Verbosity level, either as integer (0-4) or string (‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’, ‘TRACE’)

Integer levels map to:
  • 0: Minimal output (only errors)

  • 1: Show warnings and errors

  • 2: Show main progress (default)

  • 3: Show detailed progress

  • 4: Show everything

solve(max_iterations: int = 2000) SemiAsyncValueIterationState[source]

Run solver to convergence or max iterations.

Parameters:

max_iterations – Maximum number of iterations to run

Returns:

SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including batch ordering info

property solver_state: SemiAsyncValueIterationState

Get solver state for checkpointing.

class mdpax.solvers.policy_iteration.PolicyIteration(problem: Problem | None = None, config: PolicyIterationConfig | None = None, **kwargs)[source]

Bases: ValueIteration

Policy iteration solver for MDPs.

This solver implements policy iteration with parallel state updates across devices. States are automatically batched and padded for efficient parallel processing.

The algorithm alternates between:
  1. Policy evaluation: computing values for current policy using iterative method with batched updates

  2. Policy improvement: one-step lookahead to find better policy

The algorithm is considered to have converged when the policy does not change between successive iterations. For each iteration, the convergence of policy evaluation is tested using the span of differences in values between successive iterations by default (convergence_test=’span’).

By default, the value estimates from the previous policy as used as the starting estimates for the next policy evaluation (reset_values_for_each_policy_eval=False). To start policy evaluation from the initial values in each iteration, set reset_values_for_each_policy_eval=True.

To match the behaviour of pymdptoolbox’s PolicyIteration class (with iterative evaluation) use the following arguments:

  • reset_values_for_each_policy_eval=True

  • convergence_test=’max_diff’

  • max_eval_iter=10000

  • epsilon=1e-4

Supports checkpointing for long-running problems using the CheckpointMixin.

Parameters:
  • problem – Problem instance or None if using config

  • config – Configuration object. If provided, other kwargs are ignored.

  • **kwargs – Parameters matching PolicyIterationConfig. See Config class for detailed parameter descriptions.

Config

alias of PolicyIterationConfig

property batch_size: int

Actual batch size being used.

property has_full_config: bool

Check if solver and problem have complete configs for reconstruction.

Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects.

Returns:

True if both solver and problem have complete configs, False otherwise.

property is_checkpointing_enabled: bool

Check if checkpointing is enabled.

Returns:

True if checkpointing is properly configured and enabled, False otherwise.

load_checkpoint(checkpoint_dir: str | Path, step: int | None = None) None

Load solver state from checkpoint.

Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir).

Parameters:
  • checkpoint_dir – Directory containing checkpoint to load from

  • step – Specific step to load. If None, loads the latest checkpoint.

Returns:

None

property n_devices: int

Number of available devices.

property n_pad: int

Number of padding elements added.

classmethod restore(checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None) Solver

Load solver from checkpoint.

This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters.

Parameters:
  • checkpoint_dir – Directory containing checkpoints.

  • step – Specific step to load. If None, loads the latest checkpoint.

  • new_checkpoint_dir – Optional new directory for future checkpoints. Useful when restoring to a different location.

  • checkpoint_frequency – Optional new checkpoint frequency.

  • max_checkpoints – Optional new maximum number of checkpoints.

  • enable_async_checkpointing – Optional new async checkpointing setting.

Returns:

Reconstructed solver instance with restored state.

save(step: int) None

Save current solver state to checkpoint.

Parameters:

step – Current iteration/step number to associate with the checkpoint.

Returns:

None

set_verbosity(level: int | str) None

Set the verbosity level for solver output.

Parameters:

level – Verbosity level, either as integer (0-4) or string (‘ERROR’, ‘WARNING’, ‘INFO’, ‘DEBUG’, ‘TRACE’)

Integer levels map to:
  • 0: Minimal output (only errors)

  • 1: Show warnings and errors

  • 2: Show main progress (default)

  • 3: Show detailed progress

  • 4: Show everything

solve(max_iterations: int = 1000) SolverState[source]

Run solver to convergence or max iterations.

Policy iteration is guaranteed to converge in finite iterations for discounted MDPs. Stops when policy stops changing or max iterations reached.

Parameters:

max_iterations – Maximum number of iterations to run

Returns:

SolverState containing final values [n_states], optimal policy [n_states], and SolverInfo including iteration count

property solver_state: SolverState

Get solver state for checkpointing.

Problems

Basic Problems

class mdpax.problems.forest.Forest(config: ForestConfig | None = None, **kwargs)[source]

Bases: Problem

Forest management MDP problem.

The forest management problem involves deciding whether to cut down trees for immediate reward or wait for them to grow larger. There is a risk of fire destroying the forest during each time step.

Adapted from the example problem in pymdptoolbox.

State Space (state_dim = 1):
Vector containing:
  • Tree age: 1 element in range [0, S-1] (newly planted to mature forest)

Action Space (action_dim = 1):
Vector containing:
  • Decision: 1 element in range {0=wait, 1=cut}

Random Events (event_dim = 1):
Vector containing:
  • Fire occurrence: 1 element in range {0=no_fire, 1=fire}

Dynamics:
  1. Choose to cut or wait

  2. If wait:
    • Check for fire (probability p if waiting)

    • If fire, reset to age 0 with no reward

    • If no fire, age increases by 1 (up to S-1) and receive r1 reward if in oldest state

  3. If cut:
    • Receive reward r2 if in oldest state and 1 otherwise

    • Reset to age 0

Parameters:
  • config – Configuration object. If provided, keyword arguments are ignored.

  • **kwargs – Parameters matching ForestConfig. See ForestConfig for detailed parameter descriptions.

References

Config

alias of ForestConfig

property name: str

A unique identifier for this problem type

random_event_probability(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) float[source]

Compute probability of random event given state-action pair.

When waiting:
  • No fire probability is 1 - p

  • Fire probability is p

When cutting:
  • No fire probability is 1

  • Fire probability is 0

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Probability of the random event occurring

state_to_index(state: Float[Array, 'state_dim']) int[source]

Convert state vector to index.

Parameters:

state – Vector representation of a state [state_dim]

Returns:

Index of the state in state_space

transition(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) tuple[Float[Array, 'state_dim'], float][source]

Compute next state and reward for a transition.

Processes one step of the forest management system:
  1. Choose to cut or wait

  2. If wait:
    • Check for fire (probability p if waiting)

    • If fire, reset to age 0 with no reward

    • If no fire, age increases by 1 (up to S-1) and receive r1 reward if in oldest state

  3. If cut:
    • Receive reward r2 if in oldest state and 1 otherwise

    • Reset to age 0

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Tuple containing the next state vector [state_dim] and the immediate reward

Perishable Inventory Problems

class mdpax.problems.perishable_inventory.de_moor_single_product.DeMoorSingleProductPerishable(config: DeMoorSingleProductPerishableConfig | None = None, **kwargs)[source]

Bases: Problem

Perishable inventory MDP problem from De Moor et al. (2022).

Models a single-product, single-echelon, periodic review perishable inventory replenishment problem where all stock has the same remaining useful life at arrival.

State Space (state_dim = lead_time + max_useful_life - 1):
Vector containing:
  • Orders in transit: [lead_time-1] elements in range [0, max_order_quantity]

  • Stock by age: [max_useful_life] elements in range [0, max_order_quantity], ordered with oldest units on the right

Action Space (action_dim = 1):
Vector containing:
  • Order quantity: 1 element in range [0, max_order_quantity]

Random Events (event_dim = 1):
Vector containing:
  • Demand: 1 element in range [0, max_demand]

Dynamics:
  1. Place replenishment order

  2. Sample demand from truncated, discretized gamma distribution

  3. Issue stock using FIFO or LIFO policy

  4. Age remaining stock one period and discard expired units

  5. Reward is negative of total costs:
    • Variable ordering costs (per unit ordered)

    • Shortage costs (per unit of unmet demand)

    • Wastage costs (per unit that expires)

    • Holding costs (per unit in stock at end of period)

  6. Receive order placed lead_time - 1 periods ago immediately before the next period

Parameters:
  • config – Configuration object. If provided, keyword arguments are ignored.

  • **kwargs – Parameters matching DeMoorSingleProductPerishableConfig. See Config class for detailed parameter descriptions.

References

Config

alias of DeMoorSingleProductPerishableConfig

property name: str

A unique identifier for this problem type

random_event_probability(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) float[source]

Compute probability of random event given state and action.

Demand follows a discretized gamma distribution with mean demand_gamma_mean and coefficient of variation demand_gamma_cov. The demand distribution is independent of the current state and action.

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Probability of this demand value occurring

state_to_index(state: Float[Array, 'state_dim']) int[source]

Convert state vector to index.

Parameters:

state – State vector to convert [state_dim]

Returns:

Index of the state in state_space

transition(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) tuple[Float[Array, 'state_dim'], float][source]

Compute next state and reward for a transition.

Processes one step of the perishable inventory system:
  1. Place replenishment order

  2. Sample demand from truncated, discretized gamma distribution

  3. Issue stock using FIFO or LIFO policy

  4. Age remaining stock one period and discard expired units

  5. Reward is negative of total costs:
    • Variable ordering costs (per unit ordered)

    • Shortage costs (per unit of unmet demand)

    • Wastage costs (per unit that expires)

    • Holding costs (per unit in stock at end of period)

  6. Receive order placed lead_time - 1 periods ago immediately before the next period

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Tuple containing the next state vector [state_dim] and the immediate reward

class mdpax.problems.perishable_inventory.hendrix_two_product.HendrixTwoProductPerishable(config: HendrixTwoProductPerishableConfig | None = None, **kwargs)[source]

Bases: Problem

Two-product perishable inventory MDP problem from Hendrix et al. (2019).

Models a two-product, single-echelon, periodic review perishable inventory replenishment problem where all stock has the same remaining useful life at arrival and there is the possibility for substution between products.

State Space (state_dim = 2 * max_useful_life):
Vector containing:
  • Product A stock by age: [max_useful_life] elements in range [0, max_order_quantity_a], ordered with oldest units on the right

  • Product B stock by age: [max_useful_life] elements in range [0, max_order_quantity_b], ordered with oldest units on the right

Action Space (action_dim = 2):
Vector containing:
  • Product A order quantity: 1 element in range [0, max_order_quantity_a]

  • Product B order quantity: 1 element in range [0, max_order_quantity_b]

Random Events (event_dim = 2):
Vector containing:
  • Product A units issued: 1 element in range [0, max_stock_a]

  • Product B units issued: 1 element in range [0, max_stock_b]

Dynamics:
  1. Place replenishment order

  2. Random event determines units issued of each product, incorporating both:
    • Poisson-distributed demand for each product

    • Possible substitution from A to B when B’s demand exceeds stock

  3. Issue stock using FIFO policy for each product

  4. Age remaining stock one period and discard expired units

  5. Reward is revenue from units issued less variable ordering costs

  6. Receive order placed at the start of the period immediately before the next period

Parameters:
  • config – Configuration object. If provided, keyword arguments are ignored.

  • **kwargs – Parameters matching HendrixTwoProductPerishableConfig. See Config class for detailed parameter descriptions.

References

Note

The three random elements in the transition are the demand for each product and the number of units of demand for product B willing to accept substitution with product A. For consistency with the original implementation, the random events are taken the be the number of units of each product issued. The transition is deterministic given the number of products of each type issued.

Config

alias of HendrixTwoProductPerishableConfig

initial_value(state: Float[Array, 'state_dim']) float[source]

Return initial value estimate for a given state.

Initial value estimate based on one-step ahead expected sales revenue.

Parameters:

state – State vector [state_dim]

Returns:

Initial value estimate for the given state

property name: str

A unique identifier for this problem type

random_event_probability(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) float[source]

Compute probability of random event given state and action.

The number of units issued of each product follows a compound distribution:
  • Demand for each product is Poisson distributed

  • For product B, if demand exceeds stock, excess demand can be satisfied by product A with binomial probability

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Probability of this combination of issued units for both products

state_to_index(state: Float[Array, 'state_dim']) int[source]

Convert state vector to index.

Parameters:

state – State vector to convert [state_dim]

Returns:

Integer index of the state in state_space

transition(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) tuple[Float[Array, 'state_dim'], float][source]

Compute next state and reward for a transition.

Processes one step of the two-product perishable inventory system:
  1. Place replenishment order

  2. Random event determines units issued of each product, incorporating both:
    • Poisson-distributed demand for each product

    • Possible substitution from A to B when B’s demand exceeds stock

  3. Issue stock using FIFO policy for each product

  4. Age remaining stock one period and discard expired units

  5. Reward is revenue from units issued less variable ordering costs

  6. Receive order placed at the start of the period immediately before the next period

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Tuple containing the next state vector [state_dim] and the immediate reward

class mdpax.problems.perishable_inventory.mirjalili_platelet.MirjaliliPlateletPerishable(config: MirjaliliPlateletPerishableConfig | None = None, **kwargs)[source]

Bases: Problem

Platelet inventory MDP problem from Mirjalili (2022).

Models a single-product, single-echelon, periodic review perishable inventory replenishment problem for platelets in a hospital blood bank where the products have a fixed maximum useful life but uncertain remaining useful life at arrival. The distribution of remaining useful life at arrival may depend on the order quantity.

State Space (state_dim = max_useful_life):
Vector containing:
  • Weekday: 1 element in range [0, 6] (Monday to Sunday)

  • Stock by age: [max_useful_life-1] elements in range [0, max_order_quantity], ordered with oldest units on the right

Action Space (action_dim = 1):
Vector containing:
  • Order quantity: 1 element in range [0, max_order_quantity]

Random Events (event_dim = max_useful_life + 1):
Vector containing:
  • Demand: 1 element in range [0, max_demand]

  • Stock received by age: [max_useful_life] elements in range [0, max_order_quantity] summing to at most max_order_quantity

Dynamics:
  1. Place replenishment order

  2. Immediately receive the order, where the remaining useful life of the units at arrival is sampled from a multinomial distribution with parameters that may depend on the order quantity

  3. Sample demand from weekday-specific truncated negative binomial distribution

  4. Issue stock using OUFO (Oldest Units First Out) policy

  5. Age remaining stock one period and discard expired units

  6. Reward is negative of total costs:
    • Variable ordering costs (per unit ordered)

    • Fixed ordering costs (when order > 0)

    • Shortage costs (per unit of unmet demand)

    • Wastage costs (per unit that expires)

    • Holding costs (per unit in stock at end of period, including expiring units)

  7. Update weekday to next day of week

Parameters:
  • config – Configuration object. If provided, keyword arguments are ignored.

  • **kwargs – Parameters matching MirjaliliPlateletPerishableConfig. See Config class for detailed parameter descriptions.

References

Note

  • In the original source, the demand distribution is a truncated negative binomial distribution over the number of failured before reaching a specified number of successed parameterized by n (target number of successes) and delta (expected value).

  • The probability of success of a trial is n/(n + delta).

Config

alias of MirjaliliPlateletPerishableConfig

property name: str

A unique identifier for this problem type

random_event_probability(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) float[source]

Compute probability of random event given state and action.

Combines demand probabilities (based on weekday) with order receipt probabilities (based on action).

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Probability of this combination of demand and received stock

state_to_index(state: Float[Array, 'state_dim']) int[source]

Convert state vector to index.

Parameters:

state – State vector to convert [state_dim]

Returns:

Index of the state in state_space

transition(state: Float[Array, 'state_dim'], action: Float[Array, 'action_dim'], random_event: Float[Array, 'event_dim']) tuple[Float[Array, 'state_dim'], float][source]

Compute next state and reward for a transition.

Processes one step of the platelet inventory system:
  1. Place replenishment order

  2. Immediately receive the order, where the remaining useful life of the units at arrival is sampled from a multinomial distribution with parameters that may depend on the order quantity

  3. Sample demand from weekday-specific truncated negative binomial distribution

  4. Issue stock using OUFO (Oldest Units First Out) policy

  5. Age remaining stock one period and discard expired units

  6. Reward is negative of total costs:
    • Variable ordering costs (per unit ordered)

    • Fixed ordering costs (when order > 0)

    • Shortage costs (per unit of unmet demand)

    • Wastage costs (per unit that expires)

    • Holding costs (per unit in stock at end of period, including expiring units)

  7. Update weekday to next day of week

Parameters:
  • state – Current state vector [state_dim]

  • action – Action vector [action_dim]

  • random_event – Random event vector [event_dim]

Returns:

Tuple containing the next state vector [state_dim] and the immediate reward

Utils

Batch Processing

class mdpax.utils.batch_processing.BatchProcessor(n_states: int, state_dim: int, max_batch_size: int = 1024, pmap_device_count: Int[Array, ''] | None = None)[source]

Bases: object

Handles batching and padding of state spaces for parallel processing.

This class manages the batching of states for efficient parallel processing across multiple devices. It handles: - Determining batch sizes based on problem size and available devices - Padding state arrays to fit batch dimensions - Reshaping arrays for device distribution - Removing padding and batching from results

Parameters:
  • n_states – Total number of states in the problem.

  • state_dim – Dimensionality of each state vector.

  • max_batch_size – Maximum allowed size for batches. Defaults to 1024.

  • pmap_device_count – Number of JAX devices to use. If None, uses all available.

n_states

Total number of states in the problem.

Type:

int

state_dim

Dimensionality of each state vector.

Type:

int

n_devices

Number of devices being used.

Type:

int

batch_size

Actual batch size after adjusting for problem size and devices.

Type:

int

n_pad

Number of padding elements added.

Type:

int

n_batches

Number of batches per device.

Type:

int

Example

>>> n_states = 1000
>>> state_dim = 3
>>> processor = BatchProcessor(n_states, state_dim)
>>> states = jnp.zeros((n_states, state_dim))
>>> batched = processor.prepare_batches(states)  # Shape: [n_devices, n_batches, batch_size, 3]
>>> results = some_operation(batched)
>>> unbatched = processor.unbatch_results(results)  # Shape: [1000, ...]
property batch_shape: Tuple[int, int, int]

Get the shape of batched data.

Returns:

Tuple of (n_devices, n_batches, batch_size) indicating how the data will be distributed across devices and batches.

prepare_batches(states: Float[Array, 'n_states state_dim']) Float[Array, 'n_devices n_batches batch_size state_dim'][source]

Prepare states for batch processing.

Pads the state array if needed and reshapes it for distribution across devices.

Parameters:

states – Array of states with shape [n_states, state_dim].

Returns:

Array of batched and padded states with shape [n_devices, n_batches, batch_size, state_dim].

unbatch_results(batched_results: Float[Array, 'n_devices n_batches batch_size *dims']) Float[Array, 'n_states *dims'][source]

Remove batching and padding from results.

Parameters:

batched_results – Results from batch processing with shape [n_devices, n_batches, batch_size, *dims] where *dims are any additional dimensions from the operation.

Returns:

Array of unbatched and unpadded results with shape [n_states, *dims].

Checkpointing

class mdpax.utils.checkpointing.CheckpointMixin[source]

Bases: ABC

Mixin to add checkpointing capabilities to a solver.

Quick Start:
  • Working with built-in problems? Use restore() to load checkpoints

  • Working in a notebook with a custom defined problem? Use load_checkpoint()

  • Not sure? The system will automatically detect which mode to use and log accordingly

The checkpointing system uses Hydra’s structured configs, which are Python dataclasses that contain all information needed to instantiate objects. These configs exist for all example Problems and Solvers in mdpax (see Forest, DeMoorSingleProduct, etc.).

For custom problems, there are two options:

  1. Full Reconstruction: - Define your problem in a module (not a notebook) - Create a Hydra config class for your problem (see Forest for a simple example) - Allows automatic reconstruction of both problem and solver

  2. Manual Reconstruction: - Use when working in notebooks or without configs - Requires manually reconstructing problem and solver before loading state - More flexible but less automated

Required Implementation:
_restore_state_from_checkpoint(state: Dict[str, Any]) -> None:

Restore solver state from a checkpoint state dictionary.

checkpoint_dir

Directory where checkpoints are stored.

Type:

Path

checkpoint_frequency

Number of iterations between checkpoints, 0 to disable.

Type:

int

max_checkpoints

Maximum number of checkpoints to retain.

Type:

int

enable_async_checkpointing

Whether async checkpointing is enabled.

Type:

bool

checkpoint_manager

Orbax checkpoint manager instance.

Type:

checkpoint.CheckpointManager

Examples

>>> # Full Reconstruction with built-in problem
>>> from mdpax.problems import Forest  # Problem with Hydra config
>>> from mdpax.solvers import ValueIteration  # Solver with Hydra config
>>>
>>> problem = Forest(S=4)
>>> solver = ValueIteration(
...     problem=problem,
...     checkpoint_dir="checkpoints/run1",
...     checkpoint_frequency=5
... )
>>> solver.solve(max_iterations=100)
>>>
>>> # Later, full reconstruction:
>>> solver = ValueIteration.restore("checkpoints/run1")  # Recreates everything
>>> # Manual Reconstruction (e.g., custom problem in notebook)
>>> class MyProblem:  # Custom problem without config
...     def __init__(self, size):
...         self.size = size
...
>>> problem = MyProblem(size=4)
>>> solver = ValueIteration(
...     problem=problem,
...     checkpoint_dir="checkpoints/run1",
...     checkpoint_frequency=5
... )
>>> solver.solve(max_iterations=100)
>>>
>>> # Later, manual reconstruction:
>>> problem = MyProblem(size=4)  # Must recreate problem
>>> solver = ValueIteration(
...     problem=problem,
...     checkpoint_dir="checkpoints/run2",  # New save location
... )
>>> solver.load_checkpoint("checkpoints/run1")  # Load from original location
>>> solver.solve(max_iterations=50)  # New checkpoints go to run2
abstract _restore_state_from_checkpoint(state: dict[str, Any]) None[source]

Restore solver state from checkpoint.

Parameters:

state – Dictionary containing solver state from checkpoint.

property has_full_config: bool

Check if solver and problem have complete configs for reconstruction.

Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects.

Returns:

True if both solver and problem have complete configs, False otherwise.

property is_checkpointing_enabled: bool

Check if checkpointing is enabled.

Returns:

True if checkpointing is properly configured and enabled, False otherwise.

load_checkpoint(checkpoint_dir: str | Path, step: int | None = None) None[source]

Load solver state from checkpoint.

Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir).

Parameters:
  • checkpoint_dir – Directory containing checkpoint to load from

  • step – Specific step to load. If None, loads the latest checkpoint.

Returns:

None

classmethod restore(checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None) Solver[source]

Load solver from checkpoint.

This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters.

Parameters:
  • checkpoint_dir – Directory containing checkpoints.

  • step – Specific step to load. If None, loads the latest checkpoint.

  • new_checkpoint_dir – Optional new directory for future checkpoints. Useful when restoring to a different location.

  • checkpoint_frequency – Optional new checkpoint frequency.

  • max_checkpoints – Optional new maximum number of checkpoints.

  • enable_async_checkpointing – Optional new async checkpointing setting.

Returns:

Reconstructed solver instance with restored state.

save(step: int) None[source]

Save current solver state to checkpoint.

Parameters:

step – Current iteration/step number to associate with the checkpoint.

Returns:

None

Spaces

mdpax.utils.spaces.create_range_space(mins: Array, maxs: Array) tuple[Array, Callable][source]

Create a range-based discrete space and its indexing function.

Helper function for creating any discrete space (states, actions, or random events) that can be represented as ranges in each dimension. Creates both the space (all possible vectors) and a function to map vectors to indices.

The ranges in each dimension are inclusive of both mins and maxs. For example, if mins=[0] and maxs=[2], the space will include vectors [0], [1], and [2].

Parameters:
  • mins – Lower bounds for each dimension [dim], inclusive

  • maxs – Upper bounds for each dimension [dim], inclusive

Returns:

Array of all possible vectors [n_elements, dim] index_fn: Function that maps vector to unique index

Return type:

space

Note

The ‘clip’ mode in index_fnmeans any vector will map to a valid index. This is necessary for compatibility with JAX’s jit compilation but may lead to unexpected results if the vector is not within the bounds. https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ravel_multi_index.html

Example

>>> # For a state space with 2 dimensions: [0,1] x [0,2] (inclusive)
>>> state_space, state_to_index = create_range_space(jnp.array([0, 0]), jnp.array([1, 2]))
>>> print(state_space)  # All 6 combinations including bounds
[[0 0]
 [0 1]
 [0 2]
 [1 0]
 [1 1]
 [1 2]]
>>> print(state_to_index(jnp.array([1, 1])))
4
>>> # For an action space with 1 dimension: [0,5] (inclusive)
>>> action_space, action_to_index = create_range_space(jnp.array([0]), jnp.array([5]))
>>> print(action_space)  # All 6 values from 0 to 5
[[0]
 [1]
 [2]
 [3]
 [4]
 [5]]