wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +12 -7
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +260 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/train_classifier.py +16 -3
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +2 -2
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -4,15 +4,19 @@ Unified extraction strategies for activation collection.
|
|
|
4
4
|
These strategies combine prompt construction and token extraction into a single
|
|
5
5
|
unified approach, based on empirical testing of what actually works.
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
CHAT STRATEGIES (require chat template - for instruct models):
|
|
8
8
|
- chat_mean: Chat template prompt, mean of answer tokens
|
|
9
9
|
- chat_first: Chat template prompt, first answer token
|
|
10
10
|
- chat_last: Chat template prompt, last token
|
|
11
|
-
- chat_gen_point: Chat template prompt, token before answer (generation decision point)
|
|
12
11
|
- chat_max_norm: Chat template prompt, token with max norm in answer
|
|
13
12
|
- chat_weighted: Chat template prompt, position-weighted mean (earlier tokens weighted more)
|
|
14
13
|
- role_play: "Behave like person who answers Q with A" format, last token
|
|
15
14
|
- mc_balanced: Multiple choice with balanced A/B assignment, last token
|
|
15
|
+
|
|
16
|
+
BASE MODEL STRATEGIES (no chat template - for base models like gemma-2b, gemma-9b):
|
|
17
|
+
- completion_last: Direct Q+A completion, last token
|
|
18
|
+
- completion_mean: Direct Q+A completion, mean of answer tokens
|
|
19
|
+
- mc_completion: Multiple choice without chat template, A/B token
|
|
16
20
|
"""
|
|
17
21
|
|
|
18
22
|
from enum import Enum
|
|
@@ -35,10 +39,7 @@ class ExtractionStrategy(str, Enum):
|
|
|
35
39
|
"""Chat template prompt with Q+A, extract first answer token."""
|
|
36
40
|
|
|
37
41
|
CHAT_LAST = "chat_last"
|
|
38
|
-
"""Chat template prompt with Q+A, extract
|
|
39
|
-
|
|
40
|
-
CHAT_GEN_POINT = "chat_gen_point"
|
|
41
|
-
"""Chat template prompt with Q+A, extract token before answer starts (decision point)."""
|
|
42
|
+
"""Chat template prompt with Q+A, extract EOT token (has seen full answer)."""
|
|
42
43
|
|
|
43
44
|
CHAT_MAX_NORM = "chat_max_norm"
|
|
44
45
|
"""Chat template prompt with Q+A, extract token with max norm in answer region."""
|
|
@@ -47,22 +48,34 @@ class ExtractionStrategy(str, Enum):
|
|
|
47
48
|
"""Chat template prompt with Q+A, position-weighted mean (earlier tokens weighted more)."""
|
|
48
49
|
|
|
49
50
|
ROLE_PLAY = "role_play"
|
|
50
|
-
"""'Behave like person who answers Q with A' format, extract
|
|
51
|
+
"""'Behave like person who answers Q with A' format, extract EOT token."""
|
|
51
52
|
|
|
52
53
|
MC_BALANCED = "mc_balanced"
|
|
53
|
-
"""Multiple choice format with balanced A/B assignment, extract
|
|
54
|
+
"""Multiple choice format with balanced A/B assignment, extract the A/B choice token."""
|
|
55
|
+
|
|
56
|
+
# Base model strategies (no chat template required)
|
|
57
|
+
COMPLETION_LAST = "completion_last"
|
|
58
|
+
"""Direct Q+A completion without chat template, extract last token. For base models."""
|
|
59
|
+
|
|
60
|
+
COMPLETION_MEAN = "completion_mean"
|
|
61
|
+
"""Direct Q+A completion without chat template, extract mean of answer tokens. For base models."""
|
|
62
|
+
|
|
63
|
+
MC_COMPLETION = "mc_completion"
|
|
64
|
+
"""Multiple choice without chat template, extract A/B token. For base models."""
|
|
54
65
|
|
|
55
66
|
@property
|
|
56
67
|
def description(self) -> str:
|
|
57
68
|
descriptions = {
|
|
58
69
|
ExtractionStrategy.CHAT_MEAN: "Chat template with mean of answer tokens",
|
|
59
70
|
ExtractionStrategy.CHAT_FIRST: "Chat template with first answer token",
|
|
60
|
-
ExtractionStrategy.CHAT_LAST: "Chat template with
|
|
61
|
-
ExtractionStrategy.CHAT_GEN_POINT: "Chat template with generation decision point",
|
|
71
|
+
ExtractionStrategy.CHAT_LAST: "Chat template with EOT token",
|
|
62
72
|
ExtractionStrategy.CHAT_MAX_NORM: "Chat template with max-norm answer token",
|
|
63
73
|
ExtractionStrategy.CHAT_WEIGHTED: "Chat template with position-weighted mean",
|
|
64
|
-
ExtractionStrategy.ROLE_PLAY: "Role-playing format with
|
|
65
|
-
ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with
|
|
74
|
+
ExtractionStrategy.ROLE_PLAY: "Role-playing format with EOT token",
|
|
75
|
+
ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with A/B token",
|
|
76
|
+
ExtractionStrategy.COMPLETION_LAST: "Direct completion with last token (base models)",
|
|
77
|
+
ExtractionStrategy.COMPLETION_MEAN: "Direct completion with mean of answer tokens (base models)",
|
|
78
|
+
ExtractionStrategy.MC_COMPLETION: "Multiple choice completion with A/B token (base models)",
|
|
66
79
|
}
|
|
67
80
|
return descriptions.get(self, "Unknown strategy")
|
|
68
81
|
|
|
@@ -75,6 +88,81 @@ class ExtractionStrategy(str, Enum):
|
|
|
75
88
|
def list_all(cls) -> list[str]:
|
|
76
89
|
"""List all strategy names."""
|
|
77
90
|
return [s.value for s in cls]
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def for_tokenizer(cls, tokenizer, prefer_mc: bool = False) -> "ExtractionStrategy":
|
|
94
|
+
"""
|
|
95
|
+
Select the appropriate strategy based on whether tokenizer supports chat template.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
tokenizer: The tokenizer to check
|
|
99
|
+
prefer_mc: If True, prefer multiple choice strategies (mc_balanced/mc_completion)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Appropriate strategy for the tokenizer type
|
|
103
|
+
"""
|
|
104
|
+
has_chat = (hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
|
|
105
|
+
and hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None)
|
|
106
|
+
|
|
107
|
+
if has_chat:
|
|
108
|
+
return cls.MC_BALANCED if prefer_mc else cls.CHAT_LAST
|
|
109
|
+
else:
|
|
110
|
+
return cls.MC_COMPLETION if prefer_mc else cls.COMPLETION_LAST
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def is_base_model_strategy(cls, strategy: "ExtractionStrategy") -> bool:
|
|
114
|
+
"""Check if a strategy is designed for base models (no chat template)."""
|
|
115
|
+
return strategy in (cls.COMPLETION_LAST, cls.COMPLETION_MEAN, cls.MC_COMPLETION)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def get_equivalent_for_model_type(cls, strategy: "ExtractionStrategy", tokenizer) -> "ExtractionStrategy":
|
|
119
|
+
"""
|
|
120
|
+
Get the equivalent strategy for the given tokenizer type.
|
|
121
|
+
|
|
122
|
+
If strategy requires chat template but tokenizer doesn't have it,
|
|
123
|
+
returns the base model equivalent. And vice versa.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
strategy: The requested strategy
|
|
127
|
+
tokenizer: The tokenizer to check
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The appropriate strategy for the tokenizer
|
|
131
|
+
"""
|
|
132
|
+
has_chat = (hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
|
|
133
|
+
and hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None)
|
|
134
|
+
is_base_strategy = cls.is_base_model_strategy(strategy)
|
|
135
|
+
|
|
136
|
+
if has_chat and is_base_strategy:
|
|
137
|
+
# Tokenizer has chat but strategy is for base model - upgrade to chat version
|
|
138
|
+
mapping = {
|
|
139
|
+
cls.COMPLETION_LAST: cls.CHAT_LAST,
|
|
140
|
+
cls.COMPLETION_MEAN: cls.CHAT_MEAN,
|
|
141
|
+
cls.MC_COMPLETION: cls.MC_BALANCED,
|
|
142
|
+
}
|
|
143
|
+
return mapping.get(strategy, strategy)
|
|
144
|
+
|
|
145
|
+
elif not has_chat and not is_base_strategy:
|
|
146
|
+
# Tokenizer is base model but strategy requires chat - downgrade to base version
|
|
147
|
+
mapping = {
|
|
148
|
+
cls.CHAT_LAST: cls.COMPLETION_LAST,
|
|
149
|
+
cls.CHAT_FIRST: cls.COMPLETION_LAST,
|
|
150
|
+
cls.CHAT_MEAN: cls.COMPLETION_MEAN,
|
|
151
|
+
cls.CHAT_MAX_NORM: cls.COMPLETION_LAST,
|
|
152
|
+
cls.CHAT_WEIGHTED: cls.COMPLETION_MEAN,
|
|
153
|
+
cls.ROLE_PLAY: cls.COMPLETION_LAST,
|
|
154
|
+
cls.MC_BALANCED: cls.MC_COMPLETION,
|
|
155
|
+
}
|
|
156
|
+
return mapping.get(strategy, cls.COMPLETION_LAST)
|
|
157
|
+
|
|
158
|
+
return strategy
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def tokenizer_has_chat_template(tokenizer) -> bool:
|
|
162
|
+
"""Check if tokenizer supports chat template."""
|
|
163
|
+
has_method = hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
|
|
164
|
+
has_template = hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
|
|
165
|
+
return has_method and has_template
|
|
78
166
|
|
|
79
167
|
|
|
80
168
|
# Random tokens for role_play strategy (deterministic based on prompt hash)
|
|
@@ -88,6 +176,7 @@ def build_extraction_texts(
|
|
|
88
176
|
tokenizer,
|
|
89
177
|
other_response: Optional[str] = None,
|
|
90
178
|
is_positive: bool = True,
|
|
179
|
+
auto_convert_strategy: bool = True,
|
|
91
180
|
) -> Tuple[str, str, Optional[str]]:
|
|
92
181
|
"""
|
|
93
182
|
Build the full text for activation extraction based on strategy.
|
|
@@ -97,8 +186,9 @@ def build_extraction_texts(
|
|
|
97
186
|
prompt: The user prompt/question
|
|
98
187
|
response: The response to extract activations for
|
|
99
188
|
tokenizer: The tokenizer (needs apply_chat_template for chat strategies)
|
|
100
|
-
other_response: For mc_balanced, the other response option
|
|
101
|
-
is_positive: For mc_balanced, whether 'response' is the positive option
|
|
189
|
+
other_response: For mc_balanced/mc_completion, the other response option
|
|
190
|
+
is_positive: For mc_balanced/mc_completion, whether 'response' is the positive option
|
|
191
|
+
auto_convert_strategy: If True, automatically convert strategy to match tokenizer type
|
|
102
192
|
|
|
103
193
|
Returns:
|
|
104
194
|
Tuple of (full_text, answer_text, prompt_only_text)
|
|
@@ -106,31 +196,40 @@ def build_extraction_texts(
|
|
|
106
196
|
- answer_text: The answer portion (for strategies that need it)
|
|
107
197
|
- prompt_only_text: Prompt without answer (for boundary detection)
|
|
108
198
|
"""
|
|
199
|
+
# Auto-convert strategy if needed
|
|
200
|
+
if auto_convert_strategy:
|
|
201
|
+
original_strategy = strategy
|
|
202
|
+
strategy = ExtractionStrategy.get_equivalent_for_model_type(strategy, tokenizer)
|
|
203
|
+
if strategy != original_strategy:
|
|
204
|
+
import warnings
|
|
205
|
+
warnings.warn(
|
|
206
|
+
f"Strategy {original_strategy.value} not compatible with tokenizer, "
|
|
207
|
+
f"using {strategy.value} instead.",
|
|
208
|
+
UserWarning
|
|
209
|
+
)
|
|
109
210
|
|
|
110
211
|
if strategy in (ExtractionStrategy.CHAT_MEAN, ExtractionStrategy.CHAT_FIRST,
|
|
111
|
-
ExtractionStrategy.CHAT_LAST, ExtractionStrategy.
|
|
112
|
-
ExtractionStrategy.
|
|
212
|
+
ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_MAX_NORM,
|
|
213
|
+
ExtractionStrategy.CHAT_WEIGHTED):
|
|
113
214
|
# All chat_* strategies use the same prompt construction
|
|
114
|
-
if hasattr(tokenizer, "apply_chat_template"):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
prompt_only = prompt
|
|
133
|
-
full_text = f"{prompt} {response}"
|
|
215
|
+
if not hasattr(tokenizer, "apply_chat_template"):
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
|
|
218
|
+
f"Base models without chat templates are not supported for chat_* strategies. "
|
|
219
|
+
f"Use completion_last, completion_mean, or mc_completion instead."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
223
|
+
[{"role": "user", "content": prompt}],
|
|
224
|
+
tokenize=False,
|
|
225
|
+
add_generation_prompt=True,
|
|
226
|
+
)
|
|
227
|
+
full_text = tokenizer.apply_chat_template(
|
|
228
|
+
[{"role": "user", "content": prompt},
|
|
229
|
+
{"role": "assistant", "content": response}],
|
|
230
|
+
tokenize=False,
|
|
231
|
+
add_generation_prompt=False,
|
|
232
|
+
)
|
|
134
233
|
|
|
135
234
|
return full_text, response, prompt_only
|
|
136
235
|
|
|
@@ -139,25 +238,23 @@ def build_extraction_texts(
|
|
|
139
238
|
random_token = ROLE_PLAY_TOKENS[hash(prompt) % len(ROLE_PLAY_TOKENS)]
|
|
140
239
|
instruction = f"Behave like a person that would answer {prompt} with {response}"
|
|
141
240
|
|
|
142
|
-
if hasattr(tokenizer, "apply_chat_template"):
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
prompt_only = instruction
|
|
160
|
-
full_text = f"{instruction} {random_token}"
|
|
241
|
+
if not hasattr(tokenizer, "apply_chat_template"):
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
|
|
244
|
+
f"Use completion_last or mc_completion for base models."
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
248
|
+
[{"role": "user", "content": instruction}],
|
|
249
|
+
tokenize=False,
|
|
250
|
+
add_generation_prompt=True,
|
|
251
|
+
)
|
|
252
|
+
full_text = tokenizer.apply_chat_template(
|
|
253
|
+
[{"role": "user", "content": instruction},
|
|
254
|
+
{"role": "assistant", "content": random_token}],
|
|
255
|
+
tokenize=False,
|
|
256
|
+
add_generation_prompt=False,
|
|
257
|
+
)
|
|
161
258
|
|
|
162
259
|
return full_text, random_token, prompt_only
|
|
163
260
|
|
|
@@ -188,28 +285,66 @@ def build_extraction_texts(
|
|
|
188
285
|
option_b = response[:200] # negative
|
|
189
286
|
answer = "B"
|
|
190
287
|
|
|
191
|
-
mc_prompt = f"
|
|
288
|
+
mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
|
|
289
|
+
|
|
290
|
+
if not hasattr(tokenizer, "apply_chat_template"):
|
|
291
|
+
raise ValueError(
|
|
292
|
+
f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
|
|
293
|
+
f"Use mc_completion for base models."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
297
|
+
[{"role": "user", "content": mc_prompt}],
|
|
298
|
+
tokenize=False,
|
|
299
|
+
add_generation_prompt=True,
|
|
300
|
+
)
|
|
301
|
+
full_text = tokenizer.apply_chat_template(
|
|
302
|
+
[{"role": "user", "content": mc_prompt},
|
|
303
|
+
{"role": "assistant", "content": answer}],
|
|
304
|
+
tokenize=False,
|
|
305
|
+
add_generation_prompt=False,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return full_text, answer, prompt_only
|
|
309
|
+
|
|
310
|
+
elif strategy in (ExtractionStrategy.COMPLETION_LAST, ExtractionStrategy.COMPLETION_MEAN):
|
|
311
|
+
# Base model strategies - direct Q+A without chat template
|
|
312
|
+
# Format: "Q: {prompt}\nA: {response}"
|
|
313
|
+
prompt_only = f"Q: {prompt}\nA:"
|
|
314
|
+
full_text = f"Q: {prompt}\nA: {response}"
|
|
315
|
+
return full_text, response, prompt_only
|
|
316
|
+
|
|
317
|
+
elif strategy == ExtractionStrategy.MC_COMPLETION:
|
|
318
|
+
# Multiple choice for base models - no chat template
|
|
319
|
+
if other_response is None:
|
|
320
|
+
raise ValueError("MC_COMPLETION strategy requires other_response")
|
|
321
|
+
|
|
322
|
+
# Deterministic "random" based on prompt - same for both pos and neg of a pair
|
|
323
|
+
pos_goes_in_b = hash(prompt) % 2 == 0
|
|
192
324
|
|
|
193
|
-
if
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
{"role": "assistant", "content": answer}],
|
|
203
|
-
tokenize=False,
|
|
204
|
-
add_generation_prompt=False,
|
|
205
|
-
)
|
|
206
|
-
except (ValueError, KeyError):
|
|
207
|
-
prompt_only = mc_prompt
|
|
208
|
-
full_text = f"{mc_prompt} {answer}"
|
|
325
|
+
if is_positive:
|
|
326
|
+
if pos_goes_in_b:
|
|
327
|
+
option_a = other_response[:200]
|
|
328
|
+
option_b = response[:200]
|
|
329
|
+
answer = "B"
|
|
330
|
+
else:
|
|
331
|
+
option_a = response[:200]
|
|
332
|
+
option_b = other_response[:200]
|
|
333
|
+
answer = "A"
|
|
209
334
|
else:
|
|
210
|
-
|
|
211
|
-
|
|
335
|
+
if pos_goes_in_b:
|
|
336
|
+
option_a = response[:200]
|
|
337
|
+
option_b = other_response[:200]
|
|
338
|
+
answer = "A"
|
|
339
|
+
else:
|
|
340
|
+
option_a = other_response[:200]
|
|
341
|
+
option_b = response[:200]
|
|
342
|
+
answer = "B"
|
|
212
343
|
|
|
344
|
+
mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
|
|
345
|
+
|
|
346
|
+
prompt_only = mc_prompt
|
|
347
|
+
full_text = f"{mc_prompt} {answer}"
|
|
213
348
|
return full_text, answer, prompt_only
|
|
214
349
|
|
|
215
350
|
else:
|
|
@@ -243,6 +378,7 @@ def extract_activation(
|
|
|
243
378
|
num_answer_tokens = len(answer_tokens)
|
|
244
379
|
|
|
245
380
|
if strategy == ExtractionStrategy.CHAT_LAST:
|
|
381
|
+
# EOT token - has seen the entire answer, best performance
|
|
246
382
|
return hidden_states[-1]
|
|
247
383
|
|
|
248
384
|
elif strategy == ExtractionStrategy.CHAT_FIRST:
|
|
@@ -257,11 +393,6 @@ def extract_activation(
|
|
|
257
393
|
return answer_hidden.mean(dim=0)
|
|
258
394
|
return hidden_states[-1]
|
|
259
395
|
|
|
260
|
-
elif strategy == ExtractionStrategy.CHAT_GEN_POINT:
|
|
261
|
-
# Last token before answer starts (decision point)
|
|
262
|
-
gen_point_idx = max(0, seq_len - num_answer_tokens - 2)
|
|
263
|
-
return hidden_states[gen_point_idx]
|
|
264
|
-
|
|
265
396
|
elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
|
|
266
397
|
# Token with max norm in answer region
|
|
267
398
|
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
@@ -275,18 +406,36 @@ def extract_activation(
|
|
|
275
406
|
# Position-weighted mean (earlier tokens weighted more)
|
|
276
407
|
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
277
408
|
answer_hidden = hidden_states[-num_answer_tokens-1:-1]
|
|
278
|
-
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=
|
|
409
|
+
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
|
|
279
410
|
weights = weights / weights.sum()
|
|
280
411
|
return (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
|
|
281
412
|
return hidden_states[-1]
|
|
282
413
|
|
|
283
|
-
elif strategy
|
|
284
|
-
#
|
|
414
|
+
elif strategy == ExtractionStrategy.ROLE_PLAY:
|
|
415
|
+
# EOT token - slightly better than answer word (65% vs 64%)
|
|
285
416
|
return hidden_states[-1]
|
|
286
417
|
|
|
287
|
-
|
|
288
|
-
#
|
|
418
|
+
elif strategy == ExtractionStrategy.MC_BALANCED:
|
|
419
|
+
# Answer token (A/B) - better than EOT (64% vs 56%)
|
|
420
|
+
return hidden_states[-2]
|
|
421
|
+
|
|
422
|
+
elif strategy == ExtractionStrategy.COMPLETION_LAST:
|
|
423
|
+
# Last token for base model completion
|
|
289
424
|
return hidden_states[-1]
|
|
425
|
+
|
|
426
|
+
elif strategy == ExtractionStrategy.COMPLETION_MEAN:
|
|
427
|
+
# Mean of answer tokens for base model completion
|
|
428
|
+
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
429
|
+
answer_hidden = hidden_states[-num_answer_tokens:]
|
|
430
|
+
return answer_hidden.mean(dim=0)
|
|
431
|
+
return hidden_states[-1]
|
|
432
|
+
|
|
433
|
+
elif strategy == ExtractionStrategy.MC_COMPLETION:
|
|
434
|
+
# A/B token for base model MC (last token is the answer)
|
|
435
|
+
return hidden_states[-1]
|
|
436
|
+
|
|
437
|
+
else:
|
|
438
|
+
raise ValueError(f"Unknown extraction strategy: {strategy}")
|
|
290
439
|
|
|
291
440
|
|
|
292
441
|
def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -306,3 +455,30 @@ def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
|
|
|
306
455
|
choices=ExtractionStrategy.list_all(),
|
|
307
456
|
help=f"Extraction strategy for activations. Options: {', '.join(ExtractionStrategy.list_all())}. Default: {ExtractionStrategy.default().value}",
|
|
308
457
|
)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def get_strategy_for_model(tokenizer, prefer_mc: bool = False) -> ExtractionStrategy:
|
|
461
|
+
"""
|
|
462
|
+
Get the best extraction strategy for a given tokenizer.
|
|
463
|
+
|
|
464
|
+
Automatically detects if tokenizer has chat template and returns
|
|
465
|
+
the appropriate strategy.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
tokenizer: The tokenizer to check
|
|
469
|
+
prefer_mc: If True, prefer multiple choice strategies
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
ExtractionStrategy appropriate for the tokenizer
|
|
473
|
+
|
|
474
|
+
Example:
|
|
475
|
+
>>> from transformers import AutoTokenizer
|
|
476
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
|
477
|
+
>>> strategy = get_strategy_for_model(tokenizer)
|
|
478
|
+
>>> print(strategy) # completion_last (base model)
|
|
479
|
+
|
|
480
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
|
|
481
|
+
>>> strategy = get_strategy_for_model(tokenizer)
|
|
482
|
+
>>> print(strategy) # chat_last (instruct model)
|
|
483
|
+
"""
|
|
484
|
+
return ExtractionStrategy.for_tokenizer(tokenizer, prefer_mc=prefer_mc)
|
|
@@ -14,6 +14,7 @@ import numpy as np
|
|
|
14
14
|
|
|
15
15
|
from torch.nn.modules.loss import _Loss
|
|
16
16
|
from wisent.core.errors import DuplicateNameError, InvalidRangeError, UnknownTypeError
|
|
17
|
+
from wisent.core.utils.device import preferred_dtype
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
"ClassifierTrainConfig",
|
|
@@ -164,13 +165,13 @@ class BaseClassifier(ABC):
|
|
|
164
165
|
self,
|
|
165
166
|
threshold: float = 0.5,
|
|
166
167
|
device: str | None = None,
|
|
167
|
-
dtype: torch.dtype =
|
|
168
|
+
dtype: torch.dtype | None = None,
|
|
168
169
|
) -> None:
|
|
169
170
|
if not 0.0 <= threshold <= 1.0:
|
|
170
171
|
raise InvalidRangeError(param_name="threshold", actual=threshold, min_val=0.0, max_val=1.0)
|
|
171
172
|
self.threshold = threshold
|
|
172
173
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
173
|
-
self.dtype =
|
|
174
|
+
self.dtype = dtype if dtype is not None else preferred_dtype(self.device)
|
|
174
175
|
self.model = None
|
|
175
176
|
|
|
176
177
|
@abstractmethod
|
wisent/core/cli/__init__.py
CHANGED
|
@@ -22,5 +22,6 @@ from .inference_config_cli import execute_inference_config
|
|
|
22
22
|
from .optimization_cache import execute_optimization_cache
|
|
23
23
|
from .optimize_weights import execute_optimize_weights
|
|
24
24
|
from .optimize import execute_optimize
|
|
25
|
+
from .geometry_search import execute_geometry_search
|
|
25
26
|
|
|
26
|
-
__all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize']
|
|
27
|
+
__all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize', 'execute_geometry_search']
|
|
@@ -1,8 +1,20 @@
|
|
|
1
1
|
"""Train classifier on contrastive pairs for agent."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
+
import torch
|
|
4
5
|
from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainReport
|
|
5
6
|
from wisent.core.errors import UnknownTypeError
|
|
7
|
+
from wisent.core.utils.device import preferred_dtype
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _torch_dtype_to_numpy(torch_dtype: torch.dtype):
|
|
11
|
+
"""Convert torch dtype to numpy dtype."""
|
|
12
|
+
mapping = {
|
|
13
|
+
torch.float32: np.float32,
|
|
14
|
+
torch.float16: np.float16,
|
|
15
|
+
torch.bfloat16: np.float32, # numpy doesn't support bfloat16, use float32
|
|
16
|
+
}
|
|
17
|
+
return mapping.get(torch_dtype, np.float32)
|
|
6
18
|
|
|
7
19
|
|
|
8
20
|
def _map_token_aggregation(aggregation_str: str):
|
|
@@ -97,7 +109,7 @@ def train_classifier_on_pairs(
|
|
|
97
109
|
prompt_construction_strategy = _map_prompt_strategy(prompt_strategy)
|
|
98
110
|
|
|
99
111
|
# Collect activations for all pairs
|
|
100
|
-
collector = ActivationCollector(model=model
|
|
112
|
+
collector = ActivationCollector(model=model)
|
|
101
113
|
target_layers = [str(target_layer)]
|
|
102
114
|
layer_key = target_layers[0]
|
|
103
115
|
|
|
@@ -133,8 +145,9 @@ def train_classifier_on_pairs(
|
|
|
133
145
|
X_list.append(neg_act.cpu().numpy())
|
|
134
146
|
y_list.append(0.0)
|
|
135
147
|
|
|
136
|
-
|
|
137
|
-
|
|
148
|
+
np_dtype = _torch_dtype_to_numpy(preferred_dtype())
|
|
149
|
+
X_train = np.array(X_list, dtype=np_dtype)
|
|
150
|
+
y_train = np.array(y_list, dtype=np_dtype)
|
|
138
151
|
|
|
139
152
|
print(f" Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features")
|
|
140
153
|
|
|
@@ -31,6 +31,7 @@ def execute_check_linearity(args):
|
|
|
31
31
|
from wisent.core.models.wisent_model import WisentModel
|
|
32
32
|
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
33
33
|
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
34
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
34
35
|
|
|
35
36
|
# Build ContrastivePair objects
|
|
36
37
|
pairs = []
|
|
@@ -72,6 +73,10 @@ def execute_check_linearity(args):
|
|
|
72
73
|
if args.layers:
|
|
73
74
|
config.layers_to_test = [int(l) for l in args.layers.split(',')]
|
|
74
75
|
|
|
76
|
+
if args.extraction_strategy:
|
|
77
|
+
config.extraction_strategies = [ExtractionStrategy(args.extraction_strategy)]
|
|
78
|
+
print(f"Using extraction strategy: {args.extraction_strategy}")
|
|
79
|
+
|
|
75
80
|
# Run check
|
|
76
81
|
print("\nRunning linearity check...")
|
|
77
82
|
result = check_linearity(pairs, model, config)
|
|
@@ -110,12 +115,39 @@ def execute_check_linearity(args):
|
|
|
110
115
|
|
|
111
116
|
sorted_results = sorted(result.all_results, key=lambda x: x['linear_score'], reverse=True)
|
|
112
117
|
|
|
113
|
-
print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'
|
|
114
|
-
print("-" *
|
|
118
|
+
print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'Strategy':<20} {'Structure':<12} {'Norm'}")
|
|
119
|
+
print("-" * 70)
|
|
115
120
|
|
|
116
121
|
for r in sorted_results[:20]:
|
|
117
122
|
print(f"{r['linear_score']:<8.3f} {r['cohens_d']:<8.2f} {r['layer']:<6} "
|
|
118
|
-
f"{r['
|
|
123
|
+
f"{r['extraction_strategy']:<20} {r['best_structure']:<12} {r['normalize']}")
|
|
124
|
+
|
|
125
|
+
# Show best result for each structure type
|
|
126
|
+
if sorted_results and 'all_structure_scores' in sorted_results[0]:
|
|
127
|
+
print(f"\n{'='*60}")
|
|
128
|
+
print("BEST SCORE PER STRUCTURE TYPE")
|
|
129
|
+
print(f"{'='*60}")
|
|
130
|
+
|
|
131
|
+
# Collect best score for each structure across all configs
|
|
132
|
+
best_per_structure = {}
|
|
133
|
+
for r in result.all_results:
|
|
134
|
+
if 'all_structure_scores' not in r:
|
|
135
|
+
continue
|
|
136
|
+
for struct_name, data in r['all_structure_scores'].items():
|
|
137
|
+
score = data['score']
|
|
138
|
+
if struct_name not in best_per_structure or score > best_per_structure[struct_name]['score']:
|
|
139
|
+
best_per_structure[struct_name] = {
|
|
140
|
+
'score': score,
|
|
141
|
+
'confidence': data['confidence'],
|
|
142
|
+
'layer': r['layer'],
|
|
143
|
+
'strategy': r['extraction_strategy'],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
print(f"{'Structure':<12} {'Score':<8} {'Conf':<8} {'Layer':<6} {'Strategy'}")
|
|
147
|
+
print("-" * 55)
|
|
148
|
+
sorted_structs = sorted(best_per_structure.items(), key=lambda x: x[1]['score'], reverse=True)
|
|
149
|
+
for name, data in sorted_structs:
|
|
150
|
+
print(f"{name:<12} {data['score']:<8.3f} {data['confidence']:<8.3f} {data['layer']:<6} {data['strategy']}")
|
|
119
151
|
|
|
120
152
|
# Exit code based on verdict
|
|
121
153
|
if result.verdict.value == "linear":
|
|
@@ -33,7 +33,6 @@ STRATEGIES = [
|
|
|
33
33
|
"chat_mean",
|
|
34
34
|
"chat_first",
|
|
35
35
|
"chat_last",
|
|
36
|
-
"chat_gen_point",
|
|
37
36
|
"chat_max_norm",
|
|
38
37
|
"chat_weighted",
|
|
39
38
|
"role_play",
|
|
@@ -134,9 +133,9 @@ def get_weighted_mean_answer_act(model, tokenizer, text: str, answer: str, layer
|
|
|
134
133
|
hidden = outputs.hidden_states[layer][0]
|
|
135
134
|
if num_answer_tokens > 0 and num_answer_tokens < hidden.shape[0]:
|
|
136
135
|
answer_hidden = hidden[-num_answer_tokens-1:-1, :]
|
|
137
|
-
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=
|
|
136
|
+
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
|
|
138
137
|
weights = weights / weights.sum()
|
|
139
|
-
weighted_mean = (answer_hidden * weights.unsqueeze(1)
|
|
138
|
+
weighted_mean = (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
|
|
140
139
|
return weighted_mean.cpu().float()
|
|
141
140
|
return hidden[-1].cpu().float()
|
|
142
141
|
|
|
@@ -156,8 +155,6 @@ def get_activation(model, tokenizer, prompt: str, response: str, layer: int, dev
|
|
|
156
155
|
return get_first_answer_token_act(model, tokenizer, text, response, layer, device)
|
|
157
156
|
elif strategy == "chat_last":
|
|
158
157
|
return get_last_token_act(model, tokenizer, text, layer, device)
|
|
159
|
-
elif strategy == "chat_gen_point":
|
|
160
|
-
return get_generation_point_act(model, tokenizer, text, response, layer, device)
|
|
161
158
|
elif strategy == "chat_max_norm":
|
|
162
159
|
return get_max_norm_answer_act(model, tokenizer, text, response, layer, device)
|
|
163
160
|
elif strategy == "chat_weighted":
|
|
@@ -348,7 +345,8 @@ def execute_cluster_benchmarks(args):
|
|
|
348
345
|
|
|
349
346
|
logger.info(f"Loading {model}...")
|
|
350
347
|
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
|
351
|
-
|
|
348
|
+
from wisent.core.utils.device import device_optimized_dtype
|
|
349
|
+
dtype = device_optimized_dtype(device)
|
|
352
350
|
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=dtype, device_map=device, trust_remote_code=True)
|
|
353
351
|
|
|
354
352
|
layers = get_layers_to_test(llm)
|