mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,377 @@
1
+ """
2
+ IVFSparse module
3
+ """
4
+
5
+ import math
6
+ import os
7
+
8
+ from multiprocessing.pool import ThreadPool
9
+
10
+ import numpy as np
11
+
12
+ # Conditional import
13
+ try:
14
+ from scipy.sparse import csr_matrix, vstack
15
+ from scipy.sparse.linalg import norm
16
+ from sklearn.cluster import MiniBatchKMeans
17
+ from sklearn.metrics import pairwise_distances_argmin_min
18
+ from sklearn.utils.extmath import safe_sparse_dot
19
+
20
+ IVFSPARSE = True
21
+ except ImportError:
22
+ IVFSPARSE = False
23
+
24
+ from ...serialize import SerializeFactory
25
+ from ...util import SparseArray
26
+ from ..base import ANN
27
+
28
+
29
+ class IVFSparse(ANN):
30
+ """
31
+ Inverted file (IVF) index with flat vector file storage and sparse array support.
32
+
33
+ IVFSparse builds an IVF index and enables approximate nearest neighbor (ANN) search.
34
+
35
+ This index is modeled after Faiss and supports many of the same parameters.
36
+
37
+ See this link for more: https://github.com/facebookresearch/faiss/wiki/Faster-search
38
+ """
39
+
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+
43
+ if not IVFSPARSE:
44
+ raise ImportError('IVFSparse is not available - install "ann" extra to enable')
45
+
46
+ # Cluster centroids, if computed
47
+ self.centroids = None
48
+
49
+ # Cluster id mapping
50
+ self.ids = None
51
+
52
+ # Cluster data blocks - can be a single block with no computed centroids
53
+ self.blocks = None
54
+
55
+ # Deleted ids
56
+ self.deletes = None
57
+
58
+ def index(self, embeddings):
59
+ # Compute model training size
60
+ train, sample = embeddings, self.setting("sample")
61
+ if sample:
62
+ # Get sample for training
63
+ rng = np.random.default_rng(0)
64
+ indices = sorted(rng.choice(train.shape[0], int(sample * train.shape[0]), replace=False, shuffle=False))
65
+ train = train[indices]
66
+
67
+ # Get number of clusters. Note that final number of clusters could be lower due to filtering duplicate centroids
68
+ # and pruning of small clusters
69
+ clusters = self.nlist(embeddings.shape[0], train.shape[0])
70
+
71
+ # Build cluster centroids if approximate search is enabled
72
+ # A single cluster performs exact search
73
+ self.centroids = self.build(train, clusters) if clusters > 1 else None
74
+
75
+ # Sort into clusters
76
+ ids = self.aggregate(embeddings)
77
+
78
+ # Prune small clusters (less than minpoints parameter) and rebuild
79
+ indices = sorted(k for k, v in ids.items() if len(v) >= self.minpoints())
80
+ if len(indices) > 0 and len(ids) > 1 and len(indices) != len(ids.keys()):
81
+ self.centroids = self.centroids[indices]
82
+ ids = self.aggregate(embeddings)
83
+
84
+ # Sort clusters by id
85
+ self.ids = dict(sorted(ids.items(), key=lambda x: x[0]))
86
+
87
+ # Create cluster data blocks
88
+ self.blocks = {k: embeddings[v] for k, v in self.ids.items()}
89
+
90
+ # Calculate block max summary vectors and use as centroids
91
+ self.centroids = vstack([csr_matrix(x.max(axis=0)) for x in self.blocks.values()]) if self.centroids is not None else None
92
+
93
+ # Initialize deletes
94
+ self.deletes = []
95
+
96
+ # Add id offset and index build metadata
97
+ self.config["offset"] = embeddings.shape[0]
98
+ self.metadata({"clusters": len(self.blocks)})
99
+
100
+ def append(self, embeddings):
101
+ # Get offset
102
+ offset = self.size()
103
+
104
+ # Sort into clusters and merge
105
+ for cluster, ids in self.aggregate(embeddings).items():
106
+ # Add new ids
107
+ self.ids[cluster].extend([x + offset for x in ids])
108
+
109
+ # Add new data
110
+ self.blocks[cluster] = vstack([self.blocks[cluster], embeddings[ids]])
111
+
112
+ # Update id offset and index metadata
113
+ self.config["offset"] += embeddings.shape[0]
114
+ self.metadata()
115
+
116
+ def delete(self, ids):
117
+ # Set index ids as deleted
118
+ self.deletes.extend(ids)
119
+
120
+ def search(self, queries, limit):
121
+ results = []
122
+
123
+ # Calculate number of threads using a thread batch size of 32
124
+ threads = queries.shape[0] // 32
125
+ threads = min(max(threads, 1), os.cpu_count())
126
+
127
+ # Approximate search
128
+ blockids = self.topn(queries, self.centroids, self.nprobe())[0] if self.centroids is not None else None
129
+
130
+ # This method is able to run as multiple threads due to a number of numpy/scipy method calls that drop the GIL.
131
+ results = []
132
+ with ThreadPool(threads) as pool:
133
+ for result in pool.starmap(self.scan, [(x, limit, blockids[i] if blockids is not None else None) for i, x in enumerate(queries)]):
134
+ results.append(result)
135
+
136
+ return results
137
+
138
+ def count(self):
139
+ return self.size() - len(self.deletes)
140
+
141
+ def load(self, path):
142
+ # Create streaming serializer and limit read size to a byte at a time to ensure
143
+ # only msgpack data is consumed
144
+ serializer = SerializeFactory.create("msgpack", streaming=True, read_size=1)
145
+
146
+ with open(path, "rb") as f:
147
+ # Read header
148
+ unpacker = serializer.loadstream(f)
149
+ header = next(unpacker)
150
+
151
+ # Read cluster centroids, if available
152
+ self.centroids = SparseArray().load(f) if header["centroids"] else None
153
+
154
+ # Read cluster ids
155
+ self.ids = dict(next(unpacker))
156
+
157
+ # Read cluster data blocks
158
+ self.blocks = {}
159
+ for key in self.ids:
160
+ self.blocks[key] = SparseArray().load(f)
161
+
162
+ # Read deletes
163
+ self.deletes = next(unpacker)
164
+
165
+ def save(self, path):
166
+ # IVFSparse storage format:
167
+ # - header msgpack
168
+ # - centroids sparse array (optional based on header parameters)
169
+ # - cluster ids msgpack
170
+ # - cluster data blocks list of sparse arrays
171
+ # - deletes msgpack
172
+
173
+ # Create message pack serializer
174
+ serializer = SerializeFactory.create("msgpack")
175
+
176
+ with open(path, "wb") as f:
177
+ # Write header
178
+ serializer.savestream({"centroids": self.centroids is not None, "count": self.count(), "blocks": len(self.blocks)}, f)
179
+
180
+ # Write cluster centroids, if available
181
+ if self.centroids is not None:
182
+ SparseArray().save(f, self.centroids)
183
+
184
+ # Write cluster id mapping
185
+ serializer.savestream(list(self.ids.items()), f)
186
+
187
+ # Write cluster data blocks
188
+ for block in self.blocks.values():
189
+ SparseArray().save(f, block)
190
+
191
+ # Write deletes
192
+ serializer.savestream(self.deletes, f)
193
+
194
+ def build(self, train, clusters):
195
+ """
196
+ Builds a k-means cluster to calculate centroid points for aggregating data blocks.
197
+
198
+ Args:
199
+ train: training data
200
+ clusters: number of clusters to create
201
+
202
+ Returns:
203
+ cluster centroids
204
+ """
205
+
206
+ # Select top n most important features that contribute to L2 vector norm
207
+ indices = np.argsort(-norm(train, axis=0))[: self.setting("nfeatures", 25)]
208
+ data = train[:, indices]
209
+ data = train
210
+
211
+ # Cluster data using k-means
212
+ kmeans = MiniBatchKMeans(n_clusters=clusters, random_state=0, n_init=5).fit(data)
213
+
214
+ # Find closest points to each cluster center and use those as centroids
215
+ indices, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, data, metric="l2")
216
+
217
+ # Filter out duplicate centroids and return cluster centroids
218
+ return train[np.unique(indices)]
219
+
220
+ def aggregate(self, data):
221
+ """
222
+ Aggregates input data array into clusters. This method sorts each data element into the
223
+ cluster with the highest L2 similarity centroid.
224
+
225
+ Args:
226
+ data: input data
227
+
228
+ Returns:
229
+ {cluster, ids}
230
+ """
231
+
232
+ # Exact search when only a single cluster
233
+ if self.centroids is None:
234
+ return {0: list(range(data.shape[0]))}
235
+
236
+ # Map data to closest centroids
237
+ indices, _ = pairwise_distances_argmin_min(data, self.centroids, metric="l2")
238
+
239
+ # Sort into clusters
240
+ ids = {}
241
+ for x, cluster in enumerate(indices.tolist()):
242
+ if cluster not in ids:
243
+ ids[cluster] = []
244
+
245
+ # Save id
246
+ ids[cluster].append(x)
247
+
248
+ return ids
249
+
250
+ def topn(self, queries, data, limit, deletes=None):
251
+ """
252
+ Gets the top n most similar data elements for query.
253
+
254
+ Args:
255
+ queries: queries array
256
+ data: data array
257
+ limit: top n
258
+ deletes: optional list of deletes to filter from results
259
+
260
+ Returns:
261
+ list of matching (indices, scores)
262
+ """
263
+
264
+ # Dot product similarity
265
+ scores = safe_sparse_dot(queries, data.T, dense_output=True)
266
+
267
+ # Clear deletes
268
+ if deletes is not None:
269
+ scores[:, deletes] = 0
270
+
271
+ # Get top n matching indices and scores
272
+ indices = np.argpartition(-scores, limit if limit < scores.shape[0] else scores.shape[0] - 1)[:, :limit]
273
+ scores = np.take_along_axis(scores, indices, axis=1)
274
+
275
+ return indices, scores
276
+
277
+ def scan(self, query, limit, blockids):
278
+ """
279
+ Scans a list of blocks for top n ids that match query.
280
+
281
+ Args:
282
+ query: input query
283
+ limit top n
284
+ blockids: block ids to scan
285
+
286
+ Returns:
287
+ list of (id, scores)
288
+ """
289
+
290
+ if self.centroids is not None:
291
+ # Stack into single ids list
292
+ ids = np.concatenate([self.ids[x] for x in blockids if x in self.ids])
293
+
294
+ # Stack data rows
295
+ data = vstack([self.blocks[x] for x in blockids if x in self.blocks])
296
+ else:
297
+ # Exact search
298
+ ids, data = np.array(self.ids[0]), self.blocks[0]
299
+
300
+ # Get deletes
301
+ deletes = np.argwhere(np.isin(ids, self.deletes)).ravel()
302
+
303
+ # Calculate similarity
304
+ indices, scores = self.topn(query, data, limit, deletes)
305
+ indices, scores = indices[0], scores[0]
306
+
307
+ # Map data ids and return
308
+ return list(zip(ids[indices].tolist(), scores.tolist()))
309
+
310
+ def nlist(self, count, train):
311
+ """
312
+ Calculates the number of clusters for this IVFSparse index. Note that the final number of clusters
313
+ could be lower as duplicate cluster centroids are filtered out.
314
+
315
+ Args:
316
+ count: initial dataset size
317
+ train: number of rows used to train
318
+
319
+ Returns:
320
+ number of clusters
321
+ """
322
+
323
+ # Get data size
324
+ default = 1 if count <= 5000 else self.cells(train)
325
+
326
+ # Number of clusters to create
327
+ return self.setting("nlist", default)
328
+
329
+ def nprobe(self):
330
+ """
331
+ Gets or derives the nprobe search parameter.
332
+
333
+ Returns:
334
+ nprobe setting
335
+ """
336
+
337
+ # Get size of embeddings index
338
+ size = self.size()
339
+
340
+ default = 6 if size <= 5000 else self.cells(size) // 16
341
+ return self.setting("nprobe", default)
342
+
343
+ def cells(self, count):
344
+ """
345
+ Calculates the number of IVF cells for an IVFSparse index.
346
+
347
+ Args:
348
+ count: number of rows
349
+
350
+ Returns:
351
+ number of IVF cells
352
+ """
353
+
354
+ # Calculate number of IVF cells where x = min(4 * sqrt(count), count / minpoints)
355
+ return max(min(round(4 * math.sqrt(count)), int(count / self.minpoints())), 1)
356
+
357
+ def size(self):
358
+ """
359
+ Gets the total size of this index including deletes.
360
+
361
+ Returns:
362
+ size
363
+ """
364
+
365
+ return sum(len(x) for x in self.ids.values())
366
+
367
+ def minpoints(self):
368
+ """
369
+ Gets the minimum number of points per cluster.
370
+
371
+ Returns:
372
+ minimum points per cluster
373
+ """
374
+
375
+ # Minimum number of points per cluster
376
+ # Match faiss default that requires at least 39 points per clusters
377
+ return self.setting("minpoints", 39)
@@ -0,0 +1,56 @@
1
+ """
2
+ PGSparse module
3
+ """
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+
9
+ # Conditional import
10
+ try:
11
+ from pgvector import SparseVector
12
+ from pgvector.sqlalchemy import SPARSEVEC
13
+
14
+ PGSPARSE = True
15
+ except ImportError:
16
+ PGSPARSE = False
17
+
18
+ from ..dense import PGVector
19
+
20
+
21
+ class PGSparse(PGVector):
22
+ """
23
+ Builds a Sparse ANN index backed by a Postgres database.
24
+ """
25
+
26
+ def __init__(self, config):
27
+ if not PGSPARSE:
28
+ raise ImportError('PGSparse is not available - install "ann" extra to enable')
29
+
30
+ super().__init__(config)
31
+
32
+ # Quantization not supported
33
+ self.qbits = None
34
+
35
+ def defaulttable(self):
36
+ return "svectors"
37
+
38
+ def url(self):
39
+ return self.setting("url", os.environ.get("SCORING_URL", os.environ.get("ANN_URL")))
40
+
41
+ def column(self):
42
+ return SPARSEVEC(self.config["dimensions"])
43
+
44
+ def operation(self):
45
+ return "sparsevec_ip_ops"
46
+
47
+ def prepare(self, data):
48
+ # pgvector only allows 1000 non-zero values for sparse vectors
49
+ # Trim to top 1000 values, if necessary
50
+ if data.count_nonzero() > 1000:
51
+ value = -np.sort(-data[0, :].data)[1000]
52
+ data.data = np.where(data.data > value, data.data, 0)
53
+ data.eliminate_zeros()
54
+
55
+ # Wrap as sparse vector
56
+ return SparseVector(data)
txtai/api/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ """
2
+ API imports
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ from .authorization import Authorization
8
+ from .application import app, start
9
+ from .base import API
10
+ from .cluster import Cluster
11
+ from .extension import Extension
12
+ from .factory import APIFactory
13
+ from .responses import *
14
+ from .routers import *
15
+ from .route import EncodingAPIRoute
16
+ except ImportError as missing:
17
+ # pylint: disable=W0707
18
+ raise ImportError('API is not available - install "api" extra to enable') from missing
@@ -0,0 +1,134 @@
1
+ """
2
+ FastAPI application module
3
+ """
4
+
5
+ import inspect
6
+ import os
7
+ import sys
8
+
9
+ from fastapi import APIRouter, Depends, FastAPI
10
+ from fastapi_mcp import FastApiMCP
11
+ from httpx import AsyncClient
12
+
13
+ from .authorization import Authorization
14
+ from .base import API
15
+ from .factory import APIFactory
16
+
17
+ from ..app import Application
18
+
19
+
20
+ def get():
21
+ """
22
+ Returns global API instance.
23
+
24
+ Returns:
25
+ API instance
26
+ """
27
+
28
+ return INSTANCE
29
+
30
+
31
+ def create():
32
+ """
33
+ Creates a FastAPI instance.
34
+ """
35
+
36
+ # Application dependencies
37
+ dependencies = []
38
+
39
+ # Default implementation of token authorization
40
+ token = os.environ.get("TOKEN")
41
+ if token:
42
+ dependencies.append(Depends(Authorization(token)))
43
+
44
+ # Add custom dependencies
45
+ deps = os.environ.get("DEPENDENCIES")
46
+ if deps:
47
+ for dep in deps.split(","):
48
+ # Create and add dependency
49
+ dep = APIFactory.get(dep.strip())()
50
+ dependencies.append(Depends(dep))
51
+
52
+ # Create FastAPI application
53
+ return FastAPI(lifespan=lifespan, dependencies=dependencies if dependencies else None)
54
+
55
+
56
+ def apirouters():
57
+ """
58
+ Lists available APIRouters.
59
+
60
+ Returns:
61
+ {router name: router}
62
+ """
63
+
64
+ # Get handle to api module
65
+ api = sys.modules[".".join(__name__.split(".")[:-1])]
66
+
67
+ available = {}
68
+ for name, rclass in inspect.getmembers(api, inspect.ismodule):
69
+ if hasattr(rclass, "router") and isinstance(rclass.router, APIRouter):
70
+ available[name.lower()] = rclass.router
71
+
72
+ return available
73
+
74
+
75
+ def lifespan(application):
76
+ """
77
+ FastAPI lifespan event handler.
78
+
79
+ Args:
80
+ application: FastAPI application to initialize
81
+ """
82
+
83
+ # pylint: disable=W0603
84
+ global INSTANCE
85
+
86
+ # Load YAML settings
87
+ config = Application.read(os.environ.get("CONFIG"))
88
+
89
+ # Instantiate API instance
90
+ api = os.environ.get("API_CLASS")
91
+ INSTANCE = APIFactory.create(config, api) if api else API(config)
92
+
93
+ # Get all known routers
94
+ routers = apirouters()
95
+
96
+ # Conditionally add routes based on configuration
97
+ for name, router in routers.items():
98
+ if name in config:
99
+ application.include_router(router)
100
+
101
+ # Special case for embeddings clusters
102
+ if "cluster" in config and "embeddings" not in config:
103
+ application.include_router(routers["embeddings"])
104
+
105
+ # Special case to add similarity instance for embeddings
106
+ if "embeddings" in config and "similarity" not in config:
107
+ application.include_router(routers["similarity"])
108
+
109
+ # Execute extensions if present
110
+ extensions = os.environ.get("EXTENSIONS")
111
+ if extensions:
112
+ for extension in extensions.split(","):
113
+ # Create instance and execute extension
114
+ extension = APIFactory.get(extension.strip())()
115
+ extension(application)
116
+
117
+ # Add Model Context Protocol (MCP) service, if applicable
118
+ if config.get("mcp"):
119
+ mcp = FastApiMCP(application, http_client=AsyncClient(timeout=100))
120
+ mcp.mount()
121
+
122
+ yield
123
+
124
+
125
+ def start():
126
+ """
127
+ Runs application lifespan handler.
128
+ """
129
+
130
+ list(lifespan(app))
131
+
132
+
133
+ # FastAPI instance txtai API instances
134
+ app, INSTANCE = create(), None
@@ -0,0 +1,53 @@
1
+ """
2
+ Authorization module
3
+ """
4
+
5
+ import hashlib
6
+ import os
7
+
8
+ from fastapi import Header, HTTPException
9
+
10
+
11
+ class Authorization:
12
+ """
13
+ Basic token authorization.
14
+ """
15
+
16
+ def __init__(self, token=None):
17
+ """
18
+ Creates a new Authorization instance.
19
+
20
+ Args:
21
+ token: SHA-256 hash of token to check
22
+ """
23
+
24
+ self.token = token if token else os.environ.get("TOKEN")
25
+
26
+ def __call__(self, authorization: str = Header(default=None)):
27
+ """
28
+ Validates authorization header is present and equal to current token.
29
+
30
+ Args:
31
+ authorization: authorization header
32
+ """
33
+
34
+ if not authorization or self.token != self.digest(authorization):
35
+ raise HTTPException(status_code=401, detail="Invalid Authorization Token")
36
+
37
+ def digest(self, authorization):
38
+ """
39
+ Computes a SHA-256 hash for input authorization token.
40
+
41
+ Args:
42
+ authorization: authorization header
43
+
44
+ Returns:
45
+ SHA-256 hash of authorization token
46
+ """
47
+
48
+ # Replace Bearer prefix
49
+ prefix = "Bearer "
50
+ token = authorization[len(prefix) :] if authorization.startswith(prefix) else authorization
51
+
52
+ # Compute SHA-256 hash
53
+ return hashlib.sha256(token.encode("utf-8")).hexdigest()