wisent 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +8 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
- wisent-0.5.13.dist-info/RECORD +294 -0
- wisent-0.5.13.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.11.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Get activations command execution logic."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def execute_get_activations(args):
|
|
10
|
+
"""Execute the get-activations command - load pairs and collect activations."""
|
|
11
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
12
|
+
from wisent.core.activations.activations_collector import ActivationCollector
|
|
13
|
+
from wisent.core.activations.core.atoms import ActivationAggregationStrategy
|
|
14
|
+
from wisent.core.activations.prompt_construction_strategy import PromptConstructionStrategy
|
|
15
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
16
|
+
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
17
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
18
|
+
|
|
19
|
+
print(f"\n🎨 Collecting activations from contrastive pairs")
|
|
20
|
+
print(f" Input file: {args.pairs_file}")
|
|
21
|
+
print(f" Model: {args.model}")
|
|
22
|
+
|
|
23
|
+
start_time = time.time() if args.timing else None
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
# 1. Load pairs from JSON
|
|
27
|
+
print(f"\n📂 Loading contrastive pairs...")
|
|
28
|
+
if not os.path.exists(args.pairs_file):
|
|
29
|
+
raise FileNotFoundError(f"Pairs file not found: {args.pairs_file}")
|
|
30
|
+
|
|
31
|
+
with open(args.pairs_file, 'r') as f:
|
|
32
|
+
data = json.load(f)
|
|
33
|
+
|
|
34
|
+
# Handle both formats: dict with 'pairs' key or direct list
|
|
35
|
+
if isinstance(data, dict):
|
|
36
|
+
pairs_list = data.get('pairs', [])
|
|
37
|
+
task_name = data.get('task_name', 'unknown')
|
|
38
|
+
trait_label = data.get('trait_label', task_name)
|
|
39
|
+
else:
|
|
40
|
+
pairs_list = data
|
|
41
|
+
task_name = 'unknown'
|
|
42
|
+
trait_label = 'unknown'
|
|
43
|
+
|
|
44
|
+
# Apply limit if specified
|
|
45
|
+
if args.limit:
|
|
46
|
+
pairs_list = pairs_list[:args.limit]
|
|
47
|
+
|
|
48
|
+
print(f" ✓ Loaded {len(pairs_list)} pairs")
|
|
49
|
+
|
|
50
|
+
# 2. Load model
|
|
51
|
+
print(f"\n🤖 Loading model '{args.model}'...")
|
|
52
|
+
model = WisentModel(args.model, device=args.device)
|
|
53
|
+
print(f" ✓ Model loaded with {model.num_layers} layers")
|
|
54
|
+
|
|
55
|
+
# 3. Determine layers to collect
|
|
56
|
+
if args.layers is None:
|
|
57
|
+
# Default: use middle layer
|
|
58
|
+
layers = [model.num_layers // 2]
|
|
59
|
+
elif args.layers.lower() == 'all':
|
|
60
|
+
layers = list(range(1, model.num_layers + 1))
|
|
61
|
+
else:
|
|
62
|
+
layers = [int(l.strip()) for l in args.layers.split(',')]
|
|
63
|
+
|
|
64
|
+
# Convert to strings for API
|
|
65
|
+
layer_strs = [str(l) for l in layers]
|
|
66
|
+
|
|
67
|
+
print(f"\n🎯 Collecting activations from {len(layers)} layer(s): {layers}")
|
|
68
|
+
|
|
69
|
+
# 4. Set up aggregation strategy
|
|
70
|
+
aggregation_map = {
|
|
71
|
+
'average': 'MEAN_POOLING',
|
|
72
|
+
'final': 'LAST_TOKEN',
|
|
73
|
+
'first': 'FIRST_TOKEN',
|
|
74
|
+
'max': 'MAX_POOLING',
|
|
75
|
+
'min': 'MAX_POOLING',
|
|
76
|
+
}
|
|
77
|
+
aggregation_key = aggregation_map.get(args.token_aggregation.lower(), 'MEAN_POOLING')
|
|
78
|
+
aggregation_strategy = ActivationAggregationStrategy[aggregation_key]
|
|
79
|
+
|
|
80
|
+
# 5. Map prompt strategy string to enum
|
|
81
|
+
prompt_strategy_map = {
|
|
82
|
+
'chat_template': PromptConstructionStrategy.CHAT_TEMPLATE,
|
|
83
|
+
'direct_completion': PromptConstructionStrategy.DIRECT_COMPLETION,
|
|
84
|
+
'instruction_following': PromptConstructionStrategy.INSTRUCTION_FOLLOWING,
|
|
85
|
+
'multiple_choice': PromptConstructionStrategy.MULTIPLE_CHOICE,
|
|
86
|
+
'role_playing': PromptConstructionStrategy.ROLE_PLAYING,
|
|
87
|
+
}
|
|
88
|
+
prompt_strategy = prompt_strategy_map.get(args.prompt_strategy.lower(), PromptConstructionStrategy.CHAT_TEMPLATE)
|
|
89
|
+
|
|
90
|
+
print(f" Token aggregation: {args.token_aggregation} ({aggregation_key})")
|
|
91
|
+
print(f" Prompt strategy: {args.prompt_strategy}")
|
|
92
|
+
|
|
93
|
+
# 5. Create pair set and reconstruct pairs
|
|
94
|
+
pair_set = ContrastivePairSet(name=task_name, task_type=trait_label)
|
|
95
|
+
|
|
96
|
+
for pair_data in pairs_list:
|
|
97
|
+
pair = ContrastivePair(
|
|
98
|
+
prompt=pair_data['prompt'],
|
|
99
|
+
positive_response=PositiveResponse(
|
|
100
|
+
model_response=pair_data['positive_response']['model_response']
|
|
101
|
+
),
|
|
102
|
+
negative_response=NegativeResponse(
|
|
103
|
+
model_response=pair_data['negative_response']['model_response']
|
|
104
|
+
),
|
|
105
|
+
label=pair_data.get('label', trait_label),
|
|
106
|
+
trait_description=pair_data.get('trait_description', ''),
|
|
107
|
+
)
|
|
108
|
+
pair_set.add(pair)
|
|
109
|
+
|
|
110
|
+
# 6. Collect activations
|
|
111
|
+
print(f"\n⚡ Collecting activations...")
|
|
112
|
+
collector = ActivationCollector(model=model, store_device="cpu")
|
|
113
|
+
|
|
114
|
+
enriched_pairs = []
|
|
115
|
+
for i, pair in enumerate(pair_set.pairs):
|
|
116
|
+
if args.verbose:
|
|
117
|
+
print(f" Processing pair {i+1}/{len(pair_set.pairs)}...")
|
|
118
|
+
|
|
119
|
+
# Collect activations for all requested layers at once
|
|
120
|
+
updated_pair = collector.collect_for_pair(
|
|
121
|
+
pair,
|
|
122
|
+
layers=layer_strs,
|
|
123
|
+
aggregation=aggregation_strategy,
|
|
124
|
+
return_full_sequence=False,
|
|
125
|
+
normalize_layers=False,
|
|
126
|
+
prompt_strategy=prompt_strategy
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
enriched_pairs.append(updated_pair)
|
|
130
|
+
|
|
131
|
+
print(f" ✓ Collected activations for {len(enriched_pairs)} pairs")
|
|
132
|
+
|
|
133
|
+
# 7. Convert to JSON format
|
|
134
|
+
print(f"\n💾 Saving enriched pairs to '{args.output}'...")
|
|
135
|
+
output_data = {
|
|
136
|
+
'task_name': task_name,
|
|
137
|
+
'trait_label': trait_label,
|
|
138
|
+
'model': args.model,
|
|
139
|
+
'layers': layers,
|
|
140
|
+
'token_aggregation': args.token_aggregation,
|
|
141
|
+
'num_pairs': len(enriched_pairs),
|
|
142
|
+
'pairs': []
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
for pair in enriched_pairs:
|
|
146
|
+
pair_dict = {
|
|
147
|
+
'prompt': pair.prompt,
|
|
148
|
+
'positive_response': {
|
|
149
|
+
'model_response': pair.positive_response.model_response,
|
|
150
|
+
'layers_activations': {}
|
|
151
|
+
},
|
|
152
|
+
'negative_response': {
|
|
153
|
+
'model_response': pair.negative_response.model_response,
|
|
154
|
+
'layers_activations': {}
|
|
155
|
+
},
|
|
156
|
+
'label': pair.label,
|
|
157
|
+
'trait_description': pair.trait_description,
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
# Convert activations to lists for JSON serialization
|
|
161
|
+
if pair.positive_response.layers_activations:
|
|
162
|
+
for layer_str, act in pair.positive_response.layers_activations.items():
|
|
163
|
+
if act is not None:
|
|
164
|
+
pair_dict['positive_response']['layers_activations'][layer_str] = act.cpu().tolist()
|
|
165
|
+
|
|
166
|
+
if pair.negative_response.layers_activations:
|
|
167
|
+
for layer_str, act in pair.negative_response.layers_activations.items():
|
|
168
|
+
if act is not None:
|
|
169
|
+
pair_dict['negative_response']['layers_activations'][layer_str] = act.cpu().tolist()
|
|
170
|
+
|
|
171
|
+
output_data['pairs'].append(pair_dict)
|
|
172
|
+
|
|
173
|
+
# 8. Save to file
|
|
174
|
+
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
|
|
175
|
+
with open(args.output, 'w') as f:
|
|
176
|
+
json.dump(output_data, f, indent=2)
|
|
177
|
+
|
|
178
|
+
print(f" ✓ Saved enriched pairs to: {args.output}")
|
|
179
|
+
|
|
180
|
+
if args.timing:
|
|
181
|
+
elapsed = time.time() - start_time
|
|
182
|
+
print(f" ⏱️ Total time: {elapsed:.2f}s")
|
|
183
|
+
|
|
184
|
+
print(f"\n✅ Activation collection completed successfully!\n")
|
|
185
|
+
|
|
186
|
+
except Exception as e:
|
|
187
|
+
print(f"\n❌ Error: {str(e)}", file=sys.stderr)
|
|
188
|
+
if args.verbose:
|
|
189
|
+
import traceback
|
|
190
|
+
traceback.print_exc()
|
|
191
|
+
sys.exit(1)
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""Classification optimization command execution logic."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from typing import List, Dict, Any
|
|
7
|
+
|
|
8
|
+
def execute_optimize_classification(args):
|
|
9
|
+
"""
|
|
10
|
+
Execute the optimize-classification command.
|
|
11
|
+
|
|
12
|
+
Optimizes classification parameters across all available tasks:
|
|
13
|
+
- Finds best layer for each task
|
|
14
|
+
- Finds best token aggregation method
|
|
15
|
+
- Finds best detection threshold
|
|
16
|
+
- Saves trained classifiers
|
|
17
|
+
|
|
18
|
+
EFFICIENCY: Collects raw activations ONCE, then applies different aggregation strategies
|
|
19
|
+
to the cached activations without re-running the model.
|
|
20
|
+
"""
|
|
21
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
22
|
+
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
23
|
+
from wisent.core.activations.activations_collector import ActivationCollector
|
|
24
|
+
from wisent.core.activations.core.atoms import ActivationAggregationStrategy
|
|
25
|
+
from wisent.core.classifiers.classifiers.models.logistic import LogisticClassifier
|
|
26
|
+
from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainConfig
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
29
|
+
|
|
30
|
+
print(f"\n{'='*80}")
|
|
31
|
+
print(f"🔍 CLASSIFICATION PARAMETER OPTIMIZATION")
|
|
32
|
+
print(f"{'='*80}")
|
|
33
|
+
print(f" Model: {args.model}")
|
|
34
|
+
print(f" Limit per task: {args.limit}")
|
|
35
|
+
print(f" Optimization metric: {args.optimization_metric}")
|
|
36
|
+
print(f" Device: {args.device or 'auto'}")
|
|
37
|
+
print(f"{'='*80}\n")
|
|
38
|
+
|
|
39
|
+
# 1. Load model
|
|
40
|
+
print(f"📦 Loading model...")
|
|
41
|
+
model = WisentModel(args.model, device=args.device)
|
|
42
|
+
total_layers = model.num_layers
|
|
43
|
+
print(f" ✓ Model loaded with {total_layers} layers\n")
|
|
44
|
+
|
|
45
|
+
# 2. Determine layer range
|
|
46
|
+
if args.layer_range:
|
|
47
|
+
start, end = map(int, args.layer_range.split('-'))
|
|
48
|
+
layers_to_test = list(range(start, end + 1))
|
|
49
|
+
else:
|
|
50
|
+
# Test middle layers by default (more informative)
|
|
51
|
+
start_layer = total_layers // 3
|
|
52
|
+
end_layer = (2 * total_layers) // 3
|
|
53
|
+
layers_to_test = list(range(start_layer, end_layer + 1))
|
|
54
|
+
|
|
55
|
+
print(f"🎯 Testing layers: {layers_to_test[0]} to {layers_to_test[-1]} ({len(layers_to_test)} layers)")
|
|
56
|
+
print(f"🔄 Aggregation methods: {', '.join(args.aggregation_methods)}")
|
|
57
|
+
print(f"📊 Thresholds: {args.threshold_range}\n")
|
|
58
|
+
|
|
59
|
+
# 3. Get list of tasks to optimize
|
|
60
|
+
task_list = [
|
|
61
|
+
"arc_easy", "arc_challenge", "hellaswag",
|
|
62
|
+
"winogrande", "gsm8k"
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
print(f"📋 Optimizing {len(task_list)} tasks\n")
|
|
66
|
+
|
|
67
|
+
# 4. Initialize data loader
|
|
68
|
+
loader = LMEvalDataLoader()
|
|
69
|
+
|
|
70
|
+
# 5. Results storage
|
|
71
|
+
all_results = {}
|
|
72
|
+
classifiers_saved = {}
|
|
73
|
+
|
|
74
|
+
# 6. Process each task
|
|
75
|
+
for task_idx, task_name in enumerate(task_list, 1):
|
|
76
|
+
print(f"\n{'='*80}")
|
|
77
|
+
print(f"Task {task_idx}/{len(task_list)}: {task_name}")
|
|
78
|
+
print(f"{'='*80}")
|
|
79
|
+
|
|
80
|
+
task_start_time = time.time()
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Load task data
|
|
84
|
+
print(f" 📊 Loading data...")
|
|
85
|
+
result = loader._load_one_task(
|
|
86
|
+
task_name=task_name,
|
|
87
|
+
split_ratio=0.8,
|
|
88
|
+
seed=42,
|
|
89
|
+
limit=args.limit,
|
|
90
|
+
training_limit=None,
|
|
91
|
+
testing_limit=None
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
train_pairs = result['train_qa_pairs']
|
|
95
|
+
test_pairs = result['test_qa_pairs']
|
|
96
|
+
|
|
97
|
+
print(f" ✓ Loaded {len(train_pairs.pairs)} train, {len(test_pairs.pairs)} test pairs")
|
|
98
|
+
|
|
99
|
+
# STEP 1: Collect raw activations ONCE for all layers (full sequence)
|
|
100
|
+
print(f" 🧠 Collecting raw activations (once per pair)...")
|
|
101
|
+
collector = ActivationCollector(model=model, store_device="cpu")
|
|
102
|
+
|
|
103
|
+
# Cache structure: train_cache[pair_idx][layer_str] = {pos: tensor, neg: tensor, pos_tokens: int, neg_tokens: int}
|
|
104
|
+
train_cache = {}
|
|
105
|
+
test_cache = {}
|
|
106
|
+
|
|
107
|
+
layer_strs = [str(l) for l in layers_to_test]
|
|
108
|
+
|
|
109
|
+
# Collect training activations with full sequence
|
|
110
|
+
for pair_idx, pair in enumerate(train_pairs.pairs):
|
|
111
|
+
updated_pair = collector.collect_for_pair(
|
|
112
|
+
pair,
|
|
113
|
+
layers=layer_strs,
|
|
114
|
+
aggregation=None, # Get raw activations without aggregation
|
|
115
|
+
return_full_sequence=True, # Get all token positions
|
|
116
|
+
normalize_layers=False
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
train_cache[pair_idx] = {}
|
|
120
|
+
for layer_str in layer_strs:
|
|
121
|
+
train_cache[pair_idx][layer_str] = {
|
|
122
|
+
'pos': updated_pair.positive_response.layers_activations.get(layer_str),
|
|
123
|
+
'neg': updated_pair.negative_response.layers_activations.get(layer_str),
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
# Collect test activations
|
|
127
|
+
for pair_idx, pair in enumerate(test_pairs.pairs):
|
|
128
|
+
updated_pair = collector.collect_for_pair(
|
|
129
|
+
pair,
|
|
130
|
+
layers=layer_strs,
|
|
131
|
+
aggregation=None,
|
|
132
|
+
return_full_sequence=True,
|
|
133
|
+
normalize_layers=False
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
test_cache[pair_idx] = {}
|
|
137
|
+
for layer_str in layer_strs:
|
|
138
|
+
test_cache[pair_idx][layer_str] = {
|
|
139
|
+
'pos': updated_pair.positive_response.layers_activations.get(layer_str),
|
|
140
|
+
'neg': updated_pair.negative_response.layers_activations.get(layer_str),
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
print(f" ✓ Cached activations for {len(train_cache)} train and {len(test_cache)} test pairs")
|
|
144
|
+
|
|
145
|
+
# STEP 2: Apply different aggregation strategies to cached activations
|
|
146
|
+
print(f" 🔍 Testing {len(layers_to_test) * len(args.aggregation_methods)} layer/aggregation combinations...")
|
|
147
|
+
|
|
148
|
+
# Aggregation functions
|
|
149
|
+
def aggregate_activations(raw_acts, method):
|
|
150
|
+
"""Apply aggregation to raw activation tensor."""
|
|
151
|
+
if raw_acts is None or raw_acts.numel() == 0:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
# Handle both 1D (already aggregated) and 2D (sequence, hidden_dim) tensors
|
|
155
|
+
if raw_acts.ndim == 1:
|
|
156
|
+
return raw_acts
|
|
157
|
+
elif raw_acts.ndim == 2:
|
|
158
|
+
if method == 'average':
|
|
159
|
+
return raw_acts.mean(dim=0)
|
|
160
|
+
elif method == 'final':
|
|
161
|
+
return raw_acts[-1]
|
|
162
|
+
elif method == 'first':
|
|
163
|
+
return raw_acts[0]
|
|
164
|
+
elif method == 'max':
|
|
165
|
+
return raw_acts.max(dim=0)[0]
|
|
166
|
+
elif method == 'min':
|
|
167
|
+
return raw_acts.min(dim=0)[0]
|
|
168
|
+
else:
|
|
169
|
+
# Flatten to 2D if needed
|
|
170
|
+
raw_acts = raw_acts.view(-1, raw_acts.shape[-1])
|
|
171
|
+
return aggregate_activations(raw_acts, method)
|
|
172
|
+
|
|
173
|
+
best_score = -1
|
|
174
|
+
best_config = None
|
|
175
|
+
best_classifier = None
|
|
176
|
+
|
|
177
|
+
combinations_tested = 0
|
|
178
|
+
total_combinations = len(layers_to_test) * len(args.aggregation_methods)
|
|
179
|
+
|
|
180
|
+
for layer in layers_to_test:
|
|
181
|
+
layer_str = str(layer)
|
|
182
|
+
|
|
183
|
+
for agg_method in args.aggregation_methods:
|
|
184
|
+
# Apply aggregation to cached activations
|
|
185
|
+
train_pos_acts = []
|
|
186
|
+
train_neg_acts = []
|
|
187
|
+
|
|
188
|
+
for pair_idx in train_cache:
|
|
189
|
+
pos_raw = train_cache[pair_idx][layer_str]['pos']
|
|
190
|
+
neg_raw = train_cache[pair_idx][layer_str]['neg']
|
|
191
|
+
|
|
192
|
+
pos_agg = aggregate_activations(pos_raw, agg_method)
|
|
193
|
+
neg_agg = aggregate_activations(neg_raw, agg_method)
|
|
194
|
+
|
|
195
|
+
if pos_agg is not None:
|
|
196
|
+
train_pos_acts.append(pos_agg.cpu().numpy())
|
|
197
|
+
if neg_agg is not None:
|
|
198
|
+
train_neg_acts.append(neg_agg.cpu().numpy())
|
|
199
|
+
|
|
200
|
+
if len(train_pos_acts) == 0 or len(train_neg_acts) == 0:
|
|
201
|
+
combinations_tested += 1
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
# Prepare training data
|
|
205
|
+
X_train_pos = np.array(train_pos_acts)
|
|
206
|
+
X_train_neg = np.array(train_neg_acts)
|
|
207
|
+
X_train = np.vstack([X_train_pos, X_train_neg])
|
|
208
|
+
y_train = np.array([1] * len(train_pos_acts) + [0] * len(train_neg_acts))
|
|
209
|
+
|
|
210
|
+
# Train classifier
|
|
211
|
+
classifier = LogisticClassifier(threshold=0.5, device="cpu")
|
|
212
|
+
|
|
213
|
+
config = ClassifierTrainConfig(
|
|
214
|
+
test_size=0.2,
|
|
215
|
+
batch_size=32,
|
|
216
|
+
num_epochs=30,
|
|
217
|
+
learning_rate=0.001,
|
|
218
|
+
monitor="f1",
|
|
219
|
+
random_state=42
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
report = classifier.fit(
|
|
223
|
+
torch.tensor(X_train, dtype=torch.float32),
|
|
224
|
+
torch.tensor(y_train, dtype=torch.float32),
|
|
225
|
+
config=config
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Apply aggregation to test set
|
|
229
|
+
test_pos_acts = []
|
|
230
|
+
test_neg_acts = []
|
|
231
|
+
|
|
232
|
+
for pair_idx in test_cache:
|
|
233
|
+
pos_raw = test_cache[pair_idx][layer_str]['pos']
|
|
234
|
+
neg_raw = test_cache[pair_idx][layer_str]['neg']
|
|
235
|
+
|
|
236
|
+
pos_agg = aggregate_activations(pos_raw, agg_method)
|
|
237
|
+
neg_agg = aggregate_activations(neg_raw, agg_method)
|
|
238
|
+
|
|
239
|
+
if pos_agg is not None:
|
|
240
|
+
test_pos_acts.append(pos_agg.cpu().numpy())
|
|
241
|
+
if neg_agg is not None:
|
|
242
|
+
test_neg_acts.append(neg_agg.cpu().numpy())
|
|
243
|
+
|
|
244
|
+
if len(test_pos_acts) == 0 or len(test_neg_acts) == 0:
|
|
245
|
+
combinations_tested += 1
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
X_test_pos = np.array(test_pos_acts)
|
|
249
|
+
X_test_neg = np.array(test_neg_acts)
|
|
250
|
+
X_test = np.vstack([X_test_pos, X_test_neg])
|
|
251
|
+
y_test = np.array([1] * len(test_pos_acts) + [0] * len(test_neg_acts))
|
|
252
|
+
|
|
253
|
+
# Get predictions
|
|
254
|
+
y_pred_proba = np.array(classifier.predict_proba(X_test))
|
|
255
|
+
|
|
256
|
+
# Test different thresholds
|
|
257
|
+
for threshold in args.threshold_range:
|
|
258
|
+
y_pred = (y_pred_proba > threshold).astype(int)
|
|
259
|
+
|
|
260
|
+
# Calculate metrics
|
|
261
|
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
262
|
+
|
|
263
|
+
accuracy = accuracy_score(y_test, y_pred)
|
|
264
|
+
f1 = f1_score(y_test, y_pred, zero_division=0)
|
|
265
|
+
precision = precision_score(y_test, y_pred, zero_division=0)
|
|
266
|
+
recall = recall_score(y_test, y_pred, zero_division=0)
|
|
267
|
+
|
|
268
|
+
# Choose metric based on args
|
|
269
|
+
metric_value = {
|
|
270
|
+
'f1': f1,
|
|
271
|
+
'accuracy': accuracy,
|
|
272
|
+
'precision': precision,
|
|
273
|
+
'recall': recall
|
|
274
|
+
}[args.optimization_metric]
|
|
275
|
+
|
|
276
|
+
if metric_value > best_score:
|
|
277
|
+
best_score = metric_value
|
|
278
|
+
best_config = {
|
|
279
|
+
'layer': layer,
|
|
280
|
+
'aggregation': agg_method,
|
|
281
|
+
'threshold': threshold,
|
|
282
|
+
'accuracy': float(accuracy),
|
|
283
|
+
'f1': float(f1),
|
|
284
|
+
'precision': float(precision),
|
|
285
|
+
'recall': float(recall)
|
|
286
|
+
}
|
|
287
|
+
best_classifier = classifier
|
|
288
|
+
|
|
289
|
+
combinations_tested += 1
|
|
290
|
+
print(f" Progress: {combinations_tested}/{total_combinations} combinations tested", end='\r')
|
|
291
|
+
|
|
292
|
+
print(f"\n ✅ Best config: layer={best_config['layer']}, agg={best_config['aggregation']}, thresh={best_config['threshold']:.2f}")
|
|
293
|
+
print(f" Metrics: acc={best_config['accuracy']:.3f}, f1={best_config['f1']:.3f}, prec={best_config['precision']:.3f}, rec={best_config['recall']:.3f}")
|
|
294
|
+
|
|
295
|
+
all_results[task_name] = best_config
|
|
296
|
+
|
|
297
|
+
# Note: Classifier saving disabled due to missing .save() method
|
|
298
|
+
# Can be enabled once proper serialization is implemented
|
|
299
|
+
|
|
300
|
+
task_time = time.time() - task_start_time
|
|
301
|
+
print(f" ⏱️ Task completed in {task_time:.1f}s")
|
|
302
|
+
|
|
303
|
+
except Exception as e:
|
|
304
|
+
print(f" ❌ Failed to optimize {task_name}: {e}")
|
|
305
|
+
import traceback
|
|
306
|
+
traceback.print_exc()
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
# 7. Save results
|
|
310
|
+
print(f"\n{'='*80}")
|
|
311
|
+
print(f"📊 OPTIMIZATION COMPLETE")
|
|
312
|
+
print(f"{'='*80}\n")
|
|
313
|
+
|
|
314
|
+
results_file = args.results_file or f"./optimization_results/classification_results.json"
|
|
315
|
+
import os
|
|
316
|
+
os.makedirs(os.path.dirname(results_file) if os.path.dirname(results_file) else ".", exist_ok=True)
|
|
317
|
+
|
|
318
|
+
output_data = {
|
|
319
|
+
'model': args.model,
|
|
320
|
+
'optimization_metric': args.optimization_metric,
|
|
321
|
+
'layer_range': f"{layers_to_test[0]}-{layers_to_test[-1]}",
|
|
322
|
+
'aggregation_methods': args.aggregation_methods,
|
|
323
|
+
'threshold_range': args.threshold_range,
|
|
324
|
+
'tasks': all_results,
|
|
325
|
+
'classifiers_saved': classifiers_saved
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
with open(results_file, 'w') as f:
|
|
329
|
+
json.dump(output_data, f, indent=2)
|
|
330
|
+
|
|
331
|
+
print(f"✅ Results saved to: {results_file}\n")
|
|
332
|
+
|
|
333
|
+
# Print summary
|
|
334
|
+
print("📋 SUMMARY BY TASK:")
|
|
335
|
+
print("-" * 80)
|
|
336
|
+
for task_name, config in all_results.items():
|
|
337
|
+
print(f" {task_name:20s} | Layer: {config['layer']:2d} | Agg: {config['aggregation']:8s} | Thresh: {config['threshold']:.2f} | F1: {config['f1']:.3f}")
|
|
338
|
+
print("-" * 80 + "\n")
|
|
339
|
+
|