xinference 0.16.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (50) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +48 -0
  3. xinference/client/restful/restful_client.py +19 -0
  4. xinference/core/chat_interface.py +5 -1
  5. xinference/core/image_interface.py +5 -1
  6. xinference/core/model.py +106 -16
  7. xinference/core/scheduler.py +1 -1
  8. xinference/deploy/supervisor.py +0 -4
  9. xinference/model/audio/chattts.py +25 -14
  10. xinference/model/audio/model_spec.json +1 -1
  11. xinference/model/audio/model_spec_modelscope.json +1 -1
  12. xinference/model/embedding/model_spec.json +1 -1
  13. xinference/model/image/core.py +59 -4
  14. xinference/model/image/model_spec.json +24 -3
  15. xinference/model/image/model_spec_modelscope.json +25 -3
  16. xinference/model/image/ocr/__init__.py +13 -0
  17. xinference/model/image/ocr/got_ocr2.py +76 -0
  18. xinference/model/image/scheduler/flux.py +1 -1
  19. xinference/model/image/stable_diffusion/core.py +2 -3
  20. xinference/model/image/stable_diffusion/mlx.py +221 -0
  21. xinference/model/llm/llm_family.json +9 -0
  22. xinference/model/llm/llm_family_modelscope.json +11 -0
  23. xinference/thirdparty/mlx/__init__.py +13 -0
  24. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  25. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  26. xinference/thirdparty/mlx/flux/clip.py +154 -0
  27. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  28. xinference/thirdparty/mlx/flux/flux.py +247 -0
  29. xinference/thirdparty/mlx/flux/layers.py +302 -0
  30. xinference/thirdparty/mlx/flux/lora.py +76 -0
  31. xinference/thirdparty/mlx/flux/model.py +134 -0
  32. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  33. xinference/thirdparty/mlx/flux/t5.py +244 -0
  34. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  35. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  36. xinference/thirdparty/mlx/flux/utils.py +179 -0
  37. xinference/web/ui/build/asset-manifest.json +3 -3
  38. xinference/web/ui/build/index.html +1 -1
  39. xinference/web/ui/build/static/js/{main.f7da0140.js → main.b76aeeb7.js} +3 -3
  40. xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
  42. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/METADATA +15 -8
  43. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/RECORD +48 -31
  44. xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  46. /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
  47. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
  48. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
  49. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
  50. {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-10-18T12:49:02+0800",
11
+ "date": "2024-10-25T12:51:06+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "5f7dea44832a1c41f887b9a01377191894550057",
15
- "version": "0.16.0"
14
+ "full-revisionid": "d4cd7b15104c16838e3c562cf2d33337e3d38897",
15
+ "version": "0.16.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -567,6 +567,16 @@ class RESTfulAPI:
567
567
  else None
568
568
  ),
569
569
  )
570
+ self._router.add_api_route(
571
+ "/v1/images/ocr",
572
+ self.create_ocr,
573
+ methods=["POST"],
574
+ dependencies=(
575
+ [Security(self._auth_service, scopes=["models:read"])]
576
+ if self.is_authenticated()
577
+ else None
578
+ ),
579
+ )
570
580
  # SD WebUI API
571
581
  self._router.add_api_route(
572
582
  "/sdapi/v1/options",
@@ -1754,6 +1764,44 @@ class RESTfulAPI:
1754
1764
  await self._report_error_event(model_uid, str(e))
1755
1765
  raise HTTPException(status_code=500, detail=str(e))
1756
1766
 
1767
+ async def create_ocr(
1768
+ self,
1769
+ model: str = Form(...),
1770
+ image: UploadFile = File(media_type="application/octet-stream"),
1771
+ kwargs: Optional[str] = Form(None),
1772
+ ) -> Response:
1773
+ model_uid = model
1774
+ try:
1775
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
1776
+ except ValueError as ve:
1777
+ logger.error(str(ve), exc_info=True)
1778
+ await self._report_error_event(model_uid, str(ve))
1779
+ raise HTTPException(status_code=400, detail=str(ve))
1780
+ except Exception as e:
1781
+ logger.error(e, exc_info=True)
1782
+ await self._report_error_event(model_uid, str(e))
1783
+ raise HTTPException(status_code=500, detail=str(e))
1784
+
1785
+ try:
1786
+ if kwargs is not None:
1787
+ parsed_kwargs = json.loads(kwargs)
1788
+ else:
1789
+ parsed_kwargs = {}
1790
+ im = Image.open(image.file)
1791
+ text = await model_ref.ocr(
1792
+ image=im,
1793
+ **parsed_kwargs,
1794
+ )
1795
+ return Response(content=text, media_type="text/plain")
1796
+ except RuntimeError as re:
1797
+ logger.error(re, exc_info=True)
1798
+ await self._report_error_event(model_uid, str(re))
1799
+ raise HTTPException(status_code=400, detail=str(re))
1800
+ except Exception as e:
1801
+ logger.error(e, exc_info=True)
1802
+ await self._report_error_event(model_uid, str(e))
1803
+ raise HTTPException(status_code=500, detail=str(e))
1804
+
1757
1805
  async def create_flexible_infer(self, request: Request) -> Response:
1758
1806
  payload = await request.json()
1759
1807
 
@@ -369,6 +369,25 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
369
369
  response_data = response.json()
370
370
  return response_data
371
371
 
372
+ def ocr(self, image: Union[str, bytes], **kwargs):
373
+ url = f"{self._base_url}/v1/images/ocr"
374
+ params = {
375
+ "model": self._model_uid,
376
+ "kwargs": json.dumps(kwargs),
377
+ }
378
+ files: List[Any] = []
379
+ for key, value in params.items():
380
+ files.append((key, (None, value)))
381
+ files.append(("image", ("image", image, "application/octet-stream")))
382
+ response = requests.post(url, files=files, headers=self.auth_headers)
383
+ if response.status_code != 200:
384
+ raise RuntimeError(
385
+ f"Failed to ocr the images, detail: {_get_error_string(response)}"
386
+ )
387
+
388
+ response_data = response.json()
389
+ return response_data
390
+
372
391
 
373
392
  class RESTfulVideoModelHandle(RESTfulModelHandle):
374
393
  def text_to_video(
@@ -74,7 +74,11 @@ class GradioInterface:
74
74
  # Gradio initiates the queue during a startup event, but since the app has already been
75
75
  # started, that event will not run, so manually invoke the startup events.
76
76
  # See: https://github.com/gradio-app/gradio/issues/5228
77
- interface.startup_events()
77
+ try:
78
+ interface.run_startup_events()
79
+ except AttributeError:
80
+ # compatibility
81
+ interface.startup_events()
78
82
  favicon_path = os.path.join(
79
83
  os.path.dirname(os.path.abspath(__file__)),
80
84
  os.path.pardir,
@@ -63,7 +63,11 @@ class ImageInterface:
63
63
  # Gradio initiates the queue during a startup event, but since the app has already been
64
64
  # started, that event will not run, so manually invoke the startup events.
65
65
  # See: https://github.com/gradio-app/gradio/issues/5228
66
- interface.startup_events()
66
+ try:
67
+ interface.run_startup_events()
68
+ except AttributeError:
69
+ # compatibility
70
+ interface.startup_events()
67
71
  favicon_path = os.path.join(
68
72
  os.path.dirname(os.path.abspath(__file__)),
69
73
  os.path.pardir,
xinference/core/model.py CHANGED
@@ -17,10 +17,10 @@ import functools
17
17
  import inspect
18
18
  import json
19
19
  import os
20
+ import queue
20
21
  import time
21
22
  import types
22
23
  import uuid
23
- import weakref
24
24
  from asyncio.queues import Queue
25
25
  from asyncio.tasks import wait_for
26
26
  from concurrent.futures import Future as ConcurrentFuture
@@ -32,7 +32,6 @@ from typing import (
32
32
  Callable,
33
33
  Dict,
34
34
  Generator,
35
- Iterator,
36
35
  List,
37
36
  Optional,
38
37
  Union,
@@ -209,9 +208,8 @@ class ModelActor(xo.StatelessActor):
209
208
  model_description.to_dict() if model_description else {}
210
209
  )
211
210
  self._request_limits = request_limits
212
-
213
- self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
214
- self._current_generator = lambda: None
211
+ self._pending_requests: asyncio.Queue = asyncio.Queue()
212
+ self._handle_pending_requests_task = None
215
213
  self._lock = (
216
214
  None
217
215
  if isinstance(
@@ -237,6 +235,10 @@ class ModelActor(xo.StatelessActor):
237
235
  async def __post_create__(self):
238
236
  self._loop = asyncio.get_running_loop()
239
237
 
238
+ self._handle_pending_requests_task = asyncio.create_task(
239
+ self._handle_pending_requests()
240
+ )
241
+
240
242
  if self.allow_batching():
241
243
  from .scheduler import SchedulerActor
242
244
 
@@ -474,6 +476,43 @@ class ModelActor(xo.StatelessActor):
474
476
  )
475
477
  await asyncio.gather(*coros)
476
478
 
479
+ async def _handle_pending_requests(self):
480
+ logger.info("Start requests handler.")
481
+ while True:
482
+ gen, stream_out, stop = await self._pending_requests.get()
483
+
484
+ async def _async_wrapper(_gen):
485
+ try:
486
+ # anext is only available for Python >= 3.10
487
+ return await _gen.__anext__() # noqa: F821
488
+ except StopAsyncIteration:
489
+ return stop
490
+
491
+ def _wrapper(_gen):
492
+ # Avoid issue: https://github.com/python/cpython/issues/112182
493
+ try:
494
+ return next(_gen)
495
+ except StopIteration:
496
+ return stop
497
+
498
+ while True:
499
+ try:
500
+ if inspect.isgenerator(gen):
501
+ r = await asyncio.to_thread(_wrapper, gen)
502
+ elif inspect.isasyncgen(gen):
503
+ r = await _async_wrapper(gen)
504
+ else:
505
+ raise Exception(
506
+ f"The generator {gen} should be a generator or an async generator, "
507
+ f"but a {type(gen)} is got."
508
+ )
509
+ stream_out.put_nowait(r)
510
+ if r is not stop:
511
+ continue
512
+ except Exception:
513
+ logger.exception("stream encountered an error.")
514
+ break
515
+
477
516
  async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
478
517
  return await self._call_wrapper("json", fn, *args, **kwargs)
479
518
 
@@ -487,6 +526,13 @@ class ModelActor(xo.StatelessActor):
487
526
  ret = await fn(*args, **kwargs)
488
527
  else:
489
528
  ret = await asyncio.to_thread(fn, *args, **kwargs)
529
+
530
+ if inspect.isgenerator(ret):
531
+ gen = self._to_generator(output_type, ret)
532
+ return gen
533
+ if inspect.isasyncgen(ret):
534
+ gen = self._to_async_gen(output_type, ret)
535
+ return gen
490
536
  else:
491
537
  async with self._lock:
492
538
  if inspect.iscoroutinefunction(fn):
@@ -494,17 +540,40 @@ class ModelActor(xo.StatelessActor):
494
540
  else:
495
541
  ret = await asyncio.to_thread(fn, *args, **kwargs)
496
542
 
497
- if self._lock is not None and self._current_generator():
498
- raise Exception("Parallel generation is not supported by llama-cpp-python.")
543
+ stream_out: Union[queue.Queue, asyncio.Queue]
544
+
545
+ if inspect.isgenerator(ret):
546
+ gen = self._to_generator(output_type, ret)
547
+ stream_out = queue.Queue()
548
+ stop = object()
549
+ self._pending_requests.put_nowait((gen, stream_out, stop))
550
+
551
+ def _stream_out_generator():
552
+ while True:
553
+ o = stream_out.get()
554
+ if o is stop:
555
+ break
556
+ else:
557
+ yield o
558
+
559
+ return _stream_out_generator()
560
+
561
+ if inspect.isasyncgen(ret):
562
+ gen = self._to_async_gen(output_type, ret)
563
+ stream_out = asyncio.Queue()
564
+ stop = object()
565
+ self._pending_requests.put_nowait((gen, stream_out, stop))
566
+
567
+ async def _stream_out_async_gen():
568
+ while True:
569
+ o = await stream_out.get()
570
+ if o is stop:
571
+ break
572
+ else:
573
+ yield o
574
+
575
+ return _stream_out_async_gen()
499
576
 
500
- if inspect.isgenerator(ret):
501
- gen = self._to_generator(output_type, ret)
502
- self._current_generator = weakref.ref(gen)
503
- return gen
504
- if inspect.isasyncgen(ret):
505
- gen = self._to_async_gen(output_type, ret)
506
- self._current_generator = weakref.ref(gen)
507
- return gen
508
577
  if output_type == "json":
509
578
  return await asyncio.to_thread(json_dumps, ret)
510
579
  else:
@@ -592,7 +661,6 @@ class ModelActor(xo.StatelessActor):
592
661
  prompt_or_messages, queue, call_ability, *args, **kwargs
593
662
  )
594
663
  gen = self._to_async_gen("json", ret)
595
- self._current_generator = weakref.ref(gen)
596
664
  return gen
597
665
  else:
598
666
  from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
@@ -953,6 +1021,25 @@ class ModelActor(xo.StatelessActor):
953
1021
  f"Model {self._model.model_spec} is not for creating image."
954
1022
  )
955
1023
 
1024
+ @log_async(
1025
+ logger=logger,
1026
+ ignore_kwargs=["image"],
1027
+ )
1028
+ async def ocr(
1029
+ self,
1030
+ image: "PIL.Image",
1031
+ *args,
1032
+ **kwargs,
1033
+ ):
1034
+ if hasattr(self._model, "ocr"):
1035
+ return await self._call_wrapper_json(
1036
+ self._model.ocr,
1037
+ image,
1038
+ *args,
1039
+ **kwargs,
1040
+ )
1041
+ raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
1042
+
956
1043
  @request_limit
957
1044
  @log_async(logger=logger, ignore_kwargs=["image"])
958
1045
  async def infer(
@@ -994,3 +1081,6 @@ class ModelActor(xo.StatelessActor):
994
1081
  async def record_metrics(self, name, op, kwargs):
995
1082
  worker_ref = await self._get_worker_ref()
996
1083
  await worker_ref.record_metrics(name, op, kwargs)
1084
+
1085
+ async def get_pending_requests_count(self):
1086
+ return self._pending_requests.qsize()
@@ -79,7 +79,7 @@ class InferenceRequest:
79
79
  # For tool call
80
80
  self.tools = None
81
81
  # Currently, for storing tool call streaming results.
82
- self.outputs: List[str] = []
82
+ self.outputs: List[str] = [] # type: ignore
83
83
  # inference results,
84
84
  # it is a list type because when stream=True,
85
85
  # self.completion contains all the results in a decode round.
@@ -31,10 +31,6 @@ from .utils import health_check
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- from ..model import _install as install_model
35
-
36
- install_model()
37
-
38
34
 
39
35
  async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
40
36
  logging.config.dictConfig(logging_conf) # type: ignore
@@ -54,7 +54,11 @@ class ChatTTSModel:
54
54
  torch.set_float32_matmul_precision("high")
55
55
  self._model = ChatTTS.Chat()
56
56
  logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
57
- self._model.load(source="custom", custom_path=self._model_path, **self._kwargs)
57
+ ok = self._model.load(
58
+ source="custom", custom_path=self._model_path, **self._kwargs
59
+ )
60
+ if not ok:
61
+ raise Exception(f"The ChatTTS model is not correct: {self._model_path}")
58
62
 
59
63
  def speech(
60
64
  self,
@@ -114,16 +118,15 @@ class ChatTTSModel:
114
118
  last_pos = 0
115
119
  with writer.open():
116
120
  for it in iter:
117
- for itt in it:
118
- for chunk in itt:
119
- chunk = np.array([chunk]).transpose()
120
- writer.write_audio_chunk(i, torch.from_numpy(chunk))
121
- new_last_pos = out.tell()
122
- if new_last_pos != last_pos:
123
- out.seek(last_pos)
124
- encoded_bytes = out.read()
125
- yield encoded_bytes
126
- last_pos = new_last_pos
121
+ for chunk in it:
122
+ chunk = np.array([chunk]).transpose()
123
+ writer.write_audio_chunk(i, torch.from_numpy(chunk))
124
+ new_last_pos = out.tell()
125
+ if new_last_pos != last_pos:
126
+ out.seek(last_pos)
127
+ encoded_bytes = out.read()
128
+ yield encoded_bytes
129
+ last_pos = new_last_pos
127
130
 
128
131
  return _generator()
129
132
  else:
@@ -131,7 +134,15 @@ class ChatTTSModel:
131
134
 
132
135
  # Save the generated audio
133
136
  with BytesIO() as out:
134
- torchaudio.save(
135
- out, torch.from_numpy(wavs[0]), 24000, format=response_format
136
- )
137
+ try:
138
+ torchaudio.save(
139
+ out,
140
+ torch.from_numpy(wavs[0]).unsqueeze(0),
141
+ 24000,
142
+ format=response_format,
143
+ )
144
+ except:
145
+ torchaudio.save(
146
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
147
+ )
137
148
  return out.getvalue()
@@ -127,7 +127,7 @@
127
127
  "model_name": "ChatTTS",
128
128
  "model_family": "ChatTTS",
129
129
  "model_id": "2Noise/ChatTTS",
130
- "model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
130
+ "model_revision": "3b34118f6d25850440b8901cef3e71c6ef8619c8",
131
131
  "model_ability": "text-to-audio",
132
132
  "multilingual": true
133
133
  },
@@ -42,7 +42,7 @@
42
42
  "model_name": "ChatTTS",
43
43
  "model_family": "ChatTTS",
44
44
  "model_hub": "modelscope",
45
- "model_id": "pzc163/chatTTS",
45
+ "model_id": "AI-ModelScope/ChatTTS",
46
46
  "model_revision": "master",
47
47
  "model_ability": "text-to-audio",
48
48
  "multilingual": true
@@ -233,7 +233,7 @@
233
233
  },
234
234
  {
235
235
  "model_name": "gte-Qwen2",
236
- "dimensions": 3584,
236
+ "dimensions": 4096,
237
237
  "max_tokens": 32000,
238
238
  "language": ["zh", "en"],
239
239
  "model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
@@ -11,17 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import collections.abc
15
16
  import logging
16
17
  import os
18
+ import platform
17
19
  from collections import defaultdict
18
- from typing import Dict, List, Literal, Optional, Tuple
20
+ from typing import Dict, List, Literal, Optional, Tuple, Union
19
21
 
20
22
  from ...constants import XINFERENCE_CACHE_DIR
21
23
  from ...types import PeftModelConfig
22
24
  from ..core import CacheableModelSpec, ModelDescription
23
25
  from ..utils import valid_model_revision
26
+ from .ocr.got_ocr2 import GotOCR2Model
24
27
  from .stable_diffusion.core import DiffusionModel
28
+ from .stable_diffusion.mlx import MLXDiffusionModel
25
29
 
26
30
  logger = logging.getLogger(__name__)
27
31
 
@@ -45,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
49
  model_hub: str = "huggingface"
46
50
  model_ability: Optional[List[str]]
47
51
  controlnet: Optional[List["ImageModelFamilyV1"]]
52
+ default_model_config: Optional[dict] = {}
48
53
  default_generate_config: Optional[dict] = {}
49
54
 
50
55
 
@@ -180,6 +185,28 @@ def get_cache_status(
180
185
  return valid_model_revision(meta_path, model_spec.model_revision)
181
186
 
182
187
 
188
+ def create_ocr_model_instance(
189
+ subpool_addr: str,
190
+ devices: List[str],
191
+ model_uid: str,
192
+ model_spec: ImageModelFamilyV1,
193
+ model_path: Optional[str] = None,
194
+ **kwargs,
195
+ ) -> Tuple[GotOCR2Model, ImageModelDescription]:
196
+ if not model_path:
197
+ model_path = cache(model_spec)
198
+ model = GotOCR2Model(
199
+ model_uid,
200
+ model_path,
201
+ model_spec=model_spec,
202
+ **kwargs,
203
+ )
204
+ model_description = ImageModelDescription(
205
+ subpool_addr, devices, model_spec, model_path=model_path
206
+ )
207
+ return model, model_description
208
+
209
+
183
210
  def create_image_model_instance(
184
211
  subpool_addr: str,
185
212
  devices: List[str],
@@ -189,8 +216,26 @@ def create_image_model_instance(
189
216
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
190
217
  model_path: Optional[str] = None,
191
218
  **kwargs,
192
- ) -> Tuple[DiffusionModel, ImageModelDescription]:
219
+ ) -> Tuple[
220
+ Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
221
+ ]:
193
222
  model_spec = match_diffusion(model_name, download_hub)
223
+ if model_spec.model_ability and "ocr" in model_spec.model_ability:
224
+ return create_ocr_model_instance(
225
+ subpool_addr=subpool_addr,
226
+ devices=devices,
227
+ model_uid=model_uid,
228
+ model_name=model_name,
229
+ model_spec=model_spec,
230
+ model_path=model_path,
231
+ **kwargs,
232
+ )
233
+
234
+ # use default model config
235
+ model_default_config = (model_spec.default_model_config or {}).copy()
236
+ model_default_config.update(kwargs)
237
+ kwargs = model_default_config
238
+
194
239
  controlnet = kwargs.get("controlnet")
195
240
  # Handle controlnet
196
241
  if controlnet is not None:
@@ -232,10 +277,20 @@ def create_image_model_instance(
232
277
  lora_load_kwargs = None
233
278
  lora_fuse_kwargs = None
234
279
 
235
- model = DiffusionModel(
280
+ if (
281
+ platform.system() == "Darwin"
282
+ and "arm" in platform.machine().lower()
283
+ and model_name in MLXDiffusionModel.supported_models
284
+ ):
285
+ # Mac with M series silicon chips
286
+ model_cls = MLXDiffusionModel
287
+ else:
288
+ model_cls = DiffusionModel # type: ignore
289
+
290
+ model = model_cls(
236
291
  model_uid,
237
292
  model_path,
238
- lora_model_paths=lora_model,
293
+ lora_model=lora_model,
239
294
  lora_load_kwargs=lora_load_kwargs,
240
295
  lora_fuse_kwargs=lora_fuse_kwargs,
241
296
  model_spec=model_spec,
@@ -8,7 +8,11 @@
8
8
  "text2image",
9
9
  "image2image",
10
10
  "inpainting"
11
- ]
11
+ ],
12
+ "default_model_config": {
13
+ "quantize": true,
14
+ "quantize_text_encoder": "text_encoder_2"
15
+ }
12
16
  },
13
17
  {
14
18
  "model_name": "FLUX.1-dev",
@@ -19,7 +23,11 @@
19
23
  "text2image",
20
24
  "image2image",
21
25
  "inpainting"
22
- ]
26
+ ],
27
+ "default_model_config": {
28
+ "quantize": true,
29
+ "quantize_text_encoder": "text_encoder_2"
30
+ }
23
31
  },
24
32
  {
25
33
  "model_name": "sd3-medium",
@@ -30,7 +38,11 @@
30
38
  "text2image",
31
39
  "image2image",
32
40
  "inpainting"
33
- ]
41
+ ],
42
+ "default_model_config": {
43
+ "quantize": true,
44
+ "quantize_text_encoder": "text_encoder_3"
45
+ }
34
46
  },
35
47
  {
36
48
  "model_name": "sd-turbo",
@@ -178,5 +190,14 @@
178
190
  "model_ability": [
179
191
  "inpainting"
180
192
  ]
193
+ },
194
+ {
195
+ "model_name": "GOT-OCR2_0",
196
+ "model_family": "ocr",
197
+ "model_id": "stepfun-ai/GOT-OCR2_0",
198
+ "model_revision": "cf6b7386bc89a54f09785612ba74cb12de6fa17c",
199
+ "model_ability": [
200
+ "ocr"
201
+ ]
181
202
  }
182
203
  ]
@@ -9,7 +9,11 @@
9
9
  "text2image",
10
10
  "image2image",
11
11
  "inpainting"
12
- ]
12
+ ],
13
+ "default_model_config": {
14
+ "quantize": true,
15
+ "quantize_text_encoder": "text_encoder_2"
16
+ }
13
17
  },
14
18
  {
15
19
  "model_name": "FLUX.1-dev",
@@ -21,7 +25,11 @@
21
25
  "text2image",
22
26
  "image2image",
23
27
  "inpainting"
24
- ]
28
+ ],
29
+ "default_model_config": {
30
+ "quantize": true,
31
+ "quantize_text_encoder": "text_encoder_2"
32
+ }
25
33
  },
26
34
  {
27
35
  "model_name": "sd3-medium",
@@ -33,7 +41,11 @@
33
41
  "text2image",
34
42
  "image2image",
35
43
  "inpainting"
36
- ]
44
+ ],
45
+ "default_model_config": {
46
+ "quantize": true,
47
+ "quantize_text_encoder": "text_encoder_3"
48
+ }
37
49
  },
38
50
  {
39
51
  "model_name": "sd-turbo",
@@ -148,5 +160,15 @@
148
160
  "model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
149
161
  }
150
162
  ]
163
+ },
164
+ {
165
+ "model_name": "GOT-OCR2_0",
166
+ "model_family": "ocr",
167
+ "model_id": "stepfun-ai/GOT-OCR2_0",
168
+ "model_revision": "master",
169
+ "model_hub": "modelscope",
170
+ "model_ability": [
171
+ "ocr"
172
+ ]
151
173
  }
152
174
  ]
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.