mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,298 @@
1
+ """
2
+ Translation module
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ from staticvectors import StaticVectors
8
+
9
+ STATICVECTORS = True
10
+ except ImportError:
11
+ STATICVECTORS = False
12
+
13
+ from huggingface_hub.hf_api import HfApi
14
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
15
+
16
+ from ...models import Models
17
+ from ..hfmodel import HFModel
18
+
19
+
20
+ class Translation(HFModel):
21
+ """
22
+ Translates text from source language into target language.
23
+ """
24
+
25
+ # Default language detection model
26
+ DEFAULT_LANG_DETECT = "neuml/language-id-quantized"
27
+
28
+ def __init__(self, path=None, quantize=False, gpu=True, batch=64, langdetect=None, findmodels=True):
29
+ """
30
+ Constructs a new language translation pipeline.
31
+
32
+ Args:
33
+ path: optional path to model, accepts Hugging Face model hub id or local path,
34
+ uses default model for task if not provided
35
+ quantize: if model should be quantized, defaults to False
36
+ gpu: True/False if GPU should be enabled, also supports a GPU device id
37
+ batch: batch size used to incrementally process content
38
+ langdetect: set a custom language detection function, method must take a list of strings and return
39
+ language codes for each, uses default language detector if not provided
40
+ findmodels: True/False if the Hugging Face Hub will be searched for source-target translation models
41
+ """
42
+
43
+ # Call parent constructor
44
+ super().__init__(path if path else "facebook/m2m100_418M", quantize, gpu, batch)
45
+
46
+ # Language detection
47
+ self.detector = None
48
+ self.langdetect = langdetect
49
+ self.findmodels = findmodels
50
+
51
+ # Language models
52
+ self.models = {}
53
+ self.ids = None
54
+
55
+ def __call__(self, texts, target="en", source=None, showmodels=False):
56
+ """
57
+ Translates text from source language into target language.
58
+
59
+ This method supports texts as a string or a list. If the input is a string,
60
+ the return type is string. If text is a list, the return type is a list.
61
+
62
+ Args:
63
+ texts: text|list
64
+ target: target language code, defaults to "en"
65
+ source: source language code, detects language if not provided
66
+
67
+ Returns:
68
+ list of translated text
69
+ """
70
+
71
+ values = [texts] if not isinstance(texts, list) else texts
72
+
73
+ # Detect source languages
74
+ languages = self.detect(values) if not source else [source] * len(values)
75
+ unique = set(languages)
76
+
77
+ # Build a dict from language to list of (index, text)
78
+ langdict = {}
79
+ for x, lang in enumerate(languages):
80
+ if lang not in langdict:
81
+ langdict[lang] = []
82
+ langdict[lang].append((x, values[x]))
83
+
84
+ results = {}
85
+ for language in unique:
86
+ # Get all indices and text values for a language
87
+ inputs = langdict[language]
88
+
89
+ # Translate text in batches
90
+ outputs = []
91
+ for chunk in self.batch([text for _, text in inputs], self.batchsize):
92
+ outputs.extend(self.translate(chunk, language, target, showmodels))
93
+
94
+ # Store output value
95
+ for y, (x, _) in enumerate(inputs):
96
+ if showmodels:
97
+ model, op = outputs[y]
98
+ results[x] = (op.strip(), language, model)
99
+ else:
100
+ results[x] = outputs[y].strip()
101
+
102
+ # Return results in same order as input
103
+ results = [results[x] for x in sorted(results)]
104
+ return results[0] if isinstance(texts, str) else results
105
+
106
+ def modelids(self):
107
+ """
108
+ Runs a query to get a list of available language models from the Hugging Face API.
109
+
110
+ Returns:
111
+ list of source-target language model ids
112
+ """
113
+
114
+ ids = [x.id for x in HfApi().list_models(author="Helsinki-NLP")] if self.findmodels else []
115
+ return set(ids)
116
+
117
+ def detect(self, texts):
118
+ """
119
+ Detects the language for each element in texts.
120
+
121
+ Args:
122
+ texts: list of text
123
+
124
+ Returns:
125
+ list of languages
126
+ """
127
+
128
+ # Default detector
129
+ if not self.langdetect or isinstance(self.langdetect, str):
130
+ return self.defaultdetect(texts)
131
+
132
+ # Call external language detector
133
+ return self.langdetect(texts)
134
+
135
+ def defaultdetect(self, texts):
136
+ """
137
+ Default language detection model.
138
+
139
+ Args:
140
+ texts: list of text
141
+
142
+ Returns:
143
+ list of languages
144
+ """
145
+
146
+ if not self.detector:
147
+ if not STATICVECTORS:
148
+ raise ImportError('Language detection is not available - install "pipeline" extra to enable')
149
+
150
+ # Get model path
151
+ path = self.langdetect if self.langdetect else Translation.DEFAULT_LANG_DETECT
152
+
153
+ # Load language detection model
154
+ self.detector = StaticVectors(path)
155
+
156
+ # Transform texts to format expected by language detection model
157
+ texts = [x.lower().replace("\n", " ").replace("\r\n", " ") for x in texts]
158
+
159
+ # Detect languages
160
+ return [x[0][0] for x in self.detector.predict(texts)]
161
+
162
+ def translate(self, texts, source, target, showmodels=False):
163
+ """
164
+ Translates text from source to target language.
165
+
166
+ Args:
167
+ texts: list of text
168
+ source: source language code
169
+ target: target language code
170
+
171
+ Returns:
172
+ list of translated text
173
+ """
174
+
175
+ # Return original if already in target language
176
+ if source == target:
177
+ return texts
178
+
179
+ # Load model and tokenizer
180
+ path, model, tokenizer = self.lookup(source, target)
181
+
182
+ model.to(self.device)
183
+ indices = None
184
+ maxlength = Models.maxlength(model, tokenizer)
185
+
186
+ with self.context():
187
+ if hasattr(tokenizer, "lang_code_to_id"):
188
+ source = self.langid(tokenizer.lang_code_to_id, source)
189
+ target = self.langid(tokenizer.lang_code_to_id, target)
190
+
191
+ tokenizer.src_lang = source
192
+ tokens, indices = self.tokenize(tokenizer, texts)
193
+
194
+ translated = model.generate(**tokens, forced_bos_token_id=tokenizer.lang_code_to_id[target], max_length=maxlength)
195
+ else:
196
+ tokens, indices = self.tokenize(tokenizer, texts)
197
+ translated = model.generate(**tokens, max_length=maxlength)
198
+
199
+ # Decode translations
200
+ translated = tokenizer.batch_decode(translated, skip_special_tokens=True)
201
+
202
+ # Combine translations - handle splits on large text from tokenizer
203
+ results, last = [], -1
204
+ for x, i in enumerate(indices):
205
+ v = (path, translated[x]) if showmodels else translated[x]
206
+ if i == last:
207
+ results[-1] += v
208
+ else:
209
+ results.append(v)
210
+
211
+ last = i
212
+
213
+ return results
214
+
215
+ def lookup(self, source, target):
216
+ """
217
+ Retrieves a translation model for source->target language. This method caches each model loaded.
218
+
219
+ Args:
220
+ source: source language code
221
+ target: target language code
222
+
223
+ Returns:
224
+ (model, tokenizer)
225
+ """
226
+
227
+ # Determine best translation model to use, load if necessary and return
228
+ path = self.modelpath(source, target)
229
+ if path not in self.models:
230
+ self.models[path] = self.load(path)
231
+
232
+ return (path,) + self.models[path]
233
+
234
+ def modelpath(self, source, target):
235
+ """
236
+ Derives a translation model path given source and target languages.
237
+
238
+ Args:
239
+ source: source language code
240
+ target: target language code
241
+
242
+ Returns:
243
+ model path
244
+ """
245
+
246
+ # Lazy load model ids
247
+ if self.ids is None:
248
+ self.ids = self.modelids()
249
+
250
+ # First try direct model
251
+ template = "Helsinki-NLP/opus-mt-%s-%s"
252
+ path = template % (source, target)
253
+ if path in self.ids:
254
+ return path
255
+
256
+ # Use multi-language - english model
257
+ if self.findmodels and target == "en":
258
+ return template % ("mul", target)
259
+
260
+ # Default model if no suitable model found
261
+ return self.path
262
+
263
+ def load(self, path):
264
+ """
265
+ Loads a model specified by path.
266
+
267
+ Args:
268
+ path: model path
269
+
270
+ Returns:
271
+ (model, tokenizer)
272
+ """
273
+
274
+ model = AutoModelForSeq2SeqLM.from_pretrained(path)
275
+ tokenizer = AutoTokenizer.from_pretrained(path)
276
+
277
+ # Apply model initialization routines
278
+ model = self.prepare(model)
279
+
280
+ return (model, tokenizer)
281
+
282
+ def langid(self, languages, target):
283
+ """
284
+ Searches a list of languages for a prefix match on target.
285
+
286
+ Args:
287
+ languages: list of languages
288
+ target: target language code
289
+
290
+ Returns:
291
+ best match or None if no match found
292
+ """
293
+
294
+ for lang in languages:
295
+ if lang.startswith(target):
296
+ return lang
297
+
298
+ return None
@@ -0,0 +1,7 @@
1
+ """
2
+ Train imports
3
+ """
4
+
5
+ from .hfonnx import HFOnnx
6
+ from .hftrainer import HFTrainer
7
+ from .mlonnx import MLOnnx
@@ -0,0 +1,196 @@
1
+ """
2
+ Hugging Face Transformers ONNX export module
3
+ """
4
+
5
+ from collections import OrderedDict
6
+ from io import BytesIO
7
+ from itertools import chain
8
+ from tempfile import NamedTemporaryFile
9
+
10
+ # Conditional import
11
+ try:
12
+ from onnxruntime.quantization import quantize_dynamic
13
+
14
+ ONNX_RUNTIME = True
15
+ except ImportError:
16
+ ONNX_RUNTIME = False
17
+
18
+ from torch import nn
19
+ from torch.onnx import export
20
+
21
+ from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
22
+
23
+ from ...models import PoolingFactory
24
+ from ..tensors import Tensors
25
+
26
+
27
+ class HFOnnx(Tensors):
28
+ """
29
+ Exports a Hugging Face Transformer model to ONNX.
30
+ """
31
+
32
+ def __call__(self, path, task="default", output=None, quantize=False, opset=14):
33
+ """
34
+ Exports a Hugging Face Transformer model to ONNX.
35
+
36
+ Args:
37
+ path: path to model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple
38
+ task: optional model task or category, determines the model type and outputs, defaults to export hidden state
39
+ output: optional output model path, defaults to return byte array if None
40
+ quantize: if model should be quantized (requires onnx to be installed), defaults to False
41
+ opset: onnx opset, defaults to 14
42
+
43
+ Returns:
44
+ path to model output or model as bytes depending on output parameter
45
+ """
46
+
47
+ inputs, outputs, model = self.parameters(task)
48
+
49
+ if isinstance(path, (list, tuple)):
50
+ model, tokenizer = path
51
+ model = model.cpu()
52
+ else:
53
+ model = model(path)
54
+ tokenizer = AutoTokenizer.from_pretrained(path)
55
+
56
+ # Generate dummy inputs
57
+ dummy = dict(tokenizer(["test inputs"], return_tensors="pt"))
58
+
59
+ # Default to BytesIO if no output file provided
60
+ output = output if output else BytesIO()
61
+
62
+ # Export model to ONNX
63
+ export(
64
+ model,
65
+ (dummy,),
66
+ output,
67
+ opset_version=opset,
68
+ do_constant_folding=True,
69
+ input_names=list(inputs.keys()),
70
+ output_names=list(outputs.keys()),
71
+ dynamic_axes=dict(chain(inputs.items(), outputs.items())),
72
+ )
73
+
74
+ # Quantize model
75
+ if quantize:
76
+ if not ONNX_RUNTIME:
77
+ raise ImportError('onnxruntime is not available - install "pipeline" extra to enable')
78
+
79
+ output = self.quantization(output)
80
+
81
+ if isinstance(output, BytesIO):
82
+ # Reset stream and return bytes
83
+ output.seek(0)
84
+ output = output.read()
85
+
86
+ return output
87
+
88
+ def quantization(self, output):
89
+ """
90
+ Quantizes an ONNX model.
91
+
92
+ Args:
93
+ output: path to ONNX model or BytesIO with model data
94
+
95
+ Returns:
96
+ quantized model as file path or bytes
97
+ """
98
+
99
+ temp = None
100
+ if isinstance(output, BytesIO):
101
+ with NamedTemporaryFile(suffix=".quant", delete=False) as tmpfile:
102
+ temp = tmpfile.name
103
+
104
+ with open(temp, "wb") as f:
105
+ f.write(output.getbuffer())
106
+
107
+ output = temp
108
+
109
+ # Quantize model
110
+ quantize_dynamic(output, output, extra_options={"MatMulConstBOnly": False})
111
+
112
+ # Read file back to bytes if temp file was created
113
+ if temp:
114
+ with open(temp, "rb") as f:
115
+ output = f.read()
116
+
117
+ return output
118
+
119
+ def parameters(self, task):
120
+ """
121
+ Defines inputs and outputs for an ONNX model.
122
+
123
+ Args:
124
+ task: task name used to lookup model configuration
125
+
126
+ Returns:
127
+ (inputs, outputs, model function)
128
+ """
129
+
130
+ inputs = OrderedDict(
131
+ [
132
+ ("input_ids", {0: "batch", 1: "sequence"}),
133
+ ("attention_mask", {0: "batch", 1: "sequence"}),
134
+ ("token_type_ids", {0: "batch", 1: "sequence"}),
135
+ ]
136
+ )
137
+
138
+ config = {
139
+ "default": (OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), AutoModel.from_pretrained),
140
+ "pooling": (OrderedDict({"embeddings": {0: "batch", 1: "sequence"}}), lambda x: PoolingOnnx(x, -1)),
141
+ "question-answering": (
142
+ OrderedDict(
143
+ {
144
+ "start_logits": {0: "batch", 1: "sequence"},
145
+ "end_logits": {0: "batch", 1: "sequence"},
146
+ }
147
+ ),
148
+ AutoModelForQuestionAnswering.from_pretrained,
149
+ ),
150
+ "text-classification": (OrderedDict({"logits": {0: "batch"}}), AutoModelForSequenceClassification.from_pretrained),
151
+ }
152
+
153
+ # Aliases
154
+ config["zero-shot-classification"] = config["text-classification"]
155
+
156
+ return (inputs,) + config[task]
157
+
158
+
159
+ class PoolingOnnx(nn.Module):
160
+ """
161
+ Extends Pooling methods to name inputs to model, which is required to export to ONNX.
162
+ """
163
+
164
+ def __init__(self, path, device):
165
+ """
166
+ Creates a new PoolingOnnx instance.
167
+
168
+ Args:
169
+ path: path to model, accepts Hugging Face model hub id or local path
170
+ device: tensor device id
171
+ """
172
+
173
+ super().__init__()
174
+
175
+ # Create pooling method based on configuration
176
+ self.model = PoolingFactory.create({"path": path, "device": device})
177
+
178
+ # pylint: disable=W0221
179
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
180
+ """
181
+ Runs inputs through pooling model and returns outputs.
182
+
183
+ Args:
184
+ inputs: model inputs
185
+
186
+ Returns:
187
+ model outputs
188
+ """
189
+
190
+ # Build list of arguments dynamically since some models take token_type_ids
191
+ # and others don't
192
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
193
+ if token_type_ids is not None:
194
+ inputs["token_type_ids"] = token_type_ids
195
+
196
+ return self.model.forward(**inputs)