glitchlings 0.4.1__cp312-cp312-win_amd64.whl → 0.4.3__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +30 -17
- glitchlings/__main__.py +0 -1
- glitchlings/_zoo_rust.cp312-win_amd64.pyd +0 -0
- glitchlings/compat.py +284 -0
- glitchlings/config.py +164 -34
- glitchlings/config.toml +1 -1
- glitchlings/dlc/__init__.py +3 -1
- glitchlings/dlc/_shared.py +68 -0
- glitchlings/dlc/huggingface.py +26 -41
- glitchlings/dlc/prime.py +64 -101
- glitchlings/dlc/pytorch.py +216 -0
- glitchlings/dlc/pytorch_lightning.py +233 -0
- glitchlings/lexicon/__init__.py +12 -33
- glitchlings/lexicon/_cache.py +21 -22
- glitchlings/lexicon/data/default_vector_cache.json +80 -14
- glitchlings/lexicon/metrics.py +1 -8
- glitchlings/lexicon/vector.py +109 -49
- glitchlings/lexicon/wordnet.py +89 -49
- glitchlings/main.py +30 -24
- glitchlings/util/__init__.py +18 -4
- glitchlings/util/adapters.py +27 -0
- glitchlings/zoo/__init__.py +26 -15
- glitchlings/zoo/_ocr_confusions.py +1 -3
- glitchlings/zoo/_rate.py +1 -4
- glitchlings/zoo/_sampling.py +0 -1
- glitchlings/zoo/_text_utils.py +1 -5
- glitchlings/zoo/adjax.py +2 -4
- glitchlings/zoo/apostrofae.py +128 -0
- glitchlings/zoo/assets/__init__.py +0 -0
- glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
- glitchlings/zoo/core.py +152 -87
- glitchlings/zoo/jargoyle.py +50 -45
- glitchlings/zoo/mim1c.py +11 -10
- glitchlings/zoo/redactyl.py +16 -16
- glitchlings/zoo/reduple.py +5 -3
- glitchlings/zoo/rushmore.py +4 -10
- glitchlings/zoo/scannequin.py +7 -6
- glitchlings/zoo/typogre.py +8 -9
- glitchlings/zoo/zeedub.py +6 -3
- {glitchlings-0.4.1.dist-info → glitchlings-0.4.3.dist-info}/METADATA +101 -4
- glitchlings-0.4.3.dist-info/RECORD +46 -0
- glitchlings/lexicon/graph.py +0 -290
- glitchlings-0.4.1.dist-info/RECORD +0 -39
- {glitchlings-0.4.1.dist-info → glitchlings-0.4.3.dist-info}/WHEEL +0 -0
- {glitchlings-0.4.1.dist-info → glitchlings-0.4.3.dist-info}/entry_points.txt +0 -0
- {glitchlings-0.4.1.dist-info → glitchlings-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {glitchlings-0.4.1.dist-info → glitchlings-0.4.3.dist-info}/top_level.txt +0 -0
glitchlings/dlc/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Optional DLC integrations for Glitchlings."""
|
|
2
2
|
|
|
3
3
|
from .huggingface import install as install_huggingface
|
|
4
|
+
from .pytorch import install as install_pytorch
|
|
5
|
+
from .pytorch_lightning import install as install_pytorch_lightning
|
|
4
6
|
|
|
5
|
-
__all__ = ["install_huggingface"]
|
|
7
|
+
__all__ = ["install_huggingface", "install_pytorch", "install_pytorch_lightning"]
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Shared utilities for DLC integrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def resolve_environment(
|
|
10
|
+
env: Any,
|
|
11
|
+
*,
|
|
12
|
+
loader: Callable[[str], Any],
|
|
13
|
+
environment_type: type[Any],
|
|
14
|
+
) -> Any:
|
|
15
|
+
"""Return a fully-instantiated verifier environment."""
|
|
16
|
+
if isinstance(env, str):
|
|
17
|
+
env = loader(env)
|
|
18
|
+
|
|
19
|
+
if not isinstance(env, environment_type):
|
|
20
|
+
raise TypeError("Invalid environment type")
|
|
21
|
+
|
|
22
|
+
return env
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
26
|
+
"""Identify which dataset columns should be corrupted."""
|
|
27
|
+
available = set(getattr(dataset, "column_names", ()))
|
|
28
|
+
|
|
29
|
+
if columns is not None:
|
|
30
|
+
missing = sorted(set(columns) - available)
|
|
31
|
+
if missing:
|
|
32
|
+
missing_str = ", ".join(missing)
|
|
33
|
+
raise ValueError(f"Columns not found in dataset: {missing_str}")
|
|
34
|
+
return list(columns)
|
|
35
|
+
|
|
36
|
+
for candidate in ("prompt", "question"):
|
|
37
|
+
if candidate in available:
|
|
38
|
+
return [candidate]
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
dataset_length = len(dataset)
|
|
42
|
+
except TypeError:
|
|
43
|
+
preview_rows: list[dict[str, Any]]
|
|
44
|
+
take_fn = getattr(dataset, "take", None)
|
|
45
|
+
if callable(take_fn):
|
|
46
|
+
preview_rows = list(take_fn(1))
|
|
47
|
+
else:
|
|
48
|
+
iterator = iter(dataset)
|
|
49
|
+
try:
|
|
50
|
+
first_row = next(iterator)
|
|
51
|
+
except StopIteration:
|
|
52
|
+
preview_rows = []
|
|
53
|
+
else:
|
|
54
|
+
preview_rows = [first_row]
|
|
55
|
+
sample = dict(preview_rows[0]) if preview_rows else {}
|
|
56
|
+
else:
|
|
57
|
+
sample = dataset[0] if dataset_length else {}
|
|
58
|
+
inferred = [
|
|
59
|
+
name for name in getattr(dataset, "column_names", ()) if isinstance(sample.get(name), str)
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
if inferred:
|
|
63
|
+
return inferred
|
|
64
|
+
|
|
65
|
+
raise ValueError("Unable to determine which dataset columns to corrupt.")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
__all__ = ["resolve_columns", "resolve_environment"]
|
glitchlings/dlc/huggingface.py
CHANGED
|
@@ -3,21 +3,15 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, cast
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
_DatasetsDataset = None # type: ignore[assignment]
|
|
12
|
-
else:
|
|
13
|
-
_datasets_error = None
|
|
14
|
-
|
|
15
|
-
from ..zoo import Gaggle, Glitchling, summon
|
|
8
|
+
from ..compat import datasets, get_datasets_dataset, require_datasets
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
16
11
|
|
|
17
12
|
|
|
18
13
|
def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
19
14
|
"""Normalise a column specification to a list."""
|
|
20
|
-
|
|
21
15
|
if isinstance(column, str):
|
|
22
16
|
return [column]
|
|
23
17
|
|
|
@@ -27,20 +21,6 @@ def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
|
27
21
|
return normalised
|
|
28
22
|
|
|
29
23
|
|
|
30
|
-
def _as_gaggle(glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling], seed: int) -> Gaggle:
|
|
31
|
-
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
32
|
-
|
|
33
|
-
if isinstance(glitchlings, Gaggle):
|
|
34
|
-
return glitchlings
|
|
35
|
-
|
|
36
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
37
|
-
resolved: Iterable[str | Glitchling] = [glitchlings]
|
|
38
|
-
else:
|
|
39
|
-
resolved = glitchlings
|
|
40
|
-
|
|
41
|
-
return summon(list(resolved), seed=seed)
|
|
42
|
-
|
|
43
|
-
|
|
44
24
|
def _glitch_dataset(
|
|
45
25
|
dataset: Any,
|
|
46
26
|
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
@@ -48,23 +28,28 @@ def _glitch_dataset(
|
|
|
48
28
|
*,
|
|
49
29
|
seed: int = 151,
|
|
50
30
|
) -> Any:
|
|
51
|
-
"""
|
|
52
|
-
|
|
31
|
+
"""Apply glitchlings to the provided dataset columns."""
|
|
53
32
|
columns = _normalise_columns(column)
|
|
54
|
-
gaggle =
|
|
33
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
55
34
|
return gaggle.corrupt_dataset(dataset, columns)
|
|
56
35
|
|
|
57
36
|
|
|
58
37
|
def _ensure_dataset_class() -> Any:
|
|
59
38
|
"""Return the Hugging Face :class:`~datasets.Dataset` patched with ``.glitch``."""
|
|
60
|
-
|
|
61
|
-
if
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
39
|
+
dataset_cls = get_datasets_dataset()
|
|
40
|
+
if dataset_cls is None: # pragma: no cover - datasets is an install-time dependency
|
|
41
|
+
require_datasets("datasets is not installed")
|
|
42
|
+
dataset_cls = get_datasets_dataset()
|
|
43
|
+
if dataset_cls is None:
|
|
44
|
+
message = "datasets is not installed"
|
|
45
|
+
error = datasets.error
|
|
46
|
+
if error is not None:
|
|
47
|
+
raise ModuleNotFoundError(message) from error
|
|
48
|
+
raise ModuleNotFoundError(message)
|
|
49
|
+
|
|
50
|
+
if getattr(dataset_cls, "glitch", None) is None:
|
|
51
|
+
|
|
52
|
+
def glitch(
|
|
68
53
|
self: Any,
|
|
69
54
|
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
70
55
|
*,
|
|
@@ -73,24 +58,24 @@ def _ensure_dataset_class() -> Any:
|
|
|
73
58
|
**_: Any,
|
|
74
59
|
) -> Any:
|
|
75
60
|
"""Return a lazily corrupted copy of the dataset."""
|
|
76
|
-
|
|
77
61
|
return _glitch_dataset(self, glitchlings, column, seed=seed)
|
|
78
62
|
|
|
79
|
-
setattr(
|
|
63
|
+
setattr(dataset_cls, "glitch", glitch)
|
|
80
64
|
|
|
81
|
-
return
|
|
65
|
+
return cast(type[Any], dataset_cls)
|
|
82
66
|
|
|
83
67
|
|
|
84
68
|
def install() -> None:
|
|
85
69
|
"""Monkeypatch the Hugging Face :class:`~datasets.Dataset` with ``.glitch``."""
|
|
86
|
-
|
|
87
70
|
_ensure_dataset_class()
|
|
88
71
|
|
|
89
72
|
|
|
90
|
-
|
|
73
|
+
Dataset: type[Any] | None
|
|
74
|
+
_DatasetAlias = get_datasets_dataset()
|
|
75
|
+
if _DatasetAlias is not None:
|
|
91
76
|
Dataset = _ensure_dataset_class()
|
|
92
77
|
else: # pragma: no cover - datasets is an install-time dependency
|
|
93
|
-
Dataset = None
|
|
78
|
+
Dataset = None
|
|
94
79
|
|
|
95
80
|
|
|
96
81
|
__all__ = ["Dataset", "install"]
|
glitchlings/dlc/prime.py
CHANGED
|
@@ -4,79 +4,60 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterable, Sequence
|
|
6
6
|
from enum import Enum
|
|
7
|
-
from typing import Any, Callable
|
|
7
|
+
from typing import Any, Callable, Protocol, cast
|
|
8
8
|
|
|
9
|
-
import
|
|
9
|
+
from ..compat import require_datasets, require_jellyfish, require_verifiers
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling, Mim1c, Typogre
|
|
12
|
+
from ._shared import resolve_columns as _resolve_columns_shared
|
|
13
|
+
from ._shared import resolve_environment as _resolve_environment_shared
|
|
10
14
|
|
|
11
|
-
from jellyfish import damerau_levenshtein_distance
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
16
|
-
Dataset = object # type: ignore[assignment]
|
|
17
|
-
else:
|
|
18
|
-
if Dataset is None: # pragma: no cover - optional dependency
|
|
19
|
-
Dataset = object # type: ignore[assignment]
|
|
16
|
+
class VerifierEnvironment(Protocol):
|
|
17
|
+
"""Minimal interface for verifiers environments."""
|
|
20
18
|
|
|
21
|
-
|
|
19
|
+
dataset: Any
|
|
22
20
|
|
|
23
21
|
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
if isinstance(env, str):
|
|
28
|
-
env = vf.load_environment(env)
|
|
22
|
+
class VerifierSingleTurnEnv(Protocol):
|
|
23
|
+
"""Minimal interface for single-turn verifier environments."""
|
|
29
24
|
|
|
30
|
-
|
|
31
|
-
|
|
25
|
+
dataset: Any
|
|
26
|
+
rubric: Any
|
|
32
27
|
|
|
33
|
-
return env
|
|
34
28
|
|
|
29
|
+
vf = require_verifiers("verifiers is not installed; install glitchlings[prime]")
|
|
30
|
+
_jellyfish = require_jellyfish("jellyfish is not installed; install glitchlings[prime]")
|
|
31
|
+
damerau_levenshtein_distance = _jellyfish.damerau_levenshtein_distance
|
|
35
32
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
33
|
+
try:
|
|
34
|
+
from .huggingface import Dataset as _HuggingFaceDataset
|
|
35
|
+
except ModuleNotFoundError: # pragma: no cover - optional dependency
|
|
36
|
+
_HuggingFaceDataset = None
|
|
37
|
+
else:
|
|
38
|
+
if _HuggingFaceDataset is None: # pragma: no cover - optional dependency
|
|
39
|
+
_HuggingFaceDataset = None
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return list(columns)
|
|
41
|
+
Dataset: type[Any]
|
|
42
|
+
if _HuggingFaceDataset is None:
|
|
43
|
+
Dataset = object
|
|
44
|
+
else:
|
|
45
|
+
Dataset = _HuggingFaceDataset
|
|
47
46
|
|
|
48
|
-
for candidate in ("prompt", "question"):
|
|
49
|
-
if candidate in available:
|
|
50
|
-
return [candidate]
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
iterator = iter(dataset)
|
|
61
|
-
try:
|
|
62
|
-
first_row = next(iterator)
|
|
63
|
-
except StopIteration:
|
|
64
|
-
preview_rows = []
|
|
65
|
-
else:
|
|
66
|
-
preview_rows = [first_row]
|
|
67
|
-
sample = dict(preview_rows[0]) if preview_rows else {}
|
|
68
|
-
else:
|
|
69
|
-
sample = dataset[0] if dataset_length else {}
|
|
70
|
-
inferred = [
|
|
71
|
-
name
|
|
72
|
-
for name in dataset.column_names
|
|
73
|
-
if isinstance(sample.get(name), str)
|
|
74
|
-
]
|
|
48
|
+
def _resolve_environment(env: str | VerifierEnvironment) -> VerifierEnvironment:
|
|
49
|
+
"""Return a fully-instantiated verifier environment."""
|
|
50
|
+
resolved = _resolve_environment_shared(
|
|
51
|
+
env,
|
|
52
|
+
loader=vf.load_environment,
|
|
53
|
+
environment_type=cast(type[Any], vf.Environment),
|
|
54
|
+
)
|
|
55
|
+
return cast(VerifierEnvironment, resolved)
|
|
75
56
|
|
|
76
|
-
if inferred:
|
|
77
|
-
return inferred
|
|
78
57
|
|
|
79
|
-
|
|
58
|
+
def _resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
59
|
+
"""Identify which dataset columns should be corrupted."""
|
|
60
|
+
return _resolve_columns_shared(dataset, columns)
|
|
80
61
|
|
|
81
62
|
|
|
82
63
|
class Difficulty(Enum):
|
|
@@ -90,12 +71,11 @@ class Difficulty(Enum):
|
|
|
90
71
|
|
|
91
72
|
|
|
92
73
|
def tutorial_level(
|
|
93
|
-
env:
|
|
74
|
+
env: VerifierEnvironment | str,
|
|
94
75
|
seed: int = 151,
|
|
95
76
|
difficulty: Difficulty = Difficulty.Normal,
|
|
96
|
-
) ->
|
|
77
|
+
) -> VerifierEnvironment:
|
|
97
78
|
"""Create a low-corruption environment using tuned defaults."""
|
|
98
|
-
|
|
99
79
|
tuned_mim1c = Mim1c(rate=0.01 * difficulty.value)
|
|
100
80
|
tuned_typogre = Typogre(rate=0.025 * difficulty.value)
|
|
101
81
|
|
|
@@ -107,28 +87,19 @@ def tutorial_level(
|
|
|
107
87
|
|
|
108
88
|
|
|
109
89
|
def load_environment(
|
|
110
|
-
env: str |
|
|
90
|
+
env: str | VerifierEnvironment,
|
|
111
91
|
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle | None = None,
|
|
112
92
|
*,
|
|
113
93
|
seed: int = 151,
|
|
114
94
|
columns: Sequence[str] | None = None,
|
|
115
|
-
) ->
|
|
95
|
+
) -> VerifierEnvironment:
|
|
116
96
|
"""Load an environment and optionally corrupt it with glitchlings."""
|
|
117
|
-
|
|
118
97
|
environment = _resolve_environment(env)
|
|
119
98
|
|
|
120
99
|
if glitchlings is None:
|
|
121
100
|
return environment
|
|
122
101
|
|
|
123
|
-
|
|
124
|
-
gaggle = glitchlings
|
|
125
|
-
else:
|
|
126
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
127
|
-
resolved = [glitchlings]
|
|
128
|
-
else:
|
|
129
|
-
resolved = list(glitchlings)
|
|
130
|
-
|
|
131
|
-
gaggle = summon(resolved, seed=seed)
|
|
102
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
132
103
|
|
|
133
104
|
dataset = environment.dataset
|
|
134
105
|
corrupt_columns = _resolve_columns(dataset, columns)
|
|
@@ -142,21 +113,11 @@ def _as_gaggle(
|
|
|
142
113
|
seed: int,
|
|
143
114
|
) -> Gaggle:
|
|
144
115
|
"""Coerce any supported glitchling specification into a :class:`Gaggle`."""
|
|
145
|
-
|
|
146
|
-
if isinstance(glitchlings, Gaggle):
|
|
147
|
-
return glitchlings
|
|
148
|
-
|
|
149
|
-
if isinstance(glitchlings, (Glitchling, str)):
|
|
150
|
-
resolved: Iterable[str | Glitchling] = [glitchlings]
|
|
151
|
-
else:
|
|
152
|
-
resolved = glitchlings
|
|
153
|
-
|
|
154
|
-
return summon(list(resolved), seed=seed)
|
|
116
|
+
return coerce_gaggle(glitchlings, seed=seed)
|
|
155
117
|
|
|
156
118
|
|
|
157
119
|
def _extract_completion_text(completion: Any) -> str:
|
|
158
120
|
"""Normalise a completion payload into a plain string."""
|
|
159
|
-
|
|
160
121
|
if isinstance(completion, str):
|
|
161
122
|
return completion
|
|
162
123
|
|
|
@@ -175,11 +136,10 @@ def symmetric_damerau_levenshtein_similarity(
|
|
|
175
136
|
answer: str,
|
|
176
137
|
) -> float:
|
|
177
138
|
"""Return ``1 - (distance / max_len)`` using Damerau-Levenshtein distance."""
|
|
178
|
-
|
|
179
139
|
completion_text = _extract_completion_text(completion)
|
|
180
140
|
target = answer or ""
|
|
181
141
|
denominator = max(len(completion_text), len(target), 1)
|
|
182
|
-
distance = damerau_levenshtein_distance(completion_text, target)
|
|
142
|
+
distance = cast(int, damerau_levenshtein_distance(completion_text, target))
|
|
183
143
|
score = 1.0 - (distance / denominator)
|
|
184
144
|
return max(0.0, min(1.0, score))
|
|
185
145
|
|
|
@@ -199,32 +159,34 @@ def echo_chamber(
|
|
|
199
159
|
reward_function: Callable[..., float] | None = None,
|
|
200
160
|
split: str | None = None,
|
|
201
161
|
**load_dataset_kwargs: Any,
|
|
202
|
-
) ->
|
|
162
|
+
) -> VerifierSingleTurnEnv:
|
|
203
163
|
"""Create an Echo Chamber Prime environment from a Hugging Face dataset column.
|
|
204
164
|
|
|
205
165
|
Args:
|
|
206
166
|
dataset_id: Identifier of the Hugging Face dataset to load.
|
|
207
167
|
column: Name of the column whose text should be glitched.
|
|
208
168
|
glitchlings: Glitchling specifiers that will corrupt the prompts.
|
|
209
|
-
seed: RNG seed forwarded to :func:`
|
|
169
|
+
seed: RNG seed forwarded to :func:`glitchlings.util.adapters.coerce_gaggle`.
|
|
210
170
|
instructions: System instructions supplied to the environment prompts.
|
|
211
171
|
reward_function: Optional callable used to score completions. Defaults to
|
|
212
172
|
:func:`symmetric_damerau_levenshtein_similarity` when omitted.
|
|
213
173
|
split: Optional dataset split to load.
|
|
214
174
|
**load_dataset_kwargs: Extra keyword arguments forwarded to
|
|
215
175
|
:func:`datasets.load_dataset`.
|
|
216
|
-
"""
|
|
217
176
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
177
|
+
"""
|
|
178
|
+
datasets_module = require_datasets("datasets is required to build an echo chamber")
|
|
179
|
+
load_dataset = getattr(datasets_module, "load_dataset", None)
|
|
180
|
+
if load_dataset is None: # pragma: no cover - defensive
|
|
221
181
|
message = "datasets is required to build an echo chamber"
|
|
222
|
-
raise ModuleNotFoundError(message)
|
|
182
|
+
raise ModuleNotFoundError(message)
|
|
223
183
|
|
|
224
|
-
|
|
184
|
+
dataset_dict_cls = getattr(datasets_module, "DatasetDict", dict)
|
|
185
|
+
|
|
186
|
+
hf_dataset: Any
|
|
225
187
|
if split is None:
|
|
226
188
|
hf_dataset = load_dataset(dataset_id, **load_dataset_kwargs)
|
|
227
|
-
if isinstance(hf_dataset,
|
|
189
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
228
190
|
try:
|
|
229
191
|
hf_dataset = next(iter(hf_dataset.values()))
|
|
230
192
|
except StopIteration as exc: # pragma: no cover - defensive
|
|
@@ -232,10 +194,8 @@ def echo_chamber(
|
|
|
232
194
|
else:
|
|
233
195
|
hf_dataset = load_dataset(dataset_id, split=split, **load_dataset_kwargs)
|
|
234
196
|
|
|
235
|
-
if isinstance(hf_dataset,
|
|
236
|
-
raise ValueError(
|
|
237
|
-
"Specify which split to use when the dataset loads as a DatasetDict."
|
|
238
|
-
)
|
|
197
|
+
if isinstance(hf_dataset, dataset_dict_cls):
|
|
198
|
+
raise ValueError("Specify which split to use when the dataset loads as a DatasetDict.")
|
|
239
199
|
|
|
240
200
|
filtered_dataset = hf_dataset.filter(
|
|
241
201
|
lambda row: row.get(column) is not None,
|
|
@@ -259,7 +219,7 @@ def echo_chamber(
|
|
|
259
219
|
)
|
|
260
220
|
|
|
261
221
|
try:
|
|
262
|
-
dataset_length = len(base_dataset)
|
|
222
|
+
dataset_length = len(base_dataset)
|
|
263
223
|
except TypeError:
|
|
264
224
|
preview_rows: list[dict[str, Any]]
|
|
265
225
|
take_fn = getattr(base_dataset, "take", None)
|
|
@@ -288,4 +248,7 @@ def echo_chamber(
|
|
|
288
248
|
|
|
289
249
|
rubric_func = reward_function or symmetric_damerau_levenshtein_similarity
|
|
290
250
|
rubric = vf.Rubric(funcs=[rubric_func], weights=[1.0])
|
|
291
|
-
return
|
|
251
|
+
return cast(
|
|
252
|
+
VerifierSingleTurnEnv,
|
|
253
|
+
vf.SingleTurnEnv(dataset=glitched_dataset, rubric=rubric),
|
|
254
|
+
)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Integration helpers for PyTorch data loaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import get_torch_dataloader, require_torch
|
|
9
|
+
from ..compat import torch as _torch_dependency
|
|
10
|
+
from ..util.adapters import coerce_gaggle
|
|
11
|
+
from ..zoo import Gaggle, Glitchling
|
|
12
|
+
from ..zoo.core import _is_transcript
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _normalise_columns(columns: str | int | Sequence[str | int] | None) -> list[str | int] | None:
|
|
16
|
+
"""Normalise a column specification into a list of keys or indices."""
|
|
17
|
+
if columns is None:
|
|
18
|
+
return None
|
|
19
|
+
|
|
20
|
+
if isinstance(columns, (str, int)):
|
|
21
|
+
return [columns]
|
|
22
|
+
|
|
23
|
+
normalised = list(columns)
|
|
24
|
+
if not normalised:
|
|
25
|
+
raise ValueError("At least one column must be specified")
|
|
26
|
+
return normalised
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_textual_candidate(value: Any) -> bool:
|
|
30
|
+
"""Return ``True`` when ``value`` looks like text that glitchlings can corrupt."""
|
|
31
|
+
if isinstance(value, str):
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
if _is_transcript(value, allow_empty=False, require_all_content=True):
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
|
|
38
|
+
if not value:
|
|
39
|
+
return False
|
|
40
|
+
if all(isinstance(item, str) for item in value):
|
|
41
|
+
return True
|
|
42
|
+
if _is_transcript(list(value), allow_empty=False, require_all_content=True):
|
|
43
|
+
return True
|
|
44
|
+
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _corrupt_text(value: Any, gaggle: Gaggle) -> Any:
|
|
49
|
+
"""Return ``value`` with glitchlings applied when possible."""
|
|
50
|
+
if isinstance(value, str):
|
|
51
|
+
return gaggle.corrupt(value)
|
|
52
|
+
|
|
53
|
+
if _is_transcript(value, allow_empty=True):
|
|
54
|
+
return gaggle.corrupt(value)
|
|
55
|
+
|
|
56
|
+
if isinstance(value, list) and value and all(isinstance(item, str) for item in value):
|
|
57
|
+
return [gaggle.corrupt(item) for item in value]
|
|
58
|
+
|
|
59
|
+
if isinstance(value, tuple) and value and all(isinstance(item, str) for item in value):
|
|
60
|
+
return tuple(gaggle.corrupt(item) for item in value)
|
|
61
|
+
|
|
62
|
+
return value
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _apply_to_batch(batch: Any, targets: list[str | int] | None, gaggle: Gaggle) -> Any:
|
|
66
|
+
"""Return ``batch`` with glitchlings applied to the specified ``targets``."""
|
|
67
|
+
if targets is None:
|
|
68
|
+
return _corrupt_text(batch, gaggle)
|
|
69
|
+
|
|
70
|
+
if isinstance(batch, Mapping):
|
|
71
|
+
mutated = cast(MutableMapping[str, Any], dict(batch))
|
|
72
|
+
for key in targets:
|
|
73
|
+
if not isinstance(key, str):
|
|
74
|
+
raise TypeError("Mapping batches require string column names")
|
|
75
|
+
if key not in mutated:
|
|
76
|
+
raise ValueError(f"Column '{key}' not found in DataLoader batch")
|
|
77
|
+
mutated[key] = _corrupt_text(mutated[key], gaggle)
|
|
78
|
+
return mutated
|
|
79
|
+
|
|
80
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
81
|
+
mutated_sequence = list(batch)
|
|
82
|
+
for index in targets:
|
|
83
|
+
if not isinstance(index, int):
|
|
84
|
+
raise TypeError("Sequence batches require integer column indices")
|
|
85
|
+
try:
|
|
86
|
+
mutated_sequence[index] = _corrupt_text(mutated_sequence[index], gaggle)
|
|
87
|
+
except IndexError as exc: # pragma: no cover - defensive
|
|
88
|
+
raise IndexError("Column index out of range for DataLoader batch") from exc
|
|
89
|
+
if isinstance(batch, tuple):
|
|
90
|
+
return tuple(mutated_sequence)
|
|
91
|
+
return mutated_sequence
|
|
92
|
+
|
|
93
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _infer_targets(batch: Any) -> list[str | int] | None:
|
|
97
|
+
"""Infer which fields should be glitched from a representative ``batch``."""
|
|
98
|
+
if isinstance(batch, Mapping):
|
|
99
|
+
inferred = [key for key, value in batch.items() if _is_textual_candidate(value)]
|
|
100
|
+
if inferred:
|
|
101
|
+
return inferred
|
|
102
|
+
raise ValueError("Unable to infer which mapping columns contain text")
|
|
103
|
+
|
|
104
|
+
if isinstance(batch, Sequence) and not isinstance(batch, (bytes, bytearray, str)):
|
|
105
|
+
inferred_indices: list[str | int] = [
|
|
106
|
+
idx for idx, value in enumerate(batch) if _is_textual_candidate(value)
|
|
107
|
+
]
|
|
108
|
+
if inferred_indices:
|
|
109
|
+
return inferred_indices
|
|
110
|
+
raise ValueError("Unable to infer which sequence indices contain text")
|
|
111
|
+
|
|
112
|
+
if _is_textual_candidate(batch):
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
raise TypeError("Unsupported DataLoader batch type for glitching")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class _GlitchedDataLoader(Iterable[Any]):
|
|
119
|
+
"""Wrapper that applies glitchlings lazily to each batch from a data loader."""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
dataloader: Any,
|
|
124
|
+
gaggle: Gaggle,
|
|
125
|
+
*,
|
|
126
|
+
columns: list[str | int] | None,
|
|
127
|
+
) -> None:
|
|
128
|
+
self._dataloader = dataloader
|
|
129
|
+
self._gaggle = gaggle
|
|
130
|
+
self._explicit_columns = columns
|
|
131
|
+
self._inferred_columns: list[str | int] | None | _Sentinel = _UNINITIALISED
|
|
132
|
+
|
|
133
|
+
def __iter__(self) -> Iterator[Any]:
|
|
134
|
+
# Reset all glitchling RNGs before each fresh pass for determinism.
|
|
135
|
+
self._gaggle.sort_glitchlings()
|
|
136
|
+
for batch in self._dataloader:
|
|
137
|
+
targets = self._resolve_columns(batch)
|
|
138
|
+
yield _apply_to_batch(batch, targets, self._gaggle)
|
|
139
|
+
|
|
140
|
+
def __len__(self) -> int:
|
|
141
|
+
return len(self._dataloader)
|
|
142
|
+
|
|
143
|
+
def __getattr__(self, attribute: str) -> Any:
|
|
144
|
+
return getattr(self._dataloader, attribute)
|
|
145
|
+
|
|
146
|
+
def _resolve_columns(self, batch: Any) -> list[str | int] | None:
|
|
147
|
+
if self._explicit_columns is not None:
|
|
148
|
+
return self._explicit_columns
|
|
149
|
+
|
|
150
|
+
if self._inferred_columns is _UNINITIALISED:
|
|
151
|
+
self._inferred_columns = _infer_targets(batch)
|
|
152
|
+
|
|
153
|
+
return cast(list[str | int] | None, self._inferred_columns)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class _Sentinel:
|
|
157
|
+
"""Sentinel type for deferred column inference."""
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
_UNINITIALISED = _Sentinel()
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _ensure_dataloader_class() -> type[Any]:
|
|
164
|
+
"""Return :class:`torch.utils.data.DataLoader` patched with ``.glitch``."""
|
|
165
|
+
dataloader_cls = get_torch_dataloader()
|
|
166
|
+
if dataloader_cls is None:
|
|
167
|
+
require_torch("torch is not installed; install glitchlings[torch]")
|
|
168
|
+
dataloader_cls = get_torch_dataloader()
|
|
169
|
+
if dataloader_cls is None: # pragma: no cover - defensive
|
|
170
|
+
message = "torch.utils.data.DataLoader is not available"
|
|
171
|
+
error = _torch_dependency.error
|
|
172
|
+
if error is not None:
|
|
173
|
+
raise ModuleNotFoundError(message) from error
|
|
174
|
+
raise ModuleNotFoundError(message)
|
|
175
|
+
|
|
176
|
+
if getattr(dataloader_cls, "glitch", None) is None:
|
|
177
|
+
|
|
178
|
+
def glitch(
|
|
179
|
+
self: Any,
|
|
180
|
+
glitchlings: Iterable[str | Glitchling] | Glitchling | str | Gaggle,
|
|
181
|
+
*,
|
|
182
|
+
columns: str | int | Sequence[str | int] | None = None,
|
|
183
|
+
seed: int = 151,
|
|
184
|
+
) -> _GlitchedDataLoader:
|
|
185
|
+
"""Return a lazily glitched view of the loader's batches."""
|
|
186
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
187
|
+
normalised = _normalise_columns(columns)
|
|
188
|
+
return _GlitchedDataLoader(self, gaggle, columns=normalised)
|
|
189
|
+
|
|
190
|
+
setattr(dataloader_cls, "glitch", glitch)
|
|
191
|
+
|
|
192
|
+
return cast(type[Any], dataloader_cls)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _optional_dataloader_class() -> type[Any] | None:
|
|
196
|
+
"""Return the PyTorch :class:`~torch.utils.data.DataLoader` when importable."""
|
|
197
|
+
dataloader_cls = get_torch_dataloader()
|
|
198
|
+
if dataloader_cls is None:
|
|
199
|
+
return None
|
|
200
|
+
return cast(type[Any], dataloader_cls)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def install() -> None:
|
|
204
|
+
"""Monkeypatch PyTorch's :class:`~torch.utils.data.DataLoader` with ``.glitch``."""
|
|
205
|
+
_ensure_dataloader_class()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
DataLoader: type[Any] | None
|
|
209
|
+
_DataLoaderAlias = _optional_dataloader_class()
|
|
210
|
+
if _DataLoaderAlias is not None:
|
|
211
|
+
DataLoader = _ensure_dataloader_class()
|
|
212
|
+
else: # pragma: no cover - torch is an optional dependency
|
|
213
|
+
DataLoader = None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
__all__ = ["DataLoader", "install"]
|