parabellum 0.2.26__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,18 +1,22 @@
1
- from .env import Environment, Scenario, scenarios, make_scenario, State
1
+ from .env import Environment, Scenario, make_scenario, State
2
2
  from .vis import Visualizer, Skin
3
- from .map import terrain_fn
4
3
  from .gun import bullet_fn
5
- # from .aid import aid
4
+ from . import vis
5
+ from . import terrain_db
6
+ from . import env
7
+ from . import tps
6
8
  # from .run import run
7
9
 
8
10
  __all__ = [
11
+ "env",
12
+ "terrain_db",
13
+ "vis",
14
+ "tps",
9
15
  "Environment",
10
16
  "Scenario",
11
- "scenarios",
12
17
  "make_scenario",
13
18
  "State",
14
19
  "Visualizer",
15
20
  "Skin",
16
- "terrain_fn",
17
21
  "bullet_fn",
18
22
  ]
parabellum/aid.py CHANGED
@@ -3,3 +3,25 @@
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # imports
6
+ import os
7
+ from collections import namedtuple
8
+ from typing import Tuple
9
+ import cartopy.crs as ccrs
10
+
11
+ # types
12
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
13
+
14
+
15
+ # coordinate function
16
+ def to_mercator(bbox: BBox) -> BBox:
17
+ proj = ccrs.Mercator()
18
+ west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
19
+ east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
20
+ return BBox(north=north, south=south, east=east, west=west)
21
+
22
+
23
+ def to_platecarree(bbox: BBox) -> BBox:
24
+ proj = ccrs.PlateCarree()
25
+ west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
26
+ east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
27
+ return BBox(north=north, south=south, east=east, west=west)
parabellum/env.py CHANGED
@@ -2,15 +2,16 @@
2
2
 
3
3
  import jax.numpy as jnp
4
4
  import jax
5
- import numpy as np
6
- from jax import random, Array
7
- from jax import jit
5
+ from jax import random, Array, vmap, jit
8
6
  from flax.struct import dataclass
9
7
  import chex
10
- from jax import vmap
11
8
  from jaxmarl.environments.smax.smax_env import SMAX
9
+
10
+ from math import ceil
11
+
12
12
  from typing import Tuple, Dict, cast
13
13
  from functools import partial
14
+ from parabellum import tps, geo, terrain_db
14
15
 
15
16
 
16
17
  @dataclass
@@ -18,8 +19,8 @@ class Scenario:
18
19
  """Parabellum scenario"""
19
20
 
20
21
  place: str
21
- terrain_raster: jnp.ndarray
22
- unit_starting_sectors: jnp.ndarray
22
+ terrain: tps.Terrain
23
+ unit_starting_sectors: jnp.ndarray # must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
23
24
  unit_types: chex.Array
24
25
  num_allies: int
25
26
  num_enemies: int
@@ -27,9 +28,11 @@ class Scenario:
27
28
  smacv2_position_generation: bool = False
28
29
  smacv2_unit_type_generation: bool = False
29
30
 
31
+
30
32
  @dataclass
31
33
  class State:
32
- unit_positions: Array
34
+ # terrain: Array
35
+ unit_positions: Array # fsfds
33
36
  unit_alive: Array
34
37
  unit_teams: Array
35
38
  unit_health: Array
@@ -41,93 +44,156 @@ class State:
41
44
  terminal: bool
42
45
 
43
46
 
44
- # default scenario
45
- scenarios = {
46
- "default": Scenario(
47
- "Identity Town",
48
- jnp.eye(64, dtype=jnp.uint8),
49
- jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
50
- jnp.zeros((19,), dtype=jnp.uint8),
51
- 9,
52
- 10,
53
- )
54
- }
55
-
56
-
57
- def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_allies, enemies_type, n_enemies):
47
+ def make_scenario(
48
+ place,
49
+ size,
50
+ unit_starting_sectors,
51
+ allies_type,
52
+ n_allies,
53
+ enemies_type,
54
+ n_enemies,
55
+ ):
56
+ if place in terrain_db.db:
57
+ terrain = terrain_db.make_terrain(terrain_db.db[place], size)
58
+ else:
59
+ terrain = geo.geography_fn(place, size)
60
+ if type(unit_starting_sectors) == list:
61
+ default_sector = [
62
+ 0,
63
+ 0,
64
+ size,
65
+ size,
66
+ ] # Noah feel confident that this is right. This means 50% chance. Sorry timothee if you end up here later. my bad bro.
67
+ correct_unit_starting_sectors = []
68
+ for i in range(n_allies + n_enemies):
69
+ selected_sector = None
70
+ for unit_ids, sector in unit_starting_sectors:
71
+ if i in unit_ids:
72
+ selected_sector = sector
73
+ if selected_sector is None:
74
+ selected_sector = default_sector
75
+ correct_unit_starting_sectors.append(selected_sector)
76
+ unit_starting_sectors = correct_unit_starting_sectors
58
77
  if type(allies_type) == int:
59
78
  allies = [allies_type] * n_allies
60
79
  else:
61
- assert(len(allies_type) == n_allies)
80
+ assert len(allies_type) == n_allies
62
81
  allies = allies_type
63
82
 
64
83
  if type(enemies_type) == int:
65
84
  enemies = [enemies_type] * n_enemies
66
85
  else:
67
- assert(len(enemies_type) == n_enemies)
86
+ assert len(enemies_type) == n_enemies
68
87
  enemies = enemies_type
69
88
  unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
70
- return Scenario(place, terrain_raster, unit_starting_sectors, unit_types, n_allies, n_enemies)
71
-
72
-
73
- def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
74
- """Spawns n agents on a map."""
75
- rng, key_start, key_noise = random.split(rng, 3)
76
- noise = random.uniform(key_noise, (n, 2)) * 0.5
77
-
78
- # select n random (x, y)-coords where sector == True
79
- idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
80
- coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
81
-
82
- return coords + noise + offset
89
+ return Scenario(
90
+ place,
91
+ terrain,
92
+ unit_starting_sectors, # type: ignore
93
+ unit_types,
94
+ n_allies,
95
+ n_enemies,
96
+ )
83
97
 
84
98
 
85
- def sector_fn(terrain: jnp.ndarray, sector_id: int):
86
- """return sector slice of terrain"""
87
- width, height = terrain.shape
88
- coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
89
- sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
90
- offset = jnp.array([coordx, coordy])
91
- # sector is jnp.nonzero
92
- return jnp.nonzero(sector), offset
99
+ def scenario_fn(place):
100
+ # scenario function for Noah, cos the one above is confusing
101
+ terrain = geo.geography_fn(place)
102
+ num_allies = 10
103
+ num_enemies = 10
104
+ unit_types = jnp.array([0] * num_allies + [1] * num_enemies, dtype=jnp.uint8)
105
+ # start units in default sectors
106
+ unit_starting_sectors = jnp.array([[0, 0, 1, 1]] * (num_allies + num_enemies))
107
+ return Scenario(
108
+ place=place,
109
+ terrain=terrain,
110
+ unit_starting_sectors=unit_starting_sectors,
111
+ unit_types=unit_types,
112
+ num_allies=num_allies,
113
+ num_enemies=num_enemies,
114
+ )
93
115
 
94
116
 
95
- def sector_fn(terrain: jnp.ndarray, sector: jnp.ndarray):
96
- """return sector slice of terrain"""
97
- width, height = terrain.shape
98
- coordx, coordy = int(sector[0] * width), int(sector[1] * height)
99
- sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
100
- offset = jnp.array([coordx, coordy])
101
- # sector is jnp.nonzero
102
- return jnp.nonzero(sector.T), offset
117
+ def spawn_fn(rng: jnp.ndarray, units_spawning_sectors):
118
+ """Spawns n agents on a map."""
119
+ spawn_positions = []
120
+ for sector in units_spawning_sectors:
121
+ rng, key_start, key_noise = random.split(rng, 3)
122
+ noise = 0.25 + random.uniform(key_noise, (2,)) * 0.5
123
+ idx = random.choice(key_start, sector[0].shape[0])
124
+ coord = jnp.array([sector[0][idx], sector[1][idx]])
125
+ spawn_positions.append(coord + noise)
126
+ return jnp.array(spawn_positions, dtype=jnp.float32)
127
+
128
+
129
+ def sectors_fn(sectors: jnp.ndarray, invalid_spawn_areas: jnp.ndarray):
130
+ """
131
+ sectors must be of size (num_units, 4) where sectors[i] = (x, y, width, height) of the ith unit's spawning sector (in % of the real map)
132
+ """
133
+ width, height = invalid_spawn_areas.shape
134
+ spawning_sectors = []
135
+ for sector in sectors:
136
+ coordx, coordy = (
137
+ jnp.array(sector[0] * width, dtype=jnp.int32),
138
+ jnp.array(sector[1] * height, dtype=jnp.int32),
139
+ )
140
+ sector = (
141
+ invalid_spawn_areas[
142
+ coordx : coordx + ceil(sector[2] * width),
143
+ coordy : coordy + ceil(sector[3] * height),
144
+ ]
145
+ == 0
146
+ )
147
+ valid = jnp.nonzero(sector)
148
+ if valid[0].shape[0] == 0:
149
+ raise ValueError(f"Sector {sector} only contains invalid spawn areas.")
150
+ spawning_sectors.append(
151
+ jnp.array(valid) + jnp.array([coordx, coordy]).reshape((2, -1))
152
+ )
153
+ return spawning_sectors
103
154
 
104
155
 
105
156
  class Environment(SMAX):
106
157
  def __init__(self, scenario: Scenario, **kwargs):
107
- map_height, map_width = scenario.terrain_raster.shape
158
+ map_height, map_width = scenario.terrain.building.shape
108
159
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
160
+ if "unit_type_pushable" in kwargs:
161
+ self.unit_type_pushable = kwargs["unit_type_pushable"]
162
+ del kwargs["unit_type_pushable"]
163
+ else:
164
+ self.unit_type_pushable = jnp.array([1, 1, 0, 0, 0, 1])
165
+ if "reset_when_done" in kwargs:
166
+ self.reset_when_done = kwargs["reset_when_done"]
167
+ del kwargs["reset_when_done"]
168
+ else:
169
+ self.reset_when_done = True
109
170
  super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
110
- self.terrain_raster = scenario.terrain_raster
171
+ self.terrain = scenario.terrain
111
172
  self.unit_starting_sectors = scenario.unit_starting_sectors
112
173
  # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
113
174
  # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
114
175
  # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
115
176
  self.scenario = scenario
116
- self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
177
+ self.unit_type_velocities = (
178
+ jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15]) / 2.5
179
+ if "unit_type_velocities" not in kwargs
180
+ else kwargs["unit_type_velocities"]
181
+ )
117
182
  self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
118
183
  self.max_steps = 200
119
- self._push_units_away = lambda state, firmness = 1: state # overwrite push units
120
- self.team0_sector, self.team0_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[0]) # sector_fn(self.terrain_raster, 0)
121
- self.team1_sector, self.team1_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[1]) # sector_fn(self.terrain_raster, 24)
122
-
184
+ self._push_units_away = lambda state, firmness=1: state # overwrite push units
185
+ self.spawning_sectors = sectors_fn(
186
+ self.unit_starting_sectors,
187
+ scenario.terrain.building + scenario.terrain.water,
188
+ )
189
+ self.resolution = (
190
+ jnp.array(jnp.max(self.unit_type_sight_ranges), dtype=jnp.int32) * 2
191
+ )
192
+ self.t = jnp.tile(jnp.linspace(0, 1, self.resolution), (2, 1))
123
193
 
124
- @partial(jax.jit, static_argnums=(0,))
125
- def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
194
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: # type: ignore
126
195
  """Environment-specific reset."""
127
- ally_key, enemy_key = jax.random.split(rng)
128
- team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
129
- team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
130
- unit_positions = jnp.concatenate([team_0_start, team_1_start])
196
+ unit_positions = spawn_fn(rng, self.spawning_sectors)
131
197
  unit_teams = jnp.zeros((self.num_agents,))
132
198
  unit_teams = unit_teams.at[self.num_allies :].set(1)
133
199
  unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
@@ -145,18 +211,71 @@ class Environment(SMAX):
145
211
  time=0,
146
212
  terminal=False,
147
213
  unit_weapon_cooldowns=unit_weapon_cooldowns,
214
+ # terrain=self.terrain,
148
215
  )
149
- state = self._push_units_away(state) # type: ignore
216
+ state = self._push_units_away(state) # type: ignore could be slow
150
217
  obs = self.get_obs(state)
218
+ # remove world_state from obs
151
219
  world_state = self.get_world_state(state)
152
- # obs["world_state"] = jax.lax.stop_gradient(world_state)
220
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
153
221
  return obs, state
154
222
 
155
- def step_env(self, rng, state: State, action: Array):
156
- obs, state, rewards, dones, infos = super().step_env(rng, state, action)
223
+ # def step_env(self, rng, state: State, action: Array): # type: ignore
224
+ # obs, state, rewards, dones, infos = super().step_env(rng, state, action)
157
225
  # delete world_state from obs
158
- obs.pop("world_state")
159
- return obs, state, rewards, dones, infos
226
+ # obs.pop("world_state")
227
+ # if not self.reset_when_done:
228
+ # for key in dones.keys():
229
+ # dones[key] = False
230
+ # return obs, state, rewards, dones, infos
231
+
232
+ def get_obs_unit_list(self, state: State) -> Dict[str, chex.Array]: # type: ignore
233
+ """Applies observation function to state."""
234
+
235
+ def get_features(i, j):
236
+ """Get features of unit j as seen from unit i"""
237
+ # Can just keep them symmetrical for now.
238
+ # j here means 'the jth unit that is not i'
239
+ # The observation is such that allies are always first
240
+ # so for units in the second team we count in reverse.
241
+ j = jax.lax.cond(
242
+ i < self.num_allies, lambda: j, lambda: self.num_agents - j - 1
243
+ )
244
+ offset = jax.lax.cond(i < self.num_allies, lambda: 1, lambda: -1)
245
+ j_idx = jax.lax.cond(
246
+ ((j < i) & (i < self.num_allies)) | ((j > i) & (i >= self.num_allies)),
247
+ lambda: j,
248
+ lambda: j + offset,
249
+ )
250
+ empty_features = jnp.zeros(shape=(len(self.unit_features),))
251
+ features = self._observe_features(state, i, j_idx)
252
+ visible = (
253
+ jnp.linalg.norm(state.unit_positions[j_idx] - state.unit_positions[i])
254
+ < self.unit_type_sight_ranges[state.unit_types[i]]
255
+ )
256
+ return jax.lax.cond(
257
+ visible
258
+ & state.unit_alive[i]
259
+ & state.unit_alive[j_idx]
260
+ & self.has_line_of_sight(
261
+ state.unit_positions[j_idx],
262
+ state.unit_positions[i],
263
+ self.terrain.building + self.terrain.forest,
264
+ ),
265
+ lambda: features,
266
+ lambda: empty_features,
267
+ )
268
+
269
+ get_all_features_for_unit = jax.vmap(get_features, in_axes=(None, 0))
270
+ get_all_features = jax.vmap(get_all_features_for_unit, in_axes=(0, None))
271
+ other_unit_obs = get_all_features(
272
+ jnp.arange(self.num_agents), jnp.arange(self.num_agents - 1)
273
+ )
274
+ other_unit_obs = other_unit_obs.reshape((self.num_agents, -1))
275
+ get_all_self_features = jax.vmap(self._get_own_features, in_axes=(None, 0))
276
+ own_unit_obs = get_all_self_features(state, jnp.arange(self.num_agents))
277
+ obs = jnp.concatenate([other_unit_obs, own_unit_obs], axis=-1)
278
+ return {agent: obs[self.agent_ids[agent]] for agent in self.agents}
160
279
 
161
280
  def _our_push_units_away(
162
281
  self, pos, unit_types, firmness: float = 1.0
@@ -176,7 +295,19 @@ class Environment(SMAX):
176
295
  pos
177
296
  + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
178
297
  )
179
- return unit_positions
298
+ return jnp.where(
299
+ self.unit_type_pushable[unit_types][:, None], unit_positions, pos
300
+ )
301
+
302
+ def has_line_of_sight(self, source, target, raster_input):
303
+ # suppose that the target is in sight_range of source, otherwise the line of sight might miss some cells
304
+ cells = jnp.array(
305
+ source[:, jnp.newaxis] * self.t + (1 - self.t) * target[:, jnp.newaxis],
306
+ dtype=jnp.int32,
307
+ )
308
+ mask = jnp.zeros(raster_input.shape).at[cells[0, :], cells[1, :]].set(1)
309
+ flag = ~jnp.any(jnp.logical_and(mask, raster_input))
310
+ return flag
180
311
 
181
312
  @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
182
313
  def _world_step( # modified version of JaxMARL's SMAX _world_step
@@ -185,24 +316,15 @@ class Environment(SMAX):
185
316
  state: State,
186
317
  actions: Tuple[chex.Array, chex.Array],
187
318
  ) -> State:
188
- @partial(jax.vmap, in_axes=(None, None, 0, 0))
189
- def intersect_fn(pos, new_pos, obs, obs_end):
190
- d1 = jnp.cross(obs - pos, new_pos - pos)
191
- d2 = jnp.cross(obs_end - pos, new_pos - pos)
192
- d3 = jnp.cross(pos - obs, obs_end - obs)
193
- d4 = jnp.cross(new_pos - obs, obs_end - obs)
194
- return (d1 * d2 <= 0) & (d3 * d4 <= 0)
195
-
196
- def raster_crossing(pos, new_pos):
319
+ def raster_crossing(pos, new_pos, mask: jnp.ndarray):
197
320
  pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
198
- raster = jnp.copy(self.terrain_raster)
199
321
  minimum = jnp.minimum(pos, new_pos)
200
322
  maximum = jnp.maximum(pos, new_pos)
201
- raster = jnp.where(jnp.arange(raster.shape[0]) >= minimum[0], raster, 0)
202
- raster = jnp.where(jnp.arange(raster.shape[0]) <= maximum[0], raster, 0)
203
- raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
204
- raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
205
- return jnp.any(raster)
323
+ mask = jnp.where(jnp.arange(mask.shape[0]) >= minimum[0], mask.T, 0).T
324
+ mask = jnp.where(jnp.arange(mask.shape[0]) <= maximum[0], mask.T, 0).T
325
+ mask = jnp.where(jnp.arange(mask.shape[1]) >= minimum[1], mask, 0)
326
+ mask = jnp.where(jnp.arange(mask.shape[1]) <= maximum[1], mask, 0)
327
+ return jnp.any(mask)
206
328
 
207
329
  def update_position(idx, vec):
208
330
  # Compute the movements slightly strangely.
@@ -219,13 +341,17 @@ class Environment(SMAX):
219
341
  )
220
342
  # avoid going out of bounds
221
343
  new_pos = jnp.maximum(
222
- jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
344
+ jnp.minimum(
345
+ new_pos, jnp.array([self.map_width - 1, self.map_height - 1])
346
+ ),
223
347
  jnp.zeros((2,)),
224
348
  )
225
349
 
226
350
  #######################################################################
227
351
  ############################################ avoid going into obstacles
228
- clash = raster_crossing(pos, new_pos)
352
+ clash = raster_crossing(
353
+ pos, new_pos, self.terrain.building + self.terrain.water
354
+ )
229
355
  new_pos = jnp.where(clash, pos, new_pos)
230
356
 
231
357
  #######################################################################
@@ -263,14 +389,11 @@ class Environment(SMAX):
263
389
  attacked_idx = jax.lax.select(
264
390
  action < self.num_movement_actions, idx, attacked_idx
265
391
  )
266
-
392
+ distance = jnp.linalg.norm(
393
+ state.unit_positions[idx] - state.unit_positions[attacked_idx]
394
+ )
267
395
  attack_valid = (
268
- (
269
- jnp.linalg.norm(
270
- state.unit_positions[idx] - state.unit_positions[attacked_idx]
271
- )
272
- < self.unit_type_attack_ranges[state.unit_types[idx]]
273
- )
396
+ (distance <= self.unit_type_attack_ranges[state.unit_types[idx]])
274
397
  & state.unit_alive[idx]
275
398
  & state.unit_alive[attacked_idx]
276
399
  )
@@ -281,21 +404,28 @@ class Environment(SMAX):
281
404
  -self.unit_type_attacks[state.unit_types[idx]],
282
405
  0.0,
283
406
  )
407
+ health_diff = jnp.where(
408
+ state.unit_types[idx] == 1,
409
+ health_diff
410
+ * distance
411
+ / self.unit_type_attack_ranges[state.unit_types[idx]],
412
+ health_diff,
413
+ )
284
414
  # design choice based on the pysc2 randomness details.
285
415
  # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
286
416
 
287
417
  #########################################################
288
418
  ############################### Add bystander health diff
289
419
 
290
- bystander_idxs = bystander_fn(attacked_idx) # TODO: use
291
- bystander_valid = (
292
- jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
293
- .astype(jnp.bool_) # type: ignore
294
- .astype(jnp.float32)
295
- )
296
- bystander_health_diff = (
297
- bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
298
- )
420
+ # bystander_idxs = bystander_fn(attacked_idx) # TODO: use
421
+ # bystander_valid = (
422
+ # jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
423
+ # .astype(jnp.bool_) # type: ignore
424
+ # .astype(jnp.float32)
425
+ # )
426
+ # bystander_health_diff = (
427
+ # bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
428
+ # )
299
429
 
300
430
  #########################################################
301
431
  #########################################################
@@ -319,29 +449,34 @@ class Environment(SMAX):
319
449
  health_diff,
320
450
  attacked_idx,
321
451
  cooldown_diff,
322
- (bystander_health_diff, bystander_idxs),
452
+ # (bystander_health_diff, bystander_idxs),
323
453
  )
324
454
 
325
455
  def perform_agent_action(idx, action, key):
326
456
  movement_action, attack_action = action
327
457
  new_pos = update_position(idx, movement_action)
328
- health_diff, attacked_idxes, cooldown_diff, (bystander) = (
329
- update_agent_health(idx, attack_action, key)
458
+ health_diff, attacked_idxes, cooldown_diff = update_agent_health(
459
+ idx, attack_action, key
330
460
  )
331
461
 
332
- return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
462
+ return new_pos, (health_diff, attacked_idxes), cooldown_diff
333
463
 
334
464
  keys = jax.random.split(key, num=self.num_agents)
335
- pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
465
+ pos, (health_diff, attacked_idxes), cooldown_diff = jax.vmap(
336
466
  perform_agent_action
337
467
  )(jnp.arange(self.num_agents), actions, keys)
338
468
 
339
469
  # units push each other
340
470
  new_pos = self._our_push_units_away(pos, state.unit_types)
341
- clash = jax.vmap(raster_crossing)(pos, new_pos)
471
+ clash = jax.vmap(raster_crossing, in_axes=(0, 0, None))(
472
+ pos, new_pos, self.terrain.building + self.terrain.water
473
+ )
342
474
  pos = jax.vmap(jnp.where)(clash, pos, new_pos)
343
475
  # avoid going out of bounds
344
- pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
476
+ pos = jnp.maximum(
477
+ jnp.minimum(pos, jnp.array([self.map_width - 1, self.map_height - 1])), # type: ignore
478
+ jnp.zeros((2,)),
479
+ )
345
480
 
346
481
  # Multiple enemies can attack the same unit.
347
482
  # We have `(health_diff, attacked_idx)` pairs.
@@ -370,8 +505,8 @@ class Environment(SMAX):
370
505
  #########################################################
371
506
  ############################ subtracting bystander health
372
507
 
373
- _, bystander_health_diff = bystander
374
- unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
508
+ # _, bystander_health_diff = bystander
509
+ # unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
375
510
 
376
511
  #########################################################
377
512
  #########################################################
@@ -389,15 +524,30 @@ class Environment(SMAX):
389
524
  if __name__ == "__main__":
390
525
  n_envs = 4
391
526
 
392
- env = Environment(scenarios["default"])
527
+ n_allies = 10
528
+ scenario_kwargs = {
529
+ "allies_type": 0,
530
+ "n_allies": n_allies,
531
+ "enemies_type": 0,
532
+ "n_enemies": n_allies,
533
+ "place": "Vesterbro, Copenhagen, Denmark",
534
+ "size": 100,
535
+ "unit_starting_sectors": [
536
+ ([i for i in range(n_allies)], [0.0, 0.45, 0.1, 0.1]),
537
+ ([n_allies + i for i in range(n_allies)], [0.8, 0.5, 0.1, 0.1]),
538
+ ],
539
+ }
540
+ scenario = make_scenario(**scenario_kwargs)
541
+ env = Environment(scenario)
393
542
  rng, reset_rng = random.split(random.PRNGKey(0))
394
543
  reset_key = random.split(reset_rng, n_envs)
395
544
  obs, state = vmap(env.reset)(reset_key)
396
545
  state_seq = []
397
546
 
398
- print(state.unit_positions)
399
- exit()
547
+ import time
400
548
 
549
+ step = vmap(jit(env.step))
550
+ tic = time.time()
401
551
  for i in range(10):
402
552
  rng, act_rng, step_rng = random.split(rng, 3)
403
553
  act_key = random.split(act_rng, (len(env.agents), n_envs))
@@ -407,4 +557,5 @@ if __name__ == "__main__":
407
557
  }
408
558
  step_key = random.split(step_rng, n_envs)
409
559
  state_seq.append((step_key, state, act))
410
- obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
560
+ obs, state, reward, done, infos = step(step_key, state, act)
561
+ tic = time.time()
parabellum/geo.py ADDED
@@ -0,0 +1,130 @@
1
+ # %% geo.py
2
+ # script for geospatial level generation
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from parabellum import tps
7
+ import rasterio
8
+ from rasterio import features, transform
9
+ from geopy.geocoders import Nominatim
10
+ from geopy.distance import distance
11
+ import contextily as cx
12
+ import jax.numpy as jnp
13
+ import cartopy.crs as ccrs
14
+ from jaxtyping import Array
15
+ import numpy as np
16
+ from shapely import box
17
+ import osmnx as ox
18
+ import geopandas as gpd
19
+ from collections import namedtuple
20
+ from typing import Tuple
21
+ import matplotlib.pyplot as plt
22
+ import seaborn as sns
23
+ import os
24
+ from jax.scipy.signal import convolve
25
+
26
+ # %% Types
27
+ Coords = Tuple[float, float]
28
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
29
+
30
+ # %% Constants
31
+ provider = cx.providers.Stadia.StamenTerrain( # type: ignore
32
+ api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
33
+ )
34
+ provider["url"] = provider["url"] + "?api_key={api_key}"
35
+ tags = {
36
+ "building": True,
37
+ "water": True,
38
+ "highway": True,
39
+ "landuse": [
40
+ "grass",
41
+ "forest",
42
+ "flowerbed",
43
+ "greenfield",
44
+ "village_green",
45
+ "recreation_ground",
46
+ ],
47
+ "leisure": "garden",
48
+ } # "road": True}
49
+
50
+
51
+ # %% Coordinate function
52
+ def get_coordinates(place: str) -> Coords:
53
+ geolocator = Nominatim(user_agent="parabellum")
54
+ point = geolocator.geocode(place)
55
+ return point.latitude, point.longitude # type: ignore
56
+
57
+
58
+ def get_bbox(place: str, buffer) -> BBox:
59
+ """Get bounding box from place name in crs 4326."""
60
+ coords = get_coordinates(place)
61
+ north = distance(meters=buffer).destination(coords, bearing=0).latitude
62
+ south = distance(meters=buffer).destination(coords, bearing=180).latitude
63
+ east = distance(meters=buffer).destination(coords, bearing=90).longitude
64
+ west = distance(meters=buffer).destination(coords, bearing=270).longitude
65
+ return BBox(north, south, east, west)
66
+
67
+
68
+ def basemap_fn(bbox: BBox, gdf) -> Array:
69
+ fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
70
+ gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
71
+ cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
72
+ bbox = gdf.total_bounds
73
+ ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
74
+ plt.axis("off")
75
+ plt.tight_layout(pad=0)
76
+ fig.canvas.draw()
77
+ image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
78
+ plt.close(fig)
79
+ return image
80
+
81
+
82
+ def geography_fn(place, buffer=400):
83
+ bbox = get_bbox(place, buffer)
84
+ map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
85
+ gdf = gpd.GeoDataFrame(map_data)
86
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs(
87
+ "EPSG:3857"
88
+ )
89
+ raster = raster_fn(gdf, shape=(buffer, buffer))
90
+ basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
91
+ # 0: building", 1: "water", 2: "highway", 3: "forest", 4: "garden"
92
+ kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
93
+ trans = lambda x: jnp.rot90(x, 3)
94
+ terrain = tps.Terrain(
95
+ building=trans(raster[0]),
96
+ water=trans(
97
+ raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0
98
+ ),
99
+ forest=trans(jnp.logical_or(raster[3], raster[4])),
100
+ basemap=basemap,
101
+ )
102
+ return terrain
103
+
104
+
105
+ def raster_fn(gdf, shape) -> Array:
106
+ bbox = gdf.total_bounds
107
+ t = transform.from_bounds(*bbox, *shape) # type: ignore
108
+ raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
109
+ return raster
110
+
111
+
112
+ def feature_fn(t, feature, gdf, shape):
113
+ if feature not in gdf.columns:
114
+ return jnp.zeros(shape)
115
+ gdf = gdf[~gdf[feature].isna()]
116
+ raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
117
+ return raster
118
+
119
+
120
+ # %%
121
+ if __name__ == "__main__":
122
+ place = "Thun, Switzerland"
123
+ terrain = geography_fn(place, 300)
124
+
125
+ fig, axes = plt.subplots(1, 5, figsize=(20, 20))
126
+ axes[0].imshow(terrain.building, cmap="gray")
127
+ axes[1].imshow(terrain.water, cmap="gray")
128
+ axes[2].imshow(terrain.forest, cmap="gray")
129
+ axes[3].imshow(terrain.building + terrain.water + terrain.forest)
130
+ axes[4].imshow(terrain.basemap)
parabellum/pcg.py ADDED
@@ -0,0 +1,53 @@
1
+ # pcg.py
2
+ # procedural content generation
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from jax import random, vmap
7
+ import jax.numpy as jnp
8
+ import matplotlib.pyplot as plt
9
+ from functools import partial
10
+
11
+
12
+ # %% Functions
13
+ seed = 0
14
+ n = 100
15
+ rng = random.PRNGKey(seed)
16
+ Y = random.uniform(rng, (n,))
17
+
18
+
19
+ def g(t):
20
+ return (1 - jnp.cos(jnp.pi * t)) / 2
21
+
22
+
23
+ def lerp(a, b, t):
24
+ t -= jnp.floor(t) # the fractional part of t
25
+ return (1 - t) * a + t * b
26
+
27
+
28
+ def cerp(a, b, t):
29
+ t -= jnp.floor(t) # the fractional part of t
30
+ return g(1 - t) * a + g(t) * b
31
+
32
+
33
+ def body_fn(x):
34
+ i = jnp.floor(x).astype(jnp.uint8)
35
+ return cerp(Y[i], Y[i + 1], x)
36
+
37
+
38
+ @partial(vmap, in_axes=(None, 0, None))
39
+ def noise_fn(y, t, n):
40
+ return y[t % n]
41
+
42
+
43
+ @vmap
44
+ def perlin_fn(t):
45
+ return noise_fn(Y, t * jnp.arange(n * 3), n)
46
+
47
+
48
+ xs = jnp.linspace(0, 1, 1000)
49
+ noise = perlin_fn(2 ** jnp.arange(3)).sum(0)
50
+
51
+ fig, ax = plt.subplots(figsize=(20, 4), dpi=100)
52
+ ax.set_ylim(0, 1)
53
+ ax.plot(noise / noise.max())
parabellum/run.py CHANGED
@@ -21,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
21
21
 
22
22
 
23
23
  # types
24
- State = jaxmarl.environments.smax.smax_env.State # type: ignore
24
+ State = jaxmarl.environments.smax.smax_env.State # type: ignore
25
25
  Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
26
26
  StateSeq = List[Tuple[jnp.ndarray, State, Action]]
27
27
 
@@ -0,0 +1,117 @@
1
+ # %%
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from parabellum import tps
5
+
6
+
7
+ # %%
8
+ def map_raster_from_line(raster, line, size):
9
+ x0, y0, dx, dy = line
10
+ x0 = int(x0*size)
11
+ y0 = int(y0*size)
12
+ dx = int(dx*size)
13
+ dy = int(dy*size)
14
+ max_T = int(2**0.5 * size)
15
+ for t in range(max_T+1):
16
+ alpha = t/float(max_T)
17
+ x = x0 if dx == 0 else int((1 - alpha) * x0 + alpha * (x0+dx))
18
+ y = y0 if dy == 0 else int((1 - alpha) * y0 + alpha * (y0+dy))
19
+ if 0<=x<size and 0<=y<size:
20
+ raster[x, y] = 1
21
+ return raster
22
+
23
+
24
+ # %%
25
+ def map_raster_from_rect(raster, rect, size):
26
+ x0, y0, dx, dy = rect
27
+ x0 = int(x0*size)
28
+ y0 = int(y0*size)
29
+ dx = int(dx*size)
30
+ dy = int(dy*size)
31
+ raster[x0:x0+dx, y0:y0+dy] = 1
32
+ return raster
33
+
34
+
35
+ # %%
36
+ building_color = jnp.array([201,199,198, 255])
37
+ water_color = jnp.array([193, 237, 254, 255])
38
+ forest_color = jnp.array([197,214,185, 255])
39
+ empty_color = jnp.array([255, 255, 255, 255])
40
+
41
+ def make_terrain(terrain_args, size):
42
+ args = {}
43
+ for key, config in terrain_args.items():
44
+ raster = np.zeros((size, size))
45
+ if config is not None:
46
+ for elem in config:
47
+ if "line" in elem:
48
+ raster = map_raster_from_line(raster, elem["line"], size)
49
+ elif "rect" in elem:
50
+ raster = map_raster_from_rect(raster, elem["rect"], size)
51
+ args[key] = jnp.array(raster.T)
52
+ basemap = jnp.where(args["building"][:,:,None], jnp.tile(building_color, (size, size, 1)), jnp.tile(empty_color, (size,size, 1)))
53
+ basemap = jnp.where(args["water"][:,:,None], jnp.tile(water_color, (size, size, 1)), basemap)
54
+ basemap = jnp.where(args["forest"][:,:,None], jnp.tile(forest_color, (size, size, 1)), basemap)
55
+ args["basemap"] = basemap
56
+ return tps.Terrain(**args)
57
+
58
+
59
+ # %%
60
+ db = {
61
+ "blank": {'building': None, 'water': None, 'forest': None},
62
+ "F": {'building': [{"line": [0.25, 0.33, 0.5, 0]}, {"line":[0.75, 0.33, 0., 0.25]}, {"line":[0.50, 0.33, 0., 0.25]}], 'water': None, 'forest': None},
63
+ "stronghold": {'building': [
64
+ {"line":[0.2, 0.275, 0.2, 0.]}, {"line":[0.2, 0.275, 0.0, 0.2]},
65
+ {"line":[0.4, 0.275, 0.0, 0.2]}, {"line":[0.2, 0.475, 0.2, 0.]},
66
+
67
+ {"line":[0.2, 0.525, 0.2, 0.]}, {"line": [0.2, 0.525, 0.0, 0.2]},
68
+ {"line":[0.4, 0.525, 0.0, 0.2]}, {"line": [0.2, 0.725, 0.525, 0.]},
69
+
70
+ {"line":[0.75, 0.25, 0., 0.2]}, {"line":[0.75, 0.55, 0., 0.19]},
71
+ {"line":[0.6, 0.25, 0.15, 0.]}], 'water': None, 'forest': None},
72
+ "playground": {'building': [{"line":[0.5, 0.5, 0.5, 0.]}], 'water': None, 'forest': None},
73
+ "water_park": {
74
+ 'building': [{"line":[0.5, 0.5, 0.5, 0.]}],
75
+ "water": [{"rect":[0., 0.8, 0.1, 0.05]}, {"rect": [0.2, 0.8, 0.8, 0.05]}],
76
+ "forest": [{"rect": [0., 0., 1., 0.2]}]
77
+ },
78
+ "triangle": {'building': [{"line": [0.33, 0., 0., 1.]}, {"line": [0.66, 0., 0., 1.]}], 'water': None, 'forest': None},
79
+ "u_shape": {
80
+ 'building': [],
81
+ "water": [{"rect": [0.15, 0.2, 0.1, 0.5]}, {"rect": [0.4, 0.2, 0.1, 0.5]}, {"rect": [0.2, 0.2, 0.25, 0.1]}],
82
+ "forest": []
83
+ },
84
+ }
85
+
86
+ # %% [raw]
87
+ # import matplotlib.pyplot as plt
88
+ # size = 50
89
+ # raster = np.zeros((size, size))
90
+ # rect = [0.2, 0.3, 0.05, 0.4]
91
+ # raster = map_raster_from_rect(raster, rect, size)
92
+ # rect = [0.4, 0.3, 0.05, 0.4]
93
+ # raster = map_raster_from_rect(raster, rect, size)
94
+ # rect = [0.2, 0.3, 0.25, 0.05]
95
+ # raster = map_raster_from_rect(raster, rect, size)
96
+ # rect = [0.2, 0.7, 0.25, 0.05]
97
+ # raster = map_raster_from_rect(raster, rect, size)
98
+ # rect = [0.6, 0.3, 0.4, 0.45]
99
+ # raster = map_raster_from_rect(raster, rect, size)
100
+ # plt.imshow(jnp.rot90(raster))
101
+
102
+ # %% [markdown]
103
+ # # Main
104
+
105
+ # %%
106
+ if __name__ == "__main__":
107
+ import matplotlib.pyplot as plt
108
+
109
+ # %%
110
+ terrain = make_terrain(db["u_shape"], size=50)
111
+
112
+ # %%
113
+ plt.imshow(jnp.rot90(terrain.basemap))
114
+
115
+ # %%
116
+
117
+ # %%
parabellum/tps.py ADDED
@@ -0,0 +1,17 @@
1
+ # %%
2
+ # tps.py
3
+ # parabellum types and dataclasses
4
+ # by: Noah Syrkis
5
+
6
+ # %% Imports
7
+ from chex import dataclass
8
+ from jaxtyping import Array
9
+
10
+
11
+ # %% Dataclasses
12
+ @dataclass
13
+ class Terrain:
14
+ building: Array
15
+ water: Array
16
+ forest: Array
17
+ basemap: Array
parabellum/vis.py CHANGED
@@ -3,26 +3,19 @@ Visualizer for the Parabellum environment
3
3
  """
4
4
 
5
5
  # Standard library imports
6
- from functools import partial
7
- from typing import Optional, List, Tuple
6
+ from typing import Optional, Tuple
8
7
  import cv2
9
- from PIL import Image
10
8
 
11
9
  # JAX and JAX-related imports
12
10
  import jax
13
11
  from chex import dataclass
14
- import chex
15
- from jax import vmap, tree_util, Array, jit
12
+ from jax import vmap, Array
16
13
  import jax.numpy as jnp
17
- from jaxmarl.environments.multi_agent_env import MultiAgentEnv
18
- from jaxmarl.environments.smax import SMAX
19
14
  from jaxmarl.viz.visualizer import SMAXVisualizer
20
15
 
21
16
  # Third-party imports
22
17
  import numpy as np
23
18
  import pygame
24
- import cv2
25
- from tqdm import tqdm
26
19
 
27
20
  # Local imports
28
21
  import parabellum as pb
@@ -35,12 +28,12 @@ class Skin:
35
28
  maskmap: Array # maskmap of buildings
36
29
  bg: Tuple[int, int, int] = (255, 255, 255)
37
30
  fg: Tuple[int, int, int] = (0, 0, 0)
38
- ally: Tuple[int, int, int] = (0, 255, 0)
39
- enemy: Tuple[int, int, int] = (255, 0, 0)
40
- pad: int = 100
41
- size: int = 1000 # excluding padding
31
+ ally: Tuple[int, int, int] = (0, 255, 0)
32
+ enemy: Tuple[int, int, int] = (255, 0, 0)
33
+ pad: int = 100
34
+ size: int = 1000 # excluding padding
42
35
  fps: int = 24
43
- vis_size: int = 1000 # size of the map in Vis (exluding padding)
36
+ vis_size: int = 1000 # size of the map in Vis (exluding padding)
44
37
  scale: Optional[float] = None
45
38
 
46
39
 
@@ -57,10 +50,23 @@ class Visualizer(SMAXVisualizer):
57
50
  self.env = env
58
51
 
59
52
  def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
60
- expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
61
- state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
62
- for idx, (state_seq, action_seq) in enumerate(zip(state_seq_seq, action_seq_seq)):
63
- animate_fn(self.env, self.skin, self.image, state_seq, action_seq, f"{save_fname}_{idx}.mp4")
53
+ expanded_state_seq, expanded_action_seq = expand_fn(
54
+ self.env, self.state_seq, self.action_seq
55
+ )
56
+ state_seq_seq, action_seq_seq = unbatch_fn(
57
+ expanded_state_seq, expanded_action_seq
58
+ )
59
+ for idx, (state_seq, action_seq) in enumerate(
60
+ zip(state_seq_seq, action_seq_seq)
61
+ ):
62
+ animate_fn(
63
+ self.env,
64
+ self.skin,
65
+ self.image,
66
+ state_seq,
67
+ action_seq,
68
+ f"{save_fname}_{idx}.mp4",
69
+ )
64
70
 
65
71
 
66
72
  # functions
@@ -70,22 +76,29 @@ def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
70
76
  for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
71
77
  frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
72
78
  # use cv2 to write frames to video
73
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
74
- out = cv2.VideoWriter(save_fname, fourcc, skin.fps, (skin.size + skin.pad * 2, skin.size + skin.pad * 2))
79
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
80
+ out = cv2.VideoWriter(
81
+ save_fname,
82
+ fourcc,
83
+ skin.fps,
84
+ (skin.size + skin.pad * 2, skin.size + skin.pad * 2),
85
+ )
75
86
  for frame in frames:
76
87
  out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
77
88
  out.release()
78
89
  pygame.quit()
79
90
 
80
91
 
81
- def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> pygame.Surface:
92
+ def init_frame(
93
+ env, skin, image, state: pb.State, action: Array, idx: int
94
+ ) -> pygame.Surface:
82
95
  dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
83
96
  frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
84
97
  return frame
85
98
 
86
99
 
87
100
  def transform_frame(env, skin, frame):
88
- #frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
101
+ # frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
89
102
  frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
90
103
  return frame
91
104
 
@@ -102,7 +115,7 @@ def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.n
102
115
 
103
116
 
104
117
  def render_background(env, skin, image, frame, state, action):
105
- coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
118
+ coords = (skin.pad - 5, skin.pad - 5, skin.size + 10, skin.size + 10)
106
119
  frame.fill(skin.bg)
107
120
  frame.blit(image, coords)
108
121
  pygame.draw.rect(frame, skin.fg, coords, 3)
@@ -116,6 +129,7 @@ def render_action(env, skin, image, frame, state, action):
116
129
  def render_bullet(env, skin, image, frame, state, action):
117
130
  return frame
118
131
 
132
+
119
133
  def render_agents(env, skin, image, frame, state, action):
120
134
  units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
121
135
  for idx, (pos, team, kind, health) in enumerate(zip(*units)):
@@ -136,7 +150,11 @@ def text_fn(text):
136
150
 
137
151
  def image_fn(skin: Skin): # TODO:
138
152
  """Create an image for background (basemap or maskmap)"""
139
- motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
153
+ motif = cv2.resize(
154
+ np.array(skin.maskmap.T),
155
+ (skin.size, skin.size),
156
+ interpolation=cv2.INTER_NEAREST,
157
+ ).astype(np.uint8)
140
158
  motif = (motif > 0).astype(np.uint8)
141
159
  image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
142
160
  image[motif == 1] = skin.fg
@@ -150,7 +168,9 @@ def unbatch_fn(state_seq, action_seq):
150
168
  if is_multi_run(state_seq):
151
169
  n_envs = state_seq[0][1].unit_positions.shape[0]
152
170
  state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
153
- action_seq_seq = [jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)]
171
+ action_seq_seq = [
172
+ jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)
173
+ ]
154
174
  else:
155
175
  state_seq_seq = [state_seq]
156
176
  action_seq_seq = [action_seq]
@@ -161,7 +181,9 @@ def expand_fn(env, state_seq, action_seq):
161
181
  """Expand the state sequence"""
162
182
  fn = env.expand_state_seq
163
183
  state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
164
- action_seq = [action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))]
184
+ action_seq = [
185
+ action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))
186
+ ]
165
187
  return state_seq, action_seq
166
188
 
167
189
 
@@ -1,42 +1,44 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.26
3
+ Version: 0.3.1
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
7
7
  Keywords: warfare,simulation,parallel,environment
8
8
  Author: Noah Syrkis
9
9
  Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
10
+ Requires-Python: >=3.11,<3.12
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
14
+ Requires-Dist: cartopy (>=0.23.0,<0.24.0)
15
15
  Requires-Dist: contextily (>=1.6.0,<2.0.0)
16
16
  Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
17
+ Requires-Dist: einops (>=0.8.0,<0.9.0)
17
18
  Requires-Dist: folium (>=0.17.0,<0.18.0)
18
- Requires-Dist: geopandas (>=1.0.0,<2.0.0)
19
19
  Requires-Dist: geopy (>=2.4.1,<3.0.0)
20
20
  Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
21
21
  Requires-Dist: jax (==0.4.17)
22
22
  Requires-Dist: jaxmarl (==0.0.3)
23
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
23
24
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
24
25
  Requires-Dist: moviepy (>=1.0.3,<2.0.0)
25
26
  Requires-Dist: numpy (<2)
26
27
  Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
27
- Requires-Dist: osmnx (>=1.9.3,<2.0.0)
28
+ Requires-Dist: osmnx (==2.0.0b0)
28
29
  Requires-Dist: pandas (>=2.2.2,<3.0.0)
29
30
  Requires-Dist: poetry (>=1.8.3,<2.0.0)
30
31
  Requires-Dist: pygame (>=2.5.2,<3.0.0)
31
32
  Requires-Dist: rasterio (>=1.3.10,<2.0.0)
32
33
  Requires-Dist: seaborn (>=0.13.2,<0.14.0)
34
+ Requires-Dist: stadiamaps (>=3.2.1,<4.0.0)
33
35
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
34
36
  Project-URL: Repository, https://github.com/syrkis/parabellum
35
37
  Description-Content-Type: text/markdown
36
38
 
37
39
  # Parabellum
38
40
 
39
- Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
41
+ Ultra-scalable JaxMARL based warfare simulation engine.
40
42
 
41
43
  [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
42
44
 
@@ -93,3 +95,4 @@ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.i
93
95
  ## License
94
96
 
95
97
  MIT
98
+
@@ -0,0 +1,13 @@
1
+ parabellum/__init__.py,sha256=hIOLir7wgaf_HU4j8uos7PaCrofqPQcr3FcMlBsZyr8,406
2
+ parabellum/aid.py,sha256=BPabjN4BUq1HRhkwbc9pCNsXSF_ALiG8W8cHWTWeEH4,900
3
+ parabellum/env.py,sha256=0mDqQ7-OI-oufBMMBoUt72Kf5OvHr9thilLGzszlICY,22569
4
+ parabellum/geo.py,sha256=PwEwspOppTPrHIXDZB_nGPTnVFIvDzbh2WtqzVKMUaM,4198
5
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
6
+ parabellum/pcg.py,sha256=d8KC_lbc4WUUUPaTdPJSx27VMGioys3jSGOWJ-2EahU,968
7
+ parabellum/run.py,sha256=Q53__AxzROZNgfZLVU5LDdcT61UMCkmQ_Q5wWUIrnqo,3473
8
+ parabellum/terrain_db.py,sha256=XTKlpLAi3ZwoVw4-KS-Eh15NKsBKP-yt8v6FJGUtwdM,3960
9
+ parabellum/tps.py,sha256=of-RBdelAbNCHQZd1I22RWmZkwUEh6f161mx0X_G2tE,257
10
+ parabellum/vis.py,sha256=ABHveJj0fLRWkxOv3LFIXK20QtdGhjskuFLsp7iTFu0,6185
11
+ parabellum-0.3.1.dist-info/METADATA,sha256=RrSY6CrhwpVlbdJzacX2iVkh_MEgtZkZZCPHJWJJjqo,2707
12
+ parabellum-0.3.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ parabellum-0.3.1.dist-info/RECORD,,
parabellum/map.py DELETED
@@ -1,100 +0,0 @@
1
- # map.py
2
- # parabellum map functions
3
- # by: Noah Syrkis
4
-
5
- # imports
6
- import jax.numpy as jnp
7
- from geopy.geocoders import Nominatim
8
- import geopandas as gpd
9
- import osmnx as ox
10
- import contextily as cx
11
- import matplotlib.pyplot as plt
12
- from rasterio import features
13
- import rasterio.transform
14
- from typing import Optional, Tuple
15
- from geopy.location import Location
16
- from shapely.geometry import Point
17
- import os
18
- import pickle
19
-
20
- # constants
21
- geolocator = Nominatim(user_agent="parabellum")
22
- BUILDING_TAGS = {"building": True}
23
-
24
- def get_location(place: str) -> Tuple[float, float]:
25
- """Get coordinates for a given place."""
26
- coords: Optional[Location] = geolocator.geocode(place) # type: ignore
27
- if coords is None:
28
- raise ValueError(f"Could not geocode the place: {place}")
29
- return (coords.latitude, coords.longitude)
30
-
31
- def get_building_geometry(point: Tuple[float, float], size: int) -> gpd.GeoDataFrame:
32
- """Get building geometry for a given point and size."""
33
- geometry = ox.features_from_point(point, tags=BUILDING_TAGS, dist=size // 2)
34
- return gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
35
-
36
- def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
37
- """Rasterize geometry and return as a JAX array."""
38
- w, s, e, n = gdf.total_bounds
39
- transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
40
- raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
41
- return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
42
-
43
- # +
44
- def get_from_cache(place, size):
45
- if os.path.exists("./cache"):
46
- name = str(hash((place, size))) + ".pk"
47
- if os.path.exists("./cache/" + name):
48
- with open("./cache/" + name, "rb") as f:
49
- (mask, base) = pickle.load(f)
50
- return (mask, base.astype(jnp.int64))
51
- return (None, None)
52
-
53
- def save_in_cache(place, size, mask, base):
54
- if not os.path.exists("./cache"):
55
- os.makedirs("./cache")
56
- name = str(hash((place, size))) + ".pk"
57
- with open("./cache/" + name, "wb") as f:
58
- pickle.dump((mask, base), f)
59
-
60
- def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
61
- """Returns a rasterized map of buildings for a given location."""
62
- if with_cache:
63
- mask, base = get_from_cache(place, size)
64
- if not with_cache or mask is None:
65
- point = get_location(place)
66
- gdf = get_building_geometry(point, size)
67
- mask = rasterize_geometry(gdf, size)
68
- base = get_basemap(place, size)
69
- if with_cache:
70
- save_in_cache(place, size, mask, base)
71
- return mask, base
72
-
73
-
74
- # -
75
-
76
- def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
77
- """Returns a basemap for a given place as a JAX array."""
78
- point = get_location(place)
79
- gdf = get_building_geometry(point, size)
80
- basemap, _ = cx.bounds2img(*gdf.total_bounds, ll=True)
81
- # get the middle size x size square
82
- basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
83
- (basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
84
- return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
85
-
86
-
87
- if __name__ == "__main__":
88
- place = "Cauvicourt, 14190, France"
89
- mask, base = terrain_fn(place, 500)
90
-
91
- fig, ax = plt.subplots(1, 3, figsize=(15, 5))
92
- ax[0].imshow(jnp.flip(mask,0)) # type: ignore
93
- ax[1].imshow(base) # type: ignore
94
- ax[2].imshow(base) # type: ignore
95
- ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
96
- plt.show()
97
-
98
-
99
-
100
-
@@ -1,10 +0,0 @@
1
- parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
- parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
- parabellum/env.py,sha256=H1YGAtUYNJd8OHnZ3sOEXbag5L0WjtJHBGL8ymGPvoE,16898
4
- parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
- parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
6
- parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
- parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
8
- parabellum-0.2.26.dist-info/METADATA,sha256=AJwdmHRRPG2MosgufeWqQH7LsDRtTvFFGMh1azey9zA,2671
9
- parabellum-0.2.26.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
- parabellum-0.2.26.dist-info/RECORD,,