BERATools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. beratools/__init__.py +3 -0
  2. beratools/core/__init__.py +0 -0
  3. beratools/core/algo_centerline.py +476 -0
  4. beratools/core/algo_common.py +489 -0
  5. beratools/core/algo_cost.py +185 -0
  6. beratools/core/algo_dijkstra.py +492 -0
  7. beratools/core/algo_footprint_rel.py +693 -0
  8. beratools/core/algo_line_grouping.py +941 -0
  9. beratools/core/algo_merge_lines.py +255 -0
  10. beratools/core/algo_split_with_lines.py +296 -0
  11. beratools/core/algo_vertex_optimization.py +451 -0
  12. beratools/core/constants.py +56 -0
  13. beratools/core/logger.py +92 -0
  14. beratools/core/tool_base.py +126 -0
  15. beratools/gui/__init__.py +11 -0
  16. beratools/gui/assets/BERALogo.png +0 -0
  17. beratools/gui/assets/beratools.json +471 -0
  18. beratools/gui/assets/closed.gif +0 -0
  19. beratools/gui/assets/closed.png +0 -0
  20. beratools/gui/assets/gui.json +8 -0
  21. beratools/gui/assets/open.gif +0 -0
  22. beratools/gui/assets/open.png +0 -0
  23. beratools/gui/assets/tool.gif +0 -0
  24. beratools/gui/assets/tool.png +0 -0
  25. beratools/gui/bt_data.py +485 -0
  26. beratools/gui/bt_gui_main.py +700 -0
  27. beratools/gui/main.py +27 -0
  28. beratools/gui/tool_widgets.py +730 -0
  29. beratools/tools/__init__.py +7 -0
  30. beratools/tools/canopy_threshold_relative.py +769 -0
  31. beratools/tools/centerline.py +127 -0
  32. beratools/tools/check_seed_line.py +48 -0
  33. beratools/tools/common.py +622 -0
  34. beratools/tools/line_footprint_absolute.py +203 -0
  35. beratools/tools/line_footprint_fixed.py +480 -0
  36. beratools/tools/line_footprint_functions.py +884 -0
  37. beratools/tools/line_footprint_relative.py +75 -0
  38. beratools/tools/tool_template.py +72 -0
  39. beratools/tools/vertex_optimization.py +57 -0
  40. beratools-0.1.0.dist-info/METADATA +134 -0
  41. beratools-0.1.0.dist-info/RECORD +44 -0
  42. beratools-0.1.0.dist-info/WHEEL +4 -0
  43. beratools-0.1.0.dist-info/entry_points.txt +2 -0
  44. beratools-0.1.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,489 @@
1
+ """
2
+ Copyright (C) 2025 Applied Geospatial Research Group.
3
+
4
+ This script is licensed under the GNU General Public License v3.0.
5
+ See <https://gnu.org/licenses/gpl-3.0> for full license details.
6
+
7
+ Author: Richard Zeng
8
+
9
+ Description:
10
+ This script is part of the BERA Tools.
11
+ Webpage: https://github.com/appliedgrg/beratools
12
+
13
+ The purpose of this script is to provide common algorithms
14
+ and utility functions/classes.
15
+ """
16
+
17
+ import math
18
+ import tempfile
19
+ from pathlib import Path
20
+
21
+ import geopandas as gpd
22
+ import numpy as np
23
+ import pyproj
24
+ import rasterio
25
+ import shapely
26
+ import shapely.affinity as sh_aff
27
+ import shapely.geometry as sh_geom
28
+ import shapely.ops as sh_ops
29
+ import skimage.graph as sk_graph
30
+ from osgeo import gdal
31
+ from scipy import ndimage
32
+
33
+ import beratools.core.algo_cost as algo_cost
34
+ import beratools.core.constants as bt_const
35
+
36
+ gpd.options.io_engine = "pyogrio"
37
+ DISTANCE_THRESHOLD = 2 # 1 meter for intersection neighborhood
38
+
39
+
40
+ def process_single_item(cls_obj):
41
+ """
42
+ Process a class object for universal multiprocessing.
43
+
44
+ Args:
45
+ cls_obj: Class object to be processed
46
+
47
+ Returns:
48
+ cls_obj: Class object after processing
49
+
50
+ """
51
+ try:
52
+ cls_obj.compute()
53
+ return cls_obj
54
+ except Exception as e:
55
+ import traceback
56
+
57
+ print(f"❌ Exception during compute() for object: {e}")
58
+ traceback.print_exc()
59
+ return None
60
+
61
+
62
+ def read_geospatial_file(file_path, layer=None):
63
+ """
64
+ Read a geospatial file, clean the geometries and return a GeoDataFrame.
65
+
66
+ Args:
67
+ file_path (str): The path to the geospatial file (e.g., .shp, .gpkg).
68
+ layer (str, optional): The specific layer to read if the file is
69
+ multi-layered (e.g., GeoPackage).
70
+
71
+ Returns:
72
+ GeoDataFrame: The cleaned GeoDataFrame containing the data from the file
73
+ with valid geometries only.
74
+ None: If there is an error reading the file or layer.
75
+
76
+ """
77
+ try:
78
+ if layer is None:
79
+ # Read the file without specifying a layer
80
+ gdf = gpd.read_file(file_path)
81
+ else:
82
+ # Read the file with the specified layer
83
+ gdf = gpd.read_file(file_path, layer=layer)
84
+
85
+ # Clean the geometries in the GeoDataFrame
86
+ gdf = clean_geometries(gdf)
87
+ gdf["BT_UID"] = range(len(gdf)) # assign temporary UID
88
+ return gdf
89
+
90
+ except Exception as e:
91
+ print(f"Error reading file {file_path}: {e}")
92
+ return None
93
+
94
+
95
+ def has_multilinestring(gdf):
96
+ """Check if any geometry is a MultiLineString."""
97
+ # Filter out None values (invalid geometries) from the GeoDataFrame
98
+ valid_geometries = gdf.geometry
99
+ return any(isinstance(geom, sh_geom.MultiLineString) for geom in valid_geometries)
100
+
101
+
102
+ def clean_geometries(gdf):
103
+ """
104
+ Remove rows with invalid, None, or empty geometries from the GeoDataFrame.
105
+
106
+ Args:
107
+ gdf (GeoDataFrame): The GeoDataFrame to clean.
108
+
109
+ Returns:
110
+ GeoDataFrame: The cleaned GeoDataFrame with valid, non-null,
111
+ and non-empty geometries.
112
+
113
+ """
114
+ # Remove rows where the geometry is invalid, None, or empty
115
+ gdf = gdf[gdf.geometry.is_valid] # Only keep valid geometries
116
+ gdf = gdf[~gdf.geometry.isna()] # Remove rows with None geometries
117
+ gdf = gdf[gdf.geometry.apply(lambda geom: not geom.is_empty)] # Remove empty geometries
118
+ return gdf
119
+
120
+
121
+ def clean_line_geometries(line_gdf):
122
+ """Clean line geometries in the GeoDataFrame."""
123
+ if line_gdf is None:
124
+ return line_gdf
125
+
126
+ if line_gdf.empty:
127
+ return line_gdf
128
+
129
+ line_gdf = line_gdf[~line_gdf.geometry.isna() & ~line_gdf.geometry.is_empty]
130
+ line_gdf = line_gdf[line_gdf.geometry.length > bt_const.SMALL_BUFFER]
131
+ return line_gdf
132
+
133
+
134
+ def prepare_lines_gdf(file_path, layer=None, proc_segments=True):
135
+ """
136
+ Split lines at vertices or return original rows.
137
+
138
+ It handles for MultiLineString.
139
+
140
+ """
141
+ # Check if there are any MultiLineString geometries
142
+ gdf = read_geospatial_file(file_path, layer=layer)
143
+
144
+ # Explode MultiLineStrings into individual LineStrings
145
+ if has_multilinestring(gdf):
146
+ gdf = gdf.explode(index_parts=False)
147
+
148
+ split_gdf_list = []
149
+
150
+ for row in gdf.itertuples(index=False): # Use itertuples to iterate
151
+ line = row.geometry # Access geometry directly via the named tuple
152
+
153
+ # If proc_segment is True, split the line at vertices
154
+ if proc_segments:
155
+ coords = list(line.coords) # Extract the list of coordinates (vertices)
156
+
157
+ # For each LineString, split the line into segments by the vertices
158
+ for i in range(len(coords) - 1):
159
+ segment = sh_geom.LineString([coords[i], coords[i + 1]])
160
+
161
+ # Copy over all non-geometry columns (excluding 'geometry')
162
+ attributes = {col: getattr(row, col) for col in gdf.columns if col != "geometry"}
163
+ single_row_gdf = gpd.GeoDataFrame([attributes], geometry=[segment], crs=gdf.crs)
164
+ split_gdf_list.append(single_row_gdf)
165
+
166
+ else:
167
+ # If not proc_segment, add the original row as a single-row GeoDataFrame
168
+ attributes = {col: getattr(row, col) for col in gdf.columns if col != "geometry"}
169
+ single_row_gdf = gpd.GeoDataFrame([attributes], geometry=[line], crs=gdf.crs)
170
+ split_gdf_list.append(single_row_gdf)
171
+
172
+ return split_gdf_list
173
+
174
+
175
+ # TODO use function from common
176
+ def morph_raster(corridor_thresh, canopy_raster, exp_shk_cell, cell_size_x):
177
+ # Process: Stamp CC and Max Line Width
178
+ temp1 = corridor_thresh + canopy_raster
179
+ raster_class = np.ma.where(temp1 == 0, 1, 0).data
180
+
181
+ if exp_shk_cell > 0 and cell_size_x < 1:
182
+ # Process: Expand
183
+ # FLM original Expand equivalent
184
+ cell_size = int(exp_shk_cell * 2 + 1)
185
+ expanded = ndimage.grey_dilation(raster_class, size=(cell_size, cell_size))
186
+
187
+ # Process: Shrink
188
+ # FLM original Shrink equivalent
189
+ file_shrink = ndimage.grey_erosion(expanded, size=(cell_size, cell_size))
190
+
191
+ else:
192
+ if bt_const.BT_DEBUGGING:
193
+ print("No Expand And Shrink cell performed.")
194
+ file_shrink = raster_class
195
+
196
+ # Process: Boundary Clean
197
+ clean_raster = ndimage.gaussian_filter(file_shrink, sigma=0, mode="nearest")
198
+
199
+ return clean_raster
200
+
201
+
202
+ def closest_point_to_line(point, line):
203
+ if not line:
204
+ return None
205
+
206
+ pt = line.interpolate(line.project(sh_geom.Point(point)))
207
+ return pt
208
+
209
+
210
+ def line_coord_list(line):
211
+ point_list = []
212
+ try:
213
+ for point in list(line.coords): # loops through every point in a line
214
+ # loops through every vertex of every segment
215
+ if point: # adds all the vertices to segment_list, which creates an array
216
+ point_list.append(sh_geom.Point(point[0], point[1]))
217
+ except Exception as e:
218
+ print(e)
219
+
220
+ return point_list
221
+
222
+
223
+ def intersection_of_lines(line_1, line_2):
224
+ """
225
+ Only LINESTRING is dealt with for now.
226
+
227
+ Args:
228
+ line_1 :
229
+ line_2 :
230
+
231
+ Returns:
232
+ sh_geom.Point: intersection point
233
+
234
+ """
235
+ # intersection collection, may contain points and lines
236
+ inter = None
237
+ if line_1 and line_2:
238
+ inter = line_1.intersection(line_2)
239
+
240
+ # TODO: intersection may return GeometryCollection, LineString or MultiLineString
241
+ if inter:
242
+ if (
243
+ type(inter) is sh_geom.GeometryCollection
244
+ or type(inter) is sh_geom.LineString
245
+ or type(inter) is sh_geom.MultiLineString
246
+ ):
247
+ return inter.centroid
248
+
249
+ return inter
250
+
251
+
252
+ def get_angle(line, vertex_index):
253
+ """
254
+ Calculate the angle of the first or last segment.
255
+
256
+ # TODO: use np.arctan2 instead of np.arctan
257
+
258
+ Args:
259
+ line: LineString
260
+ end_index: 0 or -1 of the line vertices. Consider the multipart.
261
+
262
+ """
263
+ pts = line_coord_list(line)
264
+
265
+ if vertex_index == 0:
266
+ pt_1 = pts[0]
267
+ pt_2 = pts[1]
268
+ elif vertex_index == -1:
269
+ pt_1 = pts[-1]
270
+ pt_2 = pts[-2]
271
+
272
+ delta_x = pt_2.x - pt_1.x
273
+ delta_y = pt_2.y - pt_1.y
274
+ if np.isclose(pt_1.x, pt_2.x):
275
+ angle = np.pi / 2
276
+ if delta_y > 0:
277
+ angle = np.pi / 2
278
+ elif delta_y < 0:
279
+ angle = -np.pi / 2
280
+ else:
281
+ angle = np.arctan(delta_y / delta_x)
282
+
283
+ # arctan is in range [-pi/2, pi/2], regulate all angles to [[-pi/2, 3*pi/2]]
284
+ if delta_x < 0:
285
+ angle += np.pi # the second or fourth quadrant
286
+
287
+ return angle
288
+
289
+
290
+ def points_are_close(pt1, pt2):
291
+ if abs(pt1.x - pt2.x) < DISTANCE_THRESHOLD and abs(pt1.y - pt2.y) < DISTANCE_THRESHOLD:
292
+ return True
293
+ else:
294
+ return False
295
+
296
+
297
+ def generate_raster_footprint(in_raster, latlon=True):
298
+ inter_img = "image_overview.tif"
299
+
300
+ src_ds = gdal.Open(in_raster)
301
+ width, height = src_ds.RasterXSize, src_ds.RasterYSize
302
+ src_crs = src_ds.GetSpatialRef().ExportToWkt()
303
+
304
+ geom = None
305
+ with tempfile.TemporaryDirectory() as tmp_folder:
306
+ if bt_const.BT_DEBUGGING:
307
+ print("Temporary folder: {}".format(tmp_folder))
308
+
309
+ if max(width, height) <= 1024:
310
+ inter_img = in_raster
311
+ else:
312
+ if width >= height:
313
+ options = gdal.TranslateOptions(width=1024, height=0)
314
+ else:
315
+ options = gdal.TranslateOptions(width=0, height=1024)
316
+
317
+ inter_img = Path(tmp_folder).joinpath(inter_img).as_posix()
318
+ gdal.Translate(inter_img, src_ds, options=options)
319
+
320
+ shapes = gdal.Footprint(None, inter_img, dstSRS=src_crs, format="GeoJSON")
321
+ target_feat = shapes["features"][0]
322
+ geom = sh_geom.shape(target_feat["geometry"])
323
+
324
+ if latlon:
325
+ out_crs = pyproj.CRS("EPSG:4326")
326
+ transformer = pyproj.Transformer.from_crs(pyproj.CRS(src_crs), out_crs)
327
+
328
+ geom = sh_ops.transform(transformer.transform, geom)
329
+
330
+ return geom
331
+
332
+
333
+ def save_raster_to_file(in_raster_mem, in_meta, out_raster_file):
334
+ """
335
+ Save raster matrix in memory to file.
336
+
337
+ Args:
338
+ in_raster_mem: numpy raster
339
+ in_meta: input meta
340
+ out_raster_file: output raster file
341
+
342
+ """
343
+ with rasterio.open(out_raster_file, "w", **in_meta) as dest:
344
+ dest.write(in_raster_mem, indexes=1)
345
+
346
+
347
+ def generate_perpendicular_line_precise(points, offset=20):
348
+ """
349
+ Generate a perpendicular line to the input line at the given point.
350
+
351
+ Args:
352
+ points (list[Point]): The points where to generate the perpendicular lines.
353
+ offset (float): The length of the perpendicular line.
354
+
355
+ Returns:
356
+ shapely.geometry.LineString: The generated perpendicular line.
357
+
358
+ """
359
+ # Compute the angle of the line
360
+ if len(points) not in [2, 3]:
361
+ return None
362
+
363
+ center = points[1]
364
+ perp_line = None
365
+
366
+ if len(points) == 2:
367
+ head = points[0]
368
+ tail = points[1]
369
+
370
+ delta_x = head.x - tail.x
371
+ delta_y = head.y - tail.y
372
+ angle = 0.0
373
+
374
+ if math.isclose(delta_x, 0.0):
375
+ angle = math.pi / 2
376
+ else:
377
+ angle = math.atan(delta_y / delta_x)
378
+
379
+ start = [center.x + offset / 2.0, center.y]
380
+ end = [center.x - offset / 2.0, center.y]
381
+ line = sh_geom.LineString([start, end])
382
+ perp_line = sh_aff.rotate(line, angle + math.pi / 2.0, origin=center, use_radians=True)
383
+ elif len(points) == 3:
384
+ head = points[0]
385
+ tail = points[2]
386
+
387
+ angle_1 = _line_angle(center, head)
388
+ angle_2 = _line_angle(center, tail)
389
+ angle_diff = (angle_2 - angle_1) / 2.0
390
+ head_new = sh_geom.Point(
391
+ center.x + offset / 2.0 * math.cos(angle_1),
392
+ center.y + offset / 2.0 * math.sin(angle_1),
393
+ )
394
+ if head.has_z:
395
+ head_new = shapely.force_3d(head_new)
396
+ try:
397
+ perp_seg_1 = sh_geom.LineString([center, head_new])
398
+ perp_seg_1 = sh_aff.rotate(perp_seg_1, angle_diff, origin=center, use_radians=True)
399
+ perp_seg_2 = sh_aff.rotate(perp_seg_1, math.pi, origin=center, use_radians=True)
400
+ perp_line = sh_geom.LineString([list(perp_seg_1.coords)[1], list(perp_seg_2.coords)[1]])
401
+ except Exception as e:
402
+ print(e)
403
+
404
+ return perp_line
405
+
406
+
407
+ def _line_angle(point_1, point_2):
408
+ """
409
+ Calculate the angle of the line.
410
+
411
+ Args:
412
+ point_1, point_2: start and end points of shapely line
413
+
414
+ """
415
+ delta_y = point_2.y - point_1.y
416
+ delta_x = point_2.x - point_1.x
417
+
418
+ angle = math.atan2(delta_y, delta_x)
419
+ return angle
420
+
421
+
422
+ def corridor_raster(raster_clip, out_meta, source, destination, cell_size, corridor_threshold):
423
+ """
424
+ Calculate corridor raster.
425
+
426
+ Args:
427
+ raster_clip (raster):
428
+ out_meta : raster file meta
429
+ source (list of point tuple(s)): start point in row/col
430
+ destination (list of point tuple(s)): end point in row/col
431
+ cell_size (tuple): (cell_size_x, cell_size_y)
432
+ corridor_threshold (double)
433
+
434
+ Returns:
435
+ corridor raster
436
+
437
+ """
438
+ try:
439
+ # change all nan to BT_NODATA_COST for workaround
440
+ if len(raster_clip.shape) > 2:
441
+ raster_clip = np.squeeze(raster_clip, axis=0)
442
+
443
+ algo_cost.remove_nan_from_array_refactor(raster_clip)
444
+
445
+ # generate the cost raster to source point
446
+ mcp_source = sk_graph.MCP_Geometric(raster_clip, sampling=cell_size)
447
+ source_cost_acc = mcp_source.find_costs(source)[0]
448
+ del mcp_source
449
+
450
+ # # # generate the cost raster to destination point
451
+ mcp_dest = sk_graph.MCP_Geometric(raster_clip, sampling=cell_size)
452
+ dest_cost_acc = mcp_dest.find_costs(destination)[0]
453
+
454
+ # Generate corridor
455
+ corridor = source_cost_acc + dest_cost_acc
456
+ corridor = np.ma.masked_invalid(corridor)
457
+
458
+ # Calculate minimum value of corridor raster
459
+ if np.ma.min(corridor) is not None:
460
+ corr_min = float(np.ma.min(corridor))
461
+ else:
462
+ corr_min = 0.5
463
+
464
+ # normalize corridor raster by deducting corr_min
465
+ corridor_norm = corridor - corr_min
466
+ corridor_thresh_cl = np.ma.where(corridor_norm >= corridor_threshold, 1.0, 0.0)
467
+
468
+ except Exception as e:
469
+ print(e)
470
+ print("corridor_raster: Exception occurred.")
471
+ return None
472
+
473
+ return corridor_thresh_cl
474
+
475
+
476
+ def remove_holes(geom):
477
+ if geom.geom_type == "Polygon":
478
+ if geom.interiors:
479
+ return sh_geom.Polygon(geom.exterior)
480
+ return geom
481
+ elif geom.geom_type == "MultiPolygon":
482
+ new_polygons = []
483
+ for polygon in geom.geoms: # Iterate through MultiPolygon
484
+ if polygon.interiors:
485
+ new_polygons.append(sh_geom.Polygon(polygon.exterior))
486
+ else:
487
+ new_polygons.append(polygon)
488
+ return sh_geom.MultiPolygon(new_polygons)
489
+ return geom # Return other geometry types as is
@@ -0,0 +1,185 @@
1
+ """
2
+ Copyright (C) 2025 Applied Geospatial Research Group.
3
+
4
+ This script is licensed under the GNU General Public License v3.0.
5
+ See <https://gnu.org/licenses/gpl-3.0> for full license details.
6
+
7
+ Author: Richard Zeng
8
+
9
+ Description:
10
+ This script is part of the BERA Tools.
11
+ Webpage: https://github.com/appliedgrg/beratools
12
+
13
+ This file hosts cost raster related functions.
14
+ """
15
+
16
+ import numpy as np
17
+ import scipy
18
+
19
+ import beratools.core.constants as bt_const
20
+
21
+
22
+ def cost_raster(
23
+ in_raster,
24
+ meta,
25
+ tree_radius=2.5,
26
+ canopy_ht_threshold=2.5,
27
+ max_line_dist=2.5,
28
+ canopy_avoid=0.4,
29
+ cost_raster_exponent=1.5,
30
+ ):
31
+ """
32
+ General version of cost_raster.
33
+
34
+ To be merged later: variables and consistent nodata solution
35
+
36
+ """
37
+ if len(in_raster.shape) > 2:
38
+ in_raster = np.squeeze(in_raster, axis=0)
39
+
40
+ # regulate canopy_avoid between 0 and 1
41
+ avoidance = max(0, min(1, canopy_avoid))
42
+ cell_x, cell_y = meta["transform"][0], -meta["transform"][4]
43
+
44
+ kernel_radius = int(tree_radius / cell_x)
45
+ kernel = circle_kernel_refactor(2 * kernel_radius + 1, kernel_radius)
46
+ dyn_canopy_ndarray = dyn_np_cc_map(in_raster, canopy_ht_threshold)
47
+
48
+ cc_std, cc_mean = cost_focal_stats(dyn_canopy_ndarray, kernel)
49
+ cc_smooth = cost_norm_dist_transform(dyn_canopy_ndarray, max_line_dist, [cell_x, cell_y])
50
+
51
+ cost_clip = dyn_np_cost_raster_refactor(
52
+ dyn_canopy_ndarray, cc_mean, cc_std, cc_smooth, avoidance, cost_raster_exponent
53
+ )
54
+
55
+ # TODO use nan or BT_DATA?
56
+ cost_clip[in_raster == bt_const.BT_NODATA] = np.nan
57
+ dyn_canopy_ndarray[in_raster == bt_const.BT_NODATA] = np.nan
58
+
59
+ return cost_clip, dyn_canopy_ndarray
60
+
61
+
62
+ def remove_nan_from_array_refactor(matrix, replacement_value=bt_const.BT_NODATA_COST):
63
+ # Use boolean indexing to replace nan values
64
+ matrix[np.isnan(matrix)] = replacement_value
65
+
66
+
67
+ def dyn_np_cc_map(in_chm, canopy_ht_threshold):
68
+ """
69
+ Create a new canopy raster.
70
+
71
+ MaskedArray based on the threshold comparison of in_chm (canopy height model)
72
+ with canopy_ht_threshold. It assigns 1.0 where the condition is True (canopy)
73
+ and 0.0 where the condition is False (non-canopy).
74
+
75
+ """
76
+ canopy_ndarray = np.ma.where(in_chm >= canopy_ht_threshold, 1.0, 0.0).astype(float)
77
+ return canopy_ndarray
78
+
79
+
80
+ def cost_focal_stats(canopy_ndarray, kernel):
81
+ mask = canopy_ndarray.mask
82
+ in_ndarray = np.ma.where(mask, np.nan, canopy_ndarray)
83
+
84
+ # Function to compute mean and standard deviation
85
+ def calc_mean(arr):
86
+ return np.nanmean(arr)
87
+
88
+ def calc_std(arr):
89
+ return np.nanstd(arr)
90
+
91
+ # Apply the generic_filter function to compute mean and std
92
+ mean_array = scipy.ndimage.generic_filter(in_ndarray, calc_mean, footprint=kernel, mode="nearest")
93
+ std_array = scipy.ndimage.generic_filter(in_ndarray, calc_std, footprint=kernel, mode="nearest")
94
+
95
+ return std_array, mean_array
96
+
97
+
98
+ def cost_norm_dist_transform(canopy_ndarray, max_line_dist, sampling):
99
+ """Compute a distance-based cost map based on the proximity of valid data points."""
100
+ # Convert masked array to a regular array and fill the masked areas with np.nan
101
+ in_ndarray = canopy_ndarray.filled(np.nan)
102
+
103
+ # Compute the Euclidean distance transform (edt) where the valid values are
104
+ euc_dist_array = scipy.ndimage.distance_transform_edt(
105
+ np.logical_not(np.isnan(in_ndarray)), sampling=sampling
106
+ )
107
+
108
+ # Apply the mask back to set the distances to np.nan
109
+ euc_dist_array[canopy_ndarray.mask] = np.nan
110
+
111
+ # Calculate the smoothness (cost) array
112
+ normalized_cost = float(max_line_dist) - euc_dist_array
113
+ normalized_cost[normalized_cost <= 0.0] = 0.0
114
+ smooth_cost_array = normalized_cost / float(max_line_dist)
115
+
116
+ return smooth_cost_array
117
+
118
+
119
+ def dyn_np_cost_raster_refactor(canopy_ndarray, cc_mean, cc_std, cc_smooth, avoidance, cost_raster_exponent):
120
+ # Calculate the lower and upper bounds for canopy cover (mean ± std deviation)
121
+ lower_bound = cc_mean - cc_std
122
+ upper_bound = cc_mean + cc_std
123
+
124
+ # Calculate the ratio between the lower and upper bounds
125
+ ratio_lower_upper = np.divide(
126
+ lower_bound,
127
+ upper_bound,
128
+ where=upper_bound != 0,
129
+ out=np.zeros(lower_bound.shape, dtype=float),
130
+ )
131
+
132
+ # Normalize the ratio to a scale between 0 and 1
133
+ normalized_ratio = (1 + ratio_lower_upper) / 2
134
+
135
+ # Adjust where the sum of mean and std deviation is less than or equal to zero
136
+ adjusted_cover = cc_mean + cc_std
137
+ adjusted_ratio = np.where(adjusted_cover <= 0, 0, normalized_ratio)
138
+
139
+ # Combine canopy cover ratio with smoothing, weighted by avoidance factor
140
+ weighted_cover = adjusted_ratio * (1 - avoidance) + (cc_smooth * avoidance)
141
+
142
+ # Final cost modification based on canopy presence (masked by canopy_ndarray)
143
+ final_cost = np.where(canopy_ndarray.data == 1, 1, weighted_cover)
144
+
145
+ # Apply the exponential transformation to the cost values
146
+ exponent_cost = np.exp(final_cost)
147
+
148
+ # Raise the cost to the specified exponent
149
+ result_cost_raster = np.power(exponent_cost, float(cost_raster_exponent))
150
+
151
+ return result_cost_raster
152
+
153
+
154
+ def circle_kernel_refactor(size, radius):
155
+ """
156
+ Create a circular kernel using Scipy.
157
+
158
+ Args:
159
+ size : kernel size
160
+ radius : radius of the circle
161
+
162
+ Returns:
163
+ kernel (ndarray): A circular kernel.
164
+
165
+ Examples:
166
+ kernel_scipy = create_circle_kernel_scipy(17, 8)
167
+ will replicate xarray-spatial kernel
168
+ cell_x = 0.3
169
+ cell_y = 0.3
170
+ tree_radius = 2.5
171
+ convolution.circle_kernel(cell_x, cell_y, tree_radius)
172
+
173
+ """
174
+ # Create grid points (mesh)
175
+ y, x = np.ogrid[:size, :size]
176
+
177
+ # Center of the kernel
178
+ center_x, center_y = (size - 1) / 2, (size - 1) / 2
179
+
180
+ # Calculate the distance from the center
181
+ distance = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
182
+
183
+ # Create a circular kernel
184
+ kernel = distance <= radius
185
+ return kernel.astype(float)