mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
amber/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Amber: helper package for the Engineer Thesis project.
|
|
2
|
+
|
|
3
|
+
This module is intentionally minimal. It exists to define the top-level package
|
|
4
|
+
and to enable code coverage to include the package. Importing it should succeed
|
|
5
|
+
without side effects.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# A tiny bit of executable code to make the package measurable by coverage.
|
|
9
|
+
PACKAGE_NAME = "amber"
|
|
10
|
+
__version__ = "0.0.0"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def ping() -> str:
|
|
14
|
+
"""Return a simple response to verify the package is wired correctly."""
|
|
15
|
+
return "pong"
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from amber.datasets.base_dataset import BaseDataset
|
|
2
|
+
from amber.datasets.text_dataset import TextDataset
|
|
3
|
+
from amber.datasets.classification_dataset import ClassificationDataset
|
|
4
|
+
from amber.datasets.loading_strategy import LoadingStrategy
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"BaseDataset",
|
|
8
|
+
"TextDataset",
|
|
9
|
+
"ClassificationDataset",
|
|
10
|
+
"LoadingStrategy",
|
|
11
|
+
]
|
|
@@ -0,0 +1,640 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import random
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
9
|
+
|
|
10
|
+
from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
|
|
11
|
+
|
|
12
|
+
from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
13
|
+
from amber.store.store import Store
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseDataset(ABC):
|
|
17
|
+
"""
|
|
18
|
+
Abstract base class for datasets with support for multiple sources,
|
|
19
|
+
loading strategies, and Store integration.
|
|
20
|
+
|
|
21
|
+
Loading Strategies:
|
|
22
|
+
- MEMORY: Load entire dataset into memory (fastest random access, highest memory usage)
|
|
23
|
+
- DISK: Save to disk, read dynamically via memory-mapped Arrow files
|
|
24
|
+
(supports len/getitem, lower memory usage)
|
|
25
|
+
- STREAMING: True streaming mode using IterableDataset
|
|
26
|
+
(lowest memory, no len/getitem support, no stratification and limit support)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
ds: Dataset | IterableDataset,
|
|
32
|
+
store: Store,
|
|
33
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Initialize dataset.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
ds: HuggingFace Dataset or IterableDataset
|
|
40
|
+
store: Store instance for caching/persistence
|
|
41
|
+
loading_strategy: How to load data (MEMORY, DISK, or STREAMING)
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If store is None, loading_strategy is invalid, or dataset operations fail
|
|
45
|
+
OSError: If file system operations fail
|
|
46
|
+
"""
|
|
47
|
+
self._validate_initialization_params(store, loading_strategy)
|
|
48
|
+
|
|
49
|
+
self._store = store
|
|
50
|
+
self._loading_strategy = loading_strategy
|
|
51
|
+
self._dataset_dir: Path = Path(store.base_path) / store.dataset_prefix
|
|
52
|
+
|
|
53
|
+
is_iterable_input = isinstance(ds, IterableDataset)
|
|
54
|
+
|
|
55
|
+
if loading_strategy == LoadingStrategy.MEMORY:
|
|
56
|
+
# MEMORY: Convert to Dataset if needed, save to disk, load fully into memory
|
|
57
|
+
self._is_iterable = False
|
|
58
|
+
if is_iterable_input:
|
|
59
|
+
ds = Dataset.from_generator(lambda: iter(ds))
|
|
60
|
+
self._ds = self._save_and_load_dataset(ds, use_memory_mapping=False)
|
|
61
|
+
elif loading_strategy == LoadingStrategy.DISK:
|
|
62
|
+
# DISK: Save to disk, use memory-mapped Arrow files (supports len/getitem)
|
|
63
|
+
self._is_iterable = False
|
|
64
|
+
if is_iterable_input:
|
|
65
|
+
ds = Dataset.from_generator(lambda: iter(ds))
|
|
66
|
+
self._ds = self._save_and_load_dataset(ds, use_memory_mapping=True)
|
|
67
|
+
elif loading_strategy == LoadingStrategy.STREAMING:
|
|
68
|
+
# STREAMING: Convert to IterableDataset, don't save to disk (no len/getitem)
|
|
69
|
+
if not is_iterable_input:
|
|
70
|
+
ds = ds.to_iterable_dataset()
|
|
71
|
+
self._is_iterable = True
|
|
72
|
+
self._ds = ds
|
|
73
|
+
# Don't save to disk for iterable-only mode
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Unknown loading strategy: {loading_strategy}. Must be one of: {[s.value for s in LoadingStrategy]}"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _validate_initialization_params(self, store: Store, loading_strategy: LoadingStrategy) -> None:
|
|
80
|
+
"""Validate initialization parameters.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
store: Store instance to validate
|
|
84
|
+
loading_strategy: Loading strategy to validate
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If store is None or loading_strategy is invalid
|
|
88
|
+
"""
|
|
89
|
+
if store is None:
|
|
90
|
+
raise ValueError("store cannot be None")
|
|
91
|
+
|
|
92
|
+
if not isinstance(loading_strategy, LoadingStrategy):
|
|
93
|
+
raise ValueError(f"loading_strategy must be a LoadingStrategy enum value, got: {type(loading_strategy)}")
|
|
94
|
+
|
|
95
|
+
def _has_valid_dataset_dir(self) -> bool:
|
|
96
|
+
"""Check if dataset directory path is valid (non-empty base_path).
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
True if base_path is not empty, False otherwise
|
|
100
|
+
"""
|
|
101
|
+
return bool(self._store.base_path and str(self._store.base_path).strip())
|
|
102
|
+
|
|
103
|
+
def _save_and_load_dataset(self, ds: Dataset, use_memory_mapping: bool = True) -> Dataset:
|
|
104
|
+
"""Save dataset to disk and load it back (with optional memory mapping).
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
ds: Dataset to save and load
|
|
108
|
+
use_memory_mapping: Whether to use memory mapping (True for DISK)
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Loaded dataset
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
OSError: If file system operations fail
|
|
115
|
+
RuntimeError: If dataset operations fail
|
|
116
|
+
"""
|
|
117
|
+
if self._has_valid_dataset_dir():
|
|
118
|
+
try:
|
|
119
|
+
self._dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
120
|
+
ds.save_to_disk(str(self._dataset_dir))
|
|
121
|
+
return load_from_disk(str(self._dataset_dir))
|
|
122
|
+
except OSError as e:
|
|
123
|
+
raise OSError(f"Failed to save/load dataset at {self._dataset_dir}. Error: {e}") from e
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise RuntimeError(f"Failed to process dataset at {self._dataset_dir}. Error: {e}") from e
|
|
126
|
+
else:
|
|
127
|
+
return ds
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def _postprocess_non_streaming_dataset(
|
|
131
|
+
cls,
|
|
132
|
+
ds: Dataset,
|
|
133
|
+
*,
|
|
134
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
135
|
+
limit: Optional[int] = None,
|
|
136
|
+
stratify_by: Optional[str] = None,
|
|
137
|
+
stratify_seed: Optional[int] = None,
|
|
138
|
+
drop_na_columns: Optional[List[str]] = None,
|
|
139
|
+
) -> Dataset:
|
|
140
|
+
"""Apply filters, stratified sampling, and limits to an in-memory dataset."""
|
|
141
|
+
|
|
142
|
+
if drop_na_columns:
|
|
143
|
+
ds = cls._drop_na(ds, drop_na_columns)
|
|
144
|
+
|
|
145
|
+
if filters:
|
|
146
|
+
ds = cls._apply_filters(ds, filters)
|
|
147
|
+
|
|
148
|
+
limit_applied = False
|
|
149
|
+
if stratify_by:
|
|
150
|
+
sample_size = limit if limit is not None else len(ds)
|
|
151
|
+
if sample_size is not None and sample_size <= 0:
|
|
152
|
+
raise ValueError(f"limit must be > 0 when stratifying, got: {sample_size}")
|
|
153
|
+
ds = cls._stratified_sample(
|
|
154
|
+
ds,
|
|
155
|
+
stratify_by=stratify_by,
|
|
156
|
+
sample_size=sample_size,
|
|
157
|
+
seed=stratify_seed,
|
|
158
|
+
)
|
|
159
|
+
limit_applied = True
|
|
160
|
+
|
|
161
|
+
if limit is not None and not limit_applied:
|
|
162
|
+
if limit <= 0:
|
|
163
|
+
raise ValueError(f"limit must be > 0, got: {limit}")
|
|
164
|
+
ds = ds.select(range(min(limit, len(ds))))
|
|
165
|
+
|
|
166
|
+
return ds
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def _drop_na(ds: Dataset, columns: List[str]) -> Dataset:
|
|
170
|
+
"""Drop rows where any of the specified columns are None or empty string."""
|
|
171
|
+
|
|
172
|
+
def _is_valid(example: Dict[str, Any]) -> bool:
|
|
173
|
+
for col in columns:
|
|
174
|
+
val = example.get(col)
|
|
175
|
+
if val is None:
|
|
176
|
+
return False
|
|
177
|
+
if isinstance(val, str) and not val.strip():
|
|
178
|
+
return False
|
|
179
|
+
return True
|
|
180
|
+
|
|
181
|
+
return ds.filter(_is_valid)
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def _apply_filters(ds: Dataset, filters: Dict[str, Any]) -> Dataset:
|
|
185
|
+
"""Apply exact-match filters to a Dataset."""
|
|
186
|
+
|
|
187
|
+
def _predicate(example: Dict[str, Any]) -> bool:
|
|
188
|
+
return all(example.get(key) == value for key, value in filters.items())
|
|
189
|
+
|
|
190
|
+
return ds.filter(_predicate)
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _stratified_sample( # noqa: C901
|
|
194
|
+
ds: Dataset,
|
|
195
|
+
*,
|
|
196
|
+
stratify_by: str,
|
|
197
|
+
sample_size: Optional[int],
|
|
198
|
+
seed: Optional[int],
|
|
199
|
+
) -> Dataset:
|
|
200
|
+
"""Return a stratified sample of the dataset with the requested size."""
|
|
201
|
+
|
|
202
|
+
if stratify_by not in ds.column_names:
|
|
203
|
+
raise ValueError(f"Column '{stratify_by}' not found in dataset columns: {ds.column_names}")
|
|
204
|
+
|
|
205
|
+
total_rows = len(ds)
|
|
206
|
+
if total_rows == 0:
|
|
207
|
+
return ds
|
|
208
|
+
|
|
209
|
+
if sample_size is None:
|
|
210
|
+
sample_size = total_rows
|
|
211
|
+
|
|
212
|
+
sample_size = min(sample_size, total_rows)
|
|
213
|
+
if sample_size <= 0:
|
|
214
|
+
raise ValueError("sample_size must be greater than 0 for stratification")
|
|
215
|
+
|
|
216
|
+
column_values = ds[stratify_by]
|
|
217
|
+
label_to_indices: Dict[Any, List[int]] = defaultdict(list)
|
|
218
|
+
for idx, label in enumerate(column_values):
|
|
219
|
+
label_to_indices[label].append(idx)
|
|
220
|
+
|
|
221
|
+
label_counts = {label: len(indices) for label, indices in label_to_indices.items()}
|
|
222
|
+
allocations: Dict[Any, int] = {}
|
|
223
|
+
fractional_parts: List[tuple[float, int, Any]] = []
|
|
224
|
+
|
|
225
|
+
allocated_total = 0
|
|
226
|
+
for order, (label, count) in enumerate(label_counts.items()):
|
|
227
|
+
exact_allocation = (count / total_rows) * sample_size
|
|
228
|
+
base_allocation = min(count, int(math.floor(exact_allocation)))
|
|
229
|
+
allocations[label] = base_allocation
|
|
230
|
+
allocated_total += base_allocation
|
|
231
|
+
fractional_parts.append((exact_allocation - base_allocation, order, label))
|
|
232
|
+
|
|
233
|
+
remaining = sample_size - allocated_total
|
|
234
|
+
fractional_parts.sort(key=lambda item: (-item[0], item[1]))
|
|
235
|
+
for _, _, label in fractional_parts:
|
|
236
|
+
if remaining <= 0:
|
|
237
|
+
break
|
|
238
|
+
available = label_counts[label] - allocations[label]
|
|
239
|
+
if available <= 0:
|
|
240
|
+
continue
|
|
241
|
+
take = min(available, remaining)
|
|
242
|
+
allocations[label] += take
|
|
243
|
+
remaining -= take
|
|
244
|
+
|
|
245
|
+
rng = random.Random(seed)
|
|
246
|
+
selected_indices: List[int] = []
|
|
247
|
+
for label, count in allocations.items():
|
|
248
|
+
if count <= 0:
|
|
249
|
+
continue
|
|
250
|
+
indices = label_to_indices[label]
|
|
251
|
+
if count >= len(indices):
|
|
252
|
+
chosen = list(indices)
|
|
253
|
+
else:
|
|
254
|
+
chosen = rng.sample(indices, count)
|
|
255
|
+
selected_indices.extend(chosen)
|
|
256
|
+
|
|
257
|
+
rng.shuffle(selected_indices)
|
|
258
|
+
return ds.select(selected_indices)
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def _load_csv_source(
|
|
262
|
+
source: Union[str, Path],
|
|
263
|
+
*,
|
|
264
|
+
delimiter: str,
|
|
265
|
+
streaming: bool,
|
|
266
|
+
**kwargs,
|
|
267
|
+
) -> Dataset | IterableDataset:
|
|
268
|
+
"""Load a CSV dataset from disk using HuggingFace datasets."""
|
|
269
|
+
|
|
270
|
+
p = Path(source)
|
|
271
|
+
if not p.exists():
|
|
272
|
+
raise FileNotFoundError(f"CSV file not found: {source}")
|
|
273
|
+
if not p.is_file():
|
|
274
|
+
raise ValueError(f"Source must be a file, got: {source}")
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
return load_dataset(
|
|
278
|
+
"csv",
|
|
279
|
+
data_files=str(p),
|
|
280
|
+
split="train",
|
|
281
|
+
delimiter=delimiter,
|
|
282
|
+
streaming=streaming,
|
|
283
|
+
**kwargs,
|
|
284
|
+
)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
raise RuntimeError(f"Failed to load CSV dataset from {source}. Error: {e}") from e
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def _load_json_source(
|
|
290
|
+
source: Union[str, Path],
|
|
291
|
+
*,
|
|
292
|
+
streaming: bool,
|
|
293
|
+
**kwargs,
|
|
294
|
+
) -> Dataset | IterableDataset:
|
|
295
|
+
"""Load a JSON/JSONL dataset from disk using HuggingFace datasets."""
|
|
296
|
+
|
|
297
|
+
p = Path(source)
|
|
298
|
+
if not p.exists():
|
|
299
|
+
raise FileNotFoundError(f"JSON file not found: {source}")
|
|
300
|
+
if not p.is_file():
|
|
301
|
+
raise ValueError(f"Source must be a file, got: {source}")
|
|
302
|
+
|
|
303
|
+
try:
|
|
304
|
+
return load_dataset(
|
|
305
|
+
"json",
|
|
306
|
+
data_files=str(p),
|
|
307
|
+
split="train",
|
|
308
|
+
streaming=streaming,
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
except Exception as e:
|
|
312
|
+
raise RuntimeError(f"Failed to load JSON dataset from {source}. Error: {e}") from e
|
|
313
|
+
|
|
314
|
+
def get_batch(self, start: int, batch_size: int) -> List[Any]:
|
|
315
|
+
"""
|
|
316
|
+
Get a contiguous batch of items.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
start: Starting index
|
|
320
|
+
batch_size: Number of items to retrieve
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
List of items
|
|
324
|
+
|
|
325
|
+
Raises:
|
|
326
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
327
|
+
"""
|
|
328
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
329
|
+
raise NotImplementedError("get_batch not supported for STREAMING datasets. Use iter_batches instead.")
|
|
330
|
+
if batch_size <= 0:
|
|
331
|
+
return []
|
|
332
|
+
end = min(start + batch_size, len(self))
|
|
333
|
+
if start >= end:
|
|
334
|
+
return []
|
|
335
|
+
return self[start:end]
|
|
336
|
+
|
|
337
|
+
def head(self, n: int = 5) -> List[Any]:
|
|
338
|
+
"""
|
|
339
|
+
Get first n items.
|
|
340
|
+
|
|
341
|
+
Works for all loading strategies.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
n: Number of items to retrieve (default: 5)
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
List of first n items
|
|
348
|
+
"""
|
|
349
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
350
|
+
items = []
|
|
351
|
+
for i, item in enumerate(self.iter_items()):
|
|
352
|
+
if i >= n:
|
|
353
|
+
break
|
|
354
|
+
items.append(item)
|
|
355
|
+
return items
|
|
356
|
+
return self[:n]
|
|
357
|
+
|
|
358
|
+
def sample(self, n: int = 5) -> List[Any]:
|
|
359
|
+
"""
|
|
360
|
+
Get n random items from the dataset.
|
|
361
|
+
|
|
362
|
+
Works for MEMORY and DISK strategies only.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
n: Number of items to sample
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
List of n randomly sampled items
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
372
|
+
"""
|
|
373
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
374
|
+
raise NotImplementedError(
|
|
375
|
+
"sample() not supported for STREAMING datasets. Use iter_items() and sample manually."
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
dataset_len = len(self)
|
|
379
|
+
if n <= 0:
|
|
380
|
+
return []
|
|
381
|
+
if n >= dataset_len:
|
|
382
|
+
# Return all items in random order
|
|
383
|
+
indices = list(range(dataset_len))
|
|
384
|
+
random.shuffle(indices)
|
|
385
|
+
return [self[i] for i in indices]
|
|
386
|
+
|
|
387
|
+
# Sample n random indices
|
|
388
|
+
indices = random.sample(range(dataset_len), n)
|
|
389
|
+
# Use __getitem__ with list of indices
|
|
390
|
+
return self[indices]
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def is_streaming(self) -> bool:
|
|
394
|
+
"""Whether this dataset is streaming (DISK or STREAMING)."""
|
|
395
|
+
return self._loading_strategy in (LoadingStrategy.DISK, LoadingStrategy.STREAMING)
|
|
396
|
+
|
|
397
|
+
@abstractmethod
|
|
398
|
+
def __len__(self) -> int:
|
|
399
|
+
"""Return the number of items in the dataset."""
|
|
400
|
+
pass
|
|
401
|
+
|
|
402
|
+
@abstractmethod
|
|
403
|
+
def __getitem__(self, idx: IndexLike) -> Any:
|
|
404
|
+
"""Get item(s) by index."""
|
|
405
|
+
pass
|
|
406
|
+
|
|
407
|
+
@abstractmethod
|
|
408
|
+
def iter_items(self) -> Iterator[Any]:
|
|
409
|
+
"""Iterate over items one by one."""
|
|
410
|
+
pass
|
|
411
|
+
|
|
412
|
+
@abstractmethod
|
|
413
|
+
def iter_batches(self, batch_size: int) -> Iterator[List[Any]]:
|
|
414
|
+
"""Iterate over items in batches."""
|
|
415
|
+
pass
|
|
416
|
+
|
|
417
|
+
@abstractmethod
|
|
418
|
+
def extract_texts_from_batch(self, batch: List[Any]) -> List[str]:
|
|
419
|
+
"""Extract text strings from a batch.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
batch: A batch as returned by iter_batches()
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
List of text strings ready for model inference
|
|
426
|
+
"""
|
|
427
|
+
pass
|
|
428
|
+
|
|
429
|
+
@abstractmethod
|
|
430
|
+
def get_all_texts(self) -> List[str]:
|
|
431
|
+
"""Get all texts from the dataset.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
List of all text strings in the dataset
|
|
435
|
+
|
|
436
|
+
Raises:
|
|
437
|
+
NotImplementedError: If not supported for streaming datasets
|
|
438
|
+
"""
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
# --- Factory methods ---
|
|
442
|
+
|
|
443
|
+
@classmethod
|
|
444
|
+
def from_huggingface(
|
|
445
|
+
cls,
|
|
446
|
+
repo_id: str,
|
|
447
|
+
store: Store,
|
|
448
|
+
*,
|
|
449
|
+
split: str = "train",
|
|
450
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
451
|
+
revision: Optional[str] = None,
|
|
452
|
+
streaming: Optional[bool] = None,
|
|
453
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
454
|
+
limit: Optional[int] = None,
|
|
455
|
+
stratify_by: Optional[str] = None,
|
|
456
|
+
stratify_seed: Optional[int] = None,
|
|
457
|
+
**kwargs,
|
|
458
|
+
) -> "BaseDataset":
|
|
459
|
+
"""
|
|
460
|
+
Load dataset from HuggingFace Hub.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
repo_id: HuggingFace dataset repository ID
|
|
464
|
+
store: Store instance
|
|
465
|
+
split: Dataset split (e.g., "train", "validation")
|
|
466
|
+
loading_strategy: Loading strategy (MEMORY, DISK, or STREAMING)
|
|
467
|
+
revision: Optional git revision/branch/tag
|
|
468
|
+
streaming: Optional override for streaming (if None, uses loading_strategy)
|
|
469
|
+
filters: Optional dict of column->value pairs used for exact-match filtering
|
|
470
|
+
limit: Optional maximum number of rows to keep (applied after filtering/stratification)
|
|
471
|
+
stratify_by: Optional column to use for stratified sampling (non-streaming only)
|
|
472
|
+
stratify_seed: Optional RNG seed for deterministic stratification
|
|
473
|
+
**kwargs: Additional arguments passed to load_dataset
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
BaseDataset instance
|
|
477
|
+
|
|
478
|
+
Raises:
|
|
479
|
+
ValueError: If repo_id is empty or store is None
|
|
480
|
+
RuntimeError: If dataset loading fails
|
|
481
|
+
"""
|
|
482
|
+
if not repo_id or not isinstance(repo_id, str) or not repo_id.strip():
|
|
483
|
+
raise ValueError(f"repo_id must be a non-empty string, got: {repo_id!r}")
|
|
484
|
+
|
|
485
|
+
if store is None:
|
|
486
|
+
raise ValueError("store cannot be None")
|
|
487
|
+
|
|
488
|
+
# Determine if we should use streaming for HuggingFace load_dataset
|
|
489
|
+
use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)
|
|
490
|
+
|
|
491
|
+
if stratify_by and loading_strategy == LoadingStrategy.STREAMING:
|
|
492
|
+
raise NotImplementedError("Stratification is not supported for STREAMING datasets.")
|
|
493
|
+
|
|
494
|
+
try:
|
|
495
|
+
ds = load_dataset(
|
|
496
|
+
path=repo_id,
|
|
497
|
+
split=split,
|
|
498
|
+
revision=revision,
|
|
499
|
+
streaming=use_streaming,
|
|
500
|
+
**kwargs,
|
|
501
|
+
)
|
|
502
|
+
except Exception as e:
|
|
503
|
+
raise RuntimeError(
|
|
504
|
+
f"Failed to load dataset from HuggingFace Hub: repo_id={repo_id!r}, "
|
|
505
|
+
f"split={split!r}, revision={revision!r}. Error: {e}"
|
|
506
|
+
) from e
|
|
507
|
+
|
|
508
|
+
if use_streaming:
|
|
509
|
+
if filters or limit or stratify_by:
|
|
510
|
+
raise NotImplementedError(
|
|
511
|
+
"filters, limit, and stratification are not supported when streaming datasets. "
|
|
512
|
+
"Choose MEMORY or DISK loading strategy instead."
|
|
513
|
+
)
|
|
514
|
+
else:
|
|
515
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
516
|
+
ds,
|
|
517
|
+
filters=filters,
|
|
518
|
+
limit=limit,
|
|
519
|
+
stratify_by=stratify_by,
|
|
520
|
+
stratify_seed=stratify_seed,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
return cls(ds, store=store, loading_strategy=loading_strategy)
|
|
524
|
+
|
|
525
|
+
@classmethod
|
|
526
|
+
def from_csv(
|
|
527
|
+
cls,
|
|
528
|
+
source: Union[str, Path],
|
|
529
|
+
store: Store,
|
|
530
|
+
*,
|
|
531
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
532
|
+
text_field: str = "text",
|
|
533
|
+
delimiter: str = ",",
|
|
534
|
+
stratify_by: Optional[str] = None,
|
|
535
|
+
stratify_seed: Optional[int] = None,
|
|
536
|
+
drop_na_columns: Optional[List[str]] = None,
|
|
537
|
+
**kwargs,
|
|
538
|
+
) -> "BaseDataset":
|
|
539
|
+
"""
|
|
540
|
+
Load dataset from CSV file.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
source: Path to CSV file
|
|
544
|
+
store: Store instance
|
|
545
|
+
loading_strategy: Loading strategy
|
|
546
|
+
text_field: Name of the column containing text
|
|
547
|
+
delimiter: CSV delimiter (default: comma)
|
|
548
|
+
stratify_by: Optional column used for stratified sampling (non-streaming only)
|
|
549
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
550
|
+
drop_na_columns: Optional list of columns to check for None/empty values
|
|
551
|
+
**kwargs: Additional arguments passed to load_dataset
|
|
552
|
+
|
|
553
|
+
Returns:
|
|
554
|
+
BaseDataset instance
|
|
555
|
+
|
|
556
|
+
Raises:
|
|
557
|
+
FileNotFoundError: If CSV file doesn't exist
|
|
558
|
+
ValueError: If store is None or source is invalid
|
|
559
|
+
RuntimeError: If dataset loading fails
|
|
560
|
+
"""
|
|
561
|
+
if store is None:
|
|
562
|
+
raise ValueError("store cannot be None")
|
|
563
|
+
|
|
564
|
+
use_streaming = loading_strategy == LoadingStrategy.STREAMING
|
|
565
|
+
if (stratify_by or drop_na_columns) and use_streaming:
|
|
566
|
+
raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
|
|
567
|
+
|
|
568
|
+
ds = cls._load_csv_source(
|
|
569
|
+
source,
|
|
570
|
+
delimiter=delimiter,
|
|
571
|
+
streaming=use_streaming,
|
|
572
|
+
**kwargs,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
if not use_streaming and (stratify_by or drop_na_columns):
|
|
576
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
577
|
+
ds,
|
|
578
|
+
stratify_by=stratify_by,
|
|
579
|
+
stratify_seed=stratify_seed,
|
|
580
|
+
drop_na_columns=drop_na_columns,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
return cls(ds, store=store, loading_strategy=loading_strategy)
|
|
584
|
+
|
|
585
|
+
@classmethod
|
|
586
|
+
def from_json(
|
|
587
|
+
cls,
|
|
588
|
+
source: Union[str, Path],
|
|
589
|
+
store: Store,
|
|
590
|
+
*,
|
|
591
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
592
|
+
text_field: str = "text",
|
|
593
|
+
stratify_by: Optional[str] = None,
|
|
594
|
+
stratify_seed: Optional[int] = None,
|
|
595
|
+
drop_na_columns: Optional[List[str]] = None,
|
|
596
|
+
**kwargs,
|
|
597
|
+
) -> "BaseDataset":
|
|
598
|
+
"""
|
|
599
|
+
Load dataset from JSON or JSONL file.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
source: Path to JSON or JSONL file
|
|
603
|
+
store: Store instance
|
|
604
|
+
loading_strategy: Loading strategy
|
|
605
|
+
text_field: Name of the field containing text (for JSON objects)
|
|
606
|
+
stratify_by: Optional column used for stratified sampling (non-streaming only)
|
|
607
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
608
|
+
drop_na_columns: Optional list of columns to check for None/empty values
|
|
609
|
+
**kwargs: Additional arguments passed to load_dataset
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
BaseDataset instance
|
|
613
|
+
|
|
614
|
+
Raises:
|
|
615
|
+
FileNotFoundError: If JSON file doesn't exist
|
|
616
|
+
ValueError: If store is None or source is invalid
|
|
617
|
+
RuntimeError: If dataset loading fails
|
|
618
|
+
"""
|
|
619
|
+
if store is None:
|
|
620
|
+
raise ValueError("store cannot be None")
|
|
621
|
+
|
|
622
|
+
use_streaming = loading_strategy == LoadingStrategy.STREAMING
|
|
623
|
+
if (stratify_by or drop_na_columns) and use_streaming:
|
|
624
|
+
raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
|
|
625
|
+
|
|
626
|
+
ds = cls._load_json_source(
|
|
627
|
+
source,
|
|
628
|
+
streaming=use_streaming,
|
|
629
|
+
**kwargs,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
if not use_streaming and (stratify_by or drop_na_columns):
|
|
633
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
634
|
+
ds,
|
|
635
|
+
stratify_by=stratify_by,
|
|
636
|
+
stratify_seed=stratify_seed,
|
|
637
|
+
drop_na_columns=drop_na_columns,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
return cls(ds, store=store, loading_strategy=loading_strategy)
|