hypertrees-forecasting 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hypertrees/__init__.py +1 -0
- hypertrees/models/HyperTreeAR.py +631 -0
- hypertrees/models/HyperTreeETS.py +1011 -0
- hypertrees/models/HyperTreeNetAR.py +916 -0
- hypertrees/models/HyperTreeSTL.py +806 -0
- hypertrees/models/__init__.py +6 -0
- hypertrees/models/mlp.py +79 -0
- hypertrees/utils.py +456 -0
- hypertrees_forecasting-0.1.0.dist-info/METADATA +427 -0
- hypertrees_forecasting-0.1.0.dist-info/RECORD +14 -0
- hypertrees_forecasting-0.1.0.dist-info/WHEEL +5 -0
- hypertrees_forecasting-0.1.0.dist-info/licenses/LICENSE +228 -0
- hypertrees_forecasting-0.1.0.dist-info/top_level.txt +1 -0
- hypertrees_forecasting-0.1.0.dist-info/zip-safe +1 -0
hypertrees/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Forecasting with Hyper-Trees"""
|
|
@@ -0,0 +1,631 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from torch.autograd import grad as autograd
|
|
8
|
+
import lightgbm as lgb
|
|
9
|
+
from typing import Tuple, Callable, Optional
|
|
10
|
+
import time
|
|
11
|
+
from ..utils import CustomLogger
|
|
12
|
+
lgb.register_logger(CustomLogger())
|
|
13
|
+
|
|
14
|
+
from ..utils import TimeSeriesPreprocessor, prepare_datasets, TrainingResult, validate_series_order, GaussNewtonHessian
|
|
15
|
+
|
|
16
|
+
class HyperTreeAR:
|
|
17
|
+
"""
|
|
18
|
+
Class that implements a Hyper-Tree-AR(p) model for time series forecasting.
|
|
19
|
+
|
|
20
|
+
The Hyper-Tree-AR(p) model extends traditional autoregressive models by allowing
|
|
21
|
+
the AR coefficients to be time-varying and estimated by gradient boosted trees.
|
|
22
|
+
This creates a non-linear, adaptive autoregressive model that can capture complex
|
|
23
|
+
temporal dependencies.
|
|
24
|
+
|
|
25
|
+
Key features:
|
|
26
|
+
- Combines tree-based models (LightGBM) with autoregressive time series modeling
|
|
27
|
+
- Allows AR coefficients to vary based on features
|
|
28
|
+
- Provides AR coefficients that can vary over time
|
|
29
|
+
|
|
30
|
+
Use this model when:
|
|
31
|
+
- You have relevant features that might influence the autoregressive structure
|
|
32
|
+
- You want more flexibility than traditional AR models
|
|
33
|
+
|
|
34
|
+
Example usage:
|
|
35
|
+
```python
|
|
36
|
+
# Imports
|
|
37
|
+
from hypertrees.models import HyperTreeAR
|
|
38
|
+
import pandas as pd
|
|
39
|
+
import matplotlib.pyplot as plt
|
|
40
|
+
|
|
41
|
+
# Initialize model
|
|
42
|
+
lag_p = 12
|
|
43
|
+
frequency = 'M'
|
|
44
|
+
fcst_h = 12
|
|
45
|
+
model = HyperTreeAR(p=lag_p, freq=frequency, fcst_h=fcst_h)
|
|
46
|
+
|
|
47
|
+
# Data
|
|
48
|
+
# The data needs to have the following columns: 'date', 'series_id', 'value'. All other columns are automatically treated as features.
|
|
49
|
+
# You don't have to add lag-values yourself, this happens automatically during training.
|
|
50
|
+
df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/air-passengers.csv', parse_dates=['ds'])
|
|
51
|
+
df.rename(columns={'unique_id': 'series_id', 'ds': 'date', 'y': 'value'}, inplace=True)
|
|
52
|
+
df['month'] = df['date'].dt.month
|
|
53
|
+
df["quarter"] = df['date'].dt.quarter
|
|
54
|
+
test = df.tail(fcst_h)
|
|
55
|
+
train = df.drop(test.index)
|
|
56
|
+
|
|
57
|
+
# Train model
|
|
58
|
+
model.train(
|
|
59
|
+
lgb_params={'learning_rate': 0.1},
|
|
60
|
+
num_iterations=100,
|
|
61
|
+
train_data=train
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Generate forecasts
|
|
65
|
+
forecasts = model.forecast(test_data=test)
|
|
66
|
+
|
|
67
|
+
# Plot results
|
|
68
|
+
datasets = [
|
|
69
|
+
(df, 'date', 'value', 'Actual', '#2E86AB', '-'),
|
|
70
|
+
(forecasts, 'date', 'fcst', 'Forecast', '#F18F01', '--')
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
for data, x_col, y_col, label, color, style in datasets:
|
|
74
|
+
plt.plot(data[x_col], data[y_col], label=label, color=color,
|
|
75
|
+
linestyle=style, linewidth=2, alpha=0.8)
|
|
76
|
+
|
|
77
|
+
plt.title('AirPassengers - Forecast', fontsize=14)
|
|
78
|
+
plt.legend(frameon=True, fancybox=True)
|
|
79
|
+
plt.grid(True, alpha=0.3)
|
|
80
|
+
plt.tight_layout()
|
|
81
|
+
```
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
p: int = 2,
|
|
87
|
+
freq: str = "M",
|
|
88
|
+
fcst_h: int = 1,
|
|
89
|
+
loss_fn: Callable = nn.MSELoss(),
|
|
90
|
+
hessian_method: str = "exact",
|
|
91
|
+
n_hessian_probes: int = 5,
|
|
92
|
+
):
|
|
93
|
+
"""
|
|
94
|
+
Initialize the Hyper-Tree-AR(p) model.
|
|
95
|
+
|
|
96
|
+
Arguments
|
|
97
|
+
----------
|
|
98
|
+
p : int
|
|
99
|
+
Maximum number of AR(p) lags. Must be a positive integer.
|
|
100
|
+
freq : str
|
|
101
|
+
Frequency of the time series (e.g., 'D' for daily, 'M' for monthly,
|
|
102
|
+
'Q' for quarterly, 'Y' for yearly).
|
|
103
|
+
fcst_h : int
|
|
104
|
+
Forecast horizon (number of periods to forecast ahead).
|
|
105
|
+
loss_fn : Callable
|
|
106
|
+
Loss function for optimization. Must be a PyTorch loss function.
|
|
107
|
+
Default is MSE loss, but can be changed for different error metrics.
|
|
108
|
+
hessian_method : str
|
|
109
|
+
Method for computing the Hessian diagonal. Options:
|
|
110
|
+
- "exact": Exact diagonal Hessian via per-parameter second-order autograd.
|
|
111
|
+
- "gn": Gauss-Newton approximation estimated via Hutchinson probing.
|
|
112
|
+
Guarantees positive semi-definite Hessians. Avoids second-order
|
|
113
|
+
differentiation at the cost of Hutchinson estimation variance.
|
|
114
|
+
n_hessian_probes : int
|
|
115
|
+
Number of Hutchinson probes for Gauss-Newton Hessian diagonal estimation.
|
|
116
|
+
Only used when hessian_method="gn". More probes reduce variance but
|
|
117
|
+
increase computation. Default is 5.
|
|
118
|
+
"""
|
|
119
|
+
# Validate inputs
|
|
120
|
+
if p <= 0:
|
|
121
|
+
raise ValueError("Parameter 'p' must be a positive integer.")
|
|
122
|
+
if fcst_h <= 0:
|
|
123
|
+
raise ValueError("Forecast horizon 'fcst_h' must be a positive integer.")
|
|
124
|
+
if not isinstance(freq, str):
|
|
125
|
+
raise TypeError("freq must be a string.")
|
|
126
|
+
if not isinstance(loss_fn, nn.Module):
|
|
127
|
+
raise TypeError("loss_fn must be a PyTorch loss function.")
|
|
128
|
+
if hessian_method not in ("exact", "gn"):
|
|
129
|
+
raise ValueError("hessian_method must be either 'exact' or 'gn'.")
|
|
130
|
+
if not isinstance(n_hessian_probes, int) or n_hessian_probes <= 0:
|
|
131
|
+
raise ValueError("n_hessian_probes must be a positive integer.")
|
|
132
|
+
|
|
133
|
+
if hessian_method == "gn" and not isinstance(loss_fn, nn.MSELoss):
|
|
134
|
+
warnings.warn(
|
|
135
|
+
f"Loss {type(loss_fn).__name__} is not nn.MSELoss. The Gauss-Newton "
|
|
136
|
+
"Hessian requires a twice-differentiable loss; non-smooth losses "
|
|
137
|
+
"(e.g., L1Loss, quantile loss, HuberLoss/SmoothL1Loss outside the quadratic "
|
|
138
|
+
"region) have zero or undefined second derivatives at kinks, "
|
|
139
|
+
"causing degenerate Hessians."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
self.p = p
|
|
143
|
+
self.freq = freq
|
|
144
|
+
self.fcst_h = fcst_h
|
|
145
|
+
self.loss_fn = loss_fn
|
|
146
|
+
self.loss_name = self.loss_fn.__class__.__name__
|
|
147
|
+
self.dtype = torch.float32
|
|
148
|
+
self.model = None
|
|
149
|
+
self.features = None # Stores feature names after training
|
|
150
|
+
self.is_trained = False # Flag to track if model has been trained
|
|
151
|
+
self.dataset_references = {} # Store references to LightGBM datasets
|
|
152
|
+
self.hessian_method = hessian_method
|
|
153
|
+
self.n_hessian_probes = n_hessian_probes
|
|
154
|
+
self._iter_count = 0
|
|
155
|
+
self._fit = None
|
|
156
|
+
self._target = None
|
|
157
|
+
|
|
158
|
+
# Bind Hessian computation strategy
|
|
159
|
+
if hessian_method == "exact":
|
|
160
|
+
self.calculate_gradients_and_hessians = self._calculate_gradients_and_hessians_exact
|
|
161
|
+
else:
|
|
162
|
+
self._gn_hessian = GaussNewtonHessian(loss_fn, n_hessian_probes, self.dtype)
|
|
163
|
+
self.calculate_gradients_and_hessians = self._calculate_gradients_and_hessians_gn
|
|
164
|
+
|
|
165
|
+
def objective_fn(self, predt: np.ndarray, data: lgb.Dataset) -> Tuple[np.ndarray, np.ndarray]:
|
|
166
|
+
"""
|
|
167
|
+
Custom objective function for LightGBM training.
|
|
168
|
+
|
|
169
|
+
This function defines the gradients and hessians for the LightGBM model
|
|
170
|
+
based on the PyTorch loss function. It converts the raw LightGBM outputs to
|
|
171
|
+
PyTorch tensors, computes the loss, and then backpropagates to get gradients.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
predt : np.ndarray
|
|
176
|
+
Raw outputs from LightGBM, representing the AR coefficients.
|
|
177
|
+
data : lgb.Dataset
|
|
178
|
+
LightGBM dataset containing the target values.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
Tuple[np.ndarray, np.ndarray]
|
|
183
|
+
Gradients and hessians for LightGBM optimization.
|
|
184
|
+
"""
|
|
185
|
+
self._iter_count += 1
|
|
186
|
+
|
|
187
|
+
target = torch.tensor(data.get_label().reshape(-1, 1), dtype=self.dtype)
|
|
188
|
+
params, loss = self.get_params_loss(predt, target, self.lags_train, requires_grad=True)
|
|
189
|
+
grad, hess = self.calculate_gradients_and_hessians(loss, params)
|
|
190
|
+
|
|
191
|
+
return grad, hess
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def eval_fn(self, predt: np.ndarray, eval_data: lgb.Dataset) -> Tuple[str, float, bool]:
|
|
195
|
+
"""
|
|
196
|
+
Custom evaluation function for evaluating forecast accuracy on an evaluation dataset.
|
|
197
|
+
|
|
198
|
+
This function computes the loss value to be monitored during evaluation.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
predt : np.ndarray
|
|
203
|
+
Raw outputs from LightGBM.
|
|
204
|
+
eval_data : lgb.Dataset
|
|
205
|
+
LightGBM dataset containing the evaluation data.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
Tuple[str, float, bool]
|
|
210
|
+
Name of the metric, value of the metric, and whether to maximize it.
|
|
211
|
+
"""
|
|
212
|
+
# Use appropriate lags based on dataset name
|
|
213
|
+
dataset_name = self.dataset_references.get(id(eval_data), "unknown")
|
|
214
|
+
if dataset_name == "train":
|
|
215
|
+
lags = self.lags_train
|
|
216
|
+
elif dataset_name == "validation":
|
|
217
|
+
lags = self.lags_eval
|
|
218
|
+
else:
|
|
219
|
+
# Default to training lags if unknown
|
|
220
|
+
lags = self.lags_train
|
|
221
|
+
warnings.warn("Unknown dataset in metric_fn. Using training lags.")
|
|
222
|
+
|
|
223
|
+
# Calculate loss
|
|
224
|
+
is_higher_better = False # Lower loss is better, so we don't maximize
|
|
225
|
+
target = torch.tensor(eval_data.get_label().reshape(-1, 1), dtype=self.dtype)
|
|
226
|
+
_, loss = self.get_params_loss(predt, target, lags)
|
|
227
|
+
|
|
228
|
+
return self.loss_name, loss.item(), is_higher_better
|
|
229
|
+
|
|
230
|
+
def get_params_loss(
|
|
231
|
+
self,
|
|
232
|
+
predt: np.ndarray,
|
|
233
|
+
target: torch.Tensor,
|
|
234
|
+
lags: torch.Tensor = None,
|
|
235
|
+
requires_grad: bool = False
|
|
236
|
+
) -> Tuple[
|
|
237
|
+
torch.Tensor, torch.Tensor]:
|
|
238
|
+
"""
|
|
239
|
+
Transform LightGBM outputs into AR parameters and calculate loss.
|
|
240
|
+
|
|
241
|
+
This function:
|
|
242
|
+
1. Reshapes the raw outputs into AR parameters
|
|
243
|
+
2. Multiplies these parameters with the lag values
|
|
244
|
+
3. Computes the forecast by summing the weighted lags
|
|
245
|
+
4. Calculates the loss between forecasts and actual values
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
predt : np.ndarray
|
|
250
|
+
Raw outputs from LightGBM.
|
|
251
|
+
target : torch.Tensor
|
|
252
|
+
Target values (actual time series values).
|
|
253
|
+
lags : torch.Tensor
|
|
254
|
+
Lagged values of the time series.
|
|
255
|
+
requires_grad : bool
|
|
256
|
+
Whether to compute gradients (True during training).
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
Tuple[torch.Tensor, torch.Tensor]
|
|
261
|
+
Parameters tensor and loss value.
|
|
262
|
+
"""
|
|
263
|
+
# Reshape outputs into parameter matrix (samples × n_params)
|
|
264
|
+
# The 'F' order means Fortran-style ordering (column-major)
|
|
265
|
+
params = nn.Parameter(
|
|
266
|
+
torch.tensor(
|
|
267
|
+
predt.reshape(-1, self.p, order="F"),
|
|
268
|
+
dtype=self.dtype
|
|
269
|
+
),
|
|
270
|
+
requires_grad=requires_grad
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Forward pass: Compute forecasts by multiplying parameters with lags and summing
|
|
274
|
+
fcst = torch.sum(params * lags, dim=1, dtype=torch.float32).unsqueeze(1)
|
|
275
|
+
|
|
276
|
+
# Calculate loss between forecasts and actual values
|
|
277
|
+
loss = self.loss_fn(fcst, target)
|
|
278
|
+
|
|
279
|
+
if self.hessian_method == "gn":
|
|
280
|
+
self._fit = fcst
|
|
281
|
+
self._target = target
|
|
282
|
+
|
|
283
|
+
return params, loss
|
|
284
|
+
|
|
285
|
+
def _calculate_gradients_and_hessians_exact(self, loss: torch.Tensor, params: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
|
|
286
|
+
"""Exact diagonal Hessian via per-parameter second-order autograd."""
|
|
287
|
+
loss.backward(create_graph=True)
|
|
288
|
+
grad = params.grad
|
|
289
|
+
hess = [
|
|
290
|
+
autograd(grad[:, i].sum(), params, retain_graph=True)[0][:, i:(i + 1)]
|
|
291
|
+
for i in range(self.p)
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
grad = grad.cpu().detach().numpy().ravel(order="F")
|
|
295
|
+
hess = torch.cat(hess, dim=1).cpu().detach().numpy().ravel(order="F")
|
|
296
|
+
params.grad = None
|
|
297
|
+
|
|
298
|
+
return grad, hess
|
|
299
|
+
|
|
300
|
+
def _calculate_gradients_and_hessians_gn(self, loss: torch.Tensor, params: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
|
|
301
|
+
"""Gauss-Newton Hessian diagonal estimated via Hutchinson probing."""
|
|
302
|
+
grad = autograd(loss, params, retain_graph=True)[0]
|
|
303
|
+
rng = torch.Generator().manual_seed(self._iter_count)
|
|
304
|
+
hess = self._gn_hessian.estimate(self._fit, self._target, params, rng)
|
|
305
|
+
self._fit = None
|
|
306
|
+
self._target = None
|
|
307
|
+
grad = grad.cpu().detach().numpy().ravel(order="F")
|
|
308
|
+
hess = hess.cpu().detach().numpy().ravel(order="F")
|
|
309
|
+
|
|
310
|
+
return grad, hess
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def train(
|
|
314
|
+
self,
|
|
315
|
+
lgb_params: dict = None,
|
|
316
|
+
num_iterations: int = 100,
|
|
317
|
+
train_data: pd.DataFrame = None,
|
|
318
|
+
validation: bool = False,
|
|
319
|
+
early_stopping_round: Optional[int] = None,
|
|
320
|
+
seed: int = 123,
|
|
321
|
+
verbose: int = -1,
|
|
322
|
+
deterministic: bool = True,
|
|
323
|
+
) -> TrainingResult:
|
|
324
|
+
"""
|
|
325
|
+
Train the Hyper-Tree-AR model on time series data.
|
|
326
|
+
|
|
327
|
+
This method:
|
|
328
|
+
1. Preprocesses the time series data to create lag features
|
|
329
|
+
2. Sets up LightGBM datasets
|
|
330
|
+
3. Trains the model using gradient boosting
|
|
331
|
+
|
|
332
|
+
The training data must contain columns:
|
|
333
|
+
- 'series_id': Identifier for each time series
|
|
334
|
+
- 'date': Timestamp for each observation
|
|
335
|
+
- 'value': Target value to forecast
|
|
336
|
+
- Additional feature columns used for forecasting
|
|
337
|
+
|
|
338
|
+
Parameters
|
|
339
|
+
----------
|
|
340
|
+
lgb_params : dict
|
|
341
|
+
LightGBM parameters like 'learning_rate', 'num_leaves', etc.
|
|
342
|
+
num_iterations : int
|
|
343
|
+
Number of boosting rounds for training
|
|
344
|
+
train_data : pd.DataFrame
|
|
345
|
+
Training data containing series_id, date, value and feature columns
|
|
346
|
+
validation : bool
|
|
347
|
+
If True, a validation set will be created for evaluation. It splits the last fcst_h values of each
|
|
348
|
+
series for validation.
|
|
349
|
+
early_stopping_round : int, optional
|
|
350
|
+
If provided, training will stop if the validation loss does not improve for this many rounds.
|
|
351
|
+
seed : int
|
|
352
|
+
Random seed for reproducibility
|
|
353
|
+
verbose : int
|
|
354
|
+
Verbosity level for LightGBM training
|
|
355
|
+
deterministic : bool
|
|
356
|
+
If True, sets LightGBM's ``deterministic`` and ``force_row_wise`` parameters to ensure
|
|
357
|
+
reproducible results. May slow down training. See
|
|
358
|
+
https://lightgbm.readthedocs.io/en/latest/Parameters.html#deterministic
|
|
359
|
+
|
|
360
|
+
Returns
|
|
361
|
+
-------
|
|
362
|
+
TrainingResult
|
|
363
|
+
Object containing evaluation results and training information.
|
|
364
|
+
"""
|
|
365
|
+
# Validate inputs
|
|
366
|
+
if train_data is None:
|
|
367
|
+
raise ValueError("train_data must be provided.")
|
|
368
|
+
if lgb_params is None:
|
|
369
|
+
raise ValueError("lgb_params must be provided.")
|
|
370
|
+
if not isinstance(train_data, pd.DataFrame):
|
|
371
|
+
raise TypeError("train_data must be a pandas DataFrame.")
|
|
372
|
+
if not isinstance(lgb_params, dict):
|
|
373
|
+
raise TypeError("lgb_params must be a dictionary.")
|
|
374
|
+
if not isinstance(num_iterations, int) or num_iterations <= 0:
|
|
375
|
+
raise ValueError("num_iterations must be a positive integer.")
|
|
376
|
+
if not isinstance(seed, int):
|
|
377
|
+
raise TypeError("seed must be an integer.")
|
|
378
|
+
if not isinstance(verbose, int):
|
|
379
|
+
raise TypeError("verbose must be an integer.")
|
|
380
|
+
if early_stopping_round is not None and (not isinstance(early_stopping_round, int) or early_stopping_round <= 0):
|
|
381
|
+
raise ValueError("early_stopping_round must be a positive integer.")
|
|
382
|
+
if not isinstance(validation, bool):
|
|
383
|
+
raise TypeError("validation must be a boolean.")
|
|
384
|
+
if not isinstance(deterministic, bool):
|
|
385
|
+
raise TypeError("deterministic must be a boolean.")
|
|
386
|
+
if early_stopping_round is not None and not validation:
|
|
387
|
+
raise ValueError("early_stopping_round can only be used when validation is True.")
|
|
388
|
+
if validation and early_stopping_round is None:
|
|
389
|
+
raise ValueError("early_stopping_round must be provided when validation is True.")
|
|
390
|
+
|
|
391
|
+
if deterministic:
|
|
392
|
+
lgb_params = {**lgb_params, "deterministic": True, "force_row_wise": True}
|
|
393
|
+
|
|
394
|
+
# Check required columns
|
|
395
|
+
required_columns = ['series_id', 'date', 'value']
|
|
396
|
+
for col in required_columns:
|
|
397
|
+
if col not in train_data.columns:
|
|
398
|
+
raise ValueError(f"Required column '{col}' not found in training data.")
|
|
399
|
+
|
|
400
|
+
# Validate row ordering: each series must be a contiguous block with
|
|
401
|
+
# monotonic dates so the training reshape and fcst_lags extraction align.
|
|
402
|
+
validate_series_order(train_data, name="train_data")
|
|
403
|
+
|
|
404
|
+
# General model parameters
|
|
405
|
+
self.lgb_params = {
|
|
406
|
+
"num_class": self.p,
|
|
407
|
+
"objective": self.objective_fn,
|
|
408
|
+
"metric": "None",
|
|
409
|
+
"random_seed": seed,
|
|
410
|
+
"verbose": verbose
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
# Update with user-provided LightGBM parameters
|
|
414
|
+
self.lgb_params.update(lgb_params)
|
|
415
|
+
|
|
416
|
+
# Reset state for re-training
|
|
417
|
+
self._iter_count = 0
|
|
418
|
+
self._fit = None
|
|
419
|
+
self._target = None
|
|
420
|
+
self.model = None
|
|
421
|
+
self.dataset_references = {}
|
|
422
|
+
self.is_trained = False
|
|
423
|
+
self.features = None
|
|
424
|
+
|
|
425
|
+
try:
|
|
426
|
+
# Initialize TimeSeriesPreprocessor for creating lagged dataframe
|
|
427
|
+
preprocessor = TimeSeriesPreprocessor(
|
|
428
|
+
freq=self.freq,
|
|
429
|
+
lags=[i for i in range(1, self.p + 1)],
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# Process full dataset to create lagged dataframe
|
|
433
|
+
full_ts = preprocessor.create_lags(train_data)
|
|
434
|
+
full_dict = preprocessor.extract(full_ts)
|
|
435
|
+
|
|
436
|
+
# Store feature names for later use
|
|
437
|
+
self.features = full_dict["features"].columns.tolist()
|
|
438
|
+
|
|
439
|
+
# Prepare datasets
|
|
440
|
+
(valid_sets,
|
|
441
|
+
valid_names,
|
|
442
|
+
callbacks,
|
|
443
|
+
evals_result,
|
|
444
|
+
lags_train,
|
|
445
|
+
lags_eval,
|
|
446
|
+
self.dataset_references) = (
|
|
447
|
+
prepare_datasets(
|
|
448
|
+
full_ts=full_ts,
|
|
449
|
+
preprocessor=preprocessor,
|
|
450
|
+
fcst_h=self.fcst_h,
|
|
451
|
+
dtype=self.dtype,
|
|
452
|
+
validation=validation,
|
|
453
|
+
early_stopping_round=early_stopping_round
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# Store lagged values for training and evaluation
|
|
458
|
+
self.lags_train = lags_train
|
|
459
|
+
self.lags_eval = lags_eval
|
|
460
|
+
|
|
461
|
+
# Store lagged train values to be used in the forecast method
|
|
462
|
+
self.fcst_lags = (
|
|
463
|
+
train_data.groupby(["series_id"], sort=False)
|
|
464
|
+
.apply(lambda x: x["value"][-self.p:][::-1].values)
|
|
465
|
+
.to_dict()
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Train LightGBM model
|
|
469
|
+
start_time = time.time()
|
|
470
|
+
self.model = lgb.train(
|
|
471
|
+
self.lgb_params,
|
|
472
|
+
valid_sets[0],
|
|
473
|
+
num_boost_round=num_iterations,
|
|
474
|
+
feval=self.eval_fn if validation else None,
|
|
475
|
+
valid_sets=valid_sets,
|
|
476
|
+
valid_names=valid_names,
|
|
477
|
+
callbacks=callbacks
|
|
478
|
+
)
|
|
479
|
+
training_time = time.time() - start_time
|
|
480
|
+
|
|
481
|
+
# Set trained flag to True
|
|
482
|
+
self.is_trained = True
|
|
483
|
+
|
|
484
|
+
# Return results
|
|
485
|
+
result = TrainingResult(
|
|
486
|
+
train_metrics=evals_result["train"] if validation else {"loss": []},
|
|
487
|
+
validation_metrics=evals_result["validation"] if validation else None,
|
|
488
|
+
best_iteration=self.model.best_iteration-1 if hasattr(self.model, 'best_iteration') else num_iterations,
|
|
489
|
+
training_time=training_time
|
|
490
|
+
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
return result
|
|
494
|
+
|
|
495
|
+
except Exception as e:
|
|
496
|
+
self.is_trained = False
|
|
497
|
+
raise RuntimeError(f"Training failed: {str(e)}")
|
|
498
|
+
|
|
499
|
+
def forecast(
|
|
500
|
+
self,
|
|
501
|
+
test_data: pd.DataFrame,
|
|
502
|
+
type: str = "forecast"
|
|
503
|
+
) -> pd.DataFrame:
|
|
504
|
+
"""
|
|
505
|
+
Generate forecasts using the trained model.
|
|
506
|
+
|
|
507
|
+
This method:
|
|
508
|
+
1. Uses the trained model to forecast AR coefficients for each test point
|
|
509
|
+
2. Recursively generates forecasts using the forecasted AR coefficients
|
|
510
|
+
|
|
511
|
+
The forecasting process implements an autoregressive model where:
|
|
512
|
+
y_t = φ₁(x)y_{t-1} + φ₂(x)y_{t-2} + ... + φₚ(x)y_{t-p}
|
|
513
|
+
|
|
514
|
+
However, unlike traditional AR models, the φ(x) coefficients are not constant
|
|
515
|
+
but determined by the LightGBM model based on features x.
|
|
516
|
+
|
|
517
|
+
Parameters
|
|
518
|
+
----------
|
|
519
|
+
test_data : pd.DataFrame
|
|
520
|
+
Test data for which to generate forecasts. Must contain the same
|
|
521
|
+
feature columns used during training.
|
|
522
|
+
type : str
|
|
523
|
+
Type of forecast to generate. Options:
|
|
524
|
+
- "forecast": Generate forecasted values
|
|
525
|
+
- "parameters": Return the AR(p) coefficients used for forecasting
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
pd.DataFrame
|
|
530
|
+
Forecasted data with columns:
|
|
531
|
+
- series_id: Identifier for each time series
|
|
532
|
+
- date: Forecast date/time
|
|
533
|
+
- fcst: Forecasted value (if type="forecast")
|
|
534
|
+
- model: Model name identifier
|
|
535
|
+
- AR(i): AR coefficient values (if type="parameters")
|
|
536
|
+
"""
|
|
537
|
+
# Check if model is trained
|
|
538
|
+
if not self.is_trained or self.model is None:
|
|
539
|
+
raise RuntimeError("Model has not been trained. Call train() before forecasting.")
|
|
540
|
+
|
|
541
|
+
# Validate input data
|
|
542
|
+
required_cols = ['series_id', 'date']
|
|
543
|
+
for col in required_cols:
|
|
544
|
+
if col not in test_data.columns:
|
|
545
|
+
raise ValueError(f"Required column '{col}' not found in test_data")
|
|
546
|
+
|
|
547
|
+
# Validate row ordering: each series must be a contiguous block with
|
|
548
|
+
# monotonic dates so the forecast reshape aligns forecasts with lags.
|
|
549
|
+
validate_series_order(test_data, name="test_data")
|
|
550
|
+
|
|
551
|
+
# Validate series IDs match training data
|
|
552
|
+
test_series_ids = test_data["series_id"].unique()
|
|
553
|
+
train_series_ids = set(self.fcst_lags.keys())
|
|
554
|
+
missing = set(test_series_ids) - train_series_ids
|
|
555
|
+
extra = train_series_ids - set(test_series_ids)
|
|
556
|
+
if missing or extra:
|
|
557
|
+
parts = []
|
|
558
|
+
if missing:
|
|
559
|
+
parts.append(f"Missing series in training: {missing}")
|
|
560
|
+
if extra:
|
|
561
|
+
parts.append(f"Extra series not in test_data: {extra}")
|
|
562
|
+
raise ValueError(". ".join(parts))
|
|
563
|
+
|
|
564
|
+
# Validate rows per series matches fcst_h (forecast only; parameters
|
|
565
|
+
# can be requested for arbitrary-length input).
|
|
566
|
+
if type == "forecast":
|
|
567
|
+
rows_per_series = test_data.groupby("series_id", sort=False).size()
|
|
568
|
+
bad = rows_per_series[rows_per_series != self.fcst_h]
|
|
569
|
+
if not bad.empty:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"Each series must have exactly fcst_h={self.fcst_h} rows in test_data. "
|
|
572
|
+
f"Series with wrong counts: {bad.to_dict()}"
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Check that all features used during training exist in test_data
|
|
576
|
+
missing_features = [f for f in self.features if f not in test_data.columns]
|
|
577
|
+
if missing_features:
|
|
578
|
+
raise ValueError(f"Missing features in test_data: {missing_features}")
|
|
579
|
+
|
|
580
|
+
# Validate type parameter
|
|
581
|
+
if type not in ["forecast", "parameters"]:
|
|
582
|
+
raise ValueError("Parameter 'type' must be either 'forecast' or 'parameters'")
|
|
583
|
+
|
|
584
|
+
try:
|
|
585
|
+
|
|
586
|
+
if type == "forecast":
|
|
587
|
+
# Get AR parameter forecasts from the LightGBM model
|
|
588
|
+
# Shape: (n_series, fcst_h, n_params)
|
|
589
|
+
n_series_test = len(test_series_ids)
|
|
590
|
+
params_fcst = self.model.predict(test_data[self.features]).reshape(n_series_test, self.fcst_h, self.p)
|
|
591
|
+
|
|
592
|
+
# Reconstruct lags array in the same order as test data
|
|
593
|
+
lags = np.array([self.fcst_lags[series_id] for series_id in test_series_ids])
|
|
594
|
+
|
|
595
|
+
# Generate multi-step forecasts
|
|
596
|
+
forecasts = []
|
|
597
|
+
for h in range(self.fcst_h):
|
|
598
|
+
# Compute next value using AR equation: y_t = φ₁y_{t-1} + φ₂y_{t-2} + ... + φₚy_{t-p}
|
|
599
|
+
next_val = np.sum(params_fcst[:, h, :] * lags, axis=1).reshape(-1, 1)
|
|
600
|
+
forecasts.append(next_val)
|
|
601
|
+
|
|
602
|
+
# Update lags for next step by adding new forecast and removing oldest lag
|
|
603
|
+
lags = np.concatenate([next_val, lags[:, :-1]], axis=1)
|
|
604
|
+
|
|
605
|
+
# Create output dataframe based on requested type
|
|
606
|
+
out_df = pd.DataFrame({
|
|
607
|
+
"series_id": test_data["series_id"].to_numpy().flatten(),
|
|
608
|
+
"date": test_data["date"].to_numpy().flatten(),
|
|
609
|
+
"fcst": np.hstack(forecasts).flatten(),
|
|
610
|
+
"model": f"Hyper-Tree-AR({self.p})",
|
|
611
|
+
})
|
|
612
|
+
|
|
613
|
+
elif type == "parameters":
|
|
614
|
+
params_fcst = np.asarray(self.model.predict(test_data[self.features]))
|
|
615
|
+
# LightGBM may return 1D (column-major) or 2D depending on version/objective.
|
|
616
|
+
# Normalize to (n_test, p) before indexing.
|
|
617
|
+
if params_fcst.ndim == 1:
|
|
618
|
+
params_fcst = params_fcst.reshape(-1, self.p, order="F")
|
|
619
|
+
out_df = pd.DataFrame({
|
|
620
|
+
"series_id": test_data["series_id"].to_numpy().flatten(),
|
|
621
|
+
"date": test_data["date"].to_numpy().flatten(),
|
|
622
|
+
"model": f"Hyper-Tree-AR({self.p})",
|
|
623
|
+
})
|
|
624
|
+
# Add AR parameters to the dataframe
|
|
625
|
+
for i in range(self.p):
|
|
626
|
+
out_df[f"AR({i + 1})"] = params_fcst[:, i].flatten()
|
|
627
|
+
|
|
628
|
+
return out_df
|
|
629
|
+
|
|
630
|
+
except Exception as e:
|
|
631
|
+
raise RuntimeError(f"Forecasting not successful: {str(e)}")
|