parabellum 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +2 -2
- parabellum/env.py +61 -42
- parabellum/map.py +16 -0
- parabellum/run.py +125 -0
- parabellum/vis.py +78 -54
- parabellum-0.2.14.dist-info/METADATA +104 -0
- parabellum-0.2.14.dist-info/RECORD +8 -0
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py +0 -4
- parabellum/.ipynb_checkpoints/env-checkpoint.py +0 -296
- parabellum/.ipynb_checkpoints/vis-checkpoint.py +0 -230
- parabellum-0.2.13.dist-info/METADATA +0 -56
- parabellum-0.2.13.dist-info/RECORD +0 -9
- {parabellum-0.2.13.dist-info → parabellum-0.2.14.dist-info}/WHEEL +0 -0
parabellum/__init__.py
CHANGED
parabellum/env.py
CHANGED
@@ -7,6 +7,7 @@ from jax import random
|
|
7
7
|
from jax import jit
|
8
8
|
from flax.struct import dataclass
|
9
9
|
import chex
|
10
|
+
from jax import vmap
|
10
11
|
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
12
|
from typing import Tuple, Dict
|
12
13
|
from functools import partial
|
@@ -16,7 +17,7 @@ from functools import partial
|
|
16
17
|
class Scenario:
|
17
18
|
"""Parabellum scenario"""
|
18
19
|
|
19
|
-
obstacle_coords: chex.Array
|
20
|
+
obstacle_coords: chex.Array # TODO: use map instead of obstacles
|
20
21
|
obstacle_deltas: chex.Array
|
21
22
|
|
22
23
|
unit_types: chex.Array
|
@@ -30,6 +31,8 @@ class Scenario:
|
|
30
31
|
# default scenario
|
31
32
|
scenarios = {
|
32
33
|
"default": Scenario(
|
34
|
+
jnp.array([[6, 10], [26, 10]]) * 8,
|
35
|
+
jnp.array([[0, 12], [0, 1]]) * 8,
|
33
36
|
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
37
|
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
38
|
jnp.zeros((19,), dtype=jnp.uint8),
|
@@ -40,24 +43,17 @@ scenarios = {
|
|
40
43
|
|
41
44
|
|
42
45
|
class Parabellum(SMAX):
|
43
|
-
def __init__(
|
44
|
-
self
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
):
|
49
|
-
super().__init__(scenario=scenario, **kwargs)
|
50
|
-
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
-
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
-
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
46
|
+
def __init__(self, scenario: Scenario, **kwargs):
|
47
|
+
super(Parabellum, self).__init__(**kwargs)
|
48
|
+
self.obstacle_coords = scenario.obstacle_coords
|
49
|
+
self.obstacle_deltas = scenario.obstacle_deltas
|
50
|
+
self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
|
53
51
|
self.max_steps = 200
|
54
|
-
# overwrite
|
55
|
-
|
56
|
-
|
57
|
-
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
-
return state
|
52
|
+
self._push_units_away = lambda x: x # overwrite push units
|
59
53
|
|
60
|
-
def _our_push_units_away(
|
54
|
+
def _our_push_units_away(
|
55
|
+
self, pos, unit_types, firmness: float = 1.0
|
56
|
+
): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
57
|
delta_matrix = pos[:, None] - pos[None, :]
|
62
58
|
dist_matrix = (
|
63
59
|
jnp.linalg.norm(delta_matrix, axis=-1)
|
@@ -74,7 +70,7 @@ class Parabellum(SMAX):
|
|
74
70
|
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
71
|
)
|
76
72
|
return unit_positions
|
77
|
-
|
73
|
+
|
78
74
|
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
75
|
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
76
|
self,
|
@@ -82,7 +78,6 @@ class Parabellum(SMAX):
|
|
82
78
|
state: State,
|
83
79
|
actions: Tuple[chex.Array, chex.Array],
|
84
80
|
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
-
|
86
81
|
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
82
|
def inter_fn(pos, new_pos, obs, obs_end):
|
88
83
|
d1 = jnp.cross(obs - pos, new_pos - pos)
|
@@ -224,19 +219,39 @@ class Parabellum(SMAX):
|
|
224
219
|
perform_agent_action
|
225
220
|
)(jnp.arange(self.num_agents), actions, keys)
|
226
221
|
|
227
|
-
#
|
222
|
+
# units push each other
|
228
223
|
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
224
|
|
230
|
-
# avoid going into obstacles after being pushed
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
225
|
+
# avoid going into obstacles after being pushed
|
226
|
+
|
227
|
+
bondaries_coords = jnp.array(
|
228
|
+
[[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
|
229
|
+
)
|
230
|
+
bondaries_deltas = jnp.array(
|
231
|
+
[
|
232
|
+
[self.map_width, 0],
|
233
|
+
[0, self.map_height],
|
234
|
+
[0, self.map_height],
|
235
|
+
[self.map_width, 0],
|
236
|
+
]
|
237
|
+
)
|
238
|
+
obstacle_coords = jnp.concatenate(
|
239
|
+
[self.obstacle_coords, bondaries_coords]
|
240
|
+
) # add the map boundaries to the obstacles to avoid
|
241
|
+
obstacle_deltas = jnp.concatenate(
|
242
|
+
[self.obstacle_deltas, bondaries_deltas]
|
243
|
+
) # add the map boundaries to the obstacles to avoid
|
244
|
+
obst_start = obstacle_coords
|
245
|
+
obst_end = obst_start + obstacle_deltas
|
246
|
+
|
247
|
+
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
248
|
+
inters = jnp.any(inter_fn(pos, new_pos, obst_start, obst_end))
|
236
249
|
return jnp.where(inters, pos, new_pos)
|
237
|
-
|
238
|
-
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(
|
239
|
-
|
250
|
+
|
251
|
+
pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
252
|
+
pos, new_pos, obst_start, obst_end
|
253
|
+
)
|
254
|
+
|
240
255
|
# Multiple enemies can attack the same unit.
|
241
256
|
# We have `(health_diff, attacked_idx)` pairs.
|
242
257
|
# `jax.lax.scatter_add` aggregates these exactly
|
@@ -278,19 +293,23 @@ class Parabellum(SMAX):
|
|
278
293
|
)
|
279
294
|
return state
|
280
295
|
|
296
|
+
|
281
297
|
if __name__ == "__main__":
|
282
|
-
|
283
|
-
|
284
|
-
|
298
|
+
n_envs = 4
|
299
|
+
kwargs = dict(map_width=64, map_height=64)
|
300
|
+
env = Parabellum(scenarios["default"], **kwargs)
|
301
|
+
rng, reset_rng = random.split(random.PRNGKey(0))
|
302
|
+
reset_key = random.split(reset_rng, n_envs)
|
303
|
+
obs, state = vmap(env.reset)(reset_key)
|
285
304
|
state_seq = []
|
286
|
-
for step in range(100):
|
287
|
-
rng, key = random.split(rng)
|
288
|
-
key_act = random.split(key, len(env.agents))
|
289
|
-
actions = {
|
290
|
-
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
-
for i, agent in enumerate(env.agents)
|
292
|
-
}
|
293
|
-
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
-
state_seq.append((obs, state, actions))
|
295
|
-
|
296
305
|
|
306
|
+
for i in range(10):
|
307
|
+
rng, act_rng, step_rng = random.split(rng, 3)
|
308
|
+
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
309
|
+
act = {
|
310
|
+
a: vmap(env.action_space(a).sample)(act_key[i])
|
311
|
+
for i, a in enumerate(env.agents)
|
312
|
+
}
|
313
|
+
step_key = random.split(step_rng, n_envs)
|
314
|
+
state_seq.append((step_key, state, act))
|
315
|
+
obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
|
parabellum/map.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
# map.py
|
2
|
+
# parabellum map functions
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# imports
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax
|
8
|
+
|
9
|
+
|
10
|
+
# functions
|
11
|
+
def map_fn(width, height, obst_coord, obst_delta):
|
12
|
+
"""Create a map from the given width, height, and obstacle coordinates and deltas."""
|
13
|
+
m = jnp.zeros((width, height))
|
14
|
+
for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
|
15
|
+
m = m.at[x : x + dx, y : y + dy].set(1)
|
16
|
+
return m
|
parabellum/run.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
# run.py
|
2
|
+
# parabellum run game live
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# Noah Syrkis
|
6
|
+
import pygame
|
7
|
+
from jax import random
|
8
|
+
from functools import partial
|
9
|
+
import darkdetect
|
10
|
+
import jax.numpy as jnp
|
11
|
+
from chex import dataclass
|
12
|
+
import jaxmarl
|
13
|
+
from typing import Tuple, List, Dict, Optional
|
14
|
+
import parabellum as pb
|
15
|
+
|
16
|
+
|
17
|
+
# constants
|
18
|
+
fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
|
19
|
+
bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
20
|
+
|
21
|
+
|
22
|
+
# types
|
23
|
+
State = jaxmarl.environments.smax.smax_env.State
|
24
|
+
Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
|
25
|
+
StateSeq = List[Tuple[jnp.ndarray, State, Action]]
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class Control:
|
30
|
+
running: bool = True
|
31
|
+
paused: bool = False
|
32
|
+
click: Optional[Tuple[int, int]] = None
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class Game:
|
37
|
+
clock: pygame.time.Clock
|
38
|
+
state: State
|
39
|
+
obs: Dict
|
40
|
+
state_seq: StateSeq
|
41
|
+
control: Control
|
42
|
+
env: pb.Parabellum
|
43
|
+
rng: random.PRNGKey
|
44
|
+
|
45
|
+
|
46
|
+
def handle_event(event, control_state):
|
47
|
+
"""Handle pygame events."""
|
48
|
+
if event.type == pygame.QUIT:
|
49
|
+
control_state.running = False
|
50
|
+
if event.type == pygame.MOUSEBUTTONDOWN:
|
51
|
+
pos = pygame.mouse.get_pos()
|
52
|
+
control_state.click = pos
|
53
|
+
if event.type == pygame.MOUSEBUTTONUP:
|
54
|
+
control_state.click = None
|
55
|
+
if event.type == pygame.KEYDOWN: # any key press pauses
|
56
|
+
control_state.paused = not control_state.paused
|
57
|
+
return control_state
|
58
|
+
|
59
|
+
|
60
|
+
def control_fn(game):
|
61
|
+
"""Handle pygame events."""
|
62
|
+
for event in pygame.event.get():
|
63
|
+
game.control = handle_event(event, game.control)
|
64
|
+
return game
|
65
|
+
|
66
|
+
|
67
|
+
def render_fn(screen, game):
|
68
|
+
"""Render the game."""
|
69
|
+
if len(game.state_seq) < 3:
|
70
|
+
return game
|
71
|
+
for rng, state, action in env.expand_state_seq(game.state_seq[-2:])[-8:]:
|
72
|
+
screen.fill(bg)
|
73
|
+
if game.control.click is not None:
|
74
|
+
pygame.draw.circle(screen, "red", game.control.click, 10)
|
75
|
+
unit_positions = state.unit_positions
|
76
|
+
for pos in unit_positions:
|
77
|
+
pos = (pos / env.map_width * 800).tolist()
|
78
|
+
pygame.draw.circle(screen, fg, pos, 5)
|
79
|
+
pygame.display.flip()
|
80
|
+
game.clock.tick(24) # limits FPS to 24
|
81
|
+
return game
|
82
|
+
|
83
|
+
|
84
|
+
def step_fn(game):
|
85
|
+
"""Step in parabellum."""
|
86
|
+
rng, act_rng, step_key = random.split(game.rng, 3)
|
87
|
+
act_key = random.split(act_rng, env.num_agents)
|
88
|
+
action = {
|
89
|
+
a: env.action_space(a).sample(act_key[i]) for i, a in enumerate(env.agents)
|
90
|
+
}
|
91
|
+
state_seq_entry = (step_key, game.state, action)
|
92
|
+
# append state_seq_entry to state_seq
|
93
|
+
game.state_seq.append(state_seq_entry)
|
94
|
+
obs, state, reward, done, info = env.step(step_key, game.state, action)
|
95
|
+
game.state = state
|
96
|
+
game.obs = obs
|
97
|
+
game.rng = rng
|
98
|
+
return game
|
99
|
+
|
100
|
+
|
101
|
+
# state
|
102
|
+
if __name__ == "__main__":
|
103
|
+
env = pb.Parabellum(pb.scenarios["default"])
|
104
|
+
pygame.init()
|
105
|
+
screen = pygame.display.set_mode((1000, 1000))
|
106
|
+
render = partial(render_fn, screen)
|
107
|
+
rng, key = random.split(random.PRNGKey(0))
|
108
|
+
obs, state = env.reset(key)
|
109
|
+
kwargs = dict(
|
110
|
+
control=Control(),
|
111
|
+
env=env,
|
112
|
+
rng=rng,
|
113
|
+
state_seq=[], # [(key, state, action)]
|
114
|
+
clock=pygame.time.Clock(),
|
115
|
+
state=state,
|
116
|
+
obs=obs,
|
117
|
+
)
|
118
|
+
game = Game(**kwargs)
|
119
|
+
|
120
|
+
while game.control.running:
|
121
|
+
game = control_fn(game)
|
122
|
+
game = game if game.control.paused else step_fn(game)
|
123
|
+
game = game if game.control.paused else render(game)
|
124
|
+
|
125
|
+
pygame.quit()
|
parabellum/vis.py
CHANGED
@@ -4,9 +4,11 @@ from tqdm import tqdm
|
|
4
4
|
import jax.numpy as jnp
|
5
5
|
import jax
|
6
6
|
from jax import vmap
|
7
|
+
from jax import tree_util
|
7
8
|
from functools import partial
|
8
9
|
import darkdetect
|
9
10
|
import pygame
|
11
|
+
import os
|
10
12
|
from moviepy.editor import ImageSequenceClip
|
11
13
|
from typing import Optional
|
12
14
|
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
@@ -20,6 +22,14 @@ from collections import defaultdict
|
|
20
22
|
action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
|
21
23
|
|
22
24
|
|
25
|
+
def small_multiples():
|
26
|
+
# make video of small multiples based on all videos in output
|
27
|
+
video_files = [f"output/parabellum_{i}.mp4" for i in range(4)]
|
28
|
+
# load mp4 videos and make a grid
|
29
|
+
clips = [ImageSequenceClip.load(filename) for filename in video_files]
|
30
|
+
print(len(clips))
|
31
|
+
|
32
|
+
|
23
33
|
class Visualizer(SMAXVisualizer):
|
24
34
|
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
25
35
|
super().__init__(env, state_seq, reward_seq)
|
@@ -30,7 +40,54 @@ class Visualizer(SMAXVisualizer):
|
|
30
40
|
self.s = 1000
|
31
41
|
self.scale = self.s / self.env.map_width
|
32
42
|
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
33
|
-
self.bullet_seq = bullet_fn
|
43
|
+
# self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
|
44
|
+
|
45
|
+
def animate(self, save_fname: str = "output/parabellum.mp4"):
|
46
|
+
multi_dim = self.state_seq[0][1].unit_positions.ndim > 1
|
47
|
+
if multi_dim:
|
48
|
+
n_envs = self.state_seq[0][1].unit_positions.shape[0]
|
49
|
+
if not self.have_expanded:
|
50
|
+
state_seqs = vmap(env.expand_state_seq)(self.state_seq)
|
51
|
+
self.have_expanded = True
|
52
|
+
for i in range(n_envs):
|
53
|
+
state_seq = jax.tree_map(lambda x: x[i], state_seqs)
|
54
|
+
action_seq = jax.tree_map(lambda x: x[i], self.action_seq)
|
55
|
+
self.animate_one(
|
56
|
+
state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
state_seq = env.expand_state_seq(self.state_seq)
|
60
|
+
self.animate_one(state_seq, self.action_seq, save_fname)
|
61
|
+
|
62
|
+
def animate_one(self, state_seq, action_seq, save_fname):
|
63
|
+
frames = [] # frames for the video
|
64
|
+
pygame.init() # initialize pygame
|
65
|
+
for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
|
66
|
+
action = action_seq[idx // self.env.world_steps_per_env_step]
|
67
|
+
screen = pygame.Surface(
|
68
|
+
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
69
|
+
)
|
70
|
+
screen.fill(self.bg) # fill the screen with the background color
|
71
|
+
|
72
|
+
self.render_agents(screen, state) # render the agents
|
73
|
+
self.render_action(screen, action)
|
74
|
+
self.render_obstacles(screen) # render the obstacles
|
75
|
+
|
76
|
+
# bullets
|
77
|
+
""" if idx < len(self.bullet_seq) * 8:
|
78
|
+
bullets = self.bullet_seq[idx // 8]
|
79
|
+
self.render_bullets(screen, bullets, idx % 8) """
|
80
|
+
|
81
|
+
# rotate the screen and append to frames
|
82
|
+
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
83
|
+
|
84
|
+
# save the images
|
85
|
+
clip = ImageSequenceClip(frames, fps=48)
|
86
|
+
clip.write_videofile(save_fname, fps=48)
|
87
|
+
# clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
88
|
+
pygame.quit()
|
89
|
+
|
90
|
+
return clip
|
34
91
|
|
35
92
|
def render_agents(self, screen, state):
|
36
93
|
time_tuple = zip(
|
@@ -100,39 +157,6 @@ class Visualizer(SMAXVisualizer):
|
|
100
157
|
position *= self.scale
|
101
158
|
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
102
159
|
|
103
|
-
def animate(self, save_fname: str = "parabellum.mp4"):
|
104
|
-
if not self.have_expanded:
|
105
|
-
self.expand_state_seq()
|
106
|
-
frames = [] # frames for the video
|
107
|
-
pygame.init() # initialize pygame
|
108
|
-
for idx, (_, state, _) in tqdm(
|
109
|
-
enumerate(self.state_seq), total=len(self.state_seq)
|
110
|
-
):
|
111
|
-
screen = pygame.Surface(
|
112
|
-
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
113
|
-
)
|
114
|
-
screen.fill(self.bg) # fill the screen with the background color
|
115
|
-
|
116
|
-
self.render_agents(screen, state) # render the agents
|
117
|
-
self.render_action(screen, self.action_seq[idx // 8])
|
118
|
-
self.render_obstacles(screen) # render the obstacles
|
119
|
-
|
120
|
-
# bullets
|
121
|
-
if idx < len(self.bullet_seq) * 8:
|
122
|
-
bullets = self.bullet_seq[idx // 8]
|
123
|
-
self.render_bullets(screen, bullets, idx % 8)
|
124
|
-
|
125
|
-
# rotate the screen and append to frames
|
126
|
-
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
127
|
-
|
128
|
-
# save the images
|
129
|
-
clip = ImageSequenceClip(frames, fps=48)
|
130
|
-
clip.write_videofile(save_fname, fps=48)
|
131
|
-
# clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
132
|
-
pygame.quit()
|
133
|
-
|
134
|
-
return clip
|
135
|
-
|
136
160
|
|
137
161
|
# functions
|
138
162
|
# bullet functions
|
@@ -201,30 +225,30 @@ def bullet_fn(env, states):
|
|
201
225
|
|
202
226
|
# test the visualizer
|
203
227
|
if __name__ == "__main__":
|
204
|
-
from parabellum import Parabellum, Scenario
|
205
228
|
from jax import random, numpy as jnp
|
229
|
+
from parabellum import Parabellum, scenarios
|
230
|
+
|
231
|
+
# small_multiples() # testing small multiples (not working yet)
|
232
|
+
# exit()
|
206
233
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
rng, key = random.split(random.PRNGKey(0))
|
214
|
-
obs, state = env.reset(key)
|
234
|
+
n_envs = 100
|
235
|
+
kwargs = dict(map_width=64, map_height=64)
|
236
|
+
env = Parabellum(scenarios["default"], **kwargs)
|
237
|
+
rng, reset_rng = random.split(random.PRNGKey(0))
|
238
|
+
reset_key = random.split(reset_rng, n_envs)
|
239
|
+
obs, state = vmap(env.reset)(reset_key)
|
215
240
|
state_seq = []
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
241
|
+
|
242
|
+
for i in range(100):
|
243
|
+
rng, act_rng, step_rng = random.split(rng, 3)
|
244
|
+
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
245
|
+
act = {
|
246
|
+
a: vmap(env.action_space(a).sample)(act_key[i])
|
247
|
+
for i, a in enumerate(env.agents)
|
222
248
|
}
|
223
|
-
|
224
|
-
|
225
|
-
obs, state, reward, done, infos = env.step(
|
249
|
+
step_key = random.split(step_rng, n_envs)
|
250
|
+
state_seq.append((step_key, state, act))
|
251
|
+
obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
|
226
252
|
|
227
253
|
vis = Visualizer(env, state_seq)
|
228
254
|
vis.animate()
|
229
|
-
|
230
|
-
|
@@ -0,0 +1,104 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.2.14
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
8
|
+
Author: Noah Syrkis
|
9
|
+
Author-email: desk@syrkis.com
|
10
|
+
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: jax (==0.4.17)
|
17
|
+
Requires-Dist: jaxmarl (==0.0.3)
|
18
|
+
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
19
|
+
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
20
|
+
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
21
|
+
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
22
|
+
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
23
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
24
|
+
Description-Content-Type: text/markdown
|
25
|
+
|
26
|
+
# parabellum
|
27
|
+
|
28
|
+
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
29
|
+
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
30
|
+
support a wide range of new features and improvements.
|
31
|
+
|
32
|
+
## Installation
|
33
|
+
|
34
|
+
Parabellum is written in Python 3.11 and can be installed using pip:
|
35
|
+
|
36
|
+
```bash
|
37
|
+
pip install parabellum
|
38
|
+
```
|
39
|
+
|
40
|
+
## Usage
|
41
|
+
|
42
|
+
Parabellum is designed to be used in conjunction with JAX, a high-performance
|
43
|
+
numerical computing library. Here is a simple example of how to use Parabellum
|
44
|
+
to simulate a game with 10 agents and 10 enemies, each taking random actions:
|
45
|
+
|
46
|
+
```python
|
47
|
+
import parabellum as pb
|
48
|
+
from jax import random
|
49
|
+
|
50
|
+
# define the scenario
|
51
|
+
kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
|
52
|
+
scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
|
53
|
+
|
54
|
+
# create the environment
|
55
|
+
kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
|
56
|
+
env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
|
57
|
+
|
58
|
+
# initiate stochasticity
|
59
|
+
rng = random.PRNGKey(0)
|
60
|
+
rng, key = random.split(rng)
|
61
|
+
|
62
|
+
# initialize the environment state
|
63
|
+
obs, state = env.reset(key)
|
64
|
+
state_sequence = []
|
65
|
+
|
66
|
+
for _ in range(1000):
|
67
|
+
|
68
|
+
# manage stochasticity
|
69
|
+
rng, rng_act, key_step = random.split(key)
|
70
|
+
key_act = random.split(rng_act, len(env.agents))
|
71
|
+
|
72
|
+
# sample actions and append to state sequence
|
73
|
+
act = {a: env.action_space(a).sample(k)
|
74
|
+
for a, k in zip(env.agents, key_act)}
|
75
|
+
|
76
|
+
# step the environment
|
77
|
+
state_sequence.append((key_act, state, act))
|
78
|
+
obs, state, reward, done, info = env.step(key_step, act, state)
|
79
|
+
|
80
|
+
|
81
|
+
# save visualization of the state sequence
|
82
|
+
vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
|
83
|
+
vis.animate()
|
84
|
+
```
|
85
|
+
|
86
|
+
|
87
|
+
## Features
|
88
|
+
|
89
|
+
- Obstacles — can be inserted in
|
90
|
+
|
91
|
+
## TODO
|
92
|
+
|
93
|
+
- [x] Parallel pygame vis
|
94
|
+
- [ ] Parallel bullet renderings
|
95
|
+
- [ ] Combine parallell plots into one (maybe out of parabellum scope)
|
96
|
+
- [ ] Color for health?
|
97
|
+
- [ ] Add the ability to see ongoing game.
|
98
|
+
- [ ] Bug test friendly fire.
|
99
|
+
- [x] Start sim from arbitrary state.
|
100
|
+
- [ ] Save when the episode ends in some state/obs variable
|
101
|
+
- [ ] Look for the source of the bug when using more Allies than Enemies
|
102
|
+
- [ ] Y inversed axis for parabellum visualization
|
103
|
+
- [ ] Units see through obstacles?
|
104
|
+
|
@@ -0,0 +1,8 @@
|
|
1
|
+
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
+
parabellum/env.py,sha256=rCn6iPLeFpqitncD9nEc0KA6N9JCMmiSyP9u2meOJxk,12325
|
3
|
+
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
+
parabellum/run.py,sha256=lVNBsMc8HY4Tqdjs_1MXGBvIzuN05brbRiqp0xlRc6c,3410
|
5
|
+
parabellum/vis.py,sha256=u7ifxWzHf96WgLTz_hw0ijy6-7wePd7lf0p-yD-NCQY,10212
|
6
|
+
parabellum-0.2.14.dist-info/METADATA,sha256=wEiXzwPfnigG5ZSANPFwGEjLCDU5D0c7qbpvEi6Gbm8,3223
|
7
|
+
parabellum-0.2.14.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
+
parabellum-0.2.14.dist-info/RECORD,,
|
@@ -1,296 +0,0 @@
|
|
1
|
-
"""Parabellum environment based on SMAX"""
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import jax
|
5
|
-
import numpy as np
|
6
|
-
from jax import random
|
7
|
-
from jax import jit
|
8
|
-
from flax.struct import dataclass
|
9
|
-
import chex
|
10
|
-
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
|
-
from typing import Tuple, Dict
|
12
|
-
from functools import partial
|
13
|
-
|
14
|
-
|
15
|
-
@dataclass
|
16
|
-
class Scenario:
|
17
|
-
"""Parabellum scenario"""
|
18
|
-
|
19
|
-
obstacle_coords: chex.Array
|
20
|
-
obstacle_deltas: chex.Array
|
21
|
-
|
22
|
-
unit_types: chex.Array
|
23
|
-
num_allies: int
|
24
|
-
num_enemies: int
|
25
|
-
|
26
|
-
smacv2_position_generation: bool = False
|
27
|
-
smacv2_unit_type_generation: bool = False
|
28
|
-
|
29
|
-
|
30
|
-
# default scenario
|
31
|
-
scenarios = {
|
32
|
-
"default": Scenario(
|
33
|
-
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
|
-
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
|
-
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
-
9,
|
37
|
-
10,
|
38
|
-
)
|
39
|
-
}
|
40
|
-
|
41
|
-
|
42
|
-
class Parabellum(SMAX):
|
43
|
-
def __init__(
|
44
|
-
self,
|
45
|
-
scenario: Scenario = scenarios["default"],
|
46
|
-
unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
|
47
|
-
**kwargs,
|
48
|
-
):
|
49
|
-
super().__init__(scenario=scenario, **kwargs)
|
50
|
-
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
-
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
-
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
53
|
-
self.max_steps = 200
|
54
|
-
# overwrite supers _world_step method
|
55
|
-
|
56
|
-
|
57
|
-
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
-
return state
|
59
|
-
|
60
|
-
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
-
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
-
dist_matrix = (
|
63
|
-
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
-
+ jnp.identity(self.num_agents)
|
65
|
-
+ 1e-6
|
66
|
-
)
|
67
|
-
radius_matrix = (
|
68
|
-
self.unit_type_radiuses[unit_types][:, None]
|
69
|
-
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
-
)
|
71
|
-
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
-
unit_positions = (
|
73
|
-
pos
|
74
|
-
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
-
)
|
76
|
-
return unit_positions
|
77
|
-
|
78
|
-
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
|
-
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
|
-
self,
|
81
|
-
key: chex.PRNGKey,
|
82
|
-
state: State,
|
83
|
-
actions: Tuple[chex.Array, chex.Array],
|
84
|
-
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
-
|
86
|
-
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
-
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
-
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
-
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
-
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
-
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
-
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
-
|
94
|
-
def update_position(idx, vec):
|
95
|
-
# Compute the movements slightly strangely.
|
96
|
-
# The velocities below are for diagonal directions
|
97
|
-
# because these are easier to encode as actions than the four
|
98
|
-
# diagonal directions. Then rotate the velocity 45
|
99
|
-
# degrees anticlockwise to compute the movement.
|
100
|
-
pos = state.unit_positions[idx]
|
101
|
-
new_pos = (
|
102
|
-
pos
|
103
|
-
+ vec
|
104
|
-
* self.unit_type_velocities[state.unit_types[idx]]
|
105
|
-
* self.time_per_step
|
106
|
-
)
|
107
|
-
# avoid going out of bounds
|
108
|
-
new_pos = jnp.maximum(
|
109
|
-
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
110
|
-
jnp.zeros((2,)),
|
111
|
-
)
|
112
|
-
|
113
|
-
#######################################################################
|
114
|
-
############################################ avoid going into obstacles
|
115
|
-
obs = self.obstacle_coords
|
116
|
-
obs_end = obs + self.obstacle_deltas
|
117
|
-
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
118
|
-
new_pos = jnp.where(inters, pos, new_pos)
|
119
|
-
|
120
|
-
#######################################################################
|
121
|
-
#######################################################################
|
122
|
-
|
123
|
-
return new_pos
|
124
|
-
|
125
|
-
#######################################################################
|
126
|
-
######################################### units close enough to get hit
|
127
|
-
|
128
|
-
def bystander_fn(attacked_idx):
|
129
|
-
idxs = jnp.zeros((self.num_agents,))
|
130
|
-
idxs *= (
|
131
|
-
jnp.linalg.norm(
|
132
|
-
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
133
|
-
)
|
134
|
-
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
135
|
-
)
|
136
|
-
return idxs
|
137
|
-
|
138
|
-
#######################################################################
|
139
|
-
#######################################################################
|
140
|
-
|
141
|
-
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
142
|
-
# for team 1, their attack actions are labelled in
|
143
|
-
# reverse order because that is the order they are
|
144
|
-
# observed in
|
145
|
-
attacked_idx = jax.lax.cond(
|
146
|
-
idx < self.num_allies,
|
147
|
-
lambda: action + self.num_allies - self.num_movement_actions,
|
148
|
-
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
149
|
-
)
|
150
|
-
# deal with no-op attack actions (i.e. agents that are moving instead)
|
151
|
-
attacked_idx = jax.lax.select(
|
152
|
-
action < self.num_movement_actions, idx, attacked_idx
|
153
|
-
)
|
154
|
-
|
155
|
-
attack_valid = (
|
156
|
-
(
|
157
|
-
jnp.linalg.norm(
|
158
|
-
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
159
|
-
)
|
160
|
-
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
161
|
-
)
|
162
|
-
& state.unit_alive[idx]
|
163
|
-
& state.unit_alive[attacked_idx]
|
164
|
-
)
|
165
|
-
attack_valid = attack_valid & (idx != attacked_idx)
|
166
|
-
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
167
|
-
health_diff = jax.lax.select(
|
168
|
-
attack_valid,
|
169
|
-
-self.unit_type_attacks[state.unit_types[idx]],
|
170
|
-
0.0,
|
171
|
-
)
|
172
|
-
# design choice based on the pysc2 randomness details.
|
173
|
-
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
174
|
-
|
175
|
-
#########################################################
|
176
|
-
############################### Add bystander health diff
|
177
|
-
|
178
|
-
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
179
|
-
bystander_valid = (
|
180
|
-
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
181
|
-
.astype(jnp.bool_)
|
182
|
-
.astype(jnp.float32)
|
183
|
-
)
|
184
|
-
bystander_health_diff = (
|
185
|
-
bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
186
|
-
)
|
187
|
-
|
188
|
-
#########################################################
|
189
|
-
#########################################################
|
190
|
-
|
191
|
-
cooldown_deviation = jax.random.uniform(
|
192
|
-
key, minval=-self.time_per_step, maxval=2 * self.time_per_step
|
193
|
-
)
|
194
|
-
cooldown = (
|
195
|
-
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
196
|
-
+ cooldown_deviation
|
197
|
-
)
|
198
|
-
cooldown_diff = jax.lax.select(
|
199
|
-
attack_valid,
|
200
|
-
# subtract the current cooldown because we are
|
201
|
-
# going to add it back. This way we effectively
|
202
|
-
# set the new cooldown to `cooldown`
|
203
|
-
cooldown - state.unit_weapon_cooldowns[idx],
|
204
|
-
-self.time_per_step,
|
205
|
-
)
|
206
|
-
return (
|
207
|
-
health_diff,
|
208
|
-
attacked_idx,
|
209
|
-
cooldown_diff,
|
210
|
-
(bystander_health_diff, bystander_idxs),
|
211
|
-
)
|
212
|
-
|
213
|
-
def perform_agent_action(idx, action, key):
|
214
|
-
movement_action, attack_action = action
|
215
|
-
new_pos = update_position(idx, movement_action)
|
216
|
-
health_diff, attacked_idxes, cooldown_diff, (bystander) = (
|
217
|
-
update_agent_health(idx, attack_action, key)
|
218
|
-
)
|
219
|
-
|
220
|
-
return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
|
221
|
-
|
222
|
-
keys = jax.random.split(key, num=self.num_agents)
|
223
|
-
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
224
|
-
perform_agent_action
|
225
|
-
)(jnp.arange(self.num_agents), actions, keys)
|
226
|
-
|
227
|
-
# checked that no unit passed through an obstacles
|
228
|
-
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
|
-
|
230
|
-
# avoid going into obstacles after being pushed
|
231
|
-
obs = self.obstacle_coords
|
232
|
-
obs_end = obs + self.obstacle_deltas
|
233
|
-
|
234
|
-
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
-
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
-
return jnp.where(inters, pos, new_pos)
|
237
|
-
|
238
|
-
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
-
|
240
|
-
# Multiple enemies can attack the same unit.
|
241
|
-
# We have `(health_diff, attacked_idx)` pairs.
|
242
|
-
# `jax.lax.scatter_add` aggregates these exactly
|
243
|
-
# in the way we want -- duplicate idxes will have their
|
244
|
-
# health differences added together. However, it is a
|
245
|
-
# super thin wrapper around the XLA scatter operation,
|
246
|
-
# which has this bonkers syntax and requires this dnums
|
247
|
-
# parameter. The usage here was inferred from a test:
|
248
|
-
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
249
|
-
dnums = jax.lax.ScatterDimensionNumbers(
|
250
|
-
update_window_dims=(),
|
251
|
-
inserted_window_dims=(0,),
|
252
|
-
scatter_dims_to_operand_dims=(0,),
|
253
|
-
)
|
254
|
-
unit_health = jnp.maximum(
|
255
|
-
jax.lax.scatter_add(
|
256
|
-
state.unit_health,
|
257
|
-
jnp.expand_dims(attacked_idxes, 1),
|
258
|
-
health_diff,
|
259
|
-
dnums,
|
260
|
-
),
|
261
|
-
0.0,
|
262
|
-
)
|
263
|
-
|
264
|
-
#########################################################
|
265
|
-
############################ subtracting bystander health
|
266
|
-
|
267
|
-
_, bystander_health_diff = bystander
|
268
|
-
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
269
|
-
|
270
|
-
#########################################################
|
271
|
-
#########################################################
|
272
|
-
|
273
|
-
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
274
|
-
state = state.replace(
|
275
|
-
unit_health=unit_health,
|
276
|
-
unit_positions=pos,
|
277
|
-
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
278
|
-
)
|
279
|
-
return state
|
280
|
-
|
281
|
-
if __name__ == "__main__":
|
282
|
-
env = Parabellum(map_width=256, map_height=256)
|
283
|
-
rng, key = random.split(random.PRNGKey(0))
|
284
|
-
obs, state = env.reset(key)
|
285
|
-
state_seq = []
|
286
|
-
for step in range(100):
|
287
|
-
rng, key = random.split(rng)
|
288
|
-
key_act = random.split(key, len(env.agents))
|
289
|
-
actions = {
|
290
|
-
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
-
for i, agent in enumerate(env.agents)
|
292
|
-
}
|
293
|
-
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
-
state_seq.append((obs, state, actions))
|
295
|
-
|
296
|
-
|
@@ -1,230 +0,0 @@
|
|
1
|
-
"""Visualizer for the Parabellum environment"""
|
2
|
-
|
3
|
-
from tqdm import tqdm
|
4
|
-
import jax.numpy as jnp
|
5
|
-
import jax
|
6
|
-
from jax import vmap
|
7
|
-
from functools import partial
|
8
|
-
import darkdetect
|
9
|
-
import pygame
|
10
|
-
from moviepy.editor import ImageSequenceClip
|
11
|
-
from typing import Optional
|
12
|
-
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
13
|
-
from jaxmarl.viz.visualizer import SMAXVisualizer
|
14
|
-
|
15
|
-
# default dict
|
16
|
-
from collections import defaultdict
|
17
|
-
|
18
|
-
|
19
|
-
# constants
|
20
|
-
action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
|
21
|
-
|
22
|
-
|
23
|
-
class Visualizer(SMAXVisualizer):
|
24
|
-
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
25
|
-
super().__init__(env, state_seq, reward_seq)
|
26
|
-
# remove fig and ax from super
|
27
|
-
self.fig, self.ax = None, None
|
28
|
-
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
29
|
-
self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
|
30
|
-
self.s = 1000
|
31
|
-
self.scale = self.s / self.env.map_width
|
32
|
-
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
33
|
-
self.bullet_seq = bullet_fn(self.env, self.state_seq)
|
34
|
-
|
35
|
-
def render_agents(self, screen, state):
|
36
|
-
time_tuple = zip(
|
37
|
-
state.unit_positions,
|
38
|
-
state.unit_teams,
|
39
|
-
state.unit_types,
|
40
|
-
state.unit_health,
|
41
|
-
)
|
42
|
-
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
43
|
-
face_col = self.fg if int(team.item()) == 0 else self.bg
|
44
|
-
pos = tuple((pos * self.scale).tolist())
|
45
|
-
|
46
|
-
# draw the agent
|
47
|
-
if hp > 0:
|
48
|
-
hp_frac = hp / self.env.unit_type_health[kind]
|
49
|
-
unit_size = self.env.unit_type_radiuses[kind]
|
50
|
-
radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
|
51
|
-
pygame.draw.circle(screen, face_col, pos, radius)
|
52
|
-
pygame.draw.circle(screen, self.fg, pos, radius, 1)
|
53
|
-
|
54
|
-
# draw the sight range
|
55
|
-
# sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
|
56
|
-
# pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
|
57
|
-
|
58
|
-
# draw attack range
|
59
|
-
# attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
|
60
|
-
# pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
|
61
|
-
# work out which agents are being shot
|
62
|
-
|
63
|
-
def render_action(self, screen, action):
|
64
|
-
def coord_fn(idx, n, team):
|
65
|
-
return (
|
66
|
-
self.s / 20 if team == 0 else self.s - self.s / 20,
|
67
|
-
# vertically centered so that n / 2 is above and below the center
|
68
|
-
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
69
|
-
)
|
70
|
-
|
71
|
-
for idx in range(self.env.num_allies):
|
72
|
-
symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
|
73
|
-
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
74
|
-
text = font.render(symb, True, self.fg)
|
75
|
-
coord = coord_fn(idx, self.env.num_allies, 0)
|
76
|
-
screen.blit(text, coord)
|
77
|
-
|
78
|
-
for idx in range(self.env.num_enemies):
|
79
|
-
symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
|
80
|
-
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
81
|
-
text = font.render(symb, True, self.fg)
|
82
|
-
coord = coord_fn(idx, self.env.num_enemies, 1)
|
83
|
-
screen.blit(text, coord)
|
84
|
-
|
85
|
-
def render_obstacles(self, screen):
|
86
|
-
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
87
|
-
d = tuple(((c + d) * self.scale).tolist())
|
88
|
-
c = tuple((c * self.scale).tolist())
|
89
|
-
pygame.draw.line(screen, self.fg, c, d, 5)
|
90
|
-
|
91
|
-
def render_bullets(self, screen, bullets, jdx):
|
92
|
-
jdx += 1
|
93
|
-
ally_bullets, enemy_bullets = bullets
|
94
|
-
for source, target in ally_bullets:
|
95
|
-
position = source + (target - source) * jdx / 8
|
96
|
-
position *= self.scale
|
97
|
-
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
98
|
-
for source, target in enemy_bullets:
|
99
|
-
position = source + (target - source) * jdx / 8
|
100
|
-
position *= self.scale
|
101
|
-
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
102
|
-
|
103
|
-
def animate(self, save_fname: str = "parabellum.mp4"):
|
104
|
-
if not self.have_expanded:
|
105
|
-
self.expand_state_seq()
|
106
|
-
frames = [] # frames for the video
|
107
|
-
pygame.init() # initialize pygame
|
108
|
-
for idx, (_, state, _) in tqdm(
|
109
|
-
enumerate(self.state_seq), total=len(self.state_seq)
|
110
|
-
):
|
111
|
-
screen = pygame.Surface(
|
112
|
-
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
113
|
-
)
|
114
|
-
screen.fill(self.bg) # fill the screen with the background color
|
115
|
-
|
116
|
-
self.render_agents(screen, state) # render the agents
|
117
|
-
self.render_action(screen, self.action_seq[idx // 8])
|
118
|
-
self.render_obstacles(screen) # render the obstacles
|
119
|
-
|
120
|
-
# bullets
|
121
|
-
if idx < len(self.bullet_seq) * 8:
|
122
|
-
bullets = self.bullet_seq[idx // 8]
|
123
|
-
self.render_bullets(screen, bullets, idx % 8)
|
124
|
-
|
125
|
-
# rotate the screen and append to frames
|
126
|
-
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
127
|
-
|
128
|
-
# save the images
|
129
|
-
clip = ImageSequenceClip(frames, fps=48)
|
130
|
-
clip.write_videofile(save_fname, fps=48)
|
131
|
-
# clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
132
|
-
pygame.quit()
|
133
|
-
|
134
|
-
return clip
|
135
|
-
|
136
|
-
|
137
|
-
# functions
|
138
|
-
# bullet functions
|
139
|
-
def dist_fn(env, pos): # computing the distances between all ally and enemy agents
|
140
|
-
delta = pos[None, :, :] - pos[:, None, :]
|
141
|
-
dist = jnp.sqrt((delta**2).sum(axis=2))
|
142
|
-
dist = dist[: env.num_allies, env.num_allies :]
|
143
|
-
return {"ally": dist, "enemy": dist.T}
|
144
|
-
|
145
|
-
|
146
|
-
def range_fn(env, dists, ranges): # computing what targets are in range
|
147
|
-
ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
|
148
|
-
enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
|
149
|
-
return {"ally": ally_range, "enemy": enemy_range}
|
150
|
-
|
151
|
-
|
152
|
-
def target_fn(acts, in_range, team): # computing the one hot valid targets
|
153
|
-
t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
|
154
|
-
t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
|
155
|
-
t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
|
156
|
-
return t_attacks * in_range[team] # one hot valid targets
|
157
|
-
|
158
|
-
|
159
|
-
def attack_fn(env, state_seq): # one hot attack list
|
160
|
-
attacks = []
|
161
|
-
for _, state, acts in state_seq:
|
162
|
-
dists = dist_fn(env, state.unit_positions)
|
163
|
-
ranges = env.unit_type_attack_ranges[state.unit_types]
|
164
|
-
in_range = range_fn(env, dists, ranges)
|
165
|
-
target = partial(target_fn, acts, in_range)
|
166
|
-
attack = {"ally": target("ally"), "enemy": target("enemy")}
|
167
|
-
attacks.append(attack)
|
168
|
-
return attacks
|
169
|
-
|
170
|
-
|
171
|
-
def bullet_fn(env, states):
|
172
|
-
bullet_seq = []
|
173
|
-
attack_seq = attack_fn(env, states)
|
174
|
-
|
175
|
-
def aux_fn(team):
|
176
|
-
bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
|
177
|
-
# bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
|
178
|
-
return bullets
|
179
|
-
|
180
|
-
state_zip = zip(states[:-1], states[1:])
|
181
|
-
for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
|
182
|
-
one_hot = attack_seq[i]
|
183
|
-
ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
|
184
|
-
|
185
|
-
ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
|
186
|
-
enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
|
187
|
-
|
188
|
-
enemy_bullets_source = state.unit_positions[
|
189
|
-
enemy_bullets[:, 0] + env.num_allies
|
190
|
-
]
|
191
|
-
ally_bullets_target = n_state.unit_positions[
|
192
|
-
ally_bullets[:, 1] + env.num_allies
|
193
|
-
]
|
194
|
-
|
195
|
-
ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
|
196
|
-
enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
|
197
|
-
|
198
|
-
bullet_seq.append((ally_bullets, enemy_bullets))
|
199
|
-
return bullet_seq
|
200
|
-
|
201
|
-
|
202
|
-
# test the visualizer
|
203
|
-
if __name__ == "__main__":
|
204
|
-
from parabellum import Parabellum, Scenario
|
205
|
-
from jax import random, numpy as jnp
|
206
|
-
|
207
|
-
s = Scenario(jnp.array([[16, 0]]),
|
208
|
-
jnp.array([[0, 32]]) * 8,
|
209
|
-
jnp.zeros((19,), dtype=jnp.uint8),
|
210
|
-
9,
|
211
|
-
10)
|
212
|
-
env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
|
213
|
-
rng, key = random.split(random.PRNGKey(0))
|
214
|
-
obs, state = env.reset(key)
|
215
|
-
state_seq = []
|
216
|
-
for step in range(50):
|
217
|
-
rng, key = random.split(rng)
|
218
|
-
key_act = random.split(key, len(env.agents))
|
219
|
-
actions = {
|
220
|
-
agent: jnp.array(1)
|
221
|
-
for i, agent in enumerate(env.agents)
|
222
|
-
}
|
223
|
-
state_seq.append((key, state, actions))
|
224
|
-
rng, key_step = random.split(rng)
|
225
|
-
obs, state, reward, done, infos = env.step(key_step, state, actions)
|
226
|
-
|
227
|
-
vis = Visualizer(env, state_seq)
|
228
|
-
vis.animate()
|
229
|
-
|
230
|
-
|
@@ -1,56 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: parabellum
|
3
|
-
Version: 0.2.13
|
4
|
-
Summary: Parabellum environment for parallel warfare simulation
|
5
|
-
Home-page: https://github.com/syrkis/parabellum
|
6
|
-
License: MIT
|
7
|
-
Keywords: warfare,simulation,parallel,environment
|
8
|
-
Author: Noah Syrkis
|
9
|
-
Author-email: desk@syrkis.com
|
10
|
-
Requires-Python: >=3.11,<4.0
|
11
|
-
Classifier: License :: OSI Approved :: MIT License
|
12
|
-
Classifier: Programming Language :: Python :: 3
|
13
|
-
Classifier: Programming Language :: Python :: 3.11
|
14
|
-
Classifier: Programming Language :: Python :: 3.12
|
15
|
-
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
-
Requires-Dist: jaxmarl (==0.0.3)
|
17
|
-
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
18
|
-
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
19
|
-
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
20
|
-
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
21
|
-
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
22
|
-
Project-URL: Repository, https://github.com/syrkis/parabellum
|
23
|
-
Description-Content-Type: text/markdown
|
24
|
-
|
25
|
-
# parabellum
|
26
|
-
|
27
|
-
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
28
|
-
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
29
|
-
support a wide range of new features and improvements.
|
30
|
-
|
31
|
-
## Installation
|
32
|
-
|
33
|
-
Install through PyPI:
|
34
|
-
|
35
|
-
```bash
|
36
|
-
pip install parabellum
|
37
|
-
```
|
38
|
-
|
39
|
-
## Usage
|
40
|
-
|
41
|
-
```python
|
42
|
-
import parabellum as pb
|
43
|
-
```
|
44
|
-
|
45
|
-
## TODO
|
46
|
-
|
47
|
-
- [ ] Parallel pygame vis
|
48
|
-
- [ ] Color for health?
|
49
|
-
- [ ] Add the ability to see ongoing game.
|
50
|
-
- [ ] Bug test friendly fire.
|
51
|
-
- [ ] Start sim from arbitrary state.
|
52
|
-
- [ ] Save when the episode ends in some state/obs variable
|
53
|
-
- [ ] Look for the source of the bug when using more Allies than Enemies
|
54
|
-
- [ ] Y inversed axis for parabellum visualization
|
55
|
-
- [ ] Units see through obstacles?
|
56
|
-
|
@@ -1,9 +0,0 @@
|
|
1
|
-
parabellum/.ipynb_checkpoints/__init__-checkpoint.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
|
2
|
-
parabellum/.ipynb_checkpoints/env-checkpoint.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
|
3
|
-
parabellum/.ipynb_checkpoints/vis-checkpoint.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
|
4
|
-
parabellum/__init__.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
|
5
|
-
parabellum/env.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
|
6
|
-
parabellum/vis.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
|
7
|
-
parabellum-0.2.13.dist-info/METADATA,sha256=jIs4QuEQ7Act04HK2cwKw4PaXOi9Y01ucQ2a4np4-KU,1627
|
8
|
-
parabellum-0.2.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
9
|
-
parabellum-0.2.13.dist-info/RECORD,,
|
File without changes
|