Source code for smartgrid.observation.observation_manager

"""
The ObservationManager is responsible for computing observations.
"""
import dataclasses
from typing import Dict, Type

from smartgrid.agents import Agent
from smartgrid.world import World
from .base_observation import BaseObservation, Observation
from .global_observation import GlobalObservation
from .local_observation import LocalObservation


def _create_observation_type(
        global_observation_type: Type[GlobalObservation],
        local_observation_type: Type[LocalObservation]
) -> Type[Observation]:
    """
    Create a new class that represents an Observation.

    An Observation merges data from both Global and Local observations.
    """
    @dataclasses.dataclass(frozen=True)
    class _Observation(
        global_observation_type,
        local_observation_type,
        Observation
    ):
        # Add specific "fields" to get the original "global_obs" and "local_obs"
        # objects. They should not be used in object representation, nor in
        # comparisons, nor in `fields` and `asdict`, ... Basically, not anywhere.
        _global_obs: global_observation_type = dataclasses.field(
            repr=False,
            compare=False,
            metadata={'include': False}
        )
        _local_obs: local_observation_type = dataclasses.field(
            repr=False,
            compare=False,
            metadata={'include': False}
        )

        # Override the qualname so that the `str` method returns a non-garbage
        # (easily understandable) class name. By default, it would return
        # `_create_observation_type.<locals>._Observation(personal_storage=...)`
        # which is ugly and hard to understand. `'Observation'` is much better,
        # even though it is not exactly correct (the class is indeed defined as
        # a local variable of a function call, but that is not important to
        # the third-party users).
        __qualname__ = 'Observation'

        @classmethod
        def create(cls,
                   global_observation: global_observation_type,
                   local_observation: local_observation_type):
            obj = cls(
                **global_observation.asdict(),
                **local_observation.asdict(),
                _global_obs=global_observation,
                _local_obs=local_observation,
            )
            return obj

        def get_global_observation(self):
            return self._global_obs

        def get_local_observation(self):
            return self._local_obs

    return _Observation


[docs] class ObservationManager: """ The ObservationManager is responsible for computing observations. Its primary purpose is to allow extensibility: the attributes :py:attr:`.global_observation` and :py:attr:`.local_observation`, which are set through the constructor, control which Observation classes will be used in the simulator. It is thus possible to subclass :py:class:`.GlobalObservation` and/or :py:class:`.LocalObservation` to use different observations. The computing calls (:py:meth:`.compute_agent` and :py:meth:`.compute_global`) are delegated to the corresponding calls through these attributes. """ global_observation: Type[GlobalObservation] """ The class that will be used to compute global observations. It should be a subclass of :py:class:`.GlobalObservation` to ensure that necessary methods are present. Please note that this field should be set to a *class* itself, not an instance, e.g., ``GlobalObservation`` (instead of ``GlobalObservation()``). """ local_observation: Type[LocalObservation] """ The class that will be used to compute local observations. It should be a subclass of :py:class:`.LocalObservation` to ensure that necessary methods are present. Please note that this field should be set to a *class* itself, not an instance, e.g., ``LocalObservation`` (instead of ``LocalObservation()``). """ observation: Type[Observation] """ The class that represents the "whole" observation (local and global). It combines fields from the :py:attr:`.global_observation` and :py:attr:`.local_observation` dataclasses. Because these two attributes are set at runtime, this class is dynamically created. To simplify usage, it supports the methods defined in :py:class:`.BaseObservation` (``fields``, ``asdict``, and transformation to NumPy array with ``np.asarray``). """
[docs] def __init__( self, local_observation: Type[LocalObservation] = LocalObservation, global_observation: Type[GlobalObservation] = GlobalObservation, ): self.global_observation = global_observation self.local_observation = local_observation self.observation = _create_observation_type(global_observation, local_observation)
[docs] def compute_agent(self, world: World, agent: Agent) -> LocalObservation: """ Create the local observation for an Agent. """ return self.local_observation.compute(world, agent)
[docs] def compute_global(self, world) -> GlobalObservation: """ Create the global observation for the World. """ return self.global_observation.compute(world)
def compute(self, world: World, agent: Agent) -> Observation: global_obs = self.compute_global(world) local_obs = self.compute_agent(world, agent) return self.observation.create(global_obs, local_obs) @property def shape(self) -> Dict[str, int]: """ Describe the shapes of the various Observations (local, global, merged). :rtype: dict :return: A dict comprised of: ``agent_state``, ``local_state``, and ``global_state``. Each of these fields describe the shape (i.e., number of dimensions) of the corresponding observation. Note that ``agent_state`` refers to the merged (both local and global) case. """ nb_local = len(self.local_observation.fields()) nb_global = len(self.global_observation.fields()) return { "agent_state": nb_local + nb_global, "local_state": nb_local, "global_state": nb_global }
[docs] def reset(self): """ Reset the ObservationManager. It is particularly important to reset the memoization process of :py:class:`.GlobalObservation`. """ self.global_observation.reset() self.local_observation.reset()