ocf-data-sampler 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/config/model.py +25 -23
- ocf_data_sampler/load/satellite.py +21 -29
- ocf_data_sampler/load/site.py +1 -1
- ocf_data_sampler/numpy_sample/gsp.py +6 -2
- ocf_data_sampler/numpy_sample/nwp.py +7 -13
- ocf_data_sampler/numpy_sample/satellite.py +11 -8
- ocf_data_sampler/numpy_sample/site.py +6 -2
- ocf_data_sampler/numpy_sample/sun_position.py +9 -10
- ocf_data_sampler/sample/__init__.py +0 -7
- ocf_data_sampler/sample/base.py +16 -35
- ocf_data_sampler/sample/site.py +28 -65
- ocf_data_sampler/sample/uk_regional.py +52 -97
- ocf_data_sampler/select/dropout.py +38 -25
- ocf_data_sampler/select/fill_time_periods.py +3 -1
- ocf_data_sampler/select/find_contiguous_time_periods.py +0 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/RECORD +27 -27
- tests/config/test_config.py +3 -3
- tests/conftest.py +33 -0
- tests/numpy_sample/test_nwp.py +3 -42
- tests/select/test_dropout.py +7 -13
- tests/test_sample/test_site_sample.py +5 -35
- tests/test_sample/test_uk_regional_sample.py +8 -35
- tests/torch_datasets/test_pvnet_uk.py +6 -19
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/LICENSE +0 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/WHEEL +0 -0
- {ocf_data_sampler-0.1.10.dist-info → ocf_data_sampler-0.1.11.dist-info}/top_level.txt +0 -0
ocf_data_sampler/config/model.py
CHANGED
|
@@ -49,31 +49,34 @@ class TimeWindowMixin(Base):
|
|
|
49
49
|
...,
|
|
50
50
|
description="Data interval ends at `t0 + interval_end_minutes`",
|
|
51
51
|
)
|
|
52
|
-
|
|
52
|
+
|
|
53
53
|
@model_validator(mode='after')
|
|
54
|
-
def
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
def validate_intervals(cls, values):
|
|
55
|
+
start = values.interval_start_minutes
|
|
56
|
+
end = values.interval_end_minutes
|
|
57
|
+
resolution = values.time_resolution_minutes
|
|
58
|
+
if start > end:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"interval_start_minutes ({start}) must be <= interval_end_minutes ({end})"
|
|
61
|
+
)
|
|
62
|
+
if (start % resolution != 0):
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"interval_start_minutes ({start}) must be divisible "
|
|
65
|
+
f"by time_resolution_minutes ({resolution})"
|
|
66
|
+
)
|
|
67
|
+
if (end % resolution != 0):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"interval_end_minutes ({end}) must be divisible "
|
|
70
|
+
f"by time_resolution_minutes ({resolution})"
|
|
71
|
+
)
|
|
57
72
|
return values
|
|
58
73
|
|
|
59
|
-
@field_validator("interval_start_minutes")
|
|
60
|
-
def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
61
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
62
|
-
raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
|
|
63
|
-
return v
|
|
64
|
-
|
|
65
|
-
@field_validator("interval_end_minutes")
|
|
66
|
-
def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
|
|
67
|
-
if v % info.data["time_resolution_minutes"] != 0:
|
|
68
|
-
raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
|
|
69
|
-
return v
|
|
70
|
-
|
|
71
74
|
|
|
72
75
|
class DropoutMixin(Base):
|
|
73
76
|
"""Mixin class, to add dropout minutes"""
|
|
74
77
|
|
|
75
|
-
dropout_timedeltas_minutes:
|
|
76
|
-
default=
|
|
78
|
+
dropout_timedeltas_minutes: List[int] = Field(
|
|
79
|
+
default=[],
|
|
77
80
|
description="List of possible minutes before t0 where data availability may start. Must be "
|
|
78
81
|
"negative or zero.",
|
|
79
82
|
)
|
|
@@ -88,18 +91,17 @@ class DropoutMixin(Base):
|
|
|
88
91
|
@field_validator("dropout_timedeltas_minutes")
|
|
89
92
|
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
|
|
90
93
|
"""Validate 'dropout_timedeltas_minutes'"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
assert m <= 0, "Dropout timedeltas must be negative"
|
|
94
|
+
for m in v:
|
|
95
|
+
assert m <= 0, "Dropout timedeltas must be negative"
|
|
94
96
|
return v
|
|
95
97
|
|
|
96
98
|
@model_validator(mode="after")
|
|
97
99
|
def dropout_instructions_consistent(self) -> Self:
|
|
98
100
|
if self.dropout_fraction == 0:
|
|
99
|
-
if self.dropout_timedeltas_minutes
|
|
101
|
+
if self.dropout_timedeltas_minutes != []:
|
|
100
102
|
raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
|
|
101
103
|
else:
|
|
102
|
-
if self.dropout_timedeltas_minutes
|
|
104
|
+
if self.dropout_timedeltas_minutes == []:
|
|
103
105
|
raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
|
|
104
106
|
return self
|
|
105
107
|
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Satellite loader"""
|
|
2
2
|
|
|
3
|
-
import subprocess
|
|
4
|
-
|
|
5
3
|
import xarray as xr
|
|
6
4
|
from ocf_data_sampler.load.utils import (
|
|
7
5
|
check_time_unique_increasing,
|
|
@@ -10,63 +8,59 @@ from ocf_data_sampler.load.utils import (
|
|
|
10
8
|
)
|
|
11
9
|
|
|
12
10
|
|
|
13
|
-
def
|
|
14
|
-
"""Helper function to open a
|
|
11
|
+
def get_single_sat_data(zarr_path: str) -> xr.Dataset:
|
|
12
|
+
"""Helper function to open a zarr from either a local or GCP path
|
|
15
13
|
|
|
16
14
|
Args:
|
|
17
|
-
zarr_path:
|
|
18
|
-
GCS paths (gs://)
|
|
15
|
+
zarr_path: path to a zarr file. Wildcards (*) are supported only for local paths
|
|
16
|
+
GCS paths (gs://) do not support wildcards
|
|
19
17
|
|
|
20
18
|
Returns:
|
|
21
|
-
An xarray Dataset containing satellite data
|
|
19
|
+
An xarray Dataset containing satellite data
|
|
22
20
|
|
|
23
21
|
Raises:
|
|
24
|
-
ValueError: If a wildcard (*) is used in a GCS (gs://) path
|
|
22
|
+
ValueError: If a wildcard (*) is used in a GCS (gs://) path
|
|
25
23
|
"""
|
|
26
24
|
|
|
27
|
-
# These kwargs are used if the path contains "*"
|
|
28
|
-
openmf_kwargs = dict(
|
|
29
|
-
engine="zarr",
|
|
30
|
-
concat_dim="time",
|
|
31
|
-
combine="nested",
|
|
32
|
-
chunks="auto",
|
|
33
|
-
join="override",
|
|
34
|
-
)
|
|
35
|
-
|
|
36
25
|
# Raise an error if a wildcard is used in a GCP path
|
|
37
26
|
if "gs://" in str(zarr_path) and "*" in str(zarr_path):
|
|
38
|
-
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs
|
|
27
|
+
raise ValueError("Wildcard (*) paths are not supported for GCP (gs://) URLs")
|
|
39
28
|
|
|
40
29
|
# Handle multi-file dataset for local paths
|
|
41
30
|
if "*" in str(zarr_path):
|
|
42
|
-
ds = xr.open_mfdataset(
|
|
31
|
+
ds = xr.open_mfdataset(
|
|
32
|
+
zarr_path,
|
|
33
|
+
engine="zarr",
|
|
34
|
+
concat_dim="time",
|
|
35
|
+
combine="nested",
|
|
36
|
+
chunks="auto",
|
|
37
|
+
join="override",
|
|
38
|
+
)
|
|
39
|
+
check_time_unique_increasing(ds.time)
|
|
43
40
|
else:
|
|
44
41
|
ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto")
|
|
45
42
|
|
|
46
|
-
# Ensure time is unique and sorted
|
|
47
|
-
ds = ds.drop_duplicates("time").sortby("time")
|
|
48
|
-
|
|
49
43
|
return ds
|
|
50
44
|
|
|
51
45
|
|
|
52
46
|
def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
53
|
-
"""Lazily opens the
|
|
47
|
+
"""Lazily opens the zarr store
|
|
54
48
|
|
|
55
49
|
Args:
|
|
56
|
-
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
|
|
57
|
-
|
|
50
|
+
zarr_path: Cloud URL or local path pattern, or list of these. If GCS URL,
|
|
51
|
+
it must start with 'gs://'
|
|
58
52
|
"""
|
|
59
53
|
|
|
60
54
|
# Open the data
|
|
61
55
|
if isinstance(zarr_path, (list, tuple)):
|
|
62
56
|
ds = xr.combine_nested(
|
|
63
|
-
[
|
|
57
|
+
[get_single_sat_data(path) for path in zarr_path],
|
|
64
58
|
concat_dim="time",
|
|
65
59
|
combine_attrs="override",
|
|
66
60
|
join="override",
|
|
67
61
|
)
|
|
68
62
|
else:
|
|
69
|
-
ds =
|
|
63
|
+
ds = get_single_sat_data(zarr_path)
|
|
70
64
|
|
|
71
65
|
ds = ds.rename(
|
|
72
66
|
{
|
|
@@ -76,9 +70,7 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
|
|
|
76
70
|
)
|
|
77
71
|
|
|
78
72
|
check_time_unique_increasing(ds.time_utc)
|
|
79
|
-
|
|
80
73
|
ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
|
|
81
|
-
|
|
82
74
|
ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
|
|
83
75
|
|
|
84
76
|
# TODO: should we control the dtype of the DataArray?
|
ocf_data_sampler/load/site.py
CHANGED
|
@@ -20,7 +20,7 @@ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArra
|
|
|
20
20
|
|
|
21
21
|
assert metadata_df.index.is_unique
|
|
22
22
|
|
|
23
|
-
# Ensure metadata aligns with the site_id dimension in
|
|
23
|
+
# Ensure metadata aligns with the site_id dimension in generation_ds
|
|
24
24
|
metadata_df = metadata_df.reindex(generation_ds.site_id.values)
|
|
25
25
|
|
|
26
26
|
# Assign coordinates to the Dataset using the aligned metadata
|
|
@@ -18,9 +18,13 @@ class GSPSampleKey:
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def convert_gsp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NumpySample
|
|
21
|
+
"""Convert from Xarray to NumpySample
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
da: Xarray DataArray containing GSP data
|
|
25
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the GSP data
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
|
-
# Extract values from the DataArray
|
|
24
28
|
sample = {
|
|
25
29
|
GSPSampleKey.gsp: da.values,
|
|
26
30
|
GSPSampleKey.nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
|
|
@@ -12,30 +12,24 @@ class NWPSampleKey:
|
|
|
12
12
|
step = 'nwp_step'
|
|
13
13
|
target_time_utc = 'nwp_target_time_utc'
|
|
14
14
|
t0_idx = 'nwp_t0_idx'
|
|
15
|
-
y_osgb = 'nwp_y_osgb'
|
|
16
|
-
x_osgb = 'nwp_x_osgb'
|
|
17
|
-
|
|
18
15
|
|
|
19
16
|
|
|
20
17
|
def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NWP NumpySample
|
|
18
|
+
"""Convert from Xarray to NWP NumpySample
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
da: Xarray DataArray containing NWP data
|
|
22
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the NWP
|
|
23
|
+
"""
|
|
22
24
|
|
|
23
|
-
# Create example and add t if available
|
|
24
25
|
sample = {
|
|
25
26
|
NWPSampleKey.nwp: da.values,
|
|
26
27
|
NWPSampleKey.channel_names: da.channel.values,
|
|
27
28
|
NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
|
|
28
29
|
NWPSampleKey.step: (da.step.values / pd.Timedelta("1h")).astype(int),
|
|
30
|
+
NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
|
|
29
31
|
}
|
|
30
32
|
|
|
31
|
-
if "target_time_utc" in da.coords:
|
|
32
|
-
sample[NWPSampleKey.target_time_utc] = da.target_time_utc.values.astype(float)
|
|
33
|
-
|
|
34
|
-
# TODO: Do we need this at all? Especially since it is only present in UKV data
|
|
35
|
-
for sample_key, dataset_key in ((NWPSampleKey.y_osgb, "y_osgb"),(NWPSampleKey.x_osgb, "x_osgb"),):
|
|
36
|
-
if dataset_key in da.coords:
|
|
37
|
-
sample[sample_key] = da[dataset_key].values
|
|
38
|
-
|
|
39
33
|
if t0_idx is not None:
|
|
40
34
|
sample[NWPSampleKey.t0_idx] = t0_idx
|
|
41
35
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Convert Satellite to NumpySample"""
|
|
2
|
+
|
|
2
3
|
import xarray as xr
|
|
3
4
|
|
|
4
5
|
|
|
@@ -12,19 +13,21 @@ class SatelliteSampleKey:
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def convert_satellite_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
15
|
-
"""Convert from Xarray to NumpySample
|
|
16
|
+
"""Convert from Xarray to NumpySample
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
da: xarray DataArray containing satellite data
|
|
20
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the satellite data
|
|
21
|
+
"""
|
|
16
22
|
sample = {
|
|
17
23
|
SatelliteSampleKey.satellite_actual: da.values,
|
|
18
24
|
SatelliteSampleKey.time_utc: da.time_utc.values.astype(float),
|
|
25
|
+
SatelliteSampleKey.x_geostationary: da.x_geostationary.values,
|
|
26
|
+
SatelliteSampleKey.y_geostationary: da.y_geostationary.values,
|
|
19
27
|
}
|
|
20
28
|
|
|
21
|
-
for sample_key, dataset_key in (
|
|
22
|
-
(SatelliteSampleKey.x_geostationary, "x_geostationary"),
|
|
23
|
-
(SatelliteSampleKey.y_geostationary, "y_geostationary"),
|
|
24
|
-
):
|
|
25
|
-
sample[sample_key] = da[dataset_key].values
|
|
26
|
-
|
|
27
29
|
if t0_idx is not None:
|
|
28
30
|
sample[SatelliteSampleKey.t0_idx] = t0_idx
|
|
29
31
|
|
|
30
|
-
return sample
|
|
32
|
+
return sample
|
|
33
|
+
|
|
@@ -18,9 +18,13 @@ class SiteSampleKey:
|
|
|
18
18
|
time_cos = "site_time_cos"
|
|
19
19
|
|
|
20
20
|
def convert_site_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) -> dict:
|
|
21
|
-
"""Convert from Xarray to NumpySample
|
|
21
|
+
"""Convert from Xarray to NumpySample
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
da: xarray DataArray containing site data
|
|
25
|
+
t0_idx: Index of the t0 timestamp in the time dimension of the site data
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
|
-
# Extract values from the DataArray
|
|
24
28
|
sample = {
|
|
25
29
|
SiteSampleKey.generation: da.values,
|
|
26
30
|
SiteSampleKey.capacity_kwp: da.isel(time_utc=0)["capacity_kwp"].values,
|
|
@@ -27,16 +27,15 @@ def calculate_azimuth_and_elevation(
|
|
|
27
27
|
latitude=lat,
|
|
28
28
|
method='nrel_numpy'
|
|
29
29
|
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
return azimuth, elevation
|
|
30
|
+
|
|
31
|
+
return solpos["azimuth"].values, solpos["elevation"].values
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
def make_sun_position_numpy_sample(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
datetimes: pd.DatetimeIndex,
|
|
36
|
+
lon: float,
|
|
37
|
+
lat: float,
|
|
38
|
+
key_prefix: str = "gsp"
|
|
40
39
|
) -> dict:
|
|
41
40
|
"""Creates NumpySample with standardized solar coordinates
|
|
42
41
|
|
|
@@ -44,22 +43,22 @@ def make_sun_position_numpy_sample(
|
|
|
44
43
|
datetimes: The datetimes to calculate solar angles for
|
|
45
44
|
lon: The longitude
|
|
46
45
|
lat: The latitude
|
|
46
|
+
key_prefix: The prefix to add to the keys in the NumpySample
|
|
47
47
|
"""
|
|
48
48
|
|
|
49
49
|
azimuth, elevation = calculate_azimuth_and_elevation(datetimes, lon, lat)
|
|
50
50
|
|
|
51
51
|
# Normalise
|
|
52
|
-
|
|
53
52
|
# Azimuth is in range [0, 360] degrees
|
|
54
53
|
azimuth = azimuth / 360
|
|
55
54
|
|
|
56
55
|
# Elevation is in range [-90, 90] degrees
|
|
57
56
|
elevation = elevation / 180 + 0.5
|
|
58
|
-
|
|
57
|
+
|
|
59
58
|
# Make NumpySample
|
|
60
59
|
sun_numpy_sample = {
|
|
61
60
|
key_prefix + "_solar_azimuth": azimuth,
|
|
62
61
|
key_prefix + "_solar_elevation": elevation,
|
|
63
62
|
}
|
|
64
63
|
|
|
65
|
-
return sun_numpy_sample
|
|
64
|
+
return sun_numpy_sample
|
ocf_data_sampler/sample/base.py
CHANGED
|
@@ -1,23 +1,15 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Base class definition - abstract
|
|
3
|
-
Handling of both flat and nested structures - consideration for NWP
|
|
4
|
-
"""
|
|
1
|
+
""" Base class for handling flat/nested data structures with NWP consideration """
|
|
5
2
|
|
|
6
|
-
import logging
|
|
7
3
|
import numpy as np
|
|
8
4
|
import torch
|
|
9
|
-
import xarray as xr
|
|
10
5
|
|
|
11
|
-
from
|
|
12
|
-
from typing import Any, Dict, Optional, Union, TypeAlias
|
|
6
|
+
from typing import TypeAlias
|
|
13
7
|
from abc import ABC, abstractmethod
|
|
14
8
|
|
|
15
9
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
NumpyBatch: TypeAlias = Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
|
|
20
|
-
TensorBatch: TypeAlias = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
|
|
10
|
+
NumpySample: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
|
|
11
|
+
NumpyBatch: TypeAlias = dict[str, np.ndarray | dict[str, np.ndarray]]
|
|
12
|
+
TensorBatch: TypeAlias = dict[str, torch.Tensor | dict[str, torch.Tensor]]
|
|
21
13
|
|
|
22
14
|
|
|
23
15
|
class SampleBase(ABC):
|
|
@@ -26,43 +18,33 @@ class SampleBase(ABC):
|
|
|
26
18
|
Provides core data storage functionality
|
|
27
19
|
"""
|
|
28
20
|
|
|
29
|
-
def __init__(self, data: Optional[Union[NumpySample, xr.Dataset]] = None):
|
|
30
|
-
""" Initialise data container """
|
|
31
|
-
logger.debug("Initialising SampleBase instance")
|
|
32
|
-
self._data = data
|
|
33
|
-
|
|
34
21
|
@abstractmethod
|
|
35
22
|
def to_numpy(self) -> NumpySample:
|
|
36
|
-
"""
|
|
23
|
+
"""Convert sample data to numpy format"""
|
|
37
24
|
raise NotImplementedError
|
|
38
25
|
|
|
39
26
|
@abstractmethod
|
|
40
|
-
def plot(self
|
|
41
|
-
""" Abstract method for plotting """
|
|
27
|
+
def plot(self) -> None:
|
|
42
28
|
raise NotImplementedError
|
|
43
29
|
|
|
44
30
|
@abstractmethod
|
|
45
|
-
def save(self, path:
|
|
46
|
-
""" Abstract method for saving sample data """
|
|
31
|
+
def save(self, path: str) -> None:
|
|
47
32
|
raise NotImplementedError
|
|
48
33
|
|
|
49
34
|
@classmethod
|
|
50
35
|
@abstractmethod
|
|
51
|
-
def load(cls, path:
|
|
52
|
-
""" Abstract class method for loading sample data """
|
|
36
|
+
def load(cls, path: str) -> 'SampleBase':
|
|
53
37
|
raise NotImplementedError
|
|
54
38
|
|
|
55
39
|
|
|
56
40
|
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
57
41
|
"""
|
|
58
|
-
|
|
42
|
+
Recursively converts numpy arrays in nested dict to torch tensors
|
|
59
43
|
Args:
|
|
60
44
|
batch: NumpyBatch with data in numpy arrays
|
|
61
45
|
Returns:
|
|
62
46
|
TensorBatch with data in torch tensors
|
|
63
47
|
"""
|
|
64
|
-
if not batch:
|
|
65
|
-
raise ValueError("Cannot convert empty batch to tensors")
|
|
66
48
|
|
|
67
49
|
for k, v in batch.items():
|
|
68
50
|
if isinstance(v, dict):
|
|
@@ -75,16 +57,15 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
|
|
|
75
57
|
return batch
|
|
76
58
|
|
|
77
59
|
|
|
78
|
-
def copy_batch_to_device(batch:
|
|
79
|
-
"""
|
|
80
|
-
Moves tensor leaves in a nested dict to a new device.
|
|
60
|
+
def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
|
|
61
|
+
"""Recursively copies tensors in nested dict to specified device
|
|
81
62
|
|
|
82
63
|
Args:
|
|
83
|
-
batch: Nested dict with tensors to move
|
|
84
|
-
device: Device to move tensors to
|
|
85
|
-
|
|
64
|
+
batch: Nested dict with tensors to move
|
|
65
|
+
device: Device to move tensors to
|
|
66
|
+
|
|
86
67
|
Returns:
|
|
87
|
-
A dict with tensors moved to the new device
|
|
68
|
+
A dict with tensors moved to the new device
|
|
88
69
|
"""
|
|
89
70
|
batch_copy = {}
|
|
90
71
|
|
ocf_data_sampler/sample/site.py
CHANGED
|
@@ -1,81 +1,44 @@
|
|
|
1
|
-
"""
|
|
2
|
-
PVNet - Site sample / dataset implementation
|
|
3
|
-
"""
|
|
1
|
+
"""PVNet Site sample implementation for netCDF data handling and conversion"""
|
|
4
2
|
|
|
5
|
-
import logging
|
|
6
3
|
import xarray as xr
|
|
7
|
-
import numpy as np
|
|
8
4
|
|
|
9
|
-
from
|
|
10
|
-
from typing import Dict, Any, Union
|
|
5
|
+
from typing_extensions import override
|
|
11
6
|
|
|
12
|
-
from ocf_data_sampler.sample.base import SampleBase
|
|
7
|
+
from ocf_data_sampler.sample.base import SampleBase, NumpySample
|
|
13
8
|
from ocf_data_sampler.torch_datasets.datasets.site import convert_netcdf_to_numpy_sample
|
|
14
9
|
|
|
15
10
|
|
|
16
|
-
logger = logging.getLogger(__name__)
|
|
17
|
-
|
|
18
|
-
|
|
19
11
|
class SiteSample(SampleBase):
|
|
20
|
-
"""
|
|
12
|
+
"""Handles PVNet site specific netCDF operations"""
|
|
21
13
|
|
|
22
|
-
def __init__(self):
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def to_numpy(self) -> Dict[str, Any]:
|
|
28
|
-
""" Convert sample numpy arrays - netCDF conversion """
|
|
29
|
-
logger.debug("Converting site sample to numpy format")
|
|
14
|
+
def __init__(self, data: xr.Dataset):
|
|
15
|
+
|
|
16
|
+
if not isinstance(data, xr.Dataset):
|
|
17
|
+
raise TypeError(f"Data must be xarray Dataset - Found type {type(data)}")
|
|
30
18
|
|
|
31
|
-
|
|
32
|
-
if not isinstance(self._data, xr.Dataset):
|
|
33
|
-
raise TypeError("Data must be xarray Dataset")
|
|
34
|
-
|
|
35
|
-
numpy_data = convert_netcdf_to_numpy_sample(self._data)
|
|
19
|
+
self._data = data
|
|
36
20
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
except Exception as e:
|
|
41
|
-
logger.error(f"Error converting to numpy: {str(e)}")
|
|
42
|
-
raise
|
|
21
|
+
@override
|
|
22
|
+
def to_numpy(self) -> NumpySample:
|
|
23
|
+
return convert_netcdf_to_numpy_sample(self._data)
|
|
43
24
|
|
|
44
|
-
def save(self, path:
|
|
45
|
-
"""
|
|
46
|
-
logger.debug(f"Saving SiteSample to {path}")
|
|
47
|
-
path = Path(path)
|
|
25
|
+
def save(self, path: str) -> None:
|
|
26
|
+
"""Save site sample data as netCDF
|
|
48
27
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
if not isinstance(self._data, xr.Dataset):
|
|
54
|
-
raise TypeError("Data must be xarray Dataset for saving")
|
|
55
|
-
|
|
56
|
-
self._data.to_netcdf(
|
|
57
|
-
path,
|
|
58
|
-
mode="w",
|
|
59
|
-
engine="h5netcdf"
|
|
60
|
-
)
|
|
61
|
-
logger.debug(f"Successfully saved SiteSample - {path}")
|
|
28
|
+
Args:
|
|
29
|
+
path: Path to save the netCDF file
|
|
30
|
+
"""
|
|
31
|
+
self._data.to_netcdf(path, mode="w", engine="h5netcdf")
|
|
62
32
|
|
|
63
33
|
@classmethod
|
|
64
|
-
def load(cls, path: str) ->
|
|
65
|
-
"""
|
|
66
|
-
logger.debug(f"Loading SiteSample from {path}")
|
|
67
|
-
path = Path(path)
|
|
68
|
-
|
|
69
|
-
if path.suffix != '.nc':
|
|
70
|
-
logger.error(f"Invalid file format - {path.suffix}")
|
|
71
|
-
raise ValueError("Only .nc format is supported")
|
|
34
|
+
def load(cls, path: str) -> 'SiteSample':
|
|
35
|
+
"""Load site sample data from netCDF
|
|
72
36
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return
|
|
77
|
-
|
|
78
|
-
#
|
|
79
|
-
def plot(self
|
|
80
|
-
"
|
|
81
|
-
pass
|
|
37
|
+
Args:
|
|
38
|
+
path: Path to load the netCDF file from
|
|
39
|
+
"""
|
|
40
|
+
return cls(xr.open_dataset(path))
|
|
41
|
+
|
|
42
|
+
# TODO - placeholder for now
|
|
43
|
+
def plot(self) -> None:
|
|
44
|
+
raise NotImplementedError("Plotting not yet implemented for SiteSample")
|