kaggle-environments 1.16.11__py2.py3-none-any.whl → 1.17.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +18 -8
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +14 -7
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +5 -4
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +33 -17
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +296 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy_test.py +57 -0
- kaggle_environments/envs/open_spiel/observation.py +133 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +416 -0
- kaggle_environments/envs/open_spiel/proxy.py +139 -0
- kaggle_environments/envs/open_spiel/proxy_test.py +64 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +18 -0
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/METADATA +25 -13
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/RECORD +21 -10
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/WHEEL +1 -1
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/entry_points.txt +0 -0
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info/licenses}/LICENSE +0 -0
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Change Connect Four state and action string representations."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ... import proxy
|
|
7
|
+
import pyspiel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConnectFourState(proxy.State):
|
|
11
|
+
"""Connect Four state proxy."""
|
|
12
|
+
|
|
13
|
+
def _player_string(self, player: int) -> str:
|
|
14
|
+
if player < 0:
|
|
15
|
+
return pyspiel.PlayerId(player).name.lower()
|
|
16
|
+
elif player == 0:
|
|
17
|
+
return 'x'
|
|
18
|
+
elif player == 1:
|
|
19
|
+
return 'o'
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError(f'Invalid player: {player}')
|
|
22
|
+
|
|
23
|
+
def state_dict(self) -> dict[str, Any]:
|
|
24
|
+
# row 0 is now bottom row
|
|
25
|
+
rows = reversed(self.to_string().strip().split('\n'))
|
|
26
|
+
board = [list(row) for row in rows]
|
|
27
|
+
winner = None
|
|
28
|
+
if self.is_terminal():
|
|
29
|
+
if self.returns()[0] > self.returns()[1]:
|
|
30
|
+
winner = 'x'
|
|
31
|
+
elif self.returns()[1] > self.returns()[0]:
|
|
32
|
+
winner = 'o'
|
|
33
|
+
else:
|
|
34
|
+
winner = 'draw'
|
|
35
|
+
return {
|
|
36
|
+
'board': board,
|
|
37
|
+
'current_player': self._player_string(self.current_player()),
|
|
38
|
+
'is_terminal': self.is_terminal(),
|
|
39
|
+
'winner': winner,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
def to_json(self) -> str:
|
|
43
|
+
return json.dumps(self.state_dict())
|
|
44
|
+
|
|
45
|
+
def action_to_dict(self, action: int) -> dict[str, Any]:
|
|
46
|
+
return {'col': action}
|
|
47
|
+
|
|
48
|
+
def action_to_json(self, action: int) -> str:
|
|
49
|
+
return json.dumps(self.action_to_dict(action))
|
|
50
|
+
|
|
51
|
+
def dict_to_action(self, action_dict: dict[str, Any]) -> int:
|
|
52
|
+
return int(action_dict['col'])
|
|
53
|
+
|
|
54
|
+
def json_to_action(self, action_json: str) -> int:
|
|
55
|
+
action_dict = json.loads(action_json)
|
|
56
|
+
return self.dict_to_action(action_dict)
|
|
57
|
+
|
|
58
|
+
def observation_string(self, player: int) -> str:
|
|
59
|
+
return self.observation_json(player)
|
|
60
|
+
|
|
61
|
+
def observation_json(self, player: int) -> str:
|
|
62
|
+
del player
|
|
63
|
+
return self.to_json()
|
|
64
|
+
|
|
65
|
+
def __str__(self):
|
|
66
|
+
return self.to_json()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ConnectFourGame(proxy.Game):
|
|
70
|
+
"""Connect Four game proxy."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, params: Any | None = None):
|
|
73
|
+
params = params or {}
|
|
74
|
+
wrapped = pyspiel.load_game('connect_four', params)
|
|
75
|
+
super().__init__(
|
|
76
|
+
wrapped,
|
|
77
|
+
short_name='connect_four_proxy',
|
|
78
|
+
long_name='Connect Four (proxy)',
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def new_initial_state(self, *args) -> ConnectFourState:
|
|
82
|
+
return ConnectFourState(self.__wrapped__.new_initial_state(*args),
|
|
83
|
+
game=self)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
pyspiel.register_game(ConnectFourGame().get_type(), ConnectFourGame)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Test for proxied Connect Four game."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from absl.testing import absltest
|
|
6
|
+
from absl.testing import parameterized
|
|
7
|
+
import pyspiel
|
|
8
|
+
from . import connect_four_proxy as connect_four
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
NUM_ROWS = 6
|
|
12
|
+
NUM_COLS = 7
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConnectFourTest(parameterized.TestCase):
|
|
16
|
+
|
|
17
|
+
def test_game_is_registered(self):
|
|
18
|
+
game = pyspiel.load_game('connect_four_proxy')
|
|
19
|
+
self.assertIsInstance(game, connect_four.ConnectFourGame)
|
|
20
|
+
|
|
21
|
+
def test_random_sim(self):
|
|
22
|
+
game = connect_four.ConnectFourGame()
|
|
23
|
+
pyspiel.random_sim_test(game, num_sims=10, serialize=False, verbose=False)
|
|
24
|
+
|
|
25
|
+
def test_state_to_json(self):
|
|
26
|
+
game = connect_four.ConnectFourGame()
|
|
27
|
+
state = game.new_initial_state()
|
|
28
|
+
json_state = json.loads(state.to_json())
|
|
29
|
+
expected_board = [['.'] * NUM_COLS for _ in range(NUM_ROWS)]
|
|
30
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
31
|
+
self.assertEqual(json_state['current_player'], 'x')
|
|
32
|
+
state.apply_action(3)
|
|
33
|
+
json_state = json.loads(state.to_json())
|
|
34
|
+
expected_board[0][3] = 'x'
|
|
35
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
36
|
+
self.assertEqual(json_state['current_player'], 'o')
|
|
37
|
+
state.apply_action(2)
|
|
38
|
+
json_state = json.loads(state.to_json())
|
|
39
|
+
expected_board[0][2] = 'o'
|
|
40
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
41
|
+
self.assertEqual(json_state['current_player'], 'x')
|
|
42
|
+
state.apply_action(2)
|
|
43
|
+
json_state = json.loads(state.to_json())
|
|
44
|
+
expected_board[1][2] = 'x'
|
|
45
|
+
self.assertEqual(json_state['board'], expected_board)
|
|
46
|
+
self.assertEqual(json_state['current_player'], 'o')
|
|
47
|
+
|
|
48
|
+
def test_action_to_json(self):
|
|
49
|
+
game = connect_four.ConnectFourGame()
|
|
50
|
+
state = game.new_initial_state()
|
|
51
|
+
action = json.loads(state.action_to_json(3))
|
|
52
|
+
self.assertEqual(json.loads(state.action_to_json(3)), action)
|
|
53
|
+
self.assertEqual(action['col'], 3)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == '__main__':
|
|
57
|
+
absltest.main()
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2019 DeepMind Technologies Limited
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""An observation of a game.
|
|
16
|
+
|
|
17
|
+
This is intended to be the main way to get observations of states in Python.
|
|
18
|
+
The usage pattern is as follows:
|
|
19
|
+
|
|
20
|
+
0. Create the game we will be playing
|
|
21
|
+
1. Create each kind of observation required, using `make_observation`
|
|
22
|
+
2. Every time a new observation is required, call:
|
|
23
|
+
`observation.set_from(state, player)`
|
|
24
|
+
The tensor contained in the Observation class will be updated with an
|
|
25
|
+
observation of the supplied state. This tensor is updated in-place, so if
|
|
26
|
+
you wish to retain it, you must make a copy.
|
|
27
|
+
|
|
28
|
+
The following options are available when creating an Observation:
|
|
29
|
+
- perfect_recall: if true, each observation must allow the observing player to
|
|
30
|
+
reconstruct their history of actions and observations.
|
|
31
|
+
- public_info: if true, the observation should include public information
|
|
32
|
+
- private_info: specifies for which players private information should be
|
|
33
|
+
included - all players, the observing player, or no players
|
|
34
|
+
- params: game-specific parameters for observations
|
|
35
|
+
|
|
36
|
+
We ultimately aim to have all games support all combinations of these arguments.
|
|
37
|
+
However, initially many games will only support the combinations corresponding
|
|
38
|
+
to ObservationTensor and InformationStateTensor:
|
|
39
|
+
- ObservationTensor: perfect_recall=False, public_info=True,
|
|
40
|
+
private_info=SinglePlayer
|
|
41
|
+
- InformationStateTensor: perfect_recall=True, public_info=True,
|
|
42
|
+
private_info=SinglePlayer
|
|
43
|
+
|
|
44
|
+
Three formats of observation are supported:
|
|
45
|
+
a. 1-D numpy array, accessed by `observation.tensor`
|
|
46
|
+
b. Dict of numpy arrays, accessed by `observation.dict`. These are pieces of the
|
|
47
|
+
1-D array, reshaped. The np.array objects refer to the same memory as the
|
|
48
|
+
1-D array (no copying!).
|
|
49
|
+
c. String, hopefully human-readable (primarily for debugging purposes)
|
|
50
|
+
|
|
51
|
+
For usage examples, see `observation_test.py`.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
import numpy as np
|
|
55
|
+
|
|
56
|
+
import pyspiel
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Corresponds to the old information_state_XXX methods.
|
|
60
|
+
INFO_STATE_OBS_TYPE = pyspiel.IIGObservationType(perfect_recall=True)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _Observation:
|
|
64
|
+
"""Contains an observation from a game."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, game, observer):
|
|
67
|
+
self._observation = pyspiel._Observation(game, observer)
|
|
68
|
+
self.dict = {}
|
|
69
|
+
if self._observation.has_tensor():
|
|
70
|
+
self.tensor = np.frombuffer(self._observation, np.float32)
|
|
71
|
+
offset = 0
|
|
72
|
+
for tensor_info in self._observation.tensors_info():
|
|
73
|
+
size = np.prod(tensor_info.shape, dtype=np.int64)
|
|
74
|
+
values = self.tensor[offset:offset + size].reshape(tensor_info.shape)
|
|
75
|
+
self.dict[tensor_info.name] = values
|
|
76
|
+
offset += size
|
|
77
|
+
else:
|
|
78
|
+
self.tensor = None
|
|
79
|
+
|
|
80
|
+
def set_from(self, state, player):
|
|
81
|
+
self._observation.set_from(state, player)
|
|
82
|
+
|
|
83
|
+
def string_from(self, state, player):
|
|
84
|
+
return (self._observation.string_from(state, player)
|
|
85
|
+
if self._observation.has_string() else None)
|
|
86
|
+
|
|
87
|
+
def compress(self):
|
|
88
|
+
return self._observation.compress()
|
|
89
|
+
|
|
90
|
+
def decompress(self, compressed_observation):
|
|
91
|
+
self._observation.decompress(compressed_observation)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def make_observation(
|
|
95
|
+
game,
|
|
96
|
+
imperfect_information_observation_type=None,
|
|
97
|
+
params=None,
|
|
98
|
+
):
|
|
99
|
+
"""Returns an _Observation instance if the imperfect_information_observation_type is supported, otherwise None."""
|
|
100
|
+
params = params or {}
|
|
101
|
+
if hasattr(game, 'make_py_observer'):
|
|
102
|
+
return game.make_py_observer(imperfect_information_observation_type, params)
|
|
103
|
+
else:
|
|
104
|
+
if imperfect_information_observation_type is not None:
|
|
105
|
+
observer = game.make_observer(
|
|
106
|
+
imperfect_information_observation_type, params
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
observer = game.make_observer(params)
|
|
110
|
+
if observer is None:
|
|
111
|
+
return None
|
|
112
|
+
return _Observation(game, observer)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class IIGObserverForPublicInfoGame:
|
|
116
|
+
"""Observer for imperfect information obvservations of public-info games."""
|
|
117
|
+
|
|
118
|
+
def __init__(self, iig_obs_type, params):
|
|
119
|
+
if params:
|
|
120
|
+
raise ValueError(f'Observation parameters not supported; passed {params}')
|
|
121
|
+
self._iig_obs_type = iig_obs_type
|
|
122
|
+
self.tensor = None
|
|
123
|
+
self.dict = {}
|
|
124
|
+
|
|
125
|
+
def set_from(self, state, player):
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
def string_from(self, state, player):
|
|
129
|
+
del player
|
|
130
|
+
if self._iig_obs_type.public_info:
|
|
131
|
+
return state.history_str()
|
|
132
|
+
else:
|
|
133
|
+
return '' # No private information to return
|
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
"""Kaggle environment wrapper for OpenSpiel games."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import os
|
|
5
|
+
import pathlib
|
|
6
|
+
import random
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from kaggle_environments import core
|
|
10
|
+
from kaggle_environments import utils
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pyspiel
|
|
13
|
+
from .games.connect_four import connect_four_proxy
|
|
14
|
+
|
|
15
|
+
DEFAULT_ACT_TIMEOUT = 5
|
|
16
|
+
DEFAULT_RUN_TIMEOUT = 1200
|
|
17
|
+
DEFAULT_EPISODE_STEP_BUFFER = 100 # To account for timeouts, retrys, etc...
|
|
18
|
+
|
|
19
|
+
BASE_SPEC_TEMPLATE = {
|
|
20
|
+
"name": "PLACEHOLDER_NAME",
|
|
21
|
+
"title": "PLACEHOLDER_TITLE",
|
|
22
|
+
"description": "PLACEHOLDER_DESCRIPTION",
|
|
23
|
+
"version": "0.1.0",
|
|
24
|
+
"agents": ["PLACEHOLDER_NUM_AGENTS"],
|
|
25
|
+
|
|
26
|
+
"configuration": {
|
|
27
|
+
"episodeSteps": -1,
|
|
28
|
+
"actTimeout": DEFAULT_ACT_TIMEOUT,
|
|
29
|
+
"runTimeout": DEFAULT_RUN_TIMEOUT,
|
|
30
|
+
"openSpielGameString": {
|
|
31
|
+
"description": "The full game string including parameters.",
|
|
32
|
+
"type": "string",
|
|
33
|
+
"default": "PLACEHOLDER_GAME_STRING"
|
|
34
|
+
},
|
|
35
|
+
"openSpielGameName": {
|
|
36
|
+
"description": "The short_name of the OpenSpiel game to load.",
|
|
37
|
+
"type": "string",
|
|
38
|
+
"default": "PLACEHOLDER_GAME_SHORT_NAME"
|
|
39
|
+
},
|
|
40
|
+
},
|
|
41
|
+
"observation": {
|
|
42
|
+
"properties": {
|
|
43
|
+
"openSpielGameString": {
|
|
44
|
+
"description": "Full game string including parameters.",
|
|
45
|
+
"type": "string"
|
|
46
|
+
},
|
|
47
|
+
"openSpielGameName": {
|
|
48
|
+
"description": "Short name of the OpenSpiel game.",
|
|
49
|
+
"type": "string"
|
|
50
|
+
},
|
|
51
|
+
"observation_string": {
|
|
52
|
+
"description": "String representation of state.",
|
|
53
|
+
"type": "string"
|
|
54
|
+
},
|
|
55
|
+
# TODO(jhtschultz): add legal action strings
|
|
56
|
+
"legal_actions": {
|
|
57
|
+
"description": "List of OpenSpiel legal actions.",
|
|
58
|
+
"type": "array",
|
|
59
|
+
"items": {
|
|
60
|
+
"type": "integer"
|
|
61
|
+
}
|
|
62
|
+
},
|
|
63
|
+
"chance_outcome_probs": {
|
|
64
|
+
"description": "List of probabilities for chance outcomes.",
|
|
65
|
+
"type": "array",
|
|
66
|
+
"items": {
|
|
67
|
+
"type": "float"
|
|
68
|
+
}
|
|
69
|
+
},
|
|
70
|
+
"current_player": {
|
|
71
|
+
"description": "ID of player whose turn it is.",
|
|
72
|
+
"type": "integer"
|
|
73
|
+
},
|
|
74
|
+
"is_terminal": {
|
|
75
|
+
"description": "Boolean indicating game end.",
|
|
76
|
+
"type": "boolean"
|
|
77
|
+
},
|
|
78
|
+
"player_id": {
|
|
79
|
+
"description": "ID of the agent receiving this observation.",
|
|
80
|
+
"type": "integer"
|
|
81
|
+
},
|
|
82
|
+
"remainingOverageTime": 60,
|
|
83
|
+
"step": 0
|
|
84
|
+
},
|
|
85
|
+
"default": {}
|
|
86
|
+
},
|
|
87
|
+
"action": {
|
|
88
|
+
"type": ["integer"],
|
|
89
|
+
"minimum": -1,
|
|
90
|
+
"default": -1
|
|
91
|
+
},
|
|
92
|
+
"reward": {
|
|
93
|
+
"type": ["number"],
|
|
94
|
+
"default": 0.0
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
_OS_GLOBAL_GAME = None
|
|
100
|
+
_OS_GLOBAL_STATE = None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _get_open_spiel_game(env_config: utils.Struct) -> pyspiel.Game:
|
|
104
|
+
global _OS_GLOBAL_GAME
|
|
105
|
+
game_string = env_config.get("openSpielGameString")
|
|
106
|
+
if game_string == str(_OS_GLOBAL_GAME):
|
|
107
|
+
return _OS_GLOBAL_GAME
|
|
108
|
+
if _OS_GLOBAL_GAME is not None:
|
|
109
|
+
print(
|
|
110
|
+
f"WARNING: Overwriting game. Old: {_OS_GLOBAL_GAME}. New {game_string}"
|
|
111
|
+
)
|
|
112
|
+
_OS_GLOBAL_GAME = pyspiel.load_game(game_string)
|
|
113
|
+
return _OS_GLOBAL_GAME
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def interpreter(
|
|
117
|
+
state: list[utils.Struct],
|
|
118
|
+
env: core.Environment,
|
|
119
|
+
) -> list[utils.Struct]:
|
|
120
|
+
"""Updates environment using player responses and returns new observations."""
|
|
121
|
+
global _OS_GLOBAL_GAME, _OS_GLOBAL_STATE
|
|
122
|
+
kaggle_state = state
|
|
123
|
+
del state
|
|
124
|
+
|
|
125
|
+
if env.done:
|
|
126
|
+
return kaggle_state
|
|
127
|
+
|
|
128
|
+
# --- Get Game Info ---
|
|
129
|
+
game = _get_open_spiel_game(env.configuration)
|
|
130
|
+
num_players = game.num_players()
|
|
131
|
+
statuses = [
|
|
132
|
+
kaggle_state[os_current_player].status
|
|
133
|
+
for os_current_player in range(num_players)
|
|
134
|
+
]
|
|
135
|
+
if not any(status == "ACTIVE" for status in statuses):
|
|
136
|
+
raise ValueError("Environment not done and no active agents.")
|
|
137
|
+
|
|
138
|
+
# --- Initialization / Reset ---
|
|
139
|
+
# TODO(jhtschultz): test this behavior.
|
|
140
|
+
is_initial_step = len(env.steps) == 1
|
|
141
|
+
if _OS_GLOBAL_STATE is None or (not is_initial_step and env.done):
|
|
142
|
+
_OS_GLOBAL_STATE = game.new_initial_state()
|
|
143
|
+
|
|
144
|
+
# --- Maybe apply agent action ---
|
|
145
|
+
os_current_player = _OS_GLOBAL_STATE.current_player()
|
|
146
|
+
action_applied = None
|
|
147
|
+
if is_initial_step:
|
|
148
|
+
pass
|
|
149
|
+
elif 0 <= os_current_player < num_players:
|
|
150
|
+
if kaggle_state[os_current_player].status != "ACTIVE":
|
|
151
|
+
pass
|
|
152
|
+
else:
|
|
153
|
+
action_submitted = kaggle_state[os_current_player].action
|
|
154
|
+
legal = _OS_GLOBAL_STATE.legal_actions()
|
|
155
|
+
if action_submitted in legal:
|
|
156
|
+
try:
|
|
157
|
+
_OS_GLOBAL_STATE.apply_action(action_submitted)
|
|
158
|
+
action_applied = action_submitted
|
|
159
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
160
|
+
kaggle_state[os_current_player].status = "ERROR"
|
|
161
|
+
else:
|
|
162
|
+
kaggle_state[os_current_player].status = "INVALID"
|
|
163
|
+
elif os_current_player == pyspiel.PlayerId.SIMULTANEOUS:
|
|
164
|
+
raise NotImplementedError
|
|
165
|
+
elif os_current_player == pyspiel.PlayerId.TERMINAL:
|
|
166
|
+
pass
|
|
167
|
+
elif os_current_player == pyspiel.PlayerId.CHANCE:
|
|
168
|
+
raise ValueError("Interpreter should not be called at chance nodes.")
|
|
169
|
+
else:
|
|
170
|
+
raise ValueError(f"Unknown OpenSpiel player ID: {os_current_player}")
|
|
171
|
+
|
|
172
|
+
# --- Update state info ---
|
|
173
|
+
while _OS_GLOBAL_STATE.is_chance_node():
|
|
174
|
+
chance_outcomes = _OS_GLOBAL_STATE.chance_outcomes
|
|
175
|
+
outcomes = _OS_GLOBAL_STATE.chance_outcomes()
|
|
176
|
+
legal_actions, chance_outcome_probs = zip(*outcomes)
|
|
177
|
+
action = np.random.choice(legal_actions, p=chance_outcome_probs)
|
|
178
|
+
_OS_GLOBAL_STATE.apply_action(action)
|
|
179
|
+
is_terminal = _OS_GLOBAL_STATE.is_terminal()
|
|
180
|
+
agent_returns = _OS_GLOBAL_STATE.returns() + [None]
|
|
181
|
+
next_agent = _OS_GLOBAL_STATE.current_player()
|
|
182
|
+
|
|
183
|
+
for i, agent_state in enumerate(kaggle_state):
|
|
184
|
+
input_status = agent_state.status
|
|
185
|
+
status = ""
|
|
186
|
+
reward = None
|
|
187
|
+
|
|
188
|
+
if input_status in ["TIMEOUT", "ERROR", "INVALID"]:
|
|
189
|
+
status = input_status
|
|
190
|
+
reward = None
|
|
191
|
+
elif is_terminal:
|
|
192
|
+
status = "DONE"
|
|
193
|
+
reward = agent_returns[i]
|
|
194
|
+
elif next_agent == i:
|
|
195
|
+
status = "ACTIVE"
|
|
196
|
+
reward = agent_returns[i]
|
|
197
|
+
else:
|
|
198
|
+
status = "INACTIVE"
|
|
199
|
+
reward = agent_returns[i]
|
|
200
|
+
|
|
201
|
+
info_dict = {}
|
|
202
|
+
# Store the applied action in info for potential debugging/analysis
|
|
203
|
+
if os_current_player == i and action_applied is not None:
|
|
204
|
+
info_dict["action_applied"] = action_applied
|
|
205
|
+
|
|
206
|
+
game_type = _OS_GLOBAL_GAME.get_type()
|
|
207
|
+
obs_str = str(_OS_GLOBAL_STATE)
|
|
208
|
+
legal_actions = _OS_GLOBAL_STATE.legal_actions(i)
|
|
209
|
+
|
|
210
|
+
if status == "ACTIVE" and not legal_actions:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Active agent {i} has no legal actions in state {_OS_GLOBAL_STATE}."
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Apply updates
|
|
216
|
+
obs_update_dict = {
|
|
217
|
+
"observation_string": obs_str,
|
|
218
|
+
"legal_actions": legal_actions,
|
|
219
|
+
"current_player": next_agent,
|
|
220
|
+
"is_terminal": is_terminal,
|
|
221
|
+
"player_id": i,
|
|
222
|
+
}
|
|
223
|
+
for k, v in obs_update_dict.items():
|
|
224
|
+
setattr(agent_state.observation, k, v)
|
|
225
|
+
agent_state.reward = reward
|
|
226
|
+
agent_state.info = info_dict
|
|
227
|
+
agent_state.status = status
|
|
228
|
+
|
|
229
|
+
return kaggle_state
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def renderer(state: list[utils.Struct], env: core.Environment) -> str:
|
|
233
|
+
"""Kaggle renderer function."""
|
|
234
|
+
try:
|
|
235
|
+
obs_str = state[-1].observation["observation_string"]
|
|
236
|
+
return obs_str if obs_str else "<Empty observation string>"
|
|
237
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
238
|
+
print(f"Error rendering {env.name} at state: {state}.")
|
|
239
|
+
raise e
|
|
240
|
+
|
|
241
|
+
# --- HTML Renderer Logic ---
|
|
242
|
+
|
|
243
|
+
def _default_html_renderer() -> str:
|
|
244
|
+
"""Provides the JavaScript string for the default HTML renderer."""
|
|
245
|
+
return """
|
|
246
|
+
function renderer(context) {
|
|
247
|
+
const { parent, environment, step } = context;
|
|
248
|
+
parent.innerHTML = ''; // Clear previous rendering
|
|
249
|
+
|
|
250
|
+
const currentStepData = environment.steps[step];
|
|
251
|
+
if (!currentStepData) {
|
|
252
|
+
parent.textContent = "Waiting for step data...";
|
|
253
|
+
return;
|
|
254
|
+
}
|
|
255
|
+
const numAgents = currentStepData.length;
|
|
256
|
+
const gameMasterIndex = numAgents - 1;
|
|
257
|
+
let obsString = "Observation not available for this step.";
|
|
258
|
+
let title = `Step: ${step}`;
|
|
259
|
+
|
|
260
|
+
if (environment.configuration && environment.configuration.openSpielGameName) {
|
|
261
|
+
title = `${environment.configuration.openSpielGameName} - Step: ${step}`;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Try to get obs_string from game_master of current step
|
|
265
|
+
if (currentStepData[gameMasterIndex] &&
|
|
266
|
+
currentStepData[gameMasterIndex].observation &&
|
|
267
|
+
typeof currentStepData[gameMasterIndex].observation.observation_string === 'string') {
|
|
268
|
+
obsString = currentStepData[gameMasterIndex].observation.observation_string;
|
|
269
|
+
}
|
|
270
|
+
// Fallback to initial step if current is unavailable (e.g. very first render call)
|
|
271
|
+
else if (step === 0 && environment.steps[0] && environment.steps[0][gameMasterIndex] &&
|
|
272
|
+
environment.steps[0][gameMasterIndex].observation &&
|
|
273
|
+
typeof environment.steps[0][gameMasterIndex].observation.observation_string === 'string') {
|
|
274
|
+
obsString = environment.steps[0][gameMasterIndex].observation.observation_string;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
const pre = document.createElement("pre");
|
|
278
|
+
pre.style.fontFamily = "monospace";
|
|
279
|
+
pre.style.margin = "10px";
|
|
280
|
+
pre.style.border = "1px solid #ccc";
|
|
281
|
+
pre.style.padding = "10px";
|
|
282
|
+
pre.style.backgroundColor = "#f9f9f9";
|
|
283
|
+
pre.style.whiteSpace = "pre-wrap";
|
|
284
|
+
pre.style.wordBreak = "break-all";
|
|
285
|
+
|
|
286
|
+
pre.textContent = `${title}\\n\\n${obsString}`;
|
|
287
|
+
parent.appendChild(pre);
|
|
288
|
+
}
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def _get_html_renderer_content(
|
|
292
|
+
open_spiel_short_name: str,
|
|
293
|
+
base_path_for_custom_renderers: pathlib.Path,
|
|
294
|
+
default_renderer_func: Callable[[], str]
|
|
295
|
+
) -> str:
|
|
296
|
+
"""
|
|
297
|
+
Tries to load a custom JS renderer for the game.
|
|
298
|
+
Falls back to the default renderer if not found or on error.
|
|
299
|
+
"""
|
|
300
|
+
if "proxy" not in open_spiel_short_name:
|
|
301
|
+
return default_renderer_func()
|
|
302
|
+
sanitized_game_name = open_spiel_short_name.replace('-', '_').replace('.', '_')
|
|
303
|
+
sanitized_game_name = sanitized_game_name.removesuffix("_proxy")
|
|
304
|
+
custom_renderer_js_path = (
|
|
305
|
+
base_path_for_custom_renderers /
|
|
306
|
+
sanitized_game_name /
|
|
307
|
+
f"{sanitized_game_name}.js"
|
|
308
|
+
)
|
|
309
|
+
if custom_renderer_js_path.is_file():
|
|
310
|
+
try:
|
|
311
|
+
with open(custom_renderer_js_path, "r", encoding="utf-8") as f:
|
|
312
|
+
content = f.read()
|
|
313
|
+
print(f"INFO: Using custom HTML renderer for {open_spiel_short_name} from {custom_renderer_js_path}")
|
|
314
|
+
return content
|
|
315
|
+
except Exception as e_render:
|
|
316
|
+
pass
|
|
317
|
+
return default_renderer_func()
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
# --- Agents ---
|
|
321
|
+
def random_agent(
|
|
322
|
+
observation: dict[str, Any],
|
|
323
|
+
configuration: dict[str, Any],
|
|
324
|
+
) -> int:
|
|
325
|
+
"""A built-in random agent specifically for OpenSpiel environments."""
|
|
326
|
+
del configuration
|
|
327
|
+
legal_actions = observation.get("legal_actions")
|
|
328
|
+
if not legal_actions:
|
|
329
|
+
return None
|
|
330
|
+
action = random.choice(legal_actions)
|
|
331
|
+
return int(action)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
agents = {
|
|
335
|
+
"random": random_agent,
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _register_open_spiel_envs(
|
|
340
|
+
games_list: list[str] | None = None,
|
|
341
|
+
) -> dict[str, Any]:
|
|
342
|
+
successfully_loaded_games = []
|
|
343
|
+
skipped_games = []
|
|
344
|
+
registered_envs = {}
|
|
345
|
+
current_file_dir = pathlib.Path(__file__).parent.resolve()
|
|
346
|
+
custom_renderers_base = current_file_dir / "games"
|
|
347
|
+
if games_list is None:
|
|
348
|
+
games_list = pyspiel.registered_names()
|
|
349
|
+
for short_name in games_list:
|
|
350
|
+
try:
|
|
351
|
+
game = pyspiel.load_game(short_name)
|
|
352
|
+
game_type = game.get_type()
|
|
353
|
+
if not any([
|
|
354
|
+
game_type.provides_information_state_string,
|
|
355
|
+
game_type.provides_observation_string,
|
|
356
|
+
]):
|
|
357
|
+
continue
|
|
358
|
+
game_spec = copy.deepcopy(BASE_SPEC_TEMPLATE)
|
|
359
|
+
env_name = f"open_spiel_{short_name.replace('-', '_').replace('.', '_')}"
|
|
360
|
+
game_spec["name"] = env_name
|
|
361
|
+
game_spec["title"] = f"Open Spiel: {short_name}"
|
|
362
|
+
game_spec["description"] = """
|
|
363
|
+
Kaggle environment wrapper for OpenSpiel games.
|
|
364
|
+
For game implementation details see:
|
|
365
|
+
https://github.com/google-deepmind/open_spiel/tree/master/open_spiel/games
|
|
366
|
+
""".strip()
|
|
367
|
+
game_spec["agents"] = [game.num_players()]
|
|
368
|
+
game_spec["configuration"]["episodeSteps"] = (
|
|
369
|
+
game.max_history_length() + DEFAULT_EPISODE_STEP_BUFFER
|
|
370
|
+
)
|
|
371
|
+
game_spec["configuration"]["openSpielGameString"]["default"] = str(game)
|
|
372
|
+
game_spec["configuration"]["openSpielGameName"]["default"] = short_name
|
|
373
|
+
game_spec["observation"]["properties"]["openSpielGameString"][
|
|
374
|
+
"default"] = str(game)
|
|
375
|
+
game_spec["observation"]["properties"]["openSpielGameName"][
|
|
376
|
+
"default"] = short_name
|
|
377
|
+
|
|
378
|
+
# Building html_renderer_callable is a bit convoluted but other approaches
|
|
379
|
+
# failed for a variety of reasons. Returning a simple lambda function
|
|
380
|
+
# doesn't work because of late-binding. The last env registered will
|
|
381
|
+
# overwrite all previous renderers.
|
|
382
|
+
js_string_content = _get_html_renderer_content(
|
|
383
|
+
open_spiel_short_name=short_name,
|
|
384
|
+
base_path_for_custom_renderers=custom_renderers_base,
|
|
385
|
+
default_renderer_func=_default_html_renderer,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def create_html_renderer_closure(captured_content):
|
|
389
|
+
def html_renderer_callable_no_args():
|
|
390
|
+
return captured_content
|
|
391
|
+
return html_renderer_callable_no_args
|
|
392
|
+
|
|
393
|
+
html_renderer_callable = create_html_renderer_closure(js_string_content)
|
|
394
|
+
|
|
395
|
+
registered_envs[env_name] = {
|
|
396
|
+
"specification": game_spec,
|
|
397
|
+
"interpreter": interpreter,
|
|
398
|
+
"renderer": renderer,
|
|
399
|
+
"html_renderer": html_renderer_callable,
|
|
400
|
+
"agents": agents,
|
|
401
|
+
}
|
|
402
|
+
successfully_loaded_games.append(short_name)
|
|
403
|
+
|
|
404
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
405
|
+
skipped_games.append(short_name)
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
print(f"""
|
|
409
|
+
Successfully loaded OpenSpiel environments: {len(successfully_loaded_games)}.
|
|
410
|
+
OpenSpiel games skipped: {len(skipped_games)}.
|
|
411
|
+
""".strip())
|
|
412
|
+
|
|
413
|
+
return registered_envs
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
registered_open_spiel_envs = _register_open_spiel_envs()
|