xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +22 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +91 -6
- xinference/client/restful/restful_client.py +39 -0
- xinference/core/model.py +41 -13
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +26 -4
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +71 -35
- xinference/model/audio/model_spec.json +88 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +322 -6
- xinference/model/embedding/model_spec.json +8 -1
- xinference/model/embedding/model_spec_modelscope.json +9 -1
- xinference/model/llm/__init__.py +4 -2
- xinference/model/llm/llm_family.json +479 -53
- xinference/model/llm/llm_family_modelscope.json +423 -17
- xinference/model/llm/mlx/core.py +230 -50
- xinference/model/llm/sglang/core.py +2 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +23 -1
- xinference/model/llm/vllm/core.py +89 -2
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/types.py +2 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -156,6 +156,7 @@ def _get_completion(
|
|
|
156
156
|
finish_reason: Optional[str],
|
|
157
157
|
model_uid: str,
|
|
158
158
|
r: InferenceRequest,
|
|
159
|
+
completion_tokens: int,
|
|
159
160
|
):
|
|
160
161
|
completion_choice = CompletionChoice(
|
|
161
162
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
@@ -170,8 +171,8 @@ def _get_completion(
|
|
|
170
171
|
)
|
|
171
172
|
completion_usage = CompletionUsage(
|
|
172
173
|
prompt_tokens=len(r.prompt_tokens),
|
|
173
|
-
completion_tokens=
|
|
174
|
-
total_tokens=len(r.prompt_tokens) +
|
|
174
|
+
completion_tokens=completion_tokens,
|
|
175
|
+
total_tokens=len(r.prompt_tokens) + completion_tokens,
|
|
175
176
|
)
|
|
176
177
|
completion = Completion(
|
|
177
178
|
id=completion_chunk["id"],
|
|
@@ -371,7 +372,7 @@ def _batch_inference_one_step_internal(
|
|
|
371
372
|
r.stopped = stopped
|
|
372
373
|
r.finish_reason = finish_reason
|
|
373
374
|
|
|
374
|
-
if r.stopped and r not in stop_token_mapping
|
|
375
|
+
if r.stopped and r not in stop_token_mapping:
|
|
375
376
|
stop_token_mapping[r] = _i + 1
|
|
376
377
|
|
|
377
378
|
if r.stream:
|
|
@@ -446,12 +447,14 @@ def _batch_inference_one_step_internal(
|
|
|
446
447
|
else:
|
|
447
448
|
# last round, handle non-stream result
|
|
448
449
|
if r.stopped and _i == decode_round - 1:
|
|
449
|
-
invalid_token_num =
|
|
450
|
+
invalid_token_num = (
|
|
451
|
+
(decode_round - stop_token_mapping[r] + 1)
|
|
452
|
+
if r.finish_reason == "stop"
|
|
453
|
+
else (decode_round - stop_token_mapping[r])
|
|
454
|
+
)
|
|
450
455
|
outputs = (
|
|
451
456
|
tokenizer.decode(
|
|
452
|
-
r.new_tokens[
|
|
453
|
-
if r.finish_reason == "stop"
|
|
454
|
-
else r.new_tokens[:-invalid_token_num],
|
|
457
|
+
r.new_tokens[:-invalid_token_num],
|
|
455
458
|
skip_special_tokens=True,
|
|
456
459
|
spaces_between_special_tokens=False,
|
|
457
460
|
clean_up_tokenization_spaces=True,
|
|
@@ -460,7 +463,12 @@ def _batch_inference_one_step_internal(
|
|
|
460
463
|
else output_mapping[r]
|
|
461
464
|
)
|
|
462
465
|
completion = _get_completion(
|
|
463
|
-
outputs,
|
|
466
|
+
outputs,
|
|
467
|
+
r.chunk_id,
|
|
468
|
+
r.finish_reason,
|
|
469
|
+
model_uid,
|
|
470
|
+
r,
|
|
471
|
+
len(r.new_tokens) - invalid_token_num,
|
|
464
472
|
)
|
|
465
473
|
r.completion = [completion]
|
|
466
474
|
|
xinference/model/llm/utils.py
CHANGED
|
@@ -324,7 +324,10 @@ class ChatModelMixin:
|
|
|
324
324
|
"""
|
|
325
325
|
try:
|
|
326
326
|
if isinstance(c, dict):
|
|
327
|
-
|
|
327
|
+
try:
|
|
328
|
+
return [(None, c["name"], json.loads(c["arguments"]))]
|
|
329
|
+
except Exception:
|
|
330
|
+
return [(None, c["name"], c["arguments"])]
|
|
328
331
|
except KeyError:
|
|
329
332
|
logger.error("Can't parse glm output: %s", c)
|
|
330
333
|
return [(str(c), None, None)]
|
|
@@ -569,6 +572,25 @@ def _decode_image(_url):
|
|
|
569
572
|
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
570
573
|
|
|
571
574
|
|
|
575
|
+
def _decode_image_without_rgb(_url):
|
|
576
|
+
if _url.startswith("data:"):
|
|
577
|
+
logging.info("Parse url by base64 decoder.")
|
|
578
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
579
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
580
|
+
_type, data = _url.split(";")
|
|
581
|
+
_, ext = _type.split("/")
|
|
582
|
+
data = data[len("base64,") :]
|
|
583
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
584
|
+
return Image.open(BytesIO(data))
|
|
585
|
+
else:
|
|
586
|
+
try:
|
|
587
|
+
response = requests.get(_url)
|
|
588
|
+
except requests.exceptions.MissingSchema:
|
|
589
|
+
return Image.open(_url)
|
|
590
|
+
else:
|
|
591
|
+
return Image.open(BytesIO(response.content))
|
|
592
|
+
|
|
593
|
+
|
|
572
594
|
@typing.no_type_check
|
|
573
595
|
def generate_completion_chunk(
|
|
574
596
|
chunk_text: Optional[str],
|
|
@@ -69,6 +69,7 @@ class VLLMModelConfig(TypedDict, total=False):
|
|
|
69
69
|
quantization: Optional[str]
|
|
70
70
|
max_model_len: Optional[int]
|
|
71
71
|
limit_mm_per_prompt: Optional[Dict[str, int]]
|
|
72
|
+
guided_decoding_backend: Optional[str]
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class VLLMGenerateConfig(TypedDict, total=False):
|
|
@@ -85,6 +86,15 @@ class VLLMGenerateConfig(TypedDict, total=False):
|
|
|
85
86
|
stop: Optional[Union[str, List[str]]]
|
|
86
87
|
stream: bool # non-sampling param, should not be passed to the engine.
|
|
87
88
|
stream_options: Optional[Union[dict, None]]
|
|
89
|
+
skip_special_tokens: Optional[bool]
|
|
90
|
+
response_format: Optional[dict]
|
|
91
|
+
guided_json: Optional[Union[str, dict]]
|
|
92
|
+
guided_regex: Optional[str]
|
|
93
|
+
guided_choice: Optional[List[str]]
|
|
94
|
+
guided_grammar: Optional[str]
|
|
95
|
+
guided_json_object: Optional[bool]
|
|
96
|
+
guided_decoding_backend: Optional[str]
|
|
97
|
+
guided_whitespace_pattern: Optional[str]
|
|
88
98
|
|
|
89
99
|
|
|
90
100
|
try:
|
|
@@ -144,6 +154,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
|
144
154
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-instruct")
|
|
145
155
|
VLLM_SUPPORTED_MODELS.append("qwen2.5-coder")
|
|
146
156
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2.5-coder-instruct")
|
|
157
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("QwQ-32B-Preview")
|
|
147
158
|
|
|
148
159
|
|
|
149
160
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
@@ -171,6 +182,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
|
|
|
171
182
|
if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
|
|
172
183
|
VLLM_SUPPORTED_MODELS.append("llama-3.1")
|
|
173
184
|
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
|
|
185
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.3-instruct")
|
|
174
186
|
|
|
175
187
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
|
|
176
188
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
|
|
@@ -314,6 +326,7 @@ class VLLMModel(LLM):
|
|
|
314
326
|
model_config.setdefault("max_num_seqs", 256)
|
|
315
327
|
model_config.setdefault("quantization", None)
|
|
316
328
|
model_config.setdefault("max_model_len", None)
|
|
329
|
+
model_config.setdefault("guided_decoding_backend", "outlines")
|
|
317
330
|
|
|
318
331
|
return model_config
|
|
319
332
|
|
|
@@ -325,6 +338,22 @@ class VLLMModel(LLM):
|
|
|
325
338
|
generate_config = {}
|
|
326
339
|
|
|
327
340
|
sanitized = VLLMGenerateConfig()
|
|
341
|
+
|
|
342
|
+
response_format = generate_config.pop("response_format", None)
|
|
343
|
+
guided_decoding_backend = generate_config.get("guided_decoding_backend", None)
|
|
344
|
+
guided_json_object = None
|
|
345
|
+
guided_json = None
|
|
346
|
+
|
|
347
|
+
if response_format is not None:
|
|
348
|
+
if response_format.get("type") == "json_object":
|
|
349
|
+
guided_json_object = True
|
|
350
|
+
elif response_format.get("type") == "json_schema":
|
|
351
|
+
json_schema = response_format.get("json_schema")
|
|
352
|
+
assert json_schema is not None
|
|
353
|
+
guided_json = json_schema.get("json_schema")
|
|
354
|
+
if guided_decoding_backend is None:
|
|
355
|
+
guided_decoding_backend = "outlines"
|
|
356
|
+
|
|
328
357
|
sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
|
|
329
358
|
sanitized.setdefault("n", generate_config.get("n", 1))
|
|
330
359
|
sanitized.setdefault("best_of", generate_config.get("best_of", None))
|
|
@@ -346,6 +375,31 @@ class VLLMModel(LLM):
|
|
|
346
375
|
sanitized.setdefault(
|
|
347
376
|
"stream_options", generate_config.get("stream_options", None)
|
|
348
377
|
)
|
|
378
|
+
sanitized.setdefault(
|
|
379
|
+
"skip_special_tokens", generate_config.get("skip_special_tokens", True)
|
|
380
|
+
)
|
|
381
|
+
sanitized.setdefault(
|
|
382
|
+
"guided_json", generate_config.get("guided_json", guided_json)
|
|
383
|
+
)
|
|
384
|
+
sanitized.setdefault("guided_regex", generate_config.get("guided_regex", None))
|
|
385
|
+
sanitized.setdefault(
|
|
386
|
+
"guided_choice", generate_config.get("guided_choice", None)
|
|
387
|
+
)
|
|
388
|
+
sanitized.setdefault(
|
|
389
|
+
"guided_grammar", generate_config.get("guided_grammar", None)
|
|
390
|
+
)
|
|
391
|
+
sanitized.setdefault(
|
|
392
|
+
"guided_whitespace_pattern",
|
|
393
|
+
generate_config.get("guided_whitespace_pattern", None),
|
|
394
|
+
)
|
|
395
|
+
sanitized.setdefault(
|
|
396
|
+
"guided_json_object",
|
|
397
|
+
generate_config.get("guided_json_object", guided_json_object),
|
|
398
|
+
)
|
|
399
|
+
sanitized.setdefault(
|
|
400
|
+
"guided_decoding_backend",
|
|
401
|
+
generate_config.get("guided_decoding_backend", guided_decoding_backend),
|
|
402
|
+
)
|
|
349
403
|
|
|
350
404
|
return sanitized
|
|
351
405
|
|
|
@@ -483,13 +537,46 @@ class VLLMModel(LLM):
|
|
|
483
537
|
if isinstance(stream_options, dict)
|
|
484
538
|
else False
|
|
485
539
|
)
|
|
486
|
-
|
|
540
|
+
|
|
541
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.6.3":
|
|
542
|
+
# guided decoding only available for vllm >= 0.6.3
|
|
543
|
+
from vllm.sampling_params import GuidedDecodingParams
|
|
544
|
+
|
|
545
|
+
guided_options = GuidedDecodingParams.from_optional(
|
|
546
|
+
json=sanitized_generate_config.pop("guided_json", None),
|
|
547
|
+
regex=sanitized_generate_config.pop("guided_regex", None),
|
|
548
|
+
choice=sanitized_generate_config.pop("guided_choice", None),
|
|
549
|
+
grammar=sanitized_generate_config.pop("guided_grammar", None),
|
|
550
|
+
json_object=sanitized_generate_config.pop("guided_json_object", None),
|
|
551
|
+
backend=sanitized_generate_config.pop("guided_decoding_backend", None),
|
|
552
|
+
whitespace_pattern=sanitized_generate_config.pop(
|
|
553
|
+
"guided_whitespace_pattern", None
|
|
554
|
+
),
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
sampling_params = SamplingParams(
|
|
558
|
+
guided_decoding=guided_options, **sanitized_generate_config
|
|
559
|
+
)
|
|
560
|
+
else:
|
|
561
|
+
# ignore generate configs
|
|
562
|
+
sanitized_generate_config.pop("guided_json", None)
|
|
563
|
+
sanitized_generate_config.pop("guided_regex", None)
|
|
564
|
+
sanitized_generate_config.pop("guided_choice", None)
|
|
565
|
+
sanitized_generate_config.pop("guided_grammar", None)
|
|
566
|
+
sanitized_generate_config.pop("guided_json_object", None)
|
|
567
|
+
sanitized_generate_config.pop("guided_decoding_backend", None)
|
|
568
|
+
sanitized_generate_config.pop("guided_whitespace_pattern", None)
|
|
569
|
+
sampling_params = SamplingParams(**sanitized_generate_config)
|
|
570
|
+
|
|
487
571
|
if not request_id:
|
|
488
572
|
request_id = str(uuid.uuid1())
|
|
489
573
|
|
|
490
574
|
assert self._engine is not None
|
|
491
575
|
results_generator = self._engine.generate(
|
|
492
|
-
prompt,
|
|
576
|
+
prompt,
|
|
577
|
+
sampling_params,
|
|
578
|
+
request_id,
|
|
579
|
+
lora_request,
|
|
493
580
|
)
|
|
494
581
|
|
|
495
582
|
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
|
File without changes
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import sys
|
|
3
|
+
from importlib.resources import files
|
|
4
|
+
|
|
5
|
+
import soundfile as sf
|
|
6
|
+
import tqdm
|
|
7
|
+
from cached_path import cached_path
|
|
8
|
+
|
|
9
|
+
from f5_tts.infer.utils_infer import (
|
|
10
|
+
hop_length,
|
|
11
|
+
infer_process,
|
|
12
|
+
load_model,
|
|
13
|
+
load_vocoder,
|
|
14
|
+
preprocess_ref_audio_text,
|
|
15
|
+
remove_silence_for_generated_wav,
|
|
16
|
+
save_spectrogram,
|
|
17
|
+
transcribe,
|
|
18
|
+
target_sample_rate,
|
|
19
|
+
)
|
|
20
|
+
from f5_tts.model import DiT, UNetT
|
|
21
|
+
from f5_tts.model.utils import seed_everything
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class F5TTS:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
model_type="F5-TTS",
|
|
28
|
+
ckpt_file="",
|
|
29
|
+
vocab_file="",
|
|
30
|
+
ode_method="euler",
|
|
31
|
+
use_ema=True,
|
|
32
|
+
vocoder_name="vocos",
|
|
33
|
+
local_path=None,
|
|
34
|
+
device=None,
|
|
35
|
+
hf_cache_dir=None,
|
|
36
|
+
):
|
|
37
|
+
# Initialize parameters
|
|
38
|
+
self.final_wave = None
|
|
39
|
+
self.target_sample_rate = target_sample_rate
|
|
40
|
+
self.hop_length = hop_length
|
|
41
|
+
self.seed = -1
|
|
42
|
+
self.mel_spec_type = vocoder_name
|
|
43
|
+
|
|
44
|
+
# Set device
|
|
45
|
+
if device is not None:
|
|
46
|
+
self.device = device
|
|
47
|
+
else:
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
51
|
+
|
|
52
|
+
# Load models
|
|
53
|
+
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
|
54
|
+
self.load_ema_model(
|
|
55
|
+
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
|
|
59
|
+
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
|
|
60
|
+
|
|
61
|
+
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
|
|
62
|
+
if model_type == "F5-TTS":
|
|
63
|
+
if not ckpt_file:
|
|
64
|
+
if mel_spec_type == "vocos":
|
|
65
|
+
ckpt_file = str(
|
|
66
|
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
|
67
|
+
)
|
|
68
|
+
elif mel_spec_type == "bigvgan":
|
|
69
|
+
ckpt_file = str(
|
|
70
|
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
|
|
71
|
+
)
|
|
72
|
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
73
|
+
model_cls = DiT
|
|
74
|
+
elif model_type == "E2-TTS":
|
|
75
|
+
if not ckpt_file:
|
|
76
|
+
ckpt_file = str(
|
|
77
|
+
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
|
78
|
+
)
|
|
79
|
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
|
80
|
+
model_cls = UNetT
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Unknown model type: {model_type}")
|
|
83
|
+
|
|
84
|
+
self.ema_model = load_model(
|
|
85
|
+
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def transcribe(self, ref_audio, language=None):
|
|
89
|
+
return transcribe(ref_audio, language)
|
|
90
|
+
|
|
91
|
+
def export_wav(self, wav, file_wave, remove_silence=False):
|
|
92
|
+
sf.write(file_wave, wav, self.target_sample_rate)
|
|
93
|
+
|
|
94
|
+
if remove_silence:
|
|
95
|
+
remove_silence_for_generated_wav(file_wave)
|
|
96
|
+
|
|
97
|
+
def export_spectrogram(self, spect, file_spect):
|
|
98
|
+
save_spectrogram(spect, file_spect)
|
|
99
|
+
|
|
100
|
+
def infer(
|
|
101
|
+
self,
|
|
102
|
+
ref_file,
|
|
103
|
+
ref_text,
|
|
104
|
+
gen_text,
|
|
105
|
+
show_info=print,
|
|
106
|
+
progress=tqdm,
|
|
107
|
+
target_rms=0.1,
|
|
108
|
+
cross_fade_duration=0.15,
|
|
109
|
+
sway_sampling_coef=-1,
|
|
110
|
+
cfg_strength=2,
|
|
111
|
+
nfe_step=32,
|
|
112
|
+
speed=1.0,
|
|
113
|
+
fix_duration=None,
|
|
114
|
+
remove_silence=False,
|
|
115
|
+
file_wave=None,
|
|
116
|
+
file_spect=None,
|
|
117
|
+
seed=-1,
|
|
118
|
+
):
|
|
119
|
+
if seed == -1:
|
|
120
|
+
seed = random.randint(0, sys.maxsize)
|
|
121
|
+
seed_everything(seed)
|
|
122
|
+
self.seed = seed
|
|
123
|
+
|
|
124
|
+
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
|
125
|
+
|
|
126
|
+
wav, sr, spect = infer_process(
|
|
127
|
+
ref_file,
|
|
128
|
+
ref_text,
|
|
129
|
+
gen_text,
|
|
130
|
+
self.ema_model,
|
|
131
|
+
self.vocoder,
|
|
132
|
+
self.mel_spec_type,
|
|
133
|
+
show_info=show_info,
|
|
134
|
+
progress=progress,
|
|
135
|
+
target_rms=target_rms,
|
|
136
|
+
cross_fade_duration=cross_fade_duration,
|
|
137
|
+
nfe_step=nfe_step,
|
|
138
|
+
cfg_strength=cfg_strength,
|
|
139
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
140
|
+
speed=speed,
|
|
141
|
+
fix_duration=fix_duration,
|
|
142
|
+
device=self.device,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if file_wave is not None:
|
|
146
|
+
self.export_wav(wav, file_wave, remove_silence)
|
|
147
|
+
|
|
148
|
+
if file_spect is not None:
|
|
149
|
+
self.export_spectrogram(spect, file_spect)
|
|
150
|
+
|
|
151
|
+
return wav, sr, spect
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
f5tts = F5TTS()
|
|
156
|
+
|
|
157
|
+
wav, sr, spect = f5tts.infer(
|
|
158
|
+
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
|
159
|
+
ref_text="some call me nature, others call me mother nature.",
|
|
160
|
+
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
|
161
|
+
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
|
162
|
+
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
|
163
|
+
seed=-1, # random seed = -1
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
print("seed :", f5tts.seed)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN # dataset name
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0 # gradient clipping
|
|
18
|
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: E2TTS_Base
|
|
22
|
+
tokenizer: pinyin
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 1024
|
|
26
|
+
depth: 24
|
|
27
|
+
heads: 16
|
|
28
|
+
ff_mult: 4
|
|
29
|
+
mel_spec:
|
|
30
|
+
target_sample_rate: 24000
|
|
31
|
+
n_mel_channels: 100
|
|
32
|
+
hop_length: 256
|
|
33
|
+
win_length: 1024
|
|
34
|
+
n_fft: 1024
|
|
35
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
36
|
+
vocoder:
|
|
37
|
+
is_local: False # use local offline ckpt or not
|
|
38
|
+
local_path: None # local vocoder path
|
|
39
|
+
|
|
40
|
+
ckpts:
|
|
41
|
+
logger: wandb # wandb | tensorboard | None
|
|
42
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
43
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
44
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0
|
|
18
|
+
bnb_optimizer: False
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: E2TTS_Small
|
|
22
|
+
tokenizer: pinyin
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 768
|
|
26
|
+
depth: 20
|
|
27
|
+
heads: 12
|
|
28
|
+
ff_mult: 4
|
|
29
|
+
mel_spec:
|
|
30
|
+
target_sample_rate: 24000
|
|
31
|
+
n_mel_channels: 100
|
|
32
|
+
hop_length: 256
|
|
33
|
+
win_length: 1024
|
|
34
|
+
n_fft: 1024
|
|
35
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
36
|
+
vocoder:
|
|
37
|
+
is_local: False # use local offline ckpt or not
|
|
38
|
+
local_path: None # local vocoder path
|
|
39
|
+
|
|
40
|
+
ckpts:
|
|
41
|
+
logger: wandb # wandb | tensorboard | None
|
|
42
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
43
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
44
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN # dataset name
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0 # gradient clipping
|
|
18
|
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: F5TTS_Base # model name
|
|
22
|
+
tokenizer: pinyin # tokenizer type
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 1024
|
|
26
|
+
depth: 22
|
|
27
|
+
heads: 16
|
|
28
|
+
ff_mult: 2
|
|
29
|
+
text_dim: 512
|
|
30
|
+
conv_layers: 4
|
|
31
|
+
mel_spec:
|
|
32
|
+
target_sample_rate: 24000
|
|
33
|
+
n_mel_channels: 100
|
|
34
|
+
hop_length: 256
|
|
35
|
+
win_length: 1024
|
|
36
|
+
n_fft: 1024
|
|
37
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
38
|
+
vocoder:
|
|
39
|
+
is_local: False # use local offline ckpt or not
|
|
40
|
+
local_path: None # local vocoder path
|
|
41
|
+
|
|
42
|
+
ckpts:
|
|
43
|
+
logger: wandb # wandb | tensorboard | None
|
|
44
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
45
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
46
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0 # gradient clipping
|
|
18
|
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: F5TTS_Small
|
|
22
|
+
tokenizer: pinyin
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 768
|
|
26
|
+
depth: 18
|
|
27
|
+
heads: 12
|
|
28
|
+
ff_mult: 2
|
|
29
|
+
text_dim: 512
|
|
30
|
+
conv_layers: 4
|
|
31
|
+
mel_spec:
|
|
32
|
+
target_sample_rate: 24000
|
|
33
|
+
n_mel_channels: 100
|
|
34
|
+
hop_length: 256
|
|
35
|
+
win_length: 1024
|
|
36
|
+
n_fft: 1024
|
|
37
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
38
|
+
vocoder:
|
|
39
|
+
is_local: False # use local offline ckpt or not
|
|
40
|
+
local_path: None # local vocoder path
|
|
41
|
+
|
|
42
|
+
ckpts:
|
|
43
|
+
logger: wandb # wandb | tensorboard | None
|
|
44
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
45
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
46
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
|
|
2
|
+
# Evaluation
|
|
3
|
+
|
|
4
|
+
Install packages for evaluation:
|
|
5
|
+
|
|
6
|
+
```bash
|
|
7
|
+
pip install -e .[eval]
|
|
8
|
+
```
|
|
9
|
+
|
|
10
|
+
## Generating Samples for Evaluation
|
|
11
|
+
|
|
12
|
+
### Prepare Test Datasets
|
|
13
|
+
|
|
14
|
+
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
|
15
|
+
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
|
16
|
+
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
|
17
|
+
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
|
18
|
+
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
|
19
|
+
|
|
20
|
+
### Batch Inference for Test Set
|
|
21
|
+
|
|
22
|
+
To run batch inference for evaluations, execute the following commands:
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
# batch inference for evaluations
|
|
26
|
+
accelerate config # if not set before
|
|
27
|
+
bash src/f5_tts/eval/eval_infer_batch.sh
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Objective Evaluation on Generated Results
|
|
31
|
+
|
|
32
|
+
### Download Evaluation Model Checkpoints
|
|
33
|
+
|
|
34
|
+
1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
|
|
35
|
+
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
|
36
|
+
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
|
37
|
+
|
|
38
|
+
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
|
39
|
+
|
|
40
|
+
### Objective Evaluation
|
|
41
|
+
|
|
42
|
+
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
|
43
|
+
```bash
|
|
44
|
+
# Evaluation for Seed-TTS test set
|
|
45
|
+
python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
|
|
46
|
+
|
|
47
|
+
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
|
48
|
+
python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
|
|
49
|
+
```
|