xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +15 -34
- xinference/client/restful/restful_client.py +2 -2
- xinference/core/chat_interface.py +45 -10
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +8 -5
- xinference/core/scheduler.py +1 -2
- xinference/core/worker.py +49 -42
- xinference/deploy/cmdline.py +2 -2
- xinference/deploy/test/test_cmdline.py +7 -7
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +26 -27
- xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
- xinference/model/llm/llm_family.json +606 -1266
- xinference/model/llm/llm_family.py +16 -139
- xinference/model/llm/llm_family_modelscope.json +276 -313
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/memory.py +9 -9
- xinference/model/llm/sglang/core.py +2 -2
- xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
- xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/{pytorch → transformers}/core.py +3 -10
- xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +540 -0
- xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
- xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
- xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
- xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
- xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
- xinference/model/llm/utils.py +85 -70
- xinference/model/llm/vllm/core.py +110 -11
- xinference/model/utils.py +1 -95
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/internvl/__init__.py +0 -0
- xinference/thirdparty/internvl/conversation.py +393 -0
- xinference/thirdparty/omnilmm/model/utils.py +16 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
- xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
- xinference/locale/utils.py +0 -39
- xinference/locale/zh_CN.json +0 -26
- xinference/model/llm/ggml/tools/__init__.py +0 -15
- xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
- xinference/model/llm/ggml/tools/gguf.py +0 -884
- xinference/model/llm/pytorch/__init__.py +0 -13
- xinference/model/llm/pytorch/baichuan.py +0 -81
- xinference/model/llm/pytorch/falcon.py +0 -138
- xinference/model/llm/pytorch/intern_vl.py +0 -352
- xinference/model/llm/pytorch/vicuna.py +0 -69
- xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
- /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
- /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
- /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
- /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import io
|
|
3
|
+
import json
|
|
4
|
+
import queue
|
|
5
|
+
import random
|
|
6
|
+
import sys
|
|
7
|
+
import traceback
|
|
8
|
+
import wave
|
|
9
|
+
from argparse import ArgumentParser
|
|
10
|
+
from http import HTTPStatus
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Annotated, Literal, Optional
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
# import pyrootutils
|
|
16
|
+
import soundfile as sf
|
|
17
|
+
import torch
|
|
18
|
+
import torchaudio
|
|
19
|
+
# from kui.asgi import (
|
|
20
|
+
# Body,
|
|
21
|
+
# HTTPException,
|
|
22
|
+
# HttpView,
|
|
23
|
+
# JSONResponse,
|
|
24
|
+
# Kui,
|
|
25
|
+
# OpenAPI,
|
|
26
|
+
# StreamResponse,
|
|
27
|
+
# )
|
|
28
|
+
# from kui.asgi.routing import MultimethodRoutes
|
|
29
|
+
from loguru import logger
|
|
30
|
+
from pydantic import BaseModel, Field
|
|
31
|
+
|
|
32
|
+
# pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
33
|
+
|
|
34
|
+
# from fish_speech.models.vqgan.lit_module import VQGAN
|
|
35
|
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
36
|
+
from fish_speech.utils import autocast_exclude_mps
|
|
37
|
+
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
|
38
|
+
from tools.llama.generate import (
|
|
39
|
+
GenerateRequest,
|
|
40
|
+
GenerateResponse,
|
|
41
|
+
WrappedGenerateResponse,
|
|
42
|
+
launch_thread_safe_queue,
|
|
43
|
+
)
|
|
44
|
+
from tools.vqgan.inference import load_model as load_decoder_model
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
48
|
+
buffer = io.BytesIO()
|
|
49
|
+
|
|
50
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
51
|
+
wav_file.setnchannels(channels)
|
|
52
|
+
wav_file.setsampwidth(bit_depth // 8)
|
|
53
|
+
wav_file.setframerate(sample_rate)
|
|
54
|
+
|
|
55
|
+
wav_header_bytes = buffer.getvalue()
|
|
56
|
+
buffer.close()
|
|
57
|
+
return wav_header_bytes
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Define utils for web server
|
|
61
|
+
# async def http_execption_handler(exc: HTTPException):
|
|
62
|
+
# return JSONResponse(
|
|
63
|
+
# dict(
|
|
64
|
+
# statusCode=exc.status_code,
|
|
65
|
+
# message=exc.content,
|
|
66
|
+
# error=HTTPStatus(exc.status_code).phrase,
|
|
67
|
+
# ),
|
|
68
|
+
# exc.status_code,
|
|
69
|
+
# exc.headers,
|
|
70
|
+
# )
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
async def other_exception_handler(exc: "Exception"):
|
|
74
|
+
traceback.print_exc()
|
|
75
|
+
|
|
76
|
+
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
77
|
+
return JSONResponse(
|
|
78
|
+
dict(statusCode=status, message=str(exc), error=status.phrase),
|
|
79
|
+
status,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def load_audio(reference_audio, sr):
|
|
84
|
+
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
85
|
+
try:
|
|
86
|
+
audio_data = base64.b64decode(reference_audio)
|
|
87
|
+
reference_audio = io.BytesIO(audio_data)
|
|
88
|
+
except base64.binascii.Error:
|
|
89
|
+
raise ValueError("Invalid path or base64 string")
|
|
90
|
+
|
|
91
|
+
waveform, original_sr = torchaudio.load(
|
|
92
|
+
reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if waveform.shape[0] > 1:
|
|
96
|
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
97
|
+
|
|
98
|
+
if original_sr != sr:
|
|
99
|
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
|
100
|
+
waveform = resampler(waveform)
|
|
101
|
+
|
|
102
|
+
audio = waveform.squeeze().numpy()
|
|
103
|
+
return audio
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
|
107
|
+
if enable_reference_audio and reference_audio is not None:
|
|
108
|
+
# Load audios, and prepare basic info here
|
|
109
|
+
reference_audio_content = load_audio(
|
|
110
|
+
reference_audio, decoder_model.spec_transform.sample_rate
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
|
114
|
+
None, None, :
|
|
115
|
+
]
|
|
116
|
+
audio_lengths = torch.tensor(
|
|
117
|
+
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
|
118
|
+
)
|
|
119
|
+
logger.info(
|
|
120
|
+
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# VQ Encoder
|
|
124
|
+
if isinstance(decoder_model, FireflyArchitecture):
|
|
125
|
+
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
|
126
|
+
|
|
127
|
+
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
|
128
|
+
else:
|
|
129
|
+
prompt_tokens = None
|
|
130
|
+
logger.info("No reference audio provided")
|
|
131
|
+
|
|
132
|
+
return prompt_tokens
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def decode_vq_tokens(
|
|
136
|
+
*,
|
|
137
|
+
decoder_model,
|
|
138
|
+
codes,
|
|
139
|
+
):
|
|
140
|
+
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
|
141
|
+
logger.info(f"VQ features: {codes.shape}")
|
|
142
|
+
|
|
143
|
+
if isinstance(decoder_model, FireflyArchitecture):
|
|
144
|
+
# VQGAN Inference
|
|
145
|
+
return decoder_model.decode(
|
|
146
|
+
indices=codes[None],
|
|
147
|
+
feature_lengths=feature_lengths,
|
|
148
|
+
).squeeze()
|
|
149
|
+
|
|
150
|
+
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# routes = MultimethodRoutes(base_class=HttpView)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_random_paths(base_path, data, speaker, emotion):
|
|
157
|
+
if base_path and data and speaker and emotion and (Path(base_path).exists()):
|
|
158
|
+
if speaker in data and emotion in data[speaker]:
|
|
159
|
+
files = data[speaker][emotion]
|
|
160
|
+
lab_files = [f for f in files if f.endswith(".lab")]
|
|
161
|
+
wav_files = [f for f in files if f.endswith(".wav")]
|
|
162
|
+
|
|
163
|
+
if lab_files and wav_files:
|
|
164
|
+
selected_lab = random.choice(lab_files)
|
|
165
|
+
selected_wav = random.choice(wav_files)
|
|
166
|
+
|
|
167
|
+
lab_path = Path(base_path) / speaker / emotion / selected_lab
|
|
168
|
+
wav_path = Path(base_path) / speaker / emotion / selected_wav
|
|
169
|
+
if lab_path.exists() and wav_path.exists():
|
|
170
|
+
return lab_path, wav_path
|
|
171
|
+
|
|
172
|
+
return None, None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def load_json(json_file):
|
|
176
|
+
if not json_file:
|
|
177
|
+
logger.info("Not using a json file")
|
|
178
|
+
return None
|
|
179
|
+
try:
|
|
180
|
+
with open(json_file, "r", encoding="utf-8") as file:
|
|
181
|
+
data = json.load(file)
|
|
182
|
+
except FileNotFoundError:
|
|
183
|
+
logger.warning(f"ref json not found: {json_file}")
|
|
184
|
+
data = None
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.warning(f"Loading json failed: {e}")
|
|
187
|
+
data = None
|
|
188
|
+
return data
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class InvokeRequest(BaseModel):
|
|
192
|
+
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
193
|
+
reference_text: Optional[str] = None
|
|
194
|
+
reference_audio: Optional[str] = None
|
|
195
|
+
max_new_tokens: int = 1024
|
|
196
|
+
chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
|
|
197
|
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
198
|
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
199
|
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
200
|
+
emotion: Optional[str] = None
|
|
201
|
+
format: Literal["wav", "mp3", "flac"] = "wav"
|
|
202
|
+
streaming: bool = False
|
|
203
|
+
ref_json: Optional[str] = "ref_data.json"
|
|
204
|
+
ref_base: Optional[str] = "ref_data"
|
|
205
|
+
speaker: Optional[str] = None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_content_type(audio_format):
|
|
209
|
+
if audio_format == "wav":
|
|
210
|
+
return "audio/wav"
|
|
211
|
+
elif audio_format == "flac":
|
|
212
|
+
return "audio/flac"
|
|
213
|
+
elif audio_format == "mp3":
|
|
214
|
+
return "audio/mpeg"
|
|
215
|
+
else:
|
|
216
|
+
return "application/octet-stream"
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@torch.inference_mode()
|
|
220
|
+
def inference(req: InvokeRequest):
|
|
221
|
+
# Parse reference audio aka prompt
|
|
222
|
+
prompt_tokens = None
|
|
223
|
+
|
|
224
|
+
ref_data = load_json(req.ref_json)
|
|
225
|
+
ref_base = req.ref_base
|
|
226
|
+
|
|
227
|
+
lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
|
|
228
|
+
|
|
229
|
+
if lab_path and wav_path:
|
|
230
|
+
with open(lab_path, "r", encoding="utf-8") as lab_file:
|
|
231
|
+
ref_text = lab_file.read()
|
|
232
|
+
req.reference_audio = wav_path
|
|
233
|
+
req.reference_text = ref_text
|
|
234
|
+
logger.info("ref_path: " + str(wav_path))
|
|
235
|
+
logger.info("ref_text: " + ref_text)
|
|
236
|
+
|
|
237
|
+
# Parse reference audio aka prompt
|
|
238
|
+
prompt_tokens = encode_reference(
|
|
239
|
+
decoder_model=decoder_model,
|
|
240
|
+
reference_audio=req.reference_audio,
|
|
241
|
+
enable_reference_audio=req.reference_audio is not None,
|
|
242
|
+
)
|
|
243
|
+
logger.info(f"ref_text: {req.reference_text}")
|
|
244
|
+
# LLAMA Inference
|
|
245
|
+
request = dict(
|
|
246
|
+
device=decoder_model.device,
|
|
247
|
+
max_new_tokens=req.max_new_tokens,
|
|
248
|
+
text=req.text,
|
|
249
|
+
top_p=req.top_p,
|
|
250
|
+
repetition_penalty=req.repetition_penalty,
|
|
251
|
+
temperature=req.temperature,
|
|
252
|
+
compile=args.compile,
|
|
253
|
+
iterative_prompt=req.chunk_length > 0,
|
|
254
|
+
chunk_length=req.chunk_length,
|
|
255
|
+
max_length=2048,
|
|
256
|
+
prompt_tokens=prompt_tokens,
|
|
257
|
+
prompt_text=req.reference_text,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
response_queue = queue.Queue()
|
|
261
|
+
llama_queue.put(
|
|
262
|
+
GenerateRequest(
|
|
263
|
+
request=request,
|
|
264
|
+
response_queue=response_queue,
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if req.streaming:
|
|
269
|
+
yield wav_chunk_header()
|
|
270
|
+
|
|
271
|
+
segments = []
|
|
272
|
+
while True:
|
|
273
|
+
result: WrappedGenerateResponse = response_queue.get()
|
|
274
|
+
if result.status == "error":
|
|
275
|
+
raise result.response
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
result: GenerateResponse = result.response
|
|
279
|
+
if result.action == "next":
|
|
280
|
+
break
|
|
281
|
+
|
|
282
|
+
with autocast_exclude_mps(
|
|
283
|
+
device_type=decoder_model.device.type, dtype=args.precision
|
|
284
|
+
):
|
|
285
|
+
fake_audios = decode_vq_tokens(
|
|
286
|
+
decoder_model=decoder_model,
|
|
287
|
+
codes=result.codes,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
fake_audios = fake_audios.float().cpu().numpy()
|
|
291
|
+
|
|
292
|
+
if req.streaming:
|
|
293
|
+
yield (fake_audios * 32768).astype(np.int16).tobytes()
|
|
294
|
+
else:
|
|
295
|
+
segments.append(fake_audios)
|
|
296
|
+
|
|
297
|
+
if req.streaming:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
if len(segments) == 0:
|
|
301
|
+
raise HTTPException(
|
|
302
|
+
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
303
|
+
content="No audio generated, please check the input text.",
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
fake_audios = np.concatenate(segments, axis=0)
|
|
307
|
+
yield fake_audios
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
|
|
311
|
+
if not use_auto_rerank:
|
|
312
|
+
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
|
|
313
|
+
return inference(req)
|
|
314
|
+
|
|
315
|
+
zh_model, en_model = load_model()
|
|
316
|
+
max_attempts = 5
|
|
317
|
+
best_wer = float("inf")
|
|
318
|
+
best_audio = None
|
|
319
|
+
|
|
320
|
+
for attempt in range(max_attempts):
|
|
321
|
+
# 调用原始的 inference 函数
|
|
322
|
+
audio_generator = inference(req)
|
|
323
|
+
fake_audios = next(audio_generator)
|
|
324
|
+
|
|
325
|
+
asr_result = batch_asr(
|
|
326
|
+
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
|
|
327
|
+
)[0]
|
|
328
|
+
wer = calculate_wer(req.text, asr_result["text"])
|
|
329
|
+
|
|
330
|
+
if wer <= 0.1 and not asr_result["huge_gap"]:
|
|
331
|
+
return fake_audios
|
|
332
|
+
|
|
333
|
+
if wer < best_wer:
|
|
334
|
+
best_wer = wer
|
|
335
|
+
best_audio = fake_audios
|
|
336
|
+
|
|
337
|
+
if attempt == max_attempts - 1:
|
|
338
|
+
break
|
|
339
|
+
|
|
340
|
+
return best_audio
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
async def inference_async(req: InvokeRequest):
|
|
344
|
+
for chunk in inference(req):
|
|
345
|
+
yield chunk
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
async def buffer_to_async_generator(buffer):
|
|
349
|
+
yield buffer
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
# @routes.http.post("/v1/invoke")
|
|
353
|
+
# async def api_invoke_model(
|
|
354
|
+
# req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
355
|
+
# ):
|
|
356
|
+
# """
|
|
357
|
+
# Invoke model and generate audio
|
|
358
|
+
# """
|
|
359
|
+
#
|
|
360
|
+
# if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
|
361
|
+
# raise HTTPException(
|
|
362
|
+
# HTTPStatus.BAD_REQUEST,
|
|
363
|
+
# content=f"Text is too long, max length is {args.max_text_length}",
|
|
364
|
+
# )
|
|
365
|
+
#
|
|
366
|
+
# if req.streaming and req.format != "wav":
|
|
367
|
+
# raise HTTPException(
|
|
368
|
+
# HTTPStatus.BAD_REQUEST,
|
|
369
|
+
# content="Streaming only supports WAV format",
|
|
370
|
+
# )
|
|
371
|
+
#
|
|
372
|
+
# if req.streaming:
|
|
373
|
+
# return StreamResponse(
|
|
374
|
+
# iterable=inference_async(req),
|
|
375
|
+
# headers={
|
|
376
|
+
# "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
377
|
+
# },
|
|
378
|
+
# content_type=get_content_type(req.format),
|
|
379
|
+
# )
|
|
380
|
+
# else:
|
|
381
|
+
# fake_audios = next(inference(req))
|
|
382
|
+
# buffer = io.BytesIO()
|
|
383
|
+
# sf.write(
|
|
384
|
+
# buffer,
|
|
385
|
+
# fake_audios,
|
|
386
|
+
# decoder_model.spec_transform.sample_rate,
|
|
387
|
+
# format=req.format,
|
|
388
|
+
# )
|
|
389
|
+
#
|
|
390
|
+
# return StreamResponse(
|
|
391
|
+
# iterable=buffer_to_async_generator(buffer.getvalue()),
|
|
392
|
+
# headers={
|
|
393
|
+
# "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
394
|
+
# },
|
|
395
|
+
# content_type=get_content_type(req.format),
|
|
396
|
+
# )
|
|
397
|
+
#
|
|
398
|
+
#
|
|
399
|
+
# @routes.http.post("/v1/health")
|
|
400
|
+
# async def api_health():
|
|
401
|
+
# """
|
|
402
|
+
# Health check
|
|
403
|
+
# """
|
|
404
|
+
#
|
|
405
|
+
# return JSONResponse({"status": "ok"})
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def parse_args():
|
|
409
|
+
parser = ArgumentParser()
|
|
410
|
+
parser.add_argument(
|
|
411
|
+
"--llama-checkpoint-path",
|
|
412
|
+
type=str,
|
|
413
|
+
default="checkpoints/fish-speech-1.2-sft",
|
|
414
|
+
)
|
|
415
|
+
parser.add_argument(
|
|
416
|
+
"--decoder-checkpoint-path",
|
|
417
|
+
type=str,
|
|
418
|
+
default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
419
|
+
)
|
|
420
|
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
421
|
+
parser.add_argument("--device", type=str, default="cuda")
|
|
422
|
+
parser.add_argument("--half", action="store_true")
|
|
423
|
+
parser.add_argument("--compile", action="store_true")
|
|
424
|
+
parser.add_argument("--max-text-length", type=int, default=0)
|
|
425
|
+
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
|
|
426
|
+
parser.add_argument("--workers", type=int, default=1)
|
|
427
|
+
parser.add_argument("--use-auto-rerank", type=bool, default=True)
|
|
428
|
+
|
|
429
|
+
return parser.parse_args()
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
# Define Kui app
|
|
433
|
+
# openapi = OpenAPI(
|
|
434
|
+
# {
|
|
435
|
+
# "title": "Fish Speech API",
|
|
436
|
+
# },
|
|
437
|
+
# ).routes
|
|
438
|
+
#
|
|
439
|
+
# app = Kui(
|
|
440
|
+
# routes=routes + openapi[1:], # Remove the default route
|
|
441
|
+
# exception_handlers={
|
|
442
|
+
# HTTPException: http_execption_handler,
|
|
443
|
+
# Exception: other_exception_handler,
|
|
444
|
+
# },
|
|
445
|
+
# cors_config={},
|
|
446
|
+
# )
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
if __name__ == "__main__":
|
|
450
|
+
import threading
|
|
451
|
+
|
|
452
|
+
import uvicorn
|
|
453
|
+
|
|
454
|
+
args = parse_args()
|
|
455
|
+
args.precision = torch.half if args.half else torch.bfloat16
|
|
456
|
+
|
|
457
|
+
logger.info("Loading Llama model...")
|
|
458
|
+
llama_queue = launch_thread_safe_queue(
|
|
459
|
+
checkpoint_path=args.llama_checkpoint_path,
|
|
460
|
+
device=args.device,
|
|
461
|
+
precision=args.precision,
|
|
462
|
+
compile=args.compile,
|
|
463
|
+
)
|
|
464
|
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
465
|
+
|
|
466
|
+
decoder_model = load_decoder_model(
|
|
467
|
+
config_name=args.decoder_config_name,
|
|
468
|
+
checkpoint_path=args.decoder_checkpoint_path,
|
|
469
|
+
device=args.device,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
logger.info("VQ-GAN model loaded, warming up...")
|
|
473
|
+
|
|
474
|
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
475
|
+
list(
|
|
476
|
+
inference(
|
|
477
|
+
InvokeRequest(
|
|
478
|
+
text="Hello world.",
|
|
479
|
+
reference_text=None,
|
|
480
|
+
reference_audio=None,
|
|
481
|
+
max_new_tokens=0,
|
|
482
|
+
top_p=0.7,
|
|
483
|
+
repetition_penalty=1.2,
|
|
484
|
+
temperature=0.7,
|
|
485
|
+
emotion=None,
|
|
486
|
+
format="wav",
|
|
487
|
+
ref_base=None,
|
|
488
|
+
ref_json=None,
|
|
489
|
+
)
|
|
490
|
+
)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
494
|
+
host, port = args.listen.split(":")
|
|
495
|
+
uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
os.environ["MODELSCOPE_CACHE"] = ".cache/"
|
|
4
|
+
|
|
5
|
+
import string
|
|
6
|
+
import time
|
|
7
|
+
from threading import Lock
|
|
8
|
+
|
|
9
|
+
import librosa
|
|
10
|
+
import numpy as np
|
|
11
|
+
import opencc
|
|
12
|
+
import torch
|
|
13
|
+
from faster_whisper import WhisperModel
|
|
14
|
+
|
|
15
|
+
t2s_converter = opencc.OpenCC("t2s")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_model(*, device="cuda"):
|
|
19
|
+
model = WhisperModel(
|
|
20
|
+
"medium",
|
|
21
|
+
device=device,
|
|
22
|
+
compute_type="float16",
|
|
23
|
+
download_root="faster_whisper",
|
|
24
|
+
)
|
|
25
|
+
print("faster_whisper loaded!")
|
|
26
|
+
return model
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@torch.no_grad()
|
|
30
|
+
def batch_asr_internal(model: WhisperModel, audios, sr):
|
|
31
|
+
resampled_audios = []
|
|
32
|
+
for audio in audios:
|
|
33
|
+
|
|
34
|
+
if isinstance(audio, np.ndarray):
|
|
35
|
+
audio = torch.from_numpy(audio).float()
|
|
36
|
+
|
|
37
|
+
if audio.dim() > 1:
|
|
38
|
+
audio = audio.squeeze()
|
|
39
|
+
|
|
40
|
+
assert audio.dim() == 1
|
|
41
|
+
audio_np = audio.numpy()
|
|
42
|
+
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
|
43
|
+
resampled_audios.append(resampled_audio)
|
|
44
|
+
|
|
45
|
+
trans_results = []
|
|
46
|
+
|
|
47
|
+
for resampled_audio in resampled_audios:
|
|
48
|
+
segments, info = model.transcribe(
|
|
49
|
+
resampled_audio,
|
|
50
|
+
language=None,
|
|
51
|
+
beam_size=5,
|
|
52
|
+
initial_prompt="Punctuation is needed in any language.",
|
|
53
|
+
)
|
|
54
|
+
trans_results.append(list(segments))
|
|
55
|
+
|
|
56
|
+
results = []
|
|
57
|
+
for trans_res, audio in zip(trans_results, audios):
|
|
58
|
+
|
|
59
|
+
duration = len(audio) / sr * 1000
|
|
60
|
+
huge_gap = False
|
|
61
|
+
max_gap = 0.0
|
|
62
|
+
|
|
63
|
+
text = None
|
|
64
|
+
last_tr = None
|
|
65
|
+
|
|
66
|
+
for tr in trans_res:
|
|
67
|
+
delta = tr.text.strip()
|
|
68
|
+
if tr.id > 1:
|
|
69
|
+
max_gap = max(tr.start - last_tr.end, max_gap)
|
|
70
|
+
text += delta
|
|
71
|
+
else:
|
|
72
|
+
text = delta
|
|
73
|
+
|
|
74
|
+
last_tr = tr
|
|
75
|
+
if max_gap > 3.0:
|
|
76
|
+
huge_gap = True
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
sim_text = t2s_converter.convert(text)
|
|
80
|
+
results.append(
|
|
81
|
+
{
|
|
82
|
+
"text": sim_text,
|
|
83
|
+
"duration": duration,
|
|
84
|
+
"huge_gap": huge_gap,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return results
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
global_lock = Lock()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def batch_asr(model, audios, sr):
|
|
95
|
+
return batch_asr_internal(model, audios, sr)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def is_chinese(text):
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def calculate_wer(text1, text2, debug=False):
|
|
103
|
+
chars1 = remove_punctuation(text1)
|
|
104
|
+
chars2 = remove_punctuation(text2)
|
|
105
|
+
|
|
106
|
+
m, n = len(chars1), len(chars2)
|
|
107
|
+
|
|
108
|
+
if m > n:
|
|
109
|
+
chars1, chars2 = chars2, chars1
|
|
110
|
+
m, n = n, m
|
|
111
|
+
|
|
112
|
+
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
|
|
113
|
+
curr = [0] * (m + 1)
|
|
114
|
+
|
|
115
|
+
for j in range(1, n + 1):
|
|
116
|
+
curr[0] = j
|
|
117
|
+
for i in range(1, m + 1):
|
|
118
|
+
if chars1[i - 1] == chars2[j - 1]:
|
|
119
|
+
curr[i] = prev[i - 1]
|
|
120
|
+
else:
|
|
121
|
+
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
|
|
122
|
+
prev, curr = curr, prev
|
|
123
|
+
|
|
124
|
+
edits = prev[m]
|
|
125
|
+
tot = max(len(chars1), len(chars2))
|
|
126
|
+
wer = edits / tot
|
|
127
|
+
|
|
128
|
+
if debug:
|
|
129
|
+
print(" gt: ", chars1)
|
|
130
|
+
print(" pred: ", chars2)
|
|
131
|
+
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
|
132
|
+
|
|
133
|
+
return wer
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def remove_punctuation(text):
|
|
137
|
+
chinese_punctuation = (
|
|
138
|
+
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
|
|
139
|
+
'‛""„‟…‧﹏'
|
|
140
|
+
)
|
|
141
|
+
all_punctuation = string.punctuation + chinese_punctuation
|
|
142
|
+
translator = str.maketrans("", "", all_punctuation)
|
|
143
|
+
text_without_punctuation = text.translate(translator)
|
|
144
|
+
return text_without_punctuation
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
if __name__ == "__main__":
|
|
148
|
+
model = load_model()
|
|
149
|
+
audios = [
|
|
150
|
+
librosa.load("44100.wav", sr=44100)[0],
|
|
151
|
+
librosa.load("lengyue.wav", sr=44100)[0],
|
|
152
|
+
]
|
|
153
|
+
print(np.array(audios[0]))
|
|
154
|
+
print(batch_asr(model, audios, 44100))
|
|
155
|
+
|
|
156
|
+
start_time = time.time()
|
|
157
|
+
for _ in range(10):
|
|
158
|
+
print(batch_asr(model, audios, 44100))
|
|
159
|
+
print("Time taken:", time.time() - start_time)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from huggingface_hub import hf_hub_download
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# Download
|
|
7
|
+
def check_and_download_files(repo_id, file_list, local_dir):
|
|
8
|
+
os.makedirs(local_dir, exist_ok=True)
|
|
9
|
+
for file in file_list:
|
|
10
|
+
file_path = os.path.join(local_dir, file)
|
|
11
|
+
if not os.path.exists(file_path):
|
|
12
|
+
print(f"{file} 不存在,从 Hugging Face 仓库下载...")
|
|
13
|
+
hf_hub_download(
|
|
14
|
+
repo_id=repo_id,
|
|
15
|
+
filename=file,
|
|
16
|
+
resume_download=True,
|
|
17
|
+
local_dir=local_dir,
|
|
18
|
+
local_dir_use_symlinks=False,
|
|
19
|
+
)
|
|
20
|
+
else:
|
|
21
|
+
print(f"{file} 已存在,跳过下载。")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# 1st
|
|
25
|
+
repo_id_1 = "fishaudio/fish-speech-1.2-sft"
|
|
26
|
+
local_dir_1 = "./checkpoints/fish-speech-1.2-sft"
|
|
27
|
+
files_1 = [
|
|
28
|
+
"model.pth",
|
|
29
|
+
"README.md",
|
|
30
|
+
"special_tokens_map.json",
|
|
31
|
+
"tokenizer_config.json",
|
|
32
|
+
"tokenizer.json",
|
|
33
|
+
"config.json",
|
|
34
|
+
"firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
# 3rd
|
|
38
|
+
repo_id_3 = "fishaudio/fish-speech-1"
|
|
39
|
+
local_dir_3 = "./"
|
|
40
|
+
files_3 = [
|
|
41
|
+
"ffmpeg.exe",
|
|
42
|
+
"ffprobe.exe",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
# 4th
|
|
46
|
+
repo_id_4 = "SpicyqSama007/fish-speech-packed"
|
|
47
|
+
local_dir_4 = "./"
|
|
48
|
+
files_4 = [
|
|
49
|
+
"asr-label-win-x64.exe",
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
check_and_download_files(repo_id_1, files_1, local_dir_1)
|
|
53
|
+
|
|
54
|
+
check_and_download_files(repo_id_3, files_3, local_dir_3)
|
|
55
|
+
check_and_download_files(repo_id_4, files_4, local_dir_4)
|