parabellum 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py +4 -0
- parabellum/.ipynb_checkpoints/env-checkpoint.py +296 -0
- parabellum/.ipynb_checkpoints/vis-checkpoint.py +230 -0
- parabellum/__init__.py +2 -2
- parabellum/env.py +51 -12
- parabellum/vis.py +11 -5
- {parabellum-0.1.0.dist-info → parabellum-0.1.2.dist-info}/METADATA +19 -2
- parabellum-0.1.2.dist-info/RECORD +9 -0
- parabellum-0.1.0.dist-info/RECORD +0 -6
- {parabellum-0.1.0.dist-info → parabellum-0.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,296 @@
|
|
1
|
+
"""Parabellum environment based on SMAX"""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import jax
|
5
|
+
import numpy as np
|
6
|
+
from jax import random
|
7
|
+
from jax import jit
|
8
|
+
from flax.struct import dataclass
|
9
|
+
import chex
|
10
|
+
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
|
+
from typing import Tuple, Dict
|
12
|
+
from functools import partial
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class Scenario:
|
17
|
+
"""Parabellum scenario"""
|
18
|
+
|
19
|
+
obstacle_coords: chex.Array
|
20
|
+
obstacle_deltas: chex.Array
|
21
|
+
|
22
|
+
unit_types: chex.Array
|
23
|
+
num_allies: int
|
24
|
+
num_enemies: int
|
25
|
+
|
26
|
+
smacv2_position_generation: bool = False
|
27
|
+
smacv2_unit_type_generation: bool = False
|
28
|
+
|
29
|
+
|
30
|
+
# default scenario
|
31
|
+
scenarios = {
|
32
|
+
"default": Scenario(
|
33
|
+
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
|
+
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
+
9,
|
37
|
+
10,
|
38
|
+
)
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
class Parabellum(SMAX):
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
scenario: Scenario = scenarios["default"],
|
46
|
+
unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
|
47
|
+
**kwargs,
|
48
|
+
):
|
49
|
+
super().__init__(scenario=scenario, **kwargs)
|
50
|
+
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
+
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
+
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
53
|
+
self.max_steps = 200
|
54
|
+
# overwrite supers _world_step method
|
55
|
+
|
56
|
+
|
57
|
+
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
+
return state
|
59
|
+
|
60
|
+
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
+
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
+
dist_matrix = (
|
63
|
+
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
+
+ jnp.identity(self.num_agents)
|
65
|
+
+ 1e-6
|
66
|
+
)
|
67
|
+
radius_matrix = (
|
68
|
+
self.unit_type_radiuses[unit_types][:, None]
|
69
|
+
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
+
)
|
71
|
+
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
+
unit_positions = (
|
73
|
+
pos
|
74
|
+
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
+
)
|
76
|
+
return unit_positions
|
77
|
+
|
78
|
+
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
|
+
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
|
+
self,
|
81
|
+
key: chex.PRNGKey,
|
82
|
+
state: State,
|
83
|
+
actions: Tuple[chex.Array, chex.Array],
|
84
|
+
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
+
|
86
|
+
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
+
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
+
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
+
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
+
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
+
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
+
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
+
|
94
|
+
def update_position(idx, vec):
|
95
|
+
# Compute the movements slightly strangely.
|
96
|
+
# The velocities below are for diagonal directions
|
97
|
+
# because these are easier to encode as actions than the four
|
98
|
+
# diagonal directions. Then rotate the velocity 45
|
99
|
+
# degrees anticlockwise to compute the movement.
|
100
|
+
pos = state.unit_positions[idx]
|
101
|
+
new_pos = (
|
102
|
+
pos
|
103
|
+
+ vec
|
104
|
+
* self.unit_type_velocities[state.unit_types[idx]]
|
105
|
+
* self.time_per_step
|
106
|
+
)
|
107
|
+
# avoid going out of bounds
|
108
|
+
new_pos = jnp.maximum(
|
109
|
+
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
110
|
+
jnp.zeros((2,)),
|
111
|
+
)
|
112
|
+
|
113
|
+
#######################################################################
|
114
|
+
############################################ avoid going into obstacles
|
115
|
+
obs = self.obstacle_coords
|
116
|
+
obs_end = obs + self.obstacle_deltas
|
117
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
118
|
+
new_pos = jnp.where(inters, pos, new_pos)
|
119
|
+
|
120
|
+
#######################################################################
|
121
|
+
#######################################################################
|
122
|
+
|
123
|
+
return new_pos
|
124
|
+
|
125
|
+
#######################################################################
|
126
|
+
######################################### units close enough to get hit
|
127
|
+
|
128
|
+
def bystander_fn(attacked_idx):
|
129
|
+
idxs = jnp.zeros((self.num_agents,))
|
130
|
+
idxs *= (
|
131
|
+
jnp.linalg.norm(
|
132
|
+
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
133
|
+
)
|
134
|
+
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
135
|
+
)
|
136
|
+
return idxs
|
137
|
+
|
138
|
+
#######################################################################
|
139
|
+
#######################################################################
|
140
|
+
|
141
|
+
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
142
|
+
# for team 1, their attack actions are labelled in
|
143
|
+
# reverse order because that is the order they are
|
144
|
+
# observed in
|
145
|
+
attacked_idx = jax.lax.cond(
|
146
|
+
idx < self.num_allies,
|
147
|
+
lambda: action + self.num_allies - self.num_movement_actions,
|
148
|
+
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
149
|
+
)
|
150
|
+
# deal with no-op attack actions (i.e. agents that are moving instead)
|
151
|
+
attacked_idx = jax.lax.select(
|
152
|
+
action < self.num_movement_actions, idx, attacked_idx
|
153
|
+
)
|
154
|
+
|
155
|
+
attack_valid = (
|
156
|
+
(
|
157
|
+
jnp.linalg.norm(
|
158
|
+
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
159
|
+
)
|
160
|
+
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
161
|
+
)
|
162
|
+
& state.unit_alive[idx]
|
163
|
+
& state.unit_alive[attacked_idx]
|
164
|
+
)
|
165
|
+
attack_valid = attack_valid & (idx != attacked_idx)
|
166
|
+
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
167
|
+
health_diff = jax.lax.select(
|
168
|
+
attack_valid,
|
169
|
+
-self.unit_type_attacks[state.unit_types[idx]],
|
170
|
+
0.0,
|
171
|
+
)
|
172
|
+
# design choice based on the pysc2 randomness details.
|
173
|
+
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
174
|
+
|
175
|
+
#########################################################
|
176
|
+
############################### Add bystander health diff
|
177
|
+
|
178
|
+
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
179
|
+
bystander_valid = (
|
180
|
+
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
181
|
+
.astype(jnp.bool_)
|
182
|
+
.astype(jnp.float32)
|
183
|
+
)
|
184
|
+
bystander_health_diff = (
|
185
|
+
bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
186
|
+
)
|
187
|
+
|
188
|
+
#########################################################
|
189
|
+
#########################################################
|
190
|
+
|
191
|
+
cooldown_deviation = jax.random.uniform(
|
192
|
+
key, minval=-self.time_per_step, maxval=2 * self.time_per_step
|
193
|
+
)
|
194
|
+
cooldown = (
|
195
|
+
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
196
|
+
+ cooldown_deviation
|
197
|
+
)
|
198
|
+
cooldown_diff = jax.lax.select(
|
199
|
+
attack_valid,
|
200
|
+
# subtract the current cooldown because we are
|
201
|
+
# going to add it back. This way we effectively
|
202
|
+
# set the new cooldown to `cooldown`
|
203
|
+
cooldown - state.unit_weapon_cooldowns[idx],
|
204
|
+
-self.time_per_step,
|
205
|
+
)
|
206
|
+
return (
|
207
|
+
health_diff,
|
208
|
+
attacked_idx,
|
209
|
+
cooldown_diff,
|
210
|
+
(bystander_health_diff, bystander_idxs),
|
211
|
+
)
|
212
|
+
|
213
|
+
def perform_agent_action(idx, action, key):
|
214
|
+
movement_action, attack_action = action
|
215
|
+
new_pos = update_position(idx, movement_action)
|
216
|
+
health_diff, attacked_idxes, cooldown_diff, (bystander) = (
|
217
|
+
update_agent_health(idx, attack_action, key)
|
218
|
+
)
|
219
|
+
|
220
|
+
return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
|
221
|
+
|
222
|
+
keys = jax.random.split(key, num=self.num_agents)
|
223
|
+
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
224
|
+
perform_agent_action
|
225
|
+
)(jnp.arange(self.num_agents), actions, keys)
|
226
|
+
|
227
|
+
# checked that no unit passed through an obstacles
|
228
|
+
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
|
+
|
230
|
+
# avoid going into obstacles after being pushed
|
231
|
+
obs = self.obstacle_coords
|
232
|
+
obs_end = obs + self.obstacle_deltas
|
233
|
+
|
234
|
+
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
+
return jnp.where(inters, pos, new_pos)
|
237
|
+
|
238
|
+
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
+
|
240
|
+
# Multiple enemies can attack the same unit.
|
241
|
+
# We have `(health_diff, attacked_idx)` pairs.
|
242
|
+
# `jax.lax.scatter_add` aggregates these exactly
|
243
|
+
# in the way we want -- duplicate idxes will have their
|
244
|
+
# health differences added together. However, it is a
|
245
|
+
# super thin wrapper around the XLA scatter operation,
|
246
|
+
# which has this bonkers syntax and requires this dnums
|
247
|
+
# parameter. The usage here was inferred from a test:
|
248
|
+
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
249
|
+
dnums = jax.lax.ScatterDimensionNumbers(
|
250
|
+
update_window_dims=(),
|
251
|
+
inserted_window_dims=(0,),
|
252
|
+
scatter_dims_to_operand_dims=(0,),
|
253
|
+
)
|
254
|
+
unit_health = jnp.maximum(
|
255
|
+
jax.lax.scatter_add(
|
256
|
+
state.unit_health,
|
257
|
+
jnp.expand_dims(attacked_idxes, 1),
|
258
|
+
health_diff,
|
259
|
+
dnums,
|
260
|
+
),
|
261
|
+
0.0,
|
262
|
+
)
|
263
|
+
|
264
|
+
#########################################################
|
265
|
+
############################ subtracting bystander health
|
266
|
+
|
267
|
+
_, bystander_health_diff = bystander
|
268
|
+
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
269
|
+
|
270
|
+
#########################################################
|
271
|
+
#########################################################
|
272
|
+
|
273
|
+
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
274
|
+
state = state.replace(
|
275
|
+
unit_health=unit_health,
|
276
|
+
unit_positions=pos,
|
277
|
+
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
278
|
+
)
|
279
|
+
return state
|
280
|
+
|
281
|
+
if __name__ == "__main__":
|
282
|
+
env = Parabellum(map_width=256, map_height=256)
|
283
|
+
rng, key = random.split(random.PRNGKey(0))
|
284
|
+
obs, state = env.reset(key)
|
285
|
+
state_seq = []
|
286
|
+
for step in range(100):
|
287
|
+
rng, key = random.split(rng)
|
288
|
+
key_act = random.split(key, len(env.agents))
|
289
|
+
actions = {
|
290
|
+
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
+
for i, agent in enumerate(env.agents)
|
292
|
+
}
|
293
|
+
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
+
state_seq.append((obs, state, actions))
|
295
|
+
|
296
|
+
|
@@ -0,0 +1,230 @@
|
|
1
|
+
"""Visualizer for the Parabellum environment"""
|
2
|
+
|
3
|
+
from tqdm import tqdm
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import jax
|
6
|
+
from jax import vmap
|
7
|
+
from functools import partial
|
8
|
+
import darkdetect
|
9
|
+
import pygame
|
10
|
+
from moviepy.editor import ImageSequenceClip
|
11
|
+
from typing import Optional
|
12
|
+
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
|
13
|
+
from jaxmarl.viz.visualizer import SMAXVisualizer
|
14
|
+
|
15
|
+
# default dict
|
16
|
+
from collections import defaultdict
|
17
|
+
|
18
|
+
|
19
|
+
# constants
|
20
|
+
action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
|
21
|
+
|
22
|
+
|
23
|
+
class Visualizer(SMAXVisualizer):
|
24
|
+
def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None):
|
25
|
+
super().__init__(env, state_seq, reward_seq)
|
26
|
+
# remove fig and ax from super
|
27
|
+
self.fig, self.ax = None, None
|
28
|
+
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
29
|
+
self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
|
30
|
+
self.s = 1000
|
31
|
+
self.scale = self.s / self.env.map_width
|
32
|
+
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
33
|
+
self.bullet_seq = bullet_fn(self.env, self.state_seq)
|
34
|
+
|
35
|
+
def render_agents(self, screen, state):
|
36
|
+
time_tuple = zip(
|
37
|
+
state.unit_positions,
|
38
|
+
state.unit_teams,
|
39
|
+
state.unit_types,
|
40
|
+
state.unit_health,
|
41
|
+
)
|
42
|
+
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
43
|
+
face_col = self.fg if int(team.item()) == 0 else self.bg
|
44
|
+
pos = tuple((pos * self.scale).tolist())
|
45
|
+
|
46
|
+
# draw the agent
|
47
|
+
if hp > 0:
|
48
|
+
hp_frac = hp / self.env.unit_type_health[kind]
|
49
|
+
unit_size = self.env.unit_type_radiuses[kind]
|
50
|
+
radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
|
51
|
+
pygame.draw.circle(screen, face_col, pos, radius)
|
52
|
+
pygame.draw.circle(screen, self.fg, pos, radius, 1)
|
53
|
+
|
54
|
+
# draw the sight range
|
55
|
+
# sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
|
56
|
+
# pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
|
57
|
+
|
58
|
+
# draw attack range
|
59
|
+
# attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
|
60
|
+
# pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
|
61
|
+
# work out which agents are being shot
|
62
|
+
|
63
|
+
def render_action(self, screen, action):
|
64
|
+
def coord_fn(idx, n, team):
|
65
|
+
return (
|
66
|
+
self.s / 20 if team == 0 else self.s - self.s / 20,
|
67
|
+
# vertically centered so that n / 2 is above and below the center
|
68
|
+
self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
|
69
|
+
)
|
70
|
+
|
71
|
+
for idx in range(self.env.num_allies):
|
72
|
+
symb = action_to_symbol.get(action[f"ally_{idx}"].astype(int).item(), "Ø")
|
73
|
+
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
74
|
+
text = font.render(symb, True, self.fg)
|
75
|
+
coord = coord_fn(idx, self.env.num_allies, 0)
|
76
|
+
screen.blit(text, coord)
|
77
|
+
|
78
|
+
for idx in range(self.env.num_enemies):
|
79
|
+
symb = action_to_symbol.get(action[f"enemy_{idx}"].astype(int).item(), "Ø")
|
80
|
+
font = pygame.font.SysFont("Fira Code", jnp.sqrt(self.s).astype(int).item())
|
81
|
+
text = font.render(symb, True, self.fg)
|
82
|
+
coord = coord_fn(idx, self.env.num_enemies, 1)
|
83
|
+
screen.blit(text, coord)
|
84
|
+
|
85
|
+
def render_obstacles(self, screen):
|
86
|
+
for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
|
87
|
+
d = tuple(((c + d) * self.scale).tolist())
|
88
|
+
c = tuple((c * self.scale).tolist())
|
89
|
+
pygame.draw.line(screen, self.fg, c, d, 5)
|
90
|
+
|
91
|
+
def render_bullets(self, screen, bullets, jdx):
|
92
|
+
jdx += 1
|
93
|
+
ally_bullets, enemy_bullets = bullets
|
94
|
+
for source, target in ally_bullets:
|
95
|
+
position = source + (target - source) * jdx / 8
|
96
|
+
position *= self.scale
|
97
|
+
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
98
|
+
for source, target in enemy_bullets:
|
99
|
+
position = source + (target - source) * jdx / 8
|
100
|
+
position *= self.scale
|
101
|
+
pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
|
102
|
+
|
103
|
+
def animate(self, save_fname: str = "parabellum.mp4"):
|
104
|
+
if not self.have_expanded:
|
105
|
+
self.expand_state_seq()
|
106
|
+
frames = [] # frames for the video
|
107
|
+
pygame.init() # initialize pygame
|
108
|
+
for idx, (_, state, _) in tqdm(
|
109
|
+
enumerate(self.state_seq), total=len(self.state_seq)
|
110
|
+
):
|
111
|
+
screen = pygame.Surface(
|
112
|
+
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
113
|
+
)
|
114
|
+
screen.fill(self.bg) # fill the screen with the background color
|
115
|
+
|
116
|
+
self.render_agents(screen, state) # render the agents
|
117
|
+
self.render_action(screen, self.action_seq[idx // 8])
|
118
|
+
self.render_obstacles(screen) # render the obstacles
|
119
|
+
|
120
|
+
# bullets
|
121
|
+
if idx < len(self.bullet_seq) * 8:
|
122
|
+
bullets = self.bullet_seq[idx // 8]
|
123
|
+
self.render_bullets(screen, bullets, idx % 8)
|
124
|
+
|
125
|
+
# rotate the screen and append to frames
|
126
|
+
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
127
|
+
|
128
|
+
# save the images
|
129
|
+
clip = ImageSequenceClip(frames, fps=48)
|
130
|
+
clip.write_videofile(save_fname, fps=48)
|
131
|
+
# clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
132
|
+
pygame.quit()
|
133
|
+
|
134
|
+
return clip
|
135
|
+
|
136
|
+
|
137
|
+
# functions
|
138
|
+
# bullet functions
|
139
|
+
def dist_fn(env, pos): # computing the distances between all ally and enemy agents
|
140
|
+
delta = pos[None, :, :] - pos[:, None, :]
|
141
|
+
dist = jnp.sqrt((delta**2).sum(axis=2))
|
142
|
+
dist = dist[: env.num_allies, env.num_allies :]
|
143
|
+
return {"ally": dist, "enemy": dist.T}
|
144
|
+
|
145
|
+
|
146
|
+
def range_fn(env, dists, ranges): # computing what targets are in range
|
147
|
+
ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
|
148
|
+
enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
|
149
|
+
return {"ally": ally_range, "enemy": enemy_range}
|
150
|
+
|
151
|
+
|
152
|
+
def target_fn(acts, in_range, team): # computing the one hot valid targets
|
153
|
+
t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
|
154
|
+
t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
|
155
|
+
t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
|
156
|
+
return t_attacks * in_range[team] # one hot valid targets
|
157
|
+
|
158
|
+
|
159
|
+
def attack_fn(env, state_seq): # one hot attack list
|
160
|
+
attacks = []
|
161
|
+
for _, state, acts in state_seq:
|
162
|
+
dists = dist_fn(env, state.unit_positions)
|
163
|
+
ranges = env.unit_type_attack_ranges[state.unit_types]
|
164
|
+
in_range = range_fn(env, dists, ranges)
|
165
|
+
target = partial(target_fn, acts, in_range)
|
166
|
+
attack = {"ally": target("ally"), "enemy": target("enemy")}
|
167
|
+
attacks.append(attack)
|
168
|
+
return attacks
|
169
|
+
|
170
|
+
|
171
|
+
def bullet_fn(env, states):
|
172
|
+
bullet_seq = []
|
173
|
+
attack_seq = attack_fn(env, states)
|
174
|
+
|
175
|
+
def aux_fn(team):
|
176
|
+
bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
|
177
|
+
# bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
|
178
|
+
return bullets
|
179
|
+
|
180
|
+
state_zip = zip(states[:-1], states[1:])
|
181
|
+
for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
|
182
|
+
one_hot = attack_seq[i]
|
183
|
+
ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
|
184
|
+
|
185
|
+
ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
|
186
|
+
enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
|
187
|
+
|
188
|
+
enemy_bullets_source = state.unit_positions[
|
189
|
+
enemy_bullets[:, 0] + env.num_allies
|
190
|
+
]
|
191
|
+
ally_bullets_target = n_state.unit_positions[
|
192
|
+
ally_bullets[:, 1] + env.num_allies
|
193
|
+
]
|
194
|
+
|
195
|
+
ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
|
196
|
+
enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
|
197
|
+
|
198
|
+
bullet_seq.append((ally_bullets, enemy_bullets))
|
199
|
+
return bullet_seq
|
200
|
+
|
201
|
+
|
202
|
+
# test the visualizer
|
203
|
+
if __name__ == "__main__":
|
204
|
+
from parabellum import Parabellum, Scenario
|
205
|
+
from jax import random, numpy as jnp
|
206
|
+
|
207
|
+
s = Scenario(jnp.array([[16, 0]]),
|
208
|
+
jnp.array([[0, 32]]) * 8,
|
209
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
210
|
+
9,
|
211
|
+
10)
|
212
|
+
env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
|
213
|
+
rng, key = random.split(random.PRNGKey(0))
|
214
|
+
obs, state = env.reset(key)
|
215
|
+
state_seq = []
|
216
|
+
for step in range(50):
|
217
|
+
rng, key = random.split(rng)
|
218
|
+
key_act = random.split(key, len(env.agents))
|
219
|
+
actions = {
|
220
|
+
agent: jnp.array(1)
|
221
|
+
for i, agent in enumerate(env.agents)
|
222
|
+
}
|
223
|
+
state_seq.append((key, state, actions))
|
224
|
+
rng, key_step = random.split(rng)
|
225
|
+
obs, state, reward, done, infos = env.step(key_step, state, actions)
|
226
|
+
|
227
|
+
vis = Visualizer(env, state_seq)
|
228
|
+
vis.animate()
|
229
|
+
|
230
|
+
|
parabellum/__init__.py
CHANGED
parabellum/env.py
CHANGED
@@ -20,8 +20,8 @@ class Scenario:
|
|
20
20
|
obstacle_deltas: chex.Array
|
21
21
|
|
22
22
|
unit_types: chex.Array
|
23
|
-
num_allies: int
|
24
|
-
num_enemies: int
|
23
|
+
num_allies: int
|
24
|
+
num_enemies: int
|
25
25
|
|
26
26
|
smacv2_position_generation: bool = False
|
27
27
|
smacv2_unit_type_generation: bool = False
|
@@ -33,6 +33,8 @@ scenarios = {
|
|
33
33
|
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
34
|
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
35
|
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
+
9,
|
37
|
+
10,
|
36
38
|
)
|
37
39
|
}
|
38
40
|
|
@@ -51,6 +53,28 @@ class Parabellum(SMAX):
|
|
51
53
|
self.max_steps = 200
|
52
54
|
# overwrite supers _world_step method
|
53
55
|
|
56
|
+
|
57
|
+
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
+
return state
|
59
|
+
|
60
|
+
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
+
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
+
dist_matrix = (
|
63
|
+
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
+
+ jnp.identity(self.num_agents)
|
65
|
+
+ 1e-6
|
66
|
+
)
|
67
|
+
radius_matrix = (
|
68
|
+
self.unit_type_radiuses[unit_types][:, None]
|
69
|
+
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
+
)
|
71
|
+
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
+
unit_positions = (
|
73
|
+
pos
|
74
|
+
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
+
)
|
76
|
+
return unit_positions
|
77
|
+
|
54
78
|
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
55
79
|
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
56
80
|
self,
|
@@ -58,6 +82,15 @@ class Parabellum(SMAX):
|
|
58
82
|
state: State,
|
59
83
|
actions: Tuple[chex.Array, chex.Array],
|
60
84
|
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
+
|
86
|
+
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
+
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
+
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
+
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
+
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
+
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
+
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
+
|
61
94
|
def update_position(idx, vec):
|
62
95
|
# Compute the movements slightly strangely.
|
63
96
|
# The velocities below are for diagonal directions
|
@@ -79,15 +112,6 @@ class Parabellum(SMAX):
|
|
79
112
|
|
80
113
|
#######################################################################
|
81
114
|
############################################ avoid going into obstacles
|
82
|
-
|
83
|
-
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
84
|
-
def inter_fn(pos, new_pos, obs, obs_end):
|
85
|
-
d1 = jnp.cross(obs - pos, new_pos - pos)
|
86
|
-
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
87
|
-
d3 = jnp.cross(pos - obs, obs_end - obs)
|
88
|
-
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
89
|
-
return (d1 * d2 < 0) & (d3 * d4 < 0)
|
90
|
-
|
91
115
|
obs = self.obstacle_coords
|
92
116
|
obs_end = obs + self.obstacle_deltas
|
93
117
|
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
@@ -199,6 +223,20 @@ class Parabellum(SMAX):
|
|
199
223
|
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
200
224
|
perform_agent_action
|
201
225
|
)(jnp.arange(self.num_agents), actions, keys)
|
226
|
+
|
227
|
+
# checked that no unit passed through an obstacles
|
228
|
+
new_pos = self._our_push_units_away(pos, state.unit_types)
|
229
|
+
|
230
|
+
# avoid going into obstacles after being pushed
|
231
|
+
obs = self.obstacle_coords
|
232
|
+
obs_end = obs + self.obstacle_deltas
|
233
|
+
|
234
|
+
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
+
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
+
return jnp.where(inters, pos, new_pos)
|
237
|
+
|
238
|
+
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
+
|
202
240
|
# Multiple enemies can attack the same unit.
|
203
241
|
# We have `(health_diff, attacked_idx)` pairs.
|
204
242
|
# `jax.lax.scatter_add` aggregates these exactly
|
@@ -240,7 +278,6 @@ class Parabellum(SMAX):
|
|
240
278
|
)
|
241
279
|
return state
|
242
280
|
|
243
|
-
|
244
281
|
if __name__ == "__main__":
|
245
282
|
env = Parabellum(map_width=256, map_height=256)
|
246
283
|
rng, key = random.split(random.PRNGKey(0))
|
@@ -255,3 +292,5 @@ if __name__ == "__main__":
|
|
255
292
|
}
|
256
293
|
_, state, _, _, _ = env.step(key, state, actions)
|
257
294
|
state_seq.append((obs, state, actions))
|
295
|
+
|
296
|
+
|
parabellum/vis.py
CHANGED
@@ -17,7 +17,6 @@ from collections import defaultdict
|
|
17
17
|
|
18
18
|
|
19
19
|
# constants
|
20
|
-
|
21
20
|
action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
|
22
21
|
|
23
22
|
|
@@ -202,18 +201,23 @@ def bullet_fn(env, states):
|
|
202
201
|
|
203
202
|
# test the visualizer
|
204
203
|
if __name__ == "__main__":
|
205
|
-
from parabellum import
|
204
|
+
from parabellum import Parabellum, Scenario
|
206
205
|
from jax import random, numpy as jnp
|
207
206
|
|
208
|
-
|
207
|
+
s = Scenario(jnp.array([[16, 0]]),
|
208
|
+
jnp.array([[0, 32]]) * 8,
|
209
|
+
jnp.zeros((19,), dtype=jnp.uint8),
|
210
|
+
9,
|
211
|
+
10)
|
212
|
+
env = Parabellum(map_width=32, map_height=32, walls_cause_death=False, scenario=s)
|
209
213
|
rng, key = random.split(random.PRNGKey(0))
|
210
214
|
obs, state = env.reset(key)
|
211
215
|
state_seq = []
|
212
|
-
for step in range(
|
216
|
+
for step in range(50):
|
213
217
|
rng, key = random.split(rng)
|
214
218
|
key_act = random.split(key, len(env.agents))
|
215
219
|
actions = {
|
216
|
-
agent:
|
220
|
+
agent: jnp.array(1)
|
217
221
|
for i, agent in enumerate(env.agents)
|
218
222
|
}
|
219
223
|
state_seq.append((key, state, actions))
|
@@ -222,3 +226,5 @@ if __name__ == "__main__":
|
|
222
226
|
|
223
227
|
vis = Visualizer(env, state_seq)
|
224
228
|
vis.animate()
|
229
|
+
|
230
|
+
|
@@ -1,10 +1,14 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: parabellum
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: Parabellum environment for parallel warfare simulation
|
5
|
+
Home-page: https://github.com/syrkis/parabellum
|
6
|
+
License: MIT
|
7
|
+
Keywords: warfare,simulation,parallel,environment
|
5
8
|
Author: Noah Syrkis
|
6
|
-
Author-email:
|
9
|
+
Author-email: desk@syrkis.com
|
7
10
|
Requires-Python: >=3.11,<4.0
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
8
12
|
Classifier: Programming Language :: Python :: 3
|
9
13
|
Classifier: Programming Language :: Python :: 3.11
|
10
14
|
Classifier: Programming Language :: Python :: 3.12
|
@@ -13,6 +17,7 @@ Requires-Dist: jaxmarl (==0.0.3)
|
|
13
17
|
Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
|
14
18
|
Requires-Dist: moviepy (>=1.0.3,<2.0.0)
|
15
19
|
Requires-Dist: pygame (>=2.5.2,<3.0.0)
|
20
|
+
Project-URL: Repository, https://github.com/syrkis/parabellum
|
16
21
|
Description-Content-Type: text/markdown
|
17
22
|
|
18
23
|
# parabellum
|
@@ -35,3 +40,15 @@ pip install parabellum
|
|
35
40
|
import parabellum as pb
|
36
41
|
```
|
37
42
|
|
43
|
+
## TODO
|
44
|
+
|
45
|
+
- [ ] Parallel pygame vis
|
46
|
+
- [ ] Color for health?
|
47
|
+
- [ ] Add the ability to see ongoing game.
|
48
|
+
- [ ] Bug test friendly fire.
|
49
|
+
- [ ] Start sim from arbitrary state.
|
50
|
+
- [ ] Save when the episode ends in some state/obs variable
|
51
|
+
- [ ] Look for the source of the bug when using more Allies than Enemies
|
52
|
+
- [ ] Y inversed axis for parabellum visualization
|
53
|
+
- [ ] Units see through obstacles?
|
54
|
+
|
@@ -0,0 +1,9 @@
|
|
1
|
+
parabellum/.ipynb_checkpoints/__init__-checkpoint.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
|
2
|
+
parabellum/.ipynb_checkpoints/env-checkpoint.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
|
3
|
+
parabellum/.ipynb_checkpoints/vis-checkpoint.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
|
4
|
+
parabellum/__init__.py,sha256=Yt1RkvkGIJdps0Axpz0ouu-Aaa07032kX04l1l7LXTw,118
|
5
|
+
parabellum/env.py,sha256=Z0PD3MJb9Amxl84MMtghTCF92Gr4ln9qSyRx2DSY15Y,11589
|
6
|
+
parabellum/vis.py,sha256=7zmFqU99gXSW6ueTeEp3CKMJ9XmrTgJkVEpktdLWd_4,8999
|
7
|
+
parabellum-0.1.2.dist-info/METADATA,sha256=bIOF-Pkl0IcYSHUPWwVxkuNyv_W-rTfKwxmuGxRjLqk,1549
|
8
|
+
parabellum-0.1.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
9
|
+
parabellum-0.1.2.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=CQj_Z18dyAU366U-UAJt8Rggq2mwArfCrn8jfIaLAjA,96
|
2
|
-
parabellum/env.py,sha256=KE_wDPUbvHu02cMl9elT0cV63xVZ3DCpZIBuBM0GoyM,10074
|
3
|
-
parabellum/vis.py,sha256=SWlSj1z09qAXUH2W4LeMHWoICZL5mANytLwR0qDXF-A,8812
|
4
|
-
parabellum-0.1.0.dist-info/METADATA,sha256=jWirrM6EnME0Y3_PetV-dAwifNnLUbSKs-QJsmMCciI,934
|
5
|
-
parabellum-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
6
|
-
parabellum-0.1.0.dist-info/RECORD,,
|
File without changes
|