xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ import click
2
+ import torch
3
+ from loguru import logger
4
+
5
+
6
+ @click.command()
7
+ @click.argument("model_path")
8
+ @click.argument("output_path")
9
+ def main(model_path, output_path):
10
+ if model_path == output_path:
11
+ logger.error("Model path and output path are the same")
12
+ return
13
+
14
+ logger.info(f"Loading model from {model_path}")
15
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16
+ torch.save(state_dict, output_path)
17
+ logger.info(f"Model saved to {output_path}")
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
@@ -0,0 +1,108 @@
1
+ from pathlib import Path
2
+ from typing import Union
3
+
4
+ from loguru import logger
5
+ from natsort import natsorted
6
+
7
+ AUDIO_EXTENSIONS = {
8
+ ".mp3",
9
+ ".wav",
10
+ ".flac",
11
+ ".ogg",
12
+ ".m4a",
13
+ ".wma",
14
+ ".aac",
15
+ ".aiff",
16
+ ".aif",
17
+ ".aifc",
18
+ }
19
+
20
+ VIDEO_EXTENSIONS = {
21
+ ".mp4",
22
+ ".avi",
23
+ }
24
+
25
+
26
+ def list_files(
27
+ path: Union[Path, str],
28
+ extensions: set[str] = None,
29
+ recursive: bool = False,
30
+ sort: bool = True,
31
+ ) -> list[Path]:
32
+ """List files in a directory.
33
+
34
+ Args:
35
+ path (Path): Path to the directory.
36
+ extensions (set, optional): Extensions to filter. Defaults to None.
37
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
38
+ sort (bool, optional): Whether to sort the files. Defaults to True.
39
+
40
+ Returns:
41
+ list: List of files.
42
+ """
43
+
44
+ if isinstance(path, str):
45
+ path = Path(path)
46
+
47
+ if not path.exists():
48
+ raise FileNotFoundError(f"Directory {path} does not exist.")
49
+
50
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
51
+
52
+ if sort:
53
+ files = natsorted(files)
54
+
55
+ return files
56
+
57
+
58
+ def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
59
+ """
60
+ Load a Bert-VITS2 style filelist.
61
+ """
62
+
63
+ files = set()
64
+ results = []
65
+ count_duplicated, count_not_found = 0, 0
66
+
67
+ LANGUAGE_TO_LANGUAGES = {
68
+ "zh": ["zh", "en"],
69
+ "jp": ["jp", "en"],
70
+ "en": ["en"],
71
+ }
72
+
73
+ with open(path, "r", encoding="utf-8") as f:
74
+ for line in f.readlines():
75
+ splits = line.strip().split("|", maxsplit=3)
76
+ if len(splits) != 4:
77
+ logger.warning(f"Invalid line: {line}")
78
+ continue
79
+
80
+ filename, speaker, language, text = splits
81
+ file = Path(filename)
82
+ language = language.strip().lower()
83
+
84
+ if language == "ja":
85
+ language = "jp"
86
+
87
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
88
+ languages = LANGUAGE_TO_LANGUAGES[language]
89
+
90
+ if file in files:
91
+ logger.warning(f"Duplicated file: {file}")
92
+ count_duplicated += 1
93
+ continue
94
+
95
+ if not file.exists():
96
+ logger.warning(f"File not found: {file}")
97
+ count_not_found += 1
98
+ continue
99
+
100
+ results.append((file, speaker, languages, text))
101
+
102
+ if count_duplicated > 0:
103
+ logger.warning(f"Total duplicated files: {count_duplicated}")
104
+
105
+ if count_not_found > 0:
106
+ logger.warning(f"Total files not found: {count_not_found}")
107
+
108
+ return results
@@ -0,0 +1,36 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+
5
+ def scan_folder(base_path):
6
+ wav_lab_pairs = {}
7
+
8
+ base = Path(base_path)
9
+ for suf in ["wav", "lab"]:
10
+ for f in base.rglob(f"*.{suf}"):
11
+ relative_path = f.relative_to(base)
12
+ parts = relative_path.parts
13
+ print(parts)
14
+ if len(parts) >= 3:
15
+ character = parts[0]
16
+ emotion = parts[1]
17
+
18
+ if character not in wav_lab_pairs:
19
+ wav_lab_pairs[character] = {}
20
+ if emotion not in wav_lab_pairs[character]:
21
+ wav_lab_pairs[character][emotion] = []
22
+ wav_lab_pairs[character][emotion].append(str(f.name))
23
+
24
+ return wav_lab_pairs
25
+
26
+
27
+ def save_to_json(data, output_file):
28
+ with open(output_file, "w", encoding="utf-8") as file:
29
+ json.dump(data, file, ensure_ascii=False, indent=2)
30
+
31
+
32
+ base_path = "ref_data"
33
+ out_ref_file = "ref_data.json"
34
+
35
+ wav_lab_pairs = scan_folder(base_path)
36
+ save_to_json(wav_lab_pairs, out_ref_file)
@@ -0,0 +1,169 @@
1
+ import itertools
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ from functools import partial
6
+ from multiprocessing import Pool
7
+ from pathlib import Path
8
+
9
+ import click
10
+ import numpy as np
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
+ from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
+ from fish_speech.utils.file import load_filelist
17
+
18
+ # To avoid CPU overload
19
+ os.environ["MKL_NUM_THREADS"] = "1"
20
+ os.environ["OMP_NUM_THREADS"] = "1"
21
+
22
+
23
+ def task_generator_folder(root: Path, text_extension: str):
24
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
25
+ files = sorted(files)
26
+
27
+ grouped_files = defaultdict(list)
28
+ for file in tqdm(files, desc=f"Grouping {root}"):
29
+ p = str(file.parent)
30
+ speaker = file.parent.name
31
+
32
+ try:
33
+ if isinstance(text_extension, str):
34
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
35
+ else:
36
+ texts = [
37
+ file.with_suffix(ext).read_text(encoding="utf-8")
38
+ for ext in text_extension
39
+ ]
40
+ except Exception as e:
41
+ logger.error(f"Failed to read text {file}: {e}")
42
+ continue
43
+
44
+ grouped_files[p].append((speaker, file, texts))
45
+
46
+ logger.info(
47
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
48
+ )
49
+
50
+ for i in grouped_files.values():
51
+ subset = [(f, t) for _, f, t in i]
52
+ yield i[0][0], subset, "folder"
53
+
54
+
55
+ def task_generator_filelist(filelist):
56
+ grouped_files = defaultdict(list)
57
+ for filename, speaker, _, text in load_filelist(filelist):
58
+ grouped_files[speaker].append((Path(filename), [text]))
59
+
60
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
61
+ for speaker, values in grouped_files.items():
62
+ yield speaker, values, "filelist"
63
+
64
+
65
+ def run_task(task):
66
+ name, subset, source = task
67
+
68
+ # Parse the files
69
+ sentences = []
70
+ for file, texts in subset:
71
+ np_file = file.with_suffix(".npy")
72
+ if np_file.exists() is False:
73
+ logger.warning(f"Can't find {np_file}")
74
+ continue
75
+
76
+ new_texts = []
77
+
78
+ for text in texts:
79
+ # Simple cleaning: replace { xxx } and < xxx > with space
80
+ text = re.sub(r"\{.*?\}", " ", text)
81
+ text = re.sub(r"<.*?>", " ", text)
82
+ text = re.sub(r"\s+", " ", text)
83
+ new_texts.append(text)
84
+
85
+ try:
86
+ semantics = np.load(np_file)
87
+ except Exception as e:
88
+ logger.error(f"Failed to parse {file}: {e}")
89
+ continue
90
+
91
+ if isinstance(semantics, np.ndarray):
92
+ semantics = semantics.tolist()
93
+
94
+ sentences.append(
95
+ Sentence(
96
+ texts=new_texts,
97
+ semantics=[Semantics(values=s) for s in semantics],
98
+ )
99
+ )
100
+
101
+ # Pack the sentences
102
+ return pack_pb_stream(
103
+ TextData(
104
+ source=source,
105
+ name=name,
106
+ sentences=sentences,
107
+ )
108
+ )
109
+
110
+
111
+ @click.command()
112
+ @click.option(
113
+ "--input",
114
+ type=click.Path(path_type=Path),
115
+ required=True,
116
+ help="A folder containing the dataset or a filelist",
117
+ multiple=True,
118
+ )
119
+ @click.option(
120
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
121
+ )
122
+ @click.option("--num-workers", type=int, default=16)
123
+ @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
124
+ @click.option(
125
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
126
+ )
127
+ def main(input, output, num_workers, text_extension, shard_size):
128
+ generator_fns = []
129
+
130
+ for f in input:
131
+ assert f.exists(), f"{f} not found"
132
+
133
+ if f.is_dir():
134
+ generator_fn = task_generator_folder(f, text_extension)
135
+ else:
136
+ generator_fn = task_generator_filelist(f)
137
+
138
+ generator_fns.append(generator_fn)
139
+
140
+ generator_fn = itertools.chain(*generator_fns)
141
+ output.mkdir(parents=True, exist_ok=True)
142
+
143
+ dataset_fp = None
144
+ tar_idx = 0
145
+ written_size = 0
146
+
147
+ with Pool(num_workers) as p:
148
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
149
+ if dataset_fp is None:
150
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
151
+
152
+ dataset_fp.write(result)
153
+ written_size += len(result)
154
+
155
+ if written_size > shard_size * 1024 * 1024:
156
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
157
+ dataset_fp.close()
158
+ dataset_fp = None
159
+ written_size = 0
160
+ tar_idx += 1
161
+
162
+ if dataset_fp is not None:
163
+ dataset_fp.close()
164
+
165
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
@@ -0,0 +1,171 @@
1
+ # import pyrootutils
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from matplotlib import pyplot as plt
5
+ from transformers import AutoTokenizer
6
+
7
+ # register eval resolver and root
8
+ # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9
+
10
+ from torch.utils.data import DataLoader
11
+
12
+ from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
13
+ from tools.llama.generate import load_model
14
+
15
+
16
+ def smooth(
17
+ scalars: list[float], weight: float
18
+ ) -> list[float]: # Weight between 0 and 1
19
+ last = scalars[0] # First value in the plot (first timestep)
20
+ smoothed = list()
21
+ for point in scalars:
22
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
23
+ smoothed.append(smoothed_val) # Save it
24
+ last = smoothed_val # Anchor the last smoothed value
25
+
26
+ return smoothed
27
+
28
+
29
+ @torch.inference_mode()
30
+ def analyze_one_model(loader, config, weight, max_length):
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ model = load_model(
33
+ config,
34
+ weight,
35
+ device,
36
+ torch.bfloat16,
37
+ max_length,
38
+ compile=False,
39
+ )[0]
40
+
41
+ current_step = 0
42
+ model.eval()
43
+
44
+ semantic_loss_sum = torch.zeros(
45
+ max_length,
46
+ dtype=torch.float32,
47
+ device=device,
48
+ )
49
+ counter = torch.zeros(
50
+ max_length,
51
+ dtype=torch.long,
52
+ device=device,
53
+ )
54
+
55
+ for batch in loader:
56
+ batch = {k: v.to(device) for k, v in batch.items()}
57
+
58
+ labels = batch["labels"]
59
+ outputs = model(
60
+ inp=batch["inputs"],
61
+ key_padding_mask=batch["attention_masks"],
62
+ )
63
+
64
+ token_logits = outputs.token_logits
65
+ codebook_logits = outputs.codebook_logits
66
+
67
+ # Generate labels
68
+ base_loss = F.cross_entropy(
69
+ token_logits.reshape(-1, token_logits.size(-1)),
70
+ labels[:, 0].reshape(-1),
71
+ ignore_index=-100,
72
+ reduction="none",
73
+ )
74
+
75
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
76
+ semantic_loss = F.cross_entropy(
77
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
78
+ codebook_labels.reshape(-1),
79
+ ignore_index=-100,
80
+ reduction="none",
81
+ )
82
+
83
+ base_loss = base_loss.reshape(labels[:, 0].shape)
84
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
85
+
86
+ semantic_loss_frame = semantic_loss.mean(-1)
87
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
88
+
89
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
90
+ semantic_loss_sum[~pad] += loss_sample[~pad]
91
+ counter[~pad] += 1
92
+
93
+ current_step += 1
94
+ if current_step == 10:
95
+ break
96
+
97
+ semantic_loss = semantic_loss.cpu()
98
+ counter = counter.cpu()
99
+ xs, ys = [], []
100
+
101
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
102
+ if count > 0:
103
+ xs.append(i)
104
+ ys.append((loss / count).item()) # for better loss visualization
105
+
106
+ smoothed_ys = smooth(ys, 0.95)
107
+
108
+ # Unload model
109
+ del model
110
+ torch.cuda.empty_cache()
111
+
112
+ return xs, ys, smoothed_ys
113
+
114
+
115
+ def main():
116
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
117
+ max_length = 4096
118
+
119
+ ds = AutoAugTextDataset(
120
+ ["data/protos/sft/云天河"],
121
+ tokenizer=tokenizer,
122
+ use_speaker=False,
123
+ interactive_prob=1.0,
124
+ max_length=max_length,
125
+ )
126
+
127
+ loader = DataLoader(
128
+ ds,
129
+ batch_size=8,
130
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
131
+ num_workers=0,
132
+ shuffle=False,
133
+ )
134
+
135
+ plt.figure(figsize=(10, 5), dpi=200)
136
+
137
+ plt.xlabel("Frame")
138
+ plt.ylabel("Loss")
139
+ plt.yscale("log")
140
+ plt.title("Semantic Loss")
141
+ plt.grid(which="both", axis="both")
142
+ plt.xlim(0, max_length)
143
+
144
+ tests = [
145
+ (
146
+ "pertrain-medium",
147
+ "dual_ar_2_codebook_medium",
148
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
149
+ ),
150
+ (
151
+ "sft-medium",
152
+ "dual_ar_2_codebook_medium",
153
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
154
+ ),
155
+ (
156
+ "sft-large",
157
+ "dual_ar_2_codebook_large",
158
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
159
+ ),
160
+ ]
161
+
162
+ for name, config, weight in tests:
163
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
164
+ plt.plot(xs, smoothed_ys, label=name)
165
+
166
+ plt.legend()
167
+ plt.savefig("semantic_loss.png")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()