glitchlings 0.4.2__cp312-cp312-manylinux_2_28_x86_64.whl → 0.4.3__cp312-cp312-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +4 -0
- glitchlings/_zoo_rust.cpython-312-x86_64-linux-gnu.so +0 -0
- glitchlings/compat.py +80 -11
- glitchlings/config.py +32 -19
- glitchlings/config.toml +1 -1
- glitchlings/dlc/__init__.py +3 -1
- glitchlings/dlc/pytorch.py +216 -0
- glitchlings/dlc/pytorch_lightning.py +233 -0
- glitchlings/lexicon/__init__.py +5 -15
- glitchlings/lexicon/_cache.py +21 -15
- glitchlings/lexicon/data/default_vector_cache.json +80 -14
- glitchlings/lexicon/vector.py +94 -15
- glitchlings/lexicon/wordnet.py +66 -25
- glitchlings/main.py +21 -11
- glitchlings/zoo/__init__.py +5 -1
- glitchlings/zoo/adjax.py +2 -2
- glitchlings/zoo/apostrofae.py +128 -0
- glitchlings/zoo/assets/__init__.py +0 -0
- glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
- glitchlings/zoo/core.py +40 -14
- glitchlings/zoo/jargoyle.py +44 -34
- glitchlings/zoo/redactyl.py +11 -8
- glitchlings/zoo/reduple.py +2 -2
- glitchlings/zoo/rushmore.py +2 -2
- glitchlings/zoo/scannequin.py +2 -2
- glitchlings/zoo/typogre.py +5 -2
- glitchlings/zoo/zeedub.py +5 -2
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/METADATA +35 -2
- glitchlings-0.4.3.dist-info/RECORD +46 -0
- glitchlings/lexicon/graph.py +0 -282
- glitchlings-0.4.2.dist-info/RECORD +0 -42
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/WHEEL +0 -0
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/entry_points.txt +0 -0
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {glitchlings-0.4.2.dist-info → glitchlings-0.4.3.dist-info}/top_level.txt +0 -0
glitchlings/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ from .config import AttackConfig, build_gaggle, load_attack_config
|
|
|
2
2
|
from .util import SAMPLE_TEXT
|
|
3
3
|
from .zoo import (
|
|
4
4
|
Adjax,
|
|
5
|
+
Apostrofae,
|
|
5
6
|
Gaggle,
|
|
6
7
|
Glitchling,
|
|
7
8
|
Jargoyle,
|
|
@@ -13,6 +14,7 @@ from .zoo import (
|
|
|
13
14
|
Typogre,
|
|
14
15
|
Zeedub,
|
|
15
16
|
adjax,
|
|
17
|
+
apostrofae,
|
|
16
18
|
is_rust_pipeline_enabled,
|
|
17
19
|
is_rust_pipeline_supported,
|
|
18
20
|
jargoyle,
|
|
@@ -38,6 +40,8 @@ __all__ = [
|
|
|
38
40
|
"jargoyle",
|
|
39
41
|
"Adjax",
|
|
40
42
|
"adjax",
|
|
43
|
+
"Apostrofae",
|
|
44
|
+
"apostrofae",
|
|
41
45
|
"Redactyl",
|
|
42
46
|
"redactyl",
|
|
43
47
|
"Reduple",
|
|
Binary file
|
glitchlings/compat.py
CHANGED
|
@@ -6,16 +6,50 @@ import re
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from importlib import import_module, metadata
|
|
8
8
|
from types import ModuleType
|
|
9
|
-
from typing import Any, Iterable
|
|
9
|
+
from typing import Any, Callable, Iterable, Protocol, cast
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _MissingSentinel:
|
|
13
|
+
__slots__ = ()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_MISSING = _MissingSentinel()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _MarkerProtocol(Protocol):
|
|
20
|
+
def evaluate(self, environment: dict[str, str]) -> bool:
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _RequirementProtocol(Protocol):
|
|
25
|
+
marker: _MarkerProtocol | None
|
|
26
|
+
name: str
|
|
27
|
+
|
|
28
|
+
def __init__(self, requirement: str) -> None:
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
try: # pragma: no cover - packaging is bundled with modern Python environments
|
|
33
|
+
from packaging.markers import default_environment as _default_environment
|
|
34
|
+
except ModuleNotFoundError: # pragma: no cover - fallback when packaging missing
|
|
35
|
+
_default_environment = None
|
|
10
36
|
|
|
11
37
|
try: # pragma: no cover - packaging is bundled with modern Python environments
|
|
12
|
-
from packaging.
|
|
13
|
-
from packaging.requirements import Requirement
|
|
38
|
+
from packaging.requirements import Requirement as _RequirementClass
|
|
14
39
|
except ModuleNotFoundError: # pragma: no cover - fallback when packaging missing
|
|
15
|
-
|
|
16
|
-
default_environment = None # type: ignore[assignment]
|
|
40
|
+
_RequirementClass = None
|
|
17
41
|
|
|
18
|
-
|
|
42
|
+
default_environment: Callable[[], dict[str, str]] | None
|
|
43
|
+
if _default_environment is None:
|
|
44
|
+
default_environment = None
|
|
45
|
+
else:
|
|
46
|
+
default_environment = cast(Callable[[], dict[str, str]], _default_environment)
|
|
47
|
+
|
|
48
|
+
Requirement: type[_RequirementProtocol] | None
|
|
49
|
+
if _RequirementClass is None:
|
|
50
|
+
Requirement = None
|
|
51
|
+
else:
|
|
52
|
+
Requirement = cast(type[_RequirementProtocol], _RequirementClass)
|
|
19
53
|
|
|
20
54
|
|
|
21
55
|
@dataclass
|
|
@@ -23,7 +57,7 @@ class OptionalDependency:
|
|
|
23
57
|
"""Lazily import an optional dependency and retain the import error."""
|
|
24
58
|
|
|
25
59
|
module_name: str
|
|
26
|
-
_cached: ModuleType |
|
|
60
|
+
_cached: ModuleType | None | _MissingSentinel = _MISSING
|
|
27
61
|
_error: ModuleNotFoundError | None = None
|
|
28
62
|
|
|
29
63
|
def _attempt_import(self) -> ModuleType | None:
|
|
@@ -40,11 +74,12 @@ class OptionalDependency:
|
|
|
40
74
|
|
|
41
75
|
def get(self) -> ModuleType | None:
|
|
42
76
|
"""Return the imported module or ``None`` when unavailable."""
|
|
43
|
-
|
|
77
|
+
cached = self._cached
|
|
78
|
+
if isinstance(cached, _MissingSentinel):
|
|
44
79
|
return self._attempt_import()
|
|
45
|
-
if
|
|
80
|
+
if cached is None:
|
|
46
81
|
return None
|
|
47
|
-
return
|
|
82
|
+
return cached
|
|
48
83
|
|
|
49
84
|
def load(self) -> ModuleType:
|
|
50
85
|
"""Return the dependency, raising the original import error when absent."""
|
|
@@ -87,16 +122,18 @@ class OptionalDependency:
|
|
|
87
122
|
return self._error
|
|
88
123
|
|
|
89
124
|
|
|
125
|
+
pytorch_lightning = OptionalDependency("pytorch_lightning")
|
|
90
126
|
datasets = OptionalDependency("datasets")
|
|
91
127
|
verifiers = OptionalDependency("verifiers")
|
|
92
128
|
jellyfish = OptionalDependency("jellyfish")
|
|
93
129
|
jsonschema = OptionalDependency("jsonschema")
|
|
94
130
|
nltk = OptionalDependency("nltk")
|
|
131
|
+
torch = OptionalDependency("torch")
|
|
95
132
|
|
|
96
133
|
|
|
97
134
|
def reset_optional_dependencies() -> None:
|
|
98
135
|
"""Clear cached optional dependency imports (used by tests)."""
|
|
99
|
-
for dependency in (datasets, verifiers, jellyfish, jsonschema, nltk):
|
|
136
|
+
for dependency in (pytorch_lightning, datasets, verifiers, jellyfish, jsonschema, nltk, torch):
|
|
100
137
|
dependency.reset()
|
|
101
138
|
|
|
102
139
|
|
|
@@ -110,6 +147,16 @@ def require_datasets(message: str = "datasets is not installed") -> ModuleType:
|
|
|
110
147
|
return datasets.require(message)
|
|
111
148
|
|
|
112
149
|
|
|
150
|
+
def get_pytorch_lightning_datamodule() -> Any | None:
|
|
151
|
+
"""Return the PyTorch Lightning ``LightningDataModule`` when available."""
|
|
152
|
+
return pytorch_lightning.attr("LightningDataModule")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def require_pytorch_lightning(message: str = "pytorch_lightning is not installed") -> ModuleType:
|
|
156
|
+
"""Ensure the PyTorch Lightning dependency is present."""
|
|
157
|
+
return pytorch_lightning.require(message)
|
|
158
|
+
|
|
159
|
+
|
|
113
160
|
def require_verifiers(message: str = "verifiers is not installed") -> ModuleType:
|
|
114
161
|
"""Ensure the verifiers dependency is present."""
|
|
115
162
|
return verifiers.require(message)
|
|
@@ -120,6 +167,28 @@ def require_jellyfish(message: str = "jellyfish is not installed") -> ModuleType
|
|
|
120
167
|
return jellyfish.require(message)
|
|
121
168
|
|
|
122
169
|
|
|
170
|
+
def require_torch(message: str = "torch is not installed") -> ModuleType:
|
|
171
|
+
"""Ensure the PyTorch dependency is present."""
|
|
172
|
+
return torch.require(message)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_torch_dataloader() -> Any | None:
|
|
176
|
+
"""Return PyTorch ``DataLoader`` when the dependency is installed."""
|
|
177
|
+
torch_module = torch.get()
|
|
178
|
+
if torch_module is None:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
utils_module = getattr(torch_module, "utils", None)
|
|
182
|
+
if utils_module is None:
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
data_module = getattr(utils_module, "data", None)
|
|
186
|
+
if data_module is None:
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
return getattr(data_module, "DataLoader", None)
|
|
190
|
+
|
|
191
|
+
|
|
123
192
|
def get_installed_extras(
|
|
124
193
|
extras: Iterable[str] | None = None,
|
|
125
194
|
*,
|
glitchlings/config.py
CHANGED
|
@@ -2,29 +2,46 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import importlib
|
|
5
6
|
import os
|
|
6
7
|
import warnings
|
|
7
8
|
from dataclasses import dataclass, field
|
|
8
9
|
from io import TextIOBase
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from typing import TYPE_CHECKING, Any, Mapping, Sequence
|
|
11
|
+
from typing import IO, TYPE_CHECKING, Any, Mapping, Protocol, Sequence, cast
|
|
12
|
+
|
|
13
|
+
from glitchlings.compat import jsonschema
|
|
11
14
|
|
|
12
15
|
try: # Python 3.11+
|
|
13
|
-
import tomllib
|
|
16
|
+
import tomllib as _tomllib
|
|
14
17
|
except ModuleNotFoundError: # pragma: no cover - Python < 3.11
|
|
15
|
-
|
|
18
|
+
_tomllib = importlib.import_module("tomli")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _TomllibModule(Protocol):
|
|
22
|
+
def load(self, fp: IO[bytes]) -> Any:
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
tomllib = cast(_TomllibModule, _tomllib)
|
|
27
|
+
|
|
16
28
|
|
|
17
|
-
|
|
29
|
+
class _YamlModule(Protocol):
|
|
30
|
+
YAMLError: type[Exception]
|
|
18
31
|
|
|
19
|
-
|
|
32
|
+
def safe_load(self, stream: str) -> Any:
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
yaml = cast(_YamlModule, importlib.import_module("yaml"))
|
|
20
37
|
|
|
21
38
|
if TYPE_CHECKING: # pragma: no cover - typing only
|
|
22
|
-
from .zoo import Glitchling
|
|
39
|
+
from .zoo import Gaggle, Glitchling
|
|
23
40
|
|
|
24
41
|
|
|
25
42
|
CONFIG_ENV_VAR = "GLITCHLINGS_CONFIG"
|
|
26
43
|
DEFAULT_CONFIG_PATH = Path(__file__).with_name("config.toml")
|
|
27
|
-
DEFAULT_LEXICON_PRIORITY = ["vector", "
|
|
44
|
+
DEFAULT_LEXICON_PRIORITY = ["vector", "wordnet"]
|
|
28
45
|
DEFAULT_ATTACK_SEED = 151
|
|
29
46
|
|
|
30
47
|
ATTACK_CONFIG_SCHEMA: dict[str, Any] = {
|
|
@@ -72,7 +89,6 @@ class LexiconConfig:
|
|
|
72
89
|
|
|
73
90
|
priority: list[str] = field(default_factory=lambda: list(DEFAULT_LEXICON_PRIORITY))
|
|
74
91
|
vector_cache: Path | None = None
|
|
75
|
-
graph_cache: Path | None = None
|
|
76
92
|
|
|
77
93
|
|
|
78
94
|
@dataclass(slots=True)
|
|
@@ -127,15 +143,9 @@ def _load_runtime_config() -> RuntimeConfig:
|
|
|
127
143
|
lexicon_section.get("vector_cache"),
|
|
128
144
|
base=path.parent,
|
|
129
145
|
)
|
|
130
|
-
graph_cache = _resolve_optional_path(
|
|
131
|
-
lexicon_section.get("graph_cache"),
|
|
132
|
-
base=path.parent,
|
|
133
|
-
)
|
|
134
|
-
|
|
135
146
|
lexicon_config = LexiconConfig(
|
|
136
147
|
priority=normalized_priority,
|
|
137
148
|
vector_cache=vector_cache,
|
|
138
|
-
graph_cache=graph_cache,
|
|
139
149
|
)
|
|
140
150
|
|
|
141
151
|
return RuntimeConfig(lexicon=lexicon_config, path=path)
|
|
@@ -154,7 +164,10 @@ def _read_toml(path: Path) -> dict[str, Any]:
|
|
|
154
164
|
return {}
|
|
155
165
|
raise FileNotFoundError(f"Configuration file '{path}' not found.")
|
|
156
166
|
with path.open("rb") as handle:
|
|
157
|
-
|
|
167
|
+
loaded = tomllib.load(handle)
|
|
168
|
+
if isinstance(loaded, Mapping):
|
|
169
|
+
return dict(loaded)
|
|
170
|
+
raise ValueError(f"Configuration file '{path}' must contain a top-level mapping.")
|
|
158
171
|
|
|
159
172
|
|
|
160
173
|
def _validate_runtime_config_data(data: Any, *, source: Path) -> Mapping[str, Any]:
|
|
@@ -173,13 +186,13 @@ def _validate_runtime_config_data(data: Any, *, source: Path) -> Mapping[str, An
|
|
|
173
186
|
if not isinstance(lexicon_section, Mapping):
|
|
174
187
|
raise ValueError("Configuration 'lexicon' section must be a table.")
|
|
175
188
|
|
|
176
|
-
allowed_lexicon_keys = {"priority", "vector_cache"
|
|
189
|
+
allowed_lexicon_keys = {"priority", "vector_cache"}
|
|
177
190
|
unexpected_keys = [str(key) for key in lexicon_section if key not in allowed_lexicon_keys]
|
|
178
191
|
if unexpected_keys:
|
|
179
192
|
extras = ", ".join(sorted(unexpected_keys))
|
|
180
193
|
raise ValueError(f"Unknown lexicon settings: {extras}.")
|
|
181
194
|
|
|
182
|
-
for key in ("vector_cache",
|
|
195
|
+
for key in ("vector_cache",):
|
|
183
196
|
value = lexicon_section.get(key)
|
|
184
197
|
if value is not None and not isinstance(value, (str, os.PathLike)):
|
|
185
198
|
raise ValueError(f"lexicon.{key} must be a path or string when provided.")
|
|
@@ -287,7 +300,7 @@ def parse_attack_config(data: Any, *, source: str = "<config>") -> AttackConfig:
|
|
|
287
300
|
return AttackConfig(glitchlings=glitchlings, seed=seed)
|
|
288
301
|
|
|
289
302
|
|
|
290
|
-
def build_gaggle(config: AttackConfig, *, seed_override: int | None = None):
|
|
303
|
+
def build_gaggle(config: AttackConfig, *, seed_override: int | None = None) -> "Gaggle":
|
|
291
304
|
"""Instantiate a ``Gaggle`` according to ``config``."""
|
|
292
305
|
from .zoo import Gaggle # Imported lazily to avoid circular dependencies
|
|
293
306
|
|
|
@@ -305,7 +318,7 @@ def _load_yaml(text: str, label: str) -> Any:
|
|
|
305
318
|
raise ValueError(f"Failed to parse attack configuration '{label}': {exc}") from exc
|
|
306
319
|
|
|
307
320
|
|
|
308
|
-
def _build_glitchling(entry: Any, source: str, index: int):
|
|
321
|
+
def _build_glitchling(entry: Any, source: str, index: int) -> "Glitchling":
|
|
309
322
|
from .zoo import get_glitchling_class, parse_glitchling_spec
|
|
310
323
|
|
|
311
324
|
if isinstance(entry, str):
|
glitchlings/config.toml
CHANGED
glitchlings/dlc/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Optional DLC integrations for Glitchlings."""
|
|
2
2
|
|
|
3
3
|
from .huggingface import install as install_huggingface
|
|
4
|
+
from .pytorch import install as install_pytorch
|
|
5
|
+
from .pytorch_lightning import install as install_pytorch_lightning
|
|
4
6
|
|
|
5
|
-
__all__ = ["install_huggingface"]
|
|
7
|
+
__all__ = ["install_huggingface", "install_pytorch", "install_pytorch_lightning"]
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch data loaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import get_torch_dataloader, require_torch
|
|
9
|
+
from ..compat import torch as _torch_dependency
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling
|
|
12
|
+
from ..zoo.core import _is_transcript
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _normalise_columns(columns: str | int | Sequence[str | int] | None) -> list[str | int] | None:
|
|
16
|
+
"""Normalise a column specification into a list of keys or indices."""
|
|
17
|
+
if columns is None:
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
if isinstance(columns, (str, int)):
|
|
21
|
+
return [columns]
|
|
22
|
+
|
|
23
|
+
normalised = list(columns)
|
|
24
|
+
if not normalised:
|
|
25
|
+
raise ValueError("At least one column must be specified")
|
|
26
|
+
return normalised
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_textual_candidate(value: Any) -> bool:
|
|
30
|
+
"""Return ``True`` when ``value`` looks like text that glitchlings can corrupt."""
|
|
31
|
+
if isinstance(value, str):
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
if _is_transcript(value, allow_empty=False, require_all_content=True):
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
|
|
38
|
+
if not value:
|
|
39
|
+
return False
|
|
40
|
+
if all(isinstance(item, str) for item in value):
|
|
41
|
+
return True
|
|
42
|
+
if _is_transcript(list(value), allow_empty=False, require_all_content=True):
|
|
43
|
+
return True
|
|
44
|
+
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _corrupt_text(value: Any, gaggle: Gaggle) -> Any:
|
|
49
|
+
"""Return ``value`` with glitchlings applied when possible."""
|
|
50
|
+
if isinstance(value, str):
|
|
51
|
+
return gaggle.corrupt(value)
|
|
52
|
+
|
|
53
|
+
if _is_transcript(value, allow_empty=True):
|
|
54
|
+
return gaggle.corrupt(value)
|
|
55
|
+
|
|
56
|
+
if isinstance(value, list) and value and all(isinstance(item, str) for item in value):
|
|
57
|
+
return [gaggle.corrupt(item) for item in value]
|
|
58
|
+
|
|
59
|
+
if isinstance(value, tuple) and value and all(isinstance(item, str) for item in value):
|
|
60
|
+
return tuple(gaggle.corrupt(item) for item in value)
|
|
61
|
+
|
|
62
|
+
return value
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _apply_to_batch(batch: Any, targets: list[str | int] | None, gaggle: Gaggle) -> Any:
|
|
66
|
+
"""Return ``batch`` with glitchlings applied to the specified ``targets``."""
|
|
67
|
+
if targets is None:
|
|
68
|
+
return _corrupt_text(batch, gaggle)
|
|
69
|
+
|
|
70
|
+
if isinstance(batch, Mapping):
|
|
71
|
+
mutated = cast(MutableMapping[str, Any], dict(batch))
|
|
72
|
+
for key in targets:
|
|
73
|
+
if not isinstance(key, str):
|
|
74
|
+
raise TypeError("Mapping batches require string column names")
|
|
75
|
+
if key not in mutated:
|
|
76
|
+
raise ValueError(f"Column '{key}' not found in DataLoader batch")
|
|
77
|
+
mutated[key] = _corrupt_text(mutated[key], gaggle)
|
|
78
|
+
return mutated
|
|
79
|
+
|
|
80
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
81
|
+
mutated_sequence = list(batch)
|
|
82
|
+
for index in targets:
|
|
83
|
+
if not isinstance(index, int):
|
|
84
|
+
raise TypeError("Sequence batches require integer column indices")
|
|
85
|
+
try:
|
|
86
|
+
mutated_sequence[index] = _corrupt_text(mutated_sequence[index], gaggle)
|
|
87
|
+
except IndexError as exc: # pragma: no cover - defensive
|
|
88
|
+
raise IndexError("Column index out of range for DataLoader batch") from exc
|
|
89
|
+
if isinstance(batch, tuple):
|
|
90
|
+
return tuple(mutated_sequence)
|
|
91
|
+
return mutated_sequence
|
|
92
|
+
|
|
93
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _infer_targets(batch: Any) -> list[str | int] | None:
|
|
97
|
+
"""Infer which fields should be glitched from a representative ``batch``."""
|
|
98
|
+
if isinstance(batch, Mapping):
|
|
99
|
+
inferred = [key for key, value in batch.items() if _is_textual_candidate(value)]
|
|
100
|
+
if inferred:
|
|
101
|
+
return inferred
|
|
102
|
+
raise ValueError("Unable to infer which mapping columns contain text")
|
|
103
|
+
|
|
104
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
105
|
+
inferred_indices: list[str | int] = [
|
|
106
|
+
idx for idx, value in enumerate(batch) if _is_textual_candidate(value)
|
|
107
|
+
]
|
|
108
|
+
if inferred_indices:
|
|
109
|
+
return inferred_indices
|
|
110
|
+
raise ValueError("Unable to infer which sequence indices contain text")
|
|
111
|
+
|
|
112
|
+
if _is_textual_candidate(batch):
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class _GlitchedDataLoader(Iterable[Any]):
|
|
119
|
+
"""Wrapper that applies glitchlings lazily to each batch from a data loader."""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
dataloader: Any,
|
|
124
|
+
gaggle: Gaggle,
|
|
125
|
+
*,
|
|
126
|
+
columns: list[str | int] | None,
|
|
127
|
+
) -> None:
|
|
128
|
+
self._dataloader = dataloader
|
|
129
|
+
self._gaggle = gaggle
|
|
130
|
+
self._explicit_columns = columns
|
|
131
|
+
self._inferred_columns: list[str | int] | None | _Sentinel = _UNINITIALISED
|
|
132
|
+
|
|
133
|
+
def __iter__(self) -> Iterator[Any]:
|
|
134
|
+
# Reset all glitchling RNGs before each fresh pass for determinism.
|
|
135
|
+
self._gaggle.sort_glitchlings()
|
|
136
|
+
for batch in self._dataloader:
|
|
137
|
+
targets = self._resolve_columns(batch)
|
|
138
|
+
yield _apply_to_batch(batch, targets, self._gaggle)
|
|
139
|
+
|
|
140
|
+
def __len__(self) -> int:
|
|
141
|
+
return len(self._dataloader)
|
|
142
|
+
|
|
143
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
144
|
+
return getattr(self._dataloader, attribute)
|
|
145
|
+
|
|
146
|
+
def _resolve_columns(self, batch: Any) -> list[str | int] | None:
|
|
147
|
+
if self._explicit_columns is not None:
|
|
148
|
+
return self._explicit_columns
|
|
149
|
+
|
|
150
|
+
if self._inferred_columns is _UNINITIALISED:
|
|
151
|
+
self._inferred_columns = _infer_targets(batch)
|
|
152
|
+
|
|
153
|
+
return cast(list[str | int] | None, self._inferred_columns)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class _Sentinel:
|
|
157
|
+
"""Sentinel type for deferred column inference."""
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
_UNINITIALISED = _Sentinel()
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _ensure_dataloader_class() -> type[Any]:
|
|
164
|
+
"""Return :class:`torch.utils.data.DataLoader` patched with ``.glitch``."""
|
|
165
|
+
dataloader_cls = get_torch_dataloader()
|
|
166
|
+
if dataloader_cls is None:
|
|
167
|
+
require_torch("torch is not installed; install glitchlings[torch]")
|
|
168
|
+
dataloader_cls = get_torch_dataloader()
|
|
169
|
+
if dataloader_cls is None: # pragma: no cover - defensive
|
|
170
|
+
message = "torch.utils.data.DataLoader is not available"
|
|
171
|
+
error = _torch_dependency.error
|
|
172
|
+
if error is not None:
|
|
173
|
+
raise ModuleNotFoundError(message) from error
|
|
174
|
+
raise ModuleNotFoundError(message)
|
|
175
|
+
|
|
176
|
+
if getattr(dataloader_cls, "glitch", None) is None:
|
|
177
|
+
|
|
178
|
+
def glitch(
|
|
179
|
+
self: Any,
|
|
180
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
181
|
+
*,
|
|
182
|
+
columns: str | int | Sequence[str | int] | None = None,
|
|
183
|
+
seed: int = 151,
|
|
184
|
+
) -> _GlitchedDataLoader:
|
|
185
|
+
"""Return a lazily glitched view of the loader's batches."""
|
|
186
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
187
|
+
normalised = _normalise_columns(columns)
|
|
188
|
+
return _GlitchedDataLoader(self, gaggle, columns=normalised)
|
|
189
|
+
|
|
190
|
+
setattr(dataloader_cls, "glitch", glitch)
|
|
191
|
+
|
|
192
|
+
return cast(type[Any], dataloader_cls)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _optional_dataloader_class() -> type[Any] | None:
|
|
196
|
+
"""Return the PyTorch :class:`~torch.utils.data.DataLoader` when importable."""
|
|
197
|
+
dataloader_cls = get_torch_dataloader()
|
|
198
|
+
if dataloader_cls is None:
|
|
199
|
+
return None
|
|
200
|
+
return cast(type[Any], dataloader_cls)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def install() -> None:
|
|
204
|
+
"""Monkeypatch PyTorch's :class:`~torch.utils.data.DataLoader` with ``.glitch``."""
|
|
205
|
+
_ensure_dataloader_class()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
DataLoader: type[Any] | None
|
|
209
|
+
_DataLoaderAlias = _optional_dataloader_class()
|
|
210
|
+
if _DataLoaderAlias is not None:
|
|
211
|
+
DataLoader = _ensure_dataloader_class()
|
|
212
|
+
else: # pragma: no cover - torch is an optional dependency
|
|
213
|
+
DataLoader = None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
__all__ = ["DataLoader", "install"]
|