mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,344 @@
1
+ """
2
+ Search module
3
+ """
4
+
5
+ import logging
6
+
7
+ from .errors import IndexNotFoundError
8
+ from .scan import Scan
9
+
10
+ # Logging configuration
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Search:
15
+ """
16
+ Executes a batch search action. A search can be both index and/or database driven.
17
+ """
18
+
19
+ def __init__(self, embeddings, indexids=False, indexonly=False):
20
+ """
21
+ Creates a new search action.
22
+
23
+ Args:
24
+ embeddings: embeddings instance
25
+ indexids: searches return indexids when True, otherwise run standard search
26
+ indexonly: always runs an index search even when a database is available
27
+ """
28
+
29
+ self.embeddings = embeddings
30
+ self.indexids = indexids or indexonly
31
+ self.indexonly = indexonly
32
+
33
+ # Alias embeddings attributes
34
+ self.ann = embeddings.ann
35
+ self.batchtransform = embeddings.batchtransform
36
+ self.database = embeddings.database
37
+ self.ids = embeddings.ids
38
+ self.indexes = embeddings.indexes
39
+ self.graph = embeddings.graph
40
+ self.query = embeddings.query
41
+ self.scoring = embeddings.scoring if embeddings.issparse() else None
42
+
43
+ def __call__(self, queries, limit=None, weights=None, index=None, parameters=None):
44
+ """
45
+ Executes a batch search for queries. This method will run either an index search or an index + database search
46
+ depending on if a database is available.
47
+
48
+ Args:
49
+ queries: list of queries
50
+ limit: maximum results
51
+ weights: hybrid score weights
52
+ index: index name
53
+ parameters: list of dicts of named parameters to bind to placeholders
54
+
55
+ Returns:
56
+ list of (id, score) per query for index search
57
+ list of dict per query for an index + database search
58
+ list of graph results for a graph index search
59
+ """
60
+
61
+ # Default input parameters
62
+ limit = limit if limit else 3
63
+ weights = weights if weights is not None else 0.5
64
+
65
+ # Return empty results if there is no database and indexes
66
+ if not self.ann and not self.scoring and not self.indexes and not self.database:
67
+ return [[]] * len(queries)
68
+
69
+ # Default index name if only subindexes set
70
+ if not index and not self.ann and not self.scoring and self.indexes:
71
+ index = self.indexes.default()
72
+
73
+ # Graph search
74
+ if self.graph and self.graph.isquery(queries):
75
+ return self.graphsearch(queries, limit, weights, index)
76
+
77
+ # Database search
78
+ if not self.indexonly and self.database:
79
+ return self.dbsearch(queries, limit, weights, index, parameters)
80
+
81
+ # Default vector index query (sparse, dense or hybrid)
82
+ return self.search(queries, limit, weights, index)
83
+
84
+ def search(self, queries, limit, weights, index):
85
+ """
86
+ Executes an index search. When only a sparse index is enabled, this is a a keyword search. When only
87
+ a dense index is enabled, this is an ann search. When both are enabled, this is a hybrid search.
88
+
89
+ This method will also query subindexes, if available.
90
+
91
+ Args:
92
+ queries: list of queries
93
+ limit: maximum results
94
+ weights: hybrid score weights
95
+ index: index name
96
+
97
+ Returns:
98
+ list of (id, score) per query
99
+ """
100
+
101
+ # Run against specified subindex
102
+ if index:
103
+ return self.subindex(queries, limit, weights, index)
104
+
105
+ # Run against base indexes
106
+ hybrid = self.ann and self.scoring
107
+ dense = self.dense(queries, limit * 10 if hybrid else limit) if self.ann else None
108
+ sparse = self.sparse(queries, limit * 10 if hybrid else limit) if self.scoring else None
109
+
110
+ # Combine scores together
111
+ if hybrid:
112
+ # Create weights array if single number passed
113
+ if isinstance(weights, (int, float)):
114
+ weights = [weights, 1 - weights]
115
+
116
+ # Create weighted scores
117
+ results = []
118
+ for vectors in zip(dense, sparse):
119
+ uids = {}
120
+ for v, scores in enumerate(vectors):
121
+ for r, (uid, score) in enumerate(scores if weights[v] > 0 else []):
122
+ # Initialize score
123
+ if uid not in uids:
124
+ uids[uid] = 0.0
125
+
126
+ # Create hybrid score
127
+ # - Convex Combination when sparse scores are normalized
128
+ # - Reciprocal Rank Fusion (RRF) when sparse scores aren't normalized
129
+ if self.scoring.isnormalized():
130
+ uids[uid] += score * weights[v]
131
+ else:
132
+ uids[uid] += (1.0 / (r + 1)) * weights[v]
133
+
134
+ results.append(sorted(uids.items(), key=lambda x: x[1], reverse=True)[:limit])
135
+
136
+ return results
137
+
138
+ # Raise an error if when no indexes are available
139
+ if not sparse and not dense:
140
+ raise IndexNotFoundError("No indexes available")
141
+
142
+ # Return single query results
143
+ return dense if dense else sparse
144
+
145
+ def subindex(self, queries, limit, weights, index):
146
+ """
147
+ Executes a subindex search.
148
+
149
+ Args:
150
+ queries: list of queries
151
+ limit: maximum results
152
+ weights: hybrid score weights
153
+ index: index name
154
+
155
+ Returns:
156
+ list of (id, score) per query
157
+ """
158
+
159
+ # Check that index exists
160
+ if not self.indexes or index not in self.indexes:
161
+ raise IndexNotFoundError(f"Index '{index}' not found")
162
+
163
+ # Run subindex search
164
+ results = self.indexes[index].batchsearch(queries, limit, weights)
165
+ return self.resolve(results)
166
+
167
+ def dense(self, queries, limit):
168
+ """
169
+ Executes an dense vector search with an approximate nearest neighbor index.
170
+
171
+ Args:
172
+ queries: list of queries
173
+ limit: maximum results
174
+
175
+ Returns:
176
+ list of (id, score) per query
177
+ """
178
+
179
+ # Convert queries to embedding vectors
180
+ embeddings = self.batchtransform((None, query, None) for query in queries)
181
+
182
+ # Search approximate nearest neighbor index
183
+ results = self.ann.search(embeddings, limit)
184
+
185
+ # Require scores to be greater than 0
186
+ results = [[(i, score) for i, score in r if score > 0] for r in results]
187
+
188
+ return self.resolve(results)
189
+
190
+ def sparse(self, queries, limit):
191
+ """
192
+ Executes a sparse vector search with a sparse keyword or sparse vector index.
193
+
194
+ Args:
195
+ queries: list of queries
196
+ limit: maximum results
197
+
198
+ Returns:
199
+ list of (id, score) per query
200
+ """
201
+
202
+ # Search sparse index
203
+ results = self.scoring.batchsearch(queries, limit)
204
+
205
+ # Require scores to be greater than 0
206
+ results = [[(i, score) for i, score in r if score > 0] for r in results]
207
+
208
+ return self.resolve(results)
209
+
210
+ def resolve(self, results):
211
+ """
212
+ Resolves index ids. This is only executed when content is disabled.
213
+
214
+ Args:
215
+ results: results
216
+
217
+ Returns:
218
+ results with resolved ids
219
+ """
220
+
221
+ # Map indexids to ids if embeddings ids are available
222
+ if not self.indexids and self.ids:
223
+ return [[(self.ids[i], score) for i, score in r] for r in results]
224
+
225
+ return results
226
+
227
+ def dbsearch(self, queries, limit, weights, index, parameters):
228
+ """
229
+ Executes an index + database search.
230
+
231
+ Args:
232
+ queries: list of queries
233
+ limit: maximum results
234
+ weights: default hybrid score weights
235
+ index: default index name
236
+ parameters: list of dicts of named parameters to bind to placeholders
237
+
238
+ Returns:
239
+ list of dict per query
240
+ """
241
+
242
+ # Parse queries
243
+ queries = self.parse(queries)
244
+
245
+ # Override limit with query limit, if applicable
246
+ limit = max(limit, self.limit(queries))
247
+
248
+ # Bulk index scan
249
+ scan = Scan(self.search, limit, weights, index)(queries, parameters)
250
+
251
+ # Combine index search results with database search results
252
+ results = []
253
+ for x, query in enumerate(queries):
254
+ # Run the database query, get matching bulk searches for current query
255
+ result = self.database.search(
256
+ query, [r for y, r in scan if x == y], limit, parameters[x] if parameters and parameters[x] else None, self.indexids
257
+ )
258
+ results.append(result)
259
+
260
+ return results
261
+
262
+ def parse(self, queries):
263
+ """
264
+ Parses a list of database queries.
265
+
266
+ Args:
267
+ queries: list of queries
268
+
269
+ Returns:
270
+ parsed queries
271
+ """
272
+
273
+ # Parsed queries
274
+ parsed = []
275
+
276
+ for query in queries:
277
+ # Parse query
278
+ parse = self.database.parse(query)
279
+
280
+ # Transform query if SQL not parsed and reparse
281
+ if self.query and "select" not in parse:
282
+ # Generate query
283
+ query = self.query(query)
284
+ logger.debug(query)
285
+
286
+ # Reparse query
287
+ parse = self.database.parse(query)
288
+
289
+ parsed.append(parse)
290
+
291
+ return parsed
292
+
293
+ def limit(self, queries):
294
+ """
295
+ Parses the largest LIMIT clause from queries.
296
+
297
+ Args:
298
+ queries: list of queries
299
+
300
+ Returns:
301
+ largest limit number or 0 if not found
302
+ """
303
+
304
+ # Override limit with largest limit from database queries
305
+ qlimit = 0
306
+ for query in queries:
307
+ # Parse out qlimit
308
+ l = query.get("limit")
309
+ if l and l.isdigit():
310
+ l = int(l)
311
+
312
+ qlimit = l if l and l > qlimit else qlimit
313
+
314
+ return qlimit
315
+
316
+ def graphsearch(self, queries, limit, weights, index):
317
+ """
318
+ Executes an index + graph search.
319
+
320
+ Args:
321
+ queries: list of queries
322
+ limit: maximum results
323
+ weights: default hybrid score weights
324
+ index: default index name
325
+
326
+ Returns:
327
+ graph search results
328
+ """
329
+
330
+ # Parse queries
331
+ queries = [self.graph.parse(query) for query in queries]
332
+
333
+ # Override limit with query limit, if applicable
334
+ limit = max(limit, self.limit(queries))
335
+
336
+ # Bulk index scan
337
+ scan = Scan(self.search, limit, weights, index)(queries, None)
338
+
339
+ # Combine index search results with database search results
340
+ for x, query in enumerate(queries):
341
+ # Add search results to query
342
+ query["results"] = [r for y, r in scan if x == y]
343
+
344
+ return self.graph.batchsearch(queries, limit, self.indexids)
@@ -0,0 +1,9 @@
1
+ """
2
+ Errors module
3
+ """
4
+
5
+
6
+ class IndexNotFoundError(Exception):
7
+ """
8
+ Raised when an embeddings query fails to locate an index
9
+ """
@@ -0,0 +1,120 @@
1
+ """
2
+ Explain module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ class Explain:
9
+ """
10
+ Explains the importance of each token in an input text element for a query. This method creates n permutations of the input text, where n
11
+ is the number of tokens in the input text. This effectively masks each token to determine its importance to the query.
12
+ """
13
+
14
+ def __init__(self, embeddings):
15
+ """
16
+ Creates a new explain action.
17
+
18
+ Args:
19
+ embeddings: embeddings instance
20
+ """
21
+
22
+ self.embeddings = embeddings
23
+ self.content = embeddings.config.get("content")
24
+
25
+ # Alias embeddings attributes
26
+ self.database = embeddings.database
27
+
28
+ def __call__(self, queries, texts, limit):
29
+ """
30
+ Explains the importance of each input token in text for a list of queries.
31
+
32
+ Args:
33
+ query: input queries
34
+ texts: optional list of (text|list of tokens), otherwise runs search queries
35
+ limit: optional limit if texts is None
36
+
37
+ Returns:
38
+ list of dict per input text per query where a higher token scores represents higher importance relative to the query
39
+ """
40
+
41
+ # Construct texts elements per query
42
+ texts = self.texts(queries, texts, limit)
43
+
44
+ # Explain each query-texts combination
45
+ return [self.explain(query, texts[x]) for x, query in enumerate(queries)]
46
+
47
+ def texts(self, queries, texts, limit):
48
+ """
49
+ Constructs lists of dict for each input query.
50
+
51
+ Args:
52
+ queries: input queries
53
+ texts: optional list of texts
54
+ limit: optional limit if texts is None
55
+
56
+ Returns:
57
+ lists of dict for each input query
58
+ """
59
+
60
+ # Calculate similarity scores per query if texts present
61
+ if texts:
62
+ results = []
63
+ for scores in self.embeddings.batchsimilarity(queries, texts):
64
+ results.append([{"id": uid, "text": texts[uid], "score": score} for uid, score in scores])
65
+
66
+ return results
67
+
68
+ # Query for results if texts is None and content is enabled
69
+ return self.embeddings.batchsearch(queries, limit) if self.content else [[]] * len(queries)
70
+
71
+ def explain(self, query, texts):
72
+ """
73
+ Explains the importance of each input token in text for a list of queries.
74
+
75
+ Args:
76
+ query: input query
77
+ texts: list of text
78
+
79
+ Returns:
80
+ list of {"id": value, "text": value, "score": value, "tokens": value} covering each input text element
81
+ """
82
+
83
+ # Explain results
84
+ results = []
85
+
86
+ # Parse out similar clauses, if necessary
87
+ if self.database:
88
+ # Parse query
89
+ query = self.database.parse(query)
90
+
91
+ # Extract query from similar clause
92
+ query = " ".join([" ".join(clause) for clause in query["similar"]]) if "similar" in query else None
93
+
94
+ # Return original texts if query, text or score not present
95
+ if not query or not texts or "score" not in texts[0] or "text" not in texts[0]:
96
+ return texts
97
+
98
+ # Calculate result per input text element
99
+ for result in texts:
100
+ text = result["text"]
101
+ tokens = text if isinstance(text, list) else text.split()
102
+
103
+ # Create permutations of input text, masking each token to determine importance
104
+ permutations = []
105
+ for i in range(len(tokens)):
106
+ data = tokens.copy()
107
+ data.pop(i)
108
+ permutations.append([" ".join(data)])
109
+
110
+ # Calculate similarity for each input text permutation and get score delta as importance
111
+ scores = [(i, result["score"] - np.abs(s)) for i, s in self.embeddings.similarity(query, permutations)]
112
+
113
+ # Append tokens to result
114
+ result["tokens"] = [(tokens[i], score) for i, score in sorted(scores, key=lambda x: x[0])]
115
+
116
+ # Add data sorted in index order
117
+ results.append(result)
118
+
119
+ # Sort score descending and return
120
+ return sorted(results, key=lambda x: x["score"], reverse=True)
@@ -0,0 +1,61 @@
1
+ """
2
+ Ids module
3
+ """
4
+
5
+
6
+ class Ids:
7
+ """
8
+ Resolves internal ids for lists of ids.
9
+ """
10
+
11
+ def __init__(self, embeddings):
12
+ """
13
+ Create a new ids action.
14
+
15
+ Args:
16
+ embeddings: embeddings instance
17
+ """
18
+
19
+ self.database = embeddings.database
20
+ self.ids = embeddings.ids
21
+
22
+ def __call__(self, ids):
23
+ """
24
+ Resolve internal ids.
25
+
26
+ Args:
27
+ ids: ids
28
+
29
+ Returns:
30
+ internal ids
31
+ """
32
+
33
+ # Resolve ids using database if available, otherwise fallback to embeddings ids
34
+ results = self.database.ids(ids) if self.database else self.scan(ids)
35
+
36
+ # Create dict of id: [iids] given there is a one to many relationship
37
+ ids = {}
38
+ for iid, uid in results:
39
+ if uid not in ids:
40
+ ids[uid] = []
41
+ ids[uid].append(iid)
42
+
43
+ return ids
44
+
45
+ def scan(self, ids):
46
+ """
47
+ Scans embeddings ids array for matches when content is disabled.
48
+
49
+ Args:
50
+ ids: search ids
51
+
52
+ Returns:
53
+ internal ids
54
+ """
55
+
56
+ # Find existing ids
57
+ indices = []
58
+ for uid in ids:
59
+ indices.extend([(index, value) for index, value in enumerate(self.ids) if uid == value])
60
+
61
+ return indices
@@ -0,0 +1,69 @@
1
+ """
2
+ Query module
3
+ """
4
+
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
6
+
7
+
8
+ class Query:
9
+ """
10
+ Query translation model.
11
+ """
12
+
13
+ def __init__(self, path, prefix=None, maxlength=512):
14
+ """
15
+ Creates a query translation model.
16
+
17
+ Args:
18
+ path: path to query model
19
+ prefix: text prefix
20
+ maxlength: max sequence length to generate
21
+ """
22
+
23
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
24
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
25
+
26
+ # Default prefix if not provided for T5 models
27
+ if not prefix and isinstance(self.model, T5ForConditionalGeneration):
28
+ prefix = "translate English to SQL: "
29
+
30
+ self.prefix = prefix
31
+ self.maxlength = maxlength
32
+
33
+ def __call__(self, query):
34
+ """
35
+ Runs query translation model.
36
+
37
+ Args:
38
+ query: input query
39
+
40
+ Returns:
41
+ transformed query
42
+ """
43
+
44
+ # Add prefix, if necessary
45
+ if self.prefix:
46
+ query = f"{self.prefix}{query}"
47
+
48
+ # Tokenize and generate text using model
49
+ features = self.tokenizer([query], return_tensors="pt")
50
+ output = self.model.generate(input_ids=features["input_ids"], attention_mask=features["attention_mask"], max_length=self.maxlength)
51
+
52
+ # Decode tokens to text
53
+ result = self.tokenizer.decode(output[0], skip_special_tokens=True)
54
+
55
+ # Clean and return generated text
56
+ return self.clean(result)
57
+
58
+ def clean(self, text):
59
+ """
60
+ Applies a series of rules to clean generated text.
61
+
62
+ Args:
63
+ text: input text
64
+
65
+ Returns:
66
+ clean text
67
+ """
68
+
69
+ return text.replace("$=", "<=")