wisent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +8 -0
- wisent/activations/__init__.py +9 -0
- wisent/activations/client.py +97 -0
- wisent/activations/extractor.py +251 -0
- wisent/activations/models.py +95 -0
- wisent/client.py +45 -0
- wisent/control_vector/__init__.py +9 -0
- wisent/control_vector/client.py +85 -0
- wisent/control_vector/manager.py +168 -0
- wisent/control_vector/models.py +70 -0
- wisent/inference/__init__.py +9 -0
- wisent/inference/client.py +103 -0
- wisent/inference/inferencer.py +250 -0
- wisent/inference/models.py +66 -0
- wisent/utils/__init__.py +3 -0
- wisent/utils/auth.py +30 -0
- wisent/utils/http.py +228 -0
- wisent/version.py +3 -0
- wisent-0.1.1.dist-info/LICENSE +21 -0
- wisent-0.1.1.dist-info/METADATA +142 -0
- wisent-0.1.1.dist-info/RECORD +23 -0
- wisent-0.1.1.dist-info/WHEEL +5 -0
- wisent-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Manager for working with control vectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from wisent.control_vector.models import ControlVector, ControlVectorConfig
|
|
11
|
+
from wisent.utils.auth import AuthManager
|
|
12
|
+
from wisent.utils.http import HTTPClient
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ControlVectorManager:
|
|
18
|
+
"""
|
|
19
|
+
Manager for working with control vectors.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
api_key: Wisent API key
|
|
23
|
+
base_url: Base URL for the API
|
|
24
|
+
timeout: Request timeout in seconds
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
api_key: str,
|
|
30
|
+
base_url: str = "https://api.wisent.ai",
|
|
31
|
+
timeout: int = 60,
|
|
32
|
+
):
|
|
33
|
+
self.auth = AuthManager(api_key)
|
|
34
|
+
self.http_client = HTTPClient(base_url, self.auth.get_headers(), timeout)
|
|
35
|
+
self.cache = {} # Simple in-memory cache
|
|
36
|
+
|
|
37
|
+
def get(self, name: str, model: str) -> ControlVector:
|
|
38
|
+
"""
|
|
39
|
+
Get a control vector from the Wisent backend.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
name: Name of the control vector
|
|
43
|
+
model: Model name
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Control vector
|
|
47
|
+
"""
|
|
48
|
+
cache_key = f"{name}:{model}"
|
|
49
|
+
if cache_key in self.cache:
|
|
50
|
+
logger.info(f"Using cached control vector: {name} for model {model}")
|
|
51
|
+
return self.cache[cache_key]
|
|
52
|
+
|
|
53
|
+
logger.info(f"Fetching control vector: {name} for model {model}")
|
|
54
|
+
data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
|
|
55
|
+
vector = ControlVector(**data)
|
|
56
|
+
|
|
57
|
+
# Cache the result
|
|
58
|
+
self.cache[cache_key] = vector
|
|
59
|
+
|
|
60
|
+
return vector
|
|
61
|
+
|
|
62
|
+
def list(
|
|
63
|
+
self,
|
|
64
|
+
model: Optional[str] = None,
|
|
65
|
+
limit: int = 100,
|
|
66
|
+
offset: int = 0,
|
|
67
|
+
) -> List[Dict]:
|
|
68
|
+
"""
|
|
69
|
+
List available control vectors from the Wisent backend.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model: Filter by model name
|
|
73
|
+
limit: Maximum number of results
|
|
74
|
+
offset: Offset for pagination
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
List of control vector metadata
|
|
78
|
+
"""
|
|
79
|
+
params = {"limit": limit, "offset": offset}
|
|
80
|
+
if model:
|
|
81
|
+
params["model"] = model
|
|
82
|
+
|
|
83
|
+
return self.http_client.get("/control_vectors", params=params)
|
|
84
|
+
|
|
85
|
+
def combine(
|
|
86
|
+
self,
|
|
87
|
+
vectors: Dict[str, float],
|
|
88
|
+
model: str,
|
|
89
|
+
) -> ControlVector:
|
|
90
|
+
"""
|
|
91
|
+
Combine multiple control vectors with weights.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
vectors: Dictionary mapping vector names to weights
|
|
95
|
+
model: Model name
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Combined control vector
|
|
99
|
+
"""
|
|
100
|
+
# Check if we can combine locally
|
|
101
|
+
can_combine_locally = True
|
|
102
|
+
local_vectors = {}
|
|
103
|
+
|
|
104
|
+
for name in vectors.keys():
|
|
105
|
+
cache_key = f"{name}:{model}"
|
|
106
|
+
if cache_key not in self.cache:
|
|
107
|
+
can_combine_locally = False
|
|
108
|
+
break
|
|
109
|
+
local_vectors[name] = self.cache[cache_key]
|
|
110
|
+
|
|
111
|
+
if can_combine_locally:
|
|
112
|
+
logger.info(f"Combining vectors locally for model {model}")
|
|
113
|
+
return self._combine_locally(local_vectors, vectors, model)
|
|
114
|
+
|
|
115
|
+
# Otherwise, use the API
|
|
116
|
+
logger.info(f"Combining vectors via API for model {model}")
|
|
117
|
+
data = self.http_client.post(
|
|
118
|
+
"/control_vectors/combine",
|
|
119
|
+
json_data={
|
|
120
|
+
"vectors": vectors,
|
|
121
|
+
"model": model,
|
|
122
|
+
}
|
|
123
|
+
)
|
|
124
|
+
return ControlVector(**data)
|
|
125
|
+
|
|
126
|
+
def _combine_locally(
|
|
127
|
+
self,
|
|
128
|
+
vectors: Dict[str, ControlVector],
|
|
129
|
+
weights: Dict[str, float],
|
|
130
|
+
model: str,
|
|
131
|
+
) -> ControlVector:
|
|
132
|
+
"""
|
|
133
|
+
Combine vectors locally.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
vectors: Dictionary mapping vector names to ControlVector objects
|
|
137
|
+
weights: Dictionary mapping vector names to weights
|
|
138
|
+
model: Model name
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Combined control vector
|
|
142
|
+
"""
|
|
143
|
+
# Convert all vectors to tensors
|
|
144
|
+
tensor_vectors = {}
|
|
145
|
+
for name, vector in vectors.items():
|
|
146
|
+
tensor_vectors[name] = vector.to_tensor()
|
|
147
|
+
|
|
148
|
+
# Get the shape from the first vector
|
|
149
|
+
first_vector = next(iter(tensor_vectors.values()))
|
|
150
|
+
combined = torch.zeros_like(first_vector)
|
|
151
|
+
|
|
152
|
+
# Combine vectors with weights
|
|
153
|
+
for name, weight in weights.items():
|
|
154
|
+
if name in tensor_vectors:
|
|
155
|
+
combined += tensor_vectors[name] * weight
|
|
156
|
+
|
|
157
|
+
# Create a new control vector
|
|
158
|
+
vector_names = list(weights.keys())
|
|
159
|
+
combined_name = f"combined_{'_'.join(vector_names)}"
|
|
160
|
+
|
|
161
|
+
return ControlVector(
|
|
162
|
+
name=combined_name,
|
|
163
|
+
model_name=model,
|
|
164
|
+
values=combined,
|
|
165
|
+
metadata={
|
|
166
|
+
"combined_from": {name: weight for name, weight in weights.items()},
|
|
167
|
+
}
|
|
168
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for control vectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ControlVector(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Represents a control vector for steering model outputs.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
name: Name of the control vector
|
|
19
|
+
model_name: Name of the model the vector is for
|
|
20
|
+
values: Vector values
|
|
21
|
+
metadata: Additional metadata
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
model_name: str
|
|
26
|
+
values: Union[List[float], np.ndarray, torch.Tensor]
|
|
27
|
+
metadata: Optional[Dict] = Field(default_factory=dict)
|
|
28
|
+
|
|
29
|
+
class Config:
|
|
30
|
+
arbitrary_types_allowed = True
|
|
31
|
+
|
|
32
|
+
def to_dict(self) -> Dict:
|
|
33
|
+
"""Convert to dictionary for API requests."""
|
|
34
|
+
values = self.values
|
|
35
|
+
if isinstance(values, torch.Tensor):
|
|
36
|
+
values = values.detach().cpu().numpy()
|
|
37
|
+
if isinstance(values, np.ndarray):
|
|
38
|
+
values = values.tolist()
|
|
39
|
+
|
|
40
|
+
return {
|
|
41
|
+
"name": self.name,
|
|
42
|
+
"model_name": self.model_name,
|
|
43
|
+
"values": values,
|
|
44
|
+
"metadata": self.metadata or {},
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
def to_tensor(self, device: str = "cpu") -> torch.Tensor:
|
|
48
|
+
"""Convert values to a PyTorch tensor."""
|
|
49
|
+
if isinstance(self.values, torch.Tensor):
|
|
50
|
+
return self.values.to(device)
|
|
51
|
+
elif isinstance(self.values, np.ndarray):
|
|
52
|
+
return torch.tensor(self.values, device=device)
|
|
53
|
+
else:
|
|
54
|
+
return torch.tensor(self.values, device=device)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class ControlVectorConfig:
|
|
59
|
+
"""
|
|
60
|
+
Configuration for control vector application.
|
|
61
|
+
|
|
62
|
+
Attributes:
|
|
63
|
+
scale: Scaling factor for the control vector
|
|
64
|
+
method: Method for applying the control vector
|
|
65
|
+
layers: Layers to apply the control vector to
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
scale: float = 1.0
|
|
69
|
+
method: str = "caa" # Context-Aware Addition
|
|
70
|
+
layers: Optional[List[int]] = None
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for model inference with control vectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from wisent.inference.client import InferenceClient
|
|
6
|
+
from wisent.inference.inferencer import Inferencer
|
|
7
|
+
from wisent.inference.models import InferenceConfig, InferenceResponse
|
|
8
|
+
|
|
9
|
+
__all__ = ["InferenceClient", "Inferencer", "InferenceConfig", "InferenceResponse"]
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Client for interacting with the inference API.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from wisent.inference.models import InferenceConfig, InferenceResponse
|
|
8
|
+
from wisent.utils.auth import AuthManager
|
|
9
|
+
from wisent.utils.http import HTTPClient
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InferenceClient:
|
|
13
|
+
"""
|
|
14
|
+
Client for interacting with the inference API.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
auth_manager: Authentication manager
|
|
18
|
+
base_url: Base URL for the API
|
|
19
|
+
timeout: Request timeout in seconds
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
|
|
23
|
+
self.auth_manager = auth_manager
|
|
24
|
+
self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
|
|
25
|
+
|
|
26
|
+
def generate(
|
|
27
|
+
self,
|
|
28
|
+
model_name: str,
|
|
29
|
+
prompt: str,
|
|
30
|
+
config: Optional[InferenceConfig] = None,
|
|
31
|
+
) -> InferenceResponse:
|
|
32
|
+
"""
|
|
33
|
+
Generate text using a model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model_name: Name of the model
|
|
37
|
+
prompt: Input prompt
|
|
38
|
+
config: Inference configuration
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Inference response
|
|
42
|
+
"""
|
|
43
|
+
config = config or InferenceConfig()
|
|
44
|
+
|
|
45
|
+
data = self.http_client.post(
|
|
46
|
+
"/inference/generate",
|
|
47
|
+
json_data={
|
|
48
|
+
"model": model_name,
|
|
49
|
+
"prompt": prompt,
|
|
50
|
+
"max_tokens": config.max_tokens,
|
|
51
|
+
"temperature": config.temperature,
|
|
52
|
+
"top_p": config.top_p,
|
|
53
|
+
"top_k": config.top_k,
|
|
54
|
+
"repetition_penalty": config.repetition_penalty,
|
|
55
|
+
"stop_sequences": config.stop_sequences,
|
|
56
|
+
}
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return InferenceResponse(**data)
|
|
60
|
+
|
|
61
|
+
def generate_with_control(
|
|
62
|
+
self,
|
|
63
|
+
model_name: str,
|
|
64
|
+
prompt: str,
|
|
65
|
+
control_vectors: Dict[str, float],
|
|
66
|
+
method: str = "caa",
|
|
67
|
+
scale: float = 1.0,
|
|
68
|
+
config: Optional[InferenceConfig] = None,
|
|
69
|
+
) -> InferenceResponse:
|
|
70
|
+
"""
|
|
71
|
+
Generate text using a model with control vectors.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
model_name: Name of the model
|
|
75
|
+
prompt: Input prompt
|
|
76
|
+
control_vectors: Dictionary mapping vector names to weights
|
|
77
|
+
method: Method for applying control vectors
|
|
78
|
+
scale: Scaling factor for control vectors
|
|
79
|
+
config: Inference configuration
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Inference response
|
|
83
|
+
"""
|
|
84
|
+
config = config or InferenceConfig()
|
|
85
|
+
|
|
86
|
+
data = self.http_client.post(
|
|
87
|
+
"/inference/generate_with_control",
|
|
88
|
+
json_data={
|
|
89
|
+
"model": model_name,
|
|
90
|
+
"prompt": prompt,
|
|
91
|
+
"control_vectors": control_vectors,
|
|
92
|
+
"method": method,
|
|
93
|
+
"scale": scale,
|
|
94
|
+
"max_tokens": config.max_tokens,
|
|
95
|
+
"temperature": config.temperature,
|
|
96
|
+
"top_p": config.top_p,
|
|
97
|
+
"top_k": config.top_k,
|
|
98
|
+
"repetition_penalty": config.repetition_penalty,
|
|
99
|
+
"stop_sequences": config.stop_sequences,
|
|
100
|
+
}
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return InferenceResponse(**data)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for local inference with control vectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
10
|
+
|
|
11
|
+
from wisent.control_vector.models import ControlVector
|
|
12
|
+
from wisent.inference.models import ControlVectorInferenceConfig, InferenceConfig, InferenceResponse
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ControlVectorHook:
|
|
18
|
+
"""
|
|
19
|
+
Hook for applying control vectors during inference.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
control_vector: Control vector to apply
|
|
23
|
+
config: Configuration for applying the control vector
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
control_vector: ControlVector,
|
|
29
|
+
config: ControlVectorInferenceConfig,
|
|
30
|
+
):
|
|
31
|
+
self.control_vector = control_vector
|
|
32
|
+
self.config = config
|
|
33
|
+
self.device = None
|
|
34
|
+
self.vector_tensor = None
|
|
35
|
+
self.hooks = []
|
|
36
|
+
|
|
37
|
+
def register(self, model):
|
|
38
|
+
"""
|
|
39
|
+
Register hooks on the model.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model: The model to register hooks on
|
|
43
|
+
"""
|
|
44
|
+
self.device = next(model.parameters()).device
|
|
45
|
+
self.vector_tensor = self.control_vector.to_tensor(self.device)
|
|
46
|
+
|
|
47
|
+
# Get transformer layers
|
|
48
|
+
if hasattr(model, "transformer"):
|
|
49
|
+
transformer_layers = model.transformer.h
|
|
50
|
+
elif hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
51
|
+
transformer_layers = model.model.layers
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Unsupported model architecture: {model.__class__.__name__}")
|
|
54
|
+
|
|
55
|
+
# Determine which layers to apply the control vector to
|
|
56
|
+
num_layers = len(transformer_layers)
|
|
57
|
+
layers = self.config.layers or [num_layers - 1] # Default to last layer
|
|
58
|
+
|
|
59
|
+
# Resolve negative indices
|
|
60
|
+
resolved_layers = []
|
|
61
|
+
for layer in layers:
|
|
62
|
+
if layer < 0:
|
|
63
|
+
resolved_layer = num_layers + layer
|
|
64
|
+
else:
|
|
65
|
+
resolved_layer = layer
|
|
66
|
+
|
|
67
|
+
if 0 <= resolved_layer < num_layers:
|
|
68
|
+
resolved_layers.append(resolved_layer)
|
|
69
|
+
|
|
70
|
+
# Register hooks
|
|
71
|
+
for layer_idx in resolved_layers:
|
|
72
|
+
layer = transformer_layers[layer_idx]
|
|
73
|
+
|
|
74
|
+
# Define hook function
|
|
75
|
+
def hook_fn(module, input, output, layer_idx=layer_idx):
|
|
76
|
+
if isinstance(output, tuple):
|
|
77
|
+
hidden_states = output[0]
|
|
78
|
+
else:
|
|
79
|
+
hidden_states = output
|
|
80
|
+
|
|
81
|
+
# Apply the control vector
|
|
82
|
+
if self.config.method == "caa": # Context-Aware Addition
|
|
83
|
+
# Add the control vector to the hidden states
|
|
84
|
+
modified = hidden_states + self.vector_tensor * self.config.scale
|
|
85
|
+
|
|
86
|
+
if isinstance(output, tuple):
|
|
87
|
+
return (modified,) + output[1:]
|
|
88
|
+
else:
|
|
89
|
+
return modified
|
|
90
|
+
else:
|
|
91
|
+
logger.warning(f"Unsupported method: {self.config.method}, using original output")
|
|
92
|
+
return output
|
|
93
|
+
|
|
94
|
+
# Register hook
|
|
95
|
+
if hasattr(layer, "output"):
|
|
96
|
+
handle = layer.output.register_forward_hook(
|
|
97
|
+
lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
handle = layer.register_forward_hook(
|
|
101
|
+
lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.hooks.append(handle)
|
|
105
|
+
|
|
106
|
+
def remove(self):
|
|
107
|
+
"""Remove all registered hooks."""
|
|
108
|
+
for hook in self.hooks:
|
|
109
|
+
hook.remove()
|
|
110
|
+
self.hooks = []
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Inferencer:
|
|
114
|
+
"""
|
|
115
|
+
Performs local inference with control vectors.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
model_name: Name of the model
|
|
119
|
+
device: Device to use for inference
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
model_name: str,
|
|
125
|
+
device: Optional[str] = None,
|
|
126
|
+
):
|
|
127
|
+
self.model_name = model_name
|
|
128
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
129
|
+
self.model = None
|
|
130
|
+
self.tokenizer = None
|
|
131
|
+
|
|
132
|
+
logger.info(f"Initializing Inferencer for model {model_name} on {self.device}")
|
|
133
|
+
|
|
134
|
+
def _load_model(self):
|
|
135
|
+
"""Load the model and tokenizer."""
|
|
136
|
+
if self.model is None:
|
|
137
|
+
logger.info(f"Loading model {self.model_name}")
|
|
138
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
139
|
+
self.model_name,
|
|
140
|
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
141
|
+
device_map=self.device
|
|
142
|
+
)
|
|
143
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
144
|
+
logger.info(f"Model loaded successfully")
|
|
145
|
+
|
|
146
|
+
def generate(
|
|
147
|
+
self,
|
|
148
|
+
prompt: str,
|
|
149
|
+
control_vector: Optional[ControlVector] = None,
|
|
150
|
+
method: str = "caa",
|
|
151
|
+
scale: float = 1.0,
|
|
152
|
+
layers: Optional[List[int]] = None,
|
|
153
|
+
config: Optional[InferenceConfig] = None,
|
|
154
|
+
) -> InferenceResponse:
|
|
155
|
+
"""
|
|
156
|
+
Generate text using the model, optionally with a control vector.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
prompt: Input prompt
|
|
160
|
+
control_vector: Control vector to apply (optional)
|
|
161
|
+
method: Method for applying the control vector
|
|
162
|
+
scale: Scaling factor for the control vector
|
|
163
|
+
layers: Layers to apply the control vector to
|
|
164
|
+
config: Inference configuration
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Inference response
|
|
168
|
+
"""
|
|
169
|
+
try:
|
|
170
|
+
self._load_model()
|
|
171
|
+
|
|
172
|
+
config = config or InferenceConfig()
|
|
173
|
+
hook = None
|
|
174
|
+
|
|
175
|
+
# Register control vector hook if provided
|
|
176
|
+
if control_vector is not None:
|
|
177
|
+
cv_config = ControlVectorInferenceConfig(
|
|
178
|
+
method=method,
|
|
179
|
+
scale=scale,
|
|
180
|
+
layers=layers,
|
|
181
|
+
)
|
|
182
|
+
hook = ControlVectorHook(control_vector, cv_config)
|
|
183
|
+
hook.register(self.model)
|
|
184
|
+
|
|
185
|
+
# Tokenize input
|
|
186
|
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
187
|
+
prompt_length = inputs.input_ids.shape[1]
|
|
188
|
+
|
|
189
|
+
# Configure generation
|
|
190
|
+
generation_config = GenerationConfig(
|
|
191
|
+
max_new_tokens=config.max_tokens,
|
|
192
|
+
temperature=config.temperature,
|
|
193
|
+
top_p=config.top_p,
|
|
194
|
+
top_k=config.top_k,
|
|
195
|
+
repetition_penalty=config.repetition_penalty,
|
|
196
|
+
do_sample=config.temperature > 0,
|
|
197
|
+
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Generate
|
|
201
|
+
with torch.no_grad():
|
|
202
|
+
output_ids = self.model.generate(
|
|
203
|
+
inputs.input_ids,
|
|
204
|
+
attention_mask=inputs.attention_mask,
|
|
205
|
+
generation_config=generation_config,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Remove control vector hook if registered
|
|
209
|
+
if hook is not None:
|
|
210
|
+
hook.remove()
|
|
211
|
+
|
|
212
|
+
# Decode output
|
|
213
|
+
generated_text = self.tokenizer.decode(
|
|
214
|
+
output_ids[0][prompt_length:],
|
|
215
|
+
skip_special_tokens=True
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Create response
|
|
219
|
+
return InferenceResponse(
|
|
220
|
+
text=generated_text,
|
|
221
|
+
model=self.model_name,
|
|
222
|
+
prompt=prompt,
|
|
223
|
+
finish_reason="length", # Simplified
|
|
224
|
+
usage={
|
|
225
|
+
"prompt_tokens": prompt_length,
|
|
226
|
+
"completion_tokens": output_ids.shape[1] - prompt_length,
|
|
227
|
+
"total_tokens": output_ids.shape[1],
|
|
228
|
+
},
|
|
229
|
+
metadata={
|
|
230
|
+
"control_vector": control_vector.name if control_vector else None,
|
|
231
|
+
"method": method if control_vector else None,
|
|
232
|
+
"scale": scale if control_vector else None,
|
|
233
|
+
}
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.error(f"Error during inference: {str(e)}")
|
|
238
|
+
if hook is not None:
|
|
239
|
+
hook.remove()
|
|
240
|
+
raise
|
|
241
|
+
|
|
242
|
+
def __del__(self):
|
|
243
|
+
"""Clean up resources."""
|
|
244
|
+
# Free GPU memory
|
|
245
|
+
if self.model is not None and hasattr(self.model, "to"):
|
|
246
|
+
self.model = self.model.to("cpu")
|
|
247
|
+
|
|
248
|
+
# Clear CUDA cache
|
|
249
|
+
if torch.cuda.is_available():
|
|
250
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for inference.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InferenceConfig(BaseModel):
|
|
12
|
+
"""
|
|
13
|
+
Configuration for model inference.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
max_tokens: Maximum number of tokens to generate
|
|
17
|
+
temperature: Sampling temperature
|
|
18
|
+
top_p: Top-p sampling parameter
|
|
19
|
+
top_k: Top-k sampling parameter
|
|
20
|
+
repetition_penalty: Repetition penalty
|
|
21
|
+
stop_sequences: Sequences that stop generation
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
max_tokens: int = 256
|
|
25
|
+
temperature: float = 0.7
|
|
26
|
+
top_p: float = 0.9
|
|
27
|
+
top_k: int = 50
|
|
28
|
+
repetition_penalty: float = 1.0
|
|
29
|
+
stop_sequences: Optional[List[str]] = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class InferenceResponse(BaseModel):
|
|
33
|
+
"""
|
|
34
|
+
Response from model inference.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
text: Generated text
|
|
38
|
+
model: Model used for generation
|
|
39
|
+
prompt: Input prompt
|
|
40
|
+
finish_reason: Reason generation stopped
|
|
41
|
+
usage: Token usage information
|
|
42
|
+
metadata: Additional metadata
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
text: str
|
|
46
|
+
model: str
|
|
47
|
+
prompt: str
|
|
48
|
+
finish_reason: str = "length"
|
|
49
|
+
usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
|
|
50
|
+
metadata: Dict = Field(default_factory=dict)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ControlVectorInferenceConfig:
|
|
55
|
+
"""
|
|
56
|
+
Configuration for inference with control vectors.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
method: Method for applying control vectors
|
|
60
|
+
scale: Scaling factor for control vectors
|
|
61
|
+
layers: Layers to apply control vectors to
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
method: str = "caa" # Context-Aware Addition
|
|
65
|
+
scale: float = 1.0
|
|
66
|
+
layers: Optional[List[int]] = None
|
wisent/utils/__init__.py
ADDED
wisent/utils/auth.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Authentication utilities for the Wisent API.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AuthManager:
|
|
9
|
+
"""
|
|
10
|
+
Manages authentication for Wisent API requests.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
api_key: The Wisent API key
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, api_key: str):
|
|
17
|
+
self.api_key = api_key
|
|
18
|
+
|
|
19
|
+
def get_headers(self) -> Dict[str, str]:
|
|
20
|
+
"""
|
|
21
|
+
Get the authentication headers for API requests.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Dict containing the authentication headers
|
|
25
|
+
"""
|
|
26
|
+
return {
|
|
27
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
28
|
+
"Content-Type": "application/json",
|
|
29
|
+
"Accept": "application/json",
|
|
30
|
+
}
|