tobac 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. tobac/__init__.py +112 -0
  2. tobac/analysis/__init__.py +31 -0
  3. tobac/analysis/cell_analysis.py +628 -0
  4. tobac/analysis/feature_analysis.py +212 -0
  5. tobac/analysis/spatial.py +619 -0
  6. tobac/centerofgravity.py +226 -0
  7. tobac/feature_detection.py +1758 -0
  8. tobac/merge_split.py +324 -0
  9. tobac/plotting.py +2321 -0
  10. tobac/segmentation/__init__.py +10 -0
  11. tobac/segmentation/watershed_segmentation.py +1316 -0
  12. tobac/testing.py +1179 -0
  13. tobac/tests/segmentation_tests/test_iris_xarray_segmentation.py +0 -0
  14. tobac/tests/segmentation_tests/test_segmentation.py +1183 -0
  15. tobac/tests/segmentation_tests/test_segmentation_time_pad.py +104 -0
  16. tobac/tests/test_analysis_spatial.py +1109 -0
  17. tobac/tests/test_convert.py +265 -0
  18. tobac/tests/test_datetime.py +216 -0
  19. tobac/tests/test_decorators.py +148 -0
  20. tobac/tests/test_feature_detection.py +1321 -0
  21. tobac/tests/test_generators.py +273 -0
  22. tobac/tests/test_import.py +24 -0
  23. tobac/tests/test_iris_xarray_match_utils.py +244 -0
  24. tobac/tests/test_merge_split.py +351 -0
  25. tobac/tests/test_pbc_utils.py +497 -0
  26. tobac/tests/test_sample_data.py +197 -0
  27. tobac/tests/test_testing.py +747 -0
  28. tobac/tests/test_tracking.py +714 -0
  29. tobac/tests/test_utils.py +650 -0
  30. tobac/tests/test_utils_bulk_statistics.py +789 -0
  31. tobac/tests/test_utils_coordinates.py +328 -0
  32. tobac/tests/test_utils_internal.py +97 -0
  33. tobac/tests/test_xarray_utils.py +232 -0
  34. tobac/tracking.py +613 -0
  35. tobac/utils/__init__.py +27 -0
  36. tobac/utils/bulk_statistics.py +360 -0
  37. tobac/utils/datetime.py +184 -0
  38. tobac/utils/decorators.py +540 -0
  39. tobac/utils/general.py +753 -0
  40. tobac/utils/generators.py +87 -0
  41. tobac/utils/internal/__init__.py +2 -0
  42. tobac/utils/internal/coordinates.py +430 -0
  43. tobac/utils/internal/iris_utils.py +462 -0
  44. tobac/utils/internal/label_props.py +82 -0
  45. tobac/utils/internal/xarray_utils.py +439 -0
  46. tobac/utils/mask.py +364 -0
  47. tobac/utils/periodic_boundaries.py +419 -0
  48. tobac/wrapper.py +244 -0
  49. tobac-1.6.2.dist-info/METADATA +154 -0
  50. tobac-1.6.2.dist-info/RECORD +53 -0
  51. tobac-1.6.2.dist-info/WHEEL +5 -0
  52. tobac-1.6.2.dist-info/licenses/LICENSE +29 -0
  53. tobac-1.6.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1758 @@
1
+ """First step toward working with *tobac*. Detects features from input 2D or 3D data.
2
+
3
+ This module can work with any two-dimensional or three-dimensional field.
4
+ To identify the features, contiguous regions above or
5
+ below a threshold are determined and labelled individually.
6
+ To describe the specific location of the feature at a
7
+ specific point in time, different spatial properties
8
+ are used to describe the identified region. [1]_
9
+
10
+ References
11
+ ----------
12
+ .. [1] Heikenfeld, M., Marinescu, P. J., Christensen, M.,
13
+ Watson-Parris, D., Senf, F., van den Heever, S. C.
14
+ & Stier, P. (2019). tobac 1.2: towards a flexible
15
+ framework for tracking and analysis of clouds in
16
+ diverse datasets. Geoscientific Model Development,
17
+ 12(11), 4551-4570.
18
+ """
19
+
20
+ from __future__ import annotations
21
+ import logging
22
+ import warnings
23
+
24
+ from typing import Optional, Union, Callable, Any
25
+ from typing_extensions import Literal
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+ import xarray as xr
30
+ from scipy.spatial import KDTree
31
+ from sklearn.neighbors import BallTree
32
+
33
+ from tobac.utils import decorators
34
+ from tobac.utils import get_statistics
35
+ from tobac.utils import internal as internal_utils
36
+ from tobac.utils import periodic_boundaries as pbc_utils
37
+ from tobac.utils.general import spectral_filtering
38
+ from tobac.utils.generators import field_and_features_over_time
39
+
40
+
41
+ def feature_position(
42
+ hdim1_indices: list[int],
43
+ hdim2_indices: list[int],
44
+ vdim_indices: Union[list[int], None] = None,
45
+ region_small: np.ndarray = None,
46
+ region_bbox: Union[list[int], tuple[int]] = None,
47
+ track_data: np.ndarray = None,
48
+ threshold_i: float = None,
49
+ position_threshold: Literal[
50
+ "center", "extreme", "weighted_diff", "weighted abs"
51
+ ] = "center",
52
+ target: Literal["maximum", "minimum"] = None,
53
+ PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none",
54
+ hdim1_min: int = 0,
55
+ hdim1_max: int = 0,
56
+ hdim2_min: int = 0,
57
+ hdim2_max: int = 0,
58
+ ) -> tuple[float]:
59
+ """Determine feature position with regard to the horizontal
60
+ dimensions in pixels from the identified region above
61
+ threshold values.
62
+
63
+ :hidden:
64
+
65
+ Parameters
66
+ ----------
67
+ hdim1_indices : list
68
+ indices of pixels in region along first horizontal
69
+ dimension
70
+
71
+ hdim2_indices : list
72
+ indices of pixels in region along second horizontal
73
+ dimension
74
+
75
+ vdim_indices : list, optional
76
+ List of indices of feature along optional vdim (typically ```z```)
77
+
78
+ region_small : 2D or 3D array-like
79
+ A true/false array containing True where the threshold
80
+ is met and false where the threshold isn't met. This
81
+ array should be the the size specified by region_bbox,
82
+ and can be a subset of the overall input array
83
+ (i.e., ```track_data```).
84
+
85
+ region_bbox : list or tuple with length of 4 or 6
86
+ The coordinates that region_small occupies within the total track_data
87
+ array. This is in the order that the coordinates come from the
88
+ ```get_label_props_in_dict``` function. For 2D data, this should be:
89
+ (hdim1 start, hdim 2 start, hdim 1 end, hdim 2 end). For 3D data, this
90
+ is: (vdim start, hdim1 start, hdim 2 start, vdim end, hdim 1 end, hdim 2 end).
91
+
92
+ track_data : 2D or 3D array-like
93
+ 2D or 3D array containing the data
94
+
95
+ threshold_i : float
96
+ The threshold value that we are testing against
97
+
98
+ position_threshold : {'center', 'extreme', 'weighted_diff', 'weighted abs'}
99
+ How to select the single point position from our data.
100
+ 'center' picks the geometrical centre of the region,
101
+ and is typically not recommended. 'extreme' picks the
102
+ maximum or minimum value inside the region (max/min set by
103
+ ```target```) 'weighted_diff' picks the centre of the
104
+ region weighted by the distance from the threshold value
105
+ 'weighted_abs' picks the centre of the region weighted by
106
+ the absolute values of the field
107
+
108
+ target : {'maximum', 'minimum'}
109
+ Used only when position_threshold is set to 'extreme',
110
+ this sets whether it is looking for maxima or minima.
111
+
112
+ PBC_flag : {'none', 'hdim_1', 'hdim_2', 'both'}
113
+ Sets whether to use periodic boundaries, and if so in which directions.
114
+ 'none' means that we do not have periodic boundaries
115
+ 'hdim_1' means that we are periodic along hdim1
116
+ 'hdim_2' means that we are periodic along hdim2
117
+ 'both' means that we are periodic along both horizontal dimensions
118
+
119
+ hdim1_min : int
120
+ Minimum real array index of the first horizontal dimension (for PBCs)
121
+
122
+ hdim1_max: int
123
+ Maximum real array index of the first horizontal dimension (for PBCs)
124
+ Note that this coordinate is INCLUSIVE, meaning that this is
125
+ the maximum coordinate value, and it is not a length.
126
+
127
+ hdim2_min : int
128
+ Minimum real array index of the first horizontal dimension (for PBCs)
129
+
130
+ hdim2_max : int
131
+ Maximum real array index of the first horizontal dimension (for PBCs)
132
+ Note that this coordinate is INCLUSIVE, meaning that this is
133
+ the maximum coordinate value, and it is not a length.
134
+
135
+ Returns
136
+ -------
137
+ 2-element or 3-element tuple of floats
138
+ If input data is 2D, this will be a 2-element tuple of floats,
139
+ where the first element is the feature position along the first
140
+ horizontal dimension and the second element is the feature position
141
+ along the second horizontal dimension.
142
+ If input data is 3D, this will be a 3-element tuple of floats, where
143
+ the first element is the feature position along the vertical dimension
144
+ and the second two elements are the feature position on the first and
145
+ second horizontal dimensions.
146
+ Note for PBCs: this point *can* be >hdim1_max or hdim2_max if the
147
+ point is between hdim1_max and hdim1_min. For example, if a feature
148
+ lies exactly between hdim1_max and hdim1_min, the output could be
149
+ between hdim1_max and hdim1_max+1. While a value between hdim1_min-1
150
+ and hdim1_min would also be valid, we choose to overflow on the max side of things.
151
+ Notes
152
+ -----
153
+ """
154
+
155
+ # First, if necessary, run PBC processing.
156
+ # processing of PBC indices
157
+ # checks to see if minimum and maximum values are present in dimensional array
158
+ # then if true, adds max value to any indices past the halfway point of their
159
+ # respective dimension. this, in essence, shifts the set of points to the high side.
160
+ pbc_options = ["hdim_1", "hdim_2", "both"]
161
+
162
+ if len(region_bbox) == 4:
163
+ # 2D case
164
+ is_3D = False
165
+ track_data_region = track_data[
166
+ region_bbox[0] : region_bbox[2], region_bbox[1] : region_bbox[3]
167
+ ]
168
+ elif len(region_bbox) == 6:
169
+ # 3D case
170
+ is_3D = True
171
+ track_data_region = track_data[
172
+ region_bbox[0] : region_bbox[3],
173
+ region_bbox[1] : region_bbox[4],
174
+ region_bbox[2] : region_bbox[5],
175
+ ]
176
+ else:
177
+ raise ValueError("region_bbox must have 4 or 6 elements.")
178
+ # whether or not to run the means at the end
179
+ run_mean = False
180
+ if position_threshold == "center":
181
+ # get position as geometrical centre of identified region:
182
+
183
+ hdim1_weights = np.ones(np.size(hdim1_indices))
184
+ hdim2_weights = np.ones(np.size(hdim2_indices))
185
+ if is_3D:
186
+ vdim_weights = np.ones(np.size(hdim2_indices))
187
+
188
+ run_mean = True
189
+
190
+ elif position_threshold == "extreme":
191
+ # get position as max/min position inside the identified region:
192
+ if target == "maximum":
193
+ index = np.argmax(track_data_region[region_small])
194
+ if target == "minimum":
195
+ index = np.argmin(track_data_region[region_small])
196
+ hdim1_index = hdim1_indices[index]
197
+ hdim2_index = hdim2_indices[index]
198
+ if is_3D:
199
+ vdim_index = vdim_indices[index]
200
+
201
+ elif position_threshold == "weighted_diff":
202
+ # get position as centre of identified region, weighted by difference from the threshold:
203
+ weights = np.abs(track_data_region[region_small] - threshold_i)
204
+ if np.sum(weights) == 0:
205
+ weights = None
206
+ hdim1_weights = weights
207
+ hdim2_weights = weights
208
+ if is_3D:
209
+ vdim_weights = weights
210
+
211
+ run_mean = True
212
+
213
+ elif position_threshold == "weighted_abs":
214
+ # get position as centre of identified region, weighted by absolute values if the field:
215
+ weights = np.abs(track_data_region[region_small])
216
+ if np.sum(weights) == 0:
217
+ weights = None
218
+ hdim1_weights = weights
219
+ hdim2_weights = weights
220
+ if is_3D:
221
+ vdim_weights = weights
222
+ run_mean = True
223
+
224
+ else:
225
+ raise ValueError(
226
+ "position_threshold must be center,extreme,weighted_diff or weighted_abs"
227
+ )
228
+
229
+ if run_mean:
230
+ if PBC_flag in ("hdim_1", "both"):
231
+ hdim1_index = pbc_utils.weighted_circmean(
232
+ hdim1_indices, weights=hdim1_weights, high=hdim1_max + 1, low=hdim1_min
233
+ )
234
+ hdim1_index = np.clip(hdim1_index, 0, hdim1_max + 1)
235
+ else:
236
+ hdim1_index = np.average(hdim1_indices, weights=hdim1_weights)
237
+ hdim1_index = np.clip(hdim1_index, 0, hdim1_max)
238
+ if PBC_flag in ("hdim_2", "both"):
239
+ hdim2_index = pbc_utils.weighted_circmean(
240
+ hdim2_indices, weights=hdim2_weights, high=hdim2_max + 1, low=hdim2_min
241
+ )
242
+ hdim2_index = np.clip(hdim2_index, 0, hdim2_max + 1)
243
+ else:
244
+ hdim2_index = np.average(hdim2_indices, weights=hdim2_weights)
245
+ hdim2_index = np.clip(hdim2_index, 0, hdim2_max)
246
+ if is_3D:
247
+ vdim_index = np.average(vdim_indices, weights=vdim_weights)
248
+
249
+ if is_3D:
250
+ return vdim_index, hdim1_index, hdim2_index
251
+ else:
252
+ return hdim1_index, hdim2_index
253
+
254
+
255
+ def test_overlap(
256
+ region_inner: list[tuple[int]], region_outer: list[tuple[int]]
257
+ ) -> bool:
258
+ """Test for overlap between two regions
259
+
260
+ :hidden:
261
+
262
+ Parameters
263
+ ----------
264
+ region_1 : list
265
+ list of 2-element tuples defining the indices of
266
+ all cell in the region
267
+
268
+ region_2 : list
269
+ list of 2-element tuples defining the indices of
270
+ all cell in the region
271
+
272
+ Returns
273
+ ----------
274
+ overlap : bool
275
+ True if there are any shared points between the two
276
+ regions
277
+ """
278
+
279
+ overlap = frozenset(region_outer).isdisjoint(region_inner)
280
+ return not overlap
281
+
282
+
283
+ def remove_parents(
284
+ features_thresholds: pd.DataFrame,
285
+ regions_i: dict,
286
+ regions_old: dict,
287
+ strict_thresholding: bool = False,
288
+ ) -> pd.DataFrame:
289
+ """Remove parents of newly detected feature regions.
290
+
291
+ Remove features where its regions surround newly
292
+ detected feature regions.
293
+
294
+ :hidden:
295
+
296
+ Parameters
297
+ ----------
298
+ features_thresholds : pandas.DataFrame
299
+ Dataframe containing detected features.
300
+
301
+ regions_i : dict
302
+ Dictionary containing the regions greater/lower than and equal to
303
+ threshold for the newly detected feature
304
+ (feature ids as keys).
305
+
306
+ regions_old : dict
307
+ Dictionary containing the regions greater/lower than and equal to
308
+ threshold from previous threshold
309
+ (feature ids as keys).
310
+
311
+ strict_thresholding: Bool, optional
312
+ If True, a feature can only be detected if all previous thresholds have been met.
313
+ Default is False.
314
+ Returns
315
+ -------
316
+ features_thresholds : pandas.DataFrame
317
+ Dataframe containing detected features excluding those
318
+ that are superseded by newly detected ones.
319
+ """
320
+
321
+ try:
322
+ all_curr_pts = np.concatenate([vals for idx, vals in regions_i.items()])
323
+ except ValueError:
324
+ # the case where there are no new regions
325
+ if strict_thresholding:
326
+ return features_thresholds, {}
327
+ else:
328
+ return features_thresholds, regions_old
329
+ try:
330
+ all_old_pts = np.concatenate([vals for idx, vals in regions_old.items()])
331
+ except ValueError:
332
+ # the case where there are no old regions
333
+ if strict_thresholding:
334
+ return (
335
+ features_thresholds[
336
+ ~features_thresholds["idx"].isin(list(regions_i.keys()))
337
+ ],
338
+ {},
339
+ )
340
+ else:
341
+ return features_thresholds, regions_i
342
+
343
+ old_feat_arr = np.empty((len(all_old_pts)))
344
+ curr_loc = 0
345
+ for idx_old in regions_old:
346
+ old_feat_arr[curr_loc : curr_loc + len(regions_old[idx_old])] = idx_old
347
+ curr_loc += len(regions_old[idx_old])
348
+
349
+ _, common_ix_new, common_ix_old = np.intersect1d(
350
+ all_curr_pts, all_old_pts, return_indices=True
351
+ )
352
+ list_remove = np.unique(old_feat_arr[common_ix_old])
353
+
354
+ if strict_thresholding:
355
+ new_feat_arr = np.empty((len(all_curr_pts)))
356
+ curr_loc = 0
357
+ for idx_new in regions_i:
358
+ new_feat_arr[curr_loc : curr_loc + len(regions_i[idx_new])] = idx_new
359
+ curr_loc += len(regions_i[idx_new])
360
+ regions_i_overlap = np.unique(new_feat_arr[common_ix_new])
361
+ no_prev_feature = np.array(list(regions_i.keys()))[
362
+ np.logical_not(np.isin(list(regions_i.keys()), regions_i_overlap))
363
+ ]
364
+ list_remove = np.concatenate([list_remove, no_prev_feature])
365
+
366
+ # remove parent regions:
367
+ if features_thresholds is not None:
368
+ features_thresholds = features_thresholds[
369
+ ~features_thresholds["idx"].isin(list_remove)
370
+ ]
371
+
372
+ if strict_thresholding:
373
+ keep_new_keys = np.isin(list(regions_i.keys()), features_thresholds["idx"])
374
+ regions_old = {
375
+ k: v for i, (k, v) in enumerate(regions_i.items()) if keep_new_keys[i]
376
+ }
377
+ else:
378
+ keep_old_keys = np.isin(
379
+ list(regions_old.keys()), features_thresholds["idx"]
380
+ )
381
+ regions_old = {
382
+ k: v for i, (k, v) in enumerate(regions_old.items()) if keep_old_keys[i]
383
+ }
384
+ regions_old.update(regions_i)
385
+ else:
386
+ regions_old = regions_i
387
+
388
+ return features_thresholds, regions_old
389
+
390
+
391
+ def feature_detection_threshold(
392
+ data_i: np.array,
393
+ i_time: int,
394
+ threshold: float = None,
395
+ min_num: int = 0,
396
+ target: Literal["maximum", "minimum"] = "maximum",
397
+ position_threshold: Literal[
398
+ "center", "extreme", "weighted_diff", "weighted_abs"
399
+ ] = "center",
400
+ sigma_threshold: float = 0.5,
401
+ n_erosion_threshold: int = 0,
402
+ n_min_threshold: Union[int, dict[float, int], list[int]] = 0,
403
+ min_distance: float = 0,
404
+ idx_start: int = 0,
405
+ PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none",
406
+ vertical_axis: int = 0,
407
+ **kwargs: dict[str, Any],
408
+ ) -> tuple[pd.DataFrame, dict]:
409
+ """Find features based on individual threshold value.
410
+
411
+ :hidden:
412
+
413
+ Parameters
414
+ ----------
415
+ data_i : np.array
416
+ 2D or 3D field to perform the feature detection (single timestep) on.
417
+
418
+ i_time : int
419
+ Number of the current timestep.
420
+
421
+ threshold : float, optional
422
+ Threshold value used to select target regions to track. The feature detection is inclusive of the
423
+ threshold value(s), i.e. values greater/less than or equal are included in the target region. The
424
+ feature detection is inclusive of the threshold value(s), i.e. values greater/less than or equal are
425
+ included in the target region. Default is None.
426
+
427
+ target : {'maximum', 'minimum'}, optional
428
+ Flag to determine if tracking is targeting minima or maxima
429
+ in the data. Default is 'maximum'.
430
+
431
+ position_threshold : {'center', 'extreme', 'weighted_diff',
432
+ 'weighted_abs'}, optional
433
+ Flag choosing method used for the position of the tracked
434
+ feature. Default is 'center'.
435
+
436
+ sigma_threshold: float, optional
437
+ Standard deviation for initial filtering step. Default is 0.5.
438
+
439
+ n_erosion_threshold: int, optional
440
+ Number of pixels by which to erode the identified features.
441
+ Default is 0.
442
+
443
+ n_min_threshold : int, dict of float to int, or list of int, optional
444
+ Minimum number of identified contiguous pixels for a feature to be detected. Default is 0.
445
+ If given as a list, the number of elements must match number of thresholds.
446
+ If given as a dict, the keys need to match the thresholds and the values are the minimum number of identified contiguous pixels for a feature using that specific threshold.
447
+
448
+ min_distance : float, optional
449
+ Minimum distance between detected features (in meters). Default is 0.
450
+
451
+ idx_start : int, optional
452
+ Feature id to start with. Default is 0.
453
+
454
+ PBC_flag : {'none', 'hdim_1', 'hdim_2', 'both'}
455
+ Sets whether to use periodic boundaries, and if so in which directions.
456
+ 'none' means that we do not have periodic boundaries
457
+ 'hdim_1' means that we are periodic along hdim1
458
+ 'hdim_2' means that we are periodic along hdim2
459
+ 'both' means that we are periodic along both horizontal dimensions
460
+ vertical_axis: int
461
+ The vertical axis number of the data.
462
+
463
+ kwargs : dict
464
+ Additional keyword arguments.
465
+
466
+
467
+ Returns
468
+ -------
469
+ features_threshold : pandas DataFrame
470
+ Detected features for individual threshold.
471
+
472
+ regions : dict
473
+ Dictionary containing the regions above/below threshold used
474
+ for each feature (feature ids as keys).
475
+ """
476
+
477
+ from skimage.measure import label
478
+ from skimage.morphology import binary_erosion
479
+ from copy import deepcopy
480
+
481
+ if min_num != 0:
482
+ warnings.warn(
483
+ "min_num parameter has no effect and will be deprecated in a future version of tobac. "
484
+ "Please use n_min_threshold instead",
485
+ FutureWarning,
486
+ )
487
+
488
+ # If we are given a 3D data array, we should do 3D feature detection.
489
+ is_3D = len(data_i.shape) == 3
490
+
491
+ # We need to transpose the input data
492
+ if is_3D:
493
+ if vertical_axis == 1:
494
+ data_i = np.transpose(data_i, axes=(1, 0, 2))
495
+ elif vertical_axis == 2:
496
+ data_i = np.transpose(data_i, axes=(2, 0, 1))
497
+
498
+ # if looking for minima, set values above threshold to 0 and scale by data minimum:
499
+ if target == "maximum":
500
+ mask = data_i >= threshold
501
+ # if looking for minima, set values above threshold to 0 and scale by data minimum:
502
+ elif target == "minimum":
503
+ mask = data_i <= threshold
504
+ # only include values greater than threshold
505
+ # erode selected regions by n pixels
506
+ if n_erosion_threshold > 0:
507
+ if is_3D:
508
+ selem = np.ones(
509
+ (n_erosion_threshold, n_erosion_threshold, n_erosion_threshold)
510
+ )
511
+ else:
512
+ selem = np.ones((n_erosion_threshold, n_erosion_threshold))
513
+ mask = binary_erosion(mask, selem)
514
+ # detect individual regions, label and count the number of pixels included:
515
+ labels, num_labels = label(mask, background=0, return_num=True)
516
+ if not is_3D:
517
+ # let's transpose labels to a 1,y,x array to make calculations etc easier.
518
+ labels = labels[np.newaxis, :, :]
519
+ # these are [min, max], meaning that the max value is inclusive and a valid
520
+ # value.
521
+ z_min = 0
522
+ z_max = labels.shape[0] - 1
523
+ y_min = 0
524
+ y_max = labels.shape[1] - 1
525
+ x_min = 0
526
+ x_max = labels.shape[2] - 1
527
+
528
+ # deal with PBCs
529
+ # all options that involve dealing with periodic boundaries
530
+ pbc_options = ["hdim_1", "hdim_2", "both"]
531
+ if PBC_flag not in pbc_options and PBC_flag != "none":
532
+ raise ValueError(
533
+ "Options for periodic are currently: none, " + ", ".join(pbc_options)
534
+ )
535
+
536
+ # we need to deal with PBCs in some way.
537
+ if PBC_flag in pbc_options and num_labels > 0:
538
+ #
539
+ # create our copy of `labels` to edit
540
+ labels_2 = deepcopy(labels)
541
+ # points we've already edited
542
+ skip_list = np.array([])
543
+ # labels that touch the PBC walls
544
+ wall_labels = np.array([], dtype=np.int32)
545
+
546
+ all_label_props = internal_utils.get_label_props_in_dict(labels)
547
+ [
548
+ all_labels_max_size,
549
+ all_label_locs_v,
550
+ all_label_locs_h1,
551
+ all_label_locs_h2,
552
+ ] = internal_utils.get_indices_of_labels_from_reg_prop_dict(all_label_props)
553
+
554
+ # find the points along the boundaries
555
+
556
+ # along hdim_1 or both horizontal boundaries
557
+ if PBC_flag == "hdim_1" or PBC_flag == "both":
558
+ # north and south wall
559
+ ns_wall = np.unique(labels[:, (y_min, y_max), :])
560
+ wall_labels = np.append(wall_labels, ns_wall)
561
+
562
+ # along hdim_2 or both horizontal boundaries
563
+ if PBC_flag == "hdim_2" or PBC_flag == "both":
564
+ # east/west wall
565
+ ew_wall = np.unique(labels[:, :, (x_min, x_max)])
566
+ wall_labels = np.append(wall_labels, ew_wall)
567
+
568
+ wall_labels = np.unique(wall_labels)
569
+
570
+ for label_ind in wall_labels:
571
+ new_label_ind = label_ind
572
+ # 0 isn't a real index
573
+ if label_ind == 0:
574
+ continue
575
+ # skip this label if we have already dealt with it.
576
+ if np.any(label_ind == skip_list):
577
+ continue
578
+
579
+ # create list for skip labels for this wall label only
580
+ skip_list_thisind = list()
581
+
582
+ # get all locations of this label.
583
+ # TODO: harmonize x/y/z vs hdim1/hdim2/vdim.
584
+ label_locs_v = all_label_locs_v[label_ind]
585
+ label_locs_h1 = all_label_locs_h1[label_ind]
586
+ label_locs_h2 = all_label_locs_h2[label_ind]
587
+
588
+ # loop through every point in the label
589
+ for label_z, label_y, label_x in zip(
590
+ label_locs_v, label_locs_h1, label_locs_h2
591
+ ):
592
+ # check if this is the special case of being a corner point.
593
+ # if it's doubly periodic AND on both x and y boundaries, it's a corner point
594
+ # and we have to look at the other corner.
595
+ # here, we will only look at the corner point and let the below deal with x/y only.
596
+ if PBC_flag == "both" and (
597
+ np.any(label_y == [y_min, y_max])
598
+ and np.any(label_x == [x_min, x_max])
599
+ ):
600
+ # adjust x and y points to the other side
601
+ y_val_alt = pbc_utils.adjust_pbc_point(label_y, y_min, y_max)
602
+ x_val_alt = pbc_utils.adjust_pbc_point(label_x, x_min, x_max)
603
+
604
+ label_on_corner = labels[label_z, y_val_alt, x_val_alt]
605
+
606
+ if (label_on_corner != 0) and (
607
+ ~np.any(label_on_corner == skip_list)
608
+ ):
609
+ # alt_inds = np.where(labels==alt_label_3)
610
+ # get a list of indices where the label on the corner is so we can switch
611
+ # them in the new list.
612
+
613
+ labels_2[
614
+ all_label_locs_v[label_on_corner],
615
+ all_label_locs_h1[label_on_corner],
616
+ all_label_locs_h2[label_on_corner],
617
+ ] = label_ind
618
+ skip_list = np.append(skip_list, label_on_corner)
619
+ skip_list_thisind = np.append(
620
+ skip_list_thisind, label_on_corner
621
+ )
622
+
623
+ # if it's labeled and has already been dealt with for this label
624
+ elif (
625
+ (label_on_corner != 0)
626
+ and (np.any(label_on_corner == skip_list))
627
+ and (np.any(label_on_corner == skip_list_thisind))
628
+ ):
629
+ # print("skip_list_thisind label - has already been treated this index")
630
+ continue
631
+
632
+ # if it's labeled and has already been dealt with via a previous label
633
+ elif (
634
+ (label_on_corner != 0)
635
+ and (np.any(label_on_corner == skip_list))
636
+ and (~np.any(label_on_corner == skip_list_thisind))
637
+ ):
638
+ # find the updated label, and overwrite all of label_ind indices with
639
+ # updated label
640
+ labels_2_alt = labels_2[label_z, y_val_alt, x_val_alt]
641
+ labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
642
+ labels_2_alt
643
+ )
644
+ skip_list = np.append(skip_list, label_ind)
645
+ break
646
+
647
+ # on the hdim1 boundary and periodic on hdim1
648
+ if (PBC_flag == "hdim_1" or PBC_flag == "both") and np.any(
649
+ label_y == [y_min, y_max]
650
+ ):
651
+ y_val_alt = pbc_utils.adjust_pbc_point(label_y, y_min, y_max)
652
+
653
+ # get the label value on the opposite side
654
+ label_alt = labels[label_z, y_val_alt, label_x]
655
+
656
+ # if it's labeled and not already been dealt with
657
+ if (label_alt != 0) and (~np.any(label_alt == skip_list)):
658
+ # find the indices where it has the label value on opposite side and change
659
+ # their value to original side
660
+ # print(all_label_locs_v[label_alt], alt_inds[0])
661
+ labels_2[
662
+ all_label_locs_v[label_alt],
663
+ all_label_locs_h1[label_alt],
664
+ all_label_locs_h2[label_alt],
665
+ ] = new_label_ind
666
+ # we have already dealt with this label.
667
+ skip_list = np.append(skip_list, label_alt)
668
+ skip_list_thisind = np.append(skip_list_thisind, label_alt)
669
+
670
+ # if it's labeled and has already been dealt with for this label
671
+ elif (
672
+ (label_alt != 0)
673
+ and (np.any(label_alt == skip_list))
674
+ and (np.any(label_alt == skip_list_thisind))
675
+ ):
676
+ continue
677
+
678
+ # if it's labeled and has already been dealt with
679
+ elif (
680
+ (label_alt != 0)
681
+ and (np.any(label_alt == skip_list))
682
+ and (~np.any(label_alt == skip_list_thisind))
683
+ ):
684
+ # find the updated label, and overwrite all of label_ind indices with
685
+ # updated label
686
+ labels_2_alt = labels_2[label_z, y_val_alt, label_x]
687
+ labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
688
+ labels_2_alt
689
+ )
690
+ new_label_ind = labels_2_alt
691
+ skip_list = np.append(skip_list, label_ind)
692
+
693
+ if (PBC_flag == "hdim_2" or PBC_flag == "both") and np.any(
694
+ label_x == [x_min, x_max]
695
+ ):
696
+ x_val_alt = pbc_utils.adjust_pbc_point(label_x, x_min, x_max)
697
+
698
+ # get the label value on the opposite side
699
+ label_alt = labels[label_z, label_y, x_val_alt]
700
+
701
+ # if it's labeled and not already been dealt with
702
+ if (label_alt != 0) and (~np.any(label_alt == skip_list)):
703
+ # find the indices where it has the label value on opposite side and change
704
+ # their value to original side
705
+ labels_2[
706
+ all_label_locs_v[label_alt],
707
+ all_label_locs_h1[label_alt],
708
+ all_label_locs_h2[label_alt],
709
+ ] = new_label_ind
710
+ # we have already dealt with this label.
711
+ skip_list = np.append(skip_list, label_alt)
712
+ skip_list_thisind = np.append(skip_list_thisind, label_alt)
713
+
714
+ # if it's labeled and has already been dealt with for this label
715
+ elif (
716
+ (label_alt != 0)
717
+ and (np.any(label_alt == skip_list))
718
+ and (np.any(label_alt == skip_list_thisind))
719
+ ):
720
+ continue
721
+
722
+ # if it's labeled and has already been dealt with
723
+ elif (
724
+ (label_alt != 0)
725
+ and (np.any(label_alt == skip_list))
726
+ and (~np.any(label_alt == skip_list_thisind))
727
+ ):
728
+ # find the updated label, and overwrite all of label_ind indices with
729
+ # updated label
730
+ labels_2_alt = labels_2[label_z, label_y, x_val_alt]
731
+ labels_2[label_locs_v, label_locs_h1, label_locs_h2] = (
732
+ labels_2_alt
733
+ )
734
+ new_label_ind = labels_2_alt
735
+ skip_list = np.append(skip_list, label_ind)
736
+
737
+ # copy over new labels after we have adjusted everything
738
+ labels = labels_2
739
+
740
+ # END PBC treatment
741
+ # we need to get label properties again after we handle PBCs.
742
+
743
+ label_props = internal_utils.get_label_props_in_dict(labels)
744
+ if len(label_props) > 0:
745
+ (
746
+ total_indices_all,
747
+ vdim_indices_all,
748
+ hdim1_indices_all,
749
+ hdim2_indices_all,
750
+ ) = internal_utils.get_indices_of_labels_from_reg_prop_dict(label_props)
751
+
752
+ # values, count = np.unique(labels[:,:].ravel(), return_counts=True)
753
+ # values_counts=dict(zip(values, count))
754
+ # Filter out regions that have less pixels than n_min_threshold
755
+ # values_counts={k:v for k, v in values_counts.items() if v>n_min_threshold}
756
+
757
+ # check if not entire domain filled as one feature
758
+ if num_labels > 0:
759
+ # create empty list to store individual features for this threshold
760
+ list_features_threshold = list()
761
+ # create empty dict to store regions for individual features for this threshold
762
+ regions = dict()
763
+ # create empty list of features to remove from parent threshold value
764
+
765
+ region = np.empty(mask.shape, dtype=bool)
766
+ # loop over individual regions:
767
+ for cur_idx in total_indices_all:
768
+ # skip this if there aren't enough points to be considered a real feature
769
+ # as defined above by n_min_threshold
770
+ curr_count = total_indices_all[cur_idx]
771
+ if curr_count <= n_min_threshold:
772
+ continue
773
+ if is_3D:
774
+ vdim_indices = vdim_indices_all[cur_idx]
775
+ else:
776
+ vdim_indices = None
777
+ hdim1_indices = hdim1_indices_all[cur_idx]
778
+ hdim2_indices = hdim2_indices_all[cur_idx]
779
+
780
+ label_bbox = label_props[cur_idx].bbox
781
+ (
782
+ bbox_zstart,
783
+ bbox_ystart,
784
+ bbox_xstart,
785
+ bbox_zend,
786
+ bbox_yend,
787
+ bbox_xend,
788
+ ) = label_bbox
789
+ bbox_zsize = bbox_zend - bbox_zstart
790
+ bbox_xsize = bbox_xend - bbox_xstart
791
+ bbox_ysize = bbox_yend - bbox_ystart
792
+ # build small region box
793
+ if is_3D:
794
+ region_small = np.full((bbox_zsize, bbox_ysize, bbox_xsize), False)
795
+ region_small[
796
+ vdim_indices - bbox_zstart,
797
+ hdim1_indices - bbox_ystart,
798
+ hdim2_indices - bbox_xstart,
799
+ ] = True
800
+
801
+ else:
802
+ region_small = np.full((bbox_ysize, bbox_xsize), False)
803
+ region_small[
804
+ hdim1_indices - bbox_ystart, hdim2_indices - bbox_xstart
805
+ ] = True
806
+ # we are 2D and need to remove the dummy 3D coordinate.
807
+ label_bbox = (
808
+ label_bbox[1],
809
+ label_bbox[2],
810
+ label_bbox[4],
811
+ label_bbox[5],
812
+ )
813
+
814
+ # [hdim1_indices,hdim2_indices]= np.nonzero(region)
815
+ # write region for individual threshold and feature to dict
816
+
817
+ """
818
+ This block of code creates 1D coordinates from the input
819
+ 2D or 3D coordinates. Dealing with 1D coordinates is substantially
820
+ faster than having to carry around (x, y, z) or (x, y) as
821
+ separate arrays. This also makes comparisons in remove_parents
822
+ substantially faster.
823
+ """
824
+ if is_3D:
825
+ region_i = np.ravel_multi_index(
826
+ (hdim1_indices, hdim2_indices, vdim_indices),
827
+ (y_max + 1, x_max + 1, z_max + 1),
828
+ )
829
+ else:
830
+ region_i = np.ravel_multi_index(
831
+ (hdim1_indices, hdim2_indices), (y_max + 1, x_max + 1)
832
+ )
833
+
834
+ regions[cur_idx + idx_start] = region_i
835
+ # Determine feature position for region by one of the following methods:
836
+ single_indices = feature_position(
837
+ hdim1_indices,
838
+ hdim2_indices,
839
+ vdim_indices=vdim_indices,
840
+ region_small=region_small,
841
+ region_bbox=label_bbox,
842
+ track_data=data_i,
843
+ threshold_i=threshold,
844
+ position_threshold=position_threshold,
845
+ target=target,
846
+ PBC_flag=PBC_flag,
847
+ hdim2_min=x_min,
848
+ hdim2_max=x_max,
849
+ hdim1_min=y_min,
850
+ hdim1_max=y_max,
851
+ )
852
+ if is_3D:
853
+ vdim_index, hdim1_index, hdim2_index = single_indices
854
+ else:
855
+ hdim1_index, hdim2_index = single_indices
856
+ # create individual DataFrame row in tracky format for identified feature
857
+ appending_dict = {
858
+ "frame": int(i_time),
859
+ "idx": cur_idx + idx_start,
860
+ "hdim_1": hdim1_index,
861
+ "hdim_2": hdim2_index,
862
+ "num": curr_count,
863
+ "threshold_value": threshold,
864
+ }
865
+ column_names = [
866
+ "frame",
867
+ "idx",
868
+ "hdim_1",
869
+ "hdim_2",
870
+ "num",
871
+ "threshold_value",
872
+ ]
873
+ if is_3D:
874
+ appending_dict["vdim"] = vdim_index
875
+ column_names = [
876
+ "frame",
877
+ "idx",
878
+ "vdim",
879
+ "hdim_1",
880
+ "hdim_2",
881
+ "num",
882
+ "threshold_value",
883
+ ]
884
+ list_features_threshold.append(appending_dict)
885
+ # after looping thru proto-features, check if any exceed num threshold
886
+ # if they do not, provide a blank pandas df and regions dict
887
+ if list_features_threshold == []:
888
+ # print("no features above num value at threshold: ",threshold)
889
+ features_threshold = pd.DataFrame()
890
+ regions = dict()
891
+ # if they do, provide a dataframe with features organized with 2D and 3D metadata
892
+ else:
893
+ # print("at least one feature above num value at threshold: ",threshold)
894
+ # print("column_names, after cur_idx loop: ",column_names)
895
+ features_threshold = pd.DataFrame(
896
+ list_features_threshold, columns=column_names
897
+ )
898
+ # features_threshold=pd.DataFrame(list_features_threshold, columns = column_names)
899
+ else:
900
+ features_threshold = pd.DataFrame()
901
+ regions = dict()
902
+
903
+ return features_threshold, regions
904
+
905
+
906
+ @internal_utils.irispandas_to_xarray()
907
+ def feature_detection_multithreshold_timestep(
908
+ data_i: xr.DataArray,
909
+ i_time: int,
910
+ threshold: list[float] = None,
911
+ min_num: int = 0,
912
+ target: Literal["maximum", "minimum"] = "maximum",
913
+ position_threshold: Literal[
914
+ "center", "extreme", "weighted_diff", "weighted abs"
915
+ ] = "center",
916
+ sigma_threshold: float = 0.5,
917
+ n_erosion_threshold: int = 0,
918
+ n_min_threshold: Union[int, dict[float, int], list[int]] = 0,
919
+ min_distance: float = 0,
920
+ feature_number_start: int = 1,
921
+ PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none",
922
+ vertical_axis: int = None,
923
+ dxy: float = -1,
924
+ wavelength_filtering: tuple[float] = None,
925
+ strict_thresholding: bool = False,
926
+ statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
927
+ statistics_unsmoothed: bool = False,
928
+ return_labels: bool = False,
929
+ **kwargs: dict[str, Any],
930
+ ) -> Union[pd.DataFrame, tuple[xr.DataArray, pd.DataFrame]]:
931
+ """Find features in each timestep.
932
+
933
+ Based on iteratively finding regions above/below a set of
934
+ thresholds. Smoothing the input data with the Gaussian filter makes
935
+ output less sensitive to noisiness of input data.
936
+
937
+ :hidden:
938
+
939
+ Parameters
940
+ ----------
941
+
942
+ data_i : iris.cube.Cube or xarray.DataArray
943
+ 3D field to perform the feature detection (single timestep) on.
944
+
945
+ i_time : int
946
+ Number of the current timestep.
947
+
948
+ threshold : list of floats, optional
949
+ Threshold value used to select target regions to track. The feature detection is inclusive of the threshold value(s), i.e. values greater/less than or equal are included in the target region. Default is None.
950
+
951
+ min_num : int, optional
952
+ This parameter is not used in the function. Default is 0.
953
+
954
+ target : {'maximum', 'minimum'}, optinal
955
+ Flag to determine if tracking is targetting minima or maxima
956
+ in the data. Default is 'maximum'.
957
+
958
+ position_threshold : {'center', 'extreme', 'weighted_diff',
959
+ 'weighted_abs'}, optional
960
+ Flag choosing method used for the position of the tracked
961
+ feature. Default is 'center'.
962
+
963
+ sigma_threshold: float, optional
964
+ Standard deviation for intial filtering step. Default is 0.5.
965
+
966
+ n_erosion_threshold: int, optional
967
+ Number of pixels by which to erode the identified features.
968
+ Default is 0.
969
+
970
+ n_min_threshold : int, dict of float to int, or list of int, optional
971
+ Minimum number of identified contiguous pixels for a feature to be detected. Default is 0.
972
+ If given as a list, the number of elements must match number of thresholds.
973
+ If given as a dict, the keys need to match the thresholds and the values are the minimum number of identified contiguous pixels for a feature using that specific threshold.
974
+
975
+ min_distance : float, optional
976
+ Minimum distance between detected features (in meters). Default is 0.
977
+
978
+ feature_number_start : int, optional
979
+ Feature id to start with. Default is 1.
980
+
981
+ PBC_flag : str('none', 'hdim_1', 'hdim_2', 'both')
982
+ Sets whether to use periodic boundaries, and if so in which directions.
983
+ 'none' means that we do not have periodic boundaries
984
+ 'hdim_1' means that we are periodic along hdim1
985
+ 'hdim_2' means that we are periodic along hdim2
986
+ 'both' means that we are periodic along both horizontal dimensions
987
+
988
+ vertical_axis: int
989
+ The vertical axis number of the data.
990
+ dxy : float
991
+ Grid spacing in meters.
992
+
993
+ wavelength_filtering: tuple, optional
994
+ Minimum and maximum wavelength for spectral filtering in meters. Default is None.
995
+
996
+ strict_thresholding: Bool, optional
997
+ If True, a feature can only be detected if all previous thresholds have been met.
998
+ Default is False.
999
+
1000
+ statistic : dict, optional
1001
+ Default is None. Optional parameter to calculate bulk statistics within feature detection.
1002
+ Dictionary with callable function(s) to apply over the region of each detected feature and the name of the statistics to appear in the feature ou tput dataframe. The functions should be the values and the names of the metric the keys (e.g. {'mean': np.mean})
1003
+
1004
+ statistics_unsmoothed: bool, optional
1005
+ Default is False. If True, calculate the statistics on the raw data instead of the smoothed input data.
1006
+
1007
+ return_labels: bool, optional
1008
+ Default is False. If True, return the label fields.
1009
+
1010
+ kwargs : dict
1011
+ Additional keyword arguments.
1012
+
1013
+
1014
+ Returns
1015
+ -------
1016
+ features_threshold : pandas DataFrame
1017
+ Detected features for individual timestep.
1018
+
1019
+ labels : xarray DataArray, optional
1020
+ Label fields for the respective thresholds. Only returned if
1021
+ return_labels is True.
1022
+ """
1023
+ # Handle scipy depreciation gracefully
1024
+ try:
1025
+ from scipy.ndimage import gaussian_filter
1026
+ except ImportError:
1027
+ from scipy.ndimage.filters import gaussian_filter
1028
+
1029
+ if min_num != 0:
1030
+ warnings.warn(
1031
+ "min_num parameter has no effect and will be deprecated in a future version of tobac. "
1032
+ "Please use n_min_threshold instead",
1033
+ FutureWarning,
1034
+ )
1035
+
1036
+ # get actual numpy array and make a copy so as not to change the data in the iris cube
1037
+ track_data = data_i.values.copy()
1038
+
1039
+ # keep a copy of the unsmoothed data (that can be used for calculating stats)
1040
+ if statistics_unsmoothed:
1041
+ if not statistic:
1042
+ raise ValueError(
1043
+ "Please provide the input parameter statistic to determine what statistics to calculate."
1044
+ )
1045
+
1046
+ track_data = gaussian_filter(
1047
+ track_data, sigma=sigma_threshold
1048
+ ) # smooth data slightly to create rounded, continuous field
1049
+
1050
+ # spectrally filter the input data, if desired
1051
+ if wavelength_filtering is not None:
1052
+ track_data = spectral_filtering(
1053
+ dxy, track_data, wavelength_filtering[0], wavelength_filtering[1]
1054
+ )
1055
+
1056
+ # sort thresholds from least extreme to most extreme
1057
+ threshold_sorted = sorted(threshold, reverse=target == "minimum")
1058
+
1059
+ # check if each threshold has a n_min_threshold (minimum nr. of grid cells associated with
1060
+ # thresholds), if multiple n_min_threshold are given
1061
+ if isinstance(n_min_threshold, list) or isinstance(n_min_threshold, dict):
1062
+ if len(n_min_threshold) is not len(threshold):
1063
+ raise ValueError(
1064
+ "Number of elements in n_min_threshold needs to be the same as thresholds, if "
1065
+ "n_min_threshold is given as dict or list."
1066
+ )
1067
+
1068
+ # check if thresholds in dict correspond to given thresholds
1069
+ if isinstance(n_min_threshold, dict):
1070
+ if threshold_sorted != sorted(
1071
+ n_min_threshold.keys(), reverse=(target == "minimum")
1072
+ ):
1073
+ raise ValueError(
1074
+ "Ambiguous input for threshold values. If n_min_threshold is given as a dict,"
1075
+ " the keys must correspond to the values in threshold."
1076
+ )
1077
+ # sort dictionary by keys (threshold values) so that they match sorted thresholds and
1078
+ # get values for n_min_threshold
1079
+ n_min_threshold = [
1080
+ n_min_threshold[threshold] for threshold in threshold_sorted
1081
+ ]
1082
+
1083
+ elif isinstance(n_min_threshold, list):
1084
+ # if n_min_threshold is a list, sort it such that it still matches with the sorted
1085
+ # threshold values
1086
+ n_min_threshold = [
1087
+ x
1088
+ for _, x in sorted(
1089
+ zip(threshold, n_min_threshold), reverse=(target == "minimum")
1090
+ )
1091
+ ]
1092
+ elif (
1093
+ not isinstance(n_min_threshold, list)
1094
+ and not isinstance(n_min_threshold, dict)
1095
+ and not isinstance(n_min_threshold, int)
1096
+ ):
1097
+ raise ValueError(
1098
+ "N_min_threshold must be an integer. If multiple values for n_min_threshold are given,"
1099
+ " please provide a dictionary or list."
1100
+ )
1101
+
1102
+ # create empty lists to store regions and features for individual timestep
1103
+ features_thresholds = pd.DataFrame()
1104
+ for i_threshold, threshold_i in enumerate(threshold_sorted):
1105
+ if i_threshold > 0 and not features_thresholds.empty:
1106
+ idx_start = features_thresholds["idx"].max() + feature_number_start
1107
+ else:
1108
+ idx_start = feature_number_start - 1
1109
+
1110
+ # select n_min_threshold for respective threshold, if multiple values are given
1111
+ if isinstance(n_min_threshold, list):
1112
+ n_min_threshold_i = n_min_threshold[i_threshold]
1113
+ else:
1114
+ n_min_threshold_i = n_min_threshold
1115
+
1116
+ features_threshold_i, regions_i = feature_detection_threshold(
1117
+ track_data,
1118
+ i_time,
1119
+ threshold=threshold_i,
1120
+ sigma_threshold=sigma_threshold,
1121
+ min_num=min_num,
1122
+ target=target,
1123
+ position_threshold=position_threshold,
1124
+ n_erosion_threshold=n_erosion_threshold,
1125
+ n_min_threshold=n_min_threshold_i,
1126
+ min_distance=min_distance,
1127
+ idx_start=idx_start,
1128
+ PBC_flag=PBC_flag,
1129
+ vertical_axis=vertical_axis,
1130
+ )
1131
+ if any([x is not None for x in features_threshold_i]):
1132
+ features_thresholds = pd.concat(
1133
+ [features_thresholds, features_threshold_i], ignore_index=True
1134
+ )
1135
+
1136
+ # For multiple threshold, and features found both in the current and previous step, remove
1137
+ # "parent" features from Dataframe
1138
+ if i_threshold > 0 and not features_thresholds.empty:
1139
+ # For multiple threshold, and features found both in the current and previous step, remove
1140
+ # "parent" features from Dataframe
1141
+ features_thresholds, regions_old = remove_parents(
1142
+ features_thresholds,
1143
+ regions_i,
1144
+ regions_old,
1145
+ strict_thresholding=strict_thresholding,
1146
+ )
1147
+ elif i_threshold == 0:
1148
+ regions_old = regions_i
1149
+
1150
+ logging.debug(
1151
+ "Finished feature detection for threshold "
1152
+ + str(i_threshold)
1153
+ + " : "
1154
+ + str(threshold_i)
1155
+ )
1156
+
1157
+ if return_labels or statistic:
1158
+ # reconstruct the labeled regions based on the regions dict
1159
+ labels = np.zeros(track_data.shape)
1160
+ labels = labels.astype(int)
1161
+ for key in regions_old.keys():
1162
+ labels.ravel()[regions_old[key]] = key
1163
+ # apply function to get statistics based on labeled regions and functions provided by the user
1164
+ # the feature dataframe is updated by appending a column for each metric
1165
+
1166
+ if statistic:
1167
+ # select which data to use according to statistics_unsmoothed option
1168
+ stats_data = data_i.values if statistics_unsmoothed else track_data
1169
+
1170
+ features_thresholds = get_statistics(
1171
+ features_thresholds,
1172
+ labels,
1173
+ stats_data,
1174
+ statistic=statistic,
1175
+ index=np.unique(labels[labels > 0]),
1176
+ id_column="idx",
1177
+ )
1178
+
1179
+ # Create the final output
1180
+ if return_labels:
1181
+ label_fields = xr.DataArray(
1182
+ labels,
1183
+ coords=data_i.coords,
1184
+ dims=data_i.dims,
1185
+ name="label_fields",
1186
+ ).assign_attrs(threshold=threshold)
1187
+
1188
+ return label_fields, features_thresholds
1189
+
1190
+ else:
1191
+ return features_thresholds
1192
+
1193
+
1194
+ @decorators.irispandas_to_xarray(save_iris_info=True)
1195
+ def feature_detection_multithreshold(
1196
+ field_in: xr.DataArray,
1197
+ dxy: float = None,
1198
+ threshold: list[float] = None,
1199
+ min_num: int = 0,
1200
+ target: Literal["maximum", "minimum"] = "maximum",
1201
+ position_threshold: Literal[
1202
+ "center", "extreme", "weighted_diff", "weighted abs"
1203
+ ] = "center",
1204
+ sigma_threshold: float = 0.5,
1205
+ n_erosion_threshold: int = 0,
1206
+ n_min_threshold: Union[int, dict[float, int], list[int]] = 0,
1207
+ min_distance: float = 0,
1208
+ feature_number_start: int = 1,
1209
+ PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none",
1210
+ vertical_coord: Optional[str] = None,
1211
+ vertical_axis: Optional[int] = None,
1212
+ detect_subset: Optional[dict] = None,
1213
+ wavelength_filtering: Optional[tuple] = None,
1214
+ dz: Union[float, None] = None,
1215
+ strict_thresholding: bool = False,
1216
+ statistic: Union[dict[str, Union[Callable, tuple[Callable, dict]]], None] = None,
1217
+ statistics_unsmoothed: bool = False,
1218
+ return_labels: bool = False,
1219
+ use_standard_names: Optional[bool] = None,
1220
+ converted_from_iris: bool = False,
1221
+ **kwargs: dict[str, Any],
1222
+ ) -> Union[pd.DataFrame, tuple[xr.DataArray, pd.DataFrame]]:
1223
+ """Perform feature detection based on contiguous regions.
1224
+
1225
+ The regions are above/below a threshold.
1226
+
1227
+
1228
+ Parameters
1229
+ ----------
1230
+ field_in : iris.cube.Cube or xarray.DataArray
1231
+ 2D or 3D field to perform the tracking on (needs to have coordinate
1232
+ 'time' along one of its dimensions),
1233
+
1234
+ dxy : float
1235
+ Grid spacing of the input data (in meter).
1236
+
1237
+ thresholds : list of floats, optional
1238
+ Threshold values used to select target regions to track. The feature detection is inclusive of the threshold value(s), i.e. values greater/less than or equal are included in the target region. Default is None.
1239
+
1240
+ target : {'maximum', 'minimum'}, optional
1241
+ Flag to determine if tracking is targetting minima or maxima in
1242
+ the data. Default is 'maximum'.
1243
+
1244
+ position_threshold : {'center', 'extreme', 'weighted_diff',
1245
+ 'weighted_abs'}, optional
1246
+ Flag choosing method used for the position of the tracked
1247
+ feature. Default is 'center'.
1248
+
1249
+ sigma_threshold: float, optional
1250
+ Standard deviation for intial filtering step. Default is 0.5.
1251
+
1252
+ n_erosion_threshold: int, optional
1253
+ Number of pixels by which to erode the identified features.
1254
+ Default is 0.
1255
+
1256
+ n_min_threshold : int, dict of float to int, or list of int, optional
1257
+ Minimum number of identified contiguous pixels for a feature to be detected. Default is 0.
1258
+ If given as a list, the number of elements must match number of thresholds.
1259
+ If given as a dict, the keys need to match the thresholds and the values are the minimum number of identified contiguous pixels for a feature using that specific threshold.
1260
+
1261
+ min_distance : float, optional
1262
+ Minimum distance between detected features (in meters). Default is 0.
1263
+
1264
+ feature_number_start : int, optional
1265
+ Feature id to start with. Default is 1.
1266
+
1267
+ PBC_flag : str('none', 'hdim_1', 'hdim_2', 'both')
1268
+ Sets whether to use periodic boundaries, and if so in which directions.
1269
+ 'none' means that we do not have periodic boundaries
1270
+ 'hdim_1' means that we are periodic along hdim1
1271
+ 'hdim_2' means that we are periodic along hdim2
1272
+ 'both' means that we are periodic along both horizontal dimensions
1273
+ vertical_coord: str
1274
+ Name of the vertical coordinate. If None, tries to auto-detect.
1275
+ It looks for the coordinate or the dimension name corresponding
1276
+ to the string.
1277
+ vertical_axis: int or None.
1278
+ The vertical axis number of the data. If None, uses vertical_coord
1279
+ to determine axis. This must be >=0.
1280
+ detect_subset: dict-like or None
1281
+ Whether to run feature detection on only a subset of the data.
1282
+ If this is not None, it will subset the grid that we run feature detection
1283
+ on to the range specified for each axis specified. The format of this dict is:
1284
+ {axis-number: (start, end)}, where axis-number is the number of the axis to subset,
1285
+ start is inclusive, and end is exclusive.
1286
+ For example, if your data are oriented as (time, z, y, x) and you want to
1287
+ only detect on values between z levels 10 and 29, you would set:
1288
+ {1: (10, 30)}.
1289
+ wavelength_filtering: tuple, optional
1290
+ Minimum and maximum wavelength for horizontal spectral filtering in meter.
1291
+ Default is None.
1292
+
1293
+ dz : float
1294
+ Constant vertical grid spacing (m), optional. If not specified
1295
+ and the input is 3D, this function requires that `altitude` is available
1296
+ in the `features` input. If you specify a value here, this function assumes
1297
+ that it is the constant z spacing between points, even if ```z_coordinate_name```
1298
+ is specified.
1299
+
1300
+ strict_thresholding: Bool, optional
1301
+ If True, a feature can only be detected if all previous thresholds have been met.
1302
+ Default is False.
1303
+
1304
+ use_standard_names: bool
1305
+ If true, when interpolating a coordinate, it looks for a standard_name
1306
+ and uses that to name the output coordinate, to mimic iris functionality.
1307
+ If false, uses the actual name of the coordinate to output.
1308
+
1309
+ statistic : dict, optional
1310
+ Default is None. Optional parameter to calculate bulk statistics within feature detection.
1311
+ Dictionary with callable function(s) to apply over the region of each detected feature and
1312
+ the name of the statistics to appear in the feature output dataframe.
1313
+ The functions should be the values and the names of the metric the keys (e.g. {'mean': np.mean})
1314
+
1315
+ statistics_unsmoothed: bool, optional
1316
+ Default is False. If True, calculate the statistics on the raw data instead of the smoothed input data.
1317
+
1318
+ return_labels: bool, optional
1319
+ Default is False. If True, return the label fields.
1320
+
1321
+ preserve_iris_datetime_types: bool, optional, default: True
1322
+ If True, for iris input, preserve the original datetime type (typically
1323
+ `cftime.DatetimeGregorian`) where possible. For xarray input, this parameter has no
1324
+ effect.
1325
+
1326
+ kwargs : dict
1327
+ Additional keyword arguments.
1328
+
1329
+
1330
+ Returns
1331
+ -------
1332
+ features : pandas.DataFrame
1333
+ Detected features. The structure of this dataframe is explained
1334
+ `here <https://tobac.readthedocs.io/en/latest/data_input.html>`__
1335
+
1336
+ labels : xarray DataArray, optional
1337
+ Label fields for the respective thresholds. Only returned if
1338
+ return_labels is True.
1339
+
1340
+ """
1341
+ from .utils import add_coordinates, add_coordinates_3D
1342
+
1343
+ time_var_name: str = "time"
1344
+ logging.debug("start feature detection based on thresholds")
1345
+
1346
+ ndim_time = internal_utils.find_axis_from_coord(field_in, time_var_name)
1347
+
1348
+ # Check whether we need to run 2D or 3D feature detection
1349
+ if field_in.ndim == 3:
1350
+ logging.debug("Running 2D feature detection")
1351
+ is_3D = False
1352
+ elif field_in.ndim == 4:
1353
+ logging.debug("Running 3D feature detection")
1354
+ is_3D = True
1355
+ else:
1356
+ raise ValueError("Feature detection only works with 2D or 3D data")
1357
+
1358
+ if detect_subset is not None:
1359
+ raise NotImplementedError("Subsetting feature detection not yet supported.")
1360
+
1361
+ if detect_subset is not None and ndim_time in detect_subset:
1362
+ raise NotImplementedError("Cannot subset on time")
1363
+
1364
+ # Remember if dz is set and not vertical coord for min distance filtering
1365
+ use_dz_for_filtering = dz is not None
1366
+
1367
+ if is_3D:
1368
+ # We need to determine the time axis so that we can determine the
1369
+ # vertical axis in each timestep if vertical_axis is not none.
1370
+
1371
+ if vertical_axis is not None and vertical_coord is not None:
1372
+ raise ValueError(
1373
+ "Only one of vertical_axis or vertical_coord should be set."
1374
+ )
1375
+
1376
+ if vertical_axis is None:
1377
+ # We need to determine vertical axis.
1378
+ # first, find the name of the vertical axis
1379
+ vertical_axis_name = internal_utils.find_vertical_coord_name(
1380
+ field_in, vertical_coord=vertical_coord
1381
+ )
1382
+ # then find our axis number.
1383
+ vertical_axis = internal_utils.find_axis_from_coord(
1384
+ field_in, vertical_axis_name
1385
+ )
1386
+
1387
+ if vertical_axis is None:
1388
+ raise ValueError("Cannot find vertical coordinate.")
1389
+
1390
+ if vertical_axis < 0:
1391
+ raise ValueError("vertical_axis must be >=0.")
1392
+ # adjust vertical axis number down based on time
1393
+ if ndim_time < vertical_axis:
1394
+ # We only need to adjust the axis number if the time axis
1395
+ # is a lower axis number than the specified vertical coordinate.
1396
+
1397
+ vertical_axis = vertical_axis - 1
1398
+
1399
+ # if single threshold is put in as a single value, turn it into a list
1400
+ if type(threshold) in [int, float]:
1401
+ threshold = [threshold]
1402
+
1403
+ # if wavelength_filtering is given, check that value cannot be larger than distances along
1404
+ # x and y, that the value cannot be smaller or equal to the grid spacing
1405
+ # and throw a warning if dxy and wavelengths have about the same order of magnitude
1406
+ if wavelength_filtering is not None:
1407
+ if is_3D:
1408
+ raise ValueError("Wavelength filtering is not supported for 3D input data.")
1409
+ else:
1410
+ distance_x = field_in.shape[1] * (dxy)
1411
+ distance_y = field_in.shape[2] * (dxy)
1412
+ distance = min(distance_x, distance_y)
1413
+
1414
+ # make sure the smaller value is taken as the minimum and the larger as the maximum
1415
+ lambda_min = min(wavelength_filtering)
1416
+ lambda_max = max(wavelength_filtering)
1417
+
1418
+ if lambda_min > distance or lambda_max > distance:
1419
+ raise ValueError(
1420
+ "The given wavelengths cannot be larger than the total distance in m along the axes"
1421
+ " of the domain."
1422
+ )
1423
+
1424
+ elif lambda_min <= dxy:
1425
+ raise ValueError(
1426
+ "The given minimum wavelength cannot be smaller than gridspacing dxy. Please note "
1427
+ "that both dxy and the values for wavelength_filtering should be given in meter."
1428
+ )
1429
+
1430
+ elif np.floor(np.log10(lambda_min)) - np.floor(np.log10(dxy)) > 1:
1431
+ warnings.warn(
1432
+ "Warning: The values for dxy and the minimum wavelength are close in order of "
1433
+ "magnitude. Please note that both dxy and for wavelength_filtering should be "
1434
+ "given in meter."
1435
+ )
1436
+
1437
+ # Initialize lists and xarrays for holding results
1438
+ list_features_timesteps = []
1439
+ if return_labels:
1440
+ label_fields = xr.DataArray(
1441
+ np.zeros(field_in.shape, dtype=int),
1442
+ coords=field_in.coords,
1443
+ dims=field_in.dims,
1444
+ name="label_fields",
1445
+ ).assign_attrs(threshold=threshold)
1446
+ else:
1447
+ label_fields = None
1448
+
1449
+ for i_time, time_i in enumerate(field_in.coords[time_var_name]):
1450
+ data_i = field_in.isel({time_var_name: i_time})
1451
+
1452
+ args = feature_detection_multithreshold_timestep(
1453
+ data_i,
1454
+ i_time,
1455
+ threshold=threshold,
1456
+ sigma_threshold=sigma_threshold,
1457
+ min_num=min_num,
1458
+ target=target,
1459
+ position_threshold=position_threshold,
1460
+ n_erosion_threshold=n_erosion_threshold,
1461
+ n_min_threshold=n_min_threshold,
1462
+ min_distance=min_distance,
1463
+ feature_number_start=feature_number_start,
1464
+ PBC_flag=PBC_flag,
1465
+ vertical_axis=vertical_axis,
1466
+ dxy=dxy,
1467
+ wavelength_filtering=wavelength_filtering,
1468
+ strict_thresholding=strict_thresholding,
1469
+ statistic=statistic,
1470
+ statistics_unsmoothed=statistics_unsmoothed,
1471
+ return_labels=return_labels,
1472
+ )
1473
+ # Process the returned data depending on the flags
1474
+ if return_labels:
1475
+ label_fields_i, features_thresholds_i = args
1476
+ label_fields.loc[{time_var_name: time_i}] = label_fields_i
1477
+
1478
+ else:
1479
+ features_thresholds_i = args
1480
+
1481
+ list_features_timesteps.append(features_thresholds_i)
1482
+
1483
+ logging.debug("Finished feature detection for %s", time_i)
1484
+
1485
+ logging.debug("feature detection: merging DataFrames")
1486
+ # Check if features are detected and then concatenate features from different timesteps into
1487
+ # one pandas DataFrame
1488
+ # If no features are detected raise error
1489
+ if any([not x.empty for x in list_features_timesteps]):
1490
+ features = pd.concat(list_features_timesteps, ignore_index=True)
1491
+ features["feature"] = features.index + feature_number_start
1492
+
1493
+ if use_standard_names is None:
1494
+ use_standard_names = True if converted_from_iris else False
1495
+
1496
+ if "vdim" in features:
1497
+ features = add_coordinates_3D(
1498
+ features,
1499
+ field_in,
1500
+ vertical_coord=vertical_coord,
1501
+ use_standard_names=use_standard_names,
1502
+ )
1503
+ else:
1504
+ features = add_coordinates(
1505
+ features,
1506
+ field_in,
1507
+ use_standard_names=use_standard_names,
1508
+ )
1509
+
1510
+ # Loop over DataFrame to remove features that are closer than distance_min to each
1511
+ # other:
1512
+ filtered_features = []
1513
+ if min_distance > 0:
1514
+ hdim1_ax, hdim2_ax = internal_utils.find_hdim_axes_3D(
1515
+ field_in, vertical_coord=vertical_coord
1516
+ )
1517
+ hdim1_max = field_in.shape[hdim1_ax] - 1
1518
+ hdim2_max = field_in.shape[hdim2_ax] - 1
1519
+
1520
+ for _, features_frame in features.groupby("frame"):
1521
+ filtered_features.append(
1522
+ filter_min_distance(
1523
+ features_frame,
1524
+ dxy=dxy,
1525
+ dz=dz if use_dz_for_filtering else None,
1526
+ min_distance=min_distance,
1527
+ z_coordinate_name=(
1528
+ None if use_dz_for_filtering else vertical_coord
1529
+ ),
1530
+ target=target,
1531
+ PBC_flag=PBC_flag,
1532
+ min_h1=0,
1533
+ max_h1=hdim1_max,
1534
+ min_h2=0,
1535
+ max_h2=hdim2_max,
1536
+ )
1537
+ )
1538
+ features = pd.concat(filtered_features, ignore_index=True)
1539
+
1540
+ # we map the feature index to the original index
1541
+ if return_labels:
1542
+
1543
+ for i, time_i, label_field_i, features_i in field_and_features_over_time(
1544
+ label_fields, features
1545
+ ):
1546
+ wh_all_labels = np.isin(label_field_i, features_i.idx)
1547
+
1548
+ remapper = xr.DataArray(
1549
+ features_i.feature, dims=("idx",), coords=dict(idx=features_i.idx)
1550
+ )
1551
+
1552
+ label_fields[i].data[wh_all_labels] = remapper.loc[
1553
+ label_field_i.data[wh_all_labels]
1554
+ ]
1555
+ label_fields[i].data[~wh_all_labels] = 0
1556
+
1557
+ else:
1558
+ features = None
1559
+ label_fields = None
1560
+ logging.debug("No features detected")
1561
+
1562
+ logging.debug("feature detection completed")
1563
+
1564
+ # Create the final output
1565
+ if return_labels:
1566
+ return label_fields, features
1567
+ else:
1568
+ return features
1569
+
1570
+
1571
+ def filter_min_distance(
1572
+ features: pd.DataFrame,
1573
+ dxy: float = None,
1574
+ dz: float = None,
1575
+ min_distance: float = None,
1576
+ x_coordinate_name: str = None,
1577
+ y_coordinate_name: str = None,
1578
+ z_coordinate_name: str = None,
1579
+ target: Literal["maximum", "minimum"] = "maximum",
1580
+ PBC_flag: Literal["none", "hdim_1", "hdim_2", "both"] = "none",
1581
+ min_h1: int = 0,
1582
+ max_h1: int = 0,
1583
+ min_h2: int = 0,
1584
+ max_h2: int = 0,
1585
+ ) -> pd.DataFrame:
1586
+ """Function to remove features that are too close together.
1587
+ If two features are closer than `min_distance`, it keeps the
1588
+ larger feature.
1589
+
1590
+ :hidden:
1591
+
1592
+
1593
+ Parameters
1594
+ ----------
1595
+ features: pandas DataFrame
1596
+ features
1597
+ dxy: float
1598
+ Constant horzontal grid spacing (meters).
1599
+ dz: float
1600
+ Constant vertical grid spacing (meters), optional. If not specified
1601
+ and the input is 3D, this function requires that `z_coordinate_name` is available
1602
+ in the `features` input. If you specify a value here, this function assumes
1603
+ that it is the constant z spacing between points, even if ```z_coordinate_name```
1604
+ is specified.
1605
+ min_distance: float
1606
+ minimum distance between detected features (meters)
1607
+ x_coordinate_name: str
1608
+ The name of the x coordinate to calculate distance based on in meters.
1609
+ This is typically `projection_x_coordinate`. Currently unused.
1610
+ y_coordinate_name: str
1611
+ The name of the y coordinate to calculate distance based on in meters.
1612
+ This is typically `projection_y_coordinate`. Currently unused.
1613
+ z_coordinate_name: str or None
1614
+ The name of the z coordinate to calculate distance based on in meters.
1615
+ This is typically `altitude`. If None, tries to auto-detect.
1616
+ target: {'maximum', 'minimum'}, optional
1617
+ Flag to determine if tracking is targeting minima or maxima in
1618
+ the data. Default is 'maximum'.
1619
+ PBC_flag : str('none', 'hdim_1', 'hdim_2', 'both'), optional
1620
+ Sets whether to use periodic boundaries, and if so in which directions.
1621
+ 'none' means that we do not have periodic boundaries
1622
+ 'hdim_1' means that we are periodic along hdim1
1623
+ 'hdim_2' means that we are periodic along hdim2
1624
+ 'both' means that we are periodic along both horizontal dimensions
1625
+ min_h1: int, optional
1626
+ Minimum real point in hdim_1, for use with periodic boundaries.
1627
+ max_h1: int, optional
1628
+ Maximum point in hdim_1, exclusive. max_h1-min_h1 should be the size.
1629
+ min_h2: int, optional
1630
+ Minimum real point in hdim_2, for use with periodic boundaries.
1631
+ max_h2: int, optional
1632
+ Maximum point in hdim_2, exclusive. max_h2-min_h2 should be the size.
1633
+
1634
+ Returns
1635
+ -------
1636
+ pandas DataFrame
1637
+ features after filtering
1638
+ """
1639
+ # Optional coordinate names are not yet implemented, set to defaults here:
1640
+ if dxy is None:
1641
+ raise NotImplementedError("dxy currently must be set.")
1642
+
1643
+ # Check if both dxy and their coordinate names are specified.
1644
+ # If they are, warn that we will use dxy.
1645
+ elif x_coordinate_name in features and y_coordinate_name in features:
1646
+ warnings.warn(
1647
+ "Both " + x_coordinate_name + "/" + y_coordinate_name + " and dxy "
1648
+ "set. Using constant dxy. Set dxy to None if you want to use the "
1649
+ "interpolated coordinates, or set `x_coordinate_name` and "
1650
+ "`y_coordinate_name` to None to use a constant dxy."
1651
+ )
1652
+ y_coordinate_name = "hdim_1"
1653
+ x_coordinate_name = "hdim_2"
1654
+ # If dxy only, use hdim_1, hdim_1 as default horizontal dimensions
1655
+ else:
1656
+ y_coordinate_name = "hdim_1"
1657
+ x_coordinate_name = "hdim_2"
1658
+
1659
+ # if we are 3D, the vertical dimension is in features
1660
+ is_3D = "vdim" in features
1661
+ if is_3D:
1662
+ if dz is None:
1663
+ # Find vertical coord name and set dz to 1
1664
+ z_coordinate_name = internal_utils.find_dataframe_vertical_coord(
1665
+ variable_dataframe=features, vertical_coord=z_coordinate_name
1666
+ )
1667
+ dz = 1
1668
+ else:
1669
+ # Use dz, warn if both are set
1670
+ if z_coordinate_name is not None:
1671
+ warnings.warn(
1672
+ "Both "
1673
+ + z_coordinate_name
1674
+ + " and dz available to filter_min_distance; using constant dz. "
1675
+ "Set dz to none if you want to use altitude or set `z_coordinate_name` to None to use "
1676
+ "constant dz.",
1677
+ UserWarning,
1678
+ )
1679
+ z_coordinate_name = "vdim"
1680
+
1681
+ if target not in ["minimum", "maximum"]:
1682
+ raise ValueError(
1683
+ "target parameter must be set to either 'minimum' or 'maximum'"
1684
+ )
1685
+
1686
+ # Calculate feature locations in cartesian coordinates
1687
+ if is_3D:
1688
+ feature_locations = features[
1689
+ [z_coordinate_name, y_coordinate_name, x_coordinate_name]
1690
+ ].to_numpy()
1691
+ feature_locations[:, 0] *= dz
1692
+ feature_locations[:, 1:] *= dxy
1693
+ else:
1694
+ feature_locations = (
1695
+ features[[y_coordinate_name, x_coordinate_name]].to_numpy() * dxy
1696
+ )
1697
+
1698
+ # Create array of flags for features to remove
1699
+ removal_flag = np.zeros(len(features), dtype=bool)
1700
+
1701
+ # Create Tree of feature locations in cartesian coordinates
1702
+ # Check if we have PBCs.
1703
+ if PBC_flag in ["hdim_1", "hdim_2", "both"]:
1704
+ # Note that we multiply by dxy to get the distances in spatial coordinates
1705
+ dist_func = pbc_utils.build_distance_function(
1706
+ min_h1 * dxy, max_h1 * dxy, min_h2 * dxy, max_h2 * dxy, PBC_flag, is_3D
1707
+ )
1708
+ features_tree = BallTree(feature_locations, metric="pyfunc", func=dist_func)
1709
+ neighbours = features_tree.query_radius(feature_locations, r=min_distance)
1710
+
1711
+ else:
1712
+ features_tree = KDTree(feature_locations)
1713
+ # Find neighbours for each point
1714
+ neighbours = features_tree.query_ball_tree(features_tree, r=min_distance)
1715
+
1716
+ # Iterate over list of neighbours to find which features to remove
1717
+ for i, neighbour_list in enumerate(neighbours):
1718
+ if len(neighbour_list) > 1:
1719
+ # Remove the feature we're interested in as it's always included
1720
+ neighbour_list = list(neighbour_list)
1721
+ neighbour_list.remove(i)
1722
+ # If maximum target check if any neighbours have a larger threshold value
1723
+ if target == "maximum" and np.any(
1724
+ features["threshold_value"].iloc[neighbour_list]
1725
+ > features["threshold_value"].iloc[i]
1726
+ ):
1727
+ removal_flag[i] = True
1728
+ # If minimum target check if any neighbours have a smaller threshold value
1729
+ elif target == "minimum" and np.any(
1730
+ features["threshold_value"].iloc[neighbour_list]
1731
+ < features["threshold_value"].iloc[i]
1732
+ ):
1733
+ removal_flag[i] = True
1734
+ # Else check if any neighbours have an equal threshold value
1735
+ else:
1736
+ wh_equal_threshold = (
1737
+ features["threshold_value"].iloc[neighbour_list]
1738
+ == features["threshold_value"].iloc[i]
1739
+ )
1740
+ if np.any(wh_equal_threshold):
1741
+ # Check if any have a larger number of points
1742
+ if np.any(
1743
+ features["num"].iloc[neighbour_list][wh_equal_threshold]
1744
+ > features["num"].iloc[i]
1745
+ ):
1746
+ removal_flag[i] = True
1747
+ # Check if any have the same number of points and a lower index value
1748
+ else:
1749
+ wh_equal_area = (
1750
+ features["num"].iloc[neighbour_list][wh_equal_threshold]
1751
+ == features["num"].iloc[i]
1752
+ )
1753
+ if np.any(wh_equal_area):
1754
+ if np.any(wh_equal_area.index[wh_equal_area] < i):
1755
+ removal_flag[i] = True
1756
+
1757
+ # Return the features that are not flagged for removal
1758
+ return features.iloc[~removal_flag]