Source code for mate.agents.mixture

"""Agent wrapper to built a mixture of agents."""

from abc import ABCMeta
from typing import List, Optional

import numpy as np
from gym.utils import seeding

from mate.agents.base import AgentBase, AgentType, CameraAgentBase, TargetAgentBase


__all__ = ['MixtureCameraAgent', 'MixtureTargetAgent']


# pylint: disable-next=too-many-instance-attributes
class MixtureAgentMixIn(AgentBase, metaclass=ABCMeta):
    """Helper class for mixture of agents.

    Randomly choose a underlying agent in candidates at episode start.
    """

    def __init__(self, candidates, weights=None, mixture_seed=None, seed=None):
        """Initialize the agent.
        This function will be called only once on initialization.
        """

        candidates = list(candidates)
        if weights is None:
            weights = [1.0] * len(candidates)
        weights = np.array(weights, dtype=np.float64, copy=True).ravel()
        assert len(candidates) == len(weights), (
            f'The number of sample weights must be equal to the number of agent candidates. '
            f'Got weights = {weights}.'
        )
        assert (weights >= 0.0).all() and weights.any(), (
            f'The sample weights for agent candidates must be non-negative and '
            f'should have at least one positive value. '
            f'Got weights = {weights}.'
        )

        self.candidates = [candidates[a] for a in np.flatnonzero(weights)]
        self.weights = weights[weights > 0.0]
        self.weights /= self.weights.sum()

        super().__init__(seed=seed)

        self._np_random_mixture = None
        self.seed_mixture(seed=mixture_seed)

        self.current_agent = None

    def clone(self) -> AgentType:
        """Clone an independent copy of the agent."""

        candidates = [candidate.clone() for candidate in self.candidates]
        seed = self.np_random.randint(np.iinfo(int).max)
        mixture_seed = self.np_random.randint(np.iinfo(int).max)

        clone = type(self)(
            candidates=candidates, weights=self.weights, mixture_seed=mixture_seed, seed=seed
        )
        return clone

    def spawn(self, num_agents: int) -> List[AgentType]:
        """Spawn new agents."""

        agents = [self.clone() for _ in range(num_agents)]
        mixture_seed = self.np_random.randint(np.iinfo(int).max)
        for agent in agents:
            agent.seed_mixture(seed=mixture_seed)

        return agents

    @property
    def np_random_mixture(self) -> np.random.RandomState:  # pylint: disable=no-member
        """The random number generator of the policy mixture."""

        if self._np_random_mixture is None:
            self.seed_mixture()
        return self._np_random_mixture

    def seed(self, seed: Optional[int] = None) -> List[int]:
        """Set seed for the agent's random number generator.
        This function will be called before the first call of reset().
        """

        seeds = super().seed(seed)

        int_max = np.iinfo(int).max
        for candidate in self.candidates:
            seeds.append(candidate.seed(self.np_random.randint(int_max))[0])

        return seeds

    def seed_mixture(self, seed: Optional[int] = None) -> List[int]:
        """Set seed for the random number generator of the policy mixture."""

        self._np_random_mixture, seed = seeding.np_random(seed=seed)
        return [seed]

    def reset(self, observation):
        """Reset the agent.
        This function will be called immediately after env.reset().
        """

        super().reset(observation)

        self.current_agent = self.np_random_mixture.choice(self.candidates, p=self.weights)
        self.current_agent.reset(observation)

    def observe(self, observation, info=None):
        """The agent observe the environment before sending messages.
        This function will be called before send_requests().
        """

        self.state, self.last_observation, self.last_info, _ = self.check_inputs(observation, info)

        self.current_agent.observe(observation, info)

    def act(self, observation, info=None, deterministic=None):
        """Get the agent action by the observation.
        This function will be called before every env.step().
        """

        self.state, observation, info, _ = self.check_inputs(observation, info)

        action = self.current_agent.act(observation, info, deterministic=deterministic)
        return action

    def send_requests(self):
        """Prepare messages to communicate with other agents in the same team.
        This function will be called after observe() but before receive_requests().
        """

        return self.current_agent.send_requests()

    def receive_requests(self, messages):
        """Receive messages from other agents in the same team.
        This function will be called after send_requests().
        """

        self.last_requests = messages = tuple(messages)

        self.current_agent.receive_requests(messages)

    def send_responses(self):
        """Prepare messages to communicate with other agents in the same team.
        This function will be called after receive_requests().
        """

        return self.current_agent.send_responses()

    def receive_responses(self, messages):
        """Receive messages from other agents in the same team.
        This function will be called after send_responses() but before act().
        """

        self.last_responses = messages = tuple(messages)

        self.current_agent.receive_responses(messages)

    def __str__(self):
        return super().__str__() + str(tuple(zip(self.weights, self.candidates)))

    def __repr__(self):
        return super().__repr__() + repr(tuple(zip(self.weights, self.candidates)))


[docs]class MixtureCameraAgent(MixtureAgentMixIn, CameraAgentBase): """Mixture Camera Agent Randomly choose a underlying camera agent in candidates at episode start. """
[docs]class MixtureTargetAgent(MixtureAgentMixIn, TargetAgentBase): """Mixture Target Agent Randomly choose a underlying target agent in candidates at episode start. """