mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/graph/topics.py
ADDED
@@ -0,0 +1,166 @@
|
|
1
|
+
"""
|
2
|
+
Topics module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..pipeline import Tokenizer
|
6
|
+
from ..scoring import ScoringFactory
|
7
|
+
|
8
|
+
|
9
|
+
class Topics:
|
10
|
+
"""
|
11
|
+
Topic modeling using community detection.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, config):
|
15
|
+
"""
|
16
|
+
Creates a new Topics instance.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
config: topic configuration
|
20
|
+
"""
|
21
|
+
|
22
|
+
self.config = config if config else {}
|
23
|
+
self.tokenizer = Tokenizer(stopwords=True)
|
24
|
+
|
25
|
+
# Additional stopwords to ignore when building topic names
|
26
|
+
self.stopwords = set()
|
27
|
+
if "stopwords" in self.config:
|
28
|
+
self.stopwords.update(self.config["stopwords"])
|
29
|
+
|
30
|
+
def __call__(self, graph):
|
31
|
+
"""
|
32
|
+
Runs topic modeling for input graph.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
graph: Graph instance
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
dictionary of {topic name: [ids]}
|
39
|
+
"""
|
40
|
+
|
41
|
+
# Detect communities
|
42
|
+
communities = graph.communities(self.config)
|
43
|
+
|
44
|
+
# Sort by community size, largest to smallest
|
45
|
+
communities = sorted(communities, key=len, reverse=True)
|
46
|
+
|
47
|
+
# Calculate centrality of graph
|
48
|
+
centrality = graph.centrality()
|
49
|
+
|
50
|
+
# Score communities and generate topn terms
|
51
|
+
topics = [self.score(graph, x, community, centrality) for x, community in enumerate(communities)]
|
52
|
+
|
53
|
+
# Merge duplicate topics and return
|
54
|
+
return self.merge(topics)
|
55
|
+
|
56
|
+
def score(self, graph, index, community, centrality):
|
57
|
+
"""
|
58
|
+
Scores a community of nodes and generates the topn terms in the community.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
graph: Graph instance
|
62
|
+
index: community index
|
63
|
+
community: community of nodes
|
64
|
+
centrality: node centrality scores
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
(topn topic terms, topic ids sorted by score descending)
|
68
|
+
"""
|
69
|
+
|
70
|
+
# Tokenize input and build scoring index
|
71
|
+
scoring = ScoringFactory.create({"method": self.config.get("labels", "bm25"), "terms": True})
|
72
|
+
scoring.index(((node, self.tokenize(graph, node), None) for node in community))
|
73
|
+
|
74
|
+
# Check if scoring index has data
|
75
|
+
if scoring.idf:
|
76
|
+
# Sort by most commonly occurring terms (i.e. lowest score)
|
77
|
+
idf = sorted(scoring.idf, key=scoring.idf.get)
|
78
|
+
|
79
|
+
# Term count for generating topic labels
|
80
|
+
topn = self.config.get("terms", 4)
|
81
|
+
|
82
|
+
# Get topn terms
|
83
|
+
terms = self.topn(idf, topn)
|
84
|
+
|
85
|
+
# Sort community by score descending
|
86
|
+
community = [uid for uid, _ in scoring.search(terms, len(community))]
|
87
|
+
else:
|
88
|
+
# No text found for topic, generate topic name
|
89
|
+
terms = ["topic", str(index)]
|
90
|
+
|
91
|
+
# Sort community by centrality scores
|
92
|
+
community = sorted(community, key=lambda x: centrality[x], reverse=True)
|
93
|
+
|
94
|
+
return (terms, community)
|
95
|
+
|
96
|
+
def tokenize(self, graph, node):
|
97
|
+
"""
|
98
|
+
Tokenizes node text.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
graph: Graph instance
|
102
|
+
node: node id
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
list of node tokens
|
106
|
+
"""
|
107
|
+
|
108
|
+
text = graph.attribute(node, "text")
|
109
|
+
return self.tokenizer(text) if text else []
|
110
|
+
|
111
|
+
def topn(self, terms, n):
|
112
|
+
"""
|
113
|
+
Gets topn terms.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
terms: list of terms
|
117
|
+
n: topn
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
topn terms
|
121
|
+
"""
|
122
|
+
|
123
|
+
topn = []
|
124
|
+
|
125
|
+
for term in terms:
|
126
|
+
# Add terms that pass tokenization rules
|
127
|
+
if self.tokenizer(term) and term not in self.stopwords:
|
128
|
+
topn.append(term)
|
129
|
+
|
130
|
+
# Break once topn terms collected
|
131
|
+
if len(topn) == n:
|
132
|
+
break
|
133
|
+
|
134
|
+
return topn
|
135
|
+
|
136
|
+
def merge(self, topics):
|
137
|
+
"""
|
138
|
+
Merges duplicate topics
|
139
|
+
|
140
|
+
Args:
|
141
|
+
topics: list of (topn terms, topic ids)
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
dictionary of {topic name:[ids]}
|
145
|
+
"""
|
146
|
+
|
147
|
+
merge, termslist = {}, {}
|
148
|
+
|
149
|
+
for terms, uids in topics:
|
150
|
+
# Use topic terms as key
|
151
|
+
key = frozenset(terms)
|
152
|
+
|
153
|
+
# Add key to merged topics, if necessary
|
154
|
+
if key not in merge:
|
155
|
+
merge[key], termslist[key] = [], terms
|
156
|
+
|
157
|
+
# Merge communities
|
158
|
+
merge[key].extend(uids)
|
159
|
+
|
160
|
+
# Sort communities largest to smallest since the order could have changed with merges
|
161
|
+
results = {}
|
162
|
+
for k, v in sorted(merge.items(), key=lambda x: len(x[1]), reverse=True):
|
163
|
+
# Create composite string key using topic terms and store ids
|
164
|
+
results["_".join(termslist[k])] = v
|
165
|
+
|
166
|
+
return results
|
txtai/models/__init__.py
ADDED
txtai/models/models.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
"""
|
2
|
+
Models module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from transformers import (
|
10
|
+
AutoConfig,
|
11
|
+
AutoModel,
|
12
|
+
AutoModelForQuestionAnswering,
|
13
|
+
AutoModelForSeq2SeqLM,
|
14
|
+
AutoModelForSequenceClassification,
|
15
|
+
AutoTokenizer,
|
16
|
+
)
|
17
|
+
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
18
|
+
|
19
|
+
from .onnx import OnnxModel
|
20
|
+
|
21
|
+
|
22
|
+
class Models:
|
23
|
+
"""
|
24
|
+
Utility methods for working with machine learning models
|
25
|
+
"""
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def checklength(config, tokenizer):
|
29
|
+
"""
|
30
|
+
Checks the length for a Hugging Face Transformers tokenizer using a Hugging Face Transformers config. Copies the
|
31
|
+
max_position_embeddings parameter if the tokenizer has no max_length set. This helps with backwards compatibility
|
32
|
+
with older tokenizers.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
config: transformers config
|
36
|
+
tokenizer: transformers tokenizer
|
37
|
+
"""
|
38
|
+
|
39
|
+
# Unpack nested config, handles passing model directly
|
40
|
+
if hasattr(config, "config"):
|
41
|
+
config = config.config
|
42
|
+
|
43
|
+
if (
|
44
|
+
hasattr(config, "max_position_embeddings")
|
45
|
+
and tokenizer
|
46
|
+
and hasattr(tokenizer, "model_max_length")
|
47
|
+
and tokenizer.model_max_length == int(1e30)
|
48
|
+
):
|
49
|
+
tokenizer.model_max_length = config.max_position_embeddings
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def maxlength(config, tokenizer):
|
53
|
+
"""
|
54
|
+
Gets the best max length to use for generate calls. This method will return config.max_length if it's set. Otherwise, it will return
|
55
|
+
tokenizer.model_max_length.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
config: transformers config
|
59
|
+
tokenizer: transformers tokenizer
|
60
|
+
"""
|
61
|
+
|
62
|
+
# Unpack nested config, handles passing model directly
|
63
|
+
if hasattr(config, "config"):
|
64
|
+
config = config.config
|
65
|
+
|
66
|
+
# Get non-defaulted fields
|
67
|
+
keys = config.to_diff_dict()
|
68
|
+
|
69
|
+
# Use config.max_length if not set to default value, else use tokenizer.model_max_length if available
|
70
|
+
return config.max_length if "max_length" in keys or not hasattr(tokenizer, "model_max_length") else tokenizer.model_max_length
|
71
|
+
|
72
|
+
@staticmethod
|
73
|
+
def deviceid(gpu):
|
74
|
+
"""
|
75
|
+
Translates input gpu argument into a device id.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
gpu: True/False if GPU should be enabled, also supports a device id/string/instance
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
device id
|
82
|
+
"""
|
83
|
+
|
84
|
+
# Return if this is already a torch device
|
85
|
+
# pylint: disable=E1101
|
86
|
+
if isinstance(gpu, torch.device):
|
87
|
+
return gpu
|
88
|
+
|
89
|
+
# Always return -1 if gpu is None or an accelerator device is unavailable
|
90
|
+
if gpu is None or not Models.hasaccelerator():
|
91
|
+
return -1
|
92
|
+
|
93
|
+
# Default to device 0 if gpu is True and not otherwise specified
|
94
|
+
if isinstance(gpu, bool):
|
95
|
+
return 0 if gpu else -1
|
96
|
+
|
97
|
+
# Return gpu as device id if gpu flag is an int
|
98
|
+
return int(gpu)
|
99
|
+
|
100
|
+
@staticmethod
|
101
|
+
def device(deviceid):
|
102
|
+
"""
|
103
|
+
Gets a tensor device.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
deviceid: device id
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
tensor device
|
110
|
+
"""
|
111
|
+
|
112
|
+
# Torch device
|
113
|
+
# pylint: disable=E1101
|
114
|
+
return deviceid if isinstance(deviceid, torch.device) else torch.device(Models.reference(deviceid))
|
115
|
+
|
116
|
+
@staticmethod
|
117
|
+
def reference(deviceid):
|
118
|
+
"""
|
119
|
+
Gets a tensor device reference.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
deviceid: device id
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
device reference
|
126
|
+
"""
|
127
|
+
|
128
|
+
return (
|
129
|
+
deviceid
|
130
|
+
if isinstance(deviceid, str)
|
131
|
+
else (
|
132
|
+
"cpu"
|
133
|
+
if deviceid < 0
|
134
|
+
else f"cuda:{deviceid}" if torch.cuda.is_available() else "mps" if Models.hasmpsdevice() else Models.finddevice()
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def acceleratorcount():
|
140
|
+
"""
|
141
|
+
Gets the number of accelerator devices available.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
number of accelerators available
|
145
|
+
"""
|
146
|
+
|
147
|
+
return max(torch.cuda.device_count(), int(Models.hasaccelerator()))
|
148
|
+
|
149
|
+
@staticmethod
|
150
|
+
def hasaccelerator():
|
151
|
+
"""
|
152
|
+
Checks if there is an accelerator device available.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
True if an accelerator device is available, False otherwise
|
156
|
+
"""
|
157
|
+
|
158
|
+
return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice())
|
159
|
+
|
160
|
+
@staticmethod
|
161
|
+
def hasmpsdevice():
|
162
|
+
"""
|
163
|
+
Checks if there is a MPS device available.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
True if a MPS device is available, False otherwise
|
167
|
+
"""
|
168
|
+
|
169
|
+
return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available()
|
170
|
+
|
171
|
+
@staticmethod
|
172
|
+
def finddevice():
|
173
|
+
"""
|
174
|
+
Attempts to find an alternative accelerator device.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
name of first alternative accelerator available or None if not found
|
178
|
+
"""
|
179
|
+
|
180
|
+
return next((device for device in ["xpu"] if hasattr(torch, device) and getattr(torch, device).is_available()), None)
|
181
|
+
|
182
|
+
@staticmethod
|
183
|
+
def load(path, config=None, task="default", modelargs=None):
|
184
|
+
"""
|
185
|
+
Loads a machine learning model. Handles multiple model frameworks (ONNX, Transformers).
|
186
|
+
|
187
|
+
Args:
|
188
|
+
path: path to model
|
189
|
+
config: path to model configuration
|
190
|
+
task: task name used to lookup model type
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
machine learning model
|
194
|
+
"""
|
195
|
+
|
196
|
+
# Detect ONNX models
|
197
|
+
if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)):
|
198
|
+
return OnnxModel(path, config)
|
199
|
+
|
200
|
+
# Return path, if path isn't a string
|
201
|
+
if not isinstance(path, str):
|
202
|
+
return path
|
203
|
+
|
204
|
+
# Transformer models
|
205
|
+
models = {
|
206
|
+
"default": AutoModel.from_pretrained,
|
207
|
+
"question-answering": AutoModelForQuestionAnswering.from_pretrained,
|
208
|
+
"summarization": AutoModelForSeq2SeqLM.from_pretrained,
|
209
|
+
"text-classification": AutoModelForSequenceClassification.from_pretrained,
|
210
|
+
"zero-shot-classification": AutoModelForSequenceClassification.from_pretrained,
|
211
|
+
}
|
212
|
+
|
213
|
+
# Pass modelargs as keyword arguments
|
214
|
+
modelargs = modelargs if modelargs else {}
|
215
|
+
|
216
|
+
# Load model for supported tasks. Return path for unsupported tasks.
|
217
|
+
return models[task](path, **modelargs) if task in models else path
|
218
|
+
|
219
|
+
@staticmethod
|
220
|
+
def tokenizer(path, **kwargs):
|
221
|
+
"""
|
222
|
+
Loads a tokenizer from path.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
path: path to tokenizer
|
226
|
+
kwargs: optional additional keyword arguments
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
tokenizer
|
230
|
+
"""
|
231
|
+
|
232
|
+
return AutoTokenizer.from_pretrained(path, **kwargs) if isinstance(path, str) else path
|
233
|
+
|
234
|
+
@staticmethod
|
235
|
+
def task(path, **kwargs):
|
236
|
+
"""
|
237
|
+
Attempts to detect the model task from path.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
path: path to model
|
241
|
+
kwargs: optional additional keyword arguments
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
inferred model task
|
245
|
+
"""
|
246
|
+
|
247
|
+
# Get model configuration
|
248
|
+
config = None
|
249
|
+
if isinstance(path, (list, tuple)) and hasattr(path[0], "config"):
|
250
|
+
config = path[0].config
|
251
|
+
elif isinstance(path, str):
|
252
|
+
config = AutoConfig.from_pretrained(path, **kwargs)
|
253
|
+
|
254
|
+
# Attempt to resolve task using configuration
|
255
|
+
task = None
|
256
|
+
if config:
|
257
|
+
architecture = config.architectures[0] if config.architectures else None
|
258
|
+
if architecture:
|
259
|
+
if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
|
260
|
+
task = "vision"
|
261
|
+
elif any(x for x in ["LMHead", "CausalLM"] if x in architecture):
|
262
|
+
task = "language-generation"
|
263
|
+
elif "QuestionAnswering" in architecture:
|
264
|
+
task = "question-answering"
|
265
|
+
elif "ConditionalGeneration" in architecture:
|
266
|
+
task = "sequence-sequence"
|
267
|
+
|
268
|
+
return task
|
txtai/models/onnx.py
ADDED
@@ -0,0 +1,133 @@
|
|
1
|
+
"""
|
2
|
+
ONNX module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
import onnxruntime as ort
|
8
|
+
|
9
|
+
ONNX_RUNTIME = True
|
10
|
+
except ImportError:
|
11
|
+
ONNX_RUNTIME = False
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from transformers import AutoConfig
|
17
|
+
from transformers.configuration_utils import PretrainedConfig
|
18
|
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
19
|
+
from transformers.modeling_utils import PreTrainedModel
|
20
|
+
|
21
|
+
from .registry import Registry
|
22
|
+
|
23
|
+
|
24
|
+
# pylint: disable=W0223
|
25
|
+
class OnnxModel(PreTrainedModel):
|
26
|
+
"""
|
27
|
+
Provides a Transformers/PyTorch compatible interface for ONNX models. Handles casting inputs
|
28
|
+
and outputs with minimal to no copying of data.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, model, config=None):
|
32
|
+
"""
|
33
|
+
Creates a new OnnxModel.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
model: path to model or InferenceSession
|
37
|
+
config: path to model configuration
|
38
|
+
"""
|
39
|
+
|
40
|
+
if not ONNX_RUNTIME:
|
41
|
+
raise ImportError('onnxruntime is not available - install "model" extra to enable')
|
42
|
+
|
43
|
+
super().__init__(AutoConfig.from_pretrained(config) if config else OnnxConfig())
|
44
|
+
|
45
|
+
# Create ONNX session
|
46
|
+
self.model = ort.InferenceSession(model, ort.SessionOptions(), self.providers())
|
47
|
+
|
48
|
+
# Add references for this class to supported AutoModel classes
|
49
|
+
Registry.register(self)
|
50
|
+
|
51
|
+
@property
|
52
|
+
def device(self):
|
53
|
+
"""
|
54
|
+
Returns model device id.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
model device id
|
58
|
+
"""
|
59
|
+
|
60
|
+
return -1
|
61
|
+
|
62
|
+
def providers(self):
|
63
|
+
"""
|
64
|
+
Returns a list of available and usable providers.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
list of available and usable providers
|
68
|
+
"""
|
69
|
+
|
70
|
+
# Create list of providers, prefer CUDA provider if available
|
71
|
+
# CUDA provider only available if GPU is available and onnxruntime-gpu installed
|
72
|
+
if torch.cuda.is_available() and "CUDAExecutionProvider" in ort.get_available_providers():
|
73
|
+
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
74
|
+
|
75
|
+
# Default when CUDA provider isn't available
|
76
|
+
return ["CPUExecutionProvider"]
|
77
|
+
|
78
|
+
def forward(self, **inputs):
|
79
|
+
"""
|
80
|
+
Runs inputs through an ONNX model and returns outputs. This method handles casting inputs
|
81
|
+
and outputs between torch tensors and numpy arrays as shared memory (no copy).
|
82
|
+
|
83
|
+
Args:
|
84
|
+
inputs: model inputs
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
model outputs
|
88
|
+
"""
|
89
|
+
|
90
|
+
inputs = self.parse(inputs)
|
91
|
+
|
92
|
+
# Run inputs through ONNX model
|
93
|
+
results = self.model.run(None, inputs)
|
94
|
+
|
95
|
+
# pylint: disable=E1101
|
96
|
+
# Detect if logits is an output and return classifier output in that case
|
97
|
+
if any(x.name for x in self.model.get_outputs() if x.name == "logits"):
|
98
|
+
return SequenceClassifierOutput(logits=torch.from_numpy(np.array(results[0])))
|
99
|
+
|
100
|
+
return torch.from_numpy(np.array(results))
|
101
|
+
|
102
|
+
def parse(self, inputs):
|
103
|
+
"""
|
104
|
+
Parse model inputs and handle converting to ONNX compatible inputs.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
inputs: model inputs
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
ONNX compatible model inputs
|
111
|
+
"""
|
112
|
+
|
113
|
+
features = {}
|
114
|
+
|
115
|
+
# Select features from inputs
|
116
|
+
for key in ["input_ids", "attention_mask", "token_type_ids"]:
|
117
|
+
if key in inputs:
|
118
|
+
value = inputs[key]
|
119
|
+
|
120
|
+
# Cast torch tensors to numpy
|
121
|
+
if hasattr(value, "cpu"):
|
122
|
+
value = value.cpu().numpy()
|
123
|
+
|
124
|
+
# Cast to numpy array if not already one
|
125
|
+
features[key] = np.asarray(value)
|
126
|
+
|
127
|
+
return features
|
128
|
+
|
129
|
+
|
130
|
+
class OnnxConfig(PretrainedConfig):
|
131
|
+
"""
|
132
|
+
Configuration for ONNX models.
|
133
|
+
"""
|