parabellum 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,4 +1,12 @@
1
- from .env import Parabellum, Scenario, scenarios
1
+ from .env import Environment, Scenario, scenarios, make_scenario
2
2
  from .vis import Visualizer
3
+ from .map import terrain_fn
3
4
 
4
- __all__ = ["Parabellum", "Visualizer", "Scenario", "scenarios"]
5
+ __all__ = [
6
+ "Environment",
7
+ "Scenario",
8
+ "scenarios",
9
+ "make_scenario",
10
+ "Visualizer",
11
+ "terrain_fn",
12
+ ]
parabellum/env.py CHANGED
@@ -17,11 +17,8 @@ from functools import partial
17
17
  class Scenario:
18
18
  """Parabellum scenario"""
19
19
 
20
+ place: str
20
21
  terrain_raster: chex.Array
21
-
22
- obstacle_coords: chex.Array # TODO: use map instead of obstacles
23
- obstacle_deltas: chex.Array
24
-
25
22
  unit_types: chex.Array
26
23
  num_allies: int
27
24
  num_enemies: int
@@ -33,9 +30,8 @@ class Scenario:
33
30
  # default scenario
34
31
  scenarios = {
35
32
  "default": Scenario(
33
+ "Identity Town",
36
34
  jnp.eye(64, dtype=jnp.uint8),
37
- jnp.array([[80, 0], [16, 12]]),
38
- jnp.array([[0, 80], [0, 20]]),
39
35
  jnp.zeros((19,), dtype=jnp.uint8),
40
36
  9,
41
37
  10,
@@ -43,57 +39,63 @@ scenarios = {
43
39
  }
44
40
 
45
41
 
46
- class Parabellum(SMAX):
42
+ def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
43
+ """Create a scenario"""
44
+ num_agents = num_allies + num_enemies
45
+ unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
46
+ return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
47
+
48
+
49
+ def spawn_fn(pool: jnp.ndarray, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
50
+ """Spawns n agents on a map."""
51
+ rng, key_start, key_noise = random.split(rng, 3)
52
+ noise = random.uniform(key_noise, (n, 2)) * 0.5
53
+
54
+ # select n random (x, y)-coords where sector == True
55
+ idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
56
+ coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
57
+
58
+ return coords + noise + offset
59
+
60
+
61
+ def sector_fn(terrain: jnp.ndarray, sector_id: int):
62
+ """return sector slice of terrain"""
63
+ width, height = terrain.shape
64
+ coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
65
+ sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
66
+ offset = jnp.array([coordx, coordy])
67
+ # sector is jnp.nonzero
68
+ return jnp.nonzero(sector), offset
69
+
70
+
71
+ class Environment(SMAX):
47
72
  def __init__(self, scenario: Scenario, **kwargs):
48
73
  map_height, map_width = scenario.terrain_raster.shape
49
74
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
50
- super(Parabellum, self).__init__(**args, **kwargs)
75
+ super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
51
76
  self.terrain_raster = scenario.terrain_raster
52
- self.obstacle_coords = scenario.obstacle_coords
53
- self.obstacle_deltas = scenario.obstacle_deltas
54
- self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
77
+ # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
78
+ # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
79
+ # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
80
+ self.scenario = scenario
81
+ self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
55
82
  self.max_steps = 200
56
83
  self._push_units_away = lambda x: x # overwrite push units
84
+ self.top_sector = sector_fn(self.terrain_raster, 0)
85
+ self.low_sector = sector_fn(self.terrain_raster, 24)
57
86
 
58
87
  @partial(jax.jit, static_argnums=(0,))
59
- def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
88
+ def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
60
89
  """Environment-specific reset."""
61
- key, team_0_key, team_1_key = jax.random.split(key, num=3)
62
- team_0_start = jnp.stack(
63
- [jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
64
- )
65
- team_0_start_noise = jax.random.uniform(
66
- team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
67
- )
68
- team_0_start = team_0_start + team_0_start_noise
69
- team_1_start = jnp.stack(
70
- [jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
71
- * self.num_enemies
72
- )
73
- team_1_start_noise = jax.random.uniform(
74
- team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
75
- )
76
- team_1_start = team_1_start + team_1_start_noise
90
+ ally_key, enemy_key = jax.random.split(rng)
91
+ team_0_start = spawn_fn(*self.top_sector, self.num_allies, ally_key)
92
+ team_1_start = spawn_fn(*self.low_sector, self.num_enemies, enemy_key)
77
93
  unit_positions = jnp.concatenate([team_0_start, team_1_start])
78
- key, pos_key = jax.random.split(key)
79
- generated_unit_positions = self.position_generator.generate(pos_key)
80
- unit_positions = jax.lax.select(
81
- self.smacv2_position_generation, generated_unit_positions, unit_positions
82
- )
83
94
  unit_teams = jnp.zeros((self.num_agents,))
84
95
  unit_teams = unit_teams.at[self.num_allies :].set(1)
85
96
  unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
86
97
  # default behaviour spawn all marines
87
- unit_types = (
88
- jnp.zeros((self.num_agents,), dtype=jnp.uint8)
89
- if self.scenario is None
90
- else self.scenario
91
- )
92
- key, unit_type_key = jax.random.split(key)
93
- generated_unit_types = self.unit_type_generator.generate(unit_type_key)
94
- unit_types = jax.lax.select(
95
- self.smacv2_unit_type_generation, generated_unit_types, unit_types
96
- )
98
+ unit_types = self.scenario.unit_types
97
99
  unit_health = self.unit_type_health[unit_types]
98
100
  state = State(
99
101
  unit_positions=unit_positions,
@@ -180,12 +182,12 @@ class Parabellum(SMAX):
180
182
 
181
183
  #######################################################################
182
184
  ############################################ avoid going into obstacles
183
- obs = self.obstacle_coords
185
+ """ obs = self.obstacle_coords
184
186
  obs_end = obs + self.obstacle_deltas
185
- inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
186
- rastersects = raster_crossing(pos, new_pos)
187
- flag = jnp.logical_or(inters, rastersects)
188
- new_pos = jnp.where(flag, pos, new_pos)
187
+ inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
188
+ clash = raster_crossing(pos, new_pos)
189
+ # flag = jnp.logical_or(inters, rastersects)
190
+ new_pos = jnp.where(clash, pos, new_pos)
189
191
 
190
192
  #######################################################################
191
193
  #######################################################################
@@ -310,14 +312,14 @@ class Parabellum(SMAX):
310
312
  [self.map_width, 0],
311
313
  ]
312
314
  )
313
- obstacle_coords = jnp.concatenate(
315
+ """ obstacle_coords = jnp.concatenate(
314
316
  [self.obstacle_coords, bondaries_coords]
315
317
  ) # add the map boundaries to the obstacles to avoid
316
318
  obstacle_deltas = jnp.concatenate(
317
319
  [self.obstacle_deltas, bondaries_deltas]
318
320
  ) # add the map boundaries to the obstacles to avoid
319
321
  obst_start = obstacle_coords
320
- obst_end = obst_start + obstacle_deltas
322
+ obst_end = obst_start + obstacle_deltas """
321
323
 
322
324
  def check_obstacles(pos, new_pos, obst_start, obst_end):
323
325
  intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
@@ -325,9 +327,9 @@ class Parabellum(SMAX):
325
327
  flag = jnp.logical_or(intersects, rastersect)
326
328
  return jnp.where(flag, pos, new_pos)
327
329
 
328
- pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
330
+ """ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
329
331
  pos, new_pos, obst_start, obst_end
330
- )
332
+ ) """
331
333
 
332
334
  # Multiple enemies can attack the same unit.
333
335
  # We have `(health_diff, attacked_idx)` pairs.
@@ -373,13 +375,16 @@ class Parabellum(SMAX):
373
375
 
374
376
  if __name__ == "__main__":
375
377
  n_envs = 4
376
- kwargs = dict(map_width=64, map_height=64)
377
- env = Parabellum(scenarios["default"], **kwargs)
378
+
379
+ env = Environment(scenarios["default"])
378
380
  rng, reset_rng = random.split(random.PRNGKey(0))
379
381
  reset_key = random.split(reset_rng, n_envs)
380
382
  obs, state = vmap(env.reset)(reset_key)
381
383
  state_seq = []
382
384
 
385
+ print(state.unit_positions)
386
+ exit()
387
+
383
388
  for i in range(10):
384
389
  rng, act_rng, step_rng = random.split(rng, 3)
385
390
  act_key = random.split(act_rng, (len(env.agents), n_envs))
parabellum/map.py CHANGED
@@ -4,13 +4,46 @@
4
4
 
5
5
  # imports
6
6
  import jax.numpy as jnp
7
- import jax
7
+ from geopy.geocoders import Nominatim
8
+ import geopandas as gpd
9
+ import osmnx as ox
10
+ import rasterio
11
+ from jax import random
12
+ from rasterio import features
13
+ import rasterio.transform
14
+
15
+ # constants
16
+ geolocator = Nominatim(user_agent="parabellum")
17
+ tags = {"building": True}
8
18
 
9
19
 
10
20
  # functions
11
- def map_fn(width, height, obst_coord, obst_delta):
12
- """Create a map from the given width, height, and obstacle coordinates and deltas."""
13
- m = jnp.zeros((width, height))
14
- for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
15
- m = m.at[x : x + dx, y : y + dy].set(1)
16
- return m
21
+ def terrain_fn(place: str, size: int = 1000):
22
+ """Returns a rasterized map of a given location."""
23
+
24
+ # location info
25
+ location = geolocator.geocode(place)
26
+ coords = (location.latitude, location.longitude)
27
+
28
+ # shape info
29
+ geometry = ox.features_from_point(coords, tags=tags, dist=size // 2)
30
+ gdf = gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
31
+
32
+ # raster info
33
+ t = rasterio.transform.from_bounds(*gdf.total_bounds, size, size)
34
+ raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=t)
35
+
36
+ # rotate 180 degrees
37
+ raster = jnp.rot90(raster, 2)
38
+
39
+ return jnp.array(raster).astype(jnp.uint8)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ import seaborn as sns
44
+
45
+ place = "Vesterbro, Copenhagen, Denmark"
46
+ terrain = terrain_fn(place)
47
+ rng, key = random.split(random.PRNGKey(0))
48
+ agents = spawn_fn(terrain, 12, 100, key)
49
+ print(agents)
parabellum/run.py CHANGED
@@ -39,7 +39,7 @@ class Game:
39
39
  obs: Dict
40
40
  state_seq: StateSeq
41
41
  control: Control
42
- env: pb.Parabellum
42
+ env: pb.Environment
43
43
  rng: random.PRNGKey
44
44
 
45
45
 
@@ -123,5 +123,3 @@ if __name__ == "__main__":
123
123
  game = game if game.control.paused else render(game)
124
124
 
125
125
  pygame.quit()
126
-
127
-
parabellum/vis.py CHANGED
@@ -11,6 +11,7 @@ from functools import partial
11
11
  import darkdetect
12
12
  import numpy as np
13
13
  import pygame
14
+ import matplotlib.pyplot as plt
14
15
  import os
15
16
  from moviepy.editor import ImageSequenceClip
16
17
  from typing import Optional
@@ -33,15 +34,24 @@ def small_multiples():
33
34
  print(len(clips))
34
35
 
35
36
 
37
+ def text_fn(text):
38
+ """rotate text upside down because of pygame issue"""
39
+ return pygame.transform.rotate(text, 180)
40
+
41
+
36
42
  class Visualizer(SMAXVisualizer):
37
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
38
- super().__init__(env, state_seq, reward_seq)
39
- # remove fig and ax from super
43
+ def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None, skin=None):
40
44
  self.fig, self.ax = None, None
45
+ super().__init__(env, state_seq, reward_seq)
46
+ # clear the figure made by SMAXVisualizer
47
+ plt.close()
41
48
  self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
42
49
  self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
43
- self.s = 1000
44
- self.scale = self.s / self.env.map_width
50
+ self.pad = 75
51
+ # TODO: make sure it's always a 1024x1024 image
52
+ self.width = 1000
53
+ self.s = self.width + self.pad + self.pad
54
+ self.scale = self.width / env.map_width
45
55
  self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
46
56
  # self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
47
57
 
@@ -65,13 +75,13 @@ class Visualizer(SMAXVisualizer):
65
75
  def animate_one(self, state_seq, action_seq, save_fname):
66
76
  frames = [] # frames for the video
67
77
  pygame.init() # initialize pygame
68
- terrain = np.array(self.env.terrain_raster)
78
+ terrain = np.array(self.env.terrain_raster.T)
69
79
  rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
70
80
  if darkdetect.isLight():
71
81
  rgb_array += 255
72
82
  rgb_array[terrain == 1] = self.fg
73
83
  mask_surface = pygame.surfarray.make_surface(rgb_array)
74
- mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
84
+ mask_surface = pygame.transform.scale(mask_surface, (self.width, self.width))
75
85
 
76
86
  for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
77
87
  action = action_seq[idx // self.env.world_steps_per_env_step]
@@ -79,11 +89,33 @@ class Visualizer(SMAXVisualizer):
79
89
  (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
80
90
  )
81
91
  screen.fill(self.bg) # fill the screen with the background color
82
- screen.blit(mask_surface, (0, 0))
92
+ screen.blit(
93
+ mask_surface,
94
+ (self.pad, self.pad, self.width, self.width),
95
+ )
96
+ # add env.scenario.place to the title (in top padding)
97
+ font = pygame.font.SysFont("Fira Code", 18)
98
+ width = self.env.map_width
99
+ title = f"{width}x{width}m in {self.env.scenario.place}"
100
+ text = text_fn(font.render(title, True, self.fg))
101
+ # center the text
102
+ screen.blit(text, (self.s // 2 - text.get_width() // 2, self.pad // 4))
103
+ # draw edge around terrain
104
+ pygame.draw.rect(
105
+ screen,
106
+ self.fg,
107
+ (
108
+ self.pad - 2,
109
+ self.pad - 2,
110
+ self.width + 4,
111
+ self.width + 4,
112
+ ),
113
+ 2,
114
+ )
83
115
 
84
116
  self.render_agents(screen, state) # render the agents
85
117
  self.render_action(screen, action)
86
- self.render_obstacles(screen) # render the obstacles
118
+ # self.render_obstacles(screen) # render the obstacles
87
119
 
88
120
  # bullets
89
121
  """ if idx < len(self.bullet_seq) * 8:
@@ -91,15 +123,16 @@ class Visualizer(SMAXVisualizer):
91
123
  self.render_bullets(screen, bullets, idx % 8) """
92
124
 
93
125
  # rotate the screen and append to frames
94
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
126
+ pixels = pygame.surfarray.pixels3d(screen).swapaxes(0, 1)
127
+ # rotate the screen 180 degrees (transpose and flip)
128
+ pixels = np.rot90(pixels, 2) # pygame starts in bottom left
129
+ frames.append(pixels)
95
130
  # save the images
96
131
  clip = ImageSequenceClip(frames, fps=48)
97
132
  clip.write_videofile(save_fname, fps=48)
98
133
  clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
99
134
  pygame.quit()
100
135
 
101
- return clip
102
-
103
136
  def render_agents(self, screen, state):
104
137
  time_tuple = zip(
105
138
  state.unit_positions,
@@ -109,7 +142,7 @@ class Visualizer(SMAXVisualizer):
109
142
  )
110
143
  for idx, (pos, team, kind, hp) in enumerate(time_tuple):
111
144
  face_col = self.fg if int(team.item()) == 0 else self.bg
112
- pos = tuple((pos * self.scale).tolist())
145
+ pos = tuple(((pos * self.scale) + self.pad).tolist())
113
146
  # draw the agent
114
147
  if hp > 0:
115
148
  hp_frac = hp / self.env.unit_type_health[kind]
@@ -131,26 +164,28 @@ class Visualizer(SMAXVisualizer):
131
164
  if self.env.action_type != "discrete":
132
165
  return
133
166
 
134
- def coord_fn(idx, n, team):
167
+ def coord_fn(idx, n, team, text):
168
+ text_adj = text.get_width() / 2
169
+ is_ally = team == "ally"
135
170
  return (
136
- self.s / 20 if team == 0 else self.s - self.s / 20,
137
171
  # vertically centered so that n / 2 is above and below the center
172
+ self.pad + self.width + self.pad / 2 - text_adj
173
+ if is_ally
174
+ else self.pad / 2 - text_adj,
138
175
  self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
139
176
  )
140
177
 
141
- for idx in range(self.env.num_allies):
142
- symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
143
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
144
- text = font.render(symb, True, self.fg)
145
- coord = coord_fn(idx, self.env.num_allies, 0)
146
- screen.blit(text, coord)
147
-
148
- for idx in range(self.env.num_enemies):
149
- symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
150
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
151
- text = font.render(symb, True, self.fg)
152
- coord = coord_fn(idx, self.env.num_enemies, 1)
153
- screen.blit(text, coord)
178
+ for team, number in [("ally", 0), ("enemy", 1)]:
179
+ for idx in range(self.env.num_allies):
180
+ symb = action_to_symbol.get(
181
+ action[f"{team}_{idx}"].astype(int).item(), "Ø"
182
+ )
183
+ font = pygame.font.SysFont(
184
+ "Fira Code", jnp.sqrt(self.s).astype(int).item()
185
+ )
186
+ text = text_fn(font.render(symb, True, self.fg))
187
+ coord = coord_fn(idx, self.env.num_allies, team, text)
188
+ screen.blit(text, coord)
154
189
 
155
190
  def render_obstacles(self, screen):
156
191
  for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
@@ -0,0 +1,90 @@
1
+ Metadata-Version: 2.1
2
+ Name: parabellum
3
+ Version: 0.2.20
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Home-page: https://github.com/syrkis/parabellum
6
+ License: MIT
7
+ Keywords: warfare,simulation,parallel,environment
8
+ Author: Noah Syrkis
9
+ Author-email: desk@syrkis.com
10
+ Requires-Python: >=3.11,<4.0
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
+ Requires-Dist: geopandas (>=1.0.0,<2.0.0)
17
+ Requires-Dist: geopy (>=2.4.1,<3.0.0)
18
+ Requires-Dist: jax (==0.4.17)
19
+ Requires-Dist: jaxmarl (==0.0.3)
20
+ Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
21
+ Requires-Dist: moviepy (>=1.0.3,<2.0.0)
22
+ Requires-Dist: numpy (<2)
23
+ Requires-Dist: osmnx (>=1.9.3,<2.0.0)
24
+ Requires-Dist: poetry (>=1.8.3,<2.0.0)
25
+ Requires-Dist: pygame (>=2.5.2,<3.0.0)
26
+ Requires-Dist: rasterio (>=1.3.10,<2.0.0)
27
+ Requires-Dist: seaborn (>=0.13.2,<0.14.0)
28
+ Requires-Dist: tqdm (>=4.66.4,<5.0.0)
29
+ Project-URL: Repository, https://github.com/syrkis/parabellum
30
+ Description-Content-Type: text/markdown
31
+
32
+ # Parabellum
33
+
34
+ Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
35
+
36
+ [![Documentation Status](https://readthedocs.org/projects/parabellum/badge/?version=latest)](https://parabellum.readthedocs.io/en/latest/?badge=latest)
37
+
38
+ ## Features
39
+
40
+ - Obstacles and terrain integration
41
+ - Rasterized maps
42
+ - Blast radii simulation
43
+ - Friendly fire mechanics
44
+ - Pygame visualization
45
+ - JAX-based parallelization
46
+
47
+ ## Install
48
+
49
+ ```bash
50
+ pip install parabellum
51
+ ```
52
+
53
+ ## Quick Start
54
+
55
+ ```python
56
+ import parabellum as pb
57
+ from jax import random
58
+
59
+ terrain = pb.terrain_fn("Thun, Switzerland", 1000)
60
+ scenario = pb.make_scenario("Thun", terrain, 10, 10)
61
+ env = pb.Parabellum(scenario)
62
+
63
+ rng, key = random.split(random.PRNGKey(0))
64
+ obs, state = env.reset(key)
65
+
66
+ # Simulation loop
67
+ for _ in range(100):
68
+ rng, rng_act, key_step = random.split(key)
69
+ key_act = random.split(rng_act, len(env.agents))
70
+ act = {a: env.action_space(a).sample(k) for a, k in zip(env.agents, key_act)}
71
+ obs, state, reward, done, info = env.step(key_step, act, state)
72
+
73
+ # Visualize
74
+ vis = pb.Visualizer(env, state_sequence)
75
+ vis.animate()
76
+ ```
77
+
78
+ ## Documentation
79
+
80
+ Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.io)
81
+
82
+ ## Team
83
+
84
+ - Noah Syrkis
85
+ - Timothée Anne
86
+ - Supervisor: Sebastian Risi
87
+
88
+ ## License
89
+
90
+ MIT
@@ -0,0 +1,8 @@
1
+ parabellum/__init__.py,sha256=cI1kxVQ274VZrBLxUYvB_j6UqhFY44sId4NgVjbf678,245
2
+ parabellum/env.py,sha256=B2WaEzCQnakVB1AuYqFsb7aiPIbaUsuq0CjdKq4pQm8,16204
3
+ parabellum/map.py,sha256=CdPebRuGafj9fAydAyTa8bqpr-PldZw7heKRReGTDBg,1257
4
+ parabellum/run.py,sha256=HyPdz5iVD8q0iYaZL2Nf02fGHByHpecUGoPlQrq9v8s,3411
5
+ parabellum/vis.py,sha256=uIhN9VSJlT4XoNuNCiU_Bw2VSNoJFOYlScCTNrMkGsg,11977
6
+ parabellum-0.2.20.dist-info/METADATA,sha256=IO3Wsx_nTM3v63FtbnJqaEfPfXOyhNGtKOm_ppchUqo,2453
7
+ parabellum-0.2.20.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
+ parabellum-0.2.20.dist-info/RECORD,,
@@ -1,104 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: parabellum
3
- Version: 0.2.19
4
- Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
- Author: Noah Syrkis
9
- Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
11
- Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
- Requires-Dist: jax (==0.4.17)
17
- Requires-Dist: jaxmarl (==0.0.3)
18
- Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
19
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
20
- Requires-Dist: poetry (>=1.8.3,<2.0.0)
21
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
22
- Requires-Dist: tqdm (>=4.66.4,<5.0.0)
23
- Project-URL: Repository, https://github.com/syrkis/parabellum
24
- Description-Content-Type: text/markdown
25
-
26
- # parabellum
27
-
28
- Parabellum is an ultra-scalable, high-performance warfare simulation engine.
29
- It is based on JaxMARL's SMAX environment, but has been heavily modified to
30
- support a wide range of new features and improvements.
31
-
32
- ## Installation
33
-
34
- Parabellum is written in Python 3.11 and can be installed using pip:
35
-
36
- ```bash
37
- pip install parabellum
38
- ```
39
-
40
- ## Usage
41
-
42
- Parabellum is designed to be used in conjunction with JAX, a high-performance
43
- numerical computing library. Here is a simple example of how to use Parabellum
44
- to simulate a game with 10 agents and 10 enemies, each taking random actions:
45
-
46
- ```python
47
- import parabellum as pb
48
- from jax import random
49
-
50
- # define the scenario
51
- kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
52
- scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
53
-
54
- # create the environment
55
- kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
56
- env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
57
-
58
- # initiate stochasticity
59
- rng = random.PRNGKey(0)
60
- rng, key = random.split(rng)
61
-
62
- # initialize the environment state
63
- obs, state = env.reset(key)
64
- state_sequence = []
65
-
66
- for _ in range(1000):
67
-
68
- # manage stochasticity
69
- rng, rng_act, key_step = random.split(key)
70
- key_act = random.split(rng_act, len(env.agents))
71
-
72
- # sample actions and append to state sequence
73
- act = {a: env.action_space(a).sample(k)
74
- for a, k in zip(env.agents, key_act)}
75
-
76
- # step the environment
77
- state_sequence.append((key_act, state, act))
78
- obs, state, reward, done, info = env.step(key_step, act, state)
79
-
80
-
81
- # save visualization of the state sequence
82
- vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
83
- vis.animate()
84
- ```
85
-
86
-
87
- ## Features
88
-
89
- - Obstacles — can be inserted in
90
-
91
- ## TODO
92
-
93
- - [x] Parallel pygame vis
94
- - [ ] Parallel bullet renderings
95
- - [ ] Combine parallell plots into one (maybe out of parabellum scope)
96
- - [ ] Color for health?
97
- - [ ] Add the ability to see ongoing game.
98
- - [ ] Bug test friendly fire.
99
- - [x] Start sim from arbitrary state.
100
- - [ ] Save when the episode ends in some state/obs variable
101
- - [ ] Look for the source of the bug when using more Allies than Enemies
102
- - [ ] Y inversed axis for parabellum visualization
103
- - [ ] Units see through obstacles?
104
-
@@ -1,8 +0,0 @@
1
- parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
- parabellum/env.py,sha256=Z8zpdCaEi5HFwN0Vd2hukOarkPSg0EKZErTRts3JQ5E,16023
3
- parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
- parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
- parabellum/vis.py,sha256=JFVTnBg-LV4jZNw6cysU6NS8ZxeMpg5wz3JOi-lrnzY,10699
6
- parabellum-0.2.19.dist-info/METADATA,sha256=DsEBAlESj8BwGSphmPPylStoXH_g_x_Iy3WJ3KEwjc0,3223
7
- parabellum-0.2.19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
- parabellum-0.2.19.dist-info/RECORD,,