Source code for mate.wrappers.more_training_information

# pylint: disable=missing-module-docstring

import itertools
from typing import List, Tuple, Union

import gym
import numpy as np

from mate import constants as consts
from mate.wrappers.typing import BaseEnvironmentType, WrapperMeta, assert_base_environment


[docs]class MoreTrainingInformation(gym.Wrapper, metaclass=WrapperMeta): """Add more environment and agent information to the info field of step(), enabling full observability of the environment. (Not used in the evaluation script.) """ def __init__(self, env: BaseEnvironmentType) -> None: assert_base_environment(env) assert not isinstance( env, MoreTrainingInformation ), f'You should not use wrapper `{self.__class__}` more than once.' super().__init__(env) # pylint: disable-next=too-many-locals
[docs] def step( self, action: Tuple[np.ndarray, np.ndarray] ) -> Union[ # original form Tuple[ Tuple[np.ndarray, np.ndarray], Tuple[float, float], bool, Tuple[List[dict], List[dict]] ], # repeated reward and individual done Tuple[ Tuple[np.ndarray, np.ndarray], Tuple[List[float], List[float]], Tuple[List[bool], List[bool]], Tuple[List[dict], List[dict]], ], ]: ( (camera_joint_observation, target_joint_observation), _, _, (camera_infos, target_infos), ) = results = self.env.step(action) offset = consts.PRESERVED_DIM camera_states_private = camera_joint_observation[ ..., offset : offset + consts.CAMERA_STATE_DIM_PRIVATE ] target_states_private = target_joint_observation[ ..., offset : offset + consts.TARGET_STATE_DIM_PRIVATE ] remaining_cargo_counts = self.remaining_cargoes.sum(axis=-1) # Information for cameras for c, camera_info in enumerate(camera_infos): camera_info.update( num_tracked=self.camera_target_view_mask[c, ...].sum(), is_sensed=self.target_camera_view_mask[..., c].any(), ) # Information for targets for t, target_info in enumerate(target_infos): goal = self.target_goals[t] warehouse_distances = np.maximum( self.target_warehouse_distances[t] - consts.WAREHOUSE_RADIUS, 0.0, dtype=np.float64 ) goal_distance = warehouse_distances[goal] if goal >= 0 else consts.TERRAIN_WIDTH / 2.0 target_info.update( goal=goal, goal_distance=goal_distance, warehouse_distances=warehouse_distances, individual_done=self.target_dones[t], is_tracked=self.camera_target_view_mask[..., t].any(), is_colliding=self.targets[t].is_colliding, ) # Enable full observability state = self.state() for info in itertools.chain(camera_infos, target_infos): info.update( state=state.copy(), camera_states=camera_states_private.copy(), target_states=target_states_private.copy(), obstacle_states=self.obstacle_states.copy(), camera_target_view_mask=self.camera_target_view_mask.copy(), camera_obstacle_view_mask=self.camera_obstacle_view_mask.copy(), target_camera_view_mask=self.target_camera_view_mask.copy(), target_obstacle_view_mask=self.target_obstacle_view_mask.copy(), target_target_view_mask=self.target_target_view_mask.copy(), remaining_cargoes=self.remaining_cargoes.copy(), remaining_cargo_counts=remaining_cargo_counts.copy(), awaiting_cargo_counts=self.awaiting_cargo_counts.copy(), ) return results