mi-crow 1.0.0__py3-none-any.whl → 1.0.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mi_crow/datasets/base_dataset.py +71 -1
- mi_crow/datasets/classification_dataset.py +136 -30
- mi_crow/datasets/text_dataset.py +165 -24
- mi_crow/hooks/controller.py +12 -7
- mi_crow/hooks/implementations/layer_activation_detector.py +30 -34
- mi_crow/hooks/implementations/model_input_detector.py +87 -87
- mi_crow/hooks/implementations/model_output_detector.py +43 -42
- mi_crow/hooks/utils.py +74 -0
- mi_crow/language_model/activations.py +174 -77
- mi_crow/language_model/device_manager.py +119 -0
- mi_crow/language_model/inference.py +18 -5
- mi_crow/language_model/initialization.py +10 -6
- mi_crow/language_model/language_model.py +67 -97
- mi_crow/language_model/layers.py +16 -13
- mi_crow/language_model/persistence.py +4 -2
- mi_crow/language_model/utils.py +5 -5
- mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py +157 -95
- mi_crow/mechanistic/sae/concepts/concept_dictionary.py +12 -2
- mi_crow/mechanistic/sae/concepts/text_heap.py +161 -0
- mi_crow/mechanistic/sae/modules/topk_sae.py +29 -22
- mi_crow/mechanistic/sae/sae.py +3 -1
- mi_crow/mechanistic/sae/sae_trainer.py +362 -29
- mi_crow/store/local_store.py +11 -5
- mi_crow/store/store.py +34 -1
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/METADATA +2 -1
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/RECORD +28 -26
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/WHEEL +0 -0
- {mi_crow-1.0.0.dist-info → mi_crow-1.0.0.post1.dist-info}/top_level.txt +0 -0
mi_crow/datasets/base_dataset.py
CHANGED
|
@@ -114,11 +114,14 @@ class BaseDataset(ABC):
|
|
|
114
114
|
OSError: If file system operations fail
|
|
115
115
|
RuntimeError: If dataset operations fail
|
|
116
116
|
"""
|
|
117
|
+
if len(ds) == 0:
|
|
118
|
+
return ds
|
|
119
|
+
|
|
117
120
|
if self._has_valid_dataset_dir():
|
|
118
121
|
try:
|
|
119
122
|
self._dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
120
123
|
ds.save_to_disk(str(self._dataset_dir))
|
|
121
|
-
return load_from_disk(str(self._dataset_dir))
|
|
124
|
+
return load_from_disk(str(self._dataset_dir), keep_in_memory=not use_memory_mapping)
|
|
122
125
|
except OSError as e:
|
|
123
126
|
raise OSError(f"Failed to save/load dataset at {self._dataset_dir}. Error: {e}") from e
|
|
124
127
|
except Exception as e:
|
|
@@ -522,6 +525,73 @@ class BaseDataset(ABC):
|
|
|
522
525
|
|
|
523
526
|
return cls(ds, store=store, loading_strategy=loading_strategy)
|
|
524
527
|
|
|
528
|
+
@classmethod
|
|
529
|
+
def from_disk(
|
|
530
|
+
cls,
|
|
531
|
+
store: Store,
|
|
532
|
+
*,
|
|
533
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
534
|
+
**kwargs: Any,
|
|
535
|
+
) -> "BaseDataset":
|
|
536
|
+
"""
|
|
537
|
+
Load dataset from already-saved Arrow files on disk.
|
|
538
|
+
|
|
539
|
+
Use this when you've previously saved a dataset and want to reload it
|
|
540
|
+
without re-downloading from HuggingFace or re-applying transformations.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
store: Store instance pointing to where the dataset was saved
|
|
544
|
+
(dataset will be loaded from store.base_path/store.dataset_prefix/)
|
|
545
|
+
loading_strategy: Loading strategy (MEMORY or DISK only, not STREAMING)
|
|
546
|
+
**kwargs: Additional arguments (for subclass compatibility)
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
BaseDataset instance loaded from disk
|
|
550
|
+
|
|
551
|
+
Raises:
|
|
552
|
+
ValueError: If store is None or loading_strategy is STREAMING
|
|
553
|
+
FileNotFoundError: If dataset directory doesn't exist
|
|
554
|
+
RuntimeError: If dataset loading fails
|
|
555
|
+
|
|
556
|
+
Example:
|
|
557
|
+
# First: save dataset
|
|
558
|
+
dataset_store = LocalStore("store/my_dataset")
|
|
559
|
+
dataset = ClassificationDataset.from_huggingface(..., store=dataset_store)
|
|
560
|
+
# Dataset saved to: store/my_dataset/datasets/*.arrow
|
|
561
|
+
|
|
562
|
+
# Later: reload from disk
|
|
563
|
+
dataset_store = LocalStore("store/my_dataset")
|
|
564
|
+
dataset = ClassificationDataset.from_disk(store=dataset_store)
|
|
565
|
+
"""
|
|
566
|
+
if store is None:
|
|
567
|
+
raise ValueError("store cannot be None")
|
|
568
|
+
|
|
569
|
+
if loading_strategy == LoadingStrategy.STREAMING:
|
|
570
|
+
raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")
|
|
571
|
+
|
|
572
|
+
dataset_dir = Path(store.base_path) / store.dataset_prefix
|
|
573
|
+
|
|
574
|
+
if not dataset_dir.exists():
|
|
575
|
+
raise FileNotFoundError(
|
|
576
|
+
f"Dataset directory not found: {dataset_dir}. "
|
|
577
|
+
f"Make sure you've previously saved a dataset to this store location."
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# Verify it's a valid Arrow dataset directory
|
|
581
|
+
arrow_files = list(dataset_dir.glob("*.arrow"))
|
|
582
|
+
if not arrow_files:
|
|
583
|
+
raise FileNotFoundError(
|
|
584
|
+
f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
try:
|
|
588
|
+
use_memory_mapping = loading_strategy == LoadingStrategy.DISK
|
|
589
|
+
ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
|
|
590
|
+
except Exception as e:
|
|
591
|
+
raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e
|
|
592
|
+
|
|
593
|
+
return cls(ds, store=store, loading_strategy=loading_strategy)
|
|
594
|
+
|
|
525
595
|
@classmethod
|
|
526
596
|
def from_csv(
|
|
527
597
|
cls,
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
|
5
5
|
|
|
6
|
-
from datasets import Dataset, IterableDataset, load_dataset
|
|
6
|
+
from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
|
|
7
7
|
|
|
8
8
|
from mi_crow.datasets.base_dataset import BaseDataset
|
|
9
9
|
from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
@@ -117,8 +117,7 @@ class ClassificationDataset(BaseDataset):
|
|
|
117
117
|
|
|
118
118
|
item = {"text": text}
|
|
119
119
|
for cat_field in self._category_fields:
|
|
120
|
-
|
|
121
|
-
if category is None:
|
|
120
|
+
if cat_field not in row:
|
|
122
121
|
raise ValueError(
|
|
123
122
|
f"Category field '{cat_field}' not found in dataset row. Available fields: {list(row.keys())}"
|
|
124
123
|
)
|
|
@@ -157,9 +156,7 @@ class ClassificationDataset(BaseDataset):
|
|
|
157
156
|
ValueError: If dataset is empty
|
|
158
157
|
"""
|
|
159
158
|
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
160
|
-
raise NotImplementedError(
|
|
161
|
-
"Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
|
|
162
|
-
)
|
|
159
|
+
raise NotImplementedError("Indexing not supported for STREAMING datasets. Use iter_items or iter_batches.")
|
|
163
160
|
|
|
164
161
|
dataset_len = len(self)
|
|
165
162
|
if dataset_len == 0:
|
|
@@ -446,6 +443,89 @@ class ClassificationDataset(BaseDataset):
|
|
|
446
443
|
category_field=category_field,
|
|
447
444
|
)
|
|
448
445
|
|
|
446
|
+
@classmethod
|
|
447
|
+
def from_disk(
|
|
448
|
+
cls,
|
|
449
|
+
store: Store,
|
|
450
|
+
*,
|
|
451
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
452
|
+
text_field: str = "text",
|
|
453
|
+
category_field: Union[str, List[str]] = "category",
|
|
454
|
+
) -> "ClassificationDataset":
|
|
455
|
+
"""
|
|
456
|
+
Load classification dataset from already-saved Arrow files on disk.
|
|
457
|
+
|
|
458
|
+
Use this when you've previously saved a dataset and want to reload it
|
|
459
|
+
without re-downloading from HuggingFace or re-applying transformations.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
store: Store instance pointing to where the dataset was saved
|
|
463
|
+
loading_strategy: Loading strategy (MEMORY or DISK only)
|
|
464
|
+
text_field: Name of the column containing text
|
|
465
|
+
category_field: Name(s) of the column(s) containing category/label
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
ClassificationDataset instance loaded from disk
|
|
469
|
+
|
|
470
|
+
Raises:
|
|
471
|
+
FileNotFoundError: If dataset directory doesn't exist or contains no Arrow files
|
|
472
|
+
ValueError: If required fields are not in the loaded dataset
|
|
473
|
+
|
|
474
|
+
Example:
|
|
475
|
+
# First: save dataset
|
|
476
|
+
dataset_store = LocalStore("store/wgmix_test")
|
|
477
|
+
dataset = ClassificationDataset.from_huggingface(
|
|
478
|
+
"allenai/wildguardmix",
|
|
479
|
+
store=dataset_store,
|
|
480
|
+
limit=100
|
|
481
|
+
)
|
|
482
|
+
# Dataset saved to: store/wgmix_test/datasets/*.arrow
|
|
483
|
+
|
|
484
|
+
# Later: reload from disk
|
|
485
|
+
dataset_store = LocalStore("store/wgmix_test")
|
|
486
|
+
dataset = ClassificationDataset.from_disk(
|
|
487
|
+
store=dataset_store,
|
|
488
|
+
text_field="prompt",
|
|
489
|
+
category_field="prompt_harm_label"
|
|
490
|
+
)
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
if store is None:
|
|
494
|
+
raise ValueError("store cannot be None")
|
|
495
|
+
|
|
496
|
+
if loading_strategy == LoadingStrategy.STREAMING:
|
|
497
|
+
raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")
|
|
498
|
+
|
|
499
|
+
dataset_dir = Path(store.base_path) / store.dataset_prefix
|
|
500
|
+
|
|
501
|
+
if not dataset_dir.exists():
|
|
502
|
+
raise FileNotFoundError(
|
|
503
|
+
f"Dataset directory not found: {dataset_dir}. "
|
|
504
|
+
f"Make sure you've previously saved a dataset to this store location."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Verify it's a valid Arrow dataset directory
|
|
508
|
+
arrow_files = list(dataset_dir.glob("*.arrow"))
|
|
509
|
+
if not arrow_files:
|
|
510
|
+
raise FileNotFoundError(
|
|
511
|
+
f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
use_memory_mapping = loading_strategy == LoadingStrategy.DISK
|
|
516
|
+
ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
|
|
517
|
+
except Exception as e:
|
|
518
|
+
raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e
|
|
519
|
+
|
|
520
|
+
# Create ClassificationDataset with the loaded dataset and field names
|
|
521
|
+
return cls(
|
|
522
|
+
ds,
|
|
523
|
+
store=store,
|
|
524
|
+
loading_strategy=loading_strategy,
|
|
525
|
+
text_field=text_field,
|
|
526
|
+
category_field=category_field,
|
|
527
|
+
)
|
|
528
|
+
|
|
449
529
|
@classmethod
|
|
450
530
|
def from_csv(
|
|
451
531
|
cls,
|
|
@@ -483,24 +563,37 @@ class ClassificationDataset(BaseDataset):
|
|
|
483
563
|
FileNotFoundError: If CSV file doesn't exist
|
|
484
564
|
RuntimeError: If dataset loading fails
|
|
485
565
|
"""
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
cat_fields = [category_field] if isinstance(category_field, str) else category_field
|
|
489
|
-
drop_na_columns = [text_field] + list(cat_fields)
|
|
566
|
+
if store is None:
|
|
567
|
+
raise ValueError("store cannot be None")
|
|
490
568
|
|
|
491
|
-
|
|
569
|
+
use_streaming = loading_strategy == LoadingStrategy.STREAMING
|
|
570
|
+
if (stratify_by or drop_na) and use_streaming:
|
|
571
|
+
raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
|
|
572
|
+
|
|
573
|
+
# Load CSV using parent's static method
|
|
574
|
+
ds = cls._load_csv_source(
|
|
492
575
|
source,
|
|
493
|
-
store=store,
|
|
494
|
-
loading_strategy=loading_strategy,
|
|
495
|
-
text_field=text_field,
|
|
496
576
|
delimiter=delimiter,
|
|
497
|
-
|
|
498
|
-
stratify_seed=stratify_seed,
|
|
499
|
-
drop_na_columns=drop_na_columns,
|
|
577
|
+
streaming=use_streaming,
|
|
500
578
|
**kwargs,
|
|
501
579
|
)
|
|
580
|
+
|
|
581
|
+
# Apply postprocessing if not streaming
|
|
582
|
+
if not use_streaming and (stratify_by or drop_na):
|
|
583
|
+
drop_na_columns = None
|
|
584
|
+
if drop_na:
|
|
585
|
+
cat_fields = [category_field] if isinstance(category_field, str) else category_field
|
|
586
|
+
drop_na_columns = [text_field] + list(cat_fields)
|
|
587
|
+
|
|
588
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
589
|
+
ds,
|
|
590
|
+
stratify_by=stratify_by,
|
|
591
|
+
stratify_seed=stratify_seed,
|
|
592
|
+
drop_na_columns=drop_na_columns,
|
|
593
|
+
)
|
|
594
|
+
|
|
502
595
|
return cls(
|
|
503
|
-
|
|
596
|
+
ds,
|
|
504
597
|
store=store,
|
|
505
598
|
loading_strategy=loading_strategy,
|
|
506
599
|
text_field=text_field,
|
|
@@ -542,23 +635,36 @@ class ClassificationDataset(BaseDataset):
|
|
|
542
635
|
FileNotFoundError: If JSON file doesn't exist
|
|
543
636
|
RuntimeError: If dataset loading fails
|
|
544
637
|
"""
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
638
|
+
if store is None:
|
|
639
|
+
raise ValueError("store cannot be None")
|
|
640
|
+
|
|
641
|
+
use_streaming = loading_strategy == LoadingStrategy.STREAMING
|
|
642
|
+
if (stratify_by or drop_na) and use_streaming:
|
|
643
|
+
raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
|
|
549
644
|
|
|
550
|
-
|
|
645
|
+
# Load JSON using parent's static method
|
|
646
|
+
ds = cls._load_json_source(
|
|
551
647
|
source,
|
|
552
|
-
|
|
553
|
-
loading_strategy=loading_strategy,
|
|
554
|
-
text_field=text_field,
|
|
555
|
-
stratify_by=stratify_by,
|
|
556
|
-
stratify_seed=stratify_seed,
|
|
557
|
-
drop_na_columns=drop_na_columns,
|
|
648
|
+
streaming=use_streaming,
|
|
558
649
|
**kwargs,
|
|
559
650
|
)
|
|
651
|
+
|
|
652
|
+
# Apply postprocessing if not streaming
|
|
653
|
+
if not use_streaming and (stratify_by or drop_na):
|
|
654
|
+
drop_na_columns = None
|
|
655
|
+
if drop_na:
|
|
656
|
+
cat_fields = [category_field] if isinstance(category_field, str) else category_field
|
|
657
|
+
drop_na_columns = [text_field] + list(cat_fields)
|
|
658
|
+
|
|
659
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
660
|
+
ds,
|
|
661
|
+
stratify_by=stratify_by,
|
|
662
|
+
stratify_seed=stratify_seed,
|
|
663
|
+
drop_na_columns=drop_na_columns,
|
|
664
|
+
)
|
|
665
|
+
|
|
560
666
|
return cls(
|
|
561
|
-
|
|
667
|
+
ds,
|
|
562
668
|
store=store,
|
|
563
669
|
loading_strategy=loading_strategy,
|
|
564
670
|
text_field=text_field,
|
mi_crow/datasets/text_dataset.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import random
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
|
5
6
|
|
|
6
|
-
from datasets import Dataset, IterableDataset, load_dataset
|
|
7
|
+
from datasets import Dataset, IterableDataset, load_dataset, load_from_disk
|
|
7
8
|
|
|
8
9
|
from mi_crow.datasets.base_dataset import BaseDataset
|
|
9
10
|
from mi_crow.datasets.loading_strategy import IndexLike, LoadingStrategy
|
|
@@ -20,7 +21,7 @@ class TextDataset(BaseDataset):
|
|
|
20
21
|
self,
|
|
21
22
|
ds: Dataset | IterableDataset,
|
|
22
23
|
store: Store,
|
|
23
|
-
loading_strategy: LoadingStrategy = LoadingStrategy.
|
|
24
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.DISK,
|
|
24
25
|
text_field: str = "text",
|
|
25
26
|
):
|
|
26
27
|
"""
|
|
@@ -115,9 +116,7 @@ class TextDataset(BaseDataset):
|
|
|
115
116
|
ValueError: If dataset is empty
|
|
116
117
|
"""
|
|
117
118
|
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
118
|
-
raise NotImplementedError(
|
|
119
|
-
"Indexing not supported for STREAMING datasets. Use iter_items or iter_batches."
|
|
120
|
-
)
|
|
119
|
+
raise NotImplementedError("Indexing not supported for STREAMING datasets. Use iter_items or iter_batches.")
|
|
121
120
|
|
|
122
121
|
dataset_len = len(self)
|
|
123
122
|
if dataset_len == 0:
|
|
@@ -217,6 +216,48 @@ class TextDataset(BaseDataset):
|
|
|
217
216
|
return list(self.iter_items())
|
|
218
217
|
return list(self._ds["text"])
|
|
219
218
|
|
|
219
|
+
def random_sample(self, n: int, seed: Optional[int] = None) -> "TextDataset":
|
|
220
|
+
"""Create a new TextDataset with n randomly sampled items.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
n: Number of items to sample
|
|
224
|
+
seed: Optional random seed for reproducibility
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
New TextDataset instance with sampled items
|
|
228
|
+
|
|
229
|
+
Raises:
|
|
230
|
+
NotImplementedError: If loading_strategy is STREAMING
|
|
231
|
+
ValueError: If n <= 0
|
|
232
|
+
"""
|
|
233
|
+
if self._loading_strategy == LoadingStrategy.STREAMING:
|
|
234
|
+
raise NotImplementedError(
|
|
235
|
+
"random_sample() not supported for STREAMING datasets. Use iter_items() and sample manually."
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if n <= 0:
|
|
239
|
+
raise ValueError(f"n must be > 0, got: {n}")
|
|
240
|
+
|
|
241
|
+
dataset_len = len(self)
|
|
242
|
+
if n >= dataset_len:
|
|
243
|
+
if seed is not None:
|
|
244
|
+
random.seed(seed)
|
|
245
|
+
indices = list(range(dataset_len))
|
|
246
|
+
random.shuffle(indices)
|
|
247
|
+
sampled_ds = self._ds.select(indices)
|
|
248
|
+
else:
|
|
249
|
+
if seed is not None:
|
|
250
|
+
random.seed(seed)
|
|
251
|
+
indices = random.sample(range(dataset_len), n)
|
|
252
|
+
sampled_ds = self._ds.select(indices)
|
|
253
|
+
|
|
254
|
+
return TextDataset(
|
|
255
|
+
sampled_ds,
|
|
256
|
+
store=self._store,
|
|
257
|
+
loading_strategy=self._loading_strategy,
|
|
258
|
+
text_field=self._text_field,
|
|
259
|
+
)
|
|
260
|
+
|
|
220
261
|
@classmethod
|
|
221
262
|
def from_huggingface(
|
|
222
263
|
cls,
|
|
@@ -300,6 +341,81 @@ class TextDataset(BaseDataset):
|
|
|
300
341
|
|
|
301
342
|
return cls(ds, store=store, loading_strategy=loading_strategy, text_field=text_field)
|
|
302
343
|
|
|
344
|
+
@classmethod
|
|
345
|
+
def from_disk(
|
|
346
|
+
cls,
|
|
347
|
+
store: Store,
|
|
348
|
+
*,
|
|
349
|
+
loading_strategy: LoadingStrategy = LoadingStrategy.MEMORY,
|
|
350
|
+
text_field: str = "text",
|
|
351
|
+
) -> "TextDataset":
|
|
352
|
+
"""
|
|
353
|
+
Load text dataset from already-saved Arrow files on disk.
|
|
354
|
+
|
|
355
|
+
Use this when you've previously saved a dataset and want to reload it
|
|
356
|
+
without re-downloading from HuggingFace or re-applying transformations.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
store: Store instance pointing to where the dataset was saved
|
|
360
|
+
loading_strategy: Loading strategy (MEMORY or DISK only)
|
|
361
|
+
text_field: Name of the column containing text
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
TextDataset instance loaded from disk
|
|
365
|
+
|
|
366
|
+
Raises:
|
|
367
|
+
FileNotFoundError: If dataset directory doesn't exist or contains no Arrow files
|
|
368
|
+
|
|
369
|
+
Example:
|
|
370
|
+
# First: save dataset
|
|
371
|
+
dataset_store = LocalStore("store/my_texts")
|
|
372
|
+
dataset = TextDataset.from_huggingface(
|
|
373
|
+
"wikipedia",
|
|
374
|
+
store=dataset_store,
|
|
375
|
+
limit=1000
|
|
376
|
+
)
|
|
377
|
+
# Dataset saved to: store/my_texts/datasets/*.arrow
|
|
378
|
+
|
|
379
|
+
# Later: reload from disk
|
|
380
|
+
dataset_store = LocalStore("store/my_texts")
|
|
381
|
+
dataset = TextDataset.from_disk(store=dataset_store)
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
if store is None:
|
|
385
|
+
raise ValueError("store cannot be None")
|
|
386
|
+
|
|
387
|
+
if loading_strategy == LoadingStrategy.STREAMING:
|
|
388
|
+
raise ValueError("STREAMING loading strategy not supported for from_disk(). Use MEMORY or DISK.")
|
|
389
|
+
|
|
390
|
+
dataset_dir = Path(store.base_path) / store.dataset_prefix
|
|
391
|
+
|
|
392
|
+
if not dataset_dir.exists():
|
|
393
|
+
raise FileNotFoundError(
|
|
394
|
+
f"Dataset directory not found: {dataset_dir}. "
|
|
395
|
+
f"Make sure you've previously saved a dataset to this store location."
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Verify it's a valid Arrow dataset directory
|
|
399
|
+
arrow_files = list(dataset_dir.glob("*.arrow"))
|
|
400
|
+
if not arrow_files:
|
|
401
|
+
raise FileNotFoundError(
|
|
402
|
+
f"No Arrow files found in {dataset_dir}. Directory exists but doesn't contain a valid dataset."
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
try:
|
|
406
|
+
use_memory_mapping = loading_strategy == LoadingStrategy.DISK
|
|
407
|
+
ds = load_from_disk(str(dataset_dir), keep_in_memory=not use_memory_mapping)
|
|
408
|
+
except Exception as e:
|
|
409
|
+
raise RuntimeError(f"Failed to load dataset from {dataset_dir}. Error: {e}") from e
|
|
410
|
+
|
|
411
|
+
# Create TextDataset with the loaded dataset and field name
|
|
412
|
+
return cls(
|
|
413
|
+
ds,
|
|
414
|
+
store=store,
|
|
415
|
+
loading_strategy=loading_strategy,
|
|
416
|
+
text_field=text_field,
|
|
417
|
+
)
|
|
418
|
+
|
|
303
419
|
@classmethod
|
|
304
420
|
def from_csv(
|
|
305
421
|
cls,
|
|
@@ -335,20 +451,33 @@ class TextDataset(BaseDataset):
|
|
|
335
451
|
FileNotFoundError: If CSV file doesn't exist
|
|
336
452
|
RuntimeError: If dataset loading fails
|
|
337
453
|
"""
|
|
338
|
-
|
|
339
|
-
|
|
454
|
+
if store is None:
|
|
455
|
+
raise ValueError("store cannot be None")
|
|
456
|
+
|
|
457
|
+
use_streaming = loading_strategy == LoadingStrategy.STREAMING
|
|
458
|
+
if (stratify_by or drop_na) and use_streaming:
|
|
459
|
+
raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
|
|
460
|
+
|
|
461
|
+
# Load CSV using parent's static method
|
|
462
|
+
ds = cls._load_csv_source(
|
|
340
463
|
source,
|
|
341
|
-
store=store,
|
|
342
|
-
loading_strategy=loading_strategy,
|
|
343
|
-
text_field=text_field,
|
|
344
464
|
delimiter=delimiter,
|
|
345
|
-
|
|
346
|
-
stratify_seed=stratify_seed,
|
|
347
|
-
drop_na_columns=drop_na_columns,
|
|
465
|
+
streaming=use_streaming,
|
|
348
466
|
**kwargs,
|
|
349
467
|
)
|
|
468
|
+
|
|
469
|
+
# Apply postprocessing if not streaming
|
|
470
|
+
if not use_streaming and (stratify_by or drop_na):
|
|
471
|
+
drop_na_columns = [text_field] if drop_na else None
|
|
472
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
473
|
+
ds,
|
|
474
|
+
stratify_by=stratify_by,
|
|
475
|
+
stratify_seed=stratify_seed,
|
|
476
|
+
drop_na_columns=drop_na_columns,
|
|
477
|
+
)
|
|
478
|
+
|
|
350
479
|
return cls(
|
|
351
|
-
|
|
480
|
+
ds,
|
|
352
481
|
store=store,
|
|
353
482
|
loading_strategy=loading_strategy,
|
|
354
483
|
text_field=text_field,
|
|
@@ -387,20 +516,32 @@ class TextDataset(BaseDataset):
|
|
|
387
516
|
FileNotFoundError: If JSON file doesn't exist
|
|
388
517
|
RuntimeError: If dataset loading fails
|
|
389
518
|
"""
|
|
390
|
-
|
|
391
|
-
|
|
519
|
+
if store is None:
|
|
520
|
+
raise ValueError("store cannot be None")
|
|
521
|
+
|
|
522
|
+
use_streaming = loading_strategy == LoadingStrategy.STREAMING
|
|
523
|
+
if (stratify_by or drop_na) and use_streaming:
|
|
524
|
+
raise NotImplementedError("Stratification and drop_na are not supported for STREAMING datasets.")
|
|
525
|
+
|
|
526
|
+
# Load JSON using parent's static method
|
|
527
|
+
ds = cls._load_json_source(
|
|
392
528
|
source,
|
|
393
|
-
|
|
394
|
-
loading_strategy=loading_strategy,
|
|
395
|
-
text_field=text_field,
|
|
396
|
-
stratify_by=stratify_by,
|
|
397
|
-
stratify_seed=stratify_seed,
|
|
398
|
-
drop_na_columns=drop_na_columns,
|
|
529
|
+
streaming=use_streaming,
|
|
399
530
|
**kwargs,
|
|
400
531
|
)
|
|
401
|
-
|
|
532
|
+
|
|
533
|
+
# Apply postprocessing if not streaming
|
|
534
|
+
if not use_streaming and (stratify_by or drop_na):
|
|
535
|
+
drop_na_columns = [text_field] if drop_na else None
|
|
536
|
+
ds = cls._postprocess_non_streaming_dataset(
|
|
537
|
+
ds,
|
|
538
|
+
stratify_by=stratify_by,
|
|
539
|
+
stratify_seed=stratify_seed,
|
|
540
|
+
drop_na_columns=drop_na_columns,
|
|
541
|
+
)
|
|
542
|
+
|
|
402
543
|
return cls(
|
|
403
|
-
|
|
544
|
+
ds,
|
|
404
545
|
store=store,
|
|
405
546
|
loading_strategy=loading_strategy,
|
|
406
547
|
text_field=text_field,
|
mi_crow/hooks/controller.py
CHANGED
|
@@ -7,7 +7,11 @@ import torch
|
|
|
7
7
|
import torch.nn as nn
|
|
8
8
|
|
|
9
9
|
from mi_crow.hooks.hook import Hook, HookType, HOOK_FUNCTION_INPUT, HOOK_FUNCTION_OUTPUT
|
|
10
|
-
from mi_crow.hooks.utils import
|
|
10
|
+
from mi_crow.hooks.utils import (
|
|
11
|
+
extract_tensor_from_input,
|
|
12
|
+
extract_tensor_from_output,
|
|
13
|
+
apply_modification_to_output
|
|
14
|
+
)
|
|
11
15
|
from mi_crow.utils import get_logger
|
|
12
16
|
|
|
13
17
|
if TYPE_CHECKING:
|
|
@@ -86,13 +90,14 @@ class Controller(Hook):
|
|
|
86
90
|
if output_tensor is None:
|
|
87
91
|
return
|
|
88
92
|
|
|
89
|
-
# Extract input tensor if available for modify_activations
|
|
90
93
|
input_tensor = extract_tensor_from_input(input)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
modified_tensor = self.modify_activations(module, input_tensor, output_tensor)
|
|
95
|
+
|
|
96
|
+
if modified_tensor is not None and isinstance(modified_tensor, torch.Tensor):
|
|
97
|
+
target_device = None
|
|
98
|
+
if self.context is not None and hasattr(self.context, 'device') and self.context.device:
|
|
99
|
+
target_device = torch.device(self.context.device)
|
|
100
|
+
apply_modification_to_output(output, modified_tensor, target_device=target_device)
|
|
96
101
|
|
|
97
102
|
def _hook_fn(
|
|
98
103
|
self,
|