ocf-data-sampler 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -52,9 +52,12 @@ def open_gsp(
52
52
  backend_kwargs = {"storage_options": {"anon": True}}
53
53
  # Currently only compatible with S3 bucket.
54
54
 
55
- ds = xr.open_dataset(zarr_path, engine="zarr", backend_kwargs=backend_kwargs).rename(
56
- {"datetime_gmt": "time_utc"},
57
- )
55
+ ds = xr.open_dataset(
56
+ zarr_path,
57
+ engine="zarr",
58
+ chunks=None,
59
+ backend_kwargs=backend_kwargs,
60
+ ).rename({"datetime_gmt": "time_utc"})
58
61
 
59
62
  if not (ds.gsp_id.isin(df_gsp_loc.index)).all():
60
63
  raise ValueError(
@@ -21,7 +21,7 @@ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
21
21
  from ocf_data_sampler.select import Location, fill_time_periods
22
22
  from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
23
23
  from ocf_data_sampler.torch_datasets.utils import (
24
- channel_dict_to_dataarray,
24
+ config_normalization_values_to_dicts,
25
25
  find_valid_time_periods,
26
26
  slice_datasets_by_space,
27
27
  slice_datasets_by_time,
@@ -110,11 +110,14 @@ class AbstractPVNetUKDataset(Dataset):
110
110
  self.config = config
111
111
  self.datasets_dict = datasets_dict
112
112
 
113
+ # Extract the normalisation values from the config for faster access
114
+ means_dict, stds_dict = config_normalization_values_to_dicts(config)
115
+ self.means_dict = means_dict
116
+ self.stds_dict = stds_dict
113
117
 
114
- @staticmethod
115
118
  def process_and_combine_datasets(
119
+ self,
116
120
  dataset_dict: dict,
117
- config: Configuration,
118
121
  t0: pd.Timestamp,
119
122
  location: Location,
120
123
  ) -> NumpySample:
@@ -122,7 +125,6 @@ class AbstractPVNetUKDataset(Dataset):
122
125
 
123
126
  Args:
124
127
  dataset_dict: Dictionary of xarray datasets
125
- config: Configuration object
126
128
  t0: init-time for sample
127
129
  location: location of the sample
128
130
  """
@@ -134,13 +136,8 @@ class AbstractPVNetUKDataset(Dataset):
134
136
  for nwp_key, da_nwp in dataset_dict["nwp"].items():
135
137
 
136
138
  # Standardise and convert to NumpyBatch
137
-
138
- da_channel_means = channel_dict_to_dataarray(
139
- config.input_data.nwp[nwp_key].channel_means,
140
- )
141
- da_channel_stds = channel_dict_to_dataarray(
142
- config.input_data.nwp[nwp_key].channel_stds,
143
- )
139
+ da_channel_means = self.means_dict["nwp"][nwp_key]
140
+ da_channel_stds = self.stds_dict["nwp"][nwp_key]
144
141
 
145
142
  da_nwp = (da_nwp - da_channel_means) / da_channel_stds
146
143
 
@@ -153,15 +150,15 @@ class AbstractPVNetUKDataset(Dataset):
153
150
  da_sat = dataset_dict["sat"]
154
151
 
155
152
  # Standardise and convert to NumpyBatch
156
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
157
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
153
+ da_channel_means = self.means_dict["sat"]
154
+ da_channel_stds = self.stds_dict["sat"]
158
155
 
159
156
  da_sat = (da_sat - da_channel_means) / da_channel_stds
160
157
 
161
158
  numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
162
159
 
163
160
  if "gsp" in dataset_dict:
164
- gsp_config = config.input_data.gsp
161
+ gsp_config = self.config.input_data.gsp
165
162
  da_gsp = dataset_dict["gsp"]
166
163
  da_gsp = da_gsp / da_gsp.effective_capacity_mwp
167
164
 
@@ -183,13 +180,8 @@ class AbstractPVNetUKDataset(Dataset):
183
180
  )
184
181
 
185
182
  # Only add solar position if explicitly configured
186
- has_solar_config = (
187
- hasattr(config.input_data, "solar_position") and
188
- config.input_data.solar_position is not None
189
- )
190
-
191
- if has_solar_config:
192
- solar_config = config.input_data.solar_position
183
+ if self.config.input_data.solar_position is not None:
184
+ solar_config = self.config.input_data.solar_position
193
185
 
194
186
  # Create datetime range for solar position calculation
195
187
  datetimes = pd.date_range(
@@ -264,7 +256,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
264
256
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
265
257
  sample_dict = compute(sample_dict)
266
258
 
267
- return self.process_and_combine_datasets(sample_dict, self.config, t0, location)
259
+ return self.process_and_combine_datasets(sample_dict, t0, location)
268
260
 
269
261
  @override
270
262
  def __getitem__(self, idx: int) -> NumpySample:
@@ -330,7 +322,6 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
330
322
  gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
331
323
  gsp_numpy_sample = self.process_and_combine_datasets(
332
324
  gsp_sample_dict,
333
- self.config,
334
325
  t0,
335
326
  location,
336
327
  )
@@ -25,7 +25,7 @@ from ocf_data_sampler.select import (
25
25
  intersection_of_multiple_dataframes_of_periods,
26
26
  )
27
27
  from ocf_data_sampler.torch_datasets.utils import (
28
- channel_dict_to_dataarray,
28
+ config_normalization_values_to_dicts,
29
29
  find_valid_time_periods,
30
30
  slice_datasets_by_space,
31
31
  slice_datasets_by_time,
@@ -62,6 +62,8 @@ def process_and_combine_datasets(
62
62
  dataset_dict: dict,
63
63
  config: Configuration,
64
64
  t0: pd.Timestamp,
65
+ means_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
66
+ stds_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
65
67
  ) -> NumpySample:
66
68
  """Normalise and convert data to numpy arrays.
67
69
 
@@ -69,6 +71,8 @@ def process_and_combine_datasets(
69
71
  dataset_dict: Dictionary of xarray datasets
70
72
  config: Configuration object
71
73
  t0: init-time for sample
74
+ means_dict: Nested dictionary of mean values for the input data sources
75
+ stds_dict: Nested dictionary of std values for the input data sources
72
76
  """
73
77
  numpy_modalities = []
74
78
 
@@ -79,12 +83,8 @@ def process_and_combine_datasets(
79
83
 
80
84
  # Standardise and convert to NumpyBatch
81
85
 
82
- da_channel_means = channel_dict_to_dataarray(
83
- config.input_data.nwp[nwp_key].channel_means,
84
- )
85
- da_channel_stds = channel_dict_to_dataarray(
86
- config.input_data.nwp[nwp_key].channel_stds,
87
- )
86
+ da_channel_means = means_dict["nwp"][nwp_key]
87
+ da_channel_stds = stds_dict["nwp"][nwp_key]
88
88
 
89
89
  da_nwp = (da_nwp - da_channel_means) / da_channel_stds
90
90
 
@@ -97,8 +97,8 @@ def process_and_combine_datasets(
97
97
  da_sat = dataset_dict["sat"]
98
98
 
99
99
  # Standardise and convert to NumpyBatch
100
- da_channel_means = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
101
- da_channel_stds = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
100
+ da_channel_means = means_dict["sat"]
101
+ da_channel_stds = stds_dict["sat"]
102
102
 
103
103
  da_sat = (da_sat - da_channel_means) / da_channel_stds
104
104
 
@@ -109,11 +109,7 @@ def process_and_combine_datasets(
109
109
  da_sites = da_sites / da_sites.capacity_kwp
110
110
 
111
111
  # Convert to NumpyBatch
112
- numpy_modalities.append(
113
- convert_site_to_numpy_sample(
114
- da_sites,
115
- ),
116
- )
112
+ numpy_modalities.append(convert_site_to_numpy_sample(da_sites))
117
113
 
118
114
  # add datetime features
119
115
  datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
@@ -193,6 +189,11 @@ class SitesDataset(Dataset):
193
189
  # Assign coords and indices to self
194
190
  self.valid_t0_and_site_ids = valid_t0_and_site_ids
195
191
 
192
+ # Extract the normalisation values from the config for faster access
193
+ means_dict, stds_dict = config_normalization_values_to_dicts(config)
194
+ self.means_dict = means_dict
195
+ self.stds_dict = stds_dict
196
+
196
197
  def find_valid_t0_and_site_ids(
197
198
  self,
198
199
  datasets_dict: dict,
@@ -273,7 +274,13 @@ class SitesDataset(Dataset):
273
274
 
274
275
  sample_dict = compute(sample_dict)
275
276
 
276
- return process_and_combine_datasets(sample_dict, self.config, t0)
277
+ return process_and_combine_datasets(
278
+ sample_dict,
279
+ self.config,
280
+ t0,
281
+ self.means_dict,
282
+ self.stds_dict,
283
+ )
277
284
 
278
285
  def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
279
286
  """Generate a sample for a given site id and t0.
@@ -332,6 +339,11 @@ class SitesDatasetConcurrent(Dataset):
332
339
  # Assign coords and indices to self
333
340
  self.valid_t0s = valid_t0s
334
341
 
342
+ # Extract the normalisation values from the config for faster access
343
+ means_dict, stds_dict = config_normalization_values_to_dicts(config)
344
+ self.means_dict = means_dict
345
+ self.stds_dict = stds_dict
346
+
335
347
  def find_valid_t0s(
336
348
  self,
337
349
  datasets_dict: dict,
@@ -406,6 +418,8 @@ class SitesDatasetConcurrent(Dataset):
406
418
  site_sample_dict,
407
419
  self.config,
408
420
  t0,
421
+ self.means_dict,
422
+ self.stds_dict,
409
423
  )
410
424
  site_samples.append(site_numpy_sample)
411
425
 
@@ -1,4 +1,4 @@
1
- from .channel_dict_to_dataarray import channel_dict_to_dataarray
1
+ from .config_normalization_values_to_dicts import config_normalization_values_to_dicts
2
2
  from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
3
3
  from .valid_time_periods import find_valid_time_periods
4
4
  from .spatial_slice_for_dataset import slice_datasets_by_space
@@ -0,0 +1,57 @@
1
+ """Utility function for converting channel dictionaries to xarray DataArrays."""
2
+
3
+ import xarray as xr
4
+
5
+ from ocf_data_sampler.config import Configuration
6
+
7
+
8
+ def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
9
+ """Converts a dictionary of channel values to a DataArray.
10
+
11
+ Args:
12
+ channel_dict: Dictionary mapping channel names (str) to their values (float).
13
+
14
+ Returns:
15
+ xr.DataArray: A 1D DataArray with channels as coordinates.
16
+ """
17
+ return xr.DataArray(
18
+ list(channel_dict.values()),
19
+ coords={"channel": list(channel_dict.keys())},
20
+ )
21
+
22
+ def config_normalization_values_to_dicts(
23
+ config: Configuration,
24
+ ) -> tuple[dict[str, xr.DataArray | dict[str, xr.DataArray]]]:
25
+ """Construct DataArrays of mean and std values from the config normalisation constants.
26
+
27
+ Args:
28
+ config: Data configuration.
29
+
30
+ Returns:
31
+ Means dict
32
+ Stds dict
33
+ """
34
+ means_dict = {}
35
+ stds_dict = {}
36
+
37
+ if config.input_data.nwp is not None:
38
+
39
+ means_dict["nwp"] = {}
40
+ stds_dict["nwp"] = {}
41
+
42
+ for nwp_key in config.input_data.nwp:
43
+ # Standardise and convert to NumpyBatch
44
+
45
+ means_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
46
+ config.input_data.nwp[nwp_key].channel_means,
47
+ )
48
+ stds_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
49
+ config.input_data.nwp[nwp_key].channel_stds,
50
+ )
51
+
52
+ if config.input_data.satellite is not None:
53
+
54
+ means_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
55
+ stds_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
56
+
57
+ return means_dict, stds_dict
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.1
3
+ Version: 0.5.3
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -7,7 +7,7 @@ ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdR
7
7
  ocf_data_sampler/data/uk_gsp_locations_20220314.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
8
  ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKFzQSjs6hcHFsV8a9uDDpy2E,9055334
9
9
  ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
10
- ocf_data_sampler/load/gsp.py,sha256=IrTA6z9quN08imKGHJLf8gRktarxn1-utNMNFD0zWQs,2944
10
+ ocf_data_sampler/load/gsp.py,sha256=d30jQWnwFaLj6rKNMHdz1qD8fzF8q--RNnEXT7bGiX0,2981
11
11
  ocf_data_sampler/load/load_dataset.py,sha256=K8rWykjII-3g127If7WRRFivzHNx3SshCvZj4uQlf28,2089
12
12
  ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=_RHWe0GmrBSA9s1TH5I9VCMPpeZEsuRuhDt5Vyyx5Fo,2725
13
13
  ocf_data_sampler/load/satellite.py,sha256=RylkJz8avxdM5pK_liaTlD1DTboyPMgykXJ4_Ek9WBA,1840
@@ -40,14 +40,14 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
40
40
  ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
41
41
  ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
42
42
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
43
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=v63goKEMI6UgBPnQCnIbxhFFdwuP_sxgcPYY6iNfGkc,12257
44
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_0A2kRq8B5WL5zWjKxNY9snAl_GwptohUt7c6DDa2AA,14812
43
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=876oLukvb1nLtZQ8HBN3PWfN7urKH2xa45tVar7XrbM,12010
44
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nn6N8daGxllYwCCiFKbCJANTl84NrDRl-nbNGcfXc3U,15429
45
45
  ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
46
46
  ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
47
47
  ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
48
48
  ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
49
- ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=N7i_hHtWUDiJqsiJoDx4T_QuiYOuvIyulPrn6xEA4TY,309
50
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=un2IiyoAmTDIymdeMiPU899_86iCDMD-oIifjHlNyqw,555
49
+ ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=_UHLL_yRzhLJVHi6ROSaSe8TGw80CAhU325uCZj7XkY,331
50
+ ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py,sha256=jS3DkAwOF1W3AQnvsdkBJ1C8Unm93kQbS8hgTCtFv2A,1743
51
51
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
52
52
  ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
53
53
  ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=8E4a5v9dqr-sZOyBruuO-tjLPBbjtpYtdFY5z23aqnU,4365
@@ -56,7 +56,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
56
56
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
57
57
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
58
58
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
59
- ocf_data_sampler-0.5.1.dist-info/METADATA,sha256=sd5ucgDgrjrwa8vImToOUdU3BCWM-fMSsDHTS51p4Zc,12580
60
- ocf_data_sampler-0.5.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- ocf_data_sampler-0.5.1.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
- ocf_data_sampler-0.5.1.dist-info/RECORD,,
59
+ ocf_data_sampler-0.5.3.dist-info/METADATA,sha256=9gg1K9SNIX6pJ-PXQptutiLU9fo7FsnrKM6vdHbpQYg,12580
60
+ ocf_data_sampler-0.5.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
+ ocf_data_sampler-0.5.3.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
+ ocf_data_sampler-0.5.3.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- """Utility function for converting channel dictionaries to xarray DataArrays."""
2
-
3
- import xarray as xr
4
-
5
-
6
- def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
7
- """Converts a dictionary of channel values to a DataArray.
8
-
9
- Args:
10
- channel_dict: Dictionary mapping channel names (str) to their values (float).
11
-
12
- Returns:
13
- xr.DataArray: A 1D DataArray with channels as coordinates.
14
- """
15
- return xr.DataArray(
16
- list(channel_dict.values()),
17
- coords={"channel": list(channel_dict.keys())},
18
- )