mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,377 @@
|
|
1
|
+
"""
|
2
|
+
IVFSparse module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import math
|
6
|
+
import os
|
7
|
+
|
8
|
+
from multiprocessing.pool import ThreadPool
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
# Conditional import
|
13
|
+
try:
|
14
|
+
from scipy.sparse import csr_matrix, vstack
|
15
|
+
from scipy.sparse.linalg import norm
|
16
|
+
from sklearn.cluster import MiniBatchKMeans
|
17
|
+
from sklearn.metrics import pairwise_distances_argmin_min
|
18
|
+
from sklearn.utils.extmath import safe_sparse_dot
|
19
|
+
|
20
|
+
IVFSPARSE = True
|
21
|
+
except ImportError:
|
22
|
+
IVFSPARSE = False
|
23
|
+
|
24
|
+
from ...serialize import SerializeFactory
|
25
|
+
from ...util import SparseArray
|
26
|
+
from ..base import ANN
|
27
|
+
|
28
|
+
|
29
|
+
class IVFSparse(ANN):
|
30
|
+
"""
|
31
|
+
Inverted file (IVF) index with flat vector file storage and sparse array support.
|
32
|
+
|
33
|
+
IVFSparse builds an IVF index and enables approximate nearest neighbor (ANN) search.
|
34
|
+
|
35
|
+
This index is modeled after Faiss and supports many of the same parameters.
|
36
|
+
|
37
|
+
See this link for more: https://github.com/facebookresearch/faiss/wiki/Faster-search
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self, config):
|
41
|
+
super().__init__(config)
|
42
|
+
|
43
|
+
if not IVFSPARSE:
|
44
|
+
raise ImportError('IVFSparse is not available - install "ann" extra to enable')
|
45
|
+
|
46
|
+
# Cluster centroids, if computed
|
47
|
+
self.centroids = None
|
48
|
+
|
49
|
+
# Cluster id mapping
|
50
|
+
self.ids = None
|
51
|
+
|
52
|
+
# Cluster data blocks - can be a single block with no computed centroids
|
53
|
+
self.blocks = None
|
54
|
+
|
55
|
+
# Deleted ids
|
56
|
+
self.deletes = None
|
57
|
+
|
58
|
+
def index(self, embeddings):
|
59
|
+
# Compute model training size
|
60
|
+
train, sample = embeddings, self.setting("sample")
|
61
|
+
if sample:
|
62
|
+
# Get sample for training
|
63
|
+
rng = np.random.default_rng(0)
|
64
|
+
indices = sorted(rng.choice(train.shape[0], int(sample * train.shape[0]), replace=False, shuffle=False))
|
65
|
+
train = train[indices]
|
66
|
+
|
67
|
+
# Get number of clusters. Note that final number of clusters could be lower due to filtering duplicate centroids
|
68
|
+
# and pruning of small clusters
|
69
|
+
clusters = self.nlist(embeddings.shape[0], train.shape[0])
|
70
|
+
|
71
|
+
# Build cluster centroids if approximate search is enabled
|
72
|
+
# A single cluster performs exact search
|
73
|
+
self.centroids = self.build(train, clusters) if clusters > 1 else None
|
74
|
+
|
75
|
+
# Sort into clusters
|
76
|
+
ids = self.aggregate(embeddings)
|
77
|
+
|
78
|
+
# Prune small clusters (less than minpoints parameter) and rebuild
|
79
|
+
indices = sorted(k for k, v in ids.items() if len(v) >= self.minpoints())
|
80
|
+
if len(indices) > 0 and len(ids) > 1 and len(indices) != len(ids.keys()):
|
81
|
+
self.centroids = self.centroids[indices]
|
82
|
+
ids = self.aggregate(embeddings)
|
83
|
+
|
84
|
+
# Sort clusters by id
|
85
|
+
self.ids = dict(sorted(ids.items(), key=lambda x: x[0]))
|
86
|
+
|
87
|
+
# Create cluster data blocks
|
88
|
+
self.blocks = {k: embeddings[v] for k, v in self.ids.items()}
|
89
|
+
|
90
|
+
# Calculate block max summary vectors and use as centroids
|
91
|
+
self.centroids = vstack([csr_matrix(x.max(axis=0)) for x in self.blocks.values()]) if self.centroids is not None else None
|
92
|
+
|
93
|
+
# Initialize deletes
|
94
|
+
self.deletes = []
|
95
|
+
|
96
|
+
# Add id offset and index build metadata
|
97
|
+
self.config["offset"] = embeddings.shape[0]
|
98
|
+
self.metadata({"clusters": len(self.blocks)})
|
99
|
+
|
100
|
+
def append(self, embeddings):
|
101
|
+
# Get offset
|
102
|
+
offset = self.size()
|
103
|
+
|
104
|
+
# Sort into clusters and merge
|
105
|
+
for cluster, ids in self.aggregate(embeddings).items():
|
106
|
+
# Add new ids
|
107
|
+
self.ids[cluster].extend([x + offset for x in ids])
|
108
|
+
|
109
|
+
# Add new data
|
110
|
+
self.blocks[cluster] = vstack([self.blocks[cluster], embeddings[ids]])
|
111
|
+
|
112
|
+
# Update id offset and index metadata
|
113
|
+
self.config["offset"] += embeddings.shape[0]
|
114
|
+
self.metadata()
|
115
|
+
|
116
|
+
def delete(self, ids):
|
117
|
+
# Set index ids as deleted
|
118
|
+
self.deletes.extend(ids)
|
119
|
+
|
120
|
+
def search(self, queries, limit):
|
121
|
+
results = []
|
122
|
+
|
123
|
+
# Calculate number of threads using a thread batch size of 32
|
124
|
+
threads = queries.shape[0] // 32
|
125
|
+
threads = min(max(threads, 1), os.cpu_count())
|
126
|
+
|
127
|
+
# Approximate search
|
128
|
+
blockids = self.topn(queries, self.centroids, self.nprobe())[0] if self.centroids is not None else None
|
129
|
+
|
130
|
+
# This method is able to run as multiple threads due to a number of numpy/scipy method calls that drop the GIL.
|
131
|
+
results = []
|
132
|
+
with ThreadPool(threads) as pool:
|
133
|
+
for result in pool.starmap(self.scan, [(x, limit, blockids[i] if blockids is not None else None) for i, x in enumerate(queries)]):
|
134
|
+
results.append(result)
|
135
|
+
|
136
|
+
return results
|
137
|
+
|
138
|
+
def count(self):
|
139
|
+
return self.size() - len(self.deletes)
|
140
|
+
|
141
|
+
def load(self, path):
|
142
|
+
# Create streaming serializer and limit read size to a byte at a time to ensure
|
143
|
+
# only msgpack data is consumed
|
144
|
+
serializer = SerializeFactory.create("msgpack", streaming=True, read_size=1)
|
145
|
+
|
146
|
+
with open(path, "rb") as f:
|
147
|
+
# Read header
|
148
|
+
unpacker = serializer.loadstream(f)
|
149
|
+
header = next(unpacker)
|
150
|
+
|
151
|
+
# Read cluster centroids, if available
|
152
|
+
self.centroids = SparseArray().load(f) if header["centroids"] else None
|
153
|
+
|
154
|
+
# Read cluster ids
|
155
|
+
self.ids = dict(next(unpacker))
|
156
|
+
|
157
|
+
# Read cluster data blocks
|
158
|
+
self.blocks = {}
|
159
|
+
for key in self.ids:
|
160
|
+
self.blocks[key] = SparseArray().load(f)
|
161
|
+
|
162
|
+
# Read deletes
|
163
|
+
self.deletes = next(unpacker)
|
164
|
+
|
165
|
+
def save(self, path):
|
166
|
+
# IVFSparse storage format:
|
167
|
+
# - header msgpack
|
168
|
+
# - centroids sparse array (optional based on header parameters)
|
169
|
+
# - cluster ids msgpack
|
170
|
+
# - cluster data blocks list of sparse arrays
|
171
|
+
# - deletes msgpack
|
172
|
+
|
173
|
+
# Create message pack serializer
|
174
|
+
serializer = SerializeFactory.create("msgpack")
|
175
|
+
|
176
|
+
with open(path, "wb") as f:
|
177
|
+
# Write header
|
178
|
+
serializer.savestream({"centroids": self.centroids is not None, "count": self.count(), "blocks": len(self.blocks)}, f)
|
179
|
+
|
180
|
+
# Write cluster centroids, if available
|
181
|
+
if self.centroids is not None:
|
182
|
+
SparseArray().save(f, self.centroids)
|
183
|
+
|
184
|
+
# Write cluster id mapping
|
185
|
+
serializer.savestream(list(self.ids.items()), f)
|
186
|
+
|
187
|
+
# Write cluster data blocks
|
188
|
+
for block in self.blocks.values():
|
189
|
+
SparseArray().save(f, block)
|
190
|
+
|
191
|
+
# Write deletes
|
192
|
+
serializer.savestream(self.deletes, f)
|
193
|
+
|
194
|
+
def build(self, train, clusters):
|
195
|
+
"""
|
196
|
+
Builds a k-means cluster to calculate centroid points for aggregating data blocks.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
train: training data
|
200
|
+
clusters: number of clusters to create
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
cluster centroids
|
204
|
+
"""
|
205
|
+
|
206
|
+
# Select top n most important features that contribute to L2 vector norm
|
207
|
+
indices = np.argsort(-norm(train, axis=0))[: self.setting("nfeatures", 25)]
|
208
|
+
data = train[:, indices]
|
209
|
+
data = train
|
210
|
+
|
211
|
+
# Cluster data using k-means
|
212
|
+
kmeans = MiniBatchKMeans(n_clusters=clusters, random_state=0, n_init=5).fit(data)
|
213
|
+
|
214
|
+
# Find closest points to each cluster center and use those as centroids
|
215
|
+
indices, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, data, metric="l2")
|
216
|
+
|
217
|
+
# Filter out duplicate centroids and return cluster centroids
|
218
|
+
return train[np.unique(indices)]
|
219
|
+
|
220
|
+
def aggregate(self, data):
|
221
|
+
"""
|
222
|
+
Aggregates input data array into clusters. This method sorts each data element into the
|
223
|
+
cluster with the highest L2 similarity centroid.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
data: input data
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
{cluster, ids}
|
230
|
+
"""
|
231
|
+
|
232
|
+
# Exact search when only a single cluster
|
233
|
+
if self.centroids is None:
|
234
|
+
return {0: list(range(data.shape[0]))}
|
235
|
+
|
236
|
+
# Map data to closest centroids
|
237
|
+
indices, _ = pairwise_distances_argmin_min(data, self.centroids, metric="l2")
|
238
|
+
|
239
|
+
# Sort into clusters
|
240
|
+
ids = {}
|
241
|
+
for x, cluster in enumerate(indices.tolist()):
|
242
|
+
if cluster not in ids:
|
243
|
+
ids[cluster] = []
|
244
|
+
|
245
|
+
# Save id
|
246
|
+
ids[cluster].append(x)
|
247
|
+
|
248
|
+
return ids
|
249
|
+
|
250
|
+
def topn(self, queries, data, limit, deletes=None):
|
251
|
+
"""
|
252
|
+
Gets the top n most similar data elements for query.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
queries: queries array
|
256
|
+
data: data array
|
257
|
+
limit: top n
|
258
|
+
deletes: optional list of deletes to filter from results
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
list of matching (indices, scores)
|
262
|
+
"""
|
263
|
+
|
264
|
+
# Dot product similarity
|
265
|
+
scores = safe_sparse_dot(queries, data.T, dense_output=True)
|
266
|
+
|
267
|
+
# Clear deletes
|
268
|
+
if deletes is not None:
|
269
|
+
scores[:, deletes] = 0
|
270
|
+
|
271
|
+
# Get top n matching indices and scores
|
272
|
+
indices = np.argpartition(-scores, limit if limit < scores.shape[0] else scores.shape[0] - 1)[:, :limit]
|
273
|
+
scores = np.take_along_axis(scores, indices, axis=1)
|
274
|
+
|
275
|
+
return indices, scores
|
276
|
+
|
277
|
+
def scan(self, query, limit, blockids):
|
278
|
+
"""
|
279
|
+
Scans a list of blocks for top n ids that match query.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
query: input query
|
283
|
+
limit top n
|
284
|
+
blockids: block ids to scan
|
285
|
+
|
286
|
+
Returns:
|
287
|
+
list of (id, scores)
|
288
|
+
"""
|
289
|
+
|
290
|
+
if self.centroids is not None:
|
291
|
+
# Stack into single ids list
|
292
|
+
ids = np.concatenate([self.ids[x] for x in blockids if x in self.ids])
|
293
|
+
|
294
|
+
# Stack data rows
|
295
|
+
data = vstack([self.blocks[x] for x in blockids if x in self.blocks])
|
296
|
+
else:
|
297
|
+
# Exact search
|
298
|
+
ids, data = np.array(self.ids[0]), self.blocks[0]
|
299
|
+
|
300
|
+
# Get deletes
|
301
|
+
deletes = np.argwhere(np.isin(ids, self.deletes)).ravel()
|
302
|
+
|
303
|
+
# Calculate similarity
|
304
|
+
indices, scores = self.topn(query, data, limit, deletes)
|
305
|
+
indices, scores = indices[0], scores[0]
|
306
|
+
|
307
|
+
# Map data ids and return
|
308
|
+
return list(zip(ids[indices].tolist(), scores.tolist()))
|
309
|
+
|
310
|
+
def nlist(self, count, train):
|
311
|
+
"""
|
312
|
+
Calculates the number of clusters for this IVFSparse index. Note that the final number of clusters
|
313
|
+
could be lower as duplicate cluster centroids are filtered out.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
count: initial dataset size
|
317
|
+
train: number of rows used to train
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
number of clusters
|
321
|
+
"""
|
322
|
+
|
323
|
+
# Get data size
|
324
|
+
default = 1 if count <= 5000 else self.cells(train)
|
325
|
+
|
326
|
+
# Number of clusters to create
|
327
|
+
return self.setting("nlist", default)
|
328
|
+
|
329
|
+
def nprobe(self):
|
330
|
+
"""
|
331
|
+
Gets or derives the nprobe search parameter.
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
nprobe setting
|
335
|
+
"""
|
336
|
+
|
337
|
+
# Get size of embeddings index
|
338
|
+
size = self.size()
|
339
|
+
|
340
|
+
default = 6 if size <= 5000 else self.cells(size) // 16
|
341
|
+
return self.setting("nprobe", default)
|
342
|
+
|
343
|
+
def cells(self, count):
|
344
|
+
"""
|
345
|
+
Calculates the number of IVF cells for an IVFSparse index.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
count: number of rows
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
number of IVF cells
|
352
|
+
"""
|
353
|
+
|
354
|
+
# Calculate number of IVF cells where x = min(4 * sqrt(count), count / minpoints)
|
355
|
+
return max(min(round(4 * math.sqrt(count)), int(count / self.minpoints())), 1)
|
356
|
+
|
357
|
+
def size(self):
|
358
|
+
"""
|
359
|
+
Gets the total size of this index including deletes.
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
size
|
363
|
+
"""
|
364
|
+
|
365
|
+
return sum(len(x) for x in self.ids.values())
|
366
|
+
|
367
|
+
def minpoints(self):
|
368
|
+
"""
|
369
|
+
Gets the minimum number of points per cluster.
|
370
|
+
|
371
|
+
Returns:
|
372
|
+
minimum points per cluster
|
373
|
+
"""
|
374
|
+
|
375
|
+
# Minimum number of points per cluster
|
376
|
+
# Match faiss default that requires at least 39 points per clusters
|
377
|
+
return self.setting("minpoints", 39)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""
|
2
|
+
PGSparse module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
# Conditional import
|
10
|
+
try:
|
11
|
+
from pgvector import SparseVector
|
12
|
+
from pgvector.sqlalchemy import SPARSEVEC
|
13
|
+
|
14
|
+
PGSPARSE = True
|
15
|
+
except ImportError:
|
16
|
+
PGSPARSE = False
|
17
|
+
|
18
|
+
from ..dense import PGVector
|
19
|
+
|
20
|
+
|
21
|
+
class PGSparse(PGVector):
|
22
|
+
"""
|
23
|
+
Builds a Sparse ANN index backed by a Postgres database.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(self, config):
|
27
|
+
if not PGSPARSE:
|
28
|
+
raise ImportError('PGSparse is not available - install "ann" extra to enable')
|
29
|
+
|
30
|
+
super().__init__(config)
|
31
|
+
|
32
|
+
# Quantization not supported
|
33
|
+
self.qbits = None
|
34
|
+
|
35
|
+
def defaulttable(self):
|
36
|
+
return "svectors"
|
37
|
+
|
38
|
+
def url(self):
|
39
|
+
return self.setting("url", os.environ.get("SCORING_URL", os.environ.get("ANN_URL")))
|
40
|
+
|
41
|
+
def column(self):
|
42
|
+
return SPARSEVEC(self.config["dimensions"])
|
43
|
+
|
44
|
+
def operation(self):
|
45
|
+
return "sparsevec_ip_ops"
|
46
|
+
|
47
|
+
def prepare(self, data):
|
48
|
+
# pgvector only allows 1000 non-zero values for sparse vectors
|
49
|
+
# Trim to top 1000 values, if necessary
|
50
|
+
if data.count_nonzero() > 1000:
|
51
|
+
value = -np.sort(-data[0, :].data)[1000]
|
52
|
+
data.data = np.where(data.data > value, data.data, 0)
|
53
|
+
data.eliminate_zeros()
|
54
|
+
|
55
|
+
# Wrap as sparse vector
|
56
|
+
return SparseVector(data)
|
txtai/api/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
API imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from .authorization import Authorization
|
8
|
+
from .application import app, start
|
9
|
+
from .base import API
|
10
|
+
from .cluster import Cluster
|
11
|
+
from .extension import Extension
|
12
|
+
from .factory import APIFactory
|
13
|
+
from .responses import *
|
14
|
+
from .routers import *
|
15
|
+
from .route import EncodingAPIRoute
|
16
|
+
except ImportError as missing:
|
17
|
+
# pylint: disable=W0707
|
18
|
+
raise ImportError('API is not available - install "api" extra to enable') from missing
|
txtai/api/application.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
"""
|
2
|
+
FastAPI application module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import inspect
|
6
|
+
import os
|
7
|
+
import sys
|
8
|
+
|
9
|
+
from fastapi import APIRouter, Depends, FastAPI
|
10
|
+
from fastapi_mcp import FastApiMCP
|
11
|
+
from httpx import AsyncClient
|
12
|
+
|
13
|
+
from .authorization import Authorization
|
14
|
+
from .base import API
|
15
|
+
from .factory import APIFactory
|
16
|
+
|
17
|
+
from ..app import Application
|
18
|
+
|
19
|
+
|
20
|
+
def get():
|
21
|
+
"""
|
22
|
+
Returns global API instance.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
API instance
|
26
|
+
"""
|
27
|
+
|
28
|
+
return INSTANCE
|
29
|
+
|
30
|
+
|
31
|
+
def create():
|
32
|
+
"""
|
33
|
+
Creates a FastAPI instance.
|
34
|
+
"""
|
35
|
+
|
36
|
+
# Application dependencies
|
37
|
+
dependencies = []
|
38
|
+
|
39
|
+
# Default implementation of token authorization
|
40
|
+
token = os.environ.get("TOKEN")
|
41
|
+
if token:
|
42
|
+
dependencies.append(Depends(Authorization(token)))
|
43
|
+
|
44
|
+
# Add custom dependencies
|
45
|
+
deps = os.environ.get("DEPENDENCIES")
|
46
|
+
if deps:
|
47
|
+
for dep in deps.split(","):
|
48
|
+
# Create and add dependency
|
49
|
+
dep = APIFactory.get(dep.strip())()
|
50
|
+
dependencies.append(Depends(dep))
|
51
|
+
|
52
|
+
# Create FastAPI application
|
53
|
+
return FastAPI(lifespan=lifespan, dependencies=dependencies if dependencies else None)
|
54
|
+
|
55
|
+
|
56
|
+
def apirouters():
|
57
|
+
"""
|
58
|
+
Lists available APIRouters.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
{router name: router}
|
62
|
+
"""
|
63
|
+
|
64
|
+
# Get handle to api module
|
65
|
+
api = sys.modules[".".join(__name__.split(".")[:-1])]
|
66
|
+
|
67
|
+
available = {}
|
68
|
+
for name, rclass in inspect.getmembers(api, inspect.ismodule):
|
69
|
+
if hasattr(rclass, "router") and isinstance(rclass.router, APIRouter):
|
70
|
+
available[name.lower()] = rclass.router
|
71
|
+
|
72
|
+
return available
|
73
|
+
|
74
|
+
|
75
|
+
def lifespan(application):
|
76
|
+
"""
|
77
|
+
FastAPI lifespan event handler.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
application: FastAPI application to initialize
|
81
|
+
"""
|
82
|
+
|
83
|
+
# pylint: disable=W0603
|
84
|
+
global INSTANCE
|
85
|
+
|
86
|
+
# Load YAML settings
|
87
|
+
config = Application.read(os.environ.get("CONFIG"))
|
88
|
+
|
89
|
+
# Instantiate API instance
|
90
|
+
api = os.environ.get("API_CLASS")
|
91
|
+
INSTANCE = APIFactory.create(config, api) if api else API(config)
|
92
|
+
|
93
|
+
# Get all known routers
|
94
|
+
routers = apirouters()
|
95
|
+
|
96
|
+
# Conditionally add routes based on configuration
|
97
|
+
for name, router in routers.items():
|
98
|
+
if name in config:
|
99
|
+
application.include_router(router)
|
100
|
+
|
101
|
+
# Special case for embeddings clusters
|
102
|
+
if "cluster" in config and "embeddings" not in config:
|
103
|
+
application.include_router(routers["embeddings"])
|
104
|
+
|
105
|
+
# Special case to add similarity instance for embeddings
|
106
|
+
if "embeddings" in config and "similarity" not in config:
|
107
|
+
application.include_router(routers["similarity"])
|
108
|
+
|
109
|
+
# Execute extensions if present
|
110
|
+
extensions = os.environ.get("EXTENSIONS")
|
111
|
+
if extensions:
|
112
|
+
for extension in extensions.split(","):
|
113
|
+
# Create instance and execute extension
|
114
|
+
extension = APIFactory.get(extension.strip())()
|
115
|
+
extension(application)
|
116
|
+
|
117
|
+
# Add Model Context Protocol (MCP) service, if applicable
|
118
|
+
if config.get("mcp"):
|
119
|
+
mcp = FastApiMCP(application, http_client=AsyncClient(timeout=100))
|
120
|
+
mcp.mount()
|
121
|
+
|
122
|
+
yield
|
123
|
+
|
124
|
+
|
125
|
+
def start():
|
126
|
+
"""
|
127
|
+
Runs application lifespan handler.
|
128
|
+
"""
|
129
|
+
|
130
|
+
list(lifespan(app))
|
131
|
+
|
132
|
+
|
133
|
+
# FastAPI instance txtai API instances
|
134
|
+
app, INSTANCE = create(), None
|
@@ -0,0 +1,53 @@
|
|
1
|
+
"""
|
2
|
+
Authorization module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import hashlib
|
6
|
+
import os
|
7
|
+
|
8
|
+
from fastapi import Header, HTTPException
|
9
|
+
|
10
|
+
|
11
|
+
class Authorization:
|
12
|
+
"""
|
13
|
+
Basic token authorization.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, token=None):
|
17
|
+
"""
|
18
|
+
Creates a new Authorization instance.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
token: SHA-256 hash of token to check
|
22
|
+
"""
|
23
|
+
|
24
|
+
self.token = token if token else os.environ.get("TOKEN")
|
25
|
+
|
26
|
+
def __call__(self, authorization: str = Header(default=None)):
|
27
|
+
"""
|
28
|
+
Validates authorization header is present and equal to current token.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
authorization: authorization header
|
32
|
+
"""
|
33
|
+
|
34
|
+
if not authorization or self.token != self.digest(authorization):
|
35
|
+
raise HTTPException(status_code=401, detail="Invalid Authorization Token")
|
36
|
+
|
37
|
+
def digest(self, authorization):
|
38
|
+
"""
|
39
|
+
Computes a SHA-256 hash for input authorization token.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
authorization: authorization header
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
SHA-256 hash of authorization token
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Replace Bearer prefix
|
49
|
+
prefix = "Bearer "
|
50
|
+
token = authorization[len(prefix) :] if authorization.startswith(prefix) else authorization
|
51
|
+
|
52
|
+
# Compute SHA-256 hash
|
53
|
+
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|