additory 0.1.0a4__py3-none-any.whl → 0.1.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. additory/__init__.py +58 -14
  2. additory/common/__init__.py +31 -147
  3. additory/common/column_selector.py +255 -0
  4. additory/common/distributions.py +286 -613
  5. additory/common/extractors.py +313 -0
  6. additory/common/knn_imputation.py +332 -0
  7. additory/common/result.py +380 -0
  8. additory/common/strategy_parser.py +243 -0
  9. additory/common/unit_conversions.py +338 -0
  10. additory/common/validation.py +283 -103
  11. additory/core/__init__.py +34 -22
  12. additory/core/backend.py +258 -0
  13. additory/core/config.py +177 -305
  14. additory/core/logging.py +230 -24
  15. additory/core/memory_manager.py +157 -495
  16. additory/expressions/__init__.py +2 -23
  17. additory/expressions/compiler.py +457 -0
  18. additory/expressions/engine.py +264 -487
  19. additory/expressions/integrity.py +179 -0
  20. additory/expressions/loader.py +263 -0
  21. additory/expressions/parser.py +363 -167
  22. additory/expressions/resolver.py +274 -0
  23. additory/functions/__init__.py +1 -0
  24. additory/functions/analyze/__init__.py +144 -0
  25. additory/functions/analyze/cardinality.py +58 -0
  26. additory/functions/analyze/correlations.py +66 -0
  27. additory/functions/analyze/distributions.py +53 -0
  28. additory/functions/analyze/duplicates.py +49 -0
  29. additory/functions/analyze/features.py +61 -0
  30. additory/functions/analyze/imputation.py +66 -0
  31. additory/functions/analyze/outliers.py +65 -0
  32. additory/functions/analyze/patterns.py +65 -0
  33. additory/functions/analyze/presets.py +72 -0
  34. additory/functions/analyze/quality.py +59 -0
  35. additory/functions/analyze/timeseries.py +53 -0
  36. additory/functions/analyze/types.py +45 -0
  37. additory/functions/expressions/__init__.py +161 -0
  38. additory/functions/snapshot/__init__.py +82 -0
  39. additory/functions/snapshot/filter.py +119 -0
  40. additory/functions/synthetic/__init__.py +113 -0
  41. additory/functions/synthetic/mode_detector.py +47 -0
  42. additory/functions/synthetic/strategies/__init__.py +1 -0
  43. additory/functions/synthetic/strategies/advanced.py +35 -0
  44. additory/functions/synthetic/strategies/augmentative.py +160 -0
  45. additory/functions/synthetic/strategies/generative.py +168 -0
  46. additory/functions/synthetic/strategies/presets.py +116 -0
  47. additory/functions/to/__init__.py +188 -0
  48. additory/functions/to/lookup.py +351 -0
  49. additory/functions/to/merge.py +189 -0
  50. additory/functions/to/sort.py +91 -0
  51. additory/functions/to/summarize.py +170 -0
  52. additory/functions/transform/__init__.py +140 -0
  53. additory/functions/transform/datetime.py +79 -0
  54. additory/functions/transform/extract.py +85 -0
  55. additory/functions/transform/harmonize.py +105 -0
  56. additory/functions/transform/knn.py +62 -0
  57. additory/functions/transform/onehotencoding.py +68 -0
  58. additory/functions/transform/transpose.py +42 -0
  59. additory-0.1.1a1.dist-info/METADATA +83 -0
  60. additory-0.1.1a1.dist-info/RECORD +62 -0
  61. additory/analysis/__init__.py +0 -48
  62. additory/analysis/cardinality.py +0 -126
  63. additory/analysis/correlations.py +0 -124
  64. additory/analysis/distributions.py +0 -376
  65. additory/analysis/quality.py +0 -158
  66. additory/analysis/scan.py +0 -400
  67. additory/common/backend.py +0 -371
  68. additory/common/column_utils.py +0 -191
  69. additory/common/exceptions.py +0 -62
  70. additory/common/lists.py +0 -229
  71. additory/common/patterns.py +0 -240
  72. additory/common/resolver.py +0 -567
  73. additory/common/sample_data.py +0 -182
  74. additory/core/ast_builder.py +0 -165
  75. additory/core/backends/__init__.py +0 -23
  76. additory/core/backends/arrow_bridge.py +0 -483
  77. additory/core/backends/cudf_bridge.py +0 -355
  78. additory/core/column_positioning.py +0 -358
  79. additory/core/compiler_polars.py +0 -166
  80. additory/core/enhanced_cache_manager.py +0 -1119
  81. additory/core/enhanced_matchers.py +0 -473
  82. additory/core/enhanced_version_manager.py +0 -325
  83. additory/core/executor.py +0 -59
  84. additory/core/integrity_manager.py +0 -477
  85. additory/core/loader.py +0 -190
  86. additory/core/namespace_manager.py +0 -657
  87. additory/core/parser.py +0 -176
  88. additory/core/polars_expression_engine.py +0 -601
  89. additory/core/registry.py +0 -177
  90. additory/core/sample_data_manager.py +0 -492
  91. additory/core/user_namespace.py +0 -751
  92. additory/core/validator.py +0 -27
  93. additory/dynamic_api.py +0 -352
  94. additory/expressions/proxy.py +0 -549
  95. additory/expressions/registry.py +0 -313
  96. additory/expressions/samples.py +0 -492
  97. additory/synthetic/__init__.py +0 -13
  98. additory/synthetic/column_name_resolver.py +0 -149
  99. additory/synthetic/deduce.py +0 -259
  100. additory/synthetic/distributions.py +0 -22
  101. additory/synthetic/forecast.py +0 -1132
  102. additory/synthetic/linked_list_parser.py +0 -415
  103. additory/synthetic/namespace_lookup.py +0 -129
  104. additory/synthetic/smote.py +0 -320
  105. additory/synthetic/strategies.py +0 -926
  106. additory/synthetic/synthesizer.py +0 -713
  107. additory/utilities/__init__.py +0 -53
  108. additory/utilities/encoding.py +0 -600
  109. additory/utilities/games.py +0 -300
  110. additory/utilities/keys.py +0 -8
  111. additory/utilities/lookup.py +0 -103
  112. additory/utilities/matchers.py +0 -216
  113. additory/utilities/resolvers.py +0 -286
  114. additory/utilities/settings.py +0 -167
  115. additory/utilities/units.py +0 -749
  116. additory/utilities/validators.py +0 -153
  117. additory-0.1.0a4.dist-info/METADATA +0 -311
  118. additory-0.1.0a4.dist-info/RECORD +0 -72
  119. additory-0.1.0a4.dist-info/licenses/LICENSE +0 -21
  120. {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/WHEEL +0 -0
  121. {additory-0.1.0a4.dist-info → additory-0.1.1a1.dist-info}/top_level.txt +0 -0
@@ -1,1132 +0,0 @@
1
- """
2
- Forecast Strategies for Synthetic Data Generation
3
-
4
- Provides time series forecasting capabilities:
5
- - Linear trend forecasting
6
- - Polynomial trend forecasting
7
- - Exponential growth forecasting
8
- - Moving average forecasting
9
- """
10
-
11
- from typing import Optional, Tuple, List, Any
12
- import warnings
13
-
14
- import numpy as np
15
-
16
- from additory.common.exceptions import ValidationError, AugmentError
17
-
18
-
19
- class ForecastMethod:
20
- """Supported forecasting methods."""
21
- LINEAR = "linear"
22
- POLYNOMIAL = "polynomial"
23
- EXPONENTIAL = "exponential"
24
- MOVING_AVERAGE = "moving_average"
25
- SEASONAL = "seasonal"
26
- HOLT_WINTERS = "holt_winters"
27
- ENSEMBLE = "ensemble"
28
- AUTO = "auto"
29
-
30
-
31
- # Minimum data requirements for each method
32
- MIN_DATA_POINTS = {
33
- ForecastMethod.LINEAR: 3,
34
- ForecastMethod.POLYNOMIAL: 5,
35
- ForecastMethod.EXPONENTIAL: 10,
36
- ForecastMethod.MOVING_AVERAGE: 4,
37
- ForecastMethod.SEASONAL: 12, # Need at least one full cycle
38
- ForecastMethod.HOLT_WINTERS: 12, # Need at least one full cycle
39
- ForecastMethod.ENSEMBLE: 5,
40
- }
41
-
42
-
43
- def detect_time_column(df_polars, hint: Optional[str] = None) -> Optional[str]:
44
- """
45
- Detect time/date column in dataframe.
46
-
47
- Args:
48
- df_polars: Polars DataFrame
49
- hint: Optional column name hint
50
-
51
- Returns:
52
- Column name or None if not found
53
-
54
- Detection logic:
55
- 1. If hint provided and exists, use it
56
- 2. Look for datetime dtype columns
57
- 3. Look for common time column names
58
- 4. Return None (will use row index)
59
- """
60
- import polars as pl
61
-
62
- # If hint provided, validate it exists
63
- if hint:
64
- if hint in df_polars.columns:
65
- return hint
66
- else:
67
- raise ValidationError(f"Specified time column '{hint}' not found in dataframe")
68
-
69
- # Check for datetime columns
70
- for col in df_polars.columns:
71
- if df_polars[col].dtype in [pl.Date, pl.Datetime, pl.Duration]:
72
- return col
73
-
74
- # Check for common time column names
75
- time_names = ['date', 'time', 'timestamp', 'datetime', 'period', 'day', 'month', 'year']
76
- for col in df_polars.columns:
77
- if col.lower() in time_names:
78
- return col
79
-
80
- # No time column found
81
- return None
82
-
83
-
84
- def validate_forecast_data(
85
- df_polars,
86
- column: str,
87
- method: str,
88
- min_points: Optional[int] = None
89
- ) -> None:
90
- """
91
- Validate data is suitable for forecasting.
92
-
93
- Args:
94
- df_polars: Polars DataFrame
95
- column: Column to forecast
96
- method: Forecasting method
97
- min_points: Minimum required data points (optional)
98
-
99
- Raises:
100
- ValidationError: If data is invalid
101
- """
102
- # Check column exists
103
- if column not in df_polars.columns:
104
- raise ValidationError(f"Column '{column}' not found in dataframe")
105
-
106
- # Check column is numeric
107
- col_data = df_polars[column]
108
- if not col_data.dtype.is_numeric():
109
- raise ValidationError(
110
- f"Column '{column}' must be numeric for forecasting. "
111
- f"Got dtype: {col_data.dtype}"
112
- )
113
-
114
- # Check minimum data points
115
- n_points = len(df_polars)
116
- required_min = min_points or MIN_DATA_POINTS.get(method, 3)
117
-
118
- if n_points < required_min:
119
- raise ValidationError(
120
- f"Insufficient data for {method} forecasting. "
121
- f"Need at least {required_min} points, got {n_points}"
122
- )
123
-
124
- # Check for null values
125
- null_count = col_data.null_count()
126
- if null_count > 0:
127
- raise ValidationError(
128
- f"Column '{column}' contains {null_count} null values. "
129
- "Remove nulls before forecasting."
130
- )
131
-
132
- # Check for variance (all values identical)
133
- values = col_data.to_numpy()
134
- if np.std(values) == 0:
135
- raise ValidationError(
136
- f"Column '{column}' has no variance (all values identical). "
137
- "Cannot forecast constant data."
138
- )
139
-
140
-
141
- def fit_linear_trend(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
142
- """
143
- Fit linear trend using least squares.
144
-
145
- Formula: y = mx + b
146
-
147
- Args:
148
- x: Independent variable (time)
149
- y: Dependent variable (values to forecast)
150
-
151
- Returns:
152
- Tuple of (slope, intercept, r_squared)
153
- """
154
- # Use numpy polyfit for linear regression
155
- coeffs = np.polyfit(x, y, 1)
156
- slope, intercept = coeffs[0], coeffs[1]
157
-
158
- # Calculate R²
159
- y_pred = slope * x + intercept
160
- ss_res = np.sum((y - y_pred) ** 2)
161
- ss_tot = np.sum((y - np.mean(y)) ** 2)
162
- r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
163
-
164
- return slope, intercept, r_squared
165
-
166
-
167
- def fit_polynomial_trend(
168
- x: np.ndarray,
169
- y: np.ndarray,
170
- degree: int = 2
171
- ) -> Tuple[np.ndarray, float]:
172
- """
173
- Fit polynomial trend using least squares.
174
-
175
- Formula: y = a*x^n + b*x^(n-1) + ... + c
176
-
177
- Args:
178
- x: Independent variable (time)
179
- y: Dependent variable (values to forecast)
180
- degree: Polynomial degree (default: 2)
181
-
182
- Returns:
183
- Tuple of (coefficients, r_squared)
184
- """
185
- # Validate degree
186
- if degree < 1:
187
- raise ValidationError(f"Polynomial degree must be >= 1, got {degree}")
188
- if degree > 10:
189
- warnings.warn(f"High polynomial degree ({degree}) may cause overfitting")
190
-
191
- # Fit polynomial
192
- coeffs = np.polyfit(x, y, degree)
193
-
194
- # Calculate R²
195
- y_pred = np.polyval(coeffs, x)
196
- ss_res = np.sum((y - y_pred) ** 2)
197
- ss_tot = np.sum((y - np.mean(y)) ** 2)
198
- r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
199
-
200
- return coeffs, r_squared
201
-
202
-
203
- def fit_exponential_trend(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
204
- """
205
- Fit exponential trend.
206
-
207
- Formula: y = a * e^(b*x)
208
-
209
- Method: Transform to log space and fit linear
210
- log(y) = log(a) + b*x
211
-
212
- Args:
213
- x: Independent variable (time)
214
- y: Dependent variable (values to forecast)
215
-
216
- Returns:
217
- Tuple of (a, b, r_squared)
218
- """
219
- # Check for non-positive values
220
- if np.any(y <= 0):
221
- raise ValidationError(
222
- "Exponential forecasting requires all positive values. "
223
- "Found zero or negative values."
224
- )
225
-
226
- # Transform to log space
227
- log_y = np.log(y)
228
-
229
- # Fit linear in log space
230
- coeffs = np.polyfit(x, log_y, 1)
231
- b, log_a = coeffs[0], coeffs[1]
232
- a = np.exp(log_a)
233
-
234
- # Calculate R² in original space
235
- y_pred = a * np.exp(b * x)
236
- ss_res = np.sum((y - y_pred) ** 2)
237
- ss_tot = np.sum((y - np.mean(y)) ** 2)
238
- r_squared = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
239
-
240
- return a, b, r_squared
241
-
242
-
243
- def calculate_moving_average(y: np.ndarray, window: int) -> float:
244
- """
245
- Calculate moving average of last 'window' values.
246
-
247
- Args:
248
- y: Array of values
249
- window: Window size
250
-
251
- Returns:
252
- Moving average value
253
- """
254
- if window < 1:
255
- raise ValidationError(f"Window size must be >= 1, got {window}")
256
-
257
- if len(y) < window:
258
- raise ValidationError(
259
- f"Not enough data for moving average. "
260
- f"Need at least {window} points, got {len(y)}"
261
- )
262
-
263
- # Take last 'window' values and average
264
- return np.mean(y[-window:])
265
-
266
-
267
- def forecast_linear(
268
- x: np.ndarray,
269
- y: np.ndarray,
270
- n_rows: int,
271
- warn_threshold: float = 0.5
272
- ) -> List[float]:
273
- """
274
- Forecast using linear trend.
275
-
276
- Args:
277
- x: Time values
278
- y: Data values
279
- n_rows: Number of values to forecast
280
- warn_threshold: R² threshold for warning
281
-
282
- Returns:
283
- List of forecasted values
284
- """
285
- slope, intercept, r_squared = fit_linear_trend(x, y)
286
-
287
- # Warn if poor fit
288
- if r_squared < warn_threshold:
289
- warnings.warn(
290
- f"Linear fit has low R² = {r_squared:.3f}. "
291
- "Consider using a different method or checking your data."
292
- )
293
-
294
- # Generate forecast points
295
- last_x = x[-1]
296
- forecast_x = np.arange(last_x + 1, last_x + n_rows + 1)
297
- forecast_y = slope * forecast_x + intercept
298
-
299
- return forecast_y.tolist()
300
-
301
-
302
- def forecast_polynomial(
303
- x: np.ndarray,
304
- y: np.ndarray,
305
- n_rows: int,
306
- degree: int = 2,
307
- warn_threshold: float = 0.5
308
- ) -> List[float]:
309
- """
310
- Forecast using polynomial trend.
311
-
312
- Args:
313
- x: Time values
314
- y: Data values
315
- n_rows: Number of values to forecast
316
- degree: Polynomial degree
317
- warn_threshold: R² threshold for warning
318
-
319
- Returns:
320
- List of forecasted values
321
- """
322
- coeffs, r_squared = fit_polynomial_trend(x, y, degree)
323
-
324
- # Warn if poor fit
325
- if r_squared < warn_threshold:
326
- warnings.warn(
327
- f"Polynomial fit (degree={degree}) has low R² = {r_squared:.3f}. "
328
- "Consider adjusting degree or using a different method."
329
- )
330
-
331
- # Generate forecast points
332
- last_x = x[-1]
333
- forecast_x = np.arange(last_x + 1, last_x + n_rows + 1)
334
- forecast_y = np.polyval(coeffs, forecast_x)
335
-
336
- return forecast_y.tolist()
337
-
338
-
339
- def forecast_exponential(
340
- x: np.ndarray,
341
- y: np.ndarray,
342
- n_rows: int,
343
- warn_threshold: float = 0.5
344
- ) -> List[float]:
345
- """
346
- Forecast using exponential trend.
347
-
348
- Args:
349
- x: Time values
350
- y: Data values
351
- n_rows: Number of values to forecast
352
- warn_threshold: R² threshold for warning
353
-
354
- Returns:
355
- List of forecasted values
356
- """
357
- a, b, r_squared = fit_exponential_trend(x, y)
358
-
359
- # Warn if poor fit
360
- if r_squared < warn_threshold:
361
- warnings.warn(
362
- f"Exponential fit has low R² = {r_squared:.3f}. "
363
- "Consider using a different method."
364
- )
365
-
366
- # Generate forecast points
367
- last_x = x[-1]
368
- forecast_x = np.arange(last_x + 1, last_x + n_rows + 1)
369
- forecast_y = a * np.exp(b * forecast_x)
370
-
371
- return forecast_y.tolist()
372
-
373
-
374
- def forecast_moving_average(
375
- y: np.ndarray,
376
- n_rows: int,
377
- window: int = 3
378
- ) -> List[float]:
379
- """
380
- Forecast using moving average.
381
-
382
- Args:
383
- y: Data values
384
- n_rows: Number of values to forecast
385
- window: Window size for moving average
386
-
387
- Returns:
388
- List of forecasted values
389
- """
390
- # Calculate initial moving average
391
- ma_value = calculate_moving_average(y, window)
392
-
393
- # Repeat the moving average for all forecast points
394
- # Note: This is a simple approach. More sophisticated would be
395
- # to recalculate MA as we add forecasted points.
396
- return [ma_value] * n_rows
397
-
398
-
399
- def auto_select_method(x: np.ndarray, y: np.ndarray) -> str:
400
- """
401
- Automatically select best forecasting method based on data.
402
-
403
- Args:
404
- x: Time values
405
- y: Data values
406
-
407
- Returns:
408
- Best method name
409
- """
410
- methods_to_try = []
411
-
412
- # Always try linear
413
- methods_to_try.append(ForecastMethod.LINEAR)
414
-
415
- # Try polynomial if enough data
416
- if len(y) >= MIN_DATA_POINTS[ForecastMethod.POLYNOMIAL]:
417
- methods_to_try.append(ForecastMethod.POLYNOMIAL)
418
-
419
- # Try exponential if all positive and enough data
420
- if np.all(y > 0) and len(y) >= MIN_DATA_POINTS[ForecastMethod.EXPONENTIAL]:
421
- methods_to_try.append(ForecastMethod.EXPONENTIAL)
422
-
423
- # Fit each method and compare R²
424
- best_method = ForecastMethod.LINEAR
425
- best_r_squared = -np.inf
426
-
427
- for method in methods_to_try:
428
- try:
429
- if method == ForecastMethod.LINEAR:
430
- _, _, r_squared = fit_linear_trend(x, y)
431
- elif method == ForecastMethod.POLYNOMIAL:
432
- _, r_squared = fit_polynomial_trend(x, y, degree=2)
433
- elif method == ForecastMethod.EXPONENTIAL:
434
- _, _, r_squared = fit_exponential_trend(x, y)
435
-
436
- if r_squared > best_r_squared:
437
- best_r_squared = r_squared
438
- best_method = method
439
- except:
440
- # Skip methods that fail
441
- continue
442
-
443
- return best_method
444
-
445
-
446
- def forecast_values(
447
- df_polars,
448
- column: str,
449
- n_rows: int,
450
- method: str = ForecastMethod.LINEAR,
451
- time_col: Optional[str] = None,
452
- **params
453
- ) -> List[float]:
454
- """
455
- Main forecasting function.
456
-
457
- Args:
458
- df_polars: Input Polars DataFrame
459
- column: Column to forecast
460
- n_rows: Number of values to forecast
461
- method: Forecasting method (linear, polynomial, exponential, moving_average, auto)
462
- time_col: Time column name (auto-detect if None)
463
- **params: Method-specific parameters:
464
- - degree: For polynomial (default: 2)
465
- - window: For moving_average (default: 3)
466
- - warn_threshold: R² threshold for warnings (default: 0.5)
467
-
468
- Returns:
469
- List of forecasted values
470
-
471
- Raises:
472
- ValidationError: If data is invalid
473
- AugmentError: If forecasting fails
474
- """
475
- # Validate data
476
- validate_forecast_data(df_polars, column, method)
477
-
478
- # Get data values
479
- y = df_polars[column].to_numpy()
480
-
481
- # Detect or use time column
482
- time_col_name = detect_time_column(df_polars, time_col)
483
-
484
- if time_col_name:
485
- # Use actual time column
486
- x = df_polars[time_col_name].to_numpy()
487
- # Convert to numeric if needed (e.g., datetime to timestamp)
488
- if not np.issubdtype(x.dtype, np.number):
489
- # Try to convert to numeric
490
- try:
491
- x = np.arange(len(x)) # Fallback to index
492
- except:
493
- x = np.arange(len(x))
494
- else:
495
- # Use row index as time
496
- x = np.arange(len(y))
497
- warnings.warn(
498
- "No time column detected. Using row index as time axis. "
499
- "Specify time_col parameter for better results."
500
- )
501
-
502
- # Auto-select method if requested
503
- if method == ForecastMethod.AUTO or method == "auto":
504
- method = auto_select_method(x, y)
505
- print(f"Auto-selected forecasting method: {method}")
506
-
507
- # Forecast based on method
508
- try:
509
- if method == ForecastMethod.LINEAR:
510
- return forecast_linear(x, y, n_rows, params.get('warn_threshold', 0.5))
511
-
512
- elif method == ForecastMethod.POLYNOMIAL:
513
- degree = params.get('degree', 2)
514
- return forecast_polynomial(x, y, n_rows, degree, params.get('warn_threshold', 0.5))
515
-
516
- elif method == ForecastMethod.EXPONENTIAL:
517
- return forecast_exponential(x, y, n_rows, params.get('warn_threshold', 0.5))
518
-
519
- elif method == ForecastMethod.MOVING_AVERAGE:
520
- window = params.get('window', 3)
521
- return forecast_moving_average(y, n_rows, window)
522
-
523
- else:
524
- raise ValidationError(
525
- f"Unknown forecasting method: '{method}'. "
526
- f"Supported: linear, polynomial, exponential, moving_average, auto"
527
- )
528
-
529
- except Exception as e:
530
- if isinstance(e, (ValidationError, AugmentError)):
531
- raise
532
- raise AugmentError(f"Forecasting failed: {e}")
533
-
534
-
535
-
536
-
537
- def detect_seasonality(y: np.ndarray, max_period: int = 12) -> Tuple[int, float]:
538
- """
539
- Detect seasonal period using autocorrelation.
540
-
541
- Args:
542
- y: Time series data
543
- max_period: Maximum period to check (default: 12 for monthly data)
544
-
545
- Returns:
546
- Tuple of (period, strength)
547
- - period: Detected seasonal period (0 if no seasonality)
548
- - strength: Strength of seasonality (0-1)
549
- """
550
- if len(y) < max_period * 2:
551
- return 0, 0.0
552
-
553
- # Calculate autocorrelation for different lags
554
- autocorr = []
555
- mean_y = np.mean(y)
556
- var_y = np.var(y)
557
-
558
- if var_y == 0:
559
- return 0, 0.0
560
-
561
- for lag in range(1, min(max_period + 1, len(y) // 2)):
562
- # Calculate autocorrelation at this lag
563
- numerator = np.sum((y[:-lag] - mean_y) * (y[lag:] - mean_y))
564
- denominator = len(y[:-lag]) * var_y
565
-
566
- if denominator > 0:
567
- acf = numerator / denominator
568
- autocorr.append((lag, acf))
569
-
570
- if not autocorr:
571
- return 0, 0.0
572
-
573
- # Find the lag with highest positive autocorrelation (excluding lag 1)
574
- autocorr_sorted = sorted(autocorr[1:], key=lambda x: x[1], reverse=True)
575
-
576
- if not autocorr_sorted or autocorr_sorted[0][1] < 0.3:
577
- # No significant seasonality
578
- return 0, 0.0
579
-
580
- period = autocorr_sorted[0][0]
581
- strength = autocorr_sorted[0][1]
582
-
583
- return period, strength
584
-
585
-
586
- def decompose_seasonal(y: np.ndarray, period: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
587
- """
588
- Simple seasonal decomposition: trend + seasonal + residual.
589
-
590
- Uses moving average for trend and averages for seasonal component.
591
-
592
- Args:
593
- y: Time series data
594
- period: Seasonal period
595
-
596
- Returns:
597
- Tuple of (trend, seasonal, residual)
598
- """
599
- n = len(y)
600
-
601
- # Calculate trend using centered moving average
602
- trend = np.zeros(n)
603
- half_period = period // 2
604
-
605
- for i in range(n):
606
- start = max(0, i - half_period)
607
- end = min(n, i + half_period + 1)
608
- trend[i] = np.mean(y[start:end])
609
-
610
- # Detrend
611
- detrended = y - trend
612
-
613
- # Calculate seasonal component (average for each position in cycle)
614
- seasonal = np.zeros(n)
615
- seasonal_averages = np.zeros(period)
616
-
617
- for i in range(period):
618
- # Get all values at this position in the cycle
619
- indices = list(range(i, n, period))
620
- if indices:
621
- seasonal_averages[i] = np.mean(detrended[indices])
622
-
623
- # Normalize seasonal component to sum to zero
624
- seasonal_averages -= np.mean(seasonal_averages)
625
-
626
- # Assign seasonal values
627
- for i in range(n):
628
- seasonal[i] = seasonal_averages[i % period]
629
-
630
- # Calculate residual
631
- residual = y - trend - seasonal
632
-
633
- return trend, seasonal, residual
634
-
635
-
636
- def forecast_seasonal(
637
- y: np.ndarray,
638
- n_rows: int,
639
- period: Optional[int] = None,
640
- auto_detect: bool = True
641
- ) -> List[float]:
642
- """
643
- Forecast using seasonal decomposition.
644
-
645
- Decomposes series into trend + seasonal + residual, then forecasts each component.
646
-
647
- Args:
648
- y: Time series data
649
- n_rows: Number of values to forecast
650
- period: Seasonal period (e.g., 12 for monthly, 7 for daily)
651
- If None and auto_detect=True, will attempt to detect
652
- auto_detect: Whether to auto-detect period if not provided
653
-
654
- Returns:
655
- List of forecasted values
656
-
657
- Raises:
658
- ValidationError: If insufficient data or no seasonality detected
659
- """
660
- n = len(y)
661
-
662
- # Detect or validate period
663
- if period is None:
664
- if not auto_detect:
665
- raise ValidationError(
666
- "Seasonal forecasting requires a period parameter. "
667
- "Provide period or set auto_detect=True"
668
- )
669
-
670
- # Auto-detect seasonality
671
- detected_period, strength = detect_seasonality(y)
672
-
673
- if detected_period == 0:
674
- raise ValidationError(
675
- "No significant seasonality detected in data. "
676
- "Try a different forecasting method or specify period manually."
677
- )
678
-
679
- period = detected_period
680
-
681
- if strength < 0.5:
682
- warnings.warn(
683
- f"Weak seasonality detected (strength={strength:.2f}). "
684
- "Results may not be reliable."
685
- )
686
-
687
- # Validate period
688
- if period < 2:
689
- raise ValidationError(f"Period must be >= 2, got {period}")
690
-
691
- if n < period * 2:
692
- raise ValidationError(
693
- f"Need at least {period * 2} data points for seasonal forecasting "
694
- f"with period={period}. Got {n} points."
695
- )
696
-
697
- # Decompose series
698
- trend, seasonal, residual = decompose_seasonal(y, period)
699
-
700
- # Forecast trend (using linear extrapolation)
701
- x = np.arange(n)
702
- trend_slope, trend_intercept, _ = fit_linear_trend(x, trend)
703
-
704
- forecast_x = np.arange(n, n + n_rows)
705
- forecast_trend = trend_slope * forecast_x + trend_intercept
706
-
707
- # Forecast seasonal component (repeat the pattern)
708
- seasonal_pattern = seasonal[-period:] # Last full cycle
709
- forecast_seasonal_comp = np.tile(seasonal_pattern, (n_rows // period) + 1)[:n_rows]
710
-
711
- # Combine forecasts
712
- forecast_values = forecast_trend + forecast_seasonal_comp
713
-
714
- return forecast_values.tolist()
715
-
716
-
717
- def forecast_holt_winters(
718
- y: np.ndarray,
719
- n_rows: int,
720
- period: Optional[int] = None,
721
- alpha: float = 0.2,
722
- beta: float = 0.1,
723
- gamma: float = 0.1,
724
- auto_detect: bool = True
725
- ) -> List[float]:
726
- """
727
- Forecast using Holt-Winters triple exponential smoothing.
728
-
729
- Handles level, trend, and seasonality components.
730
-
731
- Args:
732
- y: Time series data
733
- n_rows: Number of values to forecast
734
- period: Seasonal period (auto-detect if None)
735
- alpha: Level smoothing parameter (0-1)
736
- beta: Trend smoothing parameter (0-1)
737
- gamma: Seasonal smoothing parameter (0-1)
738
- auto_detect: Whether to auto-detect period
739
-
740
- Returns:
741
- List of forecasted values
742
-
743
- Raises:
744
- ValidationError: If parameters invalid
745
- """
746
- n = len(y)
747
-
748
- # Detect or validate period
749
- if period is None:
750
- if not auto_detect:
751
- raise ValidationError(
752
- "Holt-Winters requires a period parameter. "
753
- "Provide period or set auto_detect=True"
754
- )
755
-
756
- detected_period, strength = detect_seasonality(y)
757
-
758
- if detected_period == 0:
759
- # No seasonality, use double exponential smoothing (no seasonal component)
760
- return forecast_double_exponential(y, n_rows, alpha, beta)
761
-
762
- period = detected_period
763
-
764
- if strength < 0.3:
765
- warnings.warn(
766
- f"Weak seasonality detected (strength={strength:.2f}). "
767
- "Consider using a non-seasonal method."
768
- )
769
-
770
- # Validate parameters
771
- if period < 2:
772
- raise ValidationError(f"Period must be >= 2, got {period}")
773
-
774
- if n < period * 2:
775
- raise ValidationError(
776
- f"Need at least {period * 2} data points for Holt-Winters "
777
- f"with period={period}. Got {n} points."
778
- )
779
-
780
- for param_name, param_value in [('alpha', alpha), ('beta', beta), ('gamma', gamma)]:
781
- if not 0 < param_value < 1:
782
- raise ValidationError(
783
- f"{param_name} must be between 0 and 1, got {param_value}"
784
- )
785
-
786
- # Initialize components
787
- level = np.zeros(n)
788
- trend = np.zeros(n)
789
- seasonal = np.zeros(n + period)
790
-
791
- # Initial values
792
- level[0] = np.mean(y[:period])
793
- trend[0] = (np.mean(y[period:2*period]) - np.mean(y[:period])) / period
794
-
795
- # Initial seasonal components (first period)
796
- for i in range(period):
797
- seasonal[i] = y[i] - level[0]
798
-
799
- # Holt-Winters equations
800
- for t in range(1, n):
801
- # Level
802
- level[t] = alpha * (y[t] - seasonal[t]) + (1 - alpha) * (level[t-1] + trend[t-1])
803
-
804
- # Trend
805
- trend[t] = beta * (level[t] - level[t-1]) + (1 - beta) * trend[t-1]
806
-
807
- # Seasonal
808
- seasonal[t + period] = gamma * (y[t] - level[t]) + (1 - gamma) * seasonal[t]
809
-
810
- # Forecast
811
- forecast = []
812
- for i in range(n_rows):
813
- # Forecast = level + trend * steps + seasonal
814
- forecast_val = level[-1] + trend[-1] * (i + 1) + seasonal[n + (i % period)]
815
- forecast.append(forecast_val)
816
-
817
- return forecast
818
-
819
-
820
- def forecast_double_exponential(
821
- y: np.ndarray,
822
- n_rows: int,
823
- alpha: float = 0.2,
824
- beta: float = 0.1
825
- ) -> List[float]:
826
- """
827
- Forecast using double exponential smoothing (Holt's method).
828
-
829
- Handles level and trend, but no seasonality.
830
-
831
- Args:
832
- y: Time series data
833
- n_rows: Number of values to forecast
834
- alpha: Level smoothing parameter (0-1)
835
- beta: Trend smoothing parameter (0-1)
836
-
837
- Returns:
838
- List of forecasted values
839
- """
840
- n = len(y)
841
-
842
- # Initialize
843
- level = np.zeros(n)
844
- trend = np.zeros(n)
845
-
846
- level[0] = y[0]
847
- trend[0] = y[1] - y[0] if n > 1 else 0
848
-
849
- # Double exponential smoothing equations
850
- for t in range(1, n):
851
- level[t] = alpha * y[t] + (1 - alpha) * (level[t-1] + trend[t-1])
852
- trend[t] = beta * (level[t] - level[t-1]) + (1 - beta) * trend[t-1]
853
-
854
- # Forecast
855
- forecast = []
856
- for i in range(n_rows):
857
- forecast_val = level[-1] + trend[-1] * (i + 1)
858
- forecast.append(forecast_val)
859
-
860
- return forecast
861
-
862
-
863
- def forecast_ensemble(
864
- x: np.ndarray,
865
- y: np.ndarray,
866
- n_rows: int,
867
- methods: Optional[List[str]] = None,
868
- weights: Optional[List[float]] = None,
869
- auto_weight: bool = True
870
- ) -> List[float]:
871
- """
872
- Forecast using weighted ensemble of multiple methods.
873
-
874
- Combines predictions from multiple forecasting methods.
875
-
876
- Args:
877
- x: Time values
878
- y: Data values
879
- n_rows: Number of values to forecast
880
- methods: List of methods to ensemble (default: ['linear', 'polynomial', 'exponential'])
881
- weights: Weights for each method (auto-calculated if None)
882
- auto_weight: Whether to auto-weight based on R² scores
883
-
884
- Returns:
885
- List of forecasted values
886
- """
887
- # Default methods
888
- if methods is None:
889
- methods = [ForecastMethod.LINEAR, ForecastMethod.POLYNOMIAL]
890
- # Add exponential if all positive
891
- if np.all(y > 0):
892
- methods.append(ForecastMethod.EXPONENTIAL)
893
-
894
- # Validate we have at least 2 methods
895
- if len(methods) < 2:
896
- raise ValidationError("Ensemble requires at least 2 methods")
897
-
898
- # Generate forecasts and calculate R² for each method
899
- forecasts = []
900
- r_squared_scores = []
901
-
902
- for method in methods:
903
- try:
904
- if method == ForecastMethod.LINEAR:
905
- slope, intercept, r2 = fit_linear_trend(x, y)
906
- forecast_x = np.arange(x[-1] + 1, x[-1] + n_rows + 1)
907
- forecast_y = slope * forecast_x + intercept
908
- forecasts.append(forecast_y)
909
- r_squared_scores.append(r2)
910
-
911
- elif method == ForecastMethod.POLYNOMIAL:
912
- coeffs, r2 = fit_polynomial_trend(x, y, degree=2)
913
- forecast_x = np.arange(x[-1] + 1, x[-1] + n_rows + 1)
914
- forecast_y = np.polyval(coeffs, forecast_x)
915
- forecasts.append(forecast_y)
916
- r_squared_scores.append(r2)
917
-
918
- elif method == ForecastMethod.EXPONENTIAL:
919
- if np.all(y > 0):
920
- a, b, r2 = fit_exponential_trend(x, y)
921
- forecast_x = np.arange(x[-1] + 1, x[-1] + n_rows + 1)
922
- forecast_y = a * np.exp(b * forecast_x)
923
- forecasts.append(forecast_y)
924
- r_squared_scores.append(r2)
925
- else:
926
- continue # Skip if not all positive
927
-
928
- elif method == ForecastMethod.SEASONAL:
929
- # Try seasonal forecast
930
- try:
931
- forecast_y = forecast_seasonal(y, n_rows, period=None, auto_detect=True)
932
- forecasts.append(np.array(forecast_y))
933
- # Estimate R² for seasonal (use decomposition quality)
934
- r_squared_scores.append(0.7) # Placeholder
935
- except:
936
- continue # Skip if seasonal fails
937
-
938
- except Exception as e:
939
- warnings.warn(f"Method {method} failed in ensemble: {e}")
940
- continue
941
-
942
- if not forecasts:
943
- raise AugmentError("All ensemble methods failed")
944
-
945
- # Calculate weights
946
- if weights is None and auto_weight:
947
- # Weight by R² scores (normalized)
948
- r_squared_array = np.array(r_squared_scores)
949
- # Ensure non-negative
950
- r_squared_array = np.maximum(r_squared_array, 0)
951
-
952
- if np.sum(r_squared_array) > 0:
953
- weights = r_squared_array / np.sum(r_squared_array)
954
- else:
955
- # Equal weights if all R² are 0
956
- weights = np.ones(len(forecasts)) / len(forecasts)
957
-
958
- elif weights is None:
959
- # Equal weights
960
- weights = np.ones(len(forecasts)) / len(forecasts)
961
-
962
- else:
963
- # Validate provided weights
964
- if len(weights) != len(forecasts):
965
- raise ValidationError(
966
- f"Number of weights ({len(weights)}) must match number of methods ({len(forecasts)})"
967
- )
968
-
969
- # Normalize weights
970
- weights = np.array(weights)
971
- weights = weights / np.sum(weights)
972
-
973
- # Combine forecasts
974
- ensemble_forecast = np.zeros(n_rows)
975
- for forecast, weight in zip(forecasts, weights):
976
- ensemble_forecast += weight * forecast
977
-
978
- return ensemble_forecast.tolist()
979
-
980
-
981
- def calculate_trend_strength(y: np.ndarray) -> float:
982
- """
983
- Calculate strength of trend in time series.
984
-
985
- Uses linear regression R² as measure of trend strength.
986
-
987
- Args:
988
- y: Time series data
989
-
990
- Returns:
991
- Trend strength (0-1)
992
- """
993
- x = np.arange(len(y))
994
- _, _, r_squared = fit_linear_trend(x, y)
995
- return r_squared
996
-
997
-
998
- def forecast_values(
999
- df_polars,
1000
- column: str,
1001
- n_rows: int,
1002
- method: str = ForecastMethod.LINEAR,
1003
- time_col: Optional[str] = None,
1004
- **params
1005
- ) -> List[float]:
1006
- """
1007
- Main forecasting function.
1008
-
1009
- Args:
1010
- df_polars: Input Polars DataFrame
1011
- column: Column to forecast
1012
- n_rows: Number of values to forecast
1013
- method: Forecasting method (linear, polynomial, exponential, moving_average,
1014
- seasonal, holt_winters, ensemble, auto)
1015
- time_col: Time column name (auto-detect if None)
1016
- **params: Method-specific parameters:
1017
- - degree: For polynomial (default: 2)
1018
- - window: For moving_average (default: 3)
1019
- - period: For seasonal/holt_winters (auto-detect if None)
1020
- - alpha, beta, gamma: For holt_winters (default: 0.2, 0.1, 0.1)
1021
- - methods: For ensemble (default: ['linear', 'polynomial'])
1022
- - weights: For ensemble (auto-calculated if None)
1023
- - warn_threshold: R² threshold for warnings (default: 0.5)
1024
-
1025
- Returns:
1026
- List of forecasted values
1027
-
1028
- Raises:
1029
- ValidationError: If data is invalid
1030
- AugmentError: If forecasting fails
1031
- """
1032
- # Validate data
1033
- validate_forecast_data(df_polars, column, method)
1034
-
1035
- # Get data values
1036
- y = df_polars[column].to_numpy()
1037
-
1038
- # Detect or use time column
1039
- time_col_name = detect_time_column(df_polars, time_col)
1040
-
1041
- if time_col_name:
1042
- # Use actual time column
1043
- x = df_polars[time_col_name].to_numpy()
1044
- # Convert to numeric if needed (e.g., datetime to timestamp)
1045
- if not np.issubdtype(x.dtype, np.number):
1046
- # Try to convert to numeric
1047
- try:
1048
- x = np.arange(len(x)) # Fallback to index
1049
- except:
1050
- x = np.arange(len(x))
1051
- else:
1052
- # Use row index as time
1053
- x = np.arange(len(y))
1054
- warnings.warn(
1055
- "No time column detected. Using row index as time axis. "
1056
- "Specify time_col parameter for better results."
1057
- )
1058
-
1059
- # Auto-select method if requested
1060
- if method == ForecastMethod.AUTO or method == "auto":
1061
- # Check trend strength
1062
- trend_strength = calculate_trend_strength(y)
1063
-
1064
- # Check seasonality
1065
- detected_period, seasonal_strength = detect_seasonality(y)
1066
-
1067
- # Decision logic
1068
- if seasonal_strength > 0.5 and len(y) >= MIN_DATA_POINTS[ForecastMethod.HOLT_WINTERS]:
1069
- method = ForecastMethod.HOLT_WINTERS
1070
- elif seasonal_strength > 0.3 and len(y) >= MIN_DATA_POINTS[ForecastMethod.SEASONAL]:
1071
- method = ForecastMethod.SEASONAL
1072
- elif trend_strength > 0.7:
1073
- # Strong trend, try exponential if positive
1074
- if np.all(y > 0):
1075
- method = ForecastMethod.EXPONENTIAL
1076
- else:
1077
- method = ForecastMethod.POLYNOMIAL
1078
- elif trend_strength > 0.5:
1079
- method = ForecastMethod.LINEAR
1080
- else:
1081
- # Weak trend, use ensemble
1082
- method = ForecastMethod.ENSEMBLE
1083
-
1084
- print(f"Auto-selected forecasting method: {method} "
1085
- f"(trend={trend_strength:.2f}, seasonal={seasonal_strength:.2f})")
1086
-
1087
- # Forecast based on method
1088
- try:
1089
- if method == ForecastMethod.LINEAR:
1090
- return forecast_linear(x, y, n_rows, params.get('warn_threshold', 0.5))
1091
-
1092
- elif method == ForecastMethod.POLYNOMIAL:
1093
- degree = params.get('degree', 2)
1094
- return forecast_polynomial(x, y, n_rows, degree, params.get('warn_threshold', 0.5))
1095
-
1096
- elif method == ForecastMethod.EXPONENTIAL:
1097
- return forecast_exponential(x, y, n_rows, params.get('warn_threshold', 0.5))
1098
-
1099
- elif method == ForecastMethod.MOVING_AVERAGE:
1100
- window = params.get('window', 3)
1101
- return forecast_moving_average(y, n_rows, window)
1102
-
1103
- elif method == ForecastMethod.SEASONAL:
1104
- period = params.get('period', None)
1105
- auto_detect = params.get('auto_detect', True)
1106
- return forecast_seasonal(y, n_rows, period, auto_detect)
1107
-
1108
- elif method == ForecastMethod.HOLT_WINTERS:
1109
- period = params.get('period', None)
1110
- alpha = params.get('alpha', 0.2)
1111
- beta = params.get('beta', 0.1)
1112
- gamma = params.get('gamma', 0.1)
1113
- auto_detect = params.get('auto_detect', True)
1114
- return forecast_holt_winters(y, n_rows, period, alpha, beta, gamma, auto_detect)
1115
-
1116
- elif method == ForecastMethod.ENSEMBLE:
1117
- methods = params.get('methods', None)
1118
- weights = params.get('weights', None)
1119
- auto_weight = params.get('auto_weight', True)
1120
- return forecast_ensemble(x, y, n_rows, methods, weights, auto_weight)
1121
-
1122
- else:
1123
- raise ValidationError(
1124
- f"Unknown forecasting method: '{method}'. "
1125
- f"Supported: linear, polynomial, exponential, moving_average, "
1126
- f"seasonal, holt_winters, ensemble, auto"
1127
- )
1128
-
1129
- except Exception as e:
1130
- if isinstance(e, (ValidationError, AugmentError)):
1131
- raise
1132
- raise AugmentError(f"Forecasting failed: {e}")