parabellum 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py CHANGED
@@ -33,7 +33,7 @@ class Scenario:
33
33
  # default scenario
34
34
  scenarios = {
35
35
  "default": Scenario(
36
- jnp.eye(128, dtype=jnp.uint8),
36
+ jnp.eye(64, dtype=jnp.uint8),
37
37
  jnp.array([[80, 0], [16, 12]]),
38
38
  jnp.array([[0, 80], [0, 20]]),
39
39
  jnp.zeros((19,), dtype=jnp.uint8),
parabellum/vis.py CHANGED
@@ -1,4 +1,6 @@
1
- """Visualizer for the Parabellum environment"""
1
+ """
2
+ Visualizer for the Parabellum environment
3
+ """
2
4
 
3
5
  from tqdm import tqdm
4
6
  import jax.numpy as jnp
@@ -37,7 +39,7 @@ class Visualizer(SMAXVisualizer):
37
39
  # remove fig and ax from super
38
40
  self.fig, self.ax = None, None
39
41
  self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
40
- self.fg = (235, 235, 235) if darkdetect.isDark() else (20, 20, 20)
42
+ self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
41
43
  self.s = 1000
42
44
  self.scale = self.s / self.env.map_width
43
45
  self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
@@ -57,7 +59,7 @@ class Visualizer(SMAXVisualizer):
57
59
  state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
58
60
  )
59
61
  else:
60
- state_seq = env.expand_state_seq(self.state_seq)
62
+ state_seq = self.env.expand_state_seq(self.state_seq)
61
63
  self.animate_one(state_seq, self.action_seq, save_fname)
62
64
 
63
65
  def animate_one(self, state_seq, action_seq, save_fname):
@@ -65,6 +67,8 @@ class Visualizer(SMAXVisualizer):
65
67
  pygame.init() # initialize pygame
66
68
  terrain = np.array(self.env.terrain_raster)
67
69
  rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
70
+ if darkdetect.isLight():
71
+ rgb_array += 255
68
72
  rgb_array[terrain == 1] = self.fg
69
73
  mask_surface = pygame.surfarray.make_surface(rgb_array)
70
74
  mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
@@ -91,7 +95,7 @@ class Visualizer(SMAXVisualizer):
91
95
  # save the images
92
96
  clip = ImageSequenceClip(frames, fps=48)
93
97
  clip.write_videofile(save_fname, fps=48)
94
- # clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
98
+ clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
95
99
  pygame.quit()
96
100
 
97
101
  return clip
@@ -124,6 +128,9 @@ class Visualizer(SMAXVisualizer):
124
128
  # work out which agents are being shot
125
129
 
126
130
  def render_action(self, screen, action):
131
+ if self.env.action_type != "discrete":
132
+ return
133
+
127
134
  def coord_fn(idx, n, team):
128
135
  return (
129
136
  self.s / 20 if team == 0 else self.s - self.s / 20,
@@ -238,7 +245,7 @@ if __name__ == "__main__":
238
245
  # exit()
239
246
 
240
247
  n_envs = 2
241
- env = Parabellum(scenarios["default"])
248
+ env = Parabellum(scenarios["default"], action_type="discrete")
242
249
  rng, reset_rng = random.split(random.PRNGKey(0))
243
250
  reset_key = random.split(reset_rng, n_envs)
244
251
  obs, state = vmap(env.reset)(reset_key)
@@ -248,7 +255,7 @@ if __name__ == "__main__":
248
255
  rng, act_rng, step_rng = random.split(rng, 3)
249
256
  act_key = random.split(act_rng, (len(env.agents), n_envs))
250
257
  act = {
251
- a: jnp.ones_like(vmap(env.action_space(a).sample)(act_key[i]))
258
+ a: vmap(env.action_space(a).sample)(act_key[i])
252
259
  for i, a in enumerate(env.agents)
253
260
  }
254
261
  step_key = random.split(step_rng, n_envs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.17
3
+ Version: 0.2.19
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -0,0 +1,8 @@
1
+ parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
+ parabellum/env.py,sha256=Z8zpdCaEi5HFwN0Vd2hukOarkPSg0EKZErTRts3JQ5E,16023
3
+ parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
+ parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
+ parabellum/vis.py,sha256=JFVTnBg-LV4jZNw6cysU6NS8ZxeMpg5wz3JOi-lrnzY,10699
6
+ parabellum-0.2.19.dist-info/METADATA,sha256=DsEBAlESj8BwGSphmPPylStoXH_g_x_Iy3WJ3KEwjc0,3223
7
+ parabellum-0.2.19.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
+ parabellum-0.2.19.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
- parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
3
- parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
- parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
- parabellum/vis.py,sha256=oNRucG1pX1PKqrY-UVhRzx0PdFl9svH8cq5uDb4MCno,10559
6
- parabellum-0.2.17.dist-info/METADATA,sha256=ZFIfGi5TwUVucprOu2jSyfrjfwxQBazUr1DqRJauw4s,3223
7
- parabellum-0.2.17.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
- parabellum-0.2.17.dist-info/RECORD,,