parabellum 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py +4 -0
- parabellum/.ipynb_checkpoints/env-checkpoint.py +296 -0
- parabellum/.ipynb_checkpoints/vis-checkpoint.py +230 -0
- parabellum/__init__.py +4 -0
- parabellum/env.py +296 -0
- parabellum/vis.py +230 -0
- parabellum-0.0.0.dist-info/METADATA +55 -0
- parabellum-0.0.0.dist-info/RECORD +9 -0
- parabellum-0.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,296 @@
|
|
1
|
+
"""Parabellum environment based on SMAX"""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import jax
|
5
|
+
import numpy as np
|
6
|
+
from jax import random
|
7
|
+
from jax import jit
|
8
|
+
from flax.struct import dataclass
|
9
|
+
import chex
|
10
|
+
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
|
+
from typing import Tuple, Dict
|
12
|
+
from functools import partial
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class Scenario:
|
17
|
+
"""Parabellum scenario"""
|
18
|
+
|
19
|
+
obstacle_coords: chex.Array
|
20
|
+
obstacle_deltas: chex.Array
|
21
|
+
|
22
|
+
unit_types: chex.Array
|
23
|
+
num_allies: int
|
24
|
+
num_enemies: int
|
25
|
+
|
26
|
+
smacv2_position_generation: bool = False
|
27
|
+
smacv2_unit_type_generation: bool = False
|
28
|
+
|
29
|
+
|
30
|
+
# default scenario
|
31
|
+
scenarios = {
|
32
|
+
"default": Scenario(
|
33
|
+
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
|
+
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
+
9,
|
37
|
+
10,
|
38
|
+
)
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
class Parabellum(SMAX):
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
scenario: Scenario = scenarios["default"],
|
46
|
+
unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
|
47
|
+
**kwargs,
|
48
|
+
):
|
49
|
+
super().__init__(scenario=scenario, **kwargs)
|
50
|
+
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
+
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
+
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
53
|
+
self.max_steps = 200
|
54
|
+
# overwrite supers _world_step method
|
55
|
+
|
56
|
+
|
57
|
+
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
+
return state
|
59
|
+
|
60
|
+
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
+
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
+
dist_matrix = (
|
63
|
+
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
+
+ jnp.identity(self.num_agents)
|
65
|
+
+ 1e-6
|
66
|
+
)
|
67
|
+
radius_matrix = (
|
68
|
+
self.unit_type_radiuses[unit_types][:, None]
|
69
|
+
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
+
)
|
71
|
+
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
+
unit_positions = (
|
73
|
+
pos
|
74
|
+
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
+
)
|
76
|
+
return unit_positions
|
77
|
+
|
78
|
+
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
|
+
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
|
+
self,
|
81
|
+
key: chex.PRNGKey,
|
82
|
+
state: State,
|
83
|
+
actions: Tuple[chex.Array, chex.Array],
|
84
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
+
|
86
|
+
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
+
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
+
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
+
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
+
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
+
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
+
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
+
|
94
|
+
def update_position(idx, vec):
|
95
|
+
# Compute the movements slightly strangely.
|
96
|
+
# The velocities below are for diagonal directions
|
97
|
+
# because these are easier to encode as actions than the four
|
98
|
+
# diagonal directions. Then rotate the velocity 45
|
99
|
+
# degrees anticlockwise to compute the movement.
|
100
|
+
pos = state.unit_positions[idx]
|
101
|
+
new_pos = (
|
102
|
+
pos
|
103
|
+
+ vec
|
104
|
+
* self.unit_type_velocities[state.unit_types[idx]]
|
105
|
+
* self.time_per_step
|
106
|
+
)
|
107
|
+
# avoid going out of bounds
|
108
|
+
new_pos = jnp.maximum(
|
109
|
+
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
110
|
+
jnp.zeros((2,)),
|
111
|
+
)
|
112
|
+
|
113
|
+
#######################################################################
|
114
|
+
############################################ avoid going into obstacles
|
115
|
+
obs = self.obstacle_coords
|
116
|
+
obs_end = obs + self.obstacle_deltas
|
117
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
118
|
+
new_pos = jnp.where(inters, pos, new_pos)
|
119
|
+
|
120
|
+
#######################################################################
|
121
|
+
#######################################################################
|
122
|
+
|
123
|
+
return new_pos
|
124
|
+
|
125
|
+
#######################################################################
|
126
|
+
######################################### units close enough to get hit
|
127
|
+
|
128
|
+
def bystander_fn(attacked_idx):
|
129
|
+
idxs = jnp.zeros((self.num_agents,))
|
130
|
+
idxs *= (
|
131
|
+
jnp.linalg.norm(
|
132
|
+
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
133
|
+
)
|
134
|
+
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
135
|
+
)
|
136
|
+
return idxs
|
137
|
+
|
138
|
+
#######################################################################
|
139
|
+
#######################################################################
|
140
|
+
|
141
|
+
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
142
|
+
# for team 1, their attack actions are labelled in
|
143
|
+
# reverse order because that is the order they are
|
144
|
+
# observed in
|
145
|
+
attacked_idx = jax.lax.cond(
|
146
|
+
idx < self.num_allies,
|
147
|
+
lambda: action + self.num_allies - self.num_movement_actions,
|
148
|
+
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
149
|
+
)
|
150
|
+
# deal with no-op attack actions (i.e. agents that are moving instead)
|
151
|
+
attacked_idx = jax.lax.select(
|
152
|
+
action < self.num_movement_actions, idx, attacked_idx
|
153
|
+
)
|
154
|
+
|
155
|
+
attack_valid = (
|
156
|
+
(
|
157
|
+
jnp.linalg.norm(
|
158
|
+
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
159
|
+
)
|
160
|
+
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
161
|
+
)
|
162
|
+
& state.unit_alive[idx]
|
163
|
+
& state.unit_alive[attacked_idx]
|
164
|
+
)
|
165
|
+
attack_valid = attack_valid & (idx != attacked_idx)
|
166
|
+
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
167
|
+
health_diff = jax.lax.select(
|
168
|
+
attack_valid,
|
169
|
+
-self.unit_type_attacks[state.unit_types[idx]],
|
170
|
+
0.0,
|
171
|
+
)
|
172
|
+
# design choice based on the pysc2 randomness details.
|
173
|
+
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
174
|
+
|
175
|
+
#########################################################
|
176
|
+
############################### Add bystander health diff
|
177
|
+
|
178
|
+
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
179
|
+
bystander_valid = (
|
180
|
+
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
181
|
+
.astype(jnp.bool_)
|
182
|
+
.astype(jnp.float32)
|
183
|
+
)
|
184
|
+
bystander_health_diff = (
|
185
|
+
bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
186
|
+
)
|
187
|
+
|
188
|
+
#########################################################
|
189
|
+
#########################################################
|
190
|
+
|
191
|
+
cooldown_deviation = jax.random.uniform(
|
192
|
+
key, minval=-self.time_per_step, maxval=2 * self.time_per_step
|
193
|
+
)
|
194
|
+
cooldown = (
|
195
|
+
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
196
|
+
+ cooldown_deviation
|
197
|
+
)
|
198
|
+
cooldown_diff = jax.lax.select(
|
199
|
+
attack_valid,
|
200
|
+
# subtract the current cooldown because we are
|
201
|
+
# going to add it back. This way we effectively
|
202
|
+
# set the new cooldown to `cooldown`
|
203
|
+
cooldown - state.unit_weapon_cooldowns[idx],
|
204
|
+
-self.time_per_step,
|
205
|
+
)
|
206
|
+
return (
|
207
|
+
health_diff,
|
208
|
+
attacked_idx,
|
209
|
+
cooldown_diff,
|
210
|
+
(bystander_health_diff, bystander_idxs),
|
211
|
+
)
|
212
|
+
|
213
|
+
def perform_agent_action(idx, action, key):
|
214
|
+
movement_action, attack_action = action
|
215
|
+
new_pos = update_position(idx, movement_action)
|
216
|
+
health_diff, attacked_idxes, cooldown_diff, (bystander) = (
|
217
|
+
update_agent_health(idx, attack_action, key)
|
218
|
+
)
|
219
|
+
|
220
|
+
return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
|
221
|
+
|
222
|
+
keys = jax.random.split(key, num=self.num_agents)
|
223
|
+
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
224
|
+
perform_agent_action
|
225
|
+
)(jnp.arange(self.num_agents), actions, keys)
|
226
|
+
|
227
|
+
# checked that no unit passed through an obstacles
|
228
|
+
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
|
+
|
230
|
+
# avoid going into obstacles after being pushed
|
231
|
+
obs = self.obstacle_coords
|
232
|
+
obs_end = obs + self.obstacle_deltas
|
233
|
+
|
234
|
+
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
+
return jnp.where(inters, pos, new_pos)
|
237
|
+
|
238
|
+
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
+
|
240
|
+
# Multiple enemies can attack the same unit.
|
241
|
+
# We have `(health_diff, attacked_idx)` pairs.
|
242
|
+
# `jax.lax.scatter_add` aggregates these exactly
|
243
|
+
# in the way we want -- duplicate idxes will have their
|
244
|
+
# health differences added together. However, it is a
|
245
|
+
# super thin wrapper around the XLA scatter operation,
|
246
|
+
# which has this bonkers syntax and requires this dnums
|
247
|
+
# parameter. The usage here was inferred from a test:
|
248
|
+
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
249
|
+
dnums = jax.lax.ScatterDimensionNumbers(
|
250
|
+
update_window_dims=(),
|
251
|
+
inserted_window_dims=(0,),
|
252
|
+
scatter_dims_to_operand_dims=(0,),
|
253
|
+
)
|
254
|
+
unit_health = jnp.maximum(
|
255
|
+
jax.lax.scatter_add(
|
256
|
+
state.unit_health,
|
257
|
+
jnp.expand_dims(attacked_idxes, 1),
|
258
|
+
health_diff,
|
259
|
+
dnums,
|
260
|
+
),
|
261
|
+
0.0,
|
262
|
+
)
|
263
|
+
|
264
|
+
#########################################################
|
265
|
+
############################ subtracting bystander health
|
266
|
+
|
267
|
+
_, bystander_health_diff = bystander
|
268
|
+
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
269
|
+
|
270
|
+
#########################################################
|
271
|
+
#########################################################
|
272
|
+
|
273
|
+
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
274
|
+
state = state.replace(
|
275
|
+
unit_health=unit_health,
|
276
|
+
unit_positions=pos,
|
277
|
+
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
278
|
+
)
|
279
|
+
return state
|
280
|
+
|
281
|
+
if __name__ == "__main__":
|
282
|
+
env = Parabellum(map_width=256, map_height=256)
|
283
|
+
rng, key = random.split(random.PRNGKey(0))
|
284
|
+
obs, state = env.reset(key)
|
285
|
+
state_seq = []
|
286
|
+
for step in range(100):
|
287
|
+
rng, key = random.split(rng)
|
288
|
+
key_act = random.split(key, len(env.agents))
|
289
|
+
actions = {
|
290
|
+
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
+
for i, agent in enumerate(env.agents)
|
292
|
+
}
|
293
|
+
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
+
state_seq.append((obs, state, actions))
|
295
|
+
|
296
|
+
|
@@ -0,0 +1,230 @@
|
|
1
|
+
"""Visualizer for the Parabellum environment"""
|
2
|
+
|
3
|
+
from tqdm import tqdm
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax
|
6
|
+
from jax import vmap
|
7
|
+
from functools import partial
|
8
|
+
import darkdetect
|
9
|
+
import pygame
|
10
|
+
from moviepy.editor import ImageSequenceClip
|
11
|
+
from typing import Optional
|
12
|
+
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
13
|
+
from jaxmarl.viz.visualizer import SMAXVisualizer
|
14
|
+
|
15
|
+
# default dict
|
16
|
+
from collections import defaultdict
|
17
|
+
|
18
|
+
|
19
|
+
# constants
|
20
|
+
action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
|
21
|
+
|
22
|
+
|
23
|
+
class Visualizer(SMAXVisualizer):
|
24
|
+
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
25
|
+
super().__init__(env, state_seq, reward_seq)
|
26
|
+
# remove fig and ax from super
|
27
|
+
self.fig, self.ax = None, None
|
28
|
+
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
29
|
+
self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
|
30
|
+
self.s = 1000
|
31
|
+
self.scale = self.s / self.env.map_width
|
32
|
+
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
33
|
+
self.bullet_seq = bullet_fn(self.env, self.state_seq)
|
34
|
+
|
35
|
+
def render_agents(self, screen, state):
|
36
|
+
time_tuple = zip(
|
37
|
+
state.unit_positions,
|
38
|
+
state.unit_teams,
|
39
|
+
state.unit_types,
|
40
|
+
state.unit_health,
|
41
|
+
)
|
42
|
+
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
43
|
+
face_col = self.fg if int(team.item()) == 0 else self.bg
|
44
|
+
pos = tuple((pos * self.scale).tolist())
|
45
|
+
|
46
|
+
# draw the agent
|
47
|
+
if hp > 0:
|
48
|
+
hp_frac = hp / self.env.unit_type_health[kind]
|
49
|
+
unit_size = self.env.unit_type_radiuses[kind]
|
50
|
+
radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
|
51
|
+
pygame.draw.circle(screen, face_col, pos, radius)
|
52
|
+
pygame.draw.circle(screen, self.fg, pos, radius, 1)
|
53
|
+
|
54
|
+
# draw the sight range
|
55
|
+
# sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
|
56
|
+
# pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
|
57
|
+
|
58
|
+
# draw attack range
|
59
|
+
# attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
|
60
|
+
# pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
|
61
|
+
# work out which agents are being shot
|
62
|
+
|
63
|
+
def render_action(self, screen, action):
|
64
|
+
def coord_fn(idx, n, team):
|
65
|
+
return (
|
66
|
+
self.s / 20 if team == 0 else self.s - self.s / 20,
|
67
|
+
# vertically centered so that n / 2 is above and below the center
|
68
|
+
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
69
|
+
)
|
70
|
+
|
71
|
+
for idx in range(self.env.num_allies):
|
72
|
+
symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
|
73
|
+
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
74
|
+
text = font.render(symb, True, self.fg)
|
75
|
+
coord = coord_fn(idx, self.env.num_allies, 0)
|
76
|
+
screen.blit(text, coord)
|
77
|
+
|
78
|
+
for idx in range(self.env.num_enemies):
|
79
|
+
symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
|
80
|
+
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
81
|
+
text = font.render(symb, True, self.fg)
|
82
|
+
coord = coord_fn(idx, self.env.num_enemies, 1)
|
83
|
+
screen.blit(text, coord)
|
84
|
+
|
85
|
+
def render_obstacles(self, screen):
|
86
|
+
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
87
|
+
d = tuple(((c + d) * self.scale).tolist())
|
88
|
+
c = tuple((c * self.scale).tolist())
|
89
|
+
pygame.draw.line(screen, self.fg, c, d, 5)
|
90
|
+
|
91
|
+
def render_bullets(self, screen, bullets, jdx):
|
92
|
+
jdx += 1
|
93
|
+
ally_bullets, enemy_bullets = bullets
|
94
|
+
for source, target in ally_bullets:
|
95
|
+
position = source + (target - source) * jdx / 8
|
96
|
+
position *= self.scale
|
97
|
+
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
98
|
+
for source, target in enemy_bullets:
|
99
|
+
position = source + (target - source) * jdx / 8
|
100
|
+
position *= self.scale
|
101
|
+
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
102
|
+
|
103
|
+
def animate(self, save_fname: str = "parabellum.mp4"):
|
104
|
+
if not self.have_expanded:
|
105
|
+
self.expand_state_seq()
|
106
|
+
frames = [] # frames for the video
|
107
|
+
pygame.init() # initialize pygame
|
108
|
+
for idx, (_, state, _) in tqdm(
|
109
|
+
enumerate(self.state_seq), total=len(self.state_seq)
|
110
|
+
):
|
111
|
+
screen = pygame.Surface(
|
112
|
+
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
113
|
+
)
|
114
|
+
screen.fill(self.bg) # fill the screen with the background color
|
115
|
+
|
116
|
+
self.render_agents(screen, state) # render the agents
|
117
|
+
self.render_action(screen, self.action_seq[idx // 8])
|
118
|
+
self.render_obstacles(screen) # render the obstacles
|
119
|
+
|
120
|
+
# bullets
|
121
|
+
if idx < len(self.bullet_seq) * 8:
|
122
|
+
bullets = self.bullet_seq[idx // 8]
|
123
|
+
self.render_bullets(screen, bullets, idx % 8)
|
124
|
+
|
125
|
+
# rotate the screen and append to frames
|
126
|
+
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
127
|
+
|
128
|
+
# save the images
|
129
|
+
clip = ImageSequenceClip(frames, fps=48)
|
130
|
+
clip.write_videofile(save_fname, fps=48)
|
131
|
+
# clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
132
|
+
pygame.quit()
|
133
|
+
|
134
|
+
return clip
|
135
|
+
|
136
|
+
|
137
|
+
# functions
|
138
|
+
# bullet functions
|
139
|
+
def dist_fn(env, pos): # computing the distances between all ally and enemy agents
|
140
|
+
delta = pos[None, :, :] - pos[:, None, :]
|
141
|
+
dist = jnp.sqrt((delta**2).sum(axis=2))
|
142
|
+
dist = dist[: env.num_allies, env.num_allies :]
|
143
|
+
return {"ally": dist, "enemy": dist.T}
|
144
|
+
|
145
|
+
|
146
|
+
def range_fn(env, dists, ranges): # computing what targets are in range
|
147
|
+
ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
|
148
|
+
enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
|
149
|
+
return {"ally": ally_range, "enemy": enemy_range}
|
150
|
+
|
151
|
+
|
152
|
+
def target_fn(acts, in_range, team): # computing the one hot valid targets
|
153
|
+
t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
|
154
|
+
t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
|
155
|
+
t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
|
156
|
+
return t_attacks * in_range[team] # one hot valid targets
|
157
|
+
|
158
|
+
|
159
|
+
def attack_fn(env, state_seq): # one hot attack list
|
160
|
+
attacks = []
|
161
|
+
for _, state, acts in state_seq:
|
162
|
+
dists = dist_fn(env, state.unit_positions)
|
163
|
+
ranges = env.unit_type_attack_ranges[state.unit_types]
|
164
|
+
in_range = range_fn(env, dists, ranges)
|
165
|
+
target = partial(target_fn, acts, in_range)
|
166
|
+
attack = {"ally": target("ally"), "enemy": target("enemy")}
|
167
|
+
attacks.append(attack)
|
168
|
+
return attacks
|
169
|
+
|
170
|
+
|
171
|
+
def bullet_fn(env, states):
|
172
|
+
bullet_seq = []
|
173
|
+
attack_seq = attack_fn(env, states)
|
174
|
+
|
175
|
+
def aux_fn(team):
|
176
|
+
bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
|
177
|
+
# bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
|
178
|
+
return bullets
|
179
|
+
|
180
|
+
state_zip = zip(states[:-1], states[1:])
|
181
|
+
for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
|
182
|
+
one_hot = attack_seq[i]
|
183
|
+
ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
|
184
|
+
|
185
|
+
ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
|
186
|
+
enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
|
187
|
+
|
188
|
+
enemy_bullets_source = state.unit_positions[
|
189
|
+
enemy_bullets[:, 0] + env.num_allies
|
190
|
+
]
|
191
|
+
ally_bullets_target = n_state.unit_positions[
|
192
|
+
ally_bullets[:, 1] + env.num_allies
|
193
|
+
]
|
194
|
+
|
195
|
+
ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
|
196
|
+
enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
|
197
|
+
|
198
|
+
bullet_seq.append((ally_bullets, enemy_bullets))
|
199
|
+
return bullet_seq
|
200
|
+
|
201
|
+
|
202
|
+
# test the visualizer
|
203
|
+
if __name__ == "__main__":
|
204
|
+
from parabellum import Parabellum, Scenario
|
205
|
+
from jax import random, numpy as jnp
|
206
|
+
|
207
|
+
s = Scenario(jnp.array([[16, 0]]),
|
208
|
+
jnp.array([[0, 32]]) * 8,
|
209
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
210
|
+
9,
|
211
|
+
10)
|
212
|
+
env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
|
213
|
+
rng, key = random.split(random.PRNGKey(0))
|
214
|
+
obs, state = env.reset(key)
|
215
|
+
state_seq = []
|
216
|
+
for step in range(50):
|
217
|
+
rng, key = random.split(rng)
|
218
|
+
key_act = random.split(key, len(env.agents))
|
219
|
+
actions = {
|
220
|
+
agent: jnp.array(1)
|
221
|
+
for i, agent in enumerate(env.agents)
|
222
|
+
}
|
223
|
+
state_seq.append((key, state, actions))
|
224
|
+
rng, key_step = random.split(rng)
|
225
|
+
obs, state, reward, done, infos = env.step(key_step, state, actions)
|
226
|
+
|
227
|
+
vis = Visualizer(env, state_seq)
|
228
|
+
vis.animate()
|
229
|
+
|
230
|
+
|
parabellum/__init__.py
ADDED
parabellum/env.py
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
"""Parabellum environment based on SMAX"""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import jax
|
5
|
+
import numpy as np
|
6
|
+
from jax import random
|
7
|
+
from jax import jit
|
8
|
+
from flax.struct import dataclass
|
9
|
+
import chex
|
10
|
+
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
|
+
from typing import Tuple, Dict
|
12
|
+
from functools import partial
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class Scenario:
|
17
|
+
"""Parabellum scenario"""
|
18
|
+
|
19
|
+
obstacle_coords: chex.Array
|
20
|
+
obstacle_deltas: chex.Array
|
21
|
+
|
22
|
+
unit_types: chex.Array
|
23
|
+
num_allies: int
|
24
|
+
num_enemies: int
|
25
|
+
|
26
|
+
smacv2_position_generation: bool = False
|
27
|
+
smacv2_unit_type_generation: bool = False
|
28
|
+
|
29
|
+
|
30
|
+
# default scenario
|
31
|
+
scenarios = {
|
32
|
+
"default": Scenario(
|
33
|
+
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
|
+
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
+
9,
|
37
|
+
10,
|
38
|
+
)
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
class Parabellum(SMAX):
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
scenario: Scenario = scenarios["default"],
|
46
|
+
unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
|
47
|
+
**kwargs,
|
48
|
+
):
|
49
|
+
super().__init__(scenario=scenario, **kwargs)
|
50
|
+
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
+
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
+
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
53
|
+
self.max_steps = 200
|
54
|
+
# overwrite supers _world_step method
|
55
|
+
|
56
|
+
|
57
|
+
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
+
return state
|
59
|
+
|
60
|
+
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
+
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
+
dist_matrix = (
|
63
|
+
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
+
+ jnp.identity(self.num_agents)
|
65
|
+
+ 1e-6
|
66
|
+
)
|
67
|
+
radius_matrix = (
|
68
|
+
self.unit_type_radiuses[unit_types][:, None]
|
69
|
+
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
+
)
|
71
|
+
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
+
unit_positions = (
|
73
|
+
pos
|
74
|
+
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
+
)
|
76
|
+
return unit_positions
|
77
|
+
|
78
|
+
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
|
+
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
|
+
self,
|
81
|
+
key: chex.PRNGKey,
|
82
|
+
state: State,
|
83
|
+
actions: Tuple[chex.Array, chex.Array],
|
84
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
+
|
86
|
+
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
+
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
+
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
+
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
+
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
+
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
+
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
+
|
94
|
+
def update_position(idx, vec):
|
95
|
+
# Compute the movements slightly strangely.
|
96
|
+
# The velocities below are for diagonal directions
|
97
|
+
# because these are easier to encode as actions than the four
|
98
|
+
# diagonal directions. Then rotate the velocity 45
|
99
|
+
# degrees anticlockwise to compute the movement.
|
100
|
+
pos = state.unit_positions[idx]
|
101
|
+
new_pos = (
|
102
|
+
pos
|
103
|
+
+ vec
|
104
|
+
* self.unit_type_velocities[state.unit_types[idx]]
|
105
|
+
* self.time_per_step
|
106
|
+
)
|
107
|
+
# avoid going out of bounds
|
108
|
+
new_pos = jnp.maximum(
|
109
|
+
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
110
|
+
jnp.zeros((2,)),
|
111
|
+
)
|
112
|
+
|
113
|
+
#######################################################################
|
114
|
+
############################################ avoid going into obstacles
|
115
|
+
obs = self.obstacle_coords
|
116
|
+
obs_end = obs + self.obstacle_deltas
|
117
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
118
|
+
new_pos = jnp.where(inters, pos, new_pos)
|
119
|
+
|
120
|
+
#######################################################################
|
121
|
+
#######################################################################
|
122
|
+
|
123
|
+
return new_pos
|
124
|
+
|
125
|
+
#######################################################################
|
126
|
+
######################################### units close enough to get hit
|
127
|
+
|
128
|
+
def bystander_fn(attacked_idx):
|
129
|
+
idxs = jnp.zeros((self.num_agents,))
|
130
|
+
idxs *= (
|
131
|
+
jnp.linalg.norm(
|
132
|
+
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
133
|
+
)
|
134
|
+
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
135
|
+
)
|
136
|
+
return idxs
|
137
|
+
|
138
|
+
#######################################################################
|
139
|
+
#######################################################################
|
140
|
+
|
141
|
+
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
142
|
+
# for team 1, their attack actions are labelled in
|
143
|
+
# reverse order because that is the order they are
|
144
|
+
# observed in
|
145
|
+
attacked_idx = jax.lax.cond(
|
146
|
+
idx < self.num_allies,
|
147
|
+
lambda: action + self.num_allies - self.num_movement_actions,
|
148
|
+
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
149
|
+
)
|
150
|
+
# deal with no-op attack actions (i.e. agents that are moving instead)
|
151
|
+
attacked_idx = jax.lax.select(
|
152
|
+
action < self.num_movement_actions, idx, attacked_idx
|
153
|
+
)
|
154
|
+
|
155
|
+
attack_valid = (
|
156
|
+
(
|
157
|
+
jnp.linalg.norm(
|
158
|
+
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
159
|
+
)
|
160
|
+
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
161
|
+
)
|
162
|
+
& state.unit_alive[idx]
|
163
|
+
& state.unit_alive[attacked_idx]
|
164
|
+
)
|
165
|
+
attack_valid = attack_valid & (idx != attacked_idx)
|
166
|
+
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
167
|
+
health_diff = jax.lax.select(
|
168
|
+
attack_valid,
|
169
|
+
-self.unit_type_attacks[state.unit_types[idx]],
|
170
|
+
0.0,
|
171
|
+
)
|
172
|
+
# design choice based on the pysc2 randomness details.
|
173
|
+
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
174
|
+
|
175
|
+
#########################################################
|
176
|
+
############################### Add bystander health diff
|
177
|
+
|
178
|
+
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
179
|
+
bystander_valid = (
|
180
|
+
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
181
|
+
.astype(jnp.bool_)
|
182
|
+
.astype(jnp.float32)
|
183
|
+
)
|
184
|
+
bystander_health_diff = (
|
185
|
+
bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
186
|
+
)
|
187
|
+
|
188
|
+
#########################################################
|
189
|
+
#########################################################
|
190
|
+
|
191
|
+
cooldown_deviation = jax.random.uniform(
|
192
|
+
key, minval=-self.time_per_step, maxval=2 * self.time_per_step
|
193
|
+
)
|
194
|
+
cooldown = (
|
195
|
+
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
196
|
+
+ cooldown_deviation
|
197
|
+
)
|
198
|
+
cooldown_diff = jax.lax.select(
|
199
|
+
attack_valid,
|
200
|
+
# subtract the current cooldown because we are
|
201
|
+
# going to add it back. This way we effectively
|
202
|
+
# set the new cooldown to `cooldown`
|
203
|
+
cooldown - state.unit_weapon_cooldowns[idx],
|
204
|
+
-self.time_per_step,
|
205
|
+
)
|
206
|
+
return (
|
207
|
+
health_diff,
|
208
|
+
attacked_idx,
|
209
|
+
cooldown_diff,
|
210
|
+
(bystander_health_diff, bystander_idxs),
|
211
|
+
)
|
212
|
+
|
213
|
+
def perform_agent_action(idx, action, key):
|
214
|
+
movement_action, attack_action = action
|
215
|
+
new_pos = update_position(idx, movement_action)
|
216
|
+
health_diff, attacked_idxes, cooldown_diff, (bystander) = (
|
217
|
+
update_agent_health(idx, attack_action, key)
|
218
|
+
)
|
219
|
+
|
220
|
+
return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
|
221
|
+
|
222
|
+
keys = jax.random.split(key, num=self.num_agents)
|
223
|
+
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
224
|
+
perform_agent_action
|
225
|
+
)(jnp.arange(self.num_agents), actions, keys)
|
226
|
+
|
227
|
+
# checked that no unit passed through an obstacles
|
228
|
+
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
|
+
|
230
|
+
# avoid going into obstacles after being pushed
|
231
|
+
obs = self.obstacle_coords
|
232
|
+
obs_end = obs + self.obstacle_deltas
|
233
|
+
|
234
|
+
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
+
return jnp.where(inters, pos, new_pos)
|
237
|
+
|
238
|
+
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
+
|
240
|
+
# Multiple enemies can attack the same unit.
|
241
|
+
# We have `(health_diff, attacked_idx)` pairs.
|
242
|
+
# `jax.lax.scatter_add` aggregates these exactly
|
243
|
+
# in the way we want -- duplicate idxes will have their
|
244
|
+
# health differences added together. However, it is a
|
245
|
+
# super thin wrapper around the XLA scatter operation,
|
246
|
+
# which has this bonkers syntax and requires this dnums
|
247
|
+
# parameter. The usage here was inferred from a test:
|
248
|
+
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
249
|
+
dnums = jax.lax.ScatterDimensionNumbers(
|
250
|
+
update_window_dims=(),
|
251
|
+
inserted_window_dims=(0,),
|
252
|
+
scatter_dims_to_operand_dims=(0,),
|
253
|
+
)
|
254
|
+
unit_health = jnp.maximum(
|
255
|
+
jax.lax.scatter_add(
|
256
|
+
state.unit_health,
|
257
|
+
jnp.expand_dims(attacked_idxes, 1),
|
258
|
+
health_diff,
|
259
|
+
dnums,
|
260
|
+
),
|
261
|
+
0.0,
|
262
|
+
)
|
263
|
+
|
264
|
+
#########################################################
|
265
|
+
############################ subtracting bystander health
|
266
|
+
|
267
|
+
_, bystander_health_diff = bystander
|
268
|
+
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
269
|
+
|
270
|
+
#########################################################
|
271
|
+
#########################################################
|
272
|
+
|
273
|
+
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
274
|
+
state = state.replace(
|
275
|
+
unit_health=unit_health,
|
276
|
+
unit_positions=pos,
|
277
|
+
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
278
|
+
)
|
279
|
+
return state
|
280
|
+
|
281
|
+
if __name__ == "__main__":
|
282
|
+
env = Parabellum(map_width=256, map_height=256)
|
283
|
+
rng, key = random.split(random.PRNGKey(0))
|
284
|
+
obs, state = env.reset(key)
|
285
|
+
state_seq = []
|
286
|
+
for step in range(100):
|
287
|
+
rng, key = random.split(rng)
|
288
|
+
key_act = random.split(key, len(env.agents))
|
289
|
+
actions = {
|
290
|
+
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
+
for i, agent in enumerate(env.agents)
|
292
|
+
}
|
293
|
+
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
+
state_seq.append((obs, state, actions))
|
295
|
+
|
296
|
+
|
parabellum/vis.py
ADDED
@@ -0,0 +1,230 @@
|
|
1
|
+
"""Visualizer for the Parabellum environment"""
|
2
|
+
|
3
|
+
from tqdm import tqdm
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax
|
6
|
+
from jax import vmap
|
7
|
+
from functools import partial
|
8
|
+
import darkdetect
|
9
|
+
import pygame
|
10
|
+
from moviepy.editor import ImageSequenceClip
|
11
|
+
from typing import Optional
|
12
|
+
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
13
|
+
from jaxmarl.viz.visualizer import SMAXVisualizer
|
14
|
+
|
15
|
+
# default dict
|
16
|
+
from collections import defaultdict
|
17
|
+
|
18
|
+
|
19
|
+
# constants
|
20
|
+
action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
|
21
|
+
|
22
|
+
|
23
|
+
class Visualizer(SMAXVisualizer):
|
24
|
+
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
25
|
+
super().__init__(env, state_seq, reward_seq)
|
26
|
+
# remove fig and ax from super
|
27
|
+
self.fig, self.ax = None, None
|
28
|
+
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
29
|
+
self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
|
30
|
+
self.s = 1000
|
31
|
+
self.scale = self.s / self.env.map_width
|
32
|
+
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
33
|
+
self.bullet_seq = bullet_fn(self.env, self.state_seq)
|
34
|
+
|
35
|
+
def render_agents(self, screen, state):
|
36
|
+
time_tuple = zip(
|
37
|
+
state.unit_positions,
|
38
|
+
state.unit_teams,
|
39
|
+
state.unit_types,
|
40
|
+
state.unit_health,
|
41
|
+
)
|
42
|
+
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
43
|
+
face_col = self.fg if int(team.item()) == 0 else self.bg
|
44
|
+
pos = tuple((pos * self.scale).tolist())
|
45
|
+
|
46
|
+
# draw the agent
|
47
|
+
if hp > 0:
|
48
|
+
hp_frac = hp / self.env.unit_type_health[kind]
|
49
|
+
unit_size = self.env.unit_type_radiuses[kind]
|
50
|
+
radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
|
51
|
+
pygame.draw.circle(screen, face_col, pos, radius)
|
52
|
+
pygame.draw.circle(screen, self.fg, pos, radius, 1)
|
53
|
+
|
54
|
+
# draw the sight range
|
55
|
+
# sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
|
56
|
+
# pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
|
57
|
+
|
58
|
+
# draw attack range
|
59
|
+
# attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
|
60
|
+
# pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
|
61
|
+
# work out which agents are being shot
|
62
|
+
|
63
|
+
def render_action(self, screen, action):
|
64
|
+
def coord_fn(idx, n, team):
|
65
|
+
return (
|
66
|
+
self.s / 20 if team == 0 else self.s - self.s / 20,
|
67
|
+
# vertically centered so that n / 2 is above and below the center
|
68
|
+
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
69
|
+
)
|
70
|
+
|
71
|
+
for idx in range(self.env.num_allies):
|
72
|
+
symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
|
73
|
+
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
74
|
+
text = font.render(symb, True, self.fg)
|
75
|
+
coord = coord_fn(idx, self.env.num_allies, 0)
|
76
|
+
screen.blit(text, coord)
|
77
|
+
|
78
|
+
for idx in range(self.env.num_enemies):
|
79
|
+
symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
|
80
|
+
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
81
|
+
text = font.render(symb, True, self.fg)
|
82
|
+
coord = coord_fn(idx, self.env.num_enemies, 1)
|
83
|
+
screen.blit(text, coord)
|
84
|
+
|
85
|
+
def render_obstacles(self, screen):
|
86
|
+
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
87
|
+
d = tuple(((c + d) * self.scale).tolist())
|
88
|
+
c = tuple((c * self.scale).tolist())
|
89
|
+
pygame.draw.line(screen, self.fg, c, d, 5)
|
90
|
+
|
91
|
+
def render_bullets(self, screen, bullets, jdx):
|
92
|
+
jdx += 1
|
93
|
+
ally_bullets, enemy_bullets = bullets
|
94
|
+
for source, target in ally_bullets:
|
95
|
+
position = source + (target - source) * jdx / 8
|
96
|
+
position *= self.scale
|
97
|
+
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
98
|
+
for source, target in enemy_bullets:
|
99
|
+
position = source + (target - source) * jdx / 8
|
100
|
+
position *= self.scale
|
101
|
+
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
102
|
+
|
103
|
+
def animate(self, save_fname: str = "parabellum.mp4"):
|
104
|
+
if not self.have_expanded:
|
105
|
+
self.expand_state_seq()
|
106
|
+
frames = [] # frames for the video
|
107
|
+
pygame.init() # initialize pygame
|
108
|
+
for idx, (_, state, _) in tqdm(
|
109
|
+
enumerate(self.state_seq), total=len(self.state_seq)
|
110
|
+
):
|
111
|
+
screen = pygame.Surface(
|
112
|
+
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
113
|
+
)
|
114
|
+
screen.fill(self.bg) # fill the screen with the background color
|
115
|
+
|
116
|
+
self.render_agents(screen, state) # render the agents
|
117
|
+
self.render_action(screen, self.action_seq[idx // 8])
|
118
|
+
self.render_obstacles(screen) # render the obstacles
|
119
|
+
|
120
|
+
# bullets
|
121
|
+
if idx < len(self.bullet_seq) * 8:
|
122
|
+
bullets = self.bullet_seq[idx // 8]
|
123
|
+
self.render_bullets(screen, bullets, idx % 8)
|
124
|
+
|
125
|
+
# rotate the screen and append to frames
|
126
|
+
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
127
|
+
|
128
|
+
# save the images
|
129
|
+
clip = ImageSequenceClip(frames, fps=48)
|
130
|
+
clip.write_videofile(save_fname, fps=48)
|
131
|
+
# clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
132
|
+
pygame.quit()
|
133
|
+
|
134
|
+
return clip
|
135
|
+
|
136
|
+
|
137
|
+
# functions
|
138
|
+
# bullet functions
|
139
|
+
def dist_fn(env, pos): # computing the distances between all ally and enemy agents
|
140
|
+
delta = pos[None, :, :] - pos[:, None, :]
|
141
|
+
dist = jnp.sqrt((delta**2).sum(axis=2))
|
142
|
+
dist = dist[: env.num_allies, env.num_allies :]
|
143
|
+
return {"ally": dist, "enemy": dist.T}
|
144
|
+
|
145
|
+
|
146
|
+
def range_fn(env, dists, ranges): # computing what targets are in range
|
147
|
+
ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
|
148
|
+
enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
|
149
|
+
return {"ally": ally_range, "enemy": enemy_range}
|
150
|
+
|
151
|
+
|
152
|
+
def target_fn(acts, in_range, team): # computing the one hot valid targets
|
153
|
+
t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
|
154
|
+
t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
|
155
|
+
t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
|
156
|
+
return t_attacks * in_range[team] # one hot valid targets
|
157
|
+
|
158
|
+
|
159
|
+
def attack_fn(env, state_seq): # one hot attack list
|
160
|
+
attacks = []
|
161
|
+
for _, state, acts in state_seq:
|
162
|
+
dists = dist_fn(env, state.unit_positions)
|
163
|
+
ranges = env.unit_type_attack_ranges[state.unit_types]
|
164
|
+
in_range = range_fn(env, dists, ranges)
|
165
|
+
target = partial(target_fn, acts, in_range)
|
166
|
+
attack = {"ally": target("ally"), "enemy": target("enemy")}
|
167
|
+
attacks.append(attack)
|
168
|
+
return attacks
|
169
|
+
|
170
|
+
|
171
|
+
def bullet_fn(env, states):
|
172
|
+
bullet_seq = []
|
173
|
+
attack_seq = attack_fn(env, states)
|
174
|
+
|
175
|
+
def aux_fn(team):
|
176
|
+
bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
|
177
|
+
# bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
|
178
|
+
return bullets
|
179
|
+
|
180
|
+
state_zip = zip(states[:-1], states[1:])
|
181
|
+
for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
|
182
|
+
one_hot = attack_seq[i]
|
183
|
+
ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
|
184
|
+
|
185
|
+
ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
|
186
|
+
enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
|
187
|
+
|
188
|
+
enemy_bullets_source = state.unit_positions[
|
189
|
+
enemy_bullets[:, 0] + env.num_allies
|
190
|
+
]
|
191
|
+
ally_bullets_target = n_state.unit_positions[
|
192
|
+
ally_bullets[:, 1] + env.num_allies
|
193
|
+
]
|
194
|
+
|
195
|
+
ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
|
196
|
+
enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
|
197
|
+
|
198
|
+
bullet_seq.append((ally_bullets, enemy_bullets))
|
199
|
+
return bullet_seq
|
200
|
+
|
201
|
+
|
202
|
+
# test the visualizer
|
203
|
+
if __name__ == "__main__":
|
204
|
+
from parabellum import Parabellum, Scenario
|
205
|
+
from jax import random, numpy as jnp
|
206
|
+
|
207
|
+
s = Scenario(jnp.array([[16, 0]]),
|
208
|
+
jnp.array([[0, 32]]) * 8,
|
209
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
210
|
+
9,
|
211
|
+
10)
|
212
|
+
env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
|
213
|
+
rng, key = random.split(random.PRNGKey(0))
|
214
|
+
obs, state = env.reset(key)
|
215
|
+
state_seq = []
|
216
|
+
for step in range(50):
|
217
|
+
rng, key = random.split(rng)
|
218
|
+
key_act = random.split(key, len(env.agents))
|
219
|
+
actions = {
|
220
|
+
agent: jnp.array(1)
|
221
|
+
for i, agent in enumerate(env.agents)
|
222
|
+
}
|
223
|
+
state_seq.append((key, state, actions))
|
224
|
+
rng, key_step = random.split(rng)
|
225
|
+
obs, state, reward, done, infos = env.step(key_step, state, actions)
|
226
|
+
|
227
|
+
vis = Visualizer(env, state_seq)
|
228
|
+
vis.animate()
|
229
|
+
|
230
|
+
|
@@ -0,0 +1,55 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: parabellum
|
3
|
+
Version: 0.0.0
|
4
|
+
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
8
|
+
Author: Noah Syrkis
|
9
|
+
Author-email: desk@syrkis.com
|
10
|
+
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
|
16
|
+
Requires-Dist: jaxmarl (==0.0.3)
|
17
|
+
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
18
|
+
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
19
|
+
Requires-Dist: poetry (>=1.8.3,<2.0.0)
|
20
|
+
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
21
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
22
|
+
Description-Content-Type: text/markdown
|
23
|
+
|
24
|
+
# parabellum
|
25
|
+
|
26
|
+
Parabellum is an ultra-scalable, high-performance warfare simulation engine.
|
27
|
+
It is based on JaxMARL's SMAX environment, but has been heavily modified to
|
28
|
+
support a wide range of new features and improvements.
|
29
|
+
|
30
|
+
## Installation
|
31
|
+
|
32
|
+
Install through PyPI:
|
33
|
+
|
34
|
+
```bash
|
35
|
+
pip install parabellum
|
36
|
+
```
|
37
|
+
|
38
|
+
## Usage
|
39
|
+
|
40
|
+
```python
|
41
|
+
import parabellum as pb
|
42
|
+
```
|
43
|
+
|
44
|
+
## TODO
|
45
|
+
|
46
|
+
- [ ] Parallel pygame vis
|
47
|
+
- [ ] Color for health?
|
48
|
+
- [ ] Add the ability to see ongoing game.
|
49
|
+
- [ ] Bug test friendly fire.
|
50
|
+
- [ ] Start sim from arbitrary state.
|
51
|
+
- [ ] Save when the episode ends in some state/obs variable
|
52
|
+
- [ ] Look for the source of the bug when using more Allies than Enemies
|
53
|
+
- [ ] Y inversed axis for parabellum visualization
|
54
|
+
- [ ] Units see through obstacles?
|
55
|
+
|
@@ -0,0 +1,9 @@
|
|
1
|
+
parabellum/.ipynb_checkpoints/__init__-checkpoint.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
|
2
|
+
parabellum/.ipynb_checkpoints/env-checkpoint.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
|
3
|
+
parabellum/.ipynb_checkpoints/vis-checkpoint.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
|
4
|
+
parabellum/__init__.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
|
5
|
+
parabellum/env.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
|
6
|
+
parabellum/vis.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
|
7
|
+
parabellum-0.0.0.dist-info/METADATA,sha256=cV1VBjjoFLEUmDebGlRpARDdJULpNQ6JspYJ5dqU5ns,1588
|
8
|
+
parabellum-0.0.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
9
|
+
parabellum-0.0.0.dist-info/RECORD,,
|