kaggle-environments 0.2.1__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (215) hide show
  1. kaggle_environments/__init__.py +49 -13
  2. kaggle_environments/agent.py +177 -124
  3. kaggle_environments/api.py +31 -0
  4. kaggle_environments/core.py +295 -170
  5. kaggle_environments/envs/cabt/cabt.js +164 -0
  6. kaggle_environments/envs/cabt/cabt.json +28 -0
  7. kaggle_environments/envs/cabt/cabt.py +186 -0
  8. kaggle_environments/envs/cabt/cg/__init__.py +0 -0
  9. kaggle_environments/envs/cabt/cg/cg.dll +0 -0
  10. kaggle_environments/envs/cabt/cg/game.py +75 -0
  11. kaggle_environments/envs/cabt/cg/libcg.so +0 -0
  12. kaggle_environments/envs/cabt/cg/sim.py +48 -0
  13. kaggle_environments/envs/cabt/test_cabt.py +120 -0
  14. kaggle_environments/envs/chess/chess.js +4289 -0
  15. kaggle_environments/envs/chess/chess.json +60 -0
  16. kaggle_environments/envs/chess/chess.py +4241 -0
  17. kaggle_environments/envs/chess/test_chess.py +60 -0
  18. kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
  19. kaggle_environments/envs/connectx/connectx.js +1 -1
  20. kaggle_environments/envs/connectx/connectx.json +15 -1
  21. kaggle_environments/envs/connectx/connectx.py +6 -23
  22. kaggle_environments/envs/connectx/test_connectx.py +70 -24
  23. kaggle_environments/envs/football/football.ipynb +75 -0
  24. kaggle_environments/envs/football/football.json +91 -0
  25. kaggle_environments/envs/football/football.py +277 -0
  26. kaggle_environments/envs/football/helpers.py +95 -0
  27. kaggle_environments/envs/football/test_football.py +360 -0
  28. kaggle_environments/envs/halite/__init__.py +0 -0
  29. kaggle_environments/envs/halite/halite.ipynb +44741 -0
  30. kaggle_environments/envs/halite/halite.js +199 -83
  31. kaggle_environments/envs/halite/halite.json +31 -18
  32. kaggle_environments/envs/halite/halite.py +164 -303
  33. kaggle_environments/envs/halite/helpers.py +720 -0
  34. kaggle_environments/envs/halite/test_halite.py +190 -0
  35. kaggle_environments/envs/hungry_geese/__init__.py +0 -0
  36. kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
  37. kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +21 -14
  38. kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
  39. kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
  40. kaggle_environments/envs/identity/identity.json +6 -5
  41. kaggle_environments/envs/identity/identity.py +15 -2
  42. kaggle_environments/envs/kore_fleets/__init__.py +0 -0
  43. kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
  44. kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
  45. kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
  46. kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
  47. kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
  48. kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
  49. kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
  50. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
  51. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
  52. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
  53. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
  54. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
  55. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
  56. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
  57. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
  58. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
  59. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
  60. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
  61. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
  62. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
  63. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
  64. kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
  65. kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
  66. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
  67. kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
  68. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
  69. kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
  70. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
  71. kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
  72. kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
  73. kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
  74. kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
  75. kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
  76. kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
  77. kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
  78. kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
  79. kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
  80. kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
  81. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
  82. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
  83. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
  84. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
  85. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
  86. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
  87. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
  88. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
  89. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
  90. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
  91. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
  92. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
  93. kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
  94. kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
  95. kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
  96. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
  97. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
  98. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
  99. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
  100. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
  101. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
  102. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
  103. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
  104. kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
  105. kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
  106. kaggle_environments/envs/lux_ai_2021/README.md +3 -0
  107. kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
  108. kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
  109. kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
  110. kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
  111. kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
  112. kaggle_environments/envs/lux_ai_2021/index.html +43 -0
  113. kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
  114. kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
  115. kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
  116. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
  117. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
  118. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
  119. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
  120. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
  121. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
  122. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
  123. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
  124. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
  125. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
  126. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
  127. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
  128. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
  129. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
  130. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
  131. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
  132. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
  133. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
  134. kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
  135. kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
  136. kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
  137. kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
  138. kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
  139. kaggle_environments/envs/lux_ai_s3/README.md +21 -0
  140. kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
  141. kaggle_environments/envs/lux_ai_s3/index.html +42 -0
  142. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
  143. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
  144. kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
  145. kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
  146. kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
  147. kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
  148. kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
  149. kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
  150. kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
  151. kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
  152. kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
  153. kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
  154. kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
  155. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
  156. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
  157. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
  158. kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
  159. kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
  160. kaggle_environments/envs/mab/__init__.py +0 -0
  161. kaggle_environments/envs/mab/agents.py +12 -0
  162. kaggle_environments/envs/mab/mab.js +100 -0
  163. kaggle_environments/envs/mab/mab.json +74 -0
  164. kaggle_environments/envs/mab/mab.py +146 -0
  165. kaggle_environments/envs/open_spiel/__init__.py +0 -0
  166. kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
  167. kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
  168. kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
  169. kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
  170. kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
  171. kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
  172. kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
  173. kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
  174. kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
  175. kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
  176. kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
  177. kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
  178. kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
  179. kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
  180. kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
  181. kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
  182. kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
  183. kaggle_environments/envs/open_spiel/observation.py +128 -0
  184. kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
  185. kaggle_environments/envs/open_spiel/proxy.py +138 -0
  186. kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
  187. kaggle_environments/envs/rps/__init__.py +0 -0
  188. kaggle_environments/envs/rps/agents.py +84 -0
  189. kaggle_environments/envs/rps/helpers.py +25 -0
  190. kaggle_environments/envs/rps/rps.js +117 -0
  191. kaggle_environments/envs/rps/rps.json +63 -0
  192. kaggle_environments/envs/rps/rps.py +90 -0
  193. kaggle_environments/envs/rps/test_rps.py +110 -0
  194. kaggle_environments/envs/rps/utils.py +7 -0
  195. kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
  196. kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
  197. kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
  198. kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
  199. kaggle_environments/errors.py +2 -4
  200. kaggle_environments/helpers.py +377 -0
  201. kaggle_environments/main.py +340 -0
  202. kaggle_environments/schemas.json +23 -18
  203. kaggle_environments/static/player.html +206 -74
  204. kaggle_environments/utils.py +46 -73
  205. kaggle_environments-1.20.0.dist-info/METADATA +25 -0
  206. kaggle_environments-1.20.0.dist-info/RECORD +211 -0
  207. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info}/WHEEL +1 -2
  208. kaggle_environments-1.20.0.dist-info/entry_points.txt +3 -0
  209. kaggle_environments/envs/battlegeese/battlegeese.py +0 -223
  210. kaggle_environments/temp.py +0 -14
  211. kaggle_environments-0.2.1.dist-info/METADATA +0 -393
  212. kaggle_environments-0.2.1.dist-info/RECORD +0 -32
  213. kaggle_environments-0.2.1.dist-info/entry_points.txt +0 -3
  214. kaggle_environments-0.2.1.dist-info/top_level.txt +0 -1
  215. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,464 @@
1
+ import functools
2
+
3
+ import chex
4
+ import flax
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from flax import struct
9
+
10
+ from luxai_s3.params import MAP_TYPES, EnvParams
11
+
12
+ EMPTY_TILE = 0
13
+ NEBULA_TILE = 1
14
+ ASTEROID_TILE = 2
15
+
16
+ ENERGY_NODE_FNS = [lambda d, x, y, z: jnp.sin(d * x + y) * z, lambda d, x, y, z: (x / (d + 1) + y) * z]
17
+
18
+
19
+ @struct.dataclass
20
+ class UnitState:
21
+ position: chex.Array
22
+ """Position of the unit with shape (2) for x, y"""
23
+ energy: int
24
+ """Energy of the unit"""
25
+
26
+
27
+ @struct.dataclass
28
+ class MapTile:
29
+ energy: int
30
+ """Energy of the tile, generated via energy_nodes and energy_node_fns"""
31
+ tile_type: int
32
+ """Type of the tile"""
33
+
34
+
35
+ @struct.dataclass
36
+ class EnvState:
37
+ units: UnitState
38
+ """Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
39
+
40
+ 3 features are for position (x, y), and energy
41
+ """
42
+ units_mask: chex.Array
43
+ """Mask of units in the environment with shape (T, N) for T teams, N max units"""
44
+ energy_nodes: chex.Array
45
+ """Energy nodes in the environment with shape (N, 2) for N max energy nodes, and 2 features.
46
+
47
+ 2 features are for position (x, y)
48
+ """
49
+
50
+ energy_node_fns: chex.Array
51
+ """Energy node functions for computing the energy field of the map. They describe the function with a sequence of numbers
52
+
53
+ The first number is the function used. The subsequent numbers parameterize the function. The function is applied to distance of map tile to energy node and the function parameters.
54
+ """
55
+
56
+ # energy_field: chex.Array
57
+ # """Energy field in the environment with shape (H, W) for H height, W width. This is generated from other state"""
58
+
59
+ energy_nodes_mask: chex.Array
60
+ """Mask of energy nodes in the environment with shape (N) for N max energy nodes"""
61
+ relic_nodes: chex.Array
62
+ """Relic nodes in the environment with shape (N, 2) for N max relic nodes, and 2 features.
63
+
64
+ 2 features are for position (x, y)
65
+ """
66
+ relic_node_configs: chex.Array
67
+ """Relic node configs in the environment with shape (N, K, K) for N max relic nodes and a KxK relic configuration"""
68
+ relic_nodes_mask: chex.Array
69
+ """Mask of relic nodes in the environment with shape (N, ) for N max relic nodes"""
70
+ relic_nodes_map_weights: chex.Array
71
+ """Map of relic nodes in the environment with shape (H, W) for H height, W width. Each element is equal to the 1-indexed id of the relic node. This is generated from other state"""
72
+
73
+ relic_spawn_schedule: chex.Array
74
+ """Relic spawn schedule in the environment with shape (N, ) for N max relic nodes. Elements are the game timestep at which the relic node spawns"""
75
+
76
+ map_features: MapTile
77
+ """Map features in the environment with shape (W, H, 2) for W width, H height
78
+ """
79
+
80
+ sensor_mask: chex.Array
81
+ """Sensor mask in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
82
+
83
+ vision_power_map: chex.Array
84
+ """Vision power map in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
85
+
86
+ team_points: chex.Array
87
+ """Team points in the environment with shape (T) for T teams"""
88
+ team_wins: chex.Array
89
+ """Team wins in the environment with shape (T) for T teams"""
90
+
91
+ steps: int = 0
92
+ """steps taken in the environment"""
93
+ match_steps: int = 0
94
+ """steps taken in the current match"""
95
+
96
+
97
+ @struct.dataclass
98
+ class EnvObs:
99
+ """Partial observation of environment"""
100
+
101
+ units: UnitState
102
+ """Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
103
+
104
+ 3 features are for position (x, y), and energy
105
+ """
106
+ units_mask: chex.Array
107
+ """Mask of units in the environment with shape (T, N) for T teams, N max units"""
108
+
109
+ sensor_mask: chex.Array
110
+
111
+ map_features: MapTile
112
+ """Map features in the environment with shape (W, H, 2) for W width, H height
113
+ """
114
+ relic_nodes: chex.Array
115
+ """Position of all relic nodes with shape (N, 2) for N max relic nodes and 2 features for position (x, y). Number is -1 if not visible"""
116
+ relic_nodes_mask: chex.Array
117
+ """Mask of all relic nodes with shape (N) for N max relic nodes"""
118
+ team_points: chex.Array
119
+ """Team points in the environment with shape (T) for T teams"""
120
+ team_wins: chex.Array
121
+ """Team wins in the environment with shape (T) for T teams"""
122
+ steps: int = 0
123
+ """steps taken in the environment"""
124
+ match_steps: int = 0
125
+ """steps taken in the current match"""
126
+
127
+
128
+ def serialize_env_states(env_states: list[EnvState]):
129
+ def serialize_array(root: EnvState, arr, key_path: str = ""):
130
+ if key_path in [
131
+ "sensor_mask",
132
+ "relic_nodes_mask",
133
+ "energy_nodes_mask",
134
+ "energy_node_fns",
135
+ "relic_nodes_map_weights",
136
+ "relic_spawn_schedule",
137
+ ]:
138
+ return None
139
+ if key_path == "relic_nodes":
140
+ return root.relic_nodes[root.relic_nodes_mask].tolist()
141
+ if key_path == "relic_node_configs":
142
+ return root.relic_node_configs[root.relic_nodes_mask].tolist()
143
+ if key_path == "energy_nodes":
144
+ return root.energy_nodes[root.energy_nodes_mask].tolist()
145
+ if isinstance(arr, jnp.ndarray):
146
+ return arr.tolist()
147
+ elif isinstance(arr, dict):
148
+ ret = dict()
149
+ for k, v in arr.items():
150
+ new_key = key_path + "/" + k if key_path else k
151
+ new_val = serialize_array(root, v, new_key)
152
+ if new_val is not None:
153
+ ret[k] = new_val
154
+ return ret
155
+ return arr
156
+
157
+ steps = []
158
+ for state in env_states:
159
+ state_dict = flax.serialization.to_state_dict(state)
160
+ steps.append(serialize_array(state, state_dict))
161
+
162
+ return steps
163
+
164
+
165
+ def serialize_env_actions(env_actions: list):
166
+ def serialize_array(arr, key_path: str = ""):
167
+ if isinstance(arr, np.ndarray):
168
+ return arr.tolist()
169
+ elif isinstance(arr, jnp.ndarray):
170
+ return arr.tolist()
171
+ elif isinstance(arr, dict):
172
+ ret = dict()
173
+ for k, v in arr.items():
174
+ new_key = key_path + "/" + k if key_path else k
175
+ new_val = serialize_array(v, new_key)
176
+ if new_val is not None:
177
+ ret[k] = new_val
178
+ return ret
179
+
180
+ return arr
181
+
182
+ steps = []
183
+ for state in env_actions:
184
+ state = flax.serialization.to_state_dict(state)
185
+ steps.append(serialize_array(state))
186
+
187
+ return steps
188
+
189
+
190
+ def state_to_flat_obs(state: EnvState) -> chex.Array:
191
+ pass
192
+
193
+
194
+ def flat_obs_to_state(flat_obs: chex.Array) -> EnvState:
195
+ pass
196
+
197
+
198
+ @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9))
199
+ def gen_state(
200
+ key: chex.PRNGKey,
201
+ env_params: EnvParams,
202
+ max_units: int,
203
+ num_teams: int,
204
+ map_type: int,
205
+ map_width: int,
206
+ map_height: int,
207
+ max_energy_nodes: int,
208
+ max_relic_nodes: int,
209
+ relic_config_size: int,
210
+ ) -> EnvState:
211
+ generated = gen_map(
212
+ key, env_params, map_type, map_width, map_height, max_energy_nodes, max_relic_nodes, relic_config_size
213
+ )
214
+ relic_nodes_map_weights = jnp.zeros(shape=(map_width, map_height), dtype=jnp.int16)
215
+
216
+ # TODO (this could be optimized better)
217
+ def update_relic_node(relic_nodes_map_weights, relic_data):
218
+ relic_node, relic_node_config, mask, relic_node_id = relic_data
219
+ start_y = relic_node[1] - relic_config_size // 2
220
+ start_x = relic_node[0] - relic_config_size // 2
221
+
222
+ for dy in range(relic_config_size):
223
+ for dx in range(relic_config_size):
224
+ y, x = start_y + dy, start_x + dx
225
+ valid_pos = jnp.logical_and(
226
+ jnp.logical_and(y >= 0, x >= 0),
227
+ jnp.logical_and(y < map_height, x < map_width),
228
+ )
229
+ # ensure we don't override previous spawns
230
+ has_points = jnp.logical_and(relic_nodes_map_weights > 0, relic_nodes_map_weights <= relic_node_id + 1)
231
+ relic_nodes_map_weights = jnp.where(
232
+ valid_pos & mask & jnp.logical_not(has_points) & relic_node_config[dx, dy],
233
+ relic_nodes_map_weights.at[x, y].set(
234
+ relic_node_config[dx, dy].astype(jnp.int16) * (relic_node_id + 1)
235
+ ),
236
+ relic_nodes_map_weights,
237
+ )
238
+ return relic_nodes_map_weights, None
239
+
240
+ # this is really slow...
241
+
242
+ relic_nodes_map_weights, _ = jax.lax.scan(
243
+ update_relic_node,
244
+ relic_nodes_map_weights,
245
+ (
246
+ generated["relic_nodes"],
247
+ generated["relic_node_configs"],
248
+ generated["relic_nodes_mask"],
249
+ jnp.arange(max_relic_nodes, dtype=jnp.int16) % (max_relic_nodes // 2),
250
+ ),
251
+ )
252
+
253
+ state = EnvState(
254
+ units=UnitState(
255
+ position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16),
256
+ energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16),
257
+ ),
258
+ units_mask=jnp.zeros(shape=(num_teams, max_units), dtype=jnp.bool),
259
+ team_points=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
260
+ team_wins=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
261
+ energy_nodes=generated["energy_nodes"],
262
+ energy_node_fns=generated["energy_node_fns"],
263
+ energy_nodes_mask=generated["energy_nodes_mask"],
264
+ # energy_field=jnp.zeros(shape=(params.map_height, params.map_width), dtype=jnp.int16),
265
+ relic_nodes=generated["relic_nodes"],
266
+ relic_nodes_mask=jnp.zeros(
267
+ shape=(max_relic_nodes), dtype=jnp.bool
268
+ ), # as relic nodes are spawn in, we start with them all invisible.
269
+ relic_node_configs=generated["relic_node_configs"],
270
+ relic_nodes_map_weights=relic_nodes_map_weights,
271
+ relic_spawn_schedule=generated["relic_spawn_schedule"],
272
+ sensor_mask=jnp.zeros(
273
+ shape=(num_teams, map_height, map_width),
274
+ dtype=jnp.bool,
275
+ ),
276
+ vision_power_map=jnp.zeros(shape=(num_teams, map_height, map_width), dtype=jnp.int16),
277
+ map_features=generated["map_features"],
278
+ )
279
+ return state
280
+
281
+
282
+ @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7))
283
+ def gen_map(
284
+ key: chex.PRNGKey,
285
+ params: EnvParams,
286
+ map_type: int,
287
+ map_height: int,
288
+ map_width: int,
289
+ max_energy_nodes: int,
290
+ max_relic_nodes: int,
291
+ relic_config_size: int,
292
+ ) -> chex.Array:
293
+ map_features = MapTile(
294
+ energy=jnp.zeros(shape=(map_height, map_width), dtype=jnp.int16),
295
+ tile_type=jnp.zeros(shape=(map_height, map_width), dtype=jnp.int16),
296
+ )
297
+ energy_nodes = jnp.zeros(shape=(max_energy_nodes, 2), dtype=jnp.int16)
298
+ energy_nodes_mask = jnp.zeros(shape=(max_energy_nodes), dtype=jnp.bool)
299
+ relic_nodes = jnp.zeros(shape=(max_relic_nodes, 2), dtype=jnp.int16)
300
+ relic_nodes_mask = jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool)
301
+
302
+ if MAP_TYPES[map_type] == "random":
303
+ ### Generate nebula tiles ###
304
+ key, subkey = jax.random.split(key)
305
+ perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
306
+ noise = jnp.where(perlin_noise > 0.5, 1, 0)
307
+ # mirror along diagonal
308
+ noise = noise | noise.T
309
+ noise = noise[::-1, ::1]
310
+ map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0))
311
+
312
+ ### Generate asteroid tiles ###
313
+ key, subkey = jax.random.split(key)
314
+ perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (8, 8))
315
+ noise = jnp.where(perlin_noise < -0.5, 1, 0)
316
+ # mirror along diagonal
317
+ noise = noise | noise.T
318
+ noise = noise[::-1, ::1]
319
+ map_features = map_features.replace(
320
+ tile_type=jnp.place(map_features.tile_type, noise, ASTEROID_TILE, inplace=False)
321
+ )
322
+
323
+ ### Generate relic nodes ###
324
+ key, subkey = jax.random.split(key)
325
+ noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
326
+ # Find the positions of the highest noise values
327
+ flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2 :] # Get indices of two highest values
328
+ highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
329
+
330
+ # relic nodes have a fixed density of 20% nearby tiles can yield points
331
+ relic_node_configs = (
332
+ jax.random.randint(
333
+ key,
334
+ shape=(
335
+ max_relic_nodes,
336
+ relic_config_size,
337
+ relic_config_size,
338
+ ),
339
+ minval=0,
340
+ maxval=10,
341
+ ).astype(jnp.float32)
342
+ >= 7.5
343
+ )
344
+ highest_positions = highest_positions.astype(jnp.int16)
345
+ mirrored_positions = jnp.stack(
346
+ [map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1],
347
+ dtype=jnp.int16,
348
+ axis=-1,
349
+ )
350
+ relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
351
+
352
+ key, subkey = jax.random.split(key)
353
+ num_spawned_relic_nodes = jax.random.randint(key, (1,), minval=1, maxval=(max_relic_nodes // 2) + 1)
354
+ relic_nodes_mask_half = jnp.arange(max_relic_nodes // 2) < num_spawned_relic_nodes
355
+ relic_nodes_mask = jnp.concat([relic_nodes_mask_half, relic_nodes_mask_half], axis=0)
356
+ relic_node_configs = relic_node_configs.at[max_relic_nodes // 2 :].set(
357
+ relic_node_configs[: max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1]
358
+ )
359
+ # note that relic nodes mask is always increasing.
360
+
361
+ ### Generate energy nodes ###
362
+ key, subkey = jax.random.split(key)
363
+ noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
364
+ # Find the positions of the highest noise values
365
+ flat_indices = jnp.argsort(noise.ravel())[-max_energy_nodes // 2 :] # Get indices of highest values
366
+ highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)).astype(jnp.int16)
367
+ mirrored_positions = jnp.stack(
368
+ [map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1],
369
+ dtype=jnp.int16,
370
+ axis=-1,
371
+ )
372
+ energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
373
+ key, subkey = jax.random.split(key)
374
+ energy_nodes_mask_half = jax.random.randint(key, (max_energy_nodes // 2,), minval=0, maxval=2).astype(jnp.bool)
375
+ energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True)
376
+ energy_nodes_mask = energy_nodes_mask.at[: max_energy_nodes // 2].set(energy_nodes_mask_half)
377
+ energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2 :].set(energy_nodes_mask_half)
378
+
379
+ energy_node_fns = jnp.array(
380
+ [
381
+ [0, 1.2, 1, 4],
382
+ [0, 0, 0, 0],
383
+ [0, 0, 0, 0],
384
+ # [1, 4, 0, 2],
385
+ [0, 1.2, 1, 4],
386
+ [0, 0, 0, 0],
387
+ [0, 0, 0, 0],
388
+ # [1, 4, 0, 0]
389
+ ]
390
+ )
391
+
392
+ # generate a random relic spawn schedule
393
+ # if number is -1, then relic node is never spawned, otherwise spawn at that game timestep
394
+ assert max_relic_nodes == 6, "random map generation is hardcoded to use 6 relic nodes at most per map"
395
+ key, subkey = jax.random.split(key)
396
+ relic_spawn_schedule_half = jax.random.randint(
397
+ key, (max_relic_nodes // 2,), minval=0, maxval=params.max_steps_in_match // 2
398
+ ) + jnp.arange(3) * (params.max_steps_in_match + 1)
399
+ relic_spawn_schedule = jnp.concat([relic_spawn_schedule_half, relic_spawn_schedule_half], axis=0)
400
+ relic_spawn_schedule = jnp.where(relic_nodes_mask, relic_spawn_schedule, -1)
401
+
402
+ return dict(
403
+ map_features=map_features,
404
+ energy_nodes=energy_nodes,
405
+ energy_node_fns=energy_node_fns,
406
+ relic_nodes=relic_nodes,
407
+ energy_nodes_mask=energy_nodes_mask,
408
+ relic_nodes_mask=relic_nodes_mask,
409
+ relic_node_configs=relic_node_configs,
410
+ relic_spawn_schedule=relic_spawn_schedule,
411
+ )
412
+
413
+
414
+ def interpolant(t):
415
+ return t * t * t * (t * (t * 6 - 15) + 10)
416
+
417
+
418
+ @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
419
+ def generate_perlin_noise_2d(key, shape, res, tileable=(False, False), interpolant=interpolant):
420
+ """Generate a 2D numpy array of perlin noise.
421
+
422
+ Args:
423
+ shape: The shape of the generated array (tuple of two ints).
424
+ This must be a multple of res.
425
+ res: The number of periods of noise to generate along each
426
+ axis (tuple of two ints). Note shape must be a multiple of
427
+ res.
428
+ tileable: If the noise should be tileable along each axis
429
+ (tuple of two bools). Defaults to (False, False).
430
+ interpolant: The interpolation function, defaults to
431
+ t*t*t*(t*(t*6 - 15) + 10).
432
+
433
+ Returns:
434
+ A numpy array of shape shape with the generated noise.
435
+
436
+ Raises:
437
+ ValueError: If shape is not a multiple of res.
438
+ """
439
+ delta = (res[0] / shape[0], res[1] / shape[1])
440
+ d = (shape[0] // res[0], shape[1] // res[1])
441
+ grid = jnp.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1
442
+ # Gradients
443
+ angles = 2 * jnp.pi * jax.random.uniform(key, (res[0] + 1, res[1] + 1))
444
+ gradients = jnp.dstack((jnp.cos(angles), jnp.sin(angles)))
445
+ if tileable[0]:
446
+ gradients[-1, :] = gradients[0, :]
447
+ if tileable[1]:
448
+ gradients[:, -1] = gradients[:, 0]
449
+ gradients = gradients.repeat(d[0], 0).repeat(d[1], 1)
450
+ g00 = gradients[: -d[0], : -d[1]]
451
+ g10 = gradients[d[0] :, : -d[1]]
452
+ g01 = gradients[: -d[0], d[1] :]
453
+ g11 = gradients[d[0] :, d[1] :]
454
+
455
+ # Ramps
456
+ n00 = jnp.sum(jnp.dstack((grid[:, :, 0], grid[:, :, 1])) * g00, 2)
457
+ n10 = jnp.sum(jnp.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
458
+ n01 = jnp.sum(jnp.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
459
+ n11 = jnp.sum(jnp.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
460
+ # Interpolation
461
+ t = interpolant(grid)
462
+ n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
463
+ n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
464
+ return jnp.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
@@ -0,0 +1,12 @@
1
+ import numpy as np
2
+
3
+
4
+ def to_numpy(x):
5
+ if isinstance(x, dict):
6
+ return {k: to_numpy(v) for k, v in x.items()}
7
+ elif isinstance(x, list):
8
+ return np.array(x)
9
+ elif isinstance(x, np.ndarray):
10
+ return x
11
+ else:
12
+ return np.array(x)
@@ -0,0 +1,156 @@
1
+ # TODO (stao): Add lux ai s3 env to gymnax api wrapper, which is the old gym api
2
+ import dataclasses
3
+ import json
4
+ import os
5
+ from typing import Any, SupportsFloat
6
+
7
+ import flax
8
+ import flax.serialization
9
+ import gymnasium as gym
10
+ import jax
11
+ import numpy as np
12
+
13
+ from luxai_s3.env import LuxAIS3Env
14
+ from luxai_s3.params import EnvParams, env_params_ranges
15
+ from luxai_s3.state import serialize_env_actions, serialize_env_states
16
+ from luxai_s3.utils import to_numpy
17
+
18
+
19
+ class LuxAIS3GymEnv(gym.Env):
20
+ def __init__(self, numpy_output: bool = False):
21
+ self.numpy_output = numpy_output
22
+ self.rng_key = jax.random.key(0)
23
+ self.jax_env = LuxAIS3Env(auto_reset=False)
24
+ self.env_params: EnvParams = EnvParams()
25
+
26
+ low = np.zeros((self.env_params.max_units, 3))
27
+ low[:, 1:] = -self.env_params.unit_sap_range
28
+ high = np.ones((self.env_params.max_units, 3)) * 6
29
+ high[:, 1:] = self.env_params.unit_sap_range
30
+ self.action_space = gym.spaces.Dict(
31
+ dict(
32
+ player_0=gym.spaces.Box(low=low, high=high, dtype=np.int16),
33
+ player_1=gym.spaces.Box(low=low, high=high, dtype=np.int16),
34
+ )
35
+ )
36
+
37
+ def render(self):
38
+ self.jax_env.render(self.state, self.env_params)
39
+
40
+ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
41
+ if seed is not None:
42
+ self.rng_key = jax.random.key(seed)
43
+ self.rng_key, reset_key = jax.random.split(self.rng_key)
44
+ # generate random game parameters
45
+ # TODO (stao): check why this keeps recompiling when marking structs as static args
46
+ randomized_game_params = dict()
47
+ for k, v in env_params_ranges.items():
48
+ self.rng_key, subkey = jax.random.split(self.rng_key)
49
+ randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v)).item()
50
+ params = EnvParams(**randomized_game_params)
51
+ if options is not None and "params" in options:
52
+ params = options["params"]
53
+
54
+ self.env_params = params
55
+ obs, self.state = self.jax_env.reset(reset_key, params=params)
56
+ if self.numpy_output:
57
+ obs = to_numpy(flax.serialization.to_state_dict(obs))
58
+
59
+ # only keep the following game parameters available to the agent
60
+ params_dict = dataclasses.asdict(params)
61
+ params_dict_kept = dict()
62
+ for k in [
63
+ "max_units",
64
+ "match_count_per_episode",
65
+ "max_steps_in_match",
66
+ "map_height",
67
+ "map_width",
68
+ "num_teams",
69
+ "unit_move_cost",
70
+ "unit_sap_cost",
71
+ "unit_sap_range",
72
+ "unit_sensor_range",
73
+ ]:
74
+ params_dict_kept[k] = params_dict[k]
75
+ return obs, dict(params=params_dict_kept, full_params=params_dict, state=self.state)
76
+
77
+ def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
78
+ self.rng_key, step_key = jax.random.split(self.rng_key)
79
+ obs, self.state, reward, terminated, truncated, info = self.jax_env.step(
80
+ step_key, self.state, action, self.env_params
81
+ )
82
+ if self.numpy_output:
83
+ obs = to_numpy(flax.serialization.to_state_dict(obs))
84
+ reward = to_numpy(reward)
85
+ terminated = to_numpy(terminated)
86
+ truncated = to_numpy(truncated)
87
+ # info = to_numpy(flax.serialization.to_state_dict(info))
88
+ return obs, reward, terminated, truncated, info
89
+
90
+
91
+ # TODO: vectorized gym wrapper
92
+
93
+
94
+ class RecordEpisode(gym.Wrapper):
95
+ def __init__(
96
+ self,
97
+ env: LuxAIS3GymEnv,
98
+ save_dir: str = None,
99
+ save_on_close: bool = True,
100
+ save_on_reset: bool = True,
101
+ ):
102
+ super().__init__(env)
103
+ self.episode = dict(states=[], actions=[], metadata=dict())
104
+ self.episode_id = 0
105
+ self.save_dir = save_dir
106
+ self.save_on_close = save_on_close
107
+ self.save_on_reset = save_on_reset
108
+ self.episode_steps = 0
109
+ if save_dir is not None:
110
+ from pathlib import Path
111
+
112
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
113
+
114
+ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
115
+ if self.save_on_reset and self.episode_steps > 0:
116
+ self._save_episode_and_reset()
117
+ obs, info = self.env.reset(seed=seed, options=options)
118
+
119
+ self.episode["metadata"]["seed"] = seed
120
+ self.episode["params"] = flax.serialization.to_state_dict(info["full_params"])
121
+ self.episode["states"].append(info["state"])
122
+ return obs, info
123
+
124
+ def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
125
+ obs, reward, terminated, truncated, info = self.env.step(action)
126
+ self.episode_steps += 1
127
+ self.episode["states"].append(info["final_state"])
128
+ self.episode["actions"].append(action)
129
+ return obs, reward, terminated, truncated, info
130
+
131
+ def serialize_episode_data(self, episode=None):
132
+ if episode is None:
133
+ episode = self.episode
134
+ ret = dict()
135
+ ret["observations"] = serialize_env_states(episode["states"])
136
+ if "actions" in episode:
137
+ ret["actions"] = serialize_env_actions(episode["actions"])
138
+ ret["metadata"] = episode["metadata"]
139
+ ret["params"] = episode["params"]
140
+ return ret
141
+
142
+ def save_episode(self, save_path: str):
143
+ episode = self.serialize_episode_data()
144
+ with open(save_path, "w") as f:
145
+ json.dump(episode, f)
146
+ self.episode = dict(states=[], actions=[], metadata=dict())
147
+
148
+ def _save_episode_and_reset(self):
149
+ """saves to generated path based on self.save_dir and episoe id and updates relevant counters"""
150
+ self.save_episode(os.path.join(self.save_dir, f"episode_{self.episode_id}.json"))
151
+ self.episode_id += 1
152
+ self.episode_steps = 0
153
+
154
+ def close(self):
155
+ if self.save_on_close and self.episode_steps > 0:
156
+ self._save_episode_and_reset()