xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,7 @@ from collections import defaultdict
|
|
|
19
19
|
from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
|
+
import torch
|
|
22
23
|
|
|
23
24
|
from ...device_utils import empty_cache
|
|
24
25
|
from ...types import Embedding, EmbeddingData, EmbeddingUsage
|
|
@@ -34,7 +35,11 @@ EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
|
34
35
|
EMBEDDING_EMPTY_CACHE_COUNT = int(
|
|
35
36
|
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
|
|
36
37
|
)
|
|
38
|
+
EMBEDDING_EMPTY_CACHE_TOKENS = int(
|
|
39
|
+
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_TOKENS", "8192")
|
|
40
|
+
)
|
|
37
41
|
assert EMBEDDING_EMPTY_CACHE_COUNT > 0
|
|
42
|
+
assert EMBEDDING_EMPTY_CACHE_TOKENS > 0
|
|
38
43
|
|
|
39
44
|
|
|
40
45
|
def get_embedding_model_descriptions():
|
|
@@ -149,6 +154,25 @@ class EmbeddingModel:
|
|
|
149
154
|
def to(self, *args, **kwargs):
|
|
150
155
|
pass
|
|
151
156
|
|
|
157
|
+
torch_dtype = None
|
|
158
|
+
if torch_dtype_str := self._kwargs.get("torch_dtype"):
|
|
159
|
+
try:
|
|
160
|
+
torch_dtype = getattr(torch, torch_dtype_str)
|
|
161
|
+
if torch_dtype not in [
|
|
162
|
+
torch.float16,
|
|
163
|
+
torch.float32,
|
|
164
|
+
torch.bfloat16,
|
|
165
|
+
]:
|
|
166
|
+
logger.warning(
|
|
167
|
+
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
|
|
168
|
+
)
|
|
169
|
+
torch_dtype = torch.float32
|
|
170
|
+
except AttributeError:
|
|
171
|
+
logger.warning(
|
|
172
|
+
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
|
|
173
|
+
)
|
|
174
|
+
torch_dtype = torch.float32
|
|
175
|
+
|
|
152
176
|
from ..utils import patch_trust_remote_code
|
|
153
177
|
|
|
154
178
|
patch_trust_remote_code()
|
|
@@ -156,42 +180,21 @@ class EmbeddingModel:
|
|
|
156
180
|
"gte" in self._model_spec.model_name.lower()
|
|
157
181
|
and "qwen2" in self._model_spec.model_name.lower()
|
|
158
182
|
):
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
if torch_dtype_str is not None:
|
|
163
|
-
try:
|
|
164
|
-
torch_dtype = getattr(torch, torch_dtype_str)
|
|
165
|
-
if torch_dtype not in [
|
|
166
|
-
torch.float16,
|
|
167
|
-
torch.float32,
|
|
168
|
-
torch.bfloat16,
|
|
169
|
-
]:
|
|
170
|
-
logger.warning(
|
|
171
|
-
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
|
|
172
|
-
)
|
|
173
|
-
torch_dtype = torch.float32
|
|
174
|
-
except AttributeError:
|
|
175
|
-
logger.warning(
|
|
176
|
-
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
|
|
177
|
-
)
|
|
178
|
-
torch_dtype = torch.float32
|
|
179
|
-
else:
|
|
180
|
-
torch_dtype = "auto"
|
|
183
|
+
model_kwargs = {"device_map": "auto"}
|
|
184
|
+
if torch_dtype:
|
|
185
|
+
model_kwargs["torch_dtype"] = torch_dtype
|
|
181
186
|
self._model = XSentenceTransformer(
|
|
182
187
|
self._model_path,
|
|
183
188
|
device=self._device,
|
|
184
|
-
model_kwargs=
|
|
189
|
+
model_kwargs=model_kwargs,
|
|
185
190
|
)
|
|
186
191
|
else:
|
|
187
|
-
|
|
192
|
+
model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
|
|
193
|
+
self._model = SentenceTransformer(
|
|
194
|
+
self._model_path, device=self._device, model_kwargs=model_kwargs
|
|
195
|
+
)
|
|
188
196
|
|
|
189
197
|
def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
|
|
190
|
-
self._counter += 1
|
|
191
|
-
if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0:
|
|
192
|
-
logger.debug("Empty embedding cache.")
|
|
193
|
-
gc.collect()
|
|
194
|
-
empty_cache()
|
|
195
198
|
from sentence_transformers import SentenceTransformer
|
|
196
199
|
|
|
197
200
|
kwargs.setdefault("normalize_embeddings", True)
|
|
@@ -309,7 +312,9 @@ class EmbeddingModel:
|
|
|
309
312
|
features = model.tokenize(sentences_batch)
|
|
310
313
|
features = batch_to_device(features, device)
|
|
311
314
|
features.update(extra_features)
|
|
312
|
-
|
|
315
|
+
# when batching, the attention mask 1 means there is a token
|
|
316
|
+
# thus we just sum up it to get the total number of tokens
|
|
317
|
+
all_token_nums += features["attention_mask"].sum().item()
|
|
313
318
|
|
|
314
319
|
with torch.no_grad():
|
|
315
320
|
out_features = model.forward(features)
|
|
@@ -393,13 +398,29 @@ class EmbeddingModel:
|
|
|
393
398
|
usage = EmbeddingUsage(
|
|
394
399
|
prompt_tokens=all_token_nums, total_tokens=all_token_nums
|
|
395
400
|
)
|
|
396
|
-
|
|
401
|
+
result = Embedding(
|
|
397
402
|
object="list",
|
|
398
403
|
model=self._model_uid,
|
|
399
404
|
data=embedding_list,
|
|
400
405
|
usage=usage,
|
|
401
406
|
)
|
|
402
407
|
|
|
408
|
+
# clean cache if possible
|
|
409
|
+
self._counter += 1
|
|
410
|
+
if (
|
|
411
|
+
self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
|
|
412
|
+
or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
|
|
413
|
+
):
|
|
414
|
+
logger.debug(
|
|
415
|
+
"Empty embedding cache, calling count %s, all_token_nums %s",
|
|
416
|
+
self._counter,
|
|
417
|
+
all_token_nums,
|
|
418
|
+
)
|
|
419
|
+
gc.collect()
|
|
420
|
+
empty_cache()
|
|
421
|
+
|
|
422
|
+
return result
|
|
423
|
+
|
|
403
424
|
|
|
404
425
|
def match_embedding(
|
|
405
426
|
model_name: str,
|
xinference/model/image/core.py
CHANGED
|
@@ -47,6 +47,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
|
|
|
47
47
|
model_hub: str = "huggingface"
|
|
48
48
|
model_ability: Optional[List[str]]
|
|
49
49
|
controlnet: Optional[List["ImageModelFamilyV1"]]
|
|
50
|
+
default_generate_config: Optional[dict] = {}
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
class ImageModelDescription(ModelDescription):
|
|
@@ -238,7 +239,7 @@ def create_image_model_instance(
|
|
|
238
239
|
lora_model_paths=lora_model,
|
|
239
240
|
lora_load_kwargs=lora_load_kwargs,
|
|
240
241
|
lora_fuse_kwargs=lora_fuse_kwargs,
|
|
241
|
-
|
|
242
|
+
model_spec=model_spec,
|
|
242
243
|
**kwargs,
|
|
243
244
|
)
|
|
244
245
|
model_description = ImageModelDescription(
|
|
@@ -5,7 +5,9 @@
|
|
|
5
5
|
"model_id": "black-forest-labs/FLUX.1-schnell",
|
|
6
6
|
"model_revision": "768d12a373ed5cc9ef9a9dea7504dc09fcc14842",
|
|
7
7
|
"model_ability": [
|
|
8
|
-
"text2image"
|
|
8
|
+
"text2image",
|
|
9
|
+
"image2image",
|
|
10
|
+
"inpainting"
|
|
9
11
|
]
|
|
10
12
|
},
|
|
11
13
|
{
|
|
@@ -14,7 +16,9 @@
|
|
|
14
16
|
"model_id": "black-forest-labs/FLUX.1-dev",
|
|
15
17
|
"model_revision": "01aa605f2c300568dd6515476f04565a954fcb59",
|
|
16
18
|
"model_ability": [
|
|
17
|
-
"text2image"
|
|
19
|
+
"text2image",
|
|
20
|
+
"image2image",
|
|
21
|
+
"inpainting"
|
|
18
22
|
]
|
|
19
23
|
},
|
|
20
24
|
{
|
|
@@ -35,7 +39,11 @@
|
|
|
35
39
|
"model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c",
|
|
36
40
|
"model_ability": [
|
|
37
41
|
"text2image"
|
|
38
|
-
]
|
|
42
|
+
],
|
|
43
|
+
"default_generate_config": {
|
|
44
|
+
"guidance_scale": 0.0,
|
|
45
|
+
"num_inference_steps": 1
|
|
46
|
+
}
|
|
39
47
|
},
|
|
40
48
|
{
|
|
41
49
|
"model_name": "sdxl-turbo",
|
|
@@ -44,7 +52,11 @@
|
|
|
44
52
|
"model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b",
|
|
45
53
|
"model_ability": [
|
|
46
54
|
"text2image"
|
|
47
|
-
]
|
|
55
|
+
],
|
|
56
|
+
"default_generate_config": {
|
|
57
|
+
"guidance_scale": 0.0,
|
|
58
|
+
"num_inference_steps": 1
|
|
59
|
+
}
|
|
48
60
|
},
|
|
49
61
|
{
|
|
50
62
|
"model_name": "stable-diffusion-v1.5",
|
|
@@ -6,7 +6,9 @@
|
|
|
6
6
|
"model_id": "AI-ModelScope/FLUX.1-schnell",
|
|
7
7
|
"model_revision": "master",
|
|
8
8
|
"model_ability": [
|
|
9
|
-
"text2image"
|
|
9
|
+
"text2image",
|
|
10
|
+
"image2image",
|
|
11
|
+
"inpainting"
|
|
10
12
|
]
|
|
11
13
|
},
|
|
12
14
|
{
|
|
@@ -16,7 +18,9 @@
|
|
|
16
18
|
"model_id": "AI-ModelScope/FLUX.1-dev",
|
|
17
19
|
"model_revision": "master",
|
|
18
20
|
"model_ability": [
|
|
19
|
-
"text2image"
|
|
21
|
+
"text2image",
|
|
22
|
+
"image2image",
|
|
23
|
+
"inpainting"
|
|
20
24
|
]
|
|
21
25
|
},
|
|
22
26
|
{
|
|
@@ -39,7 +43,11 @@
|
|
|
39
43
|
"model_revision": "master",
|
|
40
44
|
"model_ability": [
|
|
41
45
|
"text2image"
|
|
42
|
-
]
|
|
46
|
+
],
|
|
47
|
+
"default_generate_config": {
|
|
48
|
+
"guidance_scale": 0.0,
|
|
49
|
+
"num_inference_steps": 1
|
|
50
|
+
}
|
|
43
51
|
},
|
|
44
52
|
{
|
|
45
53
|
"model_name": "sdxl-turbo",
|
|
@@ -49,7 +57,11 @@
|
|
|
49
57
|
"model_revision": "master",
|
|
50
58
|
"model_ability": [
|
|
51
59
|
"text2image"
|
|
52
|
-
]
|
|
60
|
+
],
|
|
61
|
+
"default_generate_config": {
|
|
62
|
+
"guidance_scale": 0.0,
|
|
63
|
+
"num_inference_steps": 1
|
|
64
|
+
}
|
|
53
65
|
},
|
|
54
66
|
{
|
|
55
67
|
"model_name": "stable-diffusion-v1.5",
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import io
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
from PIL import Image
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SDAPIToDiffusersConverter:
|
|
22
|
+
txt2img_identical_args = {
|
|
23
|
+
"prompt",
|
|
24
|
+
"negative_prompt",
|
|
25
|
+
"seed",
|
|
26
|
+
"width",
|
|
27
|
+
"height",
|
|
28
|
+
"sampler_name",
|
|
29
|
+
}
|
|
30
|
+
txt2img_arg_mapping = {
|
|
31
|
+
"steps": "num_inference_steps",
|
|
32
|
+
"cfg_scale": "guidance_scale",
|
|
33
|
+
# "denoising_strength": "strength",
|
|
34
|
+
}
|
|
35
|
+
img2img_identical_args = {
|
|
36
|
+
"prompt",
|
|
37
|
+
"negative_prompt",
|
|
38
|
+
"seed",
|
|
39
|
+
"width",
|
|
40
|
+
"height",
|
|
41
|
+
"sampler_name",
|
|
42
|
+
}
|
|
43
|
+
img2img_arg_mapping = {
|
|
44
|
+
"init_images": "image",
|
|
45
|
+
"steps": "num_inference_steps",
|
|
46
|
+
"cfg_scale": "guidance_scale",
|
|
47
|
+
"denoising_strength": "strength",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def convert_to_diffusers(sd_type: str, params: dict) -> dict:
|
|
52
|
+
diffusers_params = {}
|
|
53
|
+
|
|
54
|
+
identical_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_identical_args")
|
|
55
|
+
mapping_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_arg_mapping")
|
|
56
|
+
for param, value in params.items():
|
|
57
|
+
if param in identical_args:
|
|
58
|
+
diffusers_params[param] = value
|
|
59
|
+
elif param in mapping_args:
|
|
60
|
+
diffusers_params[mapping_args[param]] = value
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Unknown arg: {param}")
|
|
63
|
+
|
|
64
|
+
return diffusers_params
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def get_available_args(sd_type: str) -> set:
|
|
68
|
+
identical_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_identical_args")
|
|
69
|
+
mapping_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_arg_mapping")
|
|
70
|
+
return identical_args.union(mapping_args)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SDAPIDiffusionModelMixin:
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _check_kwargs(sd_type: str, kwargs: dict):
|
|
76
|
+
available_args = SDAPIToDiffusersConverter.get_available_args(sd_type)
|
|
77
|
+
unknown_args = []
|
|
78
|
+
available_kwargs = {}
|
|
79
|
+
for arg, value in kwargs.items():
|
|
80
|
+
if arg in available_args:
|
|
81
|
+
available_kwargs[arg] = value
|
|
82
|
+
else:
|
|
83
|
+
unknown_args.append(arg)
|
|
84
|
+
if unknown_args:
|
|
85
|
+
warnings.warn(
|
|
86
|
+
f"Some args are not supported for now and will be ignored: {unknown_args}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
converted_kwargs = SDAPIToDiffusersConverter.convert_to_diffusers(
|
|
90
|
+
sd_type, available_kwargs
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
width, height = converted_kwargs.pop("width", None), converted_kwargs.pop(
|
|
94
|
+
"height", None
|
|
95
|
+
)
|
|
96
|
+
if width and height:
|
|
97
|
+
converted_kwargs["size"] = f"{width}*{height}"
|
|
98
|
+
|
|
99
|
+
return converted_kwargs
|
|
100
|
+
|
|
101
|
+
def txt2img(self, **kwargs):
|
|
102
|
+
converted_kwargs = self._check_kwargs("txt2img", kwargs)
|
|
103
|
+
result = self.text_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
|
|
104
|
+
|
|
105
|
+
# convert to SD API result
|
|
106
|
+
return {
|
|
107
|
+
"images": [r["b64_json"] for r in result["data"]],
|
|
108
|
+
"info": {"created": result["created"]},
|
|
109
|
+
"parameters": {},
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _decode_b64_img(img_str: str) -> Image:
|
|
114
|
+
# img_str in a format: "data:image/png;base64," + raw_b64_img(image)
|
|
115
|
+
f, data = img_str.split(",", 1)
|
|
116
|
+
f, encode_type = f.split(";", 1)
|
|
117
|
+
assert encode_type == "base64"
|
|
118
|
+
f = f.split("/", 1)[1]
|
|
119
|
+
b = base64.b64decode(data)
|
|
120
|
+
return Image.open(io.BytesIO(b), formats=[f])
|
|
121
|
+
|
|
122
|
+
def img2img(self, **kwargs):
|
|
123
|
+
init_images = kwargs.pop("init_images", [])
|
|
124
|
+
kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images]
|
|
125
|
+
clip_skip = kwargs.get("override_settings", {}).get("clip_skip")
|
|
126
|
+
converted_kwargs = self._check_kwargs("img2img", kwargs)
|
|
127
|
+
if clip_skip:
|
|
128
|
+
converted_kwargs["clip_skip"] = clip_skip
|
|
129
|
+
result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
|
|
130
|
+
|
|
131
|
+
# convert to SD API result
|
|
132
|
+
return {
|
|
133
|
+
"images": [r["b64_json"] for r in result["data"]],
|
|
134
|
+
"info": {"created": result["created"]},
|
|
135
|
+
"parameters": {},
|
|
136
|
+
}
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import base64
|
|
16
|
+
import contextlib
|
|
17
|
+
import inspect
|
|
16
18
|
import logging
|
|
17
19
|
import os
|
|
18
20
|
import re
|
|
@@ -22,19 +24,43 @@ import uuid
|
|
|
22
24
|
from concurrent.futures import ThreadPoolExecutor
|
|
23
25
|
from functools import partial
|
|
24
26
|
from io import BytesIO
|
|
25
|
-
from typing import Dict, List, Optional, Union
|
|
27
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
26
28
|
|
|
27
29
|
import PIL.Image
|
|
30
|
+
import torch
|
|
28
31
|
from PIL import ImageOps
|
|
29
32
|
|
|
30
33
|
from ....constants import XINFERENCE_IMAGE_DIR
|
|
31
34
|
from ....device_utils import move_model_to_available_device
|
|
32
35
|
from ....types import Image, ImageList, LoRA
|
|
36
|
+
from ..sdapi import SDAPIDiffusionModelMixin
|
|
33
37
|
|
|
34
|
-
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from ..core import ImageModelFamilyV1
|
|
35
40
|
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
36
42
|
|
|
37
|
-
|
|
43
|
+
SAMPLING_METHODS = [
|
|
44
|
+
"default",
|
|
45
|
+
"DPM++ 2M",
|
|
46
|
+
"DPM++ 2M Karras",
|
|
47
|
+
"DPM++ 2M SDE",
|
|
48
|
+
"DPM++ 2M SDE Karras",
|
|
49
|
+
"DPM++ SDE",
|
|
50
|
+
"DPM++ SDE Karras",
|
|
51
|
+
"DPM2",
|
|
52
|
+
"DPM2 Karras",
|
|
53
|
+
"DPM2 a",
|
|
54
|
+
"DPM2 a Karras",
|
|
55
|
+
"Euler",
|
|
56
|
+
"Euler a",
|
|
57
|
+
"Heun",
|
|
58
|
+
"LMS",
|
|
59
|
+
"LMS Karras",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class DiffusionModel(SDAPIDiffusionModelMixin):
|
|
38
64
|
def __init__(
|
|
39
65
|
self,
|
|
40
66
|
model_uid: str,
|
|
@@ -43,7 +69,7 @@ class DiffusionModel:
|
|
|
43
69
|
lora_model: Optional[List[LoRA]] = None,
|
|
44
70
|
lora_load_kwargs: Optional[Dict] = None,
|
|
45
71
|
lora_fuse_kwargs: Optional[Dict] = None,
|
|
46
|
-
|
|
72
|
+
model_spec: Optional["ImageModelFamilyV1"] = None,
|
|
47
73
|
**kwargs,
|
|
48
74
|
):
|
|
49
75
|
self._model_uid = model_uid
|
|
@@ -59,7 +85,8 @@ class DiffusionModel:
|
|
|
59
85
|
self._lora_model = lora_model
|
|
60
86
|
self._lora_load_kwargs = lora_load_kwargs or {}
|
|
61
87
|
self._lora_fuse_kwargs = lora_fuse_kwargs or {}
|
|
62
|
-
self.
|
|
88
|
+
self._model_spec = model_spec
|
|
89
|
+
self._abilities = model_spec.model_ability or [] # type: ignore
|
|
63
90
|
self._kwargs = kwargs
|
|
64
91
|
|
|
65
92
|
@property
|
|
@@ -80,8 +107,6 @@ class DiffusionModel:
|
|
|
80
107
|
logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
|
|
81
108
|
|
|
82
109
|
def load(self):
|
|
83
|
-
import torch
|
|
84
|
-
|
|
85
110
|
if "text2image" in self._abilities or "image2image" in self._abilities:
|
|
86
111
|
from diffusers import AutoPipelineForText2Image as AutoPipelineModel
|
|
87
112
|
elif "inpainting" in self._abilities:
|
|
@@ -143,7 +168,9 @@ class DiffusionModel:
|
|
|
143
168
|
self._kwargs[text_encoder_name] = text_encoder
|
|
144
169
|
self._kwargs["device_map"] = "balanced"
|
|
145
170
|
|
|
146
|
-
logger.debug(
|
|
171
|
+
logger.debug(
|
|
172
|
+
"Loading model from %s, kwargs: %s", self._model_path, self._kwargs
|
|
173
|
+
)
|
|
147
174
|
self._model = AutoPipelineModel.from_pretrained(
|
|
148
175
|
self._model_path,
|
|
149
176
|
**self._kwargs,
|
|
@@ -158,6 +185,89 @@ class DiffusionModel:
|
|
|
158
185
|
self._model.enable_attention_slicing()
|
|
159
186
|
self._apply_lora()
|
|
160
187
|
|
|
188
|
+
@staticmethod
|
|
189
|
+
def _get_scheduler(model: Any, sampler_name: str):
|
|
190
|
+
if not sampler_name:
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
assert model is not None
|
|
194
|
+
|
|
195
|
+
import diffusers
|
|
196
|
+
|
|
197
|
+
# see https://github.com/huggingface/diffusers/issues/4167
|
|
198
|
+
# to get A1111 <> Diffusers Scheduler mapping
|
|
199
|
+
if sampler_name == "DPM++ 2M":
|
|
200
|
+
return diffusers.DPMSolverMultistepScheduler.from_config(
|
|
201
|
+
model.scheduler.config
|
|
202
|
+
)
|
|
203
|
+
elif sampler_name == "DPM++ 2M Karras":
|
|
204
|
+
return diffusers.DPMSolverMultistepScheduler.from_config(
|
|
205
|
+
model.scheduler.config, use_karras_sigmas=True
|
|
206
|
+
)
|
|
207
|
+
elif sampler_name == "DPM++ 2M SDE":
|
|
208
|
+
return diffusers.DPMSolverMultistepScheduler.from_config(
|
|
209
|
+
model.scheduler.config, algorithm_type="sde-dpmsolver++"
|
|
210
|
+
)
|
|
211
|
+
elif sampler_name == "DPM++ 2M SDE Karras":
|
|
212
|
+
return diffusers.DPMSolverMultistepScheduler.from_config(
|
|
213
|
+
model.scheduler.config,
|
|
214
|
+
algorithm_type="sde-dpmsolver++",
|
|
215
|
+
use_karras_sigmas=True,
|
|
216
|
+
)
|
|
217
|
+
elif sampler_name == "DPM++ SDE":
|
|
218
|
+
return diffusers.DPMSolverSinglestepScheduler.from_config(
|
|
219
|
+
model.scheduler.config
|
|
220
|
+
)
|
|
221
|
+
elif sampler_name == "DPM++ SDE Karras":
|
|
222
|
+
return diffusers.DPMSolverSinglestepScheduler.from_config(
|
|
223
|
+
model.scheduler.config, use_karras_sigmas=True
|
|
224
|
+
)
|
|
225
|
+
elif sampler_name == "DPM2":
|
|
226
|
+
return diffusers.KDPM2DiscreteScheduler.from_config(model.scheduler.config)
|
|
227
|
+
elif sampler_name == "DPM2 Karras":
|
|
228
|
+
return diffusers.KDPM2DiscreteScheduler.from_config(
|
|
229
|
+
model.scheduler.config, use_karras_sigmas=True
|
|
230
|
+
)
|
|
231
|
+
elif sampler_name == "DPM2 a":
|
|
232
|
+
return diffusers.KDPM2AncestralDiscreteScheduler.from_config(
|
|
233
|
+
model.scheduler.config
|
|
234
|
+
)
|
|
235
|
+
elif sampler_name == "DPM2 a Karras":
|
|
236
|
+
return diffusers.KDPM2AncestralDiscreteScheduler.from_config(
|
|
237
|
+
model.scheduler.config, use_karras_sigmas=True
|
|
238
|
+
)
|
|
239
|
+
elif sampler_name == "Euler":
|
|
240
|
+
return diffusers.EulerDiscreteScheduler.from_config(model.scheduler.config)
|
|
241
|
+
elif sampler_name == "Euler a":
|
|
242
|
+
return diffusers.EulerAncestralDiscreteScheduler.from_config(
|
|
243
|
+
model.scheduler.config
|
|
244
|
+
)
|
|
245
|
+
elif sampler_name == "Heun":
|
|
246
|
+
return diffusers.HeunDiscreteScheduler.from_config(model.scheduler.config)
|
|
247
|
+
elif sampler_name == "LMS":
|
|
248
|
+
return diffusers.LMSDiscreteScheduler.from_config(model.scheduler.config)
|
|
249
|
+
elif sampler_name == "LMS Karras":
|
|
250
|
+
return diffusers.LMSDiscreteScheduler.from_config(
|
|
251
|
+
model.scheduler.config, use_karras_sigmas=True
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
raise ValueError(f"Unknown sampler: {sampler_name}")
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
@contextlib.contextmanager
|
|
258
|
+
def _reset_when_done(model: Any, sampler_name: str):
|
|
259
|
+
assert model is not None
|
|
260
|
+
scheduler = DiffusionModel._get_scheduler(model, sampler_name)
|
|
261
|
+
if scheduler:
|
|
262
|
+
default_scheduler = model.scheduler
|
|
263
|
+
model.scheduler = scheduler
|
|
264
|
+
try:
|
|
265
|
+
yield
|
|
266
|
+
finally:
|
|
267
|
+
model.scheduler = default_scheduler
|
|
268
|
+
else:
|
|
269
|
+
yield
|
|
270
|
+
|
|
161
271
|
def _call_model(
|
|
162
272
|
self,
|
|
163
273
|
response_format: str,
|
|
@@ -168,13 +278,27 @@ class DiffusionModel:
|
|
|
168
278
|
|
|
169
279
|
from ....device_utils import empty_cache
|
|
170
280
|
|
|
171
|
-
logger.debug(
|
|
172
|
-
"stable diffusion args: %s",
|
|
173
|
-
kwargs,
|
|
174
|
-
)
|
|
175
281
|
model = model if model is not None else self._model
|
|
282
|
+
is_padded = kwargs.pop("is_padded", None)
|
|
283
|
+
origin_size = kwargs.pop("origin_size", None)
|
|
284
|
+
seed = kwargs.pop("seed", None)
|
|
285
|
+
if seed is not None:
|
|
286
|
+
kwargs["generator"] = generator = torch.Generator(device=self._model.device) # type: ignore
|
|
287
|
+
if seed != -1:
|
|
288
|
+
kwargs["generator"] = generator.manual_seed(seed)
|
|
289
|
+
sampler_name = kwargs.pop("sampler_name", None)
|
|
176
290
|
assert callable(model)
|
|
177
|
-
|
|
291
|
+
with self._reset_when_done(model, sampler_name):
|
|
292
|
+
logger.debug("stable diffusion args: %s, model: %s", kwargs, model)
|
|
293
|
+
images = model(**kwargs).images
|
|
294
|
+
|
|
295
|
+
# revert padding if padded
|
|
296
|
+
if is_padded and origin_size:
|
|
297
|
+
new_images = []
|
|
298
|
+
x, y = origin_size
|
|
299
|
+
for img in images:
|
|
300
|
+
new_images.append(img.crop((0, 0, x, y)))
|
|
301
|
+
images = new_images
|
|
178
302
|
|
|
179
303
|
# clean cache
|
|
180
304
|
gc.collect()
|
|
@@ -198,7 +322,7 @@ class DiffusionModel:
|
|
|
198
322
|
|
|
199
323
|
with ThreadPoolExecutor() as executor:
|
|
200
324
|
results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
|
|
201
|
-
image_list = [Image(url=None, b64_json=s.result()) for s in results]
|
|
325
|
+
image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
|
|
202
326
|
return ImageList(created=int(time.time()), data=image_list)
|
|
203
327
|
else:
|
|
204
328
|
raise ValueError(f"Unsupported response format: {response_format}")
|
|
@@ -220,14 +344,16 @@ class DiffusionModel:
|
|
|
220
344
|
# References:
|
|
221
345
|
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet_sdxl
|
|
222
346
|
width, height = map(int, re.split(r"[^\d]+", size))
|
|
223
|
-
self.
|
|
347
|
+
generate_kwargs = self._model_spec.default_generate_config.copy() # type: ignore
|
|
348
|
+
generate_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
|
|
349
|
+
self._filter_kwargs(generate_kwargs)
|
|
224
350
|
return self._call_model(
|
|
225
351
|
prompt=prompt,
|
|
226
352
|
height=height,
|
|
227
353
|
width=width,
|
|
228
354
|
num_images_per_prompt=n,
|
|
229
355
|
response_format=response_format,
|
|
230
|
-
**
|
|
356
|
+
**generate_kwargs,
|
|
231
357
|
)
|
|
232
358
|
|
|
233
359
|
@staticmethod
|
|
@@ -265,6 +391,9 @@ class DiffusionModel:
|
|
|
265
391
|
if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
|
|
266
392
|
# Model like SD3 image to image requires image's height and width is times of 16
|
|
267
393
|
# padding the image if specified
|
|
394
|
+
origin_x, origin_y = image.size
|
|
395
|
+
kwargs["origin_size"] = (origin_x, origin_y)
|
|
396
|
+
kwargs["is_padded"] = True
|
|
268
397
|
image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
|
|
269
398
|
|
|
270
399
|
if size:
|
|
@@ -273,12 +402,24 @@ class DiffusionModel:
|
|
|
273
402
|
width, height = image.size
|
|
274
403
|
kwargs["width"] = width
|
|
275
404
|
kwargs["height"] = height
|
|
276
|
-
|
|
405
|
+
else:
|
|
406
|
+
# SD3 image2image cannot accept width and height
|
|
407
|
+
parameters = inspect.signature(model.__call__).parameters # type: ignore
|
|
408
|
+
allow_width_height = False
|
|
409
|
+
for param in parameters.values():
|
|
410
|
+
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
|
411
|
+
allow_width_height = True
|
|
412
|
+
break
|
|
413
|
+
if "width" in parameters or "height" in parameters:
|
|
414
|
+
allow_width_height = True
|
|
415
|
+
if allow_width_height:
|
|
416
|
+
kwargs["width"], kwargs["height"] = image.size
|
|
417
|
+
|
|
418
|
+
kwargs["negative_prompt"] = negative_prompt
|
|
277
419
|
self._filter_kwargs(kwargs)
|
|
278
420
|
return self._call_model(
|
|
279
421
|
image=image,
|
|
280
422
|
prompt=prompt,
|
|
281
|
-
negative_prompt=negative_prompt,
|
|
282
423
|
num_images_per_prompt=n,
|
|
283
424
|
response_format=response_format,
|
|
284
425
|
model=model,
|
|
@@ -318,6 +459,9 @@ class DiffusionModel:
|
|
|
318
459
|
if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
|
|
319
460
|
# Model like SD3 inpainting requires image's height and width is times of 16
|
|
320
461
|
# padding the image if specified
|
|
462
|
+
origin_x, origin_y = image.size
|
|
463
|
+
kwargs["origin_size"] = (origin_x, origin_y)
|
|
464
|
+
kwargs["is_padded"] = True
|
|
321
465
|
image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
|
|
322
466
|
mask_image = self.pad_to_multiple(
|
|
323
467
|
mask_image, multiple=int(padding_image_to_multiple)
|
|
@@ -325,11 +469,12 @@ class DiffusionModel:
|
|
|
325
469
|
# calculate actual image size after padding
|
|
326
470
|
width, height = image.size
|
|
327
471
|
|
|
472
|
+
kwargs["negative_prompt"] = negative_prompt
|
|
473
|
+
self._filter_kwargs(kwargs)
|
|
328
474
|
return self._call_model(
|
|
329
475
|
image=image,
|
|
330
476
|
mask_image=mask_image,
|
|
331
477
|
prompt=prompt,
|
|
332
|
-
negative_prompt=negative_prompt,
|
|
333
478
|
height=height,
|
|
334
479
|
width=width,
|
|
335
480
|
num_images_per_prompt=n,
|