parabellum 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py
CHANGED
parabellum/vis.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Visualizer for the Parabellum environment
|
3
|
+
"""
|
2
4
|
|
3
5
|
from tqdm import tqdm
|
4
6
|
import jax.numpy as jnp
|
@@ -37,7 +39,7 @@ class Visualizer(SMAXVisualizer):
|
|
37
39
|
# remove fig and ax from super
|
38
40
|
self.fig, self.ax = None, None
|
39
41
|
self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
|
40
|
-
self.fg = (
|
42
|
+
self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
|
41
43
|
self.s = 1000
|
42
44
|
self.scale = self.s / self.env.map_width
|
43
45
|
self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
|
@@ -57,7 +59,7 @@ class Visualizer(SMAXVisualizer):
|
|
57
59
|
state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
|
58
60
|
)
|
59
61
|
else:
|
60
|
-
state_seq = env.expand_state_seq(self.state_seq)
|
62
|
+
state_seq = self.env.expand_state_seq(self.state_seq)
|
61
63
|
self.animate_one(state_seq, self.action_seq, save_fname)
|
62
64
|
|
63
65
|
def animate_one(self, state_seq, action_seq, save_fname):
|
@@ -65,6 +67,8 @@ class Visualizer(SMAXVisualizer):
|
|
65
67
|
pygame.init() # initialize pygame
|
66
68
|
terrain = np.array(self.env.terrain_raster)
|
67
69
|
rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
|
70
|
+
if darkdetect.isLight():
|
71
|
+
rgb_array += 255
|
68
72
|
rgb_array[terrain == 1] = self.fg
|
69
73
|
mask_surface = pygame.surfarray.make_surface(rgb_array)
|
70
74
|
mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
|
@@ -91,7 +95,7 @@ class Visualizer(SMAXVisualizer):
|
|
91
95
|
# save the images
|
92
96
|
clip = ImageSequenceClip(frames, fps=48)
|
93
97
|
clip.write_videofile(save_fname, fps=48)
|
94
|
-
|
98
|
+
clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
|
95
99
|
pygame.quit()
|
96
100
|
|
97
101
|
return clip
|
@@ -124,6 +128,9 @@ class Visualizer(SMAXVisualizer):
|
|
124
128
|
# work out which agents are being shot
|
125
129
|
|
126
130
|
def render_action(self, screen, action):
|
131
|
+
if self.env.action_type != "discrete":
|
132
|
+
return
|
133
|
+
|
127
134
|
def coord_fn(idx, n, team):
|
128
135
|
return (
|
129
136
|
self.s / 20 if team == 0 else self.s - self.s / 20,
|
@@ -238,7 +245,7 @@ if __name__ == "__main__":
|
|
238
245
|
# exit()
|
239
246
|
|
240
247
|
n_envs = 2
|
241
|
-
env = Parabellum(scenarios["default"])
|
248
|
+
env = Parabellum(scenarios["default"], action_type="discrete")
|
242
249
|
rng, reset_rng = random.split(random.PRNGKey(0))
|
243
250
|
reset_key = random.split(reset_rng, n_envs)
|
244
251
|
obs, state = vmap(env.reset)(reset_key)
|
@@ -248,7 +255,7 @@ if __name__ == "__main__":
|
|
248
255
|
rng, act_rng, step_rng = random.split(rng, 3)
|
249
256
|
act_key = random.split(act_rng, (len(env.agents), n_envs))
|
250
257
|
act = {
|
251
|
-
a:
|
258
|
+
a: vmap(env.action_space(a).sample)(act_key[i])
|
252
259
|
for i, a in enumerate(env.agents)
|
253
260
|
}
|
254
261
|
step_key = random.split(step_rng, n_envs)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
+
parabellum/env.py,sha256=Z8zpdCaEi5HFwN0Vd2hukOarkPSg0EKZErTRts3JQ5E,16023
|
3
|
+
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
+
parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
|
5
|
+
parabellum/vis.py,sha256=JFVTnBg-LV4jZNw6cysU6NS8ZxeMpg5wz3JOi-lrnzY,10699
|
6
|
+
parabellum-0.2.19.dist-info/METADATA,sha256=DsEBAlESj8BwGSphmPPylStoXH_g_x_Iy3WJ3KEwjc0,3223
|
7
|
+
parabellum-0.2.19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
+
parabellum-0.2.19.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
|
2
|
-
parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
|
3
|
-
parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
|
4
|
-
parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
|
5
|
-
parabellum/vis.py,sha256=oNRucG1pX1PKqrY-UVhRzx0PdFl9svH8cq5uDb4MCno,10559
|
6
|
-
parabellum-0.2.17.dist-info/METADATA,sha256=ZFIfGi5TwUVucprOu2jSyfrjfwxQBazUr1DqRJauw4s,3223
|
7
|
-
parabellum-0.2.17.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
8
|
-
parabellum-0.2.17.dist-info/RECORD,,
|
File without changes
|