parabellum 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/env.py +87 -10
- parabellum/run.py +2 -0
- parabellum/vis.py +12 -7
- {parabellum-0.2.14.dist-info → parabellum-0.2.16.dist-info}/METADATA +1 -1
- parabellum-0.2.16.dist-info/RECORD +8 -0
- parabellum-0.2.14.dist-info/RECORD +0 -8
- {parabellum-0.2.14.dist-info → parabellum-0.2.16.dist-info}/WHEEL +0 -0
parabellum/env.py
CHANGED
@@ -17,6 +17,8 @@ from functools import partial
|
|
17
17
|
class Scenario:
|
18
18
|
"""Parabellum scenario"""
|
19
19
|
|
20
|
+
terrain_raster: chex.Array
|
21
|
+
|
20
22
|
obstacle_coords: chex.Array # TODO: use map instead of obstacles
|
21
23
|
obstacle_deltas: chex.Array
|
22
24
|
|
@@ -31,10 +33,9 @@ class Scenario:
|
|
31
33
|
# default scenario
|
32
34
|
scenarios = {
|
33
35
|
"default": Scenario(
|
34
|
-
jnp.
|
35
|
-
jnp.array([[
|
36
|
-
jnp.array([[
|
37
|
-
jnp.array([[0, 12], [0, 1]]) * 8,
|
36
|
+
jnp.eye(128, dtype=jnp.uint8),
|
37
|
+
jnp.array([[80, 0], [16, 12]]),
|
38
|
+
jnp.array([[0, 80], [0, 20]]),
|
38
39
|
jnp.zeros((19,), dtype=jnp.uint8),
|
39
40
|
9,
|
40
41
|
10,
|
@@ -44,13 +45,74 @@ scenarios = {
|
|
44
45
|
|
45
46
|
class Parabellum(SMAX):
|
46
47
|
def __init__(self, scenario: Scenario, **kwargs):
|
47
|
-
|
48
|
+
map_height, map_width = scenario.terrain_raster.shape
|
49
|
+
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
50
|
+
super(Parabellum, self).__init__(**args, **kwargs)
|
51
|
+
self.terrain_raster = scenario.terrain_raster
|
48
52
|
self.obstacle_coords = scenario.obstacle_coords
|
49
53
|
self.obstacle_deltas = scenario.obstacle_deltas
|
50
54
|
self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
|
51
55
|
self.max_steps = 200
|
52
56
|
self._push_units_away = lambda x: x # overwrite push units
|
53
57
|
|
58
|
+
@partial(jax.jit, static_argnums=(0,))
|
59
|
+
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
60
|
+
"""Environment-specific reset."""
|
61
|
+
key, team_0_key, team_1_key = jax.random.split(key, num=3)
|
62
|
+
team_0_start = jnp.stack(
|
63
|
+
[jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
|
64
|
+
)
|
65
|
+
team_0_start_noise = jax.random.uniform(
|
66
|
+
team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
|
67
|
+
)
|
68
|
+
team_0_start = team_0_start + team_0_start_noise
|
69
|
+
team_1_start = jnp.stack(
|
70
|
+
[jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
|
71
|
+
* self.num_enemies
|
72
|
+
)
|
73
|
+
team_1_start_noise = jax.random.uniform(
|
74
|
+
team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
|
75
|
+
)
|
76
|
+
team_1_start = team_1_start + team_1_start_noise
|
77
|
+
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
78
|
+
key, pos_key = jax.random.split(key)
|
79
|
+
generated_unit_positions = self.position_generator.generate(pos_key)
|
80
|
+
unit_positions = jax.lax.select(
|
81
|
+
self.smacv2_position_generation, generated_unit_positions, unit_positions
|
82
|
+
)
|
83
|
+
unit_teams = jnp.zeros((self.num_agents,))
|
84
|
+
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
85
|
+
unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
|
86
|
+
# default behaviour spawn all marines
|
87
|
+
unit_types = (
|
88
|
+
jnp.zeros((self.num_agents,), dtype=jnp.uint8)
|
89
|
+
if self.scenario is None
|
90
|
+
else self.scenario
|
91
|
+
)
|
92
|
+
key, unit_type_key = jax.random.split(key)
|
93
|
+
generated_unit_types = self.unit_type_generator.generate(unit_type_key)
|
94
|
+
unit_types = jax.lax.select(
|
95
|
+
self.smacv2_unit_type_generation, generated_unit_types, unit_types
|
96
|
+
)
|
97
|
+
unit_health = self.unit_type_health[unit_types]
|
98
|
+
state = State(
|
99
|
+
unit_positions=unit_positions,
|
100
|
+
unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
|
101
|
+
unit_teams=unit_teams,
|
102
|
+
unit_health=unit_health,
|
103
|
+
unit_types=unit_types,
|
104
|
+
prev_movement_actions=jnp.zeros((self.num_agents, 2)),
|
105
|
+
prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
|
106
|
+
time=0,
|
107
|
+
terminal=False,
|
108
|
+
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
109
|
+
)
|
110
|
+
state = self._push_units_away(state)
|
111
|
+
obs = self.get_obs(state)
|
112
|
+
world_state = self.get_world_state(state)
|
113
|
+
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
114
|
+
return obs, state
|
115
|
+
|
54
116
|
def _our_push_units_away(
|
55
117
|
self, pos, unit_types, firmness: float = 1.0
|
56
118
|
): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
@@ -79,13 +141,24 @@ class Parabellum(SMAX):
|
|
79
141
|
actions: Tuple[chex.Array, chex.Array],
|
80
142
|
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
81
143
|
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
82
|
-
def
|
144
|
+
def intersect_fn(pos, new_pos, obs, obs_end):
|
83
145
|
d1 = jnp.cross(obs - pos, new_pos - pos)
|
84
146
|
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
85
147
|
d3 = jnp.cross(pos - obs, obs_end - obs)
|
86
148
|
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
87
149
|
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
88
150
|
|
151
|
+
def raster_crossing(pos, new_pos):
|
152
|
+
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
153
|
+
raster = self.terrain_raster
|
154
|
+
axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
|
155
|
+
minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
|
156
|
+
maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
|
157
|
+
segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
|
158
|
+
segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
|
159
|
+
segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
|
160
|
+
return jnp.any(segment)
|
161
|
+
|
89
162
|
def update_position(idx, vec):
|
90
163
|
# Compute the movements slightly strangely.
|
91
164
|
# The velocities below are for diagonal directions
|
@@ -109,8 +182,10 @@ class Parabellum(SMAX):
|
|
109
182
|
############################################ avoid going into obstacles
|
110
183
|
obs = self.obstacle_coords
|
111
184
|
obs_end = obs + self.obstacle_deltas
|
112
|
-
inters = jnp.any(
|
113
|
-
|
185
|
+
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
|
186
|
+
rastersects = raster_crossing(pos, new_pos)
|
187
|
+
flag = jnp.logical_or(inters, rastersects)
|
188
|
+
new_pos = jnp.where(flag, pos, new_pos)
|
114
189
|
|
115
190
|
#######################################################################
|
116
191
|
#######################################################################
|
@@ -245,8 +320,10 @@ class Parabellum(SMAX):
|
|
245
320
|
obst_end = obst_start + obstacle_deltas
|
246
321
|
|
247
322
|
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
248
|
-
|
249
|
-
|
323
|
+
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
324
|
+
rastersect = raster_crossing(pos, new_pos)
|
325
|
+
flag = jnp.logical_or(intersects, rastersect)
|
326
|
+
return jnp.where(flag, pos, new_pos)
|
250
327
|
|
251
328
|
pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
252
329
|
pos, new_pos, obst_start, obst_end
|
parabellum/run.py
CHANGED
parabellum/vis.py
CHANGED
@@ -7,6 +7,7 @@ from jax import vmap
|
|
7
7
|
from jax import tree_util
|
8
8
|
from functools import partial
|
9
9
|
import darkdetect
|
10
|
+
import numpy as np
|
10
11
|
import pygame
|
11
12
|
import os
|
12
13
|
from moviepy.editor import ImageSequenceClip
|
@@ -47,7 +48,7 @@ class Visualizer(SMAXVisualizer):
|
|
47
48
|
if multi_dim:
|
48
49
|
n_envs = self.state_seq[0][1].unit_positions.shape[0]
|
49
50
|
if not self.have_expanded:
|
50
|
-
state_seqs = vmap(env.expand_state_seq)(self.state_seq)
|
51
|
+
state_seqs = vmap(self.env.expand_state_seq)(self.state_seq)
|
51
52
|
self.have_expanded = True
|
52
53
|
for i in range(n_envs):
|
53
54
|
state_seq = jax.tree_map(lambda x: x[i], state_seqs)
|
@@ -62,12 +63,19 @@ class Visualizer(SMAXVisualizer):
|
|
62
63
|
def animate_one(self, state_seq, action_seq, save_fname):
|
63
64
|
frames = [] # frames for the video
|
64
65
|
pygame.init() # initialize pygame
|
66
|
+
terrain = np.array(self.env.terrain_raster)
|
67
|
+
rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
|
68
|
+
rgb_array[terrain == 1] = self.fg
|
69
|
+
mask_surface = pygame.surfarray.make_surface(rgb_array)
|
70
|
+
mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
|
71
|
+
|
65
72
|
for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
|
66
73
|
action = action_seq[idx // self.env.world_steps_per_env_step]
|
67
74
|
screen = pygame.Surface(
|
68
75
|
(self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
|
69
76
|
)
|
70
77
|
screen.fill(self.bg) # fill the screen with the background color
|
78
|
+
screen.blit(mask_surface, (0, 0))
|
71
79
|
|
72
80
|
self.render_agents(screen, state) # render the agents
|
73
81
|
self.render_action(screen, action)
|
@@ -80,7 +88,6 @@ class Visualizer(SMAXVisualizer):
|
|
80
88
|
|
81
89
|
# rotate the screen and append to frames
|
82
90
|
frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
|
83
|
-
|
84
91
|
# save the images
|
85
92
|
clip = ImageSequenceClip(frames, fps=48)
|
86
93
|
clip.write_videofile(save_fname, fps=48)
|
@@ -99,7 +106,6 @@ class Visualizer(SMAXVisualizer):
|
|
99
106
|
for idx, (pos, team, kind, hp) in enumerate(time_tuple):
|
100
107
|
face_col = self.fg if int(team.item()) == 0 else self.bg
|
101
108
|
pos = tuple((pos * self.scale).tolist())
|
102
|
-
|
103
109
|
# draw the agent
|
104
110
|
if hp > 0:
|
105
111
|
hp_frac = hp / self.env.unit_type_health[kind]
|
@@ -231,9 +237,8 @@ if __name__ == "__main__":
|
|
231
237
|
# small_multiples() # testing small multiples (not working yet)
|
232
238
|
# exit()
|
233
239
|
|
234
|
-
n_envs =
|
235
|
-
|
236
|
-
env = Parabellum(scenarios["default"], **kwargs)
|
240
|
+
n_envs = 2
|
241
|
+
env = Parabellum(scenarios["default"])
|
237
242
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
238
243
|
reset_key = random.split(reset_rng, n_envs)
|
239
244
|
obs, state = vmap(env.reset)(reset_key)
|
@@ -243,7 +248,7 @@ if __name__ == "__main__":
|
|
243
248
|
rng, act_rng, step_rng = random.split(rng, 3)
|
244
249
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
245
250
|
act = {
|
246
|
-
a: vmap(env.action_space(a).sample)(act_key[i])
|
251
|
+
a: jnp.ones_like(vmap(env.action_space(a).sample)(act_key[i]))
|
247
252
|
for i, a in enumerate(env.agents)
|
248
253
|
}
|
249
254
|
step_key = random.split(step_rng, n_envs)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
+
parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
|
3
|
+
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
+
parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
|
5
|
+
parabellum/vis.py,sha256=euT7VNPpKW9h0bjXwtYBa4MJRXuELfH3JnUm5ulr3s0,10559
|
6
|
+
parabellum-0.2.16.dist-info/METADATA,sha256=eXEfS4FXFp4Xrp4g3hrKnvh-fIHzmHcWlnZrIRjdF4k,3223
|
7
|
+
parabellum-0.2.16.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
+
parabellum-0.2.16.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
-
parabellum/env.py,sha256=rCn6iPLeFpqitncD9nEc0KA6N9JCMmiSyP9u2meOJxk,12325
|
3
|
-
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
-
parabellum/run.py,sha256=lVNBsMc8HY4Tqdjs_1MXGBvIzuN05brbRiqp0xlRc6c,3410
|
5
|
-
parabellum/vis.py,sha256=u7ifxWzHf96WgLTz_hw0ijy6-7wePd7lf0p-yD-NCQY,10212
|
6
|
-
parabellum-0.2.14.dist-info/METADATA,sha256=wEiXzwPfnigG5ZSANPFwGEjLCDU5D0c7qbpvEi6Gbm8,3223
|
7
|
-
parabellum-0.2.14.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
-
parabellum-0.2.14.dist-info/RECORD,,
|
File without changes
|