xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +4 -4
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +195 -34
- xinference/core/scheduler.py +10 -7
- xinference/core/utils.py +9 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +59 -4
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +76 -0
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/stable_diffusion/core.py +8 -34
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +178 -1
- xinference/model/llm/llm_family_modelscope.json +119 -0
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/core.py +37 -111
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/utils.py +4 -284
- xinference/model/llm/utils.py +2 -2
- xinference/model/llm/vllm/core.py +16 -1
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
- xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
xinference/__init__.py
CHANGED
|
@@ -26,13 +26,9 @@ except:
|
|
|
26
26
|
def _install():
|
|
27
27
|
from xoscar.backends.router import Router
|
|
28
28
|
|
|
29
|
-
from .model import _install as install_model
|
|
30
|
-
|
|
31
29
|
default_router = Router.get_instance_or_empty()
|
|
32
30
|
Router.set_instance(default_router)
|
|
33
31
|
|
|
34
|
-
install_model()
|
|
35
|
-
|
|
36
32
|
|
|
37
33
|
_install()
|
|
38
34
|
del _install
|
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-10-
|
|
11
|
+
"date": "2024-10-25T12:51:06+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "d4cd7b15104c16838e3c562cf2d33337e3d38897",
|
|
15
|
+
"version": "0.16.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -567,6 +567,16 @@ class RESTfulAPI:
|
|
|
567
567
|
else None
|
|
568
568
|
),
|
|
569
569
|
)
|
|
570
|
+
self._router.add_api_route(
|
|
571
|
+
"/v1/images/ocr",
|
|
572
|
+
self.create_ocr,
|
|
573
|
+
methods=["POST"],
|
|
574
|
+
dependencies=(
|
|
575
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
576
|
+
if self.is_authenticated()
|
|
577
|
+
else None
|
|
578
|
+
),
|
|
579
|
+
)
|
|
570
580
|
# SD WebUI API
|
|
571
581
|
self._router.add_api_route(
|
|
572
582
|
"/sdapi/v1/options",
|
|
@@ -1754,6 +1764,44 @@ class RESTfulAPI:
|
|
|
1754
1764
|
await self._report_error_event(model_uid, str(e))
|
|
1755
1765
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1756
1766
|
|
|
1767
|
+
async def create_ocr(
|
|
1768
|
+
self,
|
|
1769
|
+
model: str = Form(...),
|
|
1770
|
+
image: UploadFile = File(media_type="application/octet-stream"),
|
|
1771
|
+
kwargs: Optional[str] = Form(None),
|
|
1772
|
+
) -> Response:
|
|
1773
|
+
model_uid = model
|
|
1774
|
+
try:
|
|
1775
|
+
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1776
|
+
except ValueError as ve:
|
|
1777
|
+
logger.error(str(ve), exc_info=True)
|
|
1778
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1779
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1780
|
+
except Exception as e:
|
|
1781
|
+
logger.error(e, exc_info=True)
|
|
1782
|
+
await self._report_error_event(model_uid, str(e))
|
|
1783
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1784
|
+
|
|
1785
|
+
try:
|
|
1786
|
+
if kwargs is not None:
|
|
1787
|
+
parsed_kwargs = json.loads(kwargs)
|
|
1788
|
+
else:
|
|
1789
|
+
parsed_kwargs = {}
|
|
1790
|
+
im = Image.open(image.file)
|
|
1791
|
+
text = await model_ref.ocr(
|
|
1792
|
+
image=im,
|
|
1793
|
+
**parsed_kwargs,
|
|
1794
|
+
)
|
|
1795
|
+
return Response(content=text, media_type="text/plain")
|
|
1796
|
+
except RuntimeError as re:
|
|
1797
|
+
logger.error(re, exc_info=True)
|
|
1798
|
+
await self._report_error_event(model_uid, str(re))
|
|
1799
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1800
|
+
except Exception as e:
|
|
1801
|
+
logger.error(e, exc_info=True)
|
|
1802
|
+
await self._report_error_event(model_uid, str(e))
|
|
1803
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1804
|
+
|
|
1757
1805
|
async def create_flexible_infer(self, request: Request) -> Response:
|
|
1758
1806
|
payload = await request.json()
|
|
1759
1807
|
|
|
@@ -369,6 +369,25 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
|
|
|
369
369
|
response_data = response.json()
|
|
370
370
|
return response_data
|
|
371
371
|
|
|
372
|
+
def ocr(self, image: Union[str, bytes], **kwargs):
|
|
373
|
+
url = f"{self._base_url}/v1/images/ocr"
|
|
374
|
+
params = {
|
|
375
|
+
"model": self._model_uid,
|
|
376
|
+
"kwargs": json.dumps(kwargs),
|
|
377
|
+
}
|
|
378
|
+
files: List[Any] = []
|
|
379
|
+
for key, value in params.items():
|
|
380
|
+
files.append((key, (None, value)))
|
|
381
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
382
|
+
response = requests.post(url, files=files, headers=self.auth_headers)
|
|
383
|
+
if response.status_code != 200:
|
|
384
|
+
raise RuntimeError(
|
|
385
|
+
f"Failed to ocr the images, detail: {_get_error_string(response)}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
response_data = response.json()
|
|
389
|
+
return response_data
|
|
390
|
+
|
|
372
391
|
|
|
373
392
|
class RESTfulVideoModelHandle(RESTfulModelHandle):
|
|
374
393
|
def text_to_video(
|
xinference/constants.py
CHANGED
|
@@ -27,8 +27,8 @@ XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
|
|
|
27
27
|
XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
|
|
28
28
|
XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
29
29
|
XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
|
|
30
|
-
XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
|
|
31
30
|
XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS"
|
|
31
|
+
XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE = "XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE"
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def get_xinference_home() -> str:
|
|
@@ -80,9 +80,9 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
|
|
|
80
80
|
XINFERENCE_DISABLE_METRICS = bool(
|
|
81
81
|
int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
|
|
82
82
|
)
|
|
83
|
-
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
|
|
84
|
-
int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
|
|
85
|
-
)
|
|
86
83
|
XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
|
|
87
84
|
os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3)
|
|
88
85
|
)
|
|
86
|
+
XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
|
|
87
|
+
XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
|
|
88
|
+
)
|
|
@@ -74,7 +74,11 @@ class GradioInterface:
|
|
|
74
74
|
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
75
75
|
# started, that event will not run, so manually invoke the startup events.
|
|
76
76
|
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
77
|
-
|
|
77
|
+
try:
|
|
78
|
+
interface.run_startup_events()
|
|
79
|
+
except AttributeError:
|
|
80
|
+
# compatibility
|
|
81
|
+
interface.startup_events()
|
|
78
82
|
favicon_path = os.path.join(
|
|
79
83
|
os.path.dirname(os.path.abspath(__file__)),
|
|
80
84
|
os.path.pardir,
|
|
@@ -63,7 +63,11 @@ class ImageInterface:
|
|
|
63
63
|
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
64
64
|
# started, that event will not run, so manually invoke the startup events.
|
|
65
65
|
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
66
|
-
|
|
66
|
+
try:
|
|
67
|
+
interface.run_startup_events()
|
|
68
|
+
except AttributeError:
|
|
69
|
+
# compatibility
|
|
70
|
+
interface.startup_events()
|
|
67
71
|
favicon_path = os.path.join(
|
|
68
72
|
os.path.dirname(os.path.abspath(__file__)),
|
|
69
73
|
os.path.pardir,
|
xinference/core/model.py
CHANGED
|
@@ -17,10 +17,10 @@ import functools
|
|
|
17
17
|
import inspect
|
|
18
18
|
import json
|
|
19
19
|
import os
|
|
20
|
+
import queue
|
|
20
21
|
import time
|
|
21
22
|
import types
|
|
22
23
|
import uuid
|
|
23
|
-
import weakref
|
|
24
24
|
from asyncio.queues import Queue
|
|
25
25
|
from asyncio.tasks import wait_for
|
|
26
26
|
from concurrent.futures import Future as ConcurrentFuture
|
|
@@ -32,7 +32,6 @@ from typing import (
|
|
|
32
32
|
Callable,
|
|
33
33
|
Dict,
|
|
34
34
|
Generator,
|
|
35
|
-
Iterator,
|
|
36
35
|
List,
|
|
37
36
|
Optional,
|
|
38
37
|
Union,
|
|
@@ -41,7 +40,7 @@ from typing import (
|
|
|
41
40
|
import sse_starlette.sse
|
|
42
41
|
import xoscar as xo
|
|
43
42
|
|
|
44
|
-
from ..constants import
|
|
43
|
+
from ..constants import XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE
|
|
45
44
|
|
|
46
45
|
if TYPE_CHECKING:
|
|
47
46
|
from .progress_tracker import ProgressTrackerActor
|
|
@@ -74,6 +73,8 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
|
|
|
74
73
|
"MiniCPM-V-2.6",
|
|
75
74
|
]
|
|
76
75
|
|
|
76
|
+
XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
|
|
77
|
+
|
|
77
78
|
|
|
78
79
|
def request_limit(fn):
|
|
79
80
|
"""
|
|
@@ -153,6 +154,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
153
154
|
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
154
155
|
)
|
|
155
156
|
|
|
157
|
+
if self.allow_batching_for_text_to_image():
|
|
158
|
+
try:
|
|
159
|
+
assert self._text_to_image_scheduler_ref is not None
|
|
160
|
+
await xo.destroy_actor(self._text_to_image_scheduler_ref)
|
|
161
|
+
del self._text_to_image_scheduler_ref
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.debug(
|
|
164
|
+
f"Destroy text_to_image scheduler actor failed, address: {self.address}, error: {e}"
|
|
165
|
+
)
|
|
166
|
+
|
|
156
167
|
if hasattr(self._model, "stop") and callable(self._model.stop):
|
|
157
168
|
self._model.stop()
|
|
158
169
|
|
|
@@ -197,9 +208,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
197
208
|
model_description.to_dict() if model_description else {}
|
|
198
209
|
)
|
|
199
210
|
self._request_limits = request_limits
|
|
200
|
-
|
|
201
|
-
self.
|
|
202
|
-
self._current_generator = lambda: None
|
|
211
|
+
self._pending_requests: asyncio.Queue = asyncio.Queue()
|
|
212
|
+
self._handle_pending_requests_task = None
|
|
203
213
|
self._lock = (
|
|
204
214
|
None
|
|
205
215
|
if isinstance(
|
|
@@ -220,10 +230,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
220
230
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
221
231
|
|
|
222
232
|
self._scheduler_ref = None
|
|
233
|
+
self._text_to_image_scheduler_ref = None
|
|
223
234
|
|
|
224
235
|
async def __post_create__(self):
|
|
225
236
|
self._loop = asyncio.get_running_loop()
|
|
226
237
|
|
|
238
|
+
self._handle_pending_requests_task = asyncio.create_task(
|
|
239
|
+
self._handle_pending_requests()
|
|
240
|
+
)
|
|
241
|
+
|
|
227
242
|
if self.allow_batching():
|
|
228
243
|
from .scheduler import SchedulerActor
|
|
229
244
|
|
|
@@ -233,6 +248,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
233
248
|
uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
|
|
234
249
|
)
|
|
235
250
|
|
|
251
|
+
if self.allow_batching_for_text_to_image():
|
|
252
|
+
from ..model.image.scheduler.flux import FluxBatchSchedulerActor
|
|
253
|
+
|
|
254
|
+
self._text_to_image_scheduler_ref = await xo.create_actor(
|
|
255
|
+
FluxBatchSchedulerActor,
|
|
256
|
+
address=self.address,
|
|
257
|
+
uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()),
|
|
258
|
+
)
|
|
259
|
+
|
|
236
260
|
async def _record_completion_metrics(
|
|
237
261
|
self, duration, completion_tokens, prompt_tokens
|
|
238
262
|
):
|
|
@@ -311,10 +335,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
311
335
|
|
|
312
336
|
model_ability = self._model_description.get("model_ability", [])
|
|
313
337
|
|
|
314
|
-
condition =
|
|
315
|
-
|
|
316
|
-
)
|
|
317
|
-
if condition and "vision" in model_ability:
|
|
338
|
+
condition = isinstance(self._model, PytorchModel)
|
|
339
|
+
if condition and ("vision" in model_ability or "audio" in model_ability):
|
|
318
340
|
if (
|
|
319
341
|
self._model.model_family.model_name
|
|
320
342
|
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
|
|
@@ -331,6 +353,26 @@ class ModelActor(xo.StatelessActor):
|
|
|
331
353
|
return False
|
|
332
354
|
return condition
|
|
333
355
|
|
|
356
|
+
def allow_batching_for_text_to_image(self) -> bool:
|
|
357
|
+
from ..model.image.stable_diffusion.core import DiffusionModel
|
|
358
|
+
|
|
359
|
+
condition = XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE is not None and isinstance(
|
|
360
|
+
self._model, DiffusionModel
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if condition:
|
|
364
|
+
model_name = self._model._model_spec.model_name # type: ignore
|
|
365
|
+
if model_name in XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS:
|
|
366
|
+
return True
|
|
367
|
+
else:
|
|
368
|
+
logger.warning(
|
|
369
|
+
f"Currently for image models with text_to_image ability, "
|
|
370
|
+
f"xinference only supports {', '.join(XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS)} for batching. "
|
|
371
|
+
f"Your model {model_name} is disqualified."
|
|
372
|
+
)
|
|
373
|
+
return False
|
|
374
|
+
return condition
|
|
375
|
+
|
|
334
376
|
async def load(self):
|
|
335
377
|
self._model.load()
|
|
336
378
|
if self.allow_batching():
|
|
@@ -338,6 +380,11 @@ class ModelActor(xo.StatelessActor):
|
|
|
338
380
|
logger.debug(
|
|
339
381
|
f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
|
|
340
382
|
)
|
|
383
|
+
if self.allow_batching_for_text_to_image():
|
|
384
|
+
await self._text_to_image_scheduler_ref.set_model(self._model)
|
|
385
|
+
logger.debug(
|
|
386
|
+
f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}"
|
|
387
|
+
)
|
|
341
388
|
|
|
342
389
|
def model_uid(self):
|
|
343
390
|
return (
|
|
@@ -429,6 +476,43 @@ class ModelActor(xo.StatelessActor):
|
|
|
429
476
|
)
|
|
430
477
|
await asyncio.gather(*coros)
|
|
431
478
|
|
|
479
|
+
async def _handle_pending_requests(self):
|
|
480
|
+
logger.info("Start requests handler.")
|
|
481
|
+
while True:
|
|
482
|
+
gen, stream_out, stop = await self._pending_requests.get()
|
|
483
|
+
|
|
484
|
+
async def _async_wrapper(_gen):
|
|
485
|
+
try:
|
|
486
|
+
# anext is only available for Python >= 3.10
|
|
487
|
+
return await _gen.__anext__() # noqa: F821
|
|
488
|
+
except StopAsyncIteration:
|
|
489
|
+
return stop
|
|
490
|
+
|
|
491
|
+
def _wrapper(_gen):
|
|
492
|
+
# Avoid issue: https://github.com/python/cpython/issues/112182
|
|
493
|
+
try:
|
|
494
|
+
return next(_gen)
|
|
495
|
+
except StopIteration:
|
|
496
|
+
return stop
|
|
497
|
+
|
|
498
|
+
while True:
|
|
499
|
+
try:
|
|
500
|
+
if inspect.isgenerator(gen):
|
|
501
|
+
r = await asyncio.to_thread(_wrapper, gen)
|
|
502
|
+
elif inspect.isasyncgen(gen):
|
|
503
|
+
r = await _async_wrapper(gen)
|
|
504
|
+
else:
|
|
505
|
+
raise Exception(
|
|
506
|
+
f"The generator {gen} should be a generator or an async generator, "
|
|
507
|
+
f"but a {type(gen)} is got."
|
|
508
|
+
)
|
|
509
|
+
stream_out.put_nowait(r)
|
|
510
|
+
if r is not stop:
|
|
511
|
+
continue
|
|
512
|
+
except Exception:
|
|
513
|
+
logger.exception("stream encountered an error.")
|
|
514
|
+
break
|
|
515
|
+
|
|
432
516
|
async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
|
|
433
517
|
return await self._call_wrapper("json", fn, *args, **kwargs)
|
|
434
518
|
|
|
@@ -442,6 +526,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
442
526
|
ret = await fn(*args, **kwargs)
|
|
443
527
|
else:
|
|
444
528
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
529
|
+
|
|
530
|
+
if inspect.isgenerator(ret):
|
|
531
|
+
gen = self._to_generator(output_type, ret)
|
|
532
|
+
return gen
|
|
533
|
+
if inspect.isasyncgen(ret):
|
|
534
|
+
gen = self._to_async_gen(output_type, ret)
|
|
535
|
+
return gen
|
|
445
536
|
else:
|
|
446
537
|
async with self._lock:
|
|
447
538
|
if inspect.iscoroutinefunction(fn):
|
|
@@ -449,17 +540,40 @@ class ModelActor(xo.StatelessActor):
|
|
|
449
540
|
else:
|
|
450
541
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
451
542
|
|
|
452
|
-
|
|
453
|
-
|
|
543
|
+
stream_out: Union[queue.Queue, asyncio.Queue]
|
|
544
|
+
|
|
545
|
+
if inspect.isgenerator(ret):
|
|
546
|
+
gen = self._to_generator(output_type, ret)
|
|
547
|
+
stream_out = queue.Queue()
|
|
548
|
+
stop = object()
|
|
549
|
+
self._pending_requests.put_nowait((gen, stream_out, stop))
|
|
550
|
+
|
|
551
|
+
def _stream_out_generator():
|
|
552
|
+
while True:
|
|
553
|
+
o = stream_out.get()
|
|
554
|
+
if o is stop:
|
|
555
|
+
break
|
|
556
|
+
else:
|
|
557
|
+
yield o
|
|
558
|
+
|
|
559
|
+
return _stream_out_generator()
|
|
560
|
+
|
|
561
|
+
if inspect.isasyncgen(ret):
|
|
562
|
+
gen = self._to_async_gen(output_type, ret)
|
|
563
|
+
stream_out = asyncio.Queue()
|
|
564
|
+
stop = object()
|
|
565
|
+
self._pending_requests.put_nowait((gen, stream_out, stop))
|
|
566
|
+
|
|
567
|
+
async def _stream_out_async_gen():
|
|
568
|
+
while True:
|
|
569
|
+
o = await stream_out.get()
|
|
570
|
+
if o is stop:
|
|
571
|
+
break
|
|
572
|
+
else:
|
|
573
|
+
yield o
|
|
574
|
+
|
|
575
|
+
return _stream_out_async_gen()
|
|
454
576
|
|
|
455
|
-
if inspect.isgenerator(ret):
|
|
456
|
-
gen = self._to_generator(output_type, ret)
|
|
457
|
-
self._current_generator = weakref.ref(gen)
|
|
458
|
-
return gen
|
|
459
|
-
if inspect.isasyncgen(ret):
|
|
460
|
-
gen = self._to_async_gen(output_type, ret)
|
|
461
|
-
self._current_generator = weakref.ref(gen)
|
|
462
|
-
return gen
|
|
463
577
|
if output_type == "json":
|
|
464
578
|
return await asyncio.to_thread(json_dumps, ret)
|
|
465
579
|
else:
|
|
@@ -547,7 +661,6 @@ class ModelActor(xo.StatelessActor):
|
|
|
547
661
|
prompt_or_messages, queue, call_ability, *args, **kwargs
|
|
548
662
|
)
|
|
549
663
|
gen = self._to_async_gen("json", ret)
|
|
550
|
-
self._current_generator = weakref.ref(gen)
|
|
551
664
|
return gen
|
|
552
665
|
else:
|
|
553
666
|
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
@@ -617,12 +730,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
617
730
|
)
|
|
618
731
|
|
|
619
732
|
async def abort_request(self, request_id: str) -> str:
|
|
620
|
-
from .
|
|
733
|
+
from .utils import AbortRequestMessage
|
|
621
734
|
|
|
622
735
|
if self.allow_batching():
|
|
623
736
|
if self._scheduler_ref is None:
|
|
624
737
|
return AbortRequestMessage.NOT_FOUND.name
|
|
625
738
|
return await self._scheduler_ref.abort_request(request_id)
|
|
739
|
+
elif self.allow_batching_for_text_to_image():
|
|
740
|
+
if self._text_to_image_scheduler_ref is None:
|
|
741
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
742
|
+
return await self._text_to_image_scheduler_ref.abort_request(request_id)
|
|
626
743
|
return AbortRequestMessage.NO_OP.name
|
|
627
744
|
|
|
628
745
|
@request_limit
|
|
@@ -747,6 +864,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
747
864
|
f"Model {self._model.model_spec} is not for creating speech."
|
|
748
865
|
)
|
|
749
866
|
|
|
867
|
+
async def handle_image_batching_request(self, unique_id, *args, **kwargs):
|
|
868
|
+
size = args[2]
|
|
869
|
+
if XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE != size:
|
|
870
|
+
raise RuntimeError(
|
|
871
|
+
f"The image size: {size} of text_to_image for batching "
|
|
872
|
+
f"must be the same as the environment variable: {XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE} you set."
|
|
873
|
+
)
|
|
874
|
+
assert self._loop is not None
|
|
875
|
+
future = ConcurrentFuture()
|
|
876
|
+
await self._text_to_image_scheduler_ref.add_request(
|
|
877
|
+
unique_id, future, *args, **kwargs
|
|
878
|
+
)
|
|
879
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
880
|
+
result = await fut
|
|
881
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
882
|
+
|
|
750
883
|
@request_limit
|
|
751
884
|
@log_async(logger=logger)
|
|
752
885
|
async def text_to_image(
|
|
@@ -759,19 +892,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
759
892
|
**kwargs,
|
|
760
893
|
):
|
|
761
894
|
if hasattr(self._model, "text_to_image"):
|
|
762
|
-
|
|
763
|
-
kwargs.pop("request_id", None)
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
size,
|
|
771
|
-
response_format,
|
|
772
|
-
*args,
|
|
773
|
-
**kwargs,
|
|
895
|
+
if self.allow_batching_for_text_to_image():
|
|
896
|
+
unique_id = kwargs.pop("request_id", None)
|
|
897
|
+
return await self.handle_image_batching_request(
|
|
898
|
+
unique_id, prompt, n, size, response_format, *args, **kwargs
|
|
899
|
+
)
|
|
900
|
+
else:
|
|
901
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
902
|
+
kwargs.pop("request_id", None)
|
|
774
903
|
)
|
|
904
|
+
with progressor:
|
|
905
|
+
return await self._call_wrapper_json(
|
|
906
|
+
self._model.text_to_image,
|
|
907
|
+
prompt,
|
|
908
|
+
n,
|
|
909
|
+
size,
|
|
910
|
+
response_format,
|
|
911
|
+
*args,
|
|
912
|
+
**kwargs,
|
|
913
|
+
)
|
|
775
914
|
raise AttributeError(
|
|
776
915
|
f"Model {self._model.model_spec} is not for creating image."
|
|
777
916
|
)
|
|
@@ -882,6 +1021,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
882
1021
|
f"Model {self._model.model_spec} is not for creating image."
|
|
883
1022
|
)
|
|
884
1023
|
|
|
1024
|
+
@log_async(
|
|
1025
|
+
logger=logger,
|
|
1026
|
+
ignore_kwargs=["image"],
|
|
1027
|
+
)
|
|
1028
|
+
async def ocr(
|
|
1029
|
+
self,
|
|
1030
|
+
image: "PIL.Image",
|
|
1031
|
+
*args,
|
|
1032
|
+
**kwargs,
|
|
1033
|
+
):
|
|
1034
|
+
if hasattr(self._model, "ocr"):
|
|
1035
|
+
return await self._call_wrapper_json(
|
|
1036
|
+
self._model.ocr,
|
|
1037
|
+
image,
|
|
1038
|
+
*args,
|
|
1039
|
+
**kwargs,
|
|
1040
|
+
)
|
|
1041
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
|
|
1042
|
+
|
|
885
1043
|
@request_limit
|
|
886
1044
|
@log_async(logger=logger, ignore_kwargs=["image"])
|
|
887
1045
|
async def infer(
|
|
@@ -923,3 +1081,6 @@ class ModelActor(xo.StatelessActor):
|
|
|
923
1081
|
async def record_metrics(self, name, op, kwargs):
|
|
924
1082
|
worker_ref = await self._get_worker_ref()
|
|
925
1083
|
await worker_ref.record_metrics(name, op, kwargs)
|
|
1084
|
+
|
|
1085
|
+
async def get_pending_requests_count(self):
|
|
1086
|
+
return self._pending_requests.qsize()
|
xinference/core/scheduler.py
CHANGED
|
@@ -17,11 +17,12 @@ import functools
|
|
|
17
17
|
import logging
|
|
18
18
|
import uuid
|
|
19
19
|
from collections import deque
|
|
20
|
-
from enum import Enum
|
|
21
20
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
22
21
|
|
|
23
22
|
import xoscar as xo
|
|
24
23
|
|
|
24
|
+
from .utils import AbortRequestMessage
|
|
25
|
+
|
|
25
26
|
logger = logging.getLogger(__name__)
|
|
26
27
|
|
|
27
28
|
XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
|
|
@@ -30,12 +31,6 @@ XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
|
|
|
30
31
|
XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
|
|
31
32
|
|
|
32
33
|
|
|
33
|
-
class AbortRequestMessage(Enum):
|
|
34
|
-
NOT_FOUND = 1
|
|
35
|
-
DONE = 2
|
|
36
|
-
NO_OP = 3
|
|
37
|
-
|
|
38
|
-
|
|
39
34
|
class InferenceRequest:
|
|
40
35
|
def __init__(
|
|
41
36
|
self,
|
|
@@ -81,6 +76,10 @@ class InferenceRequest:
|
|
|
81
76
|
self.padding_len = 0
|
|
82
77
|
# Use in stream mode
|
|
83
78
|
self.last_output_length = 0
|
|
79
|
+
# For tool call
|
|
80
|
+
self.tools = None
|
|
81
|
+
# Currently, for storing tool call streaming results.
|
|
82
|
+
self.outputs: List[str] = [] # type: ignore
|
|
84
83
|
# inference results,
|
|
85
84
|
# it is a list type because when stream=True,
|
|
86
85
|
# self.completion contains all the results in a decode round.
|
|
@@ -112,6 +111,10 @@ class InferenceRequest:
|
|
|
112
111
|
"""
|
|
113
112
|
return self._prompt
|
|
114
113
|
|
|
114
|
+
@prompt.setter
|
|
115
|
+
def prompt(self, value: str):
|
|
116
|
+
self._prompt = value
|
|
117
|
+
|
|
115
118
|
@property
|
|
116
119
|
def call_ability(self):
|
|
117
120
|
return self._call_ability
|
xinference/core/utils.py
CHANGED
|
@@ -16,6 +16,7 @@ import os
|
|
|
16
16
|
import random
|
|
17
17
|
import string
|
|
18
18
|
import uuid
|
|
19
|
+
from enum import Enum
|
|
19
20
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
20
21
|
|
|
21
22
|
import orjson
|
|
@@ -27,6 +28,12 @@ from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class AbortRequestMessage(Enum):
|
|
32
|
+
NOT_FOUND = 1
|
|
33
|
+
DONE = 2
|
|
34
|
+
NO_OP = 3
|
|
35
|
+
|
|
36
|
+
|
|
30
37
|
def truncate_log_arg(arg) -> str:
|
|
31
38
|
s = str(arg)
|
|
32
39
|
if len(s) > XINFERENCE_LOG_ARG_MAX_LENGTH:
|
|
@@ -51,6 +58,8 @@ def log_async(
|
|
|
51
58
|
request_id_str = kwargs.get("request_id", "")
|
|
52
59
|
if not request_id_str:
|
|
53
60
|
request_id_str = uuid.uuid1()
|
|
61
|
+
if func_name == "text_to_image":
|
|
62
|
+
kwargs["request_id"] = request_id_str
|
|
54
63
|
request_id_str = f"[request {request_id_str}]"
|
|
55
64
|
formatted_args = ",".join(map(truncate_log_arg, args))
|
|
56
65
|
formatted_kwargs = ",".join(
|