xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -39
- xinference/client/restful/restful_client.py +3 -24
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/model.py +82 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +11 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/stable_diffusion/core.py +18 -1
- xinference/model/llm/__init__.py +21 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +619 -1297
- xinference/model/llm/llm_family.py +31 -52
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +573 -1119
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +33 -18
- xinference/model/llm/transformers/chatglm.py +167 -305
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +49 -50
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_vl.py +208 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +195 -489
- xinference/model/llm/vllm/core.py +153 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +7 -49
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
- xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
- xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
import string
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from functools import cached_property, lru_cache
|
|
6
|
+
from typing import Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import tiktoken
|
|
9
|
+
|
|
10
|
+
LANGUAGES = {
|
|
11
|
+
"en": "english",
|
|
12
|
+
"zh": "chinese",
|
|
13
|
+
"de": "german",
|
|
14
|
+
"es": "spanish",
|
|
15
|
+
"ru": "russian",
|
|
16
|
+
"ko": "korean",
|
|
17
|
+
"fr": "french",
|
|
18
|
+
"ja": "japanese",
|
|
19
|
+
"pt": "portuguese",
|
|
20
|
+
"tr": "turkish",
|
|
21
|
+
"pl": "polish",
|
|
22
|
+
"ca": "catalan",
|
|
23
|
+
"nl": "dutch",
|
|
24
|
+
"ar": "arabic",
|
|
25
|
+
"sv": "swedish",
|
|
26
|
+
"it": "italian",
|
|
27
|
+
"id": "indonesian",
|
|
28
|
+
"hi": "hindi",
|
|
29
|
+
"fi": "finnish",
|
|
30
|
+
"vi": "vietnamese",
|
|
31
|
+
"he": "hebrew",
|
|
32
|
+
"uk": "ukrainian",
|
|
33
|
+
"el": "greek",
|
|
34
|
+
"ms": "malay",
|
|
35
|
+
"cs": "czech",
|
|
36
|
+
"ro": "romanian",
|
|
37
|
+
"da": "danish",
|
|
38
|
+
"hu": "hungarian",
|
|
39
|
+
"ta": "tamil",
|
|
40
|
+
"no": "norwegian",
|
|
41
|
+
"th": "thai",
|
|
42
|
+
"ur": "urdu",
|
|
43
|
+
"hr": "croatian",
|
|
44
|
+
"bg": "bulgarian",
|
|
45
|
+
"lt": "lithuanian",
|
|
46
|
+
"la": "latin",
|
|
47
|
+
"mi": "maori",
|
|
48
|
+
"ml": "malayalam",
|
|
49
|
+
"cy": "welsh",
|
|
50
|
+
"sk": "slovak",
|
|
51
|
+
"te": "telugu",
|
|
52
|
+
"fa": "persian",
|
|
53
|
+
"lv": "latvian",
|
|
54
|
+
"bn": "bengali",
|
|
55
|
+
"sr": "serbian",
|
|
56
|
+
"az": "azerbaijani",
|
|
57
|
+
"sl": "slovenian",
|
|
58
|
+
"kn": "kannada",
|
|
59
|
+
"et": "estonian",
|
|
60
|
+
"mk": "macedonian",
|
|
61
|
+
"br": "breton",
|
|
62
|
+
"eu": "basque",
|
|
63
|
+
"is": "icelandic",
|
|
64
|
+
"hy": "armenian",
|
|
65
|
+
"ne": "nepali",
|
|
66
|
+
"mn": "mongolian",
|
|
67
|
+
"bs": "bosnian",
|
|
68
|
+
"kk": "kazakh",
|
|
69
|
+
"sq": "albanian",
|
|
70
|
+
"sw": "swahili",
|
|
71
|
+
"gl": "galician",
|
|
72
|
+
"mr": "marathi",
|
|
73
|
+
"pa": "punjabi",
|
|
74
|
+
"si": "sinhala",
|
|
75
|
+
"km": "khmer",
|
|
76
|
+
"sn": "shona",
|
|
77
|
+
"yo": "yoruba",
|
|
78
|
+
"so": "somali",
|
|
79
|
+
"af": "afrikaans",
|
|
80
|
+
"oc": "occitan",
|
|
81
|
+
"ka": "georgian",
|
|
82
|
+
"be": "belarusian",
|
|
83
|
+
"tg": "tajik",
|
|
84
|
+
"sd": "sindhi",
|
|
85
|
+
"gu": "gujarati",
|
|
86
|
+
"am": "amharic",
|
|
87
|
+
"yi": "yiddish",
|
|
88
|
+
"lo": "lao",
|
|
89
|
+
"uz": "uzbek",
|
|
90
|
+
"fo": "faroese",
|
|
91
|
+
"ht": "haitian creole",
|
|
92
|
+
"ps": "pashto",
|
|
93
|
+
"tk": "turkmen",
|
|
94
|
+
"nn": "nynorsk",
|
|
95
|
+
"mt": "maltese",
|
|
96
|
+
"sa": "sanskrit",
|
|
97
|
+
"lb": "luxembourgish",
|
|
98
|
+
"my": "myanmar",
|
|
99
|
+
"bo": "tibetan",
|
|
100
|
+
"tl": "tagalog",
|
|
101
|
+
"mg": "malagasy",
|
|
102
|
+
"as": "assamese",
|
|
103
|
+
"tt": "tatar",
|
|
104
|
+
"haw": "hawaiian",
|
|
105
|
+
"ln": "lingala",
|
|
106
|
+
"ha": "hausa",
|
|
107
|
+
"ba": "bashkir",
|
|
108
|
+
"jw": "javanese",
|
|
109
|
+
"su": "sundanese",
|
|
110
|
+
"yue": "cantonese",
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
# language code lookup by name, with a few language aliases
|
|
114
|
+
TO_LANGUAGE_CODE = {
|
|
115
|
+
**{language: code for code, language in LANGUAGES.items()},
|
|
116
|
+
"burmese": "my",
|
|
117
|
+
"valencian": "ca",
|
|
118
|
+
"flemish": "nl",
|
|
119
|
+
"haitian": "ht",
|
|
120
|
+
"letzeburgesch": "lb",
|
|
121
|
+
"pushto": "ps",
|
|
122
|
+
"panjabi": "pa",
|
|
123
|
+
"moldavian": "ro",
|
|
124
|
+
"moldovan": "ro",
|
|
125
|
+
"sinhalese": "si",
|
|
126
|
+
"castilian": "es",
|
|
127
|
+
"mandarin": "zh",
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass
|
|
132
|
+
class Tokenizer:
|
|
133
|
+
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
|
134
|
+
|
|
135
|
+
encoding: tiktoken.Encoding
|
|
136
|
+
num_languages: int
|
|
137
|
+
language: Optional[str] = None
|
|
138
|
+
task: Optional[str] = None
|
|
139
|
+
sot_sequence: Tuple[int] = ()
|
|
140
|
+
special_tokens: Dict[str, int] = field(default_factory=dict)
|
|
141
|
+
|
|
142
|
+
def __post_init__(self):
|
|
143
|
+
for special in self.encoding.special_tokens_set:
|
|
144
|
+
special_token = self.encoding.encode_single_token(special)
|
|
145
|
+
self.special_tokens[special] = special_token
|
|
146
|
+
|
|
147
|
+
sot: int = self.special_tokens["<|startoftranscript|>"]
|
|
148
|
+
translate: int = self.special_tokens["<|translate|>"]
|
|
149
|
+
transcribe: int = self.special_tokens["<|transcribe|>"]
|
|
150
|
+
|
|
151
|
+
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
|
152
|
+
sot_sequence = [sot]
|
|
153
|
+
if self.language is not None:
|
|
154
|
+
sot_sequence.append(sot + 1 + langs.index(self.language))
|
|
155
|
+
if self.task is not None:
|
|
156
|
+
task_token: int = transcribe if self.task == "transcribe" else translate
|
|
157
|
+
sot_sequence.append(task_token)
|
|
158
|
+
|
|
159
|
+
self.sot_sequence = tuple(sot_sequence)
|
|
160
|
+
|
|
161
|
+
def encode(self, text, **kwargs):
|
|
162
|
+
return self.encoding.encode(text, **kwargs)
|
|
163
|
+
|
|
164
|
+
def decode(self, token_ids: List[int], **kwargs) -> str:
|
|
165
|
+
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
|
166
|
+
return self.encoding.decode(token_ids, **kwargs)
|
|
167
|
+
|
|
168
|
+
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
|
169
|
+
"""
|
|
170
|
+
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
|
171
|
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
|
172
|
+
"""
|
|
173
|
+
return self.encoding.decode(token_ids, **kwargs)
|
|
174
|
+
|
|
175
|
+
@cached_property
|
|
176
|
+
def eot(self) -> int:
|
|
177
|
+
return self.encoding.eot_token
|
|
178
|
+
|
|
179
|
+
@cached_property
|
|
180
|
+
def transcribe(self) -> int:
|
|
181
|
+
return self.special_tokens["<|transcribe|>"]
|
|
182
|
+
|
|
183
|
+
@cached_property
|
|
184
|
+
def translate(self) -> int:
|
|
185
|
+
return self.special_tokens["<|translate|>"]
|
|
186
|
+
|
|
187
|
+
@cached_property
|
|
188
|
+
def sot(self) -> int:
|
|
189
|
+
return self.special_tokens["<|startoftranscript|>"]
|
|
190
|
+
|
|
191
|
+
@cached_property
|
|
192
|
+
def sot_lm(self) -> int:
|
|
193
|
+
return self.special_tokens["<|startoflm|>"]
|
|
194
|
+
|
|
195
|
+
@cached_property
|
|
196
|
+
def sot_prev(self) -> int:
|
|
197
|
+
return self.special_tokens["<|startofprev|>"]
|
|
198
|
+
|
|
199
|
+
@cached_property
|
|
200
|
+
def no_speech(self) -> int:
|
|
201
|
+
return self.special_tokens["<|nospeech|>"]
|
|
202
|
+
|
|
203
|
+
@cached_property
|
|
204
|
+
def no_timestamps(self) -> int:
|
|
205
|
+
return self.special_tokens["<|notimestamps|>"]
|
|
206
|
+
|
|
207
|
+
@cached_property
|
|
208
|
+
def timestamp_begin(self) -> int:
|
|
209
|
+
return self.special_tokens["<|0.00|>"]
|
|
210
|
+
|
|
211
|
+
@cached_property
|
|
212
|
+
def language_token(self) -> int:
|
|
213
|
+
"""Returns the token id corresponding to the value of the `language` field"""
|
|
214
|
+
if self.language is None:
|
|
215
|
+
raise ValueError("This tokenizer does not have language token configured")
|
|
216
|
+
|
|
217
|
+
return self.to_language_token(self.language)
|
|
218
|
+
|
|
219
|
+
def to_language_token(self, language):
|
|
220
|
+
if token := self.special_tokens.get(f"<|{language}|>", None):
|
|
221
|
+
return token
|
|
222
|
+
|
|
223
|
+
raise KeyError(f"Language {language} not found in tokenizer.")
|
|
224
|
+
|
|
225
|
+
@cached_property
|
|
226
|
+
def all_language_tokens(self) -> Tuple[int]:
|
|
227
|
+
result = []
|
|
228
|
+
for token, token_id in self.special_tokens.items():
|
|
229
|
+
if token.strip("<|>") in LANGUAGES:
|
|
230
|
+
result.append(token_id)
|
|
231
|
+
return tuple(result)[: self.num_languages]
|
|
232
|
+
|
|
233
|
+
@cached_property
|
|
234
|
+
def all_language_codes(self) -> Tuple[str]:
|
|
235
|
+
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
|
236
|
+
|
|
237
|
+
@cached_property
|
|
238
|
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
|
239
|
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
|
240
|
+
|
|
241
|
+
@cached_property
|
|
242
|
+
def non_speech_tokens(self) -> Tuple[int]:
|
|
243
|
+
"""
|
|
244
|
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
|
245
|
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
|
246
|
+
|
|
247
|
+
- ♪♪♪
|
|
248
|
+
- ( SPEAKING FOREIGN LANGUAGE )
|
|
249
|
+
- [DAVID] Hey there,
|
|
250
|
+
|
|
251
|
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
|
252
|
+
"""
|
|
253
|
+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
|
254
|
+
symbols += (
|
|
255
|
+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
|
259
|
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
|
260
|
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
|
261
|
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
|
262
|
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
|
263
|
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
|
264
|
+
|
|
265
|
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
|
266
|
+
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
|
267
|
+
for symbol in symbols + list(miscellaneous):
|
|
268
|
+
for tokens in [
|
|
269
|
+
self.encoding.encode(symbol),
|
|
270
|
+
self.encoding.encode(" " + symbol),
|
|
271
|
+
]:
|
|
272
|
+
if len(tokens) == 1 or symbol in miscellaneous:
|
|
273
|
+
result.add(tokens[0])
|
|
274
|
+
|
|
275
|
+
return tuple(sorted(result))
|
|
276
|
+
|
|
277
|
+
def split_to_word_tokens(self, tokens: List[int]):
|
|
278
|
+
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
|
279
|
+
# These languages don't typically use spaces, so it is difficult to split words
|
|
280
|
+
# without morpheme analysis. Here, we instead split words at any
|
|
281
|
+
# position where the tokens are decoded as valid unicode points
|
|
282
|
+
return self.split_tokens_on_unicode(tokens)
|
|
283
|
+
|
|
284
|
+
return self.split_tokens_on_spaces(tokens)
|
|
285
|
+
|
|
286
|
+
def split_tokens_on_unicode(self, tokens: List[int]):
|
|
287
|
+
decoded_full = self.decode_with_timestamps(tokens)
|
|
288
|
+
replacement_char = "\ufffd"
|
|
289
|
+
|
|
290
|
+
words = []
|
|
291
|
+
word_tokens = []
|
|
292
|
+
current_tokens = []
|
|
293
|
+
unicode_offset = 0
|
|
294
|
+
|
|
295
|
+
for token in tokens:
|
|
296
|
+
current_tokens.append(token)
|
|
297
|
+
decoded = self.decode_with_timestamps(current_tokens)
|
|
298
|
+
|
|
299
|
+
if (
|
|
300
|
+
replacement_char not in decoded
|
|
301
|
+
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
|
302
|
+
== replacement_char
|
|
303
|
+
):
|
|
304
|
+
words.append(decoded)
|
|
305
|
+
word_tokens.append(current_tokens)
|
|
306
|
+
current_tokens = []
|
|
307
|
+
unicode_offset += len(decoded)
|
|
308
|
+
|
|
309
|
+
return words, word_tokens
|
|
310
|
+
|
|
311
|
+
def split_tokens_on_spaces(self, tokens: List[int]):
|
|
312
|
+
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
|
313
|
+
words = []
|
|
314
|
+
word_tokens = []
|
|
315
|
+
|
|
316
|
+
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
|
317
|
+
special = subword_tokens[0] >= self.eot
|
|
318
|
+
with_space = subword.startswith(" ")
|
|
319
|
+
punctuation = subword.strip() in string.punctuation
|
|
320
|
+
if special or with_space or punctuation or len(words) == 0:
|
|
321
|
+
words.append(subword)
|
|
322
|
+
word_tokens.append(subword_tokens)
|
|
323
|
+
else:
|
|
324
|
+
words[-1] = words[-1] + subword
|
|
325
|
+
word_tokens[-1].extend(subword_tokens)
|
|
326
|
+
|
|
327
|
+
return words, word_tokens
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@lru_cache(maxsize=None)
|
|
331
|
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
|
332
|
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
|
333
|
+
ranks = {
|
|
334
|
+
base64.b64decode(token): int(rank)
|
|
335
|
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
|
336
|
+
}
|
|
337
|
+
n_vocab = len(ranks)
|
|
338
|
+
special_tokens = {}
|
|
339
|
+
|
|
340
|
+
specials = [
|
|
341
|
+
"<|endoftext|>",
|
|
342
|
+
"<|startoftranscript|>",
|
|
343
|
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
|
344
|
+
"<|translate|>",
|
|
345
|
+
"<|transcribe|>",
|
|
346
|
+
"<|startoflm|>",
|
|
347
|
+
"<|startofprev|>",
|
|
348
|
+
"<|nospeech|>",
|
|
349
|
+
"<|notimestamps|>",
|
|
350
|
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
|
351
|
+
]
|
|
352
|
+
|
|
353
|
+
for token in specials:
|
|
354
|
+
special_tokens[token] = n_vocab
|
|
355
|
+
n_vocab += 1
|
|
356
|
+
|
|
357
|
+
return tiktoken.Encoding(
|
|
358
|
+
name=os.path.basename(vocab_path),
|
|
359
|
+
explicit_n_vocab=n_vocab,
|
|
360
|
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
|
361
|
+
mergeable_ranks=ranks,
|
|
362
|
+
special_tokens=special_tokens,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
@lru_cache(maxsize=None)
|
|
367
|
+
def get_tokenizer(
|
|
368
|
+
multilingual: bool,
|
|
369
|
+
*,
|
|
370
|
+
num_languages: int = 99,
|
|
371
|
+
language: Optional[str] = None,
|
|
372
|
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
|
373
|
+
) -> Tokenizer:
|
|
374
|
+
if language is not None:
|
|
375
|
+
language = language.lower()
|
|
376
|
+
if language not in LANGUAGES:
|
|
377
|
+
if language in TO_LANGUAGE_CODE:
|
|
378
|
+
language = TO_LANGUAGE_CODE[language]
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Unsupported language: {language}")
|
|
381
|
+
|
|
382
|
+
if multilingual:
|
|
383
|
+
encoding_name = "multilingual"
|
|
384
|
+
language = language or "en"
|
|
385
|
+
task = task or "transcribe"
|
|
386
|
+
else:
|
|
387
|
+
encoding_name = "gpt2"
|
|
388
|
+
language = None
|
|
389
|
+
task = None
|
|
390
|
+
|
|
391
|
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
|
392
|
+
|
|
393
|
+
return Tokenizer(
|
|
394
|
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
|
395
|
+
)
|