falcon-tst 1.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- falcon_tst-1.0.10/PKG-INFO +70 -0
- falcon_tst-1.0.10/README.md +61 -0
- falcon_tst-1.0.10/pyproject.toml +22 -0
- falcon_tst-1.0.10/setup.cfg +4 -0
- falcon_tst-1.0.10/src/falcon_tst.egg-info/PKG-INFO +70 -0
- falcon_tst-1.0.10/src/falcon_tst.egg-info/SOURCES.txt +10 -0
- falcon_tst-1.0.10/src/falcon_tst.egg-info/dependency_links.txt +1 -0
- falcon_tst-1.0.10/src/falcon_tst.egg-info/requires.txt +2 -0
- falcon_tst-1.0.10/src/falcon_tst.egg-info/top_level.txt +1 -0
- falcon_tst-1.0.10/src/falcontst/__init__.py +3 -0
- falcon_tst-1.0.10/src/falcontst/client.py +157 -0
- falcon_tst-1.0.10/tests/test_client.py +238 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: falcon-tst
|
|
3
|
+
Version: 1.0.10
|
|
4
|
+
Summary: Python SDK for Falcon Studio prediction API.
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: numpy>=1.23.0
|
|
8
|
+
Requires-Dist: requests>=2.31.0
|
|
9
|
+
|
|
10
|
+
# Falcon SDK
|
|
11
|
+
|
|
12
|
+
The official Python SDK for the Falcon-TST Prediction API. Falcon-TST is a family of large-scale time-series foundation models developed by Ant International.
|
|
13
|
+
|
|
14
|
+
## Install
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install falcon-tst
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Usage
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from falcontst import FalconClient
|
|
26
|
+
|
|
27
|
+
client = FalconClient()
|
|
28
|
+
|
|
29
|
+
result = client.quantile_predict(
|
|
30
|
+
context=np.array([
|
|
31
|
+
[1.2, 3.4, 5.6, 7.8, 9.0],
|
|
32
|
+
[2.1, 4.3, 6.5, 8.7, 0.9],
|
|
33
|
+
]),
|
|
34
|
+
prediction_length=3,
|
|
35
|
+
model_name="demo_model",
|
|
36
|
+
group_ids=np.array([0, 0]),
|
|
37
|
+
input_mask=np.array([
|
|
38
|
+
[1, 1, 1, 1, 1],
|
|
39
|
+
[1, 1, 1, 0, 0],
|
|
40
|
+
]),
|
|
41
|
+
is_multivariate=False,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
print(result)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Batch prediction uses the same object fields as `quantile_predict`, but sends
|
|
48
|
+
multiple objects to the batch endpoint:
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
result = client.batch_predict(
|
|
52
|
+
[
|
|
53
|
+
{
|
|
54
|
+
"context": np.array([[1.2, 3.4, 5.6, 7.8, 9.0]]),
|
|
55
|
+
"prediction_length": 3,
|
|
56
|
+
"model_name": "demo_model",
|
|
57
|
+
"group_ids": np.array([0]),
|
|
58
|
+
"input_mask": np.array([[1, 1, 1, 1, 1]]),
|
|
59
|
+
"is_multivariate": False,
|
|
60
|
+
},
|
|
61
|
+
{
|
|
62
|
+
"context": np.array([[2.1, 4.3, 6.5, 8.7, 0.9]]),
|
|
63
|
+
"prediction_length": 3,
|
|
64
|
+
},
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
Each `context` should be a two-dimensional `numpy.ndarray`. For a single time
|
|
70
|
+
series, use `np.array([[1.2, 3.4, 5.6]])` instead of `np.array([1.2, 3.4, 5.6])`.
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Falcon SDK
|
|
2
|
+
|
|
3
|
+
The official Python SDK for the Falcon-TST Prediction API. Falcon-TST is a family of large-scale time-series foundation models developed by Ant International.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install falcon-tst
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from falcontst import FalconClient
|
|
17
|
+
|
|
18
|
+
client = FalconClient()
|
|
19
|
+
|
|
20
|
+
result = client.quantile_predict(
|
|
21
|
+
context=np.array([
|
|
22
|
+
[1.2, 3.4, 5.6, 7.8, 9.0],
|
|
23
|
+
[2.1, 4.3, 6.5, 8.7, 0.9],
|
|
24
|
+
]),
|
|
25
|
+
prediction_length=3,
|
|
26
|
+
model_name="demo_model",
|
|
27
|
+
group_ids=np.array([0, 0]),
|
|
28
|
+
input_mask=np.array([
|
|
29
|
+
[1, 1, 1, 1, 1],
|
|
30
|
+
[1, 1, 1, 0, 0],
|
|
31
|
+
]),
|
|
32
|
+
is_multivariate=False,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
print(result)
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
Batch prediction uses the same object fields as `quantile_predict`, but sends
|
|
39
|
+
multiple objects to the batch endpoint:
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
result = client.batch_predict(
|
|
43
|
+
[
|
|
44
|
+
{
|
|
45
|
+
"context": np.array([[1.2, 3.4, 5.6, 7.8, 9.0]]),
|
|
46
|
+
"prediction_length": 3,
|
|
47
|
+
"model_name": "demo_model",
|
|
48
|
+
"group_ids": np.array([0]),
|
|
49
|
+
"input_mask": np.array([[1, 1, 1, 1, 1]]),
|
|
50
|
+
"is_multivariate": False,
|
|
51
|
+
},
|
|
52
|
+
{
|
|
53
|
+
"context": np.array([[2.1, 4.3, 6.5, 8.7, 0.9]]),
|
|
54
|
+
"prediction_length": 3,
|
|
55
|
+
},
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Each `context` should be a two-dimensional `numpy.ndarray`. For a single time
|
|
61
|
+
series, use `np.array([[1.2, 3.4, 5.6]])` instead of `np.array([1.2, 3.4, 5.6])`.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "falcon-tst"
|
|
7
|
+
version = "1.0.10"
|
|
8
|
+
description = "Python SDK for Falcon Studio prediction API."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"numpy>=1.23.0",
|
|
13
|
+
"requests>=2.31.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[tool.setuptools.packages.find]
|
|
17
|
+
where = ["src"]
|
|
18
|
+
|
|
19
|
+
[dependency-groups]
|
|
20
|
+
dev = [
|
|
21
|
+
"twine>=6.2.0",
|
|
22
|
+
]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: falcon-tst
|
|
3
|
+
Version: 1.0.10
|
|
4
|
+
Summary: Python SDK for Falcon Studio prediction API.
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: numpy>=1.23.0
|
|
8
|
+
Requires-Dist: requests>=2.31.0
|
|
9
|
+
|
|
10
|
+
# Falcon SDK
|
|
11
|
+
|
|
12
|
+
The official Python SDK for the Falcon-TST Prediction API. Falcon-TST is a family of large-scale time-series foundation models developed by Ant International.
|
|
13
|
+
|
|
14
|
+
## Install
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install falcon-tst
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Usage
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from falcontst import FalconClient
|
|
26
|
+
|
|
27
|
+
client = FalconClient()
|
|
28
|
+
|
|
29
|
+
result = client.quantile_predict(
|
|
30
|
+
context=np.array([
|
|
31
|
+
[1.2, 3.4, 5.6, 7.8, 9.0],
|
|
32
|
+
[2.1, 4.3, 6.5, 8.7, 0.9],
|
|
33
|
+
]),
|
|
34
|
+
prediction_length=3,
|
|
35
|
+
model_name="demo_model",
|
|
36
|
+
group_ids=np.array([0, 0]),
|
|
37
|
+
input_mask=np.array([
|
|
38
|
+
[1, 1, 1, 1, 1],
|
|
39
|
+
[1, 1, 1, 0, 0],
|
|
40
|
+
]),
|
|
41
|
+
is_multivariate=False,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
print(result)
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Batch prediction uses the same object fields as `quantile_predict`, but sends
|
|
48
|
+
multiple objects to the batch endpoint:
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
result = client.batch_predict(
|
|
52
|
+
[
|
|
53
|
+
{
|
|
54
|
+
"context": np.array([[1.2, 3.4, 5.6, 7.8, 9.0]]),
|
|
55
|
+
"prediction_length": 3,
|
|
56
|
+
"model_name": "demo_model",
|
|
57
|
+
"group_ids": np.array([0]),
|
|
58
|
+
"input_mask": np.array([[1, 1, 1, 1, 1]]),
|
|
59
|
+
"is_multivariate": False,
|
|
60
|
+
},
|
|
61
|
+
{
|
|
62
|
+
"context": np.array([[2.1, 4.3, 6.5, 8.7, 0.9]]),
|
|
63
|
+
"prediction_length": 3,
|
|
64
|
+
},
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
Each `context` should be a two-dimensional `numpy.ndarray`. For a single time
|
|
70
|
+
series, use `np.array([[1.2, 3.4, 5.6]])` instead of `np.array([1.2, 3.4, 5.6])`.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/falcon_tst.egg-info/PKG-INFO
|
|
4
|
+
src/falcon_tst.egg-info/SOURCES.txt
|
|
5
|
+
src/falcon_tst.egg-info/dependency_links.txt
|
|
6
|
+
src/falcon_tst.egg-info/requires.txt
|
|
7
|
+
src/falcon_tst.egg-info/top_level.txt
|
|
8
|
+
src/falcontst/__init__.py
|
|
9
|
+
src/falcontst/client.py
|
|
10
|
+
tests/test_client.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
falcontst
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any, Mapping, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import requests
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_PREDICT_URL = (
|
|
11
|
+
"https://falconstudio-pre.antglobal.com"
|
|
12
|
+
"/falconstudio/api/v1/openapi/predict"
|
|
13
|
+
)
|
|
14
|
+
DEFAULT_BATCH_PREDICT_URL = (
|
|
15
|
+
"https://falconstudio-pre.antglobal.com"
|
|
16
|
+
"/falconstudio/api/v1/openapi/batch-predict"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FalconAPIError(Exception):
|
|
21
|
+
"""Raised when Falcon Studio API returns a non-success code."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FalconClient:
|
|
25
|
+
"""Client for Falcon Studio prediction API."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
endpoint: str = DEFAULT_PREDICT_URL,
|
|
30
|
+
batch_endpoint: str = DEFAULT_BATCH_PREDICT_URL,
|
|
31
|
+
timeout: float = 30.0,
|
|
32
|
+
session: Optional[requests.Session] = None,
|
|
33
|
+
verify: Optional[Union[bool, str]] = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.endpoint = endpoint
|
|
36
|
+
self.batch_endpoint = batch_endpoint
|
|
37
|
+
self.timeout = timeout
|
|
38
|
+
self.session = session or requests.Session()
|
|
39
|
+
self.verify = verify
|
|
40
|
+
|
|
41
|
+
def quantile_predict(
|
|
42
|
+
self,
|
|
43
|
+
context: np.ndarray,
|
|
44
|
+
prediction_length: int,
|
|
45
|
+
model_name: Optional[str] = None,
|
|
46
|
+
group_ids: Optional[np.ndarray] = None,
|
|
47
|
+
input_mask: Optional[np.ndarray] = None,
|
|
48
|
+
is_multivariate: bool = False,
|
|
49
|
+
) -> Any:
|
|
50
|
+
"""Quantile forecast inference."""
|
|
51
|
+
payload = self._build_predict_payload(
|
|
52
|
+
context=context,
|
|
53
|
+
prediction_length=prediction_length,
|
|
54
|
+
model_name=model_name,
|
|
55
|
+
group_ids=group_ids,
|
|
56
|
+
input_mask=input_mask,
|
|
57
|
+
is_multivariate=is_multivariate,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return self._post(self.endpoint, payload)
|
|
61
|
+
|
|
62
|
+
def batch_predict(self, objects: Sequence[Mapping[str, Any]]) -> Any:
|
|
63
|
+
"""Batch quantile forecast inference."""
|
|
64
|
+
if not isinstance(objects, Sequence) or isinstance(objects, (str, bytes)):
|
|
65
|
+
raise TypeError(
|
|
66
|
+
f"Expected objects to be a sequence, got {type(objects).__name__}"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
payload = []
|
|
70
|
+
for item in objects:
|
|
71
|
+
if not isinstance(item, Mapping):
|
|
72
|
+
raise TypeError(
|
|
73
|
+
f"Expected batch item to be a mapping, got {type(item).__name__}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
payload.append(
|
|
77
|
+
self._build_predict_payload(
|
|
78
|
+
context=item["context"],
|
|
79
|
+
prediction_length=item["prediction_length"],
|
|
80
|
+
model_name=item.get("model_name"),
|
|
81
|
+
group_ids=item.get("group_ids"),
|
|
82
|
+
input_mask=item.get("input_mask"),
|
|
83
|
+
is_multivariate=item.get("is_multivariate", False),
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return self._post(self.batch_endpoint, payload)
|
|
88
|
+
|
|
89
|
+
def _build_predict_payload(
|
|
90
|
+
self,
|
|
91
|
+
context: np.ndarray,
|
|
92
|
+
prediction_length: int,
|
|
93
|
+
model_name: Optional[str] = None,
|
|
94
|
+
group_ids: Optional[np.ndarray] = None,
|
|
95
|
+
input_mask: Optional[np.ndarray] = None,
|
|
96
|
+
is_multivariate: bool = False,
|
|
97
|
+
) -> dict[str, Any]:
|
|
98
|
+
if not isinstance(prediction_length, int):
|
|
99
|
+
raise TypeError(
|
|
100
|
+
f"Expected prediction_length to be int, got {type(prediction_length).__name__}"
|
|
101
|
+
)
|
|
102
|
+
if prediction_length <= 0:
|
|
103
|
+
raise ValueError("prediction_length must be greater than 0")
|
|
104
|
+
|
|
105
|
+
payload = {
|
|
106
|
+
"context": self._array_to_list(context),
|
|
107
|
+
"prediction_length": prediction_length,
|
|
108
|
+
"is_multivariate": is_multivariate,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if model_name is not None:
|
|
112
|
+
payload["model_name"] = model_name
|
|
113
|
+
if group_ids is not None:
|
|
114
|
+
payload["group_ids"] = self._array_to_list(group_ids)
|
|
115
|
+
if input_mask is not None:
|
|
116
|
+
payload["input_mask"] = self._array_to_list(input_mask)
|
|
117
|
+
|
|
118
|
+
return payload
|
|
119
|
+
|
|
120
|
+
def _post(self, endpoint: str, payload: Any) -> dict[str, Any]:
|
|
121
|
+
request_kwargs = {
|
|
122
|
+
"json": payload,
|
|
123
|
+
"timeout": self.timeout,
|
|
124
|
+
}
|
|
125
|
+
if self.verify is not None:
|
|
126
|
+
request_kwargs["verify"] = self.verify
|
|
127
|
+
|
|
128
|
+
response = self.session.post(endpoint, **request_kwargs)
|
|
129
|
+
response.raise_for_status()
|
|
130
|
+
|
|
131
|
+
return self._format_response(response.json())
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _array_to_list(array: np.ndarray) -> list[Any]:
|
|
135
|
+
if not isinstance(array, np.ndarray):
|
|
136
|
+
raise TypeError(f"Expected numpy.ndarray, got {type(array).__name__}")
|
|
137
|
+
return FalconClient._replace_nan(array.tolist())
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _replace_nan(value: Any) -> Any:
|
|
141
|
+
if isinstance(value, list):
|
|
142
|
+
return [FalconClient._replace_nan(item) for item in value]
|
|
143
|
+
if isinstance(value, float) and math.isnan(value):
|
|
144
|
+
return None
|
|
145
|
+
return value
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _format_response(data: Any) -> dict[str, Any]:
|
|
149
|
+
if not isinstance(data, dict):
|
|
150
|
+
raise FalconAPIError(f"Unexpected response format: {data}")
|
|
151
|
+
|
|
152
|
+
code = data.get("code")
|
|
153
|
+
message = data.get("message", "")
|
|
154
|
+
if code != 200:
|
|
155
|
+
raise FalconAPIError(f"Falcon API error: code={code}, message={message}")
|
|
156
|
+
|
|
157
|
+
return {"prob_prediction": data.get("data")}
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from falcontst import FalconAPIError, FalconClient
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DummyResponse:
|
|
8
|
+
def __init__(self, data):
|
|
9
|
+
self.data = data
|
|
10
|
+
|
|
11
|
+
def raise_for_status(self):
|
|
12
|
+
return None
|
|
13
|
+
|
|
14
|
+
def json(self):
|
|
15
|
+
return self.data
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DummySession:
|
|
19
|
+
def __init__(self, data):
|
|
20
|
+
self.data = data
|
|
21
|
+
self.calls = []
|
|
22
|
+
|
|
23
|
+
def post(self, url, **kwargs):
|
|
24
|
+
self.calls.append({"url": url, **kwargs})
|
|
25
|
+
return DummyResponse(self.data)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_quantile_predict_posts_expected_payload_and_returns_json():
|
|
29
|
+
session = DummySession(
|
|
30
|
+
{
|
|
31
|
+
"code": 200,
|
|
32
|
+
"message": "success",
|
|
33
|
+
"data": [[[1.6039936542510986, -0.09775638580322266, -1.3819665908813477]]],
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
client = FalconClient(endpoint="https://example.test/predict", timeout=5, session=session)
|
|
37
|
+
|
|
38
|
+
result = client.quantile_predict(
|
|
39
|
+
context=np.array([[1.2, 3.4, 5.6, 7.8, 9.0], [2.1, 4.3, 6.5, 8.7, 0.9]]),
|
|
40
|
+
prediction_length=3,
|
|
41
|
+
model_name="demo_model",
|
|
42
|
+
group_ids=np.array([0, 0]),
|
|
43
|
+
input_mask=np.array([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]),
|
|
44
|
+
is_multivariate=False,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
assert session.calls == [
|
|
48
|
+
{
|
|
49
|
+
"url": "https://example.test/predict",
|
|
50
|
+
"json": {
|
|
51
|
+
"context": [[1.2, 3.4, 5.6, 7.8, 9.0], [2.1, 4.3, 6.5, 8.7, 0.9]],
|
|
52
|
+
"prediction_length": 3,
|
|
53
|
+
"model_name": "demo_model",
|
|
54
|
+
"group_ids": [0, 0],
|
|
55
|
+
"input_mask": [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]],
|
|
56
|
+
"is_multivariate": False,
|
|
57
|
+
},
|
|
58
|
+
"timeout": 5,
|
|
59
|
+
}
|
|
60
|
+
]
|
|
61
|
+
assert result == {
|
|
62
|
+
"prob_prediction": [
|
|
63
|
+
[[1.6039936542510986, -0.09775638580322266, -1.3819665908813477]]
|
|
64
|
+
]
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_batch_predict_posts_expected_payload_and_returns_json():
|
|
69
|
+
session = DummySession(
|
|
70
|
+
{
|
|
71
|
+
"code": 200,
|
|
72
|
+
"message": "success",
|
|
73
|
+
"data": [
|
|
74
|
+
[[[1.1, 1.2, 1.3]]],
|
|
75
|
+
[[[2.1, 2.2, 2.3]]],
|
|
76
|
+
],
|
|
77
|
+
}
|
|
78
|
+
)
|
|
79
|
+
client = FalconClient(
|
|
80
|
+
batch_endpoint="https://example.test/batch-predict",
|
|
81
|
+
timeout=5,
|
|
82
|
+
session=session,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
result = client.batch_predict(
|
|
86
|
+
[
|
|
87
|
+
{
|
|
88
|
+
"context": np.array([[1.2, 3.4, 5.6]]),
|
|
89
|
+
"prediction_length": 3,
|
|
90
|
+
"model_name": "demo_model",
|
|
91
|
+
"group_ids": np.array([0]),
|
|
92
|
+
"input_mask": np.array([[1, 1, 1]]),
|
|
93
|
+
"is_multivariate": False,
|
|
94
|
+
},
|
|
95
|
+
{
|
|
96
|
+
"context": np.array([[2.1, np.nan, 6.5]]),
|
|
97
|
+
"prediction_length": 3,
|
|
98
|
+
},
|
|
99
|
+
]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
assert session.calls == [
|
|
103
|
+
{
|
|
104
|
+
"url": "https://example.test/batch-predict",
|
|
105
|
+
"json": [
|
|
106
|
+
{
|
|
107
|
+
"context": [[1.2, 3.4, 5.6]],
|
|
108
|
+
"prediction_length": 3,
|
|
109
|
+
"model_name": "demo_model",
|
|
110
|
+
"group_ids": [0],
|
|
111
|
+
"input_mask": [[1, 1, 1]],
|
|
112
|
+
"is_multivariate": False,
|
|
113
|
+
},
|
|
114
|
+
{
|
|
115
|
+
"context": [[2.1, None, 6.5]],
|
|
116
|
+
"prediction_length": 3,
|
|
117
|
+
"is_multivariate": False,
|
|
118
|
+
},
|
|
119
|
+
],
|
|
120
|
+
"timeout": 5,
|
|
121
|
+
}
|
|
122
|
+
]
|
|
123
|
+
assert result == {
|
|
124
|
+
"prob_prediction": [
|
|
125
|
+
[[[1.1, 1.2, 1.3]]],
|
|
126
|
+
[[[2.1, 2.2, 2.3]]],
|
|
127
|
+
]
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_batch_predict_requires_sequence_of_objects():
|
|
132
|
+
client = FalconClient(session=DummySession({"data": []}))
|
|
133
|
+
|
|
134
|
+
with pytest.raises(TypeError, match="Expected objects to be a sequence"):
|
|
135
|
+
client.batch_predict("not a batch")
|
|
136
|
+
|
|
137
|
+
with pytest.raises(TypeError, match="Expected batch item to be a mapping"):
|
|
138
|
+
client.batch_predict([["not", "a", "mapping"]])
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_optional_arguments_are_omitted_when_none():
|
|
142
|
+
session = DummySession({"code": 200, "message": "success", "data": [1, 2, 3]})
|
|
143
|
+
client = FalconClient(endpoint="https://example.test/predict", session=session)
|
|
144
|
+
|
|
145
|
+
client.quantile_predict(
|
|
146
|
+
context=np.array([[1.0, 2.0]]),
|
|
147
|
+
prediction_length=1,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
payload = session.calls[0]["json"]
|
|
151
|
+
assert "model_name" not in payload
|
|
152
|
+
assert "group_ids" not in payload
|
|
153
|
+
assert "input_mask" not in payload
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_verify_is_passed_when_configured():
|
|
157
|
+
session = DummySession({"code": 200, "message": "success", "data": [1, 2, 3]})
|
|
158
|
+
client = FalconClient(
|
|
159
|
+
endpoint="https://example.test/predict",
|
|
160
|
+
session=session,
|
|
161
|
+
verify=False,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
client.quantile_predict(
|
|
165
|
+
context=np.array([[1.0, 2.0]]),
|
|
166
|
+
prediction_length=1,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
assert session.calls[0]["verify"] is False
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def test_nan_values_are_sent_as_json_null():
|
|
173
|
+
session = DummySession({"code": 200, "message": "success", "data": [1, 2, 3]})
|
|
174
|
+
client = FalconClient(endpoint="https://example.test/predict", session=session)
|
|
175
|
+
|
|
176
|
+
client.quantile_predict(
|
|
177
|
+
context=np.array([[1.0, np.nan], [np.nan, 4.0]]),
|
|
178
|
+
prediction_length=1,
|
|
179
|
+
group_ids=np.array([0, np.nan]),
|
|
180
|
+
input_mask=np.array([[1.0, np.nan], [1.0, 0.0]]),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
payload = session.calls[0]["json"]
|
|
184
|
+
assert payload["context"] == [[1.0, None], [None, 4.0]]
|
|
185
|
+
assert payload["group_ids"] == [0.0, None]
|
|
186
|
+
assert payload["input_mask"] == [[1.0, None], [1.0, 0.0]]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def test_context_must_be_numpy_array():
|
|
190
|
+
client = FalconClient(session=DummySession({"data": []}))
|
|
191
|
+
|
|
192
|
+
with pytest.raises(TypeError, match="Expected numpy.ndarray"):
|
|
193
|
+
client.quantile_predict(
|
|
194
|
+
context=[[1.0, 2.0]],
|
|
195
|
+
prediction_length=1,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_prediction_length_must_be_int():
|
|
200
|
+
client = FalconClient(session=DummySession({"data": []}))
|
|
201
|
+
|
|
202
|
+
with pytest.raises(TypeError, match="Expected prediction_length to be int"):
|
|
203
|
+
client.quantile_predict(
|
|
204
|
+
context=np.array([[1.0, 2.0]]),
|
|
205
|
+
prediction_length=None,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_prediction_length_must_be_positive():
|
|
210
|
+
client = FalconClient(session=DummySession({"data": []}))
|
|
211
|
+
|
|
212
|
+
with pytest.raises(ValueError, match="prediction_length must be greater than 0"):
|
|
213
|
+
client.quantile_predict(
|
|
214
|
+
context=np.array([[1.0, 2.0]]),
|
|
215
|
+
prediction_length=0,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def test_api_error_is_raised_when_code_is_not_200():
|
|
220
|
+
client = FalconClient(
|
|
221
|
+
session=DummySession({"code": 500, "message": "model not found", "data": None})
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
with pytest.raises(FalconAPIError, match="code=500, message=model not found"):
|
|
225
|
+
client.quantile_predict(
|
|
226
|
+
context=np.array([[1.0, 2.0]]),
|
|
227
|
+
prediction_length=1,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def test_unexpected_response_format_raises_api_error():
|
|
232
|
+
client = FalconClient(session=DummySession([1, 2, 3]))
|
|
233
|
+
|
|
234
|
+
with pytest.raises(FalconAPIError, match="Unexpected response format"):
|
|
235
|
+
client.quantile_predict(
|
|
236
|
+
context=np.array([[1.0, 2.0]]),
|
|
237
|
+
prediction_length=1,
|
|
238
|
+
)
|