ocf-data-sampler 0.2.38__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (70) hide show
  1. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/PKG-INFO +2 -1
  2. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/config/model.py +33 -4
  3. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/load_dataset.py +1 -1
  4. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +1 -1
  5. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/providers/ecmwf.py +1 -1
  6. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/providers/gfs.py +6 -1
  7. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/providers/icon.py +1 -1
  8. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/providers/ukv.py +1 -1
  9. ocf_data_sampler-0.3.1/ocf_data_sampler/load/nwp/providers/utils.py +83 -0
  10. ocf_data_sampler-0.3.1/ocf_data_sampler/load/open_tensorstore_zarrs.py +92 -0
  11. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/satellite.py +6 -40
  12. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/utils.py +1 -1
  13. ocf_data_sampler-0.3.1/ocf_data_sampler/select/dropout.py +61 -0
  14. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +2 -0
  15. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler.egg-info/PKG-INFO +2 -1
  16. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler.egg-info/SOURCES.txt +1 -0
  17. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler.egg-info/requires.txt +1 -0
  18. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/pyproject.toml +1 -0
  19. ocf_data_sampler-0.2.38/ocf_data_sampler/load/nwp/providers/utils.py +0 -43
  20. ocf_data_sampler-0.2.38/ocf_data_sampler/select/dropout.py +0 -47
  21. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/LICENSE +0 -0
  22. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/README.md +0 -0
  23. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/__init__.py +0 -0
  24. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/config/__init__.py +0 -0
  25. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/config/load.py +0 -0
  26. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/config/save.py +0 -0
  27. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
  28. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
  29. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/__init__.py +0 -0
  30. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/gsp.py +0 -0
  31. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  32. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  33. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  34. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/load/site.py +0 -0
  35. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  36. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  37. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
  38. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  39. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  40. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  41. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  42. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/site.py +0 -0
  43. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  44. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/__init__.py +0 -0
  45. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  46. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  47. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/geospatial.py +0 -0
  48. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/location.py +0 -0
  49. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  50. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/select/select_time_slice.py +0 -0
  51. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
  52. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/datasets/site.py +0 -0
  53. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -0
  54. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/sample/base.py +0 -0
  55. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/sample/site.py +0 -0
  56. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -0
  57. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/__init__.py +0 -0
  58. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -0
  59. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  60. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
  61. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +0 -0
  62. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  63. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
  64. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler/utils.py +0 -0
  65. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  66. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  67. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/scripts/download_gsp_location_data.py +0 -0
  68. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/scripts/refactor_site.py +0 -0
  69. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/setup.cfg +0 -0
  70. {ocf_data_sampler-0.2.38 → ocf_data_sampler-0.3.1}/utils/compute_icon_mean_stddev.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.38
3
+ Version: 0.3.1
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -44,6 +44,7 @@ Requires-Dist: pyproj
44
44
  Requires-Dist: pyaml_env
45
45
  Requires-Dist: pyresample
46
46
  Requires-Dist: h5netcdf
47
+ Requires-Dist: xarray-tensorstore==0.1.5
47
48
 
48
49
  # ocf-data-sampler
49
50
 
@@ -90,11 +90,10 @@ class DropoutMixin(Base):
90
90
  "negative or zero.",
91
91
  )
92
92
 
93
- dropout_fraction: float = Field(
93
+ dropout_fraction: float|list[float] = Field(
94
94
  default=0,
95
- description="Chance of dropout being applied to each sample",
96
- ge=0,
97
- le=1,
95
+ description="Either a float(Chance of dropout being applied to each sample) or a list of "
96
+ "floats (probability that dropout of the corresponding timedelta is applied)",
98
97
  )
99
98
 
100
99
  @field_validator("dropout_timedeltas_minutes")
@@ -105,6 +104,36 @@ class DropoutMixin(Base):
105
104
  raise ValueError("Dropout timedeltas must be negative")
106
105
  return v
107
106
 
107
+
108
+ @field_validator("dropout_fraction")
109
+ def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
110
+ """Validate 'dropout_frac'."""
111
+ from math import isclose
112
+ if isinstance(dropout_frac, float):
113
+ if not (dropout_frac <= 1):
114
+ raise ValueError("Input should be less than or equal to 1")
115
+ elif not (dropout_frac >= 0):
116
+ raise ValueError("Input should be greater than or equal to 0")
117
+
118
+ elif isinstance(dropout_frac, list):
119
+ if not dropout_frac:
120
+ raise ValueError("List cannot be empty")
121
+
122
+ if not all(isinstance(i, float) for i in dropout_frac):
123
+ raise ValueError("All elements in the list must be floats")
124
+
125
+ if not all(0 <= i <= 1 for i in dropout_frac):
126
+ raise ValueError("Each float in the list must be between 0 and 1")
127
+
128
+ if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
129
+ raise ValueError("Sum of all floats in the list must be 1.0")
130
+
131
+
132
+ else:
133
+ raise TypeError("Must be either a float or a list of floats")
134
+ return dropout_frac
135
+
136
+
108
137
  @model_validator(mode="after")
109
138
  def dropout_instructions_consistent(self) -> "DropoutMixin":
110
139
  """Validator for dropout instructions."""
@@ -25,7 +25,7 @@ def get_dataset_dict(
25
25
  zarr_path=input_config.gsp.zarr_path,
26
26
  boundaries_version=input_config.gsp.boundaries_version,
27
27
  public=input_config.gsp.public,
28
- ).compute()
28
+ )
29
29
 
30
30
  if gsp_ids is None:
31
31
  # Remove national (gsp_id=0)
@@ -28,7 +28,7 @@ def open_cloudcasting(zarr_path: str | list[str]) -> xr.DataArray:
28
28
  [3] https://github.com/openclimatefix/sat_pred
29
29
  """
30
30
  # Open the data
31
- ds = open_zarr_paths(zarr_path)
31
+ ds = open_zarr_paths(zarr_path, backend="tensorstore")
32
32
 
33
33
  # Rename
34
34
  ds = ds.rename(
@@ -19,7 +19,7 @@ def open_ifs(zarr_path: str | list[str]) -> xr.DataArray:
19
19
  Returns:
20
20
  Xarray DataArray of the NWP data
21
21
  """
22
- ds = open_zarr_paths(zarr_path)
22
+ ds = open_zarr_paths(zarr_path, backend="tensorstore")
23
23
 
24
24
  # LEGACY SUPPORT - rename variable to channel if it exists
25
25
  ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})
@@ -23,7 +23,12 @@ def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray:
23
23
  _log.info("Loading NWP GFS data")
24
24
 
25
25
  # Open data
26
- gfs: xr.Dataset = open_zarr_paths(zarr_path, time_dim="init_time_utc", public=public)
26
+ gfs: xr.Dataset = open_zarr_paths(
27
+ zarr_path,
28
+ time_dim="init_time_utc",
29
+ public=public,
30
+ backend="dask",
31
+ )
27
32
  nwp: xr.DataArray = gfs.to_array(dim="channel")
28
33
 
29
34
  del gfs
@@ -20,7 +20,7 @@ def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray:
20
20
  Xarray DataArray of the NWP data
21
21
  """
22
22
  # Open and check initially
23
- ds = open_zarr_paths(zarr_path, time_dim="init_time_utc")
23
+ ds = open_zarr_paths(zarr_path, time_dim="init_time_utc", backend="dask")
24
24
 
25
25
  if "icon_eu_data" in ds.data_vars:
26
26
  nwp = ds["icon_eu_data"]
@@ -19,7 +19,7 @@ def open_ukv(zarr_path: str | list[str]) -> xr.DataArray:
19
19
  Returns:
20
20
  Xarray DataArray of the NWP data
21
21
  """
22
- ds = open_zarr_paths(zarr_path)
22
+ ds = open_zarr_paths(zarr_path, backend="tensorstore")
23
23
 
24
24
  ds = ds.rename(
25
25
  {
@@ -0,0 +1,83 @@
1
+ """Utility functions for the NWP data processing."""
2
+
3
+ from glob import glob
4
+
5
+ import xarray as xr
6
+ from xarray_tensorstore import open_zarr
7
+
8
+ from ocf_data_sampler.load.open_tensorstore_zarrs import open_zarrs
9
+
10
+
11
+ def open_zarr_paths(
12
+ zarr_path: str | list[str],
13
+ time_dim: str = "init_time",
14
+ public: bool = False,
15
+ backend: str = "dask",
16
+ ) -> xr.Dataset:
17
+ """Opens the NWP data.
18
+
19
+ Args:
20
+ zarr_path: Path to the zarr(s) to open
21
+ time_dim: Name of the time dimension
22
+ public: Whether the data is public or private. Only available for the dask backend.
23
+ backend: The xarray backend to use.
24
+
25
+ Returns:
26
+ The opened Xarray Dataset
27
+ """
28
+ if backend not in ["dask", "tensorstore"]:
29
+ raise ValueError(
30
+ f"Unsupported backend: {backend}. Supported backends are 'dask' and 'tensorstore'.",
31
+ )
32
+
33
+ if public and backend == "tensorstore":
34
+ raise ValueError("Public data is only supported with the 'dask' backend.")
35
+
36
+ if backend == "tensorstore":
37
+ ds = _tensostore_open_zarr_paths(zarr_path, time_dim)
38
+
39
+ elif backend == "dask":
40
+ ds = _dask_open_zarr_paths(zarr_path, time_dim, public)
41
+
42
+ return ds
43
+
44
+
45
+ def _dask_open_zarr_paths(zarr_path: str | list[str], time_dim: str, public: bool) -> xr.Dataset:
46
+ general_kwargs = {
47
+ "engine": "zarr",
48
+ "chunks": "auto",
49
+ "decode_timedelta": True,
50
+ }
51
+
52
+ if public:
53
+ # note this only works for s3 zarr paths at the moment
54
+ general_kwargs["storage_options"] = {"anon": True}
55
+
56
+ if isinstance(zarr_path, list | tuple) or "*" in str(zarr_path): # Multi-file dataset
57
+ ds = xr.open_mfdataset(
58
+ zarr_path,
59
+ concat_dim=time_dim,
60
+ combine="nested",
61
+ **general_kwargs,
62
+ ).sortby(time_dim)
63
+ else:
64
+ ds = xr.open_dataset(
65
+ zarr_path,
66
+ consolidated=True,
67
+ mode="r",
68
+ **general_kwargs,
69
+ )
70
+ return ds
71
+
72
+
73
+ def _tensostore_open_zarr_paths(zarr_path: str | list[str], time_dim: str) -> xr.Dataset:
74
+
75
+ if "*" in str(zarr_path):
76
+ zarr_path = sorted(glob(zarr_path))
77
+
78
+ if isinstance(zarr_path, list | tuple):
79
+ ds = open_zarrs(zarr_path, concat_dim=time_dim).sortby(time_dim)
80
+ else:
81
+ ds = open_zarr(zarr_path)
82
+ return ds
83
+
@@ -0,0 +1,92 @@
1
+ """Open multiple zarrs with TensorStore.
2
+
3
+ This extendds the functionality of xarray_tensorstore to open multiple zarr stores
4
+ """
5
+
6
+ import os
7
+
8
+ import tensorstore as ts
9
+ import xarray as xr
10
+ from xarray_tensorstore import (
11
+ _raise_if_mask_and_scale_used_for_data_vars,
12
+ _TensorStoreAdapter,
13
+ _zarr_spec_from_path,
14
+ )
15
+
16
+
17
+ def tensorstore_open_multi_zarrs(
18
+ paths: list[str],
19
+ data_vars: list[str],
20
+ concat_axes: list[int],
21
+ context: ts.Context,
22
+ write: bool,
23
+ ) -> dict[str, ts.TensorStore]:
24
+ """Open multiple zarrs with TensorStore.
25
+
26
+ Args:
27
+ paths: List of paths to zarr stores.
28
+ data_vars: List of data variable names to open.
29
+ concat_axes: List of axes along which to concatenate the data variables.
30
+ context: TensorStore context.
31
+ write: Whether to open the stores for writing.
32
+ """
33
+ arrays_list = []
34
+ for path in paths:
35
+ specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in data_vars}
36
+ array_futures = {
37
+ k: ts.open(spec, read=True, write=write, context=context)
38
+ for k, spec in specs.items()
39
+ }
40
+ arrays_list.append({k: v.result() for k, v in array_futures.items()})
41
+
42
+ arrays = {}
43
+ for k, axis in zip(data_vars, concat_axes, strict=False):
44
+ datasets = [d[k] for d in arrays_list]
45
+ arrays[k] = ts.concat(datasets, axis=axis)
46
+
47
+ return arrays
48
+
49
+
50
+ def open_zarrs(
51
+ paths: list[str],
52
+ concat_dim: str,
53
+ *,
54
+ context: ts.Context | None = None,
55
+ mask_and_scale: bool = True,
56
+ write: bool = False,
57
+ ) -> xr.Dataset:
58
+ """Open multiple zarrs with TensorStore.
59
+
60
+ Args:
61
+ paths: List of paths to zarr stores.
62
+ concat_dim: Dimension along which to concatenate the data variables.
63
+ context: TensorStore context.
64
+ mask_and_scale: Whether to mask and scale the data.
65
+ write: Whether to open the stores for writing.
66
+ """
67
+ if context is None:
68
+ context = ts.Context()
69
+
70
+ ds = xr.open_mfdataset(
71
+ paths,
72
+ concat_dim=concat_dim,
73
+ combine="nested",
74
+ mask_and_scale=mask_and_scale,
75
+ decode_timedelta=True,
76
+ )
77
+
78
+ if mask_and_scale:
79
+ # Data variables get replaced below with _TensorStoreAdapter arrays, which
80
+ # don't get masked or scaled. Raising an error avoids surprising users with
81
+ # incorrect data values.
82
+ _raise_if_mask_and_scale_used_for_data_vars(ds)
83
+
84
+ data_vars = list(ds.data_vars)
85
+
86
+ concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
87
+
88
+ arrays = tensorstore_open_multi_zarrs(paths, data_vars, concat_axes, context, write)
89
+
90
+ new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
91
+
92
+ return ds.copy(data=new_data)
@@ -1,6 +1,7 @@
1
1
  """Satellite loader."""
2
2
  import numpy as np
3
3
  import xarray as xr
4
+ from xarray_tensorstore import open_zarr
4
5
 
5
6
  from ocf_data_sampler.load.utils import (
6
7
  check_time_unique_increasing,
@@ -8,39 +9,7 @@ from ocf_data_sampler.load.utils import (
8
9
  make_spatial_coords_increasing,
9
10
  )
10
11
 
11
-
12
- def get_single_sat_data(zarr_path: str) -> xr.Dataset:
13
- """Helper function to open a zarr from either a local or GCP path.
14
-
15
- Args:
16
- zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
17
- GCS paths (gs://) do not support wildcards
18
-
19
- Returns:
20
- An xarray Dataset containing satellite data
21
-
22
- Raises:
23
- ValueError: If a wildcard (*) is used in a GCS (gs://) path
24
- """
25
- # Raise an error if a wildcard is used in a GCP path
26
- if "gs://" in str(zarr_path) and "*" in str(zarr_path):
27
- raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
28
-
29
- # Handle multi-file dataset for local paths
30
- if "*" in str(zarr_path):
31
- ds = xr.open_mfdataset(
32
- zarr_path,
33
- engine="zarr",
34
- concat_dim="time",
35
- combine="nested",
36
- chunks="auto",
37
- join="override",
38
- )
39
- check_time_unique_increasing(ds.time)
40
- else:
41
- ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
42
-
43
- return ds
12
+ from .open_tensorstore_zarrs import open_zarrs
44
13
 
45
14
 
46
15
  def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
@@ -52,14 +21,11 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
52
21
  """
53
22
  # Open the data
54
23
  if isinstance(zarr_path, list | tuple):
55
- ds = xr.combine_nested(
56
- [get_single_sat_data(path) for path in zarr_path],
57
- concat_dim="time",
58
- combine_attrs="override",
59
- join="override",
60
- )
24
+ ds = open_zarrs(zarr_path, concat_dim="time")
61
25
  else:
62
- ds = get_single_sat_data(zarr_path)
26
+ ds = open_zarr(zarr_path)
27
+
28
+ check_time_unique_increasing(ds.time)
63
29
 
64
30
  ds = ds.rename(
65
31
  {
@@ -47,7 +47,7 @@ def get_xr_data_array_from_xr_dataset(ds: xr.Dataset) -> xr.DataArray:
47
47
  Args:
48
48
  ds: xr.Dataset to extract xr.DataArray from
49
49
  """
50
- datavars = list(ds.var())
50
+ datavars = list(ds.data_vars)
51
51
  if len(datavars) != 1:
52
52
  raise ValueError("Cannot open as xr.DataArray: dataset contains multiple variables")
53
53
  return ds[datavars[0]]
@@ -0,0 +1,61 @@
1
+ """Functions for simulating dropout in time series data.
2
+
3
+ This is used for the following types of data: GSP, Satellite and Site
4
+ This is not used for NWP
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import xarray as xr
10
+
11
+
12
+ def apply_sampled_dropout_time(
13
+ t0: pd.Timestamp,
14
+ dropout_timedeltas: list[pd.Timedelta],
15
+ dropout_frac: float|list[float],
16
+ da: xr.DataArray,
17
+ ) -> xr.DataArray:
18
+ """Randomly pick a dropout time from a list of timedeltas and apply dropout time to the data.
19
+
20
+ Args:
21
+ t0: The forecast init-time
22
+ dropout_timedeltas: List of timedeltas relative to t0 to pick from
23
+ dropout_frac: Either a probability that dropout will be applied.
24
+ This should be between 0 and 1 inclusive.
25
+ Or a list of probabilities for each of the corresponding timedeltas
26
+ da: Xarray DataArray with 'time_utc' coordinate
27
+ """
28
+ if isinstance(dropout_frac, list):
29
+ # checking if len match
30
+ if len(dropout_frac) != len(dropout_timedeltas):
31
+ raise ValueError("Lengths of dropout_frac and dropout_timedeltas should match")
32
+
33
+
34
+
35
+
36
+ dropout_time = t0 + np.random.choice(dropout_timedeltas,p=dropout_frac)
37
+
38
+ return da.where(da.time_utc <= dropout_time)
39
+
40
+
41
+
42
+ # old logic
43
+ else:
44
+ # sample dropout time
45
+ if dropout_frac > 0 and len(dropout_timedeltas) == 0:
46
+ raise ValueError("To apply dropout, dropout_timedeltas must be provided")
47
+
48
+
49
+ if not (0 <= dropout_frac <= 1):
50
+ raise ValueError("dropout_frac must be between 0 and 1 inclusive")
51
+
52
+ if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
53
+ dropout_time = None
54
+ else:
55
+ dropout_time = t0 + np.random.choice(dropout_timedeltas)
56
+
57
+ # apply dropout time
58
+ if dropout_time is None:
59
+ return da
60
+ # This replaces the times after the dropout with NaNs
61
+ return da.where(da.time_utc <= dropout_time)
@@ -270,6 +270,8 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
270
270
  def __getitem__(self, idx: int) -> NumpySample:
271
271
  # Get the coordinates of the sample
272
272
 
273
+ idx = int(idx)
274
+
273
275
  if idx >= len(self):
274
276
  raise ValueError(f"Index {idx} out of range for dataset of length {len(self)}")
275
277
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.38
3
+ Version: 0.3.1
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -44,6 +44,7 @@ Requires-Dist: pyproj
44
44
  Requires-Dist: pyaml_env
45
45
  Requires-Dist: pyresample
46
46
  Requires-Dist: h5netcdf
47
+ Requires-Dist: xarray-tensorstore==0.1.5
47
48
 
48
49
  # ocf-data-sampler
49
50
 
@@ -17,6 +17,7 @@ ocf_data_sampler/data/uk_gsp_locations_20250109.csv
17
17
  ocf_data_sampler/load/__init__.py
18
18
  ocf_data_sampler/load/gsp.py
19
19
  ocf_data_sampler/load/load_dataset.py
20
+ ocf_data_sampler/load/open_tensorstore_zarrs.py
20
21
  ocf_data_sampler/load/satellite.py
21
22
  ocf_data_sampler/load/site.py
22
23
  ocf_data_sampler/load/utils.py
@@ -12,3 +12,4 @@ pyproj
12
12
  pyaml_env
13
13
  pyresample
14
14
  h5netcdf
15
+ xarray-tensorstore==0.1.5
@@ -35,6 +35,7 @@ dependencies = [
35
35
  "pyaml_env",
36
36
  "pyresample",
37
37
  "h5netcdf",
38
+ "xarray-tensorstore==0.1.5",
38
39
  ]
39
40
 
40
41
  [dependency-groups]
@@ -1,43 +0,0 @@
1
- """Utility functions for the NWP data processing."""
2
-
3
- import xarray as xr
4
-
5
-
6
- def open_zarr_paths(
7
- zarr_path: str | list[str], time_dim: str = "init_time", public: bool = False,
8
- ) -> xr.Dataset:
9
- """Opens the NWP data.
10
-
11
- Args:
12
- zarr_path: Path to the zarr(s) to open
13
- time_dim: Name of the time dimension
14
- public: Whether the data is public or private
15
-
16
- Returns:
17
- The opened Xarray Dataset
18
- """
19
- general_kwargs = {
20
- "engine": "zarr",
21
- "chunks": "auto",
22
- "decode_timedelta": True,
23
- }
24
-
25
- if public:
26
- # note this only works for s3 zarr paths at the moment
27
- general_kwargs["storage_options"] = {"anon": True}
28
-
29
- if type(zarr_path) in [list, tuple] or "*" in str(zarr_path): # Multi-file dataset
30
- ds = xr.open_mfdataset(
31
- zarr_path,
32
- concat_dim=time_dim,
33
- combine="nested",
34
- **general_kwargs,
35
- ).sortby(time_dim)
36
- else:
37
- ds = xr.open_dataset(
38
- zarr_path,
39
- consolidated=True,
40
- mode="r",
41
- **general_kwargs,
42
- )
43
- return ds
@@ -1,47 +0,0 @@
1
- """Functions for simulating dropout in time series data.
2
-
3
- This is used for the following types of data: GSP, Satellite and Site
4
- This is not used for NWP
5
- """
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import xarray as xr
10
-
11
-
12
- def apply_sampled_dropout_time(
13
- t0: pd.Timestamp,
14
- dropout_timedeltas: list[pd.Timedelta],
15
- dropout_frac: float,
16
- da: xr.DataArray,
17
- ) -> xr.DataArray:
18
- """Randomly pick a dropout time from a list of timedeltas and apply dropout time to the data.
19
-
20
- Args:
21
- t0: The forecast init-time
22
- dropout_timedeltas: List of timedeltas relative to t0 to pick from
23
- dropout_frac: Probability that dropout will be applied.
24
- This should be between 0 and 1 inclusive
25
- da: Xarray DataArray with 'time_utc' coordinate
26
- """
27
- # sample dropout time
28
- if dropout_frac > 0 and len(dropout_timedeltas) == 0:
29
- raise ValueError("To apply dropout, dropout_timedeltas must be provided")
30
-
31
- for t in dropout_timedeltas:
32
- if t > pd.Timedelta("0min"):
33
- raise ValueError("Dropout timedeltas must be negative")
34
-
35
- if not (0 <= dropout_frac <= 1):
36
- raise ValueError("dropout_frac must be between 0 and 1 inclusive")
37
-
38
- if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
39
- dropout_time = None
40
- else:
41
- dropout_time = t0 + np.random.choice(dropout_timedeltas)
42
-
43
- # apply dropout time
44
- if dropout_time is None:
45
- return da
46
- # This replaces the times after the dropout with NaNs
47
- return da.where(da.time_utc <= dropout_time)