ocf-data-sampler 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ocf-data-sampler might be problematic. Click here for more details.
- ocf_data_sampler/load/gsp.py +1 -1
- ocf_data_sampler/select/dropout.py +10 -20
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +11 -13
- ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +19 -23
- {ocf_data_sampler-0.2.20.dist-info → ocf_data_sampler-0.2.22.dist-info}/METADATA +1 -1
- {ocf_data_sampler-0.2.20.dist-info → ocf_data_sampler-0.2.22.dist-info}/RECORD +8 -8
- {ocf_data_sampler-0.2.20.dist-info → ocf_data_sampler-0.2.22.dist-info}/WHEEL +1 -1
- {ocf_data_sampler-0.2.20.dist-info → ocf_data_sampler-0.2.22.dist-info}/top_level.txt +0 -0
ocf_data_sampler/load/gsp.py
CHANGED
|
@@ -48,7 +48,7 @@ def open_gsp(zarr_path: str, boundaries_version: str = "20220314") -> xr.DataArr
|
|
|
48
48
|
|
|
49
49
|
if not (ds.gsp_id.isin(df_gsp_loc.index)).all():
|
|
50
50
|
raise ValueError(
|
|
51
|
-
"Some GSP IDs in the GSP generation data are available in the locations file.",
|
|
51
|
+
"Some GSP IDs in the GSP generation data are not available in the locations file.",
|
|
52
52
|
)
|
|
53
53
|
|
|
54
54
|
# Select the locations by the GSP IDs in the generation data
|
|
@@ -9,19 +9,22 @@ import pandas as pd
|
|
|
9
9
|
import xarray as xr
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def
|
|
12
|
+
def apply_sampled_dropout_time(
|
|
13
13
|
t0: pd.Timestamp,
|
|
14
14
|
dropout_timedeltas: list[pd.Timedelta],
|
|
15
15
|
dropout_frac: float,
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
da: xr.DataArray,
|
|
17
|
+
) -> xr.DataArray:
|
|
18
|
+
"""Randomly pick a dropout time from a list of timedeltas and apply dropout time to the data.
|
|
18
19
|
|
|
19
20
|
Args:
|
|
20
21
|
t0: The forecast init-time
|
|
21
22
|
dropout_timedeltas: List of timedeltas relative to t0 to pick from
|
|
22
23
|
dropout_frac: Probability that dropout will be applied.
|
|
23
24
|
This should be between 0 and 1 inclusive
|
|
25
|
+
da: Xarray DataArray with 'time_utc' coordinate
|
|
24
26
|
"""
|
|
27
|
+
# sample dropout time
|
|
25
28
|
if dropout_frac > 0 and len(dropout_timedeltas) == 0:
|
|
26
29
|
raise ValueError("To apply dropout, dropout_timedeltas must be provided")
|
|
27
30
|
|
|
@@ -37,21 +40,8 @@ def draw_dropout_time(
|
|
|
37
40
|
else:
|
|
38
41
|
dropout_time = t0 + np.random.choice(dropout_timedeltas)
|
|
39
42
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def apply_dropout_time(
|
|
44
|
-
ds: xr.DataArray,
|
|
45
|
-
dropout_time: pd.Timestamp | None,
|
|
46
|
-
) -> xr.DataArray:
|
|
47
|
-
"""Apply dropout time to the data.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
ds: Xarray DataArray with 'time_utc' coordinate
|
|
51
|
-
dropout_time: Time after which data is set to NaN
|
|
52
|
-
"""
|
|
43
|
+
# apply dropout time
|
|
53
44
|
if dropout_time is None:
|
|
54
|
-
return
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
return ds.where(ds.time_utc <= dropout_time)
|
|
45
|
+
return da
|
|
46
|
+
# This replaces the times after the dropout with NaNs
|
|
47
|
+
return da.where(da.time_utc <= dropout_time)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Torch dataset for UK PVNet."""
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
3
|
import pandas as pd
|
|
5
4
|
import xarray as xr
|
|
6
5
|
from torch.utils.data import Dataset
|
|
@@ -257,22 +256,12 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
|
|
|
257
256
|
# Construct a lookup for locations - useful for users to construct sample by GSP ID
|
|
258
257
|
location_lookup = {loc.id: loc for loc in self.locations}
|
|
259
258
|
|
|
260
|
-
# Construct indices for sampling
|
|
261
|
-
t_index, loc_index = np.meshgrid(
|
|
262
|
-
np.arange(len(self.valid_t0_times)),
|
|
263
|
-
np.arange(len(self.locations)),
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
# Make array of all possible (t0, location) coordinates. Each row is a single coordinate
|
|
267
|
-
index_pairs = np.stack((t_index.ravel(), loc_index.ravel())).T
|
|
268
|
-
|
|
269
259
|
# Assign coords and indices to self
|
|
270
260
|
self.location_lookup = location_lookup
|
|
271
|
-
self.index_pairs = index_pairs
|
|
272
261
|
|
|
273
262
|
@override
|
|
274
263
|
def __len__(self) -> int:
|
|
275
|
-
return len(self.
|
|
264
|
+
return len(self.locations)*len(self.valid_t0_times)
|
|
276
265
|
|
|
277
266
|
def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpySample:
|
|
278
267
|
"""Generate the PVNet sample for given coordinates.
|
|
@@ -290,7 +279,16 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
|
|
|
290
279
|
@override
|
|
291
280
|
def __getitem__(self, idx: int) -> NumpySample:
|
|
292
281
|
# Get the coordinates of the sample
|
|
293
|
-
|
|
282
|
+
|
|
283
|
+
if idx >= len(self):
|
|
284
|
+
raise ValueError(f"Index {idx} out of range for dataset of length {len(self)}")
|
|
285
|
+
|
|
286
|
+
# t_index will be between 0 and len(self.valid_t0_times)-1
|
|
287
|
+
t_index = idx % len(self.valid_t0_times)
|
|
288
|
+
|
|
289
|
+
# For each location, there are len(self.valid_t0_times) possible samples
|
|
290
|
+
loc_index = idx // len(self.valid_t0_times)
|
|
291
|
+
|
|
294
292
|
location = self.locations[loc_index]
|
|
295
293
|
t0 = self.valid_t0_times[t_index]
|
|
296
294
|
|
|
@@ -4,7 +4,7 @@ import pandas as pd
|
|
|
4
4
|
import xarray as xr
|
|
5
5
|
|
|
6
6
|
from ocf_data_sampler.config import Configuration
|
|
7
|
-
from ocf_data_sampler.select.dropout import
|
|
7
|
+
from ocf_data_sampler.select.dropout import apply_sampled_dropout_time
|
|
8
8
|
from ocf_data_sampler.select.select_time_slice import select_time_slice, select_time_slice_nwp
|
|
9
9
|
from ocf_data_sampler.utils import minutes
|
|
10
10
|
|
|
@@ -51,17 +51,12 @@ def slice_datasets_by_time(
|
|
|
51
51
|
interval_end=minutes(sat_config.interval_end_minutes),
|
|
52
52
|
)
|
|
53
53
|
|
|
54
|
-
#
|
|
55
|
-
|
|
54
|
+
# Apply the randomly sampled dropout
|
|
55
|
+
sliced_datasets_dict["sat"] = apply_sampled_dropout_time(
|
|
56
56
|
t0,
|
|
57
57
|
dropout_timedeltas=minutes(sat_config.dropout_timedeltas_minutes),
|
|
58
58
|
dropout_frac=sat_config.dropout_fraction,
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# Apply the dropout
|
|
62
|
-
sliced_datasets_dict["sat"] = apply_dropout_time(
|
|
63
|
-
sliced_datasets_dict["sat"],
|
|
64
|
-
sat_dropout_time,
|
|
59
|
+
da=sliced_datasets_dict["sat"],
|
|
65
60
|
)
|
|
66
61
|
|
|
67
62
|
if "gsp" in datasets_dict:
|
|
@@ -76,15 +71,11 @@ def slice_datasets_by_time(
|
|
|
76
71
|
)
|
|
77
72
|
|
|
78
73
|
# Dropout on the past GSP, but not the future GSP
|
|
79
|
-
|
|
74
|
+
da_gsp_past = apply_sampled_dropout_time(
|
|
80
75
|
t0,
|
|
81
76
|
dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
|
|
82
77
|
dropout_frac=gsp_config.dropout_fraction,
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
da_gsp_past = apply_dropout_time(
|
|
86
|
-
da_gsp_past,
|
|
87
|
-
gsp_dropout_time,
|
|
78
|
+
da=da_gsp_past,
|
|
88
79
|
)
|
|
89
80
|
|
|
90
81
|
da_gsp_future = select_time_slice(
|
|
@@ -100,25 +91,30 @@ def slice_datasets_by_time(
|
|
|
100
91
|
if "site" in datasets_dict:
|
|
101
92
|
site_config = config.input_data.site
|
|
102
93
|
|
|
103
|
-
|
|
94
|
+
da_site_past = select_time_slice(
|
|
104
95
|
datasets_dict["site"],
|
|
105
96
|
t0,
|
|
106
97
|
time_resolution=minutes(site_config.time_resolution_minutes),
|
|
107
98
|
interval_start=minutes(site_config.interval_start_minutes),
|
|
108
|
-
interval_end=minutes(
|
|
99
|
+
interval_end=minutes(0),
|
|
109
100
|
)
|
|
110
101
|
|
|
111
|
-
#
|
|
112
|
-
|
|
102
|
+
# Apply the randomly sampled dropout on the past site not the future
|
|
103
|
+
da_site_past = apply_sampled_dropout_time(
|
|
113
104
|
t0,
|
|
114
105
|
dropout_timedeltas=minutes(site_config.dropout_timedeltas_minutes),
|
|
115
106
|
dropout_frac=site_config.dropout_fraction,
|
|
107
|
+
da=da_site_past,
|
|
116
108
|
)
|
|
117
109
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
110
|
+
da_site_future = select_time_slice(
|
|
111
|
+
datasets_dict["site"],
|
|
112
|
+
t0,
|
|
113
|
+
time_resolution=minutes(site_config.time_resolution_minutes),
|
|
114
|
+
interval_start=minutes(site_config.time_resolution_minutes),
|
|
115
|
+
interval_end=minutes(site_config.interval_end_minutes),
|
|
122
116
|
)
|
|
123
117
|
|
|
118
|
+
sliced_datasets_dict["site"] = xr.concat([da_site_past, da_site_future], dim="time_utc")
|
|
119
|
+
|
|
124
120
|
return sliced_datasets_dict
|
|
@@ -7,7 +7,7 @@ ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdR
|
|
|
7
7
|
ocf_data_sampler/data/uk_gsp_locations_20220314.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
|
|
8
8
|
ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKFzQSjs6hcHFsV8a9uDDpy2E,9055334
|
|
9
9
|
ocf_data_sampler/load/__init__.py,sha256=-vQP9g0UOWdVbjEGyVX_ipa7R1btmiETIKAf6aw4d78,201
|
|
10
|
-
ocf_data_sampler/load/gsp.py,sha256=
|
|
10
|
+
ocf_data_sampler/load/gsp.py,sha256=YsIlj-LBUbREHNi78JMppOM1NbSkOe4kvtIrTwDx_JQ,1888
|
|
11
11
|
ocf_data_sampler/load/load_dataset.py,sha256=wSXPUQKgGRM6HC-yBXQ2IcDBQDckOSllmbGnhqikFMQ,2055
|
|
12
12
|
ocf_data_sampler/load/satellite.py,sha256=E7Ln7Y60Qr1RTV-_R71YoxXQM-Ca7Y1faIo3oKB2eFk,2292
|
|
13
13
|
ocf_data_sampler/load/site.py,sha256=zOzlWk6pYZBB5daqG8URGksmDXWKrkutUvN8uALAIh8,1468
|
|
@@ -31,7 +31,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=RaYzYIcB1AmDrKeiqSpn4QVfBH-QMe
|
|
|
31
31
|
ocf_data_sampler/numpy_sample/site.py,sha256=zfYBjK3CJrIaKH1QdKXU7gwOxTqONt527y3nJ9TRnwc,1325
|
|
32
32
|
ocf_data_sampler/numpy_sample/sun_position.py,sha256=5tt-zNm6aRuZMsxZPaAxyg7HeikswfZCeHWXTHuO2K0,1555
|
|
33
33
|
ocf_data_sampler/select/__init__.py,sha256=mK7Wu_-j9IXGTYrOuDf5yDDuU5a306b0iGKTAooNg_s,210
|
|
34
|
-
ocf_data_sampler/select/dropout.py,sha256=
|
|
34
|
+
ocf_data_sampler/select/dropout.py,sha256=9gPyDF7bGmvSoMjMPu1j0gTZFHNFqsT3ToIo9mFNA00,1565
|
|
35
35
|
ocf_data_sampler/select/fill_time_periods.py,sha256=TlGxp1xiAqnhdWfLy0pv3FuZc00dtimjWdLzr4JoTGA,865
|
|
36
36
|
ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=8lkWsV5i7iLCVGqQ-PGZbvWxsz3wBvLO70GSf6WeR0k,11363
|
|
37
37
|
ocf_data_sampler/select/geospatial.py,sha256=CDExkl36eZOKmdJPzUr_K0Wn3axHqv5nYo-EkSiINcc,5032
|
|
@@ -39,7 +39,7 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
|
|
|
39
39
|
ocf_data_sampler/select/select_spatial_slice.py,sha256=liAqIa-Amj58pOqx5r16i99HURj9oQ41j7gnPgRDQP4,8201
|
|
40
40
|
ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
|
|
41
41
|
ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=jfJSFcR0eO1AqeH7S3KnGjsBqVZT5w3oyi784PUR6Q0,146
|
|
42
|
-
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=
|
|
42
|
+
ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=cd4IyzYu8rMFgLHRXqYpnOIAZe4Yl21YdLmDQw45F7o,12545
|
|
43
43
|
ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nRUlhXQQGVrTuBmE1QnwXAUsPTXz0dsezlQjwK71jIQ,17641
|
|
44
44
|
ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
|
|
45
45
|
ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
|
|
@@ -49,13 +49,13 @@ ocf_data_sampler/torch_datasets/utils/__init__.py,sha256=N7i_hHtWUDiJqsiJoDx4T_Q
|
|
|
49
49
|
ocf_data_sampler/torch_datasets/utils/channel_dict_to_dataarray.py,sha256=un2IiyoAmTDIymdeMiPU899_86iCDMD-oIifjHlNyqw,555
|
|
50
50
|
ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py,sha256=we7BTxRH7B7jKayDT7YfNyfI3zZClz2Bk-HXKQIokgU,956
|
|
51
51
|
ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py,sha256=Hvz0wHSWMYYamf2oHNiGlzJcM4cAH6pL_7ZEvIBL2dE,1882
|
|
52
|
-
ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=
|
|
52
|
+
ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py,sha256=8E4a5v9dqr-sZOyBruuO-tjLPBbjtpYtdFY5z23aqnU,4365
|
|
53
53
|
ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=xcy75cVxl0WrglnX5YUAFjXXlO2GwEBHWyqo8TDuiOA,4714
|
|
54
54
|
ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul3l0EP73Ik002fStr_bhsZh9mQqEU,4735
|
|
55
55
|
scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
|
|
56
56
|
scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
|
|
57
57
|
utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
|
|
58
|
-
ocf_data_sampler-0.2.
|
|
59
|
-
ocf_data_sampler-0.2.
|
|
60
|
-
ocf_data_sampler-0.2.
|
|
61
|
-
ocf_data_sampler-0.2.
|
|
58
|
+
ocf_data_sampler-0.2.22.dist-info/METADATA,sha256=b5ruyqiy7iyNfAWznS1zENPC2fMNGv8uKYfzZI5ch1E,11581
|
|
59
|
+
ocf_data_sampler-0.2.22.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
60
|
+
ocf_data_sampler-0.2.22.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
|
|
61
|
+
ocf_data_sampler-0.2.22.dist-info/RECORD,,
|
|
File without changes
|