xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (149) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +5 -39
  4. xinference/client/restful/restful_client.py +3 -24
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/model.py +82 -31
  11. xinference/core/scheduler.py +37 -37
  12. xinference/core/status_guard.py +1 -1
  13. xinference/core/supervisor.py +11 -10
  14. xinference/core/utils.py +80 -22
  15. xinference/core/worker.py +17 -16
  16. xinference/deploy/cmdline.py +8 -16
  17. xinference/deploy/local.py +1 -1
  18. xinference/deploy/supervisor.py +1 -1
  19. xinference/deploy/utils.py +1 -1
  20. xinference/deploy/worker.py +1 -1
  21. xinference/model/audio/cosyvoice.py +86 -41
  22. xinference/model/embedding/core.py +52 -31
  23. xinference/model/image/stable_diffusion/core.py +18 -1
  24. xinference/model/llm/__init__.py +21 -11
  25. xinference/model/llm/llama_cpp/core.py +16 -33
  26. xinference/model/llm/llm_family.json +619 -1297
  27. xinference/model/llm/llm_family.py +31 -52
  28. xinference/model/llm/llm_family_csghub.json +18 -35
  29. xinference/model/llm/llm_family_modelscope.json +573 -1119
  30. xinference/model/llm/lmdeploy/core.py +56 -88
  31. xinference/model/llm/mlx/core.py +46 -69
  32. xinference/model/llm/sglang/core.py +33 -18
  33. xinference/model/llm/transformers/chatglm.py +167 -305
  34. xinference/model/llm/transformers/cogvlm2.py +36 -63
  35. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  36. xinference/model/llm/transformers/core.py +49 -50
  37. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  38. xinference/model/llm/transformers/glm4v.py +55 -111
  39. xinference/model/llm/transformers/intern_vl.py +39 -70
  40. xinference/model/llm/transformers/internlm2.py +32 -54
  41. xinference/model/llm/transformers/minicpmv25.py +22 -55
  42. xinference/model/llm/transformers/minicpmv26.py +158 -68
  43. xinference/model/llm/transformers/omnilmm.py +5 -28
  44. xinference/model/llm/transformers/qwen2_vl.py +208 -0
  45. xinference/model/llm/transformers/qwen_vl.py +34 -86
  46. xinference/model/llm/transformers/utils.py +32 -38
  47. xinference/model/llm/transformers/yi_vl.py +32 -72
  48. xinference/model/llm/utils.py +195 -489
  49. xinference/model/llm/vllm/core.py +153 -100
  50. xinference/model/rerank/core.py +41 -8
  51. xinference/model/rerank/model_spec.json +7 -0
  52. xinference/model/rerank/model_spec_modelscope.json +7 -1
  53. xinference/model/utils.py +1 -31
  54. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  55. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  56. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  57. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  58. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  59. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  60. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  61. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  62. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  63. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  64. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  65. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  66. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  67. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  69. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  70. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  71. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  72. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  73. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  74. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  75. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  76. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  77. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  78. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  79. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  80. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  81. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
  82. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  83. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  84. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  85. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  88. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  89. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  90. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  91. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  92. xinference/thirdparty/matcha/VERSION +1 -0
  93. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  94. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  95. xinference/thirdparty/omnilmm/LICENSE +201 -0
  96. xinference/thirdparty/whisper/__init__.py +156 -0
  97. xinference/thirdparty/whisper/__main__.py +3 -0
  98. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  99. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  100. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  101. xinference/thirdparty/whisper/audio.py +157 -0
  102. xinference/thirdparty/whisper/decoding.py +826 -0
  103. xinference/thirdparty/whisper/model.py +314 -0
  104. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  105. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  106. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  107. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  108. xinference/thirdparty/whisper/timing.py +386 -0
  109. xinference/thirdparty/whisper/tokenizer.py +395 -0
  110. xinference/thirdparty/whisper/transcribe.py +605 -0
  111. xinference/thirdparty/whisper/triton_ops.py +109 -0
  112. xinference/thirdparty/whisper/utils.py +316 -0
  113. xinference/thirdparty/whisper/version.py +1 -0
  114. xinference/types.py +7 -49
  115. xinference/web/ui/build/asset-manifest.json +6 -6
  116. xinference/web/ui/build/index.html +1 -1
  117. xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
  118. xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
  119. xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
  120. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
  121. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
  122. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
  123. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  124. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
  125. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  126. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  127. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  129. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  130. xinference/web/ui/node_modules/.package-lock.json +37 -0
  131. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  132. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  133. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  134. xinference/web/ui/package-lock.json +38 -0
  135. xinference/web/ui/package.json +1 -0
  136. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
  137. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
  138. xinference/model/llm/transformers/llama_2.py +0 -108
  139. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  140. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  141. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  142. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  144. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  145. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  146. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
  147. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
  148. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
  149. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,314 @@
1
+ import base64
2
+ import gzip
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Iterable, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+ from .decoding import decode as decode_function
12
+ from .decoding import detect_language as detect_language_function
13
+ from .transcribe import transcribe as transcribe_function
14
+
15
+
16
+ @dataclass
17
+ class ModelDimensions:
18
+ n_mels: int
19
+ n_audio_ctx: int
20
+ n_audio_state: int
21
+ n_audio_head: int
22
+ n_audio_layer: int
23
+ n_vocab: int
24
+ n_text_ctx: int
25
+ n_text_state: int
26
+ n_text_head: int
27
+ n_text_layer: int
28
+
29
+
30
+ class LayerNorm(nn.LayerNorm):
31
+ def forward(self, x: Tensor) -> Tensor:
32
+ return super().forward(x.float()).type(x.dtype)
33
+
34
+
35
+ class Linear(nn.Linear):
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ return F.linear(
38
+ x,
39
+ self.weight.to(x.dtype),
40
+ None if self.bias is None else self.bias.to(x.dtype),
41
+ )
42
+
43
+
44
+ class Conv1d(nn.Conv1d):
45
+ def _conv_forward(
46
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
47
+ ) -> Tensor:
48
+ return super()._conv_forward(
49
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
50
+ )
51
+
52
+
53
+ def sinusoids(length, channels, max_timescale=10000):
54
+ """Returns sinusoids for positional embedding"""
55
+ assert channels % 2 == 0
56
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
57
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
58
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
59
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
60
+
61
+
62
+ class MultiHeadAttention(nn.Module):
63
+ def __init__(self, n_state: int, n_head: int):
64
+ super().__init__()
65
+ self.n_head = n_head
66
+ self.query = Linear(n_state, n_state)
67
+ self.key = Linear(n_state, n_state, bias=False)
68
+ self.value = Linear(n_state, n_state)
69
+ self.out = Linear(n_state, n_state)
70
+
71
+ def forward(
72
+ self,
73
+ x: Tensor,
74
+ xa: Optional[Tensor] = None,
75
+ mask: Optional[Tensor] = None,
76
+ kv_cache: Optional[dict] = None,
77
+ ):
78
+ q = self.query(x)
79
+
80
+ if kv_cache is None or xa is None or self.key not in kv_cache:
81
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
82
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
83
+ k = self.key(x if xa is None else xa)
84
+ v = self.value(x if xa is None else xa)
85
+ else:
86
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
87
+ k = kv_cache[self.key]
88
+ v = kv_cache[self.value]
89
+
90
+ wv, qk = self.qkv_attention(q, k, v, mask)
91
+ return self.out(wv), qk
92
+
93
+ def qkv_attention(
94
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
95
+ ):
96
+ n_batch, n_ctx, n_state = q.shape
97
+ scale = (n_state // self.n_head) ** -0.25
98
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
99
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
100
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
101
+
102
+ qk = q @ k
103
+ if mask is not None:
104
+ qk = qk + mask[:n_ctx, :n_ctx]
105
+ qk = qk.float()
106
+
107
+ w = F.softmax(qk, dim=-1).to(q.dtype)
108
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
109
+
110
+
111
+ class ResidualAttentionBlock(nn.Module):
112
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
113
+ super().__init__()
114
+
115
+ self.attn = MultiHeadAttention(n_state, n_head)
116
+ self.attn_ln = LayerNorm(n_state)
117
+
118
+ self.cross_attn = (
119
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
120
+ )
121
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
122
+
123
+ n_mlp = n_state * 4
124
+ self.mlp = nn.Sequential(
125
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
126
+ )
127
+ self.mlp_ln = LayerNorm(n_state)
128
+
129
+ def forward(
130
+ self,
131
+ x: Tensor,
132
+ xa: Optional[Tensor] = None,
133
+ mask: Optional[Tensor] = None,
134
+ kv_cache: Optional[dict] = None,
135
+ ):
136
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
137
+ if self.cross_attn:
138
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
139
+ x = x + self.mlp(self.mlp_ln(x))
140
+ return x
141
+
142
+
143
+ class AudioEncoder(nn.Module):
144
+ def __init__(
145
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
146
+ ):
147
+ super().__init__()
148
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
149
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
150
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
151
+
152
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
153
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
154
+ )
155
+ self.ln_post = LayerNorm(n_state)
156
+
157
+ def forward(self, x: Tensor):
158
+ """
159
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
160
+ the mel spectrogram of the audio
161
+ """
162
+ x = F.gelu(self.conv1(x))
163
+ x = F.gelu(self.conv2(x))
164
+ x = x.permute(0, 2, 1)
165
+
166
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
167
+ x = (x + self.positional_embedding).to(x.dtype)
168
+
169
+ for block in self.blocks:
170
+ x = block(x)
171
+
172
+ x = self.ln_post(x)
173
+ return x
174
+
175
+
176
+ class TextDecoder(nn.Module):
177
+ def __init__(
178
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
179
+ ):
180
+ super().__init__()
181
+
182
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
183
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
184
+
185
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
186
+ [
187
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
188
+ for _ in range(n_layer)
189
+ ]
190
+ )
191
+ self.ln = LayerNorm(n_state)
192
+
193
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
194
+ self.register_buffer("mask", mask, persistent=False)
195
+
196
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
197
+ """
198
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
199
+ the text tokens
200
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
201
+ the encoded audio features to be attended on
202
+ """
203
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
204
+ x = (
205
+ self.token_embedding(x)
206
+ + self.positional_embedding[offset : offset + x.shape[-1]]
207
+ )
208
+ x = x.to(xa.dtype)
209
+
210
+ for block in self.blocks:
211
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
212
+
213
+ x = self.ln(x)
214
+ logits = (
215
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
216
+ ).float()
217
+
218
+ return logits
219
+
220
+
221
+ class Whisper(nn.Module):
222
+ def __init__(self, dims: ModelDimensions):
223
+ super().__init__()
224
+ self.dims = dims
225
+ self.encoder = AudioEncoder(
226
+ self.dims.n_mels,
227
+ self.dims.n_audio_ctx,
228
+ self.dims.n_audio_state,
229
+ self.dims.n_audio_head,
230
+ self.dims.n_audio_layer,
231
+ )
232
+ self.decoder = TextDecoder(
233
+ self.dims.n_vocab,
234
+ self.dims.n_text_ctx,
235
+ self.dims.n_text_state,
236
+ self.dims.n_text_head,
237
+ self.dims.n_text_layer,
238
+ )
239
+ # use the last half among the decoder layers for time alignment by default;
240
+ # to use a specific set of heads, see `set_alignment_heads()` below.
241
+ all_heads = torch.zeros(
242
+ self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
243
+ )
244
+ all_heads[self.dims.n_text_layer // 2 :] = True
245
+ self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
246
+
247
+ def set_alignment_heads(self, dump: bytes):
248
+ array = np.frombuffer(
249
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
250
+ ).copy()
251
+ mask = torch.from_numpy(array).reshape(
252
+ self.dims.n_text_layer, self.dims.n_text_head
253
+ )
254
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
255
+
256
+ def embed_audio(self, mel: torch.Tensor):
257
+ return self.encoder(mel)
258
+
259
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
260
+ return self.decoder(tokens, audio_features)
261
+
262
+ def forward(
263
+ self, mel: torch.Tensor, tokens: torch.Tensor
264
+ ) -> Dict[str, torch.Tensor]:
265
+ return self.decoder(tokens, self.encoder(mel))
266
+
267
+ @property
268
+ def device(self):
269
+ return next(self.parameters()).device
270
+
271
+ @property
272
+ def is_multilingual(self):
273
+ return self.dims.n_vocab >= 51865
274
+
275
+ @property
276
+ def num_languages(self):
277
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
278
+
279
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
280
+ """
281
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
282
+ tensors calculated for the previous positions. This method returns a dictionary that stores
283
+ all caches, and the necessary hooks for the key and value projection modules that save the
284
+ intermediate tensors to be reused during later calculations.
285
+
286
+ Returns
287
+ -------
288
+ cache : Dict[nn.Module, torch.Tensor]
289
+ A dictionary object mapping the key/value projection modules to its cache
290
+ hooks : List[RemovableHandle]
291
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
292
+ """
293
+ cache = {**cache} if cache is not None else {}
294
+ hooks = []
295
+
296
+ def save_to_cache(module, _, output):
297
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
298
+ # save as-is, for the first token or cross attention
299
+ cache[module] = output
300
+ else:
301
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
302
+ return cache[module]
303
+
304
+ def install_hooks(layer: nn.Module):
305
+ if isinstance(layer, MultiHeadAttention):
306
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
307
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
308
+
309
+ self.decoder.apply(install_hooks)
310
+ return cache, hooks
311
+
312
+ detect_language = detect_language_function
313
+ transcribe = transcribe_function
314
+ decode = decode_function
@@ -0,0 +1,2 @@
1
+ from .basic import BasicTextNormalizer as BasicTextNormalizer
2
+ from .english import EnglishTextNormalizer as EnglishTextNormalizer
@@ -0,0 +1,76 @@
1
+ import re
2
+ import unicodedata
3
+
4
+ import regex
5
+
6
+ # non-ASCII letters that are not separated by "NFKD" normalization
7
+ ADDITIONAL_DIACRITICS = {
8
+ "œ": "oe",
9
+ "Œ": "OE",
10
+ "ø": "o",
11
+ "Ø": "O",
12
+ "æ": "ae",
13
+ "Æ": "AE",
14
+ "ß": "ss",
15
+ "ẞ": "SS",
16
+ "đ": "d",
17
+ "Đ": "D",
18
+ "ð": "d",
19
+ "Ð": "D",
20
+ "þ": "th",
21
+ "Þ": "th",
22
+ "ł": "l",
23
+ "Ł": "L",
24
+ }
25
+
26
+
27
+ def remove_symbols_and_diacritics(s: str, keep=""):
28
+ """
29
+ Replace any other markers, symbols, and punctuations with a space,
30
+ and drop any diacritics (category 'Mn' and some manual mappings)
31
+ """
32
+ return "".join(
33
+ c
34
+ if c in keep
35
+ else ADDITIONAL_DIACRITICS[c]
36
+ if c in ADDITIONAL_DIACRITICS
37
+ else ""
38
+ if unicodedata.category(c) == "Mn"
39
+ else " "
40
+ if unicodedata.category(c)[0] in "MSP"
41
+ else c
42
+ for c in unicodedata.normalize("NFKD", s)
43
+ )
44
+
45
+
46
+ def remove_symbols(s: str):
47
+ """
48
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
49
+ """
50
+ return "".join(
51
+ " " if unicodedata.category(c)[0] in "MSP" else c
52
+ for c in unicodedata.normalize("NFKC", s)
53
+ )
54
+
55
+
56
+ class BasicTextNormalizer:
57
+ def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58
+ self.clean = (
59
+ remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60
+ )
61
+ self.split_letters = split_letters
62
+
63
+ def __call__(self, s: str):
64
+ s = s.lower()
65
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67
+ s = self.clean(s).lower()
68
+
69
+ if self.split_letters:
70
+ s = " ".join(regex.findall(r"\X", s, regex.U))
71
+
72
+ s = re.sub(
73
+ r"\s+", " ", s
74
+ ) # replace any successive whitespace characters with a space
75
+
76
+ return s