kaggle-environments 0.2.1__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/__init__.py +49 -13
- kaggle_environments/agent.py +177 -124
- kaggle_environments/api.py +31 -0
- kaggle_environments/core.py +295 -170
- kaggle_environments/envs/cabt/cabt.js +164 -0
- kaggle_environments/envs/cabt/cabt.json +28 -0
- kaggle_environments/envs/cabt/cabt.py +186 -0
- kaggle_environments/envs/cabt/cg/__init__.py +0 -0
- kaggle_environments/envs/cabt/cg/cg.dll +0 -0
- kaggle_environments/envs/cabt/cg/game.py +75 -0
- kaggle_environments/envs/cabt/cg/libcg.so +0 -0
- kaggle_environments/envs/cabt/cg/sim.py +48 -0
- kaggle_environments/envs/cabt/test_cabt.py +120 -0
- kaggle_environments/envs/chess/chess.js +4289 -0
- kaggle_environments/envs/chess/chess.json +60 -0
- kaggle_environments/envs/chess/chess.py +4241 -0
- kaggle_environments/envs/chess/test_chess.py +60 -0
- kaggle_environments/envs/connectx/connectx.ipynb +3186 -0
- kaggle_environments/envs/connectx/connectx.js +1 -1
- kaggle_environments/envs/connectx/connectx.json +15 -1
- kaggle_environments/envs/connectx/connectx.py +6 -23
- kaggle_environments/envs/connectx/test_connectx.py +70 -24
- kaggle_environments/envs/football/football.ipynb +75 -0
- kaggle_environments/envs/football/football.json +91 -0
- kaggle_environments/envs/football/football.py +277 -0
- kaggle_environments/envs/football/helpers.py +95 -0
- kaggle_environments/envs/football/test_football.py +360 -0
- kaggle_environments/envs/halite/__init__.py +0 -0
- kaggle_environments/envs/halite/halite.ipynb +44741 -0
- kaggle_environments/envs/halite/halite.js +199 -83
- kaggle_environments/envs/halite/halite.json +31 -18
- kaggle_environments/envs/halite/halite.py +164 -303
- kaggle_environments/envs/halite/helpers.py +720 -0
- kaggle_environments/envs/halite/test_halite.py +190 -0
- kaggle_environments/envs/hungry_geese/__init__.py +0 -0
- kaggle_environments/envs/{battlegeese/battlegeese.js → hungry_geese/hungry_geese.js} +38 -22
- kaggle_environments/envs/{battlegeese/battlegeese.json → hungry_geese/hungry_geese.json} +21 -14
- kaggle_environments/envs/hungry_geese/hungry_geese.py +316 -0
- kaggle_environments/envs/hungry_geese/test_hungry_geese.py +0 -0
- kaggle_environments/envs/identity/identity.json +6 -5
- kaggle_environments/envs/identity/identity.py +15 -2
- kaggle_environments/envs/kore_fleets/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/helpers.py +1005 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.ipynb +114 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.js +658 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.json +164 -0
- kaggle_environments/envs/kore_fleets/kore_fleets.py +555 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/Bot.java +54 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/README.md +26 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/hamcrest-core-1.3.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/jars/junit-4.13.2.jar +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Board.java +518 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Cell.java +61 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Configuration.java +24 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Direction.java +166 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Fleet.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/KoreJson.java +97 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Observation.java +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Pair.java +13 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Player.java +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Point.java +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/Shipyard.java +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/kore/ShipyardAction.java +59 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/BoardTest.java +567 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ConfigurationTest.java +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/KoreJsonTest.java +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ObservationTest.java +46 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/PointTest.java +21 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/ShipyardTest.java +22 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/java/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/__init__.py +0 -0
- kaggle_environments/envs/kore_fleets/starter_bots/python/main.py +27 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/Bot.ts +34 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/DoNothingBot.ts +12 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/MinerBot.ts +62 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/README.md +55 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/interpreter.ts +402 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Board.ts +514 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Cell.ts +63 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Configuration.ts +25 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Direction.ts +169 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Fleet.ts +76 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/KoreIO.ts +70 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Observation.ts +45 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Pair.ts +11 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Player.ts +68 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Point.ts +65 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/Shipyard.ts +72 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/kore/ShipyardAction.ts +58 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/main.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/miner.py +73 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/package.json +23 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/BoardTest.ts +551 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ConfigurationTest.ts +16 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ObservationTest.ts +33 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/PointTest.ts +17 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/ShipyardTest.ts +18 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/configuration.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/fullob.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/test/observation.json +1 -0
- kaggle_environments/envs/kore_fleets/starter_bots/ts/tsconfig.json +22 -0
- kaggle_environments/envs/kore_fleets/test_kore_fleets.py +331 -0
- kaggle_environments/envs/lux_ai_2021/README.md +3 -0
- kaggle_environments/envs/lux_ai_2021/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/agents.py +11 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js +2 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/754.js.LICENSE.txt +296 -0
- kaggle_environments/envs/lux_ai_2021/dimensions/main.js +1 -0
- kaggle_environments/envs/lux_ai_2021/index.html +43 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.json +100 -0
- kaggle_environments/envs/lux_ai_2021/lux_ai_2021.py +231 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.js +6 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/game_objects.js +145 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/io.js +14 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/kit.js +209 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/map.js +107 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/lux/parser.js +79 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.js +88 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/main.py +75 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/js_simple/simple.tar.gz +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/annotate.py +20 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/constants.py +25 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game.py +86 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.json +59 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_constants.py +7 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_map.py +106 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/lux/game_objects.py +154 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/random_agent.py +38 -0
- kaggle_environments/envs/lux_ai_2021/test_agents/python/simple_agent.py +82 -0
- kaggle_environments/envs/lux_ai_2021/test_lux.py +19 -0
- kaggle_environments/envs/lux_ai_2021/testing.md +23 -0
- kaggle_environments/envs/lux_ai_2021/todo.md.og +18 -0
- kaggle_environments/envs/lux_ai_s3/README.md +21 -0
- kaggle_environments/envs/lux_ai_s3/agents.py +5 -0
- kaggle_environments/envs/lux_ai_s3/index.html +42 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.json +47 -0
- kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py +178 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/__init__.py +1 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/env.py +819 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/globals.py +9 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/params.py +101 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/profiler.py +141 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/pygame_render.py +222 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/spaces.py +27 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/state.py +464 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/utils.py +12 -0
- kaggle_environments/envs/lux_ai_s3/luxai_s3/wrappers.py +156 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/agent.py +78 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/__init__.py +0 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/kit.py +31 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/lux/utils.py +17 -0
- kaggle_environments/envs/lux_ai_s3/test_agents/python/main.py +66 -0
- kaggle_environments/envs/lux_ai_s3/test_lux.py +9 -0
- kaggle_environments/envs/mab/__init__.py +0 -0
- kaggle_environments/envs/mab/agents.py +12 -0
- kaggle_environments/envs/mab/mab.js +100 -0
- kaggle_environments/envs/mab/mab.json +74 -0
- kaggle_environments/envs/mab/mab.py +146 -0
- kaggle_environments/envs/open_spiel/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/chess/chess.js +441 -0
- kaggle_environments/envs/open_spiel/games/chess/image_config.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/chess/openings.jsonl +20 -0
- kaggle_environments/envs/open_spiel/games/connect_four/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four.js +284 -0
- kaggle_environments/envs/open_spiel/games/connect_four/connect_four_proxy.py +86 -0
- kaggle_environments/envs/open_spiel/games/go/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/go/go.js +481 -0
- kaggle_environments/envs/open_spiel/games/go/go_proxy.py +99 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe.js +345 -0
- kaggle_environments/envs/open_spiel/games/tic_tac_toe/tic_tac_toe_proxy.py +98 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/__init__.py +0 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker.js +431 -0
- kaggle_environments/envs/open_spiel/games/universal_poker/universal_poker_proxy.py +159 -0
- kaggle_environments/envs/open_spiel/html_playthrough_generator.py +31 -0
- kaggle_environments/envs/open_spiel/observation.py +128 -0
- kaggle_environments/envs/open_spiel/open_spiel.py +565 -0
- kaggle_environments/envs/open_spiel/proxy.py +138 -0
- kaggle_environments/envs/open_spiel/test_open_spiel.py +191 -0
- kaggle_environments/envs/rps/__init__.py +0 -0
- kaggle_environments/envs/rps/agents.py +84 -0
- kaggle_environments/envs/rps/helpers.py +25 -0
- kaggle_environments/envs/rps/rps.js +117 -0
- kaggle_environments/envs/rps/rps.json +63 -0
- kaggle_environments/envs/rps/rps.py +90 -0
- kaggle_environments/envs/rps/test_rps.py +110 -0
- kaggle_environments/envs/rps/utils.py +7 -0
- kaggle_environments/envs/tictactoe/test_tictactoe.py +43 -77
- kaggle_environments/envs/tictactoe/tictactoe.ipynb +1397 -0
- kaggle_environments/envs/tictactoe/tictactoe.json +10 -2
- kaggle_environments/envs/tictactoe/tictactoe.py +1 -1
- kaggle_environments/errors.py +2 -4
- kaggle_environments/helpers.py +377 -0
- kaggle_environments/main.py +340 -0
- kaggle_environments/schemas.json +23 -18
- kaggle_environments/static/player.html +206 -74
- kaggle_environments/utils.py +46 -73
- kaggle_environments-1.20.0.dist-info/METADATA +25 -0
- kaggle_environments-1.20.0.dist-info/RECORD +211 -0
- {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info}/WHEEL +1 -2
- kaggle_environments-1.20.0.dist-info/entry_points.txt +3 -0
- kaggle_environments/envs/battlegeese/battlegeese.py +0 -223
- kaggle_environments/temp.py +0 -14
- kaggle_environments-0.2.1.dist-info/METADATA +0 -393
- kaggle_environments-0.2.1.dist-info/RECORD +0 -32
- kaggle_environments-0.2.1.dist-info/entry_points.txt +0 -3
- kaggle_environments-0.2.1.dist-info/top_level.txt +0 -1
- {kaggle_environments-0.2.1.dist-info → kaggle_environments-1.20.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,819 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import chex
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from gymnax.environments import environment, spaces
|
|
9
|
+
from jax import lax
|
|
10
|
+
|
|
11
|
+
from luxai_s3.params import EnvParams, env_params_ranges
|
|
12
|
+
from luxai_s3.pygame_render import LuxAIPygameRenderer
|
|
13
|
+
from luxai_s3.spaces import MultiDiscrete
|
|
14
|
+
from luxai_s3.state import ASTEROID_TILE, ENERGY_NODE_FNS, NEBULA_TILE, EnvObs, EnvState, MapTile, UnitState, gen_state
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LuxAIS3Env(environment.Environment):
|
|
18
|
+
def __init__(self, auto_reset=False, fixed_env_params: EnvParams = EnvParams(), **kwargs):
|
|
19
|
+
super().__init__(**kwargs)
|
|
20
|
+
self.renderer = LuxAIPygameRenderer()
|
|
21
|
+
self.auto_reset = auto_reset
|
|
22
|
+
self.fixed_env_params = fixed_env_params
|
|
23
|
+
"""fixed env params for concrete/static values. Necessary for jit/vmap capability with randomly sampled maps which must of consistent shape"""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def default_params(self) -> EnvParams:
|
|
27
|
+
params = EnvParams()
|
|
28
|
+
params = jax.tree_map(jax.numpy.array, params)
|
|
29
|
+
return params
|
|
30
|
+
|
|
31
|
+
def compute_unit_counts_map(self, state: EnvState, params: EnvParams, exclude_negative_energy_units: bool = False):
|
|
32
|
+
# map of total units per team on each tile, shape (num_teams, map_width, map_height)
|
|
33
|
+
unit_counts_map = jnp.zeros(
|
|
34
|
+
(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
|
|
35
|
+
dtype=jnp.int16,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def update_unit_counts_map(unit_position, unit_mask, unit_energy_nonnegative, unit_counts_map):
|
|
39
|
+
if exclude_negative_energy_units:
|
|
40
|
+
mask = unit_mask & unit_energy_nonnegative
|
|
41
|
+
else:
|
|
42
|
+
mask = unit_mask
|
|
43
|
+
unit_counts_map = unit_counts_map.at[unit_position[0], unit_position[1]].add(mask.astype(jnp.int16))
|
|
44
|
+
return unit_counts_map
|
|
45
|
+
|
|
46
|
+
for t in range(self.fixed_env_params.num_teams):
|
|
47
|
+
unit_counts_map = unit_counts_map.at[t].add(
|
|
48
|
+
jnp.sum(
|
|
49
|
+
jax.vmap(update_unit_counts_map, in_axes=(0, 0, 0, None), out_axes=0)(
|
|
50
|
+
state.units.position[t],
|
|
51
|
+
state.units_mask[t],
|
|
52
|
+
state.units.energy[t, :, 0] >= 0,
|
|
53
|
+
unit_counts_map[t],
|
|
54
|
+
),
|
|
55
|
+
axis=0,
|
|
56
|
+
dtype=jnp.int16,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
return unit_counts_map
|
|
60
|
+
|
|
61
|
+
def compute_energy_features(self, state: EnvState, params: EnvParams):
|
|
62
|
+
# first compute a array of shape (map_height, map_width, num_energy_nodes) with values equal to the distance of the tile to the energy node
|
|
63
|
+
mm = jnp.meshgrid(jnp.arange(self.fixed_env_params.map_width), jnp.arange(self.fixed_env_params.map_height))
|
|
64
|
+
mm = jnp.stack([mm[0], mm[1]]).T.astype(jnp.int16) # mm[x, y] gives [x, y]
|
|
65
|
+
distances_to_nodes = jax.vmap(lambda pos: jnp.linalg.norm(mm - pos, axis=-1))(state.energy_nodes)
|
|
66
|
+
|
|
67
|
+
def compute_energy_field(node_fn_spec, distances_to_node, mask):
|
|
68
|
+
fn_i, x, y, z = node_fn_spec
|
|
69
|
+
return jnp.where(
|
|
70
|
+
mask,
|
|
71
|
+
lax.switch(fn_i.astype(jnp.int16), ENERGY_NODE_FNS, distances_to_node, x, y, z),
|
|
72
|
+
jnp.zeros_like(distances_to_node),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
energy_field = jax.vmap(compute_energy_field)(
|
|
76
|
+
state.energy_node_fns, distances_to_nodes, state.energy_nodes_mask
|
|
77
|
+
)
|
|
78
|
+
energy_field = jnp.where(
|
|
79
|
+
energy_field.mean() < 0.25,
|
|
80
|
+
energy_field + (0.25 - energy_field.mean()),
|
|
81
|
+
energy_field,
|
|
82
|
+
)
|
|
83
|
+
energy_field = jnp.round(energy_field.sum(0)).astype(jnp.int16)
|
|
84
|
+
energy_field = jnp.clip(energy_field, params.min_energy_per_tile, params.max_energy_per_tile)
|
|
85
|
+
state = state.replace(map_features=state.map_features.replace(energy=energy_field))
|
|
86
|
+
return state
|
|
87
|
+
|
|
88
|
+
def compute_sensor_masks(self, state, params: EnvParams):
|
|
89
|
+
"""Compute the vision power and sensor mask for both teams
|
|
90
|
+
|
|
91
|
+
Algorithm:
|
|
92
|
+
|
|
93
|
+
For each team, generate a integer vision power array over the map.
|
|
94
|
+
For each unit in team, add unit sensor range value (its kind of like the units sensing power/depth) to each tile the unit's sensor range
|
|
95
|
+
Clamp the vision power array to range [0, unit_sensing_range].
|
|
96
|
+
|
|
97
|
+
With 2 vision power maps, take the nebula vision mask * nebula vision power and subtract it from the vision power maps.
|
|
98
|
+
Now any time the vision power map has value > 0, the team can sense the tile. This forms the sensor mask
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
max_sensor_range = env_params_ranges["unit_sensor_range"][-1]
|
|
102
|
+
vision_power_map_padding = max_sensor_range
|
|
103
|
+
vision_power_map = jnp.zeros(
|
|
104
|
+
shape=(
|
|
105
|
+
self.fixed_env_params.num_teams,
|
|
106
|
+
self.fixed_env_params.map_height + 2 * vision_power_map_padding,
|
|
107
|
+
self.fixed_env_params.map_width + 2 * vision_power_map_padding,
|
|
108
|
+
),
|
|
109
|
+
dtype=jnp.int16,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Update sensor mask based on the sensor range
|
|
113
|
+
def update_vision_power_map(unit_pos, vision_power_map):
|
|
114
|
+
x, y = unit_pos
|
|
115
|
+
existing_vision_power = jax.lax.dynamic_slice(
|
|
116
|
+
vision_power_map,
|
|
117
|
+
start_indices=(
|
|
118
|
+
x - max_sensor_range + vision_power_map_padding,
|
|
119
|
+
y - max_sensor_range + vision_power_map_padding,
|
|
120
|
+
),
|
|
121
|
+
slice_sizes=(
|
|
122
|
+
max_sensor_range * 2 + 1,
|
|
123
|
+
max_sensor_range * 2 + 1,
|
|
124
|
+
),
|
|
125
|
+
)
|
|
126
|
+
update = jnp.zeros_like(existing_vision_power)
|
|
127
|
+
for i in range(max_sensor_range + 1):
|
|
128
|
+
val = jnp.where(
|
|
129
|
+
i > max_sensor_range - params.unit_sensor_range - 1,
|
|
130
|
+
i + 1 - (max_sensor_range - params.unit_sensor_range),
|
|
131
|
+
0,
|
|
132
|
+
).astype(jnp.int16)
|
|
133
|
+
update = update.at[
|
|
134
|
+
i : max_sensor_range * 2 + 1 - i,
|
|
135
|
+
i : max_sensor_range * 2 + 1 - i,
|
|
136
|
+
].set(val)
|
|
137
|
+
# vision of position at center of update has an extra 10
|
|
138
|
+
update = update.at[
|
|
139
|
+
max_sensor_range,
|
|
140
|
+
max_sensor_range,
|
|
141
|
+
].add(10)
|
|
142
|
+
vision_power_map = jax.lax.dynamic_update_slice(
|
|
143
|
+
vision_power_map,
|
|
144
|
+
update=update + existing_vision_power,
|
|
145
|
+
start_indices=(
|
|
146
|
+
x - max_sensor_range + vision_power_map_padding,
|
|
147
|
+
y - max_sensor_range + vision_power_map_padding,
|
|
148
|
+
),
|
|
149
|
+
)
|
|
150
|
+
return vision_power_map
|
|
151
|
+
|
|
152
|
+
# Apply the sensor mask update for all units of both teams
|
|
153
|
+
def update_unit_vision_power_map(unit_pos, unit_mask, vision_power_map):
|
|
154
|
+
return jax.lax.cond(
|
|
155
|
+
unit_mask,
|
|
156
|
+
lambda: update_vision_power_map(unit_pos, vision_power_map),
|
|
157
|
+
lambda: vision_power_map,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def update_team_vision_power_map(team_units, unit_mask, vision_power_map):
|
|
161
|
+
def body_fun(carry, i):
|
|
162
|
+
vision_power_map = carry
|
|
163
|
+
return (
|
|
164
|
+
update_unit_vision_power_map(team_units.position[i], unit_mask[i], vision_power_map),
|
|
165
|
+
None,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
vision_power_map, _ = jax.lax.scan(body_fun, vision_power_map, jnp.arange(self.fixed_env_params.max_units))
|
|
169
|
+
return vision_power_map
|
|
170
|
+
|
|
171
|
+
vision_power_map = jax.vmap(update_team_vision_power_map)(state.units, state.units_mask, vision_power_map)
|
|
172
|
+
vision_power_map = vision_power_map[
|
|
173
|
+
:,
|
|
174
|
+
vision_power_map_padding:-vision_power_map_padding,
|
|
175
|
+
vision_power_map_padding:-vision_power_map_padding,
|
|
176
|
+
]
|
|
177
|
+
# handle nebula tiles
|
|
178
|
+
vision_power_map = (
|
|
179
|
+
vision_power_map
|
|
180
|
+
- (state.map_features.tile_type == NEBULA_TILE).astype(jnp.int16) * params.nebula_tile_vision_reduction
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
sensor_mask = vision_power_map > 0
|
|
184
|
+
state = state.replace(sensor_mask=sensor_mask)
|
|
185
|
+
state = state.replace(vision_power_map=vision_power_map)
|
|
186
|
+
return state
|
|
187
|
+
|
|
188
|
+
# @functools.partial(jax.jit, static_argnums=(0, 4))
|
|
189
|
+
def step_env(
|
|
190
|
+
self,
|
|
191
|
+
key: chex.PRNGKey,
|
|
192
|
+
state: EnvState,
|
|
193
|
+
action: Union[int, float, chex.Array],
|
|
194
|
+
params: EnvParams,
|
|
195
|
+
) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
|
|
196
|
+
state = self.compute_energy_features(state, params)
|
|
197
|
+
|
|
198
|
+
action = jnp.stack([action["player_0"], action["player_1"]])
|
|
199
|
+
|
|
200
|
+
# remove all units if the match ended in the previous step indicated by a reset of match_steps to 0
|
|
201
|
+
state = state.replace(
|
|
202
|
+
units_mask=jnp.where(
|
|
203
|
+
state.match_steps == 0,
|
|
204
|
+
jnp.zeros_like(state.units_mask),
|
|
205
|
+
state.units_mask,
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
"""remove units that have less than 0 energy"""
|
|
209
|
+
# we remove units at the start of the timestep so that the visualizer can show the unit with negative energy and is marked for removal soon.
|
|
210
|
+
state = state.replace(units_mask=(state.units.energy[..., 0] >= 0) & state.units_mask)
|
|
211
|
+
|
|
212
|
+
"""spawn relic nodes based on schedule"""
|
|
213
|
+
relic_nodes_mask = (state.steps >= state.relic_spawn_schedule) & (state.relic_spawn_schedule != -1)
|
|
214
|
+
state = state.replace(relic_nodes_mask=relic_nodes_mask)
|
|
215
|
+
|
|
216
|
+
""" process unit movement """
|
|
217
|
+
# 0 is do nothing, 1 is move up, 2 is move right, 3 is move down, 4 is move left, 5 is sap
|
|
218
|
+
# Define movement directions
|
|
219
|
+
directions = jnp.array(
|
|
220
|
+
[
|
|
221
|
+
[0, 0], # Do nothing
|
|
222
|
+
[0, -1], # Move up
|
|
223
|
+
[1, 0], # Move right
|
|
224
|
+
[0, 1], # Move down
|
|
225
|
+
[-1, 0], # Move left
|
|
226
|
+
],
|
|
227
|
+
dtype=jnp.int16,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def move_unit(unit: UnitState, action, mask):
|
|
231
|
+
new_pos = unit.position + directions[action]
|
|
232
|
+
# Check if the new position is on a map feature of value 2
|
|
233
|
+
is_blocked = state.map_features.tile_type[new_pos[0], new_pos[1]] == ASTEROID_TILE
|
|
234
|
+
enough_energy = unit.energy >= params.unit_move_cost
|
|
235
|
+
# If blocked, keep the original position
|
|
236
|
+
# new_pos = jnp.where(is_blocked, unit.position, new_pos)
|
|
237
|
+
# Ensure the new position is within the map boundaries
|
|
238
|
+
new_pos = jnp.clip(
|
|
239
|
+
new_pos,
|
|
240
|
+
0,
|
|
241
|
+
jnp.array([params.map_width - 1, params.map_height - 1], dtype=jnp.int16),
|
|
242
|
+
)
|
|
243
|
+
unit_moved = mask & ~is_blocked & enough_energy & (action < 5) & (action > 0)
|
|
244
|
+
# Update the unit's position only if it's active. Note energy is used if unit tries to move off map. Energy is not used if unit tries to move into an asteroid tile.
|
|
245
|
+
return UnitState(
|
|
246
|
+
position=jnp.where(unit_moved, new_pos, unit.position),
|
|
247
|
+
energy=jnp.where(unit_moved, unit.energy - params.unit_move_cost, unit.energy),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Move units for both teams
|
|
251
|
+
move_actions = action[..., 0]
|
|
252
|
+
state = state.replace(
|
|
253
|
+
units=jax.vmap(
|
|
254
|
+
lambda team_units, team_action, team_mask: jax.vmap(move_unit, in_axes=(0, 0, 0))(
|
|
255
|
+
team_units, team_action, team_mask
|
|
256
|
+
),
|
|
257
|
+
in_axes=(0, 0, 0),
|
|
258
|
+
)(state.units, move_actions, state.units_mask)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
original_unit_energy = state.units.energy
|
|
262
|
+
"""original amount of energy of all units"""
|
|
263
|
+
|
|
264
|
+
"""apply sap actions"""
|
|
265
|
+
sap_action_mask = action[..., 0] == 5
|
|
266
|
+
sap_action_deltas = action[..., 1:]
|
|
267
|
+
|
|
268
|
+
def sap_unit(
|
|
269
|
+
current_energy: jnp.ndarray,
|
|
270
|
+
all_units: UnitState,
|
|
271
|
+
sap_action_mask,
|
|
272
|
+
sap_action_deltas,
|
|
273
|
+
units_mask,
|
|
274
|
+
):
|
|
275
|
+
# TODO (stao): clean up this code. It is probably slower than it needs be and could be vmapped perhaps.
|
|
276
|
+
for t in range(self.fixed_env_params.num_teams):
|
|
277
|
+
other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
|
|
278
|
+
team_sap_action_deltas = sap_action_deltas[t] # (max_units, 2)
|
|
279
|
+
team_sap_action_mask = sap_action_mask[t]
|
|
280
|
+
other_team_unit_mask = units_mask[other_team_ids] # (other_teams, max_units)
|
|
281
|
+
team_sapped_positions = all_units.position[t] + team_sap_action_deltas # (max_units, 2)
|
|
282
|
+
# whether the unit is really sapping or not (needs to exist, have enough energy, and a valid sap action)
|
|
283
|
+
team_unit_sapped = (
|
|
284
|
+
units_mask[t]
|
|
285
|
+
& team_sap_action_mask
|
|
286
|
+
& (current_energy[t, :, 0] >= params.unit_sap_cost)
|
|
287
|
+
& (jnp.max(jnp.abs(team_sap_action_deltas), axis=-1) <= params.unit_sap_range)
|
|
288
|
+
) # (max_units)
|
|
289
|
+
team_unit_sapped = (
|
|
290
|
+
team_unit_sapped
|
|
291
|
+
& (team_sapped_positions >= 0).all(-1)
|
|
292
|
+
& (team_sapped_positions[:, 0] < self.fixed_env_params.map_width)
|
|
293
|
+
& (team_sapped_positions[:, 1] < self.fixed_env_params.map_height)
|
|
294
|
+
)
|
|
295
|
+
# the number of times other units are sapped
|
|
296
|
+
other_units_sapped_count = jnp.sum(
|
|
297
|
+
team_unit_sapped[None, None, :]
|
|
298
|
+
& jnp.all(
|
|
299
|
+
all_units.position[other_team_ids][:, :, None] == team_sapped_positions[None],
|
|
300
|
+
axis=-1,
|
|
301
|
+
),
|
|
302
|
+
axis=-1,
|
|
303
|
+
dtype=jnp.int16,
|
|
304
|
+
) # (len(other_team_ids), max_units)
|
|
305
|
+
# remove unit_sap_cost energy from opposition units that were in the middle of a sap action.
|
|
306
|
+
all_units = all_units.replace(
|
|
307
|
+
energy=all_units.energy.at[other_team_ids].set(
|
|
308
|
+
jnp.where(
|
|
309
|
+
other_team_unit_mask[:, :, None] & (other_units_sapped_count[:, :, None] > 0),
|
|
310
|
+
all_units.energy[other_team_ids]
|
|
311
|
+
- params.unit_sap_cost * other_units_sapped_count[:, :, None],
|
|
312
|
+
all_units.energy[other_team_ids],
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# remove unit_sap_cost * unit_sap_dropoff_factor energy from opposition units that were on tiles adjacent to the center of a sap action.
|
|
318
|
+
adjacent_offsets = jnp.array(
|
|
319
|
+
[
|
|
320
|
+
[-1, -1],
|
|
321
|
+
[-1, 0],
|
|
322
|
+
[-1, 1],
|
|
323
|
+
[0, -1],
|
|
324
|
+
[0, 1],
|
|
325
|
+
[1, -1],
|
|
326
|
+
[1, 0],
|
|
327
|
+
[1, 1],
|
|
328
|
+
],
|
|
329
|
+
dtype=jnp.int16,
|
|
330
|
+
)
|
|
331
|
+
team_sapped_adjacent_positions = (
|
|
332
|
+
team_sapped_positions[:, None, :] + adjacent_offsets
|
|
333
|
+
) # (max_units, len(adjacent_offsets), 2)
|
|
334
|
+
other_units_adjacent_sapped_count = jnp.sum(
|
|
335
|
+
team_unit_sapped[None, None, :, None]
|
|
336
|
+
& jnp.all(
|
|
337
|
+
all_units.position[other_team_ids][:, :, None, None] == team_sapped_adjacent_positions[None],
|
|
338
|
+
axis=-1,
|
|
339
|
+
),
|
|
340
|
+
axis=(-1, -2),
|
|
341
|
+
dtype=jnp.int16,
|
|
342
|
+
) # (len(other_team_ids), max_units)
|
|
343
|
+
all_units = all_units.replace(
|
|
344
|
+
energy=all_units.energy.at[other_team_ids].set(
|
|
345
|
+
jnp.where(
|
|
346
|
+
other_team_unit_mask[:, :, None] & (other_units_adjacent_sapped_count[:, :, None] > 0),
|
|
347
|
+
all_units.energy[other_team_ids]
|
|
348
|
+
- jnp.array(
|
|
349
|
+
jnp.array(params.unit_sap_cost, dtype=jnp.float32)
|
|
350
|
+
* params.unit_sap_dropoff_factor
|
|
351
|
+
* other_units_adjacent_sapped_count[:, :, None].astype(jnp.float32),
|
|
352
|
+
dtype=jnp.int16,
|
|
353
|
+
),
|
|
354
|
+
all_units.energy[other_team_ids],
|
|
355
|
+
)
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# remove unit_sap_cost energy from units that tried to sap some position within the unit's range
|
|
360
|
+
all_units = all_units.replace(
|
|
361
|
+
energy=all_units.energy.at[t].set(
|
|
362
|
+
jnp.where(
|
|
363
|
+
team_unit_sapped[:, None],
|
|
364
|
+
all_units.energy[t] - params.unit_sap_cost,
|
|
365
|
+
all_units.energy[t],
|
|
366
|
+
)
|
|
367
|
+
)
|
|
368
|
+
)
|
|
369
|
+
return all_units
|
|
370
|
+
|
|
371
|
+
state = state.replace(
|
|
372
|
+
units=sap_unit(
|
|
373
|
+
original_unit_energy,
|
|
374
|
+
state.units,
|
|
375
|
+
sap_action_mask,
|
|
376
|
+
sap_action_deltas,
|
|
377
|
+
state.units_mask,
|
|
378
|
+
)
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
"""resolve collisions and energy void fields"""
|
|
382
|
+
|
|
383
|
+
# compute energy void fields for all teams and the energy + unit counts
|
|
384
|
+
unit_aggregate_energy_void_map = jnp.zeros(
|
|
385
|
+
shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
|
|
386
|
+
dtype=jnp.int16,
|
|
387
|
+
)
|
|
388
|
+
unit_counts_map = self.compute_unit_counts_map(state, params)
|
|
389
|
+
unit_aggregate_energy_map = jnp.zeros(
|
|
390
|
+
shape=(self.fixed_env_params.num_teams, self.fixed_env_params.map_width, self.fixed_env_params.map_height),
|
|
391
|
+
dtype=jnp.int16,
|
|
392
|
+
)
|
|
393
|
+
for t in range(self.fixed_env_params.num_teams):
|
|
394
|
+
|
|
395
|
+
def scan_body(carry, x):
|
|
396
|
+
agg_energy_void_map, agg_energy_map = carry
|
|
397
|
+
unit_energy, unit_position, unit_mask = x
|
|
398
|
+
agg_energy_map = agg_energy_map.at[unit_position[0], unit_position[1]].add(
|
|
399
|
+
unit_energy[0] * unit_mask.astype(jnp.int16)
|
|
400
|
+
)
|
|
401
|
+
for deltas in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
|
|
402
|
+
new_pos = unit_position + jnp.array(deltas, dtype=jnp.int16)
|
|
403
|
+
in_map = (
|
|
404
|
+
(new_pos[0] >= 0)
|
|
405
|
+
& (new_pos[0] < self.fixed_env_params.map_width)
|
|
406
|
+
& (new_pos[1] >= 0)
|
|
407
|
+
& (new_pos[1] < self.fixed_env_params.map_height)
|
|
408
|
+
)
|
|
409
|
+
agg_energy_void_map = agg_energy_void_map.at[new_pos[0], new_pos[1]].add(
|
|
410
|
+
unit_energy[0] * unit_mask.astype(jnp.int16) * in_map.astype(jnp.int16)
|
|
411
|
+
)
|
|
412
|
+
return (agg_energy_void_map, agg_energy_map), None
|
|
413
|
+
|
|
414
|
+
agg_energy_void_map, agg_energy_map = jax.lax.scan(
|
|
415
|
+
scan_body,
|
|
416
|
+
(unit_aggregate_energy_void_map[t], unit_aggregate_energy_map[t]),
|
|
417
|
+
(original_unit_energy[t], state.units.position[t], state.units_mask[t]),
|
|
418
|
+
)[0]
|
|
419
|
+
unit_aggregate_energy_void_map = unit_aggregate_energy_void_map.at[t].add(agg_energy_void_map)
|
|
420
|
+
unit_aggregate_energy_map = unit_aggregate_energy_map.at[t].add(agg_energy_map)
|
|
421
|
+
|
|
422
|
+
# resolve collisions and keep only the surviving units
|
|
423
|
+
for t in range(self.fixed_env_params.num_teams):
|
|
424
|
+
other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
|
|
425
|
+
# get the energy map for the current team
|
|
426
|
+
opposing_unit_counts_map = unit_counts_map[other_team_ids].sum(axis=0) # (map_width, map_height)
|
|
427
|
+
team_energy_map = unit_aggregate_energy_map[t]
|
|
428
|
+
opposing_aggregate_energy_map = unit_aggregate_energy_map[other_team_ids].max(
|
|
429
|
+
axis=0
|
|
430
|
+
) # (map_width, map_height)
|
|
431
|
+
# unit survives if there are opposing units on the tile, and if the opposing unit stack has less energy on the tile than the current unit
|
|
432
|
+
surviving_unit_mask = jax.vmap(
|
|
433
|
+
lambda unit_position: (opposing_unit_counts_map[unit_position[0], unit_position[1]] == 0)
|
|
434
|
+
| (
|
|
435
|
+
opposing_aggregate_energy_map[unit_position[0], unit_position[1]]
|
|
436
|
+
< team_energy_map[unit_position[0], unit_position[1]]
|
|
437
|
+
)
|
|
438
|
+
)(state.units.position[t])
|
|
439
|
+
state = state.replace(units_mask=state.units_mask.at[t].set(surviving_unit_mask & state.units_mask[t]))
|
|
440
|
+
# apply energy void fields
|
|
441
|
+
for t in range(self.fixed_env_params.num_teams):
|
|
442
|
+
other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
|
|
443
|
+
oppposition_energy_void_map = unit_aggregate_energy_void_map[other_team_ids].sum(
|
|
444
|
+
axis=0
|
|
445
|
+
) # (map_width, map_height)
|
|
446
|
+
# unit on team t loses energy to void field equal to params.unit_energy_void_factor * void_energy / num units stacked with unit on the same tile
|
|
447
|
+
team_unit_energy = state.units.energy[t] - jnp.floor(
|
|
448
|
+
jax.vmap(
|
|
449
|
+
lambda unit_position: params.unit_energy_void_factor
|
|
450
|
+
* oppposition_energy_void_map[unit_position[0], unit_position[1]].astype(jnp.float32)
|
|
451
|
+
/ unit_counts_map[t][unit_position[0], unit_position[1]].astype(jnp.float32)
|
|
452
|
+
)(state.units.position[t])[..., None]
|
|
453
|
+
).astype(jnp.int16)
|
|
454
|
+
state = state.replace(units=state.units.replace(energy=state.units.energy.at[t].set(team_unit_energy)))
|
|
455
|
+
|
|
456
|
+
"""apply energy field to the units"""
|
|
457
|
+
|
|
458
|
+
# Update unit energy based on the energy field and nebula tileof their current position
|
|
459
|
+
def update_unit_energy(unit: UnitState, mask):
|
|
460
|
+
x, y = unit.position
|
|
461
|
+
energy_gain = (
|
|
462
|
+
state.map_features.energy[x, y]
|
|
463
|
+
- (state.map_features.tile_type[x, y] == NEBULA_TILE).astype(jnp.int16)
|
|
464
|
+
* params.nebula_tile_energy_reduction
|
|
465
|
+
)
|
|
466
|
+
# if energy gain is less than 0
|
|
467
|
+
# new_energy = jnp.where((unit.energy < 0) & (energy_gain < 0))
|
|
468
|
+
new_energy = jnp.clip(
|
|
469
|
+
unit.energy + energy_gain,
|
|
470
|
+
params.min_unit_energy,
|
|
471
|
+
params.max_unit_energy,
|
|
472
|
+
)
|
|
473
|
+
# if unit already had negative energy due to opposition units and after energy field/nebula tile it is still below 0, then it will be removed next step
|
|
474
|
+
# and we keep its energy value at whatever it is
|
|
475
|
+
new_energy = jnp.where(
|
|
476
|
+
(unit.energy < 0) & (unit.energy + energy_gain < 0),
|
|
477
|
+
unit.energy,
|
|
478
|
+
new_energy,
|
|
479
|
+
)
|
|
480
|
+
return UnitState(position=unit.position, energy=jnp.where(mask, new_energy, unit.energy))
|
|
481
|
+
|
|
482
|
+
# Apply the energy update for all units of both teams
|
|
483
|
+
state = state.replace(
|
|
484
|
+
units=jax.vmap(lambda team_units, team_mask: jax.vmap(update_unit_energy)(team_units, team_mask))(
|
|
485
|
+
state.units, state.units_mask
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
"""spawn new units in"""
|
|
490
|
+
spawn_units_in = state.match_steps % params.spawn_rate == 0
|
|
491
|
+
|
|
492
|
+
# TODO (stao): only logic in code that probably doesn't not handle more than 2 teams, everything else is vmapped across teams
|
|
493
|
+
def spawn_team_units(state: EnvState):
|
|
494
|
+
team_0_unit_count = state.units_mask[0].sum()
|
|
495
|
+
team_1_unit_count = state.units_mask[1].sum()
|
|
496
|
+
team_0_new_unit_id = state.units_mask[0].argmin()
|
|
497
|
+
team_1_new_unit_id = state.units_mask[1].argmin()
|
|
498
|
+
state = state.replace(
|
|
499
|
+
units=state.units.replace(
|
|
500
|
+
position=jnp.where(
|
|
501
|
+
team_0_unit_count < params.max_units,
|
|
502
|
+
state.units.position.at[0, team_0_new_unit_id, :].set(jnp.array([0, 0], dtype=jnp.int16)),
|
|
503
|
+
state.units.position,
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
state = state.replace(
|
|
508
|
+
units=state.units.replace(
|
|
509
|
+
energy=jnp.where(
|
|
510
|
+
team_0_unit_count < params.max_units,
|
|
511
|
+
state.units.energy.at[0, team_0_new_unit_id, :].set(
|
|
512
|
+
jnp.array([params.init_unit_energy], dtype=jnp.int16)
|
|
513
|
+
),
|
|
514
|
+
state.units.energy,
|
|
515
|
+
)
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
state = state.replace(
|
|
519
|
+
units=state.units.replace(
|
|
520
|
+
position=jnp.where(
|
|
521
|
+
team_1_unit_count < params.max_units,
|
|
522
|
+
state.units.position.at[1, team_1_new_unit_id, :].set(
|
|
523
|
+
jnp.array(
|
|
524
|
+
[params.map_width - 1, params.map_height - 1],
|
|
525
|
+
dtype=jnp.int16,
|
|
526
|
+
)
|
|
527
|
+
),
|
|
528
|
+
state.units.position,
|
|
529
|
+
)
|
|
530
|
+
)
|
|
531
|
+
)
|
|
532
|
+
state = state.replace(
|
|
533
|
+
units=state.units.replace(
|
|
534
|
+
energy=jnp.where(
|
|
535
|
+
team_1_unit_count < params.max_units,
|
|
536
|
+
state.units.energy.at[1, team_1_new_unit_id, :].set(
|
|
537
|
+
jnp.array([params.init_unit_energy], dtype=jnp.int16)
|
|
538
|
+
),
|
|
539
|
+
state.units.energy,
|
|
540
|
+
)
|
|
541
|
+
)
|
|
542
|
+
)
|
|
543
|
+
state = state.replace(
|
|
544
|
+
units_mask=state.units_mask.at[0, team_0_new_unit_id].set(
|
|
545
|
+
jnp.where(
|
|
546
|
+
team_0_unit_count < params.max_units,
|
|
547
|
+
True,
|
|
548
|
+
state.units_mask[0, team_0_new_unit_id],
|
|
549
|
+
)
|
|
550
|
+
)
|
|
551
|
+
)
|
|
552
|
+
state = state.replace(
|
|
553
|
+
units_mask=state.units_mask.at[1, team_1_new_unit_id].set(
|
|
554
|
+
jnp.where(
|
|
555
|
+
team_1_unit_count < params.max_units,
|
|
556
|
+
True,
|
|
557
|
+
state.units_mask[1, team_1_new_unit_id],
|
|
558
|
+
)
|
|
559
|
+
)
|
|
560
|
+
)
|
|
561
|
+
# state = jnp.where(team_0_unit_count < params.max_units, spawn_unit(state, 0, team_0_new_unit_id, [0, 0], params), state)
|
|
562
|
+
# state = jnp.where(team_1_unit_count < params.max_units, spawn_unit(state, 1, team_1_new_unit_id, [params.map_width - 1, params.map_height - 1], params), state)
|
|
563
|
+
return state
|
|
564
|
+
|
|
565
|
+
state = jax.lax.cond(spawn_units_in, lambda: spawn_team_units(state), lambda: state)
|
|
566
|
+
|
|
567
|
+
state = self.compute_sensor_masks(state, params)
|
|
568
|
+
|
|
569
|
+
# Shift objects around in space
|
|
570
|
+
# Move the nebula tiles in state.map_features.tile_types up by 1 and to the right by 1
|
|
571
|
+
# this is also symmetric nebula tile movement
|
|
572
|
+
new_tile_types_map = jnp.roll(
|
|
573
|
+
state.map_features.tile_type,
|
|
574
|
+
shift=(
|
|
575
|
+
1 * jnp.sign(params.nebula_tile_drift_speed),
|
|
576
|
+
-1 * jnp.sign(params.nebula_tile_drift_speed),
|
|
577
|
+
),
|
|
578
|
+
axis=(0, 1),
|
|
579
|
+
)
|
|
580
|
+
new_tile_types_map = jnp.where(
|
|
581
|
+
(state.steps - 1) * abs(params.nebula_tile_drift_speed) % 1
|
|
582
|
+
> state.steps * abs(params.nebula_tile_drift_speed) % 1,
|
|
583
|
+
new_tile_types_map,
|
|
584
|
+
state.map_features.tile_type,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
energy_node_deltas = jnp.round(
|
|
588
|
+
jax.random.uniform(
|
|
589
|
+
key=key,
|
|
590
|
+
shape=(self.fixed_env_params.max_energy_nodes // 2, 2),
|
|
591
|
+
minval=-params.energy_node_drift_magnitude,
|
|
592
|
+
maxval=params.energy_node_drift_magnitude,
|
|
593
|
+
)
|
|
594
|
+
).astype(jnp.int16)
|
|
595
|
+
energy_node_deltas_symmetric = jnp.stack([-energy_node_deltas[:, 1], -energy_node_deltas[:, 0]], axis=-1)
|
|
596
|
+
energy_node_deltas = jnp.concatenate((energy_node_deltas, energy_node_deltas_symmetric))
|
|
597
|
+
new_energy_nodes = jnp.clip(
|
|
598
|
+
state.energy_nodes + energy_node_deltas,
|
|
599
|
+
jnp.array([0, 0], dtype=jnp.int16),
|
|
600
|
+
jnp.array([self.fixed_env_params.map_width - 1, self.fixed_env_params.map_height - 1], dtype=jnp.int16),
|
|
601
|
+
)
|
|
602
|
+
new_energy_nodes = jnp.where(
|
|
603
|
+
(state.steps - 1) * abs(params.energy_node_drift_speed) % 1
|
|
604
|
+
> state.steps * abs(params.energy_node_drift_speed) % 1,
|
|
605
|
+
new_energy_nodes,
|
|
606
|
+
state.energy_nodes,
|
|
607
|
+
)
|
|
608
|
+
state = state.replace(
|
|
609
|
+
map_features=state.map_features.replace(tile_type=new_tile_types_map),
|
|
610
|
+
energy_nodes=new_energy_nodes,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
# Compute relic scores
|
|
614
|
+
def team_relic_score(unit_counts_map):
|
|
615
|
+
# not all relic nodes are spawned in yet, but relic nodes map ids are precomputed for all to be spawned relic nodes
|
|
616
|
+
# for efficiency. So we check if the relic node (by id) is spawned in yet. relic nodes mask is always increasing so we can do a simple trick below
|
|
617
|
+
scores = (
|
|
618
|
+
(unit_counts_map > 0)
|
|
619
|
+
& (state.relic_nodes_map_weights <= state.relic_nodes_mask.sum() // 2)
|
|
620
|
+
& (state.relic_nodes_map_weights > 0)
|
|
621
|
+
)
|
|
622
|
+
return jnp.sum(scores, dtype=jnp.int32)
|
|
623
|
+
|
|
624
|
+
# note we need to recompue unit counts since units can get removed due to collisions
|
|
625
|
+
team_scores = jax.vmap(team_relic_score)(
|
|
626
|
+
self.compute_unit_counts_map(state, params, exclude_negative_energy_units=True)
|
|
627
|
+
)
|
|
628
|
+
# Update team points
|
|
629
|
+
state = state.replace(team_points=state.team_points + team_scores)
|
|
630
|
+
|
|
631
|
+
# if match ended, then remove all units, update team wins, reset team points
|
|
632
|
+
winner_by_points = jnp.where(
|
|
633
|
+
state.team_points.max() > state.team_points.min(),
|
|
634
|
+
jnp.argmax(state.team_points),
|
|
635
|
+
-1,
|
|
636
|
+
)
|
|
637
|
+
winner_by_energy = jnp.sum(state.units.energy[..., 0] * state.units_mask.astype(jnp.int16), axis=1)
|
|
638
|
+
winner_by_energy = jnp.where(
|
|
639
|
+
winner_by_energy.max() > winner_by_energy.min(),
|
|
640
|
+
jnp.argmax(winner_by_energy),
|
|
641
|
+
-1,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
winner = jnp.where(
|
|
645
|
+
winner_by_points != -1,
|
|
646
|
+
winner_by_points,
|
|
647
|
+
jnp.where(
|
|
648
|
+
winner_by_energy != -1,
|
|
649
|
+
winner_by_energy,
|
|
650
|
+
jax.random.randint(key, shape=(), minval=0, maxval=params.num_teams),
|
|
651
|
+
),
|
|
652
|
+
)
|
|
653
|
+
match_ended = state.match_steps >= params.max_steps_in_match
|
|
654
|
+
|
|
655
|
+
state = state.replace(
|
|
656
|
+
match_steps=jnp.where(match_ended, -1, state.match_steps),
|
|
657
|
+
team_points=jnp.where(match_ended, jnp.zeros_like(state.team_points), state.team_points),
|
|
658
|
+
team_wins=jnp.where(match_ended, state.team_wins.at[winner].add(1), state.team_wins),
|
|
659
|
+
)
|
|
660
|
+
# Update state's step count
|
|
661
|
+
state = state.replace(steps=state.steps + 1, match_steps=state.match_steps + 1)
|
|
662
|
+
truncated = state.steps >= (params.max_steps_in_match + 1) * params.match_count_per_episode
|
|
663
|
+
reward = dict()
|
|
664
|
+
for k in range(self.fixed_env_params.num_teams):
|
|
665
|
+
reward[f"player_{k}"] = state.team_wins[k]
|
|
666
|
+
terminated = self.is_terminal(state, params)
|
|
667
|
+
return (
|
|
668
|
+
lax.stop_gradient(self.get_obs(state, params, key=key)),
|
|
669
|
+
lax.stop_gradient(state),
|
|
670
|
+
reward,
|
|
671
|
+
terminated,
|
|
672
|
+
truncated,
|
|
673
|
+
{"discount": self.discount(state, params)},
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
def reset_env(self, key: chex.PRNGKey, params: EnvParams) -> Tuple[EnvObs, EnvState]:
|
|
677
|
+
"""Reset environment state by sampling initial position."""
|
|
678
|
+
|
|
679
|
+
state = gen_state(
|
|
680
|
+
key=key,
|
|
681
|
+
env_params=params,
|
|
682
|
+
max_units=self.fixed_env_params.max_units,
|
|
683
|
+
num_teams=self.fixed_env_params.num_teams,
|
|
684
|
+
map_type=self.fixed_env_params.map_type,
|
|
685
|
+
map_width=self.fixed_env_params.map_width,
|
|
686
|
+
map_height=self.fixed_env_params.map_height,
|
|
687
|
+
max_energy_nodes=self.fixed_env_params.max_energy_nodes,
|
|
688
|
+
max_relic_nodes=self.fixed_env_params.max_relic_nodes,
|
|
689
|
+
relic_config_size=self.fixed_env_params.relic_config_size,
|
|
690
|
+
)
|
|
691
|
+
state = self.compute_energy_features(state, params)
|
|
692
|
+
state = self.compute_sensor_masks(state, params)
|
|
693
|
+
return self.get_obs(state, params=params, key=key), state
|
|
694
|
+
|
|
695
|
+
@functools.partial(jax.jit, static_argnums=(0,))
|
|
696
|
+
def step(
|
|
697
|
+
self,
|
|
698
|
+
key: chex.PRNGKey,
|
|
699
|
+
state: EnvState,
|
|
700
|
+
action: Union[int, float, chex.Array],
|
|
701
|
+
params: Optional[EnvParams] = None,
|
|
702
|
+
) -> Tuple[EnvObs, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
|
|
703
|
+
"""Performs step transitions in the environment."""
|
|
704
|
+
# Use default env parameters if no others specified
|
|
705
|
+
if params is None:
|
|
706
|
+
params = self.default_params
|
|
707
|
+
key, key_reset = jax.random.split(key)
|
|
708
|
+
obs_st, state_st, reward, terminated, truncated, info = self.step_env(key, state, action, params)
|
|
709
|
+
info["final_state"] = state_st
|
|
710
|
+
info["final_observation"] = obs_st
|
|
711
|
+
done = terminated | truncated
|
|
712
|
+
|
|
713
|
+
if self.auto_reset:
|
|
714
|
+
obs_re, state_re = self.reset_env(key_reset, params)
|
|
715
|
+
# Use lax.cond to efficiently choose between obs_re and obs_st
|
|
716
|
+
obs = jax.lax.cond(done, lambda: obs_re, lambda: obs_st)
|
|
717
|
+
state = jax.lax.cond(done, lambda: state_re, lambda: state_st)
|
|
718
|
+
else:
|
|
719
|
+
obs = obs_st
|
|
720
|
+
state = state_st
|
|
721
|
+
|
|
722
|
+
# all agents terminate/truncate at same time
|
|
723
|
+
terminated_dict = dict()
|
|
724
|
+
truncated_dict = dict()
|
|
725
|
+
for k in range(self.fixed_env_params.num_teams):
|
|
726
|
+
terminated_dict[f"player_{k}"] = terminated
|
|
727
|
+
truncated_dict[f"player_{k}"] = truncated
|
|
728
|
+
info[f"player_{k}"] = dict()
|
|
729
|
+
return obs, state, reward, terminated_dict, truncated_dict, info
|
|
730
|
+
|
|
731
|
+
@functools.partial(jax.jit, static_argnums=(0,))
|
|
732
|
+
def reset(self, key: chex.PRNGKey, params: Optional[EnvParams] = None) -> Tuple[chex.Array, EnvState]:
|
|
733
|
+
"""Performs resetting of environment."""
|
|
734
|
+
# Use default env parameters if no others specified
|
|
735
|
+
if params is None:
|
|
736
|
+
params = self.default_params
|
|
737
|
+
|
|
738
|
+
obs, state = self.reset_env(key, params)
|
|
739
|
+
return obs, state
|
|
740
|
+
|
|
741
|
+
# @functools.partial(jax.jit, static_argnums=(0, 2))
|
|
742
|
+
def get_obs(self, state: EnvState, params=None, key=None) -> EnvObs:
|
|
743
|
+
"""Return observation from raw state, handling partial observability."""
|
|
744
|
+
obs = dict()
|
|
745
|
+
|
|
746
|
+
def update_unit_mask(unit_position, unit_mask, sensor_mask):
|
|
747
|
+
return unit_mask & sensor_mask[unit_position[0], unit_position[1]]
|
|
748
|
+
|
|
749
|
+
def update_team_unit_mask(unit_position, unit_mask, sensor_mask):
|
|
750
|
+
return jax.vmap(update_unit_mask, in_axes=(0, 0, None))(unit_position, unit_mask, sensor_mask)
|
|
751
|
+
|
|
752
|
+
def update_relic_nodes_mask(relic_nodes_mask, relic_nodes, sensor_mask):
|
|
753
|
+
return jax.vmap(
|
|
754
|
+
lambda r_mask, r, s_mask: r_mask & s_mask[r[0], r[1]],
|
|
755
|
+
in_axes=(0, 0, None),
|
|
756
|
+
)(relic_nodes_mask, relic_nodes, sensor_mask)
|
|
757
|
+
|
|
758
|
+
for t in range(self.fixed_env_params.num_teams):
|
|
759
|
+
other_team_ids = jnp.array([t2 for t2 in range(self.fixed_env_params.num_teams) if t2 != t])
|
|
760
|
+
new_unit_masks = jax.vmap(update_team_unit_mask, in_axes=(0, 0, None))(
|
|
761
|
+
state.units.position[other_team_ids],
|
|
762
|
+
state.units_mask[other_team_ids],
|
|
763
|
+
state.sensor_mask[t],
|
|
764
|
+
)
|
|
765
|
+
new_unit_masks = state.units_mask.at[other_team_ids].set(new_unit_masks)
|
|
766
|
+
|
|
767
|
+
new_relic_nodes_mask = update_relic_nodes_mask(
|
|
768
|
+
state.relic_nodes_mask, state.relic_nodes, state.sensor_mask[t]
|
|
769
|
+
)
|
|
770
|
+
team_obs = EnvObs(
|
|
771
|
+
units=UnitState(
|
|
772
|
+
position=jnp.where(new_unit_masks[..., None], state.units.position, -1),
|
|
773
|
+
energy=jnp.where(new_unit_masks[..., None], state.units.energy, -1)[..., 0],
|
|
774
|
+
),
|
|
775
|
+
units_mask=new_unit_masks,
|
|
776
|
+
sensor_mask=state.sensor_mask[t],
|
|
777
|
+
map_features=MapTile(
|
|
778
|
+
energy=jnp.where(state.sensor_mask[t], state.map_features.energy, -1),
|
|
779
|
+
tile_type=jnp.where(state.sensor_mask[t], state.map_features.tile_type, -1),
|
|
780
|
+
),
|
|
781
|
+
team_points=state.team_points,
|
|
782
|
+
team_wins=state.team_wins,
|
|
783
|
+
steps=state.steps,
|
|
784
|
+
match_steps=state.match_steps,
|
|
785
|
+
relic_nodes=jnp.where(new_relic_nodes_mask[..., None], state.relic_nodes, -1),
|
|
786
|
+
relic_nodes_mask=new_relic_nodes_mask,
|
|
787
|
+
)
|
|
788
|
+
obs[f"player_{t}"] = team_obs
|
|
789
|
+
return obs
|
|
790
|
+
|
|
791
|
+
@functools.partial(jax.jit, static_argnums=(0,))
|
|
792
|
+
def is_terminal(self, state: EnvState, params: EnvParams) -> jnp.ndarray:
|
|
793
|
+
"""Check whether state is terminal. This never occurs. Game is only done when the time limit is reached."""
|
|
794
|
+
terminated = jnp.array(False)
|
|
795
|
+
return terminated
|
|
796
|
+
|
|
797
|
+
@property
|
|
798
|
+
def name(self) -> str:
|
|
799
|
+
"""Environment name."""
|
|
800
|
+
return "Lux AI Season 3"
|
|
801
|
+
|
|
802
|
+
def render(self, state: EnvState, params: EnvParams):
|
|
803
|
+
self.renderer.render(state, params)
|
|
804
|
+
|
|
805
|
+
def action_space(self, params: Optional[EnvParams] = None):
|
|
806
|
+
"""Action space of the environment."""
|
|
807
|
+
low = np.zeros((self.fixed_env_params.max_units, 3))
|
|
808
|
+
low[:, 1:] = -env_params_ranges["unit_sap_range"][-1]
|
|
809
|
+
high = np.ones((self.fixed_env_params.max_units, 3)) * 6
|
|
810
|
+
high[:, 1:] = env_params_ranges["unit_sap_range"][-1]
|
|
811
|
+
return spaces.Dict(dict(player_0=MultiDiscrete(low, high), player_1=MultiDiscrete(low, high)))
|
|
812
|
+
|
|
813
|
+
def observation_space(self, params: EnvParams):
|
|
814
|
+
"""Observation space of the environment."""
|
|
815
|
+
return spaces.Discrete(10)
|
|
816
|
+
|
|
817
|
+
def state_space(self, params: EnvParams):
|
|
818
|
+
"""State space of the environment."""
|
|
819
|
+
return spaces.Discrete(10)
|