"""Forest management MDP problem."""
import jax.numpy as jnp
from hydra.conf import dataclass
from mdpax.core.problem import Problem, ProblemConfig
from mdpax.utils.types import (
ActionSpace,
ActionVector,
RandomEventSpace,
RandomEventVector,
Reward,
StateSpace,
StateVector,
)
@dataclass
class ForestConfig(ProblemConfig):
"""Configuration for the Forest problem.
Args:
S: Number of states (tree ages from 0 to S-1). Controls the maximum
age the forest can reach. Must be positive.
r1: Reward for waiting when forest in oldest state.
r2: Reward for cutting when forest in oldest state.
p: Base probability of fire. Must be in [0,1].
Example:
>>> config = ForestConfig(S=4, r1=5.0, r2=2.0, p=0.1)
>>> problem = Forest(config=config)
# Or using kwargs:
>>> problem = Forest(S=4, r1=5.0) # Other params use defaults
"""
_target_: str = "mdpax.problems.forest.Forest"
S: int = 3
r1: float = 4.0
r2: float = 2.0
p: float = 0.1
def __post_init__(self) -> None:
"""Validate configuration parameters."""
if self.S <= 0:
raise ValueError("Number of states (S) must be positive")
if not 0 <= self.p <= 1:
raise ValueError("Probability (p) must be between 0 and 1")
[docs]
class Forest(Problem):
"""Forest management MDP problem.
The forest management problem involves deciding whether to cut down trees
for immediate reward or wait for them to grow larger. There is a risk of
fire destroying the forest during each time step.
Adapted from the example problem in pymdptoolbox.
State Space (state_dim = 1):
Vector containing:
- Tree age: 1 element in range [0, S-1] (newly planted to mature forest)
Action Space (action_dim = 1):
Vector containing:
- Decision: 1 element in range {0=wait, 1=cut}
Random Events (event_dim = 1):
Vector containing:
- Fire occurrence: 1 element in range {0=no_fire, 1=fire}
Dynamics:
1. Choose to cut or wait
2. If wait:
- Check for fire (probability p if waiting)
- If fire, reset to age 0 with no reward
- If no fire, age increases by 1 (up to S-1) and receive r1 reward if in oldest state
3. If cut:
- Receive reward r2 if in oldest state and 1 otherwise
- Reset to age 0
Args:
config: Configuration object. If provided, keyword arguments are ignored.
**kwargs: Parameters matching :class:`ForestConfig`. See ForestConfig
for detailed parameter descriptions.
References:
- pymdptoolbox: https://github.com/sawcordwell/pymdptoolbox/blob/master/src/mdptoolbox/example.py
"""
Config = ForestConfig
def __init__(self, config: ForestConfig | None = None, **kwargs):
"""Initialize the Forest problem."""
if config is not None:
self.config = config
else:
self.config = self.Config(**kwargs)
self.S = self.config.S
self.r1 = self.config.r1
self.r2 = self.config.r2
self.p = self.config.p
self._probability_matrix = jnp.array([[1 - self.p, self.p], [1, 0]])
super().__init__()
@property
def name(self) -> str:
"""A unique identifier for this problem type"""
return "forest"
def _construct_state_space(self) -> StateSpace:
"""Build array of all possible states.
Returns:
Array of shape [n_states, state_dim] containing all possible states
"""
return jnp.arange(self.S, dtype=jnp.int32).reshape(-1, 1)
[docs]
def state_to_index(self, state: StateVector) -> int:
"""Convert state vector to index.
Args:
state: Vector representation of a state [state_dim]
Returns:
Index of the state in state_space
"""
return state[0]
def _construct_action_space(self) -> ActionSpace:
"""Build array of all possible actions.
Returns:
Array of shape [n_actions, action_dim] containing all possible actions
"""
return jnp.array([[0], [1]], dtype=jnp.int32)
def _construct_random_event_space(self) -> RandomEventSpace:
"""Build array of all possible random events.
Returns:
Array of shape [n_events, event_dim] containing all possible random events
"""
return jnp.array([[0], [1]], dtype=jnp.int32)
[docs]
def random_event_probability(
self, state: StateVector, action: ActionVector, random_event: RandomEventVector
) -> float:
"""Compute probability of random event given state-action pair.
When waiting:
- No fire probability is 1 - p
- Fire probability is p
When cutting:
- No fire probability is 1
- Fire probability is 0
Args:
state: Current state vector [state_dim]
action: Action vector [action_dim]
random_event: Random event vector [event_dim]
Returns:
Probability of the random event occurring
"""
return self._probability_matrix[action[0], random_event[0]]
[docs]
def transition(
self, state: StateVector, action: ActionVector, random_event: RandomEventVector
) -> tuple[StateVector, Reward]:
"""Compute next state and reward for a transition.
Processes one step of the forest management system:
1. Choose to cut or wait
2. If wait:
- Check for fire (probability p if waiting)
- If fire, reset to age 0 with no reward
- If no fire, age increases by 1 (up to S-1) and receive r1 reward if in oldest state
3. If cut:
- Receive reward r2 if in oldest state and 1 otherwise
- Reset to age 0
Args:
state: Current state vector [state_dim]
action: Action vector [action_dim]
random_event: Random event vector [event_dim]
Returns:
Tuple containing the next state vector [state_dim] and
the immediate reward
"""
is_cut = action[0] == 1
is_fire = random_event[0] == 1
# Compute reward - only get reward when cutting
reward = jnp.where(
is_cut,
# Cut reward depends on tree age
jnp.where(
state[0] == self.S - 1, self.r2, jnp.where(state[0] == 0, 0.0, 1.0)
),
# No reward for waiting except in final state
jnp.where(state[0] == self.S - 1, self.r1, 0.0),
)
# Compute next state
next_state = jnp.array(
[
jnp.where(
is_cut | is_fire,
# Reset to age 0 if cut or fire
0,
# Otherwise increment age up to S-1
jnp.minimum(state[0] + 1, self.S - 1),
)
]
).astype(jnp.int32)
return next_state, reward