mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,98 @@
1
+ """
2
+ Statement module
3
+ """
4
+
5
+
6
+ class Statement:
7
+ """
8
+ Standard database schema SQL statements.
9
+ """
10
+
11
+ # Temporary table for working with id batches
12
+ CREATE_BATCH = """
13
+ CREATE TEMP TABLE IF NOT EXISTS batch (
14
+ indexid INTEGER,
15
+ id TEXT,
16
+ batch INTEGER
17
+ )
18
+ """
19
+
20
+ DELETE_BATCH = "DELETE FROM batch"
21
+ INSERT_BATCH_INDEXID = "INSERT INTO batch (indexid, batch) VALUES (?, ?)"
22
+ INSERT_BATCH_ID = "INSERT INTO batch (id, batch) VALUES (?, ?)"
23
+
24
+ # Temporary table for joining similarity scores
25
+ CREATE_SCORES = """
26
+ CREATE TEMP TABLE IF NOT EXISTS scores (
27
+ indexid INTEGER PRIMARY KEY,
28
+ score REAL
29
+ )
30
+ """
31
+
32
+ DELETE_SCORES = "DELETE FROM scores"
33
+ INSERT_SCORE = "INSERT INTO scores VALUES (?, ?)"
34
+
35
+ # Documents - stores full content
36
+ CREATE_DOCUMENTS = """
37
+ CREATE TABLE IF NOT EXISTS documents (
38
+ id TEXT PRIMARY KEY,
39
+ data JSON,
40
+ tags TEXT,
41
+ entry DATETIME
42
+ )
43
+ """
44
+
45
+ INSERT_DOCUMENT = "INSERT OR REPLACE INTO documents VALUES (?, ?, ?, ?)"
46
+ DELETE_DOCUMENTS = "DELETE FROM documents WHERE id IN (SELECT id FROM batch)"
47
+
48
+ # Objects - stores binary content
49
+ CREATE_OBJECTS = """
50
+ CREATE TABLE IF NOT EXISTS objects (
51
+ id TEXT PRIMARY KEY,
52
+ object BLOB,
53
+ tags TEXT,
54
+ entry DATETIME
55
+ )
56
+ """
57
+
58
+ INSERT_OBJECT = "INSERT OR REPLACE INTO objects VALUES (?, ?, ?, ?)"
59
+ DELETE_OBJECTS = "DELETE FROM objects WHERE id IN (SELECT id FROM batch)"
60
+
61
+ # Sections - stores section text
62
+ CREATE_SECTIONS = """
63
+ CREATE TABLE IF NOT EXISTS %s (
64
+ indexid INTEGER PRIMARY KEY,
65
+ id TEXT,
66
+ text TEXT,
67
+ tags TEXT,
68
+ entry DATETIME
69
+ )
70
+ """
71
+
72
+ CREATE_SECTIONS_INDEX = "CREATE INDEX section_id ON sections(id)"
73
+ INSERT_SECTION = "INSERT INTO sections VALUES (?, ?, ?, ?, ?)"
74
+ DELETE_SECTIONS = "DELETE FROM sections WHERE id IN (SELECT id FROM batch)"
75
+ COPY_SECTIONS = (
76
+ "INSERT INTO %s SELECT (select count(*) - 1 from sections s1 where s.indexid >= s1.indexid) indexid, "
77
+ + "s.id, %s AS text, s.tags, s.entry FROM sections s LEFT JOIN documents d ON s.id = d.id ORDER BY indexid"
78
+ )
79
+ STREAM_SECTIONS = (
80
+ "SELECT s.id, s.text, data, object, s.tags FROM %s s "
81
+ + "LEFT JOIN documents d ON s.id = d.id "
82
+ + "LEFT JOIN objects o ON s.id = o.id ORDER BY indexid"
83
+ )
84
+ DROP_SECTIONS = "DROP TABLE sections"
85
+ RENAME_SECTIONS = "ALTER TABLE %s RENAME TO sections"
86
+
87
+ # Queries
88
+ SELECT_IDS = "SELECT indexid, id FROM sections WHERE id in (SELECT id FROM batch)"
89
+ COUNT_IDS = "SELECT count(indexid) FROM sections"
90
+
91
+ # Partial sql clauses
92
+ TABLE_CLAUSE = (
93
+ "SELECT %s FROM sections s "
94
+ + "LEFT JOIN documents d ON s.id = d.id "
95
+ + "LEFT JOIN objects o ON s.id = o.id "
96
+ + "LEFT JOIN scores sc ON s.indexid = sc.indexid"
97
+ )
98
+ IDS_CLAUSE = "s.indexid in (SELECT indexid from batch WHERE batch=%s)"
@@ -0,0 +1,8 @@
1
+ """
2
+ SQL imports
3
+ """
4
+
5
+ from .aggregate import Aggregate
6
+ from .base import SQL, SQLError
7
+ from .expression import Expression
8
+ from .token import Token
@@ -0,0 +1,178 @@
1
+ """
2
+ Aggregate module
3
+ """
4
+
5
+ import itertools
6
+ import operator
7
+
8
+ from .base import SQL
9
+
10
+
11
+ class Aggregate(SQL):
12
+ """
13
+ Aggregates partial results from queries. Partial results come from queries when working with sharded indexes.
14
+ """
15
+
16
+ def __init__(self, database=None):
17
+ # Always return token lists as this method requires them
18
+ super().__init__(database, True)
19
+
20
+ def __call__(self, query, results):
21
+ """
22
+ Analyzes query results, combines aggregate function results and applies ordering.
23
+
24
+ Args:
25
+ query: input query
26
+ results: query results
27
+
28
+ Returns:
29
+ aggregated query results
30
+ """
31
+
32
+ # Parse query
33
+ query = super().__call__(query)
34
+
35
+ # Check if this is a SQL query
36
+ if "select" in query:
37
+ # Get list of unique and aggregate columns. If no aggregate columns or order by found, skip
38
+ columns = list(results[0].keys())
39
+ aggcolumns = self.aggcolumns(columns)
40
+ if aggcolumns or query["orderby"]:
41
+ # Merge aggregate columns
42
+ if aggcolumns:
43
+ results = self.aggregate(query, results, columns, aggcolumns)
44
+
45
+ # Sort results and return
46
+ return self.orderby(query, results) if query["orderby"] else self.defaultsort(results)
47
+
48
+ # Otherwise, run default sort
49
+ return self.defaultsort(results)
50
+
51
+ def aggcolumns(self, columns):
52
+ """
53
+ Filters columns for columns that have an aggregate function call.
54
+
55
+ Args:
56
+ columns: list of columns
57
+
58
+ Returns:
59
+ list of aggregate columns
60
+ """
61
+
62
+ aggregates = {}
63
+ for column in columns:
64
+ column = column.lower()
65
+ if column.startswith(("count(", "sum(", "total(")):
66
+ aggregates[column] = sum
67
+ elif column.startswith("max("):
68
+ aggregates[column] = max
69
+ elif column.startswith("min("):
70
+ aggregates[column] = min
71
+ elif column.startswith("avg("):
72
+ aggregates[column] = lambda x: sum(x) / len(x)
73
+
74
+ return aggregates
75
+
76
+ def aggregate(self, query, results, columns, aggcolumns):
77
+ """
78
+ Merges aggregate columns in results.
79
+
80
+ Args:
81
+ query: input query
82
+ results: query results
83
+ columns: list of select columns
84
+ aggcolumns: list of aggregate columns
85
+
86
+ Returns:
87
+ results with aggregates merged
88
+ """
89
+
90
+ # Group data, if necessary
91
+ if query["groupby"]:
92
+ results = self.groupby(query, results, columns)
93
+ else:
94
+ results = [results]
95
+
96
+ # Compute column values
97
+ rows = []
98
+ for result in results:
99
+ # Calculate/copy column values
100
+ row = {}
101
+ for column in columns:
102
+ if column in aggcolumns:
103
+ # Calculate aggregate value
104
+ function = aggcolumns[column]
105
+ row[column] = function([r[column] for r in result])
106
+ else:
107
+ # Non aggregate column value repeat, use first value
108
+ row[column] = result[0][column]
109
+
110
+ # Add row using original query columns
111
+ rows.append(row)
112
+
113
+ return rows
114
+
115
+ def groupby(self, query, results, columns):
116
+ """
117
+ Groups results using query group by clause.
118
+
119
+ Args:
120
+ query: input query
121
+ results: query results
122
+ columns: list of select columns
123
+
124
+ Returns:
125
+ results grouped using group by clause
126
+ """
127
+
128
+ groupby = [column for column in columns if column.lower() in query["groupby"]]
129
+ if groupby:
130
+ results = sorted(results, key=operator.itemgetter(*groupby))
131
+ return [list(value) for _, value in itertools.groupby(results, operator.itemgetter(*groupby))]
132
+
133
+ return [results]
134
+
135
+ def orderby(self, query, results):
136
+ """
137
+ Applies an order by clause to results.
138
+
139
+ Args:
140
+ query: input query
141
+ results: query results
142
+
143
+ Returns:
144
+ results ordered using order by clause
145
+ """
146
+
147
+ # Sort in reverse order
148
+ for clause in query["orderby"][::-1]:
149
+ # Order by columns must be selected
150
+ reverse = False
151
+ if clause.lower().endswith(" asc"):
152
+ clause = clause.rsplit(" ")[0]
153
+ elif clause.lower().endswith(" desc"):
154
+ clause = clause.rsplit(" ")[0]
155
+ reverse = True
156
+
157
+ # Order by columns must be in select clause
158
+ if clause in query["select"]:
159
+ results = sorted(results, key=operator.itemgetter(clause), reverse=reverse)
160
+
161
+ return results
162
+
163
+ def defaultsort(self, results):
164
+ """
165
+ Default sorting algorithm for results. Sorts by score descending, if available.
166
+
167
+ Args:
168
+ results: query results
169
+
170
+ Returns:
171
+ results ordered by score descending
172
+ """
173
+
174
+ # Sort standard query using score column, if present
175
+ if results and "score" in results[0]:
176
+ return sorted(results, key=lambda x: x["score"], reverse=True)
177
+
178
+ return results
@@ -0,0 +1,189 @@
1
+ """
2
+ SQL module
3
+ """
4
+
5
+ from io import StringIO
6
+ from shlex import shlex
7
+
8
+ from .expression import Expression
9
+
10
+
11
+ class SQL:
12
+ """
13
+ Translates txtai SQL statements into database native queries.
14
+ """
15
+
16
+ # List of clauses to parse
17
+ CLAUSES = ["select", "from", "where", "group", "having", "order", "limit", "offset"]
18
+
19
+ def __init__(self, database=None, tolist=False):
20
+ """
21
+ Creates a new SQL query parser.
22
+
23
+ Args:
24
+ database: database instance that provides resolver callback, if any
25
+ tolist: outputs expression lists if True, expression text otherwise, defaults to False
26
+ """
27
+
28
+ # Expression parser
29
+ self.expression = Expression(database.resolve if database else self.defaultresolve, tolist)
30
+
31
+ def __call__(self, query):
32
+ """
33
+ Parses an input SQL query and normalizes column names in the query clauses. This method will also embed
34
+ similarity search placeholders into the query.
35
+
36
+ Args:
37
+ query: input query
38
+
39
+ Returns:
40
+ {clause name: clause text}
41
+ """
42
+
43
+ clauses = None
44
+ if self.issql(query):
45
+ # Ignore multiple statements
46
+ query = query.split(";")[0]
47
+
48
+ # Tokenize query
49
+ tokens, positions = self.tokenize(query)
50
+
51
+ # Alias clauses and similar queries
52
+ aliases, similar = {}, []
53
+
54
+ # Parse SQL clauses
55
+ clauses = {
56
+ "select": self.parse(tokens, positions, "select", alias=True, aliases=aliases),
57
+ "where": self.parse(tokens, positions, "where", aliases=aliases, similar=similar),
58
+ "groupby": self.parse(tokens, positions, "group", offset=2, aliases=aliases),
59
+ "having": self.parse(tokens, positions, "having", aliases=aliases),
60
+ "orderby": self.parse(tokens, positions, "order", offset=2, aliases=aliases),
61
+ "limit": self.parse(tokens, positions, "limit", aliases=aliases),
62
+ "offset": self.parse(tokens, positions, "offset", aliases=aliases),
63
+ }
64
+
65
+ # Add parsed similar queries, if any
66
+ if similar:
67
+ clauses["similar"] = similar
68
+
69
+ # Return clauses, default to full query if this is not a SQL query
70
+ return clauses if clauses else {"similar": [[query]]}
71
+
72
+ # pylint: disable=W0613
73
+ def defaultresolve(self, name, alias=None):
74
+ """
75
+ Default resolve function. Performs no processing, only returns name.
76
+
77
+ Args:
78
+ name: query column name
79
+ alias: alias name, defaults to None
80
+
81
+ Returns:
82
+ name
83
+ """
84
+
85
+ return name
86
+
87
+ def issql(self, query):
88
+ """
89
+ Detects if this is a SQL query.
90
+
91
+ Args:
92
+ query: input query
93
+
94
+ Returns:
95
+ True if this is a valid SQL query, False otherwise
96
+ """
97
+
98
+ if isinstance(query, str):
99
+ # Reduce query to a lower-cased single line stripped of leading/trailing whitespace
100
+ query = query.lower().strip(";").replace("\n", " ").replace("\t", " ").strip()
101
+
102
+ # Detect if this is a valid txtai SQL statement
103
+ return query.startswith("select ") and (" from txtai " in query or query.endswith(" from txtai"))
104
+
105
+ return False
106
+
107
+ def snippet(self, text):
108
+ """
109
+ Parses a partial SQL snippet.
110
+
111
+ Args:
112
+ text: SQL snippet
113
+
114
+ Returns:
115
+ parsed snippet
116
+ """
117
+
118
+ tokens, _ = self.tokenize(text)
119
+ return self.expression(tokens)
120
+
121
+ def tokenize(self, query):
122
+ """
123
+ Tokenizes SQL query into tokens.
124
+
125
+ Args:
126
+ query: input query
127
+
128
+ Returns:
129
+ (tokenized query, token positions)
130
+ """
131
+
132
+ # Build a simple SQL lexer
133
+ # - Punctuation chars are parsed as standalone tokens which helps identify operators
134
+ # - Add additional wordchars to prevent splitting on those values
135
+ # - Disable comments
136
+ tokens = shlex(StringIO(query), punctuation_chars="=!<>+-*/%|")
137
+ tokens.wordchars += ":@#"
138
+ tokens.commenters = ""
139
+ tokens = list(tokens)
140
+
141
+ # Identify sql clause token positions
142
+ positions = {}
143
+
144
+ # Get position of clause keywords. For multi-term clauses, validate next token matches as well
145
+ for x, token in enumerate(tokens):
146
+ t = token.lower()
147
+ if t not in positions and t in SQL.CLAUSES and (t not in ["group", "order"] or (x + 1 < len(tokens) and tokens[x + 1].lower() == "by")):
148
+ positions[t] = x
149
+
150
+ return (tokens, positions)
151
+
152
+ def parse(self, tokens, positions, name, offset=1, alias=False, aliases=None, similar=None):
153
+ """
154
+ Runs query column name to database column name mappings for clauses. This method will also
155
+ parse SIMILAR() function calls, extract parameters for those calls and leave a placeholder
156
+ to be filled in with similarity results.
157
+
158
+ Args:
159
+ tokens: query tokens
160
+ positions: token positions - used to locate the start of sql clauses
161
+ name: current query clause name
162
+ offset: how many tokens are in the clause name
163
+ alias: True if terms in the clause should be aliased (i.e. column as alias)
164
+ aliases: dict of generated aliases, if present these tokens should NOT be resolved
165
+ similar: list where parsed similar clauses should be stored
166
+
167
+ Returns:
168
+ formatted clause
169
+ """
170
+
171
+ clause = None
172
+ if name in positions:
173
+ # Find the next clause token
174
+ end = [positions.get(x, len(tokens)) for x in SQL.CLAUSES[SQL.CLAUSES.index(name) + 1 :]]
175
+ end = min(end) if end else len(tokens)
176
+
177
+ # Start after current clause token and end before next clause or end of string
178
+ clause = tokens[positions[name] + offset : end]
179
+
180
+ # Parse and resolve parameters
181
+ clause = self.expression(clause, alias, aliases, similar)
182
+
183
+ return clause
184
+
185
+
186
+ class SQLError(Exception):
187
+ """
188
+ Raised for errors generated by user SQL queries
189
+ """