parabellum 0.2.26__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,42 +1,44 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.26
3
+ Version: 0.3.1
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
7
7
  Keywords: warfare,simulation,parallel,environment
8
8
  Author: Noah Syrkis
9
9
  Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
10
+ Requires-Python: >=3.11,<3.12
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
14
+ Requires-Dist: cartopy (>=0.23.0,<0.24.0)
15
15
  Requires-Dist: contextily (>=1.6.0,<2.0.0)
16
16
  Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
17
+ Requires-Dist: einops (>=0.8.0,<0.9.0)
17
18
  Requires-Dist: folium (>=0.17.0,<0.18.0)
18
- Requires-Dist: geopandas (>=1.0.0,<2.0.0)
19
19
  Requires-Dist: geopy (>=2.4.1,<3.0.0)
20
20
  Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
21
21
  Requires-Dist: jax (==0.4.17)
22
22
  Requires-Dist: jaxmarl (==0.0.3)
23
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
23
24
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
24
25
  Requires-Dist: moviepy (>=1.0.3,<2.0.0)
25
26
  Requires-Dist: numpy (<2)
26
27
  Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
27
- Requires-Dist: osmnx (>=1.9.3,<2.0.0)
28
+ Requires-Dist: osmnx (==2.0.0b0)
28
29
  Requires-Dist: pandas (>=2.2.2,<3.0.0)
29
30
  Requires-Dist: poetry (>=1.8.3,<2.0.0)
30
31
  Requires-Dist: pygame (>=2.5.2,<3.0.0)
31
32
  Requires-Dist: rasterio (>=1.3.10,<2.0.0)
32
33
  Requires-Dist: seaborn (>=0.13.2,<0.14.0)
34
+ Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
33
35
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
34
36
  Project-URL: Repository, https://github.com/syrkis/parabellum
35
37
  Description-Content-Type: text/markdown
36
38
 
37
39
  # Parabellum
38
40
 
39
- Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
41
+ Ultra-scalable JaxMARL based warfare simulation engine.
40
42
 
41
43
  [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
42
44
 
@@ -93,3 +95,4 @@ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.i
93
95
  ## License
94
96
 
95
97
  MIT
98
+
@@ -1,6 +1,6 @@
1
1
  # Parabellum
2
2
 
3
- Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
3
+ Ultra-scalable JaxMARL based warfare simulation engine.
4
4
 
5
5
  [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
6
6
 
@@ -56,4 +56,4 @@ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.i
56
56
 
57
57
  ## License
58
58
 
59
- MIT
59
+ MIT
@@ -1,18 +1,22 @@
1
- from .env import Environment, Scenario, scenarios, make_scenario, State
1
+ from .env import Environment, Scenario, make_scenario, State
2
2
  from .vis import Visualizer, Skin
3
- from .map import terrain_fn
4
3
  from .gun import bullet_fn
5
- # from .aid import aid
4
+ from . import vis
5
+ from . import terrain_db
6
+ from . import env
7
+ from . import tps
6
8
  # from .run import run
7
9
 
8
10
  __all__ = [
11
+ "env",
12
+ "terrain_db",
13
+ "vis",
14
+ "tps",
9
15
  "Environment",
10
16
  "Scenario",
11
- "scenarios",
12
17
  "make_scenario",
13
18
  "State",
14
19
  "Visualizer",
15
20
  "Skin",
16
- "terrain_fn",
17
21
  "bullet_fn",
18
22
  ]
@@ -0,0 +1,27 @@
1
+ # aid.py
2
+ # what you call utils.py when you want file names to be 3 letters
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ import os
7
+ from collections import namedtuple
8
+ from typing import Tuple
9
+ import cartopy.crs as ccrs
10
+
11
+ # types
12
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
13
+
14
+
15
+ # coordinate function
16
+ def to_mercator(bbox: BBox) -> BBox:
17
+ proj = ccrs.Mercator()
18
+ west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
19
+ east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
20
+ return BBox(north=north, south=south, east=east, west=west)
21
+
22
+
23
+ def to_platecarree(bbox: BBox) -> BBox:
24
+ proj = ccrs.PlateCarree()
25
+ west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
26
+ east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
27
+ return BBox(north=north, south=south, east=east, west=west)
@@ -0,0 +1,561 @@
1
+ """Parabellum environment based on SMAX"""
2
+
3
+ import jax.numpy as jnp
4
+ import jax
5
+ from jax import random, Array, vmap, jit
6
+ from flax.struct import dataclass
7
+ import chex
8
+ from jaxmarl.environments.smax.smax_env import SMAX
9
+
10
+ from math import ceil
11
+
12
+ from typing import Tuple, Dict, cast
13
+ from functools import partial
14
+ from parabellum import tps, geo, terrain_db
15
+
16
+
17
+ @dataclass
18
+ class Scenario:
19
+ """Parabellum scenario"""
20
+
21
+ place: str
22
+ terrain: tps.Terrain
23
+ unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
24
+ unit_types: chex.Array
25
+ num_allies: int
26
+ num_enemies: int
27
+
28
+ smacv2_position_generation: bool = False
29
+ smacv2_unit_type_generation: bool = False
30
+
31
+
32
+ @dataclass
33
+ class State:
34
+ # terrain: Array
35
+ unit_positions: Array # fsfds
36
+ unit_alive: Array
37
+ unit_teams: Array
38
+ unit_health: Array
39
+ unit_types: Array
40
+ unit_weapon_cooldowns: Array
41
+ prev_movement_actions: Array
42
+ prev_attack_actions: Array
43
+ time: int
44
+ terminal: bool
45
+
46
+
47
+ def make_scenario(
48
+ place,
49
+ size,
50
+ unit_starting_sectors,
51
+ allies_type,
52
+ n_allies,
53
+ enemies_type,
54
+ n_enemies,
55
+ ):
56
+ if place in terrain_db.db:
57
+ terrain = terrain_db.make_terrain(terrain_db.db[place], size)
58
+ else:
59
+ terrain = geo.geography_fn(place, size)
60
+ if type(unit_starting_sectors) == list:
61
+ default_sector = [
62
+ 0,
63
+ 0,
64
+ size,
65
+ size,
66
+ ] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
67
+ correct_unit_starting_sectors = []
68
+ for i in range(n_allies + n_enemies):
69
+ selected_sector = None
70
+ for unit_ids, sector in unit_starting_sectors:
71
+ if i in unit_ids:
72
+ selected_sector = sector
73
+ if selected_sector is None:
74
+ selected_sector = default_sector
75
+ correct_unit_starting_sectors.append(selected_sector)
76
+ unit_starting_sectors = correct_unit_starting_sectors
77
+ if type(allies_type) == int:
78
+ allies = [allies_type] * n_allies
79
+ else:
80
+ assert len(allies_type) == n_allies
81
+ allies = allies_type
82
+
83
+ if type(enemies_type) == int:
84
+ enemies = [enemies_type] * n_enemies
85
+ else:
86
+ assert len(enemies_type) == n_enemies
87
+ enemies = enemies_type
88
+ unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
89
+ return Scenario(
90
+ place,
91
+ terrain,
92
+ unit_starting_sectors, # type: ignore
93
+ unit_types,
94
+ n_allies,
95
+ n_enemies,
96
+ )
97
+
98
+
99
+ def scenario_fn(place):
100
+ # scenario function for Noah, cos the one above is confusing
101
+ terrain = geo.geography_fn(place)
102
+ num_allies = 10
103
+ num_enemies = 10
104
+ unit_types = jnp.array([0] * num_allies + [1] * num_enemies, dtype=jnp.uint8)
105
+ # start units in default sectors
106
+ unit_starting_sectors = jnp.array([[0, 0, 1, 1]] * (num_allies + num_enemies))
107
+ return Scenario(
108
+ place=place,
109
+ terrain=terrain,
110
+ unit_starting_sectors=unit_starting_sectors,
111
+ unit_types=unit_types,
112
+ num_allies=num_allies,
113
+ num_enemies=num_enemies,
114
+ )
115
+
116
+
117
+ def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
118
+ """Spawns n agents on a map."""
119
+ spawn_positions = []
120
+ for sector in units_spawning_sectors:
121
+ rng, key_start, key_noise = random.split(rng, 3)
122
+ noise = 0.25 + random.uniform(key_noise, (2,)) * 0.5
123
+ idx = random.choice(key_start, sector[0].shape[0])
124
+ coord = jnp.array([sector[0][idx], sector[1][idx]])
125
+ spawn_positions.append(coord + noise)
126
+ return jnp.array(spawn_positions, dtype=jnp.float32)
127
+
128
+
129
+ def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
130
+ """
131
+ sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
132
+ """
133
+ width, height = invalid_spawn_areas.shape
134
+ spawning_sectors = []
135
+ for sector in sectors:
136
+ coordx, coordy = (
137
+ jnp.array(sector[0] * width, dtype=jnp.int32),
138
+ jnp.array(sector[1] * height, dtype=jnp.int32),
139
+ )
140
+ sector = (
141
+ invalid_spawn_areas[
142
+ coordx : coordx + ceil(sector[2] * width),
143
+ coordy : coordy + ceil(sector[3] * height),
144
+ ]
145
+ == 0
146
+ )
147
+ valid = jnp.nonzero(sector)
148
+ if valid[0].shape[0] == 0:
149
+ raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
150
+ spawning_sectors.append(
151
+ jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1))
152
+ )
153
+ return spawning_sectors
154
+
155
+
156
+ class Environment(SMAX):
157
+ def __init__(self, scenario: Scenario, **kwargs):
158
+ map_height, map_width = scenario.terrain.building.shape
159
+ args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
160
+ if "unit_type_pushable" in kwargs:
161
+ self.unit_type_pushable = kwargs["unit_type_pushable"]
162
+ del kwargs["unit_type_pushable"]
163
+ else:
164
+ self.unit_type_pushable = jnp.array([1, 1, 0, 0, 0, 1])
165
+ if "reset_when_done" in kwargs:
166
+ self.reset_when_done = kwargs["reset_when_done"]
167
+ del kwargs["reset_when_done"]
168
+ else:
169
+ self.reset_when_done = True
170
+ super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
171
+ self.terrain = scenario.terrain
172
+ self.unit_starting_sectors = scenario.unit_starting_sectors
173
+ # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
174
+ # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
175
+ # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
176
+ self.scenario = scenario
177
+ self.unit_type_velocities = (
178
+ jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15]) / 2.5
179
+ if "unit_type_velocities" not in kwargs
180
+ else kwargs["unit_type_velocities"]
181
+ )
182
+ self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
183
+ self.max_steps = 200
184
+ self._push_units_away = lambda state, firmness=1: state # overwrite push units
185
+ self.spawning_sectors = sectors_fn(
186
+ self.unit_starting_sectors,
187
+ scenario.terrain.building + scenario.terrain.water,
188
+ )
189
+ self.resolution = (
190
+ jnp.array(jnp.max(self.unit_type_sight_ranges), dtype=jnp.int32) * 2
191
+ )
192
+ self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, 1))
193
+
194
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
195
+ """Environment-specific reset."""
196
+ unit_positions = spawn_fn(rng, self.spawning_sectors)
197
+ unit_teams = jnp.zeros((self.num_agents,))
198
+ unit_teams = unit_teams.at[self.num_allies :].set(1)
199
+ unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
200
+ # default behaviour spawn all marines
201
+ unit_types = cast(Array, self.scenario.unit_types)
202
+ unit_health = self.unit_type_health[unit_types]
203
+ state = State(
204
+ unit_positions=unit_positions,
205
+ unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
206
+ unit_teams=unit_teams,
207
+ unit_health=unit_health,
208
+ unit_types=unit_types,
209
+ prev_movement_actions=jnp.zeros((self.num_agents, 2)),
210
+ prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
211
+ time=0,
212
+ terminal=False,
213
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
214
+ # terrain=self.terrain,
215
+ )
216
+ state = self._push_units_away(state) # type: ignore could be slow
217
+ obs = self.get_obs(state)
218
+ # remove world_state from obs
219
+ world_state = self.get_world_state(state)
220
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
221
+ return obs, state
222
+
223
+ # def step_env(self, rng, state: State, action: Array): # type: ignore
224
+ # obs, state, rewards, dones, infos = super().step_env(rng, state, action)
225
+ # delete world_state from obs
226
+ # obs.pop("world_state")
227
+ # if not self.reset_when_done:
228
+ # for key in dones.keys():
229
+ # dones[key] = False
230
+ # return obs, state, rewards, dones, infos
231
+
232
+ def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
233
+ """Applies observation function to state."""
234
+
235
+ def get_features(i, j):
236
+ """Get features of unit j as seen from unit i"""
237
+ # Can just keep them symmetrical for now.
238
+ # j here means 'the jth unit that is not i'
239
+ # The observation is such that allies are always first
240
+ # so for units in the second team we count in reverse.
241
+ j = jax.lax.cond(
242
+ i < self.num_allies, lambda: j, lambda: self.num_agents - j - 1
243
+ )
244
+ offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
245
+ j_idx = jax.lax.cond(
246
+ ((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
247
+ lambda: j,
248
+ lambda: j + offset,
249
+ )
250
+ empty_features = jnp.zeros(shape=(len(self.unit_features),))
251
+ features = self._observe_features(state, i, j_idx)
252
+ visible = (
253
+ jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
254
+ < self.unit_type_sight_ranges[state.unit_types[i]]
255
+ )
256
+ return jax.lax.cond(
257
+ visible
258
+ & state.unit_alive[i]
259
+ & state.unit_alive[j_idx]
260
+ & self.has_line_of_sight(
261
+ state.unit_positions[j_idx],
262
+ state.unit_positions[i],
263
+ self.terrain.building + self.terrain.forest,
264
+ ),
265
+ lambda: features,
266
+ lambda: empty_features,
267
+ )
268
+
269
+ get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
270
+ get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
271
+ other_unit_obs = get_all_features(
272
+ jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
273
+ )
274
+ other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
275
+ get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
276
+ own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
277
+ obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
278
+ return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
279
+
280
+ def _our_push_units_away(
281
+ self, pos, unit_types, firmness: float = 1.0
282
+ ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
283
+ delta_matrix = pos[:, None] - pos[None, :]
284
+ dist_matrix = (
285
+ jnp.linalg.norm(delta_matrix, axis=-1)
286
+ + jnp.identity(self.num_agents)
287
+ + 1e-6
288
+ )
289
+ radius_matrix = (
290
+ self.unit_type_radiuses[unit_types][:, None]
291
+ + self.unit_type_radiuses[unit_types][None, :]
292
+ )
293
+ overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
294
+ unit_positions = (
295
+ pos
296
+ + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
297
+ )
298
+ return jnp.where(
299
+ self.unit_type_pushable[unit_types][:, None], unit_positions, pos
300
+ )
301
+
302
+ def has_line_of_sight(self, source, target, raster_input):
303
+ # suppose that the target is in sight_range of source, otherwise the line of sight might miss some cells
304
+ cells = jnp.array(
305
+ source[:, jnp.newaxis] * self.t + (1 - self.t) * target[:, jnp.newaxis],
306
+ dtype=jnp.int32,
307
+ )
308
+ mask = jnp.zeros(raster_input.shape).at[cells[0, :], cells[1, :]].set(1)
309
+ flag = ~jnp.any(jnp.logical_and(mask, raster_input))
310
+ return flag
311
+
312
+ @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
313
+ def _world_step( # modified version of JaxMARL's SMAX _world_step
314
+ self,
315
+ key: chex.PRNGKey,
316
+ state: State,
317
+ actions: Tuple[chex.Array, chex.Array],
318
+ ) -> State:
319
+ def raster_crossing(pos, new_pos, mask: jnp.ndarray):
320
+ pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
321
+ minimum = jnp.minimum(pos, new_pos)
322
+ maximum = jnp.maximum(pos, new_pos)
323
+ mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask.T, 0).T
324
+ mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask.T, 0).T
325
+ mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask, 0)
326
+ mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask, 0)
327
+ return jnp.any(mask)
328
+
329
+ def update_position(idx, vec):
330
+ # Compute the movements slightly strangely.
331
+ # The velocities below are for diagonal directions
332
+ # because these are easier to encode as actions than the four
333
+ # diagonal directions. Then rotate the velocity 45
334
+ # degrees anticlockwise to compute the movement.
335
+ pos = cast(Array, state.unit_positions[idx])
336
+ new_pos = (
337
+ pos
338
+ + vec
339
+ * self.unit_type_velocities[state.unit_types[idx]]
340
+ * self.time_per_step
341
+ )
342
+ # avoid going out of bounds
343
+ new_pos = jnp.maximum(
344
+ jnp.minimum(
345
+ new_pos, jnp.array([self.map_width - 1, self.map_height - 1])
346
+ ),
347
+ jnp.zeros((2,)),
348
+ )
349
+
350
+ #######################################################################
351
+ ############################################ avoid going into obstacles
352
+ clash = raster_crossing(
353
+ pos, new_pos, self.terrain.building + self.terrain.water
354
+ )
355
+ new_pos = jnp.where(clash, pos, new_pos)
356
+
357
+ #######################################################################
358
+ #######################################################################
359
+
360
+ return new_pos
361
+
362
+ #######################################################################
363
+ ######################################### units close enough to get hit
364
+
365
+ def bystander_fn(attacked_idx):
366
+ idxs = jnp.zeros((self.num_agents,))
367
+ idxs *= (
368
+ jnp.linalg.norm(
369
+ state.unit_positions - state.unit_positions[attacked_idx], axis=-1
370
+ )
371
+ < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
372
+ )
373
+ return idxs
374
+
375
+ #######################################################################
376
+ #######################################################################
377
+
378
+ def update_agent_health(idx, action, key): # TODO: add attack blasts
379
+ # for team 1, their attack actions are labelled in
380
+ # reverse order because that is the order they are
381
+ # observed in
382
+ attacked_idx = jax.lax.cond(
383
+ idx < self.num_allies,
384
+ lambda: action + self.num_allies - self.num_movement_actions,
385
+ lambda: self.num_allies - 1 - (action - self.num_movement_actions),
386
+ )
387
+ attacked_idx = cast(int, attacked_idx) # Cast to int
388
+ # deal with no-op attack actions (i.e. agents that are moving instead)
389
+ attacked_idx = jax.lax.select(
390
+ action < self.num_movement_actions, idx, attacked_idx
391
+ )
392
+ distance = jnp.linalg.norm(
393
+ state.unit_positions[idx] - state.unit_positions[attacked_idx]
394
+ )
395
+ attack_valid = (
396
+ (distance <= self.unit_type_attack_ranges[state.unit_types[idx]])
397
+ & state.unit_alive[idx]
398
+ & state.unit_alive[attacked_idx]
399
+ )
400
+ attack_valid = attack_valid & (idx != attacked_idx)
401
+ attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
402
+ health_diff = jax.lax.select(
403
+ attack_valid,
404
+ -self.unit_type_attacks[state.unit_types[idx]],
405
+ 0.0,
406
+ )
407
+ health_diff = jnp.where(
408
+ state.unit_types[idx] == 1,
409
+ health_diff
410
+ * distance
411
+ / self.unit_type_attack_ranges[state.unit_types[idx]],
412
+ health_diff,
413
+ )
414
+ # design choice based on the pysc2 randomness details.
415
+ # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
416
+
417
+ #########################################################
418
+ ############################### Add bystander health diff
419
+
420
+ # bystander_idxs = bystander_fn(attacked_idx) # TODO: use
421
+ # bystander_valid = (
422
+ # jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
423
+ # .astype(jnp.bool_) # type: ignore
424
+ # .astype(jnp.float32)
425
+ # )
426
+ # bystander_health_diff = (
427
+ # bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
428
+ # )
429
+
430
+ #########################################################
431
+ #########################################################
432
+
433
+ cooldown_deviation = jax.random.uniform(
434
+ key, minval=-self.time_per_step, maxval=2 * self.time_per_step
435
+ )
436
+ cooldown = (
437
+ self.unit_type_weapon_cooldowns[state.unit_types[idx]]
438
+ + cooldown_deviation
439
+ )
440
+ cooldown_diff = jax.lax.select(
441
+ attack_valid,
442
+ # subtract the current cooldown because we are
443
+ # going to add it back. This way we effectively
444
+ # set the new cooldown to `cooldown`
445
+ cooldown - state.unit_weapon_cooldowns[idx],
446
+ -self.time_per_step,
447
+ )
448
+ return (
449
+ health_diff,
450
+ attacked_idx,
451
+ cooldown_diff,
452
+ # (bystander_health_diff, bystander_idxs),
453
+ )
454
+
455
+ def perform_agent_action(idx, action, key):
456
+ movement_action, attack_action = action
457
+ new_pos = update_position(idx, movement_action)
458
+ health_diff, attacked_idxes, cooldown_diff = update_agent_health(
459
+ idx, attack_action, key
460
+ )
461
+
462
+ return new_pos, (health_diff, attacked_idxes), cooldown_diff
463
+
464
+ keys = jax.random.split(key, num=self.num_agents)
465
+ pos, (health_diff, attacked_idxes), cooldown_diff = jax.vmap(
466
+ perform_agent_action
467
+ )(jnp.arange(self.num_agents), actions, keys)
468
+
469
+ # units push each other
470
+ new_pos = self._our_push_units_away(pos, state.unit_types)
471
+ clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(
472
+ pos, new_pos, self.terrain.building + self.terrain.water
473
+ )
474
+ pos = jax.vmap(jnp.where)(clash, pos, new_pos)
475
+ # avoid going out of bounds
476
+ pos = jnp.maximum(
477
+ jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
478
+ jnp.zeros((2,)),
479
+ )
480
+
481
+ # Multiple enemies can attack the same unit.
482
+ # We have `(health_diff, attacked_idx)` pairs.
483
+ # `jax.lax.scatter_add` aggregates these exactly
484
+ # in the way we want -- duplicate idxes will have their
485
+ # health differences added together. However, it is a
486
+ # super thin wrapper around the XLA scatter operation,
487
+ # which has this bonkers syntax and requires this dnums
488
+ # parameter. The usage here was inferred from a test:
489
+ # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
490
+ dnums = jax.lax.ScatterDimensionNumbers(
491
+ update_window_dims=(),
492
+ inserted_window_dims=(0,),
493
+ scatter_dims_to_operand_dims=(0,),
494
+ )
495
+ unit_health = jnp.maximum(
496
+ jax.lax.scatter_add(
497
+ state.unit_health,
498
+ jnp.expand_dims(attacked_idxes, 1),
499
+ health_diff,
500
+ dnums,
501
+ ),
502
+ 0.0,
503
+ )
504
+
505
+ #########################################################
506
+ ############################ subtracting bystander health
507
+
508
+ # _, bystander_health_diff = bystander
509
+ # unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
510
+
511
+ #########################################################
512
+ #########################################################
513
+
514
+ unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
515
+ # replace unit health, unit positions and unit weapon cooldowns
516
+ state = state.replace( # type: ignore
517
+ unit_health=unit_health,
518
+ unit_positions=pos,
519
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
520
+ )
521
+ return state
522
+
523
+
524
+ if __name__ == "__main__":
525
+ n_envs = 4
526
+
527
+ n_allies = 10
528
+ scenario_kwargs = {
529
+ "allies_type": 0,
530
+ "n_allies": n_allies,
531
+ "enemies_type": 0,
532
+ "n_enemies": n_allies,
533
+ "place": "Vesterbro, Copenhagen, Denmark",
534
+ "size": 100,
535
+ "unit_starting_sectors": [
536
+ ([i for i in range(n_allies)], [0.0, 0.45, 0.1, 0.1]),
537
+ ([n_allies + i for i in range(n_allies)], [0.8, 0.5, 0.1, 0.1]),
538
+ ],
539
+ }
540
+ scenario = make_scenario(**scenario_kwargs)
541
+ env = Environment(scenario)
542
+ rng, reset_rng = random.split(random.PRNGKey(0))
543
+ reset_key = random.split(reset_rng, n_envs)
544
+ obs, state = vmap(env.reset)(reset_key)
545
+ state_seq = []
546
+
547
+ import time
548
+
549
+ step = vmap(jit(env.step))
550
+ tic = time.time()
551
+ for i in range(10):
552
+ rng, act_rng, step_rng = random.split(rng, 3)
553
+ act_key = random.split(act_rng, (len(env.agents), n_envs))
554
+ act = {
555
+ a: vmap(env.action_space(a).sample)(act_key[i])
556
+ for i, a in enumerate(env.agents)
557
+ }
558
+ step_key = random.split(step_rng, n_envs)
559
+ state_seq.append((step_key, state, act))
560
+ obs, state, reward, done, infos = step(step_key, state, act)
561
+ tic = time.time()