structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
sma/eval/loghub_eval.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
"""LogHub MVP Diagnostic Evaluation script.
|
|
2
|
+
|
|
3
|
+
Performs stratified sampling on HDFS and BGL datasets, sessionizes logs,
|
|
4
|
+
indexes cases, runs retrieval via SMA, BM25, Dense RAG, and KG-PPR Proxy,
|
|
5
|
+
and outputs performance and latency metrics.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import csv
|
|
11
|
+
import pathlib
|
|
12
|
+
import random
|
|
13
|
+
import time
|
|
14
|
+
import zipfile
|
|
15
|
+
from collections import defaultdict, Counter
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from sklearn.metrics import f1_score
|
|
19
|
+
|
|
20
|
+
from sma.encoders import get_encoder
|
|
21
|
+
from sma.index.macfac import MacFacIndex
|
|
22
|
+
from sma.match.types import MatchConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_hdfs_blocks(hdfs_zip_path: pathlib.Path) -> dict[str, str]:
|
|
26
|
+
"""Load block labels from the HDFS anomaly label CSV."""
|
|
27
|
+
labels = {}
|
|
28
|
+
with zipfile.ZipFile(hdfs_zip_path, "r") as z:
|
|
29
|
+
with z.open("preprocessed/anomaly_label.csv") as fh:
|
|
30
|
+
reader = csv.reader(fh.read().decode("utf-8").splitlines())
|
|
31
|
+
next(reader) # Skip header
|
|
32
|
+
for row in reader:
|
|
33
|
+
if len(row) >= 2:
|
|
34
|
+
labels[row[0]] = row[1]
|
|
35
|
+
return labels
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def extract_hdfs_sessions(
|
|
39
|
+
hdfs_zip_path: pathlib.Path, sampled_blocks: set[str]
|
|
40
|
+
) -> dict[str, list[str]]:
|
|
41
|
+
"""Stream HDFS.log and extract log lines for sampled blocks."""
|
|
42
|
+
block_lines = defaultdict(list)
|
|
43
|
+
import re
|
|
44
|
+
block_re = re.compile(r"blk_-?\d+")
|
|
45
|
+
|
|
46
|
+
with zipfile.ZipFile(hdfs_zip_path, "r") as z:
|
|
47
|
+
with z.open("HDFS.log") as fh:
|
|
48
|
+
for line_bytes in fh:
|
|
49
|
+
line = line_bytes.decode("utf-8", errors="ignore")
|
|
50
|
+
match = block_re.search(line)
|
|
51
|
+
if match:
|
|
52
|
+
bid = match.group(0)
|
|
53
|
+
if bid in sampled_blocks:
|
|
54
|
+
block_lines[bid].append(line)
|
|
55
|
+
return block_lines
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_hdfs_first_timestamps(hdfs_zip_path: pathlib.Path) -> dict[str, float]:
|
|
59
|
+
"""Scan HDFS.log to find the first occurrence time (in seconds) of each block."""
|
|
60
|
+
import re
|
|
61
|
+
from datetime import datetime
|
|
62
|
+
block_re = re.compile(r"blk_-?\d+")
|
|
63
|
+
block_times = {}
|
|
64
|
+
|
|
65
|
+
# Scan the entire HDFS log stream for unbiased temporal stratification
|
|
66
|
+
with zipfile.ZipFile(hdfs_zip_path, "r") as z:
|
|
67
|
+
with z.open("HDFS.log") as fh:
|
|
68
|
+
for line_bytes in fh:
|
|
69
|
+
line = line_bytes.decode("utf-8", errors="ignore")
|
|
70
|
+
match = block_re.search(line)
|
|
71
|
+
if match:
|
|
72
|
+
bid = match.group(0)
|
|
73
|
+
if bid not in block_times:
|
|
74
|
+
ts_str = line[:13]
|
|
75
|
+
if len(ts_str) == 13 and ts_str[6] == ' ' and ts_str[:6].isdigit() and ts_str[7:].isdigit():
|
|
76
|
+
try:
|
|
77
|
+
dt = datetime.strptime(ts_str, "%y%m%d %H%M%S")
|
|
78
|
+
block_times[bid] = dt.timestamp()
|
|
79
|
+
except ValueError:
|
|
80
|
+
block_times[bid] = 0.0
|
|
81
|
+
else:
|
|
82
|
+
block_times[bid] = 0.0
|
|
83
|
+
return block_times
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def sample_hdfs_stratified(
|
|
87
|
+
hdfs_zip_path: pathlib.Path, sample_size: int = 1000, seed: int = 42
|
|
88
|
+
) -> list[tuple[str, str, str]]:
|
|
89
|
+
"""Sample stratified HDFS block sessions."""
|
|
90
|
+
labels = load_hdfs_blocks(hdfs_zip_path)
|
|
91
|
+
block_times = get_hdfs_first_timestamps(hdfs_zip_path)
|
|
92
|
+
|
|
93
|
+
# Filter to blocks that we found timestamps for
|
|
94
|
+
valid_blocks = [b for b in labels if b in block_times]
|
|
95
|
+
|
|
96
|
+
anom_blocks = [b for b in valid_blocks if labels[b] == "Anomaly"]
|
|
97
|
+
norm_blocks = [b for b in valid_blocks if labels[b] == "Normal"]
|
|
98
|
+
|
|
99
|
+
rng = random.Random(seed)
|
|
100
|
+
|
|
101
|
+
def get_stratified_subset(blocks, target_n):
|
|
102
|
+
# Sort by timestamp
|
|
103
|
+
sorted_blocks = sorted(blocks, key=lambda b: block_times[b])
|
|
104
|
+
if len(sorted_blocks) <= target_n:
|
|
105
|
+
return sorted_blocks
|
|
106
|
+
|
|
107
|
+
# Divide into 5 bins
|
|
108
|
+
bins = np.array_split(sorted_blocks, 5)
|
|
109
|
+
subset = []
|
|
110
|
+
per_bin = target_n // 5
|
|
111
|
+
for b in bins:
|
|
112
|
+
subset.extend(rng.sample(list(b), min(len(b), per_bin)))
|
|
113
|
+
# Fill remainder
|
|
114
|
+
while len(subset) < target_n and sorted_blocks:
|
|
115
|
+
rem = sorted(set(sorted_blocks) - set(subset))
|
|
116
|
+
if not rem:
|
|
117
|
+
break
|
|
118
|
+
subset.append(rng.choice(rem))
|
|
119
|
+
return subset
|
|
120
|
+
|
|
121
|
+
sampled_anom = get_stratified_subset(anom_blocks, sample_size // 2)
|
|
122
|
+
sampled_norm = get_stratified_subset(norm_blocks, sample_size // 2)
|
|
123
|
+
sampled_set = set(sampled_anom + sampled_norm)
|
|
124
|
+
|
|
125
|
+
# Extract log texts
|
|
126
|
+
block_lines = extract_hdfs_sessions(hdfs_zip_path, sampled_set)
|
|
127
|
+
|
|
128
|
+
results = []
|
|
129
|
+
for bid in sampled_anom + sampled_norm:
|
|
130
|
+
lines = block_lines.get(bid, [])
|
|
131
|
+
if lines:
|
|
132
|
+
results.append((bid, "".join(lines), labels[bid]))
|
|
133
|
+
return results
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def sample_bgl_stratified(
|
|
137
|
+
bgl_zip_path: pathlib.Path, sample_size: int = 1000, seed: int = 42
|
|
138
|
+
) -> list[tuple[str, str, str]]:
|
|
139
|
+
"""Sessionize and sample stratified BGL logs using two passes to save memory."""
|
|
140
|
+
# Pass 1: Gather metadata for sessionization and labels
|
|
141
|
+
session_counts = Counter()
|
|
142
|
+
labels = defaultdict(bool)
|
|
143
|
+
timestamps = {}
|
|
144
|
+
|
|
145
|
+
with zipfile.ZipFile(bgl_zip_path, "r") as z:
|
|
146
|
+
with z.open("BGL.log") as fh:
|
|
147
|
+
for line_bytes in fh:
|
|
148
|
+
line = line_bytes.decode("utf-8", errors="ignore")
|
|
149
|
+
parts = line.split(maxsplit=5)
|
|
150
|
+
if len(parts) < 5:
|
|
151
|
+
continue
|
|
152
|
+
label = parts[0]
|
|
153
|
+
try:
|
|
154
|
+
timestamp = int(parts[1])
|
|
155
|
+
except ValueError:
|
|
156
|
+
continue
|
|
157
|
+
node_id = parts[3]
|
|
158
|
+
|
|
159
|
+
# Group BGL into 60-second windows per Node ID as per blueprint
|
|
160
|
+
window = timestamp // 60
|
|
161
|
+
session_key = f"bgl_{node_id}_{window}"
|
|
162
|
+
session_counts[session_key] += 1
|
|
163
|
+
if label != "-":
|
|
164
|
+
labels[session_key] = True
|
|
165
|
+
if session_key not in timestamps:
|
|
166
|
+
timestamps[session_key] = timestamp
|
|
167
|
+
|
|
168
|
+
# Filter sessions with length >= 3 to avoid tiny cases
|
|
169
|
+
filtered_keys = [k for k, count in session_counts.items() if count >= 3]
|
|
170
|
+
anom_keys = [k for k in filtered_keys if labels[k]]
|
|
171
|
+
norm_keys = [k for k in filtered_keys if not labels[k]]
|
|
172
|
+
|
|
173
|
+
rng = random.Random(seed)
|
|
174
|
+
|
|
175
|
+
def get_stratified_subset(keys, target_n):
|
|
176
|
+
sorted_keys = sorted(keys, key=lambda k: timestamps[k])
|
|
177
|
+
if len(sorted_keys) <= target_n:
|
|
178
|
+
return sorted_keys
|
|
179
|
+
bins = np.array_split(sorted_keys, 5)
|
|
180
|
+
subset = []
|
|
181
|
+
per_bin = target_n // 5
|
|
182
|
+
for b in bins:
|
|
183
|
+
subset.extend(rng.sample(list(b), min(len(b), per_bin)))
|
|
184
|
+
while len(subset) < target_n and sorted_keys:
|
|
185
|
+
rem = sorted(set(sorted_keys) - set(subset))
|
|
186
|
+
if not rem:
|
|
187
|
+
break
|
|
188
|
+
subset.append(rng.choice(rem))
|
|
189
|
+
return subset
|
|
190
|
+
|
|
191
|
+
sampled_anom = get_stratified_subset(anom_keys, sample_size // 2)
|
|
192
|
+
sampled_norm = get_stratified_subset(norm_keys, sample_size // 2)
|
|
193
|
+
sampled_set = set(sampled_anom + sampled_norm)
|
|
194
|
+
|
|
195
|
+
# Pass 2: Extract actual lines for the sampled set
|
|
196
|
+
sessions_lines = defaultdict(list)
|
|
197
|
+
with zipfile.ZipFile(bgl_zip_path, "r") as z:
|
|
198
|
+
with z.open("BGL.log") as fh:
|
|
199
|
+
for line_bytes in fh:
|
|
200
|
+
line = line_bytes.decode("utf-8", errors="ignore")
|
|
201
|
+
parts = line.split(maxsplit=5)
|
|
202
|
+
if len(parts) < 5:
|
|
203
|
+
continue
|
|
204
|
+
try:
|
|
205
|
+
timestamp = int(parts[1])
|
|
206
|
+
except ValueError:
|
|
207
|
+
continue
|
|
208
|
+
node_id = parts[3]
|
|
209
|
+
window = timestamp // 60
|
|
210
|
+
session_key = f"bgl_{node_id}_{window}"
|
|
211
|
+
if session_key in sampled_set:
|
|
212
|
+
# Drop the leading alert-category column: it is the ground-truth
|
|
213
|
+
# label, not log content. Keeping it leaks labels to every
|
|
214
|
+
# retriever (BGL '-' = normal, anything else = anomaly).
|
|
215
|
+
sessions_lines[session_key].append(line.partition(" ")[2] or line)
|
|
216
|
+
|
|
217
|
+
results = []
|
|
218
|
+
for k in sampled_anom + sampled_norm:
|
|
219
|
+
lines = sessions_lines.get(k, [])
|
|
220
|
+
if lines:
|
|
221
|
+
results.append((k, "".join(lines), "Anomaly" if labels[k] else "Normal"))
|
|
222
|
+
return results
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def run_evaluation(
|
|
226
|
+
dataset_name: str,
|
|
227
|
+
sampled_data: list[tuple[str, str, str]],
|
|
228
|
+
output_manifest_rows: list[dict],
|
|
229
|
+
scorer: str = "ses",
|
|
230
|
+
) -> list[dict]:
|
|
231
|
+
"""Execute four-way evaluation comparison on the sampled dataset."""
|
|
232
|
+
print(f"\n--- Running evaluation on {dataset_name} ({len(sampled_data)} cases) ---")
|
|
233
|
+
|
|
234
|
+
# Save to manifest list
|
|
235
|
+
for sid, _, label in sampled_data:
|
|
236
|
+
output_manifest_rows.append({
|
|
237
|
+
"dataset": dataset_name,
|
|
238
|
+
"session_id": sid,
|
|
239
|
+
"label": label
|
|
240
|
+
})
|
|
241
|
+
|
|
242
|
+
# Split into 80% Index / 20% Query
|
|
243
|
+
random.Random(101).shuffle(sampled_data)
|
|
244
|
+
split_idx = int(len(sampled_data) * 0.8)
|
|
245
|
+
index_data = sampled_data[:split_idx]
|
|
246
|
+
query_data = sampled_data[split_idx:]
|
|
247
|
+
|
|
248
|
+
# Parse and encode cases
|
|
249
|
+
log_encoder = get_encoder("logs")
|
|
250
|
+
|
|
251
|
+
print("Encoding index cases...")
|
|
252
|
+
index_cases = []
|
|
253
|
+
index_docs = [] # List of (case_id, text)
|
|
254
|
+
index_labels = {}
|
|
255
|
+
for sid, text, label in index_data:
|
|
256
|
+
case = log_encoder.encode(text, session_id=sid).case
|
|
257
|
+
index_cases.append(case)
|
|
258
|
+
index_docs.append((case.case_id, text))
|
|
259
|
+
index_labels[case.case_id] = label
|
|
260
|
+
|
|
261
|
+
print("Encoding query cases...")
|
|
262
|
+
query_cases = []
|
|
263
|
+
query_docs = []
|
|
264
|
+
query_labels = {}
|
|
265
|
+
for sid, text, label in query_data:
|
|
266
|
+
case = log_encoder.encode(text, session_id=sid).case
|
|
267
|
+
query_cases.append(case)
|
|
268
|
+
query_docs.append((case.case_id, text))
|
|
269
|
+
query_labels[case.case_id] = label
|
|
270
|
+
|
|
271
|
+
# Build indexes ONCE before the query loop
|
|
272
|
+
# 1. Build SMA MAC/FAC index
|
|
273
|
+
print(f"Building SMA Index (scorer={scorer})...")
|
|
274
|
+
sma_index = MacFacIndex(config=MatchConfig(scorer=scorer))
|
|
275
|
+
sma_index.build(index_cases)
|
|
276
|
+
|
|
277
|
+
# 2. Build BM25 Index
|
|
278
|
+
print("Building BM25 Index...")
|
|
279
|
+
from rank_bm25 import BM25Okapi
|
|
280
|
+
tokenized_index = [text.lower().split() for _, text in index_docs]
|
|
281
|
+
bm25_index = BM25Okapi(tokenized_index)
|
|
282
|
+
|
|
283
|
+
# 3. Build Dense RAG Index (SentenceTransformers)
|
|
284
|
+
print("Building Dense RAG Index (SentenceTransformers)...")
|
|
285
|
+
from sentence_transformers import SentenceTransformer, util
|
|
286
|
+
dense_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
287
|
+
index_texts = [text for _, text in index_docs]
|
|
288
|
+
index_embeddings = dense_model.encode(index_texts, convert_to_tensor=True, show_progress_bar=False)
|
|
289
|
+
|
|
290
|
+
# 4. Build KG-PPR Proxy index
|
|
291
|
+
print("Building KG-PPR Proxy Index...")
|
|
292
|
+
index_entity_counters = {
|
|
293
|
+
ic.case_id: Counter(e.name for e in ic.entities())
|
|
294
|
+
for ic in index_cases
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
# Per-query ranked retrieval for each method, as (case_id, score) pairs.
|
|
298
|
+
def retrieve_sma(q_case, q_text):
|
|
299
|
+
# shortlist=40, fac_budget=20 keeps CPU latency bounded
|
|
300
|
+
results = sma_index.retrieve(q_case, k=10, shortlist=40, fac_budget=20)
|
|
301
|
+
return [(r.case_id, r.ses_n) for r in results]
|
|
302
|
+
|
|
303
|
+
def retrieve_bm25(q_case, q_text):
|
|
304
|
+
scores = bm25_index.get_scores(q_text.lower().split())
|
|
305
|
+
ranked = sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))
|
|
306
|
+
return ranked[:10]
|
|
307
|
+
|
|
308
|
+
def retrieve_dense(q_case, q_text):
|
|
309
|
+
query_embedding = dense_model.encode(q_text, convert_to_tensor=True, show_progress_bar=False)
|
|
310
|
+
scores = util.cos_sim(query_embedding, index_embeddings)[0].cpu().tolist()
|
|
311
|
+
ranked = sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))
|
|
312
|
+
return ranked[:10]
|
|
313
|
+
|
|
314
|
+
def retrieve_kg(q_case, q_text):
|
|
315
|
+
q_counter = Counter(e.name for e in q_case.entities())
|
|
316
|
+
ranked = sorted(
|
|
317
|
+
(
|
|
318
|
+
(ic_id, float(sum(min(v, counts.get(k, 0)) for k, v in q_counter.items())))
|
|
319
|
+
for ic_id, counts in index_entity_counters.items()
|
|
320
|
+
),
|
|
321
|
+
key=lambda row: (-row[1], row[0]),
|
|
322
|
+
)
|
|
323
|
+
return ranked[:10]
|
|
324
|
+
|
|
325
|
+
def weighted_vote(ranked, top=5):
|
|
326
|
+
voting = {"Anomaly": 0.0, "Normal": 0.0}
|
|
327
|
+
for case_id, score in ranked[:top]:
|
|
328
|
+
voting[index_labels[case_id]] += score
|
|
329
|
+
return max(voting, key=voting.get) if sum(voting.values()) > 0 else "Normal"
|
|
330
|
+
|
|
331
|
+
retrievers = {
|
|
332
|
+
"SMA": retrieve_sma,
|
|
333
|
+
"BM25": retrieve_bm25,
|
|
334
|
+
"Dense RAG": retrieve_dense,
|
|
335
|
+
"KG-PPR Proxy": retrieve_kg,
|
|
336
|
+
}
|
|
337
|
+
methods = list(retrievers)
|
|
338
|
+
metrics_by_method = {m: {"recalls": [], "preds": [], "latencies": []} for m in methods}
|
|
339
|
+
doc_ids = [doc_id for doc_id, _ in index_docs]
|
|
340
|
+
|
|
341
|
+
print("Starting retrieval runs...")
|
|
342
|
+
total_queries = len(query_cases)
|
|
343
|
+
for idx, (q_case, (q_case_id, q_text)) in enumerate(zip(query_cases, query_docs), start=1):
|
|
344
|
+
for method, retriever in retrievers.items():
|
|
345
|
+
t0 = time.perf_counter()
|
|
346
|
+
ranked = retriever(q_case, q_text)
|
|
347
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
348
|
+
data = metrics_by_method[method]
|
|
349
|
+
data["recalls"].append([case_id for case_id, _ in ranked])
|
|
350
|
+
data["latencies"].append(elapsed_ms)
|
|
351
|
+
data["preds"].append(weighted_vote(ranked))
|
|
352
|
+
|
|
353
|
+
if idx % 20 == 0 or idx == total_queries:
|
|
354
|
+
print(f"Processed {idx}/{total_queries} retrieval runs...")
|
|
355
|
+
|
|
356
|
+
# Calculate final metrics
|
|
357
|
+
triage_rows = []
|
|
358
|
+
true_labels = [query_labels[c.case_id] for c in query_cases]
|
|
359
|
+
|
|
360
|
+
for m in methods:
|
|
361
|
+
data = metrics_by_method[m]
|
|
362
|
+
preds = data["preds"]
|
|
363
|
+
recalls = data["recalls"]
|
|
364
|
+
latencies = data["latencies"]
|
|
365
|
+
|
|
366
|
+
# F1 Score
|
|
367
|
+
f1 = f1_score(true_labels, preds, average="macro")
|
|
368
|
+
|
|
369
|
+
# label_hit_rate @ 1, 5, 10
|
|
370
|
+
r1_list = []
|
|
371
|
+
r5_list = []
|
|
372
|
+
r10_list = []
|
|
373
|
+
for q_idx, q_case in enumerate(query_cases):
|
|
374
|
+
q_label = query_labels[q_case.case_id]
|
|
375
|
+
ret_ids = recalls[q_idx]
|
|
376
|
+
|
|
377
|
+
# Find all relevant index cases for this query
|
|
378
|
+
relevant_ids = {ic.case_id for ic in index_cases if index_labels[ic.case_id] == q_label}
|
|
379
|
+
|
|
380
|
+
# Hit rate at k = count of retrieved relevant / min(relevant_ids, k)
|
|
381
|
+
def compute_hit_rate_k(k):
|
|
382
|
+
hits = len(set(ret_ids[:k]).intersection(relevant_ids))
|
|
383
|
+
denom = min(len(relevant_ids), k)
|
|
384
|
+
return hits / denom if denom > 0 else 0.0
|
|
385
|
+
|
|
386
|
+
r1_list.append(compute_hit_rate_k(1))
|
|
387
|
+
r5_list.append(compute_hit_rate_k(5))
|
|
388
|
+
r10_list.append(compute_hit_rate_k(10))
|
|
389
|
+
|
|
390
|
+
r1 = sum(r1_list) / len(r1_list)
|
|
391
|
+
r5 = sum(r5_list) / len(r5_list)
|
|
392
|
+
r10 = sum(r10_list) / len(r10_list)
|
|
393
|
+
|
|
394
|
+
# Latency p50, p95
|
|
395
|
+
p50 = np.percentile(latencies, 50)
|
|
396
|
+
p95 = np.percentile(latencies, 95)
|
|
397
|
+
|
|
398
|
+
triage_rows.append({
|
|
399
|
+
"dataset": "LogHub",
|
|
400
|
+
"split": f"{dataset_name}_MVP_diagnostic",
|
|
401
|
+
"method": m,
|
|
402
|
+
"macro_f1": f"{f1:.4f}",
|
|
403
|
+
"label_hit_rate@1": f"{r1:.4f}",
|
|
404
|
+
"label_hit_rate@5": f"{r5:.4f}",
|
|
405
|
+
"label_hit_rate@10": f"{r10:.4f}",
|
|
406
|
+
"p50_ms": f"{p50:.3f}",
|
|
407
|
+
"p95_ms": f"{p95:.3f}"
|
|
408
|
+
})
|
|
409
|
+
|
|
410
|
+
# Print results
|
|
411
|
+
print(f"Method: {m}")
|
|
412
|
+
print(f" Macro-F1: {f1:.4f}")
|
|
413
|
+
print(f" label_hit_rate@1: {r1:.4f}, label_hit_rate@5: {r5:.4f}, label_hit_rate@10: {r10:.4f}")
|
|
414
|
+
print(f" p50 Latency: {p50:.3f} ms, p95 Latency: {p95:.3f} ms")
|
|
415
|
+
|
|
416
|
+
# Diagnostic alerts for collapsed or suspiciously perfect runs
|
|
417
|
+
unique_preds = set(preds)
|
|
418
|
+
is_suspicious = (f1 == 0.0 or f1 == 1.0 or len(unique_preds) <= 1)
|
|
419
|
+
if is_suspicious:
|
|
420
|
+
reason = ""
|
|
421
|
+
if f1 == 0.0:
|
|
422
|
+
reason = "F1 is 0.0: Retrieval collapse or dataset imbalance"
|
|
423
|
+
elif f1 == 1.0:
|
|
424
|
+
reason = "F1 is 1.0: Suspiciously perfect classification - potential data leakage or indexing overlap"
|
|
425
|
+
elif len(unique_preds) <= 1:
|
|
426
|
+
reason = f"Retrieval collapse: predicted only '{list(unique_preds)[0]}' sessions"
|
|
427
|
+
|
|
428
|
+
triage_rows.append({
|
|
429
|
+
"dataset": "DIAGNOSTIC",
|
|
430
|
+
"split": f"{dataset_name}_MVP_diagnostic",
|
|
431
|
+
"method": f"{m}_alert",
|
|
432
|
+
"macro_f1": reason,
|
|
433
|
+
"label_hit_rate@1": "ALERT",
|
|
434
|
+
"label_hit_rate@5": "ALERT",
|
|
435
|
+
"label_hit_rate@10": "ALERT",
|
|
436
|
+
"p50_ms": "0.000",
|
|
437
|
+
"p95_ms": "0.000"
|
|
438
|
+
})
|
|
439
|
+
print(f" [DIAGNOSTIC ALERT] {reason}")
|
|
440
|
+
|
|
441
|
+
return triage_rows
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def run_loghub_eval(scorer: str = "ses") -> list[dict]:
|
|
445
|
+
"""Execute both HDFS and BGL evaluations and write manifests."""
|
|
446
|
+
random.seed(42)
|
|
447
|
+
|
|
448
|
+
hdfs_zip = pathlib.Path("data/raw/loghub_raw/HDFS_v1.zip")
|
|
449
|
+
bgl_zip = pathlib.Path("data/raw/loghub_raw/BGL.zip")
|
|
450
|
+
|
|
451
|
+
if not hdfs_zip.exists() or not bgl_zip.exists():
|
|
452
|
+
print("Missing log datasets. Run fetch_datasets.py first.")
|
|
453
|
+
return []
|
|
454
|
+
|
|
455
|
+
manifest_rows = []
|
|
456
|
+
|
|
457
|
+
# 1. HDFS stratified sampling & evaluation
|
|
458
|
+
print("Sampling HDFS sessions...")
|
|
459
|
+
hdfs_sampled = sample_hdfs_stratified(hdfs_zip, sample_size=1000, seed=42)
|
|
460
|
+
hdfs_rows = run_evaluation("HDFS", hdfs_sampled, manifest_rows, scorer=scorer)
|
|
461
|
+
|
|
462
|
+
# 2. BGL stratified sampling & evaluation
|
|
463
|
+
print("Sampling BGL sessions...")
|
|
464
|
+
bgl_sampled = sample_bgl_stratified(bgl_zip, sample_size=1000, seed=42)
|
|
465
|
+
bgl_rows = run_evaluation("BGL", bgl_sampled, manifest_rows, scorer=scorer)
|
|
466
|
+
|
|
467
|
+
# Save the sampled manifest for reproducibility
|
|
468
|
+
manifest_path = pathlib.Path("reports/loghub_sample_manifest.csv")
|
|
469
|
+
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
470
|
+
with manifest_path.open("w", encoding="utf-8", newline="") as fh:
|
|
471
|
+
writer = csv.DictWriter(fh, fieldnames=["dataset", "session_id", "label"])
|
|
472
|
+
writer.writeheader()
|
|
473
|
+
writer.writerows(manifest_rows)
|
|
474
|
+
print(f"Saved manifest to {manifest_path}")
|
|
475
|
+
|
|
476
|
+
return hdfs_rows + bgl_rows
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
if __name__ == "__main__":
|
|
480
|
+
run_loghub_eval()
|
sma/eval/longmemeval.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""LongMemEval loader + answer grader (real agent-memory drift benchmark)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import json, pathlib
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
DRIFT_CATEGORIES = {"knowledge-update", "temporal-reasoning"}
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class Session:
|
|
10
|
+
session_id: str
|
|
11
|
+
date: str
|
|
12
|
+
turns: list[dict]
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class LMEInstance:
|
|
16
|
+
question_id: str
|
|
17
|
+
category: str
|
|
18
|
+
question: str
|
|
19
|
+
answer: str
|
|
20
|
+
question_date: str
|
|
21
|
+
sessions: tuple[Session, ...]
|
|
22
|
+
answer_session_ids: tuple[str, ...]
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def is_drift(self) -> bool:
|
|
26
|
+
return self.category in DRIFT_CATEGORIES
|
|
27
|
+
|
|
28
|
+
def load_instances(path: str | pathlib.Path) -> list[LMEInstance]:
|
|
29
|
+
raw = json.loads(pathlib.Path(path).read_text(encoding="utf-8"))
|
|
30
|
+
out: list[LMEInstance] = []
|
|
31
|
+
for r in raw:
|
|
32
|
+
sessions = tuple(
|
|
33
|
+
Session(sid, date, turns)
|
|
34
|
+
for sid, date, turns in zip(
|
|
35
|
+
r["haystack_session_ids"], r["haystack_dates"], r["haystack_sessions"])
|
|
36
|
+
)
|
|
37
|
+
out.append(LMEInstance(
|
|
38
|
+
question_id=r["question_id"], category=r["question_type"],
|
|
39
|
+
question=r["question"], answer=str(r["answer"]),
|
|
40
|
+
question_date=r.get("question_date", ""), sessions=sessions,
|
|
41
|
+
answer_session_ids=tuple(r.get("answer_session_ids", []))))
|
|
42
|
+
return out
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def grade_answer(prediction: str, gold: str) -> float:
|
|
46
|
+
"""LongMemEval-style lenient match: normalized substring containment.
|
|
47
|
+
(The official grader uses an LLM judge; this deterministic proxy is used
|
|
48
|
+
for unit tests and the smoke run. The battery can swap in the LLM judge.)"""
|
|
49
|
+
p = " ".join(prediction.lower().split())
|
|
50
|
+
g = " ".join(gold.lower().split())
|
|
51
|
+
return 1.0 if g and g in p else 0.0
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Common interface for the four drift-experiment memory variants."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from sma.eval.longmemeval import Session
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class QueryResult:
|
|
9
|
+
answer: str
|
|
10
|
+
retrieved: list[str] = field(default_factory=list)
|
|
11
|
+
drift_flagged: bool = False # backend believes the queried fact changed
|
|
12
|
+
|
|
13
|
+
class MemoryBackend(ABC):
|
|
14
|
+
"""Shared backbone (DeepSeek orchestrator + extractor) is injected by the harness."""
|
|
15
|
+
name: str = "base"
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def reset(self) -> None: ...
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def ingest(self, session: Session) -> None: ...
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def query(self, question: str) -> QueryResult: ...
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from .base import MemoryBackend, QueryResult
|
|
3
|
+
from .shared_llm import answer_from
|
|
4
|
+
|
|
5
|
+
class ContextOnly(MemoryBackend):
|
|
6
|
+
name = "context-only"
|
|
7
|
+
def __init__(self, llm): self.llm = llm; self.turns: list[str] = []
|
|
8
|
+
def reset(self): self.turns = []
|
|
9
|
+
def ingest(self, session):
|
|
10
|
+
for t in session.turns:
|
|
11
|
+
self.turns.append(f"[{session.date}] {t['content']}")
|
|
12
|
+
def query(self, question):
|
|
13
|
+
ans = answer_from(self.llm, question, self.turns)
|
|
14
|
+
return QueryResult(answer=ans, retrieved=list(self.turns))
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from .base import MemoryBackend, QueryResult
|
|
3
|
+
from .shared_llm import extract_facts, answer_from
|
|
4
|
+
|
|
5
|
+
class RagNotes(MemoryBackend):
|
|
6
|
+
"""LLM-written notes, retrieved by token overlap (a faithful simple RAG)."""
|
|
7
|
+
name = "rag-notes"
|
|
8
|
+
def __init__(self, llm, k: int = 5): self.llm = llm; self.k = k; self.notes: list[str] = []
|
|
9
|
+
def reset(self): self.notes = []
|
|
10
|
+
def ingest(self, session):
|
|
11
|
+
for t in session.turns:
|
|
12
|
+
self.notes.extend(extract_facts(self.llm, t["content"]))
|
|
13
|
+
def query(self, question):
|
|
14
|
+
q = set(question.lower().split())
|
|
15
|
+
ranked = sorted(self.notes, key=lambda n: -len(q & set(n.lower().split())))
|
|
16
|
+
top = ranked[: self.k]
|
|
17
|
+
return QueryResult(answer=answer_from(self.llm, question, top), retrieved=top)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Extraction + answering shared by ALL variants (extraction held constant)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
_EXTRACT_SYS = ("Extract the durable user facts from the message as a JSON "
|
|
6
|
+
"array of short strings. Only facts that could be asked about "
|
|
7
|
+
"later. No commentary.")
|
|
8
|
+
_ANSWER_SYS = ("Answer the question using ONLY the provided memory items. "
|
|
9
|
+
"If the memory contradicts itself, prefer the most recent. "
|
|
10
|
+
"Answer concisely; if unknown, say 'unknown'.")
|
|
11
|
+
|
|
12
|
+
def extract_facts(llm, message: str) -> list[str]:
|
|
13
|
+
"""Call llm to extract durable facts from a chat message; returns list of strings."""
|
|
14
|
+
out = llm.complete(
|
|
15
|
+
[{"role": "system", "content": _EXTRACT_SYS},
|
|
16
|
+
{"role": "user", "content": message}], max_tokens=300)
|
|
17
|
+
try:
|
|
18
|
+
facts = json.loads(out)
|
|
19
|
+
return [str(f) for f in facts] if isinstance(facts, list) else []
|
|
20
|
+
except (json.JSONDecodeError, TypeError):
|
|
21
|
+
return []
|
|
22
|
+
|
|
23
|
+
def answer_from(llm, question: str, retrieved: list[str]) -> str:
|
|
24
|
+
"""Answer question using only the provided retrieved memory items."""
|
|
25
|
+
mem = "\n".join(f"- {r}" for r in retrieved) or "(no memory)"
|
|
26
|
+
out = llm.complete(
|
|
27
|
+
[{"role": "system", "content": _ANSWER_SYS},
|
|
28
|
+
{"role": "user", "content": f"Memory:\n{mem}\n\nQuestion: {question}"}],
|
|
29
|
+
max_tokens=120)
|
|
30
|
+
return out.strip()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""SMA memory: each turn's extracted facts are re-encoded into the case store
|
|
2
|
+
(re-derived from the conversation, never from prior generations); retrieval is
|
|
3
|
+
structural; SAGE flags expectation-violations as drift."""
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from .base import MemoryBackend, QueryResult
|
|
6
|
+
from .shared_llm import extract_facts, answer_from
|
|
7
|
+
from sma.index.macfac import MacFacIndex
|
|
8
|
+
from sma.ir.schema import make_case, stmt
|
|
9
|
+
from sma.sage.pools import SagePool
|
|
10
|
+
from sma.match.types import MatchConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _fact_to_case(fact: str):
|
|
14
|
+
toks = fact.split()
|
|
15
|
+
if len(toks) >= 3:
|
|
16
|
+
return make_case([stmt(toks[0], toks[1], " ".join(toks[2:]))])
|
|
17
|
+
return make_case([stmt("fact", *(toks or ["empty"]))])
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SmaMemory(MemoryBackend):
|
|
21
|
+
"""SMA memory backend: structural re-derivation per turn + SAGE drift detection."""
|
|
22
|
+
|
|
23
|
+
name = "sma"
|
|
24
|
+
|
|
25
|
+
def __init__(self, llm, k: int = 5):
|
|
26
|
+
self.llm = llm
|
|
27
|
+
self.k = k
|
|
28
|
+
self.last_violation = 0.0
|
|
29
|
+
|
|
30
|
+
def reset(self):
|
|
31
|
+
self.index = MacFacIndex(config=MatchConfig())
|
|
32
|
+
self.pool = SagePool("drift", assimilation_threshold=0.2)
|
|
33
|
+
self.texts: dict[str, str] = {}
|
|
34
|
+
self.last_violation = 0.0
|
|
35
|
+
|
|
36
|
+
def ingest(self, session) -> None:
|
|
37
|
+
for t in session.turns:
|
|
38
|
+
for fact in extract_facts(self.llm, t["content"]):
|
|
39
|
+
case = _fact_to_case(fact)
|
|
40
|
+
self.last_violation = self.pool.expectation_violation(case)
|
|
41
|
+
self.index.add(case)
|
|
42
|
+
self.pool.assimilate(case)
|
|
43
|
+
self.texts[case.case_id] = fact
|
|
44
|
+
|
|
45
|
+
def query(self, question: str) -> QueryResult:
|
|
46
|
+
qcase = _fact_to_case(question)
|
|
47
|
+
results = self.index.retrieve(qcase, k=self.k, shortlist=50, fac_budget=20)
|
|
48
|
+
retrieved = [self.texts.get(r.case_id, "") for r in results]
|
|
49
|
+
ans = answer_from(self.llm, question, [r for r in retrieved if r])
|
|
50
|
+
return QueryResult(
|
|
51
|
+
answer=ans,
|
|
52
|
+
retrieved=retrieved,
|
|
53
|
+
drift_flagged=self.last_violation > 0.5,
|
|
54
|
+
)
|