mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/graph/factory.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..util import Resolver
|
6
|
+
|
7
|
+
from .networkx import NetworkX
|
8
|
+
from .rdbms import RDBMS
|
9
|
+
|
10
|
+
|
11
|
+
class GraphFactory:
|
12
|
+
"""
|
13
|
+
Methods to create graphs.
|
14
|
+
"""
|
15
|
+
|
16
|
+
@staticmethod
|
17
|
+
def create(config):
|
18
|
+
"""
|
19
|
+
Create a Graph.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: graph configuration
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Graph
|
26
|
+
"""
|
27
|
+
|
28
|
+
# Graph instance
|
29
|
+
graph = None
|
30
|
+
backend = config.get("backend", "networkx")
|
31
|
+
|
32
|
+
# Create graph instance
|
33
|
+
if backend == "networkx":
|
34
|
+
graph = NetworkX(config)
|
35
|
+
elif backend == "rdbms":
|
36
|
+
graph = RDBMS(config)
|
37
|
+
else:
|
38
|
+
graph = GraphFactory.resolve(backend, config)
|
39
|
+
|
40
|
+
# Store config back
|
41
|
+
config["backend"] = backend
|
42
|
+
|
43
|
+
return graph
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def resolve(backend, config):
|
47
|
+
"""
|
48
|
+
Attempt to resolve a custom backend.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
backend: backend class
|
52
|
+
config: index configuration parameters
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Graph
|
56
|
+
"""
|
57
|
+
|
58
|
+
try:
|
59
|
+
return Resolver()(backend)(config)
|
60
|
+
except Exception as e:
|
61
|
+
raise ImportError(f"Unable to resolve graph backend: '{backend}'") from e
|
txtai/graph/networkx.py
ADDED
@@ -0,0 +1,275 @@
|
|
1
|
+
"""
|
2
|
+
NetworkX module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
from tempfile import TemporaryDirectory
|
8
|
+
|
9
|
+
# Conditional import
|
10
|
+
try:
|
11
|
+
import networkx as nx
|
12
|
+
|
13
|
+
from networkx.algorithms.community import asyn_lpa_communities, greedy_modularity_communities, louvain_partitions
|
14
|
+
from networkx.readwrite import json_graph
|
15
|
+
|
16
|
+
NETWORKX = True
|
17
|
+
except ImportError:
|
18
|
+
NETWORKX = False
|
19
|
+
|
20
|
+
from ..archive import ArchiveFactory
|
21
|
+
from ..serialize import SerializeError, SerializeFactory
|
22
|
+
|
23
|
+
from .base import Graph
|
24
|
+
from .query import Query
|
25
|
+
|
26
|
+
|
27
|
+
# pylint: disable=R0904
|
28
|
+
class NetworkX(Graph):
|
29
|
+
"""
|
30
|
+
Graph instance backed by NetworkX.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, config):
|
34
|
+
super().__init__(config)
|
35
|
+
|
36
|
+
if not NETWORKX:
|
37
|
+
raise ImportError('NetworkX is not available - install "graph" extra to enable')
|
38
|
+
|
39
|
+
def create(self):
|
40
|
+
return nx.Graph()
|
41
|
+
|
42
|
+
def count(self):
|
43
|
+
return self.backend.number_of_nodes()
|
44
|
+
|
45
|
+
def scan(self, attribute=None, data=False):
|
46
|
+
# Full graph
|
47
|
+
graph = self.backend
|
48
|
+
|
49
|
+
# Filter graph to nodes having a specified attribute
|
50
|
+
if attribute:
|
51
|
+
graph = nx.subgraph_view(self.backend, filter_node=lambda x: attribute in self.node(x))
|
52
|
+
|
53
|
+
# Return either list of matching ids or tuple of (id, attribute dictionary)
|
54
|
+
return graph.nodes(data=True) if data else graph
|
55
|
+
|
56
|
+
def node(self, node):
|
57
|
+
return self.backend.nodes.get(node)
|
58
|
+
|
59
|
+
def addnode(self, node, **attrs):
|
60
|
+
self.backend.add_node(node, **attrs)
|
61
|
+
|
62
|
+
def addnodes(self, nodes):
|
63
|
+
self.backend.add_nodes_from(nodes)
|
64
|
+
|
65
|
+
def removenode(self, node):
|
66
|
+
if self.hasnode(node):
|
67
|
+
self.backend.remove_node(node)
|
68
|
+
|
69
|
+
def hasnode(self, node):
|
70
|
+
return self.backend.has_node(node)
|
71
|
+
|
72
|
+
def attribute(self, node, field):
|
73
|
+
return self.node(node).get(field) if self.hasnode(node) else None
|
74
|
+
|
75
|
+
def addattribute(self, node, field, value):
|
76
|
+
if self.hasnode(node):
|
77
|
+
self.node(node)[field] = value
|
78
|
+
|
79
|
+
def removeattribute(self, node, field):
|
80
|
+
return self.node(node).pop(field, None) if self.hasnode(node) else None
|
81
|
+
|
82
|
+
def edgecount(self):
|
83
|
+
return self.backend.number_of_edges()
|
84
|
+
|
85
|
+
def edges(self, node):
|
86
|
+
edges = self.backend.adj.get(node)
|
87
|
+
if edges:
|
88
|
+
return dict(sorted(edges.items(), key=lambda x: x[1].get("weight", 0), reverse=True))
|
89
|
+
|
90
|
+
return None
|
91
|
+
|
92
|
+
def addedge(self, source, target, **attrs):
|
93
|
+
self.backend.add_edge(source, target, **attrs)
|
94
|
+
|
95
|
+
def addedges(self, edges):
|
96
|
+
self.backend.add_edges_from(edges)
|
97
|
+
|
98
|
+
def hasedge(self, source, target=None):
|
99
|
+
if target is None:
|
100
|
+
edges = self.backend.adj.get(source)
|
101
|
+
return len(edges) > 0 if edges else False
|
102
|
+
|
103
|
+
return self.backend.has_edge(source, target)
|
104
|
+
|
105
|
+
def centrality(self):
|
106
|
+
rank = nx.degree_centrality(self.backend)
|
107
|
+
return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
|
108
|
+
|
109
|
+
def pagerank(self):
|
110
|
+
rank = nx.pagerank(self.backend, weight="weight")
|
111
|
+
return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
|
112
|
+
|
113
|
+
def showpath(self, source, target):
|
114
|
+
# pylint: disable=E1121
|
115
|
+
return nx.shortest_path(self.backend, source, target, self.distance)
|
116
|
+
|
117
|
+
def isquery(self, queries):
|
118
|
+
return Query().isquery(queries)
|
119
|
+
|
120
|
+
def parse(self, query):
|
121
|
+
return Query().parse(query)
|
122
|
+
|
123
|
+
def search(self, query, limit=None, graph=False):
|
124
|
+
# Run graph query
|
125
|
+
results = Query()(self, query, limit)
|
126
|
+
|
127
|
+
# Transform into filtered graph
|
128
|
+
if graph:
|
129
|
+
nodes = set()
|
130
|
+
for column in results.values():
|
131
|
+
for value in column:
|
132
|
+
if isinstance(value, list):
|
133
|
+
# Path group
|
134
|
+
nodes.update([node for node in value if node and not isinstance(node, dict)])
|
135
|
+
elif isinstance(value, dict):
|
136
|
+
# Nodes by id attribute
|
137
|
+
nodes.update(uid for uid, attr in self.scan(data=True) if attr["id"] == value["id"])
|
138
|
+
elif value is not None:
|
139
|
+
# Single node id
|
140
|
+
nodes.add(value)
|
141
|
+
|
142
|
+
return self.filter(list(nodes))
|
143
|
+
|
144
|
+
# Transform columnar structure into rows
|
145
|
+
keys = list(results.keys())
|
146
|
+
rows, count = [], len(results[keys[0]])
|
147
|
+
|
148
|
+
for x in range(count):
|
149
|
+
rows.append({str(key): results[key][x] for key in keys})
|
150
|
+
|
151
|
+
return rows
|
152
|
+
|
153
|
+
def communities(self, config):
|
154
|
+
# Get community detection algorithm
|
155
|
+
algorithm = config.get("algorithm")
|
156
|
+
|
157
|
+
if algorithm == "greedy":
|
158
|
+
communities = greedy_modularity_communities(self.backend, weight="weight", resolution=config.get("resolution", 100))
|
159
|
+
elif algorithm == "lpa":
|
160
|
+
communities = asyn_lpa_communities(self.backend, weight="weight", seed=0)
|
161
|
+
else:
|
162
|
+
communities = self.louvain(config)
|
163
|
+
|
164
|
+
return communities
|
165
|
+
|
166
|
+
def load(self, path):
|
167
|
+
try:
|
168
|
+
# Load graph data
|
169
|
+
data = SerializeFactory.create().load(path)
|
170
|
+
|
171
|
+
# Add data to graph
|
172
|
+
self.backend = self.create()
|
173
|
+
self.backend.add_nodes_from(data["nodes"])
|
174
|
+
self.backend.add_edges_from(data["edges"])
|
175
|
+
|
176
|
+
# Load categories
|
177
|
+
self.categories = data.get("categories")
|
178
|
+
|
179
|
+
# Load topics
|
180
|
+
self.topics = data.get("topics")
|
181
|
+
|
182
|
+
except SerializeError:
|
183
|
+
# Backwards compatible support for legacy TAR format
|
184
|
+
self.loadtar(path)
|
185
|
+
|
186
|
+
def save(self, path):
|
187
|
+
# Save graph data
|
188
|
+
SerializeFactory.create().save(
|
189
|
+
{
|
190
|
+
"nodes": [(uid, self.node(uid)) for uid in self.scan()],
|
191
|
+
"edges": list(self.backend.edges(data=True)),
|
192
|
+
"categories": self.categories,
|
193
|
+
"topics": self.topics,
|
194
|
+
},
|
195
|
+
path,
|
196
|
+
)
|
197
|
+
|
198
|
+
def loaddict(self, data):
|
199
|
+
self.backend = json_graph.node_link_graph(data, name="indexid")
|
200
|
+
self.categories, self.topics = data.get("categories"), data.get("topics")
|
201
|
+
|
202
|
+
def savedict(self):
|
203
|
+
data = json_graph.node_link_data(self.backend, name="indexid")
|
204
|
+
data["categories"] = self.categories
|
205
|
+
data["topics"] = self.topics
|
206
|
+
|
207
|
+
return data
|
208
|
+
|
209
|
+
def louvain(self, config):
|
210
|
+
"""
|
211
|
+
Runs the Louvain community detection algorithm.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
config: topic configuration
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
list of [ids] per community
|
218
|
+
"""
|
219
|
+
|
220
|
+
# Partition level to use
|
221
|
+
level = config.get("level", "best")
|
222
|
+
|
223
|
+
# Run community detection
|
224
|
+
results = list(louvain_partitions(self.backend, weight="weight", resolution=config.get("resolution", 100), seed=0))
|
225
|
+
|
226
|
+
# Get partition level (first or best)
|
227
|
+
return results[0] if level == "first" else results[-1]
|
228
|
+
|
229
|
+
# pylint: disable=W0613
|
230
|
+
def distance(self, source, target, attrs):
|
231
|
+
"""
|
232
|
+
Computes distance between source and target nodes using weight.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
source: source node
|
236
|
+
target: target node
|
237
|
+
attrs: edge attributes
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
distance between source and target
|
241
|
+
"""
|
242
|
+
|
243
|
+
# Distance is 1 - score. Skip minimal distances as they are near duplicates.
|
244
|
+
distance = max(1.0 - attrs["weight"], 0.0)
|
245
|
+
return distance if distance >= 0.15 else 1.00
|
246
|
+
|
247
|
+
def loadtar(self, path):
|
248
|
+
"""
|
249
|
+
Loads a graph from the legacy TAR file.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
path: path to graph
|
253
|
+
"""
|
254
|
+
|
255
|
+
# Pickle serialization - backwards compatible
|
256
|
+
serializer = SerializeFactory.create("pickle")
|
257
|
+
|
258
|
+
# Extract files to temporary directory and load content
|
259
|
+
with TemporaryDirectory() as directory:
|
260
|
+
# Unpack files
|
261
|
+
archive = ArchiveFactory.create(directory)
|
262
|
+
archive.load(path, "tar")
|
263
|
+
|
264
|
+
# Load graph backend
|
265
|
+
self.backend = serializer.load(f"{directory}/graph")
|
266
|
+
|
267
|
+
# Load categories, if necessary
|
268
|
+
path = f"{directory}/categories"
|
269
|
+
if os.path.exists(path):
|
270
|
+
self.categories = serializer.load(path)
|
271
|
+
|
272
|
+
# Load topics, if necessary
|
273
|
+
path = f"{directory}/topics"
|
274
|
+
if os.path.exists(path):
|
275
|
+
self.topics = serializer.load(path)
|
txtai/graph/query.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
"""
|
2
|
+
Query module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import re
|
7
|
+
|
8
|
+
try:
|
9
|
+
from grandcypher import GrandCypher
|
10
|
+
|
11
|
+
GRANDCYPHER = True
|
12
|
+
except ImportError:
|
13
|
+
GRANDCYPHER = False
|
14
|
+
|
15
|
+
# Logging configuration
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class Query:
|
20
|
+
"""
|
21
|
+
Runs openCypher graph queries using the GrandCypher library. This class also supports search functions.
|
22
|
+
"""
|
23
|
+
|
24
|
+
# Similar token
|
25
|
+
SIMILAR = "__SIMILAR__"
|
26
|
+
|
27
|
+
def __init__(self):
|
28
|
+
"""
|
29
|
+
Create a new graph query instance.
|
30
|
+
"""
|
31
|
+
|
32
|
+
if not GRANDCYPHER:
|
33
|
+
raise ImportError('GrandCypher is not available - install "graph" extra to enable')
|
34
|
+
|
35
|
+
def __call__(self, graph, query, limit):
|
36
|
+
"""
|
37
|
+
Runs a graph query.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
graph: graph instance
|
41
|
+
query: graph query, can be a full query string or a parsed query dictionary
|
42
|
+
limit: number of results
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
results
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Results by attribute and ids filter
|
49
|
+
attributes, uids = None, None
|
50
|
+
|
51
|
+
# Build the query from a parsed query
|
52
|
+
if isinstance(query, dict):
|
53
|
+
query, attributes, uids = self.build(query)
|
54
|
+
|
55
|
+
# Filter graph, if applicable
|
56
|
+
if uids:
|
57
|
+
graph = self.filter(graph, attributes, uids)
|
58
|
+
|
59
|
+
# Debug log graph query
|
60
|
+
logger.debug(query)
|
61
|
+
|
62
|
+
# Run openCypher query
|
63
|
+
return GrandCypher(graph.backend, limit if limit else 3).run(query)
|
64
|
+
|
65
|
+
def isquery(self, queries):
|
66
|
+
"""
|
67
|
+
Checks a list of queries to see if all queries are openCypher queries.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
queries: list of queries to check
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
True if all queries are openCypher queries
|
74
|
+
"""
|
75
|
+
|
76
|
+
# Check for required graph query clauses
|
77
|
+
return all(query and query.strip().startswith("MATCH ") and "RETURN " in query for query in queries)
|
78
|
+
|
79
|
+
def parse(self, query):
|
80
|
+
"""
|
81
|
+
Parses a graph query. This method supports parsing search functions and replacing them with placeholders.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
query: graph query
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
parsed query as a dictionary
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Parameters
|
91
|
+
where, limit, nodes, similar = None, None, [], []
|
92
|
+
|
93
|
+
# Parse where clause
|
94
|
+
match = re.search(r"where(.+?)return", query, flags=re.DOTALL | re.IGNORECASE)
|
95
|
+
if match:
|
96
|
+
where = match.group(1).strip()
|
97
|
+
|
98
|
+
# Parse limit clause
|
99
|
+
match = re.search(r"limit\s+(\d+)", query, flags=re.DOTALL | re.IGNORECASE)
|
100
|
+
if match:
|
101
|
+
limit = match.group(1)
|
102
|
+
|
103
|
+
# Parse similar clauses
|
104
|
+
for x, match in enumerate(re.finditer(r"similar\((.+?)\)", query, flags=re.DOTALL | re.IGNORECASE)):
|
105
|
+
# Replace similar clause with placeholder
|
106
|
+
query = query.replace(match.group(0), f"{Query.SIMILAR}{x}")
|
107
|
+
|
108
|
+
# Parse similar clause parameters
|
109
|
+
params = [param.strip().replace("'", "").replace('"', "") for param in match.group(1).split(",")]
|
110
|
+
nodes.append(params[0])
|
111
|
+
similar.append(params[1:])
|
112
|
+
|
113
|
+
# Return parsed query
|
114
|
+
return {
|
115
|
+
"query": query,
|
116
|
+
"where": where,
|
117
|
+
"limit": limit,
|
118
|
+
"nodes": nodes,
|
119
|
+
"similar": similar,
|
120
|
+
}
|
121
|
+
|
122
|
+
def build(self, parse):
|
123
|
+
"""
|
124
|
+
Constructs a full query from a parsed query. This method supports substituting placeholders with search results.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
parse: parsed query
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
graph query
|
131
|
+
"""
|
132
|
+
|
133
|
+
# Get query. Initialize attributes and uids.
|
134
|
+
query, attributes, uids = parse["query"], {}, {}
|
135
|
+
|
136
|
+
# Replace similar clause with id query
|
137
|
+
if "results" in parse:
|
138
|
+
for x, result in enumerate(parse["results"]):
|
139
|
+
# Get query node
|
140
|
+
node = parse["nodes"][x]
|
141
|
+
|
142
|
+
# Add similar match attribute
|
143
|
+
attribute = f"match_{x}"
|
144
|
+
clause = f"{node}.{attribute} > 0"
|
145
|
+
|
146
|
+
# Replace placeholder with earch results
|
147
|
+
query = query.replace(f"{Query.SIMILAR}{x}", f"{clause}")
|
148
|
+
|
149
|
+
# Add uids and scores
|
150
|
+
for uid, score in result:
|
151
|
+
if uid not in uids:
|
152
|
+
uids[uid] = score
|
153
|
+
|
154
|
+
# Add results by attribute matched
|
155
|
+
attributes[attribute] = result
|
156
|
+
|
157
|
+
# Return query, results by attribute matched and ids filter
|
158
|
+
return query, attributes, uids.items()
|
159
|
+
|
160
|
+
def filter(self, graph, attributes, uids):
|
161
|
+
"""
|
162
|
+
Filters the input graph by uids. This method also adds similar match attributes.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
graph: graph instance
|
166
|
+
attributes: results by attribute matched
|
167
|
+
uids: single list with all matching ids
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
filtered graph
|
171
|
+
"""
|
172
|
+
|
173
|
+
# Filter the graph
|
174
|
+
graph = graph.filter(uids)
|
175
|
+
|
176
|
+
# Add similar match attributes
|
177
|
+
for attribute, result in attributes.items():
|
178
|
+
for uid, score in result:
|
179
|
+
graph.addattribute(uid, attribute, score)
|
180
|
+
|
181
|
+
return graph
|
txtai/graph/rdbms.py
ADDED
@@ -0,0 +1,113 @@
|
|
1
|
+
"""
|
2
|
+
RDBMS module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
# Conditional import
|
8
|
+
try:
|
9
|
+
from grand import Graph
|
10
|
+
from grand.backends import SQLBackend, InMemoryCachedBackend
|
11
|
+
|
12
|
+
from sqlalchemy import create_engine, text, StaticPool
|
13
|
+
from sqlalchemy.schema import CreateSchema
|
14
|
+
|
15
|
+
ORM = True
|
16
|
+
except ImportError:
|
17
|
+
ORM = False
|
18
|
+
|
19
|
+
from .networkx import NetworkX
|
20
|
+
|
21
|
+
|
22
|
+
class RDBMS(NetworkX):
|
23
|
+
"""
|
24
|
+
Graph instance backed by a relational database.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, config):
|
28
|
+
# Check before super() in case those required libraries are also not available
|
29
|
+
if not ORM:
|
30
|
+
raise ImportError('RDBMS is not available - install "graph" extra to enable')
|
31
|
+
|
32
|
+
super().__init__(config)
|
33
|
+
|
34
|
+
# Graph and database instances
|
35
|
+
self.graph = None
|
36
|
+
self.database = None
|
37
|
+
|
38
|
+
def __del__(self):
|
39
|
+
if hasattr(self, "database") and self.database:
|
40
|
+
self.database.close()
|
41
|
+
|
42
|
+
def create(self):
|
43
|
+
# Create graph instance
|
44
|
+
self.graph, self.database = self.connect()
|
45
|
+
|
46
|
+
# Clear previous graph, if available
|
47
|
+
for table in [self.config.get("nodes", "nodes"), self.config.get("edges", "edges")]:
|
48
|
+
self.database.execute(text(f"DELETE FROM {table}"))
|
49
|
+
|
50
|
+
# Return NetworkX compatible backend
|
51
|
+
return self.graph.nx
|
52
|
+
|
53
|
+
def scan(self, attribute=None, data=False):
|
54
|
+
if attribute:
|
55
|
+
for node in self.backend:
|
56
|
+
attributes = self.node(node)
|
57
|
+
if attribute in attributes:
|
58
|
+
yield (node, attributes) if data else node
|
59
|
+
else:
|
60
|
+
yield from super().scan(attribute, data)
|
61
|
+
|
62
|
+
def load(self, path):
|
63
|
+
# Create graph instance
|
64
|
+
self.graph, self.database = self.connect()
|
65
|
+
|
66
|
+
# Store NetworkX compatible backend
|
67
|
+
self.backend = self.graph.nx
|
68
|
+
|
69
|
+
def save(self, path):
|
70
|
+
self.database.commit()
|
71
|
+
|
72
|
+
def close(self):
|
73
|
+
# Parent logic
|
74
|
+
super().close()
|
75
|
+
|
76
|
+
# Close database connection
|
77
|
+
self.database.close()
|
78
|
+
|
79
|
+
def filter(self, nodes, graph=None):
|
80
|
+
return super().filter(nodes, graph if graph else NetworkX(self.config))
|
81
|
+
|
82
|
+
def connect(self):
|
83
|
+
"""
|
84
|
+
Connects to a graph backed by a relational database.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
Graph database instance
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Keyword arguments for SQLAlchemy
|
91
|
+
kwargs = {"poolclass": StaticPool, "echo": False}
|
92
|
+
url = self.config.get("url", os.environ.get("GRAPH_URL"))
|
93
|
+
|
94
|
+
# Set default schema, if necessary
|
95
|
+
schema = self.config.get("schema")
|
96
|
+
if schema:
|
97
|
+
# Check that schema exists
|
98
|
+
engine = create_engine(url)
|
99
|
+
with engine.begin() as connection:
|
100
|
+
connection.execute(CreateSchema(schema, if_not_exists=True) if "postgresql" in url else text("SELECT 1"))
|
101
|
+
|
102
|
+
# Set default schema
|
103
|
+
kwargs["connect_args"] = {"options": f'-c search_path="{schema}"'} if "postgresql" in url else {}
|
104
|
+
|
105
|
+
backend = SQLBackend(
|
106
|
+
db_url=url,
|
107
|
+
node_table_name=self.config.get("nodes", "nodes"),
|
108
|
+
edge_table_name=self.config.get("edges", "edges"),
|
109
|
+
sqlalchemy_kwargs=kwargs,
|
110
|
+
)
|
111
|
+
|
112
|
+
# pylint: disable=W0212
|
113
|
+
return Graph(backend=InMemoryCachedBackend(backend, maxsize=None)), backend._connection
|