mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
txtai/graph/topics.py ADDED
@@ -0,0 +1,166 @@
1
+ """
2
+ Topics module
3
+ """
4
+
5
+ from ..pipeline import Tokenizer
6
+ from ..scoring import ScoringFactory
7
+
8
+
9
+ class Topics:
10
+ """
11
+ Topic modeling using community detection.
12
+ """
13
+
14
+ def __init__(self, config):
15
+ """
16
+ Creates a new Topics instance.
17
+
18
+ Args:
19
+ config: topic configuration
20
+ """
21
+
22
+ self.config = config if config else {}
23
+ self.tokenizer = Tokenizer(stopwords=True)
24
+
25
+ # Additional stopwords to ignore when building topic names
26
+ self.stopwords = set()
27
+ if "stopwords" in self.config:
28
+ self.stopwords.update(self.config["stopwords"])
29
+
30
+ def __call__(self, graph):
31
+ """
32
+ Runs topic modeling for input graph.
33
+
34
+ Args:
35
+ graph: Graph instance
36
+
37
+ Returns:
38
+ dictionary of {topic name: [ids]}
39
+ """
40
+
41
+ # Detect communities
42
+ communities = graph.communities(self.config)
43
+
44
+ # Sort by community size, largest to smallest
45
+ communities = sorted(communities, key=len, reverse=True)
46
+
47
+ # Calculate centrality of graph
48
+ centrality = graph.centrality()
49
+
50
+ # Score communities and generate topn terms
51
+ topics = [self.score(graph, x, community, centrality) for x, community in enumerate(communities)]
52
+
53
+ # Merge duplicate topics and return
54
+ return self.merge(topics)
55
+
56
+ def score(self, graph, index, community, centrality):
57
+ """
58
+ Scores a community of nodes and generates the topn terms in the community.
59
+
60
+ Args:
61
+ graph: Graph instance
62
+ index: community index
63
+ community: community of nodes
64
+ centrality: node centrality scores
65
+
66
+ Returns:
67
+ (topn topic terms, topic ids sorted by score descending)
68
+ """
69
+
70
+ # Tokenize input and build scoring index
71
+ scoring = ScoringFactory.create({"method": self.config.get("labels", "bm25"), "terms": True})
72
+ scoring.index(((node, self.tokenize(graph, node), None) for node in community))
73
+
74
+ # Check if scoring index has data
75
+ if scoring.idf:
76
+ # Sort by most commonly occurring terms (i.e. lowest score)
77
+ idf = sorted(scoring.idf, key=scoring.idf.get)
78
+
79
+ # Term count for generating topic labels
80
+ topn = self.config.get("terms", 4)
81
+
82
+ # Get topn terms
83
+ terms = self.topn(idf, topn)
84
+
85
+ # Sort community by score descending
86
+ community = [uid for uid, _ in scoring.search(terms, len(community))]
87
+ else:
88
+ # No text found for topic, generate topic name
89
+ terms = ["topic", str(index)]
90
+
91
+ # Sort community by centrality scores
92
+ community = sorted(community, key=lambda x: centrality[x], reverse=True)
93
+
94
+ return (terms, community)
95
+
96
+ def tokenize(self, graph, node):
97
+ """
98
+ Tokenizes node text.
99
+
100
+ Args:
101
+ graph: Graph instance
102
+ node: node id
103
+
104
+ Returns:
105
+ list of node tokens
106
+ """
107
+
108
+ text = graph.attribute(node, "text")
109
+ return self.tokenizer(text) if text else []
110
+
111
+ def topn(self, terms, n):
112
+ """
113
+ Gets topn terms.
114
+
115
+ Args:
116
+ terms: list of terms
117
+ n: topn
118
+
119
+ Returns:
120
+ topn terms
121
+ """
122
+
123
+ topn = []
124
+
125
+ for term in terms:
126
+ # Add terms that pass tokenization rules
127
+ if self.tokenizer(term) and term not in self.stopwords:
128
+ topn.append(term)
129
+
130
+ # Break once topn terms collected
131
+ if len(topn) == n:
132
+ break
133
+
134
+ return topn
135
+
136
+ def merge(self, topics):
137
+ """
138
+ Merges duplicate topics
139
+
140
+ Args:
141
+ topics: list of (topn terms, topic ids)
142
+
143
+ Returns:
144
+ dictionary of {topic name:[ids]}
145
+ """
146
+
147
+ merge, termslist = {}, {}
148
+
149
+ for terms, uids in topics:
150
+ # Use topic terms as key
151
+ key = frozenset(terms)
152
+
153
+ # Add key to merged topics, if necessary
154
+ if key not in merge:
155
+ merge[key], termslist[key] = [], terms
156
+
157
+ # Merge communities
158
+ merge[key].extend(uids)
159
+
160
+ # Sort communities largest to smallest since the order could have changed with merges
161
+ results = {}
162
+ for k, v in sorted(merge.items(), key=lambda x: len(x[1]), reverse=True):
163
+ # Create composite string key using topic terms and store ids
164
+ results["_".join(termslist[k])] = v
165
+
166
+ return results
@@ -0,0 +1,9 @@
1
+ """
2
+ Models imports
3
+ """
4
+
5
+ from .models import Models
6
+ from .onnx import OnnxModel
7
+ from .pooling import *
8
+ from .registry import Registry
9
+ from .tokendetection import TokenDetection
txtai/models/models.py ADDED
@@ -0,0 +1,268 @@
1
+ """
2
+ Models module
3
+ """
4
+
5
+ import os
6
+
7
+ import torch
8
+
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModel,
12
+ AutoModelForQuestionAnswering,
13
+ AutoModelForSeq2SeqLM,
14
+ AutoModelForSequenceClassification,
15
+ AutoTokenizer,
16
+ )
17
+ from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
18
+
19
+ from .onnx import OnnxModel
20
+
21
+
22
+ class Models:
23
+ """
24
+ Utility methods for working with machine learning models
25
+ """
26
+
27
+ @staticmethod
28
+ def checklength(config, tokenizer):
29
+ """
30
+ Checks the length for a Hugging Face Transformers tokenizer using a Hugging Face Transformers config. Copies the
31
+ max_position_embeddings parameter if the tokenizer has no max_length set. This helps with backwards compatibility
32
+ with older tokenizers.
33
+
34
+ Args:
35
+ config: transformers config
36
+ tokenizer: transformers tokenizer
37
+ """
38
+
39
+ # Unpack nested config, handles passing model directly
40
+ if hasattr(config, "config"):
41
+ config = config.config
42
+
43
+ if (
44
+ hasattr(config, "max_position_embeddings")
45
+ and tokenizer
46
+ and hasattr(tokenizer, "model_max_length")
47
+ and tokenizer.model_max_length == int(1e30)
48
+ ):
49
+ tokenizer.model_max_length = config.max_position_embeddings
50
+
51
+ @staticmethod
52
+ def maxlength(config, tokenizer):
53
+ """
54
+ Gets the best max length to use for generate calls. This method will return config.max_length if it's set. Otherwise, it will return
55
+ tokenizer.model_max_length.
56
+
57
+ Args:
58
+ config: transformers config
59
+ tokenizer: transformers tokenizer
60
+ """
61
+
62
+ # Unpack nested config, handles passing model directly
63
+ if hasattr(config, "config"):
64
+ config = config.config
65
+
66
+ # Get non-defaulted fields
67
+ keys = config.to_diff_dict()
68
+
69
+ # Use config.max_length if not set to default value, else use tokenizer.model_max_length if available
70
+ return config.max_length if "max_length" in keys or not hasattr(tokenizer, "model_max_length") else tokenizer.model_max_length
71
+
72
+ @staticmethod
73
+ def deviceid(gpu):
74
+ """
75
+ Translates input gpu argument into a device id.
76
+
77
+ Args:
78
+ gpu: True/False if GPU should be enabled, also supports a device id/string/instance
79
+
80
+ Returns:
81
+ device id
82
+ """
83
+
84
+ # Return if this is already a torch device
85
+ # pylint: disable=E1101
86
+ if isinstance(gpu, torch.device):
87
+ return gpu
88
+
89
+ # Always return -1 if gpu is None or an accelerator device is unavailable
90
+ if gpu is None or not Models.hasaccelerator():
91
+ return -1
92
+
93
+ # Default to device 0 if gpu is True and not otherwise specified
94
+ if isinstance(gpu, bool):
95
+ return 0 if gpu else -1
96
+
97
+ # Return gpu as device id if gpu flag is an int
98
+ return int(gpu)
99
+
100
+ @staticmethod
101
+ def device(deviceid):
102
+ """
103
+ Gets a tensor device.
104
+
105
+ Args:
106
+ deviceid: device id
107
+
108
+ Returns:
109
+ tensor device
110
+ """
111
+
112
+ # Torch device
113
+ # pylint: disable=E1101
114
+ return deviceid if isinstance(deviceid, torch.device) else torch.device(Models.reference(deviceid))
115
+
116
+ @staticmethod
117
+ def reference(deviceid):
118
+ """
119
+ Gets a tensor device reference.
120
+
121
+ Args:
122
+ deviceid: device id
123
+
124
+ Returns:
125
+ device reference
126
+ """
127
+
128
+ return (
129
+ deviceid
130
+ if isinstance(deviceid, str)
131
+ else (
132
+ "cpu"
133
+ if deviceid < 0
134
+ else f"cuda:{deviceid}" if torch.cuda.is_available() else "mps" if Models.hasmpsdevice() else Models.finddevice()
135
+ )
136
+ )
137
+
138
+ @staticmethod
139
+ def acceleratorcount():
140
+ """
141
+ Gets the number of accelerator devices available.
142
+
143
+ Returns:
144
+ number of accelerators available
145
+ """
146
+
147
+ return max(torch.cuda.device_count(), int(Models.hasaccelerator()))
148
+
149
+ @staticmethod
150
+ def hasaccelerator():
151
+ """
152
+ Checks if there is an accelerator device available.
153
+
154
+ Returns:
155
+ True if an accelerator device is available, False otherwise
156
+ """
157
+
158
+ return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice())
159
+
160
+ @staticmethod
161
+ def hasmpsdevice():
162
+ """
163
+ Checks if there is a MPS device available.
164
+
165
+ Returns:
166
+ True if a MPS device is available, False otherwise
167
+ """
168
+
169
+ return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available()
170
+
171
+ @staticmethod
172
+ def finddevice():
173
+ """
174
+ Attempts to find an alternative accelerator device.
175
+
176
+ Returns:
177
+ name of first alternative accelerator available or None if not found
178
+ """
179
+
180
+ return next((device for device in ["xpu"] if hasattr(torch, device) and getattr(torch, device).is_available()), None)
181
+
182
+ @staticmethod
183
+ def load(path, config=None, task="default", modelargs=None):
184
+ """
185
+ Loads a machine learning model. Handles multiple model frameworks (ONNX, Transformers).
186
+
187
+ Args:
188
+ path: path to model
189
+ config: path to model configuration
190
+ task: task name used to lookup model type
191
+
192
+ Returns:
193
+ machine learning model
194
+ """
195
+
196
+ # Detect ONNX models
197
+ if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)):
198
+ return OnnxModel(path, config)
199
+
200
+ # Return path, if path isn't a string
201
+ if not isinstance(path, str):
202
+ return path
203
+
204
+ # Transformer models
205
+ models = {
206
+ "default": AutoModel.from_pretrained,
207
+ "question-answering": AutoModelForQuestionAnswering.from_pretrained,
208
+ "summarization": AutoModelForSeq2SeqLM.from_pretrained,
209
+ "text-classification": AutoModelForSequenceClassification.from_pretrained,
210
+ "zero-shot-classification": AutoModelForSequenceClassification.from_pretrained,
211
+ }
212
+
213
+ # Pass modelargs as keyword arguments
214
+ modelargs = modelargs if modelargs else {}
215
+
216
+ # Load model for supported tasks. Return path for unsupported tasks.
217
+ return models[task](path, **modelargs) if task in models else path
218
+
219
+ @staticmethod
220
+ def tokenizer(path, **kwargs):
221
+ """
222
+ Loads a tokenizer from path.
223
+
224
+ Args:
225
+ path: path to tokenizer
226
+ kwargs: optional additional keyword arguments
227
+
228
+ Returns:
229
+ tokenizer
230
+ """
231
+
232
+ return AutoTokenizer.from_pretrained(path, **kwargs) if isinstance(path, str) else path
233
+
234
+ @staticmethod
235
+ def task(path, **kwargs):
236
+ """
237
+ Attempts to detect the model task from path.
238
+
239
+ Args:
240
+ path: path to model
241
+ kwargs: optional additional keyword arguments
242
+
243
+ Returns:
244
+ inferred model task
245
+ """
246
+
247
+ # Get model configuration
248
+ config = None
249
+ if isinstance(path, (list, tuple)) and hasattr(path[0], "config"):
250
+ config = path[0].config
251
+ elif isinstance(path, str):
252
+ config = AutoConfig.from_pretrained(path, **kwargs)
253
+
254
+ # Attempt to resolve task using configuration
255
+ task = None
256
+ if config:
257
+ architecture = config.architectures[0] if config.architectures else None
258
+ if architecture:
259
+ if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
260
+ task = "vision"
261
+ elif any(x for x in ["LMHead", "CausalLM"] if x in architecture):
262
+ task = "language-generation"
263
+ elif "QuestionAnswering" in architecture:
264
+ task = "question-answering"
265
+ elif "ConditionalGeneration" in architecture:
266
+ task = "sequence-sequence"
267
+
268
+ return task
txtai/models/onnx.py ADDED
@@ -0,0 +1,133 @@
1
+ """
2
+ ONNX module
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ import onnxruntime as ort
8
+
9
+ ONNX_RUNTIME = True
10
+ except ImportError:
11
+ ONNX_RUNTIME = False
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from transformers import AutoConfig
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_outputs import SequenceClassifierOutput
19
+ from transformers.modeling_utils import PreTrainedModel
20
+
21
+ from .registry import Registry
22
+
23
+
24
+ # pylint: disable=W0223
25
+ class OnnxModel(PreTrainedModel):
26
+ """
27
+ Provides a Transformers/PyTorch compatible interface for ONNX models. Handles casting inputs
28
+ and outputs with minimal to no copying of data.
29
+ """
30
+
31
+ def __init__(self, model, config=None):
32
+ """
33
+ Creates a new OnnxModel.
34
+
35
+ Args:
36
+ model: path to model or InferenceSession
37
+ config: path to model configuration
38
+ """
39
+
40
+ if not ONNX_RUNTIME:
41
+ raise ImportError('onnxruntime is not available - install "model" extra to enable')
42
+
43
+ super().__init__(AutoConfig.from_pretrained(config) if config else OnnxConfig())
44
+
45
+ # Create ONNX session
46
+ self.model = ort.InferenceSession(model, ort.SessionOptions(), self.providers())
47
+
48
+ # Add references for this class to supported AutoModel classes
49
+ Registry.register(self)
50
+
51
+ @property
52
+ def device(self):
53
+ """
54
+ Returns model device id.
55
+
56
+ Returns:
57
+ model device id
58
+ """
59
+
60
+ return -1
61
+
62
+ def providers(self):
63
+ """
64
+ Returns a list of available and usable providers.
65
+
66
+ Returns:
67
+ list of available and usable providers
68
+ """
69
+
70
+ # Create list of providers, prefer CUDA provider if available
71
+ # CUDA provider only available if GPU is available and onnxruntime-gpu installed
72
+ if torch.cuda.is_available() and "CUDAExecutionProvider" in ort.get_available_providers():
73
+ return ["CUDAExecutionProvider", "CPUExecutionProvider"]
74
+
75
+ # Default when CUDA provider isn't available
76
+ return ["CPUExecutionProvider"]
77
+
78
+ def forward(self, **inputs):
79
+ """
80
+ Runs inputs through an ONNX model and returns outputs. This method handles casting inputs
81
+ and outputs between torch tensors and numpy arrays as shared memory (no copy).
82
+
83
+ Args:
84
+ inputs: model inputs
85
+
86
+ Returns:
87
+ model outputs
88
+ """
89
+
90
+ inputs = self.parse(inputs)
91
+
92
+ # Run inputs through ONNX model
93
+ results = self.model.run(None, inputs)
94
+
95
+ # pylint: disable=E1101
96
+ # Detect if logits is an output and return classifier output in that case
97
+ if any(x.name for x in self.model.get_outputs() if x.name == "logits"):
98
+ return SequenceClassifierOutput(logits=torch.from_numpy(np.array(results[0])))
99
+
100
+ return torch.from_numpy(np.array(results))
101
+
102
+ def parse(self, inputs):
103
+ """
104
+ Parse model inputs and handle converting to ONNX compatible inputs.
105
+
106
+ Args:
107
+ inputs: model inputs
108
+
109
+ Returns:
110
+ ONNX compatible model inputs
111
+ """
112
+
113
+ features = {}
114
+
115
+ # Select features from inputs
116
+ for key in ["input_ids", "attention_mask", "token_type_ids"]:
117
+ if key in inputs:
118
+ value = inputs[key]
119
+
120
+ # Cast torch tensors to numpy
121
+ if hasattr(value, "cpu"):
122
+ value = value.cpu().numpy()
123
+
124
+ # Cast to numpy array if not already one
125
+ features[key] = np.asarray(value)
126
+
127
+ return features
128
+
129
+
130
+ class OnnxConfig(PretrainedConfig):
131
+ """
132
+ Configuration for ONNX models.
133
+ """
@@ -0,0 +1,9 @@
1
+ """
2
+ Pooling imports
3
+ """
4
+
5
+ from .base import Pooling
6
+ from .cls import ClsPooling
7
+ from .factory import PoolingFactory
8
+ from .late import LatePooling
9
+ from .mean import MeanPooling