xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import gzip
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Iterable, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from torch import Tensor, nn
|
|
10
|
+
|
|
11
|
+
from .decoding import decode as decode_function
|
|
12
|
+
from .decoding import detect_language as detect_language_function
|
|
13
|
+
from .transcribe import transcribe as transcribe_function
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ModelDimensions:
|
|
18
|
+
n_mels: int
|
|
19
|
+
n_audio_ctx: int
|
|
20
|
+
n_audio_state: int
|
|
21
|
+
n_audio_head: int
|
|
22
|
+
n_audio_layer: int
|
|
23
|
+
n_vocab: int
|
|
24
|
+
n_text_ctx: int
|
|
25
|
+
n_text_state: int
|
|
26
|
+
n_text_head: int
|
|
27
|
+
n_text_layer: int
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LayerNorm(nn.LayerNorm):
|
|
31
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
32
|
+
return super().forward(x.float()).type(x.dtype)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Linear(nn.Linear):
|
|
36
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
37
|
+
return F.linear(
|
|
38
|
+
x,
|
|
39
|
+
self.weight.to(x.dtype),
|
|
40
|
+
None if self.bias is None else self.bias.to(x.dtype),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Conv1d(nn.Conv1d):
|
|
45
|
+
def _conv_forward(
|
|
46
|
+
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
|
47
|
+
) -> Tensor:
|
|
48
|
+
return super()._conv_forward(
|
|
49
|
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def sinusoids(length, channels, max_timescale=10000):
|
|
54
|
+
"""Returns sinusoids for positional embedding"""
|
|
55
|
+
assert channels % 2 == 0
|
|
56
|
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
57
|
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
|
58
|
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
59
|
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MultiHeadAttention(nn.Module):
|
|
63
|
+
def __init__(self, n_state: int, n_head: int):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.n_head = n_head
|
|
66
|
+
self.query = Linear(n_state, n_state)
|
|
67
|
+
self.key = Linear(n_state, n_state, bias=False)
|
|
68
|
+
self.value = Linear(n_state, n_state)
|
|
69
|
+
self.out = Linear(n_state, n_state)
|
|
70
|
+
|
|
71
|
+
def forward(
|
|
72
|
+
self,
|
|
73
|
+
x: Tensor,
|
|
74
|
+
xa: Optional[Tensor] = None,
|
|
75
|
+
mask: Optional[Tensor] = None,
|
|
76
|
+
kv_cache: Optional[dict] = None,
|
|
77
|
+
):
|
|
78
|
+
q = self.query(x)
|
|
79
|
+
|
|
80
|
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
81
|
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
82
|
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
83
|
+
k = self.key(x if xa is None else xa)
|
|
84
|
+
v = self.value(x if xa is None else xa)
|
|
85
|
+
else:
|
|
86
|
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
87
|
+
k = kv_cache[self.key]
|
|
88
|
+
v = kv_cache[self.value]
|
|
89
|
+
|
|
90
|
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
|
91
|
+
return self.out(wv), qk
|
|
92
|
+
|
|
93
|
+
def qkv_attention(
|
|
94
|
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
|
95
|
+
):
|
|
96
|
+
n_batch, n_ctx, n_state = q.shape
|
|
97
|
+
scale = (n_state // self.n_head) ** -0.25
|
|
98
|
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
|
99
|
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
|
100
|
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
|
101
|
+
|
|
102
|
+
qk = q @ k
|
|
103
|
+
if mask is not None:
|
|
104
|
+
qk = qk + mask[:n_ctx, :n_ctx]
|
|
105
|
+
qk = qk.float()
|
|
106
|
+
|
|
107
|
+
w = F.softmax(qk, dim=-1).to(q.dtype)
|
|
108
|
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ResidualAttentionBlock(nn.Module):
|
|
112
|
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
113
|
+
super().__init__()
|
|
114
|
+
|
|
115
|
+
self.attn = MultiHeadAttention(n_state, n_head)
|
|
116
|
+
self.attn_ln = LayerNorm(n_state)
|
|
117
|
+
|
|
118
|
+
self.cross_attn = (
|
|
119
|
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
|
120
|
+
)
|
|
121
|
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
|
122
|
+
|
|
123
|
+
n_mlp = n_state * 4
|
|
124
|
+
self.mlp = nn.Sequential(
|
|
125
|
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
|
126
|
+
)
|
|
127
|
+
self.mlp_ln = LayerNorm(n_state)
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
x: Tensor,
|
|
132
|
+
xa: Optional[Tensor] = None,
|
|
133
|
+
mask: Optional[Tensor] = None,
|
|
134
|
+
kv_cache: Optional[dict] = None,
|
|
135
|
+
):
|
|
136
|
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
|
137
|
+
if self.cross_attn:
|
|
138
|
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
|
139
|
+
x = x + self.mlp(self.mlp_ln(x))
|
|
140
|
+
return x
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class AudioEncoder(nn.Module):
|
|
144
|
+
def __init__(
|
|
145
|
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
146
|
+
):
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
|
149
|
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
|
150
|
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
|
151
|
+
|
|
152
|
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
|
153
|
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
|
154
|
+
)
|
|
155
|
+
self.ln_post = LayerNorm(n_state)
|
|
156
|
+
|
|
157
|
+
def forward(self, x: Tensor):
|
|
158
|
+
"""
|
|
159
|
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
160
|
+
the mel spectrogram of the audio
|
|
161
|
+
"""
|
|
162
|
+
x = F.gelu(self.conv1(x))
|
|
163
|
+
x = F.gelu(self.conv2(x))
|
|
164
|
+
x = x.permute(0, 2, 1)
|
|
165
|
+
|
|
166
|
+
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
|
167
|
+
x = (x + self.positional_embedding).to(x.dtype)
|
|
168
|
+
|
|
169
|
+
for block in self.blocks:
|
|
170
|
+
x = block(x)
|
|
171
|
+
|
|
172
|
+
x = self.ln_post(x)
|
|
173
|
+
return x
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class TextDecoder(nn.Module):
|
|
177
|
+
def __init__(
|
|
178
|
+
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
|
179
|
+
):
|
|
180
|
+
super().__init__()
|
|
181
|
+
|
|
182
|
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
|
183
|
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
|
184
|
+
|
|
185
|
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
|
186
|
+
[
|
|
187
|
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
|
188
|
+
for _ in range(n_layer)
|
|
189
|
+
]
|
|
190
|
+
)
|
|
191
|
+
self.ln = LayerNorm(n_state)
|
|
192
|
+
|
|
193
|
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
|
194
|
+
self.register_buffer("mask", mask, persistent=False)
|
|
195
|
+
|
|
196
|
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|
197
|
+
"""
|
|
198
|
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
|
199
|
+
the text tokens
|
|
200
|
+
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
|
201
|
+
the encoded audio features to be attended on
|
|
202
|
+
"""
|
|
203
|
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
|
204
|
+
x = (
|
|
205
|
+
self.token_embedding(x)
|
|
206
|
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
|
207
|
+
)
|
|
208
|
+
x = x.to(xa.dtype)
|
|
209
|
+
|
|
210
|
+
for block in self.blocks:
|
|
211
|
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
|
212
|
+
|
|
213
|
+
x = self.ln(x)
|
|
214
|
+
logits = (
|
|
215
|
+
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
|
216
|
+
).float()
|
|
217
|
+
|
|
218
|
+
return logits
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class Whisper(nn.Module):
|
|
222
|
+
def __init__(self, dims: ModelDimensions):
|
|
223
|
+
super().__init__()
|
|
224
|
+
self.dims = dims
|
|
225
|
+
self.encoder = AudioEncoder(
|
|
226
|
+
self.dims.n_mels,
|
|
227
|
+
self.dims.n_audio_ctx,
|
|
228
|
+
self.dims.n_audio_state,
|
|
229
|
+
self.dims.n_audio_head,
|
|
230
|
+
self.dims.n_audio_layer,
|
|
231
|
+
)
|
|
232
|
+
self.decoder = TextDecoder(
|
|
233
|
+
self.dims.n_vocab,
|
|
234
|
+
self.dims.n_text_ctx,
|
|
235
|
+
self.dims.n_text_state,
|
|
236
|
+
self.dims.n_text_head,
|
|
237
|
+
self.dims.n_text_layer,
|
|
238
|
+
)
|
|
239
|
+
# use the last half among the decoder layers for time alignment by default;
|
|
240
|
+
# to use a specific set of heads, see `set_alignment_heads()` below.
|
|
241
|
+
all_heads = torch.zeros(
|
|
242
|
+
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
|
243
|
+
)
|
|
244
|
+
all_heads[self.dims.n_text_layer // 2 :] = True
|
|
245
|
+
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
|
246
|
+
|
|
247
|
+
def set_alignment_heads(self, dump: bytes):
|
|
248
|
+
array = np.frombuffer(
|
|
249
|
+
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
|
250
|
+
).copy()
|
|
251
|
+
mask = torch.from_numpy(array).reshape(
|
|
252
|
+
self.dims.n_text_layer, self.dims.n_text_head
|
|
253
|
+
)
|
|
254
|
+
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
|
255
|
+
|
|
256
|
+
def embed_audio(self, mel: torch.Tensor):
|
|
257
|
+
return self.encoder(mel)
|
|
258
|
+
|
|
259
|
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
|
260
|
+
return self.decoder(tokens, audio_features)
|
|
261
|
+
|
|
262
|
+
def forward(
|
|
263
|
+
self, mel: torch.Tensor, tokens: torch.Tensor
|
|
264
|
+
) -> Dict[str, torch.Tensor]:
|
|
265
|
+
return self.decoder(tokens, self.encoder(mel))
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def device(self):
|
|
269
|
+
return next(self.parameters()).device
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def is_multilingual(self):
|
|
273
|
+
return self.dims.n_vocab >= 51865
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def num_languages(self):
|
|
277
|
+
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
|
278
|
+
|
|
279
|
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
280
|
+
"""
|
|
281
|
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
|
282
|
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
|
283
|
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
|
284
|
+
intermediate tensors to be reused during later calculations.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
cache : Dict[nn.Module, torch.Tensor]
|
|
289
|
+
A dictionary object mapping the key/value projection modules to its cache
|
|
290
|
+
hooks : List[RemovableHandle]
|
|
291
|
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
|
292
|
+
"""
|
|
293
|
+
cache = {**cache} if cache is not None else {}
|
|
294
|
+
hooks = []
|
|
295
|
+
|
|
296
|
+
def save_to_cache(module, _, output):
|
|
297
|
+
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
|
298
|
+
# save as-is, for the first token or cross attention
|
|
299
|
+
cache[module] = output
|
|
300
|
+
else:
|
|
301
|
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
|
302
|
+
return cache[module]
|
|
303
|
+
|
|
304
|
+
def install_hooks(layer: nn.Module):
|
|
305
|
+
if isinstance(layer, MultiHeadAttention):
|
|
306
|
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
|
307
|
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
|
308
|
+
|
|
309
|
+
self.decoder.apply(install_hooks)
|
|
310
|
+
return cache, hooks
|
|
311
|
+
|
|
312
|
+
detect_language = detect_language_function
|
|
313
|
+
transcribe = transcribe_function
|
|
314
|
+
decode = decode_function
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import unicodedata
|
|
3
|
+
|
|
4
|
+
import regex
|
|
5
|
+
|
|
6
|
+
# non-ASCII letters that are not separated by "NFKD" normalization
|
|
7
|
+
ADDITIONAL_DIACRITICS = {
|
|
8
|
+
"œ": "oe",
|
|
9
|
+
"Œ": "OE",
|
|
10
|
+
"ø": "o",
|
|
11
|
+
"Ø": "O",
|
|
12
|
+
"æ": "ae",
|
|
13
|
+
"Æ": "AE",
|
|
14
|
+
"ß": "ss",
|
|
15
|
+
"ẞ": "SS",
|
|
16
|
+
"đ": "d",
|
|
17
|
+
"Đ": "D",
|
|
18
|
+
"ð": "d",
|
|
19
|
+
"Ð": "D",
|
|
20
|
+
"þ": "th",
|
|
21
|
+
"Þ": "th",
|
|
22
|
+
"ł": "l",
|
|
23
|
+
"Ł": "L",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def remove_symbols_and_diacritics(s: str, keep=""):
|
|
28
|
+
"""
|
|
29
|
+
Replace any other markers, symbols, and punctuations with a space,
|
|
30
|
+
and drop any diacritics (category 'Mn' and some manual mappings)
|
|
31
|
+
"""
|
|
32
|
+
return "".join(
|
|
33
|
+
c
|
|
34
|
+
if c in keep
|
|
35
|
+
else ADDITIONAL_DIACRITICS[c]
|
|
36
|
+
if c in ADDITIONAL_DIACRITICS
|
|
37
|
+
else ""
|
|
38
|
+
if unicodedata.category(c) == "Mn"
|
|
39
|
+
else " "
|
|
40
|
+
if unicodedata.category(c)[0] in "MSP"
|
|
41
|
+
else c
|
|
42
|
+
for c in unicodedata.normalize("NFKD", s)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def remove_symbols(s: str):
|
|
47
|
+
"""
|
|
48
|
+
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
|
49
|
+
"""
|
|
50
|
+
return "".join(
|
|
51
|
+
" " if unicodedata.category(c)[0] in "MSP" else c
|
|
52
|
+
for c in unicodedata.normalize("NFKC", s)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class BasicTextNormalizer:
|
|
57
|
+
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
|
58
|
+
self.clean = (
|
|
59
|
+
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
|
60
|
+
)
|
|
61
|
+
self.split_letters = split_letters
|
|
62
|
+
|
|
63
|
+
def __call__(self, s: str):
|
|
64
|
+
s = s.lower()
|
|
65
|
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
|
66
|
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
|
67
|
+
s = self.clean(s).lower()
|
|
68
|
+
|
|
69
|
+
if self.split_letters:
|
|
70
|
+
s = " ".join(regex.findall(r"\X", s, regex.U))
|
|
71
|
+
|
|
72
|
+
s = re.sub(
|
|
73
|
+
r"\s+", " ", s
|
|
74
|
+
) # replace any successive whitespace characters with a space
|
|
75
|
+
|
|
76
|
+
return s
|