kaggle-environments 1.16.11__py2.py3-none-any.whl → 1.17.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +18 -8
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +14 -7
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +5 -4
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +33 -17
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +296 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy_test.py +57 -0
- kaggle_environments/envs/open_spiel/observation.py +133 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +416 -0
- kaggle_environments/envs/open_spiel/proxy.py +139 -0
- kaggle_environments/envs/open_spiel/proxy_test.py +64 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +18 -0
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/METADATA +25 -13
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/RECORD +21 -10
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/WHEEL +1 -1
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/entry_points.txt +0 -0
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info/licenses}/LICENSE +0 -0
- {kaggle_environments-1.16.11.dist-info → kaggle_environments-1.17.3.dist-info}/top_level.txt +0 -0
kaggle_environments/__init__.py
CHANGED
|
@@ -20,7 +20,7 @@ from .core import *
|
|
|
20
20
|
from .main import http_request
|
|
21
21
|
from . import errors
|
|
22
22
|
|
|
23
|
-
__version__ = "1.
|
|
23
|
+
__version__ = "1.17.3"
|
|
24
24
|
|
|
25
25
|
__all__ = ["Agent", "environments", "errors", "evaluate", "http_request",
|
|
26
26
|
"make", "register", "utils", "__version__",
|
|
@@ -31,13 +31,23 @@ __all__ = ["Agent", "environments", "errors", "evaluate", "http_request",
|
|
|
31
31
|
for name in listdir(utils.envs_path):
|
|
32
32
|
try:
|
|
33
33
|
env = import_module(f".envs.{name}.{name}", __name__)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
34
|
+
if name == "open_spiel":
|
|
35
|
+
for env_name, env_dict in env.registered_open_spiel_envs.items():
|
|
36
|
+
register(env_name, {
|
|
37
|
+
"agents": env_dict.get("agents"),
|
|
38
|
+
"html_renderer": env_dict.get("html_renderer"),
|
|
39
|
+
"interpreter": env_dict.get("interpreter"),
|
|
40
|
+
"renderer": env_dict.get("renderer"),
|
|
41
|
+
"specification": env_dict.get("specification"),
|
|
42
|
+
})
|
|
43
|
+
else:
|
|
44
|
+
register(name, {
|
|
45
|
+
"agents": getattr(env, "agents", []),
|
|
46
|
+
"html_renderer": getattr(env, "html_renderer", None),
|
|
47
|
+
"interpreter": getattr(env, "interpreter"),
|
|
48
|
+
"renderer": getattr(env, "renderer"),
|
|
49
|
+
"specification": getattr(env, "specification"),
|
|
50
|
+
})
|
|
41
51
|
except Exception as e:
|
|
42
52
|
if "football" not in name:
|
|
43
53
|
print("Loading environment %s failed: %s" % (name, e))
|
|
@@ -148,6 +148,11 @@ class LuxAIS3Env(environment.Environment):
|
|
|
148
148
|
i : max_sensor_range * 2 + 1 - i,
|
|
149
149
|
i : max_sensor_range * 2 + 1 - i,
|
|
150
150
|
].set(val)
|
|
151
|
+
# vision of position at center of update has an extra 10
|
|
152
|
+
update = update.at[
|
|
153
|
+
max_sensor_range,
|
|
154
|
+
max_sensor_range,
|
|
155
|
+
].add(10)
|
|
151
156
|
vision_power_map = jax.lax.dynamic_update_slice(
|
|
152
157
|
vision_power_map,
|
|
153
158
|
update=update + existing_vision_power,
|
|
@@ -227,6 +232,10 @@ class LuxAIS3Env(environment.Environment):
|
|
|
227
232
|
state = state.replace(
|
|
228
233
|
units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask
|
|
229
234
|
)
|
|
235
|
+
|
|
236
|
+
"""spawn relic nodes based on schedule"""
|
|
237
|
+
relic_nodes_mask = (state.steps >= state.relic_spawn_schedule) & (state.relic_spawn_schedule != -1)
|
|
238
|
+
state = state.replace(relic_nodes_mask=relic_nodes_mask)
|
|
230
239
|
|
|
231
240
|
""" process unit movement """
|
|
232
241
|
# 0 is do nothing, 1 is move up, 2 is move right, 3 is move down, 4 is move left, 5 is sap
|
|
@@ -646,11 +655,10 @@ class LuxAIS3Env(environment.Environment):
|
|
|
646
655
|
axis=(0, 1),
|
|
647
656
|
)
|
|
648
657
|
new_tile_types_map = jnp.where(
|
|
649
|
-
state.steps * params.nebula_tile_drift_speed % 1
|
|
658
|
+
(state.steps - 1) * abs(params.nebula_tile_drift_speed) % 1 > state.steps * abs(params.nebula_tile_drift_speed) % 1,
|
|
650
659
|
new_tile_types_map,
|
|
651
660
|
state.map_features.tile_type,
|
|
652
661
|
)
|
|
653
|
-
# new_energy_nodes = state.energy_nodes + jnp.array([1 * jnp.sign(params.energy_node_drift_speed), -1 * jnp.sign(params.energy_node_drift_speed)])
|
|
654
662
|
|
|
655
663
|
energy_node_deltas = jnp.round(
|
|
656
664
|
jax.random.uniform(
|
|
@@ -663,8 +671,6 @@ class LuxAIS3Env(environment.Environment):
|
|
|
663
671
|
energy_node_deltas_symmetric = jnp.stack(
|
|
664
672
|
[-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1
|
|
665
673
|
)
|
|
666
|
-
# TODO symmetric movement
|
|
667
|
-
# energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16)
|
|
668
674
|
energy_node_deltas = jnp.concatenate(
|
|
669
675
|
(energy_node_deltas, energy_node_deltas_symmetric)
|
|
670
676
|
)
|
|
@@ -677,7 +683,7 @@ class LuxAIS3Env(environment.Environment):
|
|
|
677
683
|
),
|
|
678
684
|
)
|
|
679
685
|
new_energy_nodes = jnp.where(
|
|
680
|
-
state.steps * params.energy_node_drift_speed % 1
|
|
686
|
+
(state.steps - 1) * abs(params.energy_node_drift_speed) % 1 > state.steps * abs(params.energy_node_drift_speed) % 1,
|
|
681
687
|
new_energy_nodes,
|
|
682
688
|
state.energy_nodes,
|
|
683
689
|
)
|
|
@@ -688,7 +694,9 @@ class LuxAIS3Env(environment.Environment):
|
|
|
688
694
|
|
|
689
695
|
# Compute relic scores
|
|
690
696
|
def team_relic_score(unit_counts_map):
|
|
691
|
-
|
|
697
|
+
# not all relic nodes are spawned in yet, but relic nodes map ids are precomputed for all to be spawned relic nodes
|
|
698
|
+
# for efficiency. So we check if the relic node (by id) is spawned in yet. relic nodes mask is always increasing so we can do a simple trick below
|
|
699
|
+
scores = (unit_counts_map > 0) & (state.relic_nodes_map_weights <= state.relic_nodes_mask.sum() // 2) & (state.relic_nodes_map_weights > 0)
|
|
692
700
|
return jnp.sum(scores, dtype=jnp.int32)
|
|
693
701
|
|
|
694
702
|
# note we need to recompue unit counts since units can get removed due to collisions
|
|
@@ -771,7 +779,6 @@ class LuxAIS3Env(environment.Environment):
|
|
|
771
779
|
)
|
|
772
780
|
state = self.compute_energy_features(state, params)
|
|
773
781
|
state = self.compute_sensor_masks(state, params)
|
|
774
|
-
|
|
775
782
|
return self.get_obs(state, params=params, key=key), state
|
|
776
783
|
|
|
777
784
|
@functools.partial(jax.jit, static_argnums=(0,))
|
|
@@ -46,6 +46,7 @@ class EnvParams:
|
|
|
46
46
|
min_energy_per_tile: int = -20
|
|
47
47
|
|
|
48
48
|
max_relic_nodes: int = 6
|
|
49
|
+
"""max relic nodes in the entire map. This number should be tuned carefully as relic node spawning code is hardcoded against this number 6"""
|
|
49
50
|
relic_config_size: int = 5
|
|
50
51
|
fog_of_war: bool = True
|
|
51
52
|
"""
|
|
@@ -87,15 +88,15 @@ class EnvParams:
|
|
|
87
88
|
env_params_ranges = dict(
|
|
88
89
|
# map_type=[1],
|
|
89
90
|
unit_move_cost=list(range(1, 6)),
|
|
90
|
-
unit_sensor_range=
|
|
91
|
-
nebula_tile_vision_reduction=list(range(0,
|
|
92
|
-
nebula_tile_energy_reduction=[0,
|
|
91
|
+
unit_sensor_range=[1, 2, 3, 4],
|
|
92
|
+
nebula_tile_vision_reduction=list(range(0, 8)),
|
|
93
|
+
nebula_tile_energy_reduction=[0, 1, 2, 3, 5, 25],
|
|
93
94
|
unit_sap_cost=list(range(30, 51)),
|
|
94
95
|
unit_sap_range=list(range(3, 8)),
|
|
95
96
|
unit_sap_dropoff_factor=[0.25, 0.5, 1],
|
|
96
97
|
unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
|
|
97
98
|
# map randomizations
|
|
98
|
-
nebula_tile_drift_speed=[-0.05, -0.025, 0.025, 0.05],
|
|
99
|
+
nebula_tile_drift_speed=[-0.15, -0.1, -0.05, -0.025, 0.025, 0.05, 0.1, 0.15],
|
|
99
100
|
energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
|
|
100
101
|
energy_node_drift_magnitude=list(range(3, 6)),
|
|
101
102
|
)
|
|
@@ -66,7 +66,10 @@ class EnvState:
|
|
|
66
66
|
relic_nodes_mask: chex.Array
|
|
67
67
|
"""Mask of relic nodes in the environment with shape (N, ) for N max relic nodes"""
|
|
68
68
|
relic_nodes_map_weights: chex.Array
|
|
69
|
-
"""Map of relic nodes in the environment with shape (H, W) for H height, W width.
|
|
69
|
+
"""Map of relic nodes in the environment with shape (H, W) for H height, W width. Each element is equal to the 1-indexed id of the relic node. This is generated from other state"""
|
|
70
|
+
|
|
71
|
+
relic_spawn_schedule: chex.Array
|
|
72
|
+
"""Relic spawn schedule in the environment with shape (N, ) for N max relic nodes. Elements are the game timestep at which the relic node spawns"""
|
|
70
73
|
|
|
71
74
|
map_features: MapTile
|
|
72
75
|
"""Map features in the environment with shape (W, H, 2) for W width, H height
|
|
@@ -121,7 +124,7 @@ class EnvObs:
|
|
|
121
124
|
|
|
122
125
|
def serialize_env_states(env_states: list[EnvState]):
|
|
123
126
|
def serialize_array(root: EnvState, arr, key_path: str = ""):
|
|
124
|
-
if key_path in ["sensor_mask", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights"]:
|
|
127
|
+
if key_path in ["sensor_mask", "relic_nodes_mask", "energy_nodes_mask", "energy_node_fns", "relic_nodes_map_weights", "relic_spawn_schedule"]:
|
|
125
128
|
return None
|
|
126
129
|
if key_path == "relic_nodes":
|
|
127
130
|
return root.relic_nodes[root.relic_nodes_mask].tolist()
|
|
@@ -186,9 +189,10 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
|
|
|
186
189
|
|
|
187
190
|
# TODO (this could be optimized better)
|
|
188
191
|
def update_relic_node(relic_nodes_map_weights, relic_data):
|
|
189
|
-
relic_node, relic_node_config, mask = relic_data
|
|
192
|
+
relic_node, relic_node_config, mask, relic_node_id = relic_data
|
|
190
193
|
start_y = relic_node[1] - relic_config_size // 2
|
|
191
194
|
start_x = relic_node[0] - relic_config_size // 2
|
|
195
|
+
|
|
192
196
|
for dy in range(relic_config_size):
|
|
193
197
|
for dx in range(relic_config_size):
|
|
194
198
|
y, x = start_y + dy, start_x + dx
|
|
@@ -196,14 +200,20 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
|
|
|
196
200
|
jnp.logical_and(y >= 0, x >= 0),
|
|
197
201
|
jnp.logical_and(y < map_height, x < map_width),
|
|
198
202
|
)
|
|
203
|
+
# ensure we don't override previous spawns
|
|
204
|
+
has_points = jnp.logical_and(
|
|
205
|
+
relic_nodes_map_weights > 0,
|
|
206
|
+
relic_nodes_map_weights <= relic_node_id + 1
|
|
207
|
+
)
|
|
199
208
|
relic_nodes_map_weights = jnp.where(
|
|
200
|
-
valid_pos & mask,
|
|
201
|
-
relic_nodes_map_weights.at[x, y].
|
|
209
|
+
valid_pos & mask & jnp.logical_not(has_points) & relic_node_config[dx, dy],
|
|
210
|
+
relic_nodes_map_weights.at[x, y].set(relic_node_config[dx, dy].astype(jnp.int16) * (relic_node_id + 1)),
|
|
202
211
|
relic_nodes_map_weights,
|
|
203
212
|
)
|
|
204
213
|
return relic_nodes_map_weights, None
|
|
205
214
|
|
|
206
215
|
# this is really slow...
|
|
216
|
+
|
|
207
217
|
relic_nodes_map_weights, _ = jax.lax.scan(
|
|
208
218
|
update_relic_node,
|
|
209
219
|
relic_nodes_map_weights,
|
|
@@ -211,8 +221,10 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
|
|
|
211
221
|
generated["relic_nodes"],
|
|
212
222
|
generated["relic_node_configs"],
|
|
213
223
|
generated["relic_nodes_mask"],
|
|
224
|
+
jnp.arange(max_relic_nodes, dtype=jnp.int16) % (max_relic_nodes // 2),
|
|
214
225
|
),
|
|
215
226
|
)
|
|
227
|
+
|
|
216
228
|
state = EnvState(
|
|
217
229
|
units=UnitState(position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16), energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16)),
|
|
218
230
|
units_mask=jnp.zeros(
|
|
@@ -225,9 +237,10 @@ def gen_state(key: chex.PRNGKey, env_params: EnvParams, max_units: int, num_team
|
|
|
225
237
|
energy_nodes_mask=generated["energy_nodes_mask"],
|
|
226
238
|
# energy_field=jnp.zeros(shape=(params.map_height, params.map_width), dtype=jnp.int16),
|
|
227
239
|
relic_nodes=generated["relic_nodes"],
|
|
228
|
-
relic_nodes_mask=
|
|
240
|
+
relic_nodes_mask=jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool), # as relic nodes are spawn in, we start with them all invisible.
|
|
229
241
|
relic_node_configs=generated["relic_node_configs"],
|
|
230
242
|
relic_nodes_map_weights=relic_nodes_map_weights,
|
|
243
|
+
relic_spawn_schedule=generated["relic_spawn_schedule"],
|
|
231
244
|
sensor_mask=jnp.zeros(
|
|
232
245
|
shape=(num_teams, map_height, map_width),
|
|
233
246
|
dtype=jnp.bool,
|
|
@@ -276,7 +289,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
|
|
|
276
289
|
flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2:] # Get indices of two highest values
|
|
277
290
|
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
|
|
278
291
|
|
|
279
|
-
# relic nodes have a fixed density of
|
|
292
|
+
# relic nodes have a fixed density of 20% nearby tiles can yield points
|
|
280
293
|
relic_node_configs = (
|
|
281
294
|
jax.random.randint(
|
|
282
295
|
key,
|
|
@@ -291,18 +304,15 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
|
|
|
291
304
|
>= 7.5
|
|
292
305
|
)
|
|
293
306
|
highest_positions = highest_positions.astype(jnp.int16)
|
|
294
|
-
relic_nodes_mask = relic_nodes_mask.at[0].set(True)
|
|
295
|
-
relic_nodes_mask = relic_nodes_mask.at[1].set(True)
|
|
296
307
|
mirrored_positions = jnp.stack([map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1], dtype=jnp.int16, axis=-1)
|
|
297
308
|
relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
|
|
298
309
|
|
|
299
310
|
key, subkey = jax.random.split(key)
|
|
300
|
-
|
|
301
|
-
relic_nodes_mask_half =
|
|
302
|
-
relic_nodes_mask =
|
|
303
|
-
relic_nodes_mask = relic_nodes_mask.at[max_relic_nodes // 2:].set(relic_nodes_mask_half)
|
|
304
|
-
# import ipdb;ipdb.set_trace()
|
|
311
|
+
num_spawned_relic_nodes = jax.random.randint(key, (1, ), minval=1, maxval=(max_relic_nodes // 2) + 1)
|
|
312
|
+
relic_nodes_mask_half = jnp.arange(max_relic_nodes // 2) < num_spawned_relic_nodes
|
|
313
|
+
relic_nodes_mask = jnp.concat([relic_nodes_mask_half, relic_nodes_mask_half], axis=0)
|
|
305
314
|
relic_node_configs = relic_node_configs.at[max_relic_nodes // 2:].set(relic_node_configs[:max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1])
|
|
315
|
+
# note that relic nodes mask is always increasing.
|
|
306
316
|
|
|
307
317
|
### Generate energy nodes ###
|
|
308
318
|
key, subkey = jax.random.split(key)
|
|
@@ -318,7 +328,6 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
|
|
|
318
328
|
energy_nodes_mask = energy_nodes_mask.at[:max_energy_nodes // 2].set(energy_nodes_mask_half)
|
|
319
329
|
energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2:].set(energy_nodes_mask_half)
|
|
320
330
|
|
|
321
|
-
# TODO (stao): provide more randomization options for energy node functions.
|
|
322
331
|
energy_node_fns = jnp.array(
|
|
323
332
|
[
|
|
324
333
|
[0, 1.2, 1, 4],
|
|
@@ -331,8 +340,14 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
|
|
|
331
340
|
# [1, 4, 0, 0]
|
|
332
341
|
]
|
|
333
342
|
)
|
|
334
|
-
|
|
335
|
-
#
|
|
343
|
+
|
|
344
|
+
# generate a random relic spawn schedule
|
|
345
|
+
# if number is -1, then relic node is never spawned, otherwise spawn at that game timestep
|
|
346
|
+
assert max_relic_nodes == 6, "random map generation is hardcoded to use 6 relic nodes at most per map"
|
|
347
|
+
key, subkey = jax.random.split(key)
|
|
348
|
+
relic_spawn_schedule_half = jax.random.randint(key, (max_relic_nodes //2, ), minval=0, maxval=params.max_steps_in_match // 2) + jnp.arange(3) * (params.max_steps_in_match + 1)
|
|
349
|
+
relic_spawn_schedule = jnp.concat([relic_spawn_schedule_half, relic_spawn_schedule_half], axis=0)
|
|
350
|
+
relic_spawn_schedule = jnp.where(relic_nodes_mask, relic_spawn_schedule, -1)
|
|
336
351
|
|
|
337
352
|
|
|
338
353
|
return dict(
|
|
@@ -343,6 +358,7 @@ def gen_map(key: chex.PRNGKey, params: EnvParams, map_type: int, map_height: int
|
|
|
343
358
|
energy_nodes_mask=energy_nodes_mask,
|
|
344
359
|
relic_nodes_mask=relic_nodes_mask,
|
|
345
360
|
relic_node_configs=relic_node_configs,
|
|
361
|
+
relic_spawn_schedule=relic_spawn_schedule,
|
|
346
362
|
)
|
|
347
363
|
def interpolant(t):
|
|
348
364
|
return t*t*t*(t*(t*6 - 15) + 10)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
function renderer(options) {
|
|
2
|
+
const { environment, step, parent, interactive, isInteractive } = options;
|
|
3
|
+
|
|
4
|
+
const DEFAULT_NUM_ROWS = 6;
|
|
5
|
+
const DEFAULT_NUM_COLS = 7;
|
|
6
|
+
const PLAYER_SYMBOLS = ['O', 'X']; // O: Player 0 (Yellow), X: Player 1 (Red)
|
|
7
|
+
const PLAYER_COLORS = ['#facc15', '#ef4444']; // Yellow for 'O', Red for 'X'
|
|
8
|
+
const EMPTY_CELL_COLOR = '#e5e7eb';
|
|
9
|
+
const BOARD_COLOR = '#3b82f6';
|
|
10
|
+
|
|
11
|
+
const SVG_NS = "http://www.w3.org/2000/svg";
|
|
12
|
+
const CELL_UNIT_SIZE = 100;
|
|
13
|
+
const CIRCLE_RADIUS = CELL_UNIT_SIZE * 0.42;
|
|
14
|
+
const SVG_VIEWBOX_WIDTH = DEFAULT_NUM_COLS * CELL_UNIT_SIZE;
|
|
15
|
+
const SVG_VIEWBOX_HEIGHT = DEFAULT_NUM_ROWS * CELL_UNIT_SIZE;
|
|
16
|
+
|
|
17
|
+
let currentBoardSvgElement = null;
|
|
18
|
+
let currentStatusTextElement = null;
|
|
19
|
+
let currentWinnerTextElement = null;
|
|
20
|
+
let currentMessageBoxElement = typeof document !== 'undefined' ? document.getElementById('messageBox') : null;
|
|
21
|
+
let currentRendererContainer = null;
|
|
22
|
+
let currentTitleElement = null;
|
|
23
|
+
|
|
24
|
+
function _showMessage(message, type = 'info', duration = 3000) {
|
|
25
|
+
if (typeof document === 'undefined' || !document.body) return;
|
|
26
|
+
if (!currentMessageBoxElement) {
|
|
27
|
+
currentMessageBoxElement = document.createElement('div');
|
|
28
|
+
currentMessageBoxElement.id = 'messageBox';
|
|
29
|
+
currentMessageBoxElement.style.position = 'fixed';
|
|
30
|
+
currentMessageBoxElement.style.top = '10px';
|
|
31
|
+
currentMessageBoxElement.style.left = '50%';
|
|
32
|
+
currentMessageBoxElement.style.transform = 'translateX(-50%)';
|
|
33
|
+
currentMessageBoxElement.style.padding = '0.75rem 1rem';
|
|
34
|
+
currentMessageBoxElement.style.borderRadius = '0.375rem';
|
|
35
|
+
currentMessageBoxElement.style.boxShadow = '0 2px 4px rgba(0,0,0,0.1)';
|
|
36
|
+
currentMessageBoxElement.style.zIndex = '1000';
|
|
37
|
+
currentMessageBoxElement.style.opacity = '0';
|
|
38
|
+
currentMessageBoxElement.style.transition = 'opacity 0.3s ease-in-out, background-color 0.3s';
|
|
39
|
+
currentMessageBoxElement.style.fontSize = '0.875rem';
|
|
40
|
+
currentMessageBoxElement.style.fontFamily = "'Inter', sans-serif";
|
|
41
|
+
document.body.appendChild(currentMessageBoxElement);
|
|
42
|
+
}
|
|
43
|
+
currentMessageBoxElement.textContent = message;
|
|
44
|
+
currentMessageBoxElement.style.backgroundColor = type === 'error' ? '#ef4444' : '#10b981';
|
|
45
|
+
currentMessageBoxElement.style.color = 'white';
|
|
46
|
+
currentMessageBoxElement.style.opacity = '1';
|
|
47
|
+
setTimeout(() => { if (currentMessageBoxElement) currentMessageBoxElement.style.opacity = '0'; }, duration);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
function _ensureRendererElements(parentElementToClear, rows, cols) {
|
|
51
|
+
if (!parentElementToClear) return false;
|
|
52
|
+
parentElementToClear.innerHTML = '';
|
|
53
|
+
|
|
54
|
+
currentRendererContainer = document.createElement('div');
|
|
55
|
+
currentRendererContainer.style.display = 'flex';
|
|
56
|
+
currentRendererContainer.style.flexDirection = 'column';
|
|
57
|
+
currentRendererContainer.style.alignItems = 'center';
|
|
58
|
+
currentRendererContainer.style.padding = '20px';
|
|
59
|
+
currentRendererContainer.style.boxSizing = 'border-box';
|
|
60
|
+
currentRendererContainer.style.width = '100%';
|
|
61
|
+
currentRendererContainer.style.height = '100%';
|
|
62
|
+
currentRendererContainer.style.fontFamily = "'Inter', sans-serif";
|
|
63
|
+
|
|
64
|
+
currentTitleElement = document.createElement('h1');
|
|
65
|
+
currentTitleElement.textContent = 'Connect Four';
|
|
66
|
+
currentTitleElement.style.fontSize = '1.875rem';
|
|
67
|
+
currentTitleElement.style.fontWeight = 'bold';
|
|
68
|
+
currentTitleElement.style.marginBottom = '1rem';
|
|
69
|
+
currentTitleElement.style.textAlign = 'center';
|
|
70
|
+
currentTitleElement.style.color = '#2563eb';
|
|
71
|
+
currentRendererContainer.appendChild(currentTitleElement);
|
|
72
|
+
|
|
73
|
+
currentBoardSvgElement = document.createElementNS(SVG_NS, "svg");
|
|
74
|
+
currentBoardSvgElement.setAttribute("viewBox", `0 0 ${SVG_VIEWBOX_WIDTH} ${SVG_VIEWBOX_HEIGHT}`);
|
|
75
|
+
currentBoardSvgElement.setAttribute("preserveAspectRatio", "xMidYMid meet");
|
|
76
|
+
currentBoardSvgElement.style.width = "auto";
|
|
77
|
+
currentBoardSvgElement.style.maxWidth = "500px";
|
|
78
|
+
currentBoardSvgElement.style.maxHeight = `calc(100vh - 200px)`;
|
|
79
|
+
currentBoardSvgElement.style.aspectRatio = `${cols} / ${rows}`;
|
|
80
|
+
currentBoardSvgElement.style.display = "block";
|
|
81
|
+
currentBoardSvgElement.style.margin = "0 auto 20px auto";
|
|
82
|
+
|
|
83
|
+
const boardBgRect = document.createElementNS(SVG_NS, "rect");
|
|
84
|
+
boardBgRect.setAttribute("x", "0");
|
|
85
|
+
boardBgRect.setAttribute("y", "0");
|
|
86
|
+
boardBgRect.setAttribute("width", SVG_VIEWBOX_WIDTH.toString());
|
|
87
|
+
boardBgRect.setAttribute("height", SVG_VIEWBOX_HEIGHT.toString());
|
|
88
|
+
boardBgRect.setAttribute("fill", BOARD_COLOR);
|
|
89
|
+
boardBgRect.setAttribute("rx", (CELL_UNIT_SIZE * 0.1).toString());
|
|
90
|
+
currentBoardSvgElement.appendChild(boardBgRect);
|
|
91
|
+
|
|
92
|
+
// SVG Circles are created with (0,0) being top-left visual circle
|
|
93
|
+
for (let r_visual = 0; r_visual < rows; r_visual++) {
|
|
94
|
+
for (let c_visual = 0; c_visual < cols; c_visual++) {
|
|
95
|
+
const circle = document.createElementNS(SVG_NS, "circle");
|
|
96
|
+
const cx = c_visual * CELL_UNIT_SIZE + CELL_UNIT_SIZE / 2;
|
|
97
|
+
const cy = r_visual * CELL_UNIT_SIZE + CELL_UNIT_SIZE / 2;
|
|
98
|
+
circle.setAttribute("id", `cell-${r_visual}-${c_visual}`);
|
|
99
|
+
circle.setAttribute("cx", cx.toString());
|
|
100
|
+
circle.setAttribute("cy", cy.toString());
|
|
101
|
+
circle.setAttribute("r", CIRCLE_RADIUS.toString());
|
|
102
|
+
circle.setAttribute("fill", EMPTY_CELL_COLOR);
|
|
103
|
+
currentBoardSvgElement.appendChild(circle);
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
currentRendererContainer.appendChild(currentBoardSvgElement);
|
|
107
|
+
|
|
108
|
+
const statusContainer = document.createElement('div');
|
|
109
|
+
statusContainer.style.padding = '10px 15px';
|
|
110
|
+
statusContainer.style.backgroundColor = 'white';
|
|
111
|
+
statusContainer.style.borderRadius = '8px';
|
|
112
|
+
statusContainer.style.boxShadow = '0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -1px rgba(0,0,0,0.06)';
|
|
113
|
+
statusContainer.style.textAlign = 'center';
|
|
114
|
+
statusContainer.style.width = 'auto';
|
|
115
|
+
statusContainer.style.minWidth = '200px';
|
|
116
|
+
statusContainer.style.maxWidth = '90vw';
|
|
117
|
+
currentRendererContainer.appendChild(statusContainer);
|
|
118
|
+
|
|
119
|
+
currentStatusTextElement = document.createElement('p');
|
|
120
|
+
currentStatusTextElement.style.fontSize = '1.1rem';
|
|
121
|
+
currentStatusTextElement.style.fontWeight = '600';
|
|
122
|
+
currentStatusTextElement.style.margin = '0 0 5px 0';
|
|
123
|
+
statusContainer.appendChild(currentStatusTextElement);
|
|
124
|
+
|
|
125
|
+
currentWinnerTextElement = document.createElement('p');
|
|
126
|
+
currentWinnerTextElement.style.fontSize = '1.25rem';
|
|
127
|
+
currentWinnerTextElement.style.fontWeight = '700';
|
|
128
|
+
currentWinnerTextElement.style.margin = '5px 0 0 0';
|
|
129
|
+
statusContainer.appendChild(currentWinnerTextElement);
|
|
130
|
+
|
|
131
|
+
parentElementToClear.appendChild(currentRendererContainer);
|
|
132
|
+
|
|
133
|
+
if (typeof document !== 'undefined' && !document.body.hasAttribute('data-renderer-initialized')) {
|
|
134
|
+
document.body.setAttribute('data-renderer-initialized', 'true');
|
|
135
|
+
}
|
|
136
|
+
return true;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
function _renderBoardDisplay_svg(gameStateToDisplay, displayRows, displayCols) {
|
|
140
|
+
if (!currentBoardSvgElement || !currentStatusTextElement || !currentWinnerTextElement) return;
|
|
141
|
+
|
|
142
|
+
if (!gameStateToDisplay || typeof gameStateToDisplay.board !== 'object' || !Array.isArray(gameStateToDisplay.board) || gameStateToDisplay.board.length === 0) {
|
|
143
|
+
currentStatusTextElement.textContent = "Waiting for game data...";
|
|
144
|
+
currentWinnerTextElement.textContent = "";
|
|
145
|
+
for (let r_visual = 0; r_visual < displayRows; r_visual++) {
|
|
146
|
+
for (let c_visual = 0; c_visual < displayCols; c_visual++) {
|
|
147
|
+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${r_visual}-${c_visual}`);
|
|
148
|
+
if (circleElement) {
|
|
149
|
+
circleElement.setAttribute("fill", EMPTY_CELL_COLOR);
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
return;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
const { board, current_player, is_terminal, winner } = gameStateToDisplay;
|
|
157
|
+
|
|
158
|
+
for (let r_data = 0; r_data < displayRows; r_data++) {
|
|
159
|
+
const dataRow = board[r_data];
|
|
160
|
+
if (!dataRow || !Array.isArray(dataRow) || dataRow.length !== displayCols) {
|
|
161
|
+
// Error handling for malformed row
|
|
162
|
+
for (let c_fill = 0; c_fill < displayCols; c_fill++) {
|
|
163
|
+
// Determine visual row for error display. If r_data=0 is top data,
|
|
164
|
+
// and we want to flip, then this error is for visual row (displayRows-1)-0.
|
|
165
|
+
const visual_row_for_error = (displayRows - 1) - r_data;
|
|
166
|
+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${visual_row_for_error}-${c_fill}`);
|
|
167
|
+
if (circleElement) circleElement.setAttribute("fill", '#FF00FF'); // Magenta for error
|
|
168
|
+
}
|
|
169
|
+
continue;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
const visual_svg_row_index = (displayRows - 1) - r_data;
|
|
173
|
+
|
|
174
|
+
for (let c_data = 0; c_data < displayCols; c_data++) { // c_data iterates through columns of `board[r_data]`
|
|
175
|
+
const originalCellValue = dataRow[c_data];
|
|
176
|
+
const cellValueForComparison = String(originalCellValue).trim().toLowerCase();
|
|
177
|
+
|
|
178
|
+
// The column index for SVG is the same as c_data
|
|
179
|
+
const visual_svg_col_index = c_data;
|
|
180
|
+
const circleElement = currentBoardSvgElement.querySelector(`#cell-${visual_svg_row_index}-${visual_svg_col_index}`);
|
|
181
|
+
|
|
182
|
+
if (!circleElement) continue;
|
|
183
|
+
|
|
184
|
+
let fillColor = EMPTY_CELL_COLOR;
|
|
185
|
+
if (cellValueForComparison === "o") {
|
|
186
|
+
fillColor = PLAYER_COLORS[0]; // Yellow
|
|
187
|
+
} else if (cellValueForComparison === "x") {
|
|
188
|
+
fillColor = PLAYER_COLORS[1]; // Red
|
|
189
|
+
}
|
|
190
|
+
circleElement.setAttribute("fill", fillColor);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
currentStatusTextElement.innerHTML = '';
|
|
195
|
+
currentWinnerTextElement.innerHTML = '';
|
|
196
|
+
if (is_terminal) {
|
|
197
|
+
currentStatusTextElement.textContent = "Game Over!";
|
|
198
|
+
if (winner !== null && winner !== undefined) {
|
|
199
|
+
if (String(winner).toLowerCase() === 'draw') {
|
|
200
|
+
currentWinnerTextElement.textContent = "It's a Draw!";
|
|
201
|
+
} else {
|
|
202
|
+
let winnerSymbolDisplay, winnerColorDisplay;
|
|
203
|
+
if (String(winner).toLowerCase() === "o") {
|
|
204
|
+
winnerSymbolDisplay = PLAYER_SYMBOLS[0];
|
|
205
|
+
winnerColorDisplay = PLAYER_COLORS[0];
|
|
206
|
+
} else if (String(winner).toLowerCase() === "x") {
|
|
207
|
+
winnerSymbolDisplay = PLAYER_SYMBOLS[1];
|
|
208
|
+
winnerColorDisplay = PLAYER_COLORS[1];
|
|
209
|
+
}
|
|
210
|
+
if (winnerSymbolDisplay) {
|
|
211
|
+
currentWinnerTextElement.innerHTML = `Player <span style="color: ${winnerColorDisplay}; font-weight: bold;">${winnerSymbolDisplay}</span> Wins!`;
|
|
212
|
+
} else {
|
|
213
|
+
currentWinnerTextElement.textContent = `Winner: ${String(winner).toUpperCase()}`;
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
} else { currentWinnerTextElement.textContent = "Game ended."; }
|
|
217
|
+
} else {
|
|
218
|
+
let playerSymbolToDisplay, playerColorToDisplay;
|
|
219
|
+
if (String(current_player).toLowerCase() === "o") {
|
|
220
|
+
playerSymbolToDisplay = PLAYER_SYMBOLS[0];
|
|
221
|
+
playerColorToDisplay = PLAYER_COLORS[0];
|
|
222
|
+
} else if (String(current_player).toLowerCase() === "x") {
|
|
223
|
+
playerSymbolToDisplay = PLAYER_SYMBOLS[1];
|
|
224
|
+
playerColorToDisplay = PLAYER_COLORS[1];
|
|
225
|
+
}
|
|
226
|
+
if (playerSymbolToDisplay) {
|
|
227
|
+
currentStatusTextElement.innerHTML = `Current Player: <span style="color: ${playerColorToDisplay}; font-weight: bold;">${playerSymbolToDisplay}</span>`;
|
|
228
|
+
} else {
|
|
229
|
+
currentStatusTextElement.textContent = "Waiting for player...";
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// --- Main execution logic ---
|
|
235
|
+
if (!_ensureRendererElements(parent, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS)) {
|
|
236
|
+
if (parent && typeof parent.innerHTML !== 'undefined') {
|
|
237
|
+
parent.innerHTML = "<p style='color:red; font-family: sans-serif;'>Critical Error: Renderer element setup failed.</p>";
|
|
238
|
+
}
|
|
239
|
+
return;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
if (!environment || !environment.steps || !environment.steps[step]) {
|
|
243
|
+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
244
|
+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Initializing environment...";
|
|
245
|
+
return;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
const currentStepAgents = environment.steps[step];
|
|
249
|
+
if (!currentStepAgents || !Array.isArray(currentStepAgents) || currentStepAgents.length === 0) {
|
|
250
|
+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
251
|
+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Waiting for agent data...";
|
|
252
|
+
return;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
const gameMasterAgentIndex = currentStepAgents.length - 1;
|
|
256
|
+
const gameMasterAgent = currentStepAgents[gameMasterAgentIndex];
|
|
257
|
+
|
|
258
|
+
if (!gameMasterAgent || typeof gameMasterAgent.observation === 'undefined') {
|
|
259
|
+
_renderBoardDisplay_svg(null, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
260
|
+
if(currentStatusTextElement) currentStatusTextElement.textContent = "Waiting for observation data...";
|
|
261
|
+
return;
|
|
262
|
+
}
|
|
263
|
+
const observationForRenderer = gameMasterAgent.observation;
|
|
264
|
+
|
|
265
|
+
let gameSpecificState = null;
|
|
266
|
+
|
|
267
|
+
if (observationForRenderer && typeof observationForRenderer.observation_string === 'string' && observationForRenderer.observation_string.trim() !== '') {
|
|
268
|
+
try {
|
|
269
|
+
gameSpecificState = JSON.parse(observationForRenderer.observation_string);
|
|
270
|
+
} catch (e) {
|
|
271
|
+
_showMessage("Error: Corrupted game state (obs_string).", 'error');
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
if (!gameSpecificState && observationForRenderer && typeof observationForRenderer.json === 'string' && observationForRenderer.json.trim() !== '') {
|
|
276
|
+
try {
|
|
277
|
+
gameSpecificState = JSON.parse(observationForRenderer.json);
|
|
278
|
+
} catch (e) {
|
|
279
|
+
_showMessage("Error: Corrupted game state (json).", 'error');
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if (!gameSpecificState && observationForRenderer &&
|
|
284
|
+
Array.isArray(observationForRenderer.board) &&
|
|
285
|
+
typeof observationForRenderer.current_player !== 'undefined'
|
|
286
|
+
) {
|
|
287
|
+
if( (observationForRenderer.board.length === DEFAULT_NUM_ROWS &&
|
|
288
|
+
(observationForRenderer.board.length === 0 ||
|
|
289
|
+
(Array.isArray(observationForRenderer.board[0]) && observationForRenderer.board[0].length === DEFAULT_NUM_COLS)))
|
|
290
|
+
){
|
|
291
|
+
gameSpecificState = observationForRenderer;
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
_renderBoardDisplay_svg(gameSpecificState, DEFAULT_NUM_ROWS, DEFAULT_NUM_COLS);
|
|
296
|
+
}
|