mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/data/texts.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
"""
|
2
|
+
Texts module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from itertools import chain
|
6
|
+
|
7
|
+
from .base import Data
|
8
|
+
|
9
|
+
|
10
|
+
class Texts(Data):
|
11
|
+
"""
|
12
|
+
Tokenizes text datasets as input for training language models.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, tokenizer, columns, maxlength):
|
16
|
+
"""
|
17
|
+
Creates a new instance for tokenizing Texts training data.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
tokenizer: model tokenizer
|
21
|
+
columns: tuple of columns to use for text
|
22
|
+
maxlength: maximum sequence length
|
23
|
+
"""
|
24
|
+
|
25
|
+
super().__init__(tokenizer, columns, maxlength)
|
26
|
+
|
27
|
+
# Standardize columns
|
28
|
+
if not self.columns:
|
29
|
+
self.columns = ("text", None)
|
30
|
+
|
31
|
+
def process(self, data):
|
32
|
+
# Column keys
|
33
|
+
text1, text2 = self.columns
|
34
|
+
|
35
|
+
# Tokenizer inputs can be single string or string pair, depending on task
|
36
|
+
text = (data[text1], data[text2]) if text2 else (data[text1],)
|
37
|
+
|
38
|
+
# Tokenize text and add label
|
39
|
+
inputs = self.tokenizer(*text, return_special_tokens_mask=True)
|
40
|
+
|
41
|
+
# Concat and return tokenized inputs
|
42
|
+
return self.concat(inputs)
|
43
|
+
|
44
|
+
def concat(self, inputs):
|
45
|
+
"""
|
46
|
+
Concatenates tokenized text into chunks of maxlength.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
inputs: tokenized input
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Chunks of tokenized text each with a size of maxlength
|
53
|
+
"""
|
54
|
+
|
55
|
+
# Concatenate tokenized text
|
56
|
+
concat = {k: list(chain(*inputs[k])) for k in inputs.keys()}
|
57
|
+
|
58
|
+
# Calculate total length
|
59
|
+
length = len(concat[list(inputs.keys())[0]])
|
60
|
+
|
61
|
+
# Ensure total is multiple of maxlength, drop last incomplete chunk
|
62
|
+
if length >= self.maxlength:
|
63
|
+
length = (length // self.maxlength) * self.maxlength
|
64
|
+
|
65
|
+
# Split into chunks of maxlength
|
66
|
+
result = {k: [v[x : x + self.maxlength] for x in range(0, length, self.maxlength)] for k, v in concat.items()}
|
67
|
+
|
68
|
+
return result
|
txtai/data/tokens.py
ADDED
@@ -0,0 +1,28 @@
|
|
1
|
+
"""
|
2
|
+
Tokens module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
class Tokens(torch.utils.data.Dataset):
|
9
|
+
"""
|
10
|
+
Default dataset used to hold tokenized data.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, columns):
|
14
|
+
self.data = []
|
15
|
+
|
16
|
+
# Map column-oriented data to rows
|
17
|
+
for column in columns:
|
18
|
+
for x, value in enumerate(columns[column]):
|
19
|
+
if len(self.data) <= x:
|
20
|
+
self.data.append({})
|
21
|
+
|
22
|
+
self.data[x][column] = value
|
23
|
+
|
24
|
+
def __len__(self):
|
25
|
+
return len(self.data)
|
26
|
+
|
27
|
+
def __getitem__(self, index):
|
28
|
+
return self.data[index]
|
@@ -0,0 +1,14 @@
|
|
1
|
+
"""
|
2
|
+
Database imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Database
|
6
|
+
from .client import Client
|
7
|
+
from .duckdb import DuckDB
|
8
|
+
from .embedded import Embedded
|
9
|
+
from .encoder import *
|
10
|
+
from .factory import DatabaseFactory
|
11
|
+
from .rdbms import RDBMS
|
12
|
+
from .schema import *
|
13
|
+
from .sqlite import SQLite
|
14
|
+
from .sql import *
|
txtai/database/base.py
ADDED
@@ -0,0 +1,342 @@
|
|
1
|
+
"""
|
2
|
+
Database module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import types
|
7
|
+
|
8
|
+
from .encoder import EncoderFactory
|
9
|
+
from .sql import SQL, SQLError, Token
|
10
|
+
|
11
|
+
# Logging configuration
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class Database:
|
16
|
+
"""
|
17
|
+
Base class for database instances. This class encapsulates a content database used for
|
18
|
+
storing field content as dicts and objects. The database instance works in conjuction
|
19
|
+
with a vector index to execute SQL-driven similarity search.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, config):
|
23
|
+
"""
|
24
|
+
Creates a new Database.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
config: database configuration
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Initialize configuration
|
31
|
+
self.configure(config)
|
32
|
+
|
33
|
+
def load(self, path):
|
34
|
+
"""
|
35
|
+
Loads a database path.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
path: database url
|
39
|
+
"""
|
40
|
+
|
41
|
+
raise NotImplementedError
|
42
|
+
|
43
|
+
def insert(self, documents, index=0):
|
44
|
+
"""
|
45
|
+
Inserts documents into the database.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
documents: list of documents to save
|
49
|
+
index: indexid offset, used for internal ids
|
50
|
+
"""
|
51
|
+
|
52
|
+
raise NotImplementedError
|
53
|
+
|
54
|
+
def delete(self, ids):
|
55
|
+
"""
|
56
|
+
Deletes documents from database.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
ids: ids to delete
|
60
|
+
"""
|
61
|
+
|
62
|
+
raise NotImplementedError
|
63
|
+
|
64
|
+
def reindex(self, config):
|
65
|
+
"""
|
66
|
+
Reindexes internal database content and streams results back. This method must renumber indexids
|
67
|
+
sequentially as deletes could have caused indexid gaps.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
config: new configuration
|
71
|
+
"""
|
72
|
+
|
73
|
+
raise NotImplementedError
|
74
|
+
|
75
|
+
def save(self, path):
|
76
|
+
"""
|
77
|
+
Saves a database at path.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
path: path to write database
|
81
|
+
"""
|
82
|
+
|
83
|
+
raise NotImplementedError
|
84
|
+
|
85
|
+
def close(self):
|
86
|
+
"""
|
87
|
+
Closes this database.
|
88
|
+
"""
|
89
|
+
|
90
|
+
raise NotImplementedError
|
91
|
+
|
92
|
+
def ids(self, ids):
|
93
|
+
"""
|
94
|
+
Retrieves the internal indexids for a list of ids. Multiple indexids may be present for an id in cases
|
95
|
+
where data is segmented.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
ids: list of document ids
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
list of (indexid, id)
|
102
|
+
"""
|
103
|
+
|
104
|
+
raise NotImplementedError
|
105
|
+
|
106
|
+
def count(self):
|
107
|
+
"""
|
108
|
+
Retrieves the count of this database instance.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
total database count
|
112
|
+
"""
|
113
|
+
|
114
|
+
raise NotImplementedError
|
115
|
+
|
116
|
+
def search(self, query, similarity=None, limit=None, parameters=None, indexids=False):
|
117
|
+
"""
|
118
|
+
Runs a search against the database. Supports the following methods:
|
119
|
+
|
120
|
+
1. Standard similarity query. This mode retrieves content for the ids in the similarity results
|
121
|
+
2. Similarity query as SQL. This mode will combine similarity results and database results into
|
122
|
+
a single result set. Similarity queries are set via the SIMILAR() function.
|
123
|
+
3. SQL with no similarity query. This mode runs a SQL query and retrieves the results without similarity queries.
|
124
|
+
|
125
|
+
Example queries:
|
126
|
+
"natural language processing" - standard similarity only query
|
127
|
+
"select * from txtai where similar('natural language processing')" - similarity query as SQL
|
128
|
+
"select * from txtai where similar('nlp') and entry > '2021-01-01'" - similarity query with additional SQL clauses
|
129
|
+
"select id, text, score from txtai where similar('nlp')" - similarity query with additional SQL column selections
|
130
|
+
"select * from txtai where entry > '2021-01-01' - database only query
|
131
|
+
|
132
|
+
Args:
|
133
|
+
query: input query
|
134
|
+
similarity: similarity results as [(indexid, score)]
|
135
|
+
limit: maximum number of results to return
|
136
|
+
parameters: dict of named parameters to bind to placeholders
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
query results as a list of dicts
|
140
|
+
list of ([indexid, score]) if indexids is True
|
141
|
+
"""
|
142
|
+
|
143
|
+
# Parse query if necessary
|
144
|
+
if isinstance(query, str):
|
145
|
+
query = self.parse(query)
|
146
|
+
|
147
|
+
# Add in similar results
|
148
|
+
where = query.get("where")
|
149
|
+
|
150
|
+
if "select" in query and similarity:
|
151
|
+
for x in range(len(similarity)):
|
152
|
+
token = f"{Token.SIMILAR_TOKEN}{x}"
|
153
|
+
if where and token in where:
|
154
|
+
where = where.replace(token, self.embed(similarity, x))
|
155
|
+
|
156
|
+
elif similarity:
|
157
|
+
# Not a SQL query, load similarity results, if any
|
158
|
+
where = self.embed(similarity, 0)
|
159
|
+
|
160
|
+
# Save where
|
161
|
+
query["where"] = where
|
162
|
+
|
163
|
+
# Run query
|
164
|
+
return self.query(query, limit, parameters, indexids)
|
165
|
+
|
166
|
+
def parse(self, query):
|
167
|
+
"""
|
168
|
+
Parses a query into query components.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
query: input query
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
dict of parsed query components
|
175
|
+
"""
|
176
|
+
|
177
|
+
return self.sql(query)
|
178
|
+
|
179
|
+
def resolve(self, name, alias=None):
|
180
|
+
"""
|
181
|
+
Resolves a query column name with the database column name. This method also builds alias expressions
|
182
|
+
if alias is set.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
name: query column name
|
186
|
+
alias: alias name, defaults to None
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
database column name
|
190
|
+
"""
|
191
|
+
|
192
|
+
raise NotImplementedError
|
193
|
+
|
194
|
+
def embed(self, similarity, batch):
|
195
|
+
"""
|
196
|
+
Embeds similarity query results into a database query.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
similarity: similarity results as [(indexid, score)]
|
200
|
+
batch: batch id
|
201
|
+
"""
|
202
|
+
|
203
|
+
raise NotImplementedError
|
204
|
+
|
205
|
+
def query(self, query, limit, parameters, indexids):
|
206
|
+
"""
|
207
|
+
Executes query against database.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
query: input query
|
211
|
+
limit: maximum number of results to return
|
212
|
+
parameters: dict of named parameters to bind to placeholders
|
213
|
+
indexids: results are returned as [(indexid, score)] regardless of select clause parameters if True
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
query results
|
217
|
+
"""
|
218
|
+
|
219
|
+
raise NotImplementedError
|
220
|
+
|
221
|
+
def configure(self, config):
|
222
|
+
"""
|
223
|
+
Initialize configuration.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
config: configuration
|
227
|
+
"""
|
228
|
+
|
229
|
+
# Database configuration
|
230
|
+
self.config = config
|
231
|
+
|
232
|
+
# SQL parser
|
233
|
+
self.sql = SQL(self)
|
234
|
+
|
235
|
+
# Load objects encoder
|
236
|
+
encoder = self.config.get("objects")
|
237
|
+
self.encoder = EncoderFactory.create(encoder) if encoder else None
|
238
|
+
|
239
|
+
# Transform columns
|
240
|
+
columns = config.get("columns", {})
|
241
|
+
self.text = columns.get("text", "text")
|
242
|
+
self.object = columns.get("object", "object")
|
243
|
+
|
244
|
+
# Custom functions and expressions
|
245
|
+
self.functions, self.expressions = None, None
|
246
|
+
|
247
|
+
# Load custom functions
|
248
|
+
self.registerfunctions(self.config)
|
249
|
+
|
250
|
+
# Load custom expressions
|
251
|
+
self.registerexpressions(self.config)
|
252
|
+
|
253
|
+
def registerfunctions(self, config):
|
254
|
+
"""
|
255
|
+
Register custom functions. This method stores the function details for underlying
|
256
|
+
database implementations to handle.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
config: database configuration
|
260
|
+
"""
|
261
|
+
|
262
|
+
inputs = config.get("functions") if config else None
|
263
|
+
if inputs:
|
264
|
+
functions = []
|
265
|
+
for fn in inputs:
|
266
|
+
name, argcount = None, -1
|
267
|
+
|
268
|
+
# Optional function configuration
|
269
|
+
if isinstance(fn, dict):
|
270
|
+
name, argcount, fn = fn.get("name"), fn.get("argcount", -1), fn["function"]
|
271
|
+
|
272
|
+
# Determine if this is a callable object or a function
|
273
|
+
if not isinstance(fn, types.FunctionType) and hasattr(fn, "__call__"):
|
274
|
+
name = name if name else fn.__class__.__name__.lower()
|
275
|
+
fn = fn.__call__
|
276
|
+
else:
|
277
|
+
name = name if name else fn.__name__.lower()
|
278
|
+
|
279
|
+
# Store function details
|
280
|
+
functions.append((name, argcount, fn))
|
281
|
+
|
282
|
+
# pylint: disable=W0201
|
283
|
+
self.functions = functions
|
284
|
+
|
285
|
+
def registerexpressions(self, config):
|
286
|
+
"""
|
287
|
+
Register custom expressions. This method parses and resolves expressions for later use in SQL queries.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
config: database configuration
|
291
|
+
"""
|
292
|
+
|
293
|
+
inputs = config.get("expressions") if config else None
|
294
|
+
if inputs:
|
295
|
+
expressions = {}
|
296
|
+
for entry in inputs:
|
297
|
+
name = entry.get("name")
|
298
|
+
expression = entry.get("expression")
|
299
|
+
if name and expression:
|
300
|
+
expressions[name] = self.sql.snippet(expression)
|
301
|
+
|
302
|
+
# pylint: disable=W0201
|
303
|
+
self.expressions = expressions
|
304
|
+
|
305
|
+
def execute(self, function, *args):
|
306
|
+
"""
|
307
|
+
Executes a user query. This method has common error handling logic.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
function: database execute function
|
311
|
+
args: function arguments
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
result of function(args)
|
315
|
+
"""
|
316
|
+
|
317
|
+
try:
|
318
|
+
# Debug log SQL
|
319
|
+
logger.debug(" ".join(["%s"] * len(args)), *args)
|
320
|
+
|
321
|
+
return function(*args)
|
322
|
+
except Exception as e:
|
323
|
+
raise SQLError(e) from None
|
324
|
+
|
325
|
+
def setting(self, name, default=None):
|
326
|
+
"""
|
327
|
+
Looks up database specific setting.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
name: setting name
|
331
|
+
default: default value when setting not found
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
setting value
|
335
|
+
"""
|
336
|
+
|
337
|
+
# Get the database-specific config object
|
338
|
+
database = self.config.get(self.config["content"])
|
339
|
+
|
340
|
+
# Get setting value, set default value if not found
|
341
|
+
setting = database.get(name) if database else None
|
342
|
+
return setting if setting else default
|
txtai/database/client.py
ADDED
@@ -0,0 +1,227 @@
|
|
1
|
+
"""
|
2
|
+
Client module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import time
|
7
|
+
|
8
|
+
# Conditional import
|
9
|
+
try:
|
10
|
+
from sqlalchemy import StaticPool, Text, cast, create_engine, insert, text as textsql
|
11
|
+
from sqlalchemy.orm import Session, aliased
|
12
|
+
from sqlalchemy.schema import CreateSchema
|
13
|
+
|
14
|
+
from .schema import Base, Batch, Document, Object, Section, SectionBase, Score
|
15
|
+
|
16
|
+
ORM = True
|
17
|
+
except ImportError:
|
18
|
+
ORM = False
|
19
|
+
|
20
|
+
from .rdbms import RDBMS
|
21
|
+
|
22
|
+
|
23
|
+
class Client(RDBMS):
|
24
|
+
"""
|
25
|
+
Database client instance. This class connects to an external database using SQLAlchemy. It supports any database
|
26
|
+
that is supported by SQLAlchemy (PostgreSQL, MariaDB, etc) and has JSON support.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, config):
|
30
|
+
"""
|
31
|
+
Creates a new Database.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
config: database configuration parameters
|
35
|
+
"""
|
36
|
+
|
37
|
+
super().__init__(config)
|
38
|
+
|
39
|
+
if not ORM:
|
40
|
+
raise ImportError('SQLAlchemy is not available - install "database" extra to enable')
|
41
|
+
|
42
|
+
# SQLAlchemy parameters
|
43
|
+
self.engine, self.dbconnection = None, None
|
44
|
+
|
45
|
+
def save(self, path):
|
46
|
+
# Commit session and database connection
|
47
|
+
super().save(path)
|
48
|
+
|
49
|
+
if self.dbconnection:
|
50
|
+
self.dbconnection.commit()
|
51
|
+
|
52
|
+
def close(self):
|
53
|
+
super().close()
|
54
|
+
|
55
|
+
# Dispose of engine, which also closes dbconnection
|
56
|
+
if self.engine:
|
57
|
+
self.engine.dispose()
|
58
|
+
|
59
|
+
def reindexstart(self):
|
60
|
+
# Working table name
|
61
|
+
name = f"rebuild{round(time.time() * 1000)}"
|
62
|
+
|
63
|
+
# Create working table metadata
|
64
|
+
type("Rebuild", (SectionBase,), {"__tablename__": name})
|
65
|
+
Base.metadata.tables[name].create(self.dbconnection)
|
66
|
+
|
67
|
+
return name
|
68
|
+
|
69
|
+
def reindexend(self, name):
|
70
|
+
# Remove table object from metadata
|
71
|
+
Base.metadata.remove(Base.metadata.tables[name])
|
72
|
+
|
73
|
+
def jsonprefix(self):
|
74
|
+
# JSON column prefix
|
75
|
+
return "cast("
|
76
|
+
|
77
|
+
def jsoncolumn(self, name):
|
78
|
+
# Alias documents table
|
79
|
+
d = aliased(Document, name="d")
|
80
|
+
|
81
|
+
# Build JSON column expression for column
|
82
|
+
return str(cast(d.data[name].as_string(), Text).compile(dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}))
|
83
|
+
|
84
|
+
def createtables(self):
|
85
|
+
# Create tables
|
86
|
+
Base.metadata.create_all(self.dbconnection, checkfirst=True)
|
87
|
+
|
88
|
+
# Clear existing data - table schema is created upon connecting to database
|
89
|
+
for table in ["sections", "documents", "objects"]:
|
90
|
+
self.cursor.execute(f"DELETE FROM {table}")
|
91
|
+
|
92
|
+
def finalize(self):
|
93
|
+
# Flush cached objects
|
94
|
+
self.connection.flush()
|
95
|
+
|
96
|
+
def insertdocument(self, uid, data, tags, entry):
|
97
|
+
self.connection.add(Document(id=uid, data=data, tags=tags, entry=entry))
|
98
|
+
|
99
|
+
def insertobject(self, uid, data, tags, entry):
|
100
|
+
self.connection.add(Object(id=uid, object=data, tags=tags, entry=entry))
|
101
|
+
|
102
|
+
def insertsection(self, index, uid, text, tags, entry):
|
103
|
+
# Save text section
|
104
|
+
self.connection.add(Section(indexid=index, id=uid, text=text, tags=tags, entry=entry))
|
105
|
+
|
106
|
+
def createbatch(self):
|
107
|
+
# Create temporary batch table, if necessary
|
108
|
+
Base.metadata.tables["batch"].create(self.dbconnection, checkfirst=True)
|
109
|
+
|
110
|
+
def insertbatch(self, indexids, ids, batch):
|
111
|
+
if indexids:
|
112
|
+
self.connection.execute(insert(Batch), [{"indexid": i, "batch": batch} for i in indexids])
|
113
|
+
if ids:
|
114
|
+
self.connection.execute(insert(Batch), [{"id": str(uid), "batch": batch} for uid in ids])
|
115
|
+
|
116
|
+
def createscores(self):
|
117
|
+
# Create temporary scores table, if necessary
|
118
|
+
Base.metadata.tables["scores"].create(self.dbconnection, checkfirst=True)
|
119
|
+
|
120
|
+
def insertscores(self, scores):
|
121
|
+
# Average scores by id
|
122
|
+
if scores:
|
123
|
+
self.connection.execute(insert(Score), [{"indexid": i, "score": sum(s) / len(s)} for i, s in scores.items()])
|
124
|
+
|
125
|
+
def connect(self, path=None):
|
126
|
+
# Connection URL
|
127
|
+
content = self.config.get("content")
|
128
|
+
|
129
|
+
# Read ENV variable, if necessary
|
130
|
+
content = os.environ.get("CLIENT_URL") if content == "client" else content
|
131
|
+
|
132
|
+
# Create engine using database URL
|
133
|
+
self.engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)
|
134
|
+
self.dbconnection = self.engine.connect()
|
135
|
+
|
136
|
+
# Create database session
|
137
|
+
database = Session(self.dbconnection)
|
138
|
+
|
139
|
+
# Set default schema, if necessary
|
140
|
+
schema = self.config.get("schema")
|
141
|
+
if schema:
|
142
|
+
with self.engine.begin():
|
143
|
+
self.sqldialect(database, CreateSchema(schema, if_not_exists=True))
|
144
|
+
|
145
|
+
self.sqldialect(database, textsql("SET search_path TO :schema"), {"schema": schema})
|
146
|
+
|
147
|
+
return database
|
148
|
+
|
149
|
+
def getcursor(self):
|
150
|
+
return Cursor(self.connection)
|
151
|
+
|
152
|
+
def rows(self):
|
153
|
+
return self.cursor
|
154
|
+
|
155
|
+
def addfunctions(self):
|
156
|
+
return
|
157
|
+
|
158
|
+
def sqldialect(self, database, sql, parameters=None):
|
159
|
+
"""
|
160
|
+
Executes a SQL statement based on the current SQL dialect.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
database: current database
|
164
|
+
sql: SQL to execute
|
165
|
+
parameters: optional bind parameters
|
166
|
+
"""
|
167
|
+
|
168
|
+
args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (textsql("SELECT 1"),)
|
169
|
+
database.execute(*args)
|
170
|
+
|
171
|
+
|
172
|
+
class Cursor:
|
173
|
+
"""
|
174
|
+
Implements basic compatibility with the Python DB-API.
|
175
|
+
"""
|
176
|
+
|
177
|
+
def __init__(self, connection):
|
178
|
+
self.connection = connection
|
179
|
+
self.result = None
|
180
|
+
|
181
|
+
def __iter__(self):
|
182
|
+
return self.result
|
183
|
+
|
184
|
+
def execute(self, statement, parameters=None):
|
185
|
+
"""
|
186
|
+
Executes statement.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
statement: statement to execute
|
190
|
+
parameters: optional dictionary with bind parameters
|
191
|
+
"""
|
192
|
+
|
193
|
+
if isinstance(statement, str):
|
194
|
+
statement = textsql(statement)
|
195
|
+
|
196
|
+
self.result = self.connection.execute(statement, parameters)
|
197
|
+
|
198
|
+
def fetchall(self):
|
199
|
+
"""
|
200
|
+
Fetches all rows from the current result.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
all rows from current result
|
204
|
+
"""
|
205
|
+
|
206
|
+
return self.result.all() if self.result else None
|
207
|
+
|
208
|
+
def fetchone(self):
|
209
|
+
"""
|
210
|
+
Fetches first row from current result.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
first row from current result
|
214
|
+
"""
|
215
|
+
|
216
|
+
return self.result.first() if self.result else None
|
217
|
+
|
218
|
+
@property
|
219
|
+
def description(self):
|
220
|
+
"""
|
221
|
+
Returns columns for current result.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
list of columns
|
225
|
+
"""
|
226
|
+
|
227
|
+
return [(key,) for key in self.result.keys()] if self.result else None
|