Source code for ethicalgardeners.observation

"""
The Observation module defines how agents perceive the environment in the
Ethical Gardeners simulation.

This module implements different observation strategies to control what
information agents can access about the environment. It provides:

1. :py:class:`ObservationStrategy`: An abstract strategy interface for
implementing custom observation methods

2. Two concrete implementations:

   - :py:class:`TotalObservation`: Complete grid visibility
   - :py:class:`PartialObservation`: Limited visibility range

Observations are formatted as numpy arrays compatible with Gymnasium
environments.

Custom observation strategies can be implemented by extending the
ObservationStrategy class and implementing the required methods.
"""
from abc import ABC, abstractmethod
from gymnasium.spaces import Box
import numpy as np

from ethicalgardeners.agent import Agent
from ethicalgardeners.constants import FEATURES_PER_CELL
from ethicalgardeners.gridworld import CellType


[docs] class ObservationStrategy(ABC): """ Abstract base class defining the interface for observation strategies. Observation strategies determine how agents perceive the environment, defining the structure of the observation space and how observations are generated from the world state. """
[docs] def __init__(self): """ Create the observation strategy. """
[docs] @abstractmethod def observation_space(self, agent: Agent): """ Define the observation space for a specific agent. Args: agent (:py:class:`.Agent`): The agent for which to define the observation space. Returns: gym.Space: The observation space for the specified agent. """ pass
[docs] @abstractmethod def get_observation(self, grid_world, agent: Agent): """ Generate an observation for an agent based on the current world state. Args: grid_world (:py:class:`.GridWorld`): The current state of the grid. agent (:py:class:`.Agent`): The agent for which to generate the observation. Returns: numpy.ndarray: The observation for the specified agent. """ pass
[docs] class TotalObservation(ObservationStrategy): """ Strategy that provides agents with a complete view of the entire grid. This strategy gives agents perfect information about the state of the environment, including all cells, agents, and flowers. Each cell in the grid is represented as a vector of features: * Cell type (normalized): Value between 0 and 1 representing the type of cell (:py:class:`.CellType`) divided by the number of cell types. * Pollution level (normalized): Value between 0 and 1 representing the pollution level in the cell, normalized between the minimum and maximum pollution levels in the grid. * Flower presence and type (normalized): 0 if no flower is present, otherwise a value between 0 and 1 representing the flower type + 1 divided by the number of flower types. * Flower growth stage (normalized): 0 if no flower is present, otherwise a value between 0 and 1 representing the current growth stage of the flower + 1 divided by the total number of growth stages + 1. * Agent presence (normalized): 0 if no agent is present, otherwise a value between 0 and 1 representing the index of the agent in the grid world + 1 divided by the total number of agents. * Agent's X position (normalized): Value between 0 and 1 representing the agent's X position normalized by the grid width minus 1. * Agent's Y position (normalized): Value between 0 and 1 representing the agent's Y position normalized by the grid height minus 1. Attributes: observation_shape (tuple): The dimensions of the observation (width, height, FEATURES_PER_CELL). """
[docs] def __init__(self, grid_world): """ Create the total observation strategy. Args: grid_world (:py:class:`.GridWorld`): The grid world environment to observe. """ super().__init__() self.observation_shape = (grid_world.width, grid_world.height, FEATURES_PER_CELL)
[docs] def observation_space(self, agent: Agent): """ Define the observation space as a Box with the full grid and features per cell. Args: agent (:py:class:`.Agent`): The agent for which to define the observation space. Returns: gymnasium.spaces.Box: A box space with dimensions (width, height, FEATURES_PER_CELL). """ return Box(low=0, high=1, shape=self.observation_shape, dtype=np.float32)
[docs] def get_observation(self, grid_world, agent: Agent): """ Generate a complete observation of the entire grid. Args: grid_world (:py:class:`.GridWorld`): The current state of the grid. agent (:py:class:`.Agent`): The agent for which to generate the observation. Returns: numpy.ndarray: A 3D array containing the full grid state. """ obs = np.zeros(self.observation_shape, dtype=np.float32) for x in range(self.observation_shape[0]): for y in range(self.observation_shape[1]): cell = grid_world.get_cell((x, y)) # Feature 1: Cell type (normalized) obs[x, y, 0] = cell.cell_type.value / len(CellType) # Feature 2: Pollution level (normalized) pollution_normalized = 0.0 if cell.pollution is not None: pollution_normalized = ( (cell.pollution - grid_world.min_pollution) / (grid_world.max_pollution - grid_world.min_pollution) ) obs[x, y, 1] = pollution_normalized # Feature 3: Flower presence and type (normalized) if cell.has_flower(): # +1 because flower types start at 0 so it avoids # the flower type being 0 even when there is a flower flower_type_normalized = ( (cell.flower.flower_type + 1) / len(grid_world.flowers_data) ) obs[x, y, 2] = flower_type_normalized # Feature 4: Flower growth stage (normalized) # +1 because growth stages start at 0 so it avoids # the growth stage being 0 even when there is a flower growth_stage_normalized = ( (cell.flower.current_growth_stage + 1) / (cell.flower.num_growth_stage + 1) ) obs[x, y, 3] = growth_stage_normalized # Feature 5: Agent presence (normalized) if cell.has_agent(): # Find the index of the agent in the grid world agent_idx = grid_world.agents.index(cell.agent) if agent_idx is not None: # +1 because agent indices start at 0 agent_normalized = ( (agent_idx + 1) / len(grid_world.agents) ) obs[x, y, 4] = agent_normalized # Feature 6: Agent X position (normalized) # width - 1 because the X position starts at 0 obs[x, y, 5] = ( agent.position[0] / (grid_world.width - 1) ) # Feature 7: Agent Y position (normalized) # height - 1 because the Y position starts at 0 obs[x, y, 6] = ( agent.position[1] / (grid_world.height - 1) ) return obs
[docs] class PartialObservation(ObservationStrategy): """ Strategy that provides agents with a limited view around their position. This strategy simulates limited perception by only showing agents a square area centered on their current position. Each cell in the visible area is represented as a vector of features: * Cell type (normalized): Value between 0 and 1 representing the type of cell (:py:class:`.CellType`) divided by the number of cell types. * Pollution level (normalized): Value between 0 and 1 representing the pollution level in the cell, normalized between the minimum and maximum pollution levels in the grid. * Flower presence and type (normalized): 0 if no flower is present, otherwise a value between 0 and 1 representing the flower type + 1 divided by the number of flower types. * Flower growth stage (normalized): 0 if no flower is present, otherwise a value between 0 and 1 representing the current growth stage of the flower + 1 divided by the total number of growth stages + 1. * Agent presence (normalized): 0 if no agent is present, otherwise a value between 0 and 1 representing the index of the agent in the grid world + 1 divided by the total number of agents. * Agent's X position (normalized): Value between 0 and 1 representing the agent's X position normalized by the grid width minus 1. * Agent's Y position (normalized): Value between 0 and 1 representing the agent's Y position normalized by the grid height minus 1. Attributes: obs_range (int): The visibility range in cells around the agent's position. observation_shape (tuple): The dimensions of the observation (2*obs_range+1, 2*obs_range+1, FEATURES_PER_CELL). """
[docs] def __init__(self, obs_range=1): """ Create the partial observation strategy. Args: obs_range (int, optional): The number of cells visible in each direction from the agent. """ super().__init__() self.obs_range = obs_range self.observation_shape = (2 * obs_range + 1, 2 * obs_range + 1, FEATURES_PER_CELL)
[docs] def observation_space(self, agent: Agent): """ Define the observation space as a Box with dimensions based on the range. Args: agent (:py:class:`.Agent`): The agent for which to define the observation space. Returns: gymnasium.spaces.Box: A box space with dimensions based on the visibility range and features per cell. """ return Box(low=0, high=1, shape=self.observation_shape, dtype=np.float32)
[docs] def get_observation(self, grid_world, agent: Agent): """ Generate a partial observation centered on the agent's position. Each cell in the visible area is represented with multiple features. Areas outside the grid boundaries appear as zeros in the observation. Args: grid_world (:py:class:`.GridWorld`): The current state of the grid. agent (:py:class:`.Agent`): The agent for which to generate the observation. Returns: numpy.ndarray: A 3D array containing the visible portion of the grid with all features. """ obs = np.zeros(self.observation_shape, dtype=np.float32) agent_x, agent_y = agent.position for i in range(self.observation_shape[0]): for j in range(self.observation_shape[1]): x = agent_x + j - self.obs_range y = agent_y + i - self.obs_range if 0 <= y < grid_world.height and 0 <= x < grid_world.width: cell = grid_world.get_cell((x, y)) # Feature 1: Cell type (normalized) obs[i, j, 0] = cell.cell_type.value / len(CellType) # Feature 2: Pollution level (normalized) pollution_normalized = 0 if cell.pollution is not None: pollution_normalized = ( (cell.pollution - grid_world.min_pollution) / (grid_world.max_pollution - grid_world.min_pollution) ) obs[i, j, 1] = pollution_normalized # Feature 3: Flower presence and type (normalized) if cell.has_flower(): # +1 because flower types start at 0 so it avoids # the flower type being 0 even when there is a flower flower_type_normalized = ( (cell.flower.flower_type + 1) / len(grid_world.flowers_data) ) obs[i, j, 2] = flower_type_normalized # Feature 4: Flower growth stage (normalized) # +1 because growth stages start at 0 so it avoids # the growth stage being 0 even when there is a flower growth_stage_normalized = ( (cell.flower.current_growth_stage + 1) / (cell.flower.num_growth_stage + 1) ) obs[i, j, 3] = growth_stage_normalized # Feature 5: Agent presence (normalized) if cell.has_agent(): # Find the index of the agent in the grid world agent_idx = grid_world.agents.index(cell.agent) if agent_idx is not None: # +1 because agent indices start at 0 agent_normalized = ( (agent_idx + 1) / len(grid_world.agents) ) obs[i, j, 4] = agent_normalized # Feature 6: Agent X position (normalized) # width - 1 because the X position starts at 0 obs[i, j, 5] = agent_x / (grid_world.width - 1) # Feature 7: Agent Y position (normalized) # height - 1 because the Y position starts at 0 obs[i, j, 6] = agent_y / (grid_world.height - 1) return obs