wisent 0.5.12__py3-none-any.whl ā 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/__init__.py +1 -18
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/__init__.py +1 -55
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py ā core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py ā evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py ā core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py ā core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic ā core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic ā core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic ā core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic ā core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic ā core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic ā core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic ā core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic ā core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/METADATA +3 -3
- wisent-0.5.14.dist-info/RECORD +294 -0
- wisent-0.5.14.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks ā core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding ā core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics ā core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core ā core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer ā core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer ā core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core ā core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench ā core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker ā core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker ā core/opti}/core/__init__.py +0 -0
- /wisent/{opti ā core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers ā core/opti/methods}/__init__.py +0 -0
- /wisent/{opti ā core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti ā core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core ā core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models ā core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli ā core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers ā core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders ā core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators ā core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods ā core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli ā core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands ā core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util ā core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti ā core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core ā core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Generate pairs command execution logic - synthetic generation."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def execute_generate_pairs(args):
|
|
9
|
+
"""Execute the generate-pairs command - generate synthetic contrastive pairs from trait description."""
|
|
10
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
11
|
+
from wisent.core.synthetic.generators.pairs_generator import SyntheticContrastivePairsGenerator
|
|
12
|
+
from wisent.core.synthetic.db_instructions.mini_dp import Default_DB_Instructions
|
|
13
|
+
from wisent.core.synthetic.cleaners.pairs_cleaner import PairsCleaner
|
|
14
|
+
from wisent.core.synthetic.cleaners.refusaler_cleaner import RefusalerCleaner
|
|
15
|
+
from wisent.core.synthetic.cleaners.deduper_cleaner import DeduperCleaner
|
|
16
|
+
from wisent.core.synthetic.generators.diversities.methods.fast_diversity import FastDiversity
|
|
17
|
+
|
|
18
|
+
print(f"\nšØ Generating synthetic contrastive pairs")
|
|
19
|
+
print(f" Trait: {args.trait}")
|
|
20
|
+
print(f" Number of pairs: {args.num_pairs}")
|
|
21
|
+
print(f" Model: {args.model}")
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
# 1. Load model
|
|
25
|
+
print(f"\nš¤ Loading model '{args.model}'...")
|
|
26
|
+
model = WisentModel(args.model, device=args.device)
|
|
27
|
+
print(f" ā Model loaded with {model.num_layers} layers")
|
|
28
|
+
|
|
29
|
+
# 2. Set up generation config
|
|
30
|
+
# Scale max_new_tokens based on number of pairs (roughly 150 tokens per pair + buffer)
|
|
31
|
+
estimated_tokens = args.num_pairs * 150 + 500
|
|
32
|
+
max_tokens = max(2048, min(estimated_tokens, 8192)) # Between 2048 and 8192
|
|
33
|
+
|
|
34
|
+
generation_config = {
|
|
35
|
+
"max_new_tokens": max_tokens,
|
|
36
|
+
"temperature": 0.9,
|
|
37
|
+
"do_sample": True,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# 3. Set up cleaning pipeline
|
|
41
|
+
print(f"\nš§¹ Setting up cleaning pipeline...")
|
|
42
|
+
from wisent.core.synthetic.cleaners.methods.base_dedupers import SimHashDeduper
|
|
43
|
+
|
|
44
|
+
cleaning_steps = [
|
|
45
|
+
DeduperCleaner(deduper=SimHashDeduper(threshold_bits=3)),
|
|
46
|
+
]
|
|
47
|
+
cleaner = PairsCleaner(steps=cleaning_steps)
|
|
48
|
+
|
|
49
|
+
# 4. Set up components
|
|
50
|
+
db_instructions = Default_DB_Instructions()
|
|
51
|
+
diversity = FastDiversity()
|
|
52
|
+
|
|
53
|
+
# 5. Create generator
|
|
54
|
+
print(f"\nāļø Initializing generator...")
|
|
55
|
+
generator = SyntheticContrastivePairsGenerator(
|
|
56
|
+
model=model,
|
|
57
|
+
generation_config=generation_config,
|
|
58
|
+
contrastive_set_name=f"synthetic_{args.trait[:20].replace(' ', '_')}",
|
|
59
|
+
trait_description=args.trait,
|
|
60
|
+
trait_label=args.trait[:50],
|
|
61
|
+
db_instructions=db_instructions,
|
|
62
|
+
cleaner=cleaner,
|
|
63
|
+
diversity=diversity,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# 6. Generate pairs
|
|
67
|
+
print(f"\nšÆ Generating {args.num_pairs} contrastive pairs...")
|
|
68
|
+
if args.timing:
|
|
69
|
+
import time
|
|
70
|
+
start_time = time.time()
|
|
71
|
+
|
|
72
|
+
pair_set, report = generator.generate(num_pairs=args.num_pairs)
|
|
73
|
+
|
|
74
|
+
if args.timing:
|
|
75
|
+
elapsed = time.time() - start_time
|
|
76
|
+
print(f" ā±ļø Generation time: {elapsed:.2f}s")
|
|
77
|
+
|
|
78
|
+
print(f" ā Generated {len(pair_set.pairs)} pairs")
|
|
79
|
+
|
|
80
|
+
# 7. Print generation report
|
|
81
|
+
if args.verbose and len(pair_set.pairs) > 0:
|
|
82
|
+
print(f"\nš Generation Report:")
|
|
83
|
+
print(f" Requested: {report.requested}")
|
|
84
|
+
print(f" Kept after dedupe: {report.kept_after_dedupe}")
|
|
85
|
+
print(f" Retries for refusals: {report.retries_for_refusals}")
|
|
86
|
+
if report.diversity:
|
|
87
|
+
print(f" Diversity:")
|
|
88
|
+
print(f" ⢠Unique unigrams: {report.diversity.unique_unigrams:.3f}")
|
|
89
|
+
print(f" ⢠Unique bigrams: {report.diversity.unique_bigrams:.3f}")
|
|
90
|
+
print(f" ⢠Avg Jaccard: {report.diversity.avg_jaccard_prompt:.3f}")
|
|
91
|
+
|
|
92
|
+
# 8. Convert pairs to dict format for JSON serialization
|
|
93
|
+
print(f"\nš¾ Saving pairs to '{args.output}'...")
|
|
94
|
+
pairs_data = []
|
|
95
|
+
for pair in pair_set.pairs:
|
|
96
|
+
pair_dict = pair.to_dict()
|
|
97
|
+
pairs_data.append(pair_dict)
|
|
98
|
+
|
|
99
|
+
# 9. Save to JSON file
|
|
100
|
+
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
|
101
|
+
save_data = {
|
|
102
|
+
'trait_description': args.trait,
|
|
103
|
+
'trait_label': pair_set.task_type,
|
|
104
|
+
'num_pairs': len(pairs_data),
|
|
105
|
+
'generation_config': generation_config,
|
|
106
|
+
'requested': report.requested,
|
|
107
|
+
'kept_after_dedupe': report.kept_after_dedupe,
|
|
108
|
+
'pairs': pairs_data
|
|
109
|
+
}
|
|
110
|
+
if report.diversity:
|
|
111
|
+
save_data['diversity'] = {
|
|
112
|
+
'unique_unigrams': report.diversity.unique_unigrams,
|
|
113
|
+
'unique_bigrams': report.diversity.unique_bigrams,
|
|
114
|
+
'avg_jaccard': report.diversity.avg_jaccard_prompt
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
with open(args.output, 'w') as f:
|
|
118
|
+
json.dump(save_data, f, indent=2)
|
|
119
|
+
|
|
120
|
+
print(f" ā Saved {len(pairs_data)} pairs to: {args.output}")
|
|
121
|
+
print(f"\nā
Synthetic pair generation completed successfully!\n")
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(f"\nā Error: {str(e)}", file=sys.stderr)
|
|
125
|
+
if args.verbose:
|
|
126
|
+
import traceback
|
|
127
|
+
traceback.print_exc()
|
|
128
|
+
sys.exit(1)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Generate pairs from task command execution logic."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _load_custom_task(task_name: str, limit: int | None):
|
|
9
|
+
"""Load custom tasks that aren't in lm-eval."""
|
|
10
|
+
if task_name == "livecodebench":
|
|
11
|
+
from wisent.core.tasks.livecodebench_task import LiveCodeBenchTask
|
|
12
|
+
return LiveCodeBenchTask(release_version="release_v1", limit=limit)
|
|
13
|
+
else:
|
|
14
|
+
raise ValueError(
|
|
15
|
+
f"Task '{task_name}' not found in lm-eval or custom tasks. "
|
|
16
|
+
f"Available custom tasks: livecodebench"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _build_pairs_from_custom_task(task, limit: int | None):
|
|
21
|
+
"""Build contrastive pairs from custom TaskInterface tasks."""
|
|
22
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_extractors.livecodebench import (
|
|
23
|
+
LiveCodeBenchExtractor as LiveCodeBenchPairExtractor
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
task_name = task.get_name()
|
|
27
|
+
|
|
28
|
+
if task_name == "livecodebench":
|
|
29
|
+
# Use the contrastive pair extractor for LiveCodeBench
|
|
30
|
+
extractor = LiveCodeBenchPairExtractor()
|
|
31
|
+
# Extract pairs using the task's test_docs interface
|
|
32
|
+
return extractor.extract_contrastive_pairs(task, limit=limit)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(f"No contrastive pair extractor configured for custom task: {task_name}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def execute_generate_pairs_from_task(args):
|
|
38
|
+
"""Execute the generate-pairs-from-task command - load and save contrastive pairs from a task."""
|
|
39
|
+
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
40
|
+
|
|
41
|
+
print(f"\nš Generating contrastive pairs from task: {args.task_name}")
|
|
42
|
+
|
|
43
|
+
if args.limit:
|
|
44
|
+
print(f" Limit: {args.limit} pairs")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
# 1. Load task data using LMEvalDataLoader
|
|
48
|
+
print(f"\nš Loading task '{args.task_name}'...")
|
|
49
|
+
|
|
50
|
+
# Try to load from lm-eval first
|
|
51
|
+
loader = LMEvalDataLoader()
|
|
52
|
+
try:
|
|
53
|
+
# Use load_lm_eval_task to get the task object
|
|
54
|
+
task_obj = loader.load_lm_eval_task(args.task_name)
|
|
55
|
+
except KeyError:
|
|
56
|
+
# Task not in lm-eval, try our custom tasks
|
|
57
|
+
print(f" ā¹ļø Task not found in lm-eval, trying custom tasks...")
|
|
58
|
+
task_obj = _load_custom_task(args.task_name, args.limit)
|
|
59
|
+
|
|
60
|
+
# Import the pair generation function
|
|
61
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
|
|
62
|
+
lm_build_contrastive_pairs,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Handle both lm-eval tasks (dict) and custom tasks (TaskInterface)
|
|
66
|
+
if isinstance(task_obj, dict):
|
|
67
|
+
# lm-eval task
|
|
68
|
+
if len(task_obj) != 1:
|
|
69
|
+
keys = ", ".join(sorted(task_obj.keys()))
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Task '{args.task_name}' returned {len(task_obj)} subtasks ({keys}). "
|
|
72
|
+
"Specify an explicit subtask, e.g. 'benchmark/subtask'."
|
|
73
|
+
)
|
|
74
|
+
(subname, task), = task_obj.items()
|
|
75
|
+
pairs_task_name = subname
|
|
76
|
+
|
|
77
|
+
# 2. Generate contrastive pairs using lm-eval interface
|
|
78
|
+
print(f" šØ Building contrastive pairs...")
|
|
79
|
+
pairs = lm_build_contrastive_pairs(
|
|
80
|
+
task_name=pairs_task_name,
|
|
81
|
+
lm_eval_task=task,
|
|
82
|
+
limit=args.limit,
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
# Custom task (TaskInterface)
|
|
86
|
+
task = task_obj
|
|
87
|
+
pairs_task_name = args.task_name
|
|
88
|
+
|
|
89
|
+
# 2. Generate contrastive pairs using custom task interface
|
|
90
|
+
print(f" šØ Building contrastive pairs...")
|
|
91
|
+
pairs = _build_pairs_from_custom_task(task, args.limit)
|
|
92
|
+
|
|
93
|
+
print(f" ā Generated {len(pairs)} contrastive pairs")
|
|
94
|
+
|
|
95
|
+
# 3. Convert pairs to dict format for JSON serialization
|
|
96
|
+
print(f"\nš¾ Saving pairs to '{args.output}'...")
|
|
97
|
+
pairs_data = []
|
|
98
|
+
for pair in pairs:
|
|
99
|
+
pair_dict = pair.to_dict()
|
|
100
|
+
pairs_data.append(pair_dict)
|
|
101
|
+
|
|
102
|
+
# 4. Save to JSON file
|
|
103
|
+
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
|
104
|
+
with open(args.output, 'w') as f:
|
|
105
|
+
json.dump({
|
|
106
|
+
'task_name': pairs_task_name,
|
|
107
|
+
'num_pairs': len(pairs),
|
|
108
|
+
'pairs': pairs_data
|
|
109
|
+
}, f, indent=2)
|
|
110
|
+
|
|
111
|
+
print(f" ā Saved {len(pairs)} pairs to: {args.output}")
|
|
112
|
+
print(f"\nā
Contrastive pairs generation completed successfully!\n")
|
|
113
|
+
|
|
114
|
+
except Exception as e:
|
|
115
|
+
print(f"\nā Error: {str(e)}", file=sys.stderr)
|
|
116
|
+
if args.verbose:
|
|
117
|
+
import traceback
|
|
118
|
+
traceback.print_exc()
|
|
119
|
+
sys.exit(1)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Generate responses command execution logic."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def execute_generate_responses(args):
|
|
9
|
+
"""
|
|
10
|
+
Execute the generate-responses command.
|
|
11
|
+
|
|
12
|
+
Generates model responses to questions from a task and saves them to a file.
|
|
13
|
+
"""
|
|
14
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
15
|
+
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
16
|
+
|
|
17
|
+
print(f"\n{'='*80}")
|
|
18
|
+
print(f"šÆ GENERATING RESPONSES FROM TASK")
|
|
19
|
+
print(f"{'='*80}")
|
|
20
|
+
print(f" Task: {args.task}")
|
|
21
|
+
print(f" Model: {args.model}")
|
|
22
|
+
print(f" Num questions: {args.num_questions}")
|
|
23
|
+
print(f" Device: {args.device or 'auto'}")
|
|
24
|
+
print(f"{'='*80}\n")
|
|
25
|
+
|
|
26
|
+
# Load model
|
|
27
|
+
print(f"š¦ Loading model...")
|
|
28
|
+
model = WisentModel(args.model, device=args.device)
|
|
29
|
+
print(f" ā Model loaded\n")
|
|
30
|
+
|
|
31
|
+
# Load task data
|
|
32
|
+
print(f"š Loading task data...")
|
|
33
|
+
loader = LMEvalDataLoader()
|
|
34
|
+
try:
|
|
35
|
+
result = loader._load_one_task(
|
|
36
|
+
task_name=args.task,
|
|
37
|
+
split_ratio=0.8,
|
|
38
|
+
seed=42,
|
|
39
|
+
limit=args.num_questions,
|
|
40
|
+
training_limit=None,
|
|
41
|
+
testing_limit=None
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Use test pairs for generation
|
|
45
|
+
pairs = result['test_qa_pairs'].pairs[:args.num_questions]
|
|
46
|
+
print(f" ā Loaded {len(pairs)} question pairs\n")
|
|
47
|
+
|
|
48
|
+
except Exception as e:
|
|
49
|
+
print(f" ā Failed to load task: {e}")
|
|
50
|
+
sys.exit(1)
|
|
51
|
+
|
|
52
|
+
# Generate responses
|
|
53
|
+
print(f"š¤ Generating responses...\n")
|
|
54
|
+
results = []
|
|
55
|
+
|
|
56
|
+
for idx, pair in enumerate(pairs, 1):
|
|
57
|
+
if args.verbose:
|
|
58
|
+
print(f"Question {idx}/{len(pairs)}:")
|
|
59
|
+
print(f" Prompt: {pair.prompt[:100]}...")
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
# Convert prompt to chat format
|
|
63
|
+
messages = [
|
|
64
|
+
{"role": "user", "content": pair.prompt}
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
# Generate response
|
|
68
|
+
responses = model.generate(
|
|
69
|
+
inputs=[messages],
|
|
70
|
+
max_new_tokens=args.max_new_tokens,
|
|
71
|
+
temperature=args.temperature,
|
|
72
|
+
top_p=args.top_p,
|
|
73
|
+
do_sample=True,
|
|
74
|
+
use_steering=args.use_steering,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
generated_text = responses[0] if responses else ""
|
|
78
|
+
|
|
79
|
+
if args.verbose:
|
|
80
|
+
print(f" Generated: {generated_text[:100]}...")
|
|
81
|
+
print()
|
|
82
|
+
|
|
83
|
+
results.append({
|
|
84
|
+
"question_id": idx,
|
|
85
|
+
"prompt": pair.prompt,
|
|
86
|
+
"generated_response": generated_text,
|
|
87
|
+
"positive_reference": pair.positive_response.model_response,
|
|
88
|
+
"negative_reference": pair.negative_response.model_response
|
|
89
|
+
})
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
print(f" ā Error generating response for question {idx}: {e}")
|
|
93
|
+
results.append({
|
|
94
|
+
"question_id": idx,
|
|
95
|
+
"prompt": pair.prompt,
|
|
96
|
+
"generated_response": None,
|
|
97
|
+
"error": str(e)
|
|
98
|
+
})
|
|
99
|
+
|
|
100
|
+
# Save results
|
|
101
|
+
print(f"\nš¾ Saving results...")
|
|
102
|
+
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
|
103
|
+
|
|
104
|
+
output_data = {
|
|
105
|
+
"task": args.task,
|
|
106
|
+
"model": args.model,
|
|
107
|
+
"num_questions": len(pairs),
|
|
108
|
+
"generation_params": {
|
|
109
|
+
"max_new_tokens": args.max_new_tokens,
|
|
110
|
+
"temperature": args.temperature,
|
|
111
|
+
"top_p": args.top_p,
|
|
112
|
+
"use_steering": args.use_steering
|
|
113
|
+
},
|
|
114
|
+
"responses": results
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
with open(args.output, 'w') as f:
|
|
118
|
+
json.dump(output_data, f, indent=2)
|
|
119
|
+
|
|
120
|
+
print(f" ā Results saved to: {args.output}\n")
|
|
121
|
+
|
|
122
|
+
# Print summary
|
|
123
|
+
print(f"{'='*80}")
|
|
124
|
+
print(f"ā
GENERATION COMPLETE")
|
|
125
|
+
print(f"{'='*80}")
|
|
126
|
+
print(f" Total questions: {len(results)}")
|
|
127
|
+
print(f" Successful: {sum(1 for r in results if 'error' not in r)}")
|
|
128
|
+
print(f" Failed: {sum(1 for r in results if 'error' in r)}")
|
|
129
|
+
print(f"{'='*80}\n")
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Generate steering vector from synthetic pairs command execution logic - unified pipeline."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import tempfile
|
|
7
|
+
from argparse import Namespace
|
|
8
|
+
|
|
9
|
+
from wisent.core.cli.generate_pairs import execute_generate_pairs
|
|
10
|
+
from wisent.core.cli.get_activations import execute_get_activations
|
|
11
|
+
from wisent.core.cli.create_steering_vector import execute_create_steering_vector
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def execute_generate_vector_from_synthetic(args):
|
|
15
|
+
"""
|
|
16
|
+
Execute the generate-vector-from-synthetic command - full pipeline in one command.
|
|
17
|
+
|
|
18
|
+
Pipeline:
|
|
19
|
+
1. Generate synthetic contrastive pairs for a trait
|
|
20
|
+
2. Collect activations from those pairs
|
|
21
|
+
3. Create steering vectors from the activations
|
|
22
|
+
"""
|
|
23
|
+
print(f"\n{'='*60}")
|
|
24
|
+
print(f"šÆ Generating Steering Vector from Synthetic Pairs (Full Pipeline)")
|
|
25
|
+
print(f"{'='*60}")
|
|
26
|
+
print(f" Trait: {args.trait}")
|
|
27
|
+
print(f" Model: {args.model}")
|
|
28
|
+
print(f" Num Pairs: {args.num_pairs}")
|
|
29
|
+
print(f"{'='*60}\n")
|
|
30
|
+
|
|
31
|
+
pipeline_start = time.time() if args.timing else None
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
# Determine intermediate file paths
|
|
35
|
+
if args.intermediate_dir:
|
|
36
|
+
intermediate_dir = args.intermediate_dir
|
|
37
|
+
else:
|
|
38
|
+
intermediate_dir = os.path.dirname(os.path.abspath(args.output))
|
|
39
|
+
|
|
40
|
+
os.makedirs(intermediate_dir, exist_ok=True)
|
|
41
|
+
|
|
42
|
+
# Create intermediate file paths
|
|
43
|
+
if args.keep_intermediate:
|
|
44
|
+
pairs_file = os.path.join(intermediate_dir, f"{args.trait.replace(' ', '_')}_pairs.json")
|
|
45
|
+
enriched_file = os.path.join(intermediate_dir, f"{args.trait.replace(' ', '_')}_pairs_with_activations.json")
|
|
46
|
+
else:
|
|
47
|
+
# Use temporary files that will be deleted
|
|
48
|
+
pairs_file = tempfile.NamedTemporaryFile(mode='w', suffix='_pairs.json', delete=False).name
|
|
49
|
+
enriched_file = tempfile.NamedTemporaryFile(mode='w', suffix='_enriched.json', delete=False).name
|
|
50
|
+
|
|
51
|
+
# Step 1: Generate synthetic pairs
|
|
52
|
+
print(f"{'='*60}")
|
|
53
|
+
print(f"Step 1/3: Generating synthetic contrastive pairs...")
|
|
54
|
+
print(f"{'='*60}\n")
|
|
55
|
+
|
|
56
|
+
pairs_args = Namespace(
|
|
57
|
+
trait=args.trait,
|
|
58
|
+
num_pairs=args.num_pairs,
|
|
59
|
+
output=pairs_file,
|
|
60
|
+
model=args.model,
|
|
61
|
+
device=args.device,
|
|
62
|
+
similarity_threshold=args.similarity_threshold,
|
|
63
|
+
verbose=args.verbose,
|
|
64
|
+
timing=args.timing,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
execute_generate_pairs(pairs_args)
|
|
68
|
+
print(f"\nā Step 1 complete: Pairs saved to {pairs_file}\n")
|
|
69
|
+
|
|
70
|
+
# Step 2: Collect activations
|
|
71
|
+
print(f"{'='*60}")
|
|
72
|
+
print(f"Step 2/3: Collecting activations from pairs...")
|
|
73
|
+
print(f"{'='*60}\n")
|
|
74
|
+
|
|
75
|
+
activations_args = Namespace(
|
|
76
|
+
pairs_file=pairs_file,
|
|
77
|
+
output=enriched_file,
|
|
78
|
+
model=args.model,
|
|
79
|
+
device=args.device,
|
|
80
|
+
layers=args.layers,
|
|
81
|
+
token_aggregation=args.token_aggregation,
|
|
82
|
+
prompt_strategy=args.prompt_strategy,
|
|
83
|
+
verbose=args.verbose,
|
|
84
|
+
timing=args.timing,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
execute_get_activations(activations_args)
|
|
88
|
+
print(f"\nā Step 2 complete: Enriched pairs saved to {enriched_file}\n")
|
|
89
|
+
|
|
90
|
+
# Step 3: Create steering vector
|
|
91
|
+
print(f"{'='*60}")
|
|
92
|
+
print(f"Step 3/3: Creating steering vector...")
|
|
93
|
+
print(f"{'='*60}\n")
|
|
94
|
+
|
|
95
|
+
vector_args = Namespace(
|
|
96
|
+
enriched_pairs_file=enriched_file,
|
|
97
|
+
output=args.output,
|
|
98
|
+
method=args.method,
|
|
99
|
+
normalize=args.normalize,
|
|
100
|
+
verbose=args.verbose,
|
|
101
|
+
timing=args.timing,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
execute_create_steering_vector(vector_args)
|
|
105
|
+
print(f"\nā Step 3 complete: Steering vector saved to {args.output}\n")
|
|
106
|
+
|
|
107
|
+
# Clean up intermediate files if not keeping them
|
|
108
|
+
if not args.keep_intermediate:
|
|
109
|
+
if args.verbose:
|
|
110
|
+
print(f"\nš§¹ Cleaning up intermediate files...")
|
|
111
|
+
try:
|
|
112
|
+
os.unlink(pairs_file)
|
|
113
|
+
os.unlink(enriched_file)
|
|
114
|
+
if args.verbose:
|
|
115
|
+
print(f" ā Removed temporary files")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
if args.verbose:
|
|
118
|
+
print(f" ā ļø Warning: Could not remove some temporary files: {e}")
|
|
119
|
+
|
|
120
|
+
# Final summary
|
|
121
|
+
print(f"\n{'='*60}")
|
|
122
|
+
print(f"ā
Full Pipeline Completed Successfully!")
|
|
123
|
+
print(f"{'='*60}")
|
|
124
|
+
print(f" Final steering vector: {args.output}")
|
|
125
|
+
if args.keep_intermediate:
|
|
126
|
+
print(f" Intermediate pairs: {pairs_file}")
|
|
127
|
+
print(f" Intermediate enriched: {enriched_file}")
|
|
128
|
+
if args.timing and pipeline_start:
|
|
129
|
+
total_time = time.time() - pipeline_start
|
|
130
|
+
print(f" ā±ļø Total pipeline time: {total_time:.2f}s")
|
|
131
|
+
print(f"{'='*60}\n")
|
|
132
|
+
|
|
133
|
+
except Exception as e:
|
|
134
|
+
print(f"\nā Pipeline failed: {str(e)}", file=sys.stderr)
|
|
135
|
+
if args.verbose:
|
|
136
|
+
import traceback
|
|
137
|
+
traceback.print_exc()
|
|
138
|
+
|
|
139
|
+
# Clean up on failure
|
|
140
|
+
if not args.keep_intermediate:
|
|
141
|
+
try:
|
|
142
|
+
if 'pairs_file' in locals() and os.path.exists(pairs_file):
|
|
143
|
+
os.unlink(pairs_file)
|
|
144
|
+
if 'enriched_file' in locals() and os.path.exists(enriched_file):
|
|
145
|
+
os.unlink(enriched_file)
|
|
146
|
+
except:
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
sys.exit(1)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Generate steering vector from task command execution logic - unified pipeline."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import tempfile
|
|
7
|
+
from argparse import Namespace
|
|
8
|
+
|
|
9
|
+
from wisent.core.cli.generate_pairs_from_task import execute_generate_pairs_from_task
|
|
10
|
+
from wisent.core.cli.get_activations import execute_get_activations
|
|
11
|
+
from wisent.core.cli.create_steering_vector import execute_create_steering_vector
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def execute_generate_vector_from_task(args):
|
|
15
|
+
"""
|
|
16
|
+
Execute the generate-vector-from-task command - full pipeline in one command.
|
|
17
|
+
|
|
18
|
+
Pipeline:
|
|
19
|
+
1. Generate contrastive pairs from lm-eval task
|
|
20
|
+
2. Collect activations from those pairs
|
|
21
|
+
3. Create steering vectors from the activations
|
|
22
|
+
"""
|
|
23
|
+
print(f"\n{'='*60}")
|
|
24
|
+
print(f"šÆ Generating Steering Vector from Task (Full Pipeline)")
|
|
25
|
+
print(f"{'='*60}")
|
|
26
|
+
print(f" Task: {args.task}")
|
|
27
|
+
print(f" Trait Label: {args.trait_label}")
|
|
28
|
+
print(f" Model: {args.model}")
|
|
29
|
+
print(f" Num Pairs: {args.num_pairs}")
|
|
30
|
+
print(f"{'='*60}\n")
|
|
31
|
+
|
|
32
|
+
pipeline_start = time.time() if args.timing else None
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
# Determine intermediate file paths
|
|
36
|
+
if args.intermediate_dir:
|
|
37
|
+
intermediate_dir = args.intermediate_dir
|
|
38
|
+
else:
|
|
39
|
+
intermediate_dir = os.path.dirname(os.path.abspath(args.output))
|
|
40
|
+
|
|
41
|
+
os.makedirs(intermediate_dir, exist_ok=True)
|
|
42
|
+
|
|
43
|
+
# Create intermediate file paths
|
|
44
|
+
if args.keep_intermediate:
|
|
45
|
+
pairs_file = os.path.join(intermediate_dir, f"{args.task}_{args.trait_label}_pairs.json")
|
|
46
|
+
enriched_file = os.path.join(intermediate_dir, f"{args.task}_{args.trait_label}_pairs_with_activations.json")
|
|
47
|
+
else:
|
|
48
|
+
# Use temporary files that will be deleted
|
|
49
|
+
pairs_file = tempfile.NamedTemporaryFile(mode='w', suffix='_pairs.json', delete=False).name
|
|
50
|
+
enriched_file = tempfile.NamedTemporaryFile(mode='w', suffix='_enriched.json', delete=False).name
|
|
51
|
+
|
|
52
|
+
# Step 1: Generate pairs from task
|
|
53
|
+
print(f"{'='*60}")
|
|
54
|
+
print(f"Step 1/3: Generating contrastive pairs from task...")
|
|
55
|
+
print(f"{'='*60}\n")
|
|
56
|
+
|
|
57
|
+
pairs_args = Namespace(
|
|
58
|
+
task_name=args.task,
|
|
59
|
+
limit=args.num_pairs,
|
|
60
|
+
output=pairs_file,
|
|
61
|
+
seed=42,
|
|
62
|
+
verbose=args.verbose,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
execute_generate_pairs_from_task(pairs_args)
|
|
66
|
+
print(f"\nā Step 1 complete: Pairs saved to {pairs_file}\n")
|
|
67
|
+
|
|
68
|
+
# Step 2: Collect activations
|
|
69
|
+
print(f"{'='*60}")
|
|
70
|
+
print(f"Step 2/3: Collecting activations from pairs...")
|
|
71
|
+
print(f"{'='*60}\n")
|
|
72
|
+
|
|
73
|
+
activations_args = Namespace(
|
|
74
|
+
pairs_file=pairs_file,
|
|
75
|
+
output=enriched_file,
|
|
76
|
+
model=args.model,
|
|
77
|
+
device=args.device,
|
|
78
|
+
layers=args.layers,
|
|
79
|
+
token_aggregation=args.token_aggregation,
|
|
80
|
+
prompt_strategy=args.prompt_strategy,
|
|
81
|
+
verbose=args.verbose,
|
|
82
|
+
timing=args.timing,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
execute_get_activations(activations_args)
|
|
86
|
+
print(f"\nā Step 2 complete: Enriched pairs saved to {enriched_file}\n")
|
|
87
|
+
|
|
88
|
+
# Step 3: Create steering vector
|
|
89
|
+
print(f"{'='*60}")
|
|
90
|
+
print(f"Step 3/3: Creating steering vector...")
|
|
91
|
+
print(f"{'='*60}\n")
|
|
92
|
+
|
|
93
|
+
vector_args = Namespace(
|
|
94
|
+
enriched_pairs_file=enriched_file,
|
|
95
|
+
output=args.output,
|
|
96
|
+
method=args.method,
|
|
97
|
+
normalize=args.normalize,
|
|
98
|
+
verbose=args.verbose,
|
|
99
|
+
timing=args.timing,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
execute_create_steering_vector(vector_args)
|
|
103
|
+
print(f"\nā Step 3 complete: Steering vector saved to {args.output}\n")
|
|
104
|
+
|
|
105
|
+
# Clean up intermediate files if not keeping them
|
|
106
|
+
if not args.keep_intermediate:
|
|
107
|
+
if args.verbose:
|
|
108
|
+
print(f"\nš§¹ Cleaning up intermediate files...")
|
|
109
|
+
try:
|
|
110
|
+
os.unlink(pairs_file)
|
|
111
|
+
os.unlink(enriched_file)
|
|
112
|
+
if args.verbose:
|
|
113
|
+
print(f" ā Removed temporary files")
|
|
114
|
+
except Exception as e:
|
|
115
|
+
if args.verbose:
|
|
116
|
+
print(f" ā ļø Warning: Could not remove some temporary files: {e}")
|
|
117
|
+
|
|
118
|
+
# Final summary
|
|
119
|
+
print(f"\n{'='*60}")
|
|
120
|
+
print(f"ā
Full Pipeline Completed Successfully!")
|
|
121
|
+
print(f"{'='*60}")
|
|
122
|
+
print(f" Final steering vector: {args.output}")
|
|
123
|
+
if args.keep_intermediate:
|
|
124
|
+
print(f" Intermediate pairs: {pairs_file}")
|
|
125
|
+
print(f" Intermediate enriched: {enriched_file}")
|
|
126
|
+
if args.timing and pipeline_start:
|
|
127
|
+
total_time = time.time() - pipeline_start
|
|
128
|
+
print(f" ā±ļø Total pipeline time: {total_time:.2f}s")
|
|
129
|
+
print(f"{'='*60}\n")
|
|
130
|
+
|
|
131
|
+
except Exception as e:
|
|
132
|
+
print(f"\nā Pipeline failed: {str(e)}", file=sys.stderr)
|
|
133
|
+
if args.verbose:
|
|
134
|
+
import traceback
|
|
135
|
+
traceback.print_exc()
|
|
136
|
+
|
|
137
|
+
# Clean up on failure
|
|
138
|
+
if not args.keep_intermediate:
|
|
139
|
+
try:
|
|
140
|
+
if 'pairs_file' in locals() and os.path.exists(pairs_file):
|
|
141
|
+
os.unlink(pairs_file)
|
|
142
|
+
if 'enriched_file' in locals() and os.path.exists(enriched_file):
|
|
143
|
+
os.unlink(enriched_file)
|
|
144
|
+
except:
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
sys.exit(1)
|