xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (82) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +99 -5
  4. xinference/client/restful/restful_client.py +98 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +85 -26
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/chattts.py +40 -8
  9. xinference/model/audio/core.py +5 -2
  10. xinference/model/audio/cosyvoice.py +136 -0
  11. xinference/model/audio/model_spec.json +24 -0
  12. xinference/model/audio/model_spec_modelscope.json +27 -0
  13. xinference/model/flexible/launchers/__init__.py +1 -0
  14. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  15. xinference/model/image/core.py +3 -0
  16. xinference/model/image/model_spec.json +21 -0
  17. xinference/model/image/stable_diffusion/core.py +49 -7
  18. xinference/model/llm/llm_family.json +1065 -106
  19. xinference/model/llm/llm_family.py +26 -6
  20. xinference/model/llm/llm_family_csghub.json +39 -0
  21. xinference/model/llm/llm_family_modelscope.json +460 -47
  22. xinference/model/llm/pytorch/chatglm.py +243 -5
  23. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  24. xinference/model/llm/sglang/core.py +7 -2
  25. xinference/model/llm/utils.py +78 -1
  26. xinference/model/llm/vllm/core.py +11 -0
  27. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  29. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  30. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  31. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  34. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  35. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  36. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  37. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  38. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  39. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  40. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  41. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  42. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  43. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  44. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  45. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  46. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  47. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  48. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  50. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  51. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  52. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  53. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  54. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  55. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  56. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  57. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  58. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  59. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  60. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  63. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  64. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  65. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  66. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  67. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  68. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  72. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  74. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
  75. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
  76. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  78. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  79. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  80. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  81. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  82. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
+ from .cosyvoice import CosyVoiceModel
23
24
  from .whisper import WhisperModel
24
25
 
25
26
  MAX_ATTEMPTS = 3
@@ -150,14 +151,16 @@ def create_audio_model_instance(
150
151
  model_name: str,
151
152
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
152
153
  **kwargs,
153
- ) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
154
+ ) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
154
155
  model_spec = match_audio(model_name, download_hub)
155
156
  model_path = cache(model_spec)
156
- model: Union[WhisperModel, ChatTTSModel]
157
+ model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
157
158
  if model_spec.model_family == "whisper":
158
159
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
159
160
  elif model_spec.model_family == "ChatTTS":
160
161
  model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
162
+ elif model_spec.model_family == "CosyVoice":
163
+ model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
161
164
  else:
162
165
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
163
166
  model_description = AudioModelDescription(
@@ -0,0 +1,136 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import logging
16
+ from io import BytesIO
17
+ from typing import TYPE_CHECKING, Optional
18
+
19
+ if TYPE_CHECKING:
20
+ from .core import AudioModelFamilyV1
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class CosyVoiceModel:
26
+ def __init__(
27
+ self,
28
+ model_uid: str,
29
+ model_path: str,
30
+ model_spec: "AudioModelFamilyV1",
31
+ device: Optional[str] = None,
32
+ **kwargs,
33
+ ):
34
+ self._model_uid = model_uid
35
+ self._model_path = model_path
36
+ self._model_spec = model_spec
37
+ self._device = device
38
+ self._model = None
39
+ self._kwargs = kwargs
40
+
41
+ def load(self):
42
+ import os
43
+ import sys
44
+
45
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
46
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
47
+
48
+ from cosyvoice.cli.cosyvoice import CosyVoice
49
+
50
+ self._model = CosyVoice(self._model_path)
51
+
52
+ def speech(
53
+ self,
54
+ input: str,
55
+ voice: str,
56
+ response_format: str = "mp3",
57
+ speed: float = 1.0,
58
+ stream: bool = False,
59
+ **kwargs,
60
+ ):
61
+ if stream:
62
+ raise Exception("CosyVoiceModel does not support stream.")
63
+
64
+ import torchaudio
65
+ from cosyvoice.utils.file_utils import load_wav
66
+
67
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
68
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
69
+ instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
70
+
71
+ if "SFT" in self._model_spec.model_name:
72
+ # inference_sft
73
+ assert (
74
+ prompt_speech is None
75
+ ), "CosyVoice SFT model does not support prompt_speech"
76
+ assert (
77
+ prompt_text is None
78
+ ), "CosyVoice SFT model does not support prompt_text"
79
+ assert (
80
+ instruct_text is None
81
+ ), "CosyVoice SFT model does not support instruct_text"
82
+ elif "Instruct" in self._model_spec.model_name:
83
+ # inference_instruct
84
+ assert (
85
+ prompt_speech is None
86
+ ), "CosyVoice Instruct model does not support prompt_speech"
87
+ assert (
88
+ prompt_text is None
89
+ ), "CosyVoice Instruct model does not support prompt_text"
90
+ assert (
91
+ instruct_text is not None
92
+ ), "CosyVoice Instruct model expect a instruct_text"
93
+ else:
94
+ # inference_zero_shot
95
+ # inference_cross_lingual
96
+ assert prompt_speech is not None, "CosyVoice model expect a prompt_speech"
97
+ assert (
98
+ instruct_text is None
99
+ ), "CosyVoice model does not support instruct_text"
100
+
101
+ assert self._model is not None
102
+ if prompt_speech:
103
+ assert not voice, "voice can't be set with prompt speech."
104
+ with io.BytesIO(prompt_speech) as prompt_speech_io:
105
+ prompt_speech_16k = load_wav(prompt_speech_io, 16000)
106
+ if prompt_text:
107
+ logger.info("CosyVoice inference_zero_shot")
108
+ output = self._model.inference_zero_shot(
109
+ input, prompt_text, prompt_speech_16k
110
+ )
111
+ else:
112
+ logger.info("CosyVoice inference_cross_lingual")
113
+ output = self._model.inference_cross_lingual(
114
+ input, prompt_speech_16k
115
+ )
116
+ else:
117
+ available_speakers = self._model.list_avaliable_spks()
118
+ if not voice:
119
+ voice = available_speakers[0]
120
+ else:
121
+ assert (
122
+ voice in available_speakers
123
+ ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
124
+ if instruct_text:
125
+ logger.info("CosyVoice inference_instruct")
126
+ output = self._model.inference_instruct(
127
+ input, voice, instruct_text=instruct_text
128
+ )
129
+ else:
130
+ logger.info("CosyVoice inference_sft")
131
+ output = self._model.inference_sft(input, voice)
132
+
133
+ # Save the generated audio
134
+ with BytesIO() as out:
135
+ torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
136
+ return out.getvalue()
@@ -102,5 +102,29 @@
102
102
  "model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
103
103
  "ability": "text-to-audio",
104
104
  "multilingual": true
105
+ },
106
+ {
107
+ "model_name": "CosyVoice-300M",
108
+ "model_family": "CosyVoice",
109
+ "model_id": "model-scope/CosyVoice-300M",
110
+ "model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
111
+ "ability": "audio-to-audio",
112
+ "multilingual": true
113
+ },
114
+ {
115
+ "model_name": "CosyVoice-300M-SFT",
116
+ "model_family": "CosyVoice",
117
+ "model_id": "model-scope/CosyVoice-300M-SFT",
118
+ "model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
119
+ "ability": "text-to-audio",
120
+ "multilingual": true
121
+ },
122
+ {
123
+ "model_name": "CosyVoice-300M-Instruct",
124
+ "model_family": "CosyVoice",
125
+ "model_id": "model-scope/CosyVoice-300M-Instruct",
126
+ "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
127
+ "ability": "text-to-audio",
128
+ "multilingual": true
105
129
  }
106
130
  ]
@@ -16,5 +16,32 @@
16
16
  "model_revision": "master",
17
17
  "ability": "text-to-audio",
18
18
  "multilingual": true
19
+ },
20
+ {
21
+ "model_name": "CosyVoice-300M",
22
+ "model_family": "CosyVoice",
23
+ "model_hub": "modelscope",
24
+ "model_id": "iic/CosyVoice-300M",
25
+ "model_revision": "master",
26
+ "ability": "audio-to-audio",
27
+ "multilingual": true
28
+ },
29
+ {
30
+ "model_name": "CosyVoice-300M-SFT",
31
+ "model_family": "CosyVoice",
32
+ "model_hub": "modelscope",
33
+ "model_id": "iic/CosyVoice-300M-SFT",
34
+ "model_revision": "master",
35
+ "ability": "text-to-audio",
36
+ "multilingual": true
37
+ },
38
+ {
39
+ "model_name": "CosyVoice-300M-Instruct",
40
+ "model_family": "CosyVoice",
41
+ "model_hub": "modelscope",
42
+ "model_id": "iic/CosyVoice-300M-Instruct",
43
+ "model_revision": "master",
44
+ "ability": "text-to-audio",
45
+ "multilingual": true
19
46
  }
20
47
  ]
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .image_process_launcher import launcher as image_process
15
16
  from .transformers_launcher import launcher as transformers
@@ -0,0 +1,70 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ from io import BytesIO
17
+
18
+ import PIL.Image
19
+ import PIL.ImageOps
20
+
21
+ from ....types import Image
22
+ from ..core import FlexibleModel, FlexibleModelSpec
23
+
24
+
25
+ class ImageRemoveBackgroundModel(FlexibleModel):
26
+ def infer(self, **kwargs):
27
+ invert = kwargs.get("invert", False)
28
+ b64_image: str = kwargs.get("image") # type: ignore
29
+ only_mask = kwargs.pop("only_mask", True)
30
+ image_format = kwargs.pop("image_format", "PNG")
31
+ if not b64_image:
32
+ raise ValueError("No image found to remove background")
33
+ image = base64.b64decode(b64_image)
34
+
35
+ try:
36
+ from rembg import remove
37
+ except ImportError:
38
+ error_message = "Failed to import module 'rembg'"
39
+ installation_guide = [
40
+ "Please make sure 'rembg' is installed. ",
41
+ "You can install it by visiting the installation section of the git repo:\n",
42
+ "https://github.com/danielgatis/rembg?tab=readme-ov-file#installation",
43
+ ]
44
+
45
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
46
+
47
+ im = PIL.Image.open(BytesIO(image))
48
+ om = remove(im, only_mask=only_mask, **kwargs)
49
+ if invert:
50
+ om = PIL.ImageOps.invert(om)
51
+
52
+ buffered = BytesIO()
53
+ om.save(buffered, format=image_format)
54
+ img_str = base64.b64encode(buffered.getvalue()).decode()
55
+ return Image(url=None, b64_json=img_str)
56
+
57
+
58
+ def launcher(model_uid: str, model_spec: FlexibleModelSpec, **kwargs) -> FlexibleModel:
59
+ task = kwargs.get("task")
60
+ device = kwargs.get("device")
61
+
62
+ if task == "remove_background":
63
+ return ImageRemoveBackgroundModel(
64
+ model_uid=model_uid,
65
+ model_path=model_spec.model_uri, # type: ignore
66
+ device=device,
67
+ config=kwargs,
68
+ )
69
+ else:
70
+ raise ValueError(f"Unknown Task for image processing: {task}")
@@ -45,6 +45,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
45
45
  model_id: str
46
46
  model_revision: str
47
47
  model_hub: str = "huggingface"
48
+ ability: Optional[str]
48
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
49
50
 
50
51
 
@@ -71,6 +72,7 @@ class ImageModelDescription(ModelDescription):
71
72
  "model_name": self._model_spec.model_name,
72
73
  "model_family": self._model_spec.model_family,
73
74
  "model_revision": self._model_spec.model_revision,
75
+ "ability": self._model_spec.ability,
74
76
  "controlnet": controlnet,
75
77
  }
76
78
 
@@ -234,6 +236,7 @@ def create_image_model_instance(
234
236
  lora_model_paths=lora_model,
235
237
  lora_load_kwargs=lora_load_kwargs,
236
238
  lora_fuse_kwargs=lora_fuse_kwargs,
239
+ ability=model_spec.ability,
237
240
  **kwargs,
238
241
  )
239
242
  model_description = ImageModelDescription(
@@ -92,5 +92,26 @@
92
92
  "model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
93
93
  }
94
94
  ]
95
+ },
96
+ {
97
+ "model_name": "stable-diffusion-inpainting",
98
+ "model_family": "stable_diffusion",
99
+ "model_id": "runwayml/stable-diffusion-inpainting",
100
+ "model_revision": "51388a731f57604945fddd703ecb5c50e8e7b49d",
101
+ "ability": "inpainting"
102
+ },
103
+ {
104
+ "model_name": "stable-diffusion-2-inpainting",
105
+ "model_family": "stable_diffusion",
106
+ "model_id": "stabilityai/stable-diffusion-2-inpainting",
107
+ "model_revision": "81a84f49b15956b60b4272a405ad3daef3da4590",
108
+ "ability": "inpainting"
109
+ },
110
+ {
111
+ "model_name": "stable-diffusion-xl-inpainting",
112
+ "model_family": "stable_diffusion",
113
+ "model_id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
114
+ "model_revision": "115134f363124c53c7d878647567d04daf26e41e",
115
+ "ability": "inpainting"
95
116
  }
96
117
  ]
@@ -16,6 +16,7 @@ import base64
16
16
  import logging
17
17
  import os
18
18
  import re
19
+ import sys
19
20
  import time
20
21
  import uuid
21
22
  from concurrent.futures import ThreadPoolExecutor
@@ -39,6 +40,7 @@ class DiffusionModel:
39
40
  lora_model: Optional[List[LoRA]] = None,
40
41
  lora_load_kwargs: Optional[Dict] = None,
41
42
  lora_fuse_kwargs: Optional[Dict] = None,
43
+ ability: Optional[str] = None,
42
44
  **kwargs,
43
45
  ):
44
46
  self._model_uid = model_uid
@@ -48,6 +50,7 @@ class DiffusionModel:
48
50
  self._lora_model = lora_model
49
51
  self._lora_load_kwargs = lora_load_kwargs or {}
50
52
  self._lora_fuse_kwargs = lora_fuse_kwargs or {}
53
+ self._ability = ability
51
54
  self._kwargs = kwargs
52
55
 
53
56
  def _apply_lora(self):
@@ -64,8 +67,14 @@ class DiffusionModel:
64
67
  logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
65
68
 
66
69
  def load(self):
67
- # import torch
68
- from diffusers import AutoPipelineForText2Image
70
+ import torch
71
+
72
+ if self._ability in [None, "text2image", "image2image"]:
73
+ from diffusers import AutoPipelineForText2Image as AutoPipelineModel
74
+ elif self._ability == "inpainting":
75
+ from diffusers import AutoPipelineForInpainting as AutoPipelineModel
76
+ else:
77
+ raise ValueError(f"Unknown ability: {self._ability}")
69
78
 
70
79
  controlnet = self._kwargs.get("controlnet")
71
80
  if controlnet is not None:
@@ -74,14 +83,23 @@ class DiffusionModel:
74
83
  logger.debug("Loading controlnet %s", controlnet)
75
84
  self._kwargs["controlnet"] = ControlNetModel.from_pretrained(controlnet)
76
85
 
77
- self._model = AutoPipelineForText2Image.from_pretrained(
86
+ torch_dtype = self._kwargs.get("torch_dtype")
87
+ if sys.platform != "darwin" and torch_dtype is None:
88
+ # The following params crashes on Mac M2
89
+ self._kwargs["torch_dtype"] = torch.float16
90
+ self._kwargs["use_safetensors"] = True
91
+
92
+ logger.debug("Loading model %s", AutoPipelineModel)
93
+ self._model = AutoPipelineModel.from_pretrained(
78
94
  self._model_path,
79
95
  **self._kwargs,
80
- # The following params crashes on Mac M2
81
- # torch_dtype=torch.float16,
82
- # use_safetensors=True,
83
96
  )
84
- self._model = move_model_to_available_device(self._model)
97
+ if self._kwargs.get("cpu_offload", False):
98
+ logger.debug("CPU offloading model")
99
+ self._model.enable_model_cpu_offload()
100
+ else:
101
+ logger.debug("Loading model to available device")
102
+ self._model = move_model_to_available_device(self._model)
85
103
  # Recommended if your computer has < 64 GB of RAM
86
104
  self._model.enable_attention_slicing()
87
105
  self._apply_lora()
@@ -174,3 +192,27 @@ class DiffusionModel:
174
192
  response_format=response_format,
175
193
  **kwargs,
176
194
  )
195
+
196
+ def inpainting(
197
+ self,
198
+ image: bytes,
199
+ mask_image: bytes,
200
+ prompt: Optional[Union[str, List[str]]] = None,
201
+ negative_prompt: Optional[Union[str, List[str]]] = None,
202
+ n: int = 1,
203
+ size: str = "1024*1024",
204
+ response_format: str = "url",
205
+ **kwargs,
206
+ ):
207
+ width, height = map(int, re.split(r"[^\d]+", size))
208
+ return self._call_model(
209
+ image=image,
210
+ mask_image=mask_image,
211
+ prompt=prompt,
212
+ negative_prompt=negative_prompt,
213
+ height=height,
214
+ width=width,
215
+ num_images_per_prompt=n,
216
+ response_format=response_format,
217
+ **kwargs,
218
+ )