additory 0.1.0a4__py3-none-any.whl → 0.1.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- additory/__init__.py +58 -14
- additory/common/__init__.py +31 -147
- additory/common/column_selector.py +255 -0
- additory/common/distributions.py +286 -613
- additory/common/extractors.py +313 -0
- additory/common/knn_imputation.py +332 -0
- additory/common/result.py +380 -0
- additory/common/strategy_parser.py +243 -0
- additory/common/unit_conversions.py +338 -0
- additory/common/validation.py +283 -103
- additory/core/__init__.py +34 -22
- additory/core/backend.py +258 -0
- additory/core/config.py +177 -305
- additory/core/logging.py +230 -24
- additory/core/memory_manager.py +157 -495
- additory/expressions/__init__.py +2 -23
- additory/expressions/compiler.py +457 -0
- additory/expressions/engine.py +264 -487
- additory/expressions/integrity.py +179 -0
- additory/expressions/loader.py +263 -0
- additory/expressions/parser.py +363 -167
- additory/expressions/resolver.py +274 -0
- additory/functions/__init__.py +1 -0
- additory/functions/analyze/__init__.py +144 -0
- additory/functions/analyze/cardinality.py +58 -0
- additory/functions/analyze/correlations.py +66 -0
- additory/functions/analyze/distributions.py +53 -0
- additory/functions/analyze/duplicates.py +49 -0
- additory/functions/analyze/features.py +61 -0
- additory/functions/analyze/imputation.py +66 -0
- additory/functions/analyze/outliers.py +65 -0
- additory/functions/analyze/patterns.py +65 -0
- additory/functions/analyze/presets.py +72 -0
- additory/functions/analyze/quality.py +59 -0
- additory/functions/analyze/timeseries.py +53 -0
- additory/functions/analyze/types.py +45 -0
- additory/functions/expressions/__init__.py +161 -0
- additory/functions/snapshot/__init__.py +82 -0
- additory/functions/snapshot/filter.py +119 -0
- additory/functions/synthetic/__init__.py +113 -0
- additory/functions/synthetic/mode_detector.py +47 -0
- additory/functions/synthetic/strategies/__init__.py +1 -0
- additory/functions/synthetic/strategies/advanced.py +35 -0
- additory/functions/synthetic/strategies/augmentative.py +160 -0
- additory/functions/synthetic/strategies/generative.py +168 -0
- additory/functions/synthetic/strategies/presets.py +116 -0
- additory/functions/to/__init__.py +188 -0
- additory/functions/to/lookup.py +351 -0
- additory/functions/to/merge.py +189 -0
- additory/functions/to/sort.py +91 -0
- additory/functions/to/summarize.py +170 -0
- additory/functions/transform/__init__.py +140 -0
- additory/functions/transform/datetime.py +79 -0
- additory/functions/transform/extract.py +85 -0
- additory/functions/transform/harmonize.py +105 -0
- additory/functions/transform/knn.py +62 -0
- additory/functions/transform/onehotencoding.py +68 -0
- additory/functions/transform/transpose.py +42 -0
- additory-0.1.1a1.dist-info/METADATA +83 -0
- additory-0.1.1a1.dist-info/RECORD +62 -0
- additory/analysis/__init__.py +0 -48
- additory/analysis/cardinality.py +0 -126
- additory/analysis/correlations.py +0 -124
- additory/analysis/distributions.py +0 -376
- additory/analysis/quality.py +0 -158
- additory/analysis/scan.py +0 -400
- additory/common/backend.py +0 -371
- additory/common/column_utils.py +0 -191
- additory/common/exceptions.py +0 -62
- additory/common/lists.py +0 -229
- additory/common/patterns.py +0 -240
- additory/common/resolver.py +0 -567
- additory/common/sample_data.py +0 -182
- additory/core/ast_builder.py +0 -165
- additory/core/backends/__init__.py +0 -23
- additory/core/backends/arrow_bridge.py +0 -483
- additory/core/backends/cudf_bridge.py +0 -355
- additory/core/column_positioning.py +0 -358
- additory/core/compiler_polars.py +0 -166
- additory/core/enhanced_cache_manager.py +0 -1119
- additory/core/enhanced_matchers.py +0 -473
- additory/core/enhanced_version_manager.py +0 -325
- additory/core/executor.py +0 -59
- additory/core/integrity_manager.py +0 -477
- additory/core/loader.py +0 -190
- additory/core/namespace_manager.py +0 -657
- additory/core/parser.py +0 -176
- additory/core/polars_expression_engine.py +0 -601
- additory/core/registry.py +0 -177
- additory/core/sample_data_manager.py +0 -492
- additory/core/user_namespace.py +0 -751
- additory/core/validator.py +0 -27
- additory/dynamic_api.py +0 -352
- additory/expressions/proxy.py +0 -549
- additory/expressions/registry.py +0 -313
- additory/expressions/samples.py +0 -492
- additory/synthetic/__init__.py +0 -13
- additory/synthetic/column_name_resolver.py +0 -149
- additory/synthetic/deduce.py +0 -259
- additory/synthetic/distributions.py +0 -22
- additory/synthetic/forecast.py +0 -1132
- additory/synthetic/linked_list_parser.py +0 -415
- additory/synthetic/namespace_lookup.py +0 -129
- additory/synthetic/smote.py +0 -320
- additory/synthetic/strategies.py +0 -926
- additory/synthetic/synthesizer.py +0 -713
- additory/utilities/__init__.py +0 -53
- additory/utilities/encoding.py +0 -600
- additory/utilities/games.py +0 -300
- additory/utilities/keys.py +0 -8
- additory/utilities/lookup.py +0 -103
- additory/utilities/matchers.py +0 -216
- additory/utilities/resolvers.py +0 -286
- additory/utilities/settings.py +0 -167
- additory/utilities/units.py +0 -749
- additory/utilities/validators.py +0 -153
- additory-0.1.0a4.dist-info/METADATA +0 -311
- additory-0.1.0a4.dist-info/RECORD +0 -72
- additory-0.1.0a4.dist-info/licenses/LICENSE +0 -21
- {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/WHEEL +0 -0
- {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/top_level.txt +0 -0
additory/synthetic/forecast.py
DELETED
|
@@ -1,1132 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Forecast Strategies for Synthetic Data Generation
|
|
3
|
-
|
|
4
|
-
Provides time series forecasting capabilities:
|
|
5
|
-
- Linear trend forecasting
|
|
6
|
-
- Polynomial trend forecasting
|
|
7
|
-
- Exponential growth forecasting
|
|
8
|
-
- Moving average forecasting
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
from typing import Optional, Tuple, List, Any
|
|
12
|
-
import warnings
|
|
13
|
-
|
|
14
|
-
import numpy as np
|
|
15
|
-
|
|
16
|
-
from additory.common.exceptions import ValidationError, AugmentError
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class ForecastMethod:
|
|
20
|
-
"""Supported forecasting methods."""
|
|
21
|
-
LINEAR = "linear"
|
|
22
|
-
POLYNOMIAL = "polynomial"
|
|
23
|
-
EXPONENTIAL = "exponential"
|
|
24
|
-
MOVING_AVERAGE = "moving_average"
|
|
25
|
-
SEASONAL = "seasonal"
|
|
26
|
-
HOLT_WINTERS = "holt_winters"
|
|
27
|
-
ENSEMBLE = "ensemble"
|
|
28
|
-
AUTO = "auto"
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# Minimum data requirements for each method
|
|
32
|
-
MIN_DATA_POINTS = {
|
|
33
|
-
ForecastMethod.LINEAR: 3,
|
|
34
|
-
ForecastMethod.POLYNOMIAL: 5,
|
|
35
|
-
ForecastMethod.EXPONENTIAL: 10,
|
|
36
|
-
ForecastMethod.MOVING_AVERAGE: 4,
|
|
37
|
-
ForecastMethod.SEASONAL: 12, # Need at least one full cycle
|
|
38
|
-
ForecastMethod.HOLT_WINTERS: 12, # Need at least one full cycle
|
|
39
|
-
ForecastMethod.ENSEMBLE: 5,
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def detect_time_column(df_polars, hint: Optional[str] = None) -> Optional[str]:
|
|
44
|
-
"""
|
|
45
|
-
Detect time/date column in dataframe.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
df_polars: Polars DataFrame
|
|
49
|
-
hint: Optional column name hint
|
|
50
|
-
|
|
51
|
-
Returns:
|
|
52
|
-
Column name or None if not found
|
|
53
|
-
|
|
54
|
-
Detection logic:
|
|
55
|
-
1. If hint provided and exists, use it
|
|
56
|
-
2. Look for datetime dtype columns
|
|
57
|
-
3. Look for common time column names
|
|
58
|
-
4. Return None (will use row index)
|
|
59
|
-
"""
|
|
60
|
-
import polars as pl
|
|
61
|
-
|
|
62
|
-
# If hint provided, validate it exists
|
|
63
|
-
if hint:
|
|
64
|
-
if hint in df_polars.columns:
|
|
65
|
-
return hint
|
|
66
|
-
else:
|
|
67
|
-
raise ValidationError(f"Specified time column '{hint}' not found in dataframe")
|
|
68
|
-
|
|
69
|
-
# Check for datetime columns
|
|
70
|
-
for col in df_polars.columns:
|
|
71
|
-
if df_polars[col].dtype in [pl.Date, pl.Datetime, pl.Duration]:
|
|
72
|
-
return col
|
|
73
|
-
|
|
74
|
-
# Check for common time column names
|
|
75
|
-
time_names = ['date', 'time', 'timestamp', 'datetime', 'period', 'day', 'month', 'year']
|
|
76
|
-
for col in df_polars.columns:
|
|
77
|
-
if col.lower() in time_names:
|
|
78
|
-
return col
|
|
79
|
-
|
|
80
|
-
# No time column found
|
|
81
|
-
return None
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def validate_forecast_data(
|
|
85
|
-
df_polars,
|
|
86
|
-
column: str,
|
|
87
|
-
method: str,
|
|
88
|
-
min_points: Optional[int] = None
|
|
89
|
-
) -> None:
|
|
90
|
-
"""
|
|
91
|
-
Validate data is suitable for forecasting.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
df_polars: Polars DataFrame
|
|
95
|
-
column: Column to forecast
|
|
96
|
-
method: Forecasting method
|
|
97
|
-
min_points: Minimum required data points (optional)
|
|
98
|
-
|
|
99
|
-
Raises:
|
|
100
|
-
ValidationError: If data is invalid
|
|
101
|
-
"""
|
|
102
|
-
# Check column exists
|
|
103
|
-
if column not in df_polars.columns:
|
|
104
|
-
raise ValidationError(f"Column '{column}' not found in dataframe")
|
|
105
|
-
|
|
106
|
-
# Check column is numeric
|
|
107
|
-
col_data = df_polars[column]
|
|
108
|
-
if not col_data.dtype.is_numeric():
|
|
109
|
-
raise ValidationError(
|
|
110
|
-
f"Column '{column}' must be numeric for forecasting. "
|
|
111
|
-
f"Got dtype: {col_data.dtype}"
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
# Check minimum data points
|
|
115
|
-
n_points = len(df_polars)
|
|
116
|
-
required_min = min_points or MIN_DATA_POINTS.get(method, 3)
|
|
117
|
-
|
|
118
|
-
if n_points < required_min:
|
|
119
|
-
raise ValidationError(
|
|
120
|
-
f"Insufficient data for {method} forecasting. "
|
|
121
|
-
f"Need at least {required_min} points, got {n_points}"
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
# Check for null values
|
|
125
|
-
null_count = col_data.null_count()
|
|
126
|
-
if null_count > 0:
|
|
127
|
-
raise ValidationError(
|
|
128
|
-
f"Column '{column}' contains {null_count} null values. "
|
|
129
|
-
"Remove nulls before forecasting."
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
# Check for variance (all values identical)
|
|
133
|
-
values = col_data.to_numpy()
|
|
134
|
-
if np.std(values) == 0:
|
|
135
|
-
raise ValidationError(
|
|
136
|
-
f"Column '{column}' has no variance (all values identical). "
|
|
137
|
-
"Cannot forecast constant data."
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def fit_linear_trend(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
|
|
142
|
-
"""
|
|
143
|
-
Fit linear trend using least squares.
|
|
144
|
-
|
|
145
|
-
Formula: y = mx + b
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
x: Independent variable (time)
|
|
149
|
-
y: Dependent variable (values to forecast)
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
Tuple of (slope, intercept, r_squared)
|
|
153
|
-
"""
|
|
154
|
-
# Use numpy polyfit for linear regression
|
|
155
|
-
coeffs = np.polyfit(x, y, 1)
|
|
156
|
-
slope, intercept = coeffs[0], coeffs[1]
|
|
157
|
-
|
|
158
|
-
# Calculate R²
|
|
159
|
-
y_pred = slope * x + intercept
|
|
160
|
-
ss_res = np.sum((y - y_pred) ** 2)
|
|
161
|
-
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
|
162
|
-
r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
|
|
163
|
-
|
|
164
|
-
return slope, intercept, r_squared
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def fit_polynomial_trend(
|
|
168
|
-
x: np.ndarray,
|
|
169
|
-
y: np.ndarray,
|
|
170
|
-
degree: int = 2
|
|
171
|
-
) -> Tuple[np.ndarray, float]:
|
|
172
|
-
"""
|
|
173
|
-
Fit polynomial trend using least squares.
|
|
174
|
-
|
|
175
|
-
Formula: y = a*x^n + b*x^(n-1) + ... + c
|
|
176
|
-
|
|
177
|
-
Args:
|
|
178
|
-
x: Independent variable (time)
|
|
179
|
-
y: Dependent variable (values to forecast)
|
|
180
|
-
degree: Polynomial degree (default: 2)
|
|
181
|
-
|
|
182
|
-
Returns:
|
|
183
|
-
Tuple of (coefficients, r_squared)
|
|
184
|
-
"""
|
|
185
|
-
# Validate degree
|
|
186
|
-
if degree < 1:
|
|
187
|
-
raise ValidationError(f"Polynomial degree must be >= 1, got {degree}")
|
|
188
|
-
if degree > 10:
|
|
189
|
-
warnings.warn(f"High polynomial degree ({degree}) may cause overfitting")
|
|
190
|
-
|
|
191
|
-
# Fit polynomial
|
|
192
|
-
coeffs = np.polyfit(x, y, degree)
|
|
193
|
-
|
|
194
|
-
# Calculate R²
|
|
195
|
-
y_pred = np.polyval(coeffs, x)
|
|
196
|
-
ss_res = np.sum((y - y_pred) ** 2)
|
|
197
|
-
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
|
198
|
-
r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
|
|
199
|
-
|
|
200
|
-
return coeffs, r_squared
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def fit_exponential_trend(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
|
|
204
|
-
"""
|
|
205
|
-
Fit exponential trend.
|
|
206
|
-
|
|
207
|
-
Formula: y = a * e^(b*x)
|
|
208
|
-
|
|
209
|
-
Method: Transform to log space and fit linear
|
|
210
|
-
log(y) = log(a) + b*x
|
|
211
|
-
|
|
212
|
-
Args:
|
|
213
|
-
x: Independent variable (time)
|
|
214
|
-
y: Dependent variable (values to forecast)
|
|
215
|
-
|
|
216
|
-
Returns:
|
|
217
|
-
Tuple of (a, b, r_squared)
|
|
218
|
-
"""
|
|
219
|
-
# Check for non-positive values
|
|
220
|
-
if np.any(y <= 0):
|
|
221
|
-
raise ValidationError(
|
|
222
|
-
"Exponential forecasting requires all positive values. "
|
|
223
|
-
"Found zero or negative values."
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
# Transform to log space
|
|
227
|
-
log_y = np.log(y)
|
|
228
|
-
|
|
229
|
-
# Fit linear in log space
|
|
230
|
-
coeffs = np.polyfit(x, log_y, 1)
|
|
231
|
-
b, log_a = coeffs[0], coeffs[1]
|
|
232
|
-
a = np.exp(log_a)
|
|
233
|
-
|
|
234
|
-
# Calculate R² in original space
|
|
235
|
-
y_pred = a * np.exp(b * x)
|
|
236
|
-
ss_res = np.sum((y - y_pred) ** 2)
|
|
237
|
-
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
|
238
|
-
r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
|
|
239
|
-
|
|
240
|
-
return a, b, r_squared
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
def calculate_moving_average(y: np.ndarray, window: int) -> float:
|
|
244
|
-
"""
|
|
245
|
-
Calculate moving average of last 'window' values.
|
|
246
|
-
|
|
247
|
-
Args:
|
|
248
|
-
y: Array of values
|
|
249
|
-
window: Window size
|
|
250
|
-
|
|
251
|
-
Returns:
|
|
252
|
-
Moving average value
|
|
253
|
-
"""
|
|
254
|
-
if window < 1:
|
|
255
|
-
raise ValidationError(f"Window size must be >= 1, got {window}")
|
|
256
|
-
|
|
257
|
-
if len(y) < window:
|
|
258
|
-
raise ValidationError(
|
|
259
|
-
f"Not enough data for moving average. "
|
|
260
|
-
f"Need at least {window} points, got {len(y)}"
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
# Take last 'window' values and average
|
|
264
|
-
return np.mean(y[-window:])
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def forecast_linear(
|
|
268
|
-
x: np.ndarray,
|
|
269
|
-
y: np.ndarray,
|
|
270
|
-
n_rows: int,
|
|
271
|
-
warn_threshold: float = 0.5
|
|
272
|
-
) -> List[float]:
|
|
273
|
-
"""
|
|
274
|
-
Forecast using linear trend.
|
|
275
|
-
|
|
276
|
-
Args:
|
|
277
|
-
x: Time values
|
|
278
|
-
y: Data values
|
|
279
|
-
n_rows: Number of values to forecast
|
|
280
|
-
warn_threshold: R² threshold for warning
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
List of forecasted values
|
|
284
|
-
"""
|
|
285
|
-
slope, intercept, r_squared = fit_linear_trend(x, y)
|
|
286
|
-
|
|
287
|
-
# Warn if poor fit
|
|
288
|
-
if r_squared < warn_threshold:
|
|
289
|
-
warnings.warn(
|
|
290
|
-
f"Linear fit has low R² = {r_squared:.3f}. "
|
|
291
|
-
"Consider using a different method or checking your data."
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
# Generate forecast points
|
|
295
|
-
last_x = x[-1]
|
|
296
|
-
forecast_x = np.arange(last_x + 1, last_x + n_rows + 1)
|
|
297
|
-
forecast_y = slope * forecast_x + intercept
|
|
298
|
-
|
|
299
|
-
return forecast_y.tolist()
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
def forecast_polynomial(
|
|
303
|
-
x: np.ndarray,
|
|
304
|
-
y: np.ndarray,
|
|
305
|
-
n_rows: int,
|
|
306
|
-
degree: int = 2,
|
|
307
|
-
warn_threshold: float = 0.5
|
|
308
|
-
) -> List[float]:
|
|
309
|
-
"""
|
|
310
|
-
Forecast using polynomial trend.
|
|
311
|
-
|
|
312
|
-
Args:
|
|
313
|
-
x: Time values
|
|
314
|
-
y: Data values
|
|
315
|
-
n_rows: Number of values to forecast
|
|
316
|
-
degree: Polynomial degree
|
|
317
|
-
warn_threshold: R² threshold for warning
|
|
318
|
-
|
|
319
|
-
Returns:
|
|
320
|
-
List of forecasted values
|
|
321
|
-
"""
|
|
322
|
-
coeffs, r_squared = fit_polynomial_trend(x, y, degree)
|
|
323
|
-
|
|
324
|
-
# Warn if poor fit
|
|
325
|
-
if r_squared < warn_threshold:
|
|
326
|
-
warnings.warn(
|
|
327
|
-
f"Polynomial fit (degree={degree}) has low R² = {r_squared:.3f}. "
|
|
328
|
-
"Consider adjusting degree or using a different method."
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
# Generate forecast points
|
|
332
|
-
last_x = x[-1]
|
|
333
|
-
forecast_x = np.arange(last_x + 1, last_x + n_rows + 1)
|
|
334
|
-
forecast_y = np.polyval(coeffs, forecast_x)
|
|
335
|
-
|
|
336
|
-
return forecast_y.tolist()
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
def forecast_exponential(
|
|
340
|
-
x: np.ndarray,
|
|
341
|
-
y: np.ndarray,
|
|
342
|
-
n_rows: int,
|
|
343
|
-
warn_threshold: float = 0.5
|
|
344
|
-
) -> List[float]:
|
|
345
|
-
"""
|
|
346
|
-
Forecast using exponential trend.
|
|
347
|
-
|
|
348
|
-
Args:
|
|
349
|
-
x: Time values
|
|
350
|
-
y: Data values
|
|
351
|
-
n_rows: Number of values to forecast
|
|
352
|
-
warn_threshold: R² threshold for warning
|
|
353
|
-
|
|
354
|
-
Returns:
|
|
355
|
-
List of forecasted values
|
|
356
|
-
"""
|
|
357
|
-
a, b, r_squared = fit_exponential_trend(x, y)
|
|
358
|
-
|
|
359
|
-
# Warn if poor fit
|
|
360
|
-
if r_squared < warn_threshold:
|
|
361
|
-
warnings.warn(
|
|
362
|
-
f"Exponential fit has low R² = {r_squared:.3f}. "
|
|
363
|
-
"Consider using a different method."
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
# Generate forecast points
|
|
367
|
-
last_x = x[-1]
|
|
368
|
-
forecast_x = np.arange(last_x + 1, last_x + n_rows + 1)
|
|
369
|
-
forecast_y = a * np.exp(b * forecast_x)
|
|
370
|
-
|
|
371
|
-
return forecast_y.tolist()
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
def forecast_moving_average(
|
|
375
|
-
y: np.ndarray,
|
|
376
|
-
n_rows: int,
|
|
377
|
-
window: int = 3
|
|
378
|
-
) -> List[float]:
|
|
379
|
-
"""
|
|
380
|
-
Forecast using moving average.
|
|
381
|
-
|
|
382
|
-
Args:
|
|
383
|
-
y: Data values
|
|
384
|
-
n_rows: Number of values to forecast
|
|
385
|
-
window: Window size for moving average
|
|
386
|
-
|
|
387
|
-
Returns:
|
|
388
|
-
List of forecasted values
|
|
389
|
-
"""
|
|
390
|
-
# Calculate initial moving average
|
|
391
|
-
ma_value = calculate_moving_average(y, window)
|
|
392
|
-
|
|
393
|
-
# Repeat the moving average for all forecast points
|
|
394
|
-
# Note: This is a simple approach. More sophisticated would be
|
|
395
|
-
# to recalculate MA as we add forecasted points.
|
|
396
|
-
return [ma_value] * n_rows
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
def auto_select_method(x: np.ndarray, y: np.ndarray) -> str:
|
|
400
|
-
"""
|
|
401
|
-
Automatically select best forecasting method based on data.
|
|
402
|
-
|
|
403
|
-
Args:
|
|
404
|
-
x: Time values
|
|
405
|
-
y: Data values
|
|
406
|
-
|
|
407
|
-
Returns:
|
|
408
|
-
Best method name
|
|
409
|
-
"""
|
|
410
|
-
methods_to_try = []
|
|
411
|
-
|
|
412
|
-
# Always try linear
|
|
413
|
-
methods_to_try.append(ForecastMethod.LINEAR)
|
|
414
|
-
|
|
415
|
-
# Try polynomial if enough data
|
|
416
|
-
if len(y) >= MIN_DATA_POINTS[ForecastMethod.POLYNOMIAL]:
|
|
417
|
-
methods_to_try.append(ForecastMethod.POLYNOMIAL)
|
|
418
|
-
|
|
419
|
-
# Try exponential if all positive and enough data
|
|
420
|
-
if np.all(y > 0) and len(y) >= MIN_DATA_POINTS[ForecastMethod.EXPONENTIAL]:
|
|
421
|
-
methods_to_try.append(ForecastMethod.EXPONENTIAL)
|
|
422
|
-
|
|
423
|
-
# Fit each method and compare R²
|
|
424
|
-
best_method = ForecastMethod.LINEAR
|
|
425
|
-
best_r_squared = -np.inf
|
|
426
|
-
|
|
427
|
-
for method in methods_to_try:
|
|
428
|
-
try:
|
|
429
|
-
if method == ForecastMethod.LINEAR:
|
|
430
|
-
_, _, r_squared = fit_linear_trend(x, y)
|
|
431
|
-
elif method == ForecastMethod.POLYNOMIAL:
|
|
432
|
-
_, r_squared = fit_polynomial_trend(x, y, degree=2)
|
|
433
|
-
elif method == ForecastMethod.EXPONENTIAL:
|
|
434
|
-
_, _, r_squared = fit_exponential_trend(x, y)
|
|
435
|
-
|
|
436
|
-
if r_squared > best_r_squared:
|
|
437
|
-
best_r_squared = r_squared
|
|
438
|
-
best_method = method
|
|
439
|
-
except:
|
|
440
|
-
# Skip methods that fail
|
|
441
|
-
continue
|
|
442
|
-
|
|
443
|
-
return best_method
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
def forecast_values(
|
|
447
|
-
df_polars,
|
|
448
|
-
column: str,
|
|
449
|
-
n_rows: int,
|
|
450
|
-
method: str = ForecastMethod.LINEAR,
|
|
451
|
-
time_col: Optional[str] = None,
|
|
452
|
-
**params
|
|
453
|
-
) -> List[float]:
|
|
454
|
-
"""
|
|
455
|
-
Main forecasting function.
|
|
456
|
-
|
|
457
|
-
Args:
|
|
458
|
-
df_polars: Input Polars DataFrame
|
|
459
|
-
column: Column to forecast
|
|
460
|
-
n_rows: Number of values to forecast
|
|
461
|
-
method: Forecasting method (linear, polynomial, exponential, moving_average, auto)
|
|
462
|
-
time_col: Time column name (auto-detect if None)
|
|
463
|
-
**params: Method-specific parameters:
|
|
464
|
-
- degree: For polynomial (default: 2)
|
|
465
|
-
- window: For moving_average (default: 3)
|
|
466
|
-
- warn_threshold: R² threshold for warnings (default: 0.5)
|
|
467
|
-
|
|
468
|
-
Returns:
|
|
469
|
-
List of forecasted values
|
|
470
|
-
|
|
471
|
-
Raises:
|
|
472
|
-
ValidationError: If data is invalid
|
|
473
|
-
AugmentError: If forecasting fails
|
|
474
|
-
"""
|
|
475
|
-
# Validate data
|
|
476
|
-
validate_forecast_data(df_polars, column, method)
|
|
477
|
-
|
|
478
|
-
# Get data values
|
|
479
|
-
y = df_polars[column].to_numpy()
|
|
480
|
-
|
|
481
|
-
# Detect or use time column
|
|
482
|
-
time_col_name = detect_time_column(df_polars, time_col)
|
|
483
|
-
|
|
484
|
-
if time_col_name:
|
|
485
|
-
# Use actual time column
|
|
486
|
-
x = df_polars[time_col_name].to_numpy()
|
|
487
|
-
# Convert to numeric if needed (e.g., datetime to timestamp)
|
|
488
|
-
if not np.issubdtype(x.dtype, np.number):
|
|
489
|
-
# Try to convert to numeric
|
|
490
|
-
try:
|
|
491
|
-
x = np.arange(len(x)) # Fallback to index
|
|
492
|
-
except:
|
|
493
|
-
x = np.arange(len(x))
|
|
494
|
-
else:
|
|
495
|
-
# Use row index as time
|
|
496
|
-
x = np.arange(len(y))
|
|
497
|
-
warnings.warn(
|
|
498
|
-
"No time column detected. Using row index as time axis. "
|
|
499
|
-
"Specify time_col parameter for better results."
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
# Auto-select method if requested
|
|
503
|
-
if method == ForecastMethod.AUTO or method == "auto":
|
|
504
|
-
method = auto_select_method(x, y)
|
|
505
|
-
print(f"Auto-selected forecasting method: {method}")
|
|
506
|
-
|
|
507
|
-
# Forecast based on method
|
|
508
|
-
try:
|
|
509
|
-
if method == ForecastMethod.LINEAR:
|
|
510
|
-
return forecast_linear(x, y, n_rows, params.get('warn_threshold', 0.5))
|
|
511
|
-
|
|
512
|
-
elif method == ForecastMethod.POLYNOMIAL:
|
|
513
|
-
degree = params.get('degree', 2)
|
|
514
|
-
return forecast_polynomial(x, y, n_rows, degree, params.get('warn_threshold', 0.5))
|
|
515
|
-
|
|
516
|
-
elif method == ForecastMethod.EXPONENTIAL:
|
|
517
|
-
return forecast_exponential(x, y, n_rows, params.get('warn_threshold', 0.5))
|
|
518
|
-
|
|
519
|
-
elif method == ForecastMethod.MOVING_AVERAGE:
|
|
520
|
-
window = params.get('window', 3)
|
|
521
|
-
return forecast_moving_average(y, n_rows, window)
|
|
522
|
-
|
|
523
|
-
else:
|
|
524
|
-
raise ValidationError(
|
|
525
|
-
f"Unknown forecasting method: '{method}'. "
|
|
526
|
-
f"Supported: linear, polynomial, exponential, moving_average, auto"
|
|
527
|
-
)
|
|
528
|
-
|
|
529
|
-
except Exception as e:
|
|
530
|
-
if isinstance(e, (ValidationError, AugmentError)):
|
|
531
|
-
raise
|
|
532
|
-
raise AugmentError(f"Forecasting failed: {e}")
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
def detect_seasonality(y: np.ndarray, max_period: int = 12) -> Tuple[int, float]:
|
|
538
|
-
"""
|
|
539
|
-
Detect seasonal period using autocorrelation.
|
|
540
|
-
|
|
541
|
-
Args:
|
|
542
|
-
y: Time series data
|
|
543
|
-
max_period: Maximum period to check (default: 12 for monthly data)
|
|
544
|
-
|
|
545
|
-
Returns:
|
|
546
|
-
Tuple of (period, strength)
|
|
547
|
-
- period: Detected seasonal period (0 if no seasonality)
|
|
548
|
-
- strength: Strength of seasonality (0-1)
|
|
549
|
-
"""
|
|
550
|
-
if len(y) < max_period * 2:
|
|
551
|
-
return 0, 0.0
|
|
552
|
-
|
|
553
|
-
# Calculate autocorrelation for different lags
|
|
554
|
-
autocorr = []
|
|
555
|
-
mean_y = np.mean(y)
|
|
556
|
-
var_y = np.var(y)
|
|
557
|
-
|
|
558
|
-
if var_y == 0:
|
|
559
|
-
return 0, 0.0
|
|
560
|
-
|
|
561
|
-
for lag in range(1, min(max_period + 1, len(y) // 2)):
|
|
562
|
-
# Calculate autocorrelation at this lag
|
|
563
|
-
numerator = np.sum((y[:-lag] - mean_y) * (y[lag:] - mean_y))
|
|
564
|
-
denominator = len(y[:-lag]) * var_y
|
|
565
|
-
|
|
566
|
-
if denominator > 0:
|
|
567
|
-
acf = numerator / denominator
|
|
568
|
-
autocorr.append((lag, acf))
|
|
569
|
-
|
|
570
|
-
if not autocorr:
|
|
571
|
-
return 0, 0.0
|
|
572
|
-
|
|
573
|
-
# Find the lag with highest positive autocorrelation (excluding lag 1)
|
|
574
|
-
autocorr_sorted = sorted(autocorr[1:], key=lambda x: x[1], reverse=True)
|
|
575
|
-
|
|
576
|
-
if not autocorr_sorted or autocorr_sorted[0][1] < 0.3:
|
|
577
|
-
# No significant seasonality
|
|
578
|
-
return 0, 0.0
|
|
579
|
-
|
|
580
|
-
period = autocorr_sorted[0][0]
|
|
581
|
-
strength = autocorr_sorted[0][1]
|
|
582
|
-
|
|
583
|
-
return period, strength
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
def decompose_seasonal(y: np.ndarray, period: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
587
|
-
"""
|
|
588
|
-
Simple seasonal decomposition: trend + seasonal + residual.
|
|
589
|
-
|
|
590
|
-
Uses moving average for trend and averages for seasonal component.
|
|
591
|
-
|
|
592
|
-
Args:
|
|
593
|
-
y: Time series data
|
|
594
|
-
period: Seasonal period
|
|
595
|
-
|
|
596
|
-
Returns:
|
|
597
|
-
Tuple of (trend, seasonal, residual)
|
|
598
|
-
"""
|
|
599
|
-
n = len(y)
|
|
600
|
-
|
|
601
|
-
# Calculate trend using centered moving average
|
|
602
|
-
trend = np.zeros(n)
|
|
603
|
-
half_period = period // 2
|
|
604
|
-
|
|
605
|
-
for i in range(n):
|
|
606
|
-
start = max(0, i - half_period)
|
|
607
|
-
end = min(n, i + half_period + 1)
|
|
608
|
-
trend[i] = np.mean(y[start:end])
|
|
609
|
-
|
|
610
|
-
# Detrend
|
|
611
|
-
detrended = y - trend
|
|
612
|
-
|
|
613
|
-
# Calculate seasonal component (average for each position in cycle)
|
|
614
|
-
seasonal = np.zeros(n)
|
|
615
|
-
seasonal_averages = np.zeros(period)
|
|
616
|
-
|
|
617
|
-
for i in range(period):
|
|
618
|
-
# Get all values at this position in the cycle
|
|
619
|
-
indices = list(range(i, n, period))
|
|
620
|
-
if indices:
|
|
621
|
-
seasonal_averages[i] = np.mean(detrended[indices])
|
|
622
|
-
|
|
623
|
-
# Normalize seasonal component to sum to zero
|
|
624
|
-
seasonal_averages -= np.mean(seasonal_averages)
|
|
625
|
-
|
|
626
|
-
# Assign seasonal values
|
|
627
|
-
for i in range(n):
|
|
628
|
-
seasonal[i] = seasonal_averages[i % period]
|
|
629
|
-
|
|
630
|
-
# Calculate residual
|
|
631
|
-
residual = y - trend - seasonal
|
|
632
|
-
|
|
633
|
-
return trend, seasonal, residual
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
def forecast_seasonal(
|
|
637
|
-
y: np.ndarray,
|
|
638
|
-
n_rows: int,
|
|
639
|
-
period: Optional[int] = None,
|
|
640
|
-
auto_detect: bool = True
|
|
641
|
-
) -> List[float]:
|
|
642
|
-
"""
|
|
643
|
-
Forecast using seasonal decomposition.
|
|
644
|
-
|
|
645
|
-
Decomposes series into trend + seasonal + residual, then forecasts each component.
|
|
646
|
-
|
|
647
|
-
Args:
|
|
648
|
-
y: Time series data
|
|
649
|
-
n_rows: Number of values to forecast
|
|
650
|
-
period: Seasonal period (e.g., 12 for monthly, 7 for daily)
|
|
651
|
-
If None and auto_detect=True, will attempt to detect
|
|
652
|
-
auto_detect: Whether to auto-detect period if not provided
|
|
653
|
-
|
|
654
|
-
Returns:
|
|
655
|
-
List of forecasted values
|
|
656
|
-
|
|
657
|
-
Raises:
|
|
658
|
-
ValidationError: If insufficient data or no seasonality detected
|
|
659
|
-
"""
|
|
660
|
-
n = len(y)
|
|
661
|
-
|
|
662
|
-
# Detect or validate period
|
|
663
|
-
if period is None:
|
|
664
|
-
if not auto_detect:
|
|
665
|
-
raise ValidationError(
|
|
666
|
-
"Seasonal forecasting requires a period parameter. "
|
|
667
|
-
"Provide period or set auto_detect=True"
|
|
668
|
-
)
|
|
669
|
-
|
|
670
|
-
# Auto-detect seasonality
|
|
671
|
-
detected_period, strength = detect_seasonality(y)
|
|
672
|
-
|
|
673
|
-
if detected_period == 0:
|
|
674
|
-
raise ValidationError(
|
|
675
|
-
"No significant seasonality detected in data. "
|
|
676
|
-
"Try a different forecasting method or specify period manually."
|
|
677
|
-
)
|
|
678
|
-
|
|
679
|
-
period = detected_period
|
|
680
|
-
|
|
681
|
-
if strength < 0.5:
|
|
682
|
-
warnings.warn(
|
|
683
|
-
f"Weak seasonality detected (strength={strength:.2f}). "
|
|
684
|
-
"Results may not be reliable."
|
|
685
|
-
)
|
|
686
|
-
|
|
687
|
-
# Validate period
|
|
688
|
-
if period < 2:
|
|
689
|
-
raise ValidationError(f"Period must be >= 2, got {period}")
|
|
690
|
-
|
|
691
|
-
if n < period * 2:
|
|
692
|
-
raise ValidationError(
|
|
693
|
-
f"Need at least {period * 2} data points for seasonal forecasting "
|
|
694
|
-
f"with period={period}. Got {n} points."
|
|
695
|
-
)
|
|
696
|
-
|
|
697
|
-
# Decompose series
|
|
698
|
-
trend, seasonal, residual = decompose_seasonal(y, period)
|
|
699
|
-
|
|
700
|
-
# Forecast trend (using linear extrapolation)
|
|
701
|
-
x = np.arange(n)
|
|
702
|
-
trend_slope, trend_intercept, _ = fit_linear_trend(x, trend)
|
|
703
|
-
|
|
704
|
-
forecast_x = np.arange(n, n + n_rows)
|
|
705
|
-
forecast_trend = trend_slope * forecast_x + trend_intercept
|
|
706
|
-
|
|
707
|
-
# Forecast seasonal component (repeat the pattern)
|
|
708
|
-
seasonal_pattern = seasonal[-period:] # Last full cycle
|
|
709
|
-
forecast_seasonal_comp = np.tile(seasonal_pattern, (n_rows // period) + 1)[:n_rows]
|
|
710
|
-
|
|
711
|
-
# Combine forecasts
|
|
712
|
-
forecast_values = forecast_trend + forecast_seasonal_comp
|
|
713
|
-
|
|
714
|
-
return forecast_values.tolist()
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
def forecast_holt_winters(
|
|
718
|
-
y: np.ndarray,
|
|
719
|
-
n_rows: int,
|
|
720
|
-
period: Optional[int] = None,
|
|
721
|
-
alpha: float = 0.2,
|
|
722
|
-
beta: float = 0.1,
|
|
723
|
-
gamma: float = 0.1,
|
|
724
|
-
auto_detect: bool = True
|
|
725
|
-
) -> List[float]:
|
|
726
|
-
"""
|
|
727
|
-
Forecast using Holt-Winters triple exponential smoothing.
|
|
728
|
-
|
|
729
|
-
Handles level, trend, and seasonality components.
|
|
730
|
-
|
|
731
|
-
Args:
|
|
732
|
-
y: Time series data
|
|
733
|
-
n_rows: Number of values to forecast
|
|
734
|
-
period: Seasonal period (auto-detect if None)
|
|
735
|
-
alpha: Level smoothing parameter (0-1)
|
|
736
|
-
beta: Trend smoothing parameter (0-1)
|
|
737
|
-
gamma: Seasonal smoothing parameter (0-1)
|
|
738
|
-
auto_detect: Whether to auto-detect period
|
|
739
|
-
|
|
740
|
-
Returns:
|
|
741
|
-
List of forecasted values
|
|
742
|
-
|
|
743
|
-
Raises:
|
|
744
|
-
ValidationError: If parameters invalid
|
|
745
|
-
"""
|
|
746
|
-
n = len(y)
|
|
747
|
-
|
|
748
|
-
# Detect or validate period
|
|
749
|
-
if period is None:
|
|
750
|
-
if not auto_detect:
|
|
751
|
-
raise ValidationError(
|
|
752
|
-
"Holt-Winters requires a period parameter. "
|
|
753
|
-
"Provide period or set auto_detect=True"
|
|
754
|
-
)
|
|
755
|
-
|
|
756
|
-
detected_period, strength = detect_seasonality(y)
|
|
757
|
-
|
|
758
|
-
if detected_period == 0:
|
|
759
|
-
# No seasonality, use double exponential smoothing (no seasonal component)
|
|
760
|
-
return forecast_double_exponential(y, n_rows, alpha, beta)
|
|
761
|
-
|
|
762
|
-
period = detected_period
|
|
763
|
-
|
|
764
|
-
if strength < 0.3:
|
|
765
|
-
warnings.warn(
|
|
766
|
-
f"Weak seasonality detected (strength={strength:.2f}). "
|
|
767
|
-
"Consider using a non-seasonal method."
|
|
768
|
-
)
|
|
769
|
-
|
|
770
|
-
# Validate parameters
|
|
771
|
-
if period < 2:
|
|
772
|
-
raise ValidationError(f"Period must be >= 2, got {period}")
|
|
773
|
-
|
|
774
|
-
if n < period * 2:
|
|
775
|
-
raise ValidationError(
|
|
776
|
-
f"Need at least {period * 2} data points for Holt-Winters "
|
|
777
|
-
f"with period={period}. Got {n} points."
|
|
778
|
-
)
|
|
779
|
-
|
|
780
|
-
for param_name, param_value in [('alpha', alpha), ('beta', beta), ('gamma', gamma)]:
|
|
781
|
-
if not 0 < param_value < 1:
|
|
782
|
-
raise ValidationError(
|
|
783
|
-
f"{param_name} must be between 0 and 1, got {param_value}"
|
|
784
|
-
)
|
|
785
|
-
|
|
786
|
-
# Initialize components
|
|
787
|
-
level = np.zeros(n)
|
|
788
|
-
trend = np.zeros(n)
|
|
789
|
-
seasonal = np.zeros(n + period)
|
|
790
|
-
|
|
791
|
-
# Initial values
|
|
792
|
-
level[0] = np.mean(y[:period])
|
|
793
|
-
trend[0] = (np.mean(y[period:2*period]) - np.mean(y[:period])) / period
|
|
794
|
-
|
|
795
|
-
# Initial seasonal components (first period)
|
|
796
|
-
for i in range(period):
|
|
797
|
-
seasonal[i] = y[i] - level[0]
|
|
798
|
-
|
|
799
|
-
# Holt-Winters equations
|
|
800
|
-
for t in range(1, n):
|
|
801
|
-
# Level
|
|
802
|
-
level[t] = alpha * (y[t] - seasonal[t]) + (1 - alpha) * (level[t-1] + trend[t-1])
|
|
803
|
-
|
|
804
|
-
# Trend
|
|
805
|
-
trend[t] = beta * (level[t] - level[t-1]) + (1 - beta) * trend[t-1]
|
|
806
|
-
|
|
807
|
-
# Seasonal
|
|
808
|
-
seasonal[t + period] = gamma * (y[t] - level[t]) + (1 - gamma) * seasonal[t]
|
|
809
|
-
|
|
810
|
-
# Forecast
|
|
811
|
-
forecast = []
|
|
812
|
-
for i in range(n_rows):
|
|
813
|
-
# Forecast = level + trend * steps + seasonal
|
|
814
|
-
forecast_val = level[-1] + trend[-1] * (i + 1) + seasonal[n + (i % period)]
|
|
815
|
-
forecast.append(forecast_val)
|
|
816
|
-
|
|
817
|
-
return forecast
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
def forecast_double_exponential(
|
|
821
|
-
y: np.ndarray,
|
|
822
|
-
n_rows: int,
|
|
823
|
-
alpha: float = 0.2,
|
|
824
|
-
beta: float = 0.1
|
|
825
|
-
) -> List[float]:
|
|
826
|
-
"""
|
|
827
|
-
Forecast using double exponential smoothing (Holt's method).
|
|
828
|
-
|
|
829
|
-
Handles level and trend, but no seasonality.
|
|
830
|
-
|
|
831
|
-
Args:
|
|
832
|
-
y: Time series data
|
|
833
|
-
n_rows: Number of values to forecast
|
|
834
|
-
alpha: Level smoothing parameter (0-1)
|
|
835
|
-
beta: Trend smoothing parameter (0-1)
|
|
836
|
-
|
|
837
|
-
Returns:
|
|
838
|
-
List of forecasted values
|
|
839
|
-
"""
|
|
840
|
-
n = len(y)
|
|
841
|
-
|
|
842
|
-
# Initialize
|
|
843
|
-
level = np.zeros(n)
|
|
844
|
-
trend = np.zeros(n)
|
|
845
|
-
|
|
846
|
-
level[0] = y[0]
|
|
847
|
-
trend[0] = y[1] - y[0] if n > 1 else 0
|
|
848
|
-
|
|
849
|
-
# Double exponential smoothing equations
|
|
850
|
-
for t in range(1, n):
|
|
851
|
-
level[t] = alpha * y[t] + (1 - alpha) * (level[t-1] + trend[t-1])
|
|
852
|
-
trend[t] = beta * (level[t] - level[t-1]) + (1 - beta) * trend[t-1]
|
|
853
|
-
|
|
854
|
-
# Forecast
|
|
855
|
-
forecast = []
|
|
856
|
-
for i in range(n_rows):
|
|
857
|
-
forecast_val = level[-1] + trend[-1] * (i + 1)
|
|
858
|
-
forecast.append(forecast_val)
|
|
859
|
-
|
|
860
|
-
return forecast
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
def forecast_ensemble(
|
|
864
|
-
x: np.ndarray,
|
|
865
|
-
y: np.ndarray,
|
|
866
|
-
n_rows: int,
|
|
867
|
-
methods: Optional[List[str]] = None,
|
|
868
|
-
weights: Optional[List[float]] = None,
|
|
869
|
-
auto_weight: bool = True
|
|
870
|
-
) -> List[float]:
|
|
871
|
-
"""
|
|
872
|
-
Forecast using weighted ensemble of multiple methods.
|
|
873
|
-
|
|
874
|
-
Combines predictions from multiple forecasting methods.
|
|
875
|
-
|
|
876
|
-
Args:
|
|
877
|
-
x: Time values
|
|
878
|
-
y: Data values
|
|
879
|
-
n_rows: Number of values to forecast
|
|
880
|
-
methods: List of methods to ensemble (default: ['linear', 'polynomial', 'exponential'])
|
|
881
|
-
weights: Weights for each method (auto-calculated if None)
|
|
882
|
-
auto_weight: Whether to auto-weight based on R² scores
|
|
883
|
-
|
|
884
|
-
Returns:
|
|
885
|
-
List of forecasted values
|
|
886
|
-
"""
|
|
887
|
-
# Default methods
|
|
888
|
-
if methods is None:
|
|
889
|
-
methods = [ForecastMethod.LINEAR, ForecastMethod.POLYNOMIAL]
|
|
890
|
-
# Add exponential if all positive
|
|
891
|
-
if np.all(y > 0):
|
|
892
|
-
methods.append(ForecastMethod.EXPONENTIAL)
|
|
893
|
-
|
|
894
|
-
# Validate we have at least 2 methods
|
|
895
|
-
if len(methods) < 2:
|
|
896
|
-
raise ValidationError("Ensemble requires at least 2 methods")
|
|
897
|
-
|
|
898
|
-
# Generate forecasts and calculate R² for each method
|
|
899
|
-
forecasts = []
|
|
900
|
-
r_squared_scores = []
|
|
901
|
-
|
|
902
|
-
for method in methods:
|
|
903
|
-
try:
|
|
904
|
-
if method == ForecastMethod.LINEAR:
|
|
905
|
-
slope, intercept, r2 = fit_linear_trend(x, y)
|
|
906
|
-
forecast_x = np.arange(x[-1] + 1, x[-1] + n_rows + 1)
|
|
907
|
-
forecast_y = slope * forecast_x + intercept
|
|
908
|
-
forecasts.append(forecast_y)
|
|
909
|
-
r_squared_scores.append(r2)
|
|
910
|
-
|
|
911
|
-
elif method == ForecastMethod.POLYNOMIAL:
|
|
912
|
-
coeffs, r2 = fit_polynomial_trend(x, y, degree=2)
|
|
913
|
-
forecast_x = np.arange(x[-1] + 1, x[-1] + n_rows + 1)
|
|
914
|
-
forecast_y = np.polyval(coeffs, forecast_x)
|
|
915
|
-
forecasts.append(forecast_y)
|
|
916
|
-
r_squared_scores.append(r2)
|
|
917
|
-
|
|
918
|
-
elif method == ForecastMethod.EXPONENTIAL:
|
|
919
|
-
if np.all(y > 0):
|
|
920
|
-
a, b, r2 = fit_exponential_trend(x, y)
|
|
921
|
-
forecast_x = np.arange(x[-1] + 1, x[-1] + n_rows + 1)
|
|
922
|
-
forecast_y = a * np.exp(b * forecast_x)
|
|
923
|
-
forecasts.append(forecast_y)
|
|
924
|
-
r_squared_scores.append(r2)
|
|
925
|
-
else:
|
|
926
|
-
continue # Skip if not all positive
|
|
927
|
-
|
|
928
|
-
elif method == ForecastMethod.SEASONAL:
|
|
929
|
-
# Try seasonal forecast
|
|
930
|
-
try:
|
|
931
|
-
forecast_y = forecast_seasonal(y, n_rows, period=None, auto_detect=True)
|
|
932
|
-
forecasts.append(np.array(forecast_y))
|
|
933
|
-
# Estimate R² for seasonal (use decomposition quality)
|
|
934
|
-
r_squared_scores.append(0.7) # Placeholder
|
|
935
|
-
except:
|
|
936
|
-
continue # Skip if seasonal fails
|
|
937
|
-
|
|
938
|
-
except Exception as e:
|
|
939
|
-
warnings.warn(f"Method {method} failed in ensemble: {e}")
|
|
940
|
-
continue
|
|
941
|
-
|
|
942
|
-
if not forecasts:
|
|
943
|
-
raise AugmentError("All ensemble methods failed")
|
|
944
|
-
|
|
945
|
-
# Calculate weights
|
|
946
|
-
if weights is None and auto_weight:
|
|
947
|
-
# Weight by R² scores (normalized)
|
|
948
|
-
r_squared_array = np.array(r_squared_scores)
|
|
949
|
-
# Ensure non-negative
|
|
950
|
-
r_squared_array = np.maximum(r_squared_array, 0)
|
|
951
|
-
|
|
952
|
-
if np.sum(r_squared_array) > 0:
|
|
953
|
-
weights = r_squared_array / np.sum(r_squared_array)
|
|
954
|
-
else:
|
|
955
|
-
# Equal weights if all R² are 0
|
|
956
|
-
weights = np.ones(len(forecasts)) / len(forecasts)
|
|
957
|
-
|
|
958
|
-
elif weights is None:
|
|
959
|
-
# Equal weights
|
|
960
|
-
weights = np.ones(len(forecasts)) / len(forecasts)
|
|
961
|
-
|
|
962
|
-
else:
|
|
963
|
-
# Validate provided weights
|
|
964
|
-
if len(weights) != len(forecasts):
|
|
965
|
-
raise ValidationError(
|
|
966
|
-
f"Number of weights ({len(weights)}) must match number of methods ({len(forecasts)})"
|
|
967
|
-
)
|
|
968
|
-
|
|
969
|
-
# Normalize weights
|
|
970
|
-
weights = np.array(weights)
|
|
971
|
-
weights = weights / np.sum(weights)
|
|
972
|
-
|
|
973
|
-
# Combine forecasts
|
|
974
|
-
ensemble_forecast = np.zeros(n_rows)
|
|
975
|
-
for forecast, weight in zip(forecasts, weights):
|
|
976
|
-
ensemble_forecast += weight * forecast
|
|
977
|
-
|
|
978
|
-
return ensemble_forecast.tolist()
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
def calculate_trend_strength(y: np.ndarray) -> float:
|
|
982
|
-
"""
|
|
983
|
-
Calculate strength of trend in time series.
|
|
984
|
-
|
|
985
|
-
Uses linear regression R² as measure of trend strength.
|
|
986
|
-
|
|
987
|
-
Args:
|
|
988
|
-
y: Time series data
|
|
989
|
-
|
|
990
|
-
Returns:
|
|
991
|
-
Trend strength (0-1)
|
|
992
|
-
"""
|
|
993
|
-
x = np.arange(len(y))
|
|
994
|
-
_, _, r_squared = fit_linear_trend(x, y)
|
|
995
|
-
return r_squared
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
def forecast_values(
|
|
999
|
-
df_polars,
|
|
1000
|
-
column: str,
|
|
1001
|
-
n_rows: int,
|
|
1002
|
-
method: str = ForecastMethod.LINEAR,
|
|
1003
|
-
time_col: Optional[str] = None,
|
|
1004
|
-
**params
|
|
1005
|
-
) -> List[float]:
|
|
1006
|
-
"""
|
|
1007
|
-
Main forecasting function.
|
|
1008
|
-
|
|
1009
|
-
Args:
|
|
1010
|
-
df_polars: Input Polars DataFrame
|
|
1011
|
-
column: Column to forecast
|
|
1012
|
-
n_rows: Number of values to forecast
|
|
1013
|
-
method: Forecasting method (linear, polynomial, exponential, moving_average,
|
|
1014
|
-
seasonal, holt_winters, ensemble, auto)
|
|
1015
|
-
time_col: Time column name (auto-detect if None)
|
|
1016
|
-
**params: Method-specific parameters:
|
|
1017
|
-
- degree: For polynomial (default: 2)
|
|
1018
|
-
- window: For moving_average (default: 3)
|
|
1019
|
-
- period: For seasonal/holt_winters (auto-detect if None)
|
|
1020
|
-
- alpha, beta, gamma: For holt_winters (default: 0.2, 0.1, 0.1)
|
|
1021
|
-
- methods: For ensemble (default: ['linear', 'polynomial'])
|
|
1022
|
-
- weights: For ensemble (auto-calculated if None)
|
|
1023
|
-
- warn_threshold: R² threshold for warnings (default: 0.5)
|
|
1024
|
-
|
|
1025
|
-
Returns:
|
|
1026
|
-
List of forecasted values
|
|
1027
|
-
|
|
1028
|
-
Raises:
|
|
1029
|
-
ValidationError: If data is invalid
|
|
1030
|
-
AugmentError: If forecasting fails
|
|
1031
|
-
"""
|
|
1032
|
-
# Validate data
|
|
1033
|
-
validate_forecast_data(df_polars, column, method)
|
|
1034
|
-
|
|
1035
|
-
# Get data values
|
|
1036
|
-
y = df_polars[column].to_numpy()
|
|
1037
|
-
|
|
1038
|
-
# Detect or use time column
|
|
1039
|
-
time_col_name = detect_time_column(df_polars, time_col)
|
|
1040
|
-
|
|
1041
|
-
if time_col_name:
|
|
1042
|
-
# Use actual time column
|
|
1043
|
-
x = df_polars[time_col_name].to_numpy()
|
|
1044
|
-
# Convert to numeric if needed (e.g., datetime to timestamp)
|
|
1045
|
-
if not np.issubdtype(x.dtype, np.number):
|
|
1046
|
-
# Try to convert to numeric
|
|
1047
|
-
try:
|
|
1048
|
-
x = np.arange(len(x)) # Fallback to index
|
|
1049
|
-
except:
|
|
1050
|
-
x = np.arange(len(x))
|
|
1051
|
-
else:
|
|
1052
|
-
# Use row index as time
|
|
1053
|
-
x = np.arange(len(y))
|
|
1054
|
-
warnings.warn(
|
|
1055
|
-
"No time column detected. Using row index as time axis. "
|
|
1056
|
-
"Specify time_col parameter for better results."
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
# Auto-select method if requested
|
|
1060
|
-
if method == ForecastMethod.AUTO or method == "auto":
|
|
1061
|
-
# Check trend strength
|
|
1062
|
-
trend_strength = calculate_trend_strength(y)
|
|
1063
|
-
|
|
1064
|
-
# Check seasonality
|
|
1065
|
-
detected_period, seasonal_strength = detect_seasonality(y)
|
|
1066
|
-
|
|
1067
|
-
# Decision logic
|
|
1068
|
-
if seasonal_strength > 0.5 and len(y) >= MIN_DATA_POINTS[ForecastMethod.HOLT_WINTERS]:
|
|
1069
|
-
method = ForecastMethod.HOLT_WINTERS
|
|
1070
|
-
elif seasonal_strength > 0.3 and len(y) >= MIN_DATA_POINTS[ForecastMethod.SEASONAL]:
|
|
1071
|
-
method = ForecastMethod.SEASONAL
|
|
1072
|
-
elif trend_strength > 0.7:
|
|
1073
|
-
# Strong trend, try exponential if positive
|
|
1074
|
-
if np.all(y > 0):
|
|
1075
|
-
method = ForecastMethod.EXPONENTIAL
|
|
1076
|
-
else:
|
|
1077
|
-
method = ForecastMethod.POLYNOMIAL
|
|
1078
|
-
elif trend_strength > 0.5:
|
|
1079
|
-
method = ForecastMethod.LINEAR
|
|
1080
|
-
else:
|
|
1081
|
-
# Weak trend, use ensemble
|
|
1082
|
-
method = ForecastMethod.ENSEMBLE
|
|
1083
|
-
|
|
1084
|
-
print(f"Auto-selected forecasting method: {method} "
|
|
1085
|
-
f"(trend={trend_strength:.2f}, seasonal={seasonal_strength:.2f})")
|
|
1086
|
-
|
|
1087
|
-
# Forecast based on method
|
|
1088
|
-
try:
|
|
1089
|
-
if method == ForecastMethod.LINEAR:
|
|
1090
|
-
return forecast_linear(x, y, n_rows, params.get('warn_threshold', 0.5))
|
|
1091
|
-
|
|
1092
|
-
elif method == ForecastMethod.POLYNOMIAL:
|
|
1093
|
-
degree = params.get('degree', 2)
|
|
1094
|
-
return forecast_polynomial(x, y, n_rows, degree, params.get('warn_threshold', 0.5))
|
|
1095
|
-
|
|
1096
|
-
elif method == ForecastMethod.EXPONENTIAL:
|
|
1097
|
-
return forecast_exponential(x, y, n_rows, params.get('warn_threshold', 0.5))
|
|
1098
|
-
|
|
1099
|
-
elif method == ForecastMethod.MOVING_AVERAGE:
|
|
1100
|
-
window = params.get('window', 3)
|
|
1101
|
-
return forecast_moving_average(y, n_rows, window)
|
|
1102
|
-
|
|
1103
|
-
elif method == ForecastMethod.SEASONAL:
|
|
1104
|
-
period = params.get('period', None)
|
|
1105
|
-
auto_detect = params.get('auto_detect', True)
|
|
1106
|
-
return forecast_seasonal(y, n_rows, period, auto_detect)
|
|
1107
|
-
|
|
1108
|
-
elif method == ForecastMethod.HOLT_WINTERS:
|
|
1109
|
-
period = params.get('period', None)
|
|
1110
|
-
alpha = params.get('alpha', 0.2)
|
|
1111
|
-
beta = params.get('beta', 0.1)
|
|
1112
|
-
gamma = params.get('gamma', 0.1)
|
|
1113
|
-
auto_detect = params.get('auto_detect', True)
|
|
1114
|
-
return forecast_holt_winters(y, n_rows, period, alpha, beta, gamma, auto_detect)
|
|
1115
|
-
|
|
1116
|
-
elif method == ForecastMethod.ENSEMBLE:
|
|
1117
|
-
methods = params.get('methods', None)
|
|
1118
|
-
weights = params.get('weights', None)
|
|
1119
|
-
auto_weight = params.get('auto_weight', True)
|
|
1120
|
-
return forecast_ensemble(x, y, n_rows, methods, weights, auto_weight)
|
|
1121
|
-
|
|
1122
|
-
else:
|
|
1123
|
-
raise ValidationError(
|
|
1124
|
-
f"Unknown forecasting method: '{method}'. "
|
|
1125
|
-
f"Supported: linear, polynomial, exponential, moving_average, "
|
|
1126
|
-
f"seasonal, holt_winters, ensemble, auto"
|
|
1127
|
-
)
|
|
1128
|
-
|
|
1129
|
-
except Exception as e:
|
|
1130
|
-
if isinstance(e, (ValidationError, AugmentError)):
|
|
1131
|
-
raise
|
|
1132
|
-
raise AugmentError(f"Forecasting failed: {e}")
|