mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
"""
|
2
|
+
Translation module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from staticvectors import StaticVectors
|
8
|
+
|
9
|
+
STATICVECTORS = True
|
10
|
+
except ImportError:
|
11
|
+
STATICVECTORS = False
|
12
|
+
|
13
|
+
from huggingface_hub.hf_api import HfApi
|
14
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
15
|
+
|
16
|
+
from ...models import Models
|
17
|
+
from ..hfmodel import HFModel
|
18
|
+
|
19
|
+
|
20
|
+
class Translation(HFModel):
|
21
|
+
"""
|
22
|
+
Translates text from source language into target language.
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Default language detection model
|
26
|
+
DEFAULT_LANG_DETECT = "neuml/language-id-quantized"
|
27
|
+
|
28
|
+
def __init__(self, path=None, quantize=False, gpu=True, batch=64, langdetect=None, findmodels=True):
|
29
|
+
"""
|
30
|
+
Constructs a new language translation pipeline.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
path: optional path to model, accepts Hugging Face model hub id or local path,
|
34
|
+
uses default model for task if not provided
|
35
|
+
quantize: if model should be quantized, defaults to False
|
36
|
+
gpu: True/False if GPU should be enabled, also supports a GPU device id
|
37
|
+
batch: batch size used to incrementally process content
|
38
|
+
langdetect: set a custom language detection function, method must take a list of strings and return
|
39
|
+
language codes for each, uses default language detector if not provided
|
40
|
+
findmodels: True/False if the Hugging Face Hub will be searched for source-target translation models
|
41
|
+
"""
|
42
|
+
|
43
|
+
# Call parent constructor
|
44
|
+
super().__init__(path if path else "facebook/m2m100_418M", quantize, gpu, batch)
|
45
|
+
|
46
|
+
# Language detection
|
47
|
+
self.detector = None
|
48
|
+
self.langdetect = langdetect
|
49
|
+
self.findmodels = findmodels
|
50
|
+
|
51
|
+
# Language models
|
52
|
+
self.models = {}
|
53
|
+
self.ids = None
|
54
|
+
|
55
|
+
def __call__(self, texts, target="en", source=None, showmodels=False):
|
56
|
+
"""
|
57
|
+
Translates text from source language into target language.
|
58
|
+
|
59
|
+
This method supports texts as a string or a list. If the input is a string,
|
60
|
+
the return type is string. If text is a list, the return type is a list.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
texts: text|list
|
64
|
+
target: target language code, defaults to "en"
|
65
|
+
source: source language code, detects language if not provided
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
list of translated text
|
69
|
+
"""
|
70
|
+
|
71
|
+
values = [texts] if not isinstance(texts, list) else texts
|
72
|
+
|
73
|
+
# Detect source languages
|
74
|
+
languages = self.detect(values) if not source else [source] * len(values)
|
75
|
+
unique = set(languages)
|
76
|
+
|
77
|
+
# Build a dict from language to list of (index, text)
|
78
|
+
langdict = {}
|
79
|
+
for x, lang in enumerate(languages):
|
80
|
+
if lang not in langdict:
|
81
|
+
langdict[lang] = []
|
82
|
+
langdict[lang].append((x, values[x]))
|
83
|
+
|
84
|
+
results = {}
|
85
|
+
for language in unique:
|
86
|
+
# Get all indices and text values for a language
|
87
|
+
inputs = langdict[language]
|
88
|
+
|
89
|
+
# Translate text in batches
|
90
|
+
outputs = []
|
91
|
+
for chunk in self.batch([text for _, text in inputs], self.batchsize):
|
92
|
+
outputs.extend(self.translate(chunk, language, target, showmodels))
|
93
|
+
|
94
|
+
# Store output value
|
95
|
+
for y, (x, _) in enumerate(inputs):
|
96
|
+
if showmodels:
|
97
|
+
model, op = outputs[y]
|
98
|
+
results[x] = (op.strip(), language, model)
|
99
|
+
else:
|
100
|
+
results[x] = outputs[y].strip()
|
101
|
+
|
102
|
+
# Return results in same order as input
|
103
|
+
results = [results[x] for x in sorted(results)]
|
104
|
+
return results[0] if isinstance(texts, str) else results
|
105
|
+
|
106
|
+
def modelids(self):
|
107
|
+
"""
|
108
|
+
Runs a query to get a list of available language models from the Hugging Face API.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
list of source-target language model ids
|
112
|
+
"""
|
113
|
+
|
114
|
+
ids = [x.id for x in HfApi().list_models(author="Helsinki-NLP")] if self.findmodels else []
|
115
|
+
return set(ids)
|
116
|
+
|
117
|
+
def detect(self, texts):
|
118
|
+
"""
|
119
|
+
Detects the language for each element in texts.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
texts: list of text
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
list of languages
|
126
|
+
"""
|
127
|
+
|
128
|
+
# Default detector
|
129
|
+
if not self.langdetect or isinstance(self.langdetect, str):
|
130
|
+
return self.defaultdetect(texts)
|
131
|
+
|
132
|
+
# Call external language detector
|
133
|
+
return self.langdetect(texts)
|
134
|
+
|
135
|
+
def defaultdetect(self, texts):
|
136
|
+
"""
|
137
|
+
Default language detection model.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
texts: list of text
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
list of languages
|
144
|
+
"""
|
145
|
+
|
146
|
+
if not self.detector:
|
147
|
+
if not STATICVECTORS:
|
148
|
+
raise ImportError('Language detection is not available - install "pipeline" extra to enable')
|
149
|
+
|
150
|
+
# Get model path
|
151
|
+
path = self.langdetect if self.langdetect else Translation.DEFAULT_LANG_DETECT
|
152
|
+
|
153
|
+
# Load language detection model
|
154
|
+
self.detector = StaticVectors(path)
|
155
|
+
|
156
|
+
# Transform texts to format expected by language detection model
|
157
|
+
texts = [x.lower().replace("\n", " ").replace("\r\n", " ") for x in texts]
|
158
|
+
|
159
|
+
# Detect languages
|
160
|
+
return [x[0][0] for x in self.detector.predict(texts)]
|
161
|
+
|
162
|
+
def translate(self, texts, source, target, showmodels=False):
|
163
|
+
"""
|
164
|
+
Translates text from source to target language.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
texts: list of text
|
168
|
+
source: source language code
|
169
|
+
target: target language code
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
list of translated text
|
173
|
+
"""
|
174
|
+
|
175
|
+
# Return original if already in target language
|
176
|
+
if source == target:
|
177
|
+
return texts
|
178
|
+
|
179
|
+
# Load model and tokenizer
|
180
|
+
path, model, tokenizer = self.lookup(source, target)
|
181
|
+
|
182
|
+
model.to(self.device)
|
183
|
+
indices = None
|
184
|
+
maxlength = Models.maxlength(model, tokenizer)
|
185
|
+
|
186
|
+
with self.context():
|
187
|
+
if hasattr(tokenizer, "lang_code_to_id"):
|
188
|
+
source = self.langid(tokenizer.lang_code_to_id, source)
|
189
|
+
target = self.langid(tokenizer.lang_code_to_id, target)
|
190
|
+
|
191
|
+
tokenizer.src_lang = source
|
192
|
+
tokens, indices = self.tokenize(tokenizer, texts)
|
193
|
+
|
194
|
+
translated = model.generate(**tokens, forced_bos_token_id=tokenizer.lang_code_to_id[target], max_length=maxlength)
|
195
|
+
else:
|
196
|
+
tokens, indices = self.tokenize(tokenizer, texts)
|
197
|
+
translated = model.generate(**tokens, max_length=maxlength)
|
198
|
+
|
199
|
+
# Decode translations
|
200
|
+
translated = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
201
|
+
|
202
|
+
# Combine translations - handle splits on large text from tokenizer
|
203
|
+
results, last = [], -1
|
204
|
+
for x, i in enumerate(indices):
|
205
|
+
v = (path, translated[x]) if showmodels else translated[x]
|
206
|
+
if i == last:
|
207
|
+
results[-1] += v
|
208
|
+
else:
|
209
|
+
results.append(v)
|
210
|
+
|
211
|
+
last = i
|
212
|
+
|
213
|
+
return results
|
214
|
+
|
215
|
+
def lookup(self, source, target):
|
216
|
+
"""
|
217
|
+
Retrieves a translation model for source->target language. This method caches each model loaded.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
source: source language code
|
221
|
+
target: target language code
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
(model, tokenizer)
|
225
|
+
"""
|
226
|
+
|
227
|
+
# Determine best translation model to use, load if necessary and return
|
228
|
+
path = self.modelpath(source, target)
|
229
|
+
if path not in self.models:
|
230
|
+
self.models[path] = self.load(path)
|
231
|
+
|
232
|
+
return (path,) + self.models[path]
|
233
|
+
|
234
|
+
def modelpath(self, source, target):
|
235
|
+
"""
|
236
|
+
Derives a translation model path given source and target languages.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
source: source language code
|
240
|
+
target: target language code
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
model path
|
244
|
+
"""
|
245
|
+
|
246
|
+
# Lazy load model ids
|
247
|
+
if self.ids is None:
|
248
|
+
self.ids = self.modelids()
|
249
|
+
|
250
|
+
# First try direct model
|
251
|
+
template = "Helsinki-NLP/opus-mt-%s-%s"
|
252
|
+
path = template % (source, target)
|
253
|
+
if path in self.ids:
|
254
|
+
return path
|
255
|
+
|
256
|
+
# Use multi-language - english model
|
257
|
+
if self.findmodels and target == "en":
|
258
|
+
return template % ("mul", target)
|
259
|
+
|
260
|
+
# Default model if no suitable model found
|
261
|
+
return self.path
|
262
|
+
|
263
|
+
def load(self, path):
|
264
|
+
"""
|
265
|
+
Loads a model specified by path.
|
266
|
+
|
267
|
+
Args:
|
268
|
+
path: model path
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
(model, tokenizer)
|
272
|
+
"""
|
273
|
+
|
274
|
+
model = AutoModelForSeq2SeqLM.from_pretrained(path)
|
275
|
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
276
|
+
|
277
|
+
# Apply model initialization routines
|
278
|
+
model = self.prepare(model)
|
279
|
+
|
280
|
+
return (model, tokenizer)
|
281
|
+
|
282
|
+
def langid(self, languages, target):
|
283
|
+
"""
|
284
|
+
Searches a list of languages for a prefix match on target.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
languages: list of languages
|
288
|
+
target: target language code
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
best match or None if no match found
|
292
|
+
"""
|
293
|
+
|
294
|
+
for lang in languages:
|
295
|
+
if lang.startswith(target):
|
296
|
+
return lang
|
297
|
+
|
298
|
+
return None
|
@@ -0,0 +1,196 @@
|
|
1
|
+
"""
|
2
|
+
Hugging Face Transformers ONNX export module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from collections import OrderedDict
|
6
|
+
from io import BytesIO
|
7
|
+
from itertools import chain
|
8
|
+
from tempfile import NamedTemporaryFile
|
9
|
+
|
10
|
+
# Conditional import
|
11
|
+
try:
|
12
|
+
from onnxruntime.quantization import quantize_dynamic
|
13
|
+
|
14
|
+
ONNX_RUNTIME = True
|
15
|
+
except ImportError:
|
16
|
+
ONNX_RUNTIME = False
|
17
|
+
|
18
|
+
from torch import nn
|
19
|
+
from torch.onnx import export
|
20
|
+
|
21
|
+
from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
|
22
|
+
|
23
|
+
from ...models import PoolingFactory
|
24
|
+
from ..tensors import Tensors
|
25
|
+
|
26
|
+
|
27
|
+
class HFOnnx(Tensors):
|
28
|
+
"""
|
29
|
+
Exports a Hugging Face Transformer model to ONNX.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __call__(self, path, task="default", output=None, quantize=False, opset=14):
|
33
|
+
"""
|
34
|
+
Exports a Hugging Face Transformer model to ONNX.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
path: path to model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple
|
38
|
+
task: optional model task or category, determines the model type and outputs, defaults to export hidden state
|
39
|
+
output: optional output model path, defaults to return byte array if None
|
40
|
+
quantize: if model should be quantized (requires onnx to be installed), defaults to False
|
41
|
+
opset: onnx opset, defaults to 14
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
path to model output or model as bytes depending on output parameter
|
45
|
+
"""
|
46
|
+
|
47
|
+
inputs, outputs, model = self.parameters(task)
|
48
|
+
|
49
|
+
if isinstance(path, (list, tuple)):
|
50
|
+
model, tokenizer = path
|
51
|
+
model = model.cpu()
|
52
|
+
else:
|
53
|
+
model = model(path)
|
54
|
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
55
|
+
|
56
|
+
# Generate dummy inputs
|
57
|
+
dummy = dict(tokenizer(["test inputs"], return_tensors="pt"))
|
58
|
+
|
59
|
+
# Default to BytesIO if no output file provided
|
60
|
+
output = output if output else BytesIO()
|
61
|
+
|
62
|
+
# Export model to ONNX
|
63
|
+
export(
|
64
|
+
model,
|
65
|
+
(dummy,),
|
66
|
+
output,
|
67
|
+
opset_version=opset,
|
68
|
+
do_constant_folding=True,
|
69
|
+
input_names=list(inputs.keys()),
|
70
|
+
output_names=list(outputs.keys()),
|
71
|
+
dynamic_axes=dict(chain(inputs.items(), outputs.items())),
|
72
|
+
)
|
73
|
+
|
74
|
+
# Quantize model
|
75
|
+
if quantize:
|
76
|
+
if not ONNX_RUNTIME:
|
77
|
+
raise ImportError('onnxruntime is not available - install "pipeline" extra to enable')
|
78
|
+
|
79
|
+
output = self.quantization(output)
|
80
|
+
|
81
|
+
if isinstance(output, BytesIO):
|
82
|
+
# Reset stream and return bytes
|
83
|
+
output.seek(0)
|
84
|
+
output = output.read()
|
85
|
+
|
86
|
+
return output
|
87
|
+
|
88
|
+
def quantization(self, output):
|
89
|
+
"""
|
90
|
+
Quantizes an ONNX model.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
output: path to ONNX model or BytesIO with model data
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
quantized model as file path or bytes
|
97
|
+
"""
|
98
|
+
|
99
|
+
temp = None
|
100
|
+
if isinstance(output, BytesIO):
|
101
|
+
with NamedTemporaryFile(suffix=".quant", delete=False) as tmpfile:
|
102
|
+
temp = tmpfile.name
|
103
|
+
|
104
|
+
with open(temp, "wb") as f:
|
105
|
+
f.write(output.getbuffer())
|
106
|
+
|
107
|
+
output = temp
|
108
|
+
|
109
|
+
# Quantize model
|
110
|
+
quantize_dynamic(output, output, extra_options={"MatMulConstBOnly": False})
|
111
|
+
|
112
|
+
# Read file back to bytes if temp file was created
|
113
|
+
if temp:
|
114
|
+
with open(temp, "rb") as f:
|
115
|
+
output = f.read()
|
116
|
+
|
117
|
+
return output
|
118
|
+
|
119
|
+
def parameters(self, task):
|
120
|
+
"""
|
121
|
+
Defines inputs and outputs for an ONNX model.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
task: task name used to lookup model configuration
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
(inputs, outputs, model function)
|
128
|
+
"""
|
129
|
+
|
130
|
+
inputs = OrderedDict(
|
131
|
+
[
|
132
|
+
("input_ids", {0: "batch", 1: "sequence"}),
|
133
|
+
("attention_mask", {0: "batch", 1: "sequence"}),
|
134
|
+
("token_type_ids", {0: "batch", 1: "sequence"}),
|
135
|
+
]
|
136
|
+
)
|
137
|
+
|
138
|
+
config = {
|
139
|
+
"default": (OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), AutoModel.from_pretrained),
|
140
|
+
"pooling": (OrderedDict({"embeddings": {0: "batch", 1: "sequence"}}), lambda x: PoolingOnnx(x, -1)),
|
141
|
+
"question-answering": (
|
142
|
+
OrderedDict(
|
143
|
+
{
|
144
|
+
"start_logits": {0: "batch", 1: "sequence"},
|
145
|
+
"end_logits": {0: "batch", 1: "sequence"},
|
146
|
+
}
|
147
|
+
),
|
148
|
+
AutoModelForQuestionAnswering.from_pretrained,
|
149
|
+
),
|
150
|
+
"text-classification": (OrderedDict({"logits": {0: "batch"}}), AutoModelForSequenceClassification.from_pretrained),
|
151
|
+
}
|
152
|
+
|
153
|
+
# Aliases
|
154
|
+
config["zero-shot-classification"] = config["text-classification"]
|
155
|
+
|
156
|
+
return (inputs,) + config[task]
|
157
|
+
|
158
|
+
|
159
|
+
class PoolingOnnx(nn.Module):
|
160
|
+
"""
|
161
|
+
Extends Pooling methods to name inputs to model, which is required to export to ONNX.
|
162
|
+
"""
|
163
|
+
|
164
|
+
def __init__(self, path, device):
|
165
|
+
"""
|
166
|
+
Creates a new PoolingOnnx instance.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
path: path to model, accepts Hugging Face model hub id or local path
|
170
|
+
device: tensor device id
|
171
|
+
"""
|
172
|
+
|
173
|
+
super().__init__()
|
174
|
+
|
175
|
+
# Create pooling method based on configuration
|
176
|
+
self.model = PoolingFactory.create({"path": path, "device": device})
|
177
|
+
|
178
|
+
# pylint: disable=W0221
|
179
|
+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
|
180
|
+
"""
|
181
|
+
Runs inputs through pooling model and returns outputs.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
inputs: model inputs
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
model outputs
|
188
|
+
"""
|
189
|
+
|
190
|
+
# Build list of arguments dynamically since some models take token_type_ids
|
191
|
+
# and others don't
|
192
|
+
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
193
|
+
if token_type_ids is not None:
|
194
|
+
inputs["token_type_ids"] = token_type_ids
|
195
|
+
|
196
|
+
return self.model.forward(**inputs)
|