ocf-data-sampler 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

Files changed (89) hide show
  1. {ocf_data_sampler-0.1.0/ocf_data_sampler.egg-info → ocf_data_sampler-0.1.2}/PKG-INFO +1 -1
  2. ocf_data_sampler-0.1.2/ocf_data_sampler/numpy_sample/collate.py +64 -0
  3. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/sample/uk_regional.py +3 -1
  4. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/fill_time_periods.py +1 -1
  5. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/time_slice_for_dataset.py +16 -13
  6. ocf_data_sampler-0.1.2/ocf_data_sampler/torch_datasets/datasets/__init__.py +6 -0
  7. ocf_data_sampler-0.1.0/ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py → ocf_data_sampler-0.1.2/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +114 -16
  8. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2/ocf_data_sampler.egg-info}/PKG-INFO +1 -1
  9. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler.egg-info/SOURCES.txt +2 -3
  10. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/pyproject.toml +1 -1
  11. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/conftest.py +69 -70
  12. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/load/test_load_satellite.py +3 -3
  13. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/numpy_sample/test_collate.py +4 -9
  14. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/torch_datasets/test_merge_and_fill_utils.py +0 -2
  15. ocf_data_sampler-0.1.2/tests/torch_datasets/test_pvnet_uk.py +166 -0
  16. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/torch_datasets/test_site.py +47 -36
  17. ocf_data_sampler-0.1.0/ocf_data_sampler/numpy_sample/collate.py +0 -75
  18. ocf_data_sampler-0.1.0/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -11
  19. ocf_data_sampler-0.1.0/tests/torch_datasets/conftest.py +0 -18
  20. ocf_data_sampler-0.1.0/tests/torch_datasets/test_pvnet_uk_regional.py +0 -136
  21. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/LICENSE +0 -0
  22. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/MANIFEST.in +0 -0
  23. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/README.md +0 -0
  24. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/__init__.py +0 -0
  25. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/config/__init__.py +0 -0
  26. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/config/load.py +0 -0
  27. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/config/model.py +0 -0
  28. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/config/save.py +0 -0
  29. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/constants.py +0 -0
  30. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/data/uk_gsp_locations.csv +0 -0
  31. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/__init__.py +0 -0
  32. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/gsp.py +0 -0
  33. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/load_dataset.py +0 -0
  34. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  35. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/nwp/nwp.py +0 -0
  36. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  37. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  38. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  39. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/nwp/providers/utils.py +0 -0
  40. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/satellite.py +0 -0
  41. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/site.py +0 -0
  42. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/load/utils.py +0 -0
  43. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  44. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  45. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  46. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/nwp.py +0 -0
  47. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  48. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/site.py +0 -0
  49. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  50. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/sample/__init__.py +0 -0
  51. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/sample/base.py +0 -0
  52. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/sample/site.py +0 -0
  53. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/__init__.py +0 -0
  54. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/dropout.py +0 -0
  55. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  56. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/geospatial.py +0 -0
  57. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/location.py +0 -0
  58. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/select_spatial_slice.py +0 -0
  59. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/select_time_slice.py +0 -0
  60. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/select/spatial_slice_for_dataset.py +0 -0
  61. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/torch_datasets/datasets/site.py +0 -0
  62. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -0
  63. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  64. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler/utils.py +0 -0
  65. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  66. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler.egg-info/requires.txt +0 -0
  67. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  68. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/scripts/refactor_site.py +0 -0
  69. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/setup.cfg +0 -0
  70. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/__init__.py +0 -0
  71. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/config/test_config.py +0 -0
  72. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/config/test_save.py +0 -0
  73. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/load/test_load_gsp.py +0 -0
  74. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/load/test_load_nwp.py +0 -0
  75. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/load/test_load_sites.py +0 -0
  76. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/numpy_sample/test_datetime_features.py +0 -0
  77. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/numpy_sample/test_gsp.py +0 -0
  78. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/numpy_sample/test_nwp.py +0 -0
  79. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/numpy_sample/test_satellite.py +0 -0
  80. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/numpy_sample/test_sun_position.py +0 -0
  81. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/select/test_dropout.py +0 -0
  82. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/select/test_fill_time_periods.py +0 -0
  83. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/select/test_find_contiguous_time_periods.py +0 -0
  84. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/select/test_location.py +0 -0
  85. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/select/test_select_spatial_slice.py +0 -0
  86. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/select/test_select_time_slice.py +0 -0
  87. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/test_sample/test_base.py +0 -0
  88. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/test_sample/test_site_sample.py +0 -0
  89. {ocf_data_sampler-0.1.0 → ocf_data_sampler-0.1.2}/tests/test_sample/test_uk_regional_sample.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -0,0 +1,64 @@
1
+ import numpy as np
2
+
3
+
4
+ def stack_np_samples_into_batch(dict_list: list[dict]) -> dict:
5
+ """Stacks list of dict samples into a dict where all samples are joined along a new axis
6
+
7
+ Args:
8
+ dict_list: A list of dict-like samples to stack
9
+
10
+ Returns:
11
+ Dict of the samples stacked with new batch dimension on axis 0
12
+ """
13
+
14
+ batch = {}
15
+
16
+ keys = list(dict_list[0].keys())
17
+
18
+ for key in keys:
19
+ # NWP is nested so treat separately
20
+ if key == "nwp":
21
+ batch["nwp"] = {}
22
+
23
+ # Unpack NWP provider keys
24
+ nwp_providers = list(dict_list[0]["nwp"].keys())
25
+
26
+ for nwp_provider in nwp_providers:
27
+ # Keys can be different for different NWPs
28
+ nwp_keys = list(dict_list[0]["nwp"][nwp_provider].keys())
29
+
30
+ # Create dict to store NWP batch for this provider
31
+ nwp_provider_batch = {}
32
+
33
+ for nwp_key in nwp_keys:
34
+ # Stack values under each NWP key for this provider
35
+ nwp_provider_batch[nwp_key] = stack_data_list(
36
+ [d["nwp"][nwp_provider][nwp_key] for d in dict_list],
37
+ nwp_key,
38
+ )
39
+
40
+ batch["nwp"][nwp_provider] = nwp_provider_batch
41
+
42
+ else:
43
+ batch[key] = stack_data_list([d[key] for d in dict_list], key)
44
+
45
+ return batch
46
+
47
+
48
+ def _key_is_constant(key: str):
49
+ return key.endswith("t0_idx") or key.endswith("channel_names")
50
+
51
+
52
+ def stack_data_list(data_list: list, key: str):
53
+ """Stack a sequence of data elements along a new axis
54
+
55
+ Args:
56
+ data_list: List of data elements to combine
57
+ key: string identifying the data type
58
+ """
59
+ if _key_is_constant(key):
60
+ # These are always the same for all examples.
61
+ return data_list[0]
62
+ else:
63
+ return np.stack(data_list)
64
+
@@ -65,7 +65,9 @@ class UKRegionalSample(SampleBase):
65
65
  raise ValueError(f"Only .pt format is supported: {path.suffix}")
66
66
 
67
67
  instance = cls()
68
- instance._data = torch.load(path)
68
+ # TODO: We should move away from using torch.load(..., weights_only=False)
69
+ # This is not recommended
70
+ instance._data = torch.load(path, weights_only=False)
69
71
  logger.debug(f"Successfully loaded UKRegionalSample from {path}")
70
72
  return instance
71
73
 
@@ -4,7 +4,7 @@ import pandas as pd
4
4
  import numpy as np
5
5
 
6
6
 
7
- def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta):
7
+ def fill_time_periods(time_periods: pd.DataFrame, freq: pd.Timedelta) -> pd.DatetimeIndex:
8
8
  start_dts = pd.to_datetime(time_periods["start_dt"].values).ceil(freq)
9
9
  end_dts = pd.to_datetime(time_periods["end_dt"].values)
10
10
  date_ranges = [pd.date_range(start_dt, end_dt, freq=freq) for start_dt, end_dt in zip(start_dts, end_dts)]
@@ -1,5 +1,6 @@
1
1
  """ Slice datasets by time"""
2
2
  import pandas as pd
3
+ import xarray as xr
3
4
 
4
5
  from ocf_data_sampler.config import Configuration
5
6
  from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time
@@ -64,16 +65,8 @@ def slice_datasets_by_time(
64
65
 
65
66
  if "gsp" in datasets_dict:
66
67
  gsp_config = config.input_data.gsp
67
-
68
- sliced_datasets_dict["gsp_future"] = select_time_slice(
69
- datasets_dict["gsp"],
70
- t0,
71
- sample_period_duration=minutes(gsp_config.time_resolution_minutes),
72
- interval_start=minutes(gsp_config.time_resolution_minutes),
73
- interval_end=minutes(gsp_config.interval_end_minutes),
74
- )
75
-
76
- sliced_datasets_dict["gsp"] = select_time_slice(
68
+
69
+ da_gsp_past = select_time_slice(
77
70
  datasets_dict["gsp"],
78
71
  t0,
79
72
  sample_period_duration=minutes(gsp_config.time_resolution_minutes),
@@ -81,17 +74,27 @@ def slice_datasets_by_time(
81
74
  interval_end=minutes(0),
82
75
  )
83
76
 
84
- # Dropout on the GSP, but not the future GSP
77
+ # Dropout on the past GSP, but not the future GSP
85
78
  gsp_dropout_time = draw_dropout_time(
86
79
  t0,
87
80
  dropout_timedeltas=minutes(gsp_config.dropout_timedeltas_minutes),
88
81
  dropout_frac=gsp_config.dropout_fraction,
89
82
  )
90
83
 
91
- sliced_datasets_dict["gsp"] = apply_dropout_time(
92
- sliced_datasets_dict["gsp"],
84
+ da_gsp_past = apply_dropout_time(
85
+ da_gsp_past,
93
86
  gsp_dropout_time
94
87
  )
88
+
89
+ da_gsp_future = select_time_slice(
90
+ datasets_dict["gsp"],
91
+ t0,
92
+ sample_period_duration=minutes(gsp_config.time_resolution_minutes),
93
+ interval_start=minutes(gsp_config.time_resolution_minutes),
94
+ interval_end=minutes(gsp_config.interval_end_minutes),
95
+ )
96
+
97
+ sliced_datasets_dict["gsp"] = xr.concat([da_gsp_past, da_gsp_future], dim="time_utc")
95
98
 
96
99
  if "site" in datasets_dict:
97
100
  site_config = config.input_data.site
@@ -0,0 +1,6 @@
1
+ from .pvnet_uk import PVNetUKRegionalDataset, PVNetUKConcurrentDataset
2
+
3
+ from .site import (
4
+ convert_netcdf_to_numpy_sample,
5
+ SitesDataset
6
+ )
@@ -1,15 +1,20 @@
1
- """Torch dataset for PVNet"""
1
+ """Torch dataset for UK PVNet"""
2
+
3
+ import pkg_resources
2
4
 
3
5
  import numpy as np
4
6
  import pandas as pd
5
- import pkg_resources
6
7
  import xarray as xr
7
8
  from torch.utils.data import Dataset
8
9
  from ocf_data_sampler.config import Configuration, load_yaml_configuration
9
10
  from ocf_data_sampler.load.load_dataset import get_dataset_dict
10
- from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time
11
+ from ocf_data_sampler.select import (
12
+ fill_time_periods,
13
+ Location,
14
+ slice_datasets_by_space,
15
+ slice_datasets_by_time,
16
+ )
11
17
  from ocf_data_sampler.utils import minutes
12
- from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
13
18
  from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS, RSS_MEAN, RSS_STD
14
19
  from ocf_data_sampler.numpy_sample import (
15
20
  convert_nwp_to_numpy_sample,
@@ -17,13 +22,16 @@ from ocf_data_sampler.numpy_sample import (
17
22
  convert_gsp_to_numpy_sample,
18
23
  make_sun_position_numpy_sample,
19
24
  )
25
+ from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
26
+ from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
27
+ from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
28
+ from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
29
+ from ocf_data_sampler.torch_datasets.utils.valid_time_periods import find_valid_time_periods
20
30
  from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
21
31
  merge_dicts,
22
32
  fill_nans_in_arrays,
23
33
  )
24
- from ocf_data_sampler.numpy_sample.gsp import GSPSampleKey
25
- from ocf_data_sampler.numpy_sample.nwp import NWPSampleKey
26
- from ocf_data_sampler.select.geospatial import osgb_to_lon_lat
34
+
27
35
 
28
36
  xr.set_options(keep_attrs=True)
29
37
 
@@ -65,9 +73,10 @@ def process_and_combine_datasets(
65
73
  gsp_config = config.input_data.gsp
66
74
 
67
75
  if "gsp" in dataset_dict:
68
- da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
76
+ da_gsp = dataset_dict["gsp"]
69
77
  da_gsp = da_gsp / da_gsp.effective_capacity_mwp
70
-
78
+
79
+ # Convert to NumpyBatch
71
80
  numpy_modalities.append(
72
81
  convert_gsp_to_numpy_sample(
73
82
  da_gsp,
@@ -105,6 +114,7 @@ def process_and_combine_datasets(
105
114
 
106
115
  return combined_sample
107
116
 
117
+
108
118
  def compute(xarray_dict: dict) -> dict:
109
119
  """Eagerly load a nested dictionary of xarray DataArrays"""
110
120
  for k, v in xarray_dict.items():
@@ -114,10 +124,8 @@ def compute(xarray_dict: dict) -> dict:
114
124
  xarray_dict[k] = v.compute(scheduler="single-threaded")
115
125
  return xarray_dict
116
126
 
117
- def find_valid_t0_times(
118
- datasets_dict: dict,
119
- config: Configuration,
120
- ):
127
+
128
+ def find_valid_t0_times(datasets_dict: dict, config: Configuration) -> pd.DatetimeIndex:
121
129
  """Find the t0 times where all of the requested input data is available
122
130
 
123
131
  Args:
@@ -167,7 +175,7 @@ class PVNetUKRegionalDataset(Dataset):
167
175
  self,
168
176
  config_filename: str,
169
177
  start_time: str | None = None,
170
- end_time: str| None = None,
178
+ end_time: str | None = None,
171
179
  gsp_ids: list[int] | None = None,
172
180
  ):
173
181
  """A torch Dataset for creating PVNet UK GSP samples
@@ -253,7 +261,7 @@ class PVNetUKRegionalDataset(Dataset):
253
261
  def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
254
262
  """Generate a sample for the given coordinates.
255
263
 
256
- Useful for users to generate samples by GSP ID.
264
+ Useful for users to generate specific samples.
257
265
 
258
266
  Args:
259
267
  t0: init-time for sample
@@ -265,4 +273,94 @@ class PVNetUKRegionalDataset(Dataset):
265
273
 
266
274
  location = self.location_lookup[gsp_id]
267
275
 
268
- return self._get_sample(t0, location)
276
+ return self._get_sample(t0, location)
277
+
278
+
279
+ class PVNetUKConcurrentDataset(Dataset):
280
+ def __init__(
281
+ self,
282
+ config_filename: str,
283
+ start_time: str | None = None,
284
+ end_time: str | None = None,
285
+ gsp_ids: list[int] | None = None,
286
+ ):
287
+ """A torch Dataset for creating concurrent samples of PVNet UK regional data
288
+
289
+ Each concurrent sample includes the data from all GSPs for a single t0 time
290
+
291
+ Args:
292
+ config_filename: Path to the configuration file
293
+ start_time: Limit the init-times to be after this
294
+ end_time: Limit the init-times to be before this
295
+ gsp_ids: List of all GSP IDs included in each sample. Defaults to all
296
+ """
297
+
298
+ config = load_yaml_configuration(config_filename)
299
+
300
+ datasets_dict = get_dataset_dict(config)
301
+
302
+ # Get t0 times where all input data is available
303
+ valid_t0_times = find_valid_t0_times(datasets_dict, config)
304
+
305
+ # Filter t0 times to given range
306
+ if start_time is not None:
307
+ valid_t0_times = valid_t0_times[valid_t0_times>=pd.Timestamp(start_time)]
308
+
309
+ if end_time is not None:
310
+ valid_t0_times = valid_t0_times[valid_t0_times<=pd.Timestamp(end_time)]
311
+
312
+ # Construct list of locations to sample from
313
+ locations = get_gsp_locations(gsp_ids)
314
+
315
+ # Assign coords and indices to self
316
+ self.valid_t0_times = valid_t0_times
317
+ self.locations = locations
318
+
319
+ # Assign config and input data to self
320
+ self.datasets_dict = datasets_dict
321
+ self.config = config
322
+
323
+
324
+ def __len__(self):
325
+ return len(self.valid_t0_times)
326
+
327
+
328
+ def _get_sample(self, t0: pd.Timestamp) -> dict:
329
+ """Generate a concurrent PVNet sample for given init-time
330
+
331
+ Args:
332
+ t0: init-time for sample
333
+ """
334
+ # Slice by time then load to avoid loading the data multiple times from disk
335
+ sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
336
+ sample_dict = compute(sample_dict)
337
+
338
+ gsp_samples = []
339
+
340
+ # Prepare sample for each GSP
341
+ for location in self.locations:
342
+ gsp_sample_dict = slice_datasets_by_space(sample_dict, location, self.config)
343
+ gsp_numpy_sample = process_and_combine_datasets(
344
+ gsp_sample_dict, self.config, t0, location
345
+ )
346
+ gsp_samples.append(gsp_numpy_sample)
347
+
348
+ # Stack GSP samples
349
+ return stack_np_samples_into_batch(gsp_samples)
350
+
351
+
352
+ def __getitem__(self, idx):
353
+ return self._get_sample(self.valid_t0_times[idx])
354
+
355
+
356
+ def get_sample(self, t0: pd.Timestamp) -> dict:
357
+ """Generate a sample for the given init-time.
358
+
359
+ Useful for users to generate specific samples.
360
+
361
+ Args:
362
+ t0: init-time for sample
363
+ """
364
+ # Check data is availablle for init-time t0
365
+ assert t0 in self.valid_t0_times
366
+ return self._get_sample(t0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: ocf_data_sampler
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Sample from weather data for renewable energy prediction
5
5
  Author: James Fulton, Peter Dudfield, and the Open Climate Fix team
6
6
  Author-email: info@openclimatefix.org
@@ -50,7 +50,7 @@ ocf_data_sampler/select/select_time_slice.py
50
50
  ocf_data_sampler/select/spatial_slice_for_dataset.py
51
51
  ocf_data_sampler/select/time_slice_for_dataset.py
52
52
  ocf_data_sampler/torch_datasets/datasets/__init__.py
53
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk_regional.py
53
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py
54
54
  ocf_data_sampler/torch_datasets/datasets/site.py
55
55
  ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py
56
56
  ocf_data_sampler/torch_datasets/utils/valid_time_periods.py
@@ -78,7 +78,6 @@ tests/select/test_select_time_slice.py
78
78
  tests/test_sample/test_base.py
79
79
  tests/test_sample/test_site_sample.py
80
80
  tests/test_sample/test_uk_regional_sample.py
81
- tests/torch_datasets/conftest.py
82
81
  tests/torch_datasets/test_merge_and_fill_utils.py
83
- tests/torch_datasets/test_pvnet_uk_regional.py
82
+ tests/torch_datasets/test_pvnet_uk.py
84
83
  tests/torch_datasets/test_site.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ocf_data_sampler"
7
- version = "0.1.0"
7
+ version = "0.1.2"
8
8
  license = { file = "LICENSE" }
9
9
  readme = "README.md"
10
10
  description = "Sample from weather data for renewable energy prediction"
@@ -1,14 +1,15 @@
1
+ import pytest
2
+
1
3
  import os
2
4
  import numpy as np
3
5
  import pandas as pd
4
- import pytest
5
6
  import xarray as xr
6
- import tempfile
7
- from typing import Generator
7
+ import dask.array
8
8
 
9
9
  from ocf_data_sampler.config.model import Site
10
10
  from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration
11
11
 
12
+
12
13
  _top_test_directory = os.path.dirname(os.path.realpath(__file__))
13
14
 
14
15
  @pytest.fixture()
@@ -18,40 +19,27 @@ def test_config_filename():
18
19
 
19
20
  @pytest.fixture(scope="session")
20
21
  def config_filename():
21
- return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/configs/pvnet_test_config.yaml"
22
+ return f"{_top_test_directory}/test_data/configs/pvnet_test_config.yaml"
22
23
 
23
24
 
24
25
  @pytest.fixture(scope="session")
25
- def sat_zarr_path():
26
-
27
- # Load dataset which only contains coordinates, but no data
28
- ds = xr.open_zarr(
29
- f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr.zip"
30
- ).compute()
31
-
32
- # Add time coord
33
- ds = ds.assign_coords(time=pd.date_range("2023-01-01 00:00", "2023-01-02 23:55", freq="5min"))
34
-
35
- # Add data to dataset
36
- ds["data"] = xr.DataArray(
37
- np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32),
38
- coords=ds.coords,
39
- )
40
-
41
- # Transpose to variables, time, y, x (just in case)
42
- ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
26
+ def session_tmp_path(tmp_path_factory):
27
+ return tmp_path_factory.mktemp("data")
43
28
 
44
- # add 100,000 to x_geostationary, this to make sure the fix index is within the satellite image
45
- ds["x_geostationary"] = ds["x_geostationary"] - 200_000
46
29
 
47
- # Add some NaNs
48
- ds["data"].values[:, :, 0, 0] = np.nan
49
-
50
- # make sure channel values are strings
51
- ds["variable"] = ds["variable"].astype(str)
52
-
53
- # add data attrs area
54
- ds["data"].attrs["area"] = (
30
+ @pytest.fixture(scope="session")
31
+ def sat_zarr_path(session_tmp_path):
32
+
33
+ # Define coords for satellite-like dataset
34
+ variables = [
35
+ 'IR_016', 'IR_039', 'IR_087', 'IR_097', 'IR_108', 'IR_120',
36
+ 'IR_134', 'VIS006', 'VIS008', 'WV_062', 'WV_073',
37
+ ]
38
+ x = np.linspace(start=15002, stop=-1824245, num=100)
39
+ y = np.linspace(start=4191563, stop=5304712, num=100)
40
+ times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min")
41
+
42
+ area_string = (
55
43
  """msg_seviri_rss_3km:
56
44
  description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution
57
45
  projection:
@@ -73,16 +61,31 @@ def sat_zarr_path():
73
61
  units: m
74
62
  """
75
63
  )
76
-
77
- # Specifiy chunking
78
- ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1})
64
+
65
+ # Create satellite-like data with some NaNs
66
+ data = dask.array.zeros(
67
+ shape=(len(variables), len(times), len(y), len(x)),
68
+ chunks=(-1, 10, -1, -1),
69
+ dtype=np.float32
70
+ )
71
+ data [:, 10, :, :] = np.nan
72
+
73
+ ds = xr.DataArray(
74
+ data=data,
75
+ coords=dict(
76
+ variable=variables,
77
+ time=times,
78
+ y_geostationary=y,
79
+ x_geostationary=x,
80
+ ),
81
+ attrs=dict(area=area_string),
82
+ ).to_dataset(name="data")
79
83
 
80
84
  # Save temporarily as a zarr
81
- with tempfile.TemporaryDirectory() as tmpdir:
82
- zarr_path = f"{tmpdir}/test_sat.zarr"
83
- ds.to_zarr(zarr_path)
85
+ zarr_path = session_tmp_path / "test_sat.zarr"
86
+ ds.to_zarr(zarr_path)
84
87
 
85
- yield zarr_path
88
+ yield zarr_path
86
89
 
87
90
 
88
91
  @pytest.fixture(scope="session")
@@ -112,7 +115,7 @@ def ds_nwp_ukv():
112
115
 
113
116
 
114
117
  @pytest.fixture(scope="session")
115
- def nwp_ukv_zarr_path(ds_nwp_ukv):
118
+ def nwp_ukv_zarr_path(session_tmp_path, ds_nwp_ukv):
116
119
  ds = ds_nwp_ukv.chunk(
117
120
  {
118
121
  "init_time": 1,
@@ -122,10 +125,9 @@ def nwp_ukv_zarr_path(ds_nwp_ukv):
122
125
  "y": 50,
123
126
  }
124
127
  )
125
- with tempfile.TemporaryDirectory() as tmpdir:
126
- filename = tmpdir + "/ukv_nwp.zarr"
127
- ds.to_zarr(filename)
128
- yield filename
128
+ zarr_path = session_tmp_path / "ukv_nwp.zarr"
129
+ ds.to_zarr(zarr_path)
130
+ yield zarr_path
129
131
 
130
132
 
131
133
  @pytest.fixture(scope="session")
@@ -155,7 +157,7 @@ def ds_nwp_ecmwf():
155
157
 
156
158
 
157
159
  @pytest.fixture(scope="session")
158
- def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
160
+ def nwp_ecmwf_zarr_path(session_tmp_path, ds_nwp_ecmwf):
159
161
  ds = ds_nwp_ecmwf.chunk(
160
162
  {
161
163
  "init_time": 1,
@@ -165,10 +167,10 @@ def nwp_ecmwf_zarr_path(ds_nwp_ecmwf):
165
167
  "latitude": 50,
166
168
  }
167
169
  )
168
- with tempfile.TemporaryDirectory() as tmpdir:
169
- filename = tmpdir + "/ukv_ecmwf.zarr"
170
- ds.to_zarr(filename)
171
- yield filename
170
+
171
+ zarr_path = session_tmp_path / "ukv_ecmwf.zarr"
172
+ ds.to_zarr(zarr_path)
173
+ yield zarr_path
172
174
 
173
175
 
174
176
  @pytest.fixture(scope="session")
@@ -201,7 +203,7 @@ def ds_uk_gsp():
201
203
 
202
204
 
203
205
  @pytest.fixture(scope="session")
204
- def data_sites() -> Generator[Site, None, None]:
206
+ def data_sites(session_tmp_path) -> Site:
205
207
  """
206
208
  Make fake data for sites
207
209
  Returns: filename for netcdf file, and csv metadata
@@ -245,30 +247,27 @@ def data_sites() -> Generator[Site, None, None]:
245
247
  "generation_kw": da_gen,
246
248
  })
247
249
 
248
- with tempfile.TemporaryDirectory() as tmpdir:
249
- filename = tmpdir + "/sites.netcdf"
250
- filename_csv = tmpdir + "/sites_metadata.csv"
251
- generation.to_netcdf(filename)
252
- meta_df.to_csv(filename_csv)
253
-
254
- site = Site(
255
- file_path=filename,
256
- metadata_file_path=filename_csv,
257
- interval_start_minutes=-30,
258
- interval_end_minutes=60,
259
- time_resolution_minutes=30,
260
- )
250
+ filename = f"{session_tmp_path}/sites.netcdf"
251
+ filename_csv = f"{session_tmp_path}/sites_metadata.csv"
252
+ generation.to_netcdf(filename)
253
+ meta_df.to_csv(filename_csv)
254
+
255
+ site = Site(
256
+ file_path=filename,
257
+ metadata_file_path=filename_csv,
258
+ interval_start_minutes=-30,
259
+ interval_end_minutes=60,
260
+ time_resolution_minutes=30,
261
+ )
261
262
 
262
- yield site
263
+ yield site
263
264
 
264
265
 
265
266
  @pytest.fixture(scope="session")
266
- def uk_gsp_zarr_path(ds_uk_gsp):
267
-
268
- with tempfile.TemporaryDirectory() as tmpdir:
269
- filename = tmpdir + "/uk_gsp.zarr"
270
- ds_uk_gsp.to_zarr(filename)
271
- yield filename
267
+ def uk_gsp_zarr_path(session_tmp_path, ds_uk_gsp):
268
+ zarr_path = session_tmp_path / "uk_gsp.zarr"
269
+ ds_uk_gsp.to_zarr(zarr_path)
270
+ yield zarr_path
272
271
 
273
272
 
274
273
  @pytest.fixture()
@@ -8,10 +8,10 @@ def test_open_satellite(sat_zarr_path):
8
8
 
9
9
  assert isinstance(da, xr.DataArray)
10
10
  assert da.dims == ("time_utc", "channel", "x_geostationary", "y_geostationary")
11
- # 576 is 2 days of data at 5 minutes intervals, 12 * 24 * 2
11
+ # 288 is 1 days of data at 5 minutes intervals, 12 * 24
12
12
  # There are 11 channels
13
- # There are 49 x 20 pixels
14
- assert da.shape == (576, 11, 49, 20)
13
+ # There are 100 x 100 pixels
14
+ assert da.shape == (288, 11, 100, 100)
15
15
  assert np.issubdtype(da.dtype, np.number)
16
16
 
17
17
 
@@ -1,17 +1,12 @@
1
- from ocf_data_sampler.numpy_sample import GSPSampleKey, SatelliteSampleKey
2
1
  from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
3
- from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
2
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
4
3
 
5
4
 
6
- def test_pvnet(pvnet_config_filename):
5
+ def test_stack_np_samples_into_batch(pvnet_config_filename):
7
6
 
8
7
  # Create dataset object
9
8
  dataset = PVNetUKRegionalDataset(pvnet_config_filename)
10
9
 
11
- assert len(dataset.locations) == 317
12
- assert len(dataset.valid_t0_times) == 39
13
- assert len(dataset) == 317 * 39
14
-
15
10
  # Generate 2 samples
16
11
  sample1 = dataset[0]
17
12
  sample2 = dataset[1]
@@ -22,5 +17,5 @@ def test_pvnet(pvnet_config_filename):
22
17
  assert "nwp" in batch
23
18
  assert isinstance(batch["nwp"], dict)
24
19
  assert "ukv" in batch["nwp"]
25
- assert GSPSampleKey.gsp in batch
26
- assert SatelliteSampleKey.satellite_actual in batch
20
+ assert "gsp" in batch
21
+ assert "satellite_actual" in batch
@@ -33,9 +33,7 @@ def test_fill_nans_in_arrays():
33
33
 
34
34
  result = fill_nans_in_arrays(nested_dict)
35
35
 
36
- assert not np.isnan(result["array1"]).any()
37
36
  assert np.array_equal(result["array1"], np.array([1.0, 0.0, 3.0, 0.0]))
38
- assert not np.isnan(result["nested"]["array2"]).any()
39
37
  assert np.array_equal(result["nested"]["array2"], np.array([0.0, 2.0, 0.0, 4.0]))
40
38
  assert result["string_key"] == "not_an_array"
41
39