xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (149) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +5 -39
  4. xinference/client/restful/restful_client.py +3 -24
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/model.py +82 -31
  11. xinference/core/scheduler.py +37 -37
  12. xinference/core/status_guard.py +1 -1
  13. xinference/core/supervisor.py +11 -10
  14. xinference/core/utils.py +80 -22
  15. xinference/core/worker.py +17 -16
  16. xinference/deploy/cmdline.py +8 -16
  17. xinference/deploy/local.py +1 -1
  18. xinference/deploy/supervisor.py +1 -1
  19. xinference/deploy/utils.py +1 -1
  20. xinference/deploy/worker.py +1 -1
  21. xinference/model/audio/cosyvoice.py +86 -41
  22. xinference/model/embedding/core.py +52 -31
  23. xinference/model/image/stable_diffusion/core.py +18 -1
  24. xinference/model/llm/__init__.py +21 -11
  25. xinference/model/llm/llama_cpp/core.py +16 -33
  26. xinference/model/llm/llm_family.json +619 -1297
  27. xinference/model/llm/llm_family.py +31 -52
  28. xinference/model/llm/llm_family_csghub.json +18 -35
  29. xinference/model/llm/llm_family_modelscope.json +573 -1119
  30. xinference/model/llm/lmdeploy/core.py +56 -88
  31. xinference/model/llm/mlx/core.py +46 -69
  32. xinference/model/llm/sglang/core.py +33 -18
  33. xinference/model/llm/transformers/chatglm.py +167 -305
  34. xinference/model/llm/transformers/cogvlm2.py +36 -63
  35. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  36. xinference/model/llm/transformers/core.py +49 -50
  37. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  38. xinference/model/llm/transformers/glm4v.py +55 -111
  39. xinference/model/llm/transformers/intern_vl.py +39 -70
  40. xinference/model/llm/transformers/internlm2.py +32 -54
  41. xinference/model/llm/transformers/minicpmv25.py +22 -55
  42. xinference/model/llm/transformers/minicpmv26.py +158 -68
  43. xinference/model/llm/transformers/omnilmm.py +5 -28
  44. xinference/model/llm/transformers/qwen2_vl.py +208 -0
  45. xinference/model/llm/transformers/qwen_vl.py +34 -86
  46. xinference/model/llm/transformers/utils.py +32 -38
  47. xinference/model/llm/transformers/yi_vl.py +32 -72
  48. xinference/model/llm/utils.py +195 -489
  49. xinference/model/llm/vllm/core.py +153 -100
  50. xinference/model/rerank/core.py +41 -8
  51. xinference/model/rerank/model_spec.json +7 -0
  52. xinference/model/rerank/model_spec_modelscope.json +7 -1
  53. xinference/model/utils.py +1 -31
  54. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  55. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  56. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  57. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  58. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  59. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  60. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  61. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  62. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  63. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  64. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  65. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  66. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  67. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  69. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  70. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  71. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  72. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  73. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  74. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  75. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  76. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  77. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  78. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  79. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  80. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  81. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
  82. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  83. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  84. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  85. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  88. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  89. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  90. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  91. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  92. xinference/thirdparty/matcha/VERSION +1 -0
  93. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  94. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  95. xinference/thirdparty/omnilmm/LICENSE +201 -0
  96. xinference/thirdparty/whisper/__init__.py +156 -0
  97. xinference/thirdparty/whisper/__main__.py +3 -0
  98. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  99. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  100. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  101. xinference/thirdparty/whisper/audio.py +157 -0
  102. xinference/thirdparty/whisper/decoding.py +826 -0
  103. xinference/thirdparty/whisper/model.py +314 -0
  104. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  105. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  106. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  107. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  108. xinference/thirdparty/whisper/timing.py +386 -0
  109. xinference/thirdparty/whisper/tokenizer.py +395 -0
  110. xinference/thirdparty/whisper/transcribe.py +605 -0
  111. xinference/thirdparty/whisper/triton_ops.py +109 -0
  112. xinference/thirdparty/whisper/utils.py +316 -0
  113. xinference/thirdparty/whisper/version.py +1 -0
  114. xinference/types.py +7 -49
  115. xinference/web/ui/build/asset-manifest.json +6 -6
  116. xinference/web/ui/build/index.html +1 -1
  117. xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
  118. xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
  119. xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
  120. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
  121. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
  122. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
  123. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  124. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
  125. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  126. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  127. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  129. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  130. xinference/web/ui/node_modules/.package-lock.json +37 -0
  131. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  132. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  133. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  134. xinference/web/ui/package-lock.json +38 -0
  135. xinference/web/ui/package.json +1 -0
  136. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
  137. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
  138. xinference/model/llm/transformers/llama_2.py +0 -108
  139. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  140. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  141. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  142. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  144. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  145. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  146. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
  147. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
  148. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
  149. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ import itertools
2
+ import subprocess
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, List
6
+
7
+ import numba
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
13
+ from .tokenizer import Tokenizer
14
+
15
+ if TYPE_CHECKING:
16
+ from .model import Whisper
17
+
18
+
19
+ def median_filter(x: torch.Tensor, filter_width: int):
20
+ """Apply a median filter of width `filter_width` along the last dimension of `x`"""
21
+ pad_width = filter_width // 2
22
+ if x.shape[-1] <= pad_width:
23
+ # F.pad requires the padding width to be smaller than the input dimension
24
+ return x
25
+
26
+ if (ndim := x.ndim) <= 2:
27
+ # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
28
+ x = x[None, None, :]
29
+
30
+ assert (
31
+ filter_width > 0 and filter_width % 2 == 1
32
+ ), "`filter_width` should be an odd number"
33
+
34
+ result = None
35
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
36
+ if x.is_cuda:
37
+ try:
38
+ from .triton_ops import median_filter_cuda
39
+
40
+ result = median_filter_cuda(x, filter_width)
41
+ except (RuntimeError, subprocess.CalledProcessError):
42
+ warnings.warn(
43
+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
44
+ "falling back to a slower median kernel implementation..."
45
+ )
46
+
47
+ if result is None:
48
+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
49
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
50
+
51
+ if ndim <= 2:
52
+ result = result[0, 0]
53
+
54
+ return result
55
+
56
+
57
+ @numba.jit(nopython=True)
58
+ def backtrace(trace: np.ndarray):
59
+ i = trace.shape[0] - 1
60
+ j = trace.shape[1] - 1
61
+ trace[0, :] = 2
62
+ trace[:, 0] = 1
63
+
64
+ result = []
65
+ while i > 0 or j > 0:
66
+ result.append((i - 1, j - 1))
67
+
68
+ if trace[i, j] == 0:
69
+ i -= 1
70
+ j -= 1
71
+ elif trace[i, j] == 1:
72
+ i -= 1
73
+ elif trace[i, j] == 2:
74
+ j -= 1
75
+ else:
76
+ raise ValueError("Unexpected trace[i, j]")
77
+
78
+ result = np.array(result)
79
+ return result[::-1, :].T
80
+
81
+
82
+ @numba.jit(nopython=True, parallel=True)
83
+ def dtw_cpu(x: np.ndarray):
84
+ N, M = x.shape
85
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
86
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
87
+
88
+ cost[0, 0] = 0
89
+ for j in range(1, M + 1):
90
+ for i in range(1, N + 1):
91
+ c0 = cost[i - 1, j - 1]
92
+ c1 = cost[i - 1, j]
93
+ c2 = cost[i, j - 1]
94
+
95
+ if c0 < c1 and c0 < c2:
96
+ c, t = c0, 0
97
+ elif c1 < c0 and c1 < c2:
98
+ c, t = c1, 1
99
+ else:
100
+ c, t = c2, 2
101
+
102
+ cost[i, j] = x[i - 1, j - 1] + c
103
+ trace[i, j] = t
104
+
105
+ return backtrace(trace)
106
+
107
+
108
+ def dtw_cuda(x, BLOCK_SIZE=1024):
109
+ from .triton_ops import dtw_kernel
110
+
111
+ M, N = x.shape
112
+ assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
113
+
114
+ x_skew = (
115
+ F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
116
+ )
117
+ x_skew = x_skew.T.contiguous()
118
+ cost = torch.ones(N + M + 2, M + 2) * np.inf
119
+ cost[0, 0] = 0
120
+ cost = cost.cuda()
121
+ trace = torch.zeros_like(cost, dtype=torch.int32)
122
+
123
+ dtw_kernel[(1,)](
124
+ cost,
125
+ trace,
126
+ x_skew,
127
+ x_skew.stride(0),
128
+ cost.stride(0),
129
+ trace.stride(0),
130
+ N,
131
+ M,
132
+ BLOCK_SIZE=BLOCK_SIZE,
133
+ )
134
+
135
+ trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
136
+ :, : N + 1
137
+ ]
138
+ return backtrace(trace.cpu().numpy())
139
+
140
+
141
+ def dtw(x: torch.Tensor) -> np.ndarray:
142
+ if x.is_cuda:
143
+ try:
144
+ return dtw_cuda(x)
145
+ except (RuntimeError, subprocess.CalledProcessError):
146
+ warnings.warn(
147
+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
148
+ "falling back to a slower DTW implementation..."
149
+ )
150
+
151
+ return dtw_cpu(x.double().cpu().numpy())
152
+
153
+
154
+ @dataclass
155
+ class WordTiming:
156
+ word: str
157
+ tokens: List[int]
158
+ start: float
159
+ end: float
160
+ probability: float
161
+
162
+
163
+ def find_alignment(
164
+ model: "Whisper",
165
+ tokenizer: Tokenizer,
166
+ text_tokens: List[int],
167
+ mel: torch.Tensor,
168
+ num_frames: int,
169
+ *,
170
+ medfilt_width: int = 7,
171
+ qk_scale: float = 1.0,
172
+ ) -> List[WordTiming]:
173
+ if len(text_tokens) == 0:
174
+ return []
175
+
176
+ tokens = torch.tensor(
177
+ [
178
+ *tokenizer.sot_sequence,
179
+ tokenizer.no_timestamps,
180
+ *text_tokens,
181
+ tokenizer.eot,
182
+ ]
183
+ ).to(model.device)
184
+
185
+ # install hooks on the cross attention layers to retrieve the attention weights
186
+ QKs = [None] * model.dims.n_text_layer
187
+ hooks = [
188
+ block.cross_attn.register_forward_hook(
189
+ lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
190
+ )
191
+ for i, block in enumerate(model.decoder.blocks)
192
+ ]
193
+
194
+ with torch.no_grad():
195
+ logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
196
+ sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
197
+ token_probs = sampled_logits.softmax(dim=-1)
198
+ text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
199
+ text_token_probs = text_token_probs.tolist()
200
+
201
+ for hook in hooks:
202
+ hook.remove()
203
+
204
+ # heads * tokens * frames
205
+ weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
206
+ weights = weights[:, :, : num_frames // 2]
207
+ weights = (weights * qk_scale).softmax(dim=-1)
208
+ std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
209
+ weights = (weights - mean) / std
210
+ weights = median_filter(weights, medfilt_width)
211
+
212
+ matrix = weights.mean(axis=0)
213
+ matrix = matrix[len(tokenizer.sot_sequence) : -1]
214
+ text_indices, time_indices = dtw(-matrix)
215
+
216
+ words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
217
+ if len(word_tokens) <= 1:
218
+ # return on eot only
219
+ # >>> np.pad([], (1, 0))
220
+ # array([0.])
221
+ # This results in crashes when we lookup jump_times with float, like
222
+ # IndexError: arrays used as indices must be of integer (or boolean) type
223
+ return []
224
+ word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
225
+
226
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
227
+ jump_times = time_indices[jumps] / TOKENS_PER_SECOND
228
+ start_times = jump_times[word_boundaries[:-1]]
229
+ end_times = jump_times[word_boundaries[1:]]
230
+ word_probabilities = [
231
+ np.mean(text_token_probs[i:j])
232
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
233
+ ]
234
+
235
+ return [
236
+ WordTiming(word, tokens, start, end, probability)
237
+ for word, tokens, start, end, probability in zip(
238
+ words, word_tokens, start_times, end_times, word_probabilities
239
+ )
240
+ ]
241
+
242
+
243
+ def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
244
+ # merge prepended punctuations
245
+ i = len(alignment) - 2
246
+ j = len(alignment) - 1
247
+ while i >= 0:
248
+ previous = alignment[i]
249
+ following = alignment[j]
250
+ if previous.word.startswith(" ") and previous.word.strip() in prepended:
251
+ # prepend it to the following word
252
+ following.word = previous.word + following.word
253
+ following.tokens = previous.tokens + following.tokens
254
+ previous.word = ""
255
+ previous.tokens = []
256
+ else:
257
+ j = i
258
+ i -= 1
259
+
260
+ # merge appended punctuations
261
+ i = 0
262
+ j = 1
263
+ while j < len(alignment):
264
+ previous = alignment[i]
265
+ following = alignment[j]
266
+ if not previous.word.endswith(" ") and following.word in appended:
267
+ # append it to the previous word
268
+ previous.word = previous.word + following.word
269
+ previous.tokens = previous.tokens + following.tokens
270
+ following.word = ""
271
+ following.tokens = []
272
+ else:
273
+ i = j
274
+ j += 1
275
+
276
+
277
+ def add_word_timestamps(
278
+ *,
279
+ segments: List[dict],
280
+ model: "Whisper",
281
+ tokenizer: Tokenizer,
282
+ mel: torch.Tensor,
283
+ num_frames: int,
284
+ prepend_punctuations: str = "\"'“¿([{-",
285
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
286
+ last_speech_timestamp: float,
287
+ **kwargs,
288
+ ):
289
+ if len(segments) == 0:
290
+ return
291
+
292
+ text_tokens_per_segment = [
293
+ [token for token in segment["tokens"] if token < tokenizer.eot]
294
+ for segment in segments
295
+ ]
296
+
297
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
298
+ alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
299
+ word_durations = np.array([t.end - t.start for t in alignment])
300
+ word_durations = word_durations[word_durations.nonzero()]
301
+ median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
302
+ median_duration = min(0.7, float(median_duration))
303
+ max_duration = median_duration * 2
304
+
305
+ # hack: truncate long words at sentence boundaries.
306
+ # a better segmentation algorithm based on VAD should be able to replace this.
307
+ if len(word_durations) > 0:
308
+ sentence_end_marks = ".。!!??"
309
+ # ensure words at sentence boundaries are not longer than twice the median word duration.
310
+ for i in range(1, len(alignment)):
311
+ if alignment[i].end - alignment[i].start > max_duration:
312
+ if alignment[i].word in sentence_end_marks:
313
+ alignment[i].end = alignment[i].start + max_duration
314
+ elif alignment[i - 1].word in sentence_end_marks:
315
+ alignment[i].start = alignment[i].end - max_duration
316
+
317
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
318
+
319
+ time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
320
+ word_index = 0
321
+
322
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
323
+ saved_tokens = 0
324
+ words = []
325
+
326
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
327
+ timing = alignment[word_index]
328
+
329
+ if timing.word:
330
+ words.append(
331
+ dict(
332
+ word=timing.word,
333
+ start=round(time_offset + timing.start, 2),
334
+ end=round(time_offset + timing.end, 2),
335
+ probability=timing.probability,
336
+ )
337
+ )
338
+
339
+ saved_tokens += len(timing.tokens)
340
+ word_index += 1
341
+
342
+ # hack: truncate long words at segment boundaries.
343
+ # a better segmentation algorithm based on VAD should be able to replace this.
344
+ if len(words) > 0:
345
+ # ensure the first and second word after a pause is not longer than
346
+ # twice the median word duration.
347
+ if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
348
+ words[0]["end"] - words[0]["start"] > max_duration
349
+ or (
350
+ len(words) > 1
351
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
352
+ )
353
+ ):
354
+ if (
355
+ len(words) > 1
356
+ and words[1]["end"] - words[1]["start"] > max_duration
357
+ ):
358
+ boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
359
+ words[0]["end"] = words[1]["start"] = boundary
360
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
361
+
362
+ # prefer the segment-level start timestamp if the first word is too long.
363
+ if (
364
+ segment["start"] < words[0]["end"]
365
+ and segment["start"] - 0.5 > words[0]["start"]
366
+ ):
367
+ words[0]["start"] = max(
368
+ 0, min(words[0]["end"] - median_duration, segment["start"])
369
+ )
370
+ else:
371
+ segment["start"] = words[0]["start"]
372
+
373
+ # prefer the segment-level end timestamp if the last word is too long.
374
+ if (
375
+ segment["end"] > words[-1]["start"]
376
+ and segment["end"] + 0.5 < words[-1]["end"]
377
+ ):
378
+ words[-1]["end"] = max(
379
+ words[-1]["start"] + median_duration, segment["end"]
380
+ )
381
+ else:
382
+ segment["end"] = words[-1]["end"]
383
+
384
+ last_speech_timestamp = segment["end"]
385
+
386
+ segment["words"] = words