tabpfn-time-series 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tabpfn_time_series/__init__.py +2 -0
- tabpfn_time_series/predictor.py +5 -7
- tabpfn_time_series/tabpfn_worker.py +119 -17
- {tabpfn_time_series-0.1.2.dist-info → tabpfn_time_series-0.1.3.dist-info}/METADATA +21 -7
- tabpfn_time_series-0.1.3.dist-info/RECORD +11 -0
- tabpfn_time_series-0.1.2.dist-info/RECORD +0 -11
- {tabpfn_time_series-0.1.2.dist-info → tabpfn_time_series-0.1.3.dist-info}/WHEEL +0 -0
- {tabpfn_time_series-0.1.2.dist-info → tabpfn_time_series-0.1.3.dist-info}/licenses/LICENSE.txt +0 -0
tabpfn_time_series/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from .feature import DefaultFeatures, FeatureTransformer
|
2
2
|
from .predictor import TabPFNTimeSeriesPredictor, TabPFNMode
|
3
|
+
from .defaults import TABPFN_TS_DEFAULT_QUANTILE_CONFIG
|
3
4
|
|
4
5
|
__version__ = "0.1.0"
|
5
6
|
|
@@ -8,4 +9,5 @@ __all__ = [
|
|
8
9
|
"FeatureTransformer",
|
9
10
|
"TabPFNTimeSeriesPredictor",
|
10
11
|
"TabPFNMode",
|
12
|
+
"TABPFN_TS_DEFAULT_QUANTILE_CONFIG",
|
11
13
|
]
|
tabpfn_time_series/predictor.py
CHANGED
@@ -3,11 +3,8 @@ from enum import Enum
|
|
3
3
|
|
4
4
|
from autogluon.timeseries import TimeSeriesDataFrame
|
5
5
|
|
6
|
-
from tabpfn_time_series.tabpfn_worker import TabPFNClient, LocalTabPFN
|
7
|
-
from tabpfn_time_series.defaults import
|
8
|
-
TABPFN_TS_DEFAULT_QUANTILE_CONFIG,
|
9
|
-
TABPFN_TS_DEFAULT_CONFIG,
|
10
|
-
)
|
6
|
+
from tabpfn_time_series.tabpfn_worker import TabPFNClient, LocalTabPFN, MockTabPFN
|
7
|
+
from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_CONFIG
|
11
8
|
|
12
9
|
logger = logging.getLogger(__name__)
|
13
10
|
|
@@ -15,6 +12,7 @@ logger = logging.getLogger(__name__)
|
|
15
12
|
class TabPFNMode(Enum):
|
16
13
|
LOCAL = "tabpfn-local"
|
17
14
|
CLIENT = "tabpfn-client"
|
15
|
+
MOCK = "tabpfn-mock"
|
18
16
|
|
19
17
|
|
20
18
|
class TabPFNTimeSeriesPredictor:
|
@@ -30,6 +28,7 @@ class TabPFNTimeSeriesPredictor:
|
|
30
28
|
worker_mapping = {
|
31
29
|
TabPFNMode.CLIENT: lambda: TabPFNClient(config),
|
32
30
|
TabPFNMode.LOCAL: lambda: LocalTabPFN(config),
|
31
|
+
TabPFNMode.MOCK: lambda: MockTabPFN(config),
|
33
32
|
}
|
34
33
|
self.tabpfn_worker = worker_mapping[tabpfn_mode]()
|
35
34
|
|
@@ -37,7 +36,6 @@ class TabPFNTimeSeriesPredictor:
|
|
37
36
|
self,
|
38
37
|
train_tsdf: TimeSeriesDataFrame, # with features and target
|
39
38
|
test_tsdf: TimeSeriesDataFrame, # with features only
|
40
|
-
quantile_config: list[float] = TABPFN_TS_DEFAULT_QUANTILE_CONFIG,
|
41
39
|
) -> TimeSeriesDataFrame:
|
42
40
|
"""
|
43
41
|
Predict on each time series individually (local forecasting).
|
@@ -47,4 +45,4 @@ class TabPFNTimeSeriesPredictor:
|
|
47
45
|
f"Predicting {len(train_tsdf.item_ids)} time series with config{self.tabpfn_worker.config}"
|
48
46
|
)
|
49
47
|
|
50
|
-
return self.tabpfn_worker.predict(train_tsdf, test_tsdf
|
48
|
+
return self.tabpfn_worker.predict(train_tsdf, test_tsdf)
|
@@ -2,8 +2,10 @@ import logging
|
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
from joblib import Parallel, delayed
|
4
4
|
|
5
|
+
from tqdm import tqdm
|
5
6
|
import pandas as pd
|
6
7
|
import numpy as np
|
8
|
+
import torch
|
7
9
|
from scipy.stats import norm
|
8
10
|
from autogluon.timeseries import TimeSeriesDataFrame
|
9
11
|
|
@@ -26,14 +28,7 @@ class TabPFNWorker(ABC):
|
|
26
28
|
self,
|
27
29
|
train_tsdf: TimeSeriesDataFrame,
|
28
30
|
test_tsdf: TimeSeriesDataFrame,
|
29
|
-
quantile_config: list[float],
|
30
31
|
):
|
31
|
-
if not set(quantile_config).issubset(set(TABPFN_TS_DEFAULT_QUANTILE_CONFIG)):
|
32
|
-
raise NotImplementedError(
|
33
|
-
f"We currently only supports {TABPFN_TS_DEFAULT_QUANTILE_CONFIG} for quantile prediction,"
|
34
|
-
f" but got {quantile_config}."
|
35
|
-
)
|
36
|
-
|
37
32
|
predictions = Parallel(
|
38
33
|
n_jobs=self.num_workers,
|
39
34
|
backend="loky",
|
@@ -42,9 +37,8 @@ class TabPFNWorker(ABC):
|
|
42
37
|
item_id,
|
43
38
|
train_tsdf.loc[item_id],
|
44
39
|
test_tsdf.loc[item_id],
|
45
|
-
quantile_config,
|
46
40
|
)
|
47
|
-
for item_id in train_tsdf.item_ids
|
41
|
+
for item_id in tqdm(train_tsdf.item_ids, desc="Predicting time series")
|
48
42
|
)
|
49
43
|
|
50
44
|
predictions = pd.concat(predictions)
|
@@ -59,8 +53,9 @@ class TabPFNWorker(ABC):
|
|
59
53
|
item_id: str,
|
60
54
|
single_train_tsdf: TimeSeriesDataFrame,
|
61
55
|
single_test_tsdf: TimeSeriesDataFrame,
|
62
|
-
quantile_config: list[float],
|
63
56
|
) -> pd.DataFrame:
|
57
|
+
# logger.debug(f"Predicting on item_id: {item_id}")
|
58
|
+
|
64
59
|
test_index = single_test_tsdf.index
|
65
60
|
train_X, train_y = split_time_series_to_X_y(single_train_tsdf.copy())
|
66
61
|
test_X, _ = split_time_series_to_X_y(single_test_tsdf.copy())
|
@@ -70,7 +65,7 @@ class TabPFNWorker(ABC):
|
|
70
65
|
if train_y_has_constant_value:
|
71
66
|
logger.info("Found time-series with constant target")
|
72
67
|
result = self._predict_on_constant_train_target(
|
73
|
-
single_train_tsdf, single_test_tsdf
|
68
|
+
single_train_tsdf, single_test_tsdf
|
74
69
|
)
|
75
70
|
else:
|
76
71
|
tabpfn = self._get_tabpfn_engine()
|
@@ -81,7 +76,9 @@ class TabPFNWorker(ABC):
|
|
81
76
|
result.update(
|
82
77
|
{
|
83
78
|
q: q_pred
|
84
|
-
for q, q_pred in zip(
|
79
|
+
for q, q_pred in zip(
|
80
|
+
TABPFN_TS_DEFAULT_QUANTILE_CONFIG, full_pred["quantiles"]
|
81
|
+
)
|
85
82
|
}
|
86
83
|
)
|
87
84
|
|
@@ -98,7 +95,6 @@ class TabPFNWorker(ABC):
|
|
98
95
|
self,
|
99
96
|
single_train_tsdf: TimeSeriesDataFrame,
|
100
97
|
single_test_tsdf: TimeSeriesDataFrame,
|
101
|
-
quantile_config: list[float],
|
102
98
|
) -> pd.DataFrame:
|
103
99
|
# If train_y is constant, we return the constant value from the training set
|
104
100
|
mean_constant = single_train_tsdf.target.iloc[0]
|
@@ -106,12 +102,14 @@ class TabPFNWorker(ABC):
|
|
106
102
|
|
107
103
|
# For quantile prediction, we assume that the uncertainty follows a standard normal distribution
|
108
104
|
quantile_pred_with_uncertainty = norm.ppf(
|
109
|
-
|
105
|
+
TABPFN_TS_DEFAULT_QUANTILE_CONFIG, loc=mean_constant, scale=1
|
110
106
|
)
|
111
107
|
result.update(
|
112
108
|
{
|
113
109
|
q: np.full(len(single_test_tsdf), v)
|
114
|
-
for q, v in zip(
|
110
|
+
for q, v in zip(
|
111
|
+
TABPFN_TS_DEFAULT_QUANTILE_CONFIG, quantile_pred_with_uncertainty
|
112
|
+
)
|
115
113
|
}
|
116
114
|
)
|
117
115
|
|
@@ -141,8 +139,52 @@ class LocalTabPFN(TabPFNWorker):
|
|
141
139
|
def __init__(
|
142
140
|
self,
|
143
141
|
config: dict = {},
|
142
|
+
num_workers_per_gpu: int = 4, # per GPU
|
143
|
+
):
|
144
|
+
self.num_workers_per_gpu = num_workers_per_gpu
|
145
|
+
|
146
|
+
# Only support GPU for now (inference on CPU takes too long)
|
147
|
+
if not torch.cuda.is_available():
|
148
|
+
raise ValueError("GPU is required for local TabPFN inference")
|
149
|
+
|
150
|
+
super().__init__(
|
151
|
+
config, num_workers=torch.cuda.device_count() * self.num_workers_per_gpu
|
152
|
+
)
|
153
|
+
|
154
|
+
def predict(
|
155
|
+
self,
|
156
|
+
train_tsdf: TimeSeriesDataFrame,
|
157
|
+
test_tsdf: TimeSeriesDataFrame,
|
144
158
|
):
|
145
|
-
|
159
|
+
total_num_workers = torch.cuda.device_count() * self.num_workers_per_gpu
|
160
|
+
|
161
|
+
# Split data into chunks for parallel inference on each GPU
|
162
|
+
# since the time series are of different lengths, we shuffle
|
163
|
+
# the item_ids s.t. the workload is distributed evenly across GPUs
|
164
|
+
# Also, using 'min' since num_workers could be larger than the number of time series
|
165
|
+
np.random.seed(0)
|
166
|
+
item_ids_chunks = np.array_split(
|
167
|
+
np.random.permutation(train_tsdf.item_ids),
|
168
|
+
min(total_num_workers, len(train_tsdf.item_ids)),
|
169
|
+
)
|
170
|
+
|
171
|
+
# Run predictions in parallel
|
172
|
+
predictions = Parallel(n_jobs=len(item_ids_chunks), backend="loky")(
|
173
|
+
delayed(self._prediction_routine_per_gpu)(
|
174
|
+
train_tsdf.loc[chunk],
|
175
|
+
test_tsdf.loc[chunk],
|
176
|
+
gpu_id=i
|
177
|
+
% torch.cuda.device_count(), # Alternate between available GPUs
|
178
|
+
)
|
179
|
+
for i, chunk in enumerate(item_ids_chunks)
|
180
|
+
)
|
181
|
+
|
182
|
+
predictions = pd.concat(predictions)
|
183
|
+
|
184
|
+
# Sort predictions according to original item_ids order
|
185
|
+
predictions = predictions.loc[train_tsdf.item_ids]
|
186
|
+
|
187
|
+
return TimeSeriesDataFrame(predictions)
|
146
188
|
|
147
189
|
def _get_tabpfn_engine(self):
|
148
190
|
from tabpfn import TabPFNRegressor
|
@@ -151,7 +193,67 @@ class LocalTabPFN(TabPFNWorker):
|
|
151
193
|
config = self.config["tabpfn_internal"].copy()
|
152
194
|
config["model_path"] = self._parse_model_path(config["model_path"])
|
153
195
|
|
154
|
-
return TabPFNRegressor(**config)
|
196
|
+
return TabPFNRegressor(**config, random_state=0)
|
155
197
|
|
156
198
|
def _parse_model_path(self, model_name: str) -> str:
|
157
199
|
return f"tabpfn-v2-regressor-{model_name}.ckpt"
|
200
|
+
|
201
|
+
def _prediction_routine_per_gpu(
|
202
|
+
self,
|
203
|
+
train_tsdf: TimeSeriesDataFrame,
|
204
|
+
test_tsdf: TimeSeriesDataFrame,
|
205
|
+
gpu_id: int,
|
206
|
+
):
|
207
|
+
# Set GPU
|
208
|
+
torch.cuda.set_device(gpu_id)
|
209
|
+
|
210
|
+
all_pred = []
|
211
|
+
for item_id in tqdm(train_tsdf.item_ids, desc=f"GPU {gpu_id}:"):
|
212
|
+
predictions = self._prediction_routine(
|
213
|
+
item_id,
|
214
|
+
train_tsdf.loc[item_id],
|
215
|
+
test_tsdf.loc[item_id],
|
216
|
+
)
|
217
|
+
all_pred.append(predictions)
|
218
|
+
|
219
|
+
# Clear GPU cache
|
220
|
+
torch.cuda.empty_cache()
|
221
|
+
|
222
|
+
return pd.concat(all_pred)
|
223
|
+
|
224
|
+
|
225
|
+
class MockTabPFN(TabPFNWorker):
|
226
|
+
"""
|
227
|
+
Mock TabPFN worker that returns random values for predictions.
|
228
|
+
Can be used for testing or debugging.
|
229
|
+
"""
|
230
|
+
|
231
|
+
class MockTabPFNRegressor:
|
232
|
+
TABPFN_QUANTILE = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
233
|
+
|
234
|
+
def __init__(self, *args, **kwargs):
|
235
|
+
pass
|
236
|
+
|
237
|
+
def fit(self, *args, **kwargs):
|
238
|
+
pass
|
239
|
+
|
240
|
+
def predict(self, test_X, output_type="main", **kwargs):
|
241
|
+
if output_type != "main":
|
242
|
+
raise NotImplementedError(
|
243
|
+
"Only main output is supported for mock TabPFN"
|
244
|
+
)
|
245
|
+
|
246
|
+
return {
|
247
|
+
"mean": np.random.rand(len(test_X)),
|
248
|
+
"median": np.random.rand(len(test_X)),
|
249
|
+
"mode": np.random.rand(len(test_X)),
|
250
|
+
"quantiles": [
|
251
|
+
np.random.rand(len(test_X)) for _ in self.TABPFN_QUANTILE
|
252
|
+
],
|
253
|
+
}
|
254
|
+
|
255
|
+
def __init__(self, *args, **kwargs):
|
256
|
+
super().__init__(*args, **kwargs)
|
257
|
+
|
258
|
+
def _get_tabpfn_engine(self):
|
259
|
+
return self.MockTabPFNRegressor()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tabpfn_time_series
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.3
|
4
4
|
Summary: Zero-shot time series forecasting with TabPFN
|
5
5
|
Project-URL: Homepage, https://github.com/liam-sbhoo/tabpfn-time-series
|
6
6
|
Project-URL: Bug Tracker, https://github.com/liam-sbhoo/tabpfn-time-series/issues
|
@@ -10,11 +10,12 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
-
Requires-Dist: autogluon-timeseries
|
14
|
-
Requires-Dist:
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist:
|
17
|
-
Requires-Dist: tabpfn
|
13
|
+
Requires-Dist: autogluon-timeseries>=1.2
|
14
|
+
Requires-Dist: datasets>=3.3.2
|
15
|
+
Requires-Dist: gluonts>=0.16.0
|
16
|
+
Requires-Dist: pandas<2.2.0,>=2.1.2
|
17
|
+
Requires-Dist: tabpfn-client>=0.1.1
|
18
|
+
Requires-Dist: tabpfn>=2.0.0
|
18
19
|
Requires-Dist: tqdm
|
19
20
|
Provides-Extra: dev
|
20
21
|
Requires-Dist: build; extra == 'dev'
|
@@ -24,13 +25,20 @@ Requires-Dist: ruff; extra == 'dev'
|
|
24
25
|
Requires-Dist: twine; extra == 'dev'
|
25
26
|
Description-Content-Type: text/markdown
|
26
27
|
|
27
|
-
# Time Series Forecasting with TabPFN
|
28
|
+
# Zero-Shot Time Series Forecasting with TabPFN
|
28
29
|
|
30
|
+
[](https://badge.fury.io/py/tabpfn-time-series)
|
29
31
|
[](https://colab.research.google.com/github/liam-sbhoo/tabpfn-time-series/blob/main/demo.ipynb)
|
30
32
|
[](https://discord.com/channels/1285598202732482621/)
|
31
33
|
[](https://arxiv.org/abs/2501.02945)
|
32
34
|
|
35
|
+
## 📌 News
|
36
|
+
- **27-01-2025**: 🚀 Ranked _**1st**_ on [GIFT-EVAL](https://huggingface.co/spaces/Salesforce/GIFT-Eval) benchmark<sup>[1]</sup>!
|
37
|
+
- **10-10-2024**: 🚀 TabPFN-TS [paper](https://arxiv.org/abs/2501.02945) accepted to NeurIPS 2024 [TRL](https://table-representation-learning.github.io/NeurIPS2024/) and [TSALM](https://neurips-time-series-workshop.github.io/) workshops!
|
33
38
|
|
39
|
+
_[1] Last checked on: 10/03/2025_
|
40
|
+
|
41
|
+
## ✨ Introduction
|
34
42
|
We demonstrate that the tabular foundation model **[TabPFN](https://github.com/PriorLabs/TabPFN)**, when paired with minimal featurization, can perform zero-shot time series forecasting. Its performance on point forecasting matches or even slightly outperforms state-of-the-art methods.
|
35
43
|
|
36
44
|
## 📖 How does it work?
|
@@ -59,3 +67,9 @@ On top of that, thanks to **[tabpfn-client](https://github.com/automl/tabpfn-cli
|
|
59
67
|
[](https://colab.research.google.com/github/liam-sbhoo/tabpfn-time-series/blob/main/demo.ipynb)
|
60
68
|
|
61
69
|
The demo should explain it all. 😉
|
70
|
+
|
71
|
+
## 📊 GIFT-EVAL Benchmark
|
72
|
+
|
73
|
+
We have submitted our results to the [GIFT-EVAL](https://huggingface.co/spaces/Salesforce/GIFT-Eval) benchmark. Stay tuned for results!
|
74
|
+
|
75
|
+
For more details regarding the evaluation setup, please refer to [README.md](gift_eval/README.md).
|
@@ -0,0 +1,11 @@
|
|
1
|
+
tabpfn_time_series/__init__.py,sha256=brJLLVOis4tBGOmNk6PCjyk_RaOvFITZgaYChOTVqSo,353
|
2
|
+
tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
|
3
|
+
tabpfn_time_series/defaults.py,sha256=u2_JnwxiZ5NNibzyNpsE63KuP3TcmOL1iAP8llZ2rJk,238
|
4
|
+
tabpfn_time_series/feature.py,sha256=_9FxfQfgPOOO1MiT8hB8523eZ3Nc5oKuoY7vcohKZZc,2531
|
5
|
+
tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
|
6
|
+
tabpfn_time_series/predictor.py,sha256=JzuV34zERf1XDLacGzSFJb-o077qd7GlKC6lvD62EPk,1457
|
7
|
+
tabpfn_time_series/tabpfn_worker.py,sha256=zvFwg4Dc01_m5emqmVITBr6W_cNZ04tMyntmj40pyPE,8299
|
8
|
+
tabpfn_time_series-0.1.3.dist-info/METADATA,sha256=KQZBVKZgMX4e3uxk2LTCuSwruATLowUmgrP6wbcLMB8,4158
|
9
|
+
tabpfn_time_series-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
tabpfn_time_series-0.1.3.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
|
11
|
+
tabpfn_time_series-0.1.3.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
tabpfn_time_series/__init__.py,sha256=5ruHrmKBQRIZ3WXLA8du4JKttF55ntnI74hkRsHThQ8,256
|
2
|
-
tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
|
3
|
-
tabpfn_time_series/defaults.py,sha256=u2_JnwxiZ5NNibzyNpsE63KuP3TcmOL1iAP8llZ2rJk,238
|
4
|
-
tabpfn_time_series/feature.py,sha256=_9FxfQfgPOOO1MiT8hB8523eZ3Nc5oKuoY7vcohKZZc,2531
|
5
|
-
tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
|
6
|
-
tabpfn_time_series/predictor.py,sha256=W9JijaxFaR0chfiW7m4RuDQ0wrRcJezDWVwCBEOQDFk,1502
|
7
|
-
tabpfn_time_series/tabpfn_worker.py,sha256=XNpqLEW51PgzrEopNNdtGdYArMCHT4yeBK3BS3z25K0,5021
|
8
|
-
tabpfn_time_series-0.1.2.dist-info/METADATA,sha256=hO69b8GN3GDRIetG4DGtxpdMubc8sm8h_aI2RwEto2U,3285
|
9
|
-
tabpfn_time_series-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
-
tabpfn_time_series-0.1.2.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
|
11
|
-
tabpfn_time_series-0.1.2.dist-info/RECORD,,
|
File without changes
|
{tabpfn_time_series-0.1.2.dist-info → tabpfn_time_series-0.1.3.dist-info}/licenses/LICENSE.txt
RENAMED
File without changes
|