wisent 0.7.379__py3-none-any.whl → 0.7.701__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +22 -6
- wisent/core/activations/activations.py +21 -39
- wisent/core/activations/activations_collector.py +141 -373
- wisent/core/activations/classifier_inference_strategy.py +194 -0
- wisent/core/activations/core/atoms.py +8 -92
- wisent/core/activations/extraction_strategy.py +308 -0
- wisent/core/agent/diagnose/response_diagnostics.py +3 -3
- wisent/core/agent/diagnose.py +3 -3
- wisent/core/autonomous_agent.py +2 -2
- wisent/core/cli/agent/apply_steering.py +23 -27
- wisent/core/cli/agent/evaluate_response.py +18 -20
- wisent/core/cli/agent/train_classifier.py +18 -20
- wisent/core/cli/cluster_benchmarks.py +472 -0
- wisent/core/cli/create_steering_vector.py +13 -5
- wisent/core/cli/generate_vector_from_task.py +4 -0
- wisent/core/cli/get_activations.py +12 -36
- wisent/core/cli/method_optimizer.py +859 -0
- wisent/core/cli/optimize.py +44 -5
- wisent/core/cli/optimize_classification.py +5 -6
- wisent/core/cli/optimize_sample_size.py +8 -22
- wisent/core/cli/optimize_steering.py +429 -153
- wisent/core/cli/optimize_weights.py +65 -6
- wisent/core/cli/steering_method_trainer.py +5 -4
- wisent/core/cli/steering_search_space.py +20 -15
- wisent/core/cli/tasks.py +14 -43
- wisent/core/cli/train_unified_goodness.py +17 -18
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +1578 -173
- wisent/core/contrastive_pairs/diagnostics/linearity.py +63 -80
- wisent/core/contrastive_pairs/diagnostics/vector_quality.py +6 -5
- wisent/core/contrastive_pairs/huggingface_pairs/hf_extractor_manifest.py +5 -19
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/__init__.py +11 -5
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/apps.py +146 -32
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue.py +2 -2
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/humaneval.py +98 -57
- wisent/core/contrastive_pairs/lm_eval_pairs/group_task_manifests/code_x_glue.py +8 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/group_task_manifests/freebase.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/agieval_aqua_rat.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/code_x_glue.py +11 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mbpp.py +47 -6
- wisent/core/evaluators/benchmark_specific/apps_evaluator.py +133 -0
- wisent/core/evaluators/benchmark_specific/coding/metrics/evaluator.py +6 -1
- wisent/core/evaluators/benchmark_specific/conala_evaluator.py +31 -168
- wisent/core/evaluators/custom/examples/humanization_coherent.py +89 -35
- wisent/core/evaluators/oracles/truthfulqa_gen_evaluator.py +2 -20
- wisent/core/evaluators/personalization/coherence.py +46 -0
- wisent/core/hyperparameter_optimizer.py +13 -13
- wisent/core/lm_eval_harness_ground_truth.py +7 -11
- wisent/core/main.py +3 -0
- wisent/core/models/wisent_model.py +8 -7
- wisent/core/opti/methods/opti_weights.py +29 -2
- wisent/core/optuna/classifier/activation_generator.py +14 -12
- wisent/core/optuna/steering/steering_optimization.py +14 -9
- wisent/core/parser_arguments/cluster_benchmarks_parser.py +31 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +20 -0
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +117 -10
- wisent/core/parser_arguments/optimize_weights_parser.py +6 -0
- wisent/core/parser_arguments/tasks_parser.py +7 -19
- wisent/core/steering_methods/core/atoms.py +1 -2
- wisent/core/steering_methods/methods/caa.py +1 -1
- wisent/core/steering_methods/methods/hyperplane.py +74 -0
- wisent/core/steering_methods/methods/prism.py +1 -2
- wisent/core/steering_methods/methods/pulse.py +39 -8
- wisent/core/steering_methods/methods/titan.py +59 -14
- wisent/core/steering_methods/registry.py +52 -12
- wisent/core/steering_optimizer.py +15 -15
- wisent/core/trainers/steering_trainer.py +9 -18
- wisent/parameters/lm_eval/track_progress_not_lm_eval_tasks.json +19 -70
- wisent/scripts/run_quality_metrics_sweep.sh +22 -27
- wisent/tests/test_aggregation_geometry.py +236 -0
- wisent/tests/test_detector_accuracy.py +163 -0
- wisent/tests/test_geometry_exhaustive.py +1202 -0
- wisent/tests/visualize_geometry.py +255 -61
- {wisent-0.7.379.dist-info → wisent-0.7.701.dist-info}/METADATA +1 -1
- {wisent-0.7.379.dist-info → wisent-0.7.701.dist-info}/RECORD +82 -714
- wisent/core/activations/prompt_construction_strategy.py +0 -47
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text.py +0 -15
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_go.py +0 -64
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_java.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_javascript.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_php.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_python.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_ruby.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/freebase.py +0 -99
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/instruct_humaneval.py +0 -180
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/instructhumaneval.py +0 -129
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mbpp.py +0 -142
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/agieval.py +0 -155
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/code2text.py +0 -161
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/codexglue.py +0 -107
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livemathbench.py +0 -155
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/polymath.py +0 -155
- wisent/examples/scripts/results/benchmark_descriptions.json +0 -1244
- wisent/examples/scripts/results/benchmark_evaluation_methods.json +0 -66
- wisent/examples/scripts/results/benchmark_evaluator_mapping.json +0 -2781
- wisent/examples/scripts/results/benchmark_evaluator_mapping_updated.json +0 -30536
- wisent/examples/scripts/results/benchmark_evaluators_clean.json +0 -469
- wisent/examples/scripts/results/benchmark_methods_summary.json +0 -260
- wisent/examples/scripts/results/benchmark_pair_creation_methods.json +0 -66
- wisent/examples/scripts/results/benchmark_pair_totals.json +0 -269
- wisent/examples/scripts/results/benchmark_tags.json +0 -917
- wisent/examples/scripts/results/benchmark_test_summary_nov4.json +0 -71
- wisent/examples/scripts/results/coding_benchmarks_test_code_status.json +0 -150
- wisent/examples/scripts/results/failing_benchmarks.json +0 -946
- wisent/examples/scripts/results/failing_benchmarks_list.json +0 -41
- wisent/examples/scripts/results/failing_benchmarks_test_results.json +0 -945
- wisent/examples/scripts/results/missing_benchmark_tags.json +0 -341
- wisent/examples/scripts/results/test_20_newsgroups_evaluation.json +0 -30
- wisent/examples/scripts/results/test_20_newsgroups_pairs.json +0 -8
- wisent/examples/scripts/results/test_AraDICE_evaluation.json +0 -51
- wisent/examples/scripts/results/test_AraDICE_pairs.json +0 -14
- wisent/examples/scripts/results/test_AraDiCE_boolq_egy/test_AraDiCE_boolq_egy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_AraDiCE_boolq_egy/test_AraDiCE_boolq_egy_pairs.json +0 -8
- wisent/examples/scripts/results/test_ArabCulture_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ArabCulture_pairs.json +0 -14
- wisent/examples/scripts/results/test_Tag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_Tag_pairs.json +0 -8
- wisent/examples/scripts/results/test_aclue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aclue_pairs.json +0 -14
- wisent/examples/scripts/results/test_acp_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_acp_bench_hard_evaluation.json +0 -51
- wisent/examples/scripts/results/test_acp_bench_hard_pairs.json +0 -14
- wisent/examples/scripts/results/test_acp_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_advanced_ai_risk_evaluation.json +0 -51
- wisent/examples/scripts/results/test_advanced_ai_risk_pairs.json +0 -14
- wisent/examples/scripts/results/test_aexams_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aexams_pairs.json +0 -14
- wisent/examples/scripts/results/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/results/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/results/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/results/test_ag_news_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ag_news_pairs.json +0 -8
- wisent/examples/scripts/results/test_agieval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_agieval_pairs.json +0 -14
- wisent/examples/scripts/results/test_aime2024_evaluation.json +0 -30
- wisent/examples/scripts/results/test_aime2024_pairs.json +0 -8
- wisent/examples/scripts/results/test_aime2025_evaluation.json +0 -30
- wisent/examples/scripts/results/test_aime2025_pairs.json +0 -8
- wisent/examples/scripts/results/test_aime_evaluation.json +0 -30
- wisent/examples/scripts/results/test_aime_pairs.json +0 -8
- wisent/examples/scripts/results/test_anagrams1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_anagrams1_pairs.json +0 -8
- wisent/examples/scripts/results/test_anagrams2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_anagrams2_pairs.json +0 -8
- wisent/examples/scripts/results/test_anli_evaluation.json +0 -30
- wisent/examples/scripts/results/test_anli_pairs.json +0 -8
- wisent/examples/scripts/results/test_apps_evaluation.json +0 -30
- wisent/examples/scripts/results/test_apps_pairs.json +0 -8
- wisent/examples/scripts/results/test_arabic_exams_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arabic_exams_pairs.json +0 -8
- wisent/examples/scripts/results/test_arabic_leaderboard_complete_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arabic_leaderboard_complete_pairs.json +0 -14
- wisent/examples/scripts/results/test_arabic_leaderboard_light_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arabic_leaderboard_light_pairs.json +0 -14
- wisent/examples/scripts/results/test_arabicmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arabicmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_aradice/test_aradice_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aradice/test_aradice_pairs.json +0 -14
- wisent/examples/scripts/results/test_aradice3/test_aradice_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aradice3/test_aradice_pairs.json +0 -14
- wisent/examples/scripts/results/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/results/test_arc_challenge_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arc_challenge_pairs.json +0 -8
- wisent/examples/scripts/results/test_arc_easy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arc_easy_pairs.json +0 -8
- wisent/examples/scripts/results/test_argument_topic_evaluation.json +0 -30
- wisent/examples/scripts/results/test_argument_topic_pairs.json +0 -8
- wisent/examples/scripts/results/test_arithmetic_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arithmetic_pairs.json +0 -14
- wisent/examples/scripts/results/test_asdiv_evaluation.json +0 -30
- wisent/examples/scripts/results/test_asdiv_pairs.json +0 -8
- wisent/examples/scripts/results/test_assin_entailment_evaluation.json +0 -30
- wisent/examples/scripts/results/test_assin_entailment_pairs.json +0 -8
- wisent/examples/scripts/results/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/results/test_atis_pairs.json +0 -8
- wisent/examples/scripts/results/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/results/test_babi_pairs.json +0 -8
- wisent/examples/scripts/results/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/results/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/results/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/results/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/results/test_banking77_evaluation.json +0 -30
- wisent/examples/scripts/results/test_banking77_pairs.json +0 -8
- wisent/examples/scripts/results/test_basque/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque-glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque2/test_basque-glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque_glue/test_basque-glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque_glue/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/results/test_bbh_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bbh_pairs.json +0 -14
- wisent/examples/scripts/results/test_bbq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_bbq_pairs.json +0 -8
- wisent/examples/scripts/results/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/results/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/results/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/results/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/results/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/results/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/results/test_bigbench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bigbench_pairs.json +0 -14
- wisent/examples/scripts/results/test_blimp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_blimp_pairs.json +0 -14
- wisent/examples/scripts/results/test_boolq/test_boolq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_boolq/test_boolq_pairs.json +0 -8
- wisent/examples/scripts/results/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/results/test_boolq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_boolq_pairs.json +0 -8
- wisent/examples/scripts/results/test_c4_evaluation.json +0 -30
- wisent/examples/scripts/results/test_c4_pairs.json +0 -8
- wisent/examples/scripts/results/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/results/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_catalan_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_catalan_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/results/test_cb_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cb_pairs.json +0 -8
- wisent/examples/scripts/results/test_ceval/test_ceval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ceval/test_ceval_pairs.json +0 -14
- wisent/examples/scripts/results/test_ceval_accountant/test_ceval-valid_accountant_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ceval_accountant/test_ceval-valid_accountant_pairs.json +0 -8
- wisent/examples/scripts/results/test_ceval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ceval_pairs.json +0 -14
- wisent/examples/scripts/results/test_ceval_valid/test_ceval_valid_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ceval_valid/test_ceval_valid_pairs.json +0 -14
- wisent/examples/scripts/results/test_chain_of_thought_evaluation.json +0 -51
- wisent/examples/scripts/results/test_chain_of_thought_pairs.json +0 -14
- wisent/examples/scripts/results/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/results/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/results/test_cmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_cmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/results/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_go_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_go_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_java_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_java_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_javascript_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_javascript_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_php_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_php_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_python_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_python_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_ruby_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_ruby_pairs.json +0 -8
- wisent/examples/scripts/results/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/results/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/results/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cola_pairs.json +0 -8
- wisent/examples/scripts/results/test_commonsense_qa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_commonsense_qa_pairs.json +0 -8
- wisent/examples/scripts/results/test_conala_evaluation.json +0 -30
- wisent/examples/scripts/results/test_conala_pairs.json +0 -8
- wisent/examples/scripts/results/test_concode_evaluation.json +0 -30
- wisent/examples/scripts/results/test_concode_pairs.json +0 -8
- wisent/examples/scripts/results/test_copa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_copa_pairs.json +0 -8
- wisent/examples/scripts/results/test_copal_id_evaluation.json +0 -30
- wisent/examples/scripts/results/test_copal_id_pairs.json +0 -8
- wisent/examples/scripts/results/test_coqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_coqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/results/test_crows_pairs_evaluation.json +0 -51
- wisent/examples/scripts/results/test_crows_pairs_pairs.json +0 -14
- wisent/examples/scripts/results/test_csatqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_csatqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_cycle_letters_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cycle_letters_pairs.json +0 -8
- wisent/examples/scripts/results/test_darija_bench/test_darija_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_darija_bench/test_darija_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_darija_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_darija_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_darijahellaswag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_darijahellaswag_pairs.json +0 -8
- wisent/examples/scripts/results/test_darijammlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_darijammlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/results/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/results/test_drop_evaluation.json +0 -30
- wisent/examples/scripts/results/test_drop_pairs.json +0 -8
- wisent/examples/scripts/results/test_ds1000_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ds1000_pairs.json +0 -8
- wisent/examples/scripts/results/test_egyhellaswag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_egyhellaswag_pairs.json +0 -8
- wisent/examples/scripts/results/test_egymmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_egymmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/results/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/results/test_eq_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eq_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_escola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_escola_pairs.json +0 -8
- wisent/examples/scripts/results/test_ethics_cm_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ethics_cm_pairs.json +0 -8
- wisent/examples/scripts/results/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/results/test_eus_exams/test_eus_exams_evaluation.json +0 -51
- wisent/examples/scripts/results/test_eus_exams/test_eus_exams_pairs.json +0 -14
- wisent/examples/scripts/results/test_eus_exams_es_evaluation.json +0 -51
- wisent/examples/scripts/results/test_eus_exams_es_pairs.json +0 -14
- wisent/examples/scripts/results/test_eus_exams_evaluation.json +0 -51
- wisent/examples/scripts/results/test_eus_exams_pairs.json +0 -14
- wisent/examples/scripts/results/test_eus_proficiency_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eus_proficiency_pairs.json +0 -8
- wisent/examples/scripts/results/test_eus_reading_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eus_reading_pairs.json +0 -8
- wisent/examples/scripts/results/test_eus_trivia_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eus_trivia_pairs.json +0 -8
- wisent/examples/scripts/results/test_evalita-mp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita-mp_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita-sp_sum_task_fp-small_p1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_evalita-sp_sum_task_fp-small_p1_pairs.json +0 -8
- wisent/examples/scripts/results/test_evalita_LLM_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita_LLM_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita_llm/test_evalita_llm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita_llm/test_evalita_llm_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita_mp/test_evalita-mp_te_prompt-1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_evalita_mp/test_evalita-mp_te_prompt-1_pairs.json +0 -8
- wisent/examples/scripts/results/test_evalita_mp2/test_evalita_mp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita_mp2/test_evalita_mp_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita_sp2/test_evalita-sp_sum_task_fp-small_p1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_evalita_sp2/test_evalita-sp_sum_task_fp-small_p1_pairs.json +0 -8
- wisent/examples/scripts/results/test_fda_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fda_pairs.json +0 -8
- wisent/examples/scripts/results/test_financial_tweets_evaluation.json +0 -30
- wisent/examples/scripts/results/test_financial_tweets_pairs.json +0 -8
- wisent/examples/scripts/results/test_fld/test_fld_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fld/test_fld_pairs.json +0 -8
- wisent/examples/scripts/results/test_fld_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fld_fixed/test_fld_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fld_fixed/test_fld_pairs.json +0 -8
- wisent/examples/scripts/results/test_fld_pairs.json +0 -8
- wisent/examples/scripts/results/test_flores_evaluation.json +0 -51
- wisent/examples/scripts/results/test_flores_pairs.json +0 -14
- wisent/examples/scripts/results/test_freebase_evaluation.json +0 -30
- wisent/examples/scripts/results/test_freebase_pairs.json +0 -8
- wisent/examples/scripts/results/test_french_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_french_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_galcola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_galcola_pairs.json +0 -8
- wisent/examples/scripts/results/test_galician_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_galician_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_glianorex_evaluation.json +0 -30
- wisent/examples/scripts/results/test_glianorex_pairs.json +0 -8
- wisent/examples/scripts/results/test_global_mmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_global_mmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_gpqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_gpqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_gpt3_translation_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/results/test_gpt3_translation_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/results/test_groundcocoa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_groundcocoa_pairs.json +0 -8
- wisent/examples/scripts/results/test_gsm8k_evaluation.json +0 -30
- wisent/examples/scripts/results/test_gsm8k_pairs.json +0 -8
- wisent/examples/scripts/results/test_haerae_evaluation.json +0 -51
- wisent/examples/scripts/results/test_haerae_pairs.json +0 -14
- wisent/examples/scripts/results/test_headqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_headqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_hellaswag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_hellaswag_pairs.json +0 -8
- wisent/examples/scripts/results/test_hendrycks_ethics_evaluation.json +0 -51
- wisent/examples/scripts/results/test_hendrycks_ethics_pairs.json +0 -14
- wisent/examples/scripts/results/test_hendrycks_math_evaluation.json +0 -51
- wisent/examples/scripts/results/test_hendrycks_math_pairs.json +0 -14
- wisent/examples/scripts/results/test_histoires_morales_evaluation.json +0 -30
- wisent/examples/scripts/results/test_histoires_morales_pairs.json +0 -8
- wisent/examples/scripts/results/test_hmmt_evaluation.json +0 -30
- wisent/examples/scripts/results/test_hmmt_feb_2025_evaluation.json +0 -30
- wisent/examples/scripts/results/test_hmmt_feb_2025_pairs.json +0 -8
- wisent/examples/scripts/results/test_hmmt_pairs.json +0 -8
- wisent/examples/scripts/results/test_hrm8k_evaluation.json +0 -51
- wisent/examples/scripts/results/test_hrm8k_pairs.json +0 -14
- wisent/examples/scripts/results/test_humaneval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_humaneval_pairs.json +0 -8
- wisent/examples/scripts/results/test_humaneval_plus_evaluation.json +0 -30
- wisent/examples/scripts/results/test_humaneval_plus_pairs.json +0 -8
- wisent/examples/scripts/results/test_ifeval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ifeval_pairs.json +0 -8
- wisent/examples/scripts/results/test_instruct_humaneval/test_instruct_humaneval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_instruct_humaneval/test_instruct_humaneval_pairs.json +0 -8
- wisent/examples/scripts/results/test_instruct_humaneval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_instruct_humaneval_pairs.json +0 -8
- wisent/examples/scripts/results/test_inverse_scaling_evaluation.json +0 -51
- wisent/examples/scripts/results/test_inverse_scaling_hindsight_neglect_10shot_evaluation.json +0 -30
- wisent/examples/scripts/results/test_inverse_scaling_hindsight_neglect_10shot_pairs.json +0 -8
- wisent/examples/scripts/results/test_inverse_scaling_mc/test_inverse_scaling_mc_evaluation.json +0 -51
- wisent/examples/scripts/results/test_inverse_scaling_mc/test_inverse_scaling_mc_pairs.json +0 -14
- wisent/examples/scripts/results/test_inverse_scaling_pairs.json +0 -14
- wisent/examples/scripts/results/test_iwslt2017-ar-en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017-ar-en_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017-en-ar_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017-en-ar_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017_ar_en/test_iwslt2017-ar-en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017_ar_en/test_iwslt2017-ar-en_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017_en_ar/test_iwslt2017-en-ar_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017_en_ar/test_iwslt2017-en-ar_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017_group/test_iwslt2017_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017_group/test_iwslt2017_pairs.json +0 -8
- wisent/examples/scripts/results/test_japanese_leaderboard_evaluation.json +0 -51
- wisent/examples/scripts/results/test_japanese_leaderboard_pairs.json +0 -14
- wisent/examples/scripts/results/test_jsonschema_bench/test_jsonschema_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_jsonschema_bench/test_jsonschema_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_jsonschema_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_jsonschema_bench_final/test_jsonschema_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_jsonschema_bench_final/test_jsonschema_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_jsonschema_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_kbl_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kbl_fixed/test_kbl_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kbl_fixed/test_kbl_pairs.json +0 -14
- wisent/examples/scripts/results/test_kbl_pairs.json +0 -14
- wisent/examples/scripts/results/test_kmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_kobest_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kobest_pairs.json +0 -14
- wisent/examples/scripts/results/test_kormedmcqa/test_kormedmcqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_kormedmcqa/test_kormedmcqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_kormedmcqa_dentist/test_kormedmcqa_dentist_evaluation.json +0 -30
- wisent/examples/scripts/results/test_kormedmcqa_dentist/test_kormedmcqa_dentist_pairs.json +0 -8
- wisent/examples/scripts/results/test_kormedmcqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_kormedmcqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_cloze_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_cloze_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_final/test_lambada_openai_mt_stablelm_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_final/test_lambada_openai_mt_stablelm_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_multilingual/test_lambada_multilingual_evaluation.json +0 -51
- wisent/examples/scripts/results/test_lambada_multilingual/test_lambada_multilingual_pairs.json +0 -14
- wisent/examples/scripts/results/test_lambada_multilingual_evaluation.json +0 -51
- wisent/examples/scripts/results/test_lambada_multilingual_pairs.json +0 -14
- wisent/examples/scripts/results/test_lambada_multilingual_stablelm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_lambada_multilingual_stablelm_pairs.json +0 -14
- wisent/examples/scripts/results/test_lambada_openai_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_openai_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_stablelm_en_fixed/test_lambada_openai_mt_stablelm_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_stablelm_en_fixed/test_lambada_openai_mt_stablelm_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_stablelm_fixed/test_lambada_openai_mt_stablelm_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_stablelm_fixed/test_lambada_openai_mt_stablelm_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_standard_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_standard_pairs.json +0 -8
- wisent/examples/scripts/results/test_leaderboard_evaluation.json +0 -51
- wisent/examples/scripts/results/test_leaderboard_pairs.json +0 -14
- wisent/examples/scripts/results/test_libra/test_libra_evaluation.json +0 -51
- wisent/examples/scripts/results/test_libra/test_libra_pairs.json +0 -14
- wisent/examples/scripts/results/test_libra_evaluation.json +0 -51
- wisent/examples/scripts/results/test_libra_pairs.json +0 -14
- wisent/examples/scripts/results/test_lingoly_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lingoly_pairs.json +0 -8
- wisent/examples/scripts/results/test_livecodebench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_livecodebench_pairs.json +0 -8
- wisent/examples/scripts/results/test_livemathbench_cnmo_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_livemathbench_cnmo_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_livemathbench_cnmo_zh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_livemathbench_cnmo_zh_pairs.json +0 -8
- wisent/examples/scripts/results/test_llama_evaluation.json +0 -30
- wisent/examples/scripts/results/test_llama_pairs.json +0 -8
- wisent/examples/scripts/results/test_logiqa2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_logiqa2_pairs.json +0 -8
- wisent/examples/scripts/results/test_logiqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_logiqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_m_mmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_m_mmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_mastermind/test_mastermind_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mastermind/test_mastermind_pairs.json +0 -14
- wisent/examples/scripts/results/test_mastermind_24_easy/test_mastermind_24_easy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mastermind_24_easy/test_mastermind_24_easy_pairs.json +0 -8
- wisent/examples/scripts/results/test_mastermind_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mastermind_pairs.json +0 -14
- wisent/examples/scripts/results/test_math500_evaluation.json +0 -30
- wisent/examples/scripts/results/test_math500_pairs.json +0 -8
- wisent/examples/scripts/results/test_math_evaluation.json +0 -30
- wisent/examples/scripts/results/test_math_pairs.json +0 -8
- wisent/examples/scripts/results/test_mathqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mathqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_mbpp_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mbpp_pairs.json +0 -8
- wisent/examples/scripts/results/test_mbpp_plus_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mbpp_plus_pairs.json +0 -8
- wisent/examples/scripts/results/test_mc_taco_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mc_taco_pairs.json +0 -8
- wisent/examples/scripts/results/test_med_concepts_qa/test_med_concepts_qa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_med_concepts_qa/test_med_concepts_qa_pairs.json +0 -14
- wisent/examples/scripts/results/test_med_concepts_qa_atc_easy/test_med_concepts_qa_atc_easy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_med_concepts_qa_atc_easy/test_med_concepts_qa_atc_easy_pairs.json +0 -8
- wisent/examples/scripts/results/test_med_concepts_qa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_med_concepts_qa_pairs.json +0 -14
- wisent/examples/scripts/results/test_meddialog_evaluation.json +0 -30
- wisent/examples/scripts/results/test_meddialog_pairs.json +0 -8
- wisent/examples/scripts/results/test_meddialog_raw_perplexity/test_meddialog_raw_perplexity_evaluation.json +0 -30
- wisent/examples/scripts/results/test_meddialog_raw_perplexity/test_meddialog_raw_perplexity_pairs.json +0 -8
- wisent/examples/scripts/results/test_mediqa_qa2019_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mediqa_qa2019_pairs.json +0 -8
- wisent/examples/scripts/results/test_medmcqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_medmcqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_medqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_medqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_medtext_evaluation.json +0 -30
- wisent/examples/scripts/results/test_medtext_pairs.json +0 -8
- wisent/examples/scripts/results/test_mela_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mela_pairs.json +0 -14
- wisent/examples/scripts/results/test_meqsum_evaluation.json +0 -30
- wisent/examples/scripts/results/test_meqsum_pairs.json +0 -8
- wisent/examples/scripts/results/test_mercury_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mercury_pairs.json +0 -8
- wisent/examples/scripts/results/test_metabench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_metabench_pairs.json +0 -14
- wisent/examples/scripts/results/test_mgsm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mgsm_pairs.json +0 -14
- wisent/examples/scripts/results/test_mimic_repsum_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mimic_repsum_pairs.json +0 -8
- wisent/examples/scripts/results/test_minerva_math_evaluation.json +0 -51
- wisent/examples/scripts/results/test_minerva_math_pairs.json +0 -14
- wisent/examples/scripts/results/test_mlqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mlqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu-pro-plus_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu-pro-plus_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu_pro_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu_pro_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu_prox_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu_prox_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlusr_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mmlusr_pairs.json +0 -8
- wisent/examples/scripts/results/test_mmmu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmmu_pairs.json +0 -14
- wisent/examples/scripts/results/test_mnli_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mnli_pairs.json +0 -8
- wisent/examples/scripts/results/test_model_written_evals_evaluation.json +0 -51
- wisent/examples/scripts/results/test_model_written_evals_pairs.json +0 -14
- wisent/examples/scripts/results/test_moral_stories_evaluation.json +0 -30
- wisent/examples/scripts/results/test_moral_stories_pairs.json +0 -8
- wisent/examples/scripts/results/test_mts_dialog_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mts_dialog_pairs.json +0 -8
- wisent/examples/scripts/results/test_multiblimp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_multiblimp_pairs.json +0 -14
- wisent/examples/scripts/results/test_multimedqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_multimedqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_multipl_e_evaluation.json +0 -30
- wisent/examples/scripts/results/test_multipl_e_pairs.json +0 -8
- wisent/examples/scripts/results/test_mutual_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mutual_pairs.json +0 -8
- wisent/examples/scripts/results/test_non_greedy_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_non_greedy_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_noreval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_noreval_pairs.json +0 -14
- wisent/examples/scripts/results/test_noticia_evaluation.json +0 -30
- wisent/examples/scripts/results/test_noticia_pairs.json +0 -8
- wisent/examples/scripts/results/test_nq_open_evaluation.json +0 -30
- wisent/examples/scripts/results/test_nq_open_pairs.json +0 -8
- wisent/examples/scripts/results/test_olaph_evaluation.json +0 -30
- wisent/examples/scripts/results/test_olaph_pairs.json +0 -8
- wisent/examples/scripts/results/test_openbookqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_openbookqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_openllm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_openllm_pairs.json +0 -14
- wisent/examples/scripts/results/test_option_order_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_option_order_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_paloma_evaluation.json +0 -51
- wisent/examples/scripts/results/test_paloma_pairs.json +0 -14
- wisent/examples/scripts/results/test_passkey/test_passkey_evaluation.json +0 -30
- wisent/examples/scripts/results/test_passkey/test_passkey_pairs.json +0 -8
- wisent/examples/scripts/results/test_paws-x_evaluation.json +0 -51
- wisent/examples/scripts/results/test_paws-x_pairs.json +0 -14
- wisent/examples/scripts/results/test_paws_en/test_paws_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_paws_en/test_paws_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_penn_treebank_evaluation.json +0 -30
- wisent/examples/scripts/results/test_penn_treebank_pairs.json +0 -8
- wisent/examples/scripts/results/test_pile_10k/test_pile_10k_evaluation.json +0 -30
- wisent/examples/scripts/results/test_pile_10k/test_pile_10k_pairs.json +0 -8
- wisent/examples/scripts/results/test_piqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_piqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_polemo2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polemo2_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_en_high_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_en_high_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_en_medium_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_en_medium_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_zh_high_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_zh_high_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_zh_medium_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_zh_medium_pairs.json +0 -8
- wisent/examples/scripts/results/test_portuguese_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_portuguese_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat/test_prompt_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat/test_prompt_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_prost_evaluation.json +0 -30
- wisent/examples/scripts/results/test_prost_pairs.json +0 -8
- wisent/examples/scripts/results/test_ptb_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ptb_pairs.json +0 -8
- wisent/examples/scripts/results/test_pubmedqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_pubmedqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_pythia_evaluation.json +0 -51
- wisent/examples/scripts/results/test_pythia_pairs.json +0 -14
- wisent/examples/scripts/results/test_qa4mre_evaluation.json +0 -30
- wisent/examples/scripts/results/test_qa4mre_pairs.json +0 -8
- wisent/examples/scripts/results/test_qasper_evaluation.json +0 -30
- wisent/examples/scripts/results/test_qasper_pairs.json +0 -8
- wisent/examples/scripts/results/test_race_evaluation.json +0 -30
- wisent/examples/scripts/results/test_race_pairs.json +0 -8
- wisent/examples/scripts/results/test_realtoxicityprompts_evaluation.json +0 -30
- wisent/examples/scripts/results/test_realtoxicityprompts_pairs.json +0 -8
- wisent/examples/scripts/results/test_recode_evaluation.json +0 -30
- wisent/examples/scripts/results/test_recode_pairs.json +0 -8
- wisent/examples/scripts/results/test_record_evaluation.json +0 -30
- wisent/examples/scripts/results/test_record_pairs.json +0 -8
- wisent/examples/scripts/results/test_ruler_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ruler_pairs.json +0 -14
- wisent/examples/scripts/results/test_sciq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_sciq_pairs.json +0 -8
- wisent/examples/scripts/results/test_score_evaluation.json +0 -51
- wisent/examples/scripts/results/test_score_pairs.json +0 -14
- wisent/examples/scripts/results/test_self_consistency_evaluation.json +0 -30
- wisent/examples/scripts/results/test_self_consistency_pairs.json +0 -8
- wisent/examples/scripts/results/test_siqa/test_siqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_siqa/test_siqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_siqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_siqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_spanish_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_spanish_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_squad2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_squad2_pairs.json +0 -8
- wisent/examples/scripts/results/test_squadv2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_squadv2_pairs.json +0 -8
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1_evaluation.json +0 -51
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1_pairs.json +0 -14
- wisent/examples/scripts/results/test_swag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_swag_pairs.json +0 -8
- wisent/examples/scripts/results/test_tinyBenchmarks_evaluation.json +0 -51
- wisent/examples/scripts/results/test_tinyBenchmarks_pairs.json +0 -14
- wisent/examples/scripts/results/test_tmmluplus_evaluation.json +0 -51
- wisent/examples/scripts/results/test_tmmluplus_pairs.json +0 -14
- wisent/examples/scripts/results/test_translation_evaluation.json +0 -51
- wisent/examples/scripts/results/test_translation_pairs.json +0 -14
- wisent/examples/scripts/results/test_triviaqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_triviaqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_truthfulqa-multi_evaluation.json +0 -51
- wisent/examples/scripts/results/test_truthfulqa-multi_pairs.json +0 -14
- wisent/examples/scripts/results/test_truthfulqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_truthfulqa_mc1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_truthfulqa_mc1_pairs.json +0 -8
- wisent/examples/scripts/results/test_truthfulqa_mc2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_truthfulqa_mc2_pairs.json +0 -8
- wisent/examples/scripts/results/test_truthfulqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_turkishmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_turkishmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_unfair_tos_evaluation.json +0 -30
- wisent/examples/scripts/results/test_unfair_tos_pairs.json +0 -8
- wisent/examples/scripts/results/test_unscramble_evaluation.json +0 -51
- wisent/examples/scripts/results/test_unscramble_pairs.json +0 -14
- wisent/examples/scripts/results/test_webqs_evaluation.json +0 -30
- wisent/examples/scripts/results/test_webqs_pairs.json +0 -8
- wisent/examples/scripts/results/test_wikitext103_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wikitext103_pairs.json +0 -8
- wisent/examples/scripts/results/test_wikitext_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wikitext_pairs.json +0 -8
- wisent/examples/scripts/results/test_winogender_evaluation.json +0 -51
- wisent/examples/scripts/results/test_winogender_pairs.json +0 -14
- wisent/examples/scripts/results/test_winogrande_evaluation.json +0 -30
- wisent/examples/scripts/results/test_winogrande_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmdp_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmdp_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt-ro-en-t5-prompt_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt-ro-en-t5-prompt_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt14_en_fr_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt14_en_fr_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt16_en_de_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt16_en_de_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt16_ro_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt16_ro_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_wsc273_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wsc273_pairs.json +0 -8
- wisent/examples/scripts/results/test_xcopa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xcopa_pairs.json +0 -14
- wisent/examples/scripts/results/test_xnli_eu_evaluation.json +0 -30
- wisent/examples/scripts/results/test_xnli_eu_pairs.json +0 -8
- wisent/examples/scripts/results/test_xnli_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xnli_pairs.json +0 -14
- wisent/examples/scripts/results/test_xquad_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xquad_pairs.json +0 -14
- wisent/examples/scripts/results/test_xstorycloze_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xstorycloze_pairs.json +0 -14
- wisent/examples/scripts/results/test_xsum_evaluation.json +0 -30
- wisent/examples/scripts/results/test_xsum_pairs.json +0 -8
- wisent/examples/scripts/results/test_xwinograd_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xwinograd_pairs.json +0 -14
- wisent/examples/scripts/results/test_yahoo_answers_topics_evaluation.json +0 -30
- wisent/examples/scripts/results/test_yahoo_answers_topics_pairs.json +0 -8
- {wisent-0.7.379.dist-info → wisent-0.7.701.dist-info}/WHEEL +0 -0
- {wisent-0.7.379.dist-info → wisent-0.7.701.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.379.dist-info → wisent-0.7.701.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.379.dist-info → wisent-0.7.701.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Classifier inference strategies for runtime classification.
|
|
3
|
+
|
|
4
|
+
These strategies determine how to extract activations from generated text
|
|
5
|
+
at inference time when classifying responses.
|
|
6
|
+
|
|
7
|
+
Based on empirical testing across 3 models (Llama-3.2-1B, Llama-2-7b, Qwen3-8B)
|
|
8
|
+
and 4 tasks (truthfulqa, happy, left_wing, livecodebench):
|
|
9
|
+
|
|
10
|
+
Results:
|
|
11
|
+
- last_token: 66.3% avg accuracy (94.4% when paired with chat_last training)
|
|
12
|
+
- all_mean: 65.9% avg accuracy
|
|
13
|
+
- all_min: 53.5% avg accuracy
|
|
14
|
+
- all_max: 53.3% avg accuracy
|
|
15
|
+
- first_token: 50.0% avg accuracy (completely useless - BOS token is identical for all inputs)
|
|
16
|
+
|
|
17
|
+
Recommendation: Use LAST_TOKEN (default) - it works best with chat_last training strategy.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import Optional
|
|
22
|
+
import argparse
|
|
23
|
+
import torch
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ClassifierInferenceStrategy(str, Enum):
|
|
28
|
+
"""
|
|
29
|
+
Strategies for extracting activations at inference time for classification.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
LAST_TOKEN = "last_token"
|
|
33
|
+
"""Extract activation from the last token only. Best overall performance."""
|
|
34
|
+
|
|
35
|
+
FIRST_TOKEN = "first_token"
|
|
36
|
+
"""Extract activation from the first token only. NOT RECOMMENDED - BOS token has no variance."""
|
|
37
|
+
|
|
38
|
+
ALL_MEAN = "all_mean"
|
|
39
|
+
"""Classify each token, return mean of all scores."""
|
|
40
|
+
|
|
41
|
+
ALL_MAX = "all_max"
|
|
42
|
+
"""Classify each token, return max score (most confident positive)."""
|
|
43
|
+
|
|
44
|
+
ALL_MIN = "all_min"
|
|
45
|
+
"""Classify each token, return min score (most confident negative)."""
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def description(self) -> str:
|
|
49
|
+
descriptions = {
|
|
50
|
+
ClassifierInferenceStrategy.LAST_TOKEN: "Last token activation (recommended)",
|
|
51
|
+
ClassifierInferenceStrategy.FIRST_TOKEN: "First token activation (not recommended)",
|
|
52
|
+
ClassifierInferenceStrategy.ALL_MEAN: "Mean of all token scores",
|
|
53
|
+
ClassifierInferenceStrategy.ALL_MAX: "Max of all token scores",
|
|
54
|
+
ClassifierInferenceStrategy.ALL_MIN: "Min of all token scores",
|
|
55
|
+
}
|
|
56
|
+
return descriptions.get(self, "Unknown strategy")
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def default(cls) -> "ClassifierInferenceStrategy":
|
|
60
|
+
"""Return the default strategy (last_token performs best)."""
|
|
61
|
+
return cls.LAST_TOKEN
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def list_all(cls) -> list[str]:
|
|
65
|
+
"""List all strategy names."""
|
|
66
|
+
return [s.value for s in cls]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def extract_inference_activation(
|
|
70
|
+
strategy: ClassifierInferenceStrategy,
|
|
71
|
+
hidden_states: torch.Tensor,
|
|
72
|
+
) -> torch.Tensor:
|
|
73
|
+
"""
|
|
74
|
+
Extract activation for classification at inference time.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
strategy: The inference strategy to use
|
|
78
|
+
hidden_states: Hidden states tensor of shape [seq_len, hidden_dim]
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Activation vector of shape [hidden_dim]
|
|
82
|
+
"""
|
|
83
|
+
seq_len = hidden_states.shape[0]
|
|
84
|
+
|
|
85
|
+
if strategy == ClassifierInferenceStrategy.LAST_TOKEN:
|
|
86
|
+
return hidden_states[-1]
|
|
87
|
+
|
|
88
|
+
elif strategy == ClassifierInferenceStrategy.FIRST_TOKEN:
|
|
89
|
+
return hidden_states[0]
|
|
90
|
+
|
|
91
|
+
elif strategy == ClassifierInferenceStrategy.ALL_MEAN:
|
|
92
|
+
return hidden_states.mean(dim=0)
|
|
93
|
+
|
|
94
|
+
elif strategy == ClassifierInferenceStrategy.ALL_MAX:
|
|
95
|
+
# Token with max norm
|
|
96
|
+
norms = torch.norm(hidden_states, dim=1)
|
|
97
|
+
return hidden_states[torch.argmax(norms)]
|
|
98
|
+
|
|
99
|
+
elif strategy == ClassifierInferenceStrategy.ALL_MIN:
|
|
100
|
+
# Token with min norm
|
|
101
|
+
norms = torch.norm(hidden_states, dim=1)
|
|
102
|
+
return hidden_states[torch.argmin(norms)]
|
|
103
|
+
|
|
104
|
+
else:
|
|
105
|
+
# Default fallback
|
|
106
|
+
return hidden_states[-1]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_inference_score(
|
|
110
|
+
classifier,
|
|
111
|
+
hidden_states: torch.Tensor,
|
|
112
|
+
strategy: ClassifierInferenceStrategy,
|
|
113
|
+
) -> float:
|
|
114
|
+
"""
|
|
115
|
+
Get classifier score using the specified inference strategy.
|
|
116
|
+
|
|
117
|
+
For single-token strategies (last_token, first_token), returns the classifier
|
|
118
|
+
probability for that token.
|
|
119
|
+
|
|
120
|
+
For all_* strategies, classifies each token and aggregates scores.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
classifier: A trained classifier with predict_proba method
|
|
124
|
+
hidden_states: Hidden states tensor of shape [seq_len, hidden_dim]
|
|
125
|
+
strategy: The inference strategy to use
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Classification score (probability of positive class)
|
|
129
|
+
"""
|
|
130
|
+
hidden_np = hidden_states.cpu().float().numpy()
|
|
131
|
+
seq_len = hidden_np.shape[0]
|
|
132
|
+
|
|
133
|
+
if strategy == ClassifierInferenceStrategy.LAST_TOKEN:
|
|
134
|
+
return float(classifier.predict_proba([hidden_np[-1]])[0, 1])
|
|
135
|
+
|
|
136
|
+
elif strategy == ClassifierInferenceStrategy.FIRST_TOKEN:
|
|
137
|
+
return float(classifier.predict_proba([hidden_np[0]])[0, 1])
|
|
138
|
+
|
|
139
|
+
elif strategy in (ClassifierInferenceStrategy.ALL_MEAN,
|
|
140
|
+
ClassifierInferenceStrategy.ALL_MAX,
|
|
141
|
+
ClassifierInferenceStrategy.ALL_MIN):
|
|
142
|
+
# Classify all tokens
|
|
143
|
+
all_scores = []
|
|
144
|
+
for t in range(seq_len):
|
|
145
|
+
score = classifier.predict_proba([hidden_np[t]])[0, 1]
|
|
146
|
+
all_scores.append(score)
|
|
147
|
+
|
|
148
|
+
if strategy == ClassifierInferenceStrategy.ALL_MEAN:
|
|
149
|
+
return float(np.mean(all_scores))
|
|
150
|
+
elif strategy == ClassifierInferenceStrategy.ALL_MAX:
|
|
151
|
+
return float(np.max(all_scores))
|
|
152
|
+
elif strategy == ClassifierInferenceStrategy.ALL_MIN:
|
|
153
|
+
return float(np.min(all_scores))
|
|
154
|
+
|
|
155
|
+
# Default fallback
|
|
156
|
+
return float(classifier.predict_proba([hidden_np[-1]])[0, 1])
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def get_recommended_inference_strategy(train_strategy) -> ClassifierInferenceStrategy:
|
|
160
|
+
"""
|
|
161
|
+
Get the recommended inference strategy for a given training strategy.
|
|
162
|
+
|
|
163
|
+
Based on empirical testing:
|
|
164
|
+
- chat_last, role_play, mc_balanced -> last_token (94.4%, 72.4%, 60.2%)
|
|
165
|
+
- chat_mean, chat_weighted, chat_max_norm, chat_first, chat_gen_point -> all_mean
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
train_strategy: ExtractionStrategy used for training
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Recommended ClassifierInferenceStrategy
|
|
172
|
+
"""
|
|
173
|
+
# Import here to avoid circular dependency
|
|
174
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
175
|
+
|
|
176
|
+
if train_strategy in (ExtractionStrategy.CHAT_LAST,
|
|
177
|
+
ExtractionStrategy.ROLE_PLAY,
|
|
178
|
+
ExtractionStrategy.MC_BALANCED):
|
|
179
|
+
return ClassifierInferenceStrategy.LAST_TOKEN
|
|
180
|
+
else:
|
|
181
|
+
return ClassifierInferenceStrategy.ALL_MEAN
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def add_classifier_inference_strategy_args(parser: argparse.ArgumentParser) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Add --classifier-inference-strategy argument to an argument parser.
|
|
187
|
+
"""
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"--classifier-inference-strategy",
|
|
190
|
+
type=str,
|
|
191
|
+
default=ClassifierInferenceStrategy.default().value,
|
|
192
|
+
choices=ClassifierInferenceStrategy.list_all(),
|
|
193
|
+
help=f"Inference strategy for classifier. Options: {', '.join(ClassifierInferenceStrategy.list_all())}. Default: {ClassifierInferenceStrategy.default().value}",
|
|
194
|
+
)
|
|
@@ -1,60 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from enum import Enum, auto, unique
|
|
4
3
|
from typing import Mapping, Iterator, TypeAlias
|
|
5
4
|
import numpy as np
|
|
6
5
|
import torch
|
|
7
|
-
import sys
|
|
8
6
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
# Python 3.10 compatibility
|
|
12
|
-
if sys.version_info >= (3, 11):
|
|
13
|
-
from enum import StrEnum
|
|
14
|
-
else:
|
|
15
|
-
class StrEnum(str, Enum):
|
|
16
|
-
"""StrEnum backport for Python < 3.11"""
|
|
17
|
-
def _generate_next_value_(name, start, count, last_values):
|
|
18
|
-
return name.lower()
|
|
19
|
-
|
|
20
|
-
def __str__(self) -> str:
|
|
21
|
-
return str(self.value)
|
|
22
|
-
|
|
23
|
-
__all__ = ["LayerActivations", "ActivationAggregationStrategy", "ActivationCollector", "LayerName", "LayerActivation", "ActivationMap", "RawActivationMap"]
|
|
7
|
+
__all__ = ["LayerActivations", "LayerName", "LayerActivation", "ActivationMap", "RawActivationMap"]
|
|
24
8
|
|
|
25
9
|
LayerName: TypeAlias = str
|
|
26
10
|
LayerActivation: TypeAlias = torch.Tensor | None
|
|
27
11
|
ActivationMap: TypeAlias = Mapping[LayerName, LayerActivation]
|
|
28
12
|
RawActivationMap: TypeAlias = Mapping[LayerName, torch.Tensor | np.ndarray | None]
|
|
29
13
|
|
|
30
|
-
class _LowerSnakeStrEnum(StrEnum):
|
|
31
|
-
"""StrEnum whose auto() values are lower_snake_case of the member name."""
|
|
32
|
-
def _generate_next_value_(name, start, count, last_values): # type: ignore
|
|
33
|
-
return name.lower()
|
|
34
|
-
|
|
35
|
-
@unique
|
|
36
|
-
class ActivationAggregationStrategy(_LowerSnakeStrEnum):
|
|
37
|
-
"""Strategies for selecting/aggregating tokens in activation extraction.
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
CHOICE_TOKEN = auto() # target A/B choice tokens (multiple choice)
|
|
41
|
-
CONTINUATION_TOKEN = auto() # first token of the continuation
|
|
42
|
-
LAST_TOKEN = auto() # always use the last token
|
|
43
|
-
FIRST_TOKEN = auto() # always use the first token
|
|
44
|
-
MEAN_POOLING = auto() # mean over all tokens
|
|
45
|
-
MAX_POOLING = auto() # max over all tokens
|
|
46
|
-
|
|
47
|
-
@property
|
|
48
|
-
def description(self) -> str:
|
|
49
|
-
return {
|
|
50
|
-
ActivationAggregationStrategy.CHOICE_TOKEN: "Target A/B choice tokens (multiple choice).",
|
|
51
|
-
ActivationAggregationStrategy.CONTINUATION_TOKEN: "Use the first token of the continuation.",
|
|
52
|
-
ActivationAggregationStrategy.LAST_TOKEN: "Always select the last token.",
|
|
53
|
-
ActivationAggregationStrategy.FIRST_TOKEN: "Always select the first token.",
|
|
54
|
-
ActivationAggregationStrategy.MEAN_POOLING: "Aggregate by mean over all tokens.",
|
|
55
|
-
ActivationAggregationStrategy.MAX_POOLING: "Aggregate by max over all tokens.",
|
|
56
|
-
}[self]
|
|
57
|
-
|
|
58
14
|
|
|
59
15
|
class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
60
16
|
"""Immutable mapping of layer names to activations.
|
|
@@ -72,8 +28,6 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
72
28
|
atributes:
|
|
73
29
|
_data:
|
|
74
30
|
internal storage dict. It contains information about layer activations.
|
|
75
|
-
_strategy:
|
|
76
|
-
'ActivationAggregationStrategy' (see below). Indicates how activations were aggregated if applicable.
|
|
77
31
|
|
|
78
32
|
methods:
|
|
79
33
|
'summary()':
|
|
@@ -88,13 +42,11 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
88
42
|
plain dict (useful for (de)serialization).
|
|
89
43
|
|
|
90
44
|
examples:
|
|
91
|
-
>>> acts = LayerActivations({"layer1": torch.randn(2, 10, 768), "layer2": None}
|
|
45
|
+
>>> acts = LayerActivations({"layer1": torch.randn(2, 10, 768), "layer2": None})
|
|
92
46
|
>>> acts["layer1"].shape
|
|
93
47
|
torch.Size([2, 10, 768])
|
|
94
48
|
>>> acts["layer2"] is None
|
|
95
49
|
True
|
|
96
|
-
>>> acts.activation_aggregation_strategy
|
|
97
|
-
<ActivationAggregationStrategy.MEAN_POOLING: 'mean_pooling'>
|
|
98
50
|
>>> acts.summary()
|
|
99
51
|
{'layer1': {'shape': (2, 10, 768), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False}, 'layer2': {'shape': None, 'dtype': None, 'device': None, 'requires_grad': None}}
|
|
100
52
|
>>> acts.numpy()
|
|
@@ -104,19 +56,14 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
104
56
|
layer1: Tensor(shape=(2, 10, 768), dtype=torch.float32, device=cuda:0)
|
|
105
57
|
layer2: None
|
|
106
58
|
)
|
|
107
|
-
>>> acts.detach() # if any tensor required grad
|
|
108
|
-
LayerActivations(
|
|
109
|
-
layer1: Tensor(shape=(2, 10, 768), dtype=torch.float32, device=cpu)
|
|
110
|
-
layer2: None
|
|
111
|
-
)
|
|
112
59
|
|
|
113
60
|
notes:
|
|
114
61
|
- Use 'summary()' or 'numpy()' if you need JSON-serializable content.
|
|
115
62
|
- Keys are strings by convention; enforced by type hints.
|
|
116
63
|
"""
|
|
117
|
-
__slots__ = ("_data",
|
|
64
|
+
__slots__ = ("_data",)
|
|
118
65
|
|
|
119
|
-
def __init__(self, data: RawActivationMap | None = None,
|
|
66
|
+
def __init__(self, data: RawActivationMap | None = None, dtype: torch.dtype | None = None):
|
|
120
67
|
store: dict[LayerName, LayerActivation] = {}
|
|
121
68
|
if data:
|
|
122
69
|
for layer, val in data.items():
|
|
@@ -132,33 +79,6 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
132
79
|
f"Activations for layer '{layer}' must be torch.Tensor, np.ndarray, or None."
|
|
133
80
|
)
|
|
134
81
|
self._data = store
|
|
135
|
-
self._strategy = self._normalize_strategy(activation_aggregation_strategy)
|
|
136
|
-
|
|
137
|
-
@staticmethod
|
|
138
|
-
def _normalize_strategy(
|
|
139
|
-
s: ActivationAggregationStrategy | str | None
|
|
140
|
-
) -> ActivationAggregationStrategy | None:
|
|
141
|
-
if s is None:
|
|
142
|
-
return None
|
|
143
|
-
if isinstance(s, ActivationAggregationStrategy):
|
|
144
|
-
return s
|
|
145
|
-
if isinstance(s, str):
|
|
146
|
-
try:
|
|
147
|
-
return ActivationAggregationStrategy(s)
|
|
148
|
-
except ValueError:
|
|
149
|
-
valid = [e.value for e in ActivationAggregationStrategy]
|
|
150
|
-
raise UnknownTypeError(
|
|
151
|
-
entity_type="activation_agregation_strategy",
|
|
152
|
-
value=s,
|
|
153
|
-
valid_values=valid
|
|
154
|
-
)
|
|
155
|
-
raise TypeError(
|
|
156
|
-
"activation_agregation_strategy must be ActivationAggregationStrategy | str | None"
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
@property
|
|
160
|
-
def activation_aggregation_strategy(self) -> ActivationAggregationStrategy | None:
|
|
161
|
-
return self._strategy
|
|
162
82
|
|
|
163
83
|
def __getitem__(self, key: LayerName) -> LayerActivation:
|
|
164
84
|
return self._data[key]
|
|
@@ -168,10 +88,10 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
168
88
|
return len(self._data)
|
|
169
89
|
|
|
170
90
|
def summary(self) -> dict[LayerName, dict[str, tuple | str | bool | None]]:
|
|
171
|
-
|
|
172
|
-
shape, dtype, device, requires_grad status
|
|
173
|
-
|
|
174
|
-
out: dict[LayerName, dict[str,
|
|
91
|
+
"""Return a summary of the activations. For each layer, provides
|
|
92
|
+
shape, dtype, device, requires_grad status.
|
|
93
|
+
"""
|
|
94
|
+
out: dict[LayerName, dict[str, tuple | str | bool | None]] = {}
|
|
175
95
|
for k, v in self._data.items():
|
|
176
96
|
if isinstance(v, torch.Tensor):
|
|
177
97
|
out[k] = {
|
|
@@ -182,8 +102,6 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
182
102
|
}
|
|
183
103
|
else:
|
|
184
104
|
out[k] = {"shape": None, "dtype": None, "device": None, "requires_grad": None}
|
|
185
|
-
|
|
186
|
-
out["_activation_aggregation_strategy"] = {"strategy": self._strategy.value if self._strategy else None}
|
|
187
105
|
return out
|
|
188
106
|
|
|
189
107
|
def numpy(self) -> dict[LayerName, np.ndarray | None]:
|
|
@@ -214,6 +132,4 @@ class LayerActivations(Mapping[LayerName, LayerActivation]):
|
|
|
214
132
|
else:
|
|
215
133
|
lines.append(f" {k}: None")
|
|
216
134
|
lines.append(")")
|
|
217
|
-
lines.append(f" _activation_aggregation_strategy: {self._strategy.value if self._strategy else None}")
|
|
218
|
-
|
|
219
135
|
return "\n".join(lines)
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified extraction strategies for activation collection.
|
|
3
|
+
|
|
4
|
+
These strategies combine prompt construction and token extraction into a single
|
|
5
|
+
unified approach, based on empirical testing of what actually works.
|
|
6
|
+
|
|
7
|
+
The strategies are:
|
|
8
|
+
- chat_mean: Chat template prompt, mean of answer tokens
|
|
9
|
+
- chat_first: Chat template prompt, first answer token
|
|
10
|
+
- chat_last: Chat template prompt, last token
|
|
11
|
+
- chat_gen_point: Chat template prompt, token before answer (generation decision point)
|
|
12
|
+
- chat_max_norm: Chat template prompt, token with max norm in answer
|
|
13
|
+
- chat_weighted: Chat template prompt, position-weighted mean (earlier tokens weighted more)
|
|
14
|
+
- role_play: "Behave like person who answers Q with A" format, last token
|
|
15
|
+
- mc_balanced: Multiple choice with balanced A/B assignment, last token
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from enum import Enum
|
|
19
|
+
from typing import Tuple, Optional
|
|
20
|
+
import argparse
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ExtractionStrategy(str, Enum):
|
|
25
|
+
"""
|
|
26
|
+
Unified extraction strategies combining prompt format and token selection.
|
|
27
|
+
|
|
28
|
+
These replace the old separate PromptConstructionStrategy and ActivationAggregationStrategy.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
CHAT_MEAN = "chat_mean"
|
|
32
|
+
"""Chat template prompt with Q+A, extract mean of answer tokens."""
|
|
33
|
+
|
|
34
|
+
CHAT_FIRST = "chat_first"
|
|
35
|
+
"""Chat template prompt with Q+A, extract first answer token."""
|
|
36
|
+
|
|
37
|
+
CHAT_LAST = "chat_last"
|
|
38
|
+
"""Chat template prompt with Q+A, extract last token."""
|
|
39
|
+
|
|
40
|
+
CHAT_GEN_POINT = "chat_gen_point"
|
|
41
|
+
"""Chat template prompt with Q+A, extract token before answer starts (decision point)."""
|
|
42
|
+
|
|
43
|
+
CHAT_MAX_NORM = "chat_max_norm"
|
|
44
|
+
"""Chat template prompt with Q+A, extract token with max norm in answer region."""
|
|
45
|
+
|
|
46
|
+
CHAT_WEIGHTED = "chat_weighted"
|
|
47
|
+
"""Chat template prompt with Q+A, position-weighted mean (earlier tokens weighted more)."""
|
|
48
|
+
|
|
49
|
+
ROLE_PLAY = "role_play"
|
|
50
|
+
"""'Behave like person who answers Q with A' format, extract last token."""
|
|
51
|
+
|
|
52
|
+
MC_BALANCED = "mc_balanced"
|
|
53
|
+
"""Multiple choice format with balanced A/B assignment, extract last token."""
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def description(self) -> str:
|
|
57
|
+
descriptions = {
|
|
58
|
+
ExtractionStrategy.CHAT_MEAN: "Chat template with mean of answer tokens",
|
|
59
|
+
ExtractionStrategy.CHAT_FIRST: "Chat template with first answer token",
|
|
60
|
+
ExtractionStrategy.CHAT_LAST: "Chat template with last token",
|
|
61
|
+
ExtractionStrategy.CHAT_GEN_POINT: "Chat template with generation decision point",
|
|
62
|
+
ExtractionStrategy.CHAT_MAX_NORM: "Chat template with max-norm answer token",
|
|
63
|
+
ExtractionStrategy.CHAT_WEIGHTED: "Chat template with position-weighted mean",
|
|
64
|
+
ExtractionStrategy.ROLE_PLAY: "Role-playing format with last token",
|
|
65
|
+
ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with last token",
|
|
66
|
+
}
|
|
67
|
+
return descriptions.get(self, "Unknown strategy")
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def default(cls) -> "ExtractionStrategy":
|
|
71
|
+
"""Return the default strategy (chat_last is most commonly used)."""
|
|
72
|
+
return cls.CHAT_LAST
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def list_all(cls) -> list[str]:
|
|
76
|
+
"""List all strategy names."""
|
|
77
|
+
return [s.value for s in cls]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Random tokens for role_play strategy (deterministic based on prompt hash)
|
|
81
|
+
ROLE_PLAY_TOKENS = ["I", "Well", "The", "Sure", "Let", "That", "It", "This", "My", "To"]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def build_extraction_texts(
|
|
85
|
+
strategy: ExtractionStrategy,
|
|
86
|
+
prompt: str,
|
|
87
|
+
response: str,
|
|
88
|
+
tokenizer,
|
|
89
|
+
other_response: Optional[str] = None,
|
|
90
|
+
is_positive: bool = True,
|
|
91
|
+
) -> Tuple[str, str, Optional[str]]:
|
|
92
|
+
"""
|
|
93
|
+
Build the full text for activation extraction based on strategy.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
strategy: The extraction strategy to use
|
|
97
|
+
prompt: The user prompt/question
|
|
98
|
+
response: The response to extract activations for
|
|
99
|
+
tokenizer: The tokenizer (needs apply_chat_template for chat strategies)
|
|
100
|
+
other_response: For mc_balanced, the other response option
|
|
101
|
+
is_positive: For mc_balanced, whether 'response' is the positive option
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Tuple of (full_text, answer_text, prompt_only_text)
|
|
105
|
+
- full_text: Complete text to feed to model
|
|
106
|
+
- answer_text: The answer portion (for strategies that need it)
|
|
107
|
+
- prompt_only_text: Prompt without answer (for boundary detection)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
if strategy in (ExtractionStrategy.CHAT_MEAN, ExtractionStrategy.CHAT_FIRST,
|
|
111
|
+
ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_GEN_POINT,
|
|
112
|
+
ExtractionStrategy.CHAT_MAX_NORM, ExtractionStrategy.CHAT_WEIGHTED):
|
|
113
|
+
# All chat_* strategies use the same prompt construction
|
|
114
|
+
if hasattr(tokenizer, "apply_chat_template"):
|
|
115
|
+
try:
|
|
116
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
117
|
+
[{"role": "user", "content": prompt}],
|
|
118
|
+
tokenize=False,
|
|
119
|
+
add_generation_prompt=True,
|
|
120
|
+
)
|
|
121
|
+
full_text = tokenizer.apply_chat_template(
|
|
122
|
+
[{"role": "user", "content": prompt},
|
|
123
|
+
{"role": "assistant", "content": response}],
|
|
124
|
+
tokenize=False,
|
|
125
|
+
add_generation_prompt=False,
|
|
126
|
+
)
|
|
127
|
+
except (ValueError, KeyError):
|
|
128
|
+
# Fallback for models without chat template
|
|
129
|
+
prompt_only = prompt
|
|
130
|
+
full_text = f"{prompt} {response}"
|
|
131
|
+
else:
|
|
132
|
+
prompt_only = prompt
|
|
133
|
+
full_text = f"{prompt} {response}"
|
|
134
|
+
|
|
135
|
+
return full_text, response, prompt_only
|
|
136
|
+
|
|
137
|
+
elif strategy == ExtractionStrategy.ROLE_PLAY:
|
|
138
|
+
# "Behave like a person that would answer Q with A" format
|
|
139
|
+
random_token = ROLE_PLAY_TOKENS[hash(prompt) % len(ROLE_PLAY_TOKENS)]
|
|
140
|
+
instruction = f"Behave like a person that would answer {prompt} with {response}"
|
|
141
|
+
|
|
142
|
+
if hasattr(tokenizer, "apply_chat_template"):
|
|
143
|
+
try:
|
|
144
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
145
|
+
[{"role": "user", "content": instruction}],
|
|
146
|
+
tokenize=False,
|
|
147
|
+
add_generation_prompt=True,
|
|
148
|
+
)
|
|
149
|
+
full_text = tokenizer.apply_chat_template(
|
|
150
|
+
[{"role": "user", "content": instruction},
|
|
151
|
+
{"role": "assistant", "content": random_token}],
|
|
152
|
+
tokenize=False,
|
|
153
|
+
add_generation_prompt=False,
|
|
154
|
+
)
|
|
155
|
+
except (ValueError, KeyError):
|
|
156
|
+
prompt_only = instruction
|
|
157
|
+
full_text = f"{instruction} {random_token}"
|
|
158
|
+
else:
|
|
159
|
+
prompt_only = instruction
|
|
160
|
+
full_text = f"{instruction} {random_token}"
|
|
161
|
+
|
|
162
|
+
return full_text, random_token, prompt_only
|
|
163
|
+
|
|
164
|
+
elif strategy == ExtractionStrategy.MC_BALANCED:
|
|
165
|
+
# Multiple choice with balanced A/B assignment
|
|
166
|
+
if other_response is None:
|
|
167
|
+
raise ValueError("MC_BALANCED strategy requires other_response")
|
|
168
|
+
|
|
169
|
+
# Deterministic "random" based on prompt - same for both pos and neg of a pair
|
|
170
|
+
pos_goes_in_b = hash(prompt) % 2 == 0
|
|
171
|
+
|
|
172
|
+
if is_positive:
|
|
173
|
+
if pos_goes_in_b:
|
|
174
|
+
option_a = other_response[:200] # negative
|
|
175
|
+
option_b = response[:200] # positive
|
|
176
|
+
answer = "B"
|
|
177
|
+
else:
|
|
178
|
+
option_a = response[:200] # positive
|
|
179
|
+
option_b = other_response[:200] # negative
|
|
180
|
+
answer = "A"
|
|
181
|
+
else:
|
|
182
|
+
if pos_goes_in_b:
|
|
183
|
+
option_a = response[:200] # negative
|
|
184
|
+
option_b = other_response[:200] # positive
|
|
185
|
+
answer = "A"
|
|
186
|
+
else:
|
|
187
|
+
option_a = other_response[:200] # positive
|
|
188
|
+
option_b = response[:200] # negative
|
|
189
|
+
answer = "B"
|
|
190
|
+
|
|
191
|
+
mc_prompt = f"Which is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
|
|
192
|
+
|
|
193
|
+
if hasattr(tokenizer, "apply_chat_template"):
|
|
194
|
+
try:
|
|
195
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
196
|
+
[{"role": "user", "content": mc_prompt}],
|
|
197
|
+
tokenize=False,
|
|
198
|
+
add_generation_prompt=True,
|
|
199
|
+
)
|
|
200
|
+
full_text = tokenizer.apply_chat_template(
|
|
201
|
+
[{"role": "user", "content": mc_prompt},
|
|
202
|
+
{"role": "assistant", "content": answer}],
|
|
203
|
+
tokenize=False,
|
|
204
|
+
add_generation_prompt=False,
|
|
205
|
+
)
|
|
206
|
+
except (ValueError, KeyError):
|
|
207
|
+
prompt_only = mc_prompt
|
|
208
|
+
full_text = f"{mc_prompt} {answer}"
|
|
209
|
+
else:
|
|
210
|
+
prompt_only = mc_prompt
|
|
211
|
+
full_text = f"{mc_prompt} {answer}"
|
|
212
|
+
|
|
213
|
+
return full_text, answer, prompt_only
|
|
214
|
+
|
|
215
|
+
else:
|
|
216
|
+
raise ValueError(f"Unknown extraction strategy: {strategy}")
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def extract_activation(
|
|
220
|
+
strategy: ExtractionStrategy,
|
|
221
|
+
hidden_states: torch.Tensor,
|
|
222
|
+
answer_text: str,
|
|
223
|
+
tokenizer,
|
|
224
|
+
prompt_len: int,
|
|
225
|
+
) -> torch.Tensor:
|
|
226
|
+
"""
|
|
227
|
+
Extract the activation vector based on strategy.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
strategy: The extraction strategy
|
|
231
|
+
hidden_states: Hidden states tensor of shape [seq_len, hidden_dim]
|
|
232
|
+
answer_text: The answer text (for computing answer token count)
|
|
233
|
+
tokenizer: The tokenizer
|
|
234
|
+
prompt_len: Length of prompt in tokens (boundary)
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Activation vector of shape [hidden_dim]
|
|
238
|
+
"""
|
|
239
|
+
seq_len = hidden_states.shape[0]
|
|
240
|
+
|
|
241
|
+
# Compute answer token count
|
|
242
|
+
answer_tokens = tokenizer(answer_text, add_special_tokens=False)["input_ids"]
|
|
243
|
+
num_answer_tokens = len(answer_tokens)
|
|
244
|
+
|
|
245
|
+
if strategy == ExtractionStrategy.CHAT_LAST:
|
|
246
|
+
return hidden_states[-1]
|
|
247
|
+
|
|
248
|
+
elif strategy == ExtractionStrategy.CHAT_FIRST:
|
|
249
|
+
# First token of the answer
|
|
250
|
+
first_answer_idx = max(0, seq_len - num_answer_tokens - 1)
|
|
251
|
+
return hidden_states[first_answer_idx]
|
|
252
|
+
|
|
253
|
+
elif strategy == ExtractionStrategy.CHAT_MEAN:
|
|
254
|
+
# Mean of answer tokens
|
|
255
|
+
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
256
|
+
answer_hidden = hidden_states[-num_answer_tokens-1:-1]
|
|
257
|
+
return answer_hidden.mean(dim=0)
|
|
258
|
+
return hidden_states[-1]
|
|
259
|
+
|
|
260
|
+
elif strategy == ExtractionStrategy.CHAT_GEN_POINT:
|
|
261
|
+
# Last token before answer starts (decision point)
|
|
262
|
+
gen_point_idx = max(0, seq_len - num_answer_tokens - 2)
|
|
263
|
+
return hidden_states[gen_point_idx]
|
|
264
|
+
|
|
265
|
+
elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
|
|
266
|
+
# Token with max norm in answer region
|
|
267
|
+
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
268
|
+
answer_hidden = hidden_states[-num_answer_tokens-1:-1]
|
|
269
|
+
norms = torch.norm(answer_hidden, dim=1)
|
|
270
|
+
max_idx = torch.argmax(norms)
|
|
271
|
+
return answer_hidden[max_idx]
|
|
272
|
+
return hidden_states[-1]
|
|
273
|
+
|
|
274
|
+
elif strategy == ExtractionStrategy.CHAT_WEIGHTED:
|
|
275
|
+
# Position-weighted mean (earlier tokens weighted more)
|
|
276
|
+
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
277
|
+
answer_hidden = hidden_states[-num_answer_tokens-1:-1]
|
|
278
|
+
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=torch.float32, device=answer_hidden.device) * 0.5)
|
|
279
|
+
weights = weights / weights.sum()
|
|
280
|
+
return (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
|
|
281
|
+
return hidden_states[-1]
|
|
282
|
+
|
|
283
|
+
elif strategy in (ExtractionStrategy.ROLE_PLAY, ExtractionStrategy.MC_BALANCED):
|
|
284
|
+
# Both use last token
|
|
285
|
+
return hidden_states[-1]
|
|
286
|
+
|
|
287
|
+
else:
|
|
288
|
+
# Default fallback
|
|
289
|
+
return hidden_states[-1]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Add --extraction-strategy argument to an argument parser.
|
|
295
|
+
|
|
296
|
+
Usage:
|
|
297
|
+
parser = argparse.ArgumentParser()
|
|
298
|
+
add_extraction_strategy_args(parser)
|
|
299
|
+
args = parser.parse_args()
|
|
300
|
+
strategy = ExtractionStrategy(args.extraction_strategy)
|
|
301
|
+
"""
|
|
302
|
+
parser.add_argument(
|
|
303
|
+
"--extraction-strategy",
|
|
304
|
+
type=str,
|
|
305
|
+
default=ExtractionStrategy.default().value,
|
|
306
|
+
choices=ExtractionStrategy.list_all(),
|
|
307
|
+
help=f"Extraction strategy for activations. Options: {', '.join(ExtractionStrategy.list_all())}. Default: {ExtractionStrategy.default().value}",
|
|
308
|
+
)
|