kaggle-environments 1.15.2__py2.py3-none-any.whl → 1.16.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (69) hide show
  1. kaggle_environments/__init__.py +1 -1
  2. kaggle_environments/envs/chess/chess.js +63 -22
  3. kaggle_environments/envs/chess/chess.json +4 -4
  4. kaggle_environments/envs/chess/chess.py +209 -51
  5. kaggle_environments/envs/chess/test_chess.py +43 -1
  6. kaggle_environments/envs/connectx/connectx.ipynb +3183 -0
  7. kaggle_environments/envs/football/football.ipynb +75 -0
  8. kaggle_environments/envs/halite/halite.ipynb +44736 -0
  9. kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +112 -0
  10. kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
  11. kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
  12. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
  13. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
  14. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
  15. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
  16. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
  17. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
  18. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
  19. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
  20. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
  21. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
  22. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
  23. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
  24. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
  25. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
  26. kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
  27. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
  28. kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
  29. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
  30. kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
  31. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
  32. kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
  33. kaggle_environments/envs/lux_ai_2021/README.md +3 -0
  34. kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
  35. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
  36. kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
  37. kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
  38. kaggle_environments/envs/lux_ai_s2/.gitignore +1 -0
  39. kaggle_environments/envs/lux_ai_s2/README.md +21 -0
  40. kaggle_environments/envs/lux_ai_s2/luxai_s2/.DS_Store +0 -0
  41. kaggle_environments/envs/lux_ai_s2/luxai_s2/map_generator/.DS_Store +0 -0
  42. kaggle_environments/envs/lux_ai_s3/README.md +21 -0
  43. kaggle_environments/envs/lux_ai_s3/agents.py +4 -0
  44. kaggle_environments/envs/lux_ai_s3/index.html +42 -0
  45. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
  46. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +138 -0
  47. kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
  48. kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +924 -0
  49. kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +13 -0
  50. kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
  51. kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +140 -0
  52. kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +270 -0
  53. kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +30 -0
  54. kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +399 -0
  55. kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
  56. kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +187 -0
  57. kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +71 -0
  58. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
  59. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +27 -0
  60. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
  61. kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +53 -0
  62. kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
  63. kaggle_environments/envs/tictactoe/tictactoe.ipynb +1393 -0
  64. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/METADATA +2 -2
  65. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/RECORD +69 -11
  66. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/WHEEL +1 -1
  67. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/LICENSE +0 -0
  68. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/entry_points.txt +0 -0
  69. {kaggle_environments-1.15.2.dist-info → kaggle_environments-1.16.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,924 @@
1
+ import functools
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import chex
5
+ import gymnax
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from gymnax.environments import environment, spaces
10
+ from jax import lax
11
+
12
+ from luxai_s3.params import EnvParams, env_params_ranges
13
+ from luxai_s3.spaces import MultiDiscrete
14
+ from luxai_s3.state import (
15
+ ASTEROID_TILE,
16
+ ENERGY_NODE_FNS,
17
+ NEBULA_TILE,
18
+ EnvObs,
19
+ EnvState,
20
+ MapTile,
21
+ UnitState,
22
+ gen_state
23
+ )
24
+ from luxai_s3.pygame_render import LuxAIPygameRenderer
25
+
26
+
27
+ class LuxAIS3Env(environment.Environment):
28
+ def __init__(
29
+ self, auto_reset=False, fixed_env_params: EnvParams = EnvParams(), **kwargs
30
+ ):
31
+ super().__init__(**kwargs)
32
+ self.renderer = LuxAIPygameRenderer()
33
+ self.auto_reset = auto_reset
34
+ self.fixed_env_params = fixed_env_params
35
+ """fixed env params for concrete/static values. Necessary for jit/vmap capability with randomly sampled maps which must of consistent shape"""
36
+
37
+ @property
38
+ def default_params(self) -> EnvParams:
39
+ params = EnvParams()
40
+ params = jax.tree_map(jax.numpy.array, params)
41
+ return params
42
+
43
+ def compute_unit_counts_map(self, state: EnvState, params: EnvParams):
44
+ # map of total units per team on each tile, shape (num_teams, map_width, map_height)
45
+ unit_counts_map = jnp.zeros(
46
+ (self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height), dtype=jnp.int16
47
+ )
48
+
49
+ def update_unit_counts_map(unit_position, unit_mask, unit_counts_map):
50
+ unit_counts_map = unit_counts_map.at[
51
+ unit_position[0], unit_position[1]
52
+ ].add(unit_mask.astype(jnp.int16))
53
+ return unit_counts_map
54
+
55
+ for t in range(self.fixed_env_params.num_teams):
56
+ unit_counts_map = unit_counts_map.at[t].add(
57
+ jnp.sum(
58
+ jax.vmap(update_unit_counts_map, in_axes=(0, 0, None), out_axes=0)(
59
+ state.units.position[t], state.units_mask[t], unit_counts_map[t]
60
+ ),
61
+ axis=0,
62
+ dtype=jnp.int16
63
+ )
64
+ )
65
+ return unit_counts_map
66
+
67
+ def compute_energy_features(self, state: EnvState, params: EnvParams):
68
+ # first compute a array of shape (map_height, map_width, num_energy_nodes) with values equal to the distance of the tile to the energy node
69
+ mm = jnp.meshgrid(jnp.arange(self.fixed_env_params.map_width), jnp.arange(self.fixed_env_params.map_height))
70
+ mm = jnp.stack([mm[0], mm[1]]).T.astype(jnp.int16) # mm[x, y] gives [x, y]
71
+ distances_to_nodes = jax.vmap(lambda pos: jnp.linalg.norm(mm - pos, axis=-1))(
72
+ state.energy_nodes
73
+ )
74
+
75
+ def compute_energy_field(node_fn_spec, distances_to_node, mask):
76
+ fn_i, x, y, z = node_fn_spec
77
+ return jnp.where(
78
+ mask,
79
+ lax.switch(
80
+ fn_i.astype(jnp.int16), ENERGY_NODE_FNS, distances_to_node, x, y, z
81
+ ),
82
+ jnp.zeros_like(distances_to_node),
83
+ )
84
+
85
+ energy_field = jax.vmap(compute_energy_field)(
86
+ state.energy_node_fns, distances_to_nodes, state.energy_nodes_mask
87
+ )
88
+ energy_field = jnp.where(
89
+ energy_field.mean() < 0.25,
90
+ energy_field + (0.25 - energy_field.mean()),
91
+ energy_field,
92
+ )
93
+ energy_field = jnp.round(energy_field.sum(0)).astype(jnp.int16)
94
+ energy_field = jnp.clip(
95
+ energy_field, params.min_energy_per_tile, params.max_energy_per_tile
96
+ )
97
+ state = state.replace(
98
+ map_features=state.map_features.replace(energy=energy_field)
99
+ )
100
+ return state
101
+
102
+ def compute_sensor_masks(self, state, params: EnvParams):
103
+ """Compute the vision power and sensor mask for both teams
104
+
105
+ Algorithm:
106
+
107
+ For each team, generate a integer vision power array over the map.
108
+ For each unit in team, add unit sensor range value (its kind of like the units sensing power/depth) to each tile the unit's sensor range
109
+ Clamp the vision power array to range [0, unit_sensing_range].
110
+
111
+ With 2 vision power maps, take the nebula vision mask * nebula vision power and subtract it from the vision power maps.
112
+ Now any time the vision power map has value > 0, the team can sense the tile. This forms the sensor mask
113
+ """
114
+
115
+ max_sensor_range = env_params_ranges["unit_sensor_range"][-1]
116
+ vision_power_map_padding = max_sensor_range
117
+ vision_power_map = jnp.zeros(
118
+ shape=(
119
+ self.fixed_env_params.num_teams,
120
+ self.fixed_env_params.map_height + 2 * vision_power_map_padding,
121
+ self.fixed_env_params.map_width + 2 * vision_power_map_padding,
122
+ ),
123
+ dtype=jnp.int16,
124
+ )
125
+
126
+ # Update sensor mask based on the sensor range
127
+ def update_vision_power_map(unit_pos, vision_power_map):
128
+ x, y = unit_pos
129
+ existing_vision_power = jax.lax.dynamic_slice(
130
+ vision_power_map,
131
+ start_indices=(
132
+ x - max_sensor_range + vision_power_map_padding,
133
+ y - max_sensor_range + vision_power_map_padding,
134
+ ),
135
+ slice_sizes=(
136
+ max_sensor_range * 2 + 1,
137
+ max_sensor_range * 2 + 1,
138
+ ),
139
+ )
140
+ update = jnp.zeros_like(existing_vision_power)
141
+ for i in range(max_sensor_range + 1):
142
+ val = jnp.where(i > max_sensor_range - params.unit_sensor_range - 1, i + 1 - (max_sensor_range - params.unit_sensor_range), 0).astype(jnp.int16)
143
+ update = update.at[
144
+ i : max_sensor_range * 2 + 1 - i,
145
+ i : max_sensor_range * 2 + 1 - i,
146
+ ].set(val)
147
+ vision_power_map = jax.lax.dynamic_update_slice(
148
+ vision_power_map,
149
+ update=update + existing_vision_power,
150
+ start_indices=(
151
+ x - max_sensor_range + vision_power_map_padding,
152
+ y - max_sensor_range + vision_power_map_padding,
153
+ ),
154
+ )
155
+ return vision_power_map
156
+
157
+ # Apply the sensor mask update for all units of both teams
158
+ def update_unit_vision_power_map(unit_pos, unit_mask, vision_power_map):
159
+ return jax.lax.cond(
160
+ unit_mask,
161
+ lambda: update_vision_power_map(unit_pos, vision_power_map),
162
+ lambda: vision_power_map,
163
+ )
164
+
165
+ def update_team_vision_power_map(team_units, unit_mask, vision_power_map):
166
+ def body_fun(carry, i):
167
+ vision_power_map = carry
168
+ return (
169
+ update_unit_vision_power_map(
170
+ team_units.position[i], unit_mask[i], vision_power_map
171
+ ),
172
+ None,
173
+ )
174
+
175
+ vision_power_map, _ = jax.lax.scan(
176
+ body_fun, vision_power_map, jnp.arange(self.fixed_env_params.max_units)
177
+ )
178
+ return vision_power_map
179
+
180
+ vision_power_map = jax.vmap(update_team_vision_power_map)(
181
+ state.units, state.units_mask, vision_power_map
182
+ )
183
+ vision_power_map = vision_power_map[
184
+ :,
185
+ vision_power_map_padding:-vision_power_map_padding,
186
+ vision_power_map_padding:-vision_power_map_padding,
187
+ ]
188
+ # handle nebula tiles
189
+ vision_power_map = (
190
+ vision_power_map
191
+ - (state.map_features.tile_type == NEBULA_TILE).astype(jnp.int16)
192
+ * params.nebula_tile_vision_reduction
193
+ )
194
+
195
+ sensor_mask = vision_power_map > 0
196
+ state = state.replace(sensor_mask=sensor_mask)
197
+ state = state.replace(vision_power_map=vision_power_map)
198
+ return state
199
+
200
+ # @functools.partial(jax.jit, static_argnums=(0, 4))
201
+ def step_env(
202
+ self,
203
+ key: chex.PRNGKey,
204
+ state: EnvState,
205
+ action: Union[int, float, chex.Array],
206
+ params: EnvParams,
207
+ ) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
208
+
209
+ state = self.compute_energy_features(state, params)
210
+
211
+ action = jnp.stack([action["player_0"], action["player_1"]])
212
+
213
+ # remove all units if the match ended in the previous step indicated by a reset of match_steps to 0
214
+ state = state.replace(
215
+ units_mask=jnp.where(
216
+ state.match_steps == 0,
217
+ jnp.zeros_like(state.units_mask),
218
+ state.units_mask,
219
+ )
220
+ )
221
+ """remove units that have less than 0 energy"""
222
+ # we remove units at the start of the timestep so that the visualizer can show the unit with negative energy and is marked for removal soon.
223
+ state = state.replace(
224
+ units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask
225
+ )
226
+
227
+ """ process unit movement """
228
+ # 0 is do nothing, 1 is move up, 2 is move right, 3 is move down, 4 is move left, 5 is sap
229
+ # Define movement directions
230
+ directions = jnp.array(
231
+ [
232
+ [0, 0], # Do nothing
233
+ [0, -1], # Move up
234
+ [1, 0], # Move right
235
+ [0, 1], # Move down
236
+ [-1, 0], # Move left
237
+ ],
238
+ dtype=jnp.int16,
239
+ )
240
+
241
+ def move_unit(unit: UnitState, action, mask):
242
+ new_pos = unit.position + directions[action]
243
+ # Check if the new position is on a map feature of value 2
244
+ is_blocked = (
245
+ state.map_features.tile_type[new_pos[0], new_pos[1]] == ASTEROID_TILE
246
+ )
247
+ enough_energy = unit.energy >= params.unit_move_cost
248
+ # If blocked, keep the original position
249
+ # new_pos = jnp.where(is_blocked, unit.position, new_pos)
250
+ # Ensure the new position is within the map boundaries
251
+ new_pos = jnp.clip(
252
+ new_pos,
253
+ 0,
254
+ jnp.array(
255
+ [params.map_width - 1, params.map_height - 1], dtype=jnp.int16
256
+ ),
257
+ )
258
+ unit_moved = (
259
+ mask & ~is_blocked & enough_energy & (action < 5) & (action > 0)
260
+ )
261
+ # Update the unit's position only if it's active. Note energy is used if unit tries to move off map. Energy is not used if unit tries to move into an asteroid tile.
262
+ return UnitState(
263
+ position=jnp.where(unit_moved, new_pos, unit.position),
264
+ energy=jnp.where(
265
+ unit_moved, unit.energy - params.unit_move_cost, unit.energy
266
+ ),
267
+ )
268
+
269
+ # Move units for both teams
270
+ move_actions = action[..., 0]
271
+ state = state.replace(
272
+ units=jax.vmap(
273
+ lambda team_units, team_action, team_mask: jax.vmap(
274
+ move_unit, in_axes=(0, 0, 0)
275
+ )(team_units, team_action, team_mask),
276
+ in_axes=(0, 0, 0),
277
+ )(state.units, move_actions, state.units_mask)
278
+ )
279
+
280
+ original_unit_energy = state.units.energy
281
+ """original amount of energy of all units"""
282
+
283
+ """apply sap actions"""
284
+ sap_action_mask = action[..., 0] == 5
285
+ sap_action_deltas = action[..., 1:]
286
+
287
+ def sap_unit(
288
+ current_energy: jnp.ndarray,
289
+ all_units: UnitState,
290
+ sap_action_mask,
291
+ sap_action_deltas,
292
+ units_mask,
293
+ ):
294
+ # TODO (stao): clean up this code. It is probably slower than it needs be and could be vmapped perhaps.
295
+ for t in range(self.fixed_env_params.num_teams):
296
+ other_team_ids = jnp.array(
297
+ [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t]
298
+ )
299
+ team_sap_action_deltas = sap_action_deltas[t] # (max_units, 2)
300
+ team_sap_action_mask = sap_action_mask[t]
301
+ other_team_unit_mask = units_mask[
302
+ other_team_ids
303
+ ] # (other_teams, max_units)
304
+ team_sapped_positions = (
305
+ all_units.position[t] + team_sap_action_deltas
306
+ ) # (max_units, 2)
307
+ # whether the unit is really sapping or not (needs to exist, have enough energy, and a valid sap action)
308
+ team_unit_sapped = (
309
+ units_mask[t]
310
+ & team_sap_action_mask
311
+ & (current_energy[t, :, 0] >= params.unit_sap_cost)
312
+ & (
313
+ jnp.max(jnp.abs(team_sap_action_deltas), axis=-1)
314
+ <= params.unit_sap_range
315
+ )
316
+ ) # (max_units)
317
+ team_unit_sapped = (
318
+ team_unit_sapped
319
+ & (team_sapped_positions >= 0).all(-1)
320
+ & (team_sapped_positions[:, 0] < self.fixed_env_params.map_width)
321
+ & (team_sapped_positions[:, 1] < self.fixed_env_params.map_height)
322
+ )
323
+ # the number of times other units are sapped
324
+ other_units_sapped_count = jnp.sum(
325
+ team_unit_sapped[None, None, :]
326
+ & jnp.all(
327
+ all_units.position[other_team_ids][:, :, None]
328
+ == team_sapped_positions[None],
329
+ axis=-1,
330
+ ),
331
+ axis=-1,
332
+ dtype=jnp.int16,
333
+ ) # (len(other_team_ids), max_units)
334
+ # remove unit_sap_cost energy from opposition units that were in the middle of a sap action.
335
+ all_units = all_units.replace(
336
+ energy=all_units.energy.at[other_team_ids].set(
337
+ jnp.where(
338
+ other_team_unit_mask[:, :, None]
339
+ & (other_units_sapped_count[:, :, None] > 0),
340
+ all_units.energy[other_team_ids]
341
+ - params.unit_sap_cost
342
+ * other_units_sapped_count[:, :, None],
343
+ all_units.energy[other_team_ids],
344
+ )
345
+ )
346
+ )
347
+
348
+ # remove unit_sap_cost * unit_sap_dropoff_factor energy from opposition units that were on tiles adjacent to the center of a sap action.
349
+ adjacent_offsets = jnp.array(
350
+ [
351
+ [-1, -1],
352
+ [-1, 0],
353
+ [-1, 1],
354
+ [0, -1],
355
+ [0, 1],
356
+ [1, -1],
357
+ [1, 0],
358
+ [1, 1],
359
+ ], dtype=jnp.int16
360
+ )
361
+ team_sapped_adjacent_positions = (
362
+ team_sapped_positions[:, None, :] + adjacent_offsets
363
+ ) # (max_units, len(adjacent_offsets), 2)
364
+ other_units_adjacent_sapped_count = jnp.sum(
365
+ team_unit_sapped[None, None, :, None]
366
+ & jnp.all(
367
+ all_units.position[other_team_ids][:, :, None, None]
368
+ == team_sapped_adjacent_positions[None],
369
+ axis=-1,
370
+ ),
371
+ axis=(-1, -2),
372
+ dtype=jnp.int16,
373
+ ) # (len(other_team_ids), max_units)
374
+ all_units = all_units.replace(
375
+ energy=all_units.energy.at[other_team_ids].set(
376
+ jnp.where(
377
+ other_team_unit_mask[:, :, None]
378
+ & (other_units_adjacent_sapped_count[:, :, None] > 0),
379
+ all_units.energy[other_team_ids]
380
+ - jnp.array(
381
+ params.unit_sap_cost.astype(jnp.float32)
382
+ * params.unit_sap_dropoff_factor
383
+ * other_units_adjacent_sapped_count[:, :, None].astype(jnp.float32),
384
+ dtype=jnp.int16,
385
+ ),
386
+ all_units.energy[other_team_ids],
387
+ )
388
+ )
389
+ )
390
+
391
+ # remove unit_sap_cost energy from units that tried to sap some position within the unit's range
392
+ all_units = all_units.replace(
393
+ energy=all_units.energy.at[t].set(
394
+ jnp.where(
395
+ team_unit_sapped[:, None],
396
+ all_units.energy[t] - params.unit_sap_cost,
397
+ all_units.energy[t],
398
+ )
399
+ )
400
+ )
401
+ return all_units
402
+
403
+ state = state.replace(
404
+ units=sap_unit(
405
+ original_unit_energy,
406
+ state.units,
407
+ sap_action_mask,
408
+ sap_action_deltas,
409
+ state.units_mask,
410
+ )
411
+ )
412
+
413
+ """resolve collisions and energy void fields"""
414
+
415
+ # compute energy void fields for all teams and the energy + unit counts
416
+ unit_aggregate_energy_void_map = jnp.zeros(
417
+ shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
418
+ dtype=jnp.int16,
419
+ )
420
+ unit_counts_map = self.compute_unit_counts_map(state, params)
421
+ unit_aggregate_energy_map = jnp.zeros(
422
+ shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
423
+ dtype=jnp.int16,
424
+ )
425
+ for t in range(self.fixed_env_params.num_teams):
426
+
427
+ def scan_body(carry, x):
428
+ agg_energy_void_map, agg_energy_map = carry
429
+ unit_energy, unit_position, unit_mask = x
430
+ agg_energy_map = agg_energy_map.at[
431
+ unit_position[0], unit_position[1]
432
+ ].add(unit_energy[0] * unit_mask.astype(jnp.int16))
433
+ for deltas in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
434
+ new_pos = unit_position + jnp.array(deltas, dtype=jnp.int16)
435
+ in_map = (
436
+ (new_pos[0] >= 0)
437
+ & (new_pos[0] < self.fixed_env_params.map_width)
438
+ & (new_pos[1] >= 0)
439
+ & (new_pos[1] < self.fixed_env_params.map_height)
440
+ )
441
+ agg_energy_void_map = agg_energy_void_map.at[
442
+ new_pos[0], new_pos[1]
443
+ ].add(unit_energy[0] * unit_mask.astype(jnp.int16) * in_map.astype(jnp.int16))
444
+ return (agg_energy_void_map, agg_energy_map), None
445
+
446
+ agg_energy_void_map, agg_energy_map = jax.lax.scan(
447
+ scan_body,
448
+ (unit_aggregate_energy_void_map[t], unit_aggregate_energy_map[t]),
449
+ (original_unit_energy[t], state.units.position[t], state.units_mask[t]),
450
+ )[0]
451
+ unit_aggregate_energy_void_map = unit_aggregate_energy_void_map.at[t].add(
452
+ agg_energy_void_map
453
+ )
454
+ unit_aggregate_energy_map = unit_aggregate_energy_map.at[t].add(
455
+ agg_energy_map
456
+ )
457
+
458
+ # resolve collisions and keep only the surviving units
459
+ for t in range(self.fixed_env_params.num_teams):
460
+ other_team_ids = jnp.array(
461
+ [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t]
462
+ )
463
+ # get the energy map for the current team
464
+ opposing_unit_counts_map = unit_counts_map[other_team_ids].sum(
465
+ axis=0
466
+ ) # (map_width, map_height)
467
+ team_energy_map = unit_aggregate_energy_map[t]
468
+ opposing_aggregate_energy_map = unit_aggregate_energy_map[
469
+ other_team_ids
470
+ ].max(
471
+ axis=0
472
+ ) # (map_width, map_height)
473
+ # unit survives if there are opposing units on the tile, and if the opposing unit stack has less energy on the tile than the current unit
474
+ surviving_unit_mask = jax.vmap(
475
+ lambda unit_position: (
476
+ opposing_unit_counts_map[unit_position[0], unit_position[1]] == 0
477
+ )
478
+ | (
479
+ opposing_aggregate_energy_map[unit_position[0], unit_position[1]]
480
+ < team_energy_map[unit_position[0], unit_position[1]]
481
+ )
482
+ )(state.units.position[t])
483
+ state = state.replace(
484
+ units_mask=state.units_mask.at[t].set(
485
+ surviving_unit_mask & state.units_mask[t]
486
+ )
487
+ )
488
+ # apply energy void fields
489
+ for t in range(self.fixed_env_params.num_teams):
490
+ other_team_ids = jnp.array(
491
+ [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t]
492
+ )
493
+ oppposition_energy_void_map = unit_aggregate_energy_void_map[
494
+ other_team_ids
495
+ ].sum(
496
+ axis=0
497
+ ) # (map_width, map_height)
498
+ # unit on team t loses energy to void field equal to params.unit_energy_void_factor * void_energy / num units stacked with unit on the same tile
499
+ team_unit_energy = state.units.energy[t] - jnp.floor(
500
+ jax.vmap(
501
+ lambda unit_position: params.unit_energy_void_factor
502
+ * oppposition_energy_void_map[unit_position[0], unit_position[1]].astype(jnp.float32)
503
+ / unit_counts_map[t][unit_position[0], unit_position[1]].astype(jnp.float32)
504
+ )(state.units.position[t])[..., None]
505
+ ).astype(jnp.int16)
506
+ state = state.replace(
507
+ units=state.units.replace(
508
+ energy=state.units.energy.at[t].set(team_unit_energy)
509
+ )
510
+ )
511
+
512
+ """apply energy field to the units"""
513
+
514
+ # Update unit energy based on the energy field and nebula tileof their current position
515
+ def update_unit_energy(unit: UnitState, mask):
516
+ x, y = unit.position
517
+ energy_gain = (
518
+ state.map_features.energy[x, y]
519
+ - (state.map_features.tile_type[x, y] == NEBULA_TILE).astype(jnp.int16)
520
+ * params.nebula_tile_energy_reduction
521
+ )
522
+ # if energy gain is less than 0
523
+ # new_energy = jnp.where((unit.energy < 0) & (energy_gain < 0))
524
+ new_energy = jnp.clip(
525
+ unit.energy + energy_gain,
526
+ params.min_unit_energy,
527
+ params.max_unit_energy,
528
+ )
529
+ # if unit already had negative energy due to opposition units and after energy field/nebula tile it is still below 0, then it will be removed next step
530
+ # and we keep its energy value at whatever it is
531
+ new_energy = jnp.where(
532
+ (unit.energy < 0) & (unit.energy + energy_gain < 0),
533
+ unit.energy,
534
+ new_energy,
535
+ )
536
+ return UnitState(
537
+ position=unit.position, energy=jnp.where(mask, new_energy, unit.energy)
538
+ )
539
+
540
+ # Apply the energy update for all units of both teams
541
+ state = state.replace(
542
+ units=jax.vmap(
543
+ lambda team_units, team_mask: jax.vmap(update_unit_energy)(
544
+ team_units, team_mask
545
+ )
546
+ )(state.units, state.units_mask)
547
+ )
548
+
549
+ """spawn new units in"""
550
+ spawn_units_in = state.match_steps % params.spawn_rate == 0
551
+
552
+ # TODO (stao): only logic in code that probably doesn't not handle more than 2 teams, everything else is vmapped across teams
553
+ def spawn_team_units(state: EnvState):
554
+ team_0_unit_count = state.units_mask[0].sum()
555
+ team_1_unit_count = state.units_mask[1].sum()
556
+ team_0_new_unit_id = state.units_mask[0].argmin()
557
+ team_1_new_unit_id = state.units_mask[1].argmin()
558
+ state = state.replace(
559
+ units=state.units.replace(
560
+ position=jnp.where(
561
+ team_0_unit_count < params.max_units,
562
+ state.units.position.at[0, team_0_new_unit_id, :].set(
563
+ jnp.array([0, 0], dtype=jnp.int16)
564
+ ),
565
+ state.units.position,
566
+ )
567
+ )
568
+ )
569
+ state = state.replace(
570
+ units=state.units.replace(
571
+ energy=jnp.where(
572
+ team_0_unit_count < params.max_units,
573
+ state.units.energy.at[0, team_0_new_unit_id, :].set(
574
+ jnp.array([params.init_unit_energy], dtype=jnp.int16)
575
+ ),
576
+ state.units.energy,
577
+ )
578
+ )
579
+ )
580
+ state = state.replace(
581
+ units=state.units.replace(
582
+ position=jnp.where(
583
+ team_1_unit_count < params.max_units,
584
+ state.units.position.at[1, team_1_new_unit_id, :].set(
585
+ jnp.array(
586
+ [params.map_width - 1, params.map_height - 1],
587
+ dtype=jnp.int16,
588
+ )
589
+ ),
590
+ state.units.position,
591
+ )
592
+ )
593
+ )
594
+ state = state.replace(
595
+ units=state.units.replace(
596
+ energy=jnp.where(
597
+ team_1_unit_count < params.max_units,
598
+ state.units.energy.at[1, team_1_new_unit_id, :].set(
599
+ jnp.array([params.init_unit_energy], dtype=jnp.int16)
600
+ ),
601
+ state.units.energy,
602
+ )
603
+ )
604
+ )
605
+ state = state.replace(
606
+ units_mask=state.units_mask.at[0, team_0_new_unit_id].set(
607
+ jnp.where(
608
+ team_0_unit_count < params.max_units,
609
+ True,
610
+ state.units_mask[0, team_0_new_unit_id],
611
+ )
612
+ )
613
+ )
614
+ state = state.replace(
615
+ units_mask=state.units_mask.at[1, team_1_new_unit_id].set(
616
+ jnp.where(
617
+ team_1_unit_count < params.max_units,
618
+ True,
619
+ state.units_mask[1, team_1_new_unit_id],
620
+ )
621
+ )
622
+ )
623
+ # state = jnp.where(team_0_unit_count < params.max_units, spawn_unit(state, 0, team_0_new_unit_id, [0, 0], params), state)
624
+ # state = jnp.where(team_1_unit_count < params.max_units, spawn_unit(state, 1, team_1_new_unit_id, [params.map_width - 1, params.map_height - 1], params), state)
625
+ return state
626
+
627
+ state = jax.lax.cond(
628
+ spawn_units_in, lambda: spawn_team_units(state), lambda: state
629
+ )
630
+
631
+ state = self.compute_sensor_masks(state, params)
632
+
633
+ # Shift objects around in space
634
+ # Move the nebula tiles in state.map_features.tile_types up by 1 and to the right by 1
635
+ # this is also symmetric nebula tile movement
636
+ new_tile_types_map = jnp.roll(
637
+ state.map_features.tile_type,
638
+ shift=(
639
+ 1 * jnp.sign(params.nebula_tile_drift_speed),
640
+ -1 * jnp.sign(params.nebula_tile_drift_speed),
641
+ ),
642
+ axis=(0, 1),
643
+ )
644
+ new_tile_types_map = jnp.where(
645
+ state.steps * params.nebula_tile_drift_speed % 1 == 0,
646
+ new_tile_types_map,
647
+ state.map_features.tile_type,
648
+ )
649
+ # new_energy_nodes = state.energy_nodes + jnp.array([1 * jnp.sign(params.energy_node_drift_speed), -1 * jnp.sign(params.energy_node_drift_speed)])
650
+
651
+ energy_node_deltas = jnp.round(
652
+ jax.random.uniform(
653
+ key=key,
654
+ shape=(self.fixed_env_params.max_energy_nodes // 2, 2),
655
+ minval=-params.energy_node_drift_magnitude,
656
+ maxval=params.energy_node_drift_magnitude,
657
+ )
658
+ ).astype(jnp.int16)
659
+ energy_node_deltas_symmetric = jnp.stack(
660
+ [-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1
661
+ )
662
+ # TODO symmetric movement
663
+ # energy_node_deltas = jnp.round(jax.random.uniform(key=key, shape=(params.max_energy_nodes // 2, 2), minval=-params.energy_node_drift_magnitude, maxval=params.energy_node_drift_magnitude)).astype(jnp.int16)
664
+ energy_node_deltas = jnp.concatenate(
665
+ (energy_node_deltas, energy_node_deltas_symmetric)
666
+ )
667
+ new_energy_nodes = jnp.clip(
668
+ state.energy_nodes + energy_node_deltas,
669
+ min=jnp.array([0, 0], dtype=jnp.int16),
670
+ max=jnp.array([self.fixed_env_params.map_width, self.fixed_env_params.map_height], dtype=jnp.int16),
671
+ )
672
+ new_energy_nodes = jnp.where(
673
+ state.steps * params.energy_node_drift_speed % 1 == 0,
674
+ new_energy_nodes,
675
+ state.energy_nodes,
676
+ )
677
+ state = state.replace(
678
+ map_features=state.map_features.replace(tile_type=new_tile_types_map),
679
+ energy_nodes=new_energy_nodes,
680
+ )
681
+
682
+ # Compute relic scores
683
+ def team_relic_score(unit_counts_map):
684
+ scores = (unit_counts_map > 0) & (state.relic_nodes_map_weights > 0)
685
+ return jnp.sum(scores, dtype=jnp.int32)
686
+
687
+ # note we need to recompue unit counts since units can get removed due to collisions
688
+ team_scores = jax.vmap(team_relic_score)(
689
+ self.compute_unit_counts_map(state, params)
690
+ )
691
+ # Update team points
692
+ state = state.replace(team_points=state.team_points + team_scores)
693
+
694
+ # if match ended, then remove all units, update team wins, reset team points
695
+ winner_by_points = jnp.where(
696
+ state.team_points.max() > state.team_points.min(),
697
+ jnp.argmax(state.team_points),
698
+ -1,
699
+ )
700
+ winner_by_energy = jnp.sum(
701
+ state.units.energy[..., 0] * state.units_mask.astype(jnp.int16), axis=1
702
+ )
703
+ winner_by_energy = jnp.where(
704
+ winner_by_energy.max() > winner_by_energy.min(),
705
+ jnp.argmax(winner_by_energy),
706
+ -1,
707
+ )
708
+
709
+ winner = jnp.where(
710
+ winner_by_points != -1,
711
+ winner_by_points,
712
+ jnp.where(
713
+ winner_by_energy != -1,
714
+ winner_by_energy,
715
+ jax.random.randint(key, shape=(), minval=0, maxval=params.num_teams),
716
+ ),
717
+ )
718
+ match_ended = state.match_steps >= params.max_steps_in_match
719
+
720
+ state = state.replace(
721
+ match_steps=jnp.where(match_ended, -1, state.match_steps),
722
+ team_points=jnp.where(
723
+ match_ended, jnp.zeros_like(state.team_points), state.team_points
724
+ ),
725
+ team_wins=jnp.where(
726
+ match_ended, state.team_wins.at[winner].add(1), state.team_wins
727
+ ),
728
+ )
729
+ # Update state's step count
730
+ state = state.replace(steps=state.steps + 1, match_steps=state.match_steps + 1)
731
+ truncated = (
732
+ state.steps
733
+ >= (params.max_steps_in_match + 1) * params.match_count_per_episode
734
+ )
735
+ reward = dict()
736
+ for k in range(self.fixed_env_params.num_teams):
737
+ reward[f"player_{k}"] = state.team_wins[k]
738
+ terminated = self.is_terminal(state, params)
739
+ return (
740
+ lax.stop_gradient(self.get_obs(state, params, key=key)),
741
+ lax.stop_gradient(state),
742
+ reward,
743
+ terminated,
744
+ truncated,
745
+ {"discount": self.discount(state, params)},
746
+ )
747
+
748
+ def reset_env(
749
+ self, key: chex.PRNGKey, params: EnvParams
750
+ ) -> Tuple[EnvObs, EnvState]:
751
+ """Reset environment state by sampling initial position."""
752
+
753
+ state = gen_state(
754
+ key=key,
755
+ env_params=params,
756
+ max_units=self.fixed_env_params.max_units,
757
+ num_teams=self.fixed_env_params.num_teams,
758
+ map_type=self.fixed_env_params.map_type,
759
+ map_width=self.fixed_env_params.map_width,
760
+ map_height=self.fixed_env_params.map_height,
761
+ max_energy_nodes=self.fixed_env_params.max_energy_nodes,
762
+ max_relic_nodes=self.fixed_env_params.max_relic_nodes,
763
+ relic_config_size=self.fixed_env_params.relic_config_size,
764
+ )
765
+ state = self.compute_energy_features(state, params)
766
+ state = self.compute_sensor_masks(state, params)
767
+
768
+ return self.get_obs(state, params=params, key=key), state
769
+
770
+ @functools.partial(jax.jit, static_argnums=(0,))
771
+ def step(
772
+ self,
773
+ key: chex.PRNGKey,
774
+ state: EnvState,
775
+ action: Union[int, float, chex.Array],
776
+ params: Optional[EnvParams] = None,
777
+ ) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
778
+ """Performs step transitions in the environment."""
779
+ # Use default env parameters if no others specified
780
+ if params is None:
781
+ params = self.default_params
782
+ key, key_reset = jax.random.split(key)
783
+ obs_st, state_st, reward, terminated, truncated, info = self.step_env(
784
+ key, state, action, params
785
+ )
786
+ info["final_state"] = state_st
787
+ info["final_observation"] = obs_st
788
+ done = terminated | truncated
789
+
790
+ if self.auto_reset:
791
+ obs_re, state_re = self.reset_env(key_reset, params)
792
+ # Use lax.cond to efficiently choose between obs_re and obs_st
793
+ obs = jax.lax.cond(
794
+ done,
795
+ lambda: obs_re,
796
+ lambda: obs_st
797
+ )
798
+ state = jax.lax.cond(
799
+ done,
800
+ lambda: state_re,
801
+ lambda: state_st
802
+ )
803
+ else:
804
+ obs = obs_st
805
+ state = state_st
806
+
807
+ # all agents terminate/truncate at same time
808
+ terminated_dict = dict()
809
+ truncated_dict = dict()
810
+ for k in range(self.fixed_env_params.num_teams):
811
+ terminated_dict[f"player_{k}"] = terminated
812
+ truncated_dict[f"player_{k}"] = truncated
813
+ info[f"player_{k}"] = dict()
814
+ return obs, state, reward, terminated_dict, truncated_dict, info
815
+
816
+ @functools.partial(jax.jit, static_argnums=(0,))
817
+ def reset(
818
+ self, key: chex.PRNGKey, params: Optional[EnvParams] = None
819
+ ) -> Tuple[chex.Array, EnvState]:
820
+ """Performs resetting of environment."""
821
+ # Use default env parameters if no others specified
822
+ if params is None:
823
+ params = self.default_params
824
+
825
+ obs, state = self.reset_env(key, params)
826
+ return obs, state
827
+
828
+ # @functools.partial(jax.jit, static_argnums=(0, 2))
829
+ def get_obs(self, state: EnvState, params=None, key=None) -> EnvObs:
830
+ """Return observation from raw state, handling partial observability."""
831
+ obs = dict()
832
+
833
+ def update_unit_mask(unit_position, unit_mask, sensor_mask):
834
+ return unit_mask & sensor_mask[unit_position[0], unit_position[1]]
835
+
836
+ def update_team_unit_mask(unit_position, unit_mask, sensor_mask):
837
+ return jax.vmap(update_unit_mask, in_axes=(0, 0, None))(
838
+ unit_position, unit_mask, sensor_mask
839
+ )
840
+
841
+ def update_relic_nodes_mask(relic_nodes_mask, relic_nodes, sensor_mask):
842
+ return jax.vmap(
843
+ lambda r_mask, r, s_mask: r_mask & s_mask[r[0], r[1]],
844
+ in_axes=(0, 0, None),
845
+ )(relic_nodes_mask, relic_nodes, sensor_mask)
846
+
847
+ for t in range(self.fixed_env_params.num_teams):
848
+ other_team_ids = jnp.array(
849
+ [t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t]
850
+ )
851
+ new_unit_masks = jax.vmap(update_team_unit_mask, in_axes=(0, 0, None))(
852
+ state.units.position[other_team_ids],
853
+ state.units_mask[other_team_ids],
854
+ state.sensor_mask[t],
855
+ )
856
+ new_unit_masks = state.units_mask.at[other_team_ids].set(new_unit_masks)
857
+
858
+ new_relic_nodes_mask = update_relic_nodes_mask(
859
+ state.relic_nodes_mask, state.relic_nodes, state.sensor_mask[t]
860
+ )
861
+ team_obs = EnvObs(
862
+ units=UnitState(
863
+ position=jnp.where(
864
+ new_unit_masks[..., None], state.units.position, -1
865
+ ),
866
+ energy=jnp.where(new_unit_masks[..., None], state.units.energy, -1)[
867
+ ..., 0
868
+ ],
869
+ ),
870
+ units_mask=new_unit_masks,
871
+ sensor_mask=state.sensor_mask[t],
872
+ map_features=MapTile(
873
+ energy=jnp.where(
874
+ state.sensor_mask[t], state.map_features.energy, -1
875
+ ),
876
+ tile_type=jnp.where(
877
+ state.sensor_mask[t], state.map_features.tile_type, -1
878
+ ),
879
+ ),
880
+ team_points=state.team_points,
881
+ team_wins=state.team_wins,
882
+ steps=state.steps,
883
+ match_steps=state.match_steps,
884
+ relic_nodes=jnp.where(
885
+ new_relic_nodes_mask[..., None], state.relic_nodes, -1
886
+ ),
887
+ relic_nodes_mask=new_relic_nodes_mask,
888
+ )
889
+ obs[f"player_{t}"] = team_obs
890
+ return obs
891
+
892
+ @functools.partial(jax.jit, static_argnums=(0, ))
893
+ def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray:
894
+ """Check whether state is terminal. This never occurs. Game is only done when the time limit is reached."""
895
+ terminated = jnp.array(False)
896
+ return terminated
897
+
898
+ @property
899
+ def name(self) -> str:
900
+ """Environment name."""
901
+ return "Lux AI Season 3"
902
+
903
+ def render(self, state: EnvState, params: EnvParams):
904
+ self.renderer.render(state, params)
905
+
906
+ def action_space(self, params: Optional[EnvParams] = None):
907
+ """Action space of the environment."""
908
+ low = np.zeros((self.fixed_env_params.max_units, 3))
909
+ low[:, 1:] = -env_params_ranges["unit_sap_range"][-1]
910
+ high = np.ones((self.fixed_env_params.max_units, 3)) * 6
911
+ high[:, 1:] = env_params_ranges["unit_sap_range"][-1]
912
+ return spaces.Dict(
913
+ dict(player_0=MultiDiscrete(low, high), player_1=MultiDiscrete(low, high))
914
+ )
915
+
916
+ def observation_space(self, params: EnvParams):
917
+ """Observation space of the environment."""
918
+ return spaces.Discrete(10)
919
+
920
+ def state_space(self, params: EnvParams):
921
+ """State space of the environment."""
922
+ return spaces.Discrete(10)
923
+
924
+