# pylint: disable=missing-module-docstring
from typing import Any, Dict, Optional, Tuple, Union
import gym
import numpy as np
from gym import spaces
from mate import constants as consts
from mate.wrappers.typing import BaseEnvironmentType, WrapperMeta, assert_base_environment
def indices_of_nearest_grid_point(continuous: np.ndarray, grid: np.ndarray) -> np.ndarray:
"""Convert continuous values to the indices of the nearest grid points."""
diff = continuous - grid[:, np.newaxis, :]
indices = np.argmin(np.linalg.norm(diff, axis=-1), axis=0)
return indices
[docs]class DiscreteCamera(gym.ActionWrapper, metaclass=WrapperMeta):
"""Wrap the environment to allow cameras to use discrete actions."""
def __init__(self, env: BaseEnvironmentType, levels: int = 5) -> None:
assert_base_environment(env)
assert not isinstance(
env, DiscreteCamera
), f'You should not use wrapper `{self.__class__}` more than once. Got env = {env}.'
assert levels >= 3 and levels % 2 == 1, (
f'The discrete level must be an odd number that not less than 3. '
f'Got levels = {levels}.'
)
assert env.num_cameras > 0, 'There must be at least one camera in the environment.'
super().__init__(env)
self.levels = levels
self.camera_action_space = spaces.Discrete(levels * levels)
self.camera_joint_action_space = spaces.Tuple(
spaces=(self.camera_action_space,) * env.num_cameras
)
self.action_space = spaces.Tuple(
spaces=(self.camera_joint_action_space, env.target_joint_action_space)
)
self.action_high = np.asarray(
[env.camera_rotation_step, env.camera_zooming_step], dtype=np.float64
)
self.normalized_action_grid = self.discrete_action_grid(levels=self.levels)
[docs] def load_config(self, config: Optional[Union[Dict[str, Any], str]] = None) -> None:
"""Reinitialize the Multi-Agent Tracking Environment from a dictionary mapping or a JSON/YAML file."""
self.env.load_config(config=config)
self.__init__(self.env, levels=self.levels) # pylint: disable=unnecessary-dunder-call
[docs] def action(self, action: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""Convert joint action of cameras from discrete to continuous."""
camera_joint_action_discrete, target_joint_action = action
camera_joint_action_discrete = np.asarray(
camera_joint_action_discrete, dtype=np.int64
).ravel()
assert self.camera_joint_action_space.contains(tuple(camera_joint_action_discrete)), (
f'Joint action {tuple(camera_joint_action_discrete)} outside given '
f'joint action space {self.camera_joint_action_space}.'
)
camera_joint_action_continuous = (
self.action_high * self.normalized_action_grid[camera_joint_action_discrete]
)
return camera_joint_action_continuous, target_joint_action
[docs] def reverse_action(
self, action: Tuple[np.ndarray, np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
"""Convert joint action of cameras from continuous to discrete."""
camera_joint_action_continuous, target_joint_action = action
camera_joint_action_continuous = np.asarray(
camera_joint_action_continuous, dtype=np.float64
)
camera_joint_action_continuous = camera_joint_action_continuous.reshape(
self.num_cameras, consts.CAMERA_ACTION_DIM
)
camera_joint_action_discrete = indices_of_nearest_grid_point(
camera_joint_action_continuous / self.action_high, self.normalized_action_grid
)
return camera_joint_action_discrete, target_joint_action
def __str__(self) -> str:
return f'<{self.__class__.__name__}(levels={self.levels}){self.env}>'
[docs] @staticmethod
def discrete_action_grid(levels): # pylint: disable=missing-function-docstring
assert levels >= 3 and levels % 2 == 1, (
f'The discrete level must be an odd number that not less than 3. '
f'Got levels = {levels}.'
)
# num_actions = levels * levels
# ti, tj = i / (levels - 1), j / (levels - 1)
# xi = -1. * (1. - ti) + 1. * ti
# yj = -1. * (1. - tj) + 1. * tj
# action_grid[i + levels * j] = np.array([xi, yj])
normalized_action_grid = np.stack(
np.meshgrid(
np.linspace(start=-1.0, stop=+1.0, num=levels, endpoint=True),
np.linspace(start=-1.0, stop=+1.0, num=levels, endpoint=True),
),
axis=-1,
).reshape(-1, consts.CAMERA_ACTION_DIM)
return normalized_action_grid
[docs]class DiscreteTarget(gym.ActionWrapper, metaclass=WrapperMeta):
"""Wrap the environment to allow targets to use discrete actions."""
def __init__(self, env: BaseEnvironmentType, levels: int = 5) -> None:
assert_base_environment(env)
assert not isinstance(
env, DiscreteTarget
), f'You should not use wrapper `{self.__class__}` more than once. Got env = {env}.'
assert levels >= 3 and levels % 2 == 1, (
f'The discrete level must be an odd number that not less than 3. '
f'Got levels = {levels}.'
)
super().__init__(env)
self.levels = levels
self.target_action_space = spaces.Discrete(levels * levels)
self.target_joint_action_space = spaces.Tuple(
spaces=(self.target_action_space,) * env.num_targets
)
self.action_space = spaces.Tuple(
spaces=(env.camera_joint_action_space, self.target_joint_action_space)
)
self.action_high = env.target_step_size * np.ones(
(env.num_targets, consts.TARGET_ACTION_DIM), dtype=np.float64
)
self.normalized_action_grid = self.discrete_action_grid(levels=self.levels)
[docs] def load_config(self, config: Optional[Union[Dict[str, Any], str]] = None) -> None:
"""Reinitialize the Multi-Agent Tracking Environment from a dictionary mapping or a JSON/YAML file."""
self.env.load_config(config=config)
self.__init__(self.env, levels=self.levels) # pylint: disable=unnecessary-dunder-call
[docs] def reset(self, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
joint_observations = self.env.reset(**kwargs)
for t, target in enumerate(self.targets):
self.action_high[t] = target.step_size
return joint_observations
[docs] def action(self, action: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""Convert joint action of targets from discrete to continuous."""
camera_joint_action, target_joint_action_discrete = action
target_joint_action_discrete = np.asarray(
target_joint_action_discrete, dtype=np.int64
).ravel()
assert self.target_joint_action_space.contains(tuple(target_joint_action_discrete)), (
f'Joint action {tuple(target_joint_action_discrete)} outside given '
f'joint action space {self.target_joint_action_space}.'
)
target_joint_action_continuous = (
self.action_high * self.normalized_action_grid[target_joint_action_discrete]
)
return camera_joint_action, target_joint_action_continuous
[docs] def reverse_action(
self, action: Tuple[np.ndarray, np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
"""Convert joint action of targets from continuous to discrete."""
camera_joint_action, target_joint_action_continuous = action
target_joint_action_continuous = np.asarray(
target_joint_action_continuous, dtype=np.float64
)
target_joint_action_continuous = target_joint_action_continuous.shape(
self.num_targets, consts.TARGET_ACTION_DIM
)
target_joint_action_discrete = indices_of_nearest_grid_point(
target_joint_action_continuous / self.action_high, self.normalized_action_grid
)
return camera_joint_action, target_joint_action_discrete
def __str__(self) -> str:
return f'<{self.__class__.__name__}(levels={self.levels}){self.env}>'
[docs] @staticmethod
def discrete_action_grid(levels): # pylint: disable=missing-function-docstring
assert levels >= 3 and levels % 2 == 1, (
f'The discrete level must be an odd number that not less than 3. '
f'Got levels = {levels}.'
)
# num_actions = levels * levels
# ti, tj = i / (levels - 1), j / (levels - 1)
# xi = -1. * (1. - ti) + 1. * ti
# yj = -1. * (1. - tj) + 1. * tj
# norm = np.linalg.norm([xi, yj])
# bound = np.sqrt(1. + np.square(np.max(np.abs([xi, yj])) / np.min(np.abs([xi, yj]))))
# action_grid[i + levels * j] = (np.array([xi, yj]) / norm) * (norm / bound) = np.array([xi, yj]) / bound
action_grid = np.stack(
np.meshgrid(
np.linspace(start=-1.0, stop=+1.0, num=levels, endpoint=True),
np.linspace(start=-1.0, stop=+1.0, num=levels, endpoint=True),
),
axis=-1,
).reshape(-1, consts.TARGET_ACTION_DIM)
angle = np.arctan2(action_grid[..., -1], action_grid[..., 0])
bound = 1.0 / np.cos(np.pi * ((angle / np.pi + 0.25) % 0.5 - 0.25))
normalized_action_grid = action_grid / bound[..., np.newaxis]
return normalized_action_grid