Source code for mdpax.solvers.relative_value_iteration

"""Relative value iteration solver for average reward MDPs."""

import chex
from hydra.conf import dataclass
from loguru import logger

from mdpax.core.problem import Problem, ProblemConfig
from mdpax.core.solver import SolverConfig, SolverInfo, SolverState
from mdpax.solvers.value_iteration import ValueIteration
from mdpax.utils.logging import get_convergence_format
from mdpax.utils.types import ValueFunction


@dataclass
class RelativeValueIterationConfig(SolverConfig):
    """Configuration for the Relative Value Iteration solver.

    Args:
        problem: Optional problem configuration. If not provided, can pass a Problem
            instance directly to the solver. If a Problem instance with a config is
            provided to the solver, its config will be extracted and stored here.
        epsilon: Convergence threshold for value changes
        gamma: Discount factor (fixed at 1.0 for relative value iteration)
        max_batch_size: Maximum states to process in parallel on each device
        jax_double_precision: Whether to use float64 precision
        verbose: Logging verbosity level (must be 0-4)
        checkpoint_dir: Directory to store checkpoints
        checkpoint_frequency: How often to save checkpoints (must be non-negative, 0 to disable)
        max_checkpoints: Maximum number of checkpoints to keep (must be non-negative)
        enable_async_checkpointing: Whether to save checkpoints asynchronously

    Example:
        >>> # Using a Problem instance with config
        >>> problem = Forest(S=4)  # Has config
        >>> solver = RelativeValueIteration(problem=problem)  # Config extracted automatically

        >>> # Or using a ProblemConfig directly
        >>> problem_config = ForestConfig(S=4)
        >>> config = RelativeValueIterationConfig(problem=problem_config)
        >>> solver = RelativeValueIteration(config=config)

        >>> # Or using a Problem instance without config
        >>> problem = CustomProblem()  # No config
        >>> solver = RelativeValueIteration(problem=problem)  # Checkpointing will be disabled
    """

    _target_: str = "mdpax.solvers.relative_value_iteration.RelativeValueIteration"
    problem: ProblemConfig | None = None
    gamma: float = 1.0
    epsilon: float = 1e-3
    max_batch_size: int = 1024
    jax_double_precision: bool = True
    verbose: int = 2
    checkpoint_dir: str | None = None
    checkpoint_frequency: int = 0
    max_checkpoints: int = 1
    enable_async_checkpointing: bool = True

    def __post_init__(self) -> None:
        """Validate configuration parameters."""
        if self.problem is not None and not isinstance(self.problem, ProblemConfig):
            raise TypeError("problem must be a ProblemConfig if provided")
        if not self.gamma == 1.0:
            raise ValueError("gamma must be 1.0 for relative value iteration")
        if self.epsilon <= 0:
            raise ValueError("epsilon must be positive")
        if self.max_batch_size <= 0:
            raise ValueError("max_batch_size must be positive")
        if self.checkpoint_frequency < 0:
            raise ValueError("checkpoint_frequency must be non-negative")
        if self.max_checkpoints < 0:
            raise ValueError("max_checkpoints must be non-negative")
        if not 0 <= self.verbose <= 4:
            raise ValueError("verbose must be between 0 and 4")


@chex.dataclass(frozen=True)
class RelativeValueIterationInfo(SolverInfo):
    """Runtime information for relative value iteration.

    Attributes:
        gain: Current gain term for value function adjustment,
            converges to average reward per timestep
    """

    gain: float


@chex.dataclass(frozen=True)
class RelativeValueIterationState(SolverState):
    """Runtime state for relative value iteration.

    Attributes:
        values: Current value function [n_states]
        policy: Current policy [n_states, action_dim]
        info: Solver metadata including gain term
    """

    info: RelativeValueIterationInfo


[docs] class RelativeValueIteration(ValueIteration): """Relative value iteration solver for average reward MDPs. This solver extends standard value iteration to handle average reward MDPs by: - Using gamma=1.0 (no discounting) - Tracking and subtracting a gain term to handle unbounded values Convergence testing is based on the span of value differences. Supports checkpointing for long-running problems using the CheckpointMixin. Args: problem: Problem instance or None if using config config: Configuration object. If provided, other kwargs are ignored. **kwargs: Parameters matching :class:`RelativeValueIterationConfig`. See Config class for detailed parameter descriptions. """ Config = RelativeValueIterationConfig def __init__( self, problem: Problem | None = None, config: RelativeValueIterationConfig | None = None, **kwargs, ): """Initialize the solver.""" super().__init__(problem=problem, config=config, **kwargs) def _setup_convergence_testing(self) -> None: """Setup convergence test.""" self._convergence_test_fn = self._get_span self.conv_threshold = self.epsilon self._convergence_desc = "span" # Get convergence format for logging convergence metrics self.convergence_format = get_convergence_format(float(self.conv_threshold)) def _initialize_solver_state_elements(self) -> None: """Initialize solver state elements.""" super()._initialize_solver_state_elements() self.gain = 0.0 def _iteration_step(self) -> tuple[ValueFunction, float]: """Perform one iteration of the solution algorithm. Returns: Tuple of (new values, convergence measure) where new values has shape [n_states] """ # Get new values using parent's batch processing new_values, _ = super()._iteration_step() # Calculate value differences new_values = new_values - self.gain span = self._get_span(new_values, self.values) self.gain = new_values[-1] return new_values, span
[docs] def solve(self, max_iterations: int = 2000) -> RelativeValueIterationState: """Run solver to convergence or max iterations. Args: max_iterations: Maximum number of iterations to run Returns: SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including iteration count and gain """ for _ in range(max_iterations): self.iteration += 1 new_values, conv = self._iteration_step() self.values = new_values logger.info( f"Iteration {self.iteration}: span: {conv:{self.convergence_format}}, gain: {self.gain:.4f}" ) if conv < self.epsilon: logger.info( f"Convergence threshold reached at iteration {self.iteration}" ) break if ( self.is_checkpointing_enabled and self.iteration % self.checkpoint_frequency == 0 ): self.save(self.iteration) if conv >= self.epsilon: logger.info("Maximum iterations reached") # Final checkpoint if enabled if self.is_checkpointing_enabled: self.save(self.iteration) # Extract policy if converged or on final iteration logger.info("Extracting policy") self.policy = self._extract_policy() logger.info("Policy extracted") logger.success("Relative value iteration completed") return self.solver_state
@property def solver_state(self) -> RelativeValueIterationState: """Get solver state for checkpointing.""" return RelativeValueIterationState( values=self.values, policy=self.policy, info=RelativeValueIterationInfo( iteration=self.iteration, gain=self.gain, ), ) def _restore_state_from_checkpoint( self, solver_state: RelativeValueIterationState ) -> None: """Restore solver state from checkpoint.""" self.values = solver_state.values self.policy = solver_state.policy self.iteration = solver_state.info.iteration self.gain = solver_state.info.gain