xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (62) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +48 -0
  3. xinference/client/restful/restful_client.py +19 -0
  4. xinference/constants.py +1 -0
  5. xinference/core/chat_interface.py +5 -1
  6. xinference/core/image_interface.py +5 -1
  7. xinference/core/model.py +106 -16
  8. xinference/core/scheduler.py +1 -1
  9. xinference/core/worker.py +3 -1
  10. xinference/deploy/supervisor.py +0 -4
  11. xinference/model/audio/chattts.py +25 -14
  12. xinference/model/audio/core.py +6 -2
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/core.py +3 -1
  16. xinference/model/embedding/core.py +6 -2
  17. xinference/model/embedding/model_spec.json +1 -1
  18. xinference/model/image/core.py +65 -6
  19. xinference/model/image/model_spec.json +24 -3
  20. xinference/model/image/model_spec_modelscope.json +25 -3
  21. xinference/model/image/ocr/__init__.py +13 -0
  22. xinference/model/image/ocr/got_ocr2.py +79 -0
  23. xinference/model/image/scheduler/flux.py +1 -1
  24. xinference/model/image/stable_diffusion/core.py +2 -3
  25. xinference/model/image/stable_diffusion/mlx.py +221 -0
  26. xinference/model/llm/__init__.py +33 -0
  27. xinference/model/llm/core.py +3 -1
  28. xinference/model/llm/llm_family.json +9 -0
  29. xinference/model/llm/llm_family.py +68 -2
  30. xinference/model/llm/llm_family_modelscope.json +11 -0
  31. xinference/model/llm/llm_family_openmind_hub.json +1359 -0
  32. xinference/model/rerank/core.py +9 -1
  33. xinference/model/utils.py +7 -0
  34. xinference/model/video/core.py +6 -2
  35. xinference/thirdparty/mlx/__init__.py +13 -0
  36. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  37. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  38. xinference/thirdparty/mlx/flux/clip.py +154 -0
  39. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  40. xinference/thirdparty/mlx/flux/flux.py +247 -0
  41. xinference/thirdparty/mlx/flux/layers.py +302 -0
  42. xinference/thirdparty/mlx/flux/lora.py +76 -0
  43. xinference/thirdparty/mlx/flux/model.py +134 -0
  44. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  45. xinference/thirdparty/mlx/flux/t5.py +244 -0
  46. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  47. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  48. xinference/thirdparty/mlx/flux/utils.py +179 -0
  49. xinference/web/ui/build/asset-manifest.json +3 -3
  50. xinference/web/ui/build/index.html +1 -1
  51. xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
  52. xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
  54. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
  55. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
  56. xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  58. /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
  59. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
  60. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
  61. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
  62. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-10-18T12:49:02+0800",
11
+ "date": "2024-11-01T17:56:47+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "5f7dea44832a1c41f887b9a01377191894550057",
15
- "version": "0.16.0"
14
+ "full-revisionid": "67e97ab485b539dc7a208825bee0504acc37044e",
15
+ "version": "0.16.2"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -567,6 +567,16 @@ class RESTfulAPI:
567
567
  else None
568
568
  ),
569
569
  )
570
+ self._router.add_api_route(
571
+ "/v1/images/ocr",
572
+ self.create_ocr,
573
+ methods=["POST"],
574
+ dependencies=(
575
+ [Security(self._auth_service, scopes=["models:read"])]
576
+ if self.is_authenticated()
577
+ else None
578
+ ),
579
+ )
570
580
  # SD WebUI API
571
581
  self._router.add_api_route(
572
582
  "/sdapi/v1/options",
@@ -1754,6 +1764,44 @@ class RESTfulAPI:
1754
1764
  await self._report_error_event(model_uid, str(e))
1755
1765
  raise HTTPException(status_code=500, detail=str(e))
1756
1766
 
1767
+ async def create_ocr(
1768
+ self,
1769
+ model: str = Form(...),
1770
+ image: UploadFile = File(media_type="application/octet-stream"),
1771
+ kwargs: Optional[str] = Form(None),
1772
+ ) -> Response:
1773
+ model_uid = model
1774
+ try:
1775
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
1776
+ except ValueError as ve:
1777
+ logger.error(str(ve), exc_info=True)
1778
+ await self._report_error_event(model_uid, str(ve))
1779
+ raise HTTPException(status_code=400, detail=str(ve))
1780
+ except Exception as e:
1781
+ logger.error(e, exc_info=True)
1782
+ await self._report_error_event(model_uid, str(e))
1783
+ raise HTTPException(status_code=500, detail=str(e))
1784
+
1785
+ try:
1786
+ if kwargs is not None:
1787
+ parsed_kwargs = json.loads(kwargs)
1788
+ else:
1789
+ parsed_kwargs = {}
1790
+ im = Image.open(image.file)
1791
+ text = await model_ref.ocr(
1792
+ image=im,
1793
+ **parsed_kwargs,
1794
+ )
1795
+ return Response(content=text, media_type="text/plain")
1796
+ except RuntimeError as re:
1797
+ logger.error(re, exc_info=True)
1798
+ await self._report_error_event(model_uid, str(re))
1799
+ raise HTTPException(status_code=400, detail=str(re))
1800
+ except Exception as e:
1801
+ logger.error(e, exc_info=True)
1802
+ await self._report_error_event(model_uid, str(e))
1803
+ raise HTTPException(status_code=500, detail=str(e))
1804
+
1757
1805
  async def create_flexible_infer(self, request: Request) -> Response:
1758
1806
  payload = await request.json()
1759
1807
 
@@ -369,6 +369,25 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
369
369
  response_data = response.json()
370
370
  return response_data
371
371
 
372
+ def ocr(self, image: Union[str, bytes], **kwargs):
373
+ url = f"{self._base_url}/v1/images/ocr"
374
+ params = {
375
+ "model": self._model_uid,
376
+ "kwargs": json.dumps(kwargs),
377
+ }
378
+ files: List[Any] = []
379
+ for key, value in params.items():
380
+ files.append((key, (None, value)))
381
+ files.append(("image", ("image", image, "application/octet-stream")))
382
+ response = requests.post(url, files=files, headers=self.auth_headers)
383
+ if response.status_code != 200:
384
+ raise RuntimeError(
385
+ f"Failed to ocr the images, detail: {_get_error_string(response)}"
386
+ )
387
+
388
+ response_data = response.json()
389
+ return response_data
390
+
372
391
 
373
392
  class RESTfulVideoModelHandle(RESTfulModelHandle):
374
393
  def text_to_video(
xinference/constants.py CHANGED
@@ -39,6 +39,7 @@ def get_xinference_home() -> str:
39
39
  # if user has already set `XINFERENCE_HOME` env, change huggingface and modelscope default download path
40
40
  os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(home_path, "huggingface")
41
41
  os.environ["MODELSCOPE_CACHE"] = os.path.join(home_path, "modelscope")
42
+ os.environ["XDG_CACHE_HOME"] = os.path.join(home_path, "openmind_hub")
42
43
  # In multi-tenant mode,
43
44
  # gradio's temporary files are stored in their respective home directories,
44
45
  # to prevent insufficient permissions
@@ -74,7 +74,11 @@ class GradioInterface:
74
74
  # Gradio initiates the queue during a startup event, but since the app has already been
75
75
  # started, that event will not run, so manually invoke the startup events.
76
76
  # See: https://github.com/gradio-app/gradio/issues/5228
77
- interface.startup_events()
77
+ try:
78
+ interface.run_startup_events()
79
+ except AttributeError:
80
+ # compatibility
81
+ interface.startup_events()
78
82
  favicon_path = os.path.join(
79
83
  os.path.dirname(os.path.abspath(__file__)),
80
84
  os.path.pardir,
@@ -63,7 +63,11 @@ class ImageInterface:
63
63
  # Gradio initiates the queue during a startup event, but since the app has already been
64
64
  # started, that event will not run, so manually invoke the startup events.
65
65
  # See: https://github.com/gradio-app/gradio/issues/5228
66
- interface.startup_events()
66
+ try:
67
+ interface.run_startup_events()
68
+ except AttributeError:
69
+ # compatibility
70
+ interface.startup_events()
67
71
  favicon_path = os.path.join(
68
72
  os.path.dirname(os.path.abspath(__file__)),
69
73
  os.path.pardir,
xinference/core/model.py CHANGED
@@ -17,10 +17,10 @@ import functools
17
17
  import inspect
18
18
  import json
19
19
  import os
20
+ import queue
20
21
  import time
21
22
  import types
22
23
  import uuid
23
- import weakref
24
24
  from asyncio.queues import Queue
25
25
  from asyncio.tasks import wait_for
26
26
  from concurrent.futures import Future as ConcurrentFuture
@@ -32,7 +32,6 @@ from typing import (
32
32
  Callable,
33
33
  Dict,
34
34
  Generator,
35
- Iterator,
36
35
  List,
37
36
  Optional,
38
37
  Union,
@@ -209,9 +208,8 @@ class ModelActor(xo.StatelessActor):
209
208
  model_description.to_dict() if model_description else {}
210
209
  )
211
210
  self._request_limits = request_limits
212
-
213
- self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
214
- self._current_generator = lambda: None
211
+ self._pending_requests: asyncio.Queue = asyncio.Queue()
212
+ self._handle_pending_requests_task = None
215
213
  self._lock = (
216
214
  None
217
215
  if isinstance(
@@ -237,6 +235,10 @@ class ModelActor(xo.StatelessActor):
237
235
  async def __post_create__(self):
238
236
  self._loop = asyncio.get_running_loop()
239
237
 
238
+ self._handle_pending_requests_task = asyncio.create_task(
239
+ self._handle_pending_requests()
240
+ )
241
+
240
242
  if self.allow_batching():
241
243
  from .scheduler import SchedulerActor
242
244
 
@@ -474,6 +476,43 @@ class ModelActor(xo.StatelessActor):
474
476
  )
475
477
  await asyncio.gather(*coros)
476
478
 
479
+ async def _handle_pending_requests(self):
480
+ logger.info("Start requests handler.")
481
+ while True:
482
+ gen, stream_out, stop = await self._pending_requests.get()
483
+
484
+ async def _async_wrapper(_gen):
485
+ try:
486
+ # anext is only available for Python >= 3.10
487
+ return await _gen.__anext__() # noqa: F821
488
+ except StopAsyncIteration:
489
+ return stop
490
+
491
+ def _wrapper(_gen):
492
+ # Avoid issue: https://github.com/python/cpython/issues/112182
493
+ try:
494
+ return next(_gen)
495
+ except StopIteration:
496
+ return stop
497
+
498
+ while True:
499
+ try:
500
+ if inspect.isgenerator(gen):
501
+ r = await asyncio.to_thread(_wrapper, gen)
502
+ elif inspect.isasyncgen(gen):
503
+ r = await _async_wrapper(gen)
504
+ else:
505
+ raise Exception(
506
+ f"The generator {gen} should be a generator or an async generator, "
507
+ f"but a {type(gen)} is got."
508
+ )
509
+ stream_out.put_nowait(r)
510
+ if r is not stop:
511
+ continue
512
+ except Exception:
513
+ logger.exception("stream encountered an error.")
514
+ break
515
+
477
516
  async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
478
517
  return await self._call_wrapper("json", fn, *args, **kwargs)
479
518
 
@@ -487,6 +526,13 @@ class ModelActor(xo.StatelessActor):
487
526
  ret = await fn(*args, **kwargs)
488
527
  else:
489
528
  ret = await asyncio.to_thread(fn, *args, **kwargs)
529
+
530
+ if inspect.isgenerator(ret):
531
+ gen = self._to_generator(output_type, ret)
532
+ return gen
533
+ if inspect.isasyncgen(ret):
534
+ gen = self._to_async_gen(output_type, ret)
535
+ return gen
490
536
  else:
491
537
  async with self._lock:
492
538
  if inspect.iscoroutinefunction(fn):
@@ -494,17 +540,40 @@ class ModelActor(xo.StatelessActor):
494
540
  else:
495
541
  ret = await asyncio.to_thread(fn, *args, **kwargs)
496
542
 
497
- if self._lock is not None and self._current_generator():
498
- raise Exception("Parallel generation is not supported by llama-cpp-python.")
543
+ stream_out: Union[queue.Queue, asyncio.Queue]
544
+
545
+ if inspect.isgenerator(ret):
546
+ gen = self._to_generator(output_type, ret)
547
+ stream_out = queue.Queue()
548
+ stop = object()
549
+ self._pending_requests.put_nowait((gen, stream_out, stop))
550
+
551
+ def _stream_out_generator():
552
+ while True:
553
+ o = stream_out.get()
554
+ if o is stop:
555
+ break
556
+ else:
557
+ yield o
558
+
559
+ return _stream_out_generator()
560
+
561
+ if inspect.isasyncgen(ret):
562
+ gen = self._to_async_gen(output_type, ret)
563
+ stream_out = asyncio.Queue()
564
+ stop = object()
565
+ self._pending_requests.put_nowait((gen, stream_out, stop))
566
+
567
+ async def _stream_out_async_gen():
568
+ while True:
569
+ o = await stream_out.get()
570
+ if o is stop:
571
+ break
572
+ else:
573
+ yield o
574
+
575
+ return _stream_out_async_gen()
499
576
 
500
- if inspect.isgenerator(ret):
501
- gen = self._to_generator(output_type, ret)
502
- self._current_generator = weakref.ref(gen)
503
- return gen
504
- if inspect.isasyncgen(ret):
505
- gen = self._to_async_gen(output_type, ret)
506
- self._current_generator = weakref.ref(gen)
507
- return gen
508
577
  if output_type == "json":
509
578
  return await asyncio.to_thread(json_dumps, ret)
510
579
  else:
@@ -592,7 +661,6 @@ class ModelActor(xo.StatelessActor):
592
661
  prompt_or_messages, queue, call_ability, *args, **kwargs
593
662
  )
594
663
  gen = self._to_async_gen("json", ret)
595
- self._current_generator = weakref.ref(gen)
596
664
  return gen
597
665
  else:
598
666
  from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
@@ -953,6 +1021,25 @@ class ModelActor(xo.StatelessActor):
953
1021
  f"Model {self._model.model_spec} is not for creating image."
954
1022
  )
955
1023
 
1024
+ @log_async(
1025
+ logger=logger,
1026
+ ignore_kwargs=["image"],
1027
+ )
1028
+ async def ocr(
1029
+ self,
1030
+ image: "PIL.Image",
1031
+ *args,
1032
+ **kwargs,
1033
+ ):
1034
+ if hasattr(self._model, "ocr"):
1035
+ return await self._call_wrapper_json(
1036
+ self._model.ocr,
1037
+ image,
1038
+ *args,
1039
+ **kwargs,
1040
+ )
1041
+ raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
1042
+
956
1043
  @request_limit
957
1044
  @log_async(logger=logger, ignore_kwargs=["image"])
958
1045
  async def infer(
@@ -994,3 +1081,6 @@ class ModelActor(xo.StatelessActor):
994
1081
  async def record_metrics(self, name, op, kwargs):
995
1082
  worker_ref = await self._get_worker_ref()
996
1083
  await worker_ref.record_metrics(name, op, kwargs)
1084
+
1085
+ async def get_pending_requests_count(self):
1086
+ return self._pending_requests.qsize()
@@ -79,7 +79,7 @@ class InferenceRequest:
79
79
  # For tool call
80
80
  self.tools = None
81
81
  # Currently, for storing tool call streaming results.
82
- self.outputs: List[str] = []
82
+ self.outputs: List[str] = [] # type: ignore
83
83
  # inference results,
84
84
  # it is a list type because when stream=True,
85
85
  # self.completion contains all the results in a decode round.
xinference/core/worker.py CHANGED
@@ -785,7 +785,9 @@ class WorkerActor(xo.StatelessActor):
785
785
  peft_model_config: Optional[PeftModelConfig] = None,
786
786
  request_limits: Optional[int] = None,
787
787
  gpu_idx: Optional[Union[int, List[int]]] = None,
788
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
788
+ download_hub: Optional[
789
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
790
+ ] = None,
789
791
  model_path: Optional[str] = None,
790
792
  **kwargs,
791
793
  ):
@@ -31,10 +31,6 @@ from .utils import health_check
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- from ..model import _install as install_model
35
-
36
- install_model()
37
-
38
34
 
39
35
  async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
40
36
  logging.config.dictConfig(logging_conf) # type: ignore
@@ -54,7 +54,11 @@ class ChatTTSModel:
54
54
  torch.set_float32_matmul_precision("high")
55
55
  self._model = ChatTTS.Chat()
56
56
  logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
57
- self._model.load(source="custom", custom_path=self._model_path, **self._kwargs)
57
+ ok = self._model.load(
58
+ source="custom", custom_path=self._model_path, **self._kwargs
59
+ )
60
+ if not ok:
61
+ raise Exception(f"The ChatTTS model is not correct: {self._model_path}")
58
62
 
59
63
  def speech(
60
64
  self,
@@ -114,16 +118,15 @@ class ChatTTSModel:
114
118
  last_pos = 0
115
119
  with writer.open():
116
120
  for it in iter:
117
- for itt in it:
118
- for chunk in itt:
119
- chunk = np.array([chunk]).transpose()
120
- writer.write_audio_chunk(i, torch.from_numpy(chunk))
121
- new_last_pos = out.tell()
122
- if new_last_pos != last_pos:
123
- out.seek(last_pos)
124
- encoded_bytes = out.read()
125
- yield encoded_bytes
126
- last_pos = new_last_pos
121
+ for chunk in it:
122
+ chunk = np.array([chunk]).transpose()
123
+ writer.write_audio_chunk(i, torch.from_numpy(chunk))
124
+ new_last_pos = out.tell()
125
+ if new_last_pos != last_pos:
126
+ out.seek(last_pos)
127
+ encoded_bytes = out.read()
128
+ yield encoded_bytes
129
+ last_pos = new_last_pos
127
130
 
128
131
  return _generator()
129
132
  else:
@@ -131,7 +134,15 @@ class ChatTTSModel:
131
134
 
132
135
  # Save the generated audio
133
136
  with BytesIO() as out:
134
- torchaudio.save(
135
- out, torch.from_numpy(wavs[0]), 24000, format=response_format
136
- )
137
+ try:
138
+ torchaudio.save(
139
+ out,
140
+ torch.from_numpy(wavs[0]).unsqueeze(0),
141
+ 24000,
142
+ format=response_format,
143
+ )
144
+ except:
145
+ torchaudio.save(
146
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
147
+ )
137
148
  return out.getvalue()
@@ -100,7 +100,9 @@ def generate_audio_description(
100
100
 
101
101
  def match_audio(
102
102
  model_name: str,
103
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
103
+ download_hub: Optional[
104
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
105
+ ] = None,
104
106
  ) -> AudioModelFamilyV1:
105
107
  from ..utils import download_from_modelscope
106
108
  from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
@@ -152,7 +154,9 @@ def create_audio_model_instance(
152
154
  devices: List[str],
153
155
  model_uid: str,
154
156
  model_name: str,
155
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
157
+ download_hub: Optional[
158
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
159
+ ] = None,
156
160
  model_path: Optional[str] = None,
157
161
  **kwargs,
158
162
  ) -> Tuple[
@@ -127,7 +127,7 @@
127
127
  "model_name": "ChatTTS",
128
128
  "model_family": "ChatTTS",
129
129
  "model_id": "2Noise/ChatTTS",
130
- "model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
130
+ "model_revision": "3b34118f6d25850440b8901cef3e71c6ef8619c8",
131
131
  "model_ability": "text-to-audio",
132
132
  "multilingual": true
133
133
  },
@@ -42,7 +42,7 @@
42
42
  "model_name": "ChatTTS",
43
43
  "model_family": "ChatTTS",
44
44
  "model_hub": "modelscope",
45
- "model_id": "pzc163/chatTTS",
45
+ "model_id": "AI-ModelScope/ChatTTS",
46
46
  "model_revision": "master",
47
47
  "model_ability": "text-to-audio",
48
48
  "multilingual": true
xinference/model/core.py CHANGED
@@ -55,7 +55,9 @@ def create_model_instance(
55
55
  model_size_in_billions: Optional[Union[int, str]] = None,
56
56
  quantization: Optional[str] = None,
57
57
  peft_model_config: Optional[PeftModelConfig] = None,
58
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
58
+ download_hub: Optional[
59
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
60
+ ] = None,
59
61
  model_path: Optional[str] = None,
60
62
  **kwargs,
61
63
  ) -> Tuple[Any, ModelDescription]:
@@ -433,7 +433,9 @@ class EmbeddingModel:
433
433
 
434
434
  def match_embedding(
435
435
  model_name: str,
436
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
436
+ download_hub: Optional[
437
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
438
+ ] = None,
437
439
  ) -> EmbeddingModelSpec:
438
440
  from ..utils import download_from_modelscope
439
441
  from . import BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS
@@ -469,7 +471,9 @@ def create_embedding_model_instance(
469
471
  devices: List[str],
470
472
  model_uid: str,
471
473
  model_name: str,
472
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
474
+ download_hub: Optional[
475
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
476
+ ] = None,
473
477
  model_path: Optional[str] = None,
474
478
  **kwargs,
475
479
  ) -> Tuple[EmbeddingModel, EmbeddingModelDescription]:
@@ -233,7 +233,7 @@
233
233
  },
234
234
  {
235
235
  "model_name": "gte-Qwen2",
236
- "dimensions": 3584,
236
+ "dimensions": 4096,
237
237
  "max_tokens": 32000,
238
238
  "language": ["zh", "en"],
239
239
  "model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
@@ -11,17 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import collections.abc
15
16
  import logging
16
17
  import os
18
+ import platform
17
19
  from collections import defaultdict
18
- from typing import Dict, List, Literal, Optional, Tuple
20
+ from typing import Dict, List, Literal, Optional, Tuple, Union
19
21
 
20
22
  from ...constants import XINFERENCE_CACHE_DIR
21
23
  from ...types import PeftModelConfig
22
24
  from ..core import CacheableModelSpec, ModelDescription
23
25
  from ..utils import valid_model_revision
26
+ from .ocr.got_ocr2 import GotOCR2Model
24
27
  from .stable_diffusion.core import DiffusionModel
28
+ from .stable_diffusion.mlx import MLXDiffusionModel
25
29
 
26
30
  logger = logging.getLogger(__name__)
27
31
 
@@ -45,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
49
  model_hub: str = "huggingface"
46
50
  model_ability: Optional[List[str]]
47
51
  controlnet: Optional[List["ImageModelFamilyV1"]]
52
+ default_model_config: Optional[dict] = {}
48
53
  default_generate_config: Optional[dict] = {}
49
54
 
50
55
 
@@ -120,7 +125,9 @@ def generate_image_description(
120
125
 
121
126
  def match_diffusion(
122
127
  model_name: str,
123
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
128
+ download_hub: Optional[
129
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
130
+ ] = None,
124
131
  ) -> ImageModelFamilyV1:
125
132
  from ..utils import download_from_modelscope
126
133
  from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
@@ -180,17 +187,59 @@ def get_cache_status(
180
187
  return valid_model_revision(meta_path, model_spec.model_revision)
181
188
 
182
189
 
190
+ def create_ocr_model_instance(
191
+ subpool_addr: str,
192
+ devices: List[str],
193
+ model_uid: str,
194
+ model_spec: ImageModelFamilyV1,
195
+ model_path: Optional[str] = None,
196
+ **kwargs,
197
+ ) -> Tuple[GotOCR2Model, ImageModelDescription]:
198
+ if not model_path:
199
+ model_path = cache(model_spec)
200
+ model = GotOCR2Model(
201
+ model_uid,
202
+ model_path,
203
+ model_spec=model_spec,
204
+ **kwargs,
205
+ )
206
+ model_description = ImageModelDescription(
207
+ subpool_addr, devices, model_spec, model_path=model_path
208
+ )
209
+ return model, model_description
210
+
211
+
183
212
  def create_image_model_instance(
184
213
  subpool_addr: str,
185
214
  devices: List[str],
186
215
  model_uid: str,
187
216
  model_name: str,
188
217
  peft_model_config: Optional[PeftModelConfig] = None,
189
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
218
+ download_hub: Optional[
219
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
220
+ ] = None,
190
221
  model_path: Optional[str] = None,
191
222
  **kwargs,
192
- ) -> Tuple[DiffusionModel, ImageModelDescription]:
223
+ ) -> Tuple[
224
+ Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
225
+ ]:
193
226
  model_spec = match_diffusion(model_name, download_hub)
227
+ if model_spec.model_ability and "ocr" in model_spec.model_ability:
228
+ return create_ocr_model_instance(
229
+ subpool_addr=subpool_addr,
230
+ devices=devices,
231
+ model_uid=model_uid,
232
+ model_name=model_name,
233
+ model_spec=model_spec,
234
+ model_path=model_path,
235
+ **kwargs,
236
+ )
237
+
238
+ # use default model config
239
+ model_default_config = (model_spec.default_model_config or {}).copy()
240
+ model_default_config.update(kwargs)
241
+ kwargs = model_default_config
242
+
194
243
  controlnet = kwargs.get("controlnet")
195
244
  # Handle controlnet
196
245
  if controlnet is not None:
@@ -232,10 +281,20 @@ def create_image_model_instance(
232
281
  lora_load_kwargs = None
233
282
  lora_fuse_kwargs = None
234
283
 
235
- model = DiffusionModel(
284
+ if (
285
+ platform.system() == "Darwin"
286
+ and "arm" in platform.machine().lower()
287
+ and model_name in MLXDiffusionModel.supported_models
288
+ ):
289
+ # Mac with M series silicon chips
290
+ model_cls = MLXDiffusionModel
291
+ else:
292
+ model_cls = DiffusionModel # type: ignore
293
+
294
+ model = model_cls(
236
295
  model_uid,
237
296
  model_path,
238
- lora_model_paths=lora_model,
297
+ lora_model=lora_model,
239
298
  lora_load_kwargs=lora_load_kwargs,
240
299
  lora_fuse_kwargs=lora_fuse_kwargs,
241
300
  model_spec=model_spec,