parabellum 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +10 -2
- parabellum/env.py +59 -54
- parabellum/map.py +40 -7
- parabellum/run.py +1 -3
- parabellum/vis.py +63 -28
- parabellum-0.2.20.dist-info/METADATA +90 -0
- parabellum-0.2.20.dist-info/RECORD +8 -0
- parabellum-0.2.19.dist-info/METADATA +0 -104
- parabellum-0.2.19.dist-info/RECORD +0 -8
- {parabellum-0.2.19.dist-info → parabellum-0.2.20.dist-info}/WHEEL +0 -0
parabellum/__init__.py
CHANGED
@@ -1,4 +1,12 @@
|
|
1
|
-
from .env import
|
1
|
+
from .env import Environment, Scenario, scenarios, make_scenario
|
2
2
|
from .vis import Visualizer
|
3
|
+
from .map import terrain_fn
|
3
4
|
|
4
|
-
__all__ = [
|
5
|
+
__all__ = [
|
6
|
+
"Environment",
|
7
|
+
"Scenario",
|
8
|
+
"scenarios",
|
9
|
+
"make_scenario",
|
10
|
+
"Visualizer",
|
11
|
+
"terrain_fn",
|
12
|
+
]
|
parabellum/env.py
CHANGED
@@ -17,11 +17,8 @@ from functools import partial
|
|
17
17
|
class Scenario:
|
18
18
|
"""Parabellum scenario"""
|
19
19
|
|
20
|
+
place: str
|
20
21
|
terrain_raster: chex.Array
|
21
|
-
|
22
|
-
obstacle_coords: chex.Array # TODO: use map instead of obstacles
|
23
|
-
obstacle_deltas: chex.Array
|
24
|
-
|
25
22
|
unit_types: chex.Array
|
26
23
|
num_allies: int
|
27
24
|
num_enemies: int
|
@@ -33,9 +30,8 @@ class Scenario:
|
|
33
30
|
# default scenario
|
34
31
|
scenarios = {
|
35
32
|
"default": Scenario(
|
33
|
+
"Identity Town",
|
36
34
|
jnp.eye(64, dtype=jnp.uint8),
|
37
|
-
jnp.array([[80, 0], [16, 12]]),
|
38
|
-
jnp.array([[0, 80], [0, 20]]),
|
39
35
|
jnp.zeros((19,), dtype=jnp.uint8),
|
40
36
|
9,
|
41
37
|
10,
|
@@ -43,57 +39,63 @@ scenarios = {
|
|
43
39
|
}
|
44
40
|
|
45
41
|
|
46
|
-
|
42
|
+
def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
|
43
|
+
"""Create a scenario"""
|
44
|
+
num_agents = num_allies + num_enemies
|
45
|
+
unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
|
46
|
+
return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
|
47
|
+
|
48
|
+
|
49
|
+
def spawn_fn(pool: jnp.ndarray, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
|
50
|
+
"""Spawns n agents on a map."""
|
51
|
+
rng, key_start, key_noise = random.split(rng, 3)
|
52
|
+
noise = random.uniform(key_noise, (n, 2)) * 0.5
|
53
|
+
|
54
|
+
# select n random (x, y)-coords where sector == True
|
55
|
+
idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
|
56
|
+
coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
|
57
|
+
|
58
|
+
return coords + noise + offset
|
59
|
+
|
60
|
+
|
61
|
+
def sector_fn(terrain: jnp.ndarray, sector_id: int):
|
62
|
+
"""return sector slice of terrain"""
|
63
|
+
width, height = terrain.shape
|
64
|
+
coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
|
65
|
+
sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
|
66
|
+
offset = jnp.array([coordx, coordy])
|
67
|
+
# sector is jnp.nonzero
|
68
|
+
return jnp.nonzero(sector), offset
|
69
|
+
|
70
|
+
|
71
|
+
class Environment(SMAX):
|
47
72
|
def __init__(self, scenario: Scenario, **kwargs):
|
48
73
|
map_height, map_width = scenario.terrain_raster.shape
|
49
74
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
50
|
-
super(
|
75
|
+
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
51
76
|
self.terrain_raster = scenario.terrain_raster
|
52
|
-
self.
|
53
|
-
self.
|
54
|
-
self.
|
77
|
+
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
78
|
+
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
79
|
+
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
80
|
+
self.scenario = scenario
|
81
|
+
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
55
82
|
self.max_steps = 200
|
56
83
|
self._push_units_away = lambda x: x # overwrite push units
|
84
|
+
self.top_sector = sector_fn(self.terrain_raster, 0)
|
85
|
+
self.low_sector = sector_fn(self.terrain_raster, 24)
|
57
86
|
|
58
87
|
@partial(jax.jit, static_argnums=(0,))
|
59
|
-
def reset(self,
|
88
|
+
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
60
89
|
"""Environment-specific reset."""
|
61
|
-
|
62
|
-
team_0_start =
|
63
|
-
|
64
|
-
)
|
65
|
-
team_0_start_noise = jax.random.uniform(
|
66
|
-
team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
|
67
|
-
)
|
68
|
-
team_0_start = team_0_start + team_0_start_noise
|
69
|
-
team_1_start = jnp.stack(
|
70
|
-
[jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
|
71
|
-
* self.num_enemies
|
72
|
-
)
|
73
|
-
team_1_start_noise = jax.random.uniform(
|
74
|
-
team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
|
75
|
-
)
|
76
|
-
team_1_start = team_1_start + team_1_start_noise
|
90
|
+
ally_key, enemy_key = jax.random.split(rng)
|
91
|
+
team_0_start = spawn_fn(*self.top_sector, self.num_allies, ally_key)
|
92
|
+
team_1_start = spawn_fn(*self.low_sector, self.num_enemies, enemy_key)
|
77
93
|
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
78
|
-
key, pos_key = jax.random.split(key)
|
79
|
-
generated_unit_positions = self.position_generator.generate(pos_key)
|
80
|
-
unit_positions = jax.lax.select(
|
81
|
-
self.smacv2_position_generation, generated_unit_positions, unit_positions
|
82
|
-
)
|
83
94
|
unit_teams = jnp.zeros((self.num_agents,))
|
84
95
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
85
96
|
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
86
97
|
# default behaviour spawn all marines
|
87
|
-
unit_types =
|
88
|
-
jnp.zeros((self.num_agents,), dtype=jnp.uint8)
|
89
|
-
if self.scenario is None
|
90
|
-
else self.scenario
|
91
|
-
)
|
92
|
-
key, unit_type_key = jax.random.split(key)
|
93
|
-
generated_unit_types = self.unit_type_generator.generate(unit_type_key)
|
94
|
-
unit_types = jax.lax.select(
|
95
|
-
self.smacv2_unit_type_generation, generated_unit_types, unit_types
|
96
|
-
)
|
98
|
+
unit_types = self.scenario.unit_types
|
97
99
|
unit_health = self.unit_type_health[unit_types]
|
98
100
|
state = State(
|
99
101
|
unit_positions=unit_positions,
|
@@ -180,12 +182,12 @@ class Parabellum(SMAX):
|
|
180
182
|
|
181
183
|
#######################################################################
|
182
184
|
############################################ avoid going into obstacles
|
183
|
-
obs = self.obstacle_coords
|
185
|
+
""" obs = self.obstacle_coords
|
184
186
|
obs_end = obs + self.obstacle_deltas
|
185
|
-
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
|
186
|
-
|
187
|
-
flag = jnp.logical_or(inters, rastersects)
|
188
|
-
new_pos = jnp.where(
|
187
|
+
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
|
188
|
+
clash = raster_crossing(pos, new_pos)
|
189
|
+
# flag = jnp.logical_or(inters, rastersects)
|
190
|
+
new_pos = jnp.where(clash, pos, new_pos)
|
189
191
|
|
190
192
|
#######################################################################
|
191
193
|
#######################################################################
|
@@ -310,14 +312,14 @@ class Parabellum(SMAX):
|
|
310
312
|
[self.map_width, 0],
|
311
313
|
]
|
312
314
|
)
|
313
|
-
obstacle_coords = jnp.concatenate(
|
315
|
+
""" obstacle_coords = jnp.concatenate(
|
314
316
|
[self.obstacle_coords, bondaries_coords]
|
315
317
|
) # add the map boundaries to the obstacles to avoid
|
316
318
|
obstacle_deltas = jnp.concatenate(
|
317
319
|
[self.obstacle_deltas, bondaries_deltas]
|
318
320
|
) # add the map boundaries to the obstacles to avoid
|
319
321
|
obst_start = obstacle_coords
|
320
|
-
obst_end = obst_start + obstacle_deltas
|
322
|
+
obst_end = obst_start + obstacle_deltas """
|
321
323
|
|
322
324
|
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
323
325
|
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
@@ -325,9 +327,9 @@ class Parabellum(SMAX):
|
|
325
327
|
flag = jnp.logical_or(intersects, rastersect)
|
326
328
|
return jnp.where(flag, pos, new_pos)
|
327
329
|
|
328
|
-
pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
330
|
+
""" pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
329
331
|
pos, new_pos, obst_start, obst_end
|
330
|
-
)
|
332
|
+
) """
|
331
333
|
|
332
334
|
# Multiple enemies can attack the same unit.
|
333
335
|
# We have `(health_diff, attacked_idx)` pairs.
|
@@ -373,13 +375,16 @@ class Parabellum(SMAX):
|
|
373
375
|
|
374
376
|
if __name__ == "__main__":
|
375
377
|
n_envs = 4
|
376
|
-
|
377
|
-
env =
|
378
|
+
|
379
|
+
env = Environment(scenarios["default"])
|
378
380
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
379
381
|
reset_key = random.split(reset_rng, n_envs)
|
380
382
|
obs, state = vmap(env.reset)(reset_key)
|
381
383
|
state_seq = []
|
382
384
|
|
385
|
+
print(state.unit_positions)
|
386
|
+
exit()
|
387
|
+
|
383
388
|
for i in range(10):
|
384
389
|
rng, act_rng, step_rng = random.split(rng, 3)
|
385
390
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
parabellum/map.py
CHANGED
@@ -4,13 +4,46 @@
|
|
4
4
|
|
5
5
|
# imports
|
6
6
|
import jax.numpy as jnp
|
7
|
-
import
|
7
|
+
from geopy.geocoders import Nominatim
|
8
|
+
import geopandas as gpd
|
9
|
+
import osmnx as ox
|
10
|
+
import rasterio
|
11
|
+
from jax import random
|
12
|
+
from rasterio import features
|
13
|
+
import rasterio.transform
|
14
|
+
|
15
|
+
# constants
|
16
|
+
geolocator = Nominatim(user_agent="parabellum")
|
17
|
+
tags = {"building": True}
|
8
18
|
|
9
19
|
|
10
20
|
# functions
|
11
|
-
def
|
12
|
-
"""
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
21
|
+
def terrain_fn(place: str, size: int = 1000):
|
22
|
+
"""Returns a rasterized map of a given location."""
|
23
|
+
|
24
|
+
# location info
|
25
|
+
location = geolocator.geocode(place)
|
26
|
+
coords = (location.latitude, location.longitude)
|
27
|
+
|
28
|
+
# shape info
|
29
|
+
geometry = ox.features_from_point(coords, tags=tags, dist=size // 2)
|
30
|
+
gdf = gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
|
31
|
+
|
32
|
+
# raster info
|
33
|
+
t = rasterio.transform.from_bounds(*gdf.total_bounds, size, size)
|
34
|
+
raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=t)
|
35
|
+
|
36
|
+
# rotate 180 degrees
|
37
|
+
raster = jnp.rot90(raster, 2)
|
38
|
+
|
39
|
+
return jnp.array(raster).astype(jnp.uint8)
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
import seaborn as sns
|
44
|
+
|
45
|
+
place = "Vesterbro, Copenhagen, Denmark"
|
46
|
+
terrain = terrain_fn(place)
|
47
|
+
rng, key = random.split(random.PRNGKey(0))
|
48
|
+
agents = spawn_fn(terrain, 12, 100, key)
|
49
|
+
print(agents)
|
parabellum/run.py
CHANGED
@@ -39,7 +39,7 @@ class Game:
|
|
39
39
|
obs: Dict
|
40
40
|
state_seq: StateSeq
|
41
41
|
control: Control
|
42
|
-
env: pb.
|
42
|
+
env: pb.Environment
|
43
43
|
rng: random.PRNGKey
|
44
44
|
|
45
45
|
|
@@ -123,5 +123,3 @@ if __name__ == "__main__":
|
|
123
123
|
game = game if game.control.paused else render(game)
|
124
124
|
|
125
125
|
pygame.quit()
|
126
|
-
|
127
|
-
|
parabellum/vis.py
CHANGED
@@ -11,6 +11,7 @@ from functools import partial
|
|
11
11
|
import darkdetect
|
12
12
|
import numpy as np
|
13
13
|
import pygame
|
14
|
+
import matplotlib.pyplot as plt
|
14
15
|
import os
|
15
16
|
from moviepy.editor import ImageSequenceClip
|
16
17
|
from typing import Optional
|
@@ -33,15 +34,24 @@ def small_multiples():
|
|
33
34
|
print(len(clips))
|
34
35
|
|
35
36
|
|
37
|
+
def text_fn(text):
|
38
|
+
"""rotate text upside down because of pygame issue"""
|
39
|
+
return pygame.transform.rotate(text, 180)
|
40
|
+
|
41
|
+
|
36
42
|
class Visualizer(SMAXVisualizer):
|
37
|
-
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
38
|
-
super().__init__(env, state_seq, reward_seq)
|
39
|
-
# remove fig and ax from super
|
43
|
+
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None, skin=None):
|
40
44
|
self.fig, self.ax = None, None
|
45
|
+
super().__init__(env, state_seq, reward_seq)
|
46
|
+
# clear the figure made by SMAXVisualizer
|
47
|
+
plt.close()
|
41
48
|
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
42
49
|
self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
|
43
|
-
self.
|
44
|
-
|
50
|
+
self.pad = 75
|
51
|
+
# TODO: make sure it's always a 1024x1024 image
|
52
|
+
self.width = 1000
|
53
|
+
self.s = self.width + self.pad + self.pad
|
54
|
+
self.scale = self.width / env.map_width
|
45
55
|
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
46
56
|
# self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
|
47
57
|
|
@@ -65,13 +75,13 @@ class Visualizer(SMAXVisualizer):
|
|
65
75
|
def animate_one(self, state_seq, action_seq, save_fname):
|
66
76
|
frames = [] # frames for the video
|
67
77
|
pygame.init() # initialize pygame
|
68
|
-
terrain = np.array(self.env.terrain_raster)
|
78
|
+
terrain = np.array(self.env.terrain_raster.T)
|
69
79
|
rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
|
70
80
|
if darkdetect.isLight():
|
71
81
|
rgb_array += 255
|
72
82
|
rgb_array[terrain == 1] = self.fg
|
73
83
|
mask_surface = pygame.surfarray.make_surface(rgb_array)
|
74
|
-
mask_surface = pygame.transform.scale(mask_surface, (self.
|
84
|
+
mask_surface = pygame.transform.scale(mask_surface, (self.width, self.width))
|
75
85
|
|
76
86
|
for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
|
77
87
|
action = action_seq[idx // self.env.world_steps_per_env_step]
|
@@ -79,11 +89,33 @@ class Visualizer(SMAXVisualizer):
|
|
79
89
|
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
80
90
|
)
|
81
91
|
screen.fill(self.bg) # fill the screen with the background color
|
82
|
-
screen.blit(
|
92
|
+
screen.blit(
|
93
|
+
mask_surface,
|
94
|
+
(self.pad, self.pad, self.width, self.width),
|
95
|
+
)
|
96
|
+
# add env.scenario.place to the title (in top padding)
|
97
|
+
font = pygame.font.SysFont("Fira Code", 18)
|
98
|
+
width = self.env.map_width
|
99
|
+
title = f"{width}x{width}m in {self.env.scenario.place}"
|
100
|
+
text = text_fn(font.render(title, True, self.fg))
|
101
|
+
# center the text
|
102
|
+
screen.blit(text, (self.s // 2 - text.get_width() // 2, self.pad // 4))
|
103
|
+
# draw edge around terrain
|
104
|
+
pygame.draw.rect(
|
105
|
+
screen,
|
106
|
+
self.fg,
|
107
|
+
(
|
108
|
+
self.pad - 2,
|
109
|
+
self.pad - 2,
|
110
|
+
self.width + 4,
|
111
|
+
self.width + 4,
|
112
|
+
),
|
113
|
+
2,
|
114
|
+
)
|
83
115
|
|
84
116
|
self.render_agents(screen, state) # render the agents
|
85
117
|
self.render_action(screen, action)
|
86
|
-
self.render_obstacles(screen) # render the obstacles
|
118
|
+
# self.render_obstacles(screen) # render the obstacles
|
87
119
|
|
88
120
|
# bullets
|
89
121
|
""" if idx < len(self.bullet_seq) * 8:
|
@@ -91,15 +123,16 @@ class Visualizer(SMAXVisualizer):
|
|
91
123
|
self.render_bullets(screen, bullets, idx % 8) """
|
92
124
|
|
93
125
|
# rotate the screen and append to frames
|
94
|
-
|
126
|
+
pixels = pygame.surfarray.pixels3d(screen).swapaxes(0, 1)
|
127
|
+
# rotate the screen 180 degrees (transpose and flip)
|
128
|
+
pixels = np.rot90(pixels, 2) # pygame starts in bottom left
|
129
|
+
frames.append(pixels)
|
95
130
|
# save the images
|
96
131
|
clip = ImageSequenceClip(frames, fps=48)
|
97
132
|
clip.write_videofile(save_fname, fps=48)
|
98
133
|
clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
99
134
|
pygame.quit()
|
100
135
|
|
101
|
-
return clip
|
102
|
-
|
103
136
|
def render_agents(self, screen, state):
|
104
137
|
time_tuple = zip(
|
105
138
|
state.unit_positions,
|
@@ -109,7 +142,7 @@ class Visualizer(SMAXVisualizer):
|
|
109
142
|
)
|
110
143
|
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
111
144
|
face_col = self.fg if int(team.item()) == 0 else self.bg
|
112
|
-
pos = tuple((pos * self.scale).tolist())
|
145
|
+
pos = tuple(((pos * self.scale) + self.pad).tolist())
|
113
146
|
# draw the agent
|
114
147
|
if hp > 0:
|
115
148
|
hp_frac = hp / self.env.unit_type_health[kind]
|
@@ -131,26 +164,28 @@ class Visualizer(SMAXVisualizer):
|
|
131
164
|
if self.env.action_type != "discrete":
|
132
165
|
return
|
133
166
|
|
134
|
-
def coord_fn(idx, n, team):
|
167
|
+
def coord_fn(idx, n, team, text):
|
168
|
+
text_adj = text.get_width() / 2
|
169
|
+
is_ally = team == "ally"
|
135
170
|
return (
|
136
|
-
self.s / 20 if team == 0 else self.s - self.s / 20,
|
137
171
|
# vertically centered so that n / 2 is above and below the center
|
172
|
+
self.pad + self.width + self.pad / 2 - text_adj
|
173
|
+
if is_ally
|
174
|
+
else self.pad / 2 - text_adj,
|
138
175
|
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
139
176
|
)
|
140
177
|
|
141
|
-
for
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
coord = coord_fn(idx, self.env.num_enemies, 1)
|
153
|
-
screen.blit(text, coord)
|
178
|
+
for team, number in [("ally", 0), ("enemy", 1)]:
|
179
|
+
for idx in range(self.env.num_allies):
|
180
|
+
symb = action_to_symbol.get(
|
181
|
+
action[f"{team}_{idx}"].astype(int).item(), "Ø"
|
182
|
+
)
|
183
|
+
font = pygame.font.SysFont(
|
184
|
+
"Fira Code", jnp.sqrt(self.s).astype(int).item()
|
185
|
+
)
|
186
|
+
text = text_fn(font.render(symb, True, self.fg))
|
187
|
+
coord = coord_fn(idx, self.env.num_allies, team, text)
|
188
|
+
screen.blit(text, coord)
|
154
189
|
|
155
190
|
def render_obstacles(self, screen):
|
156
191
|
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
@@ -0,0 +1,90 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.2.20
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
8
|
+
Author: Noah Syrkis
|
9
|
+
Author-email: desk@syrkis.com
|
10
|
+
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: geopandas (>=1.0.0,<2.0.0)
|
17
|
+
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
18
|
+
Requires-Dist: jax (==0.4.17)
|
19
|
+
Requires-Dist: jaxmarl (==0.0.3)
|
20
|
+
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
21
|
+
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
22
|
+
Requires-Dist: numpy (<2)
|
23
|
+
Requires-Dist: osmnx (>=1.9.3,<2.0.0)
|
24
|
+
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
25
|
+
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
26
|
+
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
27
|
+
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
28
|
+
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
29
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
30
|
+
Description-Content-Type: text/markdown
|
31
|
+
|
32
|
+
# Parabellum
|
33
|
+
|
34
|
+
Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
|
35
|
+
|
36
|
+
[](https://parabellum.readthedocs.io/en/latest/?badge=latest)
|
37
|
+
|
38
|
+
## Features
|
39
|
+
|
40
|
+
- Obstacles and terrain integration
|
41
|
+
- Rasterized maps
|
42
|
+
- Blast radii simulation
|
43
|
+
- Friendly fire mechanics
|
44
|
+
- Pygame visualization
|
45
|
+
- JAX-based parallelization
|
46
|
+
|
47
|
+
## Install
|
48
|
+
|
49
|
+
```bash
|
50
|
+
pip install parabellum
|
51
|
+
```
|
52
|
+
|
53
|
+
## Quick Start
|
54
|
+
|
55
|
+
```python
|
56
|
+
import parabellum as pb
|
57
|
+
from jax import random
|
58
|
+
|
59
|
+
terrain = pb.terrain_fn("Thun, Switzerland", 1000)
|
60
|
+
scenario = pb.make_scenario("Thun", terrain, 10, 10)
|
61
|
+
env = pb.Parabellum(scenario)
|
62
|
+
|
63
|
+
rng, key = random.split(random.PRNGKey(0))
|
64
|
+
obs, state = env.reset(key)
|
65
|
+
|
66
|
+
# Simulation loop
|
67
|
+
for _ in range(100):
|
68
|
+
rng, rng_act, key_step = random.split(key)
|
69
|
+
key_act = random.split(rng_act, len(env.agents))
|
70
|
+
act = {a: env.action_space(a).sample(k) for a, k in zip(env.agents, key_act)}
|
71
|
+
obs, state, reward, done, info = env.step(key_step, act, state)
|
72
|
+
|
73
|
+
# Visualize
|
74
|
+
vis = pb.Visualizer(env, state_sequence)
|
75
|
+
vis.animate()
|
76
|
+
```
|
77
|
+
|
78
|
+
## Documentation
|
79
|
+
|
80
|
+
Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.io)
|
81
|
+
|
82
|
+
## Team
|
83
|
+
|
84
|
+
- Noah Syrkis
|
85
|
+
- Timothée Anne
|
86
|
+
- Supervisor: Sebastian Risi
|
87
|
+
|
88
|
+
## License
|
89
|
+
|
90
|
+
MIT
|
@@ -0,0 +1,8 @@
|
|
1
|
+
parabellum/__init__.py,sha256=cI1kxVQ274VZrBLxUYvB_j6UqhFY44sId4NgVjbf678,245
|
2
|
+
parabellum/env.py,sha256=B2WaEzCQnakVB1AuYqFsb7aiPIbaUsuq0CjdKq4pQm8,16204
|
3
|
+
parabellum/map.py,sha256=CdPebRuGafj9fAydAyTa8bqpr-PldZw7heKRReGTDBg,1257
|
4
|
+
parabellum/run.py,sha256=HyPdz5iVD8q0iYaZL2Nf02fGHByHpecUGoPlQrq9v8s,3411
|
5
|
+
parabellum/vis.py,sha256=uIhN9VSJlT4XoNuNCiU_Bw2VSNoJFOYlScCTNrMkGsg,11977
|
6
|
+
parabellum-0.2.20.dist-info/METADATA,sha256=IO3Wsx_nTM3v63FtbnJqaEfPfXOyhNGtKOm_ppchUqo,2453
|
7
|
+
parabellum-0.2.20.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
+
parabellum-0.2.20.dist-info/RECORD,,
|
@@ -1,104 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: parabellum
|
3
|
-
Version: 0.2.19
|
4
|
-
Summary: Parabellum environment for parallel warfare simulation
|
5
|
-
Home-page: https://github.com/syrkis/parabellum
|
6
|
-
License: MIT
|
7
|
-
Keywords: warfare,simulation,parallel,environment
|
8
|
-
Author: Noah Syrkis
|
9
|
-
Author-email: desk@syrkis.com
|
10
|
-
Requires-Python: >=3.11,<4.0
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
12
|
-
Classifier: Programming Language :: Python :: 3
|
13
|
-
Classifier: Programming Language :: Python :: 3.11
|
14
|
-
Classifier: Programming Language :: Python :: 3.12
|
15
|
-
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
-
Requires-Dist: jax (==0.4.17)
|
17
|
-
Requires-Dist: jaxmarl (==0.0.3)
|
18
|
-
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
19
|
-
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
20
|
-
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
21
|
-
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
22
|
-
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
23
|
-
Project-URL: Repository, https://github.com/syrkis/parabellum
|
24
|
-
Description-Content-Type: text/markdown
|
25
|
-
|
26
|
-
# parabellum
|
27
|
-
|
28
|
-
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
29
|
-
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
30
|
-
support a wide range of new features and improvements.
|
31
|
-
|
32
|
-
## Installation
|
33
|
-
|
34
|
-
Parabellum is written in Python 3.11 and can be installed using pip:
|
35
|
-
|
36
|
-
```bash
|
37
|
-
pip install parabellum
|
38
|
-
```
|
39
|
-
|
40
|
-
## Usage
|
41
|
-
|
42
|
-
Parabellum is designed to be used in conjunction with JAX, a high-performance
|
43
|
-
numerical computing library. Here is a simple example of how to use Parabellum
|
44
|
-
to simulate a game with 10 agents and 10 enemies, each taking random actions:
|
45
|
-
|
46
|
-
```python
|
47
|
-
import parabellum as pb
|
48
|
-
from jax import random
|
49
|
-
|
50
|
-
# define the scenario
|
51
|
-
kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
|
52
|
-
scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
|
53
|
-
|
54
|
-
# create the environment
|
55
|
-
kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
|
56
|
-
env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
|
57
|
-
|
58
|
-
# initiate stochasticity
|
59
|
-
rng = random.PRNGKey(0)
|
60
|
-
rng, key = random.split(rng)
|
61
|
-
|
62
|
-
# initialize the environment state
|
63
|
-
obs, state = env.reset(key)
|
64
|
-
state_sequence = []
|
65
|
-
|
66
|
-
for _ in range(1000):
|
67
|
-
|
68
|
-
# manage stochasticity
|
69
|
-
rng, rng_act, key_step = random.split(key)
|
70
|
-
key_act = random.split(rng_act, len(env.agents))
|
71
|
-
|
72
|
-
# sample actions and append to state sequence
|
73
|
-
act = {a: env.action_space(a).sample(k)
|
74
|
-
for a, k in zip(env.agents, key_act)}
|
75
|
-
|
76
|
-
# step the environment
|
77
|
-
state_sequence.append((key_act, state, act))
|
78
|
-
obs, state, reward, done, info = env.step(key_step, act, state)
|
79
|
-
|
80
|
-
|
81
|
-
# save visualization of the state sequence
|
82
|
-
vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
|
83
|
-
vis.animate()
|
84
|
-
```
|
85
|
-
|
86
|
-
|
87
|
-
## Features
|
88
|
-
|
89
|
-
- Obstacles — can be inserted in
|
90
|
-
|
91
|
-
## TODO
|
92
|
-
|
93
|
-
- [x] Parallel pygame vis
|
94
|
-
- [ ] Parallel bullet renderings
|
95
|
-
- [ ] Combine parallell plots into one (maybe out of parabellum scope)
|
96
|
-
- [ ] Color for health?
|
97
|
-
- [ ] Add the ability to see ongoing game.
|
98
|
-
- [ ] Bug test friendly fire.
|
99
|
-
- [x] Start sim from arbitrary state.
|
100
|
-
- [ ] Save when the episode ends in some state/obs variable
|
101
|
-
- [ ] Look for the source of the bug when using more Allies than Enemies
|
102
|
-
- [ ] Y inversed axis for parabellum visualization
|
103
|
-
- [ ] Units see through obstacles?
|
104
|
-
|
@@ -1,8 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
-
parabellum/env.py,sha256=Z8zpdCaEi5HFwN0Vd2hukOarkPSg0EKZErTRts3JQ5E,16023
|
3
|
-
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
-
parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
|
5
|
-
parabellum/vis.py,sha256=JFVTnBg-LV4jZNw6cysU6NS8ZxeMpg5wz3JOi-lrnzY,10699
|
6
|
-
parabellum-0.2.19.dist-info/METADATA,sha256=DsEBAlESj8BwGSphmPPylStoXH_g_x_Iy3WJ3KEwjc0,3223
|
7
|
-
parabellum-0.2.19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
-
parabellum-0.2.19.dist-info/RECORD,,
|
File without changes
|