parabellum 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .env import Parabellum
2
+ from .vis import Visualizer
3
+
4
+ __all__ = ["Parabellum", "Visualizer"]
parabellum/env.py ADDED
@@ -0,0 +1,257 @@
1
+ """Parabellum environment based on SMAX"""
2
+
3
+ import jax.numpy as jnp
4
+ import jax
5
+ import numpy as np
6
+ from jax import random
7
+ from jax import jit
8
+ from flax.struct import dataclass
9
+ import chex
10
+ from jaxmarl.environments.smax.smax_env import State, SMAX
11
+ from typing import Tuple, Dict
12
+ from functools import partial
13
+
14
+
15
+ @dataclass
16
+ class Scenario:
17
+ """Parabellum scenario"""
18
+
19
+ obstacle_coords: chex.Array
20
+ obstacle_deltas: chex.Array
21
+
22
+ unit_types: chex.Array
23
+ num_allies: int = 9
24
+ num_enemies: int = 10
25
+
26
+ smacv2_position_generation: bool = False
27
+ smacv2_unit_type_generation: bool = False
28
+
29
+
30
+ # default scenario
31
+ scenarios = {
32
+ "default": Scenario(
33
+ jnp.array([[6, 10], [26, 10]]) * 8,
34
+ jnp.array([[0, 12], [0, 1]]) * 8,
35
+ jnp.zeros((19,), dtype=jnp.uint8),
36
+ )
37
+ }
38
+
39
+
40
+ class Parabellum(SMAX):
41
+ def __init__(
42
+ self,
43
+ scenario: Scenario = scenarios["default"],
44
+ unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(scenario=scenario, **kwargs)
48
+ self.unit_type_attack_blasts = unit_type_attack_blasts
49
+ self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
50
+ self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
51
+ self.max_steps = 200
52
+ # overwrite supers _world_step method
53
+
54
+ @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
55
+ def _world_step( # modified version of JaxMARL's SMAX _world_step
56
+ self,
57
+ key: chex.PRNGKey,
58
+ state: State,
59
+ actions: Tuple[chex.Array, chex.Array],
60
+ ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
61
+ def update_position(idx, vec):
62
+ # Compute the movements slightly strangely.
63
+ # The velocities below are for diagonal directions
64
+ # because these are easier to encode as actions than the four
65
+ # diagonal directions. Then rotate the velocity 45
66
+ # degrees anticlockwise to compute the movement.
67
+ pos = state.unit_positions[idx]
68
+ new_pos = (
69
+ pos
70
+ + vec
71
+ * self.unit_type_velocities[state.unit_types[idx]]
72
+ * self.time_per_step
73
+ )
74
+ # avoid going out of bounds
75
+ new_pos = jnp.maximum(
76
+ jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
77
+ jnp.zeros((2,)),
78
+ )
79
+
80
+ #######################################################################
81
+ ############################################ avoid going into obstacles
82
+
83
+ @partial(jax.vmap, in_axes=(None, None, 0, 0))
84
+ def inter_fn(pos, new_pos, obs, obs_end):
85
+ d1 = jnp.cross(obs - pos, new_pos - pos)
86
+ d2 = jnp.cross(obs_end - pos, new_pos - pos)
87
+ d3 = jnp.cross(pos - obs, obs_end - obs)
88
+ d4 = jnp.cross(new_pos - obs, obs_end - obs)
89
+ return (d1 * d2 < 0) & (d3 * d4 < 0)
90
+
91
+ obs = self.obstacle_coords
92
+ obs_end = obs + self.obstacle_deltas
93
+ inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
94
+ new_pos = jnp.where(inters, pos, new_pos)
95
+
96
+ #######################################################################
97
+ #######################################################################
98
+
99
+ return new_pos
100
+
101
+ #######################################################################
102
+ ######################################### units close enough to get hit
103
+
104
+ def bystander_fn(attacked_idx):
105
+ idxs = jnp.zeros((self.num_agents,))
106
+ idxs *= (
107
+ jnp.linalg.norm(
108
+ state.unit_positions - state.unit_positions[attacked_idx], axis=-1
109
+ )
110
+ < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
111
+ )
112
+ return idxs
113
+
114
+ #######################################################################
115
+ #######################################################################
116
+
117
+ def update_agent_health(idx, action, key): # TODO: add attack blasts
118
+ # for team 1, their attack actions are labelled in
119
+ # reverse order because that is the order they are
120
+ # observed in
121
+ attacked_idx = jax.lax.cond(
122
+ idx < self.num_allies,
123
+ lambda: action + self.num_allies - self.num_movement_actions,
124
+ lambda: self.num_allies - 1 - (action - self.num_movement_actions),
125
+ )
126
+ # deal with no-op attack actions (i.e. agents that are moving instead)
127
+ attacked_idx = jax.lax.select(
128
+ action < self.num_movement_actions, idx, attacked_idx
129
+ )
130
+
131
+ attack_valid = (
132
+ (
133
+ jnp.linalg.norm(
134
+ state.unit_positions[idx] - state.unit_positions[attacked_idx]
135
+ )
136
+ < self.unit_type_attack_ranges[state.unit_types[idx]]
137
+ )
138
+ & state.unit_alive[idx]
139
+ & state.unit_alive[attacked_idx]
140
+ )
141
+ attack_valid = attack_valid & (idx != attacked_idx)
142
+ attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
143
+ health_diff = jax.lax.select(
144
+ attack_valid,
145
+ -self.unit_type_attacks[state.unit_types[idx]],
146
+ 0.0,
147
+ )
148
+ # design choice based on the pysc2 randomness details.
149
+ # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
150
+
151
+ #########################################################
152
+ ############################### Add bystander health diff
153
+
154
+ bystander_idxs = bystander_fn(attacked_idx) # TODO: use
155
+ bystander_valid = (
156
+ jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
157
+ .astype(jnp.bool_)
158
+ .astype(jnp.float32)
159
+ )
160
+ bystander_health_diff = (
161
+ bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
162
+ )
163
+
164
+ #########################################################
165
+ #########################################################
166
+
167
+ cooldown_deviation = jax.random.uniform(
168
+ key, minval=-self.time_per_step, maxval=2 * self.time_per_step
169
+ )
170
+ cooldown = (
171
+ self.unit_type_weapon_cooldowns[state.unit_types[idx]]
172
+ + cooldown_deviation
173
+ )
174
+ cooldown_diff = jax.lax.select(
175
+ attack_valid,
176
+ # subtract the current cooldown because we are
177
+ # going to add it back. This way we effectively
178
+ # set the new cooldown to `cooldown`
179
+ cooldown - state.unit_weapon_cooldowns[idx],
180
+ -self.time_per_step,
181
+ )
182
+ return (
183
+ health_diff,
184
+ attacked_idx,
185
+ cooldown_diff,
186
+ (bystander_health_diff, bystander_idxs),
187
+ )
188
+
189
+ def perform_agent_action(idx, action, key):
190
+ movement_action, attack_action = action
191
+ new_pos = update_position(idx, movement_action)
192
+ health_diff, attacked_idxes, cooldown_diff, (bystander) = (
193
+ update_agent_health(idx, attack_action, key)
194
+ )
195
+
196
+ return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
197
+
198
+ keys = jax.random.split(key, num=self.num_agents)
199
+ pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
200
+ perform_agent_action
201
+ )(jnp.arange(self.num_agents), actions, keys)
202
+ # Multiple enemies can attack the same unit.
203
+ # We have `(health_diff, attacked_idx)` pairs.
204
+ # `jax.lax.scatter_add` aggregates these exactly
205
+ # in the way we want -- duplicate idxes will have their
206
+ # health differences added together. However, it is a
207
+ # super thin wrapper around the XLA scatter operation,
208
+ # which has this bonkers syntax and requires this dnums
209
+ # parameter. The usage here was inferred from a test:
210
+ # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
211
+ dnums = jax.lax.ScatterDimensionNumbers(
212
+ update_window_dims=(),
213
+ inserted_window_dims=(0,),
214
+ scatter_dims_to_operand_dims=(0,),
215
+ )
216
+ unit_health = jnp.maximum(
217
+ jax.lax.scatter_add(
218
+ state.unit_health,
219
+ jnp.expand_dims(attacked_idxes, 1),
220
+ health_diff,
221
+ dnums,
222
+ ),
223
+ 0.0,
224
+ )
225
+
226
+ #########################################################
227
+ ############################ subtracting bystander health
228
+
229
+ _, bystander_health_diff = bystander
230
+ unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
231
+
232
+ #########################################################
233
+ #########################################################
234
+
235
+ unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
236
+ state = state.replace(
237
+ unit_health=unit_health,
238
+ unit_positions=pos,
239
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
240
+ )
241
+ return state
242
+
243
+
244
+ if __name__ == "__main__":
245
+ env = Parabellum(map_width=256, map_height=256)
246
+ rng, key = random.split(random.PRNGKey(0))
247
+ obs, state = env.reset(key)
248
+ state_seq = []
249
+ for step in range(100):
250
+ rng, key = random.split(rng)
251
+ key_act = random.split(key, len(env.agents))
252
+ actions = {
253
+ agent: jax.random.randint(key_act[i], (), 0, 5)
254
+ for i, agent in enumerate(env.agents)
255
+ }
256
+ _, state, _, _, _ = env.step(key, state, actions)
257
+ state_seq.append((obs, state, actions))
parabellum/vis.py ADDED
@@ -0,0 +1,224 @@
1
+ """Visualizer for the Parabellum environment"""
2
+
3
+ from tqdm import tqdm
4
+ import jax.numpy as jnp
5
+ import jax
6
+ from jax import vmap
7
+ from functools import partial
8
+ import darkdetect
9
+ import pygame
10
+ from moviepy.editor import ImageSequenceClip
11
+ from typing import Optional
12
+ from jaxmarl.environments.multi_agent_env import MultiAgentEnv
13
+ from jaxmarl.viz.visualizer import SMAXVisualizer
14
+
15
+ # default dict
16
+ from collections import defaultdict
17
+
18
+
19
+ # constants
20
+
21
+ action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
22
+
23
+
24
+ class Visualizer(SMAXVisualizer):
25
+ def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
26
+ super().__init__(env, state_seq, reward_seq)
27
+ # remove fig and ax from super
28
+ self.fig, self.ax = None, None
29
+ self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
30
+ self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
31
+ self.s = 1000
32
+ self.scale = self.s / self.env.map_width
33
+ self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
34
+ self.bullet_seq = bullet_fn(self.env, self.state_seq)
35
+
36
+ def render_agents(self, screen, state):
37
+ time_tuple = zip(
38
+ state.unit_positions,
39
+ state.unit_teams,
40
+ state.unit_types,
41
+ state.unit_health,
42
+ )
43
+ for idx, (pos, team, kind, hp) in enumerate(time_tuple):
44
+ face_col = self.fg if int(team.item()) == 0 else self.bg
45
+ pos = tuple((pos * self.scale).tolist())
46
+
47
+ # draw the agent
48
+ if hp > 0:
49
+ hp_frac = hp / self.env.unit_type_health[kind]
50
+ unit_size = self.env.unit_type_radiuses[kind]
51
+ radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
52
+ pygame.draw.circle(screen, face_col, pos, radius)
53
+ pygame.draw.circle(screen, self.fg, pos, radius, 1)
54
+
55
+ # draw the sight range
56
+ # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
57
+ # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
58
+
59
+ # draw attack range
60
+ # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
61
+ # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
62
+ # work out which agents are being shot
63
+
64
+ def render_action(self, screen, action):
65
+ def coord_fn(idx, n, team):
66
+ return (
67
+ self.s / 20 if team == 0 else self.s - self.s / 20,
68
+ # vertically centered so that n / 2 is above and below the center
69
+ self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
70
+ )
71
+
72
+ for idx in range(self.env.num_allies):
73
+ symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
74
+ font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
75
+ text = font.render(symb, True, self.fg)
76
+ coord = coord_fn(idx, self.env.num_allies, 0)
77
+ screen.blit(text, coord)
78
+
79
+ for idx in range(self.env.num_enemies):
80
+ symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
81
+ font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
82
+ text = font.render(symb, True, self.fg)
83
+ coord = coord_fn(idx, self.env.num_enemies, 1)
84
+ screen.blit(text, coord)
85
+
86
+ def render_obstacles(self, screen):
87
+ for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
88
+ d = tuple(((c + d) * self.scale).tolist())
89
+ c = tuple((c * self.scale).tolist())
90
+ pygame.draw.line(screen, self.fg, c, d, 5)
91
+
92
+ def render_bullets(self, screen, bullets, jdx):
93
+ jdx += 1
94
+ ally_bullets, enemy_bullets = bullets
95
+ for source, target in ally_bullets:
96
+ position = source + (target - source) * jdx / 8
97
+ position *= self.scale
98
+ pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
99
+ for source, target in enemy_bullets:
100
+ position = source + (target - source) * jdx / 8
101
+ position *= self.scale
102
+ pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
103
+
104
+ def animate(self, save_fname: str = "parabellum.mp4"):
105
+ if not self.have_expanded:
106
+ self.expand_state_seq()
107
+ frames = [] # frames for the video
108
+ pygame.init() # initialize pygame
109
+ for idx, (_, state, _) in tqdm(
110
+ enumerate(self.state_seq), total=len(self.state_seq)
111
+ ):
112
+ screen = pygame.Surface(
113
+ (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
114
+ )
115
+ screen.fill(self.bg) # fill the screen with the background color
116
+
117
+ self.render_agents(screen, state) # render the agents
118
+ self.render_action(screen, self.action_seq[idx // 8])
119
+ self.render_obstacles(screen) # render the obstacles
120
+
121
+ # bullets
122
+ if idx < len(self.bullet_seq) * 8:
123
+ bullets = self.bullet_seq[idx // 8]
124
+ self.render_bullets(screen, bullets, idx % 8)
125
+
126
+ # rotate the screen and append to frames
127
+ frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
128
+
129
+ # save the images
130
+ clip = ImageSequenceClip(frames, fps=48)
131
+ clip.write_videofile(save_fname, fps=48)
132
+ # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
133
+ pygame.quit()
134
+
135
+ return clip
136
+
137
+
138
+ # functions
139
+ # bullet functions
140
+ def dist_fn(env, pos): # computing the distances between all ally and enemy agents
141
+ delta = pos[None, :, :] - pos[:, None, :]
142
+ dist = jnp.sqrt((delta**2).sum(axis=2))
143
+ dist = dist[: env.num_allies, env.num_allies :]
144
+ return {"ally": dist, "enemy": dist.T}
145
+
146
+
147
+ def range_fn(env, dists, ranges): # computing what targets are in range
148
+ ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
149
+ enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
150
+ return {"ally": ally_range, "enemy": enemy_range}
151
+
152
+
153
+ def target_fn(acts, in_range, team): # computing the one hot valid targets
154
+ t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
155
+ t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
156
+ t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
157
+ return t_attacks * in_range[team] # one hot valid targets
158
+
159
+
160
+ def attack_fn(env, state_seq): # one hot attack list
161
+ attacks = []
162
+ for _, state, acts in state_seq:
163
+ dists = dist_fn(env, state.unit_positions)
164
+ ranges = env.unit_type_attack_ranges[state.unit_types]
165
+ in_range = range_fn(env, dists, ranges)
166
+ target = partial(target_fn, acts, in_range)
167
+ attack = {"ally": target("ally"), "enemy": target("enemy")}
168
+ attacks.append(attack)
169
+ return attacks
170
+
171
+
172
+ def bullet_fn(env, states):
173
+ bullet_seq = []
174
+ attack_seq = attack_fn(env, states)
175
+
176
+ def aux_fn(team):
177
+ bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
178
+ # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
179
+ return bullets
180
+
181
+ state_zip = zip(states[:-1], states[1:])
182
+ for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
183
+ one_hot = attack_seq[i]
184
+ ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
185
+
186
+ ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
187
+ enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
188
+
189
+ enemy_bullets_source = state.unit_positions[
190
+ enemy_bullets[:, 0] + env.num_allies
191
+ ]
192
+ ally_bullets_target = n_state.unit_positions[
193
+ ally_bullets[:, 1] + env.num_allies
194
+ ]
195
+
196
+ ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
197
+ enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
198
+
199
+ bullet_seq.append((ally_bullets, enemy_bullets))
200
+ return bullet_seq
201
+
202
+
203
+ # test the visualizer
204
+ if __name__ == "__main__":
205
+ from parabellum import make
206
+ from jax import random, numpy as jnp
207
+
208
+ env = make("parabellum", map_width=32, map_height=32)
209
+ rng, key = random.split(random.PRNGKey(0))
210
+ obs, state = env.reset(key)
211
+ state_seq = []
212
+ for step in range(100):
213
+ rng, key = random.split(rng)
214
+ key_act = random.split(key, len(env.agents))
215
+ actions = {
216
+ agent: env.action_space(agent).sample(key_act[i])
217
+ for i, agent in enumerate(env.agents)
218
+ }
219
+ state_seq.append((key, state, actions))
220
+ rng, key_step = random.split(rng)
221
+ obs, state, reward, done, infos = env.step(key_step, state, actions)
222
+
223
+ vis = Visualizer(env, state_seq)
224
+ vis.animate()
@@ -0,0 +1,37 @@
1
+ Metadata-Version: 2.1
2
+ Name: parabellum
3
+ Version: 0.1.0
4
+ Summary: Parabellum environment for parallel warfare simulation
5
+ Author: Noah Syrkis
6
+ Author-email: noah@syrkis.com
7
+ Requires-Python: >=3.11,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Classifier: Programming Language :: Python :: 3.12
11
+ Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
12
+ Requires-Dist: jaxmarl (==0.0.3)
13
+ Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
14
+ Requires-Dist: moviepy (>=1.0.3,<2.0.0)
15
+ Requires-Dist: pygame (>=2.5.2,<3.0.0)
16
+ Description-Content-Type: text/markdown
17
+
18
+ # parabellum
19
+
20
+ Parabellum is an ultra-scalable, high-performance warfare simulation engine.
21
+ It is based on JaxMARL's SMAX environment, but has been heavily modified to
22
+ support a wide range of new features and improvements.
23
+
24
+ ## Installation
25
+
26
+ Install through PyPI:
27
+
28
+ ```bash
29
+ pip install parabellum
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ ```python
35
+ import parabellum as pb
36
+ ```
37
+
@@ -0,0 +1,6 @@
1
+ parabellum/__init__.py,sha256=CQj_Z18dyAU366U-UAJt8Rggq2mwArfCrn8jfIaLAjA,96
2
+ parabellum/env.py,sha256=KE_wDPUbvHu02cMl9elT0cV63xVZ3DCpZIBuBM0GoyM,10074
3
+ parabellum/vis.py,sha256=SWlSj1z09qAXUH2W4LeMHWoICZL5mANytLwR0qDXF-A,8812
4
+ parabellum-0.1.0.dist-info/METADATA,sha256=jWirrM6EnME0Y3_PetV-dAwifNnLUbSKs-QJsmMCciI,934
5
+ parabellum-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
6
+ parabellum-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 1.9.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any