parabellum 0.0.0__py3-none-any.whl → 0.0.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +3 -3
- parabellum/env.py +51 -277
- parabellum/geo.py +213 -0
- parabellum/types.py +138 -0
- parabellum/utils.py +59 -0
- parabellum-0.0.73.dist-info/METADATA +46 -0
- parabellum-0.0.73.dist-info/RECORD +8 -0
- {parabellum-0.0.0.dist-info → parabellum-0.0.73.dist-info}/WHEEL +1 -1
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py +0 -4
- parabellum/.ipynb_checkpoints/env-checkpoint.py +0 -296
- parabellum/.ipynb_checkpoints/vis-checkpoint.py +0 -230
- parabellum/vis.py +0 -230
- parabellum-0.0.0.dist-info/METADATA +0 -55
- parabellum-0.0.0.dist-info/RECORD +0 -9
parabellum/types.py
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
# types.py
|
2
|
+
# parabellum types
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
from chex import dataclass
|
7
|
+
from jaxtyping import Array, Float32, Int
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from parabellum.geo import geography_fn
|
10
|
+
from dataclasses import field
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class Kind:
|
15
|
+
hp: int
|
16
|
+
dam: int
|
17
|
+
speed: int
|
18
|
+
reach: int
|
19
|
+
sight: int
|
20
|
+
blast: int
|
21
|
+
r: float
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class Rules:
|
26
|
+
troop = Kind(hp=120, dam=15, speed=2, reach=4, sight=4, blast=1, r=1)
|
27
|
+
armor = Kind(hp=150, dam=12, speed=1, reach=8, sight=16, blast=3, r=2)
|
28
|
+
plane = Kind(hp=80, dam=20, speed=4, reach=16, sight=32, blast=2, r=2)
|
29
|
+
civil = Kind(hp=100, dam=0, speed=3, reach=5, sight=10, blast=1, r=2)
|
30
|
+
medic = Kind(hp=100, dam=-10, speed=3, reach=5, sight=10, blast=1, r=2)
|
31
|
+
|
32
|
+
def __post_init__(self):
|
33
|
+
self.hp = jnp.array((self.troop.hp, self.armor.hp, self.plane.hp, self.civil.hp, self.medic.hp))
|
34
|
+
self.dam = jnp.array((self.troop.dam, self.armor.dam, self.plane.dam, self.civil.dam, self.medic.dam))
|
35
|
+
self.r = jnp.array((self.troop.r, self.armor.r, self.plane.r, self.civil.r, self.medic.r))
|
36
|
+
self.speed = jnp.array(
|
37
|
+
(self.troop.speed, self.armor.speed, self.plane.speed, self.civil.speed, self.medic.speed)
|
38
|
+
)
|
39
|
+
self.reach = jnp.array(
|
40
|
+
(self.troop.reach, self.armor.reach, self.plane.reach, self.civil.reach, self.medic.reach)
|
41
|
+
)
|
42
|
+
self.sight = jnp.array(
|
43
|
+
(self.troop.sight, self.armor.sight, self.plane.sight, self.civil.sight, self.medic.sight)
|
44
|
+
)
|
45
|
+
self.blast = jnp.array(
|
46
|
+
(self.troop.blast, self.armor.blast, self.plane.blast, self.civil.blast, self.medic.blast)
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class Team:
|
52
|
+
troop: int = 1
|
53
|
+
armor: int = 0
|
54
|
+
plane: int = 0
|
55
|
+
civil: int = 0
|
56
|
+
medic: int = 0
|
57
|
+
|
58
|
+
def __post_init__(self):
|
59
|
+
self.length: int = self.troop + self.armor + self.plane + self.civil + self.medic
|
60
|
+
self.types: Array = jnp.repeat(
|
61
|
+
jnp.arange(5), jnp.array((self.troop, self.armor, self.plane, self.civil, self.medic))
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
# dataclasses
|
66
|
+
@dataclass
|
67
|
+
class State:
|
68
|
+
pos: Array
|
69
|
+
hp: Array
|
70
|
+
# target: Array
|
71
|
+
|
72
|
+
|
73
|
+
@dataclass
|
74
|
+
class Obs:
|
75
|
+
# idxs: Array
|
76
|
+
hp: Array
|
77
|
+
pos: Array
|
78
|
+
type: Array
|
79
|
+
team: Array
|
80
|
+
dist: Array
|
81
|
+
mask: Array
|
82
|
+
reach: Array
|
83
|
+
sight: Array
|
84
|
+
speed: Array
|
85
|
+
|
86
|
+
@property
|
87
|
+
def ally(self):
|
88
|
+
return (self.team == self.team[0]) & self.mask
|
89
|
+
|
90
|
+
@property
|
91
|
+
def enemy(self):
|
92
|
+
return (self.team != self.team[0]) & self.mask
|
93
|
+
|
94
|
+
|
95
|
+
@dataclass
|
96
|
+
class Action:
|
97
|
+
pos: Array
|
98
|
+
kind: Int[Array, "..."] # 0 = invalid, 1 = move, 2 = cast
|
99
|
+
|
100
|
+
@property
|
101
|
+
def invalid(self):
|
102
|
+
return self.kind == 0
|
103
|
+
|
104
|
+
@property
|
105
|
+
def move(self):
|
106
|
+
return self.kind == 1
|
107
|
+
|
108
|
+
@property
|
109
|
+
def cast(self): # cast bomb, bullet or medicin
|
110
|
+
return self.kind == 2
|
111
|
+
|
112
|
+
|
113
|
+
@dataclass
|
114
|
+
class Config: # Remove frozen=True for now
|
115
|
+
steps: int = 123
|
116
|
+
place: str = "Palazzo della Civiltà Italiana, Rome, Italy"
|
117
|
+
force: float = 0.5
|
118
|
+
sims: int = 2
|
119
|
+
size: int = 64
|
120
|
+
knn: int = 2
|
121
|
+
blu: Team = field(default_factory=lambda: Team())
|
122
|
+
red: Team = field(default_factory=lambda: Team())
|
123
|
+
rules: Rules = field(default_factory=lambda: Rules())
|
124
|
+
|
125
|
+
def __post_init__(self):
|
126
|
+
# Pre-compute everything once
|
127
|
+
self.types: Array = jnp.concat((self.blu.types, self.red.types))
|
128
|
+
self.teams: Array = jnp.repeat(jnp.arange(2), jnp.array((self.blu.length, self.red.length)))
|
129
|
+
self.map: Array = geography_fn(self.place, self.size) # Computed once here
|
130
|
+
self.hp: Array = self.rules.hp
|
131
|
+
self.dam: Array = self.rules.dam
|
132
|
+
self.r: Array = self.rules.r
|
133
|
+
self.speed: Array = self.rules.speed
|
134
|
+
self.reach: Array = self.rules.reach
|
135
|
+
self.sight: Array = self.rules.sight
|
136
|
+
self.blast: Array = self.rules.blast
|
137
|
+
self.length: int = self.blu.length + self.red.length
|
138
|
+
self.root: Array = jnp.int32(jnp.sqrt(self.length))
|
parabellum/utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
# %% utils.py
|
2
|
+
# parabellum ut
|
3
|
+
|
4
|
+
|
5
|
+
# Imports
|
6
|
+
import esch
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
from einops import rearrange, repeat
|
10
|
+
from jax import tree
|
11
|
+
from PIL import Image
|
12
|
+
from parabellum.types import Config
|
13
|
+
|
14
|
+
# Twilight colors (used in neurocope)
|
15
|
+
red = "#EA344A"
|
16
|
+
blue = "#2B60F6"
|
17
|
+
|
18
|
+
|
19
|
+
# %% Plotting
|
20
|
+
def gif_fn(cfg: Config, seq, scale=4): # animate positions TODO: remove dead units
|
21
|
+
pos = seq.pos.astype(int)
|
22
|
+
cord = jnp.concat((jnp.arange(pos.shape[0]).repeat(pos.shape[1])[..., None], pos.reshape(-1, 2)), axis=1).T
|
23
|
+
idxs = cord[:, seq.hp.flatten().astype(bool) > 0]
|
24
|
+
imgs = 1 - np.array(repeat(cfg.map, "... -> a ...", a=len(pos)).at[*idxs].set(1))
|
25
|
+
imgs = [Image.fromarray(img).resize(np.array(img.shape[:2]) * scale, Image.NEAREST) for img in imgs * 255] # type: ignore
|
26
|
+
imgs[0].save("/Users/nobr/desk/s3/btc2sim/sims.gif", save_all=True, append_images=imgs[1:], duration=10, loop=0)
|
27
|
+
|
28
|
+
|
29
|
+
def svg_fn(cfg: Config, seq, action, fname, targets=None, fps=2, debug=False):
|
30
|
+
# set up and background
|
31
|
+
e = esch.Drawing(h=cfg.size, w=cfg.size, row=1, col=seq.pos.shape[0], debug=debug, pad=10)
|
32
|
+
esch.grid_fn(repeat(np.array(cfg.map, dtype=float), f"... -> {seq.pos.shape[0]} ...") * 0.5, e, shape="square")
|
33
|
+
|
34
|
+
# loop thorugh teams
|
35
|
+
for i in jnp.unique(cfg.teams): # c#fg.teams.unique():
|
36
|
+
col = "red" if i == 1 else "blue"
|
37
|
+
|
38
|
+
# loop through types
|
39
|
+
for j in jnp.unique(cfg.types):
|
40
|
+
mask = (cfg.teams == i) & (cfg.types == j)
|
41
|
+
size, blast = float(cfg.rules.r[j]), float(cfg.rules.blast[j])
|
42
|
+
subset = np.array(rearrange(seq.pos, "a b c d -> a c d b"), dtype=float)[:, mask]
|
43
|
+
# print(tree.map(jnp.shape, action), mask.shape)
|
44
|
+
sub_action = tree.map(lambda x: x[:, :, mask], action)
|
45
|
+
# print(tree.map(jnp.shape, sub_action))
|
46
|
+
esch.sims_fn(e, subset, action=sub_action, fps=fps, col=col, stroke=col, size=size, blast=blast)
|
47
|
+
|
48
|
+
if debug:
|
49
|
+
sight, reach = float(cfg.rules.sight[j]), float(cfg.rules.reach[j])
|
50
|
+
esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=reach, stroke="grey")
|
51
|
+
esch.sims_fn(e, subset, action=None, col="none", fps=fps, size=sight, stroke="yellow")
|
52
|
+
|
53
|
+
if targets is not None:
|
54
|
+
pos = np.array(repeat(targets, f"... -> {seq.pos.shape[0]} ..."))
|
55
|
+
arr = np.ones(pos.shape[:-1])
|
56
|
+
esch.mesh_fn(e, pos, arr, shape="square", col="purple")
|
57
|
+
|
58
|
+
# save
|
59
|
+
e.dwg.saveas(fname)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.0.73
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Author-email: Noah Syrkis <desk@syrkis.com>
|
6
|
+
Requires-Python: <3.12,>=3.11
|
7
|
+
Requires-Dist: brax<0.13,>=0.12.1
|
8
|
+
Requires-Dist: cachier<4,>=3.1.2
|
9
|
+
Requires-Dist: cartopy<0.24,>=0.23.0
|
10
|
+
Requires-Dist: contextily<2,>=1.6.0
|
11
|
+
Requires-Dist: distrax<0.2,>=0.1.5
|
12
|
+
Requires-Dist: einops<0.9,>=0.8.0
|
13
|
+
Requires-Dist: equinox>=0.12.2
|
14
|
+
Requires-Dist: evosax<0.2,>=0.1.6
|
15
|
+
Requires-Dist: flashbax<0.2,>=0.1.2
|
16
|
+
Requires-Dist: flax<0.11,>=0.10.4
|
17
|
+
Requires-Dist: folium<0.18,>=0.17.0
|
18
|
+
Requires-Dist: geopy<3,>=2.4.1
|
19
|
+
Requires-Dist: gymnax<0.0.9,>=0.0.8
|
20
|
+
Requires-Dist: ipykernel<7,>=6.29.5
|
21
|
+
Requires-Dist: ipython>=8.36.0
|
22
|
+
Requires-Dist: jax-tqdm<0.4,>=0.3.1
|
23
|
+
Requires-Dist: jax<0.7,>=0.6.0
|
24
|
+
Requires-Dist: jaxkd>=0.1.0
|
25
|
+
Requires-Dist: jaxtyping<0.3,>=0.2.33
|
26
|
+
Requires-Dist: jupyterlab<5,>=4.2.2
|
27
|
+
Requires-Dist: navix<0.8,>=0.7.0
|
28
|
+
Requires-Dist: notebook>=7.4.2
|
29
|
+
Requires-Dist: numpy>=2
|
30
|
+
Requires-Dist: omegaconf<3,>=2.3.0
|
31
|
+
Requires-Dist: optax<0.3,>=0.2.4
|
32
|
+
Requires-Dist: osmnx==2.0.0b0
|
33
|
+
Requires-Dist: pandas<3,>=2.2.2
|
34
|
+
Requires-Dist: poetry<2,>=1.8.3
|
35
|
+
Requires-Dist: rasterio<2,>=1.3.10
|
36
|
+
Requires-Dist: stadiamaps<4,>=3.2.1
|
37
|
+
Requires-Dist: tensorboard>=2.19.0
|
38
|
+
Requires-Dist: tensorflow>=2.19.0
|
39
|
+
Requires-Dist: tqdm<5,>=4.66.4
|
40
|
+
Requires-Dist: wandb<0.20,>=0.19.7
|
41
|
+
Requires-Dist: xprof>=2.20.0
|
42
|
+
Description-Content-Type: text/markdown
|
43
|
+
|
44
|
+
# Parabellum
|
45
|
+
|
46
|
+
TODO: switch to red and blue team semantics (not enemy and ally)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
parabellum/__init__.py,sha256=yl3tJXQYNnBgAviJINdHrcALb1177n0CFZ2lGiacbQA,109
|
2
|
+
parabellum/env.py,sha256=8KaqUwAmP9FXrINOxGWwZmQJeM6eigeWdjNhlWUNi68,3236
|
3
|
+
parabellum/geo.py,sha256=lZ4TQvpfPGW3LV9X-1lfWh3LnwoRCtl_DkN6h26xMng,6922
|
4
|
+
parabellum/types.py,sha256=mkBwCA-2_syJAZBeNie9iCG9Cy0SvANMAjCo93XIS1A,3981
|
5
|
+
parabellum/utils.py,sha256=wOE2nfcdSvGOBFaGlTEIzS6w451arpT9VSOf_3Zhjwc,2497
|
6
|
+
parabellum-0.0.73.dist-info/METADATA,sha256=jHqmHqdrCHM2tIam9bMZ7xwvR1BFLy9bc9qSviKcYAc,1481
|
7
|
+
parabellum-0.0.73.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
parabellum-0.0.73.dist-info/RECORD,,
|
@@ -1,296 +0,0 @@
|
|
1
|
-
"""Parabellum environment based on SMAX"""
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import jax
|
5
|
-
import numpy as np
|
6
|
-
from jax import random
|
7
|
-
from jax import jit
|
8
|
-
from flax.struct import dataclass
|
9
|
-
import chex
|
10
|
-
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
|
-
from typing import Tuple, Dict
|
12
|
-
from functools import partial
|
13
|
-
|
14
|
-
|
15
|
-
@dataclass
|
16
|
-
class Scenario:
|
17
|
-
"""Parabellum scenario"""
|
18
|
-
|
19
|
-
obstacle_coords: chex.Array
|
20
|
-
obstacle_deltas: chex.Array
|
21
|
-
|
22
|
-
unit_types: chex.Array
|
23
|
-
num_allies: int
|
24
|
-
num_enemies: int
|
25
|
-
|
26
|
-
smacv2_position_generation: bool = False
|
27
|
-
smacv2_unit_type_generation: bool = False
|
28
|
-
|
29
|
-
|
30
|
-
# default scenario
|
31
|
-
scenarios = {
|
32
|
-
"default": Scenario(
|
33
|
-
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
|
-
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
|
-
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
-
9,
|
37
|
-
10,
|
38
|
-
)
|
39
|
-
}
|
40
|
-
|
41
|
-
|
42
|
-
class Parabellum(SMAX):
|
43
|
-
def __init__(
|
44
|
-
self,
|
45
|
-
scenario: Scenario = scenarios["default"],
|
46
|
-
unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
|
47
|
-
**kwargs,
|
48
|
-
):
|
49
|
-
super().__init__(scenario=scenario, **kwargs)
|
50
|
-
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
-
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
-
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
53
|
-
self.max_steps = 200
|
54
|
-
# overwrite supers _world_step method
|
55
|
-
|
56
|
-
|
57
|
-
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
-
return state
|
59
|
-
|
60
|
-
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
-
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
-
dist_matrix = (
|
63
|
-
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
-
+ jnp.identity(self.num_agents)
|
65
|
-
+ 1e-6
|
66
|
-
)
|
67
|
-
radius_matrix = (
|
68
|
-
self.unit_type_radiuses[unit_types][:, None]
|
69
|
-
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
-
)
|
71
|
-
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
-
unit_positions = (
|
73
|
-
pos
|
74
|
-
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
-
)
|
76
|
-
return unit_positions
|
77
|
-
|
78
|
-
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
|
-
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
|
-
self,
|
81
|
-
key: chex.PRNGKey,
|
82
|
-
state: State,
|
83
|
-
actions: Tuple[chex.Array, chex.Array],
|
84
|
-
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
-
|
86
|
-
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
-
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
-
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
-
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
-
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
-
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
-
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
-
|
94
|
-
def update_position(idx, vec):
|
95
|
-
# Compute the movements slightly strangely.
|
96
|
-
# The velocities below are for diagonal directions
|
97
|
-
# because these are easier to encode as actions than the four
|
98
|
-
# diagonal directions. Then rotate the velocity 45
|
99
|
-
# degrees anticlockwise to compute the movement.
|
100
|
-
pos = state.unit_positions[idx]
|
101
|
-
new_pos = (
|
102
|
-
pos
|
103
|
-
+ vec
|
104
|
-
* self.unit_type_velocities[state.unit_types[idx]]
|
105
|
-
* self.time_per_step
|
106
|
-
)
|
107
|
-
# avoid going out of bounds
|
108
|
-
new_pos = jnp.maximum(
|
109
|
-
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
110
|
-
jnp.zeros((2,)),
|
111
|
-
)
|
112
|
-
|
113
|
-
#######################################################################
|
114
|
-
############################################ avoid going into obstacles
|
115
|
-
obs = self.obstacle_coords
|
116
|
-
obs_end = obs + self.obstacle_deltas
|
117
|
-
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
118
|
-
new_pos = jnp.where(inters, pos, new_pos)
|
119
|
-
|
120
|
-
#######################################################################
|
121
|
-
#######################################################################
|
122
|
-
|
123
|
-
return new_pos
|
124
|
-
|
125
|
-
#######################################################################
|
126
|
-
######################################### units close enough to get hit
|
127
|
-
|
128
|
-
def bystander_fn(attacked_idx):
|
129
|
-
idxs = jnp.zeros((self.num_agents,))
|
130
|
-
idxs *= (
|
131
|
-
jnp.linalg.norm(
|
132
|
-
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
133
|
-
)
|
134
|
-
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
135
|
-
)
|
136
|
-
return idxs
|
137
|
-
|
138
|
-
#######################################################################
|
139
|
-
#######################################################################
|
140
|
-
|
141
|
-
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
142
|
-
# for team 1, their attack actions are labelled in
|
143
|
-
# reverse order because that is the order they are
|
144
|
-
# observed in
|
145
|
-
attacked_idx = jax.lax.cond(
|
146
|
-
idx < self.num_allies,
|
147
|
-
lambda: action + self.num_allies - self.num_movement_actions,
|
148
|
-
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
149
|
-
)
|
150
|
-
# deal with no-op attack actions (i.e. agents that are moving instead)
|
151
|
-
attacked_idx = jax.lax.select(
|
152
|
-
action < self.num_movement_actions, idx, attacked_idx
|
153
|
-
)
|
154
|
-
|
155
|
-
attack_valid = (
|
156
|
-
(
|
157
|
-
jnp.linalg.norm(
|
158
|
-
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
159
|
-
)
|
160
|
-
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
161
|
-
)
|
162
|
-
& state.unit_alive[idx]
|
163
|
-
& state.unit_alive[attacked_idx]
|
164
|
-
)
|
165
|
-
attack_valid = attack_valid & (idx != attacked_idx)
|
166
|
-
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
167
|
-
health_diff = jax.lax.select(
|
168
|
-
attack_valid,
|
169
|
-
-self.unit_type_attacks[state.unit_types[idx]],
|
170
|
-
0.0,
|
171
|
-
)
|
172
|
-
# design choice based on the pysc2 randomness details.
|
173
|
-
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
174
|
-
|
175
|
-
#########################################################
|
176
|
-
############################### Add bystander health diff
|
177
|
-
|
178
|
-
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
179
|
-
bystander_valid = (
|
180
|
-
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
181
|
-
.astype(jnp.bool_)
|
182
|
-
.astype(jnp.float32)
|
183
|
-
)
|
184
|
-
bystander_health_diff = (
|
185
|
-
bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
186
|
-
)
|
187
|
-
|
188
|
-
#########################################################
|
189
|
-
#########################################################
|
190
|
-
|
191
|
-
cooldown_deviation = jax.random.uniform(
|
192
|
-
key, minval=-self.time_per_step, maxval=2 * self.time_per_step
|
193
|
-
)
|
194
|
-
cooldown = (
|
195
|
-
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
196
|
-
+ cooldown_deviation
|
197
|
-
)
|
198
|
-
cooldown_diff = jax.lax.select(
|
199
|
-
attack_valid,
|
200
|
-
# subtract the current cooldown because we are
|
201
|
-
# going to add it back. This way we effectively
|
202
|
-
# set the new cooldown to `cooldown`
|
203
|
-
cooldown - state.unit_weapon_cooldowns[idx],
|
204
|
-
-self.time_per_step,
|
205
|
-
)
|
206
|
-
return (
|
207
|
-
health_diff,
|
208
|
-
attacked_idx,
|
209
|
-
cooldown_diff,
|
210
|
-
(bystander_health_diff, bystander_idxs),
|
211
|
-
)
|
212
|
-
|
213
|
-
def perform_agent_action(idx, action, key):
|
214
|
-
movement_action, attack_action = action
|
215
|
-
new_pos = update_position(idx, movement_action)
|
216
|
-
health_diff, attacked_idxes, cooldown_diff, (bystander) = (
|
217
|
-
update_agent_health(idx, attack_action, key)
|
218
|
-
)
|
219
|
-
|
220
|
-
return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
|
221
|
-
|
222
|
-
keys = jax.random.split(key, num=self.num_agents)
|
223
|
-
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
224
|
-
perform_agent_action
|
225
|
-
)(jnp.arange(self.num_agents), actions, keys)
|
226
|
-
|
227
|
-
# checked that no unit passed through an obstacles
|
228
|
-
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
|
-
|
230
|
-
# avoid going into obstacles after being pushed
|
231
|
-
obs = self.obstacle_coords
|
232
|
-
obs_end = obs + self.obstacle_deltas
|
233
|
-
|
234
|
-
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
-
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
-
return jnp.where(inters, pos, new_pos)
|
237
|
-
|
238
|
-
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
-
|
240
|
-
# Multiple enemies can attack the same unit.
|
241
|
-
# We have `(health_diff, attacked_idx)` pairs.
|
242
|
-
# `jax.lax.scatter_add` aggregates these exactly
|
243
|
-
# in the way we want -- duplicate idxes will have their
|
244
|
-
# health differences added together. However, it is a
|
245
|
-
# super thin wrapper around the XLA scatter operation,
|
246
|
-
# which has this bonkers syntax and requires this dnums
|
247
|
-
# parameter. The usage here was inferred from a test:
|
248
|
-
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
249
|
-
dnums = jax.lax.ScatterDimensionNumbers(
|
250
|
-
update_window_dims=(),
|
251
|
-
inserted_window_dims=(0,),
|
252
|
-
scatter_dims_to_operand_dims=(0,),
|
253
|
-
)
|
254
|
-
unit_health = jnp.maximum(
|
255
|
-
jax.lax.scatter_add(
|
256
|
-
state.unit_health,
|
257
|
-
jnp.expand_dims(attacked_idxes, 1),
|
258
|
-
health_diff,
|
259
|
-
dnums,
|
260
|
-
),
|
261
|
-
0.0,
|
262
|
-
)
|
263
|
-
|
264
|
-
#########################################################
|
265
|
-
############################ subtracting bystander health
|
266
|
-
|
267
|
-
_, bystander_health_diff = bystander
|
268
|
-
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
269
|
-
|
270
|
-
#########################################################
|
271
|
-
#########################################################
|
272
|
-
|
273
|
-
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
274
|
-
state = state.replace(
|
275
|
-
unit_health=unit_health,
|
276
|
-
unit_positions=pos,
|
277
|
-
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
278
|
-
)
|
279
|
-
return state
|
280
|
-
|
281
|
-
if __name__ == "__main__":
|
282
|
-
env = Parabellum(map_width=256, map_height=256)
|
283
|
-
rng, key = random.split(random.PRNGKey(0))
|
284
|
-
obs, state = env.reset(key)
|
285
|
-
state_seq = []
|
286
|
-
for step in range(100):
|
287
|
-
rng, key = random.split(rng)
|
288
|
-
key_act = random.split(key, len(env.agents))
|
289
|
-
actions = {
|
290
|
-
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
-
for i, agent in enumerate(env.agents)
|
292
|
-
}
|
293
|
-
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
-
state_seq.append((obs, state, actions))
|
295
|
-
|
296
|
-
|