mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,49 @@
1
+ """
2
+ Function imports
3
+ """
4
+
5
+ from smolagents import Tool
6
+
7
+
8
+ class FunctionTool(Tool):
9
+ """
10
+ Creates a FunctionTool. A FunctionTool takes descriptive configuration and injects it along with a target function
11
+ into an LLM prompt.
12
+ """
13
+
14
+ # pylint: disable=W0231
15
+ def __init__(self, config):
16
+ """
17
+ Creates a FunctionTool.
18
+
19
+ Args:
20
+ config: `name`, `description`, `inputs`, `output` and `target` configuration
21
+ """
22
+
23
+ # Tool parameters
24
+ self.name = config["name"]
25
+ self.description = config["description"]
26
+ self.inputs = config["inputs"]
27
+ self.output_type = config.get("output", config.get("output_type", "any"))
28
+ self.target = config["target"]
29
+
30
+ # pylint: disable=C0103
31
+ # Skip forward signature validation
32
+ self.skip_forward_signature_validation = True
33
+
34
+ # Validate parameters and initialize tool
35
+ super().__init__()
36
+
37
+ def forward(self, *args, **kwargs):
38
+ """
39
+ Runs target function.
40
+
41
+ Args:
42
+ args: positional args
43
+ kwargs: keyword args
44
+
45
+ Returns:
46
+ result
47
+ """
48
+
49
+ return self.target(*args, **kwargs)
txtai/ann/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """
2
+ ANN imports
3
+ """
4
+
5
+ from .base import ANN
6
+ from .dense import *
7
+ from .sparse import *
txtai/ann/base.py ADDED
@@ -0,0 +1,153 @@
1
+ """
2
+ ANN (Approximate Nearest Neighbor) module
3
+ """
4
+
5
+ import datetime
6
+ import platform
7
+
8
+ from ..version import __version__
9
+
10
+
11
+ class ANN:
12
+ """
13
+ Base class for ANN instances. This class builds vector indexes to support similarity search.
14
+ The built-in ANN backends store ids and vectors. Content storage is supported via database instances.
15
+ """
16
+
17
+ def __init__(self, config):
18
+ """
19
+ Creates a new ANN.
20
+
21
+ Args:
22
+ config: index configuration parameters
23
+ """
24
+
25
+ # ANN index
26
+ self.backend = None
27
+
28
+ # ANN configuration
29
+ self.config = config
30
+
31
+ def load(self, path):
32
+ """
33
+ Loads an ANN at path.
34
+
35
+ Args:
36
+ path: path to load ann index
37
+ """
38
+
39
+ raise NotImplementedError
40
+
41
+ def index(self, embeddings):
42
+ """
43
+ Builds an ANN index.
44
+
45
+ Args:
46
+ embeddings: embeddings array
47
+ """
48
+
49
+ raise NotImplementedError
50
+
51
+ def append(self, embeddings):
52
+ """
53
+ Append elements to an existing index.
54
+
55
+ Args:
56
+ embeddings: embeddings array
57
+ """
58
+
59
+ raise NotImplementedError
60
+
61
+ def delete(self, ids):
62
+ """
63
+ Deletes elements from existing index.
64
+
65
+ Args:
66
+ ids: ids to delete
67
+ """
68
+
69
+ raise NotImplementedError
70
+
71
+ def search(self, queries, limit):
72
+ """
73
+ Searches ANN index for query. Returns topn results.
74
+
75
+ Args:
76
+ queries: queries array
77
+ limit: maximum results
78
+
79
+ Returns:
80
+ query results
81
+ """
82
+
83
+ raise NotImplementedError
84
+
85
+ def count(self):
86
+ """
87
+ Number of elements in the ANN index.
88
+
89
+ Returns:
90
+ count
91
+ """
92
+
93
+ raise NotImplementedError
94
+
95
+ def save(self, path):
96
+ """
97
+ Saves an ANN index at path.
98
+
99
+ Args:
100
+ path: path to save ann index
101
+ """
102
+
103
+ raise NotImplementedError
104
+
105
+ def close(self):
106
+ """
107
+ Closes this ANN.
108
+ """
109
+
110
+ self.backend = None
111
+
112
+ def setting(self, name, default=None):
113
+ """
114
+ Looks up backend specific setting.
115
+
116
+ Args:
117
+ name: setting name
118
+ default: default value when setting not found
119
+
120
+ Returns:
121
+ setting value
122
+ """
123
+
124
+ # Get the backend-specific config object
125
+ backend = self.config.get(self.config["backend"])
126
+
127
+ # Get setting value, set default value if not found
128
+ setting = backend.get(name) if backend else None
129
+ return setting if setting else default
130
+
131
+ def metadata(self, settings=None):
132
+ """
133
+ Adds index build metadata.
134
+
135
+ Args:
136
+ settings: index build settings
137
+ """
138
+
139
+ # ISO 8601 timestamp
140
+ create = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
141
+
142
+ # Set build metadata if this is not an update
143
+ if settings:
144
+ self.config["build"] = {
145
+ "create": create,
146
+ "python": platform.python_version(),
147
+ "settings": settings,
148
+ "system": f"{platform.system()} ({platform.machine()})",
149
+ "txtai": __version__,
150
+ }
151
+
152
+ # Set last update date
153
+ self.config["update"] = create
@@ -0,0 +1,11 @@
1
+ """
2
+ Dense ANN imports
3
+ """
4
+
5
+ from .annoy import Annoy
6
+ from .factory import ANNFactory
7
+ from .faiss import Faiss
8
+ from .hnsw import HNSW
9
+ from .numpy import NumPy
10
+ from .pgvector import PGVector
11
+ from .torch import Torch
@@ -0,0 +1,72 @@
1
+ """
2
+ Annoy module
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ from annoy import AnnoyIndex
8
+
9
+ ANNOY = True
10
+ except ImportError:
11
+ ANNOY = False
12
+
13
+ from ..base import ANN
14
+
15
+
16
+ # pylint: disable=W0223
17
+ class Annoy(ANN):
18
+ """
19
+ Builds an ANN index using the Annoy library.
20
+ """
21
+
22
+ def __init__(self, config):
23
+ super().__init__(config)
24
+
25
+ if not ANNOY:
26
+ raise ImportError('Annoy is not available - install "ann" extra to enable')
27
+
28
+ def load(self, path):
29
+ # Load index
30
+ self.backend = AnnoyIndex(self.config["dimensions"], self.config["metric"])
31
+ self.backend.load(path)
32
+
33
+ def index(self, embeddings):
34
+ # Inner product is equal to cosine similarity on normalized vectors
35
+ self.config["metric"] = "dot"
36
+
37
+ # Create index
38
+ self.backend = AnnoyIndex(self.config["dimensions"], self.config["metric"])
39
+
40
+ # Add items - position in embeddings is used as the id
41
+ for x in range(embeddings.shape[0]):
42
+ self.backend.add_item(x, embeddings[x])
43
+
44
+ # Build index
45
+ ntrees = self.setting("ntrees", 10)
46
+ self.backend.build(ntrees)
47
+
48
+ # Add index build metadata
49
+ self.metadata({"ntrees": ntrees})
50
+
51
+ def search(self, queries, limit):
52
+ # Lookup search k setting
53
+ searchk = self.setting("searchk", -1)
54
+
55
+ # Annoy doesn't have a built in batch query method
56
+ results = []
57
+ for query in queries:
58
+ # Run the query
59
+ ids, scores = self.backend.get_nns_by_vector(query, n=limit, search_k=searchk, include_distances=True)
60
+
61
+ # Map results to [(id, score)]
62
+ results.append(list(zip(ids, scores)))
63
+
64
+ return results
65
+
66
+ def count(self):
67
+ # Number of items in index
68
+ return self.backend.get_n_items()
69
+
70
+ def save(self, path):
71
+ # Write index
72
+ self.backend.save(path)
@@ -0,0 +1,76 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ from ...util import Resolver
6
+
7
+ from .annoy import Annoy
8
+ from .faiss import Faiss
9
+ from .hnsw import HNSW
10
+ from .numpy import NumPy
11
+ from .pgvector import PGVector
12
+ from .sqlite import SQLite
13
+ from .torch import Torch
14
+
15
+
16
+ class ANNFactory:
17
+ """
18
+ Methods to create ANN indexes.
19
+ """
20
+
21
+ @staticmethod
22
+ def create(config):
23
+ """
24
+ Create an ANN.
25
+
26
+ Args:
27
+ config: index configuration parameters
28
+
29
+ Returns:
30
+ ANN
31
+ """
32
+
33
+ # ANN instance
34
+ ann = None
35
+ backend = config.get("backend", "faiss")
36
+
37
+ # Create ANN instance
38
+ if backend == "annoy":
39
+ ann = Annoy(config)
40
+ elif backend == "faiss":
41
+ ann = Faiss(config)
42
+ elif backend == "hnsw":
43
+ ann = HNSW(config)
44
+ elif backend == "numpy":
45
+ ann = NumPy(config)
46
+ elif backend == "pgvector":
47
+ ann = PGVector(config)
48
+ elif backend == "sqlite":
49
+ ann = SQLite(config)
50
+ elif backend == "torch":
51
+ ann = Torch(config)
52
+ else:
53
+ ann = ANNFactory.resolve(backend, config)
54
+
55
+ # Store config back
56
+ config["backend"] = backend
57
+
58
+ return ann
59
+
60
+ @staticmethod
61
+ def resolve(backend, config):
62
+ """
63
+ Attempt to resolve a custom backend.
64
+
65
+ Args:
66
+ backend: backend class
67
+ config: index configuration parameters
68
+
69
+ Returns:
70
+ ANN
71
+ """
72
+
73
+ try:
74
+ return Resolver()(backend)(config)
75
+ except Exception as e:
76
+ raise ImportError(f"Unable to resolve ann backend: '{backend}'") from e
@@ -0,0 +1,233 @@
1
+ """
2
+ Faiss module
3
+ """
4
+
5
+ import math
6
+ import platform
7
+
8
+ import numpy as np
9
+
10
+ from faiss import omp_set_num_threads
11
+ from faiss import index_factory, IO_FLAG_MMAP, METRIC_INNER_PRODUCT, read_index, write_index
12
+ from faiss import index_binary_factory, read_index_binary, write_index_binary, IndexBinaryIDMap
13
+
14
+ from ..base import ANN
15
+
16
+ if platform.system() == "Darwin":
17
+ # Workaround for a Faiss issue causing segmentation faults on macOS. See txtai FAQ for more.
18
+ omp_set_num_threads(1)
19
+
20
+
21
+ class Faiss(ANN):
22
+ """
23
+ Builds an ANN index using the Faiss library.
24
+ """
25
+
26
+ def __init__(self, config):
27
+ super().__init__(config)
28
+
29
+ # Scalar quantization
30
+ quantize = self.config.get("quantize")
31
+ self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
32
+
33
+ def load(self, path):
34
+ # Get read function
35
+ readindex = read_index_binary if self.qbits else read_index
36
+
37
+ # Load index
38
+ self.backend = readindex(path, IO_FLAG_MMAP if self.setting("mmap") is True else 0)
39
+
40
+ def index(self, embeddings):
41
+ # Compute model training size
42
+ train, sample = embeddings, self.setting("sample")
43
+ if sample:
44
+ # Get sample for training
45
+ rng = np.random.default_rng(0)
46
+ indices = sorted(rng.choice(train.shape[0], int(sample * train.shape[0]), replace=False, shuffle=False))
47
+ train = train[indices]
48
+
49
+ # Configure embeddings index. Inner product is equal to cosine similarity on normalized vectors.
50
+ params = self.configure(embeddings.shape[0], train.shape[0])
51
+
52
+ # Create index
53
+ self.backend = self.create(embeddings, params)
54
+
55
+ # Train model
56
+ self.backend.train(train)
57
+
58
+ # Add embeddings - position in embeddings is used as the id
59
+ self.backend.add_with_ids(embeddings, np.arange(embeddings.shape[0], dtype=np.int64))
60
+
61
+ # Add id offset and index build metadata
62
+ self.config["offset"] = embeddings.shape[0]
63
+ self.metadata({"components": params})
64
+
65
+ def append(self, embeddings):
66
+ new = embeddings.shape[0]
67
+
68
+ # Append new ids - position in embeddings + existing offset is used as the id
69
+ self.backend.add_with_ids(embeddings, np.arange(self.config["offset"], self.config["offset"] + new, dtype=np.int64))
70
+
71
+ # Update id offset and index metadata
72
+ self.config["offset"] += new
73
+ self.metadata()
74
+
75
+ def delete(self, ids):
76
+ # Remove specified ids
77
+ self.backend.remove_ids(np.array(ids, dtype=np.int64))
78
+
79
+ def search(self, queries, limit):
80
+ # Set nprobe and nflip search parameters
81
+ self.backend.nprobe = self.nprobe()
82
+ self.backend.nflip = self.setting("nflip", self.backend.nprobe)
83
+
84
+ # Run the query
85
+ scores, ids = self.backend.search(queries, limit)
86
+
87
+ # Map results to [(id, score)]
88
+ results = []
89
+ for x, score in enumerate(scores):
90
+ # Transform scores and add results
91
+ results.append(list(zip(ids[x].tolist(), self.scores(score))))
92
+
93
+ return results
94
+
95
+ def count(self):
96
+ return self.backend.ntotal
97
+
98
+ def save(self, path):
99
+ # Get write function
100
+ writeindex = write_index_binary if self.qbits else write_index
101
+
102
+ # Write index
103
+ writeindex(self.backend, path)
104
+
105
+ def configure(self, count, train):
106
+ """
107
+ Configures settings for a new index.
108
+
109
+ Args:
110
+ count: initial number of embeddings rows
111
+ train: number of rows selected for model training
112
+
113
+ Returns:
114
+ user-specified or generated components setting
115
+ """
116
+
117
+ # Lookup components setting
118
+ components = self.setting("components")
119
+
120
+ if components:
121
+ # Format and return components string
122
+ return self.components(components, train)
123
+
124
+ # Derive quantization. Prefer backend-specific setting. Fallback to root-level parameter.
125
+ quantize = self.setting("quantize", self.config.get("quantize"))
126
+ quantize = 8 if isinstance(quantize, bool) else quantize
127
+
128
+ # Get storage setting
129
+ storage = f"SQ{quantize}" if quantize else "Flat"
130
+
131
+ # Small index, use storage directly with IDMap
132
+ if count <= 5000:
133
+ return "BFlat" if self.qbits else f"IDMap,{storage}"
134
+
135
+ x = self.cells(train)
136
+ components = f"BIVF{x}" if self.qbits else f"IVF{x},{storage}"
137
+
138
+ return components
139
+
140
+ def create(self, embeddings, params):
141
+ """
142
+ Creates a new index.
143
+
144
+ Args:
145
+ embeddings: embeddings to index
146
+ params: index parameters
147
+
148
+ Returns:
149
+ new index
150
+ """
151
+
152
+ # Create binary index
153
+ if self.qbits:
154
+ index = index_binary_factory(embeddings.shape[1] * 8, params)
155
+
156
+ # Wrap with BinaryIDMap, if necessary
157
+ if any(x in params for x in ["BFlat", "BHNSW"]):
158
+ index = IndexBinaryIDMap(index)
159
+
160
+ return index
161
+
162
+ # Create standard float index
163
+ return index_factory(embeddings.shape[1], params, METRIC_INNER_PRODUCT)
164
+
165
+ def cells(self, count):
166
+ """
167
+ Calculates the number of IVF cells for an IVF index.
168
+
169
+ Args:
170
+ count: number of embeddings rows
171
+
172
+ Returns:
173
+ number of IVF cells
174
+ """
175
+
176
+ # Calculate number of IVF cells where x = min(4 * sqrt(embeddings count), embeddings count / 39)
177
+ # Faiss requires at least 39 points per cluster
178
+ return max(min(round(4 * math.sqrt(count)), int(count / 39)), 1)
179
+
180
+ def components(self, components, train):
181
+ """
182
+ Formats a components string. This method automatically calculates the optimal number of IVF cells, if omitted.
183
+
184
+ Args:
185
+ components: input components string
186
+ train: number of rows selected for model training
187
+
188
+ Returns:
189
+ formatted components string
190
+ """
191
+
192
+ # Optimal number of IVF cells
193
+ x = self.cells(train)
194
+
195
+ # Add number of IVF cells, if missing
196
+ components = [f"IVF{x}" if component == "IVF" else component for component in components.split(",")]
197
+
198
+ # Return components string
199
+ return ",".join(components)
200
+
201
+ def nprobe(self):
202
+ """
203
+ Gets or derives the nprobe search parameter.
204
+
205
+ Returns:
206
+ nprobe setting
207
+ """
208
+
209
+ # Get size of embeddings index
210
+ count = self.count()
211
+
212
+ default = 6 if count <= 5000 else round(self.cells(count) / 16)
213
+ return self.setting("nprobe", default)
214
+
215
+ def scores(self, scores):
216
+ """
217
+ Calculates the index score from the input score. This method returns the hamming score
218
+ (1.0 - (hamming distance / total number of bits)) for binary indexes and the input
219
+ scores otherwise.
220
+
221
+ Args:
222
+ scores: input scores
223
+
224
+ Returns:
225
+ index scores
226
+ """
227
+
228
+ # Calculate hamming score, bound between 0.0 - 1.0
229
+ if self.qbits:
230
+ return np.clip(1.0 - (scores / (self.config["dimensions"] * 8)), 0.0, 1.0).tolist()
231
+
232
+ # Standard scoring
233
+ return scores.tolist()
@@ -0,0 +1,104 @@
1
+ """
2
+ HNSW module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ # Conditional import
8
+ try:
9
+ # pylint: disable=E0611
10
+ from hnswlib import Index
11
+
12
+ HNSWLIB = True
13
+ except ImportError:
14
+ HNSWLIB = False
15
+
16
+ from ..base import ANN
17
+
18
+
19
+ class HNSW(ANN):
20
+ """
21
+ Builds an ANN index using the hnswlib library.
22
+ """
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+
27
+ if not HNSWLIB:
28
+ raise ImportError('HNSW is not available - install "ann" extra to enable')
29
+
30
+ def load(self, path):
31
+ # Load index
32
+ self.backend = Index(dim=self.config["dimensions"], space=self.config["metric"])
33
+ self.backend.load_index(path)
34
+
35
+ def index(self, embeddings):
36
+ # Inner product is equal to cosine similarity on normalized vectors
37
+ self.config["metric"] = "ip"
38
+
39
+ # Lookup index settings
40
+ efconstruction = self.setting("efconstruction", 200)
41
+ m = self.setting("m", 16)
42
+ seed = self.setting("randomseed", 100)
43
+
44
+ # Create index
45
+ self.backend = Index(dim=self.config["dimensions"], space=self.config["metric"])
46
+ self.backend.init_index(max_elements=embeddings.shape[0], ef_construction=efconstruction, M=m, random_seed=seed)
47
+
48
+ # Add items - position in embeddings is used as the id
49
+ self.backend.add_items(embeddings, np.arange(embeddings.shape[0], dtype=np.int64))
50
+
51
+ # Add id offset, delete counter and index build metadata
52
+ self.config["offset"] = embeddings.shape[0]
53
+ self.config["deletes"] = 0
54
+ self.metadata({"efconstruction": efconstruction, "m": m, "seed": seed})
55
+
56
+ def append(self, embeddings):
57
+ new = embeddings.shape[0]
58
+
59
+ # Resize index
60
+ self.backend.resize_index(self.config["offset"] + new)
61
+
62
+ # Append new ids - position in embeddings + existing offset is used as the id
63
+ self.backend.add_items(embeddings, np.arange(self.config["offset"], self.config["offset"] + new, dtype=np.int64))
64
+
65
+ # Update id offset and index metadata
66
+ self.config["offset"] += new
67
+ self.metadata()
68
+
69
+ def delete(self, ids):
70
+ # Mark elements as deleted to omit from search results
71
+ for uid in ids:
72
+ try:
73
+ self.backend.mark_deleted(uid)
74
+ self.config["deletes"] += 1
75
+ except RuntimeError:
76
+ # Ignore label not found error
77
+ continue
78
+
79
+ def search(self, queries, limit):
80
+ # Set ef query param
81
+ ef = self.setting("efsearch")
82
+ if ef:
83
+ self.backend.set_ef(ef)
84
+
85
+ # Run the query
86
+ ids, distances = self.backend.knn_query(queries, k=limit)
87
+
88
+ # Map results to [(id, score)]
89
+ results = []
90
+ for x, distance in enumerate(distances):
91
+ # Convert distances to similarity scores
92
+ scores = [1 - d for d in distance.tolist()]
93
+
94
+ # Build (id, score) tuples, convert np.int64 to python int
95
+ results.append(list(zip(ids[x].tolist(), scores)))
96
+
97
+ return results
98
+
99
+ def count(self):
100
+ return self.backend.get_current_count() - self.config["deletes"]
101
+
102
+ def save(self, path):
103
+ # Write index
104
+ self.backend.save_index(path)