xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +23 -1
  4. xinference/core/model.py +1 -6
  5. xinference/core/utils.py +10 -6
  6. xinference/model/audio/core.py +5 -0
  7. xinference/model/audio/cosyvoice.py +25 -3
  8. xinference/model/audio/f5tts.py +15 -10
  9. xinference/model/audio/f5tts_mlx.py +260 -0
  10. xinference/model/audio/fish_speech.py +35 -111
  11. xinference/model/audio/model_spec.json +19 -3
  12. xinference/model/audio/model_spec_modelscope.json +9 -0
  13. xinference/model/audio/utils.py +32 -0
  14. xinference/model/image/core.py +69 -1
  15. xinference/model/image/model_spec.json +127 -4
  16. xinference/model/image/model_spec_modelscope.json +130 -4
  17. xinference/model/image/stable_diffusion/core.py +45 -13
  18. xinference/model/llm/llm_family.json +47 -0
  19. xinference/model/llm/llm_family.py +15 -36
  20. xinference/model/llm/llm_family_modelscope.json +49 -0
  21. xinference/model/llm/mlx/core.py +68 -13
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  24. xinference/model/llm/utils.py +1 -0
  25. xinference/model/llm/vllm/core.py +11 -2
  26. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  27. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  28. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  29. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  30. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  31. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  34. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  35. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  36. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  37. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  38. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  39. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  40. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  41. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  42. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  43. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  44. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  45. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  46. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  47. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  48. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  49. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  50. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  51. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  52. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  53. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  54. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  55. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  56. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  57. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  58. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  59. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  60. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  61. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  62. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  63. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  64. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  65. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  66. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  67. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  68. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  69. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  70. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  71. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  72. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  73. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  74. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  75. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  76. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  77. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  78. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  79. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  80. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  81. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  82. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  83. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  84. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  85. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  86. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  87. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  88. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  89. xinference/thirdparty/matcha/utils/utils.py +2 -2
  90. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
  91. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
  92. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  93. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  94. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  95. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  96. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  99. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  100. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  101. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  102. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  103. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  104. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import gc
15
14
  import logging
16
15
  import os.path
17
- import queue
18
16
  import sys
19
17
  from io import BytesIO
20
18
  from typing import TYPE_CHECKING, Optional
@@ -60,6 +58,7 @@ class FishSpeechModel:
60
58
  self._device = device
61
59
  self._llama_queue = None
62
60
  self._model = None
61
+ self._engine = None
63
62
  self._kwargs = kwargs
64
63
 
65
64
  @property
@@ -72,6 +71,7 @@ class FishSpeechModel:
72
71
  0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
73
72
  )
74
73
 
74
+ from tools.inference_engine import TTSInferenceEngine
75
75
  from tools.llama.generate import launch_thread_safe_queue
76
76
  from tools.vqgan.inference import load_model as load_decoder_model
77
77
 
@@ -81,6 +81,11 @@ class FishSpeechModel:
81
81
  if not is_device_available(self._device):
82
82
  raise ValueError(f"Device {self._device} is not available!")
83
83
 
84
+ # https://github.com/pytorch/pytorch/issues/129207
85
+ if self._device == "mps":
86
+ logger.warning("The Conv1d has bugs on MPS backend, fallback to CPU.")
87
+ self._device = "cpu"
88
+
84
89
  enable_compile = self._kwargs.get("compile", False)
85
90
  precision = self._kwargs.get("precision", torch.bfloat16)
86
91
  logger.info("Loading Llama model, compile=%s...", enable_compile)
@@ -102,102 +107,10 @@ class FishSpeechModel:
102
107
  device=self._device,
103
108
  )
104
109
 
105
- @torch.inference_mode()
106
- def _inference(
107
- self,
108
- text,
109
- enable_reference_audio,
110
- reference_audio,
111
- reference_text,
112
- max_new_tokens,
113
- chunk_length,
114
- top_p,
115
- repetition_penalty,
116
- temperature,
117
- seed="0",
118
- streaming=False,
119
- ):
120
- from fish_speech.utils import autocast_exclude_mps, set_seed
121
- from tools.api import decode_vq_tokens, encode_reference
122
- from tools.llama.generate import (
123
- GenerateRequest,
124
- GenerateResponse,
125
- WrappedGenerateResponse,
126
- )
127
-
128
- seed = int(seed)
129
- if seed != 0:
130
- set_seed(seed)
131
- logger.warning(f"set seed: {seed}")
132
-
133
- # Parse reference audio aka prompt
134
- prompt_tokens = encode_reference(
135
- decoder_model=self._model,
136
- reference_audio=reference_audio,
137
- enable_reference_audio=enable_reference_audio,
138
- )
139
-
140
- # LLAMA Inference
141
- request = dict(
142
- device=self._model.device,
143
- max_new_tokens=max_new_tokens,
144
- text=text,
145
- top_p=top_p,
146
- repetition_penalty=repetition_penalty,
147
- temperature=temperature,
148
- compile=self._kwargs.get("compile", False),
149
- iterative_prompt=chunk_length > 0,
150
- chunk_length=chunk_length,
151
- max_length=2048,
152
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
153
- prompt_text=reference_text if enable_reference_audio else None,
154
- )
155
-
156
- response_queue = queue.Queue()
157
- self._llama_queue.put(
158
- GenerateRequest(
159
- request=request,
160
- response_queue=response_queue,
161
- )
110
+ self._engine = TTSInferenceEngine(
111
+ self._llama_queue, self._model, precision, enable_compile
162
112
  )
163
113
 
164
- segments = []
165
-
166
- while True:
167
- result: WrappedGenerateResponse = response_queue.get()
168
- if result.status == "error":
169
- raise result.response
170
-
171
- result: GenerateResponse = result.response
172
- if result.action == "next":
173
- break
174
-
175
- with autocast_exclude_mps(
176
- device_type=self._model.device.type,
177
- dtype=self._kwargs.get("precision", torch.bfloat16),
178
- ):
179
- fake_audios = decode_vq_tokens(
180
- decoder_model=self._model,
181
- codes=result.codes,
182
- )
183
-
184
- fake_audios = fake_audios.float().cpu().numpy()
185
- segments.append(fake_audios)
186
-
187
- if streaming:
188
- yield fake_audios, None, None
189
-
190
- if len(segments) == 0:
191
- raise Exception("No audio generated, please check the input text.")
192
-
193
- # No matter streaming or not, we need to return the final audio
194
- audio = np.concatenate(segments, axis=0)
195
- yield None, (self._model.spec_transform.sample_rate, audio), None
196
-
197
- if torch.cuda.is_available():
198
- torch.cuda.empty_cache()
199
- gc.collect()
200
-
201
114
  def speech(
202
115
  self,
203
116
  input: str,
@@ -211,22 +124,31 @@ class FishSpeechModel:
211
124
  if speed != 1.0:
212
125
  logger.warning("Fish speech does not support setting speed: %s.", speed)
213
126
  import torchaudio
127
+ from tools.schema import ServeReferenceAudio, ServeTTSRequest
214
128
 
215
129
  prompt_speech = kwargs.get("prompt_speech")
216
130
  prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
217
- result = self._inference(
218
- text=input,
219
- enable_reference_audio=kwargs.get(
220
- "enable_reference_audio", prompt_speech is not None
221
- ),
222
- reference_audio=prompt_speech,
223
- reference_text=prompt_text,
224
- max_new_tokens=kwargs.get("max_new_tokens", 1024),
225
- chunk_length=kwargs.get("chunk_length", 200),
226
- top_p=kwargs.get("top_p", 0.7),
227
- repetition_penalty=kwargs.get("repetition_penalty", 1.2),
228
- temperature=kwargs.get("temperature", 0.7),
229
- streaming=stream,
131
+ if prompt_speech is not None:
132
+ r = ServeReferenceAudio(audio=prompt_speech, text=prompt_text)
133
+ references = [r]
134
+ else:
135
+ references = []
136
+
137
+ assert self._engine is not None
138
+ result = self._engine.inference(
139
+ ServeTTSRequest(
140
+ text=input,
141
+ references=references,
142
+ reference_id=kwargs.get("reference_id"),
143
+ seed=kwargs.get("seed"),
144
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
145
+ chunk_length=kwargs.get("chunk_length", 200),
146
+ top_p=kwargs.get("top_p", 0.7),
147
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
148
+ temperature=kwargs.get("temperature", 0.7),
149
+ streaming=stream,
150
+ format=response_format,
151
+ )
230
152
  )
231
153
 
232
154
  if stream:
@@ -242,7 +164,9 @@ class FishSpeechModel:
242
164
  last_pos = 0
243
165
  with writer.open():
244
166
  for chunk in result:
245
- chunk = chunk[0]
167
+ if chunk.code == "final":
168
+ continue
169
+ chunk = chunk.audio[1]
246
170
  if chunk is not None:
247
171
  chunk = chunk.reshape((chunk.shape[0], 1))
248
172
  trans_chunk = torch.from_numpy(chunk)
@@ -257,7 +181,7 @@ class FishSpeechModel:
257
181
  return _stream_generator()
258
182
  else:
259
183
  result = list(result)
260
- sample_rate, audio = result[0][1]
184
+ sample_rate, audio = result[0].audio
261
185
  audio = np.array([audio])
262
186
 
263
187
  # Save the generated audio
@@ -236,10 +236,18 @@
236
236
  "multilingual": true
237
237
  },
238
238
  {
239
- "model_name": "FishSpeech-1.4",
239
+ "model_name": "CosyVoice2-0.5B",
240
+ "model_family": "CosyVoice",
241
+ "model_id": "mrfakename/CosyVoice2-0.5B",
242
+ "model_revision": "5676baabc8a76dc93ef60a88bbd2420deaa2f644",
243
+ "model_ability": "text-to-audio",
244
+ "multilingual": true
245
+ },
246
+ {
247
+ "model_name": "FishSpeech-1.5",
240
248
  "model_family": "FishAudio",
241
- "model_id": "fishaudio/fish-speech-1.4",
242
- "model_revision": "069c573759936b35191d3380deb89183c0656f59",
249
+ "model_id": "fishaudio/fish-speech-1.5",
250
+ "model_revision": "268b6ec86243dd683bc78dab7e9a6cedf9191f2a",
243
251
  "model_ability": "text-to-audio",
244
252
  "multilingual": true
245
253
  },
@@ -250,5 +258,13 @@
250
258
  "model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
251
259
  "model_ability": "text-to-audio",
252
260
  "multilingual": true
261
+ },
262
+ {
263
+ "model_name": "F5-TTS-MLX",
264
+ "model_family": "F5-TTS-MLX",
265
+ "model_id": "lucasnewman/f5-tts-mlx",
266
+ "model_revision": "7642bb232e3fcacf92c51c786edebb8624da6b93",
267
+ "model_ability": "text-to-audio",
268
+ "multilingual": true
253
269
  }
254
270
  ]
@@ -74,6 +74,15 @@
74
74
  "model_ability": "text-to-audio",
75
75
  "multilingual": true
76
76
  },
77
+ {
78
+ "model_name": "CosyVoice2-0.5B",
79
+ "model_family": "CosyVoice",
80
+ "model_hub": "modelscope",
81
+ "model_id": "iic/CosyVoice2-0.5B",
82
+ "model_revision": "master",
83
+ "model_ability": "text-to-audio",
84
+ "multilingual": true
85
+ },
77
86
  {
78
87
  "model_name": "F5-TTS",
79
88
  "model_family": "F5-TTS",
@@ -11,8 +11,40 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import io
16
+
17
+ import numpy as np
18
+
14
19
  from .core import AudioModelFamilyV1
15
20
 
16
21
 
17
22
  def get_model_version(audio_model: AudioModelFamilyV1) -> str:
18
23
  return audio_model.model_name
24
+
25
+
26
+ def ensure_sample_rate(
27
+ audio: np.ndarray, old_sample_rate: int, sample_rate: int
28
+ ) -> np.ndarray:
29
+ import soundfile as sf
30
+ from scipy.signal import resample
31
+
32
+ if old_sample_rate != sample_rate:
33
+ # Calculate the new data length
34
+ new_length = int(len(audio) * sample_rate / old_sample_rate)
35
+
36
+ # Resample the data
37
+ resampled_data = resample(audio, new_length)
38
+
39
+ # Use BytesIO to save the resampled data to memory
40
+ with io.BytesIO() as buffer:
41
+ # Write the resampled data to the memory buffer
42
+ sf.write(buffer, resampled_data, sample_rate, format="WAV")
43
+
44
+ # Reset the buffer position to the beginning
45
+ buffer.seek(0)
46
+
47
+ # Read the data from the memory buffer
48
+ audio, sr = sf.read(buffer, dtype="float32")
49
+
50
+ return audio
@@ -22,7 +22,12 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
22
22
  from ...constants import XINFERENCE_CACHE_DIR
23
23
  from ...types import PeftModelConfig
24
24
  from ..core import CacheableModelSpec, ModelDescription
25
- from ..utils import valid_model_revision
25
+ from ..utils import (
26
+ IS_NEW_HUGGINGFACE_HUB,
27
+ retry_download,
28
+ symlink_local_file,
29
+ valid_model_revision,
30
+ )
26
31
  from .ocr.got_ocr2 import GotOCR2Model
27
32
  from .stable_diffusion.core import DiffusionModel
28
33
  from .stable_diffusion.mlx import MLXDiffusionModel
@@ -51,6 +56,9 @@ class ImageModelFamilyV1(CacheableModelSpec):
51
56
  controlnet: Optional[List["ImageModelFamilyV1"]]
52
57
  default_model_config: Optional[dict] = {}
53
58
  default_generate_config: Optional[dict] = {}
59
+ gguf_model_id: Optional[str]
60
+ gguf_quantizations: Optional[List[str]]
61
+ gguf_model_file_name_template: Optional[str]
54
62
 
55
63
 
56
64
  class ImageModelDescription(ModelDescription):
@@ -187,6 +195,61 @@ def get_cache_status(
187
195
  return valid_model_revision(meta_path, model_spec.model_revision)
188
196
 
189
197
 
198
+ def cache_gguf(spec: ImageModelFamilyV1, quantization: Optional[str] = None):
199
+ if not quantization:
200
+ return
201
+
202
+ cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, spec.model_name))
203
+ if not os.path.exists(cache_dir):
204
+ os.makedirs(cache_dir, exist_ok=True)
205
+
206
+ if not spec.gguf_model_file_name_template:
207
+ raise NotImplementedError(
208
+ f"{spec.model_name} does not support GGUF quantization"
209
+ )
210
+ if quantization not in (spec.gguf_quantizations or []):
211
+ raise ValueError(
212
+ f"Cannot support quantization {quantization}, "
213
+ f"available quantizations: {spec.gguf_quantizations}"
214
+ )
215
+
216
+ filename = spec.gguf_model_file_name_template.format(quantization=quantization) # type: ignore
217
+ full_path = os.path.join(cache_dir, filename)
218
+
219
+ if spec.model_hub == "huggingface":
220
+ import huggingface_hub
221
+
222
+ use_symlinks = {}
223
+ if not IS_NEW_HUGGINGFACE_HUB:
224
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
225
+ download_file_path = retry_download(
226
+ huggingface_hub.hf_hub_download,
227
+ spec.model_name,
228
+ None,
229
+ spec.gguf_model_id,
230
+ filename=filename,
231
+ **use_symlinks,
232
+ )
233
+ if IS_NEW_HUGGINGFACE_HUB:
234
+ symlink_local_file(download_file_path, cache_dir, filename)
235
+ elif spec.model_hub == "modelscope":
236
+ from modelscope.hub.file_download import model_file_download
237
+
238
+ download_file_path = retry_download(
239
+ model_file_download,
240
+ spec.model_name,
241
+ None,
242
+ spec.gguf_model_id,
243
+ filename,
244
+ revision=spec.model_revision,
245
+ )
246
+ symlink_local_file(download_file_path, cache_dir, filename)
247
+ else:
248
+ raise NotImplementedError
249
+
250
+ return full_path
251
+
252
+
190
253
  def create_ocr_model_instance(
191
254
  subpool_addr: str,
192
255
  devices: List[str],
@@ -219,6 +282,8 @@ def create_image_model_instance(
219
282
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
220
283
  ] = None,
221
284
  model_path: Optional[str] = None,
285
+ gguf_quantization: Optional[str] = None,
286
+ gguf_model_path: Optional[str] = None,
222
287
  **kwargs,
223
288
  ) -> Tuple[
224
289
  Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
@@ -272,6 +337,8 @@ def create_image_model_instance(
272
337
  ]
273
338
  if not model_path:
274
339
  model_path = cache(model_spec)
340
+ if not gguf_model_path and gguf_quantization:
341
+ gguf_model_path = cache_gguf(model_spec, gguf_quantization)
275
342
  if peft_model_config is not None:
276
343
  lora_model = peft_model_config.peft_model
277
344
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -298,6 +365,7 @@ def create_image_model_instance(
298
365
  lora_load_kwargs=lora_load_kwargs,
299
366
  lora_fuse_kwargs=lora_fuse_kwargs,
300
367
  model_spec=model_spec,
368
+ gguf_model_path=gguf_model_path,
301
369
  **kwargs,
302
370
  )
303
371
  model_description = ImageModelDescription(
@@ -11,8 +11,24 @@
11
11
  ],
12
12
  "default_model_config": {
13
13
  "quantize": true,
14
- "quantize_text_encoder": "text_encoder_2"
15
- }
14
+ "quantize_text_encoder": "text_encoder_2",
15
+ "torch_dtype": "bfloat16"
16
+ },
17
+ "gguf_model_id": "city96/FLUX.1-schnell-gguf",
18
+ "gguf_quantizations": [
19
+ "F16",
20
+ "Q2_K",
21
+ "Q3_K_S",
22
+ "Q4_0",
23
+ "Q4_1",
24
+ "Q4_K_S",
25
+ "Q5_0",
26
+ "Q5_1",
27
+ "Q5_K_S",
28
+ "Q6_K",
29
+ "Q8_0"
30
+ ],
31
+ "gguf_model_file_name_template": "flux1-schnell-{quantization}.gguf"
16
32
  },
17
33
  {
18
34
  "model_name": "FLUX.1-dev",
@@ -26,8 +42,24 @@
26
42
  ],
27
43
  "default_model_config": {
28
44
  "quantize": true,
29
- "quantize_text_encoder": "text_encoder_2"
30
- }
45
+ "quantize_text_encoder": "text_encoder_2",
46
+ "torch_dtype": "bfloat16"
47
+ },
48
+ "gguf_model_id": "city96/FLUX.1-dev-gguf",
49
+ "gguf_quantizations": [
50
+ "F16",
51
+ "Q2_K",
52
+ "Q3_K_S",
53
+ "Q4_0",
54
+ "Q4_1",
55
+ "Q4_K_S",
56
+ "Q5_0",
57
+ "Q5_1",
58
+ "Q5_K_S",
59
+ "Q6_K",
60
+ "Q8_0"
61
+ ],
62
+ "gguf_model_file_name_template": "flux1-dev-{quantization}.gguf"
31
63
  },
32
64
  {
33
65
  "model_name": "sd3-medium",
@@ -44,6 +76,97 @@
44
76
  "quantize_text_encoder": "text_encoder_3"
45
77
  }
46
78
  },
79
+ {
80
+ "model_name": "sd3.5-medium",
81
+ "model_family": "stable_diffusion",
82
+ "model_id": "stabilityai/stable-diffusion-3.5-medium",
83
+ "model_revision": "94b13ccbe959c51e8159d91f562c58f29fac971a",
84
+ "model_ability": [
85
+ "text2image",
86
+ "image2image",
87
+ "inpainting"
88
+ ],
89
+ "default_model_config": {
90
+ "quantize": true,
91
+ "quantize_text_encoder": "text_encoder_3",
92
+ "torch_dtype": "bfloat16"
93
+ },
94
+ "gguf_model_id": "city96/stable-diffusion-3.5-medium-gguf",
95
+ "gguf_quantizations": [
96
+ "F16",
97
+ "Q3_K_M",
98
+ "Q3_K_S",
99
+ "Q4_0",
100
+ "Q4_1",
101
+ "Q4_K_M",
102
+ "Q4_K_S",
103
+ "Q5_0",
104
+ "Q5_1",
105
+ "Q5_K_M",
106
+ "Q5_K_S",
107
+ "Q6_K",
108
+ "Q8_0"
109
+ ],
110
+ "gguf_model_file_name_template": "sd3.5_medium-{quantization}.gguf"
111
+ },
112
+ {
113
+ "model_name": "sd3.5-large",
114
+ "model_family": "stable_diffusion",
115
+ "model_id": "stabilityai/stable-diffusion-3.5-large",
116
+ "model_revision": "ceddf0a7fdf2064ea28e2213e3b84e4afa170a0f",
117
+ "model_ability": [
118
+ "text2image",
119
+ "image2image",
120
+ "inpainting"
121
+ ],
122
+ "default_model_config": {
123
+ "quantize": true,
124
+ "quantize_text_encoder": "text_encoder_3",
125
+ "torch_dtype": "bfloat16",
126
+ "transformer_nf4": true
127
+ },
128
+ "gguf_model_id": "city96/stable-diffusion-3.5-large-gguf",
129
+ "gguf_quantizations": [
130
+ "F16",
131
+ "Q4_0",
132
+ "Q4_1",
133
+ "Q5_0",
134
+ "Q5_1",
135
+ "Q8_0"
136
+ ],
137
+ "gguf_model_file_name_template": "sd3.5_large-{quantization}.gguf"
138
+ },
139
+ {
140
+ "model_name": "sd3.5-large-turbo",
141
+ "model_family": "stable_diffusion",
142
+ "model_id": "stabilityai/stable-diffusion-3.5-large-turbo",
143
+ "model_revision": "ec07796fc06b096cc56de9762974a28f4c632eda",
144
+ "model_ability": [
145
+ "text2image",
146
+ "image2image",
147
+ "inpainting"
148
+ ],
149
+ "default_model_config": {
150
+ "quantize": true,
151
+ "quantize_text_encoder": "text_encoder_3",
152
+ "torch_dtype": "bfloat16",
153
+ "transformer_nf4": true
154
+ },
155
+ "default_generate_config": {
156
+ "guidance_scale": 1.0,
157
+ "num_inference_steps": 4
158
+ },
159
+ "gguf_model_id": "city96/stable-diffusion-3.5-large-turbo-gguf",
160
+ "gguf_quantizations": [
161
+ "F16",
162
+ "Q4_0",
163
+ "Q4_1",
164
+ "Q5_0",
165
+ "Q5_1",
166
+ "Q8_0"
167
+ ],
168
+ "gguf_model_file_name_template": "sd3.5_large_turbo-{quantization}.gguf"
169
+ },
47
170
  {
48
171
  "model_name": "sd-turbo",
49
172
  "model_family": "stable_diffusion",