falcon-tst 1.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: falcon-tst
3
+ Version: 1.0.10
4
+ Summary: Python SDK for Falcon Studio prediction API.
5
+ Requires-Python: >=3.9
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: numpy>=1.23.0
8
+ Requires-Dist: requests>=2.31.0
9
+
10
+ # Falcon SDK
11
+
12
+ The official Python SDK for the Falcon-TST Prediction API. Falcon-TST is a family of large-scale time-series foundation models developed by Ant International.
13
+
14
+ ## Install
15
+
16
+ ```bash
17
+ pip install falcon-tst
18
+ ```
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ import numpy as np
24
+
25
+ from falcontst import FalconClient
26
+
27
+ client = FalconClient()
28
+
29
+ result = client.quantile_predict(
30
+ context=np.array([
31
+ [1.2, 3.4, 5.6, 7.8, 9.0],
32
+ [2.1, 4.3, 6.5, 8.7, 0.9],
33
+ ]),
34
+ prediction_length=3,
35
+ model_name="demo_model",
36
+ group_ids=np.array([0, 0]),
37
+ input_mask=np.array([
38
+ [1, 1, 1, 1, 1],
39
+ [1, 1, 1, 0, 0],
40
+ ]),
41
+ is_multivariate=False,
42
+ )
43
+
44
+ print(result)
45
+ ```
46
+
47
+ Batch prediction uses the same object fields as `quantile_predict`, but sends
48
+ multiple objects to the batch endpoint:
49
+
50
+ ```python
51
+ result = client.batch_predict(
52
+ [
53
+ {
54
+ "context": np.array([[1.2, 3.4, 5.6, 7.8, 9.0]]),
55
+ "prediction_length": 3,
56
+ "model_name": "demo_model",
57
+ "group_ids": np.array([0]),
58
+ "input_mask": np.array([[1, 1, 1, 1, 1]]),
59
+ "is_multivariate": False,
60
+ },
61
+ {
62
+ "context": np.array([[2.1, 4.3, 6.5, 8.7, 0.9]]),
63
+ "prediction_length": 3,
64
+ },
65
+ ]
66
+ )
67
+ ```
68
+
69
+ Each `context` should be a two-dimensional `numpy.ndarray`. For a single time
70
+ series, use `np.array([[1.2, 3.4, 5.6]])` instead of `np.array([1.2, 3.4, 5.6])`.
@@ -0,0 +1,61 @@
1
+ # Falcon SDK
2
+
3
+ The official Python SDK for the Falcon-TST Prediction API. Falcon-TST is a family of large-scale time-series foundation models developed by Ant International.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install falcon-tst
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import numpy as np
15
+
16
+ from falcontst import FalconClient
17
+
18
+ client = FalconClient()
19
+
20
+ result = client.quantile_predict(
21
+ context=np.array([
22
+ [1.2, 3.4, 5.6, 7.8, 9.0],
23
+ [2.1, 4.3, 6.5, 8.7, 0.9],
24
+ ]),
25
+ prediction_length=3,
26
+ model_name="demo_model",
27
+ group_ids=np.array([0, 0]),
28
+ input_mask=np.array([
29
+ [1, 1, 1, 1, 1],
30
+ [1, 1, 1, 0, 0],
31
+ ]),
32
+ is_multivariate=False,
33
+ )
34
+
35
+ print(result)
36
+ ```
37
+
38
+ Batch prediction uses the same object fields as `quantile_predict`, but sends
39
+ multiple objects to the batch endpoint:
40
+
41
+ ```python
42
+ result = client.batch_predict(
43
+ [
44
+ {
45
+ "context": np.array([[1.2, 3.4, 5.6, 7.8, 9.0]]),
46
+ "prediction_length": 3,
47
+ "model_name": "demo_model",
48
+ "group_ids": np.array([0]),
49
+ "input_mask": np.array([[1, 1, 1, 1, 1]]),
50
+ "is_multivariate": False,
51
+ },
52
+ {
53
+ "context": np.array([[2.1, 4.3, 6.5, 8.7, 0.9]]),
54
+ "prediction_length": 3,
55
+ },
56
+ ]
57
+ )
58
+ ```
59
+
60
+ Each `context` should be a two-dimensional `numpy.ndarray`. For a single time
61
+ series, use `np.array([[1.2, 3.4, 5.6]])` instead of `np.array([1.2, 3.4, 5.6])`.
@@ -0,0 +1,22 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "falcon-tst"
7
+ version = "1.0.10"
8
+ description = "Python SDK for Falcon Studio prediction API."
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ dependencies = [
12
+ "numpy>=1.23.0",
13
+ "requests>=2.31.0",
14
+ ]
15
+
16
+ [tool.setuptools.packages.find]
17
+ where = ["src"]
18
+
19
+ [dependency-groups]
20
+ dev = [
21
+ "twine>=6.2.0",
22
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: falcon-tst
3
+ Version: 1.0.10
4
+ Summary: Python SDK for Falcon Studio prediction API.
5
+ Requires-Python: >=3.9
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: numpy>=1.23.0
8
+ Requires-Dist: requests>=2.31.0
9
+
10
+ # Falcon SDK
11
+
12
+ The official Python SDK for the Falcon-TST Prediction API. Falcon-TST is a family of large-scale time-series foundation models developed by Ant International.
13
+
14
+ ## Install
15
+
16
+ ```bash
17
+ pip install falcon-tst
18
+ ```
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ import numpy as np
24
+
25
+ from falcontst import FalconClient
26
+
27
+ client = FalconClient()
28
+
29
+ result = client.quantile_predict(
30
+ context=np.array([
31
+ [1.2, 3.4, 5.6, 7.8, 9.0],
32
+ [2.1, 4.3, 6.5, 8.7, 0.9],
33
+ ]),
34
+ prediction_length=3,
35
+ model_name="demo_model",
36
+ group_ids=np.array([0, 0]),
37
+ input_mask=np.array([
38
+ [1, 1, 1, 1, 1],
39
+ [1, 1, 1, 0, 0],
40
+ ]),
41
+ is_multivariate=False,
42
+ )
43
+
44
+ print(result)
45
+ ```
46
+
47
+ Batch prediction uses the same object fields as `quantile_predict`, but sends
48
+ multiple objects to the batch endpoint:
49
+
50
+ ```python
51
+ result = client.batch_predict(
52
+ [
53
+ {
54
+ "context": np.array([[1.2, 3.4, 5.6, 7.8, 9.0]]),
55
+ "prediction_length": 3,
56
+ "model_name": "demo_model",
57
+ "group_ids": np.array([0]),
58
+ "input_mask": np.array([[1, 1, 1, 1, 1]]),
59
+ "is_multivariate": False,
60
+ },
61
+ {
62
+ "context": np.array([[2.1, 4.3, 6.5, 8.7, 0.9]]),
63
+ "prediction_length": 3,
64
+ },
65
+ ]
66
+ )
67
+ ```
68
+
69
+ Each `context` should be a two-dimensional `numpy.ndarray`. For a single time
70
+ series, use `np.array([[1.2, 3.4, 5.6]])` instead of `np.array([1.2, 3.4, 5.6])`.
@@ -0,0 +1,10 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/falcon_tst.egg-info/PKG-INFO
4
+ src/falcon_tst.egg-info/SOURCES.txt
5
+ src/falcon_tst.egg-info/dependency_links.txt
6
+ src/falcon_tst.egg-info/requires.txt
7
+ src/falcon_tst.egg-info/top_level.txt
8
+ src/falcontst/__init__.py
9
+ src/falcontst/client.py
10
+ tests/test_client.py
@@ -0,0 +1,2 @@
1
+ numpy>=1.23.0
2
+ requests>=2.31.0
@@ -0,0 +1 @@
1
+ falcontst
@@ -0,0 +1,3 @@
1
+ from .client import FalconAPIError, FalconClient
2
+
3
+ __all__ = ["FalconAPIError", "FalconClient"]
@@ -0,0 +1,157 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Mapping, Optional, Sequence, Union
5
+
6
+ import numpy as np
7
+ import requests
8
+
9
+
10
+ DEFAULT_PREDICT_URL = (
11
+ "https://falconstudio-pre.antglobal.com"
12
+ "/falconstudio/api/v1/openapi/predict"
13
+ )
14
+ DEFAULT_BATCH_PREDICT_URL = (
15
+ "https://falconstudio-pre.antglobal.com"
16
+ "/falconstudio/api/v1/openapi/batch-predict"
17
+ )
18
+
19
+
20
+ class FalconAPIError(Exception):
21
+ """Raised when Falcon Studio API returns a non-success code."""
22
+
23
+
24
+ class FalconClient:
25
+ """Client for Falcon Studio prediction API."""
26
+
27
+ def __init__(
28
+ self,
29
+ endpoint: str = DEFAULT_PREDICT_URL,
30
+ batch_endpoint: str = DEFAULT_BATCH_PREDICT_URL,
31
+ timeout: float = 30.0,
32
+ session: Optional[requests.Session] = None,
33
+ verify: Optional[Union[bool, str]] = None,
34
+ ) -> None:
35
+ self.endpoint = endpoint
36
+ self.batch_endpoint = batch_endpoint
37
+ self.timeout = timeout
38
+ self.session = session or requests.Session()
39
+ self.verify = verify
40
+
41
+ def quantile_predict(
42
+ self,
43
+ context: np.ndarray,
44
+ prediction_length: int,
45
+ model_name: Optional[str] = None,
46
+ group_ids: Optional[np.ndarray] = None,
47
+ input_mask: Optional[np.ndarray] = None,
48
+ is_multivariate: bool = False,
49
+ ) -> Any:
50
+ """Quantile forecast inference."""
51
+ payload = self._build_predict_payload(
52
+ context=context,
53
+ prediction_length=prediction_length,
54
+ model_name=model_name,
55
+ group_ids=group_ids,
56
+ input_mask=input_mask,
57
+ is_multivariate=is_multivariate,
58
+ )
59
+
60
+ return self._post(self.endpoint, payload)
61
+
62
+ def batch_predict(self, objects: Sequence[Mapping[str, Any]]) -> Any:
63
+ """Batch quantile forecast inference."""
64
+ if not isinstance(objects, Sequence) or isinstance(objects, (str, bytes)):
65
+ raise TypeError(
66
+ f"Expected objects to be a sequence, got {type(objects).__name__}"
67
+ )
68
+
69
+ payload = []
70
+ for item in objects:
71
+ if not isinstance(item, Mapping):
72
+ raise TypeError(
73
+ f"Expected batch item to be a mapping, got {type(item).__name__}"
74
+ )
75
+
76
+ payload.append(
77
+ self._build_predict_payload(
78
+ context=item["context"],
79
+ prediction_length=item["prediction_length"],
80
+ model_name=item.get("model_name"),
81
+ group_ids=item.get("group_ids"),
82
+ input_mask=item.get("input_mask"),
83
+ is_multivariate=item.get("is_multivariate", False),
84
+ )
85
+ )
86
+
87
+ return self._post(self.batch_endpoint, payload)
88
+
89
+ def _build_predict_payload(
90
+ self,
91
+ context: np.ndarray,
92
+ prediction_length: int,
93
+ model_name: Optional[str] = None,
94
+ group_ids: Optional[np.ndarray] = None,
95
+ input_mask: Optional[np.ndarray] = None,
96
+ is_multivariate: bool = False,
97
+ ) -> dict[str, Any]:
98
+ if not isinstance(prediction_length, int):
99
+ raise TypeError(
100
+ f"Expected prediction_length to be int, got {type(prediction_length).__name__}"
101
+ )
102
+ if prediction_length <= 0:
103
+ raise ValueError("prediction_length must be greater than 0")
104
+
105
+ payload = {
106
+ "context": self._array_to_list(context),
107
+ "prediction_length": prediction_length,
108
+ "is_multivariate": is_multivariate,
109
+ }
110
+
111
+ if model_name is not None:
112
+ payload["model_name"] = model_name
113
+ if group_ids is not None:
114
+ payload["group_ids"] = self._array_to_list(group_ids)
115
+ if input_mask is not None:
116
+ payload["input_mask"] = self._array_to_list(input_mask)
117
+
118
+ return payload
119
+
120
+ def _post(self, endpoint: str, payload: Any) -> dict[str, Any]:
121
+ request_kwargs = {
122
+ "json": payload,
123
+ "timeout": self.timeout,
124
+ }
125
+ if self.verify is not None:
126
+ request_kwargs["verify"] = self.verify
127
+
128
+ response = self.session.post(endpoint, **request_kwargs)
129
+ response.raise_for_status()
130
+
131
+ return self._format_response(response.json())
132
+
133
+ @staticmethod
134
+ def _array_to_list(array: np.ndarray) -> list[Any]:
135
+ if not isinstance(array, np.ndarray):
136
+ raise TypeError(f"Expected numpy.ndarray, got {type(array).__name__}")
137
+ return FalconClient._replace_nan(array.tolist())
138
+
139
+ @staticmethod
140
+ def _replace_nan(value: Any) -> Any:
141
+ if isinstance(value, list):
142
+ return [FalconClient._replace_nan(item) for item in value]
143
+ if isinstance(value, float) and math.isnan(value):
144
+ return None
145
+ return value
146
+
147
+ @staticmethod
148
+ def _format_response(data: Any) -> dict[str, Any]:
149
+ if not isinstance(data, dict):
150
+ raise FalconAPIError(f"Unexpected response format: {data}")
151
+
152
+ code = data.get("code")
153
+ message = data.get("message", "")
154
+ if code != 200:
155
+ raise FalconAPIError(f"Falcon API error: code={code}, message={message}")
156
+
157
+ return {"prob_prediction": data.get("data")}
@@ -0,0 +1,238 @@
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from falcontst import FalconAPIError, FalconClient
5
+
6
+
7
+ class DummyResponse:
8
+ def __init__(self, data):
9
+ self.data = data
10
+
11
+ def raise_for_status(self):
12
+ return None
13
+
14
+ def json(self):
15
+ return self.data
16
+
17
+
18
+ class DummySession:
19
+ def __init__(self, data):
20
+ self.data = data
21
+ self.calls = []
22
+
23
+ def post(self, url, **kwargs):
24
+ self.calls.append({"url": url, **kwargs})
25
+ return DummyResponse(self.data)
26
+
27
+
28
+ def test_quantile_predict_posts_expected_payload_and_returns_json():
29
+ session = DummySession(
30
+ {
31
+ "code": 200,
32
+ "message": "success",
33
+ "data": [[[1.6039936542510986, -0.09775638580322266, -1.3819665908813477]]],
34
+ }
35
+ )
36
+ client = FalconClient(endpoint="https://example.test/predict", timeout=5, session=session)
37
+
38
+ result = client.quantile_predict(
39
+ context=np.array([[1.2, 3.4, 5.6, 7.8, 9.0], [2.1, 4.3, 6.5, 8.7, 0.9]]),
40
+ prediction_length=3,
41
+ model_name="demo_model",
42
+ group_ids=np.array([0, 0]),
43
+ input_mask=np.array([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]),
44
+ is_multivariate=False,
45
+ )
46
+
47
+ assert session.calls == [
48
+ {
49
+ "url": "https://example.test/predict",
50
+ "json": {
51
+ "context": [[1.2, 3.4, 5.6, 7.8, 9.0], [2.1, 4.3, 6.5, 8.7, 0.9]],
52
+ "prediction_length": 3,
53
+ "model_name": "demo_model",
54
+ "group_ids": [0, 0],
55
+ "input_mask": [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]],
56
+ "is_multivariate": False,
57
+ },
58
+ "timeout": 5,
59
+ }
60
+ ]
61
+ assert result == {
62
+ "prob_prediction": [
63
+ [[1.6039936542510986, -0.09775638580322266, -1.3819665908813477]]
64
+ ]
65
+ }
66
+
67
+
68
+ def test_batch_predict_posts_expected_payload_and_returns_json():
69
+ session = DummySession(
70
+ {
71
+ "code": 200,
72
+ "message": "success",
73
+ "data": [
74
+ [[[1.1, 1.2, 1.3]]],
75
+ [[[2.1, 2.2, 2.3]]],
76
+ ],
77
+ }
78
+ )
79
+ client = FalconClient(
80
+ batch_endpoint="https://example.test/batch-predict",
81
+ timeout=5,
82
+ session=session,
83
+ )
84
+
85
+ result = client.batch_predict(
86
+ [
87
+ {
88
+ "context": np.array([[1.2, 3.4, 5.6]]),
89
+ "prediction_length": 3,
90
+ "model_name": "demo_model",
91
+ "group_ids": np.array([0]),
92
+ "input_mask": np.array([[1, 1, 1]]),
93
+ "is_multivariate": False,
94
+ },
95
+ {
96
+ "context": np.array([[2.1, np.nan, 6.5]]),
97
+ "prediction_length": 3,
98
+ },
99
+ ]
100
+ )
101
+
102
+ assert session.calls == [
103
+ {
104
+ "url": "https://example.test/batch-predict",
105
+ "json": [
106
+ {
107
+ "context": [[1.2, 3.4, 5.6]],
108
+ "prediction_length": 3,
109
+ "model_name": "demo_model",
110
+ "group_ids": [0],
111
+ "input_mask": [[1, 1, 1]],
112
+ "is_multivariate": False,
113
+ },
114
+ {
115
+ "context": [[2.1, None, 6.5]],
116
+ "prediction_length": 3,
117
+ "is_multivariate": False,
118
+ },
119
+ ],
120
+ "timeout": 5,
121
+ }
122
+ ]
123
+ assert result == {
124
+ "prob_prediction": [
125
+ [[[1.1, 1.2, 1.3]]],
126
+ [[[2.1, 2.2, 2.3]]],
127
+ ]
128
+ }
129
+
130
+
131
+ def test_batch_predict_requires_sequence_of_objects():
132
+ client = FalconClient(session=DummySession({"data": []}))
133
+
134
+ with pytest.raises(TypeError, match="Expected objects to be a sequence"):
135
+ client.batch_predict("not a batch")
136
+
137
+ with pytest.raises(TypeError, match="Expected batch item to be a mapping"):
138
+ client.batch_predict([["not", "a", "mapping"]])
139
+
140
+
141
+ def test_optional_arguments_are_omitted_when_none():
142
+ session = DummySession({"code": 200, "message": "success", "data": [1, 2, 3]})
143
+ client = FalconClient(endpoint="https://example.test/predict", session=session)
144
+
145
+ client.quantile_predict(
146
+ context=np.array([[1.0, 2.0]]),
147
+ prediction_length=1,
148
+ )
149
+
150
+ payload = session.calls[0]["json"]
151
+ assert "model_name" not in payload
152
+ assert "group_ids" not in payload
153
+ assert "input_mask" not in payload
154
+
155
+
156
+ def test_verify_is_passed_when_configured():
157
+ session = DummySession({"code": 200, "message": "success", "data": [1, 2, 3]})
158
+ client = FalconClient(
159
+ endpoint="https://example.test/predict",
160
+ session=session,
161
+ verify=False,
162
+ )
163
+
164
+ client.quantile_predict(
165
+ context=np.array([[1.0, 2.0]]),
166
+ prediction_length=1,
167
+ )
168
+
169
+ assert session.calls[0]["verify"] is False
170
+
171
+
172
+ def test_nan_values_are_sent_as_json_null():
173
+ session = DummySession({"code": 200, "message": "success", "data": [1, 2, 3]})
174
+ client = FalconClient(endpoint="https://example.test/predict", session=session)
175
+
176
+ client.quantile_predict(
177
+ context=np.array([[1.0, np.nan], [np.nan, 4.0]]),
178
+ prediction_length=1,
179
+ group_ids=np.array([0, np.nan]),
180
+ input_mask=np.array([[1.0, np.nan], [1.0, 0.0]]),
181
+ )
182
+
183
+ payload = session.calls[0]["json"]
184
+ assert payload["context"] == [[1.0, None], [None, 4.0]]
185
+ assert payload["group_ids"] == [0.0, None]
186
+ assert payload["input_mask"] == [[1.0, None], [1.0, 0.0]]
187
+
188
+
189
+ def test_context_must_be_numpy_array():
190
+ client = FalconClient(session=DummySession({"data": []}))
191
+
192
+ with pytest.raises(TypeError, match="Expected numpy.ndarray"):
193
+ client.quantile_predict(
194
+ context=[[1.0, 2.0]],
195
+ prediction_length=1,
196
+ )
197
+
198
+
199
+ def test_prediction_length_must_be_int():
200
+ client = FalconClient(session=DummySession({"data": []}))
201
+
202
+ with pytest.raises(TypeError, match="Expected prediction_length to be int"):
203
+ client.quantile_predict(
204
+ context=np.array([[1.0, 2.0]]),
205
+ prediction_length=None,
206
+ )
207
+
208
+
209
+ def test_prediction_length_must_be_positive():
210
+ client = FalconClient(session=DummySession({"data": []}))
211
+
212
+ with pytest.raises(ValueError, match="prediction_length must be greater than 0"):
213
+ client.quantile_predict(
214
+ context=np.array([[1.0, 2.0]]),
215
+ prediction_length=0,
216
+ )
217
+
218
+
219
+ def test_api_error_is_raised_when_code_is_not_200():
220
+ client = FalconClient(
221
+ session=DummySession({"code": 500, "message": "model not found", "data": None})
222
+ )
223
+
224
+ with pytest.raises(FalconAPIError, match="code=500, message=model not found"):
225
+ client.quantile_predict(
226
+ context=np.array([[1.0, 2.0]]),
227
+ prediction_length=1,
228
+ )
229
+
230
+
231
+ def test_unexpected_response_format_raises_api_error():
232
+ client = FalconClient(session=DummySession([1, 2, 3]))
233
+
234
+ with pytest.raises(FalconAPIError, match="Unexpected response format"):
235
+ client.quantile_predict(
236
+ context=np.array([[1.0, 2.0]]),
237
+ prediction_length=1,
238
+ )