xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +29 -2
- xinference/client/restful/restful_client.py +10 -0
- xinference/constants.py +7 -3
- xinference/core/image_interface.py +76 -23
- xinference/core/model.py +158 -46
- xinference/core/progress_tracker.py +187 -0
- xinference/core/scheduler.py +10 -7
- xinference/core/supervisor.py +11 -0
- xinference/core/utils.py +9 -0
- xinference/core/worker.py +1 -0
- xinference/deploy/supervisor.py +4 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +2 -1
- xinference/model/audio/core.py +0 -2
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/image/core.py +6 -7
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/sdapi.py +35 -4
- xinference/model/image/stable_diffusion/core.py +215 -110
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +185 -17
- xinference/model/llm/llm_family_modelscope.json +124 -12
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/cogvlm2.py +2 -1
- xinference/model/llm/transformers/cogvlm2_video.py +2 -0
- xinference/model/llm/transformers/core.py +43 -113
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/deepseek_vl.py +2 -0
- xinference/model/llm/transformers/glm4v.py +2 -1
- xinference/model/llm/transformers/intern_vl.py +2 -0
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/minicpmv25.py +2 -0
- xinference/model/llm/transformers/minicpmv26.py +2 -0
- xinference/model/llm/transformers/omnilmm.py +2 -0
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/qwen2_audio.py +11 -4
- xinference/model/llm/transformers/qwen2_vl.py +2 -28
- xinference/model/llm/transformers/qwen_vl.py +2 -1
- xinference/model/llm/transformers/utils.py +36 -283
- xinference/model/llm/transformers/yi_vl.py +2 -0
- xinference/model/llm/utils.py +60 -16
- xinference/model/llm/vllm/core.py +68 -9
- xinference/model/llm/vllm/utils.py +0 -1
- xinference/model/utils.py +7 -4
- xinference/model/video/core.py +0 -2
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
- xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
xinference/__init__.py
CHANGED
|
@@ -26,13 +26,9 @@ except:
|
|
|
26
26
|
def _install():
|
|
27
27
|
from xoscar.backends.router import Router
|
|
28
28
|
|
|
29
|
-
from .model import _install as install_model
|
|
30
|
-
|
|
31
29
|
default_router = Router.get_instance_or_empty()
|
|
32
30
|
Router.set_instance(default_router)
|
|
33
31
|
|
|
34
|
-
install_model()
|
|
35
|
-
|
|
36
32
|
|
|
37
33
|
_install()
|
|
38
34
|
del _install
|
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-10-18T12:49:02+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "5f7dea44832a1c41f887b9a01377191894550057",
|
|
15
|
+
"version": "0.16.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -524,6 +524,16 @@ class RESTfulAPI:
|
|
|
524
524
|
else None
|
|
525
525
|
),
|
|
526
526
|
)
|
|
527
|
+
self._router.add_api_route(
|
|
528
|
+
"/v1/requests/{request_id}/progress",
|
|
529
|
+
self.get_progress,
|
|
530
|
+
methods=["get"],
|
|
531
|
+
dependencies=(
|
|
532
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
533
|
+
if self.is_authenticated()
|
|
534
|
+
else None
|
|
535
|
+
),
|
|
536
|
+
)
|
|
527
537
|
self._router.add_api_route(
|
|
528
538
|
"/v1/images/generations",
|
|
529
539
|
self.create_images,
|
|
@@ -1486,6 +1496,17 @@ class RESTfulAPI:
|
|
|
1486
1496
|
await self._report_error_event(model_uid, str(e))
|
|
1487
1497
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1488
1498
|
|
|
1499
|
+
async def get_progress(self, request_id: str) -> JSONResponse:
|
|
1500
|
+
try:
|
|
1501
|
+
supervisor_ref = await self._get_supervisor_ref()
|
|
1502
|
+
result = {"progress": await supervisor_ref.get_progress(request_id)}
|
|
1503
|
+
return JSONResponse(content=result)
|
|
1504
|
+
except KeyError as e:
|
|
1505
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
1506
|
+
except Exception as e:
|
|
1507
|
+
logger.error(e, exc_info=True)
|
|
1508
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1509
|
+
|
|
1489
1510
|
async def create_images(self, request: Request) -> Response:
|
|
1490
1511
|
body = TextToImageRequest.parse_obj(await request.json())
|
|
1491
1512
|
model_uid = body.model
|
|
@@ -1853,10 +1874,16 @@ class RESTfulAPI:
|
|
|
1853
1874
|
await self._report_error_event(model_uid, str(e))
|
|
1854
1875
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1855
1876
|
|
|
1856
|
-
from ..model.llm.utils import
|
|
1877
|
+
from ..model.llm.utils import (
|
|
1878
|
+
GLM4_TOOL_CALL_FAMILY,
|
|
1879
|
+
LLAMA3_TOOL_CALL_FAMILY,
|
|
1880
|
+
QWEN_TOOL_CALL_FAMILY,
|
|
1881
|
+
)
|
|
1857
1882
|
|
|
1858
1883
|
model_family = desc.get("model_family", "")
|
|
1859
|
-
function_call_models =
|
|
1884
|
+
function_call_models = (
|
|
1885
|
+
QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY + LLAMA3_TOOL_CALL_FAMILY
|
|
1886
|
+
)
|
|
1860
1887
|
|
|
1861
1888
|
if model_family not in function_call_models:
|
|
1862
1889
|
if body.tools:
|
|
@@ -1385,6 +1385,16 @@ class Client:
|
|
|
1385
1385
|
response_json = response.json()
|
|
1386
1386
|
return response_json
|
|
1387
1387
|
|
|
1388
|
+
def get_progress(self, request_id: str):
|
|
1389
|
+
url = f"{self.base_url}/v1/requests/{request_id}/progress"
|
|
1390
|
+
response = requests.get(url, headers=self._headers)
|
|
1391
|
+
if response.status_code != 200:
|
|
1392
|
+
raise RuntimeError(
|
|
1393
|
+
f"Failed to get progress, detail: {_get_error_string(response)}"
|
|
1394
|
+
)
|
|
1395
|
+
response_json = response.json()
|
|
1396
|
+
return response_json
|
|
1397
|
+
|
|
1388
1398
|
def abort_cluster(self):
|
|
1389
1399
|
url = f"{self.base_url}/v1/clusters"
|
|
1390
1400
|
response = requests.delete(url, headers=self._headers)
|
xinference/constants.py
CHANGED
|
@@ -27,7 +27,8 @@ XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
|
|
|
27
27
|
XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
|
|
28
28
|
XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
29
29
|
XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
|
|
30
|
-
|
|
30
|
+
XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS"
|
|
31
|
+
XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE = "XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE"
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def get_xinference_home() -> str:
|
|
@@ -79,6 +80,9 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
|
|
|
79
80
|
XINFERENCE_DISABLE_METRICS = bool(
|
|
80
81
|
int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
|
|
81
82
|
)
|
|
82
|
-
|
|
83
|
-
|
|
83
|
+
XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
|
|
84
|
+
os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3)
|
|
85
|
+
)
|
|
86
|
+
XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
|
|
87
|
+
XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
|
|
84
88
|
)
|
|
@@ -16,6 +16,9 @@ import base64
|
|
|
16
16
|
import io
|
|
17
17
|
import logging
|
|
18
18
|
import os
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
import uuid
|
|
19
22
|
from typing import Dict, List, Optional, Union
|
|
20
23
|
|
|
21
24
|
import gradio as gr
|
|
@@ -84,6 +87,7 @@ class ImageInterface:
|
|
|
84
87
|
num_inference_steps: int,
|
|
85
88
|
negative_prompt: Optional[str] = None,
|
|
86
89
|
sampler_name: Optional[str] = None,
|
|
90
|
+
progress=gr.Progress(),
|
|
87
91
|
) -> PIL.Image.Image:
|
|
88
92
|
from ..client import RESTfulClient
|
|
89
93
|
|
|
@@ -99,19 +103,43 @@ class ImageInterface:
|
|
|
99
103
|
)
|
|
100
104
|
sampler_name = None if sampler_name == "default" else sampler_name
|
|
101
105
|
|
|
102
|
-
response =
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
106
|
+
response = None
|
|
107
|
+
exc = None
|
|
108
|
+
request_id = str(uuid.uuid4())
|
|
109
|
+
|
|
110
|
+
def run_in_thread():
|
|
111
|
+
nonlocal exc, response
|
|
112
|
+
try:
|
|
113
|
+
response = model.text_to_image(
|
|
114
|
+
request_id=request_id,
|
|
115
|
+
prompt=prompt,
|
|
116
|
+
n=n,
|
|
117
|
+
size=size,
|
|
118
|
+
num_inference_steps=num_inference_steps,
|
|
119
|
+
guidance_scale=guidance_scale,
|
|
120
|
+
negative_prompt=negative_prompt,
|
|
121
|
+
sampler_name=sampler_name,
|
|
122
|
+
response_format="b64_json",
|
|
123
|
+
)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
exc = e
|
|
126
|
+
|
|
127
|
+
t = threading.Thread(target=run_in_thread)
|
|
128
|
+
t.start()
|
|
129
|
+
while t.is_alive():
|
|
130
|
+
try:
|
|
131
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
132
|
+
except (KeyError, RuntimeError):
|
|
133
|
+
cur_progress = 0.0
|
|
134
|
+
|
|
135
|
+
progress(cur_progress, desc="Generating images")
|
|
136
|
+
time.sleep(1)
|
|
137
|
+
|
|
138
|
+
if exc:
|
|
139
|
+
raise exc
|
|
112
140
|
|
|
113
141
|
images = []
|
|
114
|
-
for image_dict in response["data"]:
|
|
142
|
+
for image_dict in response["data"]: # type: ignore
|
|
115
143
|
assert image_dict["b64_json"] is not None
|
|
116
144
|
image_data = base64.b64decode(image_dict["b64_json"])
|
|
117
145
|
image = PIL.Image.open(io.BytesIO(image_data))
|
|
@@ -184,6 +212,7 @@ class ImageInterface:
|
|
|
184
212
|
num_inference_steps: int,
|
|
185
213
|
padding_image_to_multiple: int,
|
|
186
214
|
sampler_name: Optional[str] = None,
|
|
215
|
+
progress=gr.Progress(),
|
|
187
216
|
) -> PIL.Image.Image:
|
|
188
217
|
from ..client import RESTfulClient
|
|
189
218
|
|
|
@@ -205,20 +234,44 @@ class ImageInterface:
|
|
|
205
234
|
bio = io.BytesIO()
|
|
206
235
|
image.save(bio, format="png")
|
|
207
236
|
|
|
208
|
-
response =
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
237
|
+
response = None
|
|
238
|
+
exc = None
|
|
239
|
+
request_id = str(uuid.uuid4())
|
|
240
|
+
|
|
241
|
+
def run_in_thread():
|
|
242
|
+
nonlocal exc, response
|
|
243
|
+
try:
|
|
244
|
+
response = model.image_to_image(
|
|
245
|
+
request_id=request_id,
|
|
246
|
+
prompt=prompt,
|
|
247
|
+
negative_prompt=negative_prompt,
|
|
248
|
+
n=n,
|
|
249
|
+
image=bio.getvalue(),
|
|
250
|
+
size=size,
|
|
251
|
+
response_format="b64_json",
|
|
252
|
+
num_inference_steps=num_inference_steps,
|
|
253
|
+
padding_image_to_multiple=padding_image_to_multiple,
|
|
254
|
+
sampler_name=sampler_name,
|
|
255
|
+
)
|
|
256
|
+
except Exception as e:
|
|
257
|
+
exc = e
|
|
258
|
+
|
|
259
|
+
t = threading.Thread(target=run_in_thread)
|
|
260
|
+
t.start()
|
|
261
|
+
while t.is_alive():
|
|
262
|
+
try:
|
|
263
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
264
|
+
except (KeyError, RuntimeError):
|
|
265
|
+
cur_progress = 0.0
|
|
266
|
+
|
|
267
|
+
progress(cur_progress, desc="Generating images")
|
|
268
|
+
time.sleep(1)
|
|
269
|
+
|
|
270
|
+
if exc:
|
|
271
|
+
raise exc
|
|
219
272
|
|
|
220
273
|
images = []
|
|
221
|
-
for image_dict in response["data"]:
|
|
274
|
+
for image_dict in response["data"]: # type: ignore
|
|
222
275
|
assert image_dict["b64_json"] is not None
|
|
223
276
|
image_data = base64.b64decode(image_dict["b64_json"])
|
|
224
277
|
image = PIL.Image.open(io.BytesIO(image_data))
|
xinference/core/model.py
CHANGED
|
@@ -41,9 +41,10 @@ from typing import (
|
|
|
41
41
|
import sse_starlette.sse
|
|
42
42
|
import xoscar as xo
|
|
43
43
|
|
|
44
|
-
from ..constants import
|
|
44
|
+
from ..constants import XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE
|
|
45
45
|
|
|
46
46
|
if TYPE_CHECKING:
|
|
47
|
+
from .progress_tracker import ProgressTrackerActor
|
|
47
48
|
from .worker import WorkerActor
|
|
48
49
|
from ..model.llm.core import LLM
|
|
49
50
|
from ..model.core import ModelDescription
|
|
@@ -73,6 +74,8 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
|
|
|
73
74
|
"MiniCPM-V-2.6",
|
|
74
75
|
]
|
|
75
76
|
|
|
77
|
+
XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
|
|
78
|
+
|
|
76
79
|
|
|
77
80
|
def request_limit(fn):
|
|
78
81
|
"""
|
|
@@ -152,6 +155,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
152
155
|
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
153
156
|
)
|
|
154
157
|
|
|
158
|
+
if self.allow_batching_for_text_to_image():
|
|
159
|
+
try:
|
|
160
|
+
assert self._text_to_image_scheduler_ref is not None
|
|
161
|
+
await xo.destroy_actor(self._text_to_image_scheduler_ref)
|
|
162
|
+
del self._text_to_image_scheduler_ref
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.debug(
|
|
165
|
+
f"Destroy text_to_image scheduler actor failed, address: {self.address}, error: {e}"
|
|
166
|
+
)
|
|
167
|
+
|
|
155
168
|
if hasattr(self._model, "stop") and callable(self._model.stop):
|
|
156
169
|
self._model.stop()
|
|
157
170
|
|
|
@@ -177,6 +190,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
177
190
|
|
|
178
191
|
def __init__(
|
|
179
192
|
self,
|
|
193
|
+
supervisor_address: str,
|
|
180
194
|
worker_address: str,
|
|
181
195
|
model: "LLM",
|
|
182
196
|
model_description: Optional["ModelDescription"] = None,
|
|
@@ -188,6 +202,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
188
202
|
from ..model.llm.transformers.core import PytorchModel
|
|
189
203
|
from ..model.llm.vllm.core import VLLMModel
|
|
190
204
|
|
|
205
|
+
self._supervisor_address = supervisor_address
|
|
191
206
|
self._worker_address = worker_address
|
|
192
207
|
self._model = model
|
|
193
208
|
self._model_description = (
|
|
@@ -205,6 +220,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
205
220
|
else asyncio.locks.Lock()
|
|
206
221
|
)
|
|
207
222
|
self._worker_ref = None
|
|
223
|
+
self._progress_tracker_ref = None
|
|
208
224
|
self._serve_count = 0
|
|
209
225
|
self._metrics_labels = {
|
|
210
226
|
"type": self._model_description.get("model_type", "unknown"),
|
|
@@ -216,6 +232,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
216
232
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
217
233
|
|
|
218
234
|
self._scheduler_ref = None
|
|
235
|
+
self._text_to_image_scheduler_ref = None
|
|
219
236
|
|
|
220
237
|
async def __post_create__(self):
|
|
221
238
|
self._loop = asyncio.get_running_loop()
|
|
@@ -229,6 +246,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
229
246
|
uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
|
|
230
247
|
)
|
|
231
248
|
|
|
249
|
+
if self.allow_batching_for_text_to_image():
|
|
250
|
+
from ..model.image.scheduler.flux import FluxBatchSchedulerActor
|
|
251
|
+
|
|
252
|
+
self._text_to_image_scheduler_ref = await xo.create_actor(
|
|
253
|
+
FluxBatchSchedulerActor,
|
|
254
|
+
address=self.address,
|
|
255
|
+
uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()),
|
|
256
|
+
)
|
|
257
|
+
|
|
232
258
|
async def _record_completion_metrics(
|
|
233
259
|
self, duration, completion_tokens, prompt_tokens
|
|
234
260
|
):
|
|
@@ -275,6 +301,28 @@ class ModelActor(xo.StatelessActor):
|
|
|
275
301
|
)
|
|
276
302
|
return self._worker_ref
|
|
277
303
|
|
|
304
|
+
async def _get_progress_tracker_ref(
|
|
305
|
+
self,
|
|
306
|
+
) -> xo.ActorRefType["ProgressTrackerActor"]:
|
|
307
|
+
from .progress_tracker import ProgressTrackerActor
|
|
308
|
+
|
|
309
|
+
if self._progress_tracker_ref is None:
|
|
310
|
+
self._progress_tracker_ref = await xo.actor_ref(
|
|
311
|
+
address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
|
|
312
|
+
)
|
|
313
|
+
return self._progress_tracker_ref
|
|
314
|
+
|
|
315
|
+
async def _get_progressor(self, request_id: str):
|
|
316
|
+
from .progress_tracker import Progressor
|
|
317
|
+
|
|
318
|
+
progressor = Progressor(
|
|
319
|
+
request_id,
|
|
320
|
+
await self._get_progress_tracker_ref(),
|
|
321
|
+
asyncio.get_running_loop(),
|
|
322
|
+
)
|
|
323
|
+
await progressor.start()
|
|
324
|
+
return progressor
|
|
325
|
+
|
|
278
326
|
def is_vllm_backend(self) -> bool:
|
|
279
327
|
from ..model.llm.vllm.core import VLLMModel
|
|
280
328
|
|
|
@@ -285,10 +333,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
285
333
|
|
|
286
334
|
model_ability = self._model_description.get("model_ability", [])
|
|
287
335
|
|
|
288
|
-
condition =
|
|
289
|
-
|
|
290
|
-
)
|
|
291
|
-
if condition and "vision" in model_ability:
|
|
336
|
+
condition = isinstance(self._model, PytorchModel)
|
|
337
|
+
if condition and ("vision" in model_ability or "audio" in model_ability):
|
|
292
338
|
if (
|
|
293
339
|
self._model.model_family.model_name
|
|
294
340
|
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
|
|
@@ -305,6 +351,26 @@ class ModelActor(xo.StatelessActor):
|
|
|
305
351
|
return False
|
|
306
352
|
return condition
|
|
307
353
|
|
|
354
|
+
def allow_batching_for_text_to_image(self) -> bool:
|
|
355
|
+
from ..model.image.stable_diffusion.core import DiffusionModel
|
|
356
|
+
|
|
357
|
+
condition = XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE is not None and isinstance(
|
|
358
|
+
self._model, DiffusionModel
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if condition:
|
|
362
|
+
model_name = self._model._model_spec.model_name # type: ignore
|
|
363
|
+
if model_name in XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS:
|
|
364
|
+
return True
|
|
365
|
+
else:
|
|
366
|
+
logger.warning(
|
|
367
|
+
f"Currently for image models with text_to_image ability, "
|
|
368
|
+
f"xinference only supports {', '.join(XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS)} for batching. "
|
|
369
|
+
f"Your model {model_name} is disqualified."
|
|
370
|
+
)
|
|
371
|
+
return False
|
|
372
|
+
return condition
|
|
373
|
+
|
|
308
374
|
async def load(self):
|
|
309
375
|
self._model.load()
|
|
310
376
|
if self.allow_batching():
|
|
@@ -312,6 +378,11 @@ class ModelActor(xo.StatelessActor):
|
|
|
312
378
|
logger.debug(
|
|
313
379
|
f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
|
|
314
380
|
)
|
|
381
|
+
if self.allow_batching_for_text_to_image():
|
|
382
|
+
await self._text_to_image_scheduler_ref.set_model(self._model)
|
|
383
|
+
logger.debug(
|
|
384
|
+
f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}"
|
|
385
|
+
)
|
|
315
386
|
|
|
316
387
|
def model_uid(self):
|
|
317
388
|
return (
|
|
@@ -591,12 +662,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
591
662
|
)
|
|
592
663
|
|
|
593
664
|
async def abort_request(self, request_id: str) -> str:
|
|
594
|
-
from .
|
|
665
|
+
from .utils import AbortRequestMessage
|
|
595
666
|
|
|
596
667
|
if self.allow_batching():
|
|
597
668
|
if self._scheduler_ref is None:
|
|
598
669
|
return AbortRequestMessage.NOT_FOUND.name
|
|
599
670
|
return await self._scheduler_ref.abort_request(request_id)
|
|
671
|
+
elif self.allow_batching_for_text_to_image():
|
|
672
|
+
if self._text_to_image_scheduler_ref is None:
|
|
673
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
674
|
+
return await self._text_to_image_scheduler_ref.abort_request(request_id)
|
|
600
675
|
return AbortRequestMessage.NO_OP.name
|
|
601
676
|
|
|
602
677
|
@request_limit
|
|
@@ -721,6 +796,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
721
796
|
f"Model {self._model.model_spec} is not for creating speech."
|
|
722
797
|
)
|
|
723
798
|
|
|
799
|
+
async def handle_image_batching_request(self, unique_id, *args, **kwargs):
|
|
800
|
+
size = args[2]
|
|
801
|
+
if XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE != size:
|
|
802
|
+
raise RuntimeError(
|
|
803
|
+
f"The image size: {size} of text_to_image for batching "
|
|
804
|
+
f"must be the same as the environment variable: {XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE} you set."
|
|
805
|
+
)
|
|
806
|
+
assert self._loop is not None
|
|
807
|
+
future = ConcurrentFuture()
|
|
808
|
+
await self._text_to_image_scheduler_ref.add_request(
|
|
809
|
+
unique_id, future, *args, **kwargs
|
|
810
|
+
)
|
|
811
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
812
|
+
result = await fut
|
|
813
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
814
|
+
|
|
724
815
|
@request_limit
|
|
725
816
|
@log_async(logger=logger)
|
|
726
817
|
async def text_to_image(
|
|
@@ -732,17 +823,26 @@ class ModelActor(xo.StatelessActor):
|
|
|
732
823
|
*args,
|
|
733
824
|
**kwargs,
|
|
734
825
|
):
|
|
735
|
-
kwargs.pop("request_id", None)
|
|
736
826
|
if hasattr(self._model, "text_to_image"):
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
827
|
+
if self.allow_batching_for_text_to_image():
|
|
828
|
+
unique_id = kwargs.pop("request_id", None)
|
|
829
|
+
return await self.handle_image_batching_request(
|
|
830
|
+
unique_id, prompt, n, size, response_format, *args, **kwargs
|
|
831
|
+
)
|
|
832
|
+
else:
|
|
833
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
834
|
+
kwargs.pop("request_id", None)
|
|
835
|
+
)
|
|
836
|
+
with progressor:
|
|
837
|
+
return await self._call_wrapper_json(
|
|
838
|
+
self._model.text_to_image,
|
|
839
|
+
prompt,
|
|
840
|
+
n,
|
|
841
|
+
size,
|
|
842
|
+
response_format,
|
|
843
|
+
*args,
|
|
844
|
+
**kwargs,
|
|
845
|
+
)
|
|
746
846
|
raise AttributeError(
|
|
747
847
|
f"Model {self._model.model_spec} is not for creating image."
|
|
748
848
|
)
|
|
@@ -753,12 +853,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
753
853
|
self,
|
|
754
854
|
**kwargs,
|
|
755
855
|
):
|
|
756
|
-
kwargs.pop("request_id", None)
|
|
757
856
|
if hasattr(self._model, "txt2img"):
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
**kwargs,
|
|
857
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
858
|
+
kwargs.pop("request_id", None)
|
|
761
859
|
)
|
|
860
|
+
with progressor:
|
|
861
|
+
return await self._call_wrapper_json(
|
|
862
|
+
self._model.txt2img,
|
|
863
|
+
**kwargs,
|
|
864
|
+
)
|
|
762
865
|
raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
|
|
763
866
|
|
|
764
867
|
@log_async(
|
|
@@ -776,19 +879,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
776
879
|
*args,
|
|
777
880
|
**kwargs,
|
|
778
881
|
):
|
|
779
|
-
kwargs.pop("request_id", None)
|
|
780
882
|
kwargs["negative_prompt"] = negative_prompt
|
|
781
883
|
if hasattr(self._model, "image_to_image"):
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
image,
|
|
785
|
-
prompt,
|
|
786
|
-
n,
|
|
787
|
-
size,
|
|
788
|
-
response_format,
|
|
789
|
-
*args,
|
|
790
|
-
**kwargs,
|
|
884
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
885
|
+
kwargs.pop("request_id", None)
|
|
791
886
|
)
|
|
887
|
+
with progressor:
|
|
888
|
+
return await self._call_wrapper_json(
|
|
889
|
+
self._model.image_to_image,
|
|
890
|
+
image,
|
|
891
|
+
prompt,
|
|
892
|
+
n,
|
|
893
|
+
size,
|
|
894
|
+
response_format,
|
|
895
|
+
*args,
|
|
896
|
+
**kwargs,
|
|
897
|
+
)
|
|
792
898
|
raise AttributeError(
|
|
793
899
|
f"Model {self._model.model_spec} is not for creating image."
|
|
794
900
|
)
|
|
@@ -799,12 +905,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
799
905
|
self,
|
|
800
906
|
**kwargs,
|
|
801
907
|
):
|
|
802
|
-
kwargs.pop("request_id", None)
|
|
803
908
|
if hasattr(self._model, "img2img"):
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
**kwargs,
|
|
909
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
910
|
+
kwargs.pop("request_id", None)
|
|
807
911
|
)
|
|
912
|
+
with progressor:
|
|
913
|
+
return await self._call_wrapper_json(
|
|
914
|
+
self._model.img2img,
|
|
915
|
+
**kwargs,
|
|
916
|
+
)
|
|
808
917
|
raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
|
|
809
918
|
|
|
810
919
|
@log_async(
|
|
@@ -823,20 +932,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
823
932
|
*args,
|
|
824
933
|
**kwargs,
|
|
825
934
|
):
|
|
826
|
-
kwargs
|
|
935
|
+
kwargs["negative_prompt"] = negative_prompt
|
|
827
936
|
if hasattr(self._model, "inpainting"):
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
image,
|
|
831
|
-
mask_image,
|
|
832
|
-
prompt,
|
|
833
|
-
negative_prompt,
|
|
834
|
-
n,
|
|
835
|
-
size,
|
|
836
|
-
response_format,
|
|
837
|
-
*args,
|
|
838
|
-
**kwargs,
|
|
937
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
938
|
+
kwargs.pop("request_id", None)
|
|
839
939
|
)
|
|
940
|
+
with progressor:
|
|
941
|
+
return await self._call_wrapper_json(
|
|
942
|
+
self._model.inpainting,
|
|
943
|
+
image,
|
|
944
|
+
mask_image,
|
|
945
|
+
prompt,
|
|
946
|
+
n,
|
|
947
|
+
size,
|
|
948
|
+
response_format,
|
|
949
|
+
*args,
|
|
950
|
+
**kwargs,
|
|
951
|
+
)
|
|
840
952
|
raise AttributeError(
|
|
841
953
|
f"Model {self._model.model_spec} is not for creating image."
|
|
842
954
|
)
|