kaggle-environments 1.16.11__py2.py3-none-any.whl → 1.17.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

@@ -20,7 +20,7 @@ from .core import *
20
20
  from .main import http_request
21
21
  from . import errors
22
22
 
23
- __version__ = "1.16.11"
23
+ __version__ = "1.17.2"
24
24
 
25
25
  __all__ = ["Agent", "environments", "errors", "evaluate", "http_request",
26
26
  "make", "register", "utils", "__version__",
@@ -31,13 +31,23 @@ __all__ = ["Agent", "environments", "errors", "evaluate", "http_request",
31
31
  for name in listdir(utils.envs_path):
32
32
  try:
33
33
  env = import_module(f".envs.{name}.{name}", __name__)
34
- register(name, {
35
- "agents": getattr(env, "agents", []),
36
- "html_renderer": getattr(env, "html_renderer", None),
37
- "interpreter": getattr(env, "interpreter"),
38
- "renderer": getattr(env, "renderer"),
39
- "specification": getattr(env, "specification"),
40
- })
34
+ if name == "open_spiel":
35
+ for env_name, env_dict in env.registered_open_spiel_envs.items():
36
+ register(env_name, {
37
+ "agents": env_dict.get("agents"),
38
+ "html_renderer": env_dict.get("html_renderer"),
39
+ "interpreter": env_dict.get("interpreter"),
40
+ "renderer": env_dict.get("renderer"),
41
+ "specification": env_dict.get("specification"),
42
+ })
43
+ else:
44
+ register(name, {
45
+ "agents": getattr(env, "agents", []),
46
+ "html_renderer": getattr(env, "html_renderer", None),
47
+ "interpreter": getattr(env, "interpreter"),
48
+ "renderer": getattr(env, "renderer"),
49
+ "specification": getattr(env, "specification"),
50
+ })
41
51
  except Exception as e:
42
52
  if "football" not in name:
43
53
  print("Loading environment %s failed: %s" % (name, e))
@@ -148,6 +148,11 @@ class LuxAIS3Env(environment.Environment):
148
148
  i : max_sensor_range * 2 + 1 - i,
149
149
  i : max_sensor_range * 2 + 1 - i,
150
150
  ].set(val)
151
+ # vision of position at center of update has an extra 10
152
+ update = update.at[
153
+ max_sensor_range,
154
+ max_sensor_range,
155
+ ].add(10)
151
156
  vision_power_map = jax.lax.dynamic_update_slice(
152
157
  vision_power_map,
153
158
  update=update + existing_vision_power,
@@ -227,6 +232,10 @@ class LuxAIS3Env(environment.Environment):
227
232
  state = state.replace(
228
233
  units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask
229
234
  )
235
+
236
+ """spawn relic nodes based on schedule"""
237
+ relic_nodes_mask = (state.steps >= state.relic_spawn_schedule) & (state.relic_spawn_schedule != -1)
238
+ state = state.replace(relic_nodes_mask=relic_nodes_mask)
230
239
 
231
240
  """ process unit movement """
232
241
  # 0 is do nothing, 1 is move up, 2 is move right, 3 is move down, 4 is move left, 5 is sap
@@ -646,11 +655,10 @@ class LuxAIS3Env(environment.Environment):
646
655
  axis=(0, 1),
647
656
  )
648
657
  new_tile_types_map = jnp.where(
649
- state.steps * params.nebula_tile_drift_speed % 1 == 0,
658
+ (state.steps - 1) * abs(params.nebula_tile_drift_speed) % 1 > state.steps * abs(params.nebula_tile_drift_speed) % 1,
650
659
  new_tile_types_map,
651
660
  state.map_features.tile_type,
652
661
  )
653
- # new_energy_nodes = state.energy_nodes + jnp.array([1 * jnp.sign(params.energy_node_drift_speed), -1 * jnp.sign(params.energy_node_drift_speed)])
654
662
 
655
663
  energy_node_deltas = jnp.round(
656
664
  jax.random.uniform(
@@ -663,8 +671,6 @@ class LuxAIS3Env(environment.Environment):
663
671
  energy_node_deltas_symmetric = jnp.stack(
664
672
  [-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1
665
673
  )
666
- # TODO symmetric movement
667
- # energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16)
668
674
  energy_node_deltas = jnp.concatenate(
669
675
  (energy_node_deltas, energy_node_deltas_symmetric)
670
676
  )
@@ -677,7 +683,7 @@ class LuxAIS3Env(environment.Environment):
677
683
  ),
678
684
  )
679
685
  new_energy_nodes = jnp.where(
680
- state.steps * params.energy_node_drift_speed % 1 == 0,
686
+ (state.steps - 1) * abs(params.energy_node_drift_speed) % 1 > state.steps * abs(params.energy_node_drift_speed) % 1,
681
687
  new_energy_nodes,
682
688
  state.energy_nodes,
683
689
  )
@@ -688,7 +694,9 @@ class LuxAIS3Env(environment.Environment):
688
694
 
689
695
  # Compute relic scores
690
696
  def team_relic_score(unit_counts_map):
691
- scores = (unit_counts_map > 0) & (state.relic_nodes_map_weights > 0)
697
+ # not all relic nodes are spawned in yet, but relic nodes map ids are precomputed for all to be spawned relic nodes
698
+ # for efficiency. So we check if the relic node (by id) is spawned in yet. relic nodes mask is always increasing so we can do a simple trick below
699
+ scores = (unit_counts_map > 0) & (state.relic_nodes_map_weights <= state.relic_nodes_mask.sum() // 2) & (state.relic_nodes_map_weights > 0)
692
700
  return jnp.sum(scores, dtype=jnp.int32)
693
701
 
694
702
  # note we need to recompue unit counts since units can get removed due to collisions
@@ -771,7 +779,6 @@ class LuxAIS3Env(environment.Environment):
771
779
  )
772
780
  state = self.compute_energy_features(state, params)
773
781
  state = self.compute_sensor_masks(state, params)
774
-
775
782
  return self.get_obs(state, params=params, key=key), state
776
783
 
777
784
  @functools.partial(jax.jit, static_argnums=(0,))
@@ -46,6 +46,7 @@ class EnvParams:
46
46
  min_energy_per_tile: int = -20
47
47
 
48
48
  max_relic_nodes: int = 6
49
+ """max relic nodes in the entire map. This number should be tuned carefully as relic node spawning code is hardcoded against this number 6"""
49
50
  relic_config_size: int = 5
50
51
  fog_of_war: bool = True
51
52
  """
@@ -87,15 +88,15 @@ class EnvParams:
87
88
  env_params_ranges = dict(
88
89
  # map_type=[1],
89
90
  unit_move_cost=list(range(1, 6)),
90
- unit_sensor_range=list(range(2, 5)),
91
- nebula_tile_vision_reduction=list(range(0, 4)),
92
- nebula_tile_energy_reduction=[0, 0, 10, 25],
91
+ unit_sensor_range=[1, 2, 3, 4],
92
+ nebula_tile_vision_reduction=list(range(0, 8)),
93
+ nebula_tile_energy_reduction=[0, 1, 2, 3, 5, 25],
93
94
  unit_sap_cost=list(range(30, 51)),
94
95
  unit_sap_range=list(range(3, 8)),
95
96
  unit_sap_dropoff_factor=[0.25, 0.5, 1],
96
97
  unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
97
98
  # map randomizations
98
- nebula_tile_drift_speed=[-0.05, -0.025, 0.025, 0.05],
99
+ nebula_tile_drift_speed=[-0.15, -0.1, -0.05, -0.025, 0.025, 0.05, 0.1, 0.15],
99
100
  energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
100
101
  energy_node_drift_magnitude=list(range(3, 6)),
101
102
  )
@@ -66,7 +66,10 @@ class EnvState:
66
66
  relic_nodes_mask: chex.Array
67
67
  """Mask of relic nodes in the environment with shape (N, ) for N max relic nodes"""
68
68
  relic_nodes_map_weights: chex.Array
69
- """Map of relic nodes in the environment with shape (H, W) for H height, W width. True if a relic node is present, False otherwise. This is generated from other state"""
69
+ """Map of relic nodes in the environment with shape (H, W) for H height, W width. Each element is equal to the 1-indexed id of the relic node. This is generated from other state"""
70
+
71
+ relic_spawn_schedule: chex.Array
72
+ """Relic spawn schedule in the environment with shape (N, ) for N max relic nodes. Elements are the game timestep at which the relic node spawns"""
70
73
 
71
74
  map_features: MapTile
72
75
  """Map features in the environment with shape (W, H, 2) for W width, H height
@@ -121,7 +124,7 @@ class EnvObs:
121
124
 
122
125
  def serialize_env_states(env_states: list[EnvState]):
123
126
  def serialize_array(root: EnvState, arr, key_path: str = ""):
124
- if key_path in ["sensor_mask", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]:
127
+ if key_path in ["sensor_mask", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights", "relic_spawn_schedule"]:
125
128
  return None
126
129
  if key_path == "relic_nodes":
127
130
  return root.relic_nodes[root.relic_nodes_mask].tolist()
@@ -186,9 +189,10 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
186
189
 
187
190
  # TODO (this could be optimized better)
188
191
  def update_relic_node(relic_nodes_map_weights, relic_data):
189
- relic_node, relic_node_config, mask = relic_data
192
+ relic_node, relic_node_config, mask, relic_node_id = relic_data
190
193
  start_y = relic_node[1] - relic_config_size // 2
191
194
  start_x = relic_node[0] - relic_config_size // 2
195
+
192
196
  for dy in range(relic_config_size):
193
197
  for dx in range(relic_config_size):
194
198
  y, x = start_y + dy, start_x + dx
@@ -196,14 +200,20 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
196
200
  jnp.logical_and(y >= 0, x >= 0),
197
201
  jnp.logical_and(y < map_height, x < map_width),
198
202
  )
203
+ # ensure we don't override previous spawns
204
+ has_points = jnp.logical_and(
205
+ relic_nodes_map_weights > 0,
206
+ relic_nodes_map_weights <= relic_node_id + 1
207
+ )
199
208
  relic_nodes_map_weights = jnp.where(
200
- valid_pos & mask,
201
- relic_nodes_map_weights.at[x, y].add(relic_node_config[dx, dy].astype(jnp.int16)),
209
+ valid_pos & mask & jnp.logical_not(has_points) & relic_node_config[dx, dy],
210
+ relic_nodes_map_weights.at[x, y].set(relic_node_config[dx, dy].astype(jnp.int16) * (relic_node_id + 1)),
202
211
  relic_nodes_map_weights,
203
212
  )
204
213
  return relic_nodes_map_weights, None
205
214
 
206
215
  # this is really slow...
216
+
207
217
  relic_nodes_map_weights, _ = jax.lax.scan(
208
218
  update_relic_node,
209
219
  relic_nodes_map_weights,
@@ -211,8 +221,10 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
211
221
  generated["relic_nodes"],
212
222
  generated["relic_node_configs"],
213
223
  generated["relic_nodes_mask"],
224
+ jnp.arange(max_relic_nodes, dtype=jnp.int16) % (max_relic_nodes // 2),
214
225
  ),
215
226
  )
227
+
216
228
  state = EnvState(
217
229
  units=UnitState(position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16), energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16)),
218
230
  units_mask=jnp.zeros(
@@ -225,9 +237,10 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
225
237
  energy_nodes_mask=generated["energy_nodes_mask"],
226
238
  # energy_field=jnp.zeros(shape=(params.map_height, params.map_width), dtype=jnp.int16),
227
239
  relic_nodes=generated["relic_nodes"],
228
- relic_nodes_mask=generated["relic_nodes_mask"],
240
+ relic_nodes_mask=jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool), # as relic nodes are spawn in, we start with them all invisible.
229
241
  relic_node_configs=generated["relic_node_configs"],
230
242
  relic_nodes_map_weights=relic_nodes_map_weights,
243
+ relic_spawn_schedule=generated["relic_spawn_schedule"],
231
244
  sensor_mask=jnp.zeros(
232
245
  shape=(num_teams, map_height, map_width),
233
246
  dtype=jnp.bool,
@@ -276,7 +289,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
276
289
  flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2:] # Get indices of two highest values
277
290
  highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
278
291
 
279
- # relic nodes have a fixed density of 25% nearby tiles can yield points
292
+ # relic nodes have a fixed density of 20% nearby tiles can yield points
280
293
  relic_node_configs = (
281
294
  jax.random.randint(
282
295
  key,
@@ -291,18 +304,15 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
291
304
  >= 7.5
292
305
  )
293
306
  highest_positions = highest_positions.astype(jnp.int16)
294
- relic_nodes_mask = relic_nodes_mask.at[0].set(True)
295
- relic_nodes_mask = relic_nodes_mask.at[1].set(True)
296
307
  mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
297
308
  relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
298
309
 
299
310
  key, subkey = jax.random.split(key)
300
- relic_nodes_mask_half = jax.random.randint(key, (max_relic_nodes // 2, ), minval=0, maxval=2).astype(jnp.bool)
301
- relic_nodes_mask_half = relic_nodes_mask_half.at[0].set(True)
302
- relic_nodes_mask = relic_nodes_mask.at[:max_relic_nodes // 2].set(relic_nodes_mask_half)
303
- relic_nodes_mask = relic_nodes_mask.at[max_relic_nodes // 2:].set(relic_nodes_mask_half)
304
- # import ipdb;ipdb.set_trace()
311
+ num_spawned_relic_nodes = jax.random.randint(key, (1, ), minval=1, maxval=(max_relic_nodes // 2) + 1)
312
+ relic_nodes_mask_half = jnp.arange(max_relic_nodes // 2) < num_spawned_relic_nodes
313
+ relic_nodes_mask = jnp.concat([relic_nodes_mask_half, relic_nodes_mask_half], axis=0)
305
314
  relic_node_configs = relic_node_configs.at[max_relic_nodes // 2:].set(relic_node_configs[:max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1])
315
+ # note that relic nodes mask is always increasing.
306
316
 
307
317
  ### Generate energy nodes ###
308
318
  key, subkey = jax.random.split(key)
@@ -318,7 +328,6 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
318
328
  energy_nodes_mask = energy_nodes_mask.at[:max_energy_nodes // 2].set(energy_nodes_mask_half)
319
329
  energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2:].set(energy_nodes_mask_half)
320
330
 
321
- # TODO (stao): provide more randomization options for energy node functions.
322
331
  energy_node_fns = jnp.array(
323
332
  [
324
333
  [0, 1.2, 1, 4],
@@ -331,8 +340,14 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
331
340
  # [1, 4, 0, 0]
332
341
  ]
333
342
  )
334
- # import ipdb; ipdb.set_trace()
335
- # energy_node_fns = jnp.concat([energy_node_fns, jnp.zeros((params.max_energy_nodes - 2, 4), dtype=jnp.float32)], axis=0)
343
+
344
+ # generate a random relic spawn schedule
345
+ # if number is -1, then relic node is never spawned, otherwise spawn at that game timestep
346
+ assert max_relic_nodes == 6, "random map generation is hardcoded to use 6 relic nodes at most per map"
347
+ key, subkey = jax.random.split(key)
348
+ relic_spawn_schedule_half = jax.random.randint(key, (max_relic_nodes //2, ), minval=0, maxval=params.max_steps_in_match // 2) + jnp.arange(3) * (params.max_steps_in_match + 1)
349
+ relic_spawn_schedule = jnp.concat([relic_spawn_schedule_half, relic_spawn_schedule_half], axis=0)
350
+ relic_spawn_schedule = jnp.where(relic_nodes_mask, relic_spawn_schedule, -1)
336
351
 
337
352
 
338
353
  return dict(
@@ -343,6 +358,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
343
358
  energy_nodes_mask=energy_nodes_mask,
344
359
  relic_nodes_mask=relic_nodes_mask,
345
360
  relic_node_configs=relic_node_configs,
361
+ relic_spawn_schedule=relic_spawn_schedule,
346
362
  )
347
363
  def interpolant(t):
348
364
  return t*t*t*(t*(t*6 - 15) + 10)
@@ -0,0 +1,354 @@
1
+ """Kaggle environment wrapper for OpenSpiel games."""
2
+
3
+ import copy
4
+ import random
5
+ from typing import Any
6
+
7
+ from kaggle_environments import core
8
+ from kaggle_environments import utils
9
+ import numpy as np
10
+ import pyspiel
11
+
12
+
13
+ DEFAULT_ACT_TIMEOUT = 5
14
+ DEFAULT_RUN_TIMEOUT = 1200
15
+ DEFAULT_EPISODE_STEP_BUFFER = 100 # To account for timeouts, retrys, etc...
16
+
17
+ BASE_SPEC_TEMPLATE = {
18
+ "name": "PLACEHOLDER_NAME",
19
+ "title": "PLACEHOLDER_TITLE",
20
+ "description": "PLACEHOLDER_DESCRIPTION",
21
+ "version": "0.1.0",
22
+ "agents": ["PLACEHOLDER_NUM_AGENTS"],
23
+
24
+ "configuration": {
25
+ "episodeSteps": -1,
26
+ "actTimeout": DEFAULT_ACT_TIMEOUT,
27
+ "runTimeout": DEFAULT_RUN_TIMEOUT,
28
+ "openSpielGameString": {
29
+ "description": "The full game string including parameters.",
30
+ "type": "string",
31
+ "default": "PLACEHOLDER_GAME_STRING"
32
+ },
33
+ "openSpielGameName": {
34
+ "description": "The short_name of the OpenSpiel game to load.",
35
+ "type": "string",
36
+ "default": "PLACEHOLDER_GAME_SHORT_NAME"
37
+ },
38
+ },
39
+ "observation": {
40
+ "properties": {
41
+ "openSpielGameString": {
42
+ "description": "Full game string including parameters.",
43
+ "type": "string"
44
+ },
45
+ "openSpielGameName": {
46
+ "description": "Short name of the OpenSpiel game.",
47
+ "type": "string"
48
+ },
49
+ "observation_string": {
50
+ "description": "String representation of state.",
51
+ "type": "string"
52
+ },
53
+ # TODO(jhtschultz): add legal action strings
54
+ "legal_actions": {
55
+ "description": "List of OpenSpiel legal actions.",
56
+ "type": "array",
57
+ "items": {
58
+ "type": "integer"
59
+ }
60
+ },
61
+ "chance_outcome_probs": {
62
+ "description": "List of probabilities for chance outcomes.",
63
+ "type": "array",
64
+ "items": {
65
+ "type": "float"
66
+ }
67
+ },
68
+ "current_player": {
69
+ "description": "ID of player whose turn it is.",
70
+ "type": "integer"
71
+ },
72
+ "is_terminal": {
73
+ "description": "Boolean indicating game end.",
74
+ "type": "boolean"
75
+ },
76
+ "player_id": {
77
+ "description": "ID of the agent receiving this observation.",
78
+ "type": "integer"
79
+ },
80
+ "remainingOverageTime": 60,
81
+ "step": 0
82
+ },
83
+ "default": {}
84
+ },
85
+ "action": {
86
+ "type": ["integer"],
87
+ "minimum": -1,
88
+ "default": -1
89
+ },
90
+ "reward": {
91
+ "type": ["number"],
92
+ "default": 0.0
93
+ },
94
+ }
95
+
96
+
97
+ _OS_GLOBAL_GAME = None
98
+ _OS_GLOBAL_STATE = None
99
+
100
+
101
+ def _get_open_spiel_game(env_config: utils.Struct) -> pyspiel.Game:
102
+ global _OS_GLOBAL_GAME
103
+ game_string = env_config.get("openSpielGameString")
104
+ if game_string == str(_OS_GLOBAL_GAME):
105
+ return _OS_GLOBAL_GAME
106
+ if _OS_GLOBAL_GAME is not None:
107
+ print(
108
+ f"WARNING: Overwriting game. Old: {_OS_GLOBAL_GAME}. New {game_string}"
109
+ )
110
+ _OS_GLOBAL_GAME = pyspiel.load_game(game_string)
111
+ return _OS_GLOBAL_GAME
112
+
113
+
114
+ def interpreter(
115
+ state: list[utils.Struct],
116
+ env: core.Environment,
117
+ ) -> list[utils.Struct]:
118
+ """Updates environment using player responses and returns new observations."""
119
+ global _OS_GLOBAL_GAME, _OS_GLOBAL_STATE
120
+ kaggle_state = state
121
+ del state
122
+
123
+ if env.done:
124
+ return kaggle_state
125
+
126
+ # --- Get Game Info ---
127
+ game = _get_open_spiel_game(env.configuration)
128
+ num_players = game.num_players()
129
+ statuses = [
130
+ kaggle_state[os_current_player].status
131
+ for os_current_player in range(num_players)
132
+ ]
133
+ if not any(status == "ACTIVE" for status in statuses):
134
+ raise ValueError("Environment not done and no active agents.")
135
+
136
+ # --- Initialization / Reset ---
137
+ # TODO(jhtschultz): test this behavior.
138
+ is_initial_step = len(env.steps) == 1
139
+ if _OS_GLOBAL_STATE is None or (not is_initial_step and env.done):
140
+ _OS_GLOBAL_STATE = game.new_initial_state()
141
+
142
+ # --- Maybe apply agent action ---
143
+ os_current_player = _OS_GLOBAL_STATE.current_player()
144
+ action_applied = None
145
+ if is_initial_step:
146
+ pass
147
+ elif 0 <= os_current_player < num_players:
148
+ if kaggle_state[os_current_player].status != "ACTIVE":
149
+ pass
150
+ else:
151
+ action_submitted = kaggle_state[os_current_player].action
152
+ legal = _OS_GLOBAL_STATE.legal_actions()
153
+ if action_submitted in legal:
154
+ try:
155
+ _OS_GLOBAL_STATE.apply_action(action_submitted)
156
+ action_applied = action_submitted
157
+ except Exception: # pylint: disable=broad-exception-caught
158
+ kaggle_state[os_current_player].status = "ERROR"
159
+ else:
160
+ kaggle_state[os_current_player].status = "INVALID"
161
+ elif os_current_player == pyspiel.PlayerId.SIMULTANEOUS:
162
+ raise NotImplementedError
163
+ elif os_current_player == pyspiel.PlayerId.TERMINAL:
164
+ pass
165
+ elif os_current_player == pyspiel.PlayerId.CHANCE:
166
+ raise ValueError("Interpreter should not be called at chance nodes.")
167
+ else:
168
+ raise ValueError(f"Unknown OpenSpiel player ID: {os_current_player}")
169
+
170
+ # --- Update state info ---
171
+ while _OS_GLOBAL_STATE.is_chance_node():
172
+ chance_outcomes = _OS_GLOBAL_STATE.chance_outcomes
173
+ outcomes = _OS_GLOBAL_STATE.chance_outcomes()
174
+ legal_actions, chance_outcome_probs = zip(*outcomes)
175
+ action = np.random.choice(legal_actions, p=chance_outcome_probs)
176
+ _OS_GLOBAL_STATE.apply_action(action)
177
+ is_terminal = _OS_GLOBAL_STATE.is_terminal()
178
+ agent_returns = _OS_GLOBAL_STATE.returns() + [None]
179
+ next_agent = _OS_GLOBAL_STATE.current_player()
180
+
181
+ for i, agent_state in enumerate(kaggle_state):
182
+ input_status = agent_state.status
183
+ status = ""
184
+ reward = None
185
+
186
+ if input_status in ["TIMEOUT", "ERROR", "INVALID"]:
187
+ status = input_status
188
+ reward = None
189
+ elif is_terminal:
190
+ status = "DONE"
191
+ reward = agent_returns[i]
192
+ elif next_agent == i:
193
+ status = "ACTIVE"
194
+ reward = agent_returns[i]
195
+ else:
196
+ status = "INACTIVE"
197
+ reward = agent_returns[i]
198
+
199
+ info_dict = {}
200
+ # Store the applied action in info for potential debugging/analysis
201
+ if os_current_player == i and action_applied is not None:
202
+ info_dict["action_applied"] = action_applied
203
+
204
+ game_type = _OS_GLOBAL_GAME.get_type()
205
+ obs_str = str(_OS_GLOBAL_STATE)
206
+ legal_actions = _OS_GLOBAL_STATE.legal_actions(i)
207
+
208
+ if status == "ACTIVE" and not legal_actions:
209
+ raise ValueError(
210
+ f"Active agent {i} has no legal actions in state {_OS_GLOBAL_STATE}."
211
+ )
212
+
213
+ # Apply updates
214
+ obs_update_dict = {
215
+ "observation_string": obs_str,
216
+ "legal_actions": legal_actions,
217
+ "current_player": next_agent,
218
+ "is_terminal": is_terminal,
219
+ "player_id": i,
220
+ }
221
+ for k, v in obs_update_dict.items():
222
+ setattr(agent_state.observation, k, v)
223
+ agent_state.reward = reward
224
+ agent_state.info = info_dict
225
+ agent_state.status = status
226
+
227
+ return kaggle_state
228
+
229
+
230
+ def renderer(state: list[utils.Struct], env: core.Environment) -> str:
231
+ """Kaggle renderer function."""
232
+ try:
233
+ obs_str = state[-1].observation["observation_string"]
234
+ return obs_str if obs_str else "<Empty observation string>"
235
+ except Exception as e: # pylint: disable=broad-exception-caught
236
+ print(f"Error rendering {env.name} at state: {state}.")
237
+ raise e
238
+
239
+
240
+ def html_renderer():
241
+ """Provides the simplest possible HTML/JS renderer for OpenSpiel text observations."""
242
+ return """
243
+ function renderer(context) {
244
+ const { parent, environment, step } = context;
245
+ parent.innerHTML = ''; // Clear previous rendering
246
+
247
+ // Get the current step's data
248
+ const currentStepData = environment.steps[step];
249
+ const numAgents = currentStepData.length
250
+ const gameMasterIndex = numAgents - 1
251
+ let obsString = "Observation not available for this step.";
252
+
253
+ // Try to get the raw observation string from the game master agent.
254
+ if (currentStepData && currentStepData[gameMasterIndex] && currentStepData[gameMasterIndex].observation && currentStepData[gameMasterIndex].observation.observation_string !== undefined) {
255
+ obsString = currentStepData[gameMasterIndex].observation.observation_string;
256
+ } else if (step === 0 && environment.steps[0] && environment.steps[0][gameMasterIndex] && environment.steps[0][gameMasterIndex].observation && environment.steps[0][gameMasterIndex].observation.observation_string !== undefined) {
257
+ // Fallback for initial state if current step data is missing
258
+ obsString = environment.steps[0][gameMasterIndex].observation.observation_string;
259
+ }
260
+
261
+ // Create a <pre> element to preserve formatting
262
+ const pre = document.createElement("pre");
263
+ pre.style.fontFamily = "monospace"; // Ensure monospace font
264
+ pre.style.margin = "10px"; // Add some padding
265
+ pre.style.border = "1px solid #ccc";
266
+ pre.style.padding = "5px";
267
+ pre.style.backgroundColor = "#f0f0f0";
268
+
269
+ // Set the text content (safer than innerHTML for plain text)
270
+ pre.textContent = `Step: ${step}\\n\\n${obsString}`; // Add step number for context
271
+
272
+ parent.appendChild(pre);
273
+ }
274
+ """
275
+
276
+
277
+ # --- Agents ---
278
+ def random_agent(
279
+ observation: dict[str, Any],
280
+ configuration: dict[str, Any],
281
+ ) -> int:
282
+ """A built-in random agent specifically for OpenSpiel environments."""
283
+ del configuration
284
+ legal_actions = observation.get("legal_actions")
285
+ if not legal_actions:
286
+ return None
287
+ action = random.choice(legal_actions)
288
+ return int(action)
289
+
290
+
291
+ agents = {
292
+ "random": random_agent,
293
+ }
294
+
295
+
296
+ def _register_open_spiel_envs(
297
+ games_list: list[str] | None = None,
298
+ ) -> dict[str, Any]:
299
+ successfully_loaded_games = []
300
+ skipped_games = []
301
+ registered_envs = {}
302
+ if games_list is None:
303
+ games_list = pyspiel.registered_names()
304
+ for short_name in games_list:
305
+ try:
306
+ game = pyspiel.load_game(short_name)
307
+ game_type = game.get_type()
308
+ if not any([
309
+ game_type.provides_information_state_string,
310
+ game_type.provides_observation_string,
311
+ ]):
312
+ continue
313
+ game_spec = copy.deepcopy(BASE_SPEC_TEMPLATE)
314
+ env_name = f"open_spiel_{short_name.replace('-', '_').replace('.', '_')}"
315
+ game_spec["name"] = env_name
316
+ game_spec["title"] = f"Open Spiel: {short_name}"
317
+ game_spec["description"] = """
318
+ Kaggle environment wrapper for OpenSpiel games.
319
+ For game implementation details see:
320
+ https://github.com/google-deepmind/open_spiel/tree/master/open_spiel/games
321
+ """.strip()
322
+ game_spec["agents"] = [game.num_players()]
323
+ game_spec["configuration"]["episodeSteps"] = (
324
+ game.max_history_length() + DEFAULT_EPISODE_STEP_BUFFER
325
+ )
326
+ game_spec["configuration"]["openSpielGameString"]["default"] = str(game)
327
+ game_spec["configuration"]["openSpielGameName"]["default"] = short_name
328
+ game_spec["observation"]["properties"]["openSpielGameString"][
329
+ "default"] = str(game)
330
+ game_spec["observation"]["properties"]["openSpielGameName"][
331
+ "default"] = short_name
332
+
333
+ registered_envs[env_name] = {
334
+ "specification": game_spec,
335
+ "interpreter": interpreter,
336
+ "renderer": renderer,
337
+ "html_renderer": html_renderer,
338
+ "agents": agents,
339
+ }
340
+ successfully_loaded_games.append(short_name)
341
+
342
+ except Exception: # pylint: disable=broad-exception-caught
343
+ skipped_games.append(short_name)
344
+ continue
345
+
346
+ print(f"""
347
+ Successfully loaded OpenSpiel environments: {len(successfully_loaded_games)}.
348
+ OpenSpiel games skipped: {len(skipped_games)}.
349
+ """.strip())
350
+
351
+ return registered_envs
352
+
353
+
354
+ registered_open_spiel_envs = _register_open_spiel_envs()
@@ -0,0 +1,18 @@
1
+ import sys
2
+ from kaggle_environments import make
3
+ import open_spiel as open_spiel_env
4
+
5
+
6
+ def test_envs_load():
7
+ envs = open_spiel_env._register_open_spiel_envs()
8
+ print(len(envs))
9
+
10
+
11
+ def test_tic_tac_toe_playthrough():
12
+ envs = open_spiel_env._register_open_spiel_envs(["tic_tac_toe"])
13
+ print(envs)
14
+ env = make("open_spiel_tic_tac_toe", debug=True)
15
+ env.run(["random", "random"])
16
+ json = env.toJSON()
17
+ assert json["name"] == "open_spiel_tic_tac_toe"
18
+ assert all([status == "DONE" for status in json["statuses"]])
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: kaggle-environments
3
- Version: 1.16.11
3
+ Version: 1.17.2
4
4
  Summary: Kaggle Environments
5
5
  Home-page: https://github.com/Kaggle/kaggle-environments
6
6
  Author: Kaggle
@@ -10,17 +10,29 @@ Keywords: Kaggle
10
10
  Requires-Python: >=3.8
11
11
  Description-Content-Type: text/markdown
12
12
  License-File: LICENSE
13
- Requires-Dist: Chessnut >=0.4.1
14
- Requires-Dist: Flask >=1.1.2
15
- Requires-Dist: gymnasium ==0.29.0
16
- Requires-Dist: jsonschema >=3.0.1
17
- Requires-Dist: numpy >=1.19.5
18
- Requires-Dist: pettingzoo ==1.24.0
19
- Requires-Dist: requests >=2.25.1
20
- Requires-Dist: scipy >=1.11.2
21
- Requires-Dist: shimmy >=1.2.1
22
- Requires-Dist: stable-baselines3 ==2.1.0
23
- Requires-Dist: transformers >=4.33.1
13
+ Requires-Dist: jsonschema>=3.0.1
14
+ Requires-Dist: Flask>=1.1.2
15
+ Requires-Dist: numpy>=1.19.5
16
+ Requires-Dist: requests>=2.25.1
17
+ Requires-Dist: pettingzoo==1.24.0
18
+ Requires-Dist: gymnasium==0.29.0
19
+ Requires-Dist: stable-baselines3==2.1.0
20
+ Requires-Dist: transformers>=4.33.1
21
+ Requires-Dist: scipy>=1.11.2
22
+ Requires-Dist: shimmy>=1.2.1
23
+ Requires-Dist: Chessnut>=0.4.1
24
+ Requires-Dist: open_spiel>=1.5.0
25
+ Dynamic: author
26
+ Dynamic: author-email
27
+ Dynamic: description
28
+ Dynamic: description-content-type
29
+ Dynamic: home-page
30
+ Dynamic: keywords
31
+ Dynamic: license
32
+ Dynamic: license-file
33
+ Dynamic: requires-dist
34
+ Dynamic: requires-python
35
+ Dynamic: summary
24
36
 
25
37
  # [<img src="https://kaggle.com/static/images/site-logo.png" height="50" style="margin-bottom:-15px" />](https://kaggle.com) Environments
26
38
 
@@ -1,4 +1,4 @@
1
- kaggle_environments/__init__.py,sha256=aTT8xgR7BnpIGdkhn5vptp6y0a5dPlMvW5NEtq_T0gA,1683
1
+ kaggle_environments/__init__.py,sha256=gfZpgAMb07ABEi0ZWy18VtUqHEaY-tBcrDkSitkbcwM,2189
2
2
  kaggle_environments/agent.py,sha256=j9rLnCK_Gy0eRIuvlJ9vcMh3vxn-Wvu-pjCpannOolc,6703
3
3
  kaggle_environments/api.py,sha256=eLBKqr11Ku4tdsMUdUqy74FIVEA_hdV3_QUpX84x3Z8,798
4
4
  kaggle_environments/core.py,sha256=IrEkN9cIA2djBAxI8Sz1GRpGNKjhqbnBdV6irAeTm8Q,27851
@@ -169,13 +169,13 @@ kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json,sha256=wgEUeFM0-XsUZRGNLd2OMrb
169
169
  kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py,sha256=8JT30NPZN5R__azfk0PDXeXlKw6S0KqxIbllVw8RYvM,5815
170
170
  kaggle_environments/envs/lux_ai_s3/test_lux.py,sha256=cfiEv4re7pvZ9TeG9HdvGOhHb0da272w8CDUSZn5bpU,273
171
171
  kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py,sha256=2hwayynTiOtSr3V1-gjZfosn0Y3sOSKvNrYHhHeAyhY,28
172
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py,sha256=EHfijkn3mtBOUKTL6uw9Wa1YiMdmiC04ZLSI4j9YR0Y,39787
172
+ kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py,sha256=zBsmV-G1bfIeyJC8eOq3VTlXXhZSm4uOWh6NG0hlpGk,40286
173
173
  kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py,sha256=cDPe9qJirhADf9V5Geftir8ccdXWSmGsmUD8KskJ8JU,281
174
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py,sha256=7StYyQcbKqJc-VstOtBDkE4FrL6YBee071jQeQ6e7xE,3009
174
+ kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py,sha256=-lMOd8OeRD_6jAINOByezPGoCSZVZVEIRSxj2BMM3ns,3179
175
175
  kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py,sha256=DLwP5zAAyP-eNp6gtr81ketNvHQIfWJr-KLfJUwMiPo,5136
176
176
  kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py,sha256=ZMnGPFtDA8ysjuyW24ySwlnoQQp7Q8bIimIIVXjWD9Y,10843
177
177
  kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py,sha256=BpEUN5NlgfBRnFglWZGxOx-pMtn0oAhYL3v7tON6UA0,951
178
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py,sha256=ldzNOD3EHjgE7lJ5uHQO8ObaO8VlMlWBAL_6FSKtYZk,17061
178
+ kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py,sha256=i2WO12f_D9n5z77GgdVxV_JwKCaXnmEfujkvsYPVA9U,18169
179
179
  kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py,sha256=v_wveZrwyRjkhRHC6a8_dFx5RlTBLrpOOpTPq79LC0k,267
180
180
  kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py,sha256=ec6VtDyOeKHa2khSZtcdIQ3JBl6J5IgPGvlVBLdrzyY,6282
181
181
  kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py,sha256=bL6ma6dax2w7KIE1YOS4h27xG_plu1eeo876mssAX5Q,3985
@@ -188,6 +188,8 @@ kaggle_environments/envs/mab/agents.py,sha256=vPHNN5oRcbTG3FaW9iYmoeQjufXFJMjYOL
188
188
  kaggle_environments/envs/mab/mab.js,sha256=zsKGVRL9qFyUoukRj-ES5dOh8Wig7UzNf0z5Potw84E,3256
189
189
  kaggle_environments/envs/mab/mab.json,sha256=VAlpjJ7_ytYO648swQW_ICjC5JKTAdmnShuGggeSX4A,2077
190
190
  kaggle_environments/envs/mab/mab.py,sha256=bkSIxkstS98Vr3eOA9kxQkseDqa1MlG2Egfzeaf-8EA,5241
191
+ kaggle_environments/envs/open_spiel/open_spiel.py,sha256=eGoV9EpOz2v_41NSZ6mJMGcwF29Y-0fZA3IuoOd8Psk,11812
192
+ kaggle_environments/envs/open_spiel/test_open_spiel.py,sha256=ZDzXa8te8MCdgJ7d-8IdC8al72r3YPxYIJ7Gjhk-NkM,534
191
193
  kaggle_environments/envs/rps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
192
194
  kaggle_environments/envs/rps/agents.py,sha256=iBtBjPbutWickm-K1EzBEYvLWj5fvU3ks0AYQMYWgEI,2140
193
195
  kaggle_environments/envs/rps/helpers.py,sha256=NUqhJafNSzlC_ArwDIYzbLx15pkmBpzfVuG8Iv4wX9U,966
@@ -201,9 +203,9 @@ kaggle_environments/envs/tictactoe/tictactoe.js,sha256=NZDT-oSG0a6a-rso9Ldh9qkJw
201
203
  kaggle_environments/envs/tictactoe/tictactoe.json,sha256=zMXZ8-fpT7FBhzz2FFBvRLn4XwtngjEqOieMvI6cCj8,1121
202
204
  kaggle_environments/envs/tictactoe/tictactoe.py,sha256=uq3sTHWNMg0dxX2v9pTbJAKM7fwerxQt7OQjCX96m-Y,3657
203
205
  kaggle_environments/static/player.html,sha256=XyVoe0XxMa2MO1fTDY_rjyjzPN-JZgbVwJIDoLSnlw0,23016
204
- kaggle_environments-1.16.11.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
205
- kaggle_environments-1.16.11.dist-info/METADATA,sha256=lZY_4hgG1gBRPGKLxBIB-vgsaCrQd9ASdjGFUQEYFXM,10700
206
- kaggle_environments-1.16.11.dist-info/WHEEL,sha256=m9WAupmBd2JGDsXWQGJgMGXIWbQY3F5c2xBJbBhq0nY,110
207
- kaggle_environments-1.16.11.dist-info/entry_points.txt,sha256=HbVC-LKGQFV6lEEYBYyDTtrkHgdHJUWQ8_qt9KHGqz4,70
208
- kaggle_environments-1.16.11.dist-info/top_level.txt,sha256=v3MMWIPMQFcI-WuF_dJngHWe9Bb2yH_6p4wat1x4gAc,20
209
- kaggle_environments-1.16.11.dist-info/RECORD,,
206
+ kaggle_environments-1.17.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
207
+ kaggle_environments-1.17.2.dist-info/METADATA,sha256=FeHjMl4YWtcFWOflJV9P6vToQ6ZGJ0n8CZ6ijFsfr-E,10955
208
+ kaggle_environments-1.17.2.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
209
+ kaggle_environments-1.17.2.dist-info/entry_points.txt,sha256=HbVC-LKGQFV6lEEYBYyDTtrkHgdHJUWQ8_qt9KHGqz4,70
210
+ kaggle_environments-1.17.2.dist-info/top_level.txt,sha256=v3MMWIPMQFcI-WuF_dJngHWe9Bb2yH_6p4wat1x4gAc,20
211
+ kaggle_environments-1.17.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any