ocf-data-sampler 0.5.11__tar.gz → 0.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/PKG-INFO +7 -3
  2. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/README.md +5 -1
  3. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/model.py +27 -41
  4. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/gsp.py +4 -2
  5. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/nwp.py +0 -1
  6. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/utils.py +1 -1
  7. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/open_xarray_tensorstore.py +26 -7
  8. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/satellite.py +1 -1
  9. ocf_data_sampler-0.6.2/ocf_data_sampler/load/site.py +73 -0
  10. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/nwp.py +1 -1
  11. ocf_data_sampler-0.6.2/ocf_data_sampler/select/diff_channels.py +25 -0
  12. ocf_data_sampler-0.6.2/ocf_data_sampler/select/dropout.py +59 -0
  13. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/location.py +2 -2
  14. ocf_data_sampler-0.6.2/ocf_data_sampler/select/select_spatial_slice.py +110 -0
  15. ocf_data_sampler-0.6.2/ocf_data_sampler/select/select_time_slice.py +107 -0
  16. ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/datasets/picklecache.py +33 -0
  17. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py +17 -15
  18. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/datasets/site.py +27 -38
  19. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/__init__.py +2 -1
  20. ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +59 -0
  21. ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/diff_nwp_data.py +20 -0
  22. ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +50 -0
  23. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/time_slice_for_dataset.py +22 -30
  24. ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/base.py → ocf_data_sampler-0.6.2/ocf_data_sampler/torch_datasets/utils/torch_batch_utils.py +2 -29
  25. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/utils.py +18 -6
  26. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/PKG-INFO +7 -3
  27. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/SOURCES.txt +4 -4
  28. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/pyproject.toml +1 -1
  29. ocf_data_sampler-0.5.11/ocf_data_sampler/load/site.py +0 -59
  30. ocf_data_sampler-0.5.11/ocf_data_sampler/select/dropout.py +0 -61
  31. ocf_data_sampler-0.5.11/ocf_data_sampler/select/select_spatial_slice.py +0 -216
  32. ocf_data_sampler-0.5.11/ocf_data_sampler/select/select_time_slice.py +0 -143
  33. ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/__init__.py +0 -3
  34. ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/site.py +0 -48
  35. ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/sample/uk_regional.py +0 -262
  36. ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/utils/config_normalization_values_to_dicts.py +0 -57
  37. ocf_data_sampler-0.5.11/ocf_data_sampler/torch_datasets/utils/merge_and_fill_utils.py +0 -29
  38. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/LICENSE +0 -0
  39. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/__init__.py +0 -0
  40. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/__init__.py +0 -0
  41. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/load.py +0 -0
  42. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/config/save.py +0 -0
  43. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/data/uk_gsp_locations_20220314.csv +0 -0
  44. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/data/uk_gsp_locations_20250109.csv +0 -0
  45. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/__init__.py +0 -0
  46. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/load_dataset.py +0 -0
  47. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/__init__.py +0 -0
  48. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/__init__.py +0 -0
  49. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/cloudcasting.py +0 -0
  50. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/ecmwf.py +0 -0
  51. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/gfs.py +0 -0
  52. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/icon.py +0 -0
  53. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/nwp/providers/ukv.py +0 -0
  54. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/load/utils.py +0 -0
  55. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/__init__.py +0 -0
  56. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/collate.py +0 -0
  57. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/common_types.py +0 -0
  58. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/datetime_features.py +0 -0
  59. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/gsp.py +0 -0
  60. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/satellite.py +0 -0
  61. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/site.py +0 -0
  62. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/numpy_sample/sun_position.py +0 -0
  63. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/__init__.py +0 -0
  64. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/fill_time_periods.py +0 -0
  65. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/find_contiguous_time_periods.py +0 -0
  66. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/select/geospatial.py +0 -0
  67. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/datasets/__init__.py +0 -0
  68. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/add_alterate_coordinate_projections.py +0 -0
  69. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/spatial_slice_for_dataset.py +0 -0
  70. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/valid_time_periods.py +0 -0
  71. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler/torch_datasets/utils/validation_utils.py +0 -0
  72. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/dependency_links.txt +0 -0
  73. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/requires.txt +0 -0
  74. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/ocf_data_sampler.egg-info/top_level.txt +0 -0
  75. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/scripts/download_gsp_location_data.py +0 -0
  76. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/scripts/refactor_site.py +0 -0
  77. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/setup.cfg +0 -0
  78. {ocf_data_sampler-0.5.11 → ocf_data_sampler-0.6.2}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.11
3
+ Version: 0.6.2
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -28,7 +28,7 @@ License: MIT License
28
28
  Project-URL: repository, https://github.com/openclimatefix/ocf-data-sampler
29
29
  Classifier: Programming Language :: Python :: 3
30
30
  Classifier: License :: OSI Approved :: MIT License
31
- Requires-Python: >=3.11
31
+ Requires-Python: <3.14,>=3.11
32
32
  Description-Content-Type: text/markdown
33
33
  Requires-Dist: torch
34
34
  Requires-Dist: numpy
@@ -50,7 +50,7 @@ Requires-Dist: zarr>=3
50
50
  # ocf-data-sampler
51
51
 
52
52
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
53
- [![All Contributors](https://img.shields.io/badge/all_contributors-14-orange.svg?style=flat-square)](#contributors-)
53
+ [![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-)
54
54
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
55
55
 
56
56
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -137,6 +137,10 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
137
137
  <td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
138
138
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
139
139
  </tr>
140
+ <tr>
141
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/utsav-pal"><img src="https://avatars.githubusercontent.com/u/159793156?v=4?s=100" width="100px;" alt="utsav-pal"/><br /><sub><b>utsav-pal</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=utsav-pal" title="Code">💻</a></td>
142
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=zaryab-ali" title="Code">💻</a></td>
143
+ </tr>
140
144
  </tbody>
141
145
  </table>
142
146
 
@@ -1,7 +1,7 @@
1
1
  # ocf-data-sampler
2
2
 
3
3
  <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
4
- [![All Contributors](https://img.shields.io/badge/all_contributors-14-orange.svg?style=flat-square)](#contributors-)
4
+ [![All Contributors](https://img.shields.io/badge/all_contributors-16-orange.svg?style=flat-square)](#contributors-)
5
5
  <!-- ALL-CONTRIBUTORS-BADGE:END -->
6
6
 
7
7
  [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/ocf-data-sampler?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/ocf-data-sampler/tags)
@@ -88,6 +88,10 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
88
88
  <td align="center" valign="top" width="14.28%"><a href="https://drona-gyawali.github.io/"><img src="https://avatars.githubusercontent.com/u/170401554?v=4?s=100" width="100px;" alt="Dorna Raj Gyawali"/><br /><sub><b>Dorna Raj Gyawali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=drona-gyawali" title="Code">💻</a></td>
89
89
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/adnanhashmi25"><img src="https://avatars.githubusercontent.com/u/55550094?v=4?s=100" width="100px;" alt="Adnan Hashmi"/><br /><sub><b>Adnan Hashmi</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=adnanhashmi25" title="Code">💻</a></td>
90
90
  </tr>
91
+ <tr>
92
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/utsav-pal"><img src="https://avatars.githubusercontent.com/u/159793156?v=4?s=100" width="100px;" alt="utsav-pal"/><br /><sub><b>utsav-pal</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=utsav-pal" title="Code">💻</a></td>
93
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/zaryab-ali"><img src="https://avatars.githubusercontent.com/u/85732412?v=4?s=100" width="100px;" alt="zaryab-ali"/><br /><sub><b>zaryab-ali</b></sub></a><br /><a href="https://github.com/openclimatefix/ocf-data-sampler/commits?author=zaryab-ali" title="Code">💻</a></td>
94
+ </tr>
91
95
  </tbody>
92
96
  </table>
93
97
 
@@ -7,7 +7,7 @@ Prefix with a protocol like s3:// to read from alternative filesystems.
7
7
  from collections.abc import Iterator
8
8
  from typing import Literal
9
9
 
10
- from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
10
+ from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
11
11
  from typing_extensions import override
12
12
 
13
13
  NWP_PROVIDERS = [
@@ -23,10 +23,7 @@ NWP_PROVIDERS = [
23
23
  class Base(BaseModel):
24
24
  """Pydantic Base model where no extras can be added."""
25
25
 
26
- class Config:
27
- """Config class."""
28
-
29
- extra = "forbid" # forbid use of extra kwargs
26
+ model_config = ConfigDict(extra="forbid")
30
27
 
31
28
 
32
29
  class General(Base):
@@ -90,12 +87,17 @@ class DropoutMixin(Base):
90
87
  "negative or zero.",
91
88
  )
92
89
 
93
- dropout_fraction: float|list[float] = Field(
94
- default=0,
90
+ dropout_fraction: float | list[float] = Field(
91
+ default=0.0,
95
92
  description="Either a float(Chance of dropout being applied to each sample) or a list of "
96
93
  "floats (probability that dropout of the corresponding timedelta is applied)",
97
94
  )
98
95
 
96
+ dropout_value: float = Field(
97
+ default=0.0,
98
+ description="The value to use for dropped out values. "
99
+ "Idea is to use -1, but to be backwards comptaible we've put the default as 0")
100
+
99
101
  @field_validator("dropout_timedeltas_minutes")
100
102
  def dropout_timedeltas_minutes_negative(cls, v: list[int]) -> list[int]:
101
103
  """Validate 'dropout_timedeltas_minutes'."""
@@ -106,31 +108,22 @@ class DropoutMixin(Base):
106
108
 
107
109
 
108
110
  @field_validator("dropout_fraction")
109
- def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
111
+ def dropout_fractions(cls, dropout_frac: float | list[float]) -> float | list[float]:
110
112
  """Validate 'dropout_frac'."""
111
- from math import isclose
112
- if isinstance(dropout_frac, float):
113
- if not (dropout_frac <= 1):
114
- raise ValueError("Input should be less than or equal to 1")
115
- elif not (dropout_frac >= 0):
116
- raise ValueError("Input should be greater than or equal to 0")
113
+ if isinstance(dropout_frac, float | int):
114
+ if not (0<= dropout_frac <= 1):
115
+ raise ValueError("Dropout fractions must be in range [0, 1]")
117
116
 
118
117
  elif isinstance(dropout_frac, list):
119
118
  if not dropout_frac:
120
119
  raise ValueError("List cannot be empty")
121
120
 
122
- if not all(isinstance(i, float) for i in dropout_frac):
123
- raise ValueError("All elements in the list must be floats")
124
-
125
121
  if not all(0 <= i <= 1 for i in dropout_frac):
126
- raise ValueError("Each float in the list must be between 0 and 1")
127
-
128
- if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
129
- raise ValueError("Sum of all floats in the list must be 1.0")
122
+ raise ValueError("All dropout fractions must be in range [0, 1]")
130
123
 
124
+ if not (0 <= sum(dropout_frac) <= 1):
125
+ raise ValueError("The sum of dropout fractions must be in range [0, 1]")
131
126
 
132
- else:
133
- raise TypeError("Must be either a float or a list of floats")
134
127
  return dropout_frac
135
128
 
136
129
 
@@ -172,23 +165,6 @@ class NormalisationConstantsMixin(Base):
172
165
  """Normalisation constants for multiple channels."""
173
166
  normalisation_constants: dict[str, NormalisationValues]
174
167
 
175
- @property
176
- def channel_means(self) -> dict[str, float]:
177
- """Return the channel means."""
178
- return {
179
- channel: norm_values.mean
180
- for channel, norm_values in self.normalisation_constants.items()
181
- }
182
-
183
-
184
- @property
185
- def channel_stds(self) -> dict[str, float]:
186
- """Return the channel standard deviations."""
187
- return {
188
- channel: norm_values.std
189
- for channel, norm_values in self.normalisation_constants.items()
190
- }
191
-
192
168
 
193
169
  class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin, NormalisationConstantsMixin):
194
170
  """Satellite configuration model."""
@@ -363,9 +339,19 @@ class InputData(Base):
363
339
  site: Site | None = None
364
340
  solar_position: SolarPosition | None = None
365
341
 
342
+ @model_validator(mode="after")
343
+ def check_site_or_gsp(self) -> "InputData":
344
+ """Ensure that either `site` or `gsp` is provided in the input data."""
345
+ if self.site is None and self.gsp is None:
346
+ raise ValueError(
347
+ "You must provide either `site` or `gsp` in the `input_data`",
348
+ )
349
+
350
+ return self
351
+
366
352
 
367
353
  class Configuration(Base):
368
354
  """Configuration model for the dataset."""
369
355
 
370
356
  general: General = General()
371
- input_data: InputData = InputData()
357
+ input_data: InputData = Field(default_factory=InputData)
@@ -32,7 +32,7 @@ def open_gsp(
32
32
  boundaries_version: str = "20220314",
33
33
  public: bool = False,
34
34
  ) -> xr.DataArray:
35
- """Open the GSP data and validates its data types.
35
+ """Open and eagerly load the GSP data and validates its data types.
36
36
 
37
37
  Args:
38
38
  zarr_path: Path to the GSP zarr data
@@ -93,4 +93,6 @@ def open_gsp(
93
93
  dtype = gsp_da.coords[coord].dtype
94
94
  raise TypeError(f"{coord} should be {expected_dtype.__name__}, not {dtype}")
95
95
 
96
- return gsp_da
96
+ # Below we load the data eagerly into memory - this makes the dataset faster to sample from, but
97
+ # at the cost of a little extra memory usage
98
+ return gsp_da.compute()
@@ -29,7 +29,6 @@ def _validate_nwp_data(data_array: xr.DataArray, provider: str) -> None:
29
29
  common_expected_dtypes = {
30
30
  "init_time_utc": np.datetime64,
31
31
  "step": np.timedelta64,
32
- "channel": (np.str_, np.object_),
33
32
  }
34
33
 
35
34
  geographic_spatial_dtypes = {
@@ -75,7 +75,7 @@ def _tensostore_open_zarr_paths(zarr_path: str | list[str], time_dim: str) -> xr
75
75
  zarr_path = sorted(glob(zarr_path))
76
76
 
77
77
  if isinstance(zarr_path, list | tuple):
78
- ds = open_zarrs(zarr_path, concat_dim=time_dim).sortby(time_dim)
78
+ ds = open_zarrs(zarr_path, concat_dim=time_dim, data_source="nwp").sortby(time_dim)
79
79
  else:
80
80
  ds = open_zarr(zarr_path)
81
81
  return ds
@@ -14,6 +14,7 @@ References:
14
14
  [2] https://www.apache.org/licenses/LICENSE-2.0
15
15
  """
16
16
 
17
+ import logging
17
18
  import os.path
18
19
  import re
19
20
 
@@ -26,6 +27,7 @@ from xarray_tensorstore import (
26
27
  _TensorStoreAdapter,
27
28
  )
28
29
 
30
+ logger = logging.getLogger(__name__)
29
31
 
30
32
  def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
31
33
  if re.match(r"\w+\://", path): # path is a URI
@@ -127,6 +129,7 @@ def open_zarrs(
127
129
  concat_dim: str,
128
130
  context: ts.Context | None = None,
129
131
  mask_and_scale: bool = True,
132
+ data_source: str = "unknown",
130
133
  ) -> xr.Dataset:
131
134
  """Open multiple zarrs with TensorStore.
132
135
 
@@ -135,6 +138,7 @@ def open_zarrs(
135
138
  concat_dim: Dimension along which to concatenate the data variables.
136
139
  context: TensorStore context.
137
140
  mask_and_scale: Whether to mask and scale the data.
141
+ data_source: Which data source is being opened. Used for warning context.
138
142
 
139
143
  Returns:
140
144
  Concatenated Dataset with all data variables opened via TensorStore.
@@ -143,13 +147,28 @@ def open_zarrs(
143
147
  context = ts.Context()
144
148
 
145
149
  ds_list = [xr.open_zarr(p, mask_and_scale=mask_and_scale, decode_timedelta=True) for p in paths]
146
- ds = xr.concat(
147
- ds_list,
148
- dim=concat_dim,
149
- data_vars="minimal",
150
- compat="equals",
151
- combine_attrs="drop_conflicts",
152
- )
150
+ try:
151
+ ds = xr.concat(
152
+ ds_list,
153
+ dim=concat_dim,
154
+ data_vars="minimal",
155
+ compat="equals",
156
+ combine_attrs="drop_conflicts",
157
+ join="exact",
158
+ )
159
+ except ValueError:
160
+ logger.warning(f"Coordinate mismatch found in {data_source} input data. "
161
+ f"The coordinates will be overwritten! "
162
+ f"This might be fine for satellite data. "
163
+ f"Proceed with caution.")
164
+ ds = xr.concat(
165
+ ds_list,
166
+ dim=concat_dim,
167
+ data_vars="minimal",
168
+ compat="equals",
169
+ combine_attrs="drop_conflicts",
170
+ join="override",
171
+ )
153
172
 
154
173
  if mask_and_scale:
155
174
  _raise_if_mask_and_scale_used_for_data_vars(ds)
@@ -19,7 +19,7 @@ def open_sat_data(zarr_path: str | list[str]) -> xr.DataArray:
19
19
  """
20
20
  # Open the data
21
21
  if isinstance(zarr_path, list | tuple):
22
- ds = open_zarrs(zarr_path, concat_dim="time")
22
+ ds = open_zarrs(zarr_path, concat_dim="time", data_source="satellite")
23
23
  else:
24
24
  ds = open_zarr(zarr_path)
25
25
 
@@ -0,0 +1,73 @@
1
+ """Funcitons for loading site data."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import xarray as xr
6
+
7
+
8
+ def open_site(generation_file_path: str, metadata_file_path: str) -> xr.DataArray:
9
+ """Open a site's generation data and metadata.
10
+
11
+ Args:
12
+ generation_file_path: Path to the site generation netcdf data
13
+ metadata_file_path: Path to the site csv metadata
14
+
15
+ Returns:
16
+ xr.DataArray: The opened site generation data
17
+ """
18
+ generation_ds = xr.open_dataset(generation_file_path)
19
+ metadata_df = pd.read_csv(metadata_file_path, index_col="site_id")
20
+
21
+ if not metadata_df.index.is_unique:
22
+ raise ValueError("site_id is not unique in metadata")
23
+
24
+ # Ensure metadata aligns with the site_id dimension in generation_ds
25
+ metadata_df = metadata_df.reindex(generation_ds.site_id.values)
26
+
27
+ # Assign coordinates to the Dataset using the aligned metadata
28
+ # Check if variable capacity was passed with the generation data
29
+ # If not assign static capacity from metadata
30
+ if hasattr(generation_ds,"capacity_kwp"):
31
+ generation_ds = generation_ds.assign_coords(
32
+ latitude=(metadata_df.latitude.to_xarray()),
33
+ longitude=(metadata_df.longitude.to_xarray()),
34
+ capacity_kwp=generation_ds.capacity_kwp,
35
+ )
36
+ else:
37
+ generation_ds = generation_ds.assign_coords(
38
+ latitude=(metadata_df.latitude.to_xarray()),
39
+ longitude=(metadata_df.longitude.to_xarray()),
40
+ capacity_kwp=(metadata_df.capacity_kwp.to_xarray()),
41
+ )
42
+
43
+ # Sanity checks, to prevent inf or negative values
44
+ # Note NaNs are allowed in generation_kw as can have non overlapping time periods for sites
45
+ if np.isinf(generation_ds.generation_kw.values).all():
46
+ raise ValueError("generation_kw contains infinite (+/- inf) values")
47
+ if not (generation_ds.capacity_kwp.values > 0).all():
48
+ raise ValueError("capacity_kwp contains non-positive values")
49
+
50
+ site_da = generation_ds.generation_kw
51
+
52
+ # Validate data types directly in loading function
53
+ if not np.issubdtype(site_da.dtype, np.floating):
54
+ raise TypeError(f"Generation data should be float, not {site_da.dtype}")
55
+
56
+
57
+ coord_dtypes = {
58
+ "time_utc": (np.datetime64,),
59
+ "site_id": (np.integer,),
60
+ "capacity_kwp": (np.integer, np.floating),
61
+ "latitude": (np.floating,),
62
+ "longitude": (np.floating,),
63
+ }
64
+ for coord, expected_dtypes in coord_dtypes.items():
65
+ if not any(np.issubdtype(site_da.coords[coord].dtype, dt) for dt in expected_dtypes):
66
+ dtype = site_da.coords[coord].dtype
67
+ allowed = ", ".join(dt.__name__ for dt in expected_dtypes)
68
+ raise TypeError(f"{coord} should be one of ({allowed}), not {dtype}")
69
+
70
+ # Load the data eagerly into memory by calling compute
71
+ # this makes the dataset faster to sample from, but
72
+ # at the cost of a little extra memory usage
73
+ return site_da.compute()
@@ -28,7 +28,7 @@ def convert_nwp_to_numpy_sample(da: xr.DataArray, t0_idx: int | None = None) ->
28
28
  NWPSampleKey.channel_names: da.channel.values,
29
29
  NWPSampleKey.init_time_utc: da.init_time_utc.values.astype(float),
30
30
  NWPSampleKey.step: (da.step.values / 3600).astype(int),
31
- NWPSampleKey.target_time_utc: da.target_time_utc.values.astype(float),
31
+ NWPSampleKey.target_time_utc: (da.init_time_utc.values + da.step.values).astype(float),
32
32
  }
33
33
 
34
34
  if t0_idx is not None:
@@ -0,0 +1,25 @@
1
+ """Takes the diff along the step axis for a given set of channels."""
2
+
3
+ import numpy as np
4
+ import xarray as xr
5
+
6
+
7
+ def diff_channels(da: xr.DataArray, accum_channels: list[str]) -> xr.DataArray:
8
+ """Perform in-place diff of the given channels of the DataArray in the steps dimension.
9
+
10
+ Args:
11
+ da: The DataArray to slice from
12
+ accum_channels: Channels which are accumulated and need to be differenced
13
+ """
14
+ if da.dims[:2] != ("step", "channel"):
15
+ raise ValueError("This function assumes the first two dimensions are step then channel")
16
+
17
+ all_channels = da.channel.values
18
+ accum_channel_inds = [i for i, c in enumerate(all_channels) if c in accum_channels]
19
+
20
+ # Make a copy of the values to avoid changing the underlying numpy array
21
+ vals = da.values.copy()
22
+ vals[:-1, accum_channel_inds] = np.diff(vals[:, accum_channel_inds], axis=0)
23
+ da.values = vals
24
+
25
+ return da.isel(step=slice(0, -1))
@@ -0,0 +1,59 @@
1
+ """Functions for simulating dropout in time series data.
2
+
3
+ This is used for the following types of data: GSP, Satellite and Site
4
+ This is not used for NWP
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import xarray as xr
10
+
11
+
12
+ def apply_history_dropout(
13
+ t0: pd.Timestamp,
14
+ dropout_timedeltas: list[pd.Timedelta],
15
+ dropout_frac: float | list[float],
16
+ da: xr.DataArray,
17
+ ) -> xr.DataArray:
18
+ """Apply randomly sampled dropout to the historical part of some sequence data.
19
+
20
+ Dropped out data is replaced with NaNs
21
+
22
+ Args:
23
+ t0: The forecast init-time.
24
+ dropout_timedeltas: List of timedeltas relative to t0 to pick from
25
+ dropout_frac: The probabilit(ies) that each dropout timedelta will be applied. This should
26
+ be between 0 and 1 inclusive.
27
+ da: Xarray DataArray with 'time_utc' coordinate
28
+ """
29
+ if len(dropout_timedeltas)==0:
30
+ return da
31
+
32
+ if isinstance(dropout_frac, float | int):
33
+
34
+ if not (0<=dropout_frac<=1):
35
+ raise ValueError("`dropout_frac` must be in range [0, 1]")
36
+
37
+ # Create list with equal chance for all dropout timedeltas
38
+ n = len(dropout_timedeltas)
39
+ dropout_frac = [dropout_frac/n for _ in range(n)]
40
+ else:
41
+ if not 0<=sum(dropout_frac)<=1:
42
+ raise ValueError("The sum of `dropout_frac` must be in range [0, 1]")
43
+ if len(dropout_timedeltas)!=len(dropout_frac):
44
+ raise ValueError("`dropout_timedeltas` and `dropout_frac` must have the same length")
45
+
46
+ dropout_frac = [*dropout_frac] # Make copy of the list so we can append to it
47
+
48
+ dropout_timedeltas = [*dropout_timedeltas] # Make copy of the list so we can append to it
49
+
50
+ # Add chance of no dropout
51
+ dropout_frac.append(1-sum(dropout_frac))
52
+ dropout_timedeltas.append(None)
53
+
54
+ timedelta_choice = np.random.choice(dropout_timedeltas, p=dropout_frac)
55
+
56
+ if timedelta_choice is None:
57
+ return da
58
+ else:
59
+ return da.where((da.time_utc <= timedelta_choice + t0) | (da.time_utc> t0))
@@ -37,9 +37,9 @@ class Location:
37
37
  return self._projections[coord_system]
38
38
  else:
39
39
  raise ValueError(
40
- "Requested the coodinate in {coord_system}. This has not yet been added. "
40
+ f"Requested the coodinate in {coord_system}. This has not yet been added. "
41
41
  "The current available coordinate systems are "
42
- f"{list(self.self._projections.keys())}",
42
+ f"{list(self._projections.keys())}",
43
43
  )
44
44
 
45
45
  def add_coord_system(self, x: float, y: float, coord_system: int) -> None:
@@ -0,0 +1,110 @@
1
+ """Select spatial slices."""
2
+
3
+ import numpy as np
4
+ import xarray as xr
5
+
6
+ from ocf_data_sampler.select.geospatial import find_coord_system
7
+ from ocf_data_sampler.select.location import Location
8
+
9
+
10
+ def _get_pixel_index_location(da: xr.DataArray, location: Location) -> tuple[int, int]:
11
+ """Find pixel index location closest to given Location.
12
+
13
+ Args:
14
+ da: The xarray DataArray.
15
+ location: The Location object representing the point of interest.
16
+
17
+ Returns:
18
+ The pixel indices.
19
+
20
+ Raises:
21
+ ValueError: If the location is outside the bounds of the DataArray.
22
+ """
23
+ target_coords, x_dim, y_dim = find_coord_system(da)
24
+
25
+ x, y = location.in_coord_system(target_coords)
26
+
27
+ x_vals = da[x_dim].values
28
+ y_vals = da[y_dim].values
29
+
30
+ # Check that requested point lies within the data
31
+ if not (x_vals[0] < x < x_vals[-1]):
32
+ raise ValueError(
33
+ f"{x} is not in the interval {x_vals[0]}: {x_vals[-1]}",
34
+ )
35
+ if not (y_vals[0] < y < y_vals[-1]):
36
+ raise ValueError(
37
+ f"{y} is not in the interval {y_vals[0]}: {y_vals[-1]}",
38
+ )
39
+
40
+ closest_x = np.argmin(np.abs(x_vals - x))
41
+ closest_y = np.argmin(np.abs(y_vals - y))
42
+
43
+ return closest_x, closest_y
44
+
45
+
46
+ def select_spatial_slice_pixels(
47
+ da: xr.DataArray,
48
+ location: Location,
49
+ width_pixels: int,
50
+ height_pixels: int,
51
+ ) -> xr.DataArray:
52
+ """Select spatial slice based off pixels from location point of interest.
53
+
54
+ Args:
55
+ da: xarray DataArray to slice from
56
+ location: Location of interest that will be the center of the returned slice
57
+ height_pixels: Height of the slice in pixels
58
+ width_pixels: Width of the slice in pixels
59
+
60
+ Returns:
61
+ The selected DataArray slice.
62
+
63
+ Raises:
64
+ ValueError: If the dimensions are not even or the slice is not allowed
65
+ when padding is required.
66
+ """
67
+ if (width_pixels % 2) != 0:
68
+ raise ValueError("Width must be an even number")
69
+ if (height_pixels % 2) != 0:
70
+ raise ValueError("Height must be an even number")
71
+
72
+ _, x_dim, y_dim = find_coord_system(da)
73
+ center_idx_x, center_idx_y = _get_pixel_index_location(da, location)
74
+
75
+ half_width = width_pixels // 2
76
+ half_height = height_pixels // 2
77
+
78
+ left_idx = int(center_idx_x - half_width)
79
+ right_idx = int(center_idx_x + half_width)
80
+ bottom_idx = int(center_idx_y - half_height)
81
+ top_idx = int(center_idx_y + half_height)
82
+
83
+ data_width_pixels = len(da[x_dim])
84
+ data_height_pixels = len(da[y_dim])
85
+
86
+ # Padding checks
87
+ slice_unavailable = (
88
+ left_idx < 0
89
+ or right_idx > data_width_pixels
90
+ or bottom_idx < 0
91
+ or top_idx > data_height_pixels
92
+ )
93
+
94
+ if slice_unavailable:
95
+ issues = []
96
+ if left_idx < 0:
97
+ issues.append(f"left_idx ({left_idx}) < 0")
98
+ if right_idx > data_width_pixels:
99
+ issues.append(f"right_idx ({right_idx}) > data_width_pixels ({data_width_pixels})")
100
+ if bottom_idx < 0:
101
+ issues.append(f"bottom_idx ({bottom_idx}) < 0")
102
+ if top_idx > data_height_pixels:
103
+ issues.append(f"top_idx ({top_idx}) > data_height_pixels ({data_height_pixels})")
104
+ issue_details = "\n - ".join(issues)
105
+ raise ValueError(f"Window for location {location} not available: \n - {issue_details}")
106
+
107
+ # Standard selection - without padding
108
+ da = da.isel({x_dim: slice(left_idx, right_idx), y_dim: slice(bottom_idx, top_idx)})
109
+
110
+ return da