wisent 0.7.701__py3-none-any.whl ā 0.7.1045__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wisent/__init__.py +1 -1
- wisent/comparison/__init__.py +1 -0
- wisent/comparison/detect_bos_features.py +275 -0
- wisent/comparison/fgaa.py +465 -0
- wisent/comparison/lora.py +669 -0
- wisent/comparison/lora_dpo.py +592 -0
- wisent/comparison/main.py +444 -0
- wisent/comparison/ours.py +76 -0
- wisent/comparison/sae.py +304 -0
- wisent/comparison/utils.py +381 -0
- wisent/core/activations/activation_cache.py +393 -0
- wisent/core/activations/activations.py +3 -3
- wisent/core/activations/activations_collector.py +12 -7
- wisent/core/activations/classifier_inference_strategy.py +12 -11
- wisent/core/activations/extraction_strategy.py +260 -84
- wisent/core/classifiers/classifiers/core/atoms.py +3 -2
- wisent/core/cli/__init__.py +2 -1
- wisent/core/cli/agent/train_classifier.py +16 -3
- wisent/core/cli/check_linearity.py +35 -3
- wisent/core/cli/cluster_benchmarks.py +4 -6
- wisent/core/cli/create_steering_vector.py +6 -4
- wisent/core/cli/diagnose_vectors.py +7 -4
- wisent/core/cli/estimate_unified_goodness_time.py +6 -4
- wisent/core/cli/generate_pairs_from_task.py +9 -56
- wisent/core/cli/generate_vector_from_task.py +11 -20
- wisent/core/cli/geometry_search.py +137 -0
- wisent/core/cli/get_activations.py +2 -2
- wisent/core/cli/method_optimizer.py +4 -3
- wisent/core/cli/modify_weights.py +3 -2
- wisent/core/cli/optimize_sample_size.py +1 -1
- wisent/core/cli/optimize_steering.py +14 -16
- wisent/core/cli/optimize_weights.py +2 -1
- wisent/core/cli/preview_pairs.py +203 -0
- wisent/core/cli/steering_method_trainer.py +3 -3
- wisent/core/cli/tasks.py +19 -76
- wisent/core/cli/train_unified_goodness.py +3 -3
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +4 -4
- wisent/core/contrastive_pairs/diagnostics/linearity.py +7 -0
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/agentic_search.py +37 -347
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/aider_polyglot.py +113 -136
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/codeforces.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/coding_benchmarks.py +124 -504
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/faithbench.py +40 -63
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flames.py +46 -89
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/flores.py +15 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/frames.py +36 -20
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/hallucinations_leaderboard.py +3 -45
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/livemathbench.py +42 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/longform_writing.py +2 -112
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/math500.py +39 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/medium_priority_benchmarks.py +475 -525
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/mercury.py +65 -42
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/olympiadbench.py +2 -12
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/planbench.py +78 -219
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/polymath.py +37 -4
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/recode.py +84 -69
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/refusalbench.py +168 -160
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/simpleqa.py +44 -25
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/tau_bench.py +3 -103
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolbench.py +3 -97
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/toolemu.py +48 -182
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +3 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +19 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aclue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/acp_bench_hard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/advanced.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aexams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrimmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/afrixnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabculture.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_complete.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabic_leaderboard_light.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arabicmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/aradice.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +1 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/babi.py +36 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/basque_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/belebele.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/benchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bertaqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhs.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/bhtc.py +3 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/blimp_nl.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +22 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/c4.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cabbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/careqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalan_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catalanqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/catcola.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +10 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ceval_valid.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chain.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/chartqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/claim.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/click.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cnn.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cocoteros.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coedit.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/commonsense_qa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copal_id.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/csatqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cycle.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darija_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijahellaswag.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/darijammlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/dbpedia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/discrim_eval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/doc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/epec.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_ca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eq_bench_es.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/esbbq.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ethics.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_exams.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_proficiency.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_reading.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/eus_trivia.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/evalita_llm.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/financial.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/flan.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/french_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/galician_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gaokao.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/glianorex.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/global_piqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gpt3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/groundcocoa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/haerae.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_ethics.py +5 -9
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hendrycks_math.py +63 -16
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/histoires_morales.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hrm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/humaneval_infilling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/icelandic_winogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/inverse_scaling.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/japanese_leaderboard_mc.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kobest.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/kormedmcqa.py +5 -17
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_cloze.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lambada_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/law.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/leaderboard.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lingoly.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/llama3.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/lm_syneval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/longbenchv2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mastermind.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/med_concepts_qa.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/meddialog.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medical.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medmcqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mela.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/metabench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/minerva_math.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mmlusr.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multiblimp.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/non.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_gen_exact.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/noreval_mc_log_likelihoods.py +4 -8
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/nq_open.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_arc_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_hellaswag_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_mmlu_multilingual.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/okapi_truthfulqa_multilingual.py +2 -5
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/olaph.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/option.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafraseja.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/parafrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/paws_x.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/persona.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/phrases.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pile.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/portuguese_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prompt.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper_bool.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnlieu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/random.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/reversed.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/ruler.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/score.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/scrolls_mc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/self.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sglue_rte.py +2 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/siqa.py +4 -7
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/spanish_bench.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/storycloze.py +2 -6
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/summarization.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/super_glue.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swde.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sycophancy.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/t0.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/teca.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyarc.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinybenchmarks.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinygsm8k.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinyhellaswag.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinymmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinytruthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tinywinogrande.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/tmmluplus.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +9 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turblimp_core.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/turkishmmlu_mc.py +0 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/unscramble.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/vaxx.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +3 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wmdp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc273.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xcopa.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xlsum.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xquad.py +2 -4
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +2 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +2 -2
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/zhoblimp.py +1 -3
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +173 -6
- wisent/core/data_loaders/loaders/lm_loader.py +12 -1
- wisent/core/geometry_runner.py +995 -0
- wisent/core/geometry_search_space.py +237 -0
- wisent/core/hyperparameter_optimizer.py +1 -1
- wisent/core/main.py +3 -0
- wisent/core/models/core/atoms.py +5 -3
- wisent/core/models/wisent_model.py +1 -1
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/parser_arguments/check_linearity_parser.py +12 -2
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +2 -2
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +6 -13
- wisent/core/parser_arguments/geometry_search_parser.py +61 -0
- wisent/core/parser_arguments/get_activations_parser.py +5 -14
- wisent/core/parser_arguments/main_parser.py +8 -0
- wisent/core/parser_arguments/train_unified_goodness_parser.py +2 -2
- wisent/core/steering.py +5 -3
- wisent/core/steering_methods/methods/hyperplane.py +2 -1
- wisent/core/synthetic/generators/nonsense_generator.py +30 -18
- wisent/core/trainers/steering_trainer.py +2 -2
- wisent/core/utils/device.py +27 -27
- wisent/core/utils/layer_combinations.py +70 -0
- wisent/examples/__init__.py +1 -0
- wisent/examples/scripts/__init__.py +1 -0
- wisent/examples/scripts/count_all_benchmarks.py +121 -0
- wisent/examples/scripts/discover_directions.py +469 -0
- wisent/examples/scripts/extract_benchmark_info.py +71 -0
- wisent/examples/scripts/search_all_short_names.py +31 -0
- wisent/examples/scripts/test_all_benchmarks.py +138 -0
- wisent/examples/scripts/test_all_benchmarks_new.py +28 -0
- wisent/examples/scripts/test_contrastive_pairs_all_supported.py +230 -0
- wisent/examples/scripts/test_nonsense_baseline.py +261 -0
- wisent/examples/scripts/test_one_benchmark.py +324 -0
- wisent/examples/scripts/test_one_coding_benchmark.py +293 -0
- wisent/parameters/lm_eval/broken_in_lm_eval.json +179 -2
- wisent/parameters/lm_eval/category_directions.json +137 -0
- wisent/parameters/lm_eval/repair_plan.json +282 -0
- wisent/parameters/lm_eval/weak_contrastive_pairs.json +38 -0
- wisent/parameters/lm_eval/working_benchmarks.json +206 -0
- wisent/parameters/lm_eval/working_benchmarks_categorized.json +236 -0
- wisent/tests/test_detector_accuracy.py +1 -1
- wisent/tests/visualize_geometry.py +1 -1
- {wisent-0.7.701.dist-info ā wisent-0.7.1045.dist-info}/METADATA +5 -1
- {wisent-0.7.701.dist-info ā wisent-0.7.1045.dist-info}/RECORD +328 -358
- wisent/core/contrastive_pairs/huggingface_pairs/hf_task_extractors/browsecomp.py +0 -245
- wisent/examples/contrastive_pairs/humanization_human_vs_ai.json +0 -2112
- wisent/examples/scripts/1/test_basqueglue_evaluation.json +0 -51
- wisent/examples/scripts/1/test_basqueglue_pairs.json +0 -14
- wisent/examples/scripts/1/test_bec2016eu_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bec2016eu_pairs.json +0 -14
- wisent/examples/scripts/1/test_belebele_evaluation.json +0 -51
- wisent/examples/scripts/1/test_belebele_pairs.json +0 -14
- wisent/examples/scripts/1/test_benchmarks_evaluation.json +0 -51
- wisent/examples/scripts/1/test_benchmarks_pairs.json +0 -14
- wisent/examples/scripts/1/test_bertaqa_evaluation.json +0 -51
- wisent/examples/scripts/1/test_bertaqa_pairs.json +0 -14
- wisent/examples/scripts/1/test_bhtc_v2_evaluation.json +0 -30
- wisent/examples/scripts/1/test_bhtc_v2_pairs.json +0 -8
- wisent/examples/scripts/1/test_boolq-seq2seq_evaluation.json +0 -30
- wisent/examples/scripts/1/test_boolq-seq2seq_pairs.json +0 -8
- wisent/examples/scripts/1/test_cabreu_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cabreu_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_en_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_en_pairs.json +0 -8
- wisent/examples/scripts/1/test_careqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_careqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catalanqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catalanqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_catcola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_catcola_pairs.json +0 -8
- wisent/examples/scripts/1/test_chartqa_evaluation.json +0 -30
- wisent/examples/scripts/1/test_chartqa_pairs.json +0 -8
- wisent/examples/scripts/1/test_claim_stance_topic_evaluation.json +0 -30
- wisent/examples/scripts/1/test_claim_stance_topic_pairs.json +0 -8
- wisent/examples/scripts/1/test_cnn_dailymail_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cnn_dailymail_pairs.json +0 -8
- wisent/examples/scripts/1/test_cocoteros_es_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cocoteros_es_pairs.json +0 -8
- wisent/examples/scripts/1/test_coedit_gec_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coedit_gec_pairs.json +0 -8
- wisent/examples/scripts/1/test_cola_evaluation.json +0 -30
- wisent/examples/scripts/1/test_cola_pairs.json +0 -8
- wisent/examples/scripts/1/test_coqcat_evaluation.json +0 -30
- wisent/examples/scripts/1/test_coqcat_pairs.json +0 -8
- wisent/examples/scripts/1/test_dbpedia_14_evaluation.json +0 -30
- wisent/examples/scripts/1/test_dbpedia_14_pairs.json +0 -8
- wisent/examples/scripts/1/test_epec_koref_bin_evaluation.json +0 -30
- wisent/examples/scripts/1/test_epec_koref_bin_pairs.json +0 -8
- wisent/examples/scripts/1/test_ethos_binary_evaluation.json +0 -30
- wisent/examples/scripts/1/test_ethos_binary_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimgsm_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrimmlu_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_evaluation.json +0 -30
- wisent/examples/scripts/2/test_afrixnli_en_direct_amh_pairs.json +0 -8
- wisent/examples/scripts/2/test_arc_ar_evaluation.json +0 -30
- wisent/examples/scripts/2/test_arc_ar_pairs.json +0 -8
- wisent/examples/scripts/2/test_atis_evaluation.json +0 -30
- wisent/examples/scripts/2/test_atis_pairs.json +0 -8
- wisent/examples/scripts/2/test_babi_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babi_pairs.json +0 -8
- wisent/examples/scripts/2/test_babilong_evaluation.json +0 -30
- wisent/examples/scripts/2/test_babilong_pairs.json +0 -8
- wisent/examples/scripts/2/test_bangla_mmlu_evaluation.json +0 -30
- wisent/examples/scripts/2/test_bangla_mmlu_pairs.json +0 -8
- wisent/examples/scripts/2/test_basque-glue_pairs.json +0 -14
- {wisent-0.7.701.dist-info ā wisent-0.7.1045.dist-info}/WHEEL +0 -0
- {wisent-0.7.701.dist-info ā wisent-0.7.1045.dist-info}/entry_points.txt +0 -0
- {wisent-0.7.701.dist-info ā wisent-0.7.1045.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.7.701.dist-info ā wisent-0.7.1045.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ import torch
|
|
|
8
8
|
from collections import defaultdict
|
|
9
9
|
|
|
10
10
|
from wisent.core.errors import SteeringMethodUnknownError, VectorQualityTooLowError
|
|
11
|
+
from wisent.core.utils.device import preferred_dtype
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def execute_create_steering_vector(args):
|
|
@@ -46,20 +47,21 @@ def execute_create_steering_vector(args):
|
|
|
46
47
|
|
|
47
48
|
# Structure: {layer_str: {"positive": [tensors], "negative": [tensors]}}
|
|
48
49
|
layer_activations = defaultdict(lambda: {"positive": [], "negative": []})
|
|
50
|
+
dtype = preferred_dtype()
|
|
49
51
|
|
|
50
52
|
for pair in pairs_list:
|
|
51
53
|
# Extract positive activations
|
|
52
54
|
pos_layers = pair['positive_response'].get('layers_activations', {})
|
|
53
55
|
for layer_str, activation_list in pos_layers.items():
|
|
54
56
|
if activation_list is not None:
|
|
55
|
-
tensor = torch.tensor(activation_list, dtype=
|
|
57
|
+
tensor = torch.tensor(activation_list, dtype=dtype)
|
|
56
58
|
layer_activations[layer_str]["positive"].append(tensor)
|
|
57
59
|
|
|
58
60
|
# Extract negative activations
|
|
59
61
|
neg_layers = pair['negative_response'].get('layers_activations', {})
|
|
60
62
|
for layer_str, activation_list in neg_layers.items():
|
|
61
63
|
if activation_list is not None:
|
|
62
|
-
tensor = torch.tensor(activation_list, dtype=
|
|
64
|
+
tensor = torch.tensor(activation_list, dtype=dtype)
|
|
63
65
|
layer_activations[layer_str]["negative"].append(tensor)
|
|
64
66
|
|
|
65
67
|
available_layers = sorted(layer_activations.keys(), key=int)
|
|
@@ -232,7 +234,7 @@ def execute_create_steering_vector(args):
|
|
|
232
234
|
# If multiple layers, save the first one (or could save all and let user specify)
|
|
233
235
|
if len(steering_vectors) == 1:
|
|
234
236
|
layer_str = list(steering_vectors.keys())[0]
|
|
235
|
-
vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=
|
|
237
|
+
vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=dtype)
|
|
236
238
|
torch.save({
|
|
237
239
|
'steering_vector': vector_tensor,
|
|
238
240
|
'layer_index': int(layer_str),
|
|
@@ -251,7 +253,7 @@ def execute_create_steering_vector(args):
|
|
|
251
253
|
# Save multiple layers - save each to separate file
|
|
252
254
|
for layer_str in steering_vectors.keys():
|
|
253
255
|
layer_output = args.output.replace('.pt', f'_layer_{layer_str}.pt')
|
|
254
|
-
vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=
|
|
256
|
+
vector_tensor = torch.tensor(steering_vectors[layer_str], dtype=dtype)
|
|
255
257
|
torch.save({
|
|
256
258
|
'steering_vector': vector_tensor,
|
|
257
259
|
'layer_index': int(layer_str),
|
|
@@ -6,6 +6,7 @@ import os
|
|
|
6
6
|
import math
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
+
from wisent.core.utils.device import preferred_dtype
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def execute_diagnose_vectors(args):
|
|
@@ -227,10 +228,11 @@ def _run_cone_analysis(
|
|
|
227
228
|
return
|
|
228
229
|
|
|
229
230
|
# Convert to tensors if needed
|
|
231
|
+
dtype = preferred_dtype()
|
|
230
232
|
if not isinstance(pos_acts, torch.Tensor):
|
|
231
|
-
pos_acts = torch.tensor(pos_acts, dtype=
|
|
233
|
+
pos_acts = torch.tensor(pos_acts, dtype=dtype)
|
|
232
234
|
if not isinstance(neg_acts, torch.Tensor):
|
|
233
|
-
neg_acts = torch.tensor(neg_acts, dtype=
|
|
235
|
+
neg_acts = torch.tensor(neg_acts, dtype=dtype)
|
|
234
236
|
|
|
235
237
|
print(f" Positive samples: {pos_acts.shape[0]}")
|
|
236
238
|
print(f" Negative samples: {neg_acts.shape[0]}")
|
|
@@ -342,10 +344,11 @@ def _run_geometry_analysis(
|
|
|
342
344
|
return
|
|
343
345
|
|
|
344
346
|
# Convert to tensors
|
|
347
|
+
dtype = preferred_dtype()
|
|
345
348
|
if not isinstance(pos_acts, torch.Tensor):
|
|
346
|
-
pos_acts = torch.tensor(pos_acts, dtype=
|
|
349
|
+
pos_acts = torch.tensor(pos_acts, dtype=dtype)
|
|
347
350
|
if not isinstance(neg_acts, torch.Tensor):
|
|
348
|
-
neg_acts = torch.tensor(neg_acts, dtype=
|
|
351
|
+
neg_acts = torch.tensor(neg_acts, dtype=dtype)
|
|
349
352
|
|
|
350
353
|
print(f" Positive samples: {pos_acts.shape[0]}")
|
|
351
354
|
print(f" Negative samples: {neg_acts.shape[0]}")
|
|
@@ -141,8 +141,10 @@ def estimate_runtime(
|
|
|
141
141
|
results = {}
|
|
142
142
|
|
|
143
143
|
# 1. Model loading (one-time)
|
|
144
|
-
if device == 'cpu':
|
|
145
|
-
|
|
144
|
+
if device == 'cpu' or device == 'auto':
|
|
145
|
+
from wisent.core.utils.device import resolve_default_device
|
|
146
|
+
actual_device = resolve_default_device() if device == 'auto' else device
|
|
147
|
+
model_time = TIME_ESTIMATES['model_load_cpu'] if actual_device == 'cpu' else TIME_ESTIMATES['model_load_gpu']
|
|
146
148
|
else:
|
|
147
149
|
model_time = TIME_ESTIMATES['model_load_gpu']
|
|
148
150
|
results['model_loading'] = model_time
|
|
@@ -269,8 +271,8 @@ def main():
|
|
|
269
271
|
help="Skip evaluation phase"
|
|
270
272
|
)
|
|
271
273
|
parser.add_argument(
|
|
272
|
-
"--device", choices=["cuda", "cpu"], default="
|
|
273
|
-
help="Device for computation"
|
|
274
|
+
"--device", choices=["cuda", "cpu", "mps", "auto"], default="auto",
|
|
275
|
+
help="Device for computation (auto = detect best available)"
|
|
274
276
|
)
|
|
275
277
|
parser.add_argument(
|
|
276
278
|
"--show-breakdown", action="store_true",
|
|
@@ -4,8 +4,6 @@ import sys
|
|
|
4
4
|
import json
|
|
5
5
|
import os
|
|
6
6
|
|
|
7
|
-
from wisent.core.errors import InvalidDataFormatError
|
|
8
|
-
|
|
9
7
|
|
|
10
8
|
def execute_generate_pairs_from_task(args):
|
|
11
9
|
"""Execute the generate-pairs-from-task command - load and save contrastive pairs from a task."""
|
|
@@ -14,9 +12,8 @@ def execute_generate_pairs_from_task(args):
|
|
|
14
12
|
if hasattr(args, 'task_name') and args.task_name:
|
|
15
13
|
args.task_name = expand_task_if_skill_or_risk(args.task_name)
|
|
16
14
|
|
|
17
|
-
from wisent.core.contrastive_pairs.huggingface_pairs.hf_extractor_manifest import HF_EXTRACTORS
|
|
18
15
|
from wisent.core.contrastive_pairs.lm_eval_pairs.lm_task_pairs_generation import (
|
|
19
|
-
|
|
16
|
+
build_contrastive_pairs,
|
|
20
17
|
)
|
|
21
18
|
|
|
22
19
|
print(f"\nš Generating contrastive pairs from task: {args.task_name}")
|
|
@@ -26,58 +23,14 @@ def execute_generate_pairs_from_task(args):
|
|
|
26
23
|
|
|
27
24
|
try:
|
|
28
25
|
print(f"\nš Loading task '{args.task_name}'...")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
print(f" šØ Building contrastive pairs...")
|
|
38
|
-
pairs = lm_build_contrastive_pairs(
|
|
39
|
-
task_name=args.task_name,
|
|
40
|
-
lm_eval_task=None, # HF extractors don't need lm_eval_task
|
|
41
|
-
limit=args.limit,
|
|
42
|
-
)
|
|
43
|
-
pairs_task_name = args.task_name
|
|
44
|
-
else:
|
|
45
|
-
# lm-eval task - load via LMEvalDataLoader
|
|
46
|
-
from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
|
|
47
|
-
loader = LMEvalDataLoader()
|
|
48
|
-
task_obj = loader.load_lm_eval_task(args.task_name)
|
|
49
|
-
|
|
50
|
-
# Handle both lm-eval tasks (dict or ConfigurableTask)
|
|
51
|
-
if isinstance(task_obj, dict):
|
|
52
|
-
# lm-eval task group with subtasks
|
|
53
|
-
if len(task_obj) != 1:
|
|
54
|
-
keys = ", ".join(sorted(task_obj.keys()))
|
|
55
|
-
raise InvalidDataFormatError(
|
|
56
|
-
reason=f"Task '{args.task_name}' returned {len(task_obj)} subtasks ({keys}). "
|
|
57
|
-
"Specify an explicit subtask, e.g. 'benchmark/subtask'."
|
|
58
|
-
)
|
|
59
|
-
(subname, task), = task_obj.items()
|
|
60
|
-
pairs_task_name = subname
|
|
61
|
-
|
|
62
|
-
# Generate contrastive pairs using lm-eval interface
|
|
63
|
-
print(f" šØ Building contrastive pairs...")
|
|
64
|
-
pairs = lm_build_contrastive_pairs(
|
|
65
|
-
task_name=pairs_task_name,
|
|
66
|
-
lm_eval_task=task,
|
|
67
|
-
limit=args.limit,
|
|
68
|
-
)
|
|
69
|
-
else:
|
|
70
|
-
# Single lm-eval task (ConfigurableTask), not wrapped in dict
|
|
71
|
-
task = task_obj
|
|
72
|
-
pairs_task_name = args.task_name
|
|
73
|
-
|
|
74
|
-
# Generate contrastive pairs using lm-eval interface
|
|
75
|
-
print(f" šØ Building contrastive pairs...")
|
|
76
|
-
pairs = lm_build_contrastive_pairs(
|
|
77
|
-
task_name=pairs_task_name,
|
|
78
|
-
lm_eval_task=task,
|
|
79
|
-
limit=args.limit,
|
|
80
|
-
)
|
|
26
|
+
print(f" šØ Building contrastive pairs...")
|
|
27
|
+
|
|
28
|
+
# Use unified loader - handles HF, lm-eval, and group tasks automatically
|
|
29
|
+
pairs = build_contrastive_pairs(
|
|
30
|
+
task_name=args.task_name,
|
|
31
|
+
limit=args.limit,
|
|
32
|
+
)
|
|
33
|
+
pairs_task_name = args.task_name
|
|
81
34
|
|
|
82
35
|
print(f" ā Generated {len(pairs)} contrastive pairs")
|
|
83
36
|
|
|
@@ -30,8 +30,7 @@ def _load_optimal_defaults(model_name: str, task_name: str, args):
|
|
|
30
30
|
"layer": result.layer,
|
|
31
31
|
"strength": result.strength,
|
|
32
32
|
"strategy": result.strategy,
|
|
33
|
-
"
|
|
34
|
-
"prompt_strategy": result.prompt_strategy,
|
|
33
|
+
"extraction_strategy": getattr(result, 'extraction_strategy', None),
|
|
35
34
|
"score": result.score,
|
|
36
35
|
}
|
|
37
36
|
|
|
@@ -89,31 +88,24 @@ def execute_generate_vector_from_task(args):
|
|
|
89
88
|
print(f" Method: {optimal_config['method']}")
|
|
90
89
|
print(f" Layer: {optimal_config['layer']}")
|
|
91
90
|
print(f" Strength: {optimal_config['strength']}")
|
|
92
|
-
|
|
91
|
+
if optimal_config.get('extraction_strategy'):
|
|
92
|
+
print(f" Extraction Strategy: {optimal_config['extraction_strategy']}")
|
|
93
93
|
print(f" Score: {optimal_config['score']:.3f}")
|
|
94
94
|
print(f"{'='*60}")
|
|
95
|
-
|
|
95
|
+
|
|
96
96
|
# Apply optimal defaults if user didn't explicitly override
|
|
97
97
|
if not getattr(args, '_layers_set_by_user', False) and args.layers is None:
|
|
98
98
|
args.layers = str(optimal_config['layer'])
|
|
99
99
|
print(f" ā Using optimal layer: {args.layers}")
|
|
100
|
-
|
|
101
|
-
if not getattr(args, '
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
"mean_pooling": "average",
|
|
106
|
-
"first_token": "first",
|
|
107
|
-
"max_pooling": "max",
|
|
108
|
-
}
|
|
109
|
-
mapped_agg = token_agg_map.get(optimal_config['token_aggregation'], args.token_aggregation)
|
|
110
|
-
args.token_aggregation = mapped_agg
|
|
111
|
-
print(f" ā Using optimal token aggregation: {args.token_aggregation}")
|
|
112
|
-
|
|
100
|
+
|
|
101
|
+
if not getattr(args, '_extraction_strategy_set_by_user', False) and optimal_config.get('extraction_strategy'):
|
|
102
|
+
args.extraction_strategy = optimal_config['extraction_strategy']
|
|
103
|
+
print(f" ā Using optimal extraction strategy: {args.extraction_strategy}")
|
|
104
|
+
|
|
113
105
|
if not getattr(args, '_method_set_by_user', False):
|
|
114
106
|
args.method = optimal_config['method'].lower()
|
|
115
107
|
print(f" ā Using optimal method: {args.method}")
|
|
116
|
-
|
|
108
|
+
|
|
117
109
|
# Store optimal config for later use
|
|
118
110
|
args._optimal_config = optimal_config
|
|
119
111
|
print()
|
|
@@ -176,8 +168,7 @@ def execute_generate_vector_from_task(args):
|
|
|
176
168
|
model=args.model,
|
|
177
169
|
device=args.device,
|
|
178
170
|
layers=args.layers,
|
|
179
|
-
|
|
180
|
-
prompt_strategy=args.prompt_strategy,
|
|
171
|
+
extraction_strategy=args.extraction_strategy,
|
|
181
172
|
verbose=args.verbose,
|
|
182
173
|
timing=args.timing,
|
|
183
174
|
)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Run geometry search across benchmarks to find unified goodness direction."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def execute_geometry_search(args):
|
|
10
|
+
"""Execute the geometry-search command."""
|
|
11
|
+
print(f"\n{'='*60}")
|
|
12
|
+
print("GEOMETRY SEARCH")
|
|
13
|
+
print(f"{'='*60}")
|
|
14
|
+
print(f"Model: {args.model}")
|
|
15
|
+
print(f"Output: {args.output}")
|
|
16
|
+
print(f"Pairs per benchmark: {args.pairs_per_benchmark}")
|
|
17
|
+
print(f"Max layer combo size: {args.max_layer_combo_size}")
|
|
18
|
+
|
|
19
|
+
# Import dependencies
|
|
20
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
21
|
+
from wisent.core.geometry_search_space import GeometrySearchSpace, GeometrySearchConfig
|
|
22
|
+
from wisent.core.geometry_runner import GeometryRunner
|
|
23
|
+
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
24
|
+
|
|
25
|
+
# Parse strategies
|
|
26
|
+
if args.strategies:
|
|
27
|
+
strategy_names = [s.strip() for s in args.strategies.split(',')]
|
|
28
|
+
strategies = [ExtractionStrategy(s) for s in strategy_names]
|
|
29
|
+
print(f"Strategies: {strategy_names}")
|
|
30
|
+
else:
|
|
31
|
+
strategies = None # Use default (all 7)
|
|
32
|
+
print("Strategies: all 7 default strategies")
|
|
33
|
+
|
|
34
|
+
# Parse benchmarks
|
|
35
|
+
if args.benchmarks:
|
|
36
|
+
if args.benchmarks.endswith('.txt'):
|
|
37
|
+
with open(args.benchmarks) as f:
|
|
38
|
+
benchmarks = [line.strip() for line in f if line.strip()]
|
|
39
|
+
else:
|
|
40
|
+
benchmarks = [b.strip() for b in args.benchmarks.split(',')]
|
|
41
|
+
print(f"Benchmarks: {len(benchmarks)} specified")
|
|
42
|
+
else:
|
|
43
|
+
benchmarks = None # Use default (all)
|
|
44
|
+
print("Benchmarks: all available")
|
|
45
|
+
|
|
46
|
+
# Create config
|
|
47
|
+
config = GeometrySearchConfig(
|
|
48
|
+
pairs_per_benchmark=args.pairs_per_benchmark,
|
|
49
|
+
max_layer_combo_size=args.max_layer_combo_size,
|
|
50
|
+
random_seed=args.seed,
|
|
51
|
+
cache_activations=True,
|
|
52
|
+
cache_dir=args.cache_dir,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Create search space
|
|
56
|
+
search_space = GeometrySearchSpace(
|
|
57
|
+
models=[args.model],
|
|
58
|
+
strategies=strategies,
|
|
59
|
+
benchmarks=benchmarks,
|
|
60
|
+
config=config,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
print(f"\n{search_space.summary()}")
|
|
64
|
+
|
|
65
|
+
# Load model
|
|
66
|
+
print(f"\nLoading model {args.model}...")
|
|
67
|
+
model = WisentModel(args.model, device=args.device)
|
|
68
|
+
print(f"Model loaded: {model.num_layers} layers, hidden_size={model.hidden_size}")
|
|
69
|
+
|
|
70
|
+
# Create runner
|
|
71
|
+
cache_dir = args.cache_dir or f"/tmp/wisent_geometry_cache_{args.model.replace('/', '_')}"
|
|
72
|
+
runner = GeometryRunner(search_space, model, cache_dir=cache_dir)
|
|
73
|
+
|
|
74
|
+
# Run search
|
|
75
|
+
print(f"\nStarting geometry search...")
|
|
76
|
+
results = runner.run(show_progress=True)
|
|
77
|
+
|
|
78
|
+
# Save results
|
|
79
|
+
output_path = Path(args.output)
|
|
80
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
results.save(str(output_path))
|
|
82
|
+
print(f"\nResults saved to: {output_path}")
|
|
83
|
+
|
|
84
|
+
# Print summary
|
|
85
|
+
print(f"\n{'='*60}")
|
|
86
|
+
print("SUMMARY")
|
|
87
|
+
print(f"{'='*60}")
|
|
88
|
+
print(f"Total time: {results.total_time_seconds / 3600:.2f} hours")
|
|
89
|
+
print(f" Extraction: {results.extraction_time_seconds / 3600:.2f} hours")
|
|
90
|
+
print(f" Testing: {results.test_time_seconds / 60:.1f} minutes")
|
|
91
|
+
print(f"Benchmarks tested: {results.benchmarks_tested}")
|
|
92
|
+
print(f"Strategies tested: {results.strategies_tested}")
|
|
93
|
+
print(f"Layer combos tested: {results.layer_combos_tested}")
|
|
94
|
+
|
|
95
|
+
print(f"\nStructure distribution:")
|
|
96
|
+
for struct, count in sorted(results.get_structure_distribution().items(), key=lambda x: -x[1]):
|
|
97
|
+
pct = 100 * count / results.layer_combos_tested
|
|
98
|
+
print(f" {struct}: {count} ({pct:.1f}%)")
|
|
99
|
+
|
|
100
|
+
print(f"\nTop 10 by linear score:")
|
|
101
|
+
for r in results.get_best_by_linear_score(10):
|
|
102
|
+
print(f" {r.benchmark}/{r.strategy} layers={r.layers}: linear={r.linear_score:.3f} best={r.best_structure}")
|
|
103
|
+
|
|
104
|
+
print(f"\nTop 10 by cone score:")
|
|
105
|
+
for r in results.get_best_by_structure('cone', 10):
|
|
106
|
+
print(f" {r.benchmark}/{r.strategy} layers={r.layers}: cone={r.cone_score:.3f} best={r.best_structure}")
|
|
107
|
+
|
|
108
|
+
# Summary by benchmark
|
|
109
|
+
print(f"\nSummary by benchmark (avg linear score):")
|
|
110
|
+
by_bench = results.get_summary_by_benchmark()
|
|
111
|
+
sorted_benches = sorted(by_bench.items(), key=lambda x: -x[1]['mean'])[:20]
|
|
112
|
+
for bench, stats in sorted_benches:
|
|
113
|
+
print(f" {bench}: mean={stats['mean']:.3f} max={stats['max']:.3f}")
|
|
114
|
+
|
|
115
|
+
print(f"\n{'='*60}")
|
|
116
|
+
print("CONCLUSION")
|
|
117
|
+
print(f"{'='*60}")
|
|
118
|
+
|
|
119
|
+
# Determine if unified direction exists
|
|
120
|
+
dist = results.get_structure_distribution()
|
|
121
|
+
total = sum(dist.values())
|
|
122
|
+
linear_pct = 100 * dist.get('linear', 0) / total if total > 0 else 0
|
|
123
|
+
cone_pct = 100 * dist.get('cone', 0) / total if total > 0 else 0
|
|
124
|
+
orthogonal_pct = 100 * dist.get('orthogonal', 0) / total if total > 0 else 0
|
|
125
|
+
|
|
126
|
+
if linear_pct > 50:
|
|
127
|
+
print(f"UNIFIED LINEAR DIRECTION EXISTS ({linear_pct:.1f}% linear)")
|
|
128
|
+
print("Recommendation: Use CAA with the best layer/strategy combination")
|
|
129
|
+
elif cone_pct > 30:
|
|
130
|
+
print(f"CONE STRUCTURE DETECTED ({cone_pct:.1f}% cone)")
|
|
131
|
+
print("Recommendation: Use PRISM with multi-directional steering")
|
|
132
|
+
elif orthogonal_pct > 50:
|
|
133
|
+
print(f"ORTHOGONAL STRUCTURE ({orthogonal_pct:.1f}% orthogonal)")
|
|
134
|
+
print("Recommendation: No unified direction - use per-benchmark directions or TITAN")
|
|
135
|
+
else:
|
|
136
|
+
print("MIXED STRUCTURE - no clear unified direction")
|
|
137
|
+
print("Recommendation: Use TITAN for adaptive multi-component steering")
|
|
@@ -90,7 +90,7 @@ def execute_get_activations(args):
|
|
|
90
90
|
|
|
91
91
|
# 6. Collect activations
|
|
92
92
|
print(f"\nā” Collecting activations...")
|
|
93
|
-
collector = ActivationCollector(model=model
|
|
93
|
+
collector = ActivationCollector(model=model)
|
|
94
94
|
|
|
95
95
|
enriched_pairs = []
|
|
96
96
|
for i, pair in enumerate(pair_set.pairs):
|
|
@@ -114,7 +114,7 @@ def execute_get_activations(args):
|
|
|
114
114
|
'trait_label': trait_label,
|
|
115
115
|
'model': args.model,
|
|
116
116
|
'layers': layers,
|
|
117
|
-
'
|
|
117
|
+
'extraction_strategy': extraction_strategy.value,
|
|
118
118
|
'num_pairs': len(enriched_pairs),
|
|
119
119
|
'pairs': []
|
|
120
120
|
}
|
|
@@ -24,6 +24,7 @@ import torch
|
|
|
24
24
|
from wisent.core.activations.activations_collector import ActivationCollector
|
|
25
25
|
from wisent.core.activations.extraction_strategy import ExtractionStrategy
|
|
26
26
|
from wisent.core.activations.core.atoms import LayerActivations
|
|
27
|
+
from wisent.core.utils.device import resolve_default_device
|
|
27
28
|
|
|
28
29
|
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
29
30
|
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
@@ -175,7 +176,7 @@ class MethodOptimizer:
|
|
|
175
176
|
self,
|
|
176
177
|
model,
|
|
177
178
|
method_name: str,
|
|
178
|
-
device: str =
|
|
179
|
+
device: str | None = None,
|
|
179
180
|
verbose: bool = True,
|
|
180
181
|
):
|
|
181
182
|
"""
|
|
@@ -189,7 +190,7 @@ class MethodOptimizer:
|
|
|
189
190
|
"""
|
|
190
191
|
self.model = model
|
|
191
192
|
self.method_name = method_name.lower()
|
|
192
|
-
self.device = device
|
|
193
|
+
self.device = device or resolve_default_device()
|
|
193
194
|
self.verbose = verbose
|
|
194
195
|
|
|
195
196
|
# Validate method exists
|
|
@@ -250,7 +251,7 @@ class MethodOptimizer:
|
|
|
250
251
|
"mean_pooling": ExtractionStrategy.CHAT_MEAN,
|
|
251
252
|
"first_token": ExtractionStrategy.CHAT_FIRST,
|
|
252
253
|
"max_pooling": ExtractionStrategy.CHAT_MAX_NORM,
|
|
253
|
-
"continuation_token": ExtractionStrategy.
|
|
254
|
+
"continuation_token": ExtractionStrategy.CHAT_FIRST, # First answer token
|
|
254
255
|
}
|
|
255
256
|
|
|
256
257
|
prompt_strat_map = {
|
|
@@ -14,6 +14,7 @@ import time
|
|
|
14
14
|
from pathlib import Path
|
|
15
15
|
import torch
|
|
16
16
|
|
|
17
|
+
from wisent.core.utils.device import resolve_default_device
|
|
17
18
|
from wisent.core.cli_logger import setup_logger, bind
|
|
18
19
|
from wisent.core.models.wisent_model import WisentModel
|
|
19
20
|
from wisent.core.weight_modification import (
|
|
@@ -72,7 +73,7 @@ def execute_modify_weights(args):
|
|
|
72
73
|
|
|
73
74
|
if vector_path.suffix == '.pt':
|
|
74
75
|
# Load PyTorch format (from train-unified-goodness or similar)
|
|
75
|
-
checkpoint = torch.load(args.steering_vectors, map_location=
|
|
76
|
+
checkpoint = torch.load(args.steering_vectors, map_location=resolve_default_device(), weights_only=False)
|
|
76
77
|
|
|
77
78
|
# Handle different .pt file formats
|
|
78
79
|
if 'steering_vectors' in checkpoint:
|
|
@@ -354,7 +355,7 @@ def execute_modify_weights(args):
|
|
|
354
355
|
|
|
355
356
|
execute_train_unified_goodness(unified_args)
|
|
356
357
|
|
|
357
|
-
checkpoint = torch.load(unified_args.output, map_location=
|
|
358
|
+
checkpoint = torch.load(unified_args.output, map_location=resolve_default_device(), weights_only=False)
|
|
358
359
|
|
|
359
360
|
if 'steering_vectors' in checkpoint:
|
|
360
361
|
raw_vectors = checkpoint['steering_vectors']
|
|
@@ -87,7 +87,7 @@ def execute_optimize_sample_size(args):
|
|
|
87
87
|
# Get extraction strategy from args
|
|
88
88
|
extraction_strategy = ExtractionStrategy(getattr(args, 'extraction_strategy', 'chat_last'))
|
|
89
89
|
|
|
90
|
-
collector = ActivationCollector(model=model
|
|
90
|
+
collector = ActivationCollector(model=model)
|
|
91
91
|
|
|
92
92
|
# Collect test activations for all test pairs (ONCE)
|
|
93
93
|
X_test_list = []
|
|
@@ -77,7 +77,7 @@ def _run_optuna_search_for_task(
|
|
|
77
77
|
|
|
78
78
|
try:
|
|
79
79
|
# Collect activations
|
|
80
|
-
collector = ActivationCollector(model=model
|
|
80
|
+
collector = ActivationCollector(model=model)
|
|
81
81
|
pos_acts = []
|
|
82
82
|
neg_acts = []
|
|
83
83
|
|
|
@@ -389,7 +389,7 @@ def execute_comprehensive(args, model, loader):
|
|
|
389
389
|
"first_token": ExtractionStrategy.CHAT_FIRST,
|
|
390
390
|
"max_pooling": ExtractionStrategy.CHAT_MAX_NORM,
|
|
391
391
|
"choice_token": ExtractionStrategy.MC_BALANCED,
|
|
392
|
-
"continuation_token": ExtractionStrategy.
|
|
392
|
+
"continuation_token": ExtractionStrategy.CHAT_FIRST, # First answer token
|
|
393
393
|
}
|
|
394
394
|
if hasattr(args, 'search_token_aggregations') and args.search_token_aggregations:
|
|
395
395
|
token_agg_names = [x.strip() for x in args.search_token_aggregations.split(',')]
|
|
@@ -610,7 +610,7 @@ def execute_comprehensive(args, model, loader):
|
|
|
610
610
|
layer_str = str(layer)
|
|
611
611
|
|
|
612
612
|
# Step 1: Generate steering vector using CAA with current token aggregation
|
|
613
|
-
collector = ActivationCollector(model=model
|
|
613
|
+
collector = ActivationCollector(model=model)
|
|
614
614
|
|
|
615
615
|
pos_acts = []
|
|
616
616
|
neg_acts = []
|
|
@@ -1456,7 +1456,7 @@ def execute_compare_methods(args, model, loader):
|
|
|
1456
1456
|
|
|
1457
1457
|
# Collect activations once for all methods
|
|
1458
1458
|
layer_str = str(args.layer)
|
|
1459
|
-
collector = ActivationCollector(model=model
|
|
1459
|
+
collector = ActivationCollector(model=model)
|
|
1460
1460
|
|
|
1461
1461
|
print("šÆ Collecting training activations (ONCE)...")
|
|
1462
1462
|
pos_acts = []
|
|
@@ -1719,7 +1719,7 @@ def execute_optimize_layer(args, model, loader):
|
|
|
1719
1719
|
print("Aborted by user.")
|
|
1720
1720
|
return {"action": "optimize-layer", "status": "aborted", "reason": "user declined reduced search"}
|
|
1721
1721
|
|
|
1722
|
-
collector = ActivationCollector(model=model
|
|
1722
|
+
collector = ActivationCollector(model=model)
|
|
1723
1723
|
layer_results = {}
|
|
1724
1724
|
best_layer = None
|
|
1725
1725
|
best_accuracy = 0.0
|
|
@@ -1986,7 +1986,7 @@ def execute_optimize_strength(args, model, loader):
|
|
|
1986
1986
|
|
|
1987
1987
|
# Collect activations ONCE
|
|
1988
1988
|
layer_str = str(args.layer)
|
|
1989
|
-
collector = ActivationCollector(model=model
|
|
1989
|
+
collector = ActivationCollector(model=model)
|
|
1990
1990
|
|
|
1991
1991
|
print("šÆ Collecting training activations (ONCE)...")
|
|
1992
1992
|
pos_acts = []
|
|
@@ -2277,7 +2277,7 @@ def execute_auto(args, model, loader):
|
|
|
2277
2277
|
print(f" Testing {len(strengths_to_test)} strengths: {strengths_to_test[0]:.2f} to {strengths_to_test[-1]:.2f}")
|
|
2278
2278
|
print(f" Total configurations: {len(layers_to_test) * len(strengths_to_test)}\n")
|
|
2279
2279
|
|
|
2280
|
-
collector = ActivationCollector(model=model
|
|
2280
|
+
collector = ActivationCollector(model=model)
|
|
2281
2281
|
all_results = {}
|
|
2282
2282
|
best_config = None
|
|
2283
2283
|
best_accuracy = 0.0
|
|
@@ -2575,16 +2575,15 @@ def execute_personalization(args, model):
|
|
|
2575
2575
|
min_strength, max_strength = args.strength_range
|
|
2576
2576
|
strengths_to_test = np.linspace(min_strength, max_strength, 7)
|
|
2577
2577
|
|
|
2578
|
-
# Token aggregation strategies to test
|
|
2578
|
+
# Token aggregation strategies to test
|
|
2579
2579
|
token_aggregations_to_test = [
|
|
2580
2580
|
ExtractionStrategy.CHAT_LAST,
|
|
2581
2581
|
ExtractionStrategy.CHAT_MEAN,
|
|
2582
2582
|
ExtractionStrategy.CHAT_FIRST,
|
|
2583
2583
|
ExtractionStrategy.CHAT_MAX_NORM,
|
|
2584
|
-
ExtractionStrategy.CHAT_GEN_POINT,
|
|
2585
2584
|
]
|
|
2586
2585
|
|
|
2587
|
-
# Prompt construction strategies to test
|
|
2586
|
+
# Prompt construction strategies to test
|
|
2588
2587
|
prompt_constructions_to_test = [
|
|
2589
2588
|
ExtractionStrategy.CHAT_LAST,
|
|
2590
2589
|
ExtractionStrategy.CHAT_LAST,
|
|
@@ -2655,7 +2654,7 @@ def execute_personalization(args, model):
|
|
|
2655
2654
|
print(flush=True)
|
|
2656
2655
|
|
|
2657
2656
|
# Initialize activation collector
|
|
2658
|
-
collector = ActivationCollector(model=model
|
|
2657
|
+
collector = ActivationCollector(model=model)
|
|
2659
2658
|
|
|
2660
2659
|
# Track results for all configurations
|
|
2661
2660
|
all_results = {}
|
|
@@ -3108,16 +3107,15 @@ def execute_multi_personalization(args, model):
|
|
|
3108
3107
|
min_strength, max_strength = args.strength_range
|
|
3109
3108
|
strengths_to_test = np.linspace(min_strength, max_strength, 7)
|
|
3110
3109
|
|
|
3111
|
-
# Token aggregation strategies to test
|
|
3110
|
+
# Token aggregation strategies to test
|
|
3112
3111
|
token_aggregations_to_test = [
|
|
3113
3112
|
ExtractionStrategy.CHAT_LAST,
|
|
3114
3113
|
ExtractionStrategy.CHAT_MEAN,
|
|
3115
3114
|
ExtractionStrategy.CHAT_FIRST,
|
|
3116
3115
|
ExtractionStrategy.CHAT_MAX_NORM,
|
|
3117
|
-
ExtractionStrategy.CHAT_GEN_POINT,
|
|
3118
3116
|
]
|
|
3119
3117
|
|
|
3120
|
-
# Prompt construction strategies to test
|
|
3118
|
+
# Prompt construction strategies to test
|
|
3121
3119
|
prompt_constructions_to_test = [
|
|
3122
3120
|
ExtractionStrategy.CHAT_LAST,
|
|
3123
3121
|
ExtractionStrategy.CHAT_LAST,
|
|
@@ -3176,7 +3174,7 @@ def execute_multi_personalization(args, model):
|
|
|
3176
3174
|
print(f"\nš Test prompts: {test_prompts}", flush=True)
|
|
3177
3175
|
|
|
3178
3176
|
# Initialize collector
|
|
3179
|
-
collector = ActivationCollector(model=model
|
|
3177
|
+
collector = ActivationCollector(model=model)
|
|
3180
3178
|
|
|
3181
3179
|
# Track results
|
|
3182
3180
|
all_results = {}
|
|
@@ -3565,7 +3563,7 @@ def execute_universal(args, model, loader):
|
|
|
3565
3563
|
optimizer = MethodOptimizer(
|
|
3566
3564
|
model=model,
|
|
3567
3565
|
method_name=method_name,
|
|
3568
|
-
device=args.device
|
|
3566
|
+
device=args.device if hasattr(args, "device") and args.device else None,
|
|
3569
3567
|
verbose=args.verbose if hasattr(args, "verbose") else True,
|
|
3570
3568
|
)
|
|
3571
3569
|
|
|
@@ -28,6 +28,7 @@ from dataclasses import dataclass
|
|
|
28
28
|
from typing import Any, Callable
|
|
29
29
|
|
|
30
30
|
import torch
|
|
31
|
+
from wisent.core.utils.device import resolve_default_device
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def upload_to_s3(local_path: str, s3_bucket: str, s3_key: str) -> bool:
|
|
@@ -661,7 +662,7 @@ def _generate_steering_vectors(args, num_pairs: int, num_layers: int = None) ->
|
|
|
661
662
|
execute_train_unified_goodness(vector_args)
|
|
662
663
|
|
|
663
664
|
# Load the .pt file
|
|
664
|
-
checkpoint = torch.load(temp_output_pt, map_location=
|
|
665
|
+
checkpoint = torch.load(temp_output_pt, map_location=resolve_default_device(), weights_only=False)
|
|
665
666
|
|
|
666
667
|
# Handle different checkpoint formats
|
|
667
668
|
if 'all_layer_vectors' in checkpoint:
|