parabellum 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from .env import Parabellum, Scenario
2
+ from .vis import Visualizer
3
+
4
+ __all__ = ["Parabellum", "Visualizer", "Scenario"]
@@ -0,0 +1,296 @@
1
+ """Parabellum environment based on SMAX"""
2
+
3
+ import jax.numpy as jnp
4
+ import jax
5
+ import numpy as np
6
+ from jax import random
7
+ from jax import jit
8
+ from flax.struct import dataclass
9
+ import chex
10
+ from jaxmarl.environments.smax.smax_env import State, SMAX
11
+ from typing import Tuple, Dict
12
+ from functools import partial
13
+
14
+
15
+ @dataclass
16
+ class Scenario:
17
+ """Parabellum scenario"""
18
+
19
+ obstacle_coords: chex.Array
20
+ obstacle_deltas: chex.Array
21
+
22
+ unit_types: chex.Array
23
+ num_allies: int
24
+ num_enemies: int
25
+
26
+ smacv2_position_generation: bool = False
27
+ smacv2_unit_type_generation: bool = False
28
+
29
+
30
+ # default scenario
31
+ scenarios = {
32
+ "default": Scenario(
33
+ jnp.array([[6, 10], [26, 10]]) * 8,
34
+ jnp.array([[0, 12], [0, 1]]) * 8,
35
+ jnp.zeros((19,), dtype=jnp.uint8),
36
+ 9,
37
+ 10,
38
+ )
39
+ }
40
+
41
+
42
+ class Parabellum(SMAX):
43
+ def __init__(
44
+ self,
45
+ scenario: Scenario = scenarios["default"],
46
+ unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(scenario=scenario, **kwargs)
50
+ self.unit_type_attack_blasts = unit_type_attack_blasts
51
+ self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
52
+ self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
53
+ self.max_steps = 200
54
+ # overwrite supers _world_step method
55
+
56
+
57
+ def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
+ return state
59
+
60
+ def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
+ delta_matrix = pos[:, None] - pos[None, :]
62
+ dist_matrix = (
63
+ jnp.linalg.norm(delta_matrix, axis=-1)
64
+ + jnp.identity(self.num_agents)
65
+ + 1e-6
66
+ )
67
+ radius_matrix = (
68
+ self.unit_type_radiuses[unit_types][:, None]
69
+ + self.unit_type_radiuses[unit_types][None, :]
70
+ )
71
+ overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
72
+ unit_positions = (
73
+ pos
74
+ + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
+ )
76
+ return unit_positions
77
+
78
+ @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
79
+ def _world_step( # modified version of JaxMARL's SMAX _world_step
80
+ self,
81
+ key: chex.PRNGKey,
82
+ state: State,
83
+ actions: Tuple[chex.Array, chex.Array],
84
+ ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
+
86
+ @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
+ def inter_fn(pos, new_pos, obs, obs_end):
88
+ d1 = jnp.cross(obs - pos, new_pos - pos)
89
+ d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
+ d3 = jnp.cross(pos - obs, obs_end - obs)
91
+ d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
+ return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
+
94
+ def update_position(idx, vec):
95
+ # Compute the movements slightly strangely.
96
+ # The velocities below are for diagonal directions
97
+ # because these are easier to encode as actions than the four
98
+ # diagonal directions. Then rotate the velocity 45
99
+ # degrees anticlockwise to compute the movement.
100
+ pos = state.unit_positions[idx]
101
+ new_pos = (
102
+ pos
103
+ + vec
104
+ * self.unit_type_velocities[state.unit_types[idx]]
105
+ * self.time_per_step
106
+ )
107
+ # avoid going out of bounds
108
+ new_pos = jnp.maximum(
109
+ jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
110
+ jnp.zeros((2,)),
111
+ )
112
+
113
+ #######################################################################
114
+ ############################################ avoid going into obstacles
115
+ obs = self.obstacle_coords
116
+ obs_end = obs + self.obstacle_deltas
117
+ inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
118
+ new_pos = jnp.where(inters, pos, new_pos)
119
+
120
+ #######################################################################
121
+ #######################################################################
122
+
123
+ return new_pos
124
+
125
+ #######################################################################
126
+ ######################################### units close enough to get hit
127
+
128
+ def bystander_fn(attacked_idx):
129
+ idxs = jnp.zeros((self.num_agents,))
130
+ idxs *= (
131
+ jnp.linalg.norm(
132
+ state.unit_positions - state.unit_positions[attacked_idx], axis=-1
133
+ )
134
+ < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
135
+ )
136
+ return idxs
137
+
138
+ #######################################################################
139
+ #######################################################################
140
+
141
+ def update_agent_health(idx, action, key): # TODO: add attack blasts
142
+ # for team 1, their attack actions are labelled in
143
+ # reverse order because that is the order they are
144
+ # observed in
145
+ attacked_idx = jax.lax.cond(
146
+ idx < self.num_allies,
147
+ lambda: action + self.num_allies - self.num_movement_actions,
148
+ lambda: self.num_allies - 1 - (action - self.num_movement_actions),
149
+ )
150
+ # deal with no-op attack actions (i.e. agents that are moving instead)
151
+ attacked_idx = jax.lax.select(
152
+ action < self.num_movement_actions, idx, attacked_idx
153
+ )
154
+
155
+ attack_valid = (
156
+ (
157
+ jnp.linalg.norm(
158
+ state.unit_positions[idx] - state.unit_positions[attacked_idx]
159
+ )
160
+ < self.unit_type_attack_ranges[state.unit_types[idx]]
161
+ )
162
+ & state.unit_alive[idx]
163
+ & state.unit_alive[attacked_idx]
164
+ )
165
+ attack_valid = attack_valid & (idx != attacked_idx)
166
+ attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
167
+ health_diff = jax.lax.select(
168
+ attack_valid,
169
+ -self.unit_type_attacks[state.unit_types[idx]],
170
+ 0.0,
171
+ )
172
+ # design choice based on the pysc2 randomness details.
173
+ # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
174
+
175
+ #########################################################
176
+ ############################### Add bystander health diff
177
+
178
+ bystander_idxs = bystander_fn(attacked_idx) # TODO: use
179
+ bystander_valid = (
180
+ jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
181
+ .astype(jnp.bool_)
182
+ .astype(jnp.float32)
183
+ )
184
+ bystander_health_diff = (
185
+ bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
186
+ )
187
+
188
+ #########################################################
189
+ #########################################################
190
+
191
+ cooldown_deviation = jax.random.uniform(
192
+ key, minval=-self.time_per_step, maxval=2 * self.time_per_step
193
+ )
194
+ cooldown = (
195
+ self.unit_type_weapon_cooldowns[state.unit_types[idx]]
196
+ + cooldown_deviation
197
+ )
198
+ cooldown_diff = jax.lax.select(
199
+ attack_valid,
200
+ # subtract the current cooldown because we are
201
+ # going to add it back. This way we effectively
202
+ # set the new cooldown to `cooldown`
203
+ cooldown - state.unit_weapon_cooldowns[idx],
204
+ -self.time_per_step,
205
+ )
206
+ return (
207
+ health_diff,
208
+ attacked_idx,
209
+ cooldown_diff,
210
+ (bystander_health_diff, bystander_idxs),
211
+ )
212
+
213
+ def perform_agent_action(idx, action, key):
214
+ movement_action, attack_action = action
215
+ new_pos = update_position(idx, movement_action)
216
+ health_diff, attacked_idxes, cooldown_diff, (bystander) = (
217
+ update_agent_health(idx, attack_action, key)
218
+ )
219
+
220
+ return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
221
+
222
+ keys = jax.random.split(key, num=self.num_agents)
223
+ pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
224
+ perform_agent_action
225
+ )(jnp.arange(self.num_agents), actions, keys)
226
+
227
+ # checked that no unit passed through an obstacles
228
+ new_pos = self._our_push_units_away(pos, state.unit_types)
229
+
230
+ # avoid going into obstacles after being pushed
231
+ obs = self.obstacle_coords
232
+ obs_end = obs + self.obstacle_deltas
233
+
234
+ def check_obstacles(pos, new_pos, obs, obs_end):
235
+ inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
+ return jnp.where(inters, pos, new_pos)
237
+
238
+ pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
+
240
+ # Multiple enemies can attack the same unit.
241
+ # We have `(health_diff, attacked_idx)` pairs.
242
+ # `jax.lax.scatter_add` aggregates these exactly
243
+ # in the way we want -- duplicate idxes will have their
244
+ # health differences added together. However, it is a
245
+ # super thin wrapper around the XLA scatter operation,
246
+ # which has this bonkers syntax and requires this dnums
247
+ # parameter. The usage here was inferred from a test:
248
+ # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
249
+ dnums = jax.lax.ScatterDimensionNumbers(
250
+ update_window_dims=(),
251
+ inserted_window_dims=(0,),
252
+ scatter_dims_to_operand_dims=(0,),
253
+ )
254
+ unit_health = jnp.maximum(
255
+ jax.lax.scatter_add(
256
+ state.unit_health,
257
+ jnp.expand_dims(attacked_idxes, 1),
258
+ health_diff,
259
+ dnums,
260
+ ),
261
+ 0.0,
262
+ )
263
+
264
+ #########################################################
265
+ ############################ subtracting bystander health
266
+
267
+ _, bystander_health_diff = bystander
268
+ unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
269
+
270
+ #########################################################
271
+ #########################################################
272
+
273
+ unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
274
+ state = state.replace(
275
+ unit_health=unit_health,
276
+ unit_positions=pos,
277
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
278
+ )
279
+ return state
280
+
281
+ if __name__ == "__main__":
282
+ env = Parabellum(map_width=256, map_height=256)
283
+ rng, key = random.split(random.PRNGKey(0))
284
+ obs, state = env.reset(key)
285
+ state_seq = []
286
+ for step in range(100):
287
+ rng, key = random.split(rng)
288
+ key_act = random.split(key, len(env.agents))
289
+ actions = {
290
+ agent: jax.random.randint(key_act[i], (), 0, 5)
291
+ for i, agent in enumerate(env.agents)
292
+ }
293
+ _, state, _, _, _ = env.step(key, state, actions)
294
+ state_seq.append((obs, state, actions))
295
+
296
+
@@ -0,0 +1,230 @@
1
+ """Visualizer for the Parabellum environment"""
2
+
3
+ from tqdm import tqdm
4
+ import jax.numpy as jnp
5
+ import jax
6
+ from jax import vmap
7
+ from functools import partial
8
+ import darkdetect
9
+ import pygame
10
+ from moviepy.editor import ImageSequenceClip
11
+ from typing import Optional
12
+ from jaxmarl.environments.multi_agent_env import MultiAgentEnv
13
+ from jaxmarl.viz.visualizer import SMAXVisualizer
14
+
15
+ # default dict
16
+ from collections import defaultdict
17
+
18
+
19
+ # constants
20
+ action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
21
+
22
+
23
+ class Visualizer(SMAXVisualizer):
24
+ def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
25
+ super().__init__(env, state_seq, reward_seq)
26
+ # remove fig and ax from super
27
+ self.fig, self.ax = None, None
28
+ self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
29
+ self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
30
+ self.s = 1000
31
+ self.scale = self.s / self.env.map_width
32
+ self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
33
+ self.bullet_seq = bullet_fn(self.env, self.state_seq)
34
+
35
+ def render_agents(self, screen, state):
36
+ time_tuple = zip(
37
+ state.unit_positions,
38
+ state.unit_teams,
39
+ state.unit_types,
40
+ state.unit_health,
41
+ )
42
+ for idx, (pos, team, kind, hp) in enumerate(time_tuple):
43
+ face_col = self.fg if int(team.item()) == 0 else self.bg
44
+ pos = tuple((pos * self.scale).tolist())
45
+
46
+ # draw the agent
47
+ if hp > 0:
48
+ hp_frac = hp / self.env.unit_type_health[kind]
49
+ unit_size = self.env.unit_type_radiuses[kind]
50
+ radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
51
+ pygame.draw.circle(screen, face_col, pos, radius)
52
+ pygame.draw.circle(screen, self.fg, pos, radius, 1)
53
+
54
+ # draw the sight range
55
+ # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
56
+ # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
57
+
58
+ # draw attack range
59
+ # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
60
+ # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
61
+ # work out which agents are being shot
62
+
63
+ def render_action(self, screen, action):
64
+ def coord_fn(idx, n, team):
65
+ return (
66
+ self.s / 20 if team == 0 else self.s - self.s / 20,
67
+ # vertically centered so that n / 2 is above and below the center
68
+ self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
69
+ )
70
+
71
+ for idx in range(self.env.num_allies):
72
+ symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
73
+ font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
74
+ text = font.render(symb, True, self.fg)
75
+ coord = coord_fn(idx, self.env.num_allies, 0)
76
+ screen.blit(text, coord)
77
+
78
+ for idx in range(self.env.num_enemies):
79
+ symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
80
+ font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
81
+ text = font.render(symb, True, self.fg)
82
+ coord = coord_fn(idx, self.env.num_enemies, 1)
83
+ screen.blit(text, coord)
84
+
85
+ def render_obstacles(self, screen):
86
+ for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
87
+ d = tuple(((c + d) * self.scale).tolist())
88
+ c = tuple((c * self.scale).tolist())
89
+ pygame.draw.line(screen, self.fg, c, d, 5)
90
+
91
+ def render_bullets(self, screen, bullets, jdx):
92
+ jdx += 1
93
+ ally_bullets, enemy_bullets = bullets
94
+ for source, target in ally_bullets:
95
+ position = source + (target - source) * jdx / 8
96
+ position *= self.scale
97
+ pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
98
+ for source, target in enemy_bullets:
99
+ position = source + (target - source) * jdx / 8
100
+ position *= self.scale
101
+ pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
102
+
103
+ def animate(self, save_fname: str = "parabellum.mp4"):
104
+ if not self.have_expanded:
105
+ self.expand_state_seq()
106
+ frames = [] # frames for the video
107
+ pygame.init() # initialize pygame
108
+ for idx, (_, state, _) in tqdm(
109
+ enumerate(self.state_seq), total=len(self.state_seq)
110
+ ):
111
+ screen = pygame.Surface(
112
+ (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
113
+ )
114
+ screen.fill(self.bg) # fill the screen with the background color
115
+
116
+ self.render_agents(screen, state) # render the agents
117
+ self.render_action(screen, self.action_seq[idx // 8])
118
+ self.render_obstacles(screen) # render the obstacles
119
+
120
+ # bullets
121
+ if idx < len(self.bullet_seq) * 8:
122
+ bullets = self.bullet_seq[idx // 8]
123
+ self.render_bullets(screen, bullets, idx % 8)
124
+
125
+ # rotate the screen and append to frames
126
+ frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
127
+
128
+ # save the images
129
+ clip = ImageSequenceClip(frames, fps=48)
130
+ clip.write_videofile(save_fname, fps=48)
131
+ # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
132
+ pygame.quit()
133
+
134
+ return clip
135
+
136
+
137
+ # functions
138
+ # bullet functions
139
+ def dist_fn(env, pos): # computing the distances between all ally and enemy agents
140
+ delta = pos[None, :, :] - pos[:, None, :]
141
+ dist = jnp.sqrt((delta**2).sum(axis=2))
142
+ dist = dist[: env.num_allies, env.num_allies :]
143
+ return {"ally": dist, "enemy": dist.T}
144
+
145
+
146
+ def range_fn(env, dists, ranges): # computing what targets are in range
147
+ ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
148
+ enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
149
+ return {"ally": ally_range, "enemy": enemy_range}
150
+
151
+
152
+ def target_fn(acts, in_range, team): # computing the one hot valid targets
153
+ t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
154
+ t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
155
+ t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
156
+ return t_attacks * in_range[team] # one hot valid targets
157
+
158
+
159
+ def attack_fn(env, state_seq): # one hot attack list
160
+ attacks = []
161
+ for _, state, acts in state_seq:
162
+ dists = dist_fn(env, state.unit_positions)
163
+ ranges = env.unit_type_attack_ranges[state.unit_types]
164
+ in_range = range_fn(env, dists, ranges)
165
+ target = partial(target_fn, acts, in_range)
166
+ attack = {"ally": target("ally"), "enemy": target("enemy")}
167
+ attacks.append(attack)
168
+ return attacks
169
+
170
+
171
+ def bullet_fn(env, states):
172
+ bullet_seq = []
173
+ attack_seq = attack_fn(env, states)
174
+
175
+ def aux_fn(team):
176
+ bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
177
+ # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
178
+ return bullets
179
+
180
+ state_zip = zip(states[:-1], states[1:])
181
+ for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
182
+ one_hot = attack_seq[i]
183
+ ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
184
+
185
+ ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
186
+ enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
187
+
188
+ enemy_bullets_source = state.unit_positions[
189
+ enemy_bullets[:, 0] + env.num_allies
190
+ ]
191
+ ally_bullets_target = n_state.unit_positions[
192
+ ally_bullets[:, 1] + env.num_allies
193
+ ]
194
+
195
+ ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
196
+ enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
197
+
198
+ bullet_seq.append((ally_bullets, enemy_bullets))
199
+ return bullet_seq
200
+
201
+
202
+ # test the visualizer
203
+ if __name__ == "__main__":
204
+ from parabellum import Parabellum, Scenario
205
+ from jax import random, numpy as jnp
206
+
207
+ s = Scenario(jnp.array([[16, 0]]),
208
+ jnp.array([[0, 32]]) * 8,
209
+ jnp.zeros((19,), dtype=jnp.uint8),
210
+ 9,
211
+ 10)
212
+ env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
213
+ rng, key = random.split(random.PRNGKey(0))
214
+ obs, state = env.reset(key)
215
+ state_seq = []
216
+ for step in range(50):
217
+ rng, key = random.split(rng)
218
+ key_act = random.split(key, len(env.agents))
219
+ actions = {
220
+ agent: jnp.array(1)
221
+ for i, agent in enumerate(env.agents)
222
+ }
223
+ state_seq.append((key, state, actions))
224
+ rng, key_step = random.split(rng)
225
+ obs, state, reward, done, infos = env.step(key_step, state, actions)
226
+
227
+ vis = Visualizer(env, state_seq)
228
+ vis.animate()
229
+
230
+
parabellum/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .env import Parabellum
1
+ from .env import Parabellum, Scenario
2
2
  from .vis import Visualizer
3
3
 
4
- __all__ = ["Parabellum", "Visualizer"]
4
+ __all__ = ["Parabellum", "Visualizer", "Scenario"]
parabellum/env.py CHANGED
@@ -20,8 +20,8 @@ class Scenario:
20
20
  obstacle_deltas: chex.Array
21
21
 
22
22
  unit_types: chex.Array
23
- num_allies: int = 9
24
- num_enemies: int = 10
23
+ num_allies: int
24
+ num_enemies: int
25
25
 
26
26
  smacv2_position_generation: bool = False
27
27
  smacv2_unit_type_generation: bool = False
@@ -33,6 +33,8 @@ scenarios = {
33
33
  jnp.array([[6, 10], [26, 10]]) * 8,
34
34
  jnp.array([[0, 12], [0, 1]]) * 8,
35
35
  jnp.zeros((19,), dtype=jnp.uint8),
36
+ 9,
37
+ 10,
36
38
  )
37
39
  }
38
40
 
@@ -51,6 +53,28 @@ class Parabellum(SMAX):
51
53
  self.max_steps = 200
52
54
  # overwrite supers _world_step method
53
55
 
56
+
57
+ def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
+ return state
59
+
60
+ def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
+ delta_matrix = pos[:, None] - pos[None, :]
62
+ dist_matrix = (
63
+ jnp.linalg.norm(delta_matrix, axis=-1)
64
+ + jnp.identity(self.num_agents)
65
+ + 1e-6
66
+ )
67
+ radius_matrix = (
68
+ self.unit_type_radiuses[unit_types][:, None]
69
+ + self.unit_type_radiuses[unit_types][None, :]
70
+ )
71
+ overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
72
+ unit_positions = (
73
+ pos
74
+ + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
+ )
76
+ return unit_positions
77
+
54
78
  @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
55
79
  def _world_step( # modified version of JaxMARL's SMAX _world_step
56
80
  self,
@@ -58,6 +82,15 @@ class Parabellum(SMAX):
58
82
  state: State,
59
83
  actions: Tuple[chex.Array, chex.Array],
60
84
  ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
+
86
+ @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
+ def inter_fn(pos, new_pos, obs, obs_end):
88
+ d1 = jnp.cross(obs - pos, new_pos - pos)
89
+ d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
+ d3 = jnp.cross(pos - obs, obs_end - obs)
91
+ d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
+ return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
+
61
94
  def update_position(idx, vec):
62
95
  # Compute the movements slightly strangely.
63
96
  # The velocities below are for diagonal directions
@@ -79,15 +112,6 @@ class Parabellum(SMAX):
79
112
 
80
113
  #######################################################################
81
114
  ############################################ avoid going into obstacles
82
-
83
- @partial(jax.vmap, in_axes=(None, None, 0, 0))
84
- def inter_fn(pos, new_pos, obs, obs_end):
85
- d1 = jnp.cross(obs - pos, new_pos - pos)
86
- d2 = jnp.cross(obs_end - pos, new_pos - pos)
87
- d3 = jnp.cross(pos - obs, obs_end - obs)
88
- d4 = jnp.cross(new_pos - obs, obs_end - obs)
89
- return (d1 * d2 < 0) & (d3 * d4 < 0)
90
-
91
115
  obs = self.obstacle_coords
92
116
  obs_end = obs + self.obstacle_deltas
93
117
  inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
@@ -199,6 +223,20 @@ class Parabellum(SMAX):
199
223
  pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
200
224
  perform_agent_action
201
225
  )(jnp.arange(self.num_agents), actions, keys)
226
+
227
+ # checked that no unit passed through an obstacles
228
+ new_pos = self._our_push_units_away(pos, state.unit_types)
229
+
230
+ # avoid going into obstacles after being pushed
231
+ obs = self.obstacle_coords
232
+ obs_end = obs + self.obstacle_deltas
233
+
234
+ def check_obstacles(pos, new_pos, obs, obs_end):
235
+ inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
+ return jnp.where(inters, pos, new_pos)
237
+
238
+ pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
+
202
240
  # Multiple enemies can attack the same unit.
203
241
  # We have `(health_diff, attacked_idx)` pairs.
204
242
  # `jax.lax.scatter_add` aggregates these exactly
@@ -240,7 +278,6 @@ class Parabellum(SMAX):
240
278
  )
241
279
  return state
242
280
 
243
-
244
281
  if __name__ == "__main__":
245
282
  env = Parabellum(map_width=256, map_height=256)
246
283
  rng, key = random.split(random.PRNGKey(0))
@@ -255,3 +292,5 @@ if __name__ == "__main__":
255
292
  }
256
293
  _, state, _, _, _ = env.step(key, state, actions)
257
294
  state_seq.append((obs, state, actions))
295
+
296
+
parabellum/vis.py CHANGED
@@ -17,7 +17,6 @@ from collections import defaultdict
17
17
 
18
18
 
19
19
  # constants
20
-
21
20
  action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
22
21
 
23
22
 
@@ -202,18 +201,23 @@ def bullet_fn(env, states):
202
201
 
203
202
  # test the visualizer
204
203
  if __name__ == "__main__":
205
- from parabellum import make
204
+ from parabellum import Parabellum, Scenario
206
205
  from jax import random, numpy as jnp
207
206
 
208
- env = make("parabellum", map_width=32, map_height=32)
207
+ s = Scenario(jnp.array([[16, 0]]),
208
+ jnp.array([[0, 32]]) * 8,
209
+ jnp.zeros((19,), dtype=jnp.uint8),
210
+ 9,
211
+ 10)
212
+ env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
209
213
  rng, key = random.split(random.PRNGKey(0))
210
214
  obs, state = env.reset(key)
211
215
  state_seq = []
212
- for step in range(100):
216
+ for step in range(50):
213
217
  rng, key = random.split(rng)
214
218
  key_act = random.split(key, len(env.agents))
215
219
  actions = {
216
- agent: env.action_space(agent).sample(key_act[i])
220
+ agent: jnp.array(1)
217
221
  for i, agent in enumerate(env.agents)
218
222
  }
219
223
  state_seq.append((key, state, actions))
@@ -222,3 +226,5 @@ if __name__ == "__main__":
222
226
 
223
227
  vis = Visualizer(env, state_seq)
224
228
  vis.animate()
229
+
230
+
@@ -1,10 +1,14 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
+ Home-page: https://github.com/syrkis/parabellum
6
+ License: MIT
7
+ Keywords: warfare,simulation,parallel,environment
5
8
  Author: Noah Syrkis
6
- Author-email: noah@syrkis.com
9
+ Author-email: desk@syrkis.com
7
10
  Requires-Python: >=3.11,<4.0
11
+ Classifier: License :: OSI Approved :: MIT License
8
12
  Classifier: Programming Language :: Python :: 3
9
13
  Classifier: Programming Language :: Python :: 3.11
10
14
  Classifier: Programming Language :: Python :: 3.12
@@ -13,6 +17,7 @@ Requires-Dist: jaxmarl (==0.0.3)
13
17
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
14
18
  Requires-Dist: moviepy (>=1.0.3,<2.0.0)
15
19
  Requires-Dist: pygame (>=2.5.2,<3.0.0)
20
+ Project-URL: Repository, https://github.com/syrkis/parabellum
16
21
  Description-Content-Type: text/markdown
17
22
 
18
23
  # parabellum
@@ -35,3 +40,15 @@ pip install parabellum
35
40
  import parabellum as pb
36
41
  ```
37
42
 
43
+ ## TODO
44
+
45
+ - [ ] Parallel pygame vis
46
+ - [ ] Color for health?
47
+ - [ ] Add the ability to see ongoing game.
48
+ - [ ] Bug test friendly fire.
49
+ - [ ] Start sim from arbitrary state.
50
+ - [ ] Save when the episode ends in some state/obs variable
51
+ - [ ] Look for the source of the bug when using more Allies than Enemies
52
+ - [ ] Y inversed axis for parabellum visualization
53
+ - [ ] Units see through obstacles?
54
+
@@ -0,0 +1,9 @@
1
+ parabellum/.ipynb_checkpoints/__init__-checkpoint.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
2
+ parabellum/.ipynb_checkpoints/env-checkpoint.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
3
+ parabellum/.ipynb_checkpoints/vis-checkpoint.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
4
+ parabellum/__init__.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
5
+ parabellum/env.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
6
+ parabellum/vis.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
7
+ parabellum-0.1.2.dist-info/METADATA,sha256=bIOF-Pkl0IcYSHUPWwVxkuNyv_W-rTfKwxmuGxRjLqk,1549
8
+ parabellum-0.1.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
9
+ parabellum-0.1.2.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- parabellum/__init__.py,sha256=CQj_Z18dyAU366U-UAJt8Rggq2mwArfCrn8jfIaLAjA,96
2
- parabellum/env.py,sha256=KE_wDPUbvHu02cMl9elT0cV63xVZ3DCpZIBuBM0GoyM,10074
3
- parabellum/vis.py,sha256=SWlSj1z09qAXUH2W4LeMHWoICZL5mANytLwR0qDXF-A,8812
4
- parabellum-0.1.0.dist-info/METADATA,sha256=jWirrM6EnME0Y3_PetV-dAwifNnLUbSKs-QJsmMCciI,934
5
- parabellum-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
6
- parabellum-0.1.0.dist-info/RECORD,,