tobac 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. tobac/__init__.py +112 -0
  2. tobac/analysis/__init__.py +31 -0
  3. tobac/analysis/cell_analysis.py +628 -0
  4. tobac/analysis/feature_analysis.py +212 -0
  5. tobac/analysis/spatial.py +619 -0
  6. tobac/centerofgravity.py +226 -0
  7. tobac/feature_detection.py +1758 -0
  8. tobac/merge_split.py +324 -0
  9. tobac/plotting.py +2321 -0
  10. tobac/segmentation/__init__.py +10 -0
  11. tobac/segmentation/watershed_segmentation.py +1316 -0
  12. tobac/testing.py +1179 -0
  13. tobac/tests/segmentation_tests/test_iris_xarray_segmentation.py +0 -0
  14. tobac/tests/segmentation_tests/test_segmentation.py +1183 -0
  15. tobac/tests/segmentation_tests/test_segmentation_time_pad.py +104 -0
  16. tobac/tests/test_analysis_spatial.py +1109 -0
  17. tobac/tests/test_convert.py +265 -0
  18. tobac/tests/test_datetime.py +216 -0
  19. tobac/tests/test_decorators.py +148 -0
  20. tobac/tests/test_feature_detection.py +1321 -0
  21. tobac/tests/test_generators.py +273 -0
  22. tobac/tests/test_import.py +24 -0
  23. tobac/tests/test_iris_xarray_match_utils.py +244 -0
  24. tobac/tests/test_merge_split.py +351 -0
  25. tobac/tests/test_pbc_utils.py +497 -0
  26. tobac/tests/test_sample_data.py +197 -0
  27. tobac/tests/test_testing.py +747 -0
  28. tobac/tests/test_tracking.py +714 -0
  29. tobac/tests/test_utils.py +650 -0
  30. tobac/tests/test_utils_bulk_statistics.py +789 -0
  31. tobac/tests/test_utils_coordinates.py +328 -0
  32. tobac/tests/test_utils_internal.py +97 -0
  33. tobac/tests/test_xarray_utils.py +232 -0
  34. tobac/tracking.py +613 -0
  35. tobac/utils/__init__.py +27 -0
  36. tobac/utils/bulk_statistics.py +360 -0
  37. tobac/utils/datetime.py +184 -0
  38. tobac/utils/decorators.py +540 -0
  39. tobac/utils/general.py +753 -0
  40. tobac/utils/generators.py +87 -0
  41. tobac/utils/internal/__init__.py +2 -0
  42. tobac/utils/internal/coordinates.py +430 -0
  43. tobac/utils/internal/iris_utils.py +462 -0
  44. tobac/utils/internal/label_props.py +82 -0
  45. tobac/utils/internal/xarray_utils.py +439 -0
  46. tobac/utils/mask.py +364 -0
  47. tobac/utils/periodic_boundaries.py +419 -0
  48. tobac/wrapper.py +244 -0
  49. tobac-1.6.2.dist-info/METADATA +154 -0
  50. tobac-1.6.2.dist-info/RECORD +53 -0
  51. tobac-1.6.2.dist-info/WHEEL +5 -0
  52. tobac-1.6.2.dist-info/licenses/LICENSE +29 -0
  53. tobac-1.6.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1321 @@
1
+ import cftime
2
+ import tobac
3
+ import tobac.testing as tbtest
4
+ import tobac.feature_detection as feat_detect
5
+ import pytest
6
+ import numpy as np
7
+ import xarray as xr
8
+ from pandas.testing import assert_frame_equal
9
+
10
+
11
+ @pytest.mark.parametrize(
12
+ "test_threshs, n_min_threshold, dxy, wavelength_filtering",
13
+ [
14
+ ([1.5], 2, -1, None),
15
+ ([1, 1.5, 2], 2, 10000, (100 * 1000, 500 * 1000)),
16
+ ([1, 2, 1.5], [3, 1, 2], -1, None),
17
+ ([1, 1.5, 2], {1.5: 2, 1: 3, 2: 1}, -1, None),
18
+ ],
19
+ )
20
+ def test_feature_detection_multithreshold_timestep(
21
+ test_threshs, n_min_threshold, dxy, wavelength_filtering
22
+ ):
23
+ """
24
+ Tests ```tobac.feature_detection.feature_detection_multithreshold_timestep```
25
+ """
26
+
27
+ # start by building a simple dataset with a single feature and seeing
28
+ # if we identify it
29
+
30
+ test_dset_size = (50, 50)
31
+ test_hdim_1_pt = 20.0
32
+ test_hdim_2_pt = 20.0
33
+ test_hdim_1_sz = 5
34
+ test_hdim_2_sz = 5
35
+ test_amp = 2
36
+
37
+ test_data = np.zeros(test_dset_size)
38
+ test_data = tbtest.make_feature_blob(
39
+ test_data,
40
+ test_hdim_1_pt,
41
+ test_hdim_2_pt,
42
+ h1_size=test_hdim_1_sz,
43
+ h2_size=test_hdim_2_sz,
44
+ amplitude=test_amp,
45
+ )
46
+ test_data_xr = tbtest.make_dataset_from_arr(test_data, data_type="xarray")
47
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
48
+ test_data_xr,
49
+ 0,
50
+ threshold=test_threshs,
51
+ n_min_threshold=n_min_threshold,
52
+ dxy=dxy,
53
+ wavelength_filtering=wavelength_filtering,
54
+ )
55
+
56
+ # Make sure we have only one feature
57
+ assert len(fd_output) == 1, f"Expected 1 feature, but got {len(fd_output)}"
58
+ # Make sure that the location of the feature is correct
59
+ assert fd_output.iloc[0]["hdim_1"] == pytest.approx(
60
+ test_hdim_1_pt
61
+ ), f"Expected hdim_1 to be {test_hdim_1_pt}, but got {fd_output.iloc[0]['hdim_1']}"
62
+ assert fd_output.iloc[0]["hdim_2"] == pytest.approx(
63
+ test_hdim_2_pt
64
+ ), f"Expected hdim_2 to be {test_hdim_2_pt}, but got {fd_output.iloc[0]['hdim_2']}"
65
+
66
+ labels, features = feat_detect.feature_detection_multithreshold_timestep(
67
+ test_data_xr,
68
+ 0,
69
+ threshold=test_threshs,
70
+ n_min_threshold=n_min_threshold,
71
+ dxy=dxy,
72
+ wavelength_filtering=wavelength_filtering,
73
+ return_labels=True,
74
+ )
75
+
76
+ # Make sure we have only one feature
77
+ assert (
78
+ len(features.index) == 1
79
+ ), f"Expected 1 feature, but got {len(features.index)}"
80
+
81
+ # Check if labels are returned
82
+ assert isinstance(
83
+ labels, xr.DataArray
84
+ ), "Expected label fields to be a xarray.DataArray"
85
+
86
+ # Check if labels have the correct shape
87
+ assert labels.shape == (
88
+ test_data_xr.shape[0],
89
+ test_data_xr.shape[1],
90
+ ), f"Expected labels shape to be {test_data_xr.shape}, but got {labels.shape}"
91
+
92
+ # Ensure labels have at least one non-zero entry
93
+ assert (labels > 0).any(), "No labels detected in the labels array"
94
+
95
+ # Optionally check for the threshold attribute
96
+ assert hasattr(labels, "threshold"), "Expected 'threshold' attribute in labels"
97
+ assert (
98
+ labels.attrs["threshold"] == test_threshs
99
+ ), f"Expected threshold to be {test_threshs}, but got {labels.attrs['threshold']}"
100
+
101
+ # All non-zero labels must match the feature IDs in the returned dataframe
102
+ nonzero_labels = labels.values[labels.values > 0]
103
+ unique_label_value = np.unique(nonzero_labels)[0]
104
+ feature_label_value = features["idx"].iloc[0]
105
+
106
+ assert (
107
+ unique_label_value == feature_label_value
108
+ ), f"Label field contains {unique_label_value}, but features dataframe idx is {feature_label_value}"
109
+
110
+ # All labeled points are <= the threshold value in the input field
111
+ mask = labels.values > 0
112
+ labeled_values = test_data_xr.values[mask]
113
+ feature_threshold = features["threshold_value"].iloc[0]
114
+ print(labeled_values)
115
+ assert np.all(labeled_values >= feature_threshold), (
116
+ f"Found labeled pixels below threshold {feature_threshold}. "
117
+ f"Minimum labeled value is {labeled_values.min()}"
118
+ )
119
+
120
+
121
+ @pytest.mark.parametrize(
122
+ "position_threshold", [("center"), ("extreme"), ("weighted_diff"), ("weighted_abs")]
123
+ )
124
+ def test_feature_detection_position(position_threshold):
125
+ """
126
+ Tests to make sure that all feature detection position_thresholds work.
127
+ """
128
+
129
+ test_dset_size = (50, 50)
130
+
131
+ test_data = np.zeros(test_dset_size)
132
+
133
+ test_data[0:5, 0:5] = 3
134
+ test_threshs = [
135
+ 1.5,
136
+ ]
137
+ test_min_num = 2
138
+
139
+ test_data_iris = tbtest.make_dataset_from_arr(test_data, data_type="iris")
140
+
141
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
142
+ test_data_iris,
143
+ 0,
144
+ threshold=test_threshs,
145
+ n_min_threshold=test_min_num,
146
+ position_threshold=position_threshold,
147
+ )
148
+
149
+ pass
150
+
151
+
152
+ @pytest.mark.parametrize(
153
+ "feature_1_loc, feature_2_loc, dxy, dz, min_distance,"
154
+ "target, add_x_coords, add_y_coords,"
155
+ "add_z_coords, PBC_flag, expect_feature_1, expect_feature_2",
156
+ [
157
+ ( # If separation greater than min_distance, keep both features
158
+ (0, 0, 0, 4, 1),
159
+ (1, 1, 1, 4, 1),
160
+ 1000,
161
+ 100,
162
+ 1,
163
+ "maximum",
164
+ False,
165
+ False,
166
+ False,
167
+ "none",
168
+ True,
169
+ True,
170
+ ),
171
+ ( # Keep feature 1 by area
172
+ (0, 0, 0, 4, 1),
173
+ (1, 1, 1, 3, 1),
174
+ 1000,
175
+ 100,
176
+ 5000,
177
+ "maximum",
178
+ False,
179
+ False,
180
+ False,
181
+ "none",
182
+ True,
183
+ False,
184
+ ),
185
+ ( # Keep feature 2 by area
186
+ (0, 0, 0, 4, 1),
187
+ (1, 1, 1, 6, 1),
188
+ 1000,
189
+ 100,
190
+ 5000,
191
+ "maximum",
192
+ False,
193
+ False,
194
+ False,
195
+ "none",
196
+ False,
197
+ True,
198
+ ),
199
+ ( # Keep feature 1 by area
200
+ (0, 0, 0, 4, 1),
201
+ (1, 1, 1, 3, 1),
202
+ 1000,
203
+ 100,
204
+ 5000,
205
+ "minimum",
206
+ False,
207
+ False,
208
+ False,
209
+ "none",
210
+ True,
211
+ False,
212
+ ),
213
+ ( # Keep feature 2 by area
214
+ (0, 0, 0, 4, 1),
215
+ (1, 1, 1, 6, 1),
216
+ 1000,
217
+ 100,
218
+ 5000,
219
+ "minimum",
220
+ False,
221
+ False,
222
+ False,
223
+ "none",
224
+ False,
225
+ True,
226
+ ),
227
+ ( # Keep feature 1 by maximum threshold
228
+ (0, 0, 0, 4, 2),
229
+ (1, 1, 1, 10, 1),
230
+ 1000,
231
+ 100,
232
+ 5000,
233
+ "maximum",
234
+ False,
235
+ False,
236
+ False,
237
+ "none",
238
+ True,
239
+ False,
240
+ ),
241
+ ( # Keep feature 2 by maximum threshold
242
+ (0, 0, 0, 4, 2),
243
+ (1, 1, 1, 10, 3),
244
+ 1000,
245
+ 100,
246
+ 5000,
247
+ "maximum",
248
+ False,
249
+ False,
250
+ False,
251
+ "none",
252
+ False,
253
+ True,
254
+ ),
255
+ ( # Keep feature 1 by minimum threshold
256
+ (0, 0, 0, 4, -1),
257
+ (1, 1, 1, 10, 1),
258
+ 1000,
259
+ 100,
260
+ 5000,
261
+ "minimum",
262
+ False,
263
+ False,
264
+ False,
265
+ "none",
266
+ True,
267
+ False,
268
+ ),
269
+ ( # Keep feature 2 by minimum threshold
270
+ (0, 0, 0, 4, 2),
271
+ (1, 1, 1, 10, 1),
272
+ 1000,
273
+ 100,
274
+ 5000,
275
+ "minimum",
276
+ False,
277
+ False,
278
+ False,
279
+ "none",
280
+ False,
281
+ True,
282
+ ),
283
+ ( # Keep feature 1 by tie-break
284
+ (0, 0, 0, 4, 2),
285
+ (1, 1, 1, 4, 2),
286
+ 1000,
287
+ 100,
288
+ 5000,
289
+ "maximum",
290
+ False,
291
+ False,
292
+ False,
293
+ "none",
294
+ True,
295
+ False,
296
+ ),
297
+ ( # Keep feature 1 by tie-break
298
+ (0, 0, 0, 4, 2),
299
+ (1, 1, 1, 4, 2),
300
+ 1000,
301
+ 100,
302
+ 5000,
303
+ "minimum",
304
+ False,
305
+ False,
306
+ False,
307
+ "none",
308
+ True,
309
+ False,
310
+ ),
311
+ ( # If target is not maximum or minimum raise ValueError
312
+ (0, 0, 0, 4, 1),
313
+ (1, 1, 1, 4, 1),
314
+ 1000,
315
+ 100,
316
+ 1,
317
+ "__invalid_option__",
318
+ False,
319
+ False,
320
+ False,
321
+ "none",
322
+ False,
323
+ False,
324
+ ),
325
+ ( # test hdim_1 PBCs
326
+ (0, 0, 0, 4, 3),
327
+ (1, 99, 0, 4, 1),
328
+ 1000,
329
+ 100,
330
+ 3000,
331
+ "maximum",
332
+ False,
333
+ False,
334
+ False,
335
+ "hdim_1",
336
+ True,
337
+ False,
338
+ ),
339
+ ( # test hdim_2 PBCs - false case
340
+ (0, 0, 0, 4, 3),
341
+ (1, 99, 0, 4, 1),
342
+ 1000,
343
+ 100,
344
+ 3000,
345
+ "maximum",
346
+ False,
347
+ False,
348
+ False,
349
+ "hdim_2",
350
+ True,
351
+ True,
352
+ ),
353
+ ( # test hdim_2 PBCs - true case
354
+ (0, 0, 0, 4, 3),
355
+ (1, 0, 99, 4, 1),
356
+ 1000,
357
+ 100,
358
+ 3000,
359
+ "maximum",
360
+ False,
361
+ False,
362
+ False,
363
+ "hdim_2",
364
+ True,
365
+ False,
366
+ ),
367
+ ( # test both PBCs - true case
368
+ (0, 0, 0, 4, 3),
369
+ (1, 99, 99, 4, 1),
370
+ 1000,
371
+ 100,
372
+ 3000,
373
+ "maximum",
374
+ False,
375
+ False,
376
+ False,
377
+ "both",
378
+ True,
379
+ False,
380
+ ),
381
+ ( # Test using z coord name
382
+ (0, 0, 0, 4, 1),
383
+ (1, 1, 1, 4, 1),
384
+ 1000,
385
+ None,
386
+ 1,
387
+ "maximum",
388
+ False,
389
+ False,
390
+ True,
391
+ "none",
392
+ True,
393
+ True,
394
+ ),
395
+ ( # Test using z coord name
396
+ (0, 0, 0, 5, 1),
397
+ (1, 1, 1, 4, 1),
398
+ 1,
399
+ None,
400
+ 101,
401
+ "maximum",
402
+ False,
403
+ False,
404
+ True,
405
+ "none",
406
+ True,
407
+ False,
408
+ ),
409
+ ],
410
+ )
411
+ def test_filter_min_distance(
412
+ feature_1_loc,
413
+ feature_2_loc,
414
+ dxy,
415
+ dz,
416
+ min_distance,
417
+ target,
418
+ add_x_coords,
419
+ add_y_coords,
420
+ add_z_coords,
421
+ PBC_flag,
422
+ expect_feature_1,
423
+ expect_feature_2,
424
+ ):
425
+ """Tests tobac.feature_detection.filter_min_distance
426
+ Parameters
427
+ ----------
428
+ feature_1_loc: tuple, length of 4 or 5
429
+ Feature 1 location, num, and threshold value (assumes a 100 x 100 x 100 grid).
430
+ Assumes z, y, x, num, threshold_value for 3D where num is the size/ 'num'
431
+ column of the feature and threshold_value is the threshold_value.
432
+ If 2D, assumes y, x, num, threshold_value.
433
+ feature_2_loc: tuple, length of 4 or 5
434
+ Feature 2 location, same format and length as `feature_1_loc`
435
+ dxy: float or None
436
+ Horizontal grid spacing
437
+ dz: float or None
438
+ Vertical grid spacing (constant)
439
+ min_distance: float
440
+ Minimum distance between features (m)
441
+ target: str ["maximum" | "minimum"]
442
+ Target maxima or minima threshold for selecting which feature to keep
443
+ add_x_coords: bool
444
+ Whether or not to add x coordinates
445
+ add_y_coords: bool
446
+ Whether or not to add y coordinates
447
+ add_z_coords: bool
448
+ Whether or not to add z coordinates
449
+ PBC_flag : str('none', 'hdim_1', 'hdim_2', 'both')
450
+ Sets whether to use periodic boundaries, and if so in which directions.
451
+ 'none' means that we do not have periodic boundaries
452
+ 'hdim_1' means that we are periodic along hdim1
453
+ 'hdim_2' means that we are periodic along hdim2
454
+ 'both' means that we are periodic along both horizontal dimensions
455
+ expect_feature_1: bool
456
+ True if we expect feature 1 to remain, false if we expect it gone.
457
+ expect_feature_2: bool
458
+ True if we expect feature 2 to remain, false if we expect it gone.
459
+ """
460
+ import pandas as pd
461
+ import numpy as np
462
+
463
+ h1_max = 100
464
+ h2_max = 100
465
+ z_max = 100
466
+
467
+ assumed_dxy = 100
468
+ assumed_dz = 100
469
+
470
+ x_coord_name = "projection_coord_x"
471
+ y_coord_name = "projection_coord_y"
472
+ z_coord_name = "projection_coord_z"
473
+
474
+ is_3D = len(feature_1_loc) == 5
475
+ start_size_loc = 3 if is_3D else 2
476
+ start_h1_loc = 1 if is_3D else 0
477
+ feat_opts_f1 = {
478
+ "start_h1": feature_1_loc[start_h1_loc],
479
+ "start_h2": feature_1_loc[start_h1_loc + 1],
480
+ "max_h1": h1_max,
481
+ "max_h2": h2_max,
482
+ "feature_size": feature_1_loc[start_size_loc],
483
+ "threshold_val": feature_1_loc[start_size_loc + 1],
484
+ "feature_num": 1,
485
+ }
486
+
487
+ feat_opts_f2 = {
488
+ "start_h1": feature_2_loc[start_h1_loc],
489
+ "start_h2": feature_2_loc[start_h1_loc + 1],
490
+ "max_h1": h1_max,
491
+ "max_h2": h2_max,
492
+ "feature_size": feature_2_loc[start_size_loc],
493
+ "threshold_val": feature_2_loc[start_size_loc + 1],
494
+ "feature_num": 2,
495
+ }
496
+ if is_3D:
497
+ feat_opts_f1["start_v"] = feature_1_loc[0]
498
+ feat_opts_f2["start_v"] = feature_2_loc[0]
499
+
500
+ feat_1_interp = tbtest.generate_single_feature(**feat_opts_f1)
501
+ feat_2_interp = tbtest.generate_single_feature(**feat_opts_f2)
502
+
503
+ feat_combined = pd.concat([feat_1_interp, feat_2_interp], ignore_index=True)
504
+
505
+ filter_dist_opts = {
506
+ "features": feat_combined,
507
+ "dxy": dxy,
508
+ "dz": dz,
509
+ "min_distance": min_distance,
510
+ "target": target,
511
+ "PBC_flag": PBC_flag,
512
+ "min_h1": 0,
513
+ "max_h1": 100,
514
+ "min_h2": 0,
515
+ "max_h2": 100,
516
+ }
517
+ if add_x_coords:
518
+ feat_combined[x_coord_name] = feat_combined["hdim_2"] * assumed_dxy
519
+ filter_dist_opts["x_coordinate_name"] = x_coord_name
520
+ if add_y_coords:
521
+ feat_combined[y_coord_name] = feat_combined["hdim_1"] * assumed_dxy
522
+ filter_dist_opts["y_coordinate_name"] = y_coord_name
523
+ if add_z_coords and is_3D:
524
+ feat_combined[z_coord_name] = feat_combined["vdim"] * assumed_dz
525
+ filter_dist_opts["z_coordinate_name"] = z_coord_name
526
+
527
+ if target not in ["maximum", "minimum"]:
528
+ with pytest.raises(ValueError):
529
+ out_feats = feat_detect.filter_min_distance(**filter_dist_opts)
530
+
531
+ else:
532
+ out_feats = feat_detect.filter_min_distance(**filter_dist_opts)
533
+
534
+ assert expect_feature_1 == (np.sum(out_feats["feature"] == 1) == 1)
535
+ assert expect_feature_2 == (np.sum(out_feats["feature"] == 2) == 1)
536
+
537
+
538
+ @pytest.mark.parametrize(
539
+ "test_dset_size, vertical_axis_num, "
540
+ "vertical_coord_name,"
541
+ " vertical_coord_opt, expected_raise",
542
+ [
543
+ ((1, 20, 30, 40), 1, "altitude", None, False),
544
+ ((1, 20, 30, 40), 2, "altitude", None, False),
545
+ ((1, 20, 30, 40), 3, "altitude", None, False),
546
+ ((1, 20, 30, 40), 1, "air_pressure", "air_pressure", False),
547
+ ((1, 20, 30, 40), 1, "air_pressure", None, True),
548
+ ((1, 20, 30, 40), 1, "model_level_number", None, False),
549
+ ((1, 20, 30, 40), 1, "altitude", None, False),
550
+ ((1, 20, 30, 40), 1, "geopotential_height", None, False),
551
+ ],
552
+ )
553
+ def test_feature_detection_multiple_z_coords(
554
+ test_dset_size,
555
+ vertical_axis_num,
556
+ vertical_coord_name,
557
+ vertical_coord_opt,
558
+ expected_raise,
559
+ ):
560
+ """Tests ```tobac.feature_detection.feature_detection_multithreshold```
561
+ with different axes
562
+
563
+ Parameters
564
+ ----------
565
+ test_dset_size: tuple(int, int, int, int)
566
+ Size of the test dataset
567
+ vertical_axis_num: int (0-2, inclusive)
568
+ Which axis in test_dset_size is the vertical axis
569
+ vertical_coord_name: str
570
+ Name of the vertical coordinate.
571
+ vertical_coord_opt: str
572
+ What to pass in as the vertical coordinate option to segmentation_timestep
573
+ expected_raise: bool
574
+ True if we expect a ValueError to be raised, false otherwise
575
+ """
576
+ import numpy as np
577
+
578
+ # First, just check that input and output shapes are the same.
579
+ test_dxy = 1000
580
+ test_vdim_pt_1 = 8
581
+ test_hdim_1_pt_1 = 12
582
+ test_hdim_2_pt_1 = 12
583
+ test_data = np.zeros(test_dset_size)
584
+ test_data[0, 0:5, 0:5, 0:5] = 3
585
+ common_dset_opts = {
586
+ "in_arr": test_data,
587
+ "data_type": "iris",
588
+ "z_dim_name": vertical_coord_name,
589
+ }
590
+ if vertical_axis_num == 1:
591
+ test_data_iris = tbtest.make_dataset_from_arr(
592
+ time_dim_num=0, z_dim_num=1, y_dim_num=2, x_dim_num=3, **common_dset_opts
593
+ )
594
+ elif vertical_axis_num == 2:
595
+ test_data_iris = tbtest.make_dataset_from_arr(
596
+ time_dim_num=0, z_dim_num=2, y_dim_num=1, x_dim_num=3, **common_dset_opts
597
+ )
598
+ elif vertical_axis_num == 3:
599
+ test_data_iris = tbtest.make_dataset_from_arr(
600
+ time_dim_num=0, z_dim_num=3, y_dim_num=1, x_dim_num=2, **common_dset_opts
601
+ )
602
+
603
+ if not expected_raise:
604
+ out_df = feat_detect.feature_detection_multithreshold(
605
+ field_in=test_data_iris,
606
+ dxy=test_dxy,
607
+ threshold=[
608
+ 1.5,
609
+ ],
610
+ vertical_coord=vertical_coord_opt,
611
+ )
612
+ # Check that the vertical coordinate is returned.
613
+ print(out_df.columns)
614
+ assert vertical_coord_name in out_df
615
+ else:
616
+ # Expecting a raise
617
+ with pytest.raises(ValueError):
618
+ out_df = feat_detect.feature_detection_multithreshold(
619
+ field_in=test_data_iris,
620
+ dxy=test_dxy,
621
+ threshold=[
622
+ 1.5,
623
+ ],
624
+ vertical_coord=vertical_coord_opt,
625
+ )
626
+
627
+
628
+ def test_feature_detection_setting_multiple():
629
+ """Tests that an error is raised when vertical_axis and vertical_coord
630
+ are both set.
631
+ """
632
+ test_data = np.zeros((1, 5, 5, 5))
633
+ test_data[0, 0:5, 0:5, 0:5] = 3
634
+ common_dset_opts = {
635
+ "in_arr": test_data,
636
+ "data_type": "iris",
637
+ "z_dim_name": "altitude",
638
+ }
639
+ test_data_iris = tbtest.make_dataset_from_arr(
640
+ time_dim_num=0, z_dim_num=1, y_dim_num=2, x_dim_num=3, **common_dset_opts
641
+ )
642
+
643
+ with pytest.raises(ValueError):
644
+ _ = feat_detect.feature_detection_multithreshold(
645
+ field_in=test_data_iris,
646
+ dxy=10000,
647
+ threshold=[
648
+ 1.5,
649
+ ],
650
+ vertical_coord="altitude",
651
+ vertical_axis=1,
652
+ )
653
+
654
+
655
+ def test_feature_detection_multithreshold_returns():
656
+ """Tests regarding return_labels."""
657
+ test_data = np.zeros((1, 5, 5, 5))
658
+ test_data[0, 0:5, 0:5, 0:5] = 3
659
+ common_dset_opts = {
660
+ "in_arr": test_data,
661
+ "data_type": "xarray",
662
+ "z_dim_name": "altitude",
663
+ }
664
+ test_data_xr = tbtest.make_dataset_from_arr(
665
+ time_dim_num=0,
666
+ z_dim_num=1,
667
+ y_dim_num=2,
668
+ x_dim_num=3,
669
+ **common_dset_opts,
670
+ )
671
+
672
+ # Test when return_labels is True
673
+ labels, features = feat_detect.feature_detection_multithreshold(
674
+ field_in=test_data_xr,
675
+ dxy=10000,
676
+ threshold=[
677
+ 1.5,
678
+ ],
679
+ return_labels=True,
680
+ )
681
+ assert labels is not None, "Expected labels to be returned"
682
+ assert isinstance(labels, xr.DataArray), "Expected labels to be a xarray DataArray"
683
+ assert (labels > 0).any(), "Expected at least one labeled feature"
684
+
685
+
686
+ @pytest.mark.parametrize(
687
+ "test_threshs, target",
688
+ [
689
+ (([1, 2, 3], [3, 2, 1], [1, 3, 2]), "maximum"),
690
+ (([1, 2, 3], [3, 2, 1], [1, 3, 2]), "minimum"),
691
+ ],
692
+ )
693
+ def test_feature_detection_threshold_sort(test_threshs, target):
694
+ """Tests that feature detection is consistent regardless of what order they are in"""
695
+ test_dset_size = (50, 50)
696
+ test_hdim_1_pt = 20.0
697
+ test_hdim_2_pt = 20.0
698
+ test_hdim_1_sz = 5
699
+ test_hdim_2_sz = 5
700
+ test_amp = 2
701
+ test_min_num = 2
702
+
703
+ test_data = np.zeros(test_dset_size)
704
+ test_data = tbtest.make_feature_blob(
705
+ test_data,
706
+ test_hdim_1_pt,
707
+ test_hdim_2_pt,
708
+ h1_size=test_hdim_1_sz,
709
+ h2_size=test_hdim_2_sz,
710
+ amplitude=test_amp,
711
+ )
712
+ test_data_iris = tbtest.make_dataset_from_arr(test_data, data_type="iris")
713
+ fd_output_first = feat_detect.feature_detection_multithreshold_timestep(
714
+ test_data_iris,
715
+ 0,
716
+ threshold=test_threshs[0],
717
+ n_min_threshold=test_min_num,
718
+ dxy=1,
719
+ target=target,
720
+ )
721
+
722
+ for thresh_test in test_threshs[1:]:
723
+ fd_output_test = feat_detect.feature_detection_multithreshold_timestep(
724
+ test_data_iris,
725
+ 0,
726
+ threshold=thresh_test,
727
+ n_min_threshold=test_min_num,
728
+ dxy=1,
729
+ target=target,
730
+ )
731
+ assert_frame_equal(fd_output_first, fd_output_test)
732
+
733
+
734
+ @pytest.mark.parametrize(
735
+ "hdim_1_pt,"
736
+ "hdim_2_pt,"
737
+ "hdim_1_size,"
738
+ "hdim_2_size,"
739
+ "PBC_flag,"
740
+ "expected_center,",
741
+ [
742
+ (10, 10, 3, 3, "both", (10, 10)),
743
+ (0, 0, 3, 3, "both", (0, 0)),
744
+ (0, 0, 3, 3, "hdim_1", (0, 0.5)),
745
+ (0, 0, 3, 3, "hdim_2", (0.5, 0)),
746
+ (0, 10, 3, 3, "hdim_1", (0, 10)),
747
+ ],
748
+ )
749
+ def test_feature_detection_threshold_pbc(
750
+ hdim_1_pt, hdim_2_pt, hdim_1_size, hdim_2_size, PBC_flag, expected_center
751
+ ):
752
+ """Tests that feature detection works with periodic boundaries"""
753
+ test_dset_size = (50, 50)
754
+ test_amp = 2
755
+ test_min_num = 2
756
+
757
+ test_data = np.zeros(test_dset_size)
758
+ test_data = tbtest.make_feature_blob(
759
+ test_data,
760
+ hdim_1_pt,
761
+ hdim_2_pt,
762
+ h1_size=hdim_1_size,
763
+ h2_size=hdim_2_size,
764
+ amplitude=test_amp,
765
+ PBC_flag=PBC_flag,
766
+ )
767
+ # test_data_iris = tbtest.make_dataset_from_arr(test_data, data_type="iris")
768
+ fd_output_df, fd_output_reg = feat_detect.feature_detection_threshold(
769
+ test_data,
770
+ 0,
771
+ threshold=1,
772
+ n_min_threshold=test_min_num,
773
+ target="maximum",
774
+ PBC_flag=PBC_flag,
775
+ )
776
+ assert len(fd_output_df) == 1
777
+ assert fd_output_df["hdim_1"].values[0] == expected_center[0]
778
+ assert fd_output_df["hdim_2"].values[0] == expected_center[1]
779
+
780
+
781
+ def test_feature_detection_coords():
782
+ """Tests that the output features dataframe contains all the coords of the input iris cube"""
783
+ test_dset_size = (50, 50)
784
+ test_hdim_1_pt = 20.0
785
+ test_hdim_2_pt = 20.0
786
+ test_hdim_1_sz = 5
787
+ test_hdim_2_sz = 5
788
+ test_amp = 2
789
+ test_min_num = 2
790
+
791
+ test_data = np.zeros(test_dset_size)
792
+ test_data = tbtest.make_feature_blob(
793
+ test_data,
794
+ test_hdim_1_pt,
795
+ test_hdim_2_pt,
796
+ h1_size=test_hdim_1_sz,
797
+ h2_size=test_hdim_2_sz,
798
+ amplitude=test_amp,
799
+ )
800
+ test_data_xr = xr.DataArray(
801
+ test_data[np.newaxis, ...],
802
+ dims=("time", "y", "x"),
803
+ coords={
804
+ "time": [np.datetime64("2000-01-01T00:00:00")],
805
+ "y": np.arange(test_data.shape[0]),
806
+ "x": np.arange(test_data.shape[1]),
807
+ "2d_coord": xr.DataArray(np.random.rand(*test_data.shape), dims=("y", "x")),
808
+ },
809
+ )
810
+
811
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
812
+ test_data_xr,
813
+ threshold=[1, 2, 3],
814
+ n_min_threshold=test_min_num,
815
+ dxy=1,
816
+ target="maximum",
817
+ )
818
+
819
+ assert all([coord in fd_output for coord in test_data_xr.coords])
820
+
821
+ test_data_iris = test_data_xr.to_iris()
822
+
823
+ fd_output_iris = tobac.feature_detection.feature_detection_multithreshold(
824
+ test_data_iris,
825
+ threshold=[1, 2, 3],
826
+ n_min_threshold=test_min_num,
827
+ dxy=1,
828
+ target="maximum",
829
+ )
830
+
831
+ assert all([coord.name() in fd_output_iris for coord in test_data_iris.coords()])
832
+
833
+
834
+ def test_feature_detection_preserve_datetime():
835
+ """Tests that datetime output is of the correct type when converting to and from iris cubes"""
836
+ test_dset_size = (50, 50)
837
+ test_hdim_1_pt = 20.0
838
+ test_hdim_2_pt = 20.0
839
+ test_hdim_1_sz = 5
840
+ test_hdim_2_sz = 5
841
+ test_amp = 2
842
+ test_min_num = 2
843
+
844
+ test_data = np.zeros(test_dset_size)
845
+ test_data = tbtest.make_feature_blob(
846
+ test_data,
847
+ test_hdim_1_pt,
848
+ test_hdim_2_pt,
849
+ h1_size=test_hdim_1_sz,
850
+ h2_size=test_hdim_2_sz,
851
+ amplitude=test_amp,
852
+ )
853
+ test_data_xr = xr.DataArray(
854
+ test_data[np.newaxis, ...],
855
+ dims=("time", "y", "x"),
856
+ coords={
857
+ "time": [np.datetime64("2000-01-01T00:00:00")],
858
+ },
859
+ )
860
+
861
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
862
+ test_data_xr,
863
+ threshold=[1, 2, 3],
864
+ n_min_threshold=test_min_num,
865
+ dxy=1,
866
+ target="maximum",
867
+ )
868
+
869
+ assert isinstance(fd_output.time.to_numpy()[0], np.datetime64)
870
+
871
+ test_data_iris = test_data_xr.to_iris()
872
+
873
+ fd_output_iris_dt64 = tobac.feature_detection.feature_detection_multithreshold(
874
+ test_data_iris,
875
+ threshold=[1, 2, 3],
876
+ n_min_threshold=test_min_num,
877
+ dxy=1,
878
+ target="maximum",
879
+ preserve_iris_datetime_types=False,
880
+ )
881
+
882
+ assert isinstance(fd_output_iris_dt64.time.to_numpy()[0], np.datetime64)
883
+
884
+ fd_output_iris_cft = tobac.feature_detection.feature_detection_multithreshold(
885
+ test_data_iris,
886
+ threshold=[1, 2, 3],
887
+ n_min_threshold=test_min_num,
888
+ dxy=1,
889
+ target="maximum",
890
+ preserve_iris_datetime_types=True,
891
+ )
892
+
893
+ assert isinstance(
894
+ fd_output_iris_cft.time.to_numpy()[0], cftime.DatetimeProlepticGregorian
895
+ )
896
+
897
+
898
+ def test_feature_detection_preserve_datetime_3d():
899
+ """Tests that datetime output is of the correct type when converting to and from iris cubes with 3d data"""
900
+ test_dset_size = (10, 50, 50)
901
+ test_hdim_1_pt = 20.0
902
+ test_hdim_2_pt = 20.0
903
+ test_vdim_pt = 5
904
+ test_hdim_1_sz = 5
905
+ test_hdim_2_sz = 5
906
+ test_amp = 2
907
+ test_min_num = 2
908
+
909
+ test_data = np.zeros(test_dset_size)
910
+ test_data = tbtest.make_feature_blob(
911
+ test_data,
912
+ test_hdim_1_pt,
913
+ test_hdim_2_pt,
914
+ test_vdim_pt,
915
+ h1_size=test_hdim_1_sz,
916
+ h2_size=test_hdim_2_sz,
917
+ amplitude=test_amp,
918
+ )
919
+ test_data_xr = xr.DataArray(
920
+ test_data[np.newaxis, ...],
921
+ dims=("time", "z", "y", "x"),
922
+ coords={
923
+ "time": [np.datetime64("2000-01-01T00:00:00")],
924
+ "z": np.arange(test_data.shape[0]),
925
+ },
926
+ )
927
+
928
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
929
+ test_data_xr,
930
+ threshold=[1, 2, 3],
931
+ n_min_threshold=test_min_num,
932
+ dxy=1,
933
+ target="maximum",
934
+ )
935
+
936
+ assert isinstance(fd_output.time.to_numpy()[0], np.datetime64)
937
+
938
+ test_data_iris = test_data_xr.to_iris()
939
+
940
+ fd_output_iris_dt64 = tobac.feature_detection.feature_detection_multithreshold(
941
+ test_data_iris,
942
+ threshold=[1, 2, 3],
943
+ n_min_threshold=test_min_num,
944
+ dxy=1,
945
+ target="maximum",
946
+ preserve_iris_datetime_types=False,
947
+ )
948
+
949
+ assert isinstance(fd_output_iris_dt64.time.to_numpy()[0], np.datetime64)
950
+
951
+ fd_output_iris_cft = tobac.feature_detection.feature_detection_multithreshold(
952
+ test_data_iris,
953
+ threshold=[1, 2, 3],
954
+ n_min_threshold=test_min_num,
955
+ dxy=1,
956
+ target="maximum",
957
+ preserve_iris_datetime_types=True,
958
+ )
959
+
960
+ assert isinstance(
961
+ fd_output_iris_cft.time.to_numpy()[0], cftime.DatetimeProlepticGregorian
962
+ )
963
+
964
+
965
+ def test_feature_detection_360_day_calendar():
966
+ """Tests that datetime format and feature detection work correctly with
967
+ cftime 360-day calendars
968
+ """
969
+ test_dset_size = (50, 50)
970
+ test_hdim_1_pt = 20.0
971
+ test_hdim_2_pt = 20.0
972
+ test_hdim_1_sz = 5
973
+ test_hdim_2_sz = 5
974
+ test_amp = 2
975
+ test_min_num = 2
976
+
977
+ test_data = np.zeros(test_dset_size)
978
+ test_data = tbtest.make_feature_blob(
979
+ test_data,
980
+ test_hdim_1_pt,
981
+ test_hdim_2_pt,
982
+ h1_size=test_hdim_1_sz,
983
+ h2_size=test_hdim_2_sz,
984
+ amplitude=test_amp,
985
+ )
986
+ test_data_xr = xr.DataArray(
987
+ test_data[np.newaxis, ...],
988
+ dims=("time", "y", "x"),
989
+ coords={
990
+ "time": [cftime.Datetime360Day(2000, 1, 1)],
991
+ },
992
+ )
993
+
994
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
995
+ test_data_xr,
996
+ threshold=[1, 2, 3],
997
+ n_min_threshold=test_min_num,
998
+ dxy=1,
999
+ target="maximum",
1000
+ )
1001
+
1002
+ assert isinstance(fd_output.time.to_numpy()[0], cftime.Datetime360Day)
1003
+
1004
+ test_data_iris = test_data_xr.to_iris()
1005
+
1006
+ fd_output_iris = tobac.feature_detection.feature_detection_multithreshold(
1007
+ test_data_iris,
1008
+ threshold=[1, 2, 3],
1009
+ n_min_threshold=test_min_num,
1010
+ dxy=1,
1011
+ target="maximum",
1012
+ )
1013
+
1014
+ assert isinstance(fd_output_iris.time.to_numpy()[0], cftime.Datetime360Day)
1015
+
1016
+
1017
+ def test_strict_thresholding():
1018
+ """Tests that strict_thresholding prevents detection of features that have not met all
1019
+ previous n_min_threshold values"""
1020
+
1021
+ # Generate test dataset
1022
+ test_dset_size = (100, 100)
1023
+ test_hdim_1_pt = 50.0
1024
+ test_hdim_2_pt = 50.0
1025
+ test_hdim_1_sz = 10
1026
+ test_hdim_2_sz = 10
1027
+ test_amp = 10
1028
+ test_data = np.zeros(test_dset_size)
1029
+ test_data = tbtest.make_feature_blob(
1030
+ test_data,
1031
+ test_hdim_1_pt,
1032
+ test_hdim_2_pt,
1033
+ h1_size=test_hdim_1_sz,
1034
+ h2_size=test_hdim_2_sz,
1035
+ amplitude=test_amp,
1036
+ )
1037
+ test_data_iris = tbtest.make_dataset_from_arr(test_data, data_type="iris")
1038
+
1039
+ # All of these thresholds will be met
1040
+ thresholds = [1, 5, 7.5]
1041
+
1042
+ # The second n_min threshold can never be met
1043
+ n_min_thresholds = [0, test_data.size + 1, 0]
1044
+
1045
+ # This will detect 2 features (first and last threshold value)
1046
+ features = feat_detect.feature_detection_multithreshold_timestep(
1047
+ test_data_iris,
1048
+ 0,
1049
+ dxy=1,
1050
+ threshold=thresholds,
1051
+ n_min_threshold=n_min_thresholds,
1052
+ strict_thresholding=False,
1053
+ )
1054
+ assert len(features) == 1
1055
+ assert features["threshold_value"].item() == thresholds[-1]
1056
+
1057
+ # Since the second n_min_thresholds value is not met this will only detect 1 feature
1058
+ features = feat_detect.feature_detection_multithreshold_timestep(
1059
+ test_data_iris,
1060
+ 0,
1061
+ dxy=1,
1062
+ threshold=thresholds,
1063
+ n_min_threshold=n_min_thresholds,
1064
+ strict_thresholding=True,
1065
+ )
1066
+ assert len(features) == 1
1067
+ assert features["threshold_value"].item() == thresholds[0]
1068
+
1069
+ # Repeat for minima
1070
+ test_data_iris = tbtest.make_dataset_from_arr(10 - test_data, data_type="iris")
1071
+ # All of these thresholds will be met
1072
+ thresholds = [9, 5, 2.5]
1073
+
1074
+ # This will detect 2 features (first and last threshold value)
1075
+ features = feat_detect.feature_detection_multithreshold_timestep(
1076
+ test_data_iris,
1077
+ 0,
1078
+ dxy=1,
1079
+ threshold=thresholds,
1080
+ n_min_threshold=n_min_thresholds,
1081
+ strict_thresholding=False,
1082
+ target="minimum",
1083
+ )
1084
+ assert len(features) == 1
1085
+ assert features["threshold_value"].item() == thresholds[-1]
1086
+
1087
+ # Since the second n_min_thresholds value is not met this will only detect 1 feature
1088
+ features = feat_detect.feature_detection_multithreshold_timestep(
1089
+ test_data_iris,
1090
+ 0,
1091
+ dxy=1,
1092
+ threshold=thresholds,
1093
+ n_min_threshold=n_min_thresholds,
1094
+ strict_thresholding=True,
1095
+ target="minimum",
1096
+ )
1097
+ assert len(features) == 1
1098
+ assert features["threshold_value"].item() == thresholds[0]
1099
+
1100
+ # Test example from documentation
1101
+ input_field_arr = np.zeros((1, 101, 101))
1102
+
1103
+ for idx, side in enumerate([40, 20, 10, 5]):
1104
+ input_field_arr[
1105
+ :,
1106
+ (50 - side - 4 * idx) : (50 + side - 4 * idx),
1107
+ (50 - side - 4 * idx) : (50 + side - 4 * idx),
1108
+ ] = (
1109
+ 50 - side
1110
+ )
1111
+
1112
+ input_field_iris = xr.DataArray(
1113
+ input_field_arr,
1114
+ dims=["time", "Y", "X"],
1115
+ coords={"time": [np.datetime64("2019-01-01T00:00:00")]},
1116
+ ).to_iris()
1117
+
1118
+ thresholds = [8, 29, 39, 44]
1119
+
1120
+ n_min_thresholds = [79**2, input_field_arr.size, 8**2, 3**2]
1121
+
1122
+ features_demo = tobac.feature_detection_multithreshold(
1123
+ input_field_iris,
1124
+ dxy=1000,
1125
+ threshold=thresholds,
1126
+ n_min_threshold=n_min_thresholds,
1127
+ strict_thresholding=False,
1128
+ )
1129
+
1130
+ assert features_demo.iloc[0]["hdim_1"] == pytest.approx(37.5)
1131
+ assert features_demo.iloc[0]["hdim_2"] == pytest.approx(37.5)
1132
+
1133
+ # Now repeat with strict thresholding
1134
+ features_demo = tobac.feature_detection_multithreshold(
1135
+ input_field_iris,
1136
+ dxy=1000,
1137
+ threshold=thresholds,
1138
+ n_min_threshold=n_min_thresholds,
1139
+ strict_thresholding=True,
1140
+ )
1141
+
1142
+ assert features_demo.iloc[0]["hdim_1"] == pytest.approx(49.5)
1143
+ assert features_demo.iloc[0]["hdim_2"] == pytest.approx(49.5)
1144
+
1145
+
1146
+ @pytest.mark.parametrize(
1147
+ "h1_indices, h2_indices, max_h1, max_h2, PBC_flag, position_threshold, expected_output",
1148
+ (
1149
+ ([1], [1], 10, 10, "both", "center", (1, 1)),
1150
+ ([1, 2], [1, 2], 10, 10, "both", "center", (1.5, 1.5)),
1151
+ ([0, 1], [1, 2], 10, 10, "both", "center", (0.5, 1.5)),
1152
+ ([0, 10], [1, 1], 10, 10, "hdim_1", "center", (10.5, 1)),
1153
+ ([1, 1], [0, 10], 10, 10, "hdim_2", "center", (1, 10.5)),
1154
+ ([0, 10], [1, 1], 10, 10, "both", "center", (10.5, 1)),
1155
+ ([1, 1], [0, 10], 10, 10, "both", "center", (1, 10.5)),
1156
+ ([0, 10], [0, 10], 10, 10, "both", "center", (10.5, 10.5)),
1157
+ ([0, 1, 9, 10], [0, 0, 10, 10], 10, 10, "both", "center", (10.5, 10.5)),
1158
+ ),
1159
+ )
1160
+ def test_feature_position_pbc(
1161
+ h1_indices,
1162
+ h2_indices,
1163
+ max_h1,
1164
+ max_h2,
1165
+ PBC_flag,
1166
+ position_threshold,
1167
+ expected_output,
1168
+ ):
1169
+ """Tests to make sure that tobac.feature_detection.feature_position
1170
+ works properly with periodic boundaries.
1171
+ """
1172
+
1173
+ in_data = np.zeros((max_h1 + 1, max_h2 + 1))
1174
+ region = (0, 0, max_h1 + 1, max_h2 + 1)
1175
+
1176
+ feat_pos_output = feat_detect.feature_position(
1177
+ h1_indices,
1178
+ h2_indices,
1179
+ hdim1_max=max_h1,
1180
+ hdim2_max=max_h2,
1181
+ PBC_flag=PBC_flag,
1182
+ position_threshold=position_threshold,
1183
+ track_data=in_data,
1184
+ region_bbox=region,
1185
+ )
1186
+ assert feat_pos_output == expected_output
1187
+
1188
+
1189
+ def test_pbc_snake_feature_detection():
1190
+ """
1191
+ Test that a "snake" feature that crosses PBCs multiple times is recognized as a single feature
1192
+ """
1193
+
1194
+ test_arr = np.zeros((50, 50))
1195
+ test_arr[::4, 0] = 2
1196
+ test_arr[1::4, 0] = 2
1197
+ test_arr[3::4, 0] = 2
1198
+
1199
+ test_arr[1::4, 49] = 2
1200
+ test_arr[2::4, 49] = 2
1201
+ test_arr[3::4, 49] = 2
1202
+
1203
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr, data_type="iris")
1204
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1205
+ test_data_iris,
1206
+ 0,
1207
+ threshold=[1, 2, 3],
1208
+ n_min_threshold=2,
1209
+ dxy=1,
1210
+ target="maximum",
1211
+ PBC_flag="hdim_2",
1212
+ )
1213
+ assert len(fd_output) == 1
1214
+ # test hdim_1
1215
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr.T, data_type="iris")
1216
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1217
+ test_data_iris,
1218
+ 0,
1219
+ threshold=[1, 2, 3],
1220
+ n_min_threshold=2,
1221
+ dxy=1,
1222
+ target="maximum",
1223
+ PBC_flag="hdim_1",
1224
+ )
1225
+ assert len(fd_output) == 1
1226
+
1227
+
1228
+ def test_banded_feature():
1229
+ """
1230
+ Test that a feature that spans the length of the array is detected as one feature, and in the center.
1231
+ """
1232
+
1233
+ test_arr = np.zeros((50, 50))
1234
+ test_arr[20:22, :] = 2.5
1235
+ # Remove some values so that the distribution is not symmetric
1236
+ test_arr[20, 0] = 0
1237
+ test_arr[21, -1] = 0
1238
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr, data_type="iris")
1239
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1240
+ test_data_iris,
1241
+ 0,
1242
+ threshold=[1, 2, 3],
1243
+ n_min_threshold=2,
1244
+ dxy=1,
1245
+ target="maximum",
1246
+ PBC_flag="hdim_2",
1247
+ )
1248
+ assert len(fd_output) == 1
1249
+ assert fd_output.iloc[0]["hdim_1"] == 20.5
1250
+ assert fd_output.iloc[0]["hdim_2"] == 24.5
1251
+
1252
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr.T, data_type="iris")
1253
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1254
+ test_data_iris,
1255
+ 0,
1256
+ threshold=[1, 2, 3],
1257
+ n_min_threshold=2,
1258
+ dxy=1,
1259
+ target="maximum",
1260
+ PBC_flag="hdim_1",
1261
+ )
1262
+ assert len(fd_output) == 1
1263
+ assert fd_output.iloc[0]["hdim_2"] == 20.5
1264
+ assert fd_output.iloc[0]["hdim_1"] == 24.5
1265
+
1266
+ # Test different options for position_threshold
1267
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr, data_type="iris")
1268
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1269
+ test_data_iris,
1270
+ 0,
1271
+ threshold=[1, 2, 3],
1272
+ n_min_threshold=2,
1273
+ dxy=1,
1274
+ target="maximum",
1275
+ position_threshold="weighted_abs",
1276
+ PBC_flag="hdim_2",
1277
+ )
1278
+ assert len(fd_output) == 1
1279
+ assert fd_output.iloc[0]["hdim_1"] == pytest.approx(20.5)
1280
+ assert fd_output.iloc[0]["hdim_2"] == pytest.approx(24.5)
1281
+
1282
+ # Test different options for position_threshold
1283
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr, data_type="iris")
1284
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1285
+ test_data_iris,
1286
+ 0,
1287
+ threshold=[1, 2, 3],
1288
+ n_min_threshold=2,
1289
+ dxy=1,
1290
+ target="maximum",
1291
+ position_threshold="weighted_diff",
1292
+ PBC_flag="hdim_2",
1293
+ )
1294
+ assert len(fd_output) == 1
1295
+ assert fd_output.iloc[0]["hdim_1"] == pytest.approx(20.5)
1296
+ assert fd_output.iloc[0]["hdim_2"] == pytest.approx(24.5)
1297
+
1298
+ # Make a test case with a diagonal object to test corners
1299
+ test_arr = (
1300
+ np.zeros((50, 50))
1301
+ + np.diag(np.ones([50]))
1302
+ + np.diag(np.ones([49]), -1)
1303
+ + np.diag(np.ones([49]), 1)
1304
+ ) * 2.5
1305
+ # Remove some values so that the distribution is not symmetric
1306
+ test_arr[1, 0] = 0
1307
+ test_arr[-2, -1] = 0
1308
+ test_data_iris = tbtest.make_dataset_from_arr(test_arr, data_type="iris")
1309
+ fd_output = feat_detect.feature_detection_multithreshold_timestep(
1310
+ test_data_iris,
1311
+ 0,
1312
+ threshold=[1, 2, 3],
1313
+ n_min_threshold=2,
1314
+ dxy=1,
1315
+ target="maximum",
1316
+ position_threshold="weighted_diff",
1317
+ PBC_flag="both",
1318
+ )
1319
+ assert len(fd_output) == 1
1320
+ assert fd_output.iloc[0]["hdim_1"] == pytest.approx(24.5)
1321
+ assert fd_output.iloc[0]["hdim_2"] == pytest.approx(24.5)