mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
txtai/graph/factory.py ADDED
@@ -0,0 +1,61 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ from ..util import Resolver
6
+
7
+ from .networkx import NetworkX
8
+ from .rdbms import RDBMS
9
+
10
+
11
+ class GraphFactory:
12
+ """
13
+ Methods to create graphs.
14
+ """
15
+
16
+ @staticmethod
17
+ def create(config):
18
+ """
19
+ Create a Graph.
20
+
21
+ Args:
22
+ config: graph configuration
23
+
24
+ Returns:
25
+ Graph
26
+ """
27
+
28
+ # Graph instance
29
+ graph = None
30
+ backend = config.get("backend", "networkx")
31
+
32
+ # Create graph instance
33
+ if backend == "networkx":
34
+ graph = NetworkX(config)
35
+ elif backend == "rdbms":
36
+ graph = RDBMS(config)
37
+ else:
38
+ graph = GraphFactory.resolve(backend, config)
39
+
40
+ # Store config back
41
+ config["backend"] = backend
42
+
43
+ return graph
44
+
45
+ @staticmethod
46
+ def resolve(backend, config):
47
+ """
48
+ Attempt to resolve a custom backend.
49
+
50
+ Args:
51
+ backend: backend class
52
+ config: index configuration parameters
53
+
54
+ Returns:
55
+ Graph
56
+ """
57
+
58
+ try:
59
+ return Resolver()(backend)(config)
60
+ except Exception as e:
61
+ raise ImportError(f"Unable to resolve graph backend: '{backend}'") from e
@@ -0,0 +1,275 @@
1
+ """
2
+ NetworkX module
3
+ """
4
+
5
+ import os
6
+
7
+ from tempfile import TemporaryDirectory
8
+
9
+ # Conditional import
10
+ try:
11
+ import networkx as nx
12
+
13
+ from networkx.algorithms.community import asyn_lpa_communities, greedy_modularity_communities, louvain_partitions
14
+ from networkx.readwrite import json_graph
15
+
16
+ NETWORKX = True
17
+ except ImportError:
18
+ NETWORKX = False
19
+
20
+ from ..archive import ArchiveFactory
21
+ from ..serialize import SerializeError, SerializeFactory
22
+
23
+ from .base import Graph
24
+ from .query import Query
25
+
26
+
27
+ # pylint: disable=R0904
28
+ class NetworkX(Graph):
29
+ """
30
+ Graph instance backed by NetworkX.
31
+ """
32
+
33
+ def __init__(self, config):
34
+ super().__init__(config)
35
+
36
+ if not NETWORKX:
37
+ raise ImportError('NetworkX is not available - install "graph" extra to enable')
38
+
39
+ def create(self):
40
+ return nx.Graph()
41
+
42
+ def count(self):
43
+ return self.backend.number_of_nodes()
44
+
45
+ def scan(self, attribute=None, data=False):
46
+ # Full graph
47
+ graph = self.backend
48
+
49
+ # Filter graph to nodes having a specified attribute
50
+ if attribute:
51
+ graph = nx.subgraph_view(self.backend, filter_node=lambda x: attribute in self.node(x))
52
+
53
+ # Return either list of matching ids or tuple of (id, attribute dictionary)
54
+ return graph.nodes(data=True) if data else graph
55
+
56
+ def node(self, node):
57
+ return self.backend.nodes.get(node)
58
+
59
+ def addnode(self, node, **attrs):
60
+ self.backend.add_node(node, **attrs)
61
+
62
+ def addnodes(self, nodes):
63
+ self.backend.add_nodes_from(nodes)
64
+
65
+ def removenode(self, node):
66
+ if self.hasnode(node):
67
+ self.backend.remove_node(node)
68
+
69
+ def hasnode(self, node):
70
+ return self.backend.has_node(node)
71
+
72
+ def attribute(self, node, field):
73
+ return self.node(node).get(field) if self.hasnode(node) else None
74
+
75
+ def addattribute(self, node, field, value):
76
+ if self.hasnode(node):
77
+ self.node(node)[field] = value
78
+
79
+ def removeattribute(self, node, field):
80
+ return self.node(node).pop(field, None) if self.hasnode(node) else None
81
+
82
+ def edgecount(self):
83
+ return self.backend.number_of_edges()
84
+
85
+ def edges(self, node):
86
+ edges = self.backend.adj.get(node)
87
+ if edges:
88
+ return dict(sorted(edges.items(), key=lambda x: x[1].get("weight", 0), reverse=True))
89
+
90
+ return None
91
+
92
+ def addedge(self, source, target, **attrs):
93
+ self.backend.add_edge(source, target, **attrs)
94
+
95
+ def addedges(self, edges):
96
+ self.backend.add_edges_from(edges)
97
+
98
+ def hasedge(self, source, target=None):
99
+ if target is None:
100
+ edges = self.backend.adj.get(source)
101
+ return len(edges) > 0 if edges else False
102
+
103
+ return self.backend.has_edge(source, target)
104
+
105
+ def centrality(self):
106
+ rank = nx.degree_centrality(self.backend)
107
+ return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
108
+
109
+ def pagerank(self):
110
+ rank = nx.pagerank(self.backend, weight="weight")
111
+ return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
112
+
113
+ def showpath(self, source, target):
114
+ # pylint: disable=E1121
115
+ return nx.shortest_path(self.backend, source, target, self.distance)
116
+
117
+ def isquery(self, queries):
118
+ return Query().isquery(queries)
119
+
120
+ def parse(self, query):
121
+ return Query().parse(query)
122
+
123
+ def search(self, query, limit=None, graph=False):
124
+ # Run graph query
125
+ results = Query()(self, query, limit)
126
+
127
+ # Transform into filtered graph
128
+ if graph:
129
+ nodes = set()
130
+ for column in results.values():
131
+ for value in column:
132
+ if isinstance(value, list):
133
+ # Path group
134
+ nodes.update([node for node in value if node and not isinstance(node, dict)])
135
+ elif isinstance(value, dict):
136
+ # Nodes by id attribute
137
+ nodes.update(uid for uid, attr in self.scan(data=True) if attr["id"] == value["id"])
138
+ elif value is not None:
139
+ # Single node id
140
+ nodes.add(value)
141
+
142
+ return self.filter(list(nodes))
143
+
144
+ # Transform columnar structure into rows
145
+ keys = list(results.keys())
146
+ rows, count = [], len(results[keys[0]])
147
+
148
+ for x in range(count):
149
+ rows.append({str(key): results[key][x] for key in keys})
150
+
151
+ return rows
152
+
153
+ def communities(self, config):
154
+ # Get community detection algorithm
155
+ algorithm = config.get("algorithm")
156
+
157
+ if algorithm == "greedy":
158
+ communities = greedy_modularity_communities(self.backend, weight="weight", resolution=config.get("resolution", 100))
159
+ elif algorithm == "lpa":
160
+ communities = asyn_lpa_communities(self.backend, weight="weight", seed=0)
161
+ else:
162
+ communities = self.louvain(config)
163
+
164
+ return communities
165
+
166
+ def load(self, path):
167
+ try:
168
+ # Load graph data
169
+ data = SerializeFactory.create().load(path)
170
+
171
+ # Add data to graph
172
+ self.backend = self.create()
173
+ self.backend.add_nodes_from(data["nodes"])
174
+ self.backend.add_edges_from(data["edges"])
175
+
176
+ # Load categories
177
+ self.categories = data.get("categories")
178
+
179
+ # Load topics
180
+ self.topics = data.get("topics")
181
+
182
+ except SerializeError:
183
+ # Backwards compatible support for legacy TAR format
184
+ self.loadtar(path)
185
+
186
+ def save(self, path):
187
+ # Save graph data
188
+ SerializeFactory.create().save(
189
+ {
190
+ "nodes": [(uid, self.node(uid)) for uid in self.scan()],
191
+ "edges": list(self.backend.edges(data=True)),
192
+ "categories": self.categories,
193
+ "topics": self.topics,
194
+ },
195
+ path,
196
+ )
197
+
198
+ def loaddict(self, data):
199
+ self.backend = json_graph.node_link_graph(data, name="indexid")
200
+ self.categories, self.topics = data.get("categories"), data.get("topics")
201
+
202
+ def savedict(self):
203
+ data = json_graph.node_link_data(self.backend, name="indexid")
204
+ data["categories"] = self.categories
205
+ data["topics"] = self.topics
206
+
207
+ return data
208
+
209
+ def louvain(self, config):
210
+ """
211
+ Runs the Louvain community detection algorithm.
212
+
213
+ Args:
214
+ config: topic configuration
215
+
216
+ Returns:
217
+ list of [ids] per community
218
+ """
219
+
220
+ # Partition level to use
221
+ level = config.get("level", "best")
222
+
223
+ # Run community detection
224
+ results = list(louvain_partitions(self.backend, weight="weight", resolution=config.get("resolution", 100), seed=0))
225
+
226
+ # Get partition level (first or best)
227
+ return results[0] if level == "first" else results[-1]
228
+
229
+ # pylint: disable=W0613
230
+ def distance(self, source, target, attrs):
231
+ """
232
+ Computes distance between source and target nodes using weight.
233
+
234
+ Args:
235
+ source: source node
236
+ target: target node
237
+ attrs: edge attributes
238
+
239
+ Returns:
240
+ distance between source and target
241
+ """
242
+
243
+ # Distance is 1 - score. Skip minimal distances as they are near duplicates.
244
+ distance = max(1.0 - attrs["weight"], 0.0)
245
+ return distance if distance >= 0.15 else 1.00
246
+
247
+ def loadtar(self, path):
248
+ """
249
+ Loads a graph from the legacy TAR file.
250
+
251
+ Args:
252
+ path: path to graph
253
+ """
254
+
255
+ # Pickle serialization - backwards compatible
256
+ serializer = SerializeFactory.create("pickle")
257
+
258
+ # Extract files to temporary directory and load content
259
+ with TemporaryDirectory() as directory:
260
+ # Unpack files
261
+ archive = ArchiveFactory.create(directory)
262
+ archive.load(path, "tar")
263
+
264
+ # Load graph backend
265
+ self.backend = serializer.load(f"{directory}/graph")
266
+
267
+ # Load categories, if necessary
268
+ path = f"{directory}/categories"
269
+ if os.path.exists(path):
270
+ self.categories = serializer.load(path)
271
+
272
+ # Load topics, if necessary
273
+ path = f"{directory}/topics"
274
+ if os.path.exists(path):
275
+ self.topics = serializer.load(path)
txtai/graph/query.py ADDED
@@ -0,0 +1,181 @@
1
+ """
2
+ Query module
3
+ """
4
+
5
+ import logging
6
+ import re
7
+
8
+ try:
9
+ from grandcypher import GrandCypher
10
+
11
+ GRANDCYPHER = True
12
+ except ImportError:
13
+ GRANDCYPHER = False
14
+
15
+ # Logging configuration
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class Query:
20
+ """
21
+ Runs openCypher graph queries using the GrandCypher library. This class also supports search functions.
22
+ """
23
+
24
+ # Similar token
25
+ SIMILAR = "__SIMILAR__"
26
+
27
+ def __init__(self):
28
+ """
29
+ Create a new graph query instance.
30
+ """
31
+
32
+ if not GRANDCYPHER:
33
+ raise ImportError('GrandCypher is not available - install "graph" extra to enable')
34
+
35
+ def __call__(self, graph, query, limit):
36
+ """
37
+ Runs a graph query.
38
+
39
+ Args:
40
+ graph: graph instance
41
+ query: graph query, can be a full query string or a parsed query dictionary
42
+ limit: number of results
43
+
44
+ Returns:
45
+ results
46
+ """
47
+
48
+ # Results by attribute and ids filter
49
+ attributes, uids = None, None
50
+
51
+ # Build the query from a parsed query
52
+ if isinstance(query, dict):
53
+ query, attributes, uids = self.build(query)
54
+
55
+ # Filter graph, if applicable
56
+ if uids:
57
+ graph = self.filter(graph, attributes, uids)
58
+
59
+ # Debug log graph query
60
+ logger.debug(query)
61
+
62
+ # Run openCypher query
63
+ return GrandCypher(graph.backend, limit if limit else 3).run(query)
64
+
65
+ def isquery(self, queries):
66
+ """
67
+ Checks a list of queries to see if all queries are openCypher queries.
68
+
69
+ Args:
70
+ queries: list of queries to check
71
+
72
+ Returns:
73
+ True if all queries are openCypher queries
74
+ """
75
+
76
+ # Check for required graph query clauses
77
+ return all(query and query.strip().startswith("MATCH ") and "RETURN " in query for query in queries)
78
+
79
+ def parse(self, query):
80
+ """
81
+ Parses a graph query. This method supports parsing search functions and replacing them with placeholders.
82
+
83
+ Args:
84
+ query: graph query
85
+
86
+ Returns:
87
+ parsed query as a dictionary
88
+ """
89
+
90
+ # Parameters
91
+ where, limit, nodes, similar = None, None, [], []
92
+
93
+ # Parse where clause
94
+ match = re.search(r"where(.+?)return", query, flags=re.DOTALL | re.IGNORECASE)
95
+ if match:
96
+ where = match.group(1).strip()
97
+
98
+ # Parse limit clause
99
+ match = re.search(r"limit\s+(\d+)", query, flags=re.DOTALL | re.IGNORECASE)
100
+ if match:
101
+ limit = match.group(1)
102
+
103
+ # Parse similar clauses
104
+ for x, match in enumerate(re.finditer(r"similar\((.+?)\)", query, flags=re.DOTALL | re.IGNORECASE)):
105
+ # Replace similar clause with placeholder
106
+ query = query.replace(match.group(0), f"{Query.SIMILAR}{x}")
107
+
108
+ # Parse similar clause parameters
109
+ params = [param.strip().replace("'", "").replace('"', "") for param in match.group(1).split(",")]
110
+ nodes.append(params[0])
111
+ similar.append(params[1:])
112
+
113
+ # Return parsed query
114
+ return {
115
+ "query": query,
116
+ "where": where,
117
+ "limit": limit,
118
+ "nodes": nodes,
119
+ "similar": similar,
120
+ }
121
+
122
+ def build(self, parse):
123
+ """
124
+ Constructs a full query from a parsed query. This method supports substituting placeholders with search results.
125
+
126
+ Args:
127
+ parse: parsed query
128
+
129
+ Returns:
130
+ graph query
131
+ """
132
+
133
+ # Get query. Initialize attributes and uids.
134
+ query, attributes, uids = parse["query"], {}, {}
135
+
136
+ # Replace similar clause with id query
137
+ if "results" in parse:
138
+ for x, result in enumerate(parse["results"]):
139
+ # Get query node
140
+ node = parse["nodes"][x]
141
+
142
+ # Add similar match attribute
143
+ attribute = f"match_{x}"
144
+ clause = f"{node}.{attribute} > 0"
145
+
146
+ # Replace placeholder with earch results
147
+ query = query.replace(f"{Query.SIMILAR}{x}", f"{clause}")
148
+
149
+ # Add uids and scores
150
+ for uid, score in result:
151
+ if uid not in uids:
152
+ uids[uid] = score
153
+
154
+ # Add results by attribute matched
155
+ attributes[attribute] = result
156
+
157
+ # Return query, results by attribute matched and ids filter
158
+ return query, attributes, uids.items()
159
+
160
+ def filter(self, graph, attributes, uids):
161
+ """
162
+ Filters the input graph by uids. This method also adds similar match attributes.
163
+
164
+ Args:
165
+ graph: graph instance
166
+ attributes: results by attribute matched
167
+ uids: single list with all matching ids
168
+
169
+ Returns:
170
+ filtered graph
171
+ """
172
+
173
+ # Filter the graph
174
+ graph = graph.filter(uids)
175
+
176
+ # Add similar match attributes
177
+ for attribute, result in attributes.items():
178
+ for uid, score in result:
179
+ graph.addattribute(uid, attribute, score)
180
+
181
+ return graph
txtai/graph/rdbms.py ADDED
@@ -0,0 +1,113 @@
1
+ """
2
+ RDBMS module
3
+ """
4
+
5
+ import os
6
+
7
+ # Conditional import
8
+ try:
9
+ from grand import Graph
10
+ from grand.backends import SQLBackend, InMemoryCachedBackend
11
+
12
+ from sqlalchemy import create_engine, text, StaticPool
13
+ from sqlalchemy.schema import CreateSchema
14
+
15
+ ORM = True
16
+ except ImportError:
17
+ ORM = False
18
+
19
+ from .networkx import NetworkX
20
+
21
+
22
+ class RDBMS(NetworkX):
23
+ """
24
+ Graph instance backed by a relational database.
25
+ """
26
+
27
+ def __init__(self, config):
28
+ # Check before super() in case those required libraries are also not available
29
+ if not ORM:
30
+ raise ImportError('RDBMS is not available - install "graph" extra to enable')
31
+
32
+ super().__init__(config)
33
+
34
+ # Graph and database instances
35
+ self.graph = None
36
+ self.database = None
37
+
38
+ def __del__(self):
39
+ if hasattr(self, "database") and self.database:
40
+ self.database.close()
41
+
42
+ def create(self):
43
+ # Create graph instance
44
+ self.graph, self.database = self.connect()
45
+
46
+ # Clear previous graph, if available
47
+ for table in [self.config.get("nodes", "nodes"), self.config.get("edges", "edges")]:
48
+ self.database.execute(text(f"DELETE FROM {table}"))
49
+
50
+ # Return NetworkX compatible backend
51
+ return self.graph.nx
52
+
53
+ def scan(self, attribute=None, data=False):
54
+ if attribute:
55
+ for node in self.backend:
56
+ attributes = self.node(node)
57
+ if attribute in attributes:
58
+ yield (node, attributes) if data else node
59
+ else:
60
+ yield from super().scan(attribute, data)
61
+
62
+ def load(self, path):
63
+ # Create graph instance
64
+ self.graph, self.database = self.connect()
65
+
66
+ # Store NetworkX compatible backend
67
+ self.backend = self.graph.nx
68
+
69
+ def save(self, path):
70
+ self.database.commit()
71
+
72
+ def close(self):
73
+ # Parent logic
74
+ super().close()
75
+
76
+ # Close database connection
77
+ self.database.close()
78
+
79
+ def filter(self, nodes, graph=None):
80
+ return super().filter(nodes, graph if graph else NetworkX(self.config))
81
+
82
+ def connect(self):
83
+ """
84
+ Connects to a graph backed by a relational database.
85
+
86
+ Args:
87
+ Graph database instance
88
+ """
89
+
90
+ # Keyword arguments for SQLAlchemy
91
+ kwargs = {"poolclass": StaticPool, "echo": False}
92
+ url = self.config.get("url", os.environ.get("GRAPH_URL"))
93
+
94
+ # Set default schema, if necessary
95
+ schema = self.config.get("schema")
96
+ if schema:
97
+ # Check that schema exists
98
+ engine = create_engine(url)
99
+ with engine.begin() as connection:
100
+ connection.execute(CreateSchema(schema, if_not_exists=True) if "postgresql" in url else text("SELECT 1"))
101
+
102
+ # Set default schema
103
+ kwargs["connect_args"] = {"options": f'-c search_path="{schema}"'} if "postgresql" in url else {}
104
+
105
+ backend = SQLBackend(
106
+ db_url=url,
107
+ node_table_name=self.config.get("nodes", "nodes"),
108
+ edge_table_name=self.config.get("edges", "edges"),
109
+ sqlalchemy_kwargs=kwargs,
110
+ )
111
+
112
+ # pylint: disable=W0212
113
+ return Graph(backend=InMemoryCachedBackend(backend, maxsize=None)), backend._connection