parabellum 0.0.0__py3-none-any.whl → 0.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
parabellum/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .env import Parabellum, Scenario
2
- from .vis import Visualizer
1
+ from . import env, geo, types, utils
2
+ from .env import Env
3
3
 
4
- __all__ = ["Parabellum", "Visualizer", "Scenario"]
4
+ __all__ = ["geo", "env", "types", "utils", "Env"]
parabellum/env.py CHANGED
@@ -1,296 +1,70 @@
1
- """Parabellum environment based on SMAX"""
1
+ # %% env.py
2
+ # parabellum env
3
+ # by: Noah Syrkis
2
4
 
3
- import jax.numpy as jnp
4
- import jax
5
- import numpy as np
6
- from jax import random
7
- from jax import jit
8
- from flax.struct import dataclass
9
- import chex
10
- from jaxmarl.environments.smax.smax_env import State, SMAX
11
- from typing import Tuple, Dict
5
+ # Imports
12
6
  from functools import partial
7
+ from typing import Tuple
13
8
 
9
+ import jax.numpy as jnp
10
+ import jaxkd as jk
11
+ from jax import random
12
+ from jaxtyping import Array
14
13
 
15
- @dataclass
16
- class Scenario:
17
- """Parabellum scenario"""
18
-
19
- obstacle_coords: chex.Array
20
- obstacle_deltas: chex.Array
21
-
22
- unit_types: chex.Array
23
- num_allies: int
24
- num_enemies: int
25
-
26
- smacv2_position_generation: bool = False
27
- smacv2_unit_type_generation: bool = False
28
-
29
-
30
- # default scenario
31
- scenarios = {
32
- "default": Scenario(
33
- jnp.array([[6, 10], [26, 10]]) * 8,
34
- jnp.array([[0, 12], [0, 1]]) * 8,
35
- jnp.zeros((19,), dtype=jnp.uint8),
36
- 9,
37
- 10,
38
- )
39
- }
40
-
41
-
42
- class Parabellum(SMAX):
43
- def __init__(
44
- self,
45
- scenario: Scenario = scenarios["default"],
46
- unit_type_attack_blasts=jnp.array([0, 0, 0, 0, 0, 0]) + 8,
47
- **kwargs,
48
- ):
49
- super().__init__(scenario=scenario, **kwargs)
50
- self.unit_type_attack_blasts = unit_type_attack_blasts
51
- self.obstacle_coords = scenario.obstacle_coords.astype(jnp.float32)
52
- self.obstacle_deltas = scenario.obstacle_deltas.astype(jnp.float32)
53
- self.max_steps = 200
54
- # overwrite supers _world_step method
55
-
56
-
57
- def _push_units_away(self, state: State, firmness: float = 1.0): # we do it inside the _world_step to allow more obstacles constraints
58
- return state
59
-
60
- def _our_push_units_away(self, pos, unit_types, firmness: float = 1.0): # copy of SMAX._push_units_away but used without state and called inside _world_step to allow more obstacles constraints
61
- delta_matrix = pos[:, None] - pos[None, :]
62
- dist_matrix = (
63
- jnp.linalg.norm(delta_matrix, axis=-1)
64
- + jnp.identity(self.num_agents)
65
- + 1e-6
66
- )
67
- radius_matrix = (
68
- self.unit_type_radiuses[unit_types][:, None]
69
- + self.unit_type_radiuses[unit_types][None, :]
70
- )
71
- overlap_term = jax.nn.relu(radius_matrix / dist_matrix - 1.0)
72
- unit_positions = (
73
- pos
74
- + firmness * jnp.sum(delta_matrix * overlap_term[:, :, None], axis=1) / 2
75
- )
76
- return unit_positions
77
-
78
- @partial(jax.jit, static_argnums=(0,)) # replace the _world_step method
79
- def _world_step( # modified version of JaxMARL's SMAX _world_step
80
- self,
81
- key: chex.PRNGKey,
82
- state: State,
83
- actions: Tuple[chex.Array, chex.Array],
84
- ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
85
-
86
- @partial(jax.vmap, in_axes=(None, None, 0, 0))
87
- def inter_fn(pos, new_pos, obs, obs_end):
88
- d1 = jnp.cross(obs - pos, new_pos - pos)
89
- d2 = jnp.cross(obs_end - pos, new_pos - pos)
90
- d3 = jnp.cross(pos - obs, obs_end - obs)
91
- d4 = jnp.cross(new_pos - obs, obs_end - obs)
92
- return (d1 * d2 <= 0) & (d3 * d4 <= 0)
93
-
94
- def update_position(idx, vec):
95
- # Compute the movements slightly strangely.
96
- # The velocities below are for diagonal directions
97
- # because these are easier to encode as actions than the four
98
- # diagonal directions. Then rotate the velocity 45
99
- # degrees anticlockwise to compute the movement.
100
- pos = state.unit_positions[idx]
101
- new_pos = (
102
- pos
103
- + vec
104
- * self.unit_type_velocities[state.unit_types[idx]]
105
- * self.time_per_step
106
- )
107
- # avoid going out of bounds
108
- new_pos = jnp.maximum(
109
- jnp.minimum(new_pos, jnp.array([self.map_width, self.map_height])),
110
- jnp.zeros((2,)),
111
- )
112
-
113
- #######################################################################
114
- ############################################ avoid going into obstacles
115
- obs = self.obstacle_coords
116
- obs_end = obs + self.obstacle_deltas
117
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
118
- new_pos = jnp.where(inters, pos, new_pos)
119
-
120
- #######################################################################
121
- #######################################################################
122
-
123
- return new_pos
124
-
125
- #######################################################################
126
- ######################################### units close enough to get hit
127
-
128
- def bystander_fn(attacked_idx):
129
- idxs = jnp.zeros((self.num_agents,))
130
- idxs *= (
131
- jnp.linalg.norm(
132
- state.unit_positions - state.unit_positions[attacked_idx], axis=-1
133
- )
134
- < self.unit_type_attack_blasts[state.unit_types[attacked_idx]]
135
- )
136
- return idxs
137
-
138
- #######################################################################
139
- #######################################################################
140
-
141
- def update_agent_health(idx, action, key): # TODO: add attack blasts
142
- # for team 1, their attack actions are labelled in
143
- # reverse order because that is the order they are
144
- # observed in
145
- attacked_idx = jax.lax.cond(
146
- idx < self.num_allies,
147
- lambda: action + self.num_allies - self.num_movement_actions,
148
- lambda: self.num_allies - 1 - (action - self.num_movement_actions),
149
- )
150
- # deal with no-op attack actions (i.e. agents that are moving instead)
151
- attacked_idx = jax.lax.select(
152
- action < self.num_movement_actions, idx, attacked_idx
153
- )
154
-
155
- attack_valid = (
156
- (
157
- jnp.linalg.norm(
158
- state.unit_positions[idx] - state.unit_positions[attacked_idx]
159
- )
160
- < self.unit_type_attack_ranges[state.unit_types[idx]]
161
- )
162
- & state.unit_alive[idx]
163
- & state.unit_alive[attacked_idx]
164
- )
165
- attack_valid = attack_valid & (idx != attacked_idx)
166
- attack_valid = attack_valid & (state.unit_weapon_cooldowns[idx] <= 0.0)
167
- health_diff = jax.lax.select(
168
- attack_valid,
169
- -self.unit_type_attacks[state.unit_types[idx]],
170
- 0.0,
171
- )
172
- # design choice based on the pysc2 randomness details.
173
- # See https://github.com/deepmind/pysc2/blob/master/docs/environment.md#determinism-and-randomness
174
-
175
- #########################################################
176
- ############################### Add bystander health diff
14
+ from parabellum.types import Action, Config, Obs, State
177
15
 
178
- bystander_idxs = bystander_fn(attacked_idx) # TODO: use
179
- bystander_valid = (
180
- jnp.where(attack_valid, bystander_idxs, jnp.zeros((self.num_agents,)))
181
- .astype(jnp.bool_)
182
- .astype(jnp.float32)
183
- )
184
- bystander_health_diff = (
185
- bystander_valid * -self.unit_type_attacks[state.unit_types[idx]]
186
- )
187
16
 
188
- #########################################################
189
- #########################################################
17
+ # %% Dataclass
18
+ class Env:
19
+ def init(self, cfg: Config, rng: Array) -> Tuple[Obs, State]:
20
+ state = init_fn(cfg, rng) # without jit this takes forever
21
+ return obs_fn(cfg, state), state
190
22
 
191
- cooldown_deviation = jax.random.uniform(
192
- key, minval=-self.time_per_step, maxval=2 * self.time_per_step
193
- )
194
- cooldown = (
195
- self.unit_type_weapon_cooldowns[state.unit_types[idx]]
196
- + cooldown_deviation
197
- )
198
- cooldown_diff = jax.lax.select(
199
- attack_valid,
200
- # subtract the current cooldown because we are
201
- # going to add it back. This way we effectively
202
- # set the new cooldown to `cooldown`
203
- cooldown - state.unit_weapon_cooldowns[idx],
204
- -self.time_per_step,
205
- )
206
- return (
207
- health_diff,
208
- attacked_idx,
209
- cooldown_diff,
210
- (bystander_health_diff, bystander_idxs),
211
- )
23
+ def step(self, cfg: Config, rng: Array, state: State, action: Action) -> Tuple[Obs, State]:
24
+ state = step_fn(cfg, rng, state, action)
25
+ return obs_fn(cfg, state), state
212
26
 
213
- def perform_agent_action(idx, action, key):
214
- movement_action, attack_action = action
215
- new_pos = update_position(idx, movement_action)
216
- health_diff, attacked_idxes, cooldown_diff, (bystander) = (
217
- update_agent_health(idx, attack_action, key)
218
- )
219
27
 
220
- return new_pos, (health_diff, attacked_idxes), cooldown_diff, bystander
28
+ # %% Functions
29
+ def init_fn(cfg: Config, rng: Array) -> State:
30
+ prob = jnp.ones((cfg.size, cfg.size)).at[cfg.map].set(0).flatten() # Set
31
+ flat = random.choice(rng, jnp.arange(prob.size), shape=(cfg.length,), p=prob, replace=True)
32
+ idxs = (flat // len(cfg.map), flat % len(cfg.map))
33
+ pos = jnp.float32(jnp.column_stack(idxs))
34
+ return State(pos=pos, hp=cfg.hp[cfg.types])
221
35
 
222
- keys = jax.random.split(key, num=self.num_agents)
223
- pos, (health_diff, attacked_idxes), cooldown_diff, bystander = jax.vmap(
224
- perform_agent_action
225
- )(jnp.arange(self.num_agents), actions, keys)
226
36
 
227
- # checked that no unit passed through an obstacles
228
- new_pos = self._our_push_units_away(pos, state.unit_types)
37
+ def obs_fn(cfg: Config, state: State) -> Obs: # return info about neighbors ---
38
+ idxs, dist = jk.extras.query_neighbors_pairwise(state.pos, state.pos, k=cfg.knn)
39
+ mask = dist < cfg.sight[cfg.types[idxs][:, 0]][..., None] # | (state.hp[idxs] > 0)
40
+ pos = (state.pos[idxs] - state.pos[:, None, ...]).at[:, 0, :].set(state.pos) * mask[..., None]
41
+ args = state.hp, cfg.types, cfg.teams, cfg.reach, cfg.sight, cfg.speed
42
+ hp, type, team, reach, sight, speed = map(lambda x: x[idxs] * mask, args)
43
+ return Obs(pos=pos, dist=dist, hp=hp, type=type, team=team, reach=reach, sight=sight, speed=speed, mask=mask)
229
44
 
230
- # avoid going into obstacles after being pushed
231
- obs = self.obstacle_coords
232
- obs_end = obs + self.obstacle_deltas
233
-
234
- def check_obstacles(pos, new_pos, obs, obs_end):
235
- inters = jnp.any(inter_fn(pos, new_pos, obs, obs_end))
236
- return jnp.where(inters, pos, new_pos)
237
-
238
- pos = jax.vmap(check_obstacles, in_axes=(0,0,None,None))(pos, new_pos, obs, obs_end)
239
-
240
- # Multiple enemies can attack the same unit.
241
- # We have `(health_diff, attacked_idx)` pairs.
242
- # `jax.lax.scatter_add` aggregates these exactly
243
- # in the way we want -- duplicate idxes will have their
244
- # health differences added together. However, it is a
245
- # super thin wrapper around the XLA scatter operation,
246
- # which has this bonkers syntax and requires this dnums
247
- # parameter. The usage here was inferred from a test:
248
- # https://github.com/google/jax/blob/main/tests/lax_test.py#L2296
249
- dnums = jax.lax.ScatterDimensionNumbers(
250
- update_window_dims=(),
251
- inserted_window_dims=(0,),
252
- scatter_dims_to_operand_dims=(0,),
253
- )
254
- unit_health = jnp.maximum(
255
- jax.lax.scatter_add(
256
- state.unit_health,
257
- jnp.expand_dims(attacked_idxes, 1),
258
- health_diff,
259
- dnums,
260
- ),
261
- 0.0,
262
- )
263
45
 
264
- #########################################################
265
- ############################ subtracting bystander health
46
+ def step_fn(cfg: Config, rng: Array, state: State, action: Action) -> State:
47
+ idx, norm = jk.extras.query_neighbors_pairwise(state.pos + action.pos, state.pos, k=2)
48
+ args = rng, cfg, state, action, idx, norm
49
+ return State(pos=partial(push_fn, cfg, rng, idx, norm)(move_fn(*args)), hp=blast_fn(*args)) # type: ignore
266
50
 
267
- _, bystander_health_diff = bystander
268
- unit_health -= bystander_health_diff.sum(axis=0) # might be axis=1
269
51
 
270
- #########################################################
271
- #########################################################
52
+ def move_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
53
+ speed = cfg.speed[cfg.types][..., None] # max speed of a unit (step size, really)
54
+ pos = state.pos + action.pos.clip(-speed, speed) * action.move[..., None] # new poss
55
+ mask = ((pos < 0).any(axis=-1) | ((pos >= cfg.size).any(axis=-1)) | (cfg.map[*jnp.int32(pos).T] > 0))[..., None]
56
+ return jnp.where(mask, state.pos, pos) # compute new position
272
57
 
273
- unit_weapon_cooldowns = state.unit_weapon_cooldowns + cooldown_diff
274
- state = state.replace(
275
- unit_health=unit_health,
276
- unit_positions=pos,
277
- unit_weapon_cooldowns=unit_weapon_cooldowns,
278
- )
279
- return state
280
58
 
281
- if __name__ == "__main__":
282
- env = Parabellum(map_width=256, map_height=256)
283
- rng, key = random.split(random.PRNGKey(0))
284
- obs, state = env.reset(key)
285
- state_seq = []
286
- for step in range(100):
287
- rng, key = random.split(rng)
288
- key_act = random.split(key, len(env.agents))
289
- actions = {
290
- agent: jax.random.randint(key_act[i], (), 0, 5)
291
- for i, agent in enumerate(env.agents)
292
- }
293
- _, state, _, _, _ = env.step(key, state, actions)
294
- state_seq.append((obs, state, actions))
59
+ def blast_fn(rng: Array, cfg: Config, state: State, action: Action, idx: Array, norm: Array) -> Array:
60
+ dam = (cfg.dam[cfg.types] * action.cast)[..., None] * jnp.ones_like(idx)
61
+ return state.hp - jnp.zeros(cfg.length, dtype=jnp.int32).at[idx.flatten()].add(dam.flatten())
295
62
 
296
63
 
64
+ def push_fn(cfg: Config, rng: Array, idx: Array, norm: Array, pos: Array) -> Array:
65
+ return pos + random.normal(rng, pos.shape) * 0.1
66
+ # params need to be tweaked, and matched with unit size
67
+ pos_diff = pos[:, None, :] - pos[idx] # direction away from neighbors
68
+ mask = (norm < cfg.r[cfg.types][..., None]) & (norm > 0)
69
+ pos = pos + jnp.where(mask[..., None], pos_diff * cfg.force / (norm[..., None] + 1e-6), 0.0).sum(axis=1)
70
+ return pos + random.normal(rng, pos.shape) * 0.1
parabellum/geo.py ADDED
@@ -0,0 +1,213 @@
1
+ # %% geo.py
2
+ # script for geospatial level generation
3
+ # by: Noah Syrkis
4
+
5
+ # %% Imports
6
+ from rasterio import features, transform
7
+
8
+ # from jax import tree
9
+ from geopy.geocoders import Nominatim
10
+ from geopy.distance import distance
11
+ import contextily as cx
12
+ import jax.numpy as jnp
13
+ import cartopy.crs as ccrs
14
+ from jaxtyping import Array
15
+ from shapely import box
16
+ import osmnx as ox
17
+ import geopandas as gpd
18
+ from collections import namedtuple
19
+ from typing import Tuple
20
+ import matplotlib.pyplot as plt
21
+ from cachier import cachier
22
+ # from jax.scipy.signal import convolve
23
+ # from parabellum.types import Terrain
24
+
25
+ # %% Types
26
+ Coords = Tuple[float, float]
27
+ BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
28
+
29
+ # %% Constants
30
+ provider = cx.providers.Stadia.StamenTerrain( # type: ignore
31
+ api_key="86d0d32b-d2fe-49af-8db8-f7751f58e83f"
32
+ )
33
+ provider["url"] = provider["url"] + "?api_key={api_key}"
34
+ tags = {
35
+ "building": True,
36
+ "water": True,
37
+ "highway": True,
38
+ "landuse": [
39
+ "grass",
40
+ "forest",
41
+ "flowerbed",
42
+ "greenfield",
43
+ "village_green",
44
+ "recreation_ground",
45
+ ],
46
+ "leisure": "garden",
47
+ } # "road": True}
48
+
49
+
50
+ # %% Coordinate function
51
+ def get_coordinates(place: str) -> Coords:
52
+ geolocator = Nominatim(user_agent="parabellum")
53
+ point = geolocator.geocode(place)
54
+ return point.latitude, point.longitude # type: ignore
55
+
56
+
57
+ def get_bbox(place: str, buffer) -> BBox:
58
+ """Get bounding box from place name in crs 4326."""
59
+ coords = get_coordinates(place)
60
+ north = distance(meters=buffer).destination(coords, bearing=0).latitude
61
+ south = distance(meters=buffer).destination(coords, bearing=180).latitude
62
+ east = distance(meters=buffer).destination(coords, bearing=90).longitude
63
+ west = distance(meters=buffer).destination(coords, bearing=270).longitude
64
+ return BBox(north, south, east, west) # type: ignore
65
+
66
+
67
+ def basemap_fn(bbox: BBox, gdf) -> Array:
68
+ fig, ax = plt.subplots(figsize=(20, 20), subplot_kw={"projection": ccrs.Mercator()})
69
+ gdf.plot(ax=ax, color="black", alpha=0, edgecolor="black") # type: ignore
70
+ cx.add_basemap(ax, crs=gdf.crs, source=provider, zoom="auto") # type: ignore
71
+ bbox = gdf.total_bounds
72
+ ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.Mercator()) # type: ignore
73
+ plt.axis("off")
74
+ plt.tight_layout(pad=0)
75
+ fig.canvas.draw()
76
+ image = jnp.array(fig.canvas.renderer._renderer) # type: ignore
77
+ plt.close(fig)
78
+ return image
79
+
80
+
81
+ @cachier()
82
+ def geography_fn(place, buffer):
83
+ bbox = get_bbox(place, buffer)
84
+ map_data = ox.features_from_bbox(bbox=bbox, tags=tags)
85
+ gdf = gpd.GeoDataFrame(map_data)
86
+ gdf = gdf.clip(box(bbox.west, bbox.south, bbox.east, bbox.north)).to_crs("EPSG:3857")
87
+ raster = raster_fn(gdf, shape=(buffer, buffer))
88
+ # basemap = jnp.rot90(basemap_fn(bbox, gdf), 3)
89
+ # kernel = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
90
+ trans = lambda x: jnp.bool(x) # jnp.rot90(x, 3) # noqa
91
+ terrain = trans(raster[0]) # Terrain(
92
+ # building=trans(raster[0]),
93
+ # water=trans(raster[1] - convolve(raster[1] * raster[2], kernel, mode="same") > 0),
94
+ # forest=trans(jnp.logical_or(raster[3], raster[4])),
95
+ # basemap=basemap,
96
+ # )
97
+ # terrain = tree.map(lambda x: x.astype(jnp.int16), terrain)
98
+ return terrain
99
+
100
+
101
+ # =======
102
+ # terrain = tps.Terrain(building=trans(raster[0] - convolve(raster[0]*raster[2], kernel, mode='same')>0),
103
+ # water=trans(raster[1] - convolve(raster[1]*raster[2], kernel, mode='same')>0),
104
+ # forest=trans(jnp.logical_or(raster[3], raster[4])),
105
+ # basemap=basemap)
106
+ # return terrain, gdf
107
+ # >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
108
+
109
+
110
+ def raster_fn(gdf, shape) -> Array:
111
+ bbox = gdf.total_bounds
112
+ t = transform.from_bounds(*bbox, *shape) # type: ignore
113
+ raster = jnp.array([feature_fn(t, feature, gdf, shape) for feature in tags])
114
+ return raster
115
+
116
+
117
+ def feature_fn(t, feature, gdf, shape):
118
+ if feature not in gdf.columns:
119
+ return jnp.zeros(shape)
120
+ gdf = gdf[~gdf[feature].isna()]
121
+ raster = features.rasterize(gdf.geometry, out_shape=shape, transform=t, fill=0) # type: ignore
122
+ return raster
123
+
124
+
125
+ # %%
126
+ # def normalize(x):
127
+ # return (np.array(x) - m) / (M - m)
128
+
129
+
130
+ # def get_bridges(gdf):
131
+ # xmin, ymin, xmax, ymax = gdf.total_bounds
132
+ # m = np.array([xmin, ymin])
133
+ # M = np.array([xmax, ymax])
134
+
135
+ # bridges = {}
136
+ # for idx, bridge in gdf[gdf["bridge"] == "yes"].iterrows():
137
+ # if type(bridge["name"]) == str:
138
+ # bridges[idx[1]] = {
139
+ # "name": bridge["name"],
140
+ # "coords": normalize(
141
+ # [bridge.geometry.centroid.x, bridge.geometry.centroid.y]
142
+ # ),
143
+ # }
144
+ # return bridges
145
+
146
+
147
+ """
148
+ # %%
149
+ if __name__ == "__main__":
150
+ place = "Thun, Switzerland"
151
+ <<<<<<< HEAD
152
+ terrain = geography_fn(place, 300)
153
+
154
+ =======
155
+ terrain, gdf = geography_fn(place, 300)
156
+
157
+ >>>>>>> aeb13033e57083cc512a60f8f60a3db47a65ac32
158
+ fig, axes = plt.subplots(1, 5, figsize=(20, 20))
159
+ axes[0].imshow(jnp.rot90(terrain.building), cmap="gray")
160
+ axes[1].imshow(jnp.rot90(terrain.water), cmap="gray")
161
+ axes[2].imshow(jnp.rot90(terrain.forest), cmap="gray")
162
+ axes[3].imshow(jnp.rot90(terrain.building + terrain.water + terrain.forest))
163
+ axes[4].imshow(jnp.rot90(terrain.basemap))
164
+
165
+ # %%
166
+ W, H, _ = terrain.basemap.shape
167
+ bridges = get_bridges(gdf)
168
+
169
+ # %%
170
+ print("Bridges:")
171
+ for bridge in bridges.values():
172
+ x, y = int(bridge["coords"][0]*300), int(bridge["coords"][1]*300)
173
+ print(bridge["name"], f"at ({x}, {y})")
174
+
175
+ # %%
176
+ plt.subplots(figsize=(7,7))
177
+ plt.imshow(jnp.rot90(terrain.basemap))
178
+ X = [b["coords"][0]*W for b in bridges.values()]
179
+ Y = [(1-b["coords"][1])*H for b in bridges.values()]
180
+ plt.scatter(X, Y)
181
+ for i in range(len(X)):
182
+ x,y = int(X[i]), int(Y[i])
183
+ plt.text(x, y, str((int(x/W*300), int((1-(y/H))*300))))
184
+
185
+ # %%
186
+
187
+ # %% [raw]
188
+ # fig, ax = plt.subplots(figsize=(10, 10))
189
+ # gdf.plot(ax=ax, color='lightgray') # Plot all features
190
+ # bridges.plot(ax=ax, color='red') # Highlight bridges in red
191
+ # plt.show()
192
+
193
+ # %%
194
+
195
+ """
196
+
197
+ # BBox = namedtuple("BBox", ["north", "south", "east", "west"]) # type: ignore
198
+
199
+
200
+ # def to_mercator(bbox: BBox) -> BBox:
201
+ # proj = ccrs.Mercator()
202
+ # west, south = proj.transform_point(bbox.west, bbox.south, ccrs.PlateCarree())
203
+ # east, north = proj.transform_point(bbox.east, bbox.north, ccrs.PlateCarree())
204
+ # return BBox(north=north, south=south, east=east, west=west)
205
+ #
206
+ #
207
+ # def to_platecarree(bbox: BBox) -> BBox:
208
+ # proj = ccrs.PlateCarree()
209
+ # west, south = proj.transform_point(bbox.west, bbox.south, ccrs.Mercator())
210
+ #
211
+ # east, north = proj.transform_point(bbox.east, bbox.north, ccrs.Mercator())
212
+ # return BBox(north=north, south=south, east=east, west=west)
213
+ #