Source code for mate.wrappers.extra_communication_delays
# pylint: disable=missing-module-docstring
import heapq
from typing import Callable, Iterable, Tuple, Union
import gym
import numpy as np
from mate.utils import Message
from mate.wrappers.typing import MateEnvironmentType, WrapperMeta, assert_mate_environment
[docs]class ExtraCommunicationDelays(gym.Wrapper, metaclass=WrapperMeta):
"""Add extra message delays to communication channels. (Not used in the evaluation script.)
Users can use this wrapper to implement a communication channel with random delays.
"""
def __init__(
self,
env: MateEnvironmentType,
delay: Union[int, Callable[[MateEnvironmentType, Message], int]] = 3,
) -> None:
assert_mate_environment(env)
assert callable(delay) or (isinstance(delay, int) and delay > 0), (
f'The argument `delay` should be a callable function or a constant positive integer. '
f'Got delay = {delay}.'
)
super().__init__(env)
# A function with signature: (env, message) -> int
# or a constant positive integer.
self.delay = delay
self.heap = []
[docs] def reset(self, **kwargs) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
self.heap = []
return self.env.reset(**kwargs)
[docs] def send_messages(self, messages: Union[Message, Iterable[Message]]) -> None:
"""Buffer the messages from an agent to others in the same team.
The environment will send the messages to recipients' through method
receive_messages(), and also info field of step() results.
"""
if isinstance(messages, Message):
messages = (messages,)
messages = list(messages)
assert (
len({m.team for m in messages}) <= 1
), f'All messages must be from the same team. Got messages = {messages}.'
for message in messages:
if callable(self.delay):
delay = self.delay(self.unwrapped, message)
else:
delay = self.delay
heapq.heappush(self.heap, (self.episode_step + delay, message))
messages = []
while len(self.heap) > 0 and self.heap[0][0] <= self.episode_step:
_, message = heapq.heappop(self.heap)
messages.append(message)
self.env.send_messages(messages)