kaggle-environments 0.2.1__py3-none-any.whl → 1.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (214) hide show
  1. kaggle_environments/__init__.py +49 -13
  2. kaggle_environments/agent.py +177 -124
  3. kaggle_environments/api.py +31 -0
  4. kaggle_environments/core.py +295 -170
  5. kaggle_environments/envs/cabt/cabt.js +164 -0
  6. kaggle_environments/envs/cabt/cabt.json +28 -0
  7. kaggle_environments/envs/cabt/cabt.py +186 -0
  8. kaggle_environments/envs/cabt/cg/__init__.py +0 -0
  9. kaggle_environments/envs/cabt/cg/cg.dll +0 -0
  10. kaggle_environments/envs/cabt/cg/game.py +75 -0
  11. kaggle_environments/envs/cabt/cg/libcg.so +0 -0
  12. kaggle_environments/envs/cabt/cg/sim.py +48 -0
  13. kaggle_environments/envs/cabt/test_cabt.py +120 -0
  14. kaggle_environments/envs/chess/chess.js +4289 -0
  15. kaggle_environments/envs/chess/chess.json +60 -0
  16. kaggle_environments/envs/chess/chess.py +4241 -0
  17. kaggle_environments/envs/chess/test_chess.py +60 -0
  18. kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
  19. kaggle_environments/envs/connectx/connectx.js +1 -1
  20. kaggle_environments/envs/connectx/connectx.json +15 -1
  21. kaggle_environments/envs/connectx/connectx.py +6 -23
  22. kaggle_environments/envs/connectx/test_connectx.py +70 -24
  23. kaggle_environments/envs/football/football.ipynb +75 -0
  24. kaggle_environments/envs/football/football.json +91 -0
  25. kaggle_environments/envs/football/football.py +277 -0
  26. kaggle_environments/envs/football/helpers.py +95 -0
  27. kaggle_environments/envs/football/test_football.py +360 -0
  28. kaggle_environments/envs/halite/__init__.py +0 -0
  29. kaggle_environments/envs/halite/halite.ipynb +44741 -0
  30. kaggle_environments/envs/halite/halite.js +199 -83
  31. kaggle_environments/envs/halite/halite.json +31 -18
  32. kaggle_environments/envs/halite/halite.py +164 -303
  33. kaggle_environments/envs/halite/helpers.py +720 -0
  34. kaggle_environments/envs/halite/test_halite.py +190 -0
  35. kaggle_environments/envs/hungry_geese/__init__.py +0 -0
  36. kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
  37. kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +21 -14
  38. kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
  39. kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
  40. kaggle_environments/envs/identity/identity.json +6 -5
  41. kaggle_environments/envs/identity/identity.py +15 -2
  42. kaggle_environments/envs/kore_fleets/__init__.py +0 -0
  43. kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
  44. kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
  45. kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
  46. kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
  47. kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
  48. kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
  49. kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
  50. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
  51. kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
  52. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
  53. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
  54. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
  55. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
  56. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
  57. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
  58. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
  59. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
  60. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
  61. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
  62. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
  63. kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
  64. kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
  65. kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
  66. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
  67. kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
  68. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
  69. kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
  70. kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
  71. kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
  72. kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
  73. kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
  74. kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
  75. kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
  76. kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
  77. kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
  78. kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
  79. kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
  80. kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
  81. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
  82. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
  83. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
  84. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
  85. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
  86. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
  87. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
  88. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
  89. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
  90. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
  91. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
  92. kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
  93. kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
  94. kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
  95. kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
  96. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
  97. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
  98. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
  99. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
  100. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
  101. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
  102. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
  103. kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
  104. kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
  105. kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
  106. kaggle_environments/envs/lux_ai_2021/README.md +3 -0
  107. kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
  108. kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
  109. kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
  110. kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
  111. kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
  112. kaggle_environments/envs/lux_ai_2021/index.html +43 -0
  113. kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
  114. kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
  115. kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
  116. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
  117. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
  118. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
  119. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
  120. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
  121. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
  122. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
  123. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
  124. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
  125. kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
  126. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
  127. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
  128. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
  129. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
  130. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
  131. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
  132. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
  133. kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
  134. kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
  135. kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
  136. kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
  137. kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
  138. kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
  139. kaggle_environments/envs/lux_ai_s3/README.md +21 -0
  140. kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
  141. kaggle_environments/envs/lux_ai_s3/index.html +42 -0
  142. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
  143. kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
  144. kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
  145. kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
  146. kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
  147. kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
  148. kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
  149. kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
  150. kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
  151. kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
  152. kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
  153. kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
  154. kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
  155. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
  156. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
  157. kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
  158. kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
  159. kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
  160. kaggle_environments/envs/mab/__init__.py +0 -0
  161. kaggle_environments/envs/mab/agents.py +12 -0
  162. kaggle_environments/envs/mab/mab.js +100 -0
  163. kaggle_environments/envs/mab/mab.json +74 -0
  164. kaggle_environments/envs/mab/mab.py +146 -0
  165. kaggle_environments/envs/open_spiel/__init__.py +0 -0
  166. kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
  167. kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
  168. kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
  169. kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
  170. kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
  171. kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
  172. kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
  173. kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
  174. kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
  175. kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
  176. kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
  177. kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
  178. kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
  179. kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
  180. kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
  181. kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
  182. kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
  183. kaggle_environments/envs/open_spiel/observation.py +128 -0
  184. kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
  185. kaggle_environments/envs/open_spiel/proxy.py +138 -0
  186. kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
  187. kaggle_environments/envs/rps/__init__.py +0 -0
  188. kaggle_environments/envs/rps/agents.py +84 -0
  189. kaggle_environments/envs/rps/helpers.py +25 -0
  190. kaggle_environments/envs/rps/rps.js +117 -0
  191. kaggle_environments/envs/rps/rps.json +63 -0
  192. kaggle_environments/envs/rps/rps.py +90 -0
  193. kaggle_environments/envs/rps/test_rps.py +110 -0
  194. kaggle_environments/envs/rps/utils.py +7 -0
  195. kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
  196. kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
  197. kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
  198. kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
  199. kaggle_environments/errors.py +2 -4
  200. kaggle_environments/helpers.py +377 -0
  201. kaggle_environments/main.py +340 -0
  202. kaggle_environments/schemas.json +23 -18
  203. kaggle_environments/static/player.html +206 -74
  204. kaggle_environments/utils.py +46 -73
  205. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.1.dist-info}/METADATA +36 -114
  206. kaggle_environments-1.20.1.dist-info/RECORD +211 -0
  207. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.1.dist-info}/WHEEL +1 -2
  208. kaggle_environments-1.20.1.dist-info/entry_points.txt +3 -0
  209. kaggle_environments/envs/battlegeese/battlegeese.py +0 -223
  210. kaggle_environments/temp.py +0 -14
  211. kaggle_environments-0.2.1.dist-info/RECORD +0 -32
  212. kaggle_environments-0.2.1.dist-info/entry_points.txt +0 -3
  213. kaggle_environments-0.2.1.dist-info/top_level.txt +0 -1
  214. {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.1.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,9 @@
1
+ import os
2
+
3
+ TERM_COLORS = True
4
+ try:
5
+ TERM_COLORS = os.environ["LUX_COLORS"] == "False" if "LUX_COLORS" in os.environ else True
6
+ except:
7
+ TERM_COLORS = False
8
+ # print("termcolor not installed, skipping dependency")
9
+ pass
@@ -0,0 +1,101 @@
1
+ from flax import struct
2
+
3
+ MAP_TYPES = ["dev0", "random"]
4
+
5
+
6
+ @struct.dataclass
7
+ class EnvParams:
8
+ max_steps_in_match: int = 100
9
+ map_type: int = 1
10
+ """Map generation algorithm. Can change between games"""
11
+ map_width: int = 24
12
+ map_height: int = 24
13
+ num_teams: int = 2
14
+ match_count_per_episode: int = 5
15
+ """number of matches to play in one episode"""
16
+
17
+ # configs for units
18
+ max_units: int = 16
19
+ init_unit_energy: int = 100
20
+ min_unit_energy: int = 0
21
+ max_unit_energy: int = 400
22
+ unit_move_cost: int = 2
23
+ spawn_rate: int = 3
24
+
25
+ unit_sap_cost: int = 10
26
+ """
27
+ The unit sap cost is the amount of energy a unit uses when it saps another unit. Can change between games.
28
+ """
29
+ unit_sap_range: int = 4
30
+ """
31
+ The unit sap range is the range of the unit's sap action.
32
+ """
33
+ unit_sap_dropoff_factor: float = 0.5
34
+ """
35
+ The unit sap dropoff factor multiplied by unit_sap_drain
36
+ """
37
+ unit_energy_void_factor: float = 0.125
38
+ """
39
+ The unit energy void factor multiplied by unit_energy
40
+ """
41
+
42
+ # configs for energy nodes
43
+ max_energy_nodes: int = 6
44
+ max_energy_per_tile: int = 20
45
+ min_energy_per_tile: int = -20
46
+
47
+ max_relic_nodes: int = 6
48
+ """max relic nodes in the entire map. This number should be tuned carefully as relic node spawning code is hardcoded against this number 6"""
49
+ relic_config_size: int = 5
50
+ fog_of_war: bool = True
51
+ """
52
+ whether there is fog of war or not
53
+ """
54
+ unit_sensor_range: int = 2
55
+ """
56
+ The unit sensor range is the range of the unit's sensor.
57
+ Units provide "vision power" over tiles in range, equal to manhattan distance to the unit.
58
+
59
+ vision power > 0 that team can see the tiles properties
60
+ """
61
+
62
+ # nebula tile params
63
+ nebula_tile_vision_reduction: int = 1
64
+ """
65
+ The nebula tile vision reduction is the amount of vision reduction a nebula tile provides.
66
+ A tile can be seen if the vision power over it is > 0.
67
+ """
68
+
69
+ nebula_tile_energy_reduction: int = 0
70
+ """amount of energy nebula tiles reduce from a unit"""
71
+
72
+ nebula_tile_drift_speed: float = -0.05
73
+ """
74
+ how fast nebula tiles drift in one of the diagonal directions over time. If positive, flows to the top/right, negative flows to bottom/left
75
+ """
76
+ # TODO (stao): allow other kinds of symmetric drifts?
77
+
78
+ energy_node_drift_speed: int = 0.02
79
+ """
80
+ how fast energy nodes will move around over time
81
+ """
82
+ energy_node_drift_magnitude: int = 5
83
+
84
+ # option to change sap configurations
85
+
86
+
87
+ env_params_ranges = dict(
88
+ # map_type=[1],
89
+ unit_move_cost=list(range(1, 6)),
90
+ unit_sensor_range=[1, 2, 3, 4],
91
+ nebula_tile_vision_reduction=list(range(0, 8)),
92
+ nebula_tile_energy_reduction=[0, 1, 2, 3, 5, 25],
93
+ unit_sap_cost=list(range(30, 51)),
94
+ unit_sap_range=list(range(3, 8)),
95
+ unit_sap_dropoff_factor=[0.25, 0.5, 1],
96
+ unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
97
+ # map randomizations
98
+ nebula_tile_drift_speed=[-0.15, -0.1, -0.05, -0.025, 0.025, 0.05, 0.1, 0.15],
99
+ energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
100
+ energy_node_drift_magnitude=list(range(3, 6)),
101
+ )
@@ -0,0 +1,141 @@
1
+ import os
2
+ import time
3
+ from collections import defaultdict
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ import psutil
8
+ import pynvml
9
+
10
+
11
+ def flatten_dict_keys(d: dict, prefix=""):
12
+ """Flatten a dict by expanding its keys recursively."""
13
+ out = dict()
14
+ for k, v in d.items():
15
+ if isinstance(v, dict):
16
+ out.update(flatten_dict_keys(v, prefix + k + "/"))
17
+ else:
18
+ out[prefix + k] = v
19
+ return out
20
+
21
+
22
+ class Profiler:
23
+ """
24
+ A simple class to help profile/benchmark simulator code
25
+ """
26
+
27
+ def __init__(self, output_format: Literal["stdout", "json"], synchronize_torch: bool = True) -> None:
28
+ self.output_format = output_format
29
+ self.synchronize_torch = synchronize_torch
30
+ self.stats = defaultdict(list)
31
+ # Initialize NVML
32
+ pynvml.nvmlInit()
33
+
34
+ # Get handle for the first GPU (index 0)
35
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(0)
36
+
37
+ # Get the PID of the current process
38
+ self.current_pid = os.getpid()
39
+
40
+ def log(self, msg):
41
+ """log a message to stdout"""
42
+ if self.output_format == "stdout":
43
+ print(msg)
44
+
45
+ def update_csv(self, csv_path: str, data: dict):
46
+ """Update a csv file with the given data (a dict representing a unique identifier of the result row)
47
+ and stats. If the file does not exist, it will be created. The update will replace an existing row
48
+ if the given data matches the data in the row. If there are multiple matches, only the first match
49
+ will be replaced and the rest are deleted"""
50
+ import os
51
+
52
+ import pandas as pd
53
+
54
+ if os.path.exists(csv_path):
55
+ df = pd.read_csv(csv_path)
56
+ else:
57
+ df = pd.DataFrame()
58
+ stats_flat = flatten_dict_keys(self.stats)
59
+ cond = None
60
+
61
+ for k in stats_flat:
62
+ if k not in df:
63
+ df[k] = None
64
+ for k in data:
65
+ if k not in df:
66
+ df[k] = None
67
+
68
+ mask = df[k].isna() if data[k] is None else df[k] == data[k]
69
+ if cond is None:
70
+ cond = mask
71
+ else:
72
+ cond = cond & mask
73
+ data_dict = {**data, **stats_flat}
74
+ if not cond.any():
75
+ df = pd.concat([df, pd.DataFrame(data_dict, index=[len(df)])])
76
+ else:
77
+ # replace the first instance
78
+ df.loc[df.loc[cond].index[0]] = data_dict
79
+ df.drop(df.loc[cond].index[1:], inplace=True)
80
+ # delete other instances
81
+ df.to_csv(csv_path, index=False)
82
+
83
+ def profile(self, function, name: str, total_steps: int, num_envs: int, trials=1):
84
+ print(f"start recording {name} metrics")
85
+ process = psutil.Process(os.getpid())
86
+ cpu_mem_use = process.memory_info().rss
87
+ gpu_mem_use = self.get_current_process_gpu_memory()
88
+ if gpu_mem_use is None:
89
+ gpu_mem_use = 0
90
+
91
+ for trial in range(trials):
92
+ stime = time.time()
93
+ function()
94
+ dt = time.time() - stime
95
+ # dt: delta time (s)
96
+ # fps: frames per second
97
+ # psps: parallel steps per second (number of env.step calls per second)
98
+ self.stats[name].append(
99
+ dict(
100
+ dt=dt,
101
+ fps=total_steps * num_envs / dt,
102
+ psps=total_steps / dt,
103
+ total_steps=total_steps,
104
+ cpu_mem_use=cpu_mem_use,
105
+ gpu_mem_use=gpu_mem_use,
106
+ )
107
+ )
108
+ # torch.cuda.synchronize()
109
+
110
+ def log_stats(self, name: str):
111
+ stats = self.stats[name]
112
+ more_than_one_trial = len(stats) > 1
113
+ if len(stats) == 0:
114
+ return
115
+ # average the stats
116
+ avg_stats = defaultdict(list)
117
+ for data in stats:
118
+ for k, v in data.items():
119
+ avg_stats[k].append(v)
120
+ stats = {k: {"avg": np.mean(v), "std": np.std(v) if len(v) > 1 else None} for k, v in avg_stats.items()}
121
+ self.log(f"{name} ({len(self.stats[name])} trials)")
122
+ self.log(
123
+ f"AVG: {stats['fps']['avg']:0.3f} steps/s, {stats['psps']['avg']:0.3f} parallel steps/s, {stats['total_steps']['avg']} steps in {stats['dt']['avg']:0.3f}s"
124
+ )
125
+ if more_than_one_trial:
126
+ self.log(
127
+ f"STD: {stats['fps']['std']:0.3f} steps/s, {stats['psps']['std']:0.3f} parallel steps/s, {stats['total_steps']['std']} steps in {stats['dt']['std']:0.3f}s"
128
+ )
129
+ self.log(
130
+ f"{' ' * 4}CPU mem: {stats['cpu_mem_use']['avg'] / (1024**2):0.3f} MB, GPU mem: {stats['gpu_mem_use']['avg'] / (1024**2):0.3f} MB"
131
+ )
132
+
133
+ def get_current_process_gpu_memory(self):
134
+ # Get all processes running on the GPU
135
+ processes = pynvml.nvmlDeviceGetComputeRunningProcesses(self.handle)
136
+
137
+ # Iterate through the processes to find the current process
138
+ for process in processes:
139
+ if process.pid == self.current_pid:
140
+ memory_usage = process.usedGpuMemory
141
+ return memory_usage
@@ -0,0 +1,222 @@
1
+ import numpy as np
2
+
3
+ from luxai_s3.params import EnvParams
4
+ from luxai_s3.state import ASTEROID_TILE, NEBULA_TILE, EnvState
5
+
6
+ try:
7
+ import pygame
8
+ except:
9
+ pass
10
+
11
+ TILE_SIZE = 64
12
+
13
+
14
+ class LuxAIPygameRenderer:
15
+ def __init__(self):
16
+ pass
17
+
18
+ def render(self, state: EnvState, params: EnvParams):
19
+ """Render the environment."""
20
+
21
+ # Initialize Pygame if not already done
22
+ if not pygame.get_init():
23
+ pygame.init()
24
+ self.clock = pygame.time.Clock()
25
+ # Set up the display
26
+ screen_width = params.map_width * TILE_SIZE
27
+ screen_height = params.map_height * TILE_SIZE
28
+ self.screen = pygame.display.set_mode((screen_width, screen_height))
29
+ self.surface = pygame.Surface(self.screen.get_size(), pygame.SRCALPHA)
30
+ pygame.display.set_caption("Lux AI Season 3")
31
+
32
+ self.display_options = {
33
+ "show_grid": True,
34
+ "show_relic_spots": False,
35
+ "show_sensor_mask": True,
36
+ "show_vision_power_map": True,
37
+ "show_energy_field": False,
38
+ }
39
+
40
+ # Handle events to keep the window responsive
41
+ render_state = "running"
42
+ while True:
43
+ self._update_display(state, params)
44
+ for event in pygame.event.get():
45
+ if event.type == pygame.TEXTINPUT:
46
+ if event.text == " ":
47
+ if render_state == "running":
48
+ render_state = "paused"
49
+ else:
50
+ render_state = "running"
51
+ elif event.text == "r":
52
+ self.display_options["show_relic_spots"] = not self.display_options["show_relic_spots"]
53
+ elif event.text == "s":
54
+ self.display_options["show_sensor_mask"] = not self.display_options["show_sensor_mask"]
55
+ elif event.text == "e":
56
+ self.display_options["show_energy_field"] = not self.display_options["show_energy_field"]
57
+ else:
58
+ if render_state == "paused":
59
+ self.clock.tick(60)
60
+ continue
61
+ break
62
+
63
+ def _update_display(self, state: EnvState, params: EnvParams):
64
+ # Fill the screen with a background color
65
+ self.screen.fill((200, 200, 200))
66
+ self.surface.fill((200, 200, 200, 255)) # Light gray background
67
+
68
+ # Draw the grid of tiles
69
+ for x in range(params.map_width):
70
+ for y in range(params.map_height):
71
+ rect = pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE)
72
+ tile_type = state.map_features.tile_type[x, y]
73
+ if tile_type == NEBULA_TILE:
74
+ color = (166, 177, 225, 255) # Light blue (a6b1e1) for tile type 1
75
+ elif tile_type == ASTEROID_TILE:
76
+ color = (51, 56, 68, 255)
77
+ else:
78
+ color = (255, 255, 255, 255) # White for other tile types
79
+ pygame.draw.rect(self.surface, color, rect) # Draw filled squares
80
+
81
+ # Draw relic node configs if display option is enabled
82
+ def draw_rect_alpha(surface, color, rect):
83
+ shape_surf = pygame.Surface(pygame.Rect(rect).size, pygame.SRCALPHA)
84
+ pygame.draw.rect(shape_surf, color, shape_surf.get_rect())
85
+ surface.blit(shape_surf, rect)
86
+
87
+ if self.display_options["show_relic_spots"]:
88
+ mask = state.relic_nodes_map_weights
89
+ for x in range(params.map_width):
90
+ for y in range(params.map_height):
91
+ if mask[x, y] > 0:
92
+ rect = pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE)
93
+ draw_rect_alpha(self.surface, (255, 215, 0, 50), rect)
94
+
95
+ # Draw energy nodes
96
+ for i in range(params.max_energy_nodes):
97
+ if state.energy_nodes_mask[i]:
98
+ x, y = state.energy_nodes[i, :2]
99
+ center_x = (x + 0.5) * TILE_SIZE
100
+ center_y = (y + 0.5) * TILE_SIZE
101
+ radius = TILE_SIZE // 4 # Adjust this value to change the size of the circle
102
+ pygame.draw.circle(
103
+ self.surface,
104
+ (0, 255, 0, 255),
105
+ (int(center_x), int(center_y)),
106
+ radius,
107
+ )
108
+ # Draw relic nodes
109
+ for i in range(params.max_relic_nodes):
110
+ if state.relic_nodes_mask[i]:
111
+ x, y = state.relic_nodes[i, :2]
112
+ rect_size = TILE_SIZE // 2 # Make the square smaller than the tile
113
+ rect_x = x * TILE_SIZE + (TILE_SIZE - rect_size) // 2
114
+ rect_y = y * TILE_SIZE + (TILE_SIZE - rect_size) // 2
115
+ rect = pygame.Rect(rect_x, rect_y, rect_size, rect_size)
116
+ pygame.draw.rect(self.surface, (173, 151, 32, 255), rect) # Light blue color
117
+
118
+ # Draw sensor mask
119
+ if self.display_options["show_sensor_mask"]:
120
+ for team in range(params.num_teams):
121
+ for x in range(params.map_width):
122
+ for y in range(params.map_height):
123
+ if state.sensor_mask[team, x, y]:
124
+ draw_rect_alpha(
125
+ self.surface,
126
+ (255, 0, 0, 25),
127
+ pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE),
128
+ )
129
+
130
+ if self.display_options["show_energy_field"]:
131
+ font = pygame.font.Font(None, 32) # You may need to adjust the font size
132
+ for x in range(params.map_width):
133
+ for y in range(params.map_height):
134
+ energy_field_value = state.map_features.energy[x, y]
135
+ text = font.render(str(energy_field_value), True, (255, 255, 255))
136
+ text_rect = text.get_rect(center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE))
137
+ self.surface.blit(text, text_rect)
138
+ if energy_field_value > 0:
139
+ draw_rect_alpha(
140
+ self.surface,
141
+ (
142
+ 0,
143
+ 255,
144
+ 0,
145
+ 255 * energy_field_value / params.max_energy_per_tile,
146
+ ),
147
+ pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE),
148
+ )
149
+ else:
150
+ draw_rect_alpha(
151
+ self.surface,
152
+ (
153
+ 255,
154
+ 0,
155
+ 0,
156
+ 255 * energy_field_value / params.min_energy_per_tile,
157
+ ),
158
+ pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE),
159
+ )
160
+ # if self.display_options["show_vision_power_map"]:
161
+ # print(state.vision_power_map.shape)
162
+ # font = pygame.font.Font(None, 32) # You may need to adjust the font size
163
+ # # vision_power_map = vision_power_map - (state.map_features.tile_type == NEBULA_TILE)[..., 0] * params.nebula_tile_vision_reduction
164
+ # for team in range(0, 1):
165
+ # for x in range(params.map_width):
166
+ # for y in range(params.map_height):
167
+ # vision_power_value = state.vision_power_map[team, x, y]
168
+ # vision_power_value -= state.map_features.tile_type[x, y] == NEBULA_TILE
169
+ # text = font.render(str(vision_power_value), True, (0, 255, 255))
170
+ # text_rect = text.get_rect(
171
+ # center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE)
172
+ # )
173
+ # self.surface.blit(text, text_rect)
174
+
175
+ # Draw units
176
+ for team in range(2):
177
+ for i in range(params.max_units):
178
+ if state.units_mask[team, i]:
179
+ x, y = state.units.position[team, i]
180
+ center_x = (x + 0.5) * TILE_SIZE
181
+ center_y = (y + 0.5) * TILE_SIZE
182
+ radius = TILE_SIZE // 3 # Adjust this value to change the size of the circle
183
+ color = (255, 0, 0, 255) if team == 0 else (0, 0, 255, 255) # Red for team 0, Blue for team 1
184
+ pygame.draw.circle(self.surface, color, (int(center_x), int(center_y)), radius)
185
+ # Draw unit counts
186
+ unit_counts = {}
187
+ for team in range(2):
188
+ for i in range(params.max_units):
189
+ if state.units_mask[team, i]:
190
+ x, y = np.array(state.units.position[team, i])
191
+ pos = (x, y)
192
+ if pos not in unit_counts:
193
+ unit_counts[pos] = 0
194
+ unit_counts[pos] += 1
195
+
196
+ font = pygame.font.Font(None, 32) # You may need to adjust the font size
197
+ for pos, count in unit_counts.items():
198
+ if count >= 1:
199
+ x, y = pos
200
+ text = font.render(str(count), True, (255, 255, 255)) # White text
201
+ text_rect = text.get_rect(center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE))
202
+ self.surface.blit(text, text_rect)
203
+
204
+ # Draw the grid lines
205
+ for x in range(params.map_width + 1):
206
+ pygame.draw.line(
207
+ self.surface,
208
+ (100, 100, 100),
209
+ (x * TILE_SIZE, 0),
210
+ (x * TILE_SIZE, params.map_height * TILE_SIZE),
211
+ )
212
+ for y in range(params.map_height + 1):
213
+ pygame.draw.line(
214
+ self.surface,
215
+ (100, 100, 100),
216
+ (0, y * TILE_SIZE),
217
+ (params.map_width * TILE_SIZE, y * TILE_SIZE),
218
+ )
219
+
220
+ self.screen.blit(self.surface, (0, 0))
221
+ # Update the display
222
+ pygame.display.flip()
@@ -0,0 +1,27 @@
1
+ import chex
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from gymnax.environments.spaces import Space
6
+
7
+
8
+ class MultiDiscrete(Space):
9
+ """Minimal jittable class for multi discrete gymnax spaces."""
10
+
11
+ def __init__(self, low: np.ndarray, high: np.ndarray):
12
+ self.low = low
13
+ self.high = high
14
+ self.dist = self.high - self.low
15
+ assert low.shape == high.shape
16
+ self.shape = low.shape
17
+ self.dtype = jnp.int16
18
+
19
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
20
+ return (jax.random.uniform(rng, shape=self.shape, minval=0, maxval=1) * self.dist + self.low).astype(self.dtype)
21
+
22
+ def contains(self, x) -> jnp.ndarray:
23
+ """Check whether specific object is within space."""
24
+ # type_cond = isinstance(x, self.dtype)
25
+ # shape_cond = (x.shape == self.shape)
26
+ range_cond = jnp.logical_and(x >= 0, x < self.n)
27
+ return range_cond