wisent 0.7.379__py3-none-any.whl → 0.7.901__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +22 -6
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +22 -40
- wisent/core/activations/activations_collector.py +145 -373
- wisent/core/activations/classifier_inference_strategy.py +195 -0
- wisent/core/activations/core/atoms.py +8 -92
- wisent/core/activations/extraction_strategy.py +480 -0
- wisent/core/agent/diagnose/response_diagnostics.py +3 -3
- wisent/core/agent/diagnose.py +3 -3
- wisent/core/autonomous_agent.py +2 -2
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/apply_steering.py +25 -31
- wisent/core/cli/agent/evaluate_response.py +18 -20
- wisent/core/cli/agent/train_classifier.py +36 -26
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +470 -0
- wisent/core/cli/create_steering_vector.py +19 -9
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/generate_vector_from_task.py +4 -0
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +13 -37
- wisent/core/cli/method_optimizer.py +860 -0
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize.py +44 -5
- wisent/core/cli/optimize_classification.py +5 -6
- wisent/core/cli/optimize_sample_size.py +9 -23
- wisent/core/cli/optimize_steering.py +433 -159
- wisent/core/cli/optimize_weights.py +67 -7
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +8 -7
- wisent/core/cli/steering_search_space.py +20 -15
- wisent/core/cli/tasks.py +31 -117
- wisent/core/cli/train_unified_goodness.py +18 -19
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +1582 -177
- wisent/core/contrastive_pairs/diagnostics/linearity.py +70 -80
- wisent/core/contrastive_pairs/diagnostics/vector_quality.py +6 -5
- wisent/core/contrastive_pairs/huggingface_pairs/hf_extractor_manifest.py +5 -19
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/__init__.py +11 -5
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/apps.py +146 -32
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue.py +2 -2
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/humaneval.py +98 -57
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/group_task_manifests/code_x_glue.py +8 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/group_task_manifests/freebase.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +11 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/agieval_aqua_rat.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/code_x_glue.py +11 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mbpp.py +47 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/evaluators/benchmark_specific/apps_evaluator.py +133 -0
- wisent/core/evaluators/benchmark_specific/coding/metrics/evaluator.py +6 -1
- wisent/core/evaluators/benchmark_specific/conala_evaluator.py +31 -168
- wisent/core/evaluators/custom/examples/humanization_coherent.py +89 -35
- wisent/core/evaluators/oracles/truthfulqa_gen_evaluator.py +2 -20
- wisent/core/evaluators/personalization/coherence.py +46 -0
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +14 -14
- wisent/core/lm_eval_harness_ground_truth.py +7 -11
- wisent/core/main.py +6 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +9 -8
- wisent/core/opti/methods/opti_weights.py +29 -2
- wisent/core/optuna/classifier/activation_generator.py +14 -12
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +14 -9
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/cluster_benchmarks_parser.py +31 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +22 -2
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/main_parser.py +16 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +117 -10
- wisent/core/parser_arguments/optimize_weights_parser.py +6 -0
- wisent/core/parser_arguments/tasks_parser.py +7 -19
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/core/atoms.py +1 -2
- wisent/core/steering_methods/methods/caa.py +1 -1
- wisent/core/steering_methods/methods/hyperplane.py +75 -0
- wisent/core/steering_methods/methods/prism.py +1 -2
- wisent/core/steering_methods/methods/pulse.py +39 -8
- wisent/core/steering_methods/methods/titan.py +59 -14
- wisent/core/steering_methods/registry.py +52 -12
- wisent/core/steering_optimizer.py +15 -15
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +11 -20
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/generate_paper_data.py +384 -0
- wisent/examples/scripts/intervention_validation.py +626 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_evaluation.json +324 -0
- wisent/examples/scripts/results/test_AraDiCE_ArabicMMLU_lev_pairs.json +92 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_evaluation.json +324 -0
- wisent/examples/scripts/results/test_aexams_IslamicStudies_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimgsm_pairs.json +92 -0
- wisent/examples/scripts/results/test_afrimmlu_evaluation.json +324 -0
- wisent/examples/scripts/results/test_afrimmlu_pairs.json +92 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/examples/scripts/threshold_analysis.py +434 -0
- wisent/examples/scripts/visualization_gallery.py +582 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/track_progress_not_lm_eval_tasks.json +19 -70
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/scripts/run_quality_metrics_sweep.sh +22 -27
- wisent/tests/test_aggregation_geometry.py +236 -0
- wisent/tests/test_detector_accuracy.py +163 -0
- wisent/tests/test_geometry_exhaustive.py +1202 -0
- wisent/tests/visualize_geometry.py +255 -61
- {wisent-0.7.379.dist-info → wisent-0.7.901.dist-info}/METADATA +1 -1
- {wisent-0.7.379.dist-info → wisent-0.7.901.dist-info}/RECORD +376 -974
- wisent/core/activations/prompt_construction_strategy.py +0 -47
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text.py +0 -15
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_go.py +0 -64
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_java.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_javascript.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_php.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_python.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codexglue_code_to_text_ruby.py +0 -65
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/freebase.py +0 -99
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/instruct_humaneval.py +0 -180
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/instructhumaneval.py +0 -129
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mbpp.py +0 -142
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/agieval.py +0 -155
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/code2text.py +0 -161
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/codexglue.py +0 -107
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livemathbench.py +0 -155
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/polymath.py +0 -155
- wisent/examples/scripts/results/benchmark_descriptions.json +0 -1244
- wisent/examples/scripts/results/benchmark_evaluation_methods.json +0 -66
- wisent/examples/scripts/results/benchmark_evaluator_mapping.json +0 -2781
- wisent/examples/scripts/results/benchmark_evaluator_mapping_updated.json +0 -30536
- wisent/examples/scripts/results/benchmark_evaluators_clean.json +0 -469
- wisent/examples/scripts/results/benchmark_methods_summary.json +0 -260
- wisent/examples/scripts/results/benchmark_pair_creation_methods.json +0 -66
- wisent/examples/scripts/results/benchmark_pair_totals.json +0 -269
- wisent/examples/scripts/results/benchmark_tags.json +0 -917
- wisent/examples/scripts/results/benchmark_test_summary_nov4.json +0 -71
- wisent/examples/scripts/results/coding_benchmarks_test_code_status.json +0 -150
- wisent/examples/scripts/results/failing_benchmarks.json +0 -946
- wisent/examples/scripts/results/failing_benchmarks_list.json +0 -41
- wisent/examples/scripts/results/failing_benchmarks_test_results.json +0 -945
- wisent/examples/scripts/results/missing_benchmark_tags.json +0 -341
- wisent/examples/scripts/results/test_20_newsgroups_evaluation.json +0 -30
- wisent/examples/scripts/results/test_20_newsgroups_pairs.json +0 -8
- wisent/examples/scripts/results/test_AraDICE_evaluation.json +0 -51
- wisent/examples/scripts/results/test_AraDICE_pairs.json +0 -14
- wisent/examples/scripts/results/test_AraDiCE_boolq_egy/test_AraDiCE_boolq_egy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_AraDiCE_boolq_egy/test_AraDiCE_boolq_egy_pairs.json +0 -8
- wisent/examples/scripts/results/test_ArabCulture_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ArabCulture_pairs.json +0 -14
- wisent/examples/scripts/results/test_Tag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_Tag_pairs.json +0 -8
- wisent/examples/scripts/results/test_aclue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aclue_pairs.json +0 -14
- wisent/examples/scripts/results/test_acp_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_acp_bench_hard_evaluation.json +0 -51
- wisent/examples/scripts/results/test_acp_bench_hard_pairs.json +0 -14
- wisent/examples/scripts/results/test_acp_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_advanced_ai_risk_evaluation.json +0 -51
- wisent/examples/scripts/results/test_advanced_ai_risk_pairs.json +0 -14
- wisent/examples/scripts/results/test_aexams_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aexams_pairs.json +0 -14
- wisent/examples/scripts/results/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/results/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/results/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/results/test_ag_news_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ag_news_pairs.json +0 -8
- wisent/examples/scripts/results/test_agieval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_agieval_pairs.json +0 -14
- wisent/examples/scripts/results/test_aime2024_evaluation.json +0 -30
- wisent/examples/scripts/results/test_aime2024_pairs.json +0 -8
- wisent/examples/scripts/results/test_aime2025_evaluation.json +0 -30
- wisent/examples/scripts/results/test_aime2025_pairs.json +0 -8
- wisent/examples/scripts/results/test_aime_evaluation.json +0 -30
- wisent/examples/scripts/results/test_aime_pairs.json +0 -8
- wisent/examples/scripts/results/test_anagrams1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_anagrams1_pairs.json +0 -8
- wisent/examples/scripts/results/test_anagrams2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_anagrams2_pairs.json +0 -8
- wisent/examples/scripts/results/test_anli_evaluation.json +0 -30
- wisent/examples/scripts/results/test_anli_pairs.json +0 -8
- wisent/examples/scripts/results/test_apps_evaluation.json +0 -30
- wisent/examples/scripts/results/test_apps_pairs.json +0 -8
- wisent/examples/scripts/results/test_arabic_exams_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arabic_exams_pairs.json +0 -8
- wisent/examples/scripts/results/test_arabic_leaderboard_complete_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arabic_leaderboard_complete_pairs.json +0 -14
- wisent/examples/scripts/results/test_arabic_leaderboard_light_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arabic_leaderboard_light_pairs.json +0 -14
- wisent/examples/scripts/results/test_arabicmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arabicmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_aradice/test_aradice_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aradice/test_aradice_pairs.json +0 -14
- wisent/examples/scripts/results/test_aradice3/test_aradice_evaluation.json +0 -51
- wisent/examples/scripts/results/test_aradice3/test_aradice_pairs.json +0 -14
- wisent/examples/scripts/results/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/results/test_arc_challenge_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arc_challenge_pairs.json +0 -8
- wisent/examples/scripts/results/test_arc_easy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_arc_easy_pairs.json +0 -8
- wisent/examples/scripts/results/test_argument_topic_evaluation.json +0 -30
- wisent/examples/scripts/results/test_argument_topic_pairs.json +0 -8
- wisent/examples/scripts/results/test_arithmetic_evaluation.json +0 -51
- wisent/examples/scripts/results/test_arithmetic_pairs.json +0 -14
- wisent/examples/scripts/results/test_asdiv_evaluation.json +0 -30
- wisent/examples/scripts/results/test_asdiv_pairs.json +0 -8
- wisent/examples/scripts/results/test_assin_entailment_evaluation.json +0 -30
- wisent/examples/scripts/results/test_assin_entailment_pairs.json +0 -8
- wisent/examples/scripts/results/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/results/test_atis_pairs.json +0 -8
- wisent/examples/scripts/results/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/results/test_babi_pairs.json +0 -8
- wisent/examples/scripts/results/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/results/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/results/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/results/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/results/test_banking77_evaluation.json +0 -30
- wisent/examples/scripts/results/test_banking77_pairs.json +0 -8
- wisent/examples/scripts/results/test_basque/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque-glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque2/test_basque-glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque2/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_basque_glue/test_basque-glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basque_glue/test_basque-glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/results/test_bbh_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bbh_pairs.json +0 -14
- wisent/examples/scripts/results/test_bbq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_bbq_pairs.json +0 -8
- wisent/examples/scripts/results/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/results/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/results/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/results/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/results/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/results/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/results/test_bigbench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_bigbench_pairs.json +0 -14
- wisent/examples/scripts/results/test_blimp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_blimp_pairs.json +0 -14
- wisent/examples/scripts/results/test_boolq/test_boolq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_boolq/test_boolq_pairs.json +0 -8
- wisent/examples/scripts/results/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/results/test_boolq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_boolq_pairs.json +0 -8
- wisent/examples/scripts/results/test_c4_evaluation.json +0 -30
- wisent/examples/scripts/results/test_c4_pairs.json +0 -8
- wisent/examples/scripts/results/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/results/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_catalan_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_catalan_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/results/test_cb_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cb_pairs.json +0 -8
- wisent/examples/scripts/results/test_ceval/test_ceval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ceval/test_ceval_pairs.json +0 -14
- wisent/examples/scripts/results/test_ceval_accountant/test_ceval-valid_accountant_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ceval_accountant/test_ceval-valid_accountant_pairs.json +0 -8
- wisent/examples/scripts/results/test_ceval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ceval_pairs.json +0 -14
- wisent/examples/scripts/results/test_ceval_valid/test_ceval_valid_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ceval_valid/test_ceval_valid_pairs.json +0 -14
- wisent/examples/scripts/results/test_chain_of_thought_evaluation.json +0 -51
- wisent/examples/scripts/results/test_chain_of_thought_pairs.json +0 -14
- wisent/examples/scripts/results/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/results/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/results/test_cmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_cmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/results/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_go_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_go_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_java_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_java_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_javascript_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_javascript_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_php_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_php_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_python_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_python_pairs.json +0 -8
- wisent/examples/scripts/results/test_codexglue_code_to_text_ruby_evaluation.json +0 -30
- wisent/examples/scripts/results/test_codexglue_code_to_text_ruby_pairs.json +0 -8
- wisent/examples/scripts/results/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/results/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/results/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cola_pairs.json +0 -8
- wisent/examples/scripts/results/test_commonsense_qa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_commonsense_qa_pairs.json +0 -8
- wisent/examples/scripts/results/test_conala_evaluation.json +0 -30
- wisent/examples/scripts/results/test_conala_pairs.json +0 -8
- wisent/examples/scripts/results/test_concode_evaluation.json +0 -30
- wisent/examples/scripts/results/test_concode_pairs.json +0 -8
- wisent/examples/scripts/results/test_copa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_copa_pairs.json +0 -8
- wisent/examples/scripts/results/test_copal_id_evaluation.json +0 -30
- wisent/examples/scripts/results/test_copal_id_pairs.json +0 -8
- wisent/examples/scripts/results/test_coqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_coqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/results/test_crows_pairs_evaluation.json +0 -51
- wisent/examples/scripts/results/test_crows_pairs_pairs.json +0 -14
- wisent/examples/scripts/results/test_csatqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_csatqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_cycle_letters_evaluation.json +0 -30
- wisent/examples/scripts/results/test_cycle_letters_pairs.json +0 -8
- wisent/examples/scripts/results/test_darija_bench/test_darija_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_darija_bench/test_darija_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_darija_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_darija_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_darijahellaswag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_darijahellaswag_pairs.json +0 -8
- wisent/examples/scripts/results/test_darijammlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_darijammlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/results/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/results/test_drop_evaluation.json +0 -30
- wisent/examples/scripts/results/test_drop_pairs.json +0 -8
- wisent/examples/scripts/results/test_ds1000_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ds1000_pairs.json +0 -8
- wisent/examples/scripts/results/test_egyhellaswag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_egyhellaswag_pairs.json +0 -8
- wisent/examples/scripts/results/test_egymmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_egymmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/results/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/results/test_eq_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eq_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_escola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_escola_pairs.json +0 -8
- wisent/examples/scripts/results/test_ethics_cm_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ethics_cm_pairs.json +0 -8
- wisent/examples/scripts/results/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/results/test_eus_exams/test_eus_exams_evaluation.json +0 -51
- wisent/examples/scripts/results/test_eus_exams/test_eus_exams_pairs.json +0 -14
- wisent/examples/scripts/results/test_eus_exams_es_evaluation.json +0 -51
- wisent/examples/scripts/results/test_eus_exams_es_pairs.json +0 -14
- wisent/examples/scripts/results/test_eus_exams_evaluation.json +0 -51
- wisent/examples/scripts/results/test_eus_exams_pairs.json +0 -14
- wisent/examples/scripts/results/test_eus_proficiency_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eus_proficiency_pairs.json +0 -8
- wisent/examples/scripts/results/test_eus_reading_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eus_reading_pairs.json +0 -8
- wisent/examples/scripts/results/test_eus_trivia_evaluation.json +0 -30
- wisent/examples/scripts/results/test_eus_trivia_pairs.json +0 -8
- wisent/examples/scripts/results/test_evalita-mp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita-mp_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita-sp_sum_task_fp-small_p1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_evalita-sp_sum_task_fp-small_p1_pairs.json +0 -8
- wisent/examples/scripts/results/test_evalita_LLM_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita_LLM_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita_llm/test_evalita_llm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita_llm/test_evalita_llm_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita_mp/test_evalita-mp_te_prompt-1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_evalita_mp/test_evalita-mp_te_prompt-1_pairs.json +0 -8
- wisent/examples/scripts/results/test_evalita_mp2/test_evalita_mp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_evalita_mp2/test_evalita_mp_pairs.json +0 -14
- wisent/examples/scripts/results/test_evalita_sp2/test_evalita-sp_sum_task_fp-small_p1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_evalita_sp2/test_evalita-sp_sum_task_fp-small_p1_pairs.json +0 -8
- wisent/examples/scripts/results/test_fda_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fda_pairs.json +0 -8
- wisent/examples/scripts/results/test_financial_tweets_evaluation.json +0 -30
- wisent/examples/scripts/results/test_financial_tweets_pairs.json +0 -8
- wisent/examples/scripts/results/test_fld/test_fld_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fld/test_fld_pairs.json +0 -8
- wisent/examples/scripts/results/test_fld_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fld_fixed/test_fld_evaluation.json +0 -30
- wisent/examples/scripts/results/test_fld_fixed/test_fld_pairs.json +0 -8
- wisent/examples/scripts/results/test_fld_pairs.json +0 -8
- wisent/examples/scripts/results/test_flores_evaluation.json +0 -51
- wisent/examples/scripts/results/test_flores_pairs.json +0 -14
- wisent/examples/scripts/results/test_freebase_evaluation.json +0 -30
- wisent/examples/scripts/results/test_freebase_pairs.json +0 -8
- wisent/examples/scripts/results/test_french_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_french_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_galcola_evaluation.json +0 -30
- wisent/examples/scripts/results/test_galcola_pairs.json +0 -8
- wisent/examples/scripts/results/test_galician_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_galician_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_glianorex_evaluation.json +0 -30
- wisent/examples/scripts/results/test_glianorex_pairs.json +0 -8
- wisent/examples/scripts/results/test_global_mmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_global_mmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_glue_evaluation.json +0 -51
- wisent/examples/scripts/results/test_glue_pairs.json +0 -14
- wisent/examples/scripts/results/test_gpqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_gpqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_gpt3_translation_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/results/test_gpt3_translation_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/results/test_groundcocoa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_groundcocoa_pairs.json +0 -8
- wisent/examples/scripts/results/test_gsm8k_evaluation.json +0 -30
- wisent/examples/scripts/results/test_gsm8k_pairs.json +0 -8
- wisent/examples/scripts/results/test_haerae_evaluation.json +0 -51
- wisent/examples/scripts/results/test_haerae_pairs.json +0 -14
- wisent/examples/scripts/results/test_headqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_headqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_hellaswag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_hellaswag_pairs.json +0 -8
- wisent/examples/scripts/results/test_hendrycks_ethics_evaluation.json +0 -51
- wisent/examples/scripts/results/test_hendrycks_ethics_pairs.json +0 -14
- wisent/examples/scripts/results/test_hendrycks_math_evaluation.json +0 -51
- wisent/examples/scripts/results/test_hendrycks_math_pairs.json +0 -14
- wisent/examples/scripts/results/test_histoires_morales_evaluation.json +0 -30
- wisent/examples/scripts/results/test_histoires_morales_pairs.json +0 -8
- wisent/examples/scripts/results/test_hmmt_evaluation.json +0 -30
- wisent/examples/scripts/results/test_hmmt_feb_2025_evaluation.json +0 -30
- wisent/examples/scripts/results/test_hmmt_feb_2025_pairs.json +0 -8
- wisent/examples/scripts/results/test_hmmt_pairs.json +0 -8
- wisent/examples/scripts/results/test_hrm8k_evaluation.json +0 -51
- wisent/examples/scripts/results/test_hrm8k_pairs.json +0 -14
- wisent/examples/scripts/results/test_humaneval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_humaneval_pairs.json +0 -8
- wisent/examples/scripts/results/test_humaneval_plus_evaluation.json +0 -30
- wisent/examples/scripts/results/test_humaneval_plus_pairs.json +0 -8
- wisent/examples/scripts/results/test_ifeval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ifeval_pairs.json +0 -8
- wisent/examples/scripts/results/test_instruct_humaneval/test_instruct_humaneval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_instruct_humaneval/test_instruct_humaneval_pairs.json +0 -8
- wisent/examples/scripts/results/test_instruct_humaneval_evaluation.json +0 -30
- wisent/examples/scripts/results/test_instruct_humaneval_pairs.json +0 -8
- wisent/examples/scripts/results/test_inverse_scaling_evaluation.json +0 -51
- wisent/examples/scripts/results/test_inverse_scaling_hindsight_neglect_10shot_evaluation.json +0 -30
- wisent/examples/scripts/results/test_inverse_scaling_hindsight_neglect_10shot_pairs.json +0 -8
- wisent/examples/scripts/results/test_inverse_scaling_mc/test_inverse_scaling_mc_evaluation.json +0 -51
- wisent/examples/scripts/results/test_inverse_scaling_mc/test_inverse_scaling_mc_pairs.json +0 -14
- wisent/examples/scripts/results/test_inverse_scaling_pairs.json +0 -14
- wisent/examples/scripts/results/test_iwslt2017-ar-en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017-ar-en_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017-en-ar_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017-en-ar_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017_ar_en/test_iwslt2017-ar-en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017_ar_en/test_iwslt2017-ar-en_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017_en_ar/test_iwslt2017-en-ar_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017_en_ar/test_iwslt2017-en-ar_pairs.json +0 -8
- wisent/examples/scripts/results/test_iwslt2017_group/test_iwslt2017_evaluation.json +0 -30
- wisent/examples/scripts/results/test_iwslt2017_group/test_iwslt2017_pairs.json +0 -8
- wisent/examples/scripts/results/test_japanese_leaderboard_evaluation.json +0 -51
- wisent/examples/scripts/results/test_japanese_leaderboard_pairs.json +0 -14
- wisent/examples/scripts/results/test_jsonschema_bench/test_jsonschema_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_jsonschema_bench/test_jsonschema_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_jsonschema_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_jsonschema_bench_final/test_jsonschema_bench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_jsonschema_bench_final/test_jsonschema_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_jsonschema_bench_pairs.json +0 -8
- wisent/examples/scripts/results/test_kbl_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kbl_fixed/test_kbl_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kbl_fixed/test_kbl_pairs.json +0 -14
- wisent/examples/scripts/results/test_kbl_pairs.json +0 -14
- wisent/examples/scripts/results/test_kmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_kobest_evaluation.json +0 -51
- wisent/examples/scripts/results/test_kobest_pairs.json +0 -14
- wisent/examples/scripts/results/test_kormedmcqa/test_kormedmcqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_kormedmcqa/test_kormedmcqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_kormedmcqa_dentist/test_kormedmcqa_dentist_evaluation.json +0 -30
- wisent/examples/scripts/results/test_kormedmcqa_dentist/test_kormedmcqa_dentist_pairs.json +0 -8
- wisent/examples/scripts/results/test_kormedmcqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_kormedmcqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_cloze_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_cloze_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_final/test_lambada_openai_mt_stablelm_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_final/test_lambada_openai_mt_stablelm_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_multilingual/test_lambada_multilingual_evaluation.json +0 -51
- wisent/examples/scripts/results/test_lambada_multilingual/test_lambada_multilingual_pairs.json +0 -14
- wisent/examples/scripts/results/test_lambada_multilingual_evaluation.json +0 -51
- wisent/examples/scripts/results/test_lambada_multilingual_pairs.json +0 -14
- wisent/examples/scripts/results/test_lambada_multilingual_stablelm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_lambada_multilingual_stablelm_pairs.json +0 -14
- wisent/examples/scripts/results/test_lambada_openai_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_openai_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_stablelm_en_fixed/test_lambada_openai_mt_stablelm_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_stablelm_en_fixed/test_lambada_openai_mt_stablelm_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_stablelm_fixed/test_lambada_openai_mt_stablelm_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_stablelm_fixed/test_lambada_openai_mt_stablelm_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_lambada_standard_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lambada_standard_pairs.json +0 -8
- wisent/examples/scripts/results/test_leaderboard_evaluation.json +0 -51
- wisent/examples/scripts/results/test_leaderboard_pairs.json +0 -14
- wisent/examples/scripts/results/test_libra/test_libra_evaluation.json +0 -51
- wisent/examples/scripts/results/test_libra/test_libra_pairs.json +0 -14
- wisent/examples/scripts/results/test_libra_evaluation.json +0 -51
- wisent/examples/scripts/results/test_libra_pairs.json +0 -14
- wisent/examples/scripts/results/test_lingoly_evaluation.json +0 -30
- wisent/examples/scripts/results/test_lingoly_pairs.json +0 -8
- wisent/examples/scripts/results/test_livecodebench_evaluation.json +0 -30
- wisent/examples/scripts/results/test_livecodebench_pairs.json +0 -8
- wisent/examples/scripts/results/test_livemathbench_cnmo_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_livemathbench_cnmo_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_livemathbench_cnmo_zh_evaluation.json +0 -30
- wisent/examples/scripts/results/test_livemathbench_cnmo_zh_pairs.json +0 -8
- wisent/examples/scripts/results/test_llama_evaluation.json +0 -30
- wisent/examples/scripts/results/test_llama_pairs.json +0 -8
- wisent/examples/scripts/results/test_logiqa2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_logiqa2_pairs.json +0 -8
- wisent/examples/scripts/results/test_logiqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_logiqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_m_mmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_m_mmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_mastermind/test_mastermind_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mastermind/test_mastermind_pairs.json +0 -14
- wisent/examples/scripts/results/test_mastermind_24_easy/test_mastermind_24_easy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mastermind_24_easy/test_mastermind_24_easy_pairs.json +0 -8
- wisent/examples/scripts/results/test_mastermind_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mastermind_pairs.json +0 -14
- wisent/examples/scripts/results/test_math500_evaluation.json +0 -30
- wisent/examples/scripts/results/test_math500_pairs.json +0 -8
- wisent/examples/scripts/results/test_math_evaluation.json +0 -30
- wisent/examples/scripts/results/test_math_pairs.json +0 -8
- wisent/examples/scripts/results/test_mathqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mathqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_mbpp_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mbpp_pairs.json +0 -8
- wisent/examples/scripts/results/test_mbpp_plus_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mbpp_plus_pairs.json +0 -8
- wisent/examples/scripts/results/test_mc_taco_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mc_taco_pairs.json +0 -8
- wisent/examples/scripts/results/test_med_concepts_qa/test_med_concepts_qa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_med_concepts_qa/test_med_concepts_qa_pairs.json +0 -14
- wisent/examples/scripts/results/test_med_concepts_qa_atc_easy/test_med_concepts_qa_atc_easy_evaluation.json +0 -30
- wisent/examples/scripts/results/test_med_concepts_qa_atc_easy/test_med_concepts_qa_atc_easy_pairs.json +0 -8
- wisent/examples/scripts/results/test_med_concepts_qa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_med_concepts_qa_pairs.json +0 -14
- wisent/examples/scripts/results/test_meddialog_evaluation.json +0 -30
- wisent/examples/scripts/results/test_meddialog_pairs.json +0 -8
- wisent/examples/scripts/results/test_meddialog_raw_perplexity/test_meddialog_raw_perplexity_evaluation.json +0 -30
- wisent/examples/scripts/results/test_meddialog_raw_perplexity/test_meddialog_raw_perplexity_pairs.json +0 -8
- wisent/examples/scripts/results/test_mediqa_qa2019_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mediqa_qa2019_pairs.json +0 -8
- wisent/examples/scripts/results/test_medmcqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_medmcqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_medqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_medqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_medtext_evaluation.json +0 -30
- wisent/examples/scripts/results/test_medtext_pairs.json +0 -8
- wisent/examples/scripts/results/test_mela_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mela_pairs.json +0 -14
- wisent/examples/scripts/results/test_meqsum_evaluation.json +0 -30
- wisent/examples/scripts/results/test_meqsum_pairs.json +0 -8
- wisent/examples/scripts/results/test_mercury_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mercury_pairs.json +0 -8
- wisent/examples/scripts/results/test_metabench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_metabench_pairs.json +0 -14
- wisent/examples/scripts/results/test_mgsm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mgsm_pairs.json +0 -14
- wisent/examples/scripts/results/test_mimic_repsum_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mimic_repsum_pairs.json +0 -8
- wisent/examples/scripts/results/test_minerva_math_evaluation.json +0 -51
- wisent/examples/scripts/results/test_minerva_math_pairs.json +0 -14
- wisent/examples/scripts/results/test_mlqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mlqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu-pro-plus_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu-pro-plus_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu_pro_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu_pro_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlu_prox_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmlu_prox_pairs.json +0 -14
- wisent/examples/scripts/results/test_mmlusr_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mmlusr_pairs.json +0 -8
- wisent/examples/scripts/results/test_mmmu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_mmmu_pairs.json +0 -14
- wisent/examples/scripts/results/test_mnli_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mnli_pairs.json +0 -8
- wisent/examples/scripts/results/test_model_written_evals_evaluation.json +0 -51
- wisent/examples/scripts/results/test_model_written_evals_pairs.json +0 -14
- wisent/examples/scripts/results/test_moral_stories_evaluation.json +0 -30
- wisent/examples/scripts/results/test_moral_stories_pairs.json +0 -8
- wisent/examples/scripts/results/test_mts_dialog_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mts_dialog_pairs.json +0 -8
- wisent/examples/scripts/results/test_multiblimp_evaluation.json +0 -51
- wisent/examples/scripts/results/test_multiblimp_pairs.json +0 -14
- wisent/examples/scripts/results/test_multimedqa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_multimedqa_pairs.json +0 -14
- wisent/examples/scripts/results/test_multipl_e_evaluation.json +0 -30
- wisent/examples/scripts/results/test_multipl_e_pairs.json +0 -8
- wisent/examples/scripts/results/test_mutual_evaluation.json +0 -30
- wisent/examples/scripts/results/test_mutual_pairs.json +0 -8
- wisent/examples/scripts/results/test_non_greedy_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_non_greedy_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_noreval_evaluation.json +0 -51
- wisent/examples/scripts/results/test_noreval_pairs.json +0 -14
- wisent/examples/scripts/results/test_noticia_evaluation.json +0 -30
- wisent/examples/scripts/results/test_noticia_pairs.json +0 -8
- wisent/examples/scripts/results/test_nq_open_evaluation.json +0 -30
- wisent/examples/scripts/results/test_nq_open_pairs.json +0 -8
- wisent/examples/scripts/results/test_olaph_evaluation.json +0 -30
- wisent/examples/scripts/results/test_olaph_pairs.json +0 -8
- wisent/examples/scripts/results/test_openbookqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_openbookqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_openllm_evaluation.json +0 -51
- wisent/examples/scripts/results/test_openllm_pairs.json +0 -14
- wisent/examples/scripts/results/test_option_order_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_option_order_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_paloma_evaluation.json +0 -51
- wisent/examples/scripts/results/test_paloma_pairs.json +0 -14
- wisent/examples/scripts/results/test_passkey/test_passkey_evaluation.json +0 -30
- wisent/examples/scripts/results/test_passkey/test_passkey_pairs.json +0 -8
- wisent/examples/scripts/results/test_paws-x_evaluation.json +0 -51
- wisent/examples/scripts/results/test_paws-x_pairs.json +0 -14
- wisent/examples/scripts/results/test_paws_en/test_paws_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_paws_en/test_paws_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_penn_treebank_evaluation.json +0 -30
- wisent/examples/scripts/results/test_penn_treebank_pairs.json +0 -8
- wisent/examples/scripts/results/test_pile_10k/test_pile_10k_evaluation.json +0 -30
- wisent/examples/scripts/results/test_pile_10k/test_pile_10k_pairs.json +0 -8
- wisent/examples/scripts/results/test_piqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_piqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_polemo2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polemo2_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_en_high_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_en_high_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_en_medium_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_en_medium_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_zh_high_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_zh_high_pairs.json +0 -8
- wisent/examples/scripts/results/test_polymath_zh_medium_evaluation.json +0 -30
- wisent/examples/scripts/results/test_polymath_zh_medium_pairs.json +0 -8
- wisent/examples/scripts/results/test_portuguese_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_portuguese_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat/test_prompt_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat/test_prompt_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat_evaluation.json +0 -30
- wisent/examples/scripts/results/test_prompt_robustness_agieval_aqua_rat_pairs.json +0 -8
- wisent/examples/scripts/results/test_prost_evaluation.json +0 -30
- wisent/examples/scripts/results/test_prost_pairs.json +0 -8
- wisent/examples/scripts/results/test_ptb_evaluation.json +0 -30
- wisent/examples/scripts/results/test_ptb_pairs.json +0 -8
- wisent/examples/scripts/results/test_pubmedqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_pubmedqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_pythia_evaluation.json +0 -51
- wisent/examples/scripts/results/test_pythia_pairs.json +0 -14
- wisent/examples/scripts/results/test_qa4mre_evaluation.json +0 -30
- wisent/examples/scripts/results/test_qa4mre_pairs.json +0 -8
- wisent/examples/scripts/results/test_qasper_evaluation.json +0 -30
- wisent/examples/scripts/results/test_qasper_pairs.json +0 -8
- wisent/examples/scripts/results/test_race_evaluation.json +0 -30
- wisent/examples/scripts/results/test_race_pairs.json +0 -8
- wisent/examples/scripts/results/test_realtoxicityprompts_evaluation.json +0 -30
- wisent/examples/scripts/results/test_realtoxicityprompts_pairs.json +0 -8
- wisent/examples/scripts/results/test_recode_evaluation.json +0 -30
- wisent/examples/scripts/results/test_recode_pairs.json +0 -8
- wisent/examples/scripts/results/test_record_evaluation.json +0 -30
- wisent/examples/scripts/results/test_record_pairs.json +0 -8
- wisent/examples/scripts/results/test_ruler_evaluation.json +0 -51
- wisent/examples/scripts/results/test_ruler_pairs.json +0 -14
- wisent/examples/scripts/results/test_sciq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_sciq_pairs.json +0 -8
- wisent/examples/scripts/results/test_score_evaluation.json +0 -51
- wisent/examples/scripts/results/test_score_pairs.json +0 -14
- wisent/examples/scripts/results/test_self_consistency_evaluation.json +0 -30
- wisent/examples/scripts/results/test_self_consistency_pairs.json +0 -8
- wisent/examples/scripts/results/test_siqa/test_siqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_siqa/test_siqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_siqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_siqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_spanish_bench_evaluation.json +0 -51
- wisent/examples/scripts/results/test_spanish_bench_pairs.json +0 -14
- wisent/examples/scripts/results/test_squad2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_squad2_pairs.json +0 -8
- wisent/examples/scripts/results/test_squadv2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_squadv2_pairs.json +0 -8
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1_evaluation.json +0 -51
- wisent/examples/scripts/results/test_super-glue-lm-eval-v1_pairs.json +0 -14
- wisent/examples/scripts/results/test_swag_evaluation.json +0 -30
- wisent/examples/scripts/results/test_swag_pairs.json +0 -8
- wisent/examples/scripts/results/test_tinyBenchmarks_evaluation.json +0 -51
- wisent/examples/scripts/results/test_tinyBenchmarks_pairs.json +0 -14
- wisent/examples/scripts/results/test_tmmluplus_evaluation.json +0 -51
- wisent/examples/scripts/results/test_tmmluplus_pairs.json +0 -14
- wisent/examples/scripts/results/test_translation_evaluation.json +0 -51
- wisent/examples/scripts/results/test_translation_pairs.json +0 -14
- wisent/examples/scripts/results/test_triviaqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_triviaqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_truthfulqa-multi_evaluation.json +0 -51
- wisent/examples/scripts/results/test_truthfulqa-multi_pairs.json +0 -14
- wisent/examples/scripts/results/test_truthfulqa_evaluation.json +0 -30
- wisent/examples/scripts/results/test_truthfulqa_mc1_evaluation.json +0 -30
- wisent/examples/scripts/results/test_truthfulqa_mc1_pairs.json +0 -8
- wisent/examples/scripts/results/test_truthfulqa_mc2_evaluation.json +0 -30
- wisent/examples/scripts/results/test_truthfulqa_mc2_pairs.json +0 -8
- wisent/examples/scripts/results/test_truthfulqa_pairs.json +0 -8
- wisent/examples/scripts/results/test_turkishmmlu_evaluation.json +0 -51
- wisent/examples/scripts/results/test_turkishmmlu_pairs.json +0 -14
- wisent/examples/scripts/results/test_unfair_tos_evaluation.json +0 -30
- wisent/examples/scripts/results/test_unfair_tos_pairs.json +0 -8
- wisent/examples/scripts/results/test_unscramble_evaluation.json +0 -51
- wisent/examples/scripts/results/test_unscramble_pairs.json +0 -14
- wisent/examples/scripts/results/test_webqs_evaluation.json +0 -30
- wisent/examples/scripts/results/test_webqs_pairs.json +0 -8
- wisent/examples/scripts/results/test_wikitext103_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wikitext103_pairs.json +0 -8
- wisent/examples/scripts/results/test_wikitext_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wikitext_pairs.json +0 -8
- wisent/examples/scripts/results/test_winogender_evaluation.json +0 -51
- wisent/examples/scripts/results/test_winogender_pairs.json +0 -14
- wisent/examples/scripts/results/test_winogrande_evaluation.json +0 -30
- wisent/examples/scripts/results/test_winogrande_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmdp_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmdp_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt-ro-en-t5-prompt_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt-ro-en-t5-prompt_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt14_en_fr_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt14_en_fr_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt16_en_de_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt16_en_de_pairs.json +0 -8
- wisent/examples/scripts/results/test_wmt16_ro_en_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wmt16_ro_en_pairs.json +0 -8
- wisent/examples/scripts/results/test_wsc273_evaluation.json +0 -30
- wisent/examples/scripts/results/test_wsc273_pairs.json +0 -8
- wisent/examples/scripts/results/test_xcopa_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xcopa_pairs.json +0 -14
- wisent/examples/scripts/results/test_xnli_eu_evaluation.json +0 -30
- wisent/examples/scripts/results/test_xnli_eu_pairs.json +0 -8
- wisent/examples/scripts/results/test_xnli_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xnli_pairs.json +0 -14
- wisent/examples/scripts/results/test_xquad_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xquad_pairs.json +0 -14
- wisent/examples/scripts/results/test_xstorycloze_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xstorycloze_pairs.json +0 -14
- wisent/examples/scripts/results/test_xsum_evaluation.json +0 -30
- wisent/examples/scripts/results/test_xsum_pairs.json +0 -8
- wisent/examples/scripts/results/test_xwinograd_evaluation.json +0 -51
- wisent/examples/scripts/results/test_xwinograd_pairs.json +0 -14
- wisent/examples/scripts/results/test_yahoo_answers_topics_evaluation.json +0 -30
- wisent/examples/scripts/results/test_yahoo_answers_topics_pairs.json +0 -8
- {wisent-0.7.379.dist-info → wisent-0.7.901.dist-info}/WHEEL +0 -0
- {wisent-0.7.379.dist-info → wisent-0.7.901.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.379.dist-info → wisent-0.7.901.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.379.dist-info → wisent-0.7.901.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,17 @@ __all__ = [
|
|
|
24
24
|
"GeometryAnalysisResult",
|
|
25
25
|
"StructureType",
|
|
26
26
|
"detect_geometry_structure",
|
|
27
|
+
"MultiLayerGeometryConfig",
|
|
28
|
+
"MultiLayerGeometryResult",
|
|
29
|
+
"LayerGeometryResult",
|
|
30
|
+
"detect_geometry_multi_layer",
|
|
31
|
+
"detect_geometry_all_layers",
|
|
32
|
+
"ExhaustiveCombinationResult",
|
|
33
|
+
"ExhaustiveGeometryAnalysisResult",
|
|
34
|
+
"detect_geometry_exhaustive",
|
|
35
|
+
"detect_geometry_limited",
|
|
36
|
+
"detect_geometry_contiguous",
|
|
37
|
+
"detect_geometry_smart",
|
|
27
38
|
]
|
|
28
39
|
|
|
29
40
|
|
|
@@ -84,7 +95,7 @@ def run_control_vector_diagnostics(
|
|
|
84
95
|
)
|
|
85
96
|
continue
|
|
86
97
|
|
|
87
|
-
flat = detached.to(
|
|
98
|
+
flat = detached.to(device="cpu").reshape(-1)
|
|
88
99
|
|
|
89
100
|
if not torch.isfinite(flat).all():
|
|
90
101
|
non_finite = (~torch.isfinite(flat)).sum().item()
|
|
@@ -1058,66 +1069,88 @@ def _detect_cone_structure_score(
|
|
|
1058
1069
|
neg_tensor: torch.Tensor,
|
|
1059
1070
|
cfg: GeometryAnalysisConfig,
|
|
1060
1071
|
) -> StructureScore:
|
|
1061
|
-
"""Detect cone structure
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
)
|
|
1072
|
+
"""Detect cone structure using RAW cosine similarity of difference vectors.
|
|
1073
|
+
|
|
1074
|
+
A cone structure means:
|
|
1075
|
+
- Multiple difference vectors (pos_i - neg_i) point in SIMILAR directions
|
|
1076
|
+
- High cosine similarity between raw difference vectors
|
|
1077
|
+
- NOT using gradient-optimized directions (which inflate the score)
|
|
1067
1078
|
|
|
1079
|
+
This matches what the visualization computes.
|
|
1080
|
+
"""
|
|
1068
1081
|
try:
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
#
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1082
|
+
# Compute raw difference vectors (what visualization uses)
|
|
1083
|
+
n_pairs = min(pos_tensor.shape[0], neg_tensor.shape[0])
|
|
1084
|
+
if n_pairs < 3:
|
|
1085
|
+
return StructureScore(StructureType.CONE, 0.0, 0.0, {"reason": "insufficient_pairs"})
|
|
1086
|
+
|
|
1087
|
+
diff_vectors = pos_tensor[:n_pairs] - neg_tensor[:n_pairs]
|
|
1088
|
+
|
|
1089
|
+
# Normalize difference vectors
|
|
1090
|
+
norms = diff_vectors.norm(dim=1, keepdim=True)
|
|
1091
|
+
valid_mask = (norms.squeeze() > 1e-8)
|
|
1092
|
+
if valid_mask.sum() < 3:
|
|
1093
|
+
return StructureScore(StructureType.CONE, 0.0, 0.0, {"reason": "zero_differences"})
|
|
1094
|
+
|
|
1095
|
+
diff_normalized = diff_vectors[valid_mask] / norms[valid_mask]
|
|
1096
|
+
|
|
1097
|
+
# Compute pairwise cosine similarity matrix
|
|
1098
|
+
cos_sim_matrix = diff_normalized @ diff_normalized.T
|
|
1099
|
+
|
|
1100
|
+
# Get off-diagonal elements (exclude self-similarity of 1.0)
|
|
1101
|
+
n = cos_sim_matrix.shape[0]
|
|
1102
|
+
mask = ~torch.eye(n, dtype=torch.bool, device=cos_sim_matrix.device)
|
|
1103
|
+
off_diagonal = cos_sim_matrix[mask]
|
|
1104
|
+
|
|
1105
|
+
# Raw cosine similarity statistics
|
|
1106
|
+
mean_cos_sim = float(off_diagonal.mean())
|
|
1107
|
+
std_cos_sim = float(off_diagonal.std())
|
|
1108
|
+
min_cos_sim = float(off_diagonal.min())
|
|
1109
|
+
max_cos_sim = float(off_diagonal.max())
|
|
1110
|
+
|
|
1111
|
+
# Fraction of pairs with positive correlation (same half-space)
|
|
1112
|
+
positive_fraction = float((off_diagonal > 0).float().mean())
|
|
1113
|
+
|
|
1114
|
+
# Fraction with strong correlation (>0.3)
|
|
1115
|
+
strong_fraction = float((off_diagonal > 0.3).float().mean())
|
|
1116
|
+
|
|
1117
|
+
# Cone score based on raw cosine similarity:
|
|
1118
|
+
# - High mean cosine = directions are aligned = cone
|
|
1119
|
+
# - Low mean cosine = directions are independent = NOT cone
|
|
1120
|
+
# - Negative mean cosine = directions are opposing = NOT cone
|
|
1121
|
+
|
|
1122
|
+
if mean_cos_sim < 0:
|
|
1123
|
+
# Negative correlation = definitely not a cone
|
|
1124
|
+
cone_score = 0.0
|
|
1125
|
+
elif mean_cos_sim < 0.1:
|
|
1126
|
+
# Near zero = orthogonal/independent, not cone
|
|
1127
|
+
cone_score = mean_cos_sim # 0.0 - 0.1
|
|
1128
|
+
elif mean_cos_sim < 0.3:
|
|
1129
|
+
# Weak correlation = weak cone
|
|
1130
|
+
cone_score = 0.1 + 0.2 * ((mean_cos_sim - 0.1) / 0.2) # 0.1 - 0.3
|
|
1131
|
+
elif mean_cos_sim < 0.7:
|
|
1132
|
+
# Moderate correlation = good cone (ideal range)
|
|
1133
|
+
cone_score = 0.3 + 0.5 * ((mean_cos_sim - 0.3) / 0.4) # 0.3 - 0.8
|
|
1092
1134
|
else:
|
|
1093
|
-
#
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
#
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
# Adjusted cone score
|
|
1101
|
-
cone_score = (
|
|
1102
|
-
0.25 * result.half_space_consistency +
|
|
1103
|
-
0.25 * cosine_score +
|
|
1104
|
-
0.20 * cone_advantage +
|
|
1105
|
-
0.15 * multi_dir_score +
|
|
1106
|
-
0.15 * (1 - pca_penalty) # Penalize when PCA is sufficient
|
|
1107
|
-
)
|
|
1135
|
+
# Very high correlation = almost linear, still cone-like
|
|
1136
|
+
cone_score = 0.8 + 0.2 * ((mean_cos_sim - 0.7) / 0.3) # 0.8 - 1.0
|
|
1137
|
+
|
|
1138
|
+
# Confidence based on consistency (low std = more consistent = higher confidence)
|
|
1139
|
+
consistency = max(0, 1 - std_cos_sim)
|
|
1140
|
+
confidence = consistency * min(1.0, n_pairs / 20)
|
|
1108
1141
|
|
|
1109
1142
|
return StructureScore(
|
|
1110
1143
|
StructureType.CONE,
|
|
1111
1144
|
score=float(cone_score),
|
|
1112
|
-
confidence=
|
|
1145
|
+
confidence=float(confidence),
|
|
1113
1146
|
details={
|
|
1114
|
-
"
|
|
1115
|
-
"
|
|
1116
|
-
"
|
|
1117
|
-
"
|
|
1118
|
-
"
|
|
1119
|
-
"
|
|
1120
|
-
"
|
|
1147
|
+
"raw_mean_cosine_similarity": mean_cos_sim,
|
|
1148
|
+
"raw_std_cosine_similarity": std_cos_sim,
|
|
1149
|
+
"raw_min_cosine_similarity": min_cos_sim,
|
|
1150
|
+
"raw_max_cosine_similarity": max_cos_sim,
|
|
1151
|
+
"positive_correlation_fraction": positive_fraction,
|
|
1152
|
+
"strong_correlation_fraction": strong_fraction,
|
|
1153
|
+
"n_valid_pairs": int(valid_mask.sum()),
|
|
1121
1154
|
}
|
|
1122
1155
|
)
|
|
1123
1156
|
except Exception as e:
|
|
@@ -1130,7 +1163,17 @@ def _detect_cluster_structure(
|
|
|
1130
1163
|
diff_vectors: torch.Tensor,
|
|
1131
1164
|
cfg: GeometryAnalysisConfig,
|
|
1132
1165
|
) -> StructureScore:
|
|
1133
|
-
"""Detect if activations form discrete clusters.
|
|
1166
|
+
"""Detect if activations form discrete clusters.
|
|
1167
|
+
|
|
1168
|
+
Cluster structure means:
|
|
1169
|
+
- Data forms DISCRETE, SEPARATED groups
|
|
1170
|
+
- Not just "pos vs neg" (that's trivially 2 clusters)
|
|
1171
|
+
- Actual subgroups within the data
|
|
1172
|
+
|
|
1173
|
+
Key insight: k-means will ALWAYS find clusters.
|
|
1174
|
+
We need high silhouette AND clear separation to claim clusters.
|
|
1175
|
+
Also, if pos/neg perfectly separate, that's "linear", not "cluster".
|
|
1176
|
+
"""
|
|
1134
1177
|
all_activations = torch.cat([pos_tensor, neg_tensor], dim=0)
|
|
1135
1178
|
n_samples = all_activations.shape[0]
|
|
1136
1179
|
|
|
@@ -1143,7 +1186,6 @@ def _detect_cluster_structure(
|
|
|
1143
1186
|
|
|
1144
1187
|
for k in range(2, min(cfg.max_clusters + 1, n_samples // 2)):
|
|
1145
1188
|
try:
|
|
1146
|
-
# Simple k-means implementation
|
|
1147
1189
|
labels, centroids, silhouette = _kmeans_with_silhouette(all_activations, k, max_iters=50)
|
|
1148
1190
|
silhouette_scores[k] = silhouette
|
|
1149
1191
|
|
|
@@ -1156,31 +1198,55 @@ def _detect_cluster_structure(
|
|
|
1156
1198
|
if best_silhouette < 0:
|
|
1157
1199
|
return StructureScore(StructureType.CLUSTER, 0.0, 0.0, {"reason": "clustering_failed"})
|
|
1158
1200
|
|
|
1159
|
-
# Check if clusters separate pos/neg
|
|
1201
|
+
# Check if clusters just separate pos/neg (that's linear, not cluster)
|
|
1160
1202
|
labels, _, _ = _kmeans_with_silhouette(all_activations, best_k, max_iters=50)
|
|
1161
1203
|
pos_labels = labels[:pos_tensor.shape[0]]
|
|
1162
1204
|
neg_labels = labels[pos_tensor.shape[0]:]
|
|
1163
1205
|
|
|
1164
|
-
#
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1206
|
+
# If k=2 and it perfectly separates pos/neg, that's LINEAR not cluster
|
|
1207
|
+
if best_k == 2:
|
|
1208
|
+
pos_mode = pos_labels.mode().values.item() if len(pos_labels) > 0 else -1
|
|
1209
|
+
neg_mode = neg_labels.mode().values.item() if len(neg_labels) > 0 else -1
|
|
1210
|
+
pos_purity = (pos_labels == pos_mode).float().mean()
|
|
1211
|
+
neg_purity = (neg_labels == neg_mode).float().mean()
|
|
1212
|
+
|
|
1213
|
+
if pos_mode != neg_mode and pos_purity > 0.8 and neg_purity > 0.8:
|
|
1214
|
+
# Perfect pos/neg separation - this is LINEAR, not cluster
|
|
1215
|
+
return StructureScore(
|
|
1216
|
+
StructureType.CLUSTER,
|
|
1217
|
+
score=0.1, # Low score - it's actually linear
|
|
1218
|
+
confidence=0.8,
|
|
1219
|
+
details={
|
|
1220
|
+
"reason": "pos_neg_separation_is_linear",
|
|
1221
|
+
"best_k": 2,
|
|
1222
|
+
"pos_purity": float(pos_purity),
|
|
1223
|
+
"neg_purity": float(neg_purity),
|
|
1224
|
+
}
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
# For true cluster structure, we need:
|
|
1228
|
+
# 1. High silhouette (> 0.5 is good, > 0.7 is strong)
|
|
1229
|
+
# 2. k > 2 OR k=2 with mixed clusters
|
|
1230
|
+
|
|
1231
|
+
# Silhouette thresholds - be strict
|
|
1232
|
+
if best_silhouette < 0.4:
|
|
1233
|
+
# Very low silhouette = no clear cluster structure
|
|
1234
|
+
cluster_score = best_silhouette * 0.3 # Very low score
|
|
1235
|
+
elif best_silhouette < cfg.cluster_silhouette_threshold:
|
|
1236
|
+
# Moderate silhouette = weak cluster structure
|
|
1237
|
+
cluster_score = 0.1 + 0.2 * (best_silhouette / cfg.cluster_silhouette_threshold)
|
|
1179
1238
|
else:
|
|
1180
|
-
#
|
|
1181
|
-
#
|
|
1182
|
-
|
|
1183
|
-
|
|
1239
|
+
# High silhouette = good cluster structure
|
|
1240
|
+
# But only if it's not just pos/neg separation
|
|
1241
|
+
base_score = 0.3 + 0.5 * ((best_silhouette - cfg.cluster_silhouette_threshold) / (1 - cfg.cluster_silhouette_threshold))
|
|
1242
|
+
|
|
1243
|
+
# Bonus for k > 2 (more interesting structure)
|
|
1244
|
+
if best_k > 2:
|
|
1245
|
+
cluster_score = base_score + 0.2
|
|
1246
|
+
else:
|
|
1247
|
+
cluster_score = base_score
|
|
1248
|
+
|
|
1249
|
+
cluster_score = min(1.0, cluster_score)
|
|
1184
1250
|
|
|
1185
1251
|
return StructureScore(
|
|
1186
1252
|
StructureType.CLUSTER,
|
|
@@ -1190,7 +1256,6 @@ def _detect_cluster_structure(
|
|
|
1190
1256
|
"best_k": best_k,
|
|
1191
1257
|
"best_silhouette": float(best_silhouette),
|
|
1192
1258
|
"all_silhouettes": {str(k): float(v) for k, v in silhouette_scores.items()},
|
|
1193
|
-
"cluster_separation": float(cluster_separation),
|
|
1194
1259
|
"silhouette_threshold": cfg.cluster_silhouette_threshold,
|
|
1195
1260
|
}
|
|
1196
1261
|
)
|
|
@@ -1279,7 +1344,19 @@ def _detect_manifold_structure(
|
|
|
1279
1344
|
diff_vectors: torch.Tensor,
|
|
1280
1345
|
cfg: GeometryAnalysisConfig,
|
|
1281
1346
|
) -> StructureScore:
|
|
1282
|
-
"""Detect non-linear manifold structure
|
|
1347
|
+
"""Detect non-linear manifold structure.
|
|
1348
|
+
|
|
1349
|
+
Manifold structure means:
|
|
1350
|
+
- Data lies on a CURVED surface (not linear)
|
|
1351
|
+
- Linear methods (PCA, CAA) cannot capture the structure
|
|
1352
|
+
- Requires non-linear methods (TITAN, neural steering)
|
|
1353
|
+
|
|
1354
|
+
Key insight: Manifold should be a FALLBACK, not default.
|
|
1355
|
+
Only report manifold if:
|
|
1356
|
+
1. Linear doesn't work (PCA explains little variance)
|
|
1357
|
+
2. There's actual curvature (local neighborhoods don't align)
|
|
1358
|
+
3. BUT there IS structure (not just noise)
|
|
1359
|
+
"""
|
|
1283
1360
|
all_activations = torch.cat([pos_tensor, neg_tensor], dim=0)
|
|
1284
1361
|
n_samples = all_activations.shape[0]
|
|
1285
1362
|
|
|
@@ -1287,52 +1364,80 @@ def _detect_manifold_structure(
|
|
|
1287
1364
|
return StructureScore(StructureType.MANIFOLD, 0.0, 0.0, {"reason": "insufficient_data"})
|
|
1288
1365
|
|
|
1289
1366
|
try:
|
|
1290
|
-
#
|
|
1367
|
+
# 1. Check if linear works well (if yes, not manifold)
|
|
1368
|
+
centered = all_activations - all_activations.mean(dim=0, keepdim=True)
|
|
1369
|
+
try:
|
|
1370
|
+
_, S, _ = torch.linalg.svd(centered, full_matrices=False)
|
|
1371
|
+
total_var = (S ** 2).sum()
|
|
1372
|
+
if total_var > 0:
|
|
1373
|
+
# Top 2 PCs variance explained
|
|
1374
|
+
top2_var = (S[:2] ** 2).sum() / total_var
|
|
1375
|
+
linear_explains_well = float(top2_var) > 0.7
|
|
1376
|
+
else:
|
|
1377
|
+
linear_explains_well = True # No variance = trivial
|
|
1378
|
+
except Exception:
|
|
1379
|
+
linear_explains_well = False
|
|
1380
|
+
top2_var = torch.tensor(0.0)
|
|
1381
|
+
|
|
1382
|
+
if linear_explains_well:
|
|
1383
|
+
# Linear works well - not a manifold (it's linear)
|
|
1384
|
+
return StructureScore(
|
|
1385
|
+
StructureType.MANIFOLD,
|
|
1386
|
+
score=0.1,
|
|
1387
|
+
confidence=0.8,
|
|
1388
|
+
details={
|
|
1389
|
+
"reason": "linear_sufficient",
|
|
1390
|
+
"pca_top2_variance": float(top2_var),
|
|
1391
|
+
}
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
# 2. Check for actual curvature (local PCA directions vary)
|
|
1395
|
+
local_nonlinearity = _compute_local_nonlinearity(all_activations, cfg.manifold_neighbors)
|
|
1396
|
+
|
|
1397
|
+
# 3. Check if there's meaningful structure (separation between pos/neg)
|
|
1291
1398
|
mean_diff = pos_tensor.mean(dim=0) - neg_tensor.mean(dim=0)
|
|
1292
1399
|
separation_strength = mean_diff.norm() / (pos_tensor.std() + neg_tensor.std() + 1e-8)
|
|
1293
1400
|
has_structure = min(float(separation_strength) / 2, 1.0)
|
|
1294
1401
|
|
|
1295
|
-
if has_structure < 0.
|
|
1296
|
-
# No
|
|
1297
|
-
return StructureScore(
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
# Also compute local linearity deviation
|
|
1307
|
-
local_nonlinearity = _compute_local_nonlinearity(all_activations, cfg.manifold_neighbors)
|
|
1402
|
+
if has_structure < 0.3:
|
|
1403
|
+
# No clear structure - likely noise, not manifold
|
|
1404
|
+
return StructureScore(
|
|
1405
|
+
StructureType.MANIFOLD,
|
|
1406
|
+
score=0.2,
|
|
1407
|
+
confidence=0.5,
|
|
1408
|
+
details={
|
|
1409
|
+
"reason": "weak_structure",
|
|
1410
|
+
"separation_strength": float(separation_strength),
|
|
1411
|
+
}
|
|
1412
|
+
)
|
|
1308
1413
|
|
|
1309
|
-
# Manifold
|
|
1310
|
-
#
|
|
1311
|
-
#
|
|
1414
|
+
# 4. Manifold requires BOTH:
|
|
1415
|
+
# - Linear doesn't work (already checked)
|
|
1416
|
+
# - AND there's curvature
|
|
1417
|
+
# - AND there's structure
|
|
1312
1418
|
|
|
1313
|
-
#
|
|
1314
|
-
if
|
|
1315
|
-
|
|
1316
|
-
manifold_score = 0.3 * has_structure
|
|
1419
|
+
# If nonlinearity is low, it might be orthogonal/independent, not curved
|
|
1420
|
+
if local_nonlinearity < 0.3:
|
|
1421
|
+
manifold_score = 0.3 * has_structure # Low score
|
|
1317
1422
|
else:
|
|
1423
|
+
# High nonlinearity + structure = manifold candidate
|
|
1318
1424
|
manifold_score = (
|
|
1319
|
-
0.30 *
|
|
1320
|
-
0.
|
|
1321
|
-
0.
|
|
1425
|
+
0.30 * local_nonlinearity +
|
|
1426
|
+
0.30 * (1 - float(top2_var)) + # Reward when linear fails
|
|
1427
|
+
0.40 * has_structure
|
|
1322
1428
|
)
|
|
1323
1429
|
|
|
1324
|
-
# Confidence based on sample size
|
|
1325
|
-
confidence = min(1.0, n_samples / 100)
|
|
1430
|
+
# Confidence based on sample size and consistency
|
|
1431
|
+
confidence = min(1.0, n_samples / 100) * has_structure
|
|
1326
1432
|
|
|
1327
1433
|
return StructureScore(
|
|
1328
1434
|
StructureType.MANIFOLD,
|
|
1329
1435
|
score=float(manifold_score),
|
|
1330
1436
|
confidence=float(confidence),
|
|
1331
1437
|
details={
|
|
1332
|
-
"
|
|
1333
|
-
"ambient_dimensionality": ambient_dim,
|
|
1334
|
-
"dim_ratio": float(dim_ratio),
|
|
1438
|
+
"pca_top2_variance": float(top2_var),
|
|
1335
1439
|
"local_nonlinearity": float(local_nonlinearity),
|
|
1440
|
+
"separation_strength": float(separation_strength),
|
|
1336
1441
|
}
|
|
1337
1442
|
)
|
|
1338
1443
|
except Exception as e:
|
|
@@ -1444,7 +1549,7 @@ def _detect_sparse_structure(
|
|
|
1444
1549
|
sorted_abs = abs_diff.sort().values
|
|
1445
1550
|
n = len(sorted_abs)
|
|
1446
1551
|
cumsum = sorted_abs.cumsum(0)
|
|
1447
|
-
gini = (2 * torch.arange(1, n + 1, dtype=
|
|
1552
|
+
gini = (2 * torch.arange(1, n + 1, dtype=sorted_abs.dtype, device=sorted_abs.device) @ sorted_abs - (n + 1) * sorted_abs.sum()) / (n * sorted_abs.sum() + 1e-10)
|
|
1448
1553
|
|
|
1449
1554
|
# Sparse score: high if few dimensions are active
|
|
1450
1555
|
sparse_score = 0.4 * (1 - float(l1_l2_ratio)) + 0.3 * (1 - float(active_fraction)) + 0.3 * float(gini)
|
|
@@ -1527,11 +1632,11 @@ def _compute_dip_statistic(data: torch.Tensor) -> float:
|
|
|
1527
1632
|
return 0.0
|
|
1528
1633
|
|
|
1529
1634
|
# Empirical CDF
|
|
1530
|
-
ecdf = torch.arange(1, n + 1, dtype=
|
|
1635
|
+
ecdf = torch.arange(1, n + 1, dtype=sorted_data.dtype, device=sorted_data.device) / n
|
|
1531
1636
|
|
|
1532
1637
|
# Greatest convex minorant and least concave majorant
|
|
1533
1638
|
# Simplified: measure deviation from uniform
|
|
1534
|
-
uniform = torch.linspace(0, 1, n)
|
|
1639
|
+
uniform = torch.linspace(0, 1, n, dtype=sorted_data.dtype, device=sorted_data.device)
|
|
1535
1640
|
|
|
1536
1641
|
# Kolmogorov-Smirnov like statistic
|
|
1537
1642
|
ks_stat = (ecdf - uniform).abs().max()
|
|
@@ -1547,83 +1652,85 @@ def _detect_orthogonal_structure(
|
|
|
1547
1652
|
) -> StructureScore:
|
|
1548
1653
|
"""Detect if behavior is encoded in multiple orthogonal/independent subspaces.
|
|
1549
1654
|
|
|
1550
|
-
Orthogonal structure means
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
if diff_vectors.shape[0] < cfg.num_components:
|
|
1555
|
-
return StructureScore(StructureType.ORTHOGONAL, 0.0, 0.0, {"reason": "insufficient_data"})
|
|
1655
|
+
Orthogonal structure means:
|
|
1656
|
+
- Multiple difference vectors point in INDEPENDENT directions
|
|
1657
|
+
- Low cosine similarity between difference vectors (near 0)
|
|
1658
|
+
- NOT correlated (that's cone) and NOT single direction (that's linear)
|
|
1556
1659
|
|
|
1660
|
+
This is the OPPOSITE of cone - if cosine sim is low, it's orthogonal.
|
|
1661
|
+
Uses raw cosine similarity like the cone detector for consistency.
|
|
1662
|
+
"""
|
|
1557
1663
|
try:
|
|
1558
|
-
#
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
#
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
#
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
#
|
|
1587
|
-
|
|
1588
|
-
|
|
1664
|
+
# Compute raw difference vectors (same as cone detector)
|
|
1665
|
+
n_pairs = min(pos_tensor.shape[0], neg_tensor.shape[0])
|
|
1666
|
+
if n_pairs < 3:
|
|
1667
|
+
return StructureScore(StructureType.ORTHOGONAL, 0.0, 0.0, {"reason": "insufficient_pairs"})
|
|
1668
|
+
|
|
1669
|
+
diff_vectors_raw = pos_tensor[:n_pairs] - neg_tensor[:n_pairs]
|
|
1670
|
+
|
|
1671
|
+
# Normalize difference vectors
|
|
1672
|
+
norms = diff_vectors_raw.norm(dim=1, keepdim=True)
|
|
1673
|
+
valid_mask = (norms.squeeze() > 1e-8)
|
|
1674
|
+
if valid_mask.sum() < 3:
|
|
1675
|
+
return StructureScore(StructureType.ORTHOGONAL, 0.0, 0.0, {"reason": "zero_differences"})
|
|
1676
|
+
|
|
1677
|
+
diff_normalized = diff_vectors_raw[valid_mask] / norms[valid_mask]
|
|
1678
|
+
|
|
1679
|
+
# Compute pairwise cosine similarity matrix
|
|
1680
|
+
cos_sim_matrix = diff_normalized @ diff_normalized.T
|
|
1681
|
+
|
|
1682
|
+
# Get off-diagonal elements
|
|
1683
|
+
n = cos_sim_matrix.shape[0]
|
|
1684
|
+
mask = ~torch.eye(n, dtype=torch.bool, device=cos_sim_matrix.device)
|
|
1685
|
+
off_diagonal = cos_sim_matrix[mask]
|
|
1686
|
+
|
|
1687
|
+
# Raw cosine similarity statistics
|
|
1688
|
+
mean_cos_sim = float(off_diagonal.mean())
|
|
1689
|
+
std_cos_sim = float(off_diagonal.std())
|
|
1690
|
+
abs_mean_cos_sim = float(off_diagonal.abs().mean())
|
|
1691
|
+
|
|
1692
|
+
# Fraction near zero (truly orthogonal)
|
|
1693
|
+
near_zero_fraction = float((off_diagonal.abs() < 0.2).float().mean())
|
|
1694
|
+
|
|
1695
|
+
# Orthogonal = LOW cosine similarity (opposite of cone)
|
|
1696
|
+
# Ideal orthogonal: mean cosine sim near 0, low absolute mean
|
|
1697
|
+
|
|
1698
|
+
if abs_mean_cos_sim < 0.1:
|
|
1699
|
+
# Very low correlation = strong orthogonal
|
|
1700
|
+
orthogonal_score = 0.8 + 0.2 * (1 - abs_mean_cos_sim / 0.1)
|
|
1701
|
+
elif abs_mean_cos_sim < 0.2:
|
|
1702
|
+
# Low correlation = moderate orthogonal
|
|
1703
|
+
orthogonal_score = 0.5 + 0.3 * (1 - (abs_mean_cos_sim - 0.1) / 0.1)
|
|
1704
|
+
elif abs_mean_cos_sim < 0.4:
|
|
1705
|
+
# Moderate correlation = weak orthogonal
|
|
1706
|
+
orthogonal_score = 0.2 + 0.3 * (1 - (abs_mean_cos_sim - 0.2) / 0.2)
|
|
1707
|
+
else:
|
|
1708
|
+
# High correlation = not orthogonal (probably cone or linear)
|
|
1709
|
+
orthogonal_score = max(0, 0.2 * (1 - (abs_mean_cos_sim - 0.4) / 0.6))
|
|
1589
1710
|
|
|
1590
|
-
# Check separation
|
|
1711
|
+
# Check if there's meaningful separation (not just noise)
|
|
1591
1712
|
mean_diff = pos_tensor.mean(dim=0) - neg_tensor.mean(dim=0)
|
|
1592
1713
|
separation_strength = mean_diff.norm() / (pos_tensor.std() + neg_tensor.std() + 1e-8)
|
|
1593
|
-
has_separation = min(float(separation_strength) /
|
|
1594
|
-
|
|
1595
|
-
#
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
if significant_dims < 2:
|
|
1602
|
-
# Too few dimensions = linear
|
|
1603
|
-
orthogonal_score = 0.2
|
|
1604
|
-
elif significant_dims > 10:
|
|
1605
|
-
# Too many = likely noise, not structure
|
|
1606
|
-
orthogonal_score = 0.3 * has_separation
|
|
1607
|
-
else:
|
|
1608
|
-
# Reasonable number of dimensions
|
|
1609
|
-
# Check if it's not dominated by first (would be linear)
|
|
1610
|
-
# and not too spread (would be noise)
|
|
1611
|
-
structure_score = (
|
|
1612
|
-
0.3 * (1 - first_var) + # Not dominated by one direction
|
|
1613
|
-
0.3 * min(significant_dims / 4, 1.0) + # 2-4 directions is ideal
|
|
1614
|
-
0.4 * has_separation # Must have separation
|
|
1615
|
-
)
|
|
1616
|
-
orthogonal_score = structure_score * 0.8 # Scale down - orthogonal is rare
|
|
1714
|
+
has_separation = min(float(separation_strength) / 2, 1.0)
|
|
1715
|
+
|
|
1716
|
+
# Orthogonal without separation is just noise
|
|
1717
|
+
if has_separation < 0.3:
|
|
1718
|
+
orthogonal_score *= 0.3 # Heavy penalty
|
|
1719
|
+
|
|
1720
|
+
# Confidence based on consistency and sample size
|
|
1721
|
+
confidence = near_zero_fraction * min(1.0, n_pairs / 20)
|
|
1617
1722
|
|
|
1618
1723
|
return StructureScore(
|
|
1619
1724
|
StructureType.ORTHOGONAL,
|
|
1620
1725
|
score=float(orthogonal_score),
|
|
1621
|
-
confidence=
|
|
1726
|
+
confidence=float(confidence),
|
|
1622
1727
|
details={
|
|
1623
|
-
"
|
|
1624
|
-
"
|
|
1625
|
-
"
|
|
1626
|
-
"
|
|
1728
|
+
"raw_mean_cosine_similarity": mean_cos_sim,
|
|
1729
|
+
"raw_abs_mean_cosine_similarity": abs_mean_cos_sim,
|
|
1730
|
+
"raw_std_cosine_similarity": std_cos_sim,
|
|
1731
|
+
"near_zero_fraction": near_zero_fraction,
|
|
1732
|
+
"separation_strength": float(separation_strength),
|
|
1733
|
+
"n_valid_pairs": int(valid_mask.sum()),
|
|
1627
1734
|
}
|
|
1628
1735
|
)
|
|
1629
1736
|
except Exception as e:
|
|
@@ -1652,4 +1759,1302 @@ def _generate_recommendation(best_structure: StructureType, all_scores: Dict[str
|
|
|
1652
1759
|
if second_best[1].score > 0.6:
|
|
1653
1760
|
base_rec += f" (Also consider {second_best[0]}: score {second_best[1].score:.2f})"
|
|
1654
1761
|
|
|
1655
|
-
return base_rec
|
|
1762
|
+
return base_rec
|
|
1763
|
+
|
|
1764
|
+
|
|
1765
|
+
# =============================================================================
|
|
1766
|
+
# Multi-Layer Geometry Analysis
|
|
1767
|
+
# =============================================================================
|
|
1768
|
+
|
|
1769
|
+
@dataclass
|
|
1770
|
+
class MultiLayerGeometryConfig:
|
|
1771
|
+
"""Configuration for multi-layer geometry analysis."""
|
|
1772
|
+
|
|
1773
|
+
num_components: int = 5
|
|
1774
|
+
optimization_steps: int = 50
|
|
1775
|
+
combination_method: str = "concat" # "concat", "mean", "weighted"
|
|
1776
|
+
analyze_per_layer: bool = True
|
|
1777
|
+
analyze_combined: bool = True
|
|
1778
|
+
analyze_subsets: bool = True # early/middle/late
|
|
1779
|
+
analyze_pairs: bool = True # all pairs of layers
|
|
1780
|
+
analyze_adjacent: bool = True # adjacent layer pairs
|
|
1781
|
+
analyze_skip: bool = True # every other layer, every third, etc.
|
|
1782
|
+
analyze_custom: Optional[List[List[int]]] = None # custom layer combinations
|
|
1783
|
+
max_pair_combinations: int = 50 # limit number of pair combinations to analyze
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
@dataclass
|
|
1787
|
+
class LayerGeometryResult:
|
|
1788
|
+
"""Geometry result for a single layer."""
|
|
1789
|
+
layer: int
|
|
1790
|
+
best_structure: StructureType
|
|
1791
|
+
best_score: float
|
|
1792
|
+
all_scores: Dict[str, float]
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
@dataclass
|
|
1796
|
+
class MultiLayerGeometryResult:
|
|
1797
|
+
"""Results from multi-layer geometry analysis."""
|
|
1798
|
+
|
|
1799
|
+
per_layer_results: Dict[int, LayerGeometryResult]
|
|
1800
|
+
"""Geometry analysis for each individual layer."""
|
|
1801
|
+
|
|
1802
|
+
combined_result: Optional[GeometryAnalysisResult]
|
|
1803
|
+
"""Geometry analysis for all layers combined."""
|
|
1804
|
+
|
|
1805
|
+
layer_subset_results: Dict[str, GeometryAnalysisResult]
|
|
1806
|
+
"""Geometry analysis for layer subsets (e.g., 'early', 'middle', 'late')."""
|
|
1807
|
+
|
|
1808
|
+
layer_pair_results: Dict[str, GeometryAnalysisResult]
|
|
1809
|
+
"""Geometry analysis for pairs of layers (e.g., 'L1+L5', 'L2+L8')."""
|
|
1810
|
+
|
|
1811
|
+
adjacent_pair_results: Dict[str, GeometryAnalysisResult]
|
|
1812
|
+
"""Geometry analysis for adjacent layer pairs (e.g., 'L1+L2', 'L2+L3')."""
|
|
1813
|
+
|
|
1814
|
+
skip_results: Dict[str, GeometryAnalysisResult]
|
|
1815
|
+
"""Geometry analysis for skip patterns (e.g., 'every_2nd', 'every_3rd')."""
|
|
1816
|
+
|
|
1817
|
+
custom_results: Dict[str, GeometryAnalysisResult]
|
|
1818
|
+
"""Geometry analysis for custom layer combinations."""
|
|
1819
|
+
|
|
1820
|
+
best_single_layer: int
|
|
1821
|
+
"""Layer with strongest structure detection."""
|
|
1822
|
+
|
|
1823
|
+
best_single_layer_structure: StructureType
|
|
1824
|
+
"""Structure type detected at best single layer."""
|
|
1825
|
+
|
|
1826
|
+
best_single_layer_score: float
|
|
1827
|
+
"""Score at best single layer."""
|
|
1828
|
+
|
|
1829
|
+
best_combination: Optional[str]
|
|
1830
|
+
"""Best performing layer combination (if better than single layer)."""
|
|
1831
|
+
|
|
1832
|
+
best_combination_score: float
|
|
1833
|
+
"""Score of best combination."""
|
|
1834
|
+
|
|
1835
|
+
best_combination_structure: Optional[StructureType]
|
|
1836
|
+
"""Structure type detected at best combination."""
|
|
1837
|
+
|
|
1838
|
+
combined_vs_single: str
|
|
1839
|
+
"""Whether combined layers improve over single layer."""
|
|
1840
|
+
|
|
1841
|
+
layer_agreement: float
|
|
1842
|
+
"""How much layers agree on structure type (0-1)."""
|
|
1843
|
+
|
|
1844
|
+
structure_by_depth: Dict[str, List[float]]
|
|
1845
|
+
"""How each structure score varies by layer depth."""
|
|
1846
|
+
|
|
1847
|
+
all_combinations_ranked: List[Tuple[str, float, StructureType]]
|
|
1848
|
+
"""All combinations ranked by score: (name, score, structure)."""
|
|
1849
|
+
|
|
1850
|
+
recommendation: str
|
|
1851
|
+
"""Recommendation based on multi-layer analysis."""
|
|
1852
|
+
|
|
1853
|
+
|
|
1854
|
+
def detect_geometry_multi_layer(
|
|
1855
|
+
pos_activations_by_layer: Dict[int, torch.Tensor],
|
|
1856
|
+
neg_activations_by_layer: Dict[int, torch.Tensor],
|
|
1857
|
+
config: MultiLayerGeometryConfig | None = None,
|
|
1858
|
+
) -> MultiLayerGeometryResult:
|
|
1859
|
+
"""
|
|
1860
|
+
Detect geometric structure across multiple layers.
|
|
1861
|
+
|
|
1862
|
+
Analyzes:
|
|
1863
|
+
1. Each layer individually
|
|
1864
|
+
2. All layers combined (concatenated or aggregated)
|
|
1865
|
+
3. Layer subsets (early, middle, late)
|
|
1866
|
+
4. Layer pairs (all combinations of 2 layers)
|
|
1867
|
+
5. Adjacent layer pairs (L1+L2, L2+L3, etc.)
|
|
1868
|
+
6. Skip patterns (every 2nd, every 3rd layer)
|
|
1869
|
+
7. Custom layer combinations
|
|
1870
|
+
8. How structure varies by depth
|
|
1871
|
+
|
|
1872
|
+
Arguments:
|
|
1873
|
+
pos_activations_by_layer: Dict mapping layer index to positive activations [N, hidden_dim]
|
|
1874
|
+
neg_activations_by_layer: Dict mapping layer index to negative activations [N, hidden_dim]
|
|
1875
|
+
config: Analysis configuration
|
|
1876
|
+
|
|
1877
|
+
Returns:
|
|
1878
|
+
MultiLayerGeometryResult with comprehensive multi-layer analysis
|
|
1879
|
+
"""
|
|
1880
|
+
cfg = config or MultiLayerGeometryConfig()
|
|
1881
|
+
geo_cfg = GeometryAnalysisConfig(num_components=cfg.num_components, optimization_steps=cfg.optimization_steps)
|
|
1882
|
+
|
|
1883
|
+
layers = sorted(pos_activations_by_layer.keys())
|
|
1884
|
+
if not layers:
|
|
1885
|
+
raise ValueError("No layers provided")
|
|
1886
|
+
|
|
1887
|
+
# Track all combination results for ranking
|
|
1888
|
+
all_combo_results: Dict[str, GeometryAnalysisResult] = {}
|
|
1889
|
+
|
|
1890
|
+
# 1. Analyze each layer individually
|
|
1891
|
+
per_layer_results: Dict[int, LayerGeometryResult] = {}
|
|
1892
|
+
structure_by_depth: Dict[str, List[float]] = {
|
|
1893
|
+
"linear": [], "cone": [], "cluster": [], "manifold": [],
|
|
1894
|
+
"sparse": [], "bimodal": [], "orthogonal": []
|
|
1895
|
+
}
|
|
1896
|
+
|
|
1897
|
+
if cfg.analyze_per_layer:
|
|
1898
|
+
for layer in layers:
|
|
1899
|
+
pos_acts = pos_activations_by_layer[layer]
|
|
1900
|
+
neg_acts = neg_activations_by_layer[layer]
|
|
1901
|
+
|
|
1902
|
+
result = detect_geometry_structure(pos_acts, neg_acts, geo_cfg)
|
|
1903
|
+
|
|
1904
|
+
all_scores = {name: score.score for name, score in result.all_scores.items()}
|
|
1905
|
+
per_layer_results[layer] = LayerGeometryResult(
|
|
1906
|
+
layer=layer,
|
|
1907
|
+
best_structure=result.best_structure,
|
|
1908
|
+
best_score=result.best_score,
|
|
1909
|
+
all_scores=all_scores,
|
|
1910
|
+
)
|
|
1911
|
+
all_combo_results[f"L{layer}"] = result
|
|
1912
|
+
|
|
1913
|
+
for struct_name, score in all_scores.items():
|
|
1914
|
+
if struct_name in structure_by_depth:
|
|
1915
|
+
structure_by_depth[struct_name].append(score)
|
|
1916
|
+
|
|
1917
|
+
# 2. Find best single layer
|
|
1918
|
+
if per_layer_results:
|
|
1919
|
+
best_layer = max(per_layer_results.keys(), key=lambda l: per_layer_results[l].best_score)
|
|
1920
|
+
best_single_layer = best_layer
|
|
1921
|
+
best_single_layer_structure = per_layer_results[best_layer].best_structure
|
|
1922
|
+
best_single_layer_score = per_layer_results[best_layer].best_score
|
|
1923
|
+
else:
|
|
1924
|
+
best_single_layer = layers[0]
|
|
1925
|
+
best_single_layer_structure = StructureType.UNKNOWN
|
|
1926
|
+
best_single_layer_score = 0.0
|
|
1927
|
+
|
|
1928
|
+
# 3. Analyze all layers combined
|
|
1929
|
+
combined_result = None
|
|
1930
|
+
if cfg.analyze_combined and len(layers) > 1:
|
|
1931
|
+
combined_pos, combined_neg = _combine_layer_activations(
|
|
1932
|
+
pos_activations_by_layer, neg_activations_by_layer, layers, cfg.combination_method
|
|
1933
|
+
)
|
|
1934
|
+
combined_result = detect_geometry_structure(combined_pos, combined_neg, geo_cfg)
|
|
1935
|
+
all_combo_results["all_layers"] = combined_result
|
|
1936
|
+
|
|
1937
|
+
# 4. Analyze layer subsets (early, middle, late)
|
|
1938
|
+
layer_subset_results: Dict[str, GeometryAnalysisResult] = {}
|
|
1939
|
+
if cfg.analyze_subsets and len(layers) >= 3:
|
|
1940
|
+
n_layers = len(layers)
|
|
1941
|
+
third = n_layers // 3
|
|
1942
|
+
|
|
1943
|
+
early_layers = layers[:third] if third > 0 else layers[:1]
|
|
1944
|
+
middle_layers = layers[third:2*third] if third > 0 else layers[1:2]
|
|
1945
|
+
late_layers = layers[2*third:] if third > 0 else layers[-1:]
|
|
1946
|
+
|
|
1947
|
+
# Also add first_half and second_half
|
|
1948
|
+
half = n_layers // 2
|
|
1949
|
+
first_half = layers[:half] if half > 0 else layers[:1]
|
|
1950
|
+
second_half = layers[half:] if half > 0 else layers[-1:]
|
|
1951
|
+
|
|
1952
|
+
subsets = [
|
|
1953
|
+
("early", early_layers),
|
|
1954
|
+
("middle", middle_layers),
|
|
1955
|
+
("late", late_layers),
|
|
1956
|
+
("first_half", first_half),
|
|
1957
|
+
("second_half", second_half),
|
|
1958
|
+
]
|
|
1959
|
+
|
|
1960
|
+
for subset_name, subset_layers in subsets:
|
|
1961
|
+
if len(subset_layers) >= 1:
|
|
1962
|
+
subset_pos, subset_neg = _combine_layer_activations(
|
|
1963
|
+
pos_activations_by_layer, neg_activations_by_layer, subset_layers, cfg.combination_method
|
|
1964
|
+
)
|
|
1965
|
+
result = detect_geometry_structure(subset_pos, subset_neg, geo_cfg)
|
|
1966
|
+
layer_subset_results[subset_name] = result
|
|
1967
|
+
all_combo_results[subset_name] = result
|
|
1968
|
+
|
|
1969
|
+
# 5. Analyze layer pairs
|
|
1970
|
+
layer_pair_results: Dict[str, GeometryAnalysisResult] = {}
|
|
1971
|
+
if cfg.analyze_pairs and len(layers) >= 2:
|
|
1972
|
+
from itertools import combinations
|
|
1973
|
+
pair_count = 0
|
|
1974
|
+
for l1, l2 in combinations(layers, 2):
|
|
1975
|
+
if pair_count >= cfg.max_pair_combinations:
|
|
1976
|
+
break
|
|
1977
|
+
pair_name = f"L{l1}+L{l2}"
|
|
1978
|
+
pair_pos, pair_neg = _combine_layer_activations(
|
|
1979
|
+
pos_activations_by_layer, neg_activations_by_layer, [l1, l2], cfg.combination_method
|
|
1980
|
+
)
|
|
1981
|
+
result = detect_geometry_structure(pair_pos, pair_neg, geo_cfg)
|
|
1982
|
+
layer_pair_results[pair_name] = result
|
|
1983
|
+
all_combo_results[pair_name] = result
|
|
1984
|
+
pair_count += 1
|
|
1985
|
+
|
|
1986
|
+
# 6. Analyze adjacent layer pairs
|
|
1987
|
+
adjacent_pair_results: Dict[str, GeometryAnalysisResult] = {}
|
|
1988
|
+
if cfg.analyze_adjacent and len(layers) >= 2:
|
|
1989
|
+
for i in range(len(layers) - 1):
|
|
1990
|
+
l1, l2 = layers[i], layers[i + 1]
|
|
1991
|
+
pair_name = f"adj_L{l1}+L{l2}"
|
|
1992
|
+
pair_pos, pair_neg = _combine_layer_activations(
|
|
1993
|
+
pos_activations_by_layer, neg_activations_by_layer, [l1, l2], cfg.combination_method
|
|
1994
|
+
)
|
|
1995
|
+
result = detect_geometry_structure(pair_pos, pair_neg, geo_cfg)
|
|
1996
|
+
adjacent_pair_results[pair_name] = result
|
|
1997
|
+
all_combo_results[pair_name] = result
|
|
1998
|
+
|
|
1999
|
+
# 7. Analyze skip patterns
|
|
2000
|
+
skip_results: Dict[str, GeometryAnalysisResult] = {}
|
|
2001
|
+
if cfg.analyze_skip and len(layers) >= 4:
|
|
2002
|
+
# Every 2nd layer
|
|
2003
|
+
every_2nd = layers[::2]
|
|
2004
|
+
if len(every_2nd) >= 2:
|
|
2005
|
+
skip_pos, skip_neg = _combine_layer_activations(
|
|
2006
|
+
pos_activations_by_layer, neg_activations_by_layer, every_2nd, cfg.combination_method
|
|
2007
|
+
)
|
|
2008
|
+
result = detect_geometry_structure(skip_pos, skip_neg, geo_cfg)
|
|
2009
|
+
skip_results["every_2nd"] = result
|
|
2010
|
+
all_combo_results["every_2nd"] = result
|
|
2011
|
+
|
|
2012
|
+
# Every 3rd layer
|
|
2013
|
+
if len(layers) >= 6:
|
|
2014
|
+
every_3rd = layers[::3]
|
|
2015
|
+
if len(every_3rd) >= 2:
|
|
2016
|
+
skip_pos, skip_neg = _combine_layer_activations(
|
|
2017
|
+
pos_activations_by_layer, neg_activations_by_layer, every_3rd, cfg.combination_method
|
|
2018
|
+
)
|
|
2019
|
+
result = detect_geometry_structure(skip_pos, skip_neg, geo_cfg)
|
|
2020
|
+
skip_results["every_3rd"] = result
|
|
2021
|
+
all_combo_results["every_3rd"] = result
|
|
2022
|
+
|
|
2023
|
+
# First and last layer only
|
|
2024
|
+
first_last = [layers[0], layers[-1]]
|
|
2025
|
+
skip_pos, skip_neg = _combine_layer_activations(
|
|
2026
|
+
pos_activations_by_layer, neg_activations_by_layer, first_last, cfg.combination_method
|
|
2027
|
+
)
|
|
2028
|
+
result = detect_geometry_structure(skip_pos, skip_neg, geo_cfg)
|
|
2029
|
+
skip_results["first_last"] = result
|
|
2030
|
+
all_combo_results["first_last"] = result
|
|
2031
|
+
|
|
2032
|
+
# First, middle, last
|
|
2033
|
+
if len(layers) >= 3:
|
|
2034
|
+
mid_idx = len(layers) // 2
|
|
2035
|
+
first_mid_last = [layers[0], layers[mid_idx], layers[-1]]
|
|
2036
|
+
skip_pos, skip_neg = _combine_layer_activations(
|
|
2037
|
+
pos_activations_by_layer, neg_activations_by_layer, first_mid_last, cfg.combination_method
|
|
2038
|
+
)
|
|
2039
|
+
result = detect_geometry_structure(skip_pos, skip_neg, geo_cfg)
|
|
2040
|
+
skip_results["first_mid_last"] = result
|
|
2041
|
+
all_combo_results["first_mid_last"] = result
|
|
2042
|
+
|
|
2043
|
+
# 8. Analyze custom combinations
|
|
2044
|
+
custom_results: Dict[str, GeometryAnalysisResult] = {}
|
|
2045
|
+
if cfg.analyze_custom:
|
|
2046
|
+
for i, custom_layers in enumerate(cfg.analyze_custom):
|
|
2047
|
+
valid_layers = [l for l in custom_layers if l in layers]
|
|
2048
|
+
if len(valid_layers) >= 1:
|
|
2049
|
+
custom_name = f"custom_{i}_L" + "+L".join(map(str, valid_layers))
|
|
2050
|
+
custom_pos, custom_neg = _combine_layer_activations(
|
|
2051
|
+
pos_activations_by_layer, neg_activations_by_layer, valid_layers, cfg.combination_method
|
|
2052
|
+
)
|
|
2053
|
+
result = detect_geometry_structure(custom_pos, custom_neg, geo_cfg)
|
|
2054
|
+
custom_results[custom_name] = result
|
|
2055
|
+
all_combo_results[custom_name] = result
|
|
2056
|
+
|
|
2057
|
+
# 9. Compute layer agreement
|
|
2058
|
+
if per_layer_results:
|
|
2059
|
+
structures = [r.best_structure for r in per_layer_results.values()]
|
|
2060
|
+
most_common = max(set(structures), key=structures.count)
|
|
2061
|
+
layer_agreement = structures.count(most_common) / len(structures)
|
|
2062
|
+
else:
|
|
2063
|
+
layer_agreement = 0.0
|
|
2064
|
+
|
|
2065
|
+
# 10. Rank all combinations and find best
|
|
2066
|
+
all_combinations_ranked = sorted(
|
|
2067
|
+
[(name, r.best_score, r.best_structure) for name, r in all_combo_results.items()],
|
|
2068
|
+
key=lambda x: x[1],
|
|
2069
|
+
reverse=True
|
|
2070
|
+
)
|
|
2071
|
+
|
|
2072
|
+
if all_combinations_ranked:
|
|
2073
|
+
best_combo_name, best_combo_score, best_combo_structure = all_combinations_ranked[0]
|
|
2074
|
+
if best_combo_score > best_single_layer_score:
|
|
2075
|
+
best_combination = best_combo_name
|
|
2076
|
+
best_combination_score = best_combo_score
|
|
2077
|
+
best_combination_structure = best_combo_structure
|
|
2078
|
+
else:
|
|
2079
|
+
best_combination = None
|
|
2080
|
+
best_combination_score = best_single_layer_score
|
|
2081
|
+
best_combination_structure = best_single_layer_structure
|
|
2082
|
+
else:
|
|
2083
|
+
best_combination = None
|
|
2084
|
+
best_combination_score = best_single_layer_score
|
|
2085
|
+
best_combination_structure = best_single_layer_structure
|
|
2086
|
+
|
|
2087
|
+
# 11. Compare combined vs single
|
|
2088
|
+
if combined_result and per_layer_results:
|
|
2089
|
+
if combined_result.best_score > best_single_layer_score + 0.1:
|
|
2090
|
+
combined_vs_single = f"Combined ({combined_result.best_score:.2f}) better than single layer ({best_single_layer_score:.2f})"
|
|
2091
|
+
elif best_single_layer_score > combined_result.best_score + 0.1:
|
|
2092
|
+
combined_vs_single = f"Single layer {best_single_layer} ({best_single_layer_score:.2f}) better than combined ({combined_result.best_score:.2f})"
|
|
2093
|
+
else:
|
|
2094
|
+
combined_vs_single = f"Similar performance: combined={combined_result.best_score:.2f}, single={best_single_layer_score:.2f}"
|
|
2095
|
+
else:
|
|
2096
|
+
combined_vs_single = "No comparison available"
|
|
2097
|
+
|
|
2098
|
+
# 12. Generate recommendation
|
|
2099
|
+
recommendation = _generate_multi_layer_recommendation_v2(
|
|
2100
|
+
per_layer_results, combined_result, layer_subset_results,
|
|
2101
|
+
layer_pair_results, skip_results,
|
|
2102
|
+
best_single_layer, best_single_layer_structure, best_single_layer_score,
|
|
2103
|
+
best_combination, best_combination_score, best_combination_structure,
|
|
2104
|
+
layer_agreement, all_combinations_ranked
|
|
2105
|
+
)
|
|
2106
|
+
|
|
2107
|
+
return MultiLayerGeometryResult(
|
|
2108
|
+
per_layer_results=per_layer_results,
|
|
2109
|
+
combined_result=combined_result,
|
|
2110
|
+
layer_subset_results=layer_subset_results,
|
|
2111
|
+
layer_pair_results=layer_pair_results,
|
|
2112
|
+
adjacent_pair_results=adjacent_pair_results,
|
|
2113
|
+
skip_results=skip_results,
|
|
2114
|
+
custom_results=custom_results,
|
|
2115
|
+
best_single_layer=best_single_layer,
|
|
2116
|
+
best_single_layer_structure=best_single_layer_structure,
|
|
2117
|
+
best_single_layer_score=best_single_layer_score,
|
|
2118
|
+
best_combination=best_combination,
|
|
2119
|
+
best_combination_score=best_combination_score,
|
|
2120
|
+
best_combination_structure=best_combination_structure,
|
|
2121
|
+
combined_vs_single=combined_vs_single,
|
|
2122
|
+
layer_agreement=layer_agreement,
|
|
2123
|
+
structure_by_depth=structure_by_depth,
|
|
2124
|
+
all_combinations_ranked=all_combinations_ranked,
|
|
2125
|
+
recommendation=recommendation,
|
|
2126
|
+
)
|
|
2127
|
+
|
|
2128
|
+
|
|
2129
|
+
def _combine_layer_activations(
|
|
2130
|
+
pos_by_layer: Dict[int, torch.Tensor],
|
|
2131
|
+
neg_by_layer: Dict[int, torch.Tensor],
|
|
2132
|
+
layers: List[int],
|
|
2133
|
+
method: str = "concat",
|
|
2134
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
2135
|
+
"""Combine activations from multiple layers."""
|
|
2136
|
+
pos_acts = [pos_by_layer[l] for l in layers if l in pos_by_layer]
|
|
2137
|
+
neg_acts = [neg_by_layer[l] for l in layers if l in neg_by_layer]
|
|
2138
|
+
|
|
2139
|
+
if not pos_acts or not neg_acts:
|
|
2140
|
+
raise ValueError("No activations found for specified layers")
|
|
2141
|
+
|
|
2142
|
+
if method == "concat":
|
|
2143
|
+
combined_pos = torch.cat(pos_acts, dim=-1)
|
|
2144
|
+
combined_neg = torch.cat(neg_acts, dim=-1)
|
|
2145
|
+
elif method == "mean":
|
|
2146
|
+
combined_pos = torch.stack(pos_acts, dim=0).mean(dim=0)
|
|
2147
|
+
combined_neg = torch.stack(neg_acts, dim=0).mean(dim=0)
|
|
2148
|
+
elif method == "weighted":
|
|
2149
|
+
weights = torch.linspace(0.5, 1.5, len(pos_acts))
|
|
2150
|
+
weights = weights / weights.sum()
|
|
2151
|
+
combined_pos = sum(w * a for w, a in zip(weights, pos_acts))
|
|
2152
|
+
combined_neg = sum(w * a for w, a in zip(weights, neg_acts))
|
|
2153
|
+
else:
|
|
2154
|
+
raise ValueError(f"Unknown combination method: {method}")
|
|
2155
|
+
|
|
2156
|
+
return combined_pos, combined_neg
|
|
2157
|
+
|
|
2158
|
+
|
|
2159
|
+
def _generate_multi_layer_recommendation(
|
|
2160
|
+
per_layer_results: Dict[int, LayerGeometryResult],
|
|
2161
|
+
combined_result: Optional[GeometryAnalysisResult],
|
|
2162
|
+
layer_subset_results: Dict[str, GeometryAnalysisResult],
|
|
2163
|
+
best_single_layer: int,
|
|
2164
|
+
best_single_layer_structure: StructureType,
|
|
2165
|
+
best_single_layer_score: float,
|
|
2166
|
+
layer_agreement: float,
|
|
2167
|
+
) -> str:
|
|
2168
|
+
"""Generate recommendation based on multi-layer analysis."""
|
|
2169
|
+
parts = []
|
|
2170
|
+
|
|
2171
|
+
# Layer agreement insight
|
|
2172
|
+
if layer_agreement > 0.8:
|
|
2173
|
+
parts.append(f"High layer agreement ({layer_agreement:.0%}): structure is consistent across depth.")
|
|
2174
|
+
elif layer_agreement < 0.4:
|
|
2175
|
+
parts.append(f"Low layer agreement ({layer_agreement:.0%}): different structures at different depths.")
|
|
2176
|
+
|
|
2177
|
+
# Best layer recommendation
|
|
2178
|
+
parts.append(f"Best single layer: {best_single_layer} with {best_single_layer_structure.value} ({best_single_layer_score:.2f}).")
|
|
2179
|
+
|
|
2180
|
+
# Combined vs single
|
|
2181
|
+
if combined_result:
|
|
2182
|
+
if combined_result.best_score > best_single_layer_score + 0.1:
|
|
2183
|
+
parts.append(f"Combined layers improve detection ({combined_result.best_score:.2f} vs {best_single_layer_score:.2f}). Use multi-layer steering.")
|
|
2184
|
+
else:
|
|
2185
|
+
parts.append(f"Single layer is sufficient. Target layer {best_single_layer}.")
|
|
2186
|
+
|
|
2187
|
+
# Layer subset insights
|
|
2188
|
+
if layer_subset_results:
|
|
2189
|
+
subset_scores = {name: r.best_score for name, r in layer_subset_results.items()}
|
|
2190
|
+
best_subset = max(subset_scores.keys(), key=lambda k: subset_scores[k])
|
|
2191
|
+
if subset_scores[best_subset] > best_single_layer_score:
|
|
2192
|
+
parts.append(f"'{best_subset}' layers show strongest structure ({subset_scores[best_subset]:.2f}).")
|
|
2193
|
+
|
|
2194
|
+
return " ".join(parts)
|
|
2195
|
+
|
|
2196
|
+
|
|
2197
|
+
def _generate_multi_layer_recommendation_v2(
|
|
2198
|
+
per_layer_results: Dict[int, LayerGeometryResult],
|
|
2199
|
+
combined_result: Optional[GeometryAnalysisResult],
|
|
2200
|
+
layer_subset_results: Dict[str, GeometryAnalysisResult],
|
|
2201
|
+
layer_pair_results: Dict[str, GeometryAnalysisResult],
|
|
2202
|
+
skip_results: Dict[str, GeometryAnalysisResult],
|
|
2203
|
+
best_single_layer: int,
|
|
2204
|
+
best_single_layer_structure: StructureType,
|
|
2205
|
+
best_single_layer_score: float,
|
|
2206
|
+
best_combination: Optional[str],
|
|
2207
|
+
best_combination_score: float,
|
|
2208
|
+
best_combination_structure: Optional[StructureType],
|
|
2209
|
+
layer_agreement: float,
|
|
2210
|
+
all_combinations_ranked: List[Tuple[str, float, StructureType]],
|
|
2211
|
+
) -> str:
|
|
2212
|
+
"""Generate comprehensive recommendation based on multi-layer analysis."""
|
|
2213
|
+
parts = []
|
|
2214
|
+
|
|
2215
|
+
# Layer agreement insight
|
|
2216
|
+
if layer_agreement > 0.8:
|
|
2217
|
+
parts.append(f"High layer agreement ({layer_agreement:.0%}): consistent structure across depth.")
|
|
2218
|
+
elif layer_agreement < 0.4:
|
|
2219
|
+
parts.append(f"Low layer agreement ({layer_agreement:.0%}): structure varies by depth.")
|
|
2220
|
+
else:
|
|
2221
|
+
parts.append(f"Moderate layer agreement ({layer_agreement:.0%}).")
|
|
2222
|
+
|
|
2223
|
+
# Overall best recommendation
|
|
2224
|
+
if best_combination and best_combination_score > best_single_layer_score + 0.05:
|
|
2225
|
+
improvement = best_combination_score - best_single_layer_score
|
|
2226
|
+
parts.append(
|
|
2227
|
+
f"BEST: '{best_combination}' ({best_combination_structure.value}: {best_combination_score:.2f}) "
|
|
2228
|
+
f"outperforms single layer {best_single_layer} by {improvement:.2f}."
|
|
2229
|
+
)
|
|
2230
|
+
else:
|
|
2231
|
+
parts.append(
|
|
2232
|
+
f"BEST: Layer {best_single_layer} ({best_single_layer_structure.value}: {best_single_layer_score:.2f}). "
|
|
2233
|
+
f"Multi-layer combinations don't improve detection."
|
|
2234
|
+
)
|
|
2235
|
+
|
|
2236
|
+
# Top 3 combinations summary
|
|
2237
|
+
if len(all_combinations_ranked) >= 3:
|
|
2238
|
+
top3 = all_combinations_ranked[:3]
|
|
2239
|
+
top3_str = ", ".join([f"{name}={score:.2f}" for name, score, _ in top3])
|
|
2240
|
+
parts.append(f"Top 3: {top3_str}.")
|
|
2241
|
+
|
|
2242
|
+
# Specific pattern insights
|
|
2243
|
+
if skip_results:
|
|
2244
|
+
skip_scores = {name: r.best_score for name, r in skip_results.items()}
|
|
2245
|
+
best_skip = max(skip_scores.keys(), key=lambda k: skip_scores[k])
|
|
2246
|
+
if skip_scores[best_skip] > best_single_layer_score:
|
|
2247
|
+
parts.append(f"Skip pattern '{best_skip}' is effective ({skip_scores[best_skip]:.2f}).")
|
|
2248
|
+
|
|
2249
|
+
if layer_pair_results:
|
|
2250
|
+
pair_scores = {name: r.best_score for name, r in layer_pair_results.items()}
|
|
2251
|
+
best_pair = max(pair_scores.keys(), key=lambda k: pair_scores[k])
|
|
2252
|
+
best_pair_score = pair_scores[best_pair]
|
|
2253
|
+
if best_pair_score > best_single_layer_score:
|
|
2254
|
+
parts.append(f"Layer pair '{best_pair}' shows synergy ({best_pair_score:.2f}).")
|
|
2255
|
+
|
|
2256
|
+
# Depth pattern analysis
|
|
2257
|
+
if per_layer_results and len(per_layer_results) >= 3:
|
|
2258
|
+
layers_sorted = sorted(per_layer_results.keys())
|
|
2259
|
+
early_score = per_layer_results[layers_sorted[0]].best_score
|
|
2260
|
+
late_score = per_layer_results[layers_sorted[-1]].best_score
|
|
2261
|
+
if late_score > early_score + 0.2:
|
|
2262
|
+
parts.append("Later layers show stronger structure than early layers.")
|
|
2263
|
+
elif early_score > late_score + 0.2:
|
|
2264
|
+
parts.append("Early layers show stronger structure than later layers.")
|
|
2265
|
+
|
|
2266
|
+
return " ".join(parts)
|
|
2267
|
+
|
|
2268
|
+
|
|
2269
|
+
def detect_geometry_all_layers(
|
|
2270
|
+
pairs_with_activations: List,
|
|
2271
|
+
layers: Optional[List[int]] = None,
|
|
2272
|
+
config: MultiLayerGeometryConfig | None = None,
|
|
2273
|
+
) -> MultiLayerGeometryResult:
|
|
2274
|
+
"""
|
|
2275
|
+
Convenience function to detect geometry from pairs with pre-collected activations.
|
|
2276
|
+
|
|
2277
|
+
Arguments:
|
|
2278
|
+
pairs_with_activations: List of ContrastivePair objects with layers_activations populated
|
|
2279
|
+
layers: Specific layers to analyze (None = all available)
|
|
2280
|
+
config: Analysis configuration
|
|
2281
|
+
|
|
2282
|
+
Returns:
|
|
2283
|
+
MultiLayerGeometryResult
|
|
2284
|
+
"""
|
|
2285
|
+
if not pairs_with_activations:
|
|
2286
|
+
raise ValueError("No pairs provided")
|
|
2287
|
+
|
|
2288
|
+
# Extract activations by layer
|
|
2289
|
+
pos_by_layer: Dict[int, List[torch.Tensor]] = {}
|
|
2290
|
+
neg_by_layer: Dict[int, List[torch.Tensor]] = {}
|
|
2291
|
+
|
|
2292
|
+
for pair in pairs_with_activations:
|
|
2293
|
+
pos_acts = pair.positive_response.layers_activations
|
|
2294
|
+
neg_acts = pair.negative_response.layers_activations
|
|
2295
|
+
|
|
2296
|
+
for layer_key, act in pos_acts.items():
|
|
2297
|
+
layer = int(layer_key)
|
|
2298
|
+
if layers is None or layer in layers:
|
|
2299
|
+
if layer not in pos_by_layer:
|
|
2300
|
+
pos_by_layer[layer] = []
|
|
2301
|
+
pos_by_layer[layer].append(act.float() if act is not None else None)
|
|
2302
|
+
|
|
2303
|
+
for layer_key, act in neg_acts.items():
|
|
2304
|
+
layer = int(layer_key)
|
|
2305
|
+
if layers is None or layer in layers:
|
|
2306
|
+
if layer not in neg_by_layer:
|
|
2307
|
+
neg_by_layer[layer] = []
|
|
2308
|
+
neg_by_layer[layer].append(act.float() if act is not None else None)
|
|
2309
|
+
|
|
2310
|
+
# Stack into tensors
|
|
2311
|
+
pos_tensors = {}
|
|
2312
|
+
neg_tensors = {}
|
|
2313
|
+
for layer in pos_by_layer:
|
|
2314
|
+
valid_pos = [a for a in pos_by_layer[layer] if a is not None]
|
|
2315
|
+
valid_neg = [a for a in neg_by_layer.get(layer, []) if a is not None]
|
|
2316
|
+
if valid_pos and valid_neg:
|
|
2317
|
+
pos_tensors[layer] = torch.stack(valid_pos)
|
|
2318
|
+
neg_tensors[layer] = torch.stack(valid_neg)
|
|
2319
|
+
|
|
2320
|
+
return detect_geometry_multi_layer(pos_tensors, neg_tensors, config)
|
|
2321
|
+
|
|
2322
|
+
|
|
2323
|
+
@dataclass
|
|
2324
|
+
class ExhaustiveCombinationResult:
|
|
2325
|
+
"""Result for a single layer combination."""
|
|
2326
|
+
layers: Tuple[int, ...]
|
|
2327
|
+
best_structure: StructureType
|
|
2328
|
+
best_score: float
|
|
2329
|
+
all_scores: Dict[str, float]
|
|
2330
|
+
|
|
2331
|
+
|
|
2332
|
+
@dataclass
|
|
2333
|
+
class ExhaustiveGeometryAnalysisResult:
|
|
2334
|
+
"""Results from exhaustive layer combination analysis."""
|
|
2335
|
+
|
|
2336
|
+
total_combinations: int
|
|
2337
|
+
"""Total number of combinations tested."""
|
|
2338
|
+
|
|
2339
|
+
all_results: List[ExhaustiveCombinationResult]
|
|
2340
|
+
"""All results, sorted by best_score descending."""
|
|
2341
|
+
|
|
2342
|
+
best_combination: Tuple[int, ...]
|
|
2343
|
+
"""Layer combination with highest score."""
|
|
2344
|
+
|
|
2345
|
+
best_score: float
|
|
2346
|
+
"""Highest score achieved."""
|
|
2347
|
+
|
|
2348
|
+
best_structure: StructureType
|
|
2349
|
+
"""Structure type at best combination."""
|
|
2350
|
+
|
|
2351
|
+
top_10: List[ExhaustiveCombinationResult]
|
|
2352
|
+
"""Top 10 combinations."""
|
|
2353
|
+
|
|
2354
|
+
single_layer_best: int
|
|
2355
|
+
"""Best single layer."""
|
|
2356
|
+
|
|
2357
|
+
single_layer_best_score: float
|
|
2358
|
+
"""Score of best single layer."""
|
|
2359
|
+
|
|
2360
|
+
combination_beats_single: bool
|
|
2361
|
+
"""Whether any multi-layer combination beats best single layer."""
|
|
2362
|
+
|
|
2363
|
+
improvement_over_single: float
|
|
2364
|
+
"""How much best combination improves over best single layer."""
|
|
2365
|
+
|
|
2366
|
+
patterns: Dict[str, Any]
|
|
2367
|
+
"""Discovered patterns (layer frequency in top combinations, etc.)."""
|
|
2368
|
+
|
|
2369
|
+
recommendation: str
|
|
2370
|
+
"""Final recommendation."""
|
|
2371
|
+
|
|
2372
|
+
|
|
2373
|
+
def detect_geometry_exhaustive(
|
|
2374
|
+
pos_activations_by_layer: Dict[int, torch.Tensor],
|
|
2375
|
+
neg_activations_by_layer: Dict[int, torch.Tensor],
|
|
2376
|
+
max_layers: int = 16,
|
|
2377
|
+
combination_method: str = "concat",
|
|
2378
|
+
num_components: int = 5,
|
|
2379
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
2380
|
+
top_k: int = 100,
|
|
2381
|
+
) -> ExhaustiveGeometryAnalysisResult:
|
|
2382
|
+
"""
|
|
2383
|
+
Exhaustively test all 2^N - 1 layer combinations for geometric structure.
|
|
2384
|
+
|
|
2385
|
+
Memory-efficient: uses generators and only keeps top_k results in memory.
|
|
2386
|
+
|
|
2387
|
+
Arguments:
|
|
2388
|
+
pos_activations_by_layer: Dict mapping layer index to positive activations [N, hidden_dim]
|
|
2389
|
+
neg_activations_by_layer: Dict mapping layer index to negative activations [N, hidden_dim]
|
|
2390
|
+
max_layers: Maximum number of layers to consider (limits combinations)
|
|
2391
|
+
combination_method: How to combine layers ("concat", "mean", "weighted")
|
|
2392
|
+
num_components: Number of PCA components for analysis
|
|
2393
|
+
progress_callback: Optional callback(current, total) for progress reporting
|
|
2394
|
+
top_k: Number of top results to keep in memory (default 100)
|
|
2395
|
+
|
|
2396
|
+
Returns:
|
|
2397
|
+
ExhaustiveGeometryAnalysisResult with top combinations ranked
|
|
2398
|
+
"""
|
|
2399
|
+
import heapq
|
|
2400
|
+
from itertools import combinations as itertools_combinations
|
|
2401
|
+
|
|
2402
|
+
layers = sorted(pos_activations_by_layer.keys())[:max_layers]
|
|
2403
|
+
n_layers = len(layers)
|
|
2404
|
+
|
|
2405
|
+
if n_layers == 0:
|
|
2406
|
+
raise ValueError("No layers provided")
|
|
2407
|
+
|
|
2408
|
+
geo_cfg = GeometryAnalysisConfig(num_components=num_components, optimization_steps=50)
|
|
2409
|
+
|
|
2410
|
+
# Calculate total without building list (2^n - 1)
|
|
2411
|
+
total_combinations = (1 << n_layers) - 1
|
|
2412
|
+
|
|
2413
|
+
# Use min-heap to keep top_k results (negate scores for max-heap behavior)
|
|
2414
|
+
top_results_heap: List[Tuple[float, ExhaustiveCombinationResult]] = []
|
|
2415
|
+
single_layer_results: List[ExhaustiveCombinationResult] = []
|
|
2416
|
+
|
|
2417
|
+
# Generator for combinations - no upfront memory allocation
|
|
2418
|
+
def combo_generator():
|
|
2419
|
+
for r in range(1, n_layers + 1):
|
|
2420
|
+
for combo in itertools_combinations(layers, r):
|
|
2421
|
+
yield combo
|
|
2422
|
+
|
|
2423
|
+
# Test each combination
|
|
2424
|
+
idx = 0
|
|
2425
|
+
for combo in combo_generator():
|
|
2426
|
+
idx += 1
|
|
2427
|
+
if progress_callback:
|
|
2428
|
+
progress_callback(idx, total_combinations)
|
|
2429
|
+
|
|
2430
|
+
# Combine activations for this subset
|
|
2431
|
+
if len(combo) == 1:
|
|
2432
|
+
layer = combo[0]
|
|
2433
|
+
combined_pos = pos_activations_by_layer[layer]
|
|
2434
|
+
combined_neg = neg_activations_by_layer[layer]
|
|
2435
|
+
else:
|
|
2436
|
+
combined_pos, combined_neg = _combine_layer_activations(
|
|
2437
|
+
pos_activations_by_layer, neg_activations_by_layer,
|
|
2438
|
+
list(combo), combination_method
|
|
2439
|
+
)
|
|
2440
|
+
|
|
2441
|
+
# Run geometry detection
|
|
2442
|
+
result = detect_geometry_structure(combined_pos, combined_neg, geo_cfg)
|
|
2443
|
+
|
|
2444
|
+
all_scores = {name: score.score for name, score in result.all_scores.items()}
|
|
2445
|
+
combo_result = ExhaustiveCombinationResult(
|
|
2446
|
+
layers=combo,
|
|
2447
|
+
best_structure=result.best_structure,
|
|
2448
|
+
best_score=result.best_score,
|
|
2449
|
+
all_scores=all_scores,
|
|
2450
|
+
)
|
|
2451
|
+
|
|
2452
|
+
# Track single layer results separately
|
|
2453
|
+
if len(combo) == 1:
|
|
2454
|
+
single_layer_results.append(combo_result)
|
|
2455
|
+
|
|
2456
|
+
# Maintain top_k using heap
|
|
2457
|
+
if len(top_results_heap) < top_k:
|
|
2458
|
+
heapq.heappush(top_results_heap, (combo_result.best_score, combo_result))
|
|
2459
|
+
elif combo_result.best_score > top_results_heap[0][0]:
|
|
2460
|
+
heapq.heapreplace(top_results_heap, (combo_result.best_score, combo_result))
|
|
2461
|
+
|
|
2462
|
+
# Extract top results sorted by score descending
|
|
2463
|
+
all_results = [r for _, r in sorted(top_results_heap, key=lambda x: -x[0])]
|
|
2464
|
+
|
|
2465
|
+
# Extract insights
|
|
2466
|
+
best_result = all_results[0] if all_results else None
|
|
2467
|
+
best_combination = best_result.layers if best_result else ()
|
|
2468
|
+
best_score = best_result.best_score if best_result else 0.0
|
|
2469
|
+
best_structure = best_result.best_structure if best_result else StructureType.UNKNOWN
|
|
2470
|
+
|
|
2471
|
+
top_10 = all_results[:10]
|
|
2472
|
+
|
|
2473
|
+
# Find best single layer
|
|
2474
|
+
if single_layer_results:
|
|
2475
|
+
single_layer_results.sort(key=lambda x: x.best_score, reverse=True)
|
|
2476
|
+
single_layer_best = single_layer_results[0].layers[0]
|
|
2477
|
+
single_layer_best_score = single_layer_results[0].best_score
|
|
2478
|
+
else:
|
|
2479
|
+
single_layer_best = layers[0]
|
|
2480
|
+
single_layer_best_score = 0.0
|
|
2481
|
+
|
|
2482
|
+
combination_beats_single = best_score > single_layer_best_score
|
|
2483
|
+
improvement_over_single = best_score - single_layer_best_score
|
|
2484
|
+
|
|
2485
|
+
# Analyze patterns from top results
|
|
2486
|
+
patterns = _analyze_combination_patterns(all_results, layers, top_k=min(50, len(all_results)))
|
|
2487
|
+
|
|
2488
|
+
# Generate recommendation
|
|
2489
|
+
recommendation = _generate_exhaustive_recommendation(
|
|
2490
|
+
best_combination, best_score, best_structure,
|
|
2491
|
+
single_layer_best, single_layer_best_score,
|
|
2492
|
+
combination_beats_single, improvement_over_single,
|
|
2493
|
+
patterns, total_combinations
|
|
2494
|
+
)
|
|
2495
|
+
|
|
2496
|
+
return ExhaustiveGeometryAnalysisResult(
|
|
2497
|
+
total_combinations=total_combinations,
|
|
2498
|
+
all_results=all_results,
|
|
2499
|
+
best_combination=best_combination,
|
|
2500
|
+
best_score=best_score,
|
|
2501
|
+
best_structure=best_structure,
|
|
2502
|
+
top_10=top_10,
|
|
2503
|
+
single_layer_best=single_layer_best,
|
|
2504
|
+
single_layer_best_score=single_layer_best_score,
|
|
2505
|
+
combination_beats_single=combination_beats_single,
|
|
2506
|
+
improvement_over_single=improvement_over_single,
|
|
2507
|
+
patterns=patterns,
|
|
2508
|
+
recommendation=recommendation,
|
|
2509
|
+
)
|
|
2510
|
+
|
|
2511
|
+
|
|
2512
|
+
def detect_geometry_limited(
|
|
2513
|
+
pos_activations_by_layer: Dict[int, torch.Tensor],
|
|
2514
|
+
neg_activations_by_layer: Dict[int, torch.Tensor],
|
|
2515
|
+
max_combo_size: int = 3,
|
|
2516
|
+
combination_method: str = "concat",
|
|
2517
|
+
num_components: int = 5,
|
|
2518
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
2519
|
+
top_k: int = 100,
|
|
2520
|
+
) -> ExhaustiveGeometryAnalysisResult:
|
|
2521
|
+
"""
|
|
2522
|
+
Test limited layer combinations: 1-layer, 2-layer, ..., max_combo_size-layer, plus all layers.
|
|
2523
|
+
|
|
2524
|
+
Much faster than exhaustive search while still finding good combinations.
|
|
2525
|
+
For N layers with max_combo_size=3:
|
|
2526
|
+
- 1-layer: N combinations
|
|
2527
|
+
- 2-layer: N*(N-1)/2 combinations
|
|
2528
|
+
- 3-layer: N*(N-1)*(N-2)/6 combinations
|
|
2529
|
+
- all-layers: 1 combination
|
|
2530
|
+
|
|
2531
|
+
Total: O(N^3) instead of O(2^N)
|
|
2532
|
+
|
|
2533
|
+
Arguments:
|
|
2534
|
+
pos_activations_by_layer: Dict mapping layer index to positive activations [N, hidden_dim]
|
|
2535
|
+
neg_activations_by_layer: Dict mapping layer index to negative activations [N, hidden_dim]
|
|
2536
|
+
max_combo_size: Maximum combination size to test (1, 2, 3, etc.) before jumping to all
|
|
2537
|
+
combination_method: How to combine layers ("concat", "mean", "weighted")
|
|
2538
|
+
num_components: Number of PCA components for analysis
|
|
2539
|
+
progress_callback: Optional callback(current, total) for progress reporting
|
|
2540
|
+
top_k: Number of top results to keep in memory (default 100)
|
|
2541
|
+
|
|
2542
|
+
Returns:
|
|
2543
|
+
ExhaustiveGeometryAnalysisResult with top combinations ranked
|
|
2544
|
+
"""
|
|
2545
|
+
import heapq
|
|
2546
|
+
from itertools import combinations as itertools_combinations
|
|
2547
|
+
from math import comb
|
|
2548
|
+
|
|
2549
|
+
layers = sorted(pos_activations_by_layer.keys())
|
|
2550
|
+
n_layers = len(layers)
|
|
2551
|
+
|
|
2552
|
+
if n_layers == 0:
|
|
2553
|
+
raise ValueError("No layers provided")
|
|
2554
|
+
|
|
2555
|
+
geo_cfg = GeometryAnalysisConfig(num_components=num_components, optimization_steps=50)
|
|
2556
|
+
|
|
2557
|
+
# Calculate total combinations: sum of C(n,r) for r=1 to max_combo_size, plus 1 for all layers
|
|
2558
|
+
total_combinations = sum(comb(n_layers, r) for r in range(1, min(max_combo_size, n_layers) + 1))
|
|
2559
|
+
if max_combo_size < n_layers:
|
|
2560
|
+
total_combinations += 1 # Add all-layers combination
|
|
2561
|
+
|
|
2562
|
+
# Use min-heap to keep top_k results
|
|
2563
|
+
top_results_heap: List[Tuple[float, ExhaustiveCombinationResult]] = []
|
|
2564
|
+
single_layer_results: List[ExhaustiveCombinationResult] = []
|
|
2565
|
+
|
|
2566
|
+
# Generator for limited combinations
|
|
2567
|
+
def combo_generator():
|
|
2568
|
+
# 1-layer, 2-layer, ..., max_combo_size-layer
|
|
2569
|
+
for r in range(1, min(max_combo_size, n_layers) + 1):
|
|
2570
|
+
for combo in itertools_combinations(layers, r):
|
|
2571
|
+
yield combo
|
|
2572
|
+
# All layers (if not already included)
|
|
2573
|
+
if max_combo_size < n_layers:
|
|
2574
|
+
yield tuple(layers)
|
|
2575
|
+
|
|
2576
|
+
# Test each combination
|
|
2577
|
+
idx = 0
|
|
2578
|
+
for combo in combo_generator():
|
|
2579
|
+
idx += 1
|
|
2580
|
+
if progress_callback:
|
|
2581
|
+
progress_callback(idx, total_combinations)
|
|
2582
|
+
|
|
2583
|
+
# Combine activations for this subset
|
|
2584
|
+
if len(combo) == 1:
|
|
2585
|
+
layer = combo[0]
|
|
2586
|
+
combined_pos = pos_activations_by_layer[layer]
|
|
2587
|
+
combined_neg = neg_activations_by_layer[layer]
|
|
2588
|
+
else:
|
|
2589
|
+
combined_pos, combined_neg = _combine_layer_activations(
|
|
2590
|
+
pos_activations_by_layer, neg_activations_by_layer,
|
|
2591
|
+
list(combo), combination_method
|
|
2592
|
+
)
|
|
2593
|
+
|
|
2594
|
+
# Run geometry detection
|
|
2595
|
+
result = detect_geometry_structure(combined_pos, combined_neg, geo_cfg)
|
|
2596
|
+
|
|
2597
|
+
all_scores = {name: score.score for name, score in result.all_scores.items()}
|
|
2598
|
+
combo_result = ExhaustiveCombinationResult(
|
|
2599
|
+
layers=combo,
|
|
2600
|
+
best_structure=result.best_structure,
|
|
2601
|
+
best_score=result.best_score,
|
|
2602
|
+
all_scores=all_scores,
|
|
2603
|
+
)
|
|
2604
|
+
|
|
2605
|
+
# Track single layer results separately
|
|
2606
|
+
if len(combo) == 1:
|
|
2607
|
+
single_layer_results.append(combo_result)
|
|
2608
|
+
|
|
2609
|
+
# Maintain top_k using heap
|
|
2610
|
+
if len(top_results_heap) < top_k:
|
|
2611
|
+
heapq.heappush(top_results_heap, (combo_result.best_score, combo_result))
|
|
2612
|
+
elif combo_result.best_score > top_results_heap[0][0]:
|
|
2613
|
+
heapq.heapreplace(top_results_heap, (combo_result.best_score, combo_result))
|
|
2614
|
+
|
|
2615
|
+
# Extract top results sorted by score descending
|
|
2616
|
+
all_results = [r for _, r in sorted(top_results_heap, key=lambda x: -x[0])]
|
|
2617
|
+
|
|
2618
|
+
# Extract insights
|
|
2619
|
+
best_result = all_results[0] if all_results else None
|
|
2620
|
+
best_combination = best_result.layers if best_result else ()
|
|
2621
|
+
best_score = best_result.best_score if best_result else 0.0
|
|
2622
|
+
best_structure = best_result.best_structure if best_result else StructureType.UNKNOWN
|
|
2623
|
+
|
|
2624
|
+
top_10 = all_results[:10]
|
|
2625
|
+
|
|
2626
|
+
# Find best single layer
|
|
2627
|
+
if single_layer_results:
|
|
2628
|
+
single_layer_results.sort(key=lambda x: x.best_score, reverse=True)
|
|
2629
|
+
single_layer_best = single_layer_results[0].layers[0]
|
|
2630
|
+
single_layer_best_score = single_layer_results[0].best_score
|
|
2631
|
+
else:
|
|
2632
|
+
single_layer_best = layers[0]
|
|
2633
|
+
single_layer_best_score = 0.0
|
|
2634
|
+
|
|
2635
|
+
combination_beats_single = best_score > single_layer_best_score
|
|
2636
|
+
improvement_over_single = best_score - single_layer_best_score
|
|
2637
|
+
|
|
2638
|
+
# Analyze patterns from top results
|
|
2639
|
+
patterns = _analyze_combination_patterns(all_results, layers, top_k=min(50, len(all_results)))
|
|
2640
|
+
|
|
2641
|
+
# Generate recommendation
|
|
2642
|
+
recommendation = _generate_exhaustive_recommendation(
|
|
2643
|
+
best_combination, best_score, best_structure,
|
|
2644
|
+
single_layer_best, single_layer_best_score,
|
|
2645
|
+
combination_beats_single, improvement_over_single,
|
|
2646
|
+
patterns, total_combinations
|
|
2647
|
+
)
|
|
2648
|
+
|
|
2649
|
+
return ExhaustiveGeometryAnalysisResult(
|
|
2650
|
+
total_combinations=total_combinations,
|
|
2651
|
+
all_results=all_results,
|
|
2652
|
+
best_combination=best_combination,
|
|
2653
|
+
best_score=best_score,
|
|
2654
|
+
best_structure=best_structure,
|
|
2655
|
+
top_10=top_10,
|
|
2656
|
+
single_layer_best=single_layer_best,
|
|
2657
|
+
single_layer_best_score=single_layer_best_score,
|
|
2658
|
+
combination_beats_single=combination_beats_single,
|
|
2659
|
+
improvement_over_single=improvement_over_single,
|
|
2660
|
+
patterns=patterns,
|
|
2661
|
+
recommendation=recommendation,
|
|
2662
|
+
)
|
|
2663
|
+
|
|
2664
|
+
|
|
2665
|
+
def detect_geometry_contiguous(
|
|
2666
|
+
pos_activations_by_layer: Dict[int, torch.Tensor],
|
|
2667
|
+
neg_activations_by_layer: Dict[int, torch.Tensor],
|
|
2668
|
+
combination_method: str = "concat",
|
|
2669
|
+
num_components: int = 5,
|
|
2670
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
2671
|
+
top_k: int = 100,
|
|
2672
|
+
) -> ExhaustiveGeometryAnalysisResult:
|
|
2673
|
+
"""
|
|
2674
|
+
Test contiguous layer combinations only.
|
|
2675
|
+
|
|
2676
|
+
Only tests combinations where layers are adjacent: 1-2, 2-3, 1-3, 5-8, etc.
|
|
2677
|
+
Much faster: O(N^2) combinations instead of O(2^N).
|
|
2678
|
+
|
|
2679
|
+
For N layers: N*(N+1)/2 combinations
|
|
2680
|
+
- 36 layers: 666 combinations
|
|
2681
|
+
- 24 layers: 300 combinations
|
|
2682
|
+
|
|
2683
|
+
Arguments:
|
|
2684
|
+
pos_activations_by_layer: Dict mapping layer index to positive activations [N, hidden_dim]
|
|
2685
|
+
neg_activations_by_layer: Dict mapping layer index to negative activations [N, hidden_dim]
|
|
2686
|
+
combination_method: How to combine layers ("concat", "mean", "weighted")
|
|
2687
|
+
num_components: Number of PCA components for analysis
|
|
2688
|
+
progress_callback: Optional callback(current, total) for progress reporting
|
|
2689
|
+
top_k: Number of top results to keep in memory (default 100)
|
|
2690
|
+
|
|
2691
|
+
Returns:
|
|
2692
|
+
ExhaustiveGeometryAnalysisResult with top combinations ranked
|
|
2693
|
+
"""
|
|
2694
|
+
import heapq
|
|
2695
|
+
|
|
2696
|
+
layers = sorted(pos_activations_by_layer.keys())
|
|
2697
|
+
n_layers = len(layers)
|
|
2698
|
+
|
|
2699
|
+
if n_layers == 0:
|
|
2700
|
+
raise ValueError("No layers provided")
|
|
2701
|
+
|
|
2702
|
+
geo_cfg = GeometryAnalysisConfig(num_components=num_components, optimization_steps=50)
|
|
2703
|
+
|
|
2704
|
+
# Total contiguous combinations: N*(N+1)/2
|
|
2705
|
+
total_combinations = n_layers * (n_layers + 1) // 2
|
|
2706
|
+
|
|
2707
|
+
# Use min-heap to keep top_k results
|
|
2708
|
+
top_results_heap: List[Tuple[float, ExhaustiveCombinationResult]] = []
|
|
2709
|
+
single_layer_results: List[ExhaustiveCombinationResult] = []
|
|
2710
|
+
|
|
2711
|
+
# Generator for contiguous combinations
|
|
2712
|
+
def combo_generator():
|
|
2713
|
+
# For each starting layer
|
|
2714
|
+
for start_idx in range(n_layers):
|
|
2715
|
+
# For each ending layer (inclusive)
|
|
2716
|
+
for end_idx in range(start_idx, n_layers):
|
|
2717
|
+
yield tuple(layers[start_idx:end_idx + 1])
|
|
2718
|
+
|
|
2719
|
+
# Test each combination
|
|
2720
|
+
idx = 0
|
|
2721
|
+
for combo in combo_generator():
|
|
2722
|
+
idx += 1
|
|
2723
|
+
if progress_callback:
|
|
2724
|
+
progress_callback(idx, total_combinations)
|
|
2725
|
+
|
|
2726
|
+
# Combine activations for this subset
|
|
2727
|
+
if len(combo) == 1:
|
|
2728
|
+
layer = combo[0]
|
|
2729
|
+
combined_pos = pos_activations_by_layer[layer]
|
|
2730
|
+
combined_neg = neg_activations_by_layer[layer]
|
|
2731
|
+
else:
|
|
2732
|
+
combined_pos, combined_neg = _combine_layer_activations(
|
|
2733
|
+
pos_activations_by_layer, neg_activations_by_layer,
|
|
2734
|
+
list(combo), combination_method
|
|
2735
|
+
)
|
|
2736
|
+
|
|
2737
|
+
# Run geometry detection
|
|
2738
|
+
result = detect_geometry_structure(combined_pos, combined_neg, geo_cfg)
|
|
2739
|
+
|
|
2740
|
+
all_scores = {name: score.score for name, score in result.all_scores.items()}
|
|
2741
|
+
combo_result = ExhaustiveCombinationResult(
|
|
2742
|
+
layers=combo,
|
|
2743
|
+
best_structure=result.best_structure,
|
|
2744
|
+
best_score=result.best_score,
|
|
2745
|
+
all_scores=all_scores,
|
|
2746
|
+
)
|
|
2747
|
+
|
|
2748
|
+
# Track single layer results separately
|
|
2749
|
+
if len(combo) == 1:
|
|
2750
|
+
single_layer_results.append(combo_result)
|
|
2751
|
+
|
|
2752
|
+
# Maintain top_k using heap
|
|
2753
|
+
if len(top_results_heap) < top_k:
|
|
2754
|
+
heapq.heappush(top_results_heap, (combo_result.best_score, combo_result))
|
|
2755
|
+
elif combo_result.best_score > top_results_heap[0][0]:
|
|
2756
|
+
heapq.heapreplace(top_results_heap, (combo_result.best_score, combo_result))
|
|
2757
|
+
|
|
2758
|
+
# Extract top results sorted by score descending
|
|
2759
|
+
all_results = [r for _, r in sorted(top_results_heap, key=lambda x: -x[0])]
|
|
2760
|
+
|
|
2761
|
+
# Extract insights
|
|
2762
|
+
best_result = all_results[0] if all_results else None
|
|
2763
|
+
best_combination = best_result.layers if best_result else ()
|
|
2764
|
+
best_score = best_result.best_score if best_result else 0.0
|
|
2765
|
+
best_structure = best_result.best_structure if best_result else StructureType.UNKNOWN
|
|
2766
|
+
|
|
2767
|
+
top_10 = all_results[:10]
|
|
2768
|
+
|
|
2769
|
+
# Find best single layer
|
|
2770
|
+
if single_layer_results:
|
|
2771
|
+
single_layer_results.sort(key=lambda x: x.best_score, reverse=True)
|
|
2772
|
+
single_layer_best = single_layer_results[0].layers[0]
|
|
2773
|
+
single_layer_best_score = single_layer_results[0].best_score
|
|
2774
|
+
else:
|
|
2775
|
+
single_layer_best = layers[0]
|
|
2776
|
+
single_layer_best_score = 0.0
|
|
2777
|
+
|
|
2778
|
+
combination_beats_single = best_score > single_layer_best_score
|
|
2779
|
+
improvement_over_single = best_score - single_layer_best_score
|
|
2780
|
+
|
|
2781
|
+
# Analyze patterns from top results
|
|
2782
|
+
patterns = _analyze_combination_patterns(all_results, layers, top_k=min(50, len(all_results)))
|
|
2783
|
+
|
|
2784
|
+
# Generate recommendation
|
|
2785
|
+
recommendation = _generate_exhaustive_recommendation(
|
|
2786
|
+
best_combination, best_score, best_structure,
|
|
2787
|
+
single_layer_best, single_layer_best_score,
|
|
2788
|
+
combination_beats_single, improvement_over_single,
|
|
2789
|
+
patterns, total_combinations
|
|
2790
|
+
)
|
|
2791
|
+
|
|
2792
|
+
return ExhaustiveGeometryAnalysisResult(
|
|
2793
|
+
total_combinations=total_combinations,
|
|
2794
|
+
all_results=all_results,
|
|
2795
|
+
best_combination=best_combination,
|
|
2796
|
+
best_score=best_score,
|
|
2797
|
+
best_structure=best_structure,
|
|
2798
|
+
top_10=top_10,
|
|
2799
|
+
single_layer_best=single_layer_best,
|
|
2800
|
+
single_layer_best_score=single_layer_best_score,
|
|
2801
|
+
combination_beats_single=combination_beats_single,
|
|
2802
|
+
improvement_over_single=improvement_over_single,
|
|
2803
|
+
patterns=patterns,
|
|
2804
|
+
recommendation=recommendation,
|
|
2805
|
+
)
|
|
2806
|
+
|
|
2807
|
+
|
|
2808
|
+
def detect_geometry_smart(
|
|
2809
|
+
pos_activations_by_layer: Dict[int, torch.Tensor],
|
|
2810
|
+
neg_activations_by_layer: Dict[int, torch.Tensor],
|
|
2811
|
+
max_combo_size: int = 3,
|
|
2812
|
+
combination_method: str = "concat",
|
|
2813
|
+
num_components: int = 5,
|
|
2814
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
2815
|
+
top_k: int = 100,
|
|
2816
|
+
) -> ExhaustiveGeometryAnalysisResult:
|
|
2817
|
+
"""
|
|
2818
|
+
Smart layer combination search: contiguous + limited (1,2,3-layer) combinations.
|
|
2819
|
+
|
|
2820
|
+
Tests:
|
|
2821
|
+
1. All contiguous combinations (L1-L2, L1-L3, L5-L10, etc.)
|
|
2822
|
+
2. All 1,2,3-layer non-contiguous combinations
|
|
2823
|
+
|
|
2824
|
+
Deduplicates overlapping combinations.
|
|
2825
|
+
|
|
2826
|
+
For N=36 layers with max_combo_size=3:
|
|
2827
|
+
- Contiguous: 666 combinations
|
|
2828
|
+
- Limited non-contiguous: ~7,100 additional combinations
|
|
2829
|
+
- Total: ~7,800 unique combinations
|
|
2830
|
+
|
|
2831
|
+
Arguments:
|
|
2832
|
+
pos_activations_by_layer: Dict mapping layer index to positive activations [N, hidden_dim]
|
|
2833
|
+
neg_activations_by_layer: Dict mapping layer index to negative activations [N, hidden_dim]
|
|
2834
|
+
max_combo_size: Maximum combination size for non-contiguous (default: 3)
|
|
2835
|
+
combination_method: How to combine layers ("concat", "mean", "weighted")
|
|
2836
|
+
num_components: Number of PCA components for analysis
|
|
2837
|
+
progress_callback: Optional callback(current, total) for progress reporting
|
|
2838
|
+
top_k: Number of top results to keep in memory (default 100)
|
|
2839
|
+
|
|
2840
|
+
Returns:
|
|
2841
|
+
ExhaustiveGeometryAnalysisResult with top combinations ranked
|
|
2842
|
+
"""
|
|
2843
|
+
import heapq
|
|
2844
|
+
from itertools import combinations as itertools_combinations
|
|
2845
|
+
|
|
2846
|
+
layers = sorted(pos_activations_by_layer.keys())
|
|
2847
|
+
n_layers = len(layers)
|
|
2848
|
+
|
|
2849
|
+
if n_layers == 0:
|
|
2850
|
+
raise ValueError("No layers provided")
|
|
2851
|
+
|
|
2852
|
+
geo_cfg = GeometryAnalysisConfig(num_components=num_components, optimization_steps=50)
|
|
2853
|
+
|
|
2854
|
+
# Generate all unique combinations: contiguous + limited
|
|
2855
|
+
all_combos_set: set = set()
|
|
2856
|
+
|
|
2857
|
+
# Add contiguous combinations
|
|
2858
|
+
for start_idx in range(n_layers):
|
|
2859
|
+
for end_idx in range(start_idx, n_layers):
|
|
2860
|
+
all_combos_set.add(tuple(layers[start_idx:end_idx + 1]))
|
|
2861
|
+
|
|
2862
|
+
# Add limited combinations (1,2,3-layer)
|
|
2863
|
+
for r in range(1, min(max_combo_size, n_layers) + 1):
|
|
2864
|
+
for combo in itertools_combinations(layers, r):
|
|
2865
|
+
all_combos_set.add(combo)
|
|
2866
|
+
|
|
2867
|
+
# Convert to sorted list
|
|
2868
|
+
all_combos = sorted(all_combos_set, key=lambda x: (len(x), x))
|
|
2869
|
+
total_combinations = len(all_combos)
|
|
2870
|
+
|
|
2871
|
+
# Use min-heap to keep top_k results
|
|
2872
|
+
top_results_heap: List[Tuple[float, ExhaustiveCombinationResult]] = []
|
|
2873
|
+
single_layer_results: List[ExhaustiveCombinationResult] = []
|
|
2874
|
+
|
|
2875
|
+
# Test each combination
|
|
2876
|
+
for idx, combo in enumerate(all_combos):
|
|
2877
|
+
if progress_callback:
|
|
2878
|
+
progress_callback(idx + 1, total_combinations)
|
|
2879
|
+
|
|
2880
|
+
# Combine activations for this subset
|
|
2881
|
+
if len(combo) == 1:
|
|
2882
|
+
layer = combo[0]
|
|
2883
|
+
combined_pos = pos_activations_by_layer[layer]
|
|
2884
|
+
combined_neg = neg_activations_by_layer[layer]
|
|
2885
|
+
else:
|
|
2886
|
+
combined_pos, combined_neg = _combine_layer_activations(
|
|
2887
|
+
pos_activations_by_layer, neg_activations_by_layer,
|
|
2888
|
+
list(combo), combination_method
|
|
2889
|
+
)
|
|
2890
|
+
|
|
2891
|
+
# Run geometry detection
|
|
2892
|
+
result = detect_geometry_structure(combined_pos, combined_neg, geo_cfg)
|
|
2893
|
+
|
|
2894
|
+
all_scores = {name: score.score for name, score in result.all_scores.items()}
|
|
2895
|
+
combo_result = ExhaustiveCombinationResult(
|
|
2896
|
+
layers=combo,
|
|
2897
|
+
best_structure=result.best_structure,
|
|
2898
|
+
best_score=result.best_score,
|
|
2899
|
+
all_scores=all_scores,
|
|
2900
|
+
)
|
|
2901
|
+
|
|
2902
|
+
# Track single layer results separately
|
|
2903
|
+
if len(combo) == 1:
|
|
2904
|
+
single_layer_results.append(combo_result)
|
|
2905
|
+
|
|
2906
|
+
# Maintain top_k using heap
|
|
2907
|
+
if len(top_results_heap) < top_k:
|
|
2908
|
+
heapq.heappush(top_results_heap, (combo_result.best_score, combo_result))
|
|
2909
|
+
elif combo_result.best_score > top_results_heap[0][0]:
|
|
2910
|
+
heapq.heapreplace(top_results_heap, (combo_result.best_score, combo_result))
|
|
2911
|
+
|
|
2912
|
+
# Extract top results sorted by score descending
|
|
2913
|
+
all_results = [r for _, r in sorted(top_results_heap, key=lambda x: -x[0])]
|
|
2914
|
+
|
|
2915
|
+
# Extract insights
|
|
2916
|
+
best_result = all_results[0] if all_results else None
|
|
2917
|
+
best_combination = best_result.layers if best_result else ()
|
|
2918
|
+
best_score = best_result.best_score if best_result else 0.0
|
|
2919
|
+
best_structure = best_result.best_structure if best_result else StructureType.UNKNOWN
|
|
2920
|
+
|
|
2921
|
+
top_10 = all_results[:10]
|
|
2922
|
+
|
|
2923
|
+
# Find best single layer
|
|
2924
|
+
if single_layer_results:
|
|
2925
|
+
single_layer_results.sort(key=lambda x: x.best_score, reverse=True)
|
|
2926
|
+
single_layer_best = single_layer_results[0].layers[0]
|
|
2927
|
+
single_layer_best_score = single_layer_results[0].best_score
|
|
2928
|
+
else:
|
|
2929
|
+
single_layer_best = layers[0]
|
|
2930
|
+
single_layer_best_score = 0.0
|
|
2931
|
+
|
|
2932
|
+
combination_beats_single = best_score > single_layer_best_score
|
|
2933
|
+
improvement_over_single = best_score - single_layer_best_score
|
|
2934
|
+
|
|
2935
|
+
# Analyze patterns from top results
|
|
2936
|
+
patterns = _analyze_combination_patterns(all_results, layers, top_k=min(50, len(all_results)))
|
|
2937
|
+
|
|
2938
|
+
# Generate recommendation
|
|
2939
|
+
recommendation = _generate_exhaustive_recommendation(
|
|
2940
|
+
best_combination, best_score, best_structure,
|
|
2941
|
+
single_layer_best, single_layer_best_score,
|
|
2942
|
+
combination_beats_single, improvement_over_single,
|
|
2943
|
+
patterns, total_combinations
|
|
2944
|
+
)
|
|
2945
|
+
|
|
2946
|
+
return ExhaustiveGeometryAnalysisResult(
|
|
2947
|
+
total_combinations=total_combinations,
|
|
2948
|
+
all_results=all_results,
|
|
2949
|
+
best_combination=best_combination,
|
|
2950
|
+
best_score=best_score,
|
|
2951
|
+
best_structure=best_structure,
|
|
2952
|
+
top_10=top_10,
|
|
2953
|
+
single_layer_best=single_layer_best,
|
|
2954
|
+
single_layer_best_score=single_layer_best_score,
|
|
2955
|
+
combination_beats_single=combination_beats_single,
|
|
2956
|
+
improvement_over_single=improvement_over_single,
|
|
2957
|
+
patterns=patterns,
|
|
2958
|
+
recommendation=recommendation,
|
|
2959
|
+
)
|
|
2960
|
+
|
|
2961
|
+
|
|
2962
|
+
def _analyze_combination_patterns(
|
|
2963
|
+
all_results: List[ExhaustiveCombinationResult],
|
|
2964
|
+
layers: List[int],
|
|
2965
|
+
top_k: int = 50,
|
|
2966
|
+
) -> Dict[str, Any]:
|
|
2967
|
+
"""Analyze patterns in top combinations."""
|
|
2968
|
+
from collections import Counter
|
|
2969
|
+
|
|
2970
|
+
top_results = all_results[:top_k]
|
|
2971
|
+
|
|
2972
|
+
# Layer frequency in top combinations
|
|
2973
|
+
layer_freq = Counter()
|
|
2974
|
+
for r in top_results:
|
|
2975
|
+
for layer in r.layers:
|
|
2976
|
+
layer_freq[layer] += 1
|
|
2977
|
+
|
|
2978
|
+
# Combination size distribution in top results
|
|
2979
|
+
size_dist = Counter(len(r.layers) for r in top_results)
|
|
2980
|
+
|
|
2981
|
+
# Best score by combination size
|
|
2982
|
+
size_to_best: Dict[int, float] = {}
|
|
2983
|
+
for r in all_results:
|
|
2984
|
+
size = len(r.layers)
|
|
2985
|
+
if size not in size_to_best or r.best_score > size_to_best[size]:
|
|
2986
|
+
size_to_best[size] = r.best_score
|
|
2987
|
+
|
|
2988
|
+
# Structure frequency in top combinations
|
|
2989
|
+
structure_freq = Counter(r.best_structure for r in top_results)
|
|
2990
|
+
|
|
2991
|
+
# Adjacent layer pairs in top combinations
|
|
2992
|
+
adjacent_count = 0
|
|
2993
|
+
for r in top_results:
|
|
2994
|
+
if len(r.layers) >= 2:
|
|
2995
|
+
sorted_layers = sorted(r.layers)
|
|
2996
|
+
for i in range(len(sorted_layers) - 1):
|
|
2997
|
+
if sorted_layers[i + 1] - sorted_layers[i] == 1:
|
|
2998
|
+
adjacent_count += 1
|
|
2999
|
+
break
|
|
3000
|
+
|
|
3001
|
+
# Layer position analysis (early vs late layers)
|
|
3002
|
+
mid_layer = layers[len(layers) // 2] if layers else 0
|
|
3003
|
+
early_in_top = sum(1 for r in top_results for l in r.layers if l < mid_layer)
|
|
3004
|
+
late_in_top = sum(1 for r in top_results for l in r.layers if l >= mid_layer)
|
|
3005
|
+
|
|
3006
|
+
return {
|
|
3007
|
+
"layer_frequency_in_top": dict(layer_freq.most_common()),
|
|
3008
|
+
"most_important_layers": [l for l, _ in layer_freq.most_common(5)],
|
|
3009
|
+
"size_distribution_in_top": dict(size_dist),
|
|
3010
|
+
"best_score_by_size": size_to_best,
|
|
3011
|
+
"optimal_combination_size": max(size_to_best.keys(), key=lambda k: size_to_best[k]) if size_to_best else 1,
|
|
3012
|
+
"structure_frequency_in_top": {s.value: c for s, c in structure_freq.most_common()},
|
|
3013
|
+
"dominant_structure": structure_freq.most_common(1)[0][0].value if structure_freq else "unknown",
|
|
3014
|
+
"adjacent_pairs_in_top": adjacent_count,
|
|
3015
|
+
"early_vs_late_ratio": early_in_top / late_in_top if late_in_top > 0 else float('inf'),
|
|
3016
|
+
}
|
|
3017
|
+
|
|
3018
|
+
|
|
3019
|
+
def _generate_exhaustive_recommendation(
|
|
3020
|
+
best_combination: Tuple[int, ...],
|
|
3021
|
+
best_score: float,
|
|
3022
|
+
best_structure: StructureType,
|
|
3023
|
+
single_layer_best: int,
|
|
3024
|
+
single_layer_best_score: float,
|
|
3025
|
+
combination_beats_single: bool,
|
|
3026
|
+
improvement_over_single: float,
|
|
3027
|
+
patterns: Dict[str, Any],
|
|
3028
|
+
total_combinations: int,
|
|
3029
|
+
) -> str:
|
|
3030
|
+
"""Generate recommendation from exhaustive analysis."""
|
|
3031
|
+
parts = []
|
|
3032
|
+
|
|
3033
|
+
parts.append(f"Tested {total_combinations} layer combinations.")
|
|
3034
|
+
|
|
3035
|
+
if combination_beats_single and improvement_over_single > 0.05:
|
|
3036
|
+
layers_str = "+".join(f"L{l}" for l in best_combination)
|
|
3037
|
+
parts.append(
|
|
3038
|
+
f"BEST: {layers_str} ({best_structure.value}: {best_score:.3f}), "
|
|
3039
|
+
f"+{improvement_over_single:.3f} over single layer L{single_layer_best}."
|
|
3040
|
+
)
|
|
3041
|
+
else:
|
|
3042
|
+
parts.append(
|
|
3043
|
+
f"BEST: Single layer L{single_layer_best} ({best_score:.3f}). "
|
|
3044
|
+
f"Multi-layer combinations don't significantly improve."
|
|
3045
|
+
)
|
|
3046
|
+
|
|
3047
|
+
# Pattern insights
|
|
3048
|
+
opt_size = patterns.get("optimal_combination_size", 1)
|
|
3049
|
+
if opt_size > 1:
|
|
3050
|
+
parts.append(f"Optimal combination size: {opt_size} layers.")
|
|
3051
|
+
|
|
3052
|
+
important_layers = patterns.get("most_important_layers", [])
|
|
3053
|
+
if important_layers:
|
|
3054
|
+
layers_str = ", ".join(f"L{l}" for l in important_layers[:3])
|
|
3055
|
+
parts.append(f"Most important layers: {layers_str}.")
|
|
3056
|
+
|
|
3057
|
+
dominant = patterns.get("dominant_structure", "unknown")
|
|
3058
|
+
parts.append(f"Dominant structure: {dominant}.")
|
|
3059
|
+
|
|
3060
|
+
return " ".join(parts)
|