kgzip 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kgzip/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from kgzip import compat # noqa: F401 — applies nest_asyncio if needed
2
+ from kgzip._version import __version__
3
+ from kgzip.exceptions import KGZipError
4
+ from kgzip.serialize import to_compact, to_triples
5
+ from kgzip.store import KGZipStore
6
+
7
+ __all__ = ["KGZipStore", "KGZipError", "to_triples", "to_compact", "__version__"]
kgzip/_version.py ADDED
@@ -0,0 +1,3 @@
1
+ """Single source of truth for the package version."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,5 @@
1
+ from kgzip.benchmark.medical_kg import generate_medical_kg
2
+ from kgzip.benchmark.harness import time_query
3
+ from kgzip.benchmark.comparison import run_benchmark
4
+
5
+ __all__ = ["generate_medical_kg", "time_query", "run_benchmark"]
@@ -0,0 +1,155 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import shutil
7
+ import tempfile
8
+ import time
9
+ from collections import deque
10
+ from typing import List
11
+
12
+ import networkx as nx
13
+
14
+ from kgzip.benchmark.harness import time_query
15
+ from kgzip.models import (
16
+ BenchmarkReport,
17
+ DecisionConfig,
18
+ KGZipConfig,
19
+ StorageConfig,
20
+ )
21
+ from kgzip.store import KGZipStore
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Number of timed runs per method (kept small to keep tests fast)
26
+ _BENCH_RUNS = 5
27
+
28
+
29
+ def _bfs_subgraph(G: nx.DiGraph, seed_nodes: List[str], depth: int) -> set:
30
+ visited: set = set()
31
+ queue: deque = deque()
32
+ seen: set = set()
33
+ for n in seed_nodes:
34
+ if n in G:
35
+ queue.append((n, 0))
36
+ seen.add(n)
37
+ while queue:
38
+ node, d = queue.popleft()
39
+ visited.add(node)
40
+ if d < depth:
41
+ for nbr in G.successors(node):
42
+ if nbr not in seen:
43
+ seen.add(nbr)
44
+ queue.append((nbr, d + 1))
45
+ return visited
46
+
47
+
48
+ def run_benchmark(
49
+ graph: nx.DiGraph,
50
+ query_node_ids: List[str],
51
+ depth: int = 2,
52
+ ) -> BenchmarkReport:
53
+ """Compare standard (full-graph-load) traversal against KGZip capsule query.
54
+
55
+ Standard baseline: serialize full graph as JSON to disk; each timed run
56
+ loads the JSON from disk, reconstructs the graph, then runs BFS. This
57
+ mirrors how a naive graph store would answer queries — reading all data for
58
+ every access.
59
+
60
+ KGZip: compresses the graph into parallel-decodable capsules; each timed
61
+ run decodes only the capsules relevant to the query.
62
+
63
+ Returns BenchmarkReport with timing statistics and size comparison.
64
+ """
65
+ baseline_dir = tempfile.mkdtemp(prefix="kgzip_bench_std_")
66
+ kgzip_dir = tempfile.mkdtemp(prefix="kgzip_bench_kgz_")
67
+
68
+ try:
69
+ # --- Standard baseline setup ---
70
+ graph_data = nx.node_link_data(graph)
71
+ baseline_path = os.path.join(baseline_dir, "graph.json")
72
+ with open(baseline_path, "w", encoding="utf-8") as f:
73
+ json.dump(graph_data, f)
74
+ original_size_estimate_bytes = os.path.getsize(baseline_path)
75
+
76
+ def _standard_run() -> set:
77
+ with open(baseline_path, encoding="utf-8") as f:
78
+ data = json.load(f)
79
+ G = nx.node_link_graph(data)
80
+ return _bfs_subgraph(G, query_node_ids, depth)
81
+
82
+ # warm-up (avoids first-call cold-start bias)
83
+ _standard_run()
84
+ standard_stats = time_query(_standard_run, runs=_BENCH_RUNS)
85
+
86
+ # --- KGZip setup ---
87
+ # Small capsule cap ensures multiple capsules so queries load a fraction.
88
+ kgzip_config = KGZipConfig(
89
+ decision=DecisionConfig(
90
+ max_capsule_nodes=100,
91
+ min_capsule_nodes=5,
92
+ random_seed=42,
93
+ ),
94
+ storage=StorageConfig(
95
+ base_path=kgzip_dir,
96
+ compression="zstd",
97
+ compression_level=3,
98
+ ),
99
+ )
100
+ store = KGZipStore(kgzip_dir, kgzip_config)
101
+
102
+ t0 = time.monotonic()
103
+ store_ref = store.compress(graph)
104
+ compress_time_ms = (time.monotonic() - t0) * 1000.0
105
+
106
+ # warm-up
107
+ store.query(query_node_ids, depth=depth)
108
+
109
+ kgzip_stats = time_query(
110
+ lambda: store.query(query_node_ids, depth=depth),
111
+ runs=_BENCH_RUNS,
112
+ )
113
+
114
+ speedup_ratio = (
115
+ standard_stats["mean_ms"] / kgzip_stats["mean_ms"]
116
+ if kgzip_stats["mean_ms"] > 0.0
117
+ else 0.0
118
+ )
119
+
120
+ report = BenchmarkReport(
121
+ compress_time_ms=compress_time_ms,
122
+ standard_mean_ms=standard_stats["mean_ms"],
123
+ kgzip_mean_ms=kgzip_stats["mean_ms"],
124
+ speedup_ratio=speedup_ratio,
125
+ capsule_count=store_ref.capsule_count,
126
+ store_size_bytes=store_ref.total_bytes,
127
+ original_size_estimate_bytes=original_size_estimate_bytes,
128
+ )
129
+
130
+ _print_summary(report, graph)
131
+ return report
132
+
133
+ finally:
134
+ shutil.rmtree(baseline_dir, ignore_errors=True)
135
+ shutil.rmtree(kgzip_dir, ignore_errors=True)
136
+
137
+
138
+ def _print_summary(report: BenchmarkReport, graph: nx.DiGraph) -> None:
139
+ print(
140
+ f"\n{'='*60}\n"
141
+ f"KGZip Benchmark Results\n"
142
+ f"{'='*60}\n"
143
+ f"Graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges\n"
144
+ f"Capsules: {report.capsule_count}\n"
145
+ f"Compress time: {report.compress_time_ms:.1f} ms\n"
146
+ f"\nQuery latency (mean over {_BENCH_RUNS} runs):\n"
147
+ f" Standard BFS: {report.standard_mean_ms:.2f} ms\n"
148
+ f" KGZip: {report.kgzip_mean_ms:.2f} ms\n"
149
+ f" Speedup: {report.speedup_ratio:.2f}x\n"
150
+ f"\nStorage:\n"
151
+ f" Original est: {report.original_size_estimate_bytes:,} bytes\n"
152
+ f" KGZip store: {report.store_size_bytes:,} bytes\n"
153
+ f" Reduction: {100*(1 - report.store_size_bytes/max(report.original_size_estimate_bytes,1)):.1f}%\n"
154
+ f"{'='*60}\n"
155
+ )
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ import statistics
4
+ import time
5
+ from typing import Any, Callable, Dict
6
+
7
+
8
+ def time_query(fn: Callable, *args: Any, runs: int = 10) -> Dict[str, float]:
9
+ """Time fn(*args) over N runs and return wall-time statistics in milliseconds."""
10
+ times: list = []
11
+ for _ in range(runs):
12
+ t0 = time.monotonic()
13
+ fn(*args)
14
+ times.append((time.monotonic() - t0) * 1000.0)
15
+
16
+ return {
17
+ "mean_ms": statistics.mean(times),
18
+ "median_ms": statistics.median(times),
19
+ "min_ms": min(times),
20
+ "max_ms": max(times),
21
+ "std_ms": statistics.stdev(times) if len(times) > 1 else 0.0,
22
+ }
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+
5
+ import networkx as nx
6
+
7
+
8
+ def generate_medical_kg(
9
+ n_drugs: int = 50,
10
+ n_diseases: int = 100,
11
+ n_genes: int = 200,
12
+ n_proteins: int = 150,
13
+ seed: int = 42,
14
+ ) -> nx.DiGraph:
15
+ """Generate a synthetic medical knowledge graph with realistic community structure.
16
+
17
+ Node types and approximate counts (defaults → ~500 nodes, ~1500 edges):
18
+ Drug --treats--> Disease
19
+ Disease --associated_with--> Gene
20
+ Gene --encodes--> Protein
21
+ Drug --targets--> Protein
22
+ Disease --comorbid_with--> Disease
23
+ """
24
+ rng = random.Random(seed)
25
+ G = nx.DiGraph()
26
+
27
+ drugs = [f"Drug_{i}" for i in range(n_drugs)]
28
+ diseases = [f"Disease_{i}" for i in range(n_diseases)]
29
+ genes = [f"Gene_{i}" for i in range(n_genes)]
30
+ proteins = [f"Protein_{i}" for i in range(n_proteins)]
31
+
32
+ for node in drugs:
33
+ G.add_node(node, node_type="drug", name=node)
34
+ for node in diseases:
35
+ G.add_node(node, node_type="disease", name=node)
36
+ for node in genes:
37
+ G.add_node(node, node_type="gene", name=node)
38
+ for node in proteins:
39
+ G.add_node(node, node_type="protein", name=node)
40
+
41
+ # Drug --treats--> Disease
42
+ for drug in drugs:
43
+ count = rng.randint(3, 6)
44
+ for disease in rng.sample(diseases, min(count, len(diseases))):
45
+ G.add_edge(drug, disease, relation="treats")
46
+
47
+ # Disease --associated_with--> Gene
48
+ for disease in diseases:
49
+ count = rng.randint(4, 8)
50
+ for gene in rng.sample(genes, min(count, len(genes))):
51
+ G.add_edge(disease, gene, relation="associated_with")
52
+
53
+ # Gene --encodes--> Protein
54
+ for gene in genes:
55
+ count = rng.randint(1, 3)
56
+ for protein in rng.sample(proteins, min(count, len(proteins))):
57
+ G.add_edge(gene, protein, relation="encodes")
58
+
59
+ # Drug --targets--> Protein
60
+ for drug in drugs:
61
+ count = rng.randint(2, 4)
62
+ for protein in rng.sample(proteins, min(count, len(proteins))):
63
+ G.add_edge(drug, protein, relation="targets")
64
+
65
+ # Disease --comorbid_with--> Disease (sparse)
66
+ comorbid_count = int(n_diseases * 0.5)
67
+ sampled = rng.sample(diseases, min(comorbid_count * 2, len(diseases)))
68
+ for i in range(0, len(sampled) - 1, 2):
69
+ G.add_edge(sampled[i], sampled[i + 1], relation="comorbid_with")
70
+
71
+ return G
@@ -0,0 +1,226 @@
1
+ """Query-quality evaluation harness for KGZip.
2
+
3
+ KGZip retrieves at *capsule* granularity: a depth-k query loads the seed's capsule
4
+ (plus neighbour capsules for deeper queries) and returns their full contents. That
5
+ is not the same as the exact k-hop neighbourhood of the seed. This harness measures
6
+ how well capsule retrieval approximates the true neighbourhood, across graph sizes.
7
+
8
+ Metrics (averaged over sampled seeds, per depth):
9
+ node_recall = |returned ∩ true| / |true| (did we get the real neighbourhood?)
10
+ node_precision = |returned ∩ true| / |returned| (how much extra did we pull in?)
11
+ edge_recall / edge_precision — same, over edge (src,dst,relation) signatures
12
+ perfect_recall_frac — fraction of seeds with node_recall == 1.0
13
+
14
+ A separate full-graph check confirms losslessness: querying every node returns the
15
+ entire original graph (recall == precision == 1.0).
16
+
17
+ Run directly:
18
+ python -m kgzip.benchmark.quality # default scales
19
+ python -m kgzip.benchmark.quality 100 1000 10000 # custom node counts
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import random
24
+ import tempfile
25
+ import time
26
+ from collections import deque
27
+ from typing import Dict, List, Set, Tuple
28
+
29
+ import networkx as nx
30
+
31
+ from kgzip.benchmark.medical_kg import generate_medical_kg
32
+ from kgzip.models import DecisionConfig, KGZipConfig, StorageConfig
33
+ from kgzip.store import KGZipStore
34
+
35
+ DEFAULT_SCALES = [100, 1_000, 10_000, 50_000, 100_000]
36
+ DEFAULT_DEPTHS = [1, 2]
37
+ DEFAULT_N_SEEDS = 20
38
+
39
+ _BASE_TOTAL = 500 # generate_medical_kg default node total (50+100+200+150)
40
+
41
+
42
+ def _scaled_medical_kg(n_nodes: int, seed: int = 42) -> nx.DiGraph:
43
+ """Generate a medical KG scaled to approximately *n_nodes* nodes."""
44
+ factor = n_nodes / _BASE_TOTAL
45
+ return generate_medical_kg(
46
+ n_drugs=max(1, round(50 * factor)),
47
+ n_diseases=max(1, round(100 * factor)),
48
+ n_genes=max(1, round(200 * factor)),
49
+ n_proteins=max(1, round(150 * factor)),
50
+ seed=seed,
51
+ )
52
+
53
+
54
+ def _true_neighbourhood(
55
+ G: nx.DiGraph, seed: str, depth: int
56
+ ) -> Tuple[Set[str], Set[Tuple[str, str, str]]]:
57
+ """Undirected k-hop neighbourhood of *seed*: node set and edge signature set."""
58
+ visited: Set[str] = {seed}
59
+ frontier = {seed}
60
+ for _ in range(depth):
61
+ nxt: Set[str] = set()
62
+ for n in frontier:
63
+ nxt.update(G.predecessors(n))
64
+ nxt.update(G.successors(n))
65
+ nxt -= visited
66
+ visited |= nxt
67
+ frontier = nxt
68
+ if not frontier:
69
+ break
70
+ edges: Set[Tuple[str, str, str]] = set()
71
+ for u, v, data in G.edges(data=True):
72
+ if u in visited and v in visited:
73
+ edges.add((u, v, data.get("relation", "related_to")))
74
+ return visited, edges
75
+
76
+
77
+ def _safe_div(a: float, b: float) -> float:
78
+ return a / b if b else 1.0
79
+
80
+
81
+ def evaluate_scale(
82
+ n_nodes: int,
83
+ depths: List[int] = None,
84
+ n_seeds: int = DEFAULT_N_SEEDS,
85
+ seed: int = 42,
86
+ ) -> Dict:
87
+ """Compress a scaled synthetic KG and measure query quality at each depth."""
88
+ depths = depths or DEFAULT_DEPTHS
89
+ rng = random.Random(seed)
90
+
91
+ G = _scaled_medical_kg(n_nodes, seed=seed)
92
+ actual_nodes = G.number_of_nodes()
93
+ actual_edges = G.number_of_edges()
94
+
95
+ store_dir = tempfile.mkdtemp(prefix="kgzip_quality_")
96
+ config = KGZipConfig(
97
+ decision=DecisionConfig(random_seed=seed),
98
+ storage=StorageConfig(base_path=store_dir, compression="zstd", compression_level=3),
99
+ )
100
+ store = KGZipStore(store_dir, config)
101
+
102
+ t0 = time.monotonic()
103
+ ref = store.compress(G)
104
+ compress_ms = (time.monotonic() - t0) * 1000.0
105
+
106
+ all_ids = [str(n) for n in G.nodes()]
107
+ seeds = rng.sample(all_ids, min(n_seeds, len(all_ids)))
108
+
109
+ per_depth: Dict[int, Dict[str, float]] = {}
110
+ for depth in depths:
111
+ nr = np = er = ep = 0.0
112
+ perfect = 0
113
+ lat_ms = 0.0
114
+ for s in seeds:
115
+ t = time.monotonic()
116
+ result = store.query([s], depth=depth)
117
+ lat_ms += (time.monotonic() - t) * 1000.0
118
+
119
+ ret_nodes = {n.id for n in result.subgraph.nodes}
120
+ ret_edges = {(e.src, e.dst, e.relation) for e in result.subgraph.edges}
121
+ true_nodes, true_edges = _true_neighbourhood(G, s, depth)
122
+
123
+ inter_n = len(ret_nodes & true_nodes)
124
+ inter_e = len(ret_edges & true_edges)
125
+ node_recall = _safe_div(inter_n, len(true_nodes))
126
+ nr += node_recall
127
+ np += _safe_div(inter_n, len(ret_nodes))
128
+ er += _safe_div(inter_e, len(true_edges))
129
+ ep += _safe_div(inter_e, len(ret_edges))
130
+ if node_recall >= 0.999:
131
+ perfect += 1
132
+
133
+ k = len(seeds)
134
+ per_depth[depth] = {
135
+ "node_recall": nr / k,
136
+ "node_precision": np / k,
137
+ "edge_recall": er / k,
138
+ "edge_precision": ep / k,
139
+ "perfect_recall_frac": perfect / k,
140
+ "mean_latency_ms": lat_ms / k,
141
+ }
142
+
143
+ # Losslessness check: query every node, expect the whole graph back. This is only
144
+ # meaningful when the store fits under the per-query capsule cap (50); above that a
145
+ # single query is intentionally capped, so we mark it n/a rather than a false fail.
146
+ from kgzip.query.interface import _MAX_CAPSULES_PER_QUERY
147
+
148
+ if ref.capsule_count <= _MAX_CAPSULES_PER_QUERY:
149
+ full = store.query(all_ids, depth=1)
150
+ full_nodes = {n.id for n in full.subgraph.nodes}
151
+ full_edges = {(e.src, e.dst, e.relation) for e in full.subgraph.edges}
152
+ orig_edges = {(str(u), str(v), d.get("relation", "related_to")) for u, v, d in G.edges(data=True)}
153
+ lossless = (full_nodes == set(all_ids)) and (full_edges == orig_edges)
154
+ else:
155
+ lossless = None # n/a: exceeds the 50-capsule per-query cap
156
+
157
+ return {
158
+ "requested_nodes": n_nodes,
159
+ "actual_nodes": actual_nodes,
160
+ "actual_edges": actual_edges,
161
+ "capsule_count": ref.capsule_count,
162
+ "store_bytes": ref.total_bytes,
163
+ "compress_ms": compress_ms,
164
+ "per_depth": per_depth,
165
+ "lossless_full_query": lossless,
166
+ }
167
+
168
+
169
+ def run_quality_eval(
170
+ scales: List[int] = None,
171
+ depths: List[int] = None,
172
+ n_seeds: int = DEFAULT_N_SEEDS,
173
+ seed: int = 42,
174
+ ) -> List[Dict]:
175
+ scales = scales or DEFAULT_SCALES
176
+ depths = depths or DEFAULT_DEPTHS
177
+ reports = []
178
+ for n in scales:
179
+ print(f"\n[quality] evaluating ~{n:,} nodes ...", flush=True)
180
+ rep = evaluate_scale(n, depths=depths, n_seeds=n_seeds, seed=seed)
181
+ reports.append(rep)
182
+ _print_scale(rep)
183
+ _print_summary(reports, depths)
184
+ return reports
185
+
186
+
187
+ def _print_scale(rep: Dict) -> None:
188
+ print(
189
+ f" nodes={rep['actual_nodes']:,} edges={rep['actual_edges']:,} "
190
+ f"capsules={rep['capsule_count']} "
191
+ f"store={rep['store_bytes']:,}B compress={rep['compress_ms']:.0f}ms "
192
+ f"lossless_full={rep['lossless_full_query']}"
193
+ )
194
+ for depth, m in rep["per_depth"].items():
195
+ print(
196
+ f" depth={depth}: "
197
+ f"node_recall={m['node_recall']:.2f} node_prec={m['node_precision']:.2f} "
198
+ f"edge_recall={m['edge_recall']:.2f} edge_prec={m['edge_precision']:.2f} "
199
+ f"perfect_recall={m['perfect_recall_frac']:.0%} "
200
+ f"lat={m['mean_latency_ms']:.2f}ms"
201
+ )
202
+
203
+
204
+ def _print_summary(reports: List[Dict], depths: List[int]) -> None:
205
+ print(f"\n{'='*84}\nKGZip query-quality summary\n{'='*84}")
206
+ header = f"{'nodes':>9} {'caps':>5} {'depth':>5} {'n_rec':>6} {'n_prec':>7} {'e_rec':>6} {'e_prec':>7} {'lat_ms':>7}"
207
+ print(header)
208
+ print("-" * len(header))
209
+ for rep in reports:
210
+ for depth in depths:
211
+ m = rep["per_depth"][depth]
212
+ print(
213
+ f"{rep['actual_nodes']:>9,} {rep['capsule_count']:>5} {depth:>5} "
214
+ f"{m['node_recall']:>6.2f} {m['node_precision']:>7.2f} "
215
+ f"{m['edge_recall']:>6.2f} {m['edge_precision']:>7.2f} {m['mean_latency_ms']:>7.2f}"
216
+ )
217
+ all_lossless = all(r["lossless_full_query"] for r in reports)
218
+ print("-" * len(header))
219
+ print(f"full-graph query lossless at every scale: {all_lossless}")
220
+
221
+
222
+ if __name__ == "__main__":
223
+ import sys
224
+
225
+ args = [int(a) for a in sys.argv[1:]] if len(sys.argv) > 1 else None
226
+ run_quality_eval(scales=args)
kgzip/compat.py ADDED
@@ -0,0 +1,77 @@
1
+ """asyncio compatibility helpers for KGZip.
2
+
3
+ KGZip exposes a synchronous public API but decodes capsules with asyncio under the
4
+ hood. This module provides :func:`run_coro`, which drives a coroutine to completion
5
+ from synchronous code in *both* plain scripts and environments that already have a
6
+ running event loop (Jupyter, or being called from inside async code).
7
+
8
+ It deliberately avoids the deprecated ``asyncio.get_event_loop()`` pattern, which
9
+ raises/ warns on Python 3.10+ when there is no current loop.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import logging
15
+ import threading
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _NEST_APPLIED = False
20
+
21
+ # A dedicated event loop per thread, reused across calls. Reuse avoids the
22
+ # per-call setup/teardown cost of asyncio.run(); thread-local storage keeps it
23
+ # safe to call query() concurrently from multiple threads (each thread drives
24
+ # its own loop, and a loop is not safe to share across threads).
25
+ _thread_local = threading.local()
26
+
27
+
28
+ def _ensure_nest_asyncio() -> None:
29
+ """Apply nest_asyncio once so a running loop can be re-entered. Idempotent."""
30
+ global _NEST_APPLIED
31
+ if _NEST_APPLIED:
32
+ return
33
+ try:
34
+ import nest_asyncio
35
+
36
+ nest_asyncio.apply()
37
+ _NEST_APPLIED = True
38
+ logger.debug("nest_asyncio applied for re-entrant event loop support")
39
+ except Exception: # pragma: no cover - nest_asyncio is a declared dependency
40
+ # Non-fatal: only matters when called from within a running loop.
41
+ pass
42
+
43
+
44
+ def run_coro(coro):
45
+ """Run *coro* to completion and return its result, from synchronous code.
46
+
47
+ - No running loop (scripts, pytest, normal use): run on a reused per-thread
48
+ loop. Thread-safe — each thread drives its own loop.
49
+ - Running loop present (Jupyter, or invoked from async code): apply nest_asyncio
50
+ once and re-enter the existing loop.
51
+ """
52
+ try:
53
+ running_loop = asyncio.get_running_loop()
54
+ except RuntimeError:
55
+ running_loop = None
56
+
57
+ if running_loop is not None:
58
+ _ensure_nest_asyncio()
59
+ return running_loop.run_until_complete(coro)
60
+
61
+ loop = getattr(_thread_local, "loop", None)
62
+ if loop is None or loop.is_closed():
63
+ loop = asyncio.new_event_loop()
64
+ _thread_local.loop = loop
65
+ return loop.run_until_complete(coro)
66
+
67
+
68
+ def _apply_on_import() -> None:
69
+ """If imported while a loop is already running (e.g. Jupyter), patch it now."""
70
+ try:
71
+ asyncio.get_running_loop()
72
+ except RuntimeError:
73
+ return # no running loop at import time — nothing to do
74
+ _ensure_nest_asyncio()
75
+
76
+
77
+ _apply_on_import()
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ from datetime import datetime, timezone
5
+ from typing import List, Set
6
+
7
+ from kgzip.decision.community import detect_communities
8
+ from kgzip.decision.gcs_scorer import estimate_score
9
+ from kgzip.decision.mode_selector import select_mode
10
+ from kgzip.decision.profiler import profile
11
+ from kgzip.decision.spectral import compute_fingerprint
12
+ from kgzip.exceptions import EmptyGraphError
13
+ from kgzip.models import (
14
+ CapsulePlan,
15
+ DecisionConfig,
16
+ DecisionPlan,
17
+ NormalizedGraph,
18
+ PlanMeta,
19
+ )
20
+
21
+ __all__ = [
22
+ "profile",
23
+ "detect_communities",
24
+ "compute_fingerprint",
25
+ "build_plan",
26
+ ]
27
+
28
+
29
+ def build_plan(
30
+ graph: NormalizedGraph, config: DecisionConfig = None
31
+ ) -> DecisionPlan:
32
+ if config is None:
33
+ config = DecisionConfig()
34
+
35
+ if graph.meta.node_count == 0:
36
+ raise EmptyGraphError("Cannot build plan for empty graph")
37
+
38
+ graph_profile = profile(graph)
39
+ communities: List[Set[str]] = detect_communities(graph, config)
40
+
41
+ # Build a reverse lookup: node_id → community index
42
+ node_to_comm: dict = {}
43
+ for idx, comm in enumerate(communities):
44
+ for nid in comm:
45
+ node_to_comm[nid] = idx
46
+
47
+ # Identify boundary nodes and neighbor capsule relationships
48
+ boundary_sets: List[Set[str]] = [set() for _ in communities]
49
+ # We'll use indices now; capsule_ids assigned after
50
+ neighbor_sets: List[Set[int]] = [set() for _ in communities]
51
+
52
+ for edge in graph.edges:
53
+ src_idx = node_to_comm.get(edge.src)
54
+ dst_idx = node_to_comm.get(edge.dst)
55
+ if src_idx is not None and dst_idx is not None and src_idx != dst_idx:
56
+ boundary_sets[src_idx].add(edge.src)
57
+ boundary_sets[dst_idx].add(edge.dst)
58
+ neighbor_sets[src_idx].add(dst_idx)
59
+ neighbor_sets[dst_idx].add(src_idx)
60
+
61
+ # Assign capsule UUIDs
62
+ capsule_ids = [str(uuid.uuid4()) for _ in communities]
63
+
64
+ capsules: List[CapsulePlan] = []
65
+ for idx, comm in enumerate(communities):
66
+ mode = select_mode(comm, graph, config)
67
+ fp = compute_fingerprint(comm, graph, config.spectral_k)
68
+ gcs = estimate_score(comm, graph)
69
+ neighbor_ids = [capsule_ids[j] for j in neighbor_sets[idx]]
70
+
71
+ capsules.append(
72
+ CapsulePlan(
73
+ capsule_id=capsule_ids[idx],
74
+ node_ids=set(comm),
75
+ boundary_nodes=boundary_sets[idx],
76
+ encoding_mode=mode,
77
+ gcs_score=gcs,
78
+ spectral_fp=fp,
79
+ neighbor_capsules=neighbor_ids,
80
+ )
81
+ )
82
+
83
+ # Compute community modularity via Louvain Q
84
+ try:
85
+ import community as louvain_community
86
+ nxg = graph.to_networkx().to_undirected()
87
+ partition = {nid: idx for idx, comm in enumerate(communities) for nid in comm}
88
+ modularity = louvain_community.modularity(partition, nxg)
89
+ except Exception:
90
+ modularity = -1.0
91
+
92
+ mean_gcs = sum(c.gcs_score for c in capsules) / len(capsules) if capsules else 0.0
93
+
94
+ plan_meta = PlanMeta(
95
+ total_capsules=len(capsules),
96
+ mean_gcs=mean_gcs,
97
+ community_modularity=modularity,
98
+ planned_at=datetime.now(timezone.utc).isoformat(),
99
+ )
100
+
101
+ return DecisionPlan(capsules=capsules, plan_meta=plan_meta)