parabellum 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/env.py +50 -54
- parabellum/map.py +45 -13
- parabellum/vis.py +3 -3
- {parabellum-0.2.24.dist-info → parabellum-0.2.26.dist-info}/METADATA +1 -1
- parabellum-0.2.26.dist-info/RECORD +10 -0
- parabellum-0.2.24.dist-info/RECORD +0 -10
- {parabellum-0.2.24.dist-info → parabellum-0.2.26.dist-info}/WHEEL +0 -0
parabellum/env.py
CHANGED
@@ -19,6 +19,7 @@ class Scenario:
|
|
19
19
|
|
20
20
|
place: str
|
21
21
|
terrain_raster: jnp.ndarray
|
22
|
+
unit_starting_sectors: jnp.ndarray
|
22
23
|
unit_types: chex.Array
|
23
24
|
num_allies: int
|
24
25
|
num_enemies: int
|
@@ -45,6 +46,7 @@ scenarios = {
|
|
45
46
|
"default": Scenario(
|
46
47
|
"Identity Town",
|
47
48
|
jnp.eye(64, dtype=jnp.uint8),
|
49
|
+
jnp.array([[0, 0, 0.2, 0.2], [0.7,0.7,0.2,0.2]]),
|
48
50
|
jnp.zeros((19,), dtype=jnp.uint8),
|
49
51
|
9,
|
50
52
|
10,
|
@@ -52,11 +54,20 @@ scenarios = {
|
|
52
54
|
}
|
53
55
|
|
54
56
|
|
55
|
-
def make_scenario(place, terrain_raster,
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
57
|
+
def make_scenario(place, terrain_raster, unit_starting_sectors, allies_type, n_allies, enemies_type, n_enemies):
|
58
|
+
if type(allies_type) == int:
|
59
|
+
allies = [allies_type] * n_allies
|
60
|
+
else:
|
61
|
+
assert(len(allies_type) == n_allies)
|
62
|
+
allies = allies_type
|
63
|
+
|
64
|
+
if type(enemies_type) == int:
|
65
|
+
enemies = [enemies_type] * n_enemies
|
66
|
+
else:
|
67
|
+
assert(len(enemies_type) == n_enemies)
|
68
|
+
enemies = enemies_type
|
69
|
+
unit_types = jnp.array(allies + enemies, dtype=jnp.uint8)
|
70
|
+
return Scenario(place, terrain_raster, unit_starting_sectors, unit_types, n_allies, n_enemies)
|
60
71
|
|
61
72
|
|
62
73
|
def spawn_fn(pool, offset: jnp.ndarray, n: int, rng: jnp.ndarray):
|
@@ -81,29 +92,41 @@ def sector_fn(terrain: jnp.ndarray, sector_id: int):
|
|
81
92
|
return jnp.nonzero(sector), offset
|
82
93
|
|
83
94
|
|
95
|
+
def sector_fn(terrain: jnp.ndarray, sector: jnp.ndarray):
|
96
|
+
"""return sector slice of terrain"""
|
97
|
+
width, height = terrain.shape
|
98
|
+
coordx, coordy = int(sector[0] * width), int(sector[1] * height)
|
99
|
+
sector = terrain[coordy : coordy + int(sector[3] * height), coordx : coordx + int(sector[2] * width)] == 0
|
100
|
+
offset = jnp.array([coordx, coordy])
|
101
|
+
# sector is jnp.nonzero
|
102
|
+
return jnp.nonzero(sector.T), offset
|
103
|
+
|
104
|
+
|
84
105
|
class Environment(SMAX):
|
85
106
|
def __init__(self, scenario: Scenario, **kwargs):
|
86
107
|
map_height, map_width = scenario.terrain_raster.shape
|
87
108
|
args = dict(scenario=scenario, map_height=map_height, map_width=map_width)
|
88
109
|
super(Environment, self).__init__(**args, walls_cause_death=False, **kwargs)
|
89
110
|
self.terrain_raster = scenario.terrain_raster
|
111
|
+
self.unit_starting_sectors = scenario.unit_starting_sectors
|
90
112
|
# self.unit_type_names = ["tinker", "tailor", "soldier", "spy"]
|
91
113
|
# self.unit_type_health = jnp.array([100, 100, 100, 100], dtype=jnp.float32)
|
92
114
|
# self.unit_type_damage = jnp.array([10, 10, 10, 10], dtype=jnp.float32)
|
93
115
|
self.scenario = scenario
|
116
|
+
self.unit_type_velocities=jnp.array([3.15, 2.25, 4.13, 3.15, 4.13, 3.15])/2.5
|
94
117
|
self.unit_type_attack_blasts = jnp.zeros((3,), dtype=jnp.float32) # TODO: add
|
95
118
|
self.max_steps = 200
|
96
119
|
self._push_units_away = lambda state, firmness = 1: state # overwrite push units
|
97
|
-
self.
|
98
|
-
self.
|
120
|
+
self.team0_sector, self.team0_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[0]) # sector_fn(self.terrain_raster, 0)
|
121
|
+
self.team1_sector, self.team1_sector_offset = sector_fn(self.terrain_raster, self.unit_starting_sectors[1]) # sector_fn(self.terrain_raster, 24)
|
99
122
|
|
100
123
|
|
101
124
|
@partial(jax.jit, static_argnums=(0,))
|
102
125
|
def reset(self, rng: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
|
103
126
|
"""Environment-specific reset."""
|
104
127
|
ally_key, enemy_key = jax.random.split(rng)
|
105
|
-
team_0_start = spawn_fn(self.
|
106
|
-
team_1_start = spawn_fn(self.
|
128
|
+
team_0_start = spawn_fn(self.team0_sector, self.team0_sector_offset, self.num_allies, ally_key)
|
129
|
+
team_1_start = spawn_fn(self.team1_sector, self.team1_sector_offset, self.num_enemies, enemy_key)
|
107
130
|
unit_positions = jnp.concatenate([team_0_start, team_1_start])
|
108
131
|
unit_teams = jnp.zeros((self.num_agents,))
|
109
132
|
unit_teams = unit_teams.at[self.num_allies :].set(1)
|
@@ -126,9 +149,14 @@ class Environment(SMAX):
|
|
126
149
|
state = self._push_units_away(state) # type: ignore
|
127
150
|
obs = self.get_obs(state)
|
128
151
|
world_state = self.get_world_state(state)
|
129
|
-
obs["world_state"] = jax.lax.stop_gradient(world_state)
|
152
|
+
# obs["world_state"] = jax.lax.stop_gradient(world_state)
|
130
153
|
return obs, state
|
131
154
|
|
155
|
+
def step_env(self, rng, state: State, action: Array):
|
156
|
+
obs, state, rewards, dones, infos = super().step_env(rng, state, action)
|
157
|
+
# delete world_state from obs
|
158
|
+
obs.pop("world_state")
|
159
|
+
return obs, state, rewards, dones, infos
|
132
160
|
|
133
161
|
def _our_push_units_away(
|
134
162
|
self, pos, unit_types, firmness: float = 1.0
|
@@ -167,14 +195,14 @@ class Environment(SMAX):
|
|
167
195
|
|
168
196
|
def raster_crossing(pos, new_pos):
|
169
197
|
pos, new_pos = pos.astype(jnp.int32), new_pos.astype(jnp.int32)
|
170
|
-
raster = self.terrain_raster
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
return jnp.any(
|
198
|
+
raster = jnp.copy(self.terrain_raster)
|
199
|
+
minimum = jnp.minimum(pos, new_pos)
|
200
|
+
maximum = jnp.maximum(pos, new_pos)
|
201
|
+
raster = jnp.where(jnp.arange(raster.shape[0]) >= minimum[0], raster, 0)
|
202
|
+
raster = jnp.where(jnp.arange(raster.shape[0]) <= maximum[0], raster, 0)
|
203
|
+
raster = jnp.where(jnp.arange(raster.shape[1]) >= minimum[1], raster.T, 0).T
|
204
|
+
raster = jnp.where(jnp.arange(raster.shape[1]) <= maximum[1], raster.T, 0).T
|
205
|
+
return jnp.any(raster)
|
178
206
|
|
179
207
|
def update_position(idx, vec):
|
180
208
|
# Compute the movements slightly strangely.
|
@@ -197,11 +225,7 @@ class Environment(SMAX):
|
|
197
225
|
|
198
226
|
#######################################################################
|
199
227
|
############################################ avoid going into obstacles
|
200
|
-
""" obs = self.obstacle_coords
|
201
|
-
obs_end = obs + self.obstacle_deltas
|
202
|
-
inters = jnp.any(intersect_fn(pos, new_pos, obs, obs_end)) """
|
203
228
|
clash = raster_crossing(pos, new_pos)
|
204
|
-
# flag = jnp.logical_or(inters, rastersects)
|
205
229
|
new_pos = jnp.where(clash, pos, new_pos)
|
206
230
|
|
207
231
|
#######################################################################
|
@@ -314,38 +338,10 @@ class Environment(SMAX):
|
|
314
338
|
|
315
339
|
# units push each other
|
316
340
|
new_pos = self._our_push_units_away(pos, state.unit_types)
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
[[0, 0], [0, 0], [self.map_width, 0], [0, self.map_height]]
|
322
|
-
)
|
323
|
-
bondaries_deltas = jnp.array(
|
324
|
-
[
|
325
|
-
[self.map_width, 0],
|
326
|
-
[0, self.map_height],
|
327
|
-
[0, self.map_height],
|
328
|
-
[self.map_width, 0],
|
329
|
-
]
|
330
|
-
)
|
331
|
-
""" obstacle_coords = jnp.concatenate(
|
332
|
-
[self.obstacle_coords, bondaries_coords]
|
333
|
-
) # add the map boundaries to the obstacles to avoid
|
334
|
-
obstacle_deltas = jnp.concatenate(
|
335
|
-
[self.obstacle_deltas, bondaries_deltas]
|
336
|
-
) # add the map boundaries to the obstacles to avoid
|
337
|
-
obst_start = obstacle_coords
|
338
|
-
obst_end = obst_start + obstacle_deltas """
|
339
|
-
|
340
|
-
def check_obstacles(pos, new_pos, obst_start, obst_end):
|
341
|
-
intersects = jnp.any(intersect_fn(pos, new_pos, obst_start, obst_end))
|
342
|
-
rastersect = raster_crossing(pos, new_pos)
|
343
|
-
flag = jnp.logical_or(intersects, rastersect)
|
344
|
-
return jnp.where(flag, pos, new_pos)
|
345
|
-
|
346
|
-
""" pos = jax.vmap(check_obstacles, in_axes=(0, 0, None, None))(
|
347
|
-
pos, new_pos, obst_start, obst_end
|
348
|
-
) """
|
341
|
+
clash = jax.vmap(raster_crossing)(pos, new_pos)
|
342
|
+
pos = jax.vmap(jnp.where)(clash, pos, new_pos)
|
343
|
+
# avoid going out of bounds
|
344
|
+
pos = jnp.maximum(jnp.minimum(pos, jnp.array([self.map_width, self.map_height])),jnp.zeros((2,)),)
|
349
345
|
|
350
346
|
# Multiple enemies can attack the same unit.
|
351
347
|
# We have `(health_diff, attacked_idx)` pairs.
|
parabellum/map.py
CHANGED
@@ -14,6 +14,8 @@ import rasterio.transform
|
|
14
14
|
from typing import Optional, Tuple
|
15
15
|
from geopy.location import Location
|
16
16
|
from shapely.geometry import Point
|
17
|
+
import os
|
18
|
+
import pickle
|
17
19
|
|
18
20
|
# constants
|
19
21
|
geolocator = Nominatim(user_agent="parabellum")
|
@@ -36,16 +38,41 @@ def rasterize_geometry(gdf: gpd.GeoDataFrame, size: int) -> jnp.ndarray:
|
|
36
38
|
w, s, e, n = gdf.total_bounds
|
37
39
|
transform = rasterio.transform.from_bounds(w, s, e, n, size, size)
|
38
40
|
raster = features.rasterize(gdf.geometry, out_shape=(size, size), transform=transform)
|
39
|
-
return jnp.array(jnp.
|
41
|
+
return jnp.array(jnp.flip(raster, 0) ).astype(jnp.uint8)
|
40
42
|
|
41
|
-
|
43
|
+
# +
|
44
|
+
def get_from_cache(place, size):
|
45
|
+
if os.path.exists("./cache"):
|
46
|
+
name = str(hash((place, size))) + ".pk"
|
47
|
+
if os.path.exists("./cache/" + name):
|
48
|
+
with open("./cache/" + name, "rb") as f:
|
49
|
+
(mask, base) = pickle.load(f)
|
50
|
+
return (mask, base.astype(jnp.int64))
|
51
|
+
return (None, None)
|
52
|
+
|
53
|
+
def save_in_cache(place, size, mask, base):
|
54
|
+
if not os.path.exists("./cache"):
|
55
|
+
os.makedirs("./cache")
|
56
|
+
name = str(hash((place, size))) + ".pk"
|
57
|
+
with open("./cache/" + name, "wb") as f:
|
58
|
+
pickle.dump((mask, base), f)
|
59
|
+
|
60
|
+
def terrain_fn(place: str, size: int = 1000, with_cache: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
42
61
|
"""Returns a rasterized map of buildings for a given location."""
|
43
|
-
|
44
|
-
|
45
|
-
mask
|
46
|
-
|
62
|
+
if with_cache:
|
63
|
+
mask, base = get_from_cache(place, size)
|
64
|
+
if not with_cache or mask is None:
|
65
|
+
point = get_location(place)
|
66
|
+
gdf = get_building_geometry(point, size)
|
67
|
+
mask = rasterize_geometry(gdf, size)
|
68
|
+
base = get_basemap(place, size)
|
69
|
+
if with_cache:
|
70
|
+
save_in_cache(place, size, mask, base)
|
47
71
|
return mask, base
|
48
72
|
|
73
|
+
|
74
|
+
# -
|
75
|
+
|
49
76
|
def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
|
50
77
|
"""Returns a basemap for a given place as a JAX array."""
|
51
78
|
point = get_location(place)
|
@@ -54,15 +81,20 @@ def get_basemap(place: str, size: int = 1000) -> jnp.ndarray:
|
|
54
81
|
# get the middle size x size square
|
55
82
|
basemap = basemap[(basemap.shape[0] - size) // 2:(basemap.shape[0] + size) // 2,
|
56
83
|
(basemap.shape[1] - size) // 2:(basemap.shape[1] + size) // 2]
|
57
|
-
return jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
|
84
|
+
return basemap # jnp.array(jnp.rot90(basemap, 2)).astype(jnp.uint8)
|
58
85
|
|
59
86
|
|
60
87
|
if __name__ == "__main__":
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
ax[0].imshow(mask) # type: ignore
|
88
|
+
place = "Cauvicourt, 14190, France"
|
89
|
+
mask, base = terrain_fn(place, 500)
|
90
|
+
|
91
|
+
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
|
92
|
+
ax[0].imshow(jnp.flip(mask,0)) # type: ignore
|
67
93
|
ax[1].imshow(base) # type: ignore
|
94
|
+
ax[2].imshow(base) # type: ignore
|
95
|
+
ax[2].imshow(jnp.flip(mask,0), alpha=jnp.flip(mask,0)) # type: ignore
|
68
96
|
plt.show()
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
|
parabellum/vis.py
CHANGED
@@ -56,7 +56,6 @@ class Visualizer(SMAXVisualizer):
|
|
56
56
|
self.skin.scale = self.skin.size / env.map_width # assumes square map
|
57
57
|
self.env = env
|
58
58
|
|
59
|
-
|
60
59
|
def animate(self, save_fname: Optional[str] = "output/parabellum", view=None):
|
61
60
|
expanded_state_seq, expanded_action_seq = expand_fn(self.env, self.state_seq, self.action_seq)
|
62
61
|
state_seq_seq, action_seq_seq = unbatch_fn(expanded_state_seq, expanded_action_seq)
|
@@ -86,7 +85,8 @@ def init_frame(env, skin, image, state: pb.State, action: Array, idx: int) -> py
|
|
86
85
|
|
87
86
|
|
88
87
|
def transform_frame(env, skin, frame):
|
89
|
-
frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
88
|
+
#frame = np.rot90(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 2)
|
89
|
+
frame = np.flip(pygame.surfarray.pixels3d(frame).swapaxes(0, 1), 0)
|
90
90
|
return frame
|
91
91
|
|
92
92
|
|
@@ -136,7 +136,7 @@ def text_fn(text):
|
|
136
136
|
|
137
137
|
def image_fn(skin: Skin): # TODO:
|
138
138
|
"""Create an image for background (basemap or maskmap)"""
|
139
|
-
motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.
|
139
|
+
motif = cv2.resize(np.array(skin.maskmap.T), (skin.size, skin.size), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
|
140
140
|
motif = (motif > 0).astype(np.uint8)
|
141
141
|
image = np.zeros((skin.size, skin.size, 3), dtype=np.uint8) + skin.bg
|
142
142
|
image[motif == 1] = skin.fg
|
@@ -0,0 +1,10 @@
|
|
1
|
+
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
+
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
+
parabellum/env.py,sha256=H1YGAtUYNJd8OHnZ3sOEXbag5L0WjtJHBGL8ymGPvoE,16898
|
4
|
+
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
+
parabellum/map.py,sha256=UwMqwySasX5oLw9v5YMJARAPwvQThLTRW36NpbwvBC8,3564
|
6
|
+
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
+
parabellum/vis.py,sha256=q7_OIMjzt-7nBOojVVW7Wiiq9ojsjaltIsH6eOOxPKk,6116
|
8
|
+
parabellum-0.2.26.dist-info/METADATA,sha256=AJwdmHRRPG2MosgufeWqQH7LsDRtTvFFGMh1azey9zA,2671
|
9
|
+
parabellum-0.2.26.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
+
parabellum-0.2.26.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
parabellum/__init__.py,sha256=-5cWXJkHnfH_CbhTEall8Wak8McAFXZHP1L8Fu7Uo5k,373
|
2
|
-
parabellum/aid.py,sha256=HWST27inTFXcp8b11izJF0U7N7DZnRTIS3n1Qfa-Ko4,106
|
3
|
-
parabellum/env.py,sha256=rzhsMiCRVBOza63XboPYTLr8MRO45HUIIYX8qqFJHmE,16726
|
4
|
-
parabellum/gun.py,sha256=nvsJdcZ2Qd6lbPlAgsUiaLhstTi1UdLQ8kOnbCenucY,2618
|
5
|
-
parabellum/map.py,sha256=EUcPe4Upu9MQzS8h15IVPGCaAyRPLSkmoLd5ZT-V4Pk,2599
|
6
|
-
parabellum/run.py,sha256=EO_F7VPwayatpSHrcbSahtinsV4QObhcx0jo-4KZO1E,3472
|
7
|
-
parabellum/vis.py,sha256=uXTnhJL23JLQHW9by-M4bF73dSVA5TIkpNdfo_Go2Ro,6045
|
8
|
-
parabellum-0.2.24.dist-info/METADATA,sha256=0TXKsb81R0YnMOnpTIrEfDUOJVs92jXbKGpILB2WDO4,2671
|
9
|
-
parabellum-0.2.24.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
10
|
-
parabellum-0.2.24.dist-info/RECORD,,
|
File without changes
|