xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
1
+ import os
2
+ import subprocess as sp
3
+ import sys
4
+ import time
5
+ from datetime import timedelta
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from random import Random
9
+
10
+ import click
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
+ from hydra import compose, initialize
15
+ from hydra.utils import instantiate
16
+ from lightning import LightningModule
17
+ from loguru import logger
18
+ from omegaconf import OmegaConf
19
+
20
+ from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
+
22
+ # register eval resolver
23
+ OmegaConf.register_new_resolver("eval", eval)
24
+ # This file is used to convert the audio files to text files using the Whisper model.
25
+ # It's mainly used to generate the training data for the VQ model.
26
+
27
+
28
+ RANK = int(os.environ.get("SLURM_PROCID", 0))
29
+ WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
30
+
31
+ logger_format = (
32
+ "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
33
+ "<level>{level: <8}</level> | "
34
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
35
+ "{extra[rank]} - <level>{message}</level>"
36
+ )
37
+ logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
38
+ logger.remove()
39
+ logger.add(sys.stderr, format=logger_format)
40
+
41
+
42
+ @lru_cache(maxsize=1)
43
+ def get_model(
44
+ config_name: str = "firefly_gan_vq",
45
+ checkpoint_path: str = "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
46
+ device: str | torch.device = "cuda",
47
+ ):
48
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
49
+ cfg = compose(config_name=config_name)
50
+
51
+ model = instantiate(cfg)
52
+ state_dict = torch.load(
53
+ checkpoint_path,
54
+ map_location=device,
55
+ )
56
+ if "state_dict" in state_dict:
57
+ state_dict = state_dict["state_dict"]
58
+
59
+ if any("generator" in k for k in state_dict):
60
+ state_dict = {
61
+ k.replace("generator.", ""): v
62
+ for k, v in state_dict.items()
63
+ if "generator." in k
64
+ }
65
+
66
+ model.load_state_dict(state_dict, strict=False)
67
+ model.eval()
68
+ model.to(device)
69
+
70
+ logger.info(f"Loaded model")
71
+ return model
72
+
73
+
74
+ @torch.inference_mode()
75
+ def process_batch(files: list[Path], model) -> float:
76
+ wavs = []
77
+ audio_lengths = []
78
+ new_files = []
79
+ max_length = total_time = 0
80
+
81
+ for file in files:
82
+ try:
83
+ wav, sr = torchaudio.load(
84
+ str(file), backend="sox" if sys.platform == "linux" else "soundfile"
85
+ ) # Need to install libsox-dev
86
+ except Exception as e:
87
+ logger.error(f"Error reading {file}: {e}")
88
+ continue
89
+
90
+ if wav.shape[0] > 1:
91
+ wav = wav.mean(dim=0, keepdim=True)
92
+
93
+ wav = torchaudio.functional.resample(
94
+ wav.cuda(), sr, model.spec_transform.sample_rate
95
+ )[0]
96
+ total_time += len(wav) / model.spec_transform.sample_rate
97
+ max_length = max(max_length, len(wav))
98
+
99
+ wavs.append(wav)
100
+ audio_lengths.append(len(wav))
101
+ new_files.append(file)
102
+
103
+ files = new_files
104
+
105
+ # Pad to max length
106
+ for i, wav in enumerate(wavs):
107
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
108
+
109
+ audios = torch.stack(wavs, dim=0)[:, None]
110
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
111
+
112
+ # Calculate lengths
113
+ indices, feature_lengths = model.encode(audios, audio_lengths)
114
+
115
+ # Save to disk
116
+ outputs = indices.cpu().numpy()
117
+
118
+ for file, length, feature, audio_length in zip(
119
+ files, feature_lengths, outputs, audio_lengths
120
+ ):
121
+ feature = feature[:, :length]
122
+
123
+ # (T,)
124
+ with open(file.with_suffix(".npy"), "wb") as f:
125
+ np.save(f, feature)
126
+
127
+ return total_time
128
+
129
+
130
+ @click.command()
131
+ @click.argument("folder")
132
+ @click.option("--num-workers", default=1)
133
+ @click.option("--config-name", default="firefly_gan_vq")
134
+ @click.option(
135
+ "--checkpoint-path",
136
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
137
+ )
138
+ @click.option("--batch-size", default=64)
139
+ @click.option("--filelist", default=None, type=Path)
140
+ def main(
141
+ folder: str,
142
+ num_workers: int,
143
+ config_name: str,
144
+ checkpoint_path: str,
145
+ batch_size: int,
146
+ filelist: Path,
147
+ ):
148
+ if num_workers > 1 and WORLD_SIZE != num_workers:
149
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
150
+
151
+ logger.info(f"Spawning {num_workers} workers")
152
+
153
+ if torch.cuda.is_available():
154
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
155
+ if visible_devices is None:
156
+ visible_devices = list(range(torch.cuda.device_count()))
157
+ else:
158
+ visible_devices = visible_devices.split(",")
159
+ else:
160
+ # Set to empty string to avoid using GPU
161
+ visible_devices = [""]
162
+
163
+ processes = []
164
+ for i in range(num_workers):
165
+ env = os.environ.copy()
166
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
167
+ env["SLURM_PROCID"] = str(i)
168
+ env["SLURM_NTASKS"] = str(num_workers)
169
+
170
+ processes.append(
171
+ sp.Popen(
172
+ [sys.executable] + sys.argv.copy(),
173
+ env=env,
174
+ )
175
+ )
176
+
177
+ for p in processes:
178
+ p.wait()
179
+
180
+ logger.info(f"All workers finished")
181
+ return
182
+
183
+ # This is a worker
184
+ logger.info(f"Starting worker")
185
+ if filelist:
186
+ files = [i[0] for i in load_filelist(filelist)]
187
+ else:
188
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
189
+
190
+ print(f"Found {len(files)} files")
191
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
192
+
193
+ total_files = len(files)
194
+ files = files[RANK::WORLD_SIZE]
195
+ logger.info(f"Processing {len(files)}/{total_files} files")
196
+
197
+ # Batch processing
198
+ total_time = 0
199
+ begin_time = time.time()
200
+ processed_files = 0
201
+ model = get_model(config_name, checkpoint_path)
202
+
203
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
204
+ batch = files[idx : idx + batch_size]
205
+ batch_time = process_batch(batch, model)
206
+
207
+ total_time += batch_time
208
+ processed_files += len(batch)
209
+
210
+ if (n_batch + 1) % 10 == 0:
211
+ eta = (
212
+ (time.time() - begin_time)
213
+ / processed_files
214
+ * (len(files) - processed_files)
215
+ )
216
+ logger.info(
217
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
218
+ + f"ETA: {timedelta(seconds=round(eta))}s"
219
+ )
220
+
221
+ logger.info(
222
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
223
+ )
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
@@ -0,0 +1,120 @@
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import hydra
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+ import torchaudio
9
+ from hydra import compose, initialize
10
+ from hydra.utils import instantiate
11
+ from loguru import logger
12
+ from omegaconf import OmegaConf
13
+
14
+ from tools.file import AUDIO_EXTENSIONS
15
+
16
+ # register eval resolver
17
+ OmegaConf.register_new_resolver("eval", eval)
18
+
19
+
20
+ def load_model(config_name, checkpoint_path, device="cuda"):
21
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
22
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
23
+ cfg = compose(config_name=config_name)
24
+
25
+ model = instantiate(cfg)
26
+ state_dict = torch.load(
27
+ checkpoint_path,
28
+ map_location=device,
29
+ )
30
+ if "state_dict" in state_dict:
31
+ state_dict = state_dict["state_dict"]
32
+
33
+ if any("generator" in k for k in state_dict):
34
+ state_dict = {
35
+ k.replace("generator.", ""): v
36
+ for k, v in state_dict.items()
37
+ if "generator." in k
38
+ }
39
+
40
+ result = model.load_state_dict(state_dict, strict=False)
41
+ model.eval()
42
+ model.to(device)
43
+
44
+ logger.info(f"Loaded model: {result}")
45
+ return model
46
+
47
+
48
+ @torch.no_grad()
49
+ @click.command()
50
+ @click.option(
51
+ "--input-path",
52
+ "-i",
53
+ default="test.wav",
54
+ type=click.Path(exists=True, path_type=Path),
55
+ )
56
+ @click.option(
57
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
58
+ )
59
+ @click.option("--config-name", default="firefly_gan_vq")
60
+ @click.option(
61
+ "--checkpoint-path",
62
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
63
+ )
64
+ @click.option(
65
+ "--device",
66
+ "-d",
67
+ default="cuda",
68
+ )
69
+ def main(input_path, output_path, config_name, checkpoint_path, device):
70
+ model = load_model(config_name, checkpoint_path, device=device)
71
+
72
+ if input_path.suffix in AUDIO_EXTENSIONS:
73
+ logger.info(f"Processing in-place reconstruction of {input_path}")
74
+
75
+ # Load audio
76
+ audio, sr = torchaudio.load(str(input_path))
77
+ if audio.shape[0] > 1:
78
+ audio = audio.mean(0, keepdim=True)
79
+ audio = torchaudio.functional.resample(
80
+ audio, sr, model.spec_transform.sample_rate
81
+ )
82
+
83
+ audios = audio[None].to(device)
84
+ logger.info(
85
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
86
+ )
87
+
88
+ # VQ Encoder
89
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
90
+ indices = model.encode(audios, audio_lengths)[0][0]
91
+
92
+ logger.info(f"Generated indices of shape {indices.shape}")
93
+
94
+ # Save indices
95
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
96
+ elif input_path.suffix == ".npy":
97
+ logger.info(f"Processing precomputed indices from {input_path}")
98
+ indices = np.load(input_path)
99
+ indices = torch.from_numpy(indices).to(device).long()
100
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
101
+ else:
102
+ raise ValueError(f"Unknown input type: {input_path}")
103
+
104
+ # Restore
105
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
106
+ fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
107
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
108
+
109
+ logger.info(
110
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
111
+ )
112
+
113
+ # Save audio
114
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
115
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
116
+ logger.info(f"Saved audio to {output_path}")
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()