parabellum 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/env.py CHANGED
@@ -19,6 +19,7 @@ class Scenario:
19
19
 
20
20
  place: str
21
21
  terrain_raster: jnp.ndarray
22
+ unit_starting_sectors: jnp.ndarray
22
23
  unit_types: chex.Array
23
24
  num_allies: int
24
25
  num_enemies: int
@@ -45,6 +46,7 @@ scenarios = {
45
46
  "default": Scenario(
46
47
  "Identity Town",
47
48
  jnp.eye(64, dtype=jnp.uint8),
49
+ jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
48
50
  jnp.zeros((19,), dtype=jnp.uint8),
49
51
  9,
50
52
  10,
@@ -52,11 +54,20 @@ scenarios = {
52
54
  }
53
55
 
54
56
 
55
- def make_scenario(place, terrain_raster, num_allies=9, num_enemies=10):
56
- """Create a scenario"""
57
- num_agents = num_allies + num_enemies
58
- unit_types = jnp.zeros((num_agents,)).astype(jnp.uint8)
59
- return Scenario(place, terrain_raster, unit_types, num_allies, num_enemies)
57
+ def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_allies, enemies_type, n_enemies):
58
+ if type(allies_type) == int:
59
+ allies = [allies_type] * n_allies
60
+ else:
61
+ assert(len(allies_type) == n_allies)
62
+ allies = allies_type
63
+
64
+ if type(enemies_type) == int:
65
+ enemies = [enemies_type] * n_enemies
66
+ else:
67
+ assert(len(enemies_type) == n_enemies)
68
+ enemies = enemies_type
69
+ unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
70
+ return Scenario(place, terrain_raster, unit_starting_sectors, unit_types, n_allies, n_enemies)
60
71
 
61
72
 
62
73
  def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
@@ -81,29 +92,41 @@ def sector_fn(terrain: jnp.ndarray, sector_id: int):
81
92
  return jnp.nonzero(sector), offset
82
93
 
83
94
 
95
+ def sector_fn(terrain: jnp.ndarray, sector: jnp.ndarray):
96
+ """return sector slice of terrain"""
97
+ width, height = terrain.shape
98
+ coordx, coordy = int(sector[0] * width), int(sector[1] * height)
99
+ sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
100
+ offset = jnp.array([coordx, coordy])
101
+ # sector is jnp.nonzero
102
+ return jnp.nonzero(sector.T), offset
103
+
104
+
84
105
  class Environment(SMAX):
85
106
  def __init__(self, scenario: Scenario, **kwargs):
86
107
  map_height, map_width = scenario.terrain_raster.shape
87
108
  args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
88
109
  super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
89
110
  self.terrain_raster = scenario.terrain_raster
111
+ self.unit_starting_sectors = scenario.unit_starting_sectors
90
112
  # self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
91
113
  # self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
92
114
  # self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
93
115
  self.scenario = scenario
116
+ self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
94
117
  self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
95
118
  self.max_steps = 200
96
119
  self._push_units_away = lambda state, firmness = 1: state # overwrite push units
97
- self.top_sector, self.top_sector_offset = sector_fn(self.terrain_raster, 0)
98
- self.low_sector, self.low_sector_offset = sector_fn(self.terrain_raster, 24)
120
+ self.team0_sector, self.team0_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[0]) # sector_fn(self.terrain_raster, 0)
121
+ self.team1_sector, self.team1_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[1]) # sector_fn(self.terrain_raster, 24)
99
122
 
100
123
 
101
124
  @partial(jax.jit, static_argnums=(0,))
102
125
  def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
103
126
  """Environment-specific reset."""
104
127
  ally_key, enemy_key = jax.random.split(rng)
105
- team_0_start = spawn_fn(self.top_sector, self.top_sector_offset, self.num_allies, ally_key)
106
- team_1_start = spawn_fn(self.low_sector, self.low_sector_offset, self.num_enemies, enemy_key)
128
+ team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
129
+ team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
107
130
  unit_positions = jnp.concatenate([team_0_start, team_1_start])
108
131
  unit_teams = jnp.zeros((self.num_agents,))
109
132
  unit_teams = unit_teams.at[self.num_allies :].set(1)
@@ -167,15 +190,15 @@ class Environment(SMAX):
167
190
 
168
191
  def raster_crossing(pos, new_pos):
169
192
  pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
170
- raster = self.terrain_raster
171
- axis = jnp.argmax(jnp.abs(new_pos - pos), axis=-1)
172
- minimum = jnp.minimum(pos[axis], new_pos[axis]).squeeze()
173
- maximum = jnp.maximum(pos[axis], new_pos[axis]).squeeze()
174
- segment = jnp.where(axis == 0, raster[pos[1]], raster.T[pos[0]])
175
- segment = jnp.where(jnp.arange(segment.shape[0]) >= minimum, segment, 0)
176
- segment = jnp.where(jnp.arange(segment.shape[0]) <= maximum, segment, 0)
177
- return jnp.any(segment)
178
-
193
+ raster = jnp.copy(self.terrain_raster)
194
+ minimum = jnp.minimum(pos, new_pos)
195
+ maximum = jnp.maximum(pos, new_pos)
196
+ raster = jnp.where(jnp.arange(raster.shape[0]) >= minimum[0], raster, 0)
197
+ raster = jnp.where(jnp.arange(raster.shape[0]) <= maximum[0], raster, 0)
198
+ raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
199
+ raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
200
+ return jnp.any(raster)
201
+
179
202
  def update_position(idx, vec):
180
203
  # Compute the movements slightly strangely.
181
204
  # The velocities below are for diagonal directions
@@ -197,11 +220,7 @@ class Environment(SMAX):
197
220
 
198
221
  #######################################################################
199
222
  ############################################ avoid going into obstacles
200
- """ obs = self.obstacle_coords
201
- obs_end = obs + self.obstacle_deltas
202
- inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
203
223
  clash = raster_crossing(pos, new_pos)
204
- # flag = jnp.logical_or(inters, rastersects)
205
224
  new_pos = jnp.where(clash, pos, new_pos)
206
225
 
207
226
  #######################################################################
@@ -314,39 +333,11 @@ class Environment(SMAX):
314
333
 
315
334
  # units push each other
316
335
  new_pos = self._our_push_units_away(pos, state.unit_types)
317
-
318
- # avoid going into obstacles after being pushed
319
-
320
- bondaries_coords = jnp.array(
321
- [[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
322
- )
323
- bondaries_deltas = jnp.array(
324
- [
325
- [self.map_width, 0],
326
- [0, self.map_height],
327
- [0, self.map_height],
328
- [self.map_width, 0],
329
- ]
330
- )
331
- """ obstacle_coords = jnp.concatenate(
332
- [self.obstacle_coords, bondaries_coords]
333
- ) # add the map boundaries to the obstacles to avoid
334
- obstacle_deltas = jnp.concatenate(
335
- [self.obstacle_deltas, bondaries_deltas]
336
- ) # add the map boundaries to the obstacles to avoid
337
- obst_start = obstacle_coords
338
- obst_end = obst_start + obstacle_deltas """
339
-
340
- def check_obstacles(pos, new_pos, obst_start, obst_end):
341
- intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
342
- rastersect = raster_crossing(pos, new_pos)
343
- flag = jnp.logical_or(intersects, rastersect)
344
- return jnp.where(flag, pos, new_pos)
345
-
346
- """ pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
347
- pos, new_pos, obst_start, obst_end
348
- ) """
349
-
336
+ clash = jax.vmap(raster_crossing)(pos, new_pos)
337
+ pos = jax.vmap(jnp.where)(clash, pos, new_pos)
338
+ # avoid going out of bounds
339
+ pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
340
+
350
341
  # Multiple enemies can attack the same unit.
351
342
  # We have `(health_diff, attacked_idx)` pairs.
352
343
  # `jax.lax.scatter_add` aggregates these exactly
parabellum/map.py CHANGED
@@ -14,6 +14,8 @@ import rasterio.transform
14
14
  from typing import Optional, Tuple
15
15
  from geopy.location import Location
16
16
  from shapely.geometry import Point
17
+ import os
18
+ import pickle
17
19
 
18
20
  # constants
19
21
  geolocator = Nominatim(user_agent="parabellum")
@@ -36,16 +38,41 @@ def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
36
38
  w, s, e, n = gdf.total_bounds
37
39
  transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
38
40
  raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
39
- return jnp.array(jnp.rot90(raster, 2)).astype(jnp.uint8)
41
+ return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
40
42
 
41
- def terrain_fn(place: str, size: int = 1000) -> Tuple[jnp.ndarray, jnp.ndarray]:
43
+ # +
44
+ def get_from_cache(place, size):
45
+ if os.path.exists("./cache"):
46
+ name = str(hash((place, size))) + ".pk"
47
+ if os.path.exists("./cache/" + name):
48
+ with open("./cache/" + name, "rb") as f:
49
+ (mask, base) = pickle.load(f)
50
+ return (mask, base.astype(jnp.int64))
51
+ return (None, None)
52
+
53
+ def save_in_cache(place, size, mask, base):
54
+ if not os.path.exists("./cache"):
55
+ os.makedirs("./cache")
56
+ name = str(hash((place, size))) + ".pk"
57
+ with open("./cache/" + name, "wb") as f:
58
+ pickle.dump((mask, base), f)
59
+
60
+ def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
42
61
  """Returns a rasterized map of buildings for a given location."""
43
- point = get_location(place)
44
- gdf = get_building_geometry(point, size)
45
- mask = rasterize_geometry(gdf, size)
46
- base = get_basemap(place, size)
62
+ if with_cache:
63
+ mask, base = get_from_cache(place, size)
64
+ if not with_cache or mask is None:
65
+ point = get_location(place)
66
+ gdf = get_building_geometry(point, size)
67
+ mask = rasterize_geometry(gdf, size)
68
+ base = get_basemap(place, size)
69
+ if with_cache:
70
+ save_in_cache(place, size, mask, base)
47
71
  return mask, base
48
72
 
73
+
74
+ # -
75
+
49
76
  def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
50
77
  """Returns a basemap for a given place as a JAX array."""
51
78
  point = get_location(place)
@@ -54,15 +81,20 @@ def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
54
81
  # get the middle size x size square
55
82
  basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
56
83
  (basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
57
- return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
84
+ return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
58
85
 
59
86
 
60
87
  if __name__ == "__main__":
61
- import seaborn as sns
62
- place = "Thun, Switzerland"
63
- mask, base = terrain_fn(place)
64
-
65
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
66
- ax[0].imshow(mask) # type: ignore
88
+ place = "Cauvicourt, 14190, France"
89
+ mask, base = terrain_fn(place, 500)
90
+
91
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
92
+ ax[0].imshow(jnp.flip(mask,0)) # type: ignore
67
93
  ax[1].imshow(base) # type: ignore
94
+ ax[2].imshow(base) # type: ignore
95
+ ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
68
96
  plt.show()
97
+
98
+
99
+
100
+
parabellum/vis.py CHANGED
@@ -56,7 +56,6 @@ class Visualizer(SMAXVisualizer):
56
56
  self.skin.scale = self.skin.size / env.map_width # assumes square map
57
57
  self.env = env
58
58
 
59
-
60
59
  def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
61
60
  expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
62
61
  state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
@@ -86,7 +85,8 @@ def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> py
86
85
 
87
86
 
88
87
  def transform_frame(env, skin, frame):
89
- frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
88
+ #frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
89
+ frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
90
90
  return frame
91
91
 
92
92
 
@@ -136,7 +136,7 @@ def text_fn(text):
136
136
 
137
137
  def image_fn(skin: Skin): # TODO:
138
138
  """Create an image for background (basemap or maskmap)"""
139
- motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_LANCZOS4).astype(np.uint8)
139
+ motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
140
140
  motif = (motif > 0).astype(np.uint8)
141
141
  image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
142
142
  image[motif == 1] = skin.fg
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: parabellum
3
- Version: 0.2.24
3
+ Version: 0.2.25
4
4
  Summary: Parabellum environment for parallel warfare simulation
5
5
  Home-page: https://github.com/syrkis/parabellum
6
6
  License: MIT
@@ -0,0 +1,10 @@
1
+ parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
+ parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
+ parabellum/env.py,sha256=H1VwxHj8KqbjBJ8b7NMxkAg3Q0qwVzGulrigc26Tzkc,16663
4
+ parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
+ parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
6
+ parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
+ parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
8
+ parabellum-0.2.25.dist-info/METADATA,sha256=UCuRLYhSUxnebs5pRs1FaWK67PUJw0kgsJ2UaaxNM9Q,2671
9
+ parabellum-0.2.25.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
+ parabellum-0.2.25.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
2
- parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
3
- parabellum/env.py,sha256=rzhsMiCRVBOza63XboPYTLr8MRO45HUIIYX8qqFJHmE,16726
4
- parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
5
- parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
6
- parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
7
- parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
8
- parabellum-0.2.24.dist-info/METADATA,sha256=0TXKsb81R0YnMOnpTIrEfDUOJVs92jXbKGpILB2WDO4,2671
9
- parabellum-0.2.24.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
10
- parabellum-0.2.24.dist-info/RECORD,,