xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,826 @@
1
+ from dataclasses import dataclass, field, replace
2
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .model import Whisper
16
+
17
+
18
+ @torch.no_grad()
19
+ def detect_language(
20
+ model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
21
+ ) -> Tuple[Tensor, List[dict]]:
22
+ """
23
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
24
+ of the most probable language tokens and the probability distribution over all language tokens.
25
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
26
+
27
+ Returns
28
+ -------
29
+ language_tokens : Tensor, shape = (n_audio,)
30
+ ids of the most probable language tokens, which appears after the startoftranscript token.
31
+ language_probs : List[Dict[str, float]], length = n_audio
32
+ list of dictionaries containing the probability distribution over all languages.
33
+ """
34
+ if tokenizer is None:
35
+ tokenizer = get_tokenizer(
36
+ model.is_multilingual, num_languages=model.num_languages
37
+ )
38
+ if (
39
+ tokenizer.language is None
40
+ or tokenizer.language_token not in tokenizer.sot_sequence
41
+ ):
42
+ raise ValueError(
43
+ "This model doesn't have language tokens so it can't perform lang id"
44
+ )
45
+
46
+ single = mel.ndim == 2
47
+ if single:
48
+ mel = mel.unsqueeze(0)
49
+
50
+ # skip encoder forward pass if already-encoded audio features were given
51
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
52
+ mel = model.encoder(mel)
53
+
54
+ # forward pass using a single token, startoftranscript
55
+ n_audio = mel.shape[0]
56
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
57
+ logits = model.logits(x, mel)[:, 0]
58
+
59
+ # collect detected languages; suppress all non-language tokens
60
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
61
+ mask[list(tokenizer.all_language_tokens)] = False
62
+ logits[:, mask] = -np.inf
63
+ language_tokens = logits.argmax(dim=-1)
64
+ language_token_probs = logits.softmax(dim=-1).cpu()
65
+ language_probs = [
66
+ {
67
+ c: language_token_probs[i, j].item()
68
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
69
+ }
70
+ for i in range(n_audio)
71
+ ]
72
+
73
+ if single:
74
+ language_tokens = language_tokens[0]
75
+ language_probs = language_probs[0]
76
+
77
+ return language_tokens, language_probs
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class DecodingOptions:
82
+ # whether to perform X->X "transcribe" or X->English "translate"
83
+ task: str = "transcribe"
84
+
85
+ # language that the audio is in; uses detected language if None
86
+ language: Optional[str] = None
87
+
88
+ # sampling-related options
89
+ temperature: float = 0.0
90
+ sample_len: Optional[int] = None # maximum number of tokens to sample
91
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
92
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
93
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
94
+
95
+ # "alpha" in Google NMT, or None for length norm, when ranking generations
96
+ # to select which to return among the beams or best-of-N samples
97
+ length_penalty: Optional[float] = None
98
+
99
+ # text or tokens to feed as the prompt or the prefix; for more info:
100
+ # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
101
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
102
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
103
+
104
+ # list of tokens ids (or comma-separated token ids) to suppress
105
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
106
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
107
+ suppress_blank: bool = True # this will suppress blank outputs
108
+
109
+ # timestamp sampling options
110
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
111
+ max_initial_timestamp: Optional[float] = 1.0
112
+
113
+ # implementation details
114
+ fp16: bool = True # use fp16 for most of the calculation
115
+
116
+
117
+ @dataclass(frozen=True)
118
+ class DecodingResult:
119
+ audio_features: Tensor
120
+ language: str
121
+ language_probs: Optional[Dict[str, float]] = None
122
+ tokens: List[int] = field(default_factory=list)
123
+ text: str = ""
124
+ avg_logprob: float = np.nan
125
+ no_speech_prob: float = np.nan
126
+ temperature: float = np.nan
127
+ compression_ratio: float = np.nan
128
+
129
+
130
+ class Inference:
131
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
132
+ """Perform a forward pass on the decoder and return per-token logits"""
133
+ raise NotImplementedError
134
+
135
+ def rearrange_kv_cache(self, source_indices) -> None:
136
+ """Update the key-value cache according to the updated beams"""
137
+ raise NotImplementedError
138
+
139
+ def cleanup_caching(self) -> None:
140
+ """Clean up any resources or hooks after decoding is finished"""
141
+ pass
142
+
143
+
144
+ class PyTorchInference(Inference):
145
+ def __init__(self, model: "Whisper", initial_token_length: int):
146
+ self.model: "Whisper" = model
147
+ self.initial_token_length = initial_token_length
148
+ self.kv_cache = {}
149
+ self.hooks = []
150
+
151
+ key_modules = [block.attn.key for block in self.model.decoder.blocks]
152
+ value_modules = [block.attn.value for block in self.model.decoder.blocks]
153
+ self.kv_modules = key_modules + value_modules
154
+
155
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
156
+ if not self.kv_cache:
157
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
158
+
159
+ if tokens.shape[-1] > self.initial_token_length:
160
+ # only need to use the last token except in the first forward pass
161
+ tokens = tokens[:, -1:]
162
+
163
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
164
+
165
+ def cleanup_caching(self):
166
+ for hook in self.hooks:
167
+ hook.remove()
168
+
169
+ self.kv_cache = {}
170
+ self.hooks = []
171
+
172
+ def rearrange_kv_cache(self, source_indices):
173
+ if source_indices != list(range(len(source_indices))):
174
+ for module in self.kv_modules:
175
+ # update the key/value cache to contain the selected sequences
176
+ self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
177
+
178
+
179
+ class SequenceRanker:
180
+ def rank(
181
+ self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
182
+ ) -> List[int]:
183
+ """
184
+ Given a list of groups of samples and their cumulative log probabilities,
185
+ return the indices of the samples in each group to select as the final result
186
+ """
187
+ raise NotImplementedError
188
+
189
+
190
+ class MaximumLikelihoodRanker(SequenceRanker):
191
+ """
192
+ Select the sample with the highest log probabilities, penalized using either
193
+ a simple length normalization or Google NMT paper's length penalty
194
+ """
195
+
196
+ def __init__(self, length_penalty: Optional[float]):
197
+ self.length_penalty = length_penalty
198
+
199
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
200
+ def scores(logprobs, lengths):
201
+ result = []
202
+ for logprob, length in zip(logprobs, lengths):
203
+ if self.length_penalty is None:
204
+ penalty = length
205
+ else:
206
+ # from the Google NMT paper
207
+ penalty = ((5 + length) / 6) ** self.length_penalty
208
+ result.append(logprob / penalty)
209
+ return result
210
+
211
+ # get the sequence with the highest score
212
+ lengths = [[len(t) for t in s] for s in tokens]
213
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
214
+
215
+
216
+ class TokenDecoder:
217
+ def reset(self):
218
+ """Initialize any stateful variables for decoding a new sequence"""
219
+
220
+ def update(
221
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
222
+ ) -> Tuple[Tensor, bool]:
223
+ """Specify how to select the next token, based on the current trace and logits
224
+
225
+ Parameters
226
+ ----------
227
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
228
+ all tokens in the context so far, including the prefix and sot_sequence tokens
229
+
230
+ logits : Tensor, shape = (n_batch, vocab_size)
231
+ per-token logits of the probability distribution at the current step
232
+
233
+ sum_logprobs : Tensor, shape = (n_batch)
234
+ cumulative log probabilities for each sequence
235
+
236
+ Returns
237
+ -------
238
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
239
+ the tokens, appended with the selected next token
240
+
241
+ completed : bool
242
+ True if all sequences has reached the end of text
243
+
244
+ """
245
+ raise NotImplementedError
246
+
247
+ def finalize(
248
+ self, tokens: Tensor, sum_logprobs: Tensor
249
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
250
+ """Finalize search and return the final candidate sequences
251
+
252
+ Parameters
253
+ ----------
254
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
255
+ all tokens in the context so far, including the prefix and sot_sequence
256
+
257
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
258
+ cumulative log probabilities for each sequence
259
+
260
+ Returns
261
+ -------
262
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
263
+ sequence of Tensors containing candidate token sequences, for each audio input
264
+
265
+ sum_logprobs : List[List[float]], length = n_audio
266
+ sequence of cumulative log probabilities corresponding to the above
267
+
268
+ """
269
+ raise NotImplementedError
270
+
271
+
272
+ class GreedyDecoder(TokenDecoder):
273
+ def __init__(self, temperature: float, eot: int):
274
+ self.temperature = temperature
275
+ self.eot = eot
276
+
277
+ def update(
278
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
279
+ ) -> Tuple[Tensor, bool]:
280
+ if self.temperature == 0:
281
+ next_tokens = logits.argmax(dim=-1)
282
+ else:
283
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
284
+
285
+ logprobs = F.log_softmax(logits.float(), dim=-1)
286
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
287
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
288
+
289
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
290
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
291
+
292
+ completed = (tokens[:, -1] == self.eot).all()
293
+ return tokens, completed
294
+
295
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
296
+ # make sure each sequence has at least one EOT token at the end
297
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
298
+ return tokens, sum_logprobs.tolist()
299
+
300
+
301
+ class BeamSearchDecoder(TokenDecoder):
302
+ def __init__(
303
+ self,
304
+ beam_size: int,
305
+ eot: int,
306
+ inference: Inference,
307
+ patience: Optional[float] = None,
308
+ ):
309
+ self.beam_size = beam_size
310
+ self.eot = eot
311
+ self.inference = inference
312
+ self.patience = patience or 1.0
313
+ self.max_candidates: int = round(beam_size * self.patience)
314
+ self.finished_sequences = None
315
+
316
+ assert (
317
+ self.max_candidates > 0
318
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
319
+
320
+ def reset(self):
321
+ self.finished_sequences = None
322
+
323
+ def update(
324
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
325
+ ) -> Tuple[Tensor, bool]:
326
+ if tokens.shape[0] % self.beam_size != 0:
327
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
328
+
329
+ n_audio = tokens.shape[0] // self.beam_size
330
+ if self.finished_sequences is None: # for the first update
331
+ self.finished_sequences = [{} for _ in range(n_audio)]
332
+
333
+ logprobs = F.log_softmax(logits.float(), dim=-1)
334
+ next_tokens, source_indices, finished_sequences = [], [], []
335
+ for i in range(n_audio):
336
+ scores, sources, finished = {}, {}, {}
337
+
338
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
339
+ for j in range(self.beam_size):
340
+ idx = i * self.beam_size + j
341
+ prefix = tokens[idx].tolist()
342
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
343
+ new_logprob = (sum_logprobs[idx] + logprob).item()
344
+ sequence = tuple(prefix + [token.item()])
345
+ scores[sequence] = new_logprob
346
+ sources[sequence] = idx
347
+
348
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
349
+ saved = 0
350
+ for sequence in sorted(scores, key=scores.get, reverse=True):
351
+ if sequence[-1] == self.eot:
352
+ finished[sequence] = scores[sequence]
353
+ else:
354
+ sum_logprobs[len(next_tokens)] = scores[sequence]
355
+ next_tokens.append(sequence)
356
+ source_indices.append(sources[sequence])
357
+
358
+ saved += 1
359
+ if saved == self.beam_size:
360
+ break
361
+
362
+ finished_sequences.append(finished)
363
+
364
+ tokens = torch.tensor(next_tokens, device=tokens.device)
365
+ self.inference.rearrange_kv_cache(source_indices)
366
+
367
+ # add newly finished sequences to self.finished_sequences
368
+ assert len(self.finished_sequences) == len(finished_sequences)
369
+ for previously_finished, newly_finished in zip(
370
+ self.finished_sequences, finished_sequences
371
+ ):
372
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
373
+ if len(previously_finished) >= self.max_candidates:
374
+ break # the candidate list is full
375
+ previously_finished[seq] = newly_finished[seq]
376
+
377
+ # mark as completed if all audio has enough number of samples
378
+ completed = all(
379
+ len(sequences) >= self.max_candidates
380
+ for sequences in self.finished_sequences
381
+ )
382
+ return tokens, completed
383
+
384
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
385
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
386
+ sum_logprobs = sum_logprobs.cpu()
387
+ for i, sequences in enumerate(self.finished_sequences):
388
+ if (
389
+ len(sequences) < self.beam_size
390
+ ): # when not enough sequences are finished
391
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
392
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
393
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
394
+ if len(sequences) >= self.beam_size:
395
+ break
396
+
397
+ tokens: List[List[Tensor]] = [
398
+ [torch.tensor(seq) for seq in sequences.keys()]
399
+ for sequences in self.finished_sequences
400
+ ]
401
+ sum_logprobs: List[List[float]] = [
402
+ list(sequences.values()) for sequences in self.finished_sequences
403
+ ]
404
+ return tokens, sum_logprobs
405
+
406
+
407
+ class LogitFilter:
408
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
409
+ """Apply any filtering or masking to logits in-place
410
+
411
+ Parameters
412
+ ----------
413
+ logits : Tensor, shape = (n_batch, vocab_size)
414
+ per-token logits of the probability distribution at the current step
415
+
416
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
417
+ all tokens in the context so far, including the prefix and sot_sequence tokens
418
+
419
+ """
420
+ raise NotImplementedError
421
+
422
+
423
+ class SuppressBlank(LogitFilter):
424
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
425
+ self.tokenizer = tokenizer
426
+ self.sample_begin = sample_begin
427
+
428
+ def apply(self, logits: Tensor, tokens: Tensor):
429
+ if tokens.shape[1] == self.sample_begin:
430
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
431
+
432
+
433
+ class SuppressTokens(LogitFilter):
434
+ def __init__(self, suppress_tokens: Sequence[int]):
435
+ self.suppress_tokens = list(suppress_tokens)
436
+
437
+ def apply(self, logits: Tensor, tokens: Tensor):
438
+ logits[:, self.suppress_tokens] = -np.inf
439
+
440
+
441
+ class ApplyTimestampRules(LogitFilter):
442
+ def __init__(
443
+ self,
444
+ tokenizer: Tokenizer,
445
+ sample_begin: int,
446
+ max_initial_timestamp_index: Optional[int],
447
+ ):
448
+ self.tokenizer = tokenizer
449
+ self.sample_begin = sample_begin
450
+ self.max_initial_timestamp_index = max_initial_timestamp_index
451
+
452
+ def apply(self, logits: Tensor, tokens: Tensor):
453
+ # suppress <|notimestamps|> which is handled by without_timestamps
454
+ if self.tokenizer.no_timestamps is not None:
455
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
456
+
457
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
458
+ for k in range(tokens.shape[0]):
459
+ sampled_tokens = tokens[k, self.sample_begin :]
460
+ seq = [t for t in sampled_tokens.tolist()]
461
+ last_was_timestamp = (
462
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
463
+ )
464
+ penultimate_was_timestamp = (
465
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
466
+ )
467
+
468
+ if last_was_timestamp:
469
+ if penultimate_was_timestamp: # has to be non-timestamp
470
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
471
+ else: # cannot be normal text tokens
472
+ logits[k, : self.tokenizer.eot] = -np.inf
473
+
474
+ timestamps = sampled_tokens[
475
+ sampled_tokens.ge(self.tokenizer.timestamp_begin)
476
+ ]
477
+ if timestamps.numel() > 0:
478
+ # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
479
+ # also force each segment to have a nonzero length, to prevent infinite looping
480
+ if last_was_timestamp and not penultimate_was_timestamp:
481
+ timestamp_last = timestamps[-1]
482
+ else:
483
+ timestamp_last = timestamps[-1] + 1
484
+ logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
485
+
486
+ if tokens.shape[1] == self.sample_begin:
487
+ # suppress generating non-timestamp tokens at the beginning
488
+ logits[:, : self.tokenizer.timestamp_begin] = -np.inf
489
+
490
+ # apply the `max_initial_timestamp` option
491
+ if self.max_initial_timestamp_index is not None:
492
+ last_allowed = (
493
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
494
+ )
495
+ logits[:, last_allowed + 1 :] = -np.inf
496
+
497
+ # if sum of probability over timestamps is above any other token, sample timestamp
498
+ logprobs = F.log_softmax(logits.float(), dim=-1)
499
+ for k in range(tokens.shape[0]):
500
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
501
+ dim=-1
502
+ )
503
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
504
+ if timestamp_logprob > max_text_token_logprob:
505
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
506
+
507
+
508
+ class DecodingTask:
509
+ inference: Inference
510
+ sequence_ranker: SequenceRanker
511
+ decoder: TokenDecoder
512
+ logit_filters: List[LogitFilter]
513
+
514
+ def __init__(self, model: "Whisper", options: DecodingOptions):
515
+ self.model = model
516
+
517
+ language = options.language or "en"
518
+ tokenizer = get_tokenizer(
519
+ model.is_multilingual,
520
+ num_languages=model.num_languages,
521
+ language=language,
522
+ task=options.task,
523
+ )
524
+ self.tokenizer: Tokenizer = tokenizer
525
+ self.options: DecodingOptions = self._verify_options(options)
526
+
527
+ self.n_group: int = options.beam_size or options.best_of or 1
528
+ self.n_ctx: int = model.dims.n_text_ctx
529
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
530
+
531
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
532
+ if self.options.without_timestamps:
533
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
534
+
535
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
536
+ self.sample_begin: int = len(self.initial_tokens)
537
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
538
+
539
+ # inference: implements the forward pass through the decoder, including kv caching
540
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
541
+
542
+ # sequence ranker: implements how to rank a group of sampled sequences
543
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
544
+
545
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
546
+ if options.beam_size is not None:
547
+ self.decoder = BeamSearchDecoder(
548
+ options.beam_size, tokenizer.eot, self.inference, options.patience
549
+ )
550
+ else:
551
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
552
+
553
+ # logit filters: applies various rules to suppress or penalize certain tokens
554
+ self.logit_filters = []
555
+ if self.options.suppress_blank:
556
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
557
+ if self.options.suppress_tokens:
558
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
559
+ if not options.without_timestamps:
560
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
561
+ max_initial_timestamp_index = None
562
+ if options.max_initial_timestamp:
563
+ max_initial_timestamp_index = round(
564
+ self.options.max_initial_timestamp / precision
565
+ )
566
+ self.logit_filters.append(
567
+ ApplyTimestampRules(
568
+ tokenizer, self.sample_begin, max_initial_timestamp_index
569
+ )
570
+ )
571
+
572
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
573
+ if options.beam_size is not None and options.best_of is not None:
574
+ raise ValueError("beam_size and best_of can't be given together")
575
+ if options.temperature == 0:
576
+ if options.best_of is not None:
577
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
578
+ if options.patience is not None and options.beam_size is None:
579
+ raise ValueError("patience requires beam_size to be given")
580
+ if options.length_penalty is not None and not (
581
+ 0 <= options.length_penalty <= 1
582
+ ):
583
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
584
+
585
+ return options
586
+
587
+ def _get_initial_tokens(self) -> Tuple[int]:
588
+ tokens = list(self.sot_sequence)
589
+
590
+ if prefix := self.options.prefix:
591
+ prefix_tokens = (
592
+ self.tokenizer.encode(" " + prefix.strip())
593
+ if isinstance(prefix, str)
594
+ else prefix
595
+ )
596
+ if self.sample_len is not None:
597
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
598
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
599
+ tokens = tokens + prefix_tokens
600
+
601
+ if prompt := self.options.prompt:
602
+ prompt_tokens = (
603
+ self.tokenizer.encode(" " + prompt.strip())
604
+ if isinstance(prompt, str)
605
+ else prompt
606
+ )
607
+ tokens = (
608
+ [self.tokenizer.sot_prev]
609
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
610
+ + tokens
611
+ )
612
+
613
+ return tuple(tokens)
614
+
615
+ def _get_suppress_tokens(self) -> Tuple[int]:
616
+ suppress_tokens = self.options.suppress_tokens
617
+
618
+ if isinstance(suppress_tokens, str):
619
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
620
+
621
+ if -1 in suppress_tokens:
622
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
623
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
624
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
625
+ suppress_tokens = [] # interpret empty string as an empty list
626
+ else:
627
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
628
+
629
+ suppress_tokens.extend(
630
+ [
631
+ self.tokenizer.transcribe,
632
+ self.tokenizer.translate,
633
+ self.tokenizer.sot,
634
+ self.tokenizer.sot_prev,
635
+ self.tokenizer.sot_lm,
636
+ ]
637
+ )
638
+ if self.tokenizer.no_speech is not None:
639
+ # no-speech probability is collected separately
640
+ suppress_tokens.append(self.tokenizer.no_speech)
641
+
642
+ return tuple(sorted(set(suppress_tokens)))
643
+
644
+ def _get_audio_features(self, mel: Tensor):
645
+ if self.options.fp16:
646
+ mel = mel.half()
647
+
648
+ if mel.shape[-2:] == (
649
+ self.model.dims.n_audio_ctx,
650
+ self.model.dims.n_audio_state,
651
+ ):
652
+ # encoded audio features are given; skip audio encoding
653
+ audio_features = mel
654
+ else:
655
+ audio_features = self.model.encoder(mel)
656
+
657
+ if audio_features.dtype != (
658
+ torch.float16 if self.options.fp16 else torch.float32
659
+ ):
660
+ return TypeError(
661
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
662
+ )
663
+
664
+ return audio_features
665
+
666
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
667
+ languages = [self.options.language] * audio_features.shape[0]
668
+ lang_probs = None
669
+
670
+ if self.options.language is None or self.options.task == "lang_id":
671
+ lang_tokens, lang_probs = self.model.detect_language(
672
+ audio_features, self.tokenizer
673
+ )
674
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
675
+ if self.options.language is None:
676
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
677
+
678
+ return languages, lang_probs
679
+
680
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
681
+ n_batch = tokens.shape[0]
682
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
683
+ no_speech_probs = [np.nan] * n_batch
684
+
685
+ try:
686
+ for i in range(self.sample_len):
687
+ logits = self.inference.logits(tokens, audio_features)
688
+
689
+ if (
690
+ i == 0 and self.tokenizer.no_speech is not None
691
+ ): # save no_speech_probs
692
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
693
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
694
+
695
+ # now we need to consider the logits at the last token only
696
+ logits = logits[:, -1]
697
+
698
+ # apply the logit filters, e.g. for suppressing or applying penalty to
699
+ for logit_filter in self.logit_filters:
700
+ logit_filter.apply(logits, tokens)
701
+
702
+ # expand the tokens tensor with the selected next tokens
703
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
704
+
705
+ if completed or tokens.shape[-1] > self.n_ctx:
706
+ break
707
+ finally:
708
+ self.inference.cleanup_caching()
709
+
710
+ return tokens, sum_logprobs, no_speech_probs
711
+
712
+ @torch.no_grad()
713
+ def run(self, mel: Tensor) -> List[DecodingResult]:
714
+ self.decoder.reset()
715
+ tokenizer: Tokenizer = self.tokenizer
716
+ n_audio: int = mel.shape[0]
717
+
718
+ audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
719
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
720
+
721
+ # detect language if requested, overwriting the language token
722
+ languages, language_probs = self._detect_language(audio_features, tokens)
723
+ if self.options.task == "lang_id":
724
+ return [
725
+ DecodingResult(
726
+ audio_features=features, language=language, language_probs=probs
727
+ )
728
+ for features, language, probs in zip(
729
+ audio_features, languages, language_probs
730
+ )
731
+ ]
732
+
733
+ # repeat text tensors by the group size, for beam search or best-of-n sampling
734
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
735
+
736
+ # call the main sampling loop
737
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
738
+
739
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
740
+ audio_features = audio_features[:: self.n_group]
741
+ no_speech_probs = no_speech_probs[:: self.n_group]
742
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
743
+
744
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
745
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
746
+
747
+ # get the final candidates for each group, and slice between the first sampled token and EOT
748
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
749
+ tokens: List[List[Tensor]] = [
750
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
751
+ for s in tokens
752
+ ]
753
+
754
+ # select the top-ranked sample in each group
755
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
756
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
757
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
758
+
759
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
760
+ avg_logprobs: List[float] = [
761
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
762
+ ]
763
+
764
+ fields = (
765
+ texts,
766
+ languages,
767
+ tokens,
768
+ audio_features,
769
+ avg_logprobs,
770
+ no_speech_probs,
771
+ )
772
+ if len(set(map(len, fields))) != 1:
773
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
774
+
775
+ return [
776
+ DecodingResult(
777
+ audio_features=features,
778
+ language=language,
779
+ tokens=tokens,
780
+ text=text,
781
+ avg_logprob=avg_logprob,
782
+ no_speech_prob=no_speech_prob,
783
+ temperature=self.options.temperature,
784
+ compression_ratio=compression_ratio(text),
785
+ )
786
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
787
+ *fields
788
+ )
789
+ ]
790
+
791
+
792
+ @torch.no_grad()
793
+ def decode(
794
+ model: "Whisper",
795
+ mel: Tensor,
796
+ options: DecodingOptions = DecodingOptions(),
797
+ **kwargs,
798
+ ) -> Union[DecodingResult, List[DecodingResult]]:
799
+ """
800
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
801
+
802
+ Parameters
803
+ ----------
804
+ model: Whisper
805
+ the Whisper model instance
806
+
807
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
808
+ A tensor containing the Mel spectrogram(s)
809
+
810
+ options: DecodingOptions
811
+ A dataclass that contains all necessary options for decoding 30-second segments
812
+
813
+ Returns
814
+ -------
815
+ result: Union[DecodingResult, List[DecodingResult]]
816
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
817
+ """
818
+ if single := mel.ndim == 2:
819
+ mel = mel.unsqueeze(0)
820
+
821
+ if kwargs:
822
+ options = replace(options, **kwargs)
823
+
824
+ result = DecodingTask(model, options).run(mel)
825
+
826
+ return result[0] if single else result