wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +9 -5
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +256 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/apply_steering.py +5 -7
- wisent/core/cli/agent/train_classifier.py +19 -7
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/generate_paper_data.py +384 -0
- wisent/examples/scripts/intervention_validation.py +626 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/examples/scripts/threshold_analysis.py +434 -0
- wisent/examples/scripts/visualization_gallery.py +582 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
|
@@ -4,15 +4,19 @@ Unified extraction strategies for activation collection.
|
|
|
4
4
|
These strategies combine prompt construction and token extraction into a single
|
|
5
5
|
unified approach, based on empirical testing of what actually works.
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
CHAT STRATEGIES (require chat template - for instruct models):
|
|
8
8
|
- chat_mean: Chat template prompt, mean of answer tokens
|
|
9
9
|
- chat_first: Chat template prompt, first answer token
|
|
10
10
|
- chat_last: Chat template prompt, last token
|
|
11
|
-
- chat_gen_point: Chat template prompt, token before answer (generation decision point)
|
|
12
11
|
- chat_max_norm: Chat template prompt, token with max norm in answer
|
|
13
12
|
- chat_weighted: Chat template prompt, position-weighted mean (earlier tokens weighted more)
|
|
14
13
|
- role_play: "Behave like person who answers Q with A" format, last token
|
|
15
14
|
- mc_balanced: Multiple choice with balanced A/B assignment, last token
|
|
15
|
+
|
|
16
|
+
BASE MODEL STRATEGIES (no chat template - for base models like gemma-2b, gemma-9b):
|
|
17
|
+
- completion_last: Direct Q+A completion, last token
|
|
18
|
+
- completion_mean: Direct Q+A completion, mean of answer tokens
|
|
19
|
+
- mc_completion: Multiple choice without chat template, A/B token
|
|
16
20
|
"""
|
|
17
21
|
|
|
18
22
|
from enum import Enum
|
|
@@ -35,10 +39,7 @@ class ExtractionStrategy(str, Enum):
|
|
|
35
39
|
"""Chat template prompt with Q+A, extract first answer token."""
|
|
36
40
|
|
|
37
41
|
CHAT_LAST = "chat_last"
|
|
38
|
-
"""Chat template prompt with Q+A, extract
|
|
39
|
-
|
|
40
|
-
CHAT_GEN_POINT = "chat_gen_point"
|
|
41
|
-
"""Chat template prompt with Q+A, extract token before answer starts (decision point)."""
|
|
42
|
+
"""Chat template prompt with Q+A, extract EOT token (has seen full answer)."""
|
|
42
43
|
|
|
43
44
|
CHAT_MAX_NORM = "chat_max_norm"
|
|
44
45
|
"""Chat template prompt with Q+A, extract token with max norm in answer region."""
|
|
@@ -47,22 +48,34 @@ class ExtractionStrategy(str, Enum):
|
|
|
47
48
|
"""Chat template prompt with Q+A, position-weighted mean (earlier tokens weighted more)."""
|
|
48
49
|
|
|
49
50
|
ROLE_PLAY = "role_play"
|
|
50
|
-
"""'Behave like person who answers Q with A' format, extract
|
|
51
|
+
"""'Behave like person who answers Q with A' format, extract EOT token."""
|
|
51
52
|
|
|
52
53
|
MC_BALANCED = "mc_balanced"
|
|
53
|
-
"""Multiple choice format with balanced A/B assignment, extract
|
|
54
|
+
"""Multiple choice format with balanced A/B assignment, extract the A/B choice token."""
|
|
55
|
+
|
|
56
|
+
# Base model strategies (no chat template required)
|
|
57
|
+
COMPLETION_LAST = "completion_last"
|
|
58
|
+
"""Direct Q+A completion without chat template, extract last token. For base models."""
|
|
59
|
+
|
|
60
|
+
COMPLETION_MEAN = "completion_mean"
|
|
61
|
+
"""Direct Q+A completion without chat template, extract mean of answer tokens. For base models."""
|
|
62
|
+
|
|
63
|
+
MC_COMPLETION = "mc_completion"
|
|
64
|
+
"""Multiple choice without chat template, extract A/B token. For base models."""
|
|
54
65
|
|
|
55
66
|
@property
|
|
56
67
|
def description(self) -> str:
|
|
57
68
|
descriptions = {
|
|
58
69
|
ExtractionStrategy.CHAT_MEAN: "Chat template with mean of answer tokens",
|
|
59
70
|
ExtractionStrategy.CHAT_FIRST: "Chat template with first answer token",
|
|
60
|
-
ExtractionStrategy.CHAT_LAST: "Chat template with
|
|
61
|
-
ExtractionStrategy.CHAT_GEN_POINT: "Chat template with generation decision point",
|
|
71
|
+
ExtractionStrategy.CHAT_LAST: "Chat template with EOT token",
|
|
62
72
|
ExtractionStrategy.CHAT_MAX_NORM: "Chat template with max-norm answer token",
|
|
63
73
|
ExtractionStrategy.CHAT_WEIGHTED: "Chat template with position-weighted mean",
|
|
64
|
-
ExtractionStrategy.ROLE_PLAY: "Role-playing format with
|
|
65
|
-
ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with
|
|
74
|
+
ExtractionStrategy.ROLE_PLAY: "Role-playing format with EOT token",
|
|
75
|
+
ExtractionStrategy.MC_BALANCED: "Balanced multiple choice with A/B token",
|
|
76
|
+
ExtractionStrategy.COMPLETION_LAST: "Direct completion with last token (base models)",
|
|
77
|
+
ExtractionStrategy.COMPLETION_MEAN: "Direct completion with mean of answer tokens (base models)",
|
|
78
|
+
ExtractionStrategy.MC_COMPLETION: "Multiple choice completion with A/B token (base models)",
|
|
66
79
|
}
|
|
67
80
|
return descriptions.get(self, "Unknown strategy")
|
|
68
81
|
|
|
@@ -75,6 +88,77 @@ class ExtractionStrategy(str, Enum):
|
|
|
75
88
|
def list_all(cls) -> list[str]:
|
|
76
89
|
"""List all strategy names."""
|
|
77
90
|
return [s.value for s in cls]
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def for_tokenizer(cls, tokenizer, prefer_mc: bool = False) -> "ExtractionStrategy":
|
|
94
|
+
"""
|
|
95
|
+
Select the appropriate strategy based on whether tokenizer supports chat template.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
tokenizer: The tokenizer to check
|
|
99
|
+
prefer_mc: If True, prefer multiple choice strategies (mc_balanced/mc_completion)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Appropriate strategy for the tokenizer type
|
|
103
|
+
"""
|
|
104
|
+
has_chat = hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
|
|
105
|
+
|
|
106
|
+
if has_chat:
|
|
107
|
+
return cls.MC_BALANCED if prefer_mc else cls.CHAT_LAST
|
|
108
|
+
else:
|
|
109
|
+
return cls.MC_COMPLETION if prefer_mc else cls.COMPLETION_LAST
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def is_base_model_strategy(cls, strategy: "ExtractionStrategy") -> bool:
|
|
113
|
+
"""Check if a strategy is designed for base models (no chat template)."""
|
|
114
|
+
return strategy in (cls.COMPLETION_LAST, cls.COMPLETION_MEAN, cls.MC_COMPLETION)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_equivalent_for_model_type(cls, strategy: "ExtractionStrategy", tokenizer) -> "ExtractionStrategy":
|
|
118
|
+
"""
|
|
119
|
+
Get the equivalent strategy for the given tokenizer type.
|
|
120
|
+
|
|
121
|
+
If strategy requires chat template but tokenizer doesn't have it,
|
|
122
|
+
returns the base model equivalent. And vice versa.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
strategy: The requested strategy
|
|
126
|
+
tokenizer: The tokenizer to check
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The appropriate strategy for the tokenizer
|
|
130
|
+
"""
|
|
131
|
+
has_chat = hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
|
|
132
|
+
is_base_strategy = cls.is_base_model_strategy(strategy)
|
|
133
|
+
|
|
134
|
+
if has_chat and is_base_strategy:
|
|
135
|
+
# Tokenizer has chat but strategy is for base model - upgrade to chat version
|
|
136
|
+
mapping = {
|
|
137
|
+
cls.COMPLETION_LAST: cls.CHAT_LAST,
|
|
138
|
+
cls.COMPLETION_MEAN: cls.CHAT_MEAN,
|
|
139
|
+
cls.MC_COMPLETION: cls.MC_BALANCED,
|
|
140
|
+
}
|
|
141
|
+
return mapping.get(strategy, strategy)
|
|
142
|
+
|
|
143
|
+
elif not has_chat and not is_base_strategy:
|
|
144
|
+
# Tokenizer is base model but strategy requires chat - downgrade to base version
|
|
145
|
+
mapping = {
|
|
146
|
+
cls.CHAT_LAST: cls.COMPLETION_LAST,
|
|
147
|
+
cls.CHAT_FIRST: cls.COMPLETION_LAST,
|
|
148
|
+
cls.CHAT_MEAN: cls.COMPLETION_MEAN,
|
|
149
|
+
cls.CHAT_MAX_NORM: cls.COMPLETION_LAST,
|
|
150
|
+
cls.CHAT_WEIGHTED: cls.COMPLETION_MEAN,
|
|
151
|
+
cls.ROLE_PLAY: cls.COMPLETION_LAST,
|
|
152
|
+
cls.MC_BALANCED: cls.MC_COMPLETION,
|
|
153
|
+
}
|
|
154
|
+
return mapping.get(strategy, cls.COMPLETION_LAST)
|
|
155
|
+
|
|
156
|
+
return strategy
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def tokenizer_has_chat_template(tokenizer) -> bool:
|
|
160
|
+
"""Check if tokenizer supports chat template."""
|
|
161
|
+
return hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template"))
|
|
78
162
|
|
|
79
163
|
|
|
80
164
|
# Random tokens for role_play strategy (deterministic based on prompt hash)
|
|
@@ -88,6 +172,7 @@ def build_extraction_texts(
|
|
|
88
172
|
tokenizer,
|
|
89
173
|
other_response: Optional[str] = None,
|
|
90
174
|
is_positive: bool = True,
|
|
175
|
+
auto_convert_strategy: bool = True,
|
|
91
176
|
) -> Tuple[str, str, Optional[str]]:
|
|
92
177
|
"""
|
|
93
178
|
Build the full text for activation extraction based on strategy.
|
|
@@ -97,8 +182,9 @@ def build_extraction_texts(
|
|
|
97
182
|
prompt: The user prompt/question
|
|
98
183
|
response: The response to extract activations for
|
|
99
184
|
tokenizer: The tokenizer (needs apply_chat_template for chat strategies)
|
|
100
|
-
other_response: For mc_balanced, the other response option
|
|
101
|
-
is_positive: For mc_balanced, whether 'response' is the positive option
|
|
185
|
+
other_response: For mc_balanced/mc_completion, the other response option
|
|
186
|
+
is_positive: For mc_balanced/mc_completion, whether 'response' is the positive option
|
|
187
|
+
auto_convert_strategy: If True, automatically convert strategy to match tokenizer type
|
|
102
188
|
|
|
103
189
|
Returns:
|
|
104
190
|
Tuple of (full_text, answer_text, prompt_only_text)
|
|
@@ -106,31 +192,40 @@ def build_extraction_texts(
|
|
|
106
192
|
- answer_text: The answer portion (for strategies that need it)
|
|
107
193
|
- prompt_only_text: Prompt without answer (for boundary detection)
|
|
108
194
|
"""
|
|
195
|
+
# Auto-convert strategy if needed
|
|
196
|
+
if auto_convert_strategy:
|
|
197
|
+
original_strategy = strategy
|
|
198
|
+
strategy = ExtractionStrategy.get_equivalent_for_model_type(strategy, tokenizer)
|
|
199
|
+
if strategy != original_strategy:
|
|
200
|
+
import warnings
|
|
201
|
+
warnings.warn(
|
|
202
|
+
f"Strategy {original_strategy.value} not compatible with tokenizer, "
|
|
203
|
+
f"using {strategy.value} instead.",
|
|
204
|
+
UserWarning
|
|
205
|
+
)
|
|
109
206
|
|
|
110
207
|
if strategy in (ExtractionStrategy.CHAT_MEAN, ExtractionStrategy.CHAT_FIRST,
|
|
111
|
-
ExtractionStrategy.CHAT_LAST, ExtractionStrategy.
|
|
112
|
-
ExtractionStrategy.
|
|
208
|
+
ExtractionStrategy.CHAT_LAST, ExtractionStrategy.CHAT_MAX_NORM,
|
|
209
|
+
ExtractionStrategy.CHAT_WEIGHTED):
|
|
113
210
|
# All chat_* strategies use the same prompt construction
|
|
114
|
-
if hasattr(tokenizer, "apply_chat_template"):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
prompt_only = prompt
|
|
133
|
-
full_text = f"{prompt} {response}"
|
|
211
|
+
if not hasattr(tokenizer, "apply_chat_template"):
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
|
|
214
|
+
f"Base models without chat templates are not supported for chat_* strategies. "
|
|
215
|
+
f"Use completion_last, completion_mean, or mc_completion instead."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
219
|
+
[{"role": "user", "content": prompt}],
|
|
220
|
+
tokenize=False,
|
|
221
|
+
add_generation_prompt=True,
|
|
222
|
+
)
|
|
223
|
+
full_text = tokenizer.apply_chat_template(
|
|
224
|
+
[{"role": "user", "content": prompt},
|
|
225
|
+
{"role": "assistant", "content": response}],
|
|
226
|
+
tokenize=False,
|
|
227
|
+
add_generation_prompt=False,
|
|
228
|
+
)
|
|
134
229
|
|
|
135
230
|
return full_text, response, prompt_only
|
|
136
231
|
|
|
@@ -139,25 +234,23 @@ def build_extraction_texts(
|
|
|
139
234
|
random_token = ROLE_PLAY_TOKENS[hash(prompt) % len(ROLE_PLAY_TOKENS)]
|
|
140
235
|
instruction = f"Behave like a person that would answer {prompt} with {response}"
|
|
141
236
|
|
|
142
|
-
if hasattr(tokenizer, "apply_chat_template"):
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
prompt_only = instruction
|
|
160
|
-
full_text = f"{instruction} {random_token}"
|
|
237
|
+
if not hasattr(tokenizer, "apply_chat_template"):
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
|
|
240
|
+
f"Use completion_last or mc_completion for base models."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
244
|
+
[{"role": "user", "content": instruction}],
|
|
245
|
+
tokenize=False,
|
|
246
|
+
add_generation_prompt=True,
|
|
247
|
+
)
|
|
248
|
+
full_text = tokenizer.apply_chat_template(
|
|
249
|
+
[{"role": "user", "content": instruction},
|
|
250
|
+
{"role": "assistant", "content": random_token}],
|
|
251
|
+
tokenize=False,
|
|
252
|
+
add_generation_prompt=False,
|
|
253
|
+
)
|
|
161
254
|
|
|
162
255
|
return full_text, random_token, prompt_only
|
|
163
256
|
|
|
@@ -188,28 +281,66 @@ def build_extraction_texts(
|
|
|
188
281
|
option_b = response[:200] # negative
|
|
189
282
|
answer = "B"
|
|
190
283
|
|
|
191
|
-
mc_prompt = f"
|
|
284
|
+
mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
|
|
285
|
+
|
|
286
|
+
if not hasattr(tokenizer, "apply_chat_template"):
|
|
287
|
+
raise ValueError(
|
|
288
|
+
f"Strategy {strategy.value} requires a tokenizer with apply_chat_template. "
|
|
289
|
+
f"Use mc_completion for base models."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
prompt_only = tokenizer.apply_chat_template(
|
|
293
|
+
[{"role": "user", "content": mc_prompt}],
|
|
294
|
+
tokenize=False,
|
|
295
|
+
add_generation_prompt=True,
|
|
296
|
+
)
|
|
297
|
+
full_text = tokenizer.apply_chat_template(
|
|
298
|
+
[{"role": "user", "content": mc_prompt},
|
|
299
|
+
{"role": "assistant", "content": answer}],
|
|
300
|
+
tokenize=False,
|
|
301
|
+
add_generation_prompt=False,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
return full_text, answer, prompt_only
|
|
305
|
+
|
|
306
|
+
elif strategy in (ExtractionStrategy.COMPLETION_LAST, ExtractionStrategy.COMPLETION_MEAN):
|
|
307
|
+
# Base model strategies - direct Q+A without chat template
|
|
308
|
+
# Format: "Q: {prompt}\nA: {response}"
|
|
309
|
+
prompt_only = f"Q: {prompt}\nA:"
|
|
310
|
+
full_text = f"Q: {prompt}\nA: {response}"
|
|
311
|
+
return full_text, response, prompt_only
|
|
312
|
+
|
|
313
|
+
elif strategy == ExtractionStrategy.MC_COMPLETION:
|
|
314
|
+
# Multiple choice for base models - no chat template
|
|
315
|
+
if other_response is None:
|
|
316
|
+
raise ValueError("MC_COMPLETION strategy requires other_response")
|
|
317
|
+
|
|
318
|
+
# Deterministic "random" based on prompt - same for both pos and neg of a pair
|
|
319
|
+
pos_goes_in_b = hash(prompt) % 2 == 0
|
|
192
320
|
|
|
193
|
-
if
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
{"role": "assistant", "content": answer}],
|
|
203
|
-
tokenize=False,
|
|
204
|
-
add_generation_prompt=False,
|
|
205
|
-
)
|
|
206
|
-
except (ValueError, KeyError):
|
|
207
|
-
prompt_only = mc_prompt
|
|
208
|
-
full_text = f"{mc_prompt} {answer}"
|
|
321
|
+
if is_positive:
|
|
322
|
+
if pos_goes_in_b:
|
|
323
|
+
option_a = other_response[:200]
|
|
324
|
+
option_b = response[:200]
|
|
325
|
+
answer = "B"
|
|
326
|
+
else:
|
|
327
|
+
option_a = response[:200]
|
|
328
|
+
option_b = other_response[:200]
|
|
329
|
+
answer = "A"
|
|
209
330
|
else:
|
|
210
|
-
|
|
211
|
-
|
|
331
|
+
if pos_goes_in_b:
|
|
332
|
+
option_a = response[:200]
|
|
333
|
+
option_b = other_response[:200]
|
|
334
|
+
answer = "A"
|
|
335
|
+
else:
|
|
336
|
+
option_a = other_response[:200]
|
|
337
|
+
option_b = response[:200]
|
|
338
|
+
answer = "B"
|
|
212
339
|
|
|
340
|
+
mc_prompt = f"Question: {prompt}\n\nWhich is correct?\nA. {option_a}\nB. {option_b}\nAnswer:"
|
|
341
|
+
|
|
342
|
+
prompt_only = mc_prompt
|
|
343
|
+
full_text = f"{mc_prompt} {answer}"
|
|
213
344
|
return full_text, answer, prompt_only
|
|
214
345
|
|
|
215
346
|
else:
|
|
@@ -243,6 +374,7 @@ def extract_activation(
|
|
|
243
374
|
num_answer_tokens = len(answer_tokens)
|
|
244
375
|
|
|
245
376
|
if strategy == ExtractionStrategy.CHAT_LAST:
|
|
377
|
+
# EOT token - has seen the entire answer, best performance
|
|
246
378
|
return hidden_states[-1]
|
|
247
379
|
|
|
248
380
|
elif strategy == ExtractionStrategy.CHAT_FIRST:
|
|
@@ -257,11 +389,6 @@ def extract_activation(
|
|
|
257
389
|
return answer_hidden.mean(dim=0)
|
|
258
390
|
return hidden_states[-1]
|
|
259
391
|
|
|
260
|
-
elif strategy == ExtractionStrategy.CHAT_GEN_POINT:
|
|
261
|
-
# Last token before answer starts (decision point)
|
|
262
|
-
gen_point_idx = max(0, seq_len - num_answer_tokens - 2)
|
|
263
|
-
return hidden_states[gen_point_idx]
|
|
264
|
-
|
|
265
392
|
elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
|
|
266
393
|
# Token with max norm in answer region
|
|
267
394
|
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
@@ -275,18 +402,36 @@ def extract_activation(
|
|
|
275
402
|
# Position-weighted mean (earlier tokens weighted more)
|
|
276
403
|
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
277
404
|
answer_hidden = hidden_states[-num_answer_tokens-1:-1]
|
|
278
|
-
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=
|
|
405
|
+
weights = torch.exp(-torch.arange(answer_hidden.shape[0], dtype=answer_hidden.dtype, device=answer_hidden.device) * 0.5)
|
|
279
406
|
weights = weights / weights.sum()
|
|
280
407
|
return (answer_hidden * weights.unsqueeze(1)).sum(dim=0)
|
|
281
408
|
return hidden_states[-1]
|
|
282
409
|
|
|
283
|
-
elif strategy
|
|
284
|
-
#
|
|
410
|
+
elif strategy == ExtractionStrategy.ROLE_PLAY:
|
|
411
|
+
# EOT token - slightly better than answer word (65% vs 64%)
|
|
285
412
|
return hidden_states[-1]
|
|
286
413
|
|
|
287
|
-
|
|
288
|
-
#
|
|
414
|
+
elif strategy == ExtractionStrategy.MC_BALANCED:
|
|
415
|
+
# Answer token (A/B) - better than EOT (64% vs 56%)
|
|
416
|
+
return hidden_states[-2]
|
|
417
|
+
|
|
418
|
+
elif strategy == ExtractionStrategy.COMPLETION_LAST:
|
|
419
|
+
# Last token for base model completion
|
|
289
420
|
return hidden_states[-1]
|
|
421
|
+
|
|
422
|
+
elif strategy == ExtractionStrategy.COMPLETION_MEAN:
|
|
423
|
+
# Mean of answer tokens for base model completion
|
|
424
|
+
if num_answer_tokens > 0 and seq_len > num_answer_tokens:
|
|
425
|
+
answer_hidden = hidden_states[-num_answer_tokens:]
|
|
426
|
+
return answer_hidden.mean(dim=0)
|
|
427
|
+
return hidden_states[-1]
|
|
428
|
+
|
|
429
|
+
elif strategy == ExtractionStrategy.MC_COMPLETION:
|
|
430
|
+
# A/B token for base model MC (last token is the answer)
|
|
431
|
+
return hidden_states[-1]
|
|
432
|
+
|
|
433
|
+
else:
|
|
434
|
+
raise ValueError(f"Unknown extraction strategy: {strategy}")
|
|
290
435
|
|
|
291
436
|
|
|
292
437
|
def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -306,3 +451,30 @@ def add_extraction_strategy_args(parser: argparse.ArgumentParser) -> None:
|
|
|
306
451
|
choices=ExtractionStrategy.list_all(),
|
|
307
452
|
help=f"Extraction strategy for activations. Options: {', '.join(ExtractionStrategy.list_all())}. Default: {ExtractionStrategy.default().value}",
|
|
308
453
|
)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def get_strategy_for_model(tokenizer, prefer_mc: bool = False) -> ExtractionStrategy:
|
|
457
|
+
"""
|
|
458
|
+
Get the best extraction strategy for a given tokenizer.
|
|
459
|
+
|
|
460
|
+
Automatically detects if tokenizer has chat template and returns
|
|
461
|
+
the appropriate strategy.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
tokenizer: The tokenizer to check
|
|
465
|
+
prefer_mc: If True, prefer multiple choice strategies
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
ExtractionStrategy appropriate for the tokenizer
|
|
469
|
+
|
|
470
|
+
Example:
|
|
471
|
+
>>> from transformers import AutoTokenizer
|
|
472
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
|
473
|
+
>>> strategy = get_strategy_for_model(tokenizer)
|
|
474
|
+
>>> print(strategy) # completion_last (base model)
|
|
475
|
+
|
|
476
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
|
|
477
|
+
>>> strategy = get_strategy_for_model(tokenizer)
|
|
478
|
+
>>> print(strategy) # chat_last (instruct model)
|
|
479
|
+
"""
|
|
480
|
+
return ExtractionStrategy.for_tokenizer(tokenizer, prefer_mc=prefer_mc)
|
|
@@ -14,6 +14,7 @@ import numpy as np
|
|
|
14
14
|
|
|
15
15
|
from torch.nn.modules.loss import _Loss
|
|
16
16
|
from wisent.core.errors import DuplicateNameError, InvalidRangeError, UnknownTypeError
|
|
17
|
+
from wisent.core.utils.device import preferred_dtype
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
"ClassifierTrainConfig",
|
|
@@ -164,13 +165,13 @@ class BaseClassifier(ABC):
|
|
|
164
165
|
self,
|
|
165
166
|
threshold: float = 0.5,
|
|
166
167
|
device: str | None = None,
|
|
167
|
-
dtype: torch.dtype =
|
|
168
|
+
dtype: torch.dtype | None = None,
|
|
168
169
|
) -> None:
|
|
169
170
|
if not 0.0 <= threshold <= 1.0:
|
|
170
171
|
raise InvalidRangeError(param_name="threshold", actual=threshold, min_val=0.0, max_val=1.0)
|
|
171
172
|
self.threshold = threshold
|
|
172
173
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
173
|
-
self.dtype =
|
|
174
|
+
self.dtype = dtype if dtype is not None else preferred_dtype(self.device)
|
|
174
175
|
self.model = None
|
|
175
176
|
|
|
176
177
|
@abstractmethod
|
wisent/core/cli/__init__.py
CHANGED
|
@@ -22,5 +22,6 @@ from .inference_config_cli import execute_inference_config
|
|
|
22
22
|
from .optimization_cache import execute_optimization_cache
|
|
23
23
|
from .optimize_weights import execute_optimize_weights
|
|
24
24
|
from .optimize import execute_optimize
|
|
25
|
+
from .geometry_search import execute_geometry_search
|
|
25
26
|
|
|
26
|
-
__all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize']
|
|
27
|
+
__all__ = ['execute_tasks', 'execute_generate_pairs_from_task', 'execute_generate_pairs', 'execute_diagnose_pairs', 'execute_get_activations', 'execute_diagnose_vectors', 'execute_create_steering_vector', 'execute_generate_vector_from_task', 'execute_generate_vector_from_synthetic', 'execute_optimize_classification', 'execute_optimize_steering', 'execute_optimize_sample_size', 'execute_generate_responses', 'execute_evaluate_responses', 'execute_multi_steer', 'execute_agent', 'execute_modify_weights', 'execute_evaluate_refusal', 'execute_inference_config', 'execute_optimization_cache', 'execute_optimize_weights', 'execute_optimize', 'execute_geometry_search']
|
|
@@ -19,7 +19,7 @@ def _map_token_aggregation(aggregation_str: str):
|
|
|
19
19
|
|
|
20
20
|
def _map_prompt_strategy(strategy_str: str):
|
|
21
21
|
"""Map string prompt strategy to ExtractionStrategy."""
|
|
22
|
-
|
|
22
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
23
23
|
|
|
24
24
|
mapping = {
|
|
25
25
|
"chat_template": ExtractionStrategy.CHAT_LAST,
|
|
@@ -111,9 +111,8 @@ def apply_steering_and_evaluate(
|
|
|
111
111
|
|
|
112
112
|
updated_pair = collector.collect(
|
|
113
113
|
pair, strategy=aggregation_strategy,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
prompt_strategy=prompt_construction_strategy
|
|
114
|
+
layers=target_layers,
|
|
115
|
+
normalize=normalize_layers
|
|
117
116
|
)
|
|
118
117
|
enriched_pairs.append(updated_pair)
|
|
119
118
|
|
|
@@ -174,9 +173,8 @@ def apply_steering_and_evaluate(
|
|
|
174
173
|
|
|
175
174
|
steered_evaluated_pair = collector.collect(
|
|
176
175
|
steered_dummy_pair, strategy=aggregation_strategy,
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
prompt_strategy=prompt_construction_strategy
|
|
176
|
+
layers=target_layers,
|
|
177
|
+
normalize=normalize_layers
|
|
180
178
|
)
|
|
181
179
|
|
|
182
180
|
steered_quality = 0.0
|
|
@@ -1,8 +1,20 @@
|
|
|
1
1
|
"""Train classifier on contrastive pairs for agent."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
+
import torch
|
|
4
5
|
from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainReport
|
|
5
6
|
from wisent.core.errors import UnknownTypeError
|
|
7
|
+
from wisent.core.utils.device import preferred_dtype
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _torch_dtype_to_numpy(torch_dtype: torch.dtype):
|
|
11
|
+
"""Convert torch dtype to numpy dtype."""
|
|
12
|
+
mapping = {
|
|
13
|
+
torch.float32: np.float32,
|
|
14
|
+
torch.float16: np.float16,
|
|
15
|
+
torch.bfloat16: np.float32, # numpy doesn't support bfloat16, use float32
|
|
16
|
+
}
|
|
17
|
+
return mapping.get(torch_dtype, np.float32)
|
|
6
18
|
|
|
7
19
|
|
|
8
20
|
def _map_token_aggregation(aggregation_str: str):
|
|
@@ -21,7 +33,7 @@ def _map_token_aggregation(aggregation_str: str):
|
|
|
21
33
|
|
|
22
34
|
def _map_prompt_strategy(strategy_str: str):
|
|
23
35
|
"""Map string prompt strategy to ExtractionStrategy."""
|
|
24
|
-
|
|
36
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
25
37
|
|
|
26
38
|
mapping = {
|
|
27
39
|
"chat_template": ExtractionStrategy.CHAT_LAST,
|
|
@@ -97,7 +109,7 @@ def train_classifier_on_pairs(
|
|
|
97
109
|
prompt_construction_strategy = _map_prompt_strategy(prompt_strategy)
|
|
98
110
|
|
|
99
111
|
# Collect activations for all pairs
|
|
100
|
-
collector = ActivationCollector(model=model
|
|
112
|
+
collector = ActivationCollector(model=model)
|
|
101
113
|
target_layers = [str(target_layer)]
|
|
102
114
|
layer_key = target_layers[0]
|
|
103
115
|
|
|
@@ -108,9 +120,8 @@ def train_classifier_on_pairs(
|
|
|
108
120
|
|
|
109
121
|
updated_pair = collector.collect(
|
|
110
122
|
pair, strategy=aggregation_strategy,
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
prompt_strategy=prompt_construction_strategy
|
|
123
|
+
layers=[str(target_layer)],
|
|
124
|
+
normalize=normalize_layers
|
|
114
125
|
)
|
|
115
126
|
enriched_training_pairs.append(updated_pair)
|
|
116
127
|
|
|
@@ -133,8 +144,9 @@ def train_classifier_on_pairs(
|
|
|
133
144
|
X_list.append(neg_act.cpu().numpy())
|
|
134
145
|
y_list.append(0.0)
|
|
135
146
|
|
|
136
|
-
|
|
137
|
-
|
|
147
|
+
np_dtype = _torch_dtype_to_numpy(preferred_dtype())
|
|
148
|
+
X_train = np.array(X_list, dtype=np_dtype)
|
|
149
|
+
y_train = np.array(y_list, dtype=np_dtype)
|
|
138
150
|
|
|
139
151
|
print(f" Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features")
|
|
140
152
|
|
|
@@ -31,6 +31,7 @@ def execute_check_linearity(args):
|
|
|
31
31
|
from wisent.core.models.wisent_model import WisentModel
|
|
32
32
|
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
33
33
|
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
34
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
34
35
|
|
|
35
36
|
# Build ContrastivePair objects
|
|
36
37
|
pairs = []
|
|
@@ -72,6 +73,10 @@ def execute_check_linearity(args):
|
|
|
72
73
|
if args.layers:
|
|
73
74
|
config.layers_to_test = [int(l) for l in args.layers.split(',')]
|
|
74
75
|
|
|
76
|
+
if args.extraction_strategy:
|
|
77
|
+
config.extraction_strategies = [ExtractionStrategy(args.extraction_strategy)]
|
|
78
|
+
print(f"Using extraction strategy: {args.extraction_strategy}")
|
|
79
|
+
|
|
75
80
|
# Run check
|
|
76
81
|
print("\nRunning linearity check...")
|
|
77
82
|
result = check_linearity(pairs, model, config)
|
|
@@ -110,12 +115,39 @@ def execute_check_linearity(args):
|
|
|
110
115
|
|
|
111
116
|
sorted_results = sorted(result.all_results, key=lambda x: x['linear_score'], reverse=True)
|
|
112
117
|
|
|
113
|
-
print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'
|
|
114
|
-
print("-" *
|
|
118
|
+
print(f"{'Linear':<8} {'d':<8} {'Layer':<6} {'Strategy':<20} {'Structure':<12} {'Norm'}")
|
|
119
|
+
print("-" * 70)
|
|
115
120
|
|
|
116
121
|
for r in sorted_results[:20]:
|
|
117
122
|
print(f"{r['linear_score']:<8.3f} {r['cohens_d']:<8.2f} {r['layer']:<6} "
|
|
118
|
-
f"{r['
|
|
123
|
+
f"{r['extraction_strategy']:<20} {r['best_structure']:<12} {r['normalize']}")
|
|
124
|
+
|
|
125
|
+
# Show best result for each structure type
|
|
126
|
+
if sorted_results and 'all_structure_scores' in sorted_results[0]:
|
|
127
|
+
print(f"\n{'='*60}")
|
|
128
|
+
print("BEST SCORE PER STRUCTURE TYPE")
|
|
129
|
+
print(f"{'='*60}")
|
|
130
|
+
|
|
131
|
+
# Collect best score for each structure across all configs
|
|
132
|
+
best_per_structure = {}
|
|
133
|
+
for r in result.all_results:
|
|
134
|
+
if 'all_structure_scores' not in r:
|
|
135
|
+
continue
|
|
136
|
+
for struct_name, data in r['all_structure_scores'].items():
|
|
137
|
+
score = data['score']
|
|
138
|
+
if struct_name not in best_per_structure or score > best_per_structure[struct_name]['score']:
|
|
139
|
+
best_per_structure[struct_name] = {
|
|
140
|
+
'score': score,
|
|
141
|
+
'confidence': data['confidence'],
|
|
142
|
+
'layer': r['layer'],
|
|
143
|
+
'strategy': r['extraction_strategy'],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
print(f"{'Structure':<12} {'Score':<8} {'Conf':<8} {'Layer':<6} {'Strategy'}")
|
|
147
|
+
print("-" * 55)
|
|
148
|
+
sorted_structs = sorted(best_per_structure.items(), key=lambda x: x[1]['score'], reverse=True)
|
|
149
|
+
for name, data in sorted_structs:
|
|
150
|
+
print(f"{name:<12} {data['score']:<8.3f} {data['confidence']:<8.3f} {data['layer']:<6} {data['strategy']}")
|
|
119
151
|
|
|
120
152
|
# Exit code based on verdict
|
|
121
153
|
if result.verdict.value == "linear":
|