ocf-data-sampler 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -73,3 +73,26 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
73
73
  elif np.issubdtype(v.dtype, np.number):
74
74
  batch[k] = torch.as_tensor(v)
75
75
  return batch
76
+
77
+
78
+ def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
79
+ """
80
+ Moves tensor leaves in a nested dict to a new device.
81
+
82
+ Args:
83
+ batch: Nested dict with tensors to move.
84
+ device: Device to move tensors to.
85
+
86
+ Returns:
87
+ A dict with tensors moved to the new device.
88
+ """
89
+ batch_copy = {}
90
+
91
+ for k, v in batch.items():
92
+ if isinstance(v, dict):
93
+ batch_copy[k] = copy_batch_to_device(v, device)
94
+ elif isinstance(v, torch.Tensor):
95
+ batch_copy[k] = v.to(device)
96
+ else:
97
+ batch_copy[k] = v
98
+ return batch_copy
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
27
27
  ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
28
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
30
+ ocf_data_sampler/sample/base.py,sha256=q3wpqoW4JXRmzfar6ed7UMn1nxBxSJXNvMLJmHXy1dw,2856
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
32
  ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
@@ -46,7 +46,7 @@ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=L_4w967ZxPjd7vHRkPtj7ZSm
46
46
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
47
47
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
48
48
  ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
49
- scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
49
+ scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
50
50
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
52
52
  tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
@@ -68,15 +68,15 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
68
68
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
69
69
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
70
70
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
71
- tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
71
+ tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
72
72
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
73
73
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
74
74
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
75
75
  tests/torch_datasets/test_pvnet_uk.py,sha256=F0D-DugFgVtt8G1q7lylmPLrOZj6H6YPNd9s_6Wn_yM,5594
76
76
  tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
77
77
  tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
78
- ocf_data_sampler-0.1.8.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
- ocf_data_sampler-0.1.8.dist-info/METADATA,sha256=hWohmy0-J73u-uy3MPEG0_tuprAXOh32hX8WyIDPqaU,12173
80
- ocf_data_sampler-0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
- ocf_data_sampler-0.1.8.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
- ocf_data_sampler-0.1.8.dist-info/RECORD,,
78
+ ocf_data_sampler-0.1.9.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
79
+ ocf_data_sampler-0.1.9.dist-info/METADATA,sha256=Lfu8Yrj4CSlqPzGhk0iDy5r5zCLd5REnGAlVcFuKuow,12173
80
+ ocf_data_sampler-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
81
+ ocf_data_sampler-0.1.9.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
82
+ ocf_data_sampler-0.1.9.dist-info/RECORD,,
scripts/refactor_site.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """ Helper functions for refactoring legacy site data """
2
-
2
+ import xarray as xr
3
3
 
4
4
  def legacy_format(data_ds, metadata_df):
5
5
  """This formats old legacy data to the new format.
@@ -9,7 +9,8 @@ import numpy as np
9
9
  from pathlib import Path
10
10
  from ocf_data_sampler.sample.base import (
11
11
  SampleBase,
12
- batch_to_tensor
12
+ batch_to_tensor,
13
+ copy_batch_to_device
13
14
  )
14
15
 
15
16
  class TestSample(SampleBase):
@@ -145,3 +146,19 @@ def test_batch_to_tensor_multidimensional():
145
146
  assert tensor_batch['matrix'].shape == (2, 2)
146
147
  assert tensor_batch['tensor'].shape == (2, 2, 2)
147
148
  assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
149
+
150
+
151
+ def test_copy_batch_to_device():
152
+ """ Test moving tensors to a different device """
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ batch = {
155
+ 'tensor_data': torch.tensor([1, 2, 3]),
156
+ 'nested': {
157
+ 'matrix': torch.tensor([[1, 2], [3, 4]])
158
+ },
159
+ 'non_tensor': 'unchanged'
160
+ }
161
+ moved_batch = copy_batch_to_device(batch, device)
162
+ assert moved_batch['tensor_data'].device == device
163
+ assert moved_batch['nested']['matrix'].device == device
164
+ assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged