wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,727 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Any, Iterable
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from transformers import (
|
|
10
|
+
AutoModelForCausalLM,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
PreTrainedModel,
|
|
13
|
+
PreTrainedTokenizerBase,
|
|
14
|
+
TextIteratorStreamer
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from wisent.core.models.core.atoms import SteeringPlan, SteeringVector, HookHandleGroup, GenerationStats, TopLogits
|
|
20
|
+
from wisent.core.activations.core.atoms import RawActivationMap
|
|
21
|
+
|
|
22
|
+
from wisent.core.prompts.core.atom import ChatMessage
|
|
23
|
+
from wisent.core.utils.device import resolve_default_device, resolve_torch_device
|
|
24
|
+
from wisent.core.contrastive_pairs.diagnostics import run_control_steering_diagnostics
|
|
25
|
+
|
|
26
|
+
import threading
|
|
27
|
+
|
|
28
|
+
__all__ = ["WisentModel"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
class WisentModel:
|
|
34
|
+
"""
|
|
35
|
+
Wrapper around a causal LM (HF transformers) with steering capabilities.
|
|
36
|
+
|
|
37
|
+
atributes:
|
|
38
|
+
model_name:
|
|
39
|
+
HF repo id or local path (e.g., 'meta-llama/Llama-3-8B-Instruct', 'Qwen/Qwen2.5-7B-Instruct).
|
|
40
|
+
device:
|
|
41
|
+
'cuda', 'cuda:0', 'cpu', etc. If None, leave to HF defaults/accelerate.
|
|
42
|
+
hf_model:
|
|
43
|
+
the loaded PreTrainedModel instance.
|
|
44
|
+
tokenizer:
|
|
45
|
+
the loaded PreTrainedTokenizerBase instance.
|
|
46
|
+
hidden_size:
|
|
47
|
+
model hidden size (last dim of residual stream).
|
|
48
|
+
num_layers:
|
|
49
|
+
number of decoder blocks we can hook.
|
|
50
|
+
_steering_plan:
|
|
51
|
+
current SteeringPlan (can be empty).
|
|
52
|
+
_hook_group:
|
|
53
|
+
manages active hooks for clean detach.
|
|
54
|
+
"""
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model_name: str,
|
|
58
|
+
steering_layers: list[RawActivationMap] | RawActivationMap | None = None,
|
|
59
|
+
steering_weights: list[float] | None = None,
|
|
60
|
+
layers_description: list[str] | None = None,
|
|
61
|
+
device: str | None = None,
|
|
62
|
+
hf_model: AutoModelForCausalLM | None = None
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Initialize the wrapper (model + tokenizer + default steering plan).
|
|
66
|
+
|
|
67
|
+
arguments:
|
|
68
|
+
model_name:
|
|
69
|
+
HF repo id or local path (e.g., 'meta-llama/Llama-3-8B-Instruct', 'Qwen/Qwen2.5-7B-Instruct').
|
|
70
|
+
steering_layers:
|
|
71
|
+
list of RawActivationMap or single RawActivationMap of steering vectors (layer_name -> tensor), optional (can be {}).
|
|
72
|
+
We can have for example steering vectors obtained during training on a specific trait (e.g., toxicity and evilness).
|
|
73
|
+
So, by passing multiple steering vectors, we can combine them at inference time. If we don't pass any weights,
|
|
74
|
+
they will be equally weighted.
|
|
75
|
+
steering_weights:
|
|
76
|
+
list of weights for each steering vector, optional (can be None). If None, all vectors are equally weighted.
|
|
77
|
+
device:
|
|
78
|
+
'cuda', 'cuda:0', 'cpu', etc. If None, leave to HF defaults/accelerate.
|
|
79
|
+
hf_model:
|
|
80
|
+
optional preloaded model (skips from_pretrained if provided).
|
|
81
|
+
"""
|
|
82
|
+
self.model_name = model_name
|
|
83
|
+
self.device = device or resolve_default_device()
|
|
84
|
+
|
|
85
|
+
# Determine appropriate dtype and settings for the device
|
|
86
|
+
load_kwargs = {
|
|
87
|
+
"trust_remote_code": True,
|
|
88
|
+
"low_cpu_mem_usage": True,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if self.device == "mps":
|
|
92
|
+
load_kwargs["dtype"] = torch.float16
|
|
93
|
+
load_kwargs["device_map"] = "mps"
|
|
94
|
+
load_kwargs["attn_implementation"] = "eager" # MPS doesn't support flash attention
|
|
95
|
+
elif self.device == "cuda":
|
|
96
|
+
load_kwargs["dtype"] = torch.float16
|
|
97
|
+
load_kwargs["device_map"] = "auto"
|
|
98
|
+
else:
|
|
99
|
+
load_kwargs["dtype"] = torch.float32
|
|
100
|
+
load_kwargs["device_map"] = None
|
|
101
|
+
|
|
102
|
+
self.hf_model: PreTrainedModel = hf_model or AutoModelForCausalLM.from_pretrained(
|
|
103
|
+
model_name,
|
|
104
|
+
**load_kwargs
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
device_map_used = load_kwargs.get("device_map")
|
|
108
|
+
|
|
109
|
+
# Only move to device if device_map wasn't used
|
|
110
|
+
if device_map_used is None:
|
|
111
|
+
self.hf_model.to(self.device)
|
|
112
|
+
|
|
113
|
+
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
|
|
114
|
+
model_name, use_fast=True, trust_remote_code=True
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if not self._is_chat_tokenizer():
|
|
118
|
+
raise ValueError("Tokenizer does not support chat templates (missing apply_chat_template method). Change to a chat-capable model.")
|
|
119
|
+
|
|
120
|
+
if self.tokenizer.pad_token_id is None:
|
|
121
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
122
|
+
if getattr(self.hf_model.generation_config, "pad_token_id", None) is None:
|
|
123
|
+
self.hf_model.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
|
124
|
+
|
|
125
|
+
self._steering_plan: SteeringPlan = SteeringPlan.from_raw(
|
|
126
|
+
raw=steering_layers,
|
|
127
|
+
weights=steering_weights,
|
|
128
|
+
layers_description=layers_description,
|
|
129
|
+
)
|
|
130
|
+
self._hook_group = HookHandleGroup()
|
|
131
|
+
|
|
132
|
+
self._layers, self._hidden_size = self._resolve_decoder_layers_and_hidden()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def hidden_size(self) -> int:
|
|
137
|
+
return self._hidden_size
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def num_layers(self) -> int:
|
|
141
|
+
return len(self._layers)
|
|
142
|
+
|
|
143
|
+
def _resolve_decoder_layers_and_hidden(self) -> tuple[list[nn.Module], int]:
|
|
144
|
+
m = self.hf_model
|
|
145
|
+
hidden_size = getattr(m.config, "hidden_size", None) or getattr(m.config, "n_embd", None)
|
|
146
|
+
layers: list[nn.Module] = []
|
|
147
|
+
|
|
148
|
+
candidates = [
|
|
149
|
+
"layers",
|
|
150
|
+
"model.layers",
|
|
151
|
+
"model.decoder.layers",
|
|
152
|
+
"transformer.h",
|
|
153
|
+
"base_model.model.layers",
|
|
154
|
+
"blocks", "model.blocks",
|
|
155
|
+
]
|
|
156
|
+
for path in candidates:
|
|
157
|
+
obj = m
|
|
158
|
+
try:
|
|
159
|
+
for attr in path.split("."):
|
|
160
|
+
if attr:
|
|
161
|
+
obj = getattr(obj, attr)
|
|
162
|
+
if (isinstance(obj, nn.ModuleList) or
|
|
163
|
+
(isinstance(obj, (list, tuple)) and obj and isinstance(obj[0], nn.Module))):
|
|
164
|
+
layers = list(obj)
|
|
165
|
+
break
|
|
166
|
+
except AttributeError:
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
if not layers:
|
|
170
|
+
raise RuntimeError("Could not resolve decoder layers for steering hooks.")
|
|
171
|
+
|
|
172
|
+
if hidden_size is None:
|
|
173
|
+
for p in m.parameters():
|
|
174
|
+
if p.ndim >= 2:
|
|
175
|
+
hidden_size = int(p.shape[-1]); break
|
|
176
|
+
if hidden_size is None:
|
|
177
|
+
raise RuntimeError("Could not infer hidden size from model config.")
|
|
178
|
+
|
|
179
|
+
return layers, int(hidden_size)
|
|
180
|
+
|
|
181
|
+
def _is_chat_tokenizer(self) -> bool:
|
|
182
|
+
return hasattr(self.tokenizer, "apply_chat_template") and callable(getattr(self.tokenizer, "apply_chat_template"))
|
|
183
|
+
|
|
184
|
+
def apply_steering(self, plan: SteeringPlan | None = None) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Register forward hooks to add steering vectors *after* the selected decoder blocks.
|
|
187
|
+
If plan is None, use the internal plan set at init or via set_steering_from_raw().
|
|
188
|
+
Multiple vectors per layer are summed inside the hook.
|
|
189
|
+
|
|
190
|
+
arguments:
|
|
191
|
+
plan:
|
|
192
|
+
optional SteeringPlan to use for this call only (overrides internal plan).
|
|
193
|
+
If None, uses the internal plan.
|
|
194
|
+
|
|
195
|
+
SteeringPlan maps layer names (str) to list of SteeringVector, each with its own scale.
|
|
196
|
+
Like this:
|
|
197
|
+
plan = SteeringPlan.from_raw({"6": torch.randn(wm.hidden_size)}, scale=0.7)
|
|
198
|
+
|
|
199
|
+
example:
|
|
200
|
+
>>> wm.apply_steering() # uses current internal plan
|
|
201
|
+
>>> # ... generate ...
|
|
202
|
+
>>> wm.detach() # back to vanilla
|
|
203
|
+
"""
|
|
204
|
+
p = plan or self._steering_plan
|
|
205
|
+
if p.is_empty():
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
p.validate_hidden_size(hidden_size=self._hidden_size)
|
|
209
|
+
self.detach()
|
|
210
|
+
|
|
211
|
+
name_to_index = {str(i + 1): i for i in range(len(self._layers))}
|
|
212
|
+
|
|
213
|
+
for lname, vec in p.layers.items():
|
|
214
|
+
if lname not in name_to_index:
|
|
215
|
+
continue
|
|
216
|
+
idx = name_to_index[lname]
|
|
217
|
+
layer = self._layers[idx]
|
|
218
|
+
|
|
219
|
+
def _hook_factory(v: SteeringVector):
|
|
220
|
+
def _hook(_mod: nn.Module, _inp: tuple, out: torch.Tensor | tuple) -> torch.Tensor | tuple:
|
|
221
|
+
if isinstance(out, tuple):
|
|
222
|
+
hs = out[0]
|
|
223
|
+
delta = torch.zeros_like(hs)
|
|
224
|
+
delta = delta + v.materialize(hs)
|
|
225
|
+
return (hs + delta,) + out[1:]
|
|
226
|
+
else:
|
|
227
|
+
hs = out
|
|
228
|
+
delta = torch.zeros_like(hs)
|
|
229
|
+
delta = delta + v.materialize(hs)
|
|
230
|
+
return hs + delta
|
|
231
|
+
return _hook
|
|
232
|
+
|
|
233
|
+
handle = layer.register_forward_hook(_hook_factory(vec))
|
|
234
|
+
self._hook_group.add(handle)
|
|
235
|
+
|
|
236
|
+
def detach(self) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Remove all registered steering hooks; model returns to unsteered behavior.
|
|
239
|
+
"""
|
|
240
|
+
self._hook_group.remove_all()
|
|
241
|
+
|
|
242
|
+
@contextmanager
|
|
243
|
+
def detached(self):
|
|
244
|
+
"""
|
|
245
|
+
Context manager that guarantees a vanilla (unsteered) model inside the block.
|
|
246
|
+
|
|
247
|
+
example:
|
|
248
|
+
>>> with wm.detached():
|
|
249
|
+
... txt = wm.generate([[{"role": "user", "content": "Plain run"}]], use_steering=False)[0]
|
|
250
|
+
"""
|
|
251
|
+
self.detach()
|
|
252
|
+
try:
|
|
253
|
+
yield
|
|
254
|
+
finally:
|
|
255
|
+
self.detach()
|
|
256
|
+
|
|
257
|
+
def _encode_one(
|
|
258
|
+
self,
|
|
259
|
+
message: list[ChatMessage],
|
|
260
|
+
add_generation_prompt: bool = True,
|
|
261
|
+
) -> dict[str, torch.Tensor]:
|
|
262
|
+
"""
|
|
263
|
+
Encode a single input in chat format.
|
|
264
|
+
|
|
265
|
+
arguments:
|
|
266
|
+
messages:
|
|
267
|
+
list of {'role': str, 'content': str} dicts (chat messages).
|
|
268
|
+
add_generation_prompt:
|
|
269
|
+
If True, append the model's generation prompt at the end.
|
|
270
|
+
|
|
271
|
+
returns:
|
|
272
|
+
dict with 'input_ids' and 'attention_mask' tensors.
|
|
273
|
+
|
|
274
|
+
example:
|
|
275
|
+
>>> msgs = [
|
|
276
|
+
... {"role":"system","content":"Be concise."},
|
|
277
|
+
... {"role":"user","content":"Two bullet points about koalas."}
|
|
278
|
+
... ]
|
|
279
|
+
>>> wm._encode_one(msgs, add_generation_prompt=True)
|
|
280
|
+
{"input_ids": tensor([[...]]), "attention_mask": tensor([[...]])}
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
ids = self.tokenizer.apply_chat_template(
|
|
284
|
+
message, tokenize=True, add_generation_prompt=add_generation_prompt, return_tensors="pt"
|
|
285
|
+
)[0]
|
|
286
|
+
return {
|
|
287
|
+
"input_ids": ids,
|
|
288
|
+
"attention_mask": torch.ones_like(ids),
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
def _batch_encode(
|
|
292
|
+
self,
|
|
293
|
+
inputs: list[list[ChatMessage]],
|
|
294
|
+
add_generation_prompt: bool = True,
|
|
295
|
+
) -> dict[str, torch.Tensor]:
|
|
296
|
+
"""
|
|
297
|
+
Batch-encode a list of chat messages.
|
|
298
|
+
|
|
299
|
+
arguments:
|
|
300
|
+
inputs:
|
|
301
|
+
list of chat messages (each a list of {'role','content'} dicts).
|
|
302
|
+
add_generation_prompt:
|
|
303
|
+
If True, append the model's generation prompt at the end of each.
|
|
304
|
+
|
|
305
|
+
returns:
|
|
306
|
+
dict with batched 'input_ids' and 'attention_mask' tensors.
|
|
307
|
+
|
|
308
|
+
example:
|
|
309
|
+
>>> msgs1 = [
|
|
310
|
+
... {"role":"system","content":"Be concise."},
|
|
311
|
+
... {"role":"user","content":"Two bullet points about koalas."}
|
|
312
|
+
... ]
|
|
313
|
+
>>> msgs2 = [
|
|
314
|
+
... {"role":"user","content":"Write a haiku about rain."}
|
|
315
|
+
... ]
|
|
316
|
+
>>> wm._batch_encode([msgs1, msgs2], add_generation_prompt=True)
|
|
317
|
+
{"input_ids": tensor([[...],[...]]), "attention_mask": tensor([[...],[...]])}
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
singles = []
|
|
321
|
+
for item in inputs:
|
|
322
|
+
singles.append(self._encode_one(item, add_generation_prompt=add_generation_prompt))
|
|
323
|
+
|
|
324
|
+
batch = self.tokenizer.pad(singles, padding=True, return_tensors="pt")
|
|
325
|
+
|
|
326
|
+
batch = {k: v.to(resolve_torch_device()) for k, v in batch.items()}
|
|
327
|
+
|
|
328
|
+
return batch
|
|
329
|
+
|
|
330
|
+
@torch.inference_mode()
|
|
331
|
+
def generate(
|
|
332
|
+
self,
|
|
333
|
+
inputs: list[list[ChatMessage]],
|
|
334
|
+
max_new_tokens: int = 128,
|
|
335
|
+
temperature: float = 0.7,
|
|
336
|
+
top_p: float = 0.95,
|
|
337
|
+
do_sample: bool = True,
|
|
338
|
+
num_return_sequences: int = 1,
|
|
339
|
+
use_steering: bool = False,
|
|
340
|
+
steering_plan: SteeringPlan | None = None,
|
|
341
|
+
**gen_kwargs: Any,
|
|
342
|
+
) -> list[str]:
|
|
343
|
+
"""
|
|
344
|
+
Batched text generation with optional steering.
|
|
345
|
+
|
|
346
|
+
attributes:
|
|
347
|
+
inputs:
|
|
348
|
+
list of chat messages (each a list of {'role','content'} dicts).
|
|
349
|
+
max_new_tokens:
|
|
350
|
+
max tokens to generate (beyond the prompt).
|
|
351
|
+
temperature:
|
|
352
|
+
sampling temperature (0 = greedy, 1 = default sampling).
|
|
353
|
+
top_p:
|
|
354
|
+
nucleus sampling probability (1.0 = no nucleus).
|
|
355
|
+
do_sample:
|
|
356
|
+
if False, uses greedy decoding (top_k=1).
|
|
357
|
+
num_return_sequences:
|
|
358
|
+
number of completions to generate per input.
|
|
359
|
+
use_steering:
|
|
360
|
+
if True, apply the current steering plan (if any).
|
|
361
|
+
steering_plan:
|
|
362
|
+
optional SteeringPlan to use for this call only (overrides internal plan).
|
|
363
|
+
If None, uses the internal plan.
|
|
364
|
+
**gen_kwargs:
|
|
365
|
+
additional kwargs passed to 'model.generate()'.
|
|
366
|
+
|
|
367
|
+
returns:
|
|
368
|
+
list of generated strings (length = len(inputs) * num_return_sequences).
|
|
369
|
+
|
|
370
|
+
generation flow:
|
|
371
|
+
notation:
|
|
372
|
+
- Let B be batch size, T_in the (padded) input length, H the hidden size.
|
|
373
|
+
- Decoder has L layers; we index user-facing layers as strings "1".. "L" (layer 1 is the first decoder block).
|
|
374
|
+
- Steering plan maps layer names to one or more steering vectors with scales:
|
|
375
|
+
'{"6": [SteeringVector(v6, scale=0.7)], "12": [SteeringVector(v12a, 1.0), SteeringVector(v12b, 0.4)]}'
|
|
376
|
+
|
|
377
|
+
preparation:
|
|
378
|
+
Given chat messages:
|
|
379
|
+
msgs = [
|
|
380
|
+
{"role":"system","content":"Be concise."},
|
|
381
|
+
{"role":"user","content":"Two bullet points about koalas."}
|
|
382
|
+
]
|
|
383
|
+
|
|
384
|
+
Encoding produces:
|
|
385
|
+
- If chat template is available, 'apply_chat_template(..., tokenize=True)' yields `input_ids` of shape '[T1]'.
|
|
386
|
+
- After 'tokenizer.pad([...])', the batch tensors have shapes:
|
|
387
|
+
- 'input_ids: [B, T_in]'
|
|
388
|
+
- 'attention_mask: [B, T_in]'
|
|
389
|
+
where 'T_in = T1' and 'B = 2' in this example.
|
|
390
|
+
|
|
391
|
+
without steering:
|
|
392
|
+
>>> wm = WisentModel("meta-llama/Meta-Llama-3-8B-Instruct", layers={}, device="cuda")
|
|
393
|
+
>>> out_plain = wm.generate([msgs], max_new_tokens=32, use_steering=False)
|
|
394
|
+
# out_plain: list[str] length B (or B * num_return_sequences)
|
|
395
|
+
|
|
396
|
+
>>> for i, msg in enumerate(msgs):
|
|
397
|
+
... print(f"User {i+1}: {msg['content']}")
|
|
398
|
+
... print(f"Assistant {i+1}: {out_plain[i]}")
|
|
399
|
+
|
|
400
|
+
internally during generation step 't = 0..T_out-1':
|
|
401
|
+
- Each decoder block 'i' outputs a residual stream tensor of shape '[B, T_in + t, H]'.
|
|
402
|
+
- No modification is applied; the model returns logits → token → appended to sequence.
|
|
403
|
+
|
|
404
|
+
with steering (add AFTER layer i):
|
|
405
|
+
# Build steering vectors of shape [H] for chosen layers; scales are per-vector.
|
|
406
|
+
>>> plan = SteeringPlan.from_raw({
|
|
407
|
+
... "6": torch.randn(wm.hidden_size), # will be normalized/broadcast if needed
|
|
408
|
+
... "12": torch.randn(wm.hidden_size),
|
|
409
|
+
... }, scale=0.7, normalize=True)
|
|
410
|
+
|
|
411
|
+
# Set once and use
|
|
412
|
+
>>> wm.set_steering_from_raw({"6": plan.layers["6"][0].vector, "12": plan.layers["12"][0].vector},
|
|
413
|
+
scale=0.7, normalize=True)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
What the hook 'sees' at a steered layer 'i' on step 't':
|
|
417
|
+
- The layer's output (residual stream) 'h_i' has shape '[B, T_in + t, H]'.
|
|
418
|
+
- Your steering vector 'v_i' is materialized to '[1, 1, H]' (or '[B,T,H]' if you passed that) and cast to the same dtype/device.
|
|
419
|
+
- The hook returns 'h_i' = h_i + α_i * v_i' (if multiple vectors are configured for the same layer, it sums them).
|
|
420
|
+
- This addition is cheap: one broadcasted add per steered layer, per step.
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
shapes recap at generation step t (same for chat or plain strings):
|
|
424
|
+
- Decoder block output: '[B, T_in + t, H]'
|
|
425
|
+
- Materialized steering vector: '[1, 1, H]' (broadcast to '[B, T_in + t, H]')
|
|
426
|
+
- Residual after steering (per layer): '[B, T_in + t, H]'
|
|
427
|
+
|
|
428
|
+
example (one batch):
|
|
429
|
+
>>> msgs = [
|
|
430
|
+
... {"role":"system","content":"Be concise."},
|
|
431
|
+
... {"role":"user","content":"Two bullet points about koalas."}
|
|
432
|
+
... ]
|
|
433
|
+
>>> wm.apply_steering() # or pass use_steering=True below
|
|
434
|
+
>>> out = wm.generate([msgs], max_new_tokens=32, use_steering=True)
|
|
435
|
+
>>> for i, msg in enumerate(msgs):
|
|
436
|
+
... print(f"User {i+1}: {msg['content']}")
|
|
437
|
+
... print(f"Assistant {i+1}: {out[i]}")
|
|
438
|
+
"""
|
|
439
|
+
if use_steering:
|
|
440
|
+
self.apply_steering(steering_plan)
|
|
441
|
+
|
|
442
|
+
batch = self._batch_encode(inputs, add_generation_prompt=True)
|
|
443
|
+
|
|
444
|
+
gen_out = self.hf_model.generate(
|
|
445
|
+
**batch,
|
|
446
|
+
max_new_tokens=max_new_tokens,
|
|
447
|
+
temperature=temperature,
|
|
448
|
+
top_p=top_p,
|
|
449
|
+
do_sample=do_sample,
|
|
450
|
+
num_return_sequences=num_return_sequences,
|
|
451
|
+
return_dict_in_generate=True,
|
|
452
|
+
output_scores=False,
|
|
453
|
+
**gen_kwargs,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
if use_steering:
|
|
457
|
+
self.detach()
|
|
458
|
+
|
|
459
|
+
seqs = gen_out.sequences # [B * num_return_sequences, T_total]
|
|
460
|
+
texts = self.tokenizer.batch_decode(seqs, skip_special_tokens=True)
|
|
461
|
+
return texts
|
|
462
|
+
|
|
463
|
+
@torch.inference_mode()
|
|
464
|
+
def generate_with_stats(
|
|
465
|
+
self,
|
|
466
|
+
inputs: list[list[ChatMessage]],
|
|
467
|
+
max_new_tokens: int = 64,
|
|
468
|
+
temperature: float = 0.7,
|
|
469
|
+
top_p: float = 0.95,
|
|
470
|
+
do_sample: bool = True,
|
|
471
|
+
num_return_sequences: int = 1,
|
|
472
|
+
collect_topk: int = 5,
|
|
473
|
+
use_steering: bool = False,
|
|
474
|
+
steering_plan: SteeringPlan | None = None,
|
|
475
|
+
**gen_kwargs: Any,
|
|
476
|
+
) -> tuple[list[str], list[GenerationStats]]:
|
|
477
|
+
"""
|
|
478
|
+
Generate with efficient per-token stats (logits / probs), compatible with steering.
|
|
479
|
+
Implementation detail: uses `output_scores=True` + `return_dict_in_generate=True` (HF standard). :contentReference[oaicite:11]{index=11}
|
|
480
|
+
|
|
481
|
+
attributes:
|
|
482
|
+
inputs:
|
|
483
|
+
list of chat messages (each a list of {'role','content'} dicts).
|
|
484
|
+
max_new_tokens:
|
|
485
|
+
max tokens to generate (beyond the prompt).
|
|
486
|
+
temperature:
|
|
487
|
+
sampling temperature (0 = greedy, 1 = default sampling).
|
|
488
|
+
top_p:
|
|
489
|
+
nucleus sampling probability (0 = no filtering, 1 = full filtering).
|
|
490
|
+
do_sample:
|
|
491
|
+
if False, uses greedy decoding (top_k=1).
|
|
492
|
+
num_return_sequences:
|
|
493
|
+
number of completions to generate per input.
|
|
494
|
+
collect_topk:
|
|
495
|
+
if > 0, collect top-k logits/probs per step for analysis/visualization.
|
|
496
|
+
use_steering:
|
|
497
|
+
if True, apply the current steering plan (if any).
|
|
498
|
+
steering_plan:
|
|
499
|
+
optional SteeringPlan to use for this call only (overrides internal plan).
|
|
500
|
+
If None, uses the internal plan.
|
|
501
|
+
**gen_kwargs:
|
|
502
|
+
additional kwargs passed to 'model.generate()'.
|
|
503
|
+
|
|
504
|
+
returns:
|
|
505
|
+
- list of generated strings (length = len(inputs) * num_return_sequences).
|
|
506
|
+
- list of GenerationStats (length = len(inputs) * num_return_sequences).
|
|
507
|
+
Each GenerationStats has:
|
|
508
|
+
tokens:
|
|
509
|
+
list of generated token ids (length = actual generated tokens).
|
|
510
|
+
per_step:
|
|
511
|
+
if collect_topk > 0, list of TopLogits (length = actual generated tokens).
|
|
512
|
+
Each TopLogits has:
|
|
513
|
+
token_id:
|
|
514
|
+
the generated token id at that step.
|
|
515
|
+
logit:
|
|
516
|
+
the raw logit for that token.
|
|
517
|
+
prob:
|
|
518
|
+
the softmax probability for that token.
|
|
519
|
+
topk_ids:
|
|
520
|
+
if collect_topk > 0, list of top-k token ids at that step.
|
|
521
|
+
topk_probs:
|
|
522
|
+
if collect_topk > 0, list of top-k probabilities at that step.
|
|
523
|
+
|
|
524
|
+
example:
|
|
525
|
+
>>> msgs = [[
|
|
526
|
+
... {"role":"system","content":"Be concise."},
|
|
527
|
+
... {"role":"user","content":"Two bullet points about koalas."}
|
|
528
|
+
... ]]
|
|
529
|
+
>>> wm = WisentModel("meta-llama/Meta-Llama-3-8B-Instruct", layers={}, device="cuda")
|
|
530
|
+
>>> wm.set_steering_from_raw({"6": torch.randn(wm.hidden_size), "12": torch.randn(wm.hidden_size)}, scale=0.7, normalize=True)
|
|
531
|
+
>>> texts, stats = wm.generate_with_stats(
|
|
532
|
+
... msgs,
|
|
533
|
+
... max_new_tokens=48, collect_topk=5, use_steering=True
|
|
534
|
+
... )
|
|
535
|
+
>>> stats[0].per_step[0].prob # probability of the first generated token
|
|
536
|
+
"""
|
|
537
|
+
if use_steering:
|
|
538
|
+
self.apply_steering(steering_plan)
|
|
539
|
+
|
|
540
|
+
batch = self._batch_encode(inputs, add_generation_prompt=True)
|
|
541
|
+
|
|
542
|
+
out = self.hf_model.generate(
|
|
543
|
+
**batch,
|
|
544
|
+
max_new_tokens=max_new_tokens,
|
|
545
|
+
temperature=temperature,
|
|
546
|
+
top_p=top_p,
|
|
547
|
+
do_sample=do_sample,
|
|
548
|
+
num_return_sequences=num_return_sequences,
|
|
549
|
+
return_dict_in_generate=True,
|
|
550
|
+
output_scores=True,
|
|
551
|
+
**gen_kwargs,
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
if use_steering:
|
|
555
|
+
self.detach()
|
|
556
|
+
|
|
557
|
+
texts = self.tokenizer.batch_decode(out.sequences, skip_special_tokens=True)
|
|
558
|
+
|
|
559
|
+
scores: list[torch.Tensor] = list(out.scores or [])
|
|
560
|
+
stats: list[GenerationStats] = []
|
|
561
|
+
|
|
562
|
+
if scores:
|
|
563
|
+
stacked = torch.stack(scores, dim=0) # [steps, B*num_ret, V]
|
|
564
|
+
steps = stacked.size(0)
|
|
565
|
+
gen_token_ids = out.sequences[:, -steps:] # [B*num_ret, steps]
|
|
566
|
+
|
|
567
|
+
logprobs = torch.log_softmax(stacked.float(), dim=-1) # [steps, B, V]
|
|
568
|
+
B = logprobs.size(1)
|
|
569
|
+
V = logprobs.size(2)
|
|
570
|
+
|
|
571
|
+
for b in range(B):
|
|
572
|
+
toks = gen_token_ids[b].tolist()
|
|
573
|
+
per_step: list[TopLogits] = []
|
|
574
|
+
for t, tok_id in enumerate(toks):
|
|
575
|
+
lp_row = logprobs[t, b] # [V]
|
|
576
|
+
logit = scores[t][b, tok_id].item()
|
|
577
|
+
prob = float(lp_row[tok_id].exp().item())
|
|
578
|
+
if collect_topk > 0:
|
|
579
|
+
topk_vals, topk_ids = lp_row.topk(min(collect_topk, V))
|
|
580
|
+
per_step.append(TopLogits(
|
|
581
|
+
token_id=int(tok_id),
|
|
582
|
+
logit=float(logit),
|
|
583
|
+
prob=float(prob),
|
|
584
|
+
topk_ids=topk_ids.tolist(),
|
|
585
|
+
topk_probs=topk_vals.exp().tolist(),
|
|
586
|
+
))
|
|
587
|
+
else:
|
|
588
|
+
per_step.append(TopLogits(
|
|
589
|
+
token_id=int(tok_id),
|
|
590
|
+
logit=float(logit),
|
|
591
|
+
prob=float(prob),
|
|
592
|
+
))
|
|
593
|
+
stats.append(GenerationStats(tokens=toks, per_step=per_step))
|
|
594
|
+
else:
|
|
595
|
+
for _ in range(out.sequences.size(0)):
|
|
596
|
+
stats.append(GenerationStats(tokens=[], per_step=None))
|
|
597
|
+
|
|
598
|
+
return texts, stats
|
|
599
|
+
|
|
600
|
+
@torch.inference_mode()
|
|
601
|
+
def generate_stream(
|
|
602
|
+
self,
|
|
603
|
+
inputs: list[list[ChatMessage]],
|
|
604
|
+
max_new_tokens: int = 128,
|
|
605
|
+
temperature: float = 0.7,
|
|
606
|
+
top_p: float = 0.95,
|
|
607
|
+
do_sample: bool = True,
|
|
608
|
+
use_steering: bool = False,
|
|
609
|
+
steering_plan: SteeringPlan | None = None,
|
|
610
|
+
skip_prompt: bool = True,
|
|
611
|
+
skip_special_tokens: bool = True,
|
|
612
|
+
**gen_kwargs: Any,
|
|
613
|
+
) -> Iterable[str]:
|
|
614
|
+
"""
|
|
615
|
+
Streamed text generation with optional steering.
|
|
616
|
+
Uses the TextIteratorStreamer from transformers.
|
|
617
|
+
|
|
618
|
+
attributes:
|
|
619
|
+
inputs:
|
|
620
|
+
list of chat messages (each a list of {'role','content'} dicts). Currently only one conversation is supported.
|
|
621
|
+
max_new_tokens:
|
|
622
|
+
max tokens to generate (beyond the prompt).
|
|
623
|
+
temperature:
|
|
624
|
+
sampling temperature (0 = greedy, 1 = default sampling).
|
|
625
|
+
top_p:
|
|
626
|
+
nucleus sampling probability (1.0 = no nucleus).
|
|
627
|
+
do_sample:
|
|
628
|
+
if False, uses greedy decoding (top_k=1).
|
|
629
|
+
use_steering:
|
|
630
|
+
if True, apply the current steering plan (if any).
|
|
631
|
+
steering_plan:
|
|
632
|
+
optional SteeringPlan to use for this call only (overrides internal plan).
|
|
633
|
+
If None, uses the internal plan.
|
|
634
|
+
skip_prompt:
|
|
635
|
+
if True, the yielded text excludes the input prompt.
|
|
636
|
+
skip_special_tokens:
|
|
637
|
+
if True, special tokens are removed from the yielded text.
|
|
638
|
+
**gen_kwargs:
|
|
639
|
+
additional kwargs passed to 'model.generate()'.
|
|
640
|
+
|
|
641
|
+
yields:
|
|
642
|
+
generated text chunks (str), as they become available.
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
if len(inputs) != 1:
|
|
646
|
+
raise ValueError(
|
|
647
|
+
f"generate_stream currently supports exactly one conversation at a time (got {len(inputs)})."
|
|
648
|
+
)
|
|
649
|
+
if use_steering:
|
|
650
|
+
self.apply_steering(steering_plan)
|
|
651
|
+
|
|
652
|
+
batch = self._batch_encode(inputs, add_generation_prompt=True)
|
|
653
|
+
|
|
654
|
+
streamer = TextIteratorStreamer(
|
|
655
|
+
self.tokenizer,
|
|
656
|
+
skip_prompt=skip_prompt,
|
|
657
|
+
skip_special_tokens=skip_special_tokens,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
generation_kwargs = dict(
|
|
661
|
+
batch,
|
|
662
|
+
max_new_tokens=max_new_tokens,
|
|
663
|
+
temperature=temperature,
|
|
664
|
+
top_p=top_p,
|
|
665
|
+
do_sample=do_sample,
|
|
666
|
+
return_dict_in_generate=False,
|
|
667
|
+
output_scores=False,
|
|
668
|
+
streamer=streamer,
|
|
669
|
+
**gen_kwargs,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
worker = threading.Thread(
|
|
673
|
+
target=self.hf_model.generate,
|
|
674
|
+
kwargs=generation_kwargs,
|
|
675
|
+
daemon=True,
|
|
676
|
+
)
|
|
677
|
+
worker.start()
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
for new_text in streamer:
|
|
681
|
+
if new_text:
|
|
682
|
+
yield new_text
|
|
683
|
+
finally:
|
|
684
|
+
if use_steering:
|
|
685
|
+
self.detach()
|
|
686
|
+
worker.join(timeout=0.0)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def set_steering_from_raw(self, raw: list[RawActivationMap] | RawActivationMap | None, layers_description: list[str] | None = None, steering_weights: list[float] | None = None, scale: float = 1.0, normalize: bool = False) -> None:
|
|
690
|
+
"""
|
|
691
|
+
Replace the internal steering plan using a RawActivationMap (layer_name -> tensor).
|
|
692
|
+
If raw is None or empty, clears any existing steering. If we
|
|
693
|
+
"""
|
|
694
|
+
if not raw:
|
|
695
|
+
self._steering_plan = SteeringPlan()
|
|
696
|
+
return
|
|
697
|
+
|
|
698
|
+
# TODO: this should be outside
|
|
699
|
+
reports = run_control_steering_diagnostics(raw)
|
|
700
|
+
for report in reports:
|
|
701
|
+
for issue in report.issues:
|
|
702
|
+
log_method = logger.error if issue.severity == "critical" else logger.warning
|
|
703
|
+
log_method(
|
|
704
|
+
"[control_vector diagnostics] %s (details=%s)",
|
|
705
|
+
issue.message,
|
|
706
|
+
issue.details,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
if any(report.has_critical_issues for report in reports):
|
|
710
|
+
raise ValueError("Control vector diagnostics found critical issues; refusing to set steering.")
|
|
711
|
+
|
|
712
|
+
self._steering_plan = SteeringPlan.from_raw(
|
|
713
|
+
raw=raw,
|
|
714
|
+
layers_description=layers_description,
|
|
715
|
+
weights=steering_weights,
|
|
716
|
+
scale=scale,
|
|
717
|
+
normalize=normalize,
|
|
718
|
+
expected_hidden_size=self._hidden_size
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
def clear_steering(self) -> None:
|
|
722
|
+
"""
|
|
723
|
+
Remove any existing steering configuration and active hooks.
|
|
724
|
+
After calling this, generation is vanilla.
|
|
725
|
+
"""
|
|
726
|
+
self._steering_plan = SteeringPlan()
|
|
727
|
+
self.detach()
|