Source code for mate.agents.greedy

"""Built-in greedy rule-based agents."""

import numpy as np

from mate.agents.base import CameraAgentBase, TargetAgentBase
from mate.utils import normalize_angle, sin_deg

__all__ = ['GreedyCameraAgent', 'GreedyTargetAgent']

[docs]class GreedyCameraAgent(CameraAgentBase): # pylint: disable=too-many-instance-attributes """Greedy Camera Agent Arbitrarily tracks the nearest target. If no target found, use previous action or generate a new random action. """
[docs] def __init__( self, seed=None, memory_period=25, filterout_unloaded=False, filterout_beyond_range=True ): """Initialize the agent. This function will be called only once on initialization. """ super().__init__(seed=seed) self.filterout_unloaded = filterout_unloaded self.filterout_beyond_range = filterout_beyond_range self.range_factor = 1.1 # 110% self.memory = None self.time2forget = None self.never_loaded = None self.memory_period = memory_period self.prev_action = self.DEFAULT_ACTION self.neighboring_teammate_states = {} self.message2send = {} self.communication_delay = None
[docs] def reset(self, observation): """Reset the agent. This function will be called immediately after env.reset(). """ super().reset(observation) target_states, tracked_bits = self.get_all_opponent_states(observation) self.memory = list(target_states) self.time2forget = self.memory_period * np.asarray(tracked_bits, dtype=np.int64) self.never_loaded = np.ones(self.num_targets, dtype=np.bool8) self.prev_action = self.DEFAULT_ACTION self.neighboring_teammate_states.clear() self.message2send.clear() self.communication_delay = np.zeros(self.num_teammates, dtype=np.int64) self.message2send['state'] = self.state.copy()
[docs] def observe(self, observation, info=None): """The agent observe the environment before sending messages. This function will be called before send_responses(). """ self.state, observation, info, messages = self.check_inputs(observation, info) self.process_messages(observation, messages)
[docs] def act(self, observation, info=None, deterministic=None): """Get the agent action by the observation. This function will be called before every env.step(). Arbitrarily track the nearest target. If no target found, use previous action or generate a new random action. """ self.state, observation, info, _ = self.check_inputs(observation, info) tracked_targets = [self.memory[t] for t in np.flatnonzero(self.time2forget)] if self.filterout_beyond_range: threshold = self.range_factor * self.state.max_sight_range tracked_targets = [ts for ts in tracked_targets if (ts - self.state).norm < threshold] if self.filterout_unloaded: tracked_targets = [ ts for ts in tracked_targets if ts.is_loaded or self.never_loaded[ts.index] ] if len(tracked_targets) > 0: action = self.act_from_target_states(tracked_targets) else: if self.np_random.binomial(1, 0.1) != 0: action = self.action_space.sample() else: action = self.prev_action self.prev_action = action return action
[docs] def process_messages(self, observation, messages): # pylint: disable=unused-argument """Process observation and prepare messages to teammates.""" self.time2forget = np.maximum(self.time2forget - 1, 0, dtype=np.int64) target_states, tracked_bits = self.get_all_opponent_states(observation) for t in np.flatnonzero(tracked_bits): self.time2forget[t] = self.memory_period self.memory[t] = target_states[t] if target_states[t].is_loaded: self.never_loaded[t] = False self.message2send.setdefault('target_states', []) self.message2send['target_states'].append(target_states[t])
[docs] def act_from_target_states(self, target_states): """Place the selected target at the center of the field of view.""" assert ( len(target_states) > 0 ), 'You should provide at least one target to compute the action.' def select_target(): """Select the nearest target.""" return min(target_states, key=lambda ts: (ts - self.state).norm) def best_orientation(): return (target_state - self.state).angle def best_viewing_angle(): distance = (target_state - self.state).norm if ( distance * (1.0 + sin_deg(self.state.min_viewing_angle / 2.0)) >= self.state.max_sight_range ): return self.state.min_viewing_angle area_product = self.state.viewing_angle * np.square(self.state.sight_range) if distance <= np.sqrt(area_product / 180.0) / 2.0: return min(180.0, MAX_CAMERA_VIEWING_ANGLE) best = min(180.0, MAX_CAMERA_VIEWING_ANGLE) for _ in range(20): sight_range = distance * (1.0 + sin_deg(min(best / 2.0, 90.0))) best = area_product / np.square(sight_range) return np.clip(best, a_min=self.state.min_viewing_angle, a_max=MAX_CAMERA_VIEWING_ANGLE) target_state = select_target() return np.asarray( [ normalize_angle(best_orientation() - self.state.orientation), best_viewing_angle() - self.state.viewing_angle, ] ).clip(min=self.action_space.low, max=self.action_space.high)
[docs] def send_responses(self): """Prepare messages to communicate with other agents in the same team. This function will be called before receive_responses(). Send the newest target states to teammates if necessary. """ messages = [] self.communication_delay = np.maximum(self.communication_delay - 1, 0, dtype=np.int64) if len(self.message2send) > 0: for c in range(self.num_cameras): if c == self.index or self.communication_delay[c] > 0: continue content = self.message2send.copy() if 'target_states' in content: if c in self.neighboring_teammate_states and self.filterout_beyond_range: teammate_state = self.neighboring_teammate_states[c] threshold = self.range_factor * teammate_state.max_sight_range content['target_states'] = [ ts for ts in content['target_states'] if (ts - teammate_state).norm < threshold ] if len(content['target_states']) == 0: del content['target_states'] else: del content['target_states'] if len(content) > 0: messages.append(self.pack_message(recipient=c, content=content)) delay = self.np_random.randint(self.memory_period // 4, 2 * self.memory_period) self.communication_delay[c] = delay self.message2send.clear() return messages
[docs] def receive_responses(self, messages): """Receive messages from other agents in the same team. This function will be called after send_responses() but before act(). Receive and process messages from teammates. """ self.last_responses = tuple(messages) for message in self.last_responses: if 'state' in message.content: teammate_state = message.content['state'] is_neighboring = True if self.filterout_beyond_range: distance = (teammate_state - self.state).norm threshold = ( self.state.max_sight_range + self.range_factor * teammate_state.max_sight_range ) is_neighboring = distance < threshold if is_neighboring: self.neighboring_teammate_states[message.sender] = teammate_state elif message.sender in self.neighboring_teammate_states: del self.neighboring_teammate_states[message.sender] self.neighboring_teammate_states[message.sender] = teammate_state for target_state in message.content.get('target_states', []): self.memory[target_state.index] = target_state self.time2forget[target_state.index] = self.memory_period if target_state.is_loaded: self.never_loaded[target_state.index] = False
[docs]class GreedyTargetAgent(TargetAgentBase): # pylint: disable=too-many-instance-attributes """Greedy Target Agent Arbitrarily runs towards the destination (desired warehouse) with some noise. """
[docs] def __init__(self, seed=None, noise_scale=0.5): """Initialize the agent. This function will be called only once on initialization. """ super().__init__(seed=seed) self.noise_scale = float(noise_scale) self.goal_bits = None self.prev_state = None self.prev_noise = None self.non_empty_warehouses = set(range(NUM_WAREHOUSES)) self.need_communication = False
@property def goal(self): """Index of the current warehouse.""" if self.goal_bits is not None and self.goal_bits.any(): return np.flatnonzero(self.goal_bits)[0] return None @property def goal_location(self): """Location of the current warehouse.""" goal = self.goal if goal is not None: return WAREHOUSES[goal] return None
[docs] def reset(self, observation): """Reset the agent. This function will be called immediately after env.reset(). """ super().reset(observation) self.prev_state = self.state self.prev_noise = 0.5 * self.action_space.sample() self.goal_bits = self.state.goal_bits.copy() self.non_empty_warehouses = set(range(NUM_WAREHOUSES)) self.need_communication = False
[docs] def observe(self, observation, info=None): """The agent observe the environment before sending messages. This function will be called before send_responses(). """ self.state, observation, info, messages = self.check_inputs(observation, info) self.process_messages(observation, messages)
[docs] def act(self, observation, info=None, deterministic=None): """Get the agent action by the observation. This function will be called before every env.step(). Arbitrarily run towards the warehouse with some noise. """ self.state, observation, info, _ = self.check_inputs(observation, info) if self.state.goal_bits.any(): self.goal_bits = self.state.goal_bits if self.goal is None or ( not self.state.goal_bits.any() and self.goal not in self.non_empty_warehouses ): self.goal_bits = np.zeros_like(self.state.goal_bits) if len(self.non_empty_warehouses) > 0: new_goal = self.np_random.choice(list(self.non_empty_warehouses)) self.goal_bits[new_goal] = 1 prev_actual_action = self.state.location - self.prev_state.location if self.goal is not None: action = self.goal_location - self.state.location else: action = np.zeros_like(self.state.location) step_size = np.linalg.norm(action) if step_size > self.state.step_size: action *= self.state.step_size / step_size prob = 0.05 if np.linalg.norm(prev_actual_action) > 0.2 * self.state.step_size else 0.75 if self.np_random.binomial(1, prob) != 0: noise = self.noise_scale * self.action_space.sample() else: noise = self.prev_noise action = (action + noise).clip(min=self.action_space.low, max=self.action_space.high) self.prev_state = self.state self.prev_noise = noise return action
[docs] def process_messages(self, observation, messages): # pylint: disable=unused-argument """Process observation and prepare messages to teammates.""" seen_empty_warehouses = set(np.flatnonzero(self.state.empty_bits)) if len(seen_empty_warehouses.intersection(self.non_empty_warehouses)) > 0: self.non_empty_warehouses.difference_update(seen_empty_warehouses) self.need_communication = True
[docs] def send_responses(self): """Prepare messages to communicate with other agents in the same team. This function will be called before receive_responses(). Send indices of non-empty warehouses to teammate if necessary. """ messages = [] if self.need_communication: content = {'non_empty_warehouses': self.non_empty_warehouses.copy()} messages.append(self.pack_message(content=content)) # broadcasting self.need_communication = False return messages
[docs] def receive_responses(self, messages): """Receive messages from other agents in the same team. This function will be called after send_responses() but before act(). Receive and process messages from teammates. """ self.last_responses = tuple(messages) for message in self.last_responses: self.non_empty_warehouses.intersection_update(message.content['non_empty_warehouses'])