parabellum 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .env import Parabellum, Scenario
1
+ from .env import Parabellum, Scenario, scenarios
2
2
  from .vis import Visualizer
3
3
 
4
- __all__ = ["Parabellum", "Visualizer", "Scenario"]
4
+ __all__ = ["Parabellum", "Visualizer", "Scenario", "scenarios"]
parabellum/env.py CHANGED
@@ -7,6 +7,7 @@ from jax import random
7
7
  from jax import jit
8
8
  from flax.struct import dataclass
9
9
  import chex
10
+ from jax import vmap
10
11
  from jaxmarl.environments.smax.smax_env import State, SMAX
11
12
  from typing import Tuple, Dict
12
13
  from functools import partial
@@ -16,7 +17,9 @@ from functools import partial
16
17
  class Scenario:
17
18
  """Parabellum scenario"""
18
19
 
19
- obstacle_coords: chex.Array
20
+ terrain_raster: chex.Array
21
+
22
+ obstacle_coords: chex.Array # TODO: use map instead of obstacles
20
23
  obstacle_deltas: chex.Array
21
24
 
22
25
  unit_types: chex.Array
@@ -30,8 +33,9 @@ class Scenario:
30
33
  # default scenario
31
34
  scenarios = {
32
35
  "default": Scenario(
33
- jnp.array([[6, 10], [26, 10]]) * 8,
34
- jnp.array([[0, 12], [0, 1]]) * 8,
36
+ jnp.eye(128, dtype=jnp.uint8),
37
+ jnp.array([[80, 0], [16, 12]]),
38
+ jnp.array([[0, 80], [0, 20]]),
35
39
  jnp.zeros((19,), dtype=jnp.uint8),
36
40
  9,
37
41
  10,
@@ -40,24 +44,78 @@ scenarios = {
40
44
 
41
45
 
42
46
  class Parabellum(SMAX):
43
- def __init__(
44
- self,
45
- scenario: Scenario = scenarios["default"],
46
- unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
47
- **kwargs,
48
- ):
49
- super().__init__(scenario=scenario, **kwargs)
50
- self.unit_type_attack_blasts = unit_type_attack_blasts
51
- self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
52
- self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
47
+ def __init__(self, scenario: Scenario, **kwargs):
48
+ map_height, map_width = scenario.terrain_raster.shape
49
+ args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
50
+ super(Parabellum, self).__init__(**args, **kwargs)
51
+ self.terrain_raster = scenario.terrain_raster
52
+ self.obstacle_coords = scenario.obstacle_coords
53
+ self.obstacle_deltas = scenario.obstacle_deltas
54
+ self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
53
55
  self.max_steps = 200
54
- # overwrite supers _world_step method
55
-
56
-
57
- def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
- return state
59
-
60
- def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
56
+ self._push_units_away = lambda x: x # overwrite push units
57
+
58
+ @partial(jax.jit, static_argnums=(0,))
59
+ def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
60
+ """Environment-specific reset."""
61
+ key, team_0_key, team_1_key = jax.random.split(key, num=3)
62
+ team_0_start = jnp.stack(
63
+ [jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
64
+ )
65
+ team_0_start_noise = jax.random.uniform(
66
+ team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
67
+ )
68
+ team_0_start = team_0_start + team_0_start_noise
69
+ team_1_start = jnp.stack(
70
+ [jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
71
+ * self.num_enemies
72
+ )
73
+ team_1_start_noise = jax.random.uniform(
74
+ team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
75
+ )
76
+ team_1_start = team_1_start + team_1_start_noise
77
+ unit_positions = jnp.concatenate([team_0_start, team_1_start])
78
+ key, pos_key = jax.random.split(key)
79
+ generated_unit_positions = self.position_generator.generate(pos_key)
80
+ unit_positions = jax.lax.select(
81
+ self.smacv2_position_generation, generated_unit_positions, unit_positions
82
+ )
83
+ unit_teams = jnp.zeros((self.num_agents,))
84
+ unit_teams = unit_teams.at[self.num_allies :].set(1)
85
+ unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
86
+ # default behaviour spawn all marines
87
+ unit_types = (
88
+ jnp.zeros((self.num_agents,), dtype=jnp.uint8)
89
+ if self.scenario is None
90
+ else self.scenario
91
+ )
92
+ key, unit_type_key = jax.random.split(key)
93
+ generated_unit_types = self.unit_type_generator.generate(unit_type_key)
94
+ unit_types = jax.lax.select(
95
+ self.smacv2_unit_type_generation, generated_unit_types, unit_types
96
+ )
97
+ unit_health = self.unit_type_health[unit_types]
98
+ state = State(
99
+ unit_positions=unit_positions,
100
+ unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
101
+ unit_teams=unit_teams,
102
+ unit_health=unit_health,
103
+ unit_types=unit_types,
104
+ prev_movement_actions=jnp.zeros((self.num_agents, 2)),
105
+ prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
106
+ time=0,
107
+ terminal=False,
108
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
109
+ )
110
+ state = self._push_units_away(state)
111
+ obs = self.get_obs(state)
112
+ world_state = self.get_world_state(state)
113
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
114
+ return obs, state
115
+
116
+ def _our_push_units_away(
117
+ self, pos, unit_types, firmness: float = 1.0
118
+ ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
119
  delta_matrix = pos[:, None] - pos[None, :]
62
120
  dist_matrix = (
63
121
  jnp.linalg.norm(delta_matrix, axis=-1)
@@ -74,7 +132,7 @@ class Parabellum(SMAX):
74
132
  + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
133
  )
76
134
  return unit_positions
77
-
135
+
78
136
  @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
79
137
  def _world_step( # modified version of JaxMARL's SMAX _world_step
80
138
  self,
@@ -82,15 +140,25 @@ class Parabellum(SMAX):
82
140
  state: State,
83
141
  actions: Tuple[chex.Array, chex.Array],
84
142
  ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
-
86
143
  @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
- def inter_fn(pos, new_pos, obs, obs_end):
144
+ def intersect_fn(pos, new_pos, obs, obs_end):
88
145
  d1 = jnp.cross(obs - pos, new_pos - pos)
89
146
  d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
147
  d3 = jnp.cross(pos - obs, obs_end - obs)
91
148
  d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
149
  return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
150
 
151
+ def raster_crossing(pos, new_pos):
152
+ pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
153
+ raster = self.terrain_raster
154
+ axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
155
+ minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
156
+ maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
157
+ segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
158
+ segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
159
+ segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
160
+ return jnp.any(segment)
161
+
94
162
  def update_position(idx, vec):
95
163
  # Compute the movements slightly strangely.
96
164
  # The velocities below are for diagonal directions
@@ -114,8 +182,10 @@ class Parabellum(SMAX):
114
182
  ############################################ avoid going into obstacles
115
183
  obs = self.obstacle_coords
116
184
  obs_end = obs + self.obstacle_deltas
117
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
118
- new_pos = jnp.where(inters, pos, new_pos)
185
+ inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
186
+ rastersects = raster_crossing(pos, new_pos)
187
+ flag = jnp.logical_or(inters, rastersects)
188
+ new_pos = jnp.where(flag, pos, new_pos)
119
189
 
120
190
  #######################################################################
121
191
  #######################################################################
@@ -224,19 +294,41 @@ class Parabellum(SMAX):
224
294
  perform_agent_action
225
295
  )(jnp.arange(self.num_agents), actions, keys)
226
296
 
227
- # checked that no unit passed through an obstacles
297
+ # units push each other
228
298
  new_pos = self._our_push_units_away(pos, state.unit_types)
229
299
 
230
- # avoid going into obstacles after being pushed
231
- obs = self.obstacle_coords
232
- obs_end = obs + self.obstacle_deltas
233
-
234
- def check_obstacles(pos, new_pos, obs, obs_end):
235
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
- return jnp.where(inters, pos, new_pos)
237
-
238
- pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
-
300
+ # avoid going into obstacles after being pushed
301
+
302
+ bondaries_coords = jnp.array(
303
+ [[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
304
+ )
305
+ bondaries_deltas = jnp.array(
306
+ [
307
+ [self.map_width, 0],
308
+ [0, self.map_height],
309
+ [0, self.map_height],
310
+ [self.map_width, 0],
311
+ ]
312
+ )
313
+ obstacle_coords = jnp.concatenate(
314
+ [self.obstacle_coords, bondaries_coords]
315
+ ) # add the map boundaries to the obstacles to avoid
316
+ obstacle_deltas = jnp.concatenate(
317
+ [self.obstacle_deltas, bondaries_deltas]
318
+ ) # add the map boundaries to the obstacles to avoid
319
+ obst_start = obstacle_coords
320
+ obst_end = obst_start + obstacle_deltas
321
+
322
+ def check_obstacles(pos, new_pos, obst_start, obst_end):
323
+ intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
324
+ rastersect = raster_crossing(pos, new_pos)
325
+ flag = jnp.logical_or(intersects, rastersect)
326
+ return jnp.where(flag, pos, new_pos)
327
+
328
+ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
329
+ pos, new_pos, obst_start, obst_end
330
+ )
331
+
240
332
  # Multiple enemies can attack the same unit.
241
333
  # We have `(health_diff, attacked_idx)` pairs.
242
334
  # `jax.lax.scatter_add` aggregates these exactly
@@ -278,19 +370,23 @@ class Parabellum(SMAX):
278
370
  )
279
371
  return state
280
372
 
373
+
281
374
  if __name__ == "__main__":
282
- env = Parabellum(map_width=256, map_height=256)
283
- rng, key = random.split(random.PRNGKey(0))
284
- obs, state = env.reset(key)
375
+ n_envs = 4
376
+ kwargs = dict(map_width=64, map_height=64)
377
+ env = Parabellum(scenarios["default"], **kwargs)
378
+ rng, reset_rng = random.split(random.PRNGKey(0))
379
+ reset_key = random.split(reset_rng, n_envs)
380
+ obs, state = vmap(env.reset)(reset_key)
285
381
  state_seq = []
286
- for step in range(100):
287
- rng, key = random.split(rng)
288
- key_act = random.split(key, len(env.agents))
289
- actions = {
290
- agent: jax.random.randint(key_act[i], (), 0, 5)
291
- for i, agent in enumerate(env.agents)
292
- }
293
- _, state, _, _, _ = env.step(key, state, actions)
294
- state_seq.append((obs, state, actions))
295
-
296
382
 
383
+ for i in range(10):
384
+ rng, act_rng, step_rng = random.split(rng, 3)
385
+ act_key = random.split(act_rng, (len(env.agents), n_envs))
386
+ act = {
387
+ a: vmap(env.action_space(a).sample)(act_key[i])
388
+ for i, a in enumerate(env.agents)
389
+ }
390
+ step_key = random.split(step_rng, n_envs)
391
+ state_seq.append((step_key, state, act))
392
+ obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
parabellum/map.py ADDED
@@ -0,0 +1,16 @@
1
+ # map.py
2
+ # parabellum map functions
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ import jax.numpy as jnp
7
+ import jax
8
+
9
+
10
+ # functions
11
+ def map_fn(width, height, obst_coord, obst_delta):
12
+ """Create a map from the given width, height, and obstacle coordinates and deltas."""
13
+ m = jnp.zeros((width, height))
14
+ for (x, y), (dx, dy) in zip(obst_coord, obst_delta):
15
+ m = m.at[x : x + dx, y : y + dy].set(1)
16
+ return m
parabellum/run.py ADDED
@@ -0,0 +1,127 @@
1
+ # run.py
2
+ # parabellum run game live
3
+ # by: Noah Syrkis
4
+
5
+ # Noah Syrkis
6
+ import pygame
7
+ from jax import random
8
+ from functools import partial
9
+ import darkdetect
10
+ import jax.numpy as jnp
11
+ from chex import dataclass
12
+ import jaxmarl
13
+ from typing import Tuple, List, Dict, Optional
14
+ import parabellum as pb
15
+
16
+
17
+ # constants
18
+ fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
19
+ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
20
+
21
+
22
+ # types
23
+ State = jaxmarl.environments.smax.smax_env.State
24
+ Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
25
+ StateSeq = List[Tuple[jnp.ndarray, State, Action]]
26
+
27
+
28
+ @dataclass
29
+ class Control:
30
+ running: bool = True
31
+ paused: bool = False
32
+ click: Optional[Tuple[int, int]] = None
33
+
34
+
35
+ @dataclass
36
+ class Game:
37
+ clock: pygame.time.Clock
38
+ state: State
39
+ obs: Dict
40
+ state_seq: StateSeq
41
+ control: Control
42
+ env: pb.Parabellum
43
+ rng: random.PRNGKey
44
+
45
+
46
+ def handle_event(event, control_state):
47
+ """Handle pygame events."""
48
+ if event.type == pygame.QUIT:
49
+ control_state.running = False
50
+ if event.type == pygame.MOUSEBUTTONDOWN:
51
+ pos = pygame.mouse.get_pos()
52
+ control_state.click = pos
53
+ if event.type == pygame.MOUSEBUTTONUP:
54
+ control_state.click = None
55
+ if event.type == pygame.KEYDOWN: # any key press pauses
56
+ control_state.paused = not control_state.paused
57
+ return control_state
58
+
59
+
60
+ def control_fn(game):
61
+ """Handle pygame events."""
62
+ for event in pygame.event.get():
63
+ game.control = handle_event(event, game.control)
64
+ return game
65
+
66
+
67
+ def render_fn(screen, game):
68
+ """Render the game."""
69
+ if len(game.state_seq) < 3:
70
+ return game
71
+ for rng, state, action in env.expand_state_seq(game.state_seq[-2:])[-8:]:
72
+ screen.fill(bg)
73
+ if game.control.click is not None:
74
+ pygame.draw.circle(screen, "red", game.control.click, 10)
75
+ unit_positions = state.unit_positions
76
+ for pos in unit_positions:
77
+ pos = (pos / env.map_width * 800).tolist()
78
+ pygame.draw.circle(screen, fg, pos, 5)
79
+ pygame.display.flip()
80
+ game.clock.tick(24) # limits FPS to 24
81
+ return game
82
+
83
+
84
+ def step_fn(game):
85
+ """Step in parabellum."""
86
+ rng, act_rng, step_key = random.split(game.rng, 3)
87
+ act_key = random.split(act_rng, env.num_agents)
88
+ action = {
89
+ a: env.action_space(a).sample(act_key[i]) for i, a in enumerate(env.agents)
90
+ }
91
+ state_seq_entry = (step_key, game.state, action)
92
+ # append state_seq_entry to state_seq
93
+ game.state_seq.append(state_seq_entry)
94
+ obs, state, reward, done, info = env.step(step_key, game.state, action)
95
+ game.state = state
96
+ game.obs = obs
97
+ game.rng = rng
98
+ return game
99
+
100
+
101
+ # state
102
+ if __name__ == "__main__":
103
+ env = pb.Parabellum(pb.scenarios["default"])
104
+ pygame.init()
105
+ screen = pygame.display.set_mode((1000, 1000))
106
+ render = partial(render_fn, screen)
107
+ rng, key = random.split(random.PRNGKey(0))
108
+ obs, state = env.reset(key)
109
+ kwargs = dict(
110
+ control=Control(),
111
+ env=env,
112
+ rng=rng,
113
+ state_seq=[], # [(key, state, action)]
114
+ clock=pygame.time.Clock(),
115
+ state=state,
116
+ obs=obs,
117
+ )
118
+ game = Game(**kwargs)
119
+
120
+ while game.control.running:
121
+ game = control_fn(game)
122
+ game = game if game.control.paused else step_fn(game)
123
+ game = game if game.control.paused else render(game)
124
+
125
+ pygame.quit()
126
+
127
+
parabellum/vis.py CHANGED
@@ -4,9 +4,12 @@ from tqdm import tqdm
4
4
  import jax.numpy as jnp
5
5
  import jax
6
6
  from jax import vmap
7
+ from jax import tree_util
7
8
  from functools import partial
8
9
  import darkdetect
10
+ import numpy as np
9
11
  import pygame
12
+ import os
10
13
  from moviepy.editor import ImageSequenceClip
11
14
  from typing import Optional
12
15
  from jaxmarl.environments.multi_agent_env import MultiAgentEnv
@@ -20,6 +23,14 @@ from collections import defaultdict
20
23
  action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
21
24
 
22
25
 
26
+ def small_multiples():
27
+ # make video of small multiples based on all videos in output
28
+ video_files = [f"output/parabellum_{i}.mp4" for i in range(4)]
29
+ # load mp4 videos and make a grid
30
+ clips = [ImageSequenceClip.load(filename) for filename in video_files]
31
+ print(len(clips))
32
+
33
+
23
34
  class Visualizer(SMAXVisualizer):
24
35
  def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
25
36
  super().__init__(env, state_seq, reward_seq)
@@ -30,7 +41,60 @@ class Visualizer(SMAXVisualizer):
30
41
  self.s = 1000
31
42
  self.scale = self.s / self.env.map_width
32
43
  self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
33
- self.bullet_seq = bullet_fn(self.env, self.state_seq)
44
+ # self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
45
+
46
+ def animate(self, save_fname: str = "output/parabellum.mp4"):
47
+ multi_dim = self.state_seq[0][1].unit_positions.ndim > 1
48
+ if multi_dim:
49
+ n_envs = self.state_seq[0][1].unit_positions.shape[0]
50
+ if not self.have_expanded:
51
+ state_seqs = vmap(env.expand_state_seq)(self.state_seq)
52
+ self.have_expanded = True
53
+ for i in range(n_envs):
54
+ state_seq = jax.tree_map(lambda x: x[i], state_seqs)
55
+ action_seq = jax.tree_map(lambda x: x[i], self.action_seq)
56
+ self.animate_one(
57
+ state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
58
+ )
59
+ else:
60
+ state_seq = env.expand_state_seq(self.state_seq)
61
+ self.animate_one(state_seq, self.action_seq, save_fname)
62
+
63
+ def animate_one(self, state_seq, action_seq, save_fname):
64
+ frames = [] # frames for the video
65
+ pygame.init() # initialize pygame
66
+ terrain = np.array(self.env.terrain_raster)
67
+ rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
68
+ rgb_array[terrain == 1] = self.fg
69
+ mask_surface = pygame.surfarray.make_surface(rgb_array)
70
+ mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
71
+
72
+ for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
73
+ action = action_seq[idx // self.env.world_steps_per_env_step]
74
+ screen = pygame.Surface(
75
+ (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
76
+ )
77
+ screen.fill(self.bg) # fill the screen with the background color
78
+ screen.blit(mask_surface, (0, 0))
79
+
80
+ self.render_agents(screen, state) # render the agents
81
+ self.render_action(screen, action)
82
+ self.render_obstacles(screen) # render the obstacles
83
+
84
+ # bullets
85
+ """ if idx < len(self.bullet_seq) * 8:
86
+ bullets = self.bullet_seq[idx // 8]
87
+ self.render_bullets(screen, bullets, idx % 8) """
88
+
89
+ # rotate the screen and append to frames
90
+ frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
91
+ # save the images
92
+ clip = ImageSequenceClip(frames, fps=48)
93
+ clip.write_videofile(save_fname, fps=48)
94
+ # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
95
+ pygame.quit()
96
+
97
+ return clip
34
98
 
35
99
  def render_agents(self, screen, state):
36
100
  time_tuple = zip(
@@ -42,7 +106,6 @@ class Visualizer(SMAXVisualizer):
42
106
  for idx, (pos, team, kind, hp) in enumerate(time_tuple):
43
107
  face_col = self.fg if int(team.item()) == 0 else self.bg
44
108
  pos = tuple((pos * self.scale).tolist())
45
-
46
109
  # draw the agent
47
110
  if hp > 0:
48
111
  hp_frac = hp / self.env.unit_type_health[kind]
@@ -100,39 +163,6 @@ class Visualizer(SMAXVisualizer):
100
163
  position *= self.scale
101
164
  pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
102
165
 
103
- def animate(self, save_fname: str = "parabellum.mp4"):
104
- if not self.have_expanded:
105
- self.expand_state_seq()
106
- frames = [] # frames for the video
107
- pygame.init() # initialize pygame
108
- for idx, (_, state, _) in tqdm(
109
- enumerate(self.state_seq), total=len(self.state_seq)
110
- ):
111
- screen = pygame.Surface(
112
- (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
113
- )
114
- screen.fill(self.bg) # fill the screen with the background color
115
-
116
- self.render_agents(screen, state) # render the agents
117
- self.render_action(screen, self.action_seq[idx // 8])
118
- self.render_obstacles(screen) # render the obstacles
119
-
120
- # bullets
121
- if idx < len(self.bullet_seq) * 8:
122
- bullets = self.bullet_seq[idx // 8]
123
- self.render_bullets(screen, bullets, idx % 8)
124
-
125
- # rotate the screen and append to frames
126
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
127
-
128
- # save the images
129
- clip = ImageSequenceClip(frames, fps=48)
130
- clip.write_videofile(save_fname, fps=48)
131
- # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
132
- pygame.quit()
133
-
134
- return clip
135
-
136
166
 
137
167
  # functions
138
168
  # bullet functions
@@ -201,30 +231,29 @@ def bullet_fn(env, states):
201
231
 
202
232
  # test the visualizer
203
233
  if __name__ == "__main__":
204
- from parabellum import Parabellum, Scenario
205
234
  from jax import random, numpy as jnp
235
+ from parabellum import Parabellum, scenarios
236
+
237
+ # small_multiples() # testing small multiples (not working yet)
238
+ # exit()
206
239
 
207
- s = Scenario(jnp.array([[16, 0]]),
208
- jnp.array([[0, 32]]) * 8,
209
- jnp.zeros((19,), dtype=jnp.uint8),
210
- 9,
211
- 10)
212
- env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
213
- rng, key = random.split(random.PRNGKey(0))
214
- obs, state = env.reset(key)
240
+ n_envs = 2
241
+ env = Parabellum(scenarios["default"])
242
+ rng, reset_rng = random.split(random.PRNGKey(0))
243
+ reset_key = random.split(reset_rng, n_envs)
244
+ obs, state = vmap(env.reset)(reset_key)
215
245
  state_seq = []
216
- for step in range(50):
217
- rng, key = random.split(rng)
218
- key_act = random.split(key, len(env.agents))
219
- actions = {
220
- agent: jnp.array(1)
221
- for i, agent in enumerate(env.agents)
246
+
247
+ for i in range(100):
248
+ rng, act_rng, step_rng = random.split(rng, 3)
249
+ act_key = random.split(act_rng, (len(env.agents), n_envs))
250
+ act = {
251
+ a: jnp.ones_like(vmap(env.action_space(a).sample)(act_key[i]))
252
+ for i, a in enumerate(env.agents)
222
253
  }
223
- state_seq.append((key, state, actions))
224
- rng, key_step = random.split(rng)
225
- obs, state, reward, done, infos = env.step(key_step, state, actions)
254
+ step_key = random.split(step_rng, n_envs)
255
+ state_seq.append((step_key, state, act))
256
+ obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
226
257
 
227
258
  vis = Visualizer(env, state_seq)
228
259
  vis.animate()
229
-
230
-
@@ -0,0 +1,104 @@
1
+ Metadata-Version: 2.1
2
+ Name: parabellum
3
+ Version: 0.2.15
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Home-page: https://github.com/syrkis/parabellum
6
+ License: MIT
7
+ Keywords: warfare,simulation,parallel,environment
8
+ Author: Noah Syrkis
9
+ Author-email: desk@syrkis.com
10
+ Requires-Python: >=3.11,<4.0
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
+ Requires-Dist: jax (==0.4.17)
17
+ Requires-Dist: jaxmarl (==0.0.3)
18
+ Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
19
+ Requires-Dist: moviepy (>=1.0.3,<2.0.0)
20
+ Requires-Dist: poetry (>=1.8.3,<2.0.0)
21
+ Requires-Dist: pygame (>=2.5.2,<3.0.0)
22
+ Requires-Dist: tqdm (>=4.66.4,<5.0.0)
23
+ Project-URL: Repository, https://github.com/syrkis/parabellum
24
+ Description-Content-Type: text/markdown
25
+
26
+ # parabellum
27
+
28
+ Parabellum is an ultra-scalable, high-performance warfare simulation engine.
29
+ It is based on JaxMARL's SMAX environment, but has been heavily modified to
30
+ support a wide range of new features and improvements.
31
+
32
+ ## Installation
33
+
34
+ Parabellum is written in Python 3.11 and can be installed using pip:
35
+
36
+ ```bash
37
+ pip install parabellum
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ Parabellum is designed to be used in conjunction with JAX, a high-performance
43
+ numerical computing library. Here is a simple example of how to use Parabellum
44
+ to simulate a game with 10 agents and 10 enemies, each taking random actions:
45
+
46
+ ```python
47
+ import parabellum as pb
48
+ from jax import random
49
+
50
+ # define the scenario
51
+ kwargs = dict(obstacle_coords=[(7, 7)], obstacle_deltas=[(10, 0)])
52
+ scenario = pb.Scenario(**kwargs) # <- Scenario is an important part of parabellum
53
+
54
+ # create the environment
55
+ kwargs = dict(map_width=256, map_height=256, num_agents=10, num_enemies=10)
56
+ env = pb.Parabellum(**kwargs) # <- Parabellum is the central class of parabellum
57
+
58
+ # initiate stochasticity
59
+ rng = random.PRNGKey(0)
60
+ rng, key = random.split(rng)
61
+
62
+ # initialize the environment state
63
+ obs, state = env.reset(key)
64
+ state_sequence = []
65
+
66
+ for _ in range(1000):
67
+
68
+ # manage stochasticity
69
+ rng, rng_act, key_step = random.split(key)
70
+ key_act = random.split(rng_act, len(env.agents))
71
+
72
+ # sample actions and append to state sequence
73
+ act = {a: env.action_space(a).sample(k)
74
+ for a, k in zip(env.agents, key_act)}
75
+
76
+ # step the environment
77
+ state_sequence.append((key_act, state, act))
78
+ obs, state, reward, done, info = env.step(key_step, act, state)
79
+
80
+
81
+ # save visualization of the state sequence
82
+ vis = pb.Visualizer(env, state_sequence) # <- Visualizer is a nice to have class
83
+ vis.animate()
84
+ ```
85
+
86
+
87
+ ## Features
88
+
89
+ - Obstacles — can be inserted in
90
+
91
+ ## TODO
92
+
93
+ - [x] Parallel pygame vis
94
+ - [ ] Parallel bullet renderings
95
+ - [ ] Combine parallell plots into one (maybe out of parabellum scope)
96
+ - [ ] Color for health?
97
+ - [ ] Add the ability to see ongoing game.
98
+ - [ ] Bug test friendly fire.
99
+ - [x] Start sim from arbitrary state.
100
+ - [ ] Save when the episode ends in some state/obs variable
101
+ - [ ] Look for the source of the bug when using more Allies than Enemies
102
+ - [ ] Y inversed axis for parabellum visualization
103
+ - [ ] Units see through obstacles?
104
+
@@ -0,0 +1,8 @@
1
+ parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
+ parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
3
+ parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
+ parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
+ parabellum/vis.py,sha256=xElgv7cbI-YH4aSosoCvHW34qJpo8Pz_xr5hD-0SGB4,10554
6
+ parabellum-0.2.15.dist-info/METADATA,sha256=L3a5CPmPo2ea8En3a7yvngl7pymdgxWbWKzn9LSzIk0,3223
7
+ parabellum-0.2.15.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
+ parabellum-0.2.15.dist-info/RECORD,,
@@ -1,4 +0,0 @@
1
- from .env import Parabellum, Scenario
2
- from .vis import Visualizer
3
-
4
- __all__ = ["Parabellum", "Visualizer", "Scenario"]
@@ -1,296 +0,0 @@
1
- """Parabellum environment based on SMAX"""
2
-
3
- import jax.numpy as jnp
4
- import jax
5
- import numpy as np
6
- from jax import random
7
- from jax import jit
8
- from flax.struct import dataclass
9
- import chex
10
- from jaxmarl.environments.smax.smax_env import State, SMAX
11
- from typing import Tuple, Dict
12
- from functools import partial
13
-
14
-
15
- @dataclass
16
- class Scenario:
17
- """Parabellum scenario"""
18
-
19
- obstacle_coords: chex.Array
20
- obstacle_deltas: chex.Array
21
-
22
- unit_types: chex.Array
23
- num_allies: int
24
- num_enemies: int
25
-
26
- smacv2_position_generation: bool = False
27
- smacv2_unit_type_generation: bool = False
28
-
29
-
30
- # default scenario
31
- scenarios = {
32
- "default": Scenario(
33
- jnp.array([[6, 10], [26, 10]]) * 8,
34
- jnp.array([[0, 12], [0, 1]]) * 8,
35
- jnp.zeros((19,), dtype=jnp.uint8),
36
- 9,
37
- 10,
38
- )
39
- }
40
-
41
-
42
- class Parabellum(SMAX):
43
- def __init__(
44
- self,
45
- scenario: Scenario = scenarios["default"],
46
- unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
47
- **kwargs,
48
- ):
49
- super().__init__(scenario=scenario, **kwargs)
50
- self.unit_type_attack_blasts = unit_type_attack_blasts
51
- self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
52
- self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
53
- self.max_steps = 200
54
- # overwrite supers _world_step method
55
-
56
-
57
- def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
- return state
59
-
60
- def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
- delta_matrix = pos[:, None] - pos[None, :]
62
- dist_matrix = (
63
- jnp.linalg.norm(delta_matrix, axis=-1)
64
- + jnp.identity(self.num_agents)
65
- + 1e-6
66
- )
67
- radius_matrix = (
68
- self.unit_type_radiuses[unit_types][:, None]
69
- + self.unit_type_radiuses[unit_types][None, :]
70
- )
71
- overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
72
- unit_positions = (
73
- pos
74
- + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
- )
76
- return unit_positions
77
-
78
- @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
79
- def _world_step( # modified version of JaxMARL's SMAX _world_step
80
- self,
81
- key: chex.PRNGKey,
82
- state: State,
83
- actions: Tuple[chex.Array, chex.Array],
84
- ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
-
86
- @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
- def inter_fn(pos, new_pos, obs, obs_end):
88
- d1 = jnp.cross(obs - pos, new_pos - pos)
89
- d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
- d3 = jnp.cross(pos - obs, obs_end - obs)
91
- d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
- return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
-
94
- def update_position(idx, vec):
95
- # Compute the movements slightly strangely.
96
- # The velocities below are for diagonal directions
97
- # because these are easier to encode as actions than the four
98
- # diagonal directions. Then rotate the velocity 45
99
- # degrees anticlockwise to compute the movement.
100
- pos = state.unit_positions[idx]
101
- new_pos = (
102
- pos
103
- + vec
104
- * self.unit_type_velocities[state.unit_types[idx]]
105
- * self.time_per_step
106
- )
107
- # avoid going out of bounds
108
- new_pos = jnp.maximum(
109
- jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
110
- jnp.zeros((2,)),
111
- )
112
-
113
- #######################################################################
114
- ############################################ avoid going into obstacles
115
- obs = self.obstacle_coords
116
- obs_end = obs + self.obstacle_deltas
117
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
118
- new_pos = jnp.where(inters, pos, new_pos)
119
-
120
- #######################################################################
121
- #######################################################################
122
-
123
- return new_pos
124
-
125
- #######################################################################
126
- ######################################### units close enough to get hit
127
-
128
- def bystander_fn(attacked_idx):
129
- idxs = jnp.zeros((self.num_agents,))
130
- idxs *= (
131
- jnp.linalg.norm(
132
- state.unit_positions - state.unit_positions[attacked_idx], axis=-1
133
- )
134
- < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
135
- )
136
- return idxs
137
-
138
- #######################################################################
139
- #######################################################################
140
-
141
- def update_agent_health(idx, action, key): # TODO: add attack blasts
142
- # for team 1, their attack actions are labelled in
143
- # reverse order because that is the order they are
144
- # observed in
145
- attacked_idx = jax.lax.cond(
146
- idx < self.num_allies,
147
- lambda: action + self.num_allies - self.num_movement_actions,
148
- lambda: self.num_allies - 1 - (action - self.num_movement_actions),
149
- )
150
- # deal with no-op attack actions (i.e. agents that are moving instead)
151
- attacked_idx = jax.lax.select(
152
- action < self.num_movement_actions, idx, attacked_idx
153
- )
154
-
155
- attack_valid = (
156
- (
157
- jnp.linalg.norm(
158
- state.unit_positions[idx] - state.unit_positions[attacked_idx]
159
- )
160
- < self.unit_type_attack_ranges[state.unit_types[idx]]
161
- )
162
- & state.unit_alive[idx]
163
- & state.unit_alive[attacked_idx]
164
- )
165
- attack_valid = attack_valid & (idx != attacked_idx)
166
- attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
167
- health_diff = jax.lax.select(
168
- attack_valid,
169
- -self.unit_type_attacks[state.unit_types[idx]],
170
- 0.0,
171
- )
172
- # design choice based on the pysc2 randomness details.
173
- # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
174
-
175
- #########################################################
176
- ############################### Add bystander health diff
177
-
178
- bystander_idxs = bystander_fn(attacked_idx) # TODO: use
179
- bystander_valid = (
180
- jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
181
- .astype(jnp.bool_)
182
- .astype(jnp.float32)
183
- )
184
- bystander_health_diff = (
185
- bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
186
- )
187
-
188
- #########################################################
189
- #########################################################
190
-
191
- cooldown_deviation = jax.random.uniform(
192
- key, minval=-self.time_per_step, maxval=2 * self.time_per_step
193
- )
194
- cooldown = (
195
- self.unit_type_weapon_cooldowns[state.unit_types[idx]]
196
- + cooldown_deviation
197
- )
198
- cooldown_diff = jax.lax.select(
199
- attack_valid,
200
- # subtract the current cooldown because we are
201
- # going to add it back. This way we effectively
202
- # set the new cooldown to `cooldown`
203
- cooldown - state.unit_weapon_cooldowns[idx],
204
- -self.time_per_step,
205
- )
206
- return (
207
- health_diff,
208
- attacked_idx,
209
- cooldown_diff,
210
- (bystander_health_diff, bystander_idxs),
211
- )
212
-
213
- def perform_agent_action(idx, action, key):
214
- movement_action, attack_action = action
215
- new_pos = update_position(idx, movement_action)
216
- health_diff, attacked_idxes, cooldown_diff, (bystander) = (
217
- update_agent_health(idx, attack_action, key)
218
- )
219
-
220
- return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
221
-
222
- keys = jax.random.split(key, num=self.num_agents)
223
- pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
224
- perform_agent_action
225
- )(jnp.arange(self.num_agents), actions, keys)
226
-
227
- # checked that no unit passed through an obstacles
228
- new_pos = self._our_push_units_away(pos, state.unit_types)
229
-
230
- # avoid going into obstacles after being pushed
231
- obs = self.obstacle_coords
232
- obs_end = obs + self.obstacle_deltas
233
-
234
- def check_obstacles(pos, new_pos, obs, obs_end):
235
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
- return jnp.where(inters, pos, new_pos)
237
-
238
- pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
-
240
- # Multiple enemies can attack the same unit.
241
- # We have `(health_diff, attacked_idx)` pairs.
242
- # `jax.lax.scatter_add` aggregates these exactly
243
- # in the way we want -- duplicate idxes will have their
244
- # health differences added together. However, it is a
245
- # super thin wrapper around the XLA scatter operation,
246
- # which has this bonkers syntax and requires this dnums
247
- # parameter. The usage here was inferred from a test:
248
- # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
249
- dnums = jax.lax.ScatterDimensionNumbers(
250
- update_window_dims=(),
251
- inserted_window_dims=(0,),
252
- scatter_dims_to_operand_dims=(0,),
253
- )
254
- unit_health = jnp.maximum(
255
- jax.lax.scatter_add(
256
- state.unit_health,
257
- jnp.expand_dims(attacked_idxes, 1),
258
- health_diff,
259
- dnums,
260
- ),
261
- 0.0,
262
- )
263
-
264
- #########################################################
265
- ############################ subtracting bystander health
266
-
267
- _, bystander_health_diff = bystander
268
- unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
269
-
270
- #########################################################
271
- #########################################################
272
-
273
- unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
274
- state = state.replace(
275
- unit_health=unit_health,
276
- unit_positions=pos,
277
- unit_weapon_cooldowns=unit_weapon_cooldowns,
278
- )
279
- return state
280
-
281
- if __name__ == "__main__":
282
- env = Parabellum(map_width=256, map_height=256)
283
- rng, key = random.split(random.PRNGKey(0))
284
- obs, state = env.reset(key)
285
- state_seq = []
286
- for step in range(100):
287
- rng, key = random.split(rng)
288
- key_act = random.split(key, len(env.agents))
289
- actions = {
290
- agent: jax.random.randint(key_act[i], (), 0, 5)
291
- for i, agent in enumerate(env.agents)
292
- }
293
- _, state, _, _, _ = env.step(key, state, actions)
294
- state_seq.append((obs, state, actions))
295
-
296
-
@@ -1,230 +0,0 @@
1
- """Visualizer for the Parabellum environment"""
2
-
3
- from tqdm import tqdm
4
- import jax.numpy as jnp
5
- import jax
6
- from jax import vmap
7
- from functools import partial
8
- import darkdetect
9
- import pygame
10
- from moviepy.editor import ImageSequenceClip
11
- from typing import Optional
12
- from jaxmarl.environments.multi_agent_env import MultiAgentEnv
13
- from jaxmarl.viz.visualizer import SMAXVisualizer
14
-
15
- # default dict
16
- from collections import defaultdict
17
-
18
-
19
- # constants
20
- action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
21
-
22
-
23
- class Visualizer(SMAXVisualizer):
24
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
25
- super().__init__(env, state_seq, reward_seq)
26
- # remove fig and ax from super
27
- self.fig, self.ax = None, None
28
- self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
29
- self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
30
- self.s = 1000
31
- self.scale = self.s / self.env.map_width
32
- self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
33
- self.bullet_seq = bullet_fn(self.env, self.state_seq)
34
-
35
- def render_agents(self, screen, state):
36
- time_tuple = zip(
37
- state.unit_positions,
38
- state.unit_teams,
39
- state.unit_types,
40
- state.unit_health,
41
- )
42
- for idx, (pos, team, kind, hp) in enumerate(time_tuple):
43
- face_col = self.fg if int(team.item()) == 0 else self.bg
44
- pos = tuple((pos * self.scale).tolist())
45
-
46
- # draw the agent
47
- if hp > 0:
48
- hp_frac = hp / self.env.unit_type_health[kind]
49
- unit_size = self.env.unit_type_radiuses[kind]
50
- radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
51
- pygame.draw.circle(screen, face_col, pos, radius)
52
- pygame.draw.circle(screen, self.fg, pos, radius, 1)
53
-
54
- # draw the sight range
55
- # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
56
- # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
57
-
58
- # draw attack range
59
- # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
60
- # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
61
- # work out which agents are being shot
62
-
63
- def render_action(self, screen, action):
64
- def coord_fn(idx, n, team):
65
- return (
66
- self.s / 20 if team == 0 else self.s - self.s / 20,
67
- # vertically centered so that n / 2 is above and below the center
68
- self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
69
- )
70
-
71
- for idx in range(self.env.num_allies):
72
- symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
73
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
74
- text = font.render(symb, True, self.fg)
75
- coord = coord_fn(idx, self.env.num_allies, 0)
76
- screen.blit(text, coord)
77
-
78
- for idx in range(self.env.num_enemies):
79
- symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
80
- font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
81
- text = font.render(symb, True, self.fg)
82
- coord = coord_fn(idx, self.env.num_enemies, 1)
83
- screen.blit(text, coord)
84
-
85
- def render_obstacles(self, screen):
86
- for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
87
- d = tuple(((c + d) * self.scale).tolist())
88
- c = tuple((c * self.scale).tolist())
89
- pygame.draw.line(screen, self.fg, c, d, 5)
90
-
91
- def render_bullets(self, screen, bullets, jdx):
92
- jdx += 1
93
- ally_bullets, enemy_bullets = bullets
94
- for source, target in ally_bullets:
95
- position = source + (target - source) * jdx / 8
96
- position *= self.scale
97
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
98
- for source, target in enemy_bullets:
99
- position = source + (target - source) * jdx / 8
100
- position *= self.scale
101
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
102
-
103
- def animate(self, save_fname: str = "parabellum.mp4"):
104
- if not self.have_expanded:
105
- self.expand_state_seq()
106
- frames = [] # frames for the video
107
- pygame.init() # initialize pygame
108
- for idx, (_, state, _) in tqdm(
109
- enumerate(self.state_seq), total=len(self.state_seq)
110
- ):
111
- screen = pygame.Surface(
112
- (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
113
- )
114
- screen.fill(self.bg) # fill the screen with the background color
115
-
116
- self.render_agents(screen, state) # render the agents
117
- self.render_action(screen, self.action_seq[idx // 8])
118
- self.render_obstacles(screen) # render the obstacles
119
-
120
- # bullets
121
- if idx < len(self.bullet_seq) * 8:
122
- bullets = self.bullet_seq[idx // 8]
123
- self.render_bullets(screen, bullets, idx % 8)
124
-
125
- # rotate the screen and append to frames
126
- frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
127
-
128
- # save the images
129
- clip = ImageSequenceClip(frames, fps=48)
130
- clip.write_videofile(save_fname, fps=48)
131
- # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
132
- pygame.quit()
133
-
134
- return clip
135
-
136
-
137
- # functions
138
- # bullet functions
139
- def dist_fn(env, pos): # computing the distances between all ally and enemy agents
140
- delta = pos[None, :, :] - pos[:, None, :]
141
- dist = jnp.sqrt((delta**2).sum(axis=2))
142
- dist = dist[: env.num_allies, env.num_allies :]
143
- return {"ally": dist, "enemy": dist.T}
144
-
145
-
146
- def range_fn(env, dists, ranges): # computing what targets are in range
147
- ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
148
- enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
149
- return {"ally": ally_range, "enemy": enemy_range}
150
-
151
-
152
- def target_fn(acts, in_range, team): # computing the one hot valid targets
153
- t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
154
- t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
155
- t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
156
- return t_attacks * in_range[team] # one hot valid targets
157
-
158
-
159
- def attack_fn(env, state_seq): # one hot attack list
160
- attacks = []
161
- for _, state, acts in state_seq:
162
- dists = dist_fn(env, state.unit_positions)
163
- ranges = env.unit_type_attack_ranges[state.unit_types]
164
- in_range = range_fn(env, dists, ranges)
165
- target = partial(target_fn, acts, in_range)
166
- attack = {"ally": target("ally"), "enemy": target("enemy")}
167
- attacks.append(attack)
168
- return attacks
169
-
170
-
171
- def bullet_fn(env, states):
172
- bullet_seq = []
173
- attack_seq = attack_fn(env, states)
174
-
175
- def aux_fn(team):
176
- bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
177
- # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
178
- return bullets
179
-
180
- state_zip = zip(states[:-1], states[1:])
181
- for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
182
- one_hot = attack_seq[i]
183
- ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
184
-
185
- ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
186
- enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
187
-
188
- enemy_bullets_source = state.unit_positions[
189
- enemy_bullets[:, 0] + env.num_allies
190
- ]
191
- ally_bullets_target = n_state.unit_positions[
192
- ally_bullets[:, 1] + env.num_allies
193
- ]
194
-
195
- ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
196
- enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
197
-
198
- bullet_seq.append((ally_bullets, enemy_bullets))
199
- return bullet_seq
200
-
201
-
202
- # test the visualizer
203
- if __name__ == "__main__":
204
- from parabellum import Parabellum, Scenario
205
- from jax import random, numpy as jnp
206
-
207
- s = Scenario(jnp.array([[16, 0]]),
208
- jnp.array([[0, 32]]) * 8,
209
- jnp.zeros((19,), dtype=jnp.uint8),
210
- 9,
211
- 10)
212
- env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
213
- rng, key = random.split(random.PRNGKey(0))
214
- obs, state = env.reset(key)
215
- state_seq = []
216
- for step in range(50):
217
- rng, key = random.split(rng)
218
- key_act = random.split(key, len(env.agents))
219
- actions = {
220
- agent: jnp.array(1)
221
- for i, agent in enumerate(env.agents)
222
- }
223
- state_seq.append((key, state, actions))
224
- rng, key_step = random.split(rng)
225
- obs, state, reward, done, infos = env.step(key_step, state, actions)
226
-
227
- vis = Visualizer(env, state_seq)
228
- vis.animate()
229
-
230
-
@@ -1,56 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: parabellum
3
- Version: 0.2.13
4
- Summary: Parabellum environment for parallel warfare simulation
5
- Home-page: https://github.com/syrkis/parabellum
6
- License: MIT
7
- Keywords: warfare,simulation,parallel,environment
8
- Author: Noah Syrkis
9
- Author-email: desk@syrkis.com
10
- Requires-Python: >=3.11,<4.0
11
- Classifier: License :: OSI Approved :: MIT License
12
- Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
16
- Requires-Dist: jaxmarl (==0.0.3)
17
- Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
18
- Requires-Dist: moviepy (>=1.0.3,<2.0.0)
19
- Requires-Dist: poetry (>=1.8.3,<2.0.0)
20
- Requires-Dist: pygame (>=2.5.2,<3.0.0)
21
- Requires-Dist: tqdm (>=4.66.4,<5.0.0)
22
- Project-URL: Repository, https://github.com/syrkis/parabellum
23
- Description-Content-Type: text/markdown
24
-
25
- # parabellum
26
-
27
- Parabellum is an ultra-scalable, high-performance warfare simulation engine.
28
- It is based on JaxMARL's SMAX environment, but has been heavily modified to
29
- support a wide range of new features and improvements.
30
-
31
- ## Installation
32
-
33
- Install through PyPI:
34
-
35
- ```bash
36
- pip install parabellum
37
- ```
38
-
39
- ## Usage
40
-
41
- ```python
42
- import parabellum as pb
43
- ```
44
-
45
- ## TODO
46
-
47
- - [ ] Parallel pygame vis
48
- - [ ] Color for health?
49
- - [ ] Add the ability to see ongoing game.
50
- - [ ] Bug test friendly fire.
51
- - [ ] Start sim from arbitrary state.
52
- - [ ] Save when the episode ends in some state/obs variable
53
- - [ ] Look for the source of the bug when using more Allies than Enemies
54
- - [ ] Y inversed axis for parabellum visualization
55
- - [ ] Units see through obstacles?
56
-
@@ -1,9 +0,0 @@
1
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
2
- parabellum/.ipynb_checkpoints/env-checkpoint.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
3
- parabellum/.ipynb_checkpoints/vis-checkpoint.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
4
- parabellum/__init__.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
5
- parabellum/env.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
6
- parabellum/vis.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
7
- parabellum-0.2.13.dist-info/METADATA,sha256=jIs4QuEQ7Act04HK2cwKw4PaXOi9Y01ucQ2a4np4-KU,1627
8
- parabellum-0.2.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
9
- parabellum-0.2.13.dist-info/RECORD,,