glitchlings 0.4.5__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +71 -0
- glitchlings/__main__.py +8 -0
- glitchlings/_zoo_rust.cp311-win_amd64.pyd +0 -0
- glitchlings/compat.py +282 -0
- glitchlings/config.py +386 -0
- glitchlings/config.toml +3 -0
- glitchlings/data/__init__.py +1 -0
- glitchlings/data/hokey_assets.json +193 -0
- glitchlings/dlc/__init__.py +7 -0
- glitchlings/dlc/_shared.py +153 -0
- glitchlings/dlc/huggingface.py +81 -0
- glitchlings/dlc/prime.py +254 -0
- glitchlings/dlc/pytorch.py +166 -0
- glitchlings/dlc/pytorch_lightning.py +209 -0
- glitchlings/lexicon/__init__.py +192 -0
- glitchlings/lexicon/_cache.py +108 -0
- glitchlings/lexicon/data/default_vector_cache.json +82 -0
- glitchlings/lexicon/metrics.py +162 -0
- glitchlings/lexicon/vector.py +652 -0
- glitchlings/lexicon/wordnet.py +228 -0
- glitchlings/main.py +364 -0
- glitchlings/util/__init__.py +195 -0
- glitchlings/util/adapters.py +27 -0
- glitchlings/util/hokey_generator.py +144 -0
- glitchlings/util/stretch_locator.py +140 -0
- glitchlings/util/stretchability.py +375 -0
- glitchlings/zoo/__init__.py +172 -0
- glitchlings/zoo/_ocr_confusions.py +32 -0
- glitchlings/zoo/_rate.py +131 -0
- glitchlings/zoo/_rust_extensions.py +143 -0
- glitchlings/zoo/_sampling.py +54 -0
- glitchlings/zoo/_text_utils.py +100 -0
- glitchlings/zoo/adjax.py +128 -0
- glitchlings/zoo/apostrofae.py +127 -0
- glitchlings/zoo/assets/__init__.py +0 -0
- glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
- glitchlings/zoo/core.py +582 -0
- glitchlings/zoo/hokey.py +173 -0
- glitchlings/zoo/jargoyle.py +335 -0
- glitchlings/zoo/mim1c.py +109 -0
- glitchlings/zoo/ocr_confusions.tsv +30 -0
- glitchlings/zoo/redactyl.py +193 -0
- glitchlings/zoo/reduple.py +148 -0
- glitchlings/zoo/rushmore.py +153 -0
- glitchlings/zoo/scannequin.py +171 -0
- glitchlings/zoo/typogre.py +231 -0
- glitchlings/zoo/zeedub.py +185 -0
- glitchlings-0.4.5.dist-info/METADATA +648 -0
- glitchlings-0.4.5.dist-info/RECORD +53 -0
- glitchlings-0.4.5.dist-info/WHEEL +5 -0
- glitchlings-0.4.5.dist-info/entry_points.txt +2 -0
- glitchlings-0.4.5.dist-info/licenses/LICENSE +201 -0
- glitchlings-0.4.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Shared utilities for DLC integrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..zoo.core import Gaggle, _is_transcript
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_environment(
|
|
12
|
+
env: Any,
|
|
13
|
+
*,
|
|
14
|
+
loader: Callable[[str], Any],
|
|
15
|
+
environment_type: type[Any],
|
|
16
|
+
) -> Any:
|
|
17
|
+
"""Return a fully-instantiated verifier environment."""
|
|
18
|
+
if isinstance(env, str):
|
|
19
|
+
env = loader(env)
|
|
20
|
+
|
|
21
|
+
if not isinstance(env, environment_type):
|
|
22
|
+
raise TypeError("Invalid environment type")
|
|
23
|
+
|
|
24
|
+
return env
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
28
|
+
"""Identify which dataset columns should be corrupted."""
|
|
29
|
+
available = set(getattr(dataset, "column_names", ()))
|
|
30
|
+
|
|
31
|
+
if columns is not None:
|
|
32
|
+
missing = sorted(set(columns) - available)
|
|
33
|
+
if missing:
|
|
34
|
+
missing_str = ", ".join(missing)
|
|
35
|
+
raise ValueError(f"Columns not found in dataset: {missing_str}")
|
|
36
|
+
return list(columns)
|
|
37
|
+
|
|
38
|
+
for candidate in ("prompt", "question"):
|
|
39
|
+
if candidate in available:
|
|
40
|
+
return [candidate]
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
dataset_length = len(dataset)
|
|
44
|
+
except TypeError:
|
|
45
|
+
preview_rows: list[dict[str, Any]]
|
|
46
|
+
take_fn = getattr(dataset, "take", None)
|
|
47
|
+
if callable(take_fn):
|
|
48
|
+
preview_rows = list(take_fn(1))
|
|
49
|
+
else:
|
|
50
|
+
iterator = iter(dataset)
|
|
51
|
+
try:
|
|
52
|
+
first_row = next(iterator)
|
|
53
|
+
except StopIteration:
|
|
54
|
+
preview_rows = []
|
|
55
|
+
else:
|
|
56
|
+
preview_rows = [first_row]
|
|
57
|
+
sample = dict(preview_rows[0]) if preview_rows else {}
|
|
58
|
+
else:
|
|
59
|
+
sample = dataset[0] if dataset_length else {}
|
|
60
|
+
inferred = [
|
|
61
|
+
name for name in getattr(dataset, "column_names", ()) if isinstance(sample.get(name), str)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
if inferred:
|
|
65
|
+
return inferred
|
|
66
|
+
|
|
67
|
+
raise ValueError("Unable to determine which dataset columns to corrupt.")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def normalize_column_spec(
|
|
71
|
+
columns: str | int | Sequence[str | int] | None,
|
|
72
|
+
) -> list[str | int] | None:
|
|
73
|
+
"""Normalize a column specification into a list of keys or indices.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
columns: Column specification as a single value, sequence of values, or None.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A list of column identifiers, or None if input was None.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If an empty sequence is provided.
|
|
83
|
+
"""
|
|
84
|
+
if columns is None:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
if isinstance(columns, (str, int)):
|
|
88
|
+
return [columns]
|
|
89
|
+
|
|
90
|
+
normalized = list(columns)
|
|
91
|
+
if not normalized:
|
|
92
|
+
raise ValueError("At least one column must be specified")
|
|
93
|
+
return normalized
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def is_textual_candidate(value: Any) -> bool:
|
|
97
|
+
"""Return ``True`` when ``value`` looks like text that glitchlings can corrupt.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
value: The value to check for textual content.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
True if the value appears to be textual content.
|
|
104
|
+
"""
|
|
105
|
+
if isinstance(value, str):
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
if _is_transcript(value, allow_empty=False, require_all_content=True):
|
|
109
|
+
return True
|
|
110
|
+
|
|
111
|
+
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
|
|
112
|
+
if not value:
|
|
113
|
+
return False
|
|
114
|
+
if all(isinstance(item, str) for item in value):
|
|
115
|
+
return True
|
|
116
|
+
if _is_transcript(list(value), allow_empty=False, require_all_content=True):
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def corrupt_text_value(value: Any, gaggle: Gaggle) -> Any:
|
|
123
|
+
"""Return ``value`` with glitchlings applied when possible.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
value: The value to corrupt (string, transcript, or sequence of strings).
|
|
127
|
+
gaggle: The gaggle of glitchlings to apply.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The corrupted value, preserving the original type where possible.
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(value, str):
|
|
133
|
+
return gaggle.corrupt(value)
|
|
134
|
+
|
|
135
|
+
if _is_transcript(value, allow_empty=True):
|
|
136
|
+
return gaggle.corrupt(value)
|
|
137
|
+
|
|
138
|
+
if isinstance(value, list) and value and all(isinstance(item, str) for item in value):
|
|
139
|
+
return [gaggle.corrupt(item) for item in value]
|
|
140
|
+
|
|
141
|
+
if isinstance(value, tuple) and value and all(isinstance(item, str) for item in value):
|
|
142
|
+
return tuple(gaggle.corrupt(item) for item in value)
|
|
143
|
+
|
|
144
|
+
return value
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
__all__ = [
|
|
148
|
+
"corrupt_text_value",
|
|
149
|
+
"is_textual_candidate",
|
|
150
|
+
"normalize_column_spec",
|
|
151
|
+
"resolve_columns",
|
|
152
|
+
"resolve_environment",
|
|
153
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Integration helpers for the Hugging Face datasets library."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import datasets, get_datasets_dataset, require_datasets
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _normalize_columns(column: str | Sequence[str]) -> list[str]:
|
|
14
|
+
"""Normalize a column specification to a list."""
|
|
15
|
+
if isinstance(column, str):
|
|
16
|
+
return [column]
|
|
17
|
+
|
|
18
|
+
normalized = list(column)
|
|
19
|
+
if not normalized:
|
|
20
|
+
raise ValueError("At least one column must be specified")
|
|
21
|
+
return normalized
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _glitch_dataset(
|
|
25
|
+
dataset: Any,
|
|
26
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
27
|
+
column: str | Sequence[str],
|
|
28
|
+
*,
|
|
29
|
+
seed: int = 151,
|
|
30
|
+
) -> Any:
|
|
31
|
+
"""Apply glitchlings to the provided dataset columns."""
|
|
32
|
+
columns = _normalize_columns(column)
|
|
33
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
34
|
+
return gaggle.corrupt_dataset(dataset, columns)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _ensure_dataset_class() -> Any:
|
|
38
|
+
"""Return the Hugging Face :class:`~datasets.Dataset` patched with ``.glitch``."""
|
|
39
|
+
dataset_cls = get_datasets_dataset()
|
|
40
|
+
if dataset_cls is None: # pragma: no cover - datasets is an install-time dependency
|
|
41
|
+
require_datasets("datasets is not installed")
|
|
42
|
+
dataset_cls = get_datasets_dataset()
|
|
43
|
+
if dataset_cls is None:
|
|
44
|
+
message = "datasets is not installed"
|
|
45
|
+
error = datasets.error
|
|
46
|
+
if error is not None:
|
|
47
|
+
raise ModuleNotFoundError(message) from error
|
|
48
|
+
raise ModuleNotFoundError(message)
|
|
49
|
+
|
|
50
|
+
if getattr(dataset_cls, "glitch", None) is None:
|
|
51
|
+
|
|
52
|
+
def glitch(
|
|
53
|
+
self: Any,
|
|
54
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
55
|
+
*,
|
|
56
|
+
column: str | Sequence[str],
|
|
57
|
+
seed: int = 151,
|
|
58
|
+
**_: Any,
|
|
59
|
+
) -> Any:
|
|
60
|
+
"""Return a lazily corrupted copy of the dataset."""
|
|
61
|
+
return _glitch_dataset(self, glitchlings, column, seed=seed)
|
|
62
|
+
|
|
63
|
+
setattr(dataset_cls, "glitch", glitch)
|
|
64
|
+
|
|
65
|
+
return cast(type[Any], dataset_cls)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def install() -> None:
|
|
69
|
+
"""Monkeypatch the Hugging Face :class:`~datasets.Dataset` with ``.glitch``."""
|
|
70
|
+
_ensure_dataset_class()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
Dataset: type[Any] | None
|
|
74
|
+
_DatasetAlias = get_datasets_dataset()
|
|
75
|
+
if _DatasetAlias is not None:
|
|
76
|
+
Dataset = _ensure_dataset_class()
|
|
77
|
+
else: # pragma: no cover - datasets is an install-time dependency
|
|
78
|
+
Dataset = None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
__all__ = ["Dataset", "install"]
|
glitchlings/dlc/prime.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Integration helpers for the optional verifiers prime DLC."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Sequence
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any, Callable, Protocol, cast
|
|
8
|
+
|
|
9
|
+
from ..compat import require_datasets, require_jellyfish, require_verifiers
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling, Mim1c, Typogre
|
|
12
|
+
from ._shared import resolve_columns as _resolve_columns_shared
|
|
13
|
+
from ._shared import resolve_environment as _resolve_environment_shared
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class VerifierEnvironment(Protocol):
|
|
17
|
+
"""Minimal interface for verifiers environments."""
|
|
18
|
+
|
|
19
|
+
dataset: Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class VerifierSingleTurnEnv(Protocol):
|
|
23
|
+
"""Minimal interface for single-turn verifier environments."""
|
|
24
|
+
|
|
25
|
+
dataset: Any
|
|
26
|
+
rubric: Any
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
vf = require_verifiers("verifiers is not installed; install glitchlings[prime]")
|
|
30
|
+
_jellyfish = require_jellyfish("jellyfish is not installed; install glitchlings[prime]")
|
|
31
|
+
damerau_levenshtein_distance = _jellyfish.damerau_levenshtein_distance
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from .huggingface import Dataset as _HuggingFaceDataset
|
|
35
|
+
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
36
|
+
_HuggingFaceDataset = None
|
|
37
|
+
else:
|
|
38
|
+
if _HuggingFaceDataset is None: # pragma: no cover - optional dependency
|
|
39
|
+
_HuggingFaceDataset = None
|
|
40
|
+
|
|
41
|
+
Dataset: type[Any]
|
|
42
|
+
if _HuggingFaceDataset is None:
|
|
43
|
+
Dataset = object
|
|
44
|
+
else:
|
|
45
|
+
Dataset = _HuggingFaceDataset
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _resolve_environment(env: str | VerifierEnvironment) -> VerifierEnvironment:
|
|
49
|
+
"""Return a fully-instantiated verifier environment."""
|
|
50
|
+
resolved = _resolve_environment_shared(
|
|
51
|
+
env,
|
|
52
|
+
loader=vf.load_environment,
|
|
53
|
+
environment_type=cast(type[Any], vf.Environment),
|
|
54
|
+
)
|
|
55
|
+
return cast(VerifierEnvironment, resolved)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
59
|
+
"""Identify which dataset columns should be corrupted."""
|
|
60
|
+
return _resolve_columns_shared(dataset, columns)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Difficulty(Enum):
|
|
64
|
+
"""Difficulty levels for tutorial environments."""
|
|
65
|
+
|
|
66
|
+
Easy = 0.25
|
|
67
|
+
Normal = 1.0
|
|
68
|
+
Hard = 1.75
|
|
69
|
+
Extreme = 3
|
|
70
|
+
Impossible = 9
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def tutorial_level(
|
|
74
|
+
env: VerifierEnvironment | str,
|
|
75
|
+
seed: int = 151,
|
|
76
|
+
difficulty: Difficulty = Difficulty.Normal,
|
|
77
|
+
) -> VerifierEnvironment:
|
|
78
|
+
"""Create a low-corruption environment using tuned defaults."""
|
|
79
|
+
tuned_mim1c = Mim1c(rate=0.01 * difficulty.value)
|
|
80
|
+
tuned_typogre = Typogre(rate=0.025 * difficulty.value)
|
|
81
|
+
|
|
82
|
+
return load_environment(
|
|
83
|
+
env,
|
|
84
|
+
glitchlings=[tuned_mim1c, tuned_typogre],
|
|
85
|
+
seed=seed,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def load_environment(
|
|
90
|
+
env: str | VerifierEnvironment,
|
|
91
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle | None = None,
|
|
92
|
+
*,
|
|
93
|
+
seed: int = 151,
|
|
94
|
+
columns: Sequence[str] | None = None,
|
|
95
|
+
) -> VerifierEnvironment:
|
|
96
|
+
"""Load an environment and optionally corrupt it with glitchlings."""
|
|
97
|
+
environment = _resolve_environment(env)
|
|
98
|
+
|
|
99
|
+
if glitchlings is None:
|
|
100
|
+
return environment
|
|
101
|
+
|
|
102
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
103
|
+
|
|
104
|
+
dataset = environment.dataset
|
|
105
|
+
corrupt_columns = _resolve_columns(dataset, columns)
|
|
106
|
+
environment.dataset = gaggle.corrupt_dataset(dataset, corrupt_columns)
|
|
107
|
+
return environment
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _as_gaggle(
|
|
111
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
112
|
+
*,
|
|
113
|
+
seed: int,
|
|
114
|
+
) -> Gaggle:
|
|
115
|
+
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
116
|
+
return coerce_gaggle(glitchlings, seed=seed)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _extract_completion_text(completion: Any) -> str:
|
|
120
|
+
"""Normalize a completion payload into a plain string."""
|
|
121
|
+
if isinstance(completion, str):
|
|
122
|
+
return completion
|
|
123
|
+
|
|
124
|
+
if isinstance(completion, list) and completion:
|
|
125
|
+
first = completion[0]
|
|
126
|
+
if isinstance(first, dict) and "content" in first:
|
|
127
|
+
return str(first["content"])
|
|
128
|
+
return str(first)
|
|
129
|
+
|
|
130
|
+
return str(completion)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def symmetric_damerau_levenshtein_similarity(
|
|
134
|
+
_: Any,
|
|
135
|
+
completion: Any,
|
|
136
|
+
answer: str,
|
|
137
|
+
) -> float:
|
|
138
|
+
"""Return ``1 - (distance / max_len)`` using Damerau-Levenshtein distance."""
|
|
139
|
+
completion_text = _extract_completion_text(completion)
|
|
140
|
+
target = answer or ""
|
|
141
|
+
denominator = max(len(completion_text), len(target), 1)
|
|
142
|
+
distance = cast(int, damerau_levenshtein_distance(completion_text, target))
|
|
143
|
+
score = 1.0 - (distance / denominator)
|
|
144
|
+
return max(0.0, min(1.0, score))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
DEFAULT_CLEANUP_INSTRUCTIONS = (
|
|
148
|
+
"You are a meticulous copy editor. Restore the provided text to its original form."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def echo_chamber(
|
|
153
|
+
dataset_id: str,
|
|
154
|
+
column: str,
|
|
155
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
156
|
+
*,
|
|
157
|
+
seed: int = 151,
|
|
158
|
+
instructions: str = DEFAULT_CLEANUP_INSTRUCTIONS,
|
|
159
|
+
reward_function: Callable[..., float] | None = None,
|
|
160
|
+
split: str | None = None,
|
|
161
|
+
**load_dataset_kwargs: Any,
|
|
162
|
+
) -> VerifierSingleTurnEnv:
|
|
163
|
+
"""Create an Echo Chamber Prime environment from a Hugging Face dataset column.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
dataset_id: Identifier of the Hugging Face dataset to load.
|
|
167
|
+
column: Name of the column whose text should be glitched.
|
|
168
|
+
glitchlings: Glitchling specifiers that will corrupt the prompts.
|
|
169
|
+
seed: RNG seed forwarded to :func:`glitchlings.util.adapters.coerce_gaggle`.
|
|
170
|
+
instructions: System instructions supplied to the environment prompts.
|
|
171
|
+
reward_function: Optional callable used to score completions. Defaults to
|
|
172
|
+
:func:`symmetric_damerau_levenshtein_similarity` when omitted.
|
|
173
|
+
split: Optional dataset split to load.
|
|
174
|
+
**load_dataset_kwargs: Extra keyword arguments forwarded to
|
|
175
|
+
:func:`datasets.load_dataset`.
|
|
176
|
+
|
|
177
|
+
"""
|
|
178
|
+
datasets_module = require_datasets("datasets is required to build an echo chamber")
|
|
179
|
+
load_dataset = getattr(datasets_module, "load_dataset", None)
|
|
180
|
+
if load_dataset is None: # pragma: no cover - defensive
|
|
181
|
+
message = "datasets is required to build an echo chamber"
|
|
182
|
+
raise ModuleNotFoundError(message)
|
|
183
|
+
|
|
184
|
+
dataset_dict_cls = getattr(datasets_module, "DatasetDict", dict)
|
|
185
|
+
|
|
186
|
+
hf_dataset: Any
|
|
187
|
+
if split is None:
|
|
188
|
+
hf_dataset = load_dataset(dataset_id, **load_dataset_kwargs)
|
|
189
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
190
|
+
try:
|
|
191
|
+
hf_dataset = next(iter(hf_dataset.values()))
|
|
192
|
+
except StopIteration as exc: # pragma: no cover - defensive
|
|
193
|
+
raise ValueError("The specified dataset does not contain any splits") from exc
|
|
194
|
+
else:
|
|
195
|
+
hf_dataset = load_dataset(dataset_id, split=split, **load_dataset_kwargs)
|
|
196
|
+
|
|
197
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
198
|
+
raise ValueError("Specify which split to use when the dataset loads as a DatasetDict.")
|
|
199
|
+
|
|
200
|
+
filtered_dataset = hf_dataset.filter(
|
|
201
|
+
lambda row: row.get(column) is not None,
|
|
202
|
+
load_from_cache_file=False,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
source_column_names = list(filtered_dataset.column_names)
|
|
206
|
+
|
|
207
|
+
def _build_prompt(row: dict[str, Any]) -> dict[str, Any]:
|
|
208
|
+
text = str(row[column])
|
|
209
|
+
prompt = [
|
|
210
|
+
{"role": "system", "content": instructions},
|
|
211
|
+
{"role": "user", "content": f"Corrupted text:\n{text}"},
|
|
212
|
+
]
|
|
213
|
+
return {"prompt": prompt, "answer": text}
|
|
214
|
+
|
|
215
|
+
base_dataset = filtered_dataset.map(
|
|
216
|
+
_build_prompt,
|
|
217
|
+
remove_columns=source_column_names,
|
|
218
|
+
load_from_cache_file=False,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
dataset_length = len(base_dataset)
|
|
223
|
+
except TypeError:
|
|
224
|
+
preview_rows: list[dict[str, Any]]
|
|
225
|
+
take_fn = getattr(base_dataset, "take", None)
|
|
226
|
+
if callable(take_fn):
|
|
227
|
+
preview_rows = list(take_fn(1))
|
|
228
|
+
else:
|
|
229
|
+
iterator = iter(base_dataset)
|
|
230
|
+
try:
|
|
231
|
+
first_row = next(iterator)
|
|
232
|
+
except StopIteration:
|
|
233
|
+
preview_rows = []
|
|
234
|
+
else:
|
|
235
|
+
preview_rows = [first_row]
|
|
236
|
+
if not preview_rows:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Column '{column}' did not yield any textual entries in dataset '{dataset_id}'."
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
if dataset_length == 0:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Column '{column}' did not yield any textual entries in dataset '{dataset_id}'."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
gaggle = _as_gaggle(glitchlings, seed=seed)
|
|
247
|
+
glitched_dataset = gaggle.corrupt_dataset(base_dataset, ["prompt"])
|
|
248
|
+
|
|
249
|
+
rubric_func = reward_function or symmetric_damerau_levenshtein_similarity
|
|
250
|
+
rubric = vf.Rubric(funcs=[rubric_func], weights=[1.0])
|
|
251
|
+
return cast(
|
|
252
|
+
VerifierSingleTurnEnv,
|
|
253
|
+
vf.SingleTurnEnv(dataset=glitched_dataset, rubric=rubric),
|
|
254
|
+
)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch data loaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import get_torch_dataloader, require_torch
|
|
9
|
+
from ..compat import torch as _torch_dependency
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling
|
|
12
|
+
from ._shared import corrupt_text_value, is_textual_candidate, normalize_column_spec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _apply_to_batch(batch: Any, targets: list[str | int] | None, gaggle: Gaggle) -> Any:
|
|
16
|
+
"""Return ``batch`` with glitchlings applied to the specified ``targets``."""
|
|
17
|
+
if targets is None:
|
|
18
|
+
return corrupt_text_value(batch, gaggle)
|
|
19
|
+
|
|
20
|
+
if isinstance(batch, Mapping):
|
|
21
|
+
mutated = cast(MutableMapping[str, Any], dict(batch))
|
|
22
|
+
for key in targets:
|
|
23
|
+
if not isinstance(key, str):
|
|
24
|
+
raise TypeError("Mapping batches require string column names")
|
|
25
|
+
if key not in mutated:
|
|
26
|
+
raise ValueError(f"Column '{key}' not found in DataLoader batch")
|
|
27
|
+
mutated[key] = corrupt_text_value(mutated[key], gaggle)
|
|
28
|
+
return mutated
|
|
29
|
+
|
|
30
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
31
|
+
mutated_sequence = list(batch)
|
|
32
|
+
for index in targets:
|
|
33
|
+
if not isinstance(index, int):
|
|
34
|
+
raise TypeError("Sequence batches require integer column indices")
|
|
35
|
+
try:
|
|
36
|
+
mutated_sequence[index] = corrupt_text_value(mutated_sequence[index], gaggle)
|
|
37
|
+
except IndexError as exc: # pragma: no cover - defensive
|
|
38
|
+
raise IndexError("Column index out of range for DataLoader batch") from exc
|
|
39
|
+
if isinstance(batch, tuple):
|
|
40
|
+
return tuple(mutated_sequence)
|
|
41
|
+
return mutated_sequence
|
|
42
|
+
|
|
43
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _infer_targets(batch: Any) -> list[str | int] | None:
|
|
47
|
+
"""Infer which fields should be glitched from a representative ``batch``."""
|
|
48
|
+
if isinstance(batch, Mapping):
|
|
49
|
+
inferred = [key for key, value in batch.items() if is_textual_candidate(value)]
|
|
50
|
+
if inferred:
|
|
51
|
+
return inferred
|
|
52
|
+
raise ValueError("Unable to infer which mapping columns contain text")
|
|
53
|
+
|
|
54
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
55
|
+
inferred_indices: list[str | int] = [
|
|
56
|
+
idx for idx, value in enumerate(batch) if is_textual_candidate(value)
|
|
57
|
+
]
|
|
58
|
+
if inferred_indices:
|
|
59
|
+
return inferred_indices
|
|
60
|
+
raise ValueError("Unable to infer which sequence indices contain text")
|
|
61
|
+
|
|
62
|
+
if is_textual_candidate(batch):
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _GlitchedDataLoader(Iterable[Any]):
|
|
69
|
+
"""Wrapper that applies glitchlings lazily to each batch from a data loader."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
dataloader: Any,
|
|
74
|
+
gaggle: Gaggle,
|
|
75
|
+
*,
|
|
76
|
+
columns: list[str | int] | None,
|
|
77
|
+
) -> None:
|
|
78
|
+
self._dataloader = dataloader
|
|
79
|
+
self._gaggle = gaggle
|
|
80
|
+
self._explicit_columns = columns
|
|
81
|
+
self._inferred_columns: list[str | int] | None | _Sentinel = _UNINITIALISED
|
|
82
|
+
|
|
83
|
+
def __iter__(self) -> Iterator[Any]:
|
|
84
|
+
# Reset all glitchling RNGs before each fresh pass for determinism.
|
|
85
|
+
self._gaggle.sort_glitchlings()
|
|
86
|
+
for batch in self._dataloader:
|
|
87
|
+
targets = self._resolve_columns(batch)
|
|
88
|
+
yield _apply_to_batch(batch, targets, self._gaggle)
|
|
89
|
+
|
|
90
|
+
def __len__(self) -> int:
|
|
91
|
+
return len(self._dataloader)
|
|
92
|
+
|
|
93
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
94
|
+
return getattr(self._dataloader, attribute)
|
|
95
|
+
|
|
96
|
+
def _resolve_columns(self, batch: Any) -> list[str | int] | None:
|
|
97
|
+
if self._explicit_columns is not None:
|
|
98
|
+
return self._explicit_columns
|
|
99
|
+
|
|
100
|
+
if self._inferred_columns is _UNINITIALISED:
|
|
101
|
+
self._inferred_columns = _infer_targets(batch)
|
|
102
|
+
|
|
103
|
+
return cast(list[str | int] | None, self._inferred_columns)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _Sentinel:
|
|
107
|
+
"""Sentinel type for deferred column inference."""
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
_UNINITIALISED = _Sentinel()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _ensure_dataloader_class() -> type[Any]:
|
|
114
|
+
"""Return :class:`torch.utils.data.DataLoader` patched with ``.glitch``."""
|
|
115
|
+
dataloader_cls = get_torch_dataloader()
|
|
116
|
+
if dataloader_cls is None:
|
|
117
|
+
require_torch("torch is not installed; install glitchlings[torch]")
|
|
118
|
+
dataloader_cls = get_torch_dataloader()
|
|
119
|
+
if dataloader_cls is None: # pragma: no cover - defensive
|
|
120
|
+
message = "torch.utils.data.DataLoader is not available"
|
|
121
|
+
error = _torch_dependency.error
|
|
122
|
+
if error is not None:
|
|
123
|
+
raise ModuleNotFoundError(message) from error
|
|
124
|
+
raise ModuleNotFoundError(message)
|
|
125
|
+
|
|
126
|
+
if getattr(dataloader_cls, "glitch", None) is None:
|
|
127
|
+
|
|
128
|
+
def glitch(
|
|
129
|
+
self: Any,
|
|
130
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
131
|
+
*,
|
|
132
|
+
columns: str | int | Sequence[str | int] | None = None,
|
|
133
|
+
seed: int = 151,
|
|
134
|
+
) -> _GlitchedDataLoader:
|
|
135
|
+
"""Return a lazily glitched view of the loader's batches."""
|
|
136
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
137
|
+
normalized = normalize_column_spec(columns)
|
|
138
|
+
return _GlitchedDataLoader(self, gaggle, columns=normalized)
|
|
139
|
+
|
|
140
|
+
setattr(dataloader_cls, "glitch", glitch)
|
|
141
|
+
|
|
142
|
+
return cast(type[Any], dataloader_cls)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _optional_dataloader_class() -> type[Any] | None:
|
|
146
|
+
"""Return the PyTorch :class:`~torch.utils.data.DataLoader` when importable."""
|
|
147
|
+
dataloader_cls = get_torch_dataloader()
|
|
148
|
+
if dataloader_cls is None:
|
|
149
|
+
return None
|
|
150
|
+
return cast(type[Any], dataloader_cls)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def install() -> None:
|
|
154
|
+
"""Monkeypatch PyTorch's :class:`~torch.utils.data.DataLoader` with ``.glitch``."""
|
|
155
|
+
_ensure_dataloader_class()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
DataLoader: type[Any] | None
|
|
159
|
+
_DataLoaderAlias = _optional_dataloader_class()
|
|
160
|
+
if _DataLoaderAlias is not None:
|
|
161
|
+
DataLoader = _ensure_dataloader_class()
|
|
162
|
+
else: # pragma: no cover - torch is an optional dependency
|
|
163
|
+
DataLoader = None
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
__all__ = ["DataLoader", "install"]
|