Using Ethical Gardeners with Custom Algorithms¶
This tutorial explains how to integrate your own reinforcement learning algorithms with the Ethical Gardeners environment.
Compatibility with Custom Algorithms¶
The Ethical Gardeners environment is designed to work with various reinforcement learning algorithms, including custom ones. While we provide built-in support for Stable Baselines 3, you can use your own algorithms by following these guidelines.
Requirements for Custom Algorithms¶
To use your algorithm with the utility functions provided by Ethical Gardeners:
For the
train()function, your algorithm should have:a
learn()method that accepts atotal_timestepsparametera
save()method to persist the trained model
For the
evaluate()and thepredict_action()functions, your algorithm should have:a
predict()method that takes observations and if you want action masks and returns actions
Using the Core Utility Functions¶
The following code snippet is a minimal example of how to use the training, evaluation, and prediction functions with a custom algorithm. The example uses the MaskablePPO algorithm from sb3-contrib for an example of an algorithm that supports action masking and the DQN algorithm from Stable Baselines 3 for an example of an algorithm that does not support action masking.
For training, the train() function accepts a model so the model must be instantiated
beforehand. In the example, we instantiate the model with the environment and a policy. The environment is either a default
Ethical Gardeners environment made with the make_env() function or a vectorized one with
multiple environments.
For training, evaluation and prediction, you must say whether your algorithm supports action masking or not. If it does, the
needs_action_mask parameter should be set to True. If it does not, it should be set to False.
from copy import deepcopy
import hydra
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
# from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv # , SubprocVecEnv
from ethicalgardeners import algorithms, make_env
from ethicalgardeners.main import run_simulation, _find_config_path
@hydra.main(version_base=None, config_path=_find_config_path())
def main(config):
algo = "maskable_ppo" # "dqn" or "maskable_ppo"
# ---- Training ----
num_envs = 10
total_timesteps = 0
configs = [deepcopy(config) for _ in range(num_envs)]
for i, cfg in enumerate(configs):
cfg["random_seed"] = 42 + i # a different seed for each env
cfg["num_iterations"] = 2048
total_timesteps += cfg["num_iterations"]
# When num_envs > 1, multiple environments are created using the provided
# configs and run in parallel using either SubprocVecEnv or DummyVecEnv.
# When num_envs = 1, a single environment is created using the first config
if num_envs > 1:
env_fns = [algorithms.make_env_thunk(make_env, cfg) for cfg in configs]
vec_cls = DummyVecEnv # or SubprocVecEnv
env = vec_cls(env_fns)
else:
env = algorithms.make_SB3_env(make_env, configs[0])
# Create the model using the provided model function
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=3)
# or model = DQN("MlpPolicy", env, verbose=3)
algorithms.train(model, algo, total_timesteps=total_timesteps)
env.close()
# ---- Evaluation ----
env = make_env(config)
policy_path = algorithms.get_latest_policy(algo)
model = MaskablePPO.load(policy_path)
# or model = DQN.load(policy_path)
round_rewards, total_rewards, winrate, scores = algorithms.evaluate(
env, model, algo,
num_games=5, deterministic=True,
needs_action_mask=True # True for MaskablePPO, False for DQN
)
print("Rewards by round: ", round_rewards)
print("Total rewards (incl. negative rewards): ", total_rewards)
print("Winrate: ", winrate)
print("Final scores: ", scores)
# ---- Use the trained model as an agent in the environment ----
env = make_env(config)
agent_algorithms = [model for _ in range(2)]
# Main loop for the environment
run_simulation(
env, agent_algorithms, deterministic=[True, True],
needs_action_mask=[True, True]
)
if __name__ == "__main__":
main()