tobac 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. tobac/__init__.py +112 -0
  2. tobac/analysis/__init__.py +31 -0
  3. tobac/analysis/cell_analysis.py +628 -0
  4. tobac/analysis/feature_analysis.py +212 -0
  5. tobac/analysis/spatial.py +619 -0
  6. tobac/centerofgravity.py +226 -0
  7. tobac/feature_detection.py +1758 -0
  8. tobac/merge_split.py +324 -0
  9. tobac/plotting.py +2321 -0
  10. tobac/segmentation/__init__.py +10 -0
  11. tobac/segmentation/watershed_segmentation.py +1316 -0
  12. tobac/testing.py +1179 -0
  13. tobac/tests/segmentation_tests/test_iris_xarray_segmentation.py +0 -0
  14. tobac/tests/segmentation_tests/test_segmentation.py +1183 -0
  15. tobac/tests/segmentation_tests/test_segmentation_time_pad.py +104 -0
  16. tobac/tests/test_analysis_spatial.py +1109 -0
  17. tobac/tests/test_convert.py +265 -0
  18. tobac/tests/test_datetime.py +216 -0
  19. tobac/tests/test_decorators.py +148 -0
  20. tobac/tests/test_feature_detection.py +1321 -0
  21. tobac/tests/test_generators.py +273 -0
  22. tobac/tests/test_import.py +24 -0
  23. tobac/tests/test_iris_xarray_match_utils.py +244 -0
  24. tobac/tests/test_merge_split.py +351 -0
  25. tobac/tests/test_pbc_utils.py +497 -0
  26. tobac/tests/test_sample_data.py +197 -0
  27. tobac/tests/test_testing.py +747 -0
  28. tobac/tests/test_tracking.py +714 -0
  29. tobac/tests/test_utils.py +650 -0
  30. tobac/tests/test_utils_bulk_statistics.py +789 -0
  31. tobac/tests/test_utils_coordinates.py +328 -0
  32. tobac/tests/test_utils_internal.py +97 -0
  33. tobac/tests/test_xarray_utils.py +232 -0
  34. tobac/tracking.py +613 -0
  35. tobac/utils/__init__.py +27 -0
  36. tobac/utils/bulk_statistics.py +360 -0
  37. tobac/utils/datetime.py +184 -0
  38. tobac/utils/decorators.py +540 -0
  39. tobac/utils/general.py +753 -0
  40. tobac/utils/generators.py +87 -0
  41. tobac/utils/internal/__init__.py +2 -0
  42. tobac/utils/internal/coordinates.py +430 -0
  43. tobac/utils/internal/iris_utils.py +462 -0
  44. tobac/utils/internal/label_props.py +82 -0
  45. tobac/utils/internal/xarray_utils.py +439 -0
  46. tobac/utils/mask.py +364 -0
  47. tobac/utils/periodic_boundaries.py +419 -0
  48. tobac/wrapper.py +244 -0
  49. tobac-1.6.2.dist-info/METADATA +154 -0
  50. tobac-1.6.2.dist-info/RECORD +53 -0
  51. tobac-1.6.2.dist-info/WHEEL +5 -0
  52. tobac-1.6.2.dist-info/licenses/LICENSE +29 -0
  53. tobac-1.6.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,789 @@
1
+ from datetime import datetime
2
+
3
+ import numpy as np
4
+ import dask.array as da
5
+ import pandas as pd
6
+ import xarray as xr
7
+ import pytest
8
+
9
+ import tobac
10
+ import tobac.utils as tb_utils
11
+ import tobac.testing as tb_test
12
+
13
+
14
+ @pytest.mark.parametrize("statistics_unsmoothed", [(False), (True)])
15
+ def test_bulk_statistics_fd(statistics_unsmoothed):
16
+ """
17
+ Assure that bulk statistics in feature detection work, both on smoothed and raw data
18
+ """
19
+ ### Test 2D data with time dimension
20
+ test_data = tb_test.make_simple_sample_data_2D().core_data()
21
+ common_dset_opts = {
22
+ "in_arr": test_data,
23
+ "data_type": "iris",
24
+ }
25
+ test_data_iris = tb_test.make_dataset_from_arr(
26
+ time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
27
+ )
28
+ stats = {"feature_max": np.max}
29
+
30
+ # detect features
31
+ threshold = 7
32
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
33
+ test_data_iris,
34
+ dxy=1000,
35
+ threshold=[threshold],
36
+ n_min_threshold=100,
37
+ target="maximum",
38
+ statistic=stats,
39
+ statistics_unsmoothed=statistics_unsmoothed,
40
+ )
41
+
42
+ assert "feature_max" in fd_output.columns
43
+
44
+
45
+ @pytest.mark.parametrize(
46
+ "id_column, index",
47
+ [
48
+ ("feature", [1]),
49
+ ("feature_id", [1]),
50
+ ("cell", [1]),
51
+ ],
52
+ )
53
+ def test_bulk_statistics(id_column, index):
54
+ """
55
+ Test to assure that bulk statistics for identified features are computed as expected.
56
+
57
+ """
58
+
59
+ ### Test 2D data with time dimension
60
+ test_data = tb_test.make_simple_sample_data_2D().core_data()
61
+ common_dset_opts = {
62
+ "in_arr": test_data,
63
+ "data_type": "iris",
64
+ }
65
+ test_data_iris = tb_test.make_dataset_from_arr(
66
+ time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
67
+ )
68
+
69
+ # detect features
70
+ threshold = 7
71
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
72
+ test_data_iris,
73
+ dxy=1000,
74
+ threshold=[threshold],
75
+ n_min_threshold=100,
76
+ target="maximum",
77
+ )
78
+
79
+ # perform segmentation with bulk statistics
80
+ stats = {
81
+ "segment_max": np.max,
82
+ "segment_min": min,
83
+ "percentiles": (np.percentile, {"q": 95}),
84
+ }
85
+ out_seg_mask, out_df = tobac.segmentation.segmentation_2D(
86
+ fd_output, test_data_iris, dxy=1000, threshold=threshold, statistic=stats
87
+ )
88
+
89
+ #### checks
90
+ out_df = out_df.rename(columns={"feature": id_column})
91
+
92
+ # assure that bulk statistics in postprocessing give same result
93
+ out_segmentation = tb_utils.get_statistics_from_mask(
94
+ out_df, out_seg_mask, test_data_iris, statistic=stats, id_column=id_column
95
+ )
96
+ assert out_segmentation.equals(out_df)
97
+
98
+ # assure that column names in new dataframe correspond to keys in statistics dictionary
99
+ for key in stats.keys():
100
+ assert key in out_df.columns
101
+
102
+ # assure that statistics bring expected result
103
+ for frame in out_df.frame.values:
104
+ assert out_df[out_df.frame == frame].segment_max.values[0] == np.max(
105
+ test_data[frame]
106
+ )
107
+
108
+ ### Test the same with 3D data
109
+ test_data_iris = tb_test.make_sample_data_3D_3blobs()
110
+
111
+ # detect features in test dataset
112
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
113
+ test_data_iris,
114
+ dxy=1000,
115
+ threshold=[threshold],
116
+ n_min_threshold=100,
117
+ target="maximum",
118
+ )
119
+
120
+ # perform segmentation with bulk statistics
121
+ stats = {
122
+ "segment_max": np.max,
123
+ "segment_min": min,
124
+ "percentiles": (np.percentile, {"q": 95}),
125
+ }
126
+ out_seg_mask, out_df = tobac.segmentation.segmentation_3D(
127
+ fd_output, test_data_iris, dxy=1000, threshold=threshold, statistic=stats
128
+ )
129
+
130
+ ##### checks #####
131
+ out_df = out_df.rename(columns={"feature": id_column})
132
+ # assure that bulk statistics in postprocessing give same result
133
+ out_segmentation = tb_utils.get_statistics_from_mask(
134
+ out_df, out_seg_mask, test_data_iris, statistic=stats, id_column=id_column
135
+ )
136
+
137
+ assert out_segmentation.equals(out_df)
138
+
139
+ # assure that column names in new dataframe correspond to keys in statistics dictionary
140
+ for key in stats.keys():
141
+ assert key in out_df.columns
142
+
143
+ # assure that statistics bring expected result
144
+ for frame in out_df.frame.values:
145
+ assert out_df[out_df.frame == frame].segment_max.values[0] == np.max(
146
+ test_data_iris.data[frame]
147
+ )
148
+
149
+
150
+ def test_bulk_statistics_missing_segments():
151
+ """
152
+ Test that output feature dataframe contains all the same time steps even though for some timesteps,
153
+ the statistics have not been calculated (in the case of unmatching labels or no segment labels for a given feature)
154
+ """
155
+
156
+ ### Test 2D data with time dimension
157
+ test_data = tb_test.make_simple_sample_data_2D().core_data()
158
+ common_dset_opts = {
159
+ "in_arr": test_data,
160
+ "data_type": "iris",
161
+ }
162
+
163
+ test_data_iris = tb_test.make_dataset_from_arr(
164
+ time_dim_num=0, y_dim_num=1, x_dim_num=2, **common_dset_opts
165
+ )
166
+
167
+ # detect features
168
+ threshold = 7
169
+ # test_data_iris = testing.make_dataset_from_arr(test_data, data_type="iris")
170
+ fd_output = tobac.feature_detection.feature_detection_multithreshold(
171
+ test_data_iris,
172
+ dxy=1000,
173
+ threshold=[threshold],
174
+ n_min_threshold=100,
175
+ target="maximum",
176
+ )
177
+
178
+ # perform segmentation with bulk statistics
179
+ stats = {
180
+ "segment_max": np.max,
181
+ "segment_min": min,
182
+ "percentiles": (np.percentile, {"q": 95}),
183
+ }
184
+
185
+ out_seg_mask, out_df = tobac.segmentation.segmentation_2D(
186
+ fd_output, test_data_iris, dxy=1000, threshold=threshold
187
+ )
188
+
189
+ # specify some timesteps we set to zero
190
+ timesteps_to_zero = [1, 3, 10] # 0-based indexing
191
+ modified_data = out_seg_mask.data.copy()
192
+ # Set values to zero for the specified timesteps
193
+ for timestep in timesteps_to_zero:
194
+ modified_data[timestep, :, :] = 0 # Set all values for this timestep to zero
195
+
196
+ # assure that bulk statistics in postprocessing give same result
197
+ out_segmentation = tb_utils.get_statistics_from_mask(
198
+ out_df, out_seg_mask, test_data_iris, statistic=stats
199
+ )
200
+
201
+ assert out_df.time.unique().size == out_segmentation.time.unique().size
202
+
203
+
204
+ def test_bulk_statistics_multiple_fields():
205
+ """
206
+ Test that multiple field input to bulk_statistics works as intended
207
+ """
208
+
209
+ test_labels = np.array(
210
+ [
211
+ [
212
+ [0, 0, 0, 0, 0],
213
+ [0, 1, 0, 2, 0],
214
+ [0, 1, 0, 2, 0],
215
+ [0, 1, 0, 0, 0],
216
+ [0, 0, 0, 0, 0],
217
+ ],
218
+ [
219
+ [0, 0, 0, 0, 0],
220
+ [0, 3, 0, 0, 0],
221
+ [0, 3, 0, 4, 0],
222
+ [0, 3, 0, 4, 0],
223
+ [0, 0, 0, 0, 0],
224
+ ],
225
+ ],
226
+ dtype=int,
227
+ )
228
+
229
+ test_labels = xr.DataArray(
230
+ test_labels,
231
+ dims=("time", "y", "x"),
232
+ coords={
233
+ "time": [datetime(2000, 1, 1), datetime(2000, 1, 1, 0, 5)],
234
+ "y": np.arange(5),
235
+ "x": np.arange(5),
236
+ },
237
+ )
238
+
239
+ test_values = np.array(
240
+ [
241
+ [
242
+ [0, 0, 0, 0, 0],
243
+ [0, 1, 0, 2, 0],
244
+ [0, 2, 0, 2, 0],
245
+ [0, 3, 0, 0, 0],
246
+ [0, 0, 0, 0, 0],
247
+ ],
248
+ [
249
+ [0, 0, 0, 0, 0],
250
+ [0, 2, 0, 0, 0],
251
+ [0, 3, 0, 3, 0],
252
+ [0, 4, 0, 2, 0],
253
+ [0, 0, 0, 0, 0],
254
+ ],
255
+ ]
256
+ )
257
+
258
+ test_values = xr.DataArray(
259
+ test_values, dims=test_labels.dims, coords=test_labels.coords
260
+ )
261
+
262
+ test_weights = np.array(
263
+ [
264
+ [
265
+ [0, 0, 0, 0, 0],
266
+ [0, 0, 0, 1, 0],
267
+ [0, 0, 0, 1, 0],
268
+ [0, 1, 0, 0, 0],
269
+ [0, 0, 0, 0, 0],
270
+ ],
271
+ [
272
+ [0, 0, 0, 0, 0],
273
+ [0, 1, 0, 0, 0],
274
+ [0, 0, 0, 1, 0],
275
+ [0, 0, 0, 1, 0],
276
+ [0, 0, 0, 0, 0],
277
+ ],
278
+ ]
279
+ )
280
+
281
+ test_weights = xr.DataArray(
282
+ test_weights, dims=test_labels.dims, coords=test_labels.coords
283
+ )
284
+
285
+ test_features = pd.DataFrame(
286
+ {
287
+ "feature": [1, 2, 3, 4],
288
+ "frame": [0, 0, 1, 1],
289
+ "time": [
290
+ datetime(2000, 1, 1),
291
+ datetime(2000, 1, 1),
292
+ datetime(2000, 1, 1, 0, 5),
293
+ datetime(2000, 1, 1, 0, 5),
294
+ ],
295
+ }
296
+ )
297
+
298
+ statistics_mean = {"mean": np.mean}
299
+
300
+ expected_mean_result = np.array([2, 2, 3, 2.5])
301
+
302
+ bulk_statistics_output = tb_utils.get_statistics_from_mask(
303
+ test_features, test_labels, test_values, statistic=statistics_mean
304
+ )
305
+
306
+ statistics_weighted_mean = {
307
+ "weighted_mean": (lambda x, y: np.average(x, weights=y))
308
+ }
309
+
310
+ expected_weighted_mean_result = np.array([3, 2, 2, 2.5])
311
+
312
+ bulk_statistics_output = tb_utils.get_statistics_from_mask(
313
+ bulk_statistics_output,
314
+ test_labels,
315
+ test_values,
316
+ test_weights,
317
+ statistic=statistics_weighted_mean,
318
+ )
319
+
320
+ assert np.all(bulk_statistics_output["mean"] == expected_mean_result)
321
+ assert np.all(
322
+ bulk_statistics_output["weighted_mean"] == expected_weighted_mean_result
323
+ )
324
+
325
+
326
+ def test_bulk_statistics_time_invariant_field():
327
+ """
328
+ Some fields, such as area, are time invariant, and so passing an array with
329
+ a time dimension is memory inefficient. Here we test if
330
+ `get_statistics_from_mask` works if an input field has no time dimension,
331
+ by passing the whole field to `get_statistics` rather than a time slice.
332
+ """
333
+
334
+ test_labels = np.array(
335
+ [
336
+ [
337
+ [0, 0, 0, 0, 0],
338
+ [0, 1, 0, 2, 0],
339
+ [0, 1, 0, 2, 0],
340
+ [0, 1, 0, 0, 0],
341
+ [0, 0, 0, 0, 0],
342
+ ],
343
+ [
344
+ [0, 0, 0, 0, 0],
345
+ [0, 3, 0, 0, 0],
346
+ [0, 3, 0, 4, 0],
347
+ [0, 3, 0, 4, 0],
348
+ [0, 0, 0, 0, 0],
349
+ ],
350
+ ],
351
+ dtype=int,
352
+ )
353
+
354
+ test_labels = xr.DataArray(
355
+ test_labels,
356
+ dims=("time", "y", "x"),
357
+ coords={
358
+ "time": [datetime(2000, 1, 1), datetime(2000, 1, 1, 0, 5)],
359
+ "y": np.arange(5),
360
+ "x": np.arange(5),
361
+ },
362
+ )
363
+
364
+ test_areas = np.array(
365
+ [
366
+ [0.25, 0.5, 0.75, 1, 1],
367
+ [0.25, 0.5, 0.75, 1, 1],
368
+ [0.25, 0.5, 0.75, 1, 1],
369
+ [0.25, 0.5, 0.75, 1, 1],
370
+ [0.25, 0.5, 0.75, 1, 1],
371
+ ]
372
+ )
373
+
374
+ test_areas = xr.DataArray(
375
+ test_areas,
376
+ dims=("y", "x"),
377
+ coords={
378
+ "y": np.arange(5),
379
+ "x": np.arange(5),
380
+ },
381
+ )
382
+
383
+ test_features = pd.DataFrame(
384
+ {
385
+ "feature": [1, 2, 3, 4],
386
+ "frame": [0, 0, 1, 1],
387
+ "time": [
388
+ datetime(2000, 1, 1),
389
+ datetime(2000, 1, 1),
390
+ datetime(2000, 1, 1, 0, 5),
391
+ datetime(2000, 1, 1, 0, 5),
392
+ ],
393
+ }
394
+ )
395
+
396
+ statistics_sum = {"sum": np.sum}
397
+
398
+ expected_sum_result = np.array([1.5, 2, 1.5, 2])
399
+
400
+ bulk_statistics_output = tb_utils.get_statistics_from_mask(
401
+ test_features, test_labels, test_areas, statistic=statistics_sum
402
+ )
403
+
404
+ assert np.all(bulk_statistics_output["sum"] == expected_sum_result)
405
+
406
+
407
+ def test_bulk_statistics_broadcasting():
408
+ """
409
+ Test whether field broadcasting works for bulk_statistics, with both leading and trailing dimensions tested
410
+ """
411
+ test_labels = np.array(
412
+ [
413
+ [
414
+ [0, 0, 0, 0, 0],
415
+ [0, 1, 0, 2, 0],
416
+ [0, 1, 0, 2, 0],
417
+ [0, 1, 0, 0, 0],
418
+ [0, 0, 0, 0, 0],
419
+ ],
420
+ [
421
+ [0, 0, 0, 0, 0],
422
+ [0, 3, 0, 0, 0],
423
+ [0, 3, 0, 4, 0],
424
+ [0, 3, 0, 4, 0],
425
+ [0, 0, 0, 0, 0],
426
+ ],
427
+ ],
428
+ dtype=int,
429
+ )
430
+
431
+ test_labels = xr.DataArray(
432
+ test_labels,
433
+ dims=("time", "y", "x"),
434
+ coords={
435
+ "time": [datetime(2000, 1, 1), datetime(2000, 1, 1, 0, 5)],
436
+ "y": np.arange(5),
437
+ "x": np.arange(5),
438
+ },
439
+ )
440
+
441
+ test_values = np.array(
442
+ [
443
+ [0.25, 0.5, 0.75, 1, 1],
444
+ [1.25, 1.5, 1.75, 2, 2],
445
+ ]
446
+ )
447
+
448
+ test_values = xr.DataArray(
449
+ test_values,
450
+ dims=("time", "x"),
451
+ coords={"time": test_labels.time, "x": test_labels.x},
452
+ )
453
+
454
+ test_weights = np.array([0, 0, 1, 0, 0]).reshape([5, 1])
455
+
456
+ test_weights = xr.DataArray(
457
+ test_weights, dims=("y", "z"), coords={"y": test_labels.y}
458
+ )
459
+
460
+ test_features = pd.DataFrame(
461
+ {
462
+ "feature": [1, 2, 3, 4],
463
+ "frame": [0, 0, 1, 1],
464
+ "time": [
465
+ datetime(2000, 1, 1),
466
+ datetime(2000, 1, 1),
467
+ datetime(2000, 1, 1, 0, 5),
468
+ datetime(2000, 1, 1, 0, 5),
469
+ ],
470
+ }
471
+ )
472
+
473
+ statistics_sum = {"sum": np.sum}
474
+
475
+ expected_sum_result = np.array([1.5, 2, 4.5, 4])
476
+
477
+ bulk_statistics_output = tb_utils.get_statistics_from_mask(
478
+ test_features, test_labels, test_values, statistic=statistics_sum
479
+ )
480
+
481
+ statistics_weighted_sum = {"weighted_sum": (lambda x, y: np.sum(x * y))}
482
+
483
+ expected_weighted_sum_result = np.array([0.5, 1, 1.5, 2])
484
+
485
+ bulk_statistics_output = tb_utils.get_statistics_from_mask(
486
+ bulk_statistics_output,
487
+ test_labels,
488
+ test_values,
489
+ test_weights,
490
+ statistic=statistics_weighted_sum,
491
+ )
492
+
493
+ assert np.all(bulk_statistics_output["sum"] == expected_sum_result)
494
+ assert np.all(
495
+ bulk_statistics_output["weighted_sum"] == expected_weighted_sum_result
496
+ )
497
+
498
+
499
+ def test_get_statistics_collapse_axis():
500
+ """
501
+ Test the collapse_axis keyword of get_statistics
502
+ """
503
+ test_labels = np.array(
504
+ [
505
+ [0, 0, 0, 0, 0],
506
+ [0, 1, 0, 0, 0],
507
+ [0, 1, 0, 2, 0],
508
+ [0, 1, 0, 2, 0],
509
+ [0, 0, 0, 0, 0],
510
+ ],
511
+ dtype=int,
512
+ )
513
+
514
+ test_values = np.array([0.25, 0.5, 0.75, 1, 1])
515
+
516
+ test_features = pd.DataFrame(
517
+ {
518
+ "feature": [1, 2],
519
+ "frame": [0, 0],
520
+ "time": [
521
+ datetime(2000, 1, 1),
522
+ datetime(2000, 1, 1),
523
+ ],
524
+ }
525
+ )
526
+ statistics_sum = {"sum": np.sum}
527
+
528
+ expected_sum_result_axis0 = np.array([0.5, 1])
529
+ output_collapse_axis0 = tb_utils.get_statistics(
530
+ test_features,
531
+ test_labels,
532
+ test_values,
533
+ statistic=statistics_sum,
534
+ collapse_axis=0,
535
+ )
536
+ assert np.all(output_collapse_axis0["sum"] == expected_sum_result_axis0)
537
+
538
+ expected_sum_result_axis1 = np.array([2.25, 1.75])
539
+ output_collapse_axis1 = tb_utils.get_statistics(
540
+ test_features,
541
+ test_labels,
542
+ test_values,
543
+ statistic=statistics_sum,
544
+ collapse_axis=1,
545
+ )
546
+ assert np.all(output_collapse_axis1["sum"] == expected_sum_result_axis1)
547
+
548
+ # Check that attempting broadcast raises a ValueError
549
+ with pytest.raises(ValueError):
550
+ _ = tb_utils.get_statistics(
551
+ test_features,
552
+ test_labels,
553
+ test_values.reshape([5, 1]),
554
+ statistic=statistics_sum,
555
+ collapse_axis=0,
556
+ )
557
+
558
+ # Check that attempting to collapse all axes raises a ValueError:
559
+ with pytest.raises(ValueError):
560
+ _ = tb_utils.get_statistics(
561
+ test_features,
562
+ test_labels,
563
+ test_values,
564
+ statistic=statistics_sum,
565
+ collapse_axis=[0, 1],
566
+ )
567
+
568
+ # Test with collpasing multiple axes
569
+ test_labels = np.array(
570
+ [
571
+ [
572
+ [0, 0, 0, 0, 0],
573
+ [0, 1, 0, 0, 0],
574
+ [0, 1, 0, 2, 0],
575
+ [0, 1, 0, 2, 0],
576
+ [0, 0, 0, 0, 0],
577
+ ],
578
+ [
579
+ [0, 0, 0, 0, 0],
580
+ [0, 1, 0, 0, 0],
581
+ [0, 1, 0, 0, 0],
582
+ [0, 0, 0, 0, 0],
583
+ [0, 0, 0, 0, 0],
584
+ ],
585
+ ],
586
+ dtype=int,
587
+ )
588
+ test_values = np.array([0.5, 1])
589
+ expected_sum_result_axis12 = np.array([1.5, 0.5])
590
+ output_collapse_axis12 = tb_utils.get_statistics(
591
+ test_features,
592
+ test_labels,
593
+ test_values,
594
+ statistic=statistics_sum,
595
+ collapse_axis=[1, 2],
596
+ )
597
+ assert np.all(output_collapse_axis12["sum"] == expected_sum_result_axis12)
598
+
599
+
600
+ def test_get_statistics_from_mask_collapse_dim():
601
+ """
602
+ Test the collapse_dim keyword of get_statistics_from_mask
603
+ """
604
+
605
+ test_labels = np.array(
606
+ [
607
+ [
608
+ [
609
+ [0, 0, 0, 0, 0],
610
+ [0, 1, 0, 2, 0],
611
+ [0, 1, 0, 2, 0],
612
+ [0, 1, 0, 0, 0],
613
+ [0, 0, 0, 0, 0],
614
+ ],
615
+ [
616
+ [0, 0, 0, 0, 0],
617
+ [0, 1, 0, 0, 0],
618
+ [0, 1, 0, 3, 0],
619
+ [0, 1, 0, 3, 0],
620
+ [0, 0, 0, 0, 0],
621
+ ],
622
+ ],
623
+ ],
624
+ dtype=int,
625
+ )
626
+
627
+ test_labels = xr.DataArray(
628
+ test_labels,
629
+ dims=("time", "z", "y", "x"),
630
+ coords={
631
+ "time": [datetime(2000, 1, 1)],
632
+ "z": np.arange(2),
633
+ "y": np.arange(5),
634
+ "x": np.arange(5),
635
+ },
636
+ )
637
+
638
+ test_values = np.ones([5, 5])
639
+
640
+ test_values = xr.DataArray(
641
+ test_values,
642
+ dims=("x", "y"),
643
+ coords={
644
+ "y": np.arange(5),
645
+ "x": np.arange(5),
646
+ },
647
+ )
648
+
649
+ test_features = pd.DataFrame(
650
+ {
651
+ "feature": [1, 2, 3],
652
+ "frame": [0, 0, 0],
653
+ "time": [
654
+ datetime(2000, 1, 1),
655
+ datetime(2000, 1, 1),
656
+ datetime(2000, 1, 1),
657
+ ],
658
+ }
659
+ )
660
+
661
+ statistics_sum = {"sum": np.sum}
662
+
663
+ expected_sum_result = np.array([3, 2, 2])
664
+
665
+ # Test over a single dim
666
+ statistics_output = tb_utils.get_statistics_from_mask(
667
+ test_features,
668
+ test_labels,
669
+ test_values,
670
+ statistic=statistics_sum,
671
+ collapse_dim="z",
672
+ )
673
+
674
+ assert np.all(statistics_output["sum"] == expected_sum_result)
675
+
676
+ test_values = np.ones([2])
677
+
678
+ test_values = xr.DataArray(
679
+ test_values,
680
+ dims=("z",),
681
+ coords={
682
+ "z": np.arange(2),
683
+ },
684
+ )
685
+
686
+ expected_sum_result = np.array([2, 1, 1])
687
+
688
+ # Test over multiple dims
689
+ statistics_output = tb_utils.get_statistics_from_mask(
690
+ test_features,
691
+ test_labels,
692
+ test_values,
693
+ statistic=statistics_sum,
694
+ collapse_dim=("x", "y"),
695
+ )
696
+
697
+ assert np.all(statistics_output["sum"] == expected_sum_result)
698
+
699
+ # Test that collapse_dim not in labels raises an error
700
+ with pytest.raises(ValueError):
701
+ _ = statistics_output = tb_utils.get_statistics_from_mask(
702
+ test_features,
703
+ test_labels,
704
+ test_values,
705
+ statistic=statistics_sum,
706
+ collapse_dim="not_a_dim",
707
+ )
708
+
709
+
710
+ def test_bulk_statistics_dask():
711
+ """
712
+ Test dask input for labels and fields is handled correctly
713
+ """
714
+
715
+ test_labels = da.array(
716
+ [
717
+ [
718
+ [0, 0, 0, 0, 0],
719
+ [0, 1, 0, 2, 0],
720
+ [0, 1, 0, 2, 0],
721
+ [0, 1, 0, 0, 0],
722
+ [0, 0, 0, 0, 0],
723
+ ],
724
+ [
725
+ [0, 0, 0, 0, 0],
726
+ [0, 3, 0, 0, 0],
727
+ [0, 3, 0, 4, 0],
728
+ [0, 3, 0, 4, 0],
729
+ [0, 0, 0, 0, 0],
730
+ ],
731
+ ],
732
+ dtype=int,
733
+ )
734
+
735
+ test_labels = xr.DataArray(
736
+ test_labels,
737
+ dims=("time", "y", "x"),
738
+ coords={
739
+ "time": [datetime(2000, 1, 1), datetime(2000, 1, 1, 0, 5)],
740
+ "y": np.arange(5),
741
+ "x": np.arange(5),
742
+ },
743
+ )
744
+
745
+ test_values = da.array(
746
+ [
747
+ [
748
+ [0, 0, 0, 0, 0],
749
+ [0, 1, 0, 2, 0],
750
+ [0, 2, 0, 2, 0],
751
+ [0, 3, 0, 0, 0],
752
+ [0, 0, 0, 0, 0],
753
+ ],
754
+ [
755
+ [0, 0, 0, 0, 0],
756
+ [0, 2, 0, 0, 0],
757
+ [0, 3, 0, 3, 0],
758
+ [0, 4, 0, 2, 0],
759
+ [0, 0, 0, 0, 0],
760
+ ],
761
+ ]
762
+ )
763
+
764
+ test_values = xr.DataArray(
765
+ test_values, dims=test_labels.dims, coords=test_labels.coords
766
+ )
767
+
768
+ test_features = pd.DataFrame(
769
+ {
770
+ "feature": [1, 2, 3, 4],
771
+ "frame": [0, 0, 1, 1],
772
+ "time": [
773
+ datetime(2000, 1, 1),
774
+ datetime(2000, 1, 1),
775
+ datetime(2000, 1, 1, 0, 5),
776
+ datetime(2000, 1, 1, 0, 5),
777
+ ],
778
+ }
779
+ )
780
+
781
+ statistics_size = {"size": np.size}
782
+
783
+ expected_size_result = np.array([3, 2, 3, 2])
784
+
785
+ bulk_statistics_output = tb_utils.get_statistics_from_mask(
786
+ test_features, test_labels, test_values, statistic=statistics_size
787
+ )
788
+
789
+ assert np.all(bulk_statistics_output["size"] == expected_size_result)