spotforecast2 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. spotforecast2/.DS_Store +0 -0
  2. spotforecast2/__init__.py +2 -0
  3. spotforecast2/data/__init__.py +0 -0
  4. spotforecast2/data/data.py +130 -0
  5. spotforecast2/data/fetch_data.py +209 -0
  6. spotforecast2/exceptions.py +681 -0
  7. spotforecast2/forecaster/.DS_Store +0 -0
  8. spotforecast2/forecaster/__init__.py +7 -0
  9. spotforecast2/forecaster/base.py +448 -0
  10. spotforecast2/forecaster/metrics.py +527 -0
  11. spotforecast2/forecaster/recursive/__init__.py +4 -0
  12. spotforecast2/forecaster/recursive/_forecaster_equivalent_date.py +1075 -0
  13. spotforecast2/forecaster/recursive/_forecaster_recursive.py +939 -0
  14. spotforecast2/forecaster/recursive/_warnings.py +15 -0
  15. spotforecast2/forecaster/utils.py +954 -0
  16. spotforecast2/model_selection/__init__.py +5 -0
  17. spotforecast2/model_selection/bayesian_search.py +453 -0
  18. spotforecast2/model_selection/grid_search.py +314 -0
  19. spotforecast2/model_selection/random_search.py +151 -0
  20. spotforecast2/model_selection/split_base.py +357 -0
  21. spotforecast2/model_selection/split_one_step.py +245 -0
  22. spotforecast2/model_selection/split_ts_cv.py +634 -0
  23. spotforecast2/model_selection/utils_common.py +718 -0
  24. spotforecast2/model_selection/utils_metrics.py +103 -0
  25. spotforecast2/model_selection/validation.py +685 -0
  26. spotforecast2/preprocessing/__init__.py +30 -0
  27. spotforecast2/preprocessing/_binner.py +378 -0
  28. spotforecast2/preprocessing/_common.py +123 -0
  29. spotforecast2/preprocessing/_differentiator.py +123 -0
  30. spotforecast2/preprocessing/_rolling.py +136 -0
  31. spotforecast2/preprocessing/curate_data.py +254 -0
  32. spotforecast2/preprocessing/imputation.py +92 -0
  33. spotforecast2/preprocessing/outlier.py +114 -0
  34. spotforecast2/preprocessing/split.py +139 -0
  35. spotforecast2/py.typed +0 -0
  36. spotforecast2/utils/__init__.py +43 -0
  37. spotforecast2/utils/convert_to_utc.py +44 -0
  38. spotforecast2/utils/data_transform.py +208 -0
  39. spotforecast2/utils/forecaster_config.py +344 -0
  40. spotforecast2/utils/generate_holiday.py +70 -0
  41. spotforecast2/utils/validation.py +569 -0
  42. spotforecast2/weather/__init__.py +0 -0
  43. spotforecast2/weather/weather_client.py +288 -0
  44. spotforecast2-0.0.1.dist-info/METADATA +47 -0
  45. spotforecast2-0.0.1.dist-info/RECORD +46 -0
  46. spotforecast2-0.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,136 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from typing import List, Any
4
+ from ._common import (
5
+ _np_mean_jit,
6
+ _np_std_jit,
7
+ _np_min_jit,
8
+ _np_max_jit,
9
+ _np_sum_jit,
10
+ _np_median_jit,
11
+ )
12
+
13
+
14
+ class RollingFeatures:
15
+ """
16
+ Compute rolling features (stats) over a window of the time series.
17
+ Compatible with scikit-learn transformers API (fit, transform).
18
+
19
+ Attributes:
20
+ stats_funcs (list): List of rolling statistics functions.
21
+ window_sizes (list): List of window sizes.
22
+ features_names (list): List of feature names.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ stats: str | List[str] | List[Any],
28
+ window_sizes: int | List[int],
29
+ features_names: List[str] | None = None,
30
+ ):
31
+ """
32
+ Initialize the rolling features transformer.
33
+
34
+ Args:
35
+ stats (str | List[str] | List[Any]): Rolling statistics to compute.
36
+ window_sizes (int | List[int]): Window sizes for rolling statistics.
37
+ features_names (List[str] | None, optional): Names of the features.
38
+ Defaults to None.
39
+ """
40
+ self.stats = stats
41
+ self.window_sizes = window_sizes
42
+ self.features_names = features_names
43
+
44
+ # Validation and processing logic...
45
+ self._validate_params()
46
+
47
+ def _validate_params(self):
48
+ """
49
+ Validate the parameters of the rolling features transformer.
50
+ """
51
+ if isinstance(self.window_sizes, int):
52
+ self.window_sizes = [self.window_sizes]
53
+
54
+ if isinstance(self.stats, str):
55
+ self.stats = [self.stats]
56
+
57
+ # Map strings to functions
58
+ valid_stats = {
59
+ "mean": _np_mean_jit,
60
+ "std": _np_std_jit,
61
+ "min": _np_min_jit,
62
+ "max": _np_max_jit,
63
+ "sum": _np_sum_jit,
64
+ "median": _np_median_jit,
65
+ }
66
+
67
+ self.stats_funcs = []
68
+ for s in self.stats:
69
+ if isinstance(s, str):
70
+ if s not in valid_stats:
71
+ raise ValueError(
72
+ f"Stat '{s}' not supported. Supported: {list(valid_stats.keys())}"
73
+ )
74
+ self.stats_funcs.append(valid_stats[s])
75
+ else:
76
+ self.stats_funcs.append(s)
77
+
78
+ if self.features_names is None:
79
+ self.features_names = []
80
+ for ws in self.window_sizes:
81
+ for s in self.stats:
82
+ s_name = s if isinstance(s, str) else s.__name__
83
+ self.features_names.append(f"roll_{s_name}_{ws}")
84
+
85
+ def fit(self, X, y=None):
86
+ """
87
+ Fit the rolling features transformer.
88
+
89
+ Args:
90
+ X (np.ndarray): Time series to transform.
91
+ y (object, optional): Ignored.
92
+
93
+ Returns:
94
+ self: Fitted rolling features transformer.
95
+ """
96
+ return self
97
+
98
+ def transform(self, X: np.ndarray) -> np.ndarray:
99
+ """
100
+ Compute rolling features.
101
+
102
+ Args:
103
+ X (np.ndarray): Time series to transform.
104
+
105
+ Returns:
106
+ np.ndarray: Array with rolling features.
107
+ """
108
+ # Assume X is 1D array
109
+ n_samples = len(X)
110
+ output = np.full((n_samples, len(self.features_names)), np.nan)
111
+
112
+ idx_feature = 0
113
+ for ws in self.window_sizes:
114
+ for func in self.stats_funcs:
115
+ # Naive rolling window loop - can be optimized or use pandas rolling
116
+ # Using pandas for simplicity and speed if X is convertible
117
+ series = pd.Series(X)
118
+ rolled = series.rolling(window=ws).apply(func, raw=True)
119
+ output[:, idx_feature] = rolled.values
120
+ idx_feature += 1
121
+
122
+ return output
123
+
124
+ def transform_batch(self, X: pd.Series) -> pd.DataFrame:
125
+ """
126
+ Transform a pandas Series to rolling features DataFrame.
127
+
128
+ Args:
129
+ X (pd.Series): Time series to transform.
130
+
131
+ Returns:
132
+ pd.DataFrame: DataFrame with rolling features.
133
+ """
134
+ values = X.to_numpy()
135
+ transformed = self.transform(values)
136
+ return pd.DataFrame(transformed, index=X.index, columns=self.features_names)
@@ -0,0 +1,254 @@
1
+ import pandas as pd
2
+
3
+
4
+ def get_start_end(
5
+ data: pd.DataFrame,
6
+ forecast_horizon: int,
7
+ verbose: bool = True,
8
+ ) -> tuple[str, str, str, str]:
9
+ """Get start and end date strings for data and covariate ranges.
10
+ Covariate range is extended by the forecast horizon.
11
+
12
+ Args:
13
+ data (pd.DataFrame):
14
+ The dataset with a datetime index.
15
+ forecast_horizon (int):
16
+ The forecast horizon in hours.
17
+ verbose (bool):
18
+ Whether to print the determined date ranges.
19
+
20
+ Returns:
21
+ tuple[str, str, str, str]: (data_start, data_end, covariate_start, covariate_end)
22
+ Date strings in the format "YYYY-MM-DDTHH:MM" for data and covariate ranges.
23
+
24
+ Examples:
25
+ >>> from spotforecast2.preprocessing.curate_data import get_start_end
26
+ >>> import pandas as pd
27
+ >>> date_rng = pd.date_range(start='2023-01-01', end='2023-01-10', freq='h')
28
+ >>> data = pd.DataFrame(date_rng, columns=['date'])
29
+ >>> data.set_index('date', inplace=True)
30
+ >>> start, end, cov_start, cov_end = get_start_end(data, forecast_horizon=24, verbose=False)
31
+ >>> print(start, end, cov_start, cov_end)
32
+ 2023-01-01T00:00 2023-01-10T00:00 2023-01-01T00:00 2023-01-11T00:00
33
+ """
34
+ FORECAST_HORIZON = forecast_horizon
35
+
36
+ START = data.index.min().strftime("%Y-%m-%dT%H:%M")
37
+ END = data.index.max().strftime("%Y-%m-%dT%H:%M")
38
+ if verbose:
39
+ print(f"Data range: {START} to {END}")
40
+ # Define covariate range relative to data range
41
+ COV_START = START
42
+ # Extend end date by forecast horizon to include future covariates
43
+ COV_END = (pd.to_datetime(END) + pd.Timedelta(hours=FORECAST_HORIZON)).strftime(
44
+ "%Y-%m-%dT%H:%M"
45
+ )
46
+ if verbose:
47
+ print(f"Covariate data range: {COV_START} to {COV_END}")
48
+ return START, END, COV_START, COV_END
49
+
50
+
51
+ def curate_holidays(
52
+ holiday_df: pd.DataFrame, data: pd.DataFrame, forecast_horizon: int
53
+ ):
54
+ """Checks if the holiday dataframe has the correct shape.
55
+ Args:
56
+ holiday_df (pd.DataFrame):
57
+ DataFrame containing holiday information.
58
+ data (pd.DataFrame):
59
+ The main dataset.
60
+ forecast_horizon (int):
61
+ The forecast horizon in hours.
62
+
63
+ Examples:
64
+ >>> from spotforecast2.data.fetch_data import fetch_data, fetch_holiday_data
65
+ >>> from spotforecast2.preprocessing.curate_data import get_start_end, curate_holidays
66
+ >>> data = fetch_data()
67
+ >>> START, END, COV_START, COV_END = get_start_end(
68
+ ... data=data,
69
+ ... forecast_horizon=24,
70
+ ... verbose=False
71
+ ... )
72
+ >>> holiday_df = fetch_holiday_data(
73
+ ... start='2023-01-01T00:00',
74
+ ... end='2023-01-10T00:00',
75
+ ... tz='UTC',
76
+ ... freq='h',
77
+ ... country_code='DE',
78
+ ... state='NW'
79
+ ... )
80
+ >>> FORECAST_HORIZON = 24
81
+ >>> curate_holidays(holiday_df, data, forecast_horizon=FORECAST_HORIZON)
82
+
83
+ Raises:
84
+ AssertionError:
85
+ If the holiday dataframe does not have the correct number of rows.
86
+ """
87
+ try:
88
+ assert holiday_df.shape[0] == data.shape[0] + forecast_horizon
89
+ print("Holiday dataframe has correct shape.")
90
+ except AssertionError:
91
+ print("Holiday dataframe has wrong shape.")
92
+
93
+
94
+ def curate_weather(weather_df: pd.DataFrame, data: pd.DataFrame, forecast_horizon: int):
95
+ """Checks if the weather dataframe has the correct shape.
96
+
97
+ Args:
98
+ weather_df (pd.DataFrame):
99
+ DataFrame containing weather information.
100
+ data (pd.DataFrame):
101
+ The main dataset.
102
+ forecast_horizon (int):
103
+ The forecast horizon in hours.
104
+
105
+ Examples:
106
+ >>> from spotforecast2.data.fetch_data import fetch_data, fetch_weather_data
107
+ >>> from spotforecast2.preprocessing.curate_data import get_start_end, curate_weather
108
+ >>> data = fetch_data()
109
+ >>> START, END, COV_START, COV_END = get_start_end(
110
+ ... data=data,
111
+ ... forecast_horizon=24,
112
+ ... verbose=False
113
+ ... )
114
+ >>> weather_df = fetch_weather_data(
115
+ ... cov_start=COV_START,
116
+ ... cov_end=COV_END,
117
+ ... tz='UTC',
118
+ ... freq='h',
119
+ ... latitude=51.5136,
120
+ ... longitude=7.4653
121
+ ... )
122
+ >>> FORECAST_HORIZON = 24
123
+ >>> curate_weather(weather_df, data, forecast_horizon=FORECAST_HORIZON)
124
+
125
+ Raises:
126
+ AssertionError:
127
+ If the weather dataframe does not have the correct number of rows.
128
+ """
129
+ try:
130
+ assert weather_df.shape[0] == data.shape[0] + forecast_horizon
131
+ print("Weather dataframe has correct shape.")
132
+ except AssertionError:
133
+ print("Weather dataframe has wrong shape.")
134
+
135
+
136
+ def basic_ts_checks(data: pd.DataFrame, verbose: bool = False) -> bool:
137
+ """Checks if the time series data has a datetime index and is sorted.
138
+
139
+ Args:
140
+ data (pd.DataFrame):
141
+ The main dataset.
142
+ verbose (bool):
143
+ Whether to print additional information.
144
+
145
+ Examples:
146
+ >>> from spotforecast2.data.fetch_data import fetch_data
147
+ >>> from spotforecast2.preprocessing.curate_data import basic_ts_checks
148
+ >>> data = fetch_data()
149
+ >>> basic_ts_checks(data)
150
+
151
+ Raises:
152
+ TypeError:
153
+ If the index is not a datetime index.
154
+ ValueError:
155
+ If the datetime index is not sorted in increasing order or is incomplete.
156
+
157
+ Returns:
158
+ bool: True if the datetime index is valid, sorted, and complete.
159
+ """
160
+ # Check if the time series data has a datetime index
161
+ if not pd.api.types.is_datetime64_any_dtype(data.index):
162
+ raise TypeError("The index is not a datetime index.")
163
+
164
+ # Check if the datetime index is sorted
165
+ if not data.index.is_monotonic_increasing:
166
+ raise ValueError("The datetime index is not sorted in increasing order.")
167
+
168
+ # Check if the index is complete (no missing timestamps)
169
+ start_date = data.index.min()
170
+ end_date = data.index.max()
171
+ complete_date_range = pd.date_range(
172
+ start=start_date, end=end_date, freq=data.index.freq
173
+ )
174
+ is_index_complete = (data.index == complete_date_range).all()
175
+
176
+ if not is_index_complete:
177
+ raise ValueError(
178
+ "The datetime index has missing timestamps and is not complete."
179
+ )
180
+ if verbose:
181
+ print(
182
+ "The time series data has a valid datetime index that is sorted and complete."
183
+ )
184
+ return True
185
+
186
+
187
+ def agg_and_resample_data(
188
+ data: pd.DataFrame,
189
+ rule: str = "h",
190
+ closed: str = "left",
191
+ label: str = "left",
192
+ by="mean",
193
+ verbose: bool = False,
194
+ ) -> pd.DataFrame:
195
+ """
196
+ Aggregates and resamples the data to (e.g.,hourly) frequency by computing the specified aggregation (e.g. for each hour).
197
+
198
+ Args:
199
+ data (pd.DataFrame):
200
+ The dataset with a datetime index.
201
+ rule (str):
202
+ The resample rule (e.g., 'h' for hourly, 'D' for daily).
203
+ Default is 'h' which creates an hourly grid.
204
+ closed (str):
205
+ Which side of bin interval is closed. Default is 'left'.
206
+ Using `closed="left", label="left"` specifies that a time interval
207
+ (e.g., 10:00 to 11:00) is labeled with the start timestamp (10:00).
208
+ For consumption data, a different representation is usually more common:
209
+ `closed="left", label="right"`, so the interval is labeled with the end
210
+ timestamp (11:00), since consumption is typically reported after one hour.
211
+ label (str):
212
+ Which bin edge label to use. Default is 'left'.
213
+ See 'closed' parameter for details on labeling behavior.
214
+ by (str or callable):
215
+ Aggregation method to apply (e.g., 'mean', 'sum', 'median').
216
+ Default is 'mean'.
217
+ The aggregation serves robustness: if the data were more finely resolved
218
+ (e.g., quarter-hourly), asfreq would only pick one value (sampling),
219
+ while .agg("mean") forms the correct average over the hour.
220
+ If the data is already hourly, .agg doesn't change anything but ensures
221
+ that no duplicates exist.
222
+ verbose (bool):
223
+ Whether to print additional information.
224
+
225
+ Returns:
226
+ pd.DataFrame: Resampled and aggregated dataframe.
227
+
228
+ Notes:
229
+ - resample(rule="h"): Creates an hourly grid
230
+ - closed/label: Control how time intervals are labeled
231
+ - .agg({...: by}): Aggregates values within each time bin
232
+
233
+ Examples::
234
+ >>> from spotforecast2.preprocessing.curate_data import agg_and_resample_data
235
+ >>> import pandas as pd
236
+ >>> date_rng = pd.date_range(start='2023-01-01', end='2023-01-02', freq='15T')
237
+ >>> data = pd.DataFrame(date_rng, columns=['date'])
238
+ >>> data.set_index('date', inplace=True)
239
+ >>> data['value'] = range(len(data))
240
+ >>> resampled_data = agg_and_resample_data(data, rule='h', by='mean')
241
+ >>> print(resampled_data.head())
242
+ """
243
+ if verbose:
244
+ print(f"Original data shape: {data.shape}")
245
+ # Create aggregation dictionary for all columns
246
+ agg_dict = {col: by for col in data.columns}
247
+
248
+ data = data.resample(rule=rule, closed=closed, label=label).agg(agg_dict)
249
+ if verbose:
250
+ print(
251
+ f"Data resampled with rule='{rule}', closed='{closed}', label='{label}', aggregation='{by}'."
252
+ )
253
+ print(f"Resampled data shape: {data.shape}")
254
+ return data
@@ -0,0 +1,92 @@
1
+ import pandas as pd
2
+
3
+
4
+ def custom_weights(index, weights_series: pd.Series) -> float:
5
+ """
6
+ Return 0 if index is in or near any gap.
7
+
8
+ Args:
9
+ index (pd.Index):
10
+ The index to check.
11
+ weights_series (pd.Series):
12
+ Series containing weights.
13
+
14
+ Returns:
15
+ float: The weight corresponding to the index.
16
+
17
+ Examples:
18
+ >>> from spotforecast2.data.fetch_data import fetch_data
19
+ >>> from spotforecast2.preprocessing.imputation import custom_weights
20
+ >>> data = fetch_data()
21
+ >>> _, missing_weights = get_missing_weights(data, window_size=72, verbose=False)
22
+ >>> for idx in data.index[:5]:
23
+ ... weight = custom_weights(idx, missing_weights)
24
+ ... print(f"Index: {idx}, Weight: {weight}")
25
+ """
26
+ # do plausibility check
27
+ if isinstance(index, pd.Index):
28
+ if not index.isin(weights_series.index).all():
29
+ raise ValueError("Index not found in weights_series.")
30
+ return weights_series.loc[index].values
31
+
32
+ if index not in weights_series.index:
33
+ raise ValueError("Index not found in weights_series.")
34
+ return weights_series.loc[index]
35
+
36
+
37
+ def get_missing_weights(
38
+ data: pd.DataFrame, window_size: int = 72, verbose: bool = False
39
+ ) -> tuple[pd.DataFrame, pd.Series]:
40
+ """
41
+ Return imputed DataFrame and a series indicating missing weights.
42
+
43
+ Args:
44
+ data (pd.DataFrame):
45
+ The input dataset.
46
+ window_size (int):
47
+ The size of the rolling window to consider for missing values.
48
+ verbose (bool):
49
+ Whether to print additional information.
50
+
51
+ Returns:
52
+ Tuple[pd.DataFrame, pd.Series]:
53
+ A tuple containing the forward and backward filled DataFrame and a boolean series where True indicates missing weights.
54
+
55
+ Examples:
56
+ >>> from spotforecast2.data.fetch_data import fetch_data
57
+ >>> from spotforecast2.preprocessing.imputation import get_missing_weights
58
+ >>> data = fetch_data()
59
+ >>> filled_data, missing_weights = get_missing_weights(data, window_size=72, verbose=True)
60
+
61
+ """
62
+ # first perform some checks if dataframe has enough data and if window_size is appropriate
63
+ if data.shape[0] == 0:
64
+ raise ValueError("Input data is empty.")
65
+ if window_size <= 0:
66
+ raise ValueError("window_size must be a positive integer.")
67
+ if window_size >= data.shape[0]:
68
+ raise ValueError("window_size must be smaller than the number of rows in data.")
69
+
70
+ missing_indices = data.index[data.isnull().any(axis=1)]
71
+ n_missing = len(missing_indices)
72
+ if verbose:
73
+ pct_missing = (n_missing / len(data)) * 100
74
+ print(f"Number of rows with missing values: {n_missing}")
75
+ print(f"Percentage of rows with missing values: {pct_missing:.2f}%")
76
+ print(f"missing_indices: {missing_indices}")
77
+ data = data.ffill()
78
+ data = data.bfill()
79
+
80
+ is_missing = pd.Series(0, index=data.index)
81
+ is_missing.loc[missing_indices] = 1
82
+ weights_series = 1 - is_missing.rolling(window=window_size + 1, min_periods=1).max()
83
+ if verbose:
84
+ n_missing_after = weights_series.isna().sum()
85
+ pct_missing_after = (n_missing_after / len(data)) * 100
86
+ print(
87
+ f"Number of rows with missing weights after processing: {n_missing_after}"
88
+ )
89
+ print(
90
+ f"Percentage of rows with missing weights after processing: {pct_missing_after:.2f}%"
91
+ )
92
+ return data, weights_series.isna()
@@ -0,0 +1,114 @@
1
+ from sklearn.ensemble import IsolationForest
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+
6
+ def mark_outliers(
7
+ data: pd.DataFrame,
8
+ contamination: float = 0.1,
9
+ random_state: int = 1234,
10
+ verbose: bool = False,
11
+ ) -> tuple[pd.DataFrame, np.ndarray]:
12
+ """Marks outliers as NaN in the dataset using Isolation Forest.
13
+
14
+ Args:
15
+ data (pd.DataFrame):
16
+ The input dataset.
17
+ contamination (float):
18
+ The (estimated) proportion of outliers in the dataset.
19
+ random_state (int):
20
+ Random seed for reproducibility. Default is 1234.
21
+ verbose (bool):
22
+ Whether to print additional information.
23
+
24
+ Returns:
25
+ tuple[pd.DataFrame, np.ndarray]: A tuple containing the modified dataset with outliers marked as NaN and the outlier labels.
26
+
27
+ Examples:
28
+ >>> from spotforecast2.data.fetch_data import fetch_data
29
+ >>> from spotforecast2.preprocessing.outlier import mark_outliers
30
+ >>> data = fetch_data()
31
+ >>> cleaned_data, outlier_labels = mark_outliers(data, contamination=0.1, random_state=42, verbose=True)
32
+ """
33
+ for col in data.columns:
34
+ iso = IsolationForest(contamination=contamination, random_state=random_state)
35
+ # Fit and predict (-1 for outliers, 1 for inliers)
36
+ outliers = iso.fit_predict(data[[col]])
37
+
38
+ # Mark outliers as NaN
39
+ data.loc[outliers == -1, col] = np.nan
40
+
41
+ pct_outliers = (outliers == -1).mean() * 100
42
+ if verbose:
43
+ print(
44
+ f"Column '{col}': Marked {pct_outliers:.4f}% of data points as outliers."
45
+ )
46
+ return data, outliers
47
+
48
+
49
+ def manual_outlier_removal(
50
+ data: pd.DataFrame,
51
+ column: str,
52
+ lower_threshold: float | None = None,
53
+ upper_threshold: float | None = None,
54
+ verbose: bool = False,
55
+ ) -> tuple[pd.DataFrame, int]:
56
+ """Manual outlier removal function.
57
+ Args:
58
+ data (pd.DataFrame):
59
+ The input dataset.
60
+ column (str):
61
+ The column name in which to perform manual outlier removal.
62
+ lower_threshold (float | None):
63
+ The lower threshold below which values are considered outliers.
64
+ If None, no lower threshold is applied.
65
+ upper_threshold (float | None):
66
+ The upper threshold above which values are considered outliers.
67
+ If None, no upper threshold is applied.
68
+ verbose (bool):
69
+ Whether to print additional information.
70
+
71
+ Returns:
72
+ tuple[pd.DataFrame, int]: A tuple containing the modified dataset with outliers marked as NaN and the number of outliers marked.
73
+
74
+ Examples:
75
+ >>> from spotforecast2.data.fetch_data import fetch_data
76
+ >>> from spotforecast2.preprocessing.outlier import manual_outlier_removal
77
+ >>> data = fetch_data()
78
+ >>> data, n_manual_outliers = manual_outlier_removal(
79
+ ... data,
80
+ ... column='ABC',
81
+ ... lower_threshold=50,
82
+ ... upper_threshold=700,
83
+ ... verbose=True
84
+ """
85
+ if lower_threshold is None and upper_threshold is None:
86
+ if verbose:
87
+ print(f"No thresholds provided for {column}; no outliers marked.")
88
+ return data, 0
89
+
90
+ if lower_threshold is not None and upper_threshold is not None:
91
+ mask = (data[column] > upper_threshold) | (data[column] < lower_threshold)
92
+ elif lower_threshold is not None:
93
+ mask = data[column] < lower_threshold
94
+ else:
95
+ mask = data[column] > upper_threshold
96
+
97
+ n_manual_outliers = mask.sum()
98
+
99
+ data.loc[mask, column] = np.nan
100
+
101
+ if verbose:
102
+ if lower_threshold is not None and upper_threshold is not None:
103
+ print(
104
+ f"Manually marked {n_manual_outliers} values > {upper_threshold} or < {lower_threshold} as outliers in {column}."
105
+ )
106
+ elif lower_threshold is not None:
107
+ print(
108
+ f"Manually marked {n_manual_outliers} values < {lower_threshold} as outliers in {column}."
109
+ )
110
+ else:
111
+ print(
112
+ f"Manually marked {n_manual_outliers} values > {upper_threshold} as outliers in {column}."
113
+ )
114
+ return data, n_manual_outliers