"""Perishable inventory MDP problem from Hendrix et al. (2019)."""
import chex
import jax
import jax.numpy as jnp
import numpy as np
import scipy.stats
from hydra.conf import dataclass
from jaxtyping import Array, Float
from mdpax.core.problem import Problem, ProblemConfig
from mdpax.utils.spaces import create_range_space
from mdpax.utils.types import (
ActionSpace,
ActionVector,
RandomEventSpace,
RandomEventVector,
Reward,
StateSpace,
StateVector,
)
@dataclass
class HendrixTwoProductPerishableConfig(ProblemConfig):
"""Configuration for the HendrixTwoProductPerishable problem.
Args:
max_useful_life: Number of periods before stock expires. Must be >= 1.
demand_poisson_mean_a: Mean of Poisson distribution for product A demand. Must be positive.
demand_poisson_mean_b: Mean of Poisson distribution for product B demand. Must be positive.
substitution_probability: Probability of substituting A for B when B is out. Must be in [0,1].
variable_order_cost_a: Cost per unit of product A ordered
variable_order_cost_b: Cost per unit of product B ordered
sales_price_a: Revenue per unit of product A sold
sales_price_b: Revenue per unit of product B sold
max_order_quantity_a: Maximum units of product A that can be ordered. Must be positive.
max_order_quantity_b: Maximum units of product B that can be ordered. Must be positive.
Example:
>>> config = HendrixTwoProductPerishableConfig(
... max_useful_life=3,
... demand_poisson_mean_a=4.0,
... max_order_quantity_a=15,
... )
>>> problem = HendrixTwoProductPerishable(config=config)
# Or using kwargs:
>>> problem = HendrixTwoProductPerishable(max_useful_life=3)
"""
_target_: str = (
"mdpax.problems.perishable_inventory.hendrix_two_product.HendrixTwoProductPerishable"
)
max_useful_life: int = 2
demand_poisson_mean_a: float = 5.0
demand_poisson_mean_b: float = 5.0
substitution_probability: float = 0.5
variable_order_cost_a: float = 0.5
variable_order_cost_b: float = 0.5
sales_price_a: float = 1.0
sales_price_b: float = 1.0
max_order_quantity_a: int = 10
max_order_quantity_b: int = 10
def __post_init__(self) -> None:
"""Validate configuration parameters."""
if self.max_useful_life < 1:
raise ValueError("max_useful_life must be greater than or equal to 1")
if self.demand_poisson_mean_a <= 0:
raise ValueError("demand_poisson_mean_a must be positive")
if self.demand_poisson_mean_b <= 0:
raise ValueError("demand_poisson_mean_b must be positive")
if not 0 <= self.substitution_probability <= 1:
raise ValueError("substitution_probability must be between 0 and 1")
if self.max_order_quantity_a <= 0:
raise ValueError("max_order_quantity_a must be positive")
if self.max_order_quantity_b <= 0:
raise ValueError("max_order_quantity_b must be positive")
[docs]
class HendrixTwoProductPerishable(Problem):
"""Two-product perishable inventory MDP problem from Hendrix et al. (2019).
Models a two-product, single-echelon, periodic review perishable
inventory replenishment problem where all stock has the same remaining
useful life at arrival and there is the possibility for substution between
products.
State Space (state_dim = 2 * max_useful_life):
Vector containing:
- Product A stock by age: [max_useful_life] elements in range [0, max_order_quantity_a], ordered with oldest units on the right
- Product B stock by age: [max_useful_life] elements in range [0, max_order_quantity_b], ordered with oldest units on the right
Action Space (action_dim = 2):
Vector containing:
- Product A order quantity: 1 element in range [0, max_order_quantity_a]
- Product B order quantity: 1 element in range [0, max_order_quantity_b]
Random Events (event_dim = 2):
Vector containing:
- Product A units issued: 1 element in range [0, max_stock_a]
- Product B units issued: 1 element in range [0, max_stock_b]
Dynamics:
1. Place replenishment order
2. Random event determines units issued of each product, incorporating both:
- Poisson-distributed demand for each product
- Possible substitution from A to B when B's demand exceeds stock
3. Issue stock using FIFO policy for each product
4. Age remaining stock one period and discard expired units
5. Reward is revenue from units issued less variable ordering costs
6. Receive order placed at the start of the period immediately before the next period
Args:
config: Configuration object. If provided, keyword arguments are ignored.
**kwargs: Parameters matching :class:`HendrixTwoProductPerishableConfig`.
See Config class for detailed parameter descriptions.
References:
- Hendrix et al. (2019): https://doi.org/10.1002/cmm4.1027
Note:
The three random elements in the transition are the demand for each product and
the number of units of demand for product B willing to accept substitution with
product A. For consistency with the original implementation, the random events
are taken the be the number of units of each product issued. The transition is
deterministic given the number of products of each type issued.
"""
Config = HendrixTwoProductPerishableConfig
def __init__(
self, config: HendrixTwoProductPerishableConfig | None = None, **kwargs
):
if config is not None:
self.config = config
else:
self.config = self.Config(**kwargs)
self.max_useful_life = self.config.max_useful_life
self.demand_poisson_mean_a = self.config.demand_poisson_mean_a
self.demand_poisson_mean_b = self.config.demand_poisson_mean_b
self.substitution_probability = self.config.substitution_probability
self.variable_order_cost_a = self.config.variable_order_cost_a
self.variable_order_cost_b = self.config.variable_order_cost_b
self.sales_price_a = self.config.sales_price_a
self.sales_price_b = self.config.sales_price_b
self.variable_order_costs = jnp.array(
[self.variable_order_cost_a, self.variable_order_cost_b]
)
self.sales_prices = jnp.array([self.sales_price_a, self.sales_price_b])
self.max_order_quantity_a = self.config.max_order_quantity_a
self.max_order_quantity_b = self.config.max_order_quantity_b
super().__init__()
@property
def name(self) -> str:
"""A unique identifier for this problem type"""
return "hendrix_two_product"
def _setup_before_space_construction(self) -> None:
"""Setup before space construction."""
# Compute dynamic limits on stock and demand
self.max_stock_a = self.max_order_quantity_a * self.max_useful_life
self.max_stock_b = self.max_order_quantity_b * self.max_useful_life
self.max_demand = self.max_useful_life * (
max(self.max_order_quantity_a, self.max_order_quantity_b) + 2
)
# Build lookup tables for state, action, and random event components
# so they can be used to index into state, action, and random event vectors
# by name in transition function
self.state_component_lookup = self._construct_state_component_lookup()
self.action_component_lookup = self._construct_action_component_lookup()
self.random_event_component_lookup = (
self._construct_random_event_component_lookup()
)
def _setup_after_space_construction(self) -> None:
"""Setup after space construction."""
# Precompute conditional probabilities
self.pu = self._calculate_pu()
self.pz = self._calculate_pz()
def _construct_state_space(self) -> StateSpace:
"""Build array of all possible states.
Returns:
Array of shape [n_states, state_dim] containing all possible states
"""
mins = np.zeros(2 * self.max_useful_life, dtype=np.int32)
maxs = np.hstack(
[
np.full(
self.max_useful_life,
self.max_order_quantity_a,
),
np.full(
self.max_useful_life,
self.max_order_quantity_b,
),
]
)
state_space, self._state_to_index_fn = create_range_space(mins, maxs)
return state_space
def _construct_action_space(self) -> ActionSpace:
"""Build array of all possible actions.
Returns:
Array of shape [n_actions, action_dim] containing all possible actions
"""
mins = np.array([0, 0])
maxs = np.array([self.max_order_quantity_a, self.max_order_quantity_b])
action_space, _ = create_range_space(mins, maxs)
return action_space
def _construct_random_event_space(self) -> RandomEventSpace:
"""Build array of all possible random events.
Returns:
Array of shape [n_events, event_dim] containing all possible random events
"""
mins = np.array([0, 0])
maxs = np.array([self.max_stock_a, self.max_stock_b])
random_event_space, self._random_event_to_index = create_range_space(mins, maxs)
return random_event_space
[docs]
def state_to_index(self, state: StateVector) -> int:
"""Convert state vector to index.
Args:
state: State vector to convert [state_dim]
Returns:
Integer index of the state in state_space
"""
return self._state_to_index_fn(state)
[docs]
def random_event_probability(
self,
state: StateVector,
action: ActionVector,
random_event: RandomEventVector,
) -> float:
"""Compute probability of random event given state and action.
The number of units issued of each product follows a compound distribution:
- Demand for each product is Poisson distributed
- For product B, if demand exceeds stock, excess demand can be satisfied by product A with binomial probability
Args:
state: Current state vector [state_dim]
action: Action vector [action_dim]
random_event: Random event vector [event_dim]
Returns:
Probability of this combination of issued units for both products
"""
stock_a = jnp.sum(state[self.state_component_lookup["stock_a"]])
stock_b = jnp.sum(state[self.state_component_lookup["stock_b"]])
# Issued a less than stock of a, issued b less than stock of b
probs_1 = self._get_probs_ia_lt_stock_a_ib_lt_stock_b(stock_a, stock_b)
# Issued a equal to stock of a, issued b less than stock of b
probs_2 = self._get_probs_ia_eq_stock_a_ib_lt_stock_b(stock_a, stock_b)
# Issued a less than stock of a, issued b equal to stock of b
probs_3 = self._get_probs_ia_lt_stock_a_ib_eq_stock_b(stock_a, stock_b)
# Issued a equal to stock of a, issued b equal to stock of b
probs_4 = self._get_probs_ia_eq_stock_a_ib_eq_stock_b(stock_a, stock_b)
all_probs = (probs_1 + probs_2 + probs_3 + probs_4).reshape(-1)
return all_probs[self._random_event_to_index(random_event)]
[docs]
def transition(
self,
state: StateVector,
action: ActionVector,
random_event: RandomEventVector,
) -> tuple[StateVector, Reward]:
"""Compute next state and reward for a transition.
Processes one step of the two-product perishable inventory system:
1. Place replenishment order
2. Random event determines units issued of each product, incorporating both:
- Poisson-distributed demand for each product
- Possible substitution from A to B when B's demand exceeds stock
3. Issue stock using FIFO policy for each product
4. Age remaining stock one period and discard expired units
5. Reward is revenue from units issued less variable ordering costs
6. Receive order placed at the start of the period immediately before the next period
Args:
state: Current state vector [state_dim]
action: Action vector [action_dim]
random_event: Random event vector [event_dim]
Returns:
Tuple containing the next state vector [state_dim] and
the immediate reward
"""
issued_a = random_event[self.random_event_component_lookup["issued_a"]]
issued_b = random_event[self.random_event_component_lookup["issued_b"]]
opening_stock_a = state[self.state_component_lookup["stock_a"]]
opening_stock_b = state[self.state_component_lookup["stock_b"]]
stock_after_issue_a = self._issue_fifo(opening_stock_a, issued_a)
stock_after_issue_b = self._issue_fifo(opening_stock_b, issued_b)
# Pass through the random outcome (units issued)
single_step_reward = self._calculate_single_step_reward(
state, action, random_event
)
# Age stock one day and receive the order from the morning
closing_stock_a = jnp.hstack(
[
action[self.action_component_lookup["order_quantity_a"]],
stock_after_issue_a[0 : self.max_useful_life - 1],
]
)
closing_stock_b = jnp.hstack(
[
action[self.action_component_lookup["order_quantity_b"]],
stock_after_issue_b[0 : self.max_useful_life - 1],
]
)
next_state = jnp.concatenate([closing_stock_a, closing_stock_b], axis=-1)
return (
next_state,
single_step_reward,
)
[docs]
def initial_value(self, state: StateVector) -> float:
"""Return initial value estimate for a given state.
Initial value estimate based on one-step ahead expected sales revenue.
Args:
state: State vector [state_dim]
Returns:
Initial value estimate for the given state
"""
return self._calculate_expected_sales_revenue(state)
# Transition function helper methods
# ----------------------------------
def _construct_state_component_lookup(self) -> dict[str, int | slice]:
"""Build mapping from state components to indices."""
m = self.max_useful_life
return {
"stock_a": slice(0, m),
"stock_b": slice(m, 2 * m),
}
def _construct_action_component_lookup(self) -> dict[str, int]:
"""Build mapping from action components to indices."""
return {
"order_quantity_a": 0,
"order_quantity_b": 1,
}
def _construct_random_event_component_lookup(self) -> dict[str, int]:
"""Build mapping from random event components to indices."""
return {
"issued_a": 0,
"issued_b": 1,
}
def _issue_fifo(
self, opening_stock: Float[Array, "max_useful_life"], demand: int
) -> Float[Array, "max_useful_life"]:
"""Issue stock using FIFO (First-In-First-Out) policy.
Issues stock starting with oldest items first (right side of vector).
Uses scan to process each age category in sequence.
Args:
opening_stock: Current stock levels by age [max_useful_life]
demand: Total customer demand to satisfy
Returns:
Array of ppdated stock levels after issuing [max_useful_life]
"""
_, remaining_stock = jax.lax.scan(
self._issue_one_step, demand, opening_stock, reverse=True
)
return remaining_stock
def _issue_one_step(
self, remaining_demand: int, stock_element: int
) -> tuple[int, int]:
"""Process one age category during stock issuing.
Args:
remaining_demand: Unfulfilled demand to satisfy
stock_element: Available stock of current age
Returns:
Tuple containing the remaining demand and remaining stock
after processing this age category
"""
remaining_stock = (stock_element - remaining_demand).clip(0)
remaining_demand = (remaining_demand - stock_element).clip(0)
return remaining_demand, remaining_stock
def _calculate_single_step_reward(
self,
state: StateVector,
action: ActionVector,
random_event: RandomEventVector,
) -> Reward:
"""Calculate reward (revenue minus costs) for one transition step.
Computes total reward by combining:
- Variable ordering costs (negative)
- Sales revenue from issued stock (positive)
Args:
state: Current state vector [state_dim]
action: Action vector [action_dim]
random_event: Random event vector [event_dim] containing units issued
Returns:
Revenue minus costs for this step
"""
cost = jnp.dot(action, self.variable_order_costs)
revenue = jnp.dot(random_event, self.sales_prices)
return revenue - cost
# Random event probability helper methods
# ---------------------------------------
def _calculate_pu(self) -> Float[Array, "max_demand_plus_one max_stock_b_plus_one"]:
"""Calculate conditional probabilities for substitution demand.
Returns:
Array of probabilities where pu[u,y] is Prob(u|y),
the conditional probability of u substitution demand given y units
of product B in stock. Shape is [max_demand + 1, max_stock_b + 1].
"""
pu = np.zeros((self.max_demand + 1, self.max_stock_b + 1))
for y in range(0, self.max_stock_b + 1):
x = np.arange(0, self.max_demand - y)
pu[0, y] = scipy.stats.poisson.pmf(x + y, self.demand_poisson_mean_b).dot(
scipy.stats.binom.pmf(0, x, self.substitution_probability)
)
for u in range(1, self.max_demand - y):
x = np.arange(u, self.max_demand - y)
pu[u, y] = scipy.stats.poisson.pmf(
x + y, self.demand_poisson_mean_b
).dot(scipy.stats.binom.pmf(u, x, self.substitution_probability))
return jnp.array(pu)
def _calculate_pz(self) -> Float[Array, "max_demand_plus_one max_stock_b_plus_one"]:
"""Calculate conditional probabilities for total demand for product A.
Returns:
Array of probabilities where pz[z,y] is Prob(z|y),
the conditional probability of z total demand for product A given
demand for product B is at least equal to y units in stock.
Shape is [max_demand + 1, max_stock_b + 1].
"""
pz = np.zeros((self.max_demand + 1, self.max_stock_b + 1))
pa = scipy.stats.poisson.pmf(
np.arange(self.max_demand + 1), self.demand_poisson_mean_a
)
# No demand for a itself, and no subst demand
pz[0, :] = pa[0] * self.pu[0, :]
for y in range(0, self.max_stock_b + 1):
for z in range(1, self.max_demand + 1):
pz[z, y] = pa[np.arange(0, z + 1)].dot(
self.pu[z - np.arange(0, z + 1), y]
)
return jnp.array(pz)
def _get_probs_ia_lt_stock_a_ib_lt_stock_b(
self, stock_a: chex.Array, stock_b: chex.Array
) -> chex.Array:
"""Calculate probabilities for case where issued quantities are below stock levels."""
# P(i_a, i_b) = P(d_a=ia) * P(d_b=ib)
# Easy cases, all demand met and no substitution
prob_da = jax.scipy.stats.poisson.pmf(
jnp.arange(self.max_stock_a + 1), self.demand_poisson_mean_a
)
prob_da_masked = prob_da * (jnp.arange(self.max_stock_a + 1) < stock_a)
prob_db = jax.scipy.stats.poisson.pmf(
jnp.arange(self.max_stock_b + 1), self.demand_poisson_mean_b
)
prob_db_masked = prob_db * (jnp.arange(self.max_stock_b + 1) < stock_b)
issued_probs = jnp.outer(prob_da_masked, prob_db_masked)
return issued_probs
def _get_probs_ia_eq_stock_a_ib_lt_stock_b(
self, stock_a: chex.Array, stock_b: chex.Array
) -> chex.Array:
"""Calculate probabilities for case where product A issued equals stock level."""
# Therefore P(i_a, i_b) = P(d_a>=ia) * P(d_b=ib)
# No substitution
issued_probs = jnp.zeros((self.max_stock_a + 1, self.max_stock_b + 1))
# Demand for a higher than stock_a, but demand for b less than than stock_b
prob_da_gteq_stock_a = 1 - jax.scipy.stats.poisson.cdf(
stock_a - 1, self.demand_poisson_mean_a
)
prob_db = jax.scipy.stats.poisson.pmf(
jnp.arange(self.max_stock_b + 1), self.demand_poisson_mean_b
)
prob_db_masked = prob_db * (jnp.arange(self.max_stock_b + 1) < stock_b)
probs = prob_da_gteq_stock_a * prob_db_masked
issued_probs = issued_probs.at[stock_a, :].add(probs)
return issued_probs
def _get_probs_ia_lt_stock_a_ib_eq_stock_b(
self, stock_a: chex.Array, stock_b: chex.Array
) -> chex.Array:
"""Calculate probabilities for case where product B issued equals stock level."""
# Therefore total demand for a is < stock_a, demand for b >= stock_b
issued_probs = jnp.zeros((self.max_stock_a + 1, self.max_stock_b + 1))
# Demand for b higher than stock_b, so substitution possible
probs_issued_a = jax.lax.dynamic_slice(
self.pz, (0, stock_b), (self.max_demand + 1, 1)
).reshape(-1)
probs_issued_a_masked = probs_issued_a * (
jnp.arange(len(probs_issued_a)) < stock_a
)
# Trim array to max_stock_a
probs_issued_a_masked = jax.lax.dynamic_slice(
probs_issued_a_masked, (0,), (self.max_stock_a + 1,)
)
issued_probs = issued_probs.at[:, stock_b].add(probs_issued_a_masked)
return issued_probs
def _get_probs_ia_eq_stock_a_ib_eq_stock_b(
self, stock_a: chex.Array, stock_b: chex.Array
) -> chex.Array:
"""Calculate probabilities for case where both products issued equal stock levels."""
# Therefore total demand for a is >= stock_a, demand for b >= stock_b
issued_probs = jnp.zeros((self.max_stock_a + 1, self.max_stock_b + 1))
# Demand for b higher than stock_b, so subsitution possible
probs_issued_a = jax.lax.dynamic_slice(
self.pz, (0, stock_b), (self.max_demand + 1, 1)
).reshape(-1)
prob_combined_demand_gteq_stock_a = probs_issued_a.dot(
jnp.arange(len(probs_issued_a)) >= stock_a
)
issued_probs = issued_probs.at[stock_a, stock_b].add(
prob_combined_demand_gteq_stock_a
)
return issued_probs
# Initial value helper methods
# ---------------------------------------
def _calculate_sales_revenue_for_possible_random_events(
self,
) -> Float[Array, "n_events"]:
"""Calculate the sales revenue for each possible random event.
Returns:
Array of sales revenue for each possible random event [n_events]
"""
return (self.random_event_space.dot(self.sales_prices)).reshape(-1)
def _calculate_expected_sales_revenue(self, state: StateVector) -> float:
"""Calculate the expected sales revenue for a given state.
Args:
state: State vector to calculate expected revenue for [state_dim]
Returns:
Expected sales revenue for one step from this state
"""
issued_probabilities = jax.vmap(
self.random_event_probability, in_axes=(None, None, 0)
)(state, 0, self.random_event_space)
expected_sales_revenue = issued_probabilities.dot(
self._calculate_sales_revenue_for_possible_random_events()
)
return expected_sales_revenue