ocf-data-sampler 0.5.7__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -30,7 +30,7 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
30
30
  fill_nans_in_arrays,
31
31
  merge_dicts,
32
32
  )
33
- from ocf_data_sampler.utils import compute, minutes
33
+ from ocf_data_sampler.utils import minutes, tensorstore_compute
34
34
 
35
35
  xr.set_options(keep_attrs=True)
36
36
 
@@ -254,7 +254,7 @@ class PVNetUKRegionalDataset(AbstractPVNetUKDataset):
254
254
  """
255
255
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
256
256
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
257
- sample_dict = compute(sample_dict)
257
+ sample_dict = tensorstore_compute(sample_dict)
258
258
 
259
259
  return self.process_and_combine_datasets(sample_dict, t0, location)
260
260
 
@@ -313,7 +313,7 @@ class PVNetUKConcurrentDataset(AbstractPVNetUKDataset):
313
313
  """
314
314
  # Slice by time then load to avoid loading the data multiple times from disk
315
315
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
316
- sample_dict = compute(sample_dict)
316
+ sample_dict = tensorstore_compute(sample_dict)
317
317
 
318
318
  gsp_samples = []
319
319
 
@@ -34,7 +34,7 @@ from ocf_data_sampler.torch_datasets.utils.merge_and_fill_utils import (
34
34
  fill_nans_in_arrays,
35
35
  merge_dicts,
36
36
  )
37
- from ocf_data_sampler.utils import compute, minutes
37
+ from ocf_data_sampler.utils import minutes, tensorstore_compute
38
38
 
39
39
  xr.set_options(keep_attrs=True)
40
40
 
@@ -272,7 +272,7 @@ class SitesDataset(Dataset):
272
272
  sample_dict = slice_datasets_by_space(self.datasets_dict, location, self.config)
273
273
  sample_dict = slice_datasets_by_time(sample_dict, t0, self.config)
274
274
 
275
- sample_dict = compute(sample_dict)
275
+ sample_dict = tensorstore_compute(sample_dict)
276
276
 
277
277
  return process_and_combine_datasets(
278
278
  sample_dict,
@@ -408,7 +408,7 @@ class SitesDatasetConcurrent(Dataset):
408
408
  """
409
409
  # slice by time first as we want to keep all site id info
410
410
  sample_dict = slice_datasets_by_time(self.datasets_dict, t0, self.config)
411
- sample_dict = compute(sample_dict)
411
+ sample_dict = tensorstore_compute(sample_dict)
412
412
 
413
413
  site_samples = []
414
414
 
ocf_data_sampler/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Miscellaneous helper functions."""
2
2
 
3
3
  import pandas as pd
4
+ from xarray_tensorstore import read
4
5
 
5
6
 
6
7
  def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex:
@@ -11,11 +12,26 @@ def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex:
11
12
  """
12
13
  return pd.to_timedelta(minutes, unit="m")
13
14
 
15
+
14
16
  def compute(xarray_dict: dict) -> dict:
15
17
  """Eagerly load a nested dictionary of xarray DataArrays."""
16
18
  for k, v in xarray_dict.items():
17
19
  if isinstance(v, dict):
18
20
  xarray_dict[k] = compute(v)
19
21
  else:
20
- xarray_dict[k] = v.compute(scheduler="single-threaded")
22
+ xarray_dict[k] = v.compute()
21
23
  return xarray_dict
24
+
25
+
26
+ def tensorstore_compute(xarray_dict: dict) -> dict:
27
+ """Eagerly read and load a nested dictionary of xarray-tensorstore DataArrays."""
28
+ # Kick off the tensorstore async reading
29
+ for k, v in xarray_dict.items():
30
+ if isinstance(v, dict):
31
+ xarray_dict[k] = tensorstore_compute(v)
32
+ else:
33
+ xarray_dict[k] = read(v)
34
+
35
+ # Running the compute function will wait until all arrays have been read
36
+ return compute(xarray_dict)
37
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.5.7
3
+ Version: 0.5.9
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -1,5 +1,5 @@
1
1
  ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
2
- ocf_data_sampler/utils.py,sha256=2NEl70ySdTpr0pbLRk4LGklvXe1Nv1hun9XKcDw7-44,610
2
+ ocf_data_sampler/utils.py,sha256=0Wlx7SNOJE5ZNs_F3m-fkKsF58haF-IxDxcQD2hKN34,1088
3
3
  ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
4
4
  ocf_data_sampler/config/load.py,sha256=LL-7wemI8o4KPkx35j-wQ3HjsMvDgqXr7G46IcASfnU,632
5
5
  ocf_data_sampler/config/model.py,sha256=Jss8UDJAaQIBDr9megX2pERoT0ocFmwLNFC8pCWN6VA,12386
@@ -40,8 +40,8 @@ ocf_data_sampler/select/location.py,sha256=AZvGR8y62opiW7zACGXjoOtBEWRfSLOZIA73O
40
40
  ocf_data_sampler/select/select_spatial_slice.py,sha256=Hd4jGRUfIZRoWCirOQZeoLpaUnStB6KyFSTPX69wZLw,8790
41
41
  ocf_data_sampler/select/select_time_slice.py,sha256=HeHbwZ0CP03x0-LaJtpbSdtpLufwVTR73p6wH6O_PS8,5513
42
42
  ocf_data_sampler/torch_datasets/datasets/__init__.py,sha256=o0SsEXXZ6k9iL__5_RN1Sf60lw_eqK91P3UFEHAD2k0,102
43
- ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=876oLukvb1nLtZQ8HBN3PWfN7urKH2xa45tVar7XrbM,12010
44
- ocf_data_sampler/torch_datasets/datasets/site.py,sha256=nn6N8daGxllYwCCiFKbCJANTl84NrDRl-nbNGcfXc3U,15429
43
+ ocf_data_sampler/torch_datasets/datasets/pvnet_uk.py,sha256=qbyvTOZZNcGioeH-DDoJmSf_KLRidiuBQRnrvZXD6ts,12046
44
+ ocf_data_sampler/torch_datasets/datasets/site.py,sha256=_FUV_KDe5k7acAmjE9Z2kYgxCFJZrLjziaZssIi1ipg,15465
45
45
  ocf_data_sampler/torch_datasets/sample/__init__.py,sha256=GL84vdZl_SjHDGVyh9Uekx2XhPYuZ0dnO3l6f6KXnHI,100
46
46
  ocf_data_sampler/torch_datasets/sample/base.py,sha256=cQ1oIyhdmlotejZK8B3Cw6MNvpdnBPD8G_o2h7Ye4Vc,2206
47
47
  ocf_data_sampler/torch_datasets/sample/site.py,sha256=40NwNTqjL1WVhPdwe02zDHHfDLG2u_bvCfRCtGAtFc0,1466
@@ -55,8 +55,7 @@ ocf_data_sampler/torch_datasets/utils/valid_time_periods.py,sha256=xcy75cVxl0Wrg
55
55
  ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul3l0EP73Ik002fStr_bhsZh9mQqEU,4735
56
56
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
57
57
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
58
- utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
59
- ocf_data_sampler-0.5.7.dist-info/METADATA,sha256=Nu2RLYiLYyU6nkLu8g__Q8EPFIgYMLu5cZLcLXAckXs,12816
60
- ocf_data_sampler-0.5.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- ocf_data_sampler-0.5.7.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
- ocf_data_sampler-0.5.7.dist-info/RECORD,,
58
+ ocf_data_sampler-0.5.9.dist-info/METADATA,sha256=LUgQmrakbDwIEfeP_3IojePDYDdvm15iUtftl5o8Rps,12816
59
+ ocf_data_sampler-0.5.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
+ ocf_data_sampler-0.5.9.dist-info/top_level.txt,sha256=deUxqmsONNAGZDNbsntbXH7BRA1MqWaUeAJrCo6q_xA,25
61
+ ocf_data_sampler-0.5.9.dist-info/RECORD,,
@@ -1,3 +1,2 @@
1
1
  ocf_data_sampler
2
2
  scripts
3
- utils
@@ -1,72 +0,0 @@
1
- """Script to compute normalisation constants from NWP data."""
2
-
3
- import argparse
4
- import glob
5
- import logging
6
-
7
- import numpy as np
8
- import xarray as xr
9
-
10
- from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu
11
-
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Add argument parser
17
- parser = argparse.ArgumentParser(description="Compute normalization constants from NWP data")
18
- parser.add_argument("--data-path", type=str, required=True,
19
- help='Path pattern to zarr files (e.g., "/path/to/data/*.zarr.zip")')
20
- parser.add_argument("--n-samples", type=int, default=2000,
21
- help="Number of random samples to use (default: 2000)")
22
-
23
- args = parser.parse_args()
24
-
25
- zarr_files = glob.glob(args.data_path)
26
- n_samples = args.n_samples
27
-
28
- ds = open_icon_eu(zarr_files)
29
-
30
- n_init_times = ds.sizes["init_time_utc"]
31
- n_lats = ds.sizes["latitude"]
32
- n_longs = ds.sizes["longitude"]
33
- n_steps = ds.sizes["step"]
34
-
35
- random_init_times = np.random.choice(n_init_times, size=n_samples, replace=True)
36
- random_lats = np.random.choice(n_lats, size=n_samples, replace=True)
37
- random_longs = np.random.choice(n_longs, size=n_samples, replace=True)
38
- random_steps = np.random.choice(n_steps, size=n_samples, replace=True)
39
-
40
- samples = []
41
- for i in range(n_samples):
42
- sample = ds.isel(init_time_utc=random_init_times[i],
43
- latitude=random_lats[i],
44
- longitude=random_longs[i],
45
- step=random_steps[i])
46
- samples.append(sample)
47
-
48
- samples_stack = xr.concat(samples, dim="samples")
49
-
50
-
51
- available_channels = samples_stack.channel.values.tolist()
52
- logger.info("Available channels: %s", available_channels)
53
-
54
- ICON_EU_MEAN = {}
55
- ICON_EU_STD = {}
56
-
57
- for var in available_channels:
58
- if var not in available_channels:
59
- logger.warning("Variable '%s' not found in the channel coordinate; skipping.", var)
60
- continue
61
- var_data = samples_stack.sel(channel=var)
62
- var_mean = float(var_data.mean().compute())
63
- var_std = float(var_data.std().compute())
64
-
65
- ICON_EU_MEAN[var] = var_mean
66
- ICON_EU_STD[var] = var_std
67
-
68
- logger.info("Processed %s: mean=%.4f, std=%.4f", var, var_mean, var_std)
69
-
70
- logger.info("\nMean values:\n%s", ICON_EU_MEAN)
71
- logger.info("\nStandard deviations:\n%s", ICON_EU_STD)
72
-