anemoi-utils 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

anemoi/utils/dates.py CHANGED
@@ -11,11 +11,29 @@
11
11
  import calendar
12
12
  import datetime
13
13
  import re
14
+ from typing import Any
15
+ from typing import List
16
+ from typing import Optional
17
+ from typing import Set
18
+ from typing import Tuple
19
+ from typing import Union
14
20
 
15
21
  import aniso8601
16
22
 
17
23
 
18
- def normalise_frequency(frequency):
24
+ def normalise_frequency(frequency: Union[int, str]) -> int:
25
+ """Normalise frequency to hours.
26
+
27
+ Parameters
28
+ ----------
29
+ frequency : int or str
30
+ The frequency to normalise.
31
+
32
+ Returns
33
+ -------
34
+ int
35
+ The normalised frequency in hours.
36
+ """
19
37
  if isinstance(frequency, int):
20
38
  return frequency
21
39
  assert isinstance(frequency, str), (type(frequency), frequency)
@@ -25,7 +43,7 @@ def normalise_frequency(frequency):
25
43
  return {"h": v, "d": v * 24}[unit]
26
44
 
27
45
 
28
- def _no_time_zone(date) -> datetime.datetime:
46
+ def _no_time_zone(date: datetime.datetime) -> datetime.datetime:
29
47
  """Remove time zone information from a date.
30
48
 
31
49
  Parameters
@@ -43,7 +61,7 @@ def _no_time_zone(date) -> datetime.datetime:
43
61
 
44
62
 
45
63
  # this function is use in anemoi-datasets
46
- def as_datetime(date, keep_time_zone=False) -> datetime.datetime:
64
+ def as_datetime(date: Union[datetime.date, datetime.datetime, str], keep_time_zone: bool = False) -> datetime.datetime:
47
65
  """Convert a date to a datetime object, removing any time zone information.
48
66
 
49
67
  Parameters
@@ -73,7 +91,23 @@ def as_datetime(date, keep_time_zone=False) -> datetime.datetime:
73
91
  raise ValueError(f"Invalid date type: {type(date)}")
74
92
 
75
93
 
76
- def _as_datetime_list(date, default_increment):
94
+ def _as_datetime_list(
95
+ date: Union[datetime.date, datetime.datetime, str], default_increment: datetime.timedelta
96
+ ) -> iter:
97
+ """Convert a date to a list of datetime objects.
98
+
99
+ Parameters
100
+ ----------
101
+ date : datetime.date or datetime.datetime or str
102
+ The date to convert.
103
+ default_increment : datetime.timedelta
104
+ The default increment for the list.
105
+
106
+ Returns
107
+ -------
108
+ iter
109
+ An iterator of datetime objects.
110
+ """
77
111
  if isinstance(date, (list, tuple)):
78
112
  for d in date:
79
113
  yield from _as_datetime_list(d, default_increment)
@@ -102,12 +136,28 @@ def _as_datetime_list(date, default_increment):
102
136
  yield as_datetime(date)
103
137
 
104
138
 
105
- def as_datetime_list(date, default_increment=1):
139
+ def as_datetime_list(
140
+ date: Union[datetime.date, datetime.datetime, str], default_increment: int = 1
141
+ ) -> list[datetime.datetime]:
142
+ """Convert a date to a list of datetime objects.
143
+
144
+ Parameters
145
+ ----------
146
+ date : datetime.date or datetime.datetime or str
147
+ The date to convert.
148
+ default_increment : int, optional
149
+ The default increment in hours, by default 1.
150
+
151
+ Returns
152
+ -------
153
+ list of datetime.datetime
154
+ A list of datetime objects.
155
+ """
106
156
  default_increment = frequency_to_timedelta(default_increment)
107
157
  return list(_as_datetime_list(date, default_increment))
108
158
 
109
159
 
110
- def as_timedelta(frequency) -> datetime.timedelta:
160
+ def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
111
161
  """Convert anything to a timedelta object.
112
162
 
113
163
  Parameters
@@ -155,12 +205,19 @@ def as_timedelta(frequency) -> datetime.timedelta:
155
205
  unit = {"h": "hours", "d": "days", "s": "seconds", "m": "minutes"}[unit]
156
206
  return datetime.timedelta(**{unit: v})
157
207
 
158
- m = frequency.split(":")
159
- if len(m) == 2:
160
- return datetime.timedelta(hours=int(m[0]), minutes=int(m[1]))
208
+ if re.match(r"^\d+:\d+(:\d+)?$", frequency):
209
+ m = frequency.split(":")
210
+ if len(m) == 2:
211
+ return datetime.timedelta(hours=int(m[0]), minutes=int(m[1]))
212
+
213
+ if len(m) == 3:
214
+ return datetime.timedelta(hours=int(m[0]), minutes=int(m[1]), seconds=int(m[2]))
161
215
 
162
- if len(m) == 3:
163
- return datetime.timedelta(hours=int(m[0]), minutes=int(m[1]), seconds=int(m[2]))
216
+ if re.match(r"^\d+ days?, \d+:\d+:\d+$", frequency):
217
+ m = frequency.split(", ")
218
+ days = int(m[0].split()[0])
219
+ hms = m[1].split(":")
220
+ return datetime.timedelta(days=days, hours=int(hms[0]), minutes=int(hms[1]), seconds=int(hms[2]))
164
221
 
165
222
  # ISO8601
166
223
  try:
@@ -171,12 +228,23 @@ def as_timedelta(frequency) -> datetime.timedelta:
171
228
  raise ValueError(f"Cannot convert frequency {frequency} to timedelta")
172
229
 
173
230
 
174
- def frequency_to_timedelta(frequency) -> datetime.timedelta:
175
- """Convert a frequency to a timedelta object."""
231
+ def frequency_to_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
232
+ """Convert a frequency to a timedelta object.
233
+
234
+ Parameters
235
+ ----------
236
+ frequency : int or str or datetime.timedelta
237
+ The frequency to convert.
238
+
239
+ Returns
240
+ -------
241
+ datetime.timedelta
242
+ The timedelta object.
243
+ """
176
244
  return as_timedelta(frequency)
177
245
 
178
246
 
179
- def frequency_to_string(frequency) -> str:
247
+ def frequency_to_string(frequency: datetime.timedelta) -> str:
180
248
  """Convert a frequency (i.e. a datetime.timedelta) to a string.
181
249
 
182
250
  Parameters
@@ -223,20 +291,19 @@ def frequency_to_string(frequency) -> str:
223
291
  return str(frequency)
224
292
 
225
293
 
226
- def frequency_to_seconds(frequency) -> int:
294
+ def frequency_to_seconds(frequency: Union[int, str, datetime.timedelta]) -> int:
227
295
  """Convert a frequency to seconds.
228
296
 
229
297
  Parameters
230
298
  ----------
231
- frequency : _type_
232
- _description_
299
+ frequency : int or str or datetime.timedelta
300
+ The frequency to convert.
233
301
 
234
302
  Returns
235
303
  -------
236
304
  int
237
305
  Number of seconds.
238
306
  """
239
-
240
307
  result = frequency_to_timedelta(frequency).total_seconds()
241
308
  assert int(result) == result, result
242
309
  return int(result)
@@ -269,7 +336,19 @@ MONTH = {
269
336
  }
270
337
 
271
338
 
272
- def _make_day(day):
339
+ def _make_day(day: Optional[Tuple[int, List[int]]]) -> Set[int]:
340
+ """Create a set of days.
341
+
342
+ Parameters
343
+ ----------
344
+ day : int or list of int or None
345
+ The day(s) to include in the set.
346
+
347
+ Returns
348
+ -------
349
+ set of int
350
+ A set of days.
351
+ """
273
352
  if day is None:
274
353
  return set(range(1, 32))
275
354
  if not isinstance(day, list):
@@ -277,7 +356,19 @@ def _make_day(day):
277
356
  return {int(d) for d in day}
278
357
 
279
358
 
280
- def _make_week(week):
359
+ def _make_week(week: Optional[Tuple[str, List[str]]]) -> Set[int]:
360
+ """Create a set of weekdays.
361
+
362
+ Parameters
363
+ ----------
364
+ week : str or list of str or None
365
+ The weekday(s) to include in the set.
366
+
367
+ Returns
368
+ -------
369
+ set of int
370
+ A set of weekdays.
371
+ """
281
372
  if week is None:
282
373
  return set(range(7))
283
374
  if not isinstance(week, list):
@@ -285,7 +376,19 @@ def _make_week(week):
285
376
  return {DOW[w.lower()] for w in week}
286
377
 
287
378
 
288
- def _make_months(months):
379
+ def _make_months(months: Optional[Union[int, str, List[Union[int, str]]]]) -> Set[int]:
380
+ """Create a set of months.
381
+
382
+ Parameters
383
+ ----------
384
+ months : int or str or list of int or str or None
385
+ The month(s) to include in the set.
386
+
387
+ Returns
388
+ -------
389
+ set of int
390
+ A set of months.
391
+ """
289
392
  if months is None:
290
393
  return set(range(1, 13))
291
394
 
@@ -298,23 +401,32 @@ def _make_months(months):
298
401
  class DateTimes:
299
402
  """The DateTimes class is an iterator that generates datetime objects within a given range."""
300
403
 
301
- def __init__(self, start, end, increment=24, *, day_of_month=None, day_of_week=None, calendar_months=None):
302
- """_summary_
404
+ def __init__(
405
+ self,
406
+ start: Union[datetime.date, datetime.datetime, str],
407
+ end: Union[datetime.date, datetime.datetime, str],
408
+ increment: int = 24,
409
+ *,
410
+ day_of_month: Optional[Tuple[int, List[int]]] = None,
411
+ day_of_week: Optional[Tuple[str, List[str]]] = None,
412
+ calendar_months: Optional[Union[int, str, List[Union[int, str]]]] = None,
413
+ ):
414
+ """Initialize the DateTimes iterator.
303
415
 
304
416
  Parameters
305
417
  ----------
306
- start : _type_
307
- _description_
308
- end : _type_
309
- _description_
418
+ start : datetime.date or datetime.datetime or str
419
+ The start date.
420
+ end : datetime.date or datetime.datetime or str
421
+ The end date.
310
422
  increment : int, optional
311
- _description_, by default 24
312
- day_of_month : _type_, optional
313
- _description_, by default None
314
- day_of_week : _type_, optional
315
- _description_, by default None
316
- calendar_months : _type_, optional
317
- _description_, by default None
423
+ The increment in hours, by default 24.
424
+ day_of_month : int or list of int or None, optional
425
+ The day(s) of the month to include, by default None.
426
+ day_of_week : str or list of str or None, optional
427
+ The day(s) of the week to include, by default None.
428
+ calendar_months : int or str or list of int or str or None, optional
429
+ The month(s) to include, by default None.
318
430
  """
319
431
  self.start = as_datetime(start)
320
432
  self.end = as_datetime(end)
@@ -323,7 +435,14 @@ class DateTimes:
323
435
  self.day_of_week = _make_week(day_of_week)
324
436
  self.calendar_months = _make_months(calendar_months)
325
437
 
326
- def __iter__(self):
438
+ def __iter__(self) -> iter:
439
+ """Iterate over the datetime objects.
440
+
441
+ Returns
442
+ -------
443
+ iter
444
+ An iterator of datetime objects.
445
+ """
327
446
  date = self.start
328
447
  while date <= self.end:
329
448
  if (
@@ -339,13 +458,13 @@ class DateTimes:
339
458
  class Year(DateTimes):
340
459
  """Year is defined as the months of January to December."""
341
460
 
342
- def __init__(self, year, **kwargs):
343
- """_summary_
461
+ def __init__(self, year: int, **kwargs):
462
+ """Initialize the Year iterator.
344
463
 
345
464
  Parameters
346
465
  ----------
347
466
  year : int
348
- _description_
467
+ The year.
349
468
  """
350
469
  super().__init__(datetime.datetime(year, 1, 1), datetime.datetime(year, 12, 31), **kwargs)
351
470
 
@@ -353,13 +472,13 @@ class Year(DateTimes):
353
472
  class Winter(DateTimes):
354
473
  """Winter is defined as the months of December, January and February."""
355
474
 
356
- def __init__(self, year, **kwargs):
357
- """_summary_
475
+ def __init__(self, year: int, **kwargs):
476
+ """Initialize the Winter iterator.
358
477
 
359
478
  Parameters
360
479
  ----------
361
480
  year : int
362
- _description_
481
+ The year.
363
482
  """
364
483
  super().__init__(
365
484
  datetime.datetime(year, 12, 1),
@@ -371,13 +490,13 @@ class Winter(DateTimes):
371
490
  class Spring(DateTimes):
372
491
  """Spring is defined as the months of March, April and May."""
373
492
 
374
- def __init__(self, year, **kwargs):
375
- """_summary_
493
+ def __init__(self, year: int, **kwargs):
494
+ """Initialize the Spring iterator.
376
495
 
377
496
  Parameters
378
497
  ----------
379
498
  year : int
380
- _description_
499
+ The year.
381
500
  """
382
501
  super().__init__(datetime.datetime(year, 3, 1), datetime.datetime(year, 5, 31), **kwargs)
383
502
 
@@ -385,13 +504,13 @@ class Spring(DateTimes):
385
504
  class Summer(DateTimes):
386
505
  """Summer is defined as the months of June, July and August."""
387
506
 
388
- def __init__(self, year, **kwargs):
389
- """_summary_
507
+ def __init__(self, year: int, **kwargs):
508
+ """Initialize the Summer iterator.
390
509
 
391
510
  Parameters
392
511
  ----------
393
512
  year : int
394
- _description_
513
+ The year.
395
514
  """
396
515
  super().__init__(datetime.datetime(year, 6, 1), datetime.datetime(year, 8, 31), **kwargs)
397
516
 
@@ -399,13 +518,13 @@ class Summer(DateTimes):
399
518
  class Autumn(DateTimes):
400
519
  """Autumn is defined as the months of September, October and November."""
401
520
 
402
- def __init__(self, year, **kwargs):
403
- """_summary_
521
+ def __init__(self, year: int, **kwargs):
522
+ """Initialize the Autumn iterator.
404
523
 
405
524
  Parameters
406
525
  ----------
407
526
  year : int
408
- _description_
527
+ The year.
409
528
  """
410
529
  super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs)
411
530
 
@@ -413,13 +532,27 @@ class Autumn(DateTimes):
413
532
  class ConcatDateTimes:
414
533
  """ConcatDateTimes is an iterator that generates datetime objects from a list of dates."""
415
534
 
416
- def __init__(self, *dates):
535
+ def __init__(self, *dates: DateTimes):
536
+ """Initialize the ConcatDateTimes iterator.
537
+
538
+ Parameters
539
+ ----------
540
+ dates : DateTimes
541
+ The list of DateTimes objects.
542
+ """
417
543
  if len(dates) == 1 and isinstance(dates[0], list):
418
544
  dates = dates[0]
419
545
 
420
546
  self.dates = dates
421
547
 
422
- def __iter__(self):
548
+ def __iter__(self) -> iter:
549
+ """Iterate over the datetime objects.
550
+
551
+ Returns
552
+ -------
553
+ iter
554
+ An iterator of datetime objects.
555
+ """
423
556
  for date in self.dates:
424
557
  yield from date
425
558
 
@@ -427,15 +560,43 @@ class ConcatDateTimes:
427
560
  class EnumDateTimes:
428
561
  """EnumDateTimes is an iterator that generates datetime objects from a list of dates."""
429
562
 
430
- def __init__(self, dates):
563
+ def __init__(self, dates: list[Union[datetime.date, datetime.datetime, str]]):
564
+ """Initialize the EnumDateTimes iterator.
565
+
566
+ Parameters
567
+ ----------
568
+ dates : list of datetime.date or datetime.datetime or str
569
+ The list of dates.
570
+ """
431
571
  self.dates = dates
432
572
 
433
- def __iter__(self):
573
+ def __iter__(self) -> iter:
574
+ """Iterate over the datetime objects.
575
+
576
+ Returns
577
+ -------
578
+ iter
579
+ An iterator of datetime objects.
580
+ """
434
581
  for date in self.dates:
435
582
  yield as_datetime(date)
436
583
 
437
584
 
438
- def datetimes_factory(*args, **kwargs):
585
+ def datetimes_factory(*args: Any, **kwargs: Any) -> Union[DateTimes, ConcatDateTimes, EnumDateTimes]:
586
+ """Create a DateTimes, ConcatDateTimes, or EnumDateTimes object.
587
+
588
+ Parameters
589
+ ----------
590
+ *args : Any
591
+ Positional arguments.
592
+ **kwargs : Any
593
+ Keyword arguments.
594
+
595
+ Returns
596
+ -------
597
+ DateTimes or ConcatDateTimes or EnumDateTimes
598
+ The created object.
599
+ """
439
600
  if args and kwargs:
440
601
  raise ValueError("Cannot provide both args and kwargs for a list of dates")
441
602
 
anemoi/utils/devtools.py CHANGED
@@ -8,29 +8,74 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
+ from typing import Any
12
+
11
13
  import cartopy.crs as ccrs
12
14
  import cartopy.feature as cfeature
13
15
  import matplotlib.pyplot as plt
14
16
  import matplotlib.tri as tri
15
17
  import numpy as np
16
18
 
17
- """FOR DEVELOPMENT PURPOSES ONLY
19
+ """FOR DEVELOPMENT PURPOSES ONLY.
18
20
 
19
21
  This module contains
20
-
21
22
  """
22
23
 
23
24
  # TODO: use earthkit-plots
24
25
 
25
26
 
26
- def fix(lons):
27
+ def fix(lons: np.ndarray) -> np.ndarray:
28
+ """Fix longitudes greater than 180 degrees.
29
+
30
+ Parameters
31
+ ----------
32
+ lons : np.ndarray
33
+ Array of longitudes.
34
+
35
+ Returns
36
+ -------
37
+ np.ndarray
38
+ Array of fixed longitudes.
39
+ """
27
40
  return np.where(lons > 180, lons - 360, lons)
28
41
 
29
42
 
30
43
  def plot_values(
31
- values, latitudes, longitudes, title=None, missing_value=None, min_value=None, max_value=None, **kwargs
32
- ):
33
-
44
+ values: np.ndarray,
45
+ latitudes: np.ndarray,
46
+ longitudes: np.ndarray,
47
+ title: str = None,
48
+ missing_value: float = None,
49
+ min_value: float = None,
50
+ max_value: float = None,
51
+ **kwargs: dict,
52
+ ) -> plt.Axes:
53
+ """Plot values on a map.
54
+
55
+ Parameters
56
+ ----------
57
+ values : np.ndarray
58
+ Array of values to plot.
59
+ latitudes : np.ndarray
60
+ Array of latitudes.
61
+ longitudes : np.ndarray
62
+ Array of longitudes.
63
+ title : str, optional
64
+ Title of the plot, by default None.
65
+ missing_value : float, optional
66
+ Value to use for missing data, by default None.
67
+ min_value : float, optional
68
+ Minimum value for the plot, by default None.
69
+ max_value : float, optional
70
+ Maximum value for the plot, by default None.
71
+ **kwargs : dict
72
+ Additional keyword arguments for the plot.
73
+
74
+ Returns
75
+ -------
76
+ plt.Axes
77
+ The plot axes.
78
+ """
34
79
  _, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
35
80
  ax.coastlines()
36
81
  ax.add_feature(cfeature.BORDERS, linestyle=":")
@@ -77,7 +122,23 @@ def plot_values(
77
122
  return ax
78
123
 
79
124
 
80
- def plot_field(field, title=None, **kwargs):
125
+ def plot_field(field: Any, title: str = None, **kwargs: dict) -> plt.Axes:
126
+ """Plot a field on a map.
127
+
128
+ Parameters
129
+ ----------
130
+ field : Any
131
+ The field to plot.
132
+ title : str, optional
133
+ Title of the plot, by default None.
134
+ **kwargs : dict
135
+ Additional keyword arguments for the plot.
136
+
137
+ Returns
138
+ -------
139
+ plt.Axes
140
+ The plot axes.
141
+ """
81
142
  values = field.to_numpy(flatten=True)
82
143
  latitudes, longitudes = field.grid_points()
83
144
  return plot_values(values, latitudes, longitudes, title=title, **kwargs)
anemoi/utils/grib.py CHANGED
@@ -11,11 +11,12 @@
11
11
  """Utilities for working with GRIB parameters.
12
12
 
13
13
  See https://codes.ecmwf.int/grib/param-db/ for more information.
14
-
15
14
  """
16
15
 
17
16
  import logging
18
17
  import re
18
+ from typing import Dict
19
+ from typing import Union
19
20
 
20
21
  import requests
21
22
 
@@ -25,7 +26,14 @@ LOG = logging.getLogger(__name__)
25
26
 
26
27
 
27
28
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
28
- def _units():
29
+ def _units() -> Dict[str, str]:
30
+ """Fetch and cache GRIB parameter units.
31
+
32
+ Returns
33
+ -------
34
+ dict
35
+ A dictionary mapping unit ids to their names.
36
+ """
29
37
  r = requests.get("https://codes.ecmwf.int/parameter-database/api/v1/unit/")
30
38
  r.raise_for_status()
31
39
  units = r.json()
@@ -33,7 +41,24 @@ def _units():
33
41
 
34
42
 
35
43
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
36
- def _search_param(name):
44
+ def _search_param(name: str) -> Dict[str, Union[str, int]]:
45
+ """Search for a GRIB parameter by name.
46
+
47
+ Parameters
48
+ ----------
49
+ name : str
50
+ Parameter name to search for.
51
+
52
+ Returns
53
+ -------
54
+ dict
55
+ A dictionary containing parameter details.
56
+
57
+ Raises
58
+ ------
59
+ KeyError
60
+ If no parameter is found.
61
+ """
37
62
  name = re.escape(name)
38
63
  r = requests.get(f"https://codes.ecmwf.int/parameter-database/api/v1/param/?search=^{name}$&regex=true")
39
64
  r.raise_for_status()
@@ -68,7 +93,6 @@ def shortname_to_paramid(shortname: str) -> int:
68
93
 
69
94
  >>> shortname_to_paramid("2t")
70
95
  167
71
-
72
96
  """
73
97
  return _search_param(shortname)["id"]
74
98
 
@@ -88,12 +112,11 @@ def paramid_to_shortname(paramid: int) -> str:
88
112
 
89
113
  >>> paramid_to_shortname(167)
90
114
  '2t'
91
-
92
115
  """
93
116
  return _search_param(str(paramid))["shortname"]
94
117
 
95
118
 
96
- def units(param) -> str:
119
+ def units(param: Union[int, str]) -> str:
97
120
  """Return the units of a GRIB parameter given its name or id.
98
121
 
99
122
  Parameters
@@ -108,14 +131,13 @@ def units(param) -> str:
108
131
 
109
132
  >>> unit(167)
110
133
  'K'
111
-
112
134
  """
113
135
 
114
136
  unit_id = str(_search_param(str(param))["unit_id"])
115
137
  return _units()[unit_id]
116
138
 
117
139
 
118
- def must_be_positive(param) -> bool:
140
+ def must_be_positive(param: Union[int, str]) -> bool:
119
141
  """Check if a parameter must be positive.
120
142
 
121
143
  Parameters
@@ -130,6 +152,5 @@ def must_be_positive(param) -> bool:
130
152
 
131
153
  >>> must_be_positive("tp")
132
154
  True
133
-
134
155
  """
135
156
  return units(param) in ["m", "kg kg**-1", "m of water equivalent"]