ssb-sgis 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sgis/maps/map.py CHANGED
@@ -83,8 +83,7 @@ class Map:
83
83
  scheme: str = DEFAULT_SCHEME,
84
84
  **kwargs,
85
85
  ):
86
- if not all(isinstance(gdf, GeoDataFrame) for gdf in gdfs):
87
- gdfs, column = self._separate_args(gdfs, column)
86
+ gdfs, column, kwargs = self._separate_args(gdfs, column, kwargs)
88
87
 
89
88
  self._column = column
90
89
  self.bins = bins
@@ -139,10 +138,10 @@ class Map:
139
138
  self.show.append(show)
140
139
  self.labels = new_labels
141
140
 
142
- if self.show:
141
+ if len(self._gdfs):
143
142
  last_show = self.show[-1]
144
143
  else:
145
- last_show = True
144
+ last_show = show
146
145
 
147
146
  # pop all geometry-like items from kwargs into self._gdfs
148
147
  self.kwargs = {}
@@ -296,9 +295,17 @@ class Map:
296
295
  def _separate_args(
297
296
  args: tuple,
298
297
  column: str | None,
298
+ kwargs: dict,
299
299
  ) -> tuple[tuple[GeoDataFrame], str]:
300
300
  """Separate GeoDataFrames from string (column argument)."""
301
301
 
302
+ def as_dict(obj):
303
+ if hasattr(obj, "__dict__"):
304
+ return obj.__dict__
305
+ elif isinstance(obj, dict):
306
+ return obj
307
+ raise TypeError
308
+
302
309
  gdfs: tuple[GeoDataFrame] = ()
303
310
  for arg in args:
304
311
  if isinstance(arg, str):
@@ -310,8 +317,27 @@ class Map:
310
317
  )
311
318
  elif isinstance(arg, (GeoDataFrame, GeoSeries, Geometry)):
312
319
  gdfs = gdfs + (arg,)
313
-
314
- return gdfs, column
320
+ elif isinstance(arg, dict) or hasattr(arg, "__dict__"):
321
+ # add dicts or classes with GeoDataFrames to kwargs
322
+ more_gdfs = {}
323
+ for key, value in as_dict(arg).items():
324
+ if isinstance(value, (GeoDataFrame, GeoSeries, Geometry)):
325
+ more_gdfs[key] = value
326
+ elif isinstance(value, dict) or hasattr(value, "__dict__"):
327
+ try:
328
+ # same as above, one level down
329
+ more_gdfs |= {
330
+ k: v
331
+ for k, v in value.items()
332
+ if isinstance(v, (GeoDataFrame, GeoSeries, Geometry))
333
+ }
334
+ except Exception:
335
+ # no need to raise here
336
+ pass
337
+
338
+ kwargs |= more_gdfs
339
+
340
+ return gdfs, column, kwargs
315
341
 
316
342
  def _prepare_continous_map(self):
317
343
  """Create bins if not already done and adjust k if needed."""
sgis/maps/maps.py CHANGED
@@ -56,7 +56,7 @@ def _get_location_mask(kwargs: dict, gdfs) -> tuple[GeoDataFrame | None, dict]:
56
56
 
57
57
 
58
58
  def explore(
59
- *gdfs: GeoDataFrame,
59
+ *gdfs: GeoDataFrame | dict[str, GeoDataFrame],
60
60
  column: str | None = None,
61
61
  center: Any | None = None,
62
62
  labels: tuple[str] | None = None,
@@ -121,6 +121,8 @@ def explore(
121
121
  >>> explore(roads, points, column="meters", cmap="plasma", max_zoom=60)
122
122
  """
123
123
 
124
+ gdfs, column, kwargs = Map._separate_args(gdfs, column, kwargs)
125
+
124
126
  loc_mask, kwargs = _get_location_mask(kwargs | {"size": size}, gdfs)
125
127
 
126
128
  kwargs.pop("size", None)
@@ -145,7 +147,11 @@ def explore(
145
147
  elif isinstance(center, GeoDataFrame):
146
148
  mask = center
147
149
  else:
148
- mask = to_gdf_func(center, crs=gdfs[0].crs)
150
+ try:
151
+ mask = to_gdf_func(center, crs=gdfs[0].crs)
152
+ except IndexError:
153
+ df = [x for x in kwargs.values() if hasattr(x, "crs")][0]
154
+ mask = to_gdf_func(center, crs=df.crs)
149
155
 
150
156
  if get_geom_type(mask) in ["point", "line"]:
151
157
  mask = mask.buffer(size)
@@ -176,7 +182,7 @@ def explore(
176
182
  if not kwargs.pop("explore", True):
177
183
  return qtm(m._gdf, column=m.column, cmap=m._cmap, k=m.k)
178
184
 
179
- m.explore()
185
+ return m.explore()
180
186
 
181
187
 
182
188
  def samplemap(
@@ -251,20 +257,11 @@ def samplemap(
251
257
  if gdfs and isinstance(gdfs[-1], (float, int)):
252
258
  *gdfs, size = gdfs
253
259
 
260
+ gdfs, column, kwargs = Map._separate_args(gdfs, column, kwargs)
261
+
254
262
  mask, kwargs = _get_location_mask(kwargs | {"size": size}, gdfs)
255
263
  kwargs.pop("size")
256
264
 
257
- if mask is not None:
258
- gdfs, column = Explore._separate_args(gdfs, column)
259
- gdfs, kwargs = _prepare_clipmap(
260
- *gdfs,
261
- mask=mask,
262
- labels=labels,
263
- **kwargs,
264
- )
265
- if not gdfs:
266
- return
267
-
268
265
  if explore:
269
266
  m = Explore(
270
267
  *gdfs,
@@ -277,6 +274,12 @@ def samplemap(
277
274
  )
278
275
  if m.gdfs is None:
279
276
  return
277
+ if mask is not None:
278
+ m._gdfs = [gdf.clip(mask) for gdf in m._gdfs]
279
+ m._gdf = m._gdf.clip(mask)
280
+ m._nan_idx = m._gdf[m._column].isna()
281
+ m._get_unique_values()
282
+
280
283
  m.samplemap(size, sample_from_first=sample_from_first)
281
284
 
282
285
  else:
@@ -311,38 +314,6 @@ def samplemap(
311
314
  qtm(m._gdf, column=m.column, cmap=m._cmap, k=m.k)
312
315
 
313
316
 
314
- def _prepare_clipmap(*gdfs, mask, labels, **kwargs):
315
- if mask is None:
316
- mask, kwargs = _get_location_mask(kwargs, gdfs)
317
- if mask is None and len(gdfs) > 1:
318
- *gdfs, mask = gdfs
319
- elif mask is None:
320
- raise ValueError("Must speficy mask.")
321
-
322
- # storing object names in dict here, since the names disappear after clip
323
- if not labels:
324
- namedict = make_namedict(gdfs)
325
- kwargs["namedict"] = namedict
326
-
327
- clipped: tuple[GeoDataFrame] = ()
328
-
329
- if mask is not None:
330
- for gdf in gdfs:
331
- clipped_ = gdf.clip(mask)
332
- clipped = clipped + (clipped_,)
333
-
334
- else:
335
- for gdf in gdfs[:-1]:
336
- clipped_ = gdf.clip(gdfs[-1])
337
- clipped = clipped + (clipped_,)
338
-
339
- if not any(len(gdf) for gdf in clipped):
340
- warnings.warn("None of the GeoDataFrames are within the mask extent.")
341
- return None, None
342
-
343
- return clipped, kwargs
344
-
345
-
346
317
  def clipmap(
347
318
  *gdfs: GeoDataFrame,
348
319
  column: str | None = None,
@@ -391,23 +362,18 @@ def clipmap(
391
362
  samplemap: same functionality, but shows only a random area of a given size.
392
363
  """
393
364
 
394
- gdfs, column = Explore._separate_args(gdfs, column)
365
+ gdfs, column, kwargs = Map._separate_args(gdfs, column, kwargs)
395
366
 
396
- clipped, kwargs = _prepare_clipmap(
397
- *gdfs,
398
- mask=mask,
399
- labels=labels,
400
- **kwargs,
401
- )
402
- if not clipped:
403
- return
367
+ if mask is None and len(gdfs) > 1:
368
+ mask = gdfs[-1]
369
+ gdfs = gdfs[:-1]
404
370
 
405
371
  center = kwargs.pop("center", None)
406
372
  size = kwargs.pop("size", None)
407
373
 
408
374
  if explore:
409
375
  m = Explore(
410
- *clipped,
376
+ *gdfs,
411
377
  column=column,
412
378
  labels=labels,
413
379
  browser=browser,
@@ -418,6 +384,10 @@ def clipmap(
418
384
  if m.gdfs is None:
419
385
  return
420
386
 
387
+ m._gdfs = [gdf.clip(mask) for gdf in m._gdfs]
388
+ m._gdf = m._gdf.clip(mask)
389
+ m._nan_idx = m._gdf[m._column].isna()
390
+ m._get_unique_values()
421
391
  m.explore(center=center, size=size)
422
392
  else:
423
393
  m = Map(
@@ -426,6 +396,14 @@ def clipmap(
426
396
  labels=labels,
427
397
  **kwargs,
428
398
  )
399
+ if m.gdfs is None:
400
+ return
401
+
402
+ m._gdfs = [gdf.clip(mask) for gdf in m._gdfs]
403
+ m._gdf = m._gdf.clip(mask)
404
+ m._nan_idx = m._gdf[m._column].isna()
405
+ m._get_unique_values()
406
+
429
407
  qtm(m._gdf, column=m.column, cmap=m._cmap, k=m.k)
430
408
 
431
409
 
@@ -449,8 +427,12 @@ def explore_locals(*gdfs, to_gdf: bool = True, **kwargs):
449
427
  continue
450
428
  if not to_gdf:
451
429
  continue
452
- if hasattr(value, "__len__") and not len(value):
430
+ try:
431
+ if hasattr(value, "__len__") and not len(value):
432
+ continue
433
+ except TypeError:
453
434
  continue
435
+
454
436
  try:
455
437
  gdf = clean_geoms(to_gdf_func(value))
456
438
  if len(gdf):
@@ -509,7 +491,7 @@ def qtm(
509
491
  See also:
510
492
  ThematicMap: Class with more options for customising the plot.
511
493
  """
512
- gdfs, column = Explore._separate_args(gdfs, column)
494
+ gdfs, column, kwargs = Map._separate_args(gdfs, column, kwargs)
513
495
 
514
496
  new_kwargs = {}
515
497
  for key, value in kwargs.items():
@@ -0,0 +1,61 @@
1
+ from xyzservices import TileProvider, Bunch, providers
2
+
3
+ kartverket = Bunch(
4
+ norgeskart=TileProvider(
5
+ name="Norgeskart",
6
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=norgeskart_bakgrunn&zoom={z}&x={x}&y={y}",
7
+ attribution="© Kartverket",
8
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
9
+ ),
10
+
11
+ bakgrunnskart_forenklet=TileProvider(
12
+ name="Norgeskart forenklet",
13
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=bakgrunnskart_forenklet&zoom={z}&x={x}&y={y}",
14
+ attribution="© Kartverket",
15
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
16
+ ),
17
+
18
+ norges_grunnkart=TileProvider(
19
+ name="Norges grunnkart",
20
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=norges_grunnkart&zoom={z}&x={x}&y={y}",
21
+ attribution="© Kartverket",
22
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
23
+ ),
24
+
25
+ norges_grunnkart_gråtone=TileProvider(
26
+ name="Norges grunnkart gråtone",
27
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=norges_grunnkart_graatone&zoom={z}&x={x}&y={y}",
28
+ attribution="© Kartverket",
29
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
30
+ ),
31
+
32
+ n50=TileProvider(
33
+ name="N5 til N50 kartdata",
34
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=kartdata3&zoom={z}&x={x}&y={y}",
35
+ attribution="© Kartverket",
36
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
37
+ ),
38
+
39
+ topogråtone=TileProvider(
40
+ name="Topografisk norgeskart gråtone",
41
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=topo4graatone&zoom={z}&x={x}&y={y}",
42
+ attribution="© Kartverket",
43
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
44
+ ),
45
+
46
+ toporaster=TileProvider(
47
+ name="Topografisk raster",
48
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_gmaps?layers=toporaster4&zoom={z}&x={x}&y={y}",
49
+ attribution="© Kartverket",
50
+ html_attribution='&copy; <a href="https://kartverket.no">Kartverket</a>',
51
+ ),
52
+
53
+ norge_i_bilder=TileProvider(
54
+ name="Norge i bilder",
55
+ url="https://opencache.statkart.no/gatekeeper/gk/gk.open_nib_web_mercator_wmts_v2?SERVICE=WMTS&REQUEST=GetTile&VERSION=1.0.0&LAYER=Nibcache_web_mercator_v2&STYLE=default&FORMAT=image/jpgpng&tileMatrixSet=default028mm&tileMatrix={z}&tileRow={y}&tileCol={x}",
56
+ max_zoom=19,
57
+ attribution="© Geovekst",
58
+ ),
59
+ )
60
+
61
+ xyz = Bunch({"Kartverket": kartverket} | providers)
@@ -3,15 +3,63 @@
3
3
  import geopandas as gpd
4
4
  import numpy as np
5
5
  import pandas as pd
6
- from geopandas import GeoDataFrame
6
+ from geopandas import GeoDataFrame, GeoSeries
7
7
  from pandas import DataFrame
8
8
  from shapely import shortest_line
9
9
 
10
- from ..geopandas_tools.conversion import coordinate_array
10
+ from ..geopandas_tools.conversion import coordinate_array, to_geoseries
11
+ from ..geopandas_tools.geometry_types import get_geom_type
11
12
  from ..geopandas_tools.neighbors import k_nearest_neighbors
12
13
  from .nodes import make_edge_wkt_cols, make_node_ids
13
14
 
14
15
 
16
+ def close_network_holes_to(
17
+ lines: GeoDataFrame | GeoSeries,
18
+ extend_to: GeoDataFrame | GeoSeries,
19
+ max_distance: int | float,
20
+ max_angle: int | float,
21
+ ) -> GeoDataFrame | GeoSeries:
22
+ if isinstance(lines, GeoSeries):
23
+ lines = lines.to_frame("geometry")
24
+ was_geoseries = True
25
+ else:
26
+ was_geoseries = False
27
+
28
+ lines, _ = make_node_ids(lines)
29
+
30
+ if isinstance(extend_to, GeoSeries):
31
+ extend_to = extend_to.to_frame("geometry")
32
+
33
+ if not (extend_to.geom_type == "Point").all():
34
+ raise ValueError("'extend_to' must be singlepart point geometries")
35
+
36
+ extend_to["wkt"] = extend_to.geometry.to_wkt()
37
+ extend_to = extend_to.drop_duplicates("wkt")
38
+ extend_to["node_id"] = range(len(extend_to))
39
+
40
+ new_lines: GeoSeries = _close_holes_all_lines(
41
+ lines, extend_to, max_distance=max_distance, max_angle=max_angle, idx_start=0
42
+ )
43
+
44
+ if was_geoseries:
45
+ return pd.concat([lines.geometry, new_lines])
46
+
47
+ new_lines = gpd.GeoDataFrame(
48
+ {"geometry": new_lines}, geometry="geometry", crs=lines.crs
49
+ )
50
+
51
+ return pd.concat([lines, new_lines], ignore_index=True).drop(
52
+ columns=[
53
+ "source_wkt",
54
+ "target_wkt",
55
+ "source",
56
+ "target",
57
+ "n_source",
58
+ "n_target",
59
+ ]
60
+ )
61
+
62
+
15
63
  def close_network_holes(
16
64
  gdf: GeoDataFrame,
17
65
  max_distance: int | float,
@@ -88,32 +136,42 @@ def close_network_holes(
88
136
  intentional. They are road blocks where most cars aren't allowed to pass. Fill the
89
137
  holes only if it makes the travel times/routes more realistic.
90
138
  """
91
- gdf, nodes = make_node_ids(gdf)
92
139
 
93
- new_lines = _find_holes_all_lines(
94
- gdf,
95
- nodes,
96
- max_distance,
97
- max_angle=max_angle,
140
+ lines, nodes = make_node_ids(gdf)
141
+
142
+ # remove duplicates of lines going both directions
143
+ lines["sorted"] = [
144
+ "_".join(sorted([s, t]))
145
+ for s, t in zip(lines["source"], lines["target"], strict=True)
146
+ ]
147
+
148
+ new_lines: GeoSeries = _close_holes_all_lines(
149
+ lines.drop_duplicates("sorted"), nodes, max_distance, max_angle, idx_start=1
150
+ )
151
+
152
+ new_lines = gpd.GeoDataFrame(
153
+ {"geometry": new_lines}, geometry="geometry", crs=gdf.crs
98
154
  )
99
155
 
100
156
  if not len(new_lines):
101
- gdf[hole_col] = 0 if hole_col not in gdf.columns else gdf[hole_col].fillna(0)
102
- return gdf
157
+ lines[hole_col] = (
158
+ 0 if hole_col not in lines.columns else lines[hole_col].fillna(0)
159
+ )
160
+ return lines
103
161
 
104
162
  new_lines = make_edge_wkt_cols(new_lines)
105
163
 
106
- wkt_id_dict = {
107
- wkt: id for wkt, id in zip(nodes["wkt"], nodes["node_id"], strict=True)
108
- }
164
+ wkt_id_dict = dict(zip(nodes["wkt"], nodes["node_id"], strict=True))
109
165
  new_lines["source"] = new_lines["source_wkt"].map(wkt_id_dict)
110
166
  new_lines["target"] = new_lines["target_wkt"].map(wkt_id_dict)
111
167
 
112
168
  if hole_col:
113
169
  new_lines[hole_col] = 1
114
- gdf[hole_col] = 0 if hole_col not in gdf.columns else gdf[hole_col].fillna(0)
170
+ lines[hole_col] = (
171
+ 0 if hole_col not in lines.columns else lines[hole_col].fillna(0)
172
+ )
115
173
 
116
- return pd.concat([gdf, new_lines], ignore_index=True)
174
+ return pd.concat([lines, new_lines], ignore_index=True)
117
175
 
118
176
 
119
177
  def get_angle(array_a, array_b):
@@ -200,49 +258,29 @@ def close_network_holes_to_deadends(
200
258
  return pd.concat([gdf, new_lines], ignore_index=True)
201
259
 
202
260
 
203
- def _find_holes_all_lines(
204
- lines: GeoDataFrame,
205
- nodes: GeoDataFrame,
206
- max_distance: int | float,
207
- max_angle: int,
208
- ) -> GeoDataFrame | DataFrame:
209
- """Creates lines between deadends and closest node.
210
-
211
- Creates lines if distance is less than max_distance and angle less than max_angle.
212
-
213
- wkt: well-known text, e.g. "POINT (60 10)"
214
- """
215
- k = 50 if len(nodes) >= 50 else len(nodes)
216
- crs = nodes.crs
217
-
218
- # remove duplicates of lines going both directions
219
- lines["sorted"] = [
220
- "_".join(sorted([s, t]))
221
- for s, t in zip(lines["source"], lines["target"], strict=True)
222
- ]
223
-
224
- no_dups = lines.drop_duplicates("sorted")
225
-
226
- no_dups, nodes = make_node_ids(no_dups)
261
+ def _close_holes_all_lines(
262
+ lines, nodes, max_distance, max_angle, idx_start: int
263
+ ) -> GeoSeries:
264
+ k = min(len(nodes), 50)
227
265
 
228
266
  # make point gdf for the deadends and the other endpoint of the deadend lines
229
- deadends_target = no_dups.loc[no_dups.n_target == 1].rename(
267
+ deadends_target = lines.loc[lines["n_target"] == 1].rename(
230
268
  columns={"target_wkt": "wkt", "source_wkt": "wkt_other_end"}
231
269
  )
232
- deadends_source = no_dups.loc[no_dups.n_source == 1].rename(
270
+ deadends_source = lines.loc[lines["n_source"] == 1].rename(
233
271
  columns={"source_wkt": "wkt", "target_wkt": "wkt_other_end"}
234
272
  )
235
273
  deadends = pd.concat([deadends_source, deadends_target], ignore_index=True)
236
274
 
237
275
  if len(deadends) <= 1:
238
- return DataFrame()
276
+ return GeoSeries()
239
277
 
240
278
  deadends_other_end = deadends.copy()
241
279
  deadends_other_end["geometry"] = gpd.GeoSeries.from_wkt(
242
- deadends_other_end["wkt_other_end"], crs=crs
280
+ deadends_other_end["wkt_other_end"]
243
281
  )
244
282
 
245
- deadends["geometry"] = gpd.GeoSeries.from_wkt(deadends["wkt"], crs=crs)
283
+ deadends["geometry"] = gpd.GeoSeries.from_wkt(deadends["wkt"])
246
284
 
247
285
  deadends_array = coordinate_array(deadends)
248
286
  nodes_array = coordinate_array(nodes)
@@ -255,7 +293,7 @@ def _find_holes_all_lines(
255
293
  # and endpoints of the new lines in lists, looping through the k neighbour points
256
294
  new_sources: list[str] = []
257
295
  new_targets: list[str] = []
258
- for i in np.arange(1, k):
296
+ for i in np.arange(idx_start, k):
259
297
  # to break out of the loop if no new_targets that meet the condition are found
260
298
  len_now = len(new_sources)
261
299
 
@@ -263,7 +301,7 @@ def _find_holes_all_lines(
263
301
  indices = all_indices[:, i]
264
302
  dists = all_dists[:, i]
265
303
 
266
- these_nodes_array = coordinate_array(nodes.loc[indices])
304
+ these_nodes_array = coordinate_array(nodes.iloc[indices])
267
305
 
268
306
  if np.all(deadends_other_end_array == these_nodes_array):
269
307
  continue
@@ -286,7 +324,7 @@ def _find_holes_all_lines(
286
324
 
287
325
  from_wkt = deadends.loc[condition, "wkt"]
288
326
  to_idx = indices[condition]
289
- to_wkt = nodes.loc[to_idx, "wkt"]
327
+ to_wkt = nodes.iloc[to_idx]["wkt"]
290
328
 
291
329
  # now add the wkts to the lists of new sources and targets. If the source
292
330
  # is already added, the new wks will not be added again
@@ -301,18 +339,10 @@ def _find_holes_all_lines(
301
339
  if len_now == len(new_sources):
302
340
  break
303
341
 
304
- # make GeoDataFrame with straight lines
305
- new_sources = gpd.GeoSeries.from_wkt(new_sources, crs=crs)
306
- new_targets = gpd.GeoSeries.from_wkt(new_targets, crs=crs)
307
- new_lines = shortest_line(new_sources, new_targets)
308
- new_lines = gpd.GeoDataFrame({"geometry": new_lines}, geometry="geometry", crs=crs)
309
-
310
- if not len(new_lines):
311
- return new_lines
312
-
313
- new_lines = make_edge_wkt_cols(new_lines)
314
-
315
- return new_lines
342
+ # make GeoSeries with straight lines
343
+ new_sources = gpd.GeoSeries.from_wkt(new_sources, crs=lines.crs)
344
+ new_targets = gpd.GeoSeries.from_wkt(new_targets, crs=lines.crs)
345
+ return shortest_line(new_sources, new_targets)
316
346
 
317
347
 
318
348
  def _find_holes_deadends(
@@ -362,7 +392,4 @@ def _find_holes_deadends(
362
392
  new_lines = shortest_line(from_geom, to_geom)
363
393
  new_lines = gpd.GeoDataFrame({"geometry": new_lines}, geometry="geometry", crs=crs)
364
394
 
365
- if not len(new_lines):
366
- return new_lines
367
-
368
395
  return new_lines
@@ -101,7 +101,7 @@ def split_lines_by_nearest_point(
101
101
 
102
102
  gdf = gdf.copy()
103
103
 
104
- # move the points to the nearest exact point of the line
104
+ # move the points to the nearest exact location on the line
105
105
  if max_distance:
106
106
  snapped = snap_within_distance(points, gdf, max_distance=max_distance)
107
107
  else:
@@ -167,12 +167,18 @@ def split_lines_by_nearest_point(
167
167
  # the snapped points.
168
168
 
169
169
  splitted = change_line_endpoint(
170
- splitted, dists_source, pointmapper_source, change_what="first"
170
+ splitted,
171
+ indices=dists_source.index,
172
+ pointmapper=pointmapper_source,
173
+ change_what="first",
171
174
  ) # i=0)
172
175
 
173
176
  # same for the lines where the target was split, but change the last coordinate
174
177
  splitted = change_line_endpoint(
175
- splitted, dists_target, pointmapper_target, change_what="last"
178
+ splitted,
179
+ indices=dists_target.index,
180
+ pointmapper=pointmapper_target,
181
+ change_what="last",
176
182
  ) # , i=-1)
177
183
 
178
184
  if splitted_col:
@@ -185,7 +191,7 @@ def split_lines_by_nearest_point(
185
191
 
186
192
  def change_line_endpoint(
187
193
  gdf: GeoDataFrame,
188
- dists: pd.DataFrame,
194
+ indices: pd.Index,
189
195
  pointmapper: pd.Series,
190
196
  change_what: str | int,
191
197
  ) -> GeoDataFrame:
@@ -204,7 +210,7 @@ def change_line_endpoint(
204
210
  f"change_what should be 'first' or 'last' or 0 or -1. Got {change_what}"
205
211
  )
206
212
 
207
- is_relevant = gdf.index.isin(dists.index)
213
+ is_relevant = gdf.index.isin(indices)
208
214
  relevant_lines = gdf.loc[is_relevant]
209
215
 
210
216
  relevant_lines.geometry = extract_unique_points(relevant_lines.geometry)
@@ -35,7 +35,7 @@ def get_connected_components(gdf: GeoDataFrame) -> GeoDataFrame:
35
35
 
36
36
  Removing the isolated network islands.
37
37
 
38
- >>> connected_roads = get_connected_components(roads).query("connected == 1")
38
+ >>> connected_roads = get_connected_components(roads).loc[lambda x: x["connected"] == 1]
39
39
  >>> roads.connected.value_counts()
40
40
  1.0 85638
41
41
  Name: connected, dtype: int64
@@ -171,7 +171,7 @@ def _prepare_make_edge_cols(
171
171
 
172
172
  geom_col = lines._geometry_column_name
173
173
 
174
- # some LinearRings are coded as LineStrings and need to be removed manually
174
+ # some LineStrings are in fact rings and must be removed manually
175
175
  boundary = lines[geom_col].boundary
176
176
  circles = boundary.loc[boundary.is_empty]
177
177
  lines = lines[~lines.index.isin(circles.index)]
@@ -20,7 +20,8 @@ def traveling_salesman_problem(
20
20
  return_to_start: If True (default), the path
21
21
  will make a full circle to the startpoint.
22
22
  If False, a dummy node will be added to make the
23
- salesman focus only on getting to the last node.
23
+ salesman focus only on getting to the last node. Not
24
+ guaranteed to work, meaning the wrong edge might be removed.
24
25
  distances: Optional DataFrame of distances between all points.
25
26
  If not provided, the calculation is done within this function.
26
27
  The DataFrame should be identical to the DataFrame created
@@ -75,7 +76,7 @@ def traveling_salesman_problem(
75
76
  & (x["neighbor_index"].isin(points.index))
76
77
  ]
77
78
 
78
- # need integer index
79
+ # need tange integer index
79
80
  to_int_idx = {idx: i for i, idx in enumerate(points.index)}
80
81
  points.index = points.index.map(to_int_idx)
81
82
  points = points.sort_index()
@@ -92,11 +93,12 @@ def traveling_salesman_problem(
92
93
  distances = distances.sort_values(
93
94
  ["mean_distance", "distance"], ascending=[True, False]
94
95
  )
96
+
95
97
  max_dist_idx = distances["mean_distance"].idxmax()
96
98
 
97
99
  dummy_node_idx = points.index.max() + 1
98
100
  n_points = dummy_node_idx + 1
99
- max_dist_and_some = distances["distance"].max() * 1.1
101
+ max_dist_and_some = distances["distance"].sum() * 1.01
100
102
 
101
103
  # add edges in both directions to the dummy node
102
104
  dummy_node = pd.DataFrame(
@@ -152,4 +154,6 @@ def traveling_salesman_problem(
152
154
 
153
155
  best_path = best_path[idx_start:] + best_path[:idx_start]
154
156
 
155
- return [idx_to_point[i] for i in best_path if i != dummy_node_idx]
157
+ as_points = [idx_to_point[i] for i in best_path if i != dummy_node_idx]
158
+
159
+ return as_points # + [as_points[0]]