flexinference 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexinference/__init__.py +6 -0
- flexinference/_client.py +156 -0
- flexinference/models.py +2521 -0
- flexinference/py.typed +0 -0
- flexinference-0.1.0.dist-info/METADATA +120 -0
- flexinference-0.1.0.dist-info/RECORD +8 -0
- flexinference-0.1.0.dist-info/WHEEL +4 -0
- flexinference-0.1.0.dist-info/licenses/LICENSE +21 -0
flexinference/_client.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterator, Mapping
|
|
5
|
+
from typing import Any, Literal, cast, overload
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from .models import (
|
|
10
|
+
CreateChatCompletionRequest,
|
|
11
|
+
CreateChatCompletionResponse,
|
|
12
|
+
CreateChatCompletionStreamResponse,
|
|
13
|
+
CreateResponse,
|
|
14
|
+
Response,
|
|
15
|
+
ResponseStreamEvent,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
DEFAULT_BASE_URL = "https://api.flexinference.com/v1"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FlexInferenceError(Exception):
|
|
22
|
+
"""An error returned by FlexInference or by the upstream provider, passed through."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, status: int, body: Mapping[str, Any] | None, fallback: str) -> None:
|
|
25
|
+
err = body.get("error") if isinstance(body, Mapping) else None
|
|
26
|
+
err = err if isinstance(err, Mapping) else {}
|
|
27
|
+
super().__init__(err.get("message") or fallback)
|
|
28
|
+
self.status = status
|
|
29
|
+
self.type: str | None = err.get("type")
|
|
30
|
+
self.code: str | None = err.get("code")
|
|
31
|
+
self.param: str | None = err.get("param")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FlexInference:
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
*,
|
|
38
|
+
api_key: str,
|
|
39
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
40
|
+
client: httpx.Client | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
if not api_key:
|
|
43
|
+
raise ValueError("FlexInference: `api_key` is required.")
|
|
44
|
+
self._base_url = base_url.rstrip("/")
|
|
45
|
+
# Inference calls routinely run far longer than httpx's 5s default; match the
|
|
46
|
+
# OpenAI SDK's generous read budget while keeping a short connect timeout.
|
|
47
|
+
self._client = client or httpx.Client(timeout=httpx.Timeout(600.0, connect=10.0))
|
|
48
|
+
self._headers = {
|
|
49
|
+
"Authorization": f"Bearer {api_key}",
|
|
50
|
+
"Content-Type": "application/json",
|
|
51
|
+
"Accept": "text/event-stream, application/json",
|
|
52
|
+
}
|
|
53
|
+
self.responses = _Responses(self)
|
|
54
|
+
self.chat = _Chat(self)
|
|
55
|
+
|
|
56
|
+
def close(self) -> None:
|
|
57
|
+
self._client.close()
|
|
58
|
+
|
|
59
|
+
def __enter__(self) -> FlexInference:
|
|
60
|
+
return self
|
|
61
|
+
|
|
62
|
+
def __exit__(self, *exc: object) -> None:
|
|
63
|
+
self.close()
|
|
64
|
+
|
|
65
|
+
def _post_json(self, path: str, payload: Mapping[str, Any]) -> Any:
|
|
66
|
+
r = self._client.post(
|
|
67
|
+
f"{self._base_url}{path}", headers=self._headers, content=json.dumps(payload)
|
|
68
|
+
)
|
|
69
|
+
if r.status_code >= 400:
|
|
70
|
+
raise FlexInferenceError(r.status_code, _safe_json(r), f"HTTP {r.status_code}")
|
|
71
|
+
return r.json()
|
|
72
|
+
|
|
73
|
+
def _post_stream(self, path: str, payload: Mapping[str, Any]) -> Iterator[Any]:
|
|
74
|
+
with self._client.stream(
|
|
75
|
+
"POST", f"{self._base_url}{path}", headers=self._headers, content=json.dumps(payload)
|
|
76
|
+
) as r:
|
|
77
|
+
if r.status_code >= 400:
|
|
78
|
+
r.read()
|
|
79
|
+
raise FlexInferenceError(r.status_code, _safe_json(r), f"HTTP {r.status_code}")
|
|
80
|
+
yield from _parse_sse(r)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _safe_json(r: httpx.Response) -> Mapping[str, Any] | None:
|
|
84
|
+
try:
|
|
85
|
+
data = r.json()
|
|
86
|
+
except (json.JSONDecodeError, ValueError):
|
|
87
|
+
return None
|
|
88
|
+
return data if isinstance(data, Mapping) else None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _parse_sse(r: httpx.Response) -> Iterator[Any]:
|
|
92
|
+
data_lines: list[str] = []
|
|
93
|
+
for line in r.iter_lines():
|
|
94
|
+
if line == "":
|
|
95
|
+
if data_lines:
|
|
96
|
+
data = "\n".join(data_lines)
|
|
97
|
+
data_lines = []
|
|
98
|
+
if data == "[DONE]":
|
|
99
|
+
return
|
|
100
|
+
yield json.loads(data)
|
|
101
|
+
continue
|
|
102
|
+
if line.startswith("data:"):
|
|
103
|
+
data_lines.append(line[5:].lstrip(" "))
|
|
104
|
+
if data_lines:
|
|
105
|
+
data = "\n".join(data_lines)
|
|
106
|
+
if data != "[DONE]":
|
|
107
|
+
yield json.loads(data)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class _Responses:
|
|
111
|
+
def __init__(self, parent: FlexInference) -> None:
|
|
112
|
+
self._parent = parent
|
|
113
|
+
|
|
114
|
+
@overload
|
|
115
|
+
def create(self, body: CreateResponse, *, stream: Literal[False] = False) -> Response: ...
|
|
116
|
+
@overload
|
|
117
|
+
def create(
|
|
118
|
+
self, body: CreateResponse, *, stream: Literal[True]
|
|
119
|
+
) -> Iterator[ResponseStreamEvent]: ...
|
|
120
|
+
def create(
|
|
121
|
+
self, body: CreateResponse, *, stream: bool = False
|
|
122
|
+
) -> Response | Iterator[ResponseStreamEvent]:
|
|
123
|
+
payload = dict(body)
|
|
124
|
+
payload["stream"] = stream
|
|
125
|
+
if stream:
|
|
126
|
+
return self._parent._post_stream("/responses", payload)
|
|
127
|
+
return cast(Response, self._parent._post_json("/responses", payload))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class _Chat:
|
|
131
|
+
def __init__(self, parent: FlexInference) -> None:
|
|
132
|
+
self.completions = _ChatCompletions(parent)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class _ChatCompletions:
|
|
136
|
+
def __init__(self, parent: FlexInference) -> None:
|
|
137
|
+
self._parent = parent
|
|
138
|
+
|
|
139
|
+
@overload
|
|
140
|
+
def create(
|
|
141
|
+
self, body: CreateChatCompletionRequest, *, stream: Literal[False] = False
|
|
142
|
+
) -> CreateChatCompletionResponse: ...
|
|
143
|
+
@overload
|
|
144
|
+
def create(
|
|
145
|
+
self, body: CreateChatCompletionRequest, *, stream: Literal[True]
|
|
146
|
+
) -> Iterator[CreateChatCompletionStreamResponse]: ...
|
|
147
|
+
def create(
|
|
148
|
+
self, body: CreateChatCompletionRequest, *, stream: bool = False
|
|
149
|
+
) -> CreateChatCompletionResponse | Iterator[CreateChatCompletionStreamResponse]:
|
|
150
|
+
payload = dict(body)
|
|
151
|
+
payload["stream"] = stream
|
|
152
|
+
if stream:
|
|
153
|
+
return self._parent._post_stream("/chat/completions", payload)
|
|
154
|
+
return cast(
|
|
155
|
+
CreateChatCompletionResponse, self._parent._post_json("/chat/completions", payload)
|
|
156
|
+
)
|