structuremappingmemory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. sma/__init__.py +5 -0
  2. sma/__main__.py +5 -0
  3. sma/agent/__init__.py +5 -0
  4. sma/agent/adapter_draft.py +217 -0
  5. sma/agent/api.py +67 -0
  6. sma/agent/comparison.py +591 -0
  7. sma/agent/llm.py +280 -0
  8. sma/agent/policies.py +21 -0
  9. sma/agent/service.py +95 -0
  10. sma/cli.py +65 -0
  11. sma/encoders/__init__.py +38 -0
  12. sma/encoders/agentobs.py +27 -0
  13. sma/encoders/base.py +23 -0
  14. sma/encoders/code_treesitter.py +64 -0
  15. sma/encoders/coverage.py +80 -0
  16. sma/encoders/draft_adapter.py +183 -0
  17. sma/encoders/healthcare.py +207 -0
  18. sma/encoders/logs_drain.py +142 -0
  19. sma/encoders/prose_tier1.py +57 -0
  20. sma/encoders/structured.py +57 -0
  21. sma/encoders/traces.py +45 -0
  22. sma/eval/__init__.py +2 -0
  23. sma/eval/agentic/__init__.py +35 -0
  24. sma/eval/agentic/arms/__init__.py +0 -0
  25. sma/eval/agentic/arms/cyber.py +48 -0
  26. sma/eval/agentic/arms/discovery.py +35 -0
  27. sma/eval/agentic/arms/finance.py +38 -0
  28. sma/eval/agentic/arms/legal.py +74 -0
  29. sma/eval/agentic/arms/medicine.py +45 -0
  30. sma/eval/agentic/harness.py +275 -0
  31. sma/eval/agentic/memories.py +308 -0
  32. sma/eval/agentic/metrics.py +82 -0
  33. sma/eval/agentic_qa/__init__.py +27 -0
  34. sma/eval/agentic_qa/agent.py +383 -0
  35. sma/eval/agentic_qa/metrics.py +239 -0
  36. sma/eval/agentic_qa/pools.py +197 -0
  37. sma/eval/arn.py +65 -0
  38. sma/eval/baselines/__init__.py +6 -0
  39. sma/eval/baselines/bge_dense.py +54 -0
  40. sma/eval/baselines/bm25.py +18 -0
  41. sma/eval/baselines/dense.py +42 -0
  42. sma/eval/baselines/hipporag.py +235 -0
  43. sma/eval/baselines/hybrid_rrf.py +30 -0
  44. sma/eval/baselines/longcontext_llm.py +124 -0
  45. sma/eval/baselines/rerank.py +41 -0
  46. sma/eval/baselines/splade.py +77 -0
  47. sma/eval/baselines/wl_kernel.py +163 -0
  48. sma/eval/bugsinpy.py +358 -0
  49. sma/eval/bugsinpy_families.py +164 -0
  50. sma/eval/crossdomain.py +89 -0
  51. sma/eval/diabetes.py +61 -0
  52. sma/eval/drift_env.py +26 -0
  53. sma/eval/drift_metrics.py +24 -0
  54. sma/eval/family_labels.py +167 -0
  55. sma/eval/fraud_elliptic/__init__.py +29 -0
  56. sma/eval/fraud_elliptic/encoder.py +279 -0
  57. sma/eval/fraud_elliptic/eval.py +269 -0
  58. sma/eval/fraud_elliptic/test_encoder.py +123 -0
  59. sma/eval/ieee_cis.py +66 -0
  60. sma/eval/loghub.py +16 -0
  61. sma/eval/loghub_eval.py +480 -0
  62. sma/eval/longmemeval.py +51 -0
  63. sma/eval/memory_backends/__init__.py +2 -0
  64. sma/eval/memory_backends/base.py +22 -0
  65. sma/eval/memory_backends/context_only.py +14 -0
  66. sma/eval/memory_backends/rag_notes.py +17 -0
  67. sma/eval/memory_backends/shared_llm.py +30 -0
  68. sma/eval/memory_backends/sma_memory.py +54 -0
  69. sma/eval/memory_backends/zep_graphiti.py +33 -0
  70. sma/eval/metrics.py +32 -0
  71. sma/eval/ontology_bench.py +219 -0
  72. sma/eval/report.py +573 -0
  73. sma/eval/ssb_eval.py +216 -0
  74. sma/eval/ssb_generator.py +116 -0
  75. sma/eval/stats.py +108 -0
  76. sma/eval/transfer_eval.py +844 -0
  77. sma/index/__init__.py +15 -0
  78. sma/index/ann.py +21 -0
  79. sma/index/content_vectors.py +60 -0
  80. sma/index/inverted.py +63 -0
  81. sma/index/macfac.py +174 -0
  82. sma/ir/__init__.py +22 -0
  83. sma/ir/canon.py +106 -0
  84. sma/ir/schema.py +165 -0
  85. sma/ir/sexpr.py +86 -0
  86. sma/ir/signatures.py +76 -0
  87. sma/match/__init__.py +20 -0
  88. sma/match/conflicts.py +46 -0
  89. sma/match/engine.py +60 -0
  90. sma/match/explain.py +59 -0
  91. sma/match/infer.py +54 -0
  92. sma/match/kernels.py +54 -0
  93. sma/match/mdl.py +30 -0
  94. sma/match/merge_cpsat.py +77 -0
  95. sma/match/merge_greedy.py +15 -0
  96. sma/match/mh.py +177 -0
  97. sma/match/ses.py +84 -0
  98. sma/match/types.py +115 -0
  99. sma/match/verifier.py +27 -0
  100. sma/ontology/__init__.py +45 -0
  101. sma/ontology/attack.py +134 -0
  102. sma/ontology/cpc.py +69 -0
  103. sma/ontology/graph.py +58 -0
  104. sma/ontology/loader.py +262 -0
  105. sma/ontology/mitre_xml.py +67 -0
  106. sma/ontology/mount.py +101 -0
  107. sma/ontology/rdf_loader.py +75 -0
  108. sma/ontology/registry.py +115 -0
  109. sma/ontology/router.py +69 -0
  110. sma/ontology/usgaap.py +73 -0
  111. sma/sage/__init__.py +6 -0
  112. sma/sage/assimilate.py +12 -0
  113. sma/sage/pools.py +105 -0
  114. sma/sage/probabilities.py +10 -0
  115. sma/store/__init__.py +6 -0
  116. sma/store/lmdb_store.py +78 -0
  117. sma/store/registry.py +26 -0
  118. sma/store/wal.py +26 -0
  119. sma/ui/app.py +642 -0
  120. structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
  121. structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
  122. structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
  123. structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
  124. structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
  125. structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,480 @@
1
+ """LogHub MVP Diagnostic Evaluation script.
2
+
3
+ Performs stratified sampling on HDFS and BGL datasets, sessionizes logs,
4
+ indexes cases, runs retrieval via SMA, BM25, Dense RAG, and KG-PPR Proxy,
5
+ and outputs performance and latency metrics.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import pathlib
12
+ import random
13
+ import time
14
+ import zipfile
15
+ from collections import defaultdict, Counter
16
+
17
+ import numpy as np
18
+ from sklearn.metrics import f1_score
19
+
20
+ from sma.encoders import get_encoder
21
+ from sma.index.macfac import MacFacIndex
22
+ from sma.match.types import MatchConfig
23
+
24
+
25
+ def load_hdfs_blocks(hdfs_zip_path: pathlib.Path) -> dict[str, str]:
26
+ """Load block labels from the HDFS anomaly label CSV."""
27
+ labels = {}
28
+ with zipfile.ZipFile(hdfs_zip_path, "r") as z:
29
+ with z.open("preprocessed/anomaly_label.csv") as fh:
30
+ reader = csv.reader(fh.read().decode("utf-8").splitlines())
31
+ next(reader) # Skip header
32
+ for row in reader:
33
+ if len(row) >= 2:
34
+ labels[row[0]] = row[1]
35
+ return labels
36
+
37
+
38
+ def extract_hdfs_sessions(
39
+ hdfs_zip_path: pathlib.Path, sampled_blocks: set[str]
40
+ ) -> dict[str, list[str]]:
41
+ """Stream HDFS.log and extract log lines for sampled blocks."""
42
+ block_lines = defaultdict(list)
43
+ import re
44
+ block_re = re.compile(r"blk_-?\d+")
45
+
46
+ with zipfile.ZipFile(hdfs_zip_path, "r") as z:
47
+ with z.open("HDFS.log") as fh:
48
+ for line_bytes in fh:
49
+ line = line_bytes.decode("utf-8", errors="ignore")
50
+ match = block_re.search(line)
51
+ if match:
52
+ bid = match.group(0)
53
+ if bid in sampled_blocks:
54
+ block_lines[bid].append(line)
55
+ return block_lines
56
+
57
+
58
+ def get_hdfs_first_timestamps(hdfs_zip_path: pathlib.Path) -> dict[str, float]:
59
+ """Scan HDFS.log to find the first occurrence time (in seconds) of each block."""
60
+ import re
61
+ from datetime import datetime
62
+ block_re = re.compile(r"blk_-?\d+")
63
+ block_times = {}
64
+
65
+ # Scan the entire HDFS log stream for unbiased temporal stratification
66
+ with zipfile.ZipFile(hdfs_zip_path, "r") as z:
67
+ with z.open("HDFS.log") as fh:
68
+ for line_bytes in fh:
69
+ line = line_bytes.decode("utf-8", errors="ignore")
70
+ match = block_re.search(line)
71
+ if match:
72
+ bid = match.group(0)
73
+ if bid not in block_times:
74
+ ts_str = line[:13]
75
+ if len(ts_str) == 13 and ts_str[6] == ' ' and ts_str[:6].isdigit() and ts_str[7:].isdigit():
76
+ try:
77
+ dt = datetime.strptime(ts_str, "%y%m%d %H%M%S")
78
+ block_times[bid] = dt.timestamp()
79
+ except ValueError:
80
+ block_times[bid] = 0.0
81
+ else:
82
+ block_times[bid] = 0.0
83
+ return block_times
84
+
85
+
86
+ def sample_hdfs_stratified(
87
+ hdfs_zip_path: pathlib.Path, sample_size: int = 1000, seed: int = 42
88
+ ) -> list[tuple[str, str, str]]:
89
+ """Sample stratified HDFS block sessions."""
90
+ labels = load_hdfs_blocks(hdfs_zip_path)
91
+ block_times = get_hdfs_first_timestamps(hdfs_zip_path)
92
+
93
+ # Filter to blocks that we found timestamps for
94
+ valid_blocks = [b for b in labels if b in block_times]
95
+
96
+ anom_blocks = [b for b in valid_blocks if labels[b] == "Anomaly"]
97
+ norm_blocks = [b for b in valid_blocks if labels[b] == "Normal"]
98
+
99
+ rng = random.Random(seed)
100
+
101
+ def get_stratified_subset(blocks, target_n):
102
+ # Sort by timestamp
103
+ sorted_blocks = sorted(blocks, key=lambda b: block_times[b])
104
+ if len(sorted_blocks) <= target_n:
105
+ return sorted_blocks
106
+
107
+ # Divide into 5 bins
108
+ bins = np.array_split(sorted_blocks, 5)
109
+ subset = []
110
+ per_bin = target_n // 5
111
+ for b in bins:
112
+ subset.extend(rng.sample(list(b), min(len(b), per_bin)))
113
+ # Fill remainder
114
+ while len(subset) < target_n and sorted_blocks:
115
+ rem = sorted(set(sorted_blocks) - set(subset))
116
+ if not rem:
117
+ break
118
+ subset.append(rng.choice(rem))
119
+ return subset
120
+
121
+ sampled_anom = get_stratified_subset(anom_blocks, sample_size // 2)
122
+ sampled_norm = get_stratified_subset(norm_blocks, sample_size // 2)
123
+ sampled_set = set(sampled_anom + sampled_norm)
124
+
125
+ # Extract log texts
126
+ block_lines = extract_hdfs_sessions(hdfs_zip_path, sampled_set)
127
+
128
+ results = []
129
+ for bid in sampled_anom + sampled_norm:
130
+ lines = block_lines.get(bid, [])
131
+ if lines:
132
+ results.append((bid, "".join(lines), labels[bid]))
133
+ return results
134
+
135
+
136
+ def sample_bgl_stratified(
137
+ bgl_zip_path: pathlib.Path, sample_size: int = 1000, seed: int = 42
138
+ ) -> list[tuple[str, str, str]]:
139
+ """Sessionize and sample stratified BGL logs using two passes to save memory."""
140
+ # Pass 1: Gather metadata for sessionization and labels
141
+ session_counts = Counter()
142
+ labels = defaultdict(bool)
143
+ timestamps = {}
144
+
145
+ with zipfile.ZipFile(bgl_zip_path, "r") as z:
146
+ with z.open("BGL.log") as fh:
147
+ for line_bytes in fh:
148
+ line = line_bytes.decode("utf-8", errors="ignore")
149
+ parts = line.split(maxsplit=5)
150
+ if len(parts) < 5:
151
+ continue
152
+ label = parts[0]
153
+ try:
154
+ timestamp = int(parts[1])
155
+ except ValueError:
156
+ continue
157
+ node_id = parts[3]
158
+
159
+ # Group BGL into 60-second windows per Node ID as per blueprint
160
+ window = timestamp // 60
161
+ session_key = f"bgl_{node_id}_{window}"
162
+ session_counts[session_key] += 1
163
+ if label != "-":
164
+ labels[session_key] = True
165
+ if session_key not in timestamps:
166
+ timestamps[session_key] = timestamp
167
+
168
+ # Filter sessions with length >= 3 to avoid tiny cases
169
+ filtered_keys = [k for k, count in session_counts.items() if count >= 3]
170
+ anom_keys = [k for k in filtered_keys if labels[k]]
171
+ norm_keys = [k for k in filtered_keys if not labels[k]]
172
+
173
+ rng = random.Random(seed)
174
+
175
+ def get_stratified_subset(keys, target_n):
176
+ sorted_keys = sorted(keys, key=lambda k: timestamps[k])
177
+ if len(sorted_keys) <= target_n:
178
+ return sorted_keys
179
+ bins = np.array_split(sorted_keys, 5)
180
+ subset = []
181
+ per_bin = target_n // 5
182
+ for b in bins:
183
+ subset.extend(rng.sample(list(b), min(len(b), per_bin)))
184
+ while len(subset) < target_n and sorted_keys:
185
+ rem = sorted(set(sorted_keys) - set(subset))
186
+ if not rem:
187
+ break
188
+ subset.append(rng.choice(rem))
189
+ return subset
190
+
191
+ sampled_anom = get_stratified_subset(anom_keys, sample_size // 2)
192
+ sampled_norm = get_stratified_subset(norm_keys, sample_size // 2)
193
+ sampled_set = set(sampled_anom + sampled_norm)
194
+
195
+ # Pass 2: Extract actual lines for the sampled set
196
+ sessions_lines = defaultdict(list)
197
+ with zipfile.ZipFile(bgl_zip_path, "r") as z:
198
+ with z.open("BGL.log") as fh:
199
+ for line_bytes in fh:
200
+ line = line_bytes.decode("utf-8", errors="ignore")
201
+ parts = line.split(maxsplit=5)
202
+ if len(parts) < 5:
203
+ continue
204
+ try:
205
+ timestamp = int(parts[1])
206
+ except ValueError:
207
+ continue
208
+ node_id = parts[3]
209
+ window = timestamp // 60
210
+ session_key = f"bgl_{node_id}_{window}"
211
+ if session_key in sampled_set:
212
+ # Drop the leading alert-category column: it is the ground-truth
213
+ # label, not log content. Keeping it leaks labels to every
214
+ # retriever (BGL '-' = normal, anything else = anomaly).
215
+ sessions_lines[session_key].append(line.partition(" ")[2] or line)
216
+
217
+ results = []
218
+ for k in sampled_anom + sampled_norm:
219
+ lines = sessions_lines.get(k, [])
220
+ if lines:
221
+ results.append((k, "".join(lines), "Anomaly" if labels[k] else "Normal"))
222
+ return results
223
+
224
+
225
+ def run_evaluation(
226
+ dataset_name: str,
227
+ sampled_data: list[tuple[str, str, str]],
228
+ output_manifest_rows: list[dict],
229
+ scorer: str = "ses",
230
+ ) -> list[dict]:
231
+ """Execute four-way evaluation comparison on the sampled dataset."""
232
+ print(f"\n--- Running evaluation on {dataset_name} ({len(sampled_data)} cases) ---")
233
+
234
+ # Save to manifest list
235
+ for sid, _, label in sampled_data:
236
+ output_manifest_rows.append({
237
+ "dataset": dataset_name,
238
+ "session_id": sid,
239
+ "label": label
240
+ })
241
+
242
+ # Split into 80% Index / 20% Query
243
+ random.Random(101).shuffle(sampled_data)
244
+ split_idx = int(len(sampled_data) * 0.8)
245
+ index_data = sampled_data[:split_idx]
246
+ query_data = sampled_data[split_idx:]
247
+
248
+ # Parse and encode cases
249
+ log_encoder = get_encoder("logs")
250
+
251
+ print("Encoding index cases...")
252
+ index_cases = []
253
+ index_docs = [] # List of (case_id, text)
254
+ index_labels = {}
255
+ for sid, text, label in index_data:
256
+ case = log_encoder.encode(text, session_id=sid).case
257
+ index_cases.append(case)
258
+ index_docs.append((case.case_id, text))
259
+ index_labels[case.case_id] = label
260
+
261
+ print("Encoding query cases...")
262
+ query_cases = []
263
+ query_docs = []
264
+ query_labels = {}
265
+ for sid, text, label in query_data:
266
+ case = log_encoder.encode(text, session_id=sid).case
267
+ query_cases.append(case)
268
+ query_docs.append((case.case_id, text))
269
+ query_labels[case.case_id] = label
270
+
271
+ # Build indexes ONCE before the query loop
272
+ # 1. Build SMA MAC/FAC index
273
+ print(f"Building SMA Index (scorer={scorer})...")
274
+ sma_index = MacFacIndex(config=MatchConfig(scorer=scorer))
275
+ sma_index.build(index_cases)
276
+
277
+ # 2. Build BM25 Index
278
+ print("Building BM25 Index...")
279
+ from rank_bm25 import BM25Okapi
280
+ tokenized_index = [text.lower().split() for _, text in index_docs]
281
+ bm25_index = BM25Okapi(tokenized_index)
282
+
283
+ # 3. Build Dense RAG Index (SentenceTransformers)
284
+ print("Building Dense RAG Index (SentenceTransformers)...")
285
+ from sentence_transformers import SentenceTransformer, util
286
+ dense_model = SentenceTransformer('all-MiniLM-L6-v2')
287
+ index_texts = [text for _, text in index_docs]
288
+ index_embeddings = dense_model.encode(index_texts, convert_to_tensor=True, show_progress_bar=False)
289
+
290
+ # 4. Build KG-PPR Proxy index
291
+ print("Building KG-PPR Proxy Index...")
292
+ index_entity_counters = {
293
+ ic.case_id: Counter(e.name for e in ic.entities())
294
+ for ic in index_cases
295
+ }
296
+
297
+ # Per-query ranked retrieval for each method, as (case_id, score) pairs.
298
+ def retrieve_sma(q_case, q_text):
299
+ # shortlist=40, fac_budget=20 keeps CPU latency bounded
300
+ results = sma_index.retrieve(q_case, k=10, shortlist=40, fac_budget=20)
301
+ return [(r.case_id, r.ses_n) for r in results]
302
+
303
+ def retrieve_bm25(q_case, q_text):
304
+ scores = bm25_index.get_scores(q_text.lower().split())
305
+ ranked = sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))
306
+ return ranked[:10]
307
+
308
+ def retrieve_dense(q_case, q_text):
309
+ query_embedding = dense_model.encode(q_text, convert_to_tensor=True, show_progress_bar=False)
310
+ scores = util.cos_sim(query_embedding, index_embeddings)[0].cpu().tolist()
311
+ ranked = sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))
312
+ return ranked[:10]
313
+
314
+ def retrieve_kg(q_case, q_text):
315
+ q_counter = Counter(e.name for e in q_case.entities())
316
+ ranked = sorted(
317
+ (
318
+ (ic_id, float(sum(min(v, counts.get(k, 0)) for k, v in q_counter.items())))
319
+ for ic_id, counts in index_entity_counters.items()
320
+ ),
321
+ key=lambda row: (-row[1], row[0]),
322
+ )
323
+ return ranked[:10]
324
+
325
+ def weighted_vote(ranked, top=5):
326
+ voting = {"Anomaly": 0.0, "Normal": 0.0}
327
+ for case_id, score in ranked[:top]:
328
+ voting[index_labels[case_id]] += score
329
+ return max(voting, key=voting.get) if sum(voting.values()) > 0 else "Normal"
330
+
331
+ retrievers = {
332
+ "SMA": retrieve_sma,
333
+ "BM25": retrieve_bm25,
334
+ "Dense RAG": retrieve_dense,
335
+ "KG-PPR Proxy": retrieve_kg,
336
+ }
337
+ methods = list(retrievers)
338
+ metrics_by_method = {m: {"recalls": [], "preds": [], "latencies": []} for m in methods}
339
+ doc_ids = [doc_id for doc_id, _ in index_docs]
340
+
341
+ print("Starting retrieval runs...")
342
+ total_queries = len(query_cases)
343
+ for idx, (q_case, (q_case_id, q_text)) in enumerate(zip(query_cases, query_docs), start=1):
344
+ for method, retriever in retrievers.items():
345
+ t0 = time.perf_counter()
346
+ ranked = retriever(q_case, q_text)
347
+ elapsed_ms = (time.perf_counter() - t0) * 1000
348
+ data = metrics_by_method[method]
349
+ data["recalls"].append([case_id for case_id, _ in ranked])
350
+ data["latencies"].append(elapsed_ms)
351
+ data["preds"].append(weighted_vote(ranked))
352
+
353
+ if idx % 20 == 0 or idx == total_queries:
354
+ print(f"Processed {idx}/{total_queries} retrieval runs...")
355
+
356
+ # Calculate final metrics
357
+ triage_rows = []
358
+ true_labels = [query_labels[c.case_id] for c in query_cases]
359
+
360
+ for m in methods:
361
+ data = metrics_by_method[m]
362
+ preds = data["preds"]
363
+ recalls = data["recalls"]
364
+ latencies = data["latencies"]
365
+
366
+ # F1 Score
367
+ f1 = f1_score(true_labels, preds, average="macro")
368
+
369
+ # label_hit_rate @ 1, 5, 10
370
+ r1_list = []
371
+ r5_list = []
372
+ r10_list = []
373
+ for q_idx, q_case in enumerate(query_cases):
374
+ q_label = query_labels[q_case.case_id]
375
+ ret_ids = recalls[q_idx]
376
+
377
+ # Find all relevant index cases for this query
378
+ relevant_ids = {ic.case_id for ic in index_cases if index_labels[ic.case_id] == q_label}
379
+
380
+ # Hit rate at k = count of retrieved relevant / min(relevant_ids, k)
381
+ def compute_hit_rate_k(k):
382
+ hits = len(set(ret_ids[:k]).intersection(relevant_ids))
383
+ denom = min(len(relevant_ids), k)
384
+ return hits / denom if denom > 0 else 0.0
385
+
386
+ r1_list.append(compute_hit_rate_k(1))
387
+ r5_list.append(compute_hit_rate_k(5))
388
+ r10_list.append(compute_hit_rate_k(10))
389
+
390
+ r1 = sum(r1_list) / len(r1_list)
391
+ r5 = sum(r5_list) / len(r5_list)
392
+ r10 = sum(r10_list) / len(r10_list)
393
+
394
+ # Latency p50, p95
395
+ p50 = np.percentile(latencies, 50)
396
+ p95 = np.percentile(latencies, 95)
397
+
398
+ triage_rows.append({
399
+ "dataset": "LogHub",
400
+ "split": f"{dataset_name}_MVP_diagnostic",
401
+ "method": m,
402
+ "macro_f1": f"{f1:.4f}",
403
+ "label_hit_rate@1": f"{r1:.4f}",
404
+ "label_hit_rate@5": f"{r5:.4f}",
405
+ "label_hit_rate@10": f"{r10:.4f}",
406
+ "p50_ms": f"{p50:.3f}",
407
+ "p95_ms": f"{p95:.3f}"
408
+ })
409
+
410
+ # Print results
411
+ print(f"Method: {m}")
412
+ print(f" Macro-F1: {f1:.4f}")
413
+ print(f" label_hit_rate@1: {r1:.4f}, label_hit_rate@5: {r5:.4f}, label_hit_rate@10: {r10:.4f}")
414
+ print(f" p50 Latency: {p50:.3f} ms, p95 Latency: {p95:.3f} ms")
415
+
416
+ # Diagnostic alerts for collapsed or suspiciously perfect runs
417
+ unique_preds = set(preds)
418
+ is_suspicious = (f1 == 0.0 or f1 == 1.0 or len(unique_preds) <= 1)
419
+ if is_suspicious:
420
+ reason = ""
421
+ if f1 == 0.0:
422
+ reason = "F1 is 0.0: Retrieval collapse or dataset imbalance"
423
+ elif f1 == 1.0:
424
+ reason = "F1 is 1.0: Suspiciously perfect classification - potential data leakage or indexing overlap"
425
+ elif len(unique_preds) <= 1:
426
+ reason = f"Retrieval collapse: predicted only '{list(unique_preds)[0]}' sessions"
427
+
428
+ triage_rows.append({
429
+ "dataset": "DIAGNOSTIC",
430
+ "split": f"{dataset_name}_MVP_diagnostic",
431
+ "method": f"{m}_alert",
432
+ "macro_f1": reason,
433
+ "label_hit_rate@1": "ALERT",
434
+ "label_hit_rate@5": "ALERT",
435
+ "label_hit_rate@10": "ALERT",
436
+ "p50_ms": "0.000",
437
+ "p95_ms": "0.000"
438
+ })
439
+ print(f" [DIAGNOSTIC ALERT] {reason}")
440
+
441
+ return triage_rows
442
+
443
+
444
+ def run_loghub_eval(scorer: str = "ses") -> list[dict]:
445
+ """Execute both HDFS and BGL evaluations and write manifests."""
446
+ random.seed(42)
447
+
448
+ hdfs_zip = pathlib.Path("data/raw/loghub_raw/HDFS_v1.zip")
449
+ bgl_zip = pathlib.Path("data/raw/loghub_raw/BGL.zip")
450
+
451
+ if not hdfs_zip.exists() or not bgl_zip.exists():
452
+ print("Missing log datasets. Run fetch_datasets.py first.")
453
+ return []
454
+
455
+ manifest_rows = []
456
+
457
+ # 1. HDFS stratified sampling & evaluation
458
+ print("Sampling HDFS sessions...")
459
+ hdfs_sampled = sample_hdfs_stratified(hdfs_zip, sample_size=1000, seed=42)
460
+ hdfs_rows = run_evaluation("HDFS", hdfs_sampled, manifest_rows, scorer=scorer)
461
+
462
+ # 2. BGL stratified sampling & evaluation
463
+ print("Sampling BGL sessions...")
464
+ bgl_sampled = sample_bgl_stratified(bgl_zip, sample_size=1000, seed=42)
465
+ bgl_rows = run_evaluation("BGL", bgl_sampled, manifest_rows, scorer=scorer)
466
+
467
+ # Save the sampled manifest for reproducibility
468
+ manifest_path = pathlib.Path("reports/loghub_sample_manifest.csv")
469
+ manifest_path.parent.mkdir(parents=True, exist_ok=True)
470
+ with manifest_path.open("w", encoding="utf-8", newline="") as fh:
471
+ writer = csv.DictWriter(fh, fieldnames=["dataset", "session_id", "label"])
472
+ writer.writeheader()
473
+ writer.writerows(manifest_rows)
474
+ print(f"Saved manifest to {manifest_path}")
475
+
476
+ return hdfs_rows + bgl_rows
477
+
478
+
479
+ if __name__ == "__main__":
480
+ run_loghub_eval()
@@ -0,0 +1,51 @@
1
+ """LongMemEval loader + answer grader (real agent-memory drift benchmark)."""
2
+ from __future__ import annotations
3
+ import json, pathlib
4
+ from dataclasses import dataclass
5
+
6
+ DRIFT_CATEGORIES = {"knowledge-update", "temporal-reasoning"}
7
+
8
+ @dataclass(frozen=True)
9
+ class Session:
10
+ session_id: str
11
+ date: str
12
+ turns: list[dict]
13
+
14
+ @dataclass(frozen=True)
15
+ class LMEInstance:
16
+ question_id: str
17
+ category: str
18
+ question: str
19
+ answer: str
20
+ question_date: str
21
+ sessions: tuple[Session, ...]
22
+ answer_session_ids: tuple[str, ...]
23
+
24
+ @property
25
+ def is_drift(self) -> bool:
26
+ return self.category in DRIFT_CATEGORIES
27
+
28
+ def load_instances(path: str | pathlib.Path) -> list[LMEInstance]:
29
+ raw = json.loads(pathlib.Path(path).read_text(encoding="utf-8"))
30
+ out: list[LMEInstance] = []
31
+ for r in raw:
32
+ sessions = tuple(
33
+ Session(sid, date, turns)
34
+ for sid, date, turns in zip(
35
+ r["haystack_session_ids"], r["haystack_dates"], r["haystack_sessions"])
36
+ )
37
+ out.append(LMEInstance(
38
+ question_id=r["question_id"], category=r["question_type"],
39
+ question=r["question"], answer=str(r["answer"]),
40
+ question_date=r.get("question_date", ""), sessions=sessions,
41
+ answer_session_ids=tuple(r.get("answer_session_ids", []))))
42
+ return out
43
+
44
+
45
+ def grade_answer(prediction: str, gold: str) -> float:
46
+ """LongMemEval-style lenient match: normalized substring containment.
47
+ (The official grader uses an LLM judge; this deterministic proxy is used
48
+ for unit tests and the smoke run. The battery can swap in the LLM judge.)"""
49
+ p = " ".join(prediction.lower().split())
50
+ g = " ".join(gold.lower().split())
51
+ return 1.0 if g and g in p else 0.0
@@ -0,0 +1,2 @@
1
+ from .base import MemoryBackend, QueryResult
2
+ __all__ = ["MemoryBackend", "QueryResult"]
@@ -0,0 +1,22 @@
1
+ """Common interface for the four drift-experiment memory variants."""
2
+ from __future__ import annotations
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from sma.eval.longmemeval import Session
6
+
7
+ @dataclass
8
+ class QueryResult:
9
+ answer: str
10
+ retrieved: list[str] = field(default_factory=list)
11
+ drift_flagged: bool = False # backend believes the queried fact changed
12
+
13
+ class MemoryBackend(ABC):
14
+ """Shared backbone (DeepSeek orchestrator + extractor) is injected by the harness."""
15
+ name: str = "base"
16
+
17
+ @abstractmethod
18
+ def reset(self) -> None: ...
19
+ @abstractmethod
20
+ def ingest(self, session: Session) -> None: ...
21
+ @abstractmethod
22
+ def query(self, question: str) -> QueryResult: ...
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from .base import MemoryBackend, QueryResult
3
+ from .shared_llm import answer_from
4
+
5
+ class ContextOnly(MemoryBackend):
6
+ name = "context-only"
7
+ def __init__(self, llm): self.llm = llm; self.turns: list[str] = []
8
+ def reset(self): self.turns = []
9
+ def ingest(self, session):
10
+ for t in session.turns:
11
+ self.turns.append(f"[{session.date}] {t['content']}")
12
+ def query(self, question):
13
+ ans = answer_from(self.llm, question, self.turns)
14
+ return QueryResult(answer=ans, retrieved=list(self.turns))
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+ from .base import MemoryBackend, QueryResult
3
+ from .shared_llm import extract_facts, answer_from
4
+
5
+ class RagNotes(MemoryBackend):
6
+ """LLM-written notes, retrieved by token overlap (a faithful simple RAG)."""
7
+ name = "rag-notes"
8
+ def __init__(self, llm, k: int = 5): self.llm = llm; self.k = k; self.notes: list[str] = []
9
+ def reset(self): self.notes = []
10
+ def ingest(self, session):
11
+ for t in session.turns:
12
+ self.notes.extend(extract_facts(self.llm, t["content"]))
13
+ def query(self, question):
14
+ q = set(question.lower().split())
15
+ ranked = sorted(self.notes, key=lambda n: -len(q & set(n.lower().split())))
16
+ top = ranked[: self.k]
17
+ return QueryResult(answer=answer_from(self.llm, question, top), retrieved=top)
@@ -0,0 +1,30 @@
1
+ """Extraction + answering shared by ALL variants (extraction held constant)."""
2
+ from __future__ import annotations
3
+ import json
4
+
5
+ _EXTRACT_SYS = ("Extract the durable user facts from the message as a JSON "
6
+ "array of short strings. Only facts that could be asked about "
7
+ "later. No commentary.")
8
+ _ANSWER_SYS = ("Answer the question using ONLY the provided memory items. "
9
+ "If the memory contradicts itself, prefer the most recent. "
10
+ "Answer concisely; if unknown, say 'unknown'.")
11
+
12
+ def extract_facts(llm, message: str) -> list[str]:
13
+ """Call llm to extract durable facts from a chat message; returns list of strings."""
14
+ out = llm.complete(
15
+ [{"role": "system", "content": _EXTRACT_SYS},
16
+ {"role": "user", "content": message}], max_tokens=300)
17
+ try:
18
+ facts = json.loads(out)
19
+ return [str(f) for f in facts] if isinstance(facts, list) else []
20
+ except (json.JSONDecodeError, TypeError):
21
+ return []
22
+
23
+ def answer_from(llm, question: str, retrieved: list[str]) -> str:
24
+ """Answer question using only the provided retrieved memory items."""
25
+ mem = "\n".join(f"- {r}" for r in retrieved) or "(no memory)"
26
+ out = llm.complete(
27
+ [{"role": "system", "content": _ANSWER_SYS},
28
+ {"role": "user", "content": f"Memory:\n{mem}\n\nQuestion: {question}"}],
29
+ max_tokens=120)
30
+ return out.strip()
@@ -0,0 +1,54 @@
1
+ """SMA memory: each turn's extracted facts are re-encoded into the case store
2
+ (re-derived from the conversation, never from prior generations); retrieval is
3
+ structural; SAGE flags expectation-violations as drift."""
4
+ from __future__ import annotations
5
+ from .base import MemoryBackend, QueryResult
6
+ from .shared_llm import extract_facts, answer_from
7
+ from sma.index.macfac import MacFacIndex
8
+ from sma.ir.schema import make_case, stmt
9
+ from sma.sage.pools import SagePool
10
+ from sma.match.types import MatchConfig
11
+
12
+
13
+ def _fact_to_case(fact: str):
14
+ toks = fact.split()
15
+ if len(toks) >= 3:
16
+ return make_case([stmt(toks[0], toks[1], " ".join(toks[2:]))])
17
+ return make_case([stmt("fact", *(toks or ["empty"]))])
18
+
19
+
20
+ class SmaMemory(MemoryBackend):
21
+ """SMA memory backend: structural re-derivation per turn + SAGE drift detection."""
22
+
23
+ name = "sma"
24
+
25
+ def __init__(self, llm, k: int = 5):
26
+ self.llm = llm
27
+ self.k = k
28
+ self.last_violation = 0.0
29
+
30
+ def reset(self):
31
+ self.index = MacFacIndex(config=MatchConfig())
32
+ self.pool = SagePool("drift", assimilation_threshold=0.2)
33
+ self.texts: dict[str, str] = {}
34
+ self.last_violation = 0.0
35
+
36
+ def ingest(self, session) -> None:
37
+ for t in session.turns:
38
+ for fact in extract_facts(self.llm, t["content"]):
39
+ case = _fact_to_case(fact)
40
+ self.last_violation = self.pool.expectation_violation(case)
41
+ self.index.add(case)
42
+ self.pool.assimilate(case)
43
+ self.texts[case.case_id] = fact
44
+
45
+ def query(self, question: str) -> QueryResult:
46
+ qcase = _fact_to_case(question)
47
+ results = self.index.retrieve(qcase, k=self.k, shortlist=50, fac_budget=20)
48
+ retrieved = [self.texts.get(r.case_id, "") for r in results]
49
+ ans = answer_from(self.llm, question, [r for r in retrieved if r])
50
+ return QueryResult(
51
+ answer=ans,
52
+ retrieved=retrieved,
53
+ drift_flagged=self.last_violation > 0.5,
54
+ )