cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/core/config.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Hydra-compatible configuration dataclasses."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from omegaconf import MISSING
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BackendConfig:
|
|
11
|
+
"""Configuration for inference backend."""
|
|
12
|
+
|
|
13
|
+
_target_: str = MISSING
|
|
14
|
+
device: str = "cuda"
|
|
15
|
+
dtype: str = "bfloat16"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class TransformersBackendConfig(BackendConfig):
|
|
20
|
+
"""Transformers-specific backend config."""
|
|
21
|
+
|
|
22
|
+
_target_: str = "cotlab.backends.TransformersBackend"
|
|
23
|
+
enable_hooks: bool = True
|
|
24
|
+
trust_remote_code: bool = True
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class VLLMBackendConfig(BackendConfig):
|
|
29
|
+
"""vLLM-specific backend config."""
|
|
30
|
+
|
|
31
|
+
_target_: str = "cotlab.backends.VLLMBackend"
|
|
32
|
+
tensor_parallel_size: int = 1
|
|
33
|
+
max_model_len: int = 4096
|
|
34
|
+
trust_remote_code: bool = True
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ModelConfig:
|
|
39
|
+
"""Configuration for model loading."""
|
|
40
|
+
|
|
41
|
+
name: str = MISSING
|
|
42
|
+
variant: str = "4b"
|
|
43
|
+
max_new_tokens: int = 512
|
|
44
|
+
temperature: float = 0.7
|
|
45
|
+
top_p: float = 0.9
|
|
46
|
+
do_sample: bool = True
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class PromptConfig:
|
|
51
|
+
"""Configuration for prompt strategy."""
|
|
52
|
+
|
|
53
|
+
_target_: str = MISSING
|
|
54
|
+
name: str = MISSING
|
|
55
|
+
system_role: Optional[str] = None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class DatasetConfig:
|
|
60
|
+
"""Configuration for dataset loading."""
|
|
61
|
+
|
|
62
|
+
_target_: str = MISSING
|
|
63
|
+
name: str = MISSING
|
|
64
|
+
path: str = MISSING
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class ExperimentConfig:
|
|
69
|
+
"""Configuration for an experiment."""
|
|
70
|
+
|
|
71
|
+
_target_: str = MISSING
|
|
72
|
+
name: str = MISSING
|
|
73
|
+
description: str = ""
|
|
74
|
+
num_samples: Optional[int] = None # None = use all available samples
|
|
75
|
+
tests: List[str] = field(default_factory=list)
|
|
76
|
+
metrics: List[str] = field(default_factory=list)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class Config:
|
|
81
|
+
"""Root configuration."""
|
|
82
|
+
|
|
83
|
+
backend: BackendConfig = field(default_factory=BackendConfig)
|
|
84
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
85
|
+
prompt: PromptConfig = field(default_factory=PromptConfig)
|
|
86
|
+
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
|
87
|
+
experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
|
|
88
|
+
seed: int = 42
|
|
89
|
+
verbose: bool = True
|
|
90
|
+
dry_run: bool = False
|
cotlab/core/registry.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Component registry for dynamic instantiation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Type
|
|
4
|
+
|
|
5
|
+
from hydra.utils import instantiate
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Registry:
|
|
10
|
+
"""Registry for dynamically instantiating components from config."""
|
|
11
|
+
|
|
12
|
+
_backends: Dict[str, Type] = {}
|
|
13
|
+
_prompts: Dict[str, Type] = {}
|
|
14
|
+
_experiments: Dict[str, Type] = {}
|
|
15
|
+
_datasets: Dict[str, Type] = {}
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def register_backend(cls, name: str):
|
|
19
|
+
"""Decorator to register a backend class."""
|
|
20
|
+
|
|
21
|
+
def decorator(klass):
|
|
22
|
+
cls._backends[name] = klass
|
|
23
|
+
return klass
|
|
24
|
+
|
|
25
|
+
return decorator
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def register_prompt(cls, name: str):
|
|
29
|
+
"""Decorator to register a prompt strategy class."""
|
|
30
|
+
|
|
31
|
+
def decorator(klass):
|
|
32
|
+
cls._prompts[name] = klass
|
|
33
|
+
return klass
|
|
34
|
+
|
|
35
|
+
return decorator
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def register_experiment(cls, name: str):
|
|
39
|
+
"""Decorator to register an experiment class."""
|
|
40
|
+
|
|
41
|
+
def decorator(klass):
|
|
42
|
+
cls._experiments[name] = klass
|
|
43
|
+
return klass
|
|
44
|
+
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def register_dataset(cls, name: str):
|
|
49
|
+
"""Decorator to register a dataset class."""
|
|
50
|
+
|
|
51
|
+
def decorator(klass):
|
|
52
|
+
cls._datasets[name] = klass
|
|
53
|
+
return klass
|
|
54
|
+
|
|
55
|
+
return decorator
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def create_component(cfg: DictConfig) -> Any:
|
|
59
|
+
"""
|
|
60
|
+
Create a component from a Hydra config using _target_.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
cfg: DictConfig with _target_ specifying the class
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Instantiated component
|
|
67
|
+
"""
|
|
68
|
+
return instantiate(cfg)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Datasets module."""
|
|
2
|
+
|
|
3
|
+
from .loaders import (
|
|
4
|
+
BaseDataset,
|
|
5
|
+
CardiologyDataset,
|
|
6
|
+
HistopathologyDataset,
|
|
7
|
+
MARCDataset,
|
|
8
|
+
MedBulletsDataset,
|
|
9
|
+
MedQADataset,
|
|
10
|
+
MMLUMedicalDataset,
|
|
11
|
+
NeurologyDataset,
|
|
12
|
+
OncologyDataset,
|
|
13
|
+
PatchingPairsDataset,
|
|
14
|
+
PediatricsDataset,
|
|
15
|
+
PLABDataset,
|
|
16
|
+
ProbingDiagnosisDataset,
|
|
17
|
+
PubHealthBenchDataset,
|
|
18
|
+
PubMedQADataset,
|
|
19
|
+
RadiologyDataset,
|
|
20
|
+
Sample,
|
|
21
|
+
SyntheticMedicalDataset,
|
|
22
|
+
TCGADataset,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"Sample",
|
|
27
|
+
"BaseDataset",
|
|
28
|
+
"CardiologyDataset",
|
|
29
|
+
"HistopathologyDataset",
|
|
30
|
+
"MARCDataset",
|
|
31
|
+
"MedBulletsDataset",
|
|
32
|
+
"MedQADataset",
|
|
33
|
+
"MMLUMedicalDataset",
|
|
34
|
+
"NeurologyDataset",
|
|
35
|
+
"OncologyDataset",
|
|
36
|
+
"RadiologyDataset",
|
|
37
|
+
"PediatricsDataset",
|
|
38
|
+
"PLABDataset",
|
|
39
|
+
"PubMedQADataset",
|
|
40
|
+
"PubHealthBenchDataset",
|
|
41
|
+
"SyntheticMedicalDataset",
|
|
42
|
+
"PatchingPairsDataset",
|
|
43
|
+
"ProbingDiagnosisDataset",
|
|
44
|
+
"TCGADataset",
|
|
45
|
+
]
|