parabellum 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,12 +1,18 @@
1
- from .env import Environment, Scenario, scenarios, make_scenario
2
- from .vis import Visualizer
1
+ from .env import Environment, Scenario, scenarios, make_scenario, State
2
+ from .vis import Visualizer, Skin
3
3
  from .map import terrain_fn
4
+ from .gun import bullet_fn
5
+ # from .aid import aid
6
+ # from .run import run
4
7
 
5
8
  __all__ = [
6
9
  "Environment",
7
10
  "Scenario",
8
11
  "scenarios",
9
12
  "make_scenario",
13
+ "State",
10
14
  "Visualizer",
15
+ "Skin",
11
16
  "terrain_fn",
17
+ "bullet_fn",
12
18
  ]
parabellum/aid.py ADDED
@@ -0,0 +1,5 @@
1
+ # aid.py
2
+ # what you call utils.py when you want file names to be 3 letters
3
+ # by: Noah Syrkis
4
+
5
+ # imports
parabellum/env.py CHANGED
@@ -3,13 +3,13 @@
3
3
  import jax.numpy as jnp
4
4
  import jax
5
5
  import numpy as np
6
- from jax import random
6
+ from jax import random, Array
7
7
  from jax import jit
8
8
  from flax.struct import dataclass
9
9
  import chex
10
10
  from jax import vmap
11
- from jaxmarl.environments.smax.smax_env import State, SMAX
12
- from typing import Tuple, Dict
11
+ from jaxmarl.environments.smax.smax_env import SMAX
12
+ from typing import Tuple, Dict, cast
13
13
  from functools import partial
14
14
 
15
15
 
@@ -18,7 +18,7 @@ class Scenario:
18
18
  """Parabellum scenario"""
19
19
 
20
20
  place: str
21
- terrain_raster: chex.Array
21
+ terrain_raster: jnp.ndarray
22
22
  unit_types: chex.Array
23
23
  num_allies: int
24
24
  num_enemies: int
@@ -26,6 +26,19 @@ class Scenario:
26
26
  smacv2_position_generation: bool = False
27
27
  smacv2_unit_type_generation: bool = False
28
28
 
29
+ @dataclass
30
+ class State:
31
+ unit_positions: Array
32
+ unit_alive: Array
33
+ unit_teams: Array
34
+ unit_health: Array
35
+ unit_types: Array
36
+ unit_weapon_cooldowns: Array
37
+ prev_movement_actions: Array
38
+ prev_attack_actions: Array
39
+ time: int
40
+ terminal: bool
41
+
29
42
 
30
43
  # default scenario
31
44
  scenarios = {
@@ -46,7 +59,7 @@ def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
46
59
  return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
47
60
 
48
61
 
49
- def spawn_fn(pool: jnp.ndarray, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
62
+ def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
50
63
  """Spawns n agents on a map."""
51
64
  rng, key_start, key_noise = random.split(rng, 3)
52
65
  noise = random.uniform(key_noise, (n, 2)) * 0.5
@@ -80,22 +93,23 @@ class Environment(SMAX):
80
93
  self.scenario = scenario
81
94
  self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
82
95
  self.max_steps = 200
83
- self._push_units_away = lambda x: x # overwrite push units
84
- self.top_sector = sector_fn(self.terrain_raster, 0)
85
- self.low_sector = sector_fn(self.terrain_raster, 24)
96
+ self._push_units_away = lambda state, firmness = 1: state # overwrite push units
97
+ self.top_sector, self.top_sector_offset = sector_fn(self.terrain_raster, 0)
98
+ self.low_sector, self.low_sector_offset = sector_fn(self.terrain_raster, 24)
99
+
86
100
 
87
101
  @partial(jax.jit, static_argnums=(0,))
88
102
  def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
89
103
  """Environment-specific reset."""
90
104
  ally_key, enemy_key = jax.random.split(rng)
91
- team_0_start = spawn_fn(*self.top_sector, self.num_allies, ally_key)
92
- team_1_start = spawn_fn(*self.low_sector, self.num_enemies, enemy_key)
105
+ team_0_start = spawn_fn(self.top_sector, self.top_sector_offset, self.num_allies, ally_key)
106
+ team_1_start = spawn_fn(self.low_sector, self.low_sector_offset, self.num_enemies, enemy_key)
93
107
  unit_positions = jnp.concatenate([team_0_start, team_1_start])
94
108
  unit_teams = jnp.zeros((self.num_agents,))
95
109
  unit_teams = unit_teams.at[self.num_allies :].set(1)
96
110
  unit_weapon_cooldowns = jnp.zeros((self.num_agents,))
97
111
  # default behaviour spawn all marines
98
- unit_types = self.scenario.unit_types
112
+ unit_types = cast(Array, self.scenario.unit_types)
99
113
  unit_health = self.unit_type_health[unit_types]
100
114
  state = State(
101
115
  unit_positions=unit_positions,
@@ -109,12 +123,18 @@ class Environment(SMAX):
109
123
  terminal=False,
110
124
  unit_weapon_cooldowns=unit_weapon_cooldowns,
111
125
  )
112
- state = self._push_units_away(state)
126
+ state = self._push_units_away(state) # type: ignore
113
127
  obs = self.get_obs(state)
114
128
  world_state = self.get_world_state(state)
115
- obs["world_state"] = jax.lax.stop_gradient(world_state)
129
+ # obs["world_state"] = jax.lax.stop_gradient(world_state)
116
130
  return obs, state
117
131
 
132
+ def step_env(self, state: State, action: Array):
133
+ obs, state, rewards, dones, infos = super().step_env(state, action)
134
+ # delete world_state from obs
135
+ obs.pop("world_state")
136
+ return obs, state, rewards, dones, infos
137
+
118
138
  def _our_push_units_away(
119
139
  self, pos, unit_types, firmness: float = 1.0
120
140
  ): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
@@ -141,7 +161,7 @@ class Environment(SMAX):
141
161
  key: chex.PRNGKey,
142
162
  state: State,
143
163
  actions: Tuple[chex.Array, chex.Array],
144
- ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
164
+ ) -> State:
145
165
  @partial(jax.vmap, in_axes=(None, None, 0, 0))
146
166
  def intersect_fn(pos, new_pos, obs, obs_end):
147
167
  d1 = jnp.cross(obs - pos, new_pos - pos)
@@ -167,7 +187,7 @@ class Environment(SMAX):
167
187
  # because these are easier to encode as actions than the four
168
188
  # diagonal directions. Then rotate the velocity 45
169
189
  # degrees anticlockwise to compute the movement.
170
- pos = state.unit_positions[idx]
190
+ pos = cast(Array, state.unit_positions[idx])
171
191
  new_pos = (
172
192
  pos
173
193
  + vec
@@ -219,6 +239,7 @@ class Environment(SMAX):
219
239
  lambda: action + self.num_allies - self.num_movement_actions,
220
240
  lambda: self.num_allies - 1 - (action - self.num_movement_actions),
221
241
  )
242
+ attacked_idx = cast(int, attacked_idx) # Cast to int
222
243
  # deal with no-op attack actions (i.e. agents that are moving instead)
223
244
  attacked_idx = jax.lax.select(
224
245
  action < self.num_movement_actions, idx, attacked_idx
@@ -250,7 +271,7 @@ class Environment(SMAX):
250
271
  bystander_idxs = bystander_fn(attacked_idx) # TODO: use
251
272
  bystander_valid = (
252
273
  jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
253
- .astype(jnp.bool_)
274
+ .astype(jnp.bool_) # type: ignore
254
275
  .astype(jnp.float32)
255
276
  )
256
277
  bystander_health_diff = (
@@ -365,7 +386,8 @@ class Environment(SMAX):
365
386
  #########################################################
366
387
 
367
388
  unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
368
- state = state.replace(
389
+ # replace unit health, unit positions and unit weapon cooldowns
390
+ state = state.replace( # type: ignore
369
391
  unit_health=unit_health,
370
392
  unit_positions=pos,
371
393
  unit_weapon_cooldowns=unit_weapon_cooldowns,
parabellum/gun.py ADDED
@@ -0,0 +1,70 @@
1
+ # gun.py
2
+ # parabellum bullet rendering assosciated functions
3
+ # by: Noah Syrkis
4
+
5
+ # imports
6
+ from functools import partial
7
+ import jax.numpy as jnp
8
+
9
+
10
+ def dist_fn(env, pos): # computing the distances between all ally and enemy agents
11
+ delta = pos[None, :, :] - pos[:, None, :]
12
+ dist = jnp.sqrt((delta**2).sum(axis=2))
13
+ dist = dist[: env.num_allies, env.num_allies :]
14
+ return {"ally": dist, "enemy": dist.T}
15
+
16
+
17
+ def range_fn(env, dists, ranges): # computing what targets are in range
18
+ ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
19
+ enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
20
+ return {"ally": ally_range, "enemy": enemy_range}
21
+
22
+
23
+ def target_fn(acts, in_range, team): # computing the one hot valid targets
24
+ t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
25
+ t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
26
+ t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
27
+ return t_attacks * in_range[team] # one hot valid targets
28
+
29
+
30
+ def attack_fn(env, state_seq): # one hot attack list
31
+ attacks = []
32
+ for _, state, acts in state_seq:
33
+ dists = dist_fn(env, state.unit_positions)
34
+ ranges = env.unit_type_attack_ranges[state.unit_types]
35
+ in_range = range_fn(env, dists, ranges)
36
+ target = partial(target_fn, acts, in_range)
37
+ attack = {"ally": target("ally"), "enemy": target("enemy")}
38
+ attacks.append(attack)
39
+ return attacks
40
+
41
+
42
+ def bullet_fn(env, states):
43
+ bullet_seq = []
44
+ attack_seq = attack_fn(env, states)
45
+
46
+ def aux_fn(team):
47
+ bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
48
+ # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
49
+ return bullets
50
+
51
+ state_zip = zip(states[:-1], states[1:])
52
+ for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
53
+ one_hot = attack_seq[i]
54
+ ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
55
+
56
+ ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
57
+ enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
58
+
59
+ enemy_bullets_source = state.unit_positions[
60
+ enemy_bullets[:, 0] + env.num_allies
61
+ ]
62
+ ally_bullets_target = n_state.unit_positions[
63
+ ally_bullets[:, 1] + env.num_allies
64
+ ]
65
+
66
+ ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
67
+ enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
68
+
69
+ bullet_seq.append((ally_bullets, enemy_bullets))
70
+ return bullet_seq
parabellum/map.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # map.py
2
- # parabellum map functions
2
+ # parabellum map functions
3
3
  # by: Noah Syrkis
4
4
 
5
5
  # imports
@@ -7,43 +7,62 @@ import jax.numpy as jnp
7
7
  from geopy.geocoders import Nominatim
8
8
  import geopandas as gpd
9
9
  import osmnx as ox
10
- import rasterio
11
- from jax import random
10
+ import contextily as cx
11
+ import matplotlib.pyplot as plt
12
12
  from rasterio import features
13
13
  import rasterio.transform
14
+ from typing import Optional, Tuple
15
+ from geopy.location import Location
16
+ from shapely.geometry import Point
14
17
 
15
18
  # constants
16
19
  geolocator = Nominatim(user_agent="parabellum")
17
- tags = {"building": True}
20
+ BUILDING_TAGS = {"building": True}
18
21
 
22
+ def get_location(place: str) -> Tuple[float, float]:
23
+ """Get coordinates for a given place."""
24
+ coords: Optional[Location] = geolocator.geocode(place) # type: ignore
25
+ if coords is None:
26
+ raise ValueError(f"Could not geocode the place: {place}")
27
+ return (coords.latitude, coords.longitude)
19
28
 
20
- # functions
21
- def terrain_fn(place: str, size: int = 1000):
22
- """Returns a rasterized map of a given location."""
29
+ def get_building_geometry(point: Tuple[float, float], size: int) -> gpd.GeoDataFrame:
30
+ """Get building geometry for a given point and size."""
31
+ geometry = ox.features_from_point(point, tags=BUILDING_TAGS, dist=size // 2)
32
+ return gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
23
33
 
24
- # location info
25
- location = geolocator.geocode(place)
26
- coords = (location.latitude, location.longitude)
34
+ def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
35
+ """Rasterize geometry and return as a JAX array."""
36
+ w, s, e, n = gdf.total_bounds
37
+ transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
38
+ raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
39
+ return jnp.array(jnp.rot90(raster, 2)).astype(jnp.uint8)
27
40
 
28
- # shape info
29
- geometry = ox.features_from_point(coords, tags=tags, dist=size // 2)
30
- gdf = gpd.GeoDataFrame(geometry).set_crs("EPSG:4326")
41
+ def terrain_fn(place: str, size: int = 1000) -> Tuple[jnp.ndarray, jnp.ndarray]:
42
+ """Returns a rasterized map of buildings for a given location."""
43
+ point = get_location(place)
44
+ gdf = get_building_geometry(point, size)
45
+ mask = rasterize_geometry(gdf, size)
46
+ base = get_basemap(place, size)
47
+ return mask, base
31
48
 
32
- # raster info
33
- t = rasterio.transform.from_bounds(*gdf.total_bounds, size, size)
34
- raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=t)
35
-
36
- # rotate 180 degrees
37
- raster = jnp.rot90(raster, 2)
38
-
39
- return jnp.array(raster).astype(jnp.uint8)
49
+ def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
50
+ """Returns a basemap for a given place as a JAX array."""
51
+ point = get_location(place)
52
+ gdf = get_building_geometry(point, size)
53
+ basemap, _ = cx.bounds2img(*gdf.total_bounds, ll=True)
54
+ # get the middle size x size square
55
+ basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
56
+ (basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
57
+ return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
40
58
 
41
59
 
42
60
  if __name__ == "__main__":
43
61
  import seaborn as sns
62
+ place = "Thun, Switzerland"
63
+ mask, base = terrain_fn(place)
44
64
 
45
- place = "Vesterbro, Copenhagen, Denmark"
46
- terrain = terrain_fn(place)
47
- rng, key = random.split(random.PRNGKey(0))
48
- agents = spawn_fn(terrain, 12, 100, key)
49
- print(agents)
65
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
66
+ ax[0].imshow(mask) # type: ignore
67
+ ax[1].imshow(base) # type: ignore
68
+ plt.show()
parabellum/run.py CHANGED
@@ -10,6 +10,7 @@ import darkdetect
10
10
  import jax.numpy as jnp
11
11
  from chex import dataclass
12
12
  import jaxmarl
13
+ from jax import Array
13
14
  from typing import Tuple, List, Dict, Optional
14
15
  import parabellum as pb
15
16
 
@@ -20,7 +21,7 @@ bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
20
21
 
21
22
 
22
23
  # types
23
- State = jaxmarl.environments.smax.smax_env.State
24
+ State = jaxmarl.environments.smax.smax_env.State # type: ignore
24
25
  Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
25
26
  StateSeq = List[Tuple[jnp.ndarray, State, Action]]
26
27
 
@@ -35,12 +36,12 @@ class Control:
35
36
  @dataclass
36
37
  class Game:
37
38
  clock: pygame.time.Clock
38
- state: State
39
+ state: State # type: ignore
39
40
  obs: Dict
40
41
  state_seq: StateSeq
41
42
  control: Control
42
43
  env: pb.Environment
43
- rng: random.PRNGKey
44
+ rng: Array
44
45
 
45
46
 
46
47
  def handle_event(event, control_state):
@@ -100,7 +101,7 @@ def step_fn(game):
100
101
 
101
102
  # state
102
103
  if __name__ == "__main__":
103
- env = pb.Parabellum(pb.scenarios["default"])
104
+ env = pb.Environment(pb.scenarios["default"])
104
105
  pygame.init()
105
106
  screen = pygame.display.set_mode((1000, 1000))
106
107
  render = partial(render_fn, screen)
@@ -115,7 +116,7 @@ if __name__ == "__main__":
115
116
  state=state,
116
117
  obs=obs,
117
118
  )
118
- game = Game(**kwargs)
119
+ game = Game(**kwargs) # type: ignore
119
120
 
120
121
  while game.control.running:
121
122
  game = control_fn(game)
parabellum/vis.py CHANGED
@@ -2,36 +2,131 @@
2
2
  Visualizer for the Parabellum environment
3
3
  """
4
4
 
5
- from tqdm import tqdm
6
- import jax.numpy as jnp
7
- import jax
8
- from jax import vmap
9
- from jax import tree_util
5
+ # Standard library imports
10
6
  from functools import partial
11
- import darkdetect
12
- import numpy as np
13
- import pygame
14
- import matplotlib.pyplot as plt
15
- import os
16
- from moviepy.editor import ImageSequenceClip
17
- from typing import Optional
7
+ from typing import Optional, List, Tuple
8
+ import cv2
9
+ from PIL import Image
10
+
11
+ # JAX and JAX-related imports
12
+ import jax
13
+ from chex import dataclass
14
+ import chex
15
+ from jax import vmap, tree_util, Array, jit
16
+ import jax.numpy as jnp
18
17
  from jaxmarl.environments.multi_agent_env import MultiAgentEnv
18
+ from jaxmarl.environments.smax import SMAX
19
19
  from jaxmarl.viz.visualizer import SMAXVisualizer
20
20
 
21
- # default dict
22
- from collections import defaultdict
21
+ # Third-party imports
22
+ import numpy as np
23
+ import pygame
24
+ import cv2
25
+ from tqdm import tqdm
26
+
27
+ # Local imports
28
+ import parabellum as pb
29
+
30
+
31
+ # skin dataclass
32
+ @dataclass
33
+ class Skin:
34
+ # basemap: Array # basemap of buildings
35
+ maskmap: Array # maskmap of buildings
36
+ bg: Tuple[int, int, int] = (255, 255, 255)
37
+ fg: Tuple[int, int, int] = (0, 0, 0)
38
+ ally: Tuple[int, int, int] = (0, 255, 0)
39
+ enemy: Tuple[int, int, int] = (255, 0, 0)
40
+ pad: int = 100
41
+ size: int = 1000 # excluding padding
42
+ fps: int = 24
43
+ vis_size: int = 1000 # size of the map in Vis (exluding padding)
44
+ scale: Optional[float] = None
23
45
 
24
46
 
25
- # constants
26
- action_to_symbol = {0: "↑", 1: "→", 2: "↓", 3: "←", 4: "Ø"}
47
+ class Visualizer(SMAXVisualizer):
48
+ def __init__(self, env: pb.Environment, state_seq, skin: Skin, reward_seq=None):
49
+ super(Visualizer, self).__init__(env, state_seq, reward_seq)
27
50
 
51
+ # self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
52
+ self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
53
+ self.state_seq = state_seq
54
+ self.image = image_fn(skin)
55
+ self.skin = skin
56
+ self.skin.scale = self.skin.size / env.map_width # assumes square map
57
+ self.env = env
58
+
59
+
60
+ def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
61
+ expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
62
+ state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
63
+ for idx, (state_seq, action_seq) in enumerate(zip(state_seq_seq, action_seq_seq)):
64
+ animate_fn(self.env, self.skin, self.image, state_seq, action_seq, f"{save_fname}_{idx}.mp4")
65
+
66
+
67
+ # functions
68
+ def animate_fn(env, skin, image, state_seq, action_seq, save_fname):
69
+ pygame.init()
70
+ frames = []
71
+ for idx, (state_tup, action) in enumerate(zip(state_seq, action_seq)):
72
+ frames += [frame_fn(env, skin, image, state_tup[1], action, idx)]
73
+ # use cv2 to write frames to video
74
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
75
+ out = cv2.VideoWriter(save_fname, fourcc, skin.fps, (skin.size + skin.pad * 2, skin.size + skin.pad * 2))
76
+ for frame in frames:
77
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
78
+ out.release()
79
+ pygame.quit()
28
80
 
29
- def small_multiples():
30
- # make video of small multiples based on all videos in output
31
- video_files = [f"output/parabellum_{i}.mp4" for i in range(4)]
32
- # load mp4 videos and make a grid
33
- clips = [ImageSequenceClip.load(filename) for filename in video_files]
34
- print(len(clips))
81
+
82
+ def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> pygame.Surface:
83
+ dims = (skin.size + skin.pad * 2, skin.size + skin.pad * 2)
84
+ frame = pygame.Surface(dims, pygame.SRCALPHA | pygame.HWSURFACE)
85
+ return frame
86
+
87
+
88
+ def transform_frame(env, skin, frame):
89
+ frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
90
+ return frame
91
+
92
+
93
+ def frame_fn(env, skin, image, state: pb.State, action: Array, idx: int) -> np.ndarray:
94
+ """Create a frame"""
95
+ frame = init_frame(env, skin, image, state, action, idx)
96
+
97
+ pipeline = [render_background, render_agents, render_action, render_bullet]
98
+ for fn in pipeline:
99
+ frame = fn(env, skin, image, frame, state, action)
100
+
101
+ return transform_frame(env, skin, frame)
102
+
103
+
104
+ def render_background(env, skin, image, frame, state, action):
105
+ coords = (skin.pad-5, skin.pad-5, skin.size+10, skin.size+10)
106
+ frame.fill(skin.bg)
107
+ frame.blit(image, coords)
108
+ pygame.draw.rect(frame, skin.fg, coords, 3)
109
+ return frame
110
+
111
+
112
+ def render_action(env, skin, image, frame, state, action):
113
+ return frame
114
+
115
+
116
+ def render_bullet(env, skin, image, frame, state, action):
117
+ return frame
118
+
119
+ def render_agents(env, skin, image, frame, state, action):
120
+ units = state.unit_positions, state.unit_teams, state.unit_types, state.unit_health
121
+ for idx, (pos, team, kind, health) in enumerate(zip(*units)):
122
+ pos = tuple((pos * skin.scale).astype(int) + skin.pad)
123
+ # draw the agent
124
+ if health > 0:
125
+ unit_size = env.unit_type_radiuses[kind]
126
+ radius = float(jnp.ceil((unit_size * skin.scale)).astype(int) + 1)
127
+ pygame.draw.circle(frame, skin.fg, pos, radius, 1)
128
+ pygame.draw.circle(frame, skin.bg, pos, radius + 1, 1)
129
+ return frame
35
130
 
36
131
 
37
132
  def text_fn(text):
@@ -39,263 +134,36 @@ def text_fn(text):
39
134
  return pygame.transform.rotate(text, 180)
40
135
 
41
136
 
42
- class Visualizer(SMAXVisualizer):
43
- def __init__(self, env: MultiAgentEnv, state_seq, reward_seq=None, skin=None):
44
- self.fig, self.ax = None, None
45
- super().__init__(env, state_seq, reward_seq)
46
- # clear the figure made by SMAXVisualizer
47
- plt.close()
48
- self.bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)
49
- self.fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
50
- self.pad = 75
51
- # TODO: make sure it's always a 1024x1024 image
52
- self.width = 1000
53
- self.s = self.width + self.pad + self.pad
54
- self.scale = self.width / env.map_width
55
- self.action_seq = [action for _, _, action in state_seq] # bcs SMAX bug
56
- # self.bullet_seq = vmap(partial(bullet_fn, self.env))(self.state_seq)
137
+ def image_fn(skin: Skin): # TODO:
138
+ """Create an image for background (basemap or maskmap)"""
139
+ motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
140
+ motif = (motif > 0).astype(np.uint8)
141
+ image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
142
+ image[motif == 1] = skin.fg
143
+ image = pygame.surfarray.make_surface(image)
144
+ image = pygame.transform.scale(image, (skin.size, skin.size))
145
+ return image
57
146
 
58
- def animate(self, save_fname: str = "output/parabellum.mp4"):
59
- multi_dim = self.state_seq[0][1].unit_positions.ndim > 2
60
- if multi_dim:
61
- n_envs = self.state_seq[0][1].unit_positions.shape[0]
62
- if not self.have_expanded:
63
- state_seqs = vmap(self.env.expand_state_seq)(self.state_seq)
64
- self.have_expanded = True
65
- for i in range(n_envs):
66
- state_seq = jax.tree_map(lambda x: x[i], state_seqs)
67
- action_seq = jax.tree_map(lambda x: x[i], self.action_seq)
68
- self.animate_one(
69
- state_seq, action_seq, save_fname.replace(".mp4", f"_{i}.mp4")
70
- )
71
- else:
72
- state_seq = self.env.expand_state_seq(self.state_seq)
73
- self.animate_one(state_seq, self.action_seq, save_fname)
74
-
75
- def animate_one(self, state_seq, action_seq, save_fname):
76
- frames = [] # frames for the video
77
- pygame.init() # initialize pygame
78
- terrain = np.array(self.env.terrain_raster.T)
79
- rgb_array = np.zeros((terrain.shape[0], terrain.shape[1], 3), dtype=np.uint8)
80
- if darkdetect.isLight():
81
- rgb_array += 255
82
- rgb_array[terrain == 1] = self.fg
83
- mask_surface = pygame.surfarray.make_surface(rgb_array)
84
- mask_surface = pygame.transform.scale(mask_surface, (self.width, self.width))
85
-
86
- for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(self.state_seq)):
87
- action = action_seq[idx // self.env.world_steps_per_env_step]
88
- screen = pygame.Surface(
89
- (self.s, self.s), pygame.HWSURFACE | pygame.DOUBLEBUF
90
- )
91
- screen.fill(self.bg) # fill the screen with the background color
92
- screen.blit(
93
- mask_surface,
94
- (self.pad, self.pad, self.width, self.width),
95
- )
96
- # add env.scenario.place to the title (in top padding)
97
- font = pygame.font.SysFont("Fira Code", 18)
98
- width = self.env.map_width
99
- title = f"{width}x{width}m in {self.env.scenario.place}"
100
- text = text_fn(font.render(title, True, self.fg))
101
- # center the text
102
- screen.blit(text, (self.s // 2 - text.get_width() // 2, self.pad // 4))
103
- # draw edge around terrain
104
- pygame.draw.rect(
105
- screen,
106
- self.fg,
107
- (
108
- self.pad - 2,
109
- self.pad - 2,
110
- self.width + 4,
111
- self.width + 4,
112
- ),
113
- 2,
114
- )
115
-
116
- self.render_agents(screen, state) # render the agents
117
- self.render_action(screen, action)
118
- # self.render_obstacles(screen) # render the obstacles
119
-
120
- # bullets
121
- """ if idx < len(self.bullet_seq) * 8:
122
- bullets = self.bullet_seq[idx // 8]
123
- self.render_bullets(screen, bullets, idx % 8) """
124
-
125
- # rotate the screen and append to frames
126
- pixels = pygame.surfarray.pixels3d(screen).swapaxes(0, 1)
127
- # rotate the screen 180 degrees (transpose and flip)
128
- pixels = np.rot90(pixels, 2) # pygame starts in bottom left
129
- frames.append(pixels)
130
- # save the images
131
- clip = ImageSequenceClip(frames, fps=48)
132
- clip.write_videofile(save_fname, fps=48)
133
- clip.write_gif(save_fname.replace(".mp4", ".gif"), fps=24)
134
- pygame.quit()
135
-
136
- def render_agents(self, screen, state):
137
- time_tuple = zip(
138
- state.unit_positions,
139
- state.unit_teams,
140
- state.unit_types,
141
- state.unit_health,
142
- )
143
- for idx, (pos, team, kind, hp) in enumerate(time_tuple):
144
- face_col = self.fg if int(team.item()) == 0 else self.bg
145
- pos = tuple(((pos * self.scale) + self.pad).tolist())
146
- # draw the agent
147
- if hp > 0:
148
- hp_frac = hp / self.env.unit_type_health[kind]
149
- unit_size = self.env.unit_type_radiuses[kind]
150
- radius = jnp.ceil((unit_size * self.scale * hp_frac)).astype(int) + 1
151
- pygame.draw.circle(screen, face_col, pos, radius)
152
- pygame.draw.circle(screen, self.fg, pos, radius, 1)
153
-
154
- # draw the sight range
155
- # sight_range = self.env.unit_type_sight_ranges[kind] * self.scale
156
- # pygame.draw.circle(screen, self.fg, pos, sight_range.astype(int), 2)
157
-
158
- # draw attack range
159
- # attack_range = self.env.unit_type_attack_ranges[kind] * self.scale
160
- # pygame.draw.circle(screen, self.fg, pos, attack_range.astype(int), 2)
161
- # work out which agents are being shot
162
-
163
- def render_action(self, screen, action):
164
- if self.env.action_type != "discrete":
165
- return
166
-
167
- def coord_fn(idx, n, team, text):
168
- text_adj = text.get_width() / 2
169
- is_ally = team == "ally"
170
- return (
171
- # vertically centered so that n / 2 is above and below the center
172
- self.pad + self.width + self.pad / 2 - text_adj
173
- if is_ally
174
- else self.pad / 2 - text_adj,
175
- self.s / 2 - (n / 2) * self.s / 20 + idx * self.s / 20,
176
- )
177
-
178
- for team, number in [("ally", 0), ("enemy", 1)]:
179
- for idx in range(self.env.num_allies):
180
- symb = action_to_symbol.get(
181
- action[f"{team}_{idx}"].astype(int).item(), "Ø"
182
- )
183
- font = pygame.font.SysFont(
184
- "Fira Code", jnp.sqrt(self.s).astype(int).item()
185
- )
186
- text = text_fn(font.render(symb, True, self.fg))
187
- coord = coord_fn(idx, self.env.num_allies, team, text)
188
- screen.blit(text, coord)
189
-
190
- def render_obstacles(self, screen):
191
- for c, d in zip(self.env.obstacle_coords, self.env.obstacle_deltas):
192
- d = tuple(((c + d) * self.scale).tolist())
193
- c = tuple((c * self.scale).tolist())
194
- pygame.draw.line(screen, self.fg, c, d, 5)
195
-
196
- def render_bullets(self, screen, bullets, jdx):
197
- jdx += 1
198
- ally_bullets, enemy_bullets = bullets
199
- for source, target in ally_bullets:
200
- position = source + (target - source) * jdx / 8
201
- position *= self.scale
202
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
203
- for source, target in enemy_bullets:
204
- position = source + (target - source) * jdx / 8
205
- position *= self.scale
206
- pygame.draw.circle(screen, self.fg, tuple(position.tolist()), 3)
207
147
 
148
+ def unbatch_fn(state_seq, action_seq):
149
+ """state seq is a list of tuples of (step_key, state, actions)."""
150
+ if is_multi_run(state_seq):
151
+ n_envs = state_seq[0][1].unit_positions.shape[0]
152
+ state_seq_seq = [jax.tree_map(lambda x: x[i], state_seq) for i in range(n_envs)]
153
+ action_seq_seq = [jax.tree_map(lambda x: x[i], action_seq) for i in range(n_envs)]
154
+ else:
155
+ state_seq_seq = [state_seq]
156
+ action_seq_seq = [action_seq]
157
+ return state_seq_seq, action_seq_seq
208
158
 
209
- # functions
210
- # bullet functions
211
- def dist_fn(env, pos): # computing the distances between all ally and enemy agents
212
- delta = pos[None, :, :] - pos[:, None, :]
213
- dist = jnp.sqrt((delta**2).sum(axis=2))
214
- dist = dist[: env.num_allies, env.num_allies :]
215
- return {"ally": dist, "enemy": dist.T}
216
-
217
-
218
- def range_fn(env, dists, ranges): # computing what targets are in range
219
- ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
220
- enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
221
- return {"ally": ally_range, "enemy": enemy_range}
222
-
223
-
224
- def target_fn(acts, in_range, team): # computing the one hot valid targets
225
- t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
226
- t_targets = jnp.where(t_acts > 4, -1, t_acts - 5) # first 5 are move actions
227
- t_attacks = jnp.eye(in_range[team].shape[1] + 1)[t_targets][:, :-1]
228
- return t_attacks * in_range[team] # one hot valid targets
229
-
230
-
231
- def attack_fn(env, state_seq): # one hot attack list
232
- attacks = []
233
- for _, state, acts in state_seq:
234
- dists = dist_fn(env, state.unit_positions)
235
- ranges = env.unit_type_attack_ranges[state.unit_types]
236
- in_range = range_fn(env, dists, ranges)
237
- target = partial(target_fn, acts, in_range)
238
- attack = {"ally": target("ally"), "enemy": target("enemy")}
239
- attacks.append(attack)
240
- return attacks
241
-
242
-
243
- def bullet_fn(env, states):
244
- bullet_seq = []
245
- attack_seq = attack_fn(env, states)
246
-
247
- def aux_fn(team):
248
- bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
249
- # bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
250
- return bullets
251
-
252
- state_zip = zip(states[:-1], states[1:])
253
- for i, ((_, state, _), (_, n_state, _)) in enumerate(state_zip):
254
- one_hot = attack_seq[i]
255
- ally_bullets, enemy_bullets = aux_fn("ally"), aux_fn("enemy")
256
-
257
- ally_bullets_source = state.unit_positions[ally_bullets[:, 0]]
258
- enemy_bullets_target = n_state.unit_positions[enemy_bullets[:, 1]]
259
-
260
- enemy_bullets_source = state.unit_positions[
261
- enemy_bullets[:, 0] + env.num_allies
262
- ]
263
- ally_bullets_target = n_state.unit_positions[
264
- ally_bullets[:, 1] + env.num_allies
265
- ]
266
-
267
- ally_bullets = jnp.stack((ally_bullets_source, ally_bullets_target), axis=1)
268
- enemy_bullets = jnp.stack((enemy_bullets_source, enemy_bullets_target), axis=1)
269
-
270
- bullet_seq.append((ally_bullets, enemy_bullets))
271
- return bullet_seq
272
-
273
-
274
- # test the visualizer
275
- if __name__ == "__main__":
276
- from jax import random, numpy as jnp
277
- from parabellum import Parabellum, scenarios
278
-
279
- # small_multiples() # testing small multiples (not working yet)
280
- # exit()
281
-
282
- n_envs = 2
283
- env = Parabellum(scenarios["default"], action_type="discrete")
284
- rng, reset_rng = random.split(random.PRNGKey(0))
285
- reset_key = random.split(reset_rng, n_envs)
286
- obs, state = vmap(env.reset)(reset_key)
287
- state_seq = []
288
-
289
- for i in range(100):
290
- rng, act_rng, step_rng = random.split(rng, 3)
291
- act_key = random.split(act_rng, (len(env.agents), n_envs))
292
- act = {
293
- a: vmap(env.action_space(a).sample)(act_key[i])
294
- for i, a in enumerate(env.agents)
295
- }
296
- step_key = random.split(step_rng, n_envs)
297
- state_seq.append((step_key, state, act))
298
- obs, state, reward, done, infos = vmap(env.step)(step_key, state, act)
299
-
300
- vis = Visualizer(env, state_seq)
301
- vis.animate()
159
+
160
+ def expand_fn(env, state_seq, action_seq):
161
+ """Expand the state sequence"""
162
+ fn = env.expand_state_seq
163
+ state_seq = vmap(fn)(state_seq) if is_multi_run(state_seq) else fn(state_seq)
164
+ action_seq = [action_seq[i // env.world_steps_per_env_step] for i in range(len(state_seq))]
165
+ return state_seq, action_seq
166
+
167
+
168
+ def is_multi_run(state_seq):
169
+ return state_seq[0][1].unit_positions.ndim > 2
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.20
3
+ Version: 0.2.22
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -12,15 +12,20 @@ Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Dist: contextily (>=1.6.0,<2.0.0)
15
16
  Requires-Dist: darkdetect (>=0.8.0,<0.9.0)
17
+ Requires-Dist: folium (>=0.17.0,<0.18.0)
16
18
  Requires-Dist: geopandas (>=1.0.0,<2.0.0)
17
19
  Requires-Dist: geopy (>=2.4.1,<3.0.0)
20
+ Requires-Dist: ipykernel (>=6.29.5,<7.0.0)
18
21
  Requires-Dist: jax (==0.4.17)
19
22
  Requires-Dist: jaxmarl (==0.0.3)
20
23
  Requires-Dist: jupyterlab (>=4.2.2,<5.0.0)
21
24
  Requires-Dist: moviepy (>=1.0.3,<2.0.0)
22
25
  Requires-Dist: numpy (<2)
26
+ Requires-Dist: opencv-python (>=4.10.0.84,<5.0.0.0)
23
27
  Requires-Dist: osmnx (>=1.9.3,<2.0.0)
28
+ Requires-Dist: pandas (>=2.2.2,<3.0.0)
24
29
  Requires-Dist: poetry (>=1.8.3,<2.0.0)
25
30
  Requires-Dist: pygame (>=2.5.2,<3.0.0)
26
31
  Requires-Dist: rasterio (>=1.3.10,<2.0.0)
@@ -0,0 +1,10 @@
1
+ parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
+ parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
+ parabellum/env.py,sha256=u0NuQUQMKz92Ke9IpNtwTClgxBnnEvGNqW6GgA57mps,16975
4
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
+ parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
6
+ parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
+ parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
8
+ parabellum-0.2.22.dist-info/METADATA,sha256=FZgaXTNbHOIhwezMuyFQDFmeECghpkkyQYt3b3PVoYo,2671
9
+ parabellum-0.2.22.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
+ parabellum-0.2.22.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- parabellum/__init__.py,sha256=cI1kxVQ274VZrBLxUYvB_j6UqhFY44sId4NgVjbf678,245
2
- parabellum/env.py,sha256=B2WaEzCQnakVB1AuYqFsb7aiPIbaUsuq0CjdKq4pQm8,16204
3
- parabellum/map.py,sha256=CdPebRuGafj9fAydAyTa8bqpr-PldZw7heKRReGTDBg,1257
4
- parabellum/run.py,sha256=HyPdz5iVD8q0iYaZL2Nf02fGHByHpecUGoPlQrq9v8s,3411
5
- parabellum/vis.py,sha256=uIhN9VSJlT4XoNuNCiU_Bw2VSNoJFOYlScCTNrMkGsg,11977
6
- parabellum-0.2.20.dist-info/METADATA,sha256=IO3Wsx_nTM3v63FtbnJqaEfPfXOyhNGtKOm_ppchUqo,2453
7
- parabellum-0.2.20.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
8
- parabellum-0.2.20.dist-info/RECORD,,