ocf-data-sampler 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ocf_data_sampler/config/model.py +34 -0
  2. ocf_data_sampler/load/load_dataset.py +55 -0
  3. ocf_data_sampler/load/nwp/providers/ecmwf.py +5 -2
  4. ocf_data_sampler/load/site.py +30 -0
  5. ocf_data_sampler/numpy_batch/__init__.py +4 -3
  6. ocf_data_sampler/numpy_batch/gsp.py +12 -12
  7. ocf_data_sampler/numpy_batch/nwp.py +14 -14
  8. ocf_data_sampler/numpy_batch/satellite.py +8 -8
  9. ocf_data_sampler/numpy_batch/site.py +29 -0
  10. ocf_data_sampler/select/__init__.py +8 -1
  11. ocf_data_sampler/select/dropout.py +2 -1
  12. ocf_data_sampler/select/geospatial.py +43 -1
  13. ocf_data_sampler/select/select_spatial_slice.py +8 -2
  14. ocf_data_sampler/select/spatial_slice_for_dataset.py +53 -0
  15. ocf_data_sampler/select/time_slice_for_dataset.py +124 -0
  16. ocf_data_sampler/time_functions.py +11 -0
  17. ocf_data_sampler/torch_datasets/process_and_combine.py +153 -0
  18. ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +8 -418
  19. ocf_data_sampler/torch_datasets/site.py +196 -0
  20. ocf_data_sampler/torch_datasets/valid_time_periods.py +108 -0
  21. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/METADATA +1 -1
  22. ocf_data_sampler-0.0.25.dist-info/RECORD +66 -0
  23. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/WHEEL +1 -1
  24. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/top_level.txt +1 -0
  25. scripts/refactor_site.py +50 -0
  26. tests/conftest.py +62 -0
  27. tests/load/test_load_sites.py +14 -0
  28. tests/numpy_batch/test_gsp.py +1 -2
  29. tests/numpy_batch/test_nwp.py +1 -3
  30. tests/numpy_batch/test_satellite.py +1 -3
  31. tests/numpy_batch/test_sun_position.py +7 -7
  32. tests/torch_datasets/test_pvnet_uk_regional.py +4 -6
  33. tests/torch_datasets/test_site.py +85 -0
  34. ocf_data_sampler-0.0.23.dist-info/RECORD +0 -54
  35. {ocf_data_sampler-0.0.23.dist-info → ocf_data_sampler-0.0.25.dist-info}/LICENSE +0 -0
@@ -0,0 +1,196 @@
1
+ """Torch dataset for sites"""
2
+ import logging
3
+
4
+ import pandas as pd
5
+ import xarray as xr
6
+ from torch.utils.data import Dataset
7
+
8
+ from ocf_data_sampler.config import Configuration, load_yaml_configuration
9
+ from ocf_data_sampler.load.load_dataset import get_dataset_dict
10
+ from ocf_data_sampler.select import (
11
+ Location,
12
+ fill_time_periods,
13
+ find_contiguous_t0_periods,
14
+ intersection_of_multiple_dataframes_of_periods,
15
+ slice_datasets_by_time, slice_datasets_by_space
16
+ )
17
+ from ocf_data_sampler.time_functions import minutes
18
+ from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute
19
+ from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods
20
+
21
+ xr.set_options(keep_attrs=True)
22
+
23
+
24
+ def find_valid_t0_and_site_ids(
25
+ datasets_dict: dict,
26
+ config: Configuration,
27
+ ) -> pd.DataFrame:
28
+ """Find the t0 times where all of the requested input data is available
29
+
30
+ The idea is to
31
+ 1. Get valid time period for nwp and satellite
32
+ 2. For each site location, find valid periods for that location
33
+
34
+ Args:
35
+ datasets_dict: A dictionary of input datasets
36
+ config: Configuration file
37
+ """
38
+
39
+ # 1. Get valid time period for nwp and satellite
40
+ datasets_nwp_and_sat_dict = {"nwp": datasets_dict["nwp"], "sat": datasets_dict["sat"]}
41
+ valid_time_periods = find_valid_time_periods(datasets_nwp_and_sat_dict, config)
42
+
43
+ # 2. Now lets loop over each location in system id and find the valid periods
44
+ # Should we have a different option if there are not nans
45
+ sites = datasets_dict["site"]
46
+ site_ids = sites.site_id.values
47
+ site_config = config.input_data.site
48
+ valid_t0_and_site_ids = []
49
+ for site_id in site_ids:
50
+ site = sites.sel(site_id=site_id)
51
+
52
+ # drop any nan values
53
+ # not sure this is right?
54
+ site = site.dropna(dim='time_utc')
55
+
56
+ # Get the valid time periods for this location
57
+ time_periods = find_contiguous_t0_periods(
58
+ pd.DatetimeIndex(site["time_utc"]),
59
+ sample_period_duration=minutes(site_config.time_resolution_minutes),
60
+ history_duration=minutes(site_config.history_minutes),
61
+ forecast_duration=minutes(site_config.forecast_minutes),
62
+ )
63
+ valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods(
64
+ [valid_time_periods, time_periods]
65
+ )
66
+
67
+ # Fill out the contiguous time periods to get the t0 times
68
+ valid_t0_times_per_site = fill_time_periods(
69
+ valid_time_periods_per_site,
70
+ freq=minutes(site_config.time_resolution_minutes)
71
+ )
72
+
73
+ valid_t0_per_site = pd.DataFrame(index=valid_t0_times_per_site)
74
+ valid_t0_per_site['site_id'] = site_id
75
+ valid_t0_and_site_ids.append(valid_t0_per_site)
76
+
77
+ valid_t0_and_site_ids = pd.concat(valid_t0_and_site_ids)
78
+ valid_t0_and_site_ids.index.name = 't0'
79
+ valid_t0_and_site_ids.reset_index(inplace=True)
80
+
81
+ return valid_t0_and_site_ids
82
+
83
+
84
+ def get_locations(site_xr: xr.Dataset):
85
+ """Get list of locations of all sites"""
86
+
87
+ locations = []
88
+ for site_id in site_xr.site_id.values:
89
+ site = site_xr.sel(site_id=site_id)
90
+ location = Location(
91
+ id=site_id,
92
+ x=site.longitude.values,
93
+ y=site.latitude.values,
94
+ coordinate_system="lon_lat"
95
+ )
96
+ locations.append(location)
97
+
98
+ return locations
99
+
100
+
101
+ class SitesDataset(Dataset):
102
+ def __init__(
103
+ self,
104
+ config_filename: str,
105
+ start_time: str | None = None,
106
+ end_time: str | None = None,
107
+ ):
108
+ """A torch Dataset for creating PVNet Site samples
109
+
110
+ Args:
111
+ config_filename: Path to the configuration file
112
+ start_time: Limit the init-times to be after this
113
+ end_time: Limit the init-times to be before this
114
+ """
115
+
116
+ config = load_yaml_configuration(config_filename)
117
+
118
+ datasets_dict = get_dataset_dict(config)
119
+
120
+ # get all locations
121
+ self.locations = get_locations(datasets_dict['site'])
122
+
123
+ # Get t0 times where all input data is available
124
+ valid_t0_and_site_ids = find_valid_t0_and_site_ids(datasets_dict, config)
125
+
126
+ # Filter t0 times to given range
127
+ if start_time is not None:
128
+ valid_t0_and_site_ids \
129
+ = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] >= pd.Timestamp(start_time)]
130
+
131
+ if end_time is not None:
132
+ valid_t0_and_site_ids \
133
+ = valid_t0_and_site_ids[valid_t0_and_site_ids['t0'] <= pd.Timestamp(end_time)]
134
+
135
+
136
+ # Assign coords and indices to self
137
+ self.valid_t0_and_site_ids = valid_t0_and_site_ids
138
+
139
+ # Assign config and input data to self
140
+ self.datasets_dict = datasets_dict
141
+ self.config = config
142
+
143
+ def __len__(self):
144
+ return len(self.valid_t0_and_site_ids)
145
+
146
+ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
147
+ """Generate the PVNet sample for given coordinates
148
+
149
+ Args:
150
+ t0: init-time for sample
151
+ location: location for sample
152
+ """
153
+ sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
154
+ sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
155
+ sample_dict = compute(sample_dict)
156
+
157
+ sample = process_and_combine_datasets(sample_dict, self.config, t0, location, sun_position_key='site')
158
+
159
+ return sample
160
+
161
+ def get_location_from_site_id(self, site_id):
162
+ """Get location from system id"""
163
+
164
+ locations = [loc for loc in self.locations if loc.id == site_id]
165
+ if len(locations) == 0:
166
+ raise ValueError(f"Location not found for site_id {site_id}")
167
+
168
+ if len(locations) > 1:
169
+ logging.warning(f"Multiple locations found for site_id {site_id}, but will take the first")
170
+
171
+ return locations[0]
172
+
173
+ def __getitem__(self, idx):
174
+
175
+ # Get the coordinates of the sample
176
+ t0, site_id = self.valid_t0_and_site_ids.iloc[idx]
177
+
178
+ # get location from site id
179
+ location = self.get_location_from_site_id(site_id)
180
+
181
+ # Generate the sample
182
+ return self._get_sample(t0, location)
183
+
184
+ def get_sample(self, t0: pd.Timestamp, site_id: int) -> dict:
185
+ """Generate a sample for a given site id and t0.
186
+
187
+ Useful for users to generate samples by t0 and site id
188
+
189
+ Args:
190
+ t0: init-time for sample
191
+ site_id: site id as int
192
+ """
193
+
194
+ location = self.get_location_from_site_id(site_id)
195
+
196
+ return self._get_sample(t0, location)
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from ocf_data_sampler.config import Configuration
5
+ from ocf_data_sampler.select.find_contiguous_time_periods import find_contiguous_t0_periods_nwp, \
6
+ find_contiguous_t0_periods, intersection_of_multiple_dataframes_of_periods
7
+ from ocf_data_sampler.time_functions import minutes
8
+
9
+
10
+ def find_valid_time_periods(
11
+ datasets_dict: dict,
12
+ config: Configuration,
13
+ ):
14
+ """Find the t0 times where all of the requested input data is available
15
+
16
+ Args:
17
+ datasets_dict: A dictionary of input datasets
18
+ config: Configuration file
19
+ """
20
+
21
+ assert set(datasets_dict.keys()).issubset({"nwp", "sat", "gsp"})
22
+
23
+ contiguous_time_periods: dict[str: pd.DataFrame] = {} # Used to store contiguous time periods from each data source
24
+
25
+ if "nwp" in datasets_dict:
26
+ for nwp_key, nwp_config in config.input_data.nwp.items():
27
+
28
+ da = datasets_dict["nwp"][nwp_key]
29
+
30
+ if nwp_config.dropout_timedeltas_minutes is None:
31
+ max_dropout = minutes(0)
32
+ else:
33
+ max_dropout = minutes(np.max(np.abs(nwp_config.dropout_timedeltas_minutes)))
34
+
35
+ if nwp_config.max_staleness_minutes is None:
36
+ max_staleness = None
37
+ else:
38
+ max_staleness = minutes(nwp_config.max_staleness_minutes)
39
+
40
+ # The last step of the forecast is lost if we have to diff channels
41
+ if len(nwp_config.nwp_accum_channels) > 0:
42
+ end_buffer = minutes(nwp_config.time_resolution_minutes)
43
+ else:
44
+ end_buffer = minutes(0)
45
+
46
+ # This is the max staleness we can use considering the max step of the input data
47
+ max_possible_staleness = (
48
+ pd.Timedelta(da["step"].max().item())
49
+ - minutes(nwp_config.forecast_minutes)
50
+ - end_buffer
51
+ )
52
+
53
+ # Default to use max possible staleness unless specified in config
54
+ if max_staleness is None:
55
+ max_staleness = max_possible_staleness
56
+ else:
57
+ # Make sure the max acceptable staleness isn't longer than the max possible
58
+ assert max_staleness <= max_possible_staleness
59
+
60
+ time_periods = find_contiguous_t0_periods_nwp(
61
+ datetimes=pd.DatetimeIndex(da["init_time_utc"]),
62
+ history_duration=minutes(nwp_config.history_minutes),
63
+ max_staleness=max_staleness,
64
+ max_dropout=max_dropout,
65
+ )
66
+
67
+ contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods
68
+
69
+ if "sat" in datasets_dict:
70
+ sat_config = config.input_data.satellite
71
+
72
+ time_periods = find_contiguous_t0_periods(
73
+ pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]),
74
+ sample_period_duration=minutes(sat_config.time_resolution_minutes),
75
+ history_duration=minutes(sat_config.history_minutes),
76
+ forecast_duration=minutes(sat_config.forecast_minutes),
77
+ )
78
+
79
+ contiguous_time_periods['sat'] = time_periods
80
+
81
+ if "gsp" in datasets_dict:
82
+ gsp_config = config.input_data.gsp
83
+
84
+ time_periods = find_contiguous_t0_periods(
85
+ pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]),
86
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
87
+ history_duration=minutes(gsp_config.history_minutes),
88
+ forecast_duration=minutes(gsp_config.forecast_minutes),
89
+ )
90
+
91
+ contiguous_time_periods['gsp'] = time_periods
92
+
93
+ # just get the values (not the keys)
94
+ contiguous_time_periods_values = list(contiguous_time_periods.values())
95
+
96
+ # Find joint overlapping contiguous time periods
97
+ if len(contiguous_time_periods_values) > 1:
98
+ valid_time_periods = intersection_of_multiple_dataframes_of_periods(
99
+ contiguous_time_periods_values
100
+ )
101
+ else:
102
+ valid_time_periods = contiguous_time_periods_values[0]
103
+
104
+ # check there are some valid time periods
105
+ if len(valid_time_periods) == 0:
106
+ raise ValueError(f"No valid time periods found, {contiguous_time_periods=}")
107
+
108
+ return valid_time_periods
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ocf_data_sampler
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -0,0 +1,66 @@
1
+ ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
+ ocf_data_sampler/constants.py,sha256=tUwHrsGShqIn5Izze4i32_xB6X0v67rvQwIYB-P5PJQ,3355
3
+ ocf_data_sampler/time_functions.py,sha256=R6ZlVEe6h4UlJeUW7paZYAMWveOv9MTjMsoISCwnsiE,284
4
+ ocf_data_sampler/config/__init__.py,sha256=YXnAkgHViHB26hSsjiv32b6EbpG-A1kKTkARJf0_RkY,212
5
+ ocf_data_sampler/config/load.py,sha256=4f7vPHAIAmd-55tPxoIzn7F_TI_ue4NxkDcLPoVWl0g,943
6
+ ocf_data_sampler/config/model.py,sha256=5GO8SF_4iOZhCAyIJyENSl0dnDRIWrURgqwslrVWke8,9462
7
+ ocf_data_sampler/config/save.py,sha256=wKdctbv0dxIIiQtcRHLRxpWQVhEFQ_FCWg-oNaRLIps,1093
8
+ ocf_data_sampler/data/uk_gsp_locations.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
9
+ ocf_data_sampler/load/__init__.py,sha256=MjgfxilTzyz1RYFoBEeAXmE9hyjknLvdmlHPmlAoiQY,44
10
+ ocf_data_sampler/load/gsp.py,sha256=Gcr1JVUOPKhFRDCSHtfPDjxx0BtyyEhXrZvGEKLPJ5I,759
11
+ ocf_data_sampler/load/load_dataset.py,sha256=R4RAIVLVx6CHA6Qs61kD9sx834I_GMGAn6G7ZgwFMUA,1627
12
+ ocf_data_sampler/load/satellite.py,sha256=3KlA1fx4SwxdzM-jC1WRaONXO0D6m0WxORnEnwUnZrA,2967
13
+ ocf_data_sampler/load/site.py,sha256=ROif2XXIIgBz-JOOiHymTq1CMXswJ3AzENU9DJmYpcU,782
14
+ ocf_data_sampler/load/utils.py,sha256=EQGvVWlGMoSOdbDYuMfVAa0v6wmAOPmHIAemdrTB5v4,1406
15
+ ocf_data_sampler/load/nwp/__init__.py,sha256=SmcrnbygO5xtCKmGR4wtHrj-HI7nOAvnAtfuvRufBGQ,25
16
+ ocf_data_sampler/load/nwp/nwp.py,sha256=O4QnajEZem8BvBgTcYYDBhRhgqPYuJkolHmpMRmrXEA,610
17
+ ocf_data_sampler/load/nwp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ ocf_data_sampler/load/nwp/providers/ecmwf.py,sha256=2iR1Iy542lo51rC6XFLV-3pbUE68dWjlHa6TVJzx3ac,1280
19
+ ocf_data_sampler/load/nwp/providers/ukv.py,sha256=79Bm7q-K_GJPYMy62SUIZbRWRF4-tIaB1dYPEgLD9vo,1207
20
+ ocf_data_sampler/load/nwp/providers/utils.py,sha256=Sy2exG1wpXLLhMXYdsfR-DZMR3txG1_bBmBdchlc-yA,848
21
+ ocf_data_sampler/numpy_batch/__init__.py,sha256=8MgRF29rK9bKP4b4iHakaoGwBKUcjWZ-VFKjCcq53QA,336
22
+ ocf_data_sampler/numpy_batch/gsp.py,sha256=QjQ25JmtufvdiSsxUkBTPhxouYGWPnnWze8pXr_aBno,960
23
+ ocf_data_sampler/numpy_batch/nwp.py,sha256=dAehfRo5DL2Yb20ifHHl5cU1QOrm3ZOpQmN39fSUOw8,1255
24
+ ocf_data_sampler/numpy_batch/satellite.py,sha256=3NoE_ElzMHwO60apqJeFAwI6J7eIxD0OWTyAVl-uJi8,903
25
+ ocf_data_sampler/numpy_batch/site.py,sha256=lJYMEot50UgSBnSOgADQMjUhky1YyWKYqwNsisyYV6w,789
26
+ ocf_data_sampler/numpy_batch/sun_position.py,sha256=zw2bjtcjsm_tvKk0r_MZmgfYUJLHuLjLly2sMjwP3XI,1606
27
+ ocf_data_sampler/select/__init__.py,sha256=E4AJulEbO2K-o0UlG1fgaEteuf_1ZFjHTvrotXSb4YU,332
28
+ ocf_data_sampler/select/dropout.py,sha256=HCx5Wzk8Oh2Z9vV94Jy-ALJsHtGduwvMaQOleQXp5z0,1142
29
+ ocf_data_sampler/select/fill_time_periods.py,sha256=iTtMjIPFYG5xtUYYedAFBLjTWWUa7t7WQ0-yksWf0-E,440
30
+ ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=6ioB8LeFpFNBMgKDxrgG3zqzNjkBF_jlV9yye2ZYT2E,11925
31
+ ocf_data_sampler/select/geospatial.py,sha256=4xL-9y674jjoaXeqE52NHCHVfknciE4OEGsZtn9DvP4,4911
32
+ ocf_data_sampler/select/location.py,sha256=26Y5ZjfFngShBwXieuWSoOA-RLaRzci4TTmcDk3Wg7U,2015
33
+ ocf_data_sampler/select/select_spatial_slice.py,sha256=WNxwur9Q5oetvogATw8-hNejDuEwrXHzuZIovFDjNJA,11488
34
+ ocf_data_sampler/select/select_time_slice.py,sha256=41cch1fQr59fZgv7UHsNGc3OvoynrixT3bmr3_1d7cU,6628
35
+ ocf_data_sampler/select/spatial_slice_for_dataset.py,sha256=Nrc3j8DR5MM4BPPp9IQwaIMpoyOkc6AADMnfOjg-170,1791
36
+ ocf_data_sampler/select/time_slice_for_dataset.py,sha256=A9fxvurbM0JSRkrjyg5Lr70_Mj6t5OO7HFqHUZel9q4,4220
37
+ ocf_data_sampler/torch_datasets/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
38
+ ocf_data_sampler/torch_datasets/process_and_combine.py,sha256=Lovc2UM3-HgUy2BoQEIr0gQTz3USW6ACRWo-iTgxjHs,4993
39
+ ocf_data_sampler/torch_datasets/pvnet_uk_regional.py,sha256=TpHALGU7hpo3iLbvD0nkoY6zu94Vq99W1V1qSGEcIW8,5552
40
+ ocf_data_sampler/torch_datasets/site.py,sha256=1k0fWXYwAAIWG5DX_j3tgNfY8gglfPGLNzNlZd8EnJs,6631
41
+ ocf_data_sampler/torch_datasets/valid_time_periods.py,sha256=dNJkBH5wdsFUjoFSmthU3yTqar6OPE77WsRQUebm-PY,4163
42
+ scripts/refactor_site.py,sha256=asZ27hQ4IyXgCCUaFJqcz1ObBNcV2W3ywqHBpSXA_fc,1728
43
+ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
+ tests/conftest.py,sha256=ZRktySCynj3NBbFRR4EFNLRLFMErkQsC-qQlmQzhbRg,7360
45
+ tests/config/test_config.py,sha256=G_PD_pXib0zdRBPUIn0jjwJ9VyoKaO_TanLN1Mh5Ca4,5055
46
+ tests/load/test_load_gsp.py,sha256=aT_nqaSXmUTcdHzuTT7AmXJr3R31k4OEN-Fv3eLxlQE,424
47
+ tests/load/test_load_nwp.py,sha256=3qyyDkB1q9t3tyAwogfotNrxqUOpXXimco1CImoEWGg,753
48
+ tests/load/test_load_satellite.py,sha256=STX5AqqmOAgUgE9R1xyq_sM3P1b8NKdGjO-hDhayfxM,524
49
+ tests/load/test_load_sites.py,sha256=T9lSEnGPI8FQISudVYHHNTHeplNS62Vrx48jaZ6J_Jo,364
50
+ tests/numpy_batch/test_gsp.py,sha256=VANXV32K8aLX4dCdhCUnDorJmyNN-Bjc7Wc1N-RzWEk,548
51
+ tests/numpy_batch/test_nwp.py,sha256=Fnj7cR-VR2Z0kMu8SrgnIayjxWnPWrYFjWSjMmnrh4Y,1445
52
+ tests/numpy_batch/test_satellite.py,sha256=8a4ZwMLpsOmYKmwI1oW_su_hwkCNYMEJAEfa0dbsx1k,1179
53
+ tests/numpy_batch/test_sun_position.py,sha256=FYQ7KtlN0V5LlEjgI-cKjTMtGHUCxiMvxkRYTdMAgEE,2485
54
+ tests/select/test_dropout.py,sha256=kiycl7RxAQYMCZJlokmx6Da5h_oBpSs8Is8pmSW4gOU,2413
55
+ tests/select/test_fill_time_periods.py,sha256=o59f2YRe5b0vJrG3B0aYZkYeHnpNk4s6EJxdXZluNQg,907
56
+ tests/select/test_find_contiguous_time_periods.py,sha256=G6tJRJd0DMfH9EdfzlKWsmfTbtMwOf3w-2filjJzuIQ,5998
57
+ tests/select/test_location.py,sha256=_WZk2FPYeJ-nIfCJS6Sp_yaVEEo7m31DmMFoZzgyCts,2712
58
+ tests/select/test_select_spatial_slice.py,sha256=7EX9b6g-pMdACQx3yefjs5do2s-Rho2UmKevV4oglsU,5147
59
+ tests/select/test_select_time_slice.py,sha256=XC1J3DBBDnt81jcba5u-Hnd0yKv8GIQErLm-OECV6rs,10147
60
+ tests/torch_datasets/test_pvnet_uk_regional.py,sha256=u3taw6p3oozM0_7cEEhCYbImAQPRldRhpruqSyV08Vg,2675
61
+ tests/torch_datasets/test_site.py,sha256=5hdUP64neCDWEo2NMSd-MhbpuQjQvD6NOvhZ1DlMmo8,2733
62
+ ocf_data_sampler-0.0.25.dist-info/LICENSE,sha256=F-Q3UFCR-BECSocV55BFDpn4YKxve9PKrm-lTt6o_Tg,1073
63
+ ocf_data_sampler-0.0.25.dist-info/METADATA,sha256=p3SKEM4gRy0Z4LTcRWlgTrpjQ-QV89ar69tM9EwhudU,5269
64
+ ocf_data_sampler-0.0.25.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
65
+ ocf_data_sampler-0.0.25.dist-info/top_level.txt,sha256=Faob6N6cFdPc5eUpCTYcXgCaNhi4XLLteUL5W5ayYmg,31
66
+ ocf_data_sampler-0.0.25.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,2 +1,3 @@
1
1
  ocf_data_sampler
2
+ scripts
2
3
  tests
@@ -0,0 +1,50 @@
1
+ """ Helper functions for refactoring legacy site data """
2
+
3
+
4
+ def legacy_format(data_ds, metadata_df):
5
+ """This formats old legacy data to the new format.
6
+
7
+ 1. This renames the columns in the metadata
8
+ 2. Re-formats the site data from data variables named by the site_id to
9
+ a data array with a site_id dimension. Also adds capacity_kwp to the dataset as a time series for each site_id
10
+ """
11
+
12
+ if "system_id" in metadata_df.columns:
13
+ metadata_df["site_id"] = metadata_df["system_id"]
14
+
15
+ if "capacity_megawatts" in metadata_df.columns:
16
+ metadata_df["capacity_kwp"] = metadata_df["capacity_megawatts"] * 1000
17
+
18
+ # only site data has the site_id as data variables.
19
+ # We want to join them all together and create another coordinate called site_id
20
+ if "0" in data_ds:
21
+ gen_df = data_ds.to_dataframe()
22
+ gen_da = xr.DataArray(
23
+ data=gen_df.values,
24
+ coords=(
25
+ ("time_utc", gen_df.index.values),
26
+ ("site_id", metadata_df["site_id"]),
27
+ ),
28
+ name="generation_kw",
29
+ )
30
+
31
+ capacity_df = gen_df
32
+ for col in capacity_df.columns:
33
+ capacity_df[col] = metadata_df[metadata_df["site_id"].astype(str) == col][
34
+ "capacity_kwp"
35
+ ].iloc[0]
36
+ capacity_da = xr.DataArray(
37
+ data=capacity_df.values,
38
+ coords=(
39
+ ("time_utc", gen_df.index.values),
40
+ ("site_id", metadata_df["site_id"]),
41
+ ),
42
+ name="capacity_kwp",
43
+ )
44
+ data_ds = xr.Dataset(
45
+ {
46
+ "generation_kw": gen_da,
47
+ "capacity_kwp": capacity_da,
48
+ }
49
+ )
50
+ return data_ds
tests/conftest.py CHANGED
@@ -6,6 +6,8 @@ import pytest
6
6
  import xarray as xr
7
7
  import tempfile
8
8
 
9
+ from ocf_data_sampler.config.model import Site
10
+
9
11
  _top_test_directory = os.path.dirname(os.path.realpath(__file__))
10
12
 
11
13
  @pytest.fixture()
@@ -197,6 +199,66 @@ def ds_uk_gsp():
197
199
  })
198
200
 
199
201
 
202
+ @pytest.fixture(scope="session")
203
+ def data_sites() -> Site:
204
+ """
205
+ Make fake data for sites
206
+ Returns: filename for netcdf file, and csv metadata
207
+ """
208
+ times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min")
209
+ site_ids = list(range(0,10))
210
+ capacity_kwp_1d = np.array([0.1,1.1,4,6,8,9,15,2,3,4])
211
+ # these are quite specific for the fake satellite data
212
+ longitude = np.arange(-4, -3, 0.1)
213
+ latitude = np.arange(51, 52, 0.1)
214
+
215
+ generation = np.random.uniform(0, 200, size=(len(times), len(site_ids))).astype(np.float32)
216
+
217
+ # repeat capacity in new dims len(times) times
218
+ capacity_kwp = (np.tile(capacity_kwp_1d, len(times))).reshape(len(times),10)
219
+
220
+ coords = (
221
+ ("time_utc", times),
222
+ ("site_id", site_ids),
223
+ )
224
+
225
+ da_cap = xr.DataArray(
226
+ capacity_kwp,
227
+ coords=coords,
228
+ )
229
+
230
+ da_gen = xr.DataArray(
231
+ generation,
232
+ coords=coords,
233
+ )
234
+
235
+ # metadata
236
+ meta_df = pd.DataFrame(columns=[], data = [])
237
+ meta_df['site_id'] = site_ids
238
+ meta_df['capacity_kwp'] = capacity_kwp_1d
239
+ meta_df['longitude'] = longitude
240
+ meta_df['latitude'] = latitude
241
+
242
+ generation = xr.Dataset({
243
+ "capacity_kwp": da_cap,
244
+ "generation_kw": da_gen,
245
+ })
246
+
247
+ with tempfile.TemporaryDirectory() as tmpdir:
248
+ filename = tmpdir + "/sites.netcdf"
249
+ filename_csv = tmpdir + "/sites_metadata.csv"
250
+ generation.to_netcdf(filename)
251
+ meta_df.to_csv(filename_csv)
252
+
253
+ site = Site(file_path=filename,
254
+ metadata_file_path=filename_csv,
255
+ time_resolution_minutes=30,
256
+ forecast_minutes=60,
257
+ history_minutes=30)
258
+
259
+ yield site
260
+
261
+
200
262
  @pytest.fixture(scope="session")
201
263
  def uk_gsp_zarr_path(ds_uk_gsp):
202
264
 
@@ -0,0 +1,14 @@
1
+ from ocf_data_sampler.load.site import open_site
2
+ import xarray as xr
3
+
4
+
5
+ def test_open_site(data_sites):
6
+ da = open_site(data_sites)
7
+
8
+ assert isinstance(da, xr.DataArray)
9
+ assert da.dims == ("time_utc", "site_id")
10
+
11
+ assert "capacity_kwp" in da.coords
12
+ assert "latitude" in da.coords
13
+ assert "longitude" in da.coords
14
+ assert da.shape == (49, 10)
@@ -1,7 +1,6 @@
1
1
  from ocf_data_sampler.load.gsp import open_gsp
2
2
 
3
- from ocf_data_sampler.numpy_batch import convert_gsp_to_numpy_batch
4
- from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
3
+ from ocf_data_sampler.numpy_batch import convert_gsp_to_numpy_batch, GSPBatchKey
5
4
 
6
5
  def test_convert_gsp_to_numpy_batch(uk_gsp_zarr_path):
7
6
 
@@ -4,9 +4,7 @@ import xarray as xr
4
4
 
5
5
  import pytest
6
6
 
7
- from ocf_data_sampler.numpy_batch import convert_nwp_to_numpy_batch
8
-
9
- from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
7
+ from ocf_data_sampler.numpy_batch import convert_nwp_to_numpy_batch, NWPBatchKey
10
8
 
11
9
  @pytest.fixture(scope="module")
12
10
  def da_nwp_like():
@@ -5,9 +5,7 @@ import xarray as xr
5
5
 
6
6
  import pytest
7
7
 
8
- from ocf_data_sampler.numpy_batch import convert_satellite_to_numpy_batch
9
-
10
- from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
8
+ from ocf_data_sampler.numpy_batch import convert_satellite_to_numpy_batch, SatelliteBatchKey
11
9
 
12
10
 
13
11
  @pytest.fixture(scope="module")
@@ -6,7 +6,7 @@ from ocf_data_sampler.numpy_batch.sun_position import (
6
6
  calculate_azimuth_and_elevation, make_sun_position_numpy_batch
7
7
  )
8
8
 
9
- from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
9
+ from ocf_data_sampler.numpy_batch import GSPBatchKey
10
10
 
11
11
 
12
12
  @pytest.mark.parametrize("lat", [0, 5, 10, 23.5])
@@ -71,11 +71,11 @@ def test_make_sun_position_numpy_batch():
71
71
 
72
72
  batch = make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix="gsp")
73
73
 
74
- assert GSPBatchKey.gsp_solar_elevation in batch
75
- assert GSPBatchKey.gsp_solar_azimuth in batch
74
+ assert GSPBatchKey.solar_elevation in batch
75
+ assert GSPBatchKey.solar_azimuth in batch
76
76
 
77
77
  # The solar coords are normalised in the function
78
- assert (batch[GSPBatchKey.gsp_solar_elevation]>=0).all()
79
- assert (batch[GSPBatchKey.gsp_solar_elevation]<=1).all()
80
- assert (batch[GSPBatchKey.gsp_solar_azimuth]>=0).all()
81
- assert (batch[GSPBatchKey.gsp_solar_azimuth]<=1).all()
78
+ assert (batch[GSPBatchKey.solar_elevation]>=0).all()
79
+ assert (batch[GSPBatchKey.solar_elevation]<=1).all()
80
+ assert (batch[GSPBatchKey.solar_azimuth]>=0).all()
81
+ assert (batch[GSPBatchKey.solar_azimuth]<=1).all()
@@ -3,9 +3,7 @@ import tempfile
3
3
 
4
4
  from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
5
5
  from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
6
- from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
7
- from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey
8
- from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey
6
+ from ocf_data_sampler.numpy_batch import NWPBatchKey, GSPBatchKey, SatelliteBatchKey
9
7
 
10
8
 
11
9
  @pytest.fixture()
@@ -39,7 +37,7 @@ def test_pvnet(pvnet_config_filename):
39
37
 
40
38
  for key in [
41
39
  NWPBatchKey.nwp, SatelliteBatchKey.satellite_actual, GSPBatchKey.gsp,
42
- GSPBatchKey.gsp_solar_azimuth, GSPBatchKey.gsp_solar_elevation,
40
+ GSPBatchKey.solar_azimuth, GSPBatchKey.solar_elevation,
43
41
  ]:
44
42
  assert key in sample
45
43
 
@@ -54,8 +52,8 @@ def test_pvnet(pvnet_config_filename):
54
52
  # 3 hours of 30 minute data (inclusive)
55
53
  assert sample[GSPBatchKey.gsp].shape == (7,)
56
54
  # Solar angles have same shape as GSP data
57
- assert sample[GSPBatchKey.gsp_solar_azimuth].shape == (7,)
58
- assert sample[GSPBatchKey.gsp_solar_elevation].shape == (7,)
55
+ assert sample[GSPBatchKey.solar_azimuth].shape == (7,)
56
+ assert sample[GSPBatchKey.solar_elevation].shape == (7,)
59
57
 
60
58
  def test_pvnet_no_gsp(pvnet_config_filename):
61
59