glitchlings 0.4.2__cp311-cp311-macosx_11_0_universal2.whl → 0.4.3__cp311-cp311-macosx_11_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +4 -0
- glitchlings/_zoo_rust.cpython-311-darwin.so +0 -0
- glitchlings/compat.py +80 -11
- glitchlings/config.py +32 -19
- glitchlings/config.toml +1 -1
- glitchlings/dlc/__init__.py +3 -1
- glitchlings/dlc/pytorch.py +216 -0
- glitchlings/dlc/pytorch_lightning.py +233 -0
- glitchlings/lexicon/__init__.py +5 -15
- glitchlings/lexicon/_cache.py +21 -15
- glitchlings/lexicon/data/default_vector_cache.json +80 -14
- glitchlings/lexicon/vector.py +94 -15
- glitchlings/lexicon/wordnet.py +66 -25
- glitchlings/main.py +21 -11
- glitchlings/zoo/__init__.py +5 -1
- glitchlings/zoo/adjax.py +2 -2
- glitchlings/zoo/apostrofae.py +128 -0
- glitchlings/zoo/assets/__init__.py +0 -0
- glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
- glitchlings/zoo/core.py +40 -14
- glitchlings/zoo/jargoyle.py +44 -34
- glitchlings/zoo/redactyl.py +11 -8
- glitchlings/zoo/reduple.py +2 -2
- glitchlings/zoo/rushmore.py +2 -2
- glitchlings/zoo/scannequin.py +2 -2
- glitchlings/zoo/typogre.py +5 -2
- glitchlings/zoo/zeedub.py +5 -2
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/METADATA +35 -2
- glitchlings-0.4.3.dist-info/RECORD +46 -0
- glitchlings/lexicon/graph.py +0 -282
- glitchlings-0.4.2.dist-info/RECORD +0 -42
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/WHEEL +0 -0
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/entry_points.txt +0 -0
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch Lightning data modules."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import get_pytorch_lightning_datamodule, require_pytorch_lightning
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
11
|
+
from ..zoo.core import _is_transcript
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
15
|
+
"""Normalise a column specification to a list."""
|
|
16
|
+
if isinstance(column, str):
|
|
17
|
+
return [column]
|
|
18
|
+
|
|
19
|
+
normalised = list(column)
|
|
20
|
+
if not normalised:
|
|
21
|
+
raise ValueError("At least one column must be specified")
|
|
22
|
+
return normalised
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _glitch_value(value: Any, gaggle: Gaggle) -> Any:
|
|
26
|
+
"""Apply glitchlings to a value when it contains textual content."""
|
|
27
|
+
if isinstance(value, str) or _is_transcript(value, allow_empty=False, require_all_content=True):
|
|
28
|
+
return gaggle.corrupt(value)
|
|
29
|
+
|
|
30
|
+
if isinstance(value, Sequence) and value and all(isinstance(item, str) for item in value):
|
|
31
|
+
return [gaggle.corrupt(item) for item in value]
|
|
32
|
+
|
|
33
|
+
return value
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _glitch_batch(batch: Any, columns: list[str], gaggle: Gaggle) -> Any:
|
|
37
|
+
"""Apply glitchlings to the configured batch columns."""
|
|
38
|
+
if not isinstance(batch, Mapping):
|
|
39
|
+
return batch
|
|
40
|
+
|
|
41
|
+
if hasattr(batch, "copy"):
|
|
42
|
+
mutated = batch.copy()
|
|
43
|
+
else:
|
|
44
|
+
mutated = dict(batch)
|
|
45
|
+
|
|
46
|
+
missing = [column for column in columns if column not in mutated]
|
|
47
|
+
if missing:
|
|
48
|
+
missing_str = ", ".join(sorted(missing))
|
|
49
|
+
raise ValueError(f"Columns not found in batch: {missing_str}")
|
|
50
|
+
|
|
51
|
+
for column in columns:
|
|
52
|
+
mutated[column] = _glitch_value(mutated[column], gaggle)
|
|
53
|
+
|
|
54
|
+
return mutated
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _wrap_dataloader(dataloader: Any, columns: list[str], gaggle: Gaggle) -> Any:
|
|
58
|
+
"""Wrap a dataloader so yielded batches are corrupted lazily."""
|
|
59
|
+
if dataloader is None:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
if isinstance(dataloader, Mapping):
|
|
63
|
+
mapping_type = cast(type[Any], dataloader.__class__)
|
|
64
|
+
return mapping_type(
|
|
65
|
+
{
|
|
66
|
+
key: _wrap_dataloader(value, columns, gaggle)
|
|
67
|
+
for key, value in dataloader.items()
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if isinstance(dataloader, list):
|
|
72
|
+
return [_wrap_dataloader(value, columns, gaggle) for value in dataloader]
|
|
73
|
+
|
|
74
|
+
if isinstance(dataloader, tuple):
|
|
75
|
+
return tuple(_wrap_dataloader(value, columns, gaggle) for value in dataloader)
|
|
76
|
+
|
|
77
|
+
if isinstance(dataloader, Sequence) and not isinstance(dataloader, (str, bytes, bytearray)):
|
|
78
|
+
sequence_type = cast(type[Any], dataloader.__class__)
|
|
79
|
+
return sequence_type(
|
|
80
|
+
_wrap_dataloader(value, columns, gaggle) for value in dataloader
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return _GlitchedDataLoader(dataloader, columns, gaggle)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class _GlitchedDataLoader:
|
|
87
|
+
"""Proxy dataloader that glitches batches produced by the wrapped loader."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, dataloader: Any, columns: list[str], gaggle: Gaggle) -> None:
|
|
90
|
+
self._dataloader = dataloader
|
|
91
|
+
self._columns = columns
|
|
92
|
+
self._gaggle = gaggle
|
|
93
|
+
|
|
94
|
+
def __iter__(self) -> Any:
|
|
95
|
+
for batch in self._dataloader:
|
|
96
|
+
yield _glitch_batch(batch, self._columns, self._gaggle)
|
|
97
|
+
|
|
98
|
+
def __len__(self) -> int:
|
|
99
|
+
return len(self._dataloader)
|
|
100
|
+
|
|
101
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
102
|
+
return getattr(self._dataloader, attribute)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _glitch_datamodule(
|
|
106
|
+
datamodule: Any,
|
|
107
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
108
|
+
column: str | Sequence[str],
|
|
109
|
+
*,
|
|
110
|
+
seed: int = 151,
|
|
111
|
+
) -> Any:
|
|
112
|
+
"""Return a proxy that applies glitchlings to batches from the datamodule."""
|
|
113
|
+
|
|
114
|
+
columns = _normalise_columns(column)
|
|
115
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
116
|
+
return _GlitchedLightningDataModule(datamodule, columns, gaggle)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class _GlitchedLightningDataModule:
|
|
120
|
+
"""Proxy wrapper around a LightningDataModule applying glitchlings to batches."""
|
|
121
|
+
|
|
122
|
+
def __init__(self, base: Any, columns: list[str], gaggle: Gaggle) -> None:
|
|
123
|
+
object.__setattr__(self, "_glitch_base", base)
|
|
124
|
+
object.__setattr__(self, "_glitch_columns", columns)
|
|
125
|
+
object.__setattr__(self, "_glitch_gaggle", gaggle)
|
|
126
|
+
|
|
127
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
128
|
+
return getattr(self._glitch_base, attribute)
|
|
129
|
+
|
|
130
|
+
def __setattr__(self, attribute: str, value: Any) -> None:
|
|
131
|
+
if attribute.startswith("_glitch_"):
|
|
132
|
+
object.__setattr__(self, attribute, value)
|
|
133
|
+
else:
|
|
134
|
+
setattr(self._glitch_base, attribute, value)
|
|
135
|
+
|
|
136
|
+
def __delattr__(self, attribute: str) -> None:
|
|
137
|
+
if attribute.startswith("_glitch_"):
|
|
138
|
+
object.__delattr__(self, attribute)
|
|
139
|
+
else:
|
|
140
|
+
delattr(self._glitch_base, attribute)
|
|
141
|
+
|
|
142
|
+
def __dir__(self) -> list[str]:
|
|
143
|
+
return sorted(set(dir(self.__class__)) | set(dir(self._glitch_base)))
|
|
144
|
+
|
|
145
|
+
# LightningDataModule API -------------------------------------------------
|
|
146
|
+
def prepare_data(self, *args: Any, **kwargs: Any) -> Any:
|
|
147
|
+
return self._glitch_base.prepare_data(*args, **kwargs)
|
|
148
|
+
|
|
149
|
+
def setup(self, *args: Any, **kwargs: Any) -> Any:
|
|
150
|
+
return self._glitch_base.setup(*args, **kwargs)
|
|
151
|
+
|
|
152
|
+
def teardown(self, *args: Any, **kwargs: Any) -> Any:
|
|
153
|
+
return self._glitch_base.teardown(*args, **kwargs)
|
|
154
|
+
|
|
155
|
+
def state_dict(self) -> Mapping[str, Any]:
|
|
156
|
+
state = self._glitch_base.state_dict()
|
|
157
|
+
return cast(Mapping[str, Any], state)
|
|
158
|
+
|
|
159
|
+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
|
160
|
+
self._glitch_base.load_state_dict(state_dict)
|
|
161
|
+
|
|
162
|
+
def transfer_batch_to_device(self, batch: Any, device: Any, dataloader_idx: int) -> Any:
|
|
163
|
+
return self._glitch_base.transfer_batch_to_device(batch, device, dataloader_idx)
|
|
164
|
+
|
|
165
|
+
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
|
166
|
+
return self._glitch_base.on_before_batch_transfer(batch, dataloader_idx)
|
|
167
|
+
|
|
168
|
+
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
|
169
|
+
return self._glitch_base.on_after_batch_transfer(batch, dataloader_idx)
|
|
170
|
+
|
|
171
|
+
def train_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
172
|
+
loader = self._glitch_base.train_dataloader(*args, **kwargs)
|
|
173
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
174
|
+
|
|
175
|
+
def val_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
176
|
+
loader = self._glitch_base.val_dataloader(*args, **kwargs)
|
|
177
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
178
|
+
|
|
179
|
+
def test_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
180
|
+
loader = self._glitch_base.test_dataloader(*args, **kwargs)
|
|
181
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
182
|
+
|
|
183
|
+
def predict_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
184
|
+
loader = self._glitch_base.predict_dataloader(*args, **kwargs)
|
|
185
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _ensure_datamodule_class() -> Any:
|
|
189
|
+
"""Return the Lightning ``LightningDataModule`` patched with ``.glitch``."""
|
|
190
|
+
|
|
191
|
+
datamodule_cls = get_pytorch_lightning_datamodule()
|
|
192
|
+
if datamodule_cls is None: # pragma: no cover - dependency is optional
|
|
193
|
+
module = require_pytorch_lightning("pytorch_lightning is not installed")
|
|
194
|
+
datamodule_cls = getattr(module, "LightningDataModule", None)
|
|
195
|
+
if datamodule_cls is None:
|
|
196
|
+
raise ModuleNotFoundError("pytorch_lightning is not installed")
|
|
197
|
+
|
|
198
|
+
if getattr(datamodule_cls, "glitch", None) is None:
|
|
199
|
+
|
|
200
|
+
def glitch(
|
|
201
|
+
self: Any,
|
|
202
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
203
|
+
*,
|
|
204
|
+
column: str | Sequence[str],
|
|
205
|
+
seed: int = 151,
|
|
206
|
+
**_: Any,
|
|
207
|
+
) -> Any:
|
|
208
|
+
return _glitch_datamodule(self, glitchlings, column, seed=seed)
|
|
209
|
+
|
|
210
|
+
setattr(datamodule_cls, "glitch", glitch)
|
|
211
|
+
|
|
212
|
+
if not issubclass(_GlitchedLightningDataModule, datamodule_cls):
|
|
213
|
+
_GlitchedLightningDataModule.__bases__ = (datamodule_cls,)
|
|
214
|
+
|
|
215
|
+
return datamodule_cls
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def install() -> None:
|
|
219
|
+
"""Monkeypatch ``LightningDataModule`` with ``.glitch``."""
|
|
220
|
+
|
|
221
|
+
_ensure_datamodule_class()
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
LightningDataModule: type[Any] | None
|
|
225
|
+
_LightningDataModuleAlias = get_pytorch_lightning_datamodule()
|
|
226
|
+
if _LightningDataModuleAlias is not None:
|
|
227
|
+
LightningDataModule = _ensure_datamodule_class()
|
|
228
|
+
else: # pragma: no cover - optional dependency
|
|
229
|
+
LightningDataModule = None
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
__all__ = ["LightningDataModule", "install"]
|
|
233
|
+
|
glitchlings/lexicon/__init__.py
CHANGED
|
@@ -91,7 +91,6 @@ class LexiconBackend(Lexicon):
|
|
|
91
91
|
"""Persist the backend cache to ``path`` and return the destination."""
|
|
92
92
|
|
|
93
93
|
|
|
94
|
-
from .graph import GraphLexicon # noqa: E402
|
|
95
94
|
from .metrics import ( # noqa: E402
|
|
96
95
|
compare_lexicons,
|
|
97
96
|
coverage_ratio,
|
|
@@ -100,10 +99,13 @@ from .metrics import ( # noqa: E402
|
|
|
100
99
|
)
|
|
101
100
|
from .vector import VectorLexicon, build_vector_cache # noqa: E402
|
|
102
101
|
|
|
102
|
+
_WordNetLexicon: type[LexiconBackend] | None
|
|
103
103
|
try: # pragma: no cover - optional dependency
|
|
104
|
-
from .wordnet import WordNetLexicon
|
|
104
|
+
from .wordnet import WordNetLexicon as _WordNetLexicon
|
|
105
105
|
except Exception: # pragma: no cover - triggered when nltk unavailable
|
|
106
|
-
|
|
106
|
+
_WordNetLexicon = None
|
|
107
|
+
|
|
108
|
+
WordNetLexicon: type[LexiconBackend] | None = _WordNetLexicon
|
|
107
109
|
|
|
108
110
|
|
|
109
111
|
_BACKEND_FACTORIES: dict[str, Callable[[int | None], Lexicon | None]] = {}
|
|
@@ -135,16 +137,6 @@ def _vector_backend(seed: int | None) -> Lexicon | None:
|
|
|
135
137
|
return VectorLexicon(cache_path=cache_path, seed=seed)
|
|
136
138
|
|
|
137
139
|
|
|
138
|
-
def _graph_backend(seed: int | None) -> Lexicon | None:
|
|
139
|
-
config = get_config()
|
|
140
|
-
cache_path = config.lexicon.graph_cache
|
|
141
|
-
if cache_path is None:
|
|
142
|
-
return None
|
|
143
|
-
if not cache_path.exists():
|
|
144
|
-
return None
|
|
145
|
-
return GraphLexicon(cache_path=cache_path, seed=seed)
|
|
146
|
-
|
|
147
|
-
|
|
148
140
|
def _wordnet_backend(seed: int | None) -> Lexicon | None: # pragma: no cover - optional
|
|
149
141
|
if WordNetLexicon is None:
|
|
150
142
|
return None
|
|
@@ -156,7 +148,6 @@ def _wordnet_backend(seed: int | None) -> Lexicon | None: # pragma: no cover -
|
|
|
156
148
|
|
|
157
149
|
|
|
158
150
|
register_backend("vector", _vector_backend)
|
|
159
|
-
register_backend("graph", _graph_backend)
|
|
160
151
|
register_backend("wordnet", _wordnet_backend)
|
|
161
152
|
|
|
162
153
|
|
|
@@ -184,7 +175,6 @@ __all__ = [
|
|
|
184
175
|
"Lexicon",
|
|
185
176
|
"LexiconBackend",
|
|
186
177
|
"VectorLexicon",
|
|
187
|
-
"GraphLexicon",
|
|
188
178
|
"WordNetLexicon",
|
|
189
179
|
"build_vector_cache",
|
|
190
180
|
"compare_lexicons",
|
glitchlings/lexicon/_cache.py
CHANGED
|
@@ -6,7 +6,7 @@ import json
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from hashlib import blake2s
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Mapping, Sequence
|
|
9
|
+
from typing import Mapping, Sequence, cast
|
|
10
10
|
|
|
11
11
|
CacheEntries = dict[str, list[str]]
|
|
12
12
|
|
|
@@ -19,7 +19,7 @@ class CacheSnapshot:
|
|
|
19
19
|
checksum: str | None = None
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def _normalise_entries(payload: Mapping[str,
|
|
22
|
+
def _normalise_entries(payload: Mapping[str, object]) -> CacheEntries:
|
|
23
23
|
"""Convert raw cache payloads into canonical mapping form."""
|
|
24
24
|
entries: CacheEntries = {}
|
|
25
25
|
for key, values in payload.items():
|
|
@@ -49,27 +49,31 @@ def load_cache(path: Path) -> CacheSnapshot:
|
|
|
49
49
|
return CacheSnapshot(entries={}, checksum=None)
|
|
50
50
|
|
|
51
51
|
with path.open("r", encoding="utf8") as handle:
|
|
52
|
-
|
|
52
|
+
payload_obj = json.load(handle)
|
|
53
53
|
|
|
54
54
|
checksum: str | None = None
|
|
55
|
-
entries_payload: Mapping[str,
|
|
55
|
+
entries_payload: Mapping[str, object]
|
|
56
56
|
|
|
57
|
-
if isinstance(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
57
|
+
if not isinstance(payload_obj, Mapping):
|
|
58
|
+
raise RuntimeError("Synonym cache payload must be a mapping of strings to lists.")
|
|
59
|
+
|
|
60
|
+
payload = cast(Mapping[str, object], payload_obj)
|
|
61
|
+
|
|
62
|
+
if "__meta__" in payload and "entries" in payload:
|
|
63
|
+
meta_obj = payload["__meta__"]
|
|
64
|
+
entries_obj = payload["entries"]
|
|
65
|
+
if not isinstance(entries_obj, Mapping):
|
|
61
66
|
raise RuntimeError("Synonym cache entries must be stored as a mapping.")
|
|
62
|
-
|
|
63
|
-
|
|
67
|
+
entries_payload = cast(Mapping[str, object], entries_obj)
|
|
68
|
+
if isinstance(meta_obj, Mapping):
|
|
69
|
+
raw_checksum = meta_obj.get("checksum")
|
|
64
70
|
if raw_checksum is not None and not isinstance(raw_checksum, str):
|
|
65
71
|
raise RuntimeError("Synonym cache checksum must be a string when provided.")
|
|
66
|
-
checksum = raw_checksum
|
|
72
|
+
checksum = raw_checksum if isinstance(raw_checksum, str) else None
|
|
67
73
|
else:
|
|
68
74
|
raise RuntimeError("Synonym cache metadata must be a mapping.")
|
|
69
|
-
elif isinstance(payload, Mapping):
|
|
70
|
-
entries_payload = payload # legacy format without metadata
|
|
71
75
|
else:
|
|
72
|
-
|
|
76
|
+
entries_payload = payload # legacy format without metadata
|
|
73
77
|
|
|
74
78
|
entries = _normalise_entries(entries_payload)
|
|
75
79
|
if checksum is not None:
|
|
@@ -84,7 +88,9 @@ def load_cache(path: Path) -> CacheSnapshot:
|
|
|
84
88
|
|
|
85
89
|
def write_cache(path: Path, entries: Mapping[str, Sequence[str]]) -> CacheSnapshot:
|
|
86
90
|
"""Persist ``entries`` to ``path`` with checksum metadata."""
|
|
87
|
-
serialisable = {
|
|
91
|
+
serialisable: CacheEntries = {
|
|
92
|
+
key: list(values) for key, values in sorted(entries.items())
|
|
93
|
+
}
|
|
88
94
|
checksum = compute_checksum(serialisable)
|
|
89
95
|
payload = {
|
|
90
96
|
"__meta__": {
|
|
@@ -1,16 +1,82 @@
|
|
|
1
1
|
{
|
|
2
|
-
"
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
"
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
"
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
2
|
+
"alpha": [
|
|
3
|
+
"beta",
|
|
4
|
+
"gamma",
|
|
5
|
+
"delta"
|
|
6
|
+
],
|
|
7
|
+
"beta": [
|
|
8
|
+
"alpha",
|
|
9
|
+
"gamma",
|
|
10
|
+
"delta"
|
|
11
|
+
],
|
|
12
|
+
"delta": [
|
|
13
|
+
"alpha",
|
|
14
|
+
"beta",
|
|
15
|
+
"gamma"
|
|
16
|
+
],
|
|
17
|
+
"fast": [
|
|
18
|
+
"rapid",
|
|
19
|
+
"swift",
|
|
20
|
+
"speedy",
|
|
21
|
+
"brisk"
|
|
22
|
+
],
|
|
23
|
+
"gamma": [
|
|
24
|
+
"alpha",
|
|
25
|
+
"beta",
|
|
26
|
+
"delta"
|
|
27
|
+
],
|
|
28
|
+
"happy": [
|
|
29
|
+
"glad",
|
|
30
|
+
"joyful",
|
|
31
|
+
"content",
|
|
32
|
+
"upbeat"
|
|
33
|
+
],
|
|
34
|
+
"quick": [
|
|
35
|
+
"swift",
|
|
36
|
+
"rapid",
|
|
37
|
+
"speedy",
|
|
38
|
+
"nimble"
|
|
39
|
+
],
|
|
40
|
+
"quickly": [
|
|
41
|
+
"swiftly",
|
|
42
|
+
"rapidly",
|
|
43
|
+
"promptly",
|
|
44
|
+
"speedily"
|
|
45
|
+
],
|
|
46
|
+
"sing": [
|
|
47
|
+
"croon",
|
|
48
|
+
"serenade",
|
|
49
|
+
"vocalize",
|
|
50
|
+
"perform"
|
|
51
|
+
],
|
|
52
|
+
"slow": [
|
|
53
|
+
"sluggish",
|
|
54
|
+
"leisurely",
|
|
55
|
+
"unhurried",
|
|
56
|
+
"gradual"
|
|
57
|
+
],
|
|
58
|
+
"songs": [
|
|
59
|
+
"tracks",
|
|
60
|
+
"melodies",
|
|
61
|
+
"ballads",
|
|
62
|
+
"tunes"
|
|
63
|
+
],
|
|
64
|
+
"text": [
|
|
65
|
+
"passage",
|
|
66
|
+
"copy",
|
|
67
|
+
"script",
|
|
68
|
+
"narrative"
|
|
69
|
+
],
|
|
70
|
+
"they": [
|
|
71
|
+
"those people",
|
|
72
|
+
"those individuals",
|
|
73
|
+
"the group",
|
|
74
|
+
"those folks"
|
|
75
|
+
],
|
|
76
|
+
"words": [
|
|
77
|
+
"terms",
|
|
78
|
+
"phrases",
|
|
79
|
+
"lexicon",
|
|
80
|
+
"vocabulary"
|
|
81
|
+
]
|
|
16
82
|
}
|
glitchlings/lexicon/vector.py
CHANGED
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import argparse
|
|
6
6
|
import importlib
|
|
7
|
+
import importlib.util
|
|
7
8
|
import json
|
|
8
9
|
import math
|
|
9
10
|
import sys
|
|
@@ -188,6 +189,58 @@ def _load_spacy_language(model_name: str) -> Any:
|
|
|
188
189
|
return spacy_module.load(model_name)
|
|
189
190
|
|
|
190
191
|
|
|
192
|
+
def _load_sentence_transformer(model_name: str) -> Any:
|
|
193
|
+
"""Return a ``SentenceTransformer`` instance for ``model_name``."""
|
|
194
|
+
|
|
195
|
+
if importlib.util.find_spec("sentence_transformers") is None:
|
|
196
|
+
raise RuntimeError(
|
|
197
|
+
"sentence-transformers is required for this source; install the 'st' extra."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
module = importlib.import_module("sentence_transformers")
|
|
201
|
+
try:
|
|
202
|
+
model_cls = getattr(module, "SentenceTransformer")
|
|
203
|
+
except AttributeError as exc: # pragma: no cover - defensive
|
|
204
|
+
raise RuntimeError("sentence-transformers does not expose SentenceTransformer") from exc
|
|
205
|
+
|
|
206
|
+
return model_cls(model_name)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _build_sentence_transformer_embeddings(
|
|
210
|
+
model_name: str, tokens: Sequence[str]
|
|
211
|
+
) -> Mapping[str, Sequence[float]]:
|
|
212
|
+
"""Return embeddings for ``tokens`` using ``model_name``."""
|
|
213
|
+
|
|
214
|
+
if not tokens:
|
|
215
|
+
return {}
|
|
216
|
+
|
|
217
|
+
model = _load_sentence_transformer(model_name)
|
|
218
|
+
|
|
219
|
+
unique_tokens: list[str] = []
|
|
220
|
+
seen: set[str] = set()
|
|
221
|
+
for token in tokens:
|
|
222
|
+
normalized = token.strip()
|
|
223
|
+
if not normalized or normalized in seen:
|
|
224
|
+
continue
|
|
225
|
+
unique_tokens.append(normalized)
|
|
226
|
+
seen.add(normalized)
|
|
227
|
+
|
|
228
|
+
if not unique_tokens:
|
|
229
|
+
return {}
|
|
230
|
+
|
|
231
|
+
embeddings = model.encode(
|
|
232
|
+
unique_tokens,
|
|
233
|
+
batch_size=64,
|
|
234
|
+
normalize_embeddings=True,
|
|
235
|
+
convert_to_numpy=True,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
token: [float(value) for value in vector]
|
|
240
|
+
for token, vector in zip(unique_tokens, embeddings, strict=True)
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
|
|
191
244
|
def _resolve_source(source: Any | None) -> _Adapter | None:
|
|
192
245
|
"""Return an adapter instance for ``source`` if possible."""
|
|
193
246
|
if source is None:
|
|
@@ -248,6 +301,7 @@ class VectorLexicon(LexiconBackend):
|
|
|
248
301
|
case_sensitive: bool = False,
|
|
249
302
|
seed: int | None = None,
|
|
250
303
|
) -> None:
|
|
304
|
+
"""Initialise the lexicon with an embedding ``source`` and optional cache."""
|
|
251
305
|
super().__init__(seed=seed)
|
|
252
306
|
self._adapter = _resolve_source(source)
|
|
253
307
|
self._max_neighbors = max(1, max_neighbors)
|
|
@@ -350,6 +404,7 @@ class VectorLexicon(LexiconBackend):
|
|
|
350
404
|
return synonyms
|
|
351
405
|
|
|
352
406
|
def get_synonyms(self, word: str, pos: str | None = None, n: int = 5) -> list[str]:
|
|
407
|
+
"""Return up to ``n`` deterministic synonyms drawn from the embedding cache."""
|
|
353
408
|
normalized = self._normalize_for_lookup(word)
|
|
354
409
|
synonyms = self._ensure_cached(original=word, normalized=normalized)
|
|
355
410
|
return self._deterministic_sample(synonyms, limit=n, word=word, pos=pos)
|
|
@@ -390,6 +445,7 @@ class VectorLexicon(LexiconBackend):
|
|
|
390
445
|
return target
|
|
391
446
|
|
|
392
447
|
def supports_pos(self, pos: str | None) -> bool:
|
|
448
|
+
"""Always return ``True`` because vector sources do not encode POS metadata."""
|
|
393
449
|
return True
|
|
394
450
|
|
|
395
451
|
def __repr__(self) -> str: # pragma: no cover - debug helper
|
|
@@ -452,7 +508,8 @@ def _parse_cli(argv: Sequence[str] | None = None) -> argparse.Namespace:
|
|
|
452
508
|
"--source",
|
|
453
509
|
required=True,
|
|
454
510
|
help=(
|
|
455
|
-
"Vector source specification. Use 'spacy:<model>' for spaCy pipelines "
|
|
511
|
+
"Vector source specification. Use 'spacy:<model>' for spaCy pipelines, "
|
|
512
|
+
"'sentence-transformers:<model>' for HuggingFace checkpoints (requires --tokens), "
|
|
456
513
|
"or provide a path to a gensim KeyedVectors/word2vec file."
|
|
457
514
|
),
|
|
458
515
|
)
|
|
@@ -534,22 +591,44 @@ def main(argv: Sequence[str] | None = None) -> int:
|
|
|
534
591
|
|
|
535
592
|
normalizer = _identity
|
|
536
593
|
|
|
537
|
-
|
|
594
|
+
tokens_from_file: list[str] | None = None
|
|
538
595
|
if args.tokens is not None:
|
|
539
|
-
|
|
596
|
+
tokens_from_file = list(_iter_tokens_from_file(args.tokens))
|
|
597
|
+
if args.limit is not None:
|
|
598
|
+
tokens_from_file = tokens_from_file[: args.limit]
|
|
599
|
+
|
|
600
|
+
source_spec = args.source
|
|
601
|
+
token_iter: Iterable[str]
|
|
602
|
+
if source_spec.startswith("sentence-transformers:"):
|
|
603
|
+
model_name = source_spec.split(":", 1)[1].strip()
|
|
604
|
+
if not model_name:
|
|
605
|
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
|
606
|
+
if tokens_from_file is None:
|
|
607
|
+
raise SystemExit(
|
|
608
|
+
"Sentence-transformers sources require --tokens to supply a vocabulary."
|
|
609
|
+
)
|
|
610
|
+
source = _build_sentence_transformer_embeddings(model_name, tokens_from_file)
|
|
611
|
+
token_iter = tokens_from_file
|
|
540
612
|
else:
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
613
|
+
source = load_vector_source(source_spec)
|
|
614
|
+
if tokens_from_file is not None:
|
|
615
|
+
token_iter = tokens_from_file
|
|
616
|
+
else:
|
|
617
|
+
lexicon = VectorLexicon(
|
|
618
|
+
source=source,
|
|
619
|
+
max_neighbors=args.max_neighbors,
|
|
620
|
+
min_similarity=args.min_similarity,
|
|
621
|
+
case_sensitive=args.case_sensitive,
|
|
622
|
+
normalizer=normalizer,
|
|
623
|
+
seed=args.seed,
|
|
624
|
+
)
|
|
625
|
+
iterator = lexicon.iter_vocabulary()
|
|
626
|
+
if args.limit is not None:
|
|
627
|
+
token_iter = (
|
|
628
|
+
token for index, token in enumerate(iterator) if index < args.limit
|
|
629
|
+
)
|
|
630
|
+
else:
|
|
631
|
+
token_iter = iterator
|
|
553
632
|
|
|
554
633
|
build_vector_cache(
|
|
555
634
|
source=source,
|