ocf-data-sampler 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ocf-data-sampler might be problematic. Click here for more details.

@@ -90,11 +90,10 @@ class DropoutMixin(Base):
90
90
  "negative or zero.",
91
91
  )
92
92
 
93
- dropout_fraction: float = Field(
93
+ dropout_fraction: float|list[float] = Field(
94
94
  default=0,
95
- description="Chance of dropout being applied to each sample",
96
- ge=0,
97
- le=1,
95
+ description="Either a float(Chance of dropout being applied to each sample) or a list of "
96
+ "floats (probability that dropout of the corresponding timedelta is applied)",
98
97
  )
99
98
 
100
99
  @field_validator("dropout_timedeltas_minutes")
@@ -105,6 +104,36 @@ class DropoutMixin(Base):
105
104
  raise ValueError("Dropout timedeltas must be negative")
106
105
  return v
107
106
 
107
+
108
+ @field_validator("dropout_fraction")
109
+ def dropout_fractions(cls, dropout_frac: float|list[float]) -> float|list[float]:
110
+ """Validate 'dropout_frac'."""
111
+ from math import isclose
112
+ if isinstance(dropout_frac, float):
113
+ if not (dropout_frac <= 1):
114
+ raise ValueError("Input should be less than or equal to 1")
115
+ elif not (dropout_frac >= 0):
116
+ raise ValueError("Input should be greater than or equal to 0")
117
+
118
+ elif isinstance(dropout_frac, list):
119
+ if not dropout_frac:
120
+ raise ValueError("List cannot be empty")
121
+
122
+ if not all(isinstance(i, float) for i in dropout_frac):
123
+ raise ValueError("All elements in the list must be floats")
124
+
125
+ if not all(0 <= i <= 1 for i in dropout_frac):
126
+ raise ValueError("Each float in the list must be between 0 and 1")
127
+
128
+ if not isclose(sum(dropout_frac), 1.0, rel_tol=1e-9):
129
+ raise ValueError("Sum of all floats in the list must be 1.0")
130
+
131
+
132
+ else:
133
+ raise TypeError("Must be either a float or a list of floats")
134
+ return dropout_frac
135
+
136
+
108
137
  @model_validator(mode="after")
109
138
  def dropout_instructions_consistent(self) -> "DropoutMixin":
110
139
  """Validator for dropout instructions."""
@@ -12,7 +12,7 @@ import xarray as xr
12
12
  def apply_sampled_dropout_time(
13
13
  t0: pd.Timestamp,
14
14
  dropout_timedeltas: list[pd.Timedelta],
15
- dropout_frac: float,
15
+ dropout_frac: float|list[float],
16
16
  da: xr.DataArray,
17
17
  ) -> xr.DataArray:
18
18
  """Randomly pick a dropout time from a list of timedeltas and apply dropout time to the data.
@@ -20,28 +20,42 @@ def apply_sampled_dropout_time(
20
20
  Args:
21
21
  t0: The forecast init-time
22
22
  dropout_timedeltas: List of timedeltas relative to t0 to pick from
23
- dropout_frac: Probability that dropout will be applied.
24
- This should be between 0 and 1 inclusive
23
+ dropout_frac: Either a probability that dropout will be applied.
24
+ This should be between 0 and 1 inclusive.
25
+ Or a list of probabilities for each of the corresponding timedeltas
25
26
  da: Xarray DataArray with 'time_utc' coordinate
26
27
  """
27
- # sample dropout time
28
- if dropout_frac > 0 and len(dropout_timedeltas) == 0:
29
- raise ValueError("To apply dropout, dropout_timedeltas must be provided")
28
+ if isinstance(dropout_frac, list):
29
+ # checking if len match
30
+ if len(dropout_frac) != len(dropout_timedeltas):
31
+ raise ValueError("Lengths of dropout_frac and dropout_timedeltas should match")
30
32
 
31
- for t in dropout_timedeltas:
32
- if t > pd.Timedelta("0min"):
33
- raise ValueError("Dropout timedeltas must be negative")
34
33
 
35
- if not (0 <= dropout_frac <= 1):
36
- raise ValueError("dropout_frac must be between 0 and 1 inclusive")
37
34
 
38
- if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
39
- dropout_time = None
35
+
36
+ dropout_time = t0 + np.random.choice(dropout_timedeltas,p=dropout_frac)
37
+
38
+ return da.where(da.time_utc <= dropout_time)
39
+
40
+
41
+
42
+ # old logic
40
43
  else:
41
- dropout_time = t0 + np.random.choice(dropout_timedeltas)
44
+ # sample dropout time
45
+ if dropout_frac > 0 and len(dropout_timedeltas) == 0:
46
+ raise ValueError("To apply dropout, dropout_timedeltas must be provided")
47
+
48
+
49
+ if not (0 <= dropout_frac <= 1):
50
+ raise ValueError("dropout_frac must be between 0 and 1 inclusive")
51
+
52
+ if (len(dropout_timedeltas) == 0) or (np.random.uniform() >= dropout_frac):
53
+ dropout_time = None
54
+ else:
55
+ dropout_time = t0 + np.random.choice(dropout_timedeltas)
42
56
 
43
- # apply dropout time
44
- if dropout_time is None:
45
- return da
46
- # This replaces the times after the dropout with NaNs
47
- return da.where(da.time_utc <= dropout_time)
57
+ # apply dropout time
58
+ if dropout_time is None:
59
+ return da
60
+ # This replaces the times after the dropout with NaNs
61
+ return da.where(da.time_utc <= dropout_time)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ocf-data-sampler
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Author: James Fulton, Peter Dudfield
5
5
  Author-email: Open Climate Fix team <info@openclimatefix.org>
6
6
  License: MIT License
@@ -2,7 +2,7 @@ ocf_data_sampler/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,
2
2
  ocf_data_sampler/utils.py,sha256=2NEl70ySdTpr0pbLRk4LGklvXe1Nv1hun9XKcDw7-44,610
3
3
  ocf_data_sampler/config/__init__.py,sha256=O29mbH0XG2gIY1g3BaveGCnpBO2SFqdu-qzJ7a6evl0,223
4
4
  ocf_data_sampler/config/load.py,sha256=LL-7wemI8o4KPkx35j-wQ3HjsMvDgqXr7G46IcASfnU,632
5
- ocf_data_sampler/config/model.py,sha256=xX2PPywEFGYDsx_j9DX1GlwMRq3ovJR-mhmysMt_mO0,11116
5
+ ocf_data_sampler/config/model.py,sha256=Jss8UDJAaQIBDr9megX2pERoT0ocFmwLNFC8pCWN6VA,12386
6
6
  ocf_data_sampler/config/save.py,sha256=m8SPw5rXjkMm1rByjh3pK5StdBi4e8ysnn3jQopdRaI,1064
7
7
  ocf_data_sampler/data/uk_gsp_locations_20220314.csv,sha256=RSh7DRh55E3n8lVAaWXGTaXXHevZZtI58td4d4DhGos,10415772
8
8
  ocf_data_sampler/data/uk_gsp_locations_20250109.csv,sha256=XZISFatnbpO9j8LwaxNKFzQSjs6hcHFsV8a9uDDpy2E,9055334
@@ -32,7 +32,7 @@ ocf_data_sampler/numpy_sample/satellite.py,sha256=RaYzYIcB1AmDrKeiqSpn4QVfBH-QMe
32
32
  ocf_data_sampler/numpy_sample/site.py,sha256=zfYBjK3CJrIaKH1QdKXU7gwOxTqONt527y3nJ9TRnwc,1325
33
33
  ocf_data_sampler/numpy_sample/sun_position.py,sha256=5tt-zNm6aRuZMsxZPaAxyg7HeikswfZCeHWXTHuO2K0,1555
34
34
  ocf_data_sampler/select/__init__.py,sha256=mK7Wu_-j9IXGTYrOuDf5yDDuU5a306b0iGKTAooNg_s,210
35
- ocf_data_sampler/select/dropout.py,sha256=9gPyDF7bGmvSoMjMPu1j0gTZFHNFqsT3ToIo9mFNA00,1565
35
+ ocf_data_sampler/select/dropout.py,sha256=BYpv8L771faPOyN7SdIJ5cwkpDve-ohClj95jjsHmjg,1973
36
36
  ocf_data_sampler/select/fill_time_periods.py,sha256=TlGxp1xiAqnhdWfLy0pv3FuZc00dtimjWdLzr4JoTGA,865
37
37
  ocf_data_sampler/select/find_contiguous_time_periods.py,sha256=etkr6LuB7zxkfzWJ6SgHiULdRuFzFlq5bOUNd257Qx4,11545
38
38
  ocf_data_sampler/select/geospatial.py,sha256=CDExkl36eZOKmdJPzUr_K0Wn3axHqv5nYo-EkSiINcc,5032
@@ -56,7 +56,7 @@ ocf_data_sampler/torch_datasets/utils/validation_utils.py,sha256=YqmT-lExWlI8_ul
56
56
  scripts/download_gsp_location_data.py,sha256=rRDXMoqX-RYY4jPdxhdlxJGhWdl6r245F5UARgKV6P4,3121
57
57
  scripts/refactor_site.py,sha256=skzvsPP0Cn9yTKndzkilyNcGz4DZ88ctvCJ0XrBdc2A,3135
58
58
  utils/compute_icon_mean_stddev.py,sha256=a1oWMRMnny39rV-dvu8rcx85sb4bXzPFrR1gkUr4Jpg,2296
59
- ocf_data_sampler-0.3.0.dist-info/METADATA,sha256=Kq7LhwcpxOpfu4S4NOq-JHFJYI7eeeuxPleNPx6UMLE,12224
60
- ocf_data_sampler-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- ocf_data_sampler-0.3.0.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
- ocf_data_sampler-0.3.0.dist-info/RECORD,,
59
+ ocf_data_sampler-0.3.1.dist-info/METADATA,sha256=pQpPqmpTlUiZnPY1Q_xZr1Z-GrKSATG_P77YYHpWm6Y,12224
60
+ ocf_data_sampler-0.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
+ ocf_data_sampler-0.3.1.dist-info/top_level.txt,sha256=LEFU4Uk-PEo72QGLAfnVZIUEm37Q8mKuMeg_Xk-p33g,31
62
+ ocf_data_sampler-0.3.1.dist-info/RECORD,,