xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import torch
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@click.command()
|
|
7
|
+
@click.argument("model_path")
|
|
8
|
+
@click.argument("output_path")
|
|
9
|
+
def main(model_path, output_path):
|
|
10
|
+
if model_path == output_path:
|
|
11
|
+
logger.error("Model path and output path are the same")
|
|
12
|
+
return
|
|
13
|
+
|
|
14
|
+
logger.info(f"Loading model from {model_path}")
|
|
15
|
+
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
|
16
|
+
torch.save(state_dict, output_path)
|
|
17
|
+
logger.info(f"Model saved to {output_path}")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if __name__ == "__main__":
|
|
21
|
+
main()
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from natsort import natsorted
|
|
6
|
+
|
|
7
|
+
AUDIO_EXTENSIONS = {
|
|
8
|
+
".mp3",
|
|
9
|
+
".wav",
|
|
10
|
+
".flac",
|
|
11
|
+
".ogg",
|
|
12
|
+
".m4a",
|
|
13
|
+
".wma",
|
|
14
|
+
".aac",
|
|
15
|
+
".aiff",
|
|
16
|
+
".aif",
|
|
17
|
+
".aifc",
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
VIDEO_EXTENSIONS = {
|
|
21
|
+
".mp4",
|
|
22
|
+
".avi",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def list_files(
|
|
27
|
+
path: Union[Path, str],
|
|
28
|
+
extensions: set[str] = None,
|
|
29
|
+
recursive: bool = False,
|
|
30
|
+
sort: bool = True,
|
|
31
|
+
) -> list[Path]:
|
|
32
|
+
"""List files in a directory.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
path (Path): Path to the directory.
|
|
36
|
+
extensions (set, optional): Extensions to filter. Defaults to None.
|
|
37
|
+
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
|
38
|
+
sort (bool, optional): Whether to sort the files. Defaults to True.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
list: List of files.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
if isinstance(path, str):
|
|
45
|
+
path = Path(path)
|
|
46
|
+
|
|
47
|
+
if not path.exists():
|
|
48
|
+
raise FileNotFoundError(f"Directory {path} does not exist.")
|
|
49
|
+
|
|
50
|
+
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
|
51
|
+
|
|
52
|
+
if sort:
|
|
53
|
+
files = natsorted(files)
|
|
54
|
+
|
|
55
|
+
return files
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
|
59
|
+
"""
|
|
60
|
+
Load a Bert-VITS2 style filelist.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
files = set()
|
|
64
|
+
results = []
|
|
65
|
+
count_duplicated, count_not_found = 0, 0
|
|
66
|
+
|
|
67
|
+
LANGUAGE_TO_LANGUAGES = {
|
|
68
|
+
"zh": ["zh", "en"],
|
|
69
|
+
"jp": ["jp", "en"],
|
|
70
|
+
"en": ["en"],
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
74
|
+
for line in f.readlines():
|
|
75
|
+
splits = line.strip().split("|", maxsplit=3)
|
|
76
|
+
if len(splits) != 4:
|
|
77
|
+
logger.warning(f"Invalid line: {line}")
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
filename, speaker, language, text = splits
|
|
81
|
+
file = Path(filename)
|
|
82
|
+
language = language.strip().lower()
|
|
83
|
+
|
|
84
|
+
if language == "ja":
|
|
85
|
+
language = "jp"
|
|
86
|
+
|
|
87
|
+
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
|
88
|
+
languages = LANGUAGE_TO_LANGUAGES[language]
|
|
89
|
+
|
|
90
|
+
if file in files:
|
|
91
|
+
logger.warning(f"Duplicated file: {file}")
|
|
92
|
+
count_duplicated += 1
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
if not file.exists():
|
|
96
|
+
logger.warning(f"File not found: {file}")
|
|
97
|
+
count_not_found += 1
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
results.append((file, speaker, languages, text))
|
|
101
|
+
|
|
102
|
+
if count_duplicated > 0:
|
|
103
|
+
logger.warning(f"Total duplicated files: {count_duplicated}")
|
|
104
|
+
|
|
105
|
+
if count_not_found > 0:
|
|
106
|
+
logger.warning(f"Total files not found: {count_not_found}")
|
|
107
|
+
|
|
108
|
+
return results
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def scan_folder(base_path):
|
|
6
|
+
wav_lab_pairs = {}
|
|
7
|
+
|
|
8
|
+
base = Path(base_path)
|
|
9
|
+
for suf in ["wav", "lab"]:
|
|
10
|
+
for f in base.rglob(f"*.{suf}"):
|
|
11
|
+
relative_path = f.relative_to(base)
|
|
12
|
+
parts = relative_path.parts
|
|
13
|
+
print(parts)
|
|
14
|
+
if len(parts) >= 3:
|
|
15
|
+
character = parts[0]
|
|
16
|
+
emotion = parts[1]
|
|
17
|
+
|
|
18
|
+
if character not in wav_lab_pairs:
|
|
19
|
+
wav_lab_pairs[character] = {}
|
|
20
|
+
if emotion not in wav_lab_pairs[character]:
|
|
21
|
+
wav_lab_pairs[character][emotion] = []
|
|
22
|
+
wav_lab_pairs[character][emotion].append(str(f.name))
|
|
23
|
+
|
|
24
|
+
return wav_lab_pairs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def save_to_json(data, output_file):
|
|
28
|
+
with open(output_file, "w", encoding="utf-8") as file:
|
|
29
|
+
json.dump(data, file, ensure_ascii=False, indent=2)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
base_path = "ref_data"
|
|
33
|
+
out_ref_file = "ref_data.json"
|
|
34
|
+
|
|
35
|
+
wav_lab_pairs = scan_folder(base_path)
|
|
36
|
+
save_to_json(wav_lab_pairs, out_ref_file)
|
|
File without changes
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from functools import partial
|
|
6
|
+
from multiprocessing import Pool
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
import numpy as np
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
|
15
|
+
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
|
16
|
+
from fish_speech.utils.file import load_filelist
|
|
17
|
+
|
|
18
|
+
# To avoid CPU overload
|
|
19
|
+
os.environ["MKL_NUM_THREADS"] = "1"
|
|
20
|
+
os.environ["OMP_NUM_THREADS"] = "1"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def task_generator_folder(root: Path, text_extension: str):
|
|
24
|
+
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
|
25
|
+
files = sorted(files)
|
|
26
|
+
|
|
27
|
+
grouped_files = defaultdict(list)
|
|
28
|
+
for file in tqdm(files, desc=f"Grouping {root}"):
|
|
29
|
+
p = str(file.parent)
|
|
30
|
+
speaker = file.parent.name
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
if isinstance(text_extension, str):
|
|
34
|
+
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
|
|
35
|
+
else:
|
|
36
|
+
texts = [
|
|
37
|
+
file.with_suffix(ext).read_text(encoding="utf-8")
|
|
38
|
+
for ext in text_extension
|
|
39
|
+
]
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logger.error(f"Failed to read text {file}: {e}")
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
grouped_files[p].append((speaker, file, texts))
|
|
45
|
+
|
|
46
|
+
logger.info(
|
|
47
|
+
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
for i in grouped_files.values():
|
|
51
|
+
subset = [(f, t) for _, f, t in i]
|
|
52
|
+
yield i[0][0], subset, "folder"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def task_generator_filelist(filelist):
|
|
56
|
+
grouped_files = defaultdict(list)
|
|
57
|
+
for filename, speaker, _, text in load_filelist(filelist):
|
|
58
|
+
grouped_files[speaker].append((Path(filename), [text]))
|
|
59
|
+
|
|
60
|
+
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
|
61
|
+
for speaker, values in grouped_files.items():
|
|
62
|
+
yield speaker, values, "filelist"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def run_task(task):
|
|
66
|
+
name, subset, source = task
|
|
67
|
+
|
|
68
|
+
# Parse the files
|
|
69
|
+
sentences = []
|
|
70
|
+
for file, texts in subset:
|
|
71
|
+
np_file = file.with_suffix(".npy")
|
|
72
|
+
if np_file.exists() is False:
|
|
73
|
+
logger.warning(f"Can't find {np_file}")
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
new_texts = []
|
|
77
|
+
|
|
78
|
+
for text in texts:
|
|
79
|
+
# Simple cleaning: replace { xxx } and < xxx > with space
|
|
80
|
+
text = re.sub(r"\{.*?\}", " ", text)
|
|
81
|
+
text = re.sub(r"<.*?>", " ", text)
|
|
82
|
+
text = re.sub(r"\s+", " ", text)
|
|
83
|
+
new_texts.append(text)
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
semantics = np.load(np_file)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.error(f"Failed to parse {file}: {e}")
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
if isinstance(semantics, np.ndarray):
|
|
92
|
+
semantics = semantics.tolist()
|
|
93
|
+
|
|
94
|
+
sentences.append(
|
|
95
|
+
Sentence(
|
|
96
|
+
texts=new_texts,
|
|
97
|
+
semantics=[Semantics(values=s) for s in semantics],
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Pack the sentences
|
|
102
|
+
return pack_pb_stream(
|
|
103
|
+
TextData(
|
|
104
|
+
source=source,
|
|
105
|
+
name=name,
|
|
106
|
+
sentences=sentences,
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@click.command()
|
|
112
|
+
@click.option(
|
|
113
|
+
"--input",
|
|
114
|
+
type=click.Path(path_type=Path),
|
|
115
|
+
required=True,
|
|
116
|
+
help="A folder containing the dataset or a filelist",
|
|
117
|
+
multiple=True,
|
|
118
|
+
)
|
|
119
|
+
@click.option(
|
|
120
|
+
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
|
121
|
+
)
|
|
122
|
+
@click.option("--num-workers", type=int, default=16)
|
|
123
|
+
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
|
124
|
+
@click.option(
|
|
125
|
+
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
|
126
|
+
)
|
|
127
|
+
def main(input, output, num_workers, text_extension, shard_size):
|
|
128
|
+
generator_fns = []
|
|
129
|
+
|
|
130
|
+
for f in input:
|
|
131
|
+
assert f.exists(), f"{f} not found"
|
|
132
|
+
|
|
133
|
+
if f.is_dir():
|
|
134
|
+
generator_fn = task_generator_folder(f, text_extension)
|
|
135
|
+
else:
|
|
136
|
+
generator_fn = task_generator_filelist(f)
|
|
137
|
+
|
|
138
|
+
generator_fns.append(generator_fn)
|
|
139
|
+
|
|
140
|
+
generator_fn = itertools.chain(*generator_fns)
|
|
141
|
+
output.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
|
|
143
|
+
dataset_fp = None
|
|
144
|
+
tar_idx = 0
|
|
145
|
+
written_size = 0
|
|
146
|
+
|
|
147
|
+
with Pool(num_workers) as p:
|
|
148
|
+
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
|
149
|
+
if dataset_fp is None:
|
|
150
|
+
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
|
151
|
+
|
|
152
|
+
dataset_fp.write(result)
|
|
153
|
+
written_size += len(result)
|
|
154
|
+
|
|
155
|
+
if written_size > shard_size * 1024 * 1024:
|
|
156
|
+
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
|
157
|
+
dataset_fp.close()
|
|
158
|
+
dataset_fp = None
|
|
159
|
+
written_size = 0
|
|
160
|
+
tar_idx += 1
|
|
161
|
+
|
|
162
|
+
if dataset_fp is not None:
|
|
163
|
+
dataset_fp.close()
|
|
164
|
+
|
|
165
|
+
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if __name__ == "__main__":
|
|
169
|
+
main()
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# import pyrootutils
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from matplotlib import pyplot as plt
|
|
5
|
+
from transformers import AutoTokenizer
|
|
6
|
+
|
|
7
|
+
# register eval resolver and root
|
|
8
|
+
# pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
9
|
+
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
|
|
12
|
+
from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
|
|
13
|
+
from tools.llama.generate import load_model
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def smooth(
|
|
17
|
+
scalars: list[float], weight: float
|
|
18
|
+
) -> list[float]: # Weight between 0 and 1
|
|
19
|
+
last = scalars[0] # First value in the plot (first timestep)
|
|
20
|
+
smoothed = list()
|
|
21
|
+
for point in scalars:
|
|
22
|
+
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
|
|
23
|
+
smoothed.append(smoothed_val) # Save it
|
|
24
|
+
last = smoothed_val # Anchor the last smoothed value
|
|
25
|
+
|
|
26
|
+
return smoothed
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@torch.inference_mode()
|
|
30
|
+
def analyze_one_model(loader, config, weight, max_length):
|
|
31
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
32
|
+
model = load_model(
|
|
33
|
+
config,
|
|
34
|
+
weight,
|
|
35
|
+
device,
|
|
36
|
+
torch.bfloat16,
|
|
37
|
+
max_length,
|
|
38
|
+
compile=False,
|
|
39
|
+
)[0]
|
|
40
|
+
|
|
41
|
+
current_step = 0
|
|
42
|
+
model.eval()
|
|
43
|
+
|
|
44
|
+
semantic_loss_sum = torch.zeros(
|
|
45
|
+
max_length,
|
|
46
|
+
dtype=torch.float32,
|
|
47
|
+
device=device,
|
|
48
|
+
)
|
|
49
|
+
counter = torch.zeros(
|
|
50
|
+
max_length,
|
|
51
|
+
dtype=torch.long,
|
|
52
|
+
device=device,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
for batch in loader:
|
|
56
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
57
|
+
|
|
58
|
+
labels = batch["labels"]
|
|
59
|
+
outputs = model(
|
|
60
|
+
inp=batch["inputs"],
|
|
61
|
+
key_padding_mask=batch["attention_masks"],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
token_logits = outputs.token_logits
|
|
65
|
+
codebook_logits = outputs.codebook_logits
|
|
66
|
+
|
|
67
|
+
# Generate labels
|
|
68
|
+
base_loss = F.cross_entropy(
|
|
69
|
+
token_logits.reshape(-1, token_logits.size(-1)),
|
|
70
|
+
labels[:, 0].reshape(-1),
|
|
71
|
+
ignore_index=-100,
|
|
72
|
+
reduction="none",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
|
|
76
|
+
semantic_loss = F.cross_entropy(
|
|
77
|
+
codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
|
78
|
+
codebook_labels.reshape(-1),
|
|
79
|
+
ignore_index=-100,
|
|
80
|
+
reduction="none",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
base_loss = base_loss.reshape(labels[:, 0].shape)
|
|
84
|
+
semantic_loss = semantic_loss.reshape(codebook_labels.shape)
|
|
85
|
+
|
|
86
|
+
semantic_loss_frame = semantic_loss.mean(-1)
|
|
87
|
+
pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
|
|
88
|
+
|
|
89
|
+
for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
|
|
90
|
+
semantic_loss_sum[~pad] += loss_sample[~pad]
|
|
91
|
+
counter[~pad] += 1
|
|
92
|
+
|
|
93
|
+
current_step += 1
|
|
94
|
+
if current_step == 10:
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
semantic_loss = semantic_loss.cpu()
|
|
98
|
+
counter = counter.cpu()
|
|
99
|
+
xs, ys = [], []
|
|
100
|
+
|
|
101
|
+
for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
|
|
102
|
+
if count > 0:
|
|
103
|
+
xs.append(i)
|
|
104
|
+
ys.append((loss / count).item()) # for better loss visualization
|
|
105
|
+
|
|
106
|
+
smoothed_ys = smooth(ys, 0.95)
|
|
107
|
+
|
|
108
|
+
# Unload model
|
|
109
|
+
del model
|
|
110
|
+
torch.cuda.empty_cache()
|
|
111
|
+
|
|
112
|
+
return xs, ys, smoothed_ys
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def main():
|
|
116
|
+
tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
|
|
117
|
+
max_length = 4096
|
|
118
|
+
|
|
119
|
+
ds = AutoAugTextDataset(
|
|
120
|
+
["data/protos/sft/云天河"],
|
|
121
|
+
tokenizer=tokenizer,
|
|
122
|
+
use_speaker=False,
|
|
123
|
+
interactive_prob=1.0,
|
|
124
|
+
max_length=max_length,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
loader = DataLoader(
|
|
128
|
+
ds,
|
|
129
|
+
batch_size=8,
|
|
130
|
+
collate_fn=TextDataCollator(tokenizer, max_length=max_length),
|
|
131
|
+
num_workers=0,
|
|
132
|
+
shuffle=False,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
plt.figure(figsize=(10, 5), dpi=200)
|
|
136
|
+
|
|
137
|
+
plt.xlabel("Frame")
|
|
138
|
+
plt.ylabel("Loss")
|
|
139
|
+
plt.yscale("log")
|
|
140
|
+
plt.title("Semantic Loss")
|
|
141
|
+
plt.grid(which="both", axis="both")
|
|
142
|
+
plt.xlim(0, max_length)
|
|
143
|
+
|
|
144
|
+
tests = [
|
|
145
|
+
(
|
|
146
|
+
"pertrain-medium",
|
|
147
|
+
"dual_ar_2_codebook_medium",
|
|
148
|
+
"checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
|
|
149
|
+
),
|
|
150
|
+
(
|
|
151
|
+
"sft-medium",
|
|
152
|
+
"dual_ar_2_codebook_medium",
|
|
153
|
+
"checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
|
|
154
|
+
),
|
|
155
|
+
(
|
|
156
|
+
"sft-large",
|
|
157
|
+
"dual_ar_2_codebook_large",
|
|
158
|
+
"checkpoints/text2semantic-sft-large-v1.1-4k.pth",
|
|
159
|
+
),
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
for name, config, weight in tests:
|
|
163
|
+
xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
|
|
164
|
+
plt.plot(xs, smoothed_ys, label=name)
|
|
165
|
+
|
|
166
|
+
plt.legend()
|
|
167
|
+
plt.savefig("semantic_loss.png")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
if __name__ == "__main__":
|
|
171
|
+
main()
|