mi-crow 1.0.0__py3-none-any.whl → 1.0.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -114,11 +114,14 @@ class BaseDataset(ABC):
114
114
  OSError: If file system operations fail
115
115
  RuntimeError: If dataset operations fail
116
116
  """
117
+ if len(ds) == 0:
118
+ return ds
119
+
117
120
  if self._has_valid_dataset_dir():
118
121
  try:
119
122
  self._dataset_dir.mkdir(parents=True, exist_ok=True)
120
123
  ds.save_to_disk(str(self._dataset_dir))
121
- return load_from_disk(str(self._dataset_dir))
124
+ return load_from_disk(str(self._dataset_dir), keep_in_memory=not use_memory_mapping)
122
125
  except OSError as e:
123
126
  raise OSError(f"Failed to save/load dataset at {self._dataset_dir}. Error: {e}") from e
124
127
  except Exception as e:
@@ -522,6 +525,73 @@ class BaseDataset(ABC):
522
525
 
523
526
  return cls(ds, store=store, loading_strategy=loading_strategy)
524
527
 
528
+ @classmethod
529
+ def from_disk(
530
+ cls,
531
+ store: Store,
532
+ *,
533
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
534
+ **kwargs: Any,
535
+ ) -> "BaseDataset":
536
+ """
537
+ Load dataset from already-saved Arrow files on disk.
538
+
539
+ Use this when you've previously saved a dataset and want to reload it
540
+ without re-downloading from HuggingFace or re-applying transformations.
541
+
542
+ Args:
543
+ store: Store instance pointing to where the dataset was saved
544
+ (dataset will be loaded from store.base_path/store.dataset_prefix/)
545
+ loading_strategy: Loading strategy (MEMORY or DISK only, not STREAMING)
546
+ **kwargs: Additional arguments (for subclass compatibility)
547
+
548
+ Returns:
549
+ BaseDataset instance loaded from disk
550
+
551
+ Raises:
552
+ ValueError: If store is None or loading_strategy is STREAMING
553
+ FileNotFoundError: If dataset directory doesn't exist
554
+ RuntimeError: If dataset loading fails
555
+
556
+ Example:
557
+ # First: save dataset
558
+ dataset_store = LocalStore("store/my_dataset")
559
+ dataset = ClassificationDataset.from_huggingface(..., store=dataset_store)
560
+ # Dataset saved to: store/my_dataset/datasets/*.arrow
561
+
562
+ # Later: reload from disk
563
+ dataset_store = LocalStore("store/my_dataset")
564
+ dataset = ClassificationDataset.from_disk(store=dataset_store)
565
+ """
566
+ if store is None:
567
+ raise ValueError("store cannot be None")
568
+
569
+ if loading_strategy == LoadingStrategy.STREAMING:
570
+ raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")
571
+
572
+ dataset_dir = Path(store.base_path) / store.dataset_prefix
573
+
574
+ if not dataset_dir.exists():
575
+ raise FileNotFoundError(
576
+ f"Dataset directory not found: {dataset_dir}. "
577
+ f"Make sure you've previously saved a dataset to this store location."
578
+ )
579
+
580
+ # Verify it's a valid Arrow dataset directory
581
+ arrow_files = list(dataset_dir.glob("*.arrow"))
582
+ if not arrow_files:
583
+ raise FileNotFoundError(
584
+ f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
585
+ )
586
+
587
+ try:
588
+ use_memory_mapping = loading_strategy == LoadingStrategy.DISK
589
+ ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
590
+ except Exception as e:
591
+ raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e
592
+
593
+ return cls(ds, store=store, loading_strategy=loading_strategy)
594
+
525
595
  @classmethod
526
596
  def from_csv(
527
597
  cls,
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from pathlib import Path
4
4
  from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
5
5
 
6
- from datasets import Dataset, IterableDataset, load_dataset
6
+ from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
7
7
 
8
8
  from mi_crow.datasets.base_dataset import BaseDataset
9
9
  from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
@@ -117,8 +117,7 @@ class ClassificationDataset(BaseDataset):
117
117
 
118
118
  item = {"text": text}
119
119
  for cat_field in self._category_fields:
120
- category = row.get(cat_field)
121
- if category is None:
120
+ if cat_field not in row:
122
121
  raise ValueError(
123
122
  f"Category field '{cat_field}' not found in dataset row. Available fields: {list(row.keys())}"
124
123
  )
@@ -157,9 +156,7 @@ class ClassificationDataset(BaseDataset):
157
156
  ValueError: If dataset is empty
158
157
  """
159
158
  if self._loading_strategy == LoadingStrategy.STREAMING:
160
- raise NotImplementedError(
161
- "Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
162
- )
159
+ raise NotImplementedError("Indexing not supported for STREAMING datasets. Use iter_items or iter_batches.")
163
160
 
164
161
  dataset_len = len(self)
165
162
  if dataset_len == 0:
@@ -446,6 +443,89 @@ class ClassificationDataset(BaseDataset):
446
443
  category_field=category_field,
447
444
  )
448
445
 
446
+ @classmethod
447
+ def from_disk(
448
+ cls,
449
+ store: Store,
450
+ *,
451
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
452
+ text_field: str = "text",
453
+ category_field: Union[str, List[str]] = "category",
454
+ ) -> "ClassificationDataset":
455
+ """
456
+ Load classification dataset from already-saved Arrow files on disk.
457
+
458
+ Use this when you've previously saved a dataset and want to reload it
459
+ without re-downloading from HuggingFace or re-applying transformations.
460
+
461
+ Args:
462
+ store: Store instance pointing to where the dataset was saved
463
+ loading_strategy: Loading strategy (MEMORY or DISK only)
464
+ text_field: Name of the column containing text
465
+ category_field: Name(s) of the column(s) containing category/label
466
+
467
+ Returns:
468
+ ClassificationDataset instance loaded from disk
469
+
470
+ Raises:
471
+ FileNotFoundError: If dataset directory doesn't exist or contains no Arrow files
472
+ ValueError: If required fields are not in the loaded dataset
473
+
474
+ Example:
475
+ # First: save dataset
476
+ dataset_store = LocalStore("store/wgmix_test")
477
+ dataset = ClassificationDataset.from_huggingface(
478
+ "allenai/wildguardmix",
479
+ store=dataset_store,
480
+ limit=100
481
+ )
482
+ # Dataset saved to: store/wgmix_test/datasets/*.arrow
483
+
484
+ # Later: reload from disk
485
+ dataset_store = LocalStore("store/wgmix_test")
486
+ dataset = ClassificationDataset.from_disk(
487
+ store=dataset_store,
488
+ text_field="prompt",
489
+ category_field="prompt_harm_label"
490
+ )
491
+ """
492
+
493
+ if store is None:
494
+ raise ValueError("store cannot be None")
495
+
496
+ if loading_strategy == LoadingStrategy.STREAMING:
497
+ raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")
498
+
499
+ dataset_dir = Path(store.base_path) / store.dataset_prefix
500
+
501
+ if not dataset_dir.exists():
502
+ raise FileNotFoundError(
503
+ f"Dataset directory not found: {dataset_dir}. "
504
+ f"Make sure you've previously saved a dataset to this store location."
505
+ )
506
+
507
+ # Verify it's a valid Arrow dataset directory
508
+ arrow_files = list(dataset_dir.glob("*.arrow"))
509
+ if not arrow_files:
510
+ raise FileNotFoundError(
511
+ f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
512
+ )
513
+
514
+ try:
515
+ use_memory_mapping = loading_strategy == LoadingStrategy.DISK
516
+ ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
517
+ except Exception as e:
518
+ raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e
519
+
520
+ # Create ClassificationDataset with the loaded dataset and field names
521
+ return cls(
522
+ ds,
523
+ store=store,
524
+ loading_strategy=loading_strategy,
525
+ text_field=text_field,
526
+ category_field=category_field,
527
+ )
528
+
449
529
  @classmethod
450
530
  def from_csv(
451
531
  cls,
@@ -483,24 +563,37 @@ class ClassificationDataset(BaseDataset):
483
563
  FileNotFoundError: If CSV file doesn't exist
484
564
  RuntimeError: If dataset loading fails
485
565
  """
486
- drop_na_columns = None
487
- if drop_na:
488
- cat_fields = [category_field] if isinstance(category_field, str) else category_field
489
- drop_na_columns = [text_field] + list(cat_fields)
566
+ if store is None:
567
+ raise ValueError("store cannot be None")
490
568
 
491
- dataset = super().from_csv(
569
+ use_streaming = loading_strategy == LoadingStrategy.STREAMING
570
+ if (stratify_by or drop_na) and use_streaming:
571
+ raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
572
+
573
+ # Load CSV using parent's static method
574
+ ds = cls._load_csv_source(
492
575
  source,
493
- store=store,
494
- loading_strategy=loading_strategy,
495
- text_field=text_field,
496
576
  delimiter=delimiter,
497
- stratify_by=stratify_by,
498
- stratify_seed=stratify_seed,
499
- drop_na_columns=drop_na_columns,
577
+ streaming=use_streaming,
500
578
  **kwargs,
501
579
  )
580
+
581
+ # Apply postprocessing if not streaming
582
+ if not use_streaming and (stratify_by or drop_na):
583
+ drop_na_columns = None
584
+ if drop_na:
585
+ cat_fields = [category_field] if isinstance(category_field, str) else category_field
586
+ drop_na_columns = [text_field] + list(cat_fields)
587
+
588
+ ds = cls._postprocess_non_streaming_dataset(
589
+ ds,
590
+ stratify_by=stratify_by,
591
+ stratify_seed=stratify_seed,
592
+ drop_na_columns=drop_na_columns,
593
+ )
594
+
502
595
  return cls(
503
- dataset._ds,
596
+ ds,
504
597
  store=store,
505
598
  loading_strategy=loading_strategy,
506
599
  text_field=text_field,
@@ -542,23 +635,36 @@ class ClassificationDataset(BaseDataset):
542
635
  FileNotFoundError: If JSON file doesn't exist
543
636
  RuntimeError: If dataset loading fails
544
637
  """
545
- drop_na_columns = None
546
- if drop_na:
547
- cat_fields = [category_field] if isinstance(category_field, str) else category_field
548
- drop_na_columns = [text_field] + list(cat_fields)
638
+ if store is None:
639
+ raise ValueError("store cannot be None")
640
+
641
+ use_streaming = loading_strategy == LoadingStrategy.STREAMING
642
+ if (stratify_by or drop_na) and use_streaming:
643
+ raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
549
644
 
550
- dataset = super().from_json(
645
+ # Load JSON using parent's static method
646
+ ds = cls._load_json_source(
551
647
  source,
552
- store=store,
553
- loading_strategy=loading_strategy,
554
- text_field=text_field,
555
- stratify_by=stratify_by,
556
- stratify_seed=stratify_seed,
557
- drop_na_columns=drop_na_columns,
648
+ streaming=use_streaming,
558
649
  **kwargs,
559
650
  )
651
+
652
+ # Apply postprocessing if not streaming
653
+ if not use_streaming and (stratify_by or drop_na):
654
+ drop_na_columns = None
655
+ if drop_na:
656
+ cat_fields = [category_field] if isinstance(category_field, str) else category_field
657
+ drop_na_columns = [text_field] + list(cat_fields)
658
+
659
+ ds = cls._postprocess_non_streaming_dataset(
660
+ ds,
661
+ stratify_by=stratify_by,
662
+ stratify_seed=stratify_seed,
663
+ drop_na_columns=drop_na_columns,
664
+ )
665
+
560
666
  return cls(
561
- dataset._ds,
667
+ ds,
562
668
  store=store,
563
669
  loading_strategy=loading_strategy,
564
670
  text_field=text_field,
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import random
3
4
  from pathlib import Path
4
5
  from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
5
6
 
6
- from datasets import Dataset, IterableDataset, load_dataset
7
+ from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
7
8
 
8
9
  from mi_crow.datasets.base_dataset import BaseDataset
9
10
  from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
@@ -20,7 +21,7 @@ class TextDataset(BaseDataset):
20
21
  self,
21
22
  ds: Dataset | IterableDataset,
22
23
  store: Store,
23
- loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
24
+ loading_strategy: LoadingStrategy = LoadingStrategy.DISK,
24
25
  text_field: str = "text",
25
26
  ):
26
27
  """
@@ -115,9 +116,7 @@ class TextDataset(BaseDataset):
115
116
  ValueError: If dataset is empty
116
117
  """
117
118
  if self._loading_strategy == LoadingStrategy.STREAMING:
118
- raise NotImplementedError(
119
- "Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
120
- )
119
+ raise NotImplementedError("Indexing not supported for STREAMING datasets. Use iter_items or iter_batches.")
121
120
 
122
121
  dataset_len = len(self)
123
122
  if dataset_len == 0:
@@ -217,6 +216,48 @@ class TextDataset(BaseDataset):
217
216
  return list(self.iter_items())
218
217
  return list(self._ds["text"])
219
218
 
219
+ def random_sample(self, n: int, seed: Optional[int] = None) -> "TextDataset":
220
+ """Create a new TextDataset with n randomly sampled items.
221
+
222
+ Args:
223
+ n: Number of items to sample
224
+ seed: Optional random seed for reproducibility
225
+
226
+ Returns:
227
+ New TextDataset instance with sampled items
228
+
229
+ Raises:
230
+ NotImplementedError: If loading_strategy is STREAMING
231
+ ValueError: If n <= 0
232
+ """
233
+ if self._loading_strategy == LoadingStrategy.STREAMING:
234
+ raise NotImplementedError(
235
+ "random_sample() not supported for STREAMING datasets. Use iter_items() and sample manually."
236
+ )
237
+
238
+ if n <= 0:
239
+ raise ValueError(f"n must be > 0, got: {n}")
240
+
241
+ dataset_len = len(self)
242
+ if n >= dataset_len:
243
+ if seed is not None:
244
+ random.seed(seed)
245
+ indices = list(range(dataset_len))
246
+ random.shuffle(indices)
247
+ sampled_ds = self._ds.select(indices)
248
+ else:
249
+ if seed is not None:
250
+ random.seed(seed)
251
+ indices = random.sample(range(dataset_len), n)
252
+ sampled_ds = self._ds.select(indices)
253
+
254
+ return TextDataset(
255
+ sampled_ds,
256
+ store=self._store,
257
+ loading_strategy=self._loading_strategy,
258
+ text_field=self._text_field,
259
+ )
260
+
220
261
  @classmethod
221
262
  def from_huggingface(
222
263
  cls,
@@ -300,6 +341,81 @@ class TextDataset(BaseDataset):
300
341
 
301
342
  return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)
302
343
 
344
+ @classmethod
345
+ def from_disk(
346
+ cls,
347
+ store: Store,
348
+ *,
349
+ loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
350
+ text_field: str = "text",
351
+ ) -> "TextDataset":
352
+ """
353
+ Load text dataset from already-saved Arrow files on disk.
354
+
355
+ Use this when you've previously saved a dataset and want to reload it
356
+ without re-downloading from HuggingFace or re-applying transformations.
357
+
358
+ Args:
359
+ store: Store instance pointing to where the dataset was saved
360
+ loading_strategy: Loading strategy (MEMORY or DISK only)
361
+ text_field: Name of the column containing text
362
+
363
+ Returns:
364
+ TextDataset instance loaded from disk
365
+
366
+ Raises:
367
+ FileNotFoundError: If dataset directory doesn't exist or contains no Arrow files
368
+
369
+ Example:
370
+ # First: save dataset
371
+ dataset_store = LocalStore("store/my_texts")
372
+ dataset = TextDataset.from_huggingface(
373
+ "wikipedia",
374
+ store=dataset_store,
375
+ limit=1000
376
+ )
377
+ # Dataset saved to: store/my_texts/datasets/*.arrow
378
+
379
+ # Later: reload from disk
380
+ dataset_store = LocalStore("store/my_texts")
381
+ dataset = TextDataset.from_disk(store=dataset_store)
382
+ """
383
+
384
+ if store is None:
385
+ raise ValueError("store cannot be None")
386
+
387
+ if loading_strategy == LoadingStrategy.STREAMING:
388
+ raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")
389
+
390
+ dataset_dir = Path(store.base_path) / store.dataset_prefix
391
+
392
+ if not dataset_dir.exists():
393
+ raise FileNotFoundError(
394
+ f"Dataset directory not found: {dataset_dir}. "
395
+ f"Make sure you've previously saved a dataset to this store location."
396
+ )
397
+
398
+ # Verify it's a valid Arrow dataset directory
399
+ arrow_files = list(dataset_dir.glob("*.arrow"))
400
+ if not arrow_files:
401
+ raise FileNotFoundError(
402
+ f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
403
+ )
404
+
405
+ try:
406
+ use_memory_mapping = loading_strategy == LoadingStrategy.DISK
407
+ ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
408
+ except Exception as e:
409
+ raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e
410
+
411
+ # Create TextDataset with the loaded dataset and field name
412
+ return cls(
413
+ ds,
414
+ store=store,
415
+ loading_strategy=loading_strategy,
416
+ text_field=text_field,
417
+ )
418
+
303
419
  @classmethod
304
420
  def from_csv(
305
421
  cls,
@@ -335,20 +451,33 @@ class TextDataset(BaseDataset):
335
451
  FileNotFoundError: If CSV file doesn't exist
336
452
  RuntimeError: If dataset loading fails
337
453
  """
338
- drop_na_columns = [text_field] if drop_na else None
339
- dataset = super().from_csv(
454
+ if store is None:
455
+ raise ValueError("store cannot be None")
456
+
457
+ use_streaming = loading_strategy == LoadingStrategy.STREAMING
458
+ if (stratify_by or drop_na) and use_streaming:
459
+ raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
460
+
461
+ # Load CSV using parent's static method
462
+ ds = cls._load_csv_source(
340
463
  source,
341
- store=store,
342
- loading_strategy=loading_strategy,
343
- text_field=text_field,
344
464
  delimiter=delimiter,
345
- stratify_by=stratify_by,
346
- stratify_seed=stratify_seed,
347
- drop_na_columns=drop_na_columns,
465
+ streaming=use_streaming,
348
466
  **kwargs,
349
467
  )
468
+
469
+ # Apply postprocessing if not streaming
470
+ if not use_streaming and (stratify_by or drop_na):
471
+ drop_na_columns = [text_field] if drop_na else None
472
+ ds = cls._postprocess_non_streaming_dataset(
473
+ ds,
474
+ stratify_by=stratify_by,
475
+ stratify_seed=stratify_seed,
476
+ drop_na_columns=drop_na_columns,
477
+ )
478
+
350
479
  return cls(
351
- dataset._ds,
480
+ ds,
352
481
  store=store,
353
482
  loading_strategy=loading_strategy,
354
483
  text_field=text_field,
@@ -387,20 +516,32 @@ class TextDataset(BaseDataset):
387
516
  FileNotFoundError: If JSON file doesn't exist
388
517
  RuntimeError: If dataset loading fails
389
518
  """
390
- drop_na_columns = [text_field] if drop_na else None
391
- dataset = super().from_json(
519
+ if store is None:
520
+ raise ValueError("store cannot be None")
521
+
522
+ use_streaming = loading_strategy == LoadingStrategy.STREAMING
523
+ if (stratify_by or drop_na) and use_streaming:
524
+ raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
525
+
526
+ # Load JSON using parent's static method
527
+ ds = cls._load_json_source(
392
528
  source,
393
- store=store,
394
- loading_strategy=loading_strategy,
395
- text_field=text_field,
396
- stratify_by=stratify_by,
397
- stratify_seed=stratify_seed,
398
- drop_na_columns=drop_na_columns,
529
+ streaming=use_streaming,
399
530
  **kwargs,
400
531
  )
401
- # Re-initialize with text_field
532
+
533
+ # Apply postprocessing if not streaming
534
+ if not use_streaming and (stratify_by or drop_na):
535
+ drop_na_columns = [text_field] if drop_na else None
536
+ ds = cls._postprocess_non_streaming_dataset(
537
+ ds,
538
+ stratify_by=stratify_by,
539
+ stratify_seed=stratify_seed,
540
+ drop_na_columns=drop_na_columns,
541
+ )
542
+
402
543
  return cls(
403
- dataset._ds,
544
+ ds,
404
545
  store=store,
405
546
  loading_strategy=loading_strategy,
406
547
  text_field=text_field,
@@ -7,7 +7,11 @@ import torch
7
7
  import torch.nn as nn
8
8
 
9
9
  from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
10
- from mi_crow.hooks.utils import extract_tensor_from_input, extract_tensor_from_output
10
+ from mi_crow.hooks.utils import (
11
+ extract_tensor_from_input,
12
+ extract_tensor_from_output,
13
+ apply_modification_to_output
14
+ )
11
15
  from mi_crow.utils import get_logger
12
16
 
13
17
  if TYPE_CHECKING:
@@ -86,13 +90,14 @@ class Controller(Hook):
86
90
  if output_tensor is None:
87
91
  return
88
92
 
89
- # Extract input tensor if available for modify_activations
90
93
  input_tensor = extract_tensor_from_input(input)
91
-
92
- # Note: forward hooks can't modify output in PyTorch, but we call modify_activations
93
- # for consistency. The actual modification happens via the hook mechanism.
94
- # We still call it so controllers can capture/process activations.
95
- self.modify_activations(module, input_tensor, output_tensor)
94
+ modified_tensor = self.modify_activations(module, input_tensor, output_tensor)
95
+
96
+ if modified_tensor is not None and isinstance(modified_tensor, torch.Tensor):
97
+ target_device = None
98
+ if self.context is not None and hasattr(self.context, 'device') and self.context.device:
99
+ target_device = torch.device(self.context.device)
100
+ apply_modification_to_output(output, modified_tensor, target_device=target_device)
96
101
 
97
102
  def _hook_fn(
98
103
  self,