xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1237 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import html
5
+ import json
6
+ import os
7
+ import platform
8
+ import shutil
9
+ import signal
10
+ import subprocess
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ import gradio as gr
15
+ import psutil
16
+ import yaml
17
+ from loguru import logger
18
+ from tqdm import tqdm
19
+
20
+ PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
21
+ sys.path.insert(0, "")
22
+ print(sys.path)
23
+ cur_work_dir = Path(os.getcwd()).resolve()
24
+ print("You are in ", str(cur_work_dir))
25
+
26
+ from fish_speech.i18n import i18n
27
+ from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
28
+
29
+ config_path = cur_work_dir / "fish_speech" / "configs"
30
+ vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
31
+ llama_yml_path = config_path / "text2semantic_finetune.yaml"
32
+
33
+ env = os.environ.copy()
34
+ env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
35
+
36
+ seafoam = Seafoam()
37
+
38
+
39
+ def build_html_error_message(error):
40
+ return f"""
41
+ <div style="color: red; font-weight: bold;">
42
+ {html.escape(error)}
43
+ </div>
44
+ """
45
+
46
+
47
+ def build_html_ok_message(msg):
48
+ return f"""
49
+ <div style="color: green; font-weight: bold;">
50
+ {html.escape(msg)}
51
+ </div>
52
+ """
53
+
54
+
55
+ def build_html_href(link, desc, msg):
56
+ return f"""
57
+ <span style="color: green; font-weight: bold; display: inline-block">
58
+ {html.escape(msg)}
59
+ <a href="{link}">{desc}</a>
60
+ </span>
61
+ """
62
+
63
+
64
+ def load_data_in_raw(path):
65
+ with open(path, "r", encoding="utf-8") as file:
66
+ data = file.read()
67
+ return str(data)
68
+
69
+
70
+ def kill_proc_tree(pid, including_parent=True):
71
+ try:
72
+ parent = psutil.Process(pid)
73
+ except psutil.NoSuchProcess:
74
+ # Process already terminated
75
+ return
76
+
77
+ children = parent.children(recursive=True)
78
+ for child in children:
79
+ try:
80
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
81
+ except OSError:
82
+ pass
83
+ if including_parent:
84
+ try:
85
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
86
+ except OSError:
87
+ pass
88
+
89
+
90
+ system = platform.system()
91
+ p_label = None
92
+ p_infer = None
93
+ p_tensorboard = None
94
+
95
+
96
+ def kill_process(pid):
97
+ if system == "Windows":
98
+ cmd = "taskkill /t /f /pid %s" % pid
99
+ # os.system(cmd)
100
+ subprocess.run(cmd)
101
+ else:
102
+ kill_proc_tree(pid)
103
+
104
+
105
+ def change_label(if_label):
106
+ global p_label
107
+ if if_label == True and p_label is None:
108
+ url = "http://localhost:3000"
109
+ remote_url = "https://text-labeler.pages.dev/"
110
+ try:
111
+ p_label = subprocess.Popen(
112
+ [
113
+ (
114
+ "asr-label-linux-x64"
115
+ if sys.platform == "linux"
116
+ else "asr-label-win-x64.exe"
117
+ )
118
+ ]
119
+ )
120
+ except FileNotFoundError:
121
+ logger.warning("asr-label execution not found!")
122
+
123
+ yield build_html_href(
124
+ link=remote_url,
125
+ desc=i18n("Optional online ver"),
126
+ msg=i18n("Opened labeler in browser"),
127
+ )
128
+
129
+ elif if_label == False and p_label is not None:
130
+ kill_process(p_label.pid)
131
+ p_label = None
132
+ yield build_html_ok_message("Nothing")
133
+
134
+
135
+ def clean_infer_cache():
136
+ import tempfile
137
+
138
+ temp_dir = Path(tempfile.gettempdir())
139
+ gradio_dir = str(temp_dir / "gradio")
140
+ try:
141
+ shutil.rmtree(gradio_dir)
142
+ logger.info(f"Deleted cached audios: {gradio_dir}")
143
+ except PermissionError:
144
+ logger.info(f"Permission denied: Unable to delete {gradio_dir}")
145
+ except FileNotFoundError:
146
+ logger.info(f"{gradio_dir} was not found")
147
+ except Exception as e:
148
+ logger.info(f"An error occurred: {e}")
149
+
150
+
151
+ def change_infer(
152
+ if_infer,
153
+ host,
154
+ port,
155
+ infer_decoder_model,
156
+ infer_decoder_config,
157
+ infer_llama_model,
158
+ infer_compile,
159
+ ):
160
+ global p_infer
161
+ if if_infer == True and p_infer == None:
162
+ env = os.environ.copy()
163
+
164
+ env["GRADIO_SERVER_NAME"] = host
165
+ env["GRADIO_SERVER_PORT"] = port
166
+ # 启动第二个进程
167
+ url = f"http://{host}:{port}"
168
+ yield build_html_ok_message(
169
+ i18n("Inferring interface is launched at {}").format(url)
170
+ )
171
+
172
+ clean_infer_cache()
173
+
174
+ p_infer = subprocess.Popen(
175
+ [
176
+ PYTHON,
177
+ "tools/webui.py",
178
+ "--decoder-checkpoint-path",
179
+ infer_decoder_model,
180
+ "--decoder-config-name",
181
+ infer_decoder_config,
182
+ "--llama-checkpoint-path",
183
+ infer_llama_model,
184
+ ]
185
+ + (["--compile"] if infer_compile == "Yes" else []),
186
+ env=env,
187
+ )
188
+
189
+ elif if_infer == False and p_infer is not None:
190
+ kill_process(p_infer.pid)
191
+ p_infer = None
192
+ yield build_html_error_message(i18n("Infer interface is closed"))
193
+
194
+
195
+ js = load_data_in_raw("fish_speech/webui/js/animate.js")
196
+ css = load_data_in_raw("fish_speech/webui/css/style.css")
197
+
198
+ data_pre_output = (cur_work_dir / "data").resolve()
199
+ default_model_output = (cur_work_dir / "results").resolve()
200
+ default_filelist = data_pre_output / "detect.list"
201
+ data_pre_output.mkdir(parents=True, exist_ok=True)
202
+
203
+ items = []
204
+ dict_items = {}
205
+
206
+
207
+ def load_yaml_data_in_fact(yml_path):
208
+ with open(yml_path, "r", encoding="utf-8") as file:
209
+ yml = yaml.safe_load(file)
210
+ return yml
211
+
212
+
213
+ def write_yaml_data_in_fact(yml, yml_path):
214
+ with open(yml_path, "w", encoding="utf-8") as file:
215
+ yaml.safe_dump(yml, file, allow_unicode=True)
216
+ return yml
217
+
218
+
219
+ def generate_tree(directory, depth=0, max_depth=None, prefix=""):
220
+ if max_depth is not None and depth > max_depth:
221
+ return ""
222
+
223
+ tree_str = ""
224
+ files = []
225
+ directories = []
226
+ for item in os.listdir(directory):
227
+ if os.path.isdir(os.path.join(directory, item)):
228
+ directories.append(item)
229
+ else:
230
+ files.append(item)
231
+
232
+ entries = directories + files
233
+ for i, entry in enumerate(entries):
234
+ connector = "├── " if i < len(entries) - 1 else "└── "
235
+ tree_str += f"{prefix}{connector}{entry}<br />"
236
+ if i < len(directories):
237
+ extension = "│ " if i < len(entries) - 1 else " "
238
+ tree_str += generate_tree(
239
+ os.path.join(directory, entry),
240
+ depth + 1,
241
+ max_depth,
242
+ prefix=prefix + extension,
243
+ )
244
+ return tree_str
245
+
246
+
247
+ def new_explorer(data_path, max_depth):
248
+ return gr.Markdown(
249
+ elem_classes=["scrollable-component"],
250
+ value=generate_tree(data_path, max_depth=max_depth),
251
+ )
252
+
253
+
254
+ def add_item(
255
+ folder: str,
256
+ method: str,
257
+ label_lang: str,
258
+ if_initial_prompt: bool,
259
+ initial_prompt: str | None,
260
+ ):
261
+ folder = folder.strip(" ").strip('"')
262
+
263
+ folder_path = Path(folder)
264
+
265
+ if folder and folder not in items and data_pre_output not in folder_path.parents:
266
+ if folder_path.is_dir():
267
+ items.append(folder)
268
+ dict_items[folder] = dict(
269
+ type="folder",
270
+ method=method,
271
+ label_lang=label_lang,
272
+ initial_prompt=initial_prompt if if_initial_prompt else None,
273
+ )
274
+ elif folder:
275
+ err = folder
276
+ return gr.Checkboxgroup(choices=items), build_html_error_message(
277
+ i18n("Invalid path: {}").format(err)
278
+ )
279
+
280
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
281
+ logger.info("After Adding: " + formatted_data)
282
+ gr.Info(formatted_data)
283
+ return gr.Checkboxgroup(choices=items), build_html_ok_message(
284
+ i18n("Added path successfully!")
285
+ )
286
+
287
+
288
+ def remove_items(selected_items):
289
+ global items, dict_items
290
+ to_remove = [item for item in items if item in selected_items]
291
+ for item in to_remove:
292
+ del dict_items[item]
293
+ items = [item for item in items if item in dict_items.keys()]
294
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
295
+ logger.info(formatted_data)
296
+ gr.Warning("After Removing: " + formatted_data)
297
+ return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
298
+ i18n("Removed path successfully!")
299
+ )
300
+
301
+
302
+ def show_selected(options):
303
+ selected_options = ", ".join(options)
304
+
305
+ if options:
306
+ return i18n("Selected: {}").format(selected_options)
307
+ else:
308
+ return i18n("No selected options")
309
+
310
+
311
+ from pydub import AudioSegment
312
+
313
+
314
+ def convert_to_mono_in_place(audio_path: Path):
315
+ audio = AudioSegment.from_file(audio_path)
316
+ if audio.channels > 1:
317
+ mono_audio = audio.set_channels(1)
318
+ mono_audio.export(audio_path, format=audio_path.suffix[1:])
319
+ logger.info(f"Convert {audio_path} successfully")
320
+
321
+
322
+ def list_copy(list_file_path, method):
323
+ wav_root = data_pre_output
324
+ lst = []
325
+ with list_file_path.open("r", encoding="utf-8") as file:
326
+ for line in tqdm(file, desc="Processing audio/transcript"):
327
+ wav_path, speaker_name, language, text = line.strip().split("|")
328
+ original_wav_path = Path(wav_path)
329
+ target_wav_path = (
330
+ wav_root / original_wav_path.parent.name / original_wav_path.name
331
+ )
332
+ lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
333
+ if target_wav_path.is_file():
334
+ continue
335
+ target_wav_path.parent.mkdir(parents=True, exist_ok=True)
336
+ if method == i18n("Copy"):
337
+ shutil.copy(original_wav_path, target_wav_path)
338
+ else:
339
+ shutil.move(original_wav_path, target_wav_path.parent)
340
+ convert_to_mono_in_place(target_wav_path)
341
+ original_lab_path = original_wav_path.with_suffix(".lab")
342
+ target_lab_path = (
343
+ wav_root
344
+ / original_wav_path.parent.name
345
+ / original_wav_path.with_suffix(".lab").name
346
+ )
347
+ if target_lab_path.is_file():
348
+ continue
349
+ if method == i18n("Copy"):
350
+ shutil.copy(original_lab_path, target_lab_path)
351
+ else:
352
+ shutil.move(original_lab_path, target_lab_path.parent)
353
+
354
+ if method == i18n("Move"):
355
+ with list_file_path.open("w", encoding="utf-8") as file:
356
+ file.writelines("\n".join(lst))
357
+
358
+ del lst
359
+ return build_html_ok_message(i18n("Use filelist"))
360
+
361
+
362
+ def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
363
+ global dict_items
364
+ data_path = Path(data_path)
365
+ gr.Warning("Pre-processing begins...")
366
+ for item, content in dict_items.items():
367
+ item_path = Path(item)
368
+ tar_path = data_path / item_path.name
369
+
370
+ if content["type"] == "folder" and item_path.is_dir():
371
+ if content["method"] == i18n("Copy"):
372
+ os.makedirs(tar_path, exist_ok=True)
373
+ shutil.copytree(
374
+ src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
375
+ )
376
+ elif not tar_path.is_dir():
377
+ shutil.move(src=str(item_path), dst=str(tar_path))
378
+
379
+ for suf in ["wav", "flac", "mp3"]:
380
+ for audio_path in tar_path.glob(f"**/*.{suf}"):
381
+ convert_to_mono_in_place(audio_path)
382
+
383
+ cur_lang = content["label_lang"]
384
+ initial_prompt = content["initial_prompt"]
385
+
386
+ transcribe_cmd = [
387
+ PYTHON,
388
+ "tools/whisper_asr.py",
389
+ "--model-size",
390
+ label_model,
391
+ "--device",
392
+ label_device,
393
+ "--audio-dir",
394
+ tar_path,
395
+ "--save-dir",
396
+ tar_path,
397
+ "--language",
398
+ cur_lang,
399
+ ]
400
+
401
+ if initial_prompt is not None:
402
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
403
+
404
+ if cur_lang != "IGNORE":
405
+ try:
406
+ gr.Warning("Begin To Transcribe")
407
+ subprocess.run(
408
+ transcribe_cmd,
409
+ env=env,
410
+ )
411
+ except Exception:
412
+ print("Transcription error occurred")
413
+
414
+ elif content["type"] == "file" and item_path.is_file():
415
+ list_copy(item_path, content["method"])
416
+
417
+ return build_html_ok_message(i18n("Move files successfully")), new_explorer(
418
+ data_path, max_depth=max_depth
419
+ )
420
+
421
+
422
+ def generate_folder_name():
423
+ now = datetime.datetime.now()
424
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
425
+ return folder_name
426
+
427
+
428
+ def train_process(
429
+ data_path: str,
430
+ option: str,
431
+ # llama config
432
+ llama_ckpt,
433
+ llama_base_config,
434
+ llama_lr,
435
+ llama_maxsteps,
436
+ llama_data_num_workers,
437
+ llama_data_batch_size,
438
+ llama_data_max_length,
439
+ llama_precision,
440
+ llama_check_interval,
441
+ llama_grad_batches,
442
+ llama_use_speaker,
443
+ llama_use_lora,
444
+ ):
445
+
446
+ backend = "nccl" if sys.platform == "linux" else "gloo"
447
+
448
+ new_project = generate_folder_name()
449
+ print("New Project Name: ", new_project)
450
+
451
+ if option == "VQGAN":
452
+ msg = "Skipped VQGAN Training."
453
+ gr.Warning(msg)
454
+ logger.info(msg)
455
+
456
+ if option == "LLAMA":
457
+ msg = "LLAMA Training begins..."
458
+ gr.Warning(msg)
459
+ logger.info(msg)
460
+ subprocess.run(
461
+ [
462
+ PYTHON,
463
+ "tools/vqgan/extract_vq.py",
464
+ str(data_pre_output),
465
+ "--num-workers",
466
+ "1",
467
+ "--batch-size",
468
+ "16",
469
+ "--config-name",
470
+ "firefly_gan_vq",
471
+ "--checkpoint-path",
472
+ "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
473
+ ]
474
+ )
475
+
476
+ subprocess.run(
477
+ [
478
+ PYTHON,
479
+ "tools/llama/build_dataset.py",
480
+ "--input",
481
+ str(data_pre_output),
482
+ "--text-extension",
483
+ ".lab",
484
+ "--num-workers",
485
+ "16",
486
+ ]
487
+ )
488
+ ckpt_path = "checkpoints/fish-speech-1.2-sft/model.pth"
489
+ lora_prefix = "lora_" if llama_use_lora else ""
490
+ llama_name = lora_prefix + "text2semantic_" + new_project
491
+ latest = next(
492
+ iter(
493
+ sorted(
494
+ [
495
+ str(p.relative_to("results"))
496
+ for p in Path("results").glob(lora_prefix + "text2sem*/")
497
+ ],
498
+ reverse=True,
499
+ )
500
+ ),
501
+ llama_name,
502
+ )
503
+ project = (
504
+ llama_name
505
+ if llama_ckpt == i18n("new")
506
+ else (
507
+ latest
508
+ if llama_ckpt == i18n("latest")
509
+ else Path(llama_ckpt).relative_to("results")
510
+ )
511
+ )
512
+ logger.info(project)
513
+
514
+ if llama_check_interval > llama_maxsteps:
515
+ llama_check_interval = llama_maxsteps
516
+
517
+ train_cmd = [
518
+ PYTHON,
519
+ "fish_speech/train.py",
520
+ "--config-name",
521
+ "text2semantic_finetune",
522
+ f"project={project}",
523
+ f"trainer.strategy.process_group_backend={backend}",
524
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
525
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
526
+ f"model.optimizer.lr={llama_lr}",
527
+ f"trainer.max_steps={llama_maxsteps}",
528
+ f"data.num_workers={llama_data_num_workers}",
529
+ f"data.batch_size={llama_data_batch_size}",
530
+ f"max_length={llama_data_max_length}",
531
+ f"trainer.precision={llama_precision}",
532
+ f"trainer.val_check_interval={llama_check_interval}",
533
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
534
+ f"train_dataset.interactive_prob={llama_use_speaker}",
535
+ ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
536
+ logger.info(train_cmd)
537
+ subprocess.run(train_cmd)
538
+
539
+ return build_html_ok_message(i18n("Training stopped"))
540
+
541
+
542
+ def tensorboard_process(
543
+ if_tensorboard: bool,
544
+ tensorboard_dir: str,
545
+ host: str,
546
+ port: str,
547
+ ):
548
+ global p_tensorboard
549
+ if if_tensorboard == True and p_tensorboard == None:
550
+ url = f"http://{host}:{port}"
551
+ yield build_html_ok_message(
552
+ i18n("Tensorboard interface is launched at {}").format(url)
553
+ )
554
+ prefix = ["tensorboard"]
555
+ if Path("fishenv").exists():
556
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
557
+
558
+ p_tensorboard = subprocess.Popen(
559
+ prefix
560
+ + [
561
+ "--logdir",
562
+ tensorboard_dir,
563
+ "--host",
564
+ host,
565
+ "--port",
566
+ port,
567
+ "--reload_interval",
568
+ "120",
569
+ ]
570
+ )
571
+ elif if_tensorboard == False and p_tensorboard != None:
572
+ kill_process(p_tensorboard.pid)
573
+ p_tensorboard = None
574
+ yield build_html_error_message(i18n("Tensorboard interface is closed"))
575
+
576
+
577
+ def fresh_tb_dir():
578
+ return gr.Dropdown(
579
+ choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
580
+ )
581
+
582
+
583
+ def list_decoder_models():
584
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
585
+ if not paths:
586
+ logger.warning("No decoder model found")
587
+ return paths
588
+
589
+
590
+ def list_llama_models():
591
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
592
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
593
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
594
+ choices = sorted(choices, reverse=True)
595
+ if not choices:
596
+ logger.warning("No LLaMA model found")
597
+ return choices
598
+
599
+
600
+ def list_lora_llama_models():
601
+ choices = sorted(
602
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
603
+ )
604
+ if not choices:
605
+ logger.warning("No LoRA LLaMA model found")
606
+ return choices
607
+
608
+
609
+ def fresh_decoder_model():
610
+ return gr.Dropdown(choices=list_decoder_models())
611
+
612
+
613
+ def fresh_llama_ckpt(llama_use_lora):
614
+ return gr.Dropdown(
615
+ choices=[i18n("latest"), i18n("new")]
616
+ + (
617
+ [str(p) for p in Path("results").glob("text2sem*/")]
618
+ if not llama_use_lora
619
+ else [str(p) for p in Path("results").glob("lora_*/")]
620
+ )
621
+ )
622
+
623
+
624
+ def fresh_llama_model():
625
+ return gr.Dropdown(choices=list_llama_models())
626
+
627
+
628
+ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
629
+ if (
630
+ lora_weight is None
631
+ or not Path(lora_weight).exists()
632
+ or not Path(llama_weight).exists()
633
+ ):
634
+ return build_html_error_message(
635
+ i18n(
636
+ "Path error, please check the model file exists in the corresponding path"
637
+ )
638
+ )
639
+ gr.Warning("Merging begins...")
640
+ merge_cmd = [
641
+ PYTHON,
642
+ "tools/llama/merge_lora.py",
643
+ "--lora-config",
644
+ "r_8_alpha_16",
645
+ "--lora-weight",
646
+ lora_weight,
647
+ "--output",
648
+ llama_lora_output + "_" + generate_folder_name(),
649
+ ]
650
+ logger.info(merge_cmd)
651
+ subprocess.run(merge_cmd)
652
+ return build_html_ok_message(i18n("Merge successfully"))
653
+
654
+
655
+ def llama_quantify(llama_weight, quantify_mode):
656
+ if llama_weight is None or not Path(llama_weight).exists():
657
+ return build_html_error_message(
658
+ i18n(
659
+ "Path error, please check the model file exists in the corresponding path"
660
+ )
661
+ )
662
+
663
+ gr.Warning("Quantifying begins...")
664
+
665
+ now = generate_folder_name()
666
+ quantify_cmd = [
667
+ PYTHON,
668
+ "tools/llama/quantize.py",
669
+ "--checkpoint-path",
670
+ llama_weight,
671
+ "--mode",
672
+ quantify_mode,
673
+ "--timestamp",
674
+ now,
675
+ ]
676
+ logger.info(quantify_cmd)
677
+ subprocess.run(quantify_cmd)
678
+ if quantify_mode == "int8":
679
+ quantize_path = str(
680
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
681
+ )
682
+ else:
683
+ quantize_path = str(
684
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
685
+ )
686
+ return build_html_ok_message(
687
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
688
+ )
689
+
690
+
691
+ init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
692
+ init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
693
+
694
+ with gr.Blocks(
695
+ head="<style>\n" + css + "\n</style>",
696
+ js=js,
697
+ theme=seafoam,
698
+ analytics_enabled=False,
699
+ title="Fish Speech",
700
+ ) as demo:
701
+ with gr.Row():
702
+ with gr.Column():
703
+ with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
704
+ with gr.Row():
705
+ textbox = gr.Textbox(
706
+ label="\U0000270F "
707
+ + i18n("Input Audio & Source Path for Transcription"),
708
+ info=i18n("Speaker is identified by the folder name"),
709
+ interactive=True,
710
+ )
711
+ with gr.Row(equal_height=False):
712
+ with gr.Column():
713
+ output_radio = gr.Radio(
714
+ label="\U0001F4C1 "
715
+ + i18n("Select source file processing method"),
716
+ choices=[i18n("Copy"), i18n("Move")],
717
+ value=i18n("Copy"),
718
+ interactive=True,
719
+ )
720
+ with gr.Column():
721
+ error = gr.HTML(label=i18n("Error Message"))
722
+ if_label = gr.Checkbox(
723
+ label=i18n("Open Labeler WebUI"), scale=0, show_label=True
724
+ )
725
+
726
+ with gr.Row():
727
+ label_device = gr.Dropdown(
728
+ label=i18n("Labeling Device"),
729
+ info=i18n(
730
+ "It is recommended to use CUDA, if you have low configuration, use CPU"
731
+ ),
732
+ choices=["cpu", "cuda"],
733
+ value="cuda",
734
+ interactive=True,
735
+ )
736
+ label_model = gr.Dropdown(
737
+ label=i18n("Whisper Model"),
738
+ info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
739
+ choices=["large-v3", "medium"],
740
+ value="large-v3",
741
+ interactive=True,
742
+ )
743
+ label_radio = gr.Dropdown(
744
+ label=i18n("Optional Label Language"),
745
+ info=i18n(
746
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
747
+ ),
748
+ choices=[
749
+ (i18n("Chinese"), "zh"),
750
+ (i18n("English"), "en"),
751
+ (i18n("Japanese"), "ja"),
752
+ (i18n("Disabled"), "IGNORE"),
753
+ (i18n("auto"), "auto"),
754
+ ],
755
+ value="IGNORE",
756
+ interactive=True,
757
+ )
758
+
759
+ with gr.Row():
760
+ if_initial_prompt = gr.Checkbox(
761
+ value=False,
762
+ label=i18n("Enable Initial Prompt"),
763
+ min_width=120,
764
+ scale=0,
765
+ )
766
+ initial_prompt = gr.Textbox(
767
+ label=i18n("Initial Prompt"),
768
+ info=i18n(
769
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
770
+ ),
771
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
772
+ interactive=False,
773
+ )
774
+
775
+ with gr.Row():
776
+ add_button = gr.Button(
777
+ "\U000027A1 " + i18n("Add to Processing Area"),
778
+ variant="primary",
779
+ )
780
+ remove_button = gr.Button(
781
+ "\U000026D4 " + i18n("Remove Selected Data")
782
+ )
783
+
784
+ with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
785
+ with gr.Row():
786
+ model_type_radio = gr.Radio(
787
+ label=i18n(
788
+ "Select the model to be trained (Depending on the Tab page you are on)"
789
+ ),
790
+ interactive=False,
791
+ choices=["VQGAN", "LLAMA"],
792
+ value="VQGAN",
793
+ )
794
+ with gr.Row():
795
+ with gr.Tabs():
796
+ with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
797
+ gr.HTML("You don't need to train this model!")
798
+
799
+ with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
800
+ with gr.Row(equal_height=False):
801
+ llama_use_lora = gr.Checkbox(
802
+ label=i18n("Use LoRA"),
803
+ info=i18n(
804
+ "Use LoRA can save GPU memory, but may reduce the quality of the model"
805
+ ),
806
+ value=True,
807
+ interactive=True,
808
+ )
809
+ llama_ckpt = gr.Dropdown(
810
+ label=i18n("Select LLAMA ckpt"),
811
+ choices=[i18n("latest"), i18n("new")]
812
+ + [
813
+ str(p)
814
+ for p in Path("results").glob("text2sem*/")
815
+ ]
816
+ + [str(p) for p in Path("results").glob("lora*/")],
817
+ value=i18n("latest"),
818
+ interactive=True,
819
+ )
820
+ with gr.Row(equal_height=False):
821
+ llama_lr_slider = gr.Slider(
822
+ label=i18n("Initial Learning Rate"),
823
+ info=i18n(
824
+ "lr smaller -> usually train slower but more stable"
825
+ ),
826
+ interactive=True,
827
+ minimum=1e-5,
828
+ maximum=1e-4,
829
+ step=1e-5,
830
+ value=5e-5,
831
+ )
832
+ llama_maxsteps_slider = gr.Slider(
833
+ label=i18n("Maximum Training Steps"),
834
+ info=i18n(
835
+ "recommend: max_steps = num_audios // batch_size * (2 to 5)"
836
+ ),
837
+ interactive=True,
838
+ minimum=1,
839
+ maximum=10000,
840
+ step=1,
841
+ value=50,
842
+ )
843
+ with gr.Row(equal_height=False):
844
+ llama_base_config = gr.Dropdown(
845
+ label=i18n("Model Size"),
846
+ choices=[
847
+ "text2semantic_finetune",
848
+ ],
849
+ value="text2semantic_finetune",
850
+ )
851
+ llama_data_num_workers_slider = gr.Slider(
852
+ label=i18n("Number of Workers"),
853
+ minimum=1,
854
+ maximum=32,
855
+ step=1,
856
+ value=4,
857
+ )
858
+ with gr.Row(equal_height=False):
859
+ llama_data_batch_size_slider = gr.Slider(
860
+ label=i18n("Batch Size"),
861
+ interactive=True,
862
+ minimum=1,
863
+ maximum=32,
864
+ step=1,
865
+ value=4,
866
+ )
867
+ llama_data_max_length_slider = gr.Slider(
868
+ label=i18n("Maximum Length per Sample"),
869
+ interactive=True,
870
+ minimum=1024,
871
+ maximum=4096,
872
+ step=128,
873
+ value=1024,
874
+ )
875
+ with gr.Row(equal_height=False):
876
+ llama_precision_dropdown = gr.Dropdown(
877
+ label=i18n("Precision"),
878
+ info=i18n(
879
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
880
+ ),
881
+ interactive=True,
882
+ choices=["32", "bf16-true", "16-mixed"],
883
+ value="bf16-true",
884
+ )
885
+ llama_check_interval_slider = gr.Slider(
886
+ label=i18n("Save model every n steps"),
887
+ info=i18n(
888
+ "make sure that it's not greater than max_steps"
889
+ ),
890
+ interactive=True,
891
+ minimum=1,
892
+ maximum=1000,
893
+ step=1,
894
+ value=50,
895
+ )
896
+ with gr.Row(equal_height=False):
897
+ llama_grad_batches = gr.Slider(
898
+ label=i18n("Accumulate Gradient Batches"),
899
+ interactive=True,
900
+ minimum=1,
901
+ maximum=20,
902
+ step=1,
903
+ value=init_llama_yml["trainer"][
904
+ "accumulate_grad_batches"
905
+ ],
906
+ )
907
+ llama_use_speaker = gr.Slider(
908
+ label=i18n(
909
+ "Probability of applying Speaker Condition"
910
+ ),
911
+ interactive=True,
912
+ minimum=0.1,
913
+ maximum=1.0,
914
+ step=0.05,
915
+ value=init_llama_yml["train_dataset"][
916
+ "interactive_prob"
917
+ ],
918
+ )
919
+
920
+ with gr.Tab(label=i18n("Merge LoRA"), id=4):
921
+ with gr.Row(equal_height=False):
922
+ llama_weight = gr.Dropdown(
923
+ label=i18n("Base LLAMA Model"),
924
+ info=i18n(
925
+ "Type the path or select from the dropdown"
926
+ ),
927
+ choices=[
928
+ "checkpoints/fish-speech-1.2-sft/model.pth",
929
+ ],
930
+ value="checkpoints/fish-speech-1.2-sft/model.pth",
931
+ allow_custom_value=True,
932
+ interactive=True,
933
+ )
934
+ with gr.Row(equal_height=False):
935
+ lora_weight = gr.Dropdown(
936
+ label=i18n("LoRA Model to be merged"),
937
+ info=i18n(
938
+ "Type the path or select from the dropdown"
939
+ ),
940
+ choices=[
941
+ str(p)
942
+ for p in Path("results").glob("lora*/**/*.ckpt")
943
+ ],
944
+ allow_custom_value=True,
945
+ interactive=True,
946
+ )
947
+ lora_llama_config = gr.Dropdown(
948
+ label=i18n("LLAMA Model Config"),
949
+ info=i18n(
950
+ "Type the path or select from the dropdown"
951
+ ),
952
+ choices=[
953
+ "text2semantic_finetune",
954
+ ],
955
+ value="text2semantic_finetune",
956
+ allow_custom_value=True,
957
+ )
958
+ with gr.Row(equal_height=False):
959
+ llama_lora_output = gr.Dropdown(
960
+ label=i18n("Output Path"),
961
+ info=i18n(
962
+ "Type the path or select from the dropdown"
963
+ ),
964
+ value="checkpoints/merged",
965
+ choices=["checkpoints/merged"],
966
+ allow_custom_value=True,
967
+ interactive=True,
968
+ )
969
+ with gr.Row(equal_height=False):
970
+ llama_lora_merge_btn = gr.Button(
971
+ value=i18n("Merge"), variant="primary"
972
+ )
973
+
974
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
975
+ with gr.Row(equal_height=False):
976
+ llama_weight_to_quantify = gr.Dropdown(
977
+ label=i18n("Base LLAMA Model"),
978
+ info=i18n(
979
+ "Type the path or select from the dropdown"
980
+ ),
981
+ choices=list_llama_models(),
982
+ value="checkpoints/fish-speech-1.2-sft",
983
+ allow_custom_value=True,
984
+ interactive=True,
985
+ )
986
+ quantify_mode = gr.Dropdown(
987
+ label=i18n("Post-quantification Precision"),
988
+ info=i18n(
989
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
990
+ ),
991
+ choices=["int8", "int4"],
992
+ value="int8",
993
+ allow_custom_value=False,
994
+ interactive=True,
995
+ )
996
+ with gr.Row(equal_height=False):
997
+ llama_quantify_btn = gr.Button(
998
+ value=i18n("Quantify"), variant="primary"
999
+ )
1000
+
1001
+ with gr.Tab(label="Tensorboard", id=6):
1002
+ with gr.Row(equal_height=False):
1003
+ tb_host = gr.Textbox(
1004
+ label=i18n("Tensorboard Host"), value="127.0.0.1"
1005
+ )
1006
+ tb_port = gr.Textbox(
1007
+ label=i18n("Tensorboard Port"), value="11451"
1008
+ )
1009
+ with gr.Row(equal_height=False):
1010
+ tb_dir = gr.Dropdown(
1011
+ label=i18n("Tensorboard Log Path"),
1012
+ allow_custom_value=True,
1013
+ choices=[
1014
+ str(p)
1015
+ for p in Path("results").glob("**/tensorboard/")
1016
+ ],
1017
+ )
1018
+ with gr.Row(equal_height=False):
1019
+ if_tb = gr.Checkbox(
1020
+ label=i18n("Open Tensorboard"),
1021
+ )
1022
+
1023
+ with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
1024
+ with gr.Column():
1025
+ with gr.Row():
1026
+ with gr.Accordion(
1027
+ label="\U0001F5A5 "
1028
+ + i18n("Inference Server Configuration"),
1029
+ open=False,
1030
+ ):
1031
+ with gr.Row():
1032
+ infer_host_textbox = gr.Textbox(
1033
+ label=i18n("WebUI Host"), value="127.0.0.1"
1034
+ )
1035
+ infer_port_textbox = gr.Textbox(
1036
+ label=i18n("WebUI Port"), value="7862"
1037
+ )
1038
+ with gr.Row():
1039
+ infer_decoder_model = gr.Dropdown(
1040
+ label=i18n("Decoder Model Path"),
1041
+ info=i18n(
1042
+ "Type the path or select from the dropdown"
1043
+ ),
1044
+ choices=list_decoder_models(),
1045
+ value="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
1046
+ allow_custom_value=True,
1047
+ )
1048
+ infer_decoder_config = gr.Dropdown(
1049
+ label=i18n("Decoder Model Config"),
1050
+ info=i18n("Changing with the Model Path"),
1051
+ value="firefly_gan_vq",
1052
+ choices=[
1053
+ "firefly_gan_vq",
1054
+ ],
1055
+ allow_custom_value=True,
1056
+ )
1057
+ with gr.Row():
1058
+ infer_llama_model = gr.Dropdown(
1059
+ label=i18n("LLAMA Model Path"),
1060
+ info=i18n(
1061
+ "Type the path or select from the dropdown"
1062
+ ),
1063
+ value="checkpoints/fish-speech-1.2-sft",
1064
+ choices=list_llama_models(),
1065
+ allow_custom_value=True,
1066
+ )
1067
+
1068
+ with gr.Row():
1069
+ infer_compile = gr.Radio(
1070
+ label=i18n("Compile Model"),
1071
+ info=i18n(
1072
+ "Compile the model can significantly reduce the inference time, but will increase cold start time"
1073
+ ),
1074
+ choices=["Yes", "No"],
1075
+ value=(
1076
+ "Yes" if (sys.platform == "linux") else "No"
1077
+ ),
1078
+ interactive=is_module_installed("triton"),
1079
+ )
1080
+
1081
+ with gr.Row():
1082
+ infer_checkbox = gr.Checkbox(
1083
+ label=i18n("Open Inference Server")
1084
+ )
1085
+ infer_error = gr.HTML(label=i18n("Inference Server Error"))
1086
+
1087
+ with gr.Column():
1088
+ train_error = gr.HTML(label=i18n("Training Error"))
1089
+ checkbox_group = gr.CheckboxGroup(
1090
+ label="\U0001F4CA " + i18n("Data Source"),
1091
+ info=i18n(
1092
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
1093
+ ),
1094
+ elem_classes=["data_src"],
1095
+ )
1096
+ train_box = gr.Textbox(
1097
+ label=i18n("Data Preprocessing Path"),
1098
+ value=str(data_pre_output),
1099
+ interactive=False,
1100
+ )
1101
+ model_box = gr.Textbox(
1102
+ label="\U0001F4BE " + i18n("Model Output Path"),
1103
+ value=str(default_model_output),
1104
+ interactive=False,
1105
+ )
1106
+
1107
+ with gr.Accordion(
1108
+ i18n(
1109
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
1110
+ ),
1111
+ elem_classes=["scrollable-component"],
1112
+ elem_id="file_accordion",
1113
+ ):
1114
+ tree_slider = gr.Slider(
1115
+ minimum=0,
1116
+ maximum=3,
1117
+ value=0,
1118
+ step=1,
1119
+ show_label=False,
1120
+ container=False,
1121
+ )
1122
+ file_markdown = new_explorer(str(data_pre_output), 0)
1123
+ with gr.Row(equal_height=False):
1124
+ admit_btn = gr.Button(
1125
+ "\U00002705 " + i18n("File Preprocessing"),
1126
+ variant="primary",
1127
+ )
1128
+ fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
1129
+ help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
1130
+ train_btn = gr.Button(i18n("Start Training"), variant="primary")
1131
+
1132
+ footer = load_data_in_raw("fish_speech/webui/html/footer.html")
1133
+ footer = footer.format(
1134
+ versions=versions_html(),
1135
+ api_docs="https://speech.fish.audio/inference/#http-api",
1136
+ )
1137
+ gr.HTML(footer, elem_id="footer")
1138
+ vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
1139
+ llama_page.select(lambda: "LLAMA", None, model_type_radio)
1140
+ add_button.click(
1141
+ fn=add_item,
1142
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
1143
+ outputs=[checkbox_group, error],
1144
+ )
1145
+ remove_button.click(
1146
+ fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
1147
+ )
1148
+ checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
1149
+ help_button.click(
1150
+ fn=None,
1151
+ js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
1152
+ 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
1153
+ )
1154
+ if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
1155
+ if_initial_prompt.change(
1156
+ fn=lambda x: gr.Textbox(value="", interactive=x),
1157
+ inputs=[if_initial_prompt],
1158
+ outputs=[initial_prompt],
1159
+ )
1160
+ train_btn.click(
1161
+ fn=train_process,
1162
+ inputs=[
1163
+ train_box,
1164
+ model_type_radio,
1165
+ # llama config
1166
+ llama_ckpt,
1167
+ llama_base_config,
1168
+ llama_lr_slider,
1169
+ llama_maxsteps_slider,
1170
+ llama_data_num_workers_slider,
1171
+ llama_data_batch_size_slider,
1172
+ llama_data_max_length_slider,
1173
+ llama_precision_dropdown,
1174
+ llama_check_interval_slider,
1175
+ llama_grad_batches,
1176
+ llama_use_speaker,
1177
+ llama_use_lora,
1178
+ ],
1179
+ outputs=[train_error],
1180
+ )
1181
+ if_tb.change(
1182
+ fn=tensorboard_process,
1183
+ inputs=[if_tb, tb_dir, tb_host, tb_port],
1184
+ outputs=[train_error],
1185
+ )
1186
+ tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
1187
+ infer_decoder_model.change(
1188
+ fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
1189
+ )
1190
+ infer_llama_model.change(
1191
+ fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
1192
+ )
1193
+ llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
1194
+ admit_btn.click(
1195
+ fn=check_files,
1196
+ inputs=[train_box, tree_slider, label_model, label_device],
1197
+ outputs=[error, file_markdown],
1198
+ )
1199
+ fresh_btn.click(
1200
+ fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
1201
+ )
1202
+ llama_use_lora.change(
1203
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
1204
+ )
1205
+ llama_ckpt.change(
1206
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
1207
+ )
1208
+ lora_weight.change(
1209
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
1210
+ inputs=[],
1211
+ outputs=[lora_weight],
1212
+ )
1213
+ llama_lora_merge_btn.click(
1214
+ fn=llama_lora_merge,
1215
+ inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
1216
+ outputs=[train_error],
1217
+ )
1218
+ llama_quantify_btn.click(
1219
+ fn=llama_quantify,
1220
+ inputs=[llama_weight_to_quantify, quantify_mode],
1221
+ outputs=[train_error],
1222
+ )
1223
+ infer_checkbox.change(
1224
+ fn=change_infer,
1225
+ inputs=[
1226
+ infer_checkbox,
1227
+ infer_host_textbox,
1228
+ infer_port_textbox,
1229
+ infer_decoder_model,
1230
+ infer_decoder_config,
1231
+ infer_llama_model,
1232
+ infer_compile,
1233
+ ],
1234
+ outputs=[infer_error],
1235
+ )
1236
+
1237
+ demo.launch(inbrowser=True)