mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,303 @@
|
|
1
|
+
"""
|
2
|
+
SQLite module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import sqlite3
|
7
|
+
|
8
|
+
# Conditional import
|
9
|
+
try:
|
10
|
+
import sqlite_vec
|
11
|
+
|
12
|
+
SQLITEVEC = True
|
13
|
+
except ImportError:
|
14
|
+
SQLITEVEC = False
|
15
|
+
|
16
|
+
from ..base import ANN
|
17
|
+
|
18
|
+
|
19
|
+
class SQLite(ANN):
|
20
|
+
"""
|
21
|
+
Builds an ANN index backed by a SQLite database.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, config):
|
25
|
+
super().__init__(config)
|
26
|
+
|
27
|
+
if not SQLITEVEC:
|
28
|
+
raise ImportError('sqlite-vec is not available - install "ann" extra to enable')
|
29
|
+
|
30
|
+
# Database parameters
|
31
|
+
self.connection, self.cursor, self.path = None, None, ""
|
32
|
+
|
33
|
+
# Quantization setting
|
34
|
+
self.quantize = self.setting("quantize")
|
35
|
+
self.quantize = 8 if isinstance(self.quantize, bool) else int(self.quantize) if self.quantize else None
|
36
|
+
|
37
|
+
def load(self, path):
|
38
|
+
self.path = path
|
39
|
+
|
40
|
+
def index(self, embeddings):
|
41
|
+
# Initialize tables
|
42
|
+
self.initialize(recreate=True)
|
43
|
+
|
44
|
+
# Add vectors
|
45
|
+
self.database().executemany(self.insertsql(), enumerate(embeddings))
|
46
|
+
|
47
|
+
# Add id offset and index build metadata
|
48
|
+
self.config["offset"] = embeddings.shape[0]
|
49
|
+
self.metadata(self.settings())
|
50
|
+
|
51
|
+
def append(self, embeddings):
|
52
|
+
self.database().executemany(self.insertsql(), [(x + self.config["offset"], row) for x, row in enumerate(embeddings)])
|
53
|
+
|
54
|
+
self.config["offset"] += embeddings.shape[0]
|
55
|
+
self.metadata()
|
56
|
+
|
57
|
+
def delete(self, ids):
|
58
|
+
self.database().executemany(self.deletesql(), [(x,) for x in ids])
|
59
|
+
|
60
|
+
def search(self, queries, limit):
|
61
|
+
results = []
|
62
|
+
for query in queries:
|
63
|
+
# Execute query
|
64
|
+
self.database().execute(self.searchsql(), [query, limit])
|
65
|
+
|
66
|
+
# Add query results
|
67
|
+
results.append(list(self.database()))
|
68
|
+
|
69
|
+
return results
|
70
|
+
|
71
|
+
def count(self):
|
72
|
+
self.database().execute(self.countsql())
|
73
|
+
return self.cursor.fetchone()[0]
|
74
|
+
|
75
|
+
def save(self, path):
|
76
|
+
# Temporary database
|
77
|
+
if not self.path:
|
78
|
+
# Save temporary database
|
79
|
+
self.connection.commit()
|
80
|
+
|
81
|
+
# Copy data from current to new
|
82
|
+
connection = self.copy(path)
|
83
|
+
|
84
|
+
# Close temporary database
|
85
|
+
self.connection.close()
|
86
|
+
|
87
|
+
# Point connection to new connection
|
88
|
+
self.connection = connection
|
89
|
+
self.cursor = self.connection.cursor()
|
90
|
+
self.path = path
|
91
|
+
|
92
|
+
# Paths are equal, commit changes
|
93
|
+
elif self.path == path:
|
94
|
+
self.connection.commit()
|
95
|
+
|
96
|
+
# New path is different from current path, copy data and continue using current connection
|
97
|
+
else:
|
98
|
+
self.copy(path).close()
|
99
|
+
|
100
|
+
def close(self):
|
101
|
+
# Parent logic
|
102
|
+
super().close()
|
103
|
+
|
104
|
+
# Close database connection
|
105
|
+
if self.connection:
|
106
|
+
self.connection.close()
|
107
|
+
self.connection = None
|
108
|
+
|
109
|
+
def initialize(self, recreate=False):
|
110
|
+
"""
|
111
|
+
Initializes a new database session.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
recreate: Recreates the database tables if True
|
115
|
+
"""
|
116
|
+
|
117
|
+
# Create table
|
118
|
+
self.database().execute(self.tablesql())
|
119
|
+
|
120
|
+
# Clear data
|
121
|
+
if recreate:
|
122
|
+
self.database().execute(self.tosql("DELETE FROM {table}"))
|
123
|
+
|
124
|
+
def settings(self):
|
125
|
+
"""
|
126
|
+
Returns settings for this index.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
dict
|
130
|
+
"""
|
131
|
+
|
132
|
+
sqlite, sqlitevec = self.database().execute("SELECT sqlite_version(), vec_version()").fetchone()
|
133
|
+
|
134
|
+
return {"sqlite": sqlite, "sqlite-vec": sqlitevec}
|
135
|
+
|
136
|
+
def database(self):
|
137
|
+
"""
|
138
|
+
Gets the current database cursor. Creates a new connection
|
139
|
+
if there isn't one.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
cursor
|
143
|
+
"""
|
144
|
+
|
145
|
+
if not self.connection:
|
146
|
+
self.connection = self.connect(self.path)
|
147
|
+
self.cursor = self.connection.cursor()
|
148
|
+
|
149
|
+
return self.cursor
|
150
|
+
|
151
|
+
def connect(self, path):
|
152
|
+
"""
|
153
|
+
Creates a new database connection.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
path: path to database file
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
database connection
|
160
|
+
"""
|
161
|
+
|
162
|
+
# Create connection
|
163
|
+
connection = sqlite3.connect(path, check_same_thread=False)
|
164
|
+
|
165
|
+
# Load sqlite-vec extension
|
166
|
+
connection.enable_load_extension(True)
|
167
|
+
sqlite_vec.load(connection)
|
168
|
+
connection.enable_load_extension(False)
|
169
|
+
|
170
|
+
# Return connection and cursor
|
171
|
+
return connection
|
172
|
+
|
173
|
+
def copy(self, path):
|
174
|
+
"""
|
175
|
+
Copies content from the current database into target.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
path: target database path
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
new database connection
|
182
|
+
"""
|
183
|
+
|
184
|
+
# Delete existing file, if necessary
|
185
|
+
if os.path.exists(path):
|
186
|
+
os.remove(path)
|
187
|
+
|
188
|
+
# Create new connection
|
189
|
+
connection = self.connect(path)
|
190
|
+
|
191
|
+
if self.connection.in_transaction:
|
192
|
+
# Initialize connection
|
193
|
+
connection.execute(self.tablesql())
|
194
|
+
|
195
|
+
# The backup call will hang if there are uncommitted changes, need to copy over
|
196
|
+
# with iterdump (which is much slower)
|
197
|
+
for sql in self.connection.iterdump():
|
198
|
+
if self.tosql('insert into "{table}"') in sql.lower():
|
199
|
+
connection.execute(sql)
|
200
|
+
else:
|
201
|
+
# Database is up to date, can do a more efficient copy with SQLite C API
|
202
|
+
self.connection.backup(connection)
|
203
|
+
|
204
|
+
return connection
|
205
|
+
|
206
|
+
def tablesql(self):
|
207
|
+
"""
|
208
|
+
Builds a CREATE table statement for table.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
CREATE TABLE
|
212
|
+
"""
|
213
|
+
|
214
|
+
# Binary quantization
|
215
|
+
if self.quantize == 1:
|
216
|
+
embedding = f"embedding BIT[{self.config['dimensions']}]"
|
217
|
+
|
218
|
+
# INT8 quantization
|
219
|
+
elif self.quantize == 8:
|
220
|
+
embedding = f"embedding INT8[{self.config['dimensions']}] distance=cosine"
|
221
|
+
|
222
|
+
# Standard FLOAT32
|
223
|
+
else:
|
224
|
+
embedding = f"embedding FLOAT[{self.config['dimensions']}] distance=cosine"
|
225
|
+
|
226
|
+
# Return CREATE TABLE sql
|
227
|
+
return self.tosql(("CREATE VIRTUAL TABLE IF NOT EXISTS {table} USING vec0" "(indexid INTEGER PRIMARY KEY, " f"{embedding})"))
|
228
|
+
|
229
|
+
def insertsql(self):
|
230
|
+
"""
|
231
|
+
Creates an INSERT SQL statement.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
INSERT
|
235
|
+
"""
|
236
|
+
|
237
|
+
return self.tosql(f"INSERT INTO {{table}}(indexid, embedding) VALUES (?, {self.embeddingsql()})")
|
238
|
+
|
239
|
+
def deletesql(self):
|
240
|
+
"""
|
241
|
+
Creates a DELETE SQL statement.
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
DELETE
|
245
|
+
"""
|
246
|
+
|
247
|
+
return self.tosql("DELETE FROM {table} WHERE indexid = ?")
|
248
|
+
|
249
|
+
def searchsql(self):
|
250
|
+
"""
|
251
|
+
Creates a SELECT SQL statement for search.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
SELECT
|
255
|
+
"""
|
256
|
+
|
257
|
+
return self.tosql(("SELECT indexid, 1 - distance FROM {table} " f"WHERE embedding MATCH {self.embeddingsql()} AND k = ? ORDER BY distance"))
|
258
|
+
|
259
|
+
def countsql(self):
|
260
|
+
"""
|
261
|
+
Creates a SELECT COUNT statement.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
SELECT COUNT
|
265
|
+
"""
|
266
|
+
|
267
|
+
return self.tosql("SELECT count(indexid) FROM {table}")
|
268
|
+
|
269
|
+
def embeddingsql(self):
|
270
|
+
"""
|
271
|
+
Creates an embeddings column SQL snippet.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
embeddings column SQL
|
275
|
+
"""
|
276
|
+
|
277
|
+
# Binary quantization
|
278
|
+
if self.quantize == 1:
|
279
|
+
embedding = "vec_quantize_binary(?)"
|
280
|
+
|
281
|
+
# INT8 quantization
|
282
|
+
elif self.quantize == 8:
|
283
|
+
embedding = "vec_quantize_int8(?, 'unit')"
|
284
|
+
|
285
|
+
# Standard FLOAT32
|
286
|
+
else:
|
287
|
+
embedding = "?"
|
288
|
+
|
289
|
+
return embedding
|
290
|
+
|
291
|
+
def tosql(self, sql):
|
292
|
+
"""
|
293
|
+
Creates a SQL statement substituting in the configured table name.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
sql: SQL statement with a {table} parameter
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
fully resolved SQL statement
|
300
|
+
"""
|
301
|
+
|
302
|
+
table = self.setting("table", "vectors")
|
303
|
+
return sql.format(table=table)
|
txtai/ann/dense/torch.py
ADDED
@@ -0,0 +1,38 @@
|
|
1
|
+
"""
|
2
|
+
PyTorch module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from .numpy import NumPy
|
9
|
+
|
10
|
+
|
11
|
+
class Torch(NumPy):
|
12
|
+
"""
|
13
|
+
Builds an ANN index backed by a PyTorch array.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, config):
|
17
|
+
super().__init__(config)
|
18
|
+
|
19
|
+
# Define array functions
|
20
|
+
self.all, self.cat, self.dot, self.zeros = torch.all, torch.cat, torch.mm, torch.zeros
|
21
|
+
self.argsort, self.xor, self.clip = torch.argsort, torch.bitwise_xor, torch.clip
|
22
|
+
|
23
|
+
def tensor(self, array):
|
24
|
+
# Convert array to Tensor
|
25
|
+
if isinstance(array, np.ndarray):
|
26
|
+
array = torch.from_numpy(array)
|
27
|
+
|
28
|
+
# Load to GPU device, if available
|
29
|
+
return array.cuda() if torch.cuda.is_available() else array
|
30
|
+
|
31
|
+
def numpy(self, array):
|
32
|
+
return array.cpu().numpy()
|
33
|
+
|
34
|
+
def totype(self, array, dtype):
|
35
|
+
return array.long() if dtype == np.int64 else array
|
36
|
+
|
37
|
+
def settings(self):
|
38
|
+
return {"torch": torch.__version__}
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...util import Resolver
|
6
|
+
|
7
|
+
from .ivfsparse import IVFSparse
|
8
|
+
from .pgsparse import PGSparse
|
9
|
+
|
10
|
+
|
11
|
+
class SparseANNFactory:
|
12
|
+
"""
|
13
|
+
Methods to create Sparse ANN indexes.
|
14
|
+
"""
|
15
|
+
|
16
|
+
@staticmethod
|
17
|
+
def create(config):
|
18
|
+
"""
|
19
|
+
Create an Sparse ANN.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: index configuration parameters
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Sparse ANN
|
26
|
+
"""
|
27
|
+
|
28
|
+
# ANN instance
|
29
|
+
ann = None
|
30
|
+
backend = config.get("backend", "ivfsparse")
|
31
|
+
|
32
|
+
# Create ANN instance
|
33
|
+
if backend == "ivfsparse":
|
34
|
+
ann = IVFSparse(config)
|
35
|
+
elif backend == "pgsparse":
|
36
|
+
ann = PGSparse(config)
|
37
|
+
else:
|
38
|
+
ann = SparseANNFactory.resolve(backend, config)
|
39
|
+
|
40
|
+
# Store config back
|
41
|
+
config["backend"] = backend
|
42
|
+
|
43
|
+
return ann
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def resolve(backend, config):
|
47
|
+
"""
|
48
|
+
Attempt to resolve a custom backend.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
backend: backend class
|
52
|
+
config: index configuration parameters
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
ANN
|
56
|
+
"""
|
57
|
+
|
58
|
+
try:
|
59
|
+
return Resolver()(backend)(config)
|
60
|
+
except Exception as e:
|
61
|
+
raise ImportError(f"Unable to resolve sparse ann backend: '{backend}'") from e
|