mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Statement module
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
class Statement:
|
7
|
+
"""
|
8
|
+
Standard database schema SQL statements.
|
9
|
+
"""
|
10
|
+
|
11
|
+
# Temporary table for working with id batches
|
12
|
+
CREATE_BATCH = """
|
13
|
+
CREATE TEMP TABLE IF NOT EXISTS batch (
|
14
|
+
indexid INTEGER,
|
15
|
+
id TEXT,
|
16
|
+
batch INTEGER
|
17
|
+
)
|
18
|
+
"""
|
19
|
+
|
20
|
+
DELETE_BATCH = "DELETE FROM batch"
|
21
|
+
INSERT_BATCH_INDEXID = "INSERT INTO batch (indexid, batch) VALUES (?, ?)"
|
22
|
+
INSERT_BATCH_ID = "INSERT INTO batch (id, batch) VALUES (?, ?)"
|
23
|
+
|
24
|
+
# Temporary table for joining similarity scores
|
25
|
+
CREATE_SCORES = """
|
26
|
+
CREATE TEMP TABLE IF NOT EXISTS scores (
|
27
|
+
indexid INTEGER PRIMARY KEY,
|
28
|
+
score REAL
|
29
|
+
)
|
30
|
+
"""
|
31
|
+
|
32
|
+
DELETE_SCORES = "DELETE FROM scores"
|
33
|
+
INSERT_SCORE = "INSERT INTO scores VALUES (?, ?)"
|
34
|
+
|
35
|
+
# Documents - stores full content
|
36
|
+
CREATE_DOCUMENTS = """
|
37
|
+
CREATE TABLE IF NOT EXISTS documents (
|
38
|
+
id TEXT PRIMARY KEY,
|
39
|
+
data JSON,
|
40
|
+
tags TEXT,
|
41
|
+
entry DATETIME
|
42
|
+
)
|
43
|
+
"""
|
44
|
+
|
45
|
+
INSERT_DOCUMENT = "INSERT OR REPLACE INTO documents VALUES (?, ?, ?, ?)"
|
46
|
+
DELETE_DOCUMENTS = "DELETE FROM documents WHERE id IN (SELECT id FROM batch)"
|
47
|
+
|
48
|
+
# Objects - stores binary content
|
49
|
+
CREATE_OBJECTS = """
|
50
|
+
CREATE TABLE IF NOT EXISTS objects (
|
51
|
+
id TEXT PRIMARY KEY,
|
52
|
+
object BLOB,
|
53
|
+
tags TEXT,
|
54
|
+
entry DATETIME
|
55
|
+
)
|
56
|
+
"""
|
57
|
+
|
58
|
+
INSERT_OBJECT = "INSERT OR REPLACE INTO objects VALUES (?, ?, ?, ?)"
|
59
|
+
DELETE_OBJECTS = "DELETE FROM objects WHERE id IN (SELECT id FROM batch)"
|
60
|
+
|
61
|
+
# Sections - stores section text
|
62
|
+
CREATE_SECTIONS = """
|
63
|
+
CREATE TABLE IF NOT EXISTS %s (
|
64
|
+
indexid INTEGER PRIMARY KEY,
|
65
|
+
id TEXT,
|
66
|
+
text TEXT,
|
67
|
+
tags TEXT,
|
68
|
+
entry DATETIME
|
69
|
+
)
|
70
|
+
"""
|
71
|
+
|
72
|
+
CREATE_SECTIONS_INDEX = "CREATE INDEX section_id ON sections(id)"
|
73
|
+
INSERT_SECTION = "INSERT INTO sections VALUES (?, ?, ?, ?, ?)"
|
74
|
+
DELETE_SECTIONS = "DELETE FROM sections WHERE id IN (SELECT id FROM batch)"
|
75
|
+
COPY_SECTIONS = (
|
76
|
+
"INSERT INTO %s SELECT (select count(*) - 1 from sections s1 where s.indexid >= s1.indexid) indexid, "
|
77
|
+
+ "s.id, %s AS text, s.tags, s.entry FROM sections s LEFT JOIN documents d ON s.id = d.id ORDER BY indexid"
|
78
|
+
)
|
79
|
+
STREAM_SECTIONS = (
|
80
|
+
"SELECT s.id, s.text, data, object, s.tags FROM %s s "
|
81
|
+
+ "LEFT JOIN documents d ON s.id = d.id "
|
82
|
+
+ "LEFT JOIN objects o ON s.id = o.id ORDER BY indexid"
|
83
|
+
)
|
84
|
+
DROP_SECTIONS = "DROP TABLE sections"
|
85
|
+
RENAME_SECTIONS = "ALTER TABLE %s RENAME TO sections"
|
86
|
+
|
87
|
+
# Queries
|
88
|
+
SELECT_IDS = "SELECT indexid, id FROM sections WHERE id in (SELECT id FROM batch)"
|
89
|
+
COUNT_IDS = "SELECT count(indexid) FROM sections"
|
90
|
+
|
91
|
+
# Partial sql clauses
|
92
|
+
TABLE_CLAUSE = (
|
93
|
+
"SELECT %s FROM sections s "
|
94
|
+
+ "LEFT JOIN documents d ON s.id = d.id "
|
95
|
+
+ "LEFT JOIN objects o ON s.id = o.id "
|
96
|
+
+ "LEFT JOIN scores sc ON s.indexid = sc.indexid"
|
97
|
+
)
|
98
|
+
IDS_CLAUSE = "s.indexid in (SELECT indexid from batch WHERE batch=%s)"
|
@@ -0,0 +1,178 @@
|
|
1
|
+
"""
|
2
|
+
Aggregate module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import itertools
|
6
|
+
import operator
|
7
|
+
|
8
|
+
from .base import SQL
|
9
|
+
|
10
|
+
|
11
|
+
class Aggregate(SQL):
|
12
|
+
"""
|
13
|
+
Aggregates partial results from queries. Partial results come from queries when working with sharded indexes.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, database=None):
|
17
|
+
# Always return token lists as this method requires them
|
18
|
+
super().__init__(database, True)
|
19
|
+
|
20
|
+
def __call__(self, query, results):
|
21
|
+
"""
|
22
|
+
Analyzes query results, combines aggregate function results and applies ordering.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
query: input query
|
26
|
+
results: query results
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
aggregated query results
|
30
|
+
"""
|
31
|
+
|
32
|
+
# Parse query
|
33
|
+
query = super().__call__(query)
|
34
|
+
|
35
|
+
# Check if this is a SQL query
|
36
|
+
if "select" in query:
|
37
|
+
# Get list of unique and aggregate columns. If no aggregate columns or order by found, skip
|
38
|
+
columns = list(results[0].keys())
|
39
|
+
aggcolumns = self.aggcolumns(columns)
|
40
|
+
if aggcolumns or query["orderby"]:
|
41
|
+
# Merge aggregate columns
|
42
|
+
if aggcolumns:
|
43
|
+
results = self.aggregate(query, results, columns, aggcolumns)
|
44
|
+
|
45
|
+
# Sort results and return
|
46
|
+
return self.orderby(query, results) if query["orderby"] else self.defaultsort(results)
|
47
|
+
|
48
|
+
# Otherwise, run default sort
|
49
|
+
return self.defaultsort(results)
|
50
|
+
|
51
|
+
def aggcolumns(self, columns):
|
52
|
+
"""
|
53
|
+
Filters columns for columns that have an aggregate function call.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
columns: list of columns
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
list of aggregate columns
|
60
|
+
"""
|
61
|
+
|
62
|
+
aggregates = {}
|
63
|
+
for column in columns:
|
64
|
+
column = column.lower()
|
65
|
+
if column.startswith(("count(", "sum(", "total(")):
|
66
|
+
aggregates[column] = sum
|
67
|
+
elif column.startswith("max("):
|
68
|
+
aggregates[column] = max
|
69
|
+
elif column.startswith("min("):
|
70
|
+
aggregates[column] = min
|
71
|
+
elif column.startswith("avg("):
|
72
|
+
aggregates[column] = lambda x: sum(x) / len(x)
|
73
|
+
|
74
|
+
return aggregates
|
75
|
+
|
76
|
+
def aggregate(self, query, results, columns, aggcolumns):
|
77
|
+
"""
|
78
|
+
Merges aggregate columns in results.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
query: input query
|
82
|
+
results: query results
|
83
|
+
columns: list of select columns
|
84
|
+
aggcolumns: list of aggregate columns
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
results with aggregates merged
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Group data, if necessary
|
91
|
+
if query["groupby"]:
|
92
|
+
results = self.groupby(query, results, columns)
|
93
|
+
else:
|
94
|
+
results = [results]
|
95
|
+
|
96
|
+
# Compute column values
|
97
|
+
rows = []
|
98
|
+
for result in results:
|
99
|
+
# Calculate/copy column values
|
100
|
+
row = {}
|
101
|
+
for column in columns:
|
102
|
+
if column in aggcolumns:
|
103
|
+
# Calculate aggregate value
|
104
|
+
function = aggcolumns[column]
|
105
|
+
row[column] = function([r[column] for r in result])
|
106
|
+
else:
|
107
|
+
# Non aggregate column value repeat, use first value
|
108
|
+
row[column] = result[0][column]
|
109
|
+
|
110
|
+
# Add row using original query columns
|
111
|
+
rows.append(row)
|
112
|
+
|
113
|
+
return rows
|
114
|
+
|
115
|
+
def groupby(self, query, results, columns):
|
116
|
+
"""
|
117
|
+
Groups results using query group by clause.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
query: input query
|
121
|
+
results: query results
|
122
|
+
columns: list of select columns
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
results grouped using group by clause
|
126
|
+
"""
|
127
|
+
|
128
|
+
groupby = [column for column in columns if column.lower() in query["groupby"]]
|
129
|
+
if groupby:
|
130
|
+
results = sorted(results, key=operator.itemgetter(*groupby))
|
131
|
+
return [list(value) for _, value in itertools.groupby(results, operator.itemgetter(*groupby))]
|
132
|
+
|
133
|
+
return [results]
|
134
|
+
|
135
|
+
def orderby(self, query, results):
|
136
|
+
"""
|
137
|
+
Applies an order by clause to results.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
query: input query
|
141
|
+
results: query results
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
results ordered using order by clause
|
145
|
+
"""
|
146
|
+
|
147
|
+
# Sort in reverse order
|
148
|
+
for clause in query["orderby"][::-1]:
|
149
|
+
# Order by columns must be selected
|
150
|
+
reverse = False
|
151
|
+
if clause.lower().endswith(" asc"):
|
152
|
+
clause = clause.rsplit(" ")[0]
|
153
|
+
elif clause.lower().endswith(" desc"):
|
154
|
+
clause = clause.rsplit(" ")[0]
|
155
|
+
reverse = True
|
156
|
+
|
157
|
+
# Order by columns must be in select clause
|
158
|
+
if clause in query["select"]:
|
159
|
+
results = sorted(results, key=operator.itemgetter(clause), reverse=reverse)
|
160
|
+
|
161
|
+
return results
|
162
|
+
|
163
|
+
def defaultsort(self, results):
|
164
|
+
"""
|
165
|
+
Default sorting algorithm for results. Sorts by score descending, if available.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
results: query results
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
results ordered by score descending
|
172
|
+
"""
|
173
|
+
|
174
|
+
# Sort standard query using score column, if present
|
175
|
+
if results and "score" in results[0]:
|
176
|
+
return sorted(results, key=lambda x: x["score"], reverse=True)
|
177
|
+
|
178
|
+
return results
|
@@ -0,0 +1,189 @@
|
|
1
|
+
"""
|
2
|
+
SQL module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from io import StringIO
|
6
|
+
from shlex import shlex
|
7
|
+
|
8
|
+
from .expression import Expression
|
9
|
+
|
10
|
+
|
11
|
+
class SQL:
|
12
|
+
"""
|
13
|
+
Translates txtai SQL statements into database native queries.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# List of clauses to parse
|
17
|
+
CLAUSES = ["select", "from", "where", "group", "having", "order", "limit", "offset"]
|
18
|
+
|
19
|
+
def __init__(self, database=None, tolist=False):
|
20
|
+
"""
|
21
|
+
Creates a new SQL query parser.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
database: database instance that provides resolver callback, if any
|
25
|
+
tolist: outputs expression lists if True, expression text otherwise, defaults to False
|
26
|
+
"""
|
27
|
+
|
28
|
+
# Expression parser
|
29
|
+
self.expression = Expression(database.resolve if database else self.defaultresolve, tolist)
|
30
|
+
|
31
|
+
def __call__(self, query):
|
32
|
+
"""
|
33
|
+
Parses an input SQL query and normalizes column names in the query clauses. This method will also embed
|
34
|
+
similarity search placeholders into the query.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
query: input query
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
{clause name: clause text}
|
41
|
+
"""
|
42
|
+
|
43
|
+
clauses = None
|
44
|
+
if self.issql(query):
|
45
|
+
# Ignore multiple statements
|
46
|
+
query = query.split(";")[0]
|
47
|
+
|
48
|
+
# Tokenize query
|
49
|
+
tokens, positions = self.tokenize(query)
|
50
|
+
|
51
|
+
# Alias clauses and similar queries
|
52
|
+
aliases, similar = {}, []
|
53
|
+
|
54
|
+
# Parse SQL clauses
|
55
|
+
clauses = {
|
56
|
+
"select": self.parse(tokens, positions, "select", alias=True, aliases=aliases),
|
57
|
+
"where": self.parse(tokens, positions, "where", aliases=aliases, similar=similar),
|
58
|
+
"groupby": self.parse(tokens, positions, "group", offset=2, aliases=aliases),
|
59
|
+
"having": self.parse(tokens, positions, "having", aliases=aliases),
|
60
|
+
"orderby": self.parse(tokens, positions, "order", offset=2, aliases=aliases),
|
61
|
+
"limit": self.parse(tokens, positions, "limit", aliases=aliases),
|
62
|
+
"offset": self.parse(tokens, positions, "offset", aliases=aliases),
|
63
|
+
}
|
64
|
+
|
65
|
+
# Add parsed similar queries, if any
|
66
|
+
if similar:
|
67
|
+
clauses["similar"] = similar
|
68
|
+
|
69
|
+
# Return clauses, default to full query if this is not a SQL query
|
70
|
+
return clauses if clauses else {"similar": [[query]]}
|
71
|
+
|
72
|
+
# pylint: disable=W0613
|
73
|
+
def defaultresolve(self, name, alias=None):
|
74
|
+
"""
|
75
|
+
Default resolve function. Performs no processing, only returns name.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
name: query column name
|
79
|
+
alias: alias name, defaults to None
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
name
|
83
|
+
"""
|
84
|
+
|
85
|
+
return name
|
86
|
+
|
87
|
+
def issql(self, query):
|
88
|
+
"""
|
89
|
+
Detects if this is a SQL query.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
query: input query
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
True if this is a valid SQL query, False otherwise
|
96
|
+
"""
|
97
|
+
|
98
|
+
if isinstance(query, str):
|
99
|
+
# Reduce query to a lower-cased single line stripped of leading/trailing whitespace
|
100
|
+
query = query.lower().strip(";").replace("\n", " ").replace("\t", " ").strip()
|
101
|
+
|
102
|
+
# Detect if this is a valid txtai SQL statement
|
103
|
+
return query.startswith("select ") and (" from txtai " in query or query.endswith(" from txtai"))
|
104
|
+
|
105
|
+
return False
|
106
|
+
|
107
|
+
def snippet(self, text):
|
108
|
+
"""
|
109
|
+
Parses a partial SQL snippet.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
text: SQL snippet
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
parsed snippet
|
116
|
+
"""
|
117
|
+
|
118
|
+
tokens, _ = self.tokenize(text)
|
119
|
+
return self.expression(tokens)
|
120
|
+
|
121
|
+
def tokenize(self, query):
|
122
|
+
"""
|
123
|
+
Tokenizes SQL query into tokens.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
query: input query
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
(tokenized query, token positions)
|
130
|
+
"""
|
131
|
+
|
132
|
+
# Build a simple SQL lexer
|
133
|
+
# - Punctuation chars are parsed as standalone tokens which helps identify operators
|
134
|
+
# - Add additional wordchars to prevent splitting on those values
|
135
|
+
# - Disable comments
|
136
|
+
tokens = shlex(StringIO(query), punctuation_chars="=!<>+-*/%|")
|
137
|
+
tokens.wordchars += ":@#"
|
138
|
+
tokens.commenters = ""
|
139
|
+
tokens = list(tokens)
|
140
|
+
|
141
|
+
# Identify sql clause token positions
|
142
|
+
positions = {}
|
143
|
+
|
144
|
+
# Get position of clause keywords. For multi-term clauses, validate next token matches as well
|
145
|
+
for x, token in enumerate(tokens):
|
146
|
+
t = token.lower()
|
147
|
+
if t not in positions and t in SQL.CLAUSES and (t not in ["group", "order"] or (x + 1 < len(tokens) and tokens[x + 1].lower() == "by")):
|
148
|
+
positions[t] = x
|
149
|
+
|
150
|
+
return (tokens, positions)
|
151
|
+
|
152
|
+
def parse(self, tokens, positions, name, offset=1, alias=False, aliases=None, similar=None):
|
153
|
+
"""
|
154
|
+
Runs query column name to database column name mappings for clauses. This method will also
|
155
|
+
parse SIMILAR() function calls, extract parameters for those calls and leave a placeholder
|
156
|
+
to be filled in with similarity results.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
tokens: query tokens
|
160
|
+
positions: token positions - used to locate the start of sql clauses
|
161
|
+
name: current query clause name
|
162
|
+
offset: how many tokens are in the clause name
|
163
|
+
alias: True if terms in the clause should be aliased (i.e. column as alias)
|
164
|
+
aliases: dict of generated aliases, if present these tokens should NOT be resolved
|
165
|
+
similar: list where parsed similar clauses should be stored
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
formatted clause
|
169
|
+
"""
|
170
|
+
|
171
|
+
clause = None
|
172
|
+
if name in positions:
|
173
|
+
# Find the next clause token
|
174
|
+
end = [positions.get(x, len(tokens)) for x in SQL.CLAUSES[SQL.CLAUSES.index(name) + 1 :]]
|
175
|
+
end = min(end) if end else len(tokens)
|
176
|
+
|
177
|
+
# Start after current clause token and end before next clause or end of string
|
178
|
+
clause = tokens[positions[name] + offset : end]
|
179
|
+
|
180
|
+
# Parse and resolve parameters
|
181
|
+
clause = self.expression(clause, alias, aliases, similar)
|
182
|
+
|
183
|
+
return clause
|
184
|
+
|
185
|
+
|
186
|
+
class SQLError(Exception):
|
187
|
+
"""
|
188
|
+
Raised for errors generated by user SQL queries
|
189
|
+
"""
|