parabellum 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py CHANGED
@@ -19,6 +19,7 @@ class Scenario:
19
19
 
20
20
  place: str
21
21
  terrain_raster: jnp.ndarray
22
+ unit_starting_sectors: jnp.ndarray
22
23
  unit_types: chex.Array
23
24
  num_allies: int
24
25
  num_enemies: int
@@ -45,6 +46,7 @@ scenarios = {
45
46
  "default": Scenario(
46
47
  "Identity Town",
47
48
  jnp.eye(64, dtype=jnp.uint8),
49
+ jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
48
50
  jnp.zeros((19,), dtype=jnp.uint8),
49
51
  9,
50
52
  10,
@@ -52,11 +54,20 @@ scenarios = {
52
54
  }
53
55
 
54
56
 
55
- def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
56
- """Create a scenario"""
57
- num_agents = num_allies + num_enemies
58
- unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
59
- return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
57
+ def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_allies, enemies_type, n_enemies):
58
+ if type(allies_type) == int:
59
+ allies = [allies_type] * n_allies
60
+ else:
61
+ assert(len(allies_type) == n_allies)
62
+ allies = allies_type
63
+
64
+ if type(enemies_type) == int:
65
+ enemies = [enemies_type] * n_enemies
66
+ else:
67
+ assert(len(enemies_type) == n_enemies)
68
+ enemies = enemies_type
69
+ unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
70
+ return Scenario(place, terrain_raster, unit_starting_sectors, unit_types, n_allies, n_enemies)
60
71
 
61
72
 
62
73
  def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
@@ -81,29 +92,41 @@ def sector_fn(terrain: jnp.ndarray, sector_id: int):
81
92
  return jnp.nonzero(sector), offset
82
93
 
83
94
 
95
+ def sector_fn(terrain: jnp.ndarray, sector: jnp.ndarray):
96
+ """return sector slice of terrain"""
97
+ width, height = terrain.shape
98
+ coordx, coordy = int(sector[0] * width), int(sector[1] * height)
99
+ sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
100
+ offset = jnp.array([coordx, coordy])
101
+ # sector is jnp.nonzero
102
+ return jnp.nonzero(sector.T), offset
103
+
104
+
84
105
  class Environment(SMAX):
85
106
  def __init__(self, scenario: Scenario, **kwargs):
86
107
  map_height, map_width = scenario.terrain_raster.shape
87
108
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
88
109
  super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
89
110
  self.terrain_raster = scenario.terrain_raster
111
+ self.unit_starting_sectors = scenario.unit_starting_sectors
90
112
  # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
91
113
  # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
92
114
  # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
93
115
  self.scenario = scenario
116
+ self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
94
117
  self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
95
118
  self.max_steps = 200
96
119
  self._push_units_away = lambda state, firmness = 1: state # overwrite push units
97
- self.top_sector, self.top_sector_offset = sector_fn(self.terrain_raster, 0)
98
- self.low_sector, self.low_sector_offset = sector_fn(self.terrain_raster, 24)
120
+ self.team0_sector, self.team0_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[0]) # sector_fn(self.terrain_raster, 0)
121
+ self.team1_sector, self.team1_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[1]) # sector_fn(self.terrain_raster, 24)
99
122
 
100
123
 
101
124
  @partial(jax.jit, static_argnums=(0,))
102
125
  def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
103
126
  """Environment-specific reset."""
104
127
  ally_key, enemy_key = jax.random.split(rng)
105
- team_0_start = spawn_fn(self.top_sector, self.top_sector_offset, self.num_allies, ally_key)
106
- team_1_start = spawn_fn(self.low_sector, self.low_sector_offset, self.num_enemies, enemy_key)
128
+ team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
129
+ team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
107
130
  unit_positions = jnp.concatenate([team_0_start, team_1_start])
108
131
  unit_teams = jnp.zeros((self.num_agents,))
109
132
  unit_teams = unit_teams.at[self.num_allies :].set(1)
@@ -126,14 +149,9 @@ class Environment(SMAX):
126
149
  state = self._push_units_away(state) # type: ignore
127
150
  obs = self.get_obs(state)
128
151
  world_state = self.get_world_state(state)
129
- # obs["world_state"] = jax.lax.stop_gradient(world_state)
152
+ obs["world_state"] = jax.lax.stop_gradient(world_state)
130
153
  return obs, state
131
154
 
132
- def step_env(self, rng, state: State, action: Array):
133
- obs, state, rewards, dones, infos = super().step_env(rng, state, action)
134
- # delete world_state from obs
135
- obs.pop("world_state")
136
- return obs, state, rewards, dones, infos
137
155
 
138
156
  def _our_push_units_away(
139
157
  self, pos, unit_types, firmness: float = 1.0
@@ -172,15 +190,15 @@ class Environment(SMAX):
172
190
 
173
191
  def raster_crossing(pos, new_pos):
174
192
  pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
175
- raster = self.terrain_raster
176
- axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
177
- minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
178
- maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
179
- segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
180
- segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
181
- segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
182
- return jnp.any(segment)
183
-
193
+ raster = jnp.copy(self.terrain_raster)
194
+ minimum = jnp.minimum(pos, new_pos)
195
+ maximum = jnp.maximum(pos, new_pos)
196
+ raster = jnp.where(jnp.arange(raster.shape[0]) >= minimum[0], raster, 0)
197
+ raster = jnp.where(jnp.arange(raster.shape[0]) <= maximum[0], raster, 0)
198
+ raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
199
+ raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
200
+ return jnp.any(raster)
201
+
184
202
  def update_position(idx, vec):
185
203
  # Compute the movements slightly strangely.
186
204
  # The velocities below are for diagonal directions
@@ -202,11 +220,7 @@ class Environment(SMAX):
202
220
 
203
221
  #######################################################################
204
222
  ############################################ avoid going into obstacles
205
- """ obs = self.obstacle_coords
206
- obs_end = obs + self.obstacle_deltas
207
- inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
208
223
  clash = raster_crossing(pos, new_pos)
209
- # flag = jnp.logical_or(inters, rastersects)
210
224
  new_pos = jnp.where(clash, pos, new_pos)
211
225
 
212
226
  #######################################################################
@@ -319,39 +333,11 @@ class Environment(SMAX):
319
333
 
320
334
  # units push each other
321
335
  new_pos = self._our_push_units_away(pos, state.unit_types)
322
-
323
- # avoid going into obstacles after being pushed
324
-
325
- bondaries_coords = jnp.array(
326
- [[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
327
- )
328
- bondaries_deltas = jnp.array(
329
- [
330
- [self.map_width, 0],
331
- [0, self.map_height],
332
- [0, self.map_height],
333
- [self.map_width, 0],
334
- ]
335
- )
336
- """ obstacle_coords = jnp.concatenate(
337
- [self.obstacle_coords, bondaries_coords]
338
- ) # add the map boundaries to the obstacles to avoid
339
- obstacle_deltas = jnp.concatenate(
340
- [self.obstacle_deltas, bondaries_deltas]
341
- ) # add the map boundaries to the obstacles to avoid
342
- obst_start = obstacle_coords
343
- obst_end = obst_start + obstacle_deltas """
344
-
345
- def check_obstacles(pos, new_pos, obst_start, obst_end):
346
- intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
347
- rastersect = raster_crossing(pos, new_pos)
348
- flag = jnp.logical_or(intersects, rastersect)
349
- return jnp.where(flag, pos, new_pos)
350
-
351
- """ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
352
- pos, new_pos, obst_start, obst_end
353
- ) """
354
-
336
+ clash = jax.vmap(raster_crossing)(pos, new_pos)
337
+ pos = jax.vmap(jnp.where)(clash, pos, new_pos)
338
+ # avoid going out of bounds
339
+ pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
340
+
355
341
  # Multiple enemies can attack the same unit.
356
342
  # We have `(health_diff, attacked_idx)` pairs.
357
343
  # `jax.lax.scatter_add` aggregates these exactly
parabellum/map.py CHANGED
@@ -14,6 +14,8 @@ import rasterio.transform
14
14
  from typing import Optional, Tuple
15
15
  from geopy.location import Location
16
16
  from shapely.geometry import Point
17
+ import os
18
+ import pickle
17
19
 
18
20
  # constants
19
21
  geolocator = Nominatim(user_agent="parabellum")
@@ -36,16 +38,41 @@ def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
36
38
  w, s, e, n = gdf.total_bounds
37
39
  transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
38
40
  raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
39
- return jnp.array(jnp.rot90(raster, 2)).astype(jnp.uint8)
41
+ return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
40
42
 
41
- def terrain_fn(place: str, size: int = 1000) -> Tuple[jnp.ndarray, jnp.ndarray]:
43
+ # +
44
+ def get_from_cache(place, size):
45
+ if os.path.exists("./cache"):
46
+ name = str(hash((place, size))) + ".pk"
47
+ if os.path.exists("./cache/" + name):
48
+ with open("./cache/" + name, "rb") as f:
49
+ (mask, base) = pickle.load(f)
50
+ return (mask, base.astype(jnp.int64))
51
+ return (None, None)
52
+
53
+ def save_in_cache(place, size, mask, base):
54
+ if not os.path.exists("./cache"):
55
+ os.makedirs("./cache")
56
+ name = str(hash((place, size))) + ".pk"
57
+ with open("./cache/" + name, "wb") as f:
58
+ pickle.dump((mask, base), f)
59
+
60
+ def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
42
61
  """Returns a rasterized map of buildings for a given location."""
43
- point = get_location(place)
44
- gdf = get_building_geometry(point, size)
45
- mask = rasterize_geometry(gdf, size)
46
- base = get_basemap(place, size)
62
+ if with_cache:
63
+ mask, base = get_from_cache(place, size)
64
+ if not with_cache or mask is None:
65
+ point = get_location(place)
66
+ gdf = get_building_geometry(point, size)
67
+ mask = rasterize_geometry(gdf, size)
68
+ base = get_basemap(place, size)
69
+ if with_cache:
70
+ save_in_cache(place, size, mask, base)
47
71
  return mask, base
48
72
 
73
+
74
+ # -
75
+
49
76
  def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
50
77
  """Returns a basemap for a given place as a JAX array."""
51
78
  point = get_location(place)
@@ -54,15 +81,20 @@ def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
54
81
  # get the middle size x size square
55
82
  basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
56
83
  (basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
57
- return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
84
+ return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
58
85
 
59
86
 
60
87
  if __name__ == "__main__":
61
- import seaborn as sns
62
- place = "Thun, Switzerland"
63
- mask, base = terrain_fn(place)
64
-
65
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
66
- ax[0].imshow(mask) # type: ignore
88
+ place = "Cauvicourt, 14190, France"
89
+ mask, base = terrain_fn(place, 500)
90
+
91
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
92
+ ax[0].imshow(jnp.flip(mask,0)) # type: ignore
67
93
  ax[1].imshow(base) # type: ignore
94
+ ax[2].imshow(base) # type: ignore
95
+ ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
68
96
  plt.show()
97
+
98
+
99
+
100
+
parabellum/vis.py CHANGED
@@ -56,7 +56,6 @@ class Visualizer(SMAXVisualizer):
56
56
  self.skin.scale = self.skin.size / env.map_width # assumes square map
57
57
  self.env = env
58
58
 
59
-
60
59
  def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
61
60
  expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
62
61
  state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
@@ -86,7 +85,8 @@ def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> py
86
85
 
87
86
 
88
87
  def transform_frame(env, skin, frame):
89
- frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
88
+ #frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
89
+ frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
90
90
  return frame
91
91
 
92
92
 
@@ -136,7 +136,7 @@ def text_fn(text):
136
136
 
137
137
  def image_fn(skin: Skin): # TODO:
138
138
  """Create an image for background (basemap or maskmap)"""
139
- motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
139
+ motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
140
140
  motif = (motif > 0).astype(np.uint8)
141
141
  image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
142
142
  image[motif == 1] = skin.fg
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.23
3
+ Version: 0.2.25
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -0,0 +1,10 @@
1
+ parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
+ parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
+ parabellum/env.py,sha256=H1VwxHj8KqbjBJ8b7NMxkAg3Q0qwVzGulrigc26Tzkc,16663
4
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
+ parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
6
+ parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
+ parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
8
+ parabellum-0.2.25.dist-info/METADATA,sha256=UCuRLYhSUxnebs5pRs1FaWK67PUJw0kgsJ2UaaxNM9Q,2671
9
+ parabellum-0.2.25.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
+ parabellum-0.2.25.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
- parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
- parabellum/env.py,sha256=2IuXipHXUJAyfrjPtDW7uINb8mY4G3-xQ_lhReMGqLs,16985
4
- parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
- parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
6
- parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
- parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
8
- parabellum-0.2.23.dist-info/METADATA,sha256=qwEcFJksQ54_MNpMj8Zq1EC2qDgltEC9u9ivZFxZ0Eg,2671
9
- parabellum-0.2.23.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
- parabellum-0.2.23.dist-info/RECORD,,