parabellum 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/env.py +46 -55
- parabellum/map.py +45 -13
- parabellum/vis.py +3 -3
- {parabellum-0.2.24.dist-info → parabellum-0.2.25.dist-info}/METADATA +1 -1
- parabellum-0.2.25.dist-info/RECORD +10 -0
- parabellum-0.2.24.dist-info/RECORD +0 -10
- {parabellum-0.2.24.dist-info → parabellum-0.2.25.dist-info}/WHEEL +0 -0
parabellum/env.py
CHANGED
@@ -19,6 +19,7 @@ class Scenario:
|
|
19
19
|
|
20
20
|
place: str
|
21
21
|
terrain_raster: jnp.ndarray
|
22
|
+
unit_starting_sectors: jnp.ndarray
|
22
23
|
unit_types: chex.Array
|
23
24
|
num_allies: int
|
24
25
|
num_enemies: int
|
@@ -45,6 +46,7 @@ scenarios = {
|
|
45
46
|
"default": Scenario(
|
46
47
|
"Identity Town",
|
47
48
|
jnp.eye(64, dtype=jnp.uint8),
|
49
|
+
jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
|
48
50
|
jnp.zeros((19,), dtype=jnp.uint8),
|
49
51
|
9,
|
50
52
|
10,
|
@@ -52,11 +54,20 @@ scenarios = {
|
|
52
54
|
}
|
53
55
|
|
54
56
|
|
55
|
-
def make_scenario(place, terrain_raster,
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
57
|
+
def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_allies, enemies_type, n_enemies):
|
58
|
+
if type(allies_type) == int:
|
59
|
+
allies = [allies_type] * n_allies
|
60
|
+
else:
|
61
|
+
assert(len(allies_type) == n_allies)
|
62
|
+
allies = allies_type
|
63
|
+
|
64
|
+
if type(enemies_type) == int:
|
65
|
+
enemies = [enemies_type] * n_enemies
|
66
|
+
else:
|
67
|
+
assert(len(enemies_type) == n_enemies)
|
68
|
+
enemies = enemies_type
|
69
|
+
unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
|
70
|
+
return Scenario(place, terrain_raster, unit_starting_sectors, unit_types, n_allies, n_enemies)
|
60
71
|
|
61
72
|
|
62
73
|
def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
|
@@ -81,29 +92,41 @@ def sector_fn(terrain: jnp.ndarray, sector_id: int):
|
|
81
92
|
return jnp.nonzero(sector), offset
|
82
93
|
|
83
94
|
|
95
|
+
def sector_fn(terrain: jnp.ndarray, sector: jnp.ndarray):
|
96
|
+
"""return sector slice of terrain"""
|
97
|
+
width, height = terrain.shape
|
98
|
+
coordx, coordy = int(sector[0] * width), int(sector[1] * height)
|
99
|
+
sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
|
100
|
+
offset = jnp.array([coordx, coordy])
|
101
|
+
# sector is jnp.nonzero
|
102
|
+
return jnp.nonzero(sector.T), offset
|
103
|
+
|
104
|
+
|
84
105
|
class Environment(SMAX):
|
85
106
|
def __init__(self, scenario: Scenario, **kwargs):
|
86
107
|
map_height, map_width = scenario.terrain_raster.shape
|
87
108
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
88
109
|
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
89
110
|
self.terrain_raster = scenario.terrain_raster
|
111
|
+
self.unit_starting_sectors = scenario.unit_starting_sectors
|
90
112
|
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
91
113
|
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
92
114
|
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
93
115
|
self.scenario = scenario
|
116
|
+
self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
|
94
117
|
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
95
118
|
self.max_steps = 200
|
96
119
|
self._push_units_away = lambda state, firmness = 1: state # overwrite push units
|
97
|
-
self.
|
98
|
-
self.
|
120
|
+
self.team0_sector, self.team0_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[0]) # sector_fn(self.terrain_raster, 0)
|
121
|
+
self.team1_sector, self.team1_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[1]) # sector_fn(self.terrain_raster, 24)
|
99
122
|
|
100
123
|
|
101
124
|
@partial(jax.jit, static_argnums=(0,))
|
102
125
|
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
103
126
|
"""Environment-specific reset."""
|
104
127
|
ally_key, enemy_key = jax.random.split(rng)
|
105
|
-
team_0_start = spawn_fn(self.
|
106
|
-
team_1_start = spawn_fn(self.
|
128
|
+
team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
|
129
|
+
team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
|
107
130
|
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
108
131
|
unit_teams = jnp.zeros((self.num_agents,))
|
109
132
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
@@ -167,15 +190,15 @@ class Environment(SMAX):
|
|
167
190
|
|
168
191
|
def raster_crossing(pos, new_pos):
|
169
192
|
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
170
|
-
raster = self.terrain_raster
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
return jnp.any(
|
178
|
-
|
193
|
+
raster = jnp.copy(self.terrain_raster)
|
194
|
+
minimum = jnp.minimum(pos, new_pos)
|
195
|
+
maximum = jnp.maximum(pos, new_pos)
|
196
|
+
raster = jnp.where(jnp.arange(raster.shape[0]) >= minimum[0], raster, 0)
|
197
|
+
raster = jnp.where(jnp.arange(raster.shape[0]) <= maximum[0], raster, 0)
|
198
|
+
raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
|
199
|
+
raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
|
200
|
+
return jnp.any(raster)
|
201
|
+
|
179
202
|
def update_position(idx, vec):
|
180
203
|
# Compute the movements slightly strangely.
|
181
204
|
# The velocities below are for diagonal directions
|
@@ -197,11 +220,7 @@ class Environment(SMAX):
|
|
197
220
|
|
198
221
|
#######################################################################
|
199
222
|
############################################ avoid going into obstacles
|
200
|
-
""" obs = self.obstacle_coords
|
201
|
-
obs_end = obs + self.obstacle_deltas
|
202
|
-
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
|
203
223
|
clash = raster_crossing(pos, new_pos)
|
204
|
-
# flag = jnp.logical_or(inters, rastersects)
|
205
224
|
new_pos = jnp.where(clash, pos, new_pos)
|
206
225
|
|
207
226
|
#######################################################################
|
@@ -314,39 +333,11 @@ class Environment(SMAX):
|
|
314
333
|
|
315
334
|
# units push each other
|
316
335
|
new_pos = self._our_push_units_away(pos, state.unit_types)
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
)
|
323
|
-
bondaries_deltas = jnp.array(
|
324
|
-
[
|
325
|
-
[self.map_width, 0],
|
326
|
-
[0, self.map_height],
|
327
|
-
[0, self.map_height],
|
328
|
-
[self.map_width, 0],
|
329
|
-
]
|
330
|
-
)
|
331
|
-
""" obstacle_coords = jnp.concatenate(
|
332
|
-
[self.obstacle_coords, bondaries_coords]
|
333
|
-
) # add the map boundaries to the obstacles to avoid
|
334
|
-
obstacle_deltas = jnp.concatenate(
|
335
|
-
[self.obstacle_deltas, bondaries_deltas]
|
336
|
-
) # add the map boundaries to the obstacles to avoid
|
337
|
-
obst_start = obstacle_coords
|
338
|
-
obst_end = obst_start + obstacle_deltas """
|
339
|
-
|
340
|
-
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
341
|
-
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
342
|
-
rastersect = raster_crossing(pos, new_pos)
|
343
|
-
flag = jnp.logical_or(intersects, rastersect)
|
344
|
-
return jnp.where(flag, pos, new_pos)
|
345
|
-
|
346
|
-
""" pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
347
|
-
pos, new_pos, obst_start, obst_end
|
348
|
-
) """
|
349
|
-
|
336
|
+
clash = jax.vmap(raster_crossing)(pos, new_pos)
|
337
|
+
pos = jax.vmap(jnp.where)(clash, pos, new_pos)
|
338
|
+
# avoid going out of bounds
|
339
|
+
pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
|
340
|
+
|
350
341
|
# Multiple enemies can attack the same unit.
|
351
342
|
# We have `(health_diff, attacked_idx)` pairs.
|
352
343
|
# `jax.lax.scatter_add` aggregates these exactly
|
parabellum/map.py
CHANGED
@@ -14,6 +14,8 @@ import rasterio.transform
|
|
14
14
|
from typing import Optional, Tuple
|
15
15
|
from geopy.location import Location
|
16
16
|
from shapely.geometry import Point
|
17
|
+
import os
|
18
|
+
import pickle
|
17
19
|
|
18
20
|
# constants
|
19
21
|
geolocator = Nominatim(user_agent="parabellum")
|
@@ -36,16 +38,41 @@ def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
|
|
36
38
|
w, s, e, n = gdf.total_bounds
|
37
39
|
transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
|
38
40
|
raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
|
39
|
-
return jnp.array(jnp.
|
41
|
+
return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
|
40
42
|
|
41
|
-
|
43
|
+
# +
|
44
|
+
def get_from_cache(place, size):
|
45
|
+
if os.path.exists("./cache"):
|
46
|
+
name = str(hash((place, size))) + ".pk"
|
47
|
+
if os.path.exists("./cache/" + name):
|
48
|
+
with open("./cache/" + name, "rb") as f:
|
49
|
+
(mask, base) = pickle.load(f)
|
50
|
+
return (mask, base.astype(jnp.int64))
|
51
|
+
return (None, None)
|
52
|
+
|
53
|
+
def save_in_cache(place, size, mask, base):
|
54
|
+
if not os.path.exists("./cache"):
|
55
|
+
os.makedirs("./cache")
|
56
|
+
name = str(hash((place, size))) + ".pk"
|
57
|
+
with open("./cache/" + name, "wb") as f:
|
58
|
+
pickle.dump((mask, base), f)
|
59
|
+
|
60
|
+
def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
42
61
|
"""Returns a rasterized map of buildings for a given location."""
|
43
|
-
|
44
|
-
|
45
|
-
mask
|
46
|
-
|
62
|
+
if with_cache:
|
63
|
+
mask, base = get_from_cache(place, size)
|
64
|
+
if not with_cache or mask is None:
|
65
|
+
point = get_location(place)
|
66
|
+
gdf = get_building_geometry(point, size)
|
67
|
+
mask = rasterize_geometry(gdf, size)
|
68
|
+
base = get_basemap(place, size)
|
69
|
+
if with_cache:
|
70
|
+
save_in_cache(place, size, mask, base)
|
47
71
|
return mask, base
|
48
72
|
|
73
|
+
|
74
|
+
# -
|
75
|
+
|
49
76
|
def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
|
50
77
|
"""Returns a basemap for a given place as a JAX array."""
|
51
78
|
point = get_location(place)
|
@@ -54,15 +81,20 @@ def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
|
|
54
81
|
# get the middle size x size square
|
55
82
|
basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
|
56
83
|
(basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
|
57
|
-
return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
|
84
|
+
return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
|
58
85
|
|
59
86
|
|
60
87
|
if __name__ == "__main__":
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
ax[0].imshow(mask) # type: ignore
|
88
|
+
place = "Cauvicourt, 14190, France"
|
89
|
+
mask, base = terrain_fn(place, 500)
|
90
|
+
|
91
|
+
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
|
92
|
+
ax[0].imshow(jnp.flip(mask,0)) # type: ignore
|
67
93
|
ax[1].imshow(base) # type: ignore
|
94
|
+
ax[2].imshow(base) # type: ignore
|
95
|
+
ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
|
68
96
|
plt.show()
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
|
parabellum/vis.py
CHANGED
@@ -56,7 +56,6 @@ class Visualizer(SMAXVisualizer):
|
|
56
56
|
self.skin.scale = self.skin.size / env.map_width # assumes square map
|
57
57
|
self.env = env
|
58
58
|
|
59
|
-
|
60
59
|
def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
|
61
60
|
expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
|
62
61
|
state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
|
@@ -86,7 +85,8 @@ def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> py
|
|
86
85
|
|
87
86
|
|
88
87
|
def transform_frame(env, skin, frame):
|
89
|
-
frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
88
|
+
#frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
89
|
+
frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
|
90
90
|
return frame
|
91
91
|
|
92
92
|
|
@@ -136,7 +136,7 @@ def text_fn(text):
|
|
136
136
|
|
137
137
|
def image_fn(skin: Skin): # TODO:
|
138
138
|
"""Create an image for background (basemap or maskmap)"""
|
139
|
-
motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.
|
139
|
+
motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
|
140
140
|
motif = (motif > 0).astype(np.uint8)
|
141
141
|
image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
|
142
142
|
image[motif == 1] = skin.fg
|
@@ -0,0 +1,10 @@
|
|
1
|
+
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
+
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
+
parabellum/env.py,sha256=H1VwxHj8KqbjBJ8b7NMxkAg3Q0qwVzGulrigc26Tzkc,16663
|
4
|
+
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
+
parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
|
6
|
+
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
+
parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
|
8
|
+
parabellum-0.2.25.dist-info/METADATA,sha256=UCuRLYhSUxnebs5pRs1FaWK67PUJw0kgsJ2UaaxNM9Q,2671
|
9
|
+
parabellum-0.2.25.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
+
parabellum-0.2.25.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
-
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
-
parabellum/env.py,sha256=rzhsMiCRVBOza63XboPYTLr8MRO45HUIIYX8qqFJHmE,16726
|
4
|
-
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
-
parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
|
6
|
-
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
-
parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
|
8
|
-
parabellum-0.2.24.dist-info/METADATA,sha256=0TXKsb81R0YnMOnpTIrEfDUOJVs92jXbKGpILB2WDO4,2671
|
9
|
-
parabellum-0.2.24.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
-
parabellum-0.2.24.dist-info/RECORD,,
|
File without changes
|