mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,566 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
5
+
6
+ from datasets import Dataset, IterableDataset, load_dataset
7
+
8
+ from amber.datasets.base_dataset import BaseDataset
9
+ from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
10
+ from amber.store.store import Store
11
+
12
+
13
+ class ClassificationDataset(BaseDataset):
14
+ """
15
+ Classification dataset with text and category/label columns.
16
+ Each item is a dict with 'text' and label column(s) as keys.
17
+ Supports single or multiple label columns.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ ds: Dataset | IterableDataset,
23
+ store: Store,
24
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
25
+ text_field: str = "text",
26
+ category_field: Union[str, List[str]] = "category",
27
+ ):
28
+ """
29
+ Initialize classification dataset.
30
+
31
+ Args:
32
+ ds: HuggingFace Dataset or IterableDataset
33
+ store: Store instance
34
+ loading_strategy: Loading strategy
35
+ text_field: Name of the column containing text
36
+ category_field: Name(s) of the column(s) containing category/label.
37
+ Can be a single string or a list of strings for multiple labels.
38
+
39
+ Raises:
40
+ ValueError: If text_field or category_field is empty, or fields not found in dataset
41
+ """
42
+ self._validate_text_field(text_field)
43
+
44
+ # Normalize category_field to list
45
+ if isinstance(category_field, str):
46
+ self._category_fields = [category_field]
47
+ else:
48
+ self._category_fields = list(category_field)
49
+
50
+ self._validate_category_fields(self._category_fields)
51
+
52
+ # Validate dataset
53
+ is_iterable = isinstance(ds, IterableDataset)
54
+ if not is_iterable:
55
+ if text_field not in ds.column_names:
56
+ raise ValueError(f"Dataset must have a '{text_field}' column; got columns: {ds.column_names}")
57
+ for cat_field in self._category_fields:
58
+ if cat_field not in ds.column_names:
59
+ raise ValueError(f"Dataset must have a '{cat_field}' column; got columns: {ds.column_names}")
60
+ # Set format with all required columns
61
+ format_columns = [text_field] + self._category_fields
62
+ ds.set_format("python", columns=format_columns)
63
+
64
+ self._text_field = text_field
65
+ self._category_field = category_field # Keep original for backward compatibility
66
+ super().__init__(ds, store=store, loading_strategy=loading_strategy)
67
+
68
+ def _validate_text_field(self, text_field: str) -> None:
69
+ """Validate text_field parameter.
70
+
71
+ Args:
72
+ text_field: Text field name to validate
73
+
74
+ Raises:
75
+ ValueError: If text_field is empty or not a string
76
+ """
77
+ if not text_field or not isinstance(text_field, str) or not text_field.strip():
78
+ raise ValueError(f"text_field must be a non-empty string, got: {text_field!r}")
79
+
80
+ def _validate_category_fields(self, category_fields: List[str]) -> None:
81
+ """Validate category_fields parameter.
82
+
83
+ Args:
84
+ category_fields: List of category field names to validate
85
+
86
+ Raises:
87
+ ValueError: If category_fields is empty or contains invalid values
88
+ """
89
+ if not category_fields:
90
+ raise ValueError("category_field cannot be empty")
91
+
92
+ for cat_field in category_fields:
93
+ if not cat_field or not isinstance(cat_field, str) or not cat_field.strip():
94
+ raise ValueError(f"All category fields must be non-empty strings, got invalid field: {cat_field!r}")
95
+
96
+ def _extract_item_from_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
97
+ """Extract item (text + categories) from a dataset row.
98
+
99
+ Args:
100
+ row: Dataset row dictionary
101
+
102
+ Returns:
103
+ Dictionary with 'text' and category fields as keys
104
+
105
+ Raises:
106
+ ValueError: If required fields are not found in row
107
+ """
108
+ if self._text_field in row:
109
+ text = row[self._text_field]
110
+ elif "text" in row:
111
+ text = row["text"]
112
+ else:
113
+ raise ValueError(
114
+ f"Text field '{self._text_field}' or 'text' not found in dataset row. "
115
+ f"Available fields: {list(row.keys())}"
116
+ )
117
+
118
+ item = {"text": text}
119
+ for cat_field in self._category_fields:
120
+ category = row.get(cat_field)
121
+ if category is None:
122
+ raise ValueError(
123
+ f"Category field '{cat_field}' not found in dataset row. Available fields: {list(row.keys())}"
124
+ )
125
+ category = row.get(cat_field) # Potentially None
126
+ item[cat_field] = category
127
+
128
+ return item
129
+
130
+ def __len__(self) -> int:
131
+ """
132
+ Return the number of items in the dataset.
133
+
134
+ Raises:
135
+ NotImplementedError: If loading_strategy is STREAMING
136
+ """
137
+ if self._loading_strategy == LoadingStrategy.STREAMING:
138
+ raise NotImplementedError("len() not supported for STREAMING datasets")
139
+ return self._ds.num_rows
140
+
141
+ def __getitem__(self, idx: IndexLike) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
142
+ """
143
+ Get item(s) by index. Returns dict with 'text' and label column(s) as keys.
144
+
145
+ For single label: {"text": "...", "category": "..."}
146
+ For multiple labels: {"text": "...", "label1": "...", "label2": "..."}
147
+
148
+ Args:
149
+ idx: Index (int), slice, or sequence of indices
150
+
151
+ Returns:
152
+ Single item dict or list of item dicts
153
+
154
+ Raises:
155
+ NotImplementedError: If loading_strategy is STREAMING
156
+ IndexError: If index is out of bounds
157
+ ValueError: If dataset is empty
158
+ """
159
+ if self._loading_strategy == LoadingStrategy.STREAMING:
160
+ raise NotImplementedError(
161
+ "Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
162
+ )
163
+
164
+ dataset_len = len(self)
165
+ if dataset_len == 0:
166
+ raise ValueError("Cannot index into empty dataset")
167
+
168
+ if isinstance(idx, int):
169
+ if idx < 0:
170
+ idx = dataset_len + idx
171
+ if idx < 0 or idx >= dataset_len:
172
+ raise IndexError(f"Index {idx} out of bounds for dataset of length {dataset_len}")
173
+ row = self._ds[idx]
174
+ return self._extract_item_from_row(row)
175
+
176
+ if isinstance(idx, slice):
177
+ start, stop, step = idx.indices(dataset_len)
178
+ if step != 1:
179
+ indices = list(range(start, stop, step))
180
+ selected = self._ds.select(indices)
181
+ else:
182
+ selected = self._ds.select(range(start, stop))
183
+ return [self._extract_item_from_row(row) for row in selected]
184
+
185
+ if isinstance(idx, Sequence):
186
+ # Validate all indices are in bounds
187
+ invalid_indices = [i for i in idx if not (0 <= i < dataset_len)]
188
+ if invalid_indices:
189
+ raise IndexError(f"Indices out of bounds: {invalid_indices} (dataset length: {dataset_len})")
190
+ selected = self._ds.select(list(idx))
191
+ return [self._extract_item_from_row(row) for row in selected]
192
+
193
+ raise TypeError(f"Invalid index type: {type(idx)}")
194
+
195
+ def iter_items(self) -> Iterator[Dict[str, Any]]:
196
+ """
197
+ Iterate over items one by one. Yields dict with 'text' and label column(s) as keys.
198
+
199
+ For single label: {"text": "...", "category_column_1": "..."}
200
+ For multiple labels: {"text": "...", "category_column_1": "...", "category_column_2": "..."}
201
+
202
+ Yields:
203
+ Item dictionaries with text and category fields
204
+
205
+ Raises:
206
+ ValueError: If required fields are not found in any row
207
+ """
208
+ for row in self._ds:
209
+ yield self._extract_item_from_row(row)
210
+
211
+ def iter_batches(self, batch_size: int) -> Iterator[List[Dict[str, Any]]]:
212
+ """
213
+ Iterate over items in batches. Each batch is a list of dicts with 'text' and label column(s) as keys.
214
+
215
+ For single label: [{"text": "...", "category_column_1": "..."}, ...]
216
+ For multiple labels: [{"text": "...", "category_column_1": "...", "category_column_2": "..."}, ...]
217
+
218
+ Args:
219
+ batch_size: Number of items per batch
220
+
221
+ Yields:
222
+ Lists of item dictionaries (batches)
223
+
224
+ Raises:
225
+ ValueError: If batch_size <= 0 or required fields are not found in any row
226
+ """
227
+ if batch_size <= 0:
228
+ raise ValueError(f"batch_size must be > 0, got: {batch_size}")
229
+
230
+ if self._loading_strategy == LoadingStrategy.STREAMING:
231
+ batch = []
232
+ for row in self._ds:
233
+ batch.append(self._extract_item_from_row(row))
234
+ if len(batch) >= batch_size:
235
+ yield batch
236
+ batch = []
237
+ if batch:
238
+ yield batch
239
+ else:
240
+ # Use select to get batches with proper format
241
+ for i in range(0, len(self), batch_size):
242
+ end = min(i + batch_size, len(self))
243
+ batch_list = self[i:end]
244
+ yield batch_list
245
+
246
+ def get_categories(self) -> Union[List[Any], Dict[str, List[Any]]]: # noqa: C901
247
+ """
248
+ Get unique categories in the dataset, excluding None values.
249
+
250
+ Returns:
251
+ - For single label column: List of unique category values
252
+ - For multiple label columns: Dict mapping column name to list of unique categories
253
+
254
+ Raises:
255
+ NotImplementedError: If loading_strategy is STREAMING and dataset is large
256
+ """
257
+ if len(self._category_fields) == 1:
258
+ # Single label: return list for backward compatibility
259
+ cat_field = self._category_fields[0]
260
+ if self._loading_strategy == LoadingStrategy.STREAMING:
261
+ categories = set()
262
+ for item in self.iter_items():
263
+ cat = item[cat_field]
264
+ if cat is not None:
265
+ categories.add(cat)
266
+ return sorted(list(categories)) # noqa: C414
267
+ categories = [cat for cat in set(self._ds[cat_field]) if cat is not None]
268
+ return sorted(categories)
269
+ else:
270
+ # Multiple labels: return dict
271
+ result = {}
272
+ if self._loading_strategy == LoadingStrategy.STREAMING:
273
+ # Collect categories from all items
274
+ category_sets = {field: set() for field in self._category_fields}
275
+ for item in self.iter_items():
276
+ for field in self._category_fields:
277
+ cat = item[field]
278
+ if cat is not None:
279
+ category_sets[field].add(cat)
280
+ for field in self._category_fields:
281
+ result[field] = sorted(list(category_sets[field])) # noqa: C414
282
+ else:
283
+ # Use direct column access
284
+ for field in self._category_fields:
285
+ categories = [cat for cat in set(self._ds[field]) if cat is not None]
286
+ result[field] = sorted(categories)
287
+ return result
288
+
289
+ def extract_texts_from_batch(self, batch: List[Dict[str, Any]]) -> List[Optional[str]]:
290
+ """Extract text strings from a batch of classification items.
291
+
292
+ Args:
293
+ batch: List of dicts with 'text' and category fields
294
+
295
+ Returns:
296
+ List of text strings from the batch
297
+
298
+ Raises:
299
+ ValueError: If 'text' key is not found in any batch item
300
+ """
301
+ texts = []
302
+ for item in batch:
303
+ if "text" not in item:
304
+ raise ValueError(f"'text' key not found in batch item. Available keys: {list(item.keys())}")
305
+ texts.append(item["text"])
306
+ return texts
307
+
308
+ def get_all_texts(self) -> List[Optional[str]]:
309
+ """Get all texts from the dataset.
310
+
311
+ Returns:
312
+ List of all text strings
313
+
314
+ Raises:
315
+ NotImplementedError: If loading_strategy is STREAMING and dataset is very large
316
+ """
317
+ if self._loading_strategy == LoadingStrategy.STREAMING:
318
+ return [item["text"] for item in self.iter_items()]
319
+ return list(self._ds[self._text_field])
320
+
321
+ def get_categories_for_texts(self, texts: List[Optional[str]]) -> Union[List[Any], List[Dict[str, Any]]]:
322
+ """
323
+ Get categories for given texts (if texts match dataset texts).
324
+
325
+ Args:
326
+ texts: List of text strings to look up
327
+
328
+ Returns:
329
+ - For single label column: List of category values (one per text)
330
+ - For multiple label columns: List of dicts with label columns as keys
331
+
332
+ Raises:
333
+ NotImplementedError: If loading_strategy is STREAMING
334
+ ValueError: If texts list is empty
335
+ """
336
+ if self._loading_strategy == LoadingStrategy.STREAMING:
337
+ raise NotImplementedError("get_categories_for_texts not supported for STREAMING datasets")
338
+
339
+ if not texts:
340
+ raise ValueError("texts list cannot be empty")
341
+
342
+ if len(self._category_fields) == 1:
343
+ # Single label: return list for backward compatibility
344
+ cat_field = self._category_fields[0]
345
+ text_to_category = {row[self._text_field]: row[cat_field] for row in self._ds}
346
+ return [text_to_category.get(text) for text in texts]
347
+ else:
348
+ # Multiple labels: return list of dicts
349
+ text_to_categories = {
350
+ row[self._text_field]: {field: row[field] for field in self._category_fields} for row in self._ds
351
+ }
352
+ return [text_to_categories.get(text) for text in texts]
353
+
354
+ @classmethod
355
+ def from_huggingface(
356
+ cls,
357
+ repo_id: str,
358
+ store: Store,
359
+ *,
360
+ split: str = "train",
361
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
362
+ revision: Optional[str] = None,
363
+ text_field: str = "text",
364
+ category_field: Union[str, List[str]] = "category",
365
+ filters: Optional[Dict[str, Any]] = None,
366
+ limit: Optional[int] = None,
367
+ stratify_by: Optional[str] = None,
368
+ stratify_seed: Optional[int] = None,
369
+ streaming: Optional[bool] = None,
370
+ drop_na: bool = False,
371
+ **kwargs,
372
+ ) -> "ClassificationDataset":
373
+ """
374
+ Load classification dataset from HuggingFace Hub.
375
+
376
+ Args:
377
+ repo_id: HuggingFace dataset repository ID
378
+ store: Store instance
379
+ split: Dataset split
380
+ loading_strategy: Loading strategy
381
+ revision: Optional git revision
382
+ text_field: Name of the column containing text
383
+ category_field: Name(s) of the column(s) containing category/label
384
+ filters: Optional filters to apply (dict of column: value)
385
+ limit: Optional limit on number of rows
386
+ stratify_by: Optional column used for stratified sampling (non-streaming only)
387
+ stratify_seed: Optional RNG seed for stratified sampling
388
+ streaming: Optional override for streaming
389
+ drop_na: Whether to drop rows with None/empty text or categories
390
+ **kwargs: Additional arguments for load_dataset
391
+
392
+ Returns:
393
+ ClassificationDataset instance
394
+
395
+ Raises:
396
+ ValueError: If parameters are invalid
397
+ RuntimeError: If dataset loading fails
398
+ """
399
+ use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)
400
+
401
+ if (stratify_by or drop_na) and use_streaming:
402
+ raise NotImplementedError(
403
+ "Stratification and drop_na are not supported for streaming datasets. Use MEMORY or DISK."
404
+ )
405
+
406
+ try:
407
+ ds = load_dataset(
408
+ path=repo_id,
409
+ split=split,
410
+ revision=revision,
411
+ streaming=use_streaming,
412
+ **kwargs,
413
+ )
414
+
415
+ if use_streaming:
416
+ if filters or limit:
417
+ raise NotImplementedError(
418
+ "filters and limit are not supported when streaming datasets. Choose MEMORY or DISK."
419
+ )
420
+ else:
421
+ drop_na_columns = None
422
+ if drop_na:
423
+ cat_fields = [category_field] if isinstance(category_field, str) else category_field
424
+ drop_na_columns = [text_field] + list(cat_fields)
425
+
426
+ ds = cls._postprocess_non_streaming_dataset(
427
+ ds,
428
+ filters=filters,
429
+ limit=limit,
430
+ stratify_by=stratify_by,
431
+ stratify_seed=stratify_seed,
432
+ drop_na_columns=drop_na_columns,
433
+ )
434
+ except Exception as e:
435
+ raise RuntimeError(
436
+ f"Failed to load classification dataset from HuggingFace Hub: "
437
+ f"repo_id={repo_id!r}, split={split!r}, text_field={text_field!r}, "
438
+ f"category_field={category_field!r}. Error: {e}"
439
+ ) from e
440
+
441
+ return cls(
442
+ ds,
443
+ store=store,
444
+ loading_strategy=loading_strategy,
445
+ text_field=text_field,
446
+ category_field=category_field,
447
+ )
448
+
449
+ @classmethod
450
+ def from_csv(
451
+ cls,
452
+ source: Union[str, Path],
453
+ store: Store,
454
+ *,
455
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
456
+ text_field: str = "text",
457
+ category_field: Union[str, List[str]] = "category",
458
+ delimiter: str = ",",
459
+ stratify_by: Optional[str] = None,
460
+ stratify_seed: Optional[int] = None,
461
+ drop_na: bool = False,
462
+ **kwargs,
463
+ ) -> "ClassificationDataset":
464
+ """
465
+ Load classification dataset from CSV file.
466
+
467
+ Args:
468
+ source: Path to CSV file
469
+ store: Store instance
470
+ loading_strategy: Loading strategy
471
+ text_field: Name of the column containing text
472
+ category_field: Name(s) of the column(s) containing category/label
473
+ delimiter: CSV delimiter (default: comma)
474
+ stratify_by: Optional column used for stratified sampling
475
+ stratify_seed: Optional RNG seed for stratified sampling
476
+ drop_na: Whether to drop rows with None/empty text or categories
477
+ **kwargs: Additional arguments for load_dataset
478
+
479
+ Returns:
480
+ ClassificationDataset instance
481
+
482
+ Raises:
483
+ FileNotFoundError: If CSV file doesn't exist
484
+ RuntimeError: If dataset loading fails
485
+ """
486
+ drop_na_columns = None
487
+ if drop_na:
488
+ cat_fields = [category_field] if isinstance(category_field, str) else category_field
489
+ drop_na_columns = [text_field] + list(cat_fields)
490
+
491
+ dataset = super().from_csv(
492
+ source,
493
+ store=store,
494
+ loading_strategy=loading_strategy,
495
+ text_field=text_field,
496
+ delimiter=delimiter,
497
+ stratify_by=stratify_by,
498
+ stratify_seed=stratify_seed,
499
+ drop_na_columns=drop_na_columns,
500
+ **kwargs,
501
+ )
502
+ return cls(
503
+ dataset._ds,
504
+ store=store,
505
+ loading_strategy=loading_strategy,
506
+ text_field=text_field,
507
+ category_field=category_field,
508
+ )
509
+
510
+ @classmethod
511
+ def from_json(
512
+ cls,
513
+ source: Union[str, Path],
514
+ store: Store,
515
+ *,
516
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
517
+ text_field: str = "text",
518
+ category_field: Union[str, List[str]] = "category",
519
+ stratify_by: Optional[str] = None,
520
+ stratify_seed: Optional[int] = None,
521
+ drop_na: bool = False,
522
+ **kwargs,
523
+ ) -> "ClassificationDataset":
524
+ """
525
+ Load classification dataset from JSON/JSONL file.
526
+
527
+ Args:
528
+ source: Path to JSON or JSONL file
529
+ store: Store instance
530
+ loading_strategy: Loading strategy
531
+ text_field: Name of the field containing text
532
+ category_field: Name(s) of the field(s) containing category/label
533
+ stratify_by: Optional column used for stratified sampling
534
+ stratify_seed: Optional RNG seed for stratified sampling
535
+ drop_na: Whether to drop rows with None/empty text or categories
536
+ **kwargs: Additional arguments for load_dataset
537
+
538
+ Returns:
539
+ ClassificationDataset instance
540
+
541
+ Raises:
542
+ FileNotFoundError: If JSON file doesn't exist
543
+ RuntimeError: If dataset loading fails
544
+ """
545
+ drop_na_columns = None
546
+ if drop_na:
547
+ cat_fields = [category_field] if isinstance(category_field, str) else category_field
548
+ drop_na_columns = [text_field] + list(cat_fields)
549
+
550
+ dataset = super().from_json(
551
+ source,
552
+ store=store,
553
+ loading_strategy=loading_strategy,
554
+ text_field=text_field,
555
+ stratify_by=stratify_by,
556
+ stratify_seed=stratify_seed,
557
+ drop_na_columns=drop_na_columns,
558
+ **kwargs,
559
+ )
560
+ return cls(
561
+ dataset._ds,
562
+ store=store,
563
+ loading_strategy=loading_strategy,
564
+ text_field=text_field,
565
+ category_field=category_field,
566
+ )
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Union, Sequence, TypeAlias
5
+
6
+
7
+ class LoadingStrategy(Enum):
8
+ """
9
+ Strategy for loading dataset data.
10
+
11
+ Choose the best strategy for your use case:
12
+
13
+ - MEMORY: Load entire dataset into memory (fastest random access, highest memory usage)
14
+ Best for: Small datasets that fit in memory, when you need fast random access
15
+
16
+ - DISK: Save to disk, read dynamically via memory-mapped Arrow files
17
+ (supports len/getitem, lower memory usage)
18
+ Best for: Large datasets that don't fit in memory, when you need random access
19
+
20
+ - STREAMING: True streaming mode using IterableDataset (lowest memory, no len/getitem support)
21
+ Best for: Very large datasets, when you only need sequential iteration
22
+ """
23
+ MEMORY = "memory" # Load all into memory (fastest random access, highest memory usage)
24
+ DISK = "disk" # Save to disk, read dynamically via memory-mapped Arrow files (supports len/getitem, lower memory usage)
25
+ STREAMING = "streaming" # True streaming mode using IterableDataset (lowest memory, no len/getitem support)
26
+
27
+
28
+ IndexLike: TypeAlias = Union[int, slice, Sequence[int]]
29
+