glitchlings 0.4.2__cp312-cp312-macosx_11_0_universal2.whl → 0.4.3__cp312-cp312-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of glitchlings might be problematic. Click here for more details.

Files changed (35) hide show
  1. glitchlings/__init__.py +4 -0
  2. glitchlings/_zoo_rust.cpython-312-darwin.so +0 -0
  3. glitchlings/compat.py +80 -11
  4. glitchlings/config.py +32 -19
  5. glitchlings/config.toml +1 -1
  6. glitchlings/dlc/__init__.py +3 -1
  7. glitchlings/dlc/pytorch.py +216 -0
  8. glitchlings/dlc/pytorch_lightning.py +233 -0
  9. glitchlings/lexicon/__init__.py +5 -15
  10. glitchlings/lexicon/_cache.py +21 -15
  11. glitchlings/lexicon/data/default_vector_cache.json +80 -14
  12. glitchlings/lexicon/vector.py +94 -15
  13. glitchlings/lexicon/wordnet.py +66 -25
  14. glitchlings/main.py +21 -11
  15. glitchlings/zoo/__init__.py +5 -1
  16. glitchlings/zoo/adjax.py +2 -2
  17. glitchlings/zoo/apostrofae.py +128 -0
  18. glitchlings/zoo/assets/__init__.py +0 -0
  19. glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
  20. glitchlings/zoo/core.py +40 -14
  21. glitchlings/zoo/jargoyle.py +44 -34
  22. glitchlings/zoo/redactyl.py +11 -8
  23. glitchlings/zoo/reduple.py +2 -2
  24. glitchlings/zoo/rushmore.py +2 -2
  25. glitchlings/zoo/scannequin.py +2 -2
  26. glitchlings/zoo/typogre.py +5 -2
  27. glitchlings/zoo/zeedub.py +5 -2
  28. {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/METADATA +35 -2
  29. glitchlings-0.4.3.dist-info/RECORD +46 -0
  30. glitchlings/lexicon/graph.py +0 -282
  31. glitchlings-0.4.2.dist-info/RECORD +0 -42
  32. {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/WHEEL +0 -0
  33. {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/entry_points.txt +0 -0
  34. {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/licenses/LICENSE +0 -0
  35. {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,233 @@
1
+ """Integration helpers for PyTorch Lightning data modules."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable, Mapping, Sequence
6
+ from typing import Any, cast
7
+
8
+ from ..compat import get_pytorch_lightning_datamodule, require_pytorch_lightning
9
+ from ..util.adapters import coerce_gaggle
10
+ from ..zoo import Gaggle, Glitchling
11
+ from ..zoo.core import _is_transcript
12
+
13
+
14
+ def _normalise_columns(column: str | Sequence[str]) -> list[str]:
15
+ """Normalise a column specification to a list."""
16
+ if isinstance(column, str):
17
+ return [column]
18
+
19
+ normalised = list(column)
20
+ if not normalised:
21
+ raise ValueError("At least one column must be specified")
22
+ return normalised
23
+
24
+
25
+ def _glitch_value(value: Any, gaggle: Gaggle) -> Any:
26
+ """Apply glitchlings to a value when it contains textual content."""
27
+ if isinstance(value, str) or _is_transcript(value, allow_empty=False, require_all_content=True):
28
+ return gaggle.corrupt(value)
29
+
30
+ if isinstance(value, Sequence) and value and all(isinstance(item, str) for item in value):
31
+ return [gaggle.corrupt(item) for item in value]
32
+
33
+ return value
34
+
35
+
36
+ def _glitch_batch(batch: Any, columns: list[str], gaggle: Gaggle) -> Any:
37
+ """Apply glitchlings to the configured batch columns."""
38
+ if not isinstance(batch, Mapping):
39
+ return batch
40
+
41
+ if hasattr(batch, "copy"):
42
+ mutated = batch.copy()
43
+ else:
44
+ mutated = dict(batch)
45
+
46
+ missing = [column for column in columns if column not in mutated]
47
+ if missing:
48
+ missing_str = ", ".join(sorted(missing))
49
+ raise ValueError(f"Columns not found in batch: {missing_str}")
50
+
51
+ for column in columns:
52
+ mutated[column] = _glitch_value(mutated[column], gaggle)
53
+
54
+ return mutated
55
+
56
+
57
+ def _wrap_dataloader(dataloader: Any, columns: list[str], gaggle: Gaggle) -> Any:
58
+ """Wrap a dataloader so yielded batches are corrupted lazily."""
59
+ if dataloader is None:
60
+ return None
61
+
62
+ if isinstance(dataloader, Mapping):
63
+ mapping_type = cast(type[Any], dataloader.__class__)
64
+ return mapping_type(
65
+ {
66
+ key: _wrap_dataloader(value, columns, gaggle)
67
+ for key, value in dataloader.items()
68
+ }
69
+ )
70
+
71
+ if isinstance(dataloader, list):
72
+ return [_wrap_dataloader(value, columns, gaggle) for value in dataloader]
73
+
74
+ if isinstance(dataloader, tuple):
75
+ return tuple(_wrap_dataloader(value, columns, gaggle) for value in dataloader)
76
+
77
+ if isinstance(dataloader, Sequence) and not isinstance(dataloader, (str, bytes, bytearray)):
78
+ sequence_type = cast(type[Any], dataloader.__class__)
79
+ return sequence_type(
80
+ _wrap_dataloader(value, columns, gaggle) for value in dataloader
81
+ )
82
+
83
+ return _GlitchedDataLoader(dataloader, columns, gaggle)
84
+
85
+
86
+ class _GlitchedDataLoader:
87
+ """Proxy dataloader that glitches batches produced by the wrapped loader."""
88
+
89
+ def __init__(self, dataloader: Any, columns: list[str], gaggle: Gaggle) -> None:
90
+ self._dataloader = dataloader
91
+ self._columns = columns
92
+ self._gaggle = gaggle
93
+
94
+ def __iter__(self) -> Any:
95
+ for batch in self._dataloader:
96
+ yield _glitch_batch(batch, self._columns, self._gaggle)
97
+
98
+ def __len__(self) -> int:
99
+ return len(self._dataloader)
100
+
101
+ def __getattr__(self, attribute: str) -> Any:
102
+ return getattr(self._dataloader, attribute)
103
+
104
+
105
+ def _glitch_datamodule(
106
+ datamodule: Any,
107
+ glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
108
+ column: str | Sequence[str],
109
+ *,
110
+ seed: int = 151,
111
+ ) -> Any:
112
+ """Return a proxy that applies glitchlings to batches from the datamodule."""
113
+
114
+ columns = _normalise_columns(column)
115
+ gaggle = coerce_gaggle(glitchlings, seed=seed)
116
+ return _GlitchedLightningDataModule(datamodule, columns, gaggle)
117
+
118
+
119
+ class _GlitchedLightningDataModule:
120
+ """Proxy wrapper around a LightningDataModule applying glitchlings to batches."""
121
+
122
+ def __init__(self, base: Any, columns: list[str], gaggle: Gaggle) -> None:
123
+ object.__setattr__(self, "_glitch_base", base)
124
+ object.__setattr__(self, "_glitch_columns", columns)
125
+ object.__setattr__(self, "_glitch_gaggle", gaggle)
126
+
127
+ def __getattr__(self, attribute: str) -> Any:
128
+ return getattr(self._glitch_base, attribute)
129
+
130
+ def __setattr__(self, attribute: str, value: Any) -> None:
131
+ if attribute.startswith("_glitch_"):
132
+ object.__setattr__(self, attribute, value)
133
+ else:
134
+ setattr(self._glitch_base, attribute, value)
135
+
136
+ def __delattr__(self, attribute: str) -> None:
137
+ if attribute.startswith("_glitch_"):
138
+ object.__delattr__(self, attribute)
139
+ else:
140
+ delattr(self._glitch_base, attribute)
141
+
142
+ def __dir__(self) -> list[str]:
143
+ return sorted(set(dir(self.__class__)) | set(dir(self._glitch_base)))
144
+
145
+ # LightningDataModule API -------------------------------------------------
146
+ def prepare_data(self, *args: Any, **kwargs: Any) -> Any:
147
+ return self._glitch_base.prepare_data(*args, **kwargs)
148
+
149
+ def setup(self, *args: Any, **kwargs: Any) -> Any:
150
+ return self._glitch_base.setup(*args, **kwargs)
151
+
152
+ def teardown(self, *args: Any, **kwargs: Any) -> Any:
153
+ return self._glitch_base.teardown(*args, **kwargs)
154
+
155
+ def state_dict(self) -> Mapping[str, Any]:
156
+ state = self._glitch_base.state_dict()
157
+ return cast(Mapping[str, Any], state)
158
+
159
+ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
160
+ self._glitch_base.load_state_dict(state_dict)
161
+
162
+ def transfer_batch_to_device(self, batch: Any, device: Any, dataloader_idx: int) -> Any:
163
+ return self._glitch_base.transfer_batch_to_device(batch, device, dataloader_idx)
164
+
165
+ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
166
+ return self._glitch_base.on_before_batch_transfer(batch, dataloader_idx)
167
+
168
+ def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
169
+ return self._glitch_base.on_after_batch_transfer(batch, dataloader_idx)
170
+
171
+ def train_dataloader(self, *args: Any, **kwargs: Any) -> Any:
172
+ loader = self._glitch_base.train_dataloader(*args, **kwargs)
173
+ return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
174
+
175
+ def val_dataloader(self, *args: Any, **kwargs: Any) -> Any:
176
+ loader = self._glitch_base.val_dataloader(*args, **kwargs)
177
+ return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
178
+
179
+ def test_dataloader(self, *args: Any, **kwargs: Any) -> Any:
180
+ loader = self._glitch_base.test_dataloader(*args, **kwargs)
181
+ return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
182
+
183
+ def predict_dataloader(self, *args: Any, **kwargs: Any) -> Any:
184
+ loader = self._glitch_base.predict_dataloader(*args, **kwargs)
185
+ return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
186
+
187
+
188
+ def _ensure_datamodule_class() -> Any:
189
+ """Return the Lightning ``LightningDataModule`` patched with ``.glitch``."""
190
+
191
+ datamodule_cls = get_pytorch_lightning_datamodule()
192
+ if datamodule_cls is None: # pragma: no cover - dependency is optional
193
+ module = require_pytorch_lightning("pytorch_lightning is not installed")
194
+ datamodule_cls = getattr(module, "LightningDataModule", None)
195
+ if datamodule_cls is None:
196
+ raise ModuleNotFoundError("pytorch_lightning is not installed")
197
+
198
+ if getattr(datamodule_cls, "glitch", None) is None:
199
+
200
+ def glitch(
201
+ self: Any,
202
+ glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
203
+ *,
204
+ column: str | Sequence[str],
205
+ seed: int = 151,
206
+ **_: Any,
207
+ ) -> Any:
208
+ return _glitch_datamodule(self, glitchlings, column, seed=seed)
209
+
210
+ setattr(datamodule_cls, "glitch", glitch)
211
+
212
+ if not issubclass(_GlitchedLightningDataModule, datamodule_cls):
213
+ _GlitchedLightningDataModule.__bases__ = (datamodule_cls,)
214
+
215
+ return datamodule_cls
216
+
217
+
218
+ def install() -> None:
219
+ """Monkeypatch ``LightningDataModule`` with ``.glitch``."""
220
+
221
+ _ensure_datamodule_class()
222
+
223
+
224
+ LightningDataModule: type[Any] | None
225
+ _LightningDataModuleAlias = get_pytorch_lightning_datamodule()
226
+ if _LightningDataModuleAlias is not None:
227
+ LightningDataModule = _ensure_datamodule_class()
228
+ else: # pragma: no cover - optional dependency
229
+ LightningDataModule = None
230
+
231
+
232
+ __all__ = ["LightningDataModule", "install"]
233
+
@@ -91,7 +91,6 @@ class LexiconBackend(Lexicon):
91
91
  """Persist the backend cache to ``path`` and return the destination."""
92
92
 
93
93
 
94
- from .graph import GraphLexicon # noqa: E402
95
94
  from .metrics import ( # noqa: E402
96
95
  compare_lexicons,
97
96
  coverage_ratio,
@@ -100,10 +99,13 @@ from .metrics import ( # noqa: E402
100
99
  )
101
100
  from .vector import VectorLexicon, build_vector_cache # noqa: E402
102
101
 
102
+ _WordNetLexicon: type[LexiconBackend] | None
103
103
  try: # pragma: no cover - optional dependency
104
- from .wordnet import WordNetLexicon
104
+ from .wordnet import WordNetLexicon as _WordNetLexicon
105
105
  except Exception: # pragma: no cover - triggered when nltk unavailable
106
- WordNetLexicon = None # type: ignore[assignment]
106
+ _WordNetLexicon = None
107
+
108
+ WordNetLexicon: type[LexiconBackend] | None = _WordNetLexicon
107
109
 
108
110
 
109
111
  _BACKEND_FACTORIES: dict[str, Callable[[int | None], Lexicon | None]] = {}
@@ -135,16 +137,6 @@ def _vector_backend(seed: int | None) -> Lexicon | None:
135
137
  return VectorLexicon(cache_path=cache_path, seed=seed)
136
138
 
137
139
 
138
- def _graph_backend(seed: int | None) -> Lexicon | None:
139
- config = get_config()
140
- cache_path = config.lexicon.graph_cache
141
- if cache_path is None:
142
- return None
143
- if not cache_path.exists():
144
- return None
145
- return GraphLexicon(cache_path=cache_path, seed=seed)
146
-
147
-
148
140
  def _wordnet_backend(seed: int | None) -> Lexicon | None: # pragma: no cover - optional
149
141
  if WordNetLexicon is None:
150
142
  return None
@@ -156,7 +148,6 @@ def _wordnet_backend(seed: int | None) -> Lexicon | None: # pragma: no cover -
156
148
 
157
149
 
158
150
  register_backend("vector", _vector_backend)
159
- register_backend("graph", _graph_backend)
160
151
  register_backend("wordnet", _wordnet_backend)
161
152
 
162
153
 
@@ -184,7 +175,6 @@ __all__ = [
184
175
  "Lexicon",
185
176
  "LexiconBackend",
186
177
  "VectorLexicon",
187
- "GraphLexicon",
188
178
  "WordNetLexicon",
189
179
  "build_vector_cache",
190
180
  "compare_lexicons",
@@ -6,7 +6,7 @@ import json
6
6
  from dataclasses import dataclass
7
7
  from hashlib import blake2s
8
8
  from pathlib import Path
9
- from typing import Mapping, Sequence
9
+ from typing import Mapping, Sequence, cast
10
10
 
11
11
  CacheEntries = dict[str, list[str]]
12
12
 
@@ -19,7 +19,7 @@ class CacheSnapshot:
19
19
  checksum: str | None = None
20
20
 
21
21
 
22
- def _normalise_entries(payload: Mapping[str, Sequence[str]]) -> CacheEntries:
22
+ def _normalise_entries(payload: Mapping[str, object]) -> CacheEntries:
23
23
  """Convert raw cache payloads into canonical mapping form."""
24
24
  entries: CacheEntries = {}
25
25
  for key, values in payload.items():
@@ -49,27 +49,31 @@ def load_cache(path: Path) -> CacheSnapshot:
49
49
  return CacheSnapshot(entries={}, checksum=None)
50
50
 
51
51
  with path.open("r", encoding="utf8") as handle:
52
- payload = json.load(handle)
52
+ payload_obj = json.load(handle)
53
53
 
54
54
  checksum: str | None = None
55
- entries_payload: Mapping[str, Sequence[str]]
55
+ entries_payload: Mapping[str, object]
56
56
 
57
- if isinstance(payload, Mapping) and "__meta__" in payload and "entries" in payload:
58
- meta = payload["__meta__"]
59
- entries_payload = payload["entries"] # type: ignore[assignment]
60
- if not isinstance(entries_payload, Mapping):
57
+ if not isinstance(payload_obj, Mapping):
58
+ raise RuntimeError("Synonym cache payload must be a mapping of strings to lists.")
59
+
60
+ payload = cast(Mapping[str, object], payload_obj)
61
+
62
+ if "__meta__" in payload and "entries" in payload:
63
+ meta_obj = payload["__meta__"]
64
+ entries_obj = payload["entries"]
65
+ if not isinstance(entries_obj, Mapping):
61
66
  raise RuntimeError("Synonym cache entries must be stored as a mapping.")
62
- if isinstance(meta, Mapping):
63
- raw_checksum = meta.get("checksum")
67
+ entries_payload = cast(Mapping[str, object], entries_obj)
68
+ if isinstance(meta_obj, Mapping):
69
+ raw_checksum = meta_obj.get("checksum")
64
70
  if raw_checksum is not None and not isinstance(raw_checksum, str):
65
71
  raise RuntimeError("Synonym cache checksum must be a string when provided.")
66
- checksum = raw_checksum
72
+ checksum = raw_checksum if isinstance(raw_checksum, str) else None
67
73
  else:
68
74
  raise RuntimeError("Synonym cache metadata must be a mapping.")
69
- elif isinstance(payload, Mapping):
70
- entries_payload = payload # legacy format without metadata
71
75
  else:
72
- raise RuntimeError("Synonym cache payload must be a mapping of strings to lists.")
76
+ entries_payload = payload # legacy format without metadata
73
77
 
74
78
  entries = _normalise_entries(entries_payload)
75
79
  if checksum is not None:
@@ -84,7 +88,9 @@ def load_cache(path: Path) -> CacheSnapshot:
84
88
 
85
89
  def write_cache(path: Path, entries: Mapping[str, Sequence[str]]) -> CacheSnapshot:
86
90
  """Persist ``entries`` to ``path`` with checksum metadata."""
87
- serialisable = {key: list(values) for key, values in sorted(entries.items())}
91
+ serialisable: CacheEntries = {
92
+ key: list(values) for key, values in sorted(entries.items())
93
+ }
88
94
  checksum = compute_checksum(serialisable)
89
95
  payload = {
90
96
  "__meta__": {
@@ -1,16 +1,82 @@
1
1
  {
2
- "sing": ["croon", "warble", "chant", "serenade"],
3
- "happy": ["cheerful", "joyful", "contented", "gleeful"],
4
- "songs": ["tunes", "melodies", "ballads", "airs"],
5
- "quickly": ["rapidly", "swiftly", "speedily", "promptly"],
6
- "text": ["passage", "excerpt", "phrase", "content"],
7
- "words": ["terms", "phrases", "lexemes", "expressions"],
8
- "alpha": ["beta", "gamma", "delta"],
9
- "beta": ["alpha", "gamma", "delta"],
10
- "gamma": ["alpha", "beta", "delta"],
11
- "delta": ["alpha", "beta", "gamma"],
12
- "they": ["these people", "those folks", "those individuals"],
13
- "quick": ["rapid", "swift", "brisk", "prompt"],
14
- "fast": ["rapid", "swift", "quick", "speedy"],
15
- "slow": ["sluggish", "lethargic", "unhurried", "deliberate"]
2
+ "alpha": [
3
+ "beta",
4
+ "gamma",
5
+ "delta"
6
+ ],
7
+ "beta": [
8
+ "alpha",
9
+ "gamma",
10
+ "delta"
11
+ ],
12
+ "delta": [
13
+ "alpha",
14
+ "beta",
15
+ "gamma"
16
+ ],
17
+ "fast": [
18
+ "rapid",
19
+ "swift",
20
+ "speedy",
21
+ "brisk"
22
+ ],
23
+ "gamma": [
24
+ "alpha",
25
+ "beta",
26
+ "delta"
27
+ ],
28
+ "happy": [
29
+ "glad",
30
+ "joyful",
31
+ "content",
32
+ "upbeat"
33
+ ],
34
+ "quick": [
35
+ "swift",
36
+ "rapid",
37
+ "speedy",
38
+ "nimble"
39
+ ],
40
+ "quickly": [
41
+ "swiftly",
42
+ "rapidly",
43
+ "promptly",
44
+ "speedily"
45
+ ],
46
+ "sing": [
47
+ "croon",
48
+ "serenade",
49
+ "vocalize",
50
+ "perform"
51
+ ],
52
+ "slow": [
53
+ "sluggish",
54
+ "leisurely",
55
+ "unhurried",
56
+ "gradual"
57
+ ],
58
+ "songs": [
59
+ "tracks",
60
+ "melodies",
61
+ "ballads",
62
+ "tunes"
63
+ ],
64
+ "text": [
65
+ "passage",
66
+ "copy",
67
+ "script",
68
+ "narrative"
69
+ ],
70
+ "they": [
71
+ "those people",
72
+ "those individuals",
73
+ "the group",
74
+ "those folks"
75
+ ],
76
+ "words": [
77
+ "terms",
78
+ "phrases",
79
+ "lexicon",
80
+ "vocabulary"
81
+ ]
16
82
  }
@@ -4,6 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import argparse
6
6
  import importlib
7
+ import importlib.util
7
8
  import json
8
9
  import math
9
10
  import sys
@@ -188,6 +189,58 @@ def _load_spacy_language(model_name: str) -> Any:
188
189
  return spacy_module.load(model_name)
189
190
 
190
191
 
192
+ def _load_sentence_transformer(model_name: str) -> Any:
193
+ """Return a ``SentenceTransformer`` instance for ``model_name``."""
194
+
195
+ if importlib.util.find_spec("sentence_transformers") is None:
196
+ raise RuntimeError(
197
+ "sentence-transformers is required for this source; install the 'st' extra."
198
+ )
199
+
200
+ module = importlib.import_module("sentence_transformers")
201
+ try:
202
+ model_cls = getattr(module, "SentenceTransformer")
203
+ except AttributeError as exc: # pragma: no cover - defensive
204
+ raise RuntimeError("sentence-transformers does not expose SentenceTransformer") from exc
205
+
206
+ return model_cls(model_name)
207
+
208
+
209
+ def _build_sentence_transformer_embeddings(
210
+ model_name: str, tokens: Sequence[str]
211
+ ) -> Mapping[str, Sequence[float]]:
212
+ """Return embeddings for ``tokens`` using ``model_name``."""
213
+
214
+ if not tokens:
215
+ return {}
216
+
217
+ model = _load_sentence_transformer(model_name)
218
+
219
+ unique_tokens: list[str] = []
220
+ seen: set[str] = set()
221
+ for token in tokens:
222
+ normalized = token.strip()
223
+ if not normalized or normalized in seen:
224
+ continue
225
+ unique_tokens.append(normalized)
226
+ seen.add(normalized)
227
+
228
+ if not unique_tokens:
229
+ return {}
230
+
231
+ embeddings = model.encode(
232
+ unique_tokens,
233
+ batch_size=64,
234
+ normalize_embeddings=True,
235
+ convert_to_numpy=True,
236
+ )
237
+
238
+ return {
239
+ token: [float(value) for value in vector]
240
+ for token, vector in zip(unique_tokens, embeddings, strict=True)
241
+ }
242
+
243
+
191
244
  def _resolve_source(source: Any | None) -> _Adapter | None:
192
245
  """Return an adapter instance for ``source`` if possible."""
193
246
  if source is None:
@@ -248,6 +301,7 @@ class VectorLexicon(LexiconBackend):
248
301
  case_sensitive: bool = False,
249
302
  seed: int | None = None,
250
303
  ) -> None:
304
+ """Initialise the lexicon with an embedding ``source`` and optional cache."""
251
305
  super().__init__(seed=seed)
252
306
  self._adapter = _resolve_source(source)
253
307
  self._max_neighbors = max(1, max_neighbors)
@@ -350,6 +404,7 @@ class VectorLexicon(LexiconBackend):
350
404
  return synonyms
351
405
 
352
406
  def get_synonyms(self, word: str, pos: str | None = None, n: int = 5) -> list[str]:
407
+ """Return up to ``n`` deterministic synonyms drawn from the embedding cache."""
353
408
  normalized = self._normalize_for_lookup(word)
354
409
  synonyms = self._ensure_cached(original=word, normalized=normalized)
355
410
  return self._deterministic_sample(synonyms, limit=n, word=word, pos=pos)
@@ -390,6 +445,7 @@ class VectorLexicon(LexiconBackend):
390
445
  return target
391
446
 
392
447
  def supports_pos(self, pos: str | None) -> bool:
448
+ """Always return ``True`` because vector sources do not encode POS metadata."""
393
449
  return True
394
450
 
395
451
  def __repr__(self) -> str: # pragma: no cover - debug helper
@@ -452,7 +508,8 @@ def _parse_cli(argv: Sequence[str] | None = None) -> argparse.Namespace:
452
508
  "--source",
453
509
  required=True,
454
510
  help=(
455
- "Vector source specification. Use 'spacy:<model>' for spaCy pipelines "
511
+ "Vector source specification. Use 'spacy:<model>' for spaCy pipelines, "
512
+ "'sentence-transformers:<model>' for HuggingFace checkpoints (requires --tokens), "
456
513
  "or provide a path to a gensim KeyedVectors/word2vec file."
457
514
  ),
458
515
  )
@@ -534,22 +591,44 @@ def main(argv: Sequence[str] | None = None) -> int:
534
591
 
535
592
  normalizer = _identity
536
593
 
537
- source = load_vector_source(args.source)
594
+ tokens_from_file: list[str] | None = None
538
595
  if args.tokens is not None:
539
- token_iter: Iterable[str] = _iter_tokens_from_file(args.tokens)
596
+ tokens_from_file = list(_iter_tokens_from_file(args.tokens))
597
+ if args.limit is not None:
598
+ tokens_from_file = tokens_from_file[: args.limit]
599
+
600
+ source_spec = args.source
601
+ token_iter: Iterable[str]
602
+ if source_spec.startswith("sentence-transformers:"):
603
+ model_name = source_spec.split(":", 1)[1].strip()
604
+ if not model_name:
605
+ model_name = "sentence-transformers/all-mpnet-base-v2"
606
+ if tokens_from_file is None:
607
+ raise SystemExit(
608
+ "Sentence-transformers sources require --tokens to supply a vocabulary."
609
+ )
610
+ source = _build_sentence_transformer_embeddings(model_name, tokens_from_file)
611
+ token_iter = tokens_from_file
540
612
  else:
541
- lexicon = VectorLexicon(
542
- source=source,
543
- max_neighbors=args.max_neighbors,
544
- min_similarity=args.min_similarity,
545
- case_sensitive=args.case_sensitive,
546
- normalizer=normalizer,
547
- seed=args.seed,
548
- )
549
- token_iter = lexicon.iter_vocabulary()
550
-
551
- if args.limit is not None:
552
- token_iter = (token for index, token in enumerate(token_iter) if index < args.limit)
613
+ source = load_vector_source(source_spec)
614
+ if tokens_from_file is not None:
615
+ token_iter = tokens_from_file
616
+ else:
617
+ lexicon = VectorLexicon(
618
+ source=source,
619
+ max_neighbors=args.max_neighbors,
620
+ min_similarity=args.min_similarity,
621
+ case_sensitive=args.case_sensitive,
622
+ normalizer=normalizer,
623
+ seed=args.seed,
624
+ )
625
+ iterator = lexicon.iter_vocabulary()
626
+ if args.limit is not None:
627
+ token_iter = (
628
+ token for index, token in enumerate(iterator) if index < args.limit
629
+ )
630
+ else:
631
+ token_iter = iterator
553
632
 
554
633
  build_vector_cache(
555
634
  source=source,