wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
wisent/inference/inferencer.py
DELETED
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Functionality for local inference with control vectors.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import logging
|
|
6
|
-
from typing import Dict, List, Optional, Union
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
10
|
-
|
|
11
|
-
from wisent.control_vector.models import ControlVector
|
|
12
|
-
from wisent.inference.models import ControlVectorInferenceConfig, InferenceConfig, InferenceResponse
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ControlVectorHook:
|
|
18
|
-
"""
|
|
19
|
-
Hook for applying control vectors during inference.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
control_vector: Control vector to apply
|
|
23
|
-
config: Configuration for applying the control vector
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
control_vector: ControlVector,
|
|
29
|
-
config: ControlVectorInferenceConfig,
|
|
30
|
-
):
|
|
31
|
-
self.control_vector = control_vector
|
|
32
|
-
self.config = config
|
|
33
|
-
self.device = None
|
|
34
|
-
self.vector_tensor = None
|
|
35
|
-
self.hooks = []
|
|
36
|
-
|
|
37
|
-
def register(self, model):
|
|
38
|
-
"""
|
|
39
|
-
Register hooks on the model.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
model: The model to register hooks on
|
|
43
|
-
"""
|
|
44
|
-
self.device = next(model.parameters()).device
|
|
45
|
-
self.vector_tensor = self.control_vector.to_tensor(self.device)
|
|
46
|
-
|
|
47
|
-
# Get transformer layers
|
|
48
|
-
if hasattr(model, "transformer"):
|
|
49
|
-
transformer_layers = model.transformer.h
|
|
50
|
-
elif hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
51
|
-
transformer_layers = model.model.layers
|
|
52
|
-
else:
|
|
53
|
-
raise ValueError(f"Unsupported model architecture: {model.__class__.__name__}")
|
|
54
|
-
|
|
55
|
-
# Determine which layers to apply the control vector to
|
|
56
|
-
num_layers = len(transformer_layers)
|
|
57
|
-
layers = self.config.layers or [num_layers - 1] # Default to last layer
|
|
58
|
-
|
|
59
|
-
# Resolve negative indices
|
|
60
|
-
resolved_layers = []
|
|
61
|
-
for layer in layers:
|
|
62
|
-
if layer < 0:
|
|
63
|
-
resolved_layer = num_layers + layer
|
|
64
|
-
else:
|
|
65
|
-
resolved_layer = layer
|
|
66
|
-
|
|
67
|
-
if 0 <= resolved_layer < num_layers:
|
|
68
|
-
resolved_layers.append(resolved_layer)
|
|
69
|
-
|
|
70
|
-
# Register hooks
|
|
71
|
-
for layer_idx in resolved_layers:
|
|
72
|
-
layer = transformer_layers[layer_idx]
|
|
73
|
-
|
|
74
|
-
# Define hook function
|
|
75
|
-
def hook_fn(module, input, output, layer_idx=layer_idx):
|
|
76
|
-
if isinstance(output, tuple):
|
|
77
|
-
hidden_states = output[0]
|
|
78
|
-
else:
|
|
79
|
-
hidden_states = output
|
|
80
|
-
|
|
81
|
-
# Apply the control vector
|
|
82
|
-
if self.config.method == "caa": # Context-Aware Addition
|
|
83
|
-
# Add the control vector to the hidden states
|
|
84
|
-
modified = hidden_states + self.vector_tensor * self.config.scale
|
|
85
|
-
|
|
86
|
-
if isinstance(output, tuple):
|
|
87
|
-
return (modified,) + output[1:]
|
|
88
|
-
else:
|
|
89
|
-
return modified
|
|
90
|
-
else:
|
|
91
|
-
logger.warning(f"Unsupported method: {self.config.method}, using original output")
|
|
92
|
-
return output
|
|
93
|
-
|
|
94
|
-
# Register hook
|
|
95
|
-
if hasattr(layer, "output"):
|
|
96
|
-
handle = layer.output.register_forward_hook(
|
|
97
|
-
lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
|
|
98
|
-
)
|
|
99
|
-
else:
|
|
100
|
-
handle = layer.register_forward_hook(
|
|
101
|
-
lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
self.hooks.append(handle)
|
|
105
|
-
|
|
106
|
-
def remove(self):
|
|
107
|
-
"""Remove all registered hooks."""
|
|
108
|
-
for hook in self.hooks:
|
|
109
|
-
hook.remove()
|
|
110
|
-
self.hooks = []
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
class Inferencer:
|
|
114
|
-
"""
|
|
115
|
-
Performs local inference with control vectors.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
model_name: Name of the model
|
|
119
|
-
device: Device to use for inference
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
def __init__(
|
|
123
|
-
self,
|
|
124
|
-
model_name: str,
|
|
125
|
-
device: Optional[str] = None,
|
|
126
|
-
):
|
|
127
|
-
self.model_name = model_name
|
|
128
|
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
129
|
-
self.model = None
|
|
130
|
-
self.tokenizer = None
|
|
131
|
-
|
|
132
|
-
logger.info(f"Initializing Inferencer for model {model_name} on {self.device}")
|
|
133
|
-
|
|
134
|
-
def _load_model(self):
|
|
135
|
-
"""Load the model and tokenizer."""
|
|
136
|
-
if self.model is None:
|
|
137
|
-
logger.info(f"Loading model {self.model_name}")
|
|
138
|
-
self.model = AutoModelForCausalLM.from_pretrained(
|
|
139
|
-
self.model_name,
|
|
140
|
-
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
141
|
-
device_map=self.device
|
|
142
|
-
)
|
|
143
|
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
144
|
-
logger.info(f"Model loaded successfully")
|
|
145
|
-
|
|
146
|
-
def generate(
|
|
147
|
-
self,
|
|
148
|
-
prompt: str,
|
|
149
|
-
control_vector: Optional[ControlVector] = None,
|
|
150
|
-
method: str = "caa",
|
|
151
|
-
scale: float = 1.0,
|
|
152
|
-
layers: Optional[List[int]] = None,
|
|
153
|
-
config: Optional[InferenceConfig] = None,
|
|
154
|
-
) -> InferenceResponse:
|
|
155
|
-
"""
|
|
156
|
-
Generate text using the model, optionally with a control vector.
|
|
157
|
-
|
|
158
|
-
Args:
|
|
159
|
-
prompt: Input prompt
|
|
160
|
-
control_vector: Control vector to apply (optional)
|
|
161
|
-
method: Method for applying the control vector
|
|
162
|
-
scale: Scaling factor for the control vector
|
|
163
|
-
layers: Layers to apply the control vector to
|
|
164
|
-
config: Inference configuration
|
|
165
|
-
|
|
166
|
-
Returns:
|
|
167
|
-
Inference response
|
|
168
|
-
"""
|
|
169
|
-
try:
|
|
170
|
-
self._load_model()
|
|
171
|
-
|
|
172
|
-
config = config or InferenceConfig()
|
|
173
|
-
hook = None
|
|
174
|
-
|
|
175
|
-
# Register control vector hook if provided
|
|
176
|
-
if control_vector is not None:
|
|
177
|
-
cv_config = ControlVectorInferenceConfig(
|
|
178
|
-
method=method,
|
|
179
|
-
scale=scale,
|
|
180
|
-
layers=layers,
|
|
181
|
-
)
|
|
182
|
-
hook = ControlVectorHook(control_vector, cv_config)
|
|
183
|
-
hook.register(self.model)
|
|
184
|
-
|
|
185
|
-
# Tokenize input
|
|
186
|
-
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
187
|
-
prompt_length = inputs.input_ids.shape[1]
|
|
188
|
-
|
|
189
|
-
# Configure generation
|
|
190
|
-
generation_config = GenerationConfig(
|
|
191
|
-
max_new_tokens=config.max_tokens,
|
|
192
|
-
temperature=config.temperature,
|
|
193
|
-
top_p=config.top_p,
|
|
194
|
-
top_k=config.top_k,
|
|
195
|
-
repetition_penalty=config.repetition_penalty,
|
|
196
|
-
do_sample=config.temperature > 0,
|
|
197
|
-
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
# Generate
|
|
201
|
-
with torch.no_grad():
|
|
202
|
-
output_ids = self.model.generate(
|
|
203
|
-
inputs.input_ids,
|
|
204
|
-
attention_mask=inputs.attention_mask,
|
|
205
|
-
generation_config=generation_config,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# Remove control vector hook if registered
|
|
209
|
-
if hook is not None:
|
|
210
|
-
hook.remove()
|
|
211
|
-
|
|
212
|
-
# Decode output
|
|
213
|
-
generated_text = self.tokenizer.decode(
|
|
214
|
-
output_ids[0][prompt_length:],
|
|
215
|
-
skip_special_tokens=True
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
# Create response
|
|
219
|
-
return InferenceResponse(
|
|
220
|
-
text=generated_text,
|
|
221
|
-
model=self.model_name,
|
|
222
|
-
prompt=prompt,
|
|
223
|
-
finish_reason="length", # Simplified
|
|
224
|
-
usage={
|
|
225
|
-
"prompt_tokens": prompt_length,
|
|
226
|
-
"completion_tokens": output_ids.shape[1] - prompt_length,
|
|
227
|
-
"total_tokens": output_ids.shape[1],
|
|
228
|
-
},
|
|
229
|
-
metadata={
|
|
230
|
-
"control_vector": control_vector.name if control_vector else None,
|
|
231
|
-
"method": method if control_vector else None,
|
|
232
|
-
"scale": scale if control_vector else None,
|
|
233
|
-
}
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
except Exception as e:
|
|
237
|
-
logger.error(f"Error during inference: {str(e)}")
|
|
238
|
-
if hook is not None:
|
|
239
|
-
hook.remove()
|
|
240
|
-
raise
|
|
241
|
-
|
|
242
|
-
def __del__(self):
|
|
243
|
-
"""Clean up resources."""
|
|
244
|
-
# Free GPU memory
|
|
245
|
-
if self.model is not None and hasattr(self.model, "to"):
|
|
246
|
-
self.model = self.model.to("cpu")
|
|
247
|
-
|
|
248
|
-
# Clear CUDA cache
|
|
249
|
-
if torch.cuda.is_available():
|
|
250
|
-
torch.cuda.empty_cache()
|
wisent/inference/models.py
DELETED
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Data models for inference.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Dict, List, Optional, Union
|
|
7
|
-
|
|
8
|
-
from pydantic import BaseModel, Field
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class InferenceConfig(BaseModel):
|
|
12
|
-
"""
|
|
13
|
-
Configuration for model inference.
|
|
14
|
-
|
|
15
|
-
Attributes:
|
|
16
|
-
max_tokens: Maximum number of tokens to generate
|
|
17
|
-
temperature: Sampling temperature
|
|
18
|
-
top_p: Top-p sampling parameter
|
|
19
|
-
top_k: Top-k sampling parameter
|
|
20
|
-
repetition_penalty: Repetition penalty
|
|
21
|
-
stop_sequences: Sequences that stop generation
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
max_tokens: int = 256
|
|
25
|
-
temperature: float = 0.7
|
|
26
|
-
top_p: float = 0.9
|
|
27
|
-
top_k: int = 50
|
|
28
|
-
repetition_penalty: float = 1.0
|
|
29
|
-
stop_sequences: Optional[List[str]] = None
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class InferenceResponse(BaseModel):
|
|
33
|
-
"""
|
|
34
|
-
Response from model inference.
|
|
35
|
-
|
|
36
|
-
Attributes:
|
|
37
|
-
text: Generated text
|
|
38
|
-
model: Model used for generation
|
|
39
|
-
prompt: Input prompt
|
|
40
|
-
finish_reason: Reason generation stopped
|
|
41
|
-
usage: Token usage information
|
|
42
|
-
metadata: Additional metadata
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
text: str
|
|
46
|
-
model: str
|
|
47
|
-
prompt: str
|
|
48
|
-
finish_reason: str = "length"
|
|
49
|
-
usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
|
|
50
|
-
metadata: Dict = Field(default_factory=dict)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@dataclass
|
|
54
|
-
class ControlVectorInferenceConfig:
|
|
55
|
-
"""
|
|
56
|
-
Configuration for inference with control vectors.
|
|
57
|
-
|
|
58
|
-
Attributes:
|
|
59
|
-
method: Method for applying control vectors
|
|
60
|
-
scale: Scaling factor for control vectors
|
|
61
|
-
layers: Layers to apply control vectors to
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
method: str = "caa" # Context-Aware Addition
|
|
65
|
-
scale: float = 1.0
|
|
66
|
-
layers: Optional[List[int]] = None
|
wisent/utils/__init__.py
DELETED
wisent/utils/auth.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Authentication utilities for the Wisent API.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Dict
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class AuthManager:
|
|
9
|
-
"""
|
|
10
|
-
Manages authentication for Wisent API requests.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
api_key: The Wisent API key
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
def __init__(self, api_key: str):
|
|
17
|
-
self.api_key = api_key
|
|
18
|
-
|
|
19
|
-
def get_headers(self) -> Dict[str, str]:
|
|
20
|
-
"""
|
|
21
|
-
Get the authentication headers for API requests.
|
|
22
|
-
|
|
23
|
-
Returns:
|
|
24
|
-
Dict containing the authentication headers
|
|
25
|
-
"""
|
|
26
|
-
return {
|
|
27
|
-
"Authorization": f"Bearer {self.api_key}",
|
|
28
|
-
"Content-Type": "application/json",
|
|
29
|
-
"Accept": "application/json",
|
|
30
|
-
}
|
wisent/utils/http.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
HTTP request utilities for the Wisent API.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import json
|
|
6
|
-
from typing import Any, Dict, Optional, Union
|
|
7
|
-
|
|
8
|
-
import aiohttp
|
|
9
|
-
import requests
|
|
10
|
-
from requests.exceptions import RequestException
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class APIError(Exception):
|
|
14
|
-
"""Exception raised for API errors."""
|
|
15
|
-
|
|
16
|
-
def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict[str, Any]] = None):
|
|
17
|
-
self.message = message
|
|
18
|
-
self.status_code = status_code
|
|
19
|
-
self.response = response
|
|
20
|
-
super().__init__(self.message)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class HTTPClient:
|
|
24
|
-
"""
|
|
25
|
-
HTTP client for making requests to the Wisent API.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
base_url: The base URL for the API
|
|
29
|
-
headers: Headers to include in all requests
|
|
30
|
-
timeout: Request timeout in seconds
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __init__(self, base_url: str, headers: Dict[str, str], timeout: int = 60):
|
|
34
|
-
self.base_url = base_url.rstrip("/")
|
|
35
|
-
self.headers = headers
|
|
36
|
-
self.timeout = timeout
|
|
37
|
-
|
|
38
|
-
def _build_url(self, endpoint: str) -> str:
|
|
39
|
-
"""Build the full URL for an API endpoint."""
|
|
40
|
-
endpoint = endpoint.lstrip("/")
|
|
41
|
-
return f"{self.base_url}/{endpoint}"
|
|
42
|
-
|
|
43
|
-
def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
44
|
-
"""
|
|
45
|
-
Make a GET request to the API.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
endpoint: API endpoint
|
|
49
|
-
params: Query parameters
|
|
50
|
-
|
|
51
|
-
Returns:
|
|
52
|
-
Response data as a dictionary
|
|
53
|
-
|
|
54
|
-
Raises:
|
|
55
|
-
APIError: If the request fails
|
|
56
|
-
"""
|
|
57
|
-
url = self._build_url(endpoint)
|
|
58
|
-
try:
|
|
59
|
-
response = requests.get(
|
|
60
|
-
url,
|
|
61
|
-
headers=self.headers,
|
|
62
|
-
params=params,
|
|
63
|
-
timeout=self.timeout
|
|
64
|
-
)
|
|
65
|
-
response.raise_for_status()
|
|
66
|
-
return response.json()
|
|
67
|
-
except RequestException as e:
|
|
68
|
-
status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
|
|
69
|
-
response_data = None
|
|
70
|
-
|
|
71
|
-
if hasattr(e, "response") and e.response is not None:
|
|
72
|
-
try:
|
|
73
|
-
response_data = e.response.json()
|
|
74
|
-
except (ValueError, AttributeError):
|
|
75
|
-
response_data = {"error": str(e)}
|
|
76
|
-
|
|
77
|
-
raise APIError(
|
|
78
|
-
f"GET request to {url} failed: {str(e)}",
|
|
79
|
-
status_code=status_code,
|
|
80
|
-
response=response_data
|
|
81
|
-
) from e
|
|
82
|
-
|
|
83
|
-
def post(self, endpoint: str, data: Optional[Dict[str, Any]] = None, json_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
84
|
-
"""
|
|
85
|
-
Make a POST request to the API.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
endpoint: API endpoint
|
|
89
|
-
data: Form data
|
|
90
|
-
json_data: JSON data
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
Response data as a dictionary
|
|
94
|
-
|
|
95
|
-
Raises:
|
|
96
|
-
APIError: If the request fails
|
|
97
|
-
"""
|
|
98
|
-
url = self._build_url(endpoint)
|
|
99
|
-
try:
|
|
100
|
-
response = requests.post(
|
|
101
|
-
url,
|
|
102
|
-
headers=self.headers,
|
|
103
|
-
data=data,
|
|
104
|
-
json=json_data,
|
|
105
|
-
timeout=self.timeout
|
|
106
|
-
)
|
|
107
|
-
response.raise_for_status()
|
|
108
|
-
return response.json()
|
|
109
|
-
except RequestException as e:
|
|
110
|
-
status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
|
|
111
|
-
response_data = None
|
|
112
|
-
|
|
113
|
-
if hasattr(e, "response") and e.response is not None:
|
|
114
|
-
try:
|
|
115
|
-
response_data = e.response.json()
|
|
116
|
-
except (ValueError, AttributeError):
|
|
117
|
-
response_data = {"error": str(e)}
|
|
118
|
-
|
|
119
|
-
raise APIError(
|
|
120
|
-
f"POST request to {url} failed: {str(e)}",
|
|
121
|
-
status_code=status_code,
|
|
122
|
-
response=response_data
|
|
123
|
-
) from e
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class AsyncHTTPClient:
|
|
127
|
-
"""
|
|
128
|
-
Asynchronous HTTP client for making requests to the Wisent API.
|
|
129
|
-
|
|
130
|
-
Args:
|
|
131
|
-
base_url: The base URL for the API
|
|
132
|
-
headers: Headers to include in all requests
|
|
133
|
-
timeout: Request timeout in seconds
|
|
134
|
-
"""
|
|
135
|
-
|
|
136
|
-
def __init__(self, base_url: str, headers: Dict[str, str], timeout: int = 60):
|
|
137
|
-
self.base_url = base_url.rstrip("/")
|
|
138
|
-
self.headers = headers
|
|
139
|
-
self.timeout = timeout
|
|
140
|
-
|
|
141
|
-
def _build_url(self, endpoint: str) -> str:
|
|
142
|
-
"""Build the full URL for an API endpoint."""
|
|
143
|
-
endpoint = endpoint.lstrip("/")
|
|
144
|
-
return f"{self.base_url}/{endpoint}"
|
|
145
|
-
|
|
146
|
-
async def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
147
|
-
"""
|
|
148
|
-
Make an asynchronous GET request to the API.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
endpoint: API endpoint
|
|
152
|
-
params: Query parameters
|
|
153
|
-
|
|
154
|
-
Returns:
|
|
155
|
-
Response data as a dictionary
|
|
156
|
-
|
|
157
|
-
Raises:
|
|
158
|
-
APIError: If the request fails
|
|
159
|
-
"""
|
|
160
|
-
url = self._build_url(endpoint)
|
|
161
|
-
try:
|
|
162
|
-
async with aiohttp.ClientSession() as session:
|
|
163
|
-
async with session.get(
|
|
164
|
-
url,
|
|
165
|
-
headers=self.headers,
|
|
166
|
-
params=params,
|
|
167
|
-
timeout=self.timeout
|
|
168
|
-
) as response:
|
|
169
|
-
response.raise_for_status()
|
|
170
|
-
return await response.json()
|
|
171
|
-
except aiohttp.ClientError as e:
|
|
172
|
-
status_code = getattr(response, "status", None) if 'response' in locals() else None
|
|
173
|
-
response_data = None
|
|
174
|
-
|
|
175
|
-
if 'response' in locals():
|
|
176
|
-
try:
|
|
177
|
-
response_data = await response.json()
|
|
178
|
-
except (ValueError, AttributeError):
|
|
179
|
-
response_data = {"error": str(e)}
|
|
180
|
-
|
|
181
|
-
raise APIError(
|
|
182
|
-
f"Async GET request to {url} failed: {str(e)}",
|
|
183
|
-
status_code=status_code,
|
|
184
|
-
response=response_data
|
|
185
|
-
) from e
|
|
186
|
-
|
|
187
|
-
async def post(self, endpoint: str, data: Optional[Dict[str, Any]] = None, json_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
188
|
-
"""
|
|
189
|
-
Make an asynchronous POST request to the API.
|
|
190
|
-
|
|
191
|
-
Args:
|
|
192
|
-
endpoint: API endpoint
|
|
193
|
-
data: Form data
|
|
194
|
-
json_data: JSON data
|
|
195
|
-
|
|
196
|
-
Returns:
|
|
197
|
-
Response data as a dictionary
|
|
198
|
-
|
|
199
|
-
Raises:
|
|
200
|
-
APIError: If the request fails
|
|
201
|
-
"""
|
|
202
|
-
url = self._build_url(endpoint)
|
|
203
|
-
try:
|
|
204
|
-
async with aiohttp.ClientSession() as session:
|
|
205
|
-
async with session.post(
|
|
206
|
-
url,
|
|
207
|
-
headers=self.headers,
|
|
208
|
-
data=data,
|
|
209
|
-
json=json_data,
|
|
210
|
-
timeout=self.timeout
|
|
211
|
-
) as response:
|
|
212
|
-
response.raise_for_status()
|
|
213
|
-
return await response.json()
|
|
214
|
-
except aiohttp.ClientError as e:
|
|
215
|
-
status_code = getattr(response, "status", None) if 'response' in locals() else None
|
|
216
|
-
response_data = None
|
|
217
|
-
|
|
218
|
-
if 'response' in locals():
|
|
219
|
-
try:
|
|
220
|
-
response_data = await response.json()
|
|
221
|
-
except (ValueError, AttributeError):
|
|
222
|
-
response_data = {"error": str(e)}
|
|
223
|
-
|
|
224
|
-
raise APIError(
|
|
225
|
-
f"Async POST request to {url} failed: {str(e)}",
|
|
226
|
-
status_code=status_code,
|
|
227
|
-
response=response_data
|
|
228
|
-
) from e
|
wisent/version.py
DELETED