"""Policy iteration solver for MDPs."""
import jax
import jax.numpy as jnp
from hydra.conf import dataclass
from jaxtyping import Array, Float
from loguru import logger
from mdpax.core.problem import Problem, ProblemConfig
from mdpax.core.solver import SolverConfig, SolverState
from mdpax.solvers.value_iteration import ValueIteration
from mdpax.utils.types import (
ActionSpace,
BatchedStates,
Policy,
RandomEventSpace,
StateBatch,
ValueFunction,
)
@dataclass
class PolicyIterationConfig(SolverConfig):
"""Configuration for the Policy Iteration solver.
This solver performs policy iteration using parallel processing across devices.
Each policy evaluation step uses batched computation for efficiency with large state spaces.
Args:
problem: Optional problem configuration. If not provided, can pass a Problem
instance directly to the solver. If a Problem instance with a config is
provided to the solver, its config will be extracted and stored here.
Must be a ProblemConfig if provided.
gamma: Discount factor in [0,1]
epsilon: Convergence threshold for value changes during policy evaluation (must be positive)
max_batch_size: Maximum states to process in parallel on each device (must be positive)
jax_double_precision: Whether to use float64 precision
verbose: Logging verbosity level (must be 0-4)
checkpoint_dir: Directory to store checkpoints
checkpoint_frequency: How often to save checkpoints (must be non-negative, 0 to disable)
max_checkpoints: Maximum number of checkpoints to keep (must be non-negative)
enable_async_checkpointing: Whether to save checkpoints asynchronously
max_eval_iter: Maximum iterations for policy evaluation when using iterative method
convergence_test: Strategy for testing convergence ("span" or "max_diff")
reset_values_for_each_policy_eval: Whether to reset values to initial values at start of each policy evaluation
"""
_target_: str = "mdpax.solvers.policy_iteration.PolicyIteration"
problem: ProblemConfig | None = None
gamma: float = 0.99
epsilon: float = 1e-3
max_batch_size: int = 1024
jax_double_precision: bool = True
verbose: int = 2
checkpoint_dir: str | None = None
checkpoint_frequency: int = 0
max_checkpoints: int = 1
enable_async_checkpointing: bool = True
max_eval_iter: int = 100
convergence_test: str = "span"
reset_values_for_each_policy_eval: bool = False
def __post_init__(self) -> None:
"""Validate configuration parameters."""
if self.problem is not None and not isinstance(self.problem, ProblemConfig):
raise TypeError("problem must be a ProblemConfig if provided")
if not 0 <= self.gamma <= 1:
raise ValueError("gamma must be between 0 and 1")
if self.epsilon <= 0:
raise ValueError("epsilon must be positive")
if self.max_batch_size <= 0:
raise ValueError("max_batch_size must be positive")
if self.checkpoint_frequency < 0:
raise ValueError("checkpoint_frequency must be non-negative")
if self.max_checkpoints < 0:
raise ValueError("max_checkpoints must be non-negative")
if not 0 <= self.verbose <= 4:
raise ValueError("verbose must be between 0 and 4")
if self.max_eval_iter <= 0:
raise ValueError("max_eval_iter must be positive")
if self.convergence_test not in ["span", "max_diff"]:
raise ValueError("convergence_test must be 'span' or 'max_diff'")
[docs]
class PolicyIteration(ValueIteration):
"""Policy iteration solver for MDPs.
This solver implements policy iteration with parallel state updates across devices.
States are automatically batched and padded for efficient parallel processing.
The algorithm alternates between:
1. Policy evaluation: computing values for current policy using iterative method with batched updates
2. Policy improvement: one-step lookahead to find better policy
The algorithm is considered to have converged when the policy does not change between successive
iterations. For each iteration, the convergence of policy evaluation is tested using the span
of differences in values between successive iterations by default (convergence_test='span').
By default, the value estimates from the previous policy as used as the starting estimates
for the next policy evaluation (reset_values_for_each_policy_eval=False). To start
policy evaluation from the initial values in each iteration, set reset_values_for_each_policy_eval=True.
To match the behaviour of pymdptoolbox's PolicyIteration class (with iterative evaluation)
use the following arguments:
- reset_values_for_each_policy_eval=True
- convergence_test='max_diff'
- max_eval_iter=10000
- epsilon=1e-4
Supports checkpointing for long-running problems using the CheckpointMixin.
Args:
problem: Problem instance or None if using config
config: Configuration object. If provided, other kwargs are ignored.
**kwargs: Parameters matching :class:`PolicyIterationConfig`.
See Config class for detailed parameter descriptions.
"""
Config = PolicyIterationConfig
def __init__(
self,
problem: Problem | None = None,
config: PolicyIterationConfig | None = None,
**kwargs,
):
"""Initialize the solver."""
super().__init__(problem=problem, config=config, **kwargs)
def _setup_jax_functions(self) -> None:
"""Setup JAX functions for policy iteration."""
super()._setup_jax_functions()
self._calculate_policy_values_scan_state_batches_pmap = jax.pmap(
self._calculate_policy_values_scan_state_batches,
in_axes=((None, None, None, None, None), 0),
)
self._extract_policy_idx_scan_state_batches_pmap = jax.pmap(
self._extract_policy_idx_scan_state_batches,
in_axes=((None, None, None, None), 0),
)
def _initialize_solver_state_elements(self) -> None:
"""Initialize solver state elements."""
# Set values to zero for policy initialization
self.values = jnp.zeros(self.problem.n_states)
self.policy = self._initialize_policy()
self.values = self._initialize_values(self.batched_states)
if self.config.reset_values_for_each_policy_eval:
# store initial values for policy evaluation resets
self.initial_values = self.values
self.iteration = 0
def _initialize_policy(self) -> Policy:
"""Initialize policy as custom policy or by maximizing immediate reward.
If the problem provides an initial policy, use it. Otherwise, use the policy
that maximizes immediate reward (using _extract_policy with zero values).
"""
try:
# Try to use problem's initial policy
initial_policy = jax.vmap(self.problem.initial_policy)(
self.problem.state_space
)
except NotImplementedError:
# Extract policy using zero values (maximizes immediate reward)
initial_policy = self._extract_policy()
return initial_policy
def _calculate_policy_value_state_batch(
self,
carry: tuple[ActionSpace, RandomEventSpace, float, ValueFunction, Policy],
state_batch: StateBatch,
) -> tuple[tuple, Float[Array, "batch_size"]]:
"""Calculate values for a batch of states using their policy actions.
Args:
carry: Tuple of (actions, random_events, gamma, values, policy)
state_batch: Batch of states to update [batch_size, state_dim]
Returns:
Tuple of (carry, new_values) where new_values has shape [batch_size]
"""
actions, random_events, gamma, values, policy = carry
# Get policy actions for this batch
batch_indices = jax.vmap(self.problem.state_to_index)(state_batch)
batch_actions = policy[batch_indices] # Already contains action vectors
# Calculate values using policy actions
new_values = jax.vmap(
self._calculate_updated_state_action_value,
in_axes=(0, 0, None, None, None),
)(state_batch, batch_actions, random_events, gamma, values)
return carry, new_values
def _calculate_policy_values_scan_state_batches(
self,
carry: tuple[ActionSpace, RandomEventSpace, float, ValueFunction, Policy],
padded_batched_states: BatchedStates,
) -> Float[Array, "n_devices n_batches batch_size"]:
"""Calculate policy values for multiple batches of states.
Uses jax.lax.scan to loop over batches efficiently.
Args:
carry: Tuple of (actions, random_events, gamma, values, policy)
padded_batched_states: States prepared for batch processing
Shape: [n_devices, n_batches, batch_size, state_dim]
Returns:
Array of updated values for all states [n_devices, n_batches, batch_size]
"""
_, new_values = jax.lax.scan(
self._calculate_policy_value_state_batch,
carry,
padded_batched_states,
)
return new_values
def _calculate_policy_values(
self,
policy: Policy,
values: ValueFunction,
) -> ValueFunction:
"""Calculate new values using only the actions specified by the policy.
Uses batched computation and parallel processing across devices for efficiency
with large state spaces.
Args:
policy: Current policy [n_states]
values: Current values [n_states]
Returns:
New values [n_states]
"""
# Process batches in parallel across devices
padded_batched_values = self._calculate_policy_values_scan_state_batches_pmap(
(
self.problem.action_space,
self.problem.random_event_space,
self.gamma,
values,
policy,
),
self.batched_states,
)
# Unbatch and remove padding
new_values = self._unbatch_results(padded_batched_values)
new_values = new_values.reshape(-1)
return new_values
def _evaluate_policy(
self,
policy: Policy,
starting_values: ValueFunction | None = None,
) -> ValueFunction:
"""Evaluate policy using iterative updates with batched computation.
Similar to value iteration but only considers the current policy's action
for each state. Uses batched updates for efficiency with large state spaces.
Uses same convergence test and threshold as main iteration.
Args:
policy: Current policy to evaluate [n_states]
starting_values: Initial values to start from [n_states], uses zero values if None
Returns:
Updated values for the policy [n_states]
"""
# If no starting values provided, use either zero values or current values based on config
if starting_values is None:
values = (
self.initial_values
if self.config.reset_values_for_each_policy_eval
else self.values
)
else:
values = starting_values
# Iterate until values converge or max iterations reached
for eval_iter in range(self.config.max_eval_iter):
# Calculate new values using only the policy's actions
new_values = self._calculate_policy_values(policy, values)
# Check convergence using same test as main iteration
conv = self._convergence_test_fn(new_values, values)
if self.verbose:
logger.debug(
f"Policy evaluation iteration {eval_iter+1}: "
f"{self._convergence_desc}: {conv:{self.convergence_format}}"
)
if conv < self.conv_threshold:
break
values = new_values
return values
def _iteration_step(self) -> tuple[Policy, int]:
"""Perform one iteration of policy iteration.
1. Evaluate current policy to get values (starting values determined by config)
2. Improve policy using one-step lookahead
Returns:
Tuple of (new_policy, n_changed) where:
new_policy: Improved policy [n_states]
n_changed: Number of states where policy changed
"""
# Evaluate current policy (starting point determined by config)
self.values = self._evaluate_policy(self.policy)
# Improve policy using parent's policy extraction
new_policy = self._extract_policy()
# Count number of states where policy differs (comparing full action vectors)
n_changed = jnp.any(new_policy != self.policy, axis=1).sum()
return new_policy, n_changed
[docs]
def solve(self, max_iterations: int = 1000) -> SolverState:
"""Run solver to convergence or max iterations.
Policy iteration is guaranteed to converge in finite iterations for discounted MDPs.
Stops when policy stops changing or max iterations reached.
Args:
max_iterations: Maximum number of iterations to run
Returns:
SolverState containing final values [n_states], optimal policy [n_states],
and SolverInfo including iteration count
"""
for _ in range(max_iterations):
self.iteration += 1
# Do one iteration of policy iteration
new_policy, n_changed = self._iteration_step()
self.policy = new_policy
# Log progress
logger.info(
f"Iteration {self.iteration}: Policy updated for {n_changed} state(s) ({n_changed/self.problem.n_states*100:.2f}%)"
)
# Check for convergence
if n_changed == 0:
logger.info(f"Policy converged at iteration {self.iteration}")
break
# Save checkpoint if enabled
if (
self.is_checkpointing_enabled
and self.iteration % self.checkpoint_frequency == 0
):
self.save(self.iteration)
if n_changed > 0:
logger.info("Maximum iterations reached")
# Final checkpoint if enabled
if self.is_checkpointing_enabled:
self.save(self.iteration)
logger.success("Policy iteration completed")
return self.solver_state