glitchlings 0.4.0__cp312-cp312-macosx_11_0_universal2.whl → 0.4.2__cp312-cp312-macosx_11_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +26 -17
- glitchlings/__main__.py +0 -1
- glitchlings/_zoo_rust.cpython-312-darwin.so +0 -0
- glitchlings/compat.py +215 -0
- glitchlings/config.py +136 -19
- glitchlings/dlc/_shared.py +68 -0
- glitchlings/dlc/huggingface.py +26 -41
- glitchlings/dlc/prime.py +64 -101
- glitchlings/lexicon/__init__.py +26 -19
- glitchlings/lexicon/_cache.py +104 -0
- glitchlings/lexicon/graph.py +18 -39
- glitchlings/lexicon/metrics.py +1 -8
- glitchlings/lexicon/vector.py +29 -67
- glitchlings/lexicon/wordnet.py +39 -30
- glitchlings/main.py +9 -13
- glitchlings/util/__init__.py +18 -4
- glitchlings/util/adapters.py +27 -0
- glitchlings/zoo/__init__.py +21 -14
- glitchlings/zoo/_ocr_confusions.py +1 -3
- glitchlings/zoo/_rate.py +1 -4
- glitchlings/zoo/_sampling.py +0 -1
- glitchlings/zoo/_text_utils.py +1 -5
- glitchlings/zoo/adjax.py +0 -2
- glitchlings/zoo/core.py +185 -56
- glitchlings/zoo/jargoyle.py +9 -14
- glitchlings/zoo/mim1c.py +11 -10
- glitchlings/zoo/redactyl.py +5 -8
- glitchlings/zoo/reduple.py +3 -1
- glitchlings/zoo/rushmore.py +2 -8
- glitchlings/zoo/scannequin.py +5 -4
- glitchlings/zoo/typogre.py +3 -7
- glitchlings/zoo/zeedub.py +2 -2
- {glitchlings-0.4.0.dist-info → glitchlings-0.4.2.dist-info}/METADATA +68 -4
- glitchlings-0.4.2.dist-info/RECORD +42 -0
- glitchlings-0.4.0.dist-info/RECORD +0 -38
- {glitchlings-0.4.0.dist-info → glitchlings-0.4.2.dist-info}/WHEEL +0 -0
- {glitchlings-0.4.0.dist-info → glitchlings-0.4.2.dist-info}/entry_points.txt +0 -0
- {glitchlings-0.4.0.dist-info → glitchlings-0.4.2.dist-info}/licenses/LICENSE +0 -0
- {glitchlings-0.4.0.dist-info → glitchlings-0.4.2.dist-info}/top_level.txt +0 -0
glitchlings/dlc/huggingface.py
CHANGED
|
@@ -3,21 +3,15 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, cast
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
_DatasetsDataset = None # type: ignore[assignment]
|
|
12
|
-
else:
|
|
13
|
-
_datasets_error = None
|
|
14
|
-
|
|
15
|
-
from ..zoo import Gaggle, Glitchling, summon
|
|
8
|
+
from ..compat import datasets, get_datasets_dataset, require_datasets
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
16
11
|
|
|
17
12
|
|
|
18
13
|
def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
19
14
|
"""Normalise a column specification to a list."""
|
|
20
|
-
|
|
21
15
|
if isinstance(column, str):
|
|
22
16
|
return [column]
|
|
23
17
|
|
|
@@ -27,20 +21,6 @@ def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
|
27
21
|
return normalised
|
|
28
22
|
|
|
29
23
|
|
|
30
|
-
def _as_gaggle(glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling], seed: int) -> Gaggle:
|
|
31
|
-
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
32
|
-
|
|
33
|
-
if isinstance(glitchlings, Gaggle):
|
|
34
|
-
return glitchlings
|
|
35
|
-
|
|
36
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
37
|
-
resolved: Iterable[str | Glitchling] = [glitchlings]
|
|
38
|
-
else:
|
|
39
|
-
resolved = glitchlings
|
|
40
|
-
|
|
41
|
-
return summon(list(resolved), seed=seed)
|
|
42
|
-
|
|
43
|
-
|
|
44
24
|
def _glitch_dataset(
|
|
45
25
|
dataset: Any,
|
|
46
26
|
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
@@ -48,23 +28,28 @@ def _glitch_dataset(
|
|
|
48
28
|
*,
|
|
49
29
|
seed: int = 151,
|
|
50
30
|
) -> Any:
|
|
51
|
-
"""
|
|
52
|
-
|
|
31
|
+
"""Apply glitchlings to the provided dataset columns."""
|
|
53
32
|
columns = _normalise_columns(column)
|
|
54
|
-
gaggle =
|
|
33
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
55
34
|
return gaggle.corrupt_dataset(dataset, columns)
|
|
56
35
|
|
|
57
36
|
|
|
58
37
|
def _ensure_dataset_class() -> Any:
|
|
59
38
|
"""Return the Hugging Face :class:`~datasets.Dataset` patched with ``.glitch``."""
|
|
60
|
-
|
|
61
|
-
if
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
39
|
+
dataset_cls = get_datasets_dataset()
|
|
40
|
+
if dataset_cls is None: # pragma: no cover - datasets is an install-time dependency
|
|
41
|
+
require_datasets("datasets is not installed")
|
|
42
|
+
dataset_cls = get_datasets_dataset()
|
|
43
|
+
if dataset_cls is None:
|
|
44
|
+
message = "datasets is not installed"
|
|
45
|
+
error = datasets.error
|
|
46
|
+
if error is not None:
|
|
47
|
+
raise ModuleNotFoundError(message) from error
|
|
48
|
+
raise ModuleNotFoundError(message)
|
|
49
|
+
|
|
50
|
+
if getattr(dataset_cls, "glitch", None) is None:
|
|
51
|
+
|
|
52
|
+
def glitch(
|
|
68
53
|
self: Any,
|
|
69
54
|
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
70
55
|
*,
|
|
@@ -73,24 +58,24 @@ def _ensure_dataset_class() -> Any:
|
|
|
73
58
|
**_: Any,
|
|
74
59
|
) -> Any:
|
|
75
60
|
"""Return a lazily corrupted copy of the dataset."""
|
|
76
|
-
|
|
77
61
|
return _glitch_dataset(self, glitchlings, column, seed=seed)
|
|
78
62
|
|
|
79
|
-
setattr(
|
|
63
|
+
setattr(dataset_cls, "glitch", glitch)
|
|
80
64
|
|
|
81
|
-
return
|
|
65
|
+
return cast(type[Any], dataset_cls)
|
|
82
66
|
|
|
83
67
|
|
|
84
68
|
def install() -> None:
|
|
85
69
|
"""Monkeypatch the Hugging Face :class:`~datasets.Dataset` with ``.glitch``."""
|
|
86
|
-
|
|
87
70
|
_ensure_dataset_class()
|
|
88
71
|
|
|
89
72
|
|
|
90
|
-
|
|
73
|
+
Dataset: type[Any] | None
|
|
74
|
+
_DatasetAlias = get_datasets_dataset()
|
|
75
|
+
if _DatasetAlias is not None:
|
|
91
76
|
Dataset = _ensure_dataset_class()
|
|
92
77
|
else: # pragma: no cover - datasets is an install-time dependency
|
|
93
|
-
Dataset = None
|
|
78
|
+
Dataset = None
|
|
94
79
|
|
|
95
80
|
|
|
96
81
|
__all__ = ["Dataset", "install"]
|
glitchlings/dlc/prime.py
CHANGED
|
@@ -4,79 +4,60 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import Any, Callable
|
|
7
|
+
from typing import Any, Callable, Protocol, cast
|
|
8
8
|
|
|
9
|
-
import
|
|
9
|
+
from ..compat import require_datasets, require_jellyfish, require_verifiers
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling, Mim1c, Typogre
|
|
12
|
+
from ._shared import resolve_columns as _resolve_columns_shared
|
|
13
|
+
from ._shared import resolve_environment as _resolve_environment_shared
|
|
10
14
|
|
|
11
|
-
from jellyfish import damerau_levenshtein_distance
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
16
|
-
Dataset = object # type: ignore[assignment]
|
|
17
|
-
else:
|
|
18
|
-
if Dataset is None: # pragma: no cover - optional dependency
|
|
19
|
-
Dataset = object # type: ignore[assignment]
|
|
16
|
+
class VerifierEnvironment(Protocol):
|
|
17
|
+
"""Minimal interface for verifiers environments."""
|
|
20
18
|
|
|
21
|
-
|
|
19
|
+
dataset: Any
|
|
22
20
|
|
|
23
21
|
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
if isinstance(env, str):
|
|
28
|
-
env = vf.load_environment(env)
|
|
22
|
+
class VerifierSingleTurnEnv(Protocol):
|
|
23
|
+
"""Minimal interface for single-turn verifier environments."""
|
|
29
24
|
|
|
30
|
-
|
|
31
|
-
|
|
25
|
+
dataset: Any
|
|
26
|
+
rubric: Any
|
|
32
27
|
|
|
33
|
-
return env
|
|
34
28
|
|
|
29
|
+
vf = require_verifiers("verifiers is not installed; install glitchlings[prime]")
|
|
30
|
+
_jellyfish = require_jellyfish("jellyfish is not installed; install glitchlings[prime]")
|
|
31
|
+
damerau_levenshtein_distance = _jellyfish.damerau_levenshtein_distance
|
|
35
32
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
33
|
+
try:
|
|
34
|
+
from .huggingface import Dataset as _HuggingFaceDataset
|
|
35
|
+
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
36
|
+
_HuggingFaceDataset = None
|
|
37
|
+
else:
|
|
38
|
+
if _HuggingFaceDataset is None: # pragma: no cover - optional dependency
|
|
39
|
+
_HuggingFaceDataset = None
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return list(columns)
|
|
41
|
+
Dataset: type[Any]
|
|
42
|
+
if _HuggingFaceDataset is None:
|
|
43
|
+
Dataset = object
|
|
44
|
+
else:
|
|
45
|
+
Dataset = _HuggingFaceDataset
|
|
47
46
|
|
|
48
|
-
for candidate in ("prompt", "question"):
|
|
49
|
-
if candidate in available:
|
|
50
|
-
return [candidate]
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
iterator = iter(dataset)
|
|
61
|
-
try:
|
|
62
|
-
first_row = next(iterator)
|
|
63
|
-
except StopIteration:
|
|
64
|
-
preview_rows = []
|
|
65
|
-
else:
|
|
66
|
-
preview_rows = [first_row]
|
|
67
|
-
sample = dict(preview_rows[0]) if preview_rows else {}
|
|
68
|
-
else:
|
|
69
|
-
sample = dataset[0] if dataset_length else {}
|
|
70
|
-
inferred = [
|
|
71
|
-
name
|
|
72
|
-
for name in dataset.column_names
|
|
73
|
-
if isinstance(sample.get(name), str)
|
|
74
|
-
]
|
|
48
|
+
def _resolve_environment(env: str | VerifierEnvironment) -> VerifierEnvironment:
|
|
49
|
+
"""Return a fully-instantiated verifier environment."""
|
|
50
|
+
resolved = _resolve_environment_shared(
|
|
51
|
+
env,
|
|
52
|
+
loader=vf.load_environment,
|
|
53
|
+
environment_type=cast(type[Any], vf.Environment),
|
|
54
|
+
)
|
|
55
|
+
return cast(VerifierEnvironment, resolved)
|
|
75
56
|
|
|
76
|
-
if inferred:
|
|
77
|
-
return inferred
|
|
78
57
|
|
|
79
|
-
|
|
58
|
+
def _resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
59
|
+
"""Identify which dataset columns should be corrupted."""
|
|
60
|
+
return _resolve_columns_shared(dataset, columns)
|
|
80
61
|
|
|
81
62
|
|
|
82
63
|
class Difficulty(Enum):
|
|
@@ -90,12 +71,11 @@ class Difficulty(Enum):
|
|
|
90
71
|
|
|
91
72
|
|
|
92
73
|
def tutorial_level(
|
|
93
|
-
env:
|
|
74
|
+
env: VerifierEnvironment | str,
|
|
94
75
|
seed: int = 151,
|
|
95
76
|
difficulty: Difficulty = Difficulty.Normal,
|
|
96
|
-
) ->
|
|
77
|
+
) -> VerifierEnvironment:
|
|
97
78
|
"""Create a low-corruption environment using tuned defaults."""
|
|
98
|
-
|
|
99
79
|
tuned_mim1c = Mim1c(rate=0.01 * difficulty.value)
|
|
100
80
|
tuned_typogre = Typogre(rate=0.025 * difficulty.value)
|
|
101
81
|
|
|
@@ -107,28 +87,19 @@ def tutorial_level(
|
|
|
107
87
|
|
|
108
88
|
|
|
109
89
|
def load_environment(
|
|
110
|
-
env: str |
|
|
90
|
+
env: str | VerifierEnvironment,
|
|
111
91
|
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle | None = None,
|
|
112
92
|
*,
|
|
113
93
|
seed: int = 151,
|
|
114
94
|
columns: Sequence[str] | None = None,
|
|
115
|
-
) ->
|
|
95
|
+
) -> VerifierEnvironment:
|
|
116
96
|
"""Load an environment and optionally corrupt it with glitchlings."""
|
|
117
|
-
|
|
118
97
|
environment = _resolve_environment(env)
|
|
119
98
|
|
|
120
99
|
if glitchlings is None:
|
|
121
100
|
return environment
|
|
122
101
|
|
|
123
|
-
|
|
124
|
-
gaggle = glitchlings
|
|
125
|
-
else:
|
|
126
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
127
|
-
resolved = [glitchlings]
|
|
128
|
-
else:
|
|
129
|
-
resolved = list(glitchlings)
|
|
130
|
-
|
|
131
|
-
gaggle = summon(resolved, seed=seed)
|
|
102
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
132
103
|
|
|
133
104
|
dataset = environment.dataset
|
|
134
105
|
corrupt_columns = _resolve_columns(dataset, columns)
|
|
@@ -142,21 +113,11 @@ def _as_gaggle(
|
|
|
142
113
|
seed: int,
|
|
143
114
|
) -> Gaggle:
|
|
144
115
|
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
145
|
-
|
|
146
|
-
if isinstance(glitchlings, Gaggle):
|
|
147
|
-
return glitchlings
|
|
148
|
-
|
|
149
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
150
|
-
resolved: Iterable[str | Glitchling] = [glitchlings]
|
|
151
|
-
else:
|
|
152
|
-
resolved = glitchlings
|
|
153
|
-
|
|
154
|
-
return summon(list(resolved), seed=seed)
|
|
116
|
+
return coerce_gaggle(glitchlings, seed=seed)
|
|
155
117
|
|
|
156
118
|
|
|
157
119
|
def _extract_completion_text(completion: Any) -> str:
|
|
158
120
|
"""Normalise a completion payload into a plain string."""
|
|
159
|
-
|
|
160
121
|
if isinstance(completion, str):
|
|
161
122
|
return completion
|
|
162
123
|
|
|
@@ -175,11 +136,10 @@ def symmetric_damerau_levenshtein_similarity(
|
|
|
175
136
|
answer: str,
|
|
176
137
|
) -> float:
|
|
177
138
|
"""Return ``1 - (distance / max_len)`` using Damerau-Levenshtein distance."""
|
|
178
|
-
|
|
179
139
|
completion_text = _extract_completion_text(completion)
|
|
180
140
|
target = answer or ""
|
|
181
141
|
denominator = max(len(completion_text), len(target), 1)
|
|
182
|
-
distance = damerau_levenshtein_distance(completion_text, target)
|
|
142
|
+
distance = cast(int, damerau_levenshtein_distance(completion_text, target))
|
|
183
143
|
score = 1.0 - (distance / denominator)
|
|
184
144
|
return max(0.0, min(1.0, score))
|
|
185
145
|
|
|
@@ -199,32 +159,34 @@ def echo_chamber(
|
|
|
199
159
|
reward_function: Callable[..., float] | None = None,
|
|
200
160
|
split: str | None = None,
|
|
201
161
|
**load_dataset_kwargs: Any,
|
|
202
|
-
) ->
|
|
162
|
+
) -> VerifierSingleTurnEnv:
|
|
203
163
|
"""Create an Echo Chamber Prime environment from a Hugging Face dataset column.
|
|
204
164
|
|
|
205
165
|
Args:
|
|
206
166
|
dataset_id: Identifier of the Hugging Face dataset to load.
|
|
207
167
|
column: Name of the column whose text should be glitched.
|
|
208
168
|
glitchlings: Glitchling specifiers that will corrupt the prompts.
|
|
209
|
-
seed: RNG seed forwarded to :func:`
|
|
169
|
+
seed: RNG seed forwarded to :func:`glitchlings.util.adapters.coerce_gaggle`.
|
|
210
170
|
instructions: System instructions supplied to the environment prompts.
|
|
211
171
|
reward_function: Optional callable used to score completions. Defaults to
|
|
212
172
|
:func:`symmetric_damerau_levenshtein_similarity` when omitted.
|
|
213
173
|
split: Optional dataset split to load.
|
|
214
174
|
**load_dataset_kwargs: Extra keyword arguments forwarded to
|
|
215
175
|
:func:`datasets.load_dataset`.
|
|
216
|
-
"""
|
|
217
176
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
177
|
+
"""
|
|
178
|
+
datasets_module = require_datasets("datasets is required to build an echo chamber")
|
|
179
|
+
load_dataset = getattr(datasets_module, "load_dataset", None)
|
|
180
|
+
if load_dataset is None: # pragma: no cover - defensive
|
|
221
181
|
message = "datasets is required to build an echo chamber"
|
|
222
|
-
raise ModuleNotFoundError(message)
|
|
182
|
+
raise ModuleNotFoundError(message)
|
|
223
183
|
|
|
224
|
-
|
|
184
|
+
dataset_dict_cls = getattr(datasets_module, "DatasetDict", dict)
|
|
185
|
+
|
|
186
|
+
hf_dataset: Any
|
|
225
187
|
if split is None:
|
|
226
188
|
hf_dataset = load_dataset(dataset_id, **load_dataset_kwargs)
|
|
227
|
-
if isinstance(hf_dataset,
|
|
189
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
228
190
|
try:
|
|
229
191
|
hf_dataset = next(iter(hf_dataset.values()))
|
|
230
192
|
except StopIteration as exc: # pragma: no cover - defensive
|
|
@@ -232,10 +194,8 @@ def echo_chamber(
|
|
|
232
194
|
else:
|
|
233
195
|
hf_dataset = load_dataset(dataset_id, split=split, **load_dataset_kwargs)
|
|
234
196
|
|
|
235
|
-
if isinstance(hf_dataset,
|
|
236
|
-
raise ValueError(
|
|
237
|
-
"Specify which split to use when the dataset loads as a DatasetDict."
|
|
238
|
-
)
|
|
197
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
198
|
+
raise ValueError("Specify which split to use when the dataset loads as a DatasetDict.")
|
|
239
199
|
|
|
240
200
|
filtered_dataset = hf_dataset.filter(
|
|
241
201
|
lambda row: row.get(column) is not None,
|
|
@@ -259,7 +219,7 @@ def echo_chamber(
|
|
|
259
219
|
)
|
|
260
220
|
|
|
261
221
|
try:
|
|
262
|
-
dataset_length = len(base_dataset)
|
|
222
|
+
dataset_length = len(base_dataset)
|
|
263
223
|
except TypeError:
|
|
264
224
|
preview_rows: list[dict[str, Any]]
|
|
265
225
|
take_fn = getattr(base_dataset, "take", None)
|
|
@@ -288,4 +248,7 @@ def echo_chamber(
|
|
|
288
248
|
|
|
289
249
|
rubric_func = reward_function or symmetric_damerau_levenshtein_similarity
|
|
290
250
|
rubric = vf.Rubric(funcs=[rubric_func], weights=[1.0])
|
|
291
|
-
return
|
|
251
|
+
return cast(
|
|
252
|
+
VerifierSingleTurnEnv,
|
|
253
|
+
vf.SingleTurnEnv(dataset=glitched_dataset, rubric=rubric),
|
|
254
|
+
)
|
glitchlings/lexicon/__init__.py
CHANGED
|
@@ -2,13 +2,16 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import random
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from hashlib import blake2s
|
|
7
|
-
import
|
|
8
|
+
from pathlib import Path
|
|
8
9
|
from typing import Callable, Iterable
|
|
9
10
|
|
|
10
11
|
from glitchlings.config import get_config
|
|
11
12
|
|
|
13
|
+
from ._cache import CacheEntries, CacheSnapshot
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class Lexicon(ABC):
|
|
14
17
|
"""Abstract interface describing synonym lookup backends.
|
|
@@ -19,6 +22,7 @@ class Lexicon(ABC):
|
|
|
19
22
|
Optional integer used to derive deterministic random number generators
|
|
20
23
|
for synonym sampling. Identical seeds guarantee reproducible results for
|
|
21
24
|
the same word/part-of-speech queries.
|
|
25
|
+
|
|
22
26
|
"""
|
|
23
27
|
|
|
24
28
|
def __init__(self, *, seed: int | None = None) -> None:
|
|
@@ -27,17 +31,14 @@ class Lexicon(ABC):
|
|
|
27
31
|
@property
|
|
28
32
|
def seed(self) -> int | None:
|
|
29
33
|
"""Return the current base seed used for deterministic sampling."""
|
|
30
|
-
|
|
31
34
|
return self._seed
|
|
32
35
|
|
|
33
36
|
def reseed(self, seed: int | None) -> None:
|
|
34
37
|
"""Update the base seed driving deterministic synonym sampling."""
|
|
35
|
-
|
|
36
38
|
self._seed = seed
|
|
37
39
|
|
|
38
40
|
def _derive_rng(self, word: str, pos: str | None) -> random.Random:
|
|
39
41
|
"""Return an RNG derived from the base seed, word, and POS tag."""
|
|
40
|
-
|
|
41
42
|
seed_material = blake2s(digest_size=8)
|
|
42
43
|
seed_material.update(word.lower().encode("utf8"))
|
|
43
44
|
if pos is not None:
|
|
@@ -51,7 +52,6 @@ class Lexicon(ABC):
|
|
|
51
52
|
self, values: Iterable[str], *, limit: int, word: str, pos: str | None
|
|
52
53
|
) -> list[str]:
|
|
53
54
|
"""Return up to ``limit`` values sampled deterministically."""
|
|
54
|
-
|
|
55
55
|
if limit <= 0:
|
|
56
56
|
return []
|
|
57
57
|
|
|
@@ -65,28 +65,40 @@ class Lexicon(ABC):
|
|
|
65
65
|
return [items[index] for index in indices]
|
|
66
66
|
|
|
67
67
|
@abstractmethod
|
|
68
|
-
def get_synonyms(
|
|
69
|
-
self, word: str, pos: str | None = None, n: int = 5
|
|
70
|
-
) -> list[str]:
|
|
68
|
+
def get_synonyms(self, word: str, pos: str | None = None, n: int = 5) -> list[str]:
|
|
71
69
|
"""Return up to ``n`` synonyms for ``word`` constrained by ``pos``."""
|
|
72
70
|
|
|
73
71
|
def supports_pos(self, pos: str | None) -> bool:
|
|
74
72
|
"""Return ``True`` when the backend can service ``pos`` queries."""
|
|
75
|
-
|
|
76
73
|
return True
|
|
77
74
|
|
|
78
75
|
def __repr__(self) -> str: # pragma: no cover - trivial representation
|
|
79
76
|
return f"{self.__class__.__name__}(seed={self._seed!r})"
|
|
80
77
|
|
|
81
78
|
|
|
82
|
-
|
|
83
|
-
|
|
79
|
+
class LexiconBackend(Lexicon):
|
|
80
|
+
"""Extended lexicon interface that supports cache persistence."""
|
|
81
|
+
|
|
82
|
+
Cache = CacheEntries
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def load_cache(cls, path: str | Path) -> CacheSnapshot:
|
|
87
|
+
"""Return a validated cache snapshot loaded from ``path``."""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def save_cache(self, path: str | Path | None = None) -> Path | None:
|
|
91
|
+
"""Persist the backend cache to ``path`` and return the destination."""
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
from .graph import GraphLexicon # noqa: E402
|
|
95
|
+
from .metrics import ( # noqa: E402
|
|
84
96
|
compare_lexicons,
|
|
85
97
|
coverage_ratio,
|
|
86
98
|
mean_cosine_similarity,
|
|
87
99
|
synonym_diversity,
|
|
88
100
|
)
|
|
89
|
-
from .vector import VectorLexicon, build_vector_cache
|
|
101
|
+
from .vector import VectorLexicon, build_vector_cache # noqa: E402
|
|
90
102
|
|
|
91
103
|
try: # pragma: no cover - optional dependency
|
|
92
104
|
from .wordnet import WordNetLexicon
|
|
@@ -97,24 +109,19 @@ except Exception: # pragma: no cover - triggered when nltk unavailable
|
|
|
97
109
|
_BACKEND_FACTORIES: dict[str, Callable[[int | None], Lexicon | None]] = {}
|
|
98
110
|
|
|
99
111
|
|
|
100
|
-
def register_backend(
|
|
101
|
-
name: str, factory: Callable[[int | None], Lexicon | None]
|
|
102
|
-
) -> None:
|
|
112
|
+
def register_backend(name: str, factory: Callable[[int | None], Lexicon | None]) -> None:
|
|
103
113
|
"""Register ``factory`` for ``name`` so it can be selected via config."""
|
|
104
|
-
|
|
105
114
|
normalized = name.lower()
|
|
106
115
|
_BACKEND_FACTORIES[normalized] = factory
|
|
107
116
|
|
|
108
117
|
|
|
109
118
|
def unregister_backend(name: str) -> None:
|
|
110
119
|
"""Remove a previously registered backend."""
|
|
111
|
-
|
|
112
120
|
_BACKEND_FACTORIES.pop(name.lower(), None)
|
|
113
121
|
|
|
114
122
|
|
|
115
123
|
def available_backends() -> list[str]:
|
|
116
124
|
"""Return the names of registered lexicon factories."""
|
|
117
|
-
|
|
118
125
|
return sorted(_BACKEND_FACTORIES)
|
|
119
126
|
|
|
120
127
|
|
|
@@ -155,7 +162,6 @@ register_backend("wordnet", _wordnet_backend)
|
|
|
155
162
|
|
|
156
163
|
def get_default_lexicon(seed: int | None = None) -> Lexicon:
|
|
157
164
|
"""Return the first available lexicon according to configuration priority."""
|
|
158
|
-
|
|
159
165
|
config = get_config()
|
|
160
166
|
attempts: list[str] = []
|
|
161
167
|
for name in config.lexicon.priority:
|
|
@@ -176,6 +182,7 @@ def get_default_lexicon(seed: int | None = None) -> Lexicon:
|
|
|
176
182
|
|
|
177
183
|
__all__ = [
|
|
178
184
|
"Lexicon",
|
|
185
|
+
"LexiconBackend",
|
|
179
186
|
"VectorLexicon",
|
|
180
187
|
"GraphLexicon",
|
|
181
188
|
"WordNetLexicon",
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Shared cache helpers for lexicon backends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from hashlib import blake2s
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Mapping, Sequence
|
|
10
|
+
|
|
11
|
+
CacheEntries = dict[str, list[str]]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class CacheSnapshot:
|
|
16
|
+
"""Materialised cache data and its integrity checksum."""
|
|
17
|
+
|
|
18
|
+
entries: CacheEntries
|
|
19
|
+
checksum: str | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _normalise_entries(payload: Mapping[str, Sequence[str]]) -> CacheEntries:
|
|
23
|
+
"""Convert raw cache payloads into canonical mapping form."""
|
|
24
|
+
entries: CacheEntries = {}
|
|
25
|
+
for key, values in payload.items():
|
|
26
|
+
if not isinstance(key, str):
|
|
27
|
+
raise RuntimeError("Synonym cache keys must be strings.")
|
|
28
|
+
if not isinstance(values, Sequence):
|
|
29
|
+
raise RuntimeError("Synonym cache values must be sequences of strings.")
|
|
30
|
+
entries[key] = [str(value) for value in values]
|
|
31
|
+
return entries
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _canonical_json(entries: Mapping[str, Sequence[str]]) -> str:
|
|
35
|
+
"""Return a deterministic JSON serialisation for ``entries``."""
|
|
36
|
+
serialisable = {key: list(values) for key, values in sorted(entries.items())}
|
|
37
|
+
return json.dumps(serialisable, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def compute_checksum(entries: Mapping[str, Sequence[str]]) -> str:
|
|
41
|
+
"""Return a BLAKE2s checksum for ``entries``."""
|
|
42
|
+
digest = blake2s(_canonical_json(entries).encode("utf8"), digest_size=16)
|
|
43
|
+
return digest.hexdigest()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def load_cache(path: Path) -> CacheSnapshot:
|
|
47
|
+
"""Load a cache from ``path`` and verify its checksum if present."""
|
|
48
|
+
if not path.exists():
|
|
49
|
+
return CacheSnapshot(entries={}, checksum=None)
|
|
50
|
+
|
|
51
|
+
with path.open("r", encoding="utf8") as handle:
|
|
52
|
+
payload = json.load(handle)
|
|
53
|
+
|
|
54
|
+
checksum: str | None = None
|
|
55
|
+
entries_payload: Mapping[str, Sequence[str]]
|
|
56
|
+
|
|
57
|
+
if isinstance(payload, Mapping) and "__meta__" in payload and "entries" in payload:
|
|
58
|
+
meta = payload["__meta__"]
|
|
59
|
+
entries_payload = payload["entries"] # type: ignore[assignment]
|
|
60
|
+
if not isinstance(entries_payload, Mapping):
|
|
61
|
+
raise RuntimeError("Synonym cache entries must be stored as a mapping.")
|
|
62
|
+
if isinstance(meta, Mapping):
|
|
63
|
+
raw_checksum = meta.get("checksum")
|
|
64
|
+
if raw_checksum is not None and not isinstance(raw_checksum, str):
|
|
65
|
+
raise RuntimeError("Synonym cache checksum must be a string when provided.")
|
|
66
|
+
checksum = raw_checksum
|
|
67
|
+
else:
|
|
68
|
+
raise RuntimeError("Synonym cache metadata must be a mapping.")
|
|
69
|
+
elif isinstance(payload, Mapping):
|
|
70
|
+
entries_payload = payload # legacy format without metadata
|
|
71
|
+
else:
|
|
72
|
+
raise RuntimeError("Synonym cache payload must be a mapping of strings to lists.")
|
|
73
|
+
|
|
74
|
+
entries = _normalise_entries(entries_payload)
|
|
75
|
+
if checksum is not None:
|
|
76
|
+
expected = compute_checksum(entries)
|
|
77
|
+
if checksum != expected:
|
|
78
|
+
raise RuntimeError(
|
|
79
|
+
"Synonym cache checksum mismatch; the cache file appears to be corrupted."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return CacheSnapshot(entries=entries, checksum=checksum)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def write_cache(path: Path, entries: Mapping[str, Sequence[str]]) -> CacheSnapshot:
|
|
86
|
+
"""Persist ``entries`` to ``path`` with checksum metadata."""
|
|
87
|
+
serialisable = {key: list(values) for key, values in sorted(entries.items())}
|
|
88
|
+
checksum = compute_checksum(serialisable)
|
|
89
|
+
payload = {
|
|
90
|
+
"__meta__": {
|
|
91
|
+
"checksum": checksum,
|
|
92
|
+
"entries": len(serialisable),
|
|
93
|
+
},
|
|
94
|
+
"entries": serialisable,
|
|
95
|
+
}
|
|
96
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
with path.open("w", encoding="utf8") as handle:
|
|
99
|
+
json.dump(payload, handle, ensure_ascii=False, indent=2, sort_keys=True)
|
|
100
|
+
|
|
101
|
+
return CacheSnapshot(entries=serialisable, checksum=checksum)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
__all__ = ["CacheEntries", "CacheSnapshot", "compute_checksum", "load_cache", "write_cache"]
|