wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +9 -5
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +256 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/apply_steering.py +5 -7
- wisent/core/cli/agent/train_classifier.py +19 -7
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/generate_paper_data.py +384 -0
- wisent/examples/scripts/intervention_validation.py +626 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/examples/scripts/threshold_analysis.py +434 -0
- wisent/examples/scripts/visualization_gallery.py +582 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
wisent/__init__.py
CHANGED
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Activation cache for geometry search.
|
|
3
|
+
|
|
4
|
+
Caches activations for ALL layers once per (benchmark, strategy) pair.
|
|
5
|
+
Layer combinations are then tested from cache without re-extraction.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Dict, List, Optional, Tuple, Any
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from wisent.core.activations.core.atoms import LayerActivations, RawActivationMap
|
|
19
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
20
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
21
|
+
from wisent.core.utils.device import resolve_default_device
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CachedActivations:
|
|
26
|
+
"""
|
|
27
|
+
Cached activations for a single (benchmark, strategy) pair.
|
|
28
|
+
|
|
29
|
+
Contains activations for ALL layers for all pairs.
|
|
30
|
+
Layer combinations can be extracted without re-running the model.
|
|
31
|
+
"""
|
|
32
|
+
benchmark: str
|
|
33
|
+
strategy: ExtractionStrategy
|
|
34
|
+
model_name: str
|
|
35
|
+
num_layers: int
|
|
36
|
+
|
|
37
|
+
# List of (positive_activations, negative_activations) per pair
|
|
38
|
+
# Each activation is a dict: layer_name -> tensor [hidden_size]
|
|
39
|
+
pair_activations: List[Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]] = field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
# Metadata
|
|
42
|
+
num_pairs: int = 0
|
|
43
|
+
hidden_size: int = 0
|
|
44
|
+
|
|
45
|
+
def add_pair(self, positive: LayerActivations, negative: LayerActivations) -> None:
|
|
46
|
+
"""Add activations for a contrastive pair."""
|
|
47
|
+
pos_dict = {k: v.clone() for k, v in positive.items() if v is not None}
|
|
48
|
+
neg_dict = {k: v.clone() for k, v in negative.items() if v is not None}
|
|
49
|
+
self.pair_activations.append((pos_dict, neg_dict))
|
|
50
|
+
self.num_pairs = len(self.pair_activations)
|
|
51
|
+
|
|
52
|
+
# Infer hidden size from first tensor
|
|
53
|
+
if self.hidden_size == 0 and pos_dict:
|
|
54
|
+
first_tensor = next(iter(pos_dict.values()))
|
|
55
|
+
self.hidden_size = first_tensor.shape[-1]
|
|
56
|
+
|
|
57
|
+
def get_layer_subset(self, layers: List[int]) -> "CachedActivations":
|
|
58
|
+
"""
|
|
59
|
+
Get a new CachedActivations with only the specified layers.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
layers: List of layer indices (0-based)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
New CachedActivations with only the specified layers
|
|
66
|
+
"""
|
|
67
|
+
layer_names = [str(l) for l in layers]
|
|
68
|
+
|
|
69
|
+
new_pairs = []
|
|
70
|
+
for pos_dict, neg_dict in self.pair_activations:
|
|
71
|
+
new_pos = {k: v for k, v in pos_dict.items() if k in layer_names}
|
|
72
|
+
new_neg = {k: v for k, v in neg_dict.items() if k in layer_names}
|
|
73
|
+
new_pairs.append((new_pos, new_neg))
|
|
74
|
+
|
|
75
|
+
result = CachedActivations(
|
|
76
|
+
benchmark=self.benchmark,
|
|
77
|
+
strategy=self.strategy,
|
|
78
|
+
model_name=self.model_name,
|
|
79
|
+
num_layers=len(layers),
|
|
80
|
+
hidden_size=self.hidden_size,
|
|
81
|
+
)
|
|
82
|
+
result.pair_activations = new_pairs
|
|
83
|
+
result.num_pairs = len(new_pairs)
|
|
84
|
+
return result
|
|
85
|
+
|
|
86
|
+
def get_available_layers(self) -> List[str]:
|
|
87
|
+
"""Get list of available layer names."""
|
|
88
|
+
if not self.pair_activations:
|
|
89
|
+
return []
|
|
90
|
+
return list(self.pair_activations[0][0].keys())
|
|
91
|
+
|
|
92
|
+
def get_positive_activations(self, layer: int | str) -> torch.Tensor:
|
|
93
|
+
"""
|
|
94
|
+
Get stacked positive activations for a single layer.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
layer: Layer index (int) or layer name (str)
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Tensor of shape [num_pairs, hidden_size]
|
|
101
|
+
"""
|
|
102
|
+
layer_name = str(layer)
|
|
103
|
+
tensors = [pos[layer_name] for pos, _ in self.pair_activations if layer_name in pos]
|
|
104
|
+
if not tensors:
|
|
105
|
+
raise KeyError(f"Layer {layer_name} not found. Available: {self.get_available_layers()}")
|
|
106
|
+
return torch.stack(tensors, dim=0)
|
|
107
|
+
|
|
108
|
+
def get_negative_activations(self, layer: int | str) -> torch.Tensor:
|
|
109
|
+
"""
|
|
110
|
+
Get stacked negative activations for a single layer.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
layer: Layer index (int) or layer name (str)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tensor of shape [num_pairs, hidden_size]
|
|
117
|
+
"""
|
|
118
|
+
layer_name = str(layer)
|
|
119
|
+
tensors = [neg[layer_name] for _, neg in self.pair_activations if layer_name in neg]
|
|
120
|
+
if not tensors:
|
|
121
|
+
raise KeyError(f"Layer {layer_name} not found. Available: {self.get_available_layers()}")
|
|
122
|
+
return torch.stack(tensors, dim=0)
|
|
123
|
+
|
|
124
|
+
def get_diff_activations(self, layer: int | str) -> torch.Tensor:
|
|
125
|
+
"""
|
|
126
|
+
Get positive - negative activation differences for a layer.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
layer: Layer index (int) or layer name (str)
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Tensor of shape [num_pairs, hidden_size]
|
|
133
|
+
"""
|
|
134
|
+
return self.get_positive_activations(layer) - self.get_negative_activations(layer)
|
|
135
|
+
|
|
136
|
+
def get_all_layers_diff(self) -> Dict[str, torch.Tensor]:
|
|
137
|
+
"""
|
|
138
|
+
Get activation differences for all layers.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dict mapping layer_name -> tensor [num_pairs, hidden_size]
|
|
142
|
+
"""
|
|
143
|
+
result = {}
|
|
144
|
+
if not self.pair_activations:
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
# Get layer names from first pair
|
|
148
|
+
layer_names = list(self.pair_activations[0][0].keys())
|
|
149
|
+
for layer_name in layer_names:
|
|
150
|
+
pos_tensors = []
|
|
151
|
+
neg_tensors = []
|
|
152
|
+
for pos, neg in self.pair_activations:
|
|
153
|
+
if layer_name in pos and layer_name in neg:
|
|
154
|
+
pos_tensors.append(pos[layer_name])
|
|
155
|
+
neg_tensors.append(neg[layer_name])
|
|
156
|
+
if pos_tensors:
|
|
157
|
+
result[layer_name] = torch.stack(pos_tensors) - torch.stack(neg_tensors)
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
def to_device(self, device: str) -> "CachedActivations":
|
|
161
|
+
"""Move all tensors to a device."""
|
|
162
|
+
new_pairs = []
|
|
163
|
+
for pos, neg in self.pair_activations:
|
|
164
|
+
new_pos = {k: v.to(device) for k, v in pos.items()}
|
|
165
|
+
new_neg = {k: v.to(device) for k, v in neg.items()}
|
|
166
|
+
new_pairs.append((new_pos, new_neg))
|
|
167
|
+
|
|
168
|
+
result = CachedActivations(
|
|
169
|
+
benchmark=self.benchmark,
|
|
170
|
+
strategy=self.strategy,
|
|
171
|
+
model_name=self.model_name,
|
|
172
|
+
num_layers=self.num_layers,
|
|
173
|
+
hidden_size=self.hidden_size,
|
|
174
|
+
)
|
|
175
|
+
result.pair_activations = new_pairs
|
|
176
|
+
result.num_pairs = self.num_pairs
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class ActivationCache:
|
|
181
|
+
"""
|
|
182
|
+
Disk-backed cache for activations.
|
|
183
|
+
|
|
184
|
+
Saves/loads activations per (model, benchmark, strategy) tuple.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def __init__(self, cache_dir: str):
|
|
188
|
+
self.cache_dir = Path(cache_dir)
|
|
189
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
190
|
+
self._memory_cache: Dict[str, CachedActivations] = {}
|
|
191
|
+
|
|
192
|
+
def _get_cache_key(self, model_name: str, benchmark: str, strategy: ExtractionStrategy) -> str:
|
|
193
|
+
"""Generate a unique cache key."""
|
|
194
|
+
key_str = f"{model_name}_{benchmark}_{strategy.value}"
|
|
195
|
+
return hashlib.md5(key_str.encode()).hexdigest()[:16]
|
|
196
|
+
|
|
197
|
+
def _get_cache_path(self, cache_key: str) -> Path:
|
|
198
|
+
"""Get path for a cache file."""
|
|
199
|
+
return self.cache_dir / f"{cache_key}.pt"
|
|
200
|
+
|
|
201
|
+
def _get_metadata_path(self, cache_key: str) -> Path:
|
|
202
|
+
"""Get path for cache metadata."""
|
|
203
|
+
return self.cache_dir / f"{cache_key}.json"
|
|
204
|
+
|
|
205
|
+
def has(self, model_name: str, benchmark: str, strategy: ExtractionStrategy) -> bool:
|
|
206
|
+
"""Check if activations are cached."""
|
|
207
|
+
key = self._get_cache_key(model_name, benchmark, strategy)
|
|
208
|
+
if key in self._memory_cache:
|
|
209
|
+
return True
|
|
210
|
+
return self._get_cache_path(key).exists()
|
|
211
|
+
|
|
212
|
+
def get(
|
|
213
|
+
self,
|
|
214
|
+
model_name: str,
|
|
215
|
+
benchmark: str,
|
|
216
|
+
strategy: ExtractionStrategy,
|
|
217
|
+
load_to_memory: bool = True,
|
|
218
|
+
) -> Optional[CachedActivations]:
|
|
219
|
+
"""
|
|
220
|
+
Get cached activations if they exist.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
model_name: Model identifier
|
|
224
|
+
benchmark: Benchmark name
|
|
225
|
+
strategy: Extraction strategy
|
|
226
|
+
load_to_memory: If True, keep in memory cache after loading
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
CachedActivations or None if not cached
|
|
230
|
+
"""
|
|
231
|
+
key = self._get_cache_key(model_name, benchmark, strategy)
|
|
232
|
+
|
|
233
|
+
# Check memory cache first
|
|
234
|
+
if key in self._memory_cache:
|
|
235
|
+
return self._memory_cache[key]
|
|
236
|
+
|
|
237
|
+
# Check disk cache
|
|
238
|
+
cache_path = self._get_cache_path(key)
|
|
239
|
+
if not cache_path.exists():
|
|
240
|
+
return None
|
|
241
|
+
|
|
242
|
+
# Load from disk
|
|
243
|
+
data = torch.load(cache_path, map_location=resolve_default_device(), weights_only=False)
|
|
244
|
+
|
|
245
|
+
cached = CachedActivations(
|
|
246
|
+
benchmark=data["benchmark"],
|
|
247
|
+
strategy=ExtractionStrategy(data["strategy"]),
|
|
248
|
+
model_name=data["model_name"],
|
|
249
|
+
num_layers=data["num_layers"],
|
|
250
|
+
hidden_size=data["hidden_size"],
|
|
251
|
+
)
|
|
252
|
+
cached.pair_activations = data["pair_activations"]
|
|
253
|
+
cached.num_pairs = data["num_pairs"]
|
|
254
|
+
|
|
255
|
+
if load_to_memory:
|
|
256
|
+
self._memory_cache[key] = cached
|
|
257
|
+
|
|
258
|
+
return cached
|
|
259
|
+
|
|
260
|
+
def put(
|
|
261
|
+
self,
|
|
262
|
+
cached: CachedActivations,
|
|
263
|
+
save_to_disk: bool = True,
|
|
264
|
+
) -> None:
|
|
265
|
+
"""
|
|
266
|
+
Store cached activations.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
cached: CachedActivations to store
|
|
270
|
+
save_to_disk: If True, persist to disk
|
|
271
|
+
"""
|
|
272
|
+
key = self._get_cache_key(cached.model_name, cached.benchmark, cached.strategy)
|
|
273
|
+
|
|
274
|
+
# Store in memory
|
|
275
|
+
self._memory_cache[key] = cached
|
|
276
|
+
|
|
277
|
+
if save_to_disk:
|
|
278
|
+
# Save to disk
|
|
279
|
+
data = {
|
|
280
|
+
"benchmark": cached.benchmark,
|
|
281
|
+
"strategy": cached.strategy.value,
|
|
282
|
+
"model_name": cached.model_name,
|
|
283
|
+
"num_layers": cached.num_layers,
|
|
284
|
+
"hidden_size": cached.hidden_size,
|
|
285
|
+
"num_pairs": cached.num_pairs,
|
|
286
|
+
"pair_activations": cached.pair_activations,
|
|
287
|
+
}
|
|
288
|
+
torch.save(data, self._get_cache_path(key))
|
|
289
|
+
|
|
290
|
+
# Save metadata as JSON
|
|
291
|
+
metadata = {
|
|
292
|
+
"benchmark": cached.benchmark,
|
|
293
|
+
"strategy": cached.strategy.value,
|
|
294
|
+
"model_name": cached.model_name,
|
|
295
|
+
"num_layers": cached.num_layers,
|
|
296
|
+
"hidden_size": cached.hidden_size,
|
|
297
|
+
"num_pairs": cached.num_pairs,
|
|
298
|
+
}
|
|
299
|
+
with open(self._get_metadata_path(key), "w") as f:
|
|
300
|
+
json.dump(metadata, f, indent=2)
|
|
301
|
+
|
|
302
|
+
def clear_memory(self) -> None:
|
|
303
|
+
"""Clear the in-memory cache."""
|
|
304
|
+
self._memory_cache.clear()
|
|
305
|
+
|
|
306
|
+
def list_cached(self) -> List[Dict[str, Any]]:
|
|
307
|
+
"""List all cached activations."""
|
|
308
|
+
result = []
|
|
309
|
+
for meta_path in self.cache_dir.glob("*.json"):
|
|
310
|
+
with open(meta_path) as f:
|
|
311
|
+
result.append(json.load(f))
|
|
312
|
+
return result
|
|
313
|
+
|
|
314
|
+
def get_cache_size_bytes(self) -> int:
|
|
315
|
+
"""Get total size of cache on disk."""
|
|
316
|
+
total = 0
|
|
317
|
+
for path in self.cache_dir.glob("*.pt"):
|
|
318
|
+
total += path.stat().st_size
|
|
319
|
+
return total
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def collect_and_cache_activations(
|
|
323
|
+
model: "WisentModel",
|
|
324
|
+
pairs: List[ContrastivePair],
|
|
325
|
+
benchmark: str,
|
|
326
|
+
strategy: ExtractionStrategy,
|
|
327
|
+
cache: Optional[ActivationCache] = None,
|
|
328
|
+
cache_dir: Optional[str] = None,
|
|
329
|
+
show_progress: bool = True,
|
|
330
|
+
) -> CachedActivations:
|
|
331
|
+
"""
|
|
332
|
+
Collect activations for all pairs and all layers, then cache.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
model: WisentModel instance
|
|
336
|
+
pairs: List of contrastive pairs
|
|
337
|
+
benchmark: Benchmark name
|
|
338
|
+
strategy: Extraction strategy
|
|
339
|
+
cache: Optional existing cache to use
|
|
340
|
+
cache_dir: Cache directory (used if cache not provided)
|
|
341
|
+
show_progress: Print progress
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
CachedActivations with all layers for all pairs
|
|
345
|
+
"""
|
|
346
|
+
from wisent.core.activations.activations_collector import ActivationCollector
|
|
347
|
+
|
|
348
|
+
# Check cache first
|
|
349
|
+
if cache is None and cache_dir:
|
|
350
|
+
cache = ActivationCache(cache_dir)
|
|
351
|
+
|
|
352
|
+
if cache and cache.has(model.model_name, benchmark, strategy):
|
|
353
|
+
if show_progress:
|
|
354
|
+
print(f"Loading cached activations for {benchmark}/{strategy.value}")
|
|
355
|
+
return cache.get(model.model_name, benchmark, strategy)
|
|
356
|
+
|
|
357
|
+
# Collect activations for ALL layers (preserve model's native dtype)
|
|
358
|
+
collector = ActivationCollector(model=model)
|
|
359
|
+
|
|
360
|
+
cached = CachedActivations(
|
|
361
|
+
benchmark=benchmark,
|
|
362
|
+
strategy=strategy,
|
|
363
|
+
model_name=model.model_name,
|
|
364
|
+
num_layers=model.num_layers,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
for i, pair in enumerate(pairs):
|
|
368
|
+
if show_progress and i % 10 == 0:
|
|
369
|
+
print(f"Collecting activations: {i+1}/{len(pairs)}", end="\r", flush=True)
|
|
370
|
+
|
|
371
|
+
# Collect ALL layers (layers=None)
|
|
372
|
+
updated = collector.collect(pair, strategy=strategy, layers=None)
|
|
373
|
+
cached.add_pair(
|
|
374
|
+
updated.positive_response.layers_activations,
|
|
375
|
+
updated.negative_response.layers_activations,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if show_progress:
|
|
379
|
+
print(f"Collected activations: {len(pairs)}/{len(pairs)} pairs, {cached.num_layers} layers")
|
|
380
|
+
|
|
381
|
+
# Cache the result
|
|
382
|
+
if cache:
|
|
383
|
+
cache.put(cached)
|
|
384
|
+
if show_progress:
|
|
385
|
+
print(f"Cached to {cache.cache_dir}")
|
|
386
|
+
|
|
387
|
+
return cached
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
# Type hint for WisentModel (avoid circular import)
|
|
391
|
+
from typing import TYPE_CHECKING
|
|
392
|
+
if TYPE_CHECKING:
|
|
393
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
@@ -50,7 +50,7 @@ class Activations:
|
|
|
50
50
|
features = tensor.mean(dim=1).squeeze(0)
|
|
51
51
|
elif strategy in (ExtractionStrategy.CHAT_LAST, ExtractionStrategy.ROLE_PLAY, ExtractionStrategy.MC_BALANCED):
|
|
52
52
|
features = tensor[:, -1, :].squeeze(0)
|
|
53
|
-
elif strategy
|
|
53
|
+
elif strategy == ExtractionStrategy.CHAT_FIRST:
|
|
54
54
|
features = tensor[:, 0, :].squeeze(0)
|
|
55
55
|
elif strategy == ExtractionStrategy.CHAT_MAX_NORM:
|
|
56
56
|
norms = torch.norm(tensor, dim=2)
|
|
@@ -58,11 +58,11 @@ class Activations:
|
|
|
58
58
|
features = tensor[0, max_idx[0], :]
|
|
59
59
|
elif strategy == ExtractionStrategy.CHAT_WEIGHTED:
|
|
60
60
|
seq_len = tensor.shape[1]
|
|
61
|
-
weights = torch.exp(-torch.arange(seq_len, dtype=
|
|
61
|
+
weights = torch.exp(-torch.arange(seq_len, dtype=tensor.dtype, device=tensor.device) * 0.5)
|
|
62
62
|
weights = weights / weights.sum()
|
|
63
63
|
features = (tensor * weights.unsqueeze(0).unsqueeze(2)).sum(dim=1).squeeze(0)
|
|
64
64
|
else:
|
|
65
|
-
|
|
65
|
+
raise InvalidValueError(param="extraction_strategy", reason=f"Unknown extraction strategy: {strategy}")
|
|
66
66
|
|
|
67
67
|
return features
|
|
68
68
|
|
|
@@ -25,11 +25,11 @@ class ActivationCollector:
|
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
27
|
model: WisentModel instance
|
|
28
|
-
store_device: Device to store collected activations on (default "cpu")
|
|
29
|
-
dtype: Optional torch.dtype to cast activations to
|
|
28
|
+
store_device: Device to store collected activations on (default: "cpu" to avoid GPU OOM)
|
|
29
|
+
dtype: Optional torch.dtype to cast activations to
|
|
30
30
|
|
|
31
31
|
Example:
|
|
32
|
-
>>> collector = ActivationCollector(model=my_model
|
|
32
|
+
>>> collector = ActivationCollector(model=my_model)
|
|
33
33
|
>>> updated_pair = collector.collect(
|
|
34
34
|
... pair,
|
|
35
35
|
... strategy=ExtractionStrategy.CHAT_LAST,
|
|
@@ -37,7 +37,7 @@ class ActivationCollector:
|
|
|
37
37
|
... )
|
|
38
38
|
>>> pos_acts = updated_pair.positive_response.layers_activations
|
|
39
39
|
>>> pos_acts.summary()
|
|
40
|
-
{'8': {'shape': (2048,),
|
|
40
|
+
{'8': {'shape': (2048,), ...}, '12': {...}}
|
|
41
41
|
"""
|
|
42
42
|
|
|
43
43
|
model: "WisentModel"
|
|
@@ -220,8 +220,12 @@ class ActivationCollector:
|
|
|
220
220
|
value = h.mean(dim=0)
|
|
221
221
|
elif strategy == ExtractionStrategy.CHAT_FIRST:
|
|
222
222
|
value = h[0]
|
|
223
|
-
|
|
223
|
+
elif strategy in (ExtractionStrategy.CHAT_LAST, ExtractionStrategy.ROLE_PLAY,
|
|
224
|
+
ExtractionStrategy.MC_BALANCED,
|
|
225
|
+
ExtractionStrategy.CHAT_MAX_NORM, ExtractionStrategy.CHAT_WEIGHTED):
|
|
224
226
|
value = h[-1]
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError(f"Unsupported strategy for batched collection: {strategy}")
|
|
225
229
|
|
|
226
230
|
collected[name] = value.to(self.store_device)
|
|
227
231
|
|
|
@@ -8,13 +8,16 @@ Based on empirical testing across 3 models (Llama-3.2-1B, Llama-2-7b, Qwen3-8B)
|
|
|
8
8
|
and 4 tasks (truthfulqa, happy, left_wing, livecodebench):
|
|
9
9
|
|
|
10
10
|
Results:
|
|
11
|
-
- last_token:
|
|
12
|
-
- all_mean:
|
|
13
|
-
- all_min:
|
|
14
|
-
-
|
|
15
|
-
- first_token: 50.0% avg accuracy (completely useless - BOS token is identical for all inputs)
|
|
11
|
+
- last_token: Best performer (77% with chat_last training on truthfulqa)
|
|
12
|
+
- all_mean: Poor (~50%) - dominated by shared prompt tokens
|
|
13
|
+
- all_max/all_min: Poor (~50%)
|
|
14
|
+
- first_token: BROKEN (50%) - BOS token is identical for all inputs
|
|
16
15
|
|
|
17
16
|
Recommendation: Use LAST_TOKEN (default) - it works best with chat_last training strategy.
|
|
17
|
+
|
|
18
|
+
IMPORTANT: These strategies operate on the FULL sequence (prompt + response).
|
|
19
|
+
At inference time, we typically don't know where the answer starts, so we
|
|
20
|
+
can only use strategies that work on the whole sequence.
|
|
18
21
|
"""
|
|
19
22
|
|
|
20
23
|
from enum import Enum
|
|
@@ -102,8 +105,7 @@ def extract_inference_activation(
|
|
|
102
105
|
return hidden_states[torch.argmin(norms)]
|
|
103
106
|
|
|
104
107
|
else:
|
|
105
|
-
|
|
106
|
-
return hidden_states[-1]
|
|
108
|
+
raise ValueError(f"Unknown classifier inference strategy: {strategy}")
|
|
107
109
|
|
|
108
110
|
|
|
109
111
|
def get_inference_score(
|
|
@@ -152,8 +154,7 @@ def get_inference_score(
|
|
|
152
154
|
elif strategy == ClassifierInferenceStrategy.ALL_MIN:
|
|
153
155
|
return float(np.min(all_scores))
|
|
154
156
|
|
|
155
|
-
|
|
156
|
-
return float(classifier.predict_proba([hidden_np[-1]])[0, 1])
|
|
157
|
+
raise ValueError(f"Unknown classifier inference strategy: {strategy}")
|
|
157
158
|
|
|
158
159
|
|
|
159
160
|
def get_recommended_inference_strategy(train_strategy) -> ClassifierInferenceStrategy:
|
|
@@ -161,8 +162,8 @@ def get_recommended_inference_strategy(train_strategy) -> ClassifierInferenceStr
|
|
|
161
162
|
Get the recommended inference strategy for a given training strategy.
|
|
162
163
|
|
|
163
164
|
Based on empirical testing:
|
|
164
|
-
- chat_last, role_play, mc_balanced -> last_token
|
|
165
|
-
- chat_mean, chat_weighted, chat_max_norm, chat_first
|
|
165
|
+
- chat_last, role_play, mc_balanced -> last_token
|
|
166
|
+
- chat_mean, chat_weighted, chat_max_norm, chat_first -> all_mean
|
|
166
167
|
|
|
167
168
|
Args:
|
|
168
169
|
train_strategy: ExtractionStrategy used for training
|