Source code for mdpax.utils.checkpointing

"""Checkpointing functionality for solvers."""

from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any

import orbax.checkpoint as checkpoint
from hydra.utils import instantiate
from loguru import logger
from omegaconf import OmegaConf

from mdpax.core.solver import Solver


[docs] class CheckpointMixin(ABC): """Mixin to add checkpointing capabilities to a solver. Quick Start: - Working with built-in problems? Use restore() to load checkpoints - Working in a notebook with a custom defined problem? Use load_checkpoint() - Not sure? The system will automatically detect which mode to use and log accordingly The checkpointing system uses Hydra's structured configs, which are Python dataclasses that contain all information needed to instantiate objects. These configs exist for all example Problems and Solvers in mdpax (see Forest, DeMoorSingleProduct, etc.). For custom problems, there are two options: 1. Full Reconstruction: - Define your problem in a module (not a notebook) - Create a Hydra config class for your problem (see Forest for a simple example) - Allows automatic reconstruction of both problem and solver 2. Manual Reconstruction: - Use when working in notebooks or without configs - Requires manually reconstructing problem and solver before loading state - More flexible but less automated Required Implementation: _restore_state_from_checkpoint(state: Dict[str, Any]) -> None: Restore solver state from a checkpoint state dictionary. Attributes: checkpoint_dir (Path): Directory where checkpoints are stored. checkpoint_frequency (int): Number of iterations between checkpoints, 0 to disable. max_checkpoints (int): Maximum number of checkpoints to retain. enable_async_checkpointing (bool): Whether async checkpointing is enabled. checkpoint_manager (checkpoint.CheckpointManager): Orbax checkpoint manager instance. Examples: >>> # Full Reconstruction with built-in problem >>> from mdpax.problems import Forest # Problem with Hydra config >>> from mdpax.solvers import ValueIteration # Solver with Hydra config >>> >>> problem = Forest(S=4) >>> solver = ValueIteration( ... problem=problem, ... checkpoint_dir="checkpoints/run1", ... checkpoint_frequency=5 ... ) >>> solver.solve(max_iterations=100) >>> >>> # Later, full reconstruction: >>> solver = ValueIteration.restore("checkpoints/run1") # Recreates everything >>> # Manual Reconstruction (e.g., custom problem in notebook) >>> class MyProblem: # Custom problem without config ... def __init__(self, size): ... self.size = size ... >>> problem = MyProblem(size=4) >>> solver = ValueIteration( ... problem=problem, ... checkpoint_dir="checkpoints/run1", ... checkpoint_frequency=5 ... ) >>> solver.solve(max_iterations=100) >>> >>> # Later, manual reconstruction: >>> problem = MyProblem(size=4) # Must recreate problem >>> solver = ValueIteration( ... problem=problem, ... checkpoint_dir="checkpoints/run2", # New save location ... ) >>> solver.load_checkpoint("checkpoints/run1") # Load from original location >>> solver.solve(max_iterations=50) # New checkpoints go to run2 """ def _setup_checkpointing( self, checkpoint_dir: str | Path | None = None, checkpoint_frequency: int = 0, max_checkpoints: int = 1, enable_async_checkpointing: bool = True, ) -> None: """Configure checkpointing behavior for the solver. Args: checkpoint_dir: Directory for storing checkpoints. If None, creates a timestamped directory under 'checkpoints/{problem_name}/'. checkpoint_frequency: How often to save checkpoints in iterations. Set to 0 to disable checkpointing. max_checkpoints: Maximum number of checkpoints to retain. Older checkpoints are automatically removed. enable_async_checkpointing: Whether to use asynchronous checkpointing for better performance. """ # Validation if checkpoint_frequency < 0: raise ValueError("checkpoint_frequency must be non-negative") if max_checkpoints < 0: raise ValueError("max_checkpoints must be non-negative") # Store basic settings self.checkpoint_frequency = checkpoint_frequency self.max_checkpoints = max_checkpoints self.enable_async_checkpointing = enable_async_checkpointing self.checkpoint_manager = None # Early return if checkpointing not requested if self.checkpoint_frequency == 0: logger.info("Checkpointing not enabled") return # Setup checkpoint directory if checkpoint_dir is None: from datetime import datetime current_datetime = datetime.now().strftime("%Y%m%d/%H:%M:%S") checkpoint_dir = Path( f"checkpoints/{self.problem.name}/{current_datetime}/" ) self.checkpoint_dir = Path(checkpoint_dir).absolute() self.checkpoint_dir.mkdir(parents=True, exist_ok=True) # Create checkpoint manager and save config self.checkpoint_manager = self._create_checkpoint_manager( self.checkpoint_dir, max_checkpoints, enable_async_checkpointing ) if self.has_full_config: self._save_solver_config() logger.info( "Full checkpointing enabled with problem and solver reconstruction" ) else: logger.info( "Lightweight checkpointing enabled - problem and solver must be " "reconstructed manually" ) logger.info( f"Saving checkpoints every {self.checkpoint_frequency} " f"iteration(s) to {self.checkpoint_dir}" ) @property def is_checkpointing_enabled(self) -> bool: """Check if checkpointing is enabled. Returns: True if checkpointing is properly configured and enabled, False otherwise. """ return ( hasattr(self, "checkpoint_frequency") and self.checkpoint_frequency > 0 and hasattr(self, "checkpoint_manager") ) @property def has_full_config(self) -> bool: """Check if solver and problem have complete configs for reconstruction. Checks that both solver and problem have configs with _target_ properties, which are required for Hydra to reconstruct the objects. Returns: True if both solver and problem have complete configs, False otherwise. """ return ( hasattr(self.problem, "config") and self.problem.config is not None and hasattr(self, "config") and self.config is not None and hasattr(self.problem.config, "_target_") and self.problem.config._target_ is not None and hasattr(self.config, "_target_") and self.config._target_ is not None )
[docs] def load_checkpoint( self, checkpoint_dir: str | Path, step: int | None = None ) -> None: """Load solver state from checkpoint. Must be called on an already constructed solver instance with problem. Will load state from checkpoint_dir, but continue saving new checkpoints to the directory specified during construction (self.checkpoint_dir). Args: checkpoint_dir: Directory containing checkpoint to load from step: Specific step to load. If None, loads the latest checkpoint. Returns: None """ checkpoint_dir = Path(checkpoint_dir).absolute() load_manager = self._create_checkpoint_manager( checkpoint_dir=checkpoint_dir, max_checkpoints=1, # Only need to keep one when loading enable_async_checkpointing=True, ) template_cp_state = self.solver_state step = step or load_manager.latest_step() if step is None: raise ValueError(f"No checkpoints found in {checkpoint_dir}") cp_state = load_manager.restore( step, args=checkpoint.args.StandardRestore(template_cp_state), ) self._restore_state_from_checkpoint(cp_state)
def _save_solver_config(self) -> None: """Save the solver config to the checkpoint directory.""" OmegaConf.save(self.config, self.checkpoint_dir / "config.yaml") @classmethod def _create_checkpoint_manager( cls, checkpoint_dir: str | Path, max_checkpoints: int, enable_async_checkpointing: bool, ) -> checkpoint.CheckpointManager: """Create an Orbax checkpoint manager. Args: checkpoint_dir: Directory for storing checkpoints. max_checkpoints: Maximum number of checkpoints to retain. enable_async_checkpointing: Whether to use async checkpointing. Returns: Configured Orbax checkpoint manager. """ # Configure Orbax options = checkpoint.CheckpointManagerOptions( max_to_keep=max_checkpoints, create=True, enable_async_checkpointing=enable_async_checkpointing, ) return checkpoint.CheckpointManager( checkpoint_dir, options=options, ) @contextmanager def _checkpoint_operation(self): """Context manager for checkpoint operations. Ensures async operations complete before exiting context. """ try: yield finally: if self.enable_async_checkpointing: self.checkpoint_manager.wait_until_finished()
[docs] def save(self, step: int) -> None: """Save current solver state to checkpoint. Args: step: Current iteration/step number to associate with the checkpoint. Returns: None """ if not self.is_checkpointing_enabled: return # Get state to checkpoint cp_state = self.solver_state # Save checkpoint self.checkpoint_manager.save(step, args=checkpoint.args.StandardSave(cp_state)) status = "queued" if self.enable_async_checkpointing else "saved" logger.debug(f"Checkpoint {status} for iteration {step}")
[docs] @classmethod def restore( cls, checkpoint_dir: str | Path, step: int | None = None, new_checkpoint_dir: str | Path | None = None, checkpoint_frequency: int | None = None, max_checkpoints: int | None = None, enable_async_checkpointing: bool | None = None, ) -> Solver: """Load solver from checkpoint. This class method reconstructs a solver instance from a checkpoint, using the stored config to recreate both the problem and solver with the correct parameters. Args: checkpoint_dir: Directory containing checkpoints. step: Specific step to load. If None, loads the latest checkpoint. new_checkpoint_dir: Optional new directory for future checkpoints. Useful when restoring to a different location. checkpoint_frequency: Optional new checkpoint frequency. max_checkpoints: Optional new maximum number of checkpoints. enable_async_checkpointing: Optional new async checkpointing setting. Returns: Reconstructed solver instance with restored state. """ # Initialize checkpoint manager checkpoint_dir = Path(checkpoint_dir).absolute() config_path = checkpoint_dir / "config.yaml" if not config_path.exists(): raise FileNotFoundError( f"No config file found in {checkpoint_dir}. " "If you're working with a problem without configs (e.g., in a notebook), " "you need to:\n" "1. Create your problem instance\n" "2. Create your solver instance with that problem\n" "3. Use solver.load_checkpoint() instead of the class restore() method" ) # Create solver instance with nested problem config = OmegaConf.load(config_path) # Update config with new values (only allow new values if provided) if new_checkpoint_dir is not None: new_checkpoint_dir = Path(new_checkpoint_dir).absolute() config.checkpoint_dir = new_checkpoint_dir if checkpoint_frequency is not None: config.checkpoint_frequency = checkpoint_frequency if max_checkpoints is not None: config.max_checkpoints = max_checkpoints if enable_async_checkpointing is not None: config.enable_async_checkpointing = enable_async_checkpointing solver = instantiate(config) template_cp_state = solver.solver_state manager = cls._create_checkpoint_manager(checkpoint_dir, 1, True) # Get step to restore step = step or manager.latest_step() if step is None: raise ValueError(f"No checkpoints found in {checkpoint_dir}") # Restore state cp_state = manager.restore( step, args=checkpoint.args.StandardRestore(template_cp_state), ) # Restore runtime state solver._restore_state_from_checkpoint(cp_state) return solver
[docs] @abstractmethod def _restore_state_from_checkpoint(self, state: dict[str, Any]) -> None: """Restore solver state from checkpoint. Args: state: Dictionary containing solver state from checkpoint. """ pass