xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +106 -16
- xinference/core/scheduler.py +1 -1
- xinference/core/worker.py +3 -1
- xinference/deploy/supervisor.py +0 -4
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/core.py +6 -2
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/core.py +3 -1
- xinference/model/embedding/core.py +6 -2
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +65 -6
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +79 -0
- xinference/model/image/scheduler/flux.py +1 -1
- xinference/model/image/stable_diffusion/core.py +2 -3
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/llm/__init__.py +33 -0
- xinference/model/llm/core.py +3 -1
- xinference/model/llm/llm_family.json +9 -0
- xinference/model/llm/llm_family.py +68 -2
- xinference/model/llm/llm_family_modelscope.json +11 -0
- xinference/model/llm/llm_family_openmind_hub.json +1359 -0
- xinference/model/rerank/core.py +9 -1
- xinference/model/utils.py +7 -0
- xinference/model/video/core.py +6 -2
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
- xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-11-01T17:56:47+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.16.
|
|
14
|
+
"full-revisionid": "67e97ab485b539dc7a208825bee0504acc37044e",
|
|
15
|
+
"version": "0.16.2"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -567,6 +567,16 @@ class RESTfulAPI:
|
|
|
567
567
|
else None
|
|
568
568
|
),
|
|
569
569
|
)
|
|
570
|
+
self._router.add_api_route(
|
|
571
|
+
"/v1/images/ocr",
|
|
572
|
+
self.create_ocr,
|
|
573
|
+
methods=["POST"],
|
|
574
|
+
dependencies=(
|
|
575
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
576
|
+
if self.is_authenticated()
|
|
577
|
+
else None
|
|
578
|
+
),
|
|
579
|
+
)
|
|
570
580
|
# SD WebUI API
|
|
571
581
|
self._router.add_api_route(
|
|
572
582
|
"/sdapi/v1/options",
|
|
@@ -1754,6 +1764,44 @@ class RESTfulAPI:
|
|
|
1754
1764
|
await self._report_error_event(model_uid, str(e))
|
|
1755
1765
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1756
1766
|
|
|
1767
|
+
async def create_ocr(
|
|
1768
|
+
self,
|
|
1769
|
+
model: str = Form(...),
|
|
1770
|
+
image: UploadFile = File(media_type="application/octet-stream"),
|
|
1771
|
+
kwargs: Optional[str] = Form(None),
|
|
1772
|
+
) -> Response:
|
|
1773
|
+
model_uid = model
|
|
1774
|
+
try:
|
|
1775
|
+
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1776
|
+
except ValueError as ve:
|
|
1777
|
+
logger.error(str(ve), exc_info=True)
|
|
1778
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1779
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1780
|
+
except Exception as e:
|
|
1781
|
+
logger.error(e, exc_info=True)
|
|
1782
|
+
await self._report_error_event(model_uid, str(e))
|
|
1783
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1784
|
+
|
|
1785
|
+
try:
|
|
1786
|
+
if kwargs is not None:
|
|
1787
|
+
parsed_kwargs = json.loads(kwargs)
|
|
1788
|
+
else:
|
|
1789
|
+
parsed_kwargs = {}
|
|
1790
|
+
im = Image.open(image.file)
|
|
1791
|
+
text = await model_ref.ocr(
|
|
1792
|
+
image=im,
|
|
1793
|
+
**parsed_kwargs,
|
|
1794
|
+
)
|
|
1795
|
+
return Response(content=text, media_type="text/plain")
|
|
1796
|
+
except RuntimeError as re:
|
|
1797
|
+
logger.error(re, exc_info=True)
|
|
1798
|
+
await self._report_error_event(model_uid, str(re))
|
|
1799
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1800
|
+
except Exception as e:
|
|
1801
|
+
logger.error(e, exc_info=True)
|
|
1802
|
+
await self._report_error_event(model_uid, str(e))
|
|
1803
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1804
|
+
|
|
1757
1805
|
async def create_flexible_infer(self, request: Request) -> Response:
|
|
1758
1806
|
payload = await request.json()
|
|
1759
1807
|
|
|
@@ -369,6 +369,25 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
|
|
|
369
369
|
response_data = response.json()
|
|
370
370
|
return response_data
|
|
371
371
|
|
|
372
|
+
def ocr(self, image: Union[str, bytes], **kwargs):
|
|
373
|
+
url = f"{self._base_url}/v1/images/ocr"
|
|
374
|
+
params = {
|
|
375
|
+
"model": self._model_uid,
|
|
376
|
+
"kwargs": json.dumps(kwargs),
|
|
377
|
+
}
|
|
378
|
+
files: List[Any] = []
|
|
379
|
+
for key, value in params.items():
|
|
380
|
+
files.append((key, (None, value)))
|
|
381
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
382
|
+
response = requests.post(url, files=files, headers=self.auth_headers)
|
|
383
|
+
if response.status_code != 200:
|
|
384
|
+
raise RuntimeError(
|
|
385
|
+
f"Failed to ocr the images, detail: {_get_error_string(response)}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
response_data = response.json()
|
|
389
|
+
return response_data
|
|
390
|
+
|
|
372
391
|
|
|
373
392
|
class RESTfulVideoModelHandle(RESTfulModelHandle):
|
|
374
393
|
def text_to_video(
|
xinference/constants.py
CHANGED
|
@@ -39,6 +39,7 @@ def get_xinference_home() -> str:
|
|
|
39
39
|
# if user has already set `XINFERENCE_HOME` env, change huggingface and modelscope default download path
|
|
40
40
|
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(home_path, "huggingface")
|
|
41
41
|
os.environ["MODELSCOPE_CACHE"] = os.path.join(home_path, "modelscope")
|
|
42
|
+
os.environ["XDG_CACHE_HOME"] = os.path.join(home_path, "openmind_hub")
|
|
42
43
|
# In multi-tenant mode,
|
|
43
44
|
# gradio's temporary files are stored in their respective home directories,
|
|
44
45
|
# to prevent insufficient permissions
|
|
@@ -74,7 +74,11 @@ class GradioInterface:
|
|
|
74
74
|
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
75
75
|
# started, that event will not run, so manually invoke the startup events.
|
|
76
76
|
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
77
|
-
|
|
77
|
+
try:
|
|
78
|
+
interface.run_startup_events()
|
|
79
|
+
except AttributeError:
|
|
80
|
+
# compatibility
|
|
81
|
+
interface.startup_events()
|
|
78
82
|
favicon_path = os.path.join(
|
|
79
83
|
os.path.dirname(os.path.abspath(__file__)),
|
|
80
84
|
os.path.pardir,
|
|
@@ -63,7 +63,11 @@ class ImageInterface:
|
|
|
63
63
|
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
64
64
|
# started, that event will not run, so manually invoke the startup events.
|
|
65
65
|
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
66
|
-
|
|
66
|
+
try:
|
|
67
|
+
interface.run_startup_events()
|
|
68
|
+
except AttributeError:
|
|
69
|
+
# compatibility
|
|
70
|
+
interface.startup_events()
|
|
67
71
|
favicon_path = os.path.join(
|
|
68
72
|
os.path.dirname(os.path.abspath(__file__)),
|
|
69
73
|
os.path.pardir,
|
xinference/core/model.py
CHANGED
|
@@ -17,10 +17,10 @@ import functools
|
|
|
17
17
|
import inspect
|
|
18
18
|
import json
|
|
19
19
|
import os
|
|
20
|
+
import queue
|
|
20
21
|
import time
|
|
21
22
|
import types
|
|
22
23
|
import uuid
|
|
23
|
-
import weakref
|
|
24
24
|
from asyncio.queues import Queue
|
|
25
25
|
from asyncio.tasks import wait_for
|
|
26
26
|
from concurrent.futures import Future as ConcurrentFuture
|
|
@@ -32,7 +32,6 @@ from typing import (
|
|
|
32
32
|
Callable,
|
|
33
33
|
Dict,
|
|
34
34
|
Generator,
|
|
35
|
-
Iterator,
|
|
36
35
|
List,
|
|
37
36
|
Optional,
|
|
38
37
|
Union,
|
|
@@ -209,9 +208,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
209
208
|
model_description.to_dict() if model_description else {}
|
|
210
209
|
)
|
|
211
210
|
self._request_limits = request_limits
|
|
212
|
-
|
|
213
|
-
self.
|
|
214
|
-
self._current_generator = lambda: None
|
|
211
|
+
self._pending_requests: asyncio.Queue = asyncio.Queue()
|
|
212
|
+
self._handle_pending_requests_task = None
|
|
215
213
|
self._lock = (
|
|
216
214
|
None
|
|
217
215
|
if isinstance(
|
|
@@ -237,6 +235,10 @@ class ModelActor(xo.StatelessActor):
|
|
|
237
235
|
async def __post_create__(self):
|
|
238
236
|
self._loop = asyncio.get_running_loop()
|
|
239
237
|
|
|
238
|
+
self._handle_pending_requests_task = asyncio.create_task(
|
|
239
|
+
self._handle_pending_requests()
|
|
240
|
+
)
|
|
241
|
+
|
|
240
242
|
if self.allow_batching():
|
|
241
243
|
from .scheduler import SchedulerActor
|
|
242
244
|
|
|
@@ -474,6 +476,43 @@ class ModelActor(xo.StatelessActor):
|
|
|
474
476
|
)
|
|
475
477
|
await asyncio.gather(*coros)
|
|
476
478
|
|
|
479
|
+
async def _handle_pending_requests(self):
|
|
480
|
+
logger.info("Start requests handler.")
|
|
481
|
+
while True:
|
|
482
|
+
gen, stream_out, stop = await self._pending_requests.get()
|
|
483
|
+
|
|
484
|
+
async def _async_wrapper(_gen):
|
|
485
|
+
try:
|
|
486
|
+
# anext is only available for Python >= 3.10
|
|
487
|
+
return await _gen.__anext__() # noqa: F821
|
|
488
|
+
except StopAsyncIteration:
|
|
489
|
+
return stop
|
|
490
|
+
|
|
491
|
+
def _wrapper(_gen):
|
|
492
|
+
# Avoid issue: https://github.com/python/cpython/issues/112182
|
|
493
|
+
try:
|
|
494
|
+
return next(_gen)
|
|
495
|
+
except StopIteration:
|
|
496
|
+
return stop
|
|
497
|
+
|
|
498
|
+
while True:
|
|
499
|
+
try:
|
|
500
|
+
if inspect.isgenerator(gen):
|
|
501
|
+
r = await asyncio.to_thread(_wrapper, gen)
|
|
502
|
+
elif inspect.isasyncgen(gen):
|
|
503
|
+
r = await _async_wrapper(gen)
|
|
504
|
+
else:
|
|
505
|
+
raise Exception(
|
|
506
|
+
f"The generator {gen} should be a generator or an async generator, "
|
|
507
|
+
f"but a {type(gen)} is got."
|
|
508
|
+
)
|
|
509
|
+
stream_out.put_nowait(r)
|
|
510
|
+
if r is not stop:
|
|
511
|
+
continue
|
|
512
|
+
except Exception:
|
|
513
|
+
logger.exception("stream encountered an error.")
|
|
514
|
+
break
|
|
515
|
+
|
|
477
516
|
async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
|
|
478
517
|
return await self._call_wrapper("json", fn, *args, **kwargs)
|
|
479
518
|
|
|
@@ -487,6 +526,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
487
526
|
ret = await fn(*args, **kwargs)
|
|
488
527
|
else:
|
|
489
528
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
529
|
+
|
|
530
|
+
if inspect.isgenerator(ret):
|
|
531
|
+
gen = self._to_generator(output_type, ret)
|
|
532
|
+
return gen
|
|
533
|
+
if inspect.isasyncgen(ret):
|
|
534
|
+
gen = self._to_async_gen(output_type, ret)
|
|
535
|
+
return gen
|
|
490
536
|
else:
|
|
491
537
|
async with self._lock:
|
|
492
538
|
if inspect.iscoroutinefunction(fn):
|
|
@@ -494,17 +540,40 @@ class ModelActor(xo.StatelessActor):
|
|
|
494
540
|
else:
|
|
495
541
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
496
542
|
|
|
497
|
-
|
|
498
|
-
|
|
543
|
+
stream_out: Union[queue.Queue, asyncio.Queue]
|
|
544
|
+
|
|
545
|
+
if inspect.isgenerator(ret):
|
|
546
|
+
gen = self._to_generator(output_type, ret)
|
|
547
|
+
stream_out = queue.Queue()
|
|
548
|
+
stop = object()
|
|
549
|
+
self._pending_requests.put_nowait((gen, stream_out, stop))
|
|
550
|
+
|
|
551
|
+
def _stream_out_generator():
|
|
552
|
+
while True:
|
|
553
|
+
o = stream_out.get()
|
|
554
|
+
if o is stop:
|
|
555
|
+
break
|
|
556
|
+
else:
|
|
557
|
+
yield o
|
|
558
|
+
|
|
559
|
+
return _stream_out_generator()
|
|
560
|
+
|
|
561
|
+
if inspect.isasyncgen(ret):
|
|
562
|
+
gen = self._to_async_gen(output_type, ret)
|
|
563
|
+
stream_out = asyncio.Queue()
|
|
564
|
+
stop = object()
|
|
565
|
+
self._pending_requests.put_nowait((gen, stream_out, stop))
|
|
566
|
+
|
|
567
|
+
async def _stream_out_async_gen():
|
|
568
|
+
while True:
|
|
569
|
+
o = await stream_out.get()
|
|
570
|
+
if o is stop:
|
|
571
|
+
break
|
|
572
|
+
else:
|
|
573
|
+
yield o
|
|
574
|
+
|
|
575
|
+
return _stream_out_async_gen()
|
|
499
576
|
|
|
500
|
-
if inspect.isgenerator(ret):
|
|
501
|
-
gen = self._to_generator(output_type, ret)
|
|
502
|
-
self._current_generator = weakref.ref(gen)
|
|
503
|
-
return gen
|
|
504
|
-
if inspect.isasyncgen(ret):
|
|
505
|
-
gen = self._to_async_gen(output_type, ret)
|
|
506
|
-
self._current_generator = weakref.ref(gen)
|
|
507
|
-
return gen
|
|
508
577
|
if output_type == "json":
|
|
509
578
|
return await asyncio.to_thread(json_dumps, ret)
|
|
510
579
|
else:
|
|
@@ -592,7 +661,6 @@ class ModelActor(xo.StatelessActor):
|
|
|
592
661
|
prompt_or_messages, queue, call_ability, *args, **kwargs
|
|
593
662
|
)
|
|
594
663
|
gen = self._to_async_gen("json", ret)
|
|
595
|
-
self._current_generator = weakref.ref(gen)
|
|
596
664
|
return gen
|
|
597
665
|
else:
|
|
598
666
|
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
@@ -953,6 +1021,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
953
1021
|
f"Model {self._model.model_spec} is not for creating image."
|
|
954
1022
|
)
|
|
955
1023
|
|
|
1024
|
+
@log_async(
|
|
1025
|
+
logger=logger,
|
|
1026
|
+
ignore_kwargs=["image"],
|
|
1027
|
+
)
|
|
1028
|
+
async def ocr(
|
|
1029
|
+
self,
|
|
1030
|
+
image: "PIL.Image",
|
|
1031
|
+
*args,
|
|
1032
|
+
**kwargs,
|
|
1033
|
+
):
|
|
1034
|
+
if hasattr(self._model, "ocr"):
|
|
1035
|
+
return await self._call_wrapper_json(
|
|
1036
|
+
self._model.ocr,
|
|
1037
|
+
image,
|
|
1038
|
+
*args,
|
|
1039
|
+
**kwargs,
|
|
1040
|
+
)
|
|
1041
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
|
|
1042
|
+
|
|
956
1043
|
@request_limit
|
|
957
1044
|
@log_async(logger=logger, ignore_kwargs=["image"])
|
|
958
1045
|
async def infer(
|
|
@@ -994,3 +1081,6 @@ class ModelActor(xo.StatelessActor):
|
|
|
994
1081
|
async def record_metrics(self, name, op, kwargs):
|
|
995
1082
|
worker_ref = await self._get_worker_ref()
|
|
996
1083
|
await worker_ref.record_metrics(name, op, kwargs)
|
|
1084
|
+
|
|
1085
|
+
async def get_pending_requests_count(self):
|
|
1086
|
+
return self._pending_requests.qsize()
|
xinference/core/scheduler.py
CHANGED
|
@@ -79,7 +79,7 @@ class InferenceRequest:
|
|
|
79
79
|
# For tool call
|
|
80
80
|
self.tools = None
|
|
81
81
|
# Currently, for storing tool call streaming results.
|
|
82
|
-
self.outputs: List[str] = []
|
|
82
|
+
self.outputs: List[str] = [] # type: ignore
|
|
83
83
|
# inference results,
|
|
84
84
|
# it is a list type because when stream=True,
|
|
85
85
|
# self.completion contains all the results in a decode round.
|
xinference/core/worker.py
CHANGED
|
@@ -785,7 +785,9 @@ class WorkerActor(xo.StatelessActor):
|
|
|
785
785
|
peft_model_config: Optional[PeftModelConfig] = None,
|
|
786
786
|
request_limits: Optional[int] = None,
|
|
787
787
|
gpu_idx: Optional[Union[int, List[int]]] = None,
|
|
788
|
-
download_hub: Optional[
|
|
788
|
+
download_hub: Optional[
|
|
789
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
790
|
+
] = None,
|
|
789
791
|
model_path: Optional[str] = None,
|
|
790
792
|
**kwargs,
|
|
791
793
|
):
|
xinference/deploy/supervisor.py
CHANGED
|
@@ -31,10 +31,6 @@ from .utils import health_check
|
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
|
-
from ..model import _install as install_model
|
|
35
|
-
|
|
36
|
-
install_model()
|
|
37
|
-
|
|
38
34
|
|
|
39
35
|
async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
|
|
40
36
|
logging.config.dictConfig(logging_conf) # type: ignore
|
|
@@ -54,7 +54,11 @@ class ChatTTSModel:
|
|
|
54
54
|
torch.set_float32_matmul_precision("high")
|
|
55
55
|
self._model = ChatTTS.Chat()
|
|
56
56
|
logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
|
|
57
|
-
self._model.load(
|
|
57
|
+
ok = self._model.load(
|
|
58
|
+
source="custom", custom_path=self._model_path, **self._kwargs
|
|
59
|
+
)
|
|
60
|
+
if not ok:
|
|
61
|
+
raise Exception(f"The ChatTTS model is not correct: {self._model_path}")
|
|
58
62
|
|
|
59
63
|
def speech(
|
|
60
64
|
self,
|
|
@@ -114,16 +118,15 @@ class ChatTTSModel:
|
|
|
114
118
|
last_pos = 0
|
|
115
119
|
with writer.open():
|
|
116
120
|
for it in iter:
|
|
117
|
-
for
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
last_pos = new_last_pos
|
|
121
|
+
for chunk in it:
|
|
122
|
+
chunk = np.array([chunk]).transpose()
|
|
123
|
+
writer.write_audio_chunk(i, torch.from_numpy(chunk))
|
|
124
|
+
new_last_pos = out.tell()
|
|
125
|
+
if new_last_pos != last_pos:
|
|
126
|
+
out.seek(last_pos)
|
|
127
|
+
encoded_bytes = out.read()
|
|
128
|
+
yield encoded_bytes
|
|
129
|
+
last_pos = new_last_pos
|
|
127
130
|
|
|
128
131
|
return _generator()
|
|
129
132
|
else:
|
|
@@ -131,7 +134,15 @@ class ChatTTSModel:
|
|
|
131
134
|
|
|
132
135
|
# Save the generated audio
|
|
133
136
|
with BytesIO() as out:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
+
try:
|
|
138
|
+
torchaudio.save(
|
|
139
|
+
out,
|
|
140
|
+
torch.from_numpy(wavs[0]).unsqueeze(0),
|
|
141
|
+
24000,
|
|
142
|
+
format=response_format,
|
|
143
|
+
)
|
|
144
|
+
except:
|
|
145
|
+
torchaudio.save(
|
|
146
|
+
out, torch.from_numpy(wavs[0]), 24000, format=response_format
|
|
147
|
+
)
|
|
137
148
|
return out.getvalue()
|
xinference/model/audio/core.py
CHANGED
|
@@ -100,7 +100,9 @@ def generate_audio_description(
|
|
|
100
100
|
|
|
101
101
|
def match_audio(
|
|
102
102
|
model_name: str,
|
|
103
|
-
download_hub: Optional[
|
|
103
|
+
download_hub: Optional[
|
|
104
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
105
|
+
] = None,
|
|
104
106
|
) -> AudioModelFamilyV1:
|
|
105
107
|
from ..utils import download_from_modelscope
|
|
106
108
|
from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
|
|
@@ -152,7 +154,9 @@ def create_audio_model_instance(
|
|
|
152
154
|
devices: List[str],
|
|
153
155
|
model_uid: str,
|
|
154
156
|
model_name: str,
|
|
155
|
-
download_hub: Optional[
|
|
157
|
+
download_hub: Optional[
|
|
158
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
159
|
+
] = None,
|
|
156
160
|
model_path: Optional[str] = None,
|
|
157
161
|
**kwargs,
|
|
158
162
|
) -> Tuple[
|
|
@@ -127,7 +127,7 @@
|
|
|
127
127
|
"model_name": "ChatTTS",
|
|
128
128
|
"model_family": "ChatTTS",
|
|
129
129
|
"model_id": "2Noise/ChatTTS",
|
|
130
|
-
"model_revision": "
|
|
130
|
+
"model_revision": "3b34118f6d25850440b8901cef3e71c6ef8619c8",
|
|
131
131
|
"model_ability": "text-to-audio",
|
|
132
132
|
"multilingual": true
|
|
133
133
|
},
|
xinference/model/core.py
CHANGED
|
@@ -55,7 +55,9 @@ def create_model_instance(
|
|
|
55
55
|
model_size_in_billions: Optional[Union[int, str]] = None,
|
|
56
56
|
quantization: Optional[str] = None,
|
|
57
57
|
peft_model_config: Optional[PeftModelConfig] = None,
|
|
58
|
-
download_hub: Optional[
|
|
58
|
+
download_hub: Optional[
|
|
59
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
60
|
+
] = None,
|
|
59
61
|
model_path: Optional[str] = None,
|
|
60
62
|
**kwargs,
|
|
61
63
|
) -> Tuple[Any, ModelDescription]:
|
|
@@ -433,7 +433,9 @@ class EmbeddingModel:
|
|
|
433
433
|
|
|
434
434
|
def match_embedding(
|
|
435
435
|
model_name: str,
|
|
436
|
-
download_hub: Optional[
|
|
436
|
+
download_hub: Optional[
|
|
437
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
438
|
+
] = None,
|
|
437
439
|
) -> EmbeddingModelSpec:
|
|
438
440
|
from ..utils import download_from_modelscope
|
|
439
441
|
from . import BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS
|
|
@@ -469,7 +471,9 @@ def create_embedding_model_instance(
|
|
|
469
471
|
devices: List[str],
|
|
470
472
|
model_uid: str,
|
|
471
473
|
model_name: str,
|
|
472
|
-
download_hub: Optional[
|
|
474
|
+
download_hub: Optional[
|
|
475
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
476
|
+
] = None,
|
|
473
477
|
model_path: Optional[str] = None,
|
|
474
478
|
**kwargs,
|
|
475
479
|
) -> Tuple[EmbeddingModel, EmbeddingModelDescription]:
|
xinference/model/image/core.py
CHANGED
|
@@ -11,17 +11,21 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
import collections.abc
|
|
15
16
|
import logging
|
|
16
17
|
import os
|
|
18
|
+
import platform
|
|
17
19
|
from collections import defaultdict
|
|
18
|
-
from typing import Dict, List, Literal, Optional, Tuple
|
|
20
|
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
19
21
|
|
|
20
22
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
21
23
|
from ...types import PeftModelConfig
|
|
22
24
|
from ..core import CacheableModelSpec, ModelDescription
|
|
23
25
|
from ..utils import valid_model_revision
|
|
26
|
+
from .ocr.got_ocr2 import GotOCR2Model
|
|
24
27
|
from .stable_diffusion.core import DiffusionModel
|
|
28
|
+
from .stable_diffusion.mlx import MLXDiffusionModel
|
|
25
29
|
|
|
26
30
|
logger = logging.getLogger(__name__)
|
|
27
31
|
|
|
@@ -45,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
|
|
|
45
49
|
model_hub: str = "huggingface"
|
|
46
50
|
model_ability: Optional[List[str]]
|
|
47
51
|
controlnet: Optional[List["ImageModelFamilyV1"]]
|
|
52
|
+
default_model_config: Optional[dict] = {}
|
|
48
53
|
default_generate_config: Optional[dict] = {}
|
|
49
54
|
|
|
50
55
|
|
|
@@ -120,7 +125,9 @@ def generate_image_description(
|
|
|
120
125
|
|
|
121
126
|
def match_diffusion(
|
|
122
127
|
model_name: str,
|
|
123
|
-
download_hub: Optional[
|
|
128
|
+
download_hub: Optional[
|
|
129
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
130
|
+
] = None,
|
|
124
131
|
) -> ImageModelFamilyV1:
|
|
125
132
|
from ..utils import download_from_modelscope
|
|
126
133
|
from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
|
|
@@ -180,17 +187,59 @@ def get_cache_status(
|
|
|
180
187
|
return valid_model_revision(meta_path, model_spec.model_revision)
|
|
181
188
|
|
|
182
189
|
|
|
190
|
+
def create_ocr_model_instance(
|
|
191
|
+
subpool_addr: str,
|
|
192
|
+
devices: List[str],
|
|
193
|
+
model_uid: str,
|
|
194
|
+
model_spec: ImageModelFamilyV1,
|
|
195
|
+
model_path: Optional[str] = None,
|
|
196
|
+
**kwargs,
|
|
197
|
+
) -> Tuple[GotOCR2Model, ImageModelDescription]:
|
|
198
|
+
if not model_path:
|
|
199
|
+
model_path = cache(model_spec)
|
|
200
|
+
model = GotOCR2Model(
|
|
201
|
+
model_uid,
|
|
202
|
+
model_path,
|
|
203
|
+
model_spec=model_spec,
|
|
204
|
+
**kwargs,
|
|
205
|
+
)
|
|
206
|
+
model_description = ImageModelDescription(
|
|
207
|
+
subpool_addr, devices, model_spec, model_path=model_path
|
|
208
|
+
)
|
|
209
|
+
return model, model_description
|
|
210
|
+
|
|
211
|
+
|
|
183
212
|
def create_image_model_instance(
|
|
184
213
|
subpool_addr: str,
|
|
185
214
|
devices: List[str],
|
|
186
215
|
model_uid: str,
|
|
187
216
|
model_name: str,
|
|
188
217
|
peft_model_config: Optional[PeftModelConfig] = None,
|
|
189
|
-
download_hub: Optional[
|
|
218
|
+
download_hub: Optional[
|
|
219
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
220
|
+
] = None,
|
|
190
221
|
model_path: Optional[str] = None,
|
|
191
222
|
**kwargs,
|
|
192
|
-
) -> Tuple[
|
|
223
|
+
) -> Tuple[
|
|
224
|
+
Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
|
|
225
|
+
]:
|
|
193
226
|
model_spec = match_diffusion(model_name, download_hub)
|
|
227
|
+
if model_spec.model_ability and "ocr" in model_spec.model_ability:
|
|
228
|
+
return create_ocr_model_instance(
|
|
229
|
+
subpool_addr=subpool_addr,
|
|
230
|
+
devices=devices,
|
|
231
|
+
model_uid=model_uid,
|
|
232
|
+
model_name=model_name,
|
|
233
|
+
model_spec=model_spec,
|
|
234
|
+
model_path=model_path,
|
|
235
|
+
**kwargs,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# use default model config
|
|
239
|
+
model_default_config = (model_spec.default_model_config or {}).copy()
|
|
240
|
+
model_default_config.update(kwargs)
|
|
241
|
+
kwargs = model_default_config
|
|
242
|
+
|
|
194
243
|
controlnet = kwargs.get("controlnet")
|
|
195
244
|
# Handle controlnet
|
|
196
245
|
if controlnet is not None:
|
|
@@ -232,10 +281,20 @@ def create_image_model_instance(
|
|
|
232
281
|
lora_load_kwargs = None
|
|
233
282
|
lora_fuse_kwargs = None
|
|
234
283
|
|
|
235
|
-
|
|
284
|
+
if (
|
|
285
|
+
platform.system() == "Darwin"
|
|
286
|
+
and "arm" in platform.machine().lower()
|
|
287
|
+
and model_name in MLXDiffusionModel.supported_models
|
|
288
|
+
):
|
|
289
|
+
# Mac with M series silicon chips
|
|
290
|
+
model_cls = MLXDiffusionModel
|
|
291
|
+
else:
|
|
292
|
+
model_cls = DiffusionModel # type: ignore
|
|
293
|
+
|
|
294
|
+
model = model_cls(
|
|
236
295
|
model_uid,
|
|
237
296
|
model_path,
|
|
238
|
-
|
|
297
|
+
lora_model=lora_model,
|
|
239
298
|
lora_load_kwargs=lora_load_kwargs,
|
|
240
299
|
lora_fuse_kwargs=lora_fuse_kwargs,
|
|
241
300
|
model_spec=model_spec,
|