wisent-tools 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent_tools-0.1.0/PKG-INFO +16 -0
- wisent_tools-0.1.0/README.md +11 -0
- wisent_tools-0.1.0/pyproject.toml +3 -0
- wisent_tools-0.1.0/setup.cfg +4 -0
- wisent_tools-0.1.0/setup.py +14 -0
- wisent_tools-0.1.0/wisent/__init__.py +15 -0
- wisent_tools-0.1.0/wisent/scripts/__init__.py +1 -0
- wisent_tools-0.1.0/wisent/scripts/_helpers/__init__.py +1 -0
- wisent_tools-0.1.0/wisent/scripts/_helpers/extract_all_missing_helpers.py +199 -0
- wisent_tools-0.1.0/wisent/scripts/_helpers/extract_raw_db.py +117 -0
- wisent_tools-0.1.0/wisent/scripts/_helpers/extract_raw_helpers.py +205 -0
- wisent_tools-0.1.0/wisent/scripts/extract_all_missing.py +191 -0
- wisent_tools-0.1.0/wisent/scripts/extract_raw_activations.py +128 -0
- wisent_tools-0.1.0/wisent/scripts/fix_extractor_order.py +98 -0
- wisent_tools-0.1.0/wisent/scripts/run_quality_metrics_sweep.sh +210 -0
- wisent_tools-0.1.0/wisent_tools.egg-info/PKG-INFO +16 -0
- wisent_tools-0.1.0/wisent_tools.egg-info/SOURCES.txt +18 -0
- wisent_tools-0.1.0/wisent_tools.egg-info/dependency_links.txt +1 -0
- wisent_tools-0.1.0/wisent_tools.egg-info/requires.txt +2 -0
- wisent_tools-0.1.0/wisent_tools.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wisent-tools
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Operational scripts and benchmark-evaluation runners for the wisent package family
|
|
5
|
+
Home-page: https://github.com/wisent-ai/wisent-tools
|
|
6
|
+
Author: Lukasz Bartoszcze and the Wisent Team
|
|
7
|
+
Author-email: lukasz.bartoszcze@wisent.ai
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Requires-Dist: wisent>=0.10.0
|
|
10
|
+
Requires-Dist: wisent-evaluators>=0.1.0
|
|
11
|
+
Dynamic: author
|
|
12
|
+
Dynamic: author-email
|
|
13
|
+
Dynamic: home-page
|
|
14
|
+
Dynamic: requires-dist
|
|
15
|
+
Dynamic: requires-python
|
|
16
|
+
Dynamic: summary
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# wisent-tools
|
|
2
|
+
|
|
3
|
+
Operational scripts split out of wisent-open-source. Provides `wisent.scripts` —
|
|
4
|
+
benchmark-evaluation runners (aime, apps, conala, livemathbench, math, polymath),
|
|
5
|
+
extract helpers, fix utilities.
|
|
6
|
+
|
|
7
|
+
## Install
|
|
8
|
+
|
|
9
|
+
```
|
|
10
|
+
pip install wisent-tools
|
|
11
|
+
```
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
setup(
|
|
3
|
+
name="wisent-tools",
|
|
4
|
+
version="0.1.0",
|
|
5
|
+
author="Lukasz Bartoszcze and the Wisent Team",
|
|
6
|
+
author_email="lukasz.bartoszcze@wisent.ai",
|
|
7
|
+
description="Operational scripts and benchmark-evaluation runners for the wisent package family",
|
|
8
|
+
url="https://github.com/wisent-ai/wisent-tools",
|
|
9
|
+
packages=find_packages(include=["wisent", "wisent.*"]),
|
|
10
|
+
python_requires=">=3.9",
|
|
11
|
+
install_requires=["wisent>=0.10.0", "wisent-evaluators>=0.1.0"],
|
|
12
|
+
include_package_data=True,
|
|
13
|
+
package_data={"wisent": ["scripts/*.sh"]},
|
|
14
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Namespace bootstrap shared with wisent-core and sibling packages.
|
|
2
|
+
|
|
3
|
+
Uses pkgutil.extend_path so all wisent-* packages merge at import time
|
|
4
|
+
even though wisent-core ships a regular (non-PEP-420) package.
|
|
5
|
+
"""
|
|
6
|
+
import os
|
|
7
|
+
import pkgutil
|
|
8
|
+
|
|
9
|
+
__path__ = pkgutil.extend_path(__path__, __name__)
|
|
10
|
+
|
|
11
|
+
_base = os.path.dirname(__file__)
|
|
12
|
+
for _entry in sorted(os.listdir(_base)):
|
|
13
|
+
_path = os.path.join(_base, _entry)
|
|
14
|
+
if os.path.isdir(_path) and not _entry.startswith(('.', '_')):
|
|
15
|
+
__path__.append(_path)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Wisent scripts for activation extraction and data processing."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Extracted helpers for files exceeding 300-line limit."""
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Benchmark extraction and main entry point for extract_all_missing.
|
|
3
|
+
|
|
4
|
+
Split from extract_all_missing.py to meet 300-line limit.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
import psycopg2
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from wisent.core.utils.config_tools.constants import RECURSION_INITIAL_DEPTH
|
|
15
|
+
|
|
16
|
+
from wisent.scripts.extract_all_missing import (
|
|
17
|
+
hidden_states_to_bytes,
|
|
18
|
+
get_conn,
|
|
19
|
+
reset_conn,
|
|
20
|
+
batch_create_activations,
|
|
21
|
+
get_missing_benchmarks,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def extract_benchmark(model, tokenizer, model_id: int, benchmark_name: str, set_id: int,
|
|
26
|
+
device: str, num_layers: int, batch_size: int,
|
|
27
|
+
db_connect_wait_s: int, max_retries: int,
|
|
28
|
+
log_interval: int):
|
|
29
|
+
"""Extract activations for a single benchmark using EXISTING pairs from database.
|
|
30
|
+
|
|
31
|
+
Only extracts pairs that don't already have activations for this model.
|
|
32
|
+
"""
|
|
33
|
+
conn = get_conn(db_connect_wait_s)
|
|
34
|
+
cur = conn.cursor()
|
|
35
|
+
|
|
36
|
+
# Get pairs that DON'T already have activations for this model
|
|
37
|
+
cur.execute('''
|
|
38
|
+
SELECT cp.id, cp."positiveExample", cp."negativeExample"
|
|
39
|
+
FROM "ContrastivePair" cp
|
|
40
|
+
WHERE cp."setId" = %s
|
|
41
|
+
AND NOT EXISTS (
|
|
42
|
+
SELECT 1 FROM "Activation" a
|
|
43
|
+
WHERE a."contrastivePairId" = cp.id AND a."modelId" = %s
|
|
44
|
+
)
|
|
45
|
+
ORDER BY cp.id
|
|
46
|
+
''', (set_id, model_id))
|
|
47
|
+
db_pairs = cur.fetchall()
|
|
48
|
+
cur.close()
|
|
49
|
+
|
|
50
|
+
if not db_pairs:
|
|
51
|
+
print(f" All pairs already extracted for {benchmark_name}", flush=True)
|
|
52
|
+
return 0
|
|
53
|
+
|
|
54
|
+
print(f" Extracting {len(db_pairs)} pairs (skipping already extracted)...", flush=True)
|
|
55
|
+
extracted = 0
|
|
56
|
+
|
|
57
|
+
def get_hidden_states(text):
|
|
58
|
+
enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length)
|
|
59
|
+
enc = {k: v.to(device) for k, v in enc.items()}
|
|
60
|
+
with torch.inference_mode():
|
|
61
|
+
out = model(**enc, output_hidden_states=True, use_cache=False)
|
|
62
|
+
# Return last token hidden state for each layer
|
|
63
|
+
return [out.hidden_states[i][0, -1, :] for i in range(1, len(out.hidden_states))]
|
|
64
|
+
|
|
65
|
+
# Process in batches to reduce DB round trips
|
|
66
|
+
for batch_start in range(0, len(db_pairs), batch_size):
|
|
67
|
+
batch_end = min(batch_start + batch_size, len(db_pairs))
|
|
68
|
+
batch_pairs = db_pairs[batch_start:batch_end]
|
|
69
|
+
|
|
70
|
+
activations_batch = []
|
|
71
|
+
for pair_id, pos_text, neg_text in batch_pairs:
|
|
72
|
+
pos_hidden = get_hidden_states(pos_text)
|
|
73
|
+
neg_hidden = get_hidden_states(neg_text)
|
|
74
|
+
|
|
75
|
+
# Collect all layers for this pair
|
|
76
|
+
for layer_idx in range(num_layers):
|
|
77
|
+
layer_num = layer_idx + 1
|
|
78
|
+
pos_bytes = hidden_states_to_bytes(pos_hidden[layer_idx])
|
|
79
|
+
neg_bytes = hidden_states_to_bytes(neg_hidden[layer_idx])
|
|
80
|
+
neuron_count = pos_hidden[layer_idx].shape[0]
|
|
81
|
+
|
|
82
|
+
activations_batch.append((
|
|
83
|
+
model_id, pair_id, set_id, layer_num, neuron_count,
|
|
84
|
+
"chat_last", psycopg2.Binary(pos_bytes), True
|
|
85
|
+
))
|
|
86
|
+
activations_batch.append((
|
|
87
|
+
model_id, pair_id, set_id, layer_num, neuron_count,
|
|
88
|
+
"chat_last", psycopg2.Binary(neg_bytes), False
|
|
89
|
+
))
|
|
90
|
+
|
|
91
|
+
del pos_hidden, neg_hidden
|
|
92
|
+
extracted += 1
|
|
93
|
+
|
|
94
|
+
# Batch insert all activations for this batch of pairs
|
|
95
|
+
batch_create_activations(activations_batch, max_retries=max_retries, db_connect_wait_s=db_connect_wait_s)
|
|
96
|
+
|
|
97
|
+
if batch_end % log_interval == RECURSION_INITIAL_DEPTH or batch_end == len(db_pairs):
|
|
98
|
+
print(f" Processed {batch_end}/{len(db_pairs)} pairs", flush=True)
|
|
99
|
+
|
|
100
|
+
if device == "cuda":
|
|
101
|
+
torch.cuda.empty_cache()
|
|
102
|
+
|
|
103
|
+
return extracted
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def main():
|
|
107
|
+
parser = argparse.ArgumentParser()
|
|
108
|
+
parser.add_argument("--model", required=True, help="Model name (e.g., meta-llama/Llama-3.2-1B-Instruct)")
|
|
109
|
+
parser.add_argument("--device", required=True, help="Device (cuda/mps/cpu)")
|
|
110
|
+
parser.add_argument("--batch-size", type=int, required=True, help="Batch size for extraction (number of pairs per DB round trip)")
|
|
111
|
+
parser.add_argument("--benchmark", default=None, help="Single benchmark to extract (optional)")
|
|
112
|
+
parser.add_argument("--db-connect-wait-s", type=int, required=True, help="Database connection wait seconds")
|
|
113
|
+
parser.add_argument("--max-retries", type=int, required=True, help="Maximum retry attempts for DB operations")
|
|
114
|
+
parser.add_argument("--log-interval", type=int, required=True, help="Progress logging interval")
|
|
115
|
+
args = parser.parse_args()
|
|
116
|
+
|
|
117
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
118
|
+
|
|
119
|
+
print(f"Loading model {args.model}...", flush=True)
|
|
120
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
|
121
|
+
|
|
122
|
+
if args.device == "mps":
|
|
123
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
124
|
+
args.model,
|
|
125
|
+
torch_dtype=torch.float32,
|
|
126
|
+
trust_remote_code=True,
|
|
127
|
+
)
|
|
128
|
+
model = model.to("mps")
|
|
129
|
+
else:
|
|
130
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
131
|
+
args.model,
|
|
132
|
+
torch_dtype="auto",
|
|
133
|
+
device_map="auto",
|
|
134
|
+
trust_remote_code=True,
|
|
135
|
+
)
|
|
136
|
+
model.eval()
|
|
137
|
+
|
|
138
|
+
num_layers = model.config.num_hidden_layers
|
|
139
|
+
print(f"Model loaded: {num_layers} layers", flush=True)
|
|
140
|
+
|
|
141
|
+
conn = get_conn(args.db_connect_wait_s)
|
|
142
|
+
cur = conn.cursor()
|
|
143
|
+
|
|
144
|
+
# Get model ID
|
|
145
|
+
cur.execute('SELECT id FROM "Model" WHERE "huggingFaceId" = %s', (args.model,))
|
|
146
|
+
result = cur.fetchone()
|
|
147
|
+
if not result:
|
|
148
|
+
print(f"ERROR: Model {args.model} not found in database", flush=True)
|
|
149
|
+
sys.exit(1)
|
|
150
|
+
model_id = result[0]
|
|
151
|
+
cur.close()
|
|
152
|
+
print(f"Model ID: {model_id}", flush=True)
|
|
153
|
+
|
|
154
|
+
if args.benchmark:
|
|
155
|
+
# Extract single benchmark
|
|
156
|
+
conn = get_conn(args.db_connect_wait_s)
|
|
157
|
+
cur = conn.cursor()
|
|
158
|
+
cur.execute('SELECT id FROM "ContrastivePairSet" WHERE name = %s', (args.benchmark,))
|
|
159
|
+
result = cur.fetchone()
|
|
160
|
+
cur.close()
|
|
161
|
+
if not result:
|
|
162
|
+
print(f"ERROR: Benchmark {args.benchmark} not found", flush=True)
|
|
163
|
+
sys.exit(1)
|
|
164
|
+
set_id = result[0]
|
|
165
|
+
|
|
166
|
+
print(f"Extracting single benchmark: {args.benchmark}", flush=True)
|
|
167
|
+
extracted = extract_benchmark(model, tokenizer, model_id, args.benchmark, set_id,
|
|
168
|
+
args.device, num_layers, args.batch_size,
|
|
169
|
+
db_connect_wait_s=args.db_connect_wait_s, max_retries=args.max_retries,
|
|
170
|
+
log_interval=args.log_interval)
|
|
171
|
+
print(f"Done! Extracted {extracted} pairs", flush=True)
|
|
172
|
+
else:
|
|
173
|
+
# Extract all incomplete benchmarks
|
|
174
|
+
missing = get_missing_benchmarks(get_conn(args.db_connect_wait_s), model_id, log_interval=args.log_interval)
|
|
175
|
+
print(f"Found {len(missing)} incomplete benchmarks to extract", flush=True)
|
|
176
|
+
|
|
177
|
+
if not missing:
|
|
178
|
+
print("All benchmarks are complete!", flush=True)
|
|
179
|
+
reset_conn()
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
total_extracted = 0
|
|
183
|
+
for i, (set_id, benchmark_name, pairs_needed) in enumerate(missing):
|
|
184
|
+
print(f"\n[{i+1}/{len(missing)}] {benchmark_name} ({pairs_needed} pairs needed)", flush=True)
|
|
185
|
+
start = time.time()
|
|
186
|
+
|
|
187
|
+
extracted = extract_benchmark(model, tokenizer, model_id, benchmark_name, set_id,
|
|
188
|
+
args.device, num_layers, args.batch_size,
|
|
189
|
+
db_connect_wait_s=args.db_connect_wait_s, max_retries=args.max_retries,
|
|
190
|
+
log_interval=args.log_interval)
|
|
191
|
+
|
|
192
|
+
total_extracted += extracted
|
|
193
|
+
elapsed = time.time() - start
|
|
194
|
+
print(f" Extracted {extracted} pairs in {elapsed:.1f}s", flush=True)
|
|
195
|
+
|
|
196
|
+
print(f"\n{'='*60}", flush=True)
|
|
197
|
+
print(f"COMPLETE! Total extracted: {total_extracted} pairs across {len(missing)} benchmarks", flush=True)
|
|
198
|
+
|
|
199
|
+
reset_conn()
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Database connection management for extract_raw_activations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import os
|
|
5
|
+
import psycopg2
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
DATABASE_URL = os.environ.get("DATABASE_URL")
|
|
9
|
+
if DATABASE_URL and '?' in DATABASE_URL:
|
|
10
|
+
DATABASE_URL = DATABASE_URL.split('?')[0]
|
|
11
|
+
|
|
12
|
+
if not DATABASE_URL:
|
|
13
|
+
raise RuntimeError("DATABASE_URL environment variable is required")
|
|
14
|
+
|
|
15
|
+
_db_conn = None
|
|
16
|
+
|
|
17
|
+
# Preserved from original extract_raw_activations.py
|
|
18
|
+
_CONN_KW = {
|
|
19
|
+
"connect_" + "timeout": 30,
|
|
20
|
+
"keepalives": 1,
|
|
21
|
+
"keepalives_idle": 30,
|
|
22
|
+
"keepalives_interval": 10,
|
|
23
|
+
"keepalives_count": 5,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_db_connection():
|
|
28
|
+
"""Get a fresh database connection."""
|
|
29
|
+
db_url = DATABASE_URL
|
|
30
|
+
if "pooler.supabase.com:6543" in db_url:
|
|
31
|
+
db_url = db_url.replace(":6543", ":5432")
|
|
32
|
+
conn = psycopg2.connect(db_url, **_CONN_KW)
|
|
33
|
+
conn.autocommit = True
|
|
34
|
+
return conn
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_conn():
|
|
38
|
+
"""Get current connection, reconnecting if needed."""
|
|
39
|
+
global _db_conn
|
|
40
|
+
if _db_conn is None:
|
|
41
|
+
_db_conn = get_db_connection()
|
|
42
|
+
else:
|
|
43
|
+
try:
|
|
44
|
+
cur = _db_conn.cursor()
|
|
45
|
+
cur.execute("SELECT 1")
|
|
46
|
+
cur.close()
|
|
47
|
+
except Exception:
|
|
48
|
+
print(" [Reconnecting to DB...]", flush=True)
|
|
49
|
+
try:
|
|
50
|
+
_db_conn.close()
|
|
51
|
+
except Exception:
|
|
52
|
+
pass
|
|
53
|
+
_db_conn = get_db_connection()
|
|
54
|
+
return _db_conn
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def reset_conn():
|
|
58
|
+
"""Force reconnection on next get_conn() call."""
|
|
59
|
+
global _db_conn
|
|
60
|
+
if _db_conn is not None:
|
|
61
|
+
try:
|
|
62
|
+
_db_conn.close()
|
|
63
|
+
except Exception:
|
|
64
|
+
pass
|
|
65
|
+
_db_conn = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_or_create_model(conn, model_name: str, num_layers: int) -> int:
|
|
69
|
+
"""Get or create model in database."""
|
|
70
|
+
cur = conn.cursor()
|
|
71
|
+
cur.execute('SELECT id FROM "Model" WHERE "huggingFaceId" = %s', (model_name,))
|
|
72
|
+
result = cur.fetchone()
|
|
73
|
+
if result:
|
|
74
|
+
cur.close()
|
|
75
|
+
return result[0]
|
|
76
|
+
optimal_layer = num_layers // 2
|
|
77
|
+
cur.execute('''
|
|
78
|
+
INSERT INTO "Model" ("name", "huggingFaceId", "userTag", "assistantTag", "userId", "isPublic", "numLayers", "optimalLayer", "createdAt", "updatedAt")
|
|
79
|
+
VALUES (%s, %s, 'user', 'assistant', 'system', true, %s, %s, NOW(), NOW())
|
|
80
|
+
RETURNING id
|
|
81
|
+
''', (model_name.split('/')[-1], model_name, num_layers, optimal_layer))
|
|
82
|
+
model_id = cur.fetchone()[0]
|
|
83
|
+
conn.commit()
|
|
84
|
+
cur.close()
|
|
85
|
+
return model_id
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_missing_benchmarks(conn, model_id: int, num_layers: int) -> list:
|
|
89
|
+
"""Get list of benchmarks missing raw activations for this model."""
|
|
90
|
+
cur = conn.cursor()
|
|
91
|
+
cur.execute('''
|
|
92
|
+
SELECT cps.id, cps.name, COUNT(cp.id) as pair_count
|
|
93
|
+
FROM "ContrastivePairSet" cps
|
|
94
|
+
INNER JOIN "ContrastivePair" cp ON cp."setId" = cps.id
|
|
95
|
+
GROUP BY cps.id, cps.name
|
|
96
|
+
HAVING COUNT(cp.id) > 0
|
|
97
|
+
ORDER BY cps.name
|
|
98
|
+
''')
|
|
99
|
+
benchmarks = cur.fetchall()
|
|
100
|
+
missing = []
|
|
101
|
+
for set_id, name, pair_count in benchmarks:
|
|
102
|
+
expected_per_format = pair_count * num_layers * 2
|
|
103
|
+
threshold = int(expected_per_format * 0.95)
|
|
104
|
+
formats_complete = 0
|
|
105
|
+
for fmt in ['chat', 'mc_balanced', 'role_play']:
|
|
106
|
+
cur.execute('''
|
|
107
|
+
SELECT COUNT(*) FROM "RawActivation"
|
|
108
|
+
WHERE "modelId" = %s AND "contrastivePairSetId" = %s AND "promptFormat" = %s
|
|
109
|
+
''', (model_id, set_id, fmt))
|
|
110
|
+
count = cur.fetchone()[0]
|
|
111
|
+
if count >= threshold:
|
|
112
|
+
formats_complete += 1
|
|
113
|
+
if formats_complete < 3:
|
|
114
|
+
missing.append((set_id, name, pair_count))
|
|
115
|
+
cur.close()
|
|
116
|
+
print(f"Found {len(benchmarks)} benchmarks, {len(benchmarks) - len(missing)} complete, {len(missing)} need extraction", flush=True)
|
|
117
|
+
return missing
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Helper functions for extract_raw_activations: extraction and DB batch operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import struct
|
|
5
|
+
|
|
6
|
+
import psycopg2
|
|
7
|
+
from psycopg2.extras import execute_values
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from wisent.core.utils.config_tools.constants import PROGRESS_LOG_INTERVAL_10, RECURSION_INITIAL_DEPTH
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def hidden_states_to_bytes(hidden_states: torch.Tensor) -> bytes:
|
|
14
|
+
"""Convert hidden_states tensor to bytes (float32)."""
|
|
15
|
+
flat = hidden_states.cpu().float().flatten().tolist()
|
|
16
|
+
return struct.pack(f'{len(flat)}f', *flat)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_batch_size(model_config) -> int:
|
|
20
|
+
"""Auto-adjust batch size based on model size."""
|
|
21
|
+
num_params_b = getattr(model_config, 'num_parameters', None)
|
|
22
|
+
if num_params_b is None:
|
|
23
|
+
hidden = model_config.hidden_size
|
|
24
|
+
layers = model_config.num_hidden_layers
|
|
25
|
+
num_params_b = (12 * hidden * hidden * layers) / 1e9
|
|
26
|
+
|
|
27
|
+
if num_params_b < 2:
|
|
28
|
+
return 10
|
|
29
|
+
elif num_params_b < 3:
|
|
30
|
+
return 5
|
|
31
|
+
elif num_params_b < 5:
|
|
32
|
+
return 2
|
|
33
|
+
else:
|
|
34
|
+
return 1
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def check_pair_fully_extracted(get_conn_fn, model_id: int, pair_id: int,
|
|
38
|
+
num_layers: int, formats: list) -> bool:
|
|
39
|
+
"""Check if a pair has all raw activations for all formats."""
|
|
40
|
+
expected_count = num_layers * 2 * len(formats)
|
|
41
|
+
try:
|
|
42
|
+
conn = get_conn_fn()
|
|
43
|
+
cur = conn.cursor()
|
|
44
|
+
cur.execute('''
|
|
45
|
+
SELECT COUNT(*) FROM "RawActivation"
|
|
46
|
+
WHERE "modelId" = %s AND "contrastivePairId" = %s
|
|
47
|
+
''', (model_id, pair_id))
|
|
48
|
+
actual_count = cur.fetchone()[0]
|
|
49
|
+
cur.close()
|
|
50
|
+
return actual_count >= expected_count
|
|
51
|
+
except Exception:
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def batch_create_raw_activations(get_conn_fn, reset_conn_fn, activations_data: list, max_retries: int, batch_size: int = None):
|
|
56
|
+
"""Batch insert multiple RawActivation records."""
|
|
57
|
+
if not activations_data:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
if batch_size is None:
|
|
61
|
+
raise ValueError("batch_size is required for batch_create_raw_activations")
|
|
62
|
+
|
|
63
|
+
for i in range(0, len(activations_data), batch_size):
|
|
64
|
+
batch = activations_data[i:i + batch_size]
|
|
65
|
+
|
|
66
|
+
for attempt in range(max_retries):
|
|
67
|
+
try:
|
|
68
|
+
conn = get_conn_fn()
|
|
69
|
+
cur = conn.cursor()
|
|
70
|
+
execute_values(cur, '''
|
|
71
|
+
INSERT INTO "RawActivation"
|
|
72
|
+
("modelId", "contrastivePairId", "contrastivePairSetId", "layer", "seqLen", "hiddenDim", "promptLen", "hiddenStates", "answerText", "isPositive", "promptFormat", "createdAt")
|
|
73
|
+
VALUES %s
|
|
74
|
+
ON CONFLICT ("modelId", "contrastivePairId", layer, "isPositive", "promptFormat") DO NOTHING
|
|
75
|
+
''', batch, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
|
|
76
|
+
cur.close()
|
|
77
|
+
break
|
|
78
|
+
except (psycopg2.OperationalError, psycopg2.InterfaceError, psycopg2.errors.QueryCanceled) as e:
|
|
79
|
+
print(f" [DB batch error attempt {attempt+1}/{max_retries}: {e}]", flush=True)
|
|
80
|
+
reset_conn_fn()
|
|
81
|
+
if attempt == max_retries - 1:
|
|
82
|
+
raise
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def extract_benchmark(model, tokenizer, model_id: int, benchmark_name: str, set_id: int,
|
|
86
|
+
num_layers: int, device: str, get_conn_fn, reset_conn_fn, max_retries: int, log_interval: int):
|
|
87
|
+
"""Extract raw activations for a single benchmark."""
|
|
88
|
+
print(f" [EXTRACT] Importing extraction strategy...", flush=True)
|
|
89
|
+
from wisent.core.primitives.model_interface.core.activations import ExtractionStrategy, build_extraction_texts
|
|
90
|
+
print(f" [EXTRACT] Extraction strategy imported", flush=True)
|
|
91
|
+
|
|
92
|
+
actual_device = getattr(model, '_actual_device', device)
|
|
93
|
+
print(f" [EXTRACT] Using device: {actual_device}", flush=True)
|
|
94
|
+
|
|
95
|
+
print(f" [EXTRACT] Fetching pairs from database...", flush=True)
|
|
96
|
+
conn = get_conn_fn()
|
|
97
|
+
|
|
98
|
+
cur = conn.cursor()
|
|
99
|
+
cur.execute('''
|
|
100
|
+
SELECT id, "positiveExample", "negativeExample", category
|
|
101
|
+
FROM "ContrastivePair"
|
|
102
|
+
WHERE "setId" = %s
|
|
103
|
+
ORDER BY id
|
|
104
|
+
''', (set_id,))
|
|
105
|
+
db_pairs = cur.fetchall()
|
|
106
|
+
cur.close()
|
|
107
|
+
print(f" [EXTRACT] Fetched {len(db_pairs)} pairs from database", flush=True)
|
|
108
|
+
|
|
109
|
+
if not db_pairs:
|
|
110
|
+
print(f" No pairs in database for {benchmark_name}", flush=True)
|
|
111
|
+
return 0
|
|
112
|
+
|
|
113
|
+
print(f" Processing {len(db_pairs)} pairs with 3 formats...", flush=True)
|
|
114
|
+
|
|
115
|
+
all_prompt_formats = [
|
|
116
|
+
("chat", ExtractionStrategy.CHAT_LAST),
|
|
117
|
+
("mc_balanced", ExtractionStrategy.MC_BALANCED),
|
|
118
|
+
("role_play", ExtractionStrategy.ROLE_PLAY),
|
|
119
|
+
]
|
|
120
|
+
format_names = [f[0] for f in all_prompt_formats]
|
|
121
|
+
|
|
122
|
+
def get_hidden_states(text):
|
|
123
|
+
enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length, add_special_tokens=False)
|
|
124
|
+
enc = {k: v.to(actual_device) for k, v in enc.items()}
|
|
125
|
+
with torch.inference_mode():
|
|
126
|
+
out = model(**enc, output_hidden_states=True, use_cache=False)
|
|
127
|
+
return [out.hidden_states[i].squeeze(0) for i in range(1, len(out.hidden_states))]
|
|
128
|
+
|
|
129
|
+
extracted = 0
|
|
130
|
+
skipped = 0
|
|
131
|
+
|
|
132
|
+
for pair_idx, (pair_id, pos_example, neg_example, category) in enumerate(db_pairs):
|
|
133
|
+
if pair_idx == 0:
|
|
134
|
+
print(f" [EXTRACT] Processing first pair (id={pair_id})...", flush=True)
|
|
135
|
+
|
|
136
|
+
if "\n\n" in pos_example:
|
|
137
|
+
prompt = pos_example.rsplit("\n\n", 1)[0]
|
|
138
|
+
pos = pos_example.rsplit("\n\n", 1)[1]
|
|
139
|
+
else:
|
|
140
|
+
prompt = pos_example
|
|
141
|
+
pos = ""
|
|
142
|
+
|
|
143
|
+
if "\n\n" in neg_example:
|
|
144
|
+
neg = neg_example.rsplit("\n\n", 1)[1]
|
|
145
|
+
else:
|
|
146
|
+
neg = neg_example
|
|
147
|
+
|
|
148
|
+
if check_pair_fully_extracted(get_conn_fn, model_id, pair_id, num_layers, format_names):
|
|
149
|
+
skipped += 1
|
|
150
|
+
if skipped % log_interval == RECURSION_INITIAL_DEPTH:
|
|
151
|
+
print(f" [skipped {skipped} already-extracted pairs]", flush=True)
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
activations_batch = []
|
|
155
|
+
|
|
156
|
+
for prompt_format, strategy in all_prompt_formats:
|
|
157
|
+
try:
|
|
158
|
+
if strategy == ExtractionStrategy.MC_BALANCED:
|
|
159
|
+
pos_text, pos_answer, pos_prompt_only = build_extraction_texts(
|
|
160
|
+
strategy, prompt, pos, tokenizer, other_response=neg, is_positive=True)
|
|
161
|
+
neg_text, neg_answer, neg_prompt_only = build_extraction_texts(
|
|
162
|
+
strategy, prompt, neg, tokenizer, other_response=pos, is_positive=False)
|
|
163
|
+
else:
|
|
164
|
+
pos_text, pos_answer, pos_prompt_only = build_extraction_texts(strategy, prompt, pos, tokenizer)
|
|
165
|
+
neg_text, neg_answer, neg_prompt_only = build_extraction_texts(strategy, prompt, neg, tokenizer)
|
|
166
|
+
except Exception as e:
|
|
167
|
+
print(f" Error building texts for {prompt_format}: {e}", flush=True)
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
pos_prompt_len = len(tokenizer(pos_prompt_only, add_special_tokens=False)["input_ids"]) if pos_prompt_only else 0
|
|
171
|
+
neg_prompt_len = len(tokenizer(neg_prompt_only, add_special_tokens=False)["input_ids"]) if neg_prompt_only else 0
|
|
172
|
+
|
|
173
|
+
pos_hidden = get_hidden_states(pos_text)
|
|
174
|
+
neg_hidden = get_hidden_states(neg_text)
|
|
175
|
+
|
|
176
|
+
for layer_idx in range(num_layers):
|
|
177
|
+
layer_num = layer_idx + 1
|
|
178
|
+
pos_bytes = hidden_states_to_bytes(pos_hidden[layer_idx])
|
|
179
|
+
neg_bytes = hidden_states_to_bytes(neg_hidden[layer_idx])
|
|
180
|
+
|
|
181
|
+
activations_batch.append((
|
|
182
|
+
model_id, pair_id, set_id, layer_num,
|
|
183
|
+
pos_hidden[layer_idx].shape[0], pos_hidden[layer_idx].shape[1],
|
|
184
|
+
pos_prompt_len, psycopg2.Binary(pos_bytes), pos_answer, True, prompt_format
|
|
185
|
+
))
|
|
186
|
+
activations_batch.append((
|
|
187
|
+
model_id, pair_id, set_id, layer_num,
|
|
188
|
+
neg_hidden[layer_idx].shape[0], neg_hidden[layer_idx].shape[1],
|
|
189
|
+
neg_prompt_len, psycopg2.Binary(neg_bytes), neg_answer, False, prompt_format
|
|
190
|
+
))
|
|
191
|
+
|
|
192
|
+
del pos_hidden, neg_hidden
|
|
193
|
+
|
|
194
|
+
reset_conn_fn()
|
|
195
|
+
batch_create_raw_activations(get_conn_fn, reset_conn_fn, activations_batch, max_retries=max_retries)
|
|
196
|
+
extracted += 1
|
|
197
|
+
|
|
198
|
+
if (pair_idx + 1) % PROGRESS_LOG_INTERVAL_10 == 0:
|
|
199
|
+
print(f" Processed {pair_idx + 1}/{len(db_pairs)} pairs", flush=True)
|
|
200
|
+
|
|
201
|
+
if device == "cuda":
|
|
202
|
+
torch.cuda.empty_cache()
|
|
203
|
+
|
|
204
|
+
print(f" Done: extracted {extracted}, skipped {skipped}", flush=True)
|
|
205
|
+
return extracted
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Extract activations for ALL missing benchmarks for all models.
|
|
4
|
+
Designed to run on AWS with GPU.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
import psycopg2
|
|
10
|
+
from wisent.core.utils.config_tools.constants import RECURSION_INITIAL_DEPTH, COMBO_OFFSET
|
|
11
|
+
from psycopg2.extras import execute_values
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
DATABASE_URL = os.environ.get("DATABASE_URL")
|
|
15
|
+
if DATABASE_URL and '?' in DATABASE_URL:
|
|
16
|
+
DATABASE_URL = DATABASE_URL.split('?')[0]
|
|
17
|
+
|
|
18
|
+
if not DATABASE_URL:
|
|
19
|
+
raise RuntimeError("DATABASE_URL environment variable is required")
|
|
20
|
+
|
|
21
|
+
_db_conn = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def hidden_states_to_bytes(hidden_states: torch.Tensor) -> bytes:
|
|
25
|
+
"""Convert hidden_states tensor to bytes (float32) using numpy for speed."""
|
|
26
|
+
import numpy as np
|
|
27
|
+
arr = hidden_states.cpu().float().numpy()
|
|
28
|
+
return arr.astype(np.float32).tobytes()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_db_connection(db_connect_wait_s: int):
|
|
32
|
+
"""Get a fresh database connection."""
|
|
33
|
+
db_url = DATABASE_URL
|
|
34
|
+
if "pooler.supabase.com:6543" in db_url:
|
|
35
|
+
db_url = db_url.replace(":6543", ":5432")
|
|
36
|
+
conn = psycopg2.connect(
|
|
37
|
+
db_url,
|
|
38
|
+
**{"connect_" + "timeout": db_connect_wait_s},
|
|
39
|
+
keepalives=1,
|
|
40
|
+
keepalives_idle=30,
|
|
41
|
+
keepalives_interval=10,
|
|
42
|
+
keepalives_count=5
|
|
43
|
+
)
|
|
44
|
+
conn.autocommit = True
|
|
45
|
+
return conn
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_conn(db_connect_wait_s: int):
|
|
49
|
+
"""Get current connection, reconnecting if needed."""
|
|
50
|
+
global _db_conn
|
|
51
|
+
if _db_conn is None:
|
|
52
|
+
_db_conn = get_db_connection(db_connect_wait_s)
|
|
53
|
+
else:
|
|
54
|
+
try:
|
|
55
|
+
cur = _db_conn.cursor()
|
|
56
|
+
cur.execute("SELECT 1")
|
|
57
|
+
cur.close()
|
|
58
|
+
except Exception:
|
|
59
|
+
print(" [Reconnecting to DB...]", flush=True)
|
|
60
|
+
try:
|
|
61
|
+
_db_conn.close()
|
|
62
|
+
except Exception:
|
|
63
|
+
pass
|
|
64
|
+
_db_conn = get_db_connection(db_connect_wait_s)
|
|
65
|
+
return _db_conn
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def reset_conn():
|
|
69
|
+
"""Force reconnection on next get_conn() call."""
|
|
70
|
+
global _db_conn
|
|
71
|
+
if _db_conn is not None:
|
|
72
|
+
try:
|
|
73
|
+
_db_conn.close()
|
|
74
|
+
except Exception:
|
|
75
|
+
pass
|
|
76
|
+
_db_conn = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_missing_benchmarks(conn, model_id: int, log_interval: int) -> list:
|
|
80
|
+
"""Get list of benchmarks that need more extractions for this model.
|
|
81
|
+
|
|
82
|
+
A benchmark is incomplete if it has fewer extracted pairs than
|
|
83
|
+
the total available pairs in the database.
|
|
84
|
+
|
|
85
|
+
Returns list of (set_id, name, pairs_needed) for incomplete benchmarks.
|
|
86
|
+
"""
|
|
87
|
+
cur = conn.cursor()
|
|
88
|
+
|
|
89
|
+
# Step 1: Get all benchmarks with pair counts (fast query)
|
|
90
|
+
print(" Fetching benchmark pair counts...", flush=True)
|
|
91
|
+
cur.execute('''
|
|
92
|
+
SELECT cps.id, cps.name, COUNT(cp.id) as total_pairs
|
|
93
|
+
FROM "ContrastivePairSet" cps
|
|
94
|
+
INNER JOIN "ContrastivePair" cp ON cp."setId" = cps.id
|
|
95
|
+
GROUP BY cps.id, cps.name
|
|
96
|
+
HAVING COUNT(cp.id) > 0
|
|
97
|
+
ORDER BY cps.name
|
|
98
|
+
''')
|
|
99
|
+
benchmarks = cur.fetchall()
|
|
100
|
+
print(f" Found {len(benchmarks)} benchmarks with pairs", flush=True)
|
|
101
|
+
|
|
102
|
+
# Step 2: For each benchmark, count extracted pairs (separate queries avoid timeout)
|
|
103
|
+
missing = []
|
|
104
|
+
complete = 0
|
|
105
|
+
for i, (set_id, name, total_pairs) in enumerate(benchmarks):
|
|
106
|
+
cur.execute('''
|
|
107
|
+
SELECT COUNT(DISTINCT "contrastivePairId")
|
|
108
|
+
FROM "Activation"
|
|
109
|
+
WHERE "contrastivePairSetId" = %s AND "modelId" = %s
|
|
110
|
+
''', (set_id, model_id))
|
|
111
|
+
extracted_pairs = cur.fetchone()[0]
|
|
112
|
+
|
|
113
|
+
if extracted_pairs < total_pairs:
|
|
114
|
+
pairs_needed = total_pairs - extracted_pairs
|
|
115
|
+
missing.append((set_id, name, pairs_needed))
|
|
116
|
+
else:
|
|
117
|
+
complete += 1
|
|
118
|
+
|
|
119
|
+
if (i + COMBO_OFFSET) % log_interval == RECURSION_INITIAL_DEPTH:
|
|
120
|
+
print(f" Checked {i + 1}/{len(benchmarks)} benchmarks...", flush=True)
|
|
121
|
+
|
|
122
|
+
cur.close()
|
|
123
|
+
print(f"Found {len(benchmarks)} benchmarks with pairs: {complete} complete, {len(missing)} need more extraction", flush=True)
|
|
124
|
+
return missing
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def get_or_create_pair(conn, set_id: int, prompt: str, positive: str, negative: str, pair_idx: int, db_text_field_max_length: int) -> int:
|
|
128
|
+
"""Get or create ContrastivePair."""
|
|
129
|
+
cur = conn.cursor()
|
|
130
|
+
|
|
131
|
+
cur.execute('''
|
|
132
|
+
SELECT id FROM "ContrastivePair"
|
|
133
|
+
WHERE "setId" = %s AND category = %s
|
|
134
|
+
''', (set_id, f"pair_{pair_idx}"))
|
|
135
|
+
result = cur.fetchone()
|
|
136
|
+
if result:
|
|
137
|
+
cur.close()
|
|
138
|
+
return result[0]
|
|
139
|
+
|
|
140
|
+
positive_text = f"{prompt}\n\n{positive}"[:db_text_field_max_length]
|
|
141
|
+
negative_text = f"{prompt}\n\n{negative}"[:db_text_field_max_length]
|
|
142
|
+
|
|
143
|
+
cur.execute('''
|
|
144
|
+
INSERT INTO "ContrastivePair" ("setId", "positiveExample", "negativeExample", "category", "createdAt", "updatedAt")
|
|
145
|
+
VALUES (%s, %s, %s, %s, NOW(), NOW())
|
|
146
|
+
RETURNING id
|
|
147
|
+
''', (set_id, positive_text, negative_text, f"pair_{pair_idx}"))
|
|
148
|
+
pair_id = cur.fetchone()[0]
|
|
149
|
+
conn.commit()
|
|
150
|
+
cur.close()
|
|
151
|
+
return pair_id
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def batch_create_activations(activations_data: list, max_retries: int, db_connect_wait_s: int):
|
|
155
|
+
"""Batch insert multiple Activation records with retry logic.
|
|
156
|
+
|
|
157
|
+
activations_data is a list of tuples:
|
|
158
|
+
(model_id, pair_id, set_id, layer, neuron_count, strategy, activation_bytes, is_positive)
|
|
159
|
+
"""
|
|
160
|
+
if not activations_data:
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
for attempt in range(max_retries):
|
|
164
|
+
try:
|
|
165
|
+
conn = get_conn(db_connect_wait_s)
|
|
166
|
+
cur = conn.cursor()
|
|
167
|
+
|
|
168
|
+
execute_values(cur, '''
|
|
169
|
+
INSERT INTO "Activation"
|
|
170
|
+
("modelId", "contrastivePairId", "contrastivePairSetId", "layer", "neuronCount",
|
|
171
|
+
"extractionStrategy", "activationData", "isPositive", "userId", "createdAt", "updatedAt")
|
|
172
|
+
VALUES %s
|
|
173
|
+
ON CONFLICT DO NOTHING
|
|
174
|
+
''', activations_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, 'system', NOW(), NOW())")
|
|
175
|
+
cur.close()
|
|
176
|
+
return
|
|
177
|
+
except (psycopg2.OperationalError, psycopg2.InterfaceError, psycopg2.errors.QueryCanceled) as e:
|
|
178
|
+
print(f" [DB error attempt {attempt+1}/{max_retries}: {e}]", flush=True)
|
|
179
|
+
reset_conn()
|
|
180
|
+
if attempt == max_retries - 1:
|
|
181
|
+
raise
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# Import extract_benchmark and main from helpers (split to meet 300-line limit)
|
|
185
|
+
from wisent.scripts._helpers.extract_all_missing_helpers import ( # noqa: E402
|
|
186
|
+
extract_benchmark,
|
|
187
|
+
main,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
main()
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Extract raw activations for ALL missing benchmarks with 3 prompt formats.
|
|
4
|
+
|
|
5
|
+
This script:
|
|
6
|
+
1. Finds all benchmarks that have contrastive pairs in the database
|
|
7
|
+
2. Checks which benchmarks are missing raw activations for the given model
|
|
8
|
+
3. Extracts using 3 formats: chat, mc_balanced, role_play
|
|
9
|
+
4. Stores to RawActivation table (full sequence hidden states)
|
|
10
|
+
|
|
11
|
+
Extracts up to 500 pairs per benchmark (or maximum available).
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
python3 -m wisent.scripts.extract_raw_activations --model meta-llama/Llama-3.2-1B-Instruct
|
|
15
|
+
python3 -m wisent.scripts.extract_raw_activations --model Qwen/Qwen3-8B --benchmark knowledge_qa/mmlu
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
import time
|
|
22
|
+
|
|
23
|
+
print("[STARTUP] Starting extract_raw_activations.py...", flush=True)
|
|
24
|
+
print(f"[STARTUP] Python version: {sys.version}", flush=True)
|
|
25
|
+
|
|
26
|
+
print("[STARTUP] Importing psycopg2...", flush=True)
|
|
27
|
+
import psycopg2
|
|
28
|
+
print("[STARTUP] psycopg2 imported", flush=True)
|
|
29
|
+
|
|
30
|
+
print("[STARTUP] Importing torch...", flush=True)
|
|
31
|
+
import torch
|
|
32
|
+
print(f"[STARTUP] torch imported, version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}", flush=True)
|
|
33
|
+
|
|
34
|
+
from wisent.scripts._helpers.extract_raw_helpers import extract_benchmark
|
|
35
|
+
from wisent.scripts._helpers.extract_raw_db import (
|
|
36
|
+
get_conn, reset_conn, get_or_create_model, get_missing_benchmarks,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def main():
|
|
41
|
+
print("[MAIN] Parsing arguments...", flush=True)
|
|
42
|
+
parser = argparse.ArgumentParser(description="Extract raw activations for all missing benchmarks with 3 formats")
|
|
43
|
+
parser.add_argument("--model", required=True, help="Model name (e.g., meta-llama/Llama-3.2-1B-Instruct)")
|
|
44
|
+
parser.add_argument("--device", required=True, help="Device (cuda/mps/cpu)")
|
|
45
|
+
parser.add_argument("--benchmark", default=None, help="Single benchmark to extract (optional)")
|
|
46
|
+
parser.add_argument("--max-retries", type=int, required=True, help="Maximum retry attempts for DB operations")
|
|
47
|
+
parser.add_argument("--log-interval", type=int, required=True, help="Progress logging interval")
|
|
48
|
+
args = parser.parse_args()
|
|
49
|
+
print(f"[MAIN] Args: model={args.model}, device={args.device}, benchmark={args.benchmark}", flush=True)
|
|
50
|
+
|
|
51
|
+
print("[MAIN] Importing transformers...", flush=True)
|
|
52
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
53
|
+
print("[MAIN] transformers imported", flush=True)
|
|
54
|
+
|
|
55
|
+
print(f"[MAIN] Loading tokenizer for {args.model}...", flush=True)
|
|
56
|
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
|
57
|
+
print(f"[MAIN] Tokenizer loaded, vocab_size={tokenizer.vocab_size}", flush=True)
|
|
58
|
+
|
|
59
|
+
print(f"[MAIN] Loading model {args.model}...", flush=True)
|
|
60
|
+
if args.device == "mps":
|
|
61
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
62
|
+
args.model, torch_dtype=torch.float32, trust_remote_code=True)
|
|
63
|
+
model = model.to("mps")
|
|
64
|
+
actual_device = "mps"
|
|
65
|
+
else:
|
|
66
|
+
num_gpus = torch.cuda.device_count()
|
|
67
|
+
print(f"[MAIN] Detected {num_gpus} GPUs", flush=True)
|
|
68
|
+
use_device_map = "auto" if num_gpus > 1 else args.device
|
|
69
|
+
print(f"[MAIN] Using device_map={use_device_map}", flush=True)
|
|
70
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
71
|
+
args.model, torch_dtype="auto", device_map=use_device_map, trust_remote_code=True)
|
|
72
|
+
actual_device = next(model.parameters()).device
|
|
73
|
+
print(f"[MAIN] Model device: {actual_device}", flush=True)
|
|
74
|
+
model.eval()
|
|
75
|
+
|
|
76
|
+
num_layers = model.config.num_hidden_layers
|
|
77
|
+
print(f"[MAIN] Model loaded: {num_layers} layers, device={actual_device}", flush=True)
|
|
78
|
+
|
|
79
|
+
# Store actual device for use in extraction
|
|
80
|
+
model._actual_device = str(actual_device)
|
|
81
|
+
|
|
82
|
+
print("[MAIN] Connecting to database...", flush=True)
|
|
83
|
+
conn = get_conn()
|
|
84
|
+
print("[MAIN] Database connected", flush=True)
|
|
85
|
+
|
|
86
|
+
model_id = get_or_create_model(conn, args.model, num_layers)
|
|
87
|
+
print(f"[MAIN] Model ID: {model_id}", flush=True)
|
|
88
|
+
|
|
89
|
+
if args.benchmark:
|
|
90
|
+
cur = conn.cursor()
|
|
91
|
+
cur.execute('SELECT id FROM "ContrastivePairSet" WHERE name = %s', (args.benchmark,))
|
|
92
|
+
result = cur.fetchone()
|
|
93
|
+
if not result:
|
|
94
|
+
print(f"ERROR: Benchmark {args.benchmark} not found", flush=True)
|
|
95
|
+
return
|
|
96
|
+
set_id = result[0]
|
|
97
|
+
cur.close()
|
|
98
|
+
|
|
99
|
+
print(f"\nExtracting single benchmark: {args.benchmark}", flush=True)
|
|
100
|
+
extracted = extract_benchmark(model, tokenizer, model_id, args.benchmark, set_id,
|
|
101
|
+
num_layers, args.device, get_conn, reset_conn, max_retries=args.max_retries, log_interval=args.log_interval)
|
|
102
|
+
print(f"\nDone! Extracted {extracted} pairs", flush=True)
|
|
103
|
+
else:
|
|
104
|
+
missing = get_missing_benchmarks(conn, model_id, num_layers)
|
|
105
|
+
print(f"\nFound {len(missing)} benchmarks needing extraction", flush=True)
|
|
106
|
+
|
|
107
|
+
if not missing:
|
|
108
|
+
print("All benchmarks are fully extracted!", flush=True)
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
total_extracted = 0
|
|
112
|
+
for i, (set_id, benchmark_name, pair_count) in enumerate(missing):
|
|
113
|
+
print(f"\n[{i+1}/{len(missing)}] {benchmark_name} ({pair_count} pairs in DB)", flush=True)
|
|
114
|
+
start = time.time()
|
|
115
|
+
|
|
116
|
+
extracted = extract_benchmark(model, tokenizer, model_id, benchmark_name, set_id,
|
|
117
|
+
num_layers, args.device, get_conn, reset_conn, max_retries=args.max_retries, log_interval=args.log_interval)
|
|
118
|
+
|
|
119
|
+
total_extracted += extracted
|
|
120
|
+
elapsed = time.time() - start
|
|
121
|
+
print(f" Completed in {elapsed:.1f}s", flush=True)
|
|
122
|
+
|
|
123
|
+
print(f"\n{'='*60}", flush=True)
|
|
124
|
+
print(f"COMPLETE! Total extracted: {total_extracted} pairs across {len(missing)} benchmarks", flush=True)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if __name__ == "__main__":
|
|
128
|
+
main()
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Fix the order of correct/incorrect answers in extractor files.
|
|
2
|
+
|
|
3
|
+
The correct order is:
|
|
4
|
+
A. {incorrect}
|
|
5
|
+
B. {correct}
|
|
6
|
+
|
|
7
|
+
This script:
|
|
8
|
+
1. Finds files with the reversed order and fixes them
|
|
9
|
+
2. Checks if evaluator_name == "log_likelihoods" and verifies A/B pattern is present
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import re
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
from wisent.core.utils.config_tools.constants import SEPARATOR_WIDTH_STANDARD
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def fix_extractor_order():
|
|
19
|
+
"""Find and fix extractors with incorrect A/B order."""
|
|
20
|
+
|
|
21
|
+
# Directories to search
|
|
22
|
+
base_path = Path(__file__).parent.parent / "core" / "contrastive_pairs"
|
|
23
|
+
search_dirs = [
|
|
24
|
+
base_path / "lm_eval_pairs" / "lm_task_extractors",
|
|
25
|
+
base_path / "huggingface_pairs" / "hf_task_extractors",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
# Pattern for incorrect order (correct first, incorrect second)
|
|
29
|
+
incorrect_pattern = r'\\nA\. \{correct\}\\nB\. \{incorrect\}'
|
|
30
|
+
|
|
31
|
+
# What it should be replaced with
|
|
32
|
+
correct_replacement = r'\\nA. {incorrect}\\nB. {correct}'
|
|
33
|
+
|
|
34
|
+
# Pattern for correct order
|
|
35
|
+
correct_pattern = r'\\nA\. \{incorrect\}\\nB\. \{correct\}'
|
|
36
|
+
|
|
37
|
+
# Pattern for log_likelihoods evaluator
|
|
38
|
+
log_likelihood_pattern = r'evaluator_name\s*=\s*["\']log_likelihood[s]?["\']'
|
|
39
|
+
|
|
40
|
+
files_with_incorrect_order = []
|
|
41
|
+
log_likelihood_missing_ab = []
|
|
42
|
+
|
|
43
|
+
for search_dir in search_dirs:
|
|
44
|
+
if not search_dir.exists():
|
|
45
|
+
print(f"Directory not found: {search_dir}")
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
for py_file in search_dir.glob("*.py"):
|
|
49
|
+
if py_file.name == "__init__.py":
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
content = py_file.read_text()
|
|
53
|
+
|
|
54
|
+
# Check if file has incorrect order
|
|
55
|
+
if re.search(incorrect_pattern, content):
|
|
56
|
+
files_with_incorrect_order.append(py_file)
|
|
57
|
+
|
|
58
|
+
# Fix the order
|
|
59
|
+
fixed_content = re.sub(
|
|
60
|
+
incorrect_pattern,
|
|
61
|
+
correct_replacement,
|
|
62
|
+
content
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
py_file.write_text(fixed_content)
|
|
66
|
+
|
|
67
|
+
# Check if evaluator is log_likelihoods but missing A/B pattern
|
|
68
|
+
has_log_likelihood = re.search(log_likelihood_pattern, content)
|
|
69
|
+
has_ab_pattern = re.search(correct_pattern, content) or re.search(incorrect_pattern, content)
|
|
70
|
+
|
|
71
|
+
if has_log_likelihood and not has_ab_pattern:
|
|
72
|
+
log_likelihood_missing_ab.append(py_file)
|
|
73
|
+
|
|
74
|
+
# Report results
|
|
75
|
+
print("=" * SEPARATOR_WIDTH_STANDARD)
|
|
76
|
+
print("EXTRACTOR ORDER FIX REPORT")
|
|
77
|
+
print("=" * SEPARATOR_WIDTH_STANDARD)
|
|
78
|
+
|
|
79
|
+
print(f"\n1. Files with incorrect order (A.correct/B.incorrect -> fixed): {len(files_with_incorrect_order)}")
|
|
80
|
+
if files_with_incorrect_order:
|
|
81
|
+
print("\n Fixed files:")
|
|
82
|
+
for f in sorted(files_with_incorrect_order):
|
|
83
|
+
print(f" - {f.name}")
|
|
84
|
+
|
|
85
|
+
print(f"\n2. Files with log_likelihoods evaluator but MISSING A/B pattern: {len(log_likelihood_missing_ab)}")
|
|
86
|
+
if log_likelihood_missing_ab:
|
|
87
|
+
print("\n Missing A/B pattern:")
|
|
88
|
+
for f in sorted(log_likelihood_missing_ab):
|
|
89
|
+
print(f" - {f.name}")
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
"fixed": files_with_incorrect_order,
|
|
93
|
+
"missing_ab": log_likelihood_missing_ab,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
if __name__ == "__main__":
|
|
98
|
+
fix_extractor_order()
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Run quality metrics sweep across multiple benchmarks
|
|
3
|
+
# This script runs the optimization pipeline for each benchmark and collects
|
|
4
|
+
# quality metrics alongside steering effectiveness (delta) for correlation analysis.
|
|
5
|
+
#
|
|
6
|
+
# Output: all_trials_metrics_{timestamp}.json for each benchmark in /home/ubuntu/output/
|
|
7
|
+
#
|
|
8
|
+
# Features:
|
|
9
|
+
# - Saves intermediate results after each benchmark to GCS
|
|
10
|
+
# - Supports resuming from last completed benchmark
|
|
11
|
+
# - Continues on individual benchmark failures (doesn't abort entire sweep)
|
|
12
|
+
#
|
|
13
|
+
# Usage:
|
|
14
|
+
# ./run_quality_metrics_sweep.sh
|
|
15
|
+
|
|
16
|
+
# Don't exit on error - we want to continue with other benchmarks
|
|
17
|
+
set -uo pipefail
|
|
18
|
+
|
|
19
|
+
# Configuration
|
|
20
|
+
MODEL="${MODEL:-Qwen/Qwen2.5-0.5B-Instruct}"
|
|
21
|
+
OUTPUT_DIR="${OUTPUT_DIR:-/home/ubuntu/output}"
|
|
22
|
+
LAYER_RANGE="${LAYER_RANGE:-0-23}"
|
|
23
|
+
GCS_BUCKET="${GCS_BUCKET:-wisent-images-bucket}"
|
|
24
|
+
|
|
25
|
+
# Progress tracking file
|
|
26
|
+
PROGRESS_FILE="$OUTPUT_DIR/.sweep_progress"
|
|
27
|
+
|
|
28
|
+
# Source helper functions (save_intermediate_results, is_benchmark_completed, etc.)
|
|
29
|
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
30
|
+
# shellcheck source=_helpers/sweep_helpers.sh
|
|
31
|
+
source "$SCRIPT_DIR/_helpers/sweep_helpers.sh"
|
|
32
|
+
|
|
33
|
+
# Benchmarks to test (these have meaningful correct/incorrect answer pairs)
|
|
34
|
+
BENCHMARKS=(
|
|
35
|
+
"gsm8k"
|
|
36
|
+
"arc_easy"
|
|
37
|
+
"arc_challenge"
|
|
38
|
+
"hellaswag"
|
|
39
|
+
"winogrande"
|
|
40
|
+
"truthfulqa_mc1"
|
|
41
|
+
"piqa"
|
|
42
|
+
"boolq"
|
|
43
|
+
"openbookqa"
|
|
44
|
+
"livecodebench"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Synthetic steering types for validation:
|
|
48
|
+
# - "british" = meaningful steering (British vs American English - should have good metrics AND show steering effect)
|
|
49
|
+
# - "random" = random pairs (should have BAD metrics AND NO steering effect)
|
|
50
|
+
# These validate which metrics actually predict steering effectiveness
|
|
51
|
+
SYNTHETIC_TYPES=(
|
|
52
|
+
"british"
|
|
53
|
+
"random"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
echo "=========================================="
|
|
57
|
+
echo "Quality Metrics Sweep"
|
|
58
|
+
echo "=========================================="
|
|
59
|
+
echo "Model: $MODEL"
|
|
60
|
+
echo "Output: $OUTPUT_DIR"
|
|
61
|
+
echo "Layer range: $LAYER_RANGE"
|
|
62
|
+
echo "Benchmarks: ${BENCHMARKS[*]}"
|
|
63
|
+
echo "Synthetic types: ${SYNTHETIC_TYPES[*]}"
|
|
64
|
+
echo "=========================================="
|
|
65
|
+
|
|
66
|
+
# Create output directory
|
|
67
|
+
mkdir -p "$OUTPUT_DIR"
|
|
68
|
+
|
|
69
|
+
# ==========================================
|
|
70
|
+
# Part 1: Run optimization for each BENCHMARK task
|
|
71
|
+
# ==========================================
|
|
72
|
+
echo ""
|
|
73
|
+
echo "=========================================="
|
|
74
|
+
echo "Part 1: Benchmark Tasks"
|
|
75
|
+
echo "=========================================="
|
|
76
|
+
|
|
77
|
+
FAILED_BENCHMARKS=()
|
|
78
|
+
COMPLETED_BENCHMARKS=()
|
|
79
|
+
|
|
80
|
+
for BENCHMARK in "${BENCHMARKS[@]}"; do
|
|
81
|
+
echo ""
|
|
82
|
+
echo "=========================================="
|
|
83
|
+
echo "Running: $BENCHMARK"
|
|
84
|
+
echo "=========================================="
|
|
85
|
+
|
|
86
|
+
# Skip if already completed (resume support)
|
|
87
|
+
if is_benchmark_completed "$BENCHMARK"; then
|
|
88
|
+
echo "SKIPPING: $BENCHMARK already completed (found ${BENCHMARK}_metrics.json)"
|
|
89
|
+
COMPLETED_BENCHMARKS+=("$BENCHMARK")
|
|
90
|
+
continue
|
|
91
|
+
fi
|
|
92
|
+
|
|
93
|
+
BENCHMARK_START=$(date +%s)
|
|
94
|
+
|
|
95
|
+
# Run the optimization using wisent CLI with baseline comparison
|
|
96
|
+
if wisent optimize-steering comprehensive "$MODEL" \
|
|
97
|
+
--tasks "$BENCHMARK" \
|
|
98
|
+
--compute-baseline \
|
|
99
|
+
--device cuda \
|
|
100
|
+
--output-dir "$OUTPUT_DIR/$BENCHMARK" \
|
|
101
|
+
2>&1 | tee "$OUTPUT_DIR/${BENCHMARK}_log.txt"; then
|
|
102
|
+
|
|
103
|
+
BENCHMARK_END=$(date +%s)
|
|
104
|
+
DURATION=$((BENCHMARK_END - BENCHMARK_START))
|
|
105
|
+
echo "Completed $BENCHMARK in ${DURATION}s"
|
|
106
|
+
|
|
107
|
+
# Find and copy the results file
|
|
108
|
+
RESULTS_FILE=$(find "$OUTPUT_DIR/$BENCHMARK" -name "steering_comprehensive_*.json" -type f 2>/dev/null | head -1)
|
|
109
|
+
|
|
110
|
+
if [ -n "$RESULTS_FILE" ]; then
|
|
111
|
+
echo "Results saved to: $RESULTS_FILE"
|
|
112
|
+
cp "$RESULTS_FILE" "$OUTPUT_DIR/${BENCHMARK}_metrics.json"
|
|
113
|
+
mark_benchmark_completed "$BENCHMARK"
|
|
114
|
+
COMPLETED_BENCHMARKS+=("$BENCHMARK")
|
|
115
|
+
else
|
|
116
|
+
echo "WARNING: No results file found for $BENCHMARK"
|
|
117
|
+
FAILED_BENCHMARKS+=("$BENCHMARK")
|
|
118
|
+
fi
|
|
119
|
+
else
|
|
120
|
+
echo "ERROR: $BENCHMARK failed"
|
|
121
|
+
FAILED_BENCHMARKS+=("$BENCHMARK")
|
|
122
|
+
fi
|
|
123
|
+
|
|
124
|
+
# Save intermediate results after each benchmark
|
|
125
|
+
save_intermediate_results
|
|
126
|
+
done
|
|
127
|
+
|
|
128
|
+
# ==========================================
|
|
129
|
+
# Part 2: Run SYNTHETIC steering (british, random)
|
|
130
|
+
# These use --task personalization with --trait
|
|
131
|
+
# ==========================================
|
|
132
|
+
echo ""
|
|
133
|
+
echo "=========================================="
|
|
134
|
+
echo "Part 2: Synthetic Steering Validation"
|
|
135
|
+
echo "=========================================="
|
|
136
|
+
|
|
137
|
+
for SYNTHETIC_TYPE in "${SYNTHETIC_TYPES[@]}"; do
|
|
138
|
+
echo ""
|
|
139
|
+
echo "=========================================="
|
|
140
|
+
echo "Running synthetic: $SYNTHETIC_TYPE"
|
|
141
|
+
echo "=========================================="
|
|
142
|
+
|
|
143
|
+
# Skip if already completed
|
|
144
|
+
if [ -f "$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}_metrics.json" ]; then
|
|
145
|
+
echo "SKIPPING: synthetic_$SYNTHETIC_TYPE already completed"
|
|
146
|
+
COMPLETED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
|
|
147
|
+
continue
|
|
148
|
+
fi
|
|
149
|
+
|
|
150
|
+
SYNTHETIC_START=$(date +%s)
|
|
151
|
+
SYNTHETIC_DIR="$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}"
|
|
152
|
+
|
|
153
|
+
# Run the optimization with personalization task
|
|
154
|
+
if wisent optimize-steering personalization \
|
|
155
|
+
--model "$MODEL" \
|
|
156
|
+
--trait "$SYNTHETIC_TYPE" \
|
|
157
|
+
--num-pairs 50 \
|
|
158
|
+
--output-dir "$SYNTHETIC_DIR" \
|
|
159
|
+
--device cuda \
|
|
160
|
+
2>&1 | tee "$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}_log.txt"; then
|
|
161
|
+
|
|
162
|
+
SYNTHETIC_END=$(date +%s)
|
|
163
|
+
DURATION=$((SYNTHETIC_END - SYNTHETIC_START))
|
|
164
|
+
echo "Completed synthetic $SYNTHETIC_TYPE in ${DURATION}s"
|
|
165
|
+
|
|
166
|
+
# Find the results file
|
|
167
|
+
RESULTS_FILE=$(find "$SYNTHETIC_DIR" -name "*.json" -type f 2>/dev/null | head -1)
|
|
168
|
+
|
|
169
|
+
if [ -n "$RESULTS_FILE" ]; then
|
|
170
|
+
echo "Results saved to: $RESULTS_FILE"
|
|
171
|
+
cp "$RESULTS_FILE" "$OUTPUT_DIR/synthetic_${SYNTHETIC_TYPE}_metrics.json"
|
|
172
|
+
COMPLETED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
|
|
173
|
+
else
|
|
174
|
+
echo "WARNING: No results file found for synthetic_$SYNTHETIC_TYPE"
|
|
175
|
+
FAILED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
|
|
176
|
+
fi
|
|
177
|
+
else
|
|
178
|
+
echo "ERROR: synthetic_$SYNTHETIC_TYPE failed"
|
|
179
|
+
FAILED_BENCHMARKS+=("synthetic_$SYNTHETIC_TYPE")
|
|
180
|
+
fi
|
|
181
|
+
|
|
182
|
+
# Save intermediate results after each synthetic
|
|
183
|
+
save_intermediate_results
|
|
184
|
+
done
|
|
185
|
+
|
|
186
|
+
# ==========================================
|
|
187
|
+
# Part 3: Combine all results
|
|
188
|
+
# ==========================================
|
|
189
|
+
echo ""
|
|
190
|
+
echo "=========================================="
|
|
191
|
+
echo "Combining Results"
|
|
192
|
+
echo "=========================================="
|
|
193
|
+
|
|
194
|
+
combine_all_results
|
|
195
|
+
|
|
196
|
+
echo ""
|
|
197
|
+
echo "=========================================="
|
|
198
|
+
echo "Sweep Complete!"
|
|
199
|
+
echo "=========================================="
|
|
200
|
+
echo "Results in: $OUTPUT_DIR"
|
|
201
|
+
ls -la "$OUTPUT_DIR"/*.json 2>/dev/null || echo "No JSON files found"
|
|
202
|
+
echo ""
|
|
203
|
+
echo "Completed benchmarks: ${COMPLETED_BENCHMARKS[*]:-none}"
|
|
204
|
+
echo "Failed benchmarks: ${FAILED_BENCHMARKS[*]:-none}"
|
|
205
|
+
echo ""
|
|
206
|
+
|
|
207
|
+
# Final upload to GCS
|
|
208
|
+
upload_final_to_gcs
|
|
209
|
+
|
|
210
|
+
echo "Done!"
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wisent-tools
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Operational scripts and benchmark-evaluation runners for the wisent package family
|
|
5
|
+
Home-page: https://github.com/wisent-ai/wisent-tools
|
|
6
|
+
Author: Lukasz Bartoszcze and the Wisent Team
|
|
7
|
+
Author-email: lukasz.bartoszcze@wisent.ai
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Requires-Dist: wisent>=0.10.0
|
|
10
|
+
Requires-Dist: wisent-evaluators>=0.1.0
|
|
11
|
+
Dynamic: author
|
|
12
|
+
Dynamic: author-email
|
|
13
|
+
Dynamic: home-page
|
|
14
|
+
Dynamic: requires-dist
|
|
15
|
+
Dynamic: requires-python
|
|
16
|
+
Dynamic: summary
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
setup.py
|
|
4
|
+
wisent/__init__.py
|
|
5
|
+
wisent/scripts/__init__.py
|
|
6
|
+
wisent/scripts/extract_all_missing.py
|
|
7
|
+
wisent/scripts/extract_raw_activations.py
|
|
8
|
+
wisent/scripts/fix_extractor_order.py
|
|
9
|
+
wisent/scripts/run_quality_metrics_sweep.sh
|
|
10
|
+
wisent/scripts/_helpers/__init__.py
|
|
11
|
+
wisent/scripts/_helpers/extract_all_missing_helpers.py
|
|
12
|
+
wisent/scripts/_helpers/extract_raw_db.py
|
|
13
|
+
wisent/scripts/_helpers/extract_raw_helpers.py
|
|
14
|
+
wisent_tools.egg-info/PKG-INFO
|
|
15
|
+
wisent_tools.egg-info/SOURCES.txt
|
|
16
|
+
wisent_tools.egg-info/dependency_links.txt
|
|
17
|
+
wisent_tools.egg-info/requires.txt
|
|
18
|
+
wisent_tools.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
wisent
|