glitchlings 0.4.4__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of glitchlings might be problematic. Click here for more details.
- glitchlings/__init__.py +67 -0
- glitchlings/__main__.py +8 -0
- glitchlings/_zoo_rust.cp313-win_amd64.pyd +0 -0
- glitchlings/compat.py +284 -0
- glitchlings/config.py +388 -0
- glitchlings/config.toml +3 -0
- glitchlings/dlc/__init__.py +7 -0
- glitchlings/dlc/_shared.py +153 -0
- glitchlings/dlc/huggingface.py +81 -0
- glitchlings/dlc/prime.py +254 -0
- glitchlings/dlc/pytorch.py +166 -0
- glitchlings/dlc/pytorch_lightning.py +215 -0
- glitchlings/lexicon/__init__.py +192 -0
- glitchlings/lexicon/_cache.py +110 -0
- glitchlings/lexicon/data/default_vector_cache.json +82 -0
- glitchlings/lexicon/metrics.py +162 -0
- glitchlings/lexicon/vector.py +651 -0
- glitchlings/lexicon/wordnet.py +232 -0
- glitchlings/main.py +364 -0
- glitchlings/util/__init__.py +195 -0
- glitchlings/util/adapters.py +27 -0
- glitchlings/zoo/__init__.py +168 -0
- glitchlings/zoo/_ocr_confusions.py +32 -0
- glitchlings/zoo/_rate.py +131 -0
- glitchlings/zoo/_rust_extensions.py +143 -0
- glitchlings/zoo/_sampling.py +54 -0
- glitchlings/zoo/_text_utils.py +100 -0
- glitchlings/zoo/adjax.py +128 -0
- glitchlings/zoo/apostrofae.py +127 -0
- glitchlings/zoo/assets/__init__.py +0 -0
- glitchlings/zoo/assets/apostrofae_pairs.json +32 -0
- glitchlings/zoo/core.py +582 -0
- glitchlings/zoo/jargoyle.py +335 -0
- glitchlings/zoo/mim1c.py +109 -0
- glitchlings/zoo/ocr_confusions.tsv +30 -0
- glitchlings/zoo/redactyl.py +193 -0
- glitchlings/zoo/reduple.py +148 -0
- glitchlings/zoo/rushmore.py +153 -0
- glitchlings/zoo/scannequin.py +171 -0
- glitchlings/zoo/typogre.py +231 -0
- glitchlings/zoo/zeedub.py +185 -0
- glitchlings-0.4.4.dist-info/METADATA +627 -0
- glitchlings-0.4.4.dist-info/RECORD +47 -0
- glitchlings-0.4.4.dist-info/WHEEL +5 -0
- glitchlings-0.4.4.dist-info/entry_points.txt +2 -0
- glitchlings-0.4.4.dist-info/licenses/LICENSE +201 -0
- glitchlings-0.4.4.dist-info/top_level.txt +1 -0
glitchlings/config.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""Configuration utilities for runtime behaviour and declarative attack setups."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import os
|
|
7
|
+
import warnings
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from io import TextIOBase
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import IO, TYPE_CHECKING, Any, Mapping, Protocol, Sequence, cast
|
|
12
|
+
|
|
13
|
+
from glitchlings.compat import jsonschema
|
|
14
|
+
|
|
15
|
+
try: # Python 3.11+
|
|
16
|
+
import tomllib as _tomllib
|
|
17
|
+
except ModuleNotFoundError: # pragma: no cover - Python < 3.11
|
|
18
|
+
_tomllib = importlib.import_module("tomli")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _TomllibModule(Protocol):
|
|
22
|
+
def load(self, fp: IO[bytes]) -> Any:
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
tomllib = cast(_TomllibModule, _tomllib)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _YamlModule(Protocol):
|
|
30
|
+
YAMLError: type[Exception]
|
|
31
|
+
|
|
32
|
+
def safe_load(self, stream: str) -> Any:
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
yaml = cast(_YamlModule, importlib.import_module("yaml"))
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING: # pragma: no cover - typing only
|
|
39
|
+
from .zoo import Gaggle, Glitchling
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
CONFIG_ENV_VAR = "GLITCHLINGS_CONFIG"
|
|
43
|
+
DEFAULT_CONFIG_PATH = Path(__file__).with_name("config.toml")
|
|
44
|
+
DEFAULT_LEXICON_PRIORITY = ["vector", "wordnet"]
|
|
45
|
+
DEFAULT_ATTACK_SEED = 151
|
|
46
|
+
|
|
47
|
+
ATTACK_CONFIG_SCHEMA: dict[str, Any] = {
|
|
48
|
+
"type": "object",
|
|
49
|
+
"required": ["glitchlings"],
|
|
50
|
+
"properties": {
|
|
51
|
+
"glitchlings": {
|
|
52
|
+
"type": "array",
|
|
53
|
+
"minItems": 1,
|
|
54
|
+
"items": {
|
|
55
|
+
"anyOf": [
|
|
56
|
+
{"type": "string", "minLength": 1},
|
|
57
|
+
{
|
|
58
|
+
"type": "object",
|
|
59
|
+
"required": ["name"],
|
|
60
|
+
"properties": {
|
|
61
|
+
"name": {"type": "string", "minLength": 1},
|
|
62
|
+
"type": {"type": "string", "minLength": 1},
|
|
63
|
+
"parameters": {"type": "object"},
|
|
64
|
+
},
|
|
65
|
+
"additionalProperties": True,
|
|
66
|
+
},
|
|
67
|
+
{
|
|
68
|
+
"type": "object",
|
|
69
|
+
"required": ["type"],
|
|
70
|
+
"properties": {
|
|
71
|
+
"name": {"type": "string", "minLength": 1},
|
|
72
|
+
"type": {"type": "string", "minLength": 1},
|
|
73
|
+
"parameters": {"type": "object"},
|
|
74
|
+
},
|
|
75
|
+
"additionalProperties": True,
|
|
76
|
+
},
|
|
77
|
+
]
|
|
78
|
+
},
|
|
79
|
+
},
|
|
80
|
+
"seed": {"type": "integer"},
|
|
81
|
+
},
|
|
82
|
+
"additionalProperties": False,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass(slots=True)
|
|
87
|
+
class LexiconConfig:
|
|
88
|
+
"""Lexicon-specific configuration section."""
|
|
89
|
+
|
|
90
|
+
priority: list[str] = field(default_factory=lambda: list(DEFAULT_LEXICON_PRIORITY))
|
|
91
|
+
vector_cache: Path | None = None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass(slots=True)
|
|
95
|
+
class RuntimeConfig:
|
|
96
|
+
"""Top-level runtime configuration loaded from ``config.toml``."""
|
|
97
|
+
|
|
98
|
+
lexicon: LexiconConfig
|
|
99
|
+
path: Path
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
_CONFIG: RuntimeConfig | None = None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def reset_config() -> None:
|
|
106
|
+
"""Forget any cached runtime configuration."""
|
|
107
|
+
global _CONFIG
|
|
108
|
+
_CONFIG = None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def reload_config() -> RuntimeConfig:
|
|
112
|
+
"""Reload the runtime configuration from disk."""
|
|
113
|
+
reset_config()
|
|
114
|
+
return get_config()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_config() -> RuntimeConfig:
|
|
118
|
+
"""Return the cached runtime configuration, loading it if necessary."""
|
|
119
|
+
global _CONFIG
|
|
120
|
+
if _CONFIG is None:
|
|
121
|
+
_CONFIG = _load_runtime_config()
|
|
122
|
+
return _CONFIG
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _load_runtime_config() -> RuntimeConfig:
|
|
126
|
+
path = _resolve_config_path()
|
|
127
|
+
data = _read_toml(path)
|
|
128
|
+
mapping = _validate_runtime_config_data(data, source=path)
|
|
129
|
+
|
|
130
|
+
lexicon_section = mapping.get("lexicon", {})
|
|
131
|
+
|
|
132
|
+
priority = lexicon_section.get("priority", DEFAULT_LEXICON_PRIORITY)
|
|
133
|
+
if not isinstance(priority, Sequence) or isinstance(priority, (str, bytes)):
|
|
134
|
+
raise ValueError("lexicon.priority must be a sequence of strings.")
|
|
135
|
+
normalized_priority = []
|
|
136
|
+
for item in priority:
|
|
137
|
+
string_value = str(item)
|
|
138
|
+
if not string_value:
|
|
139
|
+
raise ValueError("lexicon.priority entries must be non-empty strings.")
|
|
140
|
+
normalized_priority.append(string_value)
|
|
141
|
+
|
|
142
|
+
vector_cache = _resolve_optional_path(
|
|
143
|
+
lexicon_section.get("vector_cache"),
|
|
144
|
+
base=path.parent,
|
|
145
|
+
)
|
|
146
|
+
lexicon_config = LexiconConfig(
|
|
147
|
+
priority=normalized_priority,
|
|
148
|
+
vector_cache=vector_cache,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return RuntimeConfig(lexicon=lexicon_config, path=path)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _resolve_config_path() -> Path:
|
|
155
|
+
override = os.environ.get(CONFIG_ENV_VAR)
|
|
156
|
+
if override:
|
|
157
|
+
return Path(override)
|
|
158
|
+
return DEFAULT_CONFIG_PATH
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _read_toml(path: Path) -> dict[str, Any]:
|
|
162
|
+
if not path.exists():
|
|
163
|
+
if path == DEFAULT_CONFIG_PATH:
|
|
164
|
+
return {}
|
|
165
|
+
raise FileNotFoundError(f"Configuration file '{path}' not found.")
|
|
166
|
+
with path.open("rb") as handle:
|
|
167
|
+
loaded = tomllib.load(handle)
|
|
168
|
+
if isinstance(loaded, Mapping):
|
|
169
|
+
return dict(loaded)
|
|
170
|
+
raise ValueError(f"Configuration file '{path}' must contain a top-level mapping.")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _validate_runtime_config_data(data: Any, *, source: Path) -> Mapping[str, Any]:
|
|
174
|
+
if data is None:
|
|
175
|
+
return {}
|
|
176
|
+
if not isinstance(data, Mapping):
|
|
177
|
+
raise ValueError(f"Configuration file '{source}' must contain a top-level mapping.")
|
|
178
|
+
|
|
179
|
+
allowed_sections = {"lexicon"}
|
|
180
|
+
unexpected_sections = [str(key) for key in data if key not in allowed_sections]
|
|
181
|
+
if unexpected_sections:
|
|
182
|
+
extras = ", ".join(sorted(unexpected_sections))
|
|
183
|
+
raise ValueError(f"Configuration file '{source}' has unsupported sections: {extras}.")
|
|
184
|
+
|
|
185
|
+
lexicon_section = data.get("lexicon", {})
|
|
186
|
+
if not isinstance(lexicon_section, Mapping):
|
|
187
|
+
raise ValueError("Configuration 'lexicon' section must be a table.")
|
|
188
|
+
|
|
189
|
+
allowed_lexicon_keys = {"priority", "vector_cache"}
|
|
190
|
+
unexpected_keys = [str(key) for key in lexicon_section if key not in allowed_lexicon_keys]
|
|
191
|
+
if unexpected_keys:
|
|
192
|
+
extras = ", ".join(sorted(unexpected_keys))
|
|
193
|
+
raise ValueError(f"Unknown lexicon settings: {extras}.")
|
|
194
|
+
|
|
195
|
+
for key in ("vector_cache",):
|
|
196
|
+
value = lexicon_section.get(key)
|
|
197
|
+
if value is not None and not isinstance(value, (str, os.PathLike)):
|
|
198
|
+
raise ValueError(f"lexicon.{key} must be a path or string when provided.")
|
|
199
|
+
|
|
200
|
+
return data
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _resolve_optional_path(value: Any, *, base: Path) -> Path | None:
|
|
204
|
+
if value in (None, ""):
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
candidate = Path(str(value))
|
|
208
|
+
if not candidate.is_absolute():
|
|
209
|
+
candidate = (base / candidate).resolve()
|
|
210
|
+
return candidate
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass(slots=True)
|
|
214
|
+
class AttackConfig:
|
|
215
|
+
"""Structured representation of a glitchling roster loaded from YAML."""
|
|
216
|
+
|
|
217
|
+
glitchlings: list["Glitchling"]
|
|
218
|
+
seed: int | None = None
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def load_attack_config(
|
|
222
|
+
source: str | Path | TextIOBase,
|
|
223
|
+
*,
|
|
224
|
+
encoding: str = "utf-8",
|
|
225
|
+
) -> AttackConfig:
|
|
226
|
+
"""Load and parse an attack configuration from YAML."""
|
|
227
|
+
if isinstance(source, (str, Path)):
|
|
228
|
+
path = Path(source)
|
|
229
|
+
label = str(path)
|
|
230
|
+
try:
|
|
231
|
+
text = path.read_text(encoding=encoding)
|
|
232
|
+
except FileNotFoundError as exc:
|
|
233
|
+
raise ValueError(f"Attack configuration '{label}' was not found.") from exc
|
|
234
|
+
elif isinstance(source, TextIOBase):
|
|
235
|
+
label = getattr(source, "name", "<stream>")
|
|
236
|
+
text = source.read()
|
|
237
|
+
else:
|
|
238
|
+
raise TypeError("Attack configuration source must be a path or text stream.")
|
|
239
|
+
|
|
240
|
+
data = _load_yaml(text, label)
|
|
241
|
+
return parse_attack_config(data, source=label)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _validate_attack_config_schema(data: Any, *, source: str) -> Mapping[str, Any]:
|
|
245
|
+
if data is None:
|
|
246
|
+
raise ValueError(f"Attack configuration '{source}' is empty.")
|
|
247
|
+
if not isinstance(data, Mapping):
|
|
248
|
+
raise ValueError(f"Attack configuration '{source}' must be a mapping.")
|
|
249
|
+
|
|
250
|
+
unexpected = [key for key in data if key not in {"glitchlings", "seed"}]
|
|
251
|
+
if unexpected:
|
|
252
|
+
extras = ", ".join(sorted(unexpected))
|
|
253
|
+
raise ValueError(f"Attack configuration '{source}' has unsupported fields: {extras}.")
|
|
254
|
+
|
|
255
|
+
if "glitchlings" not in data:
|
|
256
|
+
raise ValueError(f"Attack configuration '{source}' must define 'glitchlings'.")
|
|
257
|
+
|
|
258
|
+
raw_glitchlings = data["glitchlings"]
|
|
259
|
+
if not isinstance(raw_glitchlings, Sequence) or isinstance(raw_glitchlings, (str, bytes)):
|
|
260
|
+
raise ValueError(f"'glitchlings' in '{source}' must be a sequence.")
|
|
261
|
+
|
|
262
|
+
seed = data.get("seed")
|
|
263
|
+
if seed is not None and not isinstance(seed, int):
|
|
264
|
+
raise ValueError(f"Seed in '{source}' must be an integer if provided.")
|
|
265
|
+
|
|
266
|
+
for index, entry in enumerate(raw_glitchlings, start=1):
|
|
267
|
+
if isinstance(entry, Mapping):
|
|
268
|
+
name_candidate = entry.get("name") or entry.get("type")
|
|
269
|
+
if not isinstance(name_candidate, str) or not name_candidate.strip():
|
|
270
|
+
raise ValueError(f"{source}: glitchling #{index} is missing a 'name'.")
|
|
271
|
+
parameters = entry.get("parameters")
|
|
272
|
+
if parameters is not None and not isinstance(parameters, Mapping):
|
|
273
|
+
raise ValueError(
|
|
274
|
+
f"{source}: glitchling '{name_candidate}' parameters must be a mapping."
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
schema_module = jsonschema.get()
|
|
278
|
+
if schema_module is not None:
|
|
279
|
+
try:
|
|
280
|
+
schema_module.validate(instance=data, schema=ATTACK_CONFIG_SCHEMA)
|
|
281
|
+
except schema_module.exceptions.ValidationError as exc: # pragma: no cover - optional dep
|
|
282
|
+
message = exc.message
|
|
283
|
+
raise ValueError(f"Attack configuration '{source}' is invalid: {message}") from exc
|
|
284
|
+
|
|
285
|
+
return data
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def parse_attack_config(data: Any, *, source: str = "<config>") -> AttackConfig:
|
|
289
|
+
"""Convert arbitrary YAML data into a validated ``AttackConfig``."""
|
|
290
|
+
mapping = _validate_attack_config_schema(data, source=source)
|
|
291
|
+
|
|
292
|
+
raw_glitchlings = mapping["glitchlings"]
|
|
293
|
+
|
|
294
|
+
glitchlings: list["Glitchling"] = []
|
|
295
|
+
for index, entry in enumerate(raw_glitchlings, start=1):
|
|
296
|
+
glitchlings.append(_build_glitchling(entry, source, index))
|
|
297
|
+
|
|
298
|
+
seed = mapping.get("seed")
|
|
299
|
+
|
|
300
|
+
return AttackConfig(glitchlings=glitchlings, seed=seed)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def build_gaggle(config: AttackConfig, *, seed_override: int | None = None) -> "Gaggle":
|
|
304
|
+
"""Instantiate a ``Gaggle`` according to ``config``."""
|
|
305
|
+
from .zoo import Gaggle # Imported lazily to avoid circular dependencies
|
|
306
|
+
|
|
307
|
+
seed = seed_override if seed_override is not None else config.seed
|
|
308
|
+
if seed is None:
|
|
309
|
+
seed = DEFAULT_ATTACK_SEED
|
|
310
|
+
|
|
311
|
+
return Gaggle(config.glitchlings, seed=seed)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _load_yaml(text: str, label: str) -> Any:
|
|
315
|
+
try:
|
|
316
|
+
return yaml.safe_load(text)
|
|
317
|
+
except yaml.YAMLError as exc:
|
|
318
|
+
raise ValueError(f"Failed to parse attack configuration '{label}': {exc}") from exc
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _build_glitchling(entry: Any, source: str, index: int) -> "Glitchling":
|
|
322
|
+
from .zoo import get_glitchling_class, parse_glitchling_spec
|
|
323
|
+
|
|
324
|
+
if isinstance(entry, str):
|
|
325
|
+
try:
|
|
326
|
+
return parse_glitchling_spec(entry)
|
|
327
|
+
except ValueError as exc:
|
|
328
|
+
raise ValueError(f"{source}: glitchling #{index}: {exc}") from exc
|
|
329
|
+
|
|
330
|
+
if isinstance(entry, Mapping):
|
|
331
|
+
name_value = entry.get("name")
|
|
332
|
+
legacy_type = entry.get("type")
|
|
333
|
+
if name_value is None and legacy_type is not None:
|
|
334
|
+
warnings.warn(
|
|
335
|
+
f"{source}: glitchling #{index} uses 'type'; prefer 'name'.",
|
|
336
|
+
DeprecationWarning,
|
|
337
|
+
stacklevel=2,
|
|
338
|
+
)
|
|
339
|
+
name_value = legacy_type
|
|
340
|
+
elif name_value is None:
|
|
341
|
+
name_value = legacy_type
|
|
342
|
+
|
|
343
|
+
if not isinstance(name_value, str) or not name_value.strip():
|
|
344
|
+
raise ValueError(f"{source}: glitchling #{index} is missing a 'name'.")
|
|
345
|
+
|
|
346
|
+
parameters = entry.get("parameters")
|
|
347
|
+
if parameters is not None:
|
|
348
|
+
if not isinstance(parameters, Mapping):
|
|
349
|
+
raise ValueError(
|
|
350
|
+
f"{source}: glitchling '{name_value}' parameters must be a mapping."
|
|
351
|
+
)
|
|
352
|
+
kwargs = dict(parameters)
|
|
353
|
+
else:
|
|
354
|
+
kwargs = {
|
|
355
|
+
key: value
|
|
356
|
+
for key, value in entry.items()
|
|
357
|
+
if key not in {"name", "type", "parameters"}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
try:
|
|
361
|
+
glitchling_type = get_glitchling_class(name_value)
|
|
362
|
+
except ValueError as exc:
|
|
363
|
+
raise ValueError(f"{source}: glitchling #{index}: {exc}") from exc
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
return glitchling_type(**kwargs)
|
|
367
|
+
except TypeError as exc:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
f"{source}: glitchling #{index}: failed to instantiate '{name_value}': {exc}"
|
|
370
|
+
) from exc
|
|
371
|
+
|
|
372
|
+
raise ValueError(f"{source}: glitchling #{index} must be a string or mapping.")
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
__all__ = [
|
|
376
|
+
"AttackConfig",
|
|
377
|
+
"DEFAULT_ATTACK_SEED",
|
|
378
|
+
"DEFAULT_CONFIG_PATH",
|
|
379
|
+
"DEFAULT_LEXICON_PRIORITY",
|
|
380
|
+
"RuntimeConfig",
|
|
381
|
+
"LexiconConfig",
|
|
382
|
+
"build_gaggle",
|
|
383
|
+
"get_config",
|
|
384
|
+
"load_attack_config",
|
|
385
|
+
"parse_attack_config",
|
|
386
|
+
"reload_config",
|
|
387
|
+
"reset_config",
|
|
388
|
+
]
|
glitchlings/config.toml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Optional DLC integrations for Glitchlings."""
|
|
2
|
+
|
|
3
|
+
from .huggingface import install as install_huggingface
|
|
4
|
+
from .pytorch import install as install_pytorch
|
|
5
|
+
from .pytorch_lightning import install as install_pytorch_lightning
|
|
6
|
+
|
|
7
|
+
__all__ = ["install_huggingface", "install_pytorch", "install_pytorch_lightning"]
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Shared utilities for DLC integrations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..zoo.core import Gaggle, _is_transcript
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_environment(
|
|
12
|
+
env: Any,
|
|
13
|
+
*,
|
|
14
|
+
loader: Callable[[str], Any],
|
|
15
|
+
environment_type: type[Any],
|
|
16
|
+
) -> Any:
|
|
17
|
+
"""Return a fully-instantiated verifier environment."""
|
|
18
|
+
if isinstance(env, str):
|
|
19
|
+
env = loader(env)
|
|
20
|
+
|
|
21
|
+
if not isinstance(env, environment_type):
|
|
22
|
+
raise TypeError("Invalid environment type")
|
|
23
|
+
|
|
24
|
+
return env
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_columns(dataset: Any, columns: Sequence[str] | None) -> list[str]:
|
|
28
|
+
"""Identify which dataset columns should be corrupted."""
|
|
29
|
+
available = set(getattr(dataset, "column_names", ()))
|
|
30
|
+
|
|
31
|
+
if columns is not None:
|
|
32
|
+
missing = sorted(set(columns) - available)
|
|
33
|
+
if missing:
|
|
34
|
+
missing_str = ", ".join(missing)
|
|
35
|
+
raise ValueError(f"Columns not found in dataset: {missing_str}")
|
|
36
|
+
return list(columns)
|
|
37
|
+
|
|
38
|
+
for candidate in ("prompt", "question"):
|
|
39
|
+
if candidate in available:
|
|
40
|
+
return [candidate]
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
dataset_length = len(dataset)
|
|
44
|
+
except TypeError:
|
|
45
|
+
preview_rows: list[dict[str, Any]]
|
|
46
|
+
take_fn = getattr(dataset, "take", None)
|
|
47
|
+
if callable(take_fn):
|
|
48
|
+
preview_rows = list(take_fn(1))
|
|
49
|
+
else:
|
|
50
|
+
iterator = iter(dataset)
|
|
51
|
+
try:
|
|
52
|
+
first_row = next(iterator)
|
|
53
|
+
except StopIteration:
|
|
54
|
+
preview_rows = []
|
|
55
|
+
else:
|
|
56
|
+
preview_rows = [first_row]
|
|
57
|
+
sample = dict(preview_rows[0]) if preview_rows else {}
|
|
58
|
+
else:
|
|
59
|
+
sample = dataset[0] if dataset_length else {}
|
|
60
|
+
inferred = [
|
|
61
|
+
name for name in getattr(dataset, "column_names", ()) if isinstance(sample.get(name), str)
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
if inferred:
|
|
65
|
+
return inferred
|
|
66
|
+
|
|
67
|
+
raise ValueError("Unable to determine which dataset columns to corrupt.")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def normalise_column_spec(
|
|
71
|
+
columns: str | int | Sequence[str | int] | None,
|
|
72
|
+
) -> list[str | int] | None:
|
|
73
|
+
"""Normalise a column specification into a list of keys or indices.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
columns: Column specification as a single value, sequence of values, or None.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A list of column identifiers, or None if input was None.
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If an empty sequence is provided.
|
|
83
|
+
"""
|
|
84
|
+
if columns is None:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
if isinstance(columns, (str, int)):
|
|
88
|
+
return [columns]
|
|
89
|
+
|
|
90
|
+
normalised = list(columns)
|
|
91
|
+
if not normalised:
|
|
92
|
+
raise ValueError("At least one column must be specified")
|
|
93
|
+
return normalised
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def is_textual_candidate(value: Any) -> bool:
|
|
97
|
+
"""Return ``True`` when ``value`` looks like text that glitchlings can corrupt.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
value: The value to check for textual content.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
True if the value appears to be textual content.
|
|
104
|
+
"""
|
|
105
|
+
if isinstance(value, str):
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
if _is_transcript(value, allow_empty=False, require_all_content=True):
|
|
109
|
+
return True
|
|
110
|
+
|
|
111
|
+
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
|
|
112
|
+
if not value:
|
|
113
|
+
return False
|
|
114
|
+
if all(isinstance(item, str) for item in value):
|
|
115
|
+
return True
|
|
116
|
+
if _is_transcript(list(value), allow_empty=False, require_all_content=True):
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def corrupt_text_value(value: Any, gaggle: Gaggle) -> Any:
|
|
123
|
+
"""Return ``value`` with glitchlings applied when possible.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
value: The value to corrupt (string, transcript, or sequence of strings).
|
|
127
|
+
gaggle: The gaggle of glitchlings to apply.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The corrupted value, preserving the original type where possible.
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(value, str):
|
|
133
|
+
return gaggle.corrupt(value)
|
|
134
|
+
|
|
135
|
+
if _is_transcript(value, allow_empty=True):
|
|
136
|
+
return gaggle.corrupt(value)
|
|
137
|
+
|
|
138
|
+
if isinstance(value, list) and value and all(isinstance(item, str) for item in value):
|
|
139
|
+
return [gaggle.corrupt(item) for item in value]
|
|
140
|
+
|
|
141
|
+
if isinstance(value, tuple) and value and all(isinstance(item, str) for item in value):
|
|
142
|
+
return tuple(gaggle.corrupt(item) for item in value)
|
|
143
|
+
|
|
144
|
+
return value
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
__all__ = [
|
|
148
|
+
"corrupt_text_value",
|
|
149
|
+
"is_textual_candidate",
|
|
150
|
+
"normalise_column_spec",
|
|
151
|
+
"resolve_columns",
|
|
152
|
+
"resolve_environment",
|
|
153
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Integration helpers for the Hugging Face datasets library."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Sequence
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
from ..compat import datasets, get_datasets_dataset, require_datasets
|
|
9
|
+
from ..util.adapters import coerce_gaggle
|
|
10
|
+
from ..zoo import Gaggle, Glitchling
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _normalise_columns(column: str | Sequence[str]) -> list[str]:
|
|
14
|
+
"""Normalise a column specification to a list."""
|
|
15
|
+
if isinstance(column, str):
|
|
16
|
+
return [column]
|
|
17
|
+
|
|
18
|
+
normalised = list(column)
|
|
19
|
+
if not normalised:
|
|
20
|
+
raise ValueError("At least one column must be specified")
|
|
21
|
+
return normalised
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _glitch_dataset(
|
|
25
|
+
dataset: Any,
|
|
26
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
27
|
+
column: str | Sequence[str],
|
|
28
|
+
*,
|
|
29
|
+
seed: int = 151,
|
|
30
|
+
) -> Any:
|
|
31
|
+
"""Apply glitchlings to the provided dataset columns."""
|
|
32
|
+
columns = _normalise_columns(column)
|
|
33
|
+
gaggle = coerce_gaggle(glitchlings, seed=seed)
|
|
34
|
+
return gaggle.corrupt_dataset(dataset, columns)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _ensure_dataset_class() -> Any:
|
|
38
|
+
"""Return the Hugging Face :class:`~datasets.Dataset` patched with ``.glitch``."""
|
|
39
|
+
dataset_cls = get_datasets_dataset()
|
|
40
|
+
if dataset_cls is None: # pragma: no cover - datasets is an install-time dependency
|
|
41
|
+
require_datasets("datasets is not installed")
|
|
42
|
+
dataset_cls = get_datasets_dataset()
|
|
43
|
+
if dataset_cls is None:
|
|
44
|
+
message = "datasets is not installed"
|
|
45
|
+
error = datasets.error
|
|
46
|
+
if error is not None:
|
|
47
|
+
raise ModuleNotFoundError(message) from error
|
|
48
|
+
raise ModuleNotFoundError(message)
|
|
49
|
+
|
|
50
|
+
if getattr(dataset_cls, "glitch", None) is None:
|
|
51
|
+
|
|
52
|
+
def glitch(
|
|
53
|
+
self: Any,
|
|
54
|
+
glitchlings: Glitchling | Gaggle | str | Iterable[str | Glitchling],
|
|
55
|
+
*,
|
|
56
|
+
column: str | Sequence[str],
|
|
57
|
+
seed: int = 151,
|
|
58
|
+
**_: Any,
|
|
59
|
+
) -> Any:
|
|
60
|
+
"""Return a lazily corrupted copy of the dataset."""
|
|
61
|
+
return _glitch_dataset(self, glitchlings, column, seed=seed)
|
|
62
|
+
|
|
63
|
+
setattr(dataset_cls, "glitch", glitch)
|
|
64
|
+
|
|
65
|
+
return cast(type[Any], dataset_cls)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def install() -> None:
|
|
69
|
+
"""Monkeypatch the Hugging Face :class:`~datasets.Dataset` with ``.glitch``."""
|
|
70
|
+
_ensure_dataset_class()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
Dataset: type[Any] | None
|
|
74
|
+
_DatasetAlias = get_datasets_dataset()
|
|
75
|
+
if _DatasetAlias is not None:
|
|
76
|
+
Dataset = _ensure_dataset_class()
|
|
77
|
+
else: # pragma: no cover - datasets is an install-time dependency
|
|
78
|
+
Dataset = None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
__all__ = ["Dataset", "install"]
|