kaggle-environments 0.2.1__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +49 -13
- kaggle_environments/agent.py +177 -124
- kaggle_environments/api.py +31 -0
- kaggle_environments/core.py +295 -170
- kaggle_environments/envs/cabt/cabt.js +164 -0
- kaggle_environments/envs/cabt/cabt.json +28 -0
- kaggle_environments/envs/cabt/cabt.py +186 -0
- kaggle_environments/envs/cabt/cg/__init__.py +0 -0
- kaggle_environments/envs/cabt/cg/cg.dll +0 -0
- kaggle_environments/envs/cabt/cg/game.py +75 -0
- kaggle_environments/envs/cabt/cg/libcg.so +0 -0
- kaggle_environments/envs/cabt/cg/sim.py +48 -0
- kaggle_environments/envs/cabt/test_cabt.py +120 -0
- kaggle_environments/envs/chess/chess.js +4289 -0
- kaggle_environments/envs/chess/chess.json +60 -0
- kaggle_environments/envs/chess/chess.py +4241 -0
- kaggle_environments/envs/chess/test_chess.py +60 -0
- kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
- kaggle_environments/envs/connectx/connectx.js +1 -1
- kaggle_environments/envs/connectx/connectx.json +15 -1
- kaggle_environments/envs/connectx/connectx.py +6 -23
- kaggle_environments/envs/connectx/test_connectx.py +70 -24
- kaggle_environments/envs/football/football.ipynb +75 -0
- kaggle_environments/envs/football/football.json +91 -0
- kaggle_environments/envs/football/football.py +277 -0
- kaggle_environments/envs/football/helpers.py +95 -0
- kaggle_environments/envs/football/test_football.py +360 -0
- kaggle_environments/envs/halite/__init__.py +0 -0
- kaggle_environments/envs/halite/halite.ipynb +44741 -0
- kaggle_environments/envs/halite/halite.js +199 -83
- kaggle_environments/envs/halite/halite.json +31 -18
- kaggle_environments/envs/halite/halite.py +164 -303
- kaggle_environments/envs/halite/helpers.py +720 -0
- kaggle_environments/envs/halite/test_halite.py +190 -0
- kaggle_environments/envs/hungry_geese/__init__.py +0 -0
- kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
- kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +21 -14
- kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
- kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
- kaggle_environments/envs/identity/identity.json +6 -5
- kaggle_environments/envs/identity/identity.py +15 -2
- kaggle_environments/envs/kore_fleets/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
- kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
- kaggle_environments/envs/lux_ai_2021/README.md +3 -0
- kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
- kaggle_environments/envs/lux_ai_2021/index.html +43 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
- kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
- kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
- kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
- kaggle_environments/envs/lux_ai_s3/README.md +21 -0
- kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
- kaggle_environments/envs/lux_ai_s3/index.html +42 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
- kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
- kaggle_environments/envs/mab/__init__.py +0 -0
- kaggle_environments/envs/mab/agents.py +12 -0
- kaggle_environments/envs/mab/mab.js +100 -0
- kaggle_environments/envs/mab/mab.json +74 -0
- kaggle_environments/envs/mab/mab.py +146 -0
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
- kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
- kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
- kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
- kaggle_environments/envs/open_spiel/observation.py +128 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
- kaggle_environments/envs/open_spiel/proxy.py +138 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
- kaggle_environments/envs/rps/__init__.py +0 -0
- kaggle_environments/envs/rps/agents.py +84 -0
- kaggle_environments/envs/rps/helpers.py +25 -0
- kaggle_environments/envs/rps/rps.js +117 -0
- kaggle_environments/envs/rps/rps.json +63 -0
- kaggle_environments/envs/rps/rps.py +90 -0
- kaggle_environments/envs/rps/test_rps.py +110 -0
- kaggle_environments/envs/rps/utils.py +7 -0
- kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
- kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
- kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
- kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
- kaggle_environments/errors.py +2 -4
- kaggle_environments/helpers.py +377 -0
- kaggle_environments/main.py +340 -0
- kaggle_environments/schemas.json +23 -18
- kaggle_environments/static/player.html +206 -74
- kaggle_environments/utils.py +46 -73
- kaggle_environments-1.20.0.dist-info/METADATA +25 -0
- kaggle_environments-1.20.0.dist-info/RECORD +211 -0
- {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info}/WHEEL +1 -2
- kaggle_environments-1.20.0.dist-info/entry_points.txt +3 -0
- kaggle_environments/envs/battlegeese/battlegeese.py +0 -223
- kaggle_environments/temp.py +0 -14
- kaggle_environments-0.2.1.dist-info/METADATA +0 -393
- kaggle_environments-0.2.1.dist-info/RECORD +0 -32
- kaggle_environments-0.2.1.dist-info/entry_points.txt +0 -3
- kaggle_environments-0.2.1.dist-info/top_level.txt +0 -1
- {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import chex
|
|
4
|
+
import flax
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from flax import struct
|
|
9
|
+
|
|
10
|
+
from luxai_s3.params import MAP_TYPES, EnvParams
|
|
11
|
+
|
|
12
|
+
EMPTY_TILE = 0
|
|
13
|
+
NEBULA_TILE = 1
|
|
14
|
+
ASTEROID_TILE = 2
|
|
15
|
+
|
|
16
|
+
ENERGY_NODE_FNS = [lambda d, x, y, z: jnp.sin(d * x + y) * z, lambda d, x, y, z: (x / (d + 1) + y) * z]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@struct.dataclass
|
|
20
|
+
class UnitState:
|
|
21
|
+
position: chex.Array
|
|
22
|
+
"""Position of the unit with shape (2) for x, y"""
|
|
23
|
+
energy: int
|
|
24
|
+
"""Energy of the unit"""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@struct.dataclass
|
|
28
|
+
class MapTile:
|
|
29
|
+
energy: int
|
|
30
|
+
"""Energy of the tile, generated via energy_nodes and energy_node_fns"""
|
|
31
|
+
tile_type: int
|
|
32
|
+
"""Type of the tile"""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@struct.dataclass
|
|
36
|
+
class EnvState:
|
|
37
|
+
units: UnitState
|
|
38
|
+
"""Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
|
|
39
|
+
|
|
40
|
+
3 features are for position (x, y), and energy
|
|
41
|
+
"""
|
|
42
|
+
units_mask: chex.Array
|
|
43
|
+
"""Mask of units in the environment with shape (T, N) for T teams, N max units"""
|
|
44
|
+
energy_nodes: chex.Array
|
|
45
|
+
"""Energy nodes in the environment with shape (N, 2) for N max energy nodes, and 2 features.
|
|
46
|
+
|
|
47
|
+
2 features are for position (x, y)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
energy_node_fns: chex.Array
|
|
51
|
+
"""Energy node functions for computing the energy field of the map. They describe the function with a sequence of numbers
|
|
52
|
+
|
|
53
|
+
The first number is the function used. The subsequent numbers parameterize the function. The function is applied to distance of map tile to energy node and the function parameters.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# energy_field: chex.Array
|
|
57
|
+
# """Energy field in the environment with shape (H, W) for H height, W width. This is generated from other state"""
|
|
58
|
+
|
|
59
|
+
energy_nodes_mask: chex.Array
|
|
60
|
+
"""Mask of energy nodes in the environment with shape (N) for N max energy nodes"""
|
|
61
|
+
relic_nodes: chex.Array
|
|
62
|
+
"""Relic nodes in the environment with shape (N, 2) for N max relic nodes, and 2 features.
|
|
63
|
+
|
|
64
|
+
2 features are for position (x, y)
|
|
65
|
+
"""
|
|
66
|
+
relic_node_configs: chex.Array
|
|
67
|
+
"""Relic node configs in the environment with shape (N, K, K) for N max relic nodes and a KxK relic configuration"""
|
|
68
|
+
relic_nodes_mask: chex.Array
|
|
69
|
+
"""Mask of relic nodes in the environment with shape (N, ) for N max relic nodes"""
|
|
70
|
+
relic_nodes_map_weights: chex.Array
|
|
71
|
+
"""Map of relic nodes in the environment with shape (H, W) for H height, W width. Each element is equal to the 1-indexed id of the relic node. This is generated from other state"""
|
|
72
|
+
|
|
73
|
+
relic_spawn_schedule: chex.Array
|
|
74
|
+
"""Relic spawn schedule in the environment with shape (N, ) for N max relic nodes. Elements are the game timestep at which the relic node spawns"""
|
|
75
|
+
|
|
76
|
+
map_features: MapTile
|
|
77
|
+
"""Map features in the environment with shape (W, H, 2) for W width, H height
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
sensor_mask: chex.Array
|
|
81
|
+
"""Sensor mask in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
|
|
82
|
+
|
|
83
|
+
vision_power_map: chex.Array
|
|
84
|
+
"""Vision power map in the environment with shape (T, H, W) for T teams, H height, W width. This is generated from other state"""
|
|
85
|
+
|
|
86
|
+
team_points: chex.Array
|
|
87
|
+
"""Team points in the environment with shape (T) for T teams"""
|
|
88
|
+
team_wins: chex.Array
|
|
89
|
+
"""Team wins in the environment with shape (T) for T teams"""
|
|
90
|
+
|
|
91
|
+
steps: int = 0
|
|
92
|
+
"""steps taken in the environment"""
|
|
93
|
+
match_steps: int = 0
|
|
94
|
+
"""steps taken in the current match"""
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@struct.dataclass
|
|
98
|
+
class EnvObs:
|
|
99
|
+
"""Partial observation of environment"""
|
|
100
|
+
|
|
101
|
+
units: UnitState
|
|
102
|
+
"""Units in the environment with shape (T, N, 3) for T teams, N max units, and 3 features.
|
|
103
|
+
|
|
104
|
+
3 features are for position (x, y), and energy
|
|
105
|
+
"""
|
|
106
|
+
units_mask: chex.Array
|
|
107
|
+
"""Mask of units in the environment with shape (T, N) for T teams, N max units"""
|
|
108
|
+
|
|
109
|
+
sensor_mask: chex.Array
|
|
110
|
+
|
|
111
|
+
map_features: MapTile
|
|
112
|
+
"""Map features in the environment with shape (W, H, 2) for W width, H height
|
|
113
|
+
"""
|
|
114
|
+
relic_nodes: chex.Array
|
|
115
|
+
"""Position of all relic nodes with shape (N, 2) for N max relic nodes and 2 features for position (x, y). Number is -1 if not visible"""
|
|
116
|
+
relic_nodes_mask: chex.Array
|
|
117
|
+
"""Mask of all relic nodes with shape (N) for N max relic nodes"""
|
|
118
|
+
team_points: chex.Array
|
|
119
|
+
"""Team points in the environment with shape (T) for T teams"""
|
|
120
|
+
team_wins: chex.Array
|
|
121
|
+
"""Team wins in the environment with shape (T) for T teams"""
|
|
122
|
+
steps: int = 0
|
|
123
|
+
"""steps taken in the environment"""
|
|
124
|
+
match_steps: int = 0
|
|
125
|
+
"""steps taken in the current match"""
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def serialize_env_states(env_states: list[EnvState]):
|
|
129
|
+
def serialize_array(root: EnvState, arr, key_path: str = ""):
|
|
130
|
+
if key_path in [
|
|
131
|
+
"sensor_mask",
|
|
132
|
+
"relic_nodes_mask",
|
|
133
|
+
"energy_nodes_mask",
|
|
134
|
+
"energy_node_fns",
|
|
135
|
+
"relic_nodes_map_weights",
|
|
136
|
+
"relic_spawn_schedule",
|
|
137
|
+
]:
|
|
138
|
+
return None
|
|
139
|
+
if key_path == "relic_nodes":
|
|
140
|
+
return root.relic_nodes[root.relic_nodes_mask].tolist()
|
|
141
|
+
if key_path == "relic_node_configs":
|
|
142
|
+
return root.relic_node_configs[root.relic_nodes_mask].tolist()
|
|
143
|
+
if key_path == "energy_nodes":
|
|
144
|
+
return root.energy_nodes[root.energy_nodes_mask].tolist()
|
|
145
|
+
if isinstance(arr, jnp.ndarray):
|
|
146
|
+
return arr.tolist()
|
|
147
|
+
elif isinstance(arr, dict):
|
|
148
|
+
ret = dict()
|
|
149
|
+
for k, v in arr.items():
|
|
150
|
+
new_key = key_path + "/" + k if key_path else k
|
|
151
|
+
new_val = serialize_array(root, v, new_key)
|
|
152
|
+
if new_val is not None:
|
|
153
|
+
ret[k] = new_val
|
|
154
|
+
return ret
|
|
155
|
+
return arr
|
|
156
|
+
|
|
157
|
+
steps = []
|
|
158
|
+
for state in env_states:
|
|
159
|
+
state_dict = flax.serialization.to_state_dict(state)
|
|
160
|
+
steps.append(serialize_array(state, state_dict))
|
|
161
|
+
|
|
162
|
+
return steps
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def serialize_env_actions(env_actions: list):
|
|
166
|
+
def serialize_array(arr, key_path: str = ""):
|
|
167
|
+
if isinstance(arr, np.ndarray):
|
|
168
|
+
return arr.tolist()
|
|
169
|
+
elif isinstance(arr, jnp.ndarray):
|
|
170
|
+
return arr.tolist()
|
|
171
|
+
elif isinstance(arr, dict):
|
|
172
|
+
ret = dict()
|
|
173
|
+
for k, v in arr.items():
|
|
174
|
+
new_key = key_path + "/" + k if key_path else k
|
|
175
|
+
new_val = serialize_array(v, new_key)
|
|
176
|
+
if new_val is not None:
|
|
177
|
+
ret[k] = new_val
|
|
178
|
+
return ret
|
|
179
|
+
|
|
180
|
+
return arr
|
|
181
|
+
|
|
182
|
+
steps = []
|
|
183
|
+
for state in env_actions:
|
|
184
|
+
state = flax.serialization.to_state_dict(state)
|
|
185
|
+
steps.append(serialize_array(state))
|
|
186
|
+
|
|
187
|
+
return steps
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def state_to_flat_obs(state: EnvState) -> chex.Array:
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def flat_obs_to_state(flat_obs: chex.Array) -> EnvState:
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9))
|
|
199
|
+
def gen_state(
|
|
200
|
+
key: chex.PRNGKey,
|
|
201
|
+
env_params: EnvParams,
|
|
202
|
+
max_units: int,
|
|
203
|
+
num_teams: int,
|
|
204
|
+
map_type: int,
|
|
205
|
+
map_width: int,
|
|
206
|
+
map_height: int,
|
|
207
|
+
max_energy_nodes: int,
|
|
208
|
+
max_relic_nodes: int,
|
|
209
|
+
relic_config_size: int,
|
|
210
|
+
) -> EnvState:
|
|
211
|
+
generated = gen_map(
|
|
212
|
+
key, env_params, map_type, map_width, map_height, max_energy_nodes, max_relic_nodes, relic_config_size
|
|
213
|
+
)
|
|
214
|
+
relic_nodes_map_weights = jnp.zeros(shape=(map_width, map_height), dtype=jnp.int16)
|
|
215
|
+
|
|
216
|
+
# TODO (this could be optimized better)
|
|
217
|
+
def update_relic_node(relic_nodes_map_weights, relic_data):
|
|
218
|
+
relic_node, relic_node_config, mask, relic_node_id = relic_data
|
|
219
|
+
start_y = relic_node[1] - relic_config_size // 2
|
|
220
|
+
start_x = relic_node[0] - relic_config_size // 2
|
|
221
|
+
|
|
222
|
+
for dy in range(relic_config_size):
|
|
223
|
+
for dx in range(relic_config_size):
|
|
224
|
+
y, x = start_y + dy, start_x + dx
|
|
225
|
+
valid_pos = jnp.logical_and(
|
|
226
|
+
jnp.logical_and(y >= 0, x >= 0),
|
|
227
|
+
jnp.logical_and(y < map_height, x < map_width),
|
|
228
|
+
)
|
|
229
|
+
# ensure we don't override previous spawns
|
|
230
|
+
has_points = jnp.logical_and(relic_nodes_map_weights > 0, relic_nodes_map_weights <= relic_node_id + 1)
|
|
231
|
+
relic_nodes_map_weights = jnp.where(
|
|
232
|
+
valid_pos & mask & jnp.logical_not(has_points) & relic_node_config[dx, dy],
|
|
233
|
+
relic_nodes_map_weights.at[x, y].set(
|
|
234
|
+
relic_node_config[dx, dy].astype(jnp.int16) * (relic_node_id + 1)
|
|
235
|
+
),
|
|
236
|
+
relic_nodes_map_weights,
|
|
237
|
+
)
|
|
238
|
+
return relic_nodes_map_weights, None
|
|
239
|
+
|
|
240
|
+
# this is really slow...
|
|
241
|
+
|
|
242
|
+
relic_nodes_map_weights, _ = jax.lax.scan(
|
|
243
|
+
update_relic_node,
|
|
244
|
+
relic_nodes_map_weights,
|
|
245
|
+
(
|
|
246
|
+
generated["relic_nodes"],
|
|
247
|
+
generated["relic_node_configs"],
|
|
248
|
+
generated["relic_nodes_mask"],
|
|
249
|
+
jnp.arange(max_relic_nodes, dtype=jnp.int16) % (max_relic_nodes // 2),
|
|
250
|
+
),
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
state = EnvState(
|
|
254
|
+
units=UnitState(
|
|
255
|
+
position=jnp.zeros(shape=(num_teams, max_units, 2), dtype=jnp.int16),
|
|
256
|
+
energy=jnp.zeros(shape=(num_teams, max_units, 1), dtype=jnp.int16),
|
|
257
|
+
),
|
|
258
|
+
units_mask=jnp.zeros(shape=(num_teams, max_units), dtype=jnp.bool),
|
|
259
|
+
team_points=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
|
|
260
|
+
team_wins=jnp.zeros(shape=(num_teams), dtype=jnp.int32),
|
|
261
|
+
energy_nodes=generated["energy_nodes"],
|
|
262
|
+
energy_node_fns=generated["energy_node_fns"],
|
|
263
|
+
energy_nodes_mask=generated["energy_nodes_mask"],
|
|
264
|
+
# energy_field=jnp.zeros(shape=(params.map_height, params.map_width), dtype=jnp.int16),
|
|
265
|
+
relic_nodes=generated["relic_nodes"],
|
|
266
|
+
relic_nodes_mask=jnp.zeros(
|
|
267
|
+
shape=(max_relic_nodes), dtype=jnp.bool
|
|
268
|
+
), # as relic nodes are spawn in, we start with them all invisible.
|
|
269
|
+
relic_node_configs=generated["relic_node_configs"],
|
|
270
|
+
relic_nodes_map_weights=relic_nodes_map_weights,
|
|
271
|
+
relic_spawn_schedule=generated["relic_spawn_schedule"],
|
|
272
|
+
sensor_mask=jnp.zeros(
|
|
273
|
+
shape=(num_teams, map_height, map_width),
|
|
274
|
+
dtype=jnp.bool,
|
|
275
|
+
),
|
|
276
|
+
vision_power_map=jnp.zeros(shape=(num_teams, map_height, map_width), dtype=jnp.int16),
|
|
277
|
+
map_features=generated["map_features"],
|
|
278
|
+
)
|
|
279
|
+
return state
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7))
|
|
283
|
+
def gen_map(
|
|
284
|
+
key: chex.PRNGKey,
|
|
285
|
+
params: EnvParams,
|
|
286
|
+
map_type: int,
|
|
287
|
+
map_height: int,
|
|
288
|
+
map_width: int,
|
|
289
|
+
max_energy_nodes: int,
|
|
290
|
+
max_relic_nodes: int,
|
|
291
|
+
relic_config_size: int,
|
|
292
|
+
) -> chex.Array:
|
|
293
|
+
map_features = MapTile(
|
|
294
|
+
energy=jnp.zeros(shape=(map_height, map_width), dtype=jnp.int16),
|
|
295
|
+
tile_type=jnp.zeros(shape=(map_height, map_width), dtype=jnp.int16),
|
|
296
|
+
)
|
|
297
|
+
energy_nodes = jnp.zeros(shape=(max_energy_nodes, 2), dtype=jnp.int16)
|
|
298
|
+
energy_nodes_mask = jnp.zeros(shape=(max_energy_nodes), dtype=jnp.bool)
|
|
299
|
+
relic_nodes = jnp.zeros(shape=(max_relic_nodes, 2), dtype=jnp.int16)
|
|
300
|
+
relic_nodes_mask = jnp.zeros(shape=(max_relic_nodes), dtype=jnp.bool)
|
|
301
|
+
|
|
302
|
+
if MAP_TYPES[map_type] == "random":
|
|
303
|
+
### Generate nebula tiles ###
|
|
304
|
+
key, subkey = jax.random.split(key)
|
|
305
|
+
perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
|
|
306
|
+
noise = jnp.where(perlin_noise > 0.5, 1, 0)
|
|
307
|
+
# mirror along diagonal
|
|
308
|
+
noise = noise | noise.T
|
|
309
|
+
noise = noise[::-1, ::1]
|
|
310
|
+
map_features = map_features.replace(tile_type=jnp.where(noise, NEBULA_TILE, 0))
|
|
311
|
+
|
|
312
|
+
### Generate asteroid tiles ###
|
|
313
|
+
key, subkey = jax.random.split(key)
|
|
314
|
+
perlin_noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (8, 8))
|
|
315
|
+
noise = jnp.where(perlin_noise < -0.5, 1, 0)
|
|
316
|
+
# mirror along diagonal
|
|
317
|
+
noise = noise | noise.T
|
|
318
|
+
noise = noise[::-1, ::1]
|
|
319
|
+
map_features = map_features.replace(
|
|
320
|
+
tile_type=jnp.place(map_features.tile_type, noise, ASTEROID_TILE, inplace=False)
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
### Generate relic nodes ###
|
|
324
|
+
key, subkey = jax.random.split(key)
|
|
325
|
+
noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
|
|
326
|
+
# Find the positions of the highest noise values
|
|
327
|
+
flat_indices = jnp.argsort(noise.ravel())[-max_relic_nodes // 2 :] # Get indices of two highest values
|
|
328
|
+
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape))
|
|
329
|
+
|
|
330
|
+
# relic nodes have a fixed density of 20% nearby tiles can yield points
|
|
331
|
+
relic_node_configs = (
|
|
332
|
+
jax.random.randint(
|
|
333
|
+
key,
|
|
334
|
+
shape=(
|
|
335
|
+
max_relic_nodes,
|
|
336
|
+
relic_config_size,
|
|
337
|
+
relic_config_size,
|
|
338
|
+
),
|
|
339
|
+
minval=0,
|
|
340
|
+
maxval=10,
|
|
341
|
+
).astype(jnp.float32)
|
|
342
|
+
>= 7.5
|
|
343
|
+
)
|
|
344
|
+
highest_positions = highest_positions.astype(jnp.int16)
|
|
345
|
+
mirrored_positions = jnp.stack(
|
|
346
|
+
[map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1],
|
|
347
|
+
dtype=jnp.int16,
|
|
348
|
+
axis=-1,
|
|
349
|
+
)
|
|
350
|
+
relic_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
|
|
351
|
+
|
|
352
|
+
key, subkey = jax.random.split(key)
|
|
353
|
+
num_spawned_relic_nodes = jax.random.randint(key, (1,), minval=1, maxval=(max_relic_nodes // 2) + 1)
|
|
354
|
+
relic_nodes_mask_half = jnp.arange(max_relic_nodes // 2) < num_spawned_relic_nodes
|
|
355
|
+
relic_nodes_mask = jnp.concat([relic_nodes_mask_half, relic_nodes_mask_half], axis=0)
|
|
356
|
+
relic_node_configs = relic_node_configs.at[max_relic_nodes // 2 :].set(
|
|
357
|
+
relic_node_configs[: max_relic_nodes // 2].transpose(0, 2, 1)[:, ::-1, ::-1]
|
|
358
|
+
)
|
|
359
|
+
# note that relic nodes mask is always increasing.
|
|
360
|
+
|
|
361
|
+
### Generate energy nodes ###
|
|
362
|
+
key, subkey = jax.random.split(key)
|
|
363
|
+
noise = generate_perlin_noise_2d(subkey, (map_height, map_width), (4, 4))
|
|
364
|
+
# Find the positions of the highest noise values
|
|
365
|
+
flat_indices = jnp.argsort(noise.ravel())[-max_energy_nodes // 2 :] # Get indices of highest values
|
|
366
|
+
highest_positions = jnp.column_stack(jnp.unravel_index(flat_indices, noise.shape)).astype(jnp.int16)
|
|
367
|
+
mirrored_positions = jnp.stack(
|
|
368
|
+
[map_width - highest_positions[:, 1] - 1, map_height - highest_positions[:, 0] - 1],
|
|
369
|
+
dtype=jnp.int16,
|
|
370
|
+
axis=-1,
|
|
371
|
+
)
|
|
372
|
+
energy_nodes = jnp.concat([highest_positions, mirrored_positions], axis=0)
|
|
373
|
+
key, subkey = jax.random.split(key)
|
|
374
|
+
energy_nodes_mask_half = jax.random.randint(key, (max_energy_nodes // 2,), minval=0, maxval=2).astype(jnp.bool)
|
|
375
|
+
energy_nodes_mask_half = energy_nodes_mask_half.at[0].set(True)
|
|
376
|
+
energy_nodes_mask = energy_nodes_mask.at[: max_energy_nodes // 2].set(energy_nodes_mask_half)
|
|
377
|
+
energy_nodes_mask = energy_nodes_mask.at[max_energy_nodes // 2 :].set(energy_nodes_mask_half)
|
|
378
|
+
|
|
379
|
+
energy_node_fns = jnp.array(
|
|
380
|
+
[
|
|
381
|
+
[0, 1.2, 1, 4],
|
|
382
|
+
[0, 0, 0, 0],
|
|
383
|
+
[0, 0, 0, 0],
|
|
384
|
+
# [1, 4, 0, 2],
|
|
385
|
+
[0, 1.2, 1, 4],
|
|
386
|
+
[0, 0, 0, 0],
|
|
387
|
+
[0, 0, 0, 0],
|
|
388
|
+
# [1, 4, 0, 0]
|
|
389
|
+
]
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# generate a random relic spawn schedule
|
|
393
|
+
# if number is -1, then relic node is never spawned, otherwise spawn at that game timestep
|
|
394
|
+
assert max_relic_nodes == 6, "random map generation is hardcoded to use 6 relic nodes at most per map"
|
|
395
|
+
key, subkey = jax.random.split(key)
|
|
396
|
+
relic_spawn_schedule_half = jax.random.randint(
|
|
397
|
+
key, (max_relic_nodes // 2,), minval=0, maxval=params.max_steps_in_match // 2
|
|
398
|
+
) + jnp.arange(3) * (params.max_steps_in_match + 1)
|
|
399
|
+
relic_spawn_schedule = jnp.concat([relic_spawn_schedule_half, relic_spawn_schedule_half], axis=0)
|
|
400
|
+
relic_spawn_schedule = jnp.where(relic_nodes_mask, relic_spawn_schedule, -1)
|
|
401
|
+
|
|
402
|
+
return dict(
|
|
403
|
+
map_features=map_features,
|
|
404
|
+
energy_nodes=energy_nodes,
|
|
405
|
+
energy_node_fns=energy_node_fns,
|
|
406
|
+
relic_nodes=relic_nodes,
|
|
407
|
+
energy_nodes_mask=energy_nodes_mask,
|
|
408
|
+
relic_nodes_mask=relic_nodes_mask,
|
|
409
|
+
relic_node_configs=relic_node_configs,
|
|
410
|
+
relic_spawn_schedule=relic_spawn_schedule,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def interpolant(t):
|
|
415
|
+
return t * t * t * (t * (t * 6 - 15) + 10)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
|
419
|
+
def generate_perlin_noise_2d(key, shape, res, tileable=(False, False), interpolant=interpolant):
|
|
420
|
+
"""Generate a 2D numpy array of perlin noise.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
shape: The shape of the generated array (tuple of two ints).
|
|
424
|
+
This must be a multple of res.
|
|
425
|
+
res: The number of periods of noise to generate along each
|
|
426
|
+
axis (tuple of two ints). Note shape must be a multiple of
|
|
427
|
+
res.
|
|
428
|
+
tileable: If the noise should be tileable along each axis
|
|
429
|
+
(tuple of two bools). Defaults to (False, False).
|
|
430
|
+
interpolant: The interpolation function, defaults to
|
|
431
|
+
t*t*t*(t*(t*6 - 15) + 10).
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
A numpy array of shape shape with the generated noise.
|
|
435
|
+
|
|
436
|
+
Raises:
|
|
437
|
+
ValueError: If shape is not a multiple of res.
|
|
438
|
+
"""
|
|
439
|
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
|
440
|
+
d = (shape[0] // res[0], shape[1] // res[1])
|
|
441
|
+
grid = jnp.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1
|
|
442
|
+
# Gradients
|
|
443
|
+
angles = 2 * jnp.pi * jax.random.uniform(key, (res[0] + 1, res[1] + 1))
|
|
444
|
+
gradients = jnp.dstack((jnp.cos(angles), jnp.sin(angles)))
|
|
445
|
+
if tileable[0]:
|
|
446
|
+
gradients[-1, :] = gradients[0, :]
|
|
447
|
+
if tileable[1]:
|
|
448
|
+
gradients[:, -1] = gradients[:, 0]
|
|
449
|
+
gradients = gradients.repeat(d[0], 0).repeat(d[1], 1)
|
|
450
|
+
g00 = gradients[: -d[0], : -d[1]]
|
|
451
|
+
g10 = gradients[d[0] :, : -d[1]]
|
|
452
|
+
g01 = gradients[: -d[0], d[1] :]
|
|
453
|
+
g11 = gradients[d[0] :, d[1] :]
|
|
454
|
+
|
|
455
|
+
# Ramps
|
|
456
|
+
n00 = jnp.sum(jnp.dstack((grid[:, :, 0], grid[:, :, 1])) * g00, 2)
|
|
457
|
+
n10 = jnp.sum(jnp.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
|
|
458
|
+
n01 = jnp.sum(jnp.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
|
|
459
|
+
n11 = jnp.sum(jnp.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
|
|
460
|
+
# Interpolation
|
|
461
|
+
t = interpolant(grid)
|
|
462
|
+
n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
|
|
463
|
+
n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
|
|
464
|
+
return jnp.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
# TODO (stao): Add lux ai s3 env to gymnax api wrapper, which is the old gym api
|
|
2
|
+
import dataclasses
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, SupportsFloat
|
|
6
|
+
|
|
7
|
+
import flax
|
|
8
|
+
import flax.serialization
|
|
9
|
+
import gymnasium as gym
|
|
10
|
+
import jax
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from luxai_s3.env import LuxAIS3Env
|
|
14
|
+
from luxai_s3.params import EnvParams, env_params_ranges
|
|
15
|
+
from luxai_s3.state import serialize_env_actions, serialize_env_states
|
|
16
|
+
from luxai_s3.utils import to_numpy
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LuxAIS3GymEnv(gym.Env):
|
|
20
|
+
def __init__(self, numpy_output: bool = False):
|
|
21
|
+
self.numpy_output = numpy_output
|
|
22
|
+
self.rng_key = jax.random.key(0)
|
|
23
|
+
self.jax_env = LuxAIS3Env(auto_reset=False)
|
|
24
|
+
self.env_params: EnvParams = EnvParams()
|
|
25
|
+
|
|
26
|
+
low = np.zeros((self.env_params.max_units, 3))
|
|
27
|
+
low[:, 1:] = -self.env_params.unit_sap_range
|
|
28
|
+
high = np.ones((self.env_params.max_units, 3)) * 6
|
|
29
|
+
high[:, 1:] = self.env_params.unit_sap_range
|
|
30
|
+
self.action_space = gym.spaces.Dict(
|
|
31
|
+
dict(
|
|
32
|
+
player_0=gym.spaces.Box(low=low, high=high, dtype=np.int16),
|
|
33
|
+
player_1=gym.spaces.Box(low=low, high=high, dtype=np.int16),
|
|
34
|
+
)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def render(self):
|
|
38
|
+
self.jax_env.render(self.state, self.env_params)
|
|
39
|
+
|
|
40
|
+
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
|
|
41
|
+
if seed is not None:
|
|
42
|
+
self.rng_key = jax.random.key(seed)
|
|
43
|
+
self.rng_key, reset_key = jax.random.split(self.rng_key)
|
|
44
|
+
# generate random game parameters
|
|
45
|
+
# TODO (stao): check why this keeps recompiling when marking structs as static args
|
|
46
|
+
randomized_game_params = dict()
|
|
47
|
+
for k, v in env_params_ranges.items():
|
|
48
|
+
self.rng_key, subkey = jax.random.split(self.rng_key)
|
|
49
|
+
randomized_game_params[k] = jax.random.choice(subkey, jax.numpy.array(v)).item()
|
|
50
|
+
params = EnvParams(**randomized_game_params)
|
|
51
|
+
if options is not None and "params" in options:
|
|
52
|
+
params = options["params"]
|
|
53
|
+
|
|
54
|
+
self.env_params = params
|
|
55
|
+
obs, self.state = self.jax_env.reset(reset_key, params=params)
|
|
56
|
+
if self.numpy_output:
|
|
57
|
+
obs = to_numpy(flax.serialization.to_state_dict(obs))
|
|
58
|
+
|
|
59
|
+
# only keep the following game parameters available to the agent
|
|
60
|
+
params_dict = dataclasses.asdict(params)
|
|
61
|
+
params_dict_kept = dict()
|
|
62
|
+
for k in [
|
|
63
|
+
"max_units",
|
|
64
|
+
"match_count_per_episode",
|
|
65
|
+
"max_steps_in_match",
|
|
66
|
+
"map_height",
|
|
67
|
+
"map_width",
|
|
68
|
+
"num_teams",
|
|
69
|
+
"unit_move_cost",
|
|
70
|
+
"unit_sap_cost",
|
|
71
|
+
"unit_sap_range",
|
|
72
|
+
"unit_sensor_range",
|
|
73
|
+
]:
|
|
74
|
+
params_dict_kept[k] = params_dict[k]
|
|
75
|
+
return obs, dict(params=params_dict_kept, full_params=params_dict, state=self.state)
|
|
76
|
+
|
|
77
|
+
def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
78
|
+
self.rng_key, step_key = jax.random.split(self.rng_key)
|
|
79
|
+
obs, self.state, reward, terminated, truncated, info = self.jax_env.step(
|
|
80
|
+
step_key, self.state, action, self.env_params
|
|
81
|
+
)
|
|
82
|
+
if self.numpy_output:
|
|
83
|
+
obs = to_numpy(flax.serialization.to_state_dict(obs))
|
|
84
|
+
reward = to_numpy(reward)
|
|
85
|
+
terminated = to_numpy(terminated)
|
|
86
|
+
truncated = to_numpy(truncated)
|
|
87
|
+
# info = to_numpy(flax.serialization.to_state_dict(info))
|
|
88
|
+
return obs, reward, terminated, truncated, info
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# TODO: vectorized gym wrapper
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RecordEpisode(gym.Wrapper):
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
env: LuxAIS3GymEnv,
|
|
98
|
+
save_dir: str = None,
|
|
99
|
+
save_on_close: bool = True,
|
|
100
|
+
save_on_reset: bool = True,
|
|
101
|
+
):
|
|
102
|
+
super().__init__(env)
|
|
103
|
+
self.episode = dict(states=[], actions=[], metadata=dict())
|
|
104
|
+
self.episode_id = 0
|
|
105
|
+
self.save_dir = save_dir
|
|
106
|
+
self.save_on_close = save_on_close
|
|
107
|
+
self.save_on_reset = save_on_reset
|
|
108
|
+
self.episode_steps = 0
|
|
109
|
+
if save_dir is not None:
|
|
110
|
+
from pathlib import Path
|
|
111
|
+
|
|
112
|
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
113
|
+
|
|
114
|
+
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[Any, dict[str, Any]]:
|
|
115
|
+
if self.save_on_reset and self.episode_steps > 0:
|
|
116
|
+
self._save_episode_and_reset()
|
|
117
|
+
obs, info = self.env.reset(seed=seed, options=options)
|
|
118
|
+
|
|
119
|
+
self.episode["metadata"]["seed"] = seed
|
|
120
|
+
self.episode["params"] = flax.serialization.to_state_dict(info["full_params"])
|
|
121
|
+
self.episode["states"].append(info["state"])
|
|
122
|
+
return obs, info
|
|
123
|
+
|
|
124
|
+
def step(self, action: Any) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]:
|
|
125
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
126
|
+
self.episode_steps += 1
|
|
127
|
+
self.episode["states"].append(info["final_state"])
|
|
128
|
+
self.episode["actions"].append(action)
|
|
129
|
+
return obs, reward, terminated, truncated, info
|
|
130
|
+
|
|
131
|
+
def serialize_episode_data(self, episode=None):
|
|
132
|
+
if episode is None:
|
|
133
|
+
episode = self.episode
|
|
134
|
+
ret = dict()
|
|
135
|
+
ret["observations"] = serialize_env_states(episode["states"])
|
|
136
|
+
if "actions" in episode:
|
|
137
|
+
ret["actions"] = serialize_env_actions(episode["actions"])
|
|
138
|
+
ret["metadata"] = episode["metadata"]
|
|
139
|
+
ret["params"] = episode["params"]
|
|
140
|
+
return ret
|
|
141
|
+
|
|
142
|
+
def save_episode(self, save_path: str):
|
|
143
|
+
episode = self.serialize_episode_data()
|
|
144
|
+
with open(save_path, "w") as f:
|
|
145
|
+
json.dump(episode, f)
|
|
146
|
+
self.episode = dict(states=[], actions=[], metadata=dict())
|
|
147
|
+
|
|
148
|
+
def _save_episode_and_reset(self):
|
|
149
|
+
"""saves to generated path based on self.save_dir and episoe id and updates relevant counters"""
|
|
150
|
+
self.save_episode(os.path.join(self.save_dir, f"episode_{self.episode_id}.json"))
|
|
151
|
+
self.episode_id += 1
|
|
152
|
+
self.episode_steps = 0
|
|
153
|
+
|
|
154
|
+
def close(self):
|
|
155
|
+
if self.save_on_close and self.episode_steps > 0:
|
|
156
|
+
self._save_episode_and_reset()
|