Source code for mate.wrappers.enhanced_observation

# pylint: disable=missing-module-docstring

from typing import Any, Dict, Optional, Tuple, Union

import gym
import numpy as np

from mate import constants as consts
from mate.wrappers.typing import BaseEnvironmentType, WrapperMeta, assert_base_environment


try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


[docs]class EnhancedObservation(gym.ObservationWrapper, metaclass=WrapperMeta): """Enhance the agent's observation, which sets all observation mask to True. The targets can observe the empty status of all warehouses even when far away. """ def __init__( self, env: BaseEnvironmentType, team: Literal['both', 'camera', 'target', 'none'] = 'both' ) -> None: assert_base_environment(env) assert team in ( 'both', 'camera', 'target', 'none', ), f'Invalid argument team {team!r}. Expect one of {("both", "camera", "target", "none")}.' # pylint: disable-next=import-outside-toplevel from mate.wrappers.relative_coordinates import RelativeCoordinates assert not isinstance(env, RelativeCoordinates), ( f'You should use wrapper `{self.__class__}` before `RelativeCoordinates`. ' f'Please check your wrapper order. ' f'Got env = {env}.' ) # pylint: disable-next=import-outside-toplevel from mate.wrappers.rescaled_observation import RescaledObservation assert not isinstance(env, RescaledObservation), ( f'You should use wrapper `{self.__class__}` before `RescaledObservation`. ' f'Please check your wrapper order. ' f'Got env = {env}.' ) super().__init__(env) self.team = team self.enhanced_camera = self.team in ('camera', 'both') self.enhanced_target = self.team in ('target', 'both') numbers = (env.num_cameras, env.num_targets, env.num_obstacles) self.camera_slices = consts.camera_observation_slices_of(*numbers) self.target_slices = consts.target_observation_slices_of(*numbers) self.target_indices = consts.target_observation_indices_of(*numbers) self.target_empty_bits_slice = slice( self.target_indices[2] - consts.NUM_WAREHOUSES, self.target_indices[2] )
[docs] def load_config(self, config: Optional[Union[Dict[str, Any], str]] = None) -> None: """Reinitialize the Multi-Agent Tracking Environment from a dictionary mapping or a JSON/YAML file.""" self.env.load_config(config=config) self.__init__(self.env, team=self.team) # pylint: disable=unnecessary-dunder-call
[docs] def observation( self, observation: Tuple[np.ndarray, np.ndarray] ) -> Tuple[np.ndarray, np.ndarray]: if not (self.enhanced_camera or self.enhanced_target): return observation camera_joint_observation, target_joint_observation = observation offset = consts.PRESERVED_DIM camera_states_public = camera_joint_observation[ ..., offset : offset + consts.CAMERA_STATE_DIM_PUBLIC ] target_states_public = target_joint_observation[ ..., offset : offset + consts.TARGET_STATE_DIM_PUBLIC ] camera_states_public_flagged = np.hstack( [camera_states_public, np.ones((self.num_cameras, 1))] ) target_states_public_flagged = np.hstack( [target_states_public, np.ones((self.num_targets, 1))] ) obstacle_states_flagged = self.obstacle_states_flagged if self.enhanced_camera: camera_joint_observation[ ..., self.camera_slices['opponent_states_with_mask'] ] = target_states_public_flagged.ravel()[np.newaxis, ...] camera_joint_observation[ ..., self.camera_slices['obstacle_states_with_mask'] ] = obstacle_states_flagged.ravel()[np.newaxis, ...] camera_joint_observation[ ..., self.camera_slices['teammate_states_with_mask'] ] = camera_states_public_flagged.ravel()[np.newaxis, ...] if self.enhanced_target: target_joint_observation[..., self.target_empty_bits_slice] = np.logical_not( self.remaining_cargoes ).all(axis=-1)[np.newaxis, ...] target_joint_observation[ ..., self.target_slices['opponent_states_with_mask'] ] = camera_states_public_flagged.ravel()[np.newaxis, ...] target_joint_observation[ ..., self.target_slices['obstacle_states_with_mask'] ] = obstacle_states_flagged.ravel()[np.newaxis, ...] target_joint_observation[ ..., self.target_slices['teammate_states_with_mask'] ] = target_states_public_flagged.ravel()[np.newaxis, ...] return ( camera_joint_observation.astype(np.float64), target_joint_observation.astype(np.float64), )
def __str__(self) -> str: return f'<{type(self).__name__}(team={self.team}){self.env}>'