tobac 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. tobac/__init__.py +112 -0
  2. tobac/analysis/__init__.py +31 -0
  3. tobac/analysis/cell_analysis.py +628 -0
  4. tobac/analysis/feature_analysis.py +212 -0
  5. tobac/analysis/spatial.py +619 -0
  6. tobac/centerofgravity.py +226 -0
  7. tobac/feature_detection.py +1758 -0
  8. tobac/merge_split.py +324 -0
  9. tobac/plotting.py +2321 -0
  10. tobac/segmentation/__init__.py +10 -0
  11. tobac/segmentation/watershed_segmentation.py +1316 -0
  12. tobac/testing.py +1179 -0
  13. tobac/tests/segmentation_tests/test_iris_xarray_segmentation.py +0 -0
  14. tobac/tests/segmentation_tests/test_segmentation.py +1183 -0
  15. tobac/tests/segmentation_tests/test_segmentation_time_pad.py +104 -0
  16. tobac/tests/test_analysis_spatial.py +1109 -0
  17. tobac/tests/test_convert.py +265 -0
  18. tobac/tests/test_datetime.py +216 -0
  19. tobac/tests/test_decorators.py +148 -0
  20. tobac/tests/test_feature_detection.py +1321 -0
  21. tobac/tests/test_generators.py +273 -0
  22. tobac/tests/test_import.py +24 -0
  23. tobac/tests/test_iris_xarray_match_utils.py +244 -0
  24. tobac/tests/test_merge_split.py +351 -0
  25. tobac/tests/test_pbc_utils.py +497 -0
  26. tobac/tests/test_sample_data.py +197 -0
  27. tobac/tests/test_testing.py +747 -0
  28. tobac/tests/test_tracking.py +714 -0
  29. tobac/tests/test_utils.py +650 -0
  30. tobac/tests/test_utils_bulk_statistics.py +789 -0
  31. tobac/tests/test_utils_coordinates.py +328 -0
  32. tobac/tests/test_utils_internal.py +97 -0
  33. tobac/tests/test_xarray_utils.py +232 -0
  34. tobac/tracking.py +613 -0
  35. tobac/utils/__init__.py +27 -0
  36. tobac/utils/bulk_statistics.py +360 -0
  37. tobac/utils/datetime.py +184 -0
  38. tobac/utils/decorators.py +540 -0
  39. tobac/utils/general.py +753 -0
  40. tobac/utils/generators.py +87 -0
  41. tobac/utils/internal/__init__.py +2 -0
  42. tobac/utils/internal/coordinates.py +430 -0
  43. tobac/utils/internal/iris_utils.py +462 -0
  44. tobac/utils/internal/label_props.py +82 -0
  45. tobac/utils/internal/xarray_utils.py +439 -0
  46. tobac/utils/mask.py +364 -0
  47. tobac/utils/periodic_boundaries.py +419 -0
  48. tobac/wrapper.py +244 -0
  49. tobac-1.6.2.dist-info/METADATA +154 -0
  50. tobac-1.6.2.dist-info/RECORD +53 -0
  51. tobac-1.6.2.dist-info/WHEEL +5 -0
  52. tobac-1.6.2.dist-info/licenses/LICENSE +29 -0
  53. tobac-1.6.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,328 @@
1
+ """Unit tests for tobac.utils.internal.general_internal.py"""
2
+
3
+ import pytest
4
+ import pandas as pd
5
+
6
+ import tobac.utils.internal.coordinates as coord_utils
7
+
8
+
9
+ def test_find_coord_in_dataframe_errors():
10
+ """Test that find_coord_in_dataframe raises errors correctly"""
11
+ defaults = ["x", "projection_x_coordinate", "__other_name"]
12
+
13
+ # Test no options raises ValueError:
14
+ with pytest.raises(ValueError, match="One of coord or defaults parameter*"):
15
+ coord_utils.find_coord_in_dataframe(pd.DataFrame(columns=["time", "x"]))
16
+
17
+ # Test coordinate specified not in dataframe raise ValueError:
18
+ with pytest.raises(ValueError, match="Coordinate*"):
19
+ coord_utils.find_coord_in_dataframe(
20
+ pd.DataFrame(columns=["time", "x"]), coord="projection_x_coordinate"
21
+ )
22
+
23
+ # Test no coordinates matching defaults:
24
+ with pytest.raises(ValueError, match="No coordinate found matching defaults*"):
25
+ coord_utils.find_coord_in_dataframe(
26
+ pd.DataFrame(columns=["time", "y"]), defaults=defaults
27
+ )
28
+
29
+ # Test multiple matches with defaults:
30
+ with pytest.raises(ValueError, match="Multiple matching*"):
31
+ coord_utils.find_coord_in_dataframe(
32
+ pd.DataFrame(columns=["time", "x", "projection_x_coordinate"]),
33
+ defaults=defaults,
34
+ )
35
+
36
+ # Test that giving an object that is not a dataframe or series returns an error
37
+ with pytest.raises(ValueError, match="Input variable_dataframe is neither*"):
38
+ coord_utils.find_coord_in_dataframe("test_str", defaults=defaults)
39
+
40
+
41
+ def test_find_coord_in_dataframe():
42
+ """Test that find_coord_in_dataframe returns correct results for both
43
+ default and specific coordinates
44
+ """
45
+ defaults = ["x", "projection_x_coordinate", "__other_name"]
46
+
47
+ # Now test correct returns:
48
+ assert (
49
+ coord_utils.find_coord_in_dataframe(
50
+ pd.DataFrame(columns=["time", "x", "projection_x_coordinate"]), coord="x"
51
+ )
52
+ == "x"
53
+ )
54
+
55
+ assert (
56
+ coord_utils.find_coord_in_dataframe(
57
+ pd.DataFrame(columns=["time", "x", "projection_x_coordinate"]),
58
+ coord="projection_x_coordinate",
59
+ )
60
+ == "projection_x_coordinate"
61
+ )
62
+
63
+ assert (
64
+ coord_utils.find_coord_in_dataframe(
65
+ pd.DataFrame(columns=["time", "x", "y"]), defaults=defaults
66
+ )
67
+ == "x"
68
+ )
69
+
70
+ assert (
71
+ coord_utils.find_coord_in_dataframe(
72
+ pd.DataFrame(
73
+ columns=["time", "projection_x_coordinate", "projection_y_coordinate"]
74
+ ),
75
+ defaults=defaults,
76
+ )
77
+ == "projection_x_coordinate"
78
+ )
79
+
80
+ assert (
81
+ coord_utils.find_coord_in_dataframe(
82
+ pd.DataFrame(columns=["time", "x", "projection_x_coordinate"]),
83
+ coord="x",
84
+ defaults=defaults,
85
+ )
86
+ == "x"
87
+ )
88
+
89
+ # Test pd.Series input:
90
+ assert (
91
+ coord_utils.find_coord_in_dataframe(
92
+ pd.Series(index=["time", "x", "projection_x_coordinate"]), coord="x"
93
+ )
94
+ == "x"
95
+ )
96
+
97
+
98
+ def test_find_dataframe_vertical_coord_warning():
99
+ """Test the warning for coord="auto" in find_dataframe_vertical_coord"""
100
+ with pytest.warns(DeprecationWarning):
101
+ coord_utils.find_dataframe_vertical_coord(
102
+ pd.DataFrame(columns=["z"]), vertical_coord="auto"
103
+ )
104
+
105
+
106
+ def test_find_dataframe_vertical_coord_error():
107
+ """Test find_dataframe_vertical_coord raises errors correctly"""
108
+ # Test the error for invalid coord input:
109
+ with pytest.raises(ValueError):
110
+ coord_utils.find_dataframe_vertical_coord(
111
+ pd.DataFrame(columns=["z"]), vertical_coord="__bad_coord_name"
112
+ )
113
+
114
+ # Test the error for no default coord found:
115
+ with pytest.raises(ValueError):
116
+ coord_utils.find_dataframe_vertical_coord(pd.DataFrame(columns=["x"]))
117
+
118
+ # Test the error for multiple default coords found:
119
+ with pytest.raises(ValueError):
120
+ coord_utils.find_dataframe_vertical_coord(
121
+ pd.DataFrame(columns=["z", "geopotential_height"])
122
+ )
123
+
124
+
125
+ def test_find_dataframe_vertical_coord():
126
+ """Test find_dataframe_vertical_coord provides correct results"""
127
+ # Test default coords
128
+ assert coord_utils.find_dataframe_vertical_coord(pd.DataFrame(columns=["z"])) == "z"
129
+ assert (
130
+ coord_utils.find_dataframe_vertical_coord(
131
+ pd.DataFrame(columns=["geopotential_height"])
132
+ )
133
+ == "geopotential_height"
134
+ )
135
+
136
+ # Test coord input
137
+ assert (
138
+ coord_utils.find_dataframe_vertical_coord(
139
+ pd.DataFrame(columns=["p"]), vertical_coord="p"
140
+ )
141
+ == "p"
142
+ )
143
+
144
+ # Test coord input when multiple default coords
145
+ assert (
146
+ coord_utils.find_dataframe_vertical_coord(
147
+ pd.DataFrame(columns=["z", "geopotential_height"]), vertical_coord="z"
148
+ )
149
+ == "z"
150
+ )
151
+
152
+
153
+ def test_find_dataframe_horizontal_coords_error():
154
+ """Test find_dataframe_horizontal_coords raises errors correctly"""
155
+ # Test no matching coords
156
+ with pytest.raises(ValueError):
157
+ coord_utils.find_dataframe_horizontal_coords(
158
+ pd.DataFrame(columns=["time", "z"])
159
+ )
160
+
161
+ # Test hdim_1_coord or hdim_2_coord set but not coord_type
162
+ with pytest.raises(ValueError):
163
+ coord_utils.find_dataframe_horizontal_coords(
164
+ pd.DataFrame(columns=["time", "x", "y"]), hdim1_coord="y"
165
+ )
166
+
167
+ with pytest.raises(ValueError):
168
+ coord_utils.find_dataframe_horizontal_coords(
169
+ pd.DataFrame(columns=["time", "x", "y"]), hdim2_coord="x"
170
+ )
171
+
172
+ # Test one exists but not both:
173
+ with pytest.raises(ValueError):
174
+ coord_utils.find_dataframe_horizontal_coords(
175
+ pd.DataFrame(columns=["time", "x"])
176
+ )
177
+
178
+ with pytest.raises(ValueError):
179
+ coord_utils.find_dataframe_horizontal_coords(
180
+ pd.DataFrame(columns=["time", "y"])
181
+ )
182
+
183
+ # Test one of each exists
184
+ with pytest.raises(ValueError):
185
+ coord_utils.find_dataframe_horizontal_coords(
186
+ pd.DataFrame(columns=["time", "x", "lat"])
187
+ )
188
+
189
+ # Test failure to detect coords when hdim1_coord or hdim2_coord is specified:
190
+ with pytest.raises(ValueError):
191
+ coord_utils.find_dataframe_horizontal_coords(
192
+ pd.DataFrame(columns=["time", "x", "lat"]), hdim1_coord="y", coord_type="xy"
193
+ )
194
+
195
+ with pytest.raises(ValueError):
196
+ coord_utils.find_dataframe_horizontal_coords(
197
+ pd.DataFrame(columns=["time", "y", "lon"]), hdim2_coord="x", coord_type="xy"
198
+ )
199
+
200
+ with pytest.raises(ValueError):
201
+ coord_utils.find_dataframe_horizontal_coords(
202
+ pd.DataFrame(columns=["time", "x", "lon"]),
203
+ hdim1_coord="lat",
204
+ coord_type="latlon",
205
+ )
206
+
207
+ with pytest.raises(ValueError):
208
+ coord_utils.find_dataframe_horizontal_coords(
209
+ pd.DataFrame(columns=["time", "x", "lat"]),
210
+ hdim1_coord="lon",
211
+ coord_type="latlon",
212
+ )
213
+
214
+
215
+ def test_find_dataframe_horizontal_coords_error_coord_type():
216
+ """Test that find_dataframe_horizontal_coords raises errors correctly when
217
+ the specified coord_type does not match the coords present
218
+ """
219
+ # Check that if coord_type is specified that an error is raised even if the other type of coords are present
220
+ with pytest.raises(ValueError):
221
+ coord_utils.find_dataframe_horizontal_coords(
222
+ pd.DataFrame(columns=["time", "x", "y"]), coord_type="latlon"
223
+ )
224
+
225
+ with pytest.raises(ValueError):
226
+ coord_utils.find_dataframe_horizontal_coords(
227
+ pd.DataFrame(columns=["time", "lat", "lon"]), coord_type="xy"
228
+ )
229
+
230
+
231
+ def test_find_dataframe_horizontal_coords_defaults_xy():
232
+ """Test find_dataframe_horizontal_coords for xy coords"""
233
+ # Test defaults xy:
234
+ assert coord_utils.find_dataframe_horizontal_coords(
235
+ pd.DataFrame(columns=["time", "x", "y"])
236
+ ) == ("y", "x", "xy")
237
+
238
+ assert coord_utils.find_dataframe_horizontal_coords(
239
+ pd.DataFrame(
240
+ columns=["time", "projection_x_coordinate", "projection_y_coordinate"]
241
+ )
242
+ ) == ("projection_y_coordinate", "projection_x_coordinate", "xy")
243
+
244
+ # Test that xy take priority over latlon
245
+ assert coord_utils.find_dataframe_horizontal_coords(
246
+ pd.DataFrame(columns=["time", "x", "y", "lat", "lon"])
247
+ ) == ("y", "x", "xy")
248
+
249
+
250
+ def test_find_dataframe_horizontal_coords_defaults_latlon():
251
+ """Test find_dataframe_horizontal_coords for lat/lon coords"""
252
+ # Test defaults latlon
253
+ assert coord_utils.find_dataframe_horizontal_coords(
254
+ pd.DataFrame(columns=["time", "lon", "lat"])
255
+ ) == ("lat", "lon", "latlon")
256
+
257
+ assert coord_utils.find_dataframe_horizontal_coords(
258
+ pd.DataFrame(columns=["time", "Longitude", "Latitude"])
259
+ ) == ("Latitude", "Longitude", "latlon")
260
+
261
+ # Test that if only one of xy take latlon instead
262
+ assert coord_utils.find_dataframe_horizontal_coords(
263
+ pd.DataFrame(columns=["time", "x", "lat", "lon"])
264
+ ) == ("lat", "lon", "latlon")
265
+
266
+ # Test that setting coord_type to latlon ignores xy coords
267
+ assert coord_utils.find_dataframe_horizontal_coords(
268
+ pd.DataFrame(columns=["time", "x", "y", "lat", "lon"]), coord_type="latlon"
269
+ ) == ("lat", "lon", "latlon")
270
+
271
+
272
+ def test_find_dataframe_horizontal_coords_specific():
273
+ """Test find_dataframe_horizontal_coords when the coordinate name is
274
+ specified
275
+ """
276
+ assert coord_utils.find_dataframe_horizontal_coords(
277
+ pd.DataFrame(
278
+ columns=[
279
+ "time",
280
+ "x",
281
+ "y",
282
+ "projection_x_coordinate",
283
+ "projection_y_coordinate",
284
+ ]
285
+ ),
286
+ hdim1_coord="y",
287
+ hdim2_coord="x",
288
+ coord_type="xy",
289
+ ) == ("y", "x", "xy")
290
+
291
+ assert coord_utils.find_dataframe_horizontal_coords(
292
+ pd.DataFrame(
293
+ columns=[
294
+ "time",
295
+ "x",
296
+ "y",
297
+ "projection_x_coordinate",
298
+ "projection_y_coordinate",
299
+ ]
300
+ ),
301
+ hdim1_coord="projection_y_coordinate",
302
+ hdim2_coord="projection_x_coordinate",
303
+ coord_type="xy",
304
+ ) == ("projection_y_coordinate", "projection_x_coordinate", "xy")
305
+
306
+ # Check that order does not matter
307
+ assert coord_utils.find_dataframe_horizontal_coords(
308
+ pd.DataFrame(
309
+ columns=[
310
+ "time",
311
+ "x",
312
+ "y",
313
+ "projection_x_coordinate",
314
+ "projection_y_coordinate",
315
+ ]
316
+ ),
317
+ hdim1_coord="x",
318
+ hdim2_coord="y",
319
+ coord_type="xy",
320
+ ) == ("x", "y", "xy")
321
+
322
+ # Check that coord_type can be set wrong
323
+ assert coord_utils.find_dataframe_horizontal_coords(
324
+ pd.DataFrame(columns=["time", "x", "y", "lat", "lon"]),
325
+ hdim1_coord="lat",
326
+ hdim2_coord="lon",
327
+ coord_type="xy",
328
+ ) == ("lat", "lon", "xy")
@@ -0,0 +1,97 @@
1
+ import tobac.utils.internal as internal_utils
2
+ import tobac.testing as tbtest
3
+
4
+ import pytest
5
+ import numpy as np
6
+ import xarray as xr
7
+
8
+
9
+ @pytest.mark.parametrize(
10
+ "dset_type, time_axis, vertical_axis, expected_out",
11
+ [
12
+ ("iris", 0, 1, (2, 3)),
13
+ ("iris", -1, 0, (1, 2)),
14
+ ("iris", 0, -1, (1, 2)),
15
+ ("iris", 0, 2, (1, 3)),
16
+ ("iris", 3, 0, (1, 2)),
17
+ ("iris", 0, 3, (1, 2)),
18
+ ("iris", 1, 2, (0, 3)),
19
+ ("xarray", 0, 1, (2, 3)),
20
+ ("xarray", 0, 2, (1, 3)),
21
+ ("xarray", 3, 0, (1, 2)),
22
+ ("xarray", 0, 3, (1, 2)),
23
+ ("xarray", 1, 2, (0, 3)),
24
+ ],
25
+ )
26
+ def test_find_hdim_axes_3D(dset_type, time_axis, vertical_axis, expected_out):
27
+ """Tests tobac.utils.internal.file_hdim_axes_3D
28
+
29
+ Parameters
30
+ ----------
31
+ dset_type: str{"xarray" or "iris"}
32
+ type of the dataset to generate
33
+ time_axis: int
34
+ axis number of the time coordinate (or -1 to not have one)
35
+ vertical_axis: int
36
+ axis number of the vertical coordinate (or -1 to not have one)
37
+ expected_out: tuple (int, int)
38
+ expected output
39
+ """
40
+ ndims = 2 + (1 if time_axis >= 0 else 0) + (1 if vertical_axis >= 0 else 0)
41
+ test_dset_size = [2] * ndims
42
+
43
+ test_data = np.zeros(test_dset_size)
44
+
45
+ dset_opts = {
46
+ "in_arr": test_data,
47
+ "data_type": dset_type,
48
+ }
49
+ if time_axis >= 0:
50
+ dset_opts["time_dim_num"] = time_axis
51
+ if vertical_axis >= 0:
52
+ dset_opts["z_dim_num"] = vertical_axis
53
+ dset_opts["z_dim_name"] = "altitude"
54
+
55
+ y_set = False
56
+ for dim_number in range(ndims):
57
+ if time_axis != dim_number and vertical_axis != dim_number:
58
+ if not y_set:
59
+ dset_opts["y_dim_num"] = dim_number
60
+ y_set = True
61
+ else:
62
+ dset_opts["x_dim_num"] = dim_number
63
+
64
+ cube_test = tbtest.make_dataset_from_arr(**dset_opts)
65
+
66
+ out_coords = internal_utils.find_hdim_axes_3D(cube_test)
67
+
68
+ assert out_coords == expected_out
69
+
70
+
71
+ @pytest.mark.parametrize(
72
+ "lat_name, lon_name, lat_name_test, lon_name_test, expected_result",
73
+ [
74
+ ("lat", "lon", None, None, ("lat", "lon")),
75
+ ("lat", "long", None, None, ("lat", "long")),
76
+ ("lat", "altitude", None, None, ("lat", None)),
77
+ ("lat", "longitude", "lat", "longitude", ("lat", "longitude")),
78
+ ],
79
+ )
80
+ def test_detect_latlon_coord_name(
81
+ lat_name, lon_name, lat_name_test, lon_name_test, expected_result
82
+ ):
83
+ """Tests tobac.utils.internal.detect_latlon_coord_name"""
84
+
85
+ in_arr = np.empty((50, 50))
86
+ lat_vals = np.empty(50)
87
+ lon_vals = np.empty(50)
88
+
89
+ in_xr = xr.Dataset(
90
+ {"data": ((lat_name, lon_name), in_arr)},
91
+ coords={lat_name: lat_vals, lon_name: lon_vals},
92
+ )
93
+ out_lat_name, out_lon_name = internal_utils.detect_latlon_coord_name(
94
+ in_xr["data"].to_iris(), lat_name_test, lon_name_test
95
+ )
96
+ assert out_lat_name == expected_result[0]
97
+ assert out_lon_name == expected_result[1]
@@ -0,0 +1,232 @@
1
+ """Tests for tobac.utils.internal_utils.xarray_utils"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import pytest
9
+ import numpy as np
10
+ import xarray as xr
11
+
12
+ import tobac.utils.internal.xarray_utils as xr_utils
13
+ import tobac.testing as tbtest
14
+ import datetime
15
+
16
+
17
+ @pytest.mark.parametrize(
18
+ "dim_names, coord_dim_map, coord_looking_for, expected_out, expected_raise",
19
+ [
20
+ (
21
+ ("time", "altitude", "x", "y"), # dim_names
22
+ { # coord_dim_map
23
+ "time": ("time",),
24
+ "latitude": ("x", "y"),
25
+ "longitude": ("x", "y"),
26
+ "altmsl": ("altitude", "x", "y"),
27
+ },
28
+ "time", # coord_looking_for
29
+ 0,
30
+ False,
31
+ ),
32
+ (
33
+ ("time", "time", "time", "time", "time"), # dim_names
34
+ { # coord_dim_map
35
+ "time": ("time",),
36
+ },
37
+ "time", # coord_looking_for
38
+ 0,
39
+ True,
40
+ ),
41
+ (
42
+ ("time", "altitude", "x", "y"), # dim_names
43
+ { # coord_dim_map
44
+ "time": ("time",),
45
+ "latitude": ("x", "y"),
46
+ "longitude": ("x", "y"),
47
+ "altmsl": ("altitude", "x", "y"),
48
+ },
49
+ "altitude", # coord_looking_for
50
+ 1,
51
+ False,
52
+ ),
53
+ (
54
+ ("time", "altitude", "x", "y"), # dim_names
55
+ { # coord_dim_map
56
+ "time": ("time",),
57
+ "latitude": ("x", "y"),
58
+ "longitude": ("x", "y"),
59
+ "altmsl": ("altitude", "x", "y"),
60
+ },
61
+ "latitude", # coord_looking_for
62
+ None,
63
+ True,
64
+ ),
65
+ (
66
+ ("time", "altitude", "x", "y"), # dim_names
67
+ { # coord_dim_map
68
+ "time": ("time",),
69
+ "latitude": ("x", "y"),
70
+ "longitude": ("x", "y"),
71
+ "altmsl": ("altitude", "x", "y"),
72
+ },
73
+ "x", # coord_looking_for
74
+ 2,
75
+ False,
76
+ ),
77
+ (
78
+ ("time", "altitude", "x", "y"), # dim_names
79
+ { # coord_dim_map
80
+ "time": ("time",),
81
+ "latitude": ("x", "y"),
82
+ "longitude": ("x", "y"),
83
+ "altmsl": ("altitude", "x", "y"),
84
+ },
85
+ "z", # coord_looking_for
86
+ 2,
87
+ True,
88
+ ),
89
+ (
90
+ ("time", "altitude", "x", "y"), # dim_names
91
+ { # coord_dim_map
92
+ "t": ("time",),
93
+ "latitude": ("x", "y"),
94
+ "longitude": ("x", "y"),
95
+ "altmsl": ("altitude", "x", "y"),
96
+ },
97
+ "t", # coord_looking_for
98
+ 0,
99
+ False,
100
+ ),
101
+ ],
102
+ )
103
+ def test_find_axis_from_dim_coord(
104
+ dim_names: tuple[str],
105
+ coord_dim_map: dict,
106
+ coord_looking_for: str,
107
+ expected_out: Union[int, None],
108
+ expected_raise: bool,
109
+ ):
110
+ """Tests tobac.utils.internal.file_hdim_axes_3D
111
+
112
+ Parameters
113
+ ----------
114
+ dim_names: tuple[str]
115
+ Names of the dimensions to have
116
+ coord_dim_map: dict[str : tuple[str],]
117
+ Mapping of coordinates (keys) to dimensions (values)
118
+ coord_looking_for: str
119
+ what coordinate/dimension to look for
120
+ expected_out: Union[int, None]
121
+ What the expected output is
122
+ expected_raise: bool
123
+ Whether or not we expect a raise
124
+ """
125
+
126
+ # size of the array per dimension
127
+ arr_sz = 4
128
+ arr_da = np.empty((arr_sz,) * len(dim_names))
129
+ coord_vals = {}
130
+ for coord_nm in coord_dim_map:
131
+ coord_vals[coord_nm] = (
132
+ coord_dim_map[coord_nm],
133
+ np.empty((arr_sz,) * len(coord_dim_map[coord_nm])),
134
+ )
135
+
136
+ xr_da = xr.DataArray(arr_da, dims=dim_names, coords=coord_vals)
137
+ if expected_raise:
138
+ with pytest.raises(ValueError):
139
+ _ = xr_utils.find_axis_from_dim_coord(xr_da, coord_looking_for)
140
+ else:
141
+ out_val = xr_utils.find_axis_from_dim_coord(xr_da, coord_looking_for)
142
+ if expected_out is not None:
143
+ assert out_val == expected_out
144
+ else:
145
+ assert out_val is None
146
+
147
+
148
+ @pytest.mark.parametrize(
149
+ "dim_names, coord_dim_map, feature_pos, expected_vals",
150
+ [
151
+ (
152
+ ["time", "x", "y"],
153
+ {
154
+ "test_coord1": (tuple(), 1),
155
+ "test_coord_time": ("time", [5, 6, 7, 8, 9, 10]),
156
+ },
157
+ (1, 1),
158
+ {"test_coord1": (1, 1, 1), "test_coord_time": (5, 6, 7)},
159
+ ),
160
+ (
161
+ ["time", "x", "y"],
162
+ {
163
+ "test_coord_datetime": (
164
+ "time",
165
+ pd.date_range(
166
+ datetime.datetime(2000, 1, 1),
167
+ datetime.datetime(2000, 1, 1, 6),
168
+ freq="1h",
169
+ inclusive="left",
170
+ ),
171
+ ),
172
+ "test_coord_time": ("time", [5, 6, 7, 8, 9, 10]),
173
+ },
174
+ (1, 1),
175
+ {
176
+ "test_coord_datetime": pd.date_range(
177
+ datetime.datetime(2000, 1, 1),
178
+ datetime.datetime(2000, 1, 1, 3),
179
+ freq="1h",
180
+ inclusive="left",
181
+ ),
182
+ "test_coord_time": (5, 6, 7),
183
+ },
184
+ ),
185
+ ],
186
+ )
187
+ def test_add_coordinates_to_features_interpolate_along_other_dims(
188
+ dim_names: tuple[str],
189
+ coord_dim_map: dict,
190
+ feature_pos: tuple[int],
191
+ expected_vals: dict[str, tuple],
192
+ ):
193
+ time_len: int = 6
194
+ if len(feature_pos) == 2:
195
+ all_feats = tbtest.generate_single_feature(
196
+ feature_pos[0],
197
+ feature_pos[1],
198
+ feature_num=1,
199
+ num_frames=3,
200
+ max_h1=100,
201
+ max_h2=100,
202
+ )
203
+ arr_size = (time_len, 5, 5)
204
+
205
+ elif len(feature_pos) == 3:
206
+ all_feats = tbtest.generate_single_feature(
207
+ feature_pos[1],
208
+ feature_pos[2],
209
+ start_v=feature_pos[0],
210
+ feature_num=1,
211
+ num_frames=3,
212
+ max_h1=100,
213
+ max_h2=100,
214
+ )
215
+ arr_size = (time_len, 1, 5, 5)
216
+ else:
217
+ raise ValueError("too many dimensions")
218
+ coord_dim_map["time"] = (
219
+ ("time",),
220
+ [
221
+ datetime.datetime(2000, 1, 1, 0) + datetime.timedelta(hours=x)
222
+ for x in range(time_len)
223
+ ],
224
+ )
225
+
226
+ test_xr_arr = xr.DataArray(np.empty(arr_size), dims=dim_names, coords=coord_dim_map)
227
+
228
+ resulting_df = xr_utils.add_coordinates_to_features(all_feats, test_xr_arr)
229
+ for coord in coord_dim_map:
230
+ assert coord in resulting_df
231
+ if coord != "time":
232
+ assert np.all(resulting_df[coord].values == expected_vals[coord])