wisent 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

wisent/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ """
2
+ Wisent - Client library for interacting with the Wisent backend services.
3
+ """
4
+
5
+ from wisent.client import WisentClient
6
+ from wisent.version import __version__
7
+
8
+ __all__ = ["WisentClient", "__version__"]
@@ -0,0 +1,9 @@
1
+ """
2
+ Functionality for extracting and managing model activations.
3
+ """
4
+
5
+ from wisent.activations.client import ActivationsClient
6
+ from wisent.activations.extractor import ActivationExtractor
7
+ from wisent.activations.models import Activation, ActivationBatch
8
+
9
+ __all__ = ["ActivationsClient", "ActivationExtractor", "Activation", "ActivationBatch"]
@@ -0,0 +1,97 @@
1
+ """
2
+ Client for interacting with the activations API.
3
+ """
4
+
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ from wisent.activations.extractor import ActivationExtractor
8
+ from wisent.activations.models import Activation, ActivationBatch
9
+ from wisent.utils.auth import AuthManager
10
+ from wisent.utils.http import HTTPClient
11
+
12
+
13
+ class ActivationsClient:
14
+ """
15
+ Client for interacting with the activations API.
16
+
17
+ Args:
18
+ auth_manager: Authentication manager
19
+ base_url: Base URL for the API
20
+ timeout: Request timeout in seconds
21
+ """
22
+
23
+ def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
24
+ self.auth_manager = auth_manager
25
+ self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
26
+
27
+ def extract(
28
+ self,
29
+ model_name: str,
30
+ prompt: str,
31
+ layers: Optional[List[int]] = None,
32
+ tokens_to_extract: Optional[List[int]] = None,
33
+ device: Optional[str] = None,
34
+ ) -> ActivationBatch:
35
+ """
36
+ Extract activations from a model for a given prompt.
37
+
38
+ Args:
39
+ model_name: Name of the model
40
+ prompt: Input prompt
41
+ layers: List of layers to extract activations from (default: [-1])
42
+ tokens_to_extract: List of token indices to extract (default: [-1])
43
+ device: Device to use for extraction (default: "cuda" if available, else "cpu")
44
+
45
+ Returns:
46
+ Batch of activations
47
+ """
48
+ extractor = ActivationExtractor(model_name, device=device)
49
+ return extractor.extract(prompt, layers, tokens_to_extract)
50
+
51
+ def upload(self, batch: ActivationBatch) -> Dict:
52
+ """
53
+ Upload a batch of activations to the Wisent backend.
54
+
55
+ Args:
56
+ batch: Batch of activations
57
+
58
+ Returns:
59
+ Response from the API
60
+ """
61
+ return self.http_client.post("/activations/upload", json_data=batch.to_dict())
62
+
63
+ def get(self, batch_id: str) -> ActivationBatch:
64
+ """
65
+ Get a batch of activations from the Wisent backend.
66
+
67
+ Args:
68
+ batch_id: ID of the batch
69
+
70
+ Returns:
71
+ Batch of activations
72
+ """
73
+ data = self.http_client.get(f"/activations/{batch_id}")
74
+ return ActivationBatch(**data)
75
+
76
+ def list(
77
+ self,
78
+ model_name: Optional[str] = None,
79
+ limit: int = 100,
80
+ offset: int = 0,
81
+ ) -> List[Dict]:
82
+ """
83
+ List activation batches from the Wisent backend.
84
+
85
+ Args:
86
+ model_name: Filter by model name
87
+ limit: Maximum number of results
88
+ offset: Offset for pagination
89
+
90
+ Returns:
91
+ List of activation batch metadata
92
+ """
93
+ params = {"limit": limit, "offset": offset}
94
+ if model_name:
95
+ params["model_name"] = model_name
96
+
97
+ return self.http_client.get("/activations", params=params)
@@ -0,0 +1,251 @@
1
+ """
2
+ Functionality for extracting activations from models.
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch.utils.hooks import RemovableHandle
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+ from wisent.activations.models import Activation, ActivationBatch, ActivationExtractorConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ActivationExtractor:
18
+ """
19
+ Extracts activations from transformer models.
20
+
21
+ Args:
22
+ model_name: Name of the model to extract activations from
23
+ config: Configuration for extraction
24
+ device: Device to use for extraction
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model_name: str,
30
+ config: Optional[ActivationExtractorConfig] = None,
31
+ device: Optional[str] = None,
32
+ ):
33
+ self.model_name = model_name
34
+ self.config = config or ActivationExtractorConfig()
35
+
36
+ if device:
37
+ self.config.device = device
38
+
39
+ self.device = self.config.device
40
+ self.model = None
41
+ self.tokenizer = None
42
+ self._hooks = []
43
+ self._activations = {}
44
+
45
+ logger.info(f"Initializing ActivationExtractor for model {model_name} on {self.device}")
46
+
47
+ def _load_model(self) -> None:
48
+ """Load the model and tokenizer."""
49
+ if self.model is None:
50
+ logger.info(f"Loading model {self.model_name}")
51
+ self.model = AutoModelForCausalLM.from_pretrained(
52
+ self.model_name,
53
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
54
+ device_map=self.device
55
+ )
56
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
57
+ logger.info(f"Model loaded successfully")
58
+
59
+ def _register_hooks(self, layers: List[int]) -> None:
60
+ """
61
+ Register hooks to capture activations from specified layers.
62
+
63
+ Args:
64
+ layers: List of layer indices to capture
65
+ """
66
+ self._remove_hooks()
67
+ self._activations = {}
68
+
69
+ # Get all transformer layers
70
+ if hasattr(self.model, "transformer"):
71
+ transformer_layers = self.model.transformer.h
72
+ elif hasattr(self.model, "model") and hasattr(self.model.model, "layers"):
73
+ transformer_layers = self.model.model.layers
74
+ else:
75
+ raise ValueError(f"Unsupported model architecture: {self.model_name}")
76
+
77
+ num_layers = len(transformer_layers)
78
+
79
+ # Resolve negative indices
80
+ resolved_layers = []
81
+ for layer in layers:
82
+ if layer < 0:
83
+ resolved_layer = num_layers + layer
84
+ else:
85
+ resolved_layer = layer
86
+
87
+ if 0 <= resolved_layer < num_layers:
88
+ resolved_layers.append(resolved_layer)
89
+ else:
90
+ logger.warning(f"Layer index {layer} out of range (0-{num_layers-1}), skipping")
91
+
92
+ # Register hooks for each layer
93
+ for layer_idx in resolved_layers:
94
+ layer = transformer_layers[layer_idx]
95
+
96
+ # Define hook function to capture activations
97
+ def hook_fn(module, input, output, layer_idx=layer_idx):
98
+ # For most models, output is a tuple with hidden states as the first element
99
+ if isinstance(output, tuple):
100
+ hidden_states = output[0]
101
+ else:
102
+ hidden_states = output
103
+
104
+ if layer_idx not in self._activations:
105
+ self._activations[layer_idx] = []
106
+
107
+ # Store a copy of the hidden states
108
+ self._activations[layer_idx].append(hidden_states.detach())
109
+
110
+ # Register hook on the output of the layer
111
+ if hasattr(layer, "output"):
112
+ handle = layer.output.register_forward_hook(
113
+ lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
114
+ )
115
+ else:
116
+ handle = layer.register_forward_hook(
117
+ lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
118
+ )
119
+
120
+ self._hooks.append(handle)
121
+
122
+ logger.info(f"Registered hooks for layers: {resolved_layers}")
123
+
124
+ def _remove_hooks(self) -> None:
125
+ """Remove all registered hooks."""
126
+ for hook in self._hooks:
127
+ hook.remove()
128
+ self._hooks = []
129
+
130
+ def _get_token_indices(self, tokens_to_extract: List[int], total_tokens: int) -> List[int]:
131
+ """
132
+ Resolve token indices, handling negative indices.
133
+
134
+ Args:
135
+ tokens_to_extract: List of token indices to extract
136
+ total_tokens: Total number of tokens
137
+
138
+ Returns:
139
+ List of resolved token indices
140
+ """
141
+ resolved_indices = []
142
+
143
+ for idx in tokens_to_extract:
144
+ if idx < 0:
145
+ resolved_idx = total_tokens + idx
146
+ else:
147
+ resolved_idx = idx
148
+
149
+ if 0 <= resolved_idx < total_tokens:
150
+ resolved_indices.append(resolved_idx)
151
+ else:
152
+ logger.warning(f"Token index {idx} out of range (0-{total_tokens-1}), skipping")
153
+
154
+ return resolved_indices
155
+
156
+ def extract(
157
+ self,
158
+ prompt: str,
159
+ layers: Optional[List[int]] = None,
160
+ tokens_to_extract: Optional[List[int]] = None,
161
+ ) -> ActivationBatch:
162
+ """
163
+ Extract activations from the model for a given prompt.
164
+
165
+ Args:
166
+ prompt: Input prompt
167
+ layers: List of layers to extract activations from (default: from config)
168
+ tokens_to_extract: List of token indices to extract (default: from config)
169
+
170
+ Returns:
171
+ Batch of activations
172
+ """
173
+ try:
174
+ self._load_model()
175
+
176
+ layers = layers or self.config.layers
177
+ tokens_to_extract = tokens_to_extract or self.config.tokens_to_extract
178
+
179
+ # Register hooks for the specified layers
180
+ self._register_hooks(layers)
181
+
182
+ # Tokenize the input
183
+ inputs = self.tokenizer(prompt, return_tensors="pt")
184
+ input_ids = inputs.input_ids.to(self.device)
185
+
186
+ # Get the total number of tokens
187
+ total_tokens = input_ids.shape[1]
188
+
189
+ # Resolve token indices
190
+ token_indices = self._get_token_indices(tokens_to_extract, total_tokens)
191
+
192
+ # Run the model to capture activations
193
+ with torch.no_grad():
194
+ self.model(input_ids)
195
+
196
+ # Process captured activations
197
+ activations = []
198
+
199
+ for layer_idx, layer_activations in self._activations.items():
200
+ # Layer activations should have shape [batch_size, seq_len, hidden_dim]
201
+ hidden_states = layer_activations[0]
202
+
203
+ # Get token strings for the specified indices
204
+ token_strings = {}
205
+ for token_idx in token_indices:
206
+ token_id = input_ids[0, token_idx].item()
207
+ token_strings[token_idx] = self.tokenizer.decode([token_id])
208
+
209
+ # Extract activations for the specified tokens
210
+ for token_idx in token_indices:
211
+ # Extract the activation for this token
212
+ token_activation = hidden_states[0, token_idx, :].cpu()
213
+
214
+ # Create an Activation object
215
+ activation = Activation(
216
+ model_name=self.model_name,
217
+ layer=layer_idx,
218
+ token_index=token_idx,
219
+ values=token_activation,
220
+ token_str=token_strings.get(token_idx)
221
+ )
222
+
223
+ activations.append(activation)
224
+
225
+ # Clean up
226
+ self._remove_hooks()
227
+
228
+ # Create and return the batch
229
+ return ActivationBatch(
230
+ model_name=self.model_name,
231
+ prompt=prompt,
232
+ activations=activations,
233
+ metadata={"total_tokens": total_tokens}
234
+ )
235
+
236
+ except Exception as e:
237
+ logger.error(f"Error extracting activations: {str(e)}")
238
+ self._remove_hooks()
239
+ raise
240
+
241
+ def __del__(self):
242
+ """Clean up resources."""
243
+ self._remove_hooks()
244
+
245
+ # Free GPU memory
246
+ if self.model is not None and hasattr(self.model, "to"):
247
+ self.model = self.model.to("cpu")
248
+
249
+ # Clear CUDA cache
250
+ if torch.cuda.is_available():
251
+ torch.cuda.empty_cache()
@@ -0,0 +1,95 @@
1
+ """
2
+ Data models for model activations.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from pydantic import BaseModel, Field
11
+
12
+
13
+ class Activation(BaseModel):
14
+ """
15
+ Represents a single activation from a model.
16
+
17
+ Attributes:
18
+ model_name: Name of the model
19
+ layer: Layer index
20
+ token_index: Token index
21
+ values: Activation values
22
+ token_str: String representation of the token (optional)
23
+ """
24
+
25
+ model_name: str
26
+ layer: int
27
+ token_index: int
28
+ values: Union[List[float], np.ndarray, torch.Tensor]
29
+ token_str: Optional[str] = None
30
+
31
+ class Config:
32
+ arbitrary_types_allowed = True
33
+
34
+ def to_dict(self) -> Dict:
35
+ """Convert to dictionary for API requests."""
36
+ values = self.values
37
+ if isinstance(values, torch.Tensor):
38
+ values = values.detach().cpu().numpy()
39
+ if isinstance(values, np.ndarray):
40
+ values = values.tolist()
41
+
42
+ return {
43
+ "model_name": self.model_name,
44
+ "layer": self.layer,
45
+ "token_index": self.token_index,
46
+ "values": values,
47
+ "token_str": self.token_str,
48
+ }
49
+
50
+
51
+ class ActivationBatch(BaseModel):
52
+ """
53
+ Represents a batch of activations from a model.
54
+
55
+ Attributes:
56
+ model_name: Name of the model
57
+ prompt: Input prompt that generated the activations
58
+ activations: List of activations
59
+ metadata: Additional metadata (optional)
60
+ """
61
+
62
+ model_name: str
63
+ prompt: str
64
+ activations: List[Activation]
65
+ metadata: Optional[Dict] = Field(default_factory=dict)
66
+
67
+ class Config:
68
+ arbitrary_types_allowed = True
69
+
70
+ def to_dict(self) -> Dict:
71
+ """Convert to dictionary for API requests."""
72
+ return {
73
+ "model_name": self.model_name,
74
+ "prompt": self.prompt,
75
+ "activations": [a.to_dict() for a in self.activations],
76
+ "metadata": self.metadata or {},
77
+ }
78
+
79
+
80
+ @dataclass
81
+ class ActivationExtractorConfig:
82
+ """
83
+ Configuration for activation extraction.
84
+
85
+ Attributes:
86
+ layers: List of layers to extract activations from
87
+ tokens_to_extract: List of token indices to extract (negative indices count from the end)
88
+ batch_size: Batch size for processing
89
+ device: Device to use for extraction
90
+ """
91
+
92
+ layers: List[int] = field(default_factory=lambda: [-1])
93
+ tokens_to_extract: List[int] = field(default_factory=lambda: [-1])
94
+ batch_size: int = 1
95
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
wisent/client.py ADDED
@@ -0,0 +1,45 @@
1
+ """
2
+ Main client class for interacting with the Wisent backend services.
3
+ """
4
+
5
+ from typing import Dict, Optional
6
+
7
+ from wisent.activations import ActivationsClient
8
+ from wisent.control_vector import ControlVectorClient
9
+ from wisent.inference import InferenceClient
10
+ from wisent.utils.auth import AuthManager
11
+
12
+
13
+ class WisentClient:
14
+ """
15
+ Main client for interacting with the Wisent backend services.
16
+
17
+ This client provides access to all Wisent API functionality through
18
+ specialized sub-clients for different features.
19
+
20
+ Args:
21
+ api_key: Your Wisent API key
22
+ base_url: The base URL for the Wisent API (default: https://api.wisent.ai)
23
+ timeout: Request timeout in seconds (default: 60)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ api_key: str,
29
+ base_url: str = "https://api.wisent.ai",
30
+ timeout: int = 60,
31
+ ):
32
+ self.api_key = api_key
33
+ self.base_url = base_url
34
+ self.timeout = timeout
35
+
36
+ # Initialize auth manager
37
+ self.auth = AuthManager(api_key)
38
+
39
+ # Initialize sub-clients
40
+ self.activations = ActivationsClient(self.auth, base_url, timeout)
41
+ self.control_vector = ControlVectorClient(self.auth, base_url, timeout)
42
+ self.inference = InferenceClient(self.auth, base_url, timeout)
43
+
44
+ def __repr__(self) -> str:
45
+ return f"WisentClient(base_url='{self.base_url}')"
@@ -0,0 +1,9 @@
1
+ """
2
+ Functionality for working with control vectors.
3
+ """
4
+
5
+ from wisent.control_vector.client import ControlVectorClient
6
+ from wisent.control_vector.manager import ControlVectorManager
7
+ from wisent.control_vector.models import ControlVector, ControlVectorConfig
8
+
9
+ __all__ = ["ControlVectorClient", "ControlVectorManager", "ControlVector", "ControlVectorConfig"]
@@ -0,0 +1,85 @@
1
+ """
2
+ Client for interacting with the control vector API.
3
+ """
4
+
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ from wisent.control_vector.models import ControlVector
8
+ from wisent.utils.auth import AuthManager
9
+ from wisent.utils.http import HTTPClient
10
+
11
+
12
+ class ControlVectorClient:
13
+ """
14
+ Client for interacting with the control vector API.
15
+
16
+ Args:
17
+ auth_manager: Authentication manager
18
+ base_url: Base URL for the API
19
+ timeout: Request timeout in seconds
20
+ """
21
+
22
+ def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
23
+ self.auth_manager = auth_manager
24
+ self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
25
+
26
+ def get(self, name: str, model: str) -> ControlVector:
27
+ """
28
+ Get a control vector from the Wisent backend.
29
+
30
+ Args:
31
+ name: Name of the control vector
32
+ model: Model name
33
+
34
+ Returns:
35
+ Control vector
36
+ """
37
+ data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
38
+ return ControlVector(**data)
39
+
40
+ def list(
41
+ self,
42
+ model: Optional[str] = None,
43
+ limit: int = 100,
44
+ offset: int = 0,
45
+ ) -> List[Dict]:
46
+ """
47
+ List available control vectors from the Wisent backend.
48
+
49
+ Args:
50
+ model: Filter by model name
51
+ limit: Maximum number of results
52
+ offset: Offset for pagination
53
+
54
+ Returns:
55
+ List of control vector metadata
56
+ """
57
+ params = {"limit": limit, "offset": offset}
58
+ if model:
59
+ params["model"] = model
60
+
61
+ return self.http_client.get("/control_vectors", params=params)
62
+
63
+ def combine(
64
+ self,
65
+ vectors: Dict[str, float],
66
+ model: str,
67
+ ) -> ControlVector:
68
+ """
69
+ Combine multiple control vectors with weights.
70
+
71
+ Args:
72
+ vectors: Dictionary mapping vector names to weights
73
+ model: Model name
74
+
75
+ Returns:
76
+ Combined control vector
77
+ """
78
+ data = self.http_client.post(
79
+ "/control_vectors/combine",
80
+ json_data={
81
+ "vectors": vectors,
82
+ "model": model,
83
+ }
84
+ )
85
+ return ControlVector(**data)