parabellum 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +17 -3
- parabellum/aid.py +5 -0
- parabellum/env.py +84 -64
- parabellum/gun.py +70 -0
- parabellum/map.py +61 -9
- parabellum/run.py +7 -8
- parabellum/vis.py +150 -247
- parabellum-0.2.21.dist-info/METADATA +95 -0
- parabellum-0.2.21.dist-info/RECORD +10 -0
- parabellum-0.2.19.dist-info/METADATA +0 -104
- parabellum-0.2.19.dist-info/RECORD +0 -8
- {parabellum-0.2.19.dist-info → parabellum-0.2.21.dist-info}/WHEEL +0 -0
parabellum/__init__.py
CHANGED
@@ -1,4 +1,18 @@
|
|
1
|
-
from .env import
|
2
|
-
from .vis import Visualizer
|
1
|
+
from .env import Environment, Scenario, scenarios, make_scenario, State
|
2
|
+
from .vis import Visualizer, Skin
|
3
|
+
from .map import terrain_fn
|
4
|
+
from .gun import bullet_fn
|
5
|
+
# from .aid import aid
|
6
|
+
# from .run import run
|
3
7
|
|
4
|
-
__all__ = [
|
8
|
+
__all__ = [
|
9
|
+
"Environment",
|
10
|
+
"Scenario",
|
11
|
+
"scenarios",
|
12
|
+
"make_scenario",
|
13
|
+
"State",
|
14
|
+
"Visualizer",
|
15
|
+
"Skin",
|
16
|
+
"terrain_fn",
|
17
|
+
"bullet_fn",
|
18
|
+
]
|
parabellum/aid.py
ADDED
parabellum/env.py
CHANGED
@@ -3,13 +3,13 @@
|
|
3
3
|
import jax.numpy as jnp
|
4
4
|
import jax
|
5
5
|
import numpy as np
|
6
|
-
from jax import random
|
6
|
+
from jax import random, Array
|
7
7
|
from jax import jit
|
8
8
|
from flax.struct import dataclass
|
9
9
|
import chex
|
10
10
|
from jax import vmap
|
11
|
-
from jaxmarl.environments.smax.smax_env import
|
12
|
-
from typing import Tuple, Dict
|
11
|
+
from jaxmarl.environments.smax.smax_env import SMAX
|
12
|
+
from typing import Tuple, Dict, cast
|
13
13
|
from functools import partial
|
14
14
|
|
15
15
|
|
@@ -17,11 +17,8 @@ from functools import partial
|
|
17
17
|
class Scenario:
|
18
18
|
"""Parabellum scenario"""
|
19
19
|
|
20
|
-
|
21
|
-
|
22
|
-
obstacle_coords: chex.Array # TODO: use map instead of obstacles
|
23
|
-
obstacle_deltas: chex.Array
|
24
|
-
|
20
|
+
place: str
|
21
|
+
terrain_raster: jnp.ndarray
|
25
22
|
unit_types: chex.Array
|
26
23
|
num_allies: int
|
27
24
|
num_enemies: int
|
@@ -29,13 +26,25 @@ class Scenario:
|
|
29
26
|
smacv2_position_generation: bool = False
|
30
27
|
smacv2_unit_type_generation: bool = False
|
31
28
|
|
29
|
+
@dataclass
|
30
|
+
class State:
|
31
|
+
unit_positions: Array
|
32
|
+
unit_alive: Array
|
33
|
+
unit_teams: Array
|
34
|
+
unit_health: Array
|
35
|
+
unit_types: Array
|
36
|
+
unit_weapon_cooldowns: Array
|
37
|
+
prev_movement_actions: Array
|
38
|
+
prev_attack_actions: Array
|
39
|
+
time: int
|
40
|
+
terminal: bool
|
41
|
+
|
32
42
|
|
33
43
|
# default scenario
|
34
44
|
scenarios = {
|
35
45
|
"default": Scenario(
|
46
|
+
"Identity Town",
|
36
47
|
jnp.eye(64, dtype=jnp.uint8),
|
37
|
-
jnp.array([[80, 0], [16, 12]]),
|
38
|
-
jnp.array([[0, 80], [0, 20]]),
|
39
48
|
jnp.zeros((19,), dtype=jnp.uint8),
|
40
49
|
9,
|
41
50
|
10,
|
@@ -43,57 +52,63 @@ scenarios = {
|
|
43
52
|
}
|
44
53
|
|
45
54
|
|
46
|
-
|
55
|
+
def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
|
56
|
+
"""Create a scenario"""
|
57
|
+
num_agents = num_allies + num_enemies
|
58
|
+
unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
|
59
|
+
return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
|
60
|
+
|
61
|
+
|
62
|
+
def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
|
63
|
+
"""Spawns n agents on a map."""
|
64
|
+
rng, key_start, key_noise = random.split(rng, 3)
|
65
|
+
noise = random.uniform(key_noise, (n, 2)) * 0.5
|
66
|
+
|
67
|
+
# select n random (x, y)-coords where sector == True
|
68
|
+
idxs = random.choice(key_start, pool[0].shape[0], (n,), replace=False)
|
69
|
+
coords = jnp.array([pool[0][idxs], pool[1][idxs]]).T
|
70
|
+
|
71
|
+
return coords + noise + offset
|
72
|
+
|
73
|
+
|
74
|
+
def sector_fn(terrain: jnp.ndarray, sector_id: int):
|
75
|
+
"""return sector slice of terrain"""
|
76
|
+
width, height = terrain.shape
|
77
|
+
coordx, coordy = sector_id // 5 * width // 5, sector_id % 5 * height // 5
|
78
|
+
sector = terrain[coordx : coordx + width // 5, coordy : coordy + height // 5] == 0
|
79
|
+
offset = jnp.array([coordx, coordy])
|
80
|
+
# sector is jnp.nonzero
|
81
|
+
return jnp.nonzero(sector), offset
|
82
|
+
|
83
|
+
|
84
|
+
class Environment(SMAX):
|
47
85
|
def __init__(self, scenario: Scenario, **kwargs):
|
48
86
|
map_height, map_width = scenario.terrain_raster.shape
|
49
87
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
50
|
-
super(
|
88
|
+
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
51
89
|
self.terrain_raster = scenario.terrain_raster
|
52
|
-
self.
|
53
|
-
self.
|
54
|
-
self.
|
90
|
+
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
91
|
+
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
92
|
+
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
93
|
+
self.scenario = scenario
|
94
|
+
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
55
95
|
self.max_steps = 200
|
56
|
-
self._push_units_away = lambda
|
96
|
+
self._push_units_away = lambda state, firmness = 1: state # overwrite push units
|
97
|
+
self.top_sector, self.top_sector_offset = sector_fn(self.terrain_raster, 0)
|
98
|
+
self.low_sector, self.low_sector_offset = sector_fn(self.terrain_raster, 24)
|
57
99
|
|
58
100
|
@partial(jax.jit, static_argnums=(0,))
|
59
|
-
def reset(self,
|
101
|
+
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
60
102
|
"""Environment-specific reset."""
|
61
|
-
|
62
|
-
team_0_start =
|
63
|
-
|
64
|
-
)
|
65
|
-
team_0_start_noise = jax.random.uniform(
|
66
|
-
team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
|
67
|
-
)
|
68
|
-
team_0_start = team_0_start + team_0_start_noise
|
69
|
-
team_1_start = jnp.stack(
|
70
|
-
[jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
|
71
|
-
* self.num_enemies
|
72
|
-
)
|
73
|
-
team_1_start_noise = jax.random.uniform(
|
74
|
-
team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
|
75
|
-
)
|
76
|
-
team_1_start = team_1_start + team_1_start_noise
|
103
|
+
ally_key, enemy_key = jax.random.split(rng)
|
104
|
+
team_0_start = spawn_fn(self.top_sector, self.top_sector_offset, self.num_allies, ally_key)
|
105
|
+
team_1_start = spawn_fn(self.low_sector, self.low_sector_offset, self.num_enemies, enemy_key)
|
77
106
|
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
78
|
-
key, pos_key = jax.random.split(key)
|
79
|
-
generated_unit_positions = self.position_generator.generate(pos_key)
|
80
|
-
unit_positions = jax.lax.select(
|
81
|
-
self.smacv2_position_generation, generated_unit_positions, unit_positions
|
82
|
-
)
|
83
107
|
unit_teams = jnp.zeros((self.num_agents,))
|
84
108
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
85
109
|
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
86
110
|
# default behaviour spawn all marines
|
87
|
-
unit_types = (
|
88
|
-
jnp.zeros((self.num_agents,), dtype=jnp.uint8)
|
89
|
-
if self.scenario is None
|
90
|
-
else self.scenario
|
91
|
-
)
|
92
|
-
key, unit_type_key = jax.random.split(key)
|
93
|
-
generated_unit_types = self.unit_type_generator.generate(unit_type_key)
|
94
|
-
unit_types = jax.lax.select(
|
95
|
-
self.smacv2_unit_type_generation, generated_unit_types, unit_types
|
96
|
-
)
|
111
|
+
unit_types = cast(Array, self.scenario.unit_types)
|
97
112
|
unit_health = self.unit_type_health[unit_types]
|
98
113
|
state = State(
|
99
114
|
unit_positions=unit_positions,
|
@@ -107,7 +122,7 @@ class Parabellum(SMAX):
|
|
107
122
|
terminal=False,
|
108
123
|
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
109
124
|
)
|
110
|
-
state = self._push_units_away(state)
|
125
|
+
state = self._push_units_away(state) # type: ignore
|
111
126
|
obs = self.get_obs(state)
|
112
127
|
world_state = self.get_world_state(state)
|
113
128
|
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
@@ -139,7 +154,7 @@ class Parabellum(SMAX):
|
|
139
154
|
key: chex.PRNGKey,
|
140
155
|
state: State,
|
141
156
|
actions: Tuple[chex.Array, chex.Array],
|
142
|
-
) ->
|
157
|
+
) -> State:
|
143
158
|
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
144
159
|
def intersect_fn(pos, new_pos, obs, obs_end):
|
145
160
|
d1 = jnp.cross(obs - pos, new_pos - pos)
|
@@ -165,7 +180,7 @@ class Parabellum(SMAX):
|
|
165
180
|
# because these are easier to encode as actions than the four
|
166
181
|
# diagonal directions. Then rotate the velocity 45
|
167
182
|
# degrees anticlockwise to compute the movement.
|
168
|
-
pos = state.unit_positions[idx]
|
183
|
+
pos = cast(Array, state.unit_positions[idx])
|
169
184
|
new_pos = (
|
170
185
|
pos
|
171
186
|
+ vec
|
@@ -180,12 +195,12 @@ class Parabellum(SMAX):
|
|
180
195
|
|
181
196
|
#######################################################################
|
182
197
|
############################################ avoid going into obstacles
|
183
|
-
obs = self.obstacle_coords
|
198
|
+
""" obs = self.obstacle_coords
|
184
199
|
obs_end = obs + self.obstacle_deltas
|
185
|
-
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
|
186
|
-
|
187
|
-
flag = jnp.logical_or(inters, rastersects)
|
188
|
-
new_pos = jnp.where(
|
200
|
+
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
|
201
|
+
clash = raster_crossing(pos, new_pos)
|
202
|
+
# flag = jnp.logical_or(inters, rastersects)
|
203
|
+
new_pos = jnp.where(clash, pos, new_pos)
|
189
204
|
|
190
205
|
#######################################################################
|
191
206
|
#######################################################################
|
@@ -217,6 +232,7 @@ class Parabellum(SMAX):
|
|
217
232
|
lambda: action + self.num_allies - self.num_movement_actions,
|
218
233
|
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
219
234
|
)
|
235
|
+
attacked_idx = cast(int, attacked_idx) # Cast to int
|
220
236
|
# deal with no-op attack actions (i.e. agents that are moving instead)
|
221
237
|
attacked_idx = jax.lax.select(
|
222
238
|
action < self.num_movement_actions, idx, attacked_idx
|
@@ -248,7 +264,7 @@ class Parabellum(SMAX):
|
|
248
264
|
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
249
265
|
bystander_valid = (
|
250
266
|
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
251
|
-
.astype(jnp.bool_)
|
267
|
+
.astype(jnp.bool_) # type: ignore
|
252
268
|
.astype(jnp.float32)
|
253
269
|
)
|
254
270
|
bystander_health_diff = (
|
@@ -310,14 +326,14 @@ class Parabellum(SMAX):
|
|
310
326
|
[self.map_width, 0],
|
311
327
|
]
|
312
328
|
)
|
313
|
-
obstacle_coords = jnp.concatenate(
|
329
|
+
""" obstacle_coords = jnp.concatenate(
|
314
330
|
[self.obstacle_coords, bondaries_coords]
|
315
331
|
) # add the map boundaries to the obstacles to avoid
|
316
332
|
obstacle_deltas = jnp.concatenate(
|
317
333
|
[self.obstacle_deltas, bondaries_deltas]
|
318
334
|
) # add the map boundaries to the obstacles to avoid
|
319
335
|
obst_start = obstacle_coords
|
320
|
-
obst_end = obst_start + obstacle_deltas
|
336
|
+
obst_end = obst_start + obstacle_deltas """
|
321
337
|
|
322
338
|
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
323
339
|
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
@@ -325,9 +341,9 @@ class Parabellum(SMAX):
|
|
325
341
|
flag = jnp.logical_or(intersects, rastersect)
|
326
342
|
return jnp.where(flag, pos, new_pos)
|
327
343
|
|
328
|
-
pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
344
|
+
""" pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
329
345
|
pos, new_pos, obst_start, obst_end
|
330
|
-
)
|
346
|
+
) """
|
331
347
|
|
332
348
|
# Multiple enemies can attack the same unit.
|
333
349
|
# We have `(health_diff, attacked_idx)` pairs.
|
@@ -363,7 +379,8 @@ class Parabellum(SMAX):
|
|
363
379
|
#########################################################
|
364
380
|
|
365
381
|
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
366
|
-
|
382
|
+
# replace unit health, unit positions and unit weapon cooldowns
|
383
|
+
state = state.replace( # type: ignore
|
367
384
|
unit_health=unit_health,
|
368
385
|
unit_positions=pos,
|
369
386
|
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
@@ -373,13 +390,16 @@ class Parabellum(SMAX):
|
|
373
390
|
|
374
391
|
if __name__ == "__main__":
|
375
392
|
n_envs = 4
|
376
|
-
|
377
|
-
env =
|
393
|
+
|
394
|
+
env = Environment(scenarios["default"])
|
378
395
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
379
396
|
reset_key = random.split(reset_rng, n_envs)
|
380
397
|
obs, state = vmap(env.reset)(reset_key)
|
381
398
|
state_seq = []
|
382
399
|
|
400
|
+
print(state.unit_positions)
|
401
|
+
exit()
|
402
|
+
|
383
403
|
for i in range(10):
|
384
404
|
rng, act_rng, step_rng = random.split(rng, 3)
|
385
405
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
parabellum/gun.py
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
# gun.py
|
2
|
+
# parabellum bullet rendering assosciated functions
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
from functools import partial
|
7
|
+
import jax.numpy as jnp
|
8
|
+
|
9
|
+
|
10
|
+
def dist_fn(env, pos): # computing the distances between all ally and enemy agents
|
11
|
+
delta = pos[None, :, :] - pos[:, None, :]
|
12
|
+
dist = jnp.sqrt((delta**2).sum(axis=2))
|
13
|
+
dist = dist[: env.num_allies, env.num_allies :]
|
14
|
+
return {"ally": dist, "enemy": dist.T}
|
15
|
+
|
16
|
+
|
17
|
+
def range_fn(env, dists, ranges): # computing what targets are in range
|
18
|
+
ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
|
19
|
+
enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
|
20
|
+
return {"ally": ally_range, "enemy": enemy_range}
|
21
|
+
|
22
|
+
|
23
|
+
def target_fn(acts, in_range, team): # computing the one hot valid targets
|
24
|
+
t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
|
25
|
+
t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
|
26
|
+
t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
|
27
|
+
return t_attacks * in_range[team] # one hot valid targets
|
28
|
+
|
29
|
+
|
30
|
+
def attack_fn(env, state_seq): # one hot attack list
|
31
|
+
attacks = []
|
32
|
+
for _, state, acts in state_seq:
|
33
|
+
dists = dist_fn(env, state.unit_positions)
|
34
|
+
ranges = env.unit_type_attack_ranges[state.unit_types]
|
35
|
+
in_range = range_fn(env, dists, ranges)
|
36
|
+
target = partial(target_fn, acts, in_range)
|
37
|
+
attack = {"ally": target("ally"), "enemy": target("enemy")}
|
38
|
+
attacks.append(attack)
|
39
|
+
return attacks
|
40
|
+
|
41
|
+
|
42
|
+
def bullet_fn(env, states):
|
43
|
+
bullet_seq = []
|
44
|
+
attack_seq = attack_fn(env, states)
|
45
|
+
|
46
|
+
def aux_fn(team):
|
47
|
+
bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
|
48
|
+
# bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
|
49
|
+
return bullets
|
50
|
+
|
51
|
+
state_zip = zip(states[:-1], states[1:])
|
52
|
+
for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
|
53
|
+
one_hot = attack_seq[i]
|
54
|
+
ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
|
55
|
+
|
56
|
+
ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
|
57
|
+
enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
|
58
|
+
|
59
|
+
enemy_bullets_source = state.unit_positions[
|
60
|
+
enemy_bullets[:, 0] + env.num_allies
|
61
|
+
]
|
62
|
+
ally_bullets_target = n_state.unit_positions[
|
63
|
+
ally_bullets[:, 1] + env.num_allies
|
64
|
+
]
|
65
|
+
|
66
|
+
ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
|
67
|
+
enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
|
68
|
+
|
69
|
+
bullet_seq.append((ally_bullets, enemy_bullets))
|
70
|
+
return bullet_seq
|
parabellum/map.py
CHANGED
@@ -1,16 +1,68 @@
|
|
1
1
|
# map.py
|
2
|
-
#
|
2
|
+
# parabellum map functions
|
3
3
|
# by: Noah Syrkis
|
4
4
|
|
5
5
|
# imports
|
6
6
|
import jax.numpy as jnp
|
7
|
-
import
|
7
|
+
from geopy.geocoders import Nominatim
|
8
|
+
import geopandas as gpd
|
9
|
+
import osmnx as ox
|
10
|
+
import contextily as cx
|
11
|
+
import matplotlib.pyplot as plt
|
12
|
+
from rasterio import features
|
13
|
+
import rasterio.transform
|
14
|
+
from typing import Optional, Tuple
|
15
|
+
from geopy.location import Location
|
16
|
+
from shapely.geometry import Point
|
8
17
|
|
18
|
+
# constants
|
19
|
+
geolocator = Nominatim(user_agent="parabellum")
|
20
|
+
BUILDING_TAGS = {"building": True}
|
9
21
|
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
22
|
+
def get_location(place: str) -> Tuple[float, float]:
|
23
|
+
"""Get coordinates for a given place."""
|
24
|
+
coords: Optional[Location] = geolocator.geocode(place) # type: ignore
|
25
|
+
if coords is None:
|
26
|
+
raise ValueError(f"Could not geocode the place: {place}")
|
27
|
+
return (coords.latitude, coords.longitude)
|
28
|
+
|
29
|
+
def get_building_geometry(point: Tuple[float, float], size: int) -> gpd.GeoDataFrame:
|
30
|
+
"""Get building geometry for a given point and size."""
|
31
|
+
geometry = ox.features_from_point(point, tags=BUILDING_TAGS, dist=size // 2)
|
32
|
+
return gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
|
33
|
+
|
34
|
+
def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
|
35
|
+
"""Rasterize geometry and return as a JAX array."""
|
36
|
+
w, s, e, n = gdf.total_bounds
|
37
|
+
transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
|
38
|
+
raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
|
39
|
+
return jnp.array(jnp.rot90(raster, 2)).astype(jnp.uint8)
|
40
|
+
|
41
|
+
def terrain_fn(place: str, size: int = 1000) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
42
|
+
"""Returns a rasterized map of buildings for a given location."""
|
43
|
+
point = get_location(place)
|
44
|
+
gdf = get_building_geometry(point, size)
|
45
|
+
mask = rasterize_geometry(gdf, size)
|
46
|
+
base = get_basemap(place, size)
|
47
|
+
return mask, base
|
48
|
+
|
49
|
+
def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
|
50
|
+
"""Returns a basemap for a given place as a JAX array."""
|
51
|
+
point = get_location(place)
|
52
|
+
gdf = get_building_geometry(point, size)
|
53
|
+
basemap, _ = cx.bounds2img(*gdf.total_bounds, ll=True)
|
54
|
+
# get the middle size x size square
|
55
|
+
basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
|
56
|
+
(basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
|
57
|
+
return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
import seaborn as sns
|
62
|
+
place = "Thun, Switzerland"
|
63
|
+
mask, base = terrain_fn(place)
|
64
|
+
|
65
|
+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
66
|
+
ax[0].imshow(mask) # type: ignore
|
67
|
+
ax[1].imshow(base) # type: ignore
|
68
|
+
plt.show()
|
parabellum/run.py
CHANGED
@@ -10,6 +10,7 @@ import darkdetect
|
|
10
10
|
import jax.numpy as jnp
|
11
11
|
from chex import dataclass
|
12
12
|
import jaxmarl
|
13
|
+
from jax import Array
|
13
14
|
from typing import Tuple, List, Dict, Optional
|
14
15
|
import parabellum as pb
|
15
16
|
|
@@ -20,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
|
20
21
|
|
21
22
|
|
22
23
|
# types
|
23
|
-
State = jaxmarl.environments.smax.smax_env.State
|
24
|
+
State = jaxmarl.environments.smax.smax_env.State # type: ignore
|
24
25
|
Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
|
25
26
|
StateSeq = List[Tuple[jnp.ndarray, State, Action]]
|
26
27
|
|
@@ -35,12 +36,12 @@ class Control:
|
|
35
36
|
@dataclass
|
36
37
|
class Game:
|
37
38
|
clock: pygame.time.Clock
|
38
|
-
state: State
|
39
|
+
state: State # type: ignore
|
39
40
|
obs: Dict
|
40
41
|
state_seq: StateSeq
|
41
42
|
control: Control
|
42
|
-
env: pb.
|
43
|
-
rng:
|
43
|
+
env: pb.Environment
|
44
|
+
rng: Array
|
44
45
|
|
45
46
|
|
46
47
|
def handle_event(event, control_state):
|
@@ -100,7 +101,7 @@ def step_fn(game):
|
|
100
101
|
|
101
102
|
# state
|
102
103
|
if __name__ == "__main__":
|
103
|
-
env = pb.
|
104
|
+
env = pb.Environment(pb.scenarios["default"])
|
104
105
|
pygame.init()
|
105
106
|
screen = pygame.display.set_mode((1000, 1000))
|
106
107
|
render = partial(render_fn, screen)
|
@@ -115,7 +116,7 @@ if __name__ == "__main__":
|
|
115
116
|
state=state,
|
116
117
|
obs=obs,
|
117
118
|
)
|
118
|
-
game = Game(**kwargs)
|
119
|
+
game = Game(**kwargs) # type: ignore
|
119
120
|
|
120
121
|
while game.control.running:
|
121
122
|
game = control_fn(game)
|
@@ -123,5 +124,3 @@ if __name__ == "__main__":
|
|
123
124
|
game = game if game.control.paused else render(game)
|
124
125
|
|
125
126
|
pygame.quit()
|
126
|
-
|
127
|
-
|
parabellum/vis.py
CHANGED
@@ -2,265 +2,168 @@
|
|
2
2
|
Visualizer for the Parabellum environment
|
3
3
|
"""
|
4
4
|
|
5
|
-
|
6
|
-
import jax.numpy as jnp
|
7
|
-
import jax
|
8
|
-
from jax import vmap
|
9
|
-
from jax import tree_util
|
5
|
+
# Standard library imports
|
10
6
|
from functools import partial
|
11
|
-
import
|
12
|
-
import
|
13
|
-
import
|
14
|
-
|
15
|
-
|
16
|
-
|
7
|
+
from typing import Optional, List, Tuple
|
8
|
+
import cv2
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
# JAX and JAX-related imports
|
12
|
+
import jax
|
13
|
+
from chex import dataclass
|
14
|
+
import chex
|
15
|
+
from jax import vmap, tree_util, Array, jit
|
16
|
+
import jax.numpy as jnp
|
17
17
|
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
18
|
+
from jaxmarl.environments.smax import SMAX
|
18
19
|
from jaxmarl.viz.visualizer import SMAXVisualizer
|
19
20
|
|
20
|
-
#
|
21
|
-
|
22
|
-
|
21
|
+
# Third-party imports
|
22
|
+
import numpy as np
|
23
|
+
import pygame
|
24
|
+
import cv2
|
25
|
+
from tqdm import tqdm
|
23
26
|
|
24
|
-
#
|
25
|
-
|
27
|
+
# Local imports
|
28
|
+
import parabellum as pb
|
26
29
|
|
27
30
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
#
|
32
|
-
|
33
|
-
|
31
|
+
# skin dataclass
|
32
|
+
@dataclass
|
33
|
+
class Skin:
|
34
|
+
# basemap: Array # basemap of buildings
|
35
|
+
maskmap: Array # maskmap of buildings
|
36
|
+
bg: Tuple[int, int, int] = (255, 255, 255)
|
37
|
+
fg: Tuple[int, int, int] = (0, 0, 0)
|
38
|
+
ally: Tuple[int, int, int] = (0, 255, 0)
|
39
|
+
enemy: Tuple[int, int, int] = (255, 0, 0)
|
40
|
+
pad: int = 100
|
41
|
+
size: int = 1000 # excluding padding
|
42
|
+
fps: int = 24
|
43
|
+
vis_size: int = 1000 # size of the map in Vis (exluding padding)
|
44
|
+
scale: Optional[float] = None
|
34
45
|
|
35
46
|
|
36
47
|
class Visualizer(SMAXVisualizer):
|
37
|
-
def __init__(self, env:
|
38
|
-
super().__init__(env, state_seq, reward_seq)
|
39
|
-
|
40
|
-
self.fig, self.ax = None, None
|
41
|
-
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
42
|
-
self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
|
43
|
-
self.s = 1000
|
44
|
-
self.scale = self.s / self.env.map_width
|
45
|
-
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
48
|
+
def __init__(self, env: pb.Environment, state_seq, skin: Skin, reward_seq=None):
|
49
|
+
super(Visualizer, self).__init__(env, state_seq, reward_seq)
|
50
|
+
|
46
51
|
# self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
|
52
|
+
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
53
|
+
self.state_seq = state_seq
|
54
|
+
self.image = image_fn(skin)
|
55
|
+
self.skin = skin
|
56
|
+
self.skin.scale = self.skin.size / env.map_width # assumes square map
|
57
|
+
self.env = env
|
58
|
+
|
47
59
|
|
48
|
-
def animate(self, save_fname: str = "output/parabellum
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
state_seqs = vmap(self.env.expand_state_seq)(self.state_seq)
|
54
|
-
self.have_expanded = True
|
55
|
-
for i in range(n_envs):
|
56
|
-
state_seq = jax.tree_map(lambda x: x[i], state_seqs)
|
57
|
-
action_seq = jax.tree_map(lambda x: x[i], self.action_seq)
|
58
|
-
self.animate_one(
|
59
|
-
state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
|
60
|
-
)
|
61
|
-
else:
|
62
|
-
state_seq = self.env.expand_state_seq(self.state_seq)
|
63
|
-
self.animate_one(state_seq, self.action_seq, save_fname)
|
64
|
-
|
65
|
-
def animate_one(self, state_seq, action_seq, save_fname):
|
66
|
-
frames = [] # frames for the video
|
67
|
-
pygame.init() # initialize pygame
|
68
|
-
terrain = np.array(self.env.terrain_raster)
|
69
|
-
rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
|
70
|
-
if darkdetect.isLight():
|
71
|
-
rgb_array += 255
|
72
|
-
rgb_array[terrain == 1] = self.fg
|
73
|
-
mask_surface = pygame.surfarray.make_surface(rgb_array)
|
74
|
-
mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
|
75
|
-
|
76
|
-
for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
|
77
|
-
action = action_seq[idx // self.env.world_steps_per_env_step]
|
78
|
-
screen = pygame.Surface(
|
79
|
-
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
80
|
-
)
|
81
|
-
screen.fill(self.bg) # fill the screen with the background color
|
82
|
-
screen.blit(mask_surface, (0, 0))
|
83
|
-
|
84
|
-
self.render_agents(screen, state) # render the agents
|
85
|
-
self.render_action(screen, action)
|
86
|
-
self.render_obstacles(screen) # render the obstacles
|
87
|
-
|
88
|
-
# bullets
|
89
|
-
""" if idx < len(self.bullet_seq) * 8:
|
90
|
-
bullets = self.bullet_seq[idx // 8]
|
91
|
-
self.render_bullets(screen, bullets, idx % 8) """
|
92
|
-
|
93
|
-
# rotate the screen and append to frames
|
94
|
-
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
95
|
-
# save the images
|
96
|
-
clip = ImageSequenceClip(frames, fps=48)
|
97
|
-
clip.write_videofile(save_fname, fps=48)
|
98
|
-
clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
99
|
-
pygame.quit()
|
100
|
-
|
101
|
-
return clip
|
102
|
-
|
103
|
-
def render_agents(self, screen, state):
|
104
|
-
time_tuple = zip(
|
105
|
-
state.unit_positions,
|
106
|
-
state.unit_teams,
|
107
|
-
state.unit_types,
|
108
|
-
state.unit_health,
|
109
|
-
)
|
110
|
-
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
111
|
-
face_col = self.fg if int(team.item()) == 0 else self.bg
|
112
|
-
pos = tuple((pos * self.scale).tolist())
|
113
|
-
# draw the agent
|
114
|
-
if hp > 0:
|
115
|
-
hp_frac = hp / self.env.unit_type_health[kind]
|
116
|
-
unit_size = self.env.unit_type_radiuses[kind]
|
117
|
-
radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
|
118
|
-
pygame.draw.circle(screen, face_col, pos, radius)
|
119
|
-
pygame.draw.circle(screen, self.fg, pos, radius, 1)
|
120
|
-
|
121
|
-
# draw the sight range
|
122
|
-
# sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
|
123
|
-
# pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
|
124
|
-
|
125
|
-
# draw attack range
|
126
|
-
# attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
|
127
|
-
# pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
|
128
|
-
# work out which agents are being shot
|
129
|
-
|
130
|
-
def render_action(self, screen, action):
|
131
|
-
if self.env.action_type != "discrete":
|
132
|
-
return
|
133
|
-
|
134
|
-
def coord_fn(idx, n, team):
|
135
|
-
return (
|
136
|
-
self.s / 20 if team == 0 else self.s - self.s / 20,
|
137
|
-
# vertically centered so that n / 2 is above and below the center
|
138
|
-
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
139
|
-
)
|
140
|
-
|
141
|
-
for idx in range(self.env.num_allies):
|
142
|
-
symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
|
143
|
-
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
144
|
-
text = font.render(symb, True, self.fg)
|
145
|
-
coord = coord_fn(idx, self.env.num_allies, 0)
|
146
|
-
screen.blit(text, coord)
|
147
|
-
|
148
|
-
for idx in range(self.env.num_enemies):
|
149
|
-
symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
|
150
|
-
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
151
|
-
text = font.render(symb, True, self.fg)
|
152
|
-
coord = coord_fn(idx, self.env.num_enemies, 1)
|
153
|
-
screen.blit(text, coord)
|
154
|
-
|
155
|
-
def render_obstacles(self, screen):
|
156
|
-
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
157
|
-
d = tuple(((c + d) * self.scale).tolist())
|
158
|
-
c = tuple((c * self.scale).tolist())
|
159
|
-
pygame.draw.line(screen, self.fg, c, d, 5)
|
160
|
-
|
161
|
-
def render_bullets(self, screen, bullets, jdx):
|
162
|
-
jdx += 1
|
163
|
-
ally_bullets, enemy_bullets = bullets
|
164
|
-
for source, target in ally_bullets:
|
165
|
-
position = source + (target - source) * jdx / 8
|
166
|
-
position *= self.scale
|
167
|
-
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
168
|
-
for source, target in enemy_bullets:
|
169
|
-
position = source + (target - source) * jdx / 8
|
170
|
-
position *= self.scale
|
171
|
-
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
60
|
+
def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
|
61
|
+
expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
|
62
|
+
state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
|
63
|
+
for idx, (state_seq, action_seq) in enumerate(zip(state_seq_seq, action_seq_seq)):
|
64
|
+
animate_fn(self.env, self.skin, self.image, state_seq, action_seq, f"{save_fname}_{idx}.mp4")
|
172
65
|
|
173
66
|
|
174
67
|
# functions
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
def
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
return
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
68
|
+
def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
|
69
|
+
pygame.init()
|
70
|
+
frames = []
|
71
|
+
for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
|
72
|
+
frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
|
73
|
+
# use cv2 to write frames to video
|
74
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
|
75
|
+
out = cv2.VideoWriter(save_fname, fourcc, skin.fps, (skin.size + skin.pad * 2, skin.size + skin.pad * 2))
|
76
|
+
for frame in frames:
|
77
|
+
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
78
|
+
out.release()
|
79
|
+
pygame.quit()
|
80
|
+
|
81
|
+
|
82
|
+
def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> pygame.Surface:
|
83
|
+
dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
|
84
|
+
frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
|
85
|
+
return frame
|
86
|
+
|
87
|
+
|
88
|
+
def transform_frame(env, skin, frame):
|
89
|
+
frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
90
|
+
return frame
|
91
|
+
|
92
|
+
|
93
|
+
def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.ndarray:
|
94
|
+
"""Create a frame"""
|
95
|
+
frame = init_frame(env, skin, image, state, action, idx)
|
96
|
+
|
97
|
+
pipeline = [render_background, render_agents, render_action, render_bullet]
|
98
|
+
for fn in pipeline:
|
99
|
+
frame = fn(env, skin, image, frame, state, action)
|
100
|
+
|
101
|
+
return transform_frame(env, skin, frame)
|
102
|
+
|
103
|
+
|
104
|
+
def render_background(env, skin, image, frame, state, action):
|
105
|
+
coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
|
106
|
+
frame.fill(skin.bg)
|
107
|
+
frame.blit(image, coords)
|
108
|
+
pygame.draw.rect(frame, skin.fg, coords, 3)
|
109
|
+
return frame
|
110
|
+
|
111
|
+
|
112
|
+
def render_action(env, skin, image, frame, state, action):
|
113
|
+
return frame
|
114
|
+
|
115
|
+
|
116
|
+
def render_bullet(env, skin, image, frame, state, action):
|
117
|
+
return frame
|
118
|
+
|
119
|
+
def render_agents(env, skin, image, frame, state, action):
|
120
|
+
units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
|
121
|
+
for idx, (pos, team, kind, health) in enumerate(zip(*units)):
|
122
|
+
pos = tuple((pos * skin.scale).astype(int) + skin.pad)
|
123
|
+
# draw the agent
|
124
|
+
if health > 0:
|
125
|
+
unit_size = env.unit_type_radiuses[kind]
|
126
|
+
radius = float(jnp.ceil((unit_size * skin.scale)).astype(int) + 1)
|
127
|
+
pygame.draw.circle(frame, skin.fg, pos, radius, 1)
|
128
|
+
pygame.draw.circle(frame, skin.bg, pos, radius + 1, 1)
|
129
|
+
return frame
|
130
|
+
|
131
|
+
|
132
|
+
def text_fn(text):
|
133
|
+
"""rotate text upside down because of pygame issue"""
|
134
|
+
return pygame.transform.rotate(text, 180)
|
135
|
+
|
136
|
+
|
137
|
+
def image_fn(skin: Skin): # TODO:
|
138
|
+
"""Create an image for background (basemap or maskmap)"""
|
139
|
+
motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
|
140
|
+
motif = (motif > 0).astype(np.uint8)
|
141
|
+
image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
|
142
|
+
image[motif == 1] = skin.fg
|
143
|
+
image = pygame.surfarray.make_surface(image)
|
144
|
+
image = pygame.transform.scale(image, (skin.size, skin.size))
|
145
|
+
return image
|
146
|
+
|
147
|
+
|
148
|
+
def unbatch_fn(state_seq, action_seq):
|
149
|
+
"""state seq is a list of tuples of (step_key, state, actions)."""
|
150
|
+
if is_multi_run(state_seq):
|
151
|
+
n_envs = state_seq[0][1].unit_positions.shape[0]
|
152
|
+
state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
|
153
|
+
action_seq_seq = [jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)]
|
154
|
+
else:
|
155
|
+
state_seq_seq = [state_seq]
|
156
|
+
action_seq_seq = [action_seq]
|
157
|
+
return state_seq_seq, action_seq_seq
|
158
|
+
|
159
|
+
|
160
|
+
def expand_fn(env, state_seq, action_seq):
|
161
|
+
"""Expand the state sequence"""
|
162
|
+
fn = env.expand_state_seq
|
163
|
+
state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
|
164
|
+
action_seq = [action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))]
|
165
|
+
return state_seq, action_seq
|
166
|
+
|
167
|
+
|
168
|
+
def is_multi_run(state_seq):
|
169
|
+
return state_seq[0][1].unit_positions.ndim > 2
|
@@ -0,0 +1,95 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.2.21
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
8
|
+
Author: Noah Syrkis
|
9
|
+
Author-email: desk@syrkis.com
|
10
|
+
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Dist: contextily (>=1.6.0,<2.0.0)
|
16
|
+
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
17
|
+
Requires-Dist: folium (>=0.17.0,<0.18.0)
|
18
|
+
Requires-Dist: geopandas (>=1.0.0,<2.0.0)
|
19
|
+
Requires-Dist: geopy (>=2.4.1,<3.0.0)
|
20
|
+
Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
|
21
|
+
Requires-Dist: jax (==0.4.17)
|
22
|
+
Requires-Dist: jaxmarl (==0.0.3)
|
23
|
+
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
24
|
+
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
25
|
+
Requires-Dist: numpy (<2)
|
26
|
+
Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
|
27
|
+
Requires-Dist: osmnx (>=1.9.3,<2.0.0)
|
28
|
+
Requires-Dist: pandas (>=2.2.2,<3.0.0)
|
29
|
+
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
30
|
+
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
31
|
+
Requires-Dist: rasterio (>=1.3.10,<2.0.0)
|
32
|
+
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
33
|
+
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
34
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
35
|
+
Description-Content-Type: text/markdown
|
36
|
+
|
37
|
+
# Parabellum
|
38
|
+
|
39
|
+
Ultra-scalable JaxMARL based warfare simulation engine developed with Armasuisse funding.
|
40
|
+
|
41
|
+
[](https://parabellum.readthedocs.io/en/latest/?badge=latest)
|
42
|
+
|
43
|
+
## Features
|
44
|
+
|
45
|
+
- Obstacles and terrain integration
|
46
|
+
- Rasterized maps
|
47
|
+
- Blast radii simulation
|
48
|
+
- Friendly fire mechanics
|
49
|
+
- Pygame visualization
|
50
|
+
- JAX-based parallelization
|
51
|
+
|
52
|
+
## Install
|
53
|
+
|
54
|
+
```bash
|
55
|
+
pip install parabellum
|
56
|
+
```
|
57
|
+
|
58
|
+
## Quick Start
|
59
|
+
|
60
|
+
```python
|
61
|
+
import parabellum as pb
|
62
|
+
from jax import random
|
63
|
+
|
64
|
+
terrain = pb.terrain_fn("Thun, Switzerland", 1000)
|
65
|
+
scenario = pb.make_scenario("Thun", terrain, 10, 10)
|
66
|
+
env = pb.Parabellum(scenario)
|
67
|
+
|
68
|
+
rng, key = random.split(random.PRNGKey(0))
|
69
|
+
obs, state = env.reset(key)
|
70
|
+
|
71
|
+
# Simulation loop
|
72
|
+
for _ in range(100):
|
73
|
+
rng, rng_act, key_step = random.split(key)
|
74
|
+
key_act = random.split(rng_act, len(env.agents))
|
75
|
+
act = {a: env.action_space(a).sample(k) for a, k in zip(env.agents, key_act)}
|
76
|
+
obs, state, reward, done, info = env.step(key_step, act, state)
|
77
|
+
|
78
|
+
# Visualize
|
79
|
+
vis = pb.Visualizer(env, state_sequence)
|
80
|
+
vis.animate()
|
81
|
+
```
|
82
|
+
|
83
|
+
## Documentation
|
84
|
+
|
85
|
+
Full documentation: [parabellum.readthedocs.io](https://parabellum.readthedocs.io)
|
86
|
+
|
87
|
+
## Team
|
88
|
+
|
89
|
+
- Noah Syrkis
|
90
|
+
- Timothée Anne
|
91
|
+
- Supervisor: Sebastian Risi
|
92
|
+
|
93
|
+
## License
|
94
|
+
|
95
|
+
MIT
|
@@ -0,0 +1,10 @@
|
|
1
|
+
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
+
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
+
parabellum/env.py,sha256=L6GHlLxywpkV1bRnZcYBURREPP4CRfet_pEwCt5DB04,16724
|
4
|
+
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
+
parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
|
6
|
+
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
+
parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
|
8
|
+
parabellum-0.2.21.dist-info/METADATA,sha256=-K-3eYl1BvR3tFsiTxTyfHErQJdgPQZx08iq_kS2544,2671
|
9
|
+
parabellum-0.2.21.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
+
parabellum-0.2.21.dist-info/RECORD,,
|
@@ -1,104 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: parabellum
|
3
|
-
Version: 0.2.19
|
4
|
-
Summary: Parabellum environment for parallel warfare simulation
|
5
|
-
Home-page: https://github.com/syrkis/parabellum
|
6
|
-
License: MIT
|
7
|
-
Keywords: warfare,simulation,parallel,environment
|
8
|
-
Author: Noah Syrkis
|
9
|
-
Author-email: desk@syrkis.com
|
10
|
-
Requires-Python: >=3.11,<4.0
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
12
|
-
Classifier: Programming Language :: Python :: 3
|
13
|
-
Classifier: Programming Language :: Python :: 3.11
|
14
|
-
Classifier: Programming Language :: Python :: 3.12
|
15
|
-
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
-
Requires-Dist: jax (==0.4.17)
|
17
|
-
Requires-Dist: jaxmarl (==0.0.3)
|
18
|
-
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
19
|
-
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
20
|
-
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
21
|
-
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
22
|
-
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
23
|
-
Project-URL: Repository, https://github.com/syrkis/parabellum
|
24
|
-
Description-Content-Type: text/markdown
|
25
|
-
|
26
|
-
# parabellum
|
27
|
-
|
28
|
-
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
29
|
-
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
30
|
-
support a wide range of new features and improvements.
|
31
|
-
|
32
|
-
## Installation
|
33
|
-
|
34
|
-
Parabellum is written in Python 3.11 and can be installed using pip:
|
35
|
-
|
36
|
-
```bash
|
37
|
-
pip install parabellum
|
38
|
-
```
|
39
|
-
|
40
|
-
## Usage
|
41
|
-
|
42
|
-
Parabellum is designed to be used in conjunction with JAX, a high-performance
|
43
|
-
numerical computing library. Here is a simple example of how to use Parabellum
|
44
|
-
to simulate a game with 10 agents and 10 enemies, each taking random actions:
|
45
|
-
|
46
|
-
```python
|
47
|
-
import parabellum as pb
|
48
|
-
from jax import random
|
49
|
-
|
50
|
-
# define the scenario
|
51
|
-
kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
|
52
|
-
scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
|
53
|
-
|
54
|
-
# create the environment
|
55
|
-
kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
|
56
|
-
env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
|
57
|
-
|
58
|
-
# initiate stochasticity
|
59
|
-
rng = random.PRNGKey(0)
|
60
|
-
rng, key = random.split(rng)
|
61
|
-
|
62
|
-
# initialize the environment state
|
63
|
-
obs, state = env.reset(key)
|
64
|
-
state_sequence = []
|
65
|
-
|
66
|
-
for _ in range(1000):
|
67
|
-
|
68
|
-
# manage stochasticity
|
69
|
-
rng, rng_act, key_step = random.split(key)
|
70
|
-
key_act = random.split(rng_act, len(env.agents))
|
71
|
-
|
72
|
-
# sample actions and append to state sequence
|
73
|
-
act = {a: env.action_space(a).sample(k)
|
74
|
-
for a, k in zip(env.agents, key_act)}
|
75
|
-
|
76
|
-
# step the environment
|
77
|
-
state_sequence.append((key_act, state, act))
|
78
|
-
obs, state, reward, done, info = env.step(key_step, act, state)
|
79
|
-
|
80
|
-
|
81
|
-
# save visualization of the state sequence
|
82
|
-
vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
|
83
|
-
vis.animate()
|
84
|
-
```
|
85
|
-
|
86
|
-
|
87
|
-
## Features
|
88
|
-
|
89
|
-
- Obstacles — can be inserted in
|
90
|
-
|
91
|
-
## TODO
|
92
|
-
|
93
|
-
- [x] Parallel pygame vis
|
94
|
-
- [ ] Parallel bullet renderings
|
95
|
-
- [ ] Combine parallell plots into one (maybe out of parabellum scope)
|
96
|
-
- [ ] Color for health?
|
97
|
-
- [ ] Add the ability to see ongoing game.
|
98
|
-
- [ ] Bug test friendly fire.
|
99
|
-
- [x] Start sim from arbitrary state.
|
100
|
-
- [ ] Save when the episode ends in some state/obs variable
|
101
|
-
- [ ] Look for the source of the bug when using more Allies than Enemies
|
102
|
-
- [ ] Y inversed axis for parabellum visualization
|
103
|
-
- [ ] Units see through obstacles?
|
104
|
-
|
@@ -1,8 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
-
parabellum/env.py,sha256=Z8zpdCaEi5HFwN0Vd2hukOarkPSg0EKZErTRts3JQ5E,16023
|
3
|
-
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
-
parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
|
5
|
-
parabellum/vis.py,sha256=JFVTnBg-LV4jZNw6cysU6NS8ZxeMpg5wz3JOi-lrnzY,10699
|
6
|
-
parabellum-0.2.19.dist-info/METADATA,sha256=DsEBAlESj8BwGSphmPPylStoXH_g_x_Iy3WJ3KEwjc0,3223
|
7
|
-
parabellum-0.2.19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
-
parabellum-0.2.19.dist-info/RECORD,,
|
File without changes
|