ocf-data-sampler 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/sample/base.py +34 -3
- {ocf_data_sampler-0.1.6.dist-info → ocf_data_sampler-0.1.7.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.6.dist-info → ocf_data_sampler-0.1.7.dist-info}/RECORD +7 -7
- tests/test_sample/test_base.py +63 -2
- {ocf_data_sampler-0.1.6.dist-info → ocf_data_sampler-0.1.7.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.6.dist-info → ocf_data_sampler-0.1.7.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.6.dist-info → ocf_data_sampler-0.1.7.dist-info}/top_level.txt +0 -0
ocf_data_sampler/sample/base.py
CHANGED
|
@@ -5,25 +5,34 @@ Handling of both flat and nested structures - consideration for NWP
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import xarray as xr
|
|
8
10
|
|
|
9
11
|
from pathlib import Path
|
|
10
|
-
from typing import Any, Dict, Optional, Union
|
|
12
|
+
from typing import Any, Dict, Optional, Union, TypeAlias
|
|
11
13
|
from abc import ABC, abstractmethod
|
|
12
14
|
|
|
15
|
+
|
|
13
16
|
logger = logging.getLogger(__name__)
|
|
14
17
|
|
|
18
|
+
NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
|
|
19
|
+
NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
|
|
20
|
+
TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
|
|
21
|
+
|
|
22
|
+
|
|
15
23
|
class SampleBase(ABC):
|
|
16
24
|
"""
|
|
17
25
|
Abstract base class for all sample types
|
|
18
26
|
Provides core data storage functionality
|
|
19
27
|
"""
|
|
20
28
|
|
|
21
|
-
def __init__(self):
|
|
29
|
+
def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
|
|
22
30
|
""" Initialise data container """
|
|
23
31
|
logger.debug("Initialising SampleBase instance")
|
|
32
|
+
self._data = data
|
|
24
33
|
|
|
25
34
|
@abstractmethod
|
|
26
|
-
def to_numpy(self) ->
|
|
35
|
+
def to_numpy(self) -> NumpySample:
|
|
27
36
|
""" Convert data to a numpy array representation """
|
|
28
37
|
raise NotImplementedError
|
|
29
38
|
|
|
@@ -42,3 +51,25 @@ class SampleBase(ABC):
|
|
|
42
51
|
def load(cls, path: Union[str, Path]) -> 'SampleBase':
|
|
43
52
|
""" Abstract class method for loading sample data """
|
|
44
53
|
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
57
|
+
"""
|
|
58
|
+
Moves ndarrays in a nested dict to torch tensors
|
|
59
|
+
Args:
|
|
60
|
+
batch: NumpyBatch with data in numpy arrays
|
|
61
|
+
Returns:
|
|
62
|
+
TensorBatch with data in torch tensors
|
|
63
|
+
"""
|
|
64
|
+
if not batch:
|
|
65
|
+
raise ValueError("Cannot convert empty batch to tensors")
|
|
66
|
+
|
|
67
|
+
for k, v in batch.items():
|
|
68
|
+
if isinstance(v, dict):
|
|
69
|
+
batch[k] = batch_to_tensor(v)
|
|
70
|
+
elif isinstance(v, np.ndarray):
|
|
71
|
+
if v.dtype == np.bool_:
|
|
72
|
+
batch[k] = torch.tensor(v, dtype=torch.bool)
|
|
73
|
+
elif np.issubdtype(v.dtype, np.number):
|
|
74
|
+
batch[k] = torch.as_tensor(v)
|
|
75
|
+
return batch
|
|
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
|
|
|
27
27
|
ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
|
|
28
28
|
ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
|
|
29
29
|
ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
|
|
30
|
-
ocf_data_sampler/sample/base.py,sha256=
|
|
30
|
+
ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
|
|
31
31
|
ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
|
|
32
32
|
ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
|
|
@@ -67,14 +67,14 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
|
|
|
67
67
|
tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
|
|
68
68
|
tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
|
|
69
69
|
tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
|
|
70
|
-
tests/test_sample/test_base.py,sha256=
|
|
70
|
+
tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
|
|
71
71
|
tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
|
|
72
72
|
tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
|
|
73
73
|
tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
|
|
74
74
|
tests/torch_datasets/test_pvnet_uk.py,sha256=loueo7PUUYJVda3-vBn3bQIC_zgrTAThfx-GTDcBOZg,5596
|
|
75
75
|
tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
|
|
76
|
-
ocf_data_sampler-0.1.
|
|
77
|
-
ocf_data_sampler-0.1.
|
|
78
|
-
ocf_data_sampler-0.1.
|
|
79
|
-
ocf_data_sampler-0.1.
|
|
80
|
-
ocf_data_sampler-0.1.
|
|
76
|
+
ocf_data_sampler-0.1.7.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
|
|
77
|
+
ocf_data_sampler-0.1.7.dist-info/METADATA,sha256=8SbL1qjkmeFDYdv1_hHBL9jxbSpt4aFCpx70rEEPeb0,12173
|
|
78
|
+
ocf_data_sampler-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
79
|
+
ocf_data_sampler-0.1.7.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
|
|
80
|
+
ocf_data_sampler-0.1.7.dist-info/RECORD,,
|
tests/test_sample/test_base.py
CHANGED
|
@@ -3,11 +3,14 @@ Base class testing - SampleBase
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
+
import torch
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from ocf_data_sampler.sample.base import
|
|
10
|
-
|
|
10
|
+
from ocf_data_sampler.sample.base import (
|
|
11
|
+
SampleBase,
|
|
12
|
+
batch_to_tensor
|
|
13
|
+
)
|
|
11
14
|
|
|
12
15
|
class TestSample(SampleBase):
|
|
13
16
|
"""
|
|
@@ -84,3 +87,61 @@ def test_sample_base_to_numpy():
|
|
|
84
87
|
assert isinstance(numpy_data, dict)
|
|
85
88
|
assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
|
|
86
89
|
assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_batch_to_tensor_nested():
|
|
93
|
+
""" Test nested dictionary conversion """
|
|
94
|
+
batch = {
|
|
95
|
+
'outer': {
|
|
96
|
+
'inner': np.array([1, 2, 3])
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
tensor_batch = batch_to_tensor(batch)
|
|
100
|
+
|
|
101
|
+
assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_batch_to_tensor_mixed_types():
|
|
105
|
+
""" Test handling of mixed data types """
|
|
106
|
+
batch = {
|
|
107
|
+
'tensor_data': np.array([1, 2, 3]),
|
|
108
|
+
'string_data': 'not_a_tensor',
|
|
109
|
+
'nested': {
|
|
110
|
+
'numbers': np.array([4, 5, 6]),
|
|
111
|
+
'text': 'still_not_a_tensor'
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
tensor_batch = batch_to_tensor(batch)
|
|
115
|
+
|
|
116
|
+
assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
|
|
117
|
+
assert isinstance(tensor_batch['string_data'], str)
|
|
118
|
+
assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
|
|
119
|
+
assert isinstance(tensor_batch['nested']['text'], str)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_batch_to_tensor_different_dtypes():
|
|
123
|
+
""" Test conversion of arrays with different dtypes """
|
|
124
|
+
batch = {
|
|
125
|
+
'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
|
|
126
|
+
'int_data': np.array([1, 2, 3], dtype=np.int64),
|
|
127
|
+
'bool_data': np.array([True, False, True], dtype=np.bool_)
|
|
128
|
+
}
|
|
129
|
+
tensor_batch = batch_to_tensor(batch)
|
|
130
|
+
|
|
131
|
+
assert isinstance(tensor_batch['bool_data'], torch.Tensor)
|
|
132
|
+
assert tensor_batch['float_data'].dtype == torch.float32
|
|
133
|
+
assert tensor_batch['int_data'].dtype == torch.int64
|
|
134
|
+
assert tensor_batch['bool_data'].dtype == torch.bool
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def test_batch_to_tensor_multidimensional():
|
|
138
|
+
""" Test conversion of multidimensional arrays """
|
|
139
|
+
batch = {
|
|
140
|
+
'matrix': np.array([[1, 2], [3, 4]]),
|
|
141
|
+
'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
142
|
+
}
|
|
143
|
+
tensor_batch = batch_to_tensor(batch)
|
|
144
|
+
|
|
145
|
+
assert tensor_batch['matrix'].shape == (2, 2)
|
|
146
|
+
assert tensor_batch['tensor'].shape == (2, 2, 2)
|
|
147
|
+
assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))
|
|
File without changes
|
|
File without changes
|
|
File without changes
|