xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import subprocess
|
|
3
|
+
import warnings
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, List
|
|
6
|
+
|
|
7
|
+
import numba
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
|
13
|
+
from .tokenizer import Tokenizer
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .model import Whisper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def median_filter(x: torch.Tensor, filter_width: int):
|
|
20
|
+
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
|
21
|
+
pad_width = filter_width // 2
|
|
22
|
+
if x.shape[-1] <= pad_width:
|
|
23
|
+
# F.pad requires the padding width to be smaller than the input dimension
|
|
24
|
+
return x
|
|
25
|
+
|
|
26
|
+
if (ndim := x.ndim) <= 2:
|
|
27
|
+
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
|
28
|
+
x = x[None, None, :]
|
|
29
|
+
|
|
30
|
+
assert (
|
|
31
|
+
filter_width > 0 and filter_width % 2 == 1
|
|
32
|
+
), "`filter_width` should be an odd number"
|
|
33
|
+
|
|
34
|
+
result = None
|
|
35
|
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
|
36
|
+
if x.is_cuda:
|
|
37
|
+
try:
|
|
38
|
+
from .triton_ops import median_filter_cuda
|
|
39
|
+
|
|
40
|
+
result = median_filter_cuda(x, filter_width)
|
|
41
|
+
except (RuntimeError, subprocess.CalledProcessError):
|
|
42
|
+
warnings.warn(
|
|
43
|
+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
|
44
|
+
"falling back to a slower median kernel implementation..."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
if result is None:
|
|
48
|
+
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
|
49
|
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
|
50
|
+
|
|
51
|
+
if ndim <= 2:
|
|
52
|
+
result = result[0, 0]
|
|
53
|
+
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@numba.jit(nopython=True)
|
|
58
|
+
def backtrace(trace: np.ndarray):
|
|
59
|
+
i = trace.shape[0] - 1
|
|
60
|
+
j = trace.shape[1] - 1
|
|
61
|
+
trace[0, :] = 2
|
|
62
|
+
trace[:, 0] = 1
|
|
63
|
+
|
|
64
|
+
result = []
|
|
65
|
+
while i > 0 or j > 0:
|
|
66
|
+
result.append((i - 1, j - 1))
|
|
67
|
+
|
|
68
|
+
if trace[i, j] == 0:
|
|
69
|
+
i -= 1
|
|
70
|
+
j -= 1
|
|
71
|
+
elif trace[i, j] == 1:
|
|
72
|
+
i -= 1
|
|
73
|
+
elif trace[i, j] == 2:
|
|
74
|
+
j -= 1
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("Unexpected trace[i, j]")
|
|
77
|
+
|
|
78
|
+
result = np.array(result)
|
|
79
|
+
return result[::-1, :].T
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@numba.jit(nopython=True, parallel=True)
|
|
83
|
+
def dtw_cpu(x: np.ndarray):
|
|
84
|
+
N, M = x.shape
|
|
85
|
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
|
86
|
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
|
87
|
+
|
|
88
|
+
cost[0, 0] = 0
|
|
89
|
+
for j in range(1, M + 1):
|
|
90
|
+
for i in range(1, N + 1):
|
|
91
|
+
c0 = cost[i - 1, j - 1]
|
|
92
|
+
c1 = cost[i - 1, j]
|
|
93
|
+
c2 = cost[i, j - 1]
|
|
94
|
+
|
|
95
|
+
if c0 < c1 and c0 < c2:
|
|
96
|
+
c, t = c0, 0
|
|
97
|
+
elif c1 < c0 and c1 < c2:
|
|
98
|
+
c, t = c1, 1
|
|
99
|
+
else:
|
|
100
|
+
c, t = c2, 2
|
|
101
|
+
|
|
102
|
+
cost[i, j] = x[i - 1, j - 1] + c
|
|
103
|
+
trace[i, j] = t
|
|
104
|
+
|
|
105
|
+
return backtrace(trace)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def dtw_cuda(x, BLOCK_SIZE=1024):
|
|
109
|
+
from .triton_ops import dtw_kernel
|
|
110
|
+
|
|
111
|
+
M, N = x.shape
|
|
112
|
+
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
|
113
|
+
|
|
114
|
+
x_skew = (
|
|
115
|
+
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
|
116
|
+
)
|
|
117
|
+
x_skew = x_skew.T.contiguous()
|
|
118
|
+
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
|
119
|
+
cost[0, 0] = 0
|
|
120
|
+
cost = cost.cuda()
|
|
121
|
+
trace = torch.zeros_like(cost, dtype=torch.int32)
|
|
122
|
+
|
|
123
|
+
dtw_kernel[(1,)](
|
|
124
|
+
cost,
|
|
125
|
+
trace,
|
|
126
|
+
x_skew,
|
|
127
|
+
x_skew.stride(0),
|
|
128
|
+
cost.stride(0),
|
|
129
|
+
trace.stride(0),
|
|
130
|
+
N,
|
|
131
|
+
M,
|
|
132
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
|
136
|
+
:, : N + 1
|
|
137
|
+
]
|
|
138
|
+
return backtrace(trace.cpu().numpy())
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def dtw(x: torch.Tensor) -> np.ndarray:
|
|
142
|
+
if x.is_cuda:
|
|
143
|
+
try:
|
|
144
|
+
return dtw_cuda(x)
|
|
145
|
+
except (RuntimeError, subprocess.CalledProcessError):
|
|
146
|
+
warnings.warn(
|
|
147
|
+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
|
148
|
+
"falling back to a slower DTW implementation..."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return dtw_cpu(x.double().cpu().numpy())
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class WordTiming:
|
|
156
|
+
word: str
|
|
157
|
+
tokens: List[int]
|
|
158
|
+
start: float
|
|
159
|
+
end: float
|
|
160
|
+
probability: float
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def find_alignment(
|
|
164
|
+
model: "Whisper",
|
|
165
|
+
tokenizer: Tokenizer,
|
|
166
|
+
text_tokens: List[int],
|
|
167
|
+
mel: torch.Tensor,
|
|
168
|
+
num_frames: int,
|
|
169
|
+
*,
|
|
170
|
+
medfilt_width: int = 7,
|
|
171
|
+
qk_scale: float = 1.0,
|
|
172
|
+
) -> List[WordTiming]:
|
|
173
|
+
if len(text_tokens) == 0:
|
|
174
|
+
return []
|
|
175
|
+
|
|
176
|
+
tokens = torch.tensor(
|
|
177
|
+
[
|
|
178
|
+
*tokenizer.sot_sequence,
|
|
179
|
+
tokenizer.no_timestamps,
|
|
180
|
+
*text_tokens,
|
|
181
|
+
tokenizer.eot,
|
|
182
|
+
]
|
|
183
|
+
).to(model.device)
|
|
184
|
+
|
|
185
|
+
# install hooks on the cross attention layers to retrieve the attention weights
|
|
186
|
+
QKs = [None] * model.dims.n_text_layer
|
|
187
|
+
hooks = [
|
|
188
|
+
block.cross_attn.register_forward_hook(
|
|
189
|
+
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
|
190
|
+
)
|
|
191
|
+
for i, block in enumerate(model.decoder.blocks)
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
with torch.no_grad():
|
|
195
|
+
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
|
196
|
+
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
|
197
|
+
token_probs = sampled_logits.softmax(dim=-1)
|
|
198
|
+
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
|
199
|
+
text_token_probs = text_token_probs.tolist()
|
|
200
|
+
|
|
201
|
+
for hook in hooks:
|
|
202
|
+
hook.remove()
|
|
203
|
+
|
|
204
|
+
# heads * tokens * frames
|
|
205
|
+
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
|
206
|
+
weights = weights[:, :, : num_frames // 2]
|
|
207
|
+
weights = (weights * qk_scale).softmax(dim=-1)
|
|
208
|
+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
|
209
|
+
weights = (weights - mean) / std
|
|
210
|
+
weights = median_filter(weights, medfilt_width)
|
|
211
|
+
|
|
212
|
+
matrix = weights.mean(axis=0)
|
|
213
|
+
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
|
214
|
+
text_indices, time_indices = dtw(-matrix)
|
|
215
|
+
|
|
216
|
+
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
|
217
|
+
if len(word_tokens) <= 1:
|
|
218
|
+
# return on eot only
|
|
219
|
+
# >>> np.pad([], (1, 0))
|
|
220
|
+
# array([0.])
|
|
221
|
+
# This results in crashes when we lookup jump_times with float, like
|
|
222
|
+
# IndexError: arrays used as indices must be of integer (or boolean) type
|
|
223
|
+
return []
|
|
224
|
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
|
225
|
+
|
|
226
|
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
|
227
|
+
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
|
228
|
+
start_times = jump_times[word_boundaries[:-1]]
|
|
229
|
+
end_times = jump_times[word_boundaries[1:]]
|
|
230
|
+
word_probabilities = [
|
|
231
|
+
np.mean(text_token_probs[i:j])
|
|
232
|
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
|
233
|
+
]
|
|
234
|
+
|
|
235
|
+
return [
|
|
236
|
+
WordTiming(word, tokens, start, end, probability)
|
|
237
|
+
for word, tokens, start, end, probability in zip(
|
|
238
|
+
words, word_tokens, start_times, end_times, word_probabilities
|
|
239
|
+
)
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
|
244
|
+
# merge prepended punctuations
|
|
245
|
+
i = len(alignment) - 2
|
|
246
|
+
j = len(alignment) - 1
|
|
247
|
+
while i >= 0:
|
|
248
|
+
previous = alignment[i]
|
|
249
|
+
following = alignment[j]
|
|
250
|
+
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
|
251
|
+
# prepend it to the following word
|
|
252
|
+
following.word = previous.word + following.word
|
|
253
|
+
following.tokens = previous.tokens + following.tokens
|
|
254
|
+
previous.word = ""
|
|
255
|
+
previous.tokens = []
|
|
256
|
+
else:
|
|
257
|
+
j = i
|
|
258
|
+
i -= 1
|
|
259
|
+
|
|
260
|
+
# merge appended punctuations
|
|
261
|
+
i = 0
|
|
262
|
+
j = 1
|
|
263
|
+
while j < len(alignment):
|
|
264
|
+
previous = alignment[i]
|
|
265
|
+
following = alignment[j]
|
|
266
|
+
if not previous.word.endswith(" ") and following.word in appended:
|
|
267
|
+
# append it to the previous word
|
|
268
|
+
previous.word = previous.word + following.word
|
|
269
|
+
previous.tokens = previous.tokens + following.tokens
|
|
270
|
+
following.word = ""
|
|
271
|
+
following.tokens = []
|
|
272
|
+
else:
|
|
273
|
+
i = j
|
|
274
|
+
j += 1
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def add_word_timestamps(
|
|
278
|
+
*,
|
|
279
|
+
segments: List[dict],
|
|
280
|
+
model: "Whisper",
|
|
281
|
+
tokenizer: Tokenizer,
|
|
282
|
+
mel: torch.Tensor,
|
|
283
|
+
num_frames: int,
|
|
284
|
+
prepend_punctuations: str = "\"'“¿([{-",
|
|
285
|
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
|
286
|
+
last_speech_timestamp: float,
|
|
287
|
+
**kwargs,
|
|
288
|
+
):
|
|
289
|
+
if len(segments) == 0:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
text_tokens_per_segment = [
|
|
293
|
+
[token for token in segment["tokens"] if token < tokenizer.eot]
|
|
294
|
+
for segment in segments
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
|
298
|
+
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
|
299
|
+
word_durations = np.array([t.end - t.start for t in alignment])
|
|
300
|
+
word_durations = word_durations[word_durations.nonzero()]
|
|
301
|
+
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
|
302
|
+
median_duration = min(0.7, float(median_duration))
|
|
303
|
+
max_duration = median_duration * 2
|
|
304
|
+
|
|
305
|
+
# hack: truncate long words at sentence boundaries.
|
|
306
|
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
|
307
|
+
if len(word_durations) > 0:
|
|
308
|
+
sentence_end_marks = ".。!!??"
|
|
309
|
+
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
|
310
|
+
for i in range(1, len(alignment)):
|
|
311
|
+
if alignment[i].end - alignment[i].start > max_duration:
|
|
312
|
+
if alignment[i].word in sentence_end_marks:
|
|
313
|
+
alignment[i].end = alignment[i].start + max_duration
|
|
314
|
+
elif alignment[i - 1].word in sentence_end_marks:
|
|
315
|
+
alignment[i].start = alignment[i].end - max_duration
|
|
316
|
+
|
|
317
|
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
|
318
|
+
|
|
319
|
+
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
|
320
|
+
word_index = 0
|
|
321
|
+
|
|
322
|
+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
|
323
|
+
saved_tokens = 0
|
|
324
|
+
words = []
|
|
325
|
+
|
|
326
|
+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
|
327
|
+
timing = alignment[word_index]
|
|
328
|
+
|
|
329
|
+
if timing.word:
|
|
330
|
+
words.append(
|
|
331
|
+
dict(
|
|
332
|
+
word=timing.word,
|
|
333
|
+
start=round(time_offset + timing.start, 2),
|
|
334
|
+
end=round(time_offset + timing.end, 2),
|
|
335
|
+
probability=timing.probability,
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
saved_tokens += len(timing.tokens)
|
|
340
|
+
word_index += 1
|
|
341
|
+
|
|
342
|
+
# hack: truncate long words at segment boundaries.
|
|
343
|
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
|
344
|
+
if len(words) > 0:
|
|
345
|
+
# ensure the first and second word after a pause is not longer than
|
|
346
|
+
# twice the median word duration.
|
|
347
|
+
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
|
348
|
+
words[0]["end"] - words[0]["start"] > max_duration
|
|
349
|
+
or (
|
|
350
|
+
len(words) > 1
|
|
351
|
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
|
352
|
+
)
|
|
353
|
+
):
|
|
354
|
+
if (
|
|
355
|
+
len(words) > 1
|
|
356
|
+
and words[1]["end"] - words[1]["start"] > max_duration
|
|
357
|
+
):
|
|
358
|
+
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
|
359
|
+
words[0]["end"] = words[1]["start"] = boundary
|
|
360
|
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
|
361
|
+
|
|
362
|
+
# prefer the segment-level start timestamp if the first word is too long.
|
|
363
|
+
if (
|
|
364
|
+
segment["start"] < words[0]["end"]
|
|
365
|
+
and segment["start"] - 0.5 > words[0]["start"]
|
|
366
|
+
):
|
|
367
|
+
words[0]["start"] = max(
|
|
368
|
+
0, min(words[0]["end"] - median_duration, segment["start"])
|
|
369
|
+
)
|
|
370
|
+
else:
|
|
371
|
+
segment["start"] = words[0]["start"]
|
|
372
|
+
|
|
373
|
+
# prefer the segment-level end timestamp if the last word is too long.
|
|
374
|
+
if (
|
|
375
|
+
segment["end"] > words[-1]["start"]
|
|
376
|
+
and segment["end"] + 0.5 < words[-1]["end"]
|
|
377
|
+
):
|
|
378
|
+
words[-1]["end"] = max(
|
|
379
|
+
words[-1]["start"] + median_duration, segment["end"]
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
segment["end"] = words[-1]["end"]
|
|
383
|
+
|
|
384
|
+
last_speech_timestamp = segment["end"]
|
|
385
|
+
|
|
386
|
+
segment["words"] = words
|