ocf-data-sampler 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -5,25 +5,34 @@ Handling of both flat and nested structures - consideration for NWP
5
5
 
6
6
  import logging
7
7
  import numpy as np
8
+ import torch
9
+ import xarray as xr
8
10
 
9
11
  from pathlib import Path
10
- from typing import Any, Dict, Optional, Union
12
+ from typing import Any, Dict, Optional, Union, TypeAlias
11
13
  from abc import ABC, abstractmethod
12
14
 
15
+
13
16
  logger = logging.getLogger(__name__)
14
17
 
18
+ NumpySample: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
19
+ NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
20
+ TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
21
+
22
+
15
23
  class SampleBase(ABC):
16
24
  """
17
25
  Abstract base class for all sample types
18
26
  Provides core data storage functionality
19
27
  """
20
28
 
21
- def __init__(self):
29
+ def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
22
30
  """ Initialise data container """
23
31
  logger.debug("Initialising SampleBase instance")
32
+ self._data = data
24
33
 
25
34
  @abstractmethod
26
- def to_numpy(self) -> Dict[str, Any]:
35
+ def to_numpy(self) -> NumpySample:
27
36
  """ Convert data to a numpy array representation """
28
37
  raise NotImplementedError
29
38
 
@@ -42,3 +51,25 @@ class SampleBase(ABC):
42
51
  def load(cls, path: Union[str, Path]) -> 'SampleBase':
43
52
  """ Abstract class method for loading sample data """
44
53
  raise NotImplementedError
54
+
55
+
56
+ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
57
+ """
58
+ Moves ndarrays in a nested dict to torch tensors
59
+ Args:
60
+ batch: NumpyBatch with data in numpy arrays
61
+ Returns:
62
+ TensorBatch with data in torch tensors
63
+ """
64
+ if not batch:
65
+ raise ValueError("Cannot convert empty batch to tensors")
66
+
67
+ for k, v in batch.items():
68
+ if isinstance(v, dict):
69
+ batch[k] = batch_to_tensor(v)
70
+ elif isinstance(v, np.ndarray):
71
+ if v.dtype == np.bool_:
72
+ batch[k] = torch.tensor(v, dtype=torch.bool)
73
+ elif np.issubdtype(v.dtype, np.number):
74
+ batch[k] = torch.as_tensor(v)
75
+ return batch
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -27,7 +27,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=8OaTvkPjzSjotcdKsa6BKmmlBKDBun
27
27
  ocf_data_sampler/numpy_sample/site.py,sha256=I-cAXCOF0SDdm5Hx43lFqYZ3jh61kltLQK-fc4_nNu0,1314
28
28
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=UklhucCxCT6GMlAhCWL6c4cfWrdc1cWgegrYaqUoHOY,1611
29
29
  ocf_data_sampler/sample/__init__.py,sha256=02CM7E5nKkGiYbVW-kvzjNd4RaqGuHCkDChtmDBDUoA,248
30
- ocf_data_sampler/sample/base.py,sha256=4U78tczCRsKMDwU4HkD20nyGyYjIBSZV5neF2mT--2M,1197
30
+ ocf_data_sampler/sample/base.py,sha256=qeKuWyyO8M4QX6QDbItioeCiss0fG05NXRtf0TCMQSc,2246
31
31
  ocf_data_sampler/sample/site.py,sha256=0BvDXs0kxTjUq7kWpeoITK_uN4uE0w1IvEFXZUoKOb0,2507
32
32
  ocf_data_sampler/sample/uk_regional.py,sha256=D1A6nQB1PYCmxb3FzU9gqbNufQfx__wcprcDm50jCJw,4381
33
33
  ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
@@ -67,14 +67,14 @@ tests/select/test_find_contiguous_time_periods.py,sha256=kOga_V7er5We7ewMARXaKdM
67
67
  tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
68
68
  tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
69
69
  tests/select/test_select_time_slice.py,sha256=nYrdlmZlGEygJKiE26bADiluNPN1qt5kD4FrI2vtxUw,9686
70
- tests/test_sample/test_base.py,sha256=ljtB38MmscTGN6OvUgclBceNnfx6m7AN8iHYDml9XW4,2189
70
+ tests/test_sample/test_base.py,sha256=CkqKCZbrq3Vb4T7bOwPh3_0p8OTl0LfSLNBctYC_jag,4199
71
71
  tests/test_sample/test_site_sample.py,sha256=Gln-Or060cUWvA7Q7c1vsthgCttOAM2z9yBI9zUIrDw,6238
72
72
  tests/test_sample/test_uk_regional_sample.py,sha256=gkeQWC2wC757jKJz_QBmDMFQjn3R54q_tEo948yyxCY,4840
73
73
  tests/torch_datasets/test_merge_and_fill_utils.py,sha256=GtuQg82BM1eHQjT7Ik1x1zaVcuc7KJO4_NC9stXsd4s,1123
74
74
  tests/torch_datasets/test_pvnet_uk.py,sha256=loueo7PUUYJVda3-vBn3bQIC_zgrTAThfx-GTDcBOZg,5596
75
75
  tests/torch_datasets/test_site.py,sha256=t57vAR_RRWcbG_kEFk6VrFCYzVxwFG6qJKBnRHF02fM,7000
76
- ocf_data_sampler-0.1.6.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
77
- ocf_data_sampler-0.1.6.dist-info/METADATA,sha256=qltSR8dsD54ufCfXXFFYYLY_l_1saBWGaxwzZDIaJoU,12173
78
- ocf_data_sampler-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
79
- ocf_data_sampler-0.1.6.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
80
- ocf_data_sampler-0.1.6.dist-info/RECORD,,
76
+ ocf_data_sampler-0.1.7.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
77
+ ocf_data_sampler-0.1.7.dist-info/METADATA,sha256=8SbL1qjkmeFDYdv1_hHBL9jxbSpt4aFCpx70rEEPeb0,12173
78
+ ocf_data_sampler-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
79
+ ocf_data_sampler-0.1.7.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
80
+ ocf_data_sampler-0.1.7.dist-info/RECORD,,
@@ -3,11 +3,14 @@ Base class testing - SampleBase
3
3
  """
4
4
 
5
5
  import pytest
6
+ import torch
6
7
  import numpy as np
7
8
 
8
9
  from pathlib import Path
9
- from ocf_data_sampler.sample.base import SampleBase
10
-
10
+ from ocf_data_sampler.sample.base import (
11
+ SampleBase,
12
+ batch_to_tensor
13
+ )
11
14
 
12
15
  class TestSample(SampleBase):
13
16
  """
@@ -84,3 +87,61 @@ def test_sample_base_to_numpy():
84
87
  assert isinstance(numpy_data, dict)
85
88
  assert all(isinstance(value, np.ndarray) for value in numpy_data.values())
86
89
  assert np.array_equal(numpy_data['list_data'], np.array([1, 2, 3]))
90
+
91
+
92
+ def test_batch_to_tensor_nested():
93
+ """ Test nested dictionary conversion """
94
+ batch = {
95
+ 'outer': {
96
+ 'inner': np.array([1, 2, 3])
97
+ }
98
+ }
99
+ tensor_batch = batch_to_tensor(batch)
100
+
101
+ assert torch.equal(tensor_batch['outer']['inner'], torch.tensor([1, 2, 3]))
102
+
103
+
104
+ def test_batch_to_tensor_mixed_types():
105
+ """ Test handling of mixed data types """
106
+ batch = {
107
+ 'tensor_data': np.array([1, 2, 3]),
108
+ 'string_data': 'not_a_tensor',
109
+ 'nested': {
110
+ 'numbers': np.array([4, 5, 6]),
111
+ 'text': 'still_not_a_tensor'
112
+ }
113
+ }
114
+ tensor_batch = batch_to_tensor(batch)
115
+
116
+ assert isinstance(tensor_batch['tensor_data'], torch.Tensor)
117
+ assert isinstance(tensor_batch['string_data'], str)
118
+ assert isinstance(tensor_batch['nested']['numbers'], torch.Tensor)
119
+ assert isinstance(tensor_batch['nested']['text'], str)
120
+
121
+
122
+ def test_batch_to_tensor_different_dtypes():
123
+ """ Test conversion of arrays with different dtypes """
124
+ batch = {
125
+ 'float_data': np.array([1.0, 2.0, 3.0], dtype=np.float32),
126
+ 'int_data': np.array([1, 2, 3], dtype=np.int64),
127
+ 'bool_data': np.array([True, False, True], dtype=np.bool_)
128
+ }
129
+ tensor_batch = batch_to_tensor(batch)
130
+
131
+ assert isinstance(tensor_batch['bool_data'], torch.Tensor)
132
+ assert tensor_batch['float_data'].dtype == torch.float32
133
+ assert tensor_batch['int_data'].dtype == torch.int64
134
+ assert tensor_batch['bool_data'].dtype == torch.bool
135
+
136
+
137
+ def test_batch_to_tensor_multidimensional():
138
+ """ Test conversion of multidimensional arrays """
139
+ batch = {
140
+ 'matrix': np.array([[1, 2], [3, 4]]),
141
+ 'tensor': np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
142
+ }
143
+ tensor_batch = batch_to_tensor(batch)
144
+
145
+ assert tensor_batch['matrix'].shape == (2, 2)
146
+ assert tensor_batch['tensor'].shape == (2, 2, 2)
147
+ assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))