wisent 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/__init__.py +1 -18
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/__init__.py +1 -55
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/METADATA +3 -3
- wisent-0.5.14.dist-info/RECORD +294 -0
- wisent-0.5.14.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Perplexity evaluator for language modeling benchmarks.
|
|
2
|
+
|
|
3
|
+
Used for tasks that measure language modeling performance.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any
|
|
7
|
+
import logging
|
|
8
|
+
import math
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from wisent.core.evaluators.core.atoms import BaseEvaluator, EvalResult
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PerplexityEvaluator(BaseEvaluator):
|
|
17
|
+
"""Evaluator using perplexity for language modeling tasks.
|
|
18
|
+
|
|
19
|
+
Compatible with:
|
|
20
|
+
- WikiText: Language modeling
|
|
21
|
+
- LAMBADA: Word prediction in context
|
|
22
|
+
- Any loglikelihood_rolling task
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
name = "perplexity"
|
|
26
|
+
description = "Perplexity evaluator for language modeling"
|
|
27
|
+
task_names = ("wikitext", "lambada_openai", "lambada_standard")
|
|
28
|
+
|
|
29
|
+
def __init__(self, model=None):
|
|
30
|
+
"""Initialize perplexity evaluator.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model: Model with loglikelihood capabilities
|
|
34
|
+
"""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.model = model
|
|
37
|
+
|
|
38
|
+
def evaluate(self, response: str, expected: Any, **kwargs) -> EvalResult:
|
|
39
|
+
"""Evaluate using perplexity.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
response: Text to evaluate (for language modeling)
|
|
43
|
+
expected: NOT USED (perplexity is computed on response)
|
|
44
|
+
**kwargs:
|
|
45
|
+
model: Model instance (WisentModel or similar, overrides self.model)
|
|
46
|
+
context: Optional context for conditional generation
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
EvalResult with perplexity as confidence metric (lower is better)
|
|
50
|
+
"""
|
|
51
|
+
model = kwargs.get('model', self.model)
|
|
52
|
+
context = kwargs.get('context', '')
|
|
53
|
+
|
|
54
|
+
if model is None:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"No model provided for perplexity computation. "
|
|
57
|
+
"Please provide a model via __init__ or as a kwarg."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
# Compute perplexity
|
|
62
|
+
full_text = f"{context}{response}" if context else response
|
|
63
|
+
perplexity = self._compute_perplexity(model, full_text)
|
|
64
|
+
|
|
65
|
+
# Lower perplexity is better, so we use negative for confidence
|
|
66
|
+
# (higher confidence = lower perplexity)
|
|
67
|
+
confidence = -perplexity
|
|
68
|
+
|
|
69
|
+
return EvalResult(
|
|
70
|
+
ground_truth="EVALUATED",
|
|
71
|
+
method_used=self.name,
|
|
72
|
+
confidence=confidence,
|
|
73
|
+
details=f"Perplexity: {perplexity:.4f} (lower is better)",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.error(f"Error computing perplexity: {e}")
|
|
78
|
+
return EvalResult(
|
|
79
|
+
ground_truth="ERROR",
|
|
80
|
+
method_used=self.name,
|
|
81
|
+
confidence=0.0,
|
|
82
|
+
details=f"Perplexity computation failed: {str(e)}",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _compute_perplexity(self, model, text: str) -> float:
|
|
86
|
+
"""Compute perplexity for text.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model: Model with HuggingFace interface (WisentModel or similar)
|
|
90
|
+
text: Text to evaluate
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Perplexity value (lower is better)
|
|
94
|
+
"""
|
|
95
|
+
# Get model and tokenizer from WisentModel
|
|
96
|
+
if hasattr(model, 'hf_model') and hasattr(model, 'tokenizer'):
|
|
97
|
+
hf_model = model.hf_model
|
|
98
|
+
tokenizer = model.tokenizer
|
|
99
|
+
else:
|
|
100
|
+
# Assume model is directly a HuggingFace model
|
|
101
|
+
hf_model = model
|
|
102
|
+
tokenizer = getattr(model, 'tokenizer', None)
|
|
103
|
+
if tokenizer is None:
|
|
104
|
+
raise ValueError("Model must have a tokenizer attribute")
|
|
105
|
+
|
|
106
|
+
# Tokenize the text
|
|
107
|
+
encodings = tokenizer(text, return_tensors='pt')
|
|
108
|
+
input_ids = encodings['input_ids'].to(hf_model.device)
|
|
109
|
+
|
|
110
|
+
# Get model outputs (logits)
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
outputs = hf_model(input_ids)
|
|
113
|
+
logits = outputs.logits
|
|
114
|
+
|
|
115
|
+
# Shift logits and labels for next-token prediction
|
|
116
|
+
# logits: [batch, seq_len, vocab_size]
|
|
117
|
+
# We want to predict tokens 1..N from tokens 0..N-1
|
|
118
|
+
shift_logits = logits[:, :-1, :].contiguous()
|
|
119
|
+
shift_labels = input_ids[:, 1:].contiguous()
|
|
120
|
+
|
|
121
|
+
# Compute log probabilities
|
|
122
|
+
log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
|
|
123
|
+
|
|
124
|
+
# Gather the log probabilities of the actual tokens
|
|
125
|
+
# shift_labels: [batch, seq_len-1]
|
|
126
|
+
# We need to gather from log_probs: [batch, seq_len-1, vocab_size]
|
|
127
|
+
batch_size, seq_len = shift_labels.shape
|
|
128
|
+
token_log_probs = log_probs.gather(
|
|
129
|
+
dim=-1,
|
|
130
|
+
index=shift_labels.unsqueeze(-1)
|
|
131
|
+
).squeeze(-1)
|
|
132
|
+
|
|
133
|
+
# Compute negative log-likelihood (NLL)
|
|
134
|
+
nll = -token_log_probs.sum()
|
|
135
|
+
|
|
136
|
+
# Compute perplexity = exp(NLL / num_tokens)
|
|
137
|
+
num_tokens = seq_len
|
|
138
|
+
perplexity = torch.exp(nll / num_tokens)
|
|
139
|
+
|
|
140
|
+
return float(perplexity)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Personalization evaluator for assessing personality trait manifestation in model responses.
|
|
2
|
+
|
|
3
|
+
Evaluates how well model-generated responses exhibit specified personality traits
|
|
4
|
+
by using an LLM judge to assess trait alignment and intensity.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
|
|
11
|
+
from wisent.core.evaluators.core.atoms import BaseEvaluator, EvalResult
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PersonalizationEvaluator(BaseEvaluator):
|
|
17
|
+
"""Evaluator for personality trait manifestation in responses.
|
|
18
|
+
|
|
19
|
+
Uses an LLM judge to evaluate how well responses exhibit target personality traits.
|
|
20
|
+
Compatible with personalization benchmarks that require trait assessment.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
name = "personalization"
|
|
24
|
+
description = "Evaluates personality trait manifestation in model responses"
|
|
25
|
+
task_names = ("personalization",)
|
|
26
|
+
|
|
27
|
+
# Default judge model - can be overridden
|
|
28
|
+
default_judge_model = "gpt-4"
|
|
29
|
+
|
|
30
|
+
# Evaluation prompt template
|
|
31
|
+
JUDGE_PROMPT_TEMPLATE = """You are an expert evaluator analyzing personality traits in text responses.
|
|
32
|
+
|
|
33
|
+
**Target Trait**: {trait}
|
|
34
|
+
**Trait Description**: {trait_description}
|
|
35
|
+
|
|
36
|
+
**Prompt Given to Model**: {prompt}
|
|
37
|
+
|
|
38
|
+
**Model Response to Evaluate**:
|
|
39
|
+
{response}
|
|
40
|
+
|
|
41
|
+
Please evaluate how well this response exhibits the target trait "{trait}".
|
|
42
|
+
|
|
43
|
+
Provide your evaluation in the following format:
|
|
44
|
+
|
|
45
|
+
TRAIT_SCORE: [0-10]
|
|
46
|
+
INTENSITY: [weak/moderate/strong]
|
|
47
|
+
EXPLANATION: [Brief explanation of your rating]
|
|
48
|
+
|
|
49
|
+
Where TRAIT_SCORE represents:
|
|
50
|
+
- 0-2: Trait not present or contradicts trait
|
|
51
|
+
- 3-4: Minimal trait manifestation
|
|
52
|
+
- 5-6: Moderate trait manifestation
|
|
53
|
+
- 7-8: Strong trait manifestation
|
|
54
|
+
- 9-10: Exemplary trait manifestation
|
|
55
|
+
|
|
56
|
+
Evaluate ONLY based on how well the response demonstrates the target trait, not overall quality."""
|
|
57
|
+
|
|
58
|
+
def __init__(self, judge_model: Optional[str] = None):
|
|
59
|
+
"""Initialize personalization evaluator.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
judge_model: Model to use as judge (default: gpt-4)
|
|
63
|
+
"""
|
|
64
|
+
self.judge_model = judge_model or self.default_judge_model
|
|
65
|
+
|
|
66
|
+
def evaluate(self, response: str, expected: Any, **kwargs) -> EvalResult:
|
|
67
|
+
"""Evaluate personality trait manifestation in response.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
response: Generated model response to evaluate
|
|
71
|
+
expected: Expected trait information (dict with 'trait', 'trait_description')
|
|
72
|
+
**kwargs:
|
|
73
|
+
prompt: Original prompt given to model
|
|
74
|
+
judge_model: Override default judge model
|
|
75
|
+
use_mock: Use mock evaluation for testing (default: False)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
EvalResult with trait_score as confidence, details containing analysis
|
|
79
|
+
"""
|
|
80
|
+
# Extract trait information
|
|
81
|
+
if isinstance(expected, dict):
|
|
82
|
+
trait = expected.get('trait', 'unknown')
|
|
83
|
+
trait_description = expected.get('trait_description', '')
|
|
84
|
+
else:
|
|
85
|
+
trait = str(expected)
|
|
86
|
+
trait_description = f"The trait: {trait}"
|
|
87
|
+
|
|
88
|
+
prompt = kwargs.get('prompt', '')
|
|
89
|
+
judge_model = kwargs.get('judge_model', self.judge_model)
|
|
90
|
+
use_mock = kwargs.get('use_mock', False)
|
|
91
|
+
|
|
92
|
+
# Mock evaluation for testing without API calls
|
|
93
|
+
if use_mock:
|
|
94
|
+
return self._mock_evaluate(response, trait, prompt)
|
|
95
|
+
|
|
96
|
+
# Prepare judge prompt
|
|
97
|
+
judge_prompt = self.JUDGE_PROMPT_TEMPLATE.format(
|
|
98
|
+
trait=trait,
|
|
99
|
+
trait_description=trait_description,
|
|
100
|
+
prompt=prompt,
|
|
101
|
+
response=response
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
# Call LLM judge
|
|
106
|
+
judge_response = self._call_judge(judge_prompt, judge_model)
|
|
107
|
+
|
|
108
|
+
# Parse judge response
|
|
109
|
+
trait_score, intensity, explanation = self._parse_judge_response(judge_response)
|
|
110
|
+
|
|
111
|
+
# Normalize score to 0-1 confidence
|
|
112
|
+
confidence = trait_score / 10.0
|
|
113
|
+
|
|
114
|
+
# Determine ground truth based on score
|
|
115
|
+
if trait_score >= 7:
|
|
116
|
+
ground_truth = "TRUTHFUL" # Strong trait manifestation
|
|
117
|
+
elif trait_score >= 4:
|
|
118
|
+
ground_truth = "UNKNOWN" # Moderate trait manifestation
|
|
119
|
+
else:
|
|
120
|
+
ground_truth = "UNTRUTHFUL" # Weak/absent trait manifestation
|
|
121
|
+
|
|
122
|
+
return EvalResult(
|
|
123
|
+
ground_truth=ground_truth,
|
|
124
|
+
method_used=self.name,
|
|
125
|
+
confidence=confidence,
|
|
126
|
+
details=f"Trait: {trait} | Score: {trait_score}/10 | Intensity: {intensity} | {explanation}",
|
|
127
|
+
meta={
|
|
128
|
+
"trait": trait,
|
|
129
|
+
"trait_score": trait_score,
|
|
130
|
+
"intensity": intensity,
|
|
131
|
+
"judge_model": judge_model,
|
|
132
|
+
"explanation": explanation
|
|
133
|
+
}
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.exception(f"Error in personalization evaluation: {e}")
|
|
138
|
+
return EvalResult(
|
|
139
|
+
ground_truth="UNKNOWN",
|
|
140
|
+
method_used=self.name,
|
|
141
|
+
confidence=0.0,
|
|
142
|
+
details=f"Evaluation error: {str(e)}",
|
|
143
|
+
meta={"trait": trait, "error": str(e)}
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _mock_evaluate(self, response: str, trait: str, prompt: str) -> EvalResult:
|
|
147
|
+
"""Mock evaluation for testing without API calls.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
response: Model response
|
|
151
|
+
trait: Target trait
|
|
152
|
+
prompt: Original prompt
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
EvalResult with mock scores
|
|
156
|
+
"""
|
|
157
|
+
# Simple heuristic: check if trait appears in response
|
|
158
|
+
response_lower = response.lower()
|
|
159
|
+
trait_lower = trait.lower()
|
|
160
|
+
|
|
161
|
+
# Calculate mock score based on simple heuristics
|
|
162
|
+
score = 5 # baseline
|
|
163
|
+
|
|
164
|
+
if trait_lower in response_lower:
|
|
165
|
+
score += 2
|
|
166
|
+
|
|
167
|
+
if len(response) > 100:
|
|
168
|
+
score += 1
|
|
169
|
+
|
|
170
|
+
if len(response.split('.')) > 2:
|
|
171
|
+
score += 1
|
|
172
|
+
|
|
173
|
+
score = min(score, 10)
|
|
174
|
+
|
|
175
|
+
intensity = "weak" if score < 5 else ("moderate" if score < 7 else "strong")
|
|
176
|
+
|
|
177
|
+
confidence = score / 10.0
|
|
178
|
+
ground_truth = "TRUTHFUL" if score >= 7 else ("UNKNOWN" if score >= 4 else "UNTRUTHFUL")
|
|
179
|
+
|
|
180
|
+
return EvalResult(
|
|
181
|
+
ground_truth=ground_truth,
|
|
182
|
+
method_used=f"{self.name}_mock",
|
|
183
|
+
confidence=confidence,
|
|
184
|
+
details=f"Mock evaluation | Trait: {trait} | Score: {score}/10 | Intensity: {intensity}",
|
|
185
|
+
meta={
|
|
186
|
+
"trait": trait,
|
|
187
|
+
"trait_score": score,
|
|
188
|
+
"intensity": intensity,
|
|
189
|
+
"judge_model": "mock",
|
|
190
|
+
"explanation": "Mock evaluation based on simple heuristics"
|
|
191
|
+
}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _call_judge(self, prompt: str, model: str) -> str:
|
|
195
|
+
"""Call LLM judge to evaluate response.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
prompt: Judge prompt
|
|
199
|
+
model: Model identifier
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Judge response text
|
|
203
|
+
"""
|
|
204
|
+
# This would call OpenAI API or other LLM API
|
|
205
|
+
# For now, raising NotImplementedError to indicate it needs implementation
|
|
206
|
+
try:
|
|
207
|
+
import openai
|
|
208
|
+
|
|
209
|
+
response = openai.ChatCompletion.create(
|
|
210
|
+
model=model,
|
|
211
|
+
messages=[
|
|
212
|
+
{"role": "system", "content": "You are an expert evaluator of personality traits in text."},
|
|
213
|
+
{"role": "user", "content": prompt}
|
|
214
|
+
],
|
|
215
|
+
temperature=0.0,
|
|
216
|
+
max_tokens=500
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return response.choices[0].message.content
|
|
220
|
+
|
|
221
|
+
except ImportError:
|
|
222
|
+
raise NotImplementedError(
|
|
223
|
+
"OpenAI package not installed. Install with: pip install openai\n"
|
|
224
|
+
"Or use use_mock=True for testing without API calls."
|
|
225
|
+
)
|
|
226
|
+
except Exception as e:
|
|
227
|
+
raise RuntimeError(f"Error calling judge model: {e}")
|
|
228
|
+
|
|
229
|
+
def _parse_judge_response(self, judge_response: str) -> tuple[float, str, str]:
|
|
230
|
+
"""Parse judge response to extract score, intensity, and explanation.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
judge_response: Raw judge response text
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Tuple of (trait_score, intensity, explanation)
|
|
237
|
+
"""
|
|
238
|
+
# Extract TRAIT_SCORE
|
|
239
|
+
score_match = re.search(r'TRAIT_SCORE:\s*(\d+(?:\.\d+)?)', judge_response, re.IGNORECASE)
|
|
240
|
+
trait_score = float(score_match.group(1)) if score_match else 5.0
|
|
241
|
+
|
|
242
|
+
# Extract INTENSITY
|
|
243
|
+
intensity_match = re.search(r'INTENSITY:\s*(weak|moderate|strong)', judge_response, re.IGNORECASE)
|
|
244
|
+
intensity = intensity_match.group(1).lower() if intensity_match else "moderate"
|
|
245
|
+
|
|
246
|
+
# Extract EXPLANATION
|
|
247
|
+
explanation_match = re.search(r'EXPLANATION:\s*(.+?)(?=\n\n|\Z)', judge_response, re.IGNORECASE | re.DOTALL)
|
|
248
|
+
explanation = explanation_match.group(1).strip() if explanation_match else "No explanation provided"
|
|
249
|
+
|
|
250
|
+
return trait_score, intensity, explanation
|
|
@@ -20,7 +20,7 @@ class EvaluatorRotator:
|
|
|
20
20
|
self,
|
|
21
21
|
evaluator: Union[str, BaseEvaluator, Type[BaseEvaluator], None] = None,
|
|
22
22
|
task_name: Optional[str] = None,
|
|
23
|
-
evaluators_location: Union[str, Path] = "
|
|
23
|
+
evaluators_location: Union[str, Path] = "wisent.core.evaluators.oracles",
|
|
24
24
|
autoload: bool = True,
|
|
25
25
|
) -> None:
|
|
26
26
|
if autoload:
|
|
@@ -29,7 +29,7 @@ class EvaluatorRotator:
|
|
|
29
29
|
self._task_name = task_name
|
|
30
30
|
|
|
31
31
|
@staticmethod
|
|
32
|
-
def discover_evaluators(location: Union[str, Path] = "
|
|
32
|
+
def discover_evaluators(location: Union[str, Path] = "wisent.core.evaluators.oracles") -> None:
|
|
33
33
|
"""
|
|
34
34
|
Import all evaluator modules so BaseEvaluator subclasses self-register.
|
|
35
35
|
|
|
@@ -130,10 +130,10 @@ class EvaluatorRotator:
|
|
|
130
130
|
|
|
131
131
|
|
|
132
132
|
if __name__ == "__main__":
|
|
133
|
-
from
|
|
133
|
+
from wisent.core.evaluators.rotator import EvaluatorRotator
|
|
134
134
|
|
|
135
135
|
rot = EvaluatorRotator(
|
|
136
|
-
evaluators_location="
|
|
136
|
+
evaluators_location="wisent.core.evaluators.oracles", # << no leading slash
|
|
137
137
|
autoload=True,
|
|
138
138
|
)
|
|
139
139
|
|
|
@@ -7,7 +7,8 @@ This module provides ground truth evaluation using the lm-eval-harness framework
|
|
|
7
7
|
import logging
|
|
8
8
|
from typing import Any, Dict
|
|
9
9
|
|
|
10
|
-
from wisent.core.activations import ActivationAggregationStrategy
|
|
10
|
+
from wisent.core.activations.core.atoms import ActivationAggregationStrategy
|
|
11
|
+
from wisent.core.activations.activations import Activations
|
|
11
12
|
from wisent.core.layer import Layer
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
@@ -636,7 +637,7 @@ class LMEvalHarnessGroundTruth:
|
|
|
636
637
|
try:
|
|
637
638
|
import json
|
|
638
639
|
|
|
639
|
-
eval_methods_path = "
|
|
640
|
+
eval_methods_path = "wisent/parameters/benchmarks/benchmark_evaluation_methods.json"
|
|
640
641
|
with open(eval_methods_path) as f:
|
|
641
642
|
benchmark_methods = json.load(f)
|
|
642
643
|
return benchmark_methods.get(task_name, "text-generation")
|
wisent/core/main.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main entry point for the Wisent CLI.
|
|
3
|
+
|
|
4
|
+
This module connects the argparse parser (wisent/core/parser_arguments/) to execution logic
|
|
5
|
+
and provides the main() function that serves as the CLI entry point.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import sys
|
|
9
|
+
from wisent.core.parser_arguments import setup_parser
|
|
10
|
+
from wisent.core.branding import print_banner
|
|
11
|
+
from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, execute_generate_pairs, execute_get_activations, execute_create_steering_vector, execute_generate_vector_from_task, execute_generate_vector_from_synthetic, execute_optimize_classification, execute_optimize_steering, execute_generate_responses, execute_evaluate_responses
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main():
|
|
15
|
+
"""Main entry point for the Wisent CLI."""
|
|
16
|
+
# Show banner
|
|
17
|
+
print_banner("Wisent CLI", width=64, use_color=True)
|
|
18
|
+
|
|
19
|
+
# Parse arguments
|
|
20
|
+
parser = setup_parser()
|
|
21
|
+
args = parser.parse_args()
|
|
22
|
+
|
|
23
|
+
# If no command specified, show help
|
|
24
|
+
if not hasattr(args, 'command') or args.command is None:
|
|
25
|
+
parser.print_help()
|
|
26
|
+
sys.exit(0)
|
|
27
|
+
|
|
28
|
+
# Execute based on command
|
|
29
|
+
if args.command == 'tasks':
|
|
30
|
+
execute_tasks(args)
|
|
31
|
+
elif args.command == 'generate-pairs':
|
|
32
|
+
execute_generate_pairs(args)
|
|
33
|
+
elif args.command == 'generate-pairs-from-task':
|
|
34
|
+
execute_generate_pairs_from_task(args)
|
|
35
|
+
elif args.command == 'get-activations':
|
|
36
|
+
execute_get_activations(args)
|
|
37
|
+
elif args.command == 'create-steering-vector':
|
|
38
|
+
execute_create_steering_vector(args)
|
|
39
|
+
elif args.command == 'generate-vector-from-task':
|
|
40
|
+
execute_generate_vector_from_task(args)
|
|
41
|
+
elif args.command == 'generate-vector-from-synthetic':
|
|
42
|
+
execute_generate_vector_from_synthetic(args)
|
|
43
|
+
elif args.command == 'optimize-classification':
|
|
44
|
+
execute_optimize_classification(args)
|
|
45
|
+
elif args.command == 'optimize-steering':
|
|
46
|
+
execute_optimize_steering(args)
|
|
47
|
+
elif args.command == 'generate-responses':
|
|
48
|
+
execute_generate_responses(args)
|
|
49
|
+
elif args.command == 'evaluate-responses':
|
|
50
|
+
execute_evaluate_responses(args)
|
|
51
|
+
else:
|
|
52
|
+
print(f"\n✗ Command '{args.command}' is not yet implemented")
|
|
53
|
+
sys.exit(1)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == '__main__':
|
|
57
|
+
main()
|
wisent/core/model_persistence.py
CHANGED
|
@@ -268,7 +268,7 @@ def create_classifier_metadata(
|
|
|
268
268
|
'token_aggregation': token_aggregation,
|
|
269
269
|
'detection_threshold': detection_threshold,
|
|
270
270
|
'created_at': datetime.datetime.now().isoformat(),
|
|
271
|
-
'
|
|
271
|
+
'wisent_version': '1.0.0' # Could be dynamically determined
|
|
272
272
|
}
|
|
273
273
|
|
|
274
274
|
# Add any additional metadata
|
|
@@ -308,7 +308,7 @@ def create_steering_vector_metadata(
|
|
|
308
308
|
'vector_strength': vector_strength,
|
|
309
309
|
'training_samples': training_samples,
|
|
310
310
|
'created_at': datetime.datetime.now().isoformat(),
|
|
311
|
-
'
|
|
311
|
+
'wisent_version': '1.0.0'
|
|
312
312
|
}
|
|
313
313
|
|
|
314
314
|
# Add any additional metadata
|
|
@@ -95,7 +95,7 @@ class WisentModel:
|
|
|
95
95
|
elif self.device == "cuda":
|
|
96
96
|
load_kwargs["dtype"] = torch.float16
|
|
97
97
|
load_kwargs["device_map"] = "auto"
|
|
98
|
-
load_kwargs["attn_implementation"] = "flash_attention_2" #
|
|
98
|
+
load_kwargs["attn_implementation"] = "flash_attention_2" # Uses flash-attn for 2-4x speedup
|
|
99
99
|
else:
|
|
100
100
|
load_kwargs["dtype"] = torch.float32
|
|
101
101
|
load_kwargs["device_map"] = None
|
|
@@ -330,7 +330,7 @@ class WisentModel:
|
|
|
330
330
|
|
|
331
331
|
batch = self.tokenizer.pad(singles, padding=True, return_tensors="pt")
|
|
332
332
|
|
|
333
|
-
batch = {k: v.to(
|
|
333
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
334
334
|
|
|
335
335
|
return batch
|
|
336
336
|
|
|
@@ -463,8 +463,8 @@ class WisentModel:
|
|
|
463
463
|
)
|
|
464
464
|
# Move tensors to the correct device (same as _batch_encode does)
|
|
465
465
|
batch = {
|
|
466
|
-
"input_ids": tokenizer_output["input_ids"].to(
|
|
467
|
-
"attention_mask": tokenizer_output["attention_mask"].to(
|
|
466
|
+
"input_ids": tokenizer_output["input_ids"].to(self.device),
|
|
467
|
+
"attention_mask": tokenizer_output["attention_mask"].to(self.device)
|
|
468
468
|
}
|
|
469
469
|
else:
|
|
470
470
|
# Current behavior: apply chat template
|
|
@@ -695,8 +695,8 @@ class WisentModel:
|
|
|
695
695
|
)
|
|
696
696
|
# Move tensors to the correct device (same as _batch_encode does)
|
|
697
697
|
batch = {
|
|
698
|
-
"input_ids": tokenizer_output["input_ids"].to(
|
|
699
|
-
"attention_mask": tokenizer_output["attention_mask"].to(
|
|
698
|
+
"input_ids": tokenizer_output["input_ids"].to(self.device),
|
|
699
|
+
"attention_mask": tokenizer_output["attention_mask"].to(self.device)
|
|
700
700
|
}
|
|
701
701
|
else:
|
|
702
702
|
# Current behavior: apply chat template
|
|
@@ -28,12 +28,12 @@ def get_model_dtype(model) -> torch.dtype:
|
|
|
28
28
|
Extract model's native dtype from parameters.
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
|
-
model: PyTorch model or
|
|
31
|
+
model: PyTorch model or wisent Model wrapper
|
|
32
32
|
|
|
33
33
|
Returns:
|
|
34
34
|
The model's native dtype
|
|
35
35
|
"""
|
|
36
|
-
# Handle
|
|
36
|
+
# Handle wisent Model wrapper
|
|
37
37
|
if hasattr(model, "hf_model"):
|
|
38
38
|
model_params = model.hf_model.parameters()
|
|
39
39
|
else:
|
|
@@ -989,7 +989,7 @@ class SteeringOptimizer:
|
|
|
989
989
|
On subsequent calls: Return cached classifier from current session
|
|
990
990
|
|
|
991
991
|
Args:
|
|
992
|
-
model: Language model (
|
|
992
|
+
model: Language model (wisent Model wrapper)
|
|
993
993
|
optimization_config: Primary configuration source
|
|
994
994
|
model_name: Fallback model name if optimization_config not provided
|
|
995
995
|
task_name: Fallback task name if optimization_config not provided
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Parser arguments package for Wisent CLI.
|
|
3
|
+
|
|
4
|
+
This package contains argument parser definitions for each CLI command.
|
|
5
|
+
Each command has its own parser file for better organization and maintainability.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from wisent.core.parser_arguments.main_parser import setup_parser
|
|
9
|
+
|
|
10
|
+
__all__ = ["setup_parser"]
|