parabellum 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,4 +1,18 @@
1
- from .env import Parabellum, Scenario, scenarios
2
- from .vis import Visualizer
1
+ from .env import Environment, Scenario, scenarios, make_scenario, State
2
+ from .vis import Visualizer, Skin
3
+ from .map import terrain_fn
4
+ from .gun import bullet_fn
5
+ # from .aid import aid
6
+ # from .run import run
3
7
 
4
- __all__ = ["Parabellum", "Visualizer", "Scenario", "scenarios"]
8
+ __all__ = [
9
+ "Environment",
10
+ "Scenario",
11
+ "scenarios",
12
+ "make_scenario",
13
+ "State",
14
+ "Visualizer",
15
+ "Skin",
16
+ "terrain_fn",
17
+ "bullet_fn",
18
+ ]
parabellum/aid.py ADDED
@@ -0,0 +1,5 @@
1
+ # aid.py
2
+ # what you call utils.py when you want file names to be 3 letters
3
+ # by: Noah Syrkis
4
+
5
+ # imports
parabellum/env.py CHANGED
@@ -3,13 +3,13 @@
3
3
  import jax.numpy as jnp
4
4
  import jax
5
5
  import numpy as np
6
- from jax import random
6
+ from jax import random, Array
7
7
  from jax import jit
8
8
  from flax.struct import dataclass
9
9
  import chex
10
10
  from jax import vmap
11
- from jaxmarl.environments.smax.smax_env import State, SMAX
12
- from typing import Tuple, Dict
11
+ from jaxmarl.environments.smax.smax_env import SMAX
12
+ from typing import Tuple, Dict, cast
13
13
  from functools import partial
14
14
 
15
15
 
@@ -17,11 +17,8 @@ from functools import partial
17
17
  class Scenario:
18
18
  """Parabellum scenario"""
19
19
 
20
- terrain_raster: chex.Array
21
-
22
- obstacle_coords: chex.Array # TODO: use map instead of obstacles
23
- obstacle_deltas: chex.Array
24
-
20
+ place: str
21
+ terrain_raster: jnp.ndarray
25
22
  unit_types: chex.Array
26
23
  num_allies: int
27
24
  num_enemies: int
@@ -29,13 +26,25 @@ class Scenario:
29
26
  smacv2_position_generation: bool = False
30
27
  smacv2_unit_type_generation: bool = False
31
28
 
29
+ @dataclass
30
+ class State:
31
+ unit_positions: Array
32
+ unit_alive: Array
33
+ unit_teams: Array
34
+ unit_health: Array
35
+ unit_types: Array
36
+ unit_weapon_cooldowns: Array
37
+ prev_movement_actions: Array
38
+ prev_attack_actions: Array
39
+ time: int
40
+ terminal: bool
41
+
32
42
 
33
43
  # default scenario
34
44
  scenarios = {
35
45
  "default": Scenario(
46
+ "Identity Town",
36
47
  jnp.eye(64, dtype=jnp.uint8),
37
- jnp.array([[80, 0], [16, 12]]),
38
- jnp.array([[0, 80], [0, 20]]),
39
48
  jnp.zeros((19,), dtype=jnp.uint8),
40
49
  9,
41
50
  10,
@@ -43,57 +52,63 @@ scenarios = {
43
52
  }
44
53
 
45
54
 
46
- class Parabellum(SMAX):
55
+ def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
56
+ """Create a scenario"""
57
+ num_agents = num_allies + num_enemies
58
+ unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
59
+ return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
60
+
61
+
62
+ def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
63
+ """Spawns n agents on a map."""
64
+ rng, key_start, key_noise = random.split(rng, 3)
65
+ noise = random.uniform(key_noise, (n, 2)) * 0.5
66
+
67
+ # select n random (x, y)-coords where sector == True
68
+ idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
69
+ coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
70
+
71
+ return coords + noise + offset
72
+
73
+
74
+ def sector_fn(terrain: jnp.ndarray, sector_id: int):
75
+ """return sector slice of terrain"""
76
+ width, height = terrain.shape
77
+ coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
78
+ sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
79
+ offset = jnp.array([coordx, coordy])
80
+ # sector is jnp.nonzero
81
+ return jnp.nonzero(sector), offset
82
+
83
+
84
+ class Environment(SMAX):
47
85
  def __init__(self, scenario: Scenario, **kwargs):
48
86
  map_height, map_width = scenario.terrain_raster.shape
49
87
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
50
- super(Parabellum, self).__init__(**args, **kwargs)
88
+ super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
51
89
  self.terrain_raster = scenario.terrain_raster
52
- self.obstacle_coords = scenario.obstacle_coords
53
- self.obstacle_deltas = scenario.obstacle_deltas
54
- self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
90
+ # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
91
+ # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
92
+ # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
93
+ self.scenario = scenario
94
+ self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
55
95
  self.max_steps = 200
56
- self._push_units_away = lambda x: x # overwrite push units
96
+ self._push_units_away = lambda state, firmness = 1: state # overwrite push units
97
+ self.top_sector, self.top_sector_offset = sector_fn(self.terrain_raster, 0)
98
+ self.low_sector, self.low_sector_offset = sector_fn(self.terrain_raster, 24)
57
99
 
58
100
  @partial(jax.jit, static_argnums=(0,))
59
- def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
101
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
60
102
  """Environment-specific reset."""
61
- key, team_0_key, team_1_key = jax.random.split(key, num=3)
62
- team_0_start = jnp.stack(
63
- [jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
64
- )
65
- team_0_start_noise = jax.random.uniform(
66
- team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
67
- )
68
- team_0_start = team_0_start + team_0_start_noise
69
- team_1_start = jnp.stack(
70
- [jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
71
- * self.num_enemies
72
- )
73
- team_1_start_noise = jax.random.uniform(
74
- team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
75
- )
76
- team_1_start = team_1_start + team_1_start_noise
103
+ ally_key, enemy_key = jax.random.split(rng)
104
+ team_0_start = spawn_fn(self.top_sector, self.top_sector_offset, self.num_allies, ally_key)
105
+ team_1_start = spawn_fn(self.low_sector, self.low_sector_offset, self.num_enemies, enemy_key)
77
106
  unit_positions = jnp.concatenate([team_0_start, team_1_start])
78
- key, pos_key = jax.random.split(key)
79
- generated_unit_positions = self.position_generator.generate(pos_key)
80
- unit_positions = jax.lax.select(
81
- self.smacv2_position_generation, generated_unit_positions, unit_positions
82
- )
83
107
  unit_teams = jnp.zeros((self.num_agents,))
84
108
  unit_teams = unit_teams.at[self.num_allies :].set(1)
85
109
  unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
86
110
  # default behaviour spawn all marines
87
- unit_types = (
88
- jnp.zeros((self.num_agents,), dtype=jnp.uint8)
89
- if self.scenario is None
90
- else self.scenario
91
- )
92
- key, unit_type_key = jax.random.split(key)
93
- generated_unit_types = self.unit_type_generator.generate(unit_type_key)
94
- unit_types = jax.lax.select(
95
- self.smacv2_unit_type_generation, generated_unit_types, unit_types
96
- )
111
+ unit_types = cast(Array, self.scenario.unit_types)
97
112
  unit_health = self.unit_type_health[unit_types]
98
113
  state = State(
99
114
  unit_positions=unit_positions,
@@ -107,7 +122,7 @@ class Parabellum(SMAX):
107
122
  terminal=False,
108
123
  unit_weapon_cooldowns=unit_weapon_cooldowns,
109
124
  )
110
- state = self._push_units_away(state)
125
+ state = self._push_units_away(state) # type: ignore
111
126
  obs = self.get_obs(state)
112
127
  world_state = self.get_world_state(state)
113
128
  obs["world_state"] = jax.lax.stop_gradient(world_state)
@@ -139,7 +154,7 @@ class Parabellum(SMAX):
139
154
  key: chex.PRNGKey,
140
155
  state: State,
141
156
  actions: Tuple[chex.Array, chex.Array],
142
- ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
157
+ ) -> State:
143
158
  @partial(jax.vmap, in_axes=(None, None, 0, 0))
144
159
  def intersect_fn(pos, new_pos, obs, obs_end):
145
160
  d1 = jnp.cross(obs - pos, new_pos - pos)
@@ -165,7 +180,7 @@ class Parabellum(SMAX):
165
180
  # because these are easier to encode as actions than the four
166
181
  # diagonal directions. Then rotate the velocity 45
167
182
  # degrees anticlockwise to compute the movement.
168
- pos = state.unit_positions[idx]
183
+ pos = cast(Array, state.unit_positions[idx])
169
184
  new_pos = (
170
185
  pos
171
186
  + vec
@@ -180,12 +195,12 @@ class Parabellum(SMAX):
180
195
 
181
196
  #######################################################################
182
197
  ############################################ avoid going into obstacles
183
- obs = self.obstacle_coords
198
+ """ obs = self.obstacle_coords
184
199
  obs_end = obs + self.obstacle_deltas
185
- inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
186
- rastersects = raster_crossing(pos, new_pos)
187
- flag = jnp.logical_or(inters, rastersects)
188
- new_pos = jnp.where(flag, pos, new_pos)
200
+ inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
201
+ clash = raster_crossing(pos, new_pos)
202
+ # flag = jnp.logical_or(inters, rastersects)
203
+ new_pos = jnp.where(clash, pos, new_pos)
189
204
 
190
205
  #######################################################################
191
206
  #######################################################################
@@ -217,6 +232,7 @@ class Parabellum(SMAX):
217
232
  lambda: action + self.num_allies - self.num_movement_actions,
218
233
  lambda: self.num_allies - 1 - (action - self.num_movement_actions),
219
234
  )
235
+ attacked_idx = cast(int, attacked_idx) # Cast to int
220
236
  # deal with no-op attack actions (i.e. agents that are moving instead)
221
237
  attacked_idx = jax.lax.select(
222
238
  action < self.num_movement_actions, idx, attacked_idx
@@ -248,7 +264,7 @@ class Parabellum(SMAX):
248
264
  bystander_idxs = bystander_fn(attacked_idx) # TODO: use
249
265
  bystander_valid = (
250
266
  jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
251
- .astype(jnp.bool_)
267
+ .astype(jnp.bool_) # type: ignore
252
268
  .astype(jnp.float32)
253
269
  )
254
270
  bystander_health_diff = (
@@ -310,14 +326,14 @@ class Parabellum(SMAX):
310
326
  [self.map_width, 0],
311
327
  ]
312
328
  )
313
- obstacle_coords = jnp.concatenate(
329
+ """ obstacle_coords = jnp.concatenate(
314
330
  [self.obstacle_coords, bondaries_coords]
315
331
  ) # add the map boundaries to the obstacles to avoid
316
332
  obstacle_deltas = jnp.concatenate(
317
333
  [self.obstacle_deltas, bondaries_deltas]
318
334
  ) # add the map boundaries to the obstacles to avoid
319
335
  obst_start = obstacle_coords
320
- obst_end = obst_start + obstacle_deltas
336
+ obst_end = obst_start + obstacle_deltas """
321
337
 
322
338
  def check_obstacles(pos, new_pos, obst_start, obst_end):
323
339
  intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
@@ -325,9 +341,9 @@ class Parabellum(SMAX):
325
341
  flag = jnp.logical_or(intersects, rastersect)
326
342
  return jnp.where(flag, pos, new_pos)
327
343
 
328
- pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
344
+ """ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
329
345
  pos, new_pos, obst_start, obst_end
330
- )
346
+ ) """
331
347
 
332
348
  # Multiple enemies can attack the same unit.
333
349
  # We have `(health_diff, attacked_idx)` pairs.
@@ -363,7 +379,8 @@ class Parabellum(SMAX):
363
379
  #########################################################
364
380
 
365
381
  unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
366
- state = state.replace(
382
+ # replace unit health, unit positions and unit weapon cooldowns
383
+ state = state.replace( # type: ignore
367
384
  unit_health=unit_health,
368
385
  unit_positions=pos,
369
386
  unit_weapon_cooldowns=unit_weapon_cooldowns,
@@ -373,13 +390,16 @@ class Parabellum(SMAX):
373
390
 
374
391
  if __name__ == "__main__":
375
392
  n_envs = 4
376
- kwargs = dict(map_width=64, map_height=64)
377
- env = Parabellum(scenarios["default"], **kwargs)
393
+
394
+ env = Environment(scenarios["default"])
378
395
  rng, reset_rng = random.split(random.PRNGKey(0))
379
396
  reset_key = random.split(reset_rng, n_envs)
380
397
  obs, state = vmap(env.reset)(reset_key)
381
398
  state_seq = []
382
399
 
400
+ print(state.unit_positions)
401
+ exit()
402
+
383
403
  for i in range(10):
384
404
  rng, act_rng, step_rng = random.split(rng, 3)
385
405
  act_key = random.split(act_rng, (len(env.agents), n_envs))
parabellum/gun.py ADDED
@@ -0,0 +1,70 @@
1
+ # gun.py
2
+ # parabellum bullet rendering assosciated functions
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ from functools import partial
7
+ import jax.numpy as jnp
8
+
9
+
10
+ def dist_fn(env, pos): # computing the distances between all ally and enemy agents
11
+ delta = pos[None, :, :] - pos[:, None, :]
12
+ dist = jnp.sqrt((delta**2).sum(axis=2))
13
+ dist = dist[: env.num_allies, env.num_allies :]
14
+ return {"ally": dist, "enemy": dist.T}
15
+
16
+
17
+ def range_fn(env, dists, ranges): # computing what targets are in range
18
+ ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
19
+ enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
20
+ return {"ally": ally_range, "enemy": enemy_range}
21
+
22
+
23
+ def target_fn(acts, in_range, team): # computing the one hot valid targets
24
+ t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
25
+ t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
26
+ t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
27
+ return t_attacks * in_range[team] # one hot valid targets
28
+
29
+
30
+ def attack_fn(env, state_seq): # one hot attack list
31
+ attacks = []
32
+ for _, state, acts in state_seq:
33
+ dists = dist_fn(env, state.unit_positions)
34
+ ranges = env.unit_type_attack_ranges[state.unit_types]
35
+ in_range = range_fn(env, dists, ranges)
36
+ target = partial(target_fn, acts, in_range)
37
+ attack = {"ally": target("ally"), "enemy": target("enemy")}
38
+ attacks.append(attack)
39
+ return attacks
40
+
41
+
42
+ def bullet_fn(env, states):
43
+ bullet_seq = []
44
+ attack_seq = attack_fn(env, states)
45
+
46
+ def aux_fn(team):
47
+ bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
48
+ # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
49
+ return bullets
50
+
51
+ state_zip = zip(states[:-1], states[1:])
52
+ for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
53
+ one_hot = attack_seq[i]
54
+ ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
55
+
56
+ ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
57
+ enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
58
+
59
+ enemy_bullets_source = state.unit_positions[
60
+ enemy_bullets[:, 0] + env.num_allies
61
+ ]
62
+ ally_bullets_target = n_state.unit_positions[
63
+ ally_bullets[:, 1] + env.num_allies
64
+ ]
65
+
66
+ ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
67
+ enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
68
+
69
+ bullet_seq.append((ally_bullets, enemy_bullets))
70
+ return bullet_seq
parabellum/map.py CHANGED
@@ -1,16 +1,68 @@
1
1
  # map.py
2
- # parabellum map functions
2
+ # parabellum map functions
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # imports
6
6
  import jax.numpy as jnp
7
- import jax
7
+ from geopy.geocoders import Nominatim
8
+ import geopandas as gpd
9
+ import osmnx as ox
10
+ import contextily as cx
11
+ import matplotlib.pyplot as plt
12
+ from rasterio import features
13
+ import rasterio.transform
14
+ from typing import Optional, Tuple
15
+ from geopy.location import Location
16
+ from shapely.geometry import Point
8
17
 
18
+ # constants
19
+ geolocator = Nominatim(user_agent="parabellum")
20
+ BUILDING_TAGS = {"building": True}
9
21
 
10
- # functions
11
- def map_fn(width, height, obst_coord, obst_delta):
12
- """Create a map from the given width, height, and obstacle coordinates and deltas."""
13
- m = jnp.zeros((width, height))
14
- for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
15
- m = m.at[x : x + dx, y : y + dy].set(1)
16
- return m
22
+ def get_location(place: str) -> Tuple[float, float]:
23
+ """Get coordinates for a given place."""
24
+ coords: Optional[Location] = geolocator.geocode(place) # type: ignore
25
+ if coords is None:
26
+ raise ValueError(f"Could not geocode the place: {place}")
27
+ return (coords.latitude, coords.longitude)
28
+
29
+ def get_building_geometry(point: Tuple[float, float], size: int) -> gpd.GeoDataFrame:
30
+ """Get building geometry for a given point and size."""
31
+ geometry = ox.features_from_point(point, tags=BUILDING_TAGS, dist=size // 2)
32
+ return gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
33
+
34
+ def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
35
+ """Rasterize geometry and return as a JAX array."""
36
+ w, s, e, n = gdf.total_bounds
37
+ transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
38
+ raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
39
+ return jnp.array(jnp.rot90(raster, 2)).astype(jnp.uint8)
40
+
41
+ def terrain_fn(place: str, size: int = 1000) -> Tuple[jnp.ndarray, jnp.ndarray]:
42
+ """Returns a rasterized map of buildings for a given location."""
43
+ point = get_location(place)
44
+ gdf = get_building_geometry(point, size)
45
+ mask = rasterize_geometry(gdf, size)
46
+ base = get_basemap(place, size)
47
+ return mask, base
48
+
49
+ def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
50
+ """Returns a basemap for a given place as a JAX array."""
51
+ point = get_location(place)
52
+ gdf = get_building_geometry(point, size)
53
+ basemap, _ = cx.bounds2img(*gdf.total_bounds, ll=True)
54
+ # get the middle size x size square
55
+ basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
56
+ (basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
57
+ return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ import seaborn as sns
62
+ place = "Thun, Switzerland"
63
+ mask, base = terrain_fn(place)
64
+
65
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
66
+ ax[0].imshow(mask) # type: ignore
67
+ ax[1].imshow(base) # type: ignore
68
+ plt.show()
parabellum/run.py CHANGED
@@ -10,6 +10,7 @@ import darkdetect
10
10
  import jax.numpy as jnp
11
11
  from chex import dataclass
12
12
  import jaxmarl
13
+ from jax import Array
13
14
  from typing import Tuple, List, Dict, Optional
14
15
  import parabellum as pb
15
16
 
@@ -20,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
20
21
 
21
22
 
22
23
  # types
23
- State = jaxmarl.environments.smax.smax_env.State
24
+ State = jaxmarl.environments.smax.smax_env.State # type: ignore
24
25
  Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
25
26
  StateSeq = List[Tuple[jnp.ndarray, State, Action]]
26
27
 
@@ -35,12 +36,12 @@ class Control:
35
36
  @dataclass
36
37
  class Game:
37
38
  clock: pygame.time.Clock
38
- state: State
39
+ state: State # type: ignore
39
40
  obs: Dict
40
41
  state_seq: StateSeq
41
42
  control: Control
42
- env: pb.Parabellum
43
- rng: random.PRNGKey
43
+ env: pb.Environment
44
+ rng: Array
44
45
 
45
46
 
46
47
  def handle_event(event, control_state):
@@ -100,7 +101,7 @@ def step_fn(game):
100
101
 
101
102
  # state
102
103
  if __name__ == "__main__":
103
- env = pb.Parabellum(pb.scenarios["default"])
104
+ env = pb.Environment(pb.scenarios["default"])
104
105
  pygame.init()
105
106
  screen = pygame.display.set_mode((1000, 1000))
106
107
  render = partial(render_fn, screen)
@@ -115,7 +116,7 @@ if __name__ == "__main__":
115
116
  state=state,
116
117
  obs=obs,
117
118
  )
118
- game = Game(**kwargs)
119
+ game = Game(**kwargs) # type: ignore
119
120
 
120
121
  while game.control.running:
121
122
  game = control_fn(game)
@@ -123,5 +124,3 @@ if __name__ == "__main__":
123
124
  game = game if game.control.paused else render(game)
124
125
 
125
126
  pygame.quit()
126
-
127
-
parabellum/vis.py CHANGED
@@ -2,265 +2,168 @@
2
2
  Visualizer for the Parabellum environment
3
3
  """
4
4
 
5
- from tqdm import tqdm
6
- import jax.numpy as jnp
7
- import jax
8
- from jax import vmap
9
- from jax import tree_util
5
+ # Standard library imports
10
6
  from functools import partial
11
- import darkdetect
12
- import numpy as np
13
- import pygame
14
- import os
15
- from moviepy.editor import ImageSequenceClip
16
- from typing import Optional
7
+ from typing import Optional, List, Tuple
8
+ import cv2
9
+ from PIL import Image
10
+
11
+ # JAX and JAX-related imports
12
+ import jax
13
+ from chex import dataclass
14
+ import chex
15
+ from jax import vmap, tree_util, Array, jit
16
+ import jax.numpy as jnp
17
17
  from jaxmarl.environments.multi_agent_env import MultiAgentEnv
18
+ from jaxmarl.environments.smax import SMAX
18
19
  from jaxmarl.viz.visualizer import SMAXVisualizer
19
20
 
20
- # default dict
21
- from collections import defaultdict
22
-
21
+ # Third-party imports
22
+ import numpy as np
23
+ import pygame
24
+ import cv2
25
+ from tqdm import tqdm
23
26
 
24
- # constants
25
- action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
27
+ # Local imports
28
+ import parabellum as pb
26
29
 
27
30
 
28
- def small_multiples():
29
- # make video of small multiples based on all videos in output
30
- video_files = [f"output/parabellum_{i}.mp4" for i in range(4)]
31
- # load mp4 videos and make a grid
32
- clips = [ImageSequenceClip.load(filename) for filename in video_files]
33
- print(len(clips))
31
+ # skin dataclass
32
+ @dataclass
33
+ class Skin:
34
+ # basemap: Array # basemap of buildings
35
+ maskmap: Array # maskmap of buildings
36
+ bg: Tuple[int, int, int] = (255, 255, 255)
37
+ fg: Tuple[int, int, int] = (0, 0, 0)
38
+ ally: Tuple[int, int, int] = (0, 255, 0)
39
+ enemy: Tuple[int, int, int] = (255, 0, 0)
40
+ pad: int = 100
41
+ size: int = 1000 # excluding padding
42
+ fps: int = 24
43
+ vis_size: int = 1000 # size of the map in Vis (exluding padding)
44
+ scale: Optional[float] = None
34
45
 
35
46
 
36
47
  class Visualizer(SMAXVisualizer):
37
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
38
- super().__init__(env, state_seq, reward_seq)
39
- # remove fig and ax from super
40
- self.fig, self.ax = None, None
41
- self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
42
- self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
43
- self.s = 1000
44
- self.scale = self.s / self.env.map_width
45
- self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
48
+ def __init__(self, env: pb.Environment, state_seq, skin: Skin, reward_seq=None):
49
+ super(Visualizer, self).__init__(env, state_seq, reward_seq)
50
+
46
51
  # self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
52
+ self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
53
+ self.state_seq = state_seq
54
+ self.image = image_fn(skin)
55
+ self.skin = skin
56
+ self.skin.scale = self.skin.size / env.map_width # assumes square map
57
+ self.env = env
58
+
47
59
 
48
- def animate(self, save_fname: str = "output/parabellum.mp4"):
49
- multi_dim = self.state_seq[0][1].unit_positions.ndim > 2
50
- if multi_dim:
51
- n_envs = self.state_seq[0][1].unit_positions.shape[0]
52
- if not self.have_expanded:
53
- state_seqs = vmap(self.env.expand_state_seq)(self.state_seq)
54
- self.have_expanded = True
55
- for i in range(n_envs):
56
- state_seq = jax.tree_map(lambda x: x[i], state_seqs)
57
- action_seq = jax.tree_map(lambda x: x[i], self.action_seq)
58
- self.animate_one(
59
- state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
60
- )
61
- else:
62
- state_seq = self.env.expand_state_seq(self.state_seq)
63
- self.animate_one(state_seq, self.action_seq, save_fname)
64
-
65
- def animate_one(self, state_seq, action_seq, save_fname):
66
- frames = [] # frames for the video
67
- pygame.init() # initialize pygame
68
- terrain = np.array(self.env.terrain_raster)
69
- rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
70
- if darkdetect.isLight():
71
- rgb_array += 255
72
- rgb_array[terrain == 1] = self.fg
73
- mask_surface = pygame.surfarray.make_surface(rgb_array)
74
- mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
75
-
76
- for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
77
- action = action_seq[idx // self.env.world_steps_per_env_step]
78
- screen = pygame.Surface(
79
- (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
80
- )
81
- screen.fill(self.bg) # fill the screen with the background color
82
- screen.blit(mask_surface, (0, 0))
83
-
84
- self.render_agents(screen, state) # render the agents
85
- self.render_action(screen, action)
86
- self.render_obstacles(screen) # render the obstacles
87
-
88
- # bullets
89
- """ if idx < len(self.bullet_seq) * 8:
90
- bullets = self.bullet_seq[idx // 8]
91
- self.render_bullets(screen, bullets, idx % 8) """
92
-
93
- # rotate the screen and append to frames
94
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
95
- # save the images
96
- clip = ImageSequenceClip(frames, fps=48)
97
- clip.write_videofile(save_fname, fps=48)
98
- clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
99
- pygame.quit()
100
-
101
- return clip
102
-
103
- def render_agents(self, screen, state):
104
- time_tuple = zip(
105
- state.unit_positions,
106
- state.unit_teams,
107
- state.unit_types,
108
- state.unit_health,
109
- )
110
- for idx, (pos, team, kind, hp) in enumerate(time_tuple):
111
- face_col = self.fg if int(team.item()) == 0 else self.bg
112
- pos = tuple((pos * self.scale).tolist())
113
- # draw the agent
114
- if hp > 0:
115
- hp_frac = hp / self.env.unit_type_health[kind]
116
- unit_size = self.env.unit_type_radiuses[kind]
117
- radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
118
- pygame.draw.circle(screen, face_col, pos, radius)
119
- pygame.draw.circle(screen, self.fg, pos, radius, 1)
120
-
121
- # draw the sight range
122
- # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
123
- # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
124
-
125
- # draw attack range
126
- # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
127
- # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
128
- # work out which agents are being shot
129
-
130
- def render_action(self, screen, action):
131
- if self.env.action_type != "discrete":
132
- return
133
-
134
- def coord_fn(idx, n, team):
135
- return (
136
- self.s / 20 if team == 0 else self.s - self.s / 20,
137
- # vertically centered so that n / 2 is above and below the center
138
- self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
139
- )
140
-
141
- for idx in range(self.env.num_allies):
142
- symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
143
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
144
- text = font.render(symb, True, self.fg)
145
- coord = coord_fn(idx, self.env.num_allies, 0)
146
- screen.blit(text, coord)
147
-
148
- for idx in range(self.env.num_enemies):
149
- symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
150
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
151
- text = font.render(symb, True, self.fg)
152
- coord = coord_fn(idx, self.env.num_enemies, 1)
153
- screen.blit(text, coord)
154
-
155
- def render_obstacles(self, screen):
156
- for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
157
- d = tuple(((c + d) * self.scale).tolist())
158
- c = tuple((c * self.scale).tolist())
159
- pygame.draw.line(screen, self.fg, c, d, 5)
160
-
161
- def render_bullets(self, screen, bullets, jdx):
162
- jdx += 1
163
- ally_bullets, enemy_bullets = bullets
164
- for source, target in ally_bullets:
165
- position = source + (target - source) * jdx / 8
166
- position *= self.scale
167
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
168
- for source, target in enemy_bullets:
169
- position = source + (target - source) * jdx / 8
170
- position *= self.scale
171
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
60
+ def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
61
+ expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
62
+ state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
63
+ for idx, (state_seq, action_seq) in enumerate(zip(state_seq_seq, action_seq_seq)):
64
+ animate_fn(self.env, self.skin, self.image, state_seq, action_seq, f"{save_fname}_{idx}.mp4")
172
65
 
173
66
 
174
67
  # functions
175
- # bullet functions
176
- def dist_fn(env, pos): # computing the distances between all ally and enemy agents
177
- delta = pos[None, :, :] - pos[:, None, :]
178
- dist = jnp.sqrt((delta**2).sum(axis=2))
179
- dist = dist[: env.num_allies, env.num_allies :]
180
- return {"ally": dist, "enemy": dist.T}
181
-
182
-
183
- def range_fn(env, dists, ranges): # computing what targets are in range
184
- ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
185
- enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
186
- return {"ally": ally_range, "enemy": enemy_range}
187
-
188
-
189
- def target_fn(acts, in_range, team): # computing the one hot valid targets
190
- t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
191
- t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
192
- t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
193
- return t_attacks * in_range[team] # one hot valid targets
194
-
195
-
196
- def attack_fn(env, state_seq): # one hot attack list
197
- attacks = []
198
- for _, state, acts in state_seq:
199
- dists = dist_fn(env, state.unit_positions)
200
- ranges = env.unit_type_attack_ranges[state.unit_types]
201
- in_range = range_fn(env, dists, ranges)
202
- target = partial(target_fn, acts, in_range)
203
- attack = {"ally": target("ally"), "enemy": target("enemy")}
204
- attacks.append(attack)
205
- return attacks
206
-
207
-
208
- def bullet_fn(env, states):
209
- bullet_seq = []
210
- attack_seq = attack_fn(env, states)
211
-
212
- def aux_fn(team):
213
- bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
214
- # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
215
- return bullets
216
-
217
- state_zip = zip(states[:-1], states[1:])
218
- for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
219
- one_hot = attack_seq[i]
220
- ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
221
-
222
- ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
223
- enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
224
-
225
- enemy_bullets_source = state.unit_positions[
226
- enemy_bullets[:, 0] + env.num_allies
227
- ]
228
- ally_bullets_target = n_state.unit_positions[
229
- ally_bullets[:, 1] + env.num_allies
230
- ]
231
-
232
- ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
233
- enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
234
-
235
- bullet_seq.append((ally_bullets, enemy_bullets))
236
- return bullet_seq
237
-
238
-
239
- # test the visualizer
240
- if __name__ == "__main__":
241
- from jax import random, numpy as jnp
242
- from parabellum import Parabellum, scenarios
243
-
244
- # small_multiples() # testing small multiples (not working yet)
245
- # exit()
246
-
247
- n_envs = 2
248
- env = Parabellum(scenarios["default"], action_type="discrete")
249
- rng, reset_rng = random.split(random.PRNGKey(0))
250
- reset_key = random.split(reset_rng, n_envs)
251
- obs, state = vmap(env.reset)(reset_key)
252
- state_seq = []
253
-
254
- for i in range(100):
255
- rng, act_rng, step_rng = random.split(rng, 3)
256
- act_key = random.split(act_rng, (len(env.agents), n_envs))
257
- act = {
258
- a: vmap(env.action_space(a).sample)(act_key[i])
259
- for i, a in enumerate(env.agents)
260
- }
261
- step_key = random.split(step_rng, n_envs)
262
- state_seq.append((step_key, state, act))
263
- obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
264
-
265
- vis = Visualizer(env, state_seq)
266
- vis.animate()
68
+ def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
69
+ pygame.init()
70
+ frames = []
71
+ for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
72
+ frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
73
+ # use cv2 to write frames to video
74
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
75
+ out = cv2.VideoWriter(save_fname, fourcc, skin.fps, (skin.size + skin.pad * 2, skin.size + skin.pad * 2))
76
+ for frame in frames:
77
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
78
+ out.release()
79
+ pygame.quit()
80
+
81
+
82
+ def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> pygame.Surface:
83
+ dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
84
+ frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
85
+ return frame
86
+
87
+
88
+ def transform_frame(env, skin, frame):
89
+ frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
90
+ return frame
91
+
92
+
93
+ def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.ndarray:
94
+ """Create a frame"""
95
+ frame = init_frame(env, skin, image, state, action, idx)
96
+
97
+ pipeline = [render_background, render_agents, render_action, render_bullet]
98
+ for fn in pipeline:
99
+ frame = fn(env, skin, image, frame, state, action)
100
+
101
+ return transform_frame(env, skin, frame)
102
+
103
+
104
+ def render_background(env, skin, image, frame, state, action):
105
+ coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
106
+ frame.fill(skin.bg)
107
+ frame.blit(image, coords)
108
+ pygame.draw.rect(frame, skin.fg, coords, 3)
109
+ return frame
110
+
111
+
112
+ def render_action(env, skin, image, frame, state, action):
113
+ return frame
114
+
115
+
116
+ def render_bullet(env, skin, image, frame, state, action):
117
+ return frame
118
+
119
+ def render_agents(env, skin, image, frame, state, action):
120
+ units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
121
+ for idx, (pos, team, kind, health) in enumerate(zip(*units)):
122
+ pos = tuple((pos * skin.scale).astype(int) + skin.pad)
123
+ # draw the agent
124
+ if health > 0:
125
+ unit_size = env.unit_type_radiuses[kind]
126
+ radius = float(jnp.ceil((unit_size * skin.scale)).astype(int) + 1)
127
+ pygame.draw.circle(frame, skin.fg, pos, radius, 1)
128
+ pygame.draw.circle(frame, skin.bg, pos, radius + 1, 1)
129
+ return frame
130
+
131
+
132
+ def text_fn(text):
133
+ """rotate text upside down because of pygame issue"""
134
+ return pygame.transform.rotate(text, 180)
135
+
136
+
137
+ def image_fn(skin: Skin): # TODO:
138
+ """Create an image for background (basemap or maskmap)"""
139
+ motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
140
+ motif = (motif > 0).astype(np.uint8)
141
+ image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
142
+ image[motif == 1] = skin.fg
143
+ image = pygame.surfarray.make_surface(image)
144
+ image = pygame.transform.scale(image, (skin.size, skin.size))
145
+ return image
146
+
147
+
148
+ def unbatch_fn(state_seq, action_seq):
149
+ """state seq is a list of tuples of (step_key, state, actions)."""
150
+ if is_multi_run(state_seq):
151
+ n_envs = state_seq[0][1].unit_positions.shape[0]
152
+ state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
153
+ action_seq_seq = [jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)]
154
+ else:
155
+ state_seq_seq = [state_seq]
156
+ action_seq_seq = [action_seq]
157
+ return state_seq_seq, action_seq_seq
158
+
159
+
160
+ def expand_fn(env, state_seq, action_seq):
161
+ """Expand the state sequence"""
162
+ fn = env.expand_state_seq
163
+ state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
164
+ action_seq = [action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))]
165
+ return state_seq, action_seq
166
+
167
+
168
+ def is_multi_run(state_seq):
169
+ return state_seq[0][1].unit_positions.ndim > 2
@@ -0,0 +1,95 @@
1
+ Metadata-Version: 2.1
2
+ Name: parabellum
3
+ Version: 0.2.21
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Home-page: https://github.com/syrkis/parabellum
6
+ License: MIT
7
+ Keywords: warfare,simulation,parallel,environment
8
+ Author: Noah Syrkis
9
+ Author-email: desk@syrkis.com
10
+ Requires-Python: >=3.11,<4.0
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: contextily (>=1.6.0,<2.0.0)
16
+ Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
17
+ Requires-Dist: folium (>=0.17.0,<0.18.0)
18
+ Requires-Dist: geopandas (>=1.0.0,<2.0.0)
19
+ Requires-Dist: geopy (>=2.4.1,<3.0.0)
20
+ Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
21
+ Requires-Dist: jax (==0.4.17)
22
+ Requires-Dist: jaxmarl (==0.0.3)
23
+ Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
24
+ Requires-Dist: moviepy (>=1.0.3,<2.0.0)
25
+ Requires-Dist: numpy (<2)
26
+ Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
27
+ Requires-Dist: osmnx (>=1.9.3,<2.0.0)
28
+ Requires-Dist: pandas (>=2.2.2,<3.0.0)
29
+ Requires-Dist: poetry (>=1.8.3,<2.0.0)
30
+ Requires-Dist: pygame (>=2.5.2,<3.0.0)
31
+ Requires-Dist: rasterio (>=1.3.10,<2.0.0)
32
+ Requires-Dist: seaborn (>=0.13.2,<0.14.0)
33
+ Requires-Dist: tqdm (>=4.66.4,<5.0.0)
34
+ Project-URL: Repository, https://github.com/syrkis/parabellum
35
+ Description-Content-Type: text/markdown
36
+
37
+ # Parabellum
38
+
39
+ Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
40
+
41
+ [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
42
+
43
+ ## Features
44
+
45
+ - Obstacles and terrain integration
46
+ - Rasterized maps
47
+ - Blast radii simulation
48
+ - Friendly fire mechanics
49
+ - Pygame visualization
50
+ - JAX-based parallelization
51
+
52
+ ## Install
53
+
54
+ ```bash
55
+ pip install parabellum
56
+ ```
57
+
58
+ ## Quick Start
59
+
60
+ ```python
61
+ import parabellum as pb
62
+ from jax import random
63
+
64
+ terrain = pb.terrain_fn("Thun, Switzerland", 1000)
65
+ scenario = pb.make_scenario("Thun", terrain, 10, 10)
66
+ env = pb.Parabellum(scenario)
67
+
68
+ rng, key = random.split(random.PRNGKey(0))
69
+ obs, state = env.reset(key)
70
+
71
+ # Simulation loop
72
+ for _ in range(100):
73
+ rng, rng_act, key_step = random.split(key)
74
+ key_act = random.split(rng_act, len(env.agents))
75
+ act = {a: env.action_space(a).sample(k) for a, k in zip(env.agents, key_act)}
76
+ obs, state, reward, done, info = env.step(key_step, act, state)
77
+
78
+ # Visualize
79
+ vis = pb.Visualizer(env, state_sequence)
80
+ vis.animate()
81
+ ```
82
+
83
+ ## Documentation
84
+
85
+ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.io)
86
+
87
+ ## Team
88
+
89
+ - Noah Syrkis
90
+ - Timothée Anne
91
+ - Supervisor: Sebastian Risi
92
+
93
+ ## License
94
+
95
+ MIT
@@ -0,0 +1,10 @@
1
+ parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
+ parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
+ parabellum/env.py,sha256=L6GHlLxywpkV1bRnZcYBURREPP4CRfet_pEwCt5DB04,16724
4
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
+ parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
6
+ parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
+ parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
8
+ parabellum-0.2.21.dist-info/METADATA,sha256=-K-3eYl1BvR3tFsiTxTyfHErQJdgPQZx08iq_kS2544,2671
9
+ parabellum-0.2.21.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
+ parabellum-0.2.21.dist-info/RECORD,,
@@ -1,104 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: parabellum
3
- Version: 0.2.19
4
- Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
- Author: Noah Syrkis
9
- Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
11
- Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
- Requires-Dist: jax (==0.4.17)
17
- Requires-Dist: jaxmarl (==0.0.3)
18
- Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
19
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
20
- Requires-Dist: poetry (>=1.8.3,<2.0.0)
21
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
22
- Requires-Dist: tqdm (>=4.66.4,<5.0.0)
23
- Project-URL: Repository, https://github.com/syrkis/parabellum
24
- Description-Content-Type: text/markdown
25
-
26
- # parabellum
27
-
28
- Parabellum is an ultra-scalable, high-performance warfare simulation engine.
29
- It is based on JaxMARL's SMAX environment, but has been heavily modified to
30
- support a wide range of new features and improvements.
31
-
32
- ## Installation
33
-
34
- Parabellum is written in Python 3.11 and can be installed using pip:
35
-
36
- ```bash
37
- pip install parabellum
38
- ```
39
-
40
- ## Usage
41
-
42
- Parabellum is designed to be used in conjunction with JAX, a high-performance
43
- numerical computing library. Here is a simple example of how to use Parabellum
44
- to simulate a game with 10 agents and 10 enemies, each taking random actions:
45
-
46
- ```python
47
- import parabellum as pb
48
- from jax import random
49
-
50
- # define the scenario
51
- kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
52
- scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
53
-
54
- # create the environment
55
- kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
56
- env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
57
-
58
- # initiate stochasticity
59
- rng = random.PRNGKey(0)
60
- rng, key = random.split(rng)
61
-
62
- # initialize the environment state
63
- obs, state = env.reset(key)
64
- state_sequence = []
65
-
66
- for _ in range(1000):
67
-
68
- # manage stochasticity
69
- rng, rng_act, key_step = random.split(key)
70
- key_act = random.split(rng_act, len(env.agents))
71
-
72
- # sample actions and append to state sequence
73
- act = {a: env.action_space(a).sample(k)
74
- for a, k in zip(env.agents, key_act)}
75
-
76
- # step the environment
77
- state_sequence.append((key_act, state, act))
78
- obs, state, reward, done, info = env.step(key_step, act, state)
79
-
80
-
81
- # save visualization of the state sequence
82
- vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
83
- vis.animate()
84
- ```
85
-
86
-
87
- ## Features
88
-
89
- - Obstacles — can be inserted in
90
-
91
- ## TODO
92
-
93
- - [x] Parallel pygame vis
94
- - [ ] Parallel bullet renderings
95
- - [ ] Combine parallell plots into one (maybe out of parabellum scope)
96
- - [ ] Color for health?
97
- - [ ] Add the ability to see ongoing game.
98
- - [ ] Bug test friendly fire.
99
- - [x] Start sim from arbitrary state.
100
- - [ ] Save when the episode ends in some state/obs variable
101
- - [ ] Look for the source of the bug when using more Allies than Enemies
102
- - [ ] Y inversed axis for parabellum visualization
103
- - [ ] Units see through obstacles?
104
-
@@ -1,8 +0,0 @@
1
- parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
- parabellum/env.py,sha256=Z8zpdCaEi5HFwN0Vd2hukOarkPSg0EKZErTRts3JQ5E,16023
3
- parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
- parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
- parabellum/vis.py,sha256=JFVTnBg-LV4jZNw6cysU6NS8ZxeMpg5wz3JOi-lrnzY,10699
6
- parabellum-0.2.19.dist-info/METADATA,sha256=DsEBAlESj8BwGSphmPPylStoXH_g_x_Iy3WJ3KEwjc0,3223
7
- parabellum-0.2.19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
- parabellum-0.2.19.dist-info/RECORD,,