graph-seeder 1.0.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graph_seeder/GraphSeeder.py +47 -0
- graph_seeder/SubgraphExtractor.py +377 -0
- graph_seeder/configs/dbpedia_default.json +59 -0
- graph_seeder/configs/default.json +47 -0
- graph_seeder/configs/europeana_default.json +50 -0
- graph_seeder/configs/pgxlod_default.json +47 -0
- graph_seeder/configs/wikidata_default.json +70 -0
- graph_seeder/densification/GraphConnector.py +113 -0
- graph_seeder/extraction/BFS/BFS.py +192 -0
- graph_seeder/extraction/ExtractionStrategy.py +70 -0
- graph_seeder/extraction/Hop/HopExpansion.py +92 -0
- graph_seeder/utils/ConsoleUI.py +273 -0
- graph_seeder/utils/Factory.py +64 -0
- graph_seeder/utils/GraphExporter.py +84 -0
- graph_seeder/utils/GraphStatistics.py +32 -0
- graph_seeder/utils/URIManager.py +95 -0
- graph_seeder/utils/utils.py +217 -0
- graph_seeder/wrapper/NeighborhoodWrapper.py +47 -0
- graph_seeder/wrapper/hashmap/HashMapWrapper.py +124 -0
- graph_seeder/wrapper/sparql/BaseClient.py +23 -0
- graph_seeder/wrapper/sparql/GraphWrapper.py +269 -0
- graph_seeder/wrapper/sparql/SparqlQueryBuilder.py +175 -0
- graph_seeder/wrapper/sparql/client/SparqlClient.py +118 -0
- graph_seeder/wrapper/sparql/client/TurtleClient.py +47 -0
- graph_seeder-1.0.0.dev0.dist-info/METADATA +191 -0
- graph_seeder-1.0.0.dev0.dist-info/RECORD +28 -0
- graph_seeder-1.0.0.dev0.dist-info/WHEEL +4 -0
- graph_seeder-1.0.0.dev0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
{
|
|
2
|
+
"data": {
|
|
3
|
+
"input_path": "seed.csv",
|
|
4
|
+
"output_format": "csv",
|
|
5
|
+
"output_path": "output/result"
|
|
6
|
+
},
|
|
7
|
+
"client": {
|
|
8
|
+
"type": "SPARQL",
|
|
9
|
+
"user_agent": "YOUR_PROJECT_NAME (contact: YOUR_EMAIL)",
|
|
10
|
+
"endpoint": "https://query.wikidata.org/sparql",
|
|
11
|
+
"request_delay": 1,
|
|
12
|
+
"retry_attempts": 3,
|
|
13
|
+
"retry_delay": 3.0,
|
|
14
|
+
"rate_limit_wait": 60.0,
|
|
15
|
+
"timeout": 40.0
|
|
16
|
+
},
|
|
17
|
+
"graph_filters": {
|
|
18
|
+
"include_uri_prefixes": [
|
|
19
|
+
"http://www.wikidata.org/entity/Q"
|
|
20
|
+
],
|
|
21
|
+
"exclude_uri_prefixes": [],
|
|
22
|
+
"exclude_nodes": [
|
|
23
|
+
"http://www.wikidata.org/entity/Q5",
|
|
24
|
+
"http://www.wikidata.org/entity/Q30",
|
|
25
|
+
"http://www.wikidata.org/entity/Q145",
|
|
26
|
+
"http://www.wikidata.org/entity/Q215627",
|
|
27
|
+
"http://www.wikidata.org/entity/Q43229",
|
|
28
|
+
"http://www.wikidata.org/entity/Q6256",
|
|
29
|
+
"http://www.wikidata.org/entity/Q11424",
|
|
30
|
+
"http://www.wikidata.org/entity/Q17"
|
|
31
|
+
],
|
|
32
|
+
"exclude_properties": [
|
|
33
|
+
"http://www.wikidata.org/prop/direct/P31",
|
|
34
|
+
"http://www.wikidata.org/prop/direct/P279",
|
|
35
|
+
"http://www.wikidata.org/prop/direct/P361",
|
|
36
|
+
"http://www.wikidata.org/prop/direct/P527",
|
|
37
|
+
"http://www.wikidata.org/prop/direct/P155",
|
|
38
|
+
"http://www.wikidata.org/prop/direct/P156",
|
|
39
|
+
"http://www.wikidata.org/prop/direct/P21",
|
|
40
|
+
"http://www.wikidata.org/prop/direct/P17",
|
|
41
|
+
"http://www.wikidata.org/prop/direct/P27",
|
|
42
|
+
"http://www.wikidata.org/prop/direct/P1412",
|
|
43
|
+
"http://www.wikidata.org/prop/direct/P407"
|
|
44
|
+
],
|
|
45
|
+
"namespaces": {
|
|
46
|
+
"wdt": "http://www.wikidata.org/prop/direct/",
|
|
47
|
+
"wd": "http://www.wikidata.org/entity/"
|
|
48
|
+
}
|
|
49
|
+
},
|
|
50
|
+
"extraction": {
|
|
51
|
+
"strategy": "bfs",
|
|
52
|
+
"batch_size": 15,
|
|
53
|
+
"max_hops": 6,
|
|
54
|
+
"hub_pagination_threshold": 70000,
|
|
55
|
+
"max_neighbors_threshold": 300000,
|
|
56
|
+
"hub_pairs_batch_size": 100,
|
|
57
|
+
"min_triplets_per_property": 2,
|
|
58
|
+
"check_seeds_validity": false,
|
|
59
|
+
"check_hub_seeds": false,
|
|
60
|
+
"keep_hub_seeds": null
|
|
61
|
+
},
|
|
62
|
+
"densification": {
|
|
63
|
+
"mode": "most_connected",
|
|
64
|
+
"skip_densification": false
|
|
65
|
+
},
|
|
66
|
+
"debug": {
|
|
67
|
+
"debug_enabled": false,
|
|
68
|
+
"request_logging": false
|
|
69
|
+
}
|
|
70
|
+
}
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import networkx as nx
|
|
3
|
+
import random
|
|
4
|
+
import itertools
|
|
5
|
+
from graph_seeder.extraction.BFS.BFS import BidirectionalBFS
|
|
6
|
+
from graph_seeder.utils.ConsoleUI import ConsoleUI
|
|
7
|
+
from graph_seeder.utils.URIManager import URIManager
|
|
8
|
+
from graph_seeder.utils.utils import get_connected_components
|
|
9
|
+
from graph_seeder.utils.Factory import ComponentFactory
|
|
10
|
+
from graph_seeder.wrapper.NeighborhoodWrapper import NeighborhoodWrapper
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("subgraph")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GraphConnector:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
wrapper: NeighborhoodWrapper,
|
|
19
|
+
uri_manager: URIManager,
|
|
20
|
+
graph: nx.MultiGraph,
|
|
21
|
+
ui: ConsoleUI,
|
|
22
|
+
config: dict,
|
|
23
|
+
):
|
|
24
|
+
self.wrapper = wrapper
|
|
25
|
+
self.uri_manager = uri_manager
|
|
26
|
+
self.graph = graph
|
|
27
|
+
self.ui = ui
|
|
28
|
+
|
|
29
|
+
self.bfs_config = dict(config)
|
|
30
|
+
self.bfs_config["extraction"]["strategy"] = "bfs"
|
|
31
|
+
|
|
32
|
+
self.mode = (
|
|
33
|
+
config.get("densification", {}).get("mode", "most_connected").lower()
|
|
34
|
+
)
|
|
35
|
+
self.bfs: BidirectionalBFS = ComponentFactory.create_strategy(
|
|
36
|
+
wrapper, uri_manager, self.bfs_config
|
|
37
|
+
)
|
|
38
|
+
self.bfs.load_graph(graph)
|
|
39
|
+
|
|
40
|
+
def _pick_representative(self, comp_seeds: list[str]) -> str:
|
|
41
|
+
if self.mode == "most_connected":
|
|
42
|
+
return max(comp_seeds, key=lambda n: self.graph.degree(n))
|
|
43
|
+
elif self.mode == "random":
|
|
44
|
+
return random.choice(comp_seeds)
|
|
45
|
+
else:
|
|
46
|
+
return comp_seeds[0]
|
|
47
|
+
|
|
48
|
+
def connect(
|
|
49
|
+
self, found_seeds: set[str], triplets: list[tuple[str, str, str]]
|
|
50
|
+
) -> list[tuple[str, str, str]]:
|
|
51
|
+
new_triplets: list[tuple[str, str, str]] = list(triplets)
|
|
52
|
+
failed_component_pairs = set()
|
|
53
|
+
|
|
54
|
+
with self.ui.create_progress_bar() as progress:
|
|
55
|
+
task = progress.add_task("Components densification", total=None)
|
|
56
|
+
|
|
57
|
+
while True:
|
|
58
|
+
components = get_connected_components(new_triplets)
|
|
59
|
+
|
|
60
|
+
if len(components) <= 1:
|
|
61
|
+
current_completed = progress.tasks[0].completed
|
|
62
|
+
progress.update(task, total=current_completed)
|
|
63
|
+
logger.info("[green]✓[/] All components are now connected.")
|
|
64
|
+
break
|
|
65
|
+
|
|
66
|
+
for ca, cb in itertools.combinations(components, 2):
|
|
67
|
+
seeds_a = found_seeds & ca
|
|
68
|
+
seeds_b = found_seeds & cb
|
|
69
|
+
|
|
70
|
+
if not seeds_a or not seeds_b:
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
sorted_a = tuple(sorted(seeds_a))
|
|
74
|
+
sorted_b = tuple(sorted(seeds_b))
|
|
75
|
+
|
|
76
|
+
pair_id = tuple(sorted([sorted_a, sorted_b]))
|
|
77
|
+
|
|
78
|
+
if pair_id not in failed_component_pairs:
|
|
79
|
+
break
|
|
80
|
+
else:
|
|
81
|
+
current_completed = progress.tasks[0].completed
|
|
82
|
+
progress.update(task, total=current_completed)
|
|
83
|
+
|
|
84
|
+
logger.warning(
|
|
85
|
+
"[red]✗[/] No valid pairs left to connect. Densification aborted."
|
|
86
|
+
)
|
|
87
|
+
break
|
|
88
|
+
|
|
89
|
+
if not seeds_a or not seeds_b:
|
|
90
|
+
logger.warning(
|
|
91
|
+
"[red]✗[/] No seeds available to connect components "
|
|
92
|
+
)
|
|
93
|
+
break
|
|
94
|
+
|
|
95
|
+
source = self._pick_representative(list(seeds_a))
|
|
96
|
+
target = self._pick_representative(list(seeds_b))
|
|
97
|
+
|
|
98
|
+
triplets = self.bfs.execute_task(
|
|
99
|
+
[source, target],
|
|
100
|
+
progress,
|
|
101
|
+
task,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if triplets:
|
|
105
|
+
new_triplets.extend(triplets)
|
|
106
|
+
for s, p, o in triplets:
|
|
107
|
+
self.graph.add_edge(s, o)
|
|
108
|
+
else:
|
|
109
|
+
failed_component_pairs.add(
|
|
110
|
+
tuple(sorted([tuple(seeds_a), tuple(seeds_b)]))
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return new_triplets
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from graph_seeder.utils.URIManager import URIManager
|
|
3
|
+
from graph_seeder.wrapper.NeighborhoodWrapper import NeighborhoodWrapper
|
|
4
|
+
from graph_seeder.extraction.ExtractionStrategy import ExtractionStrategy
|
|
5
|
+
import networkx as nx
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger("subgraph")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BidirectionalBFS(ExtractionStrategy):
|
|
11
|
+
"""Bidirectional BFS over a knowledge graph.
|
|
12
|
+
|
|
13
|
+
The graph is grown on demand: nodes are only fetched from the SPARQL endpoint
|
|
14
|
+
when the BFS frontier reaches them. Hub nodes (high-degree connectors) are
|
|
15
|
+
detected and skipped before each expansion to avoid timeout storms.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
wrapper: NeighborhoodWrapper,
|
|
21
|
+
uri_manager: URIManager,
|
|
22
|
+
cfg: dict,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__(wrapper, uri_manager, cfg)
|
|
25
|
+
self.uri_manager = uri_manager
|
|
26
|
+
self.cfg = cfg
|
|
27
|
+
|
|
28
|
+
self.explored_nodes: set[str] = set()
|
|
29
|
+
self._has_path: bool = False
|
|
30
|
+
|
|
31
|
+
self._excluded_nodes = set(cfg["graph_filters"]["exclude_nodes"])
|
|
32
|
+
|
|
33
|
+
max_hops_val = cfg["extraction"].get("max_hops")
|
|
34
|
+
self.max_hops = max_hops_val if max_hops_val is not None else float("inf")
|
|
35
|
+
|
|
36
|
+
def format_progress_description(self, nodes: list[str]) -> str:
|
|
37
|
+
source_uri = self.uri_manager.compress_uri(nodes[0])
|
|
38
|
+
target_uri = self.uri_manager.compress_uri(nodes[1])
|
|
39
|
+
return f"{source_uri} → {target_uri}"
|
|
40
|
+
|
|
41
|
+
def format_start_message(
|
|
42
|
+
self,
|
|
43
|
+
nodes: list[str],
|
|
44
|
+
) -> str:
|
|
45
|
+
return f"Extracting path for {self.format_progress_description(nodes)}"
|
|
46
|
+
|
|
47
|
+
def extract(self, nodes: list[str]) -> tuple[list[tuple[str, str, str]], str]:
|
|
48
|
+
"""Extract a subgraph connecting the given seed nodes using bidirectional BFS.
|
|
49
|
+
Returns:
|
|
50
|
+
Tuple of (list of path triplets, result message)"""
|
|
51
|
+
if len(nodes) != 2:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
"BidirectionalBFS extraction requires exactly 2 seed nodes."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
source, target = nodes
|
|
57
|
+
return self._find_path(source, target)
|
|
58
|
+
|
|
59
|
+
def _find_path(
|
|
60
|
+
self, source: str, target: str
|
|
61
|
+
) -> tuple[list[tuple[str, str, str]], str]:
|
|
62
|
+
"""Find a path between two nodes within the configured hop limit.
|
|
63
|
+
Returns:
|
|
64
|
+
Tuple of (path_triplets, result message)
|
|
65
|
+
"""
|
|
66
|
+
self.graph.add_node(source)
|
|
67
|
+
self.graph.add_node(target)
|
|
68
|
+
|
|
69
|
+
source_uri = self.uri_manager.compress_uri(source)
|
|
70
|
+
target_uri = self.uri_manager.compress_uri(target)
|
|
71
|
+
|
|
72
|
+
if source == target:
|
|
73
|
+
return [], f"[red]✗[/] Source and target are the same node: {source_uri}"
|
|
74
|
+
|
|
75
|
+
q_src: set[str] = {source}
|
|
76
|
+
q_tgt: set[str] = {target}
|
|
77
|
+
visited_src: set[str] = {source}
|
|
78
|
+
visited_tgt: set[str] = {target}
|
|
79
|
+
p_src = p_tgt = 0
|
|
80
|
+
self._has_path = False
|
|
81
|
+
|
|
82
|
+
result_message = None
|
|
83
|
+
|
|
84
|
+
while q_src and q_tgt and p_src + p_tgt <= self.max_hops:
|
|
85
|
+
# We expand the smaller frontier to balance the search
|
|
86
|
+
if len(q_src) <= len(q_tgt):
|
|
87
|
+
logger.info(
|
|
88
|
+
f"Expanding from source {source_uri!r} at depth {p_src} "
|
|
89
|
+
f"({len(q_src)} nodes)"
|
|
90
|
+
)
|
|
91
|
+
q_src = self._expand_level(q_src, visited_src, visited_tgt)
|
|
92
|
+
visited_src.update(q_src)
|
|
93
|
+
p_src += 1
|
|
94
|
+
else:
|
|
95
|
+
logger.info(
|
|
96
|
+
f"Expanding from target {target_uri!r} at depth {p_tgt} "
|
|
97
|
+
f"({len(q_tgt)} nodes)"
|
|
98
|
+
)
|
|
99
|
+
q_tgt = self._expand_level(q_tgt, visited_tgt, visited_src)
|
|
100
|
+
visited_tgt.update(q_tgt)
|
|
101
|
+
p_tgt += 1
|
|
102
|
+
|
|
103
|
+
if self._has_path:
|
|
104
|
+
break
|
|
105
|
+
|
|
106
|
+
if p_src + p_tgt >= self.max_hops:
|
|
107
|
+
result_message = f"[yellow]✗[/] Path not found for {source_uri} → {target_uri} : Max hops limit ({self.max_hops}) reached before finding a connection."
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
if not self._has_path:
|
|
111
|
+
if not result_message:
|
|
112
|
+
result_message = f"[yellow]✗[/] Path not found for {source_uri} → {target_uri} : No path exists (search space exhausted, nodes may be isolated due to filters)."
|
|
113
|
+
return [], result_message
|
|
114
|
+
|
|
115
|
+
path_triplets = self._extract_path_triplets(source, target)
|
|
116
|
+
|
|
117
|
+
if len(path_triplets) > self.max_hops:
|
|
118
|
+
result_message = f"[yellow]✗[/] Path found for {source_uri} → {target_uri} but discarded: length {len(path_triplets)} exceeds max_hops ({self.max_hops})."
|
|
119
|
+
logger.warning(result_message)
|
|
120
|
+
return [], result_message
|
|
121
|
+
|
|
122
|
+
return (
|
|
123
|
+
path_triplets,
|
|
124
|
+
f"[green]✓[/] Path found for {source_uri} → {target_uri} : {len(path_triplets)} hops.",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def _expand_level(
|
|
128
|
+
self,
|
|
129
|
+
current_level: set[str],
|
|
130
|
+
nodes_visited: set[str],
|
|
131
|
+
visited_other_side: set[str],
|
|
132
|
+
) -> set[str]:
|
|
133
|
+
"""Expand one BFS frontier level using cached and remote neighborhoods."""
|
|
134
|
+
next_level: set[str] = set()
|
|
135
|
+
nodes_to_query: list[str] = []
|
|
136
|
+
|
|
137
|
+
for node in current_level:
|
|
138
|
+
if node in self.explored_nodes:
|
|
139
|
+
for neighbor in self.graph.neighbors(node):
|
|
140
|
+
if neighbor not in nodes_visited:
|
|
141
|
+
next_level.add(neighbor)
|
|
142
|
+
else:
|
|
143
|
+
nodes_to_query.append(node)
|
|
144
|
+
|
|
145
|
+
if next_level & visited_other_side:
|
|
146
|
+
self._has_path = True
|
|
147
|
+
return next_level
|
|
148
|
+
|
|
149
|
+
if not nodes_to_query:
|
|
150
|
+
return next_level
|
|
151
|
+
|
|
152
|
+
self.explored_nodes.update(nodes_to_query)
|
|
153
|
+
for node in nodes_to_query:
|
|
154
|
+
self.graph.add_node(node)
|
|
155
|
+
|
|
156
|
+
for triplets in self.wrapper.get_neighborhood(nodes_to_query):
|
|
157
|
+
for subj, predicate, obj in triplets:
|
|
158
|
+
if subj in self._excluded_nodes or obj in self._excluded_nodes:
|
|
159
|
+
continue
|
|
160
|
+
|
|
161
|
+
self.graph.add_edge(
|
|
162
|
+
subj, obj, predicate=predicate, original_subj=subj, original_obj=obj
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
for n in (subj, obj):
|
|
166
|
+
if n not in nodes_visited:
|
|
167
|
+
next_level.add(n)
|
|
168
|
+
|
|
169
|
+
if next_level & visited_other_side:
|
|
170
|
+
self._has_path = True
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
return next_level
|
|
174
|
+
|
|
175
|
+
def _extract_path_triplets(
|
|
176
|
+
self, source: str, target: str
|
|
177
|
+
) -> list[tuple[str, str, str]]:
|
|
178
|
+
"""Build a triple sequence for the shortest path currently in the graph."""
|
|
179
|
+
path_nodes: list[str] = nx.shortest_path(
|
|
180
|
+
self.graph, source=source, target=target
|
|
181
|
+
)
|
|
182
|
+
triplets: list[tuple[str, str, str]] = []
|
|
183
|
+
|
|
184
|
+
for u, v in zip(path_nodes, path_nodes[1:]):
|
|
185
|
+
edges = self.graph[u][v]
|
|
186
|
+
edge_data = edges[next(iter(edges))]
|
|
187
|
+
predicate = edge_data.get("predicate", "unknown_property")
|
|
188
|
+
subj = edge_data.get("original_subj", u)
|
|
189
|
+
obj = edge_data.get("original_obj", v)
|
|
190
|
+
triplets.append((subj, predicate, obj))
|
|
191
|
+
|
|
192
|
+
return triplets
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from networkx import MultiGraph
|
|
3
|
+
from graph_seeder.utils.URIManager import URIManager
|
|
4
|
+
from graph_seeder.wrapper.NeighborhoodWrapper import NeighborhoodWrapper
|
|
5
|
+
from rich.progress import (
|
|
6
|
+
Progress,
|
|
7
|
+
TaskID,
|
|
8
|
+
)
|
|
9
|
+
import logging
|
|
10
|
+
from time import time
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("subgraph")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ExtractionStrategy(ABC):
|
|
16
|
+
def __init__(
|
|
17
|
+
self, wrapper: NeighborhoodWrapper, uri_manager: URIManager, config: dict
|
|
18
|
+
):
|
|
19
|
+
"""Base class for extraction strategies that define how to extract a subgraph given a set of seed nodes."""
|
|
20
|
+
self.wrapper = wrapper
|
|
21
|
+
self.uri_manager = uri_manager
|
|
22
|
+
self.config = config
|
|
23
|
+
self.graph: MultiGraph = MultiGraph()
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def extract(self, nodes: list[str]) -> tuple[list[tuple[str, str, str]], str]:
|
|
27
|
+
"""Extract a subgraph given a list of seed nodes.
|
|
28
|
+
Returns:
|
|
29
|
+
Tuple of (list of triplets, result message)"""
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def format_progress_description(self, nodes: list[str]) -> str:
|
|
33
|
+
"""Short message logged in the progress bar(ex: A -> B or just A for single node extractions)"""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def format_start_message(
|
|
37
|
+
self,
|
|
38
|
+
nodes: list[str],
|
|
39
|
+
) -> str:
|
|
40
|
+
"""Full message logged when starting the extraction of a row"""
|
|
41
|
+
|
|
42
|
+
def execute_task(
|
|
43
|
+
self, nodes: list[str], progress: Progress, task: TaskID
|
|
44
|
+
) -> list[tuple[str, str, str]]:
|
|
45
|
+
"""Execute the extraction task with progress bar updates and error handling."""
|
|
46
|
+
task_description = self.format_progress_description(nodes)
|
|
47
|
+
start_message = self.format_start_message(nodes)
|
|
48
|
+
|
|
49
|
+
progress.update(task, description=f"[cyan]{task_description}[/]")
|
|
50
|
+
|
|
51
|
+
logger.info(f"[bold blue]Starting:[/] {start_message}")
|
|
52
|
+
|
|
53
|
+
start_time = time()
|
|
54
|
+
|
|
55
|
+
triplets, result_message = self.extract(nodes)
|
|
56
|
+
|
|
57
|
+
duration = time() - start_time
|
|
58
|
+
|
|
59
|
+
if not triplets:
|
|
60
|
+
logger.warning(f"{result_message} - Took {duration:.2f} sec\n")
|
|
61
|
+
else:
|
|
62
|
+
logger.info(f"{result_message} - Took {duration:.2f} sec\n")
|
|
63
|
+
|
|
64
|
+
progress.advance(task)
|
|
65
|
+
|
|
66
|
+
return triplets
|
|
67
|
+
|
|
68
|
+
def load_graph(self, graph: MultiGraph) -> None:
|
|
69
|
+
"""Load an existing graph that will be used during extraction"""
|
|
70
|
+
self.graph = graph
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from graph_seeder.extraction.ExtractionStrategy import ExtractionStrategy
|
|
3
|
+
|
|
4
|
+
from graph_seeder.utils.URIManager import URIManager
|
|
5
|
+
from graph_seeder.wrapper.NeighborhoodWrapper import NeighborhoodWrapper
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger("subgraph")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HopExpansion(ExtractionStrategy):
|
|
11
|
+
"""Simple expansion strategy that expands each node level by level up to max_hops."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self, wrapper: NeighborhoodWrapper, uri_manager: URIManager, config: dict
|
|
15
|
+
):
|
|
16
|
+
super().__init__(wrapper, uri_manager, config)
|
|
17
|
+
|
|
18
|
+
self.max_hops = config["extraction"]["max_hops"]
|
|
19
|
+
self.excluded_nodes = set(
|
|
20
|
+
config.get("graph_filters", {}).get("exclude_nodes", [])
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
def format_progress_description(self, nodes: list[str]) -> str:
|
|
24
|
+
return self.uri_manager.compress_uri(nodes[0])
|
|
25
|
+
|
|
26
|
+
def format_start_message(
|
|
27
|
+
self,
|
|
28
|
+
nodes: list[str],
|
|
29
|
+
) -> str:
|
|
30
|
+
return f"Expanding {self.max_hops} hops for {self.format_progress_description(nodes)}"
|
|
31
|
+
|
|
32
|
+
def extract(self, nodes: list[str]) -> tuple[list[tuple[str, str, str]], str]:
|
|
33
|
+
"""
|
|
34
|
+
Extract a subgraph by expanding from the given seed nodes up to max_hops.
|
|
35
|
+
Returns:
|
|
36
|
+
Tuple of (list of path triplets, result message)
|
|
37
|
+
"""
|
|
38
|
+
triplets = self._expand(nodes)
|
|
39
|
+
|
|
40
|
+
if not triplets:
|
|
41
|
+
return [], "[yellow]✗[/] No triplets found within the specified hop limit."
|
|
42
|
+
|
|
43
|
+
return (
|
|
44
|
+
triplets,
|
|
45
|
+
f"[green]✓[/] Extracted {len(triplets)} triplets within {self.max_hops} hops.",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def _expand(self, nodes: list[str]) -> list[tuple[str, str, str]] | None:
|
|
49
|
+
"""
|
|
50
|
+
Expand a list of nodes radially up to max_hops.
|
|
51
|
+
Returns a list of all discovered unique triplets.
|
|
52
|
+
"""
|
|
53
|
+
valid_nodes = [n for n in nodes if n not in self.excluded_nodes]
|
|
54
|
+
if not valid_nodes:
|
|
55
|
+
return []
|
|
56
|
+
|
|
57
|
+
visited_nodes: set[str] = set(valid_nodes)
|
|
58
|
+
current_level_nodes: set[str] = set(valid_nodes)
|
|
59
|
+
self.graph.add_nodes_from(valid_nodes)
|
|
60
|
+
all_triplets: set[tuple[str, str, str]] = set()
|
|
61
|
+
|
|
62
|
+
for hop in range(self.max_hops):
|
|
63
|
+
logger.info(
|
|
64
|
+
f"Expanding hop {hop + 1}/{self.max_hops} with {len(current_level_nodes)} nodes..."
|
|
65
|
+
)
|
|
66
|
+
next_level_nodes: set[str] = set()
|
|
67
|
+
|
|
68
|
+
for triplets in self.wrapper.get_neighborhood(list(current_level_nodes)):
|
|
69
|
+
for subj, pred, obj in triplets:
|
|
70
|
+
self.graph.add_edge(subj, obj, key=pred)
|
|
71
|
+
if subj in self.excluded_nodes or obj in self.excluded_nodes:
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
all_triplets.add((subj, pred, obj))
|
|
75
|
+
|
|
76
|
+
if subj not in visited_nodes:
|
|
77
|
+
next_level_nodes.add(subj)
|
|
78
|
+
if obj not in visited_nodes:
|
|
79
|
+
next_level_nodes.add(obj)
|
|
80
|
+
|
|
81
|
+
if not next_level_nodes:
|
|
82
|
+
logger.info("No more nodes to expand. Graph is fully explored.")
|
|
83
|
+
break
|
|
84
|
+
|
|
85
|
+
visited_nodes.update(next_level_nodes)
|
|
86
|
+
current_level_nodes = next_level_nodes
|
|
87
|
+
self.graph.add_nodes_from(current_level_nodes)
|
|
88
|
+
logger.info(
|
|
89
|
+
f"Hop {hop + 1} expansion complete. Discovered {len(all_triplets)} unique triplets"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return list(all_triplets)
|