wisent 0.7.701__py3-none-any.whl → 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +12 -7
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +260 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/train_classifier.py +16 -3
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +2 -2
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/RECORD +328 -358
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info → wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,995 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Geometry search runner.
|
|
3
|
+
|
|
4
|
+
Runs geometry tests across the search space using cached activations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import random
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from wisent.core.geometry_search_space import GeometrySearchSpace, GeometrySearchConfig
|
|
20
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
21
|
+
from wisent.core.activations.activation_cache import (
|
|
22
|
+
ActivationCache,
|
|
23
|
+
CachedActivations,
|
|
24
|
+
collect_and_cache_activations,
|
|
25
|
+
)
|
|
26
|
+
from wisent.core.utils.layer_combinations import get_layer_combinations
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_signal_strength(
|
|
30
|
+
pos_activations: torch.Tensor,
|
|
31
|
+
neg_activations: torch.Tensor,
|
|
32
|
+
n_folds: int = 5,
|
|
33
|
+
) -> float:
|
|
34
|
+
"""
|
|
35
|
+
Compute signal strength using MLP cross-validation accuracy.
|
|
36
|
+
|
|
37
|
+
This measures whether there is ANY extractable signal (linear or nonlinear)
|
|
38
|
+
that generalizes to unseen data. Random/nonsense data gives ~0.5.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
42
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
43
|
+
n_folds: Number of CV folds
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Cross-validation accuracy (0.5 = no signal, >0.7 = signal exists)
|
|
47
|
+
"""
|
|
48
|
+
try:
|
|
49
|
+
from sklearn.neural_network import MLPClassifier
|
|
50
|
+
from sklearn.model_selection import cross_val_score
|
|
51
|
+
|
|
52
|
+
n_pos = len(pos_activations)
|
|
53
|
+
n_neg = len(neg_activations)
|
|
54
|
+
|
|
55
|
+
if n_pos < 5 or n_neg < 5:
|
|
56
|
+
return 0.5 # Not enough data
|
|
57
|
+
|
|
58
|
+
X = torch.cat([pos_activations, neg_activations], dim=0).float().cpu().numpy()
|
|
59
|
+
y = np.array([1] * n_pos + [0] * n_neg)
|
|
60
|
+
|
|
61
|
+
n_folds = min(n_folds, min(n_pos, n_neg))
|
|
62
|
+
if n_folds < 2:
|
|
63
|
+
return 0.5
|
|
64
|
+
|
|
65
|
+
clf = MLPClassifier(
|
|
66
|
+
hidden_layer_sizes=(16,),
|
|
67
|
+
max_iter=500,
|
|
68
|
+
random_state=42,
|
|
69
|
+
)
|
|
70
|
+
scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
|
|
71
|
+
return float(scores.mean())
|
|
72
|
+
except Exception:
|
|
73
|
+
return 0.5
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def compute_knn_accuracy(
|
|
77
|
+
pos_activations: torch.Tensor,
|
|
78
|
+
neg_activations: torch.Tensor,
|
|
79
|
+
k: int = 10,
|
|
80
|
+
n_folds: int = 5,
|
|
81
|
+
) -> float:
|
|
82
|
+
"""
|
|
83
|
+
Compute k-NN cross-validation accuracy.
|
|
84
|
+
|
|
85
|
+
Measures local separability without assuming linearity.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
89
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
90
|
+
k: Number of neighbors
|
|
91
|
+
n_folds: Number of CV folds
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Cross-validation accuracy
|
|
95
|
+
"""
|
|
96
|
+
try:
|
|
97
|
+
from sklearn.neighbors import KNeighborsClassifier
|
|
98
|
+
from sklearn.model_selection import cross_val_score
|
|
99
|
+
|
|
100
|
+
n_pos = len(pos_activations)
|
|
101
|
+
n_neg = len(neg_activations)
|
|
102
|
+
|
|
103
|
+
if n_pos < k + 1 or n_neg < k + 1:
|
|
104
|
+
return 0.5
|
|
105
|
+
|
|
106
|
+
X = torch.cat([pos_activations, neg_activations], dim=0).float().cpu().numpy()
|
|
107
|
+
y = np.array([1] * n_pos + [0] * n_neg)
|
|
108
|
+
|
|
109
|
+
n_folds = min(n_folds, min(n_pos, n_neg))
|
|
110
|
+
if n_folds < 2:
|
|
111
|
+
return 0.5
|
|
112
|
+
|
|
113
|
+
clf = KNeighborsClassifier(n_neighbors=k)
|
|
114
|
+
scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
|
|
115
|
+
return float(scores.mean())
|
|
116
|
+
except Exception:
|
|
117
|
+
return 0.5
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def compute_mmd_rbf(
|
|
121
|
+
pos_activations: torch.Tensor,
|
|
122
|
+
neg_activations: torch.Tensor,
|
|
123
|
+
) -> float:
|
|
124
|
+
"""
|
|
125
|
+
Compute Maximum Mean Discrepancy with RBF kernel.
|
|
126
|
+
|
|
127
|
+
Measures distribution difference without assuming linearity.
|
|
128
|
+
Higher values indicate more separable distributions.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
132
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
MMD value (0 = identical distributions)
|
|
136
|
+
"""
|
|
137
|
+
try:
|
|
138
|
+
from sklearn.metrics.pairwise import rbf_kernel
|
|
139
|
+
from scipy.spatial.distance import cdist
|
|
140
|
+
|
|
141
|
+
pos = pos_activations.float().cpu().numpy()
|
|
142
|
+
neg = neg_activations.float().cpu().numpy()
|
|
143
|
+
|
|
144
|
+
# Use median heuristic for gamma
|
|
145
|
+
all_data = np.vstack([pos, neg])
|
|
146
|
+
dists = cdist(all_data, all_data, 'euclidean')
|
|
147
|
+
gamma = 1.0 / (2 * np.median(dists[dists > 0]) ** 2 + 1e-10)
|
|
148
|
+
|
|
149
|
+
K_pp = rbf_kernel(pos, pos, gamma=gamma)
|
|
150
|
+
K_nn = rbf_kernel(neg, neg, gamma=gamma)
|
|
151
|
+
K_pn = rbf_kernel(pos, neg, gamma=gamma)
|
|
152
|
+
|
|
153
|
+
m = len(pos)
|
|
154
|
+
n = len(neg)
|
|
155
|
+
|
|
156
|
+
mmd = (K_pp.sum() / (m * m) +
|
|
157
|
+
K_nn.sum() / (n * n) -
|
|
158
|
+
2 * K_pn.sum() / (m * n))
|
|
159
|
+
|
|
160
|
+
return float(max(0, mmd))
|
|
161
|
+
except Exception:
|
|
162
|
+
return 0.0
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def estimate_local_intrinsic_dim(X: np.ndarray, k: int = 10) -> float:
|
|
166
|
+
"""
|
|
167
|
+
Estimate local intrinsic dimensionality using MLE method.
|
|
168
|
+
Based on Levina & Bickel (2004).
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
X: [N, D] data matrix
|
|
172
|
+
k: Number of neighbors for estimation
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Estimated intrinsic dimension
|
|
176
|
+
"""
|
|
177
|
+
from scipy.spatial.distance import cdist
|
|
178
|
+
|
|
179
|
+
if len(X) < k + 1:
|
|
180
|
+
return float(X.shape[1])
|
|
181
|
+
|
|
182
|
+
dists = cdist(X, X, 'euclidean')
|
|
183
|
+
np.fill_diagonal(dists, np.inf)
|
|
184
|
+
|
|
185
|
+
sorted_dists = np.sort(dists, axis=1)[:, :k]
|
|
186
|
+
|
|
187
|
+
dims = []
|
|
188
|
+
for i in range(len(X)):
|
|
189
|
+
T_k = sorted_dists[i, k-1]
|
|
190
|
+
if T_k < 1e-10:
|
|
191
|
+
continue
|
|
192
|
+
log_ratios = np.log(sorted_dists[i, :k-1] / T_k + 1e-10)
|
|
193
|
+
if len(log_ratios) > 0 and log_ratios.sum() < 0:
|
|
194
|
+
dim_est = -(k - 1) / log_ratios.sum()
|
|
195
|
+
dims.append(min(dim_est, X.shape[1]))
|
|
196
|
+
|
|
197
|
+
return float(np.median(dims)) if dims else float(X.shape[1])
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def compute_local_intrinsic_dims(
|
|
201
|
+
pos_activations: torch.Tensor,
|
|
202
|
+
neg_activations: torch.Tensor,
|
|
203
|
+
k: int = 10,
|
|
204
|
+
) -> tuple:
|
|
205
|
+
"""
|
|
206
|
+
Compute local intrinsic dimension for pos and neg separately.
|
|
207
|
+
|
|
208
|
+
Different local dimensions suggest different geometric structures.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
212
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
213
|
+
k: Number of neighbors
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
(local_dim_pos, local_dim_neg, ratio)
|
|
217
|
+
"""
|
|
218
|
+
try:
|
|
219
|
+
pos = pos_activations.float().cpu().numpy()
|
|
220
|
+
neg = neg_activations.float().cpu().numpy()
|
|
221
|
+
|
|
222
|
+
dim_pos = estimate_local_intrinsic_dim(pos, k)
|
|
223
|
+
dim_neg = estimate_local_intrinsic_dim(neg, k)
|
|
224
|
+
ratio = dim_pos / (dim_neg + 1e-10)
|
|
225
|
+
|
|
226
|
+
return dim_pos, dim_neg, ratio
|
|
227
|
+
except Exception:
|
|
228
|
+
return 0.0, 0.0, 1.0
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def compute_fisher_per_dimension(
|
|
232
|
+
pos_activations: torch.Tensor,
|
|
233
|
+
neg_activations: torch.Tensor,
|
|
234
|
+
) -> dict:
|
|
235
|
+
"""
|
|
236
|
+
Compute Fisher ratio for each dimension and summary stats.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
240
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Dict with fisher_max, fisher_gini, fisher_top10_ratio, num_dims_above_1
|
|
244
|
+
"""
|
|
245
|
+
try:
|
|
246
|
+
pos = pos_activations.float().cpu().numpy()
|
|
247
|
+
neg = neg_activations.float().cpu().numpy()
|
|
248
|
+
|
|
249
|
+
n_dims = pos.shape[1]
|
|
250
|
+
fishers = np.zeros(n_dims)
|
|
251
|
+
|
|
252
|
+
for d in range(n_dims):
|
|
253
|
+
pos_d = pos[:, d]
|
|
254
|
+
neg_d = neg[:, d]
|
|
255
|
+
|
|
256
|
+
mean_pos = pos_d.mean()
|
|
257
|
+
mean_neg = neg_d.mean()
|
|
258
|
+
var_pos = pos_d.var()
|
|
259
|
+
var_neg = neg_d.var()
|
|
260
|
+
|
|
261
|
+
between_var = (mean_pos - mean_neg) ** 2
|
|
262
|
+
within_var = (var_pos + var_neg) / 2
|
|
263
|
+
|
|
264
|
+
if within_var > 1e-10:
|
|
265
|
+
fishers[d] = between_var / within_var
|
|
266
|
+
|
|
267
|
+
# Summary stats
|
|
268
|
+
fisher_max = float(fishers.max())
|
|
269
|
+
|
|
270
|
+
# Gini coefficient
|
|
271
|
+
values = np.abs(fishers)
|
|
272
|
+
if values.sum() > 1e-10:
|
|
273
|
+
values = np.sort(values)
|
|
274
|
+
n = len(values)
|
|
275
|
+
fisher_gini = (2 * np.sum((np.arange(1, n+1) * values)) / (n * values.sum())) - (n + 1) / n
|
|
276
|
+
else:
|
|
277
|
+
fisher_gini = 0.0
|
|
278
|
+
|
|
279
|
+
# Top 10 ratio
|
|
280
|
+
sorted_fishers = np.sort(fishers)[::-1]
|
|
281
|
+
top10_sum = sorted_fishers[:10].sum()
|
|
282
|
+
total_sum = fishers.sum() + 1e-10
|
|
283
|
+
fisher_top10_ratio = float(top10_sum / total_sum)
|
|
284
|
+
|
|
285
|
+
num_dims_above_1 = int((fishers > 1.0).sum())
|
|
286
|
+
|
|
287
|
+
return {
|
|
288
|
+
"fisher_max": fisher_max,
|
|
289
|
+
"fisher_gini": float(fisher_gini),
|
|
290
|
+
"fisher_top10_ratio": fisher_top10_ratio,
|
|
291
|
+
"num_dims_fisher_above_1": num_dims_above_1,
|
|
292
|
+
}
|
|
293
|
+
except Exception:
|
|
294
|
+
return {
|
|
295
|
+
"fisher_max": 0.0,
|
|
296
|
+
"fisher_gini": 0.0,
|
|
297
|
+
"fisher_top10_ratio": 0.0,
|
|
298
|
+
"num_dims_fisher_above_1": 0,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def compute_density_ratio(
|
|
303
|
+
pos_activations: torch.Tensor,
|
|
304
|
+
neg_activations: torch.Tensor,
|
|
305
|
+
) -> float:
|
|
306
|
+
"""
|
|
307
|
+
Compute ratio of average intra-class distances.
|
|
308
|
+
|
|
309
|
+
Values far from 1 suggest different local geometries.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
313
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Density ratio (pos avg dist / neg avg dist)
|
|
317
|
+
"""
|
|
318
|
+
try:
|
|
319
|
+
from scipy.spatial.distance import cdist
|
|
320
|
+
|
|
321
|
+
pos = pos_activations.float().cpu().numpy()
|
|
322
|
+
neg = neg_activations.float().cpu().numpy()
|
|
323
|
+
|
|
324
|
+
if len(pos) < 2 or len(neg) < 2:
|
|
325
|
+
return 1.0
|
|
326
|
+
|
|
327
|
+
pos_dists = cdist(pos, pos, 'euclidean')
|
|
328
|
+
neg_dists = cdist(neg, neg, 'euclidean')
|
|
329
|
+
|
|
330
|
+
np.fill_diagonal(pos_dists, np.nan)
|
|
331
|
+
np.fill_diagonal(neg_dists, np.nan)
|
|
332
|
+
|
|
333
|
+
avg_pos = np.nanmean(pos_dists)
|
|
334
|
+
avg_neg = np.nanmean(neg_dists)
|
|
335
|
+
|
|
336
|
+
if avg_neg < 1e-10:
|
|
337
|
+
return 1.0
|
|
338
|
+
|
|
339
|
+
return float(avg_pos / avg_neg)
|
|
340
|
+
except Exception:
|
|
341
|
+
return 1.0
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def compute_linear_probe_accuracy(
|
|
345
|
+
pos_activations: torch.Tensor,
|
|
346
|
+
neg_activations: torch.Tensor,
|
|
347
|
+
n_folds: int = 5,
|
|
348
|
+
) -> float:
|
|
349
|
+
"""
|
|
350
|
+
Compute linear probe cross-validation accuracy.
|
|
351
|
+
|
|
352
|
+
If signal_strength is high but linear_probe is low, the signal is nonlinear.
|
|
353
|
+
If both are high, signal is linear and CAA should work.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
pos_activations: [N, hidden_dim] positive class activations
|
|
357
|
+
neg_activations: [N, hidden_dim] negative class activations
|
|
358
|
+
n_folds: Number of CV folds
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Cross-validation accuracy (0.5 = no linear signal)
|
|
362
|
+
"""
|
|
363
|
+
try:
|
|
364
|
+
from sklearn.linear_model import LogisticRegression
|
|
365
|
+
from sklearn.model_selection import cross_val_score
|
|
366
|
+
|
|
367
|
+
n_pos = len(pos_activations)
|
|
368
|
+
n_neg = len(neg_activations)
|
|
369
|
+
|
|
370
|
+
if n_pos < 5 or n_neg < 5:
|
|
371
|
+
return 0.5
|
|
372
|
+
|
|
373
|
+
X = torch.cat([pos_activations, neg_activations], dim=0).float().cpu().numpy()
|
|
374
|
+
y = np.array([1] * n_pos + [0] * n_neg)
|
|
375
|
+
|
|
376
|
+
n_folds = min(n_folds, min(n_pos, n_neg))
|
|
377
|
+
if n_folds < 2:
|
|
378
|
+
return 0.5
|
|
379
|
+
|
|
380
|
+
clf = LogisticRegression(max_iter=1000, solver='lbfgs')
|
|
381
|
+
scores = cross_val_score(clf, X, y, cv=n_folds, scoring='accuracy')
|
|
382
|
+
return float(scores.mean())
|
|
383
|
+
except Exception:
|
|
384
|
+
return 0.5
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@dataclass
|
|
388
|
+
class GeometryTestResult:
|
|
389
|
+
"""Result of a single geometry test."""
|
|
390
|
+
benchmark: str
|
|
391
|
+
strategy: str
|
|
392
|
+
layers: List[int]
|
|
393
|
+
|
|
394
|
+
# Step 1: Is there any signal? (MLP CV accuracy)
|
|
395
|
+
signal_strength: float # MLP CV accuracy, ~0.5 = no signal, >0.6 = signal exists
|
|
396
|
+
has_signal: bool # signal_strength > 0.6
|
|
397
|
+
|
|
398
|
+
# Step 2: Is signal linear? (Linear probe CV accuracy)
|
|
399
|
+
linear_probe_accuracy: float # Linear CV accuracy, high = linear, low = nonlinear
|
|
400
|
+
is_linear: bool # linear_probe_accuracy > 0.6 AND close to signal_strength
|
|
401
|
+
|
|
402
|
+
# NEW: Nonlinear signal metrics
|
|
403
|
+
knn_accuracy_k5: float # k-NN CV accuracy with k=5
|
|
404
|
+
knn_accuracy_k10: float # k-NN CV accuracy with k=10
|
|
405
|
+
knn_accuracy_k20: float # k-NN CV accuracy with k=20
|
|
406
|
+
mmd_rbf: float # Maximum Mean Discrepancy with RBF kernel
|
|
407
|
+
local_dim_pos: float # Local intrinsic dimension of positive class
|
|
408
|
+
local_dim_neg: float # Local intrinsic dimension of negative class
|
|
409
|
+
local_dim_ratio: float # Ratio of local dimensions
|
|
410
|
+
fisher_max: float # Max Fisher ratio across all dimensions
|
|
411
|
+
fisher_gini: float # Gini coefficient of Fisher ratios (concentration)
|
|
412
|
+
fisher_top10_ratio: float # Fraction of total Fisher in top 10 dims
|
|
413
|
+
num_dims_fisher_above_1: int # Number of dimensions with Fisher > 1
|
|
414
|
+
density_ratio: float # Ratio of avg intra-class distances
|
|
415
|
+
|
|
416
|
+
# Step 3: Geometry details (only meaningful if has_signal=True)
|
|
417
|
+
# Best structure detected
|
|
418
|
+
best_structure: str # 'linear', 'cone', 'cluster', 'manifold', 'sparse', 'bimodal', 'orthogonal'
|
|
419
|
+
best_score: float
|
|
420
|
+
|
|
421
|
+
# All structure scores
|
|
422
|
+
linear_score: float
|
|
423
|
+
cone_score: float
|
|
424
|
+
orthogonal_score: float
|
|
425
|
+
manifold_score: float
|
|
426
|
+
sparse_score: float
|
|
427
|
+
cluster_score: float
|
|
428
|
+
bimodal_score: float
|
|
429
|
+
|
|
430
|
+
# Detailed metrics per structure
|
|
431
|
+
# Linear
|
|
432
|
+
cohens_d: float # separation quality
|
|
433
|
+
variance_explained: float # by primary direction
|
|
434
|
+
within_class_consistency: float
|
|
435
|
+
|
|
436
|
+
# Cone
|
|
437
|
+
raw_mean_cosine_similarity: float # between diff vectors
|
|
438
|
+
positive_correlation_fraction: float # fraction in same half-space
|
|
439
|
+
|
|
440
|
+
# Orthogonal
|
|
441
|
+
near_zero_fraction: float # fraction of near-zero correlations
|
|
442
|
+
|
|
443
|
+
# Manifold
|
|
444
|
+
pca_top2_variance: float # variance by top 2 PCs
|
|
445
|
+
local_nonlinearity: float # curvature measure
|
|
446
|
+
|
|
447
|
+
# Sparse
|
|
448
|
+
gini_coefficient: float # inequality of activations
|
|
449
|
+
active_fraction: float # fraction of active neurons
|
|
450
|
+
top_10_contribution: float # contribution of top 10 neurons
|
|
451
|
+
|
|
452
|
+
# Cluster
|
|
453
|
+
best_silhouette: float # clustering quality
|
|
454
|
+
best_k: int # optimal number of clusters
|
|
455
|
+
|
|
456
|
+
# Recommendation
|
|
457
|
+
recommended_method: str
|
|
458
|
+
|
|
459
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
460
|
+
return {
|
|
461
|
+
"benchmark": self.benchmark,
|
|
462
|
+
"strategy": self.strategy,
|
|
463
|
+
"layers": self.layers,
|
|
464
|
+
# Step 1: Signal detection
|
|
465
|
+
"signal_strength": self.signal_strength,
|
|
466
|
+
"has_signal": self.has_signal,
|
|
467
|
+
# Step 2: Linearity check
|
|
468
|
+
"linear_probe_accuracy": self.linear_probe_accuracy,
|
|
469
|
+
"is_linear": self.is_linear,
|
|
470
|
+
# NEW: Nonlinear signal metrics
|
|
471
|
+
"nonlinear_metrics": {
|
|
472
|
+
"knn_accuracy_k5": self.knn_accuracy_k5,
|
|
473
|
+
"knn_accuracy_k10": self.knn_accuracy_k10,
|
|
474
|
+
"knn_accuracy_k20": self.knn_accuracy_k20,
|
|
475
|
+
"mmd_rbf": self.mmd_rbf,
|
|
476
|
+
"local_dim_pos": self.local_dim_pos,
|
|
477
|
+
"local_dim_neg": self.local_dim_neg,
|
|
478
|
+
"local_dim_ratio": self.local_dim_ratio,
|
|
479
|
+
"fisher_max": self.fisher_max,
|
|
480
|
+
"fisher_gini": self.fisher_gini,
|
|
481
|
+
"fisher_top10_ratio": self.fisher_top10_ratio,
|
|
482
|
+
"num_dims_fisher_above_1": self.num_dims_fisher_above_1,
|
|
483
|
+
"density_ratio": self.density_ratio,
|
|
484
|
+
},
|
|
485
|
+
# Step 3: Geometry (only meaningful if has_signal)
|
|
486
|
+
"best_structure": self.best_structure,
|
|
487
|
+
"best_score": self.best_score,
|
|
488
|
+
"structure_scores": {
|
|
489
|
+
"linear": self.linear_score,
|
|
490
|
+
"cone": self.cone_score,
|
|
491
|
+
"orthogonal": self.orthogonal_score,
|
|
492
|
+
"manifold": self.manifold_score,
|
|
493
|
+
"sparse": self.sparse_score,
|
|
494
|
+
"cluster": self.cluster_score,
|
|
495
|
+
"bimodal": self.bimodal_score,
|
|
496
|
+
},
|
|
497
|
+
"linear_details": {
|
|
498
|
+
"cohens_d": self.cohens_d,
|
|
499
|
+
"variance_explained": self.variance_explained,
|
|
500
|
+
"within_class_consistency": self.within_class_consistency,
|
|
501
|
+
},
|
|
502
|
+
"cone_details": {
|
|
503
|
+
"raw_mean_cosine_similarity": self.raw_mean_cosine_similarity,
|
|
504
|
+
"positive_correlation_fraction": self.positive_correlation_fraction,
|
|
505
|
+
},
|
|
506
|
+
"orthogonal_details": {
|
|
507
|
+
"near_zero_fraction": self.near_zero_fraction,
|
|
508
|
+
},
|
|
509
|
+
"manifold_details": {
|
|
510
|
+
"pca_top2_variance": self.pca_top2_variance,
|
|
511
|
+
"local_nonlinearity": self.local_nonlinearity,
|
|
512
|
+
},
|
|
513
|
+
"sparse_details": {
|
|
514
|
+
"gini_coefficient": self.gini_coefficient,
|
|
515
|
+
"active_fraction": self.active_fraction,
|
|
516
|
+
"top_10_contribution": self.top_10_contribution,
|
|
517
|
+
},
|
|
518
|
+
"cluster_details": {
|
|
519
|
+
"best_silhouette": self.best_silhouette,
|
|
520
|
+
"best_k": self.best_k,
|
|
521
|
+
},
|
|
522
|
+
"recommended_method": self.recommended_method,
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
@dataclass
|
|
527
|
+
class GeometrySearchResults:
|
|
528
|
+
"""Results from a full geometry search."""
|
|
529
|
+
model_name: str
|
|
530
|
+
config: GeometrySearchConfig
|
|
531
|
+
results: List[GeometryTestResult] = field(default_factory=list)
|
|
532
|
+
|
|
533
|
+
# Timing
|
|
534
|
+
total_time_seconds: float = 0.0
|
|
535
|
+
extraction_time_seconds: float = 0.0
|
|
536
|
+
test_time_seconds: float = 0.0
|
|
537
|
+
|
|
538
|
+
# Counts
|
|
539
|
+
benchmarks_tested: int = 0
|
|
540
|
+
strategies_tested: int = 0
|
|
541
|
+
layer_combos_tested: int = 0
|
|
542
|
+
|
|
543
|
+
def add_result(self, result: GeometryTestResult) -> None:
|
|
544
|
+
self.results.append(result)
|
|
545
|
+
|
|
546
|
+
def get_best_by_linear_score(self, n: int = 10) -> List[GeometryTestResult]:
|
|
547
|
+
"""Get top N configurations by linear score."""
|
|
548
|
+
return sorted(self.results, key=lambda r: r.linear_score, reverse=True)[:n]
|
|
549
|
+
|
|
550
|
+
def get_best_by_structure(self, structure: str, n: int = 10) -> List[GeometryTestResult]:
|
|
551
|
+
"""Get top N configurations by a specific structure score."""
|
|
552
|
+
score_attr = f"{structure}_score"
|
|
553
|
+
return sorted(
|
|
554
|
+
self.results,
|
|
555
|
+
key=lambda r: getattr(r, score_attr, 0.0),
|
|
556
|
+
reverse=True
|
|
557
|
+
)[:n]
|
|
558
|
+
|
|
559
|
+
def get_structure_distribution(self) -> Dict[str, int]:
|
|
560
|
+
"""Count how many configurations have each structure as best."""
|
|
561
|
+
counts: Dict[str, int] = {}
|
|
562
|
+
for r in self.results:
|
|
563
|
+
s = r.best_structure
|
|
564
|
+
counts[s] = counts.get(s, 0) + 1
|
|
565
|
+
return counts
|
|
566
|
+
|
|
567
|
+
def get_summary_by_benchmark(self) -> Dict[str, Dict[str, float]]:
|
|
568
|
+
"""Get summary statistics grouped by benchmark."""
|
|
569
|
+
by_bench: Dict[str, List[float]] = {}
|
|
570
|
+
for r in self.results:
|
|
571
|
+
if r.benchmark not in by_bench:
|
|
572
|
+
by_bench[r.benchmark] = []
|
|
573
|
+
by_bench[r.benchmark].append(r.linear_score)
|
|
574
|
+
|
|
575
|
+
return {
|
|
576
|
+
bench: {
|
|
577
|
+
"mean": sum(scores) / len(scores),
|
|
578
|
+
"max": max(scores),
|
|
579
|
+
"min": min(scores),
|
|
580
|
+
"count": len(scores),
|
|
581
|
+
}
|
|
582
|
+
for bench, scores in by_bench.items()
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
586
|
+
return {
|
|
587
|
+
"model_name": self.model_name,
|
|
588
|
+
"config": self.config.to_dict(),
|
|
589
|
+
"total_time_seconds": self.total_time_seconds,
|
|
590
|
+
"extraction_time_seconds": self.extraction_time_seconds,
|
|
591
|
+
"test_time_seconds": self.test_time_seconds,
|
|
592
|
+
"benchmarks_tested": self.benchmarks_tested,
|
|
593
|
+
"strategies_tested": self.strategies_tested,
|
|
594
|
+
"layer_combos_tested": self.layer_combos_tested,
|
|
595
|
+
"results": [r.to_dict() for r in self.results],
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
def save(self, path: str) -> None:
|
|
599
|
+
with open(path, "w") as f:
|
|
600
|
+
json.dump(self.to_dict(), f, indent=2)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def compute_geometry_metrics(
|
|
604
|
+
cached: CachedActivations,
|
|
605
|
+
layers: List[int],
|
|
606
|
+
) -> GeometryTestResult:
|
|
607
|
+
"""
|
|
608
|
+
Compute geometry metrics for a layer combination from cached activations.
|
|
609
|
+
|
|
610
|
+
Uses the comprehensive detect_geometry_structure() to get scores for:
|
|
611
|
+
- linear, cone, cluster, manifold, sparse, bimodal, orthogonal
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
cached: Cached activations with all layers
|
|
615
|
+
layers: Layer indices (0-based) to analyze
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
GeometryTestResult with all structure scores
|
|
619
|
+
"""
|
|
620
|
+
from wisent.core.contrastive_pairs.diagnostics.control_vectors import (
|
|
621
|
+
detect_geometry_structure,
|
|
622
|
+
GeometryAnalysisConfig,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Stack positive and negative activations for specified layers
|
|
626
|
+
# Convert 0-based indices to 1-based layer names used in cache
|
|
627
|
+
pos_acts_list = []
|
|
628
|
+
neg_acts_list = []
|
|
629
|
+
|
|
630
|
+
for layer_idx in layers:
|
|
631
|
+
layer_name = str(layer_idx + 1) # Convert 0-based to 1-based
|
|
632
|
+
try:
|
|
633
|
+
pos = cached.get_positive_activations(layer_name) # [num_pairs, hidden_size]
|
|
634
|
+
neg = cached.get_negative_activations(layer_name) # [num_pairs, hidden_size]
|
|
635
|
+
pos_acts_list.append(pos)
|
|
636
|
+
neg_acts_list.append(neg)
|
|
637
|
+
except (KeyError, IndexError):
|
|
638
|
+
continue
|
|
639
|
+
|
|
640
|
+
if not pos_acts_list:
|
|
641
|
+
return GeometryTestResult(
|
|
642
|
+
benchmark=cached.benchmark,
|
|
643
|
+
strategy=cached.strategy.value,
|
|
644
|
+
layers=layers,
|
|
645
|
+
signal_strength=0.5,
|
|
646
|
+
has_signal=False,
|
|
647
|
+
linear_probe_accuracy=0.5,
|
|
648
|
+
is_linear=False,
|
|
649
|
+
# Nonlinear metrics
|
|
650
|
+
knn_accuracy_k5=0.5,
|
|
651
|
+
knn_accuracy_k10=0.5,
|
|
652
|
+
knn_accuracy_k20=0.5,
|
|
653
|
+
mmd_rbf=0.0,
|
|
654
|
+
local_dim_pos=0.0,
|
|
655
|
+
local_dim_neg=0.0,
|
|
656
|
+
local_dim_ratio=1.0,
|
|
657
|
+
fisher_max=0.0,
|
|
658
|
+
fisher_gini=0.0,
|
|
659
|
+
fisher_top10_ratio=0.0,
|
|
660
|
+
num_dims_fisher_above_1=0,
|
|
661
|
+
density_ratio=1.0,
|
|
662
|
+
# Structure scores
|
|
663
|
+
best_structure="error",
|
|
664
|
+
best_score=0.0,
|
|
665
|
+
linear_score=0.0,
|
|
666
|
+
cone_score=0.0,
|
|
667
|
+
orthogonal_score=0.0,
|
|
668
|
+
manifold_score=0.0,
|
|
669
|
+
sparse_score=0.0,
|
|
670
|
+
cluster_score=0.0,
|
|
671
|
+
bimodal_score=0.0,
|
|
672
|
+
cohens_d=0.0,
|
|
673
|
+
variance_explained=0.0,
|
|
674
|
+
within_class_consistency=0.0,
|
|
675
|
+
raw_mean_cosine_similarity=0.0,
|
|
676
|
+
positive_correlation_fraction=0.0,
|
|
677
|
+
near_zero_fraction=0.0,
|
|
678
|
+
pca_top2_variance=0.0,
|
|
679
|
+
local_nonlinearity=0.0,
|
|
680
|
+
gini_coefficient=0.0,
|
|
681
|
+
active_fraction=0.0,
|
|
682
|
+
top_10_contribution=0.0,
|
|
683
|
+
best_silhouette=0.0,
|
|
684
|
+
best_k=0,
|
|
685
|
+
recommended_method="error: no activations",
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Concatenate across layers: [num_pairs, hidden_size * num_layers]
|
|
689
|
+
pos_activations = torch.cat(pos_acts_list, dim=-1)
|
|
690
|
+
neg_activations = torch.cat(neg_acts_list, dim=-1)
|
|
691
|
+
|
|
692
|
+
# Convert to float32 for geometry analysis (bf16/float16 can cause dtype mismatches)
|
|
693
|
+
pos_activations = pos_activations.float()
|
|
694
|
+
neg_activations = neg_activations.float()
|
|
695
|
+
|
|
696
|
+
# Run comprehensive geometry detection
|
|
697
|
+
config = GeometryAnalysisConfig(
|
|
698
|
+
num_components=5,
|
|
699
|
+
optimization_steps=50, # Reduced for speed since we're testing many combos
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
try:
|
|
703
|
+
result = detect_geometry_structure(pos_activations, neg_activations, config)
|
|
704
|
+
|
|
705
|
+
# Step 1: Compute signal strength (MLP CV accuracy)
|
|
706
|
+
signal_strength = compute_signal_strength(pos_activations, neg_activations)
|
|
707
|
+
has_signal = signal_strength > 0.6
|
|
708
|
+
|
|
709
|
+
# Step 2: Compute linear probe accuracy
|
|
710
|
+
linear_probe_accuracy = compute_linear_probe_accuracy(pos_activations, neg_activations)
|
|
711
|
+
# Signal is linear if: has signal AND linear probe is close to MLP (within 0.1)
|
|
712
|
+
is_linear = has_signal and linear_probe_accuracy > 0.6 and (signal_strength - linear_probe_accuracy) < 0.15
|
|
713
|
+
|
|
714
|
+
# Step 2b: Compute nonlinear signal metrics
|
|
715
|
+
knn_k5 = compute_knn_accuracy(pos_activations, neg_activations, k=5)
|
|
716
|
+
knn_k10 = compute_knn_accuracy(pos_activations, neg_activations, k=10)
|
|
717
|
+
knn_k20 = compute_knn_accuracy(pos_activations, neg_activations, k=20)
|
|
718
|
+
mmd = compute_mmd_rbf(pos_activations, neg_activations)
|
|
719
|
+
local_dim_pos, local_dim_neg, local_dim_ratio = compute_local_intrinsic_dims(pos_activations, neg_activations)
|
|
720
|
+
fisher_stats = compute_fisher_per_dimension(pos_activations, neg_activations)
|
|
721
|
+
density_rat = compute_density_ratio(pos_activations, neg_activations)
|
|
722
|
+
|
|
723
|
+
# Determine recommendation based on signal analysis
|
|
724
|
+
if not has_signal:
|
|
725
|
+
recommendation = "NO_SIGNAL"
|
|
726
|
+
elif is_linear:
|
|
727
|
+
recommendation = "CAA" # Linear signal -> use Contrastive Activation Addition
|
|
728
|
+
else:
|
|
729
|
+
recommendation = "NONLINEAR" # Nonlinear signal -> need different method
|
|
730
|
+
|
|
731
|
+
# Helper to safely get detail
|
|
732
|
+
def get_detail(struct_name: str, key: str, default=0.0):
|
|
733
|
+
if struct_name in result.all_scores:
|
|
734
|
+
return result.all_scores[struct_name].details.get(key, default)
|
|
735
|
+
return default
|
|
736
|
+
|
|
737
|
+
return GeometryTestResult(
|
|
738
|
+
benchmark=cached.benchmark,
|
|
739
|
+
strategy=cached.strategy.value,
|
|
740
|
+
layers=layers,
|
|
741
|
+
signal_strength=signal_strength,
|
|
742
|
+
has_signal=has_signal,
|
|
743
|
+
linear_probe_accuracy=linear_probe_accuracy,
|
|
744
|
+
is_linear=is_linear,
|
|
745
|
+
# Nonlinear metrics
|
|
746
|
+
knn_accuracy_k5=knn_k5,
|
|
747
|
+
knn_accuracy_k10=knn_k10,
|
|
748
|
+
knn_accuracy_k20=knn_k20,
|
|
749
|
+
mmd_rbf=mmd,
|
|
750
|
+
local_dim_pos=local_dim_pos,
|
|
751
|
+
local_dim_neg=local_dim_neg,
|
|
752
|
+
local_dim_ratio=local_dim_ratio,
|
|
753
|
+
fisher_max=fisher_stats["fisher_max"],
|
|
754
|
+
fisher_gini=fisher_stats["fisher_gini"],
|
|
755
|
+
fisher_top10_ratio=fisher_stats["fisher_top10_ratio"],
|
|
756
|
+
num_dims_fisher_above_1=fisher_stats["num_dims_fisher_above_1"],
|
|
757
|
+
density_ratio=density_rat,
|
|
758
|
+
# Structure scores
|
|
759
|
+
best_structure=result.best_structure.value,
|
|
760
|
+
best_score=result.best_score,
|
|
761
|
+
linear_score=result.all_scores.get("linear", type('', (), {'score': 0.0})()).score,
|
|
762
|
+
cone_score=result.all_scores.get("cone", type('', (), {'score': 0.0})()).score,
|
|
763
|
+
orthogonal_score=result.all_scores.get("orthogonal", type('', (), {'score': 0.0})()).score,
|
|
764
|
+
manifold_score=result.all_scores.get("manifold", type('', (), {'score': 0.0})()).score,
|
|
765
|
+
sparse_score=result.all_scores.get("sparse", type('', (), {'score': 0.0})()).score,
|
|
766
|
+
cluster_score=result.all_scores.get("cluster", type('', (), {'score': 0.0})()).score,
|
|
767
|
+
bimodal_score=result.all_scores.get("bimodal", type('', (), {'score': 0.0})()).score,
|
|
768
|
+
# Linear details
|
|
769
|
+
cohens_d=get_detail("linear", "cohens_d", 0.0),
|
|
770
|
+
variance_explained=get_detail("linear", "variance_explained", 0.0),
|
|
771
|
+
within_class_consistency=get_detail("linear", "within_class_consistency", 0.0),
|
|
772
|
+
# Cone details
|
|
773
|
+
raw_mean_cosine_similarity=get_detail("cone", "raw_mean_cosine_similarity", 0.0),
|
|
774
|
+
positive_correlation_fraction=get_detail("cone", "positive_correlation_fraction", 0.0),
|
|
775
|
+
# Orthogonal details
|
|
776
|
+
near_zero_fraction=get_detail("orthogonal", "near_zero_fraction", 0.0),
|
|
777
|
+
# Manifold details
|
|
778
|
+
pca_top2_variance=get_detail("manifold", "pca_top2_variance", 0.0),
|
|
779
|
+
local_nonlinearity=get_detail("manifold", "local_nonlinearity", 0.0),
|
|
780
|
+
# Sparse details
|
|
781
|
+
gini_coefficient=get_detail("sparse", "gini_coefficient", 0.0),
|
|
782
|
+
active_fraction=get_detail("sparse", "active_fraction", 0.0),
|
|
783
|
+
top_10_contribution=get_detail("sparse", "top_10_contribution", 0.0),
|
|
784
|
+
# Cluster details
|
|
785
|
+
best_silhouette=get_detail("cluster", "best_silhouette", 0.0),
|
|
786
|
+
best_k=int(get_detail("cluster", "best_k", 2)),
|
|
787
|
+
# Recommendation based on signal analysis
|
|
788
|
+
recommended_method=recommendation,
|
|
789
|
+
)
|
|
790
|
+
except Exception as e:
|
|
791
|
+
return GeometryTestResult(
|
|
792
|
+
benchmark=cached.benchmark,
|
|
793
|
+
strategy=cached.strategy.value,
|
|
794
|
+
layers=layers,
|
|
795
|
+
signal_strength=0.5,
|
|
796
|
+
has_signal=False,
|
|
797
|
+
linear_probe_accuracy=0.5,
|
|
798
|
+
is_linear=False,
|
|
799
|
+
# Nonlinear metrics
|
|
800
|
+
knn_accuracy_k5=0.5,
|
|
801
|
+
knn_accuracy_k10=0.5,
|
|
802
|
+
knn_accuracy_k20=0.5,
|
|
803
|
+
mmd_rbf=0.0,
|
|
804
|
+
local_dim_pos=0.0,
|
|
805
|
+
local_dim_neg=0.0,
|
|
806
|
+
local_dim_ratio=1.0,
|
|
807
|
+
fisher_max=0.0,
|
|
808
|
+
fisher_gini=0.0,
|
|
809
|
+
fisher_top10_ratio=0.0,
|
|
810
|
+
num_dims_fisher_above_1=0,
|
|
811
|
+
density_ratio=1.0,
|
|
812
|
+
# Structure scores
|
|
813
|
+
best_structure="error",
|
|
814
|
+
best_score=0.0,
|
|
815
|
+
linear_score=0.0,
|
|
816
|
+
cone_score=0.0,
|
|
817
|
+
orthogonal_score=0.0,
|
|
818
|
+
manifold_score=0.0,
|
|
819
|
+
sparse_score=0.0,
|
|
820
|
+
cluster_score=0.0,
|
|
821
|
+
bimodal_score=0.0,
|
|
822
|
+
cohens_d=0.0,
|
|
823
|
+
variance_explained=0.0,
|
|
824
|
+
within_class_consistency=0.0,
|
|
825
|
+
raw_mean_cosine_similarity=0.0,
|
|
826
|
+
positive_correlation_fraction=0.0,
|
|
827
|
+
near_zero_fraction=0.0,
|
|
828
|
+
pca_top2_variance=0.0,
|
|
829
|
+
local_nonlinearity=0.0,
|
|
830
|
+
gini_coefficient=0.0,
|
|
831
|
+
active_fraction=0.0,
|
|
832
|
+
top_10_contribution=0.0,
|
|
833
|
+
best_silhouette=0.0,
|
|
834
|
+
best_k=0,
|
|
835
|
+
recommended_method=f"error: {str(e)}",
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
class GeometryRunner:
|
|
840
|
+
"""
|
|
841
|
+
Runs geometry search across the search space.
|
|
842
|
+
|
|
843
|
+
Uses activation caching for efficiency:
|
|
844
|
+
1. Extract ALL layers once per (benchmark, strategy)
|
|
845
|
+
2. Test all layer combinations from cache
|
|
846
|
+
"""
|
|
847
|
+
|
|
848
|
+
def __init__(
|
|
849
|
+
self,
|
|
850
|
+
search_space: GeometrySearchSpace,
|
|
851
|
+
model: "WisentModel",
|
|
852
|
+
cache_dir: Optional[str] = None,
|
|
853
|
+
):
|
|
854
|
+
self.search_space = search_space
|
|
855
|
+
self.model = model
|
|
856
|
+
self.cache_dir = cache_dir or f"/tmp/wisent_geometry_cache_{model.model_name.replace('/', '_')}"
|
|
857
|
+
self.cache = ActivationCache(self.cache_dir)
|
|
858
|
+
|
|
859
|
+
def run(
|
|
860
|
+
self,
|
|
861
|
+
benchmarks: Optional[List[str]] = None,
|
|
862
|
+
strategies: Optional[List[ExtractionStrategy]] = None,
|
|
863
|
+
max_layer_combo_size: Optional[int] = None,
|
|
864
|
+
show_progress: bool = True,
|
|
865
|
+
) -> GeometrySearchResults:
|
|
866
|
+
"""
|
|
867
|
+
Run the geometry search.
|
|
868
|
+
|
|
869
|
+
Args:
|
|
870
|
+
benchmarks: Benchmarks to test (default: all from search space)
|
|
871
|
+
strategies: Strategies to test (default: all from search space)
|
|
872
|
+
max_layer_combo_size: Override max layer combo size
|
|
873
|
+
show_progress: Print progress
|
|
874
|
+
|
|
875
|
+
Returns:
|
|
876
|
+
GeometrySearchResults with all test results
|
|
877
|
+
"""
|
|
878
|
+
benchmarks = benchmarks or self.search_space.benchmarks
|
|
879
|
+
strategies = strategies or self.search_space.strategies
|
|
880
|
+
max_combo = max_layer_combo_size or self.search_space.config.max_layer_combo_size
|
|
881
|
+
|
|
882
|
+
# Get layer combinations
|
|
883
|
+
num_layers = self.model.num_layers
|
|
884
|
+
layer_combos = get_layer_combinations(num_layers, max_combo)
|
|
885
|
+
|
|
886
|
+
results = GeometrySearchResults(
|
|
887
|
+
model_name=self.model.model_name,
|
|
888
|
+
config=self.search_space.config,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
start_time = time.time()
|
|
892
|
+
extraction_time = 0.0
|
|
893
|
+
test_time = 0.0
|
|
894
|
+
|
|
895
|
+
total_extractions = len(benchmarks) * len(strategies)
|
|
896
|
+
extraction_count = 0
|
|
897
|
+
|
|
898
|
+
for benchmark in benchmarks:
|
|
899
|
+
for strategy in strategies:
|
|
900
|
+
extraction_count += 1
|
|
901
|
+
|
|
902
|
+
if show_progress:
|
|
903
|
+
print(f"\n[{extraction_count}/{total_extractions}] {benchmark} / {strategy.value}")
|
|
904
|
+
|
|
905
|
+
# Get or create cached activations
|
|
906
|
+
extract_start = time.time()
|
|
907
|
+
try:
|
|
908
|
+
cached = self._get_cached_activations(benchmark, strategy, show_progress)
|
|
909
|
+
except Exception as e:
|
|
910
|
+
if show_progress:
|
|
911
|
+
print(f" SKIP: {e}")
|
|
912
|
+
continue
|
|
913
|
+
extraction_time += time.time() - extract_start
|
|
914
|
+
|
|
915
|
+
# Test all layer combinations
|
|
916
|
+
test_start = time.time()
|
|
917
|
+
for combo in layer_combos:
|
|
918
|
+
result = compute_geometry_metrics(cached, combo)
|
|
919
|
+
results.add_result(result)
|
|
920
|
+
test_time += time.time() - test_start
|
|
921
|
+
|
|
922
|
+
results.benchmarks_tested = len(set(r.benchmark for r in results.results))
|
|
923
|
+
results.strategies_tested = len(set(r.strategy for r in results.results))
|
|
924
|
+
results.layer_combos_tested = len(results.results)
|
|
925
|
+
|
|
926
|
+
if show_progress:
|
|
927
|
+
print(f" Tested {len(layer_combos)} layer combos")
|
|
928
|
+
|
|
929
|
+
results.total_time_seconds = time.time() - start_time
|
|
930
|
+
results.extraction_time_seconds = extraction_time
|
|
931
|
+
results.test_time_seconds = test_time
|
|
932
|
+
|
|
933
|
+
return results
|
|
934
|
+
|
|
935
|
+
def _get_cached_activations(
|
|
936
|
+
self,
|
|
937
|
+
benchmark: str,
|
|
938
|
+
strategy: ExtractionStrategy,
|
|
939
|
+
show_progress: bool = True,
|
|
940
|
+
) -> CachedActivations:
|
|
941
|
+
"""Get cached activations, extracting if necessary."""
|
|
942
|
+
# Check cache
|
|
943
|
+
if self.cache.has(self.model.model_name, benchmark, strategy):
|
|
944
|
+
if show_progress:
|
|
945
|
+
print(f" Loading from cache...")
|
|
946
|
+
return self.cache.get(self.model.model_name, benchmark, strategy)
|
|
947
|
+
|
|
948
|
+
# Need to extract - load pairs first
|
|
949
|
+
if show_progress:
|
|
950
|
+
print(f" Loading pairs...")
|
|
951
|
+
|
|
952
|
+
pairs = self._load_pairs(benchmark)
|
|
953
|
+
|
|
954
|
+
if show_progress:
|
|
955
|
+
print(f" Extracting activations for {len(pairs)} pairs...")
|
|
956
|
+
|
|
957
|
+
return collect_and_cache_activations(
|
|
958
|
+
model=self.model,
|
|
959
|
+
pairs=pairs,
|
|
960
|
+
benchmark=benchmark,
|
|
961
|
+
strategy=strategy,
|
|
962
|
+
cache=self.cache,
|
|
963
|
+
show_progress=show_progress,
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
def _load_pairs(self, benchmark: str) -> List:
|
|
967
|
+
"""Load contrastive pairs for a benchmark."""
|
|
968
|
+
from lm_eval.tasks import TaskManager
|
|
969
|
+
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import lm_build_contrastive_pairs
|
|
970
|
+
|
|
971
|
+
tm = TaskManager()
|
|
972
|
+
try:
|
|
973
|
+
task_dict = tm.load_task_or_group([benchmark])
|
|
974
|
+
task = list(task_dict.values())[0]
|
|
975
|
+
except Exception:
|
|
976
|
+
task = None
|
|
977
|
+
|
|
978
|
+
pairs = lm_build_contrastive_pairs(
|
|
979
|
+
benchmark,
|
|
980
|
+
task,
|
|
981
|
+
limit=self.search_space.config.pairs_per_benchmark
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
# Random sample if we have more pairs than needed
|
|
985
|
+
if len(pairs) > self.search_space.config.pairs_per_benchmark:
|
|
986
|
+
random.seed(self.search_space.config.random_seed)
|
|
987
|
+
pairs = random.sample(pairs, self.search_space.config.pairs_per_benchmark)
|
|
988
|
+
|
|
989
|
+
return pairs
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
# Type hints
|
|
993
|
+
from typing import TYPE_CHECKING
|
|
994
|
+
if TYPE_CHECKING:
|
|
995
|
+
from wisent.core.models.wisent_model import WisentModel
|