cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""vLLM backend for high-throughput inference.
|
|
2
|
+
|
|
3
|
+
Supports multiple platforms:
|
|
4
|
+
- CUDA (NVIDIA GPUs) - standard vLLM
|
|
5
|
+
- ROCm (AMD GPUs) - via ROCm Docker or HIP
|
|
6
|
+
- Metal (Apple Silicon) - via vllm-metal plugin
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import platform as plat
|
|
11
|
+
from typing import List, Optional
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from ..core.base import GenerationOutput
|
|
16
|
+
from ..core.registry import Registry
|
|
17
|
+
from .base import InferenceBackend
|
|
18
|
+
|
|
19
|
+
# Fix CUDA forking issue: vLLM's V1 engine uses multiprocessing, but when
|
|
20
|
+
# Hydra/PyTorch have already initialized CUDA, forking fails. Setting spawn
|
|
21
|
+
# method creates fresh Python processes instead.
|
|
22
|
+
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _is_apple_silicon() -> bool:
|
|
26
|
+
"""Detect if running on Apple Silicon."""
|
|
27
|
+
return plat.system() == "Darwin" and plat.processor() == "arm"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _detect_platform() -> str:
|
|
31
|
+
"""Detect the GPU platform.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
"metal" for Apple Silicon
|
|
35
|
+
"cuda" for NVIDIA/AMD GPUs (ROCm uses CUDA-compatible API)
|
|
36
|
+
"cpu" if no GPU available
|
|
37
|
+
"""
|
|
38
|
+
if _is_apple_silicon():
|
|
39
|
+
return "metal"
|
|
40
|
+
if torch.cuda.is_available():
|
|
41
|
+
return "cuda"
|
|
42
|
+
return "cpu"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@Registry.register_backend("vllm")
|
|
46
|
+
class VLLMBackend(InferenceBackend):
|
|
47
|
+
"""
|
|
48
|
+
High-throughput inference backend using vLLM.
|
|
49
|
+
|
|
50
|
+
Best for:
|
|
51
|
+
- Large-scale experiments (1000+ samples)
|
|
52
|
+
- Batch inference
|
|
53
|
+
- When activation access is not needed
|
|
54
|
+
|
|
55
|
+
Platforms:
|
|
56
|
+
- CUDA: Standard vLLM (pip install vllm)
|
|
57
|
+
- ROCm: vLLM in ROCm Docker container
|
|
58
|
+
- Metal: vLLM-Metal plugin (install via curl script)
|
|
59
|
+
|
|
60
|
+
Note:
|
|
61
|
+
Does NOT support activation extraction or patching.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
tensor_parallel_size: int = 1,
|
|
67
|
+
dtype: str = "bfloat16",
|
|
68
|
+
max_model_len: int | None = None,
|
|
69
|
+
trust_remote_code: bool = True,
|
|
70
|
+
quantization: str | None = None,
|
|
71
|
+
gpu_memory_utilization: float = 0.9,
|
|
72
|
+
enforce_eager: bool = False,
|
|
73
|
+
limit_mm_per_prompt: dict | str | None = None,
|
|
74
|
+
platform: str = "auto",
|
|
75
|
+
**kwargs,
|
|
76
|
+
):
|
|
77
|
+
self.tensor_parallel_size = tensor_parallel_size
|
|
78
|
+
self.dtype = dtype
|
|
79
|
+
self.max_model_len = max_model_len
|
|
80
|
+
self.trust_remote_code = trust_remote_code
|
|
81
|
+
self.quantization = quantization
|
|
82
|
+
self.gpu_memory_utilization = gpu_memory_utilization
|
|
83
|
+
self.enforce_eager = enforce_eager
|
|
84
|
+
self.limit_mm_per_prompt = limit_mm_per_prompt
|
|
85
|
+
self._model = None
|
|
86
|
+
self._model_name: Optional[str] = None
|
|
87
|
+
|
|
88
|
+
# Platform detection
|
|
89
|
+
self._platform = _detect_platform() if platform == "auto" else platform
|
|
90
|
+
self._setup_platform_env()
|
|
91
|
+
|
|
92
|
+
def _setup_platform_env(self) -> None:
|
|
93
|
+
"""Configure environment variables for the detected platform."""
|
|
94
|
+
if self._platform == "metal":
|
|
95
|
+
# vLLM-Metal configuration
|
|
96
|
+
os.environ.setdefault("VLLM_METAL_MEMORY_FRACTION", str(self.gpu_memory_utilization))
|
|
97
|
+
os.environ.setdefault("VLLM_METAL_USE_MLX", "1")
|
|
98
|
+
os.environ.setdefault("VLLM_METAL_BLOCK_SIZE", "16")
|
|
99
|
+
print(" Platform: Apple Silicon (Metal/MLX)")
|
|
100
|
+
elif self._platform == "cuda":
|
|
101
|
+
print(" Platform: CUDA")
|
|
102
|
+
else:
|
|
103
|
+
print(f" Platform: {self._platform}")
|
|
104
|
+
|
|
105
|
+
def load_model(self, model_name: str, **kwargs) -> None:
|
|
106
|
+
"""Load model with vLLM."""
|
|
107
|
+
try:
|
|
108
|
+
from vllm import LLM
|
|
109
|
+
except ImportError:
|
|
110
|
+
if self._platform == "metal":
|
|
111
|
+
raise ImportError(
|
|
112
|
+
"vLLM not found. On Apple Silicon, install vllm-metal:\n"
|
|
113
|
+
"curl -fsSL https://raw.githubusercontent.com/vllm-project/vllm-metal/main/install.sh | bash"
|
|
114
|
+
)
|
|
115
|
+
raise ImportError("vLLM not installed. Run: pip install vllm")
|
|
116
|
+
|
|
117
|
+
# Build LLM kwargs
|
|
118
|
+
llm_kwargs = {
|
|
119
|
+
"model": model_name,
|
|
120
|
+
"tensor_parallel_size": self.tensor_parallel_size,
|
|
121
|
+
"dtype": self.dtype,
|
|
122
|
+
"trust_remote_code": self.trust_remote_code,
|
|
123
|
+
"enforce_eager": self.enforce_eager,
|
|
124
|
+
**kwargs,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
# gpu_memory_utilization only applies to CUDA/ROCm
|
|
128
|
+
if self._platform != "metal":
|
|
129
|
+
llm_kwargs["gpu_memory_utilization"] = self.gpu_memory_utilization
|
|
130
|
+
|
|
131
|
+
if self.max_model_len is not None:
|
|
132
|
+
llm_kwargs["max_model_len"] = self.max_model_len
|
|
133
|
+
|
|
134
|
+
if self.quantization is not None:
|
|
135
|
+
llm_kwargs["quantization"] = self.quantization
|
|
136
|
+
|
|
137
|
+
if self.limit_mm_per_prompt is not None:
|
|
138
|
+
llm_kwargs["limit_mm_per_prompt"] = self.limit_mm_per_prompt
|
|
139
|
+
|
|
140
|
+
print(f"DEBUG: VLLM args: {llm_kwargs}")
|
|
141
|
+
self._model = LLM(**llm_kwargs)
|
|
142
|
+
self._model_name = model_name
|
|
143
|
+
|
|
144
|
+
def generate(
|
|
145
|
+
self,
|
|
146
|
+
prompt: str,
|
|
147
|
+
max_new_tokens: int = 512,
|
|
148
|
+
temperature: float = 0.7,
|
|
149
|
+
top_p: float = 0.9,
|
|
150
|
+
system_prompt: Optional[str] = None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
) -> GenerationOutput:
|
|
153
|
+
"""Generate from a single prompt."""
|
|
154
|
+
outputs = self.generate_batch(
|
|
155
|
+
[prompt],
|
|
156
|
+
max_new_tokens=max_new_tokens,
|
|
157
|
+
temperature=temperature,
|
|
158
|
+
top_p=top_p,
|
|
159
|
+
system_prompt=system_prompt,
|
|
160
|
+
**kwargs,
|
|
161
|
+
)
|
|
162
|
+
return outputs[0]
|
|
163
|
+
|
|
164
|
+
def generate_batch(
|
|
165
|
+
self,
|
|
166
|
+
prompts: List[str],
|
|
167
|
+
max_new_tokens: int = 512,
|
|
168
|
+
temperature: float = 0.7,
|
|
169
|
+
top_p: float = 0.9,
|
|
170
|
+
system_prompt: Optional[str] = None,
|
|
171
|
+
**kwargs,
|
|
172
|
+
) -> List[GenerationOutput]:
|
|
173
|
+
"""Generate from multiple prompts efficiently."""
|
|
174
|
+
from vllm import SamplingParams
|
|
175
|
+
|
|
176
|
+
if self._model is None:
|
|
177
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
178
|
+
|
|
179
|
+
prompts = self._apply_system_prompt(prompts, system_prompt)
|
|
180
|
+
|
|
181
|
+
sampling_params = SamplingParams(
|
|
182
|
+
max_tokens=max_new_tokens, temperature=temperature, top_p=top_p, **kwargs
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
outputs = self._model.generate(prompts, sampling_params)
|
|
186
|
+
|
|
187
|
+
return [
|
|
188
|
+
GenerationOutput(
|
|
189
|
+
text=output.outputs[0].text,
|
|
190
|
+
tokens=list(output.outputs[0].token_ids),
|
|
191
|
+
logprobs=None, # Can be enabled in SamplingParams if needed
|
|
192
|
+
)
|
|
193
|
+
for output in outputs
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def _apply_system_prompt(prompts: List[str], system_prompt: Optional[str]) -> List[str]:
|
|
198
|
+
if not system_prompt:
|
|
199
|
+
return prompts
|
|
200
|
+
system_prompt = system_prompt.strip()
|
|
201
|
+
if not system_prompt:
|
|
202
|
+
return prompts
|
|
203
|
+
return [f"{system_prompt}\n\n{prompt}" for prompt in prompts]
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def platform(self) -> str:
|
|
207
|
+
"""Return the detected platform (cuda, metal, cpu)."""
|
|
208
|
+
return self._platform
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def supports_activations(self) -> bool:
|
|
212
|
+
"""vLLM optimizes away intermediate activations."""
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def model_name(self) -> Optional[str]:
|
|
217
|
+
return self._model_name
|
|
218
|
+
|
|
219
|
+
def unload(self) -> None:
|
|
220
|
+
"""Free GPU memory."""
|
|
221
|
+
if self._model is not None:
|
|
222
|
+
del self._model
|
|
223
|
+
self._model = None
|
|
224
|
+
|
|
225
|
+
# Platform-specific cleanup
|
|
226
|
+
if self._platform != "metal" and torch.cuda.is_available():
|
|
227
|
+
torch.cuda.empty_cache()
|
cotlab/cli.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""CLI utilities for CoTLab."""
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import click
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@click.group()
|
|
10
|
+
def cli():
|
|
11
|
+
"""CoTLab command-line utilities."""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@cli.command()
|
|
16
|
+
@click.argument("model_name")
|
|
17
|
+
@click.option(
|
|
18
|
+
"--backend",
|
|
19
|
+
default="vllm",
|
|
20
|
+
type=click.Choice(["vllm", "transformers"]),
|
|
21
|
+
help="Backend: vllm or transformers",
|
|
22
|
+
)
|
|
23
|
+
@click.option("--output", "-o", help="Output path (default: conf/model/<safe_name>.yaml)")
|
|
24
|
+
def template(model_name: str, backend: str, output: str):
|
|
25
|
+
"""Generate model config from template.
|
|
26
|
+
Use this when you want to pre-create a model config file before the first run.
|
|
27
|
+
This is optional: CoTLab can also auto-generate a model config at runtime when
|
|
28
|
+
you run with `model=org/repo-id`.
|
|
29
|
+
|
|
30
|
+
Examples:
|
|
31
|
+
|
|
32
|
+
cotlab-template meta-llama/Llama-3.1-8B
|
|
33
|
+
|
|
34
|
+
cotlab-template google/gemma-3-12b --backend transformers
|
|
35
|
+
|
|
36
|
+
cotlab-template mistralai/Mistral-7B-v0.1 -o conf/model/mistral7b.yaml
|
|
37
|
+
"""
|
|
38
|
+
# Get project root (where this file is located)
|
|
39
|
+
cli_file = Path(__file__)
|
|
40
|
+
project_root = cli_file.parent.parent.parent
|
|
41
|
+
|
|
42
|
+
# Get template
|
|
43
|
+
template_path = project_root / f"conf/model/_base/{backend}_default.yaml"
|
|
44
|
+
|
|
45
|
+
if not template_path.exists():
|
|
46
|
+
click.echo(f"Template not found: {template_path}", err=True)
|
|
47
|
+
raise click.Abort()
|
|
48
|
+
|
|
49
|
+
# Generate output path
|
|
50
|
+
if not output:
|
|
51
|
+
# Convert model name to safe filename
|
|
52
|
+
# meta-llama/Llama-3.1-8B -> meta_llama_llama_3_1_8b
|
|
53
|
+
safe_name = model_name.replace("/", "_").replace("-", "_").replace(".", "_").lower()
|
|
54
|
+
output = f"conf/model/{safe_name}.yaml"
|
|
55
|
+
|
|
56
|
+
output_path = Path(output)
|
|
57
|
+
|
|
58
|
+
# Check if exists
|
|
59
|
+
if output_path.exists():
|
|
60
|
+
if not click.confirm(f"{output} already exists. Overwrite?"):
|
|
61
|
+
click.echo("Aborted.")
|
|
62
|
+
raise click.Abort()
|
|
63
|
+
|
|
64
|
+
# Ensure parent directory exists
|
|
65
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
66
|
+
|
|
67
|
+
# Copy template
|
|
68
|
+
shutil.copy(template_path, output_path)
|
|
69
|
+
|
|
70
|
+
# Replace placeholder
|
|
71
|
+
content = output_path.read_text()
|
|
72
|
+
content = content.replace("huggingface/model-name", model_name)
|
|
73
|
+
output_path.write_text(content)
|
|
74
|
+
|
|
75
|
+
# Success message
|
|
76
|
+
click.echo(f"Created: {output}")
|
|
77
|
+
click.echo(f"Edit {output} to customize parameters")
|
|
78
|
+
click.echo("\nUsage:")
|
|
79
|
+
click.echo(f" python -m cotlab.main model={output_path.stem}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
if __name__ == "__main__":
|
|
83
|
+
cli()
|
cotlab/core/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Core module - base classes and configuration."""
|
|
2
|
+
|
|
3
|
+
from .base import (
|
|
4
|
+
BaseExperiment,
|
|
5
|
+
BasePromptStrategy,
|
|
6
|
+
ExperimentResult,
|
|
7
|
+
GenerationOutput,
|
|
8
|
+
StructuredOutputMixin,
|
|
9
|
+
)
|
|
10
|
+
from .config import (
|
|
11
|
+
BackendConfig,
|
|
12
|
+
Config,
|
|
13
|
+
DatasetConfig,
|
|
14
|
+
ExperimentConfig,
|
|
15
|
+
ModelConfig,
|
|
16
|
+
PromptConfig,
|
|
17
|
+
)
|
|
18
|
+
from .registry import Registry, create_component
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"GenerationOutput",
|
|
22
|
+
"ExperimentResult",
|
|
23
|
+
"BasePromptStrategy",
|
|
24
|
+
"BaseExperiment",
|
|
25
|
+
"StructuredOutputMixin",
|
|
26
|
+
"BackendConfig",
|
|
27
|
+
"ModelConfig",
|
|
28
|
+
"PromptConfig",
|
|
29
|
+
"DatasetConfig",
|
|
30
|
+
"ExperimentConfig",
|
|
31
|
+
"Config",
|
|
32
|
+
"Registry",
|
|
33
|
+
"create_component",
|
|
34
|
+
]
|