Source code for mdpax.solvers.semi_async_value_iteration

"""Semi-asynchronous value iteration solver with different batch ordering strategies."""

import chex
import jax
import jax.numpy as jnp
import jax.random as random
from hydra.conf import dataclass
from jaxtyping import Array, Float
from loguru import logger

from mdpax.core.problem import Problem, ProblemConfig
from mdpax.core.solver import SolverConfig, SolverInfo, SolverState
from mdpax.solvers.value_iteration import ValueIteration
from mdpax.utils.types import (
    ActionSpace,
    ActionVector,
    BatchedStates,
    RandomEventSpace,
    StateBatch,
    StateVector,
    ValueFunction,
)


@dataclass
class SemiAsyncValueIterationConfig(SolverConfig):
    """Configuration for the Semi-Asynchronous Value Iteration solver.

    This solver performs asynchronous updates over batches of states using
    parallel processing across devices.

    Args:
        problem: Optional problem configuration. If not provided, can pass a Problem
            instance directly to the solver. If a Problem instance with a config is
            provided to the solver, its config will be extracted and stored here.
            Must be a ProblemConfig if provided.
        gamma: Discount factor in [0,1]
        epsilon: Convergence threshold for value changes (must be positive)
        max_batch_size: Maximum states to process in parallel on each device (must be positive)
        jax_double_precision: Whether to use float64 precision
        verbose: Logging verbosity level (must be 0-4)
        checkpoint_dir: Directory to store checkpoints
        checkpoint_frequency: How often to save checkpoints (must be non-negative, 0 to disable)
        max_checkpoints: Maximum number of checkpoints to keep (must be non-negative)
        enable_async_checkpointing: Whether to save checkpoints asynchronously
        convergence_test: Strategy for testing convergence ("span" or "max_diff")
        shuffle_states: Whether to shuffle state update order each iteration
        random_seed: Random seed for shuffling states
    """

    _target_: str = "mdpax.solvers.semi_async_value_iteration.SemiAsyncValueIteration"
    problem: ProblemConfig | None = None
    gamma: float = 0.99
    epsilon: float = 1e-3
    max_batch_size: int = 1024
    jax_double_precision: bool = True
    verbose: int = 2
    checkpoint_dir: str | None = None
    checkpoint_frequency: int = 0
    max_checkpoints: int = 1
    enable_async_checkpointing: bool = True
    convergence_test: str = "span"
    shuffle_states: bool = False
    random_seed: int = 42

    def __post_init__(self) -> None:
        """Validate configuration parameters."""
        if self.problem is not None and not isinstance(self.problem, ProblemConfig):
            raise TypeError("problem must be a ProblemConfig if provided")
        if not 0 <= self.gamma <= 1:
            raise ValueError("gamma must be between 0 and 1")
        if self.epsilon <= 0:
            raise ValueError("epsilon must be positive")
        if self.max_batch_size <= 0:
            raise ValueError("max_batch_size must be positive")
        if self.checkpoint_frequency < 0:
            raise ValueError("checkpoint_frequency must be non-negative")
        if self.max_checkpoints < 0:
            raise ValueError("max_checkpoints must be non-negative")
        if not 0 <= self.verbose <= 4:
            raise ValueError("verbose must be between 0 and 4")
        if self.convergence_test not in ["span", "max_diff"]:
            raise ValueError("Convergence test must be 'span' or 'max_diff'")


@chex.dataclass(frozen=True)
class SemiAsyncValueIterationInfo(SolverInfo):
    """Runtime information for semi-async value iteration.

    Attributes:
        batch_order: Current ordering of batches
    """

    batch_order: jnp.ndarray | None


@chex.dataclass(frozen=True)
class SemiAsyncValueIterationState(SolverState):
    """Runtime state for semi-async value iteration.

    Attributes:
        values: Current value function [n_states]
        policy: Current policy [n_states, action_dim]
        info: Solver metadata including batch ordering info
    """

    info: SemiAsyncValueIterationInfo


[docs] class SemiAsyncValueIteration(ValueIteration): """Semi-asynchronous value iteration solver with flexible batch ordering. This solver extends standard value iteration by: - Processing states in batches with updated values immediately available to subsequent batches on the same device - Supporting a fixed or random (shuffle) state update order Supports checkpointing for long-running problems using the CheckpointMixin. Args: problem: Problem instance or None if using config config: Configuration object. If provided, other kwargs are ignored. **kwargs: Parameters matching :class:`SemiAsyncValueIterationConfig`. See Config class for detailed parameter descriptions. """ Config = SemiAsyncValueIterationConfig def __init__( self, problem: Problem | None = None, config: SemiAsyncValueIterationConfig | None = None, **kwargs, ): """Initialize the solver.""" super().__init__(problem=problem, config=config, **kwargs) def _setup_config( self, problem: Problem, config: SolverConfig | None = None, **kwargs ) -> None: super()._setup_config(problem, config, **kwargs) self.key = random.PRNGKey(self.config.random_seed) def _setup_jax_functions(self) -> None: super()._setup_jax_functions() # JIT compile core computations self._jitted_calculate_updated_state_action_value = jax.jit( self._calculate_updated_state_action_value, static_argnums=(0,) ) # JIT compile state shuffling functions self._jitted_shuffle_states = jax.jit(self._shuffle_states) self._jitted_reorder_values = jax.jit(self._reorder_values) def _initialize_solver_state_elements(self) -> None: super()._initialize_solver_state_elements() self.batch_order = None self.inverse_order = None def _shuffle_states( self, key: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Shuffle the states for random processing order. Args: key: PRNG key for random shuffling Returns: Tuple of (shuffled_state_idxs, padded_batched_states, padding_mask) """ # Generate permutation of state indices shuffled_state_idxs = jax.random.permutation( key, jnp.arange(self.problem.n_states) ) # Shuffle states using these indices shuffled_states = self.problem.state_space[shuffled_state_idxs] # Prepare batched states from shuffled states padded_batched_states = self.batch_processor.prepare_batches(shuffled_states) # Create padding mask based on problem size and batch shape n_total = ( padded_batched_states.shape[0] * padded_batched_states.shape[1] * padded_batched_states.shape[2] ) padding_mask = (jnp.arange(n_total) >= self.problem.n_states).reshape( padded_batched_states.shape[0], # n_devices padded_batched_states.shape[1], # n_batches padded_batched_states.shape[2], # batch_size ) return shuffled_state_idxs, padded_batched_states, padding_mask def _reorder_values( self, shuffled_state_idxs: jnp.ndarray, values: jnp.ndarray ) -> jnp.ndarray: """Reorder values back to original state order. Args: shuffled_state_idxs: Indices used to shuffle states values: Values in shuffled order Returns: Values reordered to match original state order """ return values[jnp.argsort(shuffled_state_idxs)] def _get_value_next_state( self, next_state: StateVector, values: ValueFunction ) -> float: """Lookup the value of the next state in the value function. Args: next_state: Next state vector [state_dim] values: Current value function [n_states] Returns: Value of the next state """ # Values are always in their original positions, so just get the index return values[self.problem.state_to_index(next_state)] def _calculate_updated_state_action_value( self, state: StateVector, action: ActionVector, random_events: RandomEventSpace, gamma: float, values: ValueFunction, ) -> float: """Calculate the expected value for a state-action pair. Similar to value iteration but uses _map_state_indices for next state lookups to handle batch-aware indexing. Args: state: Current state vector [state_dim] action: Action vector [action_dim] random_events: All possible random events [n_events, event_dim] gamma: Discount factor values: Current value function [n_states] Returns: Expected value for the state-action pair """ next_states, single_step_rewards = jax.vmap( self.problem.transition, in_axes=(None, None, 0), )(state, action, random_events) # Map indices for correct value lookups in batch context next_state_values = jax.vmap( self._get_value_next_state, in_axes=(0, None), )(next_states, values) probs = jax.vmap( self.problem.random_event_probability, in_axes=(None, None, 0), )(state, action, random_events) return (single_step_rewards + gamma * next_state_values).dot(probs) def _calculate_updated_value( self, state: StateVector, actions: ActionSpace, random_events: RandomEventSpace, gamma: float, values: ValueFunction, ) -> float: """Calculate the maximum expected value over all actions for a state. Args: state: Current state vector [state_dim] actions: All possible actions [n_actions, action_dim] random_events: All possible random events [n_events, event_dim] gamma: Discount factor values: Current value function [n_states] Returns: Maximum expected value over all actions """ return jnp.max( jax.vmap( self._calculate_updated_state_action_value, in_axes=(None, 0, None, None, None), )(state, actions, random_events, gamma, values) ) def _calculate_updated_value_state_batch( self, carry: tuple[ActionSpace, RandomEventSpace, float, ValueFunction], state_batch: StateBatch, ) -> tuple[tuple, Float[Array, "batch_size"]]: """Calculate updated values for a batch of states. Similar to value iteration but ensures values reflect most recent updates. Args: carry: Tuple of (actions, random_events, gamma, values) state_batch: Batch of states to update [batch_size, state_dim] Returns: Tuple of (carry, new_values) where new_values has shape [batch_size] """ actions, random_events, gamma, values = carry new_values = jax.vmap( self._calculate_updated_value, in_axes=(0, None, None, None, None), )(state_batch, actions, random_events, gamma, values) return carry, new_values def _batch_get_indices(self, state_batch: StateBatch) -> jnp.ndarray: """Get indices for a batch of states. Args: state_batch: Batch of states [batch_size, state_dim] Returns: Array of indices [batch_size] """ return jax.vmap(self.problem.state_to_index)(state_batch) def _calculate_updated_value_scan_state_batches( self, carry: tuple[ActionSpace, RandomEventSpace, float, ValueFunction], batched_input: tuple[BatchedStates, jnp.ndarray], ) -> Float[Array, "n_devices n_batches batch_size"]: """Update values for multiple batches of states. Key difference from value iteration: values in carry are updated after each batch to make them available for subsequent batches. The batch order determines the sequence of processing, but values are always stored in their original positions in the state space. Args: carry: Tuple of (actions, random_events, gamma, values) batched_input: Tuple of (batched_states, padding_mask) where: - batched_states has shape [n_devices, n_batches, batch_size, state_dim] - padding_mask has shape [n_devices, n_batches, batch_size] Returns: Array of updated values for all states [n_devices, n_batches, batch_size] """ actions, random_events, gamma, values = carry batched_states, padding_mask = batched_input def scan_fn(carry, batch_input): actions, random_events, gamma, current_values = carry batch, batch_padding_mask = batch_input # Process current batch using current_values from carry _, new_batch_values = self._calculate_updated_value_state_batch( (actions, random_events, gamma, current_values), batch ) # Get indices for this batch - these map directly to positions in state space batch_indices = self._batch_get_indices(batch) # Update values in their original positions in state space # Only update non-padding states updated_values = current_values.at[batch_indices].set( jnp.where( batch_padding_mask, current_values[batch_indices], new_batch_values ) ) return (actions, random_events, gamma, updated_values), new_batch_values # Process batches in the specified order if self.batch_order is not None: batched_states = batched_states[self.batch_order] padding_mask = padding_mask[self.batch_order] # Run scan with prefetching _, new_values = jax.lax.scan(scan_fn, carry, (batched_states, padding_mask)) return new_values def _update_values( self, batched_states: BatchedStates, actions: ActionSpace, random_events: RandomEventSpace, gamma: float, values: ValueFunction, ) -> ValueFunction: """Update values for a batch of states using parallel processing. The semi-async nature comes from: 1. Optionally shuffling states before processing 2. Processing states in batches with immediate value updates 3. Values are always stored in their correct positions Args: batched_states: Batched state vectors [n_devices, n_batches, batch_size, state_dim] actions: Action space [n_actions, action_dim] random_events: Random event space [n_events, event_dim] gamma: Discount factor values: Current value function [n_states] Returns: Array of updated values [n_states] in natural state order """ # If using random strategy, shuffle states before processing shuffled_state_idxs = None padding_mask = None if self.config.shuffle_states: self.key, subkey = random.split(self.key) shuffled_state_idxs, batched_states, padding_mask = ( self._jitted_shuffle_states(subkey) ) else: # For fixed order, create padding mask for original batched states n_total = ( batched_states.shape[0] * batched_states.shape[1] * batched_states.shape[2] ) padding_mask = (jnp.arange(n_total) >= self.problem.n_states).reshape( batched_states.shape[0], # n_devices batched_states.shape[1], # n_batches batched_states.shape[2], # batch_size ) # Process batches semi-asynchronously padded_batched_values = self._calculate_updated_value_scan_state_batches_pmap( (actions, random_events, gamma, values), (batched_states, padding_mask) ) # Unpad the values new_values = self._unbatch_results(padded_batched_values) # If states were shuffled, reorder values back to original state order if shuffled_state_idxs is not None: new_values = self._jitted_reorder_values(shuffled_state_idxs, new_values) return new_values def _iteration_step(self) -> tuple[ValueFunction, float]: """Run one iteration of semi-async value iteration. Returns: Tuple of (new values, convergence measure) where new values has shape [n_states] """ new_values = self._update_values( self.batched_states, self.problem.action_space, self.problem.random_event_space, self.gamma, self.values, ) # Check convergence using selected test conv = self._convergence_test_fn(new_values, self.values) return new_values, conv
[docs] def solve(self, max_iterations: int = 2000) -> SemiAsyncValueIterationState: """Run solver to convergence or max iterations. Args: max_iterations: Maximum number of iterations to run Returns: SolverState containing final values [n_states], optimal policy [n_states, action_dim], and SolverInfo including batch ordering info """ for _ in range(max_iterations): self.iteration += 1 new_values, conv = self._iteration_step() self.values = new_values logger.info( f"Iteration {self.iteration}: {self._convergence_desc}: {conv:{self.convergence_format}}" ) if conv < self.conv_threshold: logger.info( f"Convergence threshold reached at iteration {self.iteration}" ) break if ( self.is_checkpointing_enabled and self.iteration % self.checkpoint_frequency == 0 ): self.save(self.iteration) if conv >= self.conv_threshold: logger.info("Maximum iterations reached") # Final checkpoint if enabled if self.is_checkpointing_enabled: self.save(self.iteration) # Extract policy if converged or on final iteration logger.info("Extracting policy") self.policy = self._extract_policy() logger.info("Policy extracted") logger.success("Semi-async value iteration completed") return self.solver_state
@property def solver_state(self) -> SemiAsyncValueIterationState: """Get solver state for checkpointing.""" return SemiAsyncValueIterationState( values=self.values, policy=self.policy, info=SemiAsyncValueIterationInfo( iteration=self.iteration, batch_order=self.batch_order, ), ) def _restore_state_from_checkpoint( self, solver_state: SemiAsyncValueIterationState ) -> None: """Restore solver state from checkpoint.""" self.values = solver_state.values self.policy = solver_state.policy self.iteration = solver_state.info.iteration self.batch_order = solver_state.info.batch_order def _compute_batch_assignments( self, n_batches: int, is_random: bool, key: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Determine batch processing order for semi-async updates. Args: n_batches: Number of batches to order is_random: Whether to use random ordering key: PRNG key for random ordering Returns: Tuple of (new order, new key) """ def random_order(key): new_key, subkey = random.split(key) # Generate permutation for batch order order = random.permutation(subkey, jnp.arange(n_batches)) # Compute inverse mapping: for each position in the reordered array, # store the original position that maps to it inverse = jnp.zeros_like(order) inverse = inverse.at[order].set(jnp.arange(len(order))) return (order, inverse), new_key def fixed_order(key): order = jnp.arange(n_batches) # For fixed order, inverse is same as order return (order, order), key return jax.lax.cond( is_random, random_order, fixed_order, key, ) def _compute_inverse_order(self, order: jnp.ndarray) -> jnp.ndarray: """Compute inverse mapping for batch order. Args: order: Current batch order Returns: Inverse mapping for value lookups """ inverse = jnp.zeros_like(order) return inverse.at[order].set(jnp.arange(len(order))) def _reorder_batches(self) -> None: """Update batch processing order for semi-async updates.""" # Compute new batch order and its inverse in one step is_random = self.config.shuffle_states n_batches = self.batched_states.shape[1] (new_order, new_inverse), self.key = self._jitted_compute_batch_assignments( n_batches, is_random, self.key ) self.batch_order = new_order self.inverse_order = new_inverse