ocf-data-sampler 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -73,3 +73,26 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
73
73
  elif np.issubdtype(v.dtype, np.number):
74
74
  batch[k] = torch.as_tensor(v)
75
75
  return batch
76
+
77
+
78
+ def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
79
+ """
80
+ Moves tensor leaves in a nested dict to a new device.
81
+
82
+ Args:
83
+ batch: Nested dict with tensors to move.
84
+ device: Device to move tensors to.
85
+
86
+ Returns:
87
+ A dict with tensors moved to the new device.
88
+ """
89
+ batch_copy = {}
90
+
91
+ for k, v in batch.items():
92
+ if isinstance(v, dict):
93
+ batch_copy[k] = copy_batch_to_device(v, device)
94
+ elif isinstance(v, torch.Tensor):
95
+ batch_copy[k] = v.to(device)
96
+ else:
97
+ batch_copy[k] = v
98
+ return batch_copy
@@ -186,9 +186,8 @@ class PVNetUKRegionalDataset(Dataset):
186
186
  gsp_ids: List of GSP IDs to create samples for. Defaults to all
187
187
  """
188
188
 
189
- config = load_yaml_configuration(config_filename)
190
-
191
- # Validate channels for NWP and satellite data
189
+ # config = load_yaml_configuration(config_filename)
190
+ config: Configuration = load_yaml_configuration(config_filename)
192
191
  validate_nwp_channels(config)
193
192
  validate_satellite_channels(config)
194
193
 
@@ -20,7 +20,6 @@ from ocf_data_sampler.select import (
20
20
  from ocf_data_sampler.utils import minutes
21
21
  from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
22
22
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import merge_dicts, fill_nans_in_arrays
23
- from ocf_data_sampler.torch_datasets.utils.validate_channels import validate_nwp_channels
24
23
 
25
24
  from ocf_data_sampler.numpy_sample import (
26
25
  convert_site_to_numpy_sample,
@@ -30,8 +29,12 @@ from ocf_data_sampler.numpy_sample import (
30
29
  make_sun_position_numpy_sample,
31
30
  )
32
31
  from ocf_data_sampler.numpy_sample import NWPSampleKey
33
- from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS
32
+ from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
34
33
 
34
+ from ocf_data_sampler.torch_datasets.utils.validate_channels import (
35
+ validate_nwp_channels,
36
+ validate_satellite_channels,
37
+ )
35
38
 
36
39
  xr.set_options(keep_attrs=True)
37
40
 
@@ -52,9 +55,8 @@ class SitesDataset(Dataset):
52
55
  """
53
56
 
54
57
  config: Configuration = load_yaml_configuration(config_filename)
55
-
56
- # Validate NWP channels
57
58
  validate_nwp_channels(config)
59
+ validate_satellite_channels(config)
58
60
 
59
61
  datasets_dict = get_dataset_dict(config.input_data)
60
62
 
@@ -237,8 +239,10 @@ class SitesDataset(Dataset):
237
239
  data_arrays.append((f"nwp-{provider}", da_nwp))
238
240
 
239
241
  if "sat" in dataset_dict:
240
- # TODO add some satellite normalisation
241
242
  da_sat = dataset_dict["sat"]
243
+
244
+ # Standardise
245
+ da_sat = (da_sat - RSS_MEAN) / RSS_STD
242
246
  data_arrays.append(("satellite", da_sat))
243
247
 
244
248
  if "site" in dataset_dict:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
27
27
  ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
28
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
30
+ ocf_data_sampler/sample/base.py,sha256=q3wpqoW4JXRmzfar6ed7UMn1nxBxSJXNvMLJmHXy1dw,2856
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
32
  ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
@@ -41,12 +41,12 @@ ocf_data_sampler/select/select_time_slice.py,sha256=9M-yvDv9K77XfEys_OIR31_aVB56
41
41
  ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=3tRrMBXr7s4CnClbVSIq7hpls3H4Y3qYTDwswcxCCCE,1763
42
42
  ocf_data_sampler/select/time_slice_for_dataset.py,sha256=Z7pOiilSHScxmBKZNG18K5J-S4ifdXXAYGZoHRHD3AY,4324
43
43
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
44
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=N85duDyEm6LIYgYIpLhrpxHddMIcvFosuZg8rzIztwE,12267
45
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=L_4w967ZxPjd7vHRkPtj7ZSmamEShKRT28j9_f-enJY,16228
44
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=ZgfvVCcEU3dj3RoY0zdBdKGppC7Wm81qecqB17gYTmE,12286
45
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_uHmqg-VJu-MHgXc5JFDX1noPfH6E8nY4XhQmsrOav4,16325
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
48
  ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
49
- scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
49
+ scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
50
50
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
52
52
  tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
@@ -68,15 +68,15 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
68
68
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
69
69
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
70
70
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
71
- tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
71
+ tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
72
72
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
73
73
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
74
74
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
75
75
  tests/torch_datasets/test_pvnet_uk.py,sha256=F0D-DugFgVtt8G1q7lylmPLrOZj6H6YPNd9s_6Wn_yM,5594
76
76
  tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
77
77
  tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
78
- ocf_data_sampler-0.1.8.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
- ocf_data_sampler-0.1.8.dist-info/METADATA,sha256=hWohmy0-J73u-uy3MPEG0_tuprAXOh32hX8WyIDPqaU,12173
80
- ocf_data_sampler-0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
- ocf_data_sampler-0.1.8.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
- ocf_data_sampler-0.1.8.dist-info/RECORD,,
78
+ ocf_data_sampler-0.1.10.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
+ ocf_data_sampler-0.1.10.dist-info/METADATA,sha256=HDEoz2xG-Qw23Rz2Wcms5_w1p3hbrRK1MBIJjwA6WrA,12174
80
+ ocf_data_sampler-0.1.10.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
+ ocf_data_sampler-0.1.10.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
+ ocf_data_sampler-0.1.10.dist-info/RECORD,,
scripts/refactor_site.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """ Helper functions for refactoring legacy site data """
2
-
2
+ import xarray as xr
3
3
 
4
4
  def legacy_format(data_ds, metadata_df):
5
5
  """This formats old legacy data to the new format.
@@ -9,7 +9,8 @@ import numpy as np
9
9
  from pathlib import Path
10
10
  from ocf_data_sampler.sample.base import (
11
11
  SampleBase,
12
- batch_to_tensor
12
+ batch_to_tensor,
13
+ copy_batch_to_device
13
14
  )
14
15
 
15
16
  class TestSample(SampleBase):
@@ -145,3 +146,19 @@ def test_batch_to_tensor_multidimensional():
145
146
  assert tensor_batch['matrix'].shape == (2, 2)
146
147
  assert tensor_batch['tensor'].shape == (2, 2, 2)
147
148
  assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
149
+
150
+
151
+ def test_copy_batch_to_device():
152
+ """ Test moving tensors to a different device """
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ batch = {
155
+ 'tensor_data': torch.tensor([1, 2, 3]),
156
+ 'nested': {
157
+ 'matrix': torch.tensor([[1, 2], [3, 4]])
158
+ },
159
+ 'non_tensor': 'unchanged'
160
+ }
161
+ moved_batch = copy_batch_to_device(batch, device)
162
+ assert moved_batch['tensor_data'].device == device
163
+ assert moved_batch['nested']['matrix'].device == device
164
+ assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged