cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/core/config.py ADDED
@@ -0,0 +1,90 @@
1
+ """Hydra-compatible configuration dataclasses."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Optional
5
+
6
+ from omegaconf import MISSING
7
+
8
+
9
+ @dataclass
10
+ class BackendConfig:
11
+ """Configuration for inference backend."""
12
+
13
+ _target_: str = MISSING
14
+ device: str = "cuda"
15
+ dtype: str = "bfloat16"
16
+
17
+
18
+ @dataclass
19
+ class TransformersBackendConfig(BackendConfig):
20
+ """Transformers-specific backend config."""
21
+
22
+ _target_: str = "cotlab.backends.TransformersBackend"
23
+ enable_hooks: bool = True
24
+ trust_remote_code: bool = True
25
+
26
+
27
+ @dataclass
28
+ class VLLMBackendConfig(BackendConfig):
29
+ """vLLM-specific backend config."""
30
+
31
+ _target_: str = "cotlab.backends.VLLMBackend"
32
+ tensor_parallel_size: int = 1
33
+ max_model_len: int = 4096
34
+ trust_remote_code: bool = True
35
+
36
+
37
+ @dataclass
38
+ class ModelConfig:
39
+ """Configuration for model loading."""
40
+
41
+ name: str = MISSING
42
+ variant: str = "4b"
43
+ max_new_tokens: int = 512
44
+ temperature: float = 0.7
45
+ top_p: float = 0.9
46
+ do_sample: bool = True
47
+
48
+
49
+ @dataclass
50
+ class PromptConfig:
51
+ """Configuration for prompt strategy."""
52
+
53
+ _target_: str = MISSING
54
+ name: str = MISSING
55
+ system_role: Optional[str] = None
56
+
57
+
58
+ @dataclass
59
+ class DatasetConfig:
60
+ """Configuration for dataset loading."""
61
+
62
+ _target_: str = MISSING
63
+ name: str = MISSING
64
+ path: str = MISSING
65
+
66
+
67
+ @dataclass
68
+ class ExperimentConfig:
69
+ """Configuration for an experiment."""
70
+
71
+ _target_: str = MISSING
72
+ name: str = MISSING
73
+ description: str = ""
74
+ num_samples: Optional[int] = None # None = use all available samples
75
+ tests: List[str] = field(default_factory=list)
76
+ metrics: List[str] = field(default_factory=list)
77
+
78
+
79
+ @dataclass
80
+ class Config:
81
+ """Root configuration."""
82
+
83
+ backend: BackendConfig = field(default_factory=BackendConfig)
84
+ model: ModelConfig = field(default_factory=ModelConfig)
85
+ prompt: PromptConfig = field(default_factory=PromptConfig)
86
+ dataset: DatasetConfig = field(default_factory=DatasetConfig)
87
+ experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
88
+ seed: int = 42
89
+ verbose: bool = True
90
+ dry_run: bool = False
@@ -0,0 +1,68 @@
1
+ """Component registry for dynamic instantiation."""
2
+
3
+ from typing import Any, Dict, Type
4
+
5
+ from hydra.utils import instantiate
6
+ from omegaconf import DictConfig
7
+
8
+
9
+ class Registry:
10
+ """Registry for dynamically instantiating components from config."""
11
+
12
+ _backends: Dict[str, Type] = {}
13
+ _prompts: Dict[str, Type] = {}
14
+ _experiments: Dict[str, Type] = {}
15
+ _datasets: Dict[str, Type] = {}
16
+
17
+ @classmethod
18
+ def register_backend(cls, name: str):
19
+ """Decorator to register a backend class."""
20
+
21
+ def decorator(klass):
22
+ cls._backends[name] = klass
23
+ return klass
24
+
25
+ return decorator
26
+
27
+ @classmethod
28
+ def register_prompt(cls, name: str):
29
+ """Decorator to register a prompt strategy class."""
30
+
31
+ def decorator(klass):
32
+ cls._prompts[name] = klass
33
+ return klass
34
+
35
+ return decorator
36
+
37
+ @classmethod
38
+ def register_experiment(cls, name: str):
39
+ """Decorator to register an experiment class."""
40
+
41
+ def decorator(klass):
42
+ cls._experiments[name] = klass
43
+ return klass
44
+
45
+ return decorator
46
+
47
+ @classmethod
48
+ def register_dataset(cls, name: str):
49
+ """Decorator to register a dataset class."""
50
+
51
+ def decorator(klass):
52
+ cls._datasets[name] = klass
53
+ return klass
54
+
55
+ return decorator
56
+
57
+
58
+ def create_component(cfg: DictConfig) -> Any:
59
+ """
60
+ Create a component from a Hydra config using _target_.
61
+
62
+ Args:
63
+ cfg: DictConfig with _target_ specifying the class
64
+
65
+ Returns:
66
+ Instantiated component
67
+ """
68
+ return instantiate(cfg)
@@ -0,0 +1,45 @@
1
+ """Datasets module."""
2
+
3
+ from .loaders import (
4
+ BaseDataset,
5
+ CardiologyDataset,
6
+ HistopathologyDataset,
7
+ MARCDataset,
8
+ MedBulletsDataset,
9
+ MedQADataset,
10
+ MMLUMedicalDataset,
11
+ NeurologyDataset,
12
+ OncologyDataset,
13
+ PatchingPairsDataset,
14
+ PediatricsDataset,
15
+ PLABDataset,
16
+ ProbingDiagnosisDataset,
17
+ PubHealthBenchDataset,
18
+ PubMedQADataset,
19
+ RadiologyDataset,
20
+ Sample,
21
+ SyntheticMedicalDataset,
22
+ TCGADataset,
23
+ )
24
+
25
+ __all__ = [
26
+ "Sample",
27
+ "BaseDataset",
28
+ "CardiologyDataset",
29
+ "HistopathologyDataset",
30
+ "MARCDataset",
31
+ "MedBulletsDataset",
32
+ "MedQADataset",
33
+ "MMLUMedicalDataset",
34
+ "NeurologyDataset",
35
+ "OncologyDataset",
36
+ "RadiologyDataset",
37
+ "PediatricsDataset",
38
+ "PLABDataset",
39
+ "PubMedQADataset",
40
+ "PubHealthBenchDataset",
41
+ "SyntheticMedicalDataset",
42
+ "PatchingPairsDataset",
43
+ "ProbingDiagnosisDataset",
44
+ "TCGADataset",
45
+ ]