xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +23 -1
- xinference/core/model.py +1 -6
- xinference/core/utils.py +10 -6
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/llm_family.json +47 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +49 -0
- xinference/model/llm/mlx/core.py +68 -13
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/utils.py +1 -0
- xinference/model/llm/vllm/core.py +11 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -11,10 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import gc
|
|
15
14
|
import logging
|
|
16
15
|
import os.path
|
|
17
|
-
import queue
|
|
18
16
|
import sys
|
|
19
17
|
from io import BytesIO
|
|
20
18
|
from typing import TYPE_CHECKING, Optional
|
|
@@ -60,6 +58,7 @@ class FishSpeechModel:
|
|
|
60
58
|
self._device = device
|
|
61
59
|
self._llama_queue = None
|
|
62
60
|
self._model = None
|
|
61
|
+
self._engine = None
|
|
63
62
|
self._kwargs = kwargs
|
|
64
63
|
|
|
65
64
|
@property
|
|
@@ -72,6 +71,7 @@ class FishSpeechModel:
|
|
|
72
71
|
0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
|
|
73
72
|
)
|
|
74
73
|
|
|
74
|
+
from tools.inference_engine import TTSInferenceEngine
|
|
75
75
|
from tools.llama.generate import launch_thread_safe_queue
|
|
76
76
|
from tools.vqgan.inference import load_model as load_decoder_model
|
|
77
77
|
|
|
@@ -81,6 +81,11 @@ class FishSpeechModel:
|
|
|
81
81
|
if not is_device_available(self._device):
|
|
82
82
|
raise ValueError(f"Device {self._device} is not available!")
|
|
83
83
|
|
|
84
|
+
# https://github.com/pytorch/pytorch/issues/129207
|
|
85
|
+
if self._device == "mps":
|
|
86
|
+
logger.warning("The Conv1d has bugs on MPS backend, fallback to CPU.")
|
|
87
|
+
self._device = "cpu"
|
|
88
|
+
|
|
84
89
|
enable_compile = self._kwargs.get("compile", False)
|
|
85
90
|
precision = self._kwargs.get("precision", torch.bfloat16)
|
|
86
91
|
logger.info("Loading Llama model, compile=%s...", enable_compile)
|
|
@@ -102,102 +107,10 @@ class FishSpeechModel:
|
|
|
102
107
|
device=self._device,
|
|
103
108
|
)
|
|
104
109
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
self,
|
|
108
|
-
text,
|
|
109
|
-
enable_reference_audio,
|
|
110
|
-
reference_audio,
|
|
111
|
-
reference_text,
|
|
112
|
-
max_new_tokens,
|
|
113
|
-
chunk_length,
|
|
114
|
-
top_p,
|
|
115
|
-
repetition_penalty,
|
|
116
|
-
temperature,
|
|
117
|
-
seed="0",
|
|
118
|
-
streaming=False,
|
|
119
|
-
):
|
|
120
|
-
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
121
|
-
from tools.api import decode_vq_tokens, encode_reference
|
|
122
|
-
from tools.llama.generate import (
|
|
123
|
-
GenerateRequest,
|
|
124
|
-
GenerateResponse,
|
|
125
|
-
WrappedGenerateResponse,
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
seed = int(seed)
|
|
129
|
-
if seed != 0:
|
|
130
|
-
set_seed(seed)
|
|
131
|
-
logger.warning(f"set seed: {seed}")
|
|
132
|
-
|
|
133
|
-
# Parse reference audio aka prompt
|
|
134
|
-
prompt_tokens = encode_reference(
|
|
135
|
-
decoder_model=self._model,
|
|
136
|
-
reference_audio=reference_audio,
|
|
137
|
-
enable_reference_audio=enable_reference_audio,
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
# LLAMA Inference
|
|
141
|
-
request = dict(
|
|
142
|
-
device=self._model.device,
|
|
143
|
-
max_new_tokens=max_new_tokens,
|
|
144
|
-
text=text,
|
|
145
|
-
top_p=top_p,
|
|
146
|
-
repetition_penalty=repetition_penalty,
|
|
147
|
-
temperature=temperature,
|
|
148
|
-
compile=self._kwargs.get("compile", False),
|
|
149
|
-
iterative_prompt=chunk_length > 0,
|
|
150
|
-
chunk_length=chunk_length,
|
|
151
|
-
max_length=2048,
|
|
152
|
-
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
153
|
-
prompt_text=reference_text if enable_reference_audio else None,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
response_queue = queue.Queue()
|
|
157
|
-
self._llama_queue.put(
|
|
158
|
-
GenerateRequest(
|
|
159
|
-
request=request,
|
|
160
|
-
response_queue=response_queue,
|
|
161
|
-
)
|
|
110
|
+
self._engine = TTSInferenceEngine(
|
|
111
|
+
self._llama_queue, self._model, precision, enable_compile
|
|
162
112
|
)
|
|
163
113
|
|
|
164
|
-
segments = []
|
|
165
|
-
|
|
166
|
-
while True:
|
|
167
|
-
result: WrappedGenerateResponse = response_queue.get()
|
|
168
|
-
if result.status == "error":
|
|
169
|
-
raise result.response
|
|
170
|
-
|
|
171
|
-
result: GenerateResponse = result.response
|
|
172
|
-
if result.action == "next":
|
|
173
|
-
break
|
|
174
|
-
|
|
175
|
-
with autocast_exclude_mps(
|
|
176
|
-
device_type=self._model.device.type,
|
|
177
|
-
dtype=self._kwargs.get("precision", torch.bfloat16),
|
|
178
|
-
):
|
|
179
|
-
fake_audios = decode_vq_tokens(
|
|
180
|
-
decoder_model=self._model,
|
|
181
|
-
codes=result.codes,
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
fake_audios = fake_audios.float().cpu().numpy()
|
|
185
|
-
segments.append(fake_audios)
|
|
186
|
-
|
|
187
|
-
if streaming:
|
|
188
|
-
yield fake_audios, None, None
|
|
189
|
-
|
|
190
|
-
if len(segments) == 0:
|
|
191
|
-
raise Exception("No audio generated, please check the input text.")
|
|
192
|
-
|
|
193
|
-
# No matter streaming or not, we need to return the final audio
|
|
194
|
-
audio = np.concatenate(segments, axis=0)
|
|
195
|
-
yield None, (self._model.spec_transform.sample_rate, audio), None
|
|
196
|
-
|
|
197
|
-
if torch.cuda.is_available():
|
|
198
|
-
torch.cuda.empty_cache()
|
|
199
|
-
gc.collect()
|
|
200
|
-
|
|
201
114
|
def speech(
|
|
202
115
|
self,
|
|
203
116
|
input: str,
|
|
@@ -211,22 +124,31 @@ class FishSpeechModel:
|
|
|
211
124
|
if speed != 1.0:
|
|
212
125
|
logger.warning("Fish speech does not support setting speed: %s.", speed)
|
|
213
126
|
import torchaudio
|
|
127
|
+
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
|
214
128
|
|
|
215
129
|
prompt_speech = kwargs.get("prompt_speech")
|
|
216
130
|
prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
|
|
217
|
-
|
|
218
|
-
text=
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
131
|
+
if prompt_speech is not None:
|
|
132
|
+
r = ServeReferenceAudio(audio=prompt_speech, text=prompt_text)
|
|
133
|
+
references = [r]
|
|
134
|
+
else:
|
|
135
|
+
references = []
|
|
136
|
+
|
|
137
|
+
assert self._engine is not None
|
|
138
|
+
result = self._engine.inference(
|
|
139
|
+
ServeTTSRequest(
|
|
140
|
+
text=input,
|
|
141
|
+
references=references,
|
|
142
|
+
reference_id=kwargs.get("reference_id"),
|
|
143
|
+
seed=kwargs.get("seed"),
|
|
144
|
+
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
|
145
|
+
chunk_length=kwargs.get("chunk_length", 200),
|
|
146
|
+
top_p=kwargs.get("top_p", 0.7),
|
|
147
|
+
repetition_penalty=kwargs.get("repetition_penalty", 1.2),
|
|
148
|
+
temperature=kwargs.get("temperature", 0.7),
|
|
149
|
+
streaming=stream,
|
|
150
|
+
format=response_format,
|
|
151
|
+
)
|
|
230
152
|
)
|
|
231
153
|
|
|
232
154
|
if stream:
|
|
@@ -242,7 +164,9 @@ class FishSpeechModel:
|
|
|
242
164
|
last_pos = 0
|
|
243
165
|
with writer.open():
|
|
244
166
|
for chunk in result:
|
|
245
|
-
chunk
|
|
167
|
+
if chunk.code == "final":
|
|
168
|
+
continue
|
|
169
|
+
chunk = chunk.audio[1]
|
|
246
170
|
if chunk is not None:
|
|
247
171
|
chunk = chunk.reshape((chunk.shape[0], 1))
|
|
248
172
|
trans_chunk = torch.from_numpy(chunk)
|
|
@@ -257,7 +181,7 @@ class FishSpeechModel:
|
|
|
257
181
|
return _stream_generator()
|
|
258
182
|
else:
|
|
259
183
|
result = list(result)
|
|
260
|
-
sample_rate, audio = result[0]
|
|
184
|
+
sample_rate, audio = result[0].audio
|
|
261
185
|
audio = np.array([audio])
|
|
262
186
|
|
|
263
187
|
# Save the generated audio
|
|
@@ -236,10 +236,18 @@
|
|
|
236
236
|
"multilingual": true
|
|
237
237
|
},
|
|
238
238
|
{
|
|
239
|
-
"model_name": "
|
|
239
|
+
"model_name": "CosyVoice2-0.5B",
|
|
240
|
+
"model_family": "CosyVoice",
|
|
241
|
+
"model_id": "mrfakename/CosyVoice2-0.5B",
|
|
242
|
+
"model_revision": "5676baabc8a76dc93ef60a88bbd2420deaa2f644",
|
|
243
|
+
"model_ability": "text-to-audio",
|
|
244
|
+
"multilingual": true
|
|
245
|
+
},
|
|
246
|
+
{
|
|
247
|
+
"model_name": "FishSpeech-1.5",
|
|
240
248
|
"model_family": "FishAudio",
|
|
241
|
-
"model_id": "fishaudio/fish-speech-1.
|
|
242
|
-
"model_revision": "
|
|
249
|
+
"model_id": "fishaudio/fish-speech-1.5",
|
|
250
|
+
"model_revision": "268b6ec86243dd683bc78dab7e9a6cedf9191f2a",
|
|
243
251
|
"model_ability": "text-to-audio",
|
|
244
252
|
"multilingual": true
|
|
245
253
|
},
|
|
@@ -250,5 +258,13 @@
|
|
|
250
258
|
"model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
|
|
251
259
|
"model_ability": "text-to-audio",
|
|
252
260
|
"multilingual": true
|
|
261
|
+
},
|
|
262
|
+
{
|
|
263
|
+
"model_name": "F5-TTS-MLX",
|
|
264
|
+
"model_family": "F5-TTS-MLX",
|
|
265
|
+
"model_id": "lucasnewman/f5-tts-mlx",
|
|
266
|
+
"model_revision": "7642bb232e3fcacf92c51c786edebb8624da6b93",
|
|
267
|
+
"model_ability": "text-to-audio",
|
|
268
|
+
"multilingual": true
|
|
253
269
|
}
|
|
254
270
|
]
|
|
@@ -74,6 +74,15 @@
|
|
|
74
74
|
"model_ability": "text-to-audio",
|
|
75
75
|
"multilingual": true
|
|
76
76
|
},
|
|
77
|
+
{
|
|
78
|
+
"model_name": "CosyVoice2-0.5B",
|
|
79
|
+
"model_family": "CosyVoice",
|
|
80
|
+
"model_hub": "modelscope",
|
|
81
|
+
"model_id": "iic/CosyVoice2-0.5B",
|
|
82
|
+
"model_revision": "master",
|
|
83
|
+
"model_ability": "text-to-audio",
|
|
84
|
+
"multilingual": true
|
|
85
|
+
},
|
|
77
86
|
{
|
|
78
87
|
"model_name": "F5-TTS",
|
|
79
88
|
"model_family": "F5-TTS",
|
xinference/model/audio/utils.py
CHANGED
|
@@ -11,8 +11,40 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import io
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
14
19
|
from .core import AudioModelFamilyV1
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
def get_model_version(audio_model: AudioModelFamilyV1) -> str:
|
|
18
23
|
return audio_model.model_name
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def ensure_sample_rate(
|
|
27
|
+
audio: np.ndarray, old_sample_rate: int, sample_rate: int
|
|
28
|
+
) -> np.ndarray:
|
|
29
|
+
import soundfile as sf
|
|
30
|
+
from scipy.signal import resample
|
|
31
|
+
|
|
32
|
+
if old_sample_rate != sample_rate:
|
|
33
|
+
# Calculate the new data length
|
|
34
|
+
new_length = int(len(audio) * sample_rate / old_sample_rate)
|
|
35
|
+
|
|
36
|
+
# Resample the data
|
|
37
|
+
resampled_data = resample(audio, new_length)
|
|
38
|
+
|
|
39
|
+
# Use BytesIO to save the resampled data to memory
|
|
40
|
+
with io.BytesIO() as buffer:
|
|
41
|
+
# Write the resampled data to the memory buffer
|
|
42
|
+
sf.write(buffer, resampled_data, sample_rate, format="WAV")
|
|
43
|
+
|
|
44
|
+
# Reset the buffer position to the beginning
|
|
45
|
+
buffer.seek(0)
|
|
46
|
+
|
|
47
|
+
# Read the data from the memory buffer
|
|
48
|
+
audio, sr = sf.read(buffer, dtype="float32")
|
|
49
|
+
|
|
50
|
+
return audio
|
xinference/model/image/core.py
CHANGED
|
@@ -22,7 +22,12 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
|
22
22
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
23
23
|
from ...types import PeftModelConfig
|
|
24
24
|
from ..core import CacheableModelSpec, ModelDescription
|
|
25
|
-
from ..utils import
|
|
25
|
+
from ..utils import (
|
|
26
|
+
IS_NEW_HUGGINGFACE_HUB,
|
|
27
|
+
retry_download,
|
|
28
|
+
symlink_local_file,
|
|
29
|
+
valid_model_revision,
|
|
30
|
+
)
|
|
26
31
|
from .ocr.got_ocr2 import GotOCR2Model
|
|
27
32
|
from .stable_diffusion.core import DiffusionModel
|
|
28
33
|
from .stable_diffusion.mlx import MLXDiffusionModel
|
|
@@ -51,6 +56,9 @@ class ImageModelFamilyV1(CacheableModelSpec):
|
|
|
51
56
|
controlnet: Optional[List["ImageModelFamilyV1"]]
|
|
52
57
|
default_model_config: Optional[dict] = {}
|
|
53
58
|
default_generate_config: Optional[dict] = {}
|
|
59
|
+
gguf_model_id: Optional[str]
|
|
60
|
+
gguf_quantizations: Optional[List[str]]
|
|
61
|
+
gguf_model_file_name_template: Optional[str]
|
|
54
62
|
|
|
55
63
|
|
|
56
64
|
class ImageModelDescription(ModelDescription):
|
|
@@ -187,6 +195,61 @@ def get_cache_status(
|
|
|
187
195
|
return valid_model_revision(meta_path, model_spec.model_revision)
|
|
188
196
|
|
|
189
197
|
|
|
198
|
+
def cache_gguf(spec: ImageModelFamilyV1, quantization: Optional[str] = None):
|
|
199
|
+
if not quantization:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, spec.model_name))
|
|
203
|
+
if not os.path.exists(cache_dir):
|
|
204
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
205
|
+
|
|
206
|
+
if not spec.gguf_model_file_name_template:
|
|
207
|
+
raise NotImplementedError(
|
|
208
|
+
f"{spec.model_name} does not support GGUF quantization"
|
|
209
|
+
)
|
|
210
|
+
if quantization not in (spec.gguf_quantizations or []):
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Cannot support quantization {quantization}, "
|
|
213
|
+
f"available quantizations: {spec.gguf_quantizations}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
filename = spec.gguf_model_file_name_template.format(quantization=quantization) # type: ignore
|
|
217
|
+
full_path = os.path.join(cache_dir, filename)
|
|
218
|
+
|
|
219
|
+
if spec.model_hub == "huggingface":
|
|
220
|
+
import huggingface_hub
|
|
221
|
+
|
|
222
|
+
use_symlinks = {}
|
|
223
|
+
if not IS_NEW_HUGGINGFACE_HUB:
|
|
224
|
+
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
|
|
225
|
+
download_file_path = retry_download(
|
|
226
|
+
huggingface_hub.hf_hub_download,
|
|
227
|
+
spec.model_name,
|
|
228
|
+
None,
|
|
229
|
+
spec.gguf_model_id,
|
|
230
|
+
filename=filename,
|
|
231
|
+
**use_symlinks,
|
|
232
|
+
)
|
|
233
|
+
if IS_NEW_HUGGINGFACE_HUB:
|
|
234
|
+
symlink_local_file(download_file_path, cache_dir, filename)
|
|
235
|
+
elif spec.model_hub == "modelscope":
|
|
236
|
+
from modelscope.hub.file_download import model_file_download
|
|
237
|
+
|
|
238
|
+
download_file_path = retry_download(
|
|
239
|
+
model_file_download,
|
|
240
|
+
spec.model_name,
|
|
241
|
+
None,
|
|
242
|
+
spec.gguf_model_id,
|
|
243
|
+
filename,
|
|
244
|
+
revision=spec.model_revision,
|
|
245
|
+
)
|
|
246
|
+
symlink_local_file(download_file_path, cache_dir, filename)
|
|
247
|
+
else:
|
|
248
|
+
raise NotImplementedError
|
|
249
|
+
|
|
250
|
+
return full_path
|
|
251
|
+
|
|
252
|
+
|
|
190
253
|
def create_ocr_model_instance(
|
|
191
254
|
subpool_addr: str,
|
|
192
255
|
devices: List[str],
|
|
@@ -219,6 +282,8 @@ def create_image_model_instance(
|
|
|
219
282
|
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
220
283
|
] = None,
|
|
221
284
|
model_path: Optional[str] = None,
|
|
285
|
+
gguf_quantization: Optional[str] = None,
|
|
286
|
+
gguf_model_path: Optional[str] = None,
|
|
222
287
|
**kwargs,
|
|
223
288
|
) -> Tuple[
|
|
224
289
|
Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
|
|
@@ -272,6 +337,8 @@ def create_image_model_instance(
|
|
|
272
337
|
]
|
|
273
338
|
if not model_path:
|
|
274
339
|
model_path = cache(model_spec)
|
|
340
|
+
if not gguf_model_path and gguf_quantization:
|
|
341
|
+
gguf_model_path = cache_gguf(model_spec, gguf_quantization)
|
|
275
342
|
if peft_model_config is not None:
|
|
276
343
|
lora_model = peft_model_config.peft_model
|
|
277
344
|
lora_load_kwargs = peft_model_config.image_lora_load_kwargs
|
|
@@ -298,6 +365,7 @@ def create_image_model_instance(
|
|
|
298
365
|
lora_load_kwargs=lora_load_kwargs,
|
|
299
366
|
lora_fuse_kwargs=lora_fuse_kwargs,
|
|
300
367
|
model_spec=model_spec,
|
|
368
|
+
gguf_model_path=gguf_model_path,
|
|
301
369
|
**kwargs,
|
|
302
370
|
)
|
|
303
371
|
model_description = ImageModelDescription(
|
|
@@ -11,8 +11,24 @@
|
|
|
11
11
|
],
|
|
12
12
|
"default_model_config": {
|
|
13
13
|
"quantize": true,
|
|
14
|
-
"quantize_text_encoder": "text_encoder_2"
|
|
15
|
-
|
|
14
|
+
"quantize_text_encoder": "text_encoder_2",
|
|
15
|
+
"torch_dtype": "bfloat16"
|
|
16
|
+
},
|
|
17
|
+
"gguf_model_id": "city96/FLUX.1-schnell-gguf",
|
|
18
|
+
"gguf_quantizations": [
|
|
19
|
+
"F16",
|
|
20
|
+
"Q2_K",
|
|
21
|
+
"Q3_K_S",
|
|
22
|
+
"Q4_0",
|
|
23
|
+
"Q4_1",
|
|
24
|
+
"Q4_K_S",
|
|
25
|
+
"Q5_0",
|
|
26
|
+
"Q5_1",
|
|
27
|
+
"Q5_K_S",
|
|
28
|
+
"Q6_K",
|
|
29
|
+
"Q8_0"
|
|
30
|
+
],
|
|
31
|
+
"gguf_model_file_name_template": "flux1-schnell-{quantization}.gguf"
|
|
16
32
|
},
|
|
17
33
|
{
|
|
18
34
|
"model_name": "FLUX.1-dev",
|
|
@@ -26,8 +42,24 @@
|
|
|
26
42
|
],
|
|
27
43
|
"default_model_config": {
|
|
28
44
|
"quantize": true,
|
|
29
|
-
"quantize_text_encoder": "text_encoder_2"
|
|
30
|
-
|
|
45
|
+
"quantize_text_encoder": "text_encoder_2",
|
|
46
|
+
"torch_dtype": "bfloat16"
|
|
47
|
+
},
|
|
48
|
+
"gguf_model_id": "city96/FLUX.1-dev-gguf",
|
|
49
|
+
"gguf_quantizations": [
|
|
50
|
+
"F16",
|
|
51
|
+
"Q2_K",
|
|
52
|
+
"Q3_K_S",
|
|
53
|
+
"Q4_0",
|
|
54
|
+
"Q4_1",
|
|
55
|
+
"Q4_K_S",
|
|
56
|
+
"Q5_0",
|
|
57
|
+
"Q5_1",
|
|
58
|
+
"Q5_K_S",
|
|
59
|
+
"Q6_K",
|
|
60
|
+
"Q8_0"
|
|
61
|
+
],
|
|
62
|
+
"gguf_model_file_name_template": "flux1-dev-{quantization}.gguf"
|
|
31
63
|
},
|
|
32
64
|
{
|
|
33
65
|
"model_name": "sd3-medium",
|
|
@@ -44,6 +76,97 @@
|
|
|
44
76
|
"quantize_text_encoder": "text_encoder_3"
|
|
45
77
|
}
|
|
46
78
|
},
|
|
79
|
+
{
|
|
80
|
+
"model_name": "sd3.5-medium",
|
|
81
|
+
"model_family": "stable_diffusion",
|
|
82
|
+
"model_id": "stabilityai/stable-diffusion-3.5-medium",
|
|
83
|
+
"model_revision": "94b13ccbe959c51e8159d91f562c58f29fac971a",
|
|
84
|
+
"model_ability": [
|
|
85
|
+
"text2image",
|
|
86
|
+
"image2image",
|
|
87
|
+
"inpainting"
|
|
88
|
+
],
|
|
89
|
+
"default_model_config": {
|
|
90
|
+
"quantize": true,
|
|
91
|
+
"quantize_text_encoder": "text_encoder_3",
|
|
92
|
+
"torch_dtype": "bfloat16"
|
|
93
|
+
},
|
|
94
|
+
"gguf_model_id": "city96/stable-diffusion-3.5-medium-gguf",
|
|
95
|
+
"gguf_quantizations": [
|
|
96
|
+
"F16",
|
|
97
|
+
"Q3_K_M",
|
|
98
|
+
"Q3_K_S",
|
|
99
|
+
"Q4_0",
|
|
100
|
+
"Q4_1",
|
|
101
|
+
"Q4_K_M",
|
|
102
|
+
"Q4_K_S",
|
|
103
|
+
"Q5_0",
|
|
104
|
+
"Q5_1",
|
|
105
|
+
"Q5_K_M",
|
|
106
|
+
"Q5_K_S",
|
|
107
|
+
"Q6_K",
|
|
108
|
+
"Q8_0"
|
|
109
|
+
],
|
|
110
|
+
"gguf_model_file_name_template": "sd3.5_medium-{quantization}.gguf"
|
|
111
|
+
},
|
|
112
|
+
{
|
|
113
|
+
"model_name": "sd3.5-large",
|
|
114
|
+
"model_family": "stable_diffusion",
|
|
115
|
+
"model_id": "stabilityai/stable-diffusion-3.5-large",
|
|
116
|
+
"model_revision": "ceddf0a7fdf2064ea28e2213e3b84e4afa170a0f",
|
|
117
|
+
"model_ability": [
|
|
118
|
+
"text2image",
|
|
119
|
+
"image2image",
|
|
120
|
+
"inpainting"
|
|
121
|
+
],
|
|
122
|
+
"default_model_config": {
|
|
123
|
+
"quantize": true,
|
|
124
|
+
"quantize_text_encoder": "text_encoder_3",
|
|
125
|
+
"torch_dtype": "bfloat16",
|
|
126
|
+
"transformer_nf4": true
|
|
127
|
+
},
|
|
128
|
+
"gguf_model_id": "city96/stable-diffusion-3.5-large-gguf",
|
|
129
|
+
"gguf_quantizations": [
|
|
130
|
+
"F16",
|
|
131
|
+
"Q4_0",
|
|
132
|
+
"Q4_1",
|
|
133
|
+
"Q5_0",
|
|
134
|
+
"Q5_1",
|
|
135
|
+
"Q8_0"
|
|
136
|
+
],
|
|
137
|
+
"gguf_model_file_name_template": "sd3.5_large-{quantization}.gguf"
|
|
138
|
+
},
|
|
139
|
+
{
|
|
140
|
+
"model_name": "sd3.5-large-turbo",
|
|
141
|
+
"model_family": "stable_diffusion",
|
|
142
|
+
"model_id": "stabilityai/stable-diffusion-3.5-large-turbo",
|
|
143
|
+
"model_revision": "ec07796fc06b096cc56de9762974a28f4c632eda",
|
|
144
|
+
"model_ability": [
|
|
145
|
+
"text2image",
|
|
146
|
+
"image2image",
|
|
147
|
+
"inpainting"
|
|
148
|
+
],
|
|
149
|
+
"default_model_config": {
|
|
150
|
+
"quantize": true,
|
|
151
|
+
"quantize_text_encoder": "text_encoder_3",
|
|
152
|
+
"torch_dtype": "bfloat16",
|
|
153
|
+
"transformer_nf4": true
|
|
154
|
+
},
|
|
155
|
+
"default_generate_config": {
|
|
156
|
+
"guidance_scale": 1.0,
|
|
157
|
+
"num_inference_steps": 4
|
|
158
|
+
},
|
|
159
|
+
"gguf_model_id": "city96/stable-diffusion-3.5-large-turbo-gguf",
|
|
160
|
+
"gguf_quantizations": [
|
|
161
|
+
"F16",
|
|
162
|
+
"Q4_0",
|
|
163
|
+
"Q4_1",
|
|
164
|
+
"Q5_0",
|
|
165
|
+
"Q5_1",
|
|
166
|
+
"Q8_0"
|
|
167
|
+
],
|
|
168
|
+
"gguf_model_file_name_template": "sd3.5_large_turbo-{quantization}.gguf"
|
|
169
|
+
},
|
|
47
170
|
{
|
|
48
171
|
"model_name": "sd-turbo",
|
|
49
172
|
"model_family": "stable_diffusion",
|