Source code for coax.envs._connect_four

from gymnasium import Env
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np

from .._base.errors import UnavailableActionError, EpisodeDoneError


__all__ = (
    'ConnectFourEnv',
)


[docs]class ConnectFourEnv(Env): r""" An adversarial environment for playing the `Connect-Four game <https://en.wikipedia.org/wiki/Connect_Four>`_. Attributes ---------- action_space : gymnasium.spaces.Discrete(7) The action space. observation_space : MultiDiscrete(nvec) The state observation space, representing the position of the current player's tokens (``s[1:,:,0]``) and the other player's tokens (``s[1:,:,1]``) as well as a mask over the space of actions, indicating which actions are available to the current player (``s[0,:,0]``) or the other player (``s[0,:,1]``). **Note:** The "current" player is relative to whose turn it is, which means that the entries ``s[:,:,0]`` and ``s[:,:,1]`` swap between turns. max_time_steps : int Maximum number of timesteps within each episode. available_actions : array of int Array of available actions. This list shrinks when columns saturate. win_reward : 1.0 The reward associated with a win. loss_reward : -1.0 The reward associated with a loss. draw_reward : 0.0 The reward associated with a draw. """ # noqa: E501 # class attributes num_rows = 6 num_cols = 7 num_players = 2 win_reward = 1.0 loss_reward = -win_reward draw_reward = 0.0 action_space = Discrete(num_cols) observation_space = MultiDiscrete( nvec=np.full((num_rows + 1, num_cols, num_players), 2, dtype='uint8')) max_time_steps = int(num_rows * num_cols) filters = np.array([ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1]], [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0]], [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], [[0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], [[0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]], [[1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]], [[0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]], [[0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0]], [[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]], ], dtype='uint8') def __init__(self): self._init_state()
[docs] def reset(self): r""" Reset the environment to the starting position. Returns ------- s : 3d-array, shape: [num_rows + 1, num_cols, num_players] A state observation, representing the position of the current player's tokens (``s[1:,:,0]``) and the other player's tokens (``s[1:,:,1]``) as well as a mask over the space of actions, indicating which actions are available to the current player (``s[0,:,0]``) or the other player (``s[0,:,1]``). **Note:** The "current" player is relative to whose turn it is, which means that the entries ``s[:,:,0]`` and ``s[:,:,1]`` swap between turns. """ self._init_state() return self.state
[docs] def step(self, a): r""" Take one step in the MDP, following the single-player convention from gymnasium. Parameters ---------- a : int, options: {0, 1, 2, 3, 4, 5, 6} The action to be taken. The action is the zero-based count of the possible insertion slots, starting from the left of the board. Returns ------- s_next : array, shape [6, 7, 2] A next-state observation, representing the position of the current player's tokens (``s[1:,:,0]``) and the other player's tokens (``s[1:,:,1]``) as well as a mask over the space of actions, indicating which actions are available to the current player (``s[0,:,0]``) or the other player (``s[0,:,1]``). **Note:** The "current" player is relative to whose turn it is, which means that the entries ``s[:,:,0]`` and ``s[:,:,1]`` swap between turns. r : float Reward associated with the transition :math:`(s, a)\to s_\text{next}`. **Note:** Since "current" player is relative to whose turn it is, you need to be careful about aligning the rewards with the correct state or state-action pair. In particular, this reward :math:`r` is the one associated with the :math:`s` and :math:`a`, i.e. *not* aligned with :math:`s_\text{next}`. done : bool Whether the episode is done. info : dict or None A dict with some extra information (or None). """ if self.done: raise EpisodeDoneError("please reset env to start new episode") if not self.action_space.contains(a): raise ValueError(f"invalid action: {repr(a)}") if a not in self.available_actions: raise UnavailableActionError("action is not available") # swap players self._players = np.roll(self._players, -1) # update state self._state[self._levels[a], a] = self._players[0] self._prev_action = a # run logic self.done, reward = self._done_reward(a) return self.state, reward, self.done, {'state_id': self.state_id}
[docs] def render(self, *args, **kwargs): r""" Render the current state of the environment. """ # lookup for symbols symbol = { 1: u'\u25CF', # player 1 token (agent) 2: u'\u25CB', # player 2 token (adversary) -1: u'\u25BD', # indicator for player 1's last action -2: u'\u25BC', # indicator for player 2's last action } # render board hrule = '+---' * self.num_cols + '+\n' board = " " board += " ".join( symbol.get(-(a == self._prev_action) * self._players[1], " ") for a in range(self.num_cols)) board += " \n" board += hrule for i in range(self.num_rows): board += "| " board += " | ".join( symbol.get(self._state[i, j], " ") for j in range(self.num_cols)) board += " |\n" board += hrule board += " 0 1 2 3 4 5 6 \n" # actions print(board)
@property def state(self): stacked_layers = np.stack(( (self._state == self._players[0]).astype('uint8'), (self._state == self._players[1]).astype('uint8'), ), axis=-1) # shape: [num_rows, num_cols, num_players] available_actions_mask = np.zeros( (1, self.num_cols, self.num_players), dtype='uint8') available_actions_mask[0, self.available_actions, :] = 1 return np.concatenate((available_actions_mask, stacked_layers), axis=0) @property def state_id(self): p = str(self._players[0]) d = '1' if self.done else '0' if self._prev_action is None: a = str(self.num_cols) else: a = str(self._prev_action) s = ''.join(self._state.ravel().astype('str')) # base-3 string s = '{:017x}'.format(int(s, 3)) # 17-char hex string return p + d + a + s # 20-char hex string def set_state(self, state_id): # decode state id p = int(state_id[0], 16) d = int(state_id[1], 16) a = int(state_id[2], 16) assert p in (1, 2) assert d in (0, 1) assert self.action_space.contains(a) or a == self.num_cols self._players[0] = p # 1 or 2 self._players[1] = 3 - p # 2 or 1 self.done = d == 1 self._prev_action = None if a == self.num_cols else a s = np._base_repr(int(state_id[3:], 16), 3) z = np.zeros(self.num_rows * self.num_cols, dtype='uint8') z[-len(s):] = np.array(list(s), dtype='uint8') self._state = z.reshape((self.num_rows, self.num_cols)) self._levels = np.full(self.num_cols, self.num_rows - 1, dtype='uint8') for j in range(self.num_cols): for i in self._state[::-1, j]: if i == 0: break self._levels[j] -= 1 @property def available_actions(self): actions = np.argwhere( (self._levels >= 0) & (self._levels < self.num_rows)).ravel() assert actions.size <= self.num_cols return actions @property def available_actions_mask(self): mask = np.zeros(self.num_cols, dtype='bool') mask[self.available_actions] = True return mask def _init_state(self): self._prev_action = None self._players = np.array([1, 2], dtype='uint8') self._state = np.zeros((self.num_rows, self.num_cols), dtype='uint8') self._levels = np.full(self.num_cols, self.num_rows - 1, dtype='uint8') self.done = False def _done_reward(self, a): r""" Check whether the last action `a` by the current player resulted in a win or draw for player 1 (the agent). This contains the main logic and implements the rules of the game. """ assert self.action_space.contains(a) # update filling levels self._levels[a] -= 1 s = self._state == self._players[0] for i0 in range(2, -1, -1): i1 = i0 + 4 for j0 in range(4): j1 = j0 + 4 if np.any(np.tensordot(self.filters, s[i0:i1, j0:j1]) == 4): return True, 1.0 # check for a draw if len(self.available_actions) == 0: return True, 0.0 # this is what's returned throughout the episode return False, 0.0