wisent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +8 -0
- wisent/activations/__init__.py +9 -0
- wisent/activations/client.py +97 -0
- wisent/activations/extractor.py +251 -0
- wisent/activations/models.py +95 -0
- wisent/client.py +45 -0
- wisent/control_vector/__init__.py +9 -0
- wisent/control_vector/client.py +85 -0
- wisent/control_vector/manager.py +168 -0
- wisent/control_vector/models.py +70 -0
- wisent/inference/__init__.py +9 -0
- wisent/inference/client.py +103 -0
- wisent/inference/inferencer.py +250 -0
- wisent/inference/models.py +66 -0
- wisent/utils/__init__.py +3 -0
- wisent/utils/auth.py +30 -0
- wisent/utils/http.py +228 -0
- wisent/version.py +3 -0
- wisent-0.1.1.dist-info/LICENSE +21 -0
- wisent-0.1.1.dist-info/METADATA +142 -0
- wisent-0.1.1.dist-info/RECORD +23 -0
- wisent-0.1.1.dist-info/WHEEL +5 -0
- wisent-0.1.1.dist-info/top_level.txt +1 -0
wisent/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for extracting and managing model activations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from wisent.activations.client import ActivationsClient
|
|
6
|
+
from wisent.activations.extractor import ActivationExtractor
|
|
7
|
+
from wisent.activations.models import Activation, ActivationBatch
|
|
8
|
+
|
|
9
|
+
__all__ = ["ActivationsClient", "ActivationExtractor", "Activation", "ActivationBatch"]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Client for interacting with the activations API.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from wisent.activations.extractor import ActivationExtractor
|
|
8
|
+
from wisent.activations.models import Activation, ActivationBatch
|
|
9
|
+
from wisent.utils.auth import AuthManager
|
|
10
|
+
from wisent.utils.http import HTTPClient
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ActivationsClient:
|
|
14
|
+
"""
|
|
15
|
+
Client for interacting with the activations API.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
auth_manager: Authentication manager
|
|
19
|
+
base_url: Base URL for the API
|
|
20
|
+
timeout: Request timeout in seconds
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
|
|
24
|
+
self.auth_manager = auth_manager
|
|
25
|
+
self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
|
|
26
|
+
|
|
27
|
+
def extract(
|
|
28
|
+
self,
|
|
29
|
+
model_name: str,
|
|
30
|
+
prompt: str,
|
|
31
|
+
layers: Optional[List[int]] = None,
|
|
32
|
+
tokens_to_extract: Optional[List[int]] = None,
|
|
33
|
+
device: Optional[str] = None,
|
|
34
|
+
) -> ActivationBatch:
|
|
35
|
+
"""
|
|
36
|
+
Extract activations from a model for a given prompt.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model_name: Name of the model
|
|
40
|
+
prompt: Input prompt
|
|
41
|
+
layers: List of layers to extract activations from (default: [-1])
|
|
42
|
+
tokens_to_extract: List of token indices to extract (default: [-1])
|
|
43
|
+
device: Device to use for extraction (default: "cuda" if available, else "cpu")
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Batch of activations
|
|
47
|
+
"""
|
|
48
|
+
extractor = ActivationExtractor(model_name, device=device)
|
|
49
|
+
return extractor.extract(prompt, layers, tokens_to_extract)
|
|
50
|
+
|
|
51
|
+
def upload(self, batch: ActivationBatch) -> Dict:
|
|
52
|
+
"""
|
|
53
|
+
Upload a batch of activations to the Wisent backend.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
batch: Batch of activations
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Response from the API
|
|
60
|
+
"""
|
|
61
|
+
return self.http_client.post("/activations/upload", json_data=batch.to_dict())
|
|
62
|
+
|
|
63
|
+
def get(self, batch_id: str) -> ActivationBatch:
|
|
64
|
+
"""
|
|
65
|
+
Get a batch of activations from the Wisent backend.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
batch_id: ID of the batch
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Batch of activations
|
|
72
|
+
"""
|
|
73
|
+
data = self.http_client.get(f"/activations/{batch_id}")
|
|
74
|
+
return ActivationBatch(**data)
|
|
75
|
+
|
|
76
|
+
def list(
|
|
77
|
+
self,
|
|
78
|
+
model_name: Optional[str] = None,
|
|
79
|
+
limit: int = 100,
|
|
80
|
+
offset: int = 0,
|
|
81
|
+
) -> List[Dict]:
|
|
82
|
+
"""
|
|
83
|
+
List activation batches from the Wisent backend.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_name: Filter by model name
|
|
87
|
+
limit: Maximum number of results
|
|
88
|
+
offset: Offset for pagination
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
List of activation batch metadata
|
|
92
|
+
"""
|
|
93
|
+
params = {"limit": limit, "offset": offset}
|
|
94
|
+
if model_name:
|
|
95
|
+
params["model_name"] = model_name
|
|
96
|
+
|
|
97
|
+
return self.http_client.get("/activations", params=params)
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for extracting activations from models.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.hooks import RemovableHandle
|
|
10
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
11
|
+
|
|
12
|
+
from wisent.activations.models import Activation, ActivationBatch, ActivationExtractorConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ActivationExtractor:
|
|
18
|
+
"""
|
|
19
|
+
Extracts activations from transformer models.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_name: Name of the model to extract activations from
|
|
23
|
+
config: Configuration for extraction
|
|
24
|
+
device: Device to use for extraction
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model_name: str,
|
|
30
|
+
config: Optional[ActivationExtractorConfig] = None,
|
|
31
|
+
device: Optional[str] = None,
|
|
32
|
+
):
|
|
33
|
+
self.model_name = model_name
|
|
34
|
+
self.config = config or ActivationExtractorConfig()
|
|
35
|
+
|
|
36
|
+
if device:
|
|
37
|
+
self.config.device = device
|
|
38
|
+
|
|
39
|
+
self.device = self.config.device
|
|
40
|
+
self.model = None
|
|
41
|
+
self.tokenizer = None
|
|
42
|
+
self._hooks = []
|
|
43
|
+
self._activations = {}
|
|
44
|
+
|
|
45
|
+
logger.info(f"Initializing ActivationExtractor for model {model_name} on {self.device}")
|
|
46
|
+
|
|
47
|
+
def _load_model(self) -> None:
|
|
48
|
+
"""Load the model and tokenizer."""
|
|
49
|
+
if self.model is None:
|
|
50
|
+
logger.info(f"Loading model {self.model_name}")
|
|
51
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
52
|
+
self.model_name,
|
|
53
|
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
54
|
+
device_map=self.device
|
|
55
|
+
)
|
|
56
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
57
|
+
logger.info(f"Model loaded successfully")
|
|
58
|
+
|
|
59
|
+
def _register_hooks(self, layers: List[int]) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Register hooks to capture activations from specified layers.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
layers: List of layer indices to capture
|
|
65
|
+
"""
|
|
66
|
+
self._remove_hooks()
|
|
67
|
+
self._activations = {}
|
|
68
|
+
|
|
69
|
+
# Get all transformer layers
|
|
70
|
+
if hasattr(self.model, "transformer"):
|
|
71
|
+
transformer_layers = self.model.transformer.h
|
|
72
|
+
elif hasattr(self.model, "model") and hasattr(self.model.model, "layers"):
|
|
73
|
+
transformer_layers = self.model.model.layers
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(f"Unsupported model architecture: {self.model_name}")
|
|
76
|
+
|
|
77
|
+
num_layers = len(transformer_layers)
|
|
78
|
+
|
|
79
|
+
# Resolve negative indices
|
|
80
|
+
resolved_layers = []
|
|
81
|
+
for layer in layers:
|
|
82
|
+
if layer < 0:
|
|
83
|
+
resolved_layer = num_layers + layer
|
|
84
|
+
else:
|
|
85
|
+
resolved_layer = layer
|
|
86
|
+
|
|
87
|
+
if 0 <= resolved_layer < num_layers:
|
|
88
|
+
resolved_layers.append(resolved_layer)
|
|
89
|
+
else:
|
|
90
|
+
logger.warning(f"Layer index {layer} out of range (0-{num_layers-1}), skipping")
|
|
91
|
+
|
|
92
|
+
# Register hooks for each layer
|
|
93
|
+
for layer_idx in resolved_layers:
|
|
94
|
+
layer = transformer_layers[layer_idx]
|
|
95
|
+
|
|
96
|
+
# Define hook function to capture activations
|
|
97
|
+
def hook_fn(module, input, output, layer_idx=layer_idx):
|
|
98
|
+
# For most models, output is a tuple with hidden states as the first element
|
|
99
|
+
if isinstance(output, tuple):
|
|
100
|
+
hidden_states = output[0]
|
|
101
|
+
else:
|
|
102
|
+
hidden_states = output
|
|
103
|
+
|
|
104
|
+
if layer_idx not in self._activations:
|
|
105
|
+
self._activations[layer_idx] = []
|
|
106
|
+
|
|
107
|
+
# Store a copy of the hidden states
|
|
108
|
+
self._activations[layer_idx].append(hidden_states.detach())
|
|
109
|
+
|
|
110
|
+
# Register hook on the output of the layer
|
|
111
|
+
if hasattr(layer, "output"):
|
|
112
|
+
handle = layer.output.register_forward_hook(
|
|
113
|
+
lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
handle = layer.register_forward_hook(
|
|
117
|
+
lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self._hooks.append(handle)
|
|
121
|
+
|
|
122
|
+
logger.info(f"Registered hooks for layers: {resolved_layers}")
|
|
123
|
+
|
|
124
|
+
def _remove_hooks(self) -> None:
|
|
125
|
+
"""Remove all registered hooks."""
|
|
126
|
+
for hook in self._hooks:
|
|
127
|
+
hook.remove()
|
|
128
|
+
self._hooks = []
|
|
129
|
+
|
|
130
|
+
def _get_token_indices(self, tokens_to_extract: List[int], total_tokens: int) -> List[int]:
|
|
131
|
+
"""
|
|
132
|
+
Resolve token indices, handling negative indices.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
tokens_to_extract: List of token indices to extract
|
|
136
|
+
total_tokens: Total number of tokens
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
List of resolved token indices
|
|
140
|
+
"""
|
|
141
|
+
resolved_indices = []
|
|
142
|
+
|
|
143
|
+
for idx in tokens_to_extract:
|
|
144
|
+
if idx < 0:
|
|
145
|
+
resolved_idx = total_tokens + idx
|
|
146
|
+
else:
|
|
147
|
+
resolved_idx = idx
|
|
148
|
+
|
|
149
|
+
if 0 <= resolved_idx < total_tokens:
|
|
150
|
+
resolved_indices.append(resolved_idx)
|
|
151
|
+
else:
|
|
152
|
+
logger.warning(f"Token index {idx} out of range (0-{total_tokens-1}), skipping")
|
|
153
|
+
|
|
154
|
+
return resolved_indices
|
|
155
|
+
|
|
156
|
+
def extract(
|
|
157
|
+
self,
|
|
158
|
+
prompt: str,
|
|
159
|
+
layers: Optional[List[int]] = None,
|
|
160
|
+
tokens_to_extract: Optional[List[int]] = None,
|
|
161
|
+
) -> ActivationBatch:
|
|
162
|
+
"""
|
|
163
|
+
Extract activations from the model for a given prompt.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
prompt: Input prompt
|
|
167
|
+
layers: List of layers to extract activations from (default: from config)
|
|
168
|
+
tokens_to_extract: List of token indices to extract (default: from config)
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Batch of activations
|
|
172
|
+
"""
|
|
173
|
+
try:
|
|
174
|
+
self._load_model()
|
|
175
|
+
|
|
176
|
+
layers = layers or self.config.layers
|
|
177
|
+
tokens_to_extract = tokens_to_extract or self.config.tokens_to_extract
|
|
178
|
+
|
|
179
|
+
# Register hooks for the specified layers
|
|
180
|
+
self._register_hooks(layers)
|
|
181
|
+
|
|
182
|
+
# Tokenize the input
|
|
183
|
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
184
|
+
input_ids = inputs.input_ids.to(self.device)
|
|
185
|
+
|
|
186
|
+
# Get the total number of tokens
|
|
187
|
+
total_tokens = input_ids.shape[1]
|
|
188
|
+
|
|
189
|
+
# Resolve token indices
|
|
190
|
+
token_indices = self._get_token_indices(tokens_to_extract, total_tokens)
|
|
191
|
+
|
|
192
|
+
# Run the model to capture activations
|
|
193
|
+
with torch.no_grad():
|
|
194
|
+
self.model(input_ids)
|
|
195
|
+
|
|
196
|
+
# Process captured activations
|
|
197
|
+
activations = []
|
|
198
|
+
|
|
199
|
+
for layer_idx, layer_activations in self._activations.items():
|
|
200
|
+
# Layer activations should have shape [batch_size, seq_len, hidden_dim]
|
|
201
|
+
hidden_states = layer_activations[0]
|
|
202
|
+
|
|
203
|
+
# Get token strings for the specified indices
|
|
204
|
+
token_strings = {}
|
|
205
|
+
for token_idx in token_indices:
|
|
206
|
+
token_id = input_ids[0, token_idx].item()
|
|
207
|
+
token_strings[token_idx] = self.tokenizer.decode([token_id])
|
|
208
|
+
|
|
209
|
+
# Extract activations for the specified tokens
|
|
210
|
+
for token_idx in token_indices:
|
|
211
|
+
# Extract the activation for this token
|
|
212
|
+
token_activation = hidden_states[0, token_idx, :].cpu()
|
|
213
|
+
|
|
214
|
+
# Create an Activation object
|
|
215
|
+
activation = Activation(
|
|
216
|
+
model_name=self.model_name,
|
|
217
|
+
layer=layer_idx,
|
|
218
|
+
token_index=token_idx,
|
|
219
|
+
values=token_activation,
|
|
220
|
+
token_str=token_strings.get(token_idx)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
activations.append(activation)
|
|
224
|
+
|
|
225
|
+
# Clean up
|
|
226
|
+
self._remove_hooks()
|
|
227
|
+
|
|
228
|
+
# Create and return the batch
|
|
229
|
+
return ActivationBatch(
|
|
230
|
+
model_name=self.model_name,
|
|
231
|
+
prompt=prompt,
|
|
232
|
+
activations=activations,
|
|
233
|
+
metadata={"total_tokens": total_tokens}
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.error(f"Error extracting activations: {str(e)}")
|
|
238
|
+
self._remove_hooks()
|
|
239
|
+
raise
|
|
240
|
+
|
|
241
|
+
def __del__(self):
|
|
242
|
+
"""Clean up resources."""
|
|
243
|
+
self._remove_hooks()
|
|
244
|
+
|
|
245
|
+
# Free GPU memory
|
|
246
|
+
if self.model is not None and hasattr(self.model, "to"):
|
|
247
|
+
self.model = self.model.to("cpu")
|
|
248
|
+
|
|
249
|
+
# Clear CUDA cache
|
|
250
|
+
if torch.cuda.is_available():
|
|
251
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data models for model activations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Activation(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Represents a single activation from a model.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
model_name: Name of the model
|
|
19
|
+
layer: Layer index
|
|
20
|
+
token_index: Token index
|
|
21
|
+
values: Activation values
|
|
22
|
+
token_str: String representation of the token (optional)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_name: str
|
|
26
|
+
layer: int
|
|
27
|
+
token_index: int
|
|
28
|
+
values: Union[List[float], np.ndarray, torch.Tensor]
|
|
29
|
+
token_str: Optional[str] = None
|
|
30
|
+
|
|
31
|
+
class Config:
|
|
32
|
+
arbitrary_types_allowed = True
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> Dict:
|
|
35
|
+
"""Convert to dictionary for API requests."""
|
|
36
|
+
values = self.values
|
|
37
|
+
if isinstance(values, torch.Tensor):
|
|
38
|
+
values = values.detach().cpu().numpy()
|
|
39
|
+
if isinstance(values, np.ndarray):
|
|
40
|
+
values = values.tolist()
|
|
41
|
+
|
|
42
|
+
return {
|
|
43
|
+
"model_name": self.model_name,
|
|
44
|
+
"layer": self.layer,
|
|
45
|
+
"token_index": self.token_index,
|
|
46
|
+
"values": values,
|
|
47
|
+
"token_str": self.token_str,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ActivationBatch(BaseModel):
|
|
52
|
+
"""
|
|
53
|
+
Represents a batch of activations from a model.
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
model_name: Name of the model
|
|
57
|
+
prompt: Input prompt that generated the activations
|
|
58
|
+
activations: List of activations
|
|
59
|
+
metadata: Additional metadata (optional)
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
model_name: str
|
|
63
|
+
prompt: str
|
|
64
|
+
activations: List[Activation]
|
|
65
|
+
metadata: Optional[Dict] = Field(default_factory=dict)
|
|
66
|
+
|
|
67
|
+
class Config:
|
|
68
|
+
arbitrary_types_allowed = True
|
|
69
|
+
|
|
70
|
+
def to_dict(self) -> Dict:
|
|
71
|
+
"""Convert to dictionary for API requests."""
|
|
72
|
+
return {
|
|
73
|
+
"model_name": self.model_name,
|
|
74
|
+
"prompt": self.prompt,
|
|
75
|
+
"activations": [a.to_dict() for a in self.activations],
|
|
76
|
+
"metadata": self.metadata or {},
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class ActivationExtractorConfig:
|
|
82
|
+
"""
|
|
83
|
+
Configuration for activation extraction.
|
|
84
|
+
|
|
85
|
+
Attributes:
|
|
86
|
+
layers: List of layers to extract activations from
|
|
87
|
+
tokens_to_extract: List of token indices to extract (negative indices count from the end)
|
|
88
|
+
batch_size: Batch size for processing
|
|
89
|
+
device: Device to use for extraction
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
layers: List[int] = field(default_factory=lambda: [-1])
|
|
93
|
+
tokens_to_extract: List[int] = field(default_factory=lambda: [-1])
|
|
94
|
+
batch_size: int = 1
|
|
95
|
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
wisent/client.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main client class for interacting with the Wisent backend services.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Optional
|
|
6
|
+
|
|
7
|
+
from wisent.activations import ActivationsClient
|
|
8
|
+
from wisent.control_vector import ControlVectorClient
|
|
9
|
+
from wisent.inference import InferenceClient
|
|
10
|
+
from wisent.utils.auth import AuthManager
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class WisentClient:
|
|
14
|
+
"""
|
|
15
|
+
Main client for interacting with the Wisent backend services.
|
|
16
|
+
|
|
17
|
+
This client provides access to all Wisent API functionality through
|
|
18
|
+
specialized sub-clients for different features.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
api_key: Your Wisent API key
|
|
22
|
+
base_url: The base URL for the Wisent API (default: https://api.wisent.ai)
|
|
23
|
+
timeout: Request timeout in seconds (default: 60)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
api_key: str,
|
|
29
|
+
base_url: str = "https://api.wisent.ai",
|
|
30
|
+
timeout: int = 60,
|
|
31
|
+
):
|
|
32
|
+
self.api_key = api_key
|
|
33
|
+
self.base_url = base_url
|
|
34
|
+
self.timeout = timeout
|
|
35
|
+
|
|
36
|
+
# Initialize auth manager
|
|
37
|
+
self.auth = AuthManager(api_key)
|
|
38
|
+
|
|
39
|
+
# Initialize sub-clients
|
|
40
|
+
self.activations = ActivationsClient(self.auth, base_url, timeout)
|
|
41
|
+
self.control_vector = ControlVectorClient(self.auth, base_url, timeout)
|
|
42
|
+
self.inference = InferenceClient(self.auth, base_url, timeout)
|
|
43
|
+
|
|
44
|
+
def __repr__(self) -> str:
|
|
45
|
+
return f"WisentClient(base_url='{self.base_url}')"
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functionality for working with control vectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from wisent.control_vector.client import ControlVectorClient
|
|
6
|
+
from wisent.control_vector.manager import ControlVectorManager
|
|
7
|
+
from wisent.control_vector.models import ControlVector, ControlVectorConfig
|
|
8
|
+
|
|
9
|
+
__all__ = ["ControlVectorClient", "ControlVectorManager", "ControlVector", "ControlVectorConfig"]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Client for interacting with the control vector API.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from wisent.control_vector.models import ControlVector
|
|
8
|
+
from wisent.utils.auth import AuthManager
|
|
9
|
+
from wisent.utils.http import HTTPClient
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ControlVectorClient:
|
|
13
|
+
"""
|
|
14
|
+
Client for interacting with the control vector API.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
auth_manager: Authentication manager
|
|
18
|
+
base_url: Base URL for the API
|
|
19
|
+
timeout: Request timeout in seconds
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
|
|
23
|
+
self.auth_manager = auth_manager
|
|
24
|
+
self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
|
|
25
|
+
|
|
26
|
+
def get(self, name: str, model: str) -> ControlVector:
|
|
27
|
+
"""
|
|
28
|
+
Get a control vector from the Wisent backend.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
name: Name of the control vector
|
|
32
|
+
model: Model name
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Control vector
|
|
36
|
+
"""
|
|
37
|
+
data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
|
|
38
|
+
return ControlVector(**data)
|
|
39
|
+
|
|
40
|
+
def list(
|
|
41
|
+
self,
|
|
42
|
+
model: Optional[str] = None,
|
|
43
|
+
limit: int = 100,
|
|
44
|
+
offset: int = 0,
|
|
45
|
+
) -> List[Dict]:
|
|
46
|
+
"""
|
|
47
|
+
List available control vectors from the Wisent backend.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
model: Filter by model name
|
|
51
|
+
limit: Maximum number of results
|
|
52
|
+
offset: Offset for pagination
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
List of control vector metadata
|
|
56
|
+
"""
|
|
57
|
+
params = {"limit": limit, "offset": offset}
|
|
58
|
+
if model:
|
|
59
|
+
params["model"] = model
|
|
60
|
+
|
|
61
|
+
return self.http_client.get("/control_vectors", params=params)
|
|
62
|
+
|
|
63
|
+
def combine(
|
|
64
|
+
self,
|
|
65
|
+
vectors: Dict[str, float],
|
|
66
|
+
model: str,
|
|
67
|
+
) -> ControlVector:
|
|
68
|
+
"""
|
|
69
|
+
Combine multiple control vectors with weights.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
vectors: Dictionary mapping vector names to weights
|
|
73
|
+
model: Model name
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Combined control vector
|
|
77
|
+
"""
|
|
78
|
+
data = self.http_client.post(
|
|
79
|
+
"/control_vectors/combine",
|
|
80
|
+
json_data={
|
|
81
|
+
"vectors": vectors,
|
|
82
|
+
"model": model,
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
return ControlVector(**data)
|