parabellum 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py CHANGED
@@ -17,6 +17,8 @@ from functools import partial
17
17
  class Scenario:
18
18
  """Parabellum scenario"""
19
19
 
20
+ terrain_raster: chex.Array
21
+
20
22
  obstacle_coords: chex.Array # TODO: use map instead of obstacles
21
23
  obstacle_deltas: chex.Array
22
24
 
@@ -31,10 +33,9 @@ class Scenario:
31
33
  # default scenario
32
34
  scenarios = {
33
35
  "default": Scenario(
34
- jnp.array([[6, 10], [26, 10]]) * 8,
35
- jnp.array([[0, 12], [0, 1]]) * 8,
36
- jnp.array([[6, 10], [26, 10]]) * 8,
37
- jnp.array([[0, 12], [0, 1]]) * 8,
36
+ jnp.eye(128, dtype=jnp.uint8),
37
+ jnp.array([[80, 0], [16, 12]]),
38
+ jnp.array([[0, 80], [0, 20]]),
38
39
  jnp.zeros((19,), dtype=jnp.uint8),
39
40
  9,
40
41
  10,
@@ -44,13 +45,74 @@ scenarios = {
44
45
 
45
46
  class Parabellum(SMAX):
46
47
  def __init__(self, scenario: Scenario, **kwargs):
47
- super(Parabellum, self).__init__(**kwargs)
48
+ map_height, map_width = scenario.terrain_raster.shape
49
+ args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
50
+ super(Parabellum, self).__init__(**args, **kwargs)
51
+ self.terrain_raster = scenario.terrain_raster
48
52
  self.obstacle_coords = scenario.obstacle_coords
49
53
  self.obstacle_deltas = scenario.obstacle_deltas
50
54
  self.unit_type_attack_blasts = jnp.zeros((19,), dtype=jnp.float32)
51
55
  self.max_steps = 200
52
56
  self._push_units_away = lambda x: x # overwrite push units
53
57
 
58
+ @partial(jax.jit, static_argnums=(0,))
59
+ def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
60
+ """Environment-specific reset."""
61
+ key, team_0_key, team_1_key = jax.random.split(key, num=3)
62
+ team_0_start = jnp.stack(
63
+ [jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies
64
+ )
65
+ team_0_start_noise = jax.random.uniform(
66
+ team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
67
+ )
68
+ team_0_start = team_0_start + team_0_start_noise
69
+ team_1_start = jnp.stack(
70
+ [jnp.array([self.map_width / 4 * 3, self.map_height / 2])]
71
+ * self.num_enemies
72
+ )
73
+ team_1_start_noise = jax.random.uniform(
74
+ team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
75
+ )
76
+ team_1_start = team_1_start + team_1_start_noise
77
+ unit_positions = jnp.concatenate([team_0_start, team_1_start])
78
+ key, pos_key = jax.random.split(key)
79
+ generated_unit_positions = self.position_generator.generate(pos_key)
80
+ unit_positions = jax.lax.select(
81
+ self.smacv2_position_generation, generated_unit_positions, unit_positions
82
+ )
83
+ unit_teams = jnp.zeros((self.num_agents,))
84
+ unit_teams = unit_teams.at[self.num_allies :].set(1)
85
+ unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
86
+ # default behaviour spawn all marines
87
+ unit_types = (
88
+ jnp.zeros((self.num_agents,), dtype=jnp.uint8)
89
+ if self.scenario is None
90
+ else self.scenario
91
+ )
92
+ key, unit_type_key = jax.random.split(key)
93
+ generated_unit_types = self.unit_type_generator.generate(unit_type_key)
94
+ unit_types = jax.lax.select(
95
+ self.smacv2_unit_type_generation, generated_unit_types, unit_types
96
+ )
97
+ unit_health = self.unit_type_health[unit_types]
98
+ state = State(
99
+ unit_positions=unit_positions,
100
+ unit_alive=jnp.ones((self.num_agents,), dtype=jnp.bool_),
101
+ unit_teams=unit_teams,
102
+ unit_health=unit_health,
103
+ unit_types=unit_types,
104
+ prev_movement_actions=jnp.zeros((self.num_agents, 2)),
105
+ prev_attack_actions=jnp.zeros((self.num_agents,), dtype=jnp.int32),
106
+ time=0,
107
+ terminal=False,
108
+ unit_weapon_cooldowns=unit_weapon_cooldowns,
109
+ )
110
+ state = self._push_units_away(state)
111
+ obs = self.get_obs(state)
112
+ world_state = self.get_world_state(state)
113
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
114
+ return obs, state
115
+
54
116
  def _our_push_units_away(
55
117
  self, pos, unit_types, firmness: float = 1.0
56
118
  ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
@@ -79,13 +141,24 @@ class Parabellum(SMAX):
79
141
  actions: Tuple[chex.Array, chex.Array],
80
142
  ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
81
143
  @partial(jax.vmap, in_axes=(None, None, 0, 0))
82
- def inter_fn(pos, new_pos, obs, obs_end):
144
+ def intersect_fn(pos, new_pos, obs, obs_end):
83
145
  d1 = jnp.cross(obs - pos, new_pos - pos)
84
146
  d2 = jnp.cross(obs_end - pos, new_pos - pos)
85
147
  d3 = jnp.cross(pos - obs, obs_end - obs)
86
148
  d4 = jnp.cross(new_pos - obs, obs_end - obs)
87
149
  return (d1 * d2 <= 0) & (d3 * d4 <= 0)
88
150
 
151
+ def raster_crossing(pos, new_pos):
152
+ pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
153
+ raster = self.terrain_raster
154
+ axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
155
+ minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
156
+ maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
157
+ segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
158
+ segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
159
+ segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
160
+ return jnp.any(segment)
161
+
89
162
  def update_position(idx, vec):
90
163
  # Compute the movements slightly strangely.
91
164
  # The velocities below are for diagonal directions
@@ -109,8 +182,10 @@ class Parabellum(SMAX):
109
182
  ############################################ avoid going into obstacles
110
183
  obs = self.obstacle_coords
111
184
  obs_end = obs + self.obstacle_deltas
112
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
113
- new_pos = jnp.where(inters, pos, new_pos)
185
+ inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end))
186
+ rastersects = raster_crossing(pos, new_pos)
187
+ flag = jnp.logical_or(inters, rastersects)
188
+ new_pos = jnp.where(flag, pos, new_pos)
114
189
 
115
190
  #######################################################################
116
191
  #######################################################################
@@ -245,8 +320,10 @@ class Parabellum(SMAX):
245
320
  obst_end = obst_start + obstacle_deltas
246
321
 
247
322
  def check_obstacles(pos, new_pos, obst_start, obst_end):
248
- inters = jnp.any(inter_fn(pos, new_pos, obst_start, obst_end))
249
- return jnp.where(inters, pos, new_pos)
323
+ intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
324
+ rastersect = raster_crossing(pos, new_pos)
325
+ flag = jnp.logical_or(intersects, rastersect)
326
+ return jnp.where(flag, pos, new_pos)
250
327
 
251
328
  pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
252
329
  pos, new_pos, obst_start, obst_end
parabellum/run.py CHANGED
@@ -123,3 +123,5 @@ if __name__ == "__main__":
123
123
  game = game if game.control.paused else render(game)
124
124
 
125
125
  pygame.quit()
126
+
127
+
parabellum/vis.py CHANGED
@@ -7,6 +7,7 @@ from jax import vmap
7
7
  from jax import tree_util
8
8
  from functools import partial
9
9
  import darkdetect
10
+ import numpy as np
10
11
  import pygame
11
12
  import os
12
13
  from moviepy.editor import ImageSequenceClip
@@ -47,7 +48,7 @@ class Visualizer(SMAXVisualizer):
47
48
  if multi_dim:
48
49
  n_envs = self.state_seq[0][1].unit_positions.shape[0]
49
50
  if not self.have_expanded:
50
- state_seqs = vmap(env.expand_state_seq)(self.state_seq)
51
+ state_seqs = vmap(self.env.expand_state_seq)(self.state_seq)
51
52
  self.have_expanded = True
52
53
  for i in range(n_envs):
53
54
  state_seq = jax.tree_map(lambda x: x[i], state_seqs)
@@ -62,12 +63,19 @@ class Visualizer(SMAXVisualizer):
62
63
  def animate_one(self, state_seq, action_seq, save_fname):
63
64
  frames = [] # frames for the video
64
65
  pygame.init() # initialize pygame
66
+ terrain = np.array(self.env.terrain_raster)
67
+ rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
68
+ rgb_array[terrain == 1] = self.fg
69
+ mask_surface = pygame.surfarray.make_surface(rgb_array)
70
+ mask_surface = pygame.transform.scale(mask_surface, (self.s, self.s))
71
+
65
72
  for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
66
73
  action = action_seq[idx // self.env.world_steps_per_env_step]
67
74
  screen = pygame.Surface(
68
75
  (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
69
76
  )
70
77
  screen.fill(self.bg) # fill the screen with the background color
78
+ screen.blit(mask_surface, (0, 0))
71
79
 
72
80
  self.render_agents(screen, state) # render the agents
73
81
  self.render_action(screen, action)
@@ -80,7 +88,6 @@ class Visualizer(SMAXVisualizer):
80
88
 
81
89
  # rotate the screen and append to frames
82
90
  frames.append(pygame.surfarray.pixels3d(screen).swapaxes(0, 1))
83
-
84
91
  # save the images
85
92
  clip = ImageSequenceClip(frames, fps=48)
86
93
  clip.write_videofile(save_fname, fps=48)
@@ -99,7 +106,6 @@ class Visualizer(SMAXVisualizer):
99
106
  for idx, (pos, team, kind, hp) in enumerate(time_tuple):
100
107
  face_col = self.fg if int(team.item()) == 0 else self.bg
101
108
  pos = tuple((pos * self.scale).tolist())
102
-
103
109
  # draw the agent
104
110
  if hp > 0:
105
111
  hp_frac = hp / self.env.unit_type_health[kind]
@@ -231,9 +237,8 @@ if __name__ == "__main__":
231
237
  # small_multiples() # testing small multiples (not working yet)
232
238
  # exit()
233
239
 
234
- n_envs = 100
235
- kwargs = dict(map_width=64, map_height=64)
236
- env = Parabellum(scenarios["default"], **kwargs)
240
+ n_envs = 2
241
+ env = Parabellum(scenarios["default"])
237
242
  rng, reset_rng = random.split(random.PRNGKey(0))
238
243
  reset_key = random.split(reset_rng, n_envs)
239
244
  obs, state = vmap(env.reset)(reset_key)
@@ -243,7 +248,7 @@ if __name__ == "__main__":
243
248
  rng, act_rng, step_rng = random.split(rng, 3)
244
249
  act_key = random.split(act_rng, (len(env.agents), n_envs))
245
250
  act = {
246
- a: vmap(env.action_space(a).sample)(act_key[i])
251
+ a: jnp.ones_like(vmap(env.action_space(a).sample)(act_key[i]))
247
252
  for i, a in enumerate(env.agents)
248
253
  }
249
254
  step_key = random.split(step_rng, n_envs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -0,0 +1,8 @@
1
+ parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
+ parabellum/env.py,sha256=d6agGy-kTRIg_r0QKCL_7iztzwhaTfsb4yhtUQfdgx0,16024
3
+ parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
+ parabellum/run.py,sha256=0IWpqcQ_qfFeElbupF5vOs_CByFfpXYuGGUHYuurFM4,3412
5
+ parabellum/vis.py,sha256=euT7VNPpKW9h0bjXwtYBa4MJRXuELfH3JnUm5ulr3s0,10559
6
+ parabellum-0.2.16.dist-info/METADATA,sha256=eXEfS4FXFp4Xrp4g3hrKnvh-fIHzmHcWlnZrIRjdF4k,3223
7
+ parabellum-0.2.16.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
+ parabellum-0.2.16.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- parabellum/__init__.py,sha256=TjZVlHZdi7CEm0gjagm3j6epPZxjR6C9G3CYtX7d-2o,142
2
- parabellum/env.py,sha256=rCn6iPLeFpqitncD9nEc0KA6N9JCMmiSyP9u2meOJxk,12325
3
- parabellum/map.py,sha256=SQeNl1kkGsnnqYoo-60zJNv36fD-8VSKasiS1_WARao,410
4
- parabellum/run.py,sha256=lVNBsMc8HY4Tqdjs_1MXGBvIzuN05brbRiqp0xlRc6c,3410
5
- parabellum/vis.py,sha256=u7ifxWzHf96WgLTz_hw0ijy6-7wePd7lf0p-yD-NCQY,10212
6
- parabellum-0.2.14.dist-info/METADATA,sha256=wEiXzwPfnigG5ZSANPFwGEjLCDU5D0c7qbpvEi6Gbm8,3223
7
- parabellum-0.2.14.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
- parabellum-0.2.14.dist-info/RECORD,,