xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,605 @@
1
+ import argparse
2
+ import os
3
+ import traceback
4
+ import warnings
5
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import tqdm
10
+
11
+ from .audio import (
12
+ FRAMES_PER_SECOND,
13
+ HOP_LENGTH,
14
+ N_FRAMES,
15
+ N_SAMPLES,
16
+ SAMPLE_RATE,
17
+ log_mel_spectrogram,
18
+ pad_or_trim,
19
+ )
20
+ from .decoding import DecodingOptions, DecodingResult
21
+ from .timing import add_word_timestamps
22
+ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
23
+ from .utils import (
24
+ exact_div,
25
+ format_timestamp,
26
+ get_end,
27
+ get_writer,
28
+ make_safe,
29
+ optional_float,
30
+ optional_int,
31
+ str2bool,
32
+ )
33
+
34
+ if TYPE_CHECKING:
35
+ from .model import Whisper
36
+
37
+
38
+ def transcribe(
39
+ model: "Whisper",
40
+ audio: Union[str, np.ndarray, torch.Tensor],
41
+ *,
42
+ verbose: Optional[bool] = None,
43
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
44
+ compression_ratio_threshold: Optional[float] = 2.4,
45
+ logprob_threshold: Optional[float] = -1.0,
46
+ no_speech_threshold: Optional[float] = 0.6,
47
+ condition_on_previous_text: bool = True,
48
+ initial_prompt: Optional[str] = None,
49
+ word_timestamps: bool = False,
50
+ prepend_punctuations: str = "\"'“¿([{-",
51
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
52
+ clip_timestamps: Union[str, List[float]] = "0",
53
+ hallucination_silence_threshold: Optional[float] = None,
54
+ **decode_options,
55
+ ):
56
+ """
57
+ Transcribe an audio file using Whisper
58
+
59
+ Parameters
60
+ ----------
61
+ model: Whisper
62
+ The Whisper model instance
63
+
64
+ audio: Union[str, np.ndarray, torch.Tensor]
65
+ The path to the audio file to open, or the audio waveform
66
+
67
+ verbose: bool
68
+ Whether to display the text being decoded to the console. If True, displays all the details,
69
+ If False, displays minimal details. If None, does not display anything
70
+
71
+ temperature: Union[float, Tuple[float, ...]]
72
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
73
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
74
+
75
+ compression_ratio_threshold: float
76
+ If the gzip compression ratio is above this value, treat as failed
77
+
78
+ logprob_threshold: float
79
+ If the average log probability over sampled tokens is below this value, treat as failed
80
+
81
+ no_speech_threshold: float
82
+ If the no_speech probability is higher than this value AND the average log probability
83
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
84
+
85
+ condition_on_previous_text: bool
86
+ if True, the previous output of the model is provided as a prompt for the next window;
87
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
88
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
89
+
90
+ word_timestamps: bool
91
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
92
+ and include the timestamps for each word in each segment.
93
+
94
+ prepend_punctuations: str
95
+ If word_timestamps is True, merge these punctuation symbols with the next word
96
+
97
+ append_punctuations: str
98
+ If word_timestamps is True, merge these punctuation symbols with the previous word
99
+
100
+ initial_prompt: Optional[str]
101
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
102
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
103
+ to make it more likely to predict those word correctly.
104
+
105
+ decode_options: dict
106
+ Keyword arguments to construct `DecodingOptions` instances
107
+
108
+ clip_timestamps: Union[str, List[float]]
109
+ Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
110
+ The last end timestamp defaults to the end of the file.
111
+
112
+ hallucination_silence_threshold: Optional[float]
113
+ When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
114
+ when a possible hallucination is detected
115
+
116
+ Returns
117
+ -------
118
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
119
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
120
+ """
121
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
122
+ if model.device == torch.device("cpu"):
123
+ if torch.cuda.is_available():
124
+ warnings.warn("Performing inference on CPU when CUDA is available")
125
+ if dtype == torch.float16:
126
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
127
+ dtype = torch.float32
128
+
129
+ if dtype == torch.float32:
130
+ decode_options["fp16"] = False
131
+
132
+ # Pad 30-seconds of silence to the input audio, for slicing
133
+ mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
134
+ content_frames = mel.shape[-1] - N_FRAMES
135
+ content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
136
+
137
+ if decode_options.get("language", None) is None:
138
+ if not model.is_multilingual:
139
+ decode_options["language"] = "en"
140
+ else:
141
+ if verbose:
142
+ print(
143
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
144
+ )
145
+ mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
146
+ _, probs = model.detect_language(mel_segment)
147
+ decode_options["language"] = max(probs, key=probs.get)
148
+ if verbose is not None:
149
+ print(
150
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
151
+ )
152
+
153
+ language: str = decode_options["language"]
154
+ task: str = decode_options.get("task", "transcribe")
155
+ tokenizer = get_tokenizer(
156
+ model.is_multilingual,
157
+ num_languages=model.num_languages,
158
+ language=language,
159
+ task=task,
160
+ )
161
+
162
+ if isinstance(clip_timestamps, str):
163
+ clip_timestamps = [
164
+ float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
165
+ ]
166
+ seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
167
+ if len(seek_points) == 0:
168
+ seek_points.append(0)
169
+ if len(seek_points) % 2 == 1:
170
+ seek_points.append(content_frames)
171
+ seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
172
+
173
+ punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
174
+
175
+ if word_timestamps and task == "translate":
176
+ warnings.warn("Word-level timestamps on translations may not be reliable.")
177
+
178
+ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
179
+ temperatures = (
180
+ [temperature] if isinstance(temperature, (int, float)) else temperature
181
+ )
182
+ decode_result = None
183
+
184
+ for t in temperatures:
185
+ kwargs = {**decode_options}
186
+ if t > 0:
187
+ # disable beam_size and patience when t > 0
188
+ kwargs.pop("beam_size", None)
189
+ kwargs.pop("patience", None)
190
+ else:
191
+ # disable best_of when t == 0
192
+ kwargs.pop("best_of", None)
193
+
194
+ options = DecodingOptions(**kwargs, temperature=t)
195
+ decode_result = model.decode(segment, options)
196
+
197
+ needs_fallback = False
198
+ if (
199
+ compression_ratio_threshold is not None
200
+ and decode_result.compression_ratio > compression_ratio_threshold
201
+ ):
202
+ needs_fallback = True # too repetitive
203
+ if (
204
+ logprob_threshold is not None
205
+ and decode_result.avg_logprob < logprob_threshold
206
+ ):
207
+ needs_fallback = True # average log probability is too low
208
+ if (
209
+ no_speech_threshold is not None
210
+ and decode_result.no_speech_prob > no_speech_threshold
211
+ ):
212
+ needs_fallback = False # silence
213
+ if not needs_fallback:
214
+ break
215
+
216
+ return decode_result
217
+
218
+ clip_idx = 0
219
+ seek = seek_clips[clip_idx][0]
220
+ input_stride = exact_div(
221
+ N_FRAMES, model.dims.n_audio_ctx
222
+ ) # mel frames per output token: 2
223
+ time_precision = (
224
+ input_stride * HOP_LENGTH / SAMPLE_RATE
225
+ ) # time per output token: 0.02 (seconds)
226
+ all_tokens = []
227
+ all_segments = []
228
+ prompt_reset_since = 0
229
+
230
+ if initial_prompt is not None:
231
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
232
+ all_tokens.extend(initial_prompt_tokens)
233
+ else:
234
+ initial_prompt_tokens = []
235
+
236
+ def new_segment(
237
+ *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
238
+ ):
239
+ tokens = tokens.tolist()
240
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
241
+ return {
242
+ "seek": seek,
243
+ "start": start,
244
+ "end": end,
245
+ "text": tokenizer.decode(text_tokens),
246
+ "tokens": tokens,
247
+ "temperature": result.temperature,
248
+ "avg_logprob": result.avg_logprob,
249
+ "compression_ratio": result.compression_ratio,
250
+ "no_speech_prob": result.no_speech_prob,
251
+ }
252
+
253
+ # show the progress bar when verbose is False (if True, transcribed text will be printed)
254
+ with tqdm.tqdm(
255
+ total=content_frames, unit="frames", disable=verbose is not False
256
+ ) as pbar:
257
+ last_speech_timestamp = 0.0
258
+ # NOTE: This loop is obscurely flattened to make the diff readable.
259
+ # A later commit should turn this into a simpler nested loop.
260
+ # for seek_clip_start, seek_clip_end in seek_clips:
261
+ # while seek < seek_clip_end
262
+ while clip_idx < len(seek_clips):
263
+ seek_clip_start, seek_clip_end = seek_clips[clip_idx]
264
+ if seek < seek_clip_start:
265
+ seek = seek_clip_start
266
+ if seek >= seek_clip_end:
267
+ clip_idx += 1
268
+ if clip_idx < len(seek_clips):
269
+ seek = seek_clips[clip_idx][0]
270
+ continue
271
+ time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
272
+ window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
273
+ segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
274
+ mel_segment = mel[:, seek : seek + segment_size]
275
+ segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
276
+ mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
277
+
278
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
279
+ result: DecodingResult = decode_with_fallback(mel_segment)
280
+ tokens = torch.tensor(result.tokens)
281
+
282
+ if no_speech_threshold is not None:
283
+ # no voice activity check
284
+ should_skip = result.no_speech_prob > no_speech_threshold
285
+ if (
286
+ logprob_threshold is not None
287
+ and result.avg_logprob > logprob_threshold
288
+ ):
289
+ # don't skip if the logprob is high enough, despite the no_speech_prob
290
+ should_skip = False
291
+
292
+ if should_skip:
293
+ seek += segment_size # fast-forward to the next segment boundary
294
+ continue
295
+
296
+ previous_seek = seek
297
+ current_segments = []
298
+
299
+ # anomalous words are very long/short/improbable
300
+ def word_anomaly_score(word: dict) -> float:
301
+ probability = word.get("probability", 0.0)
302
+ duration = word["end"] - word["start"]
303
+ score = 0.0
304
+ if probability < 0.15:
305
+ score += 1.0
306
+ if duration < 0.133:
307
+ score += (0.133 - duration) * 15
308
+ if duration > 2.0:
309
+ score += duration - 2.0
310
+ return score
311
+
312
+ def is_segment_anomaly(segment: Optional[dict]) -> bool:
313
+ if segment is None or not segment["words"]:
314
+ return False
315
+ words = [w for w in segment["words"] if w["word"] not in punctuation]
316
+ words = words[:8]
317
+ score = sum(word_anomaly_score(w) for w in words)
318
+ return score >= 3 or score + 0.01 >= len(words)
319
+
320
+ def next_words_segment(segments: List[dict]) -> Optional[dict]:
321
+ return next((s for s in segments if s["words"]), None)
322
+
323
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
324
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
325
+
326
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
327
+ consecutive.add_(1)
328
+ if len(consecutive) > 0:
329
+ # if the output contains two consecutive timestamp tokens
330
+ slices = consecutive.tolist()
331
+ if single_timestamp_ending:
332
+ slices.append(len(tokens))
333
+
334
+ last_slice = 0
335
+ for current_slice in slices:
336
+ sliced_tokens = tokens[last_slice:current_slice]
337
+ start_timestamp_pos = (
338
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
339
+ )
340
+ end_timestamp_pos = (
341
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
342
+ )
343
+ current_segments.append(
344
+ new_segment(
345
+ start=time_offset + start_timestamp_pos * time_precision,
346
+ end=time_offset + end_timestamp_pos * time_precision,
347
+ tokens=sliced_tokens,
348
+ result=result,
349
+ )
350
+ )
351
+ last_slice = current_slice
352
+
353
+ if single_timestamp_ending:
354
+ # single timestamp at the end means no speech after the last timestamp.
355
+ seek += segment_size
356
+ else:
357
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
358
+ last_timestamp_pos = (
359
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
360
+ )
361
+ seek += last_timestamp_pos * input_stride
362
+ else:
363
+ duration = segment_duration
364
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
365
+ if (
366
+ len(timestamps) > 0
367
+ and timestamps[-1].item() != tokenizer.timestamp_begin
368
+ ):
369
+ # no consecutive timestamps but it has a timestamp; use the last one.
370
+ last_timestamp_pos = (
371
+ timestamps[-1].item() - tokenizer.timestamp_begin
372
+ )
373
+ duration = last_timestamp_pos * time_precision
374
+
375
+ current_segments.append(
376
+ new_segment(
377
+ start=time_offset,
378
+ end=time_offset + duration,
379
+ tokens=tokens,
380
+ result=result,
381
+ )
382
+ )
383
+ seek += segment_size
384
+
385
+ if word_timestamps:
386
+ add_word_timestamps(
387
+ segments=current_segments,
388
+ model=model,
389
+ tokenizer=tokenizer,
390
+ mel=mel_segment,
391
+ num_frames=segment_size,
392
+ prepend_punctuations=prepend_punctuations,
393
+ append_punctuations=append_punctuations,
394
+ last_speech_timestamp=last_speech_timestamp,
395
+ )
396
+
397
+ if not single_timestamp_ending:
398
+ last_word_end = get_end(current_segments)
399
+ if last_word_end is not None and last_word_end > time_offset:
400
+ seek = round(last_word_end * FRAMES_PER_SECOND)
401
+
402
+ # skip silence before possible hallucinations
403
+ if hallucination_silence_threshold is not None:
404
+ threshold = hallucination_silence_threshold
405
+ if not single_timestamp_ending:
406
+ last_word_end = get_end(current_segments)
407
+ if last_word_end is not None and last_word_end > time_offset:
408
+ remaining_duration = window_end_time - last_word_end
409
+ if remaining_duration > threshold:
410
+ seek = round(last_word_end * FRAMES_PER_SECOND)
411
+ else:
412
+ seek = previous_seek + segment_size
413
+
414
+ # if first segment might be a hallucination, skip leading silence
415
+ first_segment = next_words_segment(current_segments)
416
+ if first_segment is not None and is_segment_anomaly(first_segment):
417
+ gap = first_segment["start"] - time_offset
418
+ if gap > threshold:
419
+ seek = previous_seek + round(gap * FRAMES_PER_SECOND)
420
+ continue
421
+
422
+ # skip silence before any possible hallucination that is surrounded
423
+ # by silence or more hallucinations
424
+ hal_last_end = last_speech_timestamp
425
+ for si in range(len(current_segments)):
426
+ segment = current_segments[si]
427
+ if not segment["words"]:
428
+ continue
429
+ if is_segment_anomaly(segment):
430
+ next_segment = next_words_segment(
431
+ current_segments[si + 1 :]
432
+ )
433
+ if next_segment is not None:
434
+ hal_next_start = next_segment["words"][0]["start"]
435
+ else:
436
+ hal_next_start = time_offset + segment_duration
437
+ silence_before = (
438
+ segment["start"] - hal_last_end > threshold
439
+ or segment["start"] < threshold
440
+ or segment["start"] - time_offset < 2.0
441
+ )
442
+ silence_after = (
443
+ hal_next_start - segment["end"] > threshold
444
+ or is_segment_anomaly(next_segment)
445
+ or window_end_time - segment["end"] < 2.0
446
+ )
447
+ if silence_before and silence_after:
448
+ seek = round(
449
+ max(time_offset + 1, segment["start"])
450
+ * FRAMES_PER_SECOND
451
+ )
452
+ if content_duration - segment["end"] < threshold:
453
+ seek = content_frames
454
+ current_segments[si:] = []
455
+ break
456
+ hal_last_end = segment["end"]
457
+
458
+ last_word_end = get_end(current_segments)
459
+ if last_word_end is not None:
460
+ last_speech_timestamp = last_word_end
461
+
462
+ if verbose:
463
+ for segment in current_segments:
464
+ start, end, text = segment["start"], segment["end"], segment["text"]
465
+ line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
466
+ print(make_safe(line))
467
+
468
+ # if a segment is instantaneous or does not contain text, clear it
469
+ for i, segment in enumerate(current_segments):
470
+ if segment["start"] == segment["end"] or segment["text"].strip() == "":
471
+ segment["text"] = ""
472
+ segment["tokens"] = []
473
+ segment["words"] = []
474
+
475
+ all_segments.extend(
476
+ [
477
+ {"id": i, **segment}
478
+ for i, segment in enumerate(
479
+ current_segments, start=len(all_segments)
480
+ )
481
+ ]
482
+ )
483
+ all_tokens.extend(
484
+ [token for segment in current_segments for token in segment["tokens"]]
485
+ )
486
+
487
+ if not condition_on_previous_text or result.temperature > 0.5:
488
+ # do not feed the prompt tokens if a high temperature was used
489
+ prompt_reset_since = len(all_tokens)
490
+
491
+ # update progress bar
492
+ pbar.update(min(content_frames, seek) - previous_seek)
493
+
494
+ return dict(
495
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
496
+ segments=all_segments,
497
+ language=language,
498
+ )
499
+
500
+
501
+ def cli():
502
+ from . import available_models
503
+
504
+ def valid_model_name(name):
505
+ if name in available_models() or os.path.exists(name):
506
+ return name
507
+ raise ValueError(
508
+ f"model should be one of {available_models()} or path to a model checkpoint"
509
+ )
510
+
511
+ # fmt: off
512
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
513
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
514
+ parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
515
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
516
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
517
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
518
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
519
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
520
+
521
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
522
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
523
+
524
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
525
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
526
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
527
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
528
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
529
+
530
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
531
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
532
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
533
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
534
+
535
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
536
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
537
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
538
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
539
+ parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
540
+ parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
541
+ parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
542
+ parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
543
+ parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
544
+ parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
545
+ parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
546
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
547
+ parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
548
+ parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
549
+ # fmt: on
550
+
551
+ args = parser.parse_args().__dict__
552
+ model_name: str = args.pop("model")
553
+ model_dir: str = args.pop("model_dir")
554
+ output_dir: str = args.pop("output_dir")
555
+ output_format: str = args.pop("output_format")
556
+ device: str = args.pop("device")
557
+ os.makedirs(output_dir, exist_ok=True)
558
+
559
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
560
+ if args["language"] is not None:
561
+ warnings.warn(
562
+ f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
563
+ )
564
+ args["language"] = "en"
565
+
566
+ temperature = args.pop("temperature")
567
+ if (increment := args.pop("temperature_increment_on_fallback")) is not None:
568
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
569
+ else:
570
+ temperature = [temperature]
571
+
572
+ if (threads := args.pop("threads")) > 0:
573
+ torch.set_num_threads(threads)
574
+
575
+ from . import load_model
576
+
577
+ model = load_model(model_name, device=device, download_root=model_dir)
578
+
579
+ writer = get_writer(output_format, output_dir)
580
+ word_options = [
581
+ "highlight_words",
582
+ "max_line_count",
583
+ "max_line_width",
584
+ "max_words_per_line",
585
+ ]
586
+ if not args["word_timestamps"]:
587
+ for option in word_options:
588
+ if args[option]:
589
+ parser.error(f"--{option} requires --word_timestamps True")
590
+ if args["max_line_count"] and not args["max_line_width"]:
591
+ warnings.warn("--max_line_count has no effect without --max_line_width")
592
+ if args["max_words_per_line"] and args["max_line_width"]:
593
+ warnings.warn("--max_words_per_line has no effect with --max_line_width")
594
+ writer_args = {arg: args.pop(arg) for arg in word_options}
595
+ for audio_path in args.pop("audio"):
596
+ try:
597
+ result = transcribe(model, audio_path, temperature=temperature, **args)
598
+ writer(result, audio_path, **writer_args)
599
+ except Exception as e:
600
+ traceback.print_exc()
601
+ print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
602
+
603
+
604
+ if __name__ == "__main__":
605
+ cli()