mi-crow 0.1.1.post12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. amber/__init__.py +15 -0
  2. amber/datasets/__init__.py +11 -0
  3. amber/datasets/base_dataset.py +640 -0
  4. amber/datasets/classification_dataset.py +566 -0
  5. amber/datasets/loading_strategy.py +29 -0
  6. amber/datasets/text_dataset.py +488 -0
  7. amber/hooks/__init__.py +20 -0
  8. amber/hooks/controller.py +171 -0
  9. amber/hooks/detector.py +95 -0
  10. amber/hooks/hook.py +218 -0
  11. amber/hooks/implementations/__init__.py +0 -0
  12. amber/hooks/implementations/function_controller.py +93 -0
  13. amber/hooks/implementations/layer_activation_detector.py +96 -0
  14. amber/hooks/implementations/model_input_detector.py +250 -0
  15. amber/hooks/implementations/model_output_detector.py +132 -0
  16. amber/hooks/utils.py +76 -0
  17. amber/language_model/__init__.py +0 -0
  18. amber/language_model/activations.py +479 -0
  19. amber/language_model/context.py +33 -0
  20. amber/language_model/contracts.py +13 -0
  21. amber/language_model/hook_metadata.py +38 -0
  22. amber/language_model/inference.py +525 -0
  23. amber/language_model/initialization.py +126 -0
  24. amber/language_model/language_model.py +390 -0
  25. amber/language_model/layers.py +460 -0
  26. amber/language_model/persistence.py +177 -0
  27. amber/language_model/tokenizer.py +203 -0
  28. amber/language_model/utils.py +97 -0
  29. amber/mechanistic/__init__.py +0 -0
  30. amber/mechanistic/sae/__init__.py +0 -0
  31. amber/mechanistic/sae/autoencoder_context.py +40 -0
  32. amber/mechanistic/sae/concepts/__init__.py +0 -0
  33. amber/mechanistic/sae/concepts/autoencoder_concepts.py +332 -0
  34. amber/mechanistic/sae/concepts/concept_dictionary.py +206 -0
  35. amber/mechanistic/sae/concepts/concept_models.py +9 -0
  36. amber/mechanistic/sae/concepts/input_tracker.py +68 -0
  37. amber/mechanistic/sae/modules/__init__.py +5 -0
  38. amber/mechanistic/sae/modules/l1_sae.py +409 -0
  39. amber/mechanistic/sae/modules/topk_sae.py +459 -0
  40. amber/mechanistic/sae/sae.py +166 -0
  41. amber/mechanistic/sae/sae_trainer.py +604 -0
  42. amber/mechanistic/sae/training/wandb_logger.py +222 -0
  43. amber/store/__init__.py +5 -0
  44. amber/store/local_store.py +437 -0
  45. amber/store/store.py +276 -0
  46. amber/store/store_dataloader.py +124 -0
  47. amber/utils.py +46 -0
  48. mi_crow-0.1.1.post12.dist-info/METADATA +124 -0
  49. mi_crow-0.1.1.post12.dist-info/RECORD +51 -0
  50. mi_crow-0.1.1.post12.dist-info/WHEEL +5 -0
  51. mi_crow-0.1.1.post12.dist-info/top_level.txt +1 -0
amber/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """Amber: helper package for the Engineer Thesis project.
2
+
3
+ This module is intentionally minimal. It exists to define the top-level package
4
+ and to enable code coverage to include the package. Importing it should succeed
5
+ without side effects.
6
+ """
7
+
8
+ # A tiny bit of executable code to make the package measurable by coverage.
9
+ PACKAGE_NAME = "amber"
10
+ __version__ = "0.0.0"
11
+
12
+
13
+ def ping() -> str:
14
+ """Return a simple response to verify the package is wired correctly."""
15
+ return "pong"
@@ -0,0 +1,11 @@
1
+ from amber.datasets.base_dataset import BaseDataset
2
+ from amber.datasets.text_dataset import TextDataset
3
+ from amber.datasets.classification_dataset import ClassificationDataset
4
+ from amber.datasets.loading_strategy import LoadingStrategy
5
+
6
+ __all__ = [
7
+ "BaseDataset",
8
+ "TextDataset",
9
+ "ClassificationDataset",
10
+ "LoadingStrategy",
11
+ ]
@@ -0,0 +1,640 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+ from abc import ABC, abstractmethod
6
+ from collections import defaultdict
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Iterator, List, Optional, Union
9
+
10
+ from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
11
+
12
+ from amber.datasets.loading_strategy import IndexLike, LoadingStrategy
13
+ from amber.store.store import Store
14
+
15
+
16
+ class BaseDataset(ABC):
17
+ """
18
+ Abstract base class for datasets with support for multiple sources,
19
+ loading strategies, and Store integration.
20
+
21
+ Loading Strategies:
22
+ - MEMORY: Load entire dataset into memory (fastest random access, highest memory usage)
23
+ - DISK: Save to disk, read dynamically via memory-mapped Arrow files
24
+ (supports len/getitem, lower memory usage)
25
+ - STREAMING: True streaming mode using IterableDataset
26
+ (lowest memory, no len/getitem support, no stratification and limit support)
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ ds: Dataset | IterableDataset,
32
+ store: Store,
33
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
34
+ ):
35
+ """
36
+ Initialize dataset.
37
+
38
+ Args:
39
+ ds: HuggingFace Dataset or IterableDataset
40
+ store: Store instance for caching/persistence
41
+ loading_strategy: How to load data (MEMORY, DISK, or STREAMING)
42
+
43
+ Raises:
44
+ ValueError: If store is None, loading_strategy is invalid, or dataset operations fail
45
+ OSError: If file system operations fail
46
+ """
47
+ self._validate_initialization_params(store, loading_strategy)
48
+
49
+ self._store = store
50
+ self._loading_strategy = loading_strategy
51
+ self._dataset_dir: Path = Path(store.base_path) / store.dataset_prefix
52
+
53
+ is_iterable_input = isinstance(ds, IterableDataset)
54
+
55
+ if loading_strategy == LoadingStrategy.MEMORY:
56
+ # MEMORY: Convert to Dataset if needed, save to disk, load fully into memory
57
+ self._is_iterable = False
58
+ if is_iterable_input:
59
+ ds = Dataset.from_generator(lambda: iter(ds))
60
+ self._ds = self._save_and_load_dataset(ds, use_memory_mapping=False)
61
+ elif loading_strategy == LoadingStrategy.DISK:
62
+ # DISK: Save to disk, use memory-mapped Arrow files (supports len/getitem)
63
+ self._is_iterable = False
64
+ if is_iterable_input:
65
+ ds = Dataset.from_generator(lambda: iter(ds))
66
+ self._ds = self._save_and_load_dataset(ds, use_memory_mapping=True)
67
+ elif loading_strategy == LoadingStrategy.STREAMING:
68
+ # STREAMING: Convert to IterableDataset, don't save to disk (no len/getitem)
69
+ if not is_iterable_input:
70
+ ds = ds.to_iterable_dataset()
71
+ self._is_iterable = True
72
+ self._ds = ds
73
+ # Don't save to disk for iterable-only mode
74
+ else:
75
+ raise ValueError(
76
+ f"Unknown loading strategy: {loading_strategy}. Must be one of: {[s.value for s in LoadingStrategy]}"
77
+ )
78
+
79
+ def _validate_initialization_params(self, store: Store, loading_strategy: LoadingStrategy) -> None:
80
+ """Validate initialization parameters.
81
+
82
+ Args:
83
+ store: Store instance to validate
84
+ loading_strategy: Loading strategy to validate
85
+
86
+ Raises:
87
+ ValueError: If store is None or loading_strategy is invalid
88
+ """
89
+ if store is None:
90
+ raise ValueError("store cannot be None")
91
+
92
+ if not isinstance(loading_strategy, LoadingStrategy):
93
+ raise ValueError(f"loading_strategy must be a LoadingStrategy enum value, got: {type(loading_strategy)}")
94
+
95
+ def _has_valid_dataset_dir(self) -> bool:
96
+ """Check if dataset directory path is valid (non-empty base_path).
97
+
98
+ Returns:
99
+ True if base_path is not empty, False otherwise
100
+ """
101
+ return bool(self._store.base_path and str(self._store.base_path).strip())
102
+
103
+ def _save_and_load_dataset(self, ds: Dataset, use_memory_mapping: bool = True) -> Dataset:
104
+ """Save dataset to disk and load it back (with optional memory mapping).
105
+
106
+ Args:
107
+ ds: Dataset to save and load
108
+ use_memory_mapping: Whether to use memory mapping (True for DISK)
109
+
110
+ Returns:
111
+ Loaded dataset
112
+
113
+ Raises:
114
+ OSError: If file system operations fail
115
+ RuntimeError: If dataset operations fail
116
+ """
117
+ if self._has_valid_dataset_dir():
118
+ try:
119
+ self._dataset_dir.mkdir(parents=True, exist_ok=True)
120
+ ds.save_to_disk(str(self._dataset_dir))
121
+ return load_from_disk(str(self._dataset_dir))
122
+ except OSError as e:
123
+ raise OSError(f"Failed to save/load dataset at {self._dataset_dir}. Error: {e}") from e
124
+ except Exception as e:
125
+ raise RuntimeError(f"Failed to process dataset at {self._dataset_dir}. Error: {e}") from e
126
+ else:
127
+ return ds
128
+
129
+ @classmethod
130
+ def _postprocess_non_streaming_dataset(
131
+ cls,
132
+ ds: Dataset,
133
+ *,
134
+ filters: Optional[Dict[str, Any]] = None,
135
+ limit: Optional[int] = None,
136
+ stratify_by: Optional[str] = None,
137
+ stratify_seed: Optional[int] = None,
138
+ drop_na_columns: Optional[List[str]] = None,
139
+ ) -> Dataset:
140
+ """Apply filters, stratified sampling, and limits to an in-memory dataset."""
141
+
142
+ if drop_na_columns:
143
+ ds = cls._drop_na(ds, drop_na_columns)
144
+
145
+ if filters:
146
+ ds = cls._apply_filters(ds, filters)
147
+
148
+ limit_applied = False
149
+ if stratify_by:
150
+ sample_size = limit if limit is not None else len(ds)
151
+ if sample_size is not None and sample_size <= 0:
152
+ raise ValueError(f"limit must be > 0 when stratifying, got: {sample_size}")
153
+ ds = cls._stratified_sample(
154
+ ds,
155
+ stratify_by=stratify_by,
156
+ sample_size=sample_size,
157
+ seed=stratify_seed,
158
+ )
159
+ limit_applied = True
160
+
161
+ if limit is not None and not limit_applied:
162
+ if limit <= 0:
163
+ raise ValueError(f"limit must be > 0, got: {limit}")
164
+ ds = ds.select(range(min(limit, len(ds))))
165
+
166
+ return ds
167
+
168
+ @staticmethod
169
+ def _drop_na(ds: Dataset, columns: List[str]) -> Dataset:
170
+ """Drop rows where any of the specified columns are None or empty string."""
171
+
172
+ def _is_valid(example: Dict[str, Any]) -> bool:
173
+ for col in columns:
174
+ val = example.get(col)
175
+ if val is None:
176
+ return False
177
+ if isinstance(val, str) and not val.strip():
178
+ return False
179
+ return True
180
+
181
+ return ds.filter(_is_valid)
182
+
183
+ @staticmethod
184
+ def _apply_filters(ds: Dataset, filters: Dict[str, Any]) -> Dataset:
185
+ """Apply exact-match filters to a Dataset."""
186
+
187
+ def _predicate(example: Dict[str, Any]) -> bool:
188
+ return all(example.get(key) == value for key, value in filters.items())
189
+
190
+ return ds.filter(_predicate)
191
+
192
+ @staticmethod
193
+ def _stratified_sample( # noqa: C901
194
+ ds: Dataset,
195
+ *,
196
+ stratify_by: str,
197
+ sample_size: Optional[int],
198
+ seed: Optional[int],
199
+ ) -> Dataset:
200
+ """Return a stratified sample of the dataset with the requested size."""
201
+
202
+ if stratify_by not in ds.column_names:
203
+ raise ValueError(f"Column '{stratify_by}' not found in dataset columns: {ds.column_names}")
204
+
205
+ total_rows = len(ds)
206
+ if total_rows == 0:
207
+ return ds
208
+
209
+ if sample_size is None:
210
+ sample_size = total_rows
211
+
212
+ sample_size = min(sample_size, total_rows)
213
+ if sample_size <= 0:
214
+ raise ValueError("sample_size must be greater than 0 for stratification")
215
+
216
+ column_values = ds[stratify_by]
217
+ label_to_indices: Dict[Any, List[int]] = defaultdict(list)
218
+ for idx, label in enumerate(column_values):
219
+ label_to_indices[label].append(idx)
220
+
221
+ label_counts = {label: len(indices) for label, indices in label_to_indices.items()}
222
+ allocations: Dict[Any, int] = {}
223
+ fractional_parts: List[tuple[float, int, Any]] = []
224
+
225
+ allocated_total = 0
226
+ for order, (label, count) in enumerate(label_counts.items()):
227
+ exact_allocation = (count / total_rows) * sample_size
228
+ base_allocation = min(count, int(math.floor(exact_allocation)))
229
+ allocations[label] = base_allocation
230
+ allocated_total += base_allocation
231
+ fractional_parts.append((exact_allocation - base_allocation, order, label))
232
+
233
+ remaining = sample_size - allocated_total
234
+ fractional_parts.sort(key=lambda item: (-item[0], item[1]))
235
+ for _, _, label in fractional_parts:
236
+ if remaining <= 0:
237
+ break
238
+ available = label_counts[label] - allocations[label]
239
+ if available <= 0:
240
+ continue
241
+ take = min(available, remaining)
242
+ allocations[label] += take
243
+ remaining -= take
244
+
245
+ rng = random.Random(seed)
246
+ selected_indices: List[int] = []
247
+ for label, count in allocations.items():
248
+ if count <= 0:
249
+ continue
250
+ indices = label_to_indices[label]
251
+ if count >= len(indices):
252
+ chosen = list(indices)
253
+ else:
254
+ chosen = rng.sample(indices, count)
255
+ selected_indices.extend(chosen)
256
+
257
+ rng.shuffle(selected_indices)
258
+ return ds.select(selected_indices)
259
+
260
+ @staticmethod
261
+ def _load_csv_source(
262
+ source: Union[str, Path],
263
+ *,
264
+ delimiter: str,
265
+ streaming: bool,
266
+ **kwargs,
267
+ ) -> Dataset | IterableDataset:
268
+ """Load a CSV dataset from disk using HuggingFace datasets."""
269
+
270
+ p = Path(source)
271
+ if not p.exists():
272
+ raise FileNotFoundError(f"CSV file not found: {source}")
273
+ if not p.is_file():
274
+ raise ValueError(f"Source must be a file, got: {source}")
275
+
276
+ try:
277
+ return load_dataset(
278
+ "csv",
279
+ data_files=str(p),
280
+ split="train",
281
+ delimiter=delimiter,
282
+ streaming=streaming,
283
+ **kwargs,
284
+ )
285
+ except Exception as e:
286
+ raise RuntimeError(f"Failed to load CSV dataset from {source}. Error: {e}") from e
287
+
288
+ @staticmethod
289
+ def _load_json_source(
290
+ source: Union[str, Path],
291
+ *,
292
+ streaming: bool,
293
+ **kwargs,
294
+ ) -> Dataset | IterableDataset:
295
+ """Load a JSON/JSONL dataset from disk using HuggingFace datasets."""
296
+
297
+ p = Path(source)
298
+ if not p.exists():
299
+ raise FileNotFoundError(f"JSON file not found: {source}")
300
+ if not p.is_file():
301
+ raise ValueError(f"Source must be a file, got: {source}")
302
+
303
+ try:
304
+ return load_dataset(
305
+ "json",
306
+ data_files=str(p),
307
+ split="train",
308
+ streaming=streaming,
309
+ **kwargs,
310
+ )
311
+ except Exception as e:
312
+ raise RuntimeError(f"Failed to load JSON dataset from {source}. Error: {e}") from e
313
+
314
+ def get_batch(self, start: int, batch_size: int) -> List[Any]:
315
+ """
316
+ Get a contiguous batch of items.
317
+
318
+ Args:
319
+ start: Starting index
320
+ batch_size: Number of items to retrieve
321
+
322
+ Returns:
323
+ List of items
324
+
325
+ Raises:
326
+ NotImplementedError: If loading_strategy is STREAMING
327
+ """
328
+ if self._loading_strategy == LoadingStrategy.STREAMING:
329
+ raise NotImplementedError("get_batch not supported for STREAMING datasets. Use iter_batches instead.")
330
+ if batch_size <= 0:
331
+ return []
332
+ end = min(start + batch_size, len(self))
333
+ if start >= end:
334
+ return []
335
+ return self[start:end]
336
+
337
+ def head(self, n: int = 5) -> List[Any]:
338
+ """
339
+ Get first n items.
340
+
341
+ Works for all loading strategies.
342
+
343
+ Args:
344
+ n: Number of items to retrieve (default: 5)
345
+
346
+ Returns:
347
+ List of first n items
348
+ """
349
+ if self._loading_strategy == LoadingStrategy.STREAMING:
350
+ items = []
351
+ for i, item in enumerate(self.iter_items()):
352
+ if i >= n:
353
+ break
354
+ items.append(item)
355
+ return items
356
+ return self[:n]
357
+
358
+ def sample(self, n: int = 5) -> List[Any]:
359
+ """
360
+ Get n random items from the dataset.
361
+
362
+ Works for MEMORY and DISK strategies only.
363
+
364
+ Args:
365
+ n: Number of items to sample
366
+
367
+ Returns:
368
+ List of n randomly sampled items
369
+
370
+ Raises:
371
+ NotImplementedError: If loading_strategy is STREAMING
372
+ """
373
+ if self._loading_strategy == LoadingStrategy.STREAMING:
374
+ raise NotImplementedError(
375
+ "sample() not supported for STREAMING datasets. Use iter_items() and sample manually."
376
+ )
377
+
378
+ dataset_len = len(self)
379
+ if n <= 0:
380
+ return []
381
+ if n >= dataset_len:
382
+ # Return all items in random order
383
+ indices = list(range(dataset_len))
384
+ random.shuffle(indices)
385
+ return [self[i] for i in indices]
386
+
387
+ # Sample n random indices
388
+ indices = random.sample(range(dataset_len), n)
389
+ # Use __getitem__ with list of indices
390
+ return self[indices]
391
+
392
+ @property
393
+ def is_streaming(self) -> bool:
394
+ """Whether this dataset is streaming (DISK or STREAMING)."""
395
+ return self._loading_strategy in (LoadingStrategy.DISK, LoadingStrategy.STREAMING)
396
+
397
+ @abstractmethod
398
+ def __len__(self) -> int:
399
+ """Return the number of items in the dataset."""
400
+ pass
401
+
402
+ @abstractmethod
403
+ def __getitem__(self, idx: IndexLike) -> Any:
404
+ """Get item(s) by index."""
405
+ pass
406
+
407
+ @abstractmethod
408
+ def iter_items(self) -> Iterator[Any]:
409
+ """Iterate over items one by one."""
410
+ pass
411
+
412
+ @abstractmethod
413
+ def iter_batches(self, batch_size: int) -> Iterator[List[Any]]:
414
+ """Iterate over items in batches."""
415
+ pass
416
+
417
+ @abstractmethod
418
+ def extract_texts_from_batch(self, batch: List[Any]) -> List[str]:
419
+ """Extract text strings from a batch.
420
+
421
+ Args:
422
+ batch: A batch as returned by iter_batches()
423
+
424
+ Returns:
425
+ List of text strings ready for model inference
426
+ """
427
+ pass
428
+
429
+ @abstractmethod
430
+ def get_all_texts(self) -> List[str]:
431
+ """Get all texts from the dataset.
432
+
433
+ Returns:
434
+ List of all text strings in the dataset
435
+
436
+ Raises:
437
+ NotImplementedError: If not supported for streaming datasets
438
+ """
439
+ pass
440
+
441
+ # --- Factory methods ---
442
+
443
+ @classmethod
444
+ def from_huggingface(
445
+ cls,
446
+ repo_id: str,
447
+ store: Store,
448
+ *,
449
+ split: str = "train",
450
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
451
+ revision: Optional[str] = None,
452
+ streaming: Optional[bool] = None,
453
+ filters: Optional[Dict[str, Any]] = None,
454
+ limit: Optional[int] = None,
455
+ stratify_by: Optional[str] = None,
456
+ stratify_seed: Optional[int] = None,
457
+ **kwargs,
458
+ ) -> "BaseDataset":
459
+ """
460
+ Load dataset from HuggingFace Hub.
461
+
462
+ Args:
463
+ repo_id: HuggingFace dataset repository ID
464
+ store: Store instance
465
+ split: Dataset split (e.g., "train", "validation")
466
+ loading_strategy: Loading strategy (MEMORY, DISK, or STREAMING)
467
+ revision: Optional git revision/branch/tag
468
+ streaming: Optional override for streaming (if None, uses loading_strategy)
469
+ filters: Optional dict of column->value pairs used for exact-match filtering
470
+ limit: Optional maximum number of rows to keep (applied after filtering/stratification)
471
+ stratify_by: Optional column to use for stratified sampling (non-streaming only)
472
+ stratify_seed: Optional RNG seed for deterministic stratification
473
+ **kwargs: Additional arguments passed to load_dataset
474
+
475
+ Returns:
476
+ BaseDataset instance
477
+
478
+ Raises:
479
+ ValueError: If repo_id is empty or store is None
480
+ RuntimeError: If dataset loading fails
481
+ """
482
+ if not repo_id or not isinstance(repo_id, str) or not repo_id.strip():
483
+ raise ValueError(f"repo_id must be a non-empty string, got: {repo_id!r}")
484
+
485
+ if store is None:
486
+ raise ValueError("store cannot be None")
487
+
488
+ # Determine if we should use streaming for HuggingFace load_dataset
489
+ use_streaming = streaming if streaming is not None else (loading_strategy == LoadingStrategy.STREAMING)
490
+
491
+ if stratify_by and loading_strategy == LoadingStrategy.STREAMING:
492
+ raise NotImplementedError("Stratification is not supported for STREAMING datasets.")
493
+
494
+ try:
495
+ ds = load_dataset(
496
+ path=repo_id,
497
+ split=split,
498
+ revision=revision,
499
+ streaming=use_streaming,
500
+ **kwargs,
501
+ )
502
+ except Exception as e:
503
+ raise RuntimeError(
504
+ f"Failed to load dataset from HuggingFace Hub: repo_id={repo_id!r}, "
505
+ f"split={split!r}, revision={revision!r}. Error: {e}"
506
+ ) from e
507
+
508
+ if use_streaming:
509
+ if filters or limit or stratify_by:
510
+ raise NotImplementedError(
511
+ "filters, limit, and stratification are not supported when streaming datasets. "
512
+ "Choose MEMORY or DISK loading strategy instead."
513
+ )
514
+ else:
515
+ ds = cls._postprocess_non_streaming_dataset(
516
+ ds,
517
+ filters=filters,
518
+ limit=limit,
519
+ stratify_by=stratify_by,
520
+ stratify_seed=stratify_seed,
521
+ )
522
+
523
+ return cls(ds, store=store, loading_strategy=loading_strategy)
524
+
525
+ @classmethod
526
+ def from_csv(
527
+ cls,
528
+ source: Union[str, Path],
529
+ store: Store,
530
+ *,
531
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
532
+ text_field: str = "text",
533
+ delimiter: str = ",",
534
+ stratify_by: Optional[str] = None,
535
+ stratify_seed: Optional[int] = None,
536
+ drop_na_columns: Optional[List[str]] = None,
537
+ **kwargs,
538
+ ) -> "BaseDataset":
539
+ """
540
+ Load dataset from CSV file.
541
+
542
+ Args:
543
+ source: Path to CSV file
544
+ store: Store instance
545
+ loading_strategy: Loading strategy
546
+ text_field: Name of the column containing text
547
+ delimiter: CSV delimiter (default: comma)
548
+ stratify_by: Optional column used for stratified sampling (non-streaming only)
549
+ stratify_seed: Optional RNG seed for stratified sampling
550
+ drop_na_columns: Optional list of columns to check for None/empty values
551
+ **kwargs: Additional arguments passed to load_dataset
552
+
553
+ Returns:
554
+ BaseDataset instance
555
+
556
+ Raises:
557
+ FileNotFoundError: If CSV file doesn't exist
558
+ ValueError: If store is None or source is invalid
559
+ RuntimeError: If dataset loading fails
560
+ """
561
+ if store is None:
562
+ raise ValueError("store cannot be None")
563
+
564
+ use_streaming = loading_strategy == LoadingStrategy.STREAMING
565
+ if (stratify_by or drop_na_columns) and use_streaming:
566
+ raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
567
+
568
+ ds = cls._load_csv_source(
569
+ source,
570
+ delimiter=delimiter,
571
+ streaming=use_streaming,
572
+ **kwargs,
573
+ )
574
+
575
+ if not use_streaming and (stratify_by or drop_na_columns):
576
+ ds = cls._postprocess_non_streaming_dataset(
577
+ ds,
578
+ stratify_by=stratify_by,
579
+ stratify_seed=stratify_seed,
580
+ drop_na_columns=drop_na_columns,
581
+ )
582
+
583
+ return cls(ds, store=store, loading_strategy=loading_strategy)
584
+
585
+ @classmethod
586
+ def from_json(
587
+ cls,
588
+ source: Union[str, Path],
589
+ store: Store,
590
+ *,
591
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
592
+ text_field: str = "text",
593
+ stratify_by: Optional[str] = None,
594
+ stratify_seed: Optional[int] = None,
595
+ drop_na_columns: Optional[List[str]] = None,
596
+ **kwargs,
597
+ ) -> "BaseDataset":
598
+ """
599
+ Load dataset from JSON or JSONL file.
600
+
601
+ Args:
602
+ source: Path to JSON or JSONL file
603
+ store: Store instance
604
+ loading_strategy: Loading strategy
605
+ text_field: Name of the field containing text (for JSON objects)
606
+ stratify_by: Optional column used for stratified sampling (non-streaming only)
607
+ stratify_seed: Optional RNG seed for stratified sampling
608
+ drop_na_columns: Optional list of columns to check for None/empty values
609
+ **kwargs: Additional arguments passed to load_dataset
610
+
611
+ Returns:
612
+ BaseDataset instance
613
+
614
+ Raises:
615
+ FileNotFoundError: If JSON file doesn't exist
616
+ ValueError: If store is None or source is invalid
617
+ RuntimeError: If dataset loading fails
618
+ """
619
+ if store is None:
620
+ raise ValueError("store cannot be None")
621
+
622
+ use_streaming = loading_strategy == LoadingStrategy.STREAMING
623
+ if (stratify_by or drop_na_columns) and use_streaming:
624
+ raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
625
+
626
+ ds = cls._load_json_source(
627
+ source,
628
+ streaming=use_streaming,
629
+ **kwargs,
630
+ )
631
+
632
+ if not use_streaming and (stratify_by or drop_na_columns):
633
+ ds = cls._postprocess_non_streaming_dataset(
634
+ ds,
635
+ stratify_by=stratify_by,
636
+ stratify_seed=stratify_seed,
637
+ drop_na_columns=drop_na_columns,
638
+ )
639
+
640
+ return cls(ds, store=store, loading_strategy=loading_strategy)