parabellum 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,4 +1,12 @@
1
- from .env import Parabellum, Scenario, scenarios
1
+ from .env import Environment, Scenario, scenarios, make_scenario
2
2
  from .vis import Visualizer
3
+ from .map import terrain_fn
3
4
 
4
- __all__ = ["Parabellum", "Visualizer", "Scenario", "scenarios"]
5
+ __all__ = [
6
+ "Environment",
7
+ "Scenario",
8
+ "scenarios",
9
+ "make_scenario",
10
+ "Visualizer",
11
+ "terrain_fn",
12
+ ]
parabellum/env.py CHANGED
@@ -17,11 +17,8 @@ from functools import partial
17
17
  class Scenario:
18
18
  """Parabellum scenario"""
19
19
 
20
+ place: str
20
21
  terrain_raster: chex.Array
21
-
22
- obstacle_coords: chex.Array # TODO: use map instead of obstacles
23
- obstacle_deltas: chex.Array
24
-
25
22
  unit_types: chex.Array
26
23
  num_allies: int
27
24
  num_enemies: int
@@ -33,9 +30,8 @@ class Scenario:
33
30
  # default scenario
34
31
  scenarios = {
35
32
  "default": Scenario(
36
- jnp.eye(128, dtype=jnp.uint8),
37
- jnp.array([[80, 0], [16, 12]]),
38
- jnp.array([[0, 80], [0, 20]]),
33
+ "Identity Town",
34
+ jnp.eye(64, dtype=jnp.uint8),
39
35
  jnp.zeros((19,), dtype=jnp.uint8),
40
36
  9,
41
37
  10,
@@ -43,57 +39,63 @@ scenarios = {
43
39
  }
44
40
 
45
41
 
46
- class Parabellum(SMAX):
42
+ def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
43
+ """Create a scenario"""
44
+ num_agents = num_allies + num_enemies
45
+ unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
46
+ return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
47
+
48
+
49
+ def spawn_fn(pool: jnp.ndarray, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
50
+ """Spawns n agents on a map."""
51
+ rng, key_start, key_noise = random.split(rng, 3)
52
+ noise = random.uniform(key_noise, (n, 2)) * 0.5
53
+
54
+ # select n random (x, y)-coords where sector == True
55
+ idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
56
+ coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
57
+
58
+ return coords + noise + offset
59
+
60
+
61
+ def sector_fn(terrain: jnp.ndarray, sector_id: int):
62
+ """return sector slice of terrain"""
63
+ width, height = terrain.shape
64
+ coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
65
+ sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
66
+ offset = jnp.array([coordx, coordy])
67
+ # sector is jnp.nonzero
68
+ return jnp.nonzero(sector), offset
69
+
70
+
71
+ class Environment(SMAX):
47
72
  def __init__(self, scenario: Scenario, **kwargs):
48
73
  map_height, map_width = scenario.terrain_raster.shape
49
74
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
50
- super(Parabellum, self).__init__(**args, **kwargs)
75
+ super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
51
76
  self.terrain_raster = scenario.terrain_raster
52
- self.obstacle_coords = scenario.obstacle_coords
53
- self.obstacle_deltas = scenario.obstacle_deltas
54
- self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
77
+ # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
78
+ # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
79
+ # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
80
+ self.scenario = scenario
81
+ self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
55
82
  self.max_steps = 200
56
83
  self._push_units_away = lambda x: x # overwrite push units
84
+ self.top_sector = sector_fn(self.terrain_raster, 0)
85
+ self.low_sector = sector_fn(self.terrain_raster, 24)
57
86
 
58
87
  @partial(jax.jit, static_argnums=(0,))
59
- def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
88
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
60
89
  """Environment-specific reset."""
61
- key, team_0_key, team_1_key = jax.random.split(key, num=3)
62
- team_0_start = jnp.stack(
63
- [jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
64
- )
65
- team_0_start_noise = jax.random.uniform(
66
- team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
67
- )
68
- team_0_start = team_0_start + team_0_start_noise
69
- team_1_start = jnp.stack(
70
- [jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
71
- * self.num_enemies
72
- )
73
- team_1_start_noise = jax.random.uniform(
74
- team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
75
- )
76
- team_1_start = team_1_start + team_1_start_noise
90
+ ally_key, enemy_key = jax.random.split(rng)
91
+ team_0_start = spawn_fn(*self.top_sector, self.num_allies, ally_key)
92
+ team_1_start = spawn_fn(*self.low_sector, self.num_enemies, enemy_key)
77
93
  unit_positions = jnp.concatenate([team_0_start, team_1_start])
78
- key, pos_key = jax.random.split(key)
79
- generated_unit_positions = self.position_generator.generate(pos_key)
80
- unit_positions = jax.lax.select(
81
- self.smacv2_position_generation, generated_unit_positions, unit_positions
82
- )
83
94
  unit_teams = jnp.zeros((self.num_agents,))
84
95
  unit_teams = unit_teams.at[self.num_allies :].set(1)
85
96
  unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
86
97
  # default behaviour spawn all marines
87
- unit_types = (
88
- jnp.zeros((self.num_agents,), dtype=jnp.uint8)
89
- if self.scenario is None
90
- else self.scenario
91
- )
92
- key, unit_type_key = jax.random.split(key)
93
- generated_unit_types = self.unit_type_generator.generate(unit_type_key)
94
- unit_types = jax.lax.select(
95
- self.smacv2_unit_type_generation, generated_unit_types, unit_types
96
- )
98
+ unit_types = self.scenario.unit_types
97
99
  unit_health = self.unit_type_health[unit_types]
98
100
  state = State(
99
101
  unit_positions=unit_positions,
@@ -180,12 +182,12 @@ class Parabellum(SMAX):
180
182
 
181
183
  #######################################################################
182
184
  ############################################ avoid going into obstacles
183
- obs = self.obstacle_coords
185
+ """ obs = self.obstacle_coords
184
186
  obs_end = obs + self.obstacle_deltas
185
- inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
186
- rastersects = raster_crossing(pos, new_pos)
187
- flag = jnp.logical_or(inters, rastersects)
188
- new_pos = jnp.where(flag, pos, new_pos)
187
+ inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
188
+ clash = raster_crossing(pos, new_pos)
189
+ # flag = jnp.logical_or(inters, rastersects)
190
+ new_pos = jnp.where(clash, pos, new_pos)
189
191
 
190
192
  #######################################################################
191
193
  #######################################################################
@@ -310,14 +312,14 @@ class Parabellum(SMAX):
310
312
  [self.map_width, 0],
311
313
  ]
312
314
  )
313
- obstacle_coords = jnp.concatenate(
315
+ """ obstacle_coords = jnp.concatenate(
314
316
  [self.obstacle_coords, bondaries_coords]
315
317
  ) # add the map boundaries to the obstacles to avoid
316
318
  obstacle_deltas = jnp.concatenate(
317
319
  [self.obstacle_deltas, bondaries_deltas]
318
320
  ) # add the map boundaries to the obstacles to avoid
319
321
  obst_start = obstacle_coords
320
- obst_end = obst_start + obstacle_deltas
322
+ obst_end = obst_start + obstacle_deltas """
321
323
 
322
324
  def check_obstacles(pos, new_pos, obst_start, obst_end):
323
325
  intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
@@ -325,9 +327,9 @@ class Parabellum(SMAX):
325
327
  flag = jnp.logical_or(intersects, rastersect)
326
328
  return jnp.where(flag, pos, new_pos)
327
329
 
328
- pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
330
+ """ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
329
331
  pos, new_pos, obst_start, obst_end
330
- )
332
+ ) """
331
333
 
332
334
  # Multiple enemies can attack the same unit.
333
335
  # We have `(health_diff, attacked_idx)` pairs.
@@ -373,13 +375,16 @@ class Parabellum(SMAX):
373
375
 
374
376
  if __name__ == "__main__":
375
377
  n_envs = 4
376
- kwargs = dict(map_width=64, map_height=64)
377
- env = Parabellum(scenarios["default"], **kwargs)
378
+
379
+ env = Environment(scenarios["default"])
378
380
  rng, reset_rng = random.split(random.PRNGKey(0))
379
381
  reset_key = random.split(reset_rng, n_envs)
380
382
  obs, state = vmap(env.reset)(reset_key)
381
383
  state_seq = []
382
384
 
385
+ print(state.unit_positions)
386
+ exit()
387
+
383
388
  for i in range(10):
384
389
  rng, act_rng, step_rng = random.split(rng, 3)
385
390
  act_key = random.split(act_rng, (len(env.agents), n_envs))
parabellum/map.py CHANGED
@@ -4,13 +4,46 @@
4
4
 
5
5
  # imports
6
6
  import jax.numpy as jnp
7
- import jax
7
+ from geopy.geocoders import Nominatim
8
+ import geopandas as gpd
9
+ import osmnx as ox
10
+ import rasterio
11
+ from jax import random
12
+ from rasterio import features
13
+ import rasterio.transform
14
+
15
+ # constants
16
+ geolocator = Nominatim(user_agent="parabellum")
17
+ tags = {"building": True}
8
18
 
9
19
 
10
20
  # functions
11
- def map_fn(width, height, obst_coord, obst_delta):
12
- """Create a map from the given width, height, and obstacle coordinates and deltas."""
13
- m = jnp.zeros((width, height))
14
- for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
15
- m = m.at[x : x + dx, y : y + dy].set(1)
16
- return m
21
+ def terrain_fn(place: str, size: int = 1000):
22
+ """Returns a rasterized map of a given location."""
23
+
24
+ # location info
25
+ location = geolocator.geocode(place)
26
+ coords = (location.latitude, location.longitude)
27
+
28
+ # shape info
29
+ geometry = ox.features_from_point(coords, tags=tags, dist=size // 2)
30
+ gdf = gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
31
+
32
+ # raster info
33
+ t = rasterio.transform.from_bounds(*gdf.total_bounds, size, size)
34
+ raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=t)
35
+
36
+ # rotate 180 degrees
37
+ raster = jnp.rot90(raster, 2)
38
+
39
+ return jnp.array(raster).astype(jnp.uint8)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ import seaborn as sns
44
+
45
+ place = "Vesterbro, Copenhagen, Denmark"
46
+ terrain = terrain_fn(place)
47
+ rng, key = random.split(random.PRNGKey(0))
48
+ agents = spawn_fn(terrain, 12, 100, key)
49
+ print(agents)
parabellum/run.py CHANGED
@@ -39,7 +39,7 @@ class Game:
39
39
  obs: Dict
40
40
  state_seq: StateSeq
41
41
  control: Control
42
- env: pb.Parabellum
42
+ env: pb.Environment
43
43
  rng: random.PRNGKey
44
44
 
45
45
 
@@ -123,5 +123,3 @@ if __name__ == "__main__":
123
123
  game = game if game.control.paused else render(game)
124
124
 
125
125
  pygame.quit()
126
-
127
-
parabellum/vis.py CHANGED
@@ -1,4 +1,6 @@
1
- """Visualizer for the Parabellum environment"""
1
+ """
2
+ Visualizer for the Parabellum environment
3
+ """
2
4
 
3
5
  from tqdm import tqdm
4
6
  import jax.numpy as jnp
@@ -9,6 +11,7 @@ from functools import partial
9
11
  import darkdetect
10
12
  import numpy as np
11
13
  import pygame
14
+ import matplotlib.pyplot as plt
12
15
  import os
13
16
  from moviepy.editor import ImageSequenceClip
14
17
  from typing import Optional
@@ -31,15 +34,24 @@ def small_multiples():
31
34
  print(len(clips))
32
35
 
33
36
 
37
+ def text_fn(text):
38
+ """rotate text upside down because of pygame issue"""
39
+ return pygame.transform.rotate(text, 180)
40
+
41
+
34
42
  class Visualizer(SMAXVisualizer):
35
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
36
- super().__init__(env, state_seq, reward_seq)
37
- # remove fig and ax from super
43
+ def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None, skin=None):
38
44
  self.fig, self.ax = None, None
45
+ super().__init__(env, state_seq, reward_seq)
46
+ # clear the figure made by SMAXVisualizer
47
+ plt.close()
39
48
  self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
40
- self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
41
- self.s = 1000
42
- self.scale = self.s / self.env.map_width
49
+ self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
50
+ self.pad = 75
51
+ # TODO: make sure it's always a 1024x1024 image
52
+ self.width = 1000
53
+ self.s = self.width + self.pad + self.pad
54
+ self.scale = self.width / env.map_width
43
55
  self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
44
56
  # self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
45
57
 
@@ -63,11 +75,13 @@ class Visualizer(SMAXVisualizer):
63
75
  def animate_one(self, state_seq, action_seq, save_fname):
64
76
  frames = [] # frames for the video
65
77
  pygame.init() # initialize pygame
66
- terrain = np.array(self.env.terrain_raster)
78
+ terrain = np.array(self.env.terrain_raster.T)
67
79
  rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
80
+ if darkdetect.isLight():
81
+ rgb_array += 255
68
82
  rgb_array[terrain == 1] = self.fg
69
83
  mask_surface = pygame.surfarray.make_surface(rgb_array)
70
- mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
84
+ mask_surface = pygame.transform.scale(mask_surface, (self.width, self.width))
71
85
 
72
86
  for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
73
87
  action = action_seq[idx // self.env.world_steps_per_env_step]
@@ -75,11 +89,33 @@ class Visualizer(SMAXVisualizer):
75
89
  (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
76
90
  )
77
91
  screen.fill(self.bg) # fill the screen with the background color
78
- screen.blit(mask_surface, (0, 0))
92
+ screen.blit(
93
+ mask_surface,
94
+ (self.pad, self.pad, self.width, self.width),
95
+ )
96
+ # add env.scenario.place to the title (in top padding)
97
+ font = pygame.font.SysFont("Fira Code", 18)
98
+ width = self.env.map_width
99
+ title = f"{width}x{width}m in {self.env.scenario.place}"
100
+ text = text_fn(font.render(title, True, self.fg))
101
+ # center the text
102
+ screen.blit(text, (self.s // 2 - text.get_width() // 2, self.pad // 4))
103
+ # draw edge around terrain
104
+ pygame.draw.rect(
105
+ screen,
106
+ self.fg,
107
+ (
108
+ self.pad - 2,
109
+ self.pad - 2,
110
+ self.width + 4,
111
+ self.width + 4,
112
+ ),
113
+ 2,
114
+ )
79
115
 
80
116
  self.render_agents(screen, state) # render the agents
81
117
  self.render_action(screen, action)
82
- self.render_obstacles(screen) # render the obstacles
118
+ # self.render_obstacles(screen) # render the obstacles
83
119
 
84
120
  # bullets
85
121
  """ if idx < len(self.bullet_seq) * 8:
@@ -87,15 +123,16 @@ class Visualizer(SMAXVisualizer):
87
123
  self.render_bullets(screen, bullets, idx % 8) """
88
124
 
89
125
  # rotate the screen and append to frames
90
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
126
+ pixels = pygame.surfarray.pixels3d(screen).swapaxes(0, 1)
127
+ # rotate the screen 180 degrees (transpose and flip)
128
+ pixels = np.rot90(pixels, 2) # pygame starts in bottom left
129
+ frames.append(pixels)
91
130
  # save the images
92
131
  clip = ImageSequenceClip(frames, fps=48)
93
132
  clip.write_videofile(save_fname, fps=48)
94
- # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
133
+ clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
95
134
  pygame.quit()
96
135
 
97
- return clip
98
-
99
136
  def render_agents(self, screen, state):
100
137
  time_tuple = zip(
101
138
  state.unit_positions,
@@ -105,7 +142,7 @@ class Visualizer(SMAXVisualizer):
105
142
  )
106
143
  for idx, (pos, team, kind, hp) in enumerate(time_tuple):
107
144
  face_col = self.fg if int(team.item()) == 0 else self.bg
108
- pos = tuple((pos * self.scale).tolist())
145
+ pos = tuple(((pos * self.scale) + self.pad).tolist())
109
146
  # draw the agent
110
147
  if hp > 0:
111
148
  hp_frac = hp / self.env.unit_type_health[kind]
@@ -124,26 +161,31 @@ class Visualizer(SMAXVisualizer):
124
161
  # work out which agents are being shot
125
162
 
126
163
  def render_action(self, screen, action):
127
- def coord_fn(idx, n, team):
164
+ if self.env.action_type != "discrete":
165
+ return
166
+
167
+ def coord_fn(idx, n, team, text):
168
+ text_adj = text.get_width() / 2
169
+ is_ally = team == "ally"
128
170
  return (
129
- self.s / 20 if team == 0 else self.s - self.s / 20,
130
171
  # vertically centered so that n / 2 is above and below the center
172
+ self.pad + self.width + self.pad / 2 - text_adj
173
+ if is_ally
174
+ else self.pad / 2 - text_adj,
131
175
  self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
132
176
  )
133
177
 
134
- for idx in range(self.env.num_allies):
135
- symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
136
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
137
- text = font.render(symb, True, self.fg)
138
- coord = coord_fn(idx, self.env.num_allies, 0)
139
- screen.blit(text, coord)
140
-
141
- for idx in range(self.env.num_enemies):
142
- symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
143
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
144
- text = font.render(symb, True, self.fg)
145
- coord = coord_fn(idx, self.env.num_enemies, 1)
146
- screen.blit(text, coord)
178
+ for team, number in [("ally", 0), ("enemy", 1)]:
179
+ for idx in range(self.env.num_allies):
180
+ symb = action_to_symbol.get(
181
+ action[f"{team}_{idx}"].astype(int).item(), "Ø"
182
+ )
183
+ font = pygame.font.SysFont(
184
+ "Fira Code", jnp.sqrt(self.s).astype(int).item()
185
+ )
186
+ text = text_fn(font.render(symb, True, self.fg))
187
+ coord = coord_fn(idx, self.env.num_allies, team, text)
188
+ screen.blit(text, coord)
147
189
 
148
190
  def render_obstacles(self, screen):
149
191
  for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
@@ -238,7 +280,7 @@ if __name__ == "__main__":
238
280
  # exit()
239
281
 
240
282
  n_envs = 2
241
- env = Parabellum(scenarios["default"])
283
+ env = Parabellum(scenarios["default"], action_type="discrete")
242
284
  rng, reset_rng = random.split(random.PRNGKey(0))
243
285
  reset_key = random.split(reset_rng, n_envs)
244
286
  obs, state = vmap(env.reset)(reset_key)
@@ -248,7 +290,7 @@ if __name__ == "__main__":
248
290
  rng, act_rng, step_rng = random.split(rng, 3)
249
291
  act_key = random.split(act_rng, (len(env.agents), n_envs))
250
292
  act = {
251
- a: jnp.ones_like(vmap(env.action_space(a).sample)(act_key[i]))
293
+ a: vmap(env.action_space(a).sample)(act_key[i])
252
294
  for i, a in enumerate(env.agents)
253
295
  }
254
296
  step_key = random.split(step_rng, n_envs)
@@ -0,0 +1,90 @@
1
+ Metadata-Version: 2.1
2
+ Name: parabellum
3
+ Version: 0.2.20
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Home-page: https://github.com/syrkis/parabellum
6
+ License: MIT
7
+ Keywords: warfare,simulation,parallel,environment
8
+ Author: Noah Syrkis
9
+ Author-email: desk@syrkis.com
10
+ Requires-Python: >=3.11,<4.0
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
+ Requires-Dist: geopandas (>=1.0.0,<2.0.0)
17
+ Requires-Dist: geopy (>=2.4.1,<3.0.0)
18
+ Requires-Dist: jax (==0.4.17)
19
+ Requires-Dist: jaxmarl (==0.0.3)
20
+ Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
21
+ Requires-Dist: moviepy (>=1.0.3,<2.0.0)
22
+ Requires-Dist: numpy (<2)
23
+ Requires-Dist: osmnx (>=1.9.3,<2.0.0)
24
+ Requires-Dist: poetry (>=1.8.3,<2.0.0)
25
+ Requires-Dist: pygame (>=2.5.2,<3.0.0)
26
+ Requires-Dist: rasterio (>=1.3.10,<2.0.0)
27
+ Requires-Dist: seaborn (>=0.13.2,<0.14.0)
28
+ Requires-Dist: tqdm (>=4.66.4,<5.0.0)
29
+ Project-URL: Repository, https://github.com/syrkis/parabellum
30
+ Description-Content-Type: text/markdown
31
+
32
+ # Parabellum
33
+
34
+ Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
35
+
36
+ [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
37
+
38
+ ## Features
39
+
40
+ - Obstacles and terrain integration
41
+ - Rasterized maps
42
+ - Blast radii simulation
43
+ - Friendly fire mechanics
44
+ - Pygame visualization
45
+ - JAX-based parallelization
46
+
47
+ ## Install
48
+
49
+ ```bash
50
+ pip install parabellum
51
+ ```
52
+
53
+ ## Quick Start
54
+
55
+ ```python
56
+ import parabellum as pb
57
+ from jax import random
58
+
59
+ terrain = pb.terrain_fn("Thun, Switzerland", 1000)
60
+ scenario = pb.make_scenario("Thun", terrain, 10, 10)
61
+ env = pb.Parabellum(scenario)
62
+
63
+ rng, key = random.split(random.PRNGKey(0))
64
+ obs, state = env.reset(key)
65
+
66
+ # Simulation loop
67
+ for _ in range(100):
68
+ rng, rng_act, key_step = random.split(key)
69
+ key_act = random.split(rng_act, len(env.agents))
70
+ act = {a: env.action_space(a).sample(k) for a, k in zip(env.agents, key_act)}
71
+ obs, state, reward, done, info = env.step(key_step, act, state)
72
+
73
+ # Visualize
74
+ vis = pb.Visualizer(env, state_sequence)
75
+ vis.animate()
76
+ ```
77
+
78
+ ## Documentation
79
+
80
+ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.io)
81
+
82
+ ## Team
83
+
84
+ - Noah Syrkis
85
+ - Timothée Anne
86
+ - Supervisor: Sebastian Risi
87
+
88
+ ## License
89
+
90
+ MIT
@@ -0,0 +1,8 @@
1
+ parabellum/__init__.py,sha256=cI1kxVQ274VZrBLxUYvB_j6UqhFY44sId4NgVjbf678,245
2
+ parabellum/env.py,sha256=B2WaEzCQnakVB1AuYqFsb7aiPIbaUsuq0CjdKq4pQm8,16204
3
+ parabellum/map.py,sha256=CdPebRuGafj9fAydAyTa8bqpr-PldZw7heKRReGTDBg,1257
4
+ parabellum/run.py,sha256=HyPdz5iVD8q0iYaZL2Nf02fGHByHpecUGoPlQrq9v8s,3411
5
+ parabellum/vis.py,sha256=uIhN9VSJlT4XoNuNCiU_Bw2VSNoJFOYlScCTNrMkGsg,11977
6
+ parabellum-0.2.20.dist-info/METADATA,sha256=IO3Wsx_nTM3v63FtbnJqaEfPfXOyhNGtKOm_ppchUqo,2453
7
+ parabellum-0.2.20.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
+ parabellum-0.2.20.dist-info/RECORD,,
@@ -1,104 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: parabellum
3
- Version: 0.2.18
4
- Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
- Author: Noah Syrkis
9
- Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
11
- Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
- Requires-Dist: jax (==0.4.17)
17
- Requires-Dist: jaxmarl (==0.0.3)
18
- Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
19
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
20
- Requires-Dist: poetry (>=1.8.3,<2.0.0)
21
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
22
- Requires-Dist: tqdm (>=4.66.4,<5.0.0)
23
- Project-URL: Repository, https://github.com/syrkis/parabellum
24
- Description-Content-Type: text/markdown
25
-
26
- # parabellum
27
-
28
- Parabellum is an ultra-scalable, high-performance warfare simulation engine.
29
- It is based on JaxMARL's SMAX environment, but has been heavily modified to
30
- support a wide range of new features and improvements.
31
-
32
- ## Installation
33
-
34
- Parabellum is written in Python 3.11 and can be installed using pip:
35
-
36
- ```bash
37
- pip install parabellum
38
- ```
39
-
40
- ## Usage
41
-
42
- Parabellum is designed to be used in conjunction with JAX, a high-performance
43
- numerical computing library. Here is a simple example of how to use Parabellum
44
- to simulate a game with 10 agents and 10 enemies, each taking random actions:
45
-
46
- ```python
47
- import parabellum as pb
48
- from jax import random
49
-
50
- # define the scenario
51
- kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
52
- scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
53
-
54
- # create the environment
55
- kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
56
- env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
57
-
58
- # initiate stochasticity
59
- rng = random.PRNGKey(0)
60
- rng, key = random.split(rng)
61
-
62
- # initialize the environment state
63
- obs, state = env.reset(key)
64
- state_sequence = []
65
-
66
- for _ in range(1000):
67
-
68
- # manage stochasticity
69
- rng, rng_act, key_step = random.split(key)
70
- key_act = random.split(rng_act, len(env.agents))
71
-
72
- # sample actions and append to state sequence
73
- act = {a: env.action_space(a).sample(k)
74
- for a, k in zip(env.agents, key_act)}
75
-
76
- # step the environment
77
- state_sequence.append((key_act, state, act))
78
- obs, state, reward, done, info = env.step(key_step, act, state)
79
-
80
-
81
- # save visualization of the state sequence
82
- vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
83
- vis.animate()
84
- ```
85
-
86
-
87
- ## Features
88
-
89
- - Obstacles — can be inserted in
90
-
91
- ## TODO
92
-
93
- - [x] Parallel pygame vis
94
- - [ ] Parallel bullet renderings
95
- - [ ] Combine parallell plots into one (maybe out of parabellum scope)
96
- - [ ] Color for health?
97
- - [ ] Add the ability to see ongoing game.
98
- - [ ] Bug test friendly fire.
99
- - [x] Start sim from arbitrary state.
100
- - [ ] Save when the episode ends in some state/obs variable
101
- - [ ] Look for the source of the bug when using more Allies than Enemies
102
- - [ ] Y inversed axis for parabellum visualization
103
- - [ ] Units see through obstacles?
104
-
@@ -1,8 +0,0 @@
1
- parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
- parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
3
- parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
- parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
- parabellum/vis.py,sha256=vb1ndWuGHBhG9Y_9ygFUDRGRfTTzKESD43jNCHMp2hQ,10564
6
- parabellum-0.2.18.dist-info/METADATA,sha256=-DXRDSO86-I9gZlFrnkRz6wkYplXGQY5TeCHfiPOOno,3223
7
- parabellum-0.2.18.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
- parabellum-0.2.18.dist-info/RECORD,,