structuremappingmemory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. sma/__init__.py +5 -0
  2. sma/__main__.py +5 -0
  3. sma/agent/__init__.py +5 -0
  4. sma/agent/adapter_draft.py +217 -0
  5. sma/agent/api.py +67 -0
  6. sma/agent/comparison.py +591 -0
  7. sma/agent/llm.py +280 -0
  8. sma/agent/policies.py +21 -0
  9. sma/agent/service.py +95 -0
  10. sma/cli.py +65 -0
  11. sma/encoders/__init__.py +38 -0
  12. sma/encoders/agentobs.py +27 -0
  13. sma/encoders/base.py +23 -0
  14. sma/encoders/code_treesitter.py +64 -0
  15. sma/encoders/coverage.py +80 -0
  16. sma/encoders/draft_adapter.py +183 -0
  17. sma/encoders/healthcare.py +207 -0
  18. sma/encoders/logs_drain.py +142 -0
  19. sma/encoders/prose_tier1.py +57 -0
  20. sma/encoders/structured.py +57 -0
  21. sma/encoders/traces.py +45 -0
  22. sma/eval/__init__.py +2 -0
  23. sma/eval/agentic/__init__.py +35 -0
  24. sma/eval/agentic/arms/__init__.py +0 -0
  25. sma/eval/agentic/arms/cyber.py +48 -0
  26. sma/eval/agentic/arms/discovery.py +35 -0
  27. sma/eval/agentic/arms/finance.py +38 -0
  28. sma/eval/agentic/arms/legal.py +74 -0
  29. sma/eval/agentic/arms/medicine.py +45 -0
  30. sma/eval/agentic/harness.py +275 -0
  31. sma/eval/agentic/memories.py +308 -0
  32. sma/eval/agentic/metrics.py +82 -0
  33. sma/eval/agentic_qa/__init__.py +27 -0
  34. sma/eval/agentic_qa/agent.py +383 -0
  35. sma/eval/agentic_qa/metrics.py +239 -0
  36. sma/eval/agentic_qa/pools.py +197 -0
  37. sma/eval/arn.py +65 -0
  38. sma/eval/baselines/__init__.py +6 -0
  39. sma/eval/baselines/bge_dense.py +54 -0
  40. sma/eval/baselines/bm25.py +18 -0
  41. sma/eval/baselines/dense.py +42 -0
  42. sma/eval/baselines/hipporag.py +235 -0
  43. sma/eval/baselines/hybrid_rrf.py +30 -0
  44. sma/eval/baselines/longcontext_llm.py +124 -0
  45. sma/eval/baselines/rerank.py +41 -0
  46. sma/eval/baselines/splade.py +77 -0
  47. sma/eval/baselines/wl_kernel.py +163 -0
  48. sma/eval/bugsinpy.py +358 -0
  49. sma/eval/bugsinpy_families.py +164 -0
  50. sma/eval/crossdomain.py +89 -0
  51. sma/eval/diabetes.py +61 -0
  52. sma/eval/drift_env.py +26 -0
  53. sma/eval/drift_metrics.py +24 -0
  54. sma/eval/family_labels.py +167 -0
  55. sma/eval/fraud_elliptic/__init__.py +29 -0
  56. sma/eval/fraud_elliptic/encoder.py +279 -0
  57. sma/eval/fraud_elliptic/eval.py +269 -0
  58. sma/eval/fraud_elliptic/test_encoder.py +123 -0
  59. sma/eval/ieee_cis.py +66 -0
  60. sma/eval/loghub.py +16 -0
  61. sma/eval/loghub_eval.py +480 -0
  62. sma/eval/longmemeval.py +51 -0
  63. sma/eval/memory_backends/__init__.py +2 -0
  64. sma/eval/memory_backends/base.py +22 -0
  65. sma/eval/memory_backends/context_only.py +14 -0
  66. sma/eval/memory_backends/rag_notes.py +17 -0
  67. sma/eval/memory_backends/shared_llm.py +30 -0
  68. sma/eval/memory_backends/sma_memory.py +54 -0
  69. sma/eval/memory_backends/zep_graphiti.py +33 -0
  70. sma/eval/metrics.py +32 -0
  71. sma/eval/ontology_bench.py +219 -0
  72. sma/eval/report.py +573 -0
  73. sma/eval/ssb_eval.py +216 -0
  74. sma/eval/ssb_generator.py +116 -0
  75. sma/eval/stats.py +108 -0
  76. sma/eval/transfer_eval.py +844 -0
  77. sma/index/__init__.py +15 -0
  78. sma/index/ann.py +21 -0
  79. sma/index/content_vectors.py +60 -0
  80. sma/index/inverted.py +63 -0
  81. sma/index/macfac.py +174 -0
  82. sma/ir/__init__.py +22 -0
  83. sma/ir/canon.py +106 -0
  84. sma/ir/schema.py +165 -0
  85. sma/ir/sexpr.py +86 -0
  86. sma/ir/signatures.py +76 -0
  87. sma/match/__init__.py +20 -0
  88. sma/match/conflicts.py +46 -0
  89. sma/match/engine.py +60 -0
  90. sma/match/explain.py +59 -0
  91. sma/match/infer.py +54 -0
  92. sma/match/kernels.py +54 -0
  93. sma/match/mdl.py +30 -0
  94. sma/match/merge_cpsat.py +77 -0
  95. sma/match/merge_greedy.py +15 -0
  96. sma/match/mh.py +177 -0
  97. sma/match/ses.py +84 -0
  98. sma/match/types.py +115 -0
  99. sma/match/verifier.py +27 -0
  100. sma/ontology/__init__.py +45 -0
  101. sma/ontology/attack.py +134 -0
  102. sma/ontology/cpc.py +69 -0
  103. sma/ontology/graph.py +58 -0
  104. sma/ontology/loader.py +262 -0
  105. sma/ontology/mitre_xml.py +67 -0
  106. sma/ontology/mount.py +101 -0
  107. sma/ontology/rdf_loader.py +75 -0
  108. sma/ontology/registry.py +115 -0
  109. sma/ontology/router.py +69 -0
  110. sma/ontology/usgaap.py +73 -0
  111. sma/sage/__init__.py +6 -0
  112. sma/sage/assimilate.py +12 -0
  113. sma/sage/pools.py +105 -0
  114. sma/sage/probabilities.py +10 -0
  115. sma/store/__init__.py +6 -0
  116. sma/store/lmdb_store.py +78 -0
  117. sma/store/registry.py +26 -0
  118. sma/store/wal.py +26 -0
  119. sma/ui/app.py +642 -0
  120. structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
  121. structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
  122. structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
  123. structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
  124. structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
  125. structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,124 @@
1
+ """B6 long-context LLM baseline (blueprint section 8.1 B6).
2
+
3
+ Controls for "maybe you don't need retrieval at all": stuff the query session
4
+ plus its top-20 BM25 candidate precedents (labels included) into one prompt
5
+ and ask deepseek-chat (temperature 0, max_tokens 10) to label the query.
6
+
7
+ API key comes from SMA_DEEPSEEK_API_KEY or the repo .env, via the same lookup
8
+ the agent layer uses (sma.agent.llm._env_key). Errors get exactly one retry;
9
+ a second failure (or an unparseable reply) marks the row failed.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import time
15
+
16
+ from sma.agent.llm import DEEPSEEK_BASE_URL, DEEPSEEK_KEY_ENV, DEEPSEEK_MODEL, _env_key
17
+
18
+ CANDIDATE_CHARS = 800
19
+ QUERY_CHARS = 1600
20
+
21
+ SYSTEM_PROMPT = (
22
+ "You are an incident triage assistant. You label log sessions as Anomaly or "
23
+ "Normal by analogy to labeled precedent sessions. Reply with exactly one word: "
24
+ "Anomaly or Normal."
25
+ )
26
+
27
+
28
+ def build_prompt(query_text: str, precedents: list[tuple[str, str]]) -> str:
29
+ """precedents: list of (label, session_text), already ranked."""
30
+ lines = [
31
+ "Query session (label unknown):",
32
+ query_text[:QUERY_CHARS],
33
+ "",
34
+ "Labeled precedent sessions (retrieved by lexical similarity, most similar first):",
35
+ ]
36
+ for i, (label, text) in enumerate(precedents, start=1):
37
+ lines.append(f"[{i}] ({label}) {text[:CANDIDATE_CHARS]}")
38
+ lines.append("")
39
+ lines.append(
40
+ "Based on these precedents, is the query session Anomaly or Normal? "
41
+ "Answer with exactly one word."
42
+ )
43
+ return "\n".join(lines)
44
+
45
+
46
+ def parse_label(content: str) -> str | None:
47
+ lowered = content.lower()
48
+ has_anomaly = "anomal" in lowered
49
+ has_normal = "normal" in lowered
50
+ if has_anomaly and not has_normal:
51
+ return "Anomaly"
52
+ if has_normal and not has_anomaly:
53
+ return "Normal"
54
+ if has_anomaly and has_normal: # ambiguous reply
55
+ first_a = lowered.find("anomal")
56
+ first_n = lowered.find("normal")
57
+ return "Anomaly" if first_a < first_n else "Normal"
58
+ return None
59
+
60
+
61
+ class LongContextDeepSeek:
62
+ def __init__(
63
+ self,
64
+ model: str = DEEPSEEK_MODEL,
65
+ api_key: str | None = None,
66
+ timeout: float = 60.0,
67
+ retry_sleep: float = 2.0,
68
+ ):
69
+ self.model = model
70
+ self.api_key = api_key or _env_key(DEEPSEEK_KEY_ENV)
71
+ self.timeout = timeout
72
+ self.retry_sleep = retry_sleep
73
+ self.calls = 0
74
+ self.failures: list[str] = []
75
+ self.total_prompt_tokens = 0
76
+ self.total_completion_tokens = 0
77
+
78
+ def _call_once(self, prompt: str) -> str:
79
+ import httpx
80
+
81
+ response = httpx.post(
82
+ f"{DEEPSEEK_BASE_URL}/chat/completions",
83
+ headers={"Authorization": f"Bearer {self.api_key}"},
84
+ json={
85
+ "model": self.model,
86
+ "messages": [
87
+ {"role": "system", "content": SYSTEM_PROMPT},
88
+ {"role": "user", "content": prompt},
89
+ ],
90
+ "temperature": 0,
91
+ "max_tokens": 10,
92
+ },
93
+ timeout=self.timeout,
94
+ )
95
+ response.raise_for_status()
96
+ payload = response.json()
97
+ self.calls += 1
98
+ usage = payload.get("usage", {})
99
+ self.total_prompt_tokens += int(usage.get("prompt_tokens", 0))
100
+ self.total_completion_tokens += int(usage.get("completion_tokens", 0))
101
+ return (payload["choices"][0]["message"]["content"] or "").strip()
102
+
103
+ def classify(
104
+ self, query_id: str, query_text: str, precedents: list[tuple[str, str]]
105
+ ) -> str | None:
106
+ """Return 'Anomaly'/'Normal', or None when the row failed (after one retry)."""
107
+ if not self.api_key:
108
+ self.failures.append(f"{query_id}: {DEEPSEEK_KEY_ENV} not set")
109
+ return None
110
+ prompt = build_prompt(query_text, precedents)
111
+ last_error = ""
112
+ for attempt in range(2): # one initial call + one retry
113
+ try:
114
+ content = self._call_once(prompt)
115
+ label = parse_label(content)
116
+ if label is not None:
117
+ return label
118
+ last_error = f"unparseable reply: {content!r}"
119
+ except Exception as exc: # noqa: BLE001 - any transport/API error retries once
120
+ last_error = f"{type(exc).__name__}: {exc}"
121
+ if attempt == 0:
122
+ time.sleep(self.retry_sleep)
123
+ self.failures.append(f"{query_id}: {last_error}")
124
+ return None
@@ -0,0 +1,41 @@
1
+ """Cross-encoder reranking (cross-encoder/ms-marco-MiniLM-L-6-v2).
2
+
3
+ Reranks a candidate pool (here: top-20 of Hybrid-RRF) and returns the top-k.
4
+ Raw model outputs are logits; we pass them through a sigmoid so the scores
5
+ are positive and usable by the protocol's weighted top-5 label vote (a
6
+ negative weight would flip a vote, which is not what 'weighted' means there).
7
+ Candidate texts are truncated to ``max_chars`` before pairing; the tokenizer
8
+ truncates to 512 wordpieces anyway, so this only bounds tokenizer cost.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import numpy as np
14
+
15
+ MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
16
+
17
+
18
+ class CrossEncoderReranker:
19
+ def __init__(self, model_name: str = MODEL_NAME, max_chars: int = 2000):
20
+ from sentence_transformers import CrossEncoder
21
+
22
+ self.model = CrossEncoder(model_name, device="cpu", max_length=512)
23
+ self.max_chars = max_chars
24
+
25
+ def rerank(
26
+ self,
27
+ query_text: str,
28
+ candidates: list[tuple[str, str]],
29
+ top_k: int = 10,
30
+ ) -> list[tuple[str, float]]:
31
+ if not candidates:
32
+ return []
33
+ q = query_text[: self.max_chars]
34
+ pairs = [(q, text[: self.max_chars]) for _, text in candidates]
35
+ logits = self.model.predict(pairs, show_progress_bar=False)
36
+ scores = 1.0 / (1.0 + np.exp(-np.asarray(logits, dtype=np.float64)))
37
+ ranked = sorted(
38
+ zip((doc_id for doc_id, _ in candidates), map(float, scores)),
39
+ key=lambda row: (-row[1], row[0]),
40
+ )
41
+ return ranked[:top_k]
@@ -0,0 +1,77 @@
1
+ """SPLADE learned sparse retrieval (naver/splade-cocondenser-ensembledistil).
2
+
3
+ CPU-only. Standard SPLADE document/query representation:
4
+ rep = max over token positions of log(1 + relu(MLM logits)),
5
+ masked by attention. Documents are batch-encoded once into a scipy CSR
6
+ matrix; query scoring is a sparse dot product. Per-query latency includes
7
+ the query forward pass (same convention as the other neural baselines).
8
+
9
+ Inputs are truncated to ``max_length`` wordpieces (default 256) to keep the
10
+ CPU forward pass tractable on long log sessions; the same truncation applies
11
+ to documents and queries, so no method gets privileged context.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import numpy as np
17
+ from scipy import sparse
18
+
19
+ MODEL_NAME = "naver/splade-cocondenser-ensembledistil"
20
+
21
+
22
+ class SpladeRetriever:
23
+ def __init__(
24
+ self,
25
+ model_name: str = MODEL_NAME,
26
+ max_length: int = 256,
27
+ batch_size: int = 8,
28
+ ):
29
+ import torch
30
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
31
+
32
+ self.torch = torch
33
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
35
+ self.model.eval()
36
+ self.max_length = max_length
37
+ self.batch_size = batch_size
38
+ self.doc_ids: list[str] = []
39
+ self.doc_matrix: sparse.csr_matrix | None = None
40
+
41
+ def _encode_batch(self, texts: list[str]) -> np.ndarray:
42
+ torch = self.torch
43
+ tokens = self.tokenizer(
44
+ texts,
45
+ return_tensors="pt",
46
+ padding=True,
47
+ truncation=True,
48
+ max_length=self.max_length,
49
+ )
50
+ with torch.no_grad():
51
+ logits = self.model(**tokens).logits # (B, T, V)
52
+ # log(1 + relu(logits)), max-pooled over valid token positions.
53
+ weights = torch.log1p(torch.relu(logits))
54
+ mask = tokens["attention_mask"].unsqueeze(-1)
55
+ reps = (weights * mask).max(dim=1).values # (B, V)
56
+ return reps.numpy()
57
+
58
+ def encode(self, texts: list[str]) -> sparse.csr_matrix:
59
+ rows = []
60
+ for start in range(0, len(texts), self.batch_size):
61
+ batch = texts[start : start + self.batch_size]
62
+ rows.append(sparse.csr_matrix(self._encode_batch(batch)))
63
+ return sparse.vstack(rows) if rows else sparse.csr_matrix((0, 0))
64
+
65
+ def build(self, documents: list[tuple[str, str]]) -> None:
66
+ self.doc_ids = [doc_id for doc_id, _ in documents]
67
+ self.doc_matrix = self.encode([text for _, text in documents])
68
+
69
+ def retrieve(self, query_text: str, k: int = 10) -> list[tuple[str, float]]:
70
+ if self.doc_matrix is None:
71
+ return []
72
+ q = sparse.csr_matrix(self._encode_batch([query_text]))
73
+ scores = np.asarray((self.doc_matrix @ q.T).todense()).ravel()
74
+ ranked = sorted(
75
+ zip(self.doc_ids, map(float, scores)), key=lambda row: (-row[1], row[0])
76
+ )
77
+ return ranked[:k]
@@ -0,0 +1,163 @@
1
+ """Weisfeiler-Leman graph-kernel control over SMA's own Tier-0 cases.
2
+
3
+ The internal control for the ladder: does *generic* graph similarity computed
4
+ on the SAME deterministic extraction (sma.encoders get_encoder("logs")) match
5
+ structure-mapping retrieval? If yes, SMA's edge is the extraction; if no, the
6
+ edge is the mapping mathematics.
7
+
8
+ Graph construction (per case):
9
+ - nodes = unique expressions (keyed by canonical s-expression) plus unique
10
+ entities (keyed by (name, type));
11
+ - edges = statement -> argument, position-annotated.
12
+
13
+ Node labels (iteration 0):
14
+ - statements: the functor (exactly what SMA's own content vectors see);
15
+ - entities whose names are arbitrary per-session identifiers (types
16
+ "event", "session") are labeled by TYPE ONLY. This is not a kindness, it
17
+ is a necessity: the session entity touches nearly every statement, so a
18
+ case-unique session name would make every refined label case-unique after
19
+ one iteration and the kernel would degenerate to zero similarity.
20
+ - all other entities (components, event_type tokens, integer counts) keep
21
+ "type:name" -- their names are shared vocabulary, i.e. real content.
22
+
23
+ Refinement: 2 WL iterations; new label = hash(own label, position-sorted child
24
+ labels, sorted parent labels). Similarity = cosine over the concatenated label
25
+ histograms of iterations 0..2 (Shervashidze et al. 2011, WL subtree kernel,
26
+ normalized).
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import hashlib
32
+ from collections import Counter
33
+
34
+ import numpy as np
35
+ from scipy import sparse
36
+
37
+ from sma.ir.schema import Case, Entity, Statement
38
+ from sma.ir.sexpr import dumps_statement
39
+
40
+ # Entity types whose names are arbitrary per-session identifiers.
41
+ _ID_LIKE_TYPES = frozenset({"event", "session"})
42
+
43
+ WL_ITERATIONS = 2
44
+
45
+
46
+ def _entity_label(ent: Entity) -> str:
47
+ if ent.type in _ID_LIKE_TYPES:
48
+ return f"ent:{ent.type}"
49
+ return f"ent:{ent.type}:{ent.name}"
50
+
51
+
52
+ def _hash_label(payload: str, memo: dict[str, str]) -> str:
53
+ cached = memo.get(payload)
54
+ if cached is None:
55
+ cached = hashlib.blake2b(payload.encode("utf-8"), digest_size=8).hexdigest()
56
+ memo[payload] = cached
57
+ return cached
58
+
59
+
60
+ def wl_histogram(case: Case, iterations: int = WL_ITERATIONS) -> Counter[str]:
61
+ """Concatenated WL label histogram (iterations 0..n) for one case."""
62
+ # --- build the node/edge structure ---------------------------------
63
+ node_labels: list[str] = []
64
+ children: list[list[tuple[int, int]]] = [] # node -> [(arg_pos, child_node)]
65
+ parents: list[list[int]] = []
66
+ stmt_idx: dict[str, int] = {}
67
+ ent_idx: dict[tuple[str, str], int] = {}
68
+
69
+ def add_node(label: str) -> int:
70
+ node_labels.append(label)
71
+ children.append([])
72
+ parents.append([])
73
+ return len(node_labels) - 1
74
+
75
+ def visit(node: Statement | Entity) -> int:
76
+ if isinstance(node, Entity):
77
+ key = (node.name, node.type)
78
+ idx = ent_idx.get(key)
79
+ if idx is None:
80
+ idx = add_node(_entity_label(node))
81
+ ent_idx[key] = idx
82
+ return idx
83
+ skey = dumps_statement(node)
84
+ idx = stmt_idx.get(skey)
85
+ if idx is not None:
86
+ return idx
87
+ idx = add_node(f"f:{node.functor}")
88
+ stmt_idx[skey] = idx
89
+ for pos, arg in enumerate(node.args):
90
+ child = visit(arg)
91
+ children[idx].append((pos, child))
92
+ parents[child].append(idx)
93
+ return idx
94
+
95
+ for statement in case.statements:
96
+ visit(statement)
97
+
98
+ # --- WL refinement ---------------------------------------------------
99
+ memo: dict[str, str] = {}
100
+ histogram: Counter[str] = Counter()
101
+ labels = list(node_labels)
102
+ for node_label in labels:
103
+ histogram[f"wl0:{node_label}"] += 1
104
+ for it in range(1, iterations + 1):
105
+ new_labels = []
106
+ for idx, own in enumerate(labels):
107
+ child_part = sorted(f"c{pos}:{labels[child]}" for pos, child in children[idx])
108
+ parent_part = sorted(f"p:{labels[p]}" for p in parents[idx])
109
+ payload = own + "|" + ",".join(child_part) + "|" + ",".join(parent_part)
110
+ new_labels.append(_hash_label(payload, memo))
111
+ labels = new_labels
112
+ for node_label in labels:
113
+ histogram[f"wl{it}:{node_label}"] += 1
114
+ return histogram
115
+
116
+
117
+ class WLKernelRetriever:
118
+ """Cosine-normalized WL subtree kernel retrieval over encoded cases."""
119
+
120
+ def __init__(self, iterations: int = WL_ITERATIONS):
121
+ self.iterations = iterations
122
+ self.doc_ids: list[str] = []
123
+ self.feature_index: dict[str, int] = {}
124
+ self.doc_matrix: sparse.csr_matrix | None = None # rows L2-normalized
125
+
126
+ def build(self, cases: list[Case]) -> None:
127
+ histograms = [wl_histogram(c, self.iterations) for c in cases]
128
+ self.doc_ids = [c.case_id for c in cases]
129
+ self.feature_index = {}
130
+ for hist in histograms:
131
+ for feat in hist:
132
+ if feat not in self.feature_index:
133
+ self.feature_index[feat] = len(self.feature_index)
134
+ rows, cols, vals = [], [], []
135
+ for row, hist in enumerate(histograms):
136
+ for feat, count in hist.items():
137
+ rows.append(row)
138
+ cols.append(self.feature_index[feat])
139
+ vals.append(float(count))
140
+ matrix = sparse.csr_matrix(
141
+ (vals, (rows, cols)), shape=(len(histograms), max(len(self.feature_index), 1))
142
+ )
143
+ norms = np.sqrt(matrix.multiply(matrix).sum(axis=1)).A.ravel()
144
+ norms[norms == 0] = 1.0
145
+ self.doc_matrix = sparse.diags(1.0 / norms) @ matrix
146
+
147
+ def retrieve(self, query_case: Case, k: int = 10) -> list[tuple[str, float]]:
148
+ if self.doc_matrix is None:
149
+ return []
150
+ hist = wl_histogram(query_case, self.iterations)
151
+ # Query norm uses the FULL histogram (including features absent from
152
+ # the index vocabulary) so cosine is honest, not inflated.
153
+ q_norm = float(np.sqrt(sum(v * v for v in hist.values()))) or 1.0
154
+ q = np.zeros(self.doc_matrix.shape[1])
155
+ for feat, count in hist.items():
156
+ col = self.feature_index.get(feat)
157
+ if col is not None:
158
+ q[col] = count / q_norm
159
+ scores = self.doc_matrix @ q
160
+ ranked = sorted(
161
+ zip(self.doc_ids, map(float, scores)), key=lambda row: (-row[1], row[0])
162
+ )
163
+ return ranked[:k]