mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,344 @@
|
|
1
|
+
"""
|
2
|
+
Search module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
|
7
|
+
from .errors import IndexNotFoundError
|
8
|
+
from .scan import Scan
|
9
|
+
|
10
|
+
# Logging configuration
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class Search:
|
15
|
+
"""
|
16
|
+
Executes a batch search action. A search can be both index and/or database driven.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, embeddings, indexids=False, indexonly=False):
|
20
|
+
"""
|
21
|
+
Creates a new search action.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
embeddings: embeddings instance
|
25
|
+
indexids: searches return indexids when True, otherwise run standard search
|
26
|
+
indexonly: always runs an index search even when a database is available
|
27
|
+
"""
|
28
|
+
|
29
|
+
self.embeddings = embeddings
|
30
|
+
self.indexids = indexids or indexonly
|
31
|
+
self.indexonly = indexonly
|
32
|
+
|
33
|
+
# Alias embeddings attributes
|
34
|
+
self.ann = embeddings.ann
|
35
|
+
self.batchtransform = embeddings.batchtransform
|
36
|
+
self.database = embeddings.database
|
37
|
+
self.ids = embeddings.ids
|
38
|
+
self.indexes = embeddings.indexes
|
39
|
+
self.graph = embeddings.graph
|
40
|
+
self.query = embeddings.query
|
41
|
+
self.scoring = embeddings.scoring if embeddings.issparse() else None
|
42
|
+
|
43
|
+
def __call__(self, queries, limit=None, weights=None, index=None, parameters=None):
|
44
|
+
"""
|
45
|
+
Executes a batch search for queries. This method will run either an index search or an index + database search
|
46
|
+
depending on if a database is available.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
queries: list of queries
|
50
|
+
limit: maximum results
|
51
|
+
weights: hybrid score weights
|
52
|
+
index: index name
|
53
|
+
parameters: list of dicts of named parameters to bind to placeholders
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
list of (id, score) per query for index search
|
57
|
+
list of dict per query for an index + database search
|
58
|
+
list of graph results for a graph index search
|
59
|
+
"""
|
60
|
+
|
61
|
+
# Default input parameters
|
62
|
+
limit = limit if limit else 3
|
63
|
+
weights = weights if weights is not None else 0.5
|
64
|
+
|
65
|
+
# Return empty results if there is no database and indexes
|
66
|
+
if not self.ann and not self.scoring and not self.indexes and not self.database:
|
67
|
+
return [[]] * len(queries)
|
68
|
+
|
69
|
+
# Default index name if only subindexes set
|
70
|
+
if not index and not self.ann and not self.scoring and self.indexes:
|
71
|
+
index = self.indexes.default()
|
72
|
+
|
73
|
+
# Graph search
|
74
|
+
if self.graph and self.graph.isquery(queries):
|
75
|
+
return self.graphsearch(queries, limit, weights, index)
|
76
|
+
|
77
|
+
# Database search
|
78
|
+
if not self.indexonly and self.database:
|
79
|
+
return self.dbsearch(queries, limit, weights, index, parameters)
|
80
|
+
|
81
|
+
# Default vector index query (sparse, dense or hybrid)
|
82
|
+
return self.search(queries, limit, weights, index)
|
83
|
+
|
84
|
+
def search(self, queries, limit, weights, index):
|
85
|
+
"""
|
86
|
+
Executes an index search. When only a sparse index is enabled, this is a a keyword search. When only
|
87
|
+
a dense index is enabled, this is an ann search. When both are enabled, this is a hybrid search.
|
88
|
+
|
89
|
+
This method will also query subindexes, if available.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
queries: list of queries
|
93
|
+
limit: maximum results
|
94
|
+
weights: hybrid score weights
|
95
|
+
index: index name
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
list of (id, score) per query
|
99
|
+
"""
|
100
|
+
|
101
|
+
# Run against specified subindex
|
102
|
+
if index:
|
103
|
+
return self.subindex(queries, limit, weights, index)
|
104
|
+
|
105
|
+
# Run against base indexes
|
106
|
+
hybrid = self.ann and self.scoring
|
107
|
+
dense = self.dense(queries, limit * 10 if hybrid else limit) if self.ann else None
|
108
|
+
sparse = self.sparse(queries, limit * 10 if hybrid else limit) if self.scoring else None
|
109
|
+
|
110
|
+
# Combine scores together
|
111
|
+
if hybrid:
|
112
|
+
# Create weights array if single number passed
|
113
|
+
if isinstance(weights, (int, float)):
|
114
|
+
weights = [weights, 1 - weights]
|
115
|
+
|
116
|
+
# Create weighted scores
|
117
|
+
results = []
|
118
|
+
for vectors in zip(dense, sparse):
|
119
|
+
uids = {}
|
120
|
+
for v, scores in enumerate(vectors):
|
121
|
+
for r, (uid, score) in enumerate(scores if weights[v] > 0 else []):
|
122
|
+
# Initialize score
|
123
|
+
if uid not in uids:
|
124
|
+
uids[uid] = 0.0
|
125
|
+
|
126
|
+
# Create hybrid score
|
127
|
+
# - Convex Combination when sparse scores are normalized
|
128
|
+
# - Reciprocal Rank Fusion (RRF) when sparse scores aren't normalized
|
129
|
+
if self.scoring.isnormalized():
|
130
|
+
uids[uid] += score * weights[v]
|
131
|
+
else:
|
132
|
+
uids[uid] += (1.0 / (r + 1)) * weights[v]
|
133
|
+
|
134
|
+
results.append(sorted(uids.items(), key=lambda x: x[1], reverse=True)[:limit])
|
135
|
+
|
136
|
+
return results
|
137
|
+
|
138
|
+
# Raise an error if when no indexes are available
|
139
|
+
if not sparse and not dense:
|
140
|
+
raise IndexNotFoundError("No indexes available")
|
141
|
+
|
142
|
+
# Return single query results
|
143
|
+
return dense if dense else sparse
|
144
|
+
|
145
|
+
def subindex(self, queries, limit, weights, index):
|
146
|
+
"""
|
147
|
+
Executes a subindex search.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
queries: list of queries
|
151
|
+
limit: maximum results
|
152
|
+
weights: hybrid score weights
|
153
|
+
index: index name
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
list of (id, score) per query
|
157
|
+
"""
|
158
|
+
|
159
|
+
# Check that index exists
|
160
|
+
if not self.indexes or index not in self.indexes:
|
161
|
+
raise IndexNotFoundError(f"Index '{index}' not found")
|
162
|
+
|
163
|
+
# Run subindex search
|
164
|
+
results = self.indexes[index].batchsearch(queries, limit, weights)
|
165
|
+
return self.resolve(results)
|
166
|
+
|
167
|
+
def dense(self, queries, limit):
|
168
|
+
"""
|
169
|
+
Executes an dense vector search with an approximate nearest neighbor index.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
queries: list of queries
|
173
|
+
limit: maximum results
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
list of (id, score) per query
|
177
|
+
"""
|
178
|
+
|
179
|
+
# Convert queries to embedding vectors
|
180
|
+
embeddings = self.batchtransform((None, query, None) for query in queries)
|
181
|
+
|
182
|
+
# Search approximate nearest neighbor index
|
183
|
+
results = self.ann.search(embeddings, limit)
|
184
|
+
|
185
|
+
# Require scores to be greater than 0
|
186
|
+
results = [[(i, score) for i, score in r if score > 0] for r in results]
|
187
|
+
|
188
|
+
return self.resolve(results)
|
189
|
+
|
190
|
+
def sparse(self, queries, limit):
|
191
|
+
"""
|
192
|
+
Executes a sparse vector search with a sparse keyword or sparse vector index.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
queries: list of queries
|
196
|
+
limit: maximum results
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
list of (id, score) per query
|
200
|
+
"""
|
201
|
+
|
202
|
+
# Search sparse index
|
203
|
+
results = self.scoring.batchsearch(queries, limit)
|
204
|
+
|
205
|
+
# Require scores to be greater than 0
|
206
|
+
results = [[(i, score) for i, score in r if score > 0] for r in results]
|
207
|
+
|
208
|
+
return self.resolve(results)
|
209
|
+
|
210
|
+
def resolve(self, results):
|
211
|
+
"""
|
212
|
+
Resolves index ids. This is only executed when content is disabled.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
results: results
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
results with resolved ids
|
219
|
+
"""
|
220
|
+
|
221
|
+
# Map indexids to ids if embeddings ids are available
|
222
|
+
if not self.indexids and self.ids:
|
223
|
+
return [[(self.ids[i], score) for i, score in r] for r in results]
|
224
|
+
|
225
|
+
return results
|
226
|
+
|
227
|
+
def dbsearch(self, queries, limit, weights, index, parameters):
|
228
|
+
"""
|
229
|
+
Executes an index + database search.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
queries: list of queries
|
233
|
+
limit: maximum results
|
234
|
+
weights: default hybrid score weights
|
235
|
+
index: default index name
|
236
|
+
parameters: list of dicts of named parameters to bind to placeholders
|
237
|
+
|
238
|
+
Returns:
|
239
|
+
list of dict per query
|
240
|
+
"""
|
241
|
+
|
242
|
+
# Parse queries
|
243
|
+
queries = self.parse(queries)
|
244
|
+
|
245
|
+
# Override limit with query limit, if applicable
|
246
|
+
limit = max(limit, self.limit(queries))
|
247
|
+
|
248
|
+
# Bulk index scan
|
249
|
+
scan = Scan(self.search, limit, weights, index)(queries, parameters)
|
250
|
+
|
251
|
+
# Combine index search results with database search results
|
252
|
+
results = []
|
253
|
+
for x, query in enumerate(queries):
|
254
|
+
# Run the database query, get matching bulk searches for current query
|
255
|
+
result = self.database.search(
|
256
|
+
query, [r for y, r in scan if x == y], limit, parameters[x] if parameters and parameters[x] else None, self.indexids
|
257
|
+
)
|
258
|
+
results.append(result)
|
259
|
+
|
260
|
+
return results
|
261
|
+
|
262
|
+
def parse(self, queries):
|
263
|
+
"""
|
264
|
+
Parses a list of database queries.
|
265
|
+
|
266
|
+
Args:
|
267
|
+
queries: list of queries
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
parsed queries
|
271
|
+
"""
|
272
|
+
|
273
|
+
# Parsed queries
|
274
|
+
parsed = []
|
275
|
+
|
276
|
+
for query in queries:
|
277
|
+
# Parse query
|
278
|
+
parse = self.database.parse(query)
|
279
|
+
|
280
|
+
# Transform query if SQL not parsed and reparse
|
281
|
+
if self.query and "select" not in parse:
|
282
|
+
# Generate query
|
283
|
+
query = self.query(query)
|
284
|
+
logger.debug(query)
|
285
|
+
|
286
|
+
# Reparse query
|
287
|
+
parse = self.database.parse(query)
|
288
|
+
|
289
|
+
parsed.append(parse)
|
290
|
+
|
291
|
+
return parsed
|
292
|
+
|
293
|
+
def limit(self, queries):
|
294
|
+
"""
|
295
|
+
Parses the largest LIMIT clause from queries.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
queries: list of queries
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
largest limit number or 0 if not found
|
302
|
+
"""
|
303
|
+
|
304
|
+
# Override limit with largest limit from database queries
|
305
|
+
qlimit = 0
|
306
|
+
for query in queries:
|
307
|
+
# Parse out qlimit
|
308
|
+
l = query.get("limit")
|
309
|
+
if l and l.isdigit():
|
310
|
+
l = int(l)
|
311
|
+
|
312
|
+
qlimit = l if l and l > qlimit else qlimit
|
313
|
+
|
314
|
+
return qlimit
|
315
|
+
|
316
|
+
def graphsearch(self, queries, limit, weights, index):
|
317
|
+
"""
|
318
|
+
Executes an index + graph search.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
queries: list of queries
|
322
|
+
limit: maximum results
|
323
|
+
weights: default hybrid score weights
|
324
|
+
index: default index name
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
graph search results
|
328
|
+
"""
|
329
|
+
|
330
|
+
# Parse queries
|
331
|
+
queries = [self.graph.parse(query) for query in queries]
|
332
|
+
|
333
|
+
# Override limit with query limit, if applicable
|
334
|
+
limit = max(limit, self.limit(queries))
|
335
|
+
|
336
|
+
# Bulk index scan
|
337
|
+
scan = Scan(self.search, limit, weights, index)(queries, None)
|
338
|
+
|
339
|
+
# Combine index search results with database search results
|
340
|
+
for x, query in enumerate(queries):
|
341
|
+
# Add search results to query
|
342
|
+
query["results"] = [r for y, r in scan if x == y]
|
343
|
+
|
344
|
+
return self.graph.batchsearch(queries, limit, self.indexids)
|
@@ -0,0 +1,120 @@
|
|
1
|
+
"""
|
2
|
+
Explain module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
|
8
|
+
class Explain:
|
9
|
+
"""
|
10
|
+
Explains the importance of each token in an input text element for a query. This method creates n permutations of the input text, where n
|
11
|
+
is the number of tokens in the input text. This effectively masks each token to determine its importance to the query.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, embeddings):
|
15
|
+
"""
|
16
|
+
Creates a new explain action.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
embeddings: embeddings instance
|
20
|
+
"""
|
21
|
+
|
22
|
+
self.embeddings = embeddings
|
23
|
+
self.content = embeddings.config.get("content")
|
24
|
+
|
25
|
+
# Alias embeddings attributes
|
26
|
+
self.database = embeddings.database
|
27
|
+
|
28
|
+
def __call__(self, queries, texts, limit):
|
29
|
+
"""
|
30
|
+
Explains the importance of each input token in text for a list of queries.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
query: input queries
|
34
|
+
texts: optional list of (text|list of tokens), otherwise runs search queries
|
35
|
+
limit: optional limit if texts is None
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
list of dict per input text per query where a higher token scores represents higher importance relative to the query
|
39
|
+
"""
|
40
|
+
|
41
|
+
# Construct texts elements per query
|
42
|
+
texts = self.texts(queries, texts, limit)
|
43
|
+
|
44
|
+
# Explain each query-texts combination
|
45
|
+
return [self.explain(query, texts[x]) for x, query in enumerate(queries)]
|
46
|
+
|
47
|
+
def texts(self, queries, texts, limit):
|
48
|
+
"""
|
49
|
+
Constructs lists of dict for each input query.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
queries: input queries
|
53
|
+
texts: optional list of texts
|
54
|
+
limit: optional limit if texts is None
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
lists of dict for each input query
|
58
|
+
"""
|
59
|
+
|
60
|
+
# Calculate similarity scores per query if texts present
|
61
|
+
if texts:
|
62
|
+
results = []
|
63
|
+
for scores in self.embeddings.batchsimilarity(queries, texts):
|
64
|
+
results.append([{"id": uid, "text": texts[uid], "score": score} for uid, score in scores])
|
65
|
+
|
66
|
+
return results
|
67
|
+
|
68
|
+
# Query for results if texts is None and content is enabled
|
69
|
+
return self.embeddings.batchsearch(queries, limit) if self.content else [[]] * len(queries)
|
70
|
+
|
71
|
+
def explain(self, query, texts):
|
72
|
+
"""
|
73
|
+
Explains the importance of each input token in text for a list of queries.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
query: input query
|
77
|
+
texts: list of text
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
list of {"id": value, "text": value, "score": value, "tokens": value} covering each input text element
|
81
|
+
"""
|
82
|
+
|
83
|
+
# Explain results
|
84
|
+
results = []
|
85
|
+
|
86
|
+
# Parse out similar clauses, if necessary
|
87
|
+
if self.database:
|
88
|
+
# Parse query
|
89
|
+
query = self.database.parse(query)
|
90
|
+
|
91
|
+
# Extract query from similar clause
|
92
|
+
query = " ".join([" ".join(clause) for clause in query["similar"]]) if "similar" in query else None
|
93
|
+
|
94
|
+
# Return original texts if query, text or score not present
|
95
|
+
if not query or not texts or "score" not in texts[0] or "text" not in texts[0]:
|
96
|
+
return texts
|
97
|
+
|
98
|
+
# Calculate result per input text element
|
99
|
+
for result in texts:
|
100
|
+
text = result["text"]
|
101
|
+
tokens = text if isinstance(text, list) else text.split()
|
102
|
+
|
103
|
+
# Create permutations of input text, masking each token to determine importance
|
104
|
+
permutations = []
|
105
|
+
for i in range(len(tokens)):
|
106
|
+
data = tokens.copy()
|
107
|
+
data.pop(i)
|
108
|
+
permutations.append([" ".join(data)])
|
109
|
+
|
110
|
+
# Calculate similarity for each input text permutation and get score delta as importance
|
111
|
+
scores = [(i, result["score"] - np.abs(s)) for i, s in self.embeddings.similarity(query, permutations)]
|
112
|
+
|
113
|
+
# Append tokens to result
|
114
|
+
result["tokens"] = [(tokens[i], score) for i, score in sorted(scores, key=lambda x: x[0])]
|
115
|
+
|
116
|
+
# Add data sorted in index order
|
117
|
+
results.append(result)
|
118
|
+
|
119
|
+
# Sort score descending and return
|
120
|
+
return sorted(results, key=lambda x: x["score"], reverse=True)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
Ids module
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
class Ids:
|
7
|
+
"""
|
8
|
+
Resolves internal ids for lists of ids.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, embeddings):
|
12
|
+
"""
|
13
|
+
Create a new ids action.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
embeddings: embeddings instance
|
17
|
+
"""
|
18
|
+
|
19
|
+
self.database = embeddings.database
|
20
|
+
self.ids = embeddings.ids
|
21
|
+
|
22
|
+
def __call__(self, ids):
|
23
|
+
"""
|
24
|
+
Resolve internal ids.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
ids: ids
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
internal ids
|
31
|
+
"""
|
32
|
+
|
33
|
+
# Resolve ids using database if available, otherwise fallback to embeddings ids
|
34
|
+
results = self.database.ids(ids) if self.database else self.scan(ids)
|
35
|
+
|
36
|
+
# Create dict of id: [iids] given there is a one to many relationship
|
37
|
+
ids = {}
|
38
|
+
for iid, uid in results:
|
39
|
+
if uid not in ids:
|
40
|
+
ids[uid] = []
|
41
|
+
ids[uid].append(iid)
|
42
|
+
|
43
|
+
return ids
|
44
|
+
|
45
|
+
def scan(self, ids):
|
46
|
+
"""
|
47
|
+
Scans embeddings ids array for matches when content is disabled.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
ids: search ids
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
internal ids
|
54
|
+
"""
|
55
|
+
|
56
|
+
# Find existing ids
|
57
|
+
indices = []
|
58
|
+
for uid in ids:
|
59
|
+
indices.extend([(index, value) for index, value in enumerate(self.ids) if uid == value])
|
60
|
+
|
61
|
+
return indices
|
@@ -0,0 +1,69 @@
|
|
1
|
+
"""
|
2
|
+
Query module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
6
|
+
|
7
|
+
|
8
|
+
class Query:
|
9
|
+
"""
|
10
|
+
Query translation model.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, path, prefix=None, maxlength=512):
|
14
|
+
"""
|
15
|
+
Creates a query translation model.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
path: path to query model
|
19
|
+
prefix: text prefix
|
20
|
+
maxlength: max sequence length to generate
|
21
|
+
"""
|
22
|
+
|
23
|
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
24
|
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
|
25
|
+
|
26
|
+
# Default prefix if not provided for T5 models
|
27
|
+
if not prefix and isinstance(self.model, T5ForConditionalGeneration):
|
28
|
+
prefix = "translate English to SQL: "
|
29
|
+
|
30
|
+
self.prefix = prefix
|
31
|
+
self.maxlength = maxlength
|
32
|
+
|
33
|
+
def __call__(self, query):
|
34
|
+
"""
|
35
|
+
Runs query translation model.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
query: input query
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
transformed query
|
42
|
+
"""
|
43
|
+
|
44
|
+
# Add prefix, if necessary
|
45
|
+
if self.prefix:
|
46
|
+
query = f"{self.prefix}{query}"
|
47
|
+
|
48
|
+
# Tokenize and generate text using model
|
49
|
+
features = self.tokenizer([query], return_tensors="pt")
|
50
|
+
output = self.model.generate(input_ids=features["input_ids"], attention_mask=features["attention_mask"], max_length=self.maxlength)
|
51
|
+
|
52
|
+
# Decode tokens to text
|
53
|
+
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
54
|
+
|
55
|
+
# Clean and return generated text
|
56
|
+
return self.clean(result)
|
57
|
+
|
58
|
+
def clean(self, text):
|
59
|
+
"""
|
60
|
+
Applies a series of rules to clean generated text.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
text: input text
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
clean text
|
67
|
+
"""
|
68
|
+
|
69
|
+
return text.replace("$=", "<=")
|