Source code for mdpax.solvers.periodic_value_iteration

"""Value iteration solver with periodic convergence checking."""

import chex
import jax.numpy as jnp
import numpy as np
from hydra.conf import MISSING, dataclass
from jaxtyping import Array, Float
from loguru import logger

from mdpax.core.problem import Problem, ProblemConfig
from mdpax.core.solver import SolverConfig, SolverInfo, SolverState
from mdpax.solvers.value_iteration import ValueIteration
from mdpax.utils.logging import get_convergence_format
from mdpax.utils.types import (
    ValueFunction,
)


@dataclass
class PeriodicValueIterationConfig(SolverConfig):
    """Configuration for the Periodic Value Iteration solver.

    This solver extends value iteration to check for convergence over a
    specified period length rather than between consecutive iterations.

    Args:
        problem: Optional problem configuration. If not provided, can pass a Problem
            instance directly to the solver. If a Problem instance with a config is
            provided to the solver, its config will be extracted and stored here.
        period: Number of iterations to check for periodic convergence (must be positive)
        gamma: Discount factor in [0,1]
        epsilon: Convergence threshold for value changes (must be positive)
        max_batch_size: Maximum states to process in parallel on each device (must be positive)
        jax_double_precision: Whether to use float64 precision
        verbose: Logging verbosity level (must be 0-4)
        checkpoint_dir: Directory to store checkpoints
        checkpoint_frequency: How often to save checkpoints (must be non-negative, 0 to disable)
        max_checkpoints: Maximum number of checkpoints to keep (must be non-negative)
        enable_async_checkpointing: Whether to save checkpoints asynchronously
        clear_value_history_on_convergence: Whether to clear history after convergence

    Example:
        >>> # Using a Problem instance with config
        >>> problem = Forest(S=4)  # Has config
        >>> solver = PeriodicValueIteration(problem=problem, period=2)  # Config extracted automatically

        >>> # Or using a ProblemConfig directly
        >>> problem_config = ForestConfig(S=4)
        >>> config = PeriodicValueIterationConfig(problem=problem_config, period=2)
        >>> solver = PeriodicValueIteration(config=config)

        >>> # Or using a Problem instance without config
        >>> problem = CustomProblem()  # No config
        >>> solver = PeriodicValueIteration(problem=problem, period=2)  # Checkpointing will be disabled
    """

    _target_: str = "mdpax.solvers.periodic_value_iteration.PeriodicValueIteration"
    problem: ProblemConfig | None = None
    period: int = MISSING
    gamma: float = 0.99
    epsilon: float = 1e-3
    max_batch_size: int = 1024
    jax_double_precision: bool = True
    verbose: int = 2
    checkpoint_dir: str | None = None
    checkpoint_frequency: int = 0
    max_checkpoints: int = 1
    enable_async_checkpointing: bool = True
    clear_value_history_on_convergence: bool = True

    def __post_init__(self) -> None:
        """Validate configuration parameters."""
        if self.problem is not None and not isinstance(self.problem, ProblemConfig):
            raise TypeError("problem must be a ProblemConfig if provided")
        if self.period <= 0:
            raise ValueError("Period must be positive")
        if self.gamma == 1.0 and self.period < 2:
            raise ValueError("Period must be at least 2 for undiscounted case")
        if not 0 <= self.gamma <= 1:
            raise ValueError("gamma must be between 0 and 1")
        if self.epsilon <= 0:
            raise ValueError("epsilon must be positive")
        if self.max_batch_size <= 0:
            raise ValueError("max_batch_size must be positive")
        if self.checkpoint_frequency < 0:
            raise ValueError("checkpoint_frequency must be non-negative")
        if self.max_checkpoints < 0:
            raise ValueError("max_checkpoints must be non-negative")
        if not 0 <= self.verbose <= 4:
            raise ValueError("verbose must be between 0 and 4")


@chex.dataclass(frozen=True)
class PeriodicValueIterationInfo(SolverInfo):
    """Runtime information for periodic value iteration.

    Attributes:
        value_history: History of value functions [period, n_states]
        history_index: Current position in circular history buffer
        period: Length of period being checked
    """

    value_history: Float[Array, "period n_states"]
    history_index: int
    period: int


@chex.dataclass(frozen=True)
class PeriodicValueIterationState(SolverState):
    """Runtime state for periodic value iteration.

    Attributes:
        values: Current value function [n_states]
        policy: Current policy [n_states, action_dim]
        info: Solver metadata including value history
    """

    info: PeriodicValueIterationInfo


[docs] class PeriodicValueIteration(ValueIteration): """Periodic value iteration solver for MDPs This is particularly useful for problems with periodic structure in the state space, where it may require fewer iterations to reach convergence than standard value iteration. Convergence testing is based on the span of value differences over a period. For undiscounted problems, this is simply the span of differences between current values and values from one period ago. For discounted problems, we sum the consecutive differences over the period, scaling each by the appropriate discount factor. The solver stores a history of values over the period to perform this comparison. Supports checkpointing for long-running problems using the CheckpointMixin. Args: problem: Problem instance or None if using config config: Configuration object. If provided, other kwargs are ignored. **kwargs: Parameters matching :class:`PeriodicValueIterationConfig`. See Config class for detailed parameter descriptions. """ Config = PeriodicValueIterationConfig def __init__( self, problem: Problem | None = None, config: PeriodicValueIterationConfig | None = None, **kwargs, ): """Initialize solver.""" super().__init__(problem=problem, config=config, **kwargs) def _setup_config( self, problem: Problem, config: SolverConfig | None = None, **kwargs ) -> None: super()._setup_config(problem, config, **kwargs) self.clear_value_history_on_convergence = ( self.config.clear_value_history_on_convergence ) self.period = self.config.period def _setup_convergence_testing(self) -> None: """Setup convergence test. For periodic value iteration, we use the span of differences over a period as the convergence measure. The convergence threshold is simply epsilon (no adjustment for gamma needed since we're comparing over a period). """ self._convergence_test_fn = self._get_periodic_span self.conv_threshold = self.epsilon self._convergence_desc = "period_span" # Get convergence format for logging convergence metrics self.convergence_format = get_convergence_format(float(self.conv_threshold)) def _initialize_solver_state_elements(self) -> None: """Initialize solver state elements.""" super()._initialize_solver_state_elements() # Initialize value history in CPU memory self.value_history = np.zeros((self.period + 1, self.problem.n_states)) self.history_index: int = 0 self.value_history[0] = np.array(self.values) def _get_periodic_span( self, new_values: ValueFunction, old_values: ValueFunction, history_index: int, period: int, value_history: Float[Array, "period_plus_one n_states"], iteration: int, gamma: float, ) -> float: """Calculate convergence measure based on span of differences over a period. For undiscounted problems (gamma=1), this is the span of differences between current values and values from one period ago. For discounted problems, we sum consecutive differences over the period, adjusting for the discount factor. Args: new_values: Current value function [n_states] old_values: Previous value function [n_states] (unused) history_index: Current position in circular history buffer period: Length of period to check value_history: History of values [period+1, n_states] iteration: Current iteration number gamma: Discount factor Returns: Span of differences over the period """ if iteration < period: return float("inf") if gamma == 1.0: return self._calculate_period_span_without_discount( new_values, history_index, period, value_history ) else: return self._calculate_period_span_with_discount( new_values, history_index, period, value_history, iteration, gamma ) def _calculate_period_span_without_discount( self, values: ValueFunction, history_index: int, period: int, value_history: Float[Array, "period_plus_one n_states"], ) -> float: """Calculate span of value changes over a period without discounting. For problems without discounting (gamma=1), we simply compare current values with values from one period ago. The circular buffer is arranged so that index (i+1) % (period+1) contains values from one period ago. Args: values: Current value function [n_states] history_index: Current position in circular history buffer period: Length of period to check value_history: History of values [period+1, n_states] Returns: Span of differences over the period """ prev_index = (history_index + 1) % (period + 1) values_prev = jnp.array(value_history[prev_index]) return self._get_span(values, values_prev) def _calculate_period_span_with_discount( self, values: ValueFunction, history_index: int, period: int, value_history: Float[Array, "period_plus_one n_states"], iteration: int, gamma: float, ) -> float: """Calculate span of undiscounted value changes over a period. For discounted problems (gamma<1), we sum the differences between consecutive steps in the period, adjusting for the discount factor. The differences are scaled by 1/gamma^(current_iteration - p - 1) to remove discounting. Args: values: Current value function [n_states] history_index: Current position in circular history buffer period: Length of period to check value_history: History of values [period+1, n_states] iteration: Current iteration number gamma: Discount factor Returns: Span of differences over the period """ period_deltas = np.zeros_like(values) for p in range(period): curr_index = (history_index - p) % (period + 1) prev_index = (curr_index - 1) % (period + 1) values_curr = value_history[curr_index] values_prev = value_history[prev_index] period_deltas += (values_curr - values_prev) / ( gamma ** (iteration - p - 1) ) period_deltas = jnp.array(period_deltas) return jnp.max(period_deltas) - jnp.min(period_deltas) def _iteration_step(self) -> tuple[ValueFunction, float]: """Perform one iteration of the solution algorithm. Returns: Tuple of (new values, convergence measure) where new values has shape [n_states] """ # Get new values using parent's batch processing new_values = self._update_values( self.batched_states, self.problem.action_space, self.problem.random_event_space, self.gamma, self.values, ) # Store values in history (CPU) self.history_index = (self.history_index + 1) % (self.period + 1) self.value_history[self.history_index] = np.array(new_values) # Calculate convergence using the test function conv = self._convergence_test_fn( new_values, self.values, self.history_index, self.period, self.value_history, self.iteration, self.gamma, ) return new_values, conv
[docs] def solve(self, max_iterations: int = 2000) -> PeriodicValueIterationState: """Run solver to convergence or max iterations. Args: max_iterations: Maximum number of iterations to run Returns: SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including iteration count and value history. """ for _ in range(max_iterations): self.iteration += 1 new_values, conv = self._iteration_step() self.values = new_values logger.info( f"Iteration {self.iteration}: {self._convergence_desc}: {conv:{self.convergence_format}}" ) # Check convergence if conv < self.conv_threshold: logger.info( f"Convergence threshold reached at iteration {self.iteration}" ) break if ( self.is_checkpointing_enabled and self.iteration % self.checkpoint_frequency == 0 ): self.save(self.iteration) if conv >= self.conv_threshold: logger.info("Maximum iterations reached") # Final checkpoint if enabled if self.is_checkpointing_enabled: self.save(self.iteration) # Extract policy if converged or on final iteration logger.info("Extracting policy") self.policy = self._extract_policy() logger.info("Policy extracted") logger.success("Periodic value iteration completed") if conv < self.conv_threshold: self._clear_value_history() return self.solver_state
def _clear_value_history(self) -> None: """Clear value history to free memory after convergence.""" if self.clear_value_history_on_convergence: self.value_history = None @property def solver_state(self) -> PeriodicValueIterationState: """Get solver state for checkpointing.""" return PeriodicValueIterationState( values=self.values, policy=self.policy, info=PeriodicValueIterationInfo( iteration=self.iteration, value_history=self.value_history, history_index=self.history_index, period=self.period, ), ) def _restore_state_from_checkpoint( self, solver_state: PeriodicValueIterationState ) -> None: """Restore solver state from checkpoint.""" self.values = solver_state.values self.policy = solver_state.policy self.iteration = solver_state.info.iteration self.value_history = solver_state.info.value_history self.history_index = solver_state.info.history_index self.period = solver_state.info.period