topolm 0.9.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- topolm-0.9.1/LICENSE +21 -0
- topolm-0.9.1/PKG-INFO +60 -0
- topolm-0.9.1/README.md +37 -0
- topolm-0.9.1/pyproject.toml +27 -0
- topolm-0.9.1/setup.cfg +4 -0
- topolm-0.9.1/tests/test_smoke.py +56 -0
- topolm-0.9.1/topolm/__init__.py +16 -0
- topolm-0.9.1/topolm/cli.py +36 -0
- topolm-0.9.1/topolm/config.py +23 -0
- topolm-0.9.1/topolm/core.py +832 -0
- topolm-0.9.1/topolm/datasets.py +53 -0
- topolm-0.9.1/topolm.egg-info/PKG-INFO +60 -0
- topolm-0.9.1/topolm.egg-info/SOURCES.txt +15 -0
- topolm-0.9.1/topolm.egg-info/dependency_links.txt +1 -0
- topolm-0.9.1/topolm.egg-info/entry_points.txt +2 -0
- topolm-0.9.1/topolm.egg-info/requires.txt +16 -0
- topolm-0.9.1/topolm.egg-info/top_level.txt +1 -0
topolm-0.9.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 JadeyGraham96
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
topolm-0.9.1/PKG-INFO
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: topolm
|
|
3
|
+
Version: 0.9.1
|
|
4
|
+
Summary: Topology-native explainable language model prototype powered by Topologist
|
|
5
|
+
Author: Robert McMenemy
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: numpy>=1.23
|
|
10
|
+
Requires-Dist: networkx>=3.0
|
|
11
|
+
Requires-Dist: topologist>=0.4.0
|
|
12
|
+
Provides-Extra: ml
|
|
13
|
+
Requires-Dist: scikit-learn>=1.3; extra == "ml"
|
|
14
|
+
Requires-Dist: torch>=2.0; extra == "ml"
|
|
15
|
+
Provides-Extra: hf
|
|
16
|
+
Requires-Dist: datasets>=2.18; extra == "hf"
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
19
|
+
Requires-Dist: ruff>=0.5; extra == "dev"
|
|
20
|
+
Requires-Dist: build>=0.10; extra == "dev"
|
|
21
|
+
Requires-Dist: twine>=4.0; extra == "dev"
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
|
|
24
|
+
# TopoLM
|
|
25
|
+
|
|
26
|
+
**TopoLM** is a topology-native, explainable language model prototype powered by `topologist`.
|
|
27
|
+
|
|
28
|
+
## Quick start
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
pip install -e .
|
|
32
|
+
python examples/basic_demo.py
|
|
33
|
+
topolm demo
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## API
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
from topolm import TopoLM, Config, load_hf_dataset
|
|
40
|
+
|
|
41
|
+
model = TopoLM(Config()).fit(corpus)
|
|
42
|
+
print(model.distribution("clarithromycin inhibits", top_k=5))
|
|
43
|
+
print(model.generate("cyp3a4 inhibition", decoding="beam"))
|
|
44
|
+
|
|
45
|
+
# training from a Hugging Face dataset
|
|
46
|
+
texts = load_hf_dataset("wikitext", split="train", text_field="text", sample_size=1000)
|
|
47
|
+
model = TopoLM(Config()).fit_texts(texts)
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Layout
|
|
51
|
+
|
|
52
|
+
```text
|
|
53
|
+
topolm/
|
|
54
|
+
__init__.py
|
|
55
|
+
config.py
|
|
56
|
+
core.py
|
|
57
|
+
cli.py
|
|
58
|
+
examples/
|
|
59
|
+
tests/
|
|
60
|
+
```
|
topolm-0.9.1/README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# TopoLM
|
|
2
|
+
|
|
3
|
+
**TopoLM** is a topology-native, explainable language model prototype powered by `topologist`.
|
|
4
|
+
|
|
5
|
+
## Quick start
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install -e .
|
|
9
|
+
python examples/basic_demo.py
|
|
10
|
+
topolm demo
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## API
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from topolm import TopoLM, Config, load_hf_dataset
|
|
17
|
+
|
|
18
|
+
model = TopoLM(Config()).fit(corpus)
|
|
19
|
+
print(model.distribution("clarithromycin inhibits", top_k=5))
|
|
20
|
+
print(model.generate("cyp3a4 inhibition", decoding="beam"))
|
|
21
|
+
|
|
22
|
+
# training from a Hugging Face dataset
|
|
23
|
+
texts = load_hf_dataset("wikitext", split="train", text_field="text", sample_size=1000)
|
|
24
|
+
model = TopoLM(Config()).fit_texts(texts)
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Layout
|
|
28
|
+
|
|
29
|
+
```text
|
|
30
|
+
topolm/
|
|
31
|
+
__init__.py
|
|
32
|
+
config.py
|
|
33
|
+
core.py
|
|
34
|
+
cli.py
|
|
35
|
+
examples/
|
|
36
|
+
tests/
|
|
37
|
+
```
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "topolm"
|
|
7
|
+
version = "0.9.1"
|
|
8
|
+
description = "Topology-native explainable language model prototype powered by Topologist"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
authors = [{name = "Robert McMenemy"}]
|
|
12
|
+
dependencies = ["numpy>=1.23", "networkx>=3.0", "topologist>=0.4.0"]
|
|
13
|
+
|
|
14
|
+
[project.optional-dependencies]
|
|
15
|
+
ml = ["scikit-learn>=1.3", "torch>=2.0"]
|
|
16
|
+
hf = ["datasets>=2.18"]
|
|
17
|
+
dev = ["pytest>=7.0", "ruff>=0.5", "build>=0.10", "twine>=4.0"]
|
|
18
|
+
|
|
19
|
+
[project.scripts]
|
|
20
|
+
topolm = "topolm.cli:main"
|
|
21
|
+
|
|
22
|
+
[tool.setuptools.packages.find]
|
|
23
|
+
where = ["."]
|
|
24
|
+
include = ["topolm*"]
|
|
25
|
+
|
|
26
|
+
[tool.pytest.ini_options]
|
|
27
|
+
testpaths = ["tests"]
|
topolm-0.9.1/setup.cfg
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from topolm import Config, TopoLM, load_hf_dataset
|
|
2
|
+
|
|
3
|
+
CORPUS = """
|
|
4
|
+
The cat sat on the mat.
|
|
5
|
+
The dog sat on the floor.
|
|
6
|
+
The attacker used CVE-2024-1234 to access the admin panel.
|
|
7
|
+
CYP3A4 inhibition increases drug exposure.
|
|
8
|
+
Clarithromycin inhibits CYP3A4.
|
|
9
|
+
Clarithromycin may increase simvastatin exposure.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def test_predict_domain_terms():
|
|
13
|
+
model = TopoLM(Config()).fit(CORPUS)
|
|
14
|
+
preds = model.distribution("clarithromycin inhibits", 5)
|
|
15
|
+
assert preds
|
|
16
|
+
assert preds[0].text == "cyp3a4"
|
|
17
|
+
|
|
18
|
+
def test_generation_is_bounded_and_fluentish():
|
|
19
|
+
model = TopoLM(Config()).fit(CORPUS)
|
|
20
|
+
out = model.generate("cyp3a4 inhibition", decoding="beam", max_units=16)
|
|
21
|
+
assert "CYP3A4" in out
|
|
22
|
+
assert len(out.split()) <= 20
|
|
23
|
+
|
|
24
|
+
def test_save_load(tmp_path):
|
|
25
|
+
model = TopoLM(Config()).fit(CORPUS)
|
|
26
|
+
path = model.save(tmp_path / "model")
|
|
27
|
+
loaded = TopoLM.load(path)
|
|
28
|
+
assert loaded.distribution("clarithromycin inhibits", 3)[0].text == "cyp3a4"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_save_load_round_trip_state(tmp_path):
|
|
32
|
+
model = TopoLM(Config()).fit(CORPUS)
|
|
33
|
+
path = model.save(tmp_path / "model_state")
|
|
34
|
+
loaded = TopoLM.load(path)
|
|
35
|
+
assert loaded.mem.unit_counts == model.mem.unit_counts
|
|
36
|
+
assert loaded.mem.phrase_counts == model.mem.phrase_counts
|
|
37
|
+
assert loaded.mem.edge_counts == model.mem.edge_counts
|
|
38
|
+
assert loaded.mem.domain_counts == model.mem.domain_counts
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_fit_texts_helper():
|
|
42
|
+
model = TopoLM(Config()).fit_texts([s.strip() for s in CORPUS.splitlines() if s.strip()])
|
|
43
|
+
assert model.distribution("clarithromycin inhibits", 5)[0].text == "cyp3a4"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_hf_dataset_loader_simple(monkeypatch):
|
|
47
|
+
import pytest
|
|
48
|
+
datasets = pytest.importorskip("datasets")
|
|
49
|
+
from datasets import Dataset
|
|
50
|
+
|
|
51
|
+
def fake_load_dataset(name, split="train"):
|
|
52
|
+
return Dataset.from_dict({"text": ["Topologist test text."]})
|
|
53
|
+
|
|
54
|
+
monkeypatch.setattr("topolm.datasets.load_dataset", fake_load_dataset)
|
|
55
|
+
texts = load_hf_dataset("testset", split="train", text_field="text", sample_size=1)
|
|
56
|
+
assert texts == ["Topologist test text."]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .config import Config
|
|
2
|
+
from .core import TopoLM, Tokenizer, Corpus, NGram, evaluate, eval_examples
|
|
3
|
+
from .datasets import hf_dataset_texts, load_hf_dataset
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Config",
|
|
7
|
+
"TopoLM",
|
|
8
|
+
"Tokenizer",
|
|
9
|
+
"Corpus",
|
|
10
|
+
"NGram",
|
|
11
|
+
"evaluate",
|
|
12
|
+
"eval_examples",
|
|
13
|
+
"load_hf_dataset",
|
|
14
|
+
"hf_dataset_texts",
|
|
15
|
+
]
|
|
16
|
+
__version__ = "0.9.1"
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from .core import TopoLM
|
|
5
|
+
from .config import Config
|
|
6
|
+
|
|
7
|
+
DEMO = """
|
|
8
|
+
The cat sat on the mat.
|
|
9
|
+
The dog sat on the floor.
|
|
10
|
+
The attacker used CVE-2024-1234 to access the admin panel.
|
|
11
|
+
CYP3A4 inhibition increases drug exposure.
|
|
12
|
+
Clarithromycin inhibits CYP3A4.
|
|
13
|
+
Clarithromycin may increase simvastatin exposure.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def build_demo_model():
|
|
17
|
+
return TopoLM(Config()).fit(DEMO)
|
|
18
|
+
|
|
19
|
+
def main(argv=None):
|
|
20
|
+
parser = argparse.ArgumentParser(prog="topolm")
|
|
21
|
+
sub = parser.add_subparsers(dest="cmd")
|
|
22
|
+
sub.add_parser("demo")
|
|
23
|
+
p = sub.add_parser("predict"); p.add_argument("context")
|
|
24
|
+
g = sub.add_parser("generate"); g.add_argument("prompt"); g.add_argument("--decoding", default="beam", choices=["beam", "nucleus", "greedy"])
|
|
25
|
+
args = parser.parse_args(argv)
|
|
26
|
+
model = build_demo_model()
|
|
27
|
+
if args.cmd == "predict":
|
|
28
|
+
for pred in model.distribution(args.context, 5):
|
|
29
|
+
print(f"{pred.text}\t{pred.probability:.3f}\t{pred.score:.3f}")
|
|
30
|
+
elif args.cmd == "generate":
|
|
31
|
+
print(model.generate(args.prompt, decoding=args.decoding))
|
|
32
|
+
else:
|
|
33
|
+
print(model.generate("clarithromycin inhibits", decoding="beam"))
|
|
34
|
+
|
|
35
|
+
if __name__ == "__main__":
|
|
36
|
+
main()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
PUNCT = {".", ",", ";", ":", "?", "!"}
|
|
4
|
+
BOUNDARY = {"<bos>", "<eos>"}
|
|
5
|
+
HUBS = {"the", "a", "an", "and", "or", "of", "to", "in", "on", "with", "when", ".", ","}
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class Config:
|
|
9
|
+
dim: int = 1024
|
|
10
|
+
seed: int = 42
|
|
11
|
+
window: int = 8
|
|
12
|
+
phrase_lengths: tuple[int, ...] = (2, 3, 4, 5)
|
|
13
|
+
max_candidates: int = 96
|
|
14
|
+
inference_candidates: int = 48
|
|
15
|
+
prediction_cache_max: int = 4096
|
|
16
|
+
temperature: float = 0.75
|
|
17
|
+
max_runtime_seconds: float = 5.0
|
|
18
|
+
default_top_p: float = 0.88
|
|
19
|
+
default_beam_width: int = 4
|
|
20
|
+
fast_dev_mode: bool = True
|
|
21
|
+
max_reranker_sentences: int = 80
|
|
22
|
+
negatives_per_positive: int = 2
|
|
23
|
+
hub_penalty: float = 0.10
|
|
@@ -0,0 +1,832 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import math
|
|
6
|
+
import random
|
|
7
|
+
import re
|
|
8
|
+
import shutil
|
|
9
|
+
import time
|
|
10
|
+
from collections import Counter, defaultdict
|
|
11
|
+
from dataclasses import asdict, dataclass, field
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import networkx as nx
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from .config import BOUNDARY, HUBS, PUNCT, Config
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from topologist import Topologist, TopologistConfig
|
|
22
|
+
except Exception: # pragma: no cover
|
|
23
|
+
Topologist = None
|
|
24
|
+
TopologistConfig = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Evidence:
|
|
29
|
+
candidate: str
|
|
30
|
+
source: str
|
|
31
|
+
reason: str
|
|
32
|
+
weight: float
|
|
33
|
+
confidence: float
|
|
34
|
+
unit_type: str = "word"
|
|
35
|
+
phrase_length: int = 0
|
|
36
|
+
relation: str | None = None
|
|
37
|
+
domain: str | None = None
|
|
38
|
+
path: list[str] = field(default_factory=list)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class Prediction:
|
|
43
|
+
text: str
|
|
44
|
+
unit_type: str
|
|
45
|
+
score: float
|
|
46
|
+
probability: float = 0.0
|
|
47
|
+
reasons: list[str] = field(default_factory=list)
|
|
48
|
+
breakdown: dict[str, float] = field(default_factory=dict)
|
|
49
|
+
evidence: list[Evidence] = field(default_factory=list)
|
|
50
|
+
paths: list[list[str]] = field(default_factory=list)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class FallbackTopo:
|
|
54
|
+
def __init__(self) -> None:
|
|
55
|
+
self.graph = nx.DiGraph()
|
|
56
|
+
|
|
57
|
+
def add_node(self, node: str, kind: str | None = None, **kwargs: Any) -> None:
|
|
58
|
+
meta = dict(kwargs)
|
|
59
|
+
if kind is not None:
|
|
60
|
+
meta["kind"] = kind
|
|
61
|
+
self.graph.add_node(node, **meta)
|
|
62
|
+
|
|
63
|
+
def add_edge(
|
|
64
|
+
self,
|
|
65
|
+
source: str,
|
|
66
|
+
relation: str,
|
|
67
|
+
target: str,
|
|
68
|
+
confidence: float | None = None,
|
|
69
|
+
metadata: dict[str, Any] | None = None,
|
|
70
|
+
**kwargs: Any,
|
|
71
|
+
) -> None:
|
|
72
|
+
meta = dict(metadata or {})
|
|
73
|
+
if confidence is not None:
|
|
74
|
+
meta["confidence"] = confidence
|
|
75
|
+
meta.update(kwargs)
|
|
76
|
+
self.graph.add_edge(source, target, relation=relation, metadata=meta)
|
|
77
|
+
|
|
78
|
+
def update_global_state(self) -> None:
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Tokenizer:
|
|
83
|
+
TOKEN_RE = re.compile(
|
|
84
|
+
r"<bos>|<eos>|CVE-\d{4}-\d+|[A-Za-z]{1,10}\d+[A-Za-z0-9]*|"
|
|
85
|
+
r"[A-Za-z]+(?:-[A-Za-z0-9]+)+|[A-Za-z]+/[A-Za-z]+|"
|
|
86
|
+
r"[A-Za-z]+(?:'[A-Za-z]+)?|\d+(?:\.\d+)?|[.!?,;:]",
|
|
87
|
+
re.I,
|
|
88
|
+
)
|
|
89
|
+
VERBS = {
|
|
90
|
+
"sat", "slept", "likes", "like", "liked", "increases", "increase", "increased",
|
|
91
|
+
"rises", "rise", "rose", "exploited", "exploit", "escalated", "escalate",
|
|
92
|
+
"detected", "detect", "exposed", "expose", "inhibits", "inhibit", "induces",
|
|
93
|
+
"metabolised", "metabolized", "causes", "cause", "requires", "require", "allows",
|
|
94
|
+
"allow", "access", "used", "use", "may", "interacts", "explain", "summarise",
|
|
95
|
+
"compare", "check", "store", "predicts", "uses", "rank",
|
|
96
|
+
}
|
|
97
|
+
PREPS = {"on", "in", "at", "by", "with", "from", "to", "of", "for", "over", "under", "through", "near", "when", "after", "before", "during"}
|
|
98
|
+
DOMAINS = {
|
|
99
|
+
"domestic": {"cat", "dog", "mat", "floor", "sofa", "fireplace", "garden", "warm", "places", "slept", "sat"},
|
|
100
|
+
"cybersecurity": {"attacker", "exploit", "exploited", "service", "privileges", "privilege", "scanner", "endpoint", "admin", "panel", "vulnerable", "cve-2024-1234", "access", "escalation"},
|
|
101
|
+
"drug_interaction": {"drug", "drug-drug", "interaction", "risk", "exposure", "inhibition", "cyp3a4", "clarithromycin", "simvastatin", "metabolised", "metabolized", "myopathy", "inhibits"},
|
|
102
|
+
"lm_research": {"language", "model", "predicts", "probabilities", "graph", "memory", "topological", "attention", "retrieval", "generation", "context"},
|
|
103
|
+
}
|
|
104
|
+
ENTITY_CANONICAL = {"cyp3a4": "CYP3A4", "cve-2024-1234": "CVE-2024-1234", "tnf-alpha": "TNF-alpha", "il-6": "IL-6"}
|
|
105
|
+
|
|
106
|
+
def tokenize(self, text: str) -> list[str]:
|
|
107
|
+
return [t.lower() for t in self.TOKEN_RE.findall(text) if t.strip()]
|
|
108
|
+
|
|
109
|
+
def units(self, text: str, keep_punct: bool = True) -> list[str]:
|
|
110
|
+
toks = self.tokenize(text)
|
|
111
|
+
return toks if keep_punct else [t for t in toks if t not in PUNCT]
|
|
112
|
+
|
|
113
|
+
def sentences(self, text: str) -> list[str]:
|
|
114
|
+
return [p.strip() for p in re.split(r"(?<=[.!?])\s+", text.strip()) if p.strip()]
|
|
115
|
+
|
|
116
|
+
def phrases(self, units: list[str], lengths: tuple[int, ...]):
|
|
117
|
+
for n in lengths:
|
|
118
|
+
for i in range(0, max(0, len(units) - n + 1)):
|
|
119
|
+
chunk = units[i : i + n]
|
|
120
|
+
yield "_".join(chunk), chunk, i
|
|
121
|
+
|
|
122
|
+
def pos(self, u: str) -> str:
|
|
123
|
+
if u in BOUNDARY:
|
|
124
|
+
return "boundary"
|
|
125
|
+
if u in PUNCT:
|
|
126
|
+
return "punctuation"
|
|
127
|
+
if u in {"the", "a", "an", "this", "that", "these", "those"}:
|
|
128
|
+
return "determiner"
|
|
129
|
+
if u in self.PREPS:
|
|
130
|
+
return "preposition"
|
|
131
|
+
if u in {"and", "or", "but", "because", "while", "if"}:
|
|
132
|
+
return "conjunction"
|
|
133
|
+
if u in self.VERBS or u.endswith("ed") or u.endswith("ing"):
|
|
134
|
+
return "verb"
|
|
135
|
+
if re.fullmatch(r"[a-z]{1,10}\d+[a-z0-9]*", u) or u.startswith("cve-"):
|
|
136
|
+
return "entity"
|
|
137
|
+
if u.endswith(("ous", "ful", "able", "ive", "al", "ic", "y")):
|
|
138
|
+
return "adjective"
|
|
139
|
+
return "noun"
|
|
140
|
+
|
|
141
|
+
def domain(self, units: list[str]) -> str:
|
|
142
|
+
clean = {u for u in units if u not in BOUNDARY and u not in PUNCT}
|
|
143
|
+
scores = Counter({d: len(clean & kws) for d, kws in self.DOMAINS.items()})
|
|
144
|
+
if not scores or scores.most_common(1)[0][1] == 0:
|
|
145
|
+
return "general"
|
|
146
|
+
return scores.most_common(1)[0][0]
|
|
147
|
+
|
|
148
|
+
def kind(self, unit: str) -> str:
|
|
149
|
+
if unit in BOUNDARY:
|
|
150
|
+
return "boundary"
|
|
151
|
+
if unit in PUNCT:
|
|
152
|
+
return "punctuation"
|
|
153
|
+
if self.pos(unit) == "entity":
|
|
154
|
+
return "entity"
|
|
155
|
+
return "word"
|
|
156
|
+
|
|
157
|
+
def restore_surface(self, unit: str) -> str:
|
|
158
|
+
return self.ENTITY_CANONICAL.get(unit, unit)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class Corpus:
|
|
162
|
+
def __init__(self):
|
|
163
|
+
self.tokenizer = Tokenizer()
|
|
164
|
+
|
|
165
|
+
def split(self, text: str, seed: int = 42, val: float = 0.15, test: float = 0.2):
|
|
166
|
+
sents = self.tokenizer.sentences(text)
|
|
167
|
+
rng = random.Random(seed)
|
|
168
|
+
rng.shuffle(sents)
|
|
169
|
+
nt = max(1, int(len(sents) * test))
|
|
170
|
+
nv = max(1, int(len(sents) * val)) if len(sents) >= 6 else 0
|
|
171
|
+
return sents[nt + nv :], sents[nt : nt + nv], sents[:nt]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class HDC:
|
|
175
|
+
def __init__(self, dim: int = 1024, seed: int = 42):
|
|
176
|
+
self.dim = dim
|
|
177
|
+
self.seed = seed
|
|
178
|
+
self.cache: dict[str, np.ndarray] = {}
|
|
179
|
+
|
|
180
|
+
def _seed(self, key: str) -> int:
|
|
181
|
+
import hashlib
|
|
182
|
+
|
|
183
|
+
return int.from_bytes(hashlib.sha256(f"{self.seed}:{key}".encode()).digest()[:8], "big")
|
|
184
|
+
|
|
185
|
+
def get(self, key: str) -> np.ndarray:
|
|
186
|
+
if key not in self.cache:
|
|
187
|
+
rng = np.random.default_rng(self._seed(key))
|
|
188
|
+
self.cache[key] = rng.choice(np.array([-1, 1], dtype=np.int8), size=self.dim)
|
|
189
|
+
return self.cache[key]
|
|
190
|
+
|
|
191
|
+
def bind(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
192
|
+
return (a * b).astype(np.int8)
|
|
193
|
+
|
|
194
|
+
def bundle(self, vectors: list[np.ndarray]) -> np.ndarray:
|
|
195
|
+
if not vectors:
|
|
196
|
+
return np.ones(self.dim, dtype=np.int8)
|
|
197
|
+
s = np.sum(np.stack(vectors).astype(np.int32), axis=0)
|
|
198
|
+
return np.where(s >= 0, 1, -1).astype(np.int8)
|
|
199
|
+
|
|
200
|
+
def encode(self, units: list[str], domain: str | None = None, layer: str = "local") -> np.ndarray:
|
|
201
|
+
vs = [self.bind(self.get(f"{layer}:pos:{i}"), self.get(f"unit:{u}")) for i, u in enumerate(units)]
|
|
202
|
+
if domain:
|
|
203
|
+
vs.append(self.get(f"domain:{domain}"))
|
|
204
|
+
return self.bundle(vs)
|
|
205
|
+
|
|
206
|
+
def lexical(self, text: str) -> np.ndarray:
|
|
207
|
+
return self.bundle([self.get(f"char:{c}") for c in text[:64]])
|
|
208
|
+
|
|
209
|
+
def sim(self, a: np.ndarray, b: np.ndarray) -> float:
|
|
210
|
+
return float(np.dot(a.astype(np.float32), b.astype(np.float32)) / self.dim)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class Memory:
|
|
214
|
+
def __init__(self, cfg: Config):
|
|
215
|
+
self.cfg = cfg
|
|
216
|
+
self.tok = Tokenizer()
|
|
217
|
+
self.hdc = HDC(cfg.dim, cfg.seed)
|
|
218
|
+
self.topo = self._make_topo()
|
|
219
|
+
self.unit_counts = Counter()
|
|
220
|
+
self.phrase_counts = Counter()
|
|
221
|
+
self.edge_counts = Counter()
|
|
222
|
+
self.pos_counts = Counter()
|
|
223
|
+
self.domain_counts = defaultdict(Counter)
|
|
224
|
+
self.contexts: list[dict[str, Any]] = []
|
|
225
|
+
self.sentences: list[str] = []
|
|
226
|
+
self.feedback: list[dict[str, Any]] = []
|
|
227
|
+
self.sid = 0
|
|
228
|
+
try:
|
|
229
|
+
self.topo.add_contradiction_pair("safe_with", "contraindicated_with")
|
|
230
|
+
except Exception:
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
def _make_topo(self):
|
|
234
|
+
if Topologist is not None:
|
|
235
|
+
try:
|
|
236
|
+
return Topologist(TopologistConfig(dim=self.cfg.dim, seed=self.cfg.seed))
|
|
237
|
+
except Exception:
|
|
238
|
+
pass
|
|
239
|
+
return FallbackTopo()
|
|
240
|
+
|
|
241
|
+
def _topo_graph(self):
|
|
242
|
+
return getattr(self.topo, "graph", None)
|
|
243
|
+
|
|
244
|
+
def _serialize_value(self, value: Any) -> Any:
|
|
245
|
+
if isinstance(value, np.ndarray):
|
|
246
|
+
return value.tolist()
|
|
247
|
+
if isinstance(value, (np.integer, np.floating)):
|
|
248
|
+
return value.item()
|
|
249
|
+
if isinstance(value, dict):
|
|
250
|
+
return {k: self._serialize_value(v) for k, v in value.items()}
|
|
251
|
+
if isinstance(value, (list, tuple)):
|
|
252
|
+
return [self._serialize_value(v) for v in value]
|
|
253
|
+
return value
|
|
254
|
+
|
|
255
|
+
def _serialize_graph(self) -> dict[str, list[dict[str, Any]]]:
|
|
256
|
+
g = self._topo_graph()
|
|
257
|
+
if g is None:
|
|
258
|
+
return {"nodes": [], "edges": []}
|
|
259
|
+
nodes = [{"node": n, "data": self._serialize_value(dict(d))} for n, d in g.nodes(data=True)]
|
|
260
|
+
edges = [
|
|
261
|
+
{
|
|
262
|
+
"source": u,
|
|
263
|
+
"relation": d.get("relation"),
|
|
264
|
+
"target": v,
|
|
265
|
+
"data": self._serialize_value(dict(d)),
|
|
266
|
+
}
|
|
267
|
+
for u, v, d in g.edges(data=True)
|
|
268
|
+
]
|
|
269
|
+
return {"nodes": nodes, "edges": edges}
|
|
270
|
+
|
|
271
|
+
def _add_node_direct(self, node: str, kind: str, meta: dict | None = None) -> None:
|
|
272
|
+
meta = meta or {}
|
|
273
|
+
try:
|
|
274
|
+
self.topo.add_node(node, kind=kind, **meta)
|
|
275
|
+
except TypeError:
|
|
276
|
+
self.topo.add_node(node, kind=kind, metadata=meta)
|
|
277
|
+
except Exception:
|
|
278
|
+
pass
|
|
279
|
+
|
|
280
|
+
def _add_edge_direct(
|
|
281
|
+
self,
|
|
282
|
+
s: str,
|
|
283
|
+
r: str,
|
|
284
|
+
t: str,
|
|
285
|
+
conf: float = 0.0,
|
|
286
|
+
evidence: list[str] | None = None,
|
|
287
|
+
meta: dict | None = None,
|
|
288
|
+
source_type: str = "corpus",
|
|
289
|
+
trust: float = 0.8,
|
|
290
|
+
) -> None:
|
|
291
|
+
meta = dict(meta or {})
|
|
292
|
+
if evidence is not None:
|
|
293
|
+
meta["evidence"] = evidence
|
|
294
|
+
try:
|
|
295
|
+
self.topo.add_edge(s, r, t, confidence=conf, source_type=source_type, evidence=evidence or [], trust_score=trust, **meta)
|
|
296
|
+
except TypeError:
|
|
297
|
+
self.topo.add_edge(s, r, t, confidence=conf, metadata=meta)
|
|
298
|
+
except Exception:
|
|
299
|
+
try:
|
|
300
|
+
self.topo.graph.add_edge(s, t, relation=r, confidence=conf, metadata=meta)
|
|
301
|
+
except Exception:
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
def _rebuild_graph(self, graph_data: dict[str, list[dict[str, Any]]]) -> None:
|
|
305
|
+
self.topo = self._make_topo()
|
|
306
|
+
for node in graph_data.get("nodes", []):
|
|
307
|
+
self._add_node_direct(node["node"], node["data"].get("kind", "unknown"), node["data"])
|
|
308
|
+
for edge in graph_data.get("edges", []):
|
|
309
|
+
self._add_edge_direct(
|
|
310
|
+
edge["source"],
|
|
311
|
+
edge["relation"],
|
|
312
|
+
edge["target"],
|
|
313
|
+
conf=float(edge["data"].get("confidence", 0.0)),
|
|
314
|
+
evidence=edge["data"].get("evidence"),
|
|
315
|
+
meta=edge["data"].get("metadata") or edge["data"],
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def _state_dict(self) -> dict[str, Any]:
|
|
319
|
+
return {
|
|
320
|
+
"sentences": self.sentences,
|
|
321
|
+
"contexts": [
|
|
322
|
+
{
|
|
323
|
+
"id": c["id"],
|
|
324
|
+
"sentence": c["sentence"],
|
|
325
|
+
"raw": c["raw"],
|
|
326
|
+
"units": c["units"],
|
|
327
|
+
"domain": c["domain"],
|
|
328
|
+
}
|
|
329
|
+
for c in self.contexts
|
|
330
|
+
],
|
|
331
|
+
"unit_counts": dict(self.unit_counts),
|
|
332
|
+
"phrase_counts": dict(self.phrase_counts),
|
|
333
|
+
"edge_counts": [
|
|
334
|
+
{"source": s, "relation": r, "target": t, "count": c}
|
|
335
|
+
for (s, r, t), c in self.edge_counts.items()
|
|
336
|
+
],
|
|
337
|
+
"pos_counts": [
|
|
338
|
+
{"prev": prev, "next": nxt, "count": c}
|
|
339
|
+
for (prev, nxt), c in self.pos_counts.items()
|
|
340
|
+
],
|
|
341
|
+
"domain_counts": {u: dict(c) for u, c in self.domain_counts.items()},
|
|
342
|
+
"graph": self._serialize_graph(),
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
def save_state(self, path: Path) -> None:
|
|
346
|
+
path = Path(path)
|
|
347
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
348
|
+
(path / "memory.json").write_text(json.dumps(self._state_dict(), indent=2), encoding="utf-8")
|
|
349
|
+
|
|
350
|
+
def load_state(self, path: Path) -> None:
|
|
351
|
+
path = Path(path)
|
|
352
|
+
data = json.loads((path / "memory.json").read_text(encoding="utf-8"))
|
|
353
|
+
self.sentences = data.get("sentences", [])
|
|
354
|
+
self.contexts = [
|
|
355
|
+
{
|
|
356
|
+
**c,
|
|
357
|
+
"hv": self.hdc.encode(c["units"], c["domain"], "sentence"),
|
|
358
|
+
}
|
|
359
|
+
for c in data.get("contexts", [])
|
|
360
|
+
]
|
|
361
|
+
self.unit_counts = Counter(data.get("unit_counts", {}))
|
|
362
|
+
self.phrase_counts = Counter(data.get("phrase_counts", {}))
|
|
363
|
+
self.edge_counts = Counter({(item["source"], item["relation"], item["target"]): item["count"] for item in data.get("edge_counts", [])})
|
|
364
|
+
self.pos_counts = Counter({(item["prev"], item["next"]): item["count"] for item in data.get("pos_counts", [])})
|
|
365
|
+
self.domain_counts = defaultdict(Counter, {u: Counter(c) for u, c in data.get("domain_counts", {}).items()})
|
|
366
|
+
self._rebuild_graph(data.get("graph", {}))
|
|
367
|
+
|
|
368
|
+
def nunit(self, u: str) -> str: return f"unit:{u}"
|
|
369
|
+
def nphrase(self, p: str) -> str: return f"phrase:{p}"
|
|
370
|
+
def npos(self, p: str) -> str: return f"pos:{p}"
|
|
371
|
+
def ndom(self, d: str) -> str: return f"domain:{d}"
|
|
372
|
+
|
|
373
|
+
def add_node(self, node: str, kind: str, meta: dict | None = None) -> None:
|
|
374
|
+
meta = meta or {}
|
|
375
|
+
try:
|
|
376
|
+
self.topo.add_node(node, kind=kind, **meta)
|
|
377
|
+
except TypeError:
|
|
378
|
+
self.topo.add_node(node, kind=kind, metadata=meta)
|
|
379
|
+
except Exception:
|
|
380
|
+
pass
|
|
381
|
+
|
|
382
|
+
def add_edge(self, s: str, r: str, t: str, conf: float = .5, evidence: list[str] | None = None, meta: dict | None = None, source_type: str = "corpus", trust: float = .8) -> None:
|
|
383
|
+
key = (s, r, t); self.edge_counts[key] += 1
|
|
384
|
+
freq = self.edge_counts[key]
|
|
385
|
+
meta = dict(meta or {}); meta.update({"frequency": freq, "last_seen": time.time(), "domain": meta.get("domain", "language_topology")})
|
|
386
|
+
conf = max(conf, min(1.0, math.log1p(freq) / 5))
|
|
387
|
+
try:
|
|
388
|
+
self.topo.add_edge(s, r, t, confidence=conf, source_type=source_type, evidence=evidence or [], trust_score=trust, **meta)
|
|
389
|
+
except TypeError:
|
|
390
|
+
self.topo.add_edge(s, r, t, confidence=conf, metadata=meta)
|
|
391
|
+
except Exception:
|
|
392
|
+
try: self.topo.graph.add_edge(s, t, relation=r, confidence=conf, metadata=meta)
|
|
393
|
+
except Exception: pass
|
|
394
|
+
|
|
395
|
+
def primary_domain(self, u: str) -> str:
|
|
396
|
+
return self.domain_counts[u].most_common(1)[0][0] if self.domain_counts.get(u) else "general"
|
|
397
|
+
|
|
398
|
+
def add_sentence(self, sentence: str) -> None:
|
|
399
|
+
raw = self.tok.units(sentence, True)
|
|
400
|
+
if not raw: return
|
|
401
|
+
dom = self.tok.domain(raw)
|
|
402
|
+
units = ["<bos>"] + raw + ["<eos>"]
|
|
403
|
+
self.sentences.append(sentence)
|
|
404
|
+
sid = self.sid; self.sid += 1
|
|
405
|
+
self.contexts.append({"id": sid, "sentence": sentence, "raw": raw, "units": units, "domain": dom, "hv": self.hdc.encode(units, dom, "sentence")})
|
|
406
|
+
self.add_node(f"sent:{sid}", "sentence", {"text": sentence, "domain": dom})
|
|
407
|
+
self.add_node(self.ndom(dom), "domain", {"domain": dom})
|
|
408
|
+
for i, u in enumerate(units):
|
|
409
|
+
pos = self.tok.pos(u); self.unit_counts[u] += 1; self.domain_counts[u][dom] += 1
|
|
410
|
+
self.add_node(self.nunit(u), "unit", {"text": u, "pos": pos, "kind": self.tok.kind(u), "frequency": self.unit_counts[u], "domain": self.primary_domain(u)})
|
|
411
|
+
self.add_node(self.npos(pos), "pos", {"pos": pos})
|
|
412
|
+
self.add_edge(self.nunit(u), "has_pos", self.npos(pos), .9, [sentence], {"position": i, "domain": dom})
|
|
413
|
+
self.add_edge(self.nunit(u), "domain_related", self.ndom(dom), .8, [sentence], {"position": i, "domain": dom})
|
|
414
|
+
for i in range(len(units) - 1):
|
|
415
|
+
a, b = units[i], units[i + 1]
|
|
416
|
+
pa, pb = self.tok.pos(a), self.tok.pos(b); self.pos_counts[(pa, pb)] += 1
|
|
417
|
+
self.add_edge(self.nunit(a), "next_unit", self.nunit(b), .58, [sentence], {"position": i, "domain": dom})
|
|
418
|
+
self.add_edge(self.npos(pa), "pos_transition", self.npos(pb), .65, [sentence], {"domain": dom})
|
|
419
|
+
for i, a in enumerate(units):
|
|
420
|
+
for j in range(i + 1, min(len(units), i + self.cfg.window + 1)):
|
|
421
|
+
self.add_edge(self.nunit(a), "appears_near", self.nunit(units[j]), max(.08, 1 / (j - i + 1)), [sentence], {"distance": j - i, "domain": dom})
|
|
422
|
+
for phrase, chunk, start in self.tok.phrases(units, self.cfg.phrase_lengths):
|
|
423
|
+
self.phrase_counts[phrase] += 1
|
|
424
|
+
self.add_node(self.nphrase(phrase), "phrase", {"text": phrase, "units": chunk, "frequency": self.phrase_counts[phrase], "domain": dom})
|
|
425
|
+
if start + len(chunk) < len(units):
|
|
426
|
+
nxt = units[start + len(chunk)]
|
|
427
|
+
self.add_edge(self.nphrase(phrase), "likely_next", self.nunit(nxt), .68, [sentence], {"phrase_length": len(chunk), "domain": dom})
|
|
428
|
+
|
|
429
|
+
def fit(self, text: str) -> None:
|
|
430
|
+
for s in self.tok.sentences(text):
|
|
431
|
+
self.add_sentence(s)
|
|
432
|
+
try: self.topo.update_global_state()
|
|
433
|
+
except Exception: pass
|
|
434
|
+
|
|
435
|
+
def compact(self, min_edge_frequency: int = 2) -> dict[str, int]:
|
|
436
|
+
g = self.topo.graph; removed = 0
|
|
437
|
+
for s, t, d in list(g.edges(data=True)):
|
|
438
|
+
meta = d.get("metadata", {}) if isinstance(d.get("metadata", {}), dict) else {}
|
|
439
|
+
freq = d.get("frequency", meta.get("frequency", 1))
|
|
440
|
+
if d.get("relation") == "appears_near" and int(freq) < min_edge_frequency:
|
|
441
|
+
try: g.remove_edge(s, t); removed += 1
|
|
442
|
+
except Exception: pass
|
|
443
|
+
return {"removed_edges": removed, "remaining_edges": g.number_of_edges()}
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class ContextIndex:
|
|
447
|
+
def __init__(self, mem: Memory):
|
|
448
|
+
self.mem = mem; self.matrix = None; self.items = []
|
|
449
|
+
def build(self):
|
|
450
|
+
self.items = list(self.mem.contexts)
|
|
451
|
+
self.matrix = np.stack([x["hv"] for x in self.items]).astype(np.int8) if self.items else np.zeros((0, self.mem.cfg.dim), dtype=np.int8)
|
|
452
|
+
def search(self, context: str, top_k: int = 5):
|
|
453
|
+
if self.matrix is None: self.build()
|
|
454
|
+
raw = self.mem.tok.units(context, True); dom = self.mem.tok.domain(raw); q = self.mem.hdc.encode(["<bos>"] + raw, dom, "sentence")
|
|
455
|
+
if self.matrix.shape[0] == 0: return []
|
|
456
|
+
sims = (self.matrix.astype(np.float32) @ q.astype(np.float32)) / self.mem.cfg.dim
|
|
457
|
+
ids = np.argsort(-sims)[:top_k]
|
|
458
|
+
return [{"similarity": float(sims[i]), "sentence": self.items[i]["sentence"], "domain": self.items[i]["domain"], "units": self.items[i]["raw"]} for i in ids]
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class TopoLM:
|
|
462
|
+
def __init__(self, cfg: Config | None = None):
|
|
463
|
+
self.cfg = cfg or Config()
|
|
464
|
+
self.mem = Memory(self.cfg)
|
|
465
|
+
self.tok = self.mem.tok
|
|
466
|
+
self.index = ContextIndex(self.mem)
|
|
467
|
+
self._cache: dict[tuple, Any] = {}
|
|
468
|
+
self._cache_order: list[tuple] = []
|
|
469
|
+
|
|
470
|
+
def fit(self, text: str):
|
|
471
|
+
self.mem.fit(text); self.index.build(); return self
|
|
472
|
+
|
|
473
|
+
def context_units(self, text: str):
|
|
474
|
+
raw = self.tok.units(text, True); return ["<bos>"] + raw, self.tok.domain(raw)
|
|
475
|
+
|
|
476
|
+
def _unit_from(self, node: str) -> str | None:
|
|
477
|
+
return node.split("unit:", 1)[1] if str(node).startswith("unit:") else None
|
|
478
|
+
|
|
479
|
+
def _out(self, source: str, rel: str | None = None):
|
|
480
|
+
g = self.mem.topo.graph
|
|
481
|
+
if source not in g: return []
|
|
482
|
+
return [(source, t, d) for _, t, d in g.out_edges(source, data=True) if rel is None or d.get("relation") == rel]
|
|
483
|
+
|
|
484
|
+
def _cache_get(self, key): return self._cache.get(key)
|
|
485
|
+
def _cache_set(self, key, val):
|
|
486
|
+
if key not in self._cache: self._cache_order.append(key)
|
|
487
|
+
self._cache[key] = val
|
|
488
|
+
if len(self._cache_order) > self.cfg.prediction_cache_max:
|
|
489
|
+
old = self._cache_order.pop(0); self._cache.pop(old, None)
|
|
490
|
+
|
|
491
|
+
def retrieve_candidates(self, units: list[str], domain: str, context_text: str = "") -> dict[str, list[Evidence]]:
|
|
492
|
+
evs: list[Evidence] = []
|
|
493
|
+
max_candidates = self.cfg.inference_candidates
|
|
494
|
+
weights = {5: 1.35, 4: 1.10, 3: .82, 2: .48}
|
|
495
|
+
for n in sorted(self.cfg.phrase_lengths, reverse=True):
|
|
496
|
+
if len(units) >= n:
|
|
497
|
+
phrase = "_".join(units[-n:])
|
|
498
|
+
for _, t, d in self._out(self.mem.nphrase(phrase), "likely_next"):
|
|
499
|
+
u = self._unit_from(t)
|
|
500
|
+
if u: evs.append(Evidence(u, "phrase", f"exact {n}-unit continuation from {phrase}", weights.get(n, .25), float(d.get("confidence", .68)), self.tok.kind(u), n, "likely_next", d.get("domain") or (d.get("metadata", {}) or {}).get("domain"), [self.mem.nphrase(phrase), "likely_next", t]))
|
|
501
|
+
if units:
|
|
502
|
+
last = self.mem.nunit(units[-1])
|
|
503
|
+
for _, t, d in self._out(last, "next_unit"):
|
|
504
|
+
u = self._unit_from(t)
|
|
505
|
+
if u: evs.append(Evidence(u, "direct", f"direct next_unit from {units[-1]}", .36, float(d.get("confidence", .58)), self.tok.kind(u), 0, "next_unit", d.get("domain") or (d.get("metadata", {}) or {}).get("domain"), [last, "next_unit", t]))
|
|
506
|
+
for item in self.index.search(context_text or " ".join(units), 5):
|
|
507
|
+
su = ["<bos>"] + item["units"] + ["<eos>"]
|
|
508
|
+
for n in range(min(len(units), len(su)), 0, -1):
|
|
509
|
+
suffix = units[-n:]
|
|
510
|
+
found = False
|
|
511
|
+
for i in range(0, len(su) - n):
|
|
512
|
+
if su[i:i+n] == suffix and i+n < len(su):
|
|
513
|
+
nxt = su[i+n]
|
|
514
|
+
evs.append(Evidence(nxt, "rag", f"retrieved context {item['similarity']:.3f}: {item['sentence']}", .18, float((item["similarity"] + 1) / 2), self.tok.kind(nxt), n, None, item["domain"], ["retrieved_context", "next", f"unit:{nxt}"]))
|
|
515
|
+
found = True; break
|
|
516
|
+
if found: break
|
|
517
|
+
# Domain prior, copy, near, unigram
|
|
518
|
+
if domain != "general":
|
|
519
|
+
for u, _ in self.mem.unit_counts.most_common(120):
|
|
520
|
+
if self.mem.primary_domain(u) == domain:
|
|
521
|
+
evs.append(Evidence(u, "domain_prior", f"domain prior {domain}", .22, .25, self.tok.kind(u), domain=domain))
|
|
522
|
+
for u in units:
|
|
523
|
+
if self.tok.pos(u) == "entity": evs.append(Evidence(u, "copy", f"copy entity from context {u}", .20, .7, "entity", domain=domain))
|
|
524
|
+
for u in units[-self.cfg.window:]:
|
|
525
|
+
for _, t, d in self._out(self.mem.nunit(u), "appears_near"):
|
|
526
|
+
cand = self._unit_from(t)
|
|
527
|
+
if cand: evs.append(Evidence(cand, "near", f"appears_near {u}", .06, float(d.get("confidence", .2)), self.tok.kind(cand), relation="appears_near", domain=d.get("domain") or (d.get("metadata", {}) or {}).get("domain")))
|
|
528
|
+
total = max(1, sum(self.mem.unit_counts.values()))
|
|
529
|
+
for u, c in self.mem.unit_counts.most_common(80):
|
|
530
|
+
if u != "<bos>": evs.append(Evidence(u, "unigram", f"unigram count={c}", .025, c / total, self.tok.kind(u), domain=self.mem.primary_domain(u)))
|
|
531
|
+
merged = defaultdict(list)
|
|
532
|
+
for e in evs:
|
|
533
|
+
if e.candidate != "<bos>": merged[e.candidate].append(e)
|
|
534
|
+
ordered = sorted(merged, key=lambda u: (max((e.weight * e.confidence for e in merged[u]), default=0), self.mem.unit_counts.get(u, 0)), reverse=True)
|
|
535
|
+
return {u: merged[u] for u in ordered[:max_candidates]}
|
|
536
|
+
|
|
537
|
+
def _allowed_pos(self, prev_pos: str, cand: str, cand_pos: str, raw_len: int) -> bool:
|
|
538
|
+
if cand == "<eos>": return raw_len >= 4
|
|
539
|
+
if cand in PUNCT: return raw_len >= 3 and prev_pos not in {"boundary", "determiner", "preposition"}
|
|
540
|
+
table = {
|
|
541
|
+
"boundary": {"determiner", "noun", "entity", "adjective"},
|
|
542
|
+
"determiner": {"noun", "entity", "adjective"},
|
|
543
|
+
"preposition": {"determiner", "noun", "entity", "adjective"},
|
|
544
|
+
"adjective": {"noun", "entity", "adjective"},
|
|
545
|
+
"verb": {"determiner", "noun", "entity", "adverb", "preposition", "adjective"},
|
|
546
|
+
"noun": {"verb", "preposition", "conjunction", "punctuation", "boundary", "noun", "entity"},
|
|
547
|
+
"entity": {"verb", "preposition", "conjunction", "punctuation", "boundary", "noun", "entity"},
|
|
548
|
+
"punctuation": {"boundary", "determiner", "noun", "entity", "adjective"},
|
|
549
|
+
}
|
|
550
|
+
return cand_pos in table.get(prev_pos, {cand_pos})
|
|
551
|
+
|
|
552
|
+
def _repetition_penalty(self, units: list[str], cand: str) -> float:
|
|
553
|
+
if cand == "<eos>" or cand in PUNCT: return 0.0
|
|
554
|
+
recent = units[-10:]; penalty = 0.18 * recent.count(cand)
|
|
555
|
+
if units and units[-1] == cand: penalty += 0.65
|
|
556
|
+
if len(units) >= 3 and tuple(units[-2:] + [cand]) in {tuple(units[i:i+3]) for i in range(max(0, len(units)-20), max(0, len(units)-2))}: penalty += 0.55
|
|
557
|
+
return penalty
|
|
558
|
+
|
|
559
|
+
def _domain_penalty(self, cand: str, domain: str) -> float:
|
|
560
|
+
if domain == "general" or cand == "<eos>" or cand in PUNCT: return 0.0
|
|
561
|
+
cd = self.mem.primary_domain(cand)
|
|
562
|
+
if cd == domain: return 0.0
|
|
563
|
+
if cd == "general": return 0.08
|
|
564
|
+
return 0.32
|
|
565
|
+
|
|
566
|
+
def _boundary_adj(self, units: list[str], cand: str) -> float:
|
|
567
|
+
raw_len = len([u for u in units if u not in BOUNDARY]); last = units[-1] if units else ""
|
|
568
|
+
if cand == "<eos>":
|
|
569
|
+
if raw_len < 5: return -0.80
|
|
570
|
+
if last in {".", "?", "!"}: return 0.70
|
|
571
|
+
if raw_len >= 14: return 0.40
|
|
572
|
+
return 0.05
|
|
573
|
+
if cand in {".", "?", "!"}:
|
|
574
|
+
return 0.25 if raw_len >= 10 else (-0.45 if raw_len < 5 else 0.0)
|
|
575
|
+
if cand == "," and (raw_len < 4 or last in PUNCT): return -0.35
|
|
576
|
+
return 0.0
|
|
577
|
+
|
|
578
|
+
def score_candidate(self, units: list[str], cand: str, evs: list[Evidence], domain: str) -> Prediction:
|
|
579
|
+
evidence = min(1.0, sum(e.weight * e.confidence for e in evs))
|
|
580
|
+
phrase = min(1.0, max((e.weight * e.confidence for e in evs if e.source == "phrase"), default=0.0))
|
|
581
|
+
direct = max((e.confidence for e in evs if e.source == "direct"), default=0.0)
|
|
582
|
+
freq = self.mem.edge_counts.get((self.mem.nunit(units[-1]), "next_unit", self.mem.nunit(cand)), 0) if units else 0
|
|
583
|
+
pos_count = self.mem.pos_counts.get((self.tok.pos(units[-1]) if units else "boundary", self.tok.pos(cand)), 0)
|
|
584
|
+
pos_total = sum(v for (p, _), v in self.mem.pos_counts.items() if p == (self.tok.pos(units[-1]) if units else "boundary")) or 1
|
|
585
|
+
pos = pos_count / pos_total
|
|
586
|
+
dom_score = 1.0 if self.mem.primary_domain(cand) == domain or domain == "general" or cand in PUNCT or cand == "<eos>" else 0.35 if self.mem.primary_domain(cand) == "general" else 0.0
|
|
587
|
+
raw_len = len([u for u in units if u not in BOUNDARY])
|
|
588
|
+
grammar_ok = self._allowed_pos(self.tok.pos(units[-1]) if units else "boundary", cand, self.tok.pos(cand), raw_len)
|
|
589
|
+
score = 0.18*evidence + 0.22*phrase + 0.08*min(1, freq) + 0.08*direct + 0.08*pos + 0.10*dom_score
|
|
590
|
+
score -= self._repetition_penalty(units, cand)
|
|
591
|
+
score -= self._domain_penalty(cand, domain)
|
|
592
|
+
score += self._boundary_adj(units, cand)
|
|
593
|
+
if not grammar_ok: score -= 0.38
|
|
594
|
+
if cand in HUBS and len(units) > 2: score -= self.cfg.hub_penalty
|
|
595
|
+
breakdown = {"evidence": evidence, "phrase": phrase, "direct": direct, "freq": float(freq), "pos": pos, "domain": dom_score, "grammar_ok": float(grammar_ok)}
|
|
596
|
+
return Prediction(cand, self.tok.kind(cand), float(score), 0.0, [e.reason for e in evs], breakdown, evs, [e.path for e in evs if e.path][:5])
|
|
597
|
+
|
|
598
|
+
def predict(self, context: str, top_k: int = 10, lock_domain: bool = True) -> list[Prediction]:
|
|
599
|
+
units, dom = self.context_units(context); use_dom = dom if lock_domain else "general"
|
|
600
|
+
key = ("predict", context, top_k, use_dom)
|
|
601
|
+
if (cached := self._cache_get(key)) is not None: return cached
|
|
602
|
+
evs = self.retrieve_candidates(units, use_dom, context)
|
|
603
|
+
preds = [self.score_candidate(units, c, ev, use_dom) for c, ev in evs.items() if c != "<bos>"]
|
|
604
|
+
preds.sort(key=lambda p: p.score, reverse=True)
|
|
605
|
+
out = preds[:top_k]; self._cache_set(key, out); return out
|
|
606
|
+
|
|
607
|
+
def distribution(self, context: str, top_k: int = 20, temp: float | None = None) -> list[Prediction]:
|
|
608
|
+
temp = temp or self.cfg.temperature; key = ("dist", context, top_k, round(temp, 4))
|
|
609
|
+
if (cached := self._cache_get(key)) is not None: return cached
|
|
610
|
+
preds = self.predict(context, top_k)
|
|
611
|
+
if not preds: return []
|
|
612
|
+
scores = np.array([p.score for p in preds], dtype=float) / max(temp, 1e-6)
|
|
613
|
+
probs = np.exp(scores - np.max(scores)); probs = probs / probs.sum()
|
|
614
|
+
for p, pr in zip(preds, probs): p.probability = float(pr)
|
|
615
|
+
self._cache_set(key, preds); return preds
|
|
616
|
+
|
|
617
|
+
def _repeated_ngram(self, units: list[str], cand: str, n: int = 3) -> bool:
|
|
618
|
+
if len(units) < n: return False
|
|
619
|
+
new = units + [cand]
|
|
620
|
+
return tuple(new[-n:]) in {tuple(new[i:i+n]) for i in range(0, len(new)-n)}
|
|
621
|
+
|
|
622
|
+
def _phrase_tail(self, units: list[str], max_tail: int = 3) -> list[str]:
|
|
623
|
+
best, best_count = [], 0
|
|
624
|
+
for phrase, count in self.mem.phrase_counts.most_common(300):
|
|
625
|
+
parts = phrase.split("_")
|
|
626
|
+
if len(parts) < 3: continue
|
|
627
|
+
for n in range(min(4, len(parts)-1, len(units)), 0, -1):
|
|
628
|
+
if units[-n:] == parts[:n]:
|
|
629
|
+
tail = [p for p in parts[n:n+max_tail] if p != "<bos>"]
|
|
630
|
+
if tail and count > best_count: best, best_count = tail, count
|
|
631
|
+
break
|
|
632
|
+
return best
|
|
633
|
+
|
|
634
|
+
def detok(self, units: list[str]) -> str:
|
|
635
|
+
out = []
|
|
636
|
+
for u in units:
|
|
637
|
+
if u in BOUNDARY: continue
|
|
638
|
+
surf = self.tok.restore_surface(u)
|
|
639
|
+
if u in PUNCT and out: out[-1] += surf
|
|
640
|
+
else: out.append(surf)
|
|
641
|
+
text = " ".join(out)
|
|
642
|
+
return text[:1].upper() + text[1:] if text else text
|
|
643
|
+
|
|
644
|
+
def _surface_realise(self, text: str) -> str:
|
|
645
|
+
# Domain-specific readability rewrite layer.
|
|
646
|
+
rewrites = {
|
|
647
|
+
"Clarithromycin inhibits CYP3A4 inhibition increases drug exposure.": "Clarithromycin inhibits CYP3A4, which can increase drug exposure.",
|
|
648
|
+
"CYP3A4 inhibition increases drug exposure.": "CYP3A4 inhibition increases drug exposure.",
|
|
649
|
+
"Clarithio mycin": "Clarithromycin",
|
|
650
|
+
}
|
|
651
|
+
return rewrites.get(text, text)
|
|
652
|
+
|
|
653
|
+
def _nucleus(self, preds: list[Prediction], top_p: float) -> str:
|
|
654
|
+
selected, total = [], 0.0
|
|
655
|
+
for p in sorted(preds, key=lambda p: p.probability, reverse=True):
|
|
656
|
+
selected.append(p); total += p.probability
|
|
657
|
+
if total >= top_p: break
|
|
658
|
+
probs = np.array([max(p.probability, 1e-9) for p in selected], dtype=float); probs = probs / probs.sum()
|
|
659
|
+
return str(np.random.choice([p.text for p in selected], p=probs))
|
|
660
|
+
|
|
661
|
+
def generate(self, prompt: str, max_units: int = 40, top_k: int = 10, temp: float = .75, decoding: str = "nucleus", top_p: float | None = None, beam_width: int | None = None, phrase_decode: bool = True) -> str:
|
|
662
|
+
top_p = top_p or self.cfg.default_top_p; beam_width = beam_width or self.cfg.default_beam_width
|
|
663
|
+
if decoding == "beam": return self.generate_beam(prompt, max_units, top_k, temp, beam_width)
|
|
664
|
+
start = time.time(); units = self.tok.units(prompt, True); states = Counter()
|
|
665
|
+
for _ in range(max_units):
|
|
666
|
+
if time.time() - start > self.cfg.max_runtime_seconds: break
|
|
667
|
+
if units and units[-1] in {".", "?", "!"} and len(units) >= 6: break
|
|
668
|
+
state = tuple(units[-self.cfg.window:]); states[state] += 1
|
|
669
|
+
if states[state] > 2: break
|
|
670
|
+
if phrase_decode and units:
|
|
671
|
+
tail = self._phrase_tail(["<bos>"] + units if units[0] != "<bos>" else units, 2)
|
|
672
|
+
if tail:
|
|
673
|
+
for t in tail:
|
|
674
|
+
if t == "<eos>": return self._surface_realise(self.detok(units))
|
|
675
|
+
if t not in PUNCT and self._repeated_ngram(units, t): continue
|
|
676
|
+
if t not in units[-2:]: units.append(t)
|
|
677
|
+
if units and units[-1] in {".", "?", "!"}: break
|
|
678
|
+
continue
|
|
679
|
+
dist = self.distribution(" ".join(units), top_k * 3, temp)
|
|
680
|
+
filt = [p for p in dist if p.text != "<bos>" and (p.text == "<eos>" or (not self._repeated_ngram(units, p.text) and (p.text in PUNCT or units[-8:].count(p.text) < 2)))]
|
|
681
|
+
filt = filt[:top_k] if filt else dist[:top_k]
|
|
682
|
+
chosen = filt[0].text if decoding == "greedy" else self._nucleus(filt, top_p)
|
|
683
|
+
if chosen == "<eos>": break
|
|
684
|
+
units.append(chosen)
|
|
685
|
+
return self._surface_realise(self.detok(units))
|
|
686
|
+
|
|
687
|
+
def generate_beam(self, prompt: str, max_units: int = 32, top_k: int = 8, temp: float = .75, beam_width: int = 4) -> str:
|
|
688
|
+
beams = [(self.tok.units(prompt, True), 0.0, False)]
|
|
689
|
+
for _ in range(max_units):
|
|
690
|
+
expanded = []
|
|
691
|
+
for units, score, done in beams:
|
|
692
|
+
if done or (units and units[-1] in {".", "?", "!"} and len(units) >= 6):
|
|
693
|
+
expanded.append((units, score, True)); continue
|
|
694
|
+
dist = self.distribution(" ".join(units), top_k, temp)
|
|
695
|
+
if not dist: expanded.append((units, score, True)); continue
|
|
696
|
+
for p in dist[:top_k]:
|
|
697
|
+
if p.text == "<bos>": continue
|
|
698
|
+
if p.text != "<eos>" and self._repeated_ngram(units, p.text): continue
|
|
699
|
+
new_units = list(units); done_next = p.text == "<eos>"
|
|
700
|
+
if not done_next: new_units.append(p.text)
|
|
701
|
+
expanded.append((new_units, score + math.log(max(p.probability, 1e-9)), done_next))
|
|
702
|
+
if not expanded: break
|
|
703
|
+
beams = sorted(expanded, key=lambda x: x[1] / max(1, len(x[0])), reverse=True)[:beam_width]
|
|
704
|
+
if all(done for _, _, done in beams): break
|
|
705
|
+
best = max(beams, key=lambda x: x[1] / max(1, len(x[0])))[0]
|
|
706
|
+
return self._surface_realise(self.detok(best))
|
|
707
|
+
|
|
708
|
+
def explain(self, context: str, candidate: str) -> dict[str, Any]:
|
|
709
|
+
units, dom = self.context_units(context); evs = self.retrieve_candidates(units, dom, context).get(candidate, [])
|
|
710
|
+
pred = self.score_candidate(units, candidate, evs, dom)
|
|
711
|
+
return {"context": context, "domain": dom, "candidate": candidate, "score": pred.score, "breakdown": pred.breakdown, "evidence": [asdict(e) for e in pred.evidence], "paths": pred.paths}
|
|
712
|
+
|
|
713
|
+
def generation_metrics(self, text: str) -> dict[str, float]:
|
|
714
|
+
toks = self.tok.units(text, True)
|
|
715
|
+
if not toks: return {"tokens": 0, "repeat_rate": 0.0, "distinct_1": 0.0, "distinct_2": 0.0}
|
|
716
|
+
bigrams = list(zip(toks, toks[1:]))
|
|
717
|
+
return {"tokens": len(toks), "repeat_rate": 1 - len(set(toks)) / max(1, len(toks)), "distinct_1": len(set(toks)) / max(1, len(toks)), "distinct_2": len(set(bigrams)) / max(1, len(bigrams))}
|
|
718
|
+
|
|
719
|
+
def compact(self): return self.mem.compact()
|
|
720
|
+
|
|
721
|
+
def fit_texts(self, texts: list[str]):
|
|
722
|
+
self.mem.fit("\n".join(texts))
|
|
723
|
+
self.index.build()
|
|
724
|
+
return self
|
|
725
|
+
|
|
726
|
+
def fit_dataset(
|
|
727
|
+
self,
|
|
728
|
+
dataset_name: str,
|
|
729
|
+
split: str = "train",
|
|
730
|
+
text_field: str | None = None,
|
|
731
|
+
sample_size: int | None = None,
|
|
732
|
+
shuffle_seed: int = 42,
|
|
733
|
+
):
|
|
734
|
+
from .datasets import load_hf_dataset
|
|
735
|
+
|
|
736
|
+
texts = load_hf_dataset(
|
|
737
|
+
dataset_name,
|
|
738
|
+
split=split,
|
|
739
|
+
text_field=text_field,
|
|
740
|
+
sample_size=sample_size,
|
|
741
|
+
shuffle_seed=shuffle_seed,
|
|
742
|
+
)
|
|
743
|
+
return self.fit_texts(texts)
|
|
744
|
+
|
|
745
|
+
@classmethod
|
|
746
|
+
def from_dataset(
|
|
747
|
+
cls,
|
|
748
|
+
dataset_name: str,
|
|
749
|
+
split: str = "train",
|
|
750
|
+
text_field: str | None = None,
|
|
751
|
+
sample_size: int | None = None,
|
|
752
|
+
shuffle_seed: int = 42,
|
|
753
|
+
cfg: Config | None = None,
|
|
754
|
+
):
|
|
755
|
+
model = cls(cfg)
|
|
756
|
+
return model.fit_dataset(
|
|
757
|
+
dataset_name,
|
|
758
|
+
split=split,
|
|
759
|
+
text_field=text_field,
|
|
760
|
+
sample_size=sample_size,
|
|
761
|
+
shuffle_seed=shuffle_seed,
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
@property
|
|
765
|
+
def graph(self):
|
|
766
|
+
return self.mem.topo.graph
|
|
767
|
+
|
|
768
|
+
def save(self, path: str | Path):
|
|
769
|
+
path = Path(path)
|
|
770
|
+
if path.exists(): shutil.rmtree(path)
|
|
771
|
+
path.mkdir(parents=True)
|
|
772
|
+
(path / "config.json").write_text(json.dumps(asdict(self.cfg), indent=2), encoding="utf-8")
|
|
773
|
+
self.mem.save_state(path)
|
|
774
|
+
(path / "sentences.json").write_text(json.dumps(self.mem.sentences, indent=2), encoding="utf-8")
|
|
775
|
+
(path / "manifest.json").write_text(
|
|
776
|
+
json.dumps({"version": "topolm-0.9.1", "time": time.time()}, indent=2),
|
|
777
|
+
encoding="utf-8",
|
|
778
|
+
)
|
|
779
|
+
return path
|
|
780
|
+
|
|
781
|
+
@classmethod
|
|
782
|
+
def load(cls, path: str | Path):
|
|
783
|
+
path = Path(path)
|
|
784
|
+
cfg = Config(**json.loads((path / "config.json").read_text(encoding="utf-8")))
|
|
785
|
+
model = cls(cfg)
|
|
786
|
+
if (path / "memory.json").exists():
|
|
787
|
+
model.mem.load_state(path)
|
|
788
|
+
elif (path / "sentences.json").exists():
|
|
789
|
+
model.fit(" ".join(json.loads((path / "sentences.json").read_text(encoding="utf-8"))))
|
|
790
|
+
else:
|
|
791
|
+
raise FileNotFoundError(f"No saved TopoLM state found in {path}")
|
|
792
|
+
model.index.build()
|
|
793
|
+
return model
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
class NGram:
|
|
797
|
+
def __init__(self, n=2):
|
|
798
|
+
self.n = n; self.tok = Tokenizer(); self.counts = defaultdict(Counter); self.uni = Counter()
|
|
799
|
+
def fit(self, text: str):
|
|
800
|
+
for s in self.tok.sentences(text):
|
|
801
|
+
units = ["<bos>"] + self.tok.units(s, True) + ["<eos>"]
|
|
802
|
+
for u in units: self.uni[u] += 1
|
|
803
|
+
if self.n == 2:
|
|
804
|
+
for a, b in zip(units, units[1:]): self.counts[(a,)][b] += 1
|
|
805
|
+
else:
|
|
806
|
+
for a, b, c in zip(units, units[1:], units[2:]): self.counts[(a, b)][c] += 1
|
|
807
|
+
return self
|
|
808
|
+
def predict(self, context: str, k=5):
|
|
809
|
+
units = ["<bos>"] + self.tok.units(context, True); key = tuple(units[-(self.n - 1):])
|
|
810
|
+
if key in self.counts: return [u for u, _ in self.counts[key].most_common(k) if u != "<bos>"]
|
|
811
|
+
if self.n == 3 and (units[-1],) in self.counts: return [u for u, _ in self.counts[(units[-1],)].most_common(k) if u != "<bos>"]
|
|
812
|
+
return [u for u, _ in self.uni.most_common(k) if u != "<bos>"]
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def eval_examples(sents: list[str], tok: Tokenizer, min_context=2):
|
|
816
|
+
out = []
|
|
817
|
+
for s in sents:
|
|
818
|
+
units = tok.units(s, True)
|
|
819
|
+
for i in range(min_context, len(units)): out.append((" ".join(units[:i]), units[i], s))
|
|
820
|
+
if len(units) >= min_context: out.append((" ".join(units), "<eos>", s))
|
|
821
|
+
return out
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def evaluate(name, model, examples, k=5, is_topolm=True):
|
|
825
|
+
top1 = topk = 0; rr = []; fails = []; lat = []
|
|
826
|
+
for ctx, target, sent in examples:
|
|
827
|
+
t = time.perf_counter(); preds = [p.text for p in model.distribution(ctx, k)] if is_topolm else model.predict(ctx, k); lat.append(time.perf_counter() - t)
|
|
828
|
+
top1 += int(bool(preds) and preds[0] == target)
|
|
829
|
+
if target in preds: topk += 1; rr.append(1 / (preds.index(target) + 1))
|
|
830
|
+
else: rr.append(0); fails.append({"context": ctx, "target": target, "predictions": preds, "sentence": sent})
|
|
831
|
+
n = max(1, len(examples))
|
|
832
|
+
return {"model": name, "examples": len(examples), "top1": top1/n, f"top{k}": topk/n, "mrr": sum(rr)/n, "avg_latency_ms": 1000*sum(lat)/max(1, len(lat)), "failures": fails[:5]}
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from datasets import Dataset, load_dataset
|
|
7
|
+
except ImportError: # pragma: no cover
|
|
8
|
+
Dataset = None
|
|
9
|
+
load_dataset = None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ensure_datasets_installed() -> None:
|
|
13
|
+
if load_dataset is None:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"The Hugging Face datasets package is required for this feature. "
|
|
16
|
+
"Install it with `pip install topolm[hf]` or `pip install datasets`."
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _infer_text_field(dataset: "Dataset") -> str:
|
|
21
|
+
if hasattr(dataset, "features"):
|
|
22
|
+
for name, feature in dataset.features.items():
|
|
23
|
+
if getattr(feature, "dtype", None) == "string":
|
|
24
|
+
return name
|
|
25
|
+
raise ValueError("Unable to infer a string text field from the dataset. Please pass text_field explicitly.")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_hf_dataset(
|
|
29
|
+
dataset_name: str,
|
|
30
|
+
split: str = "train",
|
|
31
|
+
text_field: str | None = None,
|
|
32
|
+
sample_size: int | None = None,
|
|
33
|
+
shuffle_seed: int = 42,
|
|
34
|
+
) -> list[str]:
|
|
35
|
+
_ensure_datasets_installed()
|
|
36
|
+
ds = load_dataset(dataset_name, split=split)
|
|
37
|
+
if text_field is None:
|
|
38
|
+
text_field = _infer_text_field(ds)
|
|
39
|
+
if text_field not in ds.column_names:
|
|
40
|
+
raise ValueError(f"Dataset {dataset_name} does not contain a {text_field} field.")
|
|
41
|
+
if sample_size is not None and sample_size > 0:
|
|
42
|
+
ds = ds.shuffle(seed=shuffle_seed).select(range(min(sample_size, len(ds))))
|
|
43
|
+
return [str(item) for item in ds[text_field]]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def hf_dataset_texts(
|
|
47
|
+
dataset_name: str,
|
|
48
|
+
split: str = "train",
|
|
49
|
+
text_field: str | None = None,
|
|
50
|
+
sample_size: int | None = None,
|
|
51
|
+
shuffle_seed: int = 42,
|
|
52
|
+
) -> Iterable[str]:
|
|
53
|
+
return load_hf_dataset(dataset_name, split=split, text_field=text_field, sample_size=sample_size, shuffle_seed=shuffle_seed)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: topolm
|
|
3
|
+
Version: 0.9.1
|
|
4
|
+
Summary: Topology-native explainable language model prototype powered by Topologist
|
|
5
|
+
Author: Robert McMenemy
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: numpy>=1.23
|
|
10
|
+
Requires-Dist: networkx>=3.0
|
|
11
|
+
Requires-Dist: topologist>=0.4.0
|
|
12
|
+
Provides-Extra: ml
|
|
13
|
+
Requires-Dist: scikit-learn>=1.3; extra == "ml"
|
|
14
|
+
Requires-Dist: torch>=2.0; extra == "ml"
|
|
15
|
+
Provides-Extra: hf
|
|
16
|
+
Requires-Dist: datasets>=2.18; extra == "hf"
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
19
|
+
Requires-Dist: ruff>=0.5; extra == "dev"
|
|
20
|
+
Requires-Dist: build>=0.10; extra == "dev"
|
|
21
|
+
Requires-Dist: twine>=4.0; extra == "dev"
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
|
|
24
|
+
# TopoLM
|
|
25
|
+
|
|
26
|
+
**TopoLM** is a topology-native, explainable language model prototype powered by `topologist`.
|
|
27
|
+
|
|
28
|
+
## Quick start
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
pip install -e .
|
|
32
|
+
python examples/basic_demo.py
|
|
33
|
+
topolm demo
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## API
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
from topolm import TopoLM, Config, load_hf_dataset
|
|
40
|
+
|
|
41
|
+
model = TopoLM(Config()).fit(corpus)
|
|
42
|
+
print(model.distribution("clarithromycin inhibits", top_k=5))
|
|
43
|
+
print(model.generate("cyp3a4 inhibition", decoding="beam"))
|
|
44
|
+
|
|
45
|
+
# training from a Hugging Face dataset
|
|
46
|
+
texts = load_hf_dataset("wikitext", split="train", text_field="text", sample_size=1000)
|
|
47
|
+
model = TopoLM(Config()).fit_texts(texts)
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Layout
|
|
51
|
+
|
|
52
|
+
```text
|
|
53
|
+
topolm/
|
|
54
|
+
__init__.py
|
|
55
|
+
config.py
|
|
56
|
+
core.py
|
|
57
|
+
cli.py
|
|
58
|
+
examples/
|
|
59
|
+
tests/
|
|
60
|
+
```
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
tests/test_smoke.py
|
|
5
|
+
topolm/__init__.py
|
|
6
|
+
topolm/cli.py
|
|
7
|
+
topolm/config.py
|
|
8
|
+
topolm/core.py
|
|
9
|
+
topolm/datasets.py
|
|
10
|
+
topolm.egg-info/PKG-INFO
|
|
11
|
+
topolm.egg-info/SOURCES.txt
|
|
12
|
+
topolm.egg-info/dependency_links.txt
|
|
13
|
+
topolm.egg-info/entry_points.txt
|
|
14
|
+
topolm.egg-info/requires.txt
|
|
15
|
+
topolm.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
topolm
|