kaggle-environments 0.2.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +49 -13
- kaggle_environments/agent.py +177 -124
- kaggle_environments/api.py +31 -0
- kaggle_environments/core.py +298 -173
- kaggle_environments/envs/cabt/cabt.js +164 -0
- kaggle_environments/envs/cabt/cabt.json +28 -0
- kaggle_environments/envs/cabt/cabt.py +186 -0
- kaggle_environments/envs/cabt/cg/__init__.py +0 -0
- kaggle_environments/envs/cabt/cg/cg.dll +0 -0
- kaggle_environments/envs/cabt/cg/game.py +75 -0
- kaggle_environments/envs/cabt/cg/libcg.so +0 -0
- kaggle_environments/envs/cabt/cg/sim.py +48 -0
- kaggle_environments/envs/cabt/test_cabt.py +120 -0
- kaggle_environments/envs/chess/chess.js +4289 -0
- kaggle_environments/envs/chess/chess.json +60 -0
- kaggle_environments/envs/chess/chess.py +4241 -0
- kaggle_environments/envs/chess/test_chess.py +60 -0
- kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
- kaggle_environments/envs/connectx/connectx.js +1 -1
- kaggle_environments/envs/connectx/connectx.json +15 -1
- kaggle_environments/envs/connectx/connectx.py +6 -23
- kaggle_environments/envs/connectx/test_connectx.py +70 -24
- kaggle_environments/envs/football/football.ipynb +75 -0
- kaggle_environments/envs/football/football.json +91 -0
- kaggle_environments/envs/football/football.py +277 -0
- kaggle_environments/envs/football/helpers.py +95 -0
- kaggle_environments/envs/football/test_football.py +360 -0
- kaggle_environments/envs/halite/__init__.py +0 -0
- kaggle_environments/envs/halite/halite.ipynb +44741 -0
- kaggle_environments/envs/halite/halite.js +199 -83
- kaggle_environments/envs/halite/halite.json +31 -18
- kaggle_environments/envs/halite/halite.py +164 -303
- kaggle_environments/envs/halite/helpers.py +720 -0
- kaggle_environments/envs/halite/test_halite.py +190 -0
- kaggle_environments/envs/hungry_geese/__init__.py +0 -0
- kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
- kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +22 -15
- kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
- kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
- kaggle_environments/envs/identity/identity.json +6 -5
- kaggle_environments/envs/identity/identity.py +15 -2
- kaggle_environments/envs/kore_fleets/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
- kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
- kaggle_environments/envs/lux_ai_2021/README.md +3 -0
- kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
- kaggle_environments/envs/lux_ai_2021/index.html +43 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
- kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
- kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
- kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
- kaggle_environments/envs/lux_ai_s3/README.md +21 -0
- kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
- kaggle_environments/envs/lux_ai_s3/index.html +42 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
- kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
- kaggle_environments/envs/mab/__init__.py +0 -0
- kaggle_environments/envs/mab/agents.py +12 -0
- kaggle_environments/envs/mab/mab.js +100 -0
- kaggle_environments/envs/mab/mab.json +74 -0
- kaggle_environments/envs/mab/mab.py +146 -0
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
- kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
- kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
- kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
- kaggle_environments/envs/open_spiel/observation.py +128 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
- kaggle_environments/envs/open_spiel/proxy.py +138 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
- kaggle_environments/envs/rps/__init__.py +0 -0
- kaggle_environments/envs/rps/agents.py +84 -0
- kaggle_environments/envs/rps/helpers.py +25 -0
- kaggle_environments/envs/rps/rps.js +117 -0
- kaggle_environments/envs/rps/rps.json +63 -0
- kaggle_environments/envs/rps/rps.py +90 -0
- kaggle_environments/envs/rps/test_rps.py +110 -0
- kaggle_environments/envs/rps/utils.py +7 -0
- kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
- kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
- kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
- kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
- kaggle_environments/errors.py +2 -4
- kaggle_environments/helpers.py +377 -0
- kaggle_environments/main.py +214 -50
- kaggle_environments/schemas.json +23 -18
- kaggle_environments/static/player.html +206 -74
- kaggle_environments/utils.py +46 -73
- kaggle_environments-1.20.0.dist-info/METADATA +25 -0
- kaggle_environments-1.20.0.dist-info/RECORD +211 -0
- {kaggle_environments-0.2.0.dist-info → kaggle_environments-1.20.0.dist-info}/WHEEL +1 -2
- kaggle_environments-1.20.0.dist-info/entry_points.txt +3 -0
- kaggle_environments/envs/battlegeese/battlegeese.py +0 -219
- kaggle_environments/temp.py +0 -14
- kaggle_environments-0.2.0.dist-info/METADATA +0 -393
- kaggle_environments-0.2.0.dist-info/RECORD +0 -33
- kaggle_environments-0.2.0.dist-info/entry_points.txt +0 -3
- kaggle_environments-0.2.0.dist-info/top_level.txt +0 -1
- {kaggle_environments-0.2.0.dist-info → kaggle_environments-1.20.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from flax import struct
|
|
2
|
+
|
|
3
|
+
MAP_TYPES = ["dev0", "random"]
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@struct.dataclass
|
|
7
|
+
class EnvParams:
|
|
8
|
+
max_steps_in_match: int = 100
|
|
9
|
+
map_type: int = 1
|
|
10
|
+
"""Map generation algorithm. Can change between games"""
|
|
11
|
+
map_width: int = 24
|
|
12
|
+
map_height: int = 24
|
|
13
|
+
num_teams: int = 2
|
|
14
|
+
match_count_per_episode: int = 5
|
|
15
|
+
"""number of matches to play in one episode"""
|
|
16
|
+
|
|
17
|
+
# configs for units
|
|
18
|
+
max_units: int = 16
|
|
19
|
+
init_unit_energy: int = 100
|
|
20
|
+
min_unit_energy: int = 0
|
|
21
|
+
max_unit_energy: int = 400
|
|
22
|
+
unit_move_cost: int = 2
|
|
23
|
+
spawn_rate: int = 3
|
|
24
|
+
|
|
25
|
+
unit_sap_cost: int = 10
|
|
26
|
+
"""
|
|
27
|
+
The unit sap cost is the amount of energy a unit uses when it saps another unit. Can change between games.
|
|
28
|
+
"""
|
|
29
|
+
unit_sap_range: int = 4
|
|
30
|
+
"""
|
|
31
|
+
The unit sap range is the range of the unit's sap action.
|
|
32
|
+
"""
|
|
33
|
+
unit_sap_dropoff_factor: float = 0.5
|
|
34
|
+
"""
|
|
35
|
+
The unit sap dropoff factor multiplied by unit_sap_drain
|
|
36
|
+
"""
|
|
37
|
+
unit_energy_void_factor: float = 0.125
|
|
38
|
+
"""
|
|
39
|
+
The unit energy void factor multiplied by unit_energy
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
# configs for energy nodes
|
|
43
|
+
max_energy_nodes: int = 6
|
|
44
|
+
max_energy_per_tile: int = 20
|
|
45
|
+
min_energy_per_tile: int = -20
|
|
46
|
+
|
|
47
|
+
max_relic_nodes: int = 6
|
|
48
|
+
"""max relic nodes in the entire map. This number should be tuned carefully as relic node spawning code is hardcoded against this number 6"""
|
|
49
|
+
relic_config_size: int = 5
|
|
50
|
+
fog_of_war: bool = True
|
|
51
|
+
"""
|
|
52
|
+
whether there is fog of war or not
|
|
53
|
+
"""
|
|
54
|
+
unit_sensor_range: int = 2
|
|
55
|
+
"""
|
|
56
|
+
The unit sensor range is the range of the unit's sensor.
|
|
57
|
+
Units provide "vision power" over tiles in range, equal to manhattan distance to the unit.
|
|
58
|
+
|
|
59
|
+
vision power > 0 that team can see the tiles properties
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
# nebula tile params
|
|
63
|
+
nebula_tile_vision_reduction: int = 1
|
|
64
|
+
"""
|
|
65
|
+
The nebula tile vision reduction is the amount of vision reduction a nebula tile provides.
|
|
66
|
+
A tile can be seen if the vision power over it is > 0.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
nebula_tile_energy_reduction: int = 0
|
|
70
|
+
"""amount of energy nebula tiles reduce from a unit"""
|
|
71
|
+
|
|
72
|
+
nebula_tile_drift_speed: float = -0.05
|
|
73
|
+
"""
|
|
74
|
+
how fast nebula tiles drift in one of the diagonal directions over time. If positive, flows to the top/right, negative flows to bottom/left
|
|
75
|
+
"""
|
|
76
|
+
# TODO (stao): allow other kinds of symmetric drifts?
|
|
77
|
+
|
|
78
|
+
energy_node_drift_speed: int = 0.02
|
|
79
|
+
"""
|
|
80
|
+
how fast energy nodes will move around over time
|
|
81
|
+
"""
|
|
82
|
+
energy_node_drift_magnitude: int = 5
|
|
83
|
+
|
|
84
|
+
# option to change sap configurations
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
env_params_ranges = dict(
|
|
88
|
+
# map_type=[1],
|
|
89
|
+
unit_move_cost=list(range(1, 6)),
|
|
90
|
+
unit_sensor_range=[1, 2, 3, 4],
|
|
91
|
+
nebula_tile_vision_reduction=list(range(0, 8)),
|
|
92
|
+
nebula_tile_energy_reduction=[0, 1, 2, 3, 5, 25],
|
|
93
|
+
unit_sap_cost=list(range(30, 51)),
|
|
94
|
+
unit_sap_range=list(range(3, 8)),
|
|
95
|
+
unit_sap_dropoff_factor=[0.25, 0.5, 1],
|
|
96
|
+
unit_energy_void_factor=[0.0625, 0.125, 0.25, 0.375],
|
|
97
|
+
# map randomizations
|
|
98
|
+
nebula_tile_drift_speed=[-0.15, -0.1, -0.05, -0.025, 0.025, 0.05, 0.1, 0.15],
|
|
99
|
+
energy_node_drift_speed=[0.01, 0.02, 0.03, 0.04, 0.05],
|
|
100
|
+
energy_node_drift_magnitude=list(range(3, 6)),
|
|
101
|
+
)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import psutil
|
|
8
|
+
import pynvml
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def flatten_dict_keys(d: dict, prefix=""):
|
|
12
|
+
"""Flatten a dict by expanding its keys recursively."""
|
|
13
|
+
out = dict()
|
|
14
|
+
for k, v in d.items():
|
|
15
|
+
if isinstance(v, dict):
|
|
16
|
+
out.update(flatten_dict_keys(v, prefix + k + "/"))
|
|
17
|
+
else:
|
|
18
|
+
out[prefix + k] = v
|
|
19
|
+
return out
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Profiler:
|
|
23
|
+
"""
|
|
24
|
+
A simple class to help profile/benchmark simulator code
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, output_format: Literal["stdout", "json"], synchronize_torch: bool = True) -> None:
|
|
28
|
+
self.output_format = output_format
|
|
29
|
+
self.synchronize_torch = synchronize_torch
|
|
30
|
+
self.stats = defaultdict(list)
|
|
31
|
+
# Initialize NVML
|
|
32
|
+
pynvml.nvmlInit()
|
|
33
|
+
|
|
34
|
+
# Get handle for the first GPU (index 0)
|
|
35
|
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
36
|
+
|
|
37
|
+
# Get the PID of the current process
|
|
38
|
+
self.current_pid = os.getpid()
|
|
39
|
+
|
|
40
|
+
def log(self, msg):
|
|
41
|
+
"""log a message to stdout"""
|
|
42
|
+
if self.output_format == "stdout":
|
|
43
|
+
print(msg)
|
|
44
|
+
|
|
45
|
+
def update_csv(self, csv_path: str, data: dict):
|
|
46
|
+
"""Update a csv file with the given data (a dict representing a unique identifier of the result row)
|
|
47
|
+
and stats. If the file does not exist, it will be created. The update will replace an existing row
|
|
48
|
+
if the given data matches the data in the row. If there are multiple matches, only the first match
|
|
49
|
+
will be replaced and the rest are deleted"""
|
|
50
|
+
import os
|
|
51
|
+
|
|
52
|
+
import pandas as pd
|
|
53
|
+
|
|
54
|
+
if os.path.exists(csv_path):
|
|
55
|
+
df = pd.read_csv(csv_path)
|
|
56
|
+
else:
|
|
57
|
+
df = pd.DataFrame()
|
|
58
|
+
stats_flat = flatten_dict_keys(self.stats)
|
|
59
|
+
cond = None
|
|
60
|
+
|
|
61
|
+
for k in stats_flat:
|
|
62
|
+
if k not in df:
|
|
63
|
+
df[k] = None
|
|
64
|
+
for k in data:
|
|
65
|
+
if k not in df:
|
|
66
|
+
df[k] = None
|
|
67
|
+
|
|
68
|
+
mask = df[k].isna() if data[k] is None else df[k] == data[k]
|
|
69
|
+
if cond is None:
|
|
70
|
+
cond = mask
|
|
71
|
+
else:
|
|
72
|
+
cond = cond & mask
|
|
73
|
+
data_dict = {**data, **stats_flat}
|
|
74
|
+
if not cond.any():
|
|
75
|
+
df = pd.concat([df, pd.DataFrame(data_dict, index=[len(df)])])
|
|
76
|
+
else:
|
|
77
|
+
# replace the first instance
|
|
78
|
+
df.loc[df.loc[cond].index[0]] = data_dict
|
|
79
|
+
df.drop(df.loc[cond].index[1:], inplace=True)
|
|
80
|
+
# delete other instances
|
|
81
|
+
df.to_csv(csv_path, index=False)
|
|
82
|
+
|
|
83
|
+
def profile(self, function, name: str, total_steps: int, num_envs: int, trials=1):
|
|
84
|
+
print(f"start recording {name} metrics")
|
|
85
|
+
process = psutil.Process(os.getpid())
|
|
86
|
+
cpu_mem_use = process.memory_info().rss
|
|
87
|
+
gpu_mem_use = self.get_current_process_gpu_memory()
|
|
88
|
+
if gpu_mem_use is None:
|
|
89
|
+
gpu_mem_use = 0
|
|
90
|
+
|
|
91
|
+
for trial in range(trials):
|
|
92
|
+
stime = time.time()
|
|
93
|
+
function()
|
|
94
|
+
dt = time.time() - stime
|
|
95
|
+
# dt: delta time (s)
|
|
96
|
+
# fps: frames per second
|
|
97
|
+
# psps: parallel steps per second (number of env.step calls per second)
|
|
98
|
+
self.stats[name].append(
|
|
99
|
+
dict(
|
|
100
|
+
dt=dt,
|
|
101
|
+
fps=total_steps * num_envs / dt,
|
|
102
|
+
psps=total_steps / dt,
|
|
103
|
+
total_steps=total_steps,
|
|
104
|
+
cpu_mem_use=cpu_mem_use,
|
|
105
|
+
gpu_mem_use=gpu_mem_use,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
# torch.cuda.synchronize()
|
|
109
|
+
|
|
110
|
+
def log_stats(self, name: str):
|
|
111
|
+
stats = self.stats[name]
|
|
112
|
+
more_than_one_trial = len(stats) > 1
|
|
113
|
+
if len(stats) == 0:
|
|
114
|
+
return
|
|
115
|
+
# average the stats
|
|
116
|
+
avg_stats = defaultdict(list)
|
|
117
|
+
for data in stats:
|
|
118
|
+
for k, v in data.items():
|
|
119
|
+
avg_stats[k].append(v)
|
|
120
|
+
stats = {k: {"avg": np.mean(v), "std": np.std(v) if len(v) > 1 else None} for k, v in avg_stats.items()}
|
|
121
|
+
self.log(f"{name} ({len(self.stats[name])} trials)")
|
|
122
|
+
self.log(
|
|
123
|
+
f"AVG: {stats['fps']['avg']:0.3f} steps/s, {stats['psps']['avg']:0.3f} parallel steps/s, {stats['total_steps']['avg']} steps in {stats['dt']['avg']:0.3f}s"
|
|
124
|
+
)
|
|
125
|
+
if more_than_one_trial:
|
|
126
|
+
self.log(
|
|
127
|
+
f"STD: {stats['fps']['std']:0.3f} steps/s, {stats['psps']['std']:0.3f} parallel steps/s, {stats['total_steps']['std']} steps in {stats['dt']['std']:0.3f}s"
|
|
128
|
+
)
|
|
129
|
+
self.log(
|
|
130
|
+
f"{' ' * 4}CPU mem: {stats['cpu_mem_use']['avg'] / (1024**2):0.3f} MB, GPU mem: {stats['gpu_mem_use']['avg'] / (1024**2):0.3f} MB"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def get_current_process_gpu_memory(self):
|
|
134
|
+
# Get all processes running on the GPU
|
|
135
|
+
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(self.handle)
|
|
136
|
+
|
|
137
|
+
# Iterate through the processes to find the current process
|
|
138
|
+
for process in processes:
|
|
139
|
+
if process.pid == self.current_pid:
|
|
140
|
+
memory_usage = process.usedGpuMemory
|
|
141
|
+
return memory_usage
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from luxai_s3.params import EnvParams
|
|
4
|
+
from luxai_s3.state import ASTEROID_TILE, NEBULA_TILE, EnvState
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import pygame
|
|
8
|
+
except:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
TILE_SIZE = 64
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LuxAIPygameRenderer:
|
|
15
|
+
def __init__(self):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
def render(self, state: EnvState, params: EnvParams):
|
|
19
|
+
"""Render the environment."""
|
|
20
|
+
|
|
21
|
+
# Initialize Pygame if not already done
|
|
22
|
+
if not pygame.get_init():
|
|
23
|
+
pygame.init()
|
|
24
|
+
self.clock = pygame.time.Clock()
|
|
25
|
+
# Set up the display
|
|
26
|
+
screen_width = params.map_width * TILE_SIZE
|
|
27
|
+
screen_height = params.map_height * TILE_SIZE
|
|
28
|
+
self.screen = pygame.display.set_mode((screen_width, screen_height))
|
|
29
|
+
self.surface = pygame.Surface(self.screen.get_size(), pygame.SRCALPHA)
|
|
30
|
+
pygame.display.set_caption("Lux AI Season 3")
|
|
31
|
+
|
|
32
|
+
self.display_options = {
|
|
33
|
+
"show_grid": True,
|
|
34
|
+
"show_relic_spots": False,
|
|
35
|
+
"show_sensor_mask": True,
|
|
36
|
+
"show_vision_power_map": True,
|
|
37
|
+
"show_energy_field": False,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# Handle events to keep the window responsive
|
|
41
|
+
render_state = "running"
|
|
42
|
+
while True:
|
|
43
|
+
self._update_display(state, params)
|
|
44
|
+
for event in pygame.event.get():
|
|
45
|
+
if event.type == pygame.TEXTINPUT:
|
|
46
|
+
if event.text == " ":
|
|
47
|
+
if render_state == "running":
|
|
48
|
+
render_state = "paused"
|
|
49
|
+
else:
|
|
50
|
+
render_state = "running"
|
|
51
|
+
elif event.text == "r":
|
|
52
|
+
self.display_options["show_relic_spots"] = not self.display_options["show_relic_spots"]
|
|
53
|
+
elif event.text == "s":
|
|
54
|
+
self.display_options["show_sensor_mask"] = not self.display_options["show_sensor_mask"]
|
|
55
|
+
elif event.text == "e":
|
|
56
|
+
self.display_options["show_energy_field"] = not self.display_options["show_energy_field"]
|
|
57
|
+
else:
|
|
58
|
+
if render_state == "paused":
|
|
59
|
+
self.clock.tick(60)
|
|
60
|
+
continue
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
def _update_display(self, state: EnvState, params: EnvParams):
|
|
64
|
+
# Fill the screen with a background color
|
|
65
|
+
self.screen.fill((200, 200, 200))
|
|
66
|
+
self.surface.fill((200, 200, 200, 255)) # Light gray background
|
|
67
|
+
|
|
68
|
+
# Draw the grid of tiles
|
|
69
|
+
for x in range(params.map_width):
|
|
70
|
+
for y in range(params.map_height):
|
|
71
|
+
rect = pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE)
|
|
72
|
+
tile_type = state.map_features.tile_type[x, y]
|
|
73
|
+
if tile_type == NEBULA_TILE:
|
|
74
|
+
color = (166, 177, 225, 255) # Light blue (a6b1e1) for tile type 1
|
|
75
|
+
elif tile_type == ASTEROID_TILE:
|
|
76
|
+
color = (51, 56, 68, 255)
|
|
77
|
+
else:
|
|
78
|
+
color = (255, 255, 255, 255) # White for other tile types
|
|
79
|
+
pygame.draw.rect(self.surface, color, rect) # Draw filled squares
|
|
80
|
+
|
|
81
|
+
# Draw relic node configs if display option is enabled
|
|
82
|
+
def draw_rect_alpha(surface, color, rect):
|
|
83
|
+
shape_surf = pygame.Surface(pygame.Rect(rect).size, pygame.SRCALPHA)
|
|
84
|
+
pygame.draw.rect(shape_surf, color, shape_surf.get_rect())
|
|
85
|
+
surface.blit(shape_surf, rect)
|
|
86
|
+
|
|
87
|
+
if self.display_options["show_relic_spots"]:
|
|
88
|
+
mask = state.relic_nodes_map_weights
|
|
89
|
+
for x in range(params.map_width):
|
|
90
|
+
for y in range(params.map_height):
|
|
91
|
+
if mask[x, y] > 0:
|
|
92
|
+
rect = pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE)
|
|
93
|
+
draw_rect_alpha(self.surface, (255, 215, 0, 50), rect)
|
|
94
|
+
|
|
95
|
+
# Draw energy nodes
|
|
96
|
+
for i in range(params.max_energy_nodes):
|
|
97
|
+
if state.energy_nodes_mask[i]:
|
|
98
|
+
x, y = state.energy_nodes[i, :2]
|
|
99
|
+
center_x = (x + 0.5) * TILE_SIZE
|
|
100
|
+
center_y = (y + 0.5) * TILE_SIZE
|
|
101
|
+
radius = TILE_SIZE // 4 # Adjust this value to change the size of the circle
|
|
102
|
+
pygame.draw.circle(
|
|
103
|
+
self.surface,
|
|
104
|
+
(0, 255, 0, 255),
|
|
105
|
+
(int(center_x), int(center_y)),
|
|
106
|
+
radius,
|
|
107
|
+
)
|
|
108
|
+
# Draw relic nodes
|
|
109
|
+
for i in range(params.max_relic_nodes):
|
|
110
|
+
if state.relic_nodes_mask[i]:
|
|
111
|
+
x, y = state.relic_nodes[i, :2]
|
|
112
|
+
rect_size = TILE_SIZE // 2 # Make the square smaller than the tile
|
|
113
|
+
rect_x = x * TILE_SIZE + (TILE_SIZE - rect_size) // 2
|
|
114
|
+
rect_y = y * TILE_SIZE + (TILE_SIZE - rect_size) // 2
|
|
115
|
+
rect = pygame.Rect(rect_x, rect_y, rect_size, rect_size)
|
|
116
|
+
pygame.draw.rect(self.surface, (173, 151, 32, 255), rect) # Light blue color
|
|
117
|
+
|
|
118
|
+
# Draw sensor mask
|
|
119
|
+
if self.display_options["show_sensor_mask"]:
|
|
120
|
+
for team in range(params.num_teams):
|
|
121
|
+
for x in range(params.map_width):
|
|
122
|
+
for y in range(params.map_height):
|
|
123
|
+
if state.sensor_mask[team, x, y]:
|
|
124
|
+
draw_rect_alpha(
|
|
125
|
+
self.surface,
|
|
126
|
+
(255, 0, 0, 25),
|
|
127
|
+
pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if self.display_options["show_energy_field"]:
|
|
131
|
+
font = pygame.font.Font(None, 32) # You may need to adjust the font size
|
|
132
|
+
for x in range(params.map_width):
|
|
133
|
+
for y in range(params.map_height):
|
|
134
|
+
energy_field_value = state.map_features.energy[x, y]
|
|
135
|
+
text = font.render(str(energy_field_value), True, (255, 255, 255))
|
|
136
|
+
text_rect = text.get_rect(center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE))
|
|
137
|
+
self.surface.blit(text, text_rect)
|
|
138
|
+
if energy_field_value > 0:
|
|
139
|
+
draw_rect_alpha(
|
|
140
|
+
self.surface,
|
|
141
|
+
(
|
|
142
|
+
0,
|
|
143
|
+
255,
|
|
144
|
+
0,
|
|
145
|
+
255 * energy_field_value / params.max_energy_per_tile,
|
|
146
|
+
),
|
|
147
|
+
pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE),
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
draw_rect_alpha(
|
|
151
|
+
self.surface,
|
|
152
|
+
(
|
|
153
|
+
255,
|
|
154
|
+
0,
|
|
155
|
+
0,
|
|
156
|
+
255 * energy_field_value / params.min_energy_per_tile,
|
|
157
|
+
),
|
|
158
|
+
pygame.Rect(x * TILE_SIZE, y * TILE_SIZE, TILE_SIZE, TILE_SIZE),
|
|
159
|
+
)
|
|
160
|
+
# if self.display_options["show_vision_power_map"]:
|
|
161
|
+
# print(state.vision_power_map.shape)
|
|
162
|
+
# font = pygame.font.Font(None, 32) # You may need to adjust the font size
|
|
163
|
+
# # vision_power_map = vision_power_map - (state.map_features.tile_type == NEBULA_TILE)[..., 0] * params.nebula_tile_vision_reduction
|
|
164
|
+
# for team in range(0, 1):
|
|
165
|
+
# for x in range(params.map_width):
|
|
166
|
+
# for y in range(params.map_height):
|
|
167
|
+
# vision_power_value = state.vision_power_map[team, x, y]
|
|
168
|
+
# vision_power_value -= state.map_features.tile_type[x, y] == NEBULA_TILE
|
|
169
|
+
# text = font.render(str(vision_power_value), True, (0, 255, 255))
|
|
170
|
+
# text_rect = text.get_rect(
|
|
171
|
+
# center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE)
|
|
172
|
+
# )
|
|
173
|
+
# self.surface.blit(text, text_rect)
|
|
174
|
+
|
|
175
|
+
# Draw units
|
|
176
|
+
for team in range(2):
|
|
177
|
+
for i in range(params.max_units):
|
|
178
|
+
if state.units_mask[team, i]:
|
|
179
|
+
x, y = state.units.position[team, i]
|
|
180
|
+
center_x = (x + 0.5) * TILE_SIZE
|
|
181
|
+
center_y = (y + 0.5) * TILE_SIZE
|
|
182
|
+
radius = TILE_SIZE // 3 # Adjust this value to change the size of the circle
|
|
183
|
+
color = (255, 0, 0, 255) if team == 0 else (0, 0, 255, 255) # Red for team 0, Blue for team 1
|
|
184
|
+
pygame.draw.circle(self.surface, color, (int(center_x), int(center_y)), radius)
|
|
185
|
+
# Draw unit counts
|
|
186
|
+
unit_counts = {}
|
|
187
|
+
for team in range(2):
|
|
188
|
+
for i in range(params.max_units):
|
|
189
|
+
if state.units_mask[team, i]:
|
|
190
|
+
x, y = np.array(state.units.position[team, i])
|
|
191
|
+
pos = (x, y)
|
|
192
|
+
if pos not in unit_counts:
|
|
193
|
+
unit_counts[pos] = 0
|
|
194
|
+
unit_counts[pos] += 1
|
|
195
|
+
|
|
196
|
+
font = pygame.font.Font(None, 32) # You may need to adjust the font size
|
|
197
|
+
for pos, count in unit_counts.items():
|
|
198
|
+
if count >= 1:
|
|
199
|
+
x, y = pos
|
|
200
|
+
text = font.render(str(count), True, (255, 255, 255)) # White text
|
|
201
|
+
text_rect = text.get_rect(center=((x + 0.5) * TILE_SIZE, (y + 0.5) * TILE_SIZE))
|
|
202
|
+
self.surface.blit(text, text_rect)
|
|
203
|
+
|
|
204
|
+
# Draw the grid lines
|
|
205
|
+
for x in range(params.map_width + 1):
|
|
206
|
+
pygame.draw.line(
|
|
207
|
+
self.surface,
|
|
208
|
+
(100, 100, 100),
|
|
209
|
+
(x * TILE_SIZE, 0),
|
|
210
|
+
(x * TILE_SIZE, params.map_height * TILE_SIZE),
|
|
211
|
+
)
|
|
212
|
+
for y in range(params.map_height + 1):
|
|
213
|
+
pygame.draw.line(
|
|
214
|
+
self.surface,
|
|
215
|
+
(100, 100, 100),
|
|
216
|
+
(0, y * TILE_SIZE),
|
|
217
|
+
(params.map_width * TILE_SIZE, y * TILE_SIZE),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
self.screen.blit(self.surface, (0, 0))
|
|
221
|
+
# Update the display
|
|
222
|
+
pygame.display.flip()
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import chex
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
from gymnax.environments.spaces import Space
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MultiDiscrete(Space):
|
|
9
|
+
"""Minimal jittable class for multi discrete gymnax spaces."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, low: np.ndarray, high: np.ndarray):
|
|
12
|
+
self.low = low
|
|
13
|
+
self.high = high
|
|
14
|
+
self.dist = self.high - self.low
|
|
15
|
+
assert low.shape == high.shape
|
|
16
|
+
self.shape = low.shape
|
|
17
|
+
self.dtype = jnp.int16
|
|
18
|
+
|
|
19
|
+
def sample(self, rng: chex.PRNGKey) -> chex.Array:
|
|
20
|
+
return (jax.random.uniform(rng, shape=self.shape, minval=0, maxval=1) * self.dist + self.low).astype(self.dtype)
|
|
21
|
+
|
|
22
|
+
def contains(self, x) -> jnp.ndarray:
|
|
23
|
+
"""Check whether specific object is within space."""
|
|
24
|
+
# type_cond = isinstance(x, self.dtype)
|
|
25
|
+
# shape_cond = (x.shape == self.shape)
|
|
26
|
+
range_cond = jnp.logical_and(x >= 0, x < self.n)
|
|
27
|
+
return range_cond
|