mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,164 @@
1
+ """
2
+ NumPy module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from ...serialize import SerializeFactory
8
+
9
+ from ..base import ANN
10
+
11
+
12
+ class NumPy(ANN):
13
+ """
14
+ Builds an ANN index backed by a NumPy array.
15
+ """
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+
20
+ # Array function definitions
21
+ self.all, self.cat, self.dot, self.zeros = np.all, np.concatenate, np.dot, np.zeros
22
+ self.argsort, self.xor, self.clip = np.argsort, np.bitwise_xor, np.clip
23
+
24
+ # Scalar quantization
25
+ quantize = self.config.get("quantize")
26
+ self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
27
+
28
+ def load(self, path):
29
+ # Load array from file
30
+ try:
31
+ self.backend = self.tensor(np.load(path, allow_pickle=False))
32
+ except ValueError:
33
+ # Backwards compatible support for previously pickled data
34
+ self.backend = self.tensor(SerializeFactory.create("pickle").load(path))
35
+
36
+ def index(self, embeddings):
37
+ # Create index
38
+ self.backend = self.tensor(embeddings)
39
+
40
+ # Add id offset and index build metadata
41
+ self.config["offset"] = embeddings.shape[0]
42
+ self.metadata(self.settings())
43
+
44
+ def append(self, embeddings):
45
+ # Append new data to array
46
+ self.backend = self.cat((self.backend, self.tensor(embeddings)), axis=0)
47
+
48
+ # Update id offset and index metadata
49
+ self.config["offset"] += embeddings.shape[0]
50
+ self.metadata()
51
+
52
+ def delete(self, ids):
53
+ # Filter any index greater than size of array
54
+ ids = [x for x in ids if x < self.backend.shape[0]]
55
+
56
+ # Clear specified ids
57
+ self.backend[ids] = self.tensor(self.zeros((len(ids), self.backend.shape[1])))
58
+
59
+ def search(self, queries, limit):
60
+ if self.qbits:
61
+ # Calculate hamming score for integer vectors
62
+ scores = self.hammingscore(queries)
63
+ else:
64
+ # Dot product on normalized vectors is equal to cosine similarity
65
+ scores = self.dot(self.tensor(queries), self.backend.T)
66
+
67
+ # Get topn ids
68
+ ids = self.argsort(-scores)[:, :limit]
69
+
70
+ # Map results to [(id, score)]
71
+ results = []
72
+ for x, score in enumerate(scores):
73
+ # Add results
74
+ results.append(list(zip(ids[x].tolist(), score[ids[x]].tolist())))
75
+
76
+ return results
77
+
78
+ def count(self):
79
+ # Get count of non-zero rows (ignores deleted rows)
80
+ return self.backend[~self.all(self.backend == 0, axis=1)].shape[0]
81
+
82
+ def save(self, path):
83
+ # Save array to file. Use stream to prevent ".npy" suffix being added.
84
+ with open(path, "wb") as handle:
85
+ np.save(handle, self.numpy(self.backend), allow_pickle=False)
86
+
87
+ def tensor(self, array):
88
+ """
89
+ Handles backend-specific code such as loading to a GPU device.
90
+
91
+ Args:
92
+ array: data array
93
+
94
+ Returns:
95
+ array with backend-specific logic applied
96
+ """
97
+
98
+ return array
99
+
100
+ def numpy(self, array):
101
+ """
102
+ Handles backend-specific code to convert an array to numpy
103
+
104
+ Args:
105
+ array: data array
106
+
107
+ Returns:
108
+ numpy array
109
+ """
110
+
111
+ return array
112
+
113
+ def totype(self, array, dtype):
114
+ """
115
+ Casts array to dtype.
116
+
117
+ Args:
118
+ array: input array
119
+ dtype: dtype
120
+
121
+ Returns:
122
+ array cast as dtype
123
+ """
124
+
125
+ return np.int64(array) if dtype == np.int64 else array
126
+
127
+ def settings(self):
128
+ """
129
+ Returns settings for this array.
130
+
131
+ Returns:
132
+ dict
133
+ """
134
+
135
+ return {"numpy": np.__version__}
136
+
137
+ def hammingscore(self, queries):
138
+ """
139
+ Calculates a hamming distance score.
140
+
141
+ This is defined as:
142
+
143
+ score = 1.0 - (hamming distance / total number of bits)
144
+
145
+ Args:
146
+ queries: queries array
147
+
148
+ Returns:
149
+ scores
150
+ """
151
+
152
+ # Build table of number of bits for each distinct uint8 value
153
+ table = 1 << np.arange(8)
154
+ table = self.tensor(np.array([np.count_nonzero(x & table) for x in np.arange(256)]))
155
+
156
+ # Number of different bits
157
+ delta = self.xor(self.tensor(queries[:, None]), self.backend)
158
+
159
+ # Cast to long array
160
+ delta = self.totype(delta, np.int64)
161
+
162
+ # Calculate score as 1.0 - percentage of different bits
163
+ # Bound score from 0 to 1
164
+ return self.clip(1.0 - (table[delta].sum(axis=2) / (self.config["dimensions"] * 8)), 0.0, 1.0)
@@ -0,0 +1,323 @@
1
+ """
2
+ PGVector module
3
+ """
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+
9
+ # Conditional import
10
+ try:
11
+ from pgvector.sqlalchemy import BIT, HALFVEC, VECTOR
12
+
13
+ from sqlalchemy import create_engine, delete, func, text, Column, Index, Integer, MetaData, StaticPool, Table
14
+ from sqlalchemy.orm import Session
15
+ from sqlalchemy.schema import CreateSchema
16
+
17
+ PGVECTOR = True
18
+ except ImportError:
19
+ PGVECTOR = False
20
+
21
+ from ..base import ANN
22
+
23
+
24
+ # pylint: disable=R0904
25
+ class PGVector(ANN):
26
+ """
27
+ Builds an ANN index backed by a Postgres database.
28
+ """
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+
33
+ if not PGVECTOR:
34
+ raise ImportError('PGVector is not available - install "ann" extra to enable')
35
+
36
+ # Database connection
37
+ self.engine, self.database, self.connection, self.table = None, None, None, None
38
+
39
+ # Scalar quantization
40
+ quantize = self.config.get("quantize")
41
+ self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
42
+
43
+ def load(self, path):
44
+ # Initialize tables
45
+ self.initialize()
46
+
47
+ def index(self, embeddings):
48
+ # Initialize tables
49
+ self.initialize(recreate=True)
50
+
51
+ # Prepare embeddings and insert rows
52
+ self.database.execute(self.table.insert(), [{"indexid": x, "embedding": self.prepare(row)} for x, row in enumerate(embeddings)])
53
+
54
+ # Create index
55
+ self.createindex()
56
+
57
+ # Add id offset and index build metadata
58
+ self.config["offset"] = embeddings.shape[0]
59
+ self.metadata(self.settings())
60
+
61
+ def append(self, embeddings):
62
+ # Prepare embeddings and insert rows
63
+ self.database.execute(
64
+ self.table.insert(), [{"indexid": x + self.config["offset"], "embedding": self.prepare(row)} for x, row in enumerate(embeddings)]
65
+ )
66
+
67
+ # Update id offset and index metadata
68
+ self.config["offset"] += embeddings.shape[0]
69
+ self.metadata()
70
+
71
+ def delete(self, ids):
72
+ self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids)))
73
+
74
+ def search(self, queries, limit):
75
+ results = []
76
+ for query in queries:
77
+ # Run query
78
+ query = self.database.query(self.table.c["indexid"], self.query(query)).order_by("score").limit(limit)
79
+
80
+ # Calculate and collect scores
81
+ results.append([(indexid, self.score(score)) for indexid, score in query])
82
+
83
+ return results
84
+
85
+ def count(self):
86
+ # pylint: disable=E1102
87
+ return self.database.query(func.count(self.table.c["indexid"])).scalar()
88
+
89
+ def save(self, path):
90
+ # Commit session and connection
91
+ self.database.commit()
92
+ self.connection.commit()
93
+
94
+ def close(self):
95
+ # Parent logic
96
+ super().close()
97
+
98
+ # Close database connection
99
+ if self.database:
100
+ self.database.close()
101
+ self.engine.dispose()
102
+
103
+ def initialize(self, recreate=False):
104
+ """
105
+ Initializes a new database session.
106
+
107
+ Args:
108
+ recreate: Recreates the database tables if True
109
+ """
110
+
111
+ # Connect to database
112
+ self.connect()
113
+
114
+ # Set the database schema
115
+ self.schema()
116
+
117
+ # Table name
118
+ table = self.setting("table", self.defaulttable())
119
+
120
+ # Create vectors table object
121
+ self.table = Table(table, MetaData(), Column("indexid", Integer, primary_key=True, autoincrement=False), Column("embedding", self.column()))
122
+
123
+ # Drop table, if necessary
124
+ if recreate:
125
+ self.table.drop(self.connection, checkfirst=True)
126
+
127
+ # Create table, if necessary
128
+ self.table.create(self.connection, checkfirst=True)
129
+
130
+ def createindex(self):
131
+ """
132
+ Creates a index with the current settings.
133
+ """
134
+
135
+ # Table name
136
+ table = self.setting("table", self.defaulttable())
137
+
138
+ # Create ANN index - inner product is equal to cosine similarity on normalized vectors
139
+ index = Index(
140
+ f"{table}-index",
141
+ self.table.c["embedding"],
142
+ postgresql_using="hnsw",
143
+ postgresql_with=self.settings(),
144
+ postgresql_ops={"embedding": self.operation()},
145
+ )
146
+
147
+ # Create or recreate index
148
+ index.drop(self.connection, checkfirst=True)
149
+ index.create(self.connection, checkfirst=True)
150
+
151
+ def connect(self):
152
+ """
153
+ Establishes a database connection. Cleans up any existing database connection first.
154
+ """
155
+
156
+ # Close existing connection
157
+ if self.database:
158
+ self.close()
159
+
160
+ # Create engine
161
+ self.engine = create_engine(self.url(), poolclass=StaticPool, echo=False)
162
+ self.connection = self.engine.connect()
163
+
164
+ # Start database session
165
+ self.database = Session(self.connection)
166
+
167
+ # Initialize pgvector extension
168
+ self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))
169
+
170
+ def schema(self):
171
+ """
172
+ Sets the database schema, if available.
173
+ """
174
+
175
+ # Set default schema, if necessary
176
+ schema = self.setting("schema")
177
+ if schema:
178
+ with self.engine.begin():
179
+ self.sqldialect(CreateSchema(schema, if_not_exists=True))
180
+
181
+ self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema})
182
+
183
+ def settings(self):
184
+ """
185
+ Returns settings for this index.
186
+
187
+ Returns:
188
+ dict
189
+ """
190
+
191
+ return {"m": self.setting("m", 16), "ef_construction": self.setting("efconstruction", 200)}
192
+
193
+ def sqldialect(self, sql, parameters=None):
194
+ """
195
+ Executes a SQL statement based on the current SQL dialect.
196
+
197
+ Args:
198
+ sql: SQL to execute
199
+ parameters: optional bind parameters
200
+ """
201
+
202
+ args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
203
+ self.database.execute(*args)
204
+
205
+ def defaulttable(self):
206
+ """
207
+ Returns the default table name.
208
+
209
+ Returns:
210
+ default table name
211
+ """
212
+
213
+ return "vectors"
214
+
215
+ def url(self):
216
+ """
217
+ Reads the database url parameter.
218
+
219
+ Returns:
220
+ database url
221
+ """
222
+
223
+ return self.setting("url", os.environ.get("ANN_URL"))
224
+
225
+ def column(self):
226
+ """
227
+ Gets embedding column for the current settings.
228
+
229
+ Returns:
230
+ embedding column definition
231
+ """
232
+
233
+ if self.qbits:
234
+ # If quantization is set, always return BIT vectors
235
+ return BIT(self.config["dimensions"] * 8)
236
+
237
+ if self.setting("precision") == "half":
238
+ # 16-bit HALF precision vectors
239
+ return HALFVEC(self.config["dimensions"])
240
+
241
+ # Default is full 32-bit FULL precision vectors
242
+ return VECTOR(self.config["dimensions"])
243
+
244
+ def operation(self):
245
+ """
246
+ Gets the index operation for the current settings.
247
+
248
+ Returns:
249
+ index operation
250
+ """
251
+
252
+ if self.qbits:
253
+ # If quantization is set, always return BIT vectors
254
+ return "bit_hamming_ops"
255
+
256
+ if self.setting("precision") == "half":
257
+ # 16-bit HALF precision vectors
258
+ return "halfvec_ip_ops"
259
+
260
+ # Default is full 32-bit FULL precision vectors
261
+ return "vector_ip_ops"
262
+
263
+ def prepare(self, data):
264
+ """
265
+ Prepares data for the embeddings column. This method returns a bit string for bit vectors and
266
+ the input data unmodified for float vectors.
267
+
268
+ Args:
269
+ data: input data
270
+
271
+ Returns:
272
+ data ready for the embeddings column
273
+ """
274
+
275
+ # Transform to a bit string when vector quantization is enabled
276
+ if self.qbits:
277
+ return "".join(np.where(np.unpackbits(data), "1", "0"))
278
+
279
+ # Return original data
280
+ return data
281
+
282
+ def query(self, query):
283
+ """
284
+ Creates a query statement from an input query. This method uses hamming distance for bit vectors and
285
+ the max_inner_product for float vectors.
286
+
287
+ Args:
288
+ query: input query
289
+
290
+ Returns:
291
+ query statement
292
+ """
293
+
294
+ # Prepare query embeddings
295
+ query = self.prepare(query)
296
+
297
+ # Bit vector query
298
+ if self.qbits:
299
+ return self.table.c["embedding"].hamming_distance(query).label("score")
300
+
301
+ # Float vector query
302
+ return self.table.c["embedding"].max_inner_product(query).label("score")
303
+
304
+ def score(self, score):
305
+ """
306
+ Calculates the index score from the input score. This method returns the hamming score
307
+ (1.0 - (hamming distance / total number of bits)) for bit vectors and the -score for
308
+ float vectors.
309
+
310
+ Args:
311
+ score: input score
312
+
313
+ Returns:
314
+ index score
315
+ """
316
+
317
+ # Calculate hamming score as 1.0 - (hamming distance / total number of bits)
318
+ # Bound score from 0 to 1
319
+ if self.qbits:
320
+ return min(max(0.0, 1.0 - (score / (self.config["dimensions"] * 8))), 1.0)
321
+
322
+ # pgvector returns negative inner product since Postgres only supports ASC order index scans on operators
323
+ return -score