xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (78) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +26 -4
  4. xinference/client/restful/restful_client.py +16 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +8 -3
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/core.py +5 -2
  9. xinference/model/audio/cosyvoice.py +136 -0
  10. xinference/model/audio/model_spec.json +24 -0
  11. xinference/model/audio/model_spec_modelscope.json +27 -0
  12. xinference/model/flexible/launchers/__init__.py +1 -0
  13. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  14. xinference/model/image/model_spec.json +7 -0
  15. xinference/model/image/stable_diffusion/core.py +6 -1
  16. xinference/model/llm/llm_family.json +802 -82
  17. xinference/model/llm/llm_family_csghub.json +39 -0
  18. xinference/model/llm/llm_family_modelscope.json +295 -47
  19. xinference/model/llm/pytorch/chatglm.py +243 -5
  20. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  21. xinference/model/llm/utils.py +78 -1
  22. xinference/model/llm/vllm/core.py +8 -0
  23. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  24. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  25. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  26. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  27. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  29. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  30. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  31. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  33. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  34. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  35. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  36. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  37. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  38. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  39. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  40. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  41. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  42. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  43. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  44. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  45. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  46. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  47. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  48. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  49. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  50. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  51. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  52. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  53. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  54. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  55. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  56. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  57. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  58. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  59. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  60. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  61. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  62. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  63. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  64. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  65. xinference/web/ui/build/asset-manifest.json +3 -3
  66. xinference/web/ui/build/index.html +1 -1
  67. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  68. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  70. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
  71. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
  72. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  74. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  75. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  76. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  77. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  78. {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
xinference/__init__.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from . import _version
17
16
 
18
17
  __version__ = _version.get_versions()["version"]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-19T19:15:54+0800",
11
+ "date": "2024-07-26T18:42:50+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "880929cbbc73e5206ca069591b03d9d16dd858bf",
15
- "version": "0.13.2"
14
+ "full-revisionid": "aa51ff22dbfb5644554436270deaf57a7ebaf066",
15
+ "version": "0.13.3"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -130,6 +130,7 @@ class SpeechRequest(BaseModel):
130
130
  response_format: Optional[str] = "mp3"
131
131
  speed: Optional[float] = 1.0
132
132
  stream: Optional[bool] = False
133
+ kwargs: Optional[str] = None
133
134
 
134
135
 
135
136
  class RegisterModelRequest(BaseModel):
@@ -1309,8 +1310,18 @@ class RESTfulAPI:
1309
1310
  await self._report_error_event(model_uid, str(e))
1310
1311
  raise HTTPException(status_code=500, detail=str(e))
1311
1312
 
1312
- async def create_speech(self, request: Request) -> Response:
1313
- body = SpeechRequest.parse_obj(await request.json())
1313
+ async def create_speech(
1314
+ self,
1315
+ request: Request,
1316
+ prompt_speech: Optional[UploadFile] = File(
1317
+ None, media_type="application/octet-stream"
1318
+ ),
1319
+ ) -> Response:
1320
+ if prompt_speech:
1321
+ f = await request.form()
1322
+ else:
1323
+ f = await request.json()
1324
+ body = SpeechRequest.parse_obj(f)
1314
1325
  model_uid = body.model
1315
1326
  try:
1316
1327
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1324,12 +1335,19 @@ class RESTfulAPI:
1324
1335
  raise HTTPException(status_code=500, detail=str(e))
1325
1336
 
1326
1337
  try:
1338
+ if body.kwargs is not None:
1339
+ parsed_kwargs = json.loads(body.kwargs)
1340
+ else:
1341
+ parsed_kwargs = {}
1342
+ if prompt_speech is not None:
1343
+ parsed_kwargs["prompt_speech"] = await prompt_speech.read()
1327
1344
  out = await model.speech(
1328
1345
  input=body.input,
1329
1346
  voice=body.voice,
1330
1347
  response_format=body.response_format,
1331
1348
  speed=body.speed,
1332
1349
  stream=body.stream,
1350
+ **parsed_kwargs,
1333
1351
  )
1334
1352
  if body.stream:
1335
1353
  return EventSourceResponse(
@@ -1626,10 +1644,14 @@ class RESTfulAPI:
1626
1644
  if body.tools and body.stream:
1627
1645
  is_vllm = await model.is_vllm_backend()
1628
1646
 
1629
- if not is_vllm or model_family not in QWEN_TOOL_CALL_FAMILY:
1647
+ if not (
1648
+ (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
1649
+ or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
1650
+ ):
1630
1651
  raise HTTPException(
1631
1652
  status_code=400,
1632
- detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
1653
+ detail="Streaming support for tool calls is available only when using "
1654
+ "Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
1633
1655
  )
1634
1656
 
1635
1657
  if body.stream:
@@ -768,6 +768,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
768
768
  response_format: str = "mp3",
769
769
  speed: float = 1.0,
770
770
  stream: bool = False,
771
+ prompt_speech: Optional[bytes] = None,
772
+ **kwargs,
771
773
  ):
772
774
  """
773
775
  Generates audio from the input text.
@@ -799,8 +801,21 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
799
801
  "response_format": response_format,
800
802
  "speed": speed,
801
803
  "stream": stream,
804
+ "kwargs": json.dumps(kwargs),
802
805
  }
803
- response = requests.post(url, json=params, headers=self.auth_headers)
806
+ if prompt_speech:
807
+ files: List[Any] = []
808
+ files.append(
809
+ (
810
+ "prompt_speech",
811
+ ("prompt_speech", prompt_speech, "application/octet-stream"),
812
+ )
813
+ )
814
+ response = requests.post(
815
+ url, data=params, files=files, headers=self.auth_headers
816
+ )
817
+ else:
818
+ response = requests.post(url, json=params, headers=self.auth_headers)
804
819
  if response.status_code != 200:
805
820
  raise RuntimeError(
806
821
  f"Failed to speech the text, detail: {_get_error_string(response)}"
@@ -428,7 +428,7 @@ class GradioInterface:
428
428
  }
429
429
 
430
430
  hist.append(response_content)
431
- return {
431
+ return { # type: ignore
432
432
  textbox: response_content,
433
433
  history: hist,
434
434
  }
@@ -467,7 +467,7 @@ class GradioInterface:
467
467
  }
468
468
 
469
469
  hist.append(response_content)
470
- return {
470
+ return { # type: ignore
471
471
  textbox: response_content,
472
472
  history: hist,
473
473
  }
xinference/core/model.py CHANGED
@@ -646,7 +646,10 @@ class ModelActor(xo.StatelessActor):
646
646
  f"Model {self._model.model_spec} is not for creating translations."
647
647
  )
648
648
 
649
- @log_async(logger=logger)
649
+ @log_async(
650
+ logger=logger,
651
+ args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
652
+ )
650
653
  @request_limit
651
654
  @xo.generator
652
655
  async def speech(
@@ -656,6 +659,7 @@ class ModelActor(xo.StatelessActor):
656
659
  response_format: str = "mp3",
657
660
  speed: float = 1.0,
658
661
  stream: bool = False,
662
+ **kwargs,
659
663
  ):
660
664
  if hasattr(self._model, "speech"):
661
665
  return await self._call_wrapper_binary(
@@ -665,6 +669,7 @@ class ModelActor(xo.StatelessActor):
665
669
  response_format,
666
670
  speed,
667
671
  stream,
672
+ **kwargs,
668
673
  )
669
674
  raise AttributeError(
670
675
  f"Model {self._model.model_spec} is not for creating speech."
@@ -735,7 +740,7 @@ class ModelActor(xo.StatelessActor):
735
740
  **kwargs,
736
741
  ):
737
742
  if hasattr(self._model, "inpainting"):
738
- return await self._call_wrapper(
743
+ return await self._call_wrapper_json(
739
744
  self._model.inpainting,
740
745
  image,
741
746
  mask_image,
@@ -758,7 +763,7 @@ class ModelActor(xo.StatelessActor):
758
763
  **kwargs,
759
764
  ):
760
765
  if hasattr(self._model, "infer"):
761
- return await self._call_wrapper(
766
+ return await self._call_wrapper_json(
762
767
  self._model.infer,
763
768
  **kwargs,
764
769
  )
@@ -81,7 +81,7 @@ class InferenceRequest:
81
81
  self.future_or_queue = future_or_queue
82
82
  # Record error message when this request has error.
83
83
  # Must set stopped=True when this field is set.
84
- self.error_msg: Optional[str] = None
84
+ self.error_msg: Optional[str] = None # type: ignore
85
85
  # For compatibility. Record some extra parameters for some special cases.
86
86
  self.extra_kwargs = {}
87
87
 
@@ -295,11 +295,11 @@ class SchedulerActor(xo.StatelessActor):
295
295
 
296
296
  def __init__(self):
297
297
  super().__init__()
298
- self._waiting_queue: deque[InferenceRequest] = deque()
299
- self._running_queue: deque[InferenceRequest] = deque()
298
+ self._waiting_queue: deque[InferenceRequest] = deque() # type: ignore
299
+ self._running_queue: deque[InferenceRequest] = deque() # type: ignore
300
300
  self._model = None
301
301
  self._id_to_req = {}
302
- self._abort_req_ids: Set[str] = set()
302
+ self._abort_req_ids: Set[str] = set() # type: ignore
303
303
  self._isolation = None
304
304
 
305
305
  async def __post_create__(self):
@@ -20,6 +20,7 @@ from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
+ from .cosyvoice import CosyVoiceModel
23
24
  from .whisper import WhisperModel
24
25
 
25
26
  MAX_ATTEMPTS = 3
@@ -150,14 +151,16 @@ def create_audio_model_instance(
150
151
  model_name: str,
151
152
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
152
153
  **kwargs,
153
- ) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
154
+ ) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
154
155
  model_spec = match_audio(model_name, download_hub)
155
156
  model_path = cache(model_spec)
156
- model: Union[WhisperModel, ChatTTSModel]
157
+ model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
157
158
  if model_spec.model_family == "whisper":
158
159
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
159
160
  elif model_spec.model_family == "ChatTTS":
160
161
  model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
162
+ elif model_spec.model_family == "CosyVoice":
163
+ model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
161
164
  else:
162
165
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
163
166
  model_description = AudioModelDescription(
@@ -0,0 +1,136 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import logging
16
+ from io import BytesIO
17
+ from typing import TYPE_CHECKING, Optional
18
+
19
+ if TYPE_CHECKING:
20
+ from .core import AudioModelFamilyV1
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class CosyVoiceModel:
26
+ def __init__(
27
+ self,
28
+ model_uid: str,
29
+ model_path: str,
30
+ model_spec: "AudioModelFamilyV1",
31
+ device: Optional[str] = None,
32
+ **kwargs,
33
+ ):
34
+ self._model_uid = model_uid
35
+ self._model_path = model_path
36
+ self._model_spec = model_spec
37
+ self._device = device
38
+ self._model = None
39
+ self._kwargs = kwargs
40
+
41
+ def load(self):
42
+ import os
43
+ import sys
44
+
45
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
46
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
47
+
48
+ from cosyvoice.cli.cosyvoice import CosyVoice
49
+
50
+ self._model = CosyVoice(self._model_path)
51
+
52
+ def speech(
53
+ self,
54
+ input: str,
55
+ voice: str,
56
+ response_format: str = "mp3",
57
+ speed: float = 1.0,
58
+ stream: bool = False,
59
+ **kwargs,
60
+ ):
61
+ if stream:
62
+ raise Exception("CosyVoiceModel does not support stream.")
63
+
64
+ import torchaudio
65
+ from cosyvoice.utils.file_utils import load_wav
66
+
67
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
68
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
69
+ instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
70
+
71
+ if "SFT" in self._model_spec.model_name:
72
+ # inference_sft
73
+ assert (
74
+ prompt_speech is None
75
+ ), "CosyVoice SFT model does not support prompt_speech"
76
+ assert (
77
+ prompt_text is None
78
+ ), "CosyVoice SFT model does not support prompt_text"
79
+ assert (
80
+ instruct_text is None
81
+ ), "CosyVoice SFT model does not support instruct_text"
82
+ elif "Instruct" in self._model_spec.model_name:
83
+ # inference_instruct
84
+ assert (
85
+ prompt_speech is None
86
+ ), "CosyVoice Instruct model does not support prompt_speech"
87
+ assert (
88
+ prompt_text is None
89
+ ), "CosyVoice Instruct model does not support prompt_text"
90
+ assert (
91
+ instruct_text is not None
92
+ ), "CosyVoice Instruct model expect a instruct_text"
93
+ else:
94
+ # inference_zero_shot
95
+ # inference_cross_lingual
96
+ assert prompt_speech is not None, "CosyVoice model expect a prompt_speech"
97
+ assert (
98
+ instruct_text is None
99
+ ), "CosyVoice model does not support instruct_text"
100
+
101
+ assert self._model is not None
102
+ if prompt_speech:
103
+ assert not voice, "voice can't be set with prompt speech."
104
+ with io.BytesIO(prompt_speech) as prompt_speech_io:
105
+ prompt_speech_16k = load_wav(prompt_speech_io, 16000)
106
+ if prompt_text:
107
+ logger.info("CosyVoice inference_zero_shot")
108
+ output = self._model.inference_zero_shot(
109
+ input, prompt_text, prompt_speech_16k
110
+ )
111
+ else:
112
+ logger.info("CosyVoice inference_cross_lingual")
113
+ output = self._model.inference_cross_lingual(
114
+ input, prompt_speech_16k
115
+ )
116
+ else:
117
+ available_speakers = self._model.list_avaliable_spks()
118
+ if not voice:
119
+ voice = available_speakers[0]
120
+ else:
121
+ assert (
122
+ voice in available_speakers
123
+ ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
124
+ if instruct_text:
125
+ logger.info("CosyVoice inference_instruct")
126
+ output = self._model.inference_instruct(
127
+ input, voice, instruct_text=instruct_text
128
+ )
129
+ else:
130
+ logger.info("CosyVoice inference_sft")
131
+ output = self._model.inference_sft(input, voice)
132
+
133
+ # Save the generated audio
134
+ with BytesIO() as out:
135
+ torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
136
+ return out.getvalue()
@@ -102,5 +102,29 @@
102
102
  "model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
103
103
  "ability": "text-to-audio",
104
104
  "multilingual": true
105
+ },
106
+ {
107
+ "model_name": "CosyVoice-300M",
108
+ "model_family": "CosyVoice",
109
+ "model_id": "model-scope/CosyVoice-300M",
110
+ "model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
111
+ "ability": "audio-to-audio",
112
+ "multilingual": true
113
+ },
114
+ {
115
+ "model_name": "CosyVoice-300M-SFT",
116
+ "model_family": "CosyVoice",
117
+ "model_id": "model-scope/CosyVoice-300M-SFT",
118
+ "model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
119
+ "ability": "text-to-audio",
120
+ "multilingual": true
121
+ },
122
+ {
123
+ "model_name": "CosyVoice-300M-Instruct",
124
+ "model_family": "CosyVoice",
125
+ "model_id": "model-scope/CosyVoice-300M-Instruct",
126
+ "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
127
+ "ability": "text-to-audio",
128
+ "multilingual": true
105
129
  }
106
130
  ]
@@ -16,5 +16,32 @@
16
16
  "model_revision": "master",
17
17
  "ability": "text-to-audio",
18
18
  "multilingual": true
19
+ },
20
+ {
21
+ "model_name": "CosyVoice-300M",
22
+ "model_family": "CosyVoice",
23
+ "model_hub": "modelscope",
24
+ "model_id": "iic/CosyVoice-300M",
25
+ "model_revision": "master",
26
+ "ability": "audio-to-audio",
27
+ "multilingual": true
28
+ },
29
+ {
30
+ "model_name": "CosyVoice-300M-SFT",
31
+ "model_family": "CosyVoice",
32
+ "model_hub": "modelscope",
33
+ "model_id": "iic/CosyVoice-300M-SFT",
34
+ "model_revision": "master",
35
+ "ability": "text-to-audio",
36
+ "multilingual": true
37
+ },
38
+ {
39
+ "model_name": "CosyVoice-300M-Instruct",
40
+ "model_family": "CosyVoice",
41
+ "model_hub": "modelscope",
42
+ "model_id": "iic/CosyVoice-300M-Instruct",
43
+ "model_revision": "master",
44
+ "ability": "text-to-audio",
45
+ "multilingual": true
19
46
  }
20
47
  ]
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .image_process_launcher import launcher as image_process
15
16
  from .transformers_launcher import launcher as transformers
@@ -0,0 +1,70 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ from io import BytesIO
17
+
18
+ import PIL.Image
19
+ import PIL.ImageOps
20
+
21
+ from ....types import Image
22
+ from ..core import FlexibleModel, FlexibleModelSpec
23
+
24
+
25
+ class ImageRemoveBackgroundModel(FlexibleModel):
26
+ def infer(self, **kwargs):
27
+ invert = kwargs.get("invert", False)
28
+ b64_image: str = kwargs.get("image") # type: ignore
29
+ only_mask = kwargs.pop("only_mask", True)
30
+ image_format = kwargs.pop("image_format", "PNG")
31
+ if not b64_image:
32
+ raise ValueError("No image found to remove background")
33
+ image = base64.b64decode(b64_image)
34
+
35
+ try:
36
+ from rembg import remove
37
+ except ImportError:
38
+ error_message = "Failed to import module 'rembg'"
39
+ installation_guide = [
40
+ "Please make sure 'rembg' is installed. ",
41
+ "You can install it by visiting the installation section of the git repo:\n",
42
+ "https://github.com/danielgatis/rembg?tab=readme-ov-file#installation",
43
+ ]
44
+
45
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
46
+
47
+ im = PIL.Image.open(BytesIO(image))
48
+ om = remove(im, only_mask=only_mask, **kwargs)
49
+ if invert:
50
+ om = PIL.ImageOps.invert(om)
51
+
52
+ buffered = BytesIO()
53
+ om.save(buffered, format=image_format)
54
+ img_str = base64.b64encode(buffered.getvalue()).decode()
55
+ return Image(url=None, b64_json=img_str)
56
+
57
+
58
+ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
59
+ task = kwargs.get("task")
60
+ device = kwargs.get("device")
61
+
62
+ if task == "remove_background":
63
+ return ImageRemoveBackgroundModel(
64
+ model_uid=model_uid,
65
+ model_path=model_spec.model_uri, # type: ignore
66
+ device=device,
67
+ config=kwargs,
68
+ )
69
+ else:
70
+ raise ValueError(f"Unknown Task for image processing: {task}")
@@ -106,5 +106,12 @@
106
106
  "model_id": "stabilityai/stable-diffusion-2-inpainting",
107
107
  "model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
108
108
  "ability": "inpainting"
109
+ },
110
+ {
111
+ "model_name": "stable-diffusion-xl-inpainting",
112
+ "model_family": "stable_diffusion",
113
+ "model_id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
114
+ "model_revision": "115134f363124c53c7d878647567d04daf26e41e",
115
+ "ability": "inpainting"
109
116
  }
110
117
  ]
@@ -94,7 +94,12 @@ class DiffusionModel:
94
94
  self._model_path,
95
95
  **self._kwargs,
96
96
  )
97
- self._model = move_model_to_available_device(self._model)
97
+ if self._kwargs.get("cpu_offload", False):
98
+ logger.debug("CPU offloading model")
99
+ self._model.enable_model_cpu_offload()
100
+ else:
101
+ logger.debug("Loading model to available device")
102
+ self._model = move_model_to_available_device(self._model)
98
103
  # Recommended if your computer has < 64 GB of RAM
99
104
  self._model.enable_attention_slicing()
100
105
  self._apply_lora()