How to Extend the Ethical Gardeners Environment¶
The Ethical Gardeners environment was made to be extensible, so that you can tailor it to your use-case. This tutorial explains how to extend its different components.
Some components can be extended by creating new classes and passing them to the existing classes; but, for most extensions, it will be easier or even necessary to change the existing code directly. We thus recommend to clone the repository rather than installing it as a Python package if your goal is to extend the environment.
1. Adding or Modifying Reward Functions¶
Reward functions determine how agents are evaluated in the environment. The RewardFunctions class contains methods for calculating rewards.
Current Structure¶
The RewardFunctions class calculates three types of rewards:
Ecological (
compute_ecology_reward())Well-being (
compute_wellbeing_reward())Biodiversity (
compute_biodiversity_reward())
These rewards are then combined in the compute_reward() method by averaging the 3.
How to Add a New Reward Function¶
Add a new method to the
RewardFunctionsclassModify the
compute_reward()method to include your new reward
Example: Adding a Collaboration Reward¶
def compute_collaboration_reward(self, grid_world_prev, grid_world, agent, action):
"""
Compute a reward based on how well agents collaborate to maintain
balanced planting across the grid.
Args:
grid_world_prev: The grid world before the action
grid_world: The current grid world after the action
agent: The agent performing the action
action: The action performed
Returns:
float: Normalized collaboration reward between -1 and 1
"""
# Code to calculate collaboration reward
return reward
Then, modify the compute_reward() method:
def compute_reward(self, grid_world_prev, grid_world, agent, action):
"""Compute the multi-objective reward for an agent."""
ecology_reward = self.compute_ecology_reward(grid_world_prev, grid_world, agent, action)
wellbeing_reward = self.compute_wellbeing_reward(grid_world_prev, grid_world, agent, action)
biodiversity_reward = self.compute_biodiversity_reward(grid_world_prev, grid_world, agent, action)
collaboration_reward = self.compute_collaboration_reward(grid_world_prev, grid_world, agent, action)
return {
'ecology': ecology_reward,
'wellbeing': wellbeing_reward,
'biodiversity': biodiversity_reward,
'collaboration': collaboration_reward,
'total': (ecology_reward + wellbeing_reward + biodiversity_reward + collaboration_reward) / 4
}
Change the reward functions used¶
You can also change the compute_reward() method to return only a subset of the already
implemented reward functions. For example, if you’d like only to compute the ecology and
wellbeing rewards:
def compute_reward(self, grid_world_prev, grid_world, agent, action):
"""Compute the multi-objective reward for an agent."""
ecology_reward = self.compute_ecology_reward(grid_world_prev, grid_world, agent, action)
wellbeing_reward = self.compute_wellbeing_reward(grid_world_prev, grid_world, agent, action)
return {
'ecology': ecology_reward,
'wellbeing': wellbeing_reward,
'total': (ecology_reward + wellbeing_reward) / 2
}
Or you can change the aggregation method used to compute the total reward. What is important
is that the total reward is the “final” (scalar) reward that is sent to each agent; the other
members of the returned dictionary are sent as part of the info additional data. They can be
used for analysis, or even given to Multi-Objective agents (if the learning algorithm supports
this).
2. Adding an Observation Type¶
Observations determine how agents perceive the environment. The observation module contains observation strategies.
Current Structure¶
The module implements an abstract ObservationStrategy class with two concrete implementations:
TotalObservation: provides a complete view of the gridPartialObservation: provides a limited view around the agent
How to Add a New Observation Type¶
Create a new class that inherits from
ObservationStrategyImplement the
observation_space()andget_observation()methodsModify the
make_env()function to include your new observation type
Example: Adding a Total Observation with Only Pollution Levels and Flower Growth Stages¶
class TotalObservationPollutionFlowers(ObservationStrategy):
"""
Strategy that provides agents with a full view of the grid,
only including pollution levels and flower growth stages.
"""
def __init__(self, grid_world):
"""
Create the observation strategy.
Args:
grid_world: The grid world environment to observe
"""
super().__init__()
self.observation_shape = (grid_world.width, grid_world.height, 7)
def observation_space(self, agent):
# The observation_space method determines the structure of observations that agents will receive;
# it must return a Gymnasium Space. Here, we use a Box, which simply means that observations are
# tuples, of size `self.observation_shape`, each element being a float32 between 0 and 1.
def observation_space(self, agent):
"""Define the observation space."""
return Box(low=0, high=1, shape=self.observation_shape, dtype=np.float32)
def get_observation(self, grid_world, agent):
"""Generate a complete observation but without every features of the grid."""
obs = np.zeros(self.observation_shape, dtype=np.float32)
for x in range(self.observation_shape[0]):
for y in range(self.observation_shape[1]):
cell = grid_world.get_cell((x, y))
# Feature 1: Pollution level (normalized)
pollution_normalized = 0.0
if cell.pollution is not None:
pollution_normalized = (
(cell.pollution - grid_world.min_pollution) /
(grid_world.max_pollution -
grid_world.min_pollution)
)
obs[x, y, 0] = pollution_normalized
# Feature 2: Flower growth stage (normalized)
if cell.has_flower():
growth_stage_normalized = (
(cell.flower.current_growth_stage + 1) /
(cell.flower.num_growth_stage + 1)
)
obs[x, y, 1] = growth_stage_normalized
# You can add any features you want by adding more channels to the obs array.
return obs
return obs
Then, add your new observation type to the make_env() function:
def make_env(config):
"""Create the environment based on the configuration."""
# Existing code...
elif observation_type == "partial":
obs_range = config.observation.get("range", 1)
observation_strategy = PartialObservation(
obs_range
)
elif observation_type == "total_pollution_flowers":
observation_strategy = TotalObservationPollutionFlowers(
grid_world=grid_world
)
# Existing code...
3. Adding or Modifying Metrics¶
Metrics allow tracking agent performance and environment state.
Current Structure¶
The class stores metrics in a dictionary
export_metrics()exports metrics to a CSV filesend_metrics()sends metrics to Weights & Biases
How to Add New Metrics¶
Add new keys to the
metricsdictionary in initializationUpdate metrics during simulation
Modify
_prepare_metrics()to include your new metrics
Example: Adding Diversity Metrics¶
def __init__(self, ...):
# Existing code...
self.metrics = {
# Existing metrics...
# New metrics
"diversity": {},
"agent_cooperation_score": 0.0,
}
def update_metrics(self, grid_world, agents, rewards):
"""Update metrics based on the current state of the grid."""
# Existing code to update metrics...
# Calculate new metrics
def _prepare_metrics(self):
"""Prepare a formatted dictionary of metrics for export or sending."""
metrics_dict = {
# Existing metrics...
}
# Add new metrics
metrics_dict['diversity'] = diversity
metrics_dict['agent_cooperation_score'] = self.metrics["agent_cooperation_score"]
return metrics_dict
4. Adding Actions and Handling Them¶
Actions determine what agents can do in the environment. The action and actionhandler modules manage actions.
Current Structure¶
actiondefines the enumeration of possible actionsActionHandlerimplements action handling
How to Add New Actions¶
Modify the
create_action_enum()function inactionAdd a handling method in
ActionHandlerUpdate the
handle_action()method to call your new methodUpdate the
update_action_mask()method to include your new action
Example: Adding a Pollution Cleaning Action¶
First, modify create_action_enum() function in action:
def create_action_enum(num_flower_type=1):
"""Dynamically create an enumeration of actions."""
actions = {
'UP': 0,
'DOWN': 1,
'LEFT': 2,
'RIGHT': 3,
'HARVEST': 4,
'WAIT': 5,
'CLEAN': 6, # New action for cleaning pollution
}
for i in range(num_flower_type):
action_name = f'PLANT_TYPE_{i}'
actions[action_name] = auto()
return Enum('Action', actions, type=_ActionEnum)
Then, add a method in ActionHandler:
def clean_pollution(self, agent):
"""
Clean pollution at the agent's current position.
This action reduces pollution in the current cell by a fixed amount.
Args:
agent: The agent performing the cleaning action
"""
# handle the cleaning action
Finally, update handle_action():
def handle_action(self, agent, action):
"""Process an agent's action and execute it in the grid world."""
if action in [self.action_enum.UP, self.action_enum.DOWN,
self.action_enum.LEFT, self.action_enum.RIGHT]:
self.move_agent(agent, action)
elif action == self.action_enum.HARVEST:
self.harvest_flower(agent)
elif action == self.action_enum.WAIT:
self.wait(agent)
elif action == self.action_enum.CLEAN:
self.clean_pollution(agent)
else: # Assume action is a PLANT_TYPE_i action
self.plant_flower(agent, action.flower_type)
Don’t forget to update update_action_mask() to handle the new action:
def update_action_mask(self, agent):
"""Update the action mask for the agent."""
# Existing code...
# Always allow cleaning action if the cell has pollution
cell = self.grid_world.get_cell(agent.position)
if cell.pollution is None:
mask[self.action_enum.CLEAN.value] = 0
# Rest of existing code...
5. Adding a Cell Type¶
Cell types define different parts of the environment. They are defined in gridworld.
Current Structure¶
How to Add a New Cell Type¶
Add a new value to the
CellTypeenumerationModify the
Cellclass to handle the new typeUpdate the methods of the
CellclassModify the grid initialization to include the new cell type
Modify the config
Modify the renderers to visualize the new cell type
Example: Adding a “WATER” Cell Type¶
class CellType(Enum):
"""Enum representing the possible types of cells in the grid world."""
GROUND = 0
OBSTACLE = 1
WATER = 2 # New cell type
Modify the Cell class to handle this new type:
def __init__(self, cell_type, pollution=50, pollution_increment=1):
"""Create a new cell."""
self.cell_type = cell_type
self.flower = None
self.agent = None
if cell_type == CellType.GROUND:
self.pollution = pollution
elif cell_type == CellType.OBSTACLE:
self.pollution = None
elif cell_type == CellType.WATER:
self.pollution = pollution * 0.5 # Water initially has less pollution
self.pollution_increment = pollution_increment
def update_pollution(self, min_pollution, max_pollution):
"""Update the pollution level of the cell based on its current state."""
if self.pollution is None:
return
if self.has_flower():
self.pollution = max(
self.pollution - self.flower.get_pollution_reduction(),
min_pollution
)
else:
# Water self-cleans
if self.cell_type == CellType.WATER:
self.pollution = max(
self.pollution - self.pollution_increment * 0.5,
min_pollution
)
else:
self.pollution = min(
self.pollution + self.pollution_increment,
max_pollution
)
def can_walk_on(self):
"""Check if agents can walk on this cell."""
return self.cell_type in [CellType.GROUND, CellType.WATER]
def can_plant_on(self):
"""Check if a flower can be planted in this cell."""
# Cannot plant in water
return self.cell_type == CellType.GROUND and not self.has_flower()
Modify the grid initialization to include the new cell type by doing one of the following:
Add the following to
init_from_file()after placing ground and obstacle cells:
elif cell_code == 'W':
grid[i][j] = Cell(CellType.WATER)
Add a water_ratio parameter to
init_random()and add the following after placing ground and obstacle cells and updating valid_positions:
# Place obstacles randomly
indices = np.arange(len(valid_positions)) # choice needs indices
num_waters = int(water_ratio * width * height)
selected_indices = random_generator.choice(indices,
num_waters,
replace=False)
water_positions = [valid_positions[i] for i in selected_indices]
for pos in water_positions:
i, j = pos
grid[i][j] = Cell(CellType.WATER)
valid_positions.remove(pos)
If you want to use
init_from_code(), you don’t need to modify the code.
Modify the config to include the new cell type:
Modify from_code.yaml to place water cell or add the
water_ratioparameter in random.yaml if you want to place water cells in the grid.Modify console.yaml, graphical.yaml and full.yaml to include the new cell type in the characters and colors dictionaries.
Modify the renderers to visualize the new cell type:
Add the following to the render() method of ConsoleRenderer after defining the character for ground and obstacle cells:
elif cell.cell_type == CellType.WATER:
cell_char = self.characters.get('water', 'W')
Add the following to the render() method of GraphicalRenderer after defining the color for ground and obstacle cells:
elif cell.cell_type == CellType.WATER:
cell_color = self.colors['water']
6. Adding a New Renderer Type¶
Renderers visualize the simulation environment. The renderer module defines an abstract Renderer class and concrete implementations like GraphicalRenderer (using Pygame) and ConsoleRenderer (text-based).
Current Structure¶
Renderer: Abstract base class with methods:init(): Sets up the rendering environmentrender(): Renders the current state (abstract)display_render(): Updates the display with the current frame (abstract)end_render(): Finalizes rendering and handles cleanupConcrete implementations:
GraphicalRenderer: Colorful Pygame visualizationConsoleRenderer: Text-based visualization in terminal
How to Add a New Renderer¶
Create a new class that inherits from
RendererImplement the required abstract methods
Register your renderer in the configuration system
Modify the
make_env()function to include your new renderer
Example: Adding a Heatmap Renderer¶
class HeatmapRenderer(Renderer):
"""
Renderer that visualizes the environment as a pollution heatmap using matplotlib.
This renderer focuses on pollution levels across the grid, displaying them
as a color-coded heatmap.
"""
def __init__(self, post_analysis_on=False, out_dir_path=None, cmap='coolwarm'):
"""
Create the heatmap renderer.
Args:
post_analysis_on (bool, optional): Flag to enable saving frames for
post-simulation video generation. Defaults to False.
out_dir_path (str, optional): Directory path where output files will be saved.
Required if post_analysis_on is True. Defaults to None.
cmap (str, optional): Matplotlib colormap to use for the heatmap.
Defaults to 'coolwarm'.
"""
super().__init__()
self.cmap = cmap
self.fig = None
self.ax = None
self.post_analysis_on = post_analysis_on
self.out_dir_path = out_dir_path
self.frames = []
# Initialize run_id for video output naming
self._run_id = None
if post_analysis_on:
import time
self._run_id = int(time.time())
try:
import matplotlib.pyplot as plt
self.plt = plt
except ImportError:
warnings.warn("Cannot import matplotlib. Heatmap renderer will be disabled.")
self.display = False
self.post_analysis_on = False
def init(self, grid_world):
"""
Initialize the matplotlib figure based on the grid world dimensions.
"""
if self.display or self.post_analysis_on:
# Create a new figure and axis
self.fig, self.ax = self.plt.subplots()
# Set title
self.ax.set_title("Pollution Heatmap")
# Create initial empty heatmap
self.heatmap = self.ax.imshow(
[[0 for _ in range(grid_world.width)] for _ in range(grid_world.height)],
cmap=self.cmap,
vmin=grid_world.min_pollution,
vmax=grid_world.max_pollution
)
# This method should not display anything directly; it should only prepare the data to be displayed.
# The actual display is handled in display_render.
def render(self, grid_world, agents):
"""
Render the current state of the grid world as a heatmap.
"""
if self.display or self.post_analysis_on:
# Create pollution data array
pollution_data = [[0 for _ in range(grid_world.width)] for _ in range(grid_world.height)]
# Fill in pollution data
for i in range(grid_world.height):
for j in range(grid_world.width):
cell = grid_world.get_cell((i, j))
pollution_data[i][j] = cell.pollution if cell.pollution is not None else 0
# Update the heatmap data
self.heatmap.set_data(pollution_data)
# Draw agents as markers
for agent_id, agent in agents.items():
i, j = agent.position
self.ax.plot(j, i, 'o', color='black')
# If post_analysis is enabled, save the current frame
if self.post_analysis_on:
# Convert plot to image and add to frames
image = self.plt.imread(self.fig.canvas.buffer_rgba())
self.frames.append(image)
# Compared to render, this method is usually very simple, just updating the display with the current frame.
# It allows the use of a flag to disable display while still saving frames for post-analysis.
def display_render(self):
"""
Display the rendered frame in a matplotlib window.
"""
if self.display:
self.fig.canvas.draw()
# This method usually handles cleanup and saving of any post-analysis data so it is not needed if no
# operations are done at the end of the rendering.
def end_render(self):
"""
Finalize the rendering process and clean up resources.
"""
# If post_analysis is enabled and we have frames, create a video
if self.post_analysis_on and self.frames:
# Copy the _create_video method from GaphicalRenderer and use it here
print(f"Heatmap video saved at {output_path}")
# Close the matplotlib figure
self.plt.close(self.fig)
To use this new renderer, you would configure it in your config:
renderer:
heatmap:
enabled: true
post_analysis_on: true
out_dir_path: "./videos"
cmap: 'coolwarm'
And modify the make_env() function to include the new renderer:
# Initialise renderer
self.renderers = []
# Existing renderers
if config.renderer.heatmap.get("enabled", False):
post_analysis_on = config.renderer.heatmap.get(
"post_analysis_on", False
)
out_dir = config.renderer.heatmap.get("out_dir_path", "./videos")
cmap = config.renderer.heatmap.get("cmap", 'coolwarm')
heatmap_renderer = HeatmapRenderer(
post_analysis_on=post_analysis_on,
out_dir_path=out_dir,
cmap=cmap
)
self.renderers.append(heatmap_renderer)