ethicalgardeners.algorithms.predict_action¶
- ethicalgardeners.algorithms.predict_action(model, observation, action_mask, needs_action_mask=False, deterministic=True, **kwargs)[source]¶
Predict the next action using the model, considering the action mask if needed.
The action mask is used only if the algorithm supports it (e.g. MaskablePPO). Otherwise, if the chosen action is not valid, a valid action is chosen at random.
- Parameters:
model – A trained model instance to use for prediction. The model class should contain a predict method as in Stable Baselines3.
observation – The current observation from the environment.
action_mask – The action mask indicating valid actions.
needs_action_mask – Whether the algorithm requires an action mask (e.g., MaskablePPO) or not (e.g., DQN).
deterministic – Whether to use deterministic actions when predicting with the model.
**kwargs – Additional keyword arguments to pass to the model’s predict method.