wisent 0.7.701__py3-none-any.whl → 0.7.901__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +9 -5
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +256 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/apply_steering.py +5 -7
- wisent/core/cli/agent/train_classifier.py +19 -7
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +2 -2
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/generate_paper_data.py +384 -0
- wisent/examples/scripts/intervention_validation.py +626 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/examples/scripts/threshold_analysis.py +434 -0
- wisent/examples/scripts/visualization_gallery.py +582 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/RECORD +329 -295
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration for geometry search space.
|
|
3
|
+
|
|
4
|
+
Defines all parameters to search over when testing if a unified "goodness"
|
|
5
|
+
direction exists across benchmarks.
|
|
6
|
+
|
|
7
|
+
Strategy:
|
|
8
|
+
- Extract activations for ALL layers once per (benchmark, strategy) pair
|
|
9
|
+
- Cache activations to disk/memory
|
|
10
|
+
- Test all layer combinations from cached activations (fast, just tensor math)
|
|
11
|
+
- This reduces extraction time from O(layer_combos) to O(1) per benchmark
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import List, Optional, Dict, Any
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
import json
|
|
19
|
+
|
|
20
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
21
|
+
from wisent.core.utils.layer_combinations import get_layer_combinations
|
|
22
|
+
from wisent.core.benchmark_registry import get_all_benchmarks
|
|
23
|
+
from wisent.core.activations.activation_cache import ActivationCache, CachedActivations
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class GeometrySearchConfig:
|
|
28
|
+
"""Configuration for a single geometry search run."""
|
|
29
|
+
|
|
30
|
+
# Pairs settings
|
|
31
|
+
pairs_per_benchmark: int = 50
|
|
32
|
+
random_seed: int = 42
|
|
33
|
+
|
|
34
|
+
# Layer settings
|
|
35
|
+
max_layer_combo_size: int = 3
|
|
36
|
+
|
|
37
|
+
# Caching
|
|
38
|
+
cache_activations: bool = True
|
|
39
|
+
cache_dir: Optional[str] = None
|
|
40
|
+
|
|
41
|
+
# Estimation
|
|
42
|
+
estimated_time_per_extraction_seconds: float = 120.0 # ~2 min per (benchmark, strategy)
|
|
43
|
+
|
|
44
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
45
|
+
return {
|
|
46
|
+
"pairs_per_benchmark": self.pairs_per_benchmark,
|
|
47
|
+
"random_seed": self.random_seed,
|
|
48
|
+
"max_layer_combo_size": self.max_layer_combo_size,
|
|
49
|
+
"cache_activations": self.cache_activations,
|
|
50
|
+
"cache_dir": self.cache_dir,
|
|
51
|
+
"estimated_time_per_extraction_seconds": self.estimated_time_per_extraction_seconds,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_dict(cls, data: Dict[str, Any]) -> "GeometrySearchConfig":
|
|
56
|
+
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GeometrySearchSpace:
|
|
60
|
+
"""
|
|
61
|
+
Search space configuration for geometry testing.
|
|
62
|
+
|
|
63
|
+
Combines:
|
|
64
|
+
- Models to test
|
|
65
|
+
- Extraction strategies
|
|
66
|
+
- Layer combinations
|
|
67
|
+
- Benchmarks
|
|
68
|
+
|
|
69
|
+
With activation caching:
|
|
70
|
+
- Extract ALL layers once per (benchmark, strategy)
|
|
71
|
+
- Test layer combinations from cache (no re-extraction needed)
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
# Default models to test
|
|
75
|
+
DEFAULT_MODELS = [
|
|
76
|
+
"meta-llama/Llama-3.2-1B-Instruct",
|
|
77
|
+
"meta-llama/Llama-2-7b-chat-hf",
|
|
78
|
+
"Qwen/Qwen3-8B",
|
|
79
|
+
"openai/gpt-oss-20b",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
# Extraction strategies for instruct models
|
|
83
|
+
INSTRUCT_STRATEGIES = [
|
|
84
|
+
ExtractionStrategy.CHAT_MEAN,
|
|
85
|
+
ExtractionStrategy.CHAT_FIRST,
|
|
86
|
+
ExtractionStrategy.CHAT_LAST,
|
|
87
|
+
ExtractionStrategy.CHAT_MAX_NORM,
|
|
88
|
+
ExtractionStrategy.CHAT_WEIGHTED,
|
|
89
|
+
ExtractionStrategy.ROLE_PLAY,
|
|
90
|
+
ExtractionStrategy.MC_BALANCED,
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
# Extraction strategies for base models
|
|
94
|
+
BASE_STRATEGIES = [
|
|
95
|
+
ExtractionStrategy.COMPLETION_LAST,
|
|
96
|
+
ExtractionStrategy.COMPLETION_MEAN,
|
|
97
|
+
ExtractionStrategy.MC_COMPLETION,
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
models: Optional[List[str]] = None,
|
|
103
|
+
strategies: Optional[List[ExtractionStrategy]] = None,
|
|
104
|
+
benchmarks: Optional[List[str]] = None,
|
|
105
|
+
config: Optional[GeometrySearchConfig] = None,
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Initialize the search space.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
models: List of model names to test. Defaults to DEFAULT_MODELS.
|
|
112
|
+
strategies: List of extraction strategies. Defaults to INSTRUCT_STRATEGIES.
|
|
113
|
+
benchmarks: List of benchmarks. Defaults to all available benchmarks.
|
|
114
|
+
config: Search configuration (pairs, caching, etc.)
|
|
115
|
+
"""
|
|
116
|
+
self.models = models or self.DEFAULT_MODELS
|
|
117
|
+
self.strategies = strategies or self.INSTRUCT_STRATEGIES
|
|
118
|
+
self.benchmarks = benchmarks or get_all_benchmarks()
|
|
119
|
+
self.config = config or GeometrySearchConfig()
|
|
120
|
+
|
|
121
|
+
def get_layer_combinations_for_model(self, model_name: str, num_layers: int) -> List[List[int]]:
|
|
122
|
+
"""
|
|
123
|
+
Get all layer combinations to test for a given model.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model_name: Name of the model
|
|
127
|
+
num_layers: Number of layers in the model
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
List of layer combinations
|
|
131
|
+
"""
|
|
132
|
+
return get_layer_combinations(num_layers, self.config.max_layer_combo_size)
|
|
133
|
+
|
|
134
|
+
def get_extraction_count(self) -> int:
|
|
135
|
+
"""
|
|
136
|
+
Calculate number of activation extractions needed (with caching).
|
|
137
|
+
|
|
138
|
+
With caching, we extract ALL layers once per (benchmark, strategy).
|
|
139
|
+
Layer combinations are tested from cache without re-extraction.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Number of (benchmark, strategy) pairs = extraction operations
|
|
143
|
+
"""
|
|
144
|
+
return len(self.benchmarks) * len(self.strategies)
|
|
145
|
+
|
|
146
|
+
def get_total_configurations(self, num_layers: int) -> int:
|
|
147
|
+
"""
|
|
148
|
+
Calculate total number of configurations to test.
|
|
149
|
+
|
|
150
|
+
Total = strategies * layer_combos * benchmarks
|
|
151
|
+
(Layer combos are tested from cached activations)
|
|
152
|
+
"""
|
|
153
|
+
from wisent.core.utils.layer_combinations import get_layer_combinations_count
|
|
154
|
+
|
|
155
|
+
layer_combos = get_layer_combinations_count(num_layers, self.config.max_layer_combo_size)
|
|
156
|
+
return len(self.strategies) * layer_combos * len(self.benchmarks)
|
|
157
|
+
|
|
158
|
+
def estimate_time_hours(self) -> float:
|
|
159
|
+
"""
|
|
160
|
+
Estimate total time for geometry search (per model).
|
|
161
|
+
|
|
162
|
+
With caching:
|
|
163
|
+
- Extract once per (benchmark, strategy)
|
|
164
|
+
- Layer combo testing is fast (from cache)
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Estimated hours per model
|
|
168
|
+
"""
|
|
169
|
+
extractions = self.get_extraction_count()
|
|
170
|
+
seconds = extractions * self.config.estimated_time_per_extraction_seconds
|
|
171
|
+
return seconds / 3600
|
|
172
|
+
|
|
173
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
174
|
+
"""Serialize to dictionary."""
|
|
175
|
+
return {
|
|
176
|
+
"models": self.models,
|
|
177
|
+
"strategies": [s.value for s in self.strategies],
|
|
178
|
+
"benchmarks": self.benchmarks,
|
|
179
|
+
"config": self.config.to_dict(),
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def from_dict(cls, data: Dict[str, Any]) -> "GeometrySearchSpace":
|
|
184
|
+
"""Deserialize from dictionary."""
|
|
185
|
+
strategies = [ExtractionStrategy(s) for s in data.get("strategies", [])]
|
|
186
|
+
config = GeometrySearchConfig.from_dict(data.get("config", {}))
|
|
187
|
+
return cls(
|
|
188
|
+
models=data.get("models"),
|
|
189
|
+
strategies=strategies if strategies else None,
|
|
190
|
+
benchmarks=data.get("benchmarks"),
|
|
191
|
+
config=config,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def summary(self) -> str:
|
|
195
|
+
"""Return a human-readable summary of the search space."""
|
|
196
|
+
lines = [
|
|
197
|
+
"Geometry Search Space:",
|
|
198
|
+
f" Models: {len(self.models)}",
|
|
199
|
+
f" Strategies: {len(self.strategies)}",
|
|
200
|
+
f" Benchmarks: {len(self.benchmarks)}",
|
|
201
|
+
f" Pairs per benchmark: {self.config.pairs_per_benchmark}",
|
|
202
|
+
f" Max layer combo size: {self.config.max_layer_combo_size}",
|
|
203
|
+
f" Cache activations: {self.config.cache_activations}",
|
|
204
|
+
f"",
|
|
205
|
+
f" Extractions needed (per model): {self.get_extraction_count()}",
|
|
206
|
+
f" Estimated time (per model): {self.estimate_time_hours():.1f} hours",
|
|
207
|
+
]
|
|
208
|
+
return "\n".join(lines)
|
|
209
|
+
|
|
210
|
+
def save(self, path: str) -> None:
|
|
211
|
+
"""Save search space to JSON file."""
|
|
212
|
+
with open(path, "w") as f:
|
|
213
|
+
json.dump(self.to_dict(), f, indent=2)
|
|
214
|
+
|
|
215
|
+
@classmethod
|
|
216
|
+
def load(cls, path: str) -> "GeometrySearchSpace":
|
|
217
|
+
"""Load search space from JSON file."""
|
|
218
|
+
with open(path) as f:
|
|
219
|
+
return cls.from_dict(json.load(f))
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# Default search space instance
|
|
223
|
+
DEFAULT_SEARCH_SPACE = GeometrySearchSpace()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if __name__ == "__main__":
|
|
227
|
+
# Print summary of default search space
|
|
228
|
+
space = GeometrySearchSpace()
|
|
229
|
+
print(space.summary())
|
|
230
|
+
print()
|
|
231
|
+
|
|
232
|
+
# Example with 16 layers (Llama-3.2-1B)
|
|
233
|
+
num_layers = 16
|
|
234
|
+
layer_combos = space.get_layer_combinations_for_model("test", num_layers)
|
|
235
|
+
print(f"For a {num_layers}-layer model:")
|
|
236
|
+
print(f" Layer combinations: {len(layer_combos)}")
|
|
237
|
+
print(f" Total configs to test: {space.get_total_configurations(num_layers)}")
|
|
@@ -370,7 +370,7 @@ class HyperparameterOptimizer:
|
|
|
370
370
|
prompt_strategy = prompt_strategy_map.get(prompt_construction_strategy, ExtractionStrategy.CHAT_LAST)
|
|
371
371
|
|
|
372
372
|
# Create activation collector
|
|
373
|
-
collector = ActivationCollector(model=model
|
|
373
|
+
collector = ActivationCollector(model=model)
|
|
374
374
|
layer_str = str(layer)
|
|
375
375
|
|
|
376
376
|
# Collect activations for training pairs
|
wisent/core/main.py
CHANGED
|
@@ -13,6 +13,7 @@ from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, exe
|
|
|
13
13
|
from wisent.core.cli.train_unified_goodness import execute_train_unified_goodness
|
|
14
14
|
from wisent.core.cli.check_linearity import execute_check_linearity
|
|
15
15
|
from wisent.core.cli.cluster_benchmarks import execute_cluster_benchmarks
|
|
16
|
+
from wisent.core.cli.geometry_search import execute_geometry_search
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def _should_show_banner() -> bool:
|
|
@@ -95,6 +96,8 @@ def main():
|
|
|
95
96
|
execute_check_linearity(args)
|
|
96
97
|
elif args.command == 'cluster-benchmarks':
|
|
97
98
|
execute_cluster_benchmarks(args)
|
|
99
|
+
elif args.command == 'geometry-search':
|
|
100
|
+
execute_geometry_search(args)
|
|
98
101
|
else:
|
|
99
102
|
print(f"\n✗ Command '{args.command}' is not yet implemented")
|
|
100
103
|
sys.exit(1)
|
wisent/core/models/core/atoms.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
|
7
7
|
from typing import Mapping
|
|
8
8
|
|
|
9
9
|
from wisent.core.errors import InvalidValueError, InvalidRangeError
|
|
10
|
+
from wisent.core.utils.device import preferred_dtype
|
|
10
11
|
|
|
11
12
|
if TYPE_CHECKING:
|
|
12
13
|
from wisent.core.activations.core.atoms import RawActivationMap
|
|
@@ -213,12 +214,13 @@ class SteeringPlan:
|
|
|
213
214
|
"""
|
|
214
215
|
if n < 0:
|
|
215
216
|
raise InvalidRangeError(param_name="n", actual=n, min_val=0)
|
|
217
|
+
dtype = preferred_dtype()
|
|
216
218
|
if n == 0:
|
|
217
|
-
return torch.empty(0, dtype=
|
|
219
|
+
return torch.empty(0, dtype=dtype)
|
|
218
220
|
if weights is None:
|
|
219
|
-
return torch.full((n,), 1.0 / n, dtype=
|
|
221
|
+
return torch.full((n,), 1.0 / n, dtype=dtype)
|
|
220
222
|
|
|
221
|
-
w = torch.as_tensor(weights, dtype=
|
|
223
|
+
w = torch.as_tensor(weights, dtype=dtype)
|
|
222
224
|
if w.numel() != n:
|
|
223
225
|
raise InvalidValueError(param_name="weights length", actual=w.numel(), expected=f"{n} (number of activation maps)")
|
|
224
226
|
s = float(w.sum())
|
|
@@ -89,7 +89,7 @@ class WisentModel:
|
|
|
89
89
|
optional preloaded model (skips from_pretrained if provided).
|
|
90
90
|
"""
|
|
91
91
|
self.model_name = model_name
|
|
92
|
-
self.device = device or
|
|
92
|
+
self.device = resolve_default_device() if device is None or device == "auto" else device
|
|
93
93
|
|
|
94
94
|
# Determine appropriate dtype and settings for the device
|
|
95
95
|
load_kwargs = {
|
|
@@ -17,7 +17,7 @@ from optuna.pruners import MedianPruner
|
|
|
17
17
|
from optuna.samplers import TPESampler
|
|
18
18
|
|
|
19
19
|
from wisent.core.classifier.classifier import Classifier
|
|
20
|
-
from wisent.core.utils.device import resolve_default_device
|
|
20
|
+
from wisent.core.utils.device import resolve_default_device, preferred_dtype
|
|
21
21
|
from wisent.core.errors import NoActivationDataError, ClassifierCreationError
|
|
22
22
|
|
|
23
23
|
from .activation_generator import ActivationData, ActivationGenerator, GenerationConfig
|
|
@@ -44,7 +44,7 @@ def get_model_dtype(model) -> torch.dtype:
|
|
|
44
44
|
return next(model_params).dtype
|
|
45
45
|
except StopIteration:
|
|
46
46
|
# Fallback if no parameters found
|
|
47
|
-
return
|
|
47
|
+
return preferred_dtype()
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
logger = logging.getLogger(__name__)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Parser for check-linearity command."""
|
|
2
2
|
|
|
3
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
def setup_check_linearity_parser(parser):
|
|
5
7
|
"""Set up the check-linearity command parser."""
|
|
@@ -9,6 +11,14 @@ def setup_check_linearity_parser(parser):
|
|
|
9
11
|
help='Path to JSON file containing contrastive pairs'
|
|
10
12
|
)
|
|
11
13
|
|
|
14
|
+
parser.add_argument(
|
|
15
|
+
'--extraction-strategy',
|
|
16
|
+
type=str,
|
|
17
|
+
default=None,
|
|
18
|
+
choices=ExtractionStrategy.list_all(),
|
|
19
|
+
help=f'Extraction strategy to use. If not specified, tests multiple strategies. Options: {", ".join(ExtractionStrategy.list_all())}'
|
|
20
|
+
)
|
|
21
|
+
|
|
12
22
|
parser.add_argument(
|
|
13
23
|
'--model',
|
|
14
24
|
type=str,
|
|
@@ -19,8 +29,8 @@ def setup_check_linearity_parser(parser):
|
|
|
19
29
|
parser.add_argument(
|
|
20
30
|
'--device',
|
|
21
31
|
type=str,
|
|
22
|
-
default='
|
|
23
|
-
help='Device to run model on (cuda, mps, cpu)'
|
|
32
|
+
default='auto',
|
|
33
|
+
help='Device to run model on (auto, cuda, mps, cpu)'
|
|
24
34
|
)
|
|
25
35
|
|
|
26
36
|
parser.add_argument(
|
|
@@ -40,8 +40,8 @@ def setup_generate_vector_from_synthetic_parser(parser: argparse.ArgumentParser)
|
|
|
40
40
|
parser.add_argument(
|
|
41
41
|
"--device",
|
|
42
42
|
type=str,
|
|
43
|
-
default="
|
|
44
|
-
help="Device to use (e.g., 'cpu', 'cuda', 'cuda:0')"
|
|
43
|
+
default="auto",
|
|
44
|
+
help="Device to use (e.g., 'auto', 'cpu', 'cuda', 'cuda:0', 'mps')"
|
|
45
45
|
)
|
|
46
46
|
|
|
47
47
|
# Pair generation
|
|
@@ -46,8 +46,8 @@ def setup_generate_vector_from_task_parser(parser: argparse.ArgumentParser) -> N
|
|
|
46
46
|
parser.add_argument(
|
|
47
47
|
"--device",
|
|
48
48
|
type=str,
|
|
49
|
-
default="
|
|
50
|
-
help="Device to use (e.g., 'cpu', 'cuda', 'cuda:0')"
|
|
49
|
+
default="auto",
|
|
50
|
+
help="Device to use (e.g., 'auto', 'cpu', 'cuda', 'cuda:0', 'mps')"
|
|
51
51
|
)
|
|
52
52
|
|
|
53
53
|
# Pair generation
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Parser for geometry-search command."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def setup_geometry_search_parser(parser: argparse.ArgumentParser) -> None:
|
|
7
|
+
"""Set up the geometry-search command parser."""
|
|
8
|
+
parser.add_argument(
|
|
9
|
+
"--model",
|
|
10
|
+
type=str,
|
|
11
|
+
required=True,
|
|
12
|
+
help="Model name or path (e.g., meta-llama/Llama-3.2-1B-Instruct)",
|
|
13
|
+
)
|
|
14
|
+
parser.add_argument(
|
|
15
|
+
"--output",
|
|
16
|
+
type=str,
|
|
17
|
+
default="/home/ubuntu/output/geometry_results.json",
|
|
18
|
+
help="Output path for results JSON",
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--pairs-per-benchmark",
|
|
22
|
+
type=int,
|
|
23
|
+
default=50,
|
|
24
|
+
help="Number of pairs to sample per benchmark (default: 50)",
|
|
25
|
+
)
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"--max-layer-combo-size",
|
|
28
|
+
type=int,
|
|
29
|
+
default=3,
|
|
30
|
+
help="Maximum layers in combination (default: 3 = individual + pairs + triplets)",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--strategies",
|
|
34
|
+
type=str,
|
|
35
|
+
default=None,
|
|
36
|
+
help="Comma-separated list of strategies (default: all 7)",
|
|
37
|
+
)
|
|
38
|
+
parser.add_argument(
|
|
39
|
+
"--benchmarks",
|
|
40
|
+
type=str,
|
|
41
|
+
default=None,
|
|
42
|
+
help="Comma-separated list of benchmarks, or path to .txt file (default: all)",
|
|
43
|
+
)
|
|
44
|
+
parser.add_argument(
|
|
45
|
+
"--cache-dir",
|
|
46
|
+
type=str,
|
|
47
|
+
default=None,
|
|
48
|
+
help="Directory for activation cache (default: /tmp/wisent_geometry_cache_<model>)",
|
|
49
|
+
)
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--seed",
|
|
52
|
+
type=int,
|
|
53
|
+
default=42,
|
|
54
|
+
help="Random seed for reproducibility (default: 42)",
|
|
55
|
+
)
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--device",
|
|
58
|
+
type=str,
|
|
59
|
+
default="auto",
|
|
60
|
+
help="Device for model (auto/cuda/mps/cpu, default: auto)",
|
|
61
|
+
)
|
|
@@ -40,6 +40,7 @@ from wisent.core.parser_arguments.train_unified_goodness_parser import setup_tra
|
|
|
40
40
|
from wisent.core.parser_arguments.optimize_parser import setup_optimize_parser
|
|
41
41
|
from wisent.core.parser_arguments.check_linearity_parser import setup_check_linearity_parser
|
|
42
42
|
from wisent.core.parser_arguments.cluster_benchmarks_parser import setup_cluster_benchmarks_parser
|
|
43
|
+
from wisent.core.parser_arguments.geometry_search_parser import setup_geometry_search_parser
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
def setup_parser() -> argparse.ArgumentParser:
|
|
@@ -225,4 +226,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|
|
225
226
|
)
|
|
226
227
|
setup_cluster_benchmarks_parser(cluster_benchmarks_parser)
|
|
227
228
|
|
|
229
|
+
# Geometry search command - search for unified goodness direction across all benchmarks
|
|
230
|
+
geometry_search_parser = subparsers.add_parser(
|
|
231
|
+
"geometry-search",
|
|
232
|
+
help="Search for unified goodness direction across benchmarks (analyzes structure: linear/cone/orthogonal)"
|
|
233
|
+
)
|
|
234
|
+
setup_geometry_search_parser(geometry_search_parser)
|
|
235
|
+
|
|
228
236
|
return parser
|
|
@@ -32,8 +32,8 @@ def setup_train_unified_goodness_parser(parser: argparse.ArgumentParser) -> None
|
|
|
32
32
|
parser.add_argument(
|
|
33
33
|
"--device",
|
|
34
34
|
type=str,
|
|
35
|
-
default="
|
|
36
|
-
help="Device to use (e.g., 'cpu', 'cuda', 'cuda:0')"
|
|
35
|
+
default="auto",
|
|
36
|
+
help="Device to use (e.g., 'auto', 'cpu', 'cuda', 'cuda:0', 'mps')"
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
# Benchmark selection
|
wisent/core/steering.py
CHANGED
|
@@ -477,11 +477,13 @@ class SteeringMethod:
|
|
|
477
477
|
# Get prediction from steering method
|
|
478
478
|
prediction = self.predict_proba(activation)
|
|
479
479
|
|
|
480
|
-
# Convert to tensor for loss computation
|
|
480
|
+
# Convert to tensor for loss computation (use activation's dtype)
|
|
481
481
|
if not isinstance(prediction, torch.Tensor):
|
|
482
|
-
|
|
482
|
+
from wisent.core.utils.device import preferred_dtype
|
|
483
|
+
pred_dtype = activation.dtype if isinstance(activation, torch.Tensor) else preferred_dtype()
|
|
484
|
+
prediction = torch.tensor(prediction, dtype=pred_dtype, device=self.device)
|
|
483
485
|
|
|
484
|
-
target = torch.tensor(label, dtype=
|
|
486
|
+
target = torch.tensor(label, dtype=prediction.dtype, device=self.device)
|
|
485
487
|
|
|
486
488
|
# Binary cross-entropy loss
|
|
487
489
|
loss = F.binary_cross_entropy_with_logits(prediction.unsqueeze(0), target.unsqueeze(0))
|
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
|
6
6
|
|
|
7
7
|
from wisent.core.steering_methods.core.atoms import PerLayerBaseSteeringMethod
|
|
8
8
|
from wisent.core.errors import InsufficientDataError
|
|
9
|
+
from wisent.core.utils.device import preferred_dtype
|
|
9
10
|
|
|
10
11
|
__all__ = [
|
|
11
12
|
"HyperplaneMethod",
|
|
@@ -61,7 +62,7 @@ class HyperplaneMethod(PerLayerBaseSteeringMethod):
|
|
|
61
62
|
clf.fit(X, y)
|
|
62
63
|
|
|
63
64
|
# Use classifier weights as steering vector
|
|
64
|
-
v = torch.tensor(clf.coef_[0], dtype=
|
|
65
|
+
v = torch.tensor(clf.coef_[0], dtype=preferred_dtype())
|
|
65
66
|
|
|
66
67
|
if bool(self.kwargs.get("normalize", True)):
|
|
67
68
|
v = self._safe_l2_normalize(v)
|
|
@@ -16,16 +16,6 @@ __all__ = [
|
|
|
16
16
|
class ProgrammaticNonsenseGenerator:
|
|
17
17
|
"""Generate nonsense contrastive pairs programmatically without using LLM."""
|
|
18
18
|
|
|
19
|
-
# Word list for word salad mode
|
|
20
|
-
WORD_LIST = [
|
|
21
|
-
"purple", "elephant", "calculator", "yesterday", "moon", "basket", "thinking",
|
|
22
|
-
"telephone", "mountain", "running", "quickly", "tomorrow", "happiness", "keyboard",
|
|
23
|
-
"window", "dancing", "coffee", "planet", "singing", "computer", "orange", "flying",
|
|
24
|
-
"bicycle", "dream", "ocean", "pencil", "laughing", "cloud", "table", "walking",
|
|
25
|
-
"music", "river", "chair", "jumping", "sun", "book", "swimming", "star", "door",
|
|
26
|
-
"cooking", "tree", "writing", "sky", "flower", "playing", "rain", "paper", "sleeping"
|
|
27
|
-
]
|
|
28
|
-
|
|
29
19
|
def __init__(
|
|
30
20
|
self,
|
|
31
21
|
nonsense_mode: str,
|
|
@@ -46,6 +36,18 @@ class ProgrammaticNonsenseGenerator:
|
|
|
46
36
|
self.contrastive_set_name = contrastive_set_name
|
|
47
37
|
self.trait_label = trait_label
|
|
48
38
|
self.trait_description = trait_description
|
|
39
|
+
self._valid_words = None
|
|
40
|
+
|
|
41
|
+
def set_tokenizer(self, tokenizer) -> None:
|
|
42
|
+
"""Extract valid words from tokenizer vocabulary."""
|
|
43
|
+
vocab = tokenizer.get_vocab()
|
|
44
|
+
valid_words = []
|
|
45
|
+
for token, token_id in vocab.items():
|
|
46
|
+
decoded = tokenizer.decode([token_id])
|
|
47
|
+
clean = decoded.strip()
|
|
48
|
+
if clean.isalpha() and len(clean) > 1 and len(clean) < 15:
|
|
49
|
+
valid_words.append(clean)
|
|
50
|
+
self._valid_words = list(set(valid_words))
|
|
49
51
|
|
|
50
52
|
def generate(self, num_pairs: int = 10) -> ContrastivePairSet:
|
|
51
53
|
"""
|
|
@@ -108,11 +110,14 @@ class ProgrammaticNonsenseGenerator:
|
|
|
108
110
|
|
|
109
111
|
def _generate_repetitive(self) -> str:
|
|
110
112
|
"""Generate pathologically repetitive text."""
|
|
113
|
+
if self._valid_words is None:
|
|
114
|
+
raise ValueError("Tokenizer must be set. Call set_tokenizer() first.")
|
|
115
|
+
|
|
111
116
|
# Pick a random word or phrase
|
|
112
117
|
choices = [
|
|
113
118
|
random.choice(string.ascii_lowercase), # Single letter
|
|
114
|
-
random.choice(self.
|
|
115
|
-
' '.join(random.sample(self.
|
|
119
|
+
random.choice(self._valid_words), # Single word
|
|
120
|
+
' '.join(random.sample(self._valid_words, 2)), # Two-word phrase
|
|
116
121
|
]
|
|
117
122
|
unit = random.choice(choices)
|
|
118
123
|
|
|
@@ -121,13 +126,20 @@ class ProgrammaticNonsenseGenerator:
|
|
|
121
126
|
return ' '.join([unit] * repetitions)
|
|
122
127
|
|
|
123
128
|
def _generate_word_salad(self) -> str:
|
|
124
|
-
"""Generate word salad (
|
|
125
|
-
num_words = random.randint(
|
|
126
|
-
|
|
127
|
-
|
|
129
|
+
"""Generate word salad (random tokens from tokenizer vocabulary)."""
|
|
130
|
+
num_words = random.randint(3, 10)
|
|
131
|
+
|
|
132
|
+
if self._valid_words is not None:
|
|
133
|
+
words = random.choices(self._valid_words, k=num_words)
|
|
134
|
+
return ' '.join(words)
|
|
135
|
+
|
|
136
|
+
raise ValueError("Tokenizer must be set to generate word salad. Call set_tokenizer() first.")
|
|
128
137
|
|
|
129
138
|
def _generate_mixed(self) -> str:
|
|
130
139
|
"""Generate mixed nonsense (combination of all types)."""
|
|
140
|
+
if self._valid_words is None:
|
|
141
|
+
raise ValueError("Tokenizer must be set. Call set_tokenizer() first.")
|
|
142
|
+
|
|
131
143
|
components = []
|
|
132
144
|
|
|
133
145
|
# Add 2-4 different types of nonsense
|
|
@@ -140,11 +152,11 @@ class ProgrammaticNonsenseGenerator:
|
|
|
140
152
|
length = random.randint(5, 15)
|
|
141
153
|
components.append(''.join(random.choices(string.ascii_lowercase, k=length)))
|
|
142
154
|
elif mode == 'repetitive':
|
|
143
|
-
word = random.choice(self.
|
|
155
|
+
word = random.choice(self._valid_words)
|
|
144
156
|
reps = random.randint(3, 6)
|
|
145
157
|
components.append(' '.join([word] * reps))
|
|
146
158
|
else: # word_salad
|
|
147
159
|
num_words = random.randint(3, 6)
|
|
148
|
-
components.append(' '.join(random.choices(self.
|
|
160
|
+
components.append(' '.join(random.choices(self._valid_words, k=num_words)))
|
|
149
161
|
|
|
150
162
|
return ' '.join(components)
|
|
@@ -48,8 +48,8 @@ class WisentSteeringTrainer(BaseSteeringTrainer):
|
|
|
48
48
|
model: WisentModel to use for activation collection.
|
|
49
49
|
pair_set: ContrastivePairSet with pairs to use for collection and training.
|
|
50
50
|
steering_method: BaseSteeringMethod instance to use for training.
|
|
51
|
-
store_device: Device to store collected activations on (default "cpu").
|
|
52
|
-
dtype: Optional torch.dtype to cast collected activations to
|
|
51
|
+
store_device: Device to store collected activations on (default: "cpu" to avoid GPU OOM).
|
|
52
|
+
dtype: Optional torch.dtype to cast collected activations to.
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
55
|
model: WisentModel
|