xinference 1.7.0.post1__py3-none-any.whl → 1.7.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +3 -4
- xinference/client/__init__.py +2 -0
- xinference/client/common.py +49 -2
- xinference/client/handlers.py +18 -0
- xinference/client/restful/async_restful_client.py +1760 -0
- xinference/client/restful/restful_client.py +74 -78
- xinference/core/media_interface.py +3 -1
- xinference/core/model.py +5 -4
- xinference/core/supervisor.py +10 -5
- xinference/core/worker.py +15 -14
- xinference/deploy/local.py +51 -9
- xinference/deploy/worker.py +5 -3
- xinference/device_utils.py +22 -3
- xinference/model/audio/fish_speech.py +23 -34
- xinference/model/audio/model_spec.json +4 -2
- xinference/model/audio/model_spec_modelscope.json +4 -2
- xinference/model/audio/utils.py +2 -2
- xinference/model/core.py +1 -0
- xinference/model/embedding/__init__.py +8 -8
- xinference/model/embedding/custom.py +6 -1
- xinference/model/embedding/embed_family.py +0 -41
- xinference/model/embedding/model_spec.json +10 -1
- xinference/model/embedding/model_spec_modelscope.json +10 -1
- xinference/model/embedding/sentence_transformers/core.py +30 -15
- xinference/model/flexible/core.py +1 -1
- xinference/model/flexible/launchers/__init__.py +2 -0
- xinference/model/flexible/launchers/image_process_launcher.py +1 -1
- xinference/model/flexible/launchers/modelscope_launcher.py +47 -0
- xinference/model/flexible/launchers/transformers_launcher.py +5 -5
- xinference/model/flexible/launchers/yolo_launcher.py +62 -0
- xinference/model/llm/__init__.py +7 -0
- xinference/model/llm/core.py +18 -1
- xinference/model/llm/llama_cpp/core.py +1 -1
- xinference/model/llm/llm_family.json +41 -1
- xinference/model/llm/llm_family.py +6 -0
- xinference/model/llm/llm_family_modelscope.json +43 -1
- xinference/model/llm/mlx/core.py +271 -18
- xinference/model/llm/mlx/distributed_models/__init__.py +13 -0
- xinference/model/llm/mlx/distributed_models/core.py +164 -0
- xinference/model/llm/mlx/distributed_models/deepseek_v3.py +75 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +82 -0
- xinference/model/llm/mlx/distributed_models/qwen3.py +82 -0
- xinference/model/llm/mlx/distributed_models/qwen3_moe.py +76 -0
- xinference/model/llm/reasoning_parser.py +12 -6
- xinference/model/llm/sglang/core.py +8 -4
- xinference/model/llm/transformers/chatglm.py +4 -1
- xinference/model/llm/transformers/core.py +4 -2
- xinference/model/llm/transformers/multimodal/cogagent.py +10 -4
- xinference/model/llm/transformers/multimodal/intern_vl.py +1 -1
- xinference/model/llm/utils.py +36 -17
- xinference/model/llm/vllm/core.py +142 -34
- xinference/model/llm/vllm/distributed_executor.py +96 -21
- xinference/model/llm/vllm/xavier/transfer.py +2 -2
- xinference/model/rerank/core.py +16 -9
- xinference/model/rerank/model_spec.json +3 -3
- xinference/model/rerank/model_spec_modelscope.json +3 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.9b12b7f9.js +3 -0
- xinference/web/ui/build/static/js/main.9b12b7f9.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0fd4820d93f99509e80d8702dc3f6f8272424acab5608fa7c0e82cb1d3250a87.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1460361af6975e63576708039f1cb732faf9c672d97c494d4055fc6331460be0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4efd8dda58fda83ed9546bf2f587df67f8d98e639117bee2d9326a9a1d9bebb2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5b2dafe5aa9e1105e0244a2b6751807342fa86aa0144b4e84d947a1686102715.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f75545479c17fdfe2a00235fa4a0e9da1ae95e6b3caafba87ded92de6b0240e4.json +1 -0
- xinference/web/ui/src/locales/en.json +3 -0
- xinference/web/ui/src/locales/ja.json +3 -0
- xinference/web/ui/src/locales/ko.json +3 -0
- xinference/web/ui/src/locales/zh.json +3 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/METADATA +4 -3
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/RECORD +77 -67
- xinference/web/ui/build/static/js/main.8a9e3ba0.js +0 -3
- xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/cc97b49285d7717c63374766c789141a4329a04582ab32756d7e0e614d4c5c7f.json +0 -1
- /xinference/web/ui/build/static/js/{main.8a9e3ba0.js.LICENSE.txt → main.9b12b7f9.js.LICENSE.txt} +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/WHEEL +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.7.0.post1.dist-info → xinference-1.7.1.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1760 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import json
|
|
16
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import aiohttp
|
|
19
|
+
|
|
20
|
+
from ..common import async_streaming_response_iterator, convert_float_to_int_or_str
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from ...types import (
|
|
24
|
+
ChatCompletion,
|
|
25
|
+
ChatCompletionChunk,
|
|
26
|
+
Completion,
|
|
27
|
+
CompletionChunk,
|
|
28
|
+
Embedding,
|
|
29
|
+
ImageList,
|
|
30
|
+
PytorchGenerateConfig,
|
|
31
|
+
VideoList,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _filter_params(params: Dict[Any, Any]) -> Dict[Any, Any]:
|
|
36
|
+
filtered = {}
|
|
37
|
+
for key, value in params.items():
|
|
38
|
+
if value is not None:
|
|
39
|
+
filtered[key] = value
|
|
40
|
+
return filtered
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
async def _get_error_string(response: aiohttp.ClientResponse) -> str:
|
|
44
|
+
result = None
|
|
45
|
+
try:
|
|
46
|
+
if response.content:
|
|
47
|
+
json_data = await response.json()
|
|
48
|
+
result = json_data["detail"]
|
|
49
|
+
except Exception:
|
|
50
|
+
pass
|
|
51
|
+
try:
|
|
52
|
+
response.raise_for_status()
|
|
53
|
+
except aiohttp.ClientError as e:
|
|
54
|
+
result = str(e)
|
|
55
|
+
if result is None:
|
|
56
|
+
result = "Unknown error"
|
|
57
|
+
await _release_response(response)
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def _release_response(response: aiohttp.ClientResponse):
|
|
62
|
+
"""Release the aiohttp response."""
|
|
63
|
+
response.release()
|
|
64
|
+
await response.wait_for_close()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class AsyncRESTfulModelHandle:
|
|
68
|
+
"""
|
|
69
|
+
A sync model interface (for RESTful client) which provides type hints that makes it much easier to use xinference
|
|
70
|
+
programmatically.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, model_uid: str, base_url: str, auth_headers: Dict):
|
|
74
|
+
self._model_uid = model_uid
|
|
75
|
+
self._base_url = base_url
|
|
76
|
+
self.auth_headers = auth_headers
|
|
77
|
+
self.session = aiohttp.ClientSession(
|
|
78
|
+
connector=aiohttp.TCPConnector(force_close=True)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
async def close(self):
|
|
82
|
+
"""Close the AsyncRESTfulModelHandle session."""
|
|
83
|
+
if self.session:
|
|
84
|
+
await self.session.close()
|
|
85
|
+
self.session = None
|
|
86
|
+
|
|
87
|
+
def __del__(self):
|
|
88
|
+
if self.session:
|
|
89
|
+
try:
|
|
90
|
+
loop = asyncio.get_event_loop()
|
|
91
|
+
except RuntimeError:
|
|
92
|
+
loop = asyncio.new_event_loop()
|
|
93
|
+
asyncio.set_event_loop(loop)
|
|
94
|
+
loop.run_until_complete(self.close())
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class AsyncRESTfulEmbeddingModelHandle(AsyncRESTfulModelHandle):
|
|
98
|
+
async def create_embedding(
|
|
99
|
+
self, input: Union[str, List[str]], **kwargs
|
|
100
|
+
) -> "Embedding":
|
|
101
|
+
"""
|
|
102
|
+
Create an Embedding from user input via RESTful APIs.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
input: Union[str, List[str]]
|
|
107
|
+
Input text to embed, encoded as a string or array of tokens.
|
|
108
|
+
To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
Embedding
|
|
113
|
+
The resulted Embedding vector that can be easily consumed by machine learning models and algorithms.
|
|
114
|
+
|
|
115
|
+
Raises
|
|
116
|
+
------
|
|
117
|
+
RuntimeError
|
|
118
|
+
Report the failure of embeddings and provide the error message.
|
|
119
|
+
|
|
120
|
+
"""
|
|
121
|
+
url = f"{self._base_url}/v1/embeddings"
|
|
122
|
+
request_body = {
|
|
123
|
+
"model": self._model_uid,
|
|
124
|
+
"input": input,
|
|
125
|
+
}
|
|
126
|
+
request_body.update(kwargs)
|
|
127
|
+
response = await self.session.post(
|
|
128
|
+
url, json=request_body, headers=self.auth_headers
|
|
129
|
+
)
|
|
130
|
+
if response.status != 200:
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
f"Failed to create the embeddings, detail: {await _get_error_string(response)}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
response_data = await response.json()
|
|
136
|
+
await _release_response(response)
|
|
137
|
+
return response_data
|
|
138
|
+
|
|
139
|
+
async def convert_ids_to_tokens(
|
|
140
|
+
self, input: Union[List, List[List]], **kwargs
|
|
141
|
+
) -> List[str]:
|
|
142
|
+
"""
|
|
143
|
+
Convert token IDs to human readable tokens via RESTful APIs.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
input: Union[List, List[List]]
|
|
148
|
+
Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
|
|
149
|
+
To convert multiple sequences in a single request, pass a list of token ID lists.
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
list
|
|
154
|
+
A list of decoded tokens in human readable format.
|
|
155
|
+
|
|
156
|
+
Raises
|
|
157
|
+
------
|
|
158
|
+
RuntimeError
|
|
159
|
+
Report the failure of token conversion and provide the error message.
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
url = f"{self._base_url}/v1/convert_ids_to_tokens"
|
|
163
|
+
request_body = {
|
|
164
|
+
"model": self._model_uid,
|
|
165
|
+
"input": input,
|
|
166
|
+
}
|
|
167
|
+
request_body.update(kwargs)
|
|
168
|
+
response = await self.session.post(
|
|
169
|
+
url, json=request_body, headers=self.auth_headers
|
|
170
|
+
)
|
|
171
|
+
if response.status != 200:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Failed to decode token ids, detail: {await _get_error_string(response)}"
|
|
174
|
+
)
|
|
175
|
+
response_data = await response.json()
|
|
176
|
+
await _release_response(response)
|
|
177
|
+
return response_data
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class AsyncRESTfulRerankModelHandle(AsyncRESTfulModelHandle):
|
|
181
|
+
async def rerank(
|
|
182
|
+
self,
|
|
183
|
+
documents: List[str],
|
|
184
|
+
query: str,
|
|
185
|
+
top_n: Optional[int] = None,
|
|
186
|
+
max_chunks_per_doc: Optional[int] = None,
|
|
187
|
+
return_documents: Optional[bool] = None,
|
|
188
|
+
return_len: Optional[bool] = None,
|
|
189
|
+
**kwargs,
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Returns an ordered list of documents ordered by their relevance to the provided query.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
query: str
|
|
197
|
+
The search query
|
|
198
|
+
documents: List[str]
|
|
199
|
+
The documents to rerank
|
|
200
|
+
top_n: int
|
|
201
|
+
The number of results to return, defaults to returning all results
|
|
202
|
+
max_chunks_per_doc: int
|
|
203
|
+
The maximum number of chunks derived from a document
|
|
204
|
+
return_documents: bool
|
|
205
|
+
if return documents
|
|
206
|
+
return_len: bool
|
|
207
|
+
if return tokens len
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
Scores
|
|
211
|
+
The scores of documents ordered by their relevance to the provided query
|
|
212
|
+
|
|
213
|
+
Raises
|
|
214
|
+
------
|
|
215
|
+
RuntimeError
|
|
216
|
+
Report the failure of rerank and provide the error message.
|
|
217
|
+
"""
|
|
218
|
+
url = f"{self._base_url}/v1/rerank"
|
|
219
|
+
request_body = {
|
|
220
|
+
"model": self._model_uid,
|
|
221
|
+
"documents": documents,
|
|
222
|
+
"query": query,
|
|
223
|
+
"top_n": top_n,
|
|
224
|
+
"max_chunks_per_doc": max_chunks_per_doc,
|
|
225
|
+
"return_documents": return_documents,
|
|
226
|
+
"return_len": return_len,
|
|
227
|
+
"kwargs": json.dumps(kwargs),
|
|
228
|
+
}
|
|
229
|
+
request_body.update(kwargs)
|
|
230
|
+
response = await self.session.post(
|
|
231
|
+
url, json=request_body, headers=self.auth_headers
|
|
232
|
+
)
|
|
233
|
+
if response.status != 200:
|
|
234
|
+
raise RuntimeError(
|
|
235
|
+
f"Failed to rerank documents, detail: {await _get_error_string(response)}"
|
|
236
|
+
)
|
|
237
|
+
response_data = await response.json()
|
|
238
|
+
await _release_response(response)
|
|
239
|
+
return response_data
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class AsyncRESTfulImageModelHandle(AsyncRESTfulModelHandle):
|
|
243
|
+
async def text_to_image(
|
|
244
|
+
self,
|
|
245
|
+
prompt: str,
|
|
246
|
+
n: int = 1,
|
|
247
|
+
size: str = "1024*1024",
|
|
248
|
+
response_format: str = "url",
|
|
249
|
+
**kwargs,
|
|
250
|
+
) -> "ImageList":
|
|
251
|
+
"""
|
|
252
|
+
Creates an image by the input text.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
prompt: `str` or `List[str]`
|
|
257
|
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
258
|
+
n: `int`, defaults to 1
|
|
259
|
+
The number of images to generate per prompt. Must be between 1 and 10.
|
|
260
|
+
size: `str`, defaults to `1024*1024`
|
|
261
|
+
The width*height in pixels of the generated image. Must be one of 256x256, 512x512, or 1024x1024.
|
|
262
|
+
response_format: `str`, defaults to `url`
|
|
263
|
+
The format in which the generated images are returned. Must be one of url or b64_json.
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
ImageList
|
|
267
|
+
A list of image objects.
|
|
268
|
+
"""
|
|
269
|
+
url = f"{self._base_url}/v1/images/generations"
|
|
270
|
+
request_body = {
|
|
271
|
+
"model": self._model_uid,
|
|
272
|
+
"prompt": prompt,
|
|
273
|
+
"n": n,
|
|
274
|
+
"size": size,
|
|
275
|
+
"response_format": response_format,
|
|
276
|
+
"kwargs": json.dumps(kwargs),
|
|
277
|
+
}
|
|
278
|
+
response = await self.session.post(
|
|
279
|
+
url, json=request_body, headers=self.auth_headers
|
|
280
|
+
)
|
|
281
|
+
if response.status != 200:
|
|
282
|
+
raise RuntimeError(
|
|
283
|
+
f"Failed to create the images, detail: {await _get_error_string(response)}"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
response_data = await response.json()
|
|
287
|
+
await _release_response(response)
|
|
288
|
+
return response_data
|
|
289
|
+
|
|
290
|
+
async def image_to_image(
|
|
291
|
+
self,
|
|
292
|
+
image: Union[str, bytes],
|
|
293
|
+
prompt: str,
|
|
294
|
+
negative_prompt: Optional[str] = None,
|
|
295
|
+
n: int = 1,
|
|
296
|
+
size: Optional[str] = None,
|
|
297
|
+
response_format: str = "url",
|
|
298
|
+
**kwargs,
|
|
299
|
+
) -> "ImageList":
|
|
300
|
+
"""
|
|
301
|
+
Creates an image by the input text.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
image: `Union[str, bytes]`
|
|
306
|
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
|
307
|
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
|
308
|
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
|
309
|
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
|
310
|
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
|
311
|
+
input to a single ControlNet.
|
|
312
|
+
prompt: `str` or `List[str]`
|
|
313
|
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
314
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
|
315
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
316
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
317
|
+
less than `1`).
|
|
318
|
+
n: `int`, defaults to 1
|
|
319
|
+
The number of images to generate per prompt. Must be between 1 and 10.
|
|
320
|
+
size: `str`, defaults to `1024*1024`
|
|
321
|
+
The width*height in pixels of the generated image. Must be one of 256x256, 512x512, or 1024x1024.
|
|
322
|
+
response_format: `str`, defaults to `url`
|
|
323
|
+
The format in which the generated images are returned. Must be one of url or b64_json.
|
|
324
|
+
Returns
|
|
325
|
+
-------
|
|
326
|
+
ImageList
|
|
327
|
+
A list of image objects.
|
|
328
|
+
:param prompt:
|
|
329
|
+
:param image:
|
|
330
|
+
"""
|
|
331
|
+
url = f"{self._base_url}/v1/images/variations"
|
|
332
|
+
params = {
|
|
333
|
+
"model": self._model_uid,
|
|
334
|
+
"prompt": prompt,
|
|
335
|
+
"negative_prompt": negative_prompt,
|
|
336
|
+
"n": n,
|
|
337
|
+
"size": size,
|
|
338
|
+
"response_format": response_format,
|
|
339
|
+
"kwargs": json.dumps(kwargs),
|
|
340
|
+
}
|
|
341
|
+
params = _filter_params(params)
|
|
342
|
+
files: List[Any] = []
|
|
343
|
+
for key, value in params.items():
|
|
344
|
+
files.append((key, (None, value)))
|
|
345
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
346
|
+
response = await self.session.post(url, files=files, headers=self.auth_headers)
|
|
347
|
+
if response.status != 200:
|
|
348
|
+
raise RuntimeError(
|
|
349
|
+
f"Failed to variants the images, detail: {await _get_error_string(response)}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
response_data = await response.json()
|
|
353
|
+
await _release_response(response)
|
|
354
|
+
return response_data
|
|
355
|
+
|
|
356
|
+
async def inpainting(
|
|
357
|
+
self,
|
|
358
|
+
image: Union[str, bytes],
|
|
359
|
+
mask_image: Union[str, bytes],
|
|
360
|
+
prompt: str,
|
|
361
|
+
negative_prompt: Optional[str] = None,
|
|
362
|
+
n: int = 1,
|
|
363
|
+
size: Optional[str] = None,
|
|
364
|
+
response_format: str = "url",
|
|
365
|
+
**kwargs,
|
|
366
|
+
) -> "ImageList":
|
|
367
|
+
"""
|
|
368
|
+
Inpaint an image by the input text.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
image: `Union[str, bytes]`
|
|
373
|
+
an image batch to be inpainted (which parts of the image to
|
|
374
|
+
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
|
|
375
|
+
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
|
|
376
|
+
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
|
|
377
|
+
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
|
|
378
|
+
if passing latents directly it is not encoded again.
|
|
379
|
+
mask_image: `Union[str, bytes]`
|
|
380
|
+
representing an image batch to mask `image`. White pixels in the mask
|
|
381
|
+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
|
382
|
+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
|
383
|
+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
|
384
|
+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
|
385
|
+
1)`, or `(H, W)`.
|
|
386
|
+
prompt: `str` or `List[str]`
|
|
387
|
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
388
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
|
389
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
390
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
391
|
+
less than `1`).
|
|
392
|
+
n: `int`, defaults to 1
|
|
393
|
+
The number of images to generate per prompt. Must be between 1 and 10.
|
|
394
|
+
size: `str`, defaults to None
|
|
395
|
+
The width*height in pixels of the generated image.
|
|
396
|
+
response_format: `str`, defaults to `url`
|
|
397
|
+
The format in which the generated images are returned. Must be one of url or b64_json.
|
|
398
|
+
Returns
|
|
399
|
+
-------
|
|
400
|
+
ImageList
|
|
401
|
+
A list of image objects.
|
|
402
|
+
:param prompt:
|
|
403
|
+
:param image:
|
|
404
|
+
"""
|
|
405
|
+
url = f"{self._base_url}/v1/images/inpainting"
|
|
406
|
+
params = {
|
|
407
|
+
"model": self._model_uid,
|
|
408
|
+
"prompt": prompt,
|
|
409
|
+
"negative_prompt": negative_prompt,
|
|
410
|
+
"n": n,
|
|
411
|
+
"size": size,
|
|
412
|
+
"response_format": response_format,
|
|
413
|
+
"kwargs": json.dumps(kwargs),
|
|
414
|
+
}
|
|
415
|
+
params = _filter_params(params)
|
|
416
|
+
files: List[Any] = []
|
|
417
|
+
for key, value in params.items():
|
|
418
|
+
files.append((key, (None, value)))
|
|
419
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
420
|
+
files.append(
|
|
421
|
+
("mask_image", ("mask_image", mask_image, "application/octet-stream"))
|
|
422
|
+
)
|
|
423
|
+
response = await self.session.post(url, files=files, headers=self.auth_headers)
|
|
424
|
+
if response.status != 200:
|
|
425
|
+
raise RuntimeError(
|
|
426
|
+
f"Failed to inpaint the images, detail: {await _get_error_string(response)}"
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
response_data = await response.json()
|
|
430
|
+
await _release_response(response)
|
|
431
|
+
return response_data
|
|
432
|
+
|
|
433
|
+
async def ocr(self, image: Union[str, bytes], **kwargs):
|
|
434
|
+
url = f"{self._base_url}/v1/images/ocr"
|
|
435
|
+
params = {
|
|
436
|
+
"model": self._model_uid,
|
|
437
|
+
"kwargs": json.dumps(kwargs),
|
|
438
|
+
}
|
|
439
|
+
params = _filter_params(params)
|
|
440
|
+
files: List[Any] = []
|
|
441
|
+
for key, value in params.items():
|
|
442
|
+
files.append((key, (None, value)))
|
|
443
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
444
|
+
response = await self.session.post(url, files=files, headers=self.auth_headers)
|
|
445
|
+
if response.status != 200:
|
|
446
|
+
raise RuntimeError(
|
|
447
|
+
f"Failed to ocr the images, detail: {await _get_error_string(response)}"
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
response_data = await response.json()
|
|
451
|
+
await _release_response(response)
|
|
452
|
+
return response_data
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class AsyncRESTfulVideoModelHandle(AsyncRESTfulModelHandle):
|
|
456
|
+
async def text_to_video(
|
|
457
|
+
self,
|
|
458
|
+
prompt: str,
|
|
459
|
+
n: int = 1,
|
|
460
|
+
**kwargs,
|
|
461
|
+
) -> "VideoList":
|
|
462
|
+
"""
|
|
463
|
+
Creates a video by the input text.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
prompt: `str` or `List[str]`
|
|
468
|
+
The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
|
|
469
|
+
n: `int`, defaults to 1
|
|
470
|
+
The number of videos to generate per prompt. Must be between 1 and 10.
|
|
471
|
+
Returns
|
|
472
|
+
-------
|
|
473
|
+
VideoList
|
|
474
|
+
A list of video objects.
|
|
475
|
+
"""
|
|
476
|
+
url = f"{self._base_url}/v1/video/generations"
|
|
477
|
+
request_body = {
|
|
478
|
+
"model": self._model_uid,
|
|
479
|
+
"prompt": prompt,
|
|
480
|
+
"n": n,
|
|
481
|
+
"kwargs": json.dumps(kwargs),
|
|
482
|
+
}
|
|
483
|
+
response = await self.session.post(
|
|
484
|
+
url, json=request_body, headers=self.auth_headers
|
|
485
|
+
)
|
|
486
|
+
if response.status != 200:
|
|
487
|
+
raise RuntimeError(
|
|
488
|
+
f"Failed to create the video, detail: {await _get_error_string(response)}"
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
response_data = await response.json()
|
|
492
|
+
await _release_response(response)
|
|
493
|
+
return response_data
|
|
494
|
+
|
|
495
|
+
async def image_to_video(
|
|
496
|
+
self,
|
|
497
|
+
image: Union[str, bytes],
|
|
498
|
+
prompt: str,
|
|
499
|
+
negative_prompt: Optional[str] = None,
|
|
500
|
+
n: int = 1,
|
|
501
|
+
**kwargs,
|
|
502
|
+
) -> "VideoList":
|
|
503
|
+
"""
|
|
504
|
+
Creates a video by the input image and text.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
image: `Union[str, bytes]`
|
|
509
|
+
The input image to condition the generation on.
|
|
510
|
+
prompt: `str` or `List[str]`
|
|
511
|
+
The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
|
|
512
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
|
513
|
+
The prompt or prompts not to guide the image generation.
|
|
514
|
+
n: `int`, defaults to 1
|
|
515
|
+
The number of videos to generate per prompt. Must be between 1 and 10.
|
|
516
|
+
Returns
|
|
517
|
+
-------
|
|
518
|
+
VideoList
|
|
519
|
+
A list of video objects.
|
|
520
|
+
"""
|
|
521
|
+
url = f"{self._base_url}/v1/video/generations/image"
|
|
522
|
+
params = {
|
|
523
|
+
"model": self._model_uid,
|
|
524
|
+
"prompt": prompt,
|
|
525
|
+
"negative_prompt": negative_prompt,
|
|
526
|
+
"n": n,
|
|
527
|
+
"kwargs": json.dumps(kwargs),
|
|
528
|
+
}
|
|
529
|
+
params = _filter_params(params)
|
|
530
|
+
files: List[Any] = []
|
|
531
|
+
for key, value in params.items():
|
|
532
|
+
files.append((key, (None, value)))
|
|
533
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
534
|
+
response = await self.session.post(url, files=files, headers=self.auth_headers)
|
|
535
|
+
if response.status != 200:
|
|
536
|
+
raise RuntimeError(
|
|
537
|
+
f"Failed to create the video from image, detail: {await _get_error_string(response)}"
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
response_data = await response.json()
|
|
541
|
+
await _release_response(response)
|
|
542
|
+
return response_data
|
|
543
|
+
|
|
544
|
+
async def flf_to_video(
|
|
545
|
+
self,
|
|
546
|
+
first_frame: Union[str, bytes],
|
|
547
|
+
last_frame: Union[str, bytes],
|
|
548
|
+
prompt: str,
|
|
549
|
+
negative_prompt: Optional[str] = None,
|
|
550
|
+
n: int = 1,
|
|
551
|
+
**kwargs,
|
|
552
|
+
) -> "VideoList":
|
|
553
|
+
"""
|
|
554
|
+
Creates a video by the first frame, last frame and text.
|
|
555
|
+
|
|
556
|
+
Parameters
|
|
557
|
+
----------
|
|
558
|
+
first_frame: `Union[str, bytes]`
|
|
559
|
+
The first frame to condition the generation on.
|
|
560
|
+
last_frame: `Union[str, bytes]`
|
|
561
|
+
The last frame to condition the generation on.
|
|
562
|
+
prompt: `str` or `List[str]`
|
|
563
|
+
The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
|
|
564
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
|
565
|
+
The prompt or prompts not to guide the image generation.
|
|
566
|
+
n: `int`, defaults to 1
|
|
567
|
+
The number of videos to generate per prompt. Must be between 1 and 10.
|
|
568
|
+
Returns
|
|
569
|
+
-------
|
|
570
|
+
VideoList
|
|
571
|
+
A list of video objects.
|
|
572
|
+
"""
|
|
573
|
+
url = f"{self._base_url}/v1/video/generations/flf"
|
|
574
|
+
params = {
|
|
575
|
+
"model": self._model_uid,
|
|
576
|
+
"prompt": prompt,
|
|
577
|
+
"negative_prompt": negative_prompt,
|
|
578
|
+
"n": n,
|
|
579
|
+
"kwargs": json.dumps(kwargs),
|
|
580
|
+
}
|
|
581
|
+
params = _filter_params(params)
|
|
582
|
+
files: List[Any] = []
|
|
583
|
+
for key, value in params.items():
|
|
584
|
+
files.append((key, (None, value)))
|
|
585
|
+
files.append(
|
|
586
|
+
("first_frame", ("image", first_frame, "application/octet-stream"))
|
|
587
|
+
)
|
|
588
|
+
files.append(("last_frame", ("image", last_frame, "application/octet-stream")))
|
|
589
|
+
response = await self.session.post(url, files=files, headers=self.auth_headers)
|
|
590
|
+
if response.status != 200:
|
|
591
|
+
raise RuntimeError(
|
|
592
|
+
f"Failed to create the video from image, detail: {await _get_error_string(response)}"
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
response_data = await response.json()
|
|
596
|
+
await _release_response(response)
|
|
597
|
+
return response_data
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
class AsyncRESTfulGenerateModelHandle(AsyncRESTfulModelHandle):
|
|
601
|
+
async def generate(
|
|
602
|
+
self,
|
|
603
|
+
prompt: str,
|
|
604
|
+
generate_config: Optional["PytorchGenerateConfig"] = None,
|
|
605
|
+
) -> Union["Completion", AsyncIterator["CompletionChunk"]]:
|
|
606
|
+
"""
|
|
607
|
+
Creates a completion for the provided prompt and parameters via RESTful APIs.
|
|
608
|
+
|
|
609
|
+
Parameters
|
|
610
|
+
----------
|
|
611
|
+
prompt: str
|
|
612
|
+
The user's message or user's input.
|
|
613
|
+
generate_config: Optional["PytorchGenerateConfig"]
|
|
614
|
+
Additional configuration for the chat generation.
|
|
615
|
+
"PytorchGenerateConfig" -> Configuration for pytorch model
|
|
616
|
+
|
|
617
|
+
Returns
|
|
618
|
+
-------
|
|
619
|
+
Union["Completion", AsyncIterator["CompletionChunk"]]
|
|
620
|
+
Stream is a parameter in generate_config.
|
|
621
|
+
When stream is set to True, the function will return AsyncIterator["CompletionChunk"].
|
|
622
|
+
When stream is set to False, the function will return "Completion".
|
|
623
|
+
|
|
624
|
+
Raises
|
|
625
|
+
------
|
|
626
|
+
RuntimeError
|
|
627
|
+
Fail to generate the completion from the server. Detailed information provided in error message.
|
|
628
|
+
|
|
629
|
+
"""
|
|
630
|
+
|
|
631
|
+
url = f"{self._base_url}/v1/completions"
|
|
632
|
+
|
|
633
|
+
request_body: Dict[str, Any] = {"model": self._model_uid, "prompt": prompt}
|
|
634
|
+
if generate_config is not None:
|
|
635
|
+
for key, value in generate_config.items():
|
|
636
|
+
request_body[key] = value
|
|
637
|
+
|
|
638
|
+
stream = bool(generate_config and generate_config.get("stream"))
|
|
639
|
+
|
|
640
|
+
response = await self.session.post(
|
|
641
|
+
url, json=request_body, headers=self.auth_headers
|
|
642
|
+
)
|
|
643
|
+
if response.status != 200:
|
|
644
|
+
raise RuntimeError(
|
|
645
|
+
f"Failed to generate completion, detail: {await _get_error_string(response)}"
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
if stream:
|
|
649
|
+
return async_streaming_response_iterator(response.content)
|
|
650
|
+
response_data = await response.json()
|
|
651
|
+
await _release_response(response)
|
|
652
|
+
return response_data
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
class AsyncRESTfulChatModelHandle(AsyncRESTfulGenerateModelHandle):
|
|
656
|
+
async def chat(
|
|
657
|
+
self,
|
|
658
|
+
messages: List[Dict],
|
|
659
|
+
tools: Optional[List[Dict]] = None,
|
|
660
|
+
generate_config: Optional["PytorchGenerateConfig"] = None,
|
|
661
|
+
) -> Union["ChatCompletion", AsyncIterator["ChatCompletionChunk"]]:
|
|
662
|
+
"""
|
|
663
|
+
Given a list of messages comprising a conversation, the model will return a response via RESTful APIs.
|
|
664
|
+
|
|
665
|
+
Parameters
|
|
666
|
+
----------
|
|
667
|
+
messages: List[Dict]
|
|
668
|
+
A list of messages comprising the conversation so far.
|
|
669
|
+
tools: Optional[List[Dict]]
|
|
670
|
+
A tool list.
|
|
671
|
+
generate_config: Optional["PytorchGenerateConfig"]
|
|
672
|
+
Additional configuration for the chat generation.
|
|
673
|
+
"PytorchGenerateConfig" -> configuration for pytorch model
|
|
674
|
+
|
|
675
|
+
Returns
|
|
676
|
+
-------
|
|
677
|
+
Union["ChatCompletion", AsyncIterator["ChatCompletionChunk"]]
|
|
678
|
+
Stream is a parameter in generate_config.
|
|
679
|
+
When stream is set to True, the function will return AsyncIterator["ChatCompletionChunk"].
|
|
680
|
+
When stream is set to False, the function will return "ChatCompletion".
|
|
681
|
+
|
|
682
|
+
Raises
|
|
683
|
+
------
|
|
684
|
+
RuntimeError
|
|
685
|
+
Report the failure to generate the chat from the server. Detailed information provided in error message.
|
|
686
|
+
|
|
687
|
+
"""
|
|
688
|
+
url = f"{self._base_url}/v1/chat/completions"
|
|
689
|
+
|
|
690
|
+
request_body: Dict[str, Any] = {
|
|
691
|
+
"model": self._model_uid,
|
|
692
|
+
"messages": messages,
|
|
693
|
+
}
|
|
694
|
+
if tools is not None:
|
|
695
|
+
request_body["tools"] = tools
|
|
696
|
+
if generate_config is not None:
|
|
697
|
+
for key, value in generate_config.items():
|
|
698
|
+
request_body[key] = value
|
|
699
|
+
|
|
700
|
+
stream = bool(generate_config and generate_config.get("stream"))
|
|
701
|
+
response = await self.session.post(
|
|
702
|
+
url, json=request_body, headers=self.auth_headers
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
if response.status != 200:
|
|
706
|
+
raise RuntimeError(
|
|
707
|
+
f"Failed to generate chat completion, detail: {await _get_error_string(response)}"
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
if stream:
|
|
711
|
+
return async_streaming_response_iterator(response.content)
|
|
712
|
+
|
|
713
|
+
response_data = await response.json()
|
|
714
|
+
await _release_response(response)
|
|
715
|
+
return response_data
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
class AsyncRESTfulAudioModelHandle(AsyncRESTfulModelHandle):
|
|
719
|
+
async def transcriptions(
|
|
720
|
+
self,
|
|
721
|
+
audio: bytes,
|
|
722
|
+
language: Optional[str] = None,
|
|
723
|
+
prompt: Optional[str] = None,
|
|
724
|
+
response_format: Optional[str] = "json",
|
|
725
|
+
temperature: Optional[float] = 0,
|
|
726
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
727
|
+
**kwargs,
|
|
728
|
+
):
|
|
729
|
+
"""
|
|
730
|
+
Transcribes audio into the input language.
|
|
731
|
+
|
|
732
|
+
Parameters
|
|
733
|
+
----------
|
|
734
|
+
|
|
735
|
+
audio: bytes
|
|
736
|
+
The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
|
|
737
|
+
mpga, m4a, ogg, wav, or webm.
|
|
738
|
+
language: Optional[str]
|
|
739
|
+
The language of the input audio. Supplying the input language in ISO-639-1
|
|
740
|
+
(https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format will improve accuracy and latency.
|
|
741
|
+
prompt: Optional[str]
|
|
742
|
+
An optional text to guide the model's style or continue a previous audio segment.
|
|
743
|
+
The prompt should match the audio language.
|
|
744
|
+
response_format: Optional[str], defaults to json
|
|
745
|
+
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
|
|
746
|
+
temperature: Optional[float], defaults to 0
|
|
747
|
+
The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
|
|
748
|
+
while lower values like 0.2 will make it more focused and deterministic.
|
|
749
|
+
If set to 0, the model will use log probability to automatically increase the temperature
|
|
750
|
+
until certain thresholds are hit.
|
|
751
|
+
timestamp_granularities: Optional[List[str]], default is None.
|
|
752
|
+
The timestamp granularities to populate for this transcription. response_format must be set verbose_json
|
|
753
|
+
to use timestamp granularities. Either or both of these options are supported: word, or segment.
|
|
754
|
+
Note: There is no additional latency for segment timestamps, but generating word timestamps incurs
|
|
755
|
+
additional latency.
|
|
756
|
+
|
|
757
|
+
Returns
|
|
758
|
+
-------
|
|
759
|
+
The transcribed text.
|
|
760
|
+
"""
|
|
761
|
+
url = f"{self._base_url}/v1/audio/transcriptions"
|
|
762
|
+
params = {
|
|
763
|
+
"model": self._model_uid,
|
|
764
|
+
"language": language,
|
|
765
|
+
"prompt": prompt,
|
|
766
|
+
"response_format": response_format,
|
|
767
|
+
"temperature": temperature,
|
|
768
|
+
"timestamp_granularities[]": timestamp_granularities,
|
|
769
|
+
"kwargs": json.dumps(kwargs),
|
|
770
|
+
}
|
|
771
|
+
params = _filter_params(params)
|
|
772
|
+
files: List[Any] = []
|
|
773
|
+
files.append(("file", ("file", audio, "application/octet-stream")))
|
|
774
|
+
response = await self.session.post(
|
|
775
|
+
url, data=params, files=files, headers=self.auth_headers
|
|
776
|
+
)
|
|
777
|
+
if response.status != 200:
|
|
778
|
+
raise RuntimeError(
|
|
779
|
+
f"Failed to transcribe the audio, detail: {await _get_error_string(response)}"
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
response_data = await response.json()
|
|
783
|
+
await _release_response(response)
|
|
784
|
+
return response_data
|
|
785
|
+
|
|
786
|
+
async def translations(
|
|
787
|
+
self,
|
|
788
|
+
audio: bytes,
|
|
789
|
+
language: Optional[str] = None,
|
|
790
|
+
prompt: Optional[str] = None,
|
|
791
|
+
response_format: Optional[str] = "json",
|
|
792
|
+
temperature: Optional[float] = 0,
|
|
793
|
+
timestamp_granularities: Optional[List[str]] = None,
|
|
794
|
+
):
|
|
795
|
+
"""
|
|
796
|
+
Translates audio into English.
|
|
797
|
+
|
|
798
|
+
Parameters
|
|
799
|
+
----------
|
|
800
|
+
|
|
801
|
+
audio: bytes
|
|
802
|
+
The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
|
|
803
|
+
mpga, m4a, ogg, wav, or webm.
|
|
804
|
+
language: Optional[str]
|
|
805
|
+
The language of the input audio. Supplying the input language in ISO-639-1
|
|
806
|
+
(https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format will improve accuracy and latency.
|
|
807
|
+
prompt: Optional[str]
|
|
808
|
+
An optional text to guide the model's style or continue a previous audio segment.
|
|
809
|
+
The prompt should match the audio language.
|
|
810
|
+
response_format: Optional[str], defaults to json
|
|
811
|
+
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
|
|
812
|
+
temperature: Optional[float], defaults to 0
|
|
813
|
+
The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
|
|
814
|
+
while lower values like 0.2 will make it more focused and deterministic.
|
|
815
|
+
If set to 0, the model will use log probability to automatically increase the temperature
|
|
816
|
+
until certain thresholds are hit.
|
|
817
|
+
timestamp_granularities: Optional[List[str]], default is None.
|
|
818
|
+
The timestamp granularities to populate for this transcription. response_format must be set verbose_json
|
|
819
|
+
to use timestamp granularities. Either or both of these options are supported: word, or segment.
|
|
820
|
+
Note: There is no additional latency for segment timestamps, but generating word timestamps incurs
|
|
821
|
+
additional latency.
|
|
822
|
+
|
|
823
|
+
Returns
|
|
824
|
+
-------
|
|
825
|
+
The translated text.
|
|
826
|
+
"""
|
|
827
|
+
url = f"{self._base_url}/v1/audio/translations"
|
|
828
|
+
params = {
|
|
829
|
+
"model": self._model_uid,
|
|
830
|
+
"language": language,
|
|
831
|
+
"prompt": prompt,
|
|
832
|
+
"response_format": response_format,
|
|
833
|
+
"temperature": temperature,
|
|
834
|
+
"timestamp_granularities[]": timestamp_granularities,
|
|
835
|
+
}
|
|
836
|
+
params = _filter_params(params)
|
|
837
|
+
files: List[Any] = []
|
|
838
|
+
files.append(("file", ("file", audio, "application/octet-stream")))
|
|
839
|
+
response = await self.session.post(
|
|
840
|
+
url, data=params, files=files, headers=self.auth_headers
|
|
841
|
+
)
|
|
842
|
+
if response.status != 200:
|
|
843
|
+
raise RuntimeError(
|
|
844
|
+
f"Failed to translate the audio, detail: {await _get_error_string(response)}"
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
response_data = await response.json()
|
|
848
|
+
await _release_response(response)
|
|
849
|
+
return response_data
|
|
850
|
+
|
|
851
|
+
async def speech(
|
|
852
|
+
self,
|
|
853
|
+
input: str,
|
|
854
|
+
voice: str = "",
|
|
855
|
+
response_format: str = "mp3",
|
|
856
|
+
speed: float = 1.0,
|
|
857
|
+
stream: bool = False,
|
|
858
|
+
prompt_speech: Optional[bytes] = None,
|
|
859
|
+
prompt_latent: Optional[bytes] = None,
|
|
860
|
+
**kwargs,
|
|
861
|
+
):
|
|
862
|
+
"""
|
|
863
|
+
Generates audio from the input text.
|
|
864
|
+
|
|
865
|
+
Parameters
|
|
866
|
+
----------
|
|
867
|
+
|
|
868
|
+
input: str
|
|
869
|
+
The text to generate audio for. The maximum length is 4096 characters.
|
|
870
|
+
voice: str
|
|
871
|
+
The voice to use when generating the audio.
|
|
872
|
+
response_format: str
|
|
873
|
+
The format to audio in.
|
|
874
|
+
speed: str
|
|
875
|
+
The speed of the generated audio.
|
|
876
|
+
stream: bool
|
|
877
|
+
Use stream or not.
|
|
878
|
+
prompt_speech: bytes
|
|
879
|
+
The audio bytes to be provided to the model.
|
|
880
|
+
prompt_latent: bytes
|
|
881
|
+
The latent bytes to be provided to the model.
|
|
882
|
+
|
|
883
|
+
Returns
|
|
884
|
+
-------
|
|
885
|
+
bytes
|
|
886
|
+
The generated audio binary.
|
|
887
|
+
"""
|
|
888
|
+
url = f"{self._base_url}/v1/audio/speech"
|
|
889
|
+
params = {
|
|
890
|
+
"model": self._model_uid,
|
|
891
|
+
"input": input,
|
|
892
|
+
"voice": voice,
|
|
893
|
+
"response_format": response_format,
|
|
894
|
+
"speed": speed,
|
|
895
|
+
"stream": stream,
|
|
896
|
+
"kwargs": json.dumps(kwargs),
|
|
897
|
+
}
|
|
898
|
+
params = _filter_params(params)
|
|
899
|
+
files: List[Any] = []
|
|
900
|
+
if prompt_speech:
|
|
901
|
+
files.append(
|
|
902
|
+
(
|
|
903
|
+
"prompt_speech",
|
|
904
|
+
("prompt_speech", prompt_speech, "application/octet-stream"),
|
|
905
|
+
)
|
|
906
|
+
)
|
|
907
|
+
if prompt_latent:
|
|
908
|
+
files.append(
|
|
909
|
+
(
|
|
910
|
+
"prompt_latent",
|
|
911
|
+
("prompt_latent", prompt_latent, "application/octet-stream"),
|
|
912
|
+
)
|
|
913
|
+
)
|
|
914
|
+
if files:
|
|
915
|
+
response = await self.session.post(
|
|
916
|
+
url, data=params, files=files, headers=self.auth_headers
|
|
917
|
+
)
|
|
918
|
+
else:
|
|
919
|
+
response = await self.session.post(
|
|
920
|
+
url, json=params, headers=self.auth_headers
|
|
921
|
+
)
|
|
922
|
+
if response.status != 200:
|
|
923
|
+
raise RuntimeError(
|
|
924
|
+
f"Failed to speech the text, detail: {await _get_error_string(response)}"
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
if stream:
|
|
928
|
+
await _release_response(response)
|
|
929
|
+
return response.content.iter_chunked(1024)
|
|
930
|
+
await _release_response(response)
|
|
931
|
+
return response.content
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
class AsyncRESTfulFlexibleModelHandle(AsyncRESTfulModelHandle):
|
|
935
|
+
async def infer(
|
|
936
|
+
self,
|
|
937
|
+
**kwargs,
|
|
938
|
+
):
|
|
939
|
+
"""
|
|
940
|
+
Call flexible model.
|
|
941
|
+
|
|
942
|
+
Parameters
|
|
943
|
+
----------
|
|
944
|
+
|
|
945
|
+
kwargs: dict
|
|
946
|
+
The inference arguments.
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
Returns
|
|
950
|
+
-------
|
|
951
|
+
bytes
|
|
952
|
+
The inference result.
|
|
953
|
+
"""
|
|
954
|
+
url = f"{self._base_url}/v1/flexible/infers"
|
|
955
|
+
params: Dict = { # type: ignore
|
|
956
|
+
"model": self._model_uid,
|
|
957
|
+
}
|
|
958
|
+
params.update(kwargs)
|
|
959
|
+
|
|
960
|
+
response = await self.session.post(url, json=params, headers=self.auth_headers)
|
|
961
|
+
if response.status != 200:
|
|
962
|
+
raise RuntimeError(
|
|
963
|
+
f"Failed to predict, detail: {await _get_error_string(response)}"
|
|
964
|
+
)
|
|
965
|
+
await _release_response(response)
|
|
966
|
+
return response.content
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
class AsyncClient:
|
|
970
|
+
def __init__(self, base_url, api_key: Optional[str] = None):
|
|
971
|
+
self.base_url = base_url
|
|
972
|
+
self._headers: Dict[str, str] = {}
|
|
973
|
+
self._cluster_authed = False
|
|
974
|
+
self.session = aiohttp.ClientSession(
|
|
975
|
+
connector=aiohttp.TCPConnector(force_close=True)
|
|
976
|
+
)
|
|
977
|
+
self._check_cluster_authenticated()
|
|
978
|
+
if api_key is not None and self._cluster_authed:
|
|
979
|
+
self._headers["Authorization"] = f"Bearer {api_key}"
|
|
980
|
+
|
|
981
|
+
async def close(self):
|
|
982
|
+
"""Close the AsyncClient session."""
|
|
983
|
+
if self.session:
|
|
984
|
+
await self.session.close()
|
|
985
|
+
self.session = None
|
|
986
|
+
|
|
987
|
+
def __del__(self):
|
|
988
|
+
if self.session:
|
|
989
|
+
try:
|
|
990
|
+
loop = asyncio.get_event_loop()
|
|
991
|
+
except RuntimeError:
|
|
992
|
+
loop = asyncio.new_event_loop()
|
|
993
|
+
asyncio.set_event_loop(loop)
|
|
994
|
+
loop.run_until_complete(self.close())
|
|
995
|
+
|
|
996
|
+
def _set_token(self, token: Optional[str]):
|
|
997
|
+
if not self._cluster_authed or token is None:
|
|
998
|
+
return
|
|
999
|
+
self._headers["Authorization"] = f"Bearer {token}"
|
|
1000
|
+
|
|
1001
|
+
def _get_token(self) -> Optional[str]:
|
|
1002
|
+
return (
|
|
1003
|
+
str(self._headers["Authorization"]).replace("Bearer ", "")
|
|
1004
|
+
if "Authorization" in self._headers
|
|
1005
|
+
else None
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
def _check_cluster_authenticated(self):
|
|
1009
|
+
import requests
|
|
1010
|
+
|
|
1011
|
+
session = requests.Session()
|
|
1012
|
+
url = f"{self.base_url}/v1/cluster/auth"
|
|
1013
|
+
response = session.get(url)
|
|
1014
|
+
# compatible with old version of xinference
|
|
1015
|
+
if response.status_code == 404:
|
|
1016
|
+
self._cluster_authed = False
|
|
1017
|
+
else:
|
|
1018
|
+
if response.status_code != 200:
|
|
1019
|
+
response_data = response.json()
|
|
1020
|
+
raise RuntimeError(
|
|
1021
|
+
f"Failed to get cluster information, detail: {response_data['detail']}"
|
|
1022
|
+
)
|
|
1023
|
+
response_data = response.json()
|
|
1024
|
+
self._cluster_authed = bool(response_data["auth"])
|
|
1025
|
+
|
|
1026
|
+
async def vllm_models(self) -> Dict[str, Any]:
|
|
1027
|
+
url = f"{self.base_url}/v1/models/vllm-supported"
|
|
1028
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1029
|
+
if response.status != 200:
|
|
1030
|
+
response_data = await response.json()
|
|
1031
|
+
await _release_response(response)
|
|
1032
|
+
raise RuntimeError(
|
|
1033
|
+
f"Failed to get cluster information, detail: {response_data['detail']}"
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
try:
|
|
1037
|
+
response_data = await response.json()
|
|
1038
|
+
await _release_response(response)
|
|
1039
|
+
return response_data
|
|
1040
|
+
except Exception as e:
|
|
1041
|
+
raise RuntimeError(f"Error parsing JSON response: {e}")
|
|
1042
|
+
|
|
1043
|
+
async def login(self, username: str, password: str):
|
|
1044
|
+
if not self._cluster_authed:
|
|
1045
|
+
return
|
|
1046
|
+
url = f"{self.base_url}/token"
|
|
1047
|
+
|
|
1048
|
+
payload = {"username": username, "password": password}
|
|
1049
|
+
|
|
1050
|
+
response = await self.session.post(url, json=payload)
|
|
1051
|
+
if response.status != 200:
|
|
1052
|
+
response_data = await response.json()
|
|
1053
|
+
await _release_response(response)
|
|
1054
|
+
raise RuntimeError(
|
|
1055
|
+
f"Failed to get cluster information, detail: {response_data['detail']}"
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
response_data = await response.json()
|
|
1059
|
+
await _release_response(response)
|
|
1060
|
+
# Only bearer token for now
|
|
1061
|
+
access_token = response_data["access_token"]
|
|
1062
|
+
self._headers["Authorization"] = f"Bearer {access_token}"
|
|
1063
|
+
|
|
1064
|
+
async def list_models(self) -> Dict[str, Dict[str, Any]]:
|
|
1065
|
+
"""
|
|
1066
|
+
Retrieve the model specifications from the Server.
|
|
1067
|
+
|
|
1068
|
+
Returns
|
|
1069
|
+
-------
|
|
1070
|
+
Dict[str, Dict[str, Any]]
|
|
1071
|
+
The collection of model specifications with their names on the server.
|
|
1072
|
+
|
|
1073
|
+
"""
|
|
1074
|
+
|
|
1075
|
+
url = f"{self.base_url}/v1/models"
|
|
1076
|
+
|
|
1077
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1078
|
+
if response.status != 200:
|
|
1079
|
+
raise RuntimeError(
|
|
1080
|
+
f"Failed to list model, detail: {await _get_error_string(response)}"
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
response_data = await response.json()
|
|
1084
|
+
await _release_response(response)
|
|
1085
|
+
model_list = response_data["data"]
|
|
1086
|
+
return {item["id"]: item for item in model_list}
|
|
1087
|
+
|
|
1088
|
+
async def launch_model(
|
|
1089
|
+
self,
|
|
1090
|
+
model_name: str,
|
|
1091
|
+
model_type: str = "LLM",
|
|
1092
|
+
model_engine: Optional[str] = None,
|
|
1093
|
+
model_uid: Optional[str] = None,
|
|
1094
|
+
model_size_in_billions: Optional[Union[int, str, float]] = None,
|
|
1095
|
+
model_format: Optional[str] = None,
|
|
1096
|
+
quantization: Optional[str] = None,
|
|
1097
|
+
replica: int = 1,
|
|
1098
|
+
n_worker: int = 1,
|
|
1099
|
+
n_gpu: Optional[Union[int, str]] = "auto",
|
|
1100
|
+
peft_model_config: Optional[Dict] = None,
|
|
1101
|
+
request_limits: Optional[int] = None,
|
|
1102
|
+
worker_ip: Optional[str] = None,
|
|
1103
|
+
gpu_idx: Optional[Union[int, List[int]]] = None,
|
|
1104
|
+
model_path: Optional[str] = None,
|
|
1105
|
+
**kwargs,
|
|
1106
|
+
) -> str:
|
|
1107
|
+
"""
|
|
1108
|
+
Launch the model based on the parameters on the server via RESTful APIs.
|
|
1109
|
+
|
|
1110
|
+
Parameters
|
|
1111
|
+
----------
|
|
1112
|
+
model_name: str
|
|
1113
|
+
The name of model.
|
|
1114
|
+
model_type: str
|
|
1115
|
+
type of model.
|
|
1116
|
+
model_engine: Optional[str]
|
|
1117
|
+
Specify the inference engine of the model when launching LLM.
|
|
1118
|
+
model_uid: str
|
|
1119
|
+
UID of model, auto generate a UUID if is None.
|
|
1120
|
+
model_size_in_billions: Optional[Union[int, str, float]]
|
|
1121
|
+
The size (in billions) of the model.
|
|
1122
|
+
model_format: Optional[str]
|
|
1123
|
+
The format of the model.
|
|
1124
|
+
quantization: Optional[str]
|
|
1125
|
+
The quantization of model.
|
|
1126
|
+
replica: Optional[int]
|
|
1127
|
+
The replica of model, default is 1.
|
|
1128
|
+
n_worker: int
|
|
1129
|
+
Number of workers to run.
|
|
1130
|
+
n_gpu: Optional[Union[int, str]],
|
|
1131
|
+
The number of GPUs used by the model, default is "auto". If n_worker>1, means number of GPUs per worker.
|
|
1132
|
+
``n_gpu=None`` means cpu only, ``n_gpu=auto`` lets the system automatically determine the best number of GPUs to use.
|
|
1133
|
+
peft_model_config: Optional[Dict]
|
|
1134
|
+
- "lora_list": A List of PEFT (Parameter-Efficient Fine-Tuning) model and path.
|
|
1135
|
+
- "image_lora_load_kwargs": A Dict of lora load parameters for image model
|
|
1136
|
+
- "image_lora_fuse_kwargs": A Dict of lora fuse parameters for image model
|
|
1137
|
+
request_limits: Optional[int]
|
|
1138
|
+
The number of request limits for this model, default is None.
|
|
1139
|
+
``request_limits=None`` means no limits for this model.
|
|
1140
|
+
worker_ip: Optional[str]
|
|
1141
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1142
|
+
gpu_idx: Optional[Union[int, List[int]]]
|
|
1143
|
+
Specify the GPU index where the model is located.
|
|
1144
|
+
model_path: Optional[str]
|
|
1145
|
+
Model path, if gguf format, should be the file path, otherwise, should be directory of the model.
|
|
1146
|
+
**kwargs:
|
|
1147
|
+
Any other parameters been specified. e.g. multimodal_projector for multimodal inference with the llama.cpp backend.
|
|
1148
|
+
|
|
1149
|
+
Returns
|
|
1150
|
+
-------
|
|
1151
|
+
str
|
|
1152
|
+
The unique model_uid for the launched model.
|
|
1153
|
+
|
|
1154
|
+
"""
|
|
1155
|
+
|
|
1156
|
+
url = f"{self.base_url}/v1/models"
|
|
1157
|
+
|
|
1158
|
+
# convert float to int or string since the RESTful API does not accept float.
|
|
1159
|
+
if isinstance(model_size_in_billions, float):
|
|
1160
|
+
model_size_in_billions = convert_float_to_int_or_str(model_size_in_billions)
|
|
1161
|
+
|
|
1162
|
+
payload = {
|
|
1163
|
+
"model_uid": model_uid,
|
|
1164
|
+
"model_name": model_name,
|
|
1165
|
+
"model_engine": model_engine,
|
|
1166
|
+
"peft_model_config": peft_model_config,
|
|
1167
|
+
"model_type": model_type,
|
|
1168
|
+
"model_size_in_billions": model_size_in_billions,
|
|
1169
|
+
"model_format": model_format,
|
|
1170
|
+
"quantization": quantization,
|
|
1171
|
+
"replica": replica,
|
|
1172
|
+
"n_worker": n_worker,
|
|
1173
|
+
"n_gpu": n_gpu,
|
|
1174
|
+
"request_limits": request_limits,
|
|
1175
|
+
"worker_ip": worker_ip,
|
|
1176
|
+
"gpu_idx": gpu_idx,
|
|
1177
|
+
"model_path": model_path,
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
wait_ready = kwargs.pop("wait_ready", True)
|
|
1181
|
+
|
|
1182
|
+
for key, value in kwargs.items():
|
|
1183
|
+
payload[str(key)] = value
|
|
1184
|
+
|
|
1185
|
+
if wait_ready:
|
|
1186
|
+
response = await self.session.post(url, json=payload, headers=self._headers)
|
|
1187
|
+
else:
|
|
1188
|
+
response = await self.session.post(
|
|
1189
|
+
url, json=payload, headers=self._headers, params={"wait_ready": False}
|
|
1190
|
+
)
|
|
1191
|
+
if response.status != 200:
|
|
1192
|
+
raise RuntimeError(
|
|
1193
|
+
f"Failed to launch model, detail: {await _get_error_string(response)}"
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
response_data = await response.json()
|
|
1197
|
+
await _release_response(response)
|
|
1198
|
+
return response_data["model_uid"]
|
|
1199
|
+
|
|
1200
|
+
async def terminate_model(self, model_uid: str):
|
|
1201
|
+
"""
|
|
1202
|
+
Terminate the specific model running on the server.
|
|
1203
|
+
|
|
1204
|
+
Parameters
|
|
1205
|
+
----------
|
|
1206
|
+
model_uid: str
|
|
1207
|
+
The unique id that identify the model we want.
|
|
1208
|
+
|
|
1209
|
+
Raises
|
|
1210
|
+
------
|
|
1211
|
+
RuntimeError
|
|
1212
|
+
Report failure to get the wanted model with given model_uid. Provide details of failure through error message.
|
|
1213
|
+
|
|
1214
|
+
"""
|
|
1215
|
+
|
|
1216
|
+
url = f"{self.base_url}/v1/models/{model_uid}"
|
|
1217
|
+
|
|
1218
|
+
response = await self.session.delete(url, headers=self._headers)
|
|
1219
|
+
if response.status != 200:
|
|
1220
|
+
raise RuntimeError(
|
|
1221
|
+
f"Failed to terminate model, detail: {await _get_error_string(response)}"
|
|
1222
|
+
)
|
|
1223
|
+
await _release_response(response)
|
|
1224
|
+
|
|
1225
|
+
async def get_launch_model_progress(self, model_uid: str) -> dict:
|
|
1226
|
+
"""
|
|
1227
|
+
Get progress of the specific model.
|
|
1228
|
+
|
|
1229
|
+
Parameters
|
|
1230
|
+
----------
|
|
1231
|
+
model_uid: str
|
|
1232
|
+
The unique id that identify the model we want.
|
|
1233
|
+
|
|
1234
|
+
Returns
|
|
1235
|
+
-------
|
|
1236
|
+
result: dict
|
|
1237
|
+
Result that contains progress.
|
|
1238
|
+
|
|
1239
|
+
Raises
|
|
1240
|
+
------
|
|
1241
|
+
RuntimeError
|
|
1242
|
+
Report failure to get the wanted model with given model_uid. Provide details of failure through error message.
|
|
1243
|
+
"""
|
|
1244
|
+
url = f"{self.base_url}/v1/models/{model_uid}/progress"
|
|
1245
|
+
|
|
1246
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1247
|
+
if response.status != 200:
|
|
1248
|
+
raise RuntimeError(
|
|
1249
|
+
f"Fail to get model launching progress, detail: {await _get_error_string(response)}"
|
|
1250
|
+
)
|
|
1251
|
+
response_data = await response.json()
|
|
1252
|
+
await _release_response(response)
|
|
1253
|
+
return response_data
|
|
1254
|
+
|
|
1255
|
+
async def cancel_launch_model(self, model_uid: str):
|
|
1256
|
+
"""
|
|
1257
|
+
Cancel launching model.
|
|
1258
|
+
|
|
1259
|
+
Parameters
|
|
1260
|
+
----------
|
|
1261
|
+
model_uid: str
|
|
1262
|
+
The unique id that identify the model we want.
|
|
1263
|
+
|
|
1264
|
+
Raises
|
|
1265
|
+
------
|
|
1266
|
+
RuntimeError
|
|
1267
|
+
Report failure to get the wanted model with given model_uid. Provide details of failure through error message.
|
|
1268
|
+
"""
|
|
1269
|
+
url = f"{self.base_url}/v1/models/{model_uid}/cancel"
|
|
1270
|
+
|
|
1271
|
+
response = await self.session.post(url, headers=self._headers)
|
|
1272
|
+
if response.status != 200:
|
|
1273
|
+
raise RuntimeError(
|
|
1274
|
+
f"Fail to cancel launching model, detail: {await _get_error_string(response)}"
|
|
1275
|
+
)
|
|
1276
|
+
await _release_response(response)
|
|
1277
|
+
|
|
1278
|
+
async def get_instance_info(self, model_name: str, model_uid: str):
|
|
1279
|
+
url = f"{self.base_url}/v1/models/instances"
|
|
1280
|
+
response = await self.session.get(
|
|
1281
|
+
url,
|
|
1282
|
+
headers=self._headers,
|
|
1283
|
+
params={"model_name": model_name, "model_uid": model_uid},
|
|
1284
|
+
)
|
|
1285
|
+
if response.status != 200:
|
|
1286
|
+
raise RuntimeError("Failed to get instance info")
|
|
1287
|
+
response_data = await response.json()
|
|
1288
|
+
await _release_response(response)
|
|
1289
|
+
return response_data
|
|
1290
|
+
|
|
1291
|
+
async def _get_supervisor_internal_address(self):
|
|
1292
|
+
url = f"{self.base_url}/v1/address"
|
|
1293
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1294
|
+
if response.status != 200:
|
|
1295
|
+
raise RuntimeError("Failed to get supervisor internal address")
|
|
1296
|
+
response_data = await response.json()
|
|
1297
|
+
await _release_response(response)
|
|
1298
|
+
return response_data
|
|
1299
|
+
|
|
1300
|
+
async def get_model(self, model_uid: str) -> AsyncRESTfulModelHandle:
|
|
1301
|
+
"""
|
|
1302
|
+
Launch the model based on the parameters on the server via RESTful APIs.
|
|
1303
|
+
|
|
1304
|
+
Parameters
|
|
1305
|
+
----------
|
|
1306
|
+
model_uid: str
|
|
1307
|
+
The unique id that identify the model.
|
|
1308
|
+
|
|
1309
|
+
Returns
|
|
1310
|
+
-------
|
|
1311
|
+
ModelHandle
|
|
1312
|
+
The corresponding Model Handler based on the Model specified in the uid:
|
|
1313
|
+
- :obj:`xinference.client.handlers.GenerateModelHandle` -> provide handle to basic generate Model. e.g. Baichuan.
|
|
1314
|
+
- :obj:`xinference.client.handlers.ChatModelHandle` -> provide handle to chat Model. e.g. Baichuan-chat.
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
Raises
|
|
1318
|
+
------
|
|
1319
|
+
RuntimeError
|
|
1320
|
+
Report failure to get the wanted model with given model_uid. Provide details of failure through error message.
|
|
1321
|
+
|
|
1322
|
+
"""
|
|
1323
|
+
|
|
1324
|
+
url = f"{self.base_url}/v1/models/{model_uid}"
|
|
1325
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1326
|
+
if response.status != 200:
|
|
1327
|
+
raise RuntimeError(
|
|
1328
|
+
f"Failed to get the model description, detail: {await _get_error_string(response)}"
|
|
1329
|
+
)
|
|
1330
|
+
desc = await response.json()
|
|
1331
|
+
await _release_response(response)
|
|
1332
|
+
if desc["model_type"] == "LLM":
|
|
1333
|
+
if "chat" in desc["model_ability"]:
|
|
1334
|
+
return AsyncRESTfulChatModelHandle(
|
|
1335
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1336
|
+
)
|
|
1337
|
+
elif "generate" in desc["model_ability"]:
|
|
1338
|
+
return AsyncRESTfulGenerateModelHandle(
|
|
1339
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1340
|
+
)
|
|
1341
|
+
else:
|
|
1342
|
+
raise ValueError(f"Unrecognized model ability: {desc['model_ability']}")
|
|
1343
|
+
elif desc["model_type"] == "embedding":
|
|
1344
|
+
return AsyncRESTfulEmbeddingModelHandle(
|
|
1345
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1346
|
+
)
|
|
1347
|
+
elif desc["model_type"] == "image":
|
|
1348
|
+
return AsyncRESTfulImageModelHandle(
|
|
1349
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1350
|
+
)
|
|
1351
|
+
elif desc["model_type"] == "rerank":
|
|
1352
|
+
return AsyncRESTfulRerankModelHandle(
|
|
1353
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1354
|
+
)
|
|
1355
|
+
elif desc["model_type"] == "audio":
|
|
1356
|
+
return AsyncRESTfulAudioModelHandle(
|
|
1357
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1358
|
+
)
|
|
1359
|
+
elif desc["model_type"] == "video":
|
|
1360
|
+
return AsyncRESTfulVideoModelHandle(
|
|
1361
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1362
|
+
)
|
|
1363
|
+
elif desc["model_type"] == "flexible":
|
|
1364
|
+
return AsyncRESTfulFlexibleModelHandle(
|
|
1365
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1366
|
+
)
|
|
1367
|
+
else:
|
|
1368
|
+
raise ValueError(f"Unknown model type:{desc['model_type']}")
|
|
1369
|
+
|
|
1370
|
+
async def describe_model(self, model_uid: str):
|
|
1371
|
+
"""
|
|
1372
|
+
Get model information via RESTful APIs.
|
|
1373
|
+
|
|
1374
|
+
Parameters
|
|
1375
|
+
----------
|
|
1376
|
+
model_uid: str
|
|
1377
|
+
The unique id that identify the model.
|
|
1378
|
+
|
|
1379
|
+
Returns
|
|
1380
|
+
-------
|
|
1381
|
+
dict
|
|
1382
|
+
A dictionary containing the following keys:
|
|
1383
|
+
|
|
1384
|
+
- "model_type": str
|
|
1385
|
+
the type of the model determined by its function, e.g. "LLM" (Large Language Model)
|
|
1386
|
+
- "model_name": str
|
|
1387
|
+
the name of the specific LLM model family
|
|
1388
|
+
- "model_lang": List[str]
|
|
1389
|
+
the languages supported by the LLM model
|
|
1390
|
+
- "model_ability": List[str]
|
|
1391
|
+
the ability or capabilities of the LLM model
|
|
1392
|
+
- "model_description": str
|
|
1393
|
+
a detailed description of the LLM model
|
|
1394
|
+
- "model_format": str
|
|
1395
|
+
the format specification of the LLM model
|
|
1396
|
+
- "model_size_in_billions": int
|
|
1397
|
+
the size of the LLM model in billions
|
|
1398
|
+
- "quantization": str
|
|
1399
|
+
the quantization applied to the model
|
|
1400
|
+
- "revision": str
|
|
1401
|
+
the revision number of the LLM model specification
|
|
1402
|
+
- "context_length": int
|
|
1403
|
+
the maximum text length the LLM model can accommodate (include all input & output)
|
|
1404
|
+
|
|
1405
|
+
Raises
|
|
1406
|
+
------
|
|
1407
|
+
RuntimeError
|
|
1408
|
+
Report failure to get the wanted model with given model_uid. Provide details of failure through error message.
|
|
1409
|
+
|
|
1410
|
+
"""
|
|
1411
|
+
|
|
1412
|
+
url = f"{self.base_url}/v1/models/{model_uid}"
|
|
1413
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1414
|
+
if response.status != 200:
|
|
1415
|
+
raise RuntimeError(
|
|
1416
|
+
f"Failed to get the model description, detail: {await _get_error_string(response)}"
|
|
1417
|
+
)
|
|
1418
|
+
response_data = await response.json()
|
|
1419
|
+
await _release_response(response)
|
|
1420
|
+
return response_data
|
|
1421
|
+
|
|
1422
|
+
async def register_model(
|
|
1423
|
+
self,
|
|
1424
|
+
model_type: str,
|
|
1425
|
+
model: str,
|
|
1426
|
+
persist: bool,
|
|
1427
|
+
worker_ip: Optional[str] = None,
|
|
1428
|
+
):
|
|
1429
|
+
"""
|
|
1430
|
+
Register a custom model.
|
|
1431
|
+
|
|
1432
|
+
Parameters
|
|
1433
|
+
----------
|
|
1434
|
+
model_type: str
|
|
1435
|
+
The type of model.
|
|
1436
|
+
model: str
|
|
1437
|
+
The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html)
|
|
1438
|
+
worker_ip: Optional[str]
|
|
1439
|
+
The IP address of the worker on which the model is running.
|
|
1440
|
+
persist: bool
|
|
1441
|
+
|
|
1442
|
+
|
|
1443
|
+
Raises
|
|
1444
|
+
------
|
|
1445
|
+
RuntimeError
|
|
1446
|
+
Report failure to register the custom model. Provide details of failure through error message.
|
|
1447
|
+
"""
|
|
1448
|
+
url = f"{self.base_url}/v1/model_registrations/{model_type}"
|
|
1449
|
+
request_body = {"model": model, "worker_ip": worker_ip, "persist": persist}
|
|
1450
|
+
response = await self.session.post(
|
|
1451
|
+
url, json=request_body, headers=self._headers
|
|
1452
|
+
)
|
|
1453
|
+
if response.status != 200:
|
|
1454
|
+
raise RuntimeError(
|
|
1455
|
+
f"Failed to register model, detail: {await _get_error_string(response)}"
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
response_data = await response.json()
|
|
1459
|
+
await _release_response(response)
|
|
1460
|
+
return response_data
|
|
1461
|
+
|
|
1462
|
+
async def unregister_model(self, model_type: str, model_name: str):
|
|
1463
|
+
"""
|
|
1464
|
+
Unregister a custom model.
|
|
1465
|
+
|
|
1466
|
+
Parameters
|
|
1467
|
+
----------
|
|
1468
|
+
model_type: str
|
|
1469
|
+
The type of model.
|
|
1470
|
+
model_name: str
|
|
1471
|
+
The name of the model
|
|
1472
|
+
|
|
1473
|
+
Raises
|
|
1474
|
+
------
|
|
1475
|
+
RuntimeError
|
|
1476
|
+
Report failure to unregister the custom model. Provide details of failure through error message.
|
|
1477
|
+
"""
|
|
1478
|
+
url = f"{self.base_url}/v1/model_registrations/{model_type}/{model_name}"
|
|
1479
|
+
response = await self.session.delete(url, headers=self._headers)
|
|
1480
|
+
if response.status != 200:
|
|
1481
|
+
raise RuntimeError(
|
|
1482
|
+
f"Failed to register model, detail: {await _get_error_string(response)}"
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
response_data = await response.json()
|
|
1486
|
+
await _release_response(response)
|
|
1487
|
+
return response_data
|
|
1488
|
+
|
|
1489
|
+
async def list_model_registrations(self, model_type: str) -> List[Dict[str, Any]]:
|
|
1490
|
+
"""
|
|
1491
|
+
List models registered on the server.
|
|
1492
|
+
|
|
1493
|
+
Parameters
|
|
1494
|
+
----------
|
|
1495
|
+
model_type: str
|
|
1496
|
+
The type of the model.
|
|
1497
|
+
|
|
1498
|
+
Returns
|
|
1499
|
+
-------
|
|
1500
|
+
List[Dict[str, Any]]
|
|
1501
|
+
The collection of registered models on the server.
|
|
1502
|
+
|
|
1503
|
+
Raises
|
|
1504
|
+
------
|
|
1505
|
+
RuntimeError
|
|
1506
|
+
Report failure to list model registration. Provide details of failure through error message.
|
|
1507
|
+
|
|
1508
|
+
"""
|
|
1509
|
+
url = f"{self.base_url}/v1/model_registrations/{model_type}"
|
|
1510
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1511
|
+
if response.status != 200:
|
|
1512
|
+
raise RuntimeError(
|
|
1513
|
+
f"Failed to list model registration, detail: {await _get_error_string(response)}"
|
|
1514
|
+
)
|
|
1515
|
+
|
|
1516
|
+
response_data = await response.json()
|
|
1517
|
+
await _release_response(response)
|
|
1518
|
+
return response_data
|
|
1519
|
+
|
|
1520
|
+
async def list_cached_models(
|
|
1521
|
+
self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
|
|
1522
|
+
) -> List[Dict[Any, Any]]:
|
|
1523
|
+
"""
|
|
1524
|
+
Get a list of cached models.
|
|
1525
|
+
Parameters
|
|
1526
|
+
----------
|
|
1527
|
+
model_name: Optional[str]
|
|
1528
|
+
The name of model.
|
|
1529
|
+
worker_ip: Optional[str]
|
|
1530
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1531
|
+
|
|
1532
|
+
Returns
|
|
1533
|
+
-------
|
|
1534
|
+
List[Dict[Any, Any]]
|
|
1535
|
+
The collection of cached models on the server.
|
|
1536
|
+
|
|
1537
|
+
Raises
|
|
1538
|
+
------
|
|
1539
|
+
RuntimeError
|
|
1540
|
+
Raised when the request fails, including the reason for the failure.
|
|
1541
|
+
"""
|
|
1542
|
+
|
|
1543
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1544
|
+
params = {
|
|
1545
|
+
"model_name": model_name,
|
|
1546
|
+
"worker_ip": worker_ip,
|
|
1547
|
+
}
|
|
1548
|
+
params = _filter_params(params)
|
|
1549
|
+
response = await self.session.get(url, headers=self._headers, params=params)
|
|
1550
|
+
if response.status != 200:
|
|
1551
|
+
raise RuntimeError(
|
|
1552
|
+
f"Failed to list cached model, detail: {await _get_error_string(response)}"
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
response_data = await response.json()
|
|
1556
|
+
await _release_response(response)
|
|
1557
|
+
return response_data.get("list")
|
|
1558
|
+
|
|
1559
|
+
async def list_deletable_models(
|
|
1560
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1561
|
+
) -> Dict[str, Any]:
|
|
1562
|
+
"""
|
|
1563
|
+
Get the cached models with the model path cached on the server.
|
|
1564
|
+
Parameters
|
|
1565
|
+
----------
|
|
1566
|
+
model_version: str
|
|
1567
|
+
The version of the model.
|
|
1568
|
+
worker_ip: Optional[str]
|
|
1569
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1570
|
+
Returns
|
|
1571
|
+
-------
|
|
1572
|
+
Dict[str, Dict[str,str]]]
|
|
1573
|
+
Dictionary with keys "model_name" and values model_file_location.
|
|
1574
|
+
"""
|
|
1575
|
+
url = f"{self.base_url}/v1/cache/models/files"
|
|
1576
|
+
params = {
|
|
1577
|
+
"model_version": model_version,
|
|
1578
|
+
"worker_ip": worker_ip,
|
|
1579
|
+
}
|
|
1580
|
+
params = _filter_params(params)
|
|
1581
|
+
response = await self.session.get(url, headers=self._headers, params=params)
|
|
1582
|
+
if response.status != 200:
|
|
1583
|
+
raise RuntimeError(
|
|
1584
|
+
f"Failed to get paths by model name, detail: {await _get_error_string(response)}"
|
|
1585
|
+
)
|
|
1586
|
+
|
|
1587
|
+
response_data = await response.json()
|
|
1588
|
+
await _release_response(response)
|
|
1589
|
+
return response_data
|
|
1590
|
+
|
|
1591
|
+
async def confirm_and_remove_model(
|
|
1592
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1593
|
+
) -> bool:
|
|
1594
|
+
"""
|
|
1595
|
+
Remove the cached models with the model name cached on the server.
|
|
1596
|
+
Parameters
|
|
1597
|
+
----------
|
|
1598
|
+
model_version: str
|
|
1599
|
+
The version of the model.
|
|
1600
|
+
worker_ip: Optional[str]
|
|
1601
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1602
|
+
Returns
|
|
1603
|
+
-------
|
|
1604
|
+
str
|
|
1605
|
+
The response of the server.
|
|
1606
|
+
"""
|
|
1607
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1608
|
+
params = {
|
|
1609
|
+
"model_version": model_version,
|
|
1610
|
+
"worker_ip": worker_ip,
|
|
1611
|
+
}
|
|
1612
|
+
params = _filter_params(params)
|
|
1613
|
+
response = await self.session.delete(url, headers=self._headers, params=params)
|
|
1614
|
+
if response.status != 200:
|
|
1615
|
+
raise RuntimeError(
|
|
1616
|
+
f"Failed to remove cached models, detail: {await _get_error_string(response)}"
|
|
1617
|
+
)
|
|
1618
|
+
|
|
1619
|
+
response_data = await response.json()
|
|
1620
|
+
await _release_response(response)
|
|
1621
|
+
return response_data.get("result", False)
|
|
1622
|
+
|
|
1623
|
+
async def get_model_registration(
|
|
1624
|
+
self, model_type: str, model_name: str
|
|
1625
|
+
) -> Dict[str, Any]:
|
|
1626
|
+
"""
|
|
1627
|
+
Get the model with the model type and model name registered on the server.
|
|
1628
|
+
|
|
1629
|
+
Parameters
|
|
1630
|
+
----------
|
|
1631
|
+
model_type: str
|
|
1632
|
+
The type of the model.
|
|
1633
|
+
|
|
1634
|
+
model_name: str
|
|
1635
|
+
The name of the model.
|
|
1636
|
+
Returns
|
|
1637
|
+
-------
|
|
1638
|
+
List[Dict[str, Any]]
|
|
1639
|
+
The collection of registered models on the server.
|
|
1640
|
+
"""
|
|
1641
|
+
url = f"{self.base_url}/v1/model_registrations/{model_type}/{model_name}"
|
|
1642
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1643
|
+
if response.status != 200:
|
|
1644
|
+
raise RuntimeError(
|
|
1645
|
+
f"Failed to list model registration, detail: {await _get_error_string(response)}"
|
|
1646
|
+
)
|
|
1647
|
+
|
|
1648
|
+
response_data = await response.json()
|
|
1649
|
+
await _release_response(response)
|
|
1650
|
+
return response_data
|
|
1651
|
+
|
|
1652
|
+
async def query_engine_by_model_name(
|
|
1653
|
+
self, model_name: str, model_type: Optional[str] = "LLM"
|
|
1654
|
+
):
|
|
1655
|
+
"""
|
|
1656
|
+
Get the engine parameters with the model name registered on the server.
|
|
1657
|
+
|
|
1658
|
+
Parameters
|
|
1659
|
+
----------
|
|
1660
|
+
model_name: str
|
|
1661
|
+
The name of the model.
|
|
1662
|
+
model_type: str
|
|
1663
|
+
Model type, LLM by default.
|
|
1664
|
+
Returns
|
|
1665
|
+
-------
|
|
1666
|
+
Dict[str, List[Dict[str, Any]]]
|
|
1667
|
+
The supported engine parameters of registered models on the server.
|
|
1668
|
+
"""
|
|
1669
|
+
if not model_type:
|
|
1670
|
+
url = f"{self.base_url}/v1/engines/{model_name}"
|
|
1671
|
+
else:
|
|
1672
|
+
url = f"{self.base_url}/v1/engines/{model_type}/{model_name}"
|
|
1673
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1674
|
+
if response.status != 200:
|
|
1675
|
+
raise RuntimeError(
|
|
1676
|
+
f"Failed to query engine parameters by model name, detail: {await _get_error_string(response)}"
|
|
1677
|
+
)
|
|
1678
|
+
|
|
1679
|
+
response_data = await response.json()
|
|
1680
|
+
await _release_response(response)
|
|
1681
|
+
return response_data
|
|
1682
|
+
|
|
1683
|
+
async def abort_request(
|
|
1684
|
+
self, model_uid: str, request_id: str, block_duration: int = 30
|
|
1685
|
+
):
|
|
1686
|
+
"""
|
|
1687
|
+
Abort a request.
|
|
1688
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
1689
|
+
Currently, this interface is only supported when batching is enabled for models on transformers backend.
|
|
1690
|
+
|
|
1691
|
+
Parameters
|
|
1692
|
+
----------
|
|
1693
|
+
model_uid: str
|
|
1694
|
+
Model uid.
|
|
1695
|
+
request_id: str
|
|
1696
|
+
Request id.
|
|
1697
|
+
block_duration: int
|
|
1698
|
+
The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
|
|
1699
|
+
prevent it from taking effect if it arrives before the request operation.
|
|
1700
|
+
Returns
|
|
1701
|
+
-------
|
|
1702
|
+
Dict
|
|
1703
|
+
Return empty dict.
|
|
1704
|
+
"""
|
|
1705
|
+
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1706
|
+
response = await self.session.post(
|
|
1707
|
+
url, headers=self._headers, json={"block_duration": block_duration}
|
|
1708
|
+
)
|
|
1709
|
+
if response.status != 200:
|
|
1710
|
+
raise RuntimeError(
|
|
1711
|
+
f"Failed to abort request, detail: {await _get_error_string(response)}"
|
|
1712
|
+
)
|
|
1713
|
+
|
|
1714
|
+
response_data = await response.json()
|
|
1715
|
+
await _release_response(response)
|
|
1716
|
+
return response_data
|
|
1717
|
+
|
|
1718
|
+
async def get_workers_info(self):
|
|
1719
|
+
url = f"{self.base_url}/v1/workers"
|
|
1720
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1721
|
+
if response.status != 200:
|
|
1722
|
+
raise RuntimeError(
|
|
1723
|
+
f"Failed to get workers info, detail: {await _get_error_string(response)}"
|
|
1724
|
+
)
|
|
1725
|
+
response_data = await response.json()
|
|
1726
|
+
await _release_response(response)
|
|
1727
|
+
return response_data
|
|
1728
|
+
|
|
1729
|
+
async def get_supervisor_info(self):
|
|
1730
|
+
url = f"{self.base_url}/v1/supervisor"
|
|
1731
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1732
|
+
if response.status != 200:
|
|
1733
|
+
raise RuntimeError(
|
|
1734
|
+
f"Failed to get supervisor info, detail: {await _get_error_string(response)}"
|
|
1735
|
+
)
|
|
1736
|
+
response_json = await response.json()
|
|
1737
|
+
await _release_response(response)
|
|
1738
|
+
return response_json
|
|
1739
|
+
|
|
1740
|
+
async def get_progress(self, request_id: str):
|
|
1741
|
+
url = f"{self.base_url}/v1/requests/{request_id}/progress"
|
|
1742
|
+
response = await self.session.get(url, headers=self._headers)
|
|
1743
|
+
if response.status != 200:
|
|
1744
|
+
raise RuntimeError(
|
|
1745
|
+
f"Failed to get progress, detail: {await _get_error_string(response)}"
|
|
1746
|
+
)
|
|
1747
|
+
response_json = await response.json()
|
|
1748
|
+
await _release_response(response)
|
|
1749
|
+
return response_json
|
|
1750
|
+
|
|
1751
|
+
async def abort_cluster(self):
|
|
1752
|
+
url = f"{self.base_url}/v1/clusters"
|
|
1753
|
+
response = await self.session.delete(url, headers=self._headers)
|
|
1754
|
+
if response.status != 200:
|
|
1755
|
+
raise RuntimeError(
|
|
1756
|
+
f"Failed to abort cluster, detail: {await _get_error_string(response)}"
|
|
1757
|
+
)
|
|
1758
|
+
response_json = await response.json()
|
|
1759
|
+
await _release_response(response)
|
|
1760
|
+
return response_json
|