flexinference 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+ from . import models
4
+ from ._client import DEFAULT_BASE_URL, FlexInference, FlexInferenceError
5
+
6
+ __all__ = ["DEFAULT_BASE_URL", "FlexInference", "FlexInferenceError", "models"]
@@ -0,0 +1,156 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from collections.abc import Iterator, Mapping
5
+ from typing import Any, Literal, cast, overload
6
+
7
+ import httpx
8
+
9
+ from .models import (
10
+ CreateChatCompletionRequest,
11
+ CreateChatCompletionResponse,
12
+ CreateChatCompletionStreamResponse,
13
+ CreateResponse,
14
+ Response,
15
+ ResponseStreamEvent,
16
+ )
17
+
18
+ DEFAULT_BASE_URL = "https://api.flexinference.com/v1"
19
+
20
+
21
+ class FlexInferenceError(Exception):
22
+ """An error returned by FlexInference or by the upstream provider, passed through."""
23
+
24
+ def __init__(self, status: int, body: Mapping[str, Any] | None, fallback: str) -> None:
25
+ err = body.get("error") if isinstance(body, Mapping) else None
26
+ err = err if isinstance(err, Mapping) else {}
27
+ super().__init__(err.get("message") or fallback)
28
+ self.status = status
29
+ self.type: str | None = err.get("type")
30
+ self.code: str | None = err.get("code")
31
+ self.param: str | None = err.get("param")
32
+
33
+
34
+ class FlexInference:
35
+ def __init__(
36
+ self,
37
+ *,
38
+ api_key: str,
39
+ base_url: str = DEFAULT_BASE_URL,
40
+ client: httpx.Client | None = None,
41
+ ) -> None:
42
+ if not api_key:
43
+ raise ValueError("FlexInference: `api_key` is required.")
44
+ self._base_url = base_url.rstrip("/")
45
+ # Inference calls routinely run far longer than httpx's 5s default; match the
46
+ # OpenAI SDK's generous read budget while keeping a short connect timeout.
47
+ self._client = client or httpx.Client(timeout=httpx.Timeout(600.0, connect=10.0))
48
+ self._headers = {
49
+ "Authorization": f"Bearer {api_key}",
50
+ "Content-Type": "application/json",
51
+ "Accept": "text/event-stream, application/json",
52
+ }
53
+ self.responses = _Responses(self)
54
+ self.chat = _Chat(self)
55
+
56
+ def close(self) -> None:
57
+ self._client.close()
58
+
59
+ def __enter__(self) -> FlexInference:
60
+ return self
61
+
62
+ def __exit__(self, *exc: object) -> None:
63
+ self.close()
64
+
65
+ def _post_json(self, path: str, payload: Mapping[str, Any]) -> Any:
66
+ r = self._client.post(
67
+ f"{self._base_url}{path}", headers=self._headers, content=json.dumps(payload)
68
+ )
69
+ if r.status_code >= 400:
70
+ raise FlexInferenceError(r.status_code, _safe_json(r), f"HTTP {r.status_code}")
71
+ return r.json()
72
+
73
+ def _post_stream(self, path: str, payload: Mapping[str, Any]) -> Iterator[Any]:
74
+ with self._client.stream(
75
+ "POST", f"{self._base_url}{path}", headers=self._headers, content=json.dumps(payload)
76
+ ) as r:
77
+ if r.status_code >= 400:
78
+ r.read()
79
+ raise FlexInferenceError(r.status_code, _safe_json(r), f"HTTP {r.status_code}")
80
+ yield from _parse_sse(r)
81
+
82
+
83
+ def _safe_json(r: httpx.Response) -> Mapping[str, Any] | None:
84
+ try:
85
+ data = r.json()
86
+ except (json.JSONDecodeError, ValueError):
87
+ return None
88
+ return data if isinstance(data, Mapping) else None
89
+
90
+
91
+ def _parse_sse(r: httpx.Response) -> Iterator[Any]:
92
+ data_lines: list[str] = []
93
+ for line in r.iter_lines():
94
+ if line == "":
95
+ if data_lines:
96
+ data = "\n".join(data_lines)
97
+ data_lines = []
98
+ if data == "[DONE]":
99
+ return
100
+ yield json.loads(data)
101
+ continue
102
+ if line.startswith("data:"):
103
+ data_lines.append(line[5:].lstrip(" "))
104
+ if data_lines:
105
+ data = "\n".join(data_lines)
106
+ if data != "[DONE]":
107
+ yield json.loads(data)
108
+
109
+
110
+ class _Responses:
111
+ def __init__(self, parent: FlexInference) -> None:
112
+ self._parent = parent
113
+
114
+ @overload
115
+ def create(self, body: CreateResponse, *, stream: Literal[False] = False) -> Response: ...
116
+ @overload
117
+ def create(
118
+ self, body: CreateResponse, *, stream: Literal[True]
119
+ ) -> Iterator[ResponseStreamEvent]: ...
120
+ def create(
121
+ self, body: CreateResponse, *, stream: bool = False
122
+ ) -> Response | Iterator[ResponseStreamEvent]:
123
+ payload = dict(body)
124
+ payload["stream"] = stream
125
+ if stream:
126
+ return self._parent._post_stream("/responses", payload)
127
+ return cast(Response, self._parent._post_json("/responses", payload))
128
+
129
+
130
+ class _Chat:
131
+ def __init__(self, parent: FlexInference) -> None:
132
+ self.completions = _ChatCompletions(parent)
133
+
134
+
135
+ class _ChatCompletions:
136
+ def __init__(self, parent: FlexInference) -> None:
137
+ self._parent = parent
138
+
139
+ @overload
140
+ def create(
141
+ self, body: CreateChatCompletionRequest, *, stream: Literal[False] = False
142
+ ) -> CreateChatCompletionResponse: ...
143
+ @overload
144
+ def create(
145
+ self, body: CreateChatCompletionRequest, *, stream: Literal[True]
146
+ ) -> Iterator[CreateChatCompletionStreamResponse]: ...
147
+ def create(
148
+ self, body: CreateChatCompletionRequest, *, stream: bool = False
149
+ ) -> CreateChatCompletionResponse | Iterator[CreateChatCompletionStreamResponse]:
150
+ payload = dict(body)
151
+ payload["stream"] = stream
152
+ if stream:
153
+ return self._parent._post_stream("/chat/completions", payload)
154
+ return cast(
155
+ CreateChatCompletionResponse, self._parent._post_json("/chat/completions", payload)
156
+ )