mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,49 @@
|
|
1
|
+
"""
|
2
|
+
Function imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from smolagents import Tool
|
6
|
+
|
7
|
+
|
8
|
+
class FunctionTool(Tool):
|
9
|
+
"""
|
10
|
+
Creates a FunctionTool. A FunctionTool takes descriptive configuration and injects it along with a target function
|
11
|
+
into an LLM prompt.
|
12
|
+
"""
|
13
|
+
|
14
|
+
# pylint: disable=W0231
|
15
|
+
def __init__(self, config):
|
16
|
+
"""
|
17
|
+
Creates a FunctionTool.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
config: `name`, `description`, `inputs`, `output` and `target` configuration
|
21
|
+
"""
|
22
|
+
|
23
|
+
# Tool parameters
|
24
|
+
self.name = config["name"]
|
25
|
+
self.description = config["description"]
|
26
|
+
self.inputs = config["inputs"]
|
27
|
+
self.output_type = config.get("output", config.get("output_type", "any"))
|
28
|
+
self.target = config["target"]
|
29
|
+
|
30
|
+
# pylint: disable=C0103
|
31
|
+
# Skip forward signature validation
|
32
|
+
self.skip_forward_signature_validation = True
|
33
|
+
|
34
|
+
# Validate parameters and initialize tool
|
35
|
+
super().__init__()
|
36
|
+
|
37
|
+
def forward(self, *args, **kwargs):
|
38
|
+
"""
|
39
|
+
Runs target function.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
args: positional args
|
43
|
+
kwargs: keyword args
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
result
|
47
|
+
"""
|
48
|
+
|
49
|
+
return self.target(*args, **kwargs)
|
txtai/ann/__init__.py
ADDED
txtai/ann/base.py
ADDED
@@ -0,0 +1,153 @@
|
|
1
|
+
"""
|
2
|
+
ANN (Approximate Nearest Neighbor) module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import datetime
|
6
|
+
import platform
|
7
|
+
|
8
|
+
from ..version import __version__
|
9
|
+
|
10
|
+
|
11
|
+
class ANN:
|
12
|
+
"""
|
13
|
+
Base class for ANN instances. This class builds vector indexes to support similarity search.
|
14
|
+
The built-in ANN backends store ids and vectors. Content storage is supported via database instances.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, config):
|
18
|
+
"""
|
19
|
+
Creates a new ANN.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: index configuration parameters
|
23
|
+
"""
|
24
|
+
|
25
|
+
# ANN index
|
26
|
+
self.backend = None
|
27
|
+
|
28
|
+
# ANN configuration
|
29
|
+
self.config = config
|
30
|
+
|
31
|
+
def load(self, path):
|
32
|
+
"""
|
33
|
+
Loads an ANN at path.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
path: path to load ann index
|
37
|
+
"""
|
38
|
+
|
39
|
+
raise NotImplementedError
|
40
|
+
|
41
|
+
def index(self, embeddings):
|
42
|
+
"""
|
43
|
+
Builds an ANN index.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
embeddings: embeddings array
|
47
|
+
"""
|
48
|
+
|
49
|
+
raise NotImplementedError
|
50
|
+
|
51
|
+
def append(self, embeddings):
|
52
|
+
"""
|
53
|
+
Append elements to an existing index.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
embeddings: embeddings array
|
57
|
+
"""
|
58
|
+
|
59
|
+
raise NotImplementedError
|
60
|
+
|
61
|
+
def delete(self, ids):
|
62
|
+
"""
|
63
|
+
Deletes elements from existing index.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
ids: ids to delete
|
67
|
+
"""
|
68
|
+
|
69
|
+
raise NotImplementedError
|
70
|
+
|
71
|
+
def search(self, queries, limit):
|
72
|
+
"""
|
73
|
+
Searches ANN index for query. Returns topn results.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
queries: queries array
|
77
|
+
limit: maximum results
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
query results
|
81
|
+
"""
|
82
|
+
|
83
|
+
raise NotImplementedError
|
84
|
+
|
85
|
+
def count(self):
|
86
|
+
"""
|
87
|
+
Number of elements in the ANN index.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
count
|
91
|
+
"""
|
92
|
+
|
93
|
+
raise NotImplementedError
|
94
|
+
|
95
|
+
def save(self, path):
|
96
|
+
"""
|
97
|
+
Saves an ANN index at path.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
path: path to save ann index
|
101
|
+
"""
|
102
|
+
|
103
|
+
raise NotImplementedError
|
104
|
+
|
105
|
+
def close(self):
|
106
|
+
"""
|
107
|
+
Closes this ANN.
|
108
|
+
"""
|
109
|
+
|
110
|
+
self.backend = None
|
111
|
+
|
112
|
+
def setting(self, name, default=None):
|
113
|
+
"""
|
114
|
+
Looks up backend specific setting.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
name: setting name
|
118
|
+
default: default value when setting not found
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
setting value
|
122
|
+
"""
|
123
|
+
|
124
|
+
# Get the backend-specific config object
|
125
|
+
backend = self.config.get(self.config["backend"])
|
126
|
+
|
127
|
+
# Get setting value, set default value if not found
|
128
|
+
setting = backend.get(name) if backend else None
|
129
|
+
return setting if setting else default
|
130
|
+
|
131
|
+
def metadata(self, settings=None):
|
132
|
+
"""
|
133
|
+
Adds index build metadata.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
settings: index build settings
|
137
|
+
"""
|
138
|
+
|
139
|
+
# ISO 8601 timestamp
|
140
|
+
create = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
141
|
+
|
142
|
+
# Set build metadata if this is not an update
|
143
|
+
if settings:
|
144
|
+
self.config["build"] = {
|
145
|
+
"create": create,
|
146
|
+
"python": platform.python_version(),
|
147
|
+
"settings": settings,
|
148
|
+
"system": f"{platform.system()} ({platform.machine()})",
|
149
|
+
"txtai": __version__,
|
150
|
+
}
|
151
|
+
|
152
|
+
# Set last update date
|
153
|
+
self.config["update"] = create
|
txtai/ann/dense/annoy.py
ADDED
@@ -0,0 +1,72 @@
|
|
1
|
+
"""
|
2
|
+
Annoy module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from annoy import AnnoyIndex
|
8
|
+
|
9
|
+
ANNOY = True
|
10
|
+
except ImportError:
|
11
|
+
ANNOY = False
|
12
|
+
|
13
|
+
from ..base import ANN
|
14
|
+
|
15
|
+
|
16
|
+
# pylint: disable=W0223
|
17
|
+
class Annoy(ANN):
|
18
|
+
"""
|
19
|
+
Builds an ANN index using the Annoy library.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, config):
|
23
|
+
super().__init__(config)
|
24
|
+
|
25
|
+
if not ANNOY:
|
26
|
+
raise ImportError('Annoy is not available - install "ann" extra to enable')
|
27
|
+
|
28
|
+
def load(self, path):
|
29
|
+
# Load index
|
30
|
+
self.backend = AnnoyIndex(self.config["dimensions"], self.config["metric"])
|
31
|
+
self.backend.load(path)
|
32
|
+
|
33
|
+
def index(self, embeddings):
|
34
|
+
# Inner product is equal to cosine similarity on normalized vectors
|
35
|
+
self.config["metric"] = "dot"
|
36
|
+
|
37
|
+
# Create index
|
38
|
+
self.backend = AnnoyIndex(self.config["dimensions"], self.config["metric"])
|
39
|
+
|
40
|
+
# Add items - position in embeddings is used as the id
|
41
|
+
for x in range(embeddings.shape[0]):
|
42
|
+
self.backend.add_item(x, embeddings[x])
|
43
|
+
|
44
|
+
# Build index
|
45
|
+
ntrees = self.setting("ntrees", 10)
|
46
|
+
self.backend.build(ntrees)
|
47
|
+
|
48
|
+
# Add index build metadata
|
49
|
+
self.metadata({"ntrees": ntrees})
|
50
|
+
|
51
|
+
def search(self, queries, limit):
|
52
|
+
# Lookup search k setting
|
53
|
+
searchk = self.setting("searchk", -1)
|
54
|
+
|
55
|
+
# Annoy doesn't have a built in batch query method
|
56
|
+
results = []
|
57
|
+
for query in queries:
|
58
|
+
# Run the query
|
59
|
+
ids, scores = self.backend.get_nns_by_vector(query, n=limit, search_k=searchk, include_distances=True)
|
60
|
+
|
61
|
+
# Map results to [(id, score)]
|
62
|
+
results.append(list(zip(ids, scores)))
|
63
|
+
|
64
|
+
return results
|
65
|
+
|
66
|
+
def count(self):
|
67
|
+
# Number of items in index
|
68
|
+
return self.backend.get_n_items()
|
69
|
+
|
70
|
+
def save(self, path):
|
71
|
+
# Write index
|
72
|
+
self.backend.save(path)
|
@@ -0,0 +1,76 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...util import Resolver
|
6
|
+
|
7
|
+
from .annoy import Annoy
|
8
|
+
from .faiss import Faiss
|
9
|
+
from .hnsw import HNSW
|
10
|
+
from .numpy import NumPy
|
11
|
+
from .pgvector import PGVector
|
12
|
+
from .sqlite import SQLite
|
13
|
+
from .torch import Torch
|
14
|
+
|
15
|
+
|
16
|
+
class ANNFactory:
|
17
|
+
"""
|
18
|
+
Methods to create ANN indexes.
|
19
|
+
"""
|
20
|
+
|
21
|
+
@staticmethod
|
22
|
+
def create(config):
|
23
|
+
"""
|
24
|
+
Create an ANN.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
config: index configuration parameters
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
ANN
|
31
|
+
"""
|
32
|
+
|
33
|
+
# ANN instance
|
34
|
+
ann = None
|
35
|
+
backend = config.get("backend", "faiss")
|
36
|
+
|
37
|
+
# Create ANN instance
|
38
|
+
if backend == "annoy":
|
39
|
+
ann = Annoy(config)
|
40
|
+
elif backend == "faiss":
|
41
|
+
ann = Faiss(config)
|
42
|
+
elif backend == "hnsw":
|
43
|
+
ann = HNSW(config)
|
44
|
+
elif backend == "numpy":
|
45
|
+
ann = NumPy(config)
|
46
|
+
elif backend == "pgvector":
|
47
|
+
ann = PGVector(config)
|
48
|
+
elif backend == "sqlite":
|
49
|
+
ann = SQLite(config)
|
50
|
+
elif backend == "torch":
|
51
|
+
ann = Torch(config)
|
52
|
+
else:
|
53
|
+
ann = ANNFactory.resolve(backend, config)
|
54
|
+
|
55
|
+
# Store config back
|
56
|
+
config["backend"] = backend
|
57
|
+
|
58
|
+
return ann
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def resolve(backend, config):
|
62
|
+
"""
|
63
|
+
Attempt to resolve a custom backend.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
backend: backend class
|
67
|
+
config: index configuration parameters
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
ANN
|
71
|
+
"""
|
72
|
+
|
73
|
+
try:
|
74
|
+
return Resolver()(backend)(config)
|
75
|
+
except Exception as e:
|
76
|
+
raise ImportError(f"Unable to resolve ann backend: '{backend}'") from e
|
txtai/ann/dense/faiss.py
ADDED
@@ -0,0 +1,233 @@
|
|
1
|
+
"""
|
2
|
+
Faiss module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import math
|
6
|
+
import platform
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from faiss import omp_set_num_threads
|
11
|
+
from faiss import index_factory, IO_FLAG_MMAP, METRIC_INNER_PRODUCT, read_index, write_index
|
12
|
+
from faiss import index_binary_factory, read_index_binary, write_index_binary, IndexBinaryIDMap
|
13
|
+
|
14
|
+
from ..base import ANN
|
15
|
+
|
16
|
+
if platform.system() == "Darwin":
|
17
|
+
# Workaround for a Faiss issue causing segmentation faults on macOS. See txtai FAQ for more.
|
18
|
+
omp_set_num_threads(1)
|
19
|
+
|
20
|
+
|
21
|
+
class Faiss(ANN):
|
22
|
+
"""
|
23
|
+
Builds an ANN index using the Faiss library.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(self, config):
|
27
|
+
super().__init__(config)
|
28
|
+
|
29
|
+
# Scalar quantization
|
30
|
+
quantize = self.config.get("quantize")
|
31
|
+
self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
|
32
|
+
|
33
|
+
def load(self, path):
|
34
|
+
# Get read function
|
35
|
+
readindex = read_index_binary if self.qbits else read_index
|
36
|
+
|
37
|
+
# Load index
|
38
|
+
self.backend = readindex(path, IO_FLAG_MMAP if self.setting("mmap") is True else 0)
|
39
|
+
|
40
|
+
def index(self, embeddings):
|
41
|
+
# Compute model training size
|
42
|
+
train, sample = embeddings, self.setting("sample")
|
43
|
+
if sample:
|
44
|
+
# Get sample for training
|
45
|
+
rng = np.random.default_rng(0)
|
46
|
+
indices = sorted(rng.choice(train.shape[0], int(sample * train.shape[0]), replace=False, shuffle=False))
|
47
|
+
train = train[indices]
|
48
|
+
|
49
|
+
# Configure embeddings index. Inner product is equal to cosine similarity on normalized vectors.
|
50
|
+
params = self.configure(embeddings.shape[0], train.shape[0])
|
51
|
+
|
52
|
+
# Create index
|
53
|
+
self.backend = self.create(embeddings, params)
|
54
|
+
|
55
|
+
# Train model
|
56
|
+
self.backend.train(train)
|
57
|
+
|
58
|
+
# Add embeddings - position in embeddings is used as the id
|
59
|
+
self.backend.add_with_ids(embeddings, np.arange(embeddings.shape[0], dtype=np.int64))
|
60
|
+
|
61
|
+
# Add id offset and index build metadata
|
62
|
+
self.config["offset"] = embeddings.shape[0]
|
63
|
+
self.metadata({"components": params})
|
64
|
+
|
65
|
+
def append(self, embeddings):
|
66
|
+
new = embeddings.shape[0]
|
67
|
+
|
68
|
+
# Append new ids - position in embeddings + existing offset is used as the id
|
69
|
+
self.backend.add_with_ids(embeddings, np.arange(self.config["offset"], self.config["offset"] + new, dtype=np.int64))
|
70
|
+
|
71
|
+
# Update id offset and index metadata
|
72
|
+
self.config["offset"] += new
|
73
|
+
self.metadata()
|
74
|
+
|
75
|
+
def delete(self, ids):
|
76
|
+
# Remove specified ids
|
77
|
+
self.backend.remove_ids(np.array(ids, dtype=np.int64))
|
78
|
+
|
79
|
+
def search(self, queries, limit):
|
80
|
+
# Set nprobe and nflip search parameters
|
81
|
+
self.backend.nprobe = self.nprobe()
|
82
|
+
self.backend.nflip = self.setting("nflip", self.backend.nprobe)
|
83
|
+
|
84
|
+
# Run the query
|
85
|
+
scores, ids = self.backend.search(queries, limit)
|
86
|
+
|
87
|
+
# Map results to [(id, score)]
|
88
|
+
results = []
|
89
|
+
for x, score in enumerate(scores):
|
90
|
+
# Transform scores and add results
|
91
|
+
results.append(list(zip(ids[x].tolist(), self.scores(score))))
|
92
|
+
|
93
|
+
return results
|
94
|
+
|
95
|
+
def count(self):
|
96
|
+
return self.backend.ntotal
|
97
|
+
|
98
|
+
def save(self, path):
|
99
|
+
# Get write function
|
100
|
+
writeindex = write_index_binary if self.qbits else write_index
|
101
|
+
|
102
|
+
# Write index
|
103
|
+
writeindex(self.backend, path)
|
104
|
+
|
105
|
+
def configure(self, count, train):
|
106
|
+
"""
|
107
|
+
Configures settings for a new index.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
count: initial number of embeddings rows
|
111
|
+
train: number of rows selected for model training
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
user-specified or generated components setting
|
115
|
+
"""
|
116
|
+
|
117
|
+
# Lookup components setting
|
118
|
+
components = self.setting("components")
|
119
|
+
|
120
|
+
if components:
|
121
|
+
# Format and return components string
|
122
|
+
return self.components(components, train)
|
123
|
+
|
124
|
+
# Derive quantization. Prefer backend-specific setting. Fallback to root-level parameter.
|
125
|
+
quantize = self.setting("quantize", self.config.get("quantize"))
|
126
|
+
quantize = 8 if isinstance(quantize, bool) else quantize
|
127
|
+
|
128
|
+
# Get storage setting
|
129
|
+
storage = f"SQ{quantize}" if quantize else "Flat"
|
130
|
+
|
131
|
+
# Small index, use storage directly with IDMap
|
132
|
+
if count <= 5000:
|
133
|
+
return "BFlat" if self.qbits else f"IDMap,{storage}"
|
134
|
+
|
135
|
+
x = self.cells(train)
|
136
|
+
components = f"BIVF{x}" if self.qbits else f"IVF{x},{storage}"
|
137
|
+
|
138
|
+
return components
|
139
|
+
|
140
|
+
def create(self, embeddings, params):
|
141
|
+
"""
|
142
|
+
Creates a new index.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
embeddings: embeddings to index
|
146
|
+
params: index parameters
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
new index
|
150
|
+
"""
|
151
|
+
|
152
|
+
# Create binary index
|
153
|
+
if self.qbits:
|
154
|
+
index = index_binary_factory(embeddings.shape[1] * 8, params)
|
155
|
+
|
156
|
+
# Wrap with BinaryIDMap, if necessary
|
157
|
+
if any(x in params for x in ["BFlat", "BHNSW"]):
|
158
|
+
index = IndexBinaryIDMap(index)
|
159
|
+
|
160
|
+
return index
|
161
|
+
|
162
|
+
# Create standard float index
|
163
|
+
return index_factory(embeddings.shape[1], params, METRIC_INNER_PRODUCT)
|
164
|
+
|
165
|
+
def cells(self, count):
|
166
|
+
"""
|
167
|
+
Calculates the number of IVF cells for an IVF index.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
count: number of embeddings rows
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
number of IVF cells
|
174
|
+
"""
|
175
|
+
|
176
|
+
# Calculate number of IVF cells where x = min(4 * sqrt(embeddings count), embeddings count / 39)
|
177
|
+
# Faiss requires at least 39 points per cluster
|
178
|
+
return max(min(round(4 * math.sqrt(count)), int(count / 39)), 1)
|
179
|
+
|
180
|
+
def components(self, components, train):
|
181
|
+
"""
|
182
|
+
Formats a components string. This method automatically calculates the optimal number of IVF cells, if omitted.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
components: input components string
|
186
|
+
train: number of rows selected for model training
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
formatted components string
|
190
|
+
"""
|
191
|
+
|
192
|
+
# Optimal number of IVF cells
|
193
|
+
x = self.cells(train)
|
194
|
+
|
195
|
+
# Add number of IVF cells, if missing
|
196
|
+
components = [f"IVF{x}" if component == "IVF" else component for component in components.split(",")]
|
197
|
+
|
198
|
+
# Return components string
|
199
|
+
return ",".join(components)
|
200
|
+
|
201
|
+
def nprobe(self):
|
202
|
+
"""
|
203
|
+
Gets or derives the nprobe search parameter.
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
nprobe setting
|
207
|
+
"""
|
208
|
+
|
209
|
+
# Get size of embeddings index
|
210
|
+
count = self.count()
|
211
|
+
|
212
|
+
default = 6 if count <= 5000 else round(self.cells(count) / 16)
|
213
|
+
return self.setting("nprobe", default)
|
214
|
+
|
215
|
+
def scores(self, scores):
|
216
|
+
"""
|
217
|
+
Calculates the index score from the input score. This method returns the hamming score
|
218
|
+
(1.0 - (hamming distance / total number of bits)) for binary indexes and the input
|
219
|
+
scores otherwise.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
scores: input scores
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
index scores
|
226
|
+
"""
|
227
|
+
|
228
|
+
# Calculate hamming score, bound between 0.0 - 1.0
|
229
|
+
if self.qbits:
|
230
|
+
return np.clip(1.0 - (scores / (self.config["dimensions"] * 8)), 0.0, 1.0).tolist()
|
231
|
+
|
232
|
+
# Standard scoring
|
233
|
+
return scores.tolist()
|
txtai/ann/dense/hnsw.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
"""
|
2
|
+
HNSW module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
# Conditional import
|
8
|
+
try:
|
9
|
+
# pylint: disable=E0611
|
10
|
+
from hnswlib import Index
|
11
|
+
|
12
|
+
HNSWLIB = True
|
13
|
+
except ImportError:
|
14
|
+
HNSWLIB = False
|
15
|
+
|
16
|
+
from ..base import ANN
|
17
|
+
|
18
|
+
|
19
|
+
class HNSW(ANN):
|
20
|
+
"""
|
21
|
+
Builds an ANN index using the hnswlib library.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, config):
|
25
|
+
super().__init__(config)
|
26
|
+
|
27
|
+
if not HNSWLIB:
|
28
|
+
raise ImportError('HNSW is not available - install "ann" extra to enable')
|
29
|
+
|
30
|
+
def load(self, path):
|
31
|
+
# Load index
|
32
|
+
self.backend = Index(dim=self.config["dimensions"], space=self.config["metric"])
|
33
|
+
self.backend.load_index(path)
|
34
|
+
|
35
|
+
def index(self, embeddings):
|
36
|
+
# Inner product is equal to cosine similarity on normalized vectors
|
37
|
+
self.config["metric"] = "ip"
|
38
|
+
|
39
|
+
# Lookup index settings
|
40
|
+
efconstruction = self.setting("efconstruction", 200)
|
41
|
+
m = self.setting("m", 16)
|
42
|
+
seed = self.setting("randomseed", 100)
|
43
|
+
|
44
|
+
# Create index
|
45
|
+
self.backend = Index(dim=self.config["dimensions"], space=self.config["metric"])
|
46
|
+
self.backend.init_index(max_elements=embeddings.shape[0], ef_construction=efconstruction, M=m, random_seed=seed)
|
47
|
+
|
48
|
+
# Add items - position in embeddings is used as the id
|
49
|
+
self.backend.add_items(embeddings, np.arange(embeddings.shape[0], dtype=np.int64))
|
50
|
+
|
51
|
+
# Add id offset, delete counter and index build metadata
|
52
|
+
self.config["offset"] = embeddings.shape[0]
|
53
|
+
self.config["deletes"] = 0
|
54
|
+
self.metadata({"efconstruction": efconstruction, "m": m, "seed": seed})
|
55
|
+
|
56
|
+
def append(self, embeddings):
|
57
|
+
new = embeddings.shape[0]
|
58
|
+
|
59
|
+
# Resize index
|
60
|
+
self.backend.resize_index(self.config["offset"] + new)
|
61
|
+
|
62
|
+
# Append new ids - position in embeddings + existing offset is used as the id
|
63
|
+
self.backend.add_items(embeddings, np.arange(self.config["offset"], self.config["offset"] + new, dtype=np.int64))
|
64
|
+
|
65
|
+
# Update id offset and index metadata
|
66
|
+
self.config["offset"] += new
|
67
|
+
self.metadata()
|
68
|
+
|
69
|
+
def delete(self, ids):
|
70
|
+
# Mark elements as deleted to omit from search results
|
71
|
+
for uid in ids:
|
72
|
+
try:
|
73
|
+
self.backend.mark_deleted(uid)
|
74
|
+
self.config["deletes"] += 1
|
75
|
+
except RuntimeError:
|
76
|
+
# Ignore label not found error
|
77
|
+
continue
|
78
|
+
|
79
|
+
def search(self, queries, limit):
|
80
|
+
# Set ef query param
|
81
|
+
ef = self.setting("efsearch")
|
82
|
+
if ef:
|
83
|
+
self.backend.set_ef(ef)
|
84
|
+
|
85
|
+
# Run the query
|
86
|
+
ids, distances = self.backend.knn_query(queries, k=limit)
|
87
|
+
|
88
|
+
# Map results to [(id, score)]
|
89
|
+
results = []
|
90
|
+
for x, distance in enumerate(distances):
|
91
|
+
# Convert distances to similarity scores
|
92
|
+
scores = [1 - d for d in distance.tolist()]
|
93
|
+
|
94
|
+
# Build (id, score) tuples, convert np.int64 to python int
|
95
|
+
results.append(list(zip(ids[x].tolist(), scores)))
|
96
|
+
|
97
|
+
return results
|
98
|
+
|
99
|
+
def count(self):
|
100
|
+
return self.backend.get_current_count() - self.config["deletes"]
|
101
|
+
|
102
|
+
def save(self, path):
|
103
|
+
# Write index
|
104
|
+
self.backend.save_index(path)
|