tabpfn-time-series 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tabpfn_time_series/predictor.py +6 -1
- tabpfn_time_series/tabpfn_worker.py +87 -21
- {tabpfn_time_series-1.0.2.dist-info → tabpfn_time_series-1.0.4.dist-info}/METADATA +2 -1
- {tabpfn_time_series-1.0.2.dist-info → tabpfn_time_series-1.0.4.dist-info}/RECORD +6 -6
- {tabpfn_time_series-1.0.2.dist-info → tabpfn_time_series-1.0.4.dist-info}/WHEEL +0 -0
- {tabpfn_time_series-1.0.2.dist-info → tabpfn_time_series-1.0.4.dist-info}/licenses/LICENSE.txt +0 -0
tabpfn_time_series/predictor.py
CHANGED
@@ -2,7 +2,11 @@ import logging
|
|
2
2
|
from enum import Enum
|
3
3
|
|
4
4
|
from tabpfn_time_series.ts_dataframe import TimeSeriesDataFrame
|
5
|
-
from tabpfn_time_series.tabpfn_worker import
|
5
|
+
from tabpfn_time_series.tabpfn_worker import (
|
6
|
+
TabPFNClient,
|
7
|
+
LocalTabPFN,
|
8
|
+
MockTabPFN,
|
9
|
+
)
|
6
10
|
from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_CONFIG
|
7
11
|
|
8
12
|
logger = logging.getLogger(__name__)
|
@@ -29,6 +33,7 @@ class TabPFNTimeSeriesPredictor:
|
|
29
33
|
TabPFNMode.LOCAL: lambda: LocalTabPFN(config),
|
30
34
|
TabPFNMode.MOCK: lambda: MockTabPFN(config),
|
31
35
|
}
|
36
|
+
self.tabpfn_mode = tabpfn_mode
|
32
37
|
self.tabpfn_worker = worker_mapping[tabpfn_mode]()
|
33
38
|
|
34
39
|
def predict(
|
@@ -1,6 +1,8 @@
|
|
1
|
+
import contextvars
|
1
2
|
import logging
|
2
3
|
from abc import ABC, abstractmethod
|
3
|
-
from joblib import Parallel, delayed
|
4
|
+
from joblib import Parallel, delayed, parallel_config
|
5
|
+
import backoff
|
4
6
|
|
5
7
|
from tqdm import tqdm
|
6
8
|
import pandas as pd
|
@@ -14,6 +16,9 @@ from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_QUANTILE_CONFIG
|
|
14
16
|
|
15
17
|
logger = logging.getLogger(__name__)
|
16
18
|
|
19
|
+
# Per-call attempt counter, isolated per thread & task
|
20
|
+
_retry_attempts = contextvars.ContextVar("predict_attempts", default=0)
|
21
|
+
|
17
22
|
|
18
23
|
class TabPFNWorker(ABC):
|
19
24
|
def __init__(
|
@@ -29,24 +34,7 @@ class TabPFNWorker(ABC):
|
|
29
34
|
train_tsdf: TimeSeriesDataFrame,
|
30
35
|
test_tsdf: TimeSeriesDataFrame,
|
31
36
|
):
|
32
|
-
|
33
|
-
n_jobs=self.num_workers,
|
34
|
-
backend="loky",
|
35
|
-
)(
|
36
|
-
delayed(self._prediction_routine)(
|
37
|
-
item_id,
|
38
|
-
train_tsdf.loc[item_id],
|
39
|
-
test_tsdf.loc[item_id],
|
40
|
-
)
|
41
|
-
for item_id in tqdm(train_tsdf.item_ids, desc="Predicting time series")
|
42
|
-
)
|
43
|
-
|
44
|
-
predictions = pd.concat(predictions)
|
45
|
-
|
46
|
-
# Sort predictions according to original item_ids order (important for MASE and WQL calculation)
|
47
|
-
predictions = predictions.loc[train_tsdf.item_ids]
|
48
|
-
|
49
|
-
return TimeSeriesDataFrame(predictions)
|
37
|
+
raise NotImplementedError("Predict method must be implemented in subclass")
|
50
38
|
|
51
39
|
def _prediction_routine(
|
52
40
|
self,
|
@@ -54,8 +42,6 @@ class TabPFNWorker(ABC):
|
|
54
42
|
single_train_tsdf: TimeSeriesDataFrame,
|
55
43
|
single_test_tsdf: TimeSeriesDataFrame,
|
56
44
|
) -> pd.DataFrame:
|
57
|
-
# logger.debug(f"Predicting on item_id: {item_id}")
|
58
|
-
|
59
45
|
test_index = single_test_tsdf.index
|
60
46
|
train_X, train_y = split_time_series_to_X_y(single_train_tsdf.copy())
|
61
47
|
test_X, _ = split_time_series_to_X_y(single_test_tsdf.copy())
|
@@ -116,6 +102,39 @@ class TabPFNWorker(ABC):
|
|
116
102
|
return result
|
117
103
|
|
118
104
|
|
105
|
+
def _reset_attempts(_details=None):
|
106
|
+
"""Convenience function to reset the attempt counter."""
|
107
|
+
_retry_attempts.set(0)
|
108
|
+
|
109
|
+
|
110
|
+
def _predict_giveup_mixed(exc: Exception) -> bool:
|
111
|
+
"""Determine whether to give up on a prediction call or not.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
True if the prediction call should be given up on, False otherwise.
|
115
|
+
"""
|
116
|
+
if _is_tabpfn_gcs_429(exc):
|
117
|
+
return False
|
118
|
+
|
119
|
+
# Stop after first retry for non-429
|
120
|
+
return _retry_attempts.get() >= 2
|
121
|
+
|
122
|
+
|
123
|
+
def _is_tabpfn_gcs_429(err: Exception) -> bool:
|
124
|
+
"""Determine if an error is a 429 error raised from TabPFN API
|
125
|
+
and relates to GCS 429 errors.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
True if the error is a 429 error raised from TabPFN API.
|
129
|
+
"""
|
130
|
+
markers = (
|
131
|
+
"TooManyRequests: 429",
|
132
|
+
"rateLimitExceeded",
|
133
|
+
"cloud.google.com/storage/docs/gcs429",
|
134
|
+
)
|
135
|
+
return any(m in str(err) for m in markers)
|
136
|
+
|
137
|
+
|
119
138
|
class TabPFNClient(TabPFNWorker):
|
120
139
|
def __init__(
|
121
140
|
self,
|
@@ -135,6 +154,53 @@ class TabPFNClient(TabPFNWorker):
|
|
135
154
|
|
136
155
|
super().__init__(config, num_workers)
|
137
156
|
|
157
|
+
def predict(
|
158
|
+
self,
|
159
|
+
train_tsdf: TimeSeriesDataFrame,
|
160
|
+
test_tsdf: TimeSeriesDataFrame,
|
161
|
+
):
|
162
|
+
# Run the predictions in parallel
|
163
|
+
with parallel_config(backend="threading"):
|
164
|
+
results = Parallel(
|
165
|
+
n_jobs=self.num_workers,
|
166
|
+
)(
|
167
|
+
delayed(self._prediction_routine)(
|
168
|
+
item_id,
|
169
|
+
train_tsdf.loc[item_id],
|
170
|
+
test_tsdf.loc[item_id],
|
171
|
+
)
|
172
|
+
for item_id in tqdm(train_tsdf.item_ids, desc="Predicting time series")
|
173
|
+
)
|
174
|
+
|
175
|
+
# Convert list to DataFrame
|
176
|
+
predictions = pd.concat(results)
|
177
|
+
|
178
|
+
# Sort predictions according to original item_ids order (important for MASE and WQL calculation)
|
179
|
+
predictions = predictions.loc[train_tsdf.item_ids]
|
180
|
+
|
181
|
+
return TimeSeriesDataFrame(predictions)
|
182
|
+
|
183
|
+
@backoff.on_exception(
|
184
|
+
backoff.expo,
|
185
|
+
Exception,
|
186
|
+
base=1,
|
187
|
+
factor=2,
|
188
|
+
max_tries=5,
|
189
|
+
jitter=backoff.full_jitter,
|
190
|
+
giveup=_predict_giveup_mixed,
|
191
|
+
on_success=_reset_attempts,
|
192
|
+
)
|
193
|
+
def _prediction_routine(
|
194
|
+
self,
|
195
|
+
item_id: str,
|
196
|
+
single_train_tsdf: TimeSeriesDataFrame,
|
197
|
+
single_test_tsdf: TimeSeriesDataFrame,
|
198
|
+
) -> pd.DataFrame:
|
199
|
+
# Increment attempt count at start of each try
|
200
|
+
_retry_attempts.set(_retry_attempts.get() + 1)
|
201
|
+
|
202
|
+
return super()._prediction_routine(item_id, single_train_tsdf, single_test_tsdf)
|
203
|
+
|
138
204
|
def _get_tabpfn_engine(self):
|
139
205
|
from tabpfn_client import TabPFNRegressor
|
140
206
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tabpfn_time_series
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.4
|
4
4
|
Summary: Zero-shot time series forecasting with TabPFNv2
|
5
5
|
Project-URL: Homepage, https://github.com/liam-sbhoo/tabpfn-time-series
|
6
6
|
Project-URL: Bug Tracker, https://github.com/liam-sbhoo/tabpfn-time-series/issues
|
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
+
Requires-Dist: backoff>=2.2.1
|
13
14
|
Requires-Dist: datasets>=4.0
|
14
15
|
Requires-Dist: gluonts>=0.16.0
|
15
16
|
Requires-Dist: pandas<2.2.0,>=2.1.2
|
@@ -2,15 +2,15 @@ tabpfn_time_series/__init__.py,sha256=XJXSKqWp3AF9mAaWi-4KCgHQG7NzNTaBkLOYOMxvhS
|
|
2
2
|
tabpfn_time_series/data_preparation.py,sha256=wWjSaKgV9KqKonMtSuDbYnW59ixflrScKIP_HSJ_MlA,5427
|
3
3
|
tabpfn_time_series/defaults.py,sha256=ki1y38FR4zmbHWgRjcryA5T88GzNMwhlZC-sTRjuK2U,248
|
4
4
|
tabpfn_time_series/plot.py,sha256=UXgLR2S94vi-vv1ArQKI6uYl_QwSAwAau5jFzGmQ7hw,6582
|
5
|
-
tabpfn_time_series/predictor.py,sha256=
|
6
|
-
tabpfn_time_series/tabpfn_worker.py,sha256=
|
5
|
+
tabpfn_time_series/predictor.py,sha256=2wnBAHfU5sOSJoHUm65Ej8tJjA4jGYP8yHvWeo1MzyA,1523
|
6
|
+
tabpfn_time_series/tabpfn_worker.py,sha256=k6td4Ml0E3Xr1gERze-S0kyvBB6q_hbLMzvSurdaSp0,11589
|
7
7
|
tabpfn_time_series/ts_dataframe.py,sha256=X94mssw_mSFedjplG55hjwTzKj8mM3VwWynveX3fegA,52834
|
8
8
|
tabpfn_time_series/features/__init__.py,sha256=lzdZWkEfntfg3ZHqNNbfbg-3o_VIzju0tebdRu3AzF4,421
|
9
9
|
tabpfn_time_series/features/auto_features.py,sha256=3OqqY2h7umcoLjLx4hOXypLTjwzrMtd6cQKTNi83vrU,11561
|
10
10
|
tabpfn_time_series/features/basic_features.py,sha256=OV3B__S30-CX88vGjwYQDWqAbJajQw80PxcnvJVUbm4,2955
|
11
11
|
tabpfn_time_series/features/feature_generator_base.py,sha256=jtySWLJyX4E31v6CbX44EHa8cdz7OMyauf4ltNEQeAQ,534
|
12
12
|
tabpfn_time_series/features/feature_transformer.py,sha256=JzxswTGRGlt00QoYFyvAILlUVD68njdvoU3v-phnyi8,1774
|
13
|
-
tabpfn_time_series-1.0.
|
14
|
-
tabpfn_time_series-1.0.
|
15
|
-
tabpfn_time_series-1.0.
|
16
|
-
tabpfn_time_series-1.0.
|
13
|
+
tabpfn_time_series-1.0.4.dist-info/METADATA,sha256=n7fvApkQVQYw_N4aFC6wuOaMxNsr5Lqa9l7nEyhuL6g,4947
|
14
|
+
tabpfn_time_series-1.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
tabpfn_time_series-1.0.4.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
|
16
|
+
tabpfn_time_series-1.0.4.dist-info/RECORD,,
|
File without changes
|
{tabpfn_time_series-1.0.2.dist-info → tabpfn_time_series-1.0.4.dist-info}/licenses/LICENSE.txt
RENAMED
File without changes
|