parabellum 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +10 -2
- parabellum/env.py +60 -55
- parabellum/map.py +40 -7
- parabellum/run.py +1 -3
- parabellum/vis.py +75 -33
- parabellum-0.2.20.dist-info/METADATA +90 -0
- parabellum-0.2.20.dist-info/RECORD +8 -0
- parabellum-0.2.18.dist-info/METADATA +0 -104
- parabellum-0.2.18.dist-info/RECORD +0 -8
- {parabellum-0.2.18.dist-info → parabellum-0.2.20.dist-info}/WHEEL +0 -0
parabellum/__init__.py
CHANGED
@@ -1,4 +1,12 @@
|
|
1
|
-
from .env import
|
1
|
+
from .env import Environment, Scenario, scenarios, make_scenario
|
2
2
|
from .vis import Visualizer
|
3
|
+
from .map import terrain_fn
|
3
4
|
|
4
|
-
__all__ = [
|
5
|
+
__all__ = [
|
6
|
+
"Environment",
|
7
|
+
"Scenario",
|
8
|
+
"scenarios",
|
9
|
+
"make_scenario",
|
10
|
+
"Visualizer",
|
11
|
+
"terrain_fn",
|
12
|
+
]
|
parabellum/env.py
CHANGED
@@ -17,11 +17,8 @@ from functools import partial
|
|
17
17
|
class Scenario:
|
18
18
|
"""Parabellum scenario"""
|
19
19
|
|
20
|
+
place: str
|
20
21
|
terrain_raster: chex.Array
|
21
|
-
|
22
|
-
obstacle_coords: chex.Array # TODO: use map instead of obstacles
|
23
|
-
obstacle_deltas: chex.Array
|
24
|
-
|
25
22
|
unit_types: chex.Array
|
26
23
|
num_allies: int
|
27
24
|
num_enemies: int
|
@@ -33,9 +30,8 @@ class Scenario:
|
|
33
30
|
# default scenario
|
34
31
|
scenarios = {
|
35
32
|
"default": Scenario(
|
36
|
-
|
37
|
-
jnp.
|
38
|
-
jnp.array([[0, 80], [0, 20]]),
|
33
|
+
"Identity Town",
|
34
|
+
jnp.eye(64, dtype=jnp.uint8),
|
39
35
|
jnp.zeros((19,), dtype=jnp.uint8),
|
40
36
|
9,
|
41
37
|
10,
|
@@ -43,57 +39,63 @@ scenarios = {
|
|
43
39
|
}
|
44
40
|
|
45
41
|
|
46
|
-
|
42
|
+
def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
|
43
|
+
"""Create a scenario"""
|
44
|
+
num_agents = num_allies + num_enemies
|
45
|
+
unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
|
46
|
+
return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
|
47
|
+
|
48
|
+
|
49
|
+
def spawn_fn(pool: jnp.ndarray, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
|
50
|
+
"""Spawns n agents on a map."""
|
51
|
+
rng, key_start, key_noise = random.split(rng, 3)
|
52
|
+
noise = random.uniform(key_noise, (n, 2)) * 0.5
|
53
|
+
|
54
|
+
# select n random (x, y)-coords where sector == True
|
55
|
+
idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
|
56
|
+
coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
|
57
|
+
|
58
|
+
return coords + noise + offset
|
59
|
+
|
60
|
+
|
61
|
+
def sector_fn(terrain: jnp.ndarray, sector_id: int):
|
62
|
+
"""return sector slice of terrain"""
|
63
|
+
width, height = terrain.shape
|
64
|
+
coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
|
65
|
+
sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
|
66
|
+
offset = jnp.array([coordx, coordy])
|
67
|
+
# sector is jnp.nonzero
|
68
|
+
return jnp.nonzero(sector), offset
|
69
|
+
|
70
|
+
|
71
|
+
class Environment(SMAX):
|
47
72
|
def __init__(self, scenario: Scenario, **kwargs):
|
48
73
|
map_height, map_width = scenario.terrain_raster.shape
|
49
74
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
50
|
-
super(
|
75
|
+
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
51
76
|
self.terrain_raster = scenario.terrain_raster
|
52
|
-
self.
|
53
|
-
self.
|
54
|
-
self.
|
77
|
+
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
78
|
+
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
79
|
+
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
80
|
+
self.scenario = scenario
|
81
|
+
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
55
82
|
self.max_steps = 200
|
56
83
|
self._push_units_away = lambda x: x # overwrite push units
|
84
|
+
self.top_sector = sector_fn(self.terrain_raster, 0)
|
85
|
+
self.low_sector = sector_fn(self.terrain_raster, 24)
|
57
86
|
|
58
87
|
@partial(jax.jit, static_argnums=(0,))
|
59
|
-
def reset(self,
|
88
|
+
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
60
89
|
"""Environment-specific reset."""
|
61
|
-
|
62
|
-
team_0_start =
|
63
|
-
|
64
|
-
)
|
65
|
-
team_0_start_noise = jax.random.uniform(
|
66
|
-
team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
|
67
|
-
)
|
68
|
-
team_0_start = team_0_start + team_0_start_noise
|
69
|
-
team_1_start = jnp.stack(
|
70
|
-
[jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
|
71
|
-
* self.num_enemies
|
72
|
-
)
|
73
|
-
team_1_start_noise = jax.random.uniform(
|
74
|
-
team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
|
75
|
-
)
|
76
|
-
team_1_start = team_1_start + team_1_start_noise
|
90
|
+
ally_key, enemy_key = jax.random.split(rng)
|
91
|
+
team_0_start = spawn_fn(*self.top_sector, self.num_allies, ally_key)
|
92
|
+
team_1_start = spawn_fn(*self.low_sector, self.num_enemies, enemy_key)
|
77
93
|
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
78
|
-
key, pos_key = jax.random.split(key)
|
79
|
-
generated_unit_positions = self.position_generator.generate(pos_key)
|
80
|
-
unit_positions = jax.lax.select(
|
81
|
-
self.smacv2_position_generation, generated_unit_positions, unit_positions
|
82
|
-
)
|
83
94
|
unit_teams = jnp.zeros((self.num_agents,))
|
84
95
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
85
96
|
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
86
97
|
# default behaviour spawn all marines
|
87
|
-
unit_types =
|
88
|
-
jnp.zeros((self.num_agents,), dtype=jnp.uint8)
|
89
|
-
if self.scenario is None
|
90
|
-
else self.scenario
|
91
|
-
)
|
92
|
-
key, unit_type_key = jax.random.split(key)
|
93
|
-
generated_unit_types = self.unit_type_generator.generate(unit_type_key)
|
94
|
-
unit_types = jax.lax.select(
|
95
|
-
self.smacv2_unit_type_generation, generated_unit_types, unit_types
|
96
|
-
)
|
98
|
+
unit_types = self.scenario.unit_types
|
97
99
|
unit_health = self.unit_type_health[unit_types]
|
98
100
|
state = State(
|
99
101
|
unit_positions=unit_positions,
|
@@ -180,12 +182,12 @@ class Parabellum(SMAX):
|
|
180
182
|
|
181
183
|
#######################################################################
|
182
184
|
############################################ avoid going into obstacles
|
183
|
-
obs = self.obstacle_coords
|
185
|
+
""" obs = self.obstacle_coords
|
184
186
|
obs_end = obs + self.obstacle_deltas
|
185
|
-
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
|
186
|
-
|
187
|
-
flag = jnp.logical_or(inters, rastersects)
|
188
|
-
new_pos = jnp.where(
|
187
|
+
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
|
188
|
+
clash = raster_crossing(pos, new_pos)
|
189
|
+
# flag = jnp.logical_or(inters, rastersects)
|
190
|
+
new_pos = jnp.where(clash, pos, new_pos)
|
189
191
|
|
190
192
|
#######################################################################
|
191
193
|
#######################################################################
|
@@ -310,14 +312,14 @@ class Parabellum(SMAX):
|
|
310
312
|
[self.map_width, 0],
|
311
313
|
]
|
312
314
|
)
|
313
|
-
obstacle_coords = jnp.concatenate(
|
315
|
+
""" obstacle_coords = jnp.concatenate(
|
314
316
|
[self.obstacle_coords, bondaries_coords]
|
315
317
|
) # add the map boundaries to the obstacles to avoid
|
316
318
|
obstacle_deltas = jnp.concatenate(
|
317
319
|
[self.obstacle_deltas, bondaries_deltas]
|
318
320
|
) # add the map boundaries to the obstacles to avoid
|
319
321
|
obst_start = obstacle_coords
|
320
|
-
obst_end = obst_start + obstacle_deltas
|
322
|
+
obst_end = obst_start + obstacle_deltas """
|
321
323
|
|
322
324
|
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
323
325
|
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
@@ -325,9 +327,9 @@ class Parabellum(SMAX):
|
|
325
327
|
flag = jnp.logical_or(intersects, rastersect)
|
326
328
|
return jnp.where(flag, pos, new_pos)
|
327
329
|
|
328
|
-
pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
330
|
+
""" pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
329
331
|
pos, new_pos, obst_start, obst_end
|
330
|
-
)
|
332
|
+
) """
|
331
333
|
|
332
334
|
# Multiple enemies can attack the same unit.
|
333
335
|
# We have `(health_diff, attacked_idx)` pairs.
|
@@ -373,13 +375,16 @@ class Parabellum(SMAX):
|
|
373
375
|
|
374
376
|
if __name__ == "__main__":
|
375
377
|
n_envs = 4
|
376
|
-
|
377
|
-
env =
|
378
|
+
|
379
|
+
env = Environment(scenarios["default"])
|
378
380
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
379
381
|
reset_key = random.split(reset_rng, n_envs)
|
380
382
|
obs, state = vmap(env.reset)(reset_key)
|
381
383
|
state_seq = []
|
382
384
|
|
385
|
+
print(state.unit_positions)
|
386
|
+
exit()
|
387
|
+
|
383
388
|
for i in range(10):
|
384
389
|
rng, act_rng, step_rng = random.split(rng, 3)
|
385
390
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
parabellum/map.py
CHANGED
@@ -4,13 +4,46 @@
|
|
4
4
|
|
5
5
|
# imports
|
6
6
|
import jax.numpy as jnp
|
7
|
-
import
|
7
|
+
from geopy.geocoders import Nominatim
|
8
|
+
import geopandas as gpd
|
9
|
+
import osmnx as ox
|
10
|
+
import rasterio
|
11
|
+
from jax import random
|
12
|
+
from rasterio import features
|
13
|
+
import rasterio.transform
|
14
|
+
|
15
|
+
# constants
|
16
|
+
geolocator = Nominatim(user_agent="parabellum")
|
17
|
+
tags = {"building": True}
|
8
18
|
|
9
19
|
|
10
20
|
# functions
|
11
|
-
def
|
12
|
-
"""
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
21
|
+
def terrain_fn(place: str, size: int = 1000):
|
22
|
+
"""Returns a rasterized map of a given location."""
|
23
|
+
|
24
|
+
# location info
|
25
|
+
location = geolocator.geocode(place)
|
26
|
+
coords = (location.latitude, location.longitude)
|
27
|
+
|
28
|
+
# shape info
|
29
|
+
geometry = ox.features_from_point(coords, tags=tags, dist=size // 2)
|
30
|
+
gdf = gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
|
31
|
+
|
32
|
+
# raster info
|
33
|
+
t = rasterio.transform.from_bounds(*gdf.total_bounds, size, size)
|
34
|
+
raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=t)
|
35
|
+
|
36
|
+
# rotate 180 degrees
|
37
|
+
raster = jnp.rot90(raster, 2)
|
38
|
+
|
39
|
+
return jnp.array(raster).astype(jnp.uint8)
|
40
|
+
|
41
|
+
|
42
|
+
if __name__ == "__main__":
|
43
|
+
import seaborn as sns
|
44
|
+
|
45
|
+
place = "Vesterbro, Copenhagen, Denmark"
|
46
|
+
terrain = terrain_fn(place)
|
47
|
+
rng, key = random.split(random.PRNGKey(0))
|
48
|
+
agents = spawn_fn(terrain, 12, 100, key)
|
49
|
+
print(agents)
|
parabellum/run.py
CHANGED
@@ -39,7 +39,7 @@ class Game:
|
|
39
39
|
obs: Dict
|
40
40
|
state_seq: StateSeq
|
41
41
|
control: Control
|
42
|
-
env: pb.
|
42
|
+
env: pb.Environment
|
43
43
|
rng: random.PRNGKey
|
44
44
|
|
45
45
|
|
@@ -123,5 +123,3 @@ if __name__ == "__main__":
|
|
123
123
|
game = game if game.control.paused else render(game)
|
124
124
|
|
125
125
|
pygame.quit()
|
126
|
-
|
127
|
-
|
parabellum/vis.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Visualizer for the Parabellum environment
|
3
|
+
"""
|
2
4
|
|
3
5
|
from tqdm import tqdm
|
4
6
|
import jax.numpy as jnp
|
@@ -9,6 +11,7 @@ from functools import partial
|
|
9
11
|
import darkdetect
|
10
12
|
import numpy as np
|
11
13
|
import pygame
|
14
|
+
import matplotlib.pyplot as plt
|
12
15
|
import os
|
13
16
|
from moviepy.editor import ImageSequenceClip
|
14
17
|
from typing import Optional
|
@@ -31,15 +34,24 @@ def small_multiples():
|
|
31
34
|
print(len(clips))
|
32
35
|
|
33
36
|
|
37
|
+
def text_fn(text):
|
38
|
+
"""rotate text upside down because of pygame issue"""
|
39
|
+
return pygame.transform.rotate(text, 180)
|
40
|
+
|
41
|
+
|
34
42
|
class Visualizer(SMAXVisualizer):
|
35
|
-
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
36
|
-
super().__init__(env, state_seq, reward_seq)
|
37
|
-
# remove fig and ax from super
|
43
|
+
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None, skin=None):
|
38
44
|
self.fig, self.ax = None, None
|
45
|
+
super().__init__(env, state_seq, reward_seq)
|
46
|
+
# clear the figure made by SMAXVisualizer
|
47
|
+
plt.close()
|
39
48
|
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
40
|
-
self.fg = (
|
41
|
-
self.
|
42
|
-
|
49
|
+
self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
|
50
|
+
self.pad = 75
|
51
|
+
# TODO: make sure it's always a 1024x1024 image
|
52
|
+
self.width = 1000
|
53
|
+
self.s = self.width + self.pad + self.pad
|
54
|
+
self.scale = self.width / env.map_width
|
43
55
|
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
44
56
|
# self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
|
45
57
|
|
@@ -63,11 +75,13 @@ class Visualizer(SMAXVisualizer):
|
|
63
75
|
def animate_one(self, state_seq, action_seq, save_fname):
|
64
76
|
frames = [] # frames for the video
|
65
77
|
pygame.init() # initialize pygame
|
66
|
-
terrain = np.array(self.env.terrain_raster)
|
78
|
+
terrain = np.array(self.env.terrain_raster.T)
|
67
79
|
rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
|
80
|
+
if darkdetect.isLight():
|
81
|
+
rgb_array += 255
|
68
82
|
rgb_array[terrain == 1] = self.fg
|
69
83
|
mask_surface = pygame.surfarray.make_surface(rgb_array)
|
70
|
-
mask_surface = pygame.transform.scale(mask_surface, (self.
|
84
|
+
mask_surface = pygame.transform.scale(mask_surface, (self.width, self.width))
|
71
85
|
|
72
86
|
for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
|
73
87
|
action = action_seq[idx // self.env.world_steps_per_env_step]
|
@@ -75,11 +89,33 @@ class Visualizer(SMAXVisualizer):
|
|
75
89
|
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
76
90
|
)
|
77
91
|
screen.fill(self.bg) # fill the screen with the background color
|
78
|
-
screen.blit(
|
92
|
+
screen.blit(
|
93
|
+
mask_surface,
|
94
|
+
(self.pad, self.pad, self.width, self.width),
|
95
|
+
)
|
96
|
+
# add env.scenario.place to the title (in top padding)
|
97
|
+
font = pygame.font.SysFont("Fira Code", 18)
|
98
|
+
width = self.env.map_width
|
99
|
+
title = f"{width}x{width}m in {self.env.scenario.place}"
|
100
|
+
text = text_fn(font.render(title, True, self.fg))
|
101
|
+
# center the text
|
102
|
+
screen.blit(text, (self.s // 2 - text.get_width() // 2, self.pad // 4))
|
103
|
+
# draw edge around terrain
|
104
|
+
pygame.draw.rect(
|
105
|
+
screen,
|
106
|
+
self.fg,
|
107
|
+
(
|
108
|
+
self.pad - 2,
|
109
|
+
self.pad - 2,
|
110
|
+
self.width + 4,
|
111
|
+
self.width + 4,
|
112
|
+
),
|
113
|
+
2,
|
114
|
+
)
|
79
115
|
|
80
116
|
self.render_agents(screen, state) # render the agents
|
81
117
|
self.render_action(screen, action)
|
82
|
-
self.render_obstacles(screen) # render the obstacles
|
118
|
+
# self.render_obstacles(screen) # render the obstacles
|
83
119
|
|
84
120
|
# bullets
|
85
121
|
""" if idx < len(self.bullet_seq) * 8:
|
@@ -87,15 +123,16 @@ class Visualizer(SMAXVisualizer):
|
|
87
123
|
self.render_bullets(screen, bullets, idx % 8) """
|
88
124
|
|
89
125
|
# rotate the screen and append to frames
|
90
|
-
|
126
|
+
pixels = pygame.surfarray.pixels3d(screen).swapaxes(0, 1)
|
127
|
+
# rotate the screen 180 degrees (transpose and flip)
|
128
|
+
pixels = np.rot90(pixels, 2) # pygame starts in bottom left
|
129
|
+
frames.append(pixels)
|
91
130
|
# save the images
|
92
131
|
clip = ImageSequenceClip(frames, fps=48)
|
93
132
|
clip.write_videofile(save_fname, fps=48)
|
94
|
-
|
133
|
+
clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
95
134
|
pygame.quit()
|
96
135
|
|
97
|
-
return clip
|
98
|
-
|
99
136
|
def render_agents(self, screen, state):
|
100
137
|
time_tuple = zip(
|
101
138
|
state.unit_positions,
|
@@ -105,7 +142,7 @@ class Visualizer(SMAXVisualizer):
|
|
105
142
|
)
|
106
143
|
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
107
144
|
face_col = self.fg if int(team.item()) == 0 else self.bg
|
108
|
-
pos = tuple((pos * self.scale).tolist())
|
145
|
+
pos = tuple(((pos * self.scale) + self.pad).tolist())
|
109
146
|
# draw the agent
|
110
147
|
if hp > 0:
|
111
148
|
hp_frac = hp / self.env.unit_type_health[kind]
|
@@ -124,26 +161,31 @@ class Visualizer(SMAXVisualizer):
|
|
124
161
|
# work out which agents are being shot
|
125
162
|
|
126
163
|
def render_action(self, screen, action):
|
127
|
-
|
164
|
+
if self.env.action_type != "discrete":
|
165
|
+
return
|
166
|
+
|
167
|
+
def coord_fn(idx, n, team, text):
|
168
|
+
text_adj = text.get_width() / 2
|
169
|
+
is_ally = team == "ally"
|
128
170
|
return (
|
129
|
-
self.s / 20 if team == 0 else self.s - self.s / 20,
|
130
171
|
# vertically centered so that n / 2 is above and below the center
|
172
|
+
self.pad + self.width + self.pad / 2 - text_adj
|
173
|
+
if is_ally
|
174
|
+
else self.pad / 2 - text_adj,
|
131
175
|
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
132
176
|
)
|
133
177
|
|
134
|
-
for
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
coord = coord_fn(idx, self.env.num_enemies, 1)
|
146
|
-
screen.blit(text, coord)
|
178
|
+
for team, number in [("ally", 0), ("enemy", 1)]:
|
179
|
+
for idx in range(self.env.num_allies):
|
180
|
+
symb = action_to_symbol.get(
|
181
|
+
action[f"{team}_{idx}"].astype(int).item(), "Ø"
|
182
|
+
)
|
183
|
+
font = pygame.font.SysFont(
|
184
|
+
"Fira Code", jnp.sqrt(self.s).astype(int).item()
|
185
|
+
)
|
186
|
+
text = text_fn(font.render(symb, True, self.fg))
|
187
|
+
coord = coord_fn(idx, self.env.num_allies, team, text)
|
188
|
+
screen.blit(text, coord)
|
147
189
|
|
148
190
|
def render_obstacles(self, screen):
|
149
191
|
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
@@ -238,7 +280,7 @@ if __name__ == "__main__":
|
|
238
280
|
# exit()
|
239
281
|
|
240
282
|
n_envs = 2
|
241
|
-
env = Parabellum(scenarios["default"])
|
283
|
+
env = Parabellum(scenarios["default"], action_type="discrete")
|
242
284
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
243
285
|
reset_key = random.split(reset_rng, n_envs)
|
244
286
|
obs, state = vmap(env.reset)(reset_key)
|
@@ -248,7 +290,7 @@ if __name__ == "__main__":
|
|
248
290
|
rng, act_rng, step_rng = random.split(rng, 3)
|
249
291
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
250
292
|
act = {
|
251
|
-
a:
|
293
|
+
a: vmap(env.action_space(a).sample)(act_key[i])
|
252
294
|
for i, a in enumerate(env.agents)
|
253
295
|
}
|
254
296
|
step_key = random.split(step_rng, n_envs)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.2.20
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
8
|
+
Author: Noah Syrkis
|
9
|
+
Author-email: desk@syrkis.com
|
10
|
+
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: geopandas (>=1.0.0,<2.0.0)
|
17
|
+
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
18
|
+
Requires-Dist: jax (==0.4.17)
|
19
|
+
Requires-Dist: jaxmarl (==0.0.3)
|
20
|
+
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
21
|
+
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
22
|
+
Requires-Dist: numpy (<2)
|
23
|
+
Requires-Dist: osmnx (>=1.9.3,<2.0.0)
|
24
|
+
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
25
|
+
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
26
|
+
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
27
|
+
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
28
|
+
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
29
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
30
|
+
Description-Content-Type: text/markdown
|
31
|
+
|
32
|
+
# Parabellum
|
33
|
+
|
34
|
+
Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
|
35
|
+
|
36
|
+
[](https://parabellum.readthedocs.io/en/latest/?badge=latest)
|
37
|
+
|
38
|
+
## Features
|
39
|
+
|
40
|
+
- Obstacles and terrain integration
|
41
|
+
- Rasterized maps
|
42
|
+
- Blast radii simulation
|
43
|
+
- Friendly fire mechanics
|
44
|
+
- Pygame visualization
|
45
|
+
- JAX-based parallelization
|
46
|
+
|
47
|
+
## Install
|
48
|
+
|
49
|
+
```bash
|
50
|
+
pip install parabellum
|
51
|
+
```
|
52
|
+
|
53
|
+
## Quick Start
|
54
|
+
|
55
|
+
```python
|
56
|
+
import parabellum as pb
|
57
|
+
from jax import random
|
58
|
+
|
59
|
+
terrain = pb.terrain_fn("Thun, Switzerland", 1000)
|
60
|
+
scenario = pb.make_scenario("Thun", terrain, 10, 10)
|
61
|
+
env = pb.Parabellum(scenario)
|
62
|
+
|
63
|
+
rng, key = random.split(random.PRNGKey(0))
|
64
|
+
obs, state = env.reset(key)
|
65
|
+
|
66
|
+
# Simulation loop
|
67
|
+
for _ in range(100):
|
68
|
+
rng, rng_act, key_step = random.split(key)
|
69
|
+
key_act = random.split(rng_act, len(env.agents))
|
70
|
+
act = {a: env.action_space(a).sample(k) for a, k in zip(env.agents, key_act)}
|
71
|
+
obs, state, reward, done, info = env.step(key_step, act, state)
|
72
|
+
|
73
|
+
# Visualize
|
74
|
+
vis = pb.Visualizer(env, state_sequence)
|
75
|
+
vis.animate()
|
76
|
+
```
|
77
|
+
|
78
|
+
## Documentation
|
79
|
+
|
80
|
+
Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.io)
|
81
|
+
|
82
|
+
## Team
|
83
|
+
|
84
|
+
- Noah Syrkis
|
85
|
+
- Timothée Anne
|
86
|
+
- Supervisor: Sebastian Risi
|
87
|
+
|
88
|
+
## License
|
89
|
+
|
90
|
+
MIT
|
@@ -0,0 +1,8 @@
|
|
1
|
+
parabellum/__init__.py,sha256=cI1kxVQ274VZrBLxUYvB_j6UqhFY44sId4NgVjbf678,245
|
2
|
+
parabellum/env.py,sha256=B2WaEzCQnakVB1AuYqFsb7aiPIbaUsuq0CjdKq4pQm8,16204
|
3
|
+
parabellum/map.py,sha256=CdPebRuGafj9fAydAyTa8bqpr-PldZw7heKRReGTDBg,1257
|
4
|
+
parabellum/run.py,sha256=HyPdz5iVD8q0iYaZL2Nf02fGHByHpecUGoPlQrq9v8s,3411
|
5
|
+
parabellum/vis.py,sha256=uIhN9VSJlT4XoNuNCiU_Bw2VSNoJFOYlScCTNrMkGsg,11977
|
6
|
+
parabellum-0.2.20.dist-info/METADATA,sha256=IO3Wsx_nTM3v63FtbnJqaEfPfXOyhNGtKOm_ppchUqo,2453
|
7
|
+
parabellum-0.2.20.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
+
parabellum-0.2.20.dist-info/RECORD,,
|
@@ -1,104 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: parabellum
|
3
|
-
Version: 0.2.18
|
4
|
-
Summary: Parabellum environment for parallel warfare simulation
|
5
|
-
Home-page: https://github.com/syrkis/parabellum
|
6
|
-
License: MIT
|
7
|
-
Keywords: warfare,simulation,parallel,environment
|
8
|
-
Author: Noah Syrkis
|
9
|
-
Author-email: desk@syrkis.com
|
10
|
-
Requires-Python: >=3.11,<4.0
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
12
|
-
Classifier: Programming Language :: Python :: 3
|
13
|
-
Classifier: Programming Language :: Python :: 3.11
|
14
|
-
Classifier: Programming Language :: Python :: 3.12
|
15
|
-
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
-
Requires-Dist: jax (==0.4.17)
|
17
|
-
Requires-Dist: jaxmarl (==0.0.3)
|
18
|
-
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
19
|
-
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
20
|
-
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
21
|
-
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
22
|
-
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
23
|
-
Project-URL: Repository, https://github.com/syrkis/parabellum
|
24
|
-
Description-Content-Type: text/markdown
|
25
|
-
|
26
|
-
# parabellum
|
27
|
-
|
28
|
-
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
29
|
-
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
30
|
-
support a wide range of new features and improvements.
|
31
|
-
|
32
|
-
## Installation
|
33
|
-
|
34
|
-
Parabellum is written in Python 3.11 and can be installed using pip:
|
35
|
-
|
36
|
-
```bash
|
37
|
-
pip install parabellum
|
38
|
-
```
|
39
|
-
|
40
|
-
## Usage
|
41
|
-
|
42
|
-
Parabellum is designed to be used in conjunction with JAX, a high-performance
|
43
|
-
numerical computing library. Here is a simple example of how to use Parabellum
|
44
|
-
to simulate a game with 10 agents and 10 enemies, each taking random actions:
|
45
|
-
|
46
|
-
```python
|
47
|
-
import parabellum as pb
|
48
|
-
from jax import random
|
49
|
-
|
50
|
-
# define the scenario
|
51
|
-
kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
|
52
|
-
scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
|
53
|
-
|
54
|
-
# create the environment
|
55
|
-
kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
|
56
|
-
env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
|
57
|
-
|
58
|
-
# initiate stochasticity
|
59
|
-
rng = random.PRNGKey(0)
|
60
|
-
rng, key = random.split(rng)
|
61
|
-
|
62
|
-
# initialize the environment state
|
63
|
-
obs, state = env.reset(key)
|
64
|
-
state_sequence = []
|
65
|
-
|
66
|
-
for _ in range(1000):
|
67
|
-
|
68
|
-
# manage stochasticity
|
69
|
-
rng, rng_act, key_step = random.split(key)
|
70
|
-
key_act = random.split(rng_act, len(env.agents))
|
71
|
-
|
72
|
-
# sample actions and append to state sequence
|
73
|
-
act = {a: env.action_space(a).sample(k)
|
74
|
-
for a, k in zip(env.agents, key_act)}
|
75
|
-
|
76
|
-
# step the environment
|
77
|
-
state_sequence.append((key_act, state, act))
|
78
|
-
obs, state, reward, done, info = env.step(key_step, act, state)
|
79
|
-
|
80
|
-
|
81
|
-
# save visualization of the state sequence
|
82
|
-
vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
|
83
|
-
vis.animate()
|
84
|
-
```
|
85
|
-
|
86
|
-
|
87
|
-
## Features
|
88
|
-
|
89
|
-
- Obstacles — can be inserted in
|
90
|
-
|
91
|
-
## TODO
|
92
|
-
|
93
|
-
- [x] Parallel pygame vis
|
94
|
-
- [ ] Parallel bullet renderings
|
95
|
-
- [ ] Combine parallell plots into one (maybe out of parabellum scope)
|
96
|
-
- [ ] Color for health?
|
97
|
-
- [ ] Add the ability to see ongoing game.
|
98
|
-
- [ ] Bug test friendly fire.
|
99
|
-
- [x] Start sim from arbitrary state.
|
100
|
-
- [ ] Save when the episode ends in some state/obs variable
|
101
|
-
- [ ] Look for the source of the bug when using more Allies than Enemies
|
102
|
-
- [ ] Y inversed axis for parabellum visualization
|
103
|
-
- [ ] Units see through obstacles?
|
104
|
-
|
@@ -1,8 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
-
parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
|
3
|
-
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
-
parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
|
5
|
-
parabellum/vis.py,sha256=vb1ndWuGHBhG9Y_9ygFUDRGRfTTzKESD43jNCHMp2hQ,10564
|
6
|
-
parabellum-0.2.18.dist-info/METADATA,sha256=-DXRDSO86-I9gZlFrnkRz6wkYplXGQY5TeCHfiPOOno,3223
|
7
|
-
parabellum-0.2.18.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
-
parabellum-0.2.18.dist-info/RECORD,,
|
File without changes
|