mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/api/base.py
ADDED
@@ -0,0 +1,159 @@
|
|
1
|
+
"""
|
2
|
+
API module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
|
7
|
+
from .cluster import Cluster
|
8
|
+
|
9
|
+
from ..app import Application
|
10
|
+
|
11
|
+
|
12
|
+
class API(Application):
|
13
|
+
"""
|
14
|
+
Base API template. The API is an extended txtai application, adding the ability to cluster API instances together.
|
15
|
+
|
16
|
+
Downstream applications can extend this base template to add/modify functionality.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, config, loaddata=True):
|
20
|
+
super().__init__(config, loaddata)
|
21
|
+
|
22
|
+
# Embeddings cluster
|
23
|
+
self.cluster = None
|
24
|
+
if self.config.get("cluster"):
|
25
|
+
self.cluster = Cluster(self.config["cluster"])
|
26
|
+
|
27
|
+
# pylint: disable=W0221
|
28
|
+
def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False, request=None):
|
29
|
+
# When search is invoked via the API, limit is set from the request
|
30
|
+
# When search is invoked directly, limit is set using the method parameter
|
31
|
+
limit = self.limit(request.query_params.get("limit") if request and hasattr(request, "query_params") else limit)
|
32
|
+
weights = self.weights(request.query_params.get("weights") if request and hasattr(request, "query_params") else weights)
|
33
|
+
index = request.query_params.get("index") if request and hasattr(request, "query_params") else index
|
34
|
+
parameters = request.query_params.get("parameters") if request and hasattr(request, "query_params") else parameters
|
35
|
+
graph = request.query_params.get("graph") if request and hasattr(request, "query_params") else graph
|
36
|
+
|
37
|
+
# Decode parameters
|
38
|
+
parameters = json.loads(parameters) if parameters and isinstance(parameters, str) else parameters
|
39
|
+
|
40
|
+
if self.cluster:
|
41
|
+
return self.cluster.search(query, limit, weights, index, parameters, graph)
|
42
|
+
|
43
|
+
return super().search(query, limit, weights, index, parameters, graph)
|
44
|
+
|
45
|
+
def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
|
46
|
+
if self.cluster:
|
47
|
+
return self.cluster.batchsearch(queries, self.limit(limit), weights, index, parameters, graph)
|
48
|
+
|
49
|
+
return super().batchsearch(queries, limit, weights, index, parameters, graph)
|
50
|
+
|
51
|
+
def add(self, documents):
|
52
|
+
"""
|
53
|
+
Adds a batch of documents for indexing.
|
54
|
+
|
55
|
+
Downstream applications can override this method to also store full documents in an external system.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
documents: list of {id: value, text: value}
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
unmodified input documents
|
62
|
+
"""
|
63
|
+
|
64
|
+
if self.cluster:
|
65
|
+
self.cluster.add(documents)
|
66
|
+
else:
|
67
|
+
super().add(documents)
|
68
|
+
|
69
|
+
return documents
|
70
|
+
|
71
|
+
def index(self):
|
72
|
+
"""
|
73
|
+
Builds an embeddings index for previously batched documents.
|
74
|
+
"""
|
75
|
+
|
76
|
+
if self.cluster:
|
77
|
+
self.cluster.index()
|
78
|
+
else:
|
79
|
+
super().index()
|
80
|
+
|
81
|
+
def upsert(self):
|
82
|
+
"""
|
83
|
+
Runs an embeddings upsert operation for previously batched documents.
|
84
|
+
"""
|
85
|
+
|
86
|
+
if self.cluster:
|
87
|
+
self.cluster.upsert()
|
88
|
+
else:
|
89
|
+
super().upsert()
|
90
|
+
|
91
|
+
def delete(self, ids):
|
92
|
+
"""
|
93
|
+
Deletes from an embeddings index. Returns list of ids deleted.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
ids: list of ids to delete
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
ids deleted
|
100
|
+
"""
|
101
|
+
|
102
|
+
if self.cluster:
|
103
|
+
return self.cluster.delete(ids)
|
104
|
+
|
105
|
+
return super().delete(ids)
|
106
|
+
|
107
|
+
def reindex(self, config, function=None):
|
108
|
+
"""
|
109
|
+
Recreates this embeddings index using config. This method only works if document content storage is enabled.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
config: new config
|
113
|
+
function: optional function to prepare content for indexing
|
114
|
+
"""
|
115
|
+
|
116
|
+
if self.cluster:
|
117
|
+
self.cluster.reindex(config, function)
|
118
|
+
else:
|
119
|
+
super().reindex(config, function)
|
120
|
+
|
121
|
+
def count(self):
|
122
|
+
"""
|
123
|
+
Total number of elements in this embeddings index.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
number of elements in embeddings index
|
127
|
+
"""
|
128
|
+
|
129
|
+
if self.cluster:
|
130
|
+
return self.cluster.count()
|
131
|
+
|
132
|
+
return super().count()
|
133
|
+
|
134
|
+
def limit(self, limit):
|
135
|
+
"""
|
136
|
+
Parses the number of results to return from the request. Allows range of 1-250, with a default of 10.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
limit: limit parameter
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
bounded limit
|
143
|
+
"""
|
144
|
+
|
145
|
+
# Return between 1 and 250 results, defaults to 10
|
146
|
+
return max(1, min(250, int(limit) if limit else 10))
|
147
|
+
|
148
|
+
def weights(self, weights):
|
149
|
+
"""
|
150
|
+
Parses the weights parameter from the request.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
weights: weights parameter
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
weights
|
157
|
+
"""
|
158
|
+
|
159
|
+
return float(weights) if weights else weights
|
txtai/api/cluster.py
ADDED
@@ -0,0 +1,295 @@
|
|
1
|
+
"""
|
2
|
+
Cluster module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import json
|
7
|
+
import random
|
8
|
+
import urllib.parse
|
9
|
+
import zlib
|
10
|
+
|
11
|
+
import aiohttp
|
12
|
+
|
13
|
+
from ..database.sql import Aggregate
|
14
|
+
|
15
|
+
|
16
|
+
class Cluster:
|
17
|
+
"""
|
18
|
+
Aggregates multiple embeddings shards into a single logical embeddings instance.
|
19
|
+
"""
|
20
|
+
|
21
|
+
# pylint: disable = W0231
|
22
|
+
def __init__(self, config=None):
|
23
|
+
"""
|
24
|
+
Creates a new Cluster.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
config: cluster configuration
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Configuration
|
31
|
+
self.config = config
|
32
|
+
|
33
|
+
# Embeddings shard urls
|
34
|
+
self.shards = None
|
35
|
+
if "shards" in self.config:
|
36
|
+
self.shards = self.config["shards"]
|
37
|
+
|
38
|
+
# Query aggregator
|
39
|
+
self.aggregate = Aggregate()
|
40
|
+
|
41
|
+
def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False):
|
42
|
+
"""
|
43
|
+
Finds documents most similar to the input query. This method will run either an index search
|
44
|
+
or an index + database search depending on if a database is available.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
query: input query
|
48
|
+
limit: maximum results
|
49
|
+
weights: hybrid score weights, if applicable
|
50
|
+
index: index name, if applicable
|
51
|
+
parameters: dict of named parameters to bind to placeholders
|
52
|
+
graph: return graph results if True
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
list of {id: value, score: value} for index search, list of dict for an index + database search
|
56
|
+
"""
|
57
|
+
|
58
|
+
# Build URL
|
59
|
+
action = f"search?query={urllib.parse.quote_plus(query)}"
|
60
|
+
if limit:
|
61
|
+
action += f"&limit={limit}"
|
62
|
+
if weights:
|
63
|
+
action += f"&weights={weights}"
|
64
|
+
if index:
|
65
|
+
action += f"&index={index}"
|
66
|
+
if parameters:
|
67
|
+
action += f"¶meters={json.dumps(parameters) if isinstance(parameters, dict) else parameters}"
|
68
|
+
if graph is not None:
|
69
|
+
action += f"&graph={graph}"
|
70
|
+
|
71
|
+
# Run query and flatten results into single results list
|
72
|
+
results = []
|
73
|
+
for result in self.execute("get", action):
|
74
|
+
results.extend(result)
|
75
|
+
|
76
|
+
# Combine aggregate functions and sort
|
77
|
+
results = self.aggregate(query, results)
|
78
|
+
|
79
|
+
# Limit results
|
80
|
+
return results[: (limit if limit else 10)]
|
81
|
+
|
82
|
+
def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
|
83
|
+
"""
|
84
|
+
Finds documents most similar to the input queries. This method will run either an index search
|
85
|
+
or an index + database search depending on if a database is available.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
queries: input queries
|
89
|
+
limit: maximum results
|
90
|
+
weights: hybrid score weights, if applicable
|
91
|
+
index: index name, if applicable
|
92
|
+
parameters: list of dicts of named parameters to bind to placeholders
|
93
|
+
graph: return graph results if True
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
list of {id: value, score: value} per query for index search, list of dict per query for an index + database search
|
97
|
+
"""
|
98
|
+
|
99
|
+
# POST parameters
|
100
|
+
params = {"queries": queries}
|
101
|
+
if limit:
|
102
|
+
params["limit"] = limit
|
103
|
+
if weights:
|
104
|
+
params["weights"] = weights
|
105
|
+
if index:
|
106
|
+
params["index"] = index
|
107
|
+
if parameters:
|
108
|
+
params["parameters"] = parameters
|
109
|
+
if graph is not None:
|
110
|
+
params["graph"] = graph
|
111
|
+
|
112
|
+
# Run query
|
113
|
+
batch = self.execute("post", "batchsearch", [params] * len(self.shards))
|
114
|
+
|
115
|
+
# Combine results per query
|
116
|
+
results = []
|
117
|
+
for x, query in enumerate(queries):
|
118
|
+
result = []
|
119
|
+
for section in batch:
|
120
|
+
result.extend(section[x])
|
121
|
+
|
122
|
+
# Aggregate, sort and limit results
|
123
|
+
results.append(self.aggregate(query, result)[: (limit if limit else 10)])
|
124
|
+
|
125
|
+
return results
|
126
|
+
|
127
|
+
def add(self, documents):
|
128
|
+
"""
|
129
|
+
Adds a batch of documents for indexing.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
documents: list of {id: value, text: value}
|
133
|
+
"""
|
134
|
+
|
135
|
+
self.execute("post", "add", self.shard(documents))
|
136
|
+
|
137
|
+
def index(self):
|
138
|
+
"""
|
139
|
+
Builds an embeddings index for previously batched documents.
|
140
|
+
"""
|
141
|
+
|
142
|
+
self.execute("get", "index")
|
143
|
+
|
144
|
+
def upsert(self):
|
145
|
+
"""
|
146
|
+
Runs an embeddings upsert operation for previously batched documents.
|
147
|
+
"""
|
148
|
+
|
149
|
+
self.execute("get", "upsert")
|
150
|
+
|
151
|
+
def delete(self, ids):
|
152
|
+
"""
|
153
|
+
Deletes from an embeddings cluster. Returns list of ids deleted.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
ids: list of ids to delete
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
ids deleted
|
160
|
+
"""
|
161
|
+
|
162
|
+
return [uid for ids in self.execute("post", "delete", [ids] * len(self.shards)) for uid in ids]
|
163
|
+
|
164
|
+
def reindex(self, config, function=None):
|
165
|
+
"""
|
166
|
+
Recreates this embeddings index using config. This method only works if document content storage is enabled.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
config: new config
|
170
|
+
function: optional function to prepare content for indexing
|
171
|
+
"""
|
172
|
+
|
173
|
+
self.execute("post", "reindex", [{"config": config, "function": function}] * len(self.shards))
|
174
|
+
|
175
|
+
def count(self):
|
176
|
+
"""
|
177
|
+
Total number of elements in this embeddings cluster.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
number of elements in embeddings cluster
|
181
|
+
"""
|
182
|
+
|
183
|
+
return sum(self.execute("get", "count"))
|
184
|
+
|
185
|
+
def shard(self, documents):
|
186
|
+
"""
|
187
|
+
Splits documents into equal sized shards.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
documents: input documents
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
list of evenly sized shards with the last shard having the remaining elements
|
194
|
+
"""
|
195
|
+
|
196
|
+
shards = [[] for _ in range(len(self.shards))]
|
197
|
+
for document in documents:
|
198
|
+
uid = document.get("id") if isinstance(document, dict) else document
|
199
|
+
if uid and isinstance(uid, str):
|
200
|
+
# Quick int hash of string to help derive shard id
|
201
|
+
uid = zlib.adler32(uid.encode("utf-8"))
|
202
|
+
elif uid is None:
|
203
|
+
# Get random shard id when uid isn't set
|
204
|
+
uid = random.randint(0, len(shards) - 1)
|
205
|
+
|
206
|
+
shards[uid % len(self.shards)].append(document)
|
207
|
+
|
208
|
+
return shards
|
209
|
+
|
210
|
+
def execute(self, method, action, data=None):
|
211
|
+
"""
|
212
|
+
Executes a HTTP action asynchronously.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
method: get or post
|
216
|
+
action: url action to perform
|
217
|
+
data: post parameters
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
json results if any
|
221
|
+
"""
|
222
|
+
|
223
|
+
# Get urls
|
224
|
+
urls = [f"{shard}/{action}" for shard in self.shards]
|
225
|
+
close = False
|
226
|
+
|
227
|
+
# Use existing loop if available, otherwise create one
|
228
|
+
try:
|
229
|
+
loop = asyncio.get_event_loop()
|
230
|
+
except RuntimeError:
|
231
|
+
loop = asyncio.new_event_loop()
|
232
|
+
close = True
|
233
|
+
|
234
|
+
try:
|
235
|
+
return loop.run_until_complete(self.run(urls, method, data))
|
236
|
+
finally:
|
237
|
+
# Close loop if it was created in this method
|
238
|
+
if close:
|
239
|
+
loop.close()
|
240
|
+
|
241
|
+
async def run(self, urls, method, data):
|
242
|
+
"""
|
243
|
+
Runs an async action.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
urls: run against this list of urls
|
247
|
+
method: get or post
|
248
|
+
data: list of data for each url or None
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
json results if any
|
252
|
+
"""
|
253
|
+
|
254
|
+
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
255
|
+
tasks = []
|
256
|
+
|
257
|
+
for x, url in enumerate(urls):
|
258
|
+
if method == "post":
|
259
|
+
if not data or data[x]:
|
260
|
+
tasks.append(asyncio.ensure_future(self.post(session, url, data[x] if data else None)))
|
261
|
+
else:
|
262
|
+
tasks.append(asyncio.ensure_future(self.get(session, url)))
|
263
|
+
|
264
|
+
return await asyncio.gather(*tasks)
|
265
|
+
|
266
|
+
async def get(self, session, url):
|
267
|
+
"""
|
268
|
+
Runs an async HTTP GET request.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
session: ClientSession
|
272
|
+
url: url
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
json results if any
|
276
|
+
"""
|
277
|
+
|
278
|
+
async with session.get(url) as resp:
|
279
|
+
return await resp.json()
|
280
|
+
|
281
|
+
async def post(self, session, url, data):
|
282
|
+
"""
|
283
|
+
Runs an async HTTP POST request.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
session: ClientSession
|
287
|
+
url: url
|
288
|
+
data: data to POST
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
json results if any
|
292
|
+
"""
|
293
|
+
|
294
|
+
async with session.post(url, json=data) as resp:
|
295
|
+
return await resp.json()
|
txtai/api/extension.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
"""
|
2
|
+
Extension module
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
class Extension:
|
7
|
+
"""
|
8
|
+
Defines an API extension. API extensions can expose custom pipelines or other custom logic.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __call__(self, app):
|
12
|
+
"""
|
13
|
+
Hook to register custom routing logic and/or modify the FastAPI instance.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
app: FastAPI application instance
|
17
|
+
"""
|
18
|
+
|
19
|
+
return
|
txtai/api/factory.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
"""
|
2
|
+
API factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..util import Resolver
|
6
|
+
|
7
|
+
|
8
|
+
class APIFactory:
|
9
|
+
"""
|
10
|
+
API factory. Creates new API instances.
|
11
|
+
"""
|
12
|
+
|
13
|
+
@staticmethod
|
14
|
+
def get(api):
|
15
|
+
"""
|
16
|
+
Gets a new instance of api class.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
api: API instance class
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
API
|
23
|
+
"""
|
24
|
+
|
25
|
+
return Resolver()(api)
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def create(config, api):
|
29
|
+
"""
|
30
|
+
Creates a new API instance.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
config: API configuration
|
34
|
+
api: API instance class
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
API instance
|
38
|
+
"""
|
39
|
+
|
40
|
+
return APIFactory.get(api)(config)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .json import JSONResponse
|
6
|
+
from .messagepack import MessagePackResponse
|
7
|
+
|
8
|
+
|
9
|
+
class ResponseFactory:
|
10
|
+
"""
|
11
|
+
Methods to create Response classes.
|
12
|
+
"""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def create(request):
|
16
|
+
"""
|
17
|
+
Gets a response class for request using the Accept header.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
request: request
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
response class
|
24
|
+
"""
|
25
|
+
|
26
|
+
# Get Accept header
|
27
|
+
accept = request.headers.get("Accept")
|
28
|
+
|
29
|
+
# Get response class
|
30
|
+
return MessagePackResponse if accept == MessagePackResponse.media_type else JSONResponse
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""
|
2
|
+
JSON module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import base64
|
6
|
+
import json
|
7
|
+
|
8
|
+
from io import BytesIO
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
import fastapi.responses
|
12
|
+
|
13
|
+
from PIL.Image import Image
|
14
|
+
|
15
|
+
|
16
|
+
class JSONEncoder(json.JSONEncoder):
|
17
|
+
"""
|
18
|
+
Extended JSONEncoder that serializes images and byte streams as base64 encoded text.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def default(self, o):
|
22
|
+
# Convert Image to BytesIO
|
23
|
+
if isinstance(o, Image):
|
24
|
+
buffered = BytesIO()
|
25
|
+
o.save(buffered, format=o.format, quality="keep")
|
26
|
+
o = buffered
|
27
|
+
|
28
|
+
# Unpack bytes from BytesIO
|
29
|
+
if isinstance(o, BytesIO):
|
30
|
+
o = o.getvalue()
|
31
|
+
|
32
|
+
# Base64 encode bytes instances
|
33
|
+
if isinstance(o, bytes):
|
34
|
+
return base64.b64encode(o).decode("utf-8")
|
35
|
+
|
36
|
+
# Default handler
|
37
|
+
return super().default(o)
|
38
|
+
|
39
|
+
|
40
|
+
class JSONResponse(fastapi.responses.JSONResponse):
|
41
|
+
"""
|
42
|
+
Extended JSONResponse that serializes images and byte streams as base64 encoded text.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def render(self, content: Any) -> bytes:
|
46
|
+
"""
|
47
|
+
Renders content to the response.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
content: input content
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
rendered content as bytes
|
54
|
+
"""
|
55
|
+
|
56
|
+
return json.dumps(content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), cls=JSONEncoder).encode("utf-8")
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""
|
2
|
+
MessagePack module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from io import BytesIO
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import msgpack
|
9
|
+
|
10
|
+
from fastapi import Response
|
11
|
+
from PIL.Image import Image
|
12
|
+
|
13
|
+
|
14
|
+
class MessagePackResponse(Response):
|
15
|
+
"""
|
16
|
+
Encodes responses with MessagePack.
|
17
|
+
"""
|
18
|
+
|
19
|
+
media_type = "application/msgpack"
|
20
|
+
|
21
|
+
def render(self, content: Any) -> bytes:
|
22
|
+
"""
|
23
|
+
Renders content to the response.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
content: input content
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
rendered content as bytes
|
30
|
+
"""
|
31
|
+
|
32
|
+
return msgpack.packb(content, default=MessagePackEncoder())
|
33
|
+
|
34
|
+
|
35
|
+
class MessagePackEncoder:
|
36
|
+
"""
|
37
|
+
Extended MessagePack encoder that converts images to bytes.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __call__(self, o):
|
41
|
+
# Convert Image to bytes
|
42
|
+
if isinstance(o, Image):
|
43
|
+
buffered = BytesIO()
|
44
|
+
o.save(buffered, format=o.format, quality="keep")
|
45
|
+
o = buffered
|
46
|
+
|
47
|
+
# Get bytes from BytesIO
|
48
|
+
if isinstance(o, BytesIO):
|
49
|
+
o = o.getvalue()
|
50
|
+
|
51
|
+
return o
|
txtai/api/route.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
"""
|
2
|
+
Route module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from fastapi.routing import APIRoute, get_request_handler
|
6
|
+
|
7
|
+
from .responses import ResponseFactory
|
8
|
+
|
9
|
+
|
10
|
+
class EncodingAPIRoute(APIRoute):
|
11
|
+
"""
|
12
|
+
Extended APIRoute that encodes responses based on HTTP Accept header.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def get_route_handler(self):
|
16
|
+
"""
|
17
|
+
Resolves a response class based on the HTTP Accept header.
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
route handler function
|
21
|
+
"""
|
22
|
+
|
23
|
+
async def handler(request):
|
24
|
+
route = get_request_handler(
|
25
|
+
dependant=self.dependant,
|
26
|
+
body_field=self.body_field,
|
27
|
+
status_code=self.status_code,
|
28
|
+
response_class=ResponseFactory.create(request),
|
29
|
+
response_field=self.secure_cloned_response_field,
|
30
|
+
response_model_include=self.response_model_include,
|
31
|
+
response_model_exclude=self.response_model_exclude,
|
32
|
+
response_model_by_alias=self.response_model_by_alias,
|
33
|
+
response_model_exclude_unset=self.response_model_exclude_unset,
|
34
|
+
response_model_exclude_defaults=self.response_model_exclude_defaults,
|
35
|
+
response_model_exclude_none=self.response_model_exclude_none,
|
36
|
+
dependency_overrides_provider=self.dependency_overrides_provider,
|
37
|
+
)
|
38
|
+
|
39
|
+
return await route(request)
|
40
|
+
|
41
|
+
return handler
|