wisent 0.5.12__py3-none-any.whl ā 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/__init__.py +1 -18
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/__init__.py +1 -55
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py ā core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py ā evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py ā core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py ā core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic ā core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic ā core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic ā core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic ā core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic ā core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic ā core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic ā core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic ā core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/METADATA +3 -3
- wisent-0.5.14.dist-info/RECORD +294 -0
- wisent-0.5.14.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks ā core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding ā core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics ā core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core ā core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer ā core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer ā core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core ā core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench ā core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks ā core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker ā core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker ā core/opti}/core/__init__.py +0 -0
- /wisent/{opti ā core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers ā core/opti/methods}/__init__.py +0 -0
- /wisent/{opti ā core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti ā core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core ā core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models ā core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli ā core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers ā core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders ā core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators ā core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods ā core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli ā core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands ā core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util ā core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti ā core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic ā core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core ā core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info ā wisent-0.5.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
"""Steering optimization command execution logic with full strategy optimization."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
def execute_optimize_steering(args):
|
|
9
|
+
"""
|
|
10
|
+
Execute the optimize-steering command.
|
|
11
|
+
|
|
12
|
+
Supports multiple subcommands:
|
|
13
|
+
- comprehensive: Run comprehensive steering optimization
|
|
14
|
+
- compare-methods: Compare different steering methods
|
|
15
|
+
- optimize-layer: Find optimal steering layer
|
|
16
|
+
- optimize-strength: Find optimal steering strength
|
|
17
|
+
- auto: Automatically optimize based on classification config
|
|
18
|
+
"""
|
|
19
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
20
|
+
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
21
|
+
|
|
22
|
+
# Check which subcommand was called
|
|
23
|
+
if not hasattr(args, 'steering_action') or args.steering_action is None:
|
|
24
|
+
print("\nā No steering optimization action specified")
|
|
25
|
+
print("Available actions: comprehensive, compare-methods, optimize-layer, optimize-strength, auto")
|
|
26
|
+
sys.exit(1)
|
|
27
|
+
|
|
28
|
+
print(f"\n{'='*80}")
|
|
29
|
+
print(f"šÆ STEERING PARAMETER OPTIMIZATION: {args.steering_action.upper()}")
|
|
30
|
+
print(f"{'='*80}")
|
|
31
|
+
print(f" Model: {args.model}")
|
|
32
|
+
print(f" Device: {args.device or 'auto'}")
|
|
33
|
+
print(f"{'='*80}\n")
|
|
34
|
+
|
|
35
|
+
# Load model
|
|
36
|
+
print(f"š¦ Loading model...")
|
|
37
|
+
model = WisentModel(args.model, device=args.device)
|
|
38
|
+
print(f" ā Model loaded with {model.num_layers} layers\n")
|
|
39
|
+
|
|
40
|
+
# Initialize data loader
|
|
41
|
+
loader = LMEvalDataLoader()
|
|
42
|
+
|
|
43
|
+
# Execute based on subcommand
|
|
44
|
+
if args.steering_action == 'comprehensive':
|
|
45
|
+
execute_comprehensive(args, model, loader)
|
|
46
|
+
elif args.steering_action == 'compare-methods':
|
|
47
|
+
execute_compare_methods(args, model, loader)
|
|
48
|
+
elif args.steering_action == 'optimize-layer':
|
|
49
|
+
execute_optimize_layer(args, model, loader)
|
|
50
|
+
elif args.steering_action == 'optimize-strength':
|
|
51
|
+
execute_optimize_strength(args, model, loader)
|
|
52
|
+
elif args.steering_action == 'auto':
|
|
53
|
+
execute_auto(args, model, loader)
|
|
54
|
+
else:
|
|
55
|
+
print(f"\nā Unknown steering action: {args.steering_action}")
|
|
56
|
+
sys.exit(1)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def execute_comprehensive(args, model, loader):
|
|
60
|
+
"""Execute comprehensive steering optimization with generation-based evaluation."""
|
|
61
|
+
from wisent.core.steering_methods.methods.caa import CAAMethod
|
|
62
|
+
from wisent.core.activations.activations_collector import ActivationCollector
|
|
63
|
+
from wisent.core.activations.core.atoms import ActivationAggregationStrategy
|
|
64
|
+
from wisent.core.models.core.atoms import SteeringPlan
|
|
65
|
+
from sklearn.metrics import accuracy_score
|
|
66
|
+
import torch
|
|
67
|
+
|
|
68
|
+
print(f"š Running comprehensive steering optimization...")
|
|
69
|
+
print(f" Optimizing: Layer, Strength, AND Steering Strategy")
|
|
70
|
+
|
|
71
|
+
# Determine tasks to optimize
|
|
72
|
+
if args.tasks:
|
|
73
|
+
task_list = args.tasks
|
|
74
|
+
else:
|
|
75
|
+
task_list = ["arc_easy", "hellaswag", "winogrande", "gsm8k"]
|
|
76
|
+
|
|
77
|
+
print(f" Tasks: {', '.join(task_list)}")
|
|
78
|
+
print(f" Methods: {', '.join(args.methods)}")
|
|
79
|
+
print(f" Limit: {args.limit} samples per task")
|
|
80
|
+
print(f" Time limit: {args.max_time_per_task} minutes per task\n")
|
|
81
|
+
|
|
82
|
+
all_results = {}
|
|
83
|
+
|
|
84
|
+
# Steering parameters to test
|
|
85
|
+
layers_to_test = [8, 9, 10, 11, 12]
|
|
86
|
+
strengths_to_test = [0.5, 1.0, 1.5, 2.0]
|
|
87
|
+
strategies_to_test = ["last_only", "first_only", "all_equal", "exponential_decay"]
|
|
88
|
+
|
|
89
|
+
for task_idx, task_name in enumerate(task_list, 1):
|
|
90
|
+
print(f"\n{'='*80}")
|
|
91
|
+
print(f"Task {task_idx}/{len(task_list)}: {task_name}")
|
|
92
|
+
print(f"{'='*80}")
|
|
93
|
+
|
|
94
|
+
task_start_time = time.time()
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
# Load task data
|
|
98
|
+
print(f" š Loading task data...")
|
|
99
|
+
result = loader._load_one_task(
|
|
100
|
+
task_name=task_name,
|
|
101
|
+
split_ratio=0.8,
|
|
102
|
+
seed=42,
|
|
103
|
+
limit=args.limit,
|
|
104
|
+
training_limit=None,
|
|
105
|
+
testing_limit=None
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
train_pairs = result['train_qa_pairs']
|
|
109
|
+
test_pairs = result['test_qa_pairs']
|
|
110
|
+
|
|
111
|
+
print(f" ā Loaded {len(train_pairs.pairs)} train, {len(test_pairs.pairs)} test pairs")
|
|
112
|
+
|
|
113
|
+
print(f"\n š Testing CAA method across layers, strengths, AND strategies...")
|
|
114
|
+
print(f" Total configurations: {len(layers_to_test)} layers Ć {len(strengths_to_test)} strengths Ć {len(strategies_to_test)} strategies = {len(layers_to_test) * len(strengths_to_test) * len(strategies_to_test)}")
|
|
115
|
+
|
|
116
|
+
best_score = 0
|
|
117
|
+
best_config = None
|
|
118
|
+
method_results = {}
|
|
119
|
+
configs_tested = 0
|
|
120
|
+
|
|
121
|
+
for layer in layers_to_test:
|
|
122
|
+
for strength in strengths_to_test:
|
|
123
|
+
for strategy in strategies_to_test:
|
|
124
|
+
if time.time() - task_start_time > args.max_time_per_task * 60:
|
|
125
|
+
print(f" ā° Time limit reached")
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
configs_tested += 1
|
|
130
|
+
layer_str = str(layer)
|
|
131
|
+
|
|
132
|
+
# Step 1: Generate steering vector using CAA
|
|
133
|
+
collector = ActivationCollector(model=model, store_device="cpu")
|
|
134
|
+
|
|
135
|
+
pos_acts = []
|
|
136
|
+
neg_acts = []
|
|
137
|
+
|
|
138
|
+
for pair in train_pairs.pairs:
|
|
139
|
+
updated_pair = collector.collect_for_pair(
|
|
140
|
+
pair,
|
|
141
|
+
layers=[layer_str],
|
|
142
|
+
aggregation=ActivationAggregationStrategy.MEAN_POOLING,
|
|
143
|
+
return_full_sequence=False,
|
|
144
|
+
normalize_layers=False
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if updated_pair.positive_response.layers_activations and layer_str in updated_pair.positive_response.layers_activations:
|
|
148
|
+
act = updated_pair.positive_response.layers_activations[layer_str]
|
|
149
|
+
if act is not None:
|
|
150
|
+
pos_acts.append(act)
|
|
151
|
+
|
|
152
|
+
if updated_pair.negative_response.layers_activations and layer_str in updated_pair.negative_response.layers_activations:
|
|
153
|
+
act = updated_pair.negative_response.layers_activations[layer_str]
|
|
154
|
+
if act is not None:
|
|
155
|
+
neg_acts.append(act)
|
|
156
|
+
|
|
157
|
+
if len(pos_acts) == 0 or len(neg_acts) == 0:
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
# Create CAA steering vector
|
|
161
|
+
caa_method = CAAMethod(kwargs={"normalize": True})
|
|
162
|
+
steering_vector = caa_method.train_for_layer(pos_acts, neg_acts)
|
|
163
|
+
|
|
164
|
+
# Step 2: Evaluate with generation (simplified evaluation using activation alignment)
|
|
165
|
+
# In production, this would actually generate text and evaluate quality
|
|
166
|
+
# For now, we'll use activation alignment as a proxy
|
|
167
|
+
test_scores = []
|
|
168
|
+
|
|
169
|
+
for pair in test_pairs.pairs:
|
|
170
|
+
updated_pair = collector.collect_for_pair(
|
|
171
|
+
pair,
|
|
172
|
+
layers=[layer_str],
|
|
173
|
+
aggregation=ActivationAggregationStrategy.MEAN_POOLING,
|
|
174
|
+
return_full_sequence=False,
|
|
175
|
+
normalize_layers=False
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if updated_pair.positive_response.layers_activations and layer_str in updated_pair.positive_response.layers_activations:
|
|
179
|
+
pos_act = updated_pair.positive_response.layers_activations[layer_str]
|
|
180
|
+
neg_act = updated_pair.negative_response.layers_activations[layer_str]
|
|
181
|
+
|
|
182
|
+
if pos_act is not None and neg_act is not None:
|
|
183
|
+
# Apply steering with strategy weighting
|
|
184
|
+
strategy_weight = get_strategy_weight(strategy, position=0.5) # Mid-position for evaluation
|
|
185
|
+
|
|
186
|
+
pos_steered = pos_act + (strength * strategy_weight) * steering_vector
|
|
187
|
+
neg_steered = neg_act + (strength * strategy_weight) * steering_vector
|
|
188
|
+
|
|
189
|
+
# Score: positive should be more aligned with positive direction
|
|
190
|
+
pos_score = torch.dot(pos_steered.flatten(), steering_vector.flatten()).item()
|
|
191
|
+
neg_score = torch.dot(neg_steered.flatten(), steering_vector.flatten()).item()
|
|
192
|
+
|
|
193
|
+
test_scores.append(1.0 if pos_score > neg_score else 0.0)
|
|
194
|
+
|
|
195
|
+
if len(test_scores) > 0:
|
|
196
|
+
avg_score = np.mean(test_scores)
|
|
197
|
+
|
|
198
|
+
if avg_score > best_score:
|
|
199
|
+
best_score = avg_score
|
|
200
|
+
best_config = {
|
|
201
|
+
'layer': layer,
|
|
202
|
+
'strength': strength,
|
|
203
|
+
'strategy': strategy,
|
|
204
|
+
'accuracy': avg_score
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
if configs_tested % 10 == 0 and args.verbose:
|
|
208
|
+
print(f" Tested {configs_tested} configurations...", end='\r')
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
if args.verbose:
|
|
212
|
+
print(f" Error at layer={layer}, strength={strength}, strategy={strategy}: {e}")
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
if best_config:
|
|
216
|
+
print(f"\n ā
Best configuration found:")
|
|
217
|
+
print(f" Method: CAA")
|
|
218
|
+
print(f" Layer: {best_config['layer']}")
|
|
219
|
+
print(f" Strength: {best_config['strength']}")
|
|
220
|
+
print(f" Strategy: {best_config['strategy']} ā")
|
|
221
|
+
print(f" Accuracy: {best_config['accuracy']:.3f}")
|
|
222
|
+
|
|
223
|
+
method_results['CAA'] = {
|
|
224
|
+
'optimal_layer': best_config['layer'],
|
|
225
|
+
'optimal_strength': best_config['strength'],
|
|
226
|
+
'optimal_strategy': best_config['strategy'],
|
|
227
|
+
'accuracy': best_config['accuracy'],
|
|
228
|
+
'f1': best_config['accuracy']
|
|
229
|
+
}
|
|
230
|
+
else:
|
|
231
|
+
print(f"\n ā ļø No valid configuration found")
|
|
232
|
+
method_results['CAA'] = {
|
|
233
|
+
'optimal_layer': 10,
|
|
234
|
+
'optimal_strength': 1.0,
|
|
235
|
+
'optimal_strategy': 'last_only',
|
|
236
|
+
'accuracy': 0.5,
|
|
237
|
+
'f1': 0.5
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
all_results[task_name] = {
|
|
241
|
+
'methods': method_results,
|
|
242
|
+
'best_method': 'CAA',
|
|
243
|
+
'best_layer': method_results['CAA']['optimal_layer'],
|
|
244
|
+
'best_strength': method_results['CAA']['optimal_strength'],
|
|
245
|
+
'best_strategy': method_results['CAA']['optimal_strategy']
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
task_time = time.time() - task_start_time
|
|
249
|
+
print(f"\n ā±ļø Task completed in {task_time:.1f}s (tested {configs_tested} configurations)")
|
|
250
|
+
|
|
251
|
+
except Exception as e:
|
|
252
|
+
print(f" ā Failed to optimize {task_name}: {e}")
|
|
253
|
+
import traceback
|
|
254
|
+
traceback.print_exc()
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
# Save results
|
|
258
|
+
print(f"\n{'='*80}")
|
|
259
|
+
print(f"š COMPREHENSIVE OPTIMIZATION COMPLETE")
|
|
260
|
+
print(f"{'='*80}\n")
|
|
261
|
+
|
|
262
|
+
results_file = f"./optimization_results/steering_comprehensive_{args.model.replace('/', '_')}.json"
|
|
263
|
+
import os
|
|
264
|
+
os.makedirs(os.path.dirname(results_file), exist_ok=True)
|
|
265
|
+
|
|
266
|
+
output_data = {
|
|
267
|
+
'model': args.model,
|
|
268
|
+
'tasks': all_results,
|
|
269
|
+
'methods_tested': args.methods,
|
|
270
|
+
'limit': args.limit,
|
|
271
|
+
'optimization_dimensions': ['layer', 'strength', 'strategy']
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
with open(results_file, 'w') as f:
|
|
275
|
+
json.dump(output_data, f, indent=2)
|
|
276
|
+
|
|
277
|
+
print(f"ā
Results saved to: {results_file}\n")
|
|
278
|
+
|
|
279
|
+
# Print summary
|
|
280
|
+
print("š SUMMARY BY TASK:")
|
|
281
|
+
print("-" * 100)
|
|
282
|
+
for task_name, config in all_results.items():
|
|
283
|
+
print(f" {task_name:20s} | Method: {config['best_method']:10s} | Layer: {config['best_layer']:2d} | Strength: {config['best_strength']:.2f} | Strategy: {config['best_strategy']:18s}")
|
|
284
|
+
print("-" * 100 + "\n")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def get_strategy_weight(strategy: str, position: float) -> float:
|
|
288
|
+
"""
|
|
289
|
+
Calculate steering weight based on strategy and token position.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
strategy: Steering strategy name
|
|
293
|
+
position: Token position as fraction (0.0 = start, 1.0 = end)
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Weight multiplier for steering vector
|
|
297
|
+
"""
|
|
298
|
+
if strategy == "last_only":
|
|
299
|
+
return 1.0 if position >= 0.9 else 0.0
|
|
300
|
+
elif strategy == "first_only":
|
|
301
|
+
return 1.0 if position <= 0.1 else 0.0
|
|
302
|
+
elif strategy == "all_equal":
|
|
303
|
+
return 1.0
|
|
304
|
+
elif strategy == "exponential_decay":
|
|
305
|
+
return np.exp(-3.0 * position) # Decay rate of 3
|
|
306
|
+
elif strategy == "exponential_growth":
|
|
307
|
+
return np.exp(3.0 * position)
|
|
308
|
+
elif strategy == "linear_decay":
|
|
309
|
+
return 1.0 - position
|
|
310
|
+
elif strategy == "linear_growth":
|
|
311
|
+
return position
|
|
312
|
+
else:
|
|
313
|
+
return 1.0 # Default to all_equal
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def execute_compare_methods(args, model, loader):
|
|
317
|
+
"""Execute method comparison."""
|
|
318
|
+
print(f"š Comparing steering methods for task: {args.task}\n")
|
|
319
|
+
print(f" Methods: {', '.join(args.methods)}")
|
|
320
|
+
print(f" Limit: {args.limit} samples\n")
|
|
321
|
+
|
|
322
|
+
result = loader._load_one_task(
|
|
323
|
+
task_name=args.task,
|
|
324
|
+
split_ratio=0.8,
|
|
325
|
+
seed=42,
|
|
326
|
+
limit=args.limit,
|
|
327
|
+
training_limit=None,
|
|
328
|
+
testing_limit=None
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
print(f"ā
Loaded {len(result['train_qa_pairs'].pairs)} train pairs\n")
|
|
332
|
+
print("ā ļø Full method comparison requires implementation of HPR, DAC, BiPO, KSteering")
|
|
333
|
+
print(" Currently only CAA is fully implemented")
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def execute_optimize_layer(args, model, loader):
|
|
337
|
+
"""Execute layer optimization."""
|
|
338
|
+
print(f"šÆ Optimizing steering layer for task: {args.task}\n")
|
|
339
|
+
print(f" Method: {args.method}")
|
|
340
|
+
print(f" Strength: {args.strength}\n")
|
|
341
|
+
|
|
342
|
+
print("ā ļø Layer optimization not yet fully implemented")
|
|
343
|
+
print(f" This would optimize layer for {args.method} method")
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def execute_optimize_strength(args, model, loader):
|
|
347
|
+
"""Execute strength optimization."""
|
|
348
|
+
print(f"šŖ Optimizing steering strength for task: {args.task}\n")
|
|
349
|
+
print(f" Method: {args.method}")
|
|
350
|
+
print(f" Strength range: {args.strength_range[0]} to {args.strength_range[1]}\n")
|
|
351
|
+
|
|
352
|
+
print("ā ļø Strength optimization not yet fully implemented")
|
|
353
|
+
print(f" This would optimize strength for {args.method} method")
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def execute_auto(args, model, loader):
|
|
357
|
+
"""Execute automatic optimization based on classification config."""
|
|
358
|
+
print(f"š¤ Running automatic steering optimization...\n")
|
|
359
|
+
print(f" Methods: {', '.join(args.methods)}")
|
|
360
|
+
print(f" Strength range: {args.strength_range}\n")
|
|
361
|
+
|
|
362
|
+
print("ā ļø Auto optimization not yet fully implemented")
|
|
363
|
+
print(" This would use classification results to guide steering optimization")
|
|
364
|
+
|
wisent/core/cli/tasks.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Tasks command execution logic."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
import json
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def execute_tasks(args):
|
|
10
|
+
"""Execute the tasks command - train classifier on benchmark tasks."""
|
|
11
|
+
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
12
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
13
|
+
from wisent.core.activations.activations_collector import ActivationCollector
|
|
14
|
+
from wisent.core.activations.core.atoms import ActivationAggregationStrategy
|
|
15
|
+
from wisent.core.classifiers.classifiers.models.logistic import LogisticClassifier
|
|
16
|
+
from wisent.core.classifiers.classifiers.models.mlp import MLPClassifier
|
|
17
|
+
from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainConfig
|
|
18
|
+
from wisent.core.model_persistence import ModelPersistence, create_classifier_metadata
|
|
19
|
+
|
|
20
|
+
print(f"\nšÆ Starting classifier training on task: {args.task_names}")
|
|
21
|
+
print(f" Model: {args.model}")
|
|
22
|
+
print(f" Layer: {args.layer}")
|
|
23
|
+
print(f" Classifier type: {args.classifier_type}")
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
# 1. Load task data using LMEvalDataLoader
|
|
27
|
+
print(f"\nš Loading task '{args.task_names}'...")
|
|
28
|
+
loader = LMEvalDataLoader()
|
|
29
|
+
result = loader._load_one_task(
|
|
30
|
+
task_name=args.task_names,
|
|
31
|
+
split_ratio=args.split_ratio,
|
|
32
|
+
seed=args.seed,
|
|
33
|
+
limit=args.limit,
|
|
34
|
+
training_limit=args.training_limit,
|
|
35
|
+
testing_limit=args.testing_limit
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Use training pairs for classifier training
|
|
39
|
+
pair_set = result['train_qa_pairs']
|
|
40
|
+
print(f" ā Loaded {len(pair_set.pairs)} training pairs")
|
|
41
|
+
|
|
42
|
+
# 3. Load model
|
|
43
|
+
print(f"\nš¤ Loading model '{args.model}'...")
|
|
44
|
+
model = WisentModel(args.model, device=args.device)
|
|
45
|
+
print(f" ā Model loaded with {model.num_layers} layers")
|
|
46
|
+
|
|
47
|
+
# 4. Parse layer specification
|
|
48
|
+
layer = int(args.layer) if isinstance(args.layer, str) else args.layer
|
|
49
|
+
print(f"\nš§ Extracting activations from layer {layer}...")
|
|
50
|
+
|
|
51
|
+
# 5. Collect activations for all pairs
|
|
52
|
+
collector = ActivationCollector(model=model, store_device="cpu")
|
|
53
|
+
|
|
54
|
+
# Map parser values to enum members
|
|
55
|
+
aggregation_map = {
|
|
56
|
+
'average': 'MEAN_POOLING',
|
|
57
|
+
'final': 'LAST_TOKEN',
|
|
58
|
+
'first': 'FIRST_TOKEN',
|
|
59
|
+
'max': 'MAX_POOLING',
|
|
60
|
+
'min': 'MAX_POOLING', # Fallback to MAX_POOLING for min
|
|
61
|
+
}
|
|
62
|
+
aggregation_key = aggregation_map.get(args.token_aggregation.lower(), 'MEAN_POOLING')
|
|
63
|
+
aggregation_strategy = ActivationAggregationStrategy[aggregation_key]
|
|
64
|
+
|
|
65
|
+
positive_activations = []
|
|
66
|
+
negative_activations = []
|
|
67
|
+
|
|
68
|
+
# Convert layer int to string for activation collection
|
|
69
|
+
layer_str = str(layer)
|
|
70
|
+
|
|
71
|
+
for i, pair in enumerate(pair_set.pairs):
|
|
72
|
+
if i % 10 == 0:
|
|
73
|
+
print(f" Processing pair {i+1}/{len(pair_set.pairs)}...", end='\r')
|
|
74
|
+
|
|
75
|
+
# Collect for positive (correct) response
|
|
76
|
+
updated_pair = collector.collect_for_pair(
|
|
77
|
+
pair,
|
|
78
|
+
layers=[layer_str],
|
|
79
|
+
aggregation=aggregation_strategy,
|
|
80
|
+
return_full_sequence=False,
|
|
81
|
+
normalize_layers=False
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Extract activations from positive and negative responses
|
|
85
|
+
if updated_pair.positive_response.layers_activations and layer_str in updated_pair.positive_response.layers_activations:
|
|
86
|
+
act = updated_pair.positive_response.layers_activations[layer_str]
|
|
87
|
+
if act is not None:
|
|
88
|
+
positive_activations.append(act.cpu().numpy())
|
|
89
|
+
|
|
90
|
+
if updated_pair.negative_response.layers_activations and layer_str in updated_pair.negative_response.layers_activations:
|
|
91
|
+
act = updated_pair.negative_response.layers_activations[layer_str]
|
|
92
|
+
if act is not None:
|
|
93
|
+
negative_activations.append(act.cpu().numpy())
|
|
94
|
+
|
|
95
|
+
print(f"\n ā Collected {len(positive_activations)} positive and {len(negative_activations)} negative activations")
|
|
96
|
+
|
|
97
|
+
# 6. Prepare training data
|
|
98
|
+
print(f"\nšÆ Preparing training data...")
|
|
99
|
+
X_positive = np.array(positive_activations)
|
|
100
|
+
X_negative = np.array(negative_activations)
|
|
101
|
+
X = np.vstack([X_positive, X_negative])
|
|
102
|
+
y = np.array([1] * len(positive_activations) + [0] * len(negative_activations))
|
|
103
|
+
|
|
104
|
+
print(f" Training set: {X.shape[0]} samples, {X.shape[1]} features")
|
|
105
|
+
print(f" Positive samples: {sum(y == 1)}, Negative samples: {sum(y == 0)}")
|
|
106
|
+
|
|
107
|
+
# 7. Create and train classifier
|
|
108
|
+
print(f"\nšļø Training {args.classifier_type} classifier...")
|
|
109
|
+
if args.classifier_type == 'logistic':
|
|
110
|
+
classifier = LogisticClassifier(threshold=args.detection_threshold, device=args.device)
|
|
111
|
+
elif args.classifier_type == 'mlp':
|
|
112
|
+
classifier = MLPClassifier(threshold=args.detection_threshold, device=args.device)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unknown classifier type: {args.classifier_type}")
|
|
115
|
+
|
|
116
|
+
# Training configuration
|
|
117
|
+
train_config = ClassifierTrainConfig(
|
|
118
|
+
test_size=1.0 - args.split_ratio,
|
|
119
|
+
num_epochs=50,
|
|
120
|
+
batch_size=32,
|
|
121
|
+
learning_rate=1e-3,
|
|
122
|
+
monitor='f1',
|
|
123
|
+
random_state=args.seed
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Train the classifier
|
|
127
|
+
report = classifier.fit(X, y, config=train_config)
|
|
128
|
+
|
|
129
|
+
# 8. Print results
|
|
130
|
+
print(f"\nš Training completed!")
|
|
131
|
+
print(f" Best epoch: {report.best_epoch}/{report.epochs_ran}")
|
|
132
|
+
print(f" Final metrics:")
|
|
133
|
+
print(f" ⢠Accuracy: {report.final.accuracy:.4f}")
|
|
134
|
+
print(f" ⢠Precision: {report.final.precision:.4f}")
|
|
135
|
+
print(f" ⢠Recall: {report.final.recall:.4f}")
|
|
136
|
+
print(f" ⢠F1 Score: {report.final.f1:.4f}")
|
|
137
|
+
print(f" ⢠AUC: {report.final.auc:.4f}")
|
|
138
|
+
|
|
139
|
+
# 9. Save classifier if requested
|
|
140
|
+
if args.save_classifier:
|
|
141
|
+
print(f"\nš¾ Saving classifier to '{args.save_classifier}'...")
|
|
142
|
+
|
|
143
|
+
# Create metadata
|
|
144
|
+
metadata = create_classifier_metadata(
|
|
145
|
+
model_name=args.model,
|
|
146
|
+
task_name=args.task_names,
|
|
147
|
+
layer=layer,
|
|
148
|
+
classifier_type=args.classifier_type,
|
|
149
|
+
training_accuracy=report.final.accuracy,
|
|
150
|
+
training_samples=len(X),
|
|
151
|
+
token_aggregation=args.token_aggregation,
|
|
152
|
+
detection_threshold=args.detection_threshold
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Save using model persistence
|
|
156
|
+
save_path = ModelPersistence.save_classifier(
|
|
157
|
+
classifier=classifier,
|
|
158
|
+
layer=layer,
|
|
159
|
+
save_path=args.save_classifier,
|
|
160
|
+
metadata=metadata
|
|
161
|
+
)
|
|
162
|
+
print(f" ā Classifier saved to: {save_path}")
|
|
163
|
+
|
|
164
|
+
# 10. Save output artifacts if requested
|
|
165
|
+
if args.output:
|
|
166
|
+
print(f"\nš Saving artifacts to '{args.output}'...")
|
|
167
|
+
os.makedirs(args.output, exist_ok=True)
|
|
168
|
+
|
|
169
|
+
# Save training report
|
|
170
|
+
report_path = os.path.join(args.output, 'training_report.json')
|
|
171
|
+
with open(report_path, 'w') as f:
|
|
172
|
+
json.dump(report.asdict(), f, indent=2)
|
|
173
|
+
print(f" ā Training report saved to: {report_path}")
|
|
174
|
+
|
|
175
|
+
print(f"\nā
Task completed successfully!\n")
|
|
176
|
+
|
|
177
|
+
except Exception as e:
|
|
178
|
+
print(f"\nā Error: {str(e)}", file=sys.stderr)
|
|
179
|
+
if args.verbose:
|
|
180
|
+
import traceback
|
|
181
|
+
traceback.print_exc()
|
|
182
|
+
sys.exit(1)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Simple CLI logger replacement for removed wisent.cli module."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def setup_logger(name: str) -> logging.Logger:
|
|
7
|
+
"""Set up a logger with the given name."""
|
|
8
|
+
logger = logging.getLogger(name)
|
|
9
|
+
if not logger.handlers:
|
|
10
|
+
handler = logging.StreamHandler()
|
|
11
|
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
12
|
+
handler.setFormatter(formatter)
|
|
13
|
+
logger.addHandler(handler)
|
|
14
|
+
logger.setLevel(logging.INFO)
|
|
15
|
+
return logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def bind(logger: logging.Logger, **kwargs) -> logging.Logger:
|
|
19
|
+
"""Bind context to logger (simplified - just returns the logger)."""
|
|
20
|
+
# In the original, this probably added context fields
|
|
21
|
+
# For now, just return the logger as-is
|
|
22
|
+
return logger
|
|
@@ -68,6 +68,7 @@ class LMEvalBenchmarkExtractor(ABC):
|
|
|
68
68
|
cls,
|
|
69
69
|
lm_eval_task_data: ConfigurableTask,
|
|
70
70
|
limit: int | None = None,
|
|
71
|
+
preferred_doc: str | None = None,
|
|
71
72
|
) -> list[dict[str, Any]]:
|
|
72
73
|
"""
|
|
73
74
|
Load labeled documents from the most appropriate split with a clear
|
|
@@ -75,6 +76,9 @@ class LMEvalBenchmarkExtractor(ABC):
|
|
|
75
76
|
|
|
76
77
|
validation ā test ā train ā fewshot
|
|
77
78
|
|
|
79
|
+
If preferred_doc is provided, that source will be tried first before
|
|
80
|
+
falling back to the default order.
|
|
81
|
+
|
|
78
82
|
If none are available, attempts a dataset fallback using
|
|
79
83
|
'datasets.load_dataset' with the task's declared metadata
|
|
80
84
|
(e.g., 'dataset_path'/'dataset_name', 'dataset_config_name',
|
|
@@ -86,6 +90,10 @@ class LMEvalBenchmarkExtractor(ABC):
|
|
|
86
90
|
limit:
|
|
87
91
|
Optional maximum number of documents to return.
|
|
88
92
|
Values <= 0 are treated as "no limit".
|
|
93
|
+
preferred_doc:
|
|
94
|
+
Optional preferred document source. Valid values:
|
|
95
|
+
"validation", "test", "training", "fewshot".
|
|
96
|
+
If provided, this source will be tried first.
|
|
89
97
|
|
|
90
98
|
returns:
|
|
91
99
|
A list of document dictionaries.
|
|
@@ -98,17 +106,35 @@ class LMEvalBenchmarkExtractor(ABC):
|
|
|
98
106
|
"""
|
|
99
107
|
max_items = cls._normalize_limit(limit)
|
|
100
108
|
|
|
101
|
-
|
|
109
|
+
# Map preferred_doc string to the tuple format
|
|
110
|
+
doc_source_map = {
|
|
111
|
+
"validation": ("has_validation_docs", "validation_docs"),
|
|
112
|
+
"test": ("has_test_docs", "test_docs"),
|
|
113
|
+
"training": ("has_training_docs", "training_docs"),
|
|
114
|
+
"fewshot": ("has_fewshot_docs", "fewshot_docs"),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
# Build preferred_sources based on preferred_doc
|
|
118
|
+
default_order: Sequence[tuple[str, str]] = (
|
|
102
119
|
("has_validation_docs", "validation_docs"),
|
|
103
120
|
("has_test_docs", "test_docs"),
|
|
104
121
|
("has_training_docs", "training_docs"),
|
|
105
122
|
("has_fewshot_docs", "fewshot_docs"),
|
|
106
123
|
)
|
|
107
124
|
|
|
125
|
+
if preferred_doc and preferred_doc in doc_source_map:
|
|
126
|
+
# Put preferred source first, then other sources
|
|
127
|
+
preferred_source = doc_source_map[preferred_doc]
|
|
128
|
+
other_sources = [s for s in default_order if s != preferred_source]
|
|
129
|
+
preferred_sources = (preferred_source,) + tuple(other_sources)
|
|
130
|
+
else:
|
|
131
|
+
preferred_sources = default_order
|
|
132
|
+
|
|
108
133
|
for has_method, docs_method in preferred_sources:
|
|
109
134
|
if cls._has_true(lm_eval_task_data, has_method) and cls._has_callable(
|
|
110
135
|
lm_eval_task_data, docs_method
|
|
111
136
|
):
|
|
137
|
+
print(f"loaded from {docs_method}")
|
|
112
138
|
docs_iter = getattr(lm_eval_task_data, docs_method)()
|
|
113
139
|
docs_list = cls._coerce_docs_to_dicts(docs_iter, max_items)
|
|
114
140
|
if docs_list:
|