structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""B6 long-context LLM baseline (blueprint section 8.1 B6).
|
|
2
|
+
|
|
3
|
+
Controls for "maybe you don't need retrieval at all": stuff the query session
|
|
4
|
+
plus its top-20 BM25 candidate precedents (labels included) into one prompt
|
|
5
|
+
and ask deepseek-chat (temperature 0, max_tokens 10) to label the query.
|
|
6
|
+
|
|
7
|
+
API key comes from SMA_DEEPSEEK_API_KEY or the repo .env, via the same lookup
|
|
8
|
+
the agent layer uses (sma.agent.llm._env_key). Errors get exactly one retry;
|
|
9
|
+
a second failure (or an unparseable reply) marks the row failed.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import time
|
|
15
|
+
|
|
16
|
+
from sma.agent.llm import DEEPSEEK_BASE_URL, DEEPSEEK_KEY_ENV, DEEPSEEK_MODEL, _env_key
|
|
17
|
+
|
|
18
|
+
CANDIDATE_CHARS = 800
|
|
19
|
+
QUERY_CHARS = 1600
|
|
20
|
+
|
|
21
|
+
SYSTEM_PROMPT = (
|
|
22
|
+
"You are an incident triage assistant. You label log sessions as Anomaly or "
|
|
23
|
+
"Normal by analogy to labeled precedent sessions. Reply with exactly one word: "
|
|
24
|
+
"Anomaly or Normal."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def build_prompt(query_text: str, precedents: list[tuple[str, str]]) -> str:
|
|
29
|
+
"""precedents: list of (label, session_text), already ranked."""
|
|
30
|
+
lines = [
|
|
31
|
+
"Query session (label unknown):",
|
|
32
|
+
query_text[:QUERY_CHARS],
|
|
33
|
+
"",
|
|
34
|
+
"Labeled precedent sessions (retrieved by lexical similarity, most similar first):",
|
|
35
|
+
]
|
|
36
|
+
for i, (label, text) in enumerate(precedents, start=1):
|
|
37
|
+
lines.append(f"[{i}] ({label}) {text[:CANDIDATE_CHARS]}")
|
|
38
|
+
lines.append("")
|
|
39
|
+
lines.append(
|
|
40
|
+
"Based on these precedents, is the query session Anomaly or Normal? "
|
|
41
|
+
"Answer with exactly one word."
|
|
42
|
+
)
|
|
43
|
+
return "\n".join(lines)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def parse_label(content: str) -> str | None:
|
|
47
|
+
lowered = content.lower()
|
|
48
|
+
has_anomaly = "anomal" in lowered
|
|
49
|
+
has_normal = "normal" in lowered
|
|
50
|
+
if has_anomaly and not has_normal:
|
|
51
|
+
return "Anomaly"
|
|
52
|
+
if has_normal and not has_anomaly:
|
|
53
|
+
return "Normal"
|
|
54
|
+
if has_anomaly and has_normal: # ambiguous reply
|
|
55
|
+
first_a = lowered.find("anomal")
|
|
56
|
+
first_n = lowered.find("normal")
|
|
57
|
+
return "Anomaly" if first_a < first_n else "Normal"
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LongContextDeepSeek:
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
model: str = DEEPSEEK_MODEL,
|
|
65
|
+
api_key: str | None = None,
|
|
66
|
+
timeout: float = 60.0,
|
|
67
|
+
retry_sleep: float = 2.0,
|
|
68
|
+
):
|
|
69
|
+
self.model = model
|
|
70
|
+
self.api_key = api_key or _env_key(DEEPSEEK_KEY_ENV)
|
|
71
|
+
self.timeout = timeout
|
|
72
|
+
self.retry_sleep = retry_sleep
|
|
73
|
+
self.calls = 0
|
|
74
|
+
self.failures: list[str] = []
|
|
75
|
+
self.total_prompt_tokens = 0
|
|
76
|
+
self.total_completion_tokens = 0
|
|
77
|
+
|
|
78
|
+
def _call_once(self, prompt: str) -> str:
|
|
79
|
+
import httpx
|
|
80
|
+
|
|
81
|
+
response = httpx.post(
|
|
82
|
+
f"{DEEPSEEK_BASE_URL}/chat/completions",
|
|
83
|
+
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
84
|
+
json={
|
|
85
|
+
"model": self.model,
|
|
86
|
+
"messages": [
|
|
87
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
88
|
+
{"role": "user", "content": prompt},
|
|
89
|
+
],
|
|
90
|
+
"temperature": 0,
|
|
91
|
+
"max_tokens": 10,
|
|
92
|
+
},
|
|
93
|
+
timeout=self.timeout,
|
|
94
|
+
)
|
|
95
|
+
response.raise_for_status()
|
|
96
|
+
payload = response.json()
|
|
97
|
+
self.calls += 1
|
|
98
|
+
usage = payload.get("usage", {})
|
|
99
|
+
self.total_prompt_tokens += int(usage.get("prompt_tokens", 0))
|
|
100
|
+
self.total_completion_tokens += int(usage.get("completion_tokens", 0))
|
|
101
|
+
return (payload["choices"][0]["message"]["content"] or "").strip()
|
|
102
|
+
|
|
103
|
+
def classify(
|
|
104
|
+
self, query_id: str, query_text: str, precedents: list[tuple[str, str]]
|
|
105
|
+
) -> str | None:
|
|
106
|
+
"""Return 'Anomaly'/'Normal', or None when the row failed (after one retry)."""
|
|
107
|
+
if not self.api_key:
|
|
108
|
+
self.failures.append(f"{query_id}: {DEEPSEEK_KEY_ENV} not set")
|
|
109
|
+
return None
|
|
110
|
+
prompt = build_prompt(query_text, precedents)
|
|
111
|
+
last_error = ""
|
|
112
|
+
for attempt in range(2): # one initial call + one retry
|
|
113
|
+
try:
|
|
114
|
+
content = self._call_once(prompt)
|
|
115
|
+
label = parse_label(content)
|
|
116
|
+
if label is not None:
|
|
117
|
+
return label
|
|
118
|
+
last_error = f"unparseable reply: {content!r}"
|
|
119
|
+
except Exception as exc: # noqa: BLE001 - any transport/API error retries once
|
|
120
|
+
last_error = f"{type(exc).__name__}: {exc}"
|
|
121
|
+
if attempt == 0:
|
|
122
|
+
time.sleep(self.retry_sleep)
|
|
123
|
+
self.failures.append(f"{query_id}: {last_error}")
|
|
124
|
+
return None
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Cross-encoder reranking (cross-encoder/ms-marco-MiniLM-L-6-v2).
|
|
2
|
+
|
|
3
|
+
Reranks a candidate pool (here: top-20 of Hybrid-RRF) and returns the top-k.
|
|
4
|
+
Raw model outputs are logits; we pass them through a sigmoid so the scores
|
|
5
|
+
are positive and usable by the protocol's weighted top-5 label vote (a
|
|
6
|
+
negative weight would flip a vote, which is not what 'weighted' means there).
|
|
7
|
+
Candidate texts are truncated to ``max_chars`` before pairing; the tokenizer
|
|
8
|
+
truncates to 512 wordpieces anyway, so this only bounds tokenizer cost.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CrossEncoderReranker:
|
|
19
|
+
def __init__(self, model_name: str = MODEL_NAME, max_chars: int = 2000):
|
|
20
|
+
from sentence_transformers import CrossEncoder
|
|
21
|
+
|
|
22
|
+
self.model = CrossEncoder(model_name, device="cpu", max_length=512)
|
|
23
|
+
self.max_chars = max_chars
|
|
24
|
+
|
|
25
|
+
def rerank(
|
|
26
|
+
self,
|
|
27
|
+
query_text: str,
|
|
28
|
+
candidates: list[tuple[str, str]],
|
|
29
|
+
top_k: int = 10,
|
|
30
|
+
) -> list[tuple[str, float]]:
|
|
31
|
+
if not candidates:
|
|
32
|
+
return []
|
|
33
|
+
q = query_text[: self.max_chars]
|
|
34
|
+
pairs = [(q, text[: self.max_chars]) for _, text in candidates]
|
|
35
|
+
logits = self.model.predict(pairs, show_progress_bar=False)
|
|
36
|
+
scores = 1.0 / (1.0 + np.exp(-np.asarray(logits, dtype=np.float64)))
|
|
37
|
+
ranked = sorted(
|
|
38
|
+
zip((doc_id for doc_id, _ in candidates), map(float, scores)),
|
|
39
|
+
key=lambda row: (-row[1], row[0]),
|
|
40
|
+
)
|
|
41
|
+
return ranked[:top_k]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""SPLADE learned sparse retrieval (naver/splade-cocondenser-ensembledistil).
|
|
2
|
+
|
|
3
|
+
CPU-only. Standard SPLADE document/query representation:
|
|
4
|
+
rep = max over token positions of log(1 + relu(MLM logits)),
|
|
5
|
+
masked by attention. Documents are batch-encoded once into a scipy CSR
|
|
6
|
+
matrix; query scoring is a sparse dot product. Per-query latency includes
|
|
7
|
+
the query forward pass (same convention as the other neural baselines).
|
|
8
|
+
|
|
9
|
+
Inputs are truncated to ``max_length`` wordpieces (default 256) to keep the
|
|
10
|
+
CPU forward pass tractable on long log sessions; the same truncation applies
|
|
11
|
+
to documents and queries, so no method gets privileged context.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from scipy import sparse
|
|
18
|
+
|
|
19
|
+
MODEL_NAME = "naver/splade-cocondenser-ensembledistil"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SpladeRetriever:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_name: str = MODEL_NAME,
|
|
26
|
+
max_length: int = 256,
|
|
27
|
+
batch_size: int = 8,
|
|
28
|
+
):
|
|
29
|
+
import torch
|
|
30
|
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
31
|
+
|
|
32
|
+
self.torch = torch
|
|
33
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
34
|
+
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
|
35
|
+
self.model.eval()
|
|
36
|
+
self.max_length = max_length
|
|
37
|
+
self.batch_size = batch_size
|
|
38
|
+
self.doc_ids: list[str] = []
|
|
39
|
+
self.doc_matrix: sparse.csr_matrix | None = None
|
|
40
|
+
|
|
41
|
+
def _encode_batch(self, texts: list[str]) -> np.ndarray:
|
|
42
|
+
torch = self.torch
|
|
43
|
+
tokens = self.tokenizer(
|
|
44
|
+
texts,
|
|
45
|
+
return_tensors="pt",
|
|
46
|
+
padding=True,
|
|
47
|
+
truncation=True,
|
|
48
|
+
max_length=self.max_length,
|
|
49
|
+
)
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
logits = self.model(**tokens).logits # (B, T, V)
|
|
52
|
+
# log(1 + relu(logits)), max-pooled over valid token positions.
|
|
53
|
+
weights = torch.log1p(torch.relu(logits))
|
|
54
|
+
mask = tokens["attention_mask"].unsqueeze(-1)
|
|
55
|
+
reps = (weights * mask).max(dim=1).values # (B, V)
|
|
56
|
+
return reps.numpy()
|
|
57
|
+
|
|
58
|
+
def encode(self, texts: list[str]) -> sparse.csr_matrix:
|
|
59
|
+
rows = []
|
|
60
|
+
for start in range(0, len(texts), self.batch_size):
|
|
61
|
+
batch = texts[start : start + self.batch_size]
|
|
62
|
+
rows.append(sparse.csr_matrix(self._encode_batch(batch)))
|
|
63
|
+
return sparse.vstack(rows) if rows else sparse.csr_matrix((0, 0))
|
|
64
|
+
|
|
65
|
+
def build(self, documents: list[tuple[str, str]]) -> None:
|
|
66
|
+
self.doc_ids = [doc_id for doc_id, _ in documents]
|
|
67
|
+
self.doc_matrix = self.encode([text for _, text in documents])
|
|
68
|
+
|
|
69
|
+
def retrieve(self, query_text: str, k: int = 10) -> list[tuple[str, float]]:
|
|
70
|
+
if self.doc_matrix is None:
|
|
71
|
+
return []
|
|
72
|
+
q = sparse.csr_matrix(self._encode_batch([query_text]))
|
|
73
|
+
scores = np.asarray((self.doc_matrix @ q.T).todense()).ravel()
|
|
74
|
+
ranked = sorted(
|
|
75
|
+
zip(self.doc_ids, map(float, scores)), key=lambda row: (-row[1], row[0])
|
|
76
|
+
)
|
|
77
|
+
return ranked[:k]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Weisfeiler-Leman graph-kernel control over SMA's own Tier-0 cases.
|
|
2
|
+
|
|
3
|
+
The internal control for the ladder: does *generic* graph similarity computed
|
|
4
|
+
on the SAME deterministic extraction (sma.encoders get_encoder("logs")) match
|
|
5
|
+
structure-mapping retrieval? If yes, SMA's edge is the extraction; if no, the
|
|
6
|
+
edge is the mapping mathematics.
|
|
7
|
+
|
|
8
|
+
Graph construction (per case):
|
|
9
|
+
- nodes = unique expressions (keyed by canonical s-expression) plus unique
|
|
10
|
+
entities (keyed by (name, type));
|
|
11
|
+
- edges = statement -> argument, position-annotated.
|
|
12
|
+
|
|
13
|
+
Node labels (iteration 0):
|
|
14
|
+
- statements: the functor (exactly what SMA's own content vectors see);
|
|
15
|
+
- entities whose names are arbitrary per-session identifiers (types
|
|
16
|
+
"event", "session") are labeled by TYPE ONLY. This is not a kindness, it
|
|
17
|
+
is a necessity: the session entity touches nearly every statement, so a
|
|
18
|
+
case-unique session name would make every refined label case-unique after
|
|
19
|
+
one iteration and the kernel would degenerate to zero similarity.
|
|
20
|
+
- all other entities (components, event_type tokens, integer counts) keep
|
|
21
|
+
"type:name" -- their names are shared vocabulary, i.e. real content.
|
|
22
|
+
|
|
23
|
+
Refinement: 2 WL iterations; new label = hash(own label, position-sorted child
|
|
24
|
+
labels, sorted parent labels). Similarity = cosine over the concatenated label
|
|
25
|
+
histograms of iterations 0..2 (Shervashidze et al. 2011, WL subtree kernel,
|
|
26
|
+
normalized).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import hashlib
|
|
32
|
+
from collections import Counter
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
from scipy import sparse
|
|
36
|
+
|
|
37
|
+
from sma.ir.schema import Case, Entity, Statement
|
|
38
|
+
from sma.ir.sexpr import dumps_statement
|
|
39
|
+
|
|
40
|
+
# Entity types whose names are arbitrary per-session identifiers.
|
|
41
|
+
_ID_LIKE_TYPES = frozenset({"event", "session"})
|
|
42
|
+
|
|
43
|
+
WL_ITERATIONS = 2
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _entity_label(ent: Entity) -> str:
|
|
47
|
+
if ent.type in _ID_LIKE_TYPES:
|
|
48
|
+
return f"ent:{ent.type}"
|
|
49
|
+
return f"ent:{ent.type}:{ent.name}"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _hash_label(payload: str, memo: dict[str, str]) -> str:
|
|
53
|
+
cached = memo.get(payload)
|
|
54
|
+
if cached is None:
|
|
55
|
+
cached = hashlib.blake2b(payload.encode("utf-8"), digest_size=8).hexdigest()
|
|
56
|
+
memo[payload] = cached
|
|
57
|
+
return cached
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def wl_histogram(case: Case, iterations: int = WL_ITERATIONS) -> Counter[str]:
|
|
61
|
+
"""Concatenated WL label histogram (iterations 0..n) for one case."""
|
|
62
|
+
# --- build the node/edge structure ---------------------------------
|
|
63
|
+
node_labels: list[str] = []
|
|
64
|
+
children: list[list[tuple[int, int]]] = [] # node -> [(arg_pos, child_node)]
|
|
65
|
+
parents: list[list[int]] = []
|
|
66
|
+
stmt_idx: dict[str, int] = {}
|
|
67
|
+
ent_idx: dict[tuple[str, str], int] = {}
|
|
68
|
+
|
|
69
|
+
def add_node(label: str) -> int:
|
|
70
|
+
node_labels.append(label)
|
|
71
|
+
children.append([])
|
|
72
|
+
parents.append([])
|
|
73
|
+
return len(node_labels) - 1
|
|
74
|
+
|
|
75
|
+
def visit(node: Statement | Entity) -> int:
|
|
76
|
+
if isinstance(node, Entity):
|
|
77
|
+
key = (node.name, node.type)
|
|
78
|
+
idx = ent_idx.get(key)
|
|
79
|
+
if idx is None:
|
|
80
|
+
idx = add_node(_entity_label(node))
|
|
81
|
+
ent_idx[key] = idx
|
|
82
|
+
return idx
|
|
83
|
+
skey = dumps_statement(node)
|
|
84
|
+
idx = stmt_idx.get(skey)
|
|
85
|
+
if idx is not None:
|
|
86
|
+
return idx
|
|
87
|
+
idx = add_node(f"f:{node.functor}")
|
|
88
|
+
stmt_idx[skey] = idx
|
|
89
|
+
for pos, arg in enumerate(node.args):
|
|
90
|
+
child = visit(arg)
|
|
91
|
+
children[idx].append((pos, child))
|
|
92
|
+
parents[child].append(idx)
|
|
93
|
+
return idx
|
|
94
|
+
|
|
95
|
+
for statement in case.statements:
|
|
96
|
+
visit(statement)
|
|
97
|
+
|
|
98
|
+
# --- WL refinement ---------------------------------------------------
|
|
99
|
+
memo: dict[str, str] = {}
|
|
100
|
+
histogram: Counter[str] = Counter()
|
|
101
|
+
labels = list(node_labels)
|
|
102
|
+
for node_label in labels:
|
|
103
|
+
histogram[f"wl0:{node_label}"] += 1
|
|
104
|
+
for it in range(1, iterations + 1):
|
|
105
|
+
new_labels = []
|
|
106
|
+
for idx, own in enumerate(labels):
|
|
107
|
+
child_part = sorted(f"c{pos}:{labels[child]}" for pos, child in children[idx])
|
|
108
|
+
parent_part = sorted(f"p:{labels[p]}" for p in parents[idx])
|
|
109
|
+
payload = own + "|" + ",".join(child_part) + "|" + ",".join(parent_part)
|
|
110
|
+
new_labels.append(_hash_label(payload, memo))
|
|
111
|
+
labels = new_labels
|
|
112
|
+
for node_label in labels:
|
|
113
|
+
histogram[f"wl{it}:{node_label}"] += 1
|
|
114
|
+
return histogram
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class WLKernelRetriever:
|
|
118
|
+
"""Cosine-normalized WL subtree kernel retrieval over encoded cases."""
|
|
119
|
+
|
|
120
|
+
def __init__(self, iterations: int = WL_ITERATIONS):
|
|
121
|
+
self.iterations = iterations
|
|
122
|
+
self.doc_ids: list[str] = []
|
|
123
|
+
self.feature_index: dict[str, int] = {}
|
|
124
|
+
self.doc_matrix: sparse.csr_matrix | None = None # rows L2-normalized
|
|
125
|
+
|
|
126
|
+
def build(self, cases: list[Case]) -> None:
|
|
127
|
+
histograms = [wl_histogram(c, self.iterations) for c in cases]
|
|
128
|
+
self.doc_ids = [c.case_id for c in cases]
|
|
129
|
+
self.feature_index = {}
|
|
130
|
+
for hist in histograms:
|
|
131
|
+
for feat in hist:
|
|
132
|
+
if feat not in self.feature_index:
|
|
133
|
+
self.feature_index[feat] = len(self.feature_index)
|
|
134
|
+
rows, cols, vals = [], [], []
|
|
135
|
+
for row, hist in enumerate(histograms):
|
|
136
|
+
for feat, count in hist.items():
|
|
137
|
+
rows.append(row)
|
|
138
|
+
cols.append(self.feature_index[feat])
|
|
139
|
+
vals.append(float(count))
|
|
140
|
+
matrix = sparse.csr_matrix(
|
|
141
|
+
(vals, (rows, cols)), shape=(len(histograms), max(len(self.feature_index), 1))
|
|
142
|
+
)
|
|
143
|
+
norms = np.sqrt(matrix.multiply(matrix).sum(axis=1)).A.ravel()
|
|
144
|
+
norms[norms == 0] = 1.0
|
|
145
|
+
self.doc_matrix = sparse.diags(1.0 / norms) @ matrix
|
|
146
|
+
|
|
147
|
+
def retrieve(self, query_case: Case, k: int = 10) -> list[tuple[str, float]]:
|
|
148
|
+
if self.doc_matrix is None:
|
|
149
|
+
return []
|
|
150
|
+
hist = wl_histogram(query_case, self.iterations)
|
|
151
|
+
# Query norm uses the FULL histogram (including features absent from
|
|
152
|
+
# the index vocabulary) so cosine is honest, not inflated.
|
|
153
|
+
q_norm = float(np.sqrt(sum(v * v for v in hist.values()))) or 1.0
|
|
154
|
+
q = np.zeros(self.doc_matrix.shape[1])
|
|
155
|
+
for feat, count in hist.items():
|
|
156
|
+
col = self.feature_index.get(feat)
|
|
157
|
+
if col is not None:
|
|
158
|
+
q[col] = count / q_norm
|
|
159
|
+
scores = self.doc_matrix @ q
|
|
160
|
+
ranked = sorted(
|
|
161
|
+
zip(self.doc_ids, map(float, scores)), key=lambda row: (-row[1], row[0])
|
|
162
|
+
)
|
|
163
|
+
return ranked[:k]
|