parabellum 0.0.0__py3-none-any.whl → 0.0.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- parabellum/__init__.py +3 -3
- parabellum/env.py +51 -277
- parabellum/geo.py +213 -0
- parabellum/types.py +138 -0
- parabellum/utils.py +59 -0
- parabellum-0.0.73.dist-info/METADATA +46 -0
- parabellum-0.0.73.dist-info/RECORD +8 -0
- {parabellum-0.0.0.dist-info → parabellum-0.0.73.dist-info}/WHEEL +1 -1
- parabellum/.ipynb_checkpoints/__init__-checkpoint.py +0 -4
- parabellum/.ipynb_checkpoints/env-checkpoint.py +0 -296
- parabellum/.ipynb_checkpoints/vis-checkpoint.py +0 -230
- parabellum/vis.py +0 -230
- parabellum-0.0.0.dist-info/METADATA +0 -55
- parabellum-0.0.0.dist-info/RECORD +0 -9
parabellum/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from .
|
2
|
-
from .
|
1
|
+
from . import env, geo, types, utils
|
2
|
+
from .env import Env
|
3
3
|
|
4
|
-
__all__ = ["
|
4
|
+
__all__ = ["geo", "env", "types", "utils", "Env"]
|
parabellum/env.py
CHANGED
@@ -1,296 +1,70 @@
|
|
1
|
-
|
1
|
+
# %% env.py
|
2
|
+
# parabellum env
|
3
|
+
# by: Noah Syrkis
|
2
4
|
|
3
|
-
|
4
|
-
import jax
|
5
|
-
import numpy as np
|
6
|
-
from jax import random
|
7
|
-
from jax import jit
|
8
|
-
from flax.struct import dataclass
|
9
|
-
import chex
|
10
|
-
from jaxmarl.environments.smax.smax_env import State, SMAX
|
11
|
-
from typing import Tuple, Dict
|
5
|
+
# Imports
|
12
6
|
from functools import partial
|
7
|
+
from typing import Tuple
|
13
8
|
|
9
|
+
import jax.numpy as jnp
|
10
|
+
import jaxkd as jk
|
11
|
+
from jax import random
|
12
|
+
from jaxtyping import Array
|
14
13
|
|
15
|
-
|
16
|
-
class Scenario:
|
17
|
-
"""Parabellum scenario"""
|
18
|
-
|
19
|
-
obstacle_coords: chex.Array
|
20
|
-
obstacle_deltas: chex.Array
|
21
|
-
|
22
|
-
unit_types: chex.Array
|
23
|
-
num_allies: int
|
24
|
-
num_enemies: int
|
25
|
-
|
26
|
-
smacv2_position_generation: bool = False
|
27
|
-
smacv2_unit_type_generation: bool = False
|
28
|
-
|
29
|
-
|
30
|
-
# default scenario
|
31
|
-
scenarios = {
|
32
|
-
"default": Scenario(
|
33
|
-
jnp.array([[6, 10], [26, 10]]) * 8,
|
34
|
-
jnp.array([[0, 12], [0, 1]]) * 8,
|
35
|
-
jnp.zeros((19,), dtype=jnp.uint8),
|
36
|
-
9,
|
37
|
-
10,
|
38
|
-
)
|
39
|
-
}
|
40
|
-
|
41
|
-
|
42
|
-
class Parabellum(SMAX):
|
43
|
-
def __init__(
|
44
|
-
self,
|
45
|
-
scenario: Scenario = scenarios["default"],
|
46
|
-
unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
|
47
|
-
**kwargs,
|
48
|
-
):
|
49
|
-
super().__init__(scenario=scenario, **kwargs)
|
50
|
-
self.unit_type_attack_blasts = unit_type_attack_blasts
|
51
|
-
self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
|
52
|
-
self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
|
53
|
-
self.max_steps = 200
|
54
|
-
# overwrite supers _world_step method
|
55
|
-
|
56
|
-
|
57
|
-
def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
|
58
|
-
return state
|
59
|
-
|
60
|
-
def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
|
61
|
-
delta_matrix = pos[:, None] - pos[None, :]
|
62
|
-
dist_matrix = (
|
63
|
-
jnp.linalg.norm(delta_matrix, axis=-1)
|
64
|
-
+ jnp.identity(self.num_agents)
|
65
|
-
+ 1e-6
|
66
|
-
)
|
67
|
-
radius_matrix = (
|
68
|
-
self.unit_type_radiuses[unit_types][:, None]
|
69
|
-
+ self.unit_type_radiuses[unit_types][None, :]
|
70
|
-
)
|
71
|
-
overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
|
72
|
-
unit_positions = (
|
73
|
-
pos
|
74
|
-
+ firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
|
75
|
-
)
|
76
|
-
return unit_positions
|
77
|
-
|
78
|
-
@partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
|
79
|
-
def _world_step( # modified version of JaxMARL's SMAX _world_step
|
80
|
-
self,
|
81
|
-
key: chex.PRNGKey,
|
82
|
-
state: State,
|
83
|
-
actions: Tuple[chex.Array, chex.Array],
|
84
|
-
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
|
85
|
-
|
86
|
-
@partial(jax.vmap, in_axes=(None, None, 0, 0))
|
87
|
-
def inter_fn(pos, new_pos, obs, obs_end):
|
88
|
-
d1 = jnp.cross(obs - pos, new_pos - pos)
|
89
|
-
d2 = jnp.cross(obs_end - pos, new_pos - pos)
|
90
|
-
d3 = jnp.cross(pos - obs, obs_end - obs)
|
91
|
-
d4 = jnp.cross(new_pos - obs, obs_end - obs)
|
92
|
-
return (d1 * d2 <= 0) & (d3 * d4 <= 0)
|
93
|
-
|
94
|
-
def update_position(idx, vec):
|
95
|
-
# Compute the movements slightly strangely.
|
96
|
-
# The velocities below are for diagonal directions
|
97
|
-
# because these are easier to encode as actions than the four
|
98
|
-
# diagonal directions. Then rotate the velocity 45
|
99
|
-
# degrees anticlockwise to compute the movement.
|
100
|
-
pos = state.unit_positions[idx]
|
101
|
-
new_pos = (
|
102
|
-
pos
|
103
|
-
+ vec
|
104
|
-
* self.unit_type_velocities[state.unit_types[idx]]
|
105
|
-
* self.time_per_step
|
106
|
-
)
|
107
|
-
# avoid going out of bounds
|
108
|
-
new_pos = jnp.maximum(
|
109
|
-
jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
|
110
|
-
jnp.zeros((2,)),
|
111
|
-
)
|
112
|
-
|
113
|
-
#######################################################################
|
114
|
-
############################################ avoid going into obstacles
|
115
|
-
obs = self.obstacle_coords
|
116
|
-
obs_end = obs + self.obstacle_deltas
|
117
|
-
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
118
|
-
new_pos = jnp.where(inters, pos, new_pos)
|
119
|
-
|
120
|
-
#######################################################################
|
121
|
-
#######################################################################
|
122
|
-
|
123
|
-
return new_pos
|
124
|
-
|
125
|
-
#######################################################################
|
126
|
-
######################################### units close enough to get hit
|
127
|
-
|
128
|
-
def bystander_fn(attacked_idx):
|
129
|
-
idxs = jnp.zeros((self.num_agents,))
|
130
|
-
idxs *= (
|
131
|
-
jnp.linalg.norm(
|
132
|
-
state.unit_positions - state.unit_positions[attacked_idx], axis=-1
|
133
|
-
)
|
134
|
-
< self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
|
135
|
-
)
|
136
|
-
return idxs
|
137
|
-
|
138
|
-
#######################################################################
|
139
|
-
#######################################################################
|
140
|
-
|
141
|
-
def update_agent_health(idx, action, key): # TODO: add attack blasts
|
142
|
-
# for team 1, their attack actions are labelled in
|
143
|
-
# reverse order because that is the order they are
|
144
|
-
# observed in
|
145
|
-
attacked_idx = jax.lax.cond(
|
146
|
-
idx < self.num_allies,
|
147
|
-
lambda: action + self.num_allies - self.num_movement_actions,
|
148
|
-
lambda: self.num_allies - 1 - (action - self.num_movement_actions),
|
149
|
-
)
|
150
|
-
# deal with no-op attack actions (i.e. agents that are moving instead)
|
151
|
-
attacked_idx = jax.lax.select(
|
152
|
-
action < self.num_movement_actions, idx, attacked_idx
|
153
|
-
)
|
154
|
-
|
155
|
-
attack_valid = (
|
156
|
-
(
|
157
|
-
jnp.linalg.norm(
|
158
|
-
state.unit_positions[idx] - state.unit_positions[attacked_idx]
|
159
|
-
)
|
160
|
-
< self.unit_type_attack_ranges[state.unit_types[idx]]
|
161
|
-
)
|
162
|
-
& state.unit_alive[idx]
|
163
|
-
& state.unit_alive[attacked_idx]
|
164
|
-
)
|
165
|
-
attack_valid = attack_valid & (idx != attacked_idx)
|
166
|
-
attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
|
167
|
-
health_diff = jax.lax.select(
|
168
|
-
attack_valid,
|
169
|
-
-self.unit_type_attacks[state.unit_types[idx]],
|
170
|
-
0.0,
|
171
|
-
)
|
172
|
-
# design choice based on the pysc2 randomness details.
|
173
|
-
# See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
|
174
|
-
|
175
|
-
#########################################################
|
176
|
-
############################### Add bystander health diff
|
14
|
+
from parabellum.types import Action, Config, Obs, State
|
177
15
|
|
178
|
-
bystander_idxs = bystander_fn(attacked_idx) # TODO: use
|
179
|
-
bystander_valid = (
|
180
|
-
jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
|
181
|
-
.astype(jnp.bool_)
|
182
|
-
.astype(jnp.float32)
|
183
|
-
)
|
184
|
-
bystander_health_diff = (
|
185
|
-
bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
|
186
|
-
)
|
187
16
|
|
188
|
-
|
189
|
-
|
17
|
+
# %% Dataclass
|
18
|
+
class Env:
|
19
|
+
def init(self, cfg: Config, rng: Array) -> Tuple[Obs, State]:
|
20
|
+
state = init_fn(cfg, rng) # without jit this takes forever
|
21
|
+
return obs_fn(cfg, state), state
|
190
22
|
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
cooldown = (
|
195
|
-
self.unit_type_weapon_cooldowns[state.unit_types[idx]]
|
196
|
-
+ cooldown_deviation
|
197
|
-
)
|
198
|
-
cooldown_diff = jax.lax.select(
|
199
|
-
attack_valid,
|
200
|
-
# subtract the current cooldown because we are
|
201
|
-
# going to add it back. This way we effectively
|
202
|
-
# set the new cooldown to `cooldown`
|
203
|
-
cooldown - state.unit_weapon_cooldowns[idx],
|
204
|
-
-self.time_per_step,
|
205
|
-
)
|
206
|
-
return (
|
207
|
-
health_diff,
|
208
|
-
attacked_idx,
|
209
|
-
cooldown_diff,
|
210
|
-
(bystander_health_diff, bystander_idxs),
|
211
|
-
)
|
23
|
+
def step(self, cfg: Config, rng: Array, state: State, action: Action) -> Tuple[Obs, State]:
|
24
|
+
state = step_fn(cfg, rng, state, action)
|
25
|
+
return obs_fn(cfg, state), state
|
212
26
|
|
213
|
-
def perform_agent_action(idx, action, key):
|
214
|
-
movement_action, attack_action = action
|
215
|
-
new_pos = update_position(idx, movement_action)
|
216
|
-
health_diff, attacked_idxes, cooldown_diff, (bystander) = (
|
217
|
-
update_agent_health(idx, attack_action, key)
|
218
|
-
)
|
219
27
|
|
220
|
-
|
28
|
+
# %% Functions
|
29
|
+
def init_fn(cfg: Config, rng: Array) -> State:
|
30
|
+
prob = jnp.ones((cfg.size, cfg.size)).at[cfg.map].set(0).flatten() # Set
|
31
|
+
flat = random.choice(rng, jnp.arange(prob.size), shape=(cfg.length,), p=prob, replace=True)
|
32
|
+
idxs = (flat // len(cfg.map), flat % len(cfg.map))
|
33
|
+
pos = jnp.float32(jnp.column_stack(idxs))
|
34
|
+
return State(pos=pos, hp=cfg.hp[cfg.types])
|
221
35
|
|
222
|
-
keys = jax.random.split(key, num=self.num_agents)
|
223
|
-
pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
|
224
|
-
perform_agent_action
|
225
|
-
)(jnp.arange(self.num_agents), actions, keys)
|
226
36
|
|
227
|
-
|
228
|
-
|
37
|
+
def obs_fn(cfg: Config, state: State) -> Obs: # return info about neighbors ---
|
38
|
+
idxs, dist = jk.extras.query_neighbors_pairwise(state.pos, state.pos, k=cfg.knn)
|
39
|
+
mask = dist < cfg.sight[cfg.types[idxs][:, 0]][..., None] # | (state.hp[idxs] > 0)
|
40
|
+
pos = (state.pos[idxs] - state.pos[:, None, ...]).at[:, 0, :].set(state.pos) * mask[..., None]
|
41
|
+
args = state.hp, cfg.types, cfg.teams, cfg.reach, cfg.sight, cfg.speed
|
42
|
+
hp, type, team, reach, sight, speed = map(lambda x: x[idxs] * mask, args)
|
43
|
+
return Obs(pos=pos, dist=dist, hp=hp, type=type, team=team, reach=reach, sight=sight, speed=speed, mask=mask)
|
229
44
|
|
230
|
-
# avoid going into obstacles after being pushed
|
231
|
-
obs = self.obstacle_coords
|
232
|
-
obs_end = obs + self.obstacle_deltas
|
233
|
-
|
234
|
-
def check_obstacles(pos, new_pos, obs, obs_end):
|
235
|
-
inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
|
236
|
-
return jnp.where(inters, pos, new_pos)
|
237
|
-
|
238
|
-
pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
|
239
|
-
|
240
|
-
# Multiple enemies can attack the same unit.
|
241
|
-
# We have `(health_diff, attacked_idx)` pairs.
|
242
|
-
# `jax.lax.scatter_add` aggregates these exactly
|
243
|
-
# in the way we want -- duplicate idxes will have their
|
244
|
-
# health differences added together. However, it is a
|
245
|
-
# super thin wrapper around the XLA scatter operation,
|
246
|
-
# which has this bonkers syntax and requires this dnums
|
247
|
-
# parameter. The usage here was inferred from a test:
|
248
|
-
# https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
|
249
|
-
dnums = jax.lax.ScatterDimensionNumbers(
|
250
|
-
update_window_dims=(),
|
251
|
-
inserted_window_dims=(0,),
|
252
|
-
scatter_dims_to_operand_dims=(0,),
|
253
|
-
)
|
254
|
-
unit_health = jnp.maximum(
|
255
|
-
jax.lax.scatter_add(
|
256
|
-
state.unit_health,
|
257
|
-
jnp.expand_dims(attacked_idxes, 1),
|
258
|
-
health_diff,
|
259
|
-
dnums,
|
260
|
-
),
|
261
|
-
0.0,
|
262
|
-
)
|
263
45
|
|
264
|
-
|
265
|
-
|
46
|
+
def step_fn(cfg: Config, rng: Array, state: State, action: Action) -> State:
|
47
|
+
idx, norm = jk.extras.query_neighbors_pairwise(state.pos + action.pos, state.pos, k=2)
|
48
|
+
args = rng, cfg, state, action, idx, norm
|
49
|
+
return State(pos=partial(push_fn, cfg, rng, idx, norm)(move_fn(*args)), hp=blast_fn(*args)) # type: ignore
|
266
50
|
|
267
|
-
_, bystander_health_diff = bystander
|
268
|
-
unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
|
269
51
|
|
270
|
-
|
271
|
-
|
52
|
+
def move_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
|
53
|
+
speed = cfg.speed[cfg.types][..., None] # max speed of a unit (step size, really)
|
54
|
+
pos = state.pos + action.pos.clip(-speed, speed) * action.move[..., None] # new poss
|
55
|
+
mask = ((pos < 0).any(axis=-1) | ((pos >= cfg.size).any(axis=-1)) | (cfg.map[*jnp.int32(pos).T] > 0))[..., None]
|
56
|
+
return jnp.where(mask, state.pos, pos) # compute new position
|
272
57
|
|
273
|
-
unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
|
274
|
-
state = state.replace(
|
275
|
-
unit_health=unit_health,
|
276
|
-
unit_positions=pos,
|
277
|
-
unit_weapon_cooldowns=unit_weapon_cooldowns,
|
278
|
-
)
|
279
|
-
return state
|
280
58
|
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
obs, state = env.reset(key)
|
285
|
-
state_seq = []
|
286
|
-
for step in range(100):
|
287
|
-
rng, key = random.split(rng)
|
288
|
-
key_act = random.split(key, len(env.agents))
|
289
|
-
actions = {
|
290
|
-
agent: jax.random.randint(key_act[i], (), 0, 5)
|
291
|
-
for i, agent in enumerate(env.agents)
|
292
|
-
}
|
293
|
-
_, state, _, _, _ = env.step(key, state, actions)
|
294
|
-
state_seq.append((obs, state, actions))
|
59
|
+
def blast_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
|
60
|
+
dam = (cfg.dam[cfg.types] * action.cast)[..., None] * jnp.ones_like(idx)
|
61
|
+
return state.hp - jnp.zeros(cfg.length, dtype=jnp.int32).at[idx.flatten()].add(dam.flatten())
|
295
62
|
|
296
63
|
|
64
|
+
def push_fn(cfg: Config, rng: Array, idx: Array, norm: Array, pos: Array) -> Array:
|
65
|
+
return pos + random.normal(rng, pos.shape) * 0.1
|
66
|
+
# params need to be tweaked, and matched with unit size
|
67
|
+
pos_diff = pos[:, None, :] - pos[idx] # direction away from neighbors
|
68
|
+
mask = (norm < cfg.r[cfg.types][..., None]) & (norm > 0)
|
69
|
+
pos = pos + jnp.where(mask[..., None], pos_diff * cfg.force / (norm[..., None] + 1e-6), 0.0).sum(axis=1)
|
70
|
+
return pos + random.normal(rng, pos.shape) * 0.1
|
parabellum/geo.py
ADDED
@@ -0,0 +1,213 @@
|
|
1
|
+
# %% geo.py
|
2
|
+
# script for geospatial level generation
|
3
|
+
# by: Noah Syrkis
|
4
|
+
|
5
|
+
# %% Imports
|
6
|
+
from rasterio import features, transform
|
7
|
+
|
8
|
+
# from jax import tree
|
9
|
+
from geopy.geocoders import Nominatim
|
10
|
+
from geopy.distance import distance
|
11
|
+
import contextily as cx
|
12
|
+
import jax.numpy as jnp
|
13
|
+
import cartopy.crs as ccrs
|
14
|
+
from jaxtyping import Array
|
15
|
+
from shapely import box
|
16
|
+
import osmnx as ox
|
17
|
+
import geopandas as gpd
|
18
|
+
from collections import namedtuple
|
19
|
+
from typing import Tuple
|
20
|
+
import matplotlib.pyplot as plt
|
21
|
+
from cachier import cachier
|
22
|
+
# from jax.scipy.signal import convolve
|
23
|
+
# from parabellum.types import Terrain
|
24
|
+
|
25
|
+
# %% Types
|
26
|
+
Coords = Tuple[float, float]
|
27
|
+
BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
28
|
+
|
29
|
+
# %% Constants
|
30
|
+
provider = cx.providers.Stadia.StamenTerrain( # type: ignore
|
31
|
+
api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
|
32
|
+
)
|
33
|
+
provider["url"] = provider["url"] + "?api_key={api_key}"
|
34
|
+
tags = {
|
35
|
+
"building": True,
|
36
|
+
"water": True,
|
37
|
+
"highway": True,
|
38
|
+
"landuse": [
|
39
|
+
"grass",
|
40
|
+
"forest",
|
41
|
+
"flowerbed",
|
42
|
+
"greenfield",
|
43
|
+
"village_green",
|
44
|
+
"recreation_ground",
|
45
|
+
],
|
46
|
+
"leisure": "garden",
|
47
|
+
} # "road": True}
|
48
|
+
|
49
|
+
|
50
|
+
# %% Coordinate function
|
51
|
+
def get_coordinates(place: str) -> Coords:
|
52
|
+
geolocator = Nominatim(user_agent="parabellum")
|
53
|
+
point = geolocator.geocode(place)
|
54
|
+
return point.latitude, point.longitude # type: ignore
|
55
|
+
|
56
|
+
|
57
|
+
def get_bbox(place: str, buffer) -> BBox:
|
58
|
+
"""Get bounding box from place name in crs 4326."""
|
59
|
+
coords = get_coordinates(place)
|
60
|
+
north = distance(meters=buffer).destination(coords, bearing=0).latitude
|
61
|
+
south = distance(meters=buffer).destination(coords, bearing=180).latitude
|
62
|
+
east = distance(meters=buffer).destination(coords, bearing=90).longitude
|
63
|
+
west = distance(meters=buffer).destination(coords, bearing=270).longitude
|
64
|
+
return BBox(north, south, east, west) # type: ignore
|
65
|
+
|
66
|
+
|
67
|
+
def basemap_fn(bbox: BBox, gdf) -> Array:
|
68
|
+
fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
|
69
|
+
gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
|
70
|
+
cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
|
71
|
+
bbox = gdf.total_bounds
|
72
|
+
ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
|
73
|
+
plt.axis("off")
|
74
|
+
plt.tight_layout(pad=0)
|
75
|
+
fig.canvas.draw()
|
76
|
+
image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
|
77
|
+
plt.close(fig)
|
78
|
+
return image
|
79
|
+
|
80
|
+
|
81
|
+
@cachier()
|
82
|
+
def geography_fn(place, buffer):
|
83
|
+
bbox = get_bbox(place, buffer)
|
84
|
+
map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
|
85
|
+
gdf = gpd.GeoDataFrame(map_data)
|
86
|
+
gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
|
87
|
+
raster = raster_fn(gdf, shape=(buffer, buffer))
|
88
|
+
# basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
|
89
|
+
# kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
90
|
+
trans = lambda x: jnp.bool(x) # jnp.rot90(x, 3) # noqa
|
91
|
+
terrain = trans(raster[0]) # Terrain(
|
92
|
+
# building=trans(raster[0]),
|
93
|
+
# water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
|
94
|
+
# forest=trans(jnp.logical_or(raster[3], raster[4])),
|
95
|
+
# basemap=basemap,
|
96
|
+
# )
|
97
|
+
# terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
|
98
|
+
return terrain
|
99
|
+
|
100
|
+
|
101
|
+
# =======
|
102
|
+
# terrain = tps.Terrain(building=trans(raster[0] - convolve(raster[0]*raster[2], kernel, mode='same')>0),
|
103
|
+
# water=trans(raster[1] - convolve(raster[1]*raster[2], kernel, mode='same')>0),
|
104
|
+
# forest=trans(jnp.logical_or(raster[3], raster[4])),
|
105
|
+
# basemap=basemap)
|
106
|
+
# return terrain, gdf
|
107
|
+
# >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
|
108
|
+
|
109
|
+
|
110
|
+
def raster_fn(gdf, shape) -> Array:
|
111
|
+
bbox = gdf.total_bounds
|
112
|
+
t = transform.from_bounds(*bbox, *shape) # type: ignore
|
113
|
+
raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
|
114
|
+
return raster
|
115
|
+
|
116
|
+
|
117
|
+
def feature_fn(t, feature, gdf, shape):
|
118
|
+
if feature not in gdf.columns:
|
119
|
+
return jnp.zeros(shape)
|
120
|
+
gdf = gdf[~gdf[feature].isna()]
|
121
|
+
raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
|
122
|
+
return raster
|
123
|
+
|
124
|
+
|
125
|
+
# %%
|
126
|
+
# def normalize(x):
|
127
|
+
# return (np.array(x) - m) / (M - m)
|
128
|
+
|
129
|
+
|
130
|
+
# def get_bridges(gdf):
|
131
|
+
# xmin, ymin, xmax, ymax = gdf.total_bounds
|
132
|
+
# m = np.array([xmin, ymin])
|
133
|
+
# M = np.array([xmax, ymax])
|
134
|
+
|
135
|
+
# bridges = {}
|
136
|
+
# for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
|
137
|
+
# if type(bridge["name"]) == str:
|
138
|
+
# bridges[idx[1]] = {
|
139
|
+
# "name": bridge["name"],
|
140
|
+
# "coords": normalize(
|
141
|
+
# [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
|
142
|
+
# ),
|
143
|
+
# }
|
144
|
+
# return bridges
|
145
|
+
|
146
|
+
|
147
|
+
"""
|
148
|
+
# %%
|
149
|
+
if __name__ == "__main__":
|
150
|
+
place = "Thun, Switzerland"
|
151
|
+
<<<<<<< HEAD
|
152
|
+
terrain = geography_fn(place, 300)
|
153
|
+
|
154
|
+
=======
|
155
|
+
terrain, gdf = geography_fn(place, 300)
|
156
|
+
|
157
|
+
>>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
|
158
|
+
fig, axes = plt.subplots(1, 5, figsize=(20, 20))
|
159
|
+
axes[0].imshow(jnp.rot90(terrain.building), cmap="gray")
|
160
|
+
axes[1].imshow(jnp.rot90(terrain.water), cmap="gray")
|
161
|
+
axes[2].imshow(jnp.rot90(terrain.forest), cmap="gray")
|
162
|
+
axes[3].imshow(jnp.rot90(terrain.building + terrain.water + terrain.forest))
|
163
|
+
axes[4].imshow(jnp.rot90(terrain.basemap))
|
164
|
+
|
165
|
+
# %%
|
166
|
+
W, H, _ = terrain.basemap.shape
|
167
|
+
bridges = get_bridges(gdf)
|
168
|
+
|
169
|
+
# %%
|
170
|
+
print("Bridges:")
|
171
|
+
for bridge in bridges.values():
|
172
|
+
x, y = int(bridge["coords"][0]*300), int(bridge["coords"][1]*300)
|
173
|
+
print(bridge["name"], f"at ({x}, {y})")
|
174
|
+
|
175
|
+
# %%
|
176
|
+
plt.subplots(figsize=(7,7))
|
177
|
+
plt.imshow(jnp.rot90(terrain.basemap))
|
178
|
+
X = [b["coords"][0]*W for b in bridges.values()]
|
179
|
+
Y = [(1-b["coords"][1])*H for b in bridges.values()]
|
180
|
+
plt.scatter(X, Y)
|
181
|
+
for i in range(len(X)):
|
182
|
+
x,y = int(X[i]), int(Y[i])
|
183
|
+
plt.text(x, y, str((int(x/W*300), int((1-(y/H))*300))))
|
184
|
+
|
185
|
+
# %%
|
186
|
+
|
187
|
+
# %% [raw]
|
188
|
+
# fig, ax = plt.subplots(figsize=(10, 10))
|
189
|
+
# gdf.plot(ax=ax, color='lightgray') # Plot all features
|
190
|
+
# bridges.plot(ax=ax, color='red') # Highlight bridges in red
|
191
|
+
# plt.show()
|
192
|
+
|
193
|
+
# %%
|
194
|
+
|
195
|
+
"""
|
196
|
+
|
197
|
+
# BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
|
198
|
+
|
199
|
+
|
200
|
+
# def to_mercator(bbox: BBox) -> BBox:
|
201
|
+
# proj = ccrs.Mercator()
|
202
|
+
# west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
|
203
|
+
# east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
|
204
|
+
# return BBox(north=north, south=south, east=east, west=west)
|
205
|
+
#
|
206
|
+
#
|
207
|
+
# def to_platecarree(bbox: BBox) -> BBox:
|
208
|
+
# proj = ccrs.PlateCarree()
|
209
|
+
# west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
|
210
|
+
#
|
211
|
+
# east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
|
212
|
+
# return BBox(north=north, south=south, east=east, west=west)
|
213
|
+
#
|