mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/ann/dense/numpy.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
"""
|
2
|
+
NumPy module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ...serialize import SerializeFactory
|
8
|
+
|
9
|
+
from ..base import ANN
|
10
|
+
|
11
|
+
|
12
|
+
class NumPy(ANN):
|
13
|
+
"""
|
14
|
+
Builds an ANN index backed by a NumPy array.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, config):
|
18
|
+
super().__init__(config)
|
19
|
+
|
20
|
+
# Array function definitions
|
21
|
+
self.all, self.cat, self.dot, self.zeros = np.all, np.concatenate, np.dot, np.zeros
|
22
|
+
self.argsort, self.xor, self.clip = np.argsort, np.bitwise_xor, np.clip
|
23
|
+
|
24
|
+
# Scalar quantization
|
25
|
+
quantize = self.config.get("quantize")
|
26
|
+
self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
|
27
|
+
|
28
|
+
def load(self, path):
|
29
|
+
# Load array from file
|
30
|
+
try:
|
31
|
+
self.backend = self.tensor(np.load(path, allow_pickle=False))
|
32
|
+
except ValueError:
|
33
|
+
# Backwards compatible support for previously pickled data
|
34
|
+
self.backend = self.tensor(SerializeFactory.create("pickle").load(path))
|
35
|
+
|
36
|
+
def index(self, embeddings):
|
37
|
+
# Create index
|
38
|
+
self.backend = self.tensor(embeddings)
|
39
|
+
|
40
|
+
# Add id offset and index build metadata
|
41
|
+
self.config["offset"] = embeddings.shape[0]
|
42
|
+
self.metadata(self.settings())
|
43
|
+
|
44
|
+
def append(self, embeddings):
|
45
|
+
# Append new data to array
|
46
|
+
self.backend = self.cat((self.backend, self.tensor(embeddings)), axis=0)
|
47
|
+
|
48
|
+
# Update id offset and index metadata
|
49
|
+
self.config["offset"] += embeddings.shape[0]
|
50
|
+
self.metadata()
|
51
|
+
|
52
|
+
def delete(self, ids):
|
53
|
+
# Filter any index greater than size of array
|
54
|
+
ids = [x for x in ids if x < self.backend.shape[0]]
|
55
|
+
|
56
|
+
# Clear specified ids
|
57
|
+
self.backend[ids] = self.tensor(self.zeros((len(ids), self.backend.shape[1])))
|
58
|
+
|
59
|
+
def search(self, queries, limit):
|
60
|
+
if self.qbits:
|
61
|
+
# Calculate hamming score for integer vectors
|
62
|
+
scores = self.hammingscore(queries)
|
63
|
+
else:
|
64
|
+
# Dot product on normalized vectors is equal to cosine similarity
|
65
|
+
scores = self.dot(self.tensor(queries), self.backend.T)
|
66
|
+
|
67
|
+
# Get topn ids
|
68
|
+
ids = self.argsort(-scores)[:, :limit]
|
69
|
+
|
70
|
+
# Map results to [(id, score)]
|
71
|
+
results = []
|
72
|
+
for x, score in enumerate(scores):
|
73
|
+
# Add results
|
74
|
+
results.append(list(zip(ids[x].tolist(), score[ids[x]].tolist())))
|
75
|
+
|
76
|
+
return results
|
77
|
+
|
78
|
+
def count(self):
|
79
|
+
# Get count of non-zero rows (ignores deleted rows)
|
80
|
+
return self.backend[~self.all(self.backend == 0, axis=1)].shape[0]
|
81
|
+
|
82
|
+
def save(self, path):
|
83
|
+
# Save array to file. Use stream to prevent ".npy" suffix being added.
|
84
|
+
with open(path, "wb") as handle:
|
85
|
+
np.save(handle, self.numpy(self.backend), allow_pickle=False)
|
86
|
+
|
87
|
+
def tensor(self, array):
|
88
|
+
"""
|
89
|
+
Handles backend-specific code such as loading to a GPU device.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
array: data array
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
array with backend-specific logic applied
|
96
|
+
"""
|
97
|
+
|
98
|
+
return array
|
99
|
+
|
100
|
+
def numpy(self, array):
|
101
|
+
"""
|
102
|
+
Handles backend-specific code to convert an array to numpy
|
103
|
+
|
104
|
+
Args:
|
105
|
+
array: data array
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
numpy array
|
109
|
+
"""
|
110
|
+
|
111
|
+
return array
|
112
|
+
|
113
|
+
def totype(self, array, dtype):
|
114
|
+
"""
|
115
|
+
Casts array to dtype.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
array: input array
|
119
|
+
dtype: dtype
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
array cast as dtype
|
123
|
+
"""
|
124
|
+
|
125
|
+
return np.int64(array) if dtype == np.int64 else array
|
126
|
+
|
127
|
+
def settings(self):
|
128
|
+
"""
|
129
|
+
Returns settings for this array.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
dict
|
133
|
+
"""
|
134
|
+
|
135
|
+
return {"numpy": np.__version__}
|
136
|
+
|
137
|
+
def hammingscore(self, queries):
|
138
|
+
"""
|
139
|
+
Calculates a hamming distance score.
|
140
|
+
|
141
|
+
This is defined as:
|
142
|
+
|
143
|
+
score = 1.0 - (hamming distance / total number of bits)
|
144
|
+
|
145
|
+
Args:
|
146
|
+
queries: queries array
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
scores
|
150
|
+
"""
|
151
|
+
|
152
|
+
# Build table of number of bits for each distinct uint8 value
|
153
|
+
table = 1 << np.arange(8)
|
154
|
+
table = self.tensor(np.array([np.count_nonzero(x & table) for x in np.arange(256)]))
|
155
|
+
|
156
|
+
# Number of different bits
|
157
|
+
delta = self.xor(self.tensor(queries[:, None]), self.backend)
|
158
|
+
|
159
|
+
# Cast to long array
|
160
|
+
delta = self.totype(delta, np.int64)
|
161
|
+
|
162
|
+
# Calculate score as 1.0 - percentage of different bits
|
163
|
+
# Bound score from 0 to 1
|
164
|
+
return self.clip(1.0 - (table[delta].sum(axis=2) / (self.config["dimensions"] * 8)), 0.0, 1.0)
|
@@ -0,0 +1,323 @@
|
|
1
|
+
"""
|
2
|
+
PGVector module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
# Conditional import
|
10
|
+
try:
|
11
|
+
from pgvector.sqlalchemy import BIT, HALFVEC, VECTOR
|
12
|
+
|
13
|
+
from sqlalchemy import create_engine, delete, func, text, Column, Index, Integer, MetaData, StaticPool, Table
|
14
|
+
from sqlalchemy.orm import Session
|
15
|
+
from sqlalchemy.schema import CreateSchema
|
16
|
+
|
17
|
+
PGVECTOR = True
|
18
|
+
except ImportError:
|
19
|
+
PGVECTOR = False
|
20
|
+
|
21
|
+
from ..base import ANN
|
22
|
+
|
23
|
+
|
24
|
+
# pylint: disable=R0904
|
25
|
+
class PGVector(ANN):
|
26
|
+
"""
|
27
|
+
Builds an ANN index backed by a Postgres database.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, config):
|
31
|
+
super().__init__(config)
|
32
|
+
|
33
|
+
if not PGVECTOR:
|
34
|
+
raise ImportError('PGVector is not available - install "ann" extra to enable')
|
35
|
+
|
36
|
+
# Database connection
|
37
|
+
self.engine, self.database, self.connection, self.table = None, None, None, None
|
38
|
+
|
39
|
+
# Scalar quantization
|
40
|
+
quantize = self.config.get("quantize")
|
41
|
+
self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
|
42
|
+
|
43
|
+
def load(self, path):
|
44
|
+
# Initialize tables
|
45
|
+
self.initialize()
|
46
|
+
|
47
|
+
def index(self, embeddings):
|
48
|
+
# Initialize tables
|
49
|
+
self.initialize(recreate=True)
|
50
|
+
|
51
|
+
# Prepare embeddings and insert rows
|
52
|
+
self.database.execute(self.table.insert(), [{"indexid": x, "embedding": self.prepare(row)} for x, row in enumerate(embeddings)])
|
53
|
+
|
54
|
+
# Create index
|
55
|
+
self.createindex()
|
56
|
+
|
57
|
+
# Add id offset and index build metadata
|
58
|
+
self.config["offset"] = embeddings.shape[0]
|
59
|
+
self.metadata(self.settings())
|
60
|
+
|
61
|
+
def append(self, embeddings):
|
62
|
+
# Prepare embeddings and insert rows
|
63
|
+
self.database.execute(
|
64
|
+
self.table.insert(), [{"indexid": x + self.config["offset"], "embedding": self.prepare(row)} for x, row in enumerate(embeddings)]
|
65
|
+
)
|
66
|
+
|
67
|
+
# Update id offset and index metadata
|
68
|
+
self.config["offset"] += embeddings.shape[0]
|
69
|
+
self.metadata()
|
70
|
+
|
71
|
+
def delete(self, ids):
|
72
|
+
self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids)))
|
73
|
+
|
74
|
+
def search(self, queries, limit):
|
75
|
+
results = []
|
76
|
+
for query in queries:
|
77
|
+
# Run query
|
78
|
+
query = self.database.query(self.table.c["indexid"], self.query(query)).order_by("score").limit(limit)
|
79
|
+
|
80
|
+
# Calculate and collect scores
|
81
|
+
results.append([(indexid, self.score(score)) for indexid, score in query])
|
82
|
+
|
83
|
+
return results
|
84
|
+
|
85
|
+
def count(self):
|
86
|
+
# pylint: disable=E1102
|
87
|
+
return self.database.query(func.count(self.table.c["indexid"])).scalar()
|
88
|
+
|
89
|
+
def save(self, path):
|
90
|
+
# Commit session and connection
|
91
|
+
self.database.commit()
|
92
|
+
self.connection.commit()
|
93
|
+
|
94
|
+
def close(self):
|
95
|
+
# Parent logic
|
96
|
+
super().close()
|
97
|
+
|
98
|
+
# Close database connection
|
99
|
+
if self.database:
|
100
|
+
self.database.close()
|
101
|
+
self.engine.dispose()
|
102
|
+
|
103
|
+
def initialize(self, recreate=False):
|
104
|
+
"""
|
105
|
+
Initializes a new database session.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
recreate: Recreates the database tables if True
|
109
|
+
"""
|
110
|
+
|
111
|
+
# Connect to database
|
112
|
+
self.connect()
|
113
|
+
|
114
|
+
# Set the database schema
|
115
|
+
self.schema()
|
116
|
+
|
117
|
+
# Table name
|
118
|
+
table = self.setting("table", self.defaulttable())
|
119
|
+
|
120
|
+
# Create vectors table object
|
121
|
+
self.table = Table(table, MetaData(), Column("indexid", Integer, primary_key=True, autoincrement=False), Column("embedding", self.column()))
|
122
|
+
|
123
|
+
# Drop table, if necessary
|
124
|
+
if recreate:
|
125
|
+
self.table.drop(self.connection, checkfirst=True)
|
126
|
+
|
127
|
+
# Create table, if necessary
|
128
|
+
self.table.create(self.connection, checkfirst=True)
|
129
|
+
|
130
|
+
def createindex(self):
|
131
|
+
"""
|
132
|
+
Creates a index with the current settings.
|
133
|
+
"""
|
134
|
+
|
135
|
+
# Table name
|
136
|
+
table = self.setting("table", self.defaulttable())
|
137
|
+
|
138
|
+
# Create ANN index - inner product is equal to cosine similarity on normalized vectors
|
139
|
+
index = Index(
|
140
|
+
f"{table}-index",
|
141
|
+
self.table.c["embedding"],
|
142
|
+
postgresql_using="hnsw",
|
143
|
+
postgresql_with=self.settings(),
|
144
|
+
postgresql_ops={"embedding": self.operation()},
|
145
|
+
)
|
146
|
+
|
147
|
+
# Create or recreate index
|
148
|
+
index.drop(self.connection, checkfirst=True)
|
149
|
+
index.create(self.connection, checkfirst=True)
|
150
|
+
|
151
|
+
def connect(self):
|
152
|
+
"""
|
153
|
+
Establishes a database connection. Cleans up any existing database connection first.
|
154
|
+
"""
|
155
|
+
|
156
|
+
# Close existing connection
|
157
|
+
if self.database:
|
158
|
+
self.close()
|
159
|
+
|
160
|
+
# Create engine
|
161
|
+
self.engine = create_engine(self.url(), poolclass=StaticPool, echo=False)
|
162
|
+
self.connection = self.engine.connect()
|
163
|
+
|
164
|
+
# Start database session
|
165
|
+
self.database = Session(self.connection)
|
166
|
+
|
167
|
+
# Initialize pgvector extension
|
168
|
+
self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
169
|
+
|
170
|
+
def schema(self):
|
171
|
+
"""
|
172
|
+
Sets the database schema, if available.
|
173
|
+
"""
|
174
|
+
|
175
|
+
# Set default schema, if necessary
|
176
|
+
schema = self.setting("schema")
|
177
|
+
if schema:
|
178
|
+
with self.engine.begin():
|
179
|
+
self.sqldialect(CreateSchema(schema, if_not_exists=True))
|
180
|
+
|
181
|
+
self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema})
|
182
|
+
|
183
|
+
def settings(self):
|
184
|
+
"""
|
185
|
+
Returns settings for this index.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
dict
|
189
|
+
"""
|
190
|
+
|
191
|
+
return {"m": self.setting("m", 16), "ef_construction": self.setting("efconstruction", 200)}
|
192
|
+
|
193
|
+
def sqldialect(self, sql, parameters=None):
|
194
|
+
"""
|
195
|
+
Executes a SQL statement based on the current SQL dialect.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
sql: SQL to execute
|
199
|
+
parameters: optional bind parameters
|
200
|
+
"""
|
201
|
+
|
202
|
+
args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
|
203
|
+
self.database.execute(*args)
|
204
|
+
|
205
|
+
def defaulttable(self):
|
206
|
+
"""
|
207
|
+
Returns the default table name.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
default table name
|
211
|
+
"""
|
212
|
+
|
213
|
+
return "vectors"
|
214
|
+
|
215
|
+
def url(self):
|
216
|
+
"""
|
217
|
+
Reads the database url parameter.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
database url
|
221
|
+
"""
|
222
|
+
|
223
|
+
return self.setting("url", os.environ.get("ANN_URL"))
|
224
|
+
|
225
|
+
def column(self):
|
226
|
+
"""
|
227
|
+
Gets embedding column for the current settings.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
embedding column definition
|
231
|
+
"""
|
232
|
+
|
233
|
+
if self.qbits:
|
234
|
+
# If quantization is set, always return BIT vectors
|
235
|
+
return BIT(self.config["dimensions"] * 8)
|
236
|
+
|
237
|
+
if self.setting("precision") == "half":
|
238
|
+
# 16-bit HALF precision vectors
|
239
|
+
return HALFVEC(self.config["dimensions"])
|
240
|
+
|
241
|
+
# Default is full 32-bit FULL precision vectors
|
242
|
+
return VECTOR(self.config["dimensions"])
|
243
|
+
|
244
|
+
def operation(self):
|
245
|
+
"""
|
246
|
+
Gets the index operation for the current settings.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
index operation
|
250
|
+
"""
|
251
|
+
|
252
|
+
if self.qbits:
|
253
|
+
# If quantization is set, always return BIT vectors
|
254
|
+
return "bit_hamming_ops"
|
255
|
+
|
256
|
+
if self.setting("precision") == "half":
|
257
|
+
# 16-bit HALF precision vectors
|
258
|
+
return "halfvec_ip_ops"
|
259
|
+
|
260
|
+
# Default is full 32-bit FULL precision vectors
|
261
|
+
return "vector_ip_ops"
|
262
|
+
|
263
|
+
def prepare(self, data):
|
264
|
+
"""
|
265
|
+
Prepares data for the embeddings column. This method returns a bit string for bit vectors and
|
266
|
+
the input data unmodified for float vectors.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
data: input data
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
data ready for the embeddings column
|
273
|
+
"""
|
274
|
+
|
275
|
+
# Transform to a bit string when vector quantization is enabled
|
276
|
+
if self.qbits:
|
277
|
+
return "".join(np.where(np.unpackbits(data), "1", "0"))
|
278
|
+
|
279
|
+
# Return original data
|
280
|
+
return data
|
281
|
+
|
282
|
+
def query(self, query):
|
283
|
+
"""
|
284
|
+
Creates a query statement from an input query. This method uses hamming distance for bit vectors and
|
285
|
+
the max_inner_product for float vectors.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
query: input query
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
query statement
|
292
|
+
"""
|
293
|
+
|
294
|
+
# Prepare query embeddings
|
295
|
+
query = self.prepare(query)
|
296
|
+
|
297
|
+
# Bit vector query
|
298
|
+
if self.qbits:
|
299
|
+
return self.table.c["embedding"].hamming_distance(query).label("score")
|
300
|
+
|
301
|
+
# Float vector query
|
302
|
+
return self.table.c["embedding"].max_inner_product(query).label("score")
|
303
|
+
|
304
|
+
def score(self, score):
|
305
|
+
"""
|
306
|
+
Calculates the index score from the input score. This method returns the hamming score
|
307
|
+
(1.0 - (hamming distance / total number of bits)) for bit vectors and the -score for
|
308
|
+
float vectors.
|
309
|
+
|
310
|
+
Args:
|
311
|
+
score: input score
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
index score
|
315
|
+
"""
|
316
|
+
|
317
|
+
# Calculate hamming score as 1.0 - (hamming distance / total number of bits)
|
318
|
+
# Bound score from 0 to 1
|
319
|
+
if self.qbits:
|
320
|
+
return min(max(0.0, 1.0 - (score / (self.config["dimensions"] * 8))), 1.0)
|
321
|
+
|
322
|
+
# pgvector returns negative inner product since Postgres only supports ASC order index scans on operators
|
323
|
+
return -score
|