mi-crow 0.1.1.post12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amber/__init__.py +15 -0
- amber/datasets/__init__.py +11 -0
- amber/datasets/base_dataset.py +640 -0
- amber/datasets/classification_dataset.py +566 -0
- amber/datasets/loading_strategy.py +29 -0
- amber/datasets/text_dataset.py +488 -0
- amber/hooks/__init__.py +20 -0
- amber/hooks/controller.py +171 -0
- amber/hooks/detector.py +95 -0
- amber/hooks/hook.py +218 -0
- amber/hooks/implementations/__init__.py +0 -0
- amber/hooks/implementations/function_controller.py +93 -0
- amber/hooks/implementations/layer_activation_detector.py +96 -0
- amber/hooks/implementations/model_input_detector.py +250 -0
- amber/hooks/implementations/model_output_detector.py +132 -0
- amber/hooks/utils.py +76 -0
- amber/language_model/__init__.py +0 -0
- amber/language_model/activations.py +479 -0
- amber/language_model/context.py +33 -0
- amber/language_model/contracts.py +13 -0
- amber/language_model/hook_metadata.py +38 -0
- amber/language_model/inference.py +525 -0
- amber/language_model/initialization.py +126 -0
- amber/language_model/language_model.py +390 -0
- amber/language_model/layers.py +460 -0
- amber/language_model/persistence.py +177 -0
- amber/language_model/tokenizer.py +203 -0
- amber/language_model/utils.py +97 -0
- amber/mechanistic/__init__.py +0 -0
- amber/mechanistic/sae/__init__.py +0 -0
- amber/mechanistic/sae/autoencoder_context.py +40 -0
- amber/mechanistic/sae/concepts/__init__.py +0 -0
- amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
- amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
- amber/mechanistic/sae/concepts/concept_models.py +9 -0
- amber/mechanistic/sae/concepts/input_tracker.py +68 -0
- amber/mechanistic/sae/modules/__init__.py +5 -0
- amber/mechanistic/sae/modules/l1_sae.py +409 -0
- amber/mechanistic/sae/modules/topk_sae.py +459 -0
- amber/mechanistic/sae/sae.py +166 -0
- amber/mechanistic/sae/sae_trainer.py +604 -0
- amber/mechanistic/sae/training/wandb_logger.py +222 -0
- amber/store/__init__.py +5 -0
- amber/store/local_store.py +437 -0
- amber/store/store.py +276 -0
- amber/store/store_dataloader.py +124 -0
- amber/utils.py +46 -0
- mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
- mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
- mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
- mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset, IterableDataset, load_dataset
|
|
7
|
+
|
|
8
|
+
from amber.datasets.base_dataset import BaseDataset
|
|
9
|
+
from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
10
|
+
from amber.store.store import Store
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ClassificationDataset(BaseDataset):
|
|
14
|
+
"""
|
|
15
|
+
Classification dataset with text and category/label columns.
|
|
16
|
+
Each item is a dict with 'text' and label column(s) as keys.
|
|
17
|
+
Supports single or multiple label columns.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
ds: Dataset | IterableDataset,
|
|
23
|
+
store: Store,
|
|
24
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
25
|
+
text_field: str = "text",
|
|
26
|
+
category_field: Union[str, List[str]] = "category",
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Initialize classification dataset.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
ds: HuggingFace Dataset or IterableDataset
|
|
33
|
+
store: Store instance
|
|
34
|
+
loading_strategy: Loading strategy
|
|
35
|
+
text_field: Name of the column containing text
|
|
36
|
+
category_field: Name(s) of the column(s) containing category/label.
|
|
37
|
+
Can be a single string or a list of strings for multiple labels.
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
ValueError: If text_field or category_field is empty, or fields not found in dataset
|
|
41
|
+
"""
|
|
42
|
+
self._validate_text_field(text_field)
|
|
43
|
+
|
|
44
|
+
# Normalize category_field to list
|
|
45
|
+
if isinstance(category_field, str):
|
|
46
|
+
self._category_fields = [category_field]
|
|
47
|
+
else:
|
|
48
|
+
self._category_fields = list(category_field)
|
|
49
|
+
|
|
50
|
+
self._validate_category_fields(self._category_fields)
|
|
51
|
+
|
|
52
|
+
# Validate dataset
|
|
53
|
+
is_iterable = isinstance(ds, IterableDataset)
|
|
54
|
+
if not is_iterable:
|
|
55
|
+
if text_field not in ds.column_names:
|
|
56
|
+
raise ValueError(f"Dataset must have a '{text_field}' column; got columns: {ds.column_names}")
|
|
57
|
+
for cat_field in self._category_fields:
|
|
58
|
+
if cat_field not in ds.column_names:
|
|
59
|
+
raise ValueError(f"Dataset must have a '{cat_field}' column; got columns: {ds.column_names}")
|
|
60
|
+
# Set format with all required columns
|
|
61
|
+
format_columns = [text_field] + self._category_fields
|
|
62
|
+
ds.set_format("python", columns=format_columns)
|
|
63
|
+
|
|
64
|
+
self._text_field = text_field
|
|
65
|
+
self._category_field = category_field # Keep original for backward compatibility
|
|
66
|
+
super().__init__(ds, store=store, loading_strategy=loading_strategy)
|
|
67
|
+
|
|
68
|
+
def _validate_text_field(self, text_field: str) -> None:
|
|
69
|
+
"""Validate text_field parameter.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
text_field: Text field name to validate
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If text_field is empty or not a string
|
|
76
|
+
"""
|
|
77
|
+
if not text_field or not isinstance(text_field, str) or not text_field.strip():
|
|
78
|
+
raise ValueError(f"text_field must be a non-empty string, got: {text_field!r}")
|
|
79
|
+
|
|
80
|
+
def _validate_category_fields(self, category_fields: List[str]) -> None:
|
|
81
|
+
"""Validate category_fields parameter.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
category_fields: List of category field names to validate
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If category_fields is empty or contains invalid values
|
|
88
|
+
"""
|
|
89
|
+
if not category_fields:
|
|
90
|
+
raise ValueError("category_field cannot be empty")
|
|
91
|
+
|
|
92
|
+
for cat_field in category_fields:
|
|
93
|
+
if not cat_field or not isinstance(cat_field, str) or not cat_field.strip():
|
|
94
|
+
raise ValueError(f"All category fields must be non-empty strings, got invalid field: {cat_field!r}")
|
|
95
|
+
|
|
96
|
+
def _extract_item_from_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
|
97
|
+
"""Extract item (text + categories) from a dataset row.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
row: Dataset row dictionary
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Dictionary with 'text' and category fields as keys
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If required fields are not found in row
|
|
107
|
+
"""
|
|
108
|
+
if self._text_field in row:
|
|
109
|
+
text = row[self._text_field]
|
|
110
|
+
elif "text" in row:
|
|
111
|
+
text = row["text"]
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Text field '{self._text_field}' or 'text' not found in dataset row. "
|
|
115
|
+
f"Available fields: {list(row.keys())}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
item = {"text": text}
|
|
119
|
+
for cat_field in self._category_fields:
|
|
120
|
+
category = row.get(cat_field)
|
|
121
|
+
if category is None:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Category field '{cat_field}' not found in dataset row. Available fields: {list(row.keys())}"
|
|
124
|
+
)
|
|
125
|
+
category = row.get(cat_field) # Potentially None
|
|
126
|
+
item[cat_field] = category
|
|
127
|
+
|
|
128
|
+
return item
|
|
129
|
+
|
|
130
|
+
def __len__(self) -> int:
|
|
131
|
+
"""
|
|
132
|
+
Return the number of items in the dataset.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
136
|
+
"""
|
|
137
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
138
|
+
raise NotImplementedError("len() not supported for STREAMING datasets")
|
|
139
|
+
return self._ds.num_rows
|
|
140
|
+
|
|
141
|
+
def __getitem__(self, idx: IndexLike) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
|
|
142
|
+
"""
|
|
143
|
+
Get item(s) by index. Returns dict with 'text' and label column(s) as keys.
|
|
144
|
+
|
|
145
|
+
For single label: {"text": "...", "category": "..."}
|
|
146
|
+
For multiple labels: {"text": "...", "label1": "...", "label2": "..."}
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
idx: Index (int), slice, or sequence of indices
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Single item dict or list of item dicts
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
156
|
+
IndexError: If index is out of bounds
|
|
157
|
+
ValueError: If dataset is empty
|
|
158
|
+
"""
|
|
159
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
160
|
+
raise NotImplementedError(
|
|
161
|
+
"Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
dataset_len = len(self)
|
|
165
|
+
if dataset_len == 0:
|
|
166
|
+
raise ValueError("Cannot index into empty dataset")
|
|
167
|
+
|
|
168
|
+
if isinstance(idx, int):
|
|
169
|
+
if idx < 0:
|
|
170
|
+
idx = dataset_len + idx
|
|
171
|
+
if idx < 0 or idx >= dataset_len:
|
|
172
|
+
raise IndexError(f"Index {idx} out of bounds for dataset of length {dataset_len}")
|
|
173
|
+
row = self._ds[idx]
|
|
174
|
+
return self._extract_item_from_row(row)
|
|
175
|
+
|
|
176
|
+
if isinstance(idx, slice):
|
|
177
|
+
start, stop, step = idx.indices(dataset_len)
|
|
178
|
+
if step != 1:
|
|
179
|
+
indices = list(range(start, stop, step))
|
|
180
|
+
selected = self._ds.select(indices)
|
|
181
|
+
else:
|
|
182
|
+
selected = self._ds.select(range(start, stop))
|
|
183
|
+
return [self._extract_item_from_row(row) for row in selected]
|
|
184
|
+
|
|
185
|
+
if isinstance(idx, Sequence):
|
|
186
|
+
# Validate all indices are in bounds
|
|
187
|
+
invalid_indices = [i for i in idx if not (0 <= i < dataset_len)]
|
|
188
|
+
if invalid_indices:
|
|
189
|
+
raise IndexError(f"Indices out of bounds: {invalid_indices} (dataset length: {dataset_len})")
|
|
190
|
+
selected = self._ds.select(list(idx))
|
|
191
|
+
return [self._extract_item_from_row(row) for row in selected]
|
|
192
|
+
|
|
193
|
+
raise TypeError(f"Invalid index type: {type(idx)}")
|
|
194
|
+
|
|
195
|
+
def iter_items(self) -> Iterator[Dict[str, Any]]:
|
|
196
|
+
"""
|
|
197
|
+
Iterate over items one by one. Yields dict with 'text' and label column(s) as keys.
|
|
198
|
+
|
|
199
|
+
For single label: {"text": "...", "category_column_1": "..."}
|
|
200
|
+
For multiple labels: {"text": "...", "category_column_1": "...", "category_column_2": "..."}
|
|
201
|
+
|
|
202
|
+
Yields:
|
|
203
|
+
Item dictionaries with text and category fields
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ValueError: If required fields are not found in any row
|
|
207
|
+
"""
|
|
208
|
+
for row in self._ds:
|
|
209
|
+
yield self._extract_item_from_row(row)
|
|
210
|
+
|
|
211
|
+
def iter_batches(self, batch_size: int) -> Iterator[List[Dict[str, Any]]]:
|
|
212
|
+
"""
|
|
213
|
+
Iterate over items in batches. Each batch is a list of dicts with 'text' and label column(s) as keys.
|
|
214
|
+
|
|
215
|
+
For single label: [{"text": "...", "category_column_1": "..."}, ...]
|
|
216
|
+
For multiple labels: [{"text": "...", "category_column_1": "...", "category_column_2": "..."}, ...]
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
batch_size: Number of items per batch
|
|
220
|
+
|
|
221
|
+
Yields:
|
|
222
|
+
Lists of item dictionaries (batches)
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
ValueError: If batch_size <= 0 or required fields are not found in any row
|
|
226
|
+
"""
|
|
227
|
+
if batch_size <= 0:
|
|
228
|
+
raise ValueError(f"batch_size must be > 0, got: {batch_size}")
|
|
229
|
+
|
|
230
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
231
|
+
batch = []
|
|
232
|
+
for row in self._ds:
|
|
233
|
+
batch.append(self._extract_item_from_row(row))
|
|
234
|
+
if len(batch) >= batch_size:
|
|
235
|
+
yield batch
|
|
236
|
+
batch = []
|
|
237
|
+
if batch:
|
|
238
|
+
yield batch
|
|
239
|
+
else:
|
|
240
|
+
# Use select to get batches with proper format
|
|
241
|
+
for i in range(0, len(self), batch_size):
|
|
242
|
+
end = min(i + batch_size, len(self))
|
|
243
|
+
batch_list = self[i:end]
|
|
244
|
+
yield batch_list
|
|
245
|
+
|
|
246
|
+
def get_categories(self) -> Union[List[Any], Dict[str, List[Any]]]: # noqa: C901
|
|
247
|
+
"""
|
|
248
|
+
Get unique categories in the dataset, excluding None values.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
- For single label column: List of unique category values
|
|
252
|
+
- For multiple label columns: Dict mapping column name to list of unique categories
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
NotImplementedError: If loading_strategy is STREAMING and dataset is large
|
|
256
|
+
"""
|
|
257
|
+
if len(self._category_fields) == 1:
|
|
258
|
+
# Single label: return list for backward compatibility
|
|
259
|
+
cat_field = self._category_fields[0]
|
|
260
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
261
|
+
categories = set()
|
|
262
|
+
for item in self.iter_items():
|
|
263
|
+
cat = item[cat_field]
|
|
264
|
+
if cat is not None:
|
|
265
|
+
categories.add(cat)
|
|
266
|
+
return sorted(list(categories)) # noqa: C414
|
|
267
|
+
categories = [cat for cat in set(self._ds[cat_field]) if cat is not None]
|
|
268
|
+
return sorted(categories)
|
|
269
|
+
else:
|
|
270
|
+
# Multiple labels: return dict
|
|
271
|
+
result = {}
|
|
272
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
273
|
+
# Collect categories from all items
|
|
274
|
+
category_sets = {field: set() for field in self._category_fields}
|
|
275
|
+
for item in self.iter_items():
|
|
276
|
+
for field in self._category_fields:
|
|
277
|
+
cat = item[field]
|
|
278
|
+
if cat is not None:
|
|
279
|
+
category_sets[field].add(cat)
|
|
280
|
+
for field in self._category_fields:
|
|
281
|
+
result[field] = sorted(list(category_sets[field])) # noqa: C414
|
|
282
|
+
else:
|
|
283
|
+
# Use direct column access
|
|
284
|
+
for field in self._category_fields:
|
|
285
|
+
categories = [cat for cat in set(self._ds[field]) if cat is not None]
|
|
286
|
+
result[field] = sorted(categories)
|
|
287
|
+
return result
|
|
288
|
+
|
|
289
|
+
def extract_texts_from_batch(self, batch: List[Dict[str, Any]]) -> List[Optional[str]]:
|
|
290
|
+
"""Extract text strings from a batch of classification items.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
batch: List of dicts with 'text' and category fields
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
List of text strings from the batch
|
|
297
|
+
|
|
298
|
+
Raises:
|
|
299
|
+
ValueError: If 'text' key is not found in any batch item
|
|
300
|
+
"""
|
|
301
|
+
texts = []
|
|
302
|
+
for item in batch:
|
|
303
|
+
if "text" not in item:
|
|
304
|
+
raise ValueError(f"'text' key not found in batch item. Available keys: {list(item.keys())}")
|
|
305
|
+
texts.append(item["text"])
|
|
306
|
+
return texts
|
|
307
|
+
|
|
308
|
+
def get_all_texts(self) -> List[Optional[str]]:
|
|
309
|
+
"""Get all texts from the dataset.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
List of all text strings
|
|
313
|
+
|
|
314
|
+
Raises:
|
|
315
|
+
NotImplementedError: If loading_strategy is STREAMING and dataset is very large
|
|
316
|
+
"""
|
|
317
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
318
|
+
return [item["text"] for item in self.iter_items()]
|
|
319
|
+
return list(self._ds[self._text_field])
|
|
320
|
+
|
|
321
|
+
def get_categories_for_texts(self, texts: List[Optional[str]]) -> Union[List[Any], List[Dict[str, Any]]]:
|
|
322
|
+
"""
|
|
323
|
+
Get categories for given texts (if texts match dataset texts).
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
texts: List of text strings to look up
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
- For single label column: List of category values (one per text)
|
|
330
|
+
- For multiple label columns: List of dicts with label columns as keys
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
334
|
+
ValueError: If texts list is empty
|
|
335
|
+
"""
|
|
336
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
337
|
+
raise NotImplementedError("get_categories_for_texts not supported for STREAMING datasets")
|
|
338
|
+
|
|
339
|
+
if not texts:
|
|
340
|
+
raise ValueError("texts list cannot be empty")
|
|
341
|
+
|
|
342
|
+
if len(self._category_fields) == 1:
|
|
343
|
+
# Single label: return list for backward compatibility
|
|
344
|
+
cat_field = self._category_fields[0]
|
|
345
|
+
text_to_category = {row[self._text_field]: row[cat_field] for row in self._ds}
|
|
346
|
+
return [text_to_category.get(text) for text in texts]
|
|
347
|
+
else:
|
|
348
|
+
# Multiple labels: return list of dicts
|
|
349
|
+
text_to_categories = {
|
|
350
|
+
row[self._text_field]: {field: row[field] for field in self._category_fields} for row in self._ds
|
|
351
|
+
}
|
|
352
|
+
return [text_to_categories.get(text) for text in texts]
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def from_huggingface(
|
|
356
|
+
cls,
|
|
357
|
+
repo_id: str,
|
|
358
|
+
store: Store,
|
|
359
|
+
*,
|
|
360
|
+
split: str = "train",
|
|
361
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
362
|
+
revision: Optional[str] = None,
|
|
363
|
+
text_field: str = "text",
|
|
364
|
+
category_field: Union[str, List[str]] = "category",
|
|
365
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
366
|
+
limit: Optional[int] = None,
|
|
367
|
+
stratify_by: Optional[str] = None,
|
|
368
|
+
stratify_seed: Optional[int] = None,
|
|
369
|
+
streaming: Optional[bool] = None,
|
|
370
|
+
drop_na: bool = False,
|
|
371
|
+
**kwargs,
|
|
372
|
+
) -> "ClassificationDataset":
|
|
373
|
+
"""
|
|
374
|
+
Load classification dataset from HuggingFace Hub.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
repo_id: HuggingFace dataset repository ID
|
|
378
|
+
store: Store instance
|
|
379
|
+
split: Dataset split
|
|
380
|
+
loading_strategy: Loading strategy
|
|
381
|
+
revision: Optional git revision
|
|
382
|
+
text_field: Name of the column containing text
|
|
383
|
+
category_field: Name(s) of the column(s) containing category/label
|
|
384
|
+
filters: Optional filters to apply (dict of column: value)
|
|
385
|
+
limit: Optional limit on number of rows
|
|
386
|
+
stratify_by: Optional column used for stratified sampling (non-streaming only)
|
|
387
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
388
|
+
streaming: Optional override for streaming
|
|
389
|
+
drop_na: Whether to drop rows with None/empty text or categories
|
|
390
|
+
**kwargs: Additional arguments for load_dataset
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
ClassificationDataset instance
|
|
394
|
+
|
|
395
|
+
Raises:
|
|
396
|
+
ValueError: If parameters are invalid
|
|
397
|
+
RuntimeError: If dataset loading fails
|
|
398
|
+
"""
|
|
399
|
+
use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)
|
|
400
|
+
|
|
401
|
+
if (stratify_by or drop_na) and use_streaming:
|
|
402
|
+
raise NotImplementedError(
|
|
403
|
+
"Stratification and drop_na are not supported for streaming datasets. Use MEMORY or DISK."
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
try:
|
|
407
|
+
ds = load_dataset(
|
|
408
|
+
path=repo_id,
|
|
409
|
+
split=split,
|
|
410
|
+
revision=revision,
|
|
411
|
+
streaming=use_streaming,
|
|
412
|
+
**kwargs,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
if use_streaming:
|
|
416
|
+
if filters or limit:
|
|
417
|
+
raise NotImplementedError(
|
|
418
|
+
"filters and limit are not supported when streaming datasets. Choose MEMORY or DISK."
|
|
419
|
+
)
|
|
420
|
+
else:
|
|
421
|
+
drop_na_columns = None
|
|
422
|
+
if drop_na:
|
|
423
|
+
cat_fields = [category_field] if isinstance(category_field, str) else category_field
|
|
424
|
+
drop_na_columns = [text_field] + list(cat_fields)
|
|
425
|
+
|
|
426
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
427
|
+
ds,
|
|
428
|
+
filters=filters,
|
|
429
|
+
limit=limit,
|
|
430
|
+
stratify_by=stratify_by,
|
|
431
|
+
stratify_seed=stratify_seed,
|
|
432
|
+
drop_na_columns=drop_na_columns,
|
|
433
|
+
)
|
|
434
|
+
except Exception as e:
|
|
435
|
+
raise RuntimeError(
|
|
436
|
+
f"Failed to load classification dataset from HuggingFace Hub: "
|
|
437
|
+
f"repo_id={repo_id!r}, split={split!r}, text_field={text_field!r}, "
|
|
438
|
+
f"category_field={category_field!r}. Error: {e}"
|
|
439
|
+
) from e
|
|
440
|
+
|
|
441
|
+
return cls(
|
|
442
|
+
ds,
|
|
443
|
+
store=store,
|
|
444
|
+
loading_strategy=loading_strategy,
|
|
445
|
+
text_field=text_field,
|
|
446
|
+
category_field=category_field,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
@classmethod
|
|
450
|
+
def from_csv(
|
|
451
|
+
cls,
|
|
452
|
+
source: Union[str, Path],
|
|
453
|
+
store: Store,
|
|
454
|
+
*,
|
|
455
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
456
|
+
text_field: str = "text",
|
|
457
|
+
category_field: Union[str, List[str]] = "category",
|
|
458
|
+
delimiter: str = ",",
|
|
459
|
+
stratify_by: Optional[str] = None,
|
|
460
|
+
stratify_seed: Optional[int] = None,
|
|
461
|
+
drop_na: bool = False,
|
|
462
|
+
**kwargs,
|
|
463
|
+
) -> "ClassificationDataset":
|
|
464
|
+
"""
|
|
465
|
+
Load classification dataset from CSV file.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
source: Path to CSV file
|
|
469
|
+
store: Store instance
|
|
470
|
+
loading_strategy: Loading strategy
|
|
471
|
+
text_field: Name of the column containing text
|
|
472
|
+
category_field: Name(s) of the column(s) containing category/label
|
|
473
|
+
delimiter: CSV delimiter (default: comma)
|
|
474
|
+
stratify_by: Optional column used for stratified sampling
|
|
475
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
476
|
+
drop_na: Whether to drop rows with None/empty text or categories
|
|
477
|
+
**kwargs: Additional arguments for load_dataset
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
ClassificationDataset instance
|
|
481
|
+
|
|
482
|
+
Raises:
|
|
483
|
+
FileNotFoundError: If CSV file doesn't exist
|
|
484
|
+
RuntimeError: If dataset loading fails
|
|
485
|
+
"""
|
|
486
|
+
drop_na_columns = None
|
|
487
|
+
if drop_na:
|
|
488
|
+
cat_fields = [category_field] if isinstance(category_field, str) else category_field
|
|
489
|
+
drop_na_columns = [text_field] + list(cat_fields)
|
|
490
|
+
|
|
491
|
+
dataset = super().from_csv(
|
|
492
|
+
source,
|
|
493
|
+
store=store,
|
|
494
|
+
loading_strategy=loading_strategy,
|
|
495
|
+
text_field=text_field,
|
|
496
|
+
delimiter=delimiter,
|
|
497
|
+
stratify_by=stratify_by,
|
|
498
|
+
stratify_seed=stratify_seed,
|
|
499
|
+
drop_na_columns=drop_na_columns,
|
|
500
|
+
**kwargs,
|
|
501
|
+
)
|
|
502
|
+
return cls(
|
|
503
|
+
dataset._ds,
|
|
504
|
+
store=store,
|
|
505
|
+
loading_strategy=loading_strategy,
|
|
506
|
+
text_field=text_field,
|
|
507
|
+
category_field=category_field,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
@classmethod
|
|
511
|
+
def from_json(
|
|
512
|
+
cls,
|
|
513
|
+
source: Union[str, Path],
|
|
514
|
+
store: Store,
|
|
515
|
+
*,
|
|
516
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
517
|
+
text_field: str = "text",
|
|
518
|
+
category_field: Union[str, List[str]] = "category",
|
|
519
|
+
stratify_by: Optional[str] = None,
|
|
520
|
+
stratify_seed: Optional[int] = None,
|
|
521
|
+
drop_na: bool = False,
|
|
522
|
+
**kwargs,
|
|
523
|
+
) -> "ClassificationDataset":
|
|
524
|
+
"""
|
|
525
|
+
Load classification dataset from JSON/JSONL file.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
source: Path to JSON or JSONL file
|
|
529
|
+
store: Store instance
|
|
530
|
+
loading_strategy: Loading strategy
|
|
531
|
+
text_field: Name of the field containing text
|
|
532
|
+
category_field: Name(s) of the field(s) containing category/label
|
|
533
|
+
stratify_by: Optional column used for stratified sampling
|
|
534
|
+
stratify_seed: Optional RNG seed for stratified sampling
|
|
535
|
+
drop_na: Whether to drop rows with None/empty text or categories
|
|
536
|
+
**kwargs: Additional arguments for load_dataset
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
ClassificationDataset instance
|
|
540
|
+
|
|
541
|
+
Raises:
|
|
542
|
+
FileNotFoundError: If JSON file doesn't exist
|
|
543
|
+
RuntimeError: If dataset loading fails
|
|
544
|
+
"""
|
|
545
|
+
drop_na_columns = None
|
|
546
|
+
if drop_na:
|
|
547
|
+
cat_fields = [category_field] if isinstance(category_field, str) else category_field
|
|
548
|
+
drop_na_columns = [text_field] + list(cat_fields)
|
|
549
|
+
|
|
550
|
+
dataset = super().from_json(
|
|
551
|
+
source,
|
|
552
|
+
store=store,
|
|
553
|
+
loading_strategy=loading_strategy,
|
|
554
|
+
text_field=text_field,
|
|
555
|
+
stratify_by=stratify_by,
|
|
556
|
+
stratify_seed=stratify_seed,
|
|
557
|
+
drop_na_columns=drop_na_columns,
|
|
558
|
+
**kwargs,
|
|
559
|
+
)
|
|
560
|
+
return cls(
|
|
561
|
+
dataset._ds,
|
|
562
|
+
store=store,
|
|
563
|
+
loading_strategy=loading_strategy,
|
|
564
|
+
text_field=text_field,
|
|
565
|
+
category_field=category_field,
|
|
566
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Union, Sequence, TypeAlias
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LoadingStrategy(Enum):
|
|
8
|
+
"""
|
|
9
|
+
Strategy for loading dataset data.
|
|
10
|
+
|
|
11
|
+
Choose the best strategy for your use case:
|
|
12
|
+
|
|
13
|
+
- MEMORY: Load entire dataset into memory (fastest random access, highest memory usage)
|
|
14
|
+
Best for: Small datasets that fit in memory, when you need fast random access
|
|
15
|
+
|
|
16
|
+
- DISK: Save to disk, read dynamically via memory-mapped Arrow files
|
|
17
|
+
(supports len/getitem, lower memory usage)
|
|
18
|
+
Best for: Large datasets that don't fit in memory, when you need random access
|
|
19
|
+
|
|
20
|
+
- STREAMING: True streaming mode using IterableDataset (lowest memory, no len/getitem support)
|
|
21
|
+
Best for: Very large datasets, when you only need sequential iteration
|
|
22
|
+
"""
|
|
23
|
+
MEMORY = "memory" # Load all into memory (fastest random access, highest memory usage)
|
|
24
|
+
DISK = "disk" # Save to disk, read dynamically via memory-mapped Arrow files (supports len/getitem, lower memory usage)
|
|
25
|
+
STREAMING = "streaming" # True streaming mode using IterableDataset (lowest memory, no len/getitem support)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
IndexLike: TypeAlias = Union[int, slice, Sequence[int]]
|
|
29
|
+
|