cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,227 @@
1
+ """vLLM backend for high-throughput inference.
2
+
3
+ Supports multiple platforms:
4
+ - CUDA (NVIDIA GPUs) - standard vLLM
5
+ - ROCm (AMD GPUs) - via ROCm Docker or HIP
6
+ - Metal (Apple Silicon) - via vllm-metal plugin
7
+ """
8
+
9
+ import os
10
+ import platform as plat
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+
15
+ from ..core.base import GenerationOutput
16
+ from ..core.registry import Registry
17
+ from .base import InferenceBackend
18
+
19
+ # Fix CUDA forking issue: vLLM's V1 engine uses multiprocessing, but when
20
+ # Hydra/PyTorch have already initialized CUDA, forking fails. Setting spawn
21
+ # method creates fresh Python processes instead.
22
+ os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
23
+
24
+
25
+ def _is_apple_silicon() -> bool:
26
+ """Detect if running on Apple Silicon."""
27
+ return plat.system() == "Darwin" and plat.processor() == "arm"
28
+
29
+
30
+ def _detect_platform() -> str:
31
+ """Detect the GPU platform.
32
+
33
+ Returns:
34
+ "metal" for Apple Silicon
35
+ "cuda" for NVIDIA/AMD GPUs (ROCm uses CUDA-compatible API)
36
+ "cpu" if no GPU available
37
+ """
38
+ if _is_apple_silicon():
39
+ return "metal"
40
+ if torch.cuda.is_available():
41
+ return "cuda"
42
+ return "cpu"
43
+
44
+
45
+ @Registry.register_backend("vllm")
46
+ class VLLMBackend(InferenceBackend):
47
+ """
48
+ High-throughput inference backend using vLLM.
49
+
50
+ Best for:
51
+ - Large-scale experiments (1000+ samples)
52
+ - Batch inference
53
+ - When activation access is not needed
54
+
55
+ Platforms:
56
+ - CUDA: Standard vLLM (pip install vllm)
57
+ - ROCm: vLLM in ROCm Docker container
58
+ - Metal: vLLM-Metal plugin (install via curl script)
59
+
60
+ Note:
61
+ Does NOT support activation extraction or patching.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ tensor_parallel_size: int = 1,
67
+ dtype: str = "bfloat16",
68
+ max_model_len: int | None = None,
69
+ trust_remote_code: bool = True,
70
+ quantization: str | None = None,
71
+ gpu_memory_utilization: float = 0.9,
72
+ enforce_eager: bool = False,
73
+ limit_mm_per_prompt: dict | str | None = None,
74
+ platform: str = "auto",
75
+ **kwargs,
76
+ ):
77
+ self.tensor_parallel_size = tensor_parallel_size
78
+ self.dtype = dtype
79
+ self.max_model_len = max_model_len
80
+ self.trust_remote_code = trust_remote_code
81
+ self.quantization = quantization
82
+ self.gpu_memory_utilization = gpu_memory_utilization
83
+ self.enforce_eager = enforce_eager
84
+ self.limit_mm_per_prompt = limit_mm_per_prompt
85
+ self._model = None
86
+ self._model_name: Optional[str] = None
87
+
88
+ # Platform detection
89
+ self._platform = _detect_platform() if platform == "auto" else platform
90
+ self._setup_platform_env()
91
+
92
+ def _setup_platform_env(self) -> None:
93
+ """Configure environment variables for the detected platform."""
94
+ if self._platform == "metal":
95
+ # vLLM-Metal configuration
96
+ os.environ.setdefault("VLLM_METAL_MEMORY_FRACTION", str(self.gpu_memory_utilization))
97
+ os.environ.setdefault("VLLM_METAL_USE_MLX", "1")
98
+ os.environ.setdefault("VLLM_METAL_BLOCK_SIZE", "16")
99
+ print(" Platform: Apple Silicon (Metal/MLX)")
100
+ elif self._platform == "cuda":
101
+ print(" Platform: CUDA")
102
+ else:
103
+ print(f" Platform: {self._platform}")
104
+
105
+ def load_model(self, model_name: str, **kwargs) -> None:
106
+ """Load model with vLLM."""
107
+ try:
108
+ from vllm import LLM
109
+ except ImportError:
110
+ if self._platform == "metal":
111
+ raise ImportError(
112
+ "vLLM not found. On Apple Silicon, install vllm-metal:\n"
113
+ "curl -fsSL https://raw.githubusercontent.com/vllm-project/vllm-metal/main/install.sh | bash"
114
+ )
115
+ raise ImportError("vLLM not installed. Run: pip install vllm")
116
+
117
+ # Build LLM kwargs
118
+ llm_kwargs = {
119
+ "model": model_name,
120
+ "tensor_parallel_size": self.tensor_parallel_size,
121
+ "dtype": self.dtype,
122
+ "trust_remote_code": self.trust_remote_code,
123
+ "enforce_eager": self.enforce_eager,
124
+ **kwargs,
125
+ }
126
+
127
+ # gpu_memory_utilization only applies to CUDA/ROCm
128
+ if self._platform != "metal":
129
+ llm_kwargs["gpu_memory_utilization"] = self.gpu_memory_utilization
130
+
131
+ if self.max_model_len is not None:
132
+ llm_kwargs["max_model_len"] = self.max_model_len
133
+
134
+ if self.quantization is not None:
135
+ llm_kwargs["quantization"] = self.quantization
136
+
137
+ if self.limit_mm_per_prompt is not None:
138
+ llm_kwargs["limit_mm_per_prompt"] = self.limit_mm_per_prompt
139
+
140
+ print(f"DEBUG: VLLM args: {llm_kwargs}")
141
+ self._model = LLM(**llm_kwargs)
142
+ self._model_name = model_name
143
+
144
+ def generate(
145
+ self,
146
+ prompt: str,
147
+ max_new_tokens: int = 512,
148
+ temperature: float = 0.7,
149
+ top_p: float = 0.9,
150
+ system_prompt: Optional[str] = None,
151
+ **kwargs,
152
+ ) -> GenerationOutput:
153
+ """Generate from a single prompt."""
154
+ outputs = self.generate_batch(
155
+ [prompt],
156
+ max_new_tokens=max_new_tokens,
157
+ temperature=temperature,
158
+ top_p=top_p,
159
+ system_prompt=system_prompt,
160
+ **kwargs,
161
+ )
162
+ return outputs[0]
163
+
164
+ def generate_batch(
165
+ self,
166
+ prompts: List[str],
167
+ max_new_tokens: int = 512,
168
+ temperature: float = 0.7,
169
+ top_p: float = 0.9,
170
+ system_prompt: Optional[str] = None,
171
+ **kwargs,
172
+ ) -> List[GenerationOutput]:
173
+ """Generate from multiple prompts efficiently."""
174
+ from vllm import SamplingParams
175
+
176
+ if self._model is None:
177
+ raise RuntimeError("Model not loaded. Call load_model() first.")
178
+
179
+ prompts = self._apply_system_prompt(prompts, system_prompt)
180
+
181
+ sampling_params = SamplingParams(
182
+ max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, **kwargs
183
+ )
184
+
185
+ outputs = self._model.generate(prompts, sampling_params)
186
+
187
+ return [
188
+ GenerationOutput(
189
+ text=output.outputs[0].text,
190
+ tokens=list(output.outputs[0].token_ids),
191
+ logprobs=None, # Can be enabled in SamplingParams if needed
192
+ )
193
+ for output in outputs
194
+ ]
195
+
196
+ @staticmethod
197
+ def _apply_system_prompt(prompts: List[str], system_prompt: Optional[str]) -> List[str]:
198
+ if not system_prompt:
199
+ return prompts
200
+ system_prompt = system_prompt.strip()
201
+ if not system_prompt:
202
+ return prompts
203
+ return [f"{system_prompt}\n\n{prompt}" for prompt in prompts]
204
+
205
+ @property
206
+ def platform(self) -> str:
207
+ """Return the detected platform (cuda, metal, cpu)."""
208
+ return self._platform
209
+
210
+ @property
211
+ def supports_activations(self) -> bool:
212
+ """vLLM optimizes away intermediate activations."""
213
+ return False
214
+
215
+ @property
216
+ def model_name(self) -> Optional[str]:
217
+ return self._model_name
218
+
219
+ def unload(self) -> None:
220
+ """Free GPU memory."""
221
+ if self._model is not None:
222
+ del self._model
223
+ self._model = None
224
+
225
+ # Platform-specific cleanup
226
+ if self._platform != "metal" and torch.cuda.is_available():
227
+ torch.cuda.empty_cache()
cotlab/cli.py ADDED
@@ -0,0 +1,83 @@
1
+ """CLI utilities for CoTLab."""
2
+
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ import click
7
+
8
+
9
+ @click.group()
10
+ def cli():
11
+ """CoTLab command-line utilities."""
12
+ pass
13
+
14
+
15
+ @cli.command()
16
+ @click.argument("model_name")
17
+ @click.option(
18
+ "--backend",
19
+ default="vllm",
20
+ type=click.Choice(["vllm", "transformers"]),
21
+ help="Backend: vllm or transformers",
22
+ )
23
+ @click.option("--output", "-o", help="Output path (default: conf/model/<safe_name>.yaml)")
24
+ def template(model_name: str, backend: str, output: str):
25
+ """Generate model config from template.
26
+ Use this when you want to pre-create a model config file before the first run.
27
+ This is optional: CoTLab can also auto-generate a model config at runtime when
28
+ you run with `model=org/repo-id`.
29
+
30
+ Examples:
31
+
32
+ cotlab-template meta-llama/Llama-3.1-8B
33
+
34
+ cotlab-template google/gemma-3-12b --backend transformers
35
+
36
+ cotlab-template mistralai/Mistral-7B-v0.1 -o conf/model/mistral7b.yaml
37
+ """
38
+ # Get project root (where this file is located)
39
+ cli_file = Path(__file__)
40
+ project_root = cli_file.parent.parent.parent
41
+
42
+ # Get template
43
+ template_path = project_root / f"conf/model/_base/{backend}_default.yaml"
44
+
45
+ if not template_path.exists():
46
+ click.echo(f"Template not found: {template_path}", err=True)
47
+ raise click.Abort()
48
+
49
+ # Generate output path
50
+ if not output:
51
+ # Convert model name to safe filename
52
+ # meta-llama/Llama-3.1-8B -> meta_llama_llama_3_1_8b
53
+ safe_name = model_name.replace("/", "_").replace("-", "_").replace(".", "_").lower()
54
+ output = f"conf/model/{safe_name}.yaml"
55
+
56
+ output_path = Path(output)
57
+
58
+ # Check if exists
59
+ if output_path.exists():
60
+ if not click.confirm(f"{output} already exists. Overwrite?"):
61
+ click.echo("Aborted.")
62
+ raise click.Abort()
63
+
64
+ # Ensure parent directory exists
65
+ output_path.parent.mkdir(parents=True, exist_ok=True)
66
+
67
+ # Copy template
68
+ shutil.copy(template_path, output_path)
69
+
70
+ # Replace placeholder
71
+ content = output_path.read_text()
72
+ content = content.replace("huggingface/model-name", model_name)
73
+ output_path.write_text(content)
74
+
75
+ # Success message
76
+ click.echo(f"Created: {output}")
77
+ click.echo(f"Edit {output} to customize parameters")
78
+ click.echo("\nUsage:")
79
+ click.echo(f" python -m cotlab.main model={output_path.stem}")
80
+
81
+
82
+ if __name__ == "__main__":
83
+ cli()
@@ -0,0 +1,34 @@
1
+ """Core module - base classes and configuration."""
2
+
3
+ from .base import (
4
+ BaseExperiment,
5
+ BasePromptStrategy,
6
+ ExperimentResult,
7
+ GenerationOutput,
8
+ StructuredOutputMixin,
9
+ )
10
+ from .config import (
11
+ BackendConfig,
12
+ Config,
13
+ DatasetConfig,
14
+ ExperimentConfig,
15
+ ModelConfig,
16
+ PromptConfig,
17
+ )
18
+ from .registry import Registry, create_component
19
+
20
+ __all__ = [
21
+ "GenerationOutput",
22
+ "ExperimentResult",
23
+ "BasePromptStrategy",
24
+ "BaseExperiment",
25
+ "StructuredOutputMixin",
26
+ "BackendConfig",
27
+ "ModelConfig",
28
+ "PromptConfig",
29
+ "DatasetConfig",
30
+ "ExperimentConfig",
31
+ "Config",
32
+ "Registry",
33
+ "create_component",
34
+ ]