mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
txtai/data/texts.py ADDED
@@ -0,0 +1,68 @@
1
+ """
2
+ Texts module
3
+ """
4
+
5
+ from itertools import chain
6
+
7
+ from .base import Data
8
+
9
+
10
+ class Texts(Data):
11
+ """
12
+ Tokenizes text datasets as input for training language models.
13
+ """
14
+
15
+ def __init__(self, tokenizer, columns, maxlength):
16
+ """
17
+ Creates a new instance for tokenizing Texts training data.
18
+
19
+ Args:
20
+ tokenizer: model tokenizer
21
+ columns: tuple of columns to use for text
22
+ maxlength: maximum sequence length
23
+ """
24
+
25
+ super().__init__(tokenizer, columns, maxlength)
26
+
27
+ # Standardize columns
28
+ if not self.columns:
29
+ self.columns = ("text", None)
30
+
31
+ def process(self, data):
32
+ # Column keys
33
+ text1, text2 = self.columns
34
+
35
+ # Tokenizer inputs can be single string or string pair, depending on task
36
+ text = (data[text1], data[text2]) if text2 else (data[text1],)
37
+
38
+ # Tokenize text and add label
39
+ inputs = self.tokenizer(*text, return_special_tokens_mask=True)
40
+
41
+ # Concat and return tokenized inputs
42
+ return self.concat(inputs)
43
+
44
+ def concat(self, inputs):
45
+ """
46
+ Concatenates tokenized text into chunks of maxlength.
47
+
48
+ Args:
49
+ inputs: tokenized input
50
+
51
+ Returns:
52
+ Chunks of tokenized text each with a size of maxlength
53
+ """
54
+
55
+ # Concatenate tokenized text
56
+ concat = {k: list(chain(*inputs[k])) for k in inputs.keys()}
57
+
58
+ # Calculate total length
59
+ length = len(concat[list(inputs.keys())[0]])
60
+
61
+ # Ensure total is multiple of maxlength, drop last incomplete chunk
62
+ if length >= self.maxlength:
63
+ length = (length // self.maxlength) * self.maxlength
64
+
65
+ # Split into chunks of maxlength
66
+ result = {k: [v[x : x + self.maxlength] for x in range(0, length, self.maxlength)] for k, v in concat.items()}
67
+
68
+ return result
txtai/data/tokens.py ADDED
@@ -0,0 +1,28 @@
1
+ """
2
+ Tokens module
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ class Tokens(torch.utils.data.Dataset):
9
+ """
10
+ Default dataset used to hold tokenized data.
11
+ """
12
+
13
+ def __init__(self, columns):
14
+ self.data = []
15
+
16
+ # Map column-oriented data to rows
17
+ for column in columns:
18
+ for x, value in enumerate(columns[column]):
19
+ if len(self.data) <= x:
20
+ self.data.append({})
21
+
22
+ self.data[x][column] = value
23
+
24
+ def __len__(self):
25
+ return len(self.data)
26
+
27
+ def __getitem__(self, index):
28
+ return self.data[index]
@@ -0,0 +1,14 @@
1
+ """
2
+ Database imports
3
+ """
4
+
5
+ from .base import Database
6
+ from .client import Client
7
+ from .duckdb import DuckDB
8
+ from .embedded import Embedded
9
+ from .encoder import *
10
+ from .factory import DatabaseFactory
11
+ from .rdbms import RDBMS
12
+ from .schema import *
13
+ from .sqlite import SQLite
14
+ from .sql import *
txtai/database/base.py ADDED
@@ -0,0 +1,342 @@
1
+ """
2
+ Database module
3
+ """
4
+
5
+ import logging
6
+ import types
7
+
8
+ from .encoder import EncoderFactory
9
+ from .sql import SQL, SQLError, Token
10
+
11
+ # Logging configuration
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Database:
16
+ """
17
+ Base class for database instances. This class encapsulates a content database used for
18
+ storing field content as dicts and objects. The database instance works in conjuction
19
+ with a vector index to execute SQL-driven similarity search.
20
+ """
21
+
22
+ def __init__(self, config):
23
+ """
24
+ Creates a new Database.
25
+
26
+ Args:
27
+ config: database configuration
28
+ """
29
+
30
+ # Initialize configuration
31
+ self.configure(config)
32
+
33
+ def load(self, path):
34
+ """
35
+ Loads a database path.
36
+
37
+ Args:
38
+ path: database url
39
+ """
40
+
41
+ raise NotImplementedError
42
+
43
+ def insert(self, documents, index=0):
44
+ """
45
+ Inserts documents into the database.
46
+
47
+ Args:
48
+ documents: list of documents to save
49
+ index: indexid offset, used for internal ids
50
+ """
51
+
52
+ raise NotImplementedError
53
+
54
+ def delete(self, ids):
55
+ """
56
+ Deletes documents from database.
57
+
58
+ Args:
59
+ ids: ids to delete
60
+ """
61
+
62
+ raise NotImplementedError
63
+
64
+ def reindex(self, config):
65
+ """
66
+ Reindexes internal database content and streams results back. This method must renumber indexids
67
+ sequentially as deletes could have caused indexid gaps.
68
+
69
+ Args:
70
+ config: new configuration
71
+ """
72
+
73
+ raise NotImplementedError
74
+
75
+ def save(self, path):
76
+ """
77
+ Saves a database at path.
78
+
79
+ Args:
80
+ path: path to write database
81
+ """
82
+
83
+ raise NotImplementedError
84
+
85
+ def close(self):
86
+ """
87
+ Closes this database.
88
+ """
89
+
90
+ raise NotImplementedError
91
+
92
+ def ids(self, ids):
93
+ """
94
+ Retrieves the internal indexids for a list of ids. Multiple indexids may be present for an id in cases
95
+ where data is segmented.
96
+
97
+ Args:
98
+ ids: list of document ids
99
+
100
+ Returns:
101
+ list of (indexid, id)
102
+ """
103
+
104
+ raise NotImplementedError
105
+
106
+ def count(self):
107
+ """
108
+ Retrieves the count of this database instance.
109
+
110
+ Returns:
111
+ total database count
112
+ """
113
+
114
+ raise NotImplementedError
115
+
116
+ def search(self, query, similarity=None, limit=None, parameters=None, indexids=False):
117
+ """
118
+ Runs a search against the database. Supports the following methods:
119
+
120
+ 1. Standard similarity query. This mode retrieves content for the ids in the similarity results
121
+ 2. Similarity query as SQL. This mode will combine similarity results and database results into
122
+ a single result set. Similarity queries are set via the SIMILAR() function.
123
+ 3. SQL with no similarity query. This mode runs a SQL query and retrieves the results without similarity queries.
124
+
125
+ Example queries:
126
+ "natural language processing" - standard similarity only query
127
+ "select * from txtai where similar('natural language processing')" - similarity query as SQL
128
+ "select * from txtai where similar('nlp') and entry > '2021-01-01'" - similarity query with additional SQL clauses
129
+ "select id, text, score from txtai where similar('nlp')" - similarity query with additional SQL column selections
130
+ "select * from txtai where entry > '2021-01-01' - database only query
131
+
132
+ Args:
133
+ query: input query
134
+ similarity: similarity results as [(indexid, score)]
135
+ limit: maximum number of results to return
136
+ parameters: dict of named parameters to bind to placeholders
137
+
138
+ Returns:
139
+ query results as a list of dicts
140
+ list of ([indexid, score]) if indexids is True
141
+ """
142
+
143
+ # Parse query if necessary
144
+ if isinstance(query, str):
145
+ query = self.parse(query)
146
+
147
+ # Add in similar results
148
+ where = query.get("where")
149
+
150
+ if "select" in query and similarity:
151
+ for x in range(len(similarity)):
152
+ token = f"{Token.SIMILAR_TOKEN}{x}"
153
+ if where and token in where:
154
+ where = where.replace(token, self.embed(similarity, x))
155
+
156
+ elif similarity:
157
+ # Not a SQL query, load similarity results, if any
158
+ where = self.embed(similarity, 0)
159
+
160
+ # Save where
161
+ query["where"] = where
162
+
163
+ # Run query
164
+ return self.query(query, limit, parameters, indexids)
165
+
166
+ def parse(self, query):
167
+ """
168
+ Parses a query into query components.
169
+
170
+ Args:
171
+ query: input query
172
+
173
+ Returns:
174
+ dict of parsed query components
175
+ """
176
+
177
+ return self.sql(query)
178
+
179
+ def resolve(self, name, alias=None):
180
+ """
181
+ Resolves a query column name with the database column name. This method also builds alias expressions
182
+ if alias is set.
183
+
184
+ Args:
185
+ name: query column name
186
+ alias: alias name, defaults to None
187
+
188
+ Returns:
189
+ database column name
190
+ """
191
+
192
+ raise NotImplementedError
193
+
194
+ def embed(self, similarity, batch):
195
+ """
196
+ Embeds similarity query results into a database query.
197
+
198
+ Args:
199
+ similarity: similarity results as [(indexid, score)]
200
+ batch: batch id
201
+ """
202
+
203
+ raise NotImplementedError
204
+
205
+ def query(self, query, limit, parameters, indexids):
206
+ """
207
+ Executes query against database.
208
+
209
+ Args:
210
+ query: input query
211
+ limit: maximum number of results to return
212
+ parameters: dict of named parameters to bind to placeholders
213
+ indexids: results are returned as [(indexid, score)] regardless of select clause parameters if True
214
+
215
+ Returns:
216
+ query results
217
+ """
218
+
219
+ raise NotImplementedError
220
+
221
+ def configure(self, config):
222
+ """
223
+ Initialize configuration.
224
+
225
+ Args:
226
+ config: configuration
227
+ """
228
+
229
+ # Database configuration
230
+ self.config = config
231
+
232
+ # SQL parser
233
+ self.sql = SQL(self)
234
+
235
+ # Load objects encoder
236
+ encoder = self.config.get("objects")
237
+ self.encoder = EncoderFactory.create(encoder) if encoder else None
238
+
239
+ # Transform columns
240
+ columns = config.get("columns", {})
241
+ self.text = columns.get("text", "text")
242
+ self.object = columns.get("object", "object")
243
+
244
+ # Custom functions and expressions
245
+ self.functions, self.expressions = None, None
246
+
247
+ # Load custom functions
248
+ self.registerfunctions(self.config)
249
+
250
+ # Load custom expressions
251
+ self.registerexpressions(self.config)
252
+
253
+ def registerfunctions(self, config):
254
+ """
255
+ Register custom functions. This method stores the function details for underlying
256
+ database implementations to handle.
257
+
258
+ Args:
259
+ config: database configuration
260
+ """
261
+
262
+ inputs = config.get("functions") if config else None
263
+ if inputs:
264
+ functions = []
265
+ for fn in inputs:
266
+ name, argcount = None, -1
267
+
268
+ # Optional function configuration
269
+ if isinstance(fn, dict):
270
+ name, argcount, fn = fn.get("name"), fn.get("argcount", -1), fn["function"]
271
+
272
+ # Determine if this is a callable object or a function
273
+ if not isinstance(fn, types.FunctionType) and hasattr(fn, "__call__"):
274
+ name = name if name else fn.__class__.__name__.lower()
275
+ fn = fn.__call__
276
+ else:
277
+ name = name if name else fn.__name__.lower()
278
+
279
+ # Store function details
280
+ functions.append((name, argcount, fn))
281
+
282
+ # pylint: disable=W0201
283
+ self.functions = functions
284
+
285
+ def registerexpressions(self, config):
286
+ """
287
+ Register custom expressions. This method parses and resolves expressions for later use in SQL queries.
288
+
289
+ Args:
290
+ config: database configuration
291
+ """
292
+
293
+ inputs = config.get("expressions") if config else None
294
+ if inputs:
295
+ expressions = {}
296
+ for entry in inputs:
297
+ name = entry.get("name")
298
+ expression = entry.get("expression")
299
+ if name and expression:
300
+ expressions[name] = self.sql.snippet(expression)
301
+
302
+ # pylint: disable=W0201
303
+ self.expressions = expressions
304
+
305
+ def execute(self, function, *args):
306
+ """
307
+ Executes a user query. This method has common error handling logic.
308
+
309
+ Args:
310
+ function: database execute function
311
+ args: function arguments
312
+
313
+ Returns:
314
+ result of function(args)
315
+ """
316
+
317
+ try:
318
+ # Debug log SQL
319
+ logger.debug(" ".join(["%s"] * len(args)), *args)
320
+
321
+ return function(*args)
322
+ except Exception as e:
323
+ raise SQLError(e) from None
324
+
325
+ def setting(self, name, default=None):
326
+ """
327
+ Looks up database specific setting.
328
+
329
+ Args:
330
+ name: setting name
331
+ default: default value when setting not found
332
+
333
+ Returns:
334
+ setting value
335
+ """
336
+
337
+ # Get the database-specific config object
338
+ database = self.config.get(self.config["content"])
339
+
340
+ # Get setting value, set default value if not found
341
+ setting = database.get(name) if database else None
342
+ return setting if setting else default
@@ -0,0 +1,227 @@
1
+ """
2
+ Client module
3
+ """
4
+
5
+ import os
6
+ import time
7
+
8
+ # Conditional import
9
+ try:
10
+ from sqlalchemy import StaticPool, Text, cast, create_engine, insert, text as textsql
11
+ from sqlalchemy.orm import Session, aliased
12
+ from sqlalchemy.schema import CreateSchema
13
+
14
+ from .schema import Base, Batch, Document, Object, Section, SectionBase, Score
15
+
16
+ ORM = True
17
+ except ImportError:
18
+ ORM = False
19
+
20
+ from .rdbms import RDBMS
21
+
22
+
23
+ class Client(RDBMS):
24
+ """
25
+ Database client instance. This class connects to an external database using SQLAlchemy. It supports any database
26
+ that is supported by SQLAlchemy (PostgreSQL, MariaDB, etc) and has JSON support.
27
+ """
28
+
29
+ def __init__(self, config):
30
+ """
31
+ Creates a new Database.
32
+
33
+ Args:
34
+ config: database configuration parameters
35
+ """
36
+
37
+ super().__init__(config)
38
+
39
+ if not ORM:
40
+ raise ImportError('SQLAlchemy is not available - install "database" extra to enable')
41
+
42
+ # SQLAlchemy parameters
43
+ self.engine, self.dbconnection = None, None
44
+
45
+ def save(self, path):
46
+ # Commit session and database connection
47
+ super().save(path)
48
+
49
+ if self.dbconnection:
50
+ self.dbconnection.commit()
51
+
52
+ def close(self):
53
+ super().close()
54
+
55
+ # Dispose of engine, which also closes dbconnection
56
+ if self.engine:
57
+ self.engine.dispose()
58
+
59
+ def reindexstart(self):
60
+ # Working table name
61
+ name = f"rebuild{round(time.time() * 1000)}"
62
+
63
+ # Create working table metadata
64
+ type("Rebuild", (SectionBase,), {"__tablename__": name})
65
+ Base.metadata.tables[name].create(self.dbconnection)
66
+
67
+ return name
68
+
69
+ def reindexend(self, name):
70
+ # Remove table object from metadata
71
+ Base.metadata.remove(Base.metadata.tables[name])
72
+
73
+ def jsonprefix(self):
74
+ # JSON column prefix
75
+ return "cast("
76
+
77
+ def jsoncolumn(self, name):
78
+ # Alias documents table
79
+ d = aliased(Document, name="d")
80
+
81
+ # Build JSON column expression for column
82
+ return str(cast(d.data[name].as_string(), Text).compile(dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}))
83
+
84
+ def createtables(self):
85
+ # Create tables
86
+ Base.metadata.create_all(self.dbconnection, checkfirst=True)
87
+
88
+ # Clear existing data - table schema is created upon connecting to database
89
+ for table in ["sections", "documents", "objects"]:
90
+ self.cursor.execute(f"DELETE FROM {table}")
91
+
92
+ def finalize(self):
93
+ # Flush cached objects
94
+ self.connection.flush()
95
+
96
+ def insertdocument(self, uid, data, tags, entry):
97
+ self.connection.add(Document(id=uid, data=data, tags=tags, entry=entry))
98
+
99
+ def insertobject(self, uid, data, tags, entry):
100
+ self.connection.add(Object(id=uid, object=data, tags=tags, entry=entry))
101
+
102
+ def insertsection(self, index, uid, text, tags, entry):
103
+ # Save text section
104
+ self.connection.add(Section(indexid=index, id=uid, text=text, tags=tags, entry=entry))
105
+
106
+ def createbatch(self):
107
+ # Create temporary batch table, if necessary
108
+ Base.metadata.tables["batch"].create(self.dbconnection, checkfirst=True)
109
+
110
+ def insertbatch(self, indexids, ids, batch):
111
+ if indexids:
112
+ self.connection.execute(insert(Batch), [{"indexid": i, "batch": batch} for i in indexids])
113
+ if ids:
114
+ self.connection.execute(insert(Batch), [{"id": str(uid), "batch": batch} for uid in ids])
115
+
116
+ def createscores(self):
117
+ # Create temporary scores table, if necessary
118
+ Base.metadata.tables["scores"].create(self.dbconnection, checkfirst=True)
119
+
120
+ def insertscores(self, scores):
121
+ # Average scores by id
122
+ if scores:
123
+ self.connection.execute(insert(Score), [{"indexid": i, "score": sum(s) / len(s)} for i, s in scores.items()])
124
+
125
+ def connect(self, path=None):
126
+ # Connection URL
127
+ content = self.config.get("content")
128
+
129
+ # Read ENV variable, if necessary
130
+ content = os.environ.get("CLIENT_URL") if content == "client" else content
131
+
132
+ # Create engine using database URL
133
+ self.engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)
134
+ self.dbconnection = self.engine.connect()
135
+
136
+ # Create database session
137
+ database = Session(self.dbconnection)
138
+
139
+ # Set default schema, if necessary
140
+ schema = self.config.get("schema")
141
+ if schema:
142
+ with self.engine.begin():
143
+ self.sqldialect(database, CreateSchema(schema, if_not_exists=True))
144
+
145
+ self.sqldialect(database, textsql("SET search_path TO :schema"), {"schema": schema})
146
+
147
+ return database
148
+
149
+ def getcursor(self):
150
+ return Cursor(self.connection)
151
+
152
+ def rows(self):
153
+ return self.cursor
154
+
155
+ def addfunctions(self):
156
+ return
157
+
158
+ def sqldialect(self, database, sql, parameters=None):
159
+ """
160
+ Executes a SQL statement based on the current SQL dialect.
161
+
162
+ Args:
163
+ database: current database
164
+ sql: SQL to execute
165
+ parameters: optional bind parameters
166
+ """
167
+
168
+ args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (textsql("SELECT 1"),)
169
+ database.execute(*args)
170
+
171
+
172
+ class Cursor:
173
+ """
174
+ Implements basic compatibility with the Python DB-API.
175
+ """
176
+
177
+ def __init__(self, connection):
178
+ self.connection = connection
179
+ self.result = None
180
+
181
+ def __iter__(self):
182
+ return self.result
183
+
184
+ def execute(self, statement, parameters=None):
185
+ """
186
+ Executes statement.
187
+
188
+ Args:
189
+ statement: statement to execute
190
+ parameters: optional dictionary with bind parameters
191
+ """
192
+
193
+ if isinstance(statement, str):
194
+ statement = textsql(statement)
195
+
196
+ self.result = self.connection.execute(statement, parameters)
197
+
198
+ def fetchall(self):
199
+ """
200
+ Fetches all rows from the current result.
201
+
202
+ Returns:
203
+ all rows from current result
204
+ """
205
+
206
+ return self.result.all() if self.result else None
207
+
208
+ def fetchone(self):
209
+ """
210
+ Fetches first row from current result.
211
+
212
+ Returns:
213
+ first row from current result
214
+ """
215
+
216
+ return self.result.first() if self.result else None
217
+
218
+ @property
219
+ def description(self):
220
+ """
221
+ Returns columns for current result.
222
+
223
+ Returns:
224
+ list of columns
225
+ """
226
+
227
+ return [(key,) for key in self.result.keys()] if self.result else None