spotforecast2 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spotforecast2/.DS_Store +0 -0
- spotforecast2/__init__.py +2 -0
- spotforecast2/data/__init__.py +0 -0
- spotforecast2/data/data.py +130 -0
- spotforecast2/data/fetch_data.py +209 -0
- spotforecast2/exceptions.py +681 -0
- spotforecast2/forecaster/.DS_Store +0 -0
- spotforecast2/forecaster/__init__.py +7 -0
- spotforecast2/forecaster/base.py +448 -0
- spotforecast2/forecaster/metrics.py +527 -0
- spotforecast2/forecaster/recursive/__init__.py +4 -0
- spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
- spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
- spotforecast2/forecaster/recursive/_warnings.py +15 -0
- spotforecast2/forecaster/utils.py +954 -0
- spotforecast2/model_selection/__init__.py +5 -0
- spotforecast2/model_selection/bayesian_search.py +453 -0
- spotforecast2/model_selection/grid_search.py +314 -0
- spotforecast2/model_selection/random_search.py +151 -0
- spotforecast2/model_selection/split_base.py +357 -0
- spotforecast2/model_selection/split_one_step.py +245 -0
- spotforecast2/model_selection/split_ts_cv.py +634 -0
- spotforecast2/model_selection/utils_common.py +718 -0
- spotforecast2/model_selection/utils_metrics.py +103 -0
- spotforecast2/model_selection/validation.py +685 -0
- spotforecast2/preprocessing/__init__.py +30 -0
- spotforecast2/preprocessing/_binner.py +378 -0
- spotforecast2/preprocessing/_common.py +123 -0
- spotforecast2/preprocessing/_differentiator.py +123 -0
- spotforecast2/preprocessing/_rolling.py +136 -0
- spotforecast2/preprocessing/curate_data.py +254 -0
- spotforecast2/preprocessing/imputation.py +92 -0
- spotforecast2/preprocessing/outlier.py +114 -0
- spotforecast2/preprocessing/split.py +139 -0
- spotforecast2/py.typed +0 -0
- spotforecast2/utils/__init__.py +43 -0
- spotforecast2/utils/convert_to_utc.py +44 -0
- spotforecast2/utils/data_transform.py +208 -0
- spotforecast2/utils/forecaster_config.py +344 -0
- spotforecast2/utils/generate_holiday.py +70 -0
- spotforecast2/utils/validation.py +569 -0
- spotforecast2/weather/__init__.py +0 -0
- spotforecast2/weather/weather_client.py +288 -0
- spotforecast2-0.0.1.dist-info/METADATA +47 -0
- spotforecast2-0.0.1.dist-info/RECORD +46 -0
- spotforecast2-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import List, Any
|
|
4
|
+
from ._common import (
|
|
5
|
+
_np_mean_jit,
|
|
6
|
+
_np_std_jit,
|
|
7
|
+
_np_min_jit,
|
|
8
|
+
_np_max_jit,
|
|
9
|
+
_np_sum_jit,
|
|
10
|
+
_np_median_jit,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RollingFeatures:
|
|
15
|
+
"""
|
|
16
|
+
Compute rolling features (stats) over a window of the time series.
|
|
17
|
+
Compatible with scikit-learn transformers API (fit, transform).
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
stats_funcs (list): List of rolling statistics functions.
|
|
21
|
+
window_sizes (list): List of window sizes.
|
|
22
|
+
features_names (list): List of feature names.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
stats: str | List[str] | List[Any],
|
|
28
|
+
window_sizes: int | List[int],
|
|
29
|
+
features_names: List[str] | None = None,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the rolling features transformer.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
stats (str | List[str] | List[Any]): Rolling statistics to compute.
|
|
36
|
+
window_sizes (int | List[int]): Window sizes for rolling statistics.
|
|
37
|
+
features_names (List[str] | None, optional): Names of the features.
|
|
38
|
+
Defaults to None.
|
|
39
|
+
"""
|
|
40
|
+
self.stats = stats
|
|
41
|
+
self.window_sizes = window_sizes
|
|
42
|
+
self.features_names = features_names
|
|
43
|
+
|
|
44
|
+
# Validation and processing logic...
|
|
45
|
+
self._validate_params()
|
|
46
|
+
|
|
47
|
+
def _validate_params(self):
|
|
48
|
+
"""
|
|
49
|
+
Validate the parameters of the rolling features transformer.
|
|
50
|
+
"""
|
|
51
|
+
if isinstance(self.window_sizes, int):
|
|
52
|
+
self.window_sizes = [self.window_sizes]
|
|
53
|
+
|
|
54
|
+
if isinstance(self.stats, str):
|
|
55
|
+
self.stats = [self.stats]
|
|
56
|
+
|
|
57
|
+
# Map strings to functions
|
|
58
|
+
valid_stats = {
|
|
59
|
+
"mean": _np_mean_jit,
|
|
60
|
+
"std": _np_std_jit,
|
|
61
|
+
"min": _np_min_jit,
|
|
62
|
+
"max": _np_max_jit,
|
|
63
|
+
"sum": _np_sum_jit,
|
|
64
|
+
"median": _np_median_jit,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
self.stats_funcs = []
|
|
68
|
+
for s in self.stats:
|
|
69
|
+
if isinstance(s, str):
|
|
70
|
+
if s not in valid_stats:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Stat '{s}' not supported. Supported: {list(valid_stats.keys())}"
|
|
73
|
+
)
|
|
74
|
+
self.stats_funcs.append(valid_stats[s])
|
|
75
|
+
else:
|
|
76
|
+
self.stats_funcs.append(s)
|
|
77
|
+
|
|
78
|
+
if self.features_names is None:
|
|
79
|
+
self.features_names = []
|
|
80
|
+
for ws in self.window_sizes:
|
|
81
|
+
for s in self.stats:
|
|
82
|
+
s_name = s if isinstance(s, str) else s.__name__
|
|
83
|
+
self.features_names.append(f"roll_{s_name}_{ws}")
|
|
84
|
+
|
|
85
|
+
def fit(self, X, y=None):
|
|
86
|
+
"""
|
|
87
|
+
Fit the rolling features transformer.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
X (np.ndarray): Time series to transform.
|
|
91
|
+
y (object, optional): Ignored.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
self: Fitted rolling features transformer.
|
|
95
|
+
"""
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
def transform(self, X: np.ndarray) -> np.ndarray:
|
|
99
|
+
"""
|
|
100
|
+
Compute rolling features.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
X (np.ndarray): Time series to transform.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
np.ndarray: Array with rolling features.
|
|
107
|
+
"""
|
|
108
|
+
# Assume X is 1D array
|
|
109
|
+
n_samples = len(X)
|
|
110
|
+
output = np.full((n_samples, len(self.features_names)), np.nan)
|
|
111
|
+
|
|
112
|
+
idx_feature = 0
|
|
113
|
+
for ws in self.window_sizes:
|
|
114
|
+
for func in self.stats_funcs:
|
|
115
|
+
# Naive rolling window loop - can be optimized or use pandas rolling
|
|
116
|
+
# Using pandas for simplicity and speed if X is convertible
|
|
117
|
+
series = pd.Series(X)
|
|
118
|
+
rolled = series.rolling(window=ws).apply(func, raw=True)
|
|
119
|
+
output[:, idx_feature] = rolled.values
|
|
120
|
+
idx_feature += 1
|
|
121
|
+
|
|
122
|
+
return output
|
|
123
|
+
|
|
124
|
+
def transform_batch(self, X: pd.Series) -> pd.DataFrame:
|
|
125
|
+
"""
|
|
126
|
+
Transform a pandas Series to rolling features DataFrame.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
X (pd.Series): Time series to transform.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
pd.DataFrame: DataFrame with rolling features.
|
|
133
|
+
"""
|
|
134
|
+
values = X.to_numpy()
|
|
135
|
+
transformed = self.transform(values)
|
|
136
|
+
return pd.DataFrame(transformed, index=X.index, columns=self.features_names)
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_start_end(
|
|
5
|
+
data: pd.DataFrame,
|
|
6
|
+
forecast_horizon: int,
|
|
7
|
+
verbose: bool = True,
|
|
8
|
+
) -> tuple[str, str, str, str]:
|
|
9
|
+
"""Get start and end date strings for data and covariate ranges.
|
|
10
|
+
Covariate range is extended by the forecast horizon.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
data (pd.DataFrame):
|
|
14
|
+
The dataset with a datetime index.
|
|
15
|
+
forecast_horizon (int):
|
|
16
|
+
The forecast horizon in hours.
|
|
17
|
+
verbose (bool):
|
|
18
|
+
Whether to print the determined date ranges.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
tuple[str, str, str, str]: (data_start, data_end, covariate_start, covariate_end)
|
|
22
|
+
Date strings in the format "YYYY-MM-DDTHH:MM" for data and covariate ranges.
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
>>> from spotforecast2.preprocessing.curate_data import get_start_end
|
|
26
|
+
>>> import pandas as pd
|
|
27
|
+
>>> date_rng = pd.date_range(start='2023-01-01', end='2023-01-10', freq='h')
|
|
28
|
+
>>> data = pd.DataFrame(date_rng, columns=['date'])
|
|
29
|
+
>>> data.set_index('date', inplace=True)
|
|
30
|
+
>>> start, end, cov_start, cov_end = get_start_end(data, forecast_horizon=24, verbose=False)
|
|
31
|
+
>>> print(start, end, cov_start, cov_end)
|
|
32
|
+
2023-01-01T00:00 2023-01-10T00:00 2023-01-01T00:00 2023-01-11T00:00
|
|
33
|
+
"""
|
|
34
|
+
FORECAST_HORIZON = forecast_horizon
|
|
35
|
+
|
|
36
|
+
START = data.index.min().strftime("%Y-%m-%dT%H:%M")
|
|
37
|
+
END = data.index.max().strftime("%Y-%m-%dT%H:%M")
|
|
38
|
+
if verbose:
|
|
39
|
+
print(f"Data range: {START} to {END}")
|
|
40
|
+
# Define covariate range relative to data range
|
|
41
|
+
COV_START = START
|
|
42
|
+
# Extend end date by forecast horizon to include future covariates
|
|
43
|
+
COV_END = (pd.to_datetime(END) + pd.Timedelta(hours=FORECAST_HORIZON)).strftime(
|
|
44
|
+
"%Y-%m-%dT%H:%M"
|
|
45
|
+
)
|
|
46
|
+
if verbose:
|
|
47
|
+
print(f"Covariate data range: {COV_START} to {COV_END}")
|
|
48
|
+
return START, END, COV_START, COV_END
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def curate_holidays(
|
|
52
|
+
holiday_df: pd.DataFrame, data: pd.DataFrame, forecast_horizon: int
|
|
53
|
+
):
|
|
54
|
+
"""Checks if the holiday dataframe has the correct shape.
|
|
55
|
+
Args:
|
|
56
|
+
holiday_df (pd.DataFrame):
|
|
57
|
+
DataFrame containing holiday information.
|
|
58
|
+
data (pd.DataFrame):
|
|
59
|
+
The main dataset.
|
|
60
|
+
forecast_horizon (int):
|
|
61
|
+
The forecast horizon in hours.
|
|
62
|
+
|
|
63
|
+
Examples:
|
|
64
|
+
>>> from spotforecast2.data.fetch_data import fetch_data, fetch_holiday_data
|
|
65
|
+
>>> from spotforecast2.preprocessing.curate_data import get_start_end, curate_holidays
|
|
66
|
+
>>> data = fetch_data()
|
|
67
|
+
>>> START, END, COV_START, COV_END = get_start_end(
|
|
68
|
+
... data=data,
|
|
69
|
+
... forecast_horizon=24,
|
|
70
|
+
... verbose=False
|
|
71
|
+
... )
|
|
72
|
+
>>> holiday_df = fetch_holiday_data(
|
|
73
|
+
... start='2023-01-01T00:00',
|
|
74
|
+
... end='2023-01-10T00:00',
|
|
75
|
+
... tz='UTC',
|
|
76
|
+
... freq='h',
|
|
77
|
+
... country_code='DE',
|
|
78
|
+
... state='NW'
|
|
79
|
+
... )
|
|
80
|
+
>>> FORECAST_HORIZON = 24
|
|
81
|
+
>>> curate_holidays(holiday_df, data, forecast_horizon=FORECAST_HORIZON)
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
AssertionError:
|
|
85
|
+
If the holiday dataframe does not have the correct number of rows.
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
assert holiday_df.shape[0] == data.shape[0] + forecast_horizon
|
|
89
|
+
print("Holiday dataframe has correct shape.")
|
|
90
|
+
except AssertionError:
|
|
91
|
+
print("Holiday dataframe has wrong shape.")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def curate_weather(weather_df: pd.DataFrame, data: pd.DataFrame, forecast_horizon: int):
|
|
95
|
+
"""Checks if the weather dataframe has the correct shape.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
weather_df (pd.DataFrame):
|
|
99
|
+
DataFrame containing weather information.
|
|
100
|
+
data (pd.DataFrame):
|
|
101
|
+
The main dataset.
|
|
102
|
+
forecast_horizon (int):
|
|
103
|
+
The forecast horizon in hours.
|
|
104
|
+
|
|
105
|
+
Examples:
|
|
106
|
+
>>> from spotforecast2.data.fetch_data import fetch_data, fetch_weather_data
|
|
107
|
+
>>> from spotforecast2.preprocessing.curate_data import get_start_end, curate_weather
|
|
108
|
+
>>> data = fetch_data()
|
|
109
|
+
>>> START, END, COV_START, COV_END = get_start_end(
|
|
110
|
+
... data=data,
|
|
111
|
+
... forecast_horizon=24,
|
|
112
|
+
... verbose=False
|
|
113
|
+
... )
|
|
114
|
+
>>> weather_df = fetch_weather_data(
|
|
115
|
+
... cov_start=COV_START,
|
|
116
|
+
... cov_end=COV_END,
|
|
117
|
+
... tz='UTC',
|
|
118
|
+
... freq='h',
|
|
119
|
+
... latitude=51.5136,
|
|
120
|
+
... longitude=7.4653
|
|
121
|
+
... )
|
|
122
|
+
>>> FORECAST_HORIZON = 24
|
|
123
|
+
>>> curate_weather(weather_df, data, forecast_horizon=FORECAST_HORIZON)
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
AssertionError:
|
|
127
|
+
If the weather dataframe does not have the correct number of rows.
|
|
128
|
+
"""
|
|
129
|
+
try:
|
|
130
|
+
assert weather_df.shape[0] == data.shape[0] + forecast_horizon
|
|
131
|
+
print("Weather dataframe has correct shape.")
|
|
132
|
+
except AssertionError:
|
|
133
|
+
print("Weather dataframe has wrong shape.")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def basic_ts_checks(data: pd.DataFrame, verbose: bool = False) -> bool:
|
|
137
|
+
"""Checks if the time series data has a datetime index and is sorted.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
data (pd.DataFrame):
|
|
141
|
+
The main dataset.
|
|
142
|
+
verbose (bool):
|
|
143
|
+
Whether to print additional information.
|
|
144
|
+
|
|
145
|
+
Examples:
|
|
146
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
147
|
+
>>> from spotforecast2.preprocessing.curate_data import basic_ts_checks
|
|
148
|
+
>>> data = fetch_data()
|
|
149
|
+
>>> basic_ts_checks(data)
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
TypeError:
|
|
153
|
+
If the index is not a datetime index.
|
|
154
|
+
ValueError:
|
|
155
|
+
If the datetime index is not sorted in increasing order or is incomplete.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
bool: True if the datetime index is valid, sorted, and complete.
|
|
159
|
+
"""
|
|
160
|
+
# Check if the time series data has a datetime index
|
|
161
|
+
if not pd.api.types.is_datetime64_any_dtype(data.index):
|
|
162
|
+
raise TypeError("The index is not a datetime index.")
|
|
163
|
+
|
|
164
|
+
# Check if the datetime index is sorted
|
|
165
|
+
if not data.index.is_monotonic_increasing:
|
|
166
|
+
raise ValueError("The datetime index is not sorted in increasing order.")
|
|
167
|
+
|
|
168
|
+
# Check if the index is complete (no missing timestamps)
|
|
169
|
+
start_date = data.index.min()
|
|
170
|
+
end_date = data.index.max()
|
|
171
|
+
complete_date_range = pd.date_range(
|
|
172
|
+
start=start_date, end=end_date, freq=data.index.freq
|
|
173
|
+
)
|
|
174
|
+
is_index_complete = (data.index == complete_date_range).all()
|
|
175
|
+
|
|
176
|
+
if not is_index_complete:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"The datetime index has missing timestamps and is not complete."
|
|
179
|
+
)
|
|
180
|
+
if verbose:
|
|
181
|
+
print(
|
|
182
|
+
"The time series data has a valid datetime index that is sorted and complete."
|
|
183
|
+
)
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def agg_and_resample_data(
|
|
188
|
+
data: pd.DataFrame,
|
|
189
|
+
rule: str = "h",
|
|
190
|
+
closed: str = "left",
|
|
191
|
+
label: str = "left",
|
|
192
|
+
by="mean",
|
|
193
|
+
verbose: bool = False,
|
|
194
|
+
) -> pd.DataFrame:
|
|
195
|
+
"""
|
|
196
|
+
Aggregates and resamples the data to (e.g.,hourly) frequency by computing the specified aggregation (e.g. for each hour).
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
data (pd.DataFrame):
|
|
200
|
+
The dataset with a datetime index.
|
|
201
|
+
rule (str):
|
|
202
|
+
The resample rule (e.g., 'h' for hourly, 'D' for daily).
|
|
203
|
+
Default is 'h' which creates an hourly grid.
|
|
204
|
+
closed (str):
|
|
205
|
+
Which side of bin interval is closed. Default is 'left'.
|
|
206
|
+
Using `closed="left", label="left"` specifies that a time interval
|
|
207
|
+
(e.g., 10:00 to 11:00) is labeled with the start timestamp (10:00).
|
|
208
|
+
For consumption data, a different representation is usually more common:
|
|
209
|
+
`closed="left", label="right"`, so the interval is labeled with the end
|
|
210
|
+
timestamp (11:00), since consumption is typically reported after one hour.
|
|
211
|
+
label (str):
|
|
212
|
+
Which bin edge label to use. Default is 'left'.
|
|
213
|
+
See 'closed' parameter for details on labeling behavior.
|
|
214
|
+
by (str or callable):
|
|
215
|
+
Aggregation method to apply (e.g., 'mean', 'sum', 'median').
|
|
216
|
+
Default is 'mean'.
|
|
217
|
+
The aggregation serves robustness: if the data were more finely resolved
|
|
218
|
+
(e.g., quarter-hourly), asfreq would only pick one value (sampling),
|
|
219
|
+
while .agg("mean") forms the correct average over the hour.
|
|
220
|
+
If the data is already hourly, .agg doesn't change anything but ensures
|
|
221
|
+
that no duplicates exist.
|
|
222
|
+
verbose (bool):
|
|
223
|
+
Whether to print additional information.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
pd.DataFrame: Resampled and aggregated dataframe.
|
|
227
|
+
|
|
228
|
+
Notes:
|
|
229
|
+
- resample(rule="h"): Creates an hourly grid
|
|
230
|
+
- closed/label: Control how time intervals are labeled
|
|
231
|
+
- .agg({...: by}): Aggregates values within each time bin
|
|
232
|
+
|
|
233
|
+
Examples::
|
|
234
|
+
>>> from spotforecast2.preprocessing.curate_data import agg_and_resample_data
|
|
235
|
+
>>> import pandas as pd
|
|
236
|
+
>>> date_rng = pd.date_range(start='2023-01-01', end='2023-01-02', freq='15T')
|
|
237
|
+
>>> data = pd.DataFrame(date_rng, columns=['date'])
|
|
238
|
+
>>> data.set_index('date', inplace=True)
|
|
239
|
+
>>> data['value'] = range(len(data))
|
|
240
|
+
>>> resampled_data = agg_and_resample_data(data, rule='h', by='mean')
|
|
241
|
+
>>> print(resampled_data.head())
|
|
242
|
+
"""
|
|
243
|
+
if verbose:
|
|
244
|
+
print(f"Original data shape: {data.shape}")
|
|
245
|
+
# Create aggregation dictionary for all columns
|
|
246
|
+
agg_dict = {col: by for col in data.columns}
|
|
247
|
+
|
|
248
|
+
data = data.resample(rule=rule, closed=closed, label=label).agg(agg_dict)
|
|
249
|
+
if verbose:
|
|
250
|
+
print(
|
|
251
|
+
f"Data resampled with rule='{rule}', closed='{closed}', label='{label}', aggregation='{by}'."
|
|
252
|
+
)
|
|
253
|
+
print(f"Resampled data shape: {data.shape}")
|
|
254
|
+
return data
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def custom_weights(index, weights_series: pd.Series) -> float:
|
|
5
|
+
"""
|
|
6
|
+
Return 0 if index is in or near any gap.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
index (pd.Index):
|
|
10
|
+
The index to check.
|
|
11
|
+
weights_series (pd.Series):
|
|
12
|
+
Series containing weights.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
float: The weight corresponding to the index.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
19
|
+
>>> from spotforecast2.preprocessing.imputation import custom_weights
|
|
20
|
+
>>> data = fetch_data()
|
|
21
|
+
>>> _, missing_weights = get_missing_weights(data, window_size=72, verbose=False)
|
|
22
|
+
>>> for idx in data.index[:5]:
|
|
23
|
+
... weight = custom_weights(idx, missing_weights)
|
|
24
|
+
... print(f"Index: {idx}, Weight: {weight}")
|
|
25
|
+
"""
|
|
26
|
+
# do plausibility check
|
|
27
|
+
if isinstance(index, pd.Index):
|
|
28
|
+
if not index.isin(weights_series.index).all():
|
|
29
|
+
raise ValueError("Index not found in weights_series.")
|
|
30
|
+
return weights_series.loc[index].values
|
|
31
|
+
|
|
32
|
+
if index not in weights_series.index:
|
|
33
|
+
raise ValueError("Index not found in weights_series.")
|
|
34
|
+
return weights_series.loc[index]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_missing_weights(
|
|
38
|
+
data: pd.DataFrame, window_size: int = 72, verbose: bool = False
|
|
39
|
+
) -> tuple[pd.DataFrame, pd.Series]:
|
|
40
|
+
"""
|
|
41
|
+
Return imputed DataFrame and a series indicating missing weights.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
data (pd.DataFrame):
|
|
45
|
+
The input dataset.
|
|
46
|
+
window_size (int):
|
|
47
|
+
The size of the rolling window to consider for missing values.
|
|
48
|
+
verbose (bool):
|
|
49
|
+
Whether to print additional information.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Tuple[pd.DataFrame, pd.Series]:
|
|
53
|
+
A tuple containing the forward and backward filled DataFrame and a boolean series where True indicates missing weights.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
57
|
+
>>> from spotforecast2.preprocessing.imputation import get_missing_weights
|
|
58
|
+
>>> data = fetch_data()
|
|
59
|
+
>>> filled_data, missing_weights = get_missing_weights(data, window_size=72, verbose=True)
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
# first perform some checks if dataframe has enough data and if window_size is appropriate
|
|
63
|
+
if data.shape[0] == 0:
|
|
64
|
+
raise ValueError("Input data is empty.")
|
|
65
|
+
if window_size <= 0:
|
|
66
|
+
raise ValueError("window_size must be a positive integer.")
|
|
67
|
+
if window_size >= data.shape[0]:
|
|
68
|
+
raise ValueError("window_size must be smaller than the number of rows in data.")
|
|
69
|
+
|
|
70
|
+
missing_indices = data.index[data.isnull().any(axis=1)]
|
|
71
|
+
n_missing = len(missing_indices)
|
|
72
|
+
if verbose:
|
|
73
|
+
pct_missing = (n_missing / len(data)) * 100
|
|
74
|
+
print(f"Number of rows with missing values: {n_missing}")
|
|
75
|
+
print(f"Percentage of rows with missing values: {pct_missing:.2f}%")
|
|
76
|
+
print(f"missing_indices: {missing_indices}")
|
|
77
|
+
data = data.ffill()
|
|
78
|
+
data = data.bfill()
|
|
79
|
+
|
|
80
|
+
is_missing = pd.Series(0, index=data.index)
|
|
81
|
+
is_missing.loc[missing_indices] = 1
|
|
82
|
+
weights_series = 1 - is_missing.rolling(window=window_size + 1, min_periods=1).max()
|
|
83
|
+
if verbose:
|
|
84
|
+
n_missing_after = weights_series.isna().sum()
|
|
85
|
+
pct_missing_after = (n_missing_after / len(data)) * 100
|
|
86
|
+
print(
|
|
87
|
+
f"Number of rows with missing weights after processing: {n_missing_after}"
|
|
88
|
+
)
|
|
89
|
+
print(
|
|
90
|
+
f"Percentage of rows with missing weights after processing: {pct_missing_after:.2f}%"
|
|
91
|
+
)
|
|
92
|
+
return data, weights_series.isna()
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from sklearn.ensemble import IsolationForest
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def mark_outliers(
|
|
7
|
+
data: pd.DataFrame,
|
|
8
|
+
contamination: float = 0.1,
|
|
9
|
+
random_state: int = 1234,
|
|
10
|
+
verbose: bool = False,
|
|
11
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
12
|
+
"""Marks outliers as NaN in the dataset using Isolation Forest.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
data (pd.DataFrame):
|
|
16
|
+
The input dataset.
|
|
17
|
+
contamination (float):
|
|
18
|
+
The (estimated) proportion of outliers in the dataset.
|
|
19
|
+
random_state (int):
|
|
20
|
+
Random seed for reproducibility. Default is 1234.
|
|
21
|
+
verbose (bool):
|
|
22
|
+
Whether to print additional information.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
tuple[pd.DataFrame, np.ndarray]: A tuple containing the modified dataset with outliers marked as NaN and the outlier labels.
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
29
|
+
>>> from spotforecast2.preprocessing.outlier import mark_outliers
|
|
30
|
+
>>> data = fetch_data()
|
|
31
|
+
>>> cleaned_data, outlier_labels = mark_outliers(data, contamination=0.1, random_state=42, verbose=True)
|
|
32
|
+
"""
|
|
33
|
+
for col in data.columns:
|
|
34
|
+
iso = IsolationForest(contamination=contamination, random_state=random_state)
|
|
35
|
+
# Fit and predict (-1 for outliers, 1 for inliers)
|
|
36
|
+
outliers = iso.fit_predict(data[[col]])
|
|
37
|
+
|
|
38
|
+
# Mark outliers as NaN
|
|
39
|
+
data.loc[outliers == -1, col] = np.nan
|
|
40
|
+
|
|
41
|
+
pct_outliers = (outliers == -1).mean() * 100
|
|
42
|
+
if verbose:
|
|
43
|
+
print(
|
|
44
|
+
f"Column '{col}': Marked {pct_outliers:.4f}% of data points as outliers."
|
|
45
|
+
)
|
|
46
|
+
return data, outliers
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def manual_outlier_removal(
|
|
50
|
+
data: pd.DataFrame,
|
|
51
|
+
column: str,
|
|
52
|
+
lower_threshold: float | None = None,
|
|
53
|
+
upper_threshold: float | None = None,
|
|
54
|
+
verbose: bool = False,
|
|
55
|
+
) -> tuple[pd.DataFrame, int]:
|
|
56
|
+
"""Manual outlier removal function.
|
|
57
|
+
Args:
|
|
58
|
+
data (pd.DataFrame):
|
|
59
|
+
The input dataset.
|
|
60
|
+
column (str):
|
|
61
|
+
The column name in which to perform manual outlier removal.
|
|
62
|
+
lower_threshold (float | None):
|
|
63
|
+
The lower threshold below which values are considered outliers.
|
|
64
|
+
If None, no lower threshold is applied.
|
|
65
|
+
upper_threshold (float | None):
|
|
66
|
+
The upper threshold above which values are considered outliers.
|
|
67
|
+
If None, no upper threshold is applied.
|
|
68
|
+
verbose (bool):
|
|
69
|
+
Whether to print additional information.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
tuple[pd.DataFrame, int]: A tuple containing the modified dataset with outliers marked as NaN and the number of outliers marked.
|
|
73
|
+
|
|
74
|
+
Examples:
|
|
75
|
+
>>> from spotforecast2.data.fetch_data import fetch_data
|
|
76
|
+
>>> from spotforecast2.preprocessing.outlier import manual_outlier_removal
|
|
77
|
+
>>> data = fetch_data()
|
|
78
|
+
>>> data, n_manual_outliers = manual_outlier_removal(
|
|
79
|
+
... data,
|
|
80
|
+
... column='ABC',
|
|
81
|
+
... lower_threshold=50,
|
|
82
|
+
... upper_threshold=700,
|
|
83
|
+
... verbose=True
|
|
84
|
+
"""
|
|
85
|
+
if lower_threshold is None and upper_threshold is None:
|
|
86
|
+
if verbose:
|
|
87
|
+
print(f"No thresholds provided for {column}; no outliers marked.")
|
|
88
|
+
return data, 0
|
|
89
|
+
|
|
90
|
+
if lower_threshold is not None and upper_threshold is not None:
|
|
91
|
+
mask = (data[column] > upper_threshold) | (data[column] < lower_threshold)
|
|
92
|
+
elif lower_threshold is not None:
|
|
93
|
+
mask = data[column] < lower_threshold
|
|
94
|
+
else:
|
|
95
|
+
mask = data[column] > upper_threshold
|
|
96
|
+
|
|
97
|
+
n_manual_outliers = mask.sum()
|
|
98
|
+
|
|
99
|
+
data.loc[mask, column] = np.nan
|
|
100
|
+
|
|
101
|
+
if verbose:
|
|
102
|
+
if lower_threshold is not None and upper_threshold is not None:
|
|
103
|
+
print(
|
|
104
|
+
f"Manually marked {n_manual_outliers} values > {upper_threshold} or < {lower_threshold} as outliers in {column}."
|
|
105
|
+
)
|
|
106
|
+
elif lower_threshold is not None:
|
|
107
|
+
print(
|
|
108
|
+
f"Manually marked {n_manual_outliers} values < {lower_threshold} as outliers in {column}."
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
print(
|
|
112
|
+
f"Manually marked {n_manual_outliers} values > {upper_threshold} as outliers in {column}."
|
|
113
|
+
)
|
|
114
|
+
return data, n_manual_outliers
|