ocf-data-sampler 0.2.8__tar.gz → 0.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (67) hide show
  1. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/PKG-INFO +2 -1
  2. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/__init__.py +0 -2
  3. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +165 -185
  4. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/torch_datasets/datasets/site.py +55 -59
  5. ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets/sample/__init__.py +3 -0
  6. {ocf_data_sampler-0.2.8/ocf_data_sampler → ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets}/sample/site.py +2 -1
  7. {ocf_data_sampler-0.2.8/ocf_data_sampler → ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets}/sample/uk_regional.py +2 -1
  8. ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets/utils/__init__.py +5 -0
  9. ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +18 -0
  10. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler.egg-info/PKG-INFO +2 -1
  11. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler.egg-info/SOURCES.txt +6 -6
  12. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler.egg-info/requires.txt +1 -0
  13. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/pyproject.toml +1 -0
  14. ocf_data_sampler-0.2.8/ocf_data_sampler/sample/__init__.py +0 -3
  15. ocf_data_sampler-0.2.8/ocf_data_sampler/torch_datasets/utils/__init__.py +0 -3
  16. ocf_data_sampler-0.2.8/ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -11
  17. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/LICENSE +0 -0
  18. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/README.md +0 -0
  19. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/__init__.py +0 -0
  20. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/config/__init__.py +0 -0
  21. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/config/load.py +0 -0
  22. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/config/model.py +0 -0
  23. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/config/save.py +0 -0
  24. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  25. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/__init__.py +0 -0
  26. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/gsp.py +0 -0
  27. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/load_dataset.py +0 -0
  28. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  29. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  30. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  31. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
  32. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  33. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
  34. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
  35. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  36. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  37. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/satellite.py +0 -0
  38. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/site.py +0 -0
  39. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/load/utils.py +0 -0
  40. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  41. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  42. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
  43. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  44. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  45. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  46. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  47. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/site.py +0 -0
  48. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  49. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/dropout.py +0 -0
  50. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  51. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  52. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/geospatial.py +0 -0
  53. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/location.py +0 -0
  54. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  55. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/select/select_time_slice.py +0 -0
  56. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
  57. {ocf_data_sampler-0.2.8/ocf_data_sampler → ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets}/sample/base.py +0 -0
  58. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  59. {ocf_data_sampler-0.2.8/ocf_data_sampler/select → ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets/utils}/spatial_slice_for_dataset.py +0 -0
  60. {ocf_data_sampler-0.2.8/ocf_data_sampler/select → ocf_data_sampler-0.2.9/ocf_data_sampler/torch_datasets/utils}/time_slice_for_dataset.py +0 -0
  61. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  62. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler/utils.py +0 -0
  63. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  64. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  65. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/scripts/refactor_site.py +0 -0
  66. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/setup.cfg +0 -0
  67. {ocf_data_sampler-0.2.8 → ocf_data_sampler-0.2.9}/utils/compute_icon_mean_stddev.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -35,6 +35,7 @@ Requires-Dist: numpy
35
35
  Requires-Dist: pandas
36
36
  Requires-Dist: xarray
37
37
  Requires-Dist: zarr==2.18.3
38
+ Requires-Dist: numcodecs<0.16
38
39
  Requires-Dist: dask
39
40
  Requires-Dist: matplotlib
40
41
  Requires-Dist: ocf_blosc2
@@ -4,5 +4,3 @@ from .find_contiguous_time_periods import (
4
4
  intersection_of_multiple_dataframes_of_periods,
5
5
  )
6
6
  from .location import Location
7
- from .spatial_slice_for_dataset import slice_datasets_by_space
8
- from .time_slice_for_dataset import slice_datasets_by_time
@@ -17,16 +17,17 @@ from ocf_data_sampler.numpy_sample import (
17
17
  make_sun_position_numpy_sample,
18
18
  )
19
19
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
20
+ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
20
21
  from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
21
22
  from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
22
- from ocf_data_sampler.select import (
23
- Location,
24
- fill_time_periods,
23
+ from ocf_data_sampler.select import Location, fill_time_periods
24
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
25
+ from ocf_data_sampler.torch_datasets.utils import (
26
+ channel_dict_to_dataarray,
27
+ find_valid_time_periods,
25
28
  slice_datasets_by_space,
26
29
  slice_datasets_by_time,
27
30
  )
28
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
- from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
31
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
32
  fill_nans_in_arrays,
32
33
  merge_dicts,
@@ -36,99 +37,6 @@ from ocf_data_sampler.utils import minutes
36
37
  xr.set_options(keep_attrs=True)
37
38
 
38
39
 
39
- def process_and_combine_datasets(
40
- dataset_dict: dict,
41
- config: Configuration,
42
- t0: pd.Timestamp,
43
- location: Location,
44
- ) -> dict:
45
- """Normalise and convert data to numpy arrays."""
46
- numpy_modalities = []
47
-
48
- if "nwp" in dataset_dict:
49
- nwp_numpy_modalities = {}
50
-
51
- for nwp_key, da_nwp in dataset_dict["nwp"].items():
52
-
53
- # Standardise and convert to NumpyBatch
54
-
55
- da_channel_means = channel_dict_to_dataarray(
56
- config.input_data.nwp[nwp_key].channel_means,
57
- )
58
- da_channel_stds = channel_dict_to_dataarray(
59
- config.input_data.nwp[nwp_key].channel_stds,
60
- )
61
-
62
- da_nwp = (da_nwp - da_channel_means) / da_channel_stds
63
-
64
- nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
65
-
66
- # Combine the NWPs into NumpyBatch
67
- numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
68
-
69
- if "sat" in dataset_dict:
70
- da_sat = dataset_dict["sat"]
71
-
72
- # Standardise and convert to NumpyBatch
73
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
74
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
75
-
76
- da_sat = (da_sat - da_channel_means) / da_channel_stds
77
-
78
- numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
79
-
80
- if "gsp" in dataset_dict:
81
- gsp_config = config.input_data.gsp
82
- da_gsp = dataset_dict["gsp"]
83
- da_gsp = da_gsp / da_gsp.effective_capacity_mwp
84
-
85
- # Convert to NumpyBatch
86
- numpy_modalities.append(
87
- convert_gsp_to_numpy_sample(
88
- da_gsp,
89
- t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
90
- ),
91
- )
92
-
93
- # Add GSP location data
94
- numpy_modalities.append(
95
- {
96
- GSPSampleKey.gsp_id: location.id,
97
- GSPSampleKey.x_osgb: location.x,
98
- GSPSampleKey.y_osgb: location.y,
99
- },
100
- )
101
-
102
- # Only add solar position if explicitly configured
103
- has_solar_config = (
104
- hasattr(config.input_data, "solar_position") and
105
- config.input_data.solar_position is not None
106
- )
107
-
108
- if has_solar_config:
109
- solar_config = config.input_data.solar_position
110
-
111
- # Create datetime range for solar position calculation
112
- datetimes = pd.date_range(
113
- t0 + minutes(solar_config.interval_start_minutes),
114
- t0 + minutes(solar_config.interval_end_minutes),
115
- freq=minutes(solar_config.time_resolution_minutes),
116
- )
117
-
118
- # Convert OSGB coordinates to lon/lat
119
- lon, lat = osgb_to_lon_lat(location.x, location.y)
120
-
121
- # Calculate solar positions and add to modalities
122
- solar_positions = make_sun_position_numpy_sample(datetimes, lon, lat)
123
- numpy_modalities.append(solar_positions)
124
-
125
- # Combine all the modalities and fill NaNs
126
- combined_sample = merge_dicts(numpy_modalities)
127
- combined_sample = fill_nans_in_arrays(combined_sample)
128
-
129
- return combined_sample
130
-
131
-
132
40
  def compute(xarray_dict: dict) -> dict:
133
41
  """Eagerly load a nested dictionary of xarray DataArrays."""
134
42
  for k, v in xarray_dict.items():
@@ -139,25 +47,12 @@ def compute(xarray_dict: dict) -> dict:
139
47
  return xarray_dict
140
48
 
141
49
 
142
- def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
143
- """Find the t0 times where all of the requested input data is available.
50
+ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
51
+ """Get list of locations of all GSPs.
144
52
 
145
53
  Args:
146
- datasets_dict: A dictionary of input datasets
147
- config: Configuration file
54
+ gsp_ids: List of GSP IDs to include. Defaults to all
148
55
  """
149
- valid_time_periods = find_valid_time_periods(datasets_dict, config)
150
-
151
- # Fill out the contiguous time periods to get the t0 times
152
- valid_t0_times = fill_time_periods(
153
- valid_time_periods,
154
- freq=minutes(config.input_data.gsp.time_resolution_minutes),
155
- )
156
- return valid_t0_times
157
-
158
-
159
- def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
160
- """Get list of locations of all GSPs."""
161
56
  if gsp_ids is None:
162
57
  gsp_ids = list(range(1, 318))
163
58
 
@@ -181,8 +76,8 @@ def get_gsp_locations(gsp_ids: list[int] | None = None) -> list[Location]:
181
76
  return locations
182
77
 
183
78
 
184
- class PVNetUKRegionalDataset(Dataset):
185
- """A torch Dataset for creating PVNet UK regional samples."""
79
+ class AbstractPVNetUKDataset(Dataset):
80
+ """Abstract class for PVNet UK datasets."""
186
81
 
187
82
  def __init__(
188
83
  self,
@@ -191,7 +86,7 @@ class PVNetUKRegionalDataset(Dataset):
191
86
  end_time: str | None = None,
192
87
  gsp_ids: list[int] | None = None,
193
88
  ) -> None:
194
- """A torch Dataset for creating PVNet UK GSP samples.
89
+ """A torch Dataset for creating PVNet UK samples.
195
90
 
196
91
  Args:
197
92
  config_filename: Path to the configuration file
@@ -199,13 +94,11 @@ class PVNetUKRegionalDataset(Dataset):
199
94
  end_time: Limit the init-times to be before this
200
95
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
201
96
  """
202
- # config = load_yaml_configuration(config_filename)
203
- config: Configuration = load_yaml_configuration(config_filename)
204
-
97
+ config = load_yaml_configuration(config_filename)
205
98
  datasets_dict = get_dataset_dict(config.input_data)
206
99
 
207
100
  # Get t0 times where all input data is available
208
- valid_t0_times = find_valid_t0_times(datasets_dict, config)
101
+ valid_t0_times = self.find_valid_t0_times(datasets_dict, config)
209
102
 
210
103
  # Filter t0 times to given range
211
104
  if start_time is not None:
@@ -215,35 +108,167 @@ class PVNetUKRegionalDataset(Dataset):
215
108
  valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
216
109
 
217
110
  # Construct list of locations to sample from
218
- locations = get_gsp_locations(gsp_ids)
111
+ self.locations = get_gsp_locations(gsp_ids)
112
+ self.valid_t0_times = valid_t0_times
113
+
114
+ # Assign config and input data to self
115
+ self.config = config
116
+ self.datasets_dict = datasets_dict
117
+
118
+
119
+ @staticmethod
120
+ def process_and_combine_datasets(
121
+ dataset_dict: dict,
122
+ config: Configuration,
123
+ t0: pd.Timestamp,
124
+ location: Location,
125
+ ) -> NumpySample:
126
+ """Normalise and convert data to numpy arrays.
127
+
128
+ Args:
129
+ dataset_dict: Dictionary of xarray datasets
130
+ config: Configuration object
131
+ t0: init-time for sample
132
+ location: location of the sample
133
+ """
134
+ numpy_modalities = []
135
+
136
+ if "nwp" in dataset_dict:
137
+ nwp_numpy_modalities = {}
138
+
139
+ for nwp_key, da_nwp in dataset_dict["nwp"].items():
140
+
141
+ # Standardise and convert to NumpyBatch
142
+
143
+ da_channel_means = channel_dict_to_dataarray(
144
+ config.input_data.nwp[nwp_key].channel_means,
145
+ )
146
+ da_channel_stds = channel_dict_to_dataarray(
147
+ config.input_data.nwp[nwp_key].channel_stds,
148
+ )
149
+
150
+ da_nwp = (da_nwp - da_channel_means) / da_channel_stds
151
+
152
+ nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_sample(da_nwp)
153
+
154
+ # Combine the NWPs into NumpyBatch
155
+ numpy_modalities.append({NWPSampleKey.nwp: nwp_numpy_modalities})
156
+
157
+ if "sat" in dataset_dict:
158
+ da_sat = dataset_dict["sat"]
159
+
160
+ # Standardise and convert to NumpyBatch
161
+ da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
162
+ da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
163
+
164
+ da_sat = (da_sat - da_channel_means) / da_channel_stds
165
+
166
+ numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
167
+
168
+ if "gsp" in dataset_dict:
169
+ gsp_config = config.input_data.gsp
170
+ da_gsp = dataset_dict["gsp"]
171
+ da_gsp = da_gsp / da_gsp.effective_capacity_mwp
172
+
173
+ # Convert to NumpyBatch
174
+ numpy_modalities.append(
175
+ convert_gsp_to_numpy_sample(
176
+ da_gsp,
177
+ t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes,
178
+ ),
179
+ )
180
+
181
+ # Add GSP location data
182
+ numpy_modalities.append(
183
+ {
184
+ GSPSampleKey.gsp_id: location.id,
185
+ GSPSampleKey.x_osgb: location.x,
186
+ GSPSampleKey.y_osgb: location.y,
187
+ },
188
+ )
189
+
190
+ # Only add solar position if explicitly configured
191
+ has_solar_config = (
192
+ hasattr(config.input_data, "solar_position") and
193
+ config.input_data.solar_position is not None
194
+ )
195
+
196
+ if has_solar_config:
197
+ solar_config = config.input_data.solar_position
198
+
199
+ # Create datetime range for solar position calculation
200
+ datetimes = pd.date_range(
201
+ t0 + minutes(solar_config.interval_start_minutes),
202
+ t0 + minutes(solar_config.interval_end_minutes),
203
+ freq=minutes(solar_config.time_resolution_minutes),
204
+ )
205
+
206
+ # Convert OSGB coordinates to lon/lat
207
+ lon, lat = osgb_to_lon_lat(location.x, location.y)
208
+
209
+ # Calculate solar positions and add to modalities
210
+ numpy_modalities.append(make_sun_position_numpy_sample(datetimes, lon, lat))
211
+
212
+ # Combine all the modalities and fill NaNs
213
+ combined_sample = merge_dicts(numpy_modalities)
214
+ combined_sample = fill_nans_in_arrays(combined_sample)
215
+
216
+ return combined_sample
217
+
218
+ @staticmethod
219
+ def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
220
+ """Find the t0 times where all of the requested input data is available.
221
+
222
+ Args:
223
+ datasets_dict: A dictionary of input datasets
224
+ config: Configuration file
225
+ """
226
+ valid_time_periods = find_valid_time_periods(datasets_dict, config)
227
+
228
+ # Fill out the contiguous time periods to get the t0 times
229
+ valid_t0_times = fill_time_periods(
230
+ valid_time_periods,
231
+ freq=minutes(config.input_data.gsp.time_resolution_minutes),
232
+ )
233
+ return valid_t0_times
234
+
235
+
236
+
237
+ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
238
+ """A torch Dataset for creating PVNet UK regional samples."""
239
+
240
+ @override
241
+ def __init__(
242
+ self,
243
+ config_filename: str,
244
+ start_time: str | None = None,
245
+ end_time: str | None = None,
246
+ gsp_ids: list[int] | None = None,
247
+ ) -> None:
248
+
249
+ super().__init__(config_filename, start_time, end_time, gsp_ids)
219
250
 
220
251
  # Construct a lookup for locations - useful for users to construct sample by GSP ID
221
- location_lookup = {loc.id: loc for loc in locations}
252
+ location_lookup = {loc.id: loc for loc in self.locations}
222
253
 
223
254
  # Construct indices for sampling
224
255
  t_index, loc_index = np.meshgrid(
225
- np.arange(len(valid_t0_times)),
226
- np.arange(len(locations)),
256
+ np.arange(len(self.valid_t0_times)),
257
+ np.arange(len(self.locations)),
227
258
  )
228
259
 
229
260
  # Make array of all possible (t0, location) coordinates. Each row is a single coordinate
230
261
  index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
231
262
 
232
263
  # Assign coords and indices to self
233
- self.valid_t0_times = valid_t0_times
234
- self.locations = locations
235
264
  self.location_lookup = location_lookup
236
265
  self.index_pairs = index_pairs
237
266
 
238
- # Assign config and input data to self
239
- self.datasets_dict = datasets_dict
240
- self.config = config
241
-
242
267
  @override
243
268
  def __len__(self) -> int:
244
269
  return len(self.index_pairs)
245
270
 
246
- def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
271
+ def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpySample:
247
272
  """Generate the PVNet sample for given coordinates.
248
273
 
249
274
  Args:
@@ -254,21 +279,18 @@ class PVNetUKRegionalDataset(Dataset):
254
279
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
255
280
  sample_dict = compute(sample_dict)
256
281
 
257
- sample = process_and_combine_datasets(sample_dict, self.config, t0, location)
258
-
259
- return sample
282
+ return self.process_and_combine_datasets(sample_dict, self.config, t0, location)
260
283
 
261
284
  @override
262
- def __getitem__(self, idx: int) -> dict:
285
+ def __getitem__(self, idx: int) -> NumpySample:
263
286
  # Get the coordinates of the sample
264
287
  t_index, loc_index = self.index_pairs[idx]
265
288
  location = self.locations[loc_index]
266
289
  t0 = self.valid_t0_times[t_index]
267
290
 
268
- # Generate the sample
269
291
  return self._get_sample(t0, location)
270
292
 
271
- def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
293
+ def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpySample:
272
294
  """Generate a sample for the given coordinates.
273
295
 
274
296
  Useful for users to generate specific samples.
@@ -288,56 +310,14 @@ class PVNetUKRegionalDataset(Dataset):
288
310
  return self._get_sample(t0, location)
289
311
 
290
312
 
291
- class PVNetUKConcurrentDataset(Dataset):
313
+ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
292
314
  """A torch Dataset for creating concurrent PVNet UK regional samples."""
293
315
 
294
- def __init__(
295
- self,
296
- config_filename: str,
297
- start_time: str | None = None,
298
- end_time: str | None = None,
299
- gsp_ids: list[int] | None = None,
300
- ) -> None:
301
- """A torch Dataset for creating concurrent samples of PVNet UK regional data.
302
-
303
- Each concurrent sample includes the data from all GSPs for a single t0 time
304
-
305
- Args:
306
- config_filename: Path to the configuration file
307
- start_time: Limit the init-times to be after this
308
- end_time: Limit the init-times to be before this
309
- gsp_ids: List of all GSP IDs included in each sample. Defaults to all
310
- """
311
- config = load_yaml_configuration(config_filename)
312
-
313
- datasets_dict = get_dataset_dict(config.input_data)
314
-
315
- # Get t0 times where all input data is available
316
- valid_t0_times = find_valid_t0_times(datasets_dict, config)
317
-
318
- # Filter t0 times to given range
319
- if start_time is not None:
320
- valid_t0_times = valid_t0_times[valid_t0_times >= pd.Timestamp(start_time)]
321
-
322
- if end_time is not None:
323
- valid_t0_times = valid_t0_times[valid_t0_times <= pd.Timestamp(end_time)]
324
-
325
- # Construct list of locations to sample from
326
- locations = get_gsp_locations(gsp_ids)
327
-
328
- # Assign coords and indices to self
329
- self.valid_t0_times = valid_t0_times
330
- self.locations = locations
331
-
332
- # Assign config and input data to self
333
- self.datasets_dict = datasets_dict
334
- self.config = config
335
-
336
316
  @override
337
317
  def __len__(self) -> int:
338
318
  return len(self.valid_t0_times)
339
319
 
340
- def _get_sample(self, t0: pd.Timestamp) -> dict:
320
+ def _get_sample(self, t0: pd.Timestamp) -> NumpyBatch:
341
321
  """Generate a concurrent PVNet sample for given init-time.
342
322
 
343
323
  Args:
@@ -352,7 +332,7 @@ class PVNetUKConcurrentDataset(Dataset):
352
332
  # Prepare sample for each GSP
353
333
  for location in self.locations:
354
334
  gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
355
- gsp_numpy_sample = process_and_combine_datasets(
335
+ gsp_numpy_sample = self.process_and_combine_datasets(
356
336
  gsp_sample_dict,
357
337
  self.config,
358
338
  t0,
@@ -364,10 +344,10 @@ class PVNetUKConcurrentDataset(Dataset):
364
344
  return stack_np_samples_into_batch(gsp_samples)
365
345
 
366
346
  @override
367
- def __getitem__(self, idx: int) -> dict:
347
+ def __getitem__(self, idx: int) -> NumpyBatch:
368
348
  return self._get_sample(self.valid_t0_times[idx])
369
349
 
370
- def get_sample(self, t0: pd.Timestamp) -> dict:
350
+ def get_sample(self, t0: pd.Timestamp) -> NumpyBatch:
371
351
  """Generate a sample for the given init-time.
372
352
 
373
353
  Useful for users to generate specific samples.
@@ -1,14 +1,12 @@
1
1
  """Torch dataset for sites."""
2
2
 
3
- import logging
4
-
5
3
  import numpy as np
6
4
  import pandas as pd
7
5
  import xarray as xr
8
6
  from torch.utils.data import Dataset
9
7
  from typing_extensions import override
10
8
 
11
- from ocf_data_sampler.config import Configuration, load_yaml_configuration
9
+ from ocf_data_sampler.config import load_yaml_configuration
12
10
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
13
11
  from ocf_data_sampler.numpy_sample import (
14
12
  NWPSampleKey,
@@ -18,15 +16,19 @@ from ocf_data_sampler.numpy_sample import (
18
16
  make_datetime_numpy_dict,
19
17
  make_sun_position_numpy_sample,
20
18
  )
19
+ from ocf_data_sampler.numpy_sample.common_types import NumpySample
21
20
  from ocf_data_sampler.select import (
22
21
  Location,
23
22
  fill_time_periods,
24
23
  find_contiguous_t0_periods,
25
24
  intersection_of_multiple_dataframes_of_periods,
25
+ )
26
+ from ocf_data_sampler.torch_datasets.utils import (
27
+ channel_dict_to_dataarray,
28
+ find_valid_time_periods,
26
29
  slice_datasets_by_space,
27
30
  slice_datasets_by_time,
28
31
  )
29
- from ocf_data_sampler.torch_datasets.utils import channel_dict_to_dataarray, find_valid_time_periods
30
32
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
31
33
  fill_nans_in_arrays,
32
34
  merge_dicts,
@@ -52,7 +54,7 @@ class SitesDataset(Dataset):
52
54
  start_time: Limit the init-times to be after this
53
55
  end_time: Limit the init-times to be before this
54
56
  """
55
- config: Configuration = load_yaml_configuration(config_filename)
57
+ config = load_yaml_configuration(config_filename)
56
58
  datasets_dict = get_dataset_dict(config.input_data)
57
59
 
58
60
  # Assign config and input data to self
@@ -61,6 +63,7 @@ class SitesDataset(Dataset):
61
63
 
62
64
  # get all locations
63
65
  self.locations = self.get_locations(datasets_dict["site"])
66
+ self.location_lookup = {loc.id: loc for loc in self.locations}
64
67
 
65
68
  # Get t0 times where all input data is available
66
69
  valid_t0_and_site_ids = self.find_valid_t0_and_site_ids(datasets_dict)
@@ -89,7 +92,7 @@ class SitesDataset(Dataset):
89
92
  t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
90
93
 
91
94
  # get location from site id
92
- location = self.get_location_from_site_id(site_id)
95
+ location = self.location_lookup[site_id]
93
96
 
94
97
  # Generate the sample
95
98
  return self._get_sample(t0, location)
@@ -105,8 +108,7 @@ class SitesDataset(Dataset):
105
108
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
106
109
 
107
110
  sample = self.process_and_combine_site_sample_dict(sample_dict, t0)
108
- sample = sample.compute()
109
- return sample
111
+ return sample.compute()
110
112
 
111
113
  def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
112
114
  """Generate a sample for a given site id and t0.
@@ -117,22 +119,10 @@ class SitesDataset(Dataset):
117
119
  t0: init-time for sample
118
120
  site_id: site id as int
119
121
  """
120
- location = self.get_location_from_site_id(site_id)
122
+ location = self.location_lookup[site_id]
121
123
 
122
124
  return self._get_sample(t0, location)
123
125
 
124
- def get_location_from_site_id(self, site_id: int) -> Location:
125
- """Get location from system id."""
126
- locations = [loc for loc in self.locations if loc.id == site_id]
127
- if len(locations) == 0:
128
- raise ValueError(f"Location not found for site_id {site_id}")
129
-
130
- if len(locations) > 1:
131
- logging.warning(
132
- f"Multiple locations found for site_id {site_id}, but will take the first",
133
- )
134
-
135
- return locations[0]
136
126
 
137
127
  def find_valid_t0_and_site_ids(
138
128
  self,
@@ -148,24 +138,21 @@ class SitesDataset(Dataset):
148
138
  datasets_dict: A dictionary of input datasets
149
139
  config: Configuration file
150
140
  """
151
- # 1. Get valid time period for nwp and satellite
141
+ # Get valid time period for nwp and satellite
152
142
  datasets_without_site = {k: v for k, v in datasets_dict.items() if k != "site"}
153
143
  valid_time_periods = find_valid_time_periods(datasets_without_site, self.config)
154
144
 
155
- # 2. Now lets loop over each location in system id and find the valid periods
156
- # Should we have a different option if there are not nans
145
+ # Loop over each location in system id and obtain valid periods
157
146
  sites = datasets_dict["site"]
158
147
  site_ids = sites.site_id.values
159
148
  site_config = self.config.input_data.site
160
149
  valid_t0_and_site_ids = []
161
150
  for site_id in site_ids:
162
151
  site = sites.sel(site_id=site_id)
163
-
164
- # drop any nan values
165
- # not sure this is right?
152
+ # Drop NaN values
166
153
  site = site.dropna(dim="time_utc")
167
154
 
168
- # Get the valid time periods for this location
155
+ # Obtain valid time periods for this location
169
156
  time_periods = find_contiguous_t0_periods(
170
157
  pd.DatetimeIndex(site["time_utc"]),
171
158
  time_resolution=minutes(site_config.time_resolution_minutes),
@@ -176,7 +163,7 @@ class SitesDataset(Dataset):
176
163
  [valid_time_periods, time_periods],
177
164
  )
178
165
 
179
- # Fill out the contiguous time periods to get the t0 times
166
+ # Fill out contiguous time periods to get t0 times
180
167
  valid_t0_times_per_site = fill_time_periods(
181
168
  valid_time_periods_per_site,
182
169
  freq=minutes(site_config.time_resolution_minutes),
@@ -188,12 +175,15 @@ class SitesDataset(Dataset):
188
175
 
189
176
  valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
190
177
  valid_t0_and_site_ids.index.name = "t0"
191
- valid_t0_and_site_ids.reset_index(inplace=True)
178
+ return valid_t0_and_site_ids.reset_index()
192
179
 
193
- return valid_t0_and_site_ids
194
180
 
195
181
  def get_locations(self, site_xr: xr.Dataset) -> list[Location]:
196
- """Get list of locations of all sites."""
182
+ """Get list of locations of all sites.
183
+
184
+ Args:
185
+ site_xr: xarray Dataset of site data
186
+ """
197
187
  locations = []
198
188
  for site_id in site_xr.site_id.values:
199
189
  site = site_xr.sel(site_id=site_id)
@@ -220,7 +210,6 @@ class SitesDataset(Dataset):
220
210
 
221
211
  Returns:
222
212
  xr.Dataset: A merged Dataset with nans filled in.
223
-
224
213
  """
225
214
  data_arrays = []
226
215
 
@@ -228,7 +217,6 @@ class SitesDataset(Dataset):
228
217
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
229
218
  provider = self.config.input_data.nwp[nwp_key].provider
230
219
 
231
- # Standardise
232
220
  da_channel_means = channel_dict_to_dataarray(
233
221
  self.config.input_data.nwp[nwp_key].channel_means,
234
222
  )
@@ -237,7 +225,6 @@ class SitesDataset(Dataset):
237
225
  )
238
226
 
239
227
  da_nwp = (da_nwp - da_channel_means) / da_channel_stds
240
-
241
228
  data_arrays.append((f"nwp-{provider}", da_nwp))
242
229
 
243
230
  if "sat" in dataset_dict:
@@ -251,11 +238,9 @@ class SitesDataset(Dataset):
251
238
  )
252
239
 
253
240
  da_sat = (da_sat - da_channel_means) / da_channel_stds
254
-
255
241
  data_arrays.append(("satellite", da_sat))
256
242
 
257
243
  if "site" in dataset_dict:
258
- # site_config = config.input_data.site
259
244
  da_sites = dataset_dict["site"]
260
245
  da_sites = da_sites / da_sites.capacity_kwp
261
246
  data_arrays.append(("site", da_sites))
@@ -372,12 +357,16 @@ class SitesDataset(Dataset):
372
357
 
373
358
 
374
359
  def convert_netcdf_to_numpy_sample(ds: xr.Dataset) -> dict:
375
- """Convert a netcdf dataset to a numpy sample."""
360
+ """Convert a netcdf dataset to a numpy sample.
361
+
362
+ Args:
363
+ ds: xarray Dataset
364
+ """
376
365
  # convert the single dataset to a dict of arrays
377
366
  sample_dict = convert_from_dataset_to_dict_datasets(ds)
378
367
 
379
368
  if "satellite" in sample_dict:
380
- # rename satellite to satellite actual # TODO this could be improves
369
+ # rename satellite to sat # TODO this could be improved
381
370
  sample_dict["sat"] = sample_dict.pop("satellite")
382
371
 
383
372
  # process and combine the datasets
@@ -408,43 +397,52 @@ def convert_from_dataset_to_dict_datasets(combined_dataset: xr.Dataset) -> dict[
408
397
  The uncombined datasets as a dict of xr.Datasets
409
398
  """
410
399
  # Split into datasets by splitting by the prefix added in combine_to_netcdf
411
- datasets = {}
400
+ datasets: dict[str, xr.DataArray] = {}
401
+
412
402
  # Go through each data variable and split it into a dataset
413
403
  for key, dataset in combined_dataset.items():
414
- # If 'key_' doesn't exist in a dim or coordinate, remove it
415
- dataset_dims = list(dataset.coords)
416
- for dim in dataset_dims:
404
+ # If 'key__' doesn't exist in a dim or coordinate, remove it
405
+ for dim in list(dataset.coords):
417
406
  if f"{key}__" not in dim:
418
- dataset: xr.Dataset = dataset.drop(dim)
407
+ dataset = dataset.drop_vars(dim)
419
408
  dataset = dataset.rename(
420
409
  {dim: dim.split(f"{key}__")[1] for dim in dataset.dims if dim not in dataset.coords},
421
410
  )
422
- dataset: xr.Dataset = dataset.rename(
411
+ dataset = dataset.rename(
423
412
  {coord: coord.split(f"{key}__")[1] for coord in dataset.coords},
424
413
  )
425
414
  # Split the dataset by the prefix
426
415
  datasets[key] = dataset
427
416
 
428
417
  # Unflatten any NWP data
429
- datasets = nest_nwp_source_dict(datasets, sep="-")
430
- return datasets
418
+ return nest_nwp_source_dict(datasets, sep="-")
419
+
431
420
 
421
+ def nest_nwp_source_dict(
422
+ dataset_dict: dict[xr.Dataset],
423
+ sep: str = "-",
424
+ ) -> dict[str, xr.Dataset | dict[xr.Dataset]]:
425
+ """Re-nest a dictionary where the NWP values are nested under keys 'nwp-<key>'.
432
426
 
433
- def nest_nwp_source_dict(d: dict, sep: str = "/") -> dict:
434
- """Re-nest a dictionary where the NWP values are nested under keys 'nwp/<key>'."""
427
+ Args:
428
+ dataset_dict: Dictionary of datasets
429
+ sep: Separator to use to nest NWP keys
430
+ """
435
431
  nwp_prefix = f"nwp{sep}"
436
- new_dict = {k: v for k, v in d.items() if not k.startswith(nwp_prefix)}
437
- nwp_keys = [k for k in d if k.startswith(nwp_prefix)]
432
+ new_dict = {k: v for k, v in dataset_dict.items() if not k.startswith(nwp_prefix)}
433
+ nwp_keys = [k for k in dataset_dict if k.startswith(nwp_prefix)]
438
434
  if len(nwp_keys) > 0:
439
- nwp_subdict = {k.removeprefix(nwp_prefix): d[k] for k in nwp_keys}
435
+ nwp_subdict = {k.removeprefix(nwp_prefix): dataset_dict[k] for k in nwp_keys}
440
436
  new_dict["nwp"] = nwp_subdict
441
437
  return new_dict
442
438
 
443
439
 
444
- def convert_to_numpy_and_combine(
445
- dataset_dict: dict,
446
- ) -> dict:
447
- """Convert input data in a dict to numpy arrays."""
440
+ def convert_to_numpy_and_combine(dataset_dict: dict[xr.Dataset]) -> NumpySample:
441
+ """Convert input data in a dict to numpy arrays.
442
+
443
+ Args:
444
+ dataset_dict: Dictionary of xarray Datasets
445
+ """
448
446
  numpy_modalities = []
449
447
 
450
448
  if "nwp" in dataset_dict:
@@ -474,9 +472,7 @@ def convert_to_numpy_and_combine(
474
472
 
475
473
  # Combine all the modalities and fill NaNs
476
474
  combined_sample = merge_dicts(numpy_modalities)
477
- combined_sample = fill_nans_in_arrays(combined_sample)
478
-
479
- return combined_sample
475
+ return fill_nans_in_arrays(combined_sample)
480
476
 
481
477
 
482
478
  def coarsen_data(xr_data: xr.Dataset, coarsen_to_deg: float = 0.1) -> xr.Dataset:
@@ -0,0 +1,3 @@
1
+ from .base import SampleBase
2
+ from .uk_regional import UKRegionalSample
3
+ from .site import SiteSample
@@ -4,9 +4,10 @@ import xarray as xr
4
4
  from typing_extensions import override
5
5
 
6
6
  from ocf_data_sampler.numpy_sample.common_types import NumpySample
7
- from ocf_data_sampler.sample.base import SampleBase
8
7
  from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
9
8
 
9
+ from .base import SampleBase
10
+
10
11
 
11
12
  class SiteSample(SampleBase):
12
13
  """Handles PVNet site specific netCDF operations."""
@@ -9,7 +9,8 @@ from ocf_data_sampler.numpy_sample import (
9
9
  SatelliteSampleKey,
10
10
  )
11
11
  from ocf_data_sampler.numpy_sample.common_types import NumpySample
12
- from ocf_data_sampler.sample.base import SampleBase
12
+
13
+ from .base import SampleBase
13
14
 
14
15
 
15
16
  class UKRegionalSample(SampleBase):
@@ -0,0 +1,5 @@
1
+ from .channel_dict_to_dataarray import channel_dict_to_dataarray
2
+ from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
+ from .valid_time_periods import find_valid_time_periods
4
+ from .spatial_slice_for_dataset import slice_datasets_by_space
5
+ from .time_slice_for_dataset import slice_datasets_by_time
@@ -0,0 +1,18 @@
1
+ """Utility function for converting channel dictionaries to xarray DataArrays."""
2
+
3
+ import xarray as xr
4
+
5
+
6
+ def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
+ """Converts a dictionary of channel values to a DataArray.
8
+
9
+ Args:
10
+ channel_dict: Dictionary mapping channel names (str) to their values (float).
11
+
12
+ Returns:
13
+ xr.DataArray: A 1D DataArray with channels as coordinates.
14
+ """
15
+ return xr.DataArray(
16
+ list(channel_dict.values()),
17
+ coords={"channel": list(channel_dict.keys())},
18
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -35,6 +35,7 @@ Requires-Dist: numpy
35
35
  Requires-Dist: pandas
36
36
  Requires-Dist: xarray
37
37
  Requires-Dist: zarr==2.18.3
38
+ Requires-Dist: numcodecs<0.16
38
39
  Requires-Dist: dask
39
40
  Requires-Dist: matplotlib
40
41
  Requires-Dist: ocf_blosc2
@@ -37,10 +37,6 @@ ocf_data_sampler/numpy_sample/nwp.py
37
37
  ocf_data_sampler/numpy_sample/satellite.py
38
38
  ocf_data_sampler/numpy_sample/site.py
39
39
  ocf_data_sampler/numpy_sample/sun_position.py
40
- ocf_data_sampler/sample/__init__.py
41
- ocf_data_sampler/sample/base.py
42
- ocf_data_sampler/sample/site.py
43
- ocf_data_sampler/sample/uk_regional.py
44
40
  ocf_data_sampler/select/__init__.py
45
41
  ocf_data_sampler/select/dropout.py
46
42
  ocf_data_sampler/select/fill_time_periods.py
@@ -49,14 +45,18 @@ ocf_data_sampler/select/geospatial.py
49
45
  ocf_data_sampler/select/location.py
50
46
  ocf_data_sampler/select/select_spatial_slice.py
51
47
  ocf_data_sampler/select/select_time_slice.py
52
- ocf_data_sampler/select/spatial_slice_for_dataset.py
53
- ocf_data_sampler/select/time_slice_for_dataset.py
54
48
  ocf_data_sampler/torch_datasets/datasets/__init__.py
55
49
  ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
56
50
  ocf_data_sampler/torch_datasets/datasets/site.py
51
+ ocf_data_sampler/torch_datasets/sample/__init__.py
52
+ ocf_data_sampler/torch_datasets/sample/base.py
53
+ ocf_data_sampler/torch_datasets/sample/site.py
54
+ ocf_data_sampler/torch_datasets/sample/uk_regional.py
57
55
  ocf_data_sampler/torch_datasets/utils/__init__.py
58
56
  ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py
59
57
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
58
+ ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py
59
+ ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py
60
60
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py
61
61
  scripts/refactor_site.py
62
62
  utils/compute_icon_mean_stddev.py
@@ -3,6 +3,7 @@ numpy
3
3
  pandas
4
4
  xarray
5
5
  zarr==2.18.3
6
+ numcodecs<0.16
6
7
  dask
7
8
  matplotlib
8
9
  ocf_blosc2
@@ -26,6 +26,7 @@ dependencies = [
26
26
  "pandas",
27
27
  "xarray",
28
28
  "zarr==2.18.3",
29
+ "numcodecs<0.16",
29
30
  "dask",
30
31
  "matplotlib",
31
32
  "ocf_blosc2",
@@ -1,3 +0,0 @@
1
- from ocf_data_sampler.sample.base import SampleBase
2
- from ocf_data_sampler.sample.uk_regional import UKRegionalSample
3
- from ocf_data_sampler.sample.site import SiteSample
@@ -1,3 +0,0 @@
1
- from .channel_dict_to_dataarray import channel_dict_to_dataarray
2
- from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
- from .valid_time_periods import find_valid_time_periods
@@ -1,11 +0,0 @@
1
- """Converts a dictionary of channel values to a DataArray."""
2
-
3
- import xarray as xr
4
-
5
-
6
- def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
- """Converts a dictionary of channel values to a DataArray."""
8
- return xr.DataArray(
9
- list(channel_dict.values()),
10
- coords={"channel": list(channel_dict.keys())},
11
- )