glitchlings 0.2.5__cp312-cp312-win_amd64.whl → 0.9.3__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glitchlings/__init__.py +36 -17
- glitchlings/__main__.py +0 -1
- glitchlings/_zoo_rust/__init__.py +12 -0
- glitchlings/_zoo_rust.cp312-win_amd64.pyd +0 -0
- glitchlings/assets/__init__.py +180 -0
- glitchlings/assets/apostrofae_pairs.json +32 -0
- glitchlings/assets/ekkokin_homophones.json +2014 -0
- glitchlings/assets/hokey_assets.json +193 -0
- glitchlings/assets/lexemes/academic.json +1049 -0
- glitchlings/assets/lexemes/colors.json +1333 -0
- glitchlings/assets/lexemes/corporate.json +716 -0
- glitchlings/assets/lexemes/cyberpunk.json +22 -0
- glitchlings/assets/lexemes/lovecraftian.json +23 -0
- glitchlings/assets/lexemes/synonyms.json +3354 -0
- glitchlings/assets/mim1c_homoglyphs.json.gz.b64 +1064 -0
- glitchlings/assets/pipeline_assets.json +29 -0
- glitchlings/attack/__init__.py +53 -0
- glitchlings/attack/compose.py +299 -0
- glitchlings/attack/core.py +465 -0
- glitchlings/attack/encode.py +114 -0
- glitchlings/attack/metrics.py +104 -0
- glitchlings/attack/metrics_dispatch.py +70 -0
- glitchlings/attack/tokenization.py +157 -0
- glitchlings/auggie.py +283 -0
- glitchlings/compat/__init__.py +9 -0
- glitchlings/compat/loaders.py +355 -0
- glitchlings/compat/types.py +41 -0
- glitchlings/conf/__init__.py +41 -0
- glitchlings/conf/loaders.py +331 -0
- glitchlings/conf/schema.py +156 -0
- glitchlings/conf/types.py +72 -0
- glitchlings/config.toml +2 -0
- glitchlings/constants.py +59 -0
- glitchlings/dev/__init__.py +3 -0
- glitchlings/dev/docs.py +45 -0
- glitchlings/dlc/__init__.py +17 -3
- glitchlings/dlc/_shared.py +296 -0
- glitchlings/dlc/gutenberg.py +400 -0
- glitchlings/dlc/huggingface.py +37 -65
- glitchlings/dlc/prime.py +55 -114
- glitchlings/dlc/pytorch.py +98 -0
- glitchlings/dlc/pytorch_lightning.py +173 -0
- glitchlings/internal/__init__.py +16 -0
- glitchlings/internal/rust.py +159 -0
- glitchlings/internal/rust_ffi.py +432 -0
- glitchlings/main.py +123 -32
- glitchlings/runtime_config.py +24 -0
- glitchlings/util/__init__.py +29 -176
- glitchlings/util/adapters.py +65 -0
- glitchlings/util/keyboards.py +311 -0
- glitchlings/util/transcripts.py +108 -0
- glitchlings/zoo/__init__.py +47 -24
- glitchlings/zoo/assets/__init__.py +29 -0
- glitchlings/zoo/core.py +301 -167
- glitchlings/zoo/core_execution.py +98 -0
- glitchlings/zoo/core_planning.py +451 -0
- glitchlings/zoo/corrupt_dispatch.py +295 -0
- glitchlings/zoo/ekkokin.py +118 -0
- glitchlings/zoo/hokey.py +137 -0
- glitchlings/zoo/jargoyle.py +179 -274
- glitchlings/zoo/mim1c.py +106 -68
- glitchlings/zoo/pedant/__init__.py +107 -0
- glitchlings/zoo/pedant/core.py +105 -0
- glitchlings/zoo/pedant/forms.py +74 -0
- glitchlings/zoo/pedant/stones.py +74 -0
- glitchlings/zoo/redactyl.py +44 -175
- glitchlings/zoo/rng.py +259 -0
- glitchlings/zoo/rushmore.py +359 -116
- glitchlings/zoo/scannequin.py +18 -125
- glitchlings/zoo/transforms.py +386 -0
- glitchlings/zoo/typogre.py +76 -162
- glitchlings/zoo/validation.py +477 -0
- glitchlings/zoo/zeedub.py +33 -86
- glitchlings-0.9.3.dist-info/METADATA +334 -0
- glitchlings-0.9.3.dist-info/RECORD +80 -0
- {glitchlings-0.2.5.dist-info → glitchlings-0.9.3.dist-info}/entry_points.txt +1 -0
- glitchlings/zoo/_ocr_confusions.py +0 -34
- glitchlings/zoo/_rate.py +0 -21
- glitchlings/zoo/reduple.py +0 -169
- glitchlings-0.2.5.dist-info/METADATA +0 -490
- glitchlings-0.2.5.dist-info/RECORD +0 -27
- /glitchlings/{zoo → assets}/ocr_confusions.tsv +0 -0
- {glitchlings-0.2.5.dist-info → glitchlings-0.9.3.dist-info}/WHEEL +0 -0
- {glitchlings-0.2.5.dist-info → glitchlings-0.9.3.dist-info}/licenses/LICENSE +0 -0
- {glitchlings-0.2.5.dist-info → glitchlings-0.9.3.dist-info}/top_level.txt +0 -0
glitchlings/dlc/huggingface.py
CHANGED
|
@@ -5,40 +5,19 @@ from __future__ import annotations
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
except ModuleNotFoundError as _datasets_error: # pragma: no cover - optional dependency
|
|
11
|
-
_DatasetsDataset = None # type: ignore[assignment]
|
|
12
|
-
else:
|
|
13
|
-
_datasets_error = None
|
|
8
|
+
from ..util.adapters import coerce_gaggle
|
|
9
|
+
from ..zoo import Gaggle, Glitchling
|
|
14
10
|
|
|
15
|
-
from ..zoo import Gaggle, Glitchling, summon
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
19
|
-
"""Normalise a column specification to a list."""
|
|
20
11
|
|
|
12
|
+
def _normalize_columns(column: str | Sequence[str]) -> list[str]:
|
|
13
|
+
"""Normalize a column specification to a list."""
|
|
21
14
|
if isinstance(column, str):
|
|
22
15
|
return [column]
|
|
23
16
|
|
|
24
|
-
|
|
25
|
-
if not
|
|
17
|
+
normalized = list(column)
|
|
18
|
+
if not normalized:
|
|
26
19
|
raise ValueError("At least one column must be specified")
|
|
27
|
-
return
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _as_gaggle(glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling], seed: int) -> Gaggle:
|
|
31
|
-
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
32
|
-
|
|
33
|
-
if isinstance(glitchlings, Gaggle):
|
|
34
|
-
return glitchlings
|
|
35
|
-
|
|
36
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
37
|
-
resolved: Iterable[str | Glitchling] = [glitchlings]
|
|
38
|
-
else:
|
|
39
|
-
resolved = glitchlings
|
|
40
|
-
|
|
41
|
-
return summon(list(resolved), seed=seed)
|
|
20
|
+
return normalized
|
|
42
21
|
|
|
43
22
|
|
|
44
23
|
def _glitch_dataset(
|
|
@@ -48,49 +27,42 @@ def _glitch_dataset(
|
|
|
48
27
|
*,
|
|
49
28
|
seed: int = 151,
|
|
50
29
|
) -> Any:
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
gaggle = _as_gaggle(glitchlings, seed=seed)
|
|
30
|
+
"""Apply glitchlings to the provided dataset columns."""
|
|
31
|
+
columns = _normalize_columns(column)
|
|
32
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
55
33
|
return gaggle.corrupt_dataset(dataset, columns)
|
|
56
34
|
|
|
57
35
|
|
|
58
|
-
def
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def glitch( # type: ignore[override]
|
|
68
|
-
self: Any,
|
|
69
|
-
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
70
|
-
*,
|
|
71
|
-
column: str | Sequence[str],
|
|
72
|
-
seed: int = 151,
|
|
73
|
-
**_: Any,
|
|
74
|
-
) -> Any:
|
|
75
|
-
"""Return a lazily corrupted copy of the dataset."""
|
|
76
|
-
|
|
77
|
-
return _glitch_dataset(self, glitchlings, column, seed=seed)
|
|
78
|
-
|
|
79
|
-
setattr(_DatasetsDataset, "glitch", glitch)
|
|
80
|
-
|
|
81
|
-
return _DatasetsDataset
|
|
82
|
-
|
|
36
|
+
def GlitchedDataset(
|
|
37
|
+
dataset: Any,
|
|
38
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
39
|
+
*,
|
|
40
|
+
column: str | Sequence[str],
|
|
41
|
+
seed: int = 151,
|
|
42
|
+
) -> Any:
|
|
43
|
+
"""Return a lazily corrupted copy of a Hugging Face dataset.
|
|
83
44
|
|
|
84
|
-
|
|
85
|
-
|
|
45
|
+
This function applies glitchlings to the specified columns of a dataset,
|
|
46
|
+
returning a new dataset that lazily corrupts data as it's accessed.
|
|
86
47
|
|
|
87
|
-
|
|
48
|
+
Args:
|
|
49
|
+
dataset: The Hugging Face Dataset to corrupt.
|
|
50
|
+
glitchlings: A glitchling, gaggle, or specification of glitchlings to apply.
|
|
51
|
+
column: The column name (string) or names (sequence of strings) to corrupt.
|
|
52
|
+
seed: RNG seed for deterministic corruption (default: 151).
|
|
88
53
|
|
|
54
|
+
Returns:
|
|
55
|
+
A new dataset with the specified columns corrupted by the glitchlings.
|
|
89
56
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
57
|
+
Example:
|
|
58
|
+
>>> from datasets import Dataset
|
|
59
|
+
>>> from glitchlings.dlc.huggingface import GlitchedDataset
|
|
60
|
+
>>> dataset = Dataset.from_dict({"text": ["hello", "world"]})
|
|
61
|
+
>>> corrupted = GlitchedDataset(dataset, "typogre", column="text")
|
|
62
|
+
>>> list(corrupted)
|
|
63
|
+
[{'text': 'helo'}, {'text': 'wrold'}]
|
|
64
|
+
"""
|
|
65
|
+
return _glitch_dataset(dataset, glitchlings, column, seed=seed)
|
|
94
66
|
|
|
95
67
|
|
|
96
|
-
__all__ = ["
|
|
68
|
+
__all__ = ["GlitchedDataset"]
|
glitchlings/dlc/prime.py
CHANGED
|
@@ -3,115 +3,62 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
|
-
from
|
|
7
|
-
from typing import Any, Callable
|
|
6
|
+
from typing import Any, Callable, Protocol, cast
|
|
8
7
|
|
|
9
|
-
import
|
|
8
|
+
from ..compat.loaders import require_datasets, require_jellyfish, require_verifiers
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling, Mim1c, Typogre # noqa: F401
|
|
11
|
+
from ._shared import resolve_columns as _resolve_columns_shared
|
|
10
12
|
|
|
11
|
-
from jellyfish import damerau_levenshtein_distance
|
|
12
13
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
16
|
-
Dataset = object # type: ignore[assignment]
|
|
17
|
-
else:
|
|
18
|
-
if Dataset is None: # pragma: no cover - optional dependency
|
|
19
|
-
Dataset = object # type: ignore[assignment]
|
|
14
|
+
class VerifierEnvironment(Protocol):
|
|
15
|
+
"""Minimal interface for verifiers environments."""
|
|
20
16
|
|
|
21
|
-
|
|
17
|
+
dataset: Any
|
|
22
18
|
|
|
23
19
|
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
if isinstance(env, str):
|
|
28
|
-
env = vf.load_environment(env)
|
|
29
|
-
|
|
30
|
-
if not isinstance(env, vf.Environment):
|
|
31
|
-
raise TypeError("Invalid environment type")
|
|
32
|
-
|
|
33
|
-
return env
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _resolve_columns(dataset: Dataset, columns: Sequence[str] | None) -> list[str]:
|
|
37
|
-
"""Identify which dataset columns should be corrupted."""
|
|
20
|
+
class VerifierSingleTurnEnv(Protocol):
|
|
21
|
+
"""Minimal interface for single-turn verifier environments."""
|
|
38
22
|
|
|
39
|
-
|
|
23
|
+
dataset: Any
|
|
24
|
+
rubric: Any
|
|
40
25
|
|
|
41
|
-
if columns is not None:
|
|
42
|
-
missing = sorted(set(columns) - available)
|
|
43
|
-
if missing:
|
|
44
|
-
missing_str = ", ".join(missing)
|
|
45
|
-
raise ValueError(f"Columns not found in dataset: {missing_str}")
|
|
46
|
-
return list(columns)
|
|
47
26
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
27
|
+
vf = require_verifiers("verifiers is not installed; install glitchlings[prime]")
|
|
28
|
+
_jellyfish = require_jellyfish("jellyfish is not installed; install glitchlings[prime]")
|
|
29
|
+
levenshtein_distance = _jellyfish.levenshtein_distance
|
|
51
30
|
|
|
52
|
-
sample = dataset[0] if len(dataset) else {}
|
|
53
|
-
inferred = [
|
|
54
|
-
name
|
|
55
|
-
for name in dataset.column_names
|
|
56
|
-
if isinstance(sample.get(name), str)
|
|
57
|
-
]
|
|
58
31
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class Difficulty(Enum):
|
|
66
|
-
"""Difficulty levels for tutorial environments."""
|
|
67
|
-
|
|
68
|
-
Easy = 0.25
|
|
69
|
-
Normal = 1.0
|
|
70
|
-
Hard = 1.75
|
|
71
|
-
Extreme = 3
|
|
72
|
-
Impossible = 9
|
|
32
|
+
def _resolve_environment(env: str | VerifierEnvironment) -> VerifierEnvironment:
|
|
33
|
+
"""Return a fully-instantiated verifier environment."""
|
|
34
|
+
if isinstance(env, str):
|
|
35
|
+
env = vf.load_environment(env)
|
|
73
36
|
|
|
37
|
+
if not isinstance(env, cast(type[Any], vf.Environment)):
|
|
38
|
+
raise TypeError("Invalid environment type")
|
|
74
39
|
|
|
75
|
-
|
|
76
|
-
env: vf.Environment | str,
|
|
77
|
-
seed: int = 151,
|
|
78
|
-
difficulty: Difficulty = Difficulty.Normal,
|
|
79
|
-
) -> vf.Environment:
|
|
80
|
-
"""Create a low-corruption environment using tuned defaults."""
|
|
40
|
+
return cast(VerifierEnvironment, env)
|
|
81
41
|
|
|
82
|
-
tuned_mim1c = Mim1c(rate=0.01 * difficulty.value)
|
|
83
|
-
tuned_typogre = Typogre(rate=0.025 * difficulty.value)
|
|
84
42
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
seed=seed,
|
|
89
|
-
)
|
|
43
|
+
def _resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
44
|
+
"""Identify which dataset columns should be corrupted."""
|
|
45
|
+
return _resolve_columns_shared(dataset, columns)
|
|
90
46
|
|
|
91
47
|
|
|
92
48
|
def load_environment(
|
|
93
|
-
env: str |
|
|
49
|
+
env: str | VerifierEnvironment,
|
|
94
50
|
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle | None = None,
|
|
95
51
|
*,
|
|
96
52
|
seed: int = 151,
|
|
97
53
|
columns: Sequence[str] | None = None,
|
|
98
|
-
) ->
|
|
54
|
+
) -> VerifierEnvironment:
|
|
99
55
|
"""Load an environment and optionally corrupt it with glitchlings."""
|
|
100
|
-
|
|
101
56
|
environment = _resolve_environment(env)
|
|
102
57
|
|
|
103
58
|
if glitchlings is None:
|
|
104
59
|
return environment
|
|
105
60
|
|
|
106
|
-
|
|
107
|
-
gaggle = glitchlings
|
|
108
|
-
else:
|
|
109
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
110
|
-
resolved = [glitchlings]
|
|
111
|
-
else:
|
|
112
|
-
resolved = list(glitchlings)
|
|
113
|
-
|
|
114
|
-
gaggle = summon(resolved, seed=seed)
|
|
61
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
115
62
|
|
|
116
63
|
dataset = environment.dataset
|
|
117
64
|
corrupt_columns = _resolve_columns(dataset, columns)
|
|
@@ -125,21 +72,11 @@ def _as_gaggle(
|
|
|
125
72
|
seed: int,
|
|
126
73
|
) -> Gaggle:
|
|
127
74
|
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
128
|
-
|
|
129
|
-
if isinstance(glitchlings, Gaggle):
|
|
130
|
-
return glitchlings
|
|
131
|
-
|
|
132
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
133
|
-
resolved: Iterable[str | Glitchling] = [glitchlings]
|
|
134
|
-
else:
|
|
135
|
-
resolved = glitchlings
|
|
136
|
-
|
|
137
|
-
return summon(list(resolved), seed=seed)
|
|
75
|
+
return coerce_gaggle(glitchlings, seed=seed)
|
|
138
76
|
|
|
139
77
|
|
|
140
78
|
def _extract_completion_text(completion: Any) -> str:
|
|
141
|
-
"""
|
|
142
|
-
|
|
79
|
+
"""Normalize a completion payload into a plain string."""
|
|
143
80
|
if isinstance(completion, str):
|
|
144
81
|
return completion
|
|
145
82
|
|
|
@@ -152,21 +89,22 @@ def _extract_completion_text(completion: Any) -> str:
|
|
|
152
89
|
return str(completion)
|
|
153
90
|
|
|
154
91
|
|
|
155
|
-
def
|
|
92
|
+
def normalized_edit_distance(
|
|
156
93
|
_: Any,
|
|
157
94
|
completion: Any,
|
|
158
95
|
answer: str,
|
|
159
96
|
) -> float:
|
|
160
|
-
"""Return ``1 - (distance / max_len)`` using
|
|
161
|
-
|
|
97
|
+
"""Return ``1 - (distance / max_len)`` using Levenshtein distance."""
|
|
162
98
|
completion_text = _extract_completion_text(completion)
|
|
163
99
|
target = answer or ""
|
|
164
100
|
denominator = max(len(completion_text), len(target), 1)
|
|
165
|
-
distance =
|
|
101
|
+
distance = cast(int, levenshtein_distance(completion_text, target))
|
|
166
102
|
score = 1.0 - (distance / denominator)
|
|
167
103
|
return max(0.0, min(1.0, score))
|
|
168
104
|
|
|
169
105
|
|
|
106
|
+
symmetric_levenshtein_similarity = normalized_edit_distance
|
|
107
|
+
|
|
170
108
|
DEFAULT_CLEANUP_INSTRUCTIONS = (
|
|
171
109
|
"You are a meticulous copy editor. Restore the provided text to its original form."
|
|
172
110
|
)
|
|
@@ -182,32 +120,34 @@ def echo_chamber(
|
|
|
182
120
|
reward_function: Callable[..., float] | None = None,
|
|
183
121
|
split: str | None = None,
|
|
184
122
|
**load_dataset_kwargs: Any,
|
|
185
|
-
) ->
|
|
123
|
+
) -> VerifierSingleTurnEnv:
|
|
186
124
|
"""Create an Echo Chamber Prime environment from a Hugging Face dataset column.
|
|
187
125
|
|
|
188
126
|
Args:
|
|
189
127
|
dataset_id: Identifier of the Hugging Face dataset to load.
|
|
190
128
|
column: Name of the column whose text should be glitched.
|
|
191
129
|
glitchlings: Glitchling specifiers that will corrupt the prompts.
|
|
192
|
-
seed: RNG seed forwarded to :func:`
|
|
130
|
+
seed: RNG seed forwarded to :func:`glitchlings.util.adapters.coerce_gaggle`.
|
|
193
131
|
instructions: System instructions supplied to the environment prompts.
|
|
194
132
|
reward_function: Optional callable used to score completions. Defaults to
|
|
195
|
-
:func:`
|
|
133
|
+
:func:`symmetric_levenshtein_similarity` when omitted.
|
|
196
134
|
split: Optional dataset split to load.
|
|
197
135
|
**load_dataset_kwargs: Extra keyword arguments forwarded to
|
|
198
136
|
:func:`datasets.load_dataset`.
|
|
199
|
-
"""
|
|
200
137
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
138
|
+
"""
|
|
139
|
+
datasets_module = require_datasets("datasets is required to build an echo chamber")
|
|
140
|
+
load_dataset = getattr(datasets_module, "load_dataset", None)
|
|
141
|
+
if load_dataset is None: # pragma: no cover - defensive
|
|
204
142
|
message = "datasets is required to build an echo chamber"
|
|
205
|
-
raise ModuleNotFoundError(message)
|
|
143
|
+
raise ModuleNotFoundError(message)
|
|
206
144
|
|
|
207
|
-
|
|
145
|
+
dataset_dict_cls = getattr(datasets_module, "DatasetDict", dict)
|
|
146
|
+
|
|
147
|
+
hf_dataset: Any
|
|
208
148
|
if split is None:
|
|
209
149
|
hf_dataset = load_dataset(dataset_id, **load_dataset_kwargs)
|
|
210
|
-
if isinstance(hf_dataset,
|
|
150
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
211
151
|
try:
|
|
212
152
|
hf_dataset = next(iter(hf_dataset.values()))
|
|
213
153
|
except StopIteration as exc: # pragma: no cover - defensive
|
|
@@ -215,10 +155,8 @@ def echo_chamber(
|
|
|
215
155
|
else:
|
|
216
156
|
hf_dataset = load_dataset(dataset_id, split=split, **load_dataset_kwargs)
|
|
217
157
|
|
|
218
|
-
if isinstance(hf_dataset,
|
|
219
|
-
raise ValueError(
|
|
220
|
-
"Specify which split to use when the dataset loads as a DatasetDict."
|
|
221
|
-
)
|
|
158
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
159
|
+
raise ValueError("Specify which split to use when the dataset loads as a DatasetDict.")
|
|
222
160
|
|
|
223
161
|
filtered_dataset = hf_dataset.filter(
|
|
224
162
|
lambda row: row.get(column) is not None,
|
|
@@ -242,7 +180,7 @@ def echo_chamber(
|
|
|
242
180
|
)
|
|
243
181
|
|
|
244
182
|
try:
|
|
245
|
-
dataset_length = len(base_dataset)
|
|
183
|
+
dataset_length = len(base_dataset)
|
|
246
184
|
except TypeError:
|
|
247
185
|
preview_rows: list[dict[str, Any]]
|
|
248
186
|
take_fn = getattr(base_dataset, "take", None)
|
|
@@ -269,6 +207,9 @@ def echo_chamber(
|
|
|
269
207
|
gaggle = _as_gaggle(glitchlings, seed=seed)
|
|
270
208
|
glitched_dataset = gaggle.corrupt_dataset(base_dataset, ["prompt"])
|
|
271
209
|
|
|
272
|
-
rubric_func = reward_function or
|
|
210
|
+
rubric_func = reward_function or normalized_edit_distance
|
|
273
211
|
rubric = vf.Rubric(funcs=[rubric_func], weights=[1.0])
|
|
274
|
-
return
|
|
212
|
+
return cast(
|
|
213
|
+
VerifierSingleTurnEnv,
|
|
214
|
+
vf.SingleTurnEnv(dataset=glitched_dataset, rubric=rubric),
|
|
215
|
+
)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch data loaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..util.adapters import coerce_gaggle
|
|
9
|
+
from ..zoo import Gaggle, Glitchling
|
|
10
|
+
from ._shared import corrupt_batch, infer_batch_targets, normalize_column_spec
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _GlitchedDataLoader(Iterable[Any]):
|
|
14
|
+
"""Wrapper that applies glitchlings lazily to each batch from a data loader."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
dataloader: Any,
|
|
19
|
+
gaggle: Gaggle,
|
|
20
|
+
*,
|
|
21
|
+
columns: list[str | int] | None,
|
|
22
|
+
) -> None:
|
|
23
|
+
self._dataloader = dataloader
|
|
24
|
+
self._gaggle = gaggle
|
|
25
|
+
self._explicit_columns = columns
|
|
26
|
+
self._inferred_columns: list[str | int] | None | _Sentinel = _UNINITIALISED
|
|
27
|
+
|
|
28
|
+
def __iter__(self) -> Iterator[Any]:
|
|
29
|
+
# Reset all glitchling RNGs before each fresh pass for determinism.
|
|
30
|
+
self._gaggle.sort_glitchlings()
|
|
31
|
+
for batch in self._dataloader:
|
|
32
|
+
targets = self._resolve_columns(batch)
|
|
33
|
+
yield corrupt_batch(batch, targets, self._gaggle)
|
|
34
|
+
|
|
35
|
+
def __len__(self) -> int:
|
|
36
|
+
return len(self._dataloader)
|
|
37
|
+
|
|
38
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
39
|
+
return getattr(self._dataloader, attribute)
|
|
40
|
+
|
|
41
|
+
def _resolve_columns(self, batch: Any) -> list[str | int] | None:
|
|
42
|
+
if self._explicit_columns is not None:
|
|
43
|
+
return self._explicit_columns
|
|
44
|
+
|
|
45
|
+
if self._inferred_columns is _UNINITIALISED:
|
|
46
|
+
self._inferred_columns = infer_batch_targets(batch)
|
|
47
|
+
|
|
48
|
+
return cast(list[str | int] | None, self._inferred_columns)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class _Sentinel:
|
|
52
|
+
"""Sentinel type for deferred column inference."""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_UNINITIALISED = _Sentinel()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def GlitchedDataLoader(
|
|
59
|
+
dataloader: Any,
|
|
60
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
61
|
+
*,
|
|
62
|
+
columns: str | int | Sequence[str | int] | None = None,
|
|
63
|
+
seed: int = 151,
|
|
64
|
+
) -> _GlitchedDataLoader:
|
|
65
|
+
"""Return a lazily glitched view of a PyTorch DataLoader's batches.
|
|
66
|
+
|
|
67
|
+
This function wraps a PyTorch DataLoader to apply glitchlings to specified
|
|
68
|
+
columns (or auto-inferred text columns) in each batch as it's yielded.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
dataloader: The PyTorch DataLoader to wrap.
|
|
72
|
+
glitchlings: A glitchling, gaggle, or specification of glitchlings to apply.
|
|
73
|
+
columns: Column name(s) or index/indices to corrupt. Can be:
|
|
74
|
+
- A single string column name (for dict-like batches)
|
|
75
|
+
- A single integer index (for sequence-like batches)
|
|
76
|
+
- A sequence of column names or indices
|
|
77
|
+
- None to auto-infer text columns (default)
|
|
78
|
+
seed: RNG seed for deterministic corruption (default: 151).
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
A wrapped dataloader that yields corrupted batches.
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
>>> from torch.utils.data import DataLoader
|
|
85
|
+
>>> from glitchlings.dlc.pytorch import GlitchedDataLoader
|
|
86
|
+
>>> dataset = [{"text": "hello", "label": 0}]
|
|
87
|
+
>>> loader = DataLoader(dataset)
|
|
88
|
+
>>> glitched = GlitchedDataLoader(loader, "typogre", columns="text")
|
|
89
|
+
>>> for batch in glitched:
|
|
90
|
+
... print(batch)
|
|
91
|
+
{'text': 'helo', 'label': 0}
|
|
92
|
+
"""
|
|
93
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
94
|
+
normalized = normalize_column_spec(columns)
|
|
95
|
+
return _GlitchedDataLoader(dataloader, gaggle, columns=normalized)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
__all__ = ["GlitchedDataLoader"]
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch Lightning data modules."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat.loaders import get_pytorch_lightning_datamodule
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
11
|
+
from ._shared import normalize_column_spec, wrap_dataloader
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _glitch_datamodule(
|
|
15
|
+
datamodule: Any,
|
|
16
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
17
|
+
column: str | Sequence[str],
|
|
18
|
+
*,
|
|
19
|
+
seed: int = 151,
|
|
20
|
+
) -> Any:
|
|
21
|
+
"""Return a proxy that applies glitchlings to batches from the datamodule."""
|
|
22
|
+
|
|
23
|
+
columns = normalize_column_spec(column)
|
|
24
|
+
if columns is None: # pragma: no cover - defensive
|
|
25
|
+
raise ValueError("At least one column must be specified")
|
|
26
|
+
# Lightning datamodules only support string column names (mapping keys)
|
|
27
|
+
columns_str = cast(list[str], columns)
|
|
28
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
29
|
+
|
|
30
|
+
return _GlitchedLightningDataModule(datamodule, columns_str, gaggle)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def GlitchedLightningDataModule(
|
|
34
|
+
datamodule: Any,
|
|
35
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
36
|
+
*,
|
|
37
|
+
column: str | Sequence[str],
|
|
38
|
+
seed: int = 151,
|
|
39
|
+
) -> Any:
|
|
40
|
+
"""Return a glitched wrapper around a PyTorch Lightning LightningDataModule.
|
|
41
|
+
|
|
42
|
+
This function wraps a LightningDataModule to apply glitchlings to specified
|
|
43
|
+
columns in batches yielded by the module's dataloaders.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
datamodule: The LightningDataModule to wrap.
|
|
47
|
+
glitchlings: A glitchling, gaggle, or specification of glitchlings to apply.
|
|
48
|
+
column: The column name (string) or names (sequence of strings) to corrupt.
|
|
49
|
+
seed: RNG seed for deterministic corruption (default: 151).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
A wrapped datamodule that yields corrupted batches from its dataloaders.
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
>>> from pytorch_lightning import LightningDataModule
|
|
56
|
+
>>> from glitchlings.dlc.pytorch_lightning import GlitchedLightningDataModule
|
|
57
|
+
>>> class MyDataModule(LightningDataModule):
|
|
58
|
+
... def train_dataloader(self):
|
|
59
|
+
... return [{"text": "hello", "label": 0}]
|
|
60
|
+
>>> dm = MyDataModule()
|
|
61
|
+
>>> glitched = GlitchedLightningDataModule(dm, "typogre", column="text")
|
|
62
|
+
>>> batches = list(glitched.train_dataloader())
|
|
63
|
+
"""
|
|
64
|
+
return _glitch_datamodule(datamodule, glitchlings, column, seed=seed)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class _GlitchedLightningDataModule:
|
|
68
|
+
"""Proxy wrapper around a LightningDataModule applying glitchlings to batches."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, base: Any, columns: list[str], gaggle: Gaggle) -> None:
|
|
71
|
+
object.__setattr__(self, "_glitch_base", base)
|
|
72
|
+
object.__setattr__(self, "_glitch_columns", columns)
|
|
73
|
+
object.__setattr__(self, "_glitch_gaggle", gaggle)
|
|
74
|
+
|
|
75
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
76
|
+
return getattr(self._glitch_base, attribute)
|
|
77
|
+
|
|
78
|
+
def __setattr__(self, attribute: str, value: Any) -> None:
|
|
79
|
+
if attribute.startswith("_glitch_"):
|
|
80
|
+
object.__setattr__(self, attribute, value)
|
|
81
|
+
else:
|
|
82
|
+
setattr(self._glitch_base, attribute, value)
|
|
83
|
+
|
|
84
|
+
def __delattr__(self, attribute: str) -> None:
|
|
85
|
+
if attribute.startswith("_glitch_"):
|
|
86
|
+
object.__delattr__(self, attribute)
|
|
87
|
+
else:
|
|
88
|
+
delattr(self._glitch_base, attribute)
|
|
89
|
+
|
|
90
|
+
def __dir__(self) -> list[str]:
|
|
91
|
+
return sorted(set(dir(self.__class__)) | set(dir(self._glitch_base)))
|
|
92
|
+
|
|
93
|
+
# LightningDataModule API -------------------------------------------------
|
|
94
|
+
def prepare_data(self, *args: Any, **kwargs: Any) -> Any:
|
|
95
|
+
return self._glitch_base.prepare_data(*args, **kwargs)
|
|
96
|
+
|
|
97
|
+
def setup(self, *args: Any, **kwargs: Any) -> Any:
|
|
98
|
+
return self._glitch_base.setup(*args, **kwargs)
|
|
99
|
+
|
|
100
|
+
def teardown(self, *args: Any, **kwargs: Any) -> Any:
|
|
101
|
+
return self._glitch_base.teardown(*args, **kwargs)
|
|
102
|
+
|
|
103
|
+
def state_dict(self) -> Mapping[str, Any]:
|
|
104
|
+
state = self._glitch_base.state_dict()
|
|
105
|
+
return cast(Mapping[str, Any], state)
|
|
106
|
+
|
|
107
|
+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
|
108
|
+
self._glitch_base.load_state_dict(state_dict)
|
|
109
|
+
|
|
110
|
+
def transfer_batch_to_device(self, batch: Any, device: Any, dataloader_idx: int) -> Any:
|
|
111
|
+
return self._glitch_base.transfer_batch_to_device(batch, device, dataloader_idx)
|
|
112
|
+
|
|
113
|
+
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
|
114
|
+
return self._glitch_base.on_before_batch_transfer(batch, dataloader_idx)
|
|
115
|
+
|
|
116
|
+
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
|
117
|
+
return self._glitch_base.on_after_batch_transfer(batch, dataloader_idx)
|
|
118
|
+
|
|
119
|
+
def train_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
120
|
+
loader = self._glitch_base.train_dataloader(*args, **kwargs)
|
|
121
|
+
return wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
122
|
+
|
|
123
|
+
def val_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
124
|
+
loader = self._glitch_base.val_dataloader(*args, **kwargs)
|
|
125
|
+
return wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
126
|
+
|
|
127
|
+
def test_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
128
|
+
loader = self._glitch_base.test_dataloader(*args, **kwargs)
|
|
129
|
+
return wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
130
|
+
|
|
131
|
+
def predict_dataloader(self, *args: Any, **kwargs: Any) -> Any:
|
|
132
|
+
loader = self._glitch_base.predict_dataloader(*args, **kwargs)
|
|
133
|
+
return wrap_dataloader(loader, self._glitch_columns, self._glitch_gaggle)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Module initialization: set up inheritance from LightningDataModule if available
|
|
137
|
+
def _setup_inheritance() -> None:
|
|
138
|
+
"""Set up _GlitchedLightningDataModule to inherit from LightningDataModule.
|
|
139
|
+
|
|
140
|
+
This function is called once at module import time to dynamically set the base
|
|
141
|
+
class of _GlitchedLightningDataModule to inherit from
|
|
142
|
+
pytorch_lightning.LightningDataModule when available. This ensures that
|
|
143
|
+
isinstance(glitched, LightningDataModule) checks work correctly and that the
|
|
144
|
+
wrapper interoperates with Lightning APIs that require that type.
|
|
145
|
+
"""
|
|
146
|
+
datamodule_cls = get_pytorch_lightning_datamodule()
|
|
147
|
+
if datamodule_cls is None:
|
|
148
|
+
# If LightningDataModule is not available, keep as plain object
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
# Try to dynamically set __bases__ to inherit from LightningDataModule
|
|
152
|
+
try:
|
|
153
|
+
_GlitchedLightningDataModule.__bases__ = (datamodule_cls,)
|
|
154
|
+
except TypeError:
|
|
155
|
+
# If we can't modify __bases__ (e.g., due to __slots__), create a new class
|
|
156
|
+
namespace = {
|
|
157
|
+
name: value
|
|
158
|
+
for name, value in vars(_GlitchedLightningDataModule).items()
|
|
159
|
+
if name not in {"__dict__", "__weakref__"}
|
|
160
|
+
}
|
|
161
|
+
replacement = cast(
|
|
162
|
+
type[Any],
|
|
163
|
+
type("_GlitchedLightningDataModule", (datamodule_cls,), namespace),
|
|
164
|
+
)
|
|
165
|
+
# Update the module's global namespace
|
|
166
|
+
globals()["_GlitchedLightningDataModule"] = replacement
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# Set up inheritance at module import time
|
|
170
|
+
_setup_inheritance()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
__all__ = ["GlitchedLightningDataModule"]
|