glitchlings 0.4.4__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +67 -0
- glitchlings/__main__.py +8 -0
- glitchlings/_zoo_rust.cp313-win_amd64.pyd +0 -0
- glitchlings/compat.py +284 -0
- glitchlings/config.py +388 -0
- glitchlings/config.toml +3 -0
- glitchlings/dlc/__init__.py +7 -0
- glitchlings/dlc/_shared.py +153 -0
- glitchlings/dlc/huggingface.py +81 -0
- glitchlings/dlc/prime.py +254 -0
- glitchlings/dlc/pytorch.py +166 -0
- glitchlings/dlc/pytorch_lightning.py +215 -0
- glitchlings/lexicon/__init__.py +192 -0
- glitchlings/lexicon/_cache.py +110 -0
- glitchlings/lexicon/data/default_vector_cache.json +82 -0
- glitchlings/lexicon/metrics.py +162 -0
- glitchlings/lexicon/vector.py +651 -0
- glitchlings/lexicon/wordnet.py +232 -0
- glitchlings/main.py +364 -0
- glitchlings/util/__init__.py +195 -0
- glitchlings/util/adapters.py +27 -0
- glitchlings/zoo/__init__.py +168 -0
- glitchlings/zoo/_ocr_confusions.py +32 -0
- glitchlings/zoo/_rate.py +131 -0
- glitchlings/zoo/_rust_extensions.py +143 -0
- glitchlings/zoo/_sampling.py +54 -0
- glitchlings/zoo/_text_utils.py +100 -0
- glitchlings/zoo/adjax.py +128 -0
- glitchlings/zoo/apostrofae.py +127 -0
- glitchlings/zoo/assets/__init__.py +0 -0
- glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
- glitchlings/zoo/core.py +582 -0
- glitchlings/zoo/jargoyle.py +335 -0
- glitchlings/zoo/mim1c.py +109 -0
- glitchlings/zoo/ocr_confusions.tsv +30 -0
- glitchlings/zoo/redactyl.py +193 -0
- glitchlings/zoo/reduple.py +148 -0
- glitchlings/zoo/rushmore.py +153 -0
- glitchlings/zoo/scannequin.py +171 -0
- glitchlings/zoo/typogre.py +231 -0
- glitchlings/zoo/zeedub.py +185 -0
- glitchlings-0.4.4.dist-info/METADATA +627 -0
- glitchlings-0.4.4.dist-info/RECORD +47 -0
- glitchlings-0.4.4.dist-info/WHEEL +5 -0
- glitchlings-0.4.4.dist-info/entry_points.txt +2 -0
- glitchlings-0.4.4.dist-info/licenses/LICENSE +201 -0
- glitchlings-0.4.4.dist-info/top_level.txt +1 -0
glitchlings/dlc/prime.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Integration helpers for the optional verifiers prime DLC."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Sequence
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any, Callable, Protocol, cast
|
|
8
|
+
|
|
9
|
+
from ..compat import require_datasets, require_jellyfish, require_verifiers
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling, Mim1c, Typogre
|
|
12
|
+
from ._shared import resolve_columns as _resolve_columns_shared
|
|
13
|
+
from ._shared import resolve_environment as _resolve_environment_shared
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class VerifierEnvironment(Protocol):
|
|
17
|
+
"""Minimal interface for verifiers environments."""
|
|
18
|
+
|
|
19
|
+
dataset: Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class VerifierSingleTurnEnv(Protocol):
|
|
23
|
+
"""Minimal interface for single-turn verifier environments."""
|
|
24
|
+
|
|
25
|
+
dataset: Any
|
|
26
|
+
rubric: Any
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
vf = require_verifiers("verifiers is not installed; install glitchlings[prime]")
|
|
30
|
+
_jellyfish = require_jellyfish("jellyfish is not installed; install glitchlings[prime]")
|
|
31
|
+
damerau_levenshtein_distance = _jellyfish.damerau_levenshtein_distance
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from .huggingface import Dataset as _HuggingFaceDataset
|
|
35
|
+
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
36
|
+
_HuggingFaceDataset = None
|
|
37
|
+
else:
|
|
38
|
+
if _HuggingFaceDataset is None: # pragma: no cover - optional dependency
|
|
39
|
+
_HuggingFaceDataset = None
|
|
40
|
+
|
|
41
|
+
Dataset: type[Any]
|
|
42
|
+
if _HuggingFaceDataset is None:
|
|
43
|
+
Dataset = object
|
|
44
|
+
else:
|
|
45
|
+
Dataset = _HuggingFaceDataset
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _resolve_environment(env: str | VerifierEnvironment) -> VerifierEnvironment:
|
|
49
|
+
"""Return a fully-instantiated verifier environment."""
|
|
50
|
+
resolved = _resolve_environment_shared(
|
|
51
|
+
env,
|
|
52
|
+
loader=vf.load_environment,
|
|
53
|
+
environment_type=cast(type[Any], vf.Environment),
|
|
54
|
+
)
|
|
55
|
+
return cast(VerifierEnvironment, resolved)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
59
|
+
"""Identify which dataset columns should be corrupted."""
|
|
60
|
+
return _resolve_columns_shared(dataset, columns)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Difficulty(Enum):
|
|
64
|
+
"""Difficulty levels for tutorial environments."""
|
|
65
|
+
|
|
66
|
+
Easy = 0.25
|
|
67
|
+
Normal = 1.0
|
|
68
|
+
Hard = 1.75
|
|
69
|
+
Extreme = 3
|
|
70
|
+
Impossible = 9
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def tutorial_level(
|
|
74
|
+
env: VerifierEnvironment | str,
|
|
75
|
+
seed: int = 151,
|
|
76
|
+
difficulty: Difficulty = Difficulty.Normal,
|
|
77
|
+
) -> VerifierEnvironment:
|
|
78
|
+
"""Create a low-corruption environment using tuned defaults."""
|
|
79
|
+
tuned_mim1c = Mim1c(rate=0.01 * difficulty.value)
|
|
80
|
+
tuned_typogre = Typogre(rate=0.025 * difficulty.value)
|
|
81
|
+
|
|
82
|
+
return load_environment(
|
|
83
|
+
env,
|
|
84
|
+
glitchlings=[tuned_mim1c, tuned_typogre],
|
|
85
|
+
seed=seed,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_environment(
|
|
90
|
+
env: str | VerifierEnvironment,
|
|
91
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle | None = None,
|
|
92
|
+
*,
|
|
93
|
+
seed: int = 151,
|
|
94
|
+
columns: Sequence[str] | None = None,
|
|
95
|
+
) -> VerifierEnvironment:
|
|
96
|
+
"""Load an environment and optionally corrupt it with glitchlings."""
|
|
97
|
+
environment = _resolve_environment(env)
|
|
98
|
+
|
|
99
|
+
if glitchlings is None:
|
|
100
|
+
return environment
|
|
101
|
+
|
|
102
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
103
|
+
|
|
104
|
+
dataset = environment.dataset
|
|
105
|
+
corrupt_columns = _resolve_columns(dataset, columns)
|
|
106
|
+
environment.dataset = gaggle.corrupt_dataset(dataset, corrupt_columns)
|
|
107
|
+
return environment
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _as_gaggle(
|
|
111
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
112
|
+
*,
|
|
113
|
+
seed: int,
|
|
114
|
+
) -> Gaggle:
|
|
115
|
+
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
116
|
+
return coerce_gaggle(glitchlings, seed=seed)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _extract_completion_text(completion: Any) -> str:
|
|
120
|
+
"""Normalise a completion payload into a plain string."""
|
|
121
|
+
if isinstance(completion, str):
|
|
122
|
+
return completion
|
|
123
|
+
|
|
124
|
+
if isinstance(completion, list) and completion:
|
|
125
|
+
first = completion[0]
|
|
126
|
+
if isinstance(first, dict) and "content" in first:
|
|
127
|
+
return str(first["content"])
|
|
128
|
+
return str(first)
|
|
129
|
+
|
|
130
|
+
return str(completion)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def symmetric_damerau_levenshtein_similarity(
|
|
134
|
+
_: Any,
|
|
135
|
+
completion: Any,
|
|
136
|
+
answer: str,
|
|
137
|
+
) -> float:
|
|
138
|
+
"""Return ``1 - (distance / max_len)`` using Damerau-Levenshtein distance."""
|
|
139
|
+
completion_text = _extract_completion_text(completion)
|
|
140
|
+
target = answer or ""
|
|
141
|
+
denominator = max(len(completion_text), len(target), 1)
|
|
142
|
+
distance = cast(int, damerau_levenshtein_distance(completion_text, target))
|
|
143
|
+
score = 1.0 - (distance / denominator)
|
|
144
|
+
return max(0.0, min(1.0, score))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
DEFAULT_CLEANUP_INSTRUCTIONS = (
|
|
148
|
+
"You are a meticulous copy editor. Restore the provided text to its original form."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def echo_chamber(
|
|
153
|
+
dataset_id: str,
|
|
154
|
+
column: str,
|
|
155
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
156
|
+
*,
|
|
157
|
+
seed: int = 151,
|
|
158
|
+
instructions: str = DEFAULT_CLEANUP_INSTRUCTIONS,
|
|
159
|
+
reward_function: Callable[..., float] | None = None,
|
|
160
|
+
split: str | None = None,
|
|
161
|
+
**load_dataset_kwargs: Any,
|
|
162
|
+
) -> VerifierSingleTurnEnv:
|
|
163
|
+
"""Create an Echo Chamber Prime environment from a Hugging Face dataset column.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
dataset_id: Identifier of the Hugging Face dataset to load.
|
|
167
|
+
column: Name of the column whose text should be glitched.
|
|
168
|
+
glitchlings: Glitchling specifiers that will corrupt the prompts.
|
|
169
|
+
seed: RNG seed forwarded to :func:`glitchlings.util.adapters.coerce_gaggle`.
|
|
170
|
+
instructions: System instructions supplied to the environment prompts.
|
|
171
|
+
reward_function: Optional callable used to score completions. Defaults to
|
|
172
|
+
:func:`symmetric_damerau_levenshtein_similarity` when omitted.
|
|
173
|
+
split: Optional dataset split to load.
|
|
174
|
+
**load_dataset_kwargs: Extra keyword arguments forwarded to
|
|
175
|
+
:func:`datasets.load_dataset`.
|
|
176
|
+
|
|
177
|
+
"""
|
|
178
|
+
datasets_module = require_datasets("datasets is required to build an echo chamber")
|
|
179
|
+
load_dataset = getattr(datasets_module, "load_dataset", None)
|
|
180
|
+
if load_dataset is None: # pragma: no cover - defensive
|
|
181
|
+
message = "datasets is required to build an echo chamber"
|
|
182
|
+
raise ModuleNotFoundError(message)
|
|
183
|
+
|
|
184
|
+
dataset_dict_cls = getattr(datasets_module, "DatasetDict", dict)
|
|
185
|
+
|
|
186
|
+
hf_dataset: Any
|
|
187
|
+
if split is None:
|
|
188
|
+
hf_dataset = load_dataset(dataset_id, **load_dataset_kwargs)
|
|
189
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
190
|
+
try:
|
|
191
|
+
hf_dataset = next(iter(hf_dataset.values()))
|
|
192
|
+
except StopIteration as exc: # pragma: no cover - defensive
|
|
193
|
+
raise ValueError("The specified dataset does not contain any splits") from exc
|
|
194
|
+
else:
|
|
195
|
+
hf_dataset = load_dataset(dataset_id, split=split, **load_dataset_kwargs)
|
|
196
|
+
|
|
197
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
198
|
+
raise ValueError("Specify which split to use when the dataset loads as a DatasetDict.")
|
|
199
|
+
|
|
200
|
+
filtered_dataset = hf_dataset.filter(
|
|
201
|
+
lambda row: row.get(column) is not None,
|
|
202
|
+
load_from_cache_file=False,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
source_column_names = list(filtered_dataset.column_names)
|
|
206
|
+
|
|
207
|
+
def _build_prompt(row: dict[str, Any]) -> dict[str, Any]:
|
|
208
|
+
text = str(row[column])
|
|
209
|
+
prompt = [
|
|
210
|
+
{"role": "system", "content": instructions},
|
|
211
|
+
{"role": "user", "content": f"Corrupted text:\n{text}"},
|
|
212
|
+
]
|
|
213
|
+
return {"prompt": prompt, "answer": text}
|
|
214
|
+
|
|
215
|
+
base_dataset = filtered_dataset.map(
|
|
216
|
+
_build_prompt,
|
|
217
|
+
remove_columns=source_column_names,
|
|
218
|
+
load_from_cache_file=False,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
dataset_length = len(base_dataset)
|
|
223
|
+
except TypeError:
|
|
224
|
+
preview_rows: list[dict[str, Any]]
|
|
225
|
+
take_fn = getattr(base_dataset, "take", None)
|
|
226
|
+
if callable(take_fn):
|
|
227
|
+
preview_rows = list(take_fn(1))
|
|
228
|
+
else:
|
|
229
|
+
iterator = iter(base_dataset)
|
|
230
|
+
try:
|
|
231
|
+
first_row = next(iterator)
|
|
232
|
+
except StopIteration:
|
|
233
|
+
preview_rows = []
|
|
234
|
+
else:
|
|
235
|
+
preview_rows = [first_row]
|
|
236
|
+
if not preview_rows:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Column '{column}' did not yield any textual entries in dataset '{dataset_id}'."
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
if dataset_length == 0:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Column '{column}' did not yield any textual entries in dataset '{dataset_id}'."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
gaggle = _as_gaggle(glitchlings, seed=seed)
|
|
247
|
+
glitched_dataset = gaggle.corrupt_dataset(base_dataset, ["prompt"])
|
|
248
|
+
|
|
249
|
+
rubric_func = reward_function or symmetric_damerau_levenshtein_similarity
|
|
250
|
+
rubric = vf.Rubric(funcs=[rubric_func], weights=[1.0])
|
|
251
|
+
return cast(
|
|
252
|
+
VerifierSingleTurnEnv,
|
|
253
|
+
vf.SingleTurnEnv(dataset=glitched_dataset, rubric=rubric),
|
|
254
|
+
)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch data loaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import get_torch_dataloader, require_torch
|
|
9
|
+
from ..compat import torch as _torch_dependency
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling
|
|
12
|
+
from ._shared import corrupt_text_value, is_textual_candidate, normalise_column_spec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _apply_to_batch(batch: Any, targets: list[str | int] | None, gaggle: Gaggle) -> Any:
|
|
16
|
+
"""Return ``batch`` with glitchlings applied to the specified ``targets``."""
|
|
17
|
+
if targets is None:
|
|
18
|
+
return corrupt_text_value(batch, gaggle)
|
|
19
|
+
|
|
20
|
+
if isinstance(batch, Mapping):
|
|
21
|
+
mutated = cast(MutableMapping[str, Any], dict(batch))
|
|
22
|
+
for key in targets:
|
|
23
|
+
if not isinstance(key, str):
|
|
24
|
+
raise TypeError("Mapping batches require string column names")
|
|
25
|
+
if key not in mutated:
|
|
26
|
+
raise ValueError(f"Column '{key}' not found in DataLoader batch")
|
|
27
|
+
mutated[key] = corrupt_text_value(mutated[key], gaggle)
|
|
28
|
+
return mutated
|
|
29
|
+
|
|
30
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
31
|
+
mutated_sequence = list(batch)
|
|
32
|
+
for index in targets:
|
|
33
|
+
if not isinstance(index, int):
|
|
34
|
+
raise TypeError("Sequence batches require integer column indices")
|
|
35
|
+
try:
|
|
36
|
+
mutated_sequence[index] = corrupt_text_value(mutated_sequence[index], gaggle)
|
|
37
|
+
except IndexError as exc: # pragma: no cover - defensive
|
|
38
|
+
raise IndexError("Column index out of range for DataLoader batch") from exc
|
|
39
|
+
if isinstance(batch, tuple):
|
|
40
|
+
return tuple(mutated_sequence)
|
|
41
|
+
return mutated_sequence
|
|
42
|
+
|
|
43
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _infer_targets(batch: Any) -> list[str | int] | None:
|
|
47
|
+
"""Infer which fields should be glitched from a representative ``batch``."""
|
|
48
|
+
if isinstance(batch, Mapping):
|
|
49
|
+
inferred = [key for key, value in batch.items() if is_textual_candidate(value)]
|
|
50
|
+
if inferred:
|
|
51
|
+
return inferred
|
|
52
|
+
raise ValueError("Unable to infer which mapping columns contain text")
|
|
53
|
+
|
|
54
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
55
|
+
inferred_indices: list[str | int] = [
|
|
56
|
+
idx for idx, value in enumerate(batch) if is_textual_candidate(value)
|
|
57
|
+
]
|
|
58
|
+
if inferred_indices:
|
|
59
|
+
return inferred_indices
|
|
60
|
+
raise ValueError("Unable to infer which sequence indices contain text")
|
|
61
|
+
|
|
62
|
+
if is_textual_candidate(batch):
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _GlitchedDataLoader(Iterable[Any]):
|
|
69
|
+
"""Wrapper that applies glitchlings lazily to each batch from a data loader."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
dataloader: Any,
|
|
74
|
+
gaggle: Gaggle,
|
|
75
|
+
*,
|
|
76
|
+
columns: list[str | int] | None,
|
|
77
|
+
) -> None:
|
|
78
|
+
self._dataloader = dataloader
|
|
79
|
+
self._gaggle = gaggle
|
|
80
|
+
self._explicit_columns = columns
|
|
81
|
+
self._inferred_columns: list[str | int] | None | _Sentinel = _UNINITIALISED
|
|
82
|
+
|
|
83
|
+
def __iter__(self) -> Iterator[Any]:
|
|
84
|
+
# Reset all glitchling RNGs before each fresh pass for determinism.
|
|
85
|
+
self._gaggle.sort_glitchlings()
|
|
86
|
+
for batch in self._dataloader:
|
|
87
|
+
targets = self._resolve_columns(batch)
|
|
88
|
+
yield _apply_to_batch(batch, targets, self._gaggle)
|
|
89
|
+
|
|
90
|
+
def __len__(self) -> int:
|
|
91
|
+
return len(self._dataloader)
|
|
92
|
+
|
|
93
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
94
|
+
return getattr(self._dataloader, attribute)
|
|
95
|
+
|
|
96
|
+
def _resolve_columns(self, batch: Any) -> list[str | int] | None:
|
|
97
|
+
if self._explicit_columns is not None:
|
|
98
|
+
return self._explicit_columns
|
|
99
|
+
|
|
100
|
+
if self._inferred_columns is _UNINITIALISED:
|
|
101
|
+
self._inferred_columns = _infer_targets(batch)
|
|
102
|
+
|
|
103
|
+
return cast(list[str | int] | None, self._inferred_columns)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _Sentinel:
|
|
107
|
+
"""Sentinel type for deferred column inference."""
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
_UNINITIALISED = _Sentinel()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _ensure_dataloader_class() -> type[Any]:
|
|
114
|
+
"""Return :class:`torch.utils.data.DataLoader` patched with ``.glitch``."""
|
|
115
|
+
dataloader_cls = get_torch_dataloader()
|
|
116
|
+
if dataloader_cls is None:
|
|
117
|
+
require_torch("torch is not installed; install glitchlings[torch]")
|
|
118
|
+
dataloader_cls = get_torch_dataloader()
|
|
119
|
+
if dataloader_cls is None: # pragma: no cover - defensive
|
|
120
|
+
message = "torch.utils.data.DataLoader is not available"
|
|
121
|
+
error = _torch_dependency.error
|
|
122
|
+
if error is not None:
|
|
123
|
+
raise ModuleNotFoundError(message) from error
|
|
124
|
+
raise ModuleNotFoundError(message)
|
|
125
|
+
|
|
126
|
+
if getattr(dataloader_cls, "glitch", None) is None:
|
|
127
|
+
|
|
128
|
+
def glitch(
|
|
129
|
+
self: Any,
|
|
130
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
131
|
+
*,
|
|
132
|
+
columns: str | int | Sequence[str | int] | None = None,
|
|
133
|
+
seed: int = 151,
|
|
134
|
+
) -> _GlitchedDataLoader:
|
|
135
|
+
"""Return a lazily glitched view of the loader's batches."""
|
|
136
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
137
|
+
normalised = normalise_column_spec(columns)
|
|
138
|
+
return _GlitchedDataLoader(self, gaggle, columns=normalised)
|
|
139
|
+
|
|
140
|
+
setattr(dataloader_cls, "glitch", glitch)
|
|
141
|
+
|
|
142
|
+
return cast(type[Any], dataloader_cls)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _optional_dataloader_class() -> type[Any] | None:
|
|
146
|
+
"""Return the PyTorch :class:`~torch.utils.data.DataLoader` when importable."""
|
|
147
|
+
dataloader_cls = get_torch_dataloader()
|
|
148
|
+
if dataloader_cls is None:
|
|
149
|
+
return None
|
|
150
|
+
return cast(type[Any], dataloader_cls)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def install() -> None:
|
|
154
|
+
"""Monkeypatch PyTorch's :class:`~torch.utils.data.DataLoader` with ``.glitch``."""
|
|
155
|
+
_ensure_dataloader_class()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
DataLoader: type[Any] | None
|
|
159
|
+
_DataLoaderAlias = _optional_dataloader_class()
|
|
160
|
+
if _DataLoaderAlias is not None:
|
|
161
|
+
DataLoader = _ensure_dataloader_class()
|
|
162
|
+
else: # pragma: no cover - torch is an optional dependency
|
|
163
|
+
DataLoader = None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
__all__ = ["DataLoader", "install"]
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch Lightning data modules."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import get_pytorch_lightning_datamodule, require_pytorch_lightning
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
11
|
+
from ._shared import corrupt_text_value, normalise_column_spec
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _glitch_batch(batch: Any, columns: list[str], gaggle: Gaggle) -> Any:
|
|
15
|
+
"""Apply glitchlings to the configured batch columns."""
|
|
16
|
+
if not isinstance(batch, Mapping):
|
|
17
|
+
return batch
|
|
18
|
+
|
|
19
|
+
if hasattr(batch, "copy"):
|
|
20
|
+
mutated = batch.copy()
|
|
21
|
+
else:
|
|
22
|
+
mutated = dict(batch)
|
|
23
|
+
|
|
24
|
+
missing = [column for column in columns if column not in mutated]
|
|
25
|
+
if missing:
|
|
26
|
+
missing_str = ", ".join(sorted(missing))
|
|
27
|
+
raise ValueError(f"Columns not found in batch: {missing_str}")
|
|
28
|
+
|
|
29
|
+
for column in columns:
|
|
30
|
+
mutated[column] = corrupt_text_value(mutated[column], gaggle)
|
|
31
|
+
|
|
32
|
+
return mutated
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _wrap_dataloader(dataloader: Any, columns: list[str], gaggle: Gaggle) -> Any:
|
|
36
|
+
"""Wrap a dataloader so yielded batches are corrupted lazily."""
|
|
37
|
+
if dataloader is None:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
if isinstance(dataloader, Mapping):
|
|
41
|
+
mapping_type = cast(type[Any], dataloader.__class__)
|
|
42
|
+
return mapping_type(
|
|
43
|
+
{
|
|
44
|
+
key: _wrap_dataloader(value, columns, gaggle)
|
|
45
|
+
for key, value in dataloader.items()
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if isinstance(dataloader, list):
|
|
50
|
+
return [_wrap_dataloader(value, columns, gaggle) for value in dataloader]
|
|
51
|
+
|
|
52
|
+
if isinstance(dataloader, tuple):
|
|
53
|
+
return tuple(_wrap_dataloader(value, columns, gaggle) for value in dataloader)
|
|
54
|
+
|
|
55
|
+
if isinstance(dataloader, Sequence) and not isinstance(dataloader, (str, bytes, bytearray)):
|
|
56
|
+
sequence_type = cast(type[Any], dataloader.__class__)
|
|
57
|
+
return sequence_type(
|
|
58
|
+
_wrap_dataloader(value, columns, gaggle) for value in dataloader
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return _GlitchedDataLoader(dataloader, columns, gaggle)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _GlitchedDataLoader:
|
|
65
|
+
"""Proxy dataloader that glitches batches produced by the wrapped loader."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, dataloader: Any, columns: list[str], gaggle: Gaggle) -> None:
|
|
68
|
+
self._dataloader = dataloader
|
|
69
|
+
self._columns = columns
|
|
70
|
+
self._gaggle = gaggle
|
|
71
|
+
|
|
72
|
+
def __iter__(self) -> Any:
|
|
73
|
+
for batch in self._dataloader:
|
|
74
|
+
yield _glitch_batch(batch, self._columns, self._gaggle)
|
|
75
|
+
|
|
76
|
+
def __len__(self) -> int:
|
|
77
|
+
return len(self._dataloader)
|
|
78
|
+
|
|
79
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
80
|
+
return getattr(self._dataloader, attribute)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _glitch_datamodule(
|
|
84
|
+
datamodule: Any,
|
|
85
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
86
|
+
column: str | Sequence[str],
|
|
87
|
+
*,
|
|
88
|
+
seed: int = 151,
|
|
89
|
+
) -> Any:
|
|
90
|
+
"""Return a proxy that applies glitchlings to batches from the datamodule."""
|
|
91
|
+
|
|
92
|
+
columns = normalise_column_spec(column)
|
|
93
|
+
if columns is None: # pragma: no cover - defensive
|
|
94
|
+
raise ValueError("At least one column must be specified")
|
|
95
|
+
# Lightning datamodules only support string column names (mapping keys)
|
|
96
|
+
columns_str = cast(list[str], columns)
|
|
97
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
98
|
+
return _GlitchedLightningDataModule(datamodule, columns_str, gaggle)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class _GlitchedLightningDataModule:
|
|
102
|
+
"""Proxy wrapper around a LightningDataModule applying glitchlings to batches."""
|
|
103
|
+
|
|
104
|
+
def __init__(self, base: Any, columns: list[str], gaggle: Gaggle) -> None:
|
|
105
|
+
object.__setattr__(self, "_glitch_base", base)
|
|
106
|
+
object.__setattr__(self, "_glitch_columns", columns)
|
|
107
|
+
object.__setattr__(self, "_glitch_gaggle", gaggle)
|
|
108
|
+
|
|
109
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
110
|
+
return getattr(self._glitch_base, attribute)
|
|
111
|
+
|
|
112
|
+
def __setattr__(self, attribute: str, value: Any) -> None:
|
|
113
|
+
if attribute.startswith("_glitch_"):
|
|
114
|
+
object.__setattr__(self, attribute, value)
|
|
115
|
+
else:
|
|
116
|
+
setattr(self._glitch_base, attribute, value)
|
|
117
|
+
|
|
118
|
+
def __delattr__(self, attribute: str) -> None:
|
|
119
|
+
if attribute.startswith("_glitch_"):
|
|
120
|
+
object.__delattr__(self, attribute)
|
|
121
|
+
else:
|
|
122
|
+
delattr(self._glitch_base, attribute)
|
|
123
|
+
|
|
124
|
+
def __dir__(self) -> list[str]:
|
|
125
|
+
return sorted(set(dir(self.__class__)) | set(dir(self._glitch_base)))
|
|
126
|
+
|
|
127
|
+
# LightningDataModule API -------------------------------------------------
|
|
128
|
+
def prepare_data(self, *args: Any, **kwargs: Any) -> Any:
|
|
129
|
+
return self._glitch_base.prepare_data(*args, **kwargs)
|
|
130
|
+
|
|
131
|
+
def setup(self, *args: Any, **kwargs: Any) -> Any:
|
|
132
|
+
return self._glitch_base.setup(*args, **kwargs)
|
|
133
|
+
|
|
134
|
+
def teardown(self, *args: Any, **kwargs: Any) -> Any:
|
|
135
|
+
return self._glitch_base.teardown(*args, **kwargs)
|
|
136
|
+
|
|
137
|
+
def state_dict(self) -> Mapping[str, Any]:
|
|
138
|
+
state = self._glitch_base.state_dict()
|
|
139
|
+
return cast(Mapping[str, Any], state)
|
|
140
|
+
|
|
141
|
+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
|
142
|
+
self._glitch_base.load_state_dict(state_dict)
|
|
143
|
+
|
|
144
|
+
def transfer_batch_to_device(self, batch: Any, device: Any, dataloader_idx: int) -> Any:
|
|
145
|
+
return self._glitch_base.transfer_batch_to_device(batch, device, dataloader_idx)
|
|
146
|
+
|
|
147
|
+
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
|
148
|
+
return self._glitch_base.on_before_batch_transfer(batch, dataloader_idx)
|
|
149
|
+
|
|
150
|
+
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
|
151
|
+
return self._glitch_base.on_after_batch_transfer(batch, dataloader_idx)
|
|
152
|
+
|
|
153
|
+
def train_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
154
|
+
loader = self._glitch_base.train_dataloader(*args, **kwargs)
|
|
155
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
156
|
+
|
|
157
|
+
def val_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
158
|
+
loader = self._glitch_base.val_dataloader(*args, **kwargs)
|
|
159
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
160
|
+
|
|
161
|
+
def test_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
162
|
+
loader = self._glitch_base.test_dataloader(*args, **kwargs)
|
|
163
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
164
|
+
|
|
165
|
+
def predict_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
166
|
+
loader = self._glitch_base.predict_dataloader(*args, **kwargs)
|
|
167
|
+
return _wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _ensure_datamodule_class() -> Any:
|
|
171
|
+
"""Return the Lightning ``LightningDataModule`` patched with ``.glitch``."""
|
|
172
|
+
|
|
173
|
+
datamodule_cls = get_pytorch_lightning_datamodule()
|
|
174
|
+
if datamodule_cls is None: # pragma: no cover - dependency is optional
|
|
175
|
+
module = require_pytorch_lightning("pytorch_lightning is not installed")
|
|
176
|
+
datamodule_cls = getattr(module, "LightningDataModule", None)
|
|
177
|
+
if datamodule_cls is None:
|
|
178
|
+
raise ModuleNotFoundError("pytorch_lightning is not installed")
|
|
179
|
+
|
|
180
|
+
if getattr(datamodule_cls, "glitch", None) is None:
|
|
181
|
+
|
|
182
|
+
def glitch(
|
|
183
|
+
self: Any,
|
|
184
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
185
|
+
*,
|
|
186
|
+
column: str | Sequence[str],
|
|
187
|
+
seed: int = 151,
|
|
188
|
+
**_: Any,
|
|
189
|
+
) -> Any:
|
|
190
|
+
return _glitch_datamodule(self, glitchlings, column, seed=seed)
|
|
191
|
+
|
|
192
|
+
setattr(datamodule_cls, "glitch", glitch)
|
|
193
|
+
|
|
194
|
+
if not issubclass(_GlitchedLightningDataModule, datamodule_cls):
|
|
195
|
+
_GlitchedLightningDataModule.__bases__ = (datamodule_cls,)
|
|
196
|
+
|
|
197
|
+
return datamodule_cls
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def install() -> None:
|
|
201
|
+
"""Monkeypatch ``LightningDataModule`` with ``.glitch``."""
|
|
202
|
+
|
|
203
|
+
_ensure_datamodule_class()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
LightningDataModule: type[Any] | None
|
|
207
|
+
_LightningDataModuleAlias = get_pytorch_lightning_datamodule()
|
|
208
|
+
if _LightningDataModuleAlias is not None:
|
|
209
|
+
LightningDataModule = _ensure_datamodule_class()
|
|
210
|
+
else: # pragma: no cover - optional dependency
|
|
211
|
+
LightningDataModule = None
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
__all__ = ["LightningDataModule", "install"]
|
|
215
|
+
|