BERATools 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. beratools/__init__.py +1 -7
  2. beratools/core/algo_centerline.py +491 -351
  3. beratools/core/algo_common.py +497 -0
  4. beratools/core/algo_cost.py +192 -0
  5. beratools/core/{dijkstra_algorithm.py → algo_dijkstra.py} +503 -460
  6. beratools/core/algo_footprint_rel.py +577 -0
  7. beratools/core/algo_line_grouping.py +944 -0
  8. beratools/core/algo_merge_lines.py +214 -0
  9. beratools/core/algo_split_with_lines.py +304 -0
  10. beratools/core/algo_tiler.py +428 -0
  11. beratools/core/algo_vertex_optimization.py +469 -0
  12. beratools/core/constants.py +52 -86
  13. beratools/core/logger.py +76 -85
  14. beratools/core/tool_base.py +196 -133
  15. beratools/gui/__init__.py +11 -15
  16. beratools/gui/{beratools.json → assets/beratools.json} +2185 -2300
  17. beratools/gui/batch_processing_dlg.py +513 -463
  18. beratools/gui/bt_data.py +481 -487
  19. beratools/gui/bt_gui_main.py +710 -691
  20. beratools/gui/main.py +26 -0
  21. beratools/gui/map_window.py +162 -146
  22. beratools/gui/tool_widgets.py +725 -493
  23. beratools/tools/Beratools_r_script.r +1120 -1120
  24. beratools/tools/Ht_metrics.py +116 -116
  25. beratools/tools/__init__.py +7 -7
  26. beratools/tools/batch_processing.py +136 -132
  27. beratools/tools/canopy_threshold_relative.py +672 -670
  28. beratools/tools/canopycostraster.py +222 -222
  29. beratools/tools/centerline.py +136 -176
  30. beratools/tools/common.py +857 -885
  31. beratools/tools/fl_regen_csf.py +428 -428
  32. beratools/tools/forest_line_attributes.py +408 -408
  33. beratools/tools/line_footprint_absolute.py +213 -363
  34. beratools/tools/line_footprint_fixed.py +436 -282
  35. beratools/tools/line_footprint_functions.py +733 -720
  36. beratools/tools/line_footprint_relative.py +73 -64
  37. beratools/tools/line_grouping.py +45 -0
  38. beratools/tools/ln_relative_metrics.py +615 -615
  39. beratools/tools/r_cal_lpi_elai.r +24 -24
  40. beratools/tools/r_generate_pd_focalraster.r +100 -100
  41. beratools/tools/r_interface.py +79 -79
  42. beratools/tools/r_point_density.r +8 -8
  43. beratools/tools/rpy_chm2trees.py +86 -86
  44. beratools/tools/rpy_dsm_chm_by.py +81 -81
  45. beratools/tools/rpy_dtm_by.py +63 -63
  46. beratools/tools/rpy_find_cellsize.py +43 -43
  47. beratools/tools/rpy_gnd_csf.py +74 -74
  48. beratools/tools/rpy_hummock_hollow.py +85 -85
  49. beratools/tools/rpy_hummock_hollow_raster.py +71 -71
  50. beratools/tools/rpy_las_info.py +51 -51
  51. beratools/tools/rpy_laz2las.py +40 -40
  52. beratools/tools/rpy_lpi_elai_lascat.py +466 -466
  53. beratools/tools/rpy_normalized_lidar_by.py +56 -56
  54. beratools/tools/rpy_percent_above_dbh.py +80 -80
  55. beratools/tools/rpy_points2trees.py +88 -88
  56. beratools/tools/rpy_vegcoverage.py +94 -94
  57. beratools/tools/tiler.py +48 -206
  58. beratools/tools/tool_template.py +69 -54
  59. beratools/tools/vertex_optimization.py +61 -620
  60. beratools/tools/zonal_threshold.py +144 -144
  61. beratools-0.2.2.dist-info/METADATA +108 -0
  62. beratools-0.2.2.dist-info/RECORD +74 -0
  63. {beratools-0.2.0.dist-info → beratools-0.2.2.dist-info}/WHEEL +1 -1
  64. {beratools-0.2.0.dist-info → beratools-0.2.2.dist-info}/licenses/LICENSE +22 -22
  65. beratools/gui/cli.py +0 -18
  66. beratools/gui/gui.json +0 -8
  67. beratools/gui_tk/ASCII Banners.txt +0 -248
  68. beratools/gui_tk/__init__.py +0 -20
  69. beratools/gui_tk/beratools_main.py +0 -515
  70. beratools/gui_tk/bt_widgets.py +0 -442
  71. beratools/gui_tk/cli.py +0 -18
  72. beratools/gui_tk/img/BERALogo.png +0 -0
  73. beratools/gui_tk/img/closed.gif +0 -0
  74. beratools/gui_tk/img/closed.png +0 -0
  75. beratools/gui_tk/img/open.gif +0 -0
  76. beratools/gui_tk/img/open.png +0 -0
  77. beratools/gui_tk/img/tool.gif +0 -0
  78. beratools/gui_tk/img/tool.png +0 -0
  79. beratools/gui_tk/main.py +0 -14
  80. beratools/gui_tk/map_window.py +0 -144
  81. beratools/gui_tk/runner.py +0 -1481
  82. beratools/gui_tk/tooltip.py +0 -55
  83. beratools/third_party/pyqtlet2/__init__.py +0 -9
  84. beratools/third_party/pyqtlet2/leaflet/__init__.py +0 -26
  85. beratools/third_party/pyqtlet2/leaflet/control/__init__.py +0 -6
  86. beratools/third_party/pyqtlet2/leaflet/control/control.py +0 -59
  87. beratools/third_party/pyqtlet2/leaflet/control/draw.py +0 -52
  88. beratools/third_party/pyqtlet2/leaflet/control/layers.py +0 -20
  89. beratools/third_party/pyqtlet2/leaflet/core/Parser.py +0 -24
  90. beratools/third_party/pyqtlet2/leaflet/core/__init__.py +0 -2
  91. beratools/third_party/pyqtlet2/leaflet/core/evented.py +0 -180
  92. beratools/third_party/pyqtlet2/leaflet/layer/__init__.py +0 -5
  93. beratools/third_party/pyqtlet2/leaflet/layer/featuregroup.py +0 -34
  94. beratools/third_party/pyqtlet2/leaflet/layer/icon/__init__.py +0 -1
  95. beratools/third_party/pyqtlet2/leaflet/layer/icon/icon.py +0 -30
  96. beratools/third_party/pyqtlet2/leaflet/layer/imageoverlay.py +0 -18
  97. beratools/third_party/pyqtlet2/leaflet/layer/layer.py +0 -105
  98. beratools/third_party/pyqtlet2/leaflet/layer/layergroup.py +0 -45
  99. beratools/third_party/pyqtlet2/leaflet/layer/marker/__init__.py +0 -1
  100. beratools/third_party/pyqtlet2/leaflet/layer/marker/marker.py +0 -91
  101. beratools/third_party/pyqtlet2/leaflet/layer/tile/__init__.py +0 -2
  102. beratools/third_party/pyqtlet2/leaflet/layer/tile/gridlayer.py +0 -4
  103. beratools/third_party/pyqtlet2/leaflet/layer/tile/tilelayer.py +0 -16
  104. beratools/third_party/pyqtlet2/leaflet/layer/vector/__init__.py +0 -5
  105. beratools/third_party/pyqtlet2/leaflet/layer/vector/circle.py +0 -15
  106. beratools/third_party/pyqtlet2/leaflet/layer/vector/circlemarker.py +0 -18
  107. beratools/third_party/pyqtlet2/leaflet/layer/vector/path.py +0 -5
  108. beratools/third_party/pyqtlet2/leaflet/layer/vector/polygon.py +0 -14
  109. beratools/third_party/pyqtlet2/leaflet/layer/vector/polyline.py +0 -18
  110. beratools/third_party/pyqtlet2/leaflet/layer/vector/rectangle.py +0 -14
  111. beratools/third_party/pyqtlet2/leaflet/map/__init__.py +0 -1
  112. beratools/third_party/pyqtlet2/leaflet/map/map.py +0 -220
  113. beratools/third_party/pyqtlet2/mapwidget.py +0 -45
  114. beratools/third_party/pyqtlet2/web/custom.js +0 -43
  115. beratools/third_party/pyqtlet2/web/map.html +0 -23
  116. beratools/third_party/pyqtlet2/web/modules/leaflet_193/images/layers-2x.png +0 -0
  117. beratools/third_party/pyqtlet2/web/modules/leaflet_193/images/layers.png +0 -0
  118. beratools/third_party/pyqtlet2/web/modules/leaflet_193/images/marker-icon-2x.png +0 -0
  119. beratools/third_party/pyqtlet2/web/modules/leaflet_193/images/marker-icon.png +0 -0
  120. beratools/third_party/pyqtlet2/web/modules/leaflet_193/images/marker-shadow.png +0 -0
  121. beratools/third_party/pyqtlet2/web/modules/leaflet_193/leaflet.css +0 -656
  122. beratools/third_party/pyqtlet2/web/modules/leaflet_193/leaflet.js +0 -6
  123. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/.codeclimate.yml +0 -14
  124. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/.editorconfig +0 -4
  125. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/.gitattributes +0 -22
  126. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/.travis.yml +0 -43
  127. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/LICENSE +0 -20
  128. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/layers-2x.png +0 -0
  129. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/layers.png +0 -0
  130. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/marker-icon-2x.png +0 -0
  131. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/marker-icon.png +0 -0
  132. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/marker-shadow.png +0 -0
  133. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/spritesheet-2x.png +0 -0
  134. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/spritesheet.png +0 -0
  135. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/images/spritesheet.svg +0 -156
  136. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/leaflet.draw.css +0 -10
  137. beratools/third_party/pyqtlet2/web/modules/leaflet_draw_414/leaflet.draw.js +0 -10
  138. beratools/third_party/pyqtlet2/web/modules/leaflet_rotatedMarker_020/LICENSE +0 -22
  139. beratools/third_party/pyqtlet2/web/modules/leaflet_rotatedMarker_020/leaflet.rotatedMarker.js +0 -57
  140. beratools/tools/forest_line_ecosite.py +0 -216
  141. beratools/tools/lapis_all.py +0 -103
  142. beratools/tools/least_cost_path_from_chm.py +0 -152
  143. beratools-0.2.0.dist-info/METADATA +0 -63
  144. beratools-0.2.0.dist-info/RECORD +0 -142
  145. /beratools/gui/{img → assets}/BERALogo.png +0 -0
  146. /beratools/gui/{img → assets}/closed.gif +0 -0
  147. /beratools/gui/{img → assets}/closed.png +0 -0
  148. /beratools/{gui_tk → gui/assets}/gui.json +0 -0
  149. /beratools/gui/{img → assets}/open.gif +0 -0
  150. /beratools/gui/{img → assets}/open.png +0 -0
  151. /beratools/gui/{img → assets}/tool.gif +0 -0
  152. /beratools/gui/{img → assets}/tool.png +0 -0
  153. {beratools-0.2.0.dist-info → beratools-0.2.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,428 @@
1
+ """
2
+ Copyright (C) 2025 Applied Geospatial Research Group.
3
+
4
+ This script is licensed under the GNU General Public License v3.0.
5
+ See <https://gnu.org/licenses/gpl-3.0> for full license details.
6
+
7
+ Author: Richard Zeng
8
+
9
+ Description:
10
+ This script is part of the BERA Tools.
11
+ Webpage: https://github.com/appliedgrg/beratools
12
+
13
+ The purpose of this script is to provide algorithm for
14
+ partitioning vector and raster.
15
+ """
16
+
17
+ import os
18
+
19
+ import geopandas as gpd
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import rasterio
23
+ import shapely.geometry as sh_geom
24
+ import shapely.ops as sh_ops
25
+ from rasterio.mask import mask
26
+ from sklearn.cluster import KMeans
27
+ from sklearn.neighbors import KernelDensity
28
+
29
+
30
+ def create_square_buffer(polygon, buffer_size):
31
+ """
32
+ Create a square buffer around the bounding box of the input polygon.
33
+
34
+ Args:
35
+ polygon (shapely.geometry.Polygon): The input polygon.
36
+ buffer_size (float): The side length of the square buffer.
37
+
38
+ Returns:
39
+ shapely.geometry.Polygon: A square buffer around the bounding box of the polygon.
40
+
41
+ """
42
+ # Get the bounding box of the polygon (returns (minx, miny, maxx, maxy))
43
+ bbox = polygon.bounds # (minx, miny, maxx, maxy)
44
+
45
+ # Find the center of the bounding box
46
+ center_x = (bbox[0] + bbox[2]) / 2
47
+ center_y = (bbox[1] + bbox[3]) / 2
48
+
49
+ # Create a square buffer around the bounding box using the specified buffer_size
50
+ square_buffer = sh_geom.Polygon([
51
+ (center_x - buffer_size / 2, center_y - buffer_size / 2), # Bottom-left
52
+ (center_x - buffer_size / 2, center_y + buffer_size / 2), # Top-left
53
+ (center_x + buffer_size / 2, center_y + buffer_size / 2), # Top-right
54
+ (center_x + buffer_size / 2, center_y - buffer_size / 2) # Bottom-right
55
+ ])
56
+
57
+ return square_buffer
58
+
59
+ class DensityBasedClustering:
60
+ """Density-based clustering of line features."""
61
+
62
+ def __init__(
63
+ self,
64
+ in_line,
65
+ in_raster,
66
+ out_file,
67
+ n_clusters=8,
68
+ tile_buffer=50,
69
+ bandwidth=1.5,
70
+ layer=None,
71
+ ):
72
+ self.input_file = in_line
73
+ self.input_raster = in_raster
74
+ self.output_file = out_file
75
+ self.n_clusters = n_clusters
76
+ self.tile_buffer = tile_buffer
77
+ self.bandwidth = bandwidth
78
+ self.layer = layer
79
+
80
+ self.gdf = None # Initialize gdf attribute
81
+
82
+ def read_points_from_geopackage(self):
83
+ """Read points from GeoPackage and keep the 'group' field."""
84
+ # Load the lines from GeoPackage
85
+ self.gdf = gpd.read_file(self.input_file, layer=self.layer)
86
+
87
+ # Merge lines by group
88
+ grouped = (
89
+ self.gdf.groupby("group")["geometry"]
90
+ .apply(sh_ops.unary_union)
91
+ .reset_index()
92
+ )
93
+ merged_gdf = gpd.GeoDataFrame(grouped, geometry=grouped["geometry"])
94
+
95
+ # Generate centroids for the merged MultiLineStrings
96
+ merged_gdf["centroid"] = merged_gdf.geometry.centroid
97
+
98
+ # Calculate line lengths and assign to centroids as 'weight'
99
+ merged_gdf["weight"] = merged_gdf.geometry.length
100
+ merged_gdf = merged_gdf.drop(columns="geometry")
101
+
102
+ # Create a new GeoDataFrame with centroids as the geometry
103
+ centroid_gdf = gpd.GeoDataFrame(merged_gdf, geometry="centroid")
104
+
105
+ # Ensure CRS is preserved from the original GeoDataFrame
106
+ centroid_gdf.set_crs(self.gdf.crs, allow_override=True, inplace=True)
107
+
108
+ # Filter for valid Point geometries
109
+ centroid_gdf = centroid_gdf[
110
+ centroid_gdf.geometry.apply(
111
+ lambda geom: isinstance(geom, sh_geom.Point) and not geom.is_empty
112
+ )
113
+ ]
114
+ return centroid_gdf, self.gdf
115
+
116
+ def extract_coordinates_and_weights(self, gdf):
117
+ """Extract coordinates and weights from GeoDataFrame."""
118
+ points = np.array([point.coords[0] for point in gdf.geometry])
119
+ weights = gdf[
120
+ "weight"
121
+ ].values # Assuming 'weight' field exists in the GeoDataFrame
122
+ return points, weights
123
+
124
+ def estimate_density(self, points):
125
+ """Estimate density using Kernel Density Estimation (KDE)."""
126
+ kde = KernelDensity(kernel="gaussian", bandwidth=self.bandwidth)
127
+ kde.fit(points)
128
+ return kde
129
+
130
+ def sample_points(self, kde, grid_points, n_samples=200):
131
+ """Sample additional points based on density."""
132
+ log_density = kde.score_samples(grid_points)
133
+ density = np.exp(log_density)
134
+ probabilities = density.ravel() / density.sum()
135
+ sampled_indices = np.random.choice(
136
+ grid_points.shape[0], size=n_samples, p=probabilities
137
+ )
138
+ return grid_points[sampled_indices]
139
+
140
+ def initial_clustering(self, points):
141
+ """Perform KMeans clustering."""
142
+ kmeans = KMeans(n_clusters=self.n_clusters, random_state=42)
143
+ kmeans_labels = kmeans.fit_predict(points)
144
+ return kmeans_labels, kmeans
145
+
146
+ def rebalance_with_weight_sum_constraint(
147
+ self, kmeans_labels, points, weights, kmeans, tolerance=0.5, max_iterations=20
148
+ ):
149
+ """Rebalance clusters with weight sum constraints."""
150
+ if len(kmeans_labels) != len(weights):
151
+ raise ValueError(
152
+ f"""Length mismatch: kmeans_labels has {len(kmeans_labels)} entries,
153
+ but weights has {len(weights)} entries."""
154
+ )
155
+
156
+ total_weight = np.sum(weights)
157
+ target_weight = total_weight / self.n_clusters
158
+
159
+ for iteration in range(max_iterations):
160
+ cluster_weights = np.zeros(self.n_clusters)
161
+ for cluster_id in range(self.n_clusters):
162
+ cluster_weights[cluster_id] = np.sum(
163
+ weights[kmeans_labels == cluster_id]
164
+ )
165
+
166
+ weight_differences = cluster_weights - target_weight
167
+ imbalance = np.abs(weight_differences) > tolerance * target_weight
168
+
169
+ if not np.any(imbalance):
170
+ print(f"Rebalancing completed after {iteration + 1} iterations.")
171
+ break
172
+
173
+ for cluster_id in range(self.n_clusters):
174
+ if weight_differences[cluster_id] > tolerance * target_weight:
175
+ excess_points = np.where(kmeans_labels == cluster_id)[0]
176
+ for idx in excess_points:
177
+ distances_to_centroids = np.linalg.norm(
178
+ points[idx] - kmeans.cluster_centers_, axis=1
179
+ )
180
+ distances_to_centroids[cluster_id] = np.inf
181
+ closest_cluster = np.argmin(distances_to_centroids)
182
+
183
+ if (
184
+ cluster_weights[closest_cluster]
185
+ < target_weight - tolerance * target_weight
186
+ ):
187
+ kmeans_labels[idx] = closest_cluster
188
+ elif weight_differences[cluster_id] < -tolerance * target_weight:
189
+ deficient_points = np.where(kmeans_labels == cluster_id)[0]
190
+ for idx in deficient_points:
191
+ distances_to_centroids = np.linalg.norm(
192
+ points[idx] - kmeans.cluster_centers_, axis=1
193
+ )
194
+ distances_to_centroids[cluster_id] = np.inf
195
+ closest_cluster = np.argmin(distances_to_centroids)
196
+
197
+ if (
198
+ cluster_weights[closest_cluster]
199
+ > target_weight + tolerance * target_weight
200
+ ):
201
+ kmeans_labels[idx] = cluster_id
202
+ else:
203
+ print("Maximum iterations reached without achieving perfect balance.")
204
+
205
+ self.print_cluster_weight_sums(kmeans_labels, weights)
206
+ return kmeans_labels
207
+
208
+ def print_cluster_weight_sums(self, kmeans_labels, weights):
209
+ """Print weight sums for clusters."""
210
+ for cluster_id in range(self.n_clusters):
211
+ cluster_weight_sum = np.sum(weights[kmeans_labels == cluster_id])
212
+ print(f"Cluster {cluster_id} weight sum: {cluster_weight_sum}")
213
+
214
+ def save_final_clustering_to_geopackage(
215
+ self,
216
+ all_points,
217
+ all_weights,
218
+ kmeans_labels_final,
219
+ group_field,
220
+ crs,
221
+ gdf,
222
+ centroid_gdf,
223
+ ):
224
+ """Save final clustering result to GeoPackage, keeping 'group' field."""
225
+ # Only assign labels to centroids
226
+ centroid_gdf["cluster"] = kmeans_labels_final[: len(centroid_gdf)]
227
+
228
+ # Merge the centroids centroid_gdf with the original gdf based on 'group'
229
+ gdf_merged = gdf.merge(
230
+ centroid_gdf[["group", "cluster"]], on="group", how="left"
231
+ )
232
+
233
+ # Save the whole lines dataset with clustering information
234
+ gdf_merged.to_file(self.output_file, layer="final_clusters", driver="GPKG")
235
+
236
+ # Save separate layers for each cluster
237
+ for cluster_id in range(self.n_clusters):
238
+ cluster_gdf = gdf_merged[gdf_merged["cluster"] == cluster_id]
239
+ cluster_gdf.to_file(
240
+ self.output_file,
241
+ layer=f"cluster_{cluster_id}",
242
+ driver="GPKG",
243
+ )
244
+
245
+ def plot_clusters(self, points, labels, centroids, title):
246
+ """Plot clusters."""
247
+ fig, ax = plt.subplots(figsize=(10, 10))
248
+ scatter = ax.scatter(
249
+ points[:, 0], points[:, 1], c=labels, cmap="tab10", s=50, alpha=0.6
250
+ )
251
+ ax.scatter(
252
+ centroids[:, 0],
253
+ centroids[:, 1],
254
+ c="red",
255
+ marker="*",
256
+ s=200,
257
+ label="Centroids",
258
+ )
259
+ plt.colorbar(scatter, ax=ax, label="Cluster ID")
260
+ ax.set_title(title)
261
+ plt.show()
262
+
263
+ def generate_and_clip_rasters(self, kmeans_labels_final):
264
+ """Generate bounding box polygons based on line clusters and clip the raster."""
265
+ parent_folder = os.path.dirname(self.input_file)
266
+ output_folder = os.path.join(parent_folder, "rasters")
267
+ os.makedirs(output_folder, exist_ok=True)
268
+
269
+ with rasterio.open(self.input_raster) as src:
270
+ for cluster_id in range(self.n_clusters):
271
+ cluster_lines = self.get_lines_for_cluster(
272
+ kmeans_labels_final, cluster_id
273
+ )
274
+
275
+ if cluster_lines:
276
+ multi_line = sh_geom.MultiLineString(cluster_lines)
277
+
278
+ # Collect all coordinates from the lines in the multi_line object
279
+ all_coords = []
280
+ for line in multi_line.geoms:
281
+ # Make sure each line is of type LineString
282
+ if isinstance(line, sh_geom.LineString):
283
+ coords = list(line.coords)
284
+ all_coords.extend(coords)
285
+ else:
286
+ print(f"Warning: Found non-LineString geom in {cluster_id}")
287
+
288
+ if not all_coords:
289
+ print(
290
+ f"Warning: No coordinates found: {cluster_id}. Skipping..."
291
+ )
292
+ continue
293
+
294
+ # Create a bounding box from the coordinates
295
+ min_x = min(coord[0] for coord in all_coords)
296
+ max_x = max(coord[0] for coord in all_coords)
297
+ min_y = min(coord[1] for coord in all_coords)
298
+ max_y = max(coord[1] for coord in all_coords)
299
+
300
+ print(
301
+ f"""Cluster {cluster_id} BBox:
302
+ ({min_x}, {min_y}), ({max_x}, {max_y})"""
303
+ )
304
+
305
+ # Create a Polygon representing the bounding box
306
+ bounding_box = sh_geom.Polygon(
307
+ [
308
+ (min_x, min_y),
309
+ (max_x, min_y),
310
+ (max_x, max_y),
311
+ (min_x, max_y),
312
+ (min_x, min_y),
313
+ ]
314
+ )
315
+
316
+ # Clip the raster with the bounding box
317
+ out_image, out_transform = mask(src, [bounding_box], crop=True)
318
+
319
+ # Ensure the out_image shape is correct
320
+ out_image = out_image.squeeze()
321
+
322
+ cluster_raster_path = os.path.join(
323
+ output_folder, f"cluster_{cluster_id}.tif"
324
+ )
325
+
326
+ with rasterio.open(
327
+ cluster_raster_path,
328
+ "w",
329
+ driver="GTiff",
330
+ count=1,
331
+ dtype=out_image.dtype,
332
+ crs=src.crs,
333
+ transform=out_transform,
334
+ width=out_image.shape[1],
335
+ height=out_image.shape[0],
336
+ ) as dest:
337
+ dest.write(out_image, 1)
338
+
339
+ print(f"Cluster {cluster_id} raster saved to {cluster_raster_path}")
340
+ else:
341
+ print(f"No lines found: {cluster_id}, skipping raster generation.")
342
+
343
+ def get_lines_for_cluster(self, kmeans_labels_final, cluster_id):
344
+ """Retrieve the lines corresponding to a specific cluster."""
345
+ cluster_lines = []
346
+ groups = []
347
+ for idx, centroid in self.centroid_gdf.iterrows():
348
+ if centroid["cluster"] == cluster_id:
349
+ group_value = centroid["group"]
350
+ groups.append(group_value)
351
+ # Find the lines in the original gdf that belong to this group
352
+ lines_for_cluster = self.gdf[self.gdf["group"] == group_value]
353
+ cluster_lines.extend(lines_for_cluster["geometry"])
354
+
355
+ # flatten any MultiLineString objects into individual LineString objects
356
+ flattened_lines = []
357
+ for line in cluster_lines:
358
+ if isinstance(line, sh_geom.MultiLineString):
359
+ # Extract individual LineStrings from the MultiLineString
360
+ flattened_lines.extend(
361
+ line.geoms
362
+ ) # `line.geoms` is an iterable of LineString objects
363
+ elif isinstance(line, sh_geom.LineString):
364
+ flattened_lines.append(
365
+ line
366
+ ) # Directly append the LineString if it's not a MultiLineString
367
+
368
+ return flattened_lines
369
+
370
+ def run(self):
371
+ """Run the full clustering process."""
372
+ # Step 1: Read points and original lines
373
+ centroid_gdf, gdf = self.read_points_from_geopackage()
374
+
375
+ # Assign centroid_gdf to the class attribute
376
+ self.centroid_gdf = centroid_gdf # Add this line
377
+
378
+ # Step 2: Extract coordinates and weights
379
+ points, weights = self.extract_coordinates_and_weights(centroid_gdf)
380
+
381
+ # Step 3: Estimate density
382
+ kde = self.estimate_density(points)
383
+
384
+ # Step 4: Sample additional points based on density
385
+ x_min, y_min = points.min(axis=0)
386
+ x_max, y_max = points.max(axis=0)
387
+ xx, yy = np.meshgrid(
388
+ np.linspace(x_min, x_max, 200), np.linspace(y_min, y_max, 200)
389
+ )
390
+ grid_points = np.vstack([xx.ravel(), yy.ravel()]).T
391
+ sampled_points = self.sample_points(kde, grid_points, n_samples=200)
392
+
393
+ # Combine original and sampled points
394
+ all_points = np.vstack([points, sampled_points])
395
+ all_weights = np.concatenate([weights, np.ones(sampled_points.shape[0])])
396
+
397
+ # Preserve the 'group' field for the final output
398
+ group_field = np.concatenate(
399
+ [centroid_gdf["group"].values, np.full(sampled_points.shape[0], -1)]
400
+ ) # Assign default value -1 for sampled points
401
+
402
+ # Step 5: Initial clustering
403
+ kmeans_labels_initial, kmeans_initial = self.initial_clustering(points)
404
+
405
+ # Assign clusters to the new sampled points
406
+ kmeans_labels_all = np.concatenate(
407
+ [kmeans_labels_initial, kmeans_initial.predict(sampled_points)]
408
+ )
409
+
410
+ # Step 6: Rebalance clusters with weight sum constraints
411
+ kmeans_labels_final = self.rebalance_with_weight_sum_constraint(
412
+ kmeans_labels_all, all_points, all_weights, kmeans_initial
413
+ )
414
+
415
+ # Step 7: Save final clustering to GeoPackage
416
+ self.save_final_clustering_to_geopackage(
417
+ all_points,
418
+ all_weights,
419
+ kmeans_labels_final,
420
+ group_field,
421
+ centroid_gdf.crs,
422
+ gdf,
423
+ centroid_gdf,
424
+ )
425
+
426
+ # Step 8: Generate and clip rasters for each cluster
427
+ self.generate_and_clip_rasters(kmeans_labels_final)
428
+