xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from collections import defaultdict
19
19
  from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check
20
20
 
21
21
  import numpy as np
22
+ import torch
22
23
 
23
24
  from ...device_utils import empty_cache
24
25
  from ...types import Embedding, EmbeddingData, EmbeddingUsage
@@ -34,7 +35,11 @@ EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
34
35
  EMBEDDING_EMPTY_CACHE_COUNT = int(
35
36
  os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
36
37
  )
38
+ EMBEDDING_EMPTY_CACHE_TOKENS = int(
39
+ os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_TOKENS", "8192")
40
+ )
37
41
  assert EMBEDDING_EMPTY_CACHE_COUNT > 0
42
+ assert EMBEDDING_EMPTY_CACHE_TOKENS > 0
38
43
 
39
44
 
40
45
  def get_embedding_model_descriptions():
@@ -149,6 +154,25 @@ class EmbeddingModel:
149
154
  def to(self, *args, **kwargs):
150
155
  pass
151
156
 
157
+ torch_dtype = None
158
+ if torch_dtype_str := self._kwargs.get("torch_dtype"):
159
+ try:
160
+ torch_dtype = getattr(torch, torch_dtype_str)
161
+ if torch_dtype not in [
162
+ torch.float16,
163
+ torch.float32,
164
+ torch.bfloat16,
165
+ ]:
166
+ logger.warning(
167
+ f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
168
+ )
169
+ torch_dtype = torch.float32
170
+ except AttributeError:
171
+ logger.warning(
172
+ f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
173
+ )
174
+ torch_dtype = torch.float32
175
+
152
176
  from ..utils import patch_trust_remote_code
153
177
 
154
178
  patch_trust_remote_code()
@@ -156,42 +180,21 @@ class EmbeddingModel:
156
180
  "gte" in self._model_spec.model_name.lower()
157
181
  and "qwen2" in self._model_spec.model_name.lower()
158
182
  ):
159
- import torch
160
-
161
- torch_dtype_str = self._kwargs.get("torch_dtype")
162
- if torch_dtype_str is not None:
163
- try:
164
- torch_dtype = getattr(torch, torch_dtype_str)
165
- if torch_dtype not in [
166
- torch.float16,
167
- torch.float32,
168
- torch.bfloat16,
169
- ]:
170
- logger.warning(
171
- f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
172
- )
173
- torch_dtype = torch.float32
174
- except AttributeError:
175
- logger.warning(
176
- f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
177
- )
178
- torch_dtype = torch.float32
179
- else:
180
- torch_dtype = "auto"
183
+ model_kwargs = {"device_map": "auto"}
184
+ if torch_dtype:
185
+ model_kwargs["torch_dtype"] = torch_dtype
181
186
  self._model = XSentenceTransformer(
182
187
  self._model_path,
183
188
  device=self._device,
184
- model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
189
+ model_kwargs=model_kwargs,
185
190
  )
186
191
  else:
187
- self._model = SentenceTransformer(self._model_path, device=self._device)
192
+ model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
193
+ self._model = SentenceTransformer(
194
+ self._model_path, device=self._device, model_kwargs=model_kwargs
195
+ )
188
196
 
189
197
  def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
190
- self._counter += 1
191
- if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0:
192
- logger.debug("Empty embedding cache.")
193
- gc.collect()
194
- empty_cache()
195
198
  from sentence_transformers import SentenceTransformer
196
199
 
197
200
  kwargs.setdefault("normalize_embeddings", True)
@@ -309,7 +312,9 @@ class EmbeddingModel:
309
312
  features = model.tokenize(sentences_batch)
310
313
  features = batch_to_device(features, device)
311
314
  features.update(extra_features)
312
- all_token_nums += sum([len(f) for f in features])
315
+ # when batching, the attention mask 1 means there is a token
316
+ # thus we just sum up it to get the total number of tokens
317
+ all_token_nums += features["attention_mask"].sum().item()
313
318
 
314
319
  with torch.no_grad():
315
320
  out_features = model.forward(features)
@@ -393,13 +398,29 @@ class EmbeddingModel:
393
398
  usage = EmbeddingUsage(
394
399
  prompt_tokens=all_token_nums, total_tokens=all_token_nums
395
400
  )
396
- return Embedding(
401
+ result = Embedding(
397
402
  object="list",
398
403
  model=self._model_uid,
399
404
  data=embedding_list,
400
405
  usage=usage,
401
406
  )
402
407
 
408
+ # clean cache if possible
409
+ self._counter += 1
410
+ if (
411
+ self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
412
+ or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
413
+ ):
414
+ logger.debug(
415
+ "Empty embedding cache, calling count %s, all_token_nums %s",
416
+ self._counter,
417
+ all_token_nums,
418
+ )
419
+ gc.collect()
420
+ empty_cache()
421
+
422
+ return result
423
+
403
424
 
404
425
  def match_embedding(
405
426
  model_name: str,
@@ -47,6 +47,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
47
47
  model_hub: str = "huggingface"
48
48
  model_ability: Optional[List[str]]
49
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
50
+ default_generate_config: Optional[dict] = {}
50
51
 
51
52
 
52
53
  class ImageModelDescription(ModelDescription):
@@ -238,7 +239,7 @@ def create_image_model_instance(
238
239
  lora_model_paths=lora_model,
239
240
  lora_load_kwargs=lora_load_kwargs,
240
241
  lora_fuse_kwargs=lora_fuse_kwargs,
241
- abilities=model_spec.model_ability,
242
+ model_spec=model_spec,
242
243
  **kwargs,
243
244
  )
244
245
  model_description = ImageModelDescription(
@@ -5,7 +5,9 @@
5
5
  "model_id": "black-forest-labs/FLUX.1-schnell",
6
6
  "model_revision": "768d12a373ed5cc9ef9a9dea7504dc09fcc14842",
7
7
  "model_ability": [
8
- "text2image"
8
+ "text2image",
9
+ "image2image",
10
+ "inpainting"
9
11
  ]
10
12
  },
11
13
  {
@@ -14,7 +16,9 @@
14
16
  "model_id": "black-forest-labs/FLUX.1-dev",
15
17
  "model_revision": "01aa605f2c300568dd6515476f04565a954fcb59",
16
18
  "model_ability": [
17
- "text2image"
19
+ "text2image",
20
+ "image2image",
21
+ "inpainting"
18
22
  ]
19
23
  },
20
24
  {
@@ -35,7 +39,11 @@
35
39
  "model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c",
36
40
  "model_ability": [
37
41
  "text2image"
38
- ]
42
+ ],
43
+ "default_generate_config": {
44
+ "guidance_scale": 0.0,
45
+ "num_inference_steps": 1
46
+ }
39
47
  },
40
48
  {
41
49
  "model_name": "sdxl-turbo",
@@ -44,7 +52,11 @@
44
52
  "model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b",
45
53
  "model_ability": [
46
54
  "text2image"
47
- ]
55
+ ],
56
+ "default_generate_config": {
57
+ "guidance_scale": 0.0,
58
+ "num_inference_steps": 1
59
+ }
48
60
  },
49
61
  {
50
62
  "model_name": "stable-diffusion-v1.5",
@@ -6,7 +6,9 @@
6
6
  "model_id": "AI-ModelScope/FLUX.1-schnell",
7
7
  "model_revision": "master",
8
8
  "model_ability": [
9
- "text2image"
9
+ "text2image",
10
+ "image2image",
11
+ "inpainting"
10
12
  ]
11
13
  },
12
14
  {
@@ -16,7 +18,9 @@
16
18
  "model_id": "AI-ModelScope/FLUX.1-dev",
17
19
  "model_revision": "master",
18
20
  "model_ability": [
19
- "text2image"
21
+ "text2image",
22
+ "image2image",
23
+ "inpainting"
20
24
  ]
21
25
  },
22
26
  {
@@ -39,7 +43,11 @@
39
43
  "model_revision": "master",
40
44
  "model_ability": [
41
45
  "text2image"
42
- ]
46
+ ],
47
+ "default_generate_config": {
48
+ "guidance_scale": 0.0,
49
+ "num_inference_steps": 1
50
+ }
43
51
  },
44
52
  {
45
53
  "model_name": "sdxl-turbo",
@@ -49,7 +57,11 @@
49
57
  "model_revision": "master",
50
58
  "model_ability": [
51
59
  "text2image"
52
- ]
60
+ ],
61
+ "default_generate_config": {
62
+ "guidance_scale": 0.0,
63
+ "num_inference_steps": 1
64
+ }
53
65
  },
54
66
  {
55
67
  "model_name": "stable-diffusion-v1.5",
@@ -0,0 +1,136 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import io
16
+ import warnings
17
+
18
+ from PIL import Image
19
+
20
+
21
+ class SDAPIToDiffusersConverter:
22
+ txt2img_identical_args = {
23
+ "prompt",
24
+ "negative_prompt",
25
+ "seed",
26
+ "width",
27
+ "height",
28
+ "sampler_name",
29
+ }
30
+ txt2img_arg_mapping = {
31
+ "steps": "num_inference_steps",
32
+ "cfg_scale": "guidance_scale",
33
+ # "denoising_strength": "strength",
34
+ }
35
+ img2img_identical_args = {
36
+ "prompt",
37
+ "negative_prompt",
38
+ "seed",
39
+ "width",
40
+ "height",
41
+ "sampler_name",
42
+ }
43
+ img2img_arg_mapping = {
44
+ "init_images": "image",
45
+ "steps": "num_inference_steps",
46
+ "cfg_scale": "guidance_scale",
47
+ "denoising_strength": "strength",
48
+ }
49
+
50
+ @staticmethod
51
+ def convert_to_diffusers(sd_type: str, params: dict) -> dict:
52
+ diffusers_params = {}
53
+
54
+ identical_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_identical_args")
55
+ mapping_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_arg_mapping")
56
+ for param, value in params.items():
57
+ if param in identical_args:
58
+ diffusers_params[param] = value
59
+ elif param in mapping_args:
60
+ diffusers_params[mapping_args[param]] = value
61
+ else:
62
+ raise ValueError(f"Unknown arg: {param}")
63
+
64
+ return diffusers_params
65
+
66
+ @staticmethod
67
+ def get_available_args(sd_type: str) -> set:
68
+ identical_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_identical_args")
69
+ mapping_args = getattr(SDAPIToDiffusersConverter, f"{sd_type}_arg_mapping")
70
+ return identical_args.union(mapping_args)
71
+
72
+
73
+ class SDAPIDiffusionModelMixin:
74
+ @staticmethod
75
+ def _check_kwargs(sd_type: str, kwargs: dict):
76
+ available_args = SDAPIToDiffusersConverter.get_available_args(sd_type)
77
+ unknown_args = []
78
+ available_kwargs = {}
79
+ for arg, value in kwargs.items():
80
+ if arg in available_args:
81
+ available_kwargs[arg] = value
82
+ else:
83
+ unknown_args.append(arg)
84
+ if unknown_args:
85
+ warnings.warn(
86
+ f"Some args are not supported for now and will be ignored: {unknown_args}"
87
+ )
88
+
89
+ converted_kwargs = SDAPIToDiffusersConverter.convert_to_diffusers(
90
+ sd_type, available_kwargs
91
+ )
92
+
93
+ width, height = converted_kwargs.pop("width", None), converted_kwargs.pop(
94
+ "height", None
95
+ )
96
+ if width and height:
97
+ converted_kwargs["size"] = f"{width}*{height}"
98
+
99
+ return converted_kwargs
100
+
101
+ def txt2img(self, **kwargs):
102
+ converted_kwargs = self._check_kwargs("txt2img", kwargs)
103
+ result = self.text_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
104
+
105
+ # convert to SD API result
106
+ return {
107
+ "images": [r["b64_json"] for r in result["data"]],
108
+ "info": {"created": result["created"]},
109
+ "parameters": {},
110
+ }
111
+
112
+ @staticmethod
113
+ def _decode_b64_img(img_str: str) -> Image:
114
+ # img_str in a format: "data:image/png;base64," + raw_b64_img(image)
115
+ f, data = img_str.split(",", 1)
116
+ f, encode_type = f.split(";", 1)
117
+ assert encode_type == "base64"
118
+ f = f.split("/", 1)[1]
119
+ b = base64.b64decode(data)
120
+ return Image.open(io.BytesIO(b), formats=[f])
121
+
122
+ def img2img(self, **kwargs):
123
+ init_images = kwargs.pop("init_images", [])
124
+ kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images]
125
+ clip_skip = kwargs.get("override_settings", {}).get("clip_skip")
126
+ converted_kwargs = self._check_kwargs("img2img", kwargs)
127
+ if clip_skip:
128
+ converted_kwargs["clip_skip"] = clip_skip
129
+ result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
130
+
131
+ # convert to SD API result
132
+ return {
133
+ "images": [r["b64_json"] for r in result["data"]],
134
+ "info": {"created": result["created"]},
135
+ "parameters": {},
136
+ }
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import base64
16
+ import contextlib
17
+ import inspect
16
18
  import logging
17
19
  import os
18
20
  import re
@@ -22,19 +24,43 @@ import uuid
22
24
  from concurrent.futures import ThreadPoolExecutor
23
25
  from functools import partial
24
26
  from io import BytesIO
25
- from typing import Dict, List, Optional, Union
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
28
 
27
29
  import PIL.Image
30
+ import torch
28
31
  from PIL import ImageOps
29
32
 
30
33
  from ....constants import XINFERENCE_IMAGE_DIR
31
34
  from ....device_utils import move_model_to_available_device
32
35
  from ....types import Image, ImageList, LoRA
36
+ from ..sdapi import SDAPIDiffusionModelMixin
33
37
 
34
- logger = logging.getLogger(__name__)
38
+ if TYPE_CHECKING:
39
+ from ..core import ImageModelFamilyV1
35
40
 
41
+ logger = logging.getLogger(__name__)
36
42
 
37
- class DiffusionModel:
43
+ SAMPLING_METHODS = [
44
+ "default",
45
+ "DPM++ 2M",
46
+ "DPM++ 2M Karras",
47
+ "DPM++ 2M SDE",
48
+ "DPM++ 2M SDE Karras",
49
+ "DPM++ SDE",
50
+ "DPM++ SDE Karras",
51
+ "DPM2",
52
+ "DPM2 Karras",
53
+ "DPM2 a",
54
+ "DPM2 a Karras",
55
+ "Euler",
56
+ "Euler a",
57
+ "Heun",
58
+ "LMS",
59
+ "LMS Karras",
60
+ ]
61
+
62
+
63
+ class DiffusionModel(SDAPIDiffusionModelMixin):
38
64
  def __init__(
39
65
  self,
40
66
  model_uid: str,
@@ -43,7 +69,7 @@ class DiffusionModel:
43
69
  lora_model: Optional[List[LoRA]] = None,
44
70
  lora_load_kwargs: Optional[Dict] = None,
45
71
  lora_fuse_kwargs: Optional[Dict] = None,
46
- abilities: Optional[List[str]] = None,
72
+ model_spec: Optional["ImageModelFamilyV1"] = None,
47
73
  **kwargs,
48
74
  ):
49
75
  self._model_uid = model_uid
@@ -59,7 +85,8 @@ class DiffusionModel:
59
85
  self._lora_model = lora_model
60
86
  self._lora_load_kwargs = lora_load_kwargs or {}
61
87
  self._lora_fuse_kwargs = lora_fuse_kwargs or {}
62
- self._abilities = abilities or []
88
+ self._model_spec = model_spec
89
+ self._abilities = model_spec.model_ability or [] # type: ignore
63
90
  self._kwargs = kwargs
64
91
 
65
92
  @property
@@ -80,8 +107,6 @@ class DiffusionModel:
80
107
  logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
81
108
 
82
109
  def load(self):
83
- import torch
84
-
85
110
  if "text2image" in self._abilities or "image2image" in self._abilities:
86
111
  from diffusers import AutoPipelineForText2Image as AutoPipelineModel
87
112
  elif "inpainting" in self._abilities:
@@ -143,7 +168,9 @@ class DiffusionModel:
143
168
  self._kwargs[text_encoder_name] = text_encoder
144
169
  self._kwargs["device_map"] = "balanced"
145
170
 
146
- logger.debug("Loading model %s", AutoPipelineModel)
171
+ logger.debug(
172
+ "Loading model from %s, kwargs: %s", self._model_path, self._kwargs
173
+ )
147
174
  self._model = AutoPipelineModel.from_pretrained(
148
175
  self._model_path,
149
176
  **self._kwargs,
@@ -158,6 +185,89 @@ class DiffusionModel:
158
185
  self._model.enable_attention_slicing()
159
186
  self._apply_lora()
160
187
 
188
+ @staticmethod
189
+ def _get_scheduler(model: Any, sampler_name: str):
190
+ if not sampler_name:
191
+ return
192
+
193
+ assert model is not None
194
+
195
+ import diffusers
196
+
197
+ # see https://github.com/huggingface/diffusers/issues/4167
198
+ # to get A1111 <> Diffusers Scheduler mapping
199
+ if sampler_name == "DPM++ 2M":
200
+ return diffusers.DPMSolverMultistepScheduler.from_config(
201
+ model.scheduler.config
202
+ )
203
+ elif sampler_name == "DPM++ 2M Karras":
204
+ return diffusers.DPMSolverMultistepScheduler.from_config(
205
+ model.scheduler.config, use_karras_sigmas=True
206
+ )
207
+ elif sampler_name == "DPM++ 2M SDE":
208
+ return diffusers.DPMSolverMultistepScheduler.from_config(
209
+ model.scheduler.config, algorithm_type="sde-dpmsolver++"
210
+ )
211
+ elif sampler_name == "DPM++ 2M SDE Karras":
212
+ return diffusers.DPMSolverMultistepScheduler.from_config(
213
+ model.scheduler.config,
214
+ algorithm_type="sde-dpmsolver++",
215
+ use_karras_sigmas=True,
216
+ )
217
+ elif sampler_name == "DPM++ SDE":
218
+ return diffusers.DPMSolverSinglestepScheduler.from_config(
219
+ model.scheduler.config
220
+ )
221
+ elif sampler_name == "DPM++ SDE Karras":
222
+ return diffusers.DPMSolverSinglestepScheduler.from_config(
223
+ model.scheduler.config, use_karras_sigmas=True
224
+ )
225
+ elif sampler_name == "DPM2":
226
+ return diffusers.KDPM2DiscreteScheduler.from_config(model.scheduler.config)
227
+ elif sampler_name == "DPM2 Karras":
228
+ return diffusers.KDPM2DiscreteScheduler.from_config(
229
+ model.scheduler.config, use_karras_sigmas=True
230
+ )
231
+ elif sampler_name == "DPM2 a":
232
+ return diffusers.KDPM2AncestralDiscreteScheduler.from_config(
233
+ model.scheduler.config
234
+ )
235
+ elif sampler_name == "DPM2 a Karras":
236
+ return diffusers.KDPM2AncestralDiscreteScheduler.from_config(
237
+ model.scheduler.config, use_karras_sigmas=True
238
+ )
239
+ elif sampler_name == "Euler":
240
+ return diffusers.EulerDiscreteScheduler.from_config(model.scheduler.config)
241
+ elif sampler_name == "Euler a":
242
+ return diffusers.EulerAncestralDiscreteScheduler.from_config(
243
+ model.scheduler.config
244
+ )
245
+ elif sampler_name == "Heun":
246
+ return diffusers.HeunDiscreteScheduler.from_config(model.scheduler.config)
247
+ elif sampler_name == "LMS":
248
+ return diffusers.LMSDiscreteScheduler.from_config(model.scheduler.config)
249
+ elif sampler_name == "LMS Karras":
250
+ return diffusers.LMSDiscreteScheduler.from_config(
251
+ model.scheduler.config, use_karras_sigmas=True
252
+ )
253
+ else:
254
+ raise ValueError(f"Unknown sampler: {sampler_name}")
255
+
256
+ @staticmethod
257
+ @contextlib.contextmanager
258
+ def _reset_when_done(model: Any, sampler_name: str):
259
+ assert model is not None
260
+ scheduler = DiffusionModel._get_scheduler(model, sampler_name)
261
+ if scheduler:
262
+ default_scheduler = model.scheduler
263
+ model.scheduler = scheduler
264
+ try:
265
+ yield
266
+ finally:
267
+ model.scheduler = default_scheduler
268
+ else:
269
+ yield
270
+
161
271
  def _call_model(
162
272
  self,
163
273
  response_format: str,
@@ -168,13 +278,27 @@ class DiffusionModel:
168
278
 
169
279
  from ....device_utils import empty_cache
170
280
 
171
- logger.debug(
172
- "stable diffusion args: %s",
173
- kwargs,
174
- )
175
281
  model = model if model is not None else self._model
282
+ is_padded = kwargs.pop("is_padded", None)
283
+ origin_size = kwargs.pop("origin_size", None)
284
+ seed = kwargs.pop("seed", None)
285
+ if seed is not None:
286
+ kwargs["generator"] = generator = torch.Generator(device=self._model.device) # type: ignore
287
+ if seed != -1:
288
+ kwargs["generator"] = generator.manual_seed(seed)
289
+ sampler_name = kwargs.pop("sampler_name", None)
176
290
  assert callable(model)
177
- images = model(**kwargs).images
291
+ with self._reset_when_done(model, sampler_name):
292
+ logger.debug("stable diffusion args: %s, model: %s", kwargs, model)
293
+ images = model(**kwargs).images
294
+
295
+ # revert padding if padded
296
+ if is_padded and origin_size:
297
+ new_images = []
298
+ x, y = origin_size
299
+ for img in images:
300
+ new_images.append(img.crop((0, 0, x, y)))
301
+ images = new_images
178
302
 
179
303
  # clean cache
180
304
  gc.collect()
@@ -198,7 +322,7 @@ class DiffusionModel:
198
322
 
199
323
  with ThreadPoolExecutor() as executor:
200
324
  results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
201
- image_list = [Image(url=None, b64_json=s.result()) for s in results]
325
+ image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
202
326
  return ImageList(created=int(time.time()), data=image_list)
203
327
  else:
204
328
  raise ValueError(f"Unsupported response format: {response_format}")
@@ -220,14 +344,16 @@ class DiffusionModel:
220
344
  # References:
221
345
  # https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet_sdxl
222
346
  width, height = map(int, re.split(r"[^\d]+", size))
223
- self._filter_kwargs(kwargs)
347
+ generate_kwargs = self._model_spec.default_generate_config.copy() # type: ignore
348
+ generate_kwargs.update({k: v for k, v in kwargs.items() if v is not None})
349
+ self._filter_kwargs(generate_kwargs)
224
350
  return self._call_model(
225
351
  prompt=prompt,
226
352
  height=height,
227
353
  width=width,
228
354
  num_images_per_prompt=n,
229
355
  response_format=response_format,
230
- **kwargs,
356
+ **generate_kwargs,
231
357
  )
232
358
 
233
359
  @staticmethod
@@ -265,6 +391,9 @@ class DiffusionModel:
265
391
  if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
266
392
  # Model like SD3 image to image requires image's height and width is times of 16
267
393
  # padding the image if specified
394
+ origin_x, origin_y = image.size
395
+ kwargs["origin_size"] = (origin_x, origin_y)
396
+ kwargs["is_padded"] = True
268
397
  image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
269
398
 
270
399
  if size:
@@ -273,12 +402,24 @@ class DiffusionModel:
273
402
  width, height = image.size
274
403
  kwargs["width"] = width
275
404
  kwargs["height"] = height
276
-
405
+ else:
406
+ # SD3 image2image cannot accept width and height
407
+ parameters = inspect.signature(model.__call__).parameters # type: ignore
408
+ allow_width_height = False
409
+ for param in parameters.values():
410
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
411
+ allow_width_height = True
412
+ break
413
+ if "width" in parameters or "height" in parameters:
414
+ allow_width_height = True
415
+ if allow_width_height:
416
+ kwargs["width"], kwargs["height"] = image.size
417
+
418
+ kwargs["negative_prompt"] = negative_prompt
277
419
  self._filter_kwargs(kwargs)
278
420
  return self._call_model(
279
421
  image=image,
280
422
  prompt=prompt,
281
- negative_prompt=negative_prompt,
282
423
  num_images_per_prompt=n,
283
424
  response_format=response_format,
284
425
  model=model,
@@ -318,6 +459,9 @@ class DiffusionModel:
318
459
  if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
319
460
  # Model like SD3 inpainting requires image's height and width is times of 16
320
461
  # padding the image if specified
462
+ origin_x, origin_y = image.size
463
+ kwargs["origin_size"] = (origin_x, origin_y)
464
+ kwargs["is_padded"] = True
321
465
  image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
322
466
  mask_image = self.pad_to_multiple(
323
467
  mask_image, multiple=int(padding_image_to_multiple)
@@ -325,11 +469,12 @@ class DiffusionModel:
325
469
  # calculate actual image size after padding
326
470
  width, height = image.size
327
471
 
472
+ kwargs["negative_prompt"] = negative_prompt
473
+ self._filter_kwargs(kwargs)
328
474
  return self._call_model(
329
475
  image=image,
330
476
  mask_image=mask_image,
331
477
  prompt=prompt,
332
- negative_prompt=negative_prompt,
333
478
  height=height,
334
479
  width=width,
335
480
  num_images_per_prompt=n,