wisent 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

@@ -0,0 +1,168 @@
1
+ """
2
+ Manager for working with control vectors.
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import torch
9
+
10
+ from wisent.control_vector.models import ControlVector, ControlVectorConfig
11
+ from wisent.utils.auth import AuthManager
12
+ from wisent.utils.http import HTTPClient
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ControlVectorManager:
18
+ """
19
+ Manager for working with control vectors.
20
+
21
+ Args:
22
+ api_key: Wisent API key
23
+ base_url: Base URL for the API
24
+ timeout: Request timeout in seconds
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ api_key: str,
30
+ base_url: str = "https://api.wisent.ai",
31
+ timeout: int = 60,
32
+ ):
33
+ self.auth = AuthManager(api_key)
34
+ self.http_client = HTTPClient(base_url, self.auth.get_headers(), timeout)
35
+ self.cache = {} # Simple in-memory cache
36
+
37
+ def get(self, name: str, model: str) -> ControlVector:
38
+ """
39
+ Get a control vector from the Wisent backend.
40
+
41
+ Args:
42
+ name: Name of the control vector
43
+ model: Model name
44
+
45
+ Returns:
46
+ Control vector
47
+ """
48
+ cache_key = f"{name}:{model}"
49
+ if cache_key in self.cache:
50
+ logger.info(f"Using cached control vector: {name} for model {model}")
51
+ return self.cache[cache_key]
52
+
53
+ logger.info(f"Fetching control vector: {name} for model {model}")
54
+ data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
55
+ vector = ControlVector(**data)
56
+
57
+ # Cache the result
58
+ self.cache[cache_key] = vector
59
+
60
+ return vector
61
+
62
+ def list(
63
+ self,
64
+ model: Optional[str] = None,
65
+ limit: int = 100,
66
+ offset: int = 0,
67
+ ) -> List[Dict]:
68
+ """
69
+ List available control vectors from the Wisent backend.
70
+
71
+ Args:
72
+ model: Filter by model name
73
+ limit: Maximum number of results
74
+ offset: Offset for pagination
75
+
76
+ Returns:
77
+ List of control vector metadata
78
+ """
79
+ params = {"limit": limit, "offset": offset}
80
+ if model:
81
+ params["model"] = model
82
+
83
+ return self.http_client.get("/control_vectors", params=params)
84
+
85
+ def combine(
86
+ self,
87
+ vectors: Dict[str, float],
88
+ model: str,
89
+ ) -> ControlVector:
90
+ """
91
+ Combine multiple control vectors with weights.
92
+
93
+ Args:
94
+ vectors: Dictionary mapping vector names to weights
95
+ model: Model name
96
+
97
+ Returns:
98
+ Combined control vector
99
+ """
100
+ # Check if we can combine locally
101
+ can_combine_locally = True
102
+ local_vectors = {}
103
+
104
+ for name in vectors.keys():
105
+ cache_key = f"{name}:{model}"
106
+ if cache_key not in self.cache:
107
+ can_combine_locally = False
108
+ break
109
+ local_vectors[name] = self.cache[cache_key]
110
+
111
+ if can_combine_locally:
112
+ logger.info(f"Combining vectors locally for model {model}")
113
+ return self._combine_locally(local_vectors, vectors, model)
114
+
115
+ # Otherwise, use the API
116
+ logger.info(f"Combining vectors via API for model {model}")
117
+ data = self.http_client.post(
118
+ "/control_vectors/combine",
119
+ json_data={
120
+ "vectors": vectors,
121
+ "model": model,
122
+ }
123
+ )
124
+ return ControlVector(**data)
125
+
126
+ def _combine_locally(
127
+ self,
128
+ vectors: Dict[str, ControlVector],
129
+ weights: Dict[str, float],
130
+ model: str,
131
+ ) -> ControlVector:
132
+ """
133
+ Combine vectors locally.
134
+
135
+ Args:
136
+ vectors: Dictionary mapping vector names to ControlVector objects
137
+ weights: Dictionary mapping vector names to weights
138
+ model: Model name
139
+
140
+ Returns:
141
+ Combined control vector
142
+ """
143
+ # Convert all vectors to tensors
144
+ tensor_vectors = {}
145
+ for name, vector in vectors.items():
146
+ tensor_vectors[name] = vector.to_tensor()
147
+
148
+ # Get the shape from the first vector
149
+ first_vector = next(iter(tensor_vectors.values()))
150
+ combined = torch.zeros_like(first_vector)
151
+
152
+ # Combine vectors with weights
153
+ for name, weight in weights.items():
154
+ if name in tensor_vectors:
155
+ combined += tensor_vectors[name] * weight
156
+
157
+ # Create a new control vector
158
+ vector_names = list(weights.keys())
159
+ combined_name = f"combined_{'_'.join(vector_names)}"
160
+
161
+ return ControlVector(
162
+ name=combined_name,
163
+ model_name=model,
164
+ values=combined,
165
+ metadata={
166
+ "combined_from": {name: weight for name, weight in weights.items()},
167
+ }
168
+ )
@@ -0,0 +1,70 @@
1
+ """
2
+ Data models for control vectors.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from pydantic import BaseModel, Field
11
+
12
+
13
+ class ControlVector(BaseModel):
14
+ """
15
+ Represents a control vector for steering model outputs.
16
+
17
+ Attributes:
18
+ name: Name of the control vector
19
+ model_name: Name of the model the vector is for
20
+ values: Vector values
21
+ metadata: Additional metadata
22
+ """
23
+
24
+ name: str
25
+ model_name: str
26
+ values: Union[List[float], np.ndarray, torch.Tensor]
27
+ metadata: Optional[Dict] = Field(default_factory=dict)
28
+
29
+ class Config:
30
+ arbitrary_types_allowed = True
31
+
32
+ def to_dict(self) -> Dict:
33
+ """Convert to dictionary for API requests."""
34
+ values = self.values
35
+ if isinstance(values, torch.Tensor):
36
+ values = values.detach().cpu().numpy()
37
+ if isinstance(values, np.ndarray):
38
+ values = values.tolist()
39
+
40
+ return {
41
+ "name": self.name,
42
+ "model_name": self.model_name,
43
+ "values": values,
44
+ "metadata": self.metadata or {},
45
+ }
46
+
47
+ def to_tensor(self, device: str = "cpu") -> torch.Tensor:
48
+ """Convert values to a PyTorch tensor."""
49
+ if isinstance(self.values, torch.Tensor):
50
+ return self.values.to(device)
51
+ elif isinstance(self.values, np.ndarray):
52
+ return torch.tensor(self.values, device=device)
53
+ else:
54
+ return torch.tensor(self.values, device=device)
55
+
56
+
57
+ @dataclass
58
+ class ControlVectorConfig:
59
+ """
60
+ Configuration for control vector application.
61
+
62
+ Attributes:
63
+ scale: Scaling factor for the control vector
64
+ method: Method for applying the control vector
65
+ layers: Layers to apply the control vector to
66
+ """
67
+
68
+ scale: float = 1.0
69
+ method: str = "caa" # Context-Aware Addition
70
+ layers: Optional[List[int]] = None
@@ -0,0 +1,9 @@
1
+ """
2
+ Functionality for model inference with control vectors.
3
+ """
4
+
5
+ from wisent.inference.client import InferenceClient
6
+ from wisent.inference.inferencer import Inferencer
7
+ from wisent.inference.models import InferenceConfig, InferenceResponse
8
+
9
+ __all__ = ["InferenceClient", "Inferencer", "InferenceConfig", "InferenceResponse"]
@@ -0,0 +1,103 @@
1
+ """
2
+ Client for interacting with the inference API.
3
+ """
4
+
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ from wisent.inference.models import InferenceConfig, InferenceResponse
8
+ from wisent.utils.auth import AuthManager
9
+ from wisent.utils.http import HTTPClient
10
+
11
+
12
+ class InferenceClient:
13
+ """
14
+ Client for interacting with the inference API.
15
+
16
+ Args:
17
+ auth_manager: Authentication manager
18
+ base_url: Base URL for the API
19
+ timeout: Request timeout in seconds
20
+ """
21
+
22
+ def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
23
+ self.auth_manager = auth_manager
24
+ self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
25
+
26
+ def generate(
27
+ self,
28
+ model_name: str,
29
+ prompt: str,
30
+ config: Optional[InferenceConfig] = None,
31
+ ) -> InferenceResponse:
32
+ """
33
+ Generate text using a model.
34
+
35
+ Args:
36
+ model_name: Name of the model
37
+ prompt: Input prompt
38
+ config: Inference configuration
39
+
40
+ Returns:
41
+ Inference response
42
+ """
43
+ config = config or InferenceConfig()
44
+
45
+ data = self.http_client.post(
46
+ "/inference/generate",
47
+ json_data={
48
+ "model": model_name,
49
+ "prompt": prompt,
50
+ "max_tokens": config.max_tokens,
51
+ "temperature": config.temperature,
52
+ "top_p": config.top_p,
53
+ "top_k": config.top_k,
54
+ "repetition_penalty": config.repetition_penalty,
55
+ "stop_sequences": config.stop_sequences,
56
+ }
57
+ )
58
+
59
+ return InferenceResponse(**data)
60
+
61
+ def generate_with_control(
62
+ self,
63
+ model_name: str,
64
+ prompt: str,
65
+ control_vectors: Dict[str, float],
66
+ method: str = "caa",
67
+ scale: float = 1.0,
68
+ config: Optional[InferenceConfig] = None,
69
+ ) -> InferenceResponse:
70
+ """
71
+ Generate text using a model with control vectors.
72
+
73
+ Args:
74
+ model_name: Name of the model
75
+ prompt: Input prompt
76
+ control_vectors: Dictionary mapping vector names to weights
77
+ method: Method for applying control vectors
78
+ scale: Scaling factor for control vectors
79
+ config: Inference configuration
80
+
81
+ Returns:
82
+ Inference response
83
+ """
84
+ config = config or InferenceConfig()
85
+
86
+ data = self.http_client.post(
87
+ "/inference/generate_with_control",
88
+ json_data={
89
+ "model": model_name,
90
+ "prompt": prompt,
91
+ "control_vectors": control_vectors,
92
+ "method": method,
93
+ "scale": scale,
94
+ "max_tokens": config.max_tokens,
95
+ "temperature": config.temperature,
96
+ "top_p": config.top_p,
97
+ "top_k": config.top_k,
98
+ "repetition_penalty": config.repetition_penalty,
99
+ "stop_sequences": config.stop_sequences,
100
+ }
101
+ )
102
+
103
+ return InferenceResponse(**data)
@@ -0,0 +1,250 @@
1
+ """
2
+ Functionality for local inference with control vectors.
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
10
+
11
+ from wisent.control_vector.models import ControlVector
12
+ from wisent.inference.models import ControlVectorInferenceConfig, InferenceConfig, InferenceResponse
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ControlVectorHook:
18
+ """
19
+ Hook for applying control vectors during inference.
20
+
21
+ Args:
22
+ control_vector: Control vector to apply
23
+ config: Configuration for applying the control vector
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ control_vector: ControlVector,
29
+ config: ControlVectorInferenceConfig,
30
+ ):
31
+ self.control_vector = control_vector
32
+ self.config = config
33
+ self.device = None
34
+ self.vector_tensor = None
35
+ self.hooks = []
36
+
37
+ def register(self, model):
38
+ """
39
+ Register hooks on the model.
40
+
41
+ Args:
42
+ model: The model to register hooks on
43
+ """
44
+ self.device = next(model.parameters()).device
45
+ self.vector_tensor = self.control_vector.to_tensor(self.device)
46
+
47
+ # Get transformer layers
48
+ if hasattr(model, "transformer"):
49
+ transformer_layers = model.transformer.h
50
+ elif hasattr(model, "model") and hasattr(model.model, "layers"):
51
+ transformer_layers = model.model.layers
52
+ else:
53
+ raise ValueError(f"Unsupported model architecture: {model.__class__.__name__}")
54
+
55
+ # Determine which layers to apply the control vector to
56
+ num_layers = len(transformer_layers)
57
+ layers = self.config.layers or [num_layers - 1] # Default to last layer
58
+
59
+ # Resolve negative indices
60
+ resolved_layers = []
61
+ for layer in layers:
62
+ if layer < 0:
63
+ resolved_layer = num_layers + layer
64
+ else:
65
+ resolved_layer = layer
66
+
67
+ if 0 <= resolved_layer < num_layers:
68
+ resolved_layers.append(resolved_layer)
69
+
70
+ # Register hooks
71
+ for layer_idx in resolved_layers:
72
+ layer = transformer_layers[layer_idx]
73
+
74
+ # Define hook function
75
+ def hook_fn(module, input, output, layer_idx=layer_idx):
76
+ if isinstance(output, tuple):
77
+ hidden_states = output[0]
78
+ else:
79
+ hidden_states = output
80
+
81
+ # Apply the control vector
82
+ if self.config.method == "caa": # Context-Aware Addition
83
+ # Add the control vector to the hidden states
84
+ modified = hidden_states + self.vector_tensor * self.config.scale
85
+
86
+ if isinstance(output, tuple):
87
+ return (modified,) + output[1:]
88
+ else:
89
+ return modified
90
+ else:
91
+ logger.warning(f"Unsupported method: {self.config.method}, using original output")
92
+ return output
93
+
94
+ # Register hook
95
+ if hasattr(layer, "output"):
96
+ handle = layer.output.register_forward_hook(
97
+ lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
98
+ )
99
+ else:
100
+ handle = layer.register_forward_hook(
101
+ lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
102
+ )
103
+
104
+ self.hooks.append(handle)
105
+
106
+ def remove(self):
107
+ """Remove all registered hooks."""
108
+ for hook in self.hooks:
109
+ hook.remove()
110
+ self.hooks = []
111
+
112
+
113
+ class Inferencer:
114
+ """
115
+ Performs local inference with control vectors.
116
+
117
+ Args:
118
+ model_name: Name of the model
119
+ device: Device to use for inference
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ model_name: str,
125
+ device: Optional[str] = None,
126
+ ):
127
+ self.model_name = model_name
128
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
129
+ self.model = None
130
+ self.tokenizer = None
131
+
132
+ logger.info(f"Initializing Inferencer for model {model_name} on {self.device}")
133
+
134
+ def _load_model(self):
135
+ """Load the model and tokenizer."""
136
+ if self.model is None:
137
+ logger.info(f"Loading model {self.model_name}")
138
+ self.model = AutoModelForCausalLM.from_pretrained(
139
+ self.model_name,
140
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
141
+ device_map=self.device
142
+ )
143
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
144
+ logger.info(f"Model loaded successfully")
145
+
146
+ def generate(
147
+ self,
148
+ prompt: str,
149
+ control_vector: Optional[ControlVector] = None,
150
+ method: str = "caa",
151
+ scale: float = 1.0,
152
+ layers: Optional[List[int]] = None,
153
+ config: Optional[InferenceConfig] = None,
154
+ ) -> InferenceResponse:
155
+ """
156
+ Generate text using the model, optionally with a control vector.
157
+
158
+ Args:
159
+ prompt: Input prompt
160
+ control_vector: Control vector to apply (optional)
161
+ method: Method for applying the control vector
162
+ scale: Scaling factor for the control vector
163
+ layers: Layers to apply the control vector to
164
+ config: Inference configuration
165
+
166
+ Returns:
167
+ Inference response
168
+ """
169
+ try:
170
+ self._load_model()
171
+
172
+ config = config or InferenceConfig()
173
+ hook = None
174
+
175
+ # Register control vector hook if provided
176
+ if control_vector is not None:
177
+ cv_config = ControlVectorInferenceConfig(
178
+ method=method,
179
+ scale=scale,
180
+ layers=layers,
181
+ )
182
+ hook = ControlVectorHook(control_vector, cv_config)
183
+ hook.register(self.model)
184
+
185
+ # Tokenize input
186
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
187
+ prompt_length = inputs.input_ids.shape[1]
188
+
189
+ # Configure generation
190
+ generation_config = GenerationConfig(
191
+ max_new_tokens=config.max_tokens,
192
+ temperature=config.temperature,
193
+ top_p=config.top_p,
194
+ top_k=config.top_k,
195
+ repetition_penalty=config.repetition_penalty,
196
+ do_sample=config.temperature > 0,
197
+ pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
198
+ )
199
+
200
+ # Generate
201
+ with torch.no_grad():
202
+ output_ids = self.model.generate(
203
+ inputs.input_ids,
204
+ attention_mask=inputs.attention_mask,
205
+ generation_config=generation_config,
206
+ )
207
+
208
+ # Remove control vector hook if registered
209
+ if hook is not None:
210
+ hook.remove()
211
+
212
+ # Decode output
213
+ generated_text = self.tokenizer.decode(
214
+ output_ids[0][prompt_length:],
215
+ skip_special_tokens=True
216
+ )
217
+
218
+ # Create response
219
+ return InferenceResponse(
220
+ text=generated_text,
221
+ model=self.model_name,
222
+ prompt=prompt,
223
+ finish_reason="length", # Simplified
224
+ usage={
225
+ "prompt_tokens": prompt_length,
226
+ "completion_tokens": output_ids.shape[1] - prompt_length,
227
+ "total_tokens": output_ids.shape[1],
228
+ },
229
+ metadata={
230
+ "control_vector": control_vector.name if control_vector else None,
231
+ "method": method if control_vector else None,
232
+ "scale": scale if control_vector else None,
233
+ }
234
+ )
235
+
236
+ except Exception as e:
237
+ logger.error(f"Error during inference: {str(e)}")
238
+ if hook is not None:
239
+ hook.remove()
240
+ raise
241
+
242
+ def __del__(self):
243
+ """Clean up resources."""
244
+ # Free GPU memory
245
+ if self.model is not None and hasattr(self.model, "to"):
246
+ self.model = self.model.to("cpu")
247
+
248
+ # Clear CUDA cache
249
+ if torch.cuda.is_available():
250
+ torch.cuda.empty_cache()
@@ -0,0 +1,66 @@
1
+ """
2
+ Data models for inference.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class InferenceConfig(BaseModel):
12
+ """
13
+ Configuration for model inference.
14
+
15
+ Attributes:
16
+ max_tokens: Maximum number of tokens to generate
17
+ temperature: Sampling temperature
18
+ top_p: Top-p sampling parameter
19
+ top_k: Top-k sampling parameter
20
+ repetition_penalty: Repetition penalty
21
+ stop_sequences: Sequences that stop generation
22
+ """
23
+
24
+ max_tokens: int = 256
25
+ temperature: float = 0.7
26
+ top_p: float = 0.9
27
+ top_k: int = 50
28
+ repetition_penalty: float = 1.0
29
+ stop_sequences: Optional[List[str]] = None
30
+
31
+
32
+ class InferenceResponse(BaseModel):
33
+ """
34
+ Response from model inference.
35
+
36
+ Attributes:
37
+ text: Generated text
38
+ model: Model used for generation
39
+ prompt: Input prompt
40
+ finish_reason: Reason generation stopped
41
+ usage: Token usage information
42
+ metadata: Additional metadata
43
+ """
44
+
45
+ text: str
46
+ model: str
47
+ prompt: str
48
+ finish_reason: str = "length"
49
+ usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
50
+ metadata: Dict = Field(default_factory=dict)
51
+
52
+
53
+ @dataclass
54
+ class ControlVectorInferenceConfig:
55
+ """
56
+ Configuration for inference with control vectors.
57
+
58
+ Attributes:
59
+ method: Method for applying control vectors
60
+ scale: Scaling factor for control vectors
61
+ layers: Layers to apply control vectors to
62
+ """
63
+
64
+ method: str = "caa" # Context-Aware Addition
65
+ scale: float = 1.0
66
+ layers: Optional[List[int]] = None
@@ -0,0 +1,3 @@
1
+ """
2
+ Utility functions and classes for the Wisent package.
3
+ """
wisent/utils/auth.py ADDED
@@ -0,0 +1,30 @@
1
+ """
2
+ Authentication utilities for the Wisent API.
3
+ """
4
+
5
+ from typing import Dict
6
+
7
+
8
+ class AuthManager:
9
+ """
10
+ Manages authentication for Wisent API requests.
11
+
12
+ Args:
13
+ api_key: The Wisent API key
14
+ """
15
+
16
+ def __init__(self, api_key: str):
17
+ self.api_key = api_key
18
+
19
+ def get_headers(self) -> Dict[str, str]:
20
+ """
21
+ Get the authentication headers for API requests.
22
+
23
+ Returns:
24
+ Dict containing the authentication headers
25
+ """
26
+ return {
27
+ "Authorization": f"Bearer {self.api_key}",
28
+ "Content-Type": "application/json",
29
+ "Accept": "application/json",
30
+ }