xinference 0.16.0__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +106 -16
- xinference/core/scheduler.py +1 -1
- xinference/deploy/supervisor.py +0 -4
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +59 -4
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +76 -0
- xinference/model/image/scheduler/flux.py +1 -1
- xinference/model/image/stable_diffusion/core.py +2 -3
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/llm/llm_family.json +9 -0
- xinference/model/llm/llm_family_modelscope.json +11 -0
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.f7da0140.js → main.b76aeeb7.js} +3 -3
- xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/METADATA +15 -8
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/RECORD +48 -31
- xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-10-
|
|
11
|
+
"date": "2024-10-25T12:51:06+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.16.
|
|
14
|
+
"full-revisionid": "d4cd7b15104c16838e3c562cf2d33337e3d38897",
|
|
15
|
+
"version": "0.16.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -567,6 +567,16 @@ class RESTfulAPI:
|
|
|
567
567
|
else None
|
|
568
568
|
),
|
|
569
569
|
)
|
|
570
|
+
self._router.add_api_route(
|
|
571
|
+
"/v1/images/ocr",
|
|
572
|
+
self.create_ocr,
|
|
573
|
+
methods=["POST"],
|
|
574
|
+
dependencies=(
|
|
575
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
576
|
+
if self.is_authenticated()
|
|
577
|
+
else None
|
|
578
|
+
),
|
|
579
|
+
)
|
|
570
580
|
# SD WebUI API
|
|
571
581
|
self._router.add_api_route(
|
|
572
582
|
"/sdapi/v1/options",
|
|
@@ -1754,6 +1764,44 @@ class RESTfulAPI:
|
|
|
1754
1764
|
await self._report_error_event(model_uid, str(e))
|
|
1755
1765
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1756
1766
|
|
|
1767
|
+
async def create_ocr(
|
|
1768
|
+
self,
|
|
1769
|
+
model: str = Form(...),
|
|
1770
|
+
image: UploadFile = File(media_type="application/octet-stream"),
|
|
1771
|
+
kwargs: Optional[str] = Form(None),
|
|
1772
|
+
) -> Response:
|
|
1773
|
+
model_uid = model
|
|
1774
|
+
try:
|
|
1775
|
+
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1776
|
+
except ValueError as ve:
|
|
1777
|
+
logger.error(str(ve), exc_info=True)
|
|
1778
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1779
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1780
|
+
except Exception as e:
|
|
1781
|
+
logger.error(e, exc_info=True)
|
|
1782
|
+
await self._report_error_event(model_uid, str(e))
|
|
1783
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1784
|
+
|
|
1785
|
+
try:
|
|
1786
|
+
if kwargs is not None:
|
|
1787
|
+
parsed_kwargs = json.loads(kwargs)
|
|
1788
|
+
else:
|
|
1789
|
+
parsed_kwargs = {}
|
|
1790
|
+
im = Image.open(image.file)
|
|
1791
|
+
text = await model_ref.ocr(
|
|
1792
|
+
image=im,
|
|
1793
|
+
**parsed_kwargs,
|
|
1794
|
+
)
|
|
1795
|
+
return Response(content=text, media_type="text/plain")
|
|
1796
|
+
except RuntimeError as re:
|
|
1797
|
+
logger.error(re, exc_info=True)
|
|
1798
|
+
await self._report_error_event(model_uid, str(re))
|
|
1799
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1800
|
+
except Exception as e:
|
|
1801
|
+
logger.error(e, exc_info=True)
|
|
1802
|
+
await self._report_error_event(model_uid, str(e))
|
|
1803
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1804
|
+
|
|
1757
1805
|
async def create_flexible_infer(self, request: Request) -> Response:
|
|
1758
1806
|
payload = await request.json()
|
|
1759
1807
|
|
|
@@ -369,6 +369,25 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
|
|
|
369
369
|
response_data = response.json()
|
|
370
370
|
return response_data
|
|
371
371
|
|
|
372
|
+
def ocr(self, image: Union[str, bytes], **kwargs):
|
|
373
|
+
url = f"{self._base_url}/v1/images/ocr"
|
|
374
|
+
params = {
|
|
375
|
+
"model": self._model_uid,
|
|
376
|
+
"kwargs": json.dumps(kwargs),
|
|
377
|
+
}
|
|
378
|
+
files: List[Any] = []
|
|
379
|
+
for key, value in params.items():
|
|
380
|
+
files.append((key, (None, value)))
|
|
381
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
382
|
+
response = requests.post(url, files=files, headers=self.auth_headers)
|
|
383
|
+
if response.status_code != 200:
|
|
384
|
+
raise RuntimeError(
|
|
385
|
+
f"Failed to ocr the images, detail: {_get_error_string(response)}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
response_data = response.json()
|
|
389
|
+
return response_data
|
|
390
|
+
|
|
372
391
|
|
|
373
392
|
class RESTfulVideoModelHandle(RESTfulModelHandle):
|
|
374
393
|
def text_to_video(
|
|
@@ -74,7 +74,11 @@ class GradioInterface:
|
|
|
74
74
|
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
75
75
|
# started, that event will not run, so manually invoke the startup events.
|
|
76
76
|
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
77
|
-
|
|
77
|
+
try:
|
|
78
|
+
interface.run_startup_events()
|
|
79
|
+
except AttributeError:
|
|
80
|
+
# compatibility
|
|
81
|
+
interface.startup_events()
|
|
78
82
|
favicon_path = os.path.join(
|
|
79
83
|
os.path.dirname(os.path.abspath(__file__)),
|
|
80
84
|
os.path.pardir,
|
|
@@ -63,7 +63,11 @@ class ImageInterface:
|
|
|
63
63
|
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
64
64
|
# started, that event will not run, so manually invoke the startup events.
|
|
65
65
|
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
66
|
-
|
|
66
|
+
try:
|
|
67
|
+
interface.run_startup_events()
|
|
68
|
+
except AttributeError:
|
|
69
|
+
# compatibility
|
|
70
|
+
interface.startup_events()
|
|
67
71
|
favicon_path = os.path.join(
|
|
68
72
|
os.path.dirname(os.path.abspath(__file__)),
|
|
69
73
|
os.path.pardir,
|
xinference/core/model.py
CHANGED
|
@@ -17,10 +17,10 @@ import functools
|
|
|
17
17
|
import inspect
|
|
18
18
|
import json
|
|
19
19
|
import os
|
|
20
|
+
import queue
|
|
20
21
|
import time
|
|
21
22
|
import types
|
|
22
23
|
import uuid
|
|
23
|
-
import weakref
|
|
24
24
|
from asyncio.queues import Queue
|
|
25
25
|
from asyncio.tasks import wait_for
|
|
26
26
|
from concurrent.futures import Future as ConcurrentFuture
|
|
@@ -32,7 +32,6 @@ from typing import (
|
|
|
32
32
|
Callable,
|
|
33
33
|
Dict,
|
|
34
34
|
Generator,
|
|
35
|
-
Iterator,
|
|
36
35
|
List,
|
|
37
36
|
Optional,
|
|
38
37
|
Union,
|
|
@@ -209,9 +208,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
209
208
|
model_description.to_dict() if model_description else {}
|
|
210
209
|
)
|
|
211
210
|
self._request_limits = request_limits
|
|
212
|
-
|
|
213
|
-
self.
|
|
214
|
-
self._current_generator = lambda: None
|
|
211
|
+
self._pending_requests: asyncio.Queue = asyncio.Queue()
|
|
212
|
+
self._handle_pending_requests_task = None
|
|
215
213
|
self._lock = (
|
|
216
214
|
None
|
|
217
215
|
if isinstance(
|
|
@@ -237,6 +235,10 @@ class ModelActor(xo.StatelessActor):
|
|
|
237
235
|
async def __post_create__(self):
|
|
238
236
|
self._loop = asyncio.get_running_loop()
|
|
239
237
|
|
|
238
|
+
self._handle_pending_requests_task = asyncio.create_task(
|
|
239
|
+
self._handle_pending_requests()
|
|
240
|
+
)
|
|
241
|
+
|
|
240
242
|
if self.allow_batching():
|
|
241
243
|
from .scheduler import SchedulerActor
|
|
242
244
|
|
|
@@ -474,6 +476,43 @@ class ModelActor(xo.StatelessActor):
|
|
|
474
476
|
)
|
|
475
477
|
await asyncio.gather(*coros)
|
|
476
478
|
|
|
479
|
+
async def _handle_pending_requests(self):
|
|
480
|
+
logger.info("Start requests handler.")
|
|
481
|
+
while True:
|
|
482
|
+
gen, stream_out, stop = await self._pending_requests.get()
|
|
483
|
+
|
|
484
|
+
async def _async_wrapper(_gen):
|
|
485
|
+
try:
|
|
486
|
+
# anext is only available for Python >= 3.10
|
|
487
|
+
return await _gen.__anext__() # noqa: F821
|
|
488
|
+
except StopAsyncIteration:
|
|
489
|
+
return stop
|
|
490
|
+
|
|
491
|
+
def _wrapper(_gen):
|
|
492
|
+
# Avoid issue: https://github.com/python/cpython/issues/112182
|
|
493
|
+
try:
|
|
494
|
+
return next(_gen)
|
|
495
|
+
except StopIteration:
|
|
496
|
+
return stop
|
|
497
|
+
|
|
498
|
+
while True:
|
|
499
|
+
try:
|
|
500
|
+
if inspect.isgenerator(gen):
|
|
501
|
+
r = await asyncio.to_thread(_wrapper, gen)
|
|
502
|
+
elif inspect.isasyncgen(gen):
|
|
503
|
+
r = await _async_wrapper(gen)
|
|
504
|
+
else:
|
|
505
|
+
raise Exception(
|
|
506
|
+
f"The generator {gen} should be a generator or an async generator, "
|
|
507
|
+
f"but a {type(gen)} is got."
|
|
508
|
+
)
|
|
509
|
+
stream_out.put_nowait(r)
|
|
510
|
+
if r is not stop:
|
|
511
|
+
continue
|
|
512
|
+
except Exception:
|
|
513
|
+
logger.exception("stream encountered an error.")
|
|
514
|
+
break
|
|
515
|
+
|
|
477
516
|
async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
|
|
478
517
|
return await self._call_wrapper("json", fn, *args, **kwargs)
|
|
479
518
|
|
|
@@ -487,6 +526,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
487
526
|
ret = await fn(*args, **kwargs)
|
|
488
527
|
else:
|
|
489
528
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
529
|
+
|
|
530
|
+
if inspect.isgenerator(ret):
|
|
531
|
+
gen = self._to_generator(output_type, ret)
|
|
532
|
+
return gen
|
|
533
|
+
if inspect.isasyncgen(ret):
|
|
534
|
+
gen = self._to_async_gen(output_type, ret)
|
|
535
|
+
return gen
|
|
490
536
|
else:
|
|
491
537
|
async with self._lock:
|
|
492
538
|
if inspect.iscoroutinefunction(fn):
|
|
@@ -494,17 +540,40 @@ class ModelActor(xo.StatelessActor):
|
|
|
494
540
|
else:
|
|
495
541
|
ret = await asyncio.to_thread(fn, *args, **kwargs)
|
|
496
542
|
|
|
497
|
-
|
|
498
|
-
|
|
543
|
+
stream_out: Union[queue.Queue, asyncio.Queue]
|
|
544
|
+
|
|
545
|
+
if inspect.isgenerator(ret):
|
|
546
|
+
gen = self._to_generator(output_type, ret)
|
|
547
|
+
stream_out = queue.Queue()
|
|
548
|
+
stop = object()
|
|
549
|
+
self._pending_requests.put_nowait((gen, stream_out, stop))
|
|
550
|
+
|
|
551
|
+
def _stream_out_generator():
|
|
552
|
+
while True:
|
|
553
|
+
o = stream_out.get()
|
|
554
|
+
if o is stop:
|
|
555
|
+
break
|
|
556
|
+
else:
|
|
557
|
+
yield o
|
|
558
|
+
|
|
559
|
+
return _stream_out_generator()
|
|
560
|
+
|
|
561
|
+
if inspect.isasyncgen(ret):
|
|
562
|
+
gen = self._to_async_gen(output_type, ret)
|
|
563
|
+
stream_out = asyncio.Queue()
|
|
564
|
+
stop = object()
|
|
565
|
+
self._pending_requests.put_nowait((gen, stream_out, stop))
|
|
566
|
+
|
|
567
|
+
async def _stream_out_async_gen():
|
|
568
|
+
while True:
|
|
569
|
+
o = await stream_out.get()
|
|
570
|
+
if o is stop:
|
|
571
|
+
break
|
|
572
|
+
else:
|
|
573
|
+
yield o
|
|
574
|
+
|
|
575
|
+
return _stream_out_async_gen()
|
|
499
576
|
|
|
500
|
-
if inspect.isgenerator(ret):
|
|
501
|
-
gen = self._to_generator(output_type, ret)
|
|
502
|
-
self._current_generator = weakref.ref(gen)
|
|
503
|
-
return gen
|
|
504
|
-
if inspect.isasyncgen(ret):
|
|
505
|
-
gen = self._to_async_gen(output_type, ret)
|
|
506
|
-
self._current_generator = weakref.ref(gen)
|
|
507
|
-
return gen
|
|
508
577
|
if output_type == "json":
|
|
509
578
|
return await asyncio.to_thread(json_dumps, ret)
|
|
510
579
|
else:
|
|
@@ -592,7 +661,6 @@ class ModelActor(xo.StatelessActor):
|
|
|
592
661
|
prompt_or_messages, queue, call_ability, *args, **kwargs
|
|
593
662
|
)
|
|
594
663
|
gen = self._to_async_gen("json", ret)
|
|
595
|
-
self._current_generator = weakref.ref(gen)
|
|
596
664
|
return gen
|
|
597
665
|
else:
|
|
598
666
|
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
@@ -953,6 +1021,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
953
1021
|
f"Model {self._model.model_spec} is not for creating image."
|
|
954
1022
|
)
|
|
955
1023
|
|
|
1024
|
+
@log_async(
|
|
1025
|
+
logger=logger,
|
|
1026
|
+
ignore_kwargs=["image"],
|
|
1027
|
+
)
|
|
1028
|
+
async def ocr(
|
|
1029
|
+
self,
|
|
1030
|
+
image: "PIL.Image",
|
|
1031
|
+
*args,
|
|
1032
|
+
**kwargs,
|
|
1033
|
+
):
|
|
1034
|
+
if hasattr(self._model, "ocr"):
|
|
1035
|
+
return await self._call_wrapper_json(
|
|
1036
|
+
self._model.ocr,
|
|
1037
|
+
image,
|
|
1038
|
+
*args,
|
|
1039
|
+
**kwargs,
|
|
1040
|
+
)
|
|
1041
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
|
|
1042
|
+
|
|
956
1043
|
@request_limit
|
|
957
1044
|
@log_async(logger=logger, ignore_kwargs=["image"])
|
|
958
1045
|
async def infer(
|
|
@@ -994,3 +1081,6 @@ class ModelActor(xo.StatelessActor):
|
|
|
994
1081
|
async def record_metrics(self, name, op, kwargs):
|
|
995
1082
|
worker_ref = await self._get_worker_ref()
|
|
996
1083
|
await worker_ref.record_metrics(name, op, kwargs)
|
|
1084
|
+
|
|
1085
|
+
async def get_pending_requests_count(self):
|
|
1086
|
+
return self._pending_requests.qsize()
|
xinference/core/scheduler.py
CHANGED
|
@@ -79,7 +79,7 @@ class InferenceRequest:
|
|
|
79
79
|
# For tool call
|
|
80
80
|
self.tools = None
|
|
81
81
|
# Currently, for storing tool call streaming results.
|
|
82
|
-
self.outputs: List[str] = []
|
|
82
|
+
self.outputs: List[str] = [] # type: ignore
|
|
83
83
|
# inference results,
|
|
84
84
|
# it is a list type because when stream=True,
|
|
85
85
|
# self.completion contains all the results in a decode round.
|
xinference/deploy/supervisor.py
CHANGED
|
@@ -31,10 +31,6 @@ from .utils import health_check
|
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
|
-
from ..model import _install as install_model
|
|
35
|
-
|
|
36
|
-
install_model()
|
|
37
|
-
|
|
38
34
|
|
|
39
35
|
async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
|
|
40
36
|
logging.config.dictConfig(logging_conf) # type: ignore
|
|
@@ -54,7 +54,11 @@ class ChatTTSModel:
|
|
|
54
54
|
torch.set_float32_matmul_precision("high")
|
|
55
55
|
self._model = ChatTTS.Chat()
|
|
56
56
|
logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
|
|
57
|
-
self._model.load(
|
|
57
|
+
ok = self._model.load(
|
|
58
|
+
source="custom", custom_path=self._model_path, **self._kwargs
|
|
59
|
+
)
|
|
60
|
+
if not ok:
|
|
61
|
+
raise Exception(f"The ChatTTS model is not correct: {self._model_path}")
|
|
58
62
|
|
|
59
63
|
def speech(
|
|
60
64
|
self,
|
|
@@ -114,16 +118,15 @@ class ChatTTSModel:
|
|
|
114
118
|
last_pos = 0
|
|
115
119
|
with writer.open():
|
|
116
120
|
for it in iter:
|
|
117
|
-
for
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
last_pos = new_last_pos
|
|
121
|
+
for chunk in it:
|
|
122
|
+
chunk = np.array([chunk]).transpose()
|
|
123
|
+
writer.write_audio_chunk(i, torch.from_numpy(chunk))
|
|
124
|
+
new_last_pos = out.tell()
|
|
125
|
+
if new_last_pos != last_pos:
|
|
126
|
+
out.seek(last_pos)
|
|
127
|
+
encoded_bytes = out.read()
|
|
128
|
+
yield encoded_bytes
|
|
129
|
+
last_pos = new_last_pos
|
|
127
130
|
|
|
128
131
|
return _generator()
|
|
129
132
|
else:
|
|
@@ -131,7 +134,15 @@ class ChatTTSModel:
|
|
|
131
134
|
|
|
132
135
|
# Save the generated audio
|
|
133
136
|
with BytesIO() as out:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
+
try:
|
|
138
|
+
torchaudio.save(
|
|
139
|
+
out,
|
|
140
|
+
torch.from_numpy(wavs[0]).unsqueeze(0),
|
|
141
|
+
24000,
|
|
142
|
+
format=response_format,
|
|
143
|
+
)
|
|
144
|
+
except:
|
|
145
|
+
torchaudio.save(
|
|
146
|
+
out, torch.from_numpy(wavs[0]), 24000, format=response_format
|
|
147
|
+
)
|
|
137
148
|
return out.getvalue()
|
|
@@ -127,7 +127,7 @@
|
|
|
127
127
|
"model_name": "ChatTTS",
|
|
128
128
|
"model_family": "ChatTTS",
|
|
129
129
|
"model_id": "2Noise/ChatTTS",
|
|
130
|
-
"model_revision": "
|
|
130
|
+
"model_revision": "3b34118f6d25850440b8901cef3e71c6ef8619c8",
|
|
131
131
|
"model_ability": "text-to-audio",
|
|
132
132
|
"multilingual": true
|
|
133
133
|
},
|
xinference/model/image/core.py
CHANGED
|
@@ -11,17 +11,21 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
import collections.abc
|
|
15
16
|
import logging
|
|
16
17
|
import os
|
|
18
|
+
import platform
|
|
17
19
|
from collections import defaultdict
|
|
18
|
-
from typing import Dict, List, Literal, Optional, Tuple
|
|
20
|
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
19
21
|
|
|
20
22
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
21
23
|
from ...types import PeftModelConfig
|
|
22
24
|
from ..core import CacheableModelSpec, ModelDescription
|
|
23
25
|
from ..utils import valid_model_revision
|
|
26
|
+
from .ocr.got_ocr2 import GotOCR2Model
|
|
24
27
|
from .stable_diffusion.core import DiffusionModel
|
|
28
|
+
from .stable_diffusion.mlx import MLXDiffusionModel
|
|
25
29
|
|
|
26
30
|
logger = logging.getLogger(__name__)
|
|
27
31
|
|
|
@@ -45,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
|
|
|
45
49
|
model_hub: str = "huggingface"
|
|
46
50
|
model_ability: Optional[List[str]]
|
|
47
51
|
controlnet: Optional[List["ImageModelFamilyV1"]]
|
|
52
|
+
default_model_config: Optional[dict] = {}
|
|
48
53
|
default_generate_config: Optional[dict] = {}
|
|
49
54
|
|
|
50
55
|
|
|
@@ -180,6 +185,28 @@ def get_cache_status(
|
|
|
180
185
|
return valid_model_revision(meta_path, model_spec.model_revision)
|
|
181
186
|
|
|
182
187
|
|
|
188
|
+
def create_ocr_model_instance(
|
|
189
|
+
subpool_addr: str,
|
|
190
|
+
devices: List[str],
|
|
191
|
+
model_uid: str,
|
|
192
|
+
model_spec: ImageModelFamilyV1,
|
|
193
|
+
model_path: Optional[str] = None,
|
|
194
|
+
**kwargs,
|
|
195
|
+
) -> Tuple[GotOCR2Model, ImageModelDescription]:
|
|
196
|
+
if not model_path:
|
|
197
|
+
model_path = cache(model_spec)
|
|
198
|
+
model = GotOCR2Model(
|
|
199
|
+
model_uid,
|
|
200
|
+
model_path,
|
|
201
|
+
model_spec=model_spec,
|
|
202
|
+
**kwargs,
|
|
203
|
+
)
|
|
204
|
+
model_description = ImageModelDescription(
|
|
205
|
+
subpool_addr, devices, model_spec, model_path=model_path
|
|
206
|
+
)
|
|
207
|
+
return model, model_description
|
|
208
|
+
|
|
209
|
+
|
|
183
210
|
def create_image_model_instance(
|
|
184
211
|
subpool_addr: str,
|
|
185
212
|
devices: List[str],
|
|
@@ -189,8 +216,26 @@ def create_image_model_instance(
|
|
|
189
216
|
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
190
217
|
model_path: Optional[str] = None,
|
|
191
218
|
**kwargs,
|
|
192
|
-
) -> Tuple[
|
|
219
|
+
) -> Tuple[
|
|
220
|
+
Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
|
|
221
|
+
]:
|
|
193
222
|
model_spec = match_diffusion(model_name, download_hub)
|
|
223
|
+
if model_spec.model_ability and "ocr" in model_spec.model_ability:
|
|
224
|
+
return create_ocr_model_instance(
|
|
225
|
+
subpool_addr=subpool_addr,
|
|
226
|
+
devices=devices,
|
|
227
|
+
model_uid=model_uid,
|
|
228
|
+
model_name=model_name,
|
|
229
|
+
model_spec=model_spec,
|
|
230
|
+
model_path=model_path,
|
|
231
|
+
**kwargs,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# use default model config
|
|
235
|
+
model_default_config = (model_spec.default_model_config or {}).copy()
|
|
236
|
+
model_default_config.update(kwargs)
|
|
237
|
+
kwargs = model_default_config
|
|
238
|
+
|
|
194
239
|
controlnet = kwargs.get("controlnet")
|
|
195
240
|
# Handle controlnet
|
|
196
241
|
if controlnet is not None:
|
|
@@ -232,10 +277,20 @@ def create_image_model_instance(
|
|
|
232
277
|
lora_load_kwargs = None
|
|
233
278
|
lora_fuse_kwargs = None
|
|
234
279
|
|
|
235
|
-
|
|
280
|
+
if (
|
|
281
|
+
platform.system() == "Darwin"
|
|
282
|
+
and "arm" in platform.machine().lower()
|
|
283
|
+
and model_name in MLXDiffusionModel.supported_models
|
|
284
|
+
):
|
|
285
|
+
# Mac with M series silicon chips
|
|
286
|
+
model_cls = MLXDiffusionModel
|
|
287
|
+
else:
|
|
288
|
+
model_cls = DiffusionModel # type: ignore
|
|
289
|
+
|
|
290
|
+
model = model_cls(
|
|
236
291
|
model_uid,
|
|
237
292
|
model_path,
|
|
238
|
-
|
|
293
|
+
lora_model=lora_model,
|
|
239
294
|
lora_load_kwargs=lora_load_kwargs,
|
|
240
295
|
lora_fuse_kwargs=lora_fuse_kwargs,
|
|
241
296
|
model_spec=model_spec,
|
|
@@ -8,7 +8,11 @@
|
|
|
8
8
|
"text2image",
|
|
9
9
|
"image2image",
|
|
10
10
|
"inpainting"
|
|
11
|
-
]
|
|
11
|
+
],
|
|
12
|
+
"default_model_config": {
|
|
13
|
+
"quantize": true,
|
|
14
|
+
"quantize_text_encoder": "text_encoder_2"
|
|
15
|
+
}
|
|
12
16
|
},
|
|
13
17
|
{
|
|
14
18
|
"model_name": "FLUX.1-dev",
|
|
@@ -19,7 +23,11 @@
|
|
|
19
23
|
"text2image",
|
|
20
24
|
"image2image",
|
|
21
25
|
"inpainting"
|
|
22
|
-
]
|
|
26
|
+
],
|
|
27
|
+
"default_model_config": {
|
|
28
|
+
"quantize": true,
|
|
29
|
+
"quantize_text_encoder": "text_encoder_2"
|
|
30
|
+
}
|
|
23
31
|
},
|
|
24
32
|
{
|
|
25
33
|
"model_name": "sd3-medium",
|
|
@@ -30,7 +38,11 @@
|
|
|
30
38
|
"text2image",
|
|
31
39
|
"image2image",
|
|
32
40
|
"inpainting"
|
|
33
|
-
]
|
|
41
|
+
],
|
|
42
|
+
"default_model_config": {
|
|
43
|
+
"quantize": true,
|
|
44
|
+
"quantize_text_encoder": "text_encoder_3"
|
|
45
|
+
}
|
|
34
46
|
},
|
|
35
47
|
{
|
|
36
48
|
"model_name": "sd-turbo",
|
|
@@ -178,5 +190,14 @@
|
|
|
178
190
|
"model_ability": [
|
|
179
191
|
"inpainting"
|
|
180
192
|
]
|
|
193
|
+
},
|
|
194
|
+
{
|
|
195
|
+
"model_name": "GOT-OCR2_0",
|
|
196
|
+
"model_family": "ocr",
|
|
197
|
+
"model_id": "stepfun-ai/GOT-OCR2_0",
|
|
198
|
+
"model_revision": "cf6b7386bc89a54f09785612ba74cb12de6fa17c",
|
|
199
|
+
"model_ability": [
|
|
200
|
+
"ocr"
|
|
201
|
+
]
|
|
181
202
|
}
|
|
182
203
|
]
|
|
@@ -9,7 +9,11 @@
|
|
|
9
9
|
"text2image",
|
|
10
10
|
"image2image",
|
|
11
11
|
"inpainting"
|
|
12
|
-
]
|
|
12
|
+
],
|
|
13
|
+
"default_model_config": {
|
|
14
|
+
"quantize": true,
|
|
15
|
+
"quantize_text_encoder": "text_encoder_2"
|
|
16
|
+
}
|
|
13
17
|
},
|
|
14
18
|
{
|
|
15
19
|
"model_name": "FLUX.1-dev",
|
|
@@ -21,7 +25,11 @@
|
|
|
21
25
|
"text2image",
|
|
22
26
|
"image2image",
|
|
23
27
|
"inpainting"
|
|
24
|
-
]
|
|
28
|
+
],
|
|
29
|
+
"default_model_config": {
|
|
30
|
+
"quantize": true,
|
|
31
|
+
"quantize_text_encoder": "text_encoder_2"
|
|
32
|
+
}
|
|
25
33
|
},
|
|
26
34
|
{
|
|
27
35
|
"model_name": "sd3-medium",
|
|
@@ -33,7 +41,11 @@
|
|
|
33
41
|
"text2image",
|
|
34
42
|
"image2image",
|
|
35
43
|
"inpainting"
|
|
36
|
-
]
|
|
44
|
+
],
|
|
45
|
+
"default_model_config": {
|
|
46
|
+
"quantize": true,
|
|
47
|
+
"quantize_text_encoder": "text_encoder_3"
|
|
48
|
+
}
|
|
37
49
|
},
|
|
38
50
|
{
|
|
39
51
|
"model_name": "sd-turbo",
|
|
@@ -148,5 +160,15 @@
|
|
|
148
160
|
"model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
|
|
149
161
|
}
|
|
150
162
|
]
|
|
163
|
+
},
|
|
164
|
+
{
|
|
165
|
+
"model_name": "GOT-OCR2_0",
|
|
166
|
+
"model_family": "ocr",
|
|
167
|
+
"model_id": "stepfun-ai/GOT-OCR2_0",
|
|
168
|
+
"model_revision": "master",
|
|
169
|
+
"model_hub": "modelscope",
|
|
170
|
+
"model_ability": [
|
|
171
|
+
"ocr"
|
|
172
|
+
]
|
|
151
173
|
}
|
|
152
174
|
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|