wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +9 -5
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +256 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/apply_steering.py +5 -7
- wisent/core/cli/agent/train_classifier.py +19 -7
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/generate_paper_data.py +384 -0
- wisent/examples/scripts/intervention_validation.py +626 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/examples/scripts/threshold_analysis.py +434 -0
- wisent/examples/scripts/visualization_gallery.py +582 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Preview contrastive pairs from benchmarks with different extraction strategies."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import json
|
|
5
|
+
import argparse
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def execute_preview_pairs(args):
|
|
10
|
+
"""Preview contrastive pairs from a benchmark with different strategies applied."""
|
|
11
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
|
|
12
|
+
lm_build_contrastive_pairs,
|
|
13
|
+
)
|
|
14
|
+
from wisent.core.contrastive_pairs.huggingface_pairs.hf_extractor_manifest import HF_EXTRACTORS
|
|
15
|
+
from wisent.core.activations.extraction_strategy import (
|
|
16
|
+
ExtractionStrategy,
|
|
17
|
+
build_extraction_texts,
|
|
18
|
+
get_strategy_for_model,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
task_name = args.task_name
|
|
22
|
+
limit = args.limit or 5
|
|
23
|
+
strategies = args.strategies or ['chat_last', 'mc_balanced', 'completion_last']
|
|
24
|
+
|
|
25
|
+
print(f"\n{'='*80}")
|
|
26
|
+
print(f"Preview Contrastive Pairs: {task_name}")
|
|
27
|
+
print(f"{'='*80}")
|
|
28
|
+
|
|
29
|
+
# Load pairs
|
|
30
|
+
print(f"\nLoading {limit} pairs from '{task_name}'...")
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
task_name_lower = task_name.lower()
|
|
34
|
+
is_hf_task = task_name_lower in {k.lower() for k in HF_EXTRACTORS.keys()}
|
|
35
|
+
|
|
36
|
+
if is_hf_task:
|
|
37
|
+
pairs = lm_build_contrastive_pairs(
|
|
38
|
+
task_name=task_name,
|
|
39
|
+
lm_eval_task=None,
|
|
40
|
+
limit=limit,
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
44
|
+
loader = LMEvalDataLoader()
|
|
45
|
+
task_obj = loader.load_lm_eval_task(task_name)
|
|
46
|
+
|
|
47
|
+
if isinstance(task_obj, dict):
|
|
48
|
+
if len(task_obj) != 1:
|
|
49
|
+
keys = ", ".join(sorted(task_obj.keys()))
|
|
50
|
+
print(f"Task '{task_name}' has subtasks: {keys}")
|
|
51
|
+
print("Please specify a subtask.")
|
|
52
|
+
sys.exit(1)
|
|
53
|
+
(subname, task), = task_obj.items()
|
|
54
|
+
task_name = subname
|
|
55
|
+
else:
|
|
56
|
+
task = task_obj
|
|
57
|
+
|
|
58
|
+
pairs = lm_build_contrastive_pairs(
|
|
59
|
+
task_name=task_name,
|
|
60
|
+
lm_eval_task=task,
|
|
61
|
+
limit=limit,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
print(f"Loaded {len(pairs)} pairs\n")
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
print(f"Error loading task: {e}")
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
# Mock tokenizer for preview
|
|
71
|
+
class PreviewTokenizer:
|
|
72
|
+
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
|
|
73
|
+
if len(messages) == 1:
|
|
74
|
+
return f"<|user|>\n{messages[0]['content']}\n<|assistant|>\n"
|
|
75
|
+
elif len(messages) == 2:
|
|
76
|
+
return f"<|user|>\n{messages[0]['content']}\n<|assistant|>\n{messages[1]['content']}<|end|>"
|
|
77
|
+
return str(messages)
|
|
78
|
+
|
|
79
|
+
def __call__(self, text, add_special_tokens=False):
|
|
80
|
+
return {"input_ids": text.split()}
|
|
81
|
+
|
|
82
|
+
tokenizer = PreviewTokenizer()
|
|
83
|
+
|
|
84
|
+
# Show pairs with strategies
|
|
85
|
+
for i, pair in enumerate(pairs):
|
|
86
|
+
print(f"\n{'='*80}")
|
|
87
|
+
print(f"PAIR {i+1}/{len(pairs)}")
|
|
88
|
+
print(f"{'='*80}")
|
|
89
|
+
|
|
90
|
+
print(f"\n--- RAW DATA (from extractor) ---")
|
|
91
|
+
print(f"Prompt: {pair.prompt[:300]}{'...' if len(pair.prompt) > 300 else ''}")
|
|
92
|
+
print(f"Correct: {pair.positive_response.model_response[:100]}{'...' if len(pair.positive_response.model_response) > 100 else ''}")
|
|
93
|
+
print(f"Incorrect: {pair.negative_response.model_response[:100]}{'...' if len(pair.negative_response.model_response) > 100 else ''}")
|
|
94
|
+
|
|
95
|
+
for strategy_name in strategies:
|
|
96
|
+
try:
|
|
97
|
+
strategy = ExtractionStrategy(strategy_name)
|
|
98
|
+
except ValueError:
|
|
99
|
+
print(f"\n--- {strategy_name.upper()} --- (invalid strategy)")
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
print(f"\n--- {strategy_name.upper()} ---")
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
# Build texts for positive response
|
|
106
|
+
if strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION):
|
|
107
|
+
full_text, answer, prompt_only = build_extraction_texts(
|
|
108
|
+
strategy,
|
|
109
|
+
pair.prompt,
|
|
110
|
+
pair.positive_response.model_response,
|
|
111
|
+
tokenizer,
|
|
112
|
+
other_response=pair.negative_response.model_response,
|
|
113
|
+
is_positive=True,
|
|
114
|
+
auto_convert_strategy=False,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
full_text, answer, prompt_only = build_extraction_texts(
|
|
118
|
+
strategy,
|
|
119
|
+
pair.prompt,
|
|
120
|
+
pair.positive_response.model_response,
|
|
121
|
+
tokenizer,
|
|
122
|
+
auto_convert_strategy=False,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
print(f"Full text (positive):")
|
|
126
|
+
print(f" {full_text[:400]}{'...' if len(full_text) > 400 else ''}")
|
|
127
|
+
print(f"Answer token: {answer}")
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print(f" Error: {e}")
|
|
131
|
+
|
|
132
|
+
# Summary
|
|
133
|
+
print(f"\n{'='*80}")
|
|
134
|
+
print("SUMMARY")
|
|
135
|
+
print(f"{'='*80}")
|
|
136
|
+
print(f"Task: {task_name}")
|
|
137
|
+
print(f"Pairs shown: {len(pairs)}")
|
|
138
|
+
print(f"Strategies: {', '.join(strategies)}")
|
|
139
|
+
print()
|
|
140
|
+
|
|
141
|
+
# Save to JSON if requested
|
|
142
|
+
if args.output:
|
|
143
|
+
output_data = {
|
|
144
|
+
"task_name": task_name,
|
|
145
|
+
"num_pairs": len(pairs),
|
|
146
|
+
"strategies": strategies,
|
|
147
|
+
"pairs": []
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
for pair in pairs:
|
|
151
|
+
pair_data = {
|
|
152
|
+
"raw": {
|
|
153
|
+
"prompt": pair.prompt,
|
|
154
|
+
"correct": pair.positive_response.model_response,
|
|
155
|
+
"incorrect": pair.negative_response.model_response,
|
|
156
|
+
},
|
|
157
|
+
"formatted": {}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
for strategy_name in strategies:
|
|
161
|
+
try:
|
|
162
|
+
strategy = ExtractionStrategy(strategy_name)
|
|
163
|
+
if strategy in (ExtractionStrategy.MC_BALANCED, ExtractionStrategy.MC_COMPLETION):
|
|
164
|
+
full_text, answer, _ = build_extraction_texts(
|
|
165
|
+
strategy, pair.prompt, pair.positive_response.model_response,
|
|
166
|
+
tokenizer, other_response=pair.negative_response.model_response,
|
|
167
|
+
is_positive=True, auto_convert_strategy=False,
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
full_text, answer, _ = build_extraction_texts(
|
|
171
|
+
strategy, pair.prompt, pair.positive_response.model_response,
|
|
172
|
+
tokenizer, auto_convert_strategy=False,
|
|
173
|
+
)
|
|
174
|
+
pair_data["formatted"][strategy_name] = {
|
|
175
|
+
"full_text": full_text,
|
|
176
|
+
"answer": answer,
|
|
177
|
+
}
|
|
178
|
+
except Exception as e:
|
|
179
|
+
pair_data["formatted"][strategy_name] = {"error": str(e)}
|
|
180
|
+
|
|
181
|
+
output_data["pairs"].append(pair_data)
|
|
182
|
+
|
|
183
|
+
with open(args.output, 'w') as f:
|
|
184
|
+
json.dump(output_data, f, indent=2)
|
|
185
|
+
print(f"Saved to: {args.output}")
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def main():
|
|
189
|
+
parser = argparse.ArgumentParser(description="Preview contrastive pairs with different strategies")
|
|
190
|
+
parser.add_argument("task_name", help="Task/benchmark name (e.g., boolq, mmlu, hellaswag)")
|
|
191
|
+
parser.add_argument("--limit", "-n", type=int, default=5, help="Number of pairs to show (default: 5)")
|
|
192
|
+
parser.add_argument("--strategies", "-s", nargs="+",
|
|
193
|
+
default=["chat_last", "mc_balanced", "completion_last"],
|
|
194
|
+
help="Strategies to preview")
|
|
195
|
+
parser.add_argument("--output", "-o", help="Save to JSON file")
|
|
196
|
+
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
|
197
|
+
|
|
198
|
+
args = parser.parse_args()
|
|
199
|
+
execute_preview_pairs(args)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
if __name__ == "__main__":
|
|
203
|
+
main()
|
|
@@ -156,7 +156,7 @@ def collect_activations_for_pair_set(
|
|
|
156
156
|
Returns:
|
|
157
157
|
Updated ContrastivePairSet with activations attached
|
|
158
158
|
"""
|
|
159
|
-
collector = ActivationCollector(model=model
|
|
159
|
+
collector = ActivationCollector(model=model)
|
|
160
160
|
|
|
161
161
|
updated_pairs = []
|
|
162
162
|
for pair in pair_set.pairs:
|
|
@@ -320,7 +320,7 @@ class UnifiedSteeringTrainer:
|
|
|
320
320
|
@property
|
|
321
321
|
def collector(self) -> ActivationCollector:
|
|
322
322
|
if self._collector is None:
|
|
323
|
-
self._collector = ActivationCollector(model=self.model
|
|
323
|
+
self._collector = ActivationCollector(model=self.model)
|
|
324
324
|
return self._collector
|
|
325
325
|
|
|
326
326
|
def train_for_layer(
|
|
@@ -595,7 +595,7 @@ def get_optimal_steering_plan(
|
|
|
595
595
|
method_name = config["method"]
|
|
596
596
|
|
|
597
597
|
# Collect activations for the optimal layer
|
|
598
|
-
collector = ActivationCollector(model=model
|
|
598
|
+
collector = ActivationCollector(model=model)
|
|
599
599
|
layer_str = str(layer)
|
|
600
600
|
|
|
601
601
|
pos_acts = []
|
wisent/core/cli/tasks.py
CHANGED
|
@@ -414,7 +414,7 @@ def execute_tasks(args):
|
|
|
414
414
|
print(f"\n🧠 Extracting activations from layer {layer}...")
|
|
415
415
|
|
|
416
416
|
# 5. Collect activations for all pairs
|
|
417
|
-
collector = ActivationCollector(model=model
|
|
417
|
+
collector = ActivationCollector(model=model)
|
|
418
418
|
|
|
419
419
|
# Get extraction strategy from args (already an ExtractionStrategy value string)
|
|
420
420
|
extraction_strategy = ExtractionStrategy(getattr(args, 'extraction_strategy', 'chat_last'))
|
|
@@ -581,13 +581,6 @@ def execute_tasks(args):
|
|
|
581
581
|
expected = pair.positive_response.model_response
|
|
582
582
|
choices = [pair.negative_response.model_response, pair.positive_response.model_response]
|
|
583
583
|
|
|
584
|
-
# Extract test_code from pair metadata for coding tasks
|
|
585
|
-
test_code = None
|
|
586
|
-
starter_code = None
|
|
587
|
-
if hasattr(pair, 'metadata') and pair.metadata:
|
|
588
|
-
test_code = pair.metadata.get('test_code')
|
|
589
|
-
starter_code = pair.metadata.get('starter_code')
|
|
590
|
-
|
|
591
584
|
# Generate response from unsteered model
|
|
592
585
|
messages = [{"role": "user", "content": question}]
|
|
593
586
|
|
|
@@ -597,6 +590,7 @@ def execute_tasks(args):
|
|
|
597
590
|
)[0]
|
|
598
591
|
|
|
599
592
|
# Evaluate the response using Wisent evaluator
|
|
593
|
+
# Pass all pair metadata to evaluator - each evaluator uses what it needs
|
|
600
594
|
eval_kwargs = {
|
|
601
595
|
'response': response,
|
|
602
596
|
'expected': expected,
|
|
@@ -605,16 +599,16 @@ def execute_tasks(args):
|
|
|
605
599
|
'choices': choices,
|
|
606
600
|
'task_name': task_name,
|
|
607
601
|
}
|
|
608
|
-
# Add
|
|
609
|
-
if
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
602
|
+
# Add all pair metadata to eval_kwargs (test_code, correct_answers, etc.)
|
|
603
|
+
if hasattr(pair, 'metadata') and pair.metadata:
|
|
604
|
+
for key, value in pair.metadata.items():
|
|
605
|
+
if value is not None and key not in eval_kwargs:
|
|
606
|
+
eval_kwargs[key] = value
|
|
613
607
|
eval_result = evaluator.evaluate(**eval_kwargs)
|
|
614
608
|
|
|
615
609
|
# Get activation for this generation
|
|
616
610
|
# Use ActivationCollector to collect activations from the generated text
|
|
617
|
-
gen_collector = ActivationCollector(model=model
|
|
611
|
+
gen_collector = ActivationCollector(model=model)
|
|
618
612
|
# Create a pair with the generated response
|
|
619
613
|
from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
|
|
620
614
|
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
@@ -631,56 +625,20 @@ def execute_tasks(args):
|
|
|
631
625
|
# Collect activation - ActivationCollector will re-run the model with prompt+response
|
|
632
626
|
# First, collect with full sequence to get token-by-token activations
|
|
633
627
|
collected_full = gen_collector.collect(
|
|
634
|
-
temp_pair, strategy=
|
|
635
|
-
return_full_sequence=True,
|
|
636
|
-
normalize_layers=False,
|
|
637
|
-
prompt_strategy=prompt_strategy
|
|
628
|
+
temp_pair, strategy=extraction_strategy,
|
|
638
629
|
)
|
|
639
630
|
|
|
640
631
|
# Access the collected activations
|
|
641
632
|
import torch
|
|
642
633
|
if collected_full.positive_response.layers_activations:
|
|
643
|
-
|
|
644
|
-
if layer_str in
|
|
645
|
-
|
|
646
|
-
if
|
|
647
|
-
#
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
if
|
|
651
|
-
activation_agg = activation_full_seq.mean(dim=0)
|
|
652
|
-
elif aggregation_strategy.name == 'LAST_TOKEN':
|
|
653
|
-
activation_agg = activation_full_seq[-1]
|
|
654
|
-
elif aggregation_strategy.name == 'FIRST_TOKEN':
|
|
655
|
-
activation_agg = activation_full_seq[0]
|
|
656
|
-
elif aggregation_strategy.name == 'MAX_POOLING':
|
|
657
|
-
activation_agg = activation_full_seq.max(dim=0)[0]
|
|
658
|
-
elif aggregation_strategy.name == 'MIN_POOLING':
|
|
659
|
-
activation_agg = activation_full_seq.min(dim=0)[0]
|
|
660
|
-
else:
|
|
661
|
-
# Default to mean
|
|
662
|
-
activation_agg = activation_full_seq.mean(dim=0)
|
|
663
|
-
|
|
664
|
-
# Compute per-token classifier scores first (needed for max_score aggregation)
|
|
665
|
-
token_scores = []
|
|
666
|
-
for token_idx in range(activation_full_seq.shape[0]):
|
|
667
|
-
token_act = activation_full_seq[token_idx].unsqueeze(0).float()
|
|
668
|
-
token_proba_result = classifier.predict_proba(token_act)
|
|
669
|
-
token_proba = token_proba_result if isinstance(token_proba_result, float) else token_proba_result[0]
|
|
670
|
-
token_scores.append(float(token_proba))
|
|
671
|
-
|
|
672
|
-
# Get classifier prediction - either from aggregated vector or min token score
|
|
673
|
-
if use_max_token_score:
|
|
674
|
-
# Use MINIMUM token score as response-level classification
|
|
675
|
-
# Score = P(TRUTHFUL), so min score = most suspicious token
|
|
676
|
-
# If ANY token has low truthfulness probability, flag the response
|
|
677
|
-
pred_proba = min(token_scores) if token_scores else 0.5
|
|
678
|
-
else:
|
|
679
|
-
# Standard: classify aggregated activation vector
|
|
680
|
-
act_tensor = activation_agg.unsqueeze(0).float()
|
|
681
|
-
pred_proba_result = classifier.predict_proba(act_tensor)
|
|
682
|
-
pred_proba = pred_proba_result if isinstance(pred_proba_result, float) else pred_proba_result[0]
|
|
683
|
-
|
|
634
|
+
layer_activations = collected_full.positive_response.layers_activations
|
|
635
|
+
if layer_str in layer_activations:
|
|
636
|
+
activation = layer_activations[layer_str]
|
|
637
|
+
if activation is not None and isinstance(activation, torch.Tensor):
|
|
638
|
+
# activation shape: (hidden_dim,) - already aggregated by extraction strategy
|
|
639
|
+
act_tensor = activation.unsqueeze(0).float()
|
|
640
|
+
pred_proba_result = classifier.predict_proba(act_tensor)
|
|
641
|
+
pred_proba = pred_proba_result if isinstance(pred_proba_result, float) else pred_proba_result[0]
|
|
684
642
|
pred_label = int(pred_proba > args.detection_threshold)
|
|
685
643
|
|
|
686
644
|
# Update detection stats
|
|
@@ -753,14 +711,6 @@ def execute_tasks(args):
|
|
|
753
711
|
# Ground truth from evaluator
|
|
754
712
|
ground_truth = 1 if eval_result.ground_truth == "TRUTHFUL" else 0
|
|
755
713
|
|
|
756
|
-
# token_scores = P(TRUTHFUL) for each token
|
|
757
|
-
# min_token_score = most suspicious token (lowest P(TRUTHFUL))
|
|
758
|
-
# max_token_score = most confident token (highest P(TRUTHFUL))
|
|
759
|
-
min_token_score = min(token_scores) if token_scores else 0.0
|
|
760
|
-
min_token_idx = token_scores.index(min_token_score) if token_scores else -1
|
|
761
|
-
max_token_score = max(token_scores) if token_scores else 0.0
|
|
762
|
-
max_token_idx = token_scores.index(max_token_score) if token_scores else -1
|
|
763
|
-
|
|
764
714
|
generation_results.append({
|
|
765
715
|
'question': question,
|
|
766
716
|
'response': response,
|
|
@@ -770,13 +720,6 @@ def execute_tasks(args):
|
|
|
770
720
|
'classifier_pred': pred_label,
|
|
771
721
|
'classifier_proba': float(pred_proba),
|
|
772
722
|
'correct': pred_label == ground_truth,
|
|
773
|
-
'token_scores': token_scores, # Per-token P(TRUTHFUL) probabilities
|
|
774
|
-
'min_token_score': min_token_score, # Most suspicious token - lowest P(TRUTHFUL)
|
|
775
|
-
'min_token_idx': min_token_idx, # Index of most suspicious token
|
|
776
|
-
'max_token_score': max_token_score, # Most confident token - highest P(TRUTHFUL) (kept for backward compat)
|
|
777
|
-
'max_token_idx': max_token_idx, # Index of most confident token
|
|
778
|
-
'num_tokens': len(token_scores),
|
|
779
|
-
'aggregation_method': 'max_score' if use_max_token_score else args.token_aggregation,
|
|
780
723
|
'quality_score': quality_score,
|
|
781
724
|
'issue_detected': issue_detected,
|
|
782
725
|
'detection_type': detection_type,
|
|
@@ -852,7 +795,7 @@ def execute_tasks(args):
|
|
|
852
795
|
classifier_type=args.classifier_type,
|
|
853
796
|
training_accuracy=report.final.accuracy,
|
|
854
797
|
training_samples=len(X),
|
|
855
|
-
token_aggregation=
|
|
798
|
+
token_aggregation=extraction_strategy.value,
|
|
856
799
|
detection_threshold=args.detection_threshold
|
|
857
800
|
)
|
|
858
801
|
|
|
@@ -884,7 +827,7 @@ def execute_tasks(args):
|
|
|
884
827
|
'task': args.task_names,
|
|
885
828
|
'model': args.model,
|
|
886
829
|
'layer': layer,
|
|
887
|
-
'aggregation':
|
|
830
|
+
'aggregation': extraction_strategy.value,
|
|
888
831
|
'threshold': args.detection_threshold,
|
|
889
832
|
'num_generations': len(generation_results),
|
|
890
833
|
'detection_stats': detection_stats,
|
|
@@ -325,11 +325,11 @@ def execute_train_unified_goodness(args):
|
|
|
325
325
|
'final': ExtractionStrategy.CHAT_LAST,
|
|
326
326
|
'first': ExtractionStrategy.CHAT_FIRST,
|
|
327
327
|
'max': ExtractionStrategy.CHAT_MAX_NORM,
|
|
328
|
-
'continuation': ExtractionStrategy.
|
|
328
|
+
'continuation': ExtractionStrategy.CHAT_FIRST, # First answer token
|
|
329
329
|
}
|
|
330
330
|
aggregation_strategy = aggregation_map.get(
|
|
331
331
|
args.token_aggregation,
|
|
332
|
-
ExtractionStrategy.
|
|
332
|
+
ExtractionStrategy.CHAT_LAST
|
|
333
333
|
)
|
|
334
334
|
|
|
335
335
|
# Map prompt strategy
|
|
@@ -353,7 +353,7 @@ def execute_train_unified_goodness(args):
|
|
|
353
353
|
negative_activations = activations_checkpoint['negative_activations']
|
|
354
354
|
print(f" ✓ Loaded activations from checkpoint ({len(positive_activations[layers[0]])} pairs)")
|
|
355
355
|
else:
|
|
356
|
-
collector = ActivationCollector(model=model
|
|
356
|
+
collector = ActivationCollector(model=model)
|
|
357
357
|
|
|
358
358
|
# Collect activations for all training pairs using batched processing
|
|
359
359
|
positive_activations = {layer: [] for layer in layers}
|
|
@@ -95,7 +95,7 @@ def run_control_vector_diagnostics(
|
|
|
95
95
|
)
|
|
96
96
|
continue
|
|
97
97
|
|
|
98
|
-
flat = detached.to(
|
|
98
|
+
flat = detached.to(device="cpu").reshape(-1)
|
|
99
99
|
|
|
100
100
|
if not torch.isfinite(flat).all():
|
|
101
101
|
non_finite = (~torch.isfinite(flat)).sum().item()
|
|
@@ -1549,7 +1549,7 @@ def _detect_sparse_structure(
|
|
|
1549
1549
|
sorted_abs = abs_diff.sort().values
|
|
1550
1550
|
n = len(sorted_abs)
|
|
1551
1551
|
cumsum = sorted_abs.cumsum(0)
|
|
1552
|
-
gini = (2 * torch.arange(1, n + 1, dtype=
|
|
1552
|
+
gini = (2 * torch.arange(1, n + 1, dtype=sorted_abs.dtype, device=sorted_abs.device) @ sorted_abs - (n + 1) * sorted_abs.sum()) / (n * sorted_abs.sum() + 1e-10)
|
|
1553
1553
|
|
|
1554
1554
|
# Sparse score: high if few dimensions are active
|
|
1555
1555
|
sparse_score = 0.4 * (1 - float(l1_l2_ratio)) + 0.3 * (1 - float(active_fraction)) + 0.3 * float(gini)
|
|
@@ -1632,11 +1632,11 @@ def _compute_dip_statistic(data: torch.Tensor) -> float:
|
|
|
1632
1632
|
return 0.0
|
|
1633
1633
|
|
|
1634
1634
|
# Empirical CDF
|
|
1635
|
-
ecdf = torch.arange(1, n + 1, dtype=
|
|
1635
|
+
ecdf = torch.arange(1, n + 1, dtype=sorted_data.dtype, device=sorted_data.device) / n
|
|
1636
1636
|
|
|
1637
1637
|
# Greatest convex minorant and least concave majorant
|
|
1638
1638
|
# Simplified: measure deviation from uniform
|
|
1639
|
-
uniform = torch.linspace(0, 1, n)
|
|
1639
|
+
uniform = torch.linspace(0, 1, n, dtype=sorted_data.dtype, device=sorted_data.device)
|
|
1640
1640
|
|
|
1641
1641
|
# Kolmogorov-Smirnov like statistic
|
|
1642
1642
|
ks_stat = (ecdf - uniform).abs().max()
|
|
@@ -188,6 +188,12 @@ def check_linearity(
|
|
|
188
188
|
linear_score = result.all_scores["linear"].score
|
|
189
189
|
linear_details = result.all_scores["linear"].details
|
|
190
190
|
|
|
191
|
+
# Include all structure scores
|
|
192
|
+
structure_scores = {
|
|
193
|
+
name: {"score": score.score, "confidence": score.confidence}
|
|
194
|
+
for name, score in result.all_scores.items()
|
|
195
|
+
}
|
|
196
|
+
|
|
191
197
|
all_results.append({
|
|
192
198
|
"extraction_strategy": strategy.value,
|
|
193
199
|
"normalize": normalize,
|
|
@@ -196,6 +202,7 @@ def check_linearity(
|
|
|
196
202
|
"cohens_d": linear_details.get("cohens_d", 0),
|
|
197
203
|
"variance_explained": linear_details.get("variance_explained", 0),
|
|
198
204
|
"best_structure": result.best_structure.value,
|
|
205
|
+
"all_structure_scores": structure_scores,
|
|
199
206
|
})
|
|
200
207
|
|
|
201
208
|
if not all_results:
|