tabpfn-time-series 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,11 @@ import logging
2
2
  from enum import Enum
3
3
 
4
4
  from tabpfn_time_series.ts_dataframe import TimeSeriesDataFrame
5
- from tabpfn_time_series.tabpfn_worker import TabPFNClient, LocalTabPFN, MockTabPFN
5
+ from tabpfn_time_series.tabpfn_worker import (
6
+ TabPFNClient,
7
+ LocalTabPFN,
8
+ MockTabPFN,
9
+ )
6
10
  from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_CONFIG
7
11
 
8
12
  logger = logging.getLogger(__name__)
@@ -29,6 +33,7 @@ class TabPFNTimeSeriesPredictor:
29
33
  TabPFNMode.LOCAL: lambda: LocalTabPFN(config),
30
34
  TabPFNMode.MOCK: lambda: MockTabPFN(config),
31
35
  }
36
+ self.tabpfn_mode = tabpfn_mode
32
37
  self.tabpfn_worker = worker_mapping[tabpfn_mode]()
33
38
 
34
39
  def predict(
@@ -1,6 +1,8 @@
1
+ import contextvars
1
2
  import logging
2
3
  from abc import ABC, abstractmethod
3
- from joblib import Parallel, delayed
4
+ from joblib import Parallel, delayed, parallel_config
5
+ import backoff
4
6
 
5
7
  from tqdm import tqdm
6
8
  import pandas as pd
@@ -14,6 +16,9 @@ from tabpfn_time_series.defaults import TABPFN_TS_DEFAULT_QUANTILE_CONFIG
14
16
 
15
17
  logger = logging.getLogger(__name__)
16
18
 
19
+ # Per-call attempt counter, isolated per thread & task
20
+ _retry_attempts = contextvars.ContextVar("predict_attempts", default=0)
21
+
17
22
 
18
23
  class TabPFNWorker(ABC):
19
24
  def __init__(
@@ -29,24 +34,7 @@ class TabPFNWorker(ABC):
29
34
  train_tsdf: TimeSeriesDataFrame,
30
35
  test_tsdf: TimeSeriesDataFrame,
31
36
  ):
32
- predictions = Parallel(
33
- n_jobs=self.num_workers,
34
- backend="loky",
35
- )(
36
- delayed(self._prediction_routine)(
37
- item_id,
38
- train_tsdf.loc[item_id],
39
- test_tsdf.loc[item_id],
40
- )
41
- for item_id in tqdm(train_tsdf.item_ids, desc="Predicting time series")
42
- )
43
-
44
- predictions = pd.concat(predictions)
45
-
46
- # Sort predictions according to original item_ids order (important for MASE and WQL calculation)
47
- predictions = predictions.loc[train_tsdf.item_ids]
48
-
49
- return TimeSeriesDataFrame(predictions)
37
+ raise NotImplementedError("Predict method must be implemented in subclass")
50
38
 
51
39
  def _prediction_routine(
52
40
  self,
@@ -54,8 +42,6 @@ class TabPFNWorker(ABC):
54
42
  single_train_tsdf: TimeSeriesDataFrame,
55
43
  single_test_tsdf: TimeSeriesDataFrame,
56
44
  ) -> pd.DataFrame:
57
- # logger.debug(f"Predicting on item_id: {item_id}")
58
-
59
45
  test_index = single_test_tsdf.index
60
46
  train_X, train_y = split_time_series_to_X_y(single_train_tsdf.copy())
61
47
  test_X, _ = split_time_series_to_X_y(single_test_tsdf.copy())
@@ -116,6 +102,39 @@ class TabPFNWorker(ABC):
116
102
  return result
117
103
 
118
104
 
105
+ def _reset_attempts(_details=None):
106
+ """Convenience function to reset the attempt counter."""
107
+ _retry_attempts.set(0)
108
+
109
+
110
+ def _predict_giveup_mixed(exc: Exception) -> bool:
111
+ """Determine whether to give up on a prediction call or not.
112
+
113
+ Returns:
114
+ True if the prediction call should be given up on, False otherwise.
115
+ """
116
+ if _is_tabpfn_gcs_429(exc):
117
+ return False
118
+
119
+ # Stop after first retry for non-429
120
+ return _retry_attempts.get() >= 2
121
+
122
+
123
+ def _is_tabpfn_gcs_429(err: Exception) -> bool:
124
+ """Determine if an error is a 429 error raised from TabPFN API
125
+ and relates to GCS 429 errors.
126
+
127
+ Returns:
128
+ True if the error is a 429 error raised from TabPFN API.
129
+ """
130
+ markers = (
131
+ "TooManyRequests: 429",
132
+ "rateLimitExceeded",
133
+ "cloud.google.com/storage/docs/gcs429",
134
+ )
135
+ return any(m in str(err) for m in markers)
136
+
137
+
119
138
  class TabPFNClient(TabPFNWorker):
120
139
  def __init__(
121
140
  self,
@@ -135,6 +154,53 @@ class TabPFNClient(TabPFNWorker):
135
154
 
136
155
  super().__init__(config, num_workers)
137
156
 
157
+ def predict(
158
+ self,
159
+ train_tsdf: TimeSeriesDataFrame,
160
+ test_tsdf: TimeSeriesDataFrame,
161
+ ):
162
+ # Run the predictions in parallel
163
+ with parallel_config(backend="threading"):
164
+ results = Parallel(
165
+ n_jobs=self.num_workers,
166
+ )(
167
+ delayed(self._prediction_routine)(
168
+ item_id,
169
+ train_tsdf.loc[item_id],
170
+ test_tsdf.loc[item_id],
171
+ )
172
+ for item_id in tqdm(train_tsdf.item_ids, desc="Predicting time series")
173
+ )
174
+
175
+ # Convert list to DataFrame
176
+ predictions = pd.concat(results)
177
+
178
+ # Sort predictions according to original item_ids order (important for MASE and WQL calculation)
179
+ predictions = predictions.loc[train_tsdf.item_ids]
180
+
181
+ return TimeSeriesDataFrame(predictions)
182
+
183
+ @backoff.on_exception(
184
+ backoff.expo,
185
+ Exception,
186
+ base=1,
187
+ factor=2,
188
+ max_tries=5,
189
+ jitter=backoff.full_jitter,
190
+ giveup=_predict_giveup_mixed,
191
+ on_success=_reset_attempts,
192
+ )
193
+ def _prediction_routine(
194
+ self,
195
+ item_id: str,
196
+ single_train_tsdf: TimeSeriesDataFrame,
197
+ single_test_tsdf: TimeSeriesDataFrame,
198
+ ) -> pd.DataFrame:
199
+ # Increment attempt count at start of each try
200
+ _retry_attempts.set(_retry_attempts.get() + 1)
201
+
202
+ return super()._prediction_routine(item_id, single_train_tsdf, single_test_tsdf)
203
+
138
204
  def _get_tabpfn_engine(self):
139
205
  from tabpfn_client import TabPFNRegressor
140
206
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tabpfn_time_series
3
- Version: 1.0.2
3
+ Version: 1.0.4
4
4
  Summary: Zero-shot time series forecasting with TabPFNv2
5
5
  Project-URL: Homepage, https://github.com/liam-sbhoo/tabpfn-time-series
6
6
  Project-URL: Bug Tracker, https://github.com/liam-sbhoo/tabpfn-time-series/issues
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
10
10
  Classifier: Operating System :: OS Independent
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10
13
+ Requires-Dist: backoff>=2.2.1
13
14
  Requires-Dist: datasets>=4.0
14
15
  Requires-Dist: gluonts>=0.16.0
15
16
  Requires-Dist: pandas<2.2.0,>=2.1.2
@@ -2,15 +2,15 @@ tabpfn_time_series/__init__.py,sha256=XJXSKqWp3AF9mAaWi-4KCgHQG7NzNTaBkLOYOMxvhS
2
2
  tabpfn_time_series/data_preparation.py,sha256=wWjSaKgV9KqKonMtSuDbYnW59ixflrScKIP_HSJ_MlA,5427
3
3
  tabpfn_time_series/defaults.py,sha256=ki1y38FR4zmbHWgRjcryA5T88GzNMwhlZC-sTRjuK2U,248
4
4
  tabpfn_time_series/plot.py,sha256=UXgLR2S94vi-vv1ArQKI6uYl_QwSAwAau5jFzGmQ7hw,6582
5
- tabpfn_time_series/predictor.py,sha256=UDhvH7reB3HAuxOyggE4yl2ntQXivvxgXCN_RhoddHc,1467
6
- tabpfn_time_series/tabpfn_worker.py,sha256=ZlJrU0O1dxfKh_As5Le4phm0P4RCDeXKpqp9X-h5bQs,9619
5
+ tabpfn_time_series/predictor.py,sha256=2wnBAHfU5sOSJoHUm65Ej8tJjA4jGYP8yHvWeo1MzyA,1523
6
+ tabpfn_time_series/tabpfn_worker.py,sha256=k6td4Ml0E3Xr1gERze-S0kyvBB6q_hbLMzvSurdaSp0,11589
7
7
  tabpfn_time_series/ts_dataframe.py,sha256=X94mssw_mSFedjplG55hjwTzKj8mM3VwWynveX3fegA,52834
8
8
  tabpfn_time_series/features/__init__.py,sha256=lzdZWkEfntfg3ZHqNNbfbg-3o_VIzju0tebdRu3AzF4,421
9
9
  tabpfn_time_series/features/auto_features.py,sha256=3OqqY2h7umcoLjLx4hOXypLTjwzrMtd6cQKTNi83vrU,11561
10
10
  tabpfn_time_series/features/basic_features.py,sha256=OV3B__S30-CX88vGjwYQDWqAbJajQw80PxcnvJVUbm4,2955
11
11
  tabpfn_time_series/features/feature_generator_base.py,sha256=jtySWLJyX4E31v6CbX44EHa8cdz7OMyauf4ltNEQeAQ,534
12
12
  tabpfn_time_series/features/feature_transformer.py,sha256=JzxswTGRGlt00QoYFyvAILlUVD68njdvoU3v-phnyi8,1774
13
- tabpfn_time_series-1.0.2.dist-info/METADATA,sha256=tnwSlCc7EscY_9utB6ouU5FUlQulml4wIz18rE6QwJA,4917
14
- tabpfn_time_series-1.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- tabpfn_time_series-1.0.2.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
16
- tabpfn_time_series-1.0.2.dist-info/RECORD,,
13
+ tabpfn_time_series-1.0.4.dist-info/METADATA,sha256=n7fvApkQVQYw_N4aFC6wuOaMxNsr5Lqa9l7nEyhuL6g,4947
14
+ tabpfn_time_series-1.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ tabpfn_time_series-1.0.4.dist-info/licenses/LICENSE.txt,sha256=iwhPL7kIWQG6gyLZZwIMDItGrNgxMDIq9itxkUSMapY,11345
16
+ tabpfn_time_series-1.0.4.dist-info/RECORD,,