pycontrails 0.47.3__cp312-cp312-macosx_11_0_arm64.whl → 0.48.1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pycontrails might be problematic. Click here for more details.

@@ -19,16 +19,9 @@ from overrides import overrides
19
19
  import pycontrails
20
20
  from pycontrails.core import cache, datalib
21
21
  from pycontrails.core.met import MetDataset, MetVariable
22
- from pycontrails.datalib.ecmwf.common import ECMWFAPI, rad_accumulated_to_average
23
- from pycontrails.datalib.ecmwf.variables import (
24
- PRESSURE_LEVEL_VARIABLES,
25
- SURFACE_VARIABLES,
26
- TOAIncidentSolarRadiation,
27
- TopNetSolarRadiation,
28
- TopNetThermalRadiation,
29
- )
30
- from pycontrails.utils import dependencies, iteration
31
- from pycontrails.utils.temp import temp_file
22
+ from pycontrails.datalib.ecmwf.common import ECMWFAPI
23
+ from pycontrails.datalib.ecmwf.variables import PRESSURE_LEVEL_VARIABLES, SURFACE_VARIABLES
24
+ from pycontrails.utils import dependencies, iteration, temp
32
25
  from pycontrails.utils.types import DatetimeLike
33
26
 
34
27
  if TYPE_CHECKING:
@@ -241,15 +234,6 @@ class HRES(ECMWFAPI):
241
234
  #: Forecast run time, either specified or assigned by the closest previous forecast run
242
235
  forecast_time: datetime
243
236
 
244
- #: User provided ``ECMWFService`` url
245
- url: str | None
246
-
247
- #: User provided ``ECMWFService`` key
248
- key: str | None
249
-
250
- #: User provided ``ECMWFService`` email
251
- email: str | None
252
-
253
237
  __marker = object()
254
238
 
255
239
  def __init__(
@@ -278,14 +262,11 @@ class HRES(ECMWFAPI):
278
262
  pycontrails_optional_package="ecmwf",
279
263
  )
280
264
 
281
- # constants
282
265
  # ERA5 now delays creating the server attribute until it is needed to download
283
266
  # from CDS. We could do the same here.
284
- self.url = url
285
- self.key = key
286
- self.email = email
287
- self.server = ECMWFService("mars", url=self.url, key=self.key, email=self.email)
267
+ self.server = ECMWFService("mars", url=url, key=key, email=email)
288
268
  self.paths = paths
269
+
289
270
  if cachestore is self.__marker:
290
271
  cachestore = cache.DiskCacheStore()
291
272
  self.cachestore = cachestore
@@ -300,10 +281,24 @@ class HRES(ECMWFAPI):
300
281
  self.variables = datalib.parse_variables(variables, self.supported_variables)
301
282
 
302
283
  self.grid = datalib.parse_grid(grid, [0.1, 0.25, 0.5, 1]) # lat/lon degree resolution
303
- self.stream = stream # "enfo" = ensemble forecast, "oper" = atmospheric model/HRES
304
- self.field_type = (
305
- field_type # forecast (oper), perturbed or control forecast (enfo only), or analysis
306
- )
284
+
285
+ # "enfo" = ensemble forecast
286
+ # "oper" = atmospheric model/HRES
287
+ if stream not in ("oper", "enfo"):
288
+ msg = "Parameter stream must be 'oper' or 'enfo'"
289
+ raise ValueError(msg)
290
+
291
+ self.stream = stream
292
+
293
+ # "fc" = forecast
294
+ # "pf" = perturbed forecast
295
+ # "cf" = control forecast
296
+ # "an" = analysis
297
+ if field_type not in ("fc", "pf", "cf", "an"):
298
+ msg = "Parameter field_type must be 'fc', 'pf', 'cf', or 'an'"
299
+ raise ValueError(msg)
300
+
301
+ self.field_type = field_type
307
302
 
308
303
  # set specific forecast time is requested
309
304
  if forecast_time is not None:
@@ -476,14 +471,12 @@ class HRES(ECMWFAPI):
476
471
  request = self.generate_mars_request(self.forecast_time, self.steps, request_type="list")
477
472
 
478
473
  # hold downloaded file in named temp file
479
- with temp_file() as mars_temp_filename:
480
- LOG.debug(f"Performing MARS request: {request}")
481
- self.server.execute(request, mars_temp_filename)
474
+ with temp.temp_file() as mars_temp_filename:
475
+ LOG.debug("Performing MARS request: %s", request)
476
+ self.server.execute(request, target=mars_temp_filename)
482
477
 
483
478
  with open(mars_temp_filename, "r") as f:
484
- txt = f.read()
485
-
486
- return txt
479
+ return f.read()
487
480
 
488
481
  def generate_mars_request(
489
482
  self,
@@ -528,16 +521,16 @@ class HRES(ECMWFAPI):
528
521
  steps = self.steps
529
522
 
530
523
  # set date/time for file
531
- _date = forecast_time.strftime("%Y%m%d")
532
- _time = forecast_time.strftime("%H")
524
+ date = forecast_time.strftime("%Y%m%d")
525
+ time = forecast_time.strftime("%H")
533
526
 
534
527
  # make request of mars
535
528
  request: dict[str, Any] = {
536
529
  "class": "od", # operational data
537
530
  "stream": self.stream,
538
531
  "expver": "1", # production data only
539
- "date": _date,
540
- "time": _time,
532
+ "date": date,
533
+ "time": time,
541
534
  "type": self.field_type,
542
535
  "param": f"{'/'.join(self.variable_shortnames)}",
543
536
  "step": f"{'/'.join([str(s) for s in steps])}",
@@ -574,10 +567,8 @@ class HRES(ECMWFAPI):
574
567
  step = self.step_offset + self.timesteps.index(t)
575
568
 
576
569
  # single level or pressure level
577
- if self.pressure_levels == [-1]:
578
- suffix = f"hressl{self.grid}{self.stream}{self.field_type}"
579
- else:
580
- suffix = f"hrespl{self.grid}{self.stream}{self.field_type}"
570
+ levtype = "sl" if self.pressure_levels == [-1] else "pl"
571
+ suffix = f"hres{levtype}{self.grid}{self.stream}{self.field_type}"
581
572
 
582
573
  # return cache path
583
574
  return self.cachestore.path(f"{datestr}-{step}-{suffix}.nc")
@@ -635,47 +626,31 @@ class HRES(ECMWFAPI):
635
626
  # open cache files as xr.Dataset
636
627
  ds = self.open_dataset(disk_cachepaths, **xr_kwargs)
637
628
 
638
- # TODO: corner case
639
- # If any files are already cached, they will not have the version attached
640
- if "pycontrails_version" not in ds.attrs:
641
- ds.attrs["pycontrails_version"] = pycontrails.__version__
629
+ ds.attrs.setdefault("pycontrails_version", pycontrails.__version__)
642
630
 
643
631
  # run the same ECMWF-specific processing on the dataset
644
632
  mds = self._process_dataset(ds, **kwargs)
645
633
 
646
- # convert accumulated radiation values to average instantaneous values
647
- # set minimum for all values to 0
648
- # !! Note that HRES accumulates from the *start of the forecast*,
649
- # so we need to take the diff of each accumulated value
650
- # the 0th value is set to the 1st value so each time step has a radiation value !!
651
- dt_accumulation = 60 * 60
652
-
653
- for key in [
654
- TOAIncidentSolarRadiation.standard_name,
655
- TopNetSolarRadiation.standard_name,
656
- TopNetThermalRadiation.standard_name,
657
- ]:
658
- if key in mds.data:
659
- if len(mds.data["time"]) < 2:
660
- raise RuntimeError(
661
- f"HRES datasets with data variable {key} must have at least two timesteps"
662
- f" to calculate the average instantaneous value of {key}"
663
- )
664
-
665
- # take the difference between time slices
666
- dkey_dt = mds.data[key].diff("time")
667
-
668
- # set difference value back to the data model
669
- mds.data[key] = dkey_dt
670
-
671
- # set the 0th value of the data to the 1st difference value
672
- # TODO: this assumption may not be universally applicable!
673
- mds.data[key][dict(time=0)] = dkey_dt[dict(time=0)]
674
-
675
- rad_accumulated_to_average(mds, key, dt_accumulation)
676
-
634
+ self.set_metadata(mds)
677
635
  return mds
678
636
 
637
+ @overrides
638
+ def set_metadata(self, ds: xr.Dataset | MetDataset) -> None:
639
+ if self.stream == "oper":
640
+ product = "forecast"
641
+ elif self.stream == "enfo":
642
+ product = "ensemble"
643
+ else:
644
+ msg = f"Unknown stream type {self.stream}"
645
+ raise ValueError(msg)
646
+
647
+ ds.attrs.update(
648
+ provider="ECMWF",
649
+ dataset="HRES",
650
+ product=product,
651
+ radiation_accumulated=True,
652
+ )
653
+
679
654
  def _open_and_cache(self, xr_kwargs: dict[str, Any]) -> xr.Dataset:
680
655
  """Open and cache :class:`xr.Dataset` from :attr:`self.paths`.
681
656
 
@@ -746,7 +721,7 @@ class HRES(ECMWFAPI):
746
721
  # Open ExitStack to control temp_file context manager
747
722
  with ExitStack() as stack:
748
723
  # hold downloaded file in named temp file
749
- mars_temp_grib_filename = stack.enter_context(temp_file())
724
+ mars_temp_grib_filename = stack.enter_context(temp.temp_file())
750
725
 
751
726
  # retrieve data from MARS
752
727
  LOG.debug(f"Performing MARS request: {request}")
@@ -754,18 +729,8 @@ class HRES(ECMWFAPI):
754
729
 
755
730
  # translate into netcdf from grib
756
731
  LOG.debug("Translating file into netcdf")
757
- mars_temp_nc_filename = stack.enter_context(temp_file())
758
732
  ds = stack.enter_context(xr.open_dataset(mars_temp_grib_filename, engine="cfgrib"))
759
733
 
760
- ##### TODO: do we need to store intermediate netcdf file?
761
- ds.to_netcdf(path=mars_temp_nc_filename, mode="w")
762
-
763
- # open file, edit, and save for each hourly time step
764
- ds = stack.enter_context(
765
- xr.open_dataset(mars_temp_nc_filename, engine=datalib.NETCDF_ENGINE)
766
- )
767
- #####
768
-
769
734
  # run preprocessing before cache
770
735
  ds = self._preprocess_hres_dataset(ds)
771
736
 
@@ -27,7 +27,7 @@ class IFS(datalib.MetDataSource):
27
27
 
28
28
  .. warning::
29
29
 
30
- This data source is fully implemented.
30
+ This data source is not fully implemented.
31
31
 
32
32
  Parameters
33
33
  ----------
@@ -186,9 +186,17 @@ class IFS(datalib.MetDataSource):
186
186
  # harmonize variable names
187
187
  ds = met.standardize_variables(ds, self.variables)
188
188
 
189
- ds.attrs["met_source"] = type(self).__name__
189
+ self.set_metadata(ds)
190
190
  return met.MetDataset(ds, **kwargs)
191
191
 
192
+ @overrides
193
+ def set_metadata(self, ds: xr.Dataset | met.MetDataset) -> None:
194
+ ds.attrs.update(
195
+ provider="ECMWF",
196
+ dataset="IFS",
197
+ product="forecast",
198
+ )
199
+
192
200
  @overrides
193
201
  def download_dataset(self, times: list[datetime]) -> None:
194
202
  raise NotImplementedError("IFS download is not supported")
@@ -12,7 +12,7 @@ from __future__ import annotations
12
12
  import hashlib
13
13
  import logging
14
14
  import pathlib
15
- from contextlib import ExitStack
15
+ import warnings
16
16
  from datetime import datetime
17
17
  from typing import TYPE_CHECKING, Any, Callable
18
18
 
@@ -30,15 +30,14 @@ from pycontrails.datalib.gfs.variables import (
30
30
  TOAUpwardShortwaveRadiation,
31
31
  Visibility,
32
32
  )
33
- from pycontrails.utils import dependencies
34
- from pycontrails.utils.temp import temp_file
33
+ from pycontrails.utils import dependencies, temp
35
34
  from pycontrails.utils.types import DatetimeLike
36
35
 
37
36
  # optional imports
38
37
  if TYPE_CHECKING:
39
38
  import botocore
40
39
 
41
- LOG = logging.getLogger(__name__)
40
+ logger = logging.getLogger(__name__)
42
41
 
43
42
  #: Default GFS AWS bucket
44
43
  GFS_FORECAST_BUCKET = "noaa-gfs-bdp-pds"
@@ -72,8 +71,8 @@ class GFSForecast(datalib.MetDataSource):
72
71
  Specify latitude/longitude grid spacing in data.
73
72
  Defaults to 0.25.
74
73
  forecast_time : `DatetimeLike`, optional
75
- Specify forecast run by runtime.
76
- Defaults to None.
74
+ Specify forecast run by runtime. If None (default), the forecast time
75
+ is set to the 6 hour floor of the first timestep.
77
76
  cachestore : :class:`cache.CacheStore` | None, optional
78
77
  Cache data store for staging data files.
79
78
  Defaults to :class:`cache.DiskCacheStore`.
@@ -88,7 +87,7 @@ class GFSForecast(datalib.MetDataSource):
88
87
  >>> from pycontrails.datalib.gfs import GFSForecast
89
88
 
90
89
  >>> # Store data files to local disk (default behavior)
91
- >>> times = ("2022-03-22 00:00:00", "2022-03-22 03:00:00", )
90
+ >>> times = ("2022-03-22 00:00:00", "2022-03-22 03:00:00")
92
91
  >>> gfs = GFSForecast(times, variables="air_temperature", pressure_levels=[300, 250])
93
92
  >>> gfs
94
93
  GFSForecast
@@ -97,7 +96,15 @@ class GFSForecast(datalib.MetDataSource):
97
96
  Pressure levels: [300, 250]
98
97
  Grid: 0.25
99
98
  Forecast time: 2022-03-22 00:00:00
100
- Steps: [0, 1, 2, 3]
99
+
100
+ >>> gfs = GFSForecast(times, variables="air_temperature", pressure_levels=[300, 250], grid=0.5)
101
+ >>> gfs
102
+ GFSForecast
103
+ Timesteps: ['2022-03-22 00', '2022-03-22 03']
104
+ Variables: ['t']
105
+ Pressure levels: [300, 250]
106
+ Grid: 0.5
107
+ Forecast time: 2022-03-22 00:00:00
101
108
 
102
109
  Notes
103
110
  -----
@@ -164,12 +171,17 @@ class GFSForecast(datalib.MetDataSource):
164
171
  if time is None and paths is None:
165
172
  raise ValueError("Time input is required when paths is None")
166
173
 
167
- self.timesteps = datalib.parse_timesteps(time, freq="1H")
174
+ # Forecast is available hourly for 0.25 degree grid,
175
+ # 3 hourly for 0.5 and 1 degree grid
176
+ # https://www.nco.ncep.noaa.gov/pmb/products/gfs/
177
+ freq = "1H" if grid == 0.25 else "3H"
178
+ self.timesteps = datalib.parse_timesteps(time, freq=freq)
179
+
168
180
  self.pressure_levels = datalib.parse_pressure_levels(
169
181
  pressure_levels, self.supported_pressure_levels
170
182
  )
171
183
  self.variables = datalib.parse_variables(variables, self.supported_variables)
172
- self.grid = datalib.parse_grid(grid, [0.25, 0.5, 1])
184
+ self.grid = datalib.parse_grid(grid, (0.25, 0.5, 1))
173
185
 
174
186
  # note GFS allows unsigned requests (no credentials)
175
187
  # https://stackoverflow.com/questions/34865927/can-i-use-boto3-anonymously/34866092#34866092
@@ -180,24 +192,19 @@ class GFSForecast(datalib.MetDataSource):
180
192
  # set specific forecast time is requested
181
193
  if forecast_time is not None:
182
194
  forecast_time_pd = pd.to_datetime(forecast_time)
183
- if forecast_time_pd.hour not in [0, 6, 12, 18]:
195
+ if forecast_time_pd.hour % 6:
184
196
  raise ValueError("Forecast hour must be on one of 00, 06, 12, 18")
185
197
 
186
198
  self.forecast_time = datalib.round_hour(forecast_time_pd.to_pydatetime(), 6)
187
199
 
188
200
  # if no specific forecast is requested, set the forecast time using timesteps
189
- elif self.timesteps:
201
+ else:
190
202
  # round first element to the nearest 6 hour time (00, 06, 12, 18 UTC) for forecast_time
191
203
  self.forecast_time = datalib.round_hour(self.timesteps[0], 6)
192
204
 
193
- # when no forecast_time or time input, forecast_time is defined in _open_and_cache
194
-
195
205
  def __repr__(self) -> str:
196
206
  base = super().__repr__()
197
- return (
198
- f"{base}\n\tForecast time: {getattr(self, 'forecast_time', '')}\n\tSteps:"
199
- f" {getattr(self, 'steps', '')}"
200
- )
207
+ return f"{base}\n\tForecast time: {self.forecast_time}"
201
208
 
202
209
  @property
203
210
  def supported_pressure_levels(self) -> list[int]:
@@ -282,32 +289,6 @@ class GFSForecast(datalib.MetDataSource):
282
289
  )
283
290
  return hashlib.sha1(bytes(hashstr, "utf-8")).hexdigest()
284
291
 
285
- @property
286
- def step_offset(self) -> int:
287
- """Difference between :attr:`forecast_time` and first timestep.
288
-
289
- Returns
290
- -------
291
- int
292
- Number of steps to offset in order to retrieve data starting from input time.
293
- Returns 0 if :attr:`timesteps` is empty when loading from :attr:`paths`.
294
- """
295
- if self.timesteps:
296
- return int((self.timesteps[0] - self.forecast_time).total_seconds() // 3600)
297
-
298
- return 0
299
-
300
- @property
301
- def steps(self) -> list[int]:
302
- """Forecast steps from :attr:`forecast_time` corresponding within input :attr:`time`.
303
-
304
- Returns
305
- -------
306
- list[int]
307
- List of forecast steps relative to :attr:`forecast_time`
308
- """
309
- return [self.step_offset + i for i in range(len(self.timesteps))]
310
-
311
292
  @property
312
293
  def _grid_string(self) -> str:
313
294
  """Return filename string for grid spacing."""
@@ -315,7 +296,7 @@ class GFSForecast(datalib.MetDataSource):
315
296
  return "0p25"
316
297
  if self.grid == 0.5:
317
298
  return "0p50"
318
- if self.grid == 1:
299
+ if self.grid == 1.0:
319
300
  return "1p00"
320
301
  raise ValueError(f"Unsupported grid spacing {self.grid}. Must be one of 0.25, 0.5, or 1.0.")
321
302
 
@@ -336,7 +317,7 @@ class GFSForecast(datalib.MetDataSource):
336
317
  forecast_hour = str(self.forecast_time.hour).zfill(2)
337
318
  return f"gfs.{datestr}/{forecast_hour}/atmos"
338
319
 
339
- def filename(self, step: int) -> str:
320
+ def filename(self, t: datetime) -> str:
340
321
  """Construct grib filename to retrieve from GFS bucket.
341
322
 
342
323
  String template:
@@ -349,8 +330,8 @@ class GFSForecast(datalib.MetDataSource):
349
330
 
350
331
  Parameters
351
332
  ----------
352
- step : int
353
- Integer step relative to forecast time
333
+ t : datetime
334
+ Timestep to download
354
335
 
355
336
  Returns
356
337
  -------
@@ -361,8 +342,10 @@ class GFSForecast(datalib.MetDataSource):
361
342
  ----------
362
343
  - https://www.nco.ncep.noaa.gov/pmb/products/gfs/
363
344
  """
345
+ step = pd.Timedelta(t - self.forecast_time) // pd.Timedelta(1, "h")
346
+ step_hour = str(step).zfill(3)
364
347
  forecast_hour = str(self.forecast_time.hour).zfill(2)
365
- return f"gfs.t{forecast_hour}z.pgrb2.{self._grid_string}.f{str(step).zfill(3)}"
348
+ return f"gfs.t{forecast_hour}z.pgrb2.{self._grid_string}.f{step_hour}"
366
349
 
367
350
  @overrides
368
351
  def create_cachepath(self, t: datetime) -> str:
@@ -373,7 +356,7 @@ class GFSForecast(datalib.MetDataSource):
373
356
  datestr = self.forecast_time.strftime("%Y%m%d-%H")
374
357
 
375
358
  # get step relative to forecast forecast_time
376
- step = self.step_offset + self.timesteps.index(t)
359
+ step = pd.Timedelta(t - self.forecast_time) // pd.Timedelta(1, "h")
377
360
 
378
361
  # single level or pressure level
379
362
  suffix = f"gfs{'sl' if self.pressure_levels == [-1] else 'pl'}{self.grid}"
@@ -384,7 +367,7 @@ class GFSForecast(datalib.MetDataSource):
384
367
  @overrides
385
368
  def download_dataset(self, times: list[datetime]) -> None:
386
369
  # get step relative to forecast forecast_time
387
- LOG.debug(
370
+ logger.debug(
388
371
  f"Downloading GFS forecast for forecast time {self.forecast_time} and timesteps {times}"
389
372
  )
390
373
 
@@ -442,28 +425,15 @@ class GFSForecast(datalib.MetDataSource):
442
425
  ds.attrs.setdefault("pycontrails_version", pycontrails.__version__)
443
426
 
444
427
  # run the same GFS-specific processing on the dataset
445
- mds = self._process_dataset(ds, **kwargs)
446
-
447
- # set TOAUpwardShortwaveRadiation, TOAUpwardLongwaveRadiation step 0 == step 1
448
- for key in [
449
- TOAUpwardShortwaveRadiation.standard_name,
450
- TOAUpwardLongwaveRadiation.standard_name,
451
- ]:
452
- # if step 0 (forecast time) exists in dimension
453
- forecast_time = mds.data["forecast_time"].values
454
- if key in mds.data and forecast_time in mds.data["time"]:
455
- # make sure this isn't the only time in the dataset
456
- if np.all(mds.data["time"].values == forecast_time):
457
- raise RuntimeError(
458
- f"GFS datasets with data variable {key} must have at least one timestep"
459
- f" after the forecast time to estimate the value of {key} at step 0"
460
- )
461
-
462
- # set the 0th value of the data to the 1st value
463
- # TODO: this assumption may not be universally applicable!
464
- mds.data[key][dict(time=0)] = mds.data[key][dict(time=1)]
465
-
466
- return mds
428
+ return self._process_dataset(ds, **kwargs)
429
+
430
+ @overrides
431
+ def set_metadata(self, ds: xr.Dataset | met.MetDataset) -> None:
432
+ ds.attrs.update(
433
+ provider="NCEP",
434
+ dataset="GFS",
435
+ product="forecast",
436
+ )
467
437
 
468
438
  def _download_file(self, t: datetime) -> None:
469
439
  """Download data file for forecast time and step.
@@ -487,17 +457,13 @@ class GFSForecast(datalib.MetDataSource):
487
457
  raise ValueError("Cachestore is required to download data")
488
458
 
489
459
  # construct filenames for each file
490
- step = self.step_offset + self.timesteps.index(t)
491
- filename = self.filename(step)
460
+ filename = self.filename(t)
492
461
  aws_key = f"{self.forecast_path}/{filename}"
493
462
 
494
- # Open ExitStack to control temp_file context manager
495
- with ExitStack() as stack:
496
- # hold downloaded file in named temp file
497
- temp_grib_filename = stack.enter_context(temp_file())
498
-
463
+ # Hold downloaded file in named temp file
464
+ with temp.temp_file() as temp_grib_filename:
499
465
  # retrieve data from AWS S3
500
- LOG.debug(f"Downloading GFS file {filename} from AWS bucket to {temp_grib_filename}")
466
+ logger.debug(f"Downloading GFS file {filename} from AWS bucket to {temp_grib_filename}")
501
467
  if self.show_progress:
502
468
  _download_with_progress(
503
469
  self.client, GFS_FORECAST_BUCKET, aws_key, temp_grib_filename, filename
@@ -509,12 +475,8 @@ class GFSForecast(datalib.MetDataSource):
509
475
 
510
476
  ds = self._open_gfs_dataset(temp_grib_filename, t)
511
477
 
512
- # write out data to temp, close grib file
513
- temp_nc_filename = stack.enter_context(temp_file())
514
- ds.to_netcdf(path=temp_nc_filename, mode="w")
515
-
516
- # put each hourly file into cache
517
- self.cachestore.put(temp_nc_filename, self.create_cachepath(t))
478
+ cache_path = self.create_cachepath(t)
479
+ ds.to_netcdf(cache_path)
518
480
 
519
481
  def _open_gfs_dataset(self, filepath: str | pathlib.Path, t: datetime) -> xr.Dataset:
520
482
  """Open GFS grib file for one forecast timestep.
@@ -532,22 +494,24 @@ class GFSForecast(datalib.MetDataSource):
532
494
  GFS dataset
533
495
  """
534
496
  # translate into netcdf from grib
535
- LOG.debug(f"Translating {filepath} for timestep {str(t)} into netcdf")
497
+ logger.debug(f"Translating {filepath} for timestep {str(t)} into netcdf")
536
498
 
537
499
  # get step for timestep
538
- step = self.step_offset + self.timesteps.index(t)
500
+ step = pd.Timedelta(t - self.forecast_time) // pd.Timedelta(1, "h")
539
501
 
540
502
  # open file for each variable short name individually
541
- ds = xr.Dataset()
503
+ ds: xr.Dataset | None = None
542
504
  for variable in self.variables:
543
- # radiation data is not available in the 0th step
544
- if step == 0 and variable in [
505
+ # Radiation data is not available in the 0th step
506
+ is_radiation_step_zero = step == 0 and variable in (
545
507
  TOAUpwardShortwaveRadiation,
546
508
  TOAUpwardLongwaveRadiation,
547
- ]:
548
- LOG.debug(
549
- "Radiation data is not provided for the 0th step in GFS. Setting to np.nan"
550
- " using Visibility variable"
509
+ )
510
+
511
+ if is_radiation_step_zero:
512
+ warnings.warn(
513
+ "Radiation data is not provided for the 0th step in GFS. "
514
+ "Setting to np.nan using Visibility variable"
551
515
  )
552
516
  v = Visibility
553
517
  else:
@@ -559,23 +523,22 @@ class GFSForecast(datalib.MetDataSource):
559
523
  engine="cfgrib",
560
524
  )
561
525
 
562
- if not len(ds):
526
+ if ds is None:
563
527
  ds = tmpds
564
528
  else:
565
529
  ds[v.short_name] = tmpds[v.short_name]
566
530
 
567
531
  # set all radiation data to np.nan in the 0th step
568
- if step == 0 and variable in [
569
- TOAUpwardShortwaveRadiation,
570
- TOAUpwardLongwaveRadiation,
571
- ]:
532
+ if is_radiation_step_zero:
572
533
  ds = ds.rename({Visibility.short_name: variable.short_name})
573
534
  ds[variable.short_name] = np.nan
574
535
 
536
+ assert ds is not None, "No variables were loaded from grib file"
537
+
575
538
  # for pressure levels, need to rename "level" field and downselect
576
539
  if self.pressure_levels != [-1]:
577
540
  ds = ds.rename({"isobaricInhPa": "level"})
578
- ds = ds.sel(dict(level=self.pressure_levels))
541
+ ds = ds.sel(level=self.pressure_levels)
579
542
 
580
543
  # for single level, and singular pressure levels, add the level dimension
581
544
  if len(self.pressure_levels) == 1:
@@ -623,15 +586,14 @@ class GFSForecast(datalib.MetDataSource):
623
586
  ds = ds.expand_dims({"level": self.pressure_levels})
624
587
 
625
588
  else:
626
- ds = ds.sel(dict(level=self.pressure_levels))
589
+ ds = ds.sel(level=self.pressure_levels)
627
590
 
628
591
  # harmonize variable names
629
592
  ds = met.standardize_variables(ds, self.variables)
630
593
 
631
- if "cachestore" not in kwargs:
632
- kwargs["cachestore"] = self.cachestore
594
+ kwargs.setdefault("cachestore", self.cachestore)
633
595
 
634
- ds.attrs["met_source"] = type(self).__name__
596
+ self.set_metadata(ds)
635
597
  return met.MetDataset(ds, **kwargs)
636
598
 
637
599