wisent-tools 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.4
2
+ Name: wisent-tools
3
+ Version: 0.1.0
4
+ Summary: Operational scripts and benchmark-evaluation runners for the wisent package family
5
+ Home-page: https://github.com/wisent-ai/wisent-tools
6
+ Author: Lukasz Bartoszcze and the Wisent Team
7
+ Author-email: lukasz.bartoszcze@wisent.ai
8
+ Requires-Python: >=3.9
9
+ Requires-Dist: wisent>=0.10.0
10
+ Requires-Dist: wisent-evaluators>=0.1.0
11
+ Dynamic: author
12
+ Dynamic: author-email
13
+ Dynamic: home-page
14
+ Dynamic: requires-dist
15
+ Dynamic: requires-python
16
+ Dynamic: summary
@@ -0,0 +1,11 @@
1
+ # wisent-tools
2
+
3
+ Operational scripts split out of wisent-open-source. Provides `wisent.scripts` —
4
+ benchmark-evaluation runners (aime, apps, conala, livemathbench, math, polymath),
5
+ extract helpers, fix utilities.
6
+
7
+ ## Install
8
+
9
+ ```
10
+ pip install wisent-tools
11
+ ```
@@ -0,0 +1,3 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,14 @@
1
+ from setuptools import setup, find_packages
2
+ setup(
3
+ name="wisent-tools",
4
+ version="0.1.0",
5
+ author="Lukasz Bartoszcze and the Wisent Team",
6
+ author_email="lukasz.bartoszcze@wisent.ai",
7
+ description="Operational scripts and benchmark-evaluation runners for the wisent package family",
8
+ url="https://github.com/wisent-ai/wisent-tools",
9
+ packages=find_packages(include=["wisent", "wisent.*"]),
10
+ python_requires=">=3.9",
11
+ install_requires=["wisent>=0.10.0", "wisent-evaluators>=0.1.0"],
12
+ include_package_data=True,
13
+ package_data={"wisent": ["scripts/*.sh"]},
14
+ )
@@ -0,0 +1,15 @@
1
+ """Namespace bootstrap shared with wisent-core and sibling packages.
2
+
3
+ Uses pkgutil.extend_path so all wisent-* packages merge at import time
4
+ even though wisent-core ships a regular (non-PEP-420) package.
5
+ """
6
+ import os
7
+ import pkgutil
8
+
9
+ __path__ = pkgutil.extend_path(__path__, __name__)
10
+
11
+ _base = os.path.dirname(__file__)
12
+ for _entry in sorted(os.listdir(_base)):
13
+ _path = os.path.join(_base, _entry)
14
+ if os.path.isdir(_path) and not _entry.startswith(('.', '_')):
15
+ __path__.append(_path)
@@ -0,0 +1 @@
1
+ """Wisent scripts for activation extraction and data processing."""
@@ -0,0 +1 @@
1
+ """Extracted helpers for files exceeding 300-line limit."""
@@ -0,0 +1,199 @@
1
+ """
2
+ Benchmark extraction and main entry point for extract_all_missing.
3
+
4
+ Split from extract_all_missing.py to meet 300-line limit.
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+ import time
10
+
11
+ import psycopg2
12
+ import torch
13
+
14
+ from wisent.core.utils.config_tools.constants import RECURSION_INITIAL_DEPTH
15
+
16
+ from wisent.scripts.extract_all_missing import (
17
+ hidden_states_to_bytes,
18
+ get_conn,
19
+ reset_conn,
20
+ batch_create_activations,
21
+ get_missing_benchmarks,
22
+ )
23
+
24
+
25
+ def extract_benchmark(model, tokenizer, model_id: int, benchmark_name: str, set_id: int,
26
+ device: str, num_layers: int, batch_size: int,
27
+ db_connect_wait_s: int, max_retries: int,
28
+ log_interval: int):
29
+ """Extract activations for a single benchmark using EXISTING pairs from database.
30
+
31
+ Only extracts pairs that don't already have activations for this model.
32
+ """
33
+ conn = get_conn(db_connect_wait_s)
34
+ cur = conn.cursor()
35
+
36
+ # Get pairs that DON'T already have activations for this model
37
+ cur.execute('''
38
+ SELECT cp.id, cp."positiveExample", cp."negativeExample"
39
+ FROM "ContrastivePair" cp
40
+ WHERE cp."setId" = %s
41
+ AND NOT EXISTS (
42
+ SELECT 1 FROM "Activation" a
43
+ WHERE a."contrastivePairId" = cp.id AND a."modelId" = %s
44
+ )
45
+ ORDER BY cp.id
46
+ ''', (set_id, model_id))
47
+ db_pairs = cur.fetchall()
48
+ cur.close()
49
+
50
+ if not db_pairs:
51
+ print(f" All pairs already extracted for {benchmark_name}", flush=True)
52
+ return 0
53
+
54
+ print(f" Extracting {len(db_pairs)} pairs (skipping already extracted)...", flush=True)
55
+ extracted = 0
56
+
57
+ def get_hidden_states(text):
58
+ enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length)
59
+ enc = {k: v.to(device) for k, v in enc.items()}
60
+ with torch.inference_mode():
61
+ out = model(**enc, output_hidden_states=True, use_cache=False)
62
+ # Return last token hidden state for each layer
63
+ return [out.hidden_states[i][0, -1, :] for i in range(1, len(out.hidden_states))]
64
+
65
+ # Process in batches to reduce DB round trips
66
+ for batch_start in range(0, len(db_pairs), batch_size):
67
+ batch_end = min(batch_start + batch_size, len(db_pairs))
68
+ batch_pairs = db_pairs[batch_start:batch_end]
69
+
70
+ activations_batch = []
71
+ for pair_id, pos_text, neg_text in batch_pairs:
72
+ pos_hidden = get_hidden_states(pos_text)
73
+ neg_hidden = get_hidden_states(neg_text)
74
+
75
+ # Collect all layers for this pair
76
+ for layer_idx in range(num_layers):
77
+ layer_num = layer_idx + 1
78
+ pos_bytes = hidden_states_to_bytes(pos_hidden[layer_idx])
79
+ neg_bytes = hidden_states_to_bytes(neg_hidden[layer_idx])
80
+ neuron_count = pos_hidden[layer_idx].shape[0]
81
+
82
+ activations_batch.append((
83
+ model_id, pair_id, set_id, layer_num, neuron_count,
84
+ "chat_last", psycopg2.Binary(pos_bytes), True
85
+ ))
86
+ activations_batch.append((
87
+ model_id, pair_id, set_id, layer_num, neuron_count,
88
+ "chat_last", psycopg2.Binary(neg_bytes), False
89
+ ))
90
+
91
+ del pos_hidden, neg_hidden
92
+ extracted += 1
93
+
94
+ # Batch insert all activations for this batch of pairs
95
+ batch_create_activations(activations_batch, max_retries=max_retries, db_connect_wait_s=db_connect_wait_s)
96
+
97
+ if batch_end % log_interval == RECURSION_INITIAL_DEPTH or batch_end == len(db_pairs):
98
+ print(f" Processed {batch_end}/{len(db_pairs)} pairs", flush=True)
99
+
100
+ if device == "cuda":
101
+ torch.cuda.empty_cache()
102
+
103
+ return extracted
104
+
105
+
106
+ def main():
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument("--model", required=True, help="Model name (e.g., meta-llama/Llama-3.2-1B-Instruct)")
109
+ parser.add_argument("--device", required=True, help="Device (cuda/mps/cpu)")
110
+ parser.add_argument("--batch-size", type=int, required=True, help="Batch size for extraction (number of pairs per DB round trip)")
111
+ parser.add_argument("--benchmark", default=None, help="Single benchmark to extract (optional)")
112
+ parser.add_argument("--db-connect-wait-s", type=int, required=True, help="Database connection wait seconds")
113
+ parser.add_argument("--max-retries", type=int, required=True, help="Maximum retry attempts for DB operations")
114
+ parser.add_argument("--log-interval", type=int, required=True, help="Progress logging interval")
115
+ args = parser.parse_args()
116
+
117
+ from transformers import AutoTokenizer, AutoModelForCausalLM
118
+
119
+ print(f"Loading model {args.model}...", flush=True)
120
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
121
+
122
+ if args.device == "mps":
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ args.model,
125
+ torch_dtype=torch.float32,
126
+ trust_remote_code=True,
127
+ )
128
+ model = model.to("mps")
129
+ else:
130
+ model = AutoModelForCausalLM.from_pretrained(
131
+ args.model,
132
+ torch_dtype="auto",
133
+ device_map="auto",
134
+ trust_remote_code=True,
135
+ )
136
+ model.eval()
137
+
138
+ num_layers = model.config.num_hidden_layers
139
+ print(f"Model loaded: {num_layers} layers", flush=True)
140
+
141
+ conn = get_conn(args.db_connect_wait_s)
142
+ cur = conn.cursor()
143
+
144
+ # Get model ID
145
+ cur.execute('SELECT id FROM "Model" WHERE "huggingFaceId" = %s', (args.model,))
146
+ result = cur.fetchone()
147
+ if not result:
148
+ print(f"ERROR: Model {args.model} not found in database", flush=True)
149
+ sys.exit(1)
150
+ model_id = result[0]
151
+ cur.close()
152
+ print(f"Model ID: {model_id}", flush=True)
153
+
154
+ if args.benchmark:
155
+ # Extract single benchmark
156
+ conn = get_conn(args.db_connect_wait_s)
157
+ cur = conn.cursor()
158
+ cur.execute('SELECT id FROM "ContrastivePairSet" WHERE name = %s', (args.benchmark,))
159
+ result = cur.fetchone()
160
+ cur.close()
161
+ if not result:
162
+ print(f"ERROR: Benchmark {args.benchmark} not found", flush=True)
163
+ sys.exit(1)
164
+ set_id = result[0]
165
+
166
+ print(f"Extracting single benchmark: {args.benchmark}", flush=True)
167
+ extracted = extract_benchmark(model, tokenizer, model_id, args.benchmark, set_id,
168
+ args.device, num_layers, args.batch_size,
169
+ db_connect_wait_s=args.db_connect_wait_s, max_retries=args.max_retries,
170
+ log_interval=args.log_interval)
171
+ print(f"Done! Extracted {extracted} pairs", flush=True)
172
+ else:
173
+ # Extract all incomplete benchmarks
174
+ missing = get_missing_benchmarks(get_conn(args.db_connect_wait_s), model_id, log_interval=args.log_interval)
175
+ print(f"Found {len(missing)} incomplete benchmarks to extract", flush=True)
176
+
177
+ if not missing:
178
+ print("All benchmarks are complete!", flush=True)
179
+ reset_conn()
180
+ return
181
+
182
+ total_extracted = 0
183
+ for i, (set_id, benchmark_name, pairs_needed) in enumerate(missing):
184
+ print(f"\n[{i+1}/{len(missing)}] {benchmark_name} ({pairs_needed} pairs needed)", flush=True)
185
+ start = time.time()
186
+
187
+ extracted = extract_benchmark(model, tokenizer, model_id, benchmark_name, set_id,
188
+ args.device, num_layers, args.batch_size,
189
+ db_connect_wait_s=args.db_connect_wait_s, max_retries=args.max_retries,
190
+ log_interval=args.log_interval)
191
+
192
+ total_extracted += extracted
193
+ elapsed = time.time() - start
194
+ print(f" Extracted {extracted} pairs in {elapsed:.1f}s", flush=True)
195
+
196
+ print(f"\n{'='*60}", flush=True)
197
+ print(f"COMPLETE! Total extracted: {total_extracted} pairs across {len(missing)} benchmarks", flush=True)
198
+
199
+ reset_conn()
@@ -0,0 +1,117 @@
1
+ """Database connection management for extract_raw_activations."""
2
+
3
+ from __future__ import annotations
4
+ import os
5
+ import psycopg2
6
+
7
+
8
+ DATABASE_URL = os.environ.get("DATABASE_URL")
9
+ if DATABASE_URL and '?' in DATABASE_URL:
10
+ DATABASE_URL = DATABASE_URL.split('?')[0]
11
+
12
+ if not DATABASE_URL:
13
+ raise RuntimeError("DATABASE_URL environment variable is required")
14
+
15
+ _db_conn = None
16
+
17
+ # Preserved from original extract_raw_activations.py
18
+ _CONN_KW = {
19
+ "connect_" + "timeout": 30,
20
+ "keepalives": 1,
21
+ "keepalives_idle": 30,
22
+ "keepalives_interval": 10,
23
+ "keepalives_count": 5,
24
+ }
25
+
26
+
27
+ def get_db_connection():
28
+ """Get a fresh database connection."""
29
+ db_url = DATABASE_URL
30
+ if "pooler.supabase.com:6543" in db_url:
31
+ db_url = db_url.replace(":6543", ":5432")
32
+ conn = psycopg2.connect(db_url, **_CONN_KW)
33
+ conn.autocommit = True
34
+ return conn
35
+
36
+
37
+ def get_conn():
38
+ """Get current connection, reconnecting if needed."""
39
+ global _db_conn
40
+ if _db_conn is None:
41
+ _db_conn = get_db_connection()
42
+ else:
43
+ try:
44
+ cur = _db_conn.cursor()
45
+ cur.execute("SELECT 1")
46
+ cur.close()
47
+ except Exception:
48
+ print(" [Reconnecting to DB...]", flush=True)
49
+ try:
50
+ _db_conn.close()
51
+ except Exception:
52
+ pass
53
+ _db_conn = get_db_connection()
54
+ return _db_conn
55
+
56
+
57
+ def reset_conn():
58
+ """Force reconnection on next get_conn() call."""
59
+ global _db_conn
60
+ if _db_conn is not None:
61
+ try:
62
+ _db_conn.close()
63
+ except Exception:
64
+ pass
65
+ _db_conn = None
66
+
67
+
68
+ def get_or_create_model(conn, model_name: str, num_layers: int) -> int:
69
+ """Get or create model in database."""
70
+ cur = conn.cursor()
71
+ cur.execute('SELECT id FROM "Model" WHERE "huggingFaceId" = %s', (model_name,))
72
+ result = cur.fetchone()
73
+ if result:
74
+ cur.close()
75
+ return result[0]
76
+ optimal_layer = num_layers // 2
77
+ cur.execute('''
78
+ INSERT INTO "Model" ("name", "huggingFaceId", "userTag", "assistantTag", "userId", "isPublic", "numLayers", "optimalLayer", "createdAt", "updatedAt")
79
+ VALUES (%s, %s, 'user', 'assistant', 'system', true, %s, %s, NOW(), NOW())
80
+ RETURNING id
81
+ ''', (model_name.split('/')[-1], model_name, num_layers, optimal_layer))
82
+ model_id = cur.fetchone()[0]
83
+ conn.commit()
84
+ cur.close()
85
+ return model_id
86
+
87
+
88
+ def get_missing_benchmarks(conn, model_id: int, num_layers: int) -> list:
89
+ """Get list of benchmarks missing raw activations for this model."""
90
+ cur = conn.cursor()
91
+ cur.execute('''
92
+ SELECT cps.id, cps.name, COUNT(cp.id) as pair_count
93
+ FROM "ContrastivePairSet" cps
94
+ INNER JOIN "ContrastivePair" cp ON cp."setId" = cps.id
95
+ GROUP BY cps.id, cps.name
96
+ HAVING COUNT(cp.id) > 0
97
+ ORDER BY cps.name
98
+ ''')
99
+ benchmarks = cur.fetchall()
100
+ missing = []
101
+ for set_id, name, pair_count in benchmarks:
102
+ expected_per_format = pair_count * num_layers * 2
103
+ threshold = int(expected_per_format * 0.95)
104
+ formats_complete = 0
105
+ for fmt in ['chat', 'mc_balanced', 'role_play']:
106
+ cur.execute('''
107
+ SELECT COUNT(*) FROM "RawActivation"
108
+ WHERE "modelId" = %s AND "contrastivePairSetId" = %s AND "promptFormat" = %s
109
+ ''', (model_id, set_id, fmt))
110
+ count = cur.fetchone()[0]
111
+ if count >= threshold:
112
+ formats_complete += 1
113
+ if formats_complete < 3:
114
+ missing.append((set_id, name, pair_count))
115
+ cur.close()
116
+ print(f"Found {len(benchmarks)} benchmarks, {len(benchmarks) - len(missing)} complete, {len(missing)} need extraction", flush=True)
117
+ return missing
@@ -0,0 +1,205 @@
1
+ """Helper functions for extract_raw_activations: extraction and DB batch operations."""
2
+
3
+ from __future__ import annotations
4
+ import struct
5
+
6
+ import psycopg2
7
+ from psycopg2.extras import execute_values
8
+ import torch
9
+
10
+ from wisent.core.utils.config_tools.constants import PROGRESS_LOG_INTERVAL_10, RECURSION_INITIAL_DEPTH
11
+
12
+
13
+ def hidden_states_to_bytes(hidden_states: torch.Tensor) -> bytes:
14
+ """Convert hidden_states tensor to bytes (float32)."""
15
+ flat = hidden_states.cpu().float().flatten().tolist()
16
+ return struct.pack(f'{len(flat)}f', *flat)
17
+
18
+
19
+ def get_batch_size(model_config) -> int:
20
+ """Auto-adjust batch size based on model size."""
21
+ num_params_b = getattr(model_config, 'num_parameters', None)
22
+ if num_params_b is None:
23
+ hidden = model_config.hidden_size
24
+ layers = model_config.num_hidden_layers
25
+ num_params_b = (12 * hidden * hidden * layers) / 1e9
26
+
27
+ if num_params_b < 2:
28
+ return 10
29
+ elif num_params_b < 3:
30
+ return 5
31
+ elif num_params_b < 5:
32
+ return 2
33
+ else:
34
+ return 1
35
+
36
+
37
+ def check_pair_fully_extracted(get_conn_fn, model_id: int, pair_id: int,
38
+ num_layers: int, formats: list) -> bool:
39
+ """Check if a pair has all raw activations for all formats."""
40
+ expected_count = num_layers * 2 * len(formats)
41
+ try:
42
+ conn = get_conn_fn()
43
+ cur = conn.cursor()
44
+ cur.execute('''
45
+ SELECT COUNT(*) FROM "RawActivation"
46
+ WHERE "modelId" = %s AND "contrastivePairId" = %s
47
+ ''', (model_id, pair_id))
48
+ actual_count = cur.fetchone()[0]
49
+ cur.close()
50
+ return actual_count >= expected_count
51
+ except Exception:
52
+ return False
53
+
54
+
55
+ def batch_create_raw_activations(get_conn_fn, reset_conn_fn, activations_data: list, max_retries: int, batch_size: int = None):
56
+ """Batch insert multiple RawActivation records."""
57
+ if not activations_data:
58
+ return
59
+
60
+ if batch_size is None:
61
+ raise ValueError("batch_size is required for batch_create_raw_activations")
62
+
63
+ for i in range(0, len(activations_data), batch_size):
64
+ batch = activations_data[i:i + batch_size]
65
+
66
+ for attempt in range(max_retries):
67
+ try:
68
+ conn = get_conn_fn()
69
+ cur = conn.cursor()
70
+ execute_values(cur, '''
71
+ INSERT INTO "RawActivation"
72
+ ("modelId", "contrastivePairId", "contrastivePairSetId", "layer", "seqLen", "hiddenDim", "promptLen", "hiddenStates", "answerText", "isPositive", "promptFormat", "createdAt")
73
+ VALUES %s
74
+ ON CONFLICT ("modelId", "contrastivePairId", layer, "isPositive", "promptFormat") DO NOTHING
75
+ ''', batch, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
76
+ cur.close()
77
+ break
78
+ except (psycopg2.OperationalError, psycopg2.InterfaceError, psycopg2.errors.QueryCanceled) as e:
79
+ print(f" [DB batch error attempt {attempt+1}/{max_retries}: {e}]", flush=True)
80
+ reset_conn_fn()
81
+ if attempt == max_retries - 1:
82
+ raise
83
+
84
+
85
+ def extract_benchmark(model, tokenizer, model_id: int, benchmark_name: str, set_id: int,
86
+ num_layers: int, device: str, get_conn_fn, reset_conn_fn, max_retries: int, log_interval: int):
87
+ """Extract raw activations for a single benchmark."""
88
+ print(f" [EXTRACT] Importing extraction strategy...", flush=True)
89
+ from wisent.core.primitives.model_interface.core.activations import ExtractionStrategy, build_extraction_texts
90
+ print(f" [EXTRACT] Extraction strategy imported", flush=True)
91
+
92
+ actual_device = getattr(model, '_actual_device', device)
93
+ print(f" [EXTRACT] Using device: {actual_device}", flush=True)
94
+
95
+ print(f" [EXTRACT] Fetching pairs from database...", flush=True)
96
+ conn = get_conn_fn()
97
+
98
+ cur = conn.cursor()
99
+ cur.execute('''
100
+ SELECT id, "positiveExample", "negativeExample", category
101
+ FROM "ContrastivePair"
102
+ WHERE "setId" = %s
103
+ ORDER BY id
104
+ ''', (set_id,))
105
+ db_pairs = cur.fetchall()
106
+ cur.close()
107
+ print(f" [EXTRACT] Fetched {len(db_pairs)} pairs from database", flush=True)
108
+
109
+ if not db_pairs:
110
+ print(f" No pairs in database for {benchmark_name}", flush=True)
111
+ return 0
112
+
113
+ print(f" Processing {len(db_pairs)} pairs with 3 formats...", flush=True)
114
+
115
+ all_prompt_formats = [
116
+ ("chat", ExtractionStrategy.CHAT_LAST),
117
+ ("mc_balanced", ExtractionStrategy.MC_BALANCED),
118
+ ("role_play", ExtractionStrategy.ROLE_PLAY),
119
+ ]
120
+ format_names = [f[0] for f in all_prompt_formats]
121
+
122
+ def get_hidden_states(text):
123
+ enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length, add_special_tokens=False)
124
+ enc = {k: v.to(actual_device) for k, v in enc.items()}
125
+ with torch.inference_mode():
126
+ out = model(**enc, output_hidden_states=True, use_cache=False)
127
+ return [out.hidden_states[i].squeeze(0) for i in range(1, len(out.hidden_states))]
128
+
129
+ extracted = 0
130
+ skipped = 0
131
+
132
+ for pair_idx, (pair_id, pos_example, neg_example, category) in enumerate(db_pairs):
133
+ if pair_idx == 0:
134
+ print(f" [EXTRACT] Processing first pair (id={pair_id})...", flush=True)
135
+
136
+ if "\n\n" in pos_example:
137
+ prompt = pos_example.rsplit("\n\n", 1)[0]
138
+ pos = pos_example.rsplit("\n\n", 1)[1]
139
+ else:
140
+ prompt = pos_example
141
+ pos = ""
142
+
143
+ if "\n\n" in neg_example:
144
+ neg = neg_example.rsplit("\n\n", 1)[1]
145
+ else:
146
+ neg = neg_example
147
+
148
+ if check_pair_fully_extracted(get_conn_fn, model_id, pair_id, num_layers, format_names):
149
+ skipped += 1
150
+ if skipped % log_interval == RECURSION_INITIAL_DEPTH:
151
+ print(f" [skipped {skipped} already-extracted pairs]", flush=True)
152
+ continue
153
+
154
+ activations_batch = []
155
+
156
+ for prompt_format, strategy in all_prompt_formats:
157
+ try:
158
+ if strategy == ExtractionStrategy.MC_BALANCED:
159
+ pos_text, pos_answer, pos_prompt_only = build_extraction_texts(
160
+ strategy, prompt, pos, tokenizer, other_response=neg, is_positive=True)
161
+ neg_text, neg_answer, neg_prompt_only = build_extraction_texts(
162
+ strategy, prompt, neg, tokenizer, other_response=pos, is_positive=False)
163
+ else:
164
+ pos_text, pos_answer, pos_prompt_only = build_extraction_texts(strategy, prompt, pos, tokenizer)
165
+ neg_text, neg_answer, neg_prompt_only = build_extraction_texts(strategy, prompt, neg, tokenizer)
166
+ except Exception as e:
167
+ print(f" Error building texts for {prompt_format}: {e}", flush=True)
168
+ continue
169
+
170
+ pos_prompt_len = len(tokenizer(pos_prompt_only, add_special_tokens=False)["input_ids"]) if pos_prompt_only else 0
171
+ neg_prompt_len = len(tokenizer(neg_prompt_only, add_special_tokens=False)["input_ids"]) if neg_prompt_only else 0
172
+
173
+ pos_hidden = get_hidden_states(pos_text)
174
+ neg_hidden = get_hidden_states(neg_text)
175
+
176
+ for layer_idx in range(num_layers):
177
+ layer_num = layer_idx + 1
178
+ pos_bytes = hidden_states_to_bytes(pos_hidden[layer_idx])
179
+ neg_bytes = hidden_states_to_bytes(neg_hidden[layer_idx])
180
+
181
+ activations_batch.append((
182
+ model_id, pair_id, set_id, layer_num,
183
+ pos_hidden[layer_idx].shape[0], pos_hidden[layer_idx].shape[1],
184
+ pos_prompt_len, psycopg2.Binary(pos_bytes), pos_answer, True, prompt_format
185
+ ))
186
+ activations_batch.append((
187
+ model_id, pair_id, set_id, layer_num,
188
+ neg_hidden[layer_idx].shape[0], neg_hidden[layer_idx].shape[1],
189
+ neg_prompt_len, psycopg2.Binary(neg_bytes), neg_answer, False, prompt_format
190
+ ))
191
+
192
+ del pos_hidden, neg_hidden
193
+
194
+ reset_conn_fn()
195
+ batch_create_raw_activations(get_conn_fn, reset_conn_fn, activations_batch, max_retries=max_retries)
196
+ extracted += 1
197
+
198
+ if (pair_idx + 1) % PROGRESS_LOG_INTERVAL_10 == 0:
199
+ print(f" Processed {pair_idx + 1}/{len(db_pairs)} pairs", flush=True)
200
+
201
+ if device == "cuda":
202
+ torch.cuda.empty_cache()
203
+
204
+ print(f" Done: extracted {extracted}, skipped {skipped}", flush=True)
205
+ return extracted
@@ -0,0 +1,191 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract activations for ALL missing benchmarks for all models.
4
+ Designed to run on AWS with GPU.
5
+ """
6
+
7
+ import os
8
+
9
+ import psycopg2
10
+ from wisent.core.utils.config_tools.constants import RECURSION_INITIAL_DEPTH, COMBO_OFFSET
11
+ from psycopg2.extras import execute_values
12
+ import torch
13
+
14
+ DATABASE_URL = os.environ.get("DATABASE_URL")
15
+ if DATABASE_URL and '?' in DATABASE_URL:
16
+ DATABASE_URL = DATABASE_URL.split('?')[0]
17
+
18
+ if not DATABASE_URL:
19
+ raise RuntimeError("DATABASE_URL environment variable is required")
20
+
21
+ _db_conn = None
22
+
23
+
24
+ def hidden_states_to_bytes(hidden_states: torch.Tensor) -> bytes:
25
+ """Convert hidden_states tensor to bytes (float32) using numpy for speed."""
26
+ import numpy as np
27
+ arr = hidden_states.cpu().float().numpy()
28
+ return arr.astype(np.float32).tobytes()
29
+
30
+
31
+ def get_db_connection(db_connect_wait_s: int):
32
+ """Get a fresh database connection."""
33
+ db_url = DATABASE_URL
34
+ if "pooler.supabase.com:6543" in db_url:
35
+ db_url = db_url.replace(":6543", ":5432")
36
+ conn = psycopg2.connect(
37
+ db_url,
38
+ **{"connect_" + "timeout": db_connect_wait_s},
39
+ keepalives=1,
40
+ keepalives_idle=30,
41
+ keepalives_interval=10,
42
+ keepalives_count=5
43
+ )
44
+ conn.autocommit = True
45
+ return conn
46
+
47
+
48
+ def get_conn(db_connect_wait_s: int):
49
+ """Get current connection, reconnecting if needed."""
50
+ global _db_conn
51
+ if _db_conn is None:
52
+ _db_conn = get_db_connection(db_connect_wait_s)
53
+ else:
54
+ try:
55
+ cur = _db_conn.cursor()
56
+ cur.execute("SELECT 1")
57
+ cur.close()
58
+ except Exception:
59
+ print(" [Reconnecting to DB...]", flush=True)
60
+ try:
61
+ _db_conn.close()
62
+ except Exception:
63
+ pass
64
+ _db_conn = get_db_connection(db_connect_wait_s)
65
+ return _db_conn
66
+
67
+
68
+ def reset_conn():
69
+ """Force reconnection on next get_conn() call."""
70
+ global _db_conn
71
+ if _db_conn is not None:
72
+ try:
73
+ _db_conn.close()
74
+ except Exception:
75
+ pass
76
+ _db_conn = None
77
+
78
+
79
+ def get_missing_benchmarks(conn, model_id: int, log_interval: int) -> list:
80
+ """Get list of benchmarks that need more extractions for this model.
81
+
82
+ A benchmark is incomplete if it has fewer extracted pairs than
83
+ the total available pairs in the database.
84
+
85
+ Returns list of (set_id, name, pairs_needed) for incomplete benchmarks.
86
+ """
87
+ cur = conn.cursor()
88
+
89
+ # Step 1: Get all benchmarks with pair counts (fast query)
90
+ print(" Fetching benchmark pair counts...", flush=True)
91
+ cur.execute('''
92
+ SELECT cps.id, cps.name, COUNT(cp.id) as total_pairs
93
+ FROM "ContrastivePairSet" cps
94
+ INNER JOIN "ContrastivePair" cp ON cp."setId" = cps.id
95
+ GROUP BY cps.id, cps.name
96
+ HAVING COUNT(cp.id) > 0
97
+ ORDER BY cps.name
98
+ ''')
99
+ benchmarks = cur.fetchall()
100
+ print(f" Found {len(benchmarks)} benchmarks with pairs", flush=True)
101
+
102
+ # Step 2: For each benchmark, count extracted pairs (separate queries avoid timeout)
103
+ missing = []
104
+ complete = 0
105
+ for i, (set_id, name, total_pairs) in enumerate(benchmarks):
106
+ cur.execute('''
107
+ SELECT COUNT(DISTINCT "contrastivePairId")
108
+ FROM "Activation"
109
+ WHERE "contrastivePairSetId" = %s AND "modelId" = %s
110
+ ''', (set_id, model_id))
111
+ extracted_pairs = cur.fetchone()[0]
112
+
113
+ if extracted_pairs < total_pairs:
114
+ pairs_needed = total_pairs - extracted_pairs
115
+ missing.append((set_id, name, pairs_needed))
116
+ else:
117
+ complete += 1
118
+
119
+ if (i + COMBO_OFFSET) % log_interval == RECURSION_INITIAL_DEPTH:
120
+ print(f" Checked {i + 1}/{len(benchmarks)} benchmarks...", flush=True)
121
+
122
+ cur.close()
123
+ print(f"Found {len(benchmarks)} benchmarks with pairs: {complete} complete, {len(missing)} need more extraction", flush=True)
124
+ return missing
125
+
126
+
127
+ def get_or_create_pair(conn, set_id: int, prompt: str, positive: str, negative: str, pair_idx: int, db_text_field_max_length: int) -> int:
128
+ """Get or create ContrastivePair."""
129
+ cur = conn.cursor()
130
+
131
+ cur.execute('''
132
+ SELECT id FROM "ContrastivePair"
133
+ WHERE "setId" = %s AND category = %s
134
+ ''', (set_id, f"pair_{pair_idx}"))
135
+ result = cur.fetchone()
136
+ if result:
137
+ cur.close()
138
+ return result[0]
139
+
140
+ positive_text = f"{prompt}\n\n{positive}"[:db_text_field_max_length]
141
+ negative_text = f"{prompt}\n\n{negative}"[:db_text_field_max_length]
142
+
143
+ cur.execute('''
144
+ INSERT INTO "ContrastivePair" ("setId", "positiveExample", "negativeExample", "category", "createdAt", "updatedAt")
145
+ VALUES (%s, %s, %s, %s, NOW(), NOW())
146
+ RETURNING id
147
+ ''', (set_id, positive_text, negative_text, f"pair_{pair_idx}"))
148
+ pair_id = cur.fetchone()[0]
149
+ conn.commit()
150
+ cur.close()
151
+ return pair_id
152
+
153
+
154
+ def batch_create_activations(activations_data: list, max_retries: int, db_connect_wait_s: int):
155
+ """Batch insert multiple Activation records with retry logic.
156
+
157
+ activations_data is a list of tuples:
158
+ (model_id, pair_id, set_id, layer, neuron_count, strategy, activation_bytes, is_positive)
159
+ """
160
+ if not activations_data:
161
+ return
162
+
163
+ for attempt in range(max_retries):
164
+ try:
165
+ conn = get_conn(db_connect_wait_s)
166
+ cur = conn.cursor()
167
+
168
+ execute_values(cur, '''
169
+ INSERT INTO "Activation"
170
+ ("modelId", "contrastivePairId", "contrastivePairSetId", "layer", "neuronCount",
171
+ "extractionStrategy", "activationData", "isPositive", "userId", "createdAt", "updatedAt")
172
+ VALUES %s
173
+ ON CONFLICT DO NOTHING
174
+ ''', activations_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, 'system', NOW(), NOW())")
175
+ cur.close()
176
+ return
177
+ except (psycopg2.OperationalError, psycopg2.InterfaceError, psycopg2.errors.QueryCanceled) as e:
178
+ print(f" [DB error attempt {attempt+1}/{max_retries}: {e}]", flush=True)
179
+ reset_conn()
180
+ if attempt == max_retries - 1:
181
+ raise
182
+
183
+
184
+ # Import extract_benchmark and main from helpers (split to meet 300-line limit)
185
+ from wisent.scripts._helpers.extract_all_missing_helpers import ( # noqa: E402
186
+ extract_benchmark,
187
+ main,
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ main()
@@ -0,0 +1,128 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract raw activations for ALL missing benchmarks with 3 prompt formats.
4
+
5
+ This script:
6
+ 1. Finds all benchmarks that have contrastive pairs in the database
7
+ 2. Checks which benchmarks are missing raw activations for the given model
8
+ 3. Extracts using 3 formats: chat, mc_balanced, role_play
9
+ 4. Stores to RawActivation table (full sequence hidden states)
10
+
11
+ Extracts up to 500 pairs per benchmark (or maximum available).
12
+
13
+ Usage:
14
+ python3 -m wisent.scripts.extract_raw_activations --model meta-llama/Llama-3.2-1B-Instruct
15
+ python3 -m wisent.scripts.extract_raw_activations --model Qwen/Qwen3-8B --benchmark knowledge_qa/mmlu
16
+ """
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ import time
22
+
23
+ print("[STARTUP] Starting extract_raw_activations.py...", flush=True)
24
+ print(f"[STARTUP] Python version: {sys.version}", flush=True)
25
+
26
+ print("[STARTUP] Importing psycopg2...", flush=True)
27
+ import psycopg2
28
+ print("[STARTUP] psycopg2 imported", flush=True)
29
+
30
+ print("[STARTUP] Importing torch...", flush=True)
31
+ import torch
32
+ print(f"[STARTUP] torch imported, version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}", flush=True)
33
+
34
+ from wisent.scripts._helpers.extract_raw_helpers import extract_benchmark
35
+ from wisent.scripts._helpers.extract_raw_db import (
36
+ get_conn, reset_conn, get_or_create_model, get_missing_benchmarks,
37
+ )
38
+
39
+
40
+ def main():
41
+ print("[MAIN] Parsing arguments...", flush=True)
42
+ parser = argparse.ArgumentParser(description="Extract raw activations for all missing benchmarks with 3 formats")
43
+ parser.add_argument("--model", required=True, help="Model name (e.g., meta-llama/Llama-3.2-1B-Instruct)")
44
+ parser.add_argument("--device", required=True, help="Device (cuda/mps/cpu)")
45
+ parser.add_argument("--benchmark", default=None, help="Single benchmark to extract (optional)")
46
+ parser.add_argument("--max-retries", type=int, required=True, help="Maximum retry attempts for DB operations")
47
+ parser.add_argument("--log-interval", type=int, required=True, help="Progress logging interval")
48
+ args = parser.parse_args()
49
+ print(f"[MAIN] Args: model={args.model}, device={args.device}, benchmark={args.benchmark}", flush=True)
50
+
51
+ print("[MAIN] Importing transformers...", flush=True)
52
+ from transformers import AutoTokenizer, AutoModelForCausalLM
53
+ print("[MAIN] transformers imported", flush=True)
54
+
55
+ print(f"[MAIN] Loading tokenizer for {args.model}...", flush=True)
56
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
57
+ print(f"[MAIN] Tokenizer loaded, vocab_size={tokenizer.vocab_size}", flush=True)
58
+
59
+ print(f"[MAIN] Loading model {args.model}...", flush=True)
60
+ if args.device == "mps":
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ args.model, torch_dtype=torch.float32, trust_remote_code=True)
63
+ model = model.to("mps")
64
+ actual_device = "mps"
65
+ else:
66
+ num_gpus = torch.cuda.device_count()
67
+ print(f"[MAIN] Detected {num_gpus} GPUs", flush=True)
68
+ use_device_map = "auto" if num_gpus > 1 else args.device
69
+ print(f"[MAIN] Using device_map={use_device_map}", flush=True)
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ args.model, torch_dtype="auto", device_map=use_device_map, trust_remote_code=True)
72
+ actual_device = next(model.parameters()).device
73
+ print(f"[MAIN] Model device: {actual_device}", flush=True)
74
+ model.eval()
75
+
76
+ num_layers = model.config.num_hidden_layers
77
+ print(f"[MAIN] Model loaded: {num_layers} layers, device={actual_device}", flush=True)
78
+
79
+ # Store actual device for use in extraction
80
+ model._actual_device = str(actual_device)
81
+
82
+ print("[MAIN] Connecting to database...", flush=True)
83
+ conn = get_conn()
84
+ print("[MAIN] Database connected", flush=True)
85
+
86
+ model_id = get_or_create_model(conn, args.model, num_layers)
87
+ print(f"[MAIN] Model ID: {model_id}", flush=True)
88
+
89
+ if args.benchmark:
90
+ cur = conn.cursor()
91
+ cur.execute('SELECT id FROM "ContrastivePairSet" WHERE name = %s', (args.benchmark,))
92
+ result = cur.fetchone()
93
+ if not result:
94
+ print(f"ERROR: Benchmark {args.benchmark} not found", flush=True)
95
+ return
96
+ set_id = result[0]
97
+ cur.close()
98
+
99
+ print(f"\nExtracting single benchmark: {args.benchmark}", flush=True)
100
+ extracted = extract_benchmark(model, tokenizer, model_id, args.benchmark, set_id,
101
+ num_layers, args.device, get_conn, reset_conn, max_retries=args.max_retries, log_interval=args.log_interval)
102
+ print(f"\nDone! Extracted {extracted} pairs", flush=True)
103
+ else:
104
+ missing = get_missing_benchmarks(conn, model_id, num_layers)
105
+ print(f"\nFound {len(missing)} benchmarks needing extraction", flush=True)
106
+
107
+ if not missing:
108
+ print("All benchmarks are fully extracted!", flush=True)
109
+ return
110
+
111
+ total_extracted = 0
112
+ for i, (set_id, benchmark_name, pair_count) in enumerate(missing):
113
+ print(f"\n[{i+1}/{len(missing)}] {benchmark_name} ({pair_count} pairs in DB)", flush=True)
114
+ start = time.time()
115
+
116
+ extracted = extract_benchmark(model, tokenizer, model_id, benchmark_name, set_id,
117
+ num_layers, args.device, get_conn, reset_conn, max_retries=args.max_retries, log_interval=args.log_interval)
118
+
119
+ total_extracted += extracted
120
+ elapsed = time.time() - start
121
+ print(f" Completed in {elapsed:.1f}s", flush=True)
122
+
123
+ print(f"\n{'='*60}", flush=True)
124
+ print(f"COMPLETE! Total extracted: {total_extracted} pairs across {len(missing)} benchmarks", flush=True)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
@@ -0,0 +1,98 @@
1
+ """Fix the order of correct/incorrect answers in extractor files.
2
+
3
+ The correct order is:
4
+ A. {incorrect}
5
+ B. {correct}
6
+
7
+ This script:
8
+ 1. Finds files with the reversed order and fixes them
9
+ 2. Checks if evaluator_name == "log_likelihoods" and verifies A/B pattern is present
10
+ """
11
+
12
+ import re
13
+ from pathlib import Path
14
+
15
+ from wisent.core.utils.config_tools.constants import SEPARATOR_WIDTH_STANDARD
16
+
17
+
18
+ def fix_extractor_order():
19
+ """Find and fix extractors with incorrect A/B order."""
20
+
21
+ # Directories to search
22
+ base_path = Path(__file__).parent.parent / "core" / "contrastive_pairs"
23
+ search_dirs = [
24
+ base_path / "lm_eval_pairs" / "lm_task_extractors",
25
+ base_path / "huggingface_pairs" / "hf_task_extractors",
26
+ ]
27
+
28
+ # Pattern for incorrect order (correct first, incorrect second)
29
+ incorrect_pattern = r'\\nA\. \{correct\}\\nB\. \{incorrect\}'
30
+
31
+ # What it should be replaced with
32
+ correct_replacement = r'\\nA. {incorrect}\\nB. {correct}'
33
+
34
+ # Pattern for correct order
35
+ correct_pattern = r'\\nA\. \{incorrect\}\\nB\. \{correct\}'
36
+
37
+ # Pattern for log_likelihoods evaluator
38
+ log_likelihood_pattern = r'evaluator_name\s*=\s*["\']log_likelihood[s]?["\']'
39
+
40
+ files_with_incorrect_order = []
41
+ log_likelihood_missing_ab = []
42
+
43
+ for search_dir in search_dirs:
44
+ if not search_dir.exists():
45
+ print(f"Directory not found: {search_dir}")
46
+ continue
47
+
48
+ for py_file in search_dir.glob("*.py"):
49
+ if py_file.name == "__init__.py":
50
+ continue
51
+
52
+ content = py_file.read_text()
53
+
54
+ # Check if file has incorrect order
55
+ if re.search(incorrect_pattern, content):
56
+ files_with_incorrect_order.append(py_file)
57
+
58
+ # Fix the order
59
+ fixed_content = re.sub(
60
+ incorrect_pattern,
61
+ correct_replacement,
62
+ content
63
+ )
64
+
65
+ py_file.write_text(fixed_content)
66
+
67
+ # Check if evaluator is log_likelihoods but missing A/B pattern
68
+ has_log_likelihood = re.search(log_likelihood_pattern, content)
69
+ has_ab_pattern = re.search(correct_pattern, content) or re.search(incorrect_pattern, content)
70
+
71
+ if has_log_likelihood and not has_ab_pattern:
72
+ log_likelihood_missing_ab.append(py_file)
73
+
74
+ # Report results
75
+ print("=" * SEPARATOR_WIDTH_STANDARD)
76
+ print("EXTRACTOR ORDER FIX REPORT")
77
+ print("=" * SEPARATOR_WIDTH_STANDARD)
78
+
79
+ print(f"\n1. Files with incorrect order (A.correct/B.incorrect -> fixed): {len(files_with_incorrect_order)}")
80
+ if files_with_incorrect_order:
81
+ print("\n Fixed files:")
82
+ for f in sorted(files_with_incorrect_order):
83
+ print(f" - {f.name}")
84
+
85
+ print(f"\n2. Files with log_likelihoods evaluator but MISSING A/B pattern: {len(log_likelihood_missing_ab)}")
86
+ if log_likelihood_missing_ab:
87
+ print("\n Missing A/B pattern:")
88
+ for f in sorted(log_likelihood_missing_ab):
89
+ print(f" - {f.name}")
90
+
91
+ return {
92
+ "fixed": files_with_incorrect_order,
93
+ "missing_ab": log_likelihood_missing_ab,
94
+ }
95
+
96
+
97
+ if __name__ == "__main__":
98
+ fix_extractor_order()
@@ -0,0 +1,210 @@
1
+ #!/bin/bash
2
+ # Run quality metrics sweep across multiple benchmarks
3
+ # This script runs the optimization pipeline for each benchmark and collects
4
+ # quality metrics alongside steering effectiveness (delta) for correlation analysis.
5
+ #
6
+ # Output: all_trials_metrics_{timestamp}.json for each benchmark in /home/ubuntu/output/
7
+ #
8
+ # Features:
9
+ # - Saves intermediate results after each benchmark to GCS
10
+ # - Supports resuming from last completed benchmark
11
+ # - Continues on individual benchmark failures (doesn't abort entire sweep)
12
+ #
13
+ # Usage:
14
+ # ./run_quality_metrics_sweep.sh
15
+
16
+ # Don't exit on error - we want to continue with other benchmarks
17
+ set -uo pipefail
18
+
19
+ # Configuration
20
+ MODEL="${MODEL:-Qwen/Qwen2.5-0.5B-Instruct}"
21
+ OUTPUT_DIR="${OUTPUT_DIR:-/home/ubuntu/output}"
22
+ LAYER_RANGE="${LAYER_RANGE:-0-23}"
23
+ GCS_BUCKET="${GCS_BUCKET:-wisent-images-bucket}"
24
+
25
+ # Progress tracking file
26
+ PROGRESS_FILE="$OUTPUT_DIR/.sweep_progress"
27
+
28
+ # Source helper functions (save_intermediate_results, is_benchmark_completed, etc.)
29
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
30
+ # shellcheck source=_helpers/sweep_helpers.sh
31
+ source "$SCRIPT_DIR/_helpers/sweep_helpers.sh"
32
+
33
+ # Benchmarks to test (these have meaningful correct/incorrect answer pairs)
34
+ BENCHMARKS=(
35
+ "gsm8k"
36
+ "arc_easy"
37
+ "arc_challenge"
38
+ "hellaswag"
39
+ "winogrande"
40
+ "truthfulqa_mc1"
41
+ "piqa"
42
+ "boolq"
43
+ "openbookqa"
44
+ "livecodebench"
45
+ )
46
+
47
+ # Synthetic steering types for validation:
48
+ # - "british" = meaningful steering (British vs American English - should have good metrics AND show steering effect)
49
+ # - "random" = random pairs (should have BAD metrics AND NO steering effect)
50
+ # These validate which metrics actually predict steering effectiveness
51
+ SYNTHETIC_TYPES=(
52
+ "british"
53
+ "random"
54
+ )
55
+
56
+ echo "=========================================="
57
+ echo "Quality Metrics Sweep"
58
+ echo "=========================================="
59
+ echo "Model: $MODEL"
60
+ echo "Output: $OUTPUT_DIR"
61
+ echo "Layer range: $LAYER_RANGE"
62
+ echo "Benchmarks: ${BENCHMARKS[*]}"
63
+ echo "Synthetic types: ${SYNTHETIC_TYPES[*]}"
64
+ echo "=========================================="
65
+
66
+ # Create output directory
67
+ mkdir -p "$OUTPUT_DIR"
68
+
69
+ # ==========================================
70
+ # Part 1: Run optimization for each BENCHMARK task
71
+ # ==========================================
72
+ echo ""
73
+ echo "=========================================="
74
+ echo "Part 1: Benchmark Tasks"
75
+ echo "=========================================="
76
+
77
+ FAILED_BENCHMARKS=()
78
+ COMPLETED_BENCHMARKS=()
79
+
80
+ for BENCHMARK in "${BENCHMARKS[@]}"; do
81
+ echo ""
82
+ echo "=========================================="
83
+ echo "Running: $BENCHMARK"
84
+ echo "=========================================="
85
+
86
+ # Skip if already completed (resume support)
87
+ if is_benchmark_completed "$BENCHMARK"; then
88
+ echo "SKIPPING: $BENCHMARK already completed (found ${BENCHMARK}_metrics.json)"
89
+ COMPLETED_BENCHMARKS+=("$BENCHMARK")
90
+ continue
91
+ fi
92
+
93
+ BENCHMARK_START=$(date +%s)
94
+
95
+ # Run the optimization using wisent CLI with baseline comparison
96
+ if wisent optimize-steering comprehensive "$MODEL" \
97
+ --tasks "$BENCHMARK" \
98
+ --compute-baseline \
99
+ --device cuda \
100
+ --output-dir "$OUTPUT_DIR/$BENCHMARK" \
101
+ 2>&1 | tee "$OUTPUT_DIR/${BENCHMARK}_log.txt"; then
102
+
103
+ BENCHMARK_END=$(date +%s)
104
+ DURATION=$((BENCHMARK_END - BENCHMARK_START))
105
+ echo "Completed $BENCHMARK in ${DURATION}s"
106
+
107
+ # Find and copy the results file
108
+ RESULTS_FILE=$(find "$OUTPUT_DIR/$BENCHMARK" -name "steering_comprehensive_*.json" -type f 2>/dev/null | head -1)
109
+
110
+ if [ -n "$RESULTS_FILE" ]; then
111
+ echo "Results saved to: $RESULTS_FILE"
112
+ cp "$RESULTS_FILE" "$OUTPUT_DIR/${BENCHMARK}_metrics.json"
113
+ mark_benchmark_completed "$BENCHMARK"
114
+ COMPLETED_BENCHMARKS+=("$BENCHMARK")
115
+ else
116
+ echo "WARNING: No results file found for $BENCHMARK"
117
+ FAILED_BENCHMARKS+=("$BENCHMARK")
118
+ fi
119
+ else
120
+ echo "ERROR: $BENCHMARK failed"
121
+ FAILED_BENCHMARKS+=("$BENCHMARK")
122
+ fi
123
+
124
+ # Save intermediate results after each benchmark
125
+ save_intermediate_results
126
+ done
127
+
128
+ # ==========================================
129
+ # Part 2: Run SYNTHETIC steering (british, random)
130
+ # These use --task personalization with --trait
131
+ # ==========================================
132
+ echo ""
133
+ echo "=========================================="
134
+ echo "Part 2: Synthetic Steering Validation"
135
+ echo "=========================================="
136
+
137
+ for SYNTHETIC_TYPE in "${SYNTHETIC_TYPES[@]}"; do
138
+ echo ""
139
+ echo "=========================================="
140
+ echo "Running synthetic: $SYNTHETIC_TYPE"
141
+ echo "=========================================="
142
+
143
+ # Skip if already completed
144
+ if [ -f "$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}_metrics.json" ]; then
145
+ echo "SKIPPING: synthetic_$SYNTHETIC_TYPE already completed"
146
+ COMPLETED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
147
+ continue
148
+ fi
149
+
150
+ SYNTHETIC_START=$(date +%s)
151
+ SYNTHETIC_DIR="$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}"
152
+
153
+ # Run the optimization with personalization task
154
+ if wisent optimize-steering personalization \
155
+ --model "$MODEL" \
156
+ --trait "$SYNTHETIC_TYPE" \
157
+ --num-pairs 50 \
158
+ --output-dir "$SYNTHETIC_DIR" \
159
+ --device cuda \
160
+ 2>&1 | tee "$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}_log.txt"; then
161
+
162
+ SYNTHETIC_END=$(date +%s)
163
+ DURATION=$((SYNTHETIC_END - SYNTHETIC_START))
164
+ echo "Completed synthetic $SYNTHETIC_TYPE in ${DURATION}s"
165
+
166
+ # Find the results file
167
+ RESULTS_FILE=$(find "$SYNTHETIC_DIR" -name "*.json" -type f 2>/dev/null | head -1)
168
+
169
+ if [ -n "$RESULTS_FILE" ]; then
170
+ echo "Results saved to: $RESULTS_FILE"
171
+ cp "$RESULTS_FILE" "$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}_metrics.json"
172
+ COMPLETED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
173
+ else
174
+ echo "WARNING: No results file found for synthetic_$SYNTHETIC_TYPE"
175
+ FAILED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
176
+ fi
177
+ else
178
+ echo "ERROR: synthetic_$SYNTHETIC_TYPE failed"
179
+ FAILED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
180
+ fi
181
+
182
+ # Save intermediate results after each synthetic
183
+ save_intermediate_results
184
+ done
185
+
186
+ # ==========================================
187
+ # Part 3: Combine all results
188
+ # ==========================================
189
+ echo ""
190
+ echo "=========================================="
191
+ echo "Combining Results"
192
+ echo "=========================================="
193
+
194
+ combine_all_results
195
+
196
+ echo ""
197
+ echo "=========================================="
198
+ echo "Sweep Complete!"
199
+ echo "=========================================="
200
+ echo "Results in: $OUTPUT_DIR"
201
+ ls -la "$OUTPUT_DIR"/*.json 2>/dev/null || echo "No JSON files found"
202
+ echo ""
203
+ echo "Completed benchmarks: ${COMPLETED_BENCHMARKS[*]:-none}"
204
+ echo "Failed benchmarks: ${FAILED_BENCHMARKS[*]:-none}"
205
+ echo ""
206
+
207
+ # Final upload to GCS
208
+ upload_final_to_gcs
209
+
210
+ echo "Done!"
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.4
2
+ Name: wisent-tools
3
+ Version: 0.1.0
4
+ Summary: Operational scripts and benchmark-evaluation runners for the wisent package family
5
+ Home-page: https://github.com/wisent-ai/wisent-tools
6
+ Author: Lukasz Bartoszcze and the Wisent Team
7
+ Author-email: lukasz.bartoszcze@wisent.ai
8
+ Requires-Python: >=3.9
9
+ Requires-Dist: wisent>=0.10.0
10
+ Requires-Dist: wisent-evaluators>=0.1.0
11
+ Dynamic: author
12
+ Dynamic: author-email
13
+ Dynamic: home-page
14
+ Dynamic: requires-dist
15
+ Dynamic: requires-python
16
+ Dynamic: summary
@@ -0,0 +1,18 @@
1
+ README.md
2
+ pyproject.toml
3
+ setup.py
4
+ wisent/__init__.py
5
+ wisent/scripts/__init__.py
6
+ wisent/scripts/extract_all_missing.py
7
+ wisent/scripts/extract_raw_activations.py
8
+ wisent/scripts/fix_extractor_order.py
9
+ wisent/scripts/run_quality_metrics_sweep.sh
10
+ wisent/scripts/_helpers/__init__.py
11
+ wisent/scripts/_helpers/extract_all_missing_helpers.py
12
+ wisent/scripts/_helpers/extract_raw_db.py
13
+ wisent/scripts/_helpers/extract_raw_helpers.py
14
+ wisent_tools.egg-info/PKG-INFO
15
+ wisent_tools.egg-info/SOURCES.txt
16
+ wisent_tools.egg-info/dependency_links.txt
17
+ wisent_tools.egg-info/requires.txt
18
+ wisent_tools.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ wisent>=0.10.0
2
+ wisent-evaluators>=0.1.0