tabpfn-time-series 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,18 @@
1
+ from pathlib import Path
2
+
3
+ try:
4
+ import tomllib # Python 3.11+
5
+ except ImportError:
6
+ import tomli as tomllib # Python <3.11, requires 'tomli' package
7
+
8
+ with (Path(__file__).parent.parent / "pyproject.toml").open("rb") as f:
9
+ __version__ = tomllib.load(f)["project"]["version"]
10
+
11
+
1
12
  from .features import FeatureTransformer
2
13
  from .predictor import TabPFNTimeSeriesPredictor, TabPFNMode
3
14
  from .defaults import TABPFN_TS_DEFAULT_QUANTILE_CONFIG
4
15
 
5
- __version__ = "0.1.0"
6
-
7
16
  __all__ = [
8
17
  "FeatureTransformer",
9
18
  "TabPFNTimeSeriesPredictor",
@@ -1,7 +1,5 @@
1
1
  TABPFN_TS_DEFAULT_QUANTILE_CONFIG = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
2
2
  TABPFN_TS_DEFAULT_CONFIG = {
3
- "tabpfn_internal": {
4
- "model_path": "2noar4o2",
5
- },
3
+ "tabpfn_internal": {"model_path": "tabpfn-v2-regressor-2noar4o2.ckpt"},
6
4
  "tabpfn_output_selection": "median", # mean or median
7
5
  }
@@ -31,9 +31,9 @@ class FeatureTransformer:
31
31
  train_tsdf = tsdf.iloc[: len(train_tsdf)]
32
32
  test_tsdf = tsdf.iloc[len(train_tsdf) :]
33
33
 
34
- assert (
35
- not train_tsdf[target_column].isna().any()
36
- ), "All target values in train_tsdf should be non-NaN"
34
+ assert not train_tsdf[target_column].isna().any(), (
35
+ "All target values in train_tsdf should be non-NaN"
36
+ )
37
37
  assert test_tsdf[target_column].isna().all()
38
38
 
39
39
  return train_tsdf, test_tsdf
@@ -122,18 +122,39 @@ class TabPFNClient(TabPFNWorker):
122
122
  config: dict = {},
123
123
  num_workers: int = 2,
124
124
  ):
125
- super().__init__(config, num_workers)
126
-
127
125
  # Initialize the TabPFN client (e.g. sign up, login, etc.)
128
126
  from tabpfn_client import init
129
127
 
130
128
  init()
131
129
 
130
+ # Parse the model name (only needed for TabPFNClient)
131
+ config = config.copy()
132
+ config["tabpfn_internal"]["model_path"] = self._parse_model_name(
133
+ config["tabpfn_internal"]["model_path"]
134
+ )
135
+
136
+ super().__init__(config, num_workers)
137
+
132
138
  def _get_tabpfn_engine(self):
133
139
  from tabpfn_client import TabPFNRegressor
134
140
 
135
141
  return TabPFNRegressor(**self.config["tabpfn_internal"])
136
142
 
143
+ def _parse_model_name(self, model_name: str) -> str:
144
+ from tabpfn_client import TabPFNRegressor
145
+
146
+ available_models = TabPFNRegressor.list_available_models()
147
+
148
+ for m in available_models:
149
+ # Model names from tabpfn_client are abbreviated
150
+ # e.g. "tabpfn-v2-regressor-2noar4o2.ckpt" -> "2noar4o2"
151
+ if m in model_name:
152
+ return m
153
+ raise ValueError(
154
+ f"Model {model_name} not found. Available models: {available_models}."
155
+ "Note that model names from tabpfn_client are abbreviated (e.g. 'tabpfn-v2-regressor-2noar4o2.ckpt' -> '2noar4o2')"
156
+ )
157
+
137
158
 
138
159
  class LocalTabPFN(TabPFNWorker):
139
160
  def __init__(
@@ -151,6 +172,9 @@ class LocalTabPFN(TabPFNWorker):
151
172
  config, num_workers=torch.cuda.device_count() * self.num_workers_per_gpu
152
173
  )
153
174
 
175
+ # Download the model specified in the config
176
+ self._download_model(self.config["tabpfn_internal"]["model_path"])
177
+
154
178
  def predict(
155
179
  self,
156
180
  train_tsdf: TimeSeriesDataFrame,
@@ -186,17 +210,29 @@ class LocalTabPFN(TabPFNWorker):
186
210
 
187
211
  return TimeSeriesDataFrame(predictions)
188
212
 
189
- def _get_tabpfn_engine(self):
190
- from tabpfn import TabPFNRegressor
213
+ @staticmethod
214
+ def _download_model(model_name: str):
215
+ from tabpfn.model.loading import resolve_model_path, download_model
191
216
 
192
- if "model_path" in self.config["tabpfn_internal"]:
193
- config = self.config["tabpfn_internal"].copy()
194
- config["model_path"] = self._parse_model_path(config["model_path"])
217
+ # Resolve the model path
218
+ # If the model path is not specified, this resolves to the default model path
219
+ model_path, _, model_name, which = resolve_model_path(
220
+ model_name,
221
+ which="regressor",
222
+ )
195
223
 
196
- return TabPFNRegressor(**config, random_state=0)
224
+ if not model_path.exists():
225
+ download_model(
226
+ to=model_path,
227
+ which=which,
228
+ version="v2",
229
+ model_name=model_name,
230
+ )
231
+
232
+ def _get_tabpfn_engine(self):
233
+ from tabpfn import TabPFNRegressor
197
234
 
198
- def _parse_model_path(self, model_name: str) -> str:
199
- return f"tabpfn-v2-regressor-{model_name}.ckpt"
235
+ return TabPFNRegressor(**self.config["tabpfn_internal"], random_state=0)
200
236
 
201
237
  def _prediction_routine_per_gpu(
202
238
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tabpfn_time_series
3
- Version: 1.0.0
3
+ Version: 1.0.1
4
4
  Summary: Zero-shot time series forecasting with TabPFNv2
5
5
  Project-URL: Homepage, https://github.com/liam-sbhoo/tabpfn-time-series
6
6
  Project-URL: Bug Tracker, https://github.com/liam-sbhoo/tabpfn-time-series/issues
@@ -11,19 +11,22 @@ Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
13
  Requires-Dist: autogluon-timeseries>=1.2
14
- Requires-Dist: datasets>=3.3.2
14
+ Requires-Dist: datasets>=4.0
15
15
  Requires-Dist: gluonts>=0.16.0
16
16
  Requires-Dist: pandas<2.2.0,>=2.1.2
17
17
  Requires-Dist: python-dotenv>=1.1.0
18
18
  Requires-Dist: pyyaml>=6.0.1
19
19
  Requires-Dist: tabpfn-client>=0.1.7
20
20
  Requires-Dist: tabpfn>=2.0.9
21
+ Requires-Dist: tomli>=2.2.1
21
22
  Requires-Dist: tqdm
22
23
  Provides-Extra: dev
23
24
  Requires-Dist: build; extra == 'dev'
25
+ Requires-Dist: ipykernel>=6.29.5; extra == 'dev'
24
26
  Requires-Dist: jupyter; extra == 'dev'
25
27
  Requires-Dist: pre-commit; extra == 'dev'
26
- Requires-Dist: ruff; extra == 'dev'
28
+ Requires-Dist: pytest; extra == 'dev'
29
+ Requires-Dist: ruff~=0.12.0; extra == 'dev'
27
30
  Requires-Dist: submitit>=1.5.2; extra == 'dev'
28
31
  Requires-Dist: twine; extra == 'dev'
29
32
  Requires-Dist: wandb>=0.19.8; extra == 'dev'
@@ -34,7 +37,7 @@ Description-Content-Type: text/markdown
34
37
  > Zero-Shot Time Series Forecasting with TabPFNv2
35
38
 
36
39
  [![PyPI version](https://badge.fury.io/py/tabpfn-time-series.svg)](https://badge.fury.io/py/tabpfn-time-series)
37
- [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/liam-sbhoo/tabpfn-time-series/blob/main/demo.ipynb)
40
+ [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PriorLabs/tabpfn-time-series/blob/main/demo.ipynb)
38
41
  [![Discord](https://img.shields.io/discord/1285598202732482621?color=7289da&label=Discord&logo=discord&logoColor=ffffff)](https://discord.com/channels/1285598202732482621/)
39
42
  [![arXiv](https://img.shields.io/badge/arXiv-2501.02945-<COLOR>.svg)](https://arxiv.org/abs/2501.02945v3)
40
43
 
@@ -63,15 +66,33 @@ Concretely, we:
63
66
  For more details, please refer to our [paper](https://arxiv.org/abs/2501.02945v3).
64
67
  <!-- and our [poster](docs/tabpfn-ts-neurips-poster.pdf) (presented at NeurIPS 2024 TRL and TSALM workshops). -->
65
68
 
66
- ## 👉 **Why gives us a try?**
69
+ ## 👉 **Why give us a try?**
67
70
  - **Zero-shot forecasting**: this method is extremely fast and requires no training, making it highly accessible for experimenting with your own problems.
68
71
  - **Point and probabilistic forecasting**: it provides accurate point forecasts as well as probabilistic forecasts.
69
72
  - **Support for exogenous variables**: if you have exogenous variables, this method can seemlessly incorporate them into the forecasting model.
70
73
 
71
- On top of that, thanks to **[tabpfn-client](https://github.com/automl/tabpfn-client)** from **[Prior Labs](https://priorlabs.ai)**, you wont even need your own GPU to run fast inference with TabPFNv2. 😉 We have included `tabpfn-client` as the default engine in our implementation.
74
+ On top of that, thanks to **[tabpfn-client](https://github.com/automl/tabpfn-client)** from **[Prior Labs](https://priorlabs.ai)**, you won't even need your own GPU to run fast inference with TabPFNv2. 😉 We have included `tabpfn-client` as the default engine in our implementation.
75
+
76
+ ## ⚙️ Installation
77
+
78
+ You can install the package via pip:
79
+
80
+ ```bash
81
+ pip install tabpfn-time-series
82
+ ```
83
+
84
+ ### For Developers
85
+
86
+ To install the package in editable mode with all development dependencies, run the following command in your terminal:
87
+
88
+ ```bash
89
+ pip install -e ".[dev]"
90
+ # or with uv
91
+ uv pip install -e ".[dev]"
92
+ ```
72
93
 
73
94
  ## How to use it?
74
95
 
75
- [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/liam-sbhoo/tabpfn-time-series/blob/main/demo.ipynb)
96
+ [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PriorLabs/tabpfn-time-series/blob/main/demo.ipynb)
76
97
 
77
98
  The demo should explain it all. 😉
@@ -1,15 +1,15 @@
1
- tabpfn_time_series/__init__.py,sha256=3XGvQieVbONwhVtn1rITet6HNiTMWQTxHm2xLlGI5ew,314
1
+ tabpfn_time_series/__init__.py,sha256=cuoY5WBdBZM9RUM8bX4-Gyf5M9uQ-905lkQIDSyojTw,578
2
2
  tabpfn_time_series/data_preparation.py,sha256=iNW7sAnRkTgmzzOEHBhkkTwm_lQ3p_Q9xgAQ5PbkOts,5416
3
- tabpfn_time_series/defaults.py,sha256=u2_JnwxiZ5NNibzyNpsE63KuP3TcmOL1iAP8llZ2rJk,238
3
+ tabpfn_time_series/defaults.py,sha256=ki1y38FR4zmbHWgRjcryA5T88GzNMwhlZC-sTRjuK2U,248
4
4
  tabpfn_time_series/plot.py,sha256=bwSYcWBanzPrUxXKFsbqG8fyGsOJZfgU2v3NsxzTSXo,6571
5
5
  tabpfn_time_series/predictor.py,sha256=JzuV34zERf1XDLacGzSFJb-o077qd7GlKC6lvD62EPk,1457
6
- tabpfn_time_series/tabpfn_worker.py,sha256=zvFwg4Dc01_m5emqmVITBr6W_cNZ04tMyntmj40pyPE,8299
6
+ tabpfn_time_series/tabpfn_worker.py,sha256=NeYPX1XcPe3jF5yFPMUR_3Lq5hw4RpxVFTD7z_ikahA,9608
7
7
  tabpfn_time_series/features/__init__.py,sha256=lzdZWkEfntfg3ZHqNNbfbg-3o_VIzju0tebdRu3AzF4,421
8
8
  tabpfn_time_series/features/auto_features.py,sha256=3OqqY2h7umcoLjLx4hOXypLTjwzrMtd6cQKTNi83vrU,11561
9
9
  tabpfn_time_series/features/basic_features.py,sha256=OV3B__S30-CX88vGjwYQDWqAbJajQw80PxcnvJVUbm4,2955
10
10
  tabpfn_time_series/features/feature_generator_base.py,sha256=jtySWLJyX4E31v6CbX44EHa8cdz7OMyauf4ltNEQeAQ,534
11
- tabpfn_time_series/features/feature_transformer.py,sha256=mUsbnPUhJ4lPcnGWk8Ag1hgCOE1V5I0iQRT4VFgQEso,1763
12
- tabpfn_time_series-1.0.0.dist-info/METADATA,sha256=CvXqIOHNTKyd-zpCednsqa3FloPk6lFJ4ISG0eSEWx4,4434
13
- tabpfn_time_series-1.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- tabpfn_time_series-1.0.0.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
15
- tabpfn_time_series-1.0.0.dist-info/RECORD,,
11
+ tabpfn_time_series/features/feature_transformer.py,sha256=JgPjAUiNBJGfx5PUOFVOZc467AW89Fe-JE08dKQ0AjY,1763
12
+ tabpfn_time_series-1.0.1.dist-info/METADATA,sha256=mSPGoTbkVFjvJswthF0DAUvql7NgHD9wRzVkHu1ia1I,4873
13
+ tabpfn_time_series-1.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ tabpfn_time_series-1.0.1.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
15
+ tabpfn_time_series-1.0.1.dist-info/RECORD,,