xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,573 @@
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ import copy
7
+ import json
8
+ import logging
9
+ import os.path
10
+ import random
11
+ import re
12
+ import string
13
+ import time
14
+
15
+ import numpy as np
16
+ import torch
17
+ from funasr.download.download_model_from_hub import download_model
18
+ from funasr.download.file import download_from_url
19
+ from funasr.register import tables
20
+ from funasr.train_utils.load_pretrained_model import load_pretrained_model
21
+ from funasr.train_utils.set_all_random_seed import set_all_random_seed
22
+ from funasr.utils import export_utils, misc
23
+ from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
24
+ from funasr.utils.misc import deep_update
25
+ from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
26
+ from tqdm import tqdm
27
+
28
+ from .vad_utils import merge_vad, slice_padding_audio_samples
29
+
30
+ try:
31
+ from funasr.models.campplus.cluster_backend import ClusterBackend
32
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
33
+ except:
34
+ pass
35
+
36
+
37
+ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
38
+ """ """
39
+ data_list = []
40
+ key_list = []
41
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
42
+
43
+ chars = string.ascii_letters + string.digits
44
+ if isinstance(data_in, str):
45
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
46
+ data_in = download_from_url(data_in)
47
+
48
+ if isinstance(data_in, str) and os.path.exists(
49
+ data_in
50
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
51
+ _, file_extension = os.path.splitext(data_in)
52
+ file_extension = file_extension.lower()
53
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
54
+ with open(data_in, encoding="utf-8") as fin:
55
+ for line in fin:
56
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
57
+ if data_in.endswith(
58
+ ".jsonl"
59
+ ): # file.jsonl: json.dumps({"source": data})
60
+ lines = json.loads(line.strip())
61
+ data = lines["source"]
62
+ key = data["key"] if "key" in data else key
63
+ else: # filelist, wav.scp, text.txt: id \t data or data
64
+ lines = line.strip().split(maxsplit=1)
65
+ data = lines[1] if len(lines) > 1 else lines[0]
66
+ key = lines[0] if len(lines) > 1 else key
67
+
68
+ data_list.append(data)
69
+ key_list.append(key)
70
+ else:
71
+ if key is None:
72
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
73
+ key = misc.extract_filename_without_extension(data_in)
74
+ data_list = [data_in]
75
+ key_list = [key]
76
+ elif isinstance(data_in, (list, tuple)):
77
+ if data_type is not None and isinstance(
78
+ data_type, (list, tuple)
79
+ ): # mutiple inputs
80
+ data_list_tmp = []
81
+ for data_in_i, data_type_i in zip(data_in, data_type):
82
+ key_list, data_list_i = prepare_data_iterator(
83
+ data_in=data_in_i, data_type=data_type_i
84
+ )
85
+ data_list_tmp.append(data_list_i)
86
+ data_list = []
87
+ for item in zip(*data_list_tmp):
88
+ data_list.append(item)
89
+ else:
90
+ # [audio sample point, fbank, text]
91
+ data_list = data_in
92
+ key_list = []
93
+ for data_i in data_in:
94
+ if isinstance(data_i, str) and os.path.exists(data_i):
95
+ key = misc.extract_filename_without_extension(data_i)
96
+ else:
97
+ if key is None:
98
+ key = "rand_key_" + "".join(
99
+ random.choice(chars) for _ in range(13)
100
+ )
101
+ key_list.append(key)
102
+
103
+ else: # raw text; audio sample point, fbank; bytes
104
+ if isinstance(data_in, bytes): # audio bytes
105
+ data_in = load_bytes(data_in)
106
+ if key is None:
107
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
108
+ data_list = [data_in]
109
+ key_list = [key]
110
+
111
+ return key_list, data_list
112
+
113
+
114
+ class AutoModel:
115
+
116
+ def __init__(self, **kwargs):
117
+
118
+ try:
119
+ from funasr.utils.version_checker import check_for_update
120
+
121
+ print(
122
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
123
+ )
124
+ check_for_update(disable=kwargs.get("disable_update", False))
125
+ except:
126
+ pass
127
+
128
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
129
+ logging.basicConfig(level=log_level)
130
+
131
+ model, kwargs = self.build_model(**kwargs)
132
+
133
+ # if vad_model is not None, build vad model else None
134
+ vad_model = kwargs.get("vad_model", None)
135
+ vad_kwargs = (
136
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
137
+ )
138
+ if vad_model is not None:
139
+ logging.info("Building VAD model.")
140
+ vad_kwargs["model"] = vad_model
141
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
142
+ vad_kwargs["device"] = kwargs["device"]
143
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
144
+
145
+ # if punc_model is not None, build punc model else None
146
+ punc_model = kwargs.get("punc_model", None)
147
+ punc_kwargs = (
148
+ {}
149
+ if kwargs.get("punc_kwargs", {}) is None
150
+ else kwargs.get("punc_kwargs", {})
151
+ )
152
+ if punc_model is not None:
153
+ logging.info("Building punc model.")
154
+ punc_kwargs["model"] = punc_model
155
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
156
+ punc_kwargs["device"] = kwargs["device"]
157
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
158
+
159
+ # if spk_model is not None, build spk model else None
160
+ spk_model = kwargs.get("spk_model", None)
161
+ spk_kwargs = (
162
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
163
+ )
164
+ if spk_model is not None:
165
+ logging.info("Building SPK model.")
166
+ spk_kwargs["model"] = spk_model
167
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
168
+ spk_kwargs["device"] = kwargs["device"]
169
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
170
+ self.cb_model = ClusterBackend().to(kwargs["device"])
171
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
172
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
173
+ logging.error(
174
+ "spk_mode should be one of default, vad_segment and punc_segment."
175
+ )
176
+ self.spk_mode = spk_mode
177
+
178
+ self.kwargs = kwargs
179
+ self.model = model
180
+ self.vad_model = vad_model
181
+ self.vad_kwargs = vad_kwargs
182
+ self.punc_model = punc_model
183
+ self.punc_kwargs = punc_kwargs
184
+ self.spk_model = spk_model
185
+ self.spk_kwargs = spk_kwargs
186
+ self.model_path = kwargs.get("model_path")
187
+
188
+ @staticmethod
189
+ def build_model(**kwargs):
190
+ assert "model" in kwargs
191
+ if "model_conf" not in kwargs:
192
+ logging.info(
193
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
194
+ )
195
+ kwargs = download_model(**kwargs)
196
+
197
+ set_all_random_seed(kwargs.get("seed", 0))
198
+
199
+ device = kwargs.get("device", "cuda")
200
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
201
+ device = "cpu"
202
+ kwargs["batch_size"] = 1
203
+ kwargs["device"] = device
204
+
205
+ torch.set_num_threads(kwargs.get("ncpu", 4))
206
+
207
+ # build tokenizer
208
+ tokenizer = kwargs.get("tokenizer", None)
209
+ if tokenizer is not None:
210
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
211
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
212
+ kwargs["token_list"] = (
213
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
214
+ )
215
+ kwargs["token_list"] = (
216
+ tokenizer.get_vocab()
217
+ if hasattr(tokenizer, "get_vocab")
218
+ else kwargs["token_list"]
219
+ )
220
+ vocab_size = (
221
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
222
+ )
223
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
224
+ vocab_size = tokenizer.get_vocab_size()
225
+ else:
226
+ vocab_size = -1
227
+ kwargs["tokenizer"] = tokenizer
228
+
229
+ # build frontend
230
+ frontend = kwargs.get("frontend", None)
231
+ kwargs["input_size"] = None
232
+ if frontend is not None:
233
+ frontend_class = tables.frontend_classes.get(frontend)
234
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
235
+ kwargs["input_size"] = (
236
+ frontend.output_size() if hasattr(frontend, "output_size") else None
237
+ )
238
+ kwargs["frontend"] = frontend
239
+ # build model
240
+ model_class = tables.model_classes.get(kwargs["model"])
241
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
242
+ model_conf = {}
243
+ deep_update(model_conf, kwargs.get("model_conf", {}))
244
+ deep_update(model_conf, kwargs)
245
+ model = model_class(**model_conf, vocab_size=vocab_size)
246
+
247
+ # init_param
248
+ init_param = kwargs.get("init_param", None)
249
+ if init_param is not None:
250
+ if os.path.exists(init_param):
251
+ logging.info(f"Loading pretrained params from {init_param}")
252
+ load_pretrained_model(
253
+ model=model,
254
+ path=init_param,
255
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
256
+ oss_bucket=kwargs.get("oss_bucket", None),
257
+ scope_map=kwargs.get("scope_map", []),
258
+ excludes=kwargs.get("excludes", None),
259
+ )
260
+ else:
261
+ print(f"error, init_param does not exist!: {init_param}")
262
+
263
+ # fp16
264
+ if kwargs.get("fp16", False):
265
+ model.to(torch.float16)
266
+ elif kwargs.get("bf16", False):
267
+ model.to(torch.bfloat16)
268
+ model.to(device)
269
+
270
+ if not kwargs.get("disable_log", True):
271
+ tables.print()
272
+
273
+ return model, kwargs
274
+
275
+ def __call__(self, *args, **cfg):
276
+ kwargs = self.kwargs
277
+ deep_update(kwargs, cfg)
278
+ res = self.model(*args, kwargs)
279
+ return res
280
+
281
+ def generate(self, input, input_len=None, **cfg):
282
+ if self.vad_model is None:
283
+ return self.inference(input, input_len=input_len, **cfg)
284
+
285
+ else:
286
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
287
+
288
+ def inference(
289
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
290
+ ):
291
+ kwargs = self.kwargs if kwargs is None else kwargs
292
+ if "cache" in kwargs:
293
+ kwargs.pop("cache")
294
+ deep_update(kwargs, cfg)
295
+ model = self.model if model is None else model
296
+ model.eval()
297
+
298
+ batch_size = kwargs.get("batch_size", 1)
299
+ # if kwargs.get("device", "cpu") == "cpu":
300
+ # batch_size = 1
301
+
302
+ key_list, data_list = prepare_data_iterator(
303
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
304
+ )
305
+
306
+ speed_stats = {}
307
+ asr_result_list = []
308
+ num_samples = len(data_list)
309
+ disable_pbar = self.kwargs.get("disable_pbar", False)
310
+ pbar = (
311
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
312
+ if not disable_pbar
313
+ else None
314
+ )
315
+ time_speech_total = 0.0
316
+ time_escape_total = 0.0
317
+ for beg_idx in range(0, num_samples, batch_size):
318
+ end_idx = min(num_samples, beg_idx + batch_size)
319
+ data_batch = data_list[beg_idx:end_idx]
320
+ key_batch = key_list[beg_idx:end_idx]
321
+ batch = {"data_in": data_batch, "key": key_batch}
322
+
323
+ if (end_idx - beg_idx) == 1 and kwargs.get(
324
+ "data_type", None
325
+ ) == "fbank": # fbank
326
+ batch["data_in"] = data_batch[0]
327
+ batch["data_lengths"] = input_len
328
+
329
+ time1 = time.perf_counter()
330
+ with torch.no_grad():
331
+ res = model.inference(**batch, **kwargs)
332
+ if isinstance(res, (list, tuple)):
333
+ results = res[0] if len(res) > 0 else [{"text": ""}]
334
+ meta_data = res[1] if len(res) > 1 else {}
335
+ time2 = time.perf_counter()
336
+
337
+ asr_result_list.extend(results)
338
+
339
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
340
+ batch_data_time = meta_data.get("batch_data_time", -1)
341
+ time_escape = time2 - time1
342
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
343
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
344
+ speed_stats["forward"] = f"{time_escape:0.3f}"
345
+ speed_stats["batch_size"] = f"{len(results)}"
346
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
347
+ description = f"{speed_stats}, "
348
+ if pbar:
349
+ pbar.update(end_idx - beg_idx)
350
+ pbar.set_description(description)
351
+ time_speech_total += batch_data_time
352
+ time_escape_total += time_escape
353
+
354
+ if pbar:
355
+ # pbar.update(1)
356
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
357
+ torch.cuda.empty_cache()
358
+ return asr_result_list
359
+
360
+ def vad(self, input, input_len=None, **cfg):
361
+ kwargs = self.kwargs
362
+ # step.1: compute the vad model
363
+ deep_update(self.vad_kwargs, cfg)
364
+ beg_vad = time.time()
365
+ res = self.inference(
366
+ input,
367
+ input_len=input_len,
368
+ model=self.vad_model,
369
+ kwargs=self.vad_kwargs,
370
+ **cfg,
371
+ )
372
+ end_vad = time.time()
373
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
374
+ if cfg.get("merge_vad", False):
375
+ for i in range(len(res)):
376
+ res[i]["value"] = merge_vad(
377
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
378
+ )
379
+ elapsed = end_vad - beg_vad
380
+ return elapsed, res
381
+
382
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
383
+
384
+ kwargs = self.kwargs
385
+
386
+ # step.2 compute asr model
387
+ model = self.model
388
+ deep_update(kwargs, cfg)
389
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
390
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
391
+ kwargs["batch_size"] = batch_size
392
+
393
+ key_list, data_list = prepare_data_iterator(
394
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
395
+ )
396
+ results_ret_list = []
397
+ time_speech_total_all_samples = 1e-6
398
+
399
+ beg_total = time.time()
400
+ pbar_total = (
401
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
402
+ if not kwargs.get("disable_pbar", False)
403
+ else None
404
+ )
405
+
406
+ for i in range(len(vad_res)):
407
+ key = vad_res[i]["key"]
408
+ vadsegments = vad_res[i]["value"]
409
+ input_i = data_list[i]
410
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
411
+ speech = load_audio_text_image_video(
412
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
413
+ )
414
+ speech_lengths = len(speech)
415
+ n = len(vadsegments)
416
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
417
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
418
+ results_sorted = []
419
+
420
+ if not len(sorted_data):
421
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
422
+ logging.info("decoding, utt: {}, empty speech".format(key))
423
+ continue
424
+
425
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
426
+ batch_size = max(
427
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
428
+ )
429
+
430
+ if kwargs["device"] == "cpu":
431
+ batch_size = 0
432
+
433
+ beg_idx = 0
434
+ beg_asr_total = time.time()
435
+ time_speech_total_per_sample = speech_lengths / 16000
436
+ time_speech_total_all_samples += time_speech_total_per_sample
437
+
438
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
439
+
440
+ all_segments = []
441
+ max_len_in_batch = 0
442
+ end_idx = 1
443
+
444
+ for j, _ in enumerate(range(0, n)):
445
+ # pbar_sample.update(1)
446
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
447
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
448
+ j + 1 - beg_idx
449
+ )
450
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
451
+ if (
452
+ j < n - 1
453
+ and sample_length < batch_size_threshold_ms
454
+ and potential_batch_length < batch_size
455
+ ):
456
+ max_len_in_batch = max(max_len_in_batch, sample_length)
457
+ end_idx += 1
458
+ continue
459
+
460
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
461
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
462
+ )
463
+ results = self.inference(
464
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
465
+ )
466
+
467
+ for _b in range(len(speech_j)):
468
+ results[_b]["interval"] = intervals[_b]
469
+
470
+ if self.spk_model is not None:
471
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
472
+ for _b in range(len(speech_j)):
473
+ vad_segments = [
474
+ [
475
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
476
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
477
+ np.array(speech_j[_b]),
478
+ ]
479
+ ]
480
+ segments = sv_chunk(vad_segments)
481
+ all_segments.extend(segments)
482
+ speech_b = [i[2] for i in segments]
483
+ spk_res = self.inference(
484
+ speech_b,
485
+ input_len=None,
486
+ model=self.spk_model,
487
+ kwargs=kwargs,
488
+ **cfg,
489
+ )
490
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
491
+
492
+ beg_idx = end_idx
493
+ end_idx += 1
494
+ max_len_in_batch = sample_length
495
+ if len(results) < 1:
496
+ continue
497
+ results_sorted.extend(results)
498
+
499
+ # end_asr_total = time.time()
500
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
501
+ # pbar_sample.update(1)
502
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
503
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
504
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
505
+
506
+ restored_data = [0] * n
507
+ for j in range(n):
508
+ index = sorted_data[j][1]
509
+ cur = results_sorted[j]
510
+ pattern = r"<\|([^|]+)\|>"
511
+ emotion_string = re.findall(pattern, cur["text"])
512
+ cur["text"] = re.sub(pattern, "", cur["text"])
513
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
514
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
515
+ deep_update(self.punc_kwargs, cfg)
516
+ punc_res = self.inference(
517
+ cur["text"],
518
+ model=self.punc_model,
519
+ kwargs=self.punc_kwargs,
520
+ **cfg,
521
+ )
522
+ cur["text"] = punc_res[0]["text"]
523
+
524
+ restored_data[index] = cur
525
+
526
+ end_asr_total = time.time()
527
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
528
+ if pbar_total:
529
+ pbar_total.update(1)
530
+ pbar_total.set_description(
531
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
532
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
533
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
534
+ )
535
+
536
+ # end_total = time.time()
537
+ # time_escape_total_all_samples = end_total - beg_total
538
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
539
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
540
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
541
+ return restored_data
542
+
543
+ def export(self, input=None, **cfg):
544
+ """
545
+
546
+ :param input:
547
+ :param type:
548
+ :param quantize:
549
+ :param fallback_num:
550
+ :param calib_num:
551
+ :param opset_version:
552
+ :param cfg:
553
+ :return:
554
+ """
555
+
556
+ device = cfg.get("device", "cpu")
557
+ model = self.model.to(device=device)
558
+ kwargs = self.kwargs
559
+ deep_update(kwargs, cfg)
560
+ kwargs["device"] = device
561
+ del kwargs["model"]
562
+ model.eval()
563
+
564
+ type = kwargs.get("type", "onnx")
565
+
566
+ key_list, data_list = prepare_data_iterator(
567
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
568
+ )
569
+
570
+ with torch.no_grad():
571
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
572
+
573
+ return export_dir