ocf-data-sampler 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/load/gsp.py +6 -3
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +14 -23
- ocf_data_sampler/torch_datasets/datasets/site.py +29 -15
- ocf_data_sampler/torch_datasets/utils/__init__.py +1 -1
- ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +57 -0
- {ocf_data_sampler-0.5.1.dist-info → ocf_data_sampler-0.5.3.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.5.1.dist-info → ocf_data_sampler-0.5.3.dist-info}/RECORD +9 -9
- ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py +0 -18
- {ocf_data_sampler-0.5.1.dist-info → ocf_data_sampler-0.5.3.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.5.1.dist-info → ocf_data_sampler-0.5.3.dist-info}/top_level.txt +0 -0
ocf_data_sampler/load/gsp.py
CHANGED
|
@@ -52,9 +52,12 @@ def open_gsp(
|
|
|
52
52
|
backend_kwargs = {"storage_options": {"anon": True}}
|
|
53
53
|
# Currently only compatible with S3 bucket.
|
|
54
54
|
|
|
55
|
-
ds = xr.open_dataset(
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
ds = xr.open_dataset(
|
|
56
|
+
zarr_path,
|
|
57
|
+
engine="zarr",
|
|
58
|
+
chunks=None,
|
|
59
|
+
backend_kwargs=backend_kwargs,
|
|
60
|
+
).rename({"datetime_gmt": "time_utc"})
|
|
58
61
|
|
|
59
62
|
if not (ds.gsp_id.isin(df_gsp_loc.index)).all():
|
|
60
63
|
raise ValueError(
|
|
@@ -21,7 +21,7 @@ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
|
|
|
21
21
|
from ocf_data_sampler.select import Location, fill_time_periods
|
|
22
22
|
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
|
|
23
23
|
from ocf_data_sampler.torch_datasets.utils import (
|
|
24
|
-
|
|
24
|
+
config_normalization_values_to_dicts,
|
|
25
25
|
find_valid_time_periods,
|
|
26
26
|
slice_datasets_by_space,
|
|
27
27
|
slice_datasets_by_time,
|
|
@@ -110,11 +110,14 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
110
110
|
self.config = config
|
|
111
111
|
self.datasets_dict = datasets_dict
|
|
112
112
|
|
|
113
|
+
# Extract the normalisation values from the config for faster access
|
|
114
|
+
means_dict, stds_dict = config_normalization_values_to_dicts(config)
|
|
115
|
+
self.means_dict = means_dict
|
|
116
|
+
self.stds_dict = stds_dict
|
|
113
117
|
|
|
114
|
-
@staticmethod
|
|
115
118
|
def process_and_combine_datasets(
|
|
119
|
+
self,
|
|
116
120
|
dataset_dict: dict,
|
|
117
|
-
config: Configuration,
|
|
118
121
|
t0: pd.Timestamp,
|
|
119
122
|
location: Location,
|
|
120
123
|
) -> NumpySample:
|
|
@@ -122,7 +125,6 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
122
125
|
|
|
123
126
|
Args:
|
|
124
127
|
dataset_dict: Dictionary of xarray datasets
|
|
125
|
-
config: Configuration object
|
|
126
128
|
t0: init-time for sample
|
|
127
129
|
location: location of the sample
|
|
128
130
|
"""
|
|
@@ -134,13 +136,8 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
134
136
|
for nwp_key, da_nwp in dataset_dict["nwp"].items():
|
|
135
137
|
|
|
136
138
|
# Standardise and convert to NumpyBatch
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
config.input_data.nwp[nwp_key].channel_means,
|
|
140
|
-
)
|
|
141
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
142
|
-
config.input_data.nwp[nwp_key].channel_stds,
|
|
143
|
-
)
|
|
139
|
+
da_channel_means = self.means_dict["nwp"][nwp_key]
|
|
140
|
+
da_channel_stds = self.stds_dict["nwp"][nwp_key]
|
|
144
141
|
|
|
145
142
|
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
146
143
|
|
|
@@ -153,15 +150,15 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
153
150
|
da_sat = dataset_dict["sat"]
|
|
154
151
|
|
|
155
152
|
# Standardise and convert to NumpyBatch
|
|
156
|
-
da_channel_means =
|
|
157
|
-
da_channel_stds =
|
|
153
|
+
da_channel_means = self.means_dict["sat"]
|
|
154
|
+
da_channel_stds = self.stds_dict["sat"]
|
|
158
155
|
|
|
159
156
|
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
160
157
|
|
|
161
158
|
numpy_modalities.append(convert_satellite_to_numpy_sample(da_sat))
|
|
162
159
|
|
|
163
160
|
if "gsp" in dataset_dict:
|
|
164
|
-
gsp_config = config.input_data.gsp
|
|
161
|
+
gsp_config = self.config.input_data.gsp
|
|
165
162
|
da_gsp = dataset_dict["gsp"]
|
|
166
163
|
da_gsp = da_gsp / da_gsp.effective_capacity_mwp
|
|
167
164
|
|
|
@@ -183,13 +180,8 @@ class AbstractPVNetUKDataset(Dataset):
|
|
|
183
180
|
)
|
|
184
181
|
|
|
185
182
|
# Only add solar position if explicitly configured
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
config.input_data.solar_position is not None
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
if has_solar_config:
|
|
192
|
-
solar_config = config.input_data.solar_position
|
|
183
|
+
if self.config.input_data.solar_position is not None:
|
|
184
|
+
solar_config = self.config.input_data.solar_position
|
|
193
185
|
|
|
194
186
|
# Create datetime range for solar position calculation
|
|
195
187
|
datetimes = pd.date_range(
|
|
@@ -264,7 +256,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
|
|
|
264
256
|
sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
|
|
265
257
|
sample_dict = compute(sample_dict)
|
|
266
258
|
|
|
267
|
-
return self.process_and_combine_datasets(sample_dict,
|
|
259
|
+
return self.process_and_combine_datasets(sample_dict, t0, location)
|
|
268
260
|
|
|
269
261
|
@override
|
|
270
262
|
def __getitem__(self, idx: int) -> NumpySample:
|
|
@@ -330,7 +322,6 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
|
|
|
330
322
|
gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
|
|
331
323
|
gsp_numpy_sample = self.process_and_combine_datasets(
|
|
332
324
|
gsp_sample_dict,
|
|
333
|
-
self.config,
|
|
334
325
|
t0,
|
|
335
326
|
location,
|
|
336
327
|
)
|
|
@@ -25,7 +25,7 @@ from ocf_data_sampler.select import (
|
|
|
25
25
|
intersection_of_multiple_dataframes_of_periods,
|
|
26
26
|
)
|
|
27
27
|
from ocf_data_sampler.torch_datasets.utils import (
|
|
28
|
-
|
|
28
|
+
config_normalization_values_to_dicts,
|
|
29
29
|
find_valid_time_periods,
|
|
30
30
|
slice_datasets_by_space,
|
|
31
31
|
slice_datasets_by_time,
|
|
@@ -62,6 +62,8 @@ def process_and_combine_datasets(
|
|
|
62
62
|
dataset_dict: dict,
|
|
63
63
|
config: Configuration,
|
|
64
64
|
t0: pd.Timestamp,
|
|
65
|
+
means_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
|
|
66
|
+
stds_dict: dict[str, xr.DataArray | dict[str, xr.DataArray]],
|
|
65
67
|
) -> NumpySample:
|
|
66
68
|
"""Normalise and convert data to numpy arrays.
|
|
67
69
|
|
|
@@ -69,6 +71,8 @@ def process_and_combine_datasets(
|
|
|
69
71
|
dataset_dict: Dictionary of xarray datasets
|
|
70
72
|
config: Configuration object
|
|
71
73
|
t0: init-time for sample
|
|
74
|
+
means_dict: Nested dictionary of mean values for the input data sources
|
|
75
|
+
stds_dict: Nested dictionary of std values for the input data sources
|
|
72
76
|
"""
|
|
73
77
|
numpy_modalities = []
|
|
74
78
|
|
|
@@ -79,12 +83,8 @@ def process_and_combine_datasets(
|
|
|
79
83
|
|
|
80
84
|
# Standardise and convert to NumpyBatch
|
|
81
85
|
|
|
82
|
-
da_channel_means =
|
|
83
|
-
|
|
84
|
-
)
|
|
85
|
-
da_channel_stds = channel_dict_to_dataarray(
|
|
86
|
-
config.input_data.nwp[nwp_key].channel_stds,
|
|
87
|
-
)
|
|
86
|
+
da_channel_means = means_dict["nwp"][nwp_key]
|
|
87
|
+
da_channel_stds = stds_dict["nwp"][nwp_key]
|
|
88
88
|
|
|
89
89
|
da_nwp = (da_nwp - da_channel_means) / da_channel_stds
|
|
90
90
|
|
|
@@ -97,8 +97,8 @@ def process_and_combine_datasets(
|
|
|
97
97
|
da_sat = dataset_dict["sat"]
|
|
98
98
|
|
|
99
99
|
# Standardise and convert to NumpyBatch
|
|
100
|
-
da_channel_means =
|
|
101
|
-
da_channel_stds =
|
|
100
|
+
da_channel_means = means_dict["sat"]
|
|
101
|
+
da_channel_stds = stds_dict["sat"]
|
|
102
102
|
|
|
103
103
|
da_sat = (da_sat - da_channel_means) / da_channel_stds
|
|
104
104
|
|
|
@@ -109,11 +109,7 @@ def process_and_combine_datasets(
|
|
|
109
109
|
da_sites = da_sites / da_sites.capacity_kwp
|
|
110
110
|
|
|
111
111
|
# Convert to NumpyBatch
|
|
112
|
-
numpy_modalities.append(
|
|
113
|
-
convert_site_to_numpy_sample(
|
|
114
|
-
da_sites,
|
|
115
|
-
),
|
|
116
|
-
)
|
|
112
|
+
numpy_modalities.append(convert_site_to_numpy_sample(da_sites))
|
|
117
113
|
|
|
118
114
|
# add datetime features
|
|
119
115
|
datetimes = pd.DatetimeIndex(da_sites.time_utc.values)
|
|
@@ -193,6 +189,11 @@ class SitesDataset(Dataset):
|
|
|
193
189
|
# Assign coords and indices to self
|
|
194
190
|
self.valid_t0_and_site_ids = valid_t0_and_site_ids
|
|
195
191
|
|
|
192
|
+
# Extract the normalisation values from the config for faster access
|
|
193
|
+
means_dict, stds_dict = config_normalization_values_to_dicts(config)
|
|
194
|
+
self.means_dict = means_dict
|
|
195
|
+
self.stds_dict = stds_dict
|
|
196
|
+
|
|
196
197
|
def find_valid_t0_and_site_ids(
|
|
197
198
|
self,
|
|
198
199
|
datasets_dict: dict,
|
|
@@ -273,7 +274,13 @@ class SitesDataset(Dataset):
|
|
|
273
274
|
|
|
274
275
|
sample_dict = compute(sample_dict)
|
|
275
276
|
|
|
276
|
-
return process_and_combine_datasets(
|
|
277
|
+
return process_and_combine_datasets(
|
|
278
|
+
sample_dict,
|
|
279
|
+
self.config,
|
|
280
|
+
t0,
|
|
281
|
+
self.means_dict,
|
|
282
|
+
self.stds_dict,
|
|
283
|
+
)
|
|
277
284
|
|
|
278
285
|
def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
|
|
279
286
|
"""Generate a sample for a given site id and t0.
|
|
@@ -332,6 +339,11 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
332
339
|
# Assign coords and indices to self
|
|
333
340
|
self.valid_t0s = valid_t0s
|
|
334
341
|
|
|
342
|
+
# Extract the normalisation values from the config for faster access
|
|
343
|
+
means_dict, stds_dict = config_normalization_values_to_dicts(config)
|
|
344
|
+
self.means_dict = means_dict
|
|
345
|
+
self.stds_dict = stds_dict
|
|
346
|
+
|
|
335
347
|
def find_valid_t0s(
|
|
336
348
|
self,
|
|
337
349
|
datasets_dict: dict,
|
|
@@ -406,6 +418,8 @@ class SitesDatasetConcurrent(Dataset):
|
|
|
406
418
|
site_sample_dict,
|
|
407
419
|
self.config,
|
|
408
420
|
t0,
|
|
421
|
+
self.means_dict,
|
|
422
|
+
self.stds_dict,
|
|
409
423
|
)
|
|
410
424
|
site_samples.append(site_numpy_sample)
|
|
411
425
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from .config_normalization_values_to_dicts import config_normalization_values_to_dicts
|
|
2
2
|
from .merge_and_fill_utils import fill_nans_in_arrays, merge_dicts
|
|
3
3
|
from .valid_time_periods import find_valid_time_periods
|
|
4
4
|
from .spatial_slice_for_dataset import slice_datasets_by_space
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Utility function for converting channel dictionaries to xarray DataArrays."""
|
|
2
|
+
|
|
3
|
+
import xarray as xr
|
|
4
|
+
|
|
5
|
+
from ocf_data_sampler.config import Configuration
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
|
|
9
|
+
"""Converts a dictionary of channel values to a DataArray.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
channel_dict: Dictionary mapping channel names (str) to their values (float).
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
xr.DataArray: A 1D DataArray with channels as coordinates.
|
|
16
|
+
"""
|
|
17
|
+
return xr.DataArray(
|
|
18
|
+
list(channel_dict.values()),
|
|
19
|
+
coords={"channel": list(channel_dict.keys())},
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def config_normalization_values_to_dicts(
|
|
23
|
+
config: Configuration,
|
|
24
|
+
) -> tuple[dict[str, xr.DataArray | dict[str, xr.DataArray]]]:
|
|
25
|
+
"""Construct DataArrays of mean and std values from the config normalisation constants.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config: Data configuration.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Means dict
|
|
32
|
+
Stds dict
|
|
33
|
+
"""
|
|
34
|
+
means_dict = {}
|
|
35
|
+
stds_dict = {}
|
|
36
|
+
|
|
37
|
+
if config.input_data.nwp is not None:
|
|
38
|
+
|
|
39
|
+
means_dict["nwp"] = {}
|
|
40
|
+
stds_dict["nwp"] = {}
|
|
41
|
+
|
|
42
|
+
for nwp_key in config.input_data.nwp:
|
|
43
|
+
# Standardise and convert to NumpyBatch
|
|
44
|
+
|
|
45
|
+
means_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
|
|
46
|
+
config.input_data.nwp[nwp_key].channel_means,
|
|
47
|
+
)
|
|
48
|
+
stds_dict["nwp"][nwp_key] = channel_dict_to_dataarray(
|
|
49
|
+
config.input_data.nwp[nwp_key].channel_stds,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if config.input_data.satellite is not None:
|
|
53
|
+
|
|
54
|
+
means_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_means)
|
|
55
|
+
stds_dict["sat"] = channel_dict_to_dataarray(config.input_data.satellite.channel_stds)
|
|
56
|
+
|
|
57
|
+
return means_dict, stds_dict
|
|
@@ -7,7 +7,7 @@ ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdR
|
|
|
7
7
|
ocf_data_sampler/data/uk_gsp_locations_20220314.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKFzQSjs6hcHFsV8a9uDDpy2E,9055334
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
|
|
10
|
-
ocf_data_sampler/load/gsp.py,sha256=
|
|
10
|
+
ocf_data_sampler/load/gsp.py,sha256=d30jQWnwFaLj6rKNMHdz1qD8fzF8q--RNnEXT7bGiX0,2981
|
|
11
11
|
ocf_data_sampler/load/load_dataset.py,sha256=K8rWykjII-3g127If7WRRFivzHNx3SshCvZj4uQlf28,2089
|
|
12
12
|
ocf_data_sampler/load/open_tensorstore_zarrs.py,sha256=_RHWe0GmrBSA9s1TH5I9VCMPpeZEsuRuhDt5Vyyx5Fo,2725
|
|
13
13
|
ocf_data_sampler/load/satellite.py,sha256=RylkJz8avxdM5pK_liaTlD1DTboyPMgykXJ4_Ek9WBA,1840
|
|
@@ -40,14 +40,14 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
|
|
|
40
40
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
|
|
41
41
|
ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
|
|
42
42
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
|
|
43
|
-
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=
|
|
44
|
-
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=
|
|
43
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=876oLukvb1nLtZQ8HBN3PWfN7urKH2xa45tVar7XrbM,12010
|
|
44
|
+
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nn6N8daGxllYwCCiFKbCJANTl84NrDRl-nbNGcfXc3U,15429
|
|
45
45
|
ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
|
|
46
46
|
ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
|
|
47
47
|
ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
|
|
48
48
|
ocf_data_sampler/torch_datasets/sample/uk_regional.py,sha256=Xx5cBYUyaM6PGUWQ76MHT9hwj6IJ7WAOxbpmYFbJGhc,10483
|
|
49
|
-
ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=
|
|
50
|
-
ocf_data_sampler/torch_datasets/utils/
|
|
49
|
+
ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=_UHLL_yRzhLJVHi6ROSaSe8TGw80CAhU325uCZj7XkY,331
|
|
50
|
+
ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py,sha256=jS3DkAwOF1W3AQnvsdkBJ1C8Unm93kQbS8hgTCtFv2A,1743
|
|
51
51
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
|
|
52
52
|
ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
|
|
53
53
|
ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=8E4a5v9dqr-sZOyBruuO-tjLPBbjtpYtdFY5z23aqnU,4365
|
|
@@ -56,7 +56,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
|
|
|
56
56
|
scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
|
|
57
57
|
scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
|
|
58
58
|
utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
|
|
59
|
-
ocf_data_sampler-0.5.
|
|
60
|
-
ocf_data_sampler-0.5.
|
|
61
|
-
ocf_data_sampler-0.5.
|
|
62
|
-
ocf_data_sampler-0.5.
|
|
59
|
+
ocf_data_sampler-0.5.3.dist-info/METADATA,sha256=9gg1K9SNIX6pJ-PXQptutiLU9fo7FsnrKM6vdHbpQYg,12580
|
|
60
|
+
ocf_data_sampler-0.5.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
61
|
+
ocf_data_sampler-0.5.3.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
|
|
62
|
+
ocf_data_sampler-0.5.3.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
"""Utility function for converting channel dictionaries to xarray DataArrays."""
|
|
2
|
-
|
|
3
|
-
import xarray as xr
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def channel_dict_to_dataarray(channel_dict: dict[str, float]) -> xr.DataArray:
|
|
7
|
-
"""Converts a dictionary of channel values to a DataArray.
|
|
8
|
-
|
|
9
|
-
Args:
|
|
10
|
-
channel_dict: Dictionary mapping channel names (str) to their values (float).
|
|
11
|
-
|
|
12
|
-
Returns:
|
|
13
|
-
xr.DataArray: A 1D DataArray with channels as coordinates.
|
|
14
|
-
"""
|
|
15
|
-
return xr.DataArray(
|
|
16
|
-
list(channel_dict.values()),
|
|
17
|
-
coords={"channel": list(channel_dict.keys())},
|
|
18
|
-
)
|
|
File without changes
|
|
File without changes
|