tabpfn-time-series 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,7 @@ __version__ = "0.1.0"
5
5
 
6
6
  __all__ = [
7
7
  "DefaultFeatures",
8
- "FeatureTransformer",
8
+ "FeatureTransformer",
9
9
  "TabPFNTimeSeriesPredictor",
10
10
  "TabPFNMode",
11
11
  ]
@@ -1,5 +1,7 @@
1
- TABPFN_DEFAULT_QUANTILE_CONFIG = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
2
- TABPFN_DEFAULT_CONFIG = {
3
- "model": "2noar4o2",
4
- "optimize_metric": "median",
1
+ TABPFN_TS_DEFAULT_QUANTILE_CONFIG = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
2
+ TABPFN_TS_DEFAULT_CONFIG = {
3
+ "tabpfn_internal": {
4
+ "model_path": "2noar4o2",
5
+ },
6
+ "tabpfn_output_selection": "median", # mean or median
5
7
  }
@@ -4,7 +4,10 @@ from enum import Enum
4
4
  from autogluon.timeseries import TimeSeriesDataFrame
5
5
 
6
6
  from tabpfn_time_series.tabpfn_worker import TabPFNClient, LocalTabPFN
7
- from tabpfn_time_series.defaults import TABPFN_DEFAULT_QUANTILE_CONFIG, TABPFN_DEFAULT_CONFIG
7
+ from tabpfn_time_series.defaults import (
8
+ TABPFN_TS_DEFAULT_QUANTILE_CONFIG,
9
+ TABPFN_TS_DEFAULT_CONFIG,
10
+ )
8
11
 
9
12
  logger = logging.getLogger(__name__)
10
13
 
@@ -22,11 +25,11 @@ class TabPFNTimeSeriesPredictor:
22
25
  def __init__(
23
26
  self,
24
27
  tabpfn_mode: TabPFNMode = TabPFNMode.CLIENT,
25
- tabpfn_config: dict = TABPFN_DEFAULT_CONFIG,
28
+ config: dict = TABPFN_TS_DEFAULT_CONFIG,
26
29
  ) -> None:
27
30
  worker_mapping = {
28
- TabPFNMode.CLIENT: lambda: TabPFNClient(tabpfn_config),
29
- TabPFNMode.LOCAL: lambda: LocalTabPFN(tabpfn_config),
31
+ TabPFNMode.CLIENT: lambda: TabPFNClient(config),
32
+ TabPFNMode.LOCAL: lambda: LocalTabPFN(config),
30
33
  }
31
34
  self.tabpfn_worker = worker_mapping[tabpfn_mode]()
32
35
 
@@ -34,14 +37,14 @@ class TabPFNTimeSeriesPredictor:
34
37
  self,
35
38
  train_tsdf: TimeSeriesDataFrame, # with features and target
36
39
  test_tsdf: TimeSeriesDataFrame, # with features only
37
- quantile_config: list[float] = TABPFN_DEFAULT_QUANTILE_CONFIG,
40
+ quantile_config: list[float] = TABPFN_TS_DEFAULT_QUANTILE_CONFIG,
38
41
  ) -> TimeSeriesDataFrame:
39
42
  """
40
43
  Predict on each time series individually (local forecasting).
41
44
  """
42
45
 
43
46
  logger.info(
44
- f"Predicting {len(train_tsdf.item_ids)} time series with config{self.tabpfn_worker.tabpfn_config}"
47
+ f"Predicting {len(train_tsdf.item_ids)} time series with config{self.tabpfn_worker.config}"
45
48
  )
46
49
 
47
50
  return self.tabpfn_worker.predict(train_tsdf, test_tsdf, quantile_config)
@@ -8,7 +8,7 @@ from scipy.stats import norm
8
8
  from autogluon.timeseries import TimeSeriesDataFrame
9
9
 
10
10
  from tabpfn_time_series.data_preparation import split_time_series_to_X_y
11
- from tabpfn_time_series.defaults import TABPFN_DEFAULT_QUANTILE_CONFIG
11
+ from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_QUANTILE_CONFIG
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14
 
@@ -16,10 +16,10 @@ logger = logging.getLogger(__name__)
16
16
  class TabPFNWorker(ABC):
17
17
  def __init__(
18
18
  self,
19
- tabpfn_config: dict = {},
19
+ config: dict = {},
20
20
  num_workers: int = 1,
21
21
  ):
22
- self.tabpfn_config = tabpfn_config
22
+ self.config = config
23
23
  self.num_workers = num_workers
24
24
 
25
25
  def predict(
@@ -28,6 +28,12 @@ class TabPFNWorker(ABC):
28
28
  test_tsdf: TimeSeriesDataFrame,
29
29
  quantile_config: list[float],
30
30
  ):
31
+ if not set(quantile_config).issubset(set(TABPFN_TS_DEFAULT_QUANTILE_CONFIG)):
32
+ raise NotImplementedError(
33
+ f"We currently only supports {TABPFN_TS_DEFAULT_QUANTILE_CONFIG} for quantile prediction,"
34
+ f" but got {quantile_config}."
35
+ )
36
+
31
37
  predictions = Parallel(
32
38
  n_jobs=self.num_workers,
33
39
  backend="loky",
@@ -67,12 +73,16 @@ class TabPFNWorker(ABC):
67
73
  single_train_tsdf, single_test_tsdf, quantile_config
68
74
  )
69
75
  else:
70
- # Call worker-specific prediction routine
71
- result = self._worker_specific_prediction_routine(
72
- train_X,
73
- train_y,
74
- test_X,
75
- quantile_config,
76
+ tabpfn = self._get_tabpfn_engine()
77
+ tabpfn.fit(train_X, train_y)
78
+ full_pred = tabpfn.predict(test_X, output_type="main")
79
+
80
+ result = {"target": full_pred[self.config["tabpfn_output_selection"]]}
81
+ result.update(
82
+ {
83
+ q: q_pred
84
+ for q, q_pred in zip(quantile_config, full_pred["quantiles"])
85
+ }
76
86
  )
77
87
 
78
88
  result = pd.DataFrame(result, index=test_index)
@@ -81,13 +91,7 @@ class TabPFNWorker(ABC):
81
91
  return result
82
92
 
83
93
  @abstractmethod
84
- def _worker_specific_prediction_routine(
85
- self,
86
- train_X: pd.DataFrame,
87
- train_y: pd.Series,
88
- test_X: pd.DataFrame,
89
- quantile_config: list[float],
90
- ) -> pd.DataFrame:
94
+ def _get_tabpfn_engine(self):
91
95
  pass
92
96
 
93
97
  def _predict_on_constant_train_target(
@@ -117,108 +121,37 @@ class TabPFNWorker(ABC):
117
121
  class TabPFNClient(TabPFNWorker):
118
122
  def __init__(
119
123
  self,
120
- tabpfn_config: dict = {},
124
+ config: dict = {},
121
125
  num_workers: int = 2,
122
126
  ):
123
- super().__init__(tabpfn_config, num_workers)
127
+ super().__init__(config, num_workers)
124
128
 
125
129
  # Initialize the TabPFN client (e.g. sign up, login, etc.)
126
130
  from tabpfn_client import init
127
131
 
128
132
  init()
129
133
 
130
- def predict(
131
- self,
132
- train_tsdf: TimeSeriesDataFrame,
133
- test_tsdf: TimeSeriesDataFrame,
134
- quantile_config: list[float],
135
- ):
136
- if not set(quantile_config).issubset(set(TABPFN_DEFAULT_QUANTILE_CONFIG)):
137
- raise NotImplementedError(
138
- f"TabPFNClient currently only supports {TABPFN_DEFAULT_QUANTILE_CONFIG} for quantile prediction,"
139
- f" but got {quantile_config}."
140
- )
141
-
142
- return super().predict(train_tsdf, test_tsdf, quantile_config)
143
-
144
- def _worker_specific_prediction_routine(
145
- self,
146
- train_X: pd.DataFrame,
147
- train_y: pd.Series,
148
- test_X: pd.DataFrame,
149
- quantile_config: list[float],
150
- ) -> pd.DataFrame:
134
+ def _get_tabpfn_engine(self):
151
135
  from tabpfn_client import TabPFNRegressor
152
136
 
153
- tabpfn = TabPFNRegressor(**self.tabpfn_config)
154
- tabpfn.fit(train_X, train_y)
155
- full_pred = tabpfn.predict_full(test_X)
156
-
157
- result = {"target": full_pred[self._get_optimization_mode()]}
158
- result.update({q: full_pred[f"quantile_{q:.2f}"] for q in quantile_config})
159
-
160
- return result
161
-
162
- def _get_optimization_mode(self):
163
- if (
164
- "optimize_metric" not in self.tabpfn_config
165
- or self.tabpfn_config["optimize_metric"] is None
166
- ):
167
- return "mean"
168
- elif self.tabpfn_config["optimize_metric"] in ["rmse", "mse", "r2", "mean"]:
169
- return "mean"
170
- elif self.tabpfn_config["optimize_metric"] in ["mae", "median"]:
171
- return "median"
172
- elif self.tabpfn_config["optimize_metric"] in ["mode", "exact_match"]:
173
- return "mode"
174
- else:
175
- raise ValueError(f"Unknown metric {self.tabpfn_config['optimize_metric']}")
137
+ return TabPFNRegressor(**self.config["tabpfn_internal"])
176
138
 
177
139
 
178
140
  class LocalTabPFN(TabPFNWorker):
179
141
  def __init__(
180
142
  self,
181
- tabpfn_config: dict = {},
143
+ config: dict = {},
182
144
  ):
183
- # Local TabPFN has a different interface for declaring the model
184
- if "model" in tabpfn_config:
185
- config = tabpfn_config.copy()
186
- config["model_path"] = self._parse_model_path(config["model"])
187
- del config["model"]
188
- tabpfn_config = config
189
-
190
- super().__init__(tabpfn_config, num_workers=1)
145
+ super().__init__(config, num_workers=1)
191
146
 
192
- def _worker_specific_prediction_routine(
193
- self,
194
- train_X: pd.DataFrame,
195
- train_y: pd.Series,
196
- test_X: pd.DataFrame,
197
- quantile_config: list[float],
198
- ) -> pd.DataFrame:
147
+ def _get_tabpfn_engine(self):
199
148
  from tabpfn import TabPFNRegressor
200
149
 
201
- tabpfn = TabPFNRegressor(**self.tabpfn_config)
202
- tabpfn.fit(train_X, train_y)
203
- full_pred = tabpfn.predict_full(test_X)
204
-
205
- result = {"target": full_pred[tabpfn.get_optimization_mode()]}
206
- if set(quantile_config).issubset(set(TABPFN_DEFAULT_QUANTILE_CONFIG)):
207
- result.update({q: full_pred[f"quantile_{q:.2f}"] for q in quantile_config})
208
- else:
209
- import torch
150
+ if "model_path" in self.config["tabpfn_internal"]:
151
+ config = self.config["tabpfn_internal"].copy()
152
+ config["model_path"] = self._parse_model_path(config["model_path"])
210
153
 
211
- criterion = full_pred["criterion"]
212
- logits = torch.tensor(full_pred["logits"])
213
- result.update({q: criterion.icdf(logits, q) for q in quantile_config})
214
-
215
- return result
154
+ return TabPFNRegressor(**config)
216
155
 
217
156
  def _parse_model_path(self, model_name: str) -> str:
218
- from pathlib import Path
219
- import importlib.util
220
-
221
- tabpfn_path = Path(importlib.util.find_spec("tabpfn").origin).parent
222
- return str(
223
- tabpfn_path / "model_cache" / f"model_hans_regression_{model_name}.ckpt"
224
- )
157
+ return f"tabpfn-v2-regressor-{model_name}.ckpt"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tabpfn_time_series
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: Zero-shot time series forecasting with TabPFN
5
5
  Project-URL: Homepage, https://github.com/liam-sbhoo/tabpfn-time-series
6
6
  Project-URL: Bug Tracker, https://github.com/liam-sbhoo/tabpfn-time-series/issues
@@ -10,13 +10,15 @@ Classifier: License :: OSI Approved :: Apache Software License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
- Requires-Dist: autogluon-timeseries
14
- Requires-Dist: gluonts
13
+ Requires-Dist: autogluon-timeseries==1.2
14
+ Requires-Dist: gluonts==0.16.0
15
15
  Requires-Dist: pandas
16
- Requires-Dist: tabpfn-client
16
+ Requires-Dist: tabpfn-client==0.1.1
17
+ Requires-Dist: tabpfn==2.0.0
17
18
  Requires-Dist: tqdm
18
19
  Provides-Extra: dev
19
20
  Requires-Dist: build; extra == 'dev'
21
+ Requires-Dist: jupyter; extra == 'dev'
20
22
  Requires-Dist: pre-commit; extra == 'dev'
21
23
  Requires-Dist: ruff; extra == 'dev'
22
24
  Requires-Dist: twine; extra == 'dev'
@@ -26,10 +28,10 @@ Description-Content-Type: text/markdown
26
28
 
27
29
  [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/liam-sbhoo/tabpfn-time-series/blob/main/demo.ipynb)
28
30
  [![Discord](https://img.shields.io/discord/1285598202732482621?color=7289da&label=Discord&logo=discord&logoColor=ffffff)](https://discord.com/channels/1285598202732482621/)
29
- [![arXiv](https://img.shields.io/badge/arXiv-<INDEX>-<COLOR>.svg)](https://arxiv.org/abs/2501.02945)
31
+ [![arXiv](https://img.shields.io/badge/arXiv-2501.02945-<COLOR>.svg)](https://arxiv.org/abs/2501.02945)
30
32
 
31
33
 
32
- We demonstrate that the tabular foundation model **TabPFN**, when paired with minimal featurization, can perform zero-shot time series forecasting. Its performance on point forecasting matches or even slightly outperforms state-of-the-art methods.
34
+ We demonstrate that the tabular foundation model **[TabPFN](https://github.com/PriorLabs/TabPFN)**, when paired with minimal featurization, can perform zero-shot time series forecasting. Its performance on point forecasting matches or even slightly outperforms state-of-the-art methods.
33
35
 
34
36
  ## 📖 How does it work?
35
37
 
@@ -50,7 +52,7 @@ For more details, please refer to our [paper](https://arxiv.org/abs/2501.02945)
50
52
  - **Point and probabilistic forecasting**: it provides accurate point forecasts as well as probabilistic forecasts.
51
53
  - **Support for exogenous variables**: if you have exogenous variables, this method can seemlessly incorporate them into the forecasting model.
52
54
 
53
- On top of that, thanks to [tabpfn-client](https://github.com/automl/tabpfn-client) from [Prior Labs](https://priorlabs.ai), you won’t even need your own GPU to run fast inference with TabPFN. 😉 We have included `tabpfn-client` as the default engine in our implementation.
55
+ On top of that, thanks to **[tabpfn-client](https://github.com/automl/tabpfn-client)** from **[Prior Labs](https://priorlabs.ai)**, you won’t even need your own GPU to run fast inference with TabPFN. 😉 We have included `tabpfn-client` as the default engine in our implementation.
54
56
 
55
57
  ## How to use it?
56
58
 
@@ -0,0 +1,11 @@
1
+ tabpfn_time_series/__init__.py,sha256=5ruHrmKBQRIZ3WXLA8du4JKttF55ntnI74hkRsHThQ8,256
2
+ tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
3
+ tabpfn_time_series/defaults.py,sha256=u2_JnwxiZ5NNibzyNpsE63KuP3TcmOL1iAP8llZ2rJk,238
4
+ tabpfn_time_series/feature.py,sha256=_9FxfQfgPOOO1MiT8hB8523eZ3Nc5oKuoY7vcohKZZc,2531
5
+ tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
6
+ tabpfn_time_series/predictor.py,sha256=W9JijaxFaR0chfiW7m4RuDQ0wrRcJezDWVwCBEOQDFk,1502
7
+ tabpfn_time_series/tabpfn_worker.py,sha256=XNpqLEW51PgzrEopNNdtGdYArMCHT4yeBK3BS3z25K0,5021
8
+ tabpfn_time_series-0.1.2.dist-info/METADATA,sha256=hO69b8GN3GDRIetG4DGtxpdMubc8sm8h_aI2RwEto2U,3285
9
+ tabpfn_time_series-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ tabpfn_time_series-0.1.2.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
11
+ tabpfn_time_series-0.1.2.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- tabpfn_time_series/__init__.py,sha256=atVNap8tQLI0t5COkkCop-wbY3y1FdVxIMfvCf6VDsQ,257
2
- tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
3
- tabpfn_time_series/defaults.py,sha256=C9HiD7Zm0BzVfE9e2f8nhpiPQSYx79hWozvzb-93L40,165
4
- tabpfn_time_series/feature.py,sha256=_9FxfQfgPOOO1MiT8hB8523eZ3Nc5oKuoY7vcohKZZc,2531
5
- tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
6
- tabpfn_time_series/predictor.py,sha256=YfJIe8KsyzkwgX4EFAHR8dDp-mqSv9WK88_qO_EXlws,1505
7
- tabpfn_time_series/tabpfn_worker.py,sha256=3xInPzzQtmIBPjbc_5TaQsX3-Bl3WOlxttqj3KZlC9Q,7395
8
- tabpfn_time_series-0.1.1.dist-info/METADATA,sha256=EK_9xSO0-EiE-YzlGietFOwOdD7gt5zFI5gqkzo1IMk,3147
9
- tabpfn_time_series-0.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- tabpfn_time_series-0.1.1.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
11
- tabpfn_time_series-0.1.1.dist-info/RECORD,,