structuremappingmemory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. sma/__init__.py +5 -0
  2. sma/__main__.py +5 -0
  3. sma/agent/__init__.py +5 -0
  4. sma/agent/adapter_draft.py +217 -0
  5. sma/agent/api.py +67 -0
  6. sma/agent/comparison.py +591 -0
  7. sma/agent/llm.py +280 -0
  8. sma/agent/policies.py +21 -0
  9. sma/agent/service.py +95 -0
  10. sma/cli.py +65 -0
  11. sma/encoders/__init__.py +38 -0
  12. sma/encoders/agentobs.py +27 -0
  13. sma/encoders/base.py +23 -0
  14. sma/encoders/code_treesitter.py +64 -0
  15. sma/encoders/coverage.py +80 -0
  16. sma/encoders/draft_adapter.py +183 -0
  17. sma/encoders/healthcare.py +207 -0
  18. sma/encoders/logs_drain.py +142 -0
  19. sma/encoders/prose_tier1.py +57 -0
  20. sma/encoders/structured.py +57 -0
  21. sma/encoders/traces.py +45 -0
  22. sma/eval/__init__.py +2 -0
  23. sma/eval/agentic/__init__.py +35 -0
  24. sma/eval/agentic/arms/__init__.py +0 -0
  25. sma/eval/agentic/arms/cyber.py +48 -0
  26. sma/eval/agentic/arms/discovery.py +35 -0
  27. sma/eval/agentic/arms/finance.py +38 -0
  28. sma/eval/agentic/arms/legal.py +74 -0
  29. sma/eval/agentic/arms/medicine.py +45 -0
  30. sma/eval/agentic/harness.py +275 -0
  31. sma/eval/agentic/memories.py +308 -0
  32. sma/eval/agentic/metrics.py +82 -0
  33. sma/eval/agentic_qa/__init__.py +27 -0
  34. sma/eval/agentic_qa/agent.py +383 -0
  35. sma/eval/agentic_qa/metrics.py +239 -0
  36. sma/eval/agentic_qa/pools.py +197 -0
  37. sma/eval/arn.py +65 -0
  38. sma/eval/baselines/__init__.py +6 -0
  39. sma/eval/baselines/bge_dense.py +54 -0
  40. sma/eval/baselines/bm25.py +18 -0
  41. sma/eval/baselines/dense.py +42 -0
  42. sma/eval/baselines/hipporag.py +235 -0
  43. sma/eval/baselines/hybrid_rrf.py +30 -0
  44. sma/eval/baselines/longcontext_llm.py +124 -0
  45. sma/eval/baselines/rerank.py +41 -0
  46. sma/eval/baselines/splade.py +77 -0
  47. sma/eval/baselines/wl_kernel.py +163 -0
  48. sma/eval/bugsinpy.py +358 -0
  49. sma/eval/bugsinpy_families.py +164 -0
  50. sma/eval/crossdomain.py +89 -0
  51. sma/eval/diabetes.py +61 -0
  52. sma/eval/drift_env.py +26 -0
  53. sma/eval/drift_metrics.py +24 -0
  54. sma/eval/family_labels.py +167 -0
  55. sma/eval/fraud_elliptic/__init__.py +29 -0
  56. sma/eval/fraud_elliptic/encoder.py +279 -0
  57. sma/eval/fraud_elliptic/eval.py +269 -0
  58. sma/eval/fraud_elliptic/test_encoder.py +123 -0
  59. sma/eval/ieee_cis.py +66 -0
  60. sma/eval/loghub.py +16 -0
  61. sma/eval/loghub_eval.py +480 -0
  62. sma/eval/longmemeval.py +51 -0
  63. sma/eval/memory_backends/__init__.py +2 -0
  64. sma/eval/memory_backends/base.py +22 -0
  65. sma/eval/memory_backends/context_only.py +14 -0
  66. sma/eval/memory_backends/rag_notes.py +17 -0
  67. sma/eval/memory_backends/shared_llm.py +30 -0
  68. sma/eval/memory_backends/sma_memory.py +54 -0
  69. sma/eval/memory_backends/zep_graphiti.py +33 -0
  70. sma/eval/metrics.py +32 -0
  71. sma/eval/ontology_bench.py +219 -0
  72. sma/eval/report.py +573 -0
  73. sma/eval/ssb_eval.py +216 -0
  74. sma/eval/ssb_generator.py +116 -0
  75. sma/eval/stats.py +108 -0
  76. sma/eval/transfer_eval.py +844 -0
  77. sma/index/__init__.py +15 -0
  78. sma/index/ann.py +21 -0
  79. sma/index/content_vectors.py +60 -0
  80. sma/index/inverted.py +63 -0
  81. sma/index/macfac.py +174 -0
  82. sma/ir/__init__.py +22 -0
  83. sma/ir/canon.py +106 -0
  84. sma/ir/schema.py +165 -0
  85. sma/ir/sexpr.py +86 -0
  86. sma/ir/signatures.py +76 -0
  87. sma/match/__init__.py +20 -0
  88. sma/match/conflicts.py +46 -0
  89. sma/match/engine.py +60 -0
  90. sma/match/explain.py +59 -0
  91. sma/match/infer.py +54 -0
  92. sma/match/kernels.py +54 -0
  93. sma/match/mdl.py +30 -0
  94. sma/match/merge_cpsat.py +77 -0
  95. sma/match/merge_greedy.py +15 -0
  96. sma/match/mh.py +177 -0
  97. sma/match/ses.py +84 -0
  98. sma/match/types.py +115 -0
  99. sma/match/verifier.py +27 -0
  100. sma/ontology/__init__.py +45 -0
  101. sma/ontology/attack.py +134 -0
  102. sma/ontology/cpc.py +69 -0
  103. sma/ontology/graph.py +58 -0
  104. sma/ontology/loader.py +262 -0
  105. sma/ontology/mitre_xml.py +67 -0
  106. sma/ontology/mount.py +101 -0
  107. sma/ontology/rdf_loader.py +75 -0
  108. sma/ontology/registry.py +115 -0
  109. sma/ontology/router.py +69 -0
  110. sma/ontology/usgaap.py +73 -0
  111. sma/sage/__init__.py +6 -0
  112. sma/sage/assimilate.py +12 -0
  113. sma/sage/pools.py +105 -0
  114. sma/sage/probabilities.py +10 -0
  115. sma/store/__init__.py +6 -0
  116. sma/store/lmdb_store.py +78 -0
  117. sma/store/registry.py +26 -0
  118. sma/store/wal.py +26 -0
  119. sma/ui/app.py +642 -0
  120. structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
  121. structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
  122. structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
  123. structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
  124. structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
  125. structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,844 @@
1
+ """Cross-system transfer evaluation (blueprint section 8.3, task T2-b).
2
+
3
+ Indexes incidents from one log system and queries with incidents from a
4
+ DIFFERENT system: HDFS->OpenStack and BGL->Thunderbird. Vocabularies differ
5
+ across systems but failure motifs recur, so this is the unseen-concept test
6
+ on real data. Compares the four retrieval methods from loghub_eval
7
+ (SMA, BM25, Dense RAG, KG-PPR Proxy) plus HippoRAG (B5, deterministic
8
+ adaptation) with weighted vote, label_hit_rate@k
9
+ and latency metrics, but WITHOUT an 80/20 split: the index set comes
10
+ entirely from system A and the query set entirely from system B.
11
+
12
+ Run as: python3 -u -m sma.eval.transfer_eval [--scorer ses|mdl]
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import csv
19
+ import gzip
20
+ import hashlib
21
+ import pathlib
22
+ import random
23
+ import re
24
+ import tarfile
25
+ import time
26
+ import zipfile
27
+ from collections import defaultdict, Counter
28
+
29
+ import numpy as np
30
+ from sklearn.metrics import f1_score
31
+
32
+ from sma.encoders import get_encoder
33
+ from sma.index.macfac import MacFacIndex
34
+ from sma.match.types import MatchConfig
35
+ from sma.eval.loghub_eval import sample_hdfs_stratified, sample_bgl_stratified
36
+
37
+ # Expected checksum of a complete Thunderbird.tar.gz (a previous copy was
38
+ # corrupt; we refuse to evaluate against anything that does not match).
39
+ THUNDERBIRD_MD5 = "0891b048df2919dc78c99c4428686b44"
40
+
41
+ # Thunderbird is huge (~30GB uncompressed, ~211M lines). For tractability we
42
+ # cap both streaming passes at the first 20 million lines; the split name
43
+ # records this cap as "thunderbird_first20M".
44
+ THUNDERBIRD_LINE_CAP = 20_000_000
45
+
46
+ # Spirit (Sandia supercomputer, Oliner & Stearley 2007) held-out transfer
47
+ # target. Source: USENIX CFDR hpc4/spirit2.gz (NOT in the LogHub Zenodo
48
+ # records; see data/manifests/datasets.json source_note). Same alert-flag
49
+ # format family as BGL/Thunderbird. md5 verified before every evaluation,
50
+ # like Thunderbird.
51
+ SPIRIT_MD5 = "ba6271c4f454bc21634b19c406d9769c"
52
+
53
+ # Spirit is ~37GB uncompressed (~272M lines). Same tractability cap as
54
+ # Thunderbird: both streaming passes stop at the first 20 million lines;
55
+ # the split name records this cap as "spirit_first20M".
56
+ SPIRIT_LINE_CAP = 20_000_000
57
+
58
+ OPENSTACK_INSTANCE_RE = re.compile(r"instance: ([0-9a-f-]{36})")
59
+
60
+
61
+ def get_stratified_subset(keys, target_n, sort_key, rng):
62
+ """Sample target_n keys stratified over 5 temporal bins (same scheme as
63
+ the nested helpers in loghub_eval)."""
64
+ sorted_keys = sorted(keys, key=sort_key)
65
+ if len(sorted_keys) <= target_n:
66
+ return sorted_keys
67
+ bins = np.array_split(sorted_keys, 5)
68
+ subset = []
69
+ per_bin = target_n // 5
70
+ for b in bins:
71
+ subset.extend(rng.sample(list(b), min(len(b), per_bin)))
72
+ while len(subset) < target_n and sorted_keys:
73
+ rem = list(set(sorted_keys) - set(subset))
74
+ if not rem:
75
+ break
76
+ subset.append(rng.choice(rem))
77
+ return subset
78
+
79
+
80
+ def sample_openstack(
81
+ path: pathlib.Path, sample_size: int = 200, seed: int = 42
82
+ ) -> list[tuple[str, str, str]]:
83
+ """Sessionize and sample stratified OpenStack logs by VM instance id.
84
+
85
+ The LogHub OpenStack archive contains openstack_normal1.log,
86
+ openstack_normal2.log and openstack_abnormal.log. Sessions are grouped by
87
+ the VM instance id appearing as "[instance: <uuid>]"; a session is labeled
88
+ Anomaly iff it comes from the abnormal log (instance-id sets of the normal
89
+ and abnormal runs are disjoint).
90
+ """
91
+ members = [
92
+ ("openstack_normal1.log", "Normal"),
93
+ ("openstack_normal2.log", "Normal"),
94
+ ("openstack_abnormal.log", "Anomaly"),
95
+ ]
96
+
97
+ # Pass 1: gather session sizes, labels and first-seen order
98
+ session_counts = Counter()
99
+ labels = {}
100
+ first_seen = {}
101
+ line_no = 0
102
+ with tarfile.open(path, "r:gz") as tar:
103
+ for member_name, label in members:
104
+ with tar.extractfile(member_name) as fh:
105
+ for line_bytes in fh:
106
+ line_no += 1
107
+ line = line_bytes.decode("utf-8", errors="ignore")
108
+ match = OPENSTACK_INSTANCE_RE.search(line)
109
+ if not match:
110
+ continue
111
+ key = f"openstack_{match.group(1)}"
112
+ session_counts[key] += 1
113
+ if label == "Anomaly":
114
+ labels[key] = "Anomaly"
115
+ else:
116
+ labels.setdefault(key, "Normal")
117
+ # Files are time-ordered, so first-seen line index is a
118
+ # monotone proxy for the first timestamp (used only for
119
+ # the 5-bin temporal stratification below).
120
+ if key not in first_seen:
121
+ first_seen[key] = line_no
122
+
123
+ # Filter sessions with length >= 3 to avoid tiny cases (BGL convention)
124
+ filtered_keys = [k for k, count in session_counts.items() if count >= 3]
125
+ anom_keys = [k for k in filtered_keys if labels[k] == "Anomaly"]
126
+ norm_keys = [k for k in filtered_keys if labels[k] == "Normal"]
127
+
128
+ rng = random.Random(seed)
129
+ sampled_anom = get_stratified_subset(
130
+ anom_keys, sample_size // 2, lambda k: first_seen[k], rng
131
+ )
132
+ sampled_norm = get_stratified_subset(
133
+ norm_keys, sample_size // 2, lambda k: first_seen[k], rng
134
+ )
135
+ sampled_set = set(sampled_anom + sampled_norm)
136
+
137
+ # Pass 2: extract actual lines for the sampled set
138
+ sessions_lines = defaultdict(list)
139
+ with tarfile.open(path, "r:gz") as tar:
140
+ for member_name, _label in members:
141
+ with tar.extractfile(member_name) as fh:
142
+ for line_bytes in fh:
143
+ line = line_bytes.decode("utf-8", errors="ignore")
144
+ match = OPENSTACK_INSTANCE_RE.search(line)
145
+ if not match:
146
+ continue
147
+ key = f"openstack_{match.group(1)}"
148
+ if key in sampled_set:
149
+ # Drop the leading source-filename column: which file
150
+ # a line came from perfectly encodes the session label
151
+ # (normal vs abnormal run), so keeping it would leak
152
+ # labels into the text just like the BGL alert column.
153
+ sessions_lines[key].append(line.partition(" ")[2] or line)
154
+
155
+ results = []
156
+ for k in sampled_anom + sampled_norm:
157
+ lines = sessions_lines.get(k, [])
158
+ if lines:
159
+ results.append((k, "".join(lines), labels[k]))
160
+ return results
161
+
162
+
163
+ def check_thunderbird(path: pathlib.Path) -> str | None:
164
+ """Return None if Thunderbird.tar.gz is present and checksum-verified,
165
+ otherwise a human-readable reason to skip the BGL->Thunderbird pair."""
166
+ if not path.exists():
167
+ return f"{path} is missing (download may still be in progress)"
168
+ digest = hashlib.md5()
169
+ with path.open("rb") as fh:
170
+ for chunk in iter(lambda: fh.read(1 << 20), b""):
171
+ digest.update(chunk)
172
+ actual = digest.hexdigest()
173
+ if actual != THUNDERBIRD_MD5:
174
+ return (
175
+ f"{path} md5 mismatch: expected {THUNDERBIRD_MD5}, got {actual} "
176
+ "(file incomplete or corrupt; a previous copy was corrupt too)"
177
+ )
178
+ return None
179
+
180
+
181
+ def sample_thunderbird(
182
+ path: pathlib.Path, sample_size: int = 200, seed: int = 42
183
+ ) -> list[tuple[str, str, str]]:
184
+ """Sessionize and sample stratified Thunderbird logs (BGL-like format).
185
+
186
+ Streams the tar.gz in two passes without extracting to disk or holding
187
+ lines in memory, like sample_bgl_stratified. Sessionizes per node into
188
+ 60-second windows; the first whitespace-separated field is the
189
+ ground-truth label column ("-" = normal, anything else = alert category)
190
+ and is STRIPPED from extracted text to avoid label leakage. Both passes
191
+ are capped at the first THUNDERBIRD_LINE_CAP (20M) lines for tractability
192
+ (the split name "thunderbird_first20M" records the cap).
193
+ """
194
+ skip_reason = check_thunderbird(path)
195
+ if skip_reason:
196
+ print(f"Skipping Thunderbird sampling: {skip_reason}")
197
+ return []
198
+
199
+ def stream_lines(tb_path):
200
+ """Yield decoded lines of the first log member, capped at 20M."""
201
+ with tarfile.open(tb_path, "r|gz") as tar:
202
+ for member in tar:
203
+ fh = tar.extractfile(member)
204
+ if fh is None:
205
+ continue
206
+ for line_no, line_bytes in enumerate(fh, start=1):
207
+ if line_no > THUNDERBIRD_LINE_CAP:
208
+ return
209
+ yield line_bytes.decode("utf-8", errors="ignore")
210
+ return # only the first (log) member matters
211
+
212
+ # Pass 1: gather metadata for sessionization and labels
213
+ session_counts = Counter()
214
+ labels = defaultdict(bool)
215
+ timestamps = {}
216
+ for line in stream_lines(path):
217
+ parts = line.split(maxsplit=5)
218
+ if len(parts) < 5:
219
+ continue
220
+ label = parts[0]
221
+ try:
222
+ timestamp = int(parts[1])
223
+ except ValueError:
224
+ continue
225
+ node_id = parts[3]
226
+
227
+ # Group Thunderbird into 60-second windows per node, like BGL
228
+ window = timestamp // 60
229
+ session_key = f"tbird_{node_id}_{window}"
230
+ session_counts[session_key] += 1
231
+ if label != "-":
232
+ labels[session_key] = True
233
+ if session_key not in timestamps:
234
+ timestamps[session_key] = timestamp
235
+
236
+ # Filter sessions with length >= 3 to avoid tiny cases
237
+ filtered_keys = [k for k, count in session_counts.items() if count >= 3]
238
+ anom_keys = [k for k in filtered_keys if labels[k]]
239
+ norm_keys = [k for k in filtered_keys if not labels[k]]
240
+
241
+ rng = random.Random(seed)
242
+ sampled_anom = get_stratified_subset(
243
+ anom_keys, sample_size // 2, lambda k: timestamps[k], rng
244
+ )
245
+ sampled_norm = get_stratified_subset(
246
+ norm_keys, sample_size // 2, lambda k: timestamps[k], rng
247
+ )
248
+ sampled_set = set(sampled_anom + sampled_norm)
249
+
250
+ # Pass 2: extract actual lines for the sampled set
251
+ sessions_lines = defaultdict(list)
252
+ for line in stream_lines(path):
253
+ parts = line.split(maxsplit=5)
254
+ if len(parts) < 5:
255
+ continue
256
+ try:
257
+ timestamp = int(parts[1])
258
+ except ValueError:
259
+ continue
260
+ node_id = parts[3]
261
+ window = timestamp // 60
262
+ session_key = f"tbird_{node_id}_{window}"
263
+ if session_key in sampled_set:
264
+ # Drop the leading alert-category column: it is the ground-truth
265
+ # label, not log content. Keeping it leaks labels to every
266
+ # retriever (Thunderbird '-' = normal, anything else = anomaly) -
267
+ # the same bug previously shipped and fixed in BGL.
268
+ sessions_lines[session_key].append(line.partition(" ")[2] or line)
269
+
270
+ results = []
271
+ for k in sampled_anom + sampled_norm:
272
+ lines = sessions_lines.get(k, [])
273
+ if lines:
274
+ results.append((k, "".join(lines), "Anomaly" if labels[k] else "Normal"))
275
+ return results
276
+
277
+
278
+ def check_spirit(path: pathlib.Path) -> str | None:
279
+ """Return None if spirit2.gz is present and checksum-verified, otherwise
280
+ a human-readable reason to skip Spirit pairs (mirrors check_thunderbird)."""
281
+ if not path.exists():
282
+ return f"{path} is missing (download may still be in progress)"
283
+ digest = hashlib.md5()
284
+ with path.open("rb") as fh:
285
+ for chunk in iter(lambda: fh.read(1 << 20), b""):
286
+ digest.update(chunk)
287
+ actual = digest.hexdigest()
288
+ if actual != SPIRIT_MD5:
289
+ return (
290
+ f"{path} md5 mismatch: expected {SPIRIT_MD5}, got {actual} "
291
+ "(file incomplete or corrupt)"
292
+ )
293
+ return None
294
+
295
+
296
+ def sample_spirit(
297
+ path: pathlib.Path, sample_size: int = 200, seed: int = 42
298
+ ) -> list[tuple[str, str, str]]:
299
+ """Sessionize and sample stratified Spirit logs (BGL/Thunderbird family).
300
+
301
+ Modeled exactly on sample_thunderbird: streams the plain gzip in two
302
+ passes without extracting to disk or holding lines in memory. Sessionizes
303
+ per node into 60-second windows with a >=3 line minimum; the first
304
+ whitespace-separated field is the ground-truth alert label column
305
+ ("-" = normal, anything else = alert category, e.g. R_HDA_NR) and is
306
+ STRIPPED from extracted text to avoid label leakage. Both passes are
307
+ capped at the first SPIRIT_LINE_CAP (20M) lines for tractability (the
308
+ split name "spirit_first20M" records the cap).
309
+
310
+ Spirit line format (verified against CFDR hpc4/spirit2.gz):
311
+ LABEL EPOCH DATE NODE Month Day HH:MM:SS src daemon[pid]: message
312
+ so parts[0]=label, parts[1]=epoch seconds, parts[3]=node id - identical
313
+ field positions to Thunderbird.
314
+ """
315
+ skip_reason = check_spirit(path)
316
+ if skip_reason:
317
+ print(f"Skipping Spirit sampling: {skip_reason}")
318
+ return []
319
+
320
+ def stream_lines(sp_path):
321
+ """Yield decoded lines of the gzipped log, capped at 20M."""
322
+ with gzip.open(sp_path, "rb") as fh:
323
+ for line_no, line_bytes in enumerate(fh, start=1):
324
+ if line_no > SPIRIT_LINE_CAP:
325
+ return
326
+ yield line_bytes.decode("utf-8", errors="ignore")
327
+
328
+ # Pass 1: gather metadata for sessionization and labels
329
+ session_counts = Counter()
330
+ labels = defaultdict(bool)
331
+ timestamps = {}
332
+ for line in stream_lines(path):
333
+ parts = line.split(maxsplit=5)
334
+ if len(parts) < 5:
335
+ continue
336
+ label = parts[0]
337
+ try:
338
+ timestamp = int(parts[1])
339
+ except ValueError:
340
+ continue
341
+ node_id = parts[3]
342
+
343
+ # Group Spirit into 60-second windows per node, like BGL/Thunderbird
344
+ window = timestamp // 60
345
+ session_key = f"spirit_{node_id}_{window}"
346
+ session_counts[session_key] += 1
347
+ if label != "-":
348
+ labels[session_key] = True
349
+ if session_key not in timestamps:
350
+ timestamps[session_key] = timestamp
351
+
352
+ # Filter sessions with length >= 3 to avoid tiny cases
353
+ filtered_keys = [k for k, count in session_counts.items() if count >= 3]
354
+ anom_keys = [k for k in filtered_keys if labels[k]]
355
+ norm_keys = [k for k in filtered_keys if not labels[k]]
356
+ print(
357
+ f"Spirit (first {SPIRIT_LINE_CAP // 1_000_000}M lines): "
358
+ f"{len(session_counts)} sessions, {len(filtered_keys)} with >=3 lines "
359
+ f"({len(anom_keys)} anomalous / {len(norm_keys)} normal)"
360
+ )
361
+
362
+ rng = random.Random(seed)
363
+ sampled_anom = get_stratified_subset(
364
+ anom_keys, sample_size // 2, lambda k: timestamps[k], rng
365
+ )
366
+ sampled_norm = get_stratified_subset(
367
+ norm_keys, sample_size // 2, lambda k: timestamps[k], rng
368
+ )
369
+ sampled_set = set(sampled_anom + sampled_norm)
370
+
371
+ # Pass 2: extract actual lines for the sampled set
372
+ sessions_lines = defaultdict(list)
373
+ for line in stream_lines(path):
374
+ parts = line.split(maxsplit=5)
375
+ if len(parts) < 5:
376
+ continue
377
+ try:
378
+ timestamp = int(parts[1])
379
+ except ValueError:
380
+ continue
381
+ node_id = parts[3]
382
+ window = timestamp // 60
383
+ session_key = f"spirit_{node_id}_{window}"
384
+ if session_key in sampled_set:
385
+ # Drop the leading alert-category column: it is the ground-truth
386
+ # label, not log content. Keeping it would leak labels to every
387
+ # retriever (Spirit "-" = normal, anything else = anomaly), the
388
+ # same leak previously found and fixed in BGL and Thunderbird.
389
+ sessions_lines[session_key].append(line.partition(" ")[2] or line)
390
+
391
+ results = []
392
+ for k in sampled_anom + sampled_norm:
393
+ lines = sessions_lines.get(k, [])
394
+ if lines:
395
+ results.append((k, "".join(lines), "Anomaly" if labels[k] else "Normal"))
396
+ sampled_counts = Counter(label for _, _, label in results)
397
+ print(
398
+ f"Spirit sample: {len(results)} sessions "
399
+ f"({sampled_counts.get('Anomaly', 0)} Anomaly / "
400
+ f"{sampled_counts.get('Normal', 0)} Normal)"
401
+ )
402
+ return results
403
+
404
+
405
+ def run_transfer(
406
+ index_data: list[tuple[str, str, str]],
407
+ query_data: list[tuple[str, str, str]],
408
+ pair_name: str,
409
+ scorer: str = "ses",
410
+ normalization: str = "max",
411
+ per_query_rows: list[dict] | None = None,
412
+ ) -> list[dict]:
413
+ """Execute five-way cross-system transfer comparison.
414
+
415
+ Adapted from loghub_eval.run_evaluation but WITHOUT the 80/20 split:
416
+ index_data is the full index set (system A), query_data the full query
417
+ set (system B).
418
+
419
+ If ``per_query_rows`` is a list, one dict per (query, method) is appended
420
+ to it -- query_id, true/pred label (the macro-F1 inputs) and per-query
421
+ hit@{1,5,10} -- so callers such as scripts/confirmatory_battery.py can run
422
+ paired per-query statistics. Returned summary rows are unchanged.
423
+ """
424
+ split_name = f"{pair_name}[{scorer}]"
425
+ print(
426
+ f"\n--- Running transfer evaluation {split_name} "
427
+ f"({len(index_data)} index / {len(query_data)} query cases) ---"
428
+ )
429
+
430
+ # Parse and encode cases
431
+ log_encoder = get_encoder("logs")
432
+
433
+ print("Encoding index cases...")
434
+ index_cases = []
435
+ index_docs = [] # List of (case_id, text)
436
+ index_labels = {}
437
+ for sid, text, label in index_data:
438
+ case = log_encoder.encode(text, session_id=sid).case
439
+ index_cases.append(case)
440
+ index_docs.append((case.case_id, text))
441
+ index_labels[case.case_id] = label
442
+
443
+ print("Encoding query cases...")
444
+ query_cases = []
445
+ query_docs = []
446
+ query_labels = {}
447
+ for sid, text, label in query_data:
448
+ case = log_encoder.encode(text, session_id=sid).case
449
+ query_cases.append(case)
450
+ query_docs.append((case.case_id, text))
451
+ query_labels[case.case_id] = label
452
+
453
+ # Build indexes ONCE before the query loop
454
+ # 1. Build SMA MAC/FAC index
455
+ print(f"Building SMA Index (scorer={scorer})...")
456
+ sma_index = MacFacIndex(config=MatchConfig(scorer=scorer, normalization=normalization))
457
+ sma_index.build(index_cases)
458
+
459
+ # 2. Build BM25 Index
460
+ print("Building BM25 Index...")
461
+ from rank_bm25 import BM25Okapi
462
+ tokenized_index = [text.lower().split() for _, text in index_docs]
463
+ bm25_index = BM25Okapi(tokenized_index)
464
+
465
+ # 3. Build Dense RAG Index (SentenceTransformers)
466
+ print("Building Dense RAG Index (SentenceTransformers)...")
467
+ from sentence_transformers import SentenceTransformer, util
468
+ dense_model = SentenceTransformer('all-MiniLM-L6-v2')
469
+ index_texts = [text for _, text in index_docs]
470
+ index_embeddings = dense_model.encode(index_texts, convert_to_tensor=True, show_progress_bar=False)
471
+
472
+ # 4. Build KG-PPR Proxy index
473
+ print("Building KG-PPR Proxy Index...")
474
+ index_entity_counters = {
475
+ ic.case_id: Counter(e.name for e in ic.entities())
476
+ for ic in index_cases
477
+ }
478
+
479
+ # 5. Build HippoRAG index (B5: phrase graph + Personalized PageRank)
480
+ print("Building HippoRAG Index...")
481
+ from sma.eval.baselines.hipporag import HippoRAGRetriever
482
+ hipporag_index = HippoRAGRetriever()
483
+ hipporag_index.build(index_docs)
484
+
485
+ # 6. Enterprise hybrid stack (B6): RRF(BM25 + dense) and a cross-encoder
486
+ # rerank over the fused top-20 pool - the production RAG posture.
487
+ print("Loading cross-encoder reranker (Hybrid+Rerank)...")
488
+ from sma.eval.baselines.hybrid_rrf import rrf_fuse
489
+ from sma.eval.baselines.rerank import CrossEncoderReranker
490
+ reranker = CrossEncoderReranker()
491
+ index_text_by_id = dict(index_docs)
492
+
493
+ # Per-query ranked retrieval for each method, as (case_id, score) pairs.
494
+ def retrieve_sma(q_case, q_text):
495
+ # shortlist=40, fac_budget=20 keeps CPU latency bounded
496
+ results = sma_index.retrieve(q_case, k=10, shortlist=40, fac_budget=20)
497
+ return [(r.case_id, r.ses_n) for r in results]
498
+
499
+ def retrieve_bm25(q_case, q_text):
500
+ scores = bm25_index.get_scores(q_text.lower().split())
501
+ ranked = sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))
502
+ return ranked[:10]
503
+
504
+ def retrieve_dense(q_case, q_text):
505
+ query_embedding = dense_model.encode(q_text, convert_to_tensor=True, show_progress_bar=False)
506
+ scores = util.cos_sim(query_embedding, index_embeddings)[0].cpu().tolist()
507
+ ranked = sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))
508
+ return ranked[:10]
509
+
510
+ def retrieve_kg(q_case, q_text):
511
+ q_counter = Counter(e.name for e in q_case.entities())
512
+ ranked = sorted(
513
+ (
514
+ (ic_id, float(sum(min(v, counts.get(k, 0)) for k, v in q_counter.items())))
515
+ for ic_id, counts in index_entity_counters.items()
516
+ ),
517
+ key=lambda row: (-row[1], row[0]),
518
+ )
519
+ return ranked[:10]
520
+
521
+ def retrieve_hipporag(q_case, q_text):
522
+ return hipporag_index.retrieve(q_text, k=10)
523
+
524
+ def _bm25_ranked(q_text, k):
525
+ scores = bm25_index.get_scores(q_text.lower().split())
526
+ return sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))[:k]
527
+
528
+ def _dense_ranked(q_text, k):
529
+ query_embedding = dense_model.encode(q_text, convert_to_tensor=True, show_progress_bar=False)
530
+ scores = util.cos_sim(query_embedding, index_embeddings)[0].cpu().tolist()
531
+ return sorted(zip(doc_ids, scores), key=lambda row: (-row[1], row[0]))[:k]
532
+
533
+ def retrieve_hybrid_rrf(q_case, q_text):
534
+ return rrf_fuse([_bm25_ranked(q_text, 20), _dense_ranked(q_text, 20)], top_k=10)
535
+
536
+ def retrieve_hybrid_rerank(q_case, q_text):
537
+ pool = rrf_fuse([_bm25_ranked(q_text, 20), _dense_ranked(q_text, 20)], top_k=20)
538
+ candidates = [(cid, index_text_by_id[cid]) for cid, _ in pool]
539
+ return reranker.rerank(q_text, candidates, top_k=10)
540
+
541
+ def weighted_vote(ranked, top=5):
542
+ voting = {"Anomaly": 0.0, "Normal": 0.0}
543
+ for case_id, score in ranked[:top]:
544
+ voting[index_labels[case_id]] += score
545
+ return max(voting, key=voting.get) if sum(voting.values()) > 0 else "Normal"
546
+
547
+ retrievers = {
548
+ "SMA": retrieve_sma,
549
+ "BM25": retrieve_bm25,
550
+ "Dense RAG": retrieve_dense,
551
+ "KG-PPR Proxy": retrieve_kg,
552
+ "HippoRAG": retrieve_hipporag,
553
+ "Hybrid-RRF": retrieve_hybrid_rrf,
554
+ "Hybrid+Rerank": retrieve_hybrid_rerank,
555
+ }
556
+ methods = list(retrievers)
557
+ metrics_by_method = {m: {"recalls": [], "preds": [], "latencies": []} for m in methods}
558
+ doc_ids = [doc_id for doc_id, _ in index_docs]
559
+
560
+ print("Starting retrieval runs...")
561
+ total_queries = len(query_cases)
562
+ for idx, (q_case, (q_case_id, q_text)) in enumerate(zip(query_cases, query_docs), start=1):
563
+ for method, retriever in retrievers.items():
564
+ t0 = time.perf_counter()
565
+ ranked = retriever(q_case, q_text)
566
+ elapsed_ms = (time.perf_counter() - t0) * 1000
567
+ data = metrics_by_method[method]
568
+ data["recalls"].append([case_id for case_id, _ in ranked])
569
+ data["latencies"].append(elapsed_ms)
570
+ data["preds"].append(weighted_vote(ranked))
571
+
572
+ if idx % 20 == 0 or idx == total_queries:
573
+ print(f"Processed {idx}/{total_queries} retrieval runs...")
574
+
575
+ # Calculate final metrics
576
+ transfer_rows = []
577
+ true_labels = [query_labels[c.case_id] for c in query_cases]
578
+
579
+ for m in methods:
580
+ data = metrics_by_method[m]
581
+ preds = data["preds"]
582
+ recalls = data["recalls"]
583
+ latencies = data["latencies"]
584
+
585
+ # F1 Score
586
+ f1 = f1_score(true_labels, preds, average="macro")
587
+
588
+ # label_hit_rate @ 1, 5, 10
589
+ r1_list = []
590
+ r5_list = []
591
+ r10_list = []
592
+ for q_idx, q_case in enumerate(query_cases):
593
+ q_label = query_labels[q_case.case_id]
594
+ ret_ids = recalls[q_idx]
595
+
596
+ # Find all relevant index cases for this query
597
+ relevant_ids = {ic.case_id for ic in index_cases if index_labels[ic.case_id] == q_label}
598
+
599
+ # Hit rate at k = count of retrieved relevant / min(relevant_ids, k)
600
+ def compute_hit_rate_k(k):
601
+ hits = len(set(ret_ids[:k]).intersection(relevant_ids))
602
+ denom = min(len(relevant_ids), k)
603
+ return hits / denom if denom > 0 else 0.0
604
+
605
+ r1_list.append(compute_hit_rate_k(1))
606
+ r5_list.append(compute_hit_rate_k(5))
607
+ r10_list.append(compute_hit_rate_k(10))
608
+
609
+ if per_query_rows is not None:
610
+ per_query_rows.append({
611
+ "split": split_name,
612
+ "method": m,
613
+ "query_id": q_case.case_id,
614
+ "true_label": q_label,
615
+ "pred_label": preds[q_idx],
616
+ "hit@1": r1_list[-1],
617
+ "hit@5": r5_list[-1],
618
+ "hit@10": r10_list[-1],
619
+ })
620
+
621
+ r1 = sum(r1_list) / len(r1_list)
622
+ r5 = sum(r5_list) / len(r5_list)
623
+ r10 = sum(r10_list) / len(r10_list)
624
+
625
+ # Latency p50, p95
626
+ p50 = np.percentile(latencies, 50)
627
+ p95 = np.percentile(latencies, 95)
628
+
629
+ transfer_rows.append({
630
+ "dataset": "LogHub",
631
+ "split": split_name,
632
+ "method": m,
633
+ "macro_f1": f"{f1:.4f}",
634
+ "label_hit_rate@1": f"{r1:.4f}",
635
+ "label_hit_rate@5": f"{r5:.4f}",
636
+ "label_hit_rate@10": f"{r10:.4f}",
637
+ "p50_ms": f"{p50:.3f}",
638
+ "p95_ms": f"{p95:.3f}"
639
+ })
640
+
641
+ # Print results
642
+ print(f"Method: {m}")
643
+ print(f" Macro-F1: {f1:.4f}")
644
+ print(f" label_hit_rate@1: {r1:.4f}, label_hit_rate@5: {r5:.4f}, label_hit_rate@10: {r10:.4f}")
645
+ print(f" p50 Latency: {p50:.3f} ms, p95 Latency: {p95:.3f} ms")
646
+
647
+ # Diagnostic alerts for collapsed or suspiciously perfect runs
648
+ unique_preds = set(preds)
649
+ is_suspicious = (f1 == 0.0 or f1 == 1.0 or len(unique_preds) <= 1)
650
+ if is_suspicious:
651
+ reason = ""
652
+ if f1 == 0.0:
653
+ reason = "F1 is 0.0: Retrieval collapse or dataset imbalance"
654
+ elif f1 == 1.0:
655
+ reason = "F1 is 1.0: Suspiciously perfect classification - potential data leakage or indexing overlap"
656
+ elif len(unique_preds) <= 1:
657
+ reason = f"Retrieval collapse: predicted only '{list(unique_preds)[0]}' sessions"
658
+
659
+ transfer_rows.append({
660
+ "dataset": "DIAGNOSTIC",
661
+ "split": split_name,
662
+ "method": f"{m}_alert",
663
+ "macro_f1": reason,
664
+ "label_hit_rate@1": "ALERT",
665
+ "label_hit_rate@5": "ALERT",
666
+ "label_hit_rate@10": "ALERT",
667
+ "p50_ms": "0.000",
668
+ "p95_ms": "0.000"
669
+ })
670
+ print(f" [DIAGNOSTIC ALERT] {reason}")
671
+
672
+ return transfer_rows
673
+
674
+
675
+ def append_transfer_rows(
676
+ rows: list[dict], out_path: str | pathlib.Path = "reports/transfer_metrics.csv"
677
+ ) -> None:
678
+ """Append metric rows to a transfer metrics CSV (triage schema).
679
+
680
+ Defaults to reports/transfer_metrics.csv (the original behavior)."""
681
+ if not rows:
682
+ return
683
+ out_path = pathlib.Path(out_path)
684
+ out_path.parent.mkdir(parents=True, exist_ok=True)
685
+ fieldnames = [
686
+ "dataset", "split", "method", "macro_f1",
687
+ "label_hit_rate@1", "label_hit_rate@5", "label_hit_rate@10",
688
+ "p50_ms", "p95_ms",
689
+ ]
690
+ write_header = not out_path.exists()
691
+ with out_path.open("a", encoding="utf-8", newline="") as fh:
692
+ writer = csv.DictWriter(fh, fieldnames=fieldnames)
693
+ if write_header:
694
+ writer.writeheader()
695
+ writer.writerows(rows)
696
+ print(f"Appended {len(rows)} rows to {out_path}")
697
+
698
+
699
+ # Registry of samplable systems for --pairs. Each entry maps the system name
700
+ # (as written in a "A->B" pair spec) to (archive filename, sampler, display
701
+ # name used in the split string, optional integrity-check function).
702
+ SYSTEMS = {
703
+ "HDFS": ("HDFS_v1.zip", sample_hdfs_stratified, "HDFS", None),
704
+ "BGL": ("BGL.zip", sample_bgl_stratified, "BGL", None),
705
+ "OpenStack": ("OpenStack.tar.gz", sample_openstack, "OpenStack", None),
706
+ "Thunderbird": (
707
+ "Thunderbird.tar.gz", sample_thunderbird, "thunderbird_first20M",
708
+ check_thunderbird,
709
+ ),
710
+ "Spirit": ("spirit2.gz", sample_spirit, "spirit_first20M", check_spirit),
711
+ }
712
+
713
+
714
+ def run_named_pairs(pairs_spec, scorer, seed, index_size, query_size, out_path,
715
+ normalization="max", per_query_rows=None):
716
+ """Run a comma-separated list of "A->B" transfer pairs (e.g.
717
+ "BGL->Spirit,HDFS->Spirit") with an explicit seed, appending rows to
718
+ out_path. Additive entry point used by --pairs; the default (no --pairs)
719
+ code path in main() is unchanged. ``per_query_rows`` is threaded through
720
+ to run_transfer (see there); the summary rows are also returned."""
721
+ raw_dir = pathlib.Path("data/raw/loghub_raw")
722
+ all_rows = []
723
+ sample_cache = {} # (system, size, seed) -> sampled sessions
724
+
725
+ def sample_system(name, size):
726
+ key = (name, size, seed)
727
+ if key in sample_cache:
728
+ return sample_cache[key]
729
+ filename, sampler, _display, check = SYSTEMS[name]
730
+ path = raw_dir / filename
731
+ if not path.exists():
732
+ print(f"Skipping {name}: {path} is missing. Run fetch_datasets.py first.")
733
+ data = []
734
+ else:
735
+ skip = check(path) if check else None
736
+ if skip:
737
+ print(f"Skipping {name}: {skip}")
738
+ data = []
739
+ else:
740
+ print(f"Sampling {name} sessions (size={size}, seed={seed})...")
741
+ data = sampler(path, sample_size=size, seed=seed)
742
+ counts = Counter(label for _, _, label in data)
743
+ print(
744
+ f"{name} class counts: {counts.get('Anomaly', 0)} Anomaly / "
745
+ f"{counts.get('Normal', 0)} Normal"
746
+ )
747
+ sample_cache[key] = data
748
+ return data
749
+
750
+ for pair in [p.strip() for p in pairs_spec.split(",") if p.strip()]:
751
+ if "->" not in pair:
752
+ print(f"Skipping malformed pair spec '{pair}' (expected 'A->B').")
753
+ continue
754
+ src, dst = (s.strip() for s in pair.split("->", 1))
755
+ if src not in SYSTEMS or dst not in SYSTEMS:
756
+ known = ", ".join(SYSTEMS)
757
+ print(f"Skipping pair '{pair}': unknown system (known: {known}).")
758
+ continue
759
+ index_data = sample_system(src, index_size)
760
+ query_data = sample_system(dst, query_size)
761
+ if not index_data or not query_data:
762
+ print(f"Skipping pair '{pair}': empty index or query sample.")
763
+ continue
764
+ pair_name = f"{SYSTEMS[src][2]}->{SYSTEMS[dst][2]}[seed{seed}]"
765
+ all_rows.extend(run_transfer(index_data, query_data, pair_name, scorer=scorer,
766
+ normalization=normalization, per_query_rows=per_query_rows))
767
+
768
+ append_transfer_rows(all_rows, out_path)
769
+ return all_rows
770
+
771
+
772
+ def main() -> None:
773
+ parser = argparse.ArgumentParser(description="Cross-system transfer evaluation (T2-b)")
774
+ parser.add_argument("--scorer", choices=["ses", "mdl", "surprisal"], default="ses")
775
+ parser.add_argument("--normalization", choices=["max", "min", "sqrt", "target"], default="max")
776
+ parser.add_argument("--index-size", type=int, default=800,
777
+ help="stratified sessions to index from system A")
778
+ parser.add_argument("--query-size", type=int, default=200,
779
+ help="stratified sessions to query from system B")
780
+ parser.add_argument("--pairs", default=None,
781
+ help="comma-separated 'A->B' pairs to run instead of the "
782
+ "default HDFS->OpenStack and BGL->Thunderbird pairs, "
783
+ "e.g. 'BGL->Spirit,HDFS->Spirit'")
784
+ parser.add_argument("--seed", type=int, default=42,
785
+ help="sampling seed threaded into both samplers")
786
+ parser.add_argument("--out", default="reports/transfer_metrics.csv",
787
+ help="CSV path to append metric rows to")
788
+ args = parser.parse_args()
789
+
790
+ random.seed(args.seed)
791
+
792
+ if args.pairs:
793
+ run_named_pairs(
794
+ args.pairs, args.scorer, args.seed,
795
+ normalization=args.normalization,
796
+ index_size=args.index_size, query_size=args.query_size, out_path=args.out,
797
+ )
798
+ return
799
+
800
+ raw_dir = pathlib.Path("data/raw/loghub_raw")
801
+ hdfs_zip = raw_dir / "HDFS_v1.zip"
802
+ bgl_zip = raw_dir / "BGL.zip"
803
+ openstack_tar = raw_dir / "OpenStack.tar.gz"
804
+ thunderbird_tar = raw_dir / "Thunderbird.tar.gz"
805
+
806
+ all_rows = []
807
+
808
+ # Pair 1: HDFS -> OpenStack
809
+ if not hdfs_zip.exists():
810
+ print(f"Skipping HDFS->OpenStack: {hdfs_zip} is missing. Run fetch_datasets.py first.")
811
+ elif not openstack_tar.exists():
812
+ print(f"Skipping HDFS->OpenStack: {openstack_tar} is missing. Run fetch_datasets.py first.")
813
+ else:
814
+ print("Sampling HDFS sessions (index set)...")
815
+ hdfs_index = sample_hdfs_stratified(hdfs_zip, sample_size=args.index_size, seed=args.seed)
816
+ print("Sampling OpenStack sessions (query set)...")
817
+ openstack_query = sample_openstack(openstack_tar, sample_size=args.query_size, seed=args.seed)
818
+ all_rows.extend(
819
+ run_transfer(hdfs_index, openstack_query, "HDFS->OpenStack", scorer=args.scorer)
820
+ )
821
+
822
+ # Pair 2: BGL -> Thunderbird (first 20M lines, see THUNDERBIRD_LINE_CAP)
823
+ tbird_skip = check_thunderbird(thunderbird_tar)
824
+ if not bgl_zip.exists():
825
+ print(f"Skipping BGL->Thunderbird: {bgl_zip} is missing. Run fetch_datasets.py first.")
826
+ elif tbird_skip:
827
+ print(f"Skipping BGL->Thunderbird: {tbird_skip}")
828
+ else:
829
+ print("Sampling BGL sessions (index set)...")
830
+ bgl_index = sample_bgl_stratified(bgl_zip, sample_size=args.index_size, seed=args.seed)
831
+ print("Sampling Thunderbird sessions (query set, first 20M lines)...")
832
+ tbird_query = sample_thunderbird(thunderbird_tar, sample_size=args.query_size, seed=args.seed)
833
+ if tbird_query:
834
+ all_rows.extend(
835
+ run_transfer(bgl_index, tbird_query, "BGL->thunderbird_first20M", scorer=args.scorer)
836
+ )
837
+ else:
838
+ print("Skipping BGL->Thunderbird: no Thunderbird sessions sampled.")
839
+
840
+ append_transfer_rows(all_rows, args.out)
841
+
842
+
843
+ if __name__ == "__main__":
844
+ main()