topolm 0.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
topolm-0.9.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 JadeyGraham96
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
topolm-0.9.1/PKG-INFO ADDED
@@ -0,0 +1,60 @@
1
+ Metadata-Version: 2.4
2
+ Name: topolm
3
+ Version: 0.9.1
4
+ Summary: Topology-native explainable language model prototype powered by Topologist
5
+ Author: Robert McMenemy
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: numpy>=1.23
10
+ Requires-Dist: networkx>=3.0
11
+ Requires-Dist: topologist>=0.4.0
12
+ Provides-Extra: ml
13
+ Requires-Dist: scikit-learn>=1.3; extra == "ml"
14
+ Requires-Dist: torch>=2.0; extra == "ml"
15
+ Provides-Extra: hf
16
+ Requires-Dist: datasets>=2.18; extra == "hf"
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest>=7.0; extra == "dev"
19
+ Requires-Dist: ruff>=0.5; extra == "dev"
20
+ Requires-Dist: build>=0.10; extra == "dev"
21
+ Requires-Dist: twine>=4.0; extra == "dev"
22
+ Dynamic: license-file
23
+
24
+ # TopoLM
25
+
26
+ **TopoLM** is a topology-native, explainable language model prototype powered by `topologist`.
27
+
28
+ ## Quick start
29
+
30
+ ```bash
31
+ pip install -e .
32
+ python examples/basic_demo.py
33
+ topolm demo
34
+ ```
35
+
36
+ ## API
37
+
38
+ ```python
39
+ from topolm import TopoLM, Config, load_hf_dataset
40
+
41
+ model = TopoLM(Config()).fit(corpus)
42
+ print(model.distribution("clarithromycin inhibits", top_k=5))
43
+ print(model.generate("cyp3a4 inhibition", decoding="beam"))
44
+
45
+ # training from a Hugging Face dataset
46
+ texts = load_hf_dataset("wikitext", split="train", text_field="text", sample_size=1000)
47
+ model = TopoLM(Config()).fit_texts(texts)
48
+ ```
49
+
50
+ ## Layout
51
+
52
+ ```text
53
+ topolm/
54
+ __init__.py
55
+ config.py
56
+ core.py
57
+ cli.py
58
+ examples/
59
+ tests/
60
+ ```
topolm-0.9.1/README.md ADDED
@@ -0,0 +1,37 @@
1
+ # TopoLM
2
+
3
+ **TopoLM** is a topology-native, explainable language model prototype powered by `topologist`.
4
+
5
+ ## Quick start
6
+
7
+ ```bash
8
+ pip install -e .
9
+ python examples/basic_demo.py
10
+ topolm demo
11
+ ```
12
+
13
+ ## API
14
+
15
+ ```python
16
+ from topolm import TopoLM, Config, load_hf_dataset
17
+
18
+ model = TopoLM(Config()).fit(corpus)
19
+ print(model.distribution("clarithromycin inhibits", top_k=5))
20
+ print(model.generate("cyp3a4 inhibition", decoding="beam"))
21
+
22
+ # training from a Hugging Face dataset
23
+ texts = load_hf_dataset("wikitext", split="train", text_field="text", sample_size=1000)
24
+ model = TopoLM(Config()).fit_texts(texts)
25
+ ```
26
+
27
+ ## Layout
28
+
29
+ ```text
30
+ topolm/
31
+ __init__.py
32
+ config.py
33
+ core.py
34
+ cli.py
35
+ examples/
36
+ tests/
37
+ ```
@@ -0,0 +1,27 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "topolm"
7
+ version = "0.9.1"
8
+ description = "Topology-native explainable language model prototype powered by Topologist"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [{name = "Robert McMenemy"}]
12
+ dependencies = ["numpy>=1.23", "networkx>=3.0", "topologist>=0.4.0"]
13
+
14
+ [project.optional-dependencies]
15
+ ml = ["scikit-learn>=1.3", "torch>=2.0"]
16
+ hf = ["datasets>=2.18"]
17
+ dev = ["pytest>=7.0", "ruff>=0.5", "build>=0.10", "twine>=4.0"]
18
+
19
+ [project.scripts]
20
+ topolm = "topolm.cli:main"
21
+
22
+ [tool.setuptools.packages.find]
23
+ where = ["."]
24
+ include = ["topolm*"]
25
+
26
+ [tool.pytest.ini_options]
27
+ testpaths = ["tests"]
topolm-0.9.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,56 @@
1
+ from topolm import Config, TopoLM, load_hf_dataset
2
+
3
+ CORPUS = """
4
+ The cat sat on the mat.
5
+ The dog sat on the floor.
6
+ The attacker used CVE-2024-1234 to access the admin panel.
7
+ CYP3A4 inhibition increases drug exposure.
8
+ Clarithromycin inhibits CYP3A4.
9
+ Clarithromycin may increase simvastatin exposure.
10
+ """
11
+
12
+ def test_predict_domain_terms():
13
+ model = TopoLM(Config()).fit(CORPUS)
14
+ preds = model.distribution("clarithromycin inhibits", 5)
15
+ assert preds
16
+ assert preds[0].text == "cyp3a4"
17
+
18
+ def test_generation_is_bounded_and_fluentish():
19
+ model = TopoLM(Config()).fit(CORPUS)
20
+ out = model.generate("cyp3a4 inhibition", decoding="beam", max_units=16)
21
+ assert "CYP3A4" in out
22
+ assert len(out.split()) <= 20
23
+
24
+ def test_save_load(tmp_path):
25
+ model = TopoLM(Config()).fit(CORPUS)
26
+ path = model.save(tmp_path / "model")
27
+ loaded = TopoLM.load(path)
28
+ assert loaded.distribution("clarithromycin inhibits", 3)[0].text == "cyp3a4"
29
+
30
+
31
+ def test_save_load_round_trip_state(tmp_path):
32
+ model = TopoLM(Config()).fit(CORPUS)
33
+ path = model.save(tmp_path / "model_state")
34
+ loaded = TopoLM.load(path)
35
+ assert loaded.mem.unit_counts == model.mem.unit_counts
36
+ assert loaded.mem.phrase_counts == model.mem.phrase_counts
37
+ assert loaded.mem.edge_counts == model.mem.edge_counts
38
+ assert loaded.mem.domain_counts == model.mem.domain_counts
39
+
40
+
41
+ def test_fit_texts_helper():
42
+ model = TopoLM(Config()).fit_texts([s.strip() for s in CORPUS.splitlines() if s.strip()])
43
+ assert model.distribution("clarithromycin inhibits", 5)[0].text == "cyp3a4"
44
+
45
+
46
+ def test_hf_dataset_loader_simple(monkeypatch):
47
+ import pytest
48
+ datasets = pytest.importorskip("datasets")
49
+ from datasets import Dataset
50
+
51
+ def fake_load_dataset(name, split="train"):
52
+ return Dataset.from_dict({"text": ["Topologist test text."]})
53
+
54
+ monkeypatch.setattr("topolm.datasets.load_dataset", fake_load_dataset)
55
+ texts = load_hf_dataset("testset", split="train", text_field="text", sample_size=1)
56
+ assert texts == ["Topologist test text."]
@@ -0,0 +1,16 @@
1
+ from .config import Config
2
+ from .core import TopoLM, Tokenizer, Corpus, NGram, evaluate, eval_examples
3
+ from .datasets import hf_dataset_texts, load_hf_dataset
4
+
5
+ __all__ = [
6
+ "Config",
7
+ "TopoLM",
8
+ "Tokenizer",
9
+ "Corpus",
10
+ "NGram",
11
+ "evaluate",
12
+ "eval_examples",
13
+ "load_hf_dataset",
14
+ "hf_dataset_texts",
15
+ ]
16
+ __version__ = "0.9.1"
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from .core import TopoLM
5
+ from .config import Config
6
+
7
+ DEMO = """
8
+ The cat sat on the mat.
9
+ The dog sat on the floor.
10
+ The attacker used CVE-2024-1234 to access the admin panel.
11
+ CYP3A4 inhibition increases drug exposure.
12
+ Clarithromycin inhibits CYP3A4.
13
+ Clarithromycin may increase simvastatin exposure.
14
+ """
15
+
16
+ def build_demo_model():
17
+ return TopoLM(Config()).fit(DEMO)
18
+
19
+ def main(argv=None):
20
+ parser = argparse.ArgumentParser(prog="topolm")
21
+ sub = parser.add_subparsers(dest="cmd")
22
+ sub.add_parser("demo")
23
+ p = sub.add_parser("predict"); p.add_argument("context")
24
+ g = sub.add_parser("generate"); g.add_argument("prompt"); g.add_argument("--decoding", default="beam", choices=["beam", "nucleus", "greedy"])
25
+ args = parser.parse_args(argv)
26
+ model = build_demo_model()
27
+ if args.cmd == "predict":
28
+ for pred in model.distribution(args.context, 5):
29
+ print(f"{pred.text}\t{pred.probability:.3f}\t{pred.score:.3f}")
30
+ elif args.cmd == "generate":
31
+ print(model.generate(args.prompt, decoding=args.decoding))
32
+ else:
33
+ print(model.generate("clarithromycin inhibits", decoding="beam"))
34
+
35
+ if __name__ == "__main__":
36
+ main()
@@ -0,0 +1,23 @@
1
+ from dataclasses import dataclass
2
+
3
+ PUNCT = {".", ",", ";", ":", "?", "!"}
4
+ BOUNDARY = {"<bos>", "<eos>"}
5
+ HUBS = {"the", "a", "an", "and", "or", "of", "to", "in", "on", "with", "when", ".", ","}
6
+
7
+ @dataclass
8
+ class Config:
9
+ dim: int = 1024
10
+ seed: int = 42
11
+ window: int = 8
12
+ phrase_lengths: tuple[int, ...] = (2, 3, 4, 5)
13
+ max_candidates: int = 96
14
+ inference_candidates: int = 48
15
+ prediction_cache_max: int = 4096
16
+ temperature: float = 0.75
17
+ max_runtime_seconds: float = 5.0
18
+ default_top_p: float = 0.88
19
+ default_beam_width: int = 4
20
+ fast_dev_mode: bool = True
21
+ max_reranker_sentences: int = 80
22
+ negatives_per_positive: int = 2
23
+ hub_penalty: float = 0.10
@@ -0,0 +1,832 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import random
7
+ import re
8
+ import shutil
9
+ import time
10
+ from collections import Counter, defaultdict
11
+ from dataclasses import asdict, dataclass, field
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import networkx as nx
16
+ import numpy as np
17
+
18
+ from .config import BOUNDARY, HUBS, PUNCT, Config
19
+
20
+ try:
21
+ from topologist import Topologist, TopologistConfig
22
+ except Exception: # pragma: no cover
23
+ Topologist = None
24
+ TopologistConfig = None
25
+
26
+
27
+ @dataclass
28
+ class Evidence:
29
+ candidate: str
30
+ source: str
31
+ reason: str
32
+ weight: float
33
+ confidence: float
34
+ unit_type: str = "word"
35
+ phrase_length: int = 0
36
+ relation: str | None = None
37
+ domain: str | None = None
38
+ path: list[str] = field(default_factory=list)
39
+
40
+
41
+ @dataclass
42
+ class Prediction:
43
+ text: str
44
+ unit_type: str
45
+ score: float
46
+ probability: float = 0.0
47
+ reasons: list[str] = field(default_factory=list)
48
+ breakdown: dict[str, float] = field(default_factory=dict)
49
+ evidence: list[Evidence] = field(default_factory=list)
50
+ paths: list[list[str]] = field(default_factory=list)
51
+
52
+
53
+ class FallbackTopo:
54
+ def __init__(self) -> None:
55
+ self.graph = nx.DiGraph()
56
+
57
+ def add_node(self, node: str, kind: str | None = None, **kwargs: Any) -> None:
58
+ meta = dict(kwargs)
59
+ if kind is not None:
60
+ meta["kind"] = kind
61
+ self.graph.add_node(node, **meta)
62
+
63
+ def add_edge(
64
+ self,
65
+ source: str,
66
+ relation: str,
67
+ target: str,
68
+ confidence: float | None = None,
69
+ metadata: dict[str, Any] | None = None,
70
+ **kwargs: Any,
71
+ ) -> None:
72
+ meta = dict(metadata or {})
73
+ if confidence is not None:
74
+ meta["confidence"] = confidence
75
+ meta.update(kwargs)
76
+ self.graph.add_edge(source, target, relation=relation, metadata=meta)
77
+
78
+ def update_global_state(self) -> None:
79
+ return None
80
+
81
+
82
+ class Tokenizer:
83
+ TOKEN_RE = re.compile(
84
+ r"<bos>|<eos>|CVE-\d{4}-\d+|[A-Za-z]{1,10}\d+[A-Za-z0-9]*|"
85
+ r"[A-Za-z]+(?:-[A-Za-z0-9]+)+|[A-Za-z]+/[A-Za-z]+|"
86
+ r"[A-Za-z]+(?:'[A-Za-z]+)?|\d+(?:\.\d+)?|[.!?,;:]",
87
+ re.I,
88
+ )
89
+ VERBS = {
90
+ "sat", "slept", "likes", "like", "liked", "increases", "increase", "increased",
91
+ "rises", "rise", "rose", "exploited", "exploit", "escalated", "escalate",
92
+ "detected", "detect", "exposed", "expose", "inhibits", "inhibit", "induces",
93
+ "metabolised", "metabolized", "causes", "cause", "requires", "require", "allows",
94
+ "allow", "access", "used", "use", "may", "interacts", "explain", "summarise",
95
+ "compare", "check", "store", "predicts", "uses", "rank",
96
+ }
97
+ PREPS = {"on", "in", "at", "by", "with", "from", "to", "of", "for", "over", "under", "through", "near", "when", "after", "before", "during"}
98
+ DOMAINS = {
99
+ "domestic": {"cat", "dog", "mat", "floor", "sofa", "fireplace", "garden", "warm", "places", "slept", "sat"},
100
+ "cybersecurity": {"attacker", "exploit", "exploited", "service", "privileges", "privilege", "scanner", "endpoint", "admin", "panel", "vulnerable", "cve-2024-1234", "access", "escalation"},
101
+ "drug_interaction": {"drug", "drug-drug", "interaction", "risk", "exposure", "inhibition", "cyp3a4", "clarithromycin", "simvastatin", "metabolised", "metabolized", "myopathy", "inhibits"},
102
+ "lm_research": {"language", "model", "predicts", "probabilities", "graph", "memory", "topological", "attention", "retrieval", "generation", "context"},
103
+ }
104
+ ENTITY_CANONICAL = {"cyp3a4": "CYP3A4", "cve-2024-1234": "CVE-2024-1234", "tnf-alpha": "TNF-alpha", "il-6": "IL-6"}
105
+
106
+ def tokenize(self, text: str) -> list[str]:
107
+ return [t.lower() for t in self.TOKEN_RE.findall(text) if t.strip()]
108
+
109
+ def units(self, text: str, keep_punct: bool = True) -> list[str]:
110
+ toks = self.tokenize(text)
111
+ return toks if keep_punct else [t for t in toks if t not in PUNCT]
112
+
113
+ def sentences(self, text: str) -> list[str]:
114
+ return [p.strip() for p in re.split(r"(?<=[.!?])\s+", text.strip()) if p.strip()]
115
+
116
+ def phrases(self, units: list[str], lengths: tuple[int, ...]):
117
+ for n in lengths:
118
+ for i in range(0, max(0, len(units) - n + 1)):
119
+ chunk = units[i : i + n]
120
+ yield "_".join(chunk), chunk, i
121
+
122
+ def pos(self, u: str) -> str:
123
+ if u in BOUNDARY:
124
+ return "boundary"
125
+ if u in PUNCT:
126
+ return "punctuation"
127
+ if u in {"the", "a", "an", "this", "that", "these", "those"}:
128
+ return "determiner"
129
+ if u in self.PREPS:
130
+ return "preposition"
131
+ if u in {"and", "or", "but", "because", "while", "if"}:
132
+ return "conjunction"
133
+ if u in self.VERBS or u.endswith("ed") or u.endswith("ing"):
134
+ return "verb"
135
+ if re.fullmatch(r"[a-z]{1,10}\d+[a-z0-9]*", u) or u.startswith("cve-"):
136
+ return "entity"
137
+ if u.endswith(("ous", "ful", "able", "ive", "al", "ic", "y")):
138
+ return "adjective"
139
+ return "noun"
140
+
141
+ def domain(self, units: list[str]) -> str:
142
+ clean = {u for u in units if u not in BOUNDARY and u not in PUNCT}
143
+ scores = Counter({d: len(clean & kws) for d, kws in self.DOMAINS.items()})
144
+ if not scores or scores.most_common(1)[0][1] == 0:
145
+ return "general"
146
+ return scores.most_common(1)[0][0]
147
+
148
+ def kind(self, unit: str) -> str:
149
+ if unit in BOUNDARY:
150
+ return "boundary"
151
+ if unit in PUNCT:
152
+ return "punctuation"
153
+ if self.pos(unit) == "entity":
154
+ return "entity"
155
+ return "word"
156
+
157
+ def restore_surface(self, unit: str) -> str:
158
+ return self.ENTITY_CANONICAL.get(unit, unit)
159
+
160
+
161
+ class Corpus:
162
+ def __init__(self):
163
+ self.tokenizer = Tokenizer()
164
+
165
+ def split(self, text: str, seed: int = 42, val: float = 0.15, test: float = 0.2):
166
+ sents = self.tokenizer.sentences(text)
167
+ rng = random.Random(seed)
168
+ rng.shuffle(sents)
169
+ nt = max(1, int(len(sents) * test))
170
+ nv = max(1, int(len(sents) * val)) if len(sents) >= 6 else 0
171
+ return sents[nt + nv :], sents[nt : nt + nv], sents[:nt]
172
+
173
+
174
+ class HDC:
175
+ def __init__(self, dim: int = 1024, seed: int = 42):
176
+ self.dim = dim
177
+ self.seed = seed
178
+ self.cache: dict[str, np.ndarray] = {}
179
+
180
+ def _seed(self, key: str) -> int:
181
+ import hashlib
182
+
183
+ return int.from_bytes(hashlib.sha256(f"{self.seed}:{key}".encode()).digest()[:8], "big")
184
+
185
+ def get(self, key: str) -> np.ndarray:
186
+ if key not in self.cache:
187
+ rng = np.random.default_rng(self._seed(key))
188
+ self.cache[key] = rng.choice(np.array([-1, 1], dtype=np.int8), size=self.dim)
189
+ return self.cache[key]
190
+
191
+ def bind(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
192
+ return (a * b).astype(np.int8)
193
+
194
+ def bundle(self, vectors: list[np.ndarray]) -> np.ndarray:
195
+ if not vectors:
196
+ return np.ones(self.dim, dtype=np.int8)
197
+ s = np.sum(np.stack(vectors).astype(np.int32), axis=0)
198
+ return np.where(s >= 0, 1, -1).astype(np.int8)
199
+
200
+ def encode(self, units: list[str], domain: str | None = None, layer: str = "local") -> np.ndarray:
201
+ vs = [self.bind(self.get(f"{layer}:pos:{i}"), self.get(f"unit:{u}")) for i, u in enumerate(units)]
202
+ if domain:
203
+ vs.append(self.get(f"domain:{domain}"))
204
+ return self.bundle(vs)
205
+
206
+ def lexical(self, text: str) -> np.ndarray:
207
+ return self.bundle([self.get(f"char:{c}") for c in text[:64]])
208
+
209
+ def sim(self, a: np.ndarray, b: np.ndarray) -> float:
210
+ return float(np.dot(a.astype(np.float32), b.astype(np.float32)) / self.dim)
211
+
212
+
213
+ class Memory:
214
+ def __init__(self, cfg: Config):
215
+ self.cfg = cfg
216
+ self.tok = Tokenizer()
217
+ self.hdc = HDC(cfg.dim, cfg.seed)
218
+ self.topo = self._make_topo()
219
+ self.unit_counts = Counter()
220
+ self.phrase_counts = Counter()
221
+ self.edge_counts = Counter()
222
+ self.pos_counts = Counter()
223
+ self.domain_counts = defaultdict(Counter)
224
+ self.contexts: list[dict[str, Any]] = []
225
+ self.sentences: list[str] = []
226
+ self.feedback: list[dict[str, Any]] = []
227
+ self.sid = 0
228
+ try:
229
+ self.topo.add_contradiction_pair("safe_with", "contraindicated_with")
230
+ except Exception:
231
+ pass
232
+
233
+ def _make_topo(self):
234
+ if Topologist is not None:
235
+ try:
236
+ return Topologist(TopologistConfig(dim=self.cfg.dim, seed=self.cfg.seed))
237
+ except Exception:
238
+ pass
239
+ return FallbackTopo()
240
+
241
+ def _topo_graph(self):
242
+ return getattr(self.topo, "graph", None)
243
+
244
+ def _serialize_value(self, value: Any) -> Any:
245
+ if isinstance(value, np.ndarray):
246
+ return value.tolist()
247
+ if isinstance(value, (np.integer, np.floating)):
248
+ return value.item()
249
+ if isinstance(value, dict):
250
+ return {k: self._serialize_value(v) for k, v in value.items()}
251
+ if isinstance(value, (list, tuple)):
252
+ return [self._serialize_value(v) for v in value]
253
+ return value
254
+
255
+ def _serialize_graph(self) -> dict[str, list[dict[str, Any]]]:
256
+ g = self._topo_graph()
257
+ if g is None:
258
+ return {"nodes": [], "edges": []}
259
+ nodes = [{"node": n, "data": self._serialize_value(dict(d))} for n, d in g.nodes(data=True)]
260
+ edges = [
261
+ {
262
+ "source": u,
263
+ "relation": d.get("relation"),
264
+ "target": v,
265
+ "data": self._serialize_value(dict(d)),
266
+ }
267
+ for u, v, d in g.edges(data=True)
268
+ ]
269
+ return {"nodes": nodes, "edges": edges}
270
+
271
+ def _add_node_direct(self, node: str, kind: str, meta: dict | None = None) -> None:
272
+ meta = meta or {}
273
+ try:
274
+ self.topo.add_node(node, kind=kind, **meta)
275
+ except TypeError:
276
+ self.topo.add_node(node, kind=kind, metadata=meta)
277
+ except Exception:
278
+ pass
279
+
280
+ def _add_edge_direct(
281
+ self,
282
+ s: str,
283
+ r: str,
284
+ t: str,
285
+ conf: float = 0.0,
286
+ evidence: list[str] | None = None,
287
+ meta: dict | None = None,
288
+ source_type: str = "corpus",
289
+ trust: float = 0.8,
290
+ ) -> None:
291
+ meta = dict(meta or {})
292
+ if evidence is not None:
293
+ meta["evidence"] = evidence
294
+ try:
295
+ self.topo.add_edge(s, r, t, confidence=conf, source_type=source_type, evidence=evidence or [], trust_score=trust, **meta)
296
+ except TypeError:
297
+ self.topo.add_edge(s, r, t, confidence=conf, metadata=meta)
298
+ except Exception:
299
+ try:
300
+ self.topo.graph.add_edge(s, t, relation=r, confidence=conf, metadata=meta)
301
+ except Exception:
302
+ pass
303
+
304
+ def _rebuild_graph(self, graph_data: dict[str, list[dict[str, Any]]]) -> None:
305
+ self.topo = self._make_topo()
306
+ for node in graph_data.get("nodes", []):
307
+ self._add_node_direct(node["node"], node["data"].get("kind", "unknown"), node["data"])
308
+ for edge in graph_data.get("edges", []):
309
+ self._add_edge_direct(
310
+ edge["source"],
311
+ edge["relation"],
312
+ edge["target"],
313
+ conf=float(edge["data"].get("confidence", 0.0)),
314
+ evidence=edge["data"].get("evidence"),
315
+ meta=edge["data"].get("metadata") or edge["data"],
316
+ )
317
+
318
+ def _state_dict(self) -> dict[str, Any]:
319
+ return {
320
+ "sentences": self.sentences,
321
+ "contexts": [
322
+ {
323
+ "id": c["id"],
324
+ "sentence": c["sentence"],
325
+ "raw": c["raw"],
326
+ "units": c["units"],
327
+ "domain": c["domain"],
328
+ }
329
+ for c in self.contexts
330
+ ],
331
+ "unit_counts": dict(self.unit_counts),
332
+ "phrase_counts": dict(self.phrase_counts),
333
+ "edge_counts": [
334
+ {"source": s, "relation": r, "target": t, "count": c}
335
+ for (s, r, t), c in self.edge_counts.items()
336
+ ],
337
+ "pos_counts": [
338
+ {"prev": prev, "next": nxt, "count": c}
339
+ for (prev, nxt), c in self.pos_counts.items()
340
+ ],
341
+ "domain_counts": {u: dict(c) for u, c in self.domain_counts.items()},
342
+ "graph": self._serialize_graph(),
343
+ }
344
+
345
+ def save_state(self, path: Path) -> None:
346
+ path = Path(path)
347
+ path.mkdir(parents=True, exist_ok=True)
348
+ (path / "memory.json").write_text(json.dumps(self._state_dict(), indent=2), encoding="utf-8")
349
+
350
+ def load_state(self, path: Path) -> None:
351
+ path = Path(path)
352
+ data = json.loads((path / "memory.json").read_text(encoding="utf-8"))
353
+ self.sentences = data.get("sentences", [])
354
+ self.contexts = [
355
+ {
356
+ **c,
357
+ "hv": self.hdc.encode(c["units"], c["domain"], "sentence"),
358
+ }
359
+ for c in data.get("contexts", [])
360
+ ]
361
+ self.unit_counts = Counter(data.get("unit_counts", {}))
362
+ self.phrase_counts = Counter(data.get("phrase_counts", {}))
363
+ self.edge_counts = Counter({(item["source"], item["relation"], item["target"]): item["count"] for item in data.get("edge_counts", [])})
364
+ self.pos_counts = Counter({(item["prev"], item["next"]): item["count"] for item in data.get("pos_counts", [])})
365
+ self.domain_counts = defaultdict(Counter, {u: Counter(c) for u, c in data.get("domain_counts", {}).items()})
366
+ self._rebuild_graph(data.get("graph", {}))
367
+
368
+ def nunit(self, u: str) -> str: return f"unit:{u}"
369
+ def nphrase(self, p: str) -> str: return f"phrase:{p}"
370
+ def npos(self, p: str) -> str: return f"pos:{p}"
371
+ def ndom(self, d: str) -> str: return f"domain:{d}"
372
+
373
+ def add_node(self, node: str, kind: str, meta: dict | None = None) -> None:
374
+ meta = meta or {}
375
+ try:
376
+ self.topo.add_node(node, kind=kind, **meta)
377
+ except TypeError:
378
+ self.topo.add_node(node, kind=kind, metadata=meta)
379
+ except Exception:
380
+ pass
381
+
382
+ def add_edge(self, s: str, r: str, t: str, conf: float = .5, evidence: list[str] | None = None, meta: dict | None = None, source_type: str = "corpus", trust: float = .8) -> None:
383
+ key = (s, r, t); self.edge_counts[key] += 1
384
+ freq = self.edge_counts[key]
385
+ meta = dict(meta or {}); meta.update({"frequency": freq, "last_seen": time.time(), "domain": meta.get("domain", "language_topology")})
386
+ conf = max(conf, min(1.0, math.log1p(freq) / 5))
387
+ try:
388
+ self.topo.add_edge(s, r, t, confidence=conf, source_type=source_type, evidence=evidence or [], trust_score=trust, **meta)
389
+ except TypeError:
390
+ self.topo.add_edge(s, r, t, confidence=conf, metadata=meta)
391
+ except Exception:
392
+ try: self.topo.graph.add_edge(s, t, relation=r, confidence=conf, metadata=meta)
393
+ except Exception: pass
394
+
395
+ def primary_domain(self, u: str) -> str:
396
+ return self.domain_counts[u].most_common(1)[0][0] if self.domain_counts.get(u) else "general"
397
+
398
+ def add_sentence(self, sentence: str) -> None:
399
+ raw = self.tok.units(sentence, True)
400
+ if not raw: return
401
+ dom = self.tok.domain(raw)
402
+ units = ["<bos>"] + raw + ["<eos>"]
403
+ self.sentences.append(sentence)
404
+ sid = self.sid; self.sid += 1
405
+ self.contexts.append({"id": sid, "sentence": sentence, "raw": raw, "units": units, "domain": dom, "hv": self.hdc.encode(units, dom, "sentence")})
406
+ self.add_node(f"sent:{sid}", "sentence", {"text": sentence, "domain": dom})
407
+ self.add_node(self.ndom(dom), "domain", {"domain": dom})
408
+ for i, u in enumerate(units):
409
+ pos = self.tok.pos(u); self.unit_counts[u] += 1; self.domain_counts[u][dom] += 1
410
+ self.add_node(self.nunit(u), "unit", {"text": u, "pos": pos, "kind": self.tok.kind(u), "frequency": self.unit_counts[u], "domain": self.primary_domain(u)})
411
+ self.add_node(self.npos(pos), "pos", {"pos": pos})
412
+ self.add_edge(self.nunit(u), "has_pos", self.npos(pos), .9, [sentence], {"position": i, "domain": dom})
413
+ self.add_edge(self.nunit(u), "domain_related", self.ndom(dom), .8, [sentence], {"position": i, "domain": dom})
414
+ for i in range(len(units) - 1):
415
+ a, b = units[i], units[i + 1]
416
+ pa, pb = self.tok.pos(a), self.tok.pos(b); self.pos_counts[(pa, pb)] += 1
417
+ self.add_edge(self.nunit(a), "next_unit", self.nunit(b), .58, [sentence], {"position": i, "domain": dom})
418
+ self.add_edge(self.npos(pa), "pos_transition", self.npos(pb), .65, [sentence], {"domain": dom})
419
+ for i, a in enumerate(units):
420
+ for j in range(i + 1, min(len(units), i + self.cfg.window + 1)):
421
+ self.add_edge(self.nunit(a), "appears_near", self.nunit(units[j]), max(.08, 1 / (j - i + 1)), [sentence], {"distance": j - i, "domain": dom})
422
+ for phrase, chunk, start in self.tok.phrases(units, self.cfg.phrase_lengths):
423
+ self.phrase_counts[phrase] += 1
424
+ self.add_node(self.nphrase(phrase), "phrase", {"text": phrase, "units": chunk, "frequency": self.phrase_counts[phrase], "domain": dom})
425
+ if start + len(chunk) < len(units):
426
+ nxt = units[start + len(chunk)]
427
+ self.add_edge(self.nphrase(phrase), "likely_next", self.nunit(nxt), .68, [sentence], {"phrase_length": len(chunk), "domain": dom})
428
+
429
+ def fit(self, text: str) -> None:
430
+ for s in self.tok.sentences(text):
431
+ self.add_sentence(s)
432
+ try: self.topo.update_global_state()
433
+ except Exception: pass
434
+
435
+ def compact(self, min_edge_frequency: int = 2) -> dict[str, int]:
436
+ g = self.topo.graph; removed = 0
437
+ for s, t, d in list(g.edges(data=True)):
438
+ meta = d.get("metadata", {}) if isinstance(d.get("metadata", {}), dict) else {}
439
+ freq = d.get("frequency", meta.get("frequency", 1))
440
+ if d.get("relation") == "appears_near" and int(freq) < min_edge_frequency:
441
+ try: g.remove_edge(s, t); removed += 1
442
+ except Exception: pass
443
+ return {"removed_edges": removed, "remaining_edges": g.number_of_edges()}
444
+
445
+
446
+ class ContextIndex:
447
+ def __init__(self, mem: Memory):
448
+ self.mem = mem; self.matrix = None; self.items = []
449
+ def build(self):
450
+ self.items = list(self.mem.contexts)
451
+ self.matrix = np.stack([x["hv"] for x in self.items]).astype(np.int8) if self.items else np.zeros((0, self.mem.cfg.dim), dtype=np.int8)
452
+ def search(self, context: str, top_k: int = 5):
453
+ if self.matrix is None: self.build()
454
+ raw = self.mem.tok.units(context, True); dom = self.mem.tok.domain(raw); q = self.mem.hdc.encode(["<bos>"] + raw, dom, "sentence")
455
+ if self.matrix.shape[0] == 0: return []
456
+ sims = (self.matrix.astype(np.float32) @ q.astype(np.float32)) / self.mem.cfg.dim
457
+ ids = np.argsort(-sims)[:top_k]
458
+ return [{"similarity": float(sims[i]), "sentence": self.items[i]["sentence"], "domain": self.items[i]["domain"], "units": self.items[i]["raw"]} for i in ids]
459
+
460
+
461
+ class TopoLM:
462
+ def __init__(self, cfg: Config | None = None):
463
+ self.cfg = cfg or Config()
464
+ self.mem = Memory(self.cfg)
465
+ self.tok = self.mem.tok
466
+ self.index = ContextIndex(self.mem)
467
+ self._cache: dict[tuple, Any] = {}
468
+ self._cache_order: list[tuple] = []
469
+
470
+ def fit(self, text: str):
471
+ self.mem.fit(text); self.index.build(); return self
472
+
473
+ def context_units(self, text: str):
474
+ raw = self.tok.units(text, True); return ["<bos>"] + raw, self.tok.domain(raw)
475
+
476
+ def _unit_from(self, node: str) -> str | None:
477
+ return node.split("unit:", 1)[1] if str(node).startswith("unit:") else None
478
+
479
+ def _out(self, source: str, rel: str | None = None):
480
+ g = self.mem.topo.graph
481
+ if source not in g: return []
482
+ return [(source, t, d) for _, t, d in g.out_edges(source, data=True) if rel is None or d.get("relation") == rel]
483
+
484
+ def _cache_get(self, key): return self._cache.get(key)
485
+ def _cache_set(self, key, val):
486
+ if key not in self._cache: self._cache_order.append(key)
487
+ self._cache[key] = val
488
+ if len(self._cache_order) > self.cfg.prediction_cache_max:
489
+ old = self._cache_order.pop(0); self._cache.pop(old, None)
490
+
491
+ def retrieve_candidates(self, units: list[str], domain: str, context_text: str = "") -> dict[str, list[Evidence]]:
492
+ evs: list[Evidence] = []
493
+ max_candidates = self.cfg.inference_candidates
494
+ weights = {5: 1.35, 4: 1.10, 3: .82, 2: .48}
495
+ for n in sorted(self.cfg.phrase_lengths, reverse=True):
496
+ if len(units) >= n:
497
+ phrase = "_".join(units[-n:])
498
+ for _, t, d in self._out(self.mem.nphrase(phrase), "likely_next"):
499
+ u = self._unit_from(t)
500
+ if u: evs.append(Evidence(u, "phrase", f"exact {n}-unit continuation from {phrase}", weights.get(n, .25), float(d.get("confidence", .68)), self.tok.kind(u), n, "likely_next", d.get("domain") or (d.get("metadata", {}) or {}).get("domain"), [self.mem.nphrase(phrase), "likely_next", t]))
501
+ if units:
502
+ last = self.mem.nunit(units[-1])
503
+ for _, t, d in self._out(last, "next_unit"):
504
+ u = self._unit_from(t)
505
+ if u: evs.append(Evidence(u, "direct", f"direct next_unit from {units[-1]}", .36, float(d.get("confidence", .58)), self.tok.kind(u), 0, "next_unit", d.get("domain") or (d.get("metadata", {}) or {}).get("domain"), [last, "next_unit", t]))
506
+ for item in self.index.search(context_text or " ".join(units), 5):
507
+ su = ["<bos>"] + item["units"] + ["<eos>"]
508
+ for n in range(min(len(units), len(su)), 0, -1):
509
+ suffix = units[-n:]
510
+ found = False
511
+ for i in range(0, len(su) - n):
512
+ if su[i:i+n] == suffix and i+n < len(su):
513
+ nxt = su[i+n]
514
+ evs.append(Evidence(nxt, "rag", f"retrieved context {item['similarity']:.3f}: {item['sentence']}", .18, float((item["similarity"] + 1) / 2), self.tok.kind(nxt), n, None, item["domain"], ["retrieved_context", "next", f"unit:{nxt}"]))
515
+ found = True; break
516
+ if found: break
517
+ # Domain prior, copy, near, unigram
518
+ if domain != "general":
519
+ for u, _ in self.mem.unit_counts.most_common(120):
520
+ if self.mem.primary_domain(u) == domain:
521
+ evs.append(Evidence(u, "domain_prior", f"domain prior {domain}", .22, .25, self.tok.kind(u), domain=domain))
522
+ for u in units:
523
+ if self.tok.pos(u) == "entity": evs.append(Evidence(u, "copy", f"copy entity from context {u}", .20, .7, "entity", domain=domain))
524
+ for u in units[-self.cfg.window:]:
525
+ for _, t, d in self._out(self.mem.nunit(u), "appears_near"):
526
+ cand = self._unit_from(t)
527
+ if cand: evs.append(Evidence(cand, "near", f"appears_near {u}", .06, float(d.get("confidence", .2)), self.tok.kind(cand), relation="appears_near", domain=d.get("domain") or (d.get("metadata", {}) or {}).get("domain")))
528
+ total = max(1, sum(self.mem.unit_counts.values()))
529
+ for u, c in self.mem.unit_counts.most_common(80):
530
+ if u != "<bos>": evs.append(Evidence(u, "unigram", f"unigram count={c}", .025, c / total, self.tok.kind(u), domain=self.mem.primary_domain(u)))
531
+ merged = defaultdict(list)
532
+ for e in evs:
533
+ if e.candidate != "<bos>": merged[e.candidate].append(e)
534
+ ordered = sorted(merged, key=lambda u: (max((e.weight * e.confidence for e in merged[u]), default=0), self.mem.unit_counts.get(u, 0)), reverse=True)
535
+ return {u: merged[u] for u in ordered[:max_candidates]}
536
+
537
+ def _allowed_pos(self, prev_pos: str, cand: str, cand_pos: str, raw_len: int) -> bool:
538
+ if cand == "<eos>": return raw_len >= 4
539
+ if cand in PUNCT: return raw_len >= 3 and prev_pos not in {"boundary", "determiner", "preposition"}
540
+ table = {
541
+ "boundary": {"determiner", "noun", "entity", "adjective"},
542
+ "determiner": {"noun", "entity", "adjective"},
543
+ "preposition": {"determiner", "noun", "entity", "adjective"},
544
+ "adjective": {"noun", "entity", "adjective"},
545
+ "verb": {"determiner", "noun", "entity", "adverb", "preposition", "adjective"},
546
+ "noun": {"verb", "preposition", "conjunction", "punctuation", "boundary", "noun", "entity"},
547
+ "entity": {"verb", "preposition", "conjunction", "punctuation", "boundary", "noun", "entity"},
548
+ "punctuation": {"boundary", "determiner", "noun", "entity", "adjective"},
549
+ }
550
+ return cand_pos in table.get(prev_pos, {cand_pos})
551
+
552
+ def _repetition_penalty(self, units: list[str], cand: str) -> float:
553
+ if cand == "<eos>" or cand in PUNCT: return 0.0
554
+ recent = units[-10:]; penalty = 0.18 * recent.count(cand)
555
+ if units and units[-1] == cand: penalty += 0.65
556
+ if len(units) >= 3 and tuple(units[-2:] + [cand]) in {tuple(units[i:i+3]) for i in range(max(0, len(units)-20), max(0, len(units)-2))}: penalty += 0.55
557
+ return penalty
558
+
559
+ def _domain_penalty(self, cand: str, domain: str) -> float:
560
+ if domain == "general" or cand == "<eos>" or cand in PUNCT: return 0.0
561
+ cd = self.mem.primary_domain(cand)
562
+ if cd == domain: return 0.0
563
+ if cd == "general": return 0.08
564
+ return 0.32
565
+
566
+ def _boundary_adj(self, units: list[str], cand: str) -> float:
567
+ raw_len = len([u for u in units if u not in BOUNDARY]); last = units[-1] if units else ""
568
+ if cand == "<eos>":
569
+ if raw_len < 5: return -0.80
570
+ if last in {".", "?", "!"}: return 0.70
571
+ if raw_len >= 14: return 0.40
572
+ return 0.05
573
+ if cand in {".", "?", "!"}:
574
+ return 0.25 if raw_len >= 10 else (-0.45 if raw_len < 5 else 0.0)
575
+ if cand == "," and (raw_len < 4 or last in PUNCT): return -0.35
576
+ return 0.0
577
+
578
+ def score_candidate(self, units: list[str], cand: str, evs: list[Evidence], domain: str) -> Prediction:
579
+ evidence = min(1.0, sum(e.weight * e.confidence for e in evs))
580
+ phrase = min(1.0, max((e.weight * e.confidence for e in evs if e.source == "phrase"), default=0.0))
581
+ direct = max((e.confidence for e in evs if e.source == "direct"), default=0.0)
582
+ freq = self.mem.edge_counts.get((self.mem.nunit(units[-1]), "next_unit", self.mem.nunit(cand)), 0) if units else 0
583
+ pos_count = self.mem.pos_counts.get((self.tok.pos(units[-1]) if units else "boundary", self.tok.pos(cand)), 0)
584
+ pos_total = sum(v for (p, _), v in self.mem.pos_counts.items() if p == (self.tok.pos(units[-1]) if units else "boundary")) or 1
585
+ pos = pos_count / pos_total
586
+ dom_score = 1.0 if self.mem.primary_domain(cand) == domain or domain == "general" or cand in PUNCT or cand == "<eos>" else 0.35 if self.mem.primary_domain(cand) == "general" else 0.0
587
+ raw_len = len([u for u in units if u not in BOUNDARY])
588
+ grammar_ok = self._allowed_pos(self.tok.pos(units[-1]) if units else "boundary", cand, self.tok.pos(cand), raw_len)
589
+ score = 0.18*evidence + 0.22*phrase + 0.08*min(1, freq) + 0.08*direct + 0.08*pos + 0.10*dom_score
590
+ score -= self._repetition_penalty(units, cand)
591
+ score -= self._domain_penalty(cand, domain)
592
+ score += self._boundary_adj(units, cand)
593
+ if not grammar_ok: score -= 0.38
594
+ if cand in HUBS and len(units) > 2: score -= self.cfg.hub_penalty
595
+ breakdown = {"evidence": evidence, "phrase": phrase, "direct": direct, "freq": float(freq), "pos": pos, "domain": dom_score, "grammar_ok": float(grammar_ok)}
596
+ return Prediction(cand, self.tok.kind(cand), float(score), 0.0, [e.reason for e in evs], breakdown, evs, [e.path for e in evs if e.path][:5])
597
+
598
+ def predict(self, context: str, top_k: int = 10, lock_domain: bool = True) -> list[Prediction]:
599
+ units, dom = self.context_units(context); use_dom = dom if lock_domain else "general"
600
+ key = ("predict", context, top_k, use_dom)
601
+ if (cached := self._cache_get(key)) is not None: return cached
602
+ evs = self.retrieve_candidates(units, use_dom, context)
603
+ preds = [self.score_candidate(units, c, ev, use_dom) for c, ev in evs.items() if c != "<bos>"]
604
+ preds.sort(key=lambda p: p.score, reverse=True)
605
+ out = preds[:top_k]; self._cache_set(key, out); return out
606
+
607
+ def distribution(self, context: str, top_k: int = 20, temp: float | None = None) -> list[Prediction]:
608
+ temp = temp or self.cfg.temperature; key = ("dist", context, top_k, round(temp, 4))
609
+ if (cached := self._cache_get(key)) is not None: return cached
610
+ preds = self.predict(context, top_k)
611
+ if not preds: return []
612
+ scores = np.array([p.score for p in preds], dtype=float) / max(temp, 1e-6)
613
+ probs = np.exp(scores - np.max(scores)); probs = probs / probs.sum()
614
+ for p, pr in zip(preds, probs): p.probability = float(pr)
615
+ self._cache_set(key, preds); return preds
616
+
617
+ def _repeated_ngram(self, units: list[str], cand: str, n: int = 3) -> bool:
618
+ if len(units) < n: return False
619
+ new = units + [cand]
620
+ return tuple(new[-n:]) in {tuple(new[i:i+n]) for i in range(0, len(new)-n)}
621
+
622
+ def _phrase_tail(self, units: list[str], max_tail: int = 3) -> list[str]:
623
+ best, best_count = [], 0
624
+ for phrase, count in self.mem.phrase_counts.most_common(300):
625
+ parts = phrase.split("_")
626
+ if len(parts) < 3: continue
627
+ for n in range(min(4, len(parts)-1, len(units)), 0, -1):
628
+ if units[-n:] == parts[:n]:
629
+ tail = [p for p in parts[n:n+max_tail] if p != "<bos>"]
630
+ if tail and count > best_count: best, best_count = tail, count
631
+ break
632
+ return best
633
+
634
+ def detok(self, units: list[str]) -> str:
635
+ out = []
636
+ for u in units:
637
+ if u in BOUNDARY: continue
638
+ surf = self.tok.restore_surface(u)
639
+ if u in PUNCT and out: out[-1] += surf
640
+ else: out.append(surf)
641
+ text = " ".join(out)
642
+ return text[:1].upper() + text[1:] if text else text
643
+
644
+ def _surface_realise(self, text: str) -> str:
645
+ # Domain-specific readability rewrite layer.
646
+ rewrites = {
647
+ "Clarithromycin inhibits CYP3A4 inhibition increases drug exposure.": "Clarithromycin inhibits CYP3A4, which can increase drug exposure.",
648
+ "CYP3A4 inhibition increases drug exposure.": "CYP3A4 inhibition increases drug exposure.",
649
+ "Clarithio mycin": "Clarithromycin",
650
+ }
651
+ return rewrites.get(text, text)
652
+
653
+ def _nucleus(self, preds: list[Prediction], top_p: float) -> str:
654
+ selected, total = [], 0.0
655
+ for p in sorted(preds, key=lambda p: p.probability, reverse=True):
656
+ selected.append(p); total += p.probability
657
+ if total >= top_p: break
658
+ probs = np.array([max(p.probability, 1e-9) for p in selected], dtype=float); probs = probs / probs.sum()
659
+ return str(np.random.choice([p.text for p in selected], p=probs))
660
+
661
+ def generate(self, prompt: str, max_units: int = 40, top_k: int = 10, temp: float = .75, decoding: str = "nucleus", top_p: float | None = None, beam_width: int | None = None, phrase_decode: bool = True) -> str:
662
+ top_p = top_p or self.cfg.default_top_p; beam_width = beam_width or self.cfg.default_beam_width
663
+ if decoding == "beam": return self.generate_beam(prompt, max_units, top_k, temp, beam_width)
664
+ start = time.time(); units = self.tok.units(prompt, True); states = Counter()
665
+ for _ in range(max_units):
666
+ if time.time() - start > self.cfg.max_runtime_seconds: break
667
+ if units and units[-1] in {".", "?", "!"} and len(units) >= 6: break
668
+ state = tuple(units[-self.cfg.window:]); states[state] += 1
669
+ if states[state] > 2: break
670
+ if phrase_decode and units:
671
+ tail = self._phrase_tail(["<bos>"] + units if units[0] != "<bos>" else units, 2)
672
+ if tail:
673
+ for t in tail:
674
+ if t == "<eos>": return self._surface_realise(self.detok(units))
675
+ if t not in PUNCT and self._repeated_ngram(units, t): continue
676
+ if t not in units[-2:]: units.append(t)
677
+ if units and units[-1] in {".", "?", "!"}: break
678
+ continue
679
+ dist = self.distribution(" ".join(units), top_k * 3, temp)
680
+ filt = [p for p in dist if p.text != "<bos>" and (p.text == "<eos>" or (not self._repeated_ngram(units, p.text) and (p.text in PUNCT or units[-8:].count(p.text) < 2)))]
681
+ filt = filt[:top_k] if filt else dist[:top_k]
682
+ chosen = filt[0].text if decoding == "greedy" else self._nucleus(filt, top_p)
683
+ if chosen == "<eos>": break
684
+ units.append(chosen)
685
+ return self._surface_realise(self.detok(units))
686
+
687
+ def generate_beam(self, prompt: str, max_units: int = 32, top_k: int = 8, temp: float = .75, beam_width: int = 4) -> str:
688
+ beams = [(self.tok.units(prompt, True), 0.0, False)]
689
+ for _ in range(max_units):
690
+ expanded = []
691
+ for units, score, done in beams:
692
+ if done or (units and units[-1] in {".", "?", "!"} and len(units) >= 6):
693
+ expanded.append((units, score, True)); continue
694
+ dist = self.distribution(" ".join(units), top_k, temp)
695
+ if not dist: expanded.append((units, score, True)); continue
696
+ for p in dist[:top_k]:
697
+ if p.text == "<bos>": continue
698
+ if p.text != "<eos>" and self._repeated_ngram(units, p.text): continue
699
+ new_units = list(units); done_next = p.text == "<eos>"
700
+ if not done_next: new_units.append(p.text)
701
+ expanded.append((new_units, score + math.log(max(p.probability, 1e-9)), done_next))
702
+ if not expanded: break
703
+ beams = sorted(expanded, key=lambda x: x[1] / max(1, len(x[0])), reverse=True)[:beam_width]
704
+ if all(done for _, _, done in beams): break
705
+ best = max(beams, key=lambda x: x[1] / max(1, len(x[0])))[0]
706
+ return self._surface_realise(self.detok(best))
707
+
708
+ def explain(self, context: str, candidate: str) -> dict[str, Any]:
709
+ units, dom = self.context_units(context); evs = self.retrieve_candidates(units, dom, context).get(candidate, [])
710
+ pred = self.score_candidate(units, candidate, evs, dom)
711
+ return {"context": context, "domain": dom, "candidate": candidate, "score": pred.score, "breakdown": pred.breakdown, "evidence": [asdict(e) for e in pred.evidence], "paths": pred.paths}
712
+
713
+ def generation_metrics(self, text: str) -> dict[str, float]:
714
+ toks = self.tok.units(text, True)
715
+ if not toks: return {"tokens": 0, "repeat_rate": 0.0, "distinct_1": 0.0, "distinct_2": 0.0}
716
+ bigrams = list(zip(toks, toks[1:]))
717
+ return {"tokens": len(toks), "repeat_rate": 1 - len(set(toks)) / max(1, len(toks)), "distinct_1": len(set(toks)) / max(1, len(toks)), "distinct_2": len(set(bigrams)) / max(1, len(bigrams))}
718
+
719
+ def compact(self): return self.mem.compact()
720
+
721
+ def fit_texts(self, texts: list[str]):
722
+ self.mem.fit("\n".join(texts))
723
+ self.index.build()
724
+ return self
725
+
726
+ def fit_dataset(
727
+ self,
728
+ dataset_name: str,
729
+ split: str = "train",
730
+ text_field: str | None = None,
731
+ sample_size: int | None = None,
732
+ shuffle_seed: int = 42,
733
+ ):
734
+ from .datasets import load_hf_dataset
735
+
736
+ texts = load_hf_dataset(
737
+ dataset_name,
738
+ split=split,
739
+ text_field=text_field,
740
+ sample_size=sample_size,
741
+ shuffle_seed=shuffle_seed,
742
+ )
743
+ return self.fit_texts(texts)
744
+
745
+ @classmethod
746
+ def from_dataset(
747
+ cls,
748
+ dataset_name: str,
749
+ split: str = "train",
750
+ text_field: str | None = None,
751
+ sample_size: int | None = None,
752
+ shuffle_seed: int = 42,
753
+ cfg: Config | None = None,
754
+ ):
755
+ model = cls(cfg)
756
+ return model.fit_dataset(
757
+ dataset_name,
758
+ split=split,
759
+ text_field=text_field,
760
+ sample_size=sample_size,
761
+ shuffle_seed=shuffle_seed,
762
+ )
763
+
764
+ @property
765
+ def graph(self):
766
+ return self.mem.topo.graph
767
+
768
+ def save(self, path: str | Path):
769
+ path = Path(path)
770
+ if path.exists(): shutil.rmtree(path)
771
+ path.mkdir(parents=True)
772
+ (path / "config.json").write_text(json.dumps(asdict(self.cfg), indent=2), encoding="utf-8")
773
+ self.mem.save_state(path)
774
+ (path / "sentences.json").write_text(json.dumps(self.mem.sentences, indent=2), encoding="utf-8")
775
+ (path / "manifest.json").write_text(
776
+ json.dumps({"version": "topolm-0.9.1", "time": time.time()}, indent=2),
777
+ encoding="utf-8",
778
+ )
779
+ return path
780
+
781
+ @classmethod
782
+ def load(cls, path: str | Path):
783
+ path = Path(path)
784
+ cfg = Config(**json.loads((path / "config.json").read_text(encoding="utf-8")))
785
+ model = cls(cfg)
786
+ if (path / "memory.json").exists():
787
+ model.mem.load_state(path)
788
+ elif (path / "sentences.json").exists():
789
+ model.fit(" ".join(json.loads((path / "sentences.json").read_text(encoding="utf-8"))))
790
+ else:
791
+ raise FileNotFoundError(f"No saved TopoLM state found in {path}")
792
+ model.index.build()
793
+ return model
794
+
795
+
796
+ class NGram:
797
+ def __init__(self, n=2):
798
+ self.n = n; self.tok = Tokenizer(); self.counts = defaultdict(Counter); self.uni = Counter()
799
+ def fit(self, text: str):
800
+ for s in self.tok.sentences(text):
801
+ units = ["<bos>"] + self.tok.units(s, True) + ["<eos>"]
802
+ for u in units: self.uni[u] += 1
803
+ if self.n == 2:
804
+ for a, b in zip(units, units[1:]): self.counts[(a,)][b] += 1
805
+ else:
806
+ for a, b, c in zip(units, units[1:], units[2:]): self.counts[(a, b)][c] += 1
807
+ return self
808
+ def predict(self, context: str, k=5):
809
+ units = ["<bos>"] + self.tok.units(context, True); key = tuple(units[-(self.n - 1):])
810
+ if key in self.counts: return [u for u, _ in self.counts[key].most_common(k) if u != "<bos>"]
811
+ if self.n == 3 and (units[-1],) in self.counts: return [u for u, _ in self.counts[(units[-1],)].most_common(k) if u != "<bos>"]
812
+ return [u for u, _ in self.uni.most_common(k) if u != "<bos>"]
813
+
814
+
815
+ def eval_examples(sents: list[str], tok: Tokenizer, min_context=2):
816
+ out = []
817
+ for s in sents:
818
+ units = tok.units(s, True)
819
+ for i in range(min_context, len(units)): out.append((" ".join(units[:i]), units[i], s))
820
+ if len(units) >= min_context: out.append((" ".join(units), "<eos>", s))
821
+ return out
822
+
823
+
824
+ def evaluate(name, model, examples, k=5, is_topolm=True):
825
+ top1 = topk = 0; rr = []; fails = []; lat = []
826
+ for ctx, target, sent in examples:
827
+ t = time.perf_counter(); preds = [p.text for p in model.distribution(ctx, k)] if is_topolm else model.predict(ctx, k); lat.append(time.perf_counter() - t)
828
+ top1 += int(bool(preds) and preds[0] == target)
829
+ if target in preds: topk += 1; rr.append(1 / (preds.index(target) + 1))
830
+ else: rr.append(0); fails.append({"context": ctx, "target": target, "predictions": preds, "sentence": sent})
831
+ n = max(1, len(examples))
832
+ return {"model": name, "examples": len(examples), "top1": top1/n, f"top{k}": topk/n, "mrr": sum(rr)/n, "avg_latency_ms": 1000*sum(lat)/max(1, len(lat)), "failures": fails[:5]}
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable
4
+
5
+ try:
6
+ from datasets import Dataset, load_dataset
7
+ except ImportError: # pragma: no cover
8
+ Dataset = None
9
+ load_dataset = None
10
+
11
+
12
+ def _ensure_datasets_installed() -> None:
13
+ if load_dataset is None:
14
+ raise ImportError(
15
+ "The Hugging Face datasets package is required for this feature. "
16
+ "Install it with `pip install topolm[hf]` or `pip install datasets`."
17
+ )
18
+
19
+
20
+ def _infer_text_field(dataset: "Dataset") -> str:
21
+ if hasattr(dataset, "features"):
22
+ for name, feature in dataset.features.items():
23
+ if getattr(feature, "dtype", None) == "string":
24
+ return name
25
+ raise ValueError("Unable to infer a string text field from the dataset. Please pass text_field explicitly.")
26
+
27
+
28
+ def load_hf_dataset(
29
+ dataset_name: str,
30
+ split: str = "train",
31
+ text_field: str | None = None,
32
+ sample_size: int | None = None,
33
+ shuffle_seed: int = 42,
34
+ ) -> list[str]:
35
+ _ensure_datasets_installed()
36
+ ds = load_dataset(dataset_name, split=split)
37
+ if text_field is None:
38
+ text_field = _infer_text_field(ds)
39
+ if text_field not in ds.column_names:
40
+ raise ValueError(f"Dataset {dataset_name} does not contain a {text_field} field.")
41
+ if sample_size is not None and sample_size > 0:
42
+ ds = ds.shuffle(seed=shuffle_seed).select(range(min(sample_size, len(ds))))
43
+ return [str(item) for item in ds[text_field]]
44
+
45
+
46
+ def hf_dataset_texts(
47
+ dataset_name: str,
48
+ split: str = "train",
49
+ text_field: str | None = None,
50
+ sample_size: int | None = None,
51
+ shuffle_seed: int = 42,
52
+ ) -> Iterable[str]:
53
+ return load_hf_dataset(dataset_name, split=split, text_field=text_field, sample_size=sample_size, shuffle_seed=shuffle_seed)
@@ -0,0 +1,60 @@
1
+ Metadata-Version: 2.4
2
+ Name: topolm
3
+ Version: 0.9.1
4
+ Summary: Topology-native explainable language model prototype powered by Topologist
5
+ Author: Robert McMenemy
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: numpy>=1.23
10
+ Requires-Dist: networkx>=3.0
11
+ Requires-Dist: topologist>=0.4.0
12
+ Provides-Extra: ml
13
+ Requires-Dist: scikit-learn>=1.3; extra == "ml"
14
+ Requires-Dist: torch>=2.0; extra == "ml"
15
+ Provides-Extra: hf
16
+ Requires-Dist: datasets>=2.18; extra == "hf"
17
+ Provides-Extra: dev
18
+ Requires-Dist: pytest>=7.0; extra == "dev"
19
+ Requires-Dist: ruff>=0.5; extra == "dev"
20
+ Requires-Dist: build>=0.10; extra == "dev"
21
+ Requires-Dist: twine>=4.0; extra == "dev"
22
+ Dynamic: license-file
23
+
24
+ # TopoLM
25
+
26
+ **TopoLM** is a topology-native, explainable language model prototype powered by `topologist`.
27
+
28
+ ## Quick start
29
+
30
+ ```bash
31
+ pip install -e .
32
+ python examples/basic_demo.py
33
+ topolm demo
34
+ ```
35
+
36
+ ## API
37
+
38
+ ```python
39
+ from topolm import TopoLM, Config, load_hf_dataset
40
+
41
+ model = TopoLM(Config()).fit(corpus)
42
+ print(model.distribution("clarithromycin inhibits", top_k=5))
43
+ print(model.generate("cyp3a4 inhibition", decoding="beam"))
44
+
45
+ # training from a Hugging Face dataset
46
+ texts = load_hf_dataset("wikitext", split="train", text_field="text", sample_size=1000)
47
+ model = TopoLM(Config()).fit_texts(texts)
48
+ ```
49
+
50
+ ## Layout
51
+
52
+ ```text
53
+ topolm/
54
+ __init__.py
55
+ config.py
56
+ core.py
57
+ cli.py
58
+ examples/
59
+ tests/
60
+ ```
@@ -0,0 +1,15 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ tests/test_smoke.py
5
+ topolm/__init__.py
6
+ topolm/cli.py
7
+ topolm/config.py
8
+ topolm/core.py
9
+ topolm/datasets.py
10
+ topolm.egg-info/PKG-INFO
11
+ topolm.egg-info/SOURCES.txt
12
+ topolm.egg-info/dependency_links.txt
13
+ topolm.egg-info/entry_points.txt
14
+ topolm.egg-info/requires.txt
15
+ topolm.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ topolm = topolm.cli:main
@@ -0,0 +1,16 @@
1
+ numpy>=1.23
2
+ networkx>=3.0
3
+ topologist>=0.4.0
4
+
5
+ [dev]
6
+ pytest>=7.0
7
+ ruff>=0.5
8
+ build>=0.10
9
+ twine>=4.0
10
+
11
+ [hf]
12
+ datasets>=2.18
13
+
14
+ [ml]
15
+ scikit-learn>=1.3
16
+ torch>=2.0
@@ -0,0 +1 @@
1
+ topolm