xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,495 @@
1
+ import base64
2
+ import io
3
+ import json
4
+ import queue
5
+ import random
6
+ import sys
7
+ import traceback
8
+ import wave
9
+ from argparse import ArgumentParser
10
+ from http import HTTPStatus
11
+ from pathlib import Path
12
+ from typing import Annotated, Literal, Optional
13
+
14
+ import numpy as np
15
+ # import pyrootutils
16
+ import soundfile as sf
17
+ import torch
18
+ import torchaudio
19
+ # from kui.asgi import (
20
+ # Body,
21
+ # HTTPException,
22
+ # HttpView,
23
+ # JSONResponse,
24
+ # Kui,
25
+ # OpenAPI,
26
+ # StreamResponse,
27
+ # )
28
+ # from kui.asgi.routing import MultimethodRoutes
29
+ from loguru import logger
30
+ from pydantic import BaseModel, Field
31
+
32
+ # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
33
+
34
+ # from fish_speech.models.vqgan.lit_module import VQGAN
35
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
36
+ from fish_speech.utils import autocast_exclude_mps
37
+ from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
38
+ from tools.llama.generate import (
39
+ GenerateRequest,
40
+ GenerateResponse,
41
+ WrappedGenerateResponse,
42
+ launch_thread_safe_queue,
43
+ )
44
+ from tools.vqgan.inference import load_model as load_decoder_model
45
+
46
+
47
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
48
+ buffer = io.BytesIO()
49
+
50
+ with wave.open(buffer, "wb") as wav_file:
51
+ wav_file.setnchannels(channels)
52
+ wav_file.setsampwidth(bit_depth // 8)
53
+ wav_file.setframerate(sample_rate)
54
+
55
+ wav_header_bytes = buffer.getvalue()
56
+ buffer.close()
57
+ return wav_header_bytes
58
+
59
+
60
+ # Define utils for web server
61
+ # async def http_execption_handler(exc: HTTPException):
62
+ # return JSONResponse(
63
+ # dict(
64
+ # statusCode=exc.status_code,
65
+ # message=exc.content,
66
+ # error=HTTPStatus(exc.status_code).phrase,
67
+ # ),
68
+ # exc.status_code,
69
+ # exc.headers,
70
+ # )
71
+
72
+
73
+ async def other_exception_handler(exc: "Exception"):
74
+ traceback.print_exc()
75
+
76
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
77
+ return JSONResponse(
78
+ dict(statusCode=status, message=str(exc), error=status.phrase),
79
+ status,
80
+ )
81
+
82
+
83
+ def load_audio(reference_audio, sr):
84
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
85
+ try:
86
+ audio_data = base64.b64decode(reference_audio)
87
+ reference_audio = io.BytesIO(audio_data)
88
+ except base64.binascii.Error:
89
+ raise ValueError("Invalid path or base64 string")
90
+
91
+ waveform, original_sr = torchaudio.load(
92
+ reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
93
+ )
94
+
95
+ if waveform.shape[0] > 1:
96
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
97
+
98
+ if original_sr != sr:
99
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
100
+ waveform = resampler(waveform)
101
+
102
+ audio = waveform.squeeze().numpy()
103
+ return audio
104
+
105
+
106
+ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
107
+ if enable_reference_audio and reference_audio is not None:
108
+ # Load audios, and prepare basic info here
109
+ reference_audio_content = load_audio(
110
+ reference_audio, decoder_model.spec_transform.sample_rate
111
+ )
112
+
113
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
114
+ None, None, :
115
+ ]
116
+ audio_lengths = torch.tensor(
117
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
118
+ )
119
+ logger.info(
120
+ f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
121
+ )
122
+
123
+ # VQ Encoder
124
+ if isinstance(decoder_model, FireflyArchitecture):
125
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
126
+
127
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
128
+ else:
129
+ prompt_tokens = None
130
+ logger.info("No reference audio provided")
131
+
132
+ return prompt_tokens
133
+
134
+
135
+ def decode_vq_tokens(
136
+ *,
137
+ decoder_model,
138
+ codes,
139
+ ):
140
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
141
+ logger.info(f"VQ features: {codes.shape}")
142
+
143
+ if isinstance(decoder_model, FireflyArchitecture):
144
+ # VQGAN Inference
145
+ return decoder_model.decode(
146
+ indices=codes[None],
147
+ feature_lengths=feature_lengths,
148
+ ).squeeze()
149
+
150
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
151
+
152
+
153
+ # routes = MultimethodRoutes(base_class=HttpView)
154
+
155
+
156
+ def get_random_paths(base_path, data, speaker, emotion):
157
+ if base_path and data and speaker and emotion and (Path(base_path).exists()):
158
+ if speaker in data and emotion in data[speaker]:
159
+ files = data[speaker][emotion]
160
+ lab_files = [f for f in files if f.endswith(".lab")]
161
+ wav_files = [f for f in files if f.endswith(".wav")]
162
+
163
+ if lab_files and wav_files:
164
+ selected_lab = random.choice(lab_files)
165
+ selected_wav = random.choice(wav_files)
166
+
167
+ lab_path = Path(base_path) / speaker / emotion / selected_lab
168
+ wav_path = Path(base_path) / speaker / emotion / selected_wav
169
+ if lab_path.exists() and wav_path.exists():
170
+ return lab_path, wav_path
171
+
172
+ return None, None
173
+
174
+
175
+ def load_json(json_file):
176
+ if not json_file:
177
+ logger.info("Not using a json file")
178
+ return None
179
+ try:
180
+ with open(json_file, "r", encoding="utf-8") as file:
181
+ data = json.load(file)
182
+ except FileNotFoundError:
183
+ logger.warning(f"ref json not found: {json_file}")
184
+ data = None
185
+ except Exception as e:
186
+ logger.warning(f"Loading json failed: {e}")
187
+ data = None
188
+ return data
189
+
190
+
191
+ class InvokeRequest(BaseModel):
192
+ text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
193
+ reference_text: Optional[str] = None
194
+ reference_audio: Optional[str] = None
195
+ max_new_tokens: int = 1024
196
+ chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
197
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
198
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
199
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
200
+ emotion: Optional[str] = None
201
+ format: Literal["wav", "mp3", "flac"] = "wav"
202
+ streaming: bool = False
203
+ ref_json: Optional[str] = "ref_data.json"
204
+ ref_base: Optional[str] = "ref_data"
205
+ speaker: Optional[str] = None
206
+
207
+
208
+ def get_content_type(audio_format):
209
+ if audio_format == "wav":
210
+ return "audio/wav"
211
+ elif audio_format == "flac":
212
+ return "audio/flac"
213
+ elif audio_format == "mp3":
214
+ return "audio/mpeg"
215
+ else:
216
+ return "application/octet-stream"
217
+
218
+
219
+ @torch.inference_mode()
220
+ def inference(req: InvokeRequest):
221
+ # Parse reference audio aka prompt
222
+ prompt_tokens = None
223
+
224
+ ref_data = load_json(req.ref_json)
225
+ ref_base = req.ref_base
226
+
227
+ lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
228
+
229
+ if lab_path and wav_path:
230
+ with open(lab_path, "r", encoding="utf-8") as lab_file:
231
+ ref_text = lab_file.read()
232
+ req.reference_audio = wav_path
233
+ req.reference_text = ref_text
234
+ logger.info("ref_path: " + str(wav_path))
235
+ logger.info("ref_text: " + ref_text)
236
+
237
+ # Parse reference audio aka prompt
238
+ prompt_tokens = encode_reference(
239
+ decoder_model=decoder_model,
240
+ reference_audio=req.reference_audio,
241
+ enable_reference_audio=req.reference_audio is not None,
242
+ )
243
+ logger.info(f"ref_text: {req.reference_text}")
244
+ # LLAMA Inference
245
+ request = dict(
246
+ device=decoder_model.device,
247
+ max_new_tokens=req.max_new_tokens,
248
+ text=req.text,
249
+ top_p=req.top_p,
250
+ repetition_penalty=req.repetition_penalty,
251
+ temperature=req.temperature,
252
+ compile=args.compile,
253
+ iterative_prompt=req.chunk_length > 0,
254
+ chunk_length=req.chunk_length,
255
+ max_length=2048,
256
+ prompt_tokens=prompt_tokens,
257
+ prompt_text=req.reference_text,
258
+ )
259
+
260
+ response_queue = queue.Queue()
261
+ llama_queue.put(
262
+ GenerateRequest(
263
+ request=request,
264
+ response_queue=response_queue,
265
+ )
266
+ )
267
+
268
+ if req.streaming:
269
+ yield wav_chunk_header()
270
+
271
+ segments = []
272
+ while True:
273
+ result: WrappedGenerateResponse = response_queue.get()
274
+ if result.status == "error":
275
+ raise result.response
276
+ break
277
+
278
+ result: GenerateResponse = result.response
279
+ if result.action == "next":
280
+ break
281
+
282
+ with autocast_exclude_mps(
283
+ device_type=decoder_model.device.type, dtype=args.precision
284
+ ):
285
+ fake_audios = decode_vq_tokens(
286
+ decoder_model=decoder_model,
287
+ codes=result.codes,
288
+ )
289
+
290
+ fake_audios = fake_audios.float().cpu().numpy()
291
+
292
+ if req.streaming:
293
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
294
+ else:
295
+ segments.append(fake_audios)
296
+
297
+ if req.streaming:
298
+ return
299
+
300
+ if len(segments) == 0:
301
+ raise HTTPException(
302
+ HTTPStatus.INTERNAL_SERVER_ERROR,
303
+ content="No audio generated, please check the input text.",
304
+ )
305
+
306
+ fake_audios = np.concatenate(segments, axis=0)
307
+ yield fake_audios
308
+
309
+
310
+ def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
311
+ if not use_auto_rerank:
312
+ # 如果不使用 auto_rerank,直接调用原始的 inference 函数
313
+ return inference(req)
314
+
315
+ zh_model, en_model = load_model()
316
+ max_attempts = 5
317
+ best_wer = float("inf")
318
+ best_audio = None
319
+
320
+ for attempt in range(max_attempts):
321
+ # 调用原始的 inference 函数
322
+ audio_generator = inference(req)
323
+ fake_audios = next(audio_generator)
324
+
325
+ asr_result = batch_asr(
326
+ zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
327
+ )[0]
328
+ wer = calculate_wer(req.text, asr_result["text"])
329
+
330
+ if wer <= 0.1 and not asr_result["huge_gap"]:
331
+ return fake_audios
332
+
333
+ if wer < best_wer:
334
+ best_wer = wer
335
+ best_audio = fake_audios
336
+
337
+ if attempt == max_attempts - 1:
338
+ break
339
+
340
+ return best_audio
341
+
342
+
343
+ async def inference_async(req: InvokeRequest):
344
+ for chunk in inference(req):
345
+ yield chunk
346
+
347
+
348
+ async def buffer_to_async_generator(buffer):
349
+ yield buffer
350
+
351
+
352
+ # @routes.http.post("/v1/invoke")
353
+ # async def api_invoke_model(
354
+ # req: Annotated[InvokeRequest, Body(exclusive=True)],
355
+ # ):
356
+ # """
357
+ # Invoke model and generate audio
358
+ # """
359
+ #
360
+ # if args.max_text_length > 0 and len(req.text) > args.max_text_length:
361
+ # raise HTTPException(
362
+ # HTTPStatus.BAD_REQUEST,
363
+ # content=f"Text is too long, max length is {args.max_text_length}",
364
+ # )
365
+ #
366
+ # if req.streaming and req.format != "wav":
367
+ # raise HTTPException(
368
+ # HTTPStatus.BAD_REQUEST,
369
+ # content="Streaming only supports WAV format",
370
+ # )
371
+ #
372
+ # if req.streaming:
373
+ # return StreamResponse(
374
+ # iterable=inference_async(req),
375
+ # headers={
376
+ # "Content-Disposition": f"attachment; filename=audio.{req.format}",
377
+ # },
378
+ # content_type=get_content_type(req.format),
379
+ # )
380
+ # else:
381
+ # fake_audios = next(inference(req))
382
+ # buffer = io.BytesIO()
383
+ # sf.write(
384
+ # buffer,
385
+ # fake_audios,
386
+ # decoder_model.spec_transform.sample_rate,
387
+ # format=req.format,
388
+ # )
389
+ #
390
+ # return StreamResponse(
391
+ # iterable=buffer_to_async_generator(buffer.getvalue()),
392
+ # headers={
393
+ # "Content-Disposition": f"attachment; filename=audio.{req.format}",
394
+ # },
395
+ # content_type=get_content_type(req.format),
396
+ # )
397
+ #
398
+ #
399
+ # @routes.http.post("/v1/health")
400
+ # async def api_health():
401
+ # """
402
+ # Health check
403
+ # """
404
+ #
405
+ # return JSONResponse({"status": "ok"})
406
+
407
+
408
+ def parse_args():
409
+ parser = ArgumentParser()
410
+ parser.add_argument(
411
+ "--llama-checkpoint-path",
412
+ type=str,
413
+ default="checkpoints/fish-speech-1.2-sft",
414
+ )
415
+ parser.add_argument(
416
+ "--decoder-checkpoint-path",
417
+ type=str,
418
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
419
+ )
420
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
421
+ parser.add_argument("--device", type=str, default="cuda")
422
+ parser.add_argument("--half", action="store_true")
423
+ parser.add_argument("--compile", action="store_true")
424
+ parser.add_argument("--max-text-length", type=int, default=0)
425
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
426
+ parser.add_argument("--workers", type=int, default=1)
427
+ parser.add_argument("--use-auto-rerank", type=bool, default=True)
428
+
429
+ return parser.parse_args()
430
+
431
+
432
+ # Define Kui app
433
+ # openapi = OpenAPI(
434
+ # {
435
+ # "title": "Fish Speech API",
436
+ # },
437
+ # ).routes
438
+ #
439
+ # app = Kui(
440
+ # routes=routes + openapi[1:], # Remove the default route
441
+ # exception_handlers={
442
+ # HTTPException: http_execption_handler,
443
+ # Exception: other_exception_handler,
444
+ # },
445
+ # cors_config={},
446
+ # )
447
+
448
+
449
+ if __name__ == "__main__":
450
+ import threading
451
+
452
+ import uvicorn
453
+
454
+ args = parse_args()
455
+ args.precision = torch.half if args.half else torch.bfloat16
456
+
457
+ logger.info("Loading Llama model...")
458
+ llama_queue = launch_thread_safe_queue(
459
+ checkpoint_path=args.llama_checkpoint_path,
460
+ device=args.device,
461
+ precision=args.precision,
462
+ compile=args.compile,
463
+ )
464
+ logger.info("Llama model loaded, loading VQ-GAN model...")
465
+
466
+ decoder_model = load_decoder_model(
467
+ config_name=args.decoder_config_name,
468
+ checkpoint_path=args.decoder_checkpoint_path,
469
+ device=args.device,
470
+ )
471
+
472
+ logger.info("VQ-GAN model loaded, warming up...")
473
+
474
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
475
+ list(
476
+ inference(
477
+ InvokeRequest(
478
+ text="Hello world.",
479
+ reference_text=None,
480
+ reference_audio=None,
481
+ max_new_tokens=0,
482
+ top_p=0.7,
483
+ repetition_penalty=1.2,
484
+ temperature=0.7,
485
+ emotion=None,
486
+ format="wav",
487
+ ref_base=None,
488
+ ref_json=None,
489
+ )
490
+ )
491
+ )
492
+
493
+ logger.info(f"Warming up done, starting server at http://{args.listen}")
494
+ host, port = args.listen.split(":")
495
+ uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
@@ -0,0 +1,159 @@
1
+ import os
2
+
3
+ os.environ["MODELSCOPE_CACHE"] = ".cache/"
4
+
5
+ import string
6
+ import time
7
+ from threading import Lock
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import opencc
12
+ import torch
13
+ from faster_whisper import WhisperModel
14
+
15
+ t2s_converter = opencc.OpenCC("t2s")
16
+
17
+
18
+ def load_model(*, device="cuda"):
19
+ model = WhisperModel(
20
+ "medium",
21
+ device=device,
22
+ compute_type="float16",
23
+ download_root="faster_whisper",
24
+ )
25
+ print("faster_whisper loaded!")
26
+ return model
27
+
28
+
29
+ @torch.no_grad()
30
+ def batch_asr_internal(model: WhisperModel, audios, sr):
31
+ resampled_audios = []
32
+ for audio in audios:
33
+
34
+ if isinstance(audio, np.ndarray):
35
+ audio = torch.from_numpy(audio).float()
36
+
37
+ if audio.dim() > 1:
38
+ audio = audio.squeeze()
39
+
40
+ assert audio.dim() == 1
41
+ audio_np = audio.numpy()
42
+ resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
43
+ resampled_audios.append(resampled_audio)
44
+
45
+ trans_results = []
46
+
47
+ for resampled_audio in resampled_audios:
48
+ segments, info = model.transcribe(
49
+ resampled_audio,
50
+ language=None,
51
+ beam_size=5,
52
+ initial_prompt="Punctuation is needed in any language.",
53
+ )
54
+ trans_results.append(list(segments))
55
+
56
+ results = []
57
+ for trans_res, audio in zip(trans_results, audios):
58
+
59
+ duration = len(audio) / sr * 1000
60
+ huge_gap = False
61
+ max_gap = 0.0
62
+
63
+ text = None
64
+ last_tr = None
65
+
66
+ for tr in trans_res:
67
+ delta = tr.text.strip()
68
+ if tr.id > 1:
69
+ max_gap = max(tr.start - last_tr.end, max_gap)
70
+ text += delta
71
+ else:
72
+ text = delta
73
+
74
+ last_tr = tr
75
+ if max_gap > 3.0:
76
+ huge_gap = True
77
+ break
78
+
79
+ sim_text = t2s_converter.convert(text)
80
+ results.append(
81
+ {
82
+ "text": sim_text,
83
+ "duration": duration,
84
+ "huge_gap": huge_gap,
85
+ }
86
+ )
87
+
88
+ return results
89
+
90
+
91
+ global_lock = Lock()
92
+
93
+
94
+ def batch_asr(model, audios, sr):
95
+ return batch_asr_internal(model, audios, sr)
96
+
97
+
98
+ def is_chinese(text):
99
+ return True
100
+
101
+
102
+ def calculate_wer(text1, text2, debug=False):
103
+ chars1 = remove_punctuation(text1)
104
+ chars2 = remove_punctuation(text2)
105
+
106
+ m, n = len(chars1), len(chars2)
107
+
108
+ if m > n:
109
+ chars1, chars2 = chars2, chars1
110
+ m, n = n, m
111
+
112
+ prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
113
+ curr = [0] * (m + 1)
114
+
115
+ for j in range(1, n + 1):
116
+ curr[0] = j
117
+ for i in range(1, m + 1):
118
+ if chars1[i - 1] == chars2[j - 1]:
119
+ curr[i] = prev[i - 1]
120
+ else:
121
+ curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
122
+ prev, curr = curr, prev
123
+
124
+ edits = prev[m]
125
+ tot = max(len(chars1), len(chars2))
126
+ wer = edits / tot
127
+
128
+ if debug:
129
+ print(" gt: ", chars1)
130
+ print(" pred: ", chars2)
131
+ print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
132
+
133
+ return wer
134
+
135
+
136
+ def remove_punctuation(text):
137
+ chinese_punctuation = (
138
+ " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
139
+ '‛""„‟…‧﹏'
140
+ )
141
+ all_punctuation = string.punctuation + chinese_punctuation
142
+ translator = str.maketrans("", "", all_punctuation)
143
+ text_without_punctuation = text.translate(translator)
144
+ return text_without_punctuation
145
+
146
+
147
+ if __name__ == "__main__":
148
+ model = load_model()
149
+ audios = [
150
+ librosa.load("44100.wav", sr=44100)[0],
151
+ librosa.load("lengyue.wav", sr=44100)[0],
152
+ ]
153
+ print(np.array(audios[0]))
154
+ print(batch_asr(model, audios, 44100))
155
+
156
+ start_time = time.time()
157
+ for _ in range(10):
158
+ print(batch_asr(model, audios, 44100))
159
+ print("Time taken:", time.time() - start_time)
@@ -0,0 +1,55 @@
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+
6
+ # Download
7
+ def check_and_download_files(repo_id, file_list, local_dir):
8
+ os.makedirs(local_dir, exist_ok=True)
9
+ for file in file_list:
10
+ file_path = os.path.join(local_dir, file)
11
+ if not os.path.exists(file_path):
12
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
13
+ hf_hub_download(
14
+ repo_id=repo_id,
15
+ filename=file,
16
+ resume_download=True,
17
+ local_dir=local_dir,
18
+ local_dir_use_symlinks=False,
19
+ )
20
+ else:
21
+ print(f"{file} 已存在,跳过下载。")
22
+
23
+
24
+ # 1st
25
+ repo_id_1 = "fishaudio/fish-speech-1.2-sft"
26
+ local_dir_1 = "./checkpoints/fish-speech-1.2-sft"
27
+ files_1 = [
28
+ "model.pth",
29
+ "README.md",
30
+ "special_tokens_map.json",
31
+ "tokenizer_config.json",
32
+ "tokenizer.json",
33
+ "config.json",
34
+ "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
35
+ ]
36
+
37
+ # 3rd
38
+ repo_id_3 = "fishaudio/fish-speech-1"
39
+ local_dir_3 = "./"
40
+ files_3 = [
41
+ "ffmpeg.exe",
42
+ "ffprobe.exe",
43
+ ]
44
+
45
+ # 4th
46
+ repo_id_4 = "SpicyqSama007/fish-speech-packed"
47
+ local_dir_4 = "./"
48
+ files_4 = [
49
+ "asr-label-win-x64.exe",
50
+ ]
51
+
52
+ check_and_download_files(repo_id_1, files_1, local_dir_1)
53
+
54
+ check_and_download_files(repo_id_3, files_3, local_dir_3)
55
+ check_and_download_files(repo_id_4, files_4, local_dir_4)