tabpfn-time-series 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tabpfn_time_series/__init__.py +1 -1
- tabpfn_time_series/defaults.py +6 -4
- tabpfn_time_series/predictor.py +9 -6
- tabpfn_time_series/tabpfn_worker.py +32 -99
- {tabpfn_time_series-0.1.0.dist-info → tabpfn_time_series-0.1.2.dist-info}/METADATA +9 -7
- tabpfn_time_series-0.1.2.dist-info/RECORD +11 -0
- tabpfn_time_series-0.1.0.dist-info/RECORD +0 -11
- {tabpfn_time_series-0.1.0.dist-info → tabpfn_time_series-0.1.2.dist-info}/WHEEL +0 -0
- {tabpfn_time_series-0.1.0.dist-info → tabpfn_time_series-0.1.2.dist-info}/licenses/LICENSE.txt +0 -0
tabpfn_time_series/__init__.py
CHANGED
tabpfn_time_series/defaults.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
"
|
4
|
-
|
1
|
+
TABPFN_TS_DEFAULT_QUANTILE_CONFIG = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
2
|
+
TABPFN_TS_DEFAULT_CONFIG = {
|
3
|
+
"tabpfn_internal": {
|
4
|
+
"model_path": "2noar4o2",
|
5
|
+
},
|
6
|
+
"tabpfn_output_selection": "median", # mean or median
|
5
7
|
}
|
tabpfn_time_series/predictor.py
CHANGED
@@ -4,7 +4,10 @@ from enum import Enum
|
|
4
4
|
from autogluon.timeseries import TimeSeriesDataFrame
|
5
5
|
|
6
6
|
from tabpfn_time_series.tabpfn_worker import TabPFNClient, LocalTabPFN
|
7
|
-
from tabpfn_time_series.defaults import
|
7
|
+
from tabpfn_time_series.defaults import (
|
8
|
+
TABPFN_TS_DEFAULT_QUANTILE_CONFIG,
|
9
|
+
TABPFN_TS_DEFAULT_CONFIG,
|
10
|
+
)
|
8
11
|
|
9
12
|
logger = logging.getLogger(__name__)
|
10
13
|
|
@@ -22,11 +25,11 @@ class TabPFNTimeSeriesPredictor:
|
|
22
25
|
def __init__(
|
23
26
|
self,
|
24
27
|
tabpfn_mode: TabPFNMode = TabPFNMode.CLIENT,
|
25
|
-
|
28
|
+
config: dict = TABPFN_TS_DEFAULT_CONFIG,
|
26
29
|
) -> None:
|
27
30
|
worker_mapping = {
|
28
|
-
TabPFNMode.CLIENT: lambda: TabPFNClient(
|
29
|
-
TabPFNMode.LOCAL: lambda: LocalTabPFN(
|
31
|
+
TabPFNMode.CLIENT: lambda: TabPFNClient(config),
|
32
|
+
TabPFNMode.LOCAL: lambda: LocalTabPFN(config),
|
30
33
|
}
|
31
34
|
self.tabpfn_worker = worker_mapping[tabpfn_mode]()
|
32
35
|
|
@@ -34,14 +37,14 @@ class TabPFNTimeSeriesPredictor:
|
|
34
37
|
self,
|
35
38
|
train_tsdf: TimeSeriesDataFrame, # with features and target
|
36
39
|
test_tsdf: TimeSeriesDataFrame, # with features only
|
37
|
-
quantile_config: list[float] =
|
40
|
+
quantile_config: list[float] = TABPFN_TS_DEFAULT_QUANTILE_CONFIG,
|
38
41
|
) -> TimeSeriesDataFrame:
|
39
42
|
"""
|
40
43
|
Predict on each time series individually (local forecasting).
|
41
44
|
"""
|
42
45
|
|
43
46
|
logger.info(
|
44
|
-
f"Predicting {len(train_tsdf.item_ids)} time series with config{self.tabpfn_worker.
|
47
|
+
f"Predicting {len(train_tsdf.item_ids)} time series with config{self.tabpfn_worker.config}"
|
45
48
|
)
|
46
49
|
|
47
50
|
return self.tabpfn_worker.predict(train_tsdf, test_tsdf, quantile_config)
|
@@ -8,7 +8,7 @@ from scipy.stats import norm
|
|
8
8
|
from autogluon.timeseries import TimeSeriesDataFrame
|
9
9
|
|
10
10
|
from tabpfn_time_series.data_preparation import split_time_series_to_X_y
|
11
|
-
from tabpfn_time_series.defaults import
|
11
|
+
from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_QUANTILE_CONFIG
|
12
12
|
|
13
13
|
logger = logging.getLogger(__name__)
|
14
14
|
|
@@ -16,10 +16,10 @@ logger = logging.getLogger(__name__)
|
|
16
16
|
class TabPFNWorker(ABC):
|
17
17
|
def __init__(
|
18
18
|
self,
|
19
|
-
|
19
|
+
config: dict = {},
|
20
20
|
num_workers: int = 1,
|
21
21
|
):
|
22
|
-
self.
|
22
|
+
self.config = config
|
23
23
|
self.num_workers = num_workers
|
24
24
|
|
25
25
|
def predict(
|
@@ -28,6 +28,12 @@ class TabPFNWorker(ABC):
|
|
28
28
|
test_tsdf: TimeSeriesDataFrame,
|
29
29
|
quantile_config: list[float],
|
30
30
|
):
|
31
|
+
if not set(quantile_config).issubset(set(TABPFN_TS_DEFAULT_QUANTILE_CONFIG)):
|
32
|
+
raise NotImplementedError(
|
33
|
+
f"We currently only supports {TABPFN_TS_DEFAULT_QUANTILE_CONFIG} for quantile prediction,"
|
34
|
+
f" but got {quantile_config}."
|
35
|
+
)
|
36
|
+
|
31
37
|
predictions = Parallel(
|
32
38
|
n_jobs=self.num_workers,
|
33
39
|
backend="loky",
|
@@ -67,12 +73,16 @@ class TabPFNWorker(ABC):
|
|
67
73
|
single_train_tsdf, single_test_tsdf, quantile_config
|
68
74
|
)
|
69
75
|
else:
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
+
tabpfn = self._get_tabpfn_engine()
|
77
|
+
tabpfn.fit(train_X, train_y)
|
78
|
+
full_pred = tabpfn.predict(test_X, output_type="main")
|
79
|
+
|
80
|
+
result = {"target": full_pred[self.config["tabpfn_output_selection"]]}
|
81
|
+
result.update(
|
82
|
+
{
|
83
|
+
q: q_pred
|
84
|
+
for q, q_pred in zip(quantile_config, full_pred["quantiles"])
|
85
|
+
}
|
76
86
|
)
|
77
87
|
|
78
88
|
result = pd.DataFrame(result, index=test_index)
|
@@ -81,13 +91,7 @@ class TabPFNWorker(ABC):
|
|
81
91
|
return result
|
82
92
|
|
83
93
|
@abstractmethod
|
84
|
-
def
|
85
|
-
self,
|
86
|
-
train_X: pd.DataFrame,
|
87
|
-
train_y: pd.Series,
|
88
|
-
test_X: pd.DataFrame,
|
89
|
-
quantile_config: list[float],
|
90
|
-
) -> pd.DataFrame:
|
94
|
+
def _get_tabpfn_engine(self):
|
91
95
|
pass
|
92
96
|
|
93
97
|
def _predict_on_constant_train_target(
|
@@ -117,108 +121,37 @@ class TabPFNWorker(ABC):
|
|
117
121
|
class TabPFNClient(TabPFNWorker):
|
118
122
|
def __init__(
|
119
123
|
self,
|
120
|
-
|
124
|
+
config: dict = {},
|
121
125
|
num_workers: int = 2,
|
122
126
|
):
|
123
|
-
super().__init__(
|
127
|
+
super().__init__(config, num_workers)
|
124
128
|
|
125
129
|
# Initialize the TabPFN client (e.g. sign up, login, etc.)
|
126
130
|
from tabpfn_client import init
|
127
131
|
|
128
132
|
init()
|
129
133
|
|
130
|
-
def
|
131
|
-
self,
|
132
|
-
train_tsdf: TimeSeriesDataFrame,
|
133
|
-
test_tsdf: TimeSeriesDataFrame,
|
134
|
-
quantile_config: list[float],
|
135
|
-
):
|
136
|
-
if not set(quantile_config).issubset(set(TABPFN_DEFAULT_QUANTILE_CONFIG)):
|
137
|
-
raise NotImplementedError(
|
138
|
-
f"TabPFNClient currently only supports {TABPFN_DEFAULT_QUANTILE_CONFIG} for quantile prediction,"
|
139
|
-
f" but got {quantile_config}."
|
140
|
-
)
|
141
|
-
|
142
|
-
return super().predict(train_tsdf, test_tsdf, quantile_config)
|
143
|
-
|
144
|
-
def _worker_specific_prediction_routine(
|
145
|
-
self,
|
146
|
-
train_X: pd.DataFrame,
|
147
|
-
train_y: pd.Series,
|
148
|
-
test_X: pd.DataFrame,
|
149
|
-
quantile_config: list[float],
|
150
|
-
) -> pd.DataFrame:
|
134
|
+
def _get_tabpfn_engine(self):
|
151
135
|
from tabpfn_client import TabPFNRegressor
|
152
136
|
|
153
|
-
|
154
|
-
tabpfn.fit(train_X, train_y)
|
155
|
-
full_pred = tabpfn.predict_full(test_X)
|
156
|
-
|
157
|
-
result = {"target": full_pred[self._get_optimization_mode()]}
|
158
|
-
result.update({q: full_pred[f"quantile_{q:.2f}"] for q in quantile_config})
|
159
|
-
|
160
|
-
return result
|
161
|
-
|
162
|
-
def _get_optimization_mode(self):
|
163
|
-
if (
|
164
|
-
"optimize_metric" not in self.tabpfn_config
|
165
|
-
or self.tabpfn_config["optimize_metric"] is None
|
166
|
-
):
|
167
|
-
return "mean"
|
168
|
-
elif self.tabpfn_config["optimize_metric"] in ["rmse", "mse", "r2", "mean"]:
|
169
|
-
return "mean"
|
170
|
-
elif self.tabpfn_config["optimize_metric"] in ["mae", "median"]:
|
171
|
-
return "median"
|
172
|
-
elif self.tabpfn_config["optimize_metric"] in ["mode", "exact_match"]:
|
173
|
-
return "mode"
|
174
|
-
else:
|
175
|
-
raise ValueError(f"Unknown metric {self.tabpfn_config['optimize_metric']}")
|
137
|
+
return TabPFNRegressor(**self.config["tabpfn_internal"])
|
176
138
|
|
177
139
|
|
178
140
|
class LocalTabPFN(TabPFNWorker):
|
179
141
|
def __init__(
|
180
142
|
self,
|
181
|
-
|
143
|
+
config: dict = {},
|
182
144
|
):
|
183
|
-
|
184
|
-
if "model" in tabpfn_config:
|
185
|
-
config = tabpfn_config.copy()
|
186
|
-
config["model_path"] = self._parse_model_path(config["model"])
|
187
|
-
del config["model"]
|
188
|
-
tabpfn_config = config
|
189
|
-
|
190
|
-
super().__init__(tabpfn_config, num_workers=1)
|
145
|
+
super().__init__(config, num_workers=1)
|
191
146
|
|
192
|
-
def
|
193
|
-
self,
|
194
|
-
train_X: pd.DataFrame,
|
195
|
-
train_y: pd.Series,
|
196
|
-
test_X: pd.DataFrame,
|
197
|
-
quantile_config: list[float],
|
198
|
-
) -> pd.DataFrame:
|
147
|
+
def _get_tabpfn_engine(self):
|
199
148
|
from tabpfn import TabPFNRegressor
|
200
149
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
result = {"target": full_pred[tabpfn.get_optimization_mode()]}
|
206
|
-
if set(quantile_config).issubset(set(TABPFN_DEFAULT_QUANTILE_CONFIG)):
|
207
|
-
result.update({q: full_pred[f"quantile_{q:.2f}"] for q in quantile_config})
|
208
|
-
else:
|
209
|
-
import torch
|
150
|
+
if "model_path" in self.config["tabpfn_internal"]:
|
151
|
+
config = self.config["tabpfn_internal"].copy()
|
152
|
+
config["model_path"] = self._parse_model_path(config["model_path"])
|
210
153
|
|
211
|
-
|
212
|
-
logits = torch.tensor(full_pred["logits"])
|
213
|
-
result.update({q: criterion.icdf(logits, q) for q in quantile_config})
|
214
|
-
|
215
|
-
return result
|
154
|
+
return TabPFNRegressor(**config)
|
216
155
|
|
217
156
|
def _parse_model_path(self, model_name: str) -> str:
|
218
|
-
|
219
|
-
import importlib.util
|
220
|
-
|
221
|
-
tabpfn_path = Path(importlib.util.find_spec("tabpfn").origin).parent
|
222
|
-
return str(
|
223
|
-
tabpfn_path / "model_cache" / f"model_hans_regression_{model_name}.ckpt"
|
224
|
-
)
|
157
|
+
return f"tabpfn-v2-regressor-{model_name}.ckpt"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tabpfn_time_series
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: Zero-shot time series forecasting with TabPFN
|
5
5
|
Project-URL: Homepage, https://github.com/liam-sbhoo/tabpfn-time-series
|
6
6
|
Project-URL: Bug Tracker, https://github.com/liam-sbhoo/tabpfn-time-series/issues
|
@@ -10,13 +10,15 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
-
Requires-Dist: autogluon-timeseries
|
14
|
-
Requires-Dist: gluonts
|
13
|
+
Requires-Dist: autogluon-timeseries==1.2
|
14
|
+
Requires-Dist: gluonts==0.16.0
|
15
15
|
Requires-Dist: pandas
|
16
|
-
Requires-Dist: tabpfn-client
|
16
|
+
Requires-Dist: tabpfn-client==0.1.1
|
17
|
+
Requires-Dist: tabpfn==2.0.0
|
17
18
|
Requires-Dist: tqdm
|
18
19
|
Provides-Extra: dev
|
19
20
|
Requires-Dist: build; extra == 'dev'
|
21
|
+
Requires-Dist: jupyter; extra == 'dev'
|
20
22
|
Requires-Dist: pre-commit; extra == 'dev'
|
21
23
|
Requires-Dist: ruff; extra == 'dev'
|
22
24
|
Requires-Dist: twine; extra == 'dev'
|
@@ -26,10 +28,10 @@ Description-Content-Type: text/markdown
|
|
26
28
|
|
27
29
|
[](https://colab.research.google.com/github/liam-sbhoo/tabpfn-time-series/blob/main/demo.ipynb)
|
28
30
|
[](https://discord.com/channels/1285598202732482621/)
|
29
|
-
[](https://arxiv.org/abs/2501.02945)
|
30
32
|
|
31
33
|
|
32
|
-
We demonstrate that the tabular foundation model **TabPFN**, when paired with minimal featurization, can perform zero-shot time series forecasting. Its performance on point forecasting matches or even slightly outperforms state-of-the-art methods.
|
34
|
+
We demonstrate that the tabular foundation model **[TabPFN](https://github.com/PriorLabs/TabPFN)**, when paired with minimal featurization, can perform zero-shot time series forecasting. Its performance on point forecasting matches or even slightly outperforms state-of-the-art methods.
|
33
35
|
|
34
36
|
## 📖 How does it work?
|
35
37
|
|
@@ -50,7 +52,7 @@ For more details, please refer to our [paper](https://arxiv.org/abs/2501.02945)
|
|
50
52
|
- **Point and probabilistic forecasting**: it provides accurate point forecasts as well as probabilistic forecasts.
|
51
53
|
- **Support for exogenous variables**: if you have exogenous variables, this method can seemlessly incorporate them into the forecasting model.
|
52
54
|
|
53
|
-
On top of that, thanks to [tabpfn-client](https://github.com/automl/tabpfn-client) from [Prior Labs](https://priorlabs.ai)
|
55
|
+
On top of that, thanks to **[tabpfn-client](https://github.com/automl/tabpfn-client)** from **[Prior Labs](https://priorlabs.ai)**, you won’t even need your own GPU to run fast inference with TabPFN. 😉 We have included `tabpfn-client` as the default engine in our implementation.
|
54
56
|
|
55
57
|
## How to use it?
|
56
58
|
|
@@ -0,0 +1,11 @@
|
|
1
|
+
tabpfn_time_series/__init__.py,sha256=5ruHrmKBQRIZ3WXLA8du4JKttF55ntnI74hkRsHThQ8,256
|
2
|
+
tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
|
3
|
+
tabpfn_time_series/defaults.py,sha256=u2_JnwxiZ5NNibzyNpsE63KuP3TcmOL1iAP8llZ2rJk,238
|
4
|
+
tabpfn_time_series/feature.py,sha256=_9FxfQfgPOOO1MiT8hB8523eZ3Nc5oKuoY7vcohKZZc,2531
|
5
|
+
tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
|
6
|
+
tabpfn_time_series/predictor.py,sha256=W9JijaxFaR0chfiW7m4RuDQ0wrRcJezDWVwCBEOQDFk,1502
|
7
|
+
tabpfn_time_series/tabpfn_worker.py,sha256=XNpqLEW51PgzrEopNNdtGdYArMCHT4yeBK3BS3z25K0,5021
|
8
|
+
tabpfn_time_series-0.1.2.dist-info/METADATA,sha256=hO69b8GN3GDRIetG4DGtxpdMubc8sm8h_aI2RwEto2U,3285
|
9
|
+
tabpfn_time_series-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
tabpfn_time_series-0.1.2.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
|
11
|
+
tabpfn_time_series-0.1.2.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
tabpfn_time_series/__init__.py,sha256=atVNap8tQLI0t5COkkCop-wbY3y1FdVxIMfvCf6VDsQ,257
|
2
|
-
tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
|
3
|
-
tabpfn_time_series/defaults.py,sha256=C9HiD7Zm0BzVfE9e2f8nhpiPQSYx79hWozvzb-93L40,165
|
4
|
-
tabpfn_time_series/feature.py,sha256=_9FxfQfgPOOO1MiT8hB8523eZ3Nc5oKuoY7vcohKZZc,2531
|
5
|
-
tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
|
6
|
-
tabpfn_time_series/predictor.py,sha256=YfJIe8KsyzkwgX4EFAHR8dDp-mqSv9WK88_qO_EXlws,1505
|
7
|
-
tabpfn_time_series/tabpfn_worker.py,sha256=3xInPzzQtmIBPjbc_5TaQsX3-Bl3WOlxttqj3KZlC9Q,7395
|
8
|
-
tabpfn_time_series-0.1.0.dist-info/METADATA,sha256=3lPtIH1qAR58k6Y-19ZPN8fEEtNG1UKkW5qKZ-jh8e4,3147
|
9
|
-
tabpfn_time_series-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
-
tabpfn_time_series-0.1.0.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
|
11
|
-
tabpfn_time_series-0.1.0.dist-info/RECORD,,
|
File without changes
|
{tabpfn_time_series-0.1.0.dist-info → tabpfn_time_series-0.1.2.dist-info}/licenses/LICENSE.txt
RENAMED
File without changes
|