anemoi-utils 0.4.12__py3-none-any.whl → 0.4.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

Files changed (37) hide show
  1. anemoi/utils/__init__.py +1 -0
  2. anemoi/utils/__main__.py +12 -2
  3. anemoi/utils/_version.py +9 -4
  4. anemoi/utils/caching.py +138 -13
  5. anemoi/utils/checkpoints.py +81 -13
  6. anemoi/utils/cli.py +83 -7
  7. anemoi/utils/commands/__init__.py +4 -0
  8. anemoi/utils/commands/config.py +19 -2
  9. anemoi/utils/commands/requests.py +18 -2
  10. anemoi/utils/compatibility.py +6 -5
  11. anemoi/utils/config.py +254 -23
  12. anemoi/utils/dates.py +204 -50
  13. anemoi/utils/devtools.py +68 -7
  14. anemoi/utils/grib.py +30 -9
  15. anemoi/utils/grids.py +85 -8
  16. anemoi/utils/hindcasts.py +25 -8
  17. anemoi/utils/humanize.py +357 -52
  18. anemoi/utils/logs.py +31 -3
  19. anemoi/utils/mars/__init__.py +46 -12
  20. anemoi/utils/mars/requests.py +15 -1
  21. anemoi/utils/provenance.py +189 -32
  22. anemoi/utils/registry.py +234 -44
  23. anemoi/utils/remote/__init__.py +386 -38
  24. anemoi/utils/remote/s3.py +252 -29
  25. anemoi/utils/remote/ssh.py +140 -8
  26. anemoi/utils/s3.py +77 -4
  27. anemoi/utils/sanitise.py +52 -7
  28. anemoi/utils/testing.py +182 -0
  29. anemoi/utils/text.py +218 -54
  30. anemoi/utils/timer.py +91 -15
  31. {anemoi_utils-0.4.12.dist-info → anemoi_utils-0.4.14.dist-info}/METADATA +8 -4
  32. anemoi_utils-0.4.14.dist-info/RECORD +38 -0
  33. {anemoi_utils-0.4.12.dist-info → anemoi_utils-0.4.14.dist-info}/WHEEL +1 -1
  34. anemoi_utils-0.4.12.dist-info/RECORD +0 -37
  35. {anemoi_utils-0.4.12.dist-info → anemoi_utils-0.4.14.dist-info}/entry_points.txt +0 -0
  36. {anemoi_utils-0.4.12.dist-info → anemoi_utils-0.4.14.dist-info/licenses}/LICENSE +0 -0
  37. {anemoi_utils-0.4.12.dist-info → anemoi_utils-0.4.14.dist-info}/top_level.txt +0 -0
anemoi/utils/dates.py CHANGED
@@ -11,11 +11,29 @@
11
11
  import calendar
12
12
  import datetime
13
13
  import re
14
+ from typing import Any
15
+ from typing import List
16
+ from typing import Optional
17
+ from typing import Set
18
+ from typing import Tuple
19
+ from typing import Union
14
20
 
15
21
  import aniso8601
16
22
 
17
23
 
18
- def normalise_frequency(frequency):
24
+ def normalise_frequency(frequency: Union[int, str]) -> int:
25
+ """Normalise frequency to hours.
26
+
27
+ Parameters
28
+ ----------
29
+ frequency : int or str
30
+ The frequency to normalise.
31
+
32
+ Returns
33
+ -------
34
+ int
35
+ The normalised frequency in hours.
36
+ """
19
37
  if isinstance(frequency, int):
20
38
  return frequency
21
39
  assert isinstance(frequency, str), (type(frequency), frequency)
@@ -25,7 +43,7 @@ def normalise_frequency(frequency):
25
43
  return {"h": v, "d": v * 24}[unit]
26
44
 
27
45
 
28
- def _no_time_zone(date) -> datetime.datetime:
46
+ def _no_time_zone(date: datetime.datetime) -> datetime.datetime:
29
47
  """Remove time zone information from a date.
30
48
 
31
49
  Parameters
@@ -43,7 +61,7 @@ def _no_time_zone(date) -> datetime.datetime:
43
61
 
44
62
 
45
63
  # this function is use in anemoi-datasets
46
- def as_datetime(date, keep_time_zone=False) -> datetime.datetime:
64
+ def as_datetime(date: Union[datetime.date, datetime.datetime, str], keep_time_zone: bool = False) -> datetime.datetime:
47
65
  """Convert a date to a datetime object, removing any time zone information.
48
66
 
49
67
  Parameters
@@ -73,7 +91,23 @@ def as_datetime(date, keep_time_zone=False) -> datetime.datetime:
73
91
  raise ValueError(f"Invalid date type: {type(date)}")
74
92
 
75
93
 
76
- def _as_datetime_list(date, default_increment):
94
+ def _as_datetime_list(
95
+ date: Union[datetime.date, datetime.datetime, str], default_increment: datetime.timedelta
96
+ ) -> iter:
97
+ """Convert a date to a list of datetime objects.
98
+
99
+ Parameters
100
+ ----------
101
+ date : datetime.date or datetime.datetime or str
102
+ The date to convert.
103
+ default_increment : datetime.timedelta
104
+ The default increment for the list.
105
+
106
+ Returns
107
+ -------
108
+ iter
109
+ An iterator of datetime objects.
110
+ """
77
111
  if isinstance(date, (list, tuple)):
78
112
  for d in date:
79
113
  yield from _as_datetime_list(d, default_increment)
@@ -102,12 +136,28 @@ def _as_datetime_list(date, default_increment):
102
136
  yield as_datetime(date)
103
137
 
104
138
 
105
- def as_datetime_list(date, default_increment=1):
139
+ def as_datetime_list(
140
+ date: Union[datetime.date, datetime.datetime, str], default_increment: int = 1
141
+ ) -> list[datetime.datetime]:
142
+ """Convert a date to a list of datetime objects.
143
+
144
+ Parameters
145
+ ----------
146
+ date : datetime.date or datetime.datetime or str
147
+ The date to convert.
148
+ default_increment : int, optional
149
+ The default increment in hours, by default 1.
150
+
151
+ Returns
152
+ -------
153
+ list of datetime.datetime
154
+ A list of datetime objects.
155
+ """
106
156
  default_increment = frequency_to_timedelta(default_increment)
107
157
  return list(_as_datetime_list(date, default_increment))
108
158
 
109
159
 
110
- def as_timedelta(frequency) -> datetime.timedelta:
160
+ def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
111
161
  """Convert anything to a timedelta object.
112
162
 
113
163
  Parameters
@@ -178,12 +228,23 @@ def as_timedelta(frequency) -> datetime.timedelta:
178
228
  raise ValueError(f"Cannot convert frequency {frequency} to timedelta")
179
229
 
180
230
 
181
- def frequency_to_timedelta(frequency) -> datetime.timedelta:
182
- """Convert a frequency to a timedelta object."""
231
+ def frequency_to_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
232
+ """Convert a frequency to a timedelta object.
233
+
234
+ Parameters
235
+ ----------
236
+ frequency : int or str or datetime.timedelta
237
+ The frequency to convert.
238
+
239
+ Returns
240
+ -------
241
+ datetime.timedelta
242
+ The timedelta object.
243
+ """
183
244
  return as_timedelta(frequency)
184
245
 
185
246
 
186
- def frequency_to_string(frequency) -> str:
247
+ def frequency_to_string(frequency: datetime.timedelta) -> str:
187
248
  """Convert a frequency (i.e. a datetime.timedelta) to a string.
188
249
 
189
250
  Parameters
@@ -230,20 +291,19 @@ def frequency_to_string(frequency) -> str:
230
291
  return str(frequency)
231
292
 
232
293
 
233
- def frequency_to_seconds(frequency) -> int:
294
+ def frequency_to_seconds(frequency: Union[int, str, datetime.timedelta]) -> int:
234
295
  """Convert a frequency to seconds.
235
296
 
236
297
  Parameters
237
298
  ----------
238
- frequency : _type_
239
- _description_
299
+ frequency : int or str or datetime.timedelta
300
+ The frequency to convert.
240
301
 
241
302
  Returns
242
303
  -------
243
304
  int
244
305
  Number of seconds.
245
306
  """
246
-
247
307
  result = frequency_to_timedelta(frequency).total_seconds()
248
308
  assert int(result) == result, result
249
309
  return int(result)
@@ -276,7 +336,19 @@ MONTH = {
276
336
  }
277
337
 
278
338
 
279
- def _make_day(day):
339
+ def _make_day(day: Optional[Tuple[int, List[int]]]) -> Set[int]:
340
+ """Create a set of days.
341
+
342
+ Parameters
343
+ ----------
344
+ day : int or list of int or None
345
+ The day(s) to include in the set.
346
+
347
+ Returns
348
+ -------
349
+ set of int
350
+ A set of days.
351
+ """
280
352
  if day is None:
281
353
  return set(range(1, 32))
282
354
  if not isinstance(day, list):
@@ -284,7 +356,19 @@ def _make_day(day):
284
356
  return {int(d) for d in day}
285
357
 
286
358
 
287
- def _make_week(week):
359
+ def _make_week(week: Optional[Tuple[str, List[str]]]) -> Set[int]:
360
+ """Create a set of weekdays.
361
+
362
+ Parameters
363
+ ----------
364
+ week : str or list of str or None
365
+ The weekday(s) to include in the set.
366
+
367
+ Returns
368
+ -------
369
+ set of int
370
+ A set of weekdays.
371
+ """
288
372
  if week is None:
289
373
  return set(range(7))
290
374
  if not isinstance(week, list):
@@ -292,7 +376,19 @@ def _make_week(week):
292
376
  return {DOW[w.lower()] for w in week}
293
377
 
294
378
 
295
- def _make_months(months):
379
+ def _make_months(months: Optional[Union[int, str, List[Union[int, str]]]]) -> Set[int]:
380
+ """Create a set of months.
381
+
382
+ Parameters
383
+ ----------
384
+ months : int or str or list of int or str or None
385
+ The month(s) to include in the set.
386
+
387
+ Returns
388
+ -------
389
+ set of int
390
+ A set of months.
391
+ """
296
392
  if months is None:
297
393
  return set(range(1, 13))
298
394
 
@@ -305,23 +401,32 @@ def _make_months(months):
305
401
  class DateTimes:
306
402
  """The DateTimes class is an iterator that generates datetime objects within a given range."""
307
403
 
308
- def __init__(self, start, end, increment=24, *, day_of_month=None, day_of_week=None, calendar_months=None):
309
- """_summary_
404
+ def __init__(
405
+ self,
406
+ start: Union[datetime.date, datetime.datetime, str],
407
+ end: Union[datetime.date, datetime.datetime, str],
408
+ increment: int = 24,
409
+ *,
410
+ day_of_month: Optional[Tuple[int, List[int]]] = None,
411
+ day_of_week: Optional[Tuple[str, List[str]]] = None,
412
+ calendar_months: Optional[Union[int, str, List[Union[int, str]]]] = None,
413
+ ):
414
+ """Initialize the DateTimes iterator.
310
415
 
311
416
  Parameters
312
417
  ----------
313
- start : _type_
314
- _description_
315
- end : _type_
316
- _description_
418
+ start : datetime.date or datetime.datetime or str
419
+ The start date.
420
+ end : datetime.date or datetime.datetime or str
421
+ The end date.
317
422
  increment : int, optional
318
- _description_, by default 24
319
- day_of_month : _type_, optional
320
- _description_, by default None
321
- day_of_week : _type_, optional
322
- _description_, by default None
323
- calendar_months : _type_, optional
324
- _description_, by default None
423
+ The increment in hours, by default 24.
424
+ day_of_month : int or list of int or None, optional
425
+ The day(s) of the month to include, by default None.
426
+ day_of_week : str or list of str or None, optional
427
+ The day(s) of the week to include, by default None.
428
+ calendar_months : int or str or list of int or str or None, optional
429
+ The month(s) to include, by default None.
325
430
  """
326
431
  self.start = as_datetime(start)
327
432
  self.end = as_datetime(end)
@@ -330,7 +435,14 @@ class DateTimes:
330
435
  self.day_of_week = _make_week(day_of_week)
331
436
  self.calendar_months = _make_months(calendar_months)
332
437
 
333
- def __iter__(self):
438
+ def __iter__(self) -> iter:
439
+ """Iterate over the datetime objects.
440
+
441
+ Returns
442
+ -------
443
+ iter
444
+ An iterator of datetime objects.
445
+ """
334
446
  date = self.start
335
447
  while date <= self.end:
336
448
  if (
@@ -346,13 +458,13 @@ class DateTimes:
346
458
  class Year(DateTimes):
347
459
  """Year is defined as the months of January to December."""
348
460
 
349
- def __init__(self, year, **kwargs):
350
- """_summary_
461
+ def __init__(self, year: int, **kwargs):
462
+ """Initialize the Year iterator.
351
463
 
352
464
  Parameters
353
465
  ----------
354
466
  year : int
355
- _description_
467
+ The year.
356
468
  """
357
469
  super().__init__(datetime.datetime(year, 1, 1), datetime.datetime(year, 12, 31), **kwargs)
358
470
 
@@ -360,13 +472,13 @@ class Year(DateTimes):
360
472
  class Winter(DateTimes):
361
473
  """Winter is defined as the months of December, January and February."""
362
474
 
363
- def __init__(self, year, **kwargs):
364
- """_summary_
475
+ def __init__(self, year: int, **kwargs):
476
+ """Initialize the Winter iterator.
365
477
 
366
478
  Parameters
367
479
  ----------
368
480
  year : int
369
- _description_
481
+ The year.
370
482
  """
371
483
  super().__init__(
372
484
  datetime.datetime(year, 12, 1),
@@ -378,13 +490,13 @@ class Winter(DateTimes):
378
490
  class Spring(DateTimes):
379
491
  """Spring is defined as the months of March, April and May."""
380
492
 
381
- def __init__(self, year, **kwargs):
382
- """_summary_
493
+ def __init__(self, year: int, **kwargs):
494
+ """Initialize the Spring iterator.
383
495
 
384
496
  Parameters
385
497
  ----------
386
498
  year : int
387
- _description_
499
+ The year.
388
500
  """
389
501
  super().__init__(datetime.datetime(year, 3, 1), datetime.datetime(year, 5, 31), **kwargs)
390
502
 
@@ -392,13 +504,13 @@ class Spring(DateTimes):
392
504
  class Summer(DateTimes):
393
505
  """Summer is defined as the months of June, July and August."""
394
506
 
395
- def __init__(self, year, **kwargs):
396
- """_summary_
507
+ def __init__(self, year: int, **kwargs):
508
+ """Initialize the Summer iterator.
397
509
 
398
510
  Parameters
399
511
  ----------
400
512
  year : int
401
- _description_
513
+ The year.
402
514
  """
403
515
  super().__init__(datetime.datetime(year, 6, 1), datetime.datetime(year, 8, 31), **kwargs)
404
516
 
@@ -406,13 +518,13 @@ class Summer(DateTimes):
406
518
  class Autumn(DateTimes):
407
519
  """Autumn is defined as the months of September, October and November."""
408
520
 
409
- def __init__(self, year, **kwargs):
410
- """_summary_
521
+ def __init__(self, year: int, **kwargs):
522
+ """Initialize the Autumn iterator.
411
523
 
412
524
  Parameters
413
525
  ----------
414
526
  year : int
415
- _description_
527
+ The year.
416
528
  """
417
529
  super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs)
418
530
 
@@ -420,13 +532,27 @@ class Autumn(DateTimes):
420
532
  class ConcatDateTimes:
421
533
  """ConcatDateTimes is an iterator that generates datetime objects from a list of dates."""
422
534
 
423
- def __init__(self, *dates):
535
+ def __init__(self, *dates: DateTimes):
536
+ """Initialize the ConcatDateTimes iterator.
537
+
538
+ Parameters
539
+ ----------
540
+ dates : DateTimes
541
+ The list of DateTimes objects.
542
+ """
424
543
  if len(dates) == 1 and isinstance(dates[0], list):
425
544
  dates = dates[0]
426
545
 
427
546
  self.dates = dates
428
547
 
429
- def __iter__(self):
548
+ def __iter__(self) -> iter:
549
+ """Iterate over the datetime objects.
550
+
551
+ Returns
552
+ -------
553
+ iter
554
+ An iterator of datetime objects.
555
+ """
430
556
  for date in self.dates:
431
557
  yield from date
432
558
 
@@ -434,15 +560,43 @@ class ConcatDateTimes:
434
560
  class EnumDateTimes:
435
561
  """EnumDateTimes is an iterator that generates datetime objects from a list of dates."""
436
562
 
437
- def __init__(self, dates):
563
+ def __init__(self, dates: list[Union[datetime.date, datetime.datetime, str]]):
564
+ """Initialize the EnumDateTimes iterator.
565
+
566
+ Parameters
567
+ ----------
568
+ dates : list of datetime.date or datetime.datetime or str
569
+ The list of dates.
570
+ """
438
571
  self.dates = dates
439
572
 
440
- def __iter__(self):
573
+ def __iter__(self) -> iter:
574
+ """Iterate over the datetime objects.
575
+
576
+ Returns
577
+ -------
578
+ iter
579
+ An iterator of datetime objects.
580
+ """
441
581
  for date in self.dates:
442
582
  yield as_datetime(date)
443
583
 
444
584
 
445
- def datetimes_factory(*args, **kwargs):
585
+ def datetimes_factory(*args: Any, **kwargs: Any) -> Union[DateTimes, ConcatDateTimes, EnumDateTimes]:
586
+ """Create a DateTimes, ConcatDateTimes, or EnumDateTimes object.
587
+
588
+ Parameters
589
+ ----------
590
+ *args : Any
591
+ Positional arguments.
592
+ **kwargs : Any
593
+ Keyword arguments.
594
+
595
+ Returns
596
+ -------
597
+ DateTimes or ConcatDateTimes or EnumDateTimes
598
+ The created object.
599
+ """
446
600
  if args and kwargs:
447
601
  raise ValueError("Cannot provide both args and kwargs for a list of dates")
448
602
 
anemoi/utils/devtools.py CHANGED
@@ -8,29 +8,74 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
+ from typing import Any
12
+
11
13
  import cartopy.crs as ccrs
12
14
  import cartopy.feature as cfeature
13
15
  import matplotlib.pyplot as plt
14
16
  import matplotlib.tri as tri
15
17
  import numpy as np
16
18
 
17
- """FOR DEVELOPMENT PURPOSES ONLY
19
+ """FOR DEVELOPMENT PURPOSES ONLY.
18
20
 
19
21
  This module contains
20
-
21
22
  """
22
23
 
23
24
  # TODO: use earthkit-plots
24
25
 
25
26
 
26
- def fix(lons):
27
+ def fix(lons: np.ndarray) -> np.ndarray:
28
+ """Fix longitudes greater than 180 degrees.
29
+
30
+ Parameters
31
+ ----------
32
+ lons : np.ndarray
33
+ Array of longitudes.
34
+
35
+ Returns
36
+ -------
37
+ np.ndarray
38
+ Array of fixed longitudes.
39
+ """
27
40
  return np.where(lons > 180, lons - 360, lons)
28
41
 
29
42
 
30
43
  def plot_values(
31
- values, latitudes, longitudes, title=None, missing_value=None, min_value=None, max_value=None, **kwargs
32
- ):
33
-
44
+ values: np.ndarray,
45
+ latitudes: np.ndarray,
46
+ longitudes: np.ndarray,
47
+ title: str = None,
48
+ missing_value: float = None,
49
+ min_value: float = None,
50
+ max_value: float = None,
51
+ **kwargs: dict,
52
+ ) -> plt.Axes:
53
+ """Plot values on a map.
54
+
55
+ Parameters
56
+ ----------
57
+ values : np.ndarray
58
+ Array of values to plot.
59
+ latitudes : np.ndarray
60
+ Array of latitudes.
61
+ longitudes : np.ndarray
62
+ Array of longitudes.
63
+ title : str, optional
64
+ Title of the plot, by default None.
65
+ missing_value : float, optional
66
+ Value to use for missing data, by default None.
67
+ min_value : float, optional
68
+ Minimum value for the plot, by default None.
69
+ max_value : float, optional
70
+ Maximum value for the plot, by default None.
71
+ **kwargs : dict
72
+ Additional keyword arguments for the plot.
73
+
74
+ Returns
75
+ -------
76
+ plt.Axes
77
+ The plot axes.
78
+ """
34
79
  _, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
35
80
  ax.coastlines()
36
81
  ax.add_feature(cfeature.BORDERS, linestyle=":")
@@ -77,7 +122,23 @@ def plot_values(
77
122
  return ax
78
123
 
79
124
 
80
- def plot_field(field, title=None, **kwargs):
125
+ def plot_field(field: Any, title: str = None, **kwargs: dict) -> plt.Axes:
126
+ """Plot a field on a map.
127
+
128
+ Parameters
129
+ ----------
130
+ field : Any
131
+ The field to plot.
132
+ title : str, optional
133
+ Title of the plot, by default None.
134
+ **kwargs : dict
135
+ Additional keyword arguments for the plot.
136
+
137
+ Returns
138
+ -------
139
+ plt.Axes
140
+ The plot axes.
141
+ """
81
142
  values = field.to_numpy(flatten=True)
82
143
  latitudes, longitudes = field.grid_points()
83
144
  return plot_values(values, latitudes, longitudes, title=title, **kwargs)
anemoi/utils/grib.py CHANGED
@@ -11,11 +11,12 @@
11
11
  """Utilities for working with GRIB parameters.
12
12
 
13
13
  See https://codes.ecmwf.int/grib/param-db/ for more information.
14
-
15
14
  """
16
15
 
17
16
  import logging
18
17
  import re
18
+ from typing import Dict
19
+ from typing import Union
19
20
 
20
21
  import requests
21
22
 
@@ -25,7 +26,14 @@ LOG = logging.getLogger(__name__)
25
26
 
26
27
 
27
28
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
28
- def _units():
29
+ def _units() -> Dict[str, str]:
30
+ """Fetch and cache GRIB parameter units.
31
+
32
+ Returns
33
+ -------
34
+ dict
35
+ A dictionary mapping unit ids to their names.
36
+ """
29
37
  r = requests.get("https://codes.ecmwf.int/parameter-database/api/v1/unit/")
30
38
  r.raise_for_status()
31
39
  units = r.json()
@@ -33,7 +41,24 @@ def _units():
33
41
 
34
42
 
35
43
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
36
- def _search_param(name):
44
+ def _search_param(name: str) -> Dict[str, Union[str, int]]:
45
+ """Search for a GRIB parameter by name.
46
+
47
+ Parameters
48
+ ----------
49
+ name : str
50
+ Parameter name to search for.
51
+
52
+ Returns
53
+ -------
54
+ dict
55
+ A dictionary containing parameter details.
56
+
57
+ Raises
58
+ ------
59
+ KeyError
60
+ If no parameter is found.
61
+ """
37
62
  name = re.escape(name)
38
63
  r = requests.get(f"https://codes.ecmwf.int/parameter-database/api/v1/param/?search=^{name}$&regex=true")
39
64
  r.raise_for_status()
@@ -68,7 +93,6 @@ def shortname_to_paramid(shortname: str) -> int:
68
93
 
69
94
  >>> shortname_to_paramid("2t")
70
95
  167
71
-
72
96
  """
73
97
  return _search_param(shortname)["id"]
74
98
 
@@ -88,12 +112,11 @@ def paramid_to_shortname(paramid: int) -> str:
88
112
 
89
113
  >>> paramid_to_shortname(167)
90
114
  '2t'
91
-
92
115
  """
93
116
  return _search_param(str(paramid))["shortname"]
94
117
 
95
118
 
96
- def units(param) -> str:
119
+ def units(param: Union[int, str]) -> str:
97
120
  """Return the units of a GRIB parameter given its name or id.
98
121
 
99
122
  Parameters
@@ -108,14 +131,13 @@ def units(param) -> str:
108
131
 
109
132
  >>> unit(167)
110
133
  'K'
111
-
112
134
  """
113
135
 
114
136
  unit_id = str(_search_param(str(param))["unit_id"])
115
137
  return _units()[unit_id]
116
138
 
117
139
 
118
- def must_be_positive(param) -> bool:
140
+ def must_be_positive(param: Union[int, str]) -> bool:
119
141
  """Check if a parameter must be positive.
120
142
 
121
143
  Parameters
@@ -130,6 +152,5 @@ def must_be_positive(param) -> bool:
130
152
 
131
153
  >>> must_be_positive("tp")
132
154
  True
133
-
134
155
  """
135
156
  return units(param) in ["m", "kg kg**-1", "m of water equivalent"]