Source code for mdpax.utils.spaces

"""Utilities for working with state, action and random event spaces."""

import itertools
from typing import Callable

import jax.numpy as jnp
import numpy as np
from jaxtyping import Array


[docs] def create_range_space(mins: Array, maxs: Array) -> tuple[Array, Callable]: """Create a range-based discrete space and its indexing function. Helper function for creating any discrete space (states, actions, or random events) that can be represented as ranges in each dimension. Creates both the space (all possible vectors) and a function to map vectors to indices. The ranges in each dimension are inclusive of both mins and maxs. For example, if mins=[0] and maxs=[2], the space will include vectors [0], [1], and [2]. Args: mins: Lower bounds for each dimension [dim], inclusive maxs: Upper bounds for each dimension [dim], inclusive Returns: space: Array of all possible vectors [n_elements, dim] index_fn: Function that maps vector to unique index Note: The 'clip' mode in index_fnmeans any vector will map to a valid index. This is necessary for compatibility with JAX's jit compilation but may lead to unexpected results if the vector is not within the bounds. https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ravel_multi_index.html Example: >>> # For a state space with 2 dimensions: [0,1] x [0,2] (inclusive) >>> state_space, state_to_index = create_range_space(jnp.array([0, 0]), jnp.array([1, 2])) >>> print(state_space) # All 6 combinations including bounds [[0 0] [0 1] [0 2] [1 0] [1 1] [1 2]] >>> print(state_to_index(jnp.array([1, 1]))) 4 >>> # For an action space with 1 dimension: [0,5] (inclusive) >>> action_space, action_to_index = create_range_space(jnp.array([0]), jnp.array([5])) >>> print(action_space) # All 6 values from 0 to 5 [[0] [1] [2] [3] [4] [5]] """ mins = np.asarray(mins, dtype=np.int32) maxs = np.asarray(maxs, dtype=np.int32) dimensions = maxs - mins + 1 # +1 because bounds are inclusive ranges = [ np.arange(min_val, max_val + 1) # +1 to include max_val for min_val, max_val in zip(mins, maxs) ] space = jnp.array(list(itertools.product(*ranges)), dtype=jnp.int32) def index_fn(vector: Array) -> int: """Map vector to unique index using row-major (C-style) ordering.""" return jnp.ravel_multi_index(tuple(vector), dimensions, mode="clip") return space, index_fn