wisent 0.5.11__py3-none-any.whl ā 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py ā core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py ā evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py ā core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +8 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py ā core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic ā core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic ā core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic ā core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic ā core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic ā core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic ā core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic ā core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic ā core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.11.dist-info ā wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.11.dist-info/RECORD +0 -220
- /wisent/{benchmarks ā core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding ā core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics ā core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core ā core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer ā core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer ā core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core ā core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench ā core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker ā core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker ā core/opti}/core/__init__.py +0 -0
- /wisent/{opti ā core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers ā core/opti/methods}/__init__.py +0 -0
- /wisent/{opti ā core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti ā core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core ā core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models ā core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli ā core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers ā core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders ā core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators ā core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods ā core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli ā core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands ā core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util ā core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti ā core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core ā core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.11.dist-info ā wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.11.dist-info ā wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.11.dist-info ā wisent-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Answer extractors for TaskInterface benchmarks.
|
|
3
|
+
|
|
4
|
+
These extractors parse model outputs to extract answers for validation.
|
|
5
|
+
Different from LMEvalBenchmarkExtractor which creates contrastive pairs.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Optional
|
|
10
|
+
import re
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BenchmarkExtractor(ABC):
|
|
14
|
+
"""Base class for benchmark answer extraction."""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def extract_answer(self, text: str) -> Optional[str]:
|
|
18
|
+
"""Extract answer from model's generated text."""
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
def normalize_answer(self, answer: str) -> str:
|
|
22
|
+
"""Normalize answer for comparison."""
|
|
23
|
+
if answer is None:
|
|
24
|
+
return ""
|
|
25
|
+
return answer.lower().strip()
|
|
26
|
+
|
|
27
|
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
|
28
|
+
"""Check if predicted answer matches expected."""
|
|
29
|
+
return self.normalize_answer(predicted) == self.normalize_answer(expected)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class GSM8KExtractor(BenchmarkExtractor):
|
|
33
|
+
"""Extractor for GSM8K and math tasks."""
|
|
34
|
+
|
|
35
|
+
def extract_answer(self, text: str) -> Optional[str]:
|
|
36
|
+
"""Extract numerical answer from text."""
|
|
37
|
+
if not text:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
# Strategy 1: JSON format {"final_answer": "123"}
|
|
41
|
+
try:
|
|
42
|
+
import json
|
|
43
|
+
if "{" in text and "}" in text:
|
|
44
|
+
# Find JSON-like structures
|
|
45
|
+
json_match = re.search(r'\{[^}]*"final_answer"[^}]*\}', text)
|
|
46
|
+
if json_match:
|
|
47
|
+
data = json.loads(json_match.group(0))
|
|
48
|
+
answer = data.get("final_answer")
|
|
49
|
+
if answer:
|
|
50
|
+
# Remove commas and non-numeric characters except decimal and minus
|
|
51
|
+
answer = re.sub(r'[^\d.\-]', '', str(answer))
|
|
52
|
+
if answer and answer.replace('.', '').replace('-', '').isdigit():
|
|
53
|
+
return answer
|
|
54
|
+
except Exception:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
# Strategy 2: GSM8K format "#### 123"
|
|
58
|
+
hash_match = re.search(r'####\s*([\d,.\-]+)', text)
|
|
59
|
+
if hash_match:
|
|
60
|
+
answer = hash_match.group(1).replace(',', '')
|
|
61
|
+
return answer
|
|
62
|
+
|
|
63
|
+
# Strategy 3: "The answer is 123" or "The final answer is 123"
|
|
64
|
+
answer_patterns = [
|
|
65
|
+
r'(?:the\s+)?(?:final\s+)?answer\s+is\s*:?\s*([\d,.\-]+)',
|
|
66
|
+
r'(?:therefore|thus|so),?\s+(?:the\s+)?(?:final\s+)?answer\s+is\s*:?\s*([\d,.\-]+)',
|
|
67
|
+
r'=\s*([\d,.\-]+)\s*$',
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
for pattern in answer_patterns:
|
|
71
|
+
match = re.search(pattern, text, re.IGNORECASE)
|
|
72
|
+
if match:
|
|
73
|
+
answer = match.group(1).replace(',', '')
|
|
74
|
+
return answer
|
|
75
|
+
|
|
76
|
+
# Strategy 4: Last number in text (fallback)
|
|
77
|
+
numbers = re.findall(r'-?\d+(?:\.\d+)?', text)
|
|
78
|
+
if numbers:
|
|
79
|
+
return numbers[-1]
|
|
80
|
+
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
|
84
|
+
"""Compare numerical answers with tolerance."""
|
|
85
|
+
if predicted is None:
|
|
86
|
+
return False
|
|
87
|
+
try:
|
|
88
|
+
pred_float = float(predicted)
|
|
89
|
+
expected_float = float(expected)
|
|
90
|
+
return abs(pred_float - expected_float) < 1e-6
|
|
91
|
+
except (ValueError, TypeError):
|
|
92
|
+
return self.normalize_answer(predicted) == self.normalize_answer(expected)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class LiveCodeBenchExtractor(BenchmarkExtractor):
|
|
96
|
+
"""Extractor for coding tasks."""
|
|
97
|
+
|
|
98
|
+
def extract_answer(self, text: str) -> Optional[str]:
|
|
99
|
+
"""Extract code from model response."""
|
|
100
|
+
if not text:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
# Strategy 1: Extract from markdown code blocks
|
|
104
|
+
code_block_patterns = [
|
|
105
|
+
r'```python\s*(.*?)```',
|
|
106
|
+
r'```\s*(.*?)```',
|
|
107
|
+
r'`([^`]+)`',
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
for pattern in code_block_patterns:
|
|
111
|
+
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
|
|
112
|
+
if matches:
|
|
113
|
+
# Return the longest code block found
|
|
114
|
+
code = max(matches, key=len)
|
|
115
|
+
return code.strip()
|
|
116
|
+
|
|
117
|
+
# Strategy 2: Look for function definitions
|
|
118
|
+
func_match = re.search(r'(def\s+\w+.*?)(?:\n\n|\Z)', text, re.DOTALL)
|
|
119
|
+
if func_match:
|
|
120
|
+
return func_match.group(1).strip()
|
|
121
|
+
|
|
122
|
+
# Strategy 3: Look for class definitions
|
|
123
|
+
class_match = re.search(r'(class\s+\w+.*?)(?:\n\n|\Z)', text, re.DOTALL)
|
|
124
|
+
if class_match:
|
|
125
|
+
return class_match.group(1).strip()
|
|
126
|
+
|
|
127
|
+
# Strategy 4: Return entire text if it looks like code
|
|
128
|
+
if 'def ' in text or 'class ' in text or 'import ' in text:
|
|
129
|
+
return text.strip()
|
|
130
|
+
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
|
134
|
+
"""For coding tasks, basic check: code is not empty."""
|
|
135
|
+
if predicted is None:
|
|
136
|
+
return False
|
|
137
|
+
# Basic check: code is not empty
|
|
138
|
+
return len(predicted.strip()) > 0
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class HLEExtractor(BenchmarkExtractor):
|
|
142
|
+
"""Extractor for HLE tasks."""
|
|
143
|
+
|
|
144
|
+
def extract_answer(self, text: str) -> Optional[str]:
|
|
145
|
+
"""Extract answer from HLE response."""
|
|
146
|
+
if not text:
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
# Strategy 1: JSON format {"answer": "X"}
|
|
150
|
+
try:
|
|
151
|
+
import json
|
|
152
|
+
if "{" in text and "}" in text:
|
|
153
|
+
json_match = re.search(r'\{[^}]*"answer"[^}]*\}', text)
|
|
154
|
+
if json_match:
|
|
155
|
+
data = json.loads(json_match.group(0))
|
|
156
|
+
return str(data.get("answer", ""))
|
|
157
|
+
except Exception:
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
# Strategy 2: "Answer: X" format
|
|
161
|
+
answer_match = re.search(r'answer\s*:\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
|
|
162
|
+
if answer_match:
|
|
163
|
+
return answer_match.group(1).strip()
|
|
164
|
+
|
|
165
|
+
# Strategy 3: Multiple choice (A, B, C, D)
|
|
166
|
+
mc_match = re.search(r'\b([A-D])\b', text)
|
|
167
|
+
if mc_match:
|
|
168
|
+
return mc_match.group(1)
|
|
169
|
+
|
|
170
|
+
# Strategy 4: First line (fallback)
|
|
171
|
+
first_line = text.split('\n')[0].strip()
|
|
172
|
+
if first_line:
|
|
173
|
+
return first_line
|
|
174
|
+
|
|
175
|
+
return None
|
|
176
|
+
|
|
177
|
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
|
178
|
+
"""Compare answers with case-insensitive substring matching."""
|
|
179
|
+
if predicted is None:
|
|
180
|
+
return False
|
|
181
|
+
pred_norm = self.normalize_answer(predicted)
|
|
182
|
+
exp_norm = self.normalize_answer(expected)
|
|
183
|
+
return exp_norm in pred_norm or pred_norm in exp_norm
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class SuperGPQAExtractor(BenchmarkExtractor):
|
|
187
|
+
"""Extractor for SuperGPQA science tasks."""
|
|
188
|
+
|
|
189
|
+
def extract_answer(self, text: str) -> Optional[str]:
|
|
190
|
+
"""Extract answer from SuperGPQA response."""
|
|
191
|
+
if not text:
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
# Strategy 1: "Answer: A" format
|
|
195
|
+
answer_match = re.search(r'answer\s*:\s*([A-D])', text, re.IGNORECASE)
|
|
196
|
+
if answer_match:
|
|
197
|
+
return answer_match.group(1).upper()
|
|
198
|
+
|
|
199
|
+
# Strategy 2: (A) or [A] format
|
|
200
|
+
bracket_match = re.search(r'[\(\[]\s*([A-D])\s*[\)\]]', text, re.IGNORECASE)
|
|
201
|
+
if bracket_match:
|
|
202
|
+
return bracket_match.group(1).upper()
|
|
203
|
+
|
|
204
|
+
# Strategy 3: Standalone letter
|
|
205
|
+
letter_match = re.search(r'\b([A-D])\b', text)
|
|
206
|
+
if letter_match:
|
|
207
|
+
return letter_match.group(1).upper()
|
|
208
|
+
|
|
209
|
+
# Strategy 4: First character if it's A-D
|
|
210
|
+
first_char = text.strip()[0].upper() if text.strip() else None
|
|
211
|
+
if first_char in ['A', 'B', 'C', 'D']:
|
|
212
|
+
return first_char
|
|
213
|
+
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
def check_answer(self, predicted: str, expected: str) -> bool:
|
|
217
|
+
"""Compare answers (case-insensitive letter comparison)."""
|
|
218
|
+
if predicted is None:
|
|
219
|
+
return False
|
|
220
|
+
return self.normalize_answer(predicted) == self.normalize_answer(expected)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# Registry mapping task names to extractors
|
|
224
|
+
_EXTRACTOR_REGISTRY = {}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _populate_registry():
|
|
228
|
+
"""Populate the extractor registry."""
|
|
229
|
+
global _EXTRACTOR_REGISTRY
|
|
230
|
+
|
|
231
|
+
# Math tasks use GSM8KExtractor
|
|
232
|
+
math_tasks = [
|
|
233
|
+
"gsm8k", "math", "math500", "hendrycks_math",
|
|
234
|
+
"aime", "aime2024", "aime2025",
|
|
235
|
+
"hmmt", "hmmt_feb_2025",
|
|
236
|
+
"polymath", "polymath_en_medium", "polymath_zh_medium",
|
|
237
|
+
"polymath_en_high", "polymath_zh_high",
|
|
238
|
+
"livemathbench", "livemathbench_cnmo_en", "livemathbench_cnmo_zh",
|
|
239
|
+
]
|
|
240
|
+
for task in math_tasks:
|
|
241
|
+
_EXTRACTOR_REGISTRY[task] = GSM8KExtractor()
|
|
242
|
+
|
|
243
|
+
# Coding tasks use LiveCodeBenchExtractor
|
|
244
|
+
coding_tasks = [
|
|
245
|
+
"livecodebench", "humaneval", "mbpp", "humaneval_plus", "mbpp_plus",
|
|
246
|
+
"instructhumaneval", "apps", "ds1000",
|
|
247
|
+
"multiple_py", "multiple_js", "multiple_java", "multiple_cpp", "multiple_rs", "multiple_go",
|
|
248
|
+
"conala", "concode", "mercury", "recode",
|
|
249
|
+
"codexglue_code_to_text_python", "codexglue_code_to_text_go",
|
|
250
|
+
"codexglue_code_to_text_ruby", "codexglue_code_to_text_java",
|
|
251
|
+
"codexglue_code_to_text_javascript", "codexglue_code_to_text_php",
|
|
252
|
+
]
|
|
253
|
+
for task in coding_tasks:
|
|
254
|
+
_EXTRACTOR_REGISTRY[task] = LiveCodeBenchExtractor()
|
|
255
|
+
|
|
256
|
+
# HLE tasks use HLEExtractor
|
|
257
|
+
hle_tasks = ["hle", "hle_exact_match", "hle_multiple_choice"]
|
|
258
|
+
for task in hle_tasks:
|
|
259
|
+
_EXTRACTOR_REGISTRY[task] = HLEExtractor()
|
|
260
|
+
|
|
261
|
+
# Science tasks use SuperGPQAExtractor
|
|
262
|
+
science_tasks = ["supergpqa", "supergpqa_physics", "supergpqa_chemistry", "supergpqa_biology"]
|
|
263
|
+
for task in science_tasks:
|
|
264
|
+
_EXTRACTOR_REGISTRY[task] = SuperGPQAExtractor()
|
|
265
|
+
|
|
266
|
+
# QA tasks use generic extractor (HLEExtractor works well for these)
|
|
267
|
+
qa_tasks = ["truthfulqa_mc1", "mmlu", "squad2"]
|
|
268
|
+
for task in qa_tasks:
|
|
269
|
+
_EXTRACTOR_REGISTRY[task] = HLEExtractor()
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# Populate on module load
|
|
273
|
+
_populate_registry()
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def get_extractor(task_name: str) -> BenchmarkExtractor:
|
|
277
|
+
"""Get the appropriate extractor for a task."""
|
|
278
|
+
if task_name in _EXTRACTOR_REGISTRY:
|
|
279
|
+
return _EXTRACTOR_REGISTRY[task_name]
|
|
280
|
+
|
|
281
|
+
# Fallback logic based on task name patterns
|
|
282
|
+
task_lower = task_name.lower()
|
|
283
|
+
if any(keyword in task_lower for keyword in ["math", "aime", "hmmt", "gsm", "arithmetic"]):
|
|
284
|
+
return GSM8KExtractor()
|
|
285
|
+
elif any(keyword in task_lower for keyword in ["code", "human", "mbpp", "programming"]):
|
|
286
|
+
return LiveCodeBenchExtractor()
|
|
287
|
+
elif "hle" in task_lower:
|
|
288
|
+
return HLEExtractor()
|
|
289
|
+
elif any(keyword in task_lower for keyword in ["gpqa", "science", "physics", "chemistry", "biology"]):
|
|
290
|
+
return SuperGPQAExtractor()
|
|
291
|
+
else:
|
|
292
|
+
# Default fallback
|
|
293
|
+
return HLEExtractor()
|
|
@@ -325,9 +325,12 @@ class BigCodeEvaluator:
|
|
|
325
325
|
|
|
326
326
|
def _execute_in_docker(self, sample: Dict, generation: str, task_name: str) -> Dict:
|
|
327
327
|
"""Execute code in Docker container."""
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
328
|
+
raise NotImplementedError(
|
|
329
|
+
"Docker execution is not yet implemented. "
|
|
330
|
+
"To execute code safely, please implement Docker containerization for code execution, "
|
|
331
|
+
"or use the subprocess executor by setting docker_executor=None (note: subprocess execution "
|
|
332
|
+
"is less secure and should only be used in trusted environments)."
|
|
333
|
+
)
|
|
331
334
|
|
|
332
335
|
def _execute_in_subprocess(self, sample: Dict, generation: str, task_name: str) -> Dict:
|
|
333
336
|
"""Execute code in subprocess (less secure)."""
|
|
@@ -540,10 +543,20 @@ def {expected_name}(*args, **kwargs):
|
|
|
540
543
|
return total_passed / total_samples if total_samples > 0 else 0.0
|
|
541
544
|
|
|
542
545
|
def _evaluate_text_generation(self, task: BigCodeTask, generations: List[str]) -> List[float]:
|
|
543
|
-
"""Evaluate text generation tasks (e.g., code-to-text).
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
546
|
+
"""Evaluate text generation tasks (e.g., code-to-text).
|
|
547
|
+
|
|
548
|
+
Raises:
|
|
549
|
+
NotImplementedError: BLEU scoring is not yet implemented
|
|
550
|
+
"""
|
|
551
|
+
raise NotImplementedError(
|
|
552
|
+
"BLEU scoring for text generation tasks is not yet implemented. "
|
|
553
|
+
"This requires:\n"
|
|
554
|
+
" 1. Reference text extraction from task samples\n"
|
|
555
|
+
" 2. BLEU score computation (using sacrebleu or similar library)\n"
|
|
556
|
+
" 3. Proper tokenization and n-gram matching\n"
|
|
557
|
+
"Please implement BLEU scoring or use alternative evaluation metrics "
|
|
558
|
+
"(exact match, F1 score) for code generation tasks."
|
|
559
|
+
)
|
|
547
560
|
|
|
548
561
|
|
|
549
562
|
# Main interface for BigCode integration
|
wisent/core/branding.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Wisent project branding assets (ASCII art logo, project description).
|
|
3
|
+
|
|
4
|
+
This module provides reusable branding components that can be used by:
|
|
5
|
+
- CLI interfaces
|
|
6
|
+
- Documentation
|
|
7
|
+
- Web interfaces
|
|
8
|
+
- Any other presentation layer
|
|
9
|
+
|
|
10
|
+
The ASCII art represents a Wisent (European bison), the project's namesake.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
__all__ = ["WISENT_ASCII_LOGO", "PROJECT_TAGLINE", "render_banner", "get_logo"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
WISENT_ASCII_LOGO = """
|
|
19
|
+
................. .:--++*##%%%%##**+=-:. .................
|
|
20
|
+
.. .:=*%@@@@@@@%%%%%%%@@@@@@%*=:. ..
|
|
21
|
+
. .-*%@@@%#+=-::.........:-=+#%@@@%*=. .
|
|
22
|
+
. -*%@@@#=:. .:=*%@@@*-. .
|
|
23
|
+
. .-#@@@*=. .-*@@@#-. .
|
|
24
|
+
. :#@@@*: :+%@@#- .
|
|
25
|
+
. .+@@@*: :+@@@+. .
|
|
26
|
+
. .*@@@@%*=:. -%@@#: .
|
|
27
|
+
. .#@@#=*%@@@%*-:. .#@@%: .
|
|
28
|
+
..*@@%. .-+#@@@@#+-:. .*@@%..
|
|
29
|
+
.=@@@- :-+#@@@@%*=:. .%@@*.
|
|
30
|
+
:#@@+ .:-+#@@@@%#+=:. -@@@-
|
|
31
|
+
=@@@: .-=*%@@@@%#+=:.. .#@@+
|
|
32
|
+
+@@@*=:. .:-+*%@@@@%#*=-:.. *@@+
|
|
33
|
+
+@@@@@@#+-.. .:-=*#@@@@@%#*+--.. +@@+
|
|
34
|
+
+@@#-+%@@@%: .:-=*#%@@@@@%#*+=-:.*@@+
|
|
35
|
+
=@@%. .=@@@: ..:-=+#%%@@@@@%@@@+
|
|
36
|
+
:%@@= :@@@- ..::-=+#@@@=
|
|
37
|
+
.+@@%. .#@@* +@@#:
|
|
38
|
+
..%@@*. =@@@: =@@@-.
|
|
39
|
+
. :%@@*..#@@#. .:.. =@@@= .
|
|
40
|
+
. :%@@*.:%@@*. :#@@%#*+=-::..+@@@= .
|
|
41
|
+
. :#@@%-:%@@#: .+@@@#%%@@@@@@%%@@%- .
|
|
42
|
+
. .+@@@*=#@@%- .=%@@%=...::-=#@@@@*. .
|
|
43
|
+
. :*@@@#%@@@*: .=%@@@+. .:*%@@#- .
|
|
44
|
+
. :+%@@@@@@@*-. :=*@@@%+. .-+%@@@*-. .
|
|
45
|
+
. .=*%@@@@@@#+:.:-+#@@@%*-. .:-+#%@@@#+: .
|
|
46
|
+
. .-+#%@@@@@@@@@@@@#*+**#@@@@@%*=:. .
|
|
47
|
+
.............. ..-=+*#%%%@@@@@@@@%%#*=-:. ..............
|
|
48
|
+
................... ....:::::::::.... ...................
|
|
49
|
+
""".strip()
|
|
50
|
+
|
|
51
|
+
PROJECT_TAGLINE = "Steering vectors & activation tooling"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_logo(width: int = 48) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Get the Wisent ASCII logo, optionally centered to a specific width.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
width: Total width to center the logo within (default: 48)
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
ASCII art logo as a string
|
|
63
|
+
"""
|
|
64
|
+
return "\n".join(line.center(width) for line in WISENT_ASCII_LOGO.splitlines())
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def render_banner(title: str, width: int = 48, use_color: bool = True) -> str:
|
|
68
|
+
"""
|
|
69
|
+
Render a banner with the Wisent logo, title, and tagline.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
title: Title text to display (e.g., "Wisent-Guard CLI")
|
|
73
|
+
width: Width for centering (default: 48)
|
|
74
|
+
use_color: Whether to use ANSI color codes (default: True)
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Formatted banner as a string
|
|
78
|
+
"""
|
|
79
|
+
logo = get_logo(width)
|
|
80
|
+
|
|
81
|
+
if use_color:
|
|
82
|
+
GREEN = "\x1b[32m"
|
|
83
|
+
BOLD = "\x1b[1m"
|
|
84
|
+
OFF = "\x1b[0m"
|
|
85
|
+
banner = f"{GREEN}{logo}{OFF}\n"
|
|
86
|
+
banner += f"{BOLD}{GREEN}{title}{OFF} ā {PROJECT_TAGLINE}\n"
|
|
87
|
+
else:
|
|
88
|
+
banner = f"{logo}\n"
|
|
89
|
+
banner += f"{title} ā {PROJECT_TAGLINE}\n"
|
|
90
|
+
|
|
91
|
+
return banner
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def print_banner(title: str, width: int = 48, use_color: bool = True) -> None:
|
|
95
|
+
"""
|
|
96
|
+
Print a banner with the Wisent logo, title, and tagline.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
title: Title text to display (e.g., "Wisent-Guard CLI")
|
|
100
|
+
width: Width for centering (default: 48)
|
|
101
|
+
use_color: Whether to use ANSI color codes (default: True)
|
|
102
|
+
"""
|
|
103
|
+
print(render_banner(title, width, use_color))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
if __name__ == "__main__":
|
|
107
|
+
# Demo the branding
|
|
108
|
+
print_banner("Wisent-Guard")
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""CLI execution logic for Wisent commands."""
|
|
2
|
+
|
|
3
|
+
from .tasks import execute_tasks
|
|
4
|
+
from .generate_pairs_from_task import execute_generate_pairs_from_task
|
|
5
|
+
from .generate_pairs import execute_generate_pairs
|
|
6
|
+
from .get_activations import execute_get_activations
|
|
7
|
+
from .create_steering_vector import execute_create_steering_vector
|
|
8
|
+
from .generate_vector_from_task import execute_generate_vector_from_task
|
|
9
|
+
from .generate_vector_from_synthetic import execute_generate_vector_from_synthetic
|
|
10
|
+
from .optimize_classification import execute_optimize_classification
|
|
11
|
+
from .optimize_steering import execute_optimize_steering
|
|
12
|
+
from .generate_responses import execute_generate_responses
|
|
13
|
+
from .evaluate_responses import execute_evaluate_responses
|
|
14
|
+
|
|
15
|
+
__all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_get_activations', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_generate_responses', 'execute_evaluate_responses']
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Create steering vector command execution logic."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
import torch
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def execute_create_steering_vector(args):
|
|
12
|
+
"""Execute the create-steering-vector command - load enriched pairs and create steering vectors."""
|
|
13
|
+
from wisent.core.steering_methods.methods.caa import CAAMethod
|
|
14
|
+
|
|
15
|
+
print(f"\nšÆ Creating steering vectors from enriched pairs")
|
|
16
|
+
print(f" Input file: {args.enriched_pairs_file}")
|
|
17
|
+
print(f" Method: {args.method}")
|
|
18
|
+
|
|
19
|
+
start_time = time.time() if args.timing else None
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
# 1. Load enriched pairs from JSON
|
|
23
|
+
print(f"\nš Loading enriched pairs...")
|
|
24
|
+
if not os.path.exists(args.enriched_pairs_file):
|
|
25
|
+
raise FileNotFoundError(f"Enriched pairs file not found: {args.enriched_pairs_file}")
|
|
26
|
+
|
|
27
|
+
with open(args.enriched_pairs_file, 'r') as f:
|
|
28
|
+
data = json.load(f)
|
|
29
|
+
|
|
30
|
+
# Extract metadata
|
|
31
|
+
trait_label = data.get('trait_label', 'unknown')
|
|
32
|
+
model = data.get('model', 'unknown')
|
|
33
|
+
layers = data.get('layers', [])
|
|
34
|
+
token_aggregation = data.get('token_aggregation', 'unknown')
|
|
35
|
+
pairs_list = data.get('pairs', [])
|
|
36
|
+
|
|
37
|
+
print(f" ā Loaded {len(pairs_list)} pairs")
|
|
38
|
+
print(f" ā Model: {model}")
|
|
39
|
+
print(f" ā Layers: {layers}")
|
|
40
|
+
print(f" ā Token aggregation: {token_aggregation}")
|
|
41
|
+
|
|
42
|
+
# 2. Organize activations by layer
|
|
43
|
+
print(f"\nš Organizing activations by layer...")
|
|
44
|
+
|
|
45
|
+
# Structure: {layer_str: {"positive": [tensors], "negative": [tensors]}}
|
|
46
|
+
layer_activations = defaultdict(lambda: {"positive": [], "negative": []})
|
|
47
|
+
|
|
48
|
+
for pair in pairs_list:
|
|
49
|
+
# Extract positive activations
|
|
50
|
+
pos_layers = pair['positive_response'].get('layers_activations', {})
|
|
51
|
+
for layer_str, activation_list in pos_layers.items():
|
|
52
|
+
if activation_list is not None:
|
|
53
|
+
tensor = torch.tensor(activation_list, dtype=torch.float32)
|
|
54
|
+
layer_activations[layer_str]["positive"].append(tensor)
|
|
55
|
+
|
|
56
|
+
# Extract negative activations
|
|
57
|
+
neg_layers = pair['negative_response'].get('layers_activations', {})
|
|
58
|
+
for layer_str, activation_list in neg_layers.items():
|
|
59
|
+
if activation_list is not None:
|
|
60
|
+
tensor = torch.tensor(activation_list, dtype=torch.float32)
|
|
61
|
+
layer_activations[layer_str]["negative"].append(tensor)
|
|
62
|
+
|
|
63
|
+
available_layers = sorted(layer_activations.keys(), key=int)
|
|
64
|
+
print(f" ā Found activations for {len(available_layers)} layers: {available_layers}")
|
|
65
|
+
|
|
66
|
+
# 3. Create steering method instance
|
|
67
|
+
print(f"\nš§ Initializing {args.method.upper()} steering method...")
|
|
68
|
+
|
|
69
|
+
if args.method == "caa":
|
|
70
|
+
method = CAAMethod(kwargs={"normalize": args.normalize})
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Unknown method: {args.method}")
|
|
73
|
+
|
|
74
|
+
print(f" ā Method initialized (normalize={args.normalize})")
|
|
75
|
+
|
|
76
|
+
# 4. Generate steering vectors for each layer
|
|
77
|
+
print(f"\nā” Generating steering vectors...")
|
|
78
|
+
steering_vectors = {}
|
|
79
|
+
|
|
80
|
+
for layer_str in available_layers:
|
|
81
|
+
pos_list = layer_activations[layer_str]["positive"]
|
|
82
|
+
neg_list = layer_activations[layer_str]["negative"]
|
|
83
|
+
|
|
84
|
+
if args.verbose:
|
|
85
|
+
print(f" Processing layer {layer_str}: {len(pos_list)} positive, {len(neg_list)} negative")
|
|
86
|
+
|
|
87
|
+
if not pos_list or not neg_list:
|
|
88
|
+
print(f" ā ļø Skipping layer {layer_str}: missing activations")
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
# Generate steering vector for this layer
|
|
92
|
+
vector = method.train_for_layer(pos_list, neg_list)
|
|
93
|
+
steering_vectors[layer_str] = vector.tolist() # Convert to list for JSON
|
|
94
|
+
|
|
95
|
+
print(f" ā Generated {len(steering_vectors)} steering vectors")
|
|
96
|
+
|
|
97
|
+
# 5. Save steering vectors to JSON
|
|
98
|
+
print(f"\nš¾ Saving steering vectors to '{args.output}'...")
|
|
99
|
+
output_data = {
|
|
100
|
+
'trait_label': trait_label,
|
|
101
|
+
'model': model,
|
|
102
|
+
'method': args.method,
|
|
103
|
+
'normalize': args.normalize,
|
|
104
|
+
'token_aggregation': token_aggregation,
|
|
105
|
+
'num_pairs': len(pairs_list),
|
|
106
|
+
'layers': list(steering_vectors.keys()),
|
|
107
|
+
'steering_vectors': steering_vectors,
|
|
108
|
+
'metadata': {
|
|
109
|
+
'source_file': args.enriched_pairs_file,
|
|
110
|
+
'creation_time': time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
|
115
|
+
with open(args.output, 'w') as f:
|
|
116
|
+
json.dump(output_data, f, indent=2)
|
|
117
|
+
|
|
118
|
+
print(f" ā Saved steering vectors to: {args.output}")
|
|
119
|
+
|
|
120
|
+
# 6. Display statistics
|
|
121
|
+
print(f"\nš Steering Vector Statistics:")
|
|
122
|
+
for layer_str in sorted(steering_vectors.keys(), key=int):
|
|
123
|
+
vector = torch.tensor(steering_vectors[layer_str])
|
|
124
|
+
norm = torch.linalg.norm(vector).item()
|
|
125
|
+
print(f" Layer {layer_str}: dim={len(vector)}, norm={norm:.4f}")
|
|
126
|
+
|
|
127
|
+
if args.timing:
|
|
128
|
+
elapsed = time.time() - start_time
|
|
129
|
+
print(f" ā±ļø Total time: {elapsed:.2f}s")
|
|
130
|
+
|
|
131
|
+
print(f"\nā
Steering vector creation completed successfully!\\n")
|
|
132
|
+
|
|
133
|
+
except Exception as e:
|
|
134
|
+
print(f"\nā Error: {str(e)}", file=sys.stderr)
|
|
135
|
+
if args.verbose:
|
|
136
|
+
import traceback
|
|
137
|
+
traceback.print_exc()
|
|
138
|
+
sys.exit(1)
|