kaggle-environments 0.2.1__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (215) hide show
  1. kaggle_environments/__init__.py +49 -13
  2. kaggle_environments/agent.py +177 -124
  3. kaggle_environments/api.py +31 -0
  4. kaggle_environments/core.py +295 -170
  5. kaggle_environments/envs/cabt/cabt.js +164 -0
  6. kaggle_environments/envs/cabt/cabt.json +28 -0
  7. kaggle_environments/envs/cabt/cabt.py +186 -0
  8. kaggle_environments/envs/cabt/cg/__init__.py +0 -0
  9. kaggle_environments/envs/cabt/cg/cg.dll +0 -0
  10. kaggle_environments/envs/cabt/cg/game.py +75 -0
  11. kaggle_environments/envs/cabt/cg/libcg.so +0 -0
  12. kaggle_environments/envs/cabt/cg/sim.py +48 -0
  13. kaggle_environments/envs/cabt/test_cabt.py +120 -0
  14. kaggle_environments/envs/chess/chess.js +4289 -0
  15. kaggle_environments/envs/chess/chess.json +60 -0
  16. kaggle_environments/envs/chess/chess.py +4241 -0
  17. kaggle_environments/envs/chess/test_chess.py +60 -0
  18. kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
  19. kaggle_environments/envs/connectx/connectx.js +1 -1
  20. kaggle_environments/envs/connectx/connectx.json +15 -1
  21. kaggle_environments/envs/connectx/connectx.py +6 -23
  22. kaggle_environments/envs/connectx/test_connectx.py +70 -24
  23. kaggle_environments/envs/football/football.ipynb +75 -0
  24. kaggle_environments/envs/football/football.json +91 -0
  25. kaggle_environments/envs/football/football.py +277 -0
  26. kaggle_environments/envs/football/helpers.py +95 -0
  27. kaggle_environments/envs/football/test_football.py +360 -0
  28. kaggle_environments/envs/halite/__init__.py +0 -0
  29. kaggle_environments/envs/halite/halite.ipynb +44741 -0
  30. kaggle_environments/envs/halite/halite.js +199 -83
  31. kaggle_environments/envs/halite/halite.json +31 -18
  32. kaggle_environments/envs/halite/halite.py +164 -303
  33. kaggle_environments/envs/halite/helpers.py +720 -0
  34. kaggle_environments/envs/halite/test_halite.py +190 -0
  35. kaggle_environments/envs/hungry_geese/__init__.py +0 -0
  36. kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
  37. kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +21 -14
  38. kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
  39. kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
  40. kaggle_environments/envs/identity/identity.json +6 -5
  41. kaggle_environments/envs/identity/identity.py +15 -2
  42. kaggle_environments/envs/kore_fleets/__init__.py +0 -0
  43. kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
  44. kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
  45. kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
  46. kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
  47. kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
  48. kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
  49. kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
  50. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
  51. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
  52. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
  53. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
  54. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
  55. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
  56. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
  57. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
  58. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
  59. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
  60. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
  61. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
  62. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
  63. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
  64. kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
  65. kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
  66. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
  67. kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
  68. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
  69. kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
  70. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
  71. kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
  72. kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
  73. kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
  74. kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
  75. kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
  76. kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
  77. kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
  78. kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
  79. kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
  80. kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
  81. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
  82. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
  83. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
  84. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
  85. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
  86. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
  87. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
  88. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
  89. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
  90. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
  91. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
  92. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
  93. kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
  94. kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
  95. kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
  96. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
  97. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
  98. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
  99. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
  100. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
  101. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
  102. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
  103. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
  104. kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
  105. kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
  106. kaggle_environments/envs/lux_ai_2021/README.md +3 -0
  107. kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
  108. kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
  109. kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
  110. kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
  111. kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
  112. kaggle_environments/envs/lux_ai_2021/index.html +43 -0
  113. kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
  114. kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
  115. kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
  116. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
  117. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
  118. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
  119. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
  120. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
  121. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
  122. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
  123. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
  124. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
  125. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
  126. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
  127. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
  128. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
  129. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
  130. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
  131. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
  132. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
  133. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
  134. kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
  135. kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
  136. kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
  137. kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
  138. kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
  139. kaggle_environments/envs/lux_ai_s3/README.md +21 -0
  140. kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
  141. kaggle_environments/envs/lux_ai_s3/index.html +42 -0
  142. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
  143. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
  144. kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
  145. kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
  146. kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
  147. kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
  148. kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
  149. kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
  150. kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
  151. kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
  152. kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
  153. kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
  154. kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
  155. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
  156. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
  157. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
  158. kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
  159. kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
  160. kaggle_environments/envs/mab/__init__.py +0 -0
  161. kaggle_environments/envs/mab/agents.py +12 -0
  162. kaggle_environments/envs/mab/mab.js +100 -0
  163. kaggle_environments/envs/mab/mab.json +74 -0
  164. kaggle_environments/envs/mab/mab.py +146 -0
  165. kaggle_environments/envs/open_spiel/__init__.py +0 -0
  166. kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
  167. kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
  168. kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
  169. kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
  170. kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
  171. kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
  172. kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
  173. kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
  174. kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
  175. kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
  176. kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
  177. kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
  178. kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
  179. kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
  180. kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
  181. kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
  182. kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
  183. kaggle_environments/envs/open_spiel/observation.py +128 -0
  184. kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
  185. kaggle_environments/envs/open_spiel/proxy.py +138 -0
  186. kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
  187. kaggle_environments/envs/rps/__init__.py +0 -0
  188. kaggle_environments/envs/rps/agents.py +84 -0
  189. kaggle_environments/envs/rps/helpers.py +25 -0
  190. kaggle_environments/envs/rps/rps.js +117 -0
  191. kaggle_environments/envs/rps/rps.json +63 -0
  192. kaggle_environments/envs/rps/rps.py +90 -0
  193. kaggle_environments/envs/rps/test_rps.py +110 -0
  194. kaggle_environments/envs/rps/utils.py +7 -0
  195. kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
  196. kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
  197. kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
  198. kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
  199. kaggle_environments/errors.py +2 -4
  200. kaggle_environments/helpers.py +377 -0
  201. kaggle_environments/main.py +340 -0
  202. kaggle_environments/schemas.json +23 -18
  203. kaggle_environments/static/player.html +206 -74
  204. kaggle_environments/utils.py +46 -73
  205. kaggle_environments-1.20.0.dist-info/METADATA +25 -0
  206. kaggle_environments-1.20.0.dist-info/RECORD +211 -0
  207. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info}/WHEEL +1 -2
  208. kaggle_environments-1.20.0.dist-info/entry_points.txt +3 -0
  209. kaggle_environments/envs/battlegeese/battlegeese.py +0 -223
  210. kaggle_environments/temp.py +0 -14
  211. kaggle_environments-0.2.1.dist-info/METADATA +0 -393
  212. kaggle_environments-0.2.1.dist-info/RECORD +0 -32
  213. kaggle_environments-0.2.1.dist-info/entry_points.txt +0 -3
  214. kaggle_environments-0.2.1.dist-info/top_level.txt +0 -1
  215. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,819 @@
1
+ import functools
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import chex
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from gymnax.environments import environment, spaces
9
+ from jax import lax
10
+
11
+ from luxai_s3.params import EnvParams, env_params_ranges
12
+ from luxai_s3.pygame_render import LuxAIPygameRenderer
13
+ from luxai_s3.spaces import MultiDiscrete
14
+ from luxai_s3.state import ASTEROID_TILE, ENERGY_NODE_FNS, NEBULA_TILE, EnvObs, EnvState, MapTile, UnitState, gen_state
15
+
16
+
17
+ class LuxAIS3Env(environment.Environment):
18
+ def __init__(self, auto_reset=False, fixed_env_params: EnvParams = EnvParams(), **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.renderer = LuxAIPygameRenderer()
21
+ self.auto_reset = auto_reset
22
+ self.fixed_env_params = fixed_env_params
23
+ """fixed env params for concrete/static values. Necessary for jit/vmap capability with randomly sampled maps which must of consistent shape"""
24
+
25
+ @property
26
+ def default_params(self) -> EnvParams:
27
+ params = EnvParams()
28
+ params = jax.tree_map(jax.numpy.array, params)
29
+ return params
30
+
31
+ def compute_unit_counts_map(self, state: EnvState, params: EnvParams, exclude_negative_energy_units: bool = False):
32
+ # map of total units per team on each tile, shape (num_teams, map_width, map_height)
33
+ unit_counts_map = jnp.zeros(
34
+ (self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
35
+ dtype=jnp.int16,
36
+ )
37
+
38
+ def update_unit_counts_map(unit_position, unit_mask, unit_energy_nonnegative, unit_counts_map):
39
+ if exclude_negative_energy_units:
40
+ mask = unit_mask & unit_energy_nonnegative
41
+ else:
42
+ mask = unit_mask
43
+ unit_counts_map = unit_counts_map.at[unit_position[0], unit_position[1]].add(mask.astype(jnp.int16))
44
+ return unit_counts_map
45
+
46
+ for t in range(self.fixed_env_params.num_teams):
47
+ unit_counts_map = unit_counts_map.at[t].add(
48
+ jnp.sum(
49
+ jax.vmap(update_unit_counts_map, in_axes=(0, 0, 0, None), out_axes=0)(
50
+ state.units.position[t],
51
+ state.units_mask[t],
52
+ state.units.energy[t, :, 0] >= 0,
53
+ unit_counts_map[t],
54
+ ),
55
+ axis=0,
56
+ dtype=jnp.int16,
57
+ )
58
+ )
59
+ return unit_counts_map
60
+
61
+ def compute_energy_features(self, state: EnvState, params: EnvParams):
62
+ # first compute a array of shape (map_height, map_width, num_energy_nodes) with values equal to the distance of the tile to the energy node
63
+ mm = jnp.meshgrid(jnp.arange(self.fixed_env_params.map_width), jnp.arange(self.fixed_env_params.map_height))
64
+ mm = jnp.stack([mm[0], mm[1]]).T.astype(jnp.int16) # mm[x, y] gives [x, y]
65
+ distances_to_nodes = jax.vmap(lambda pos: jnp.linalg.norm(mm - pos, axis=-1))(state.energy_nodes)
66
+
67
+ def compute_energy_field(node_fn_spec, distances_to_node, mask):
68
+ fn_i, x, y, z = node_fn_spec
69
+ return jnp.where(
70
+ mask,
71
+ lax.switch(fn_i.astype(jnp.int16), ENERGY_NODE_FNS, distances_to_node, x, y, z),
72
+ jnp.zeros_like(distances_to_node),
73
+ )
74
+
75
+ energy_field = jax.vmap(compute_energy_field)(
76
+ state.energy_node_fns, distances_to_nodes, state.energy_nodes_mask
77
+ )
78
+ energy_field = jnp.where(
79
+ energy_field.mean() < 0.25,
80
+ energy_field + (0.25 - energy_field.mean()),
81
+ energy_field,
82
+ )
83
+ energy_field = jnp.round(energy_field.sum(0)).astype(jnp.int16)
84
+ energy_field = jnp.clip(energy_field, params.min_energy_per_tile, params.max_energy_per_tile)
85
+ state = state.replace(map_features=state.map_features.replace(energy=energy_field))
86
+ return state
87
+
88
+ def compute_sensor_masks(self, state, params: EnvParams):
89
+ """Compute the vision power and sensor mask for both teams
90
+
91
+ Algorithm:
92
+
93
+ For each team, generate a integer vision power array over the map.
94
+ For each unit in team, add unit sensor range value (its kind of like the units sensing power/depth) to each tile the unit's sensor range
95
+ Clamp the vision power array to range [0, unit_sensing_range].
96
+
97
+ With 2 vision power maps, take the nebula vision mask * nebula vision power and subtract it from the vision power maps.
98
+ Now any time the vision power map has value > 0, the team can sense the tile. This forms the sensor mask
99
+ """
100
+
101
+ max_sensor_range = env_params_ranges["unit_sensor_range"][-1]
102
+ vision_power_map_padding = max_sensor_range
103
+ vision_power_map = jnp.zeros(
104
+ shape=(
105
+ self.fixed_env_params.num_teams,
106
+ self.fixed_env_params.map_height + 2 * vision_power_map_padding,
107
+ self.fixed_env_params.map_width + 2 * vision_power_map_padding,
108
+ ),
109
+ dtype=jnp.int16,
110
+ )
111
+
112
+ # Update sensor mask based on the sensor range
113
+ def update_vision_power_map(unit_pos, vision_power_map):
114
+ x, y = unit_pos
115
+ existing_vision_power = jax.lax.dynamic_slice(
116
+ vision_power_map,
117
+ start_indices=(
118
+ x - max_sensor_range + vision_power_map_padding,
119
+ y - max_sensor_range + vision_power_map_padding,
120
+ ),
121
+ slice_sizes=(
122
+ max_sensor_range * 2 + 1,
123
+ max_sensor_range * 2 + 1,
124
+ ),
125
+ )
126
+ update = jnp.zeros_like(existing_vision_power)
127
+ for i in range(max_sensor_range + 1):
128
+ val = jnp.where(
129
+ i > max_sensor_range - params.unit_sensor_range - 1,
130
+ i + 1 - (max_sensor_range - params.unit_sensor_range),
131
+ 0,
132
+ ).astype(jnp.int16)
133
+ update = update.at[
134
+ i : max_sensor_range * 2 + 1 - i,
135
+ i : max_sensor_range * 2 + 1 - i,
136
+ ].set(val)
137
+ # vision of position at center of update has an extra 10
138
+ update = update.at[
139
+ max_sensor_range,
140
+ max_sensor_range,
141
+ ].add(10)
142
+ vision_power_map = jax.lax.dynamic_update_slice(
143
+ vision_power_map,
144
+ update=update + existing_vision_power,
145
+ start_indices=(
146
+ x - max_sensor_range + vision_power_map_padding,
147
+ y - max_sensor_range + vision_power_map_padding,
148
+ ),
149
+ )
150
+ return vision_power_map
151
+
152
+ # Apply the sensor mask update for all units of both teams
153
+ def update_unit_vision_power_map(unit_pos, unit_mask, vision_power_map):
154
+ return jax.lax.cond(
155
+ unit_mask,
156
+ lambda: update_vision_power_map(unit_pos, vision_power_map),
157
+ lambda: vision_power_map,
158
+ )
159
+
160
+ def update_team_vision_power_map(team_units, unit_mask, vision_power_map):
161
+ def body_fun(carry, i):
162
+ vision_power_map = carry
163
+ return (
164
+ update_unit_vision_power_map(team_units.position[i], unit_mask[i], vision_power_map),
165
+ None,
166
+ )
167
+
168
+ vision_power_map, _ = jax.lax.scan(body_fun, vision_power_map, jnp.arange(self.fixed_env_params.max_units))
169
+ return vision_power_map
170
+
171
+ vision_power_map = jax.vmap(update_team_vision_power_map)(state.units, state.units_mask, vision_power_map)
172
+ vision_power_map = vision_power_map[
173
+ :,
174
+ vision_power_map_padding:-vision_power_map_padding,
175
+ vision_power_map_padding:-vision_power_map_padding,
176
+ ]
177
+ # handle nebula tiles
178
+ vision_power_map = (
179
+ vision_power_map
180
+ - (state.map_features.tile_type == NEBULA_TILE).astype(jnp.int16) * params.nebula_tile_vision_reduction
181
+ )
182
+
183
+ sensor_mask = vision_power_map > 0
184
+ state = state.replace(sensor_mask=sensor_mask)
185
+ state = state.replace(vision_power_map=vision_power_map)
186
+ return state
187
+
188
+ # @functools.partial(jax.jit, static_argnums=(0, 4))
189
+ def step_env(
190
+ self,
191
+ key: chex.PRNGKey,
192
+ state: EnvState,
193
+ action: Union[int, float, chex.Array],
194
+ params: EnvParams,
195
+ ) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
196
+ state = self.compute_energy_features(state, params)
197
+
198
+ action = jnp.stack([action["player_0"], action["player_1"]])
199
+
200
+ # remove all units if the match ended in the previous step indicated by a reset of match_steps to 0
201
+ state = state.replace(
202
+ units_mask=jnp.where(
203
+ state.match_steps == 0,
204
+ jnp.zeros_like(state.units_mask),
205
+ state.units_mask,
206
+ )
207
+ )
208
+ """remove units that have less than 0 energy"""
209
+ # we remove units at the start of the timestep so that the visualizer can show the unit with negative energy and is marked for removal soon.
210
+ state = state.replace(units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask)
211
+
212
+ """spawn relic nodes based on schedule"""
213
+ relic_nodes_mask = (state.steps >= state.relic_spawn_schedule) & (state.relic_spawn_schedule != -1)
214
+ state = state.replace(relic_nodes_mask=relic_nodes_mask)
215
+
216
+ """ process unit movement """
217
+ # 0 is do nothing, 1 is move up, 2 is move right, 3 is move down, 4 is move left, 5 is sap
218
+ # Define movement directions
219
+ directions = jnp.array(
220
+ [
221
+ [0, 0], # Do nothing
222
+ [0, -1], # Move up
223
+ [1, 0], # Move right
224
+ [0, 1], # Move down
225
+ [-1, 0], # Move left
226
+ ],
227
+ dtype=jnp.int16,
228
+ )
229
+
230
+ def move_unit(unit: UnitState, action, mask):
231
+ new_pos = unit.position + directions[action]
232
+ # Check if the new position is on a map feature of value 2
233
+ is_blocked = state.map_features.tile_type[new_pos[0], new_pos[1]] == ASTEROID_TILE
234
+ enough_energy = unit.energy >= params.unit_move_cost
235
+ # If blocked, keep the original position
236
+ # new_pos = jnp.where(is_blocked, unit.position, new_pos)
237
+ # Ensure the new position is within the map boundaries
238
+ new_pos = jnp.clip(
239
+ new_pos,
240
+ 0,
241
+ jnp.array([params.map_width - 1, params.map_height - 1], dtype=jnp.int16),
242
+ )
243
+ unit_moved = mask & ~is_blocked & enough_energy & (action < 5) & (action > 0)
244
+ # Update the unit's position only if it's active. Note energy is used if unit tries to move off map. Energy is not used if unit tries to move into an asteroid tile.
245
+ return UnitState(
246
+ position=jnp.where(unit_moved, new_pos, unit.position),
247
+ energy=jnp.where(unit_moved, unit.energy - params.unit_move_cost, unit.energy),
248
+ )
249
+
250
+ # Move units for both teams
251
+ move_actions = action[..., 0]
252
+ state = state.replace(
253
+ units=jax.vmap(
254
+ lambda team_units, team_action, team_mask: jax.vmap(move_unit, in_axes=(0, 0, 0))(
255
+ team_units, team_action, team_mask
256
+ ),
257
+ in_axes=(0, 0, 0),
258
+ )(state.units, move_actions, state.units_mask)
259
+ )
260
+
261
+ original_unit_energy = state.units.energy
262
+ """original amount of energy of all units"""
263
+
264
+ """apply sap actions"""
265
+ sap_action_mask = action[..., 0] == 5
266
+ sap_action_deltas = action[..., 1:]
267
+
268
+ def sap_unit(
269
+ current_energy: jnp.ndarray,
270
+ all_units: UnitState,
271
+ sap_action_mask,
272
+ sap_action_deltas,
273
+ units_mask,
274
+ ):
275
+ # TODO (stao): clean up this code. It is probably slower than it needs be and could be vmapped perhaps.
276
+ for t in range(self.fixed_env_params.num_teams):
277
+ other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
278
+ team_sap_action_deltas = sap_action_deltas[t] # (max_units, 2)
279
+ team_sap_action_mask = sap_action_mask[t]
280
+ other_team_unit_mask = units_mask[other_team_ids] # (other_teams, max_units)
281
+ team_sapped_positions = all_units.position[t] + team_sap_action_deltas # (max_units, 2)
282
+ # whether the unit is really sapping or not (needs to exist, have enough energy, and a valid sap action)
283
+ team_unit_sapped = (
284
+ units_mask[t]
285
+ & team_sap_action_mask
286
+ & (current_energy[t, :, 0] >= params.unit_sap_cost)
287
+ & (jnp.max(jnp.abs(team_sap_action_deltas), axis=-1) <= params.unit_sap_range)
288
+ ) # (max_units)
289
+ team_unit_sapped = (
290
+ team_unit_sapped
291
+ & (team_sapped_positions >= 0).all(-1)
292
+ & (team_sapped_positions[:, 0] < self.fixed_env_params.map_width)
293
+ & (team_sapped_positions[:, 1] < self.fixed_env_params.map_height)
294
+ )
295
+ # the number of times other units are sapped
296
+ other_units_sapped_count = jnp.sum(
297
+ team_unit_sapped[None, None, :]
298
+ & jnp.all(
299
+ all_units.position[other_team_ids][:, :, None] == team_sapped_positions[None],
300
+ axis=-1,
301
+ ),
302
+ axis=-1,
303
+ dtype=jnp.int16,
304
+ ) # (len(other_team_ids), max_units)
305
+ # remove unit_sap_cost energy from opposition units that were in the middle of a sap action.
306
+ all_units = all_units.replace(
307
+ energy=all_units.energy.at[other_team_ids].set(
308
+ jnp.where(
309
+ other_team_unit_mask[:, :, None] & (other_units_sapped_count[:, :, None] > 0),
310
+ all_units.energy[other_team_ids]
311
+ - params.unit_sap_cost * other_units_sapped_count[:, :, None],
312
+ all_units.energy[other_team_ids],
313
+ )
314
+ )
315
+ )
316
+
317
+ # remove unit_sap_cost * unit_sap_dropoff_factor energy from opposition units that were on tiles adjacent to the center of a sap action.
318
+ adjacent_offsets = jnp.array(
319
+ [
320
+ [-1, -1],
321
+ [-1, 0],
322
+ [-1, 1],
323
+ [0, -1],
324
+ [0, 1],
325
+ [1, -1],
326
+ [1, 0],
327
+ [1, 1],
328
+ ],
329
+ dtype=jnp.int16,
330
+ )
331
+ team_sapped_adjacent_positions = (
332
+ team_sapped_positions[:, None, :] + adjacent_offsets
333
+ ) # (max_units, len(adjacent_offsets), 2)
334
+ other_units_adjacent_sapped_count = jnp.sum(
335
+ team_unit_sapped[None, None, :, None]
336
+ & jnp.all(
337
+ all_units.position[other_team_ids][:, :, None, None] == team_sapped_adjacent_positions[None],
338
+ axis=-1,
339
+ ),
340
+ axis=(-1, -2),
341
+ dtype=jnp.int16,
342
+ ) # (len(other_team_ids), max_units)
343
+ all_units = all_units.replace(
344
+ energy=all_units.energy.at[other_team_ids].set(
345
+ jnp.where(
346
+ other_team_unit_mask[:, :, None] & (other_units_adjacent_sapped_count[:, :, None] > 0),
347
+ all_units.energy[other_team_ids]
348
+ - jnp.array(
349
+ jnp.array(params.unit_sap_cost, dtype=jnp.float32)
350
+ * params.unit_sap_dropoff_factor
351
+ * other_units_adjacent_sapped_count[:, :, None].astype(jnp.float32),
352
+ dtype=jnp.int16,
353
+ ),
354
+ all_units.energy[other_team_ids],
355
+ )
356
+ )
357
+ )
358
+
359
+ # remove unit_sap_cost energy from units that tried to sap some position within the unit's range
360
+ all_units = all_units.replace(
361
+ energy=all_units.energy.at[t].set(
362
+ jnp.where(
363
+ team_unit_sapped[:, None],
364
+ all_units.energy[t] - params.unit_sap_cost,
365
+ all_units.energy[t],
366
+ )
367
+ )
368
+ )
369
+ return all_units
370
+
371
+ state = state.replace(
372
+ units=sap_unit(
373
+ original_unit_energy,
374
+ state.units,
375
+ sap_action_mask,
376
+ sap_action_deltas,
377
+ state.units_mask,
378
+ )
379
+ )
380
+
381
+ """resolve collisions and energy void fields"""
382
+
383
+ # compute energy void fields for all teams and the energy + unit counts
384
+ unit_aggregate_energy_void_map = jnp.zeros(
385
+ shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
386
+ dtype=jnp.int16,
387
+ )
388
+ unit_counts_map = self.compute_unit_counts_map(state, params)
389
+ unit_aggregate_energy_map = jnp.zeros(
390
+ shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
391
+ dtype=jnp.int16,
392
+ )
393
+ for t in range(self.fixed_env_params.num_teams):
394
+
395
+ def scan_body(carry, x):
396
+ agg_energy_void_map, agg_energy_map = carry
397
+ unit_energy, unit_position, unit_mask = x
398
+ agg_energy_map = agg_energy_map.at[unit_position[0], unit_position[1]].add(
399
+ unit_energy[0] * unit_mask.astype(jnp.int16)
400
+ )
401
+ for deltas in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
402
+ new_pos = unit_position + jnp.array(deltas, dtype=jnp.int16)
403
+ in_map = (
404
+ (new_pos[0] >= 0)
405
+ & (new_pos[0] < self.fixed_env_params.map_width)
406
+ & (new_pos[1] >= 0)
407
+ & (new_pos[1] < self.fixed_env_params.map_height)
408
+ )
409
+ agg_energy_void_map = agg_energy_void_map.at[new_pos[0], new_pos[1]].add(
410
+ unit_energy[0] * unit_mask.astype(jnp.int16) * in_map.astype(jnp.int16)
411
+ )
412
+ return (agg_energy_void_map, agg_energy_map), None
413
+
414
+ agg_energy_void_map, agg_energy_map = jax.lax.scan(
415
+ scan_body,
416
+ (unit_aggregate_energy_void_map[t], unit_aggregate_energy_map[t]),
417
+ (original_unit_energy[t], state.units.position[t], state.units_mask[t]),
418
+ )[0]
419
+ unit_aggregate_energy_void_map = unit_aggregate_energy_void_map.at[t].add(agg_energy_void_map)
420
+ unit_aggregate_energy_map = unit_aggregate_energy_map.at[t].add(agg_energy_map)
421
+
422
+ # resolve collisions and keep only the surviving units
423
+ for t in range(self.fixed_env_params.num_teams):
424
+ other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
425
+ # get the energy map for the current team
426
+ opposing_unit_counts_map = unit_counts_map[other_team_ids].sum(axis=0) # (map_width, map_height)
427
+ team_energy_map = unit_aggregate_energy_map[t]
428
+ opposing_aggregate_energy_map = unit_aggregate_energy_map[other_team_ids].max(
429
+ axis=0
430
+ ) # (map_width, map_height)
431
+ # unit survives if there are opposing units on the tile, and if the opposing unit stack has less energy on the tile than the current unit
432
+ surviving_unit_mask = jax.vmap(
433
+ lambda unit_position: (opposing_unit_counts_map[unit_position[0], unit_position[1]] == 0)
434
+ | (
435
+ opposing_aggregate_energy_map[unit_position[0], unit_position[1]]
436
+ < team_energy_map[unit_position[0], unit_position[1]]
437
+ )
438
+ )(state.units.position[t])
439
+ state = state.replace(units_mask=state.units_mask.at[t].set(surviving_unit_mask & state.units_mask[t]))
440
+ # apply energy void fields
441
+ for t in range(self.fixed_env_params.num_teams):
442
+ other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
443
+ oppposition_energy_void_map = unit_aggregate_energy_void_map[other_team_ids].sum(
444
+ axis=0
445
+ ) # (map_width, map_height)
446
+ # unit on team t loses energy to void field equal to params.unit_energy_void_factor * void_energy / num units stacked with unit on the same tile
447
+ team_unit_energy = state.units.energy[t] - jnp.floor(
448
+ jax.vmap(
449
+ lambda unit_position: params.unit_energy_void_factor
450
+ * oppposition_energy_void_map[unit_position[0], unit_position[1]].astype(jnp.float32)
451
+ / unit_counts_map[t][unit_position[0], unit_position[1]].astype(jnp.float32)
452
+ )(state.units.position[t])[..., None]
453
+ ).astype(jnp.int16)
454
+ state = state.replace(units=state.units.replace(energy=state.units.energy.at[t].set(team_unit_energy)))
455
+
456
+ """apply energy field to the units"""
457
+
458
+ # Update unit energy based on the energy field and nebula tileof their current position
459
+ def update_unit_energy(unit: UnitState, mask):
460
+ x, y = unit.position
461
+ energy_gain = (
462
+ state.map_features.energy[x, y]
463
+ - (state.map_features.tile_type[x, y] == NEBULA_TILE).astype(jnp.int16)
464
+ * params.nebula_tile_energy_reduction
465
+ )
466
+ # if energy gain is less than 0
467
+ # new_energy = jnp.where((unit.energy < 0) & (energy_gain < 0))
468
+ new_energy = jnp.clip(
469
+ unit.energy + energy_gain,
470
+ params.min_unit_energy,
471
+ params.max_unit_energy,
472
+ )
473
+ # if unit already had negative energy due to opposition units and after energy field/nebula tile it is still below 0, then it will be removed next step
474
+ # and we keep its energy value at whatever it is
475
+ new_energy = jnp.where(
476
+ (unit.energy < 0) & (unit.energy + energy_gain < 0),
477
+ unit.energy,
478
+ new_energy,
479
+ )
480
+ return UnitState(position=unit.position, energy=jnp.where(mask, new_energy, unit.energy))
481
+
482
+ # Apply the energy update for all units of both teams
483
+ state = state.replace(
484
+ units=jax.vmap(lambda team_units, team_mask: jax.vmap(update_unit_energy)(team_units, team_mask))(
485
+ state.units, state.units_mask
486
+ )
487
+ )
488
+
489
+ """spawn new units in"""
490
+ spawn_units_in = state.match_steps % params.spawn_rate == 0
491
+
492
+ # TODO (stao): only logic in code that probably doesn't not handle more than 2 teams, everything else is vmapped across teams
493
+ def spawn_team_units(state: EnvState):
494
+ team_0_unit_count = state.units_mask[0].sum()
495
+ team_1_unit_count = state.units_mask[1].sum()
496
+ team_0_new_unit_id = state.units_mask[0].argmin()
497
+ team_1_new_unit_id = state.units_mask[1].argmin()
498
+ state = state.replace(
499
+ units=state.units.replace(
500
+ position=jnp.where(
501
+ team_0_unit_count < params.max_units,
502
+ state.units.position.at[0, team_0_new_unit_id, :].set(jnp.array([0, 0], dtype=jnp.int16)),
503
+ state.units.position,
504
+ )
505
+ )
506
+ )
507
+ state = state.replace(
508
+ units=state.units.replace(
509
+ energy=jnp.where(
510
+ team_0_unit_count < params.max_units,
511
+ state.units.energy.at[0, team_0_new_unit_id, :].set(
512
+ jnp.array([params.init_unit_energy], dtype=jnp.int16)
513
+ ),
514
+ state.units.energy,
515
+ )
516
+ )
517
+ )
518
+ state = state.replace(
519
+ units=state.units.replace(
520
+ position=jnp.where(
521
+ team_1_unit_count < params.max_units,
522
+ state.units.position.at[1, team_1_new_unit_id, :].set(
523
+ jnp.array(
524
+ [params.map_width - 1, params.map_height - 1],
525
+ dtype=jnp.int16,
526
+ )
527
+ ),
528
+ state.units.position,
529
+ )
530
+ )
531
+ )
532
+ state = state.replace(
533
+ units=state.units.replace(
534
+ energy=jnp.where(
535
+ team_1_unit_count < params.max_units,
536
+ state.units.energy.at[1, team_1_new_unit_id, :].set(
537
+ jnp.array([params.init_unit_energy], dtype=jnp.int16)
538
+ ),
539
+ state.units.energy,
540
+ )
541
+ )
542
+ )
543
+ state = state.replace(
544
+ units_mask=state.units_mask.at[0, team_0_new_unit_id].set(
545
+ jnp.where(
546
+ team_0_unit_count < params.max_units,
547
+ True,
548
+ state.units_mask[0, team_0_new_unit_id],
549
+ )
550
+ )
551
+ )
552
+ state = state.replace(
553
+ units_mask=state.units_mask.at[1, team_1_new_unit_id].set(
554
+ jnp.where(
555
+ team_1_unit_count < params.max_units,
556
+ True,
557
+ state.units_mask[1, team_1_new_unit_id],
558
+ )
559
+ )
560
+ )
561
+ # state = jnp.where(team_0_unit_count < params.max_units, spawn_unit(state, 0, team_0_new_unit_id, [0, 0], params), state)
562
+ # state = jnp.where(team_1_unit_count < params.max_units, spawn_unit(state, 1, team_1_new_unit_id, [params.map_width - 1, params.map_height - 1], params), state)
563
+ return state
564
+
565
+ state = jax.lax.cond(spawn_units_in, lambda: spawn_team_units(state), lambda: state)
566
+
567
+ state = self.compute_sensor_masks(state, params)
568
+
569
+ # Shift objects around in space
570
+ # Move the nebula tiles in state.map_features.tile_types up by 1 and to the right by 1
571
+ # this is also symmetric nebula tile movement
572
+ new_tile_types_map = jnp.roll(
573
+ state.map_features.tile_type,
574
+ shift=(
575
+ 1 * jnp.sign(params.nebula_tile_drift_speed),
576
+ -1 * jnp.sign(params.nebula_tile_drift_speed),
577
+ ),
578
+ axis=(0, 1),
579
+ )
580
+ new_tile_types_map = jnp.where(
581
+ (state.steps - 1) * abs(params.nebula_tile_drift_speed) % 1
582
+ > state.steps * abs(params.nebula_tile_drift_speed) % 1,
583
+ new_tile_types_map,
584
+ state.map_features.tile_type,
585
+ )
586
+
587
+ energy_node_deltas = jnp.round(
588
+ jax.random.uniform(
589
+ key=key,
590
+ shape=(self.fixed_env_params.max_energy_nodes // 2, 2),
591
+ minval=-params.energy_node_drift_magnitude,
592
+ maxval=params.energy_node_drift_magnitude,
593
+ )
594
+ ).astype(jnp.int16)
595
+ energy_node_deltas_symmetric = jnp.stack([-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1)
596
+ energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas_symmetric))
597
+ new_energy_nodes = jnp.clip(
598
+ state.energy_nodes + energy_node_deltas,
599
+ jnp.array([0, 0], dtype=jnp.int16),
600
+ jnp.array([self.fixed_env_params.map_width - 1, self.fixed_env_params.map_height - 1], dtype=jnp.int16),
601
+ )
602
+ new_energy_nodes = jnp.where(
603
+ (state.steps - 1) * abs(params.energy_node_drift_speed) % 1
604
+ > state.steps * abs(params.energy_node_drift_speed) % 1,
605
+ new_energy_nodes,
606
+ state.energy_nodes,
607
+ )
608
+ state = state.replace(
609
+ map_features=state.map_features.replace(tile_type=new_tile_types_map),
610
+ energy_nodes=new_energy_nodes,
611
+ )
612
+
613
+ # Compute relic scores
614
+ def team_relic_score(unit_counts_map):
615
+ # not all relic nodes are spawned in yet, but relic nodes map ids are precomputed for all to be spawned relic nodes
616
+ # for efficiency. So we check if the relic node (by id) is spawned in yet. relic nodes mask is always increasing so we can do a simple trick below
617
+ scores = (
618
+ (unit_counts_map > 0)
619
+ & (state.relic_nodes_map_weights <= state.relic_nodes_mask.sum() // 2)
620
+ & (state.relic_nodes_map_weights > 0)
621
+ )
622
+ return jnp.sum(scores, dtype=jnp.int32)
623
+
624
+ # note we need to recompue unit counts since units can get removed due to collisions
625
+ team_scores = jax.vmap(team_relic_score)(
626
+ self.compute_unit_counts_map(state, params, exclude_negative_energy_units=True)
627
+ )
628
+ # Update team points
629
+ state = state.replace(team_points=state.team_points + team_scores)
630
+
631
+ # if match ended, then remove all units, update team wins, reset team points
632
+ winner_by_points = jnp.where(
633
+ state.team_points.max() > state.team_points.min(),
634
+ jnp.argmax(state.team_points),
635
+ -1,
636
+ )
637
+ winner_by_energy = jnp.sum(state.units.energy[..., 0] * state.units_mask.astype(jnp.int16), axis=1)
638
+ winner_by_energy = jnp.where(
639
+ winner_by_energy.max() > winner_by_energy.min(),
640
+ jnp.argmax(winner_by_energy),
641
+ -1,
642
+ )
643
+
644
+ winner = jnp.where(
645
+ winner_by_points != -1,
646
+ winner_by_points,
647
+ jnp.where(
648
+ winner_by_energy != -1,
649
+ winner_by_energy,
650
+ jax.random.randint(key, shape=(), minval=0, maxval=params.num_teams),
651
+ ),
652
+ )
653
+ match_ended = state.match_steps >= params.max_steps_in_match
654
+
655
+ state = state.replace(
656
+ match_steps=jnp.where(match_ended, -1, state.match_steps),
657
+ team_points=jnp.where(match_ended, jnp.zeros_like(state.team_points), state.team_points),
658
+ team_wins=jnp.where(match_ended, state.team_wins.at[winner].add(1), state.team_wins),
659
+ )
660
+ # Update state's step count
661
+ state = state.replace(steps=state.steps + 1, match_steps=state.match_steps + 1)
662
+ truncated = state.steps >= (params.max_steps_in_match + 1) * params.match_count_per_episode
663
+ reward = dict()
664
+ for k in range(self.fixed_env_params.num_teams):
665
+ reward[f"player_{k}"] = state.team_wins[k]
666
+ terminated = self.is_terminal(state, params)
667
+ return (
668
+ lax.stop_gradient(self.get_obs(state, params, key=key)),
669
+ lax.stop_gradient(state),
670
+ reward,
671
+ terminated,
672
+ truncated,
673
+ {"discount": self.discount(state, params)},
674
+ )
675
+
676
+ def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[EnvObs, EnvState]:
677
+ """Reset environment state by sampling initial position."""
678
+
679
+ state = gen_state(
680
+ key=key,
681
+ env_params=params,
682
+ max_units=self.fixed_env_params.max_units,
683
+ num_teams=self.fixed_env_params.num_teams,
684
+ map_type=self.fixed_env_params.map_type,
685
+ map_width=self.fixed_env_params.map_width,
686
+ map_height=self.fixed_env_params.map_height,
687
+ max_energy_nodes=self.fixed_env_params.max_energy_nodes,
688
+ max_relic_nodes=self.fixed_env_params.max_relic_nodes,
689
+ relic_config_size=self.fixed_env_params.relic_config_size,
690
+ )
691
+ state = self.compute_energy_features(state, params)
692
+ state = self.compute_sensor_masks(state, params)
693
+ return self.get_obs(state, params=params, key=key), state
694
+
695
+ @functools.partial(jax.jit, static_argnums=(0,))
696
+ def step(
697
+ self,
698
+ key: chex.PRNGKey,
699
+ state: EnvState,
700
+ action: Union[int, float, chex.Array],
701
+ params: Optional[EnvParams] = None,
702
+ ) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
703
+ """Performs step transitions in the environment."""
704
+ # Use default env parameters if no others specified
705
+ if params is None:
706
+ params = self.default_params
707
+ key, key_reset = jax.random.split(key)
708
+ obs_st, state_st, reward, terminated, truncated, info = self.step_env(key, state, action, params)
709
+ info["final_state"] = state_st
710
+ info["final_observation"] = obs_st
711
+ done = terminated | truncated
712
+
713
+ if self.auto_reset:
714
+ obs_re, state_re = self.reset_env(key_reset, params)
715
+ # Use lax.cond to efficiently choose between obs_re and obs_st
716
+ obs = jax.lax.cond(done, lambda: obs_re, lambda: obs_st)
717
+ state = jax.lax.cond(done, lambda: state_re, lambda: state_st)
718
+ else:
719
+ obs = obs_st
720
+ state = state_st
721
+
722
+ # all agents terminate/truncate at same time
723
+ terminated_dict = dict()
724
+ truncated_dict = dict()
725
+ for k in range(self.fixed_env_params.num_teams):
726
+ terminated_dict[f"player_{k}"] = terminated
727
+ truncated_dict[f"player_{k}"] = truncated
728
+ info[f"player_{k}"] = dict()
729
+ return obs, state, reward, terminated_dict, truncated_dict, info
730
+
731
+ @functools.partial(jax.jit, static_argnums=(0,))
732
+ def reset(self, key: chex.PRNGKey, params: Optional[EnvParams] = None) -> Tuple[chex.Array, EnvState]:
733
+ """Performs resetting of environment."""
734
+ # Use default env parameters if no others specified
735
+ if params is None:
736
+ params = self.default_params
737
+
738
+ obs, state = self.reset_env(key, params)
739
+ return obs, state
740
+
741
+ # @functools.partial(jax.jit, static_argnums=(0, 2))
742
+ def get_obs(self, state: EnvState, params=None, key=None) -> EnvObs:
743
+ """Return observation from raw state, handling partial observability."""
744
+ obs = dict()
745
+
746
+ def update_unit_mask(unit_position, unit_mask, sensor_mask):
747
+ return unit_mask & sensor_mask[unit_position[0], unit_position[1]]
748
+
749
+ def update_team_unit_mask(unit_position, unit_mask, sensor_mask):
750
+ return jax.vmap(update_unit_mask, in_axes=(0, 0, None))(unit_position, unit_mask, sensor_mask)
751
+
752
+ def update_relic_nodes_mask(relic_nodes_mask, relic_nodes, sensor_mask):
753
+ return jax.vmap(
754
+ lambda r_mask, r, s_mask: r_mask & s_mask[r[0], r[1]],
755
+ in_axes=(0, 0, None),
756
+ )(relic_nodes_mask, relic_nodes, sensor_mask)
757
+
758
+ for t in range(self.fixed_env_params.num_teams):
759
+ other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
760
+ new_unit_masks = jax.vmap(update_team_unit_mask, in_axes=(0, 0, None))(
761
+ state.units.position[other_team_ids],
762
+ state.units_mask[other_team_ids],
763
+ state.sensor_mask[t],
764
+ )
765
+ new_unit_masks = state.units_mask.at[other_team_ids].set(new_unit_masks)
766
+
767
+ new_relic_nodes_mask = update_relic_nodes_mask(
768
+ state.relic_nodes_mask, state.relic_nodes, state.sensor_mask[t]
769
+ )
770
+ team_obs = EnvObs(
771
+ units=UnitState(
772
+ position=jnp.where(new_unit_masks[..., None], state.units.position, -1),
773
+ energy=jnp.where(new_unit_masks[..., None], state.units.energy, -1)[..., 0],
774
+ ),
775
+ units_mask=new_unit_masks,
776
+ sensor_mask=state.sensor_mask[t],
777
+ map_features=MapTile(
778
+ energy=jnp.where(state.sensor_mask[t], state.map_features.energy, -1),
779
+ tile_type=jnp.where(state.sensor_mask[t], state.map_features.tile_type, -1),
780
+ ),
781
+ team_points=state.team_points,
782
+ team_wins=state.team_wins,
783
+ steps=state.steps,
784
+ match_steps=state.match_steps,
785
+ relic_nodes=jnp.where(new_relic_nodes_mask[..., None], state.relic_nodes, -1),
786
+ relic_nodes_mask=new_relic_nodes_mask,
787
+ )
788
+ obs[f"player_{t}"] = team_obs
789
+ return obs
790
+
791
+ @functools.partial(jax.jit, static_argnums=(0,))
792
+ def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray:
793
+ """Check whether state is terminal. This never occurs. Game is only done when the time limit is reached."""
794
+ terminated = jnp.array(False)
795
+ return terminated
796
+
797
+ @property
798
+ def name(self) -> str:
799
+ """Environment name."""
800
+ return "Lux AI Season 3"
801
+
802
+ def render(self, state: EnvState, params: EnvParams):
803
+ self.renderer.render(state, params)
804
+
805
+ def action_space(self, params: Optional[EnvParams] = None):
806
+ """Action space of the environment."""
807
+ low = np.zeros((self.fixed_env_params.max_units, 3))
808
+ low[:, 1:] = -env_params_ranges["unit_sap_range"][-1]
809
+ high = np.ones((self.fixed_env_params.max_units, 3)) * 6
810
+ high[:, 1:] = env_params_ranges["unit_sap_range"][-1]
811
+ return spaces.Dict(dict(player_0=MultiDiscrete(low, high), player_1=MultiDiscrete(low, high)))
812
+
813
+ def observation_space(self, params: EnvParams):
814
+ """Observation space of the environment."""
815
+ return spaces.Discrete(10)
816
+
817
+ def state_space(self, params: EnvParams):
818
+ """State space of the environment."""
819
+ return spaces.Discrete(10)