ocf-data-sampler 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/sample/base.py +23 -0
- {ocf_data_sampler-0.1.8.dist-info → ocf_data_sampler-0.1.9.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.8.dist-info → ocf_data_sampler-0.1.9.dist-info}/RECORD +8 -8
- scripts/refactor_site.py +1 -1
- tests/test_sample/test_base.py +18 -1
- {ocf_data_sampler-0.1.8.dist-info → ocf_data_sampler-0.1.9.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.8.dist-info → ocf_data_sampler-0.1.9.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.8.dist-info → ocf_data_sampler-0.1.9.dist-info}/top_level.txt +0 -0
ocf_data_sampler/sample/base.py
CHANGED
|
@@ -73,3 +73,26 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
|
73
73
|
elif np.issubdtype(v.dtype, np.number):
|
|
74
74
|
batch[k] = torch.as_tensor(v)
|
|
75
75
|
return batch
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
|
|
79
|
+
"""
|
|
80
|
+
Moves tensor leaves in a nested dict to a new device.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
batch: Nested dict with tensors to move.
|
|
84
|
+
device: Device to move tensors to.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A dict with tensors moved to the new device.
|
|
88
|
+
"""
|
|
89
|
+
batch_copy = {}
|
|
90
|
+
|
|
91
|
+
for k, v in batch.items():
|
|
92
|
+
if isinstance(v, dict):
|
|
93
|
+
batch_copy[k] = copy_batch_to_device(v, device)
|
|
94
|
+
elif isinstance(v, torch.Tensor):
|
|
95
|
+
batch_copy[k] = v.to(device)
|
|
96
|
+
else:
|
|
97
|
+
batch_copy[k] = v
|
|
98
|
+
return batch_copy
|
|
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
|
|
|
27
27
|
ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
|
|
28
28
|
ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
|
|
29
29
|
ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
|
|
30
|
-
ocf_data_sampler/sample/base.py,sha256=
|
|
30
|
+
ocf_data_sampler/sample/base.py,sha256=q3wpqoW4JXRmzfar6ed7UMn1nxBxSJXNvMLJmHXy1dw,2856
|
|
31
31
|
ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
|
|
32
32
|
ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
@@ -46,7 +46,7 @@ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=L_4w967ZxPjd7vHRkPtj7ZSm
|
|
|
46
46
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=hIbekql64eXsNDFIoEc--GWxwdVWrh2qKegdOi70Bow,874
|
|
47
47
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=Qo65qUHtle_bW5tLTYr7empHTRv-lpjvfx_6GNJj3Xg,4371
|
|
48
48
|
ocf_data_sampler/torch_datasets/utils/validate_channels.py,sha256=u2EpiFAKAOHpmvINhOUJCT8Vbc-cle6qJ3YNVse4yLs,2884
|
|
49
|
-
scripts/refactor_site.py,sha256=
|
|
49
|
+
scripts/refactor_site.py,sha256=xaJGxt2_WObIPrPAnRiOMMB68r-5Q51jWRx409AcscM,1747
|
|
50
50
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
51
|
tests/conftest.py,sha256=RlC7YYtBLipUzFS1tQxela1SgHCxSpReUKEJ4429PwQ,7689
|
|
52
52
|
tests/config/test_config.py,sha256=VQjNiucIk5VnPQdGA6Mr-RNd9CwGI06AiikChTHrcnY,3969
|
|
@@ -68,15 +68,15 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
|
|
|
68
68
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
69
69
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
70
70
|
tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
|
|
71
|
-
tests/test_sample/test_base.py,sha256=
|
|
71
|
+
tests/test_sample/test_base.py,sha256=sD9NZghYQWbkAcQP9YXypWZowqYkO3xeNMH-_mEoD5I,4833
|
|
72
72
|
tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
|
|
73
73
|
tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
|
|
74
74
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
|
|
75
75
|
tests/torch_datasets/test_pvnet_uk.py,sha256=F0D-DugFgVtt8G1q7lylmPLrOZj6H6YPNd9s_6Wn_yM,5594
|
|
76
76
|
tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
|
|
77
77
|
tests/torch_datasets/test_validate_channels_utils.py,sha256=Rzdweu98j1of45jCOUrSiBtyPlf-dDaCceulf0H7ml8,2921
|
|
78
|
-
ocf_data_sampler-0.1.
|
|
79
|
-
ocf_data_sampler-0.1.
|
|
80
|
-
ocf_data_sampler-0.1.
|
|
81
|
-
ocf_data_sampler-0.1.
|
|
82
|
-
ocf_data_sampler-0.1.
|
|
78
|
+
ocf_data_sampler-0.1.9.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
79
|
+
ocf_data_sampler-0.1.9.dist-info/METADATA,sha256=Lfu8Yrj4CSlqPzGhk0iDy5r5zCLd5REnGAlVcFuKuow,12173
|
|
80
|
+
ocf_data_sampler-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
81
|
+
ocf_data_sampler-0.1.9.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
82
|
+
ocf_data_sampler-0.1.9.dist-info/RECORD,,
|
scripts/refactor_site.py
CHANGED
tests/test_sample/test_base.py
CHANGED
|
@@ -9,7 +9,8 @@ import numpy as np
|
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from ocf_data_sampler.sample.base import (
|
|
11
11
|
SampleBase,
|
|
12
|
-
batch_to_tensor
|
|
12
|
+
batch_to_tensor,
|
|
13
|
+
copy_batch_to_device
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
class TestSample(SampleBase):
|
|
@@ -145,3 +146,19 @@ def test_batch_to_tensor_multidimensional():
|
|
|
145
146
|
assert tensor_batch['matrix'].shape == (2, 2)
|
|
146
147
|
assert tensor_batch['tensor'].shape == (2, 2, 2)
|
|
147
148
|
assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_copy_batch_to_device():
|
|
152
|
+
""" Test moving tensors to a different device """
|
|
153
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
154
|
+
batch = {
|
|
155
|
+
'tensor_data': torch.tensor([1, 2, 3]),
|
|
156
|
+
'nested': {
|
|
157
|
+
'matrix': torch.tensor([[1, 2], [3, 4]])
|
|
158
|
+
},
|
|
159
|
+
'non_tensor': 'unchanged'
|
|
160
|
+
}
|
|
161
|
+
moved_batch = copy_batch_to_device(batch, device)
|
|
162
|
+
assert moved_batch['tensor_data'].device == device
|
|
163
|
+
assert moved_batch['nested']['matrix'].device == device
|
|
164
|
+
assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged
|
|
File without changes
|
|
File without changes
|
|
File without changes
|