wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,95 +0,0 @@
1
- """
2
- Data models for model activations.
3
- """
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Dict, List, Optional, Union
7
-
8
- import numpy as np
9
- import torch
10
- from pydantic import BaseModel, Field
11
-
12
-
13
- class Activation(BaseModel):
14
- """
15
- Represents a single activation from a model.
16
-
17
- Attributes:
18
- model_name: Name of the model
19
- layer: Layer index
20
- token_index: Token index
21
- values: Activation values
22
- token_str: String representation of the token (optional)
23
- """
24
-
25
- model_name: str
26
- layer: int
27
- token_index: int
28
- values: Union[List[float], np.ndarray, torch.Tensor]
29
- token_str: Optional[str] = None
30
-
31
- class Config:
32
- arbitrary_types_allowed = True
33
-
34
- def to_dict(self) -> Dict:
35
- """Convert to dictionary for API requests."""
36
- values = self.values
37
- if isinstance(values, torch.Tensor):
38
- values = values.detach().cpu().numpy()
39
- if isinstance(values, np.ndarray):
40
- values = values.tolist()
41
-
42
- return {
43
- "model_name": self.model_name,
44
- "layer": self.layer,
45
- "token_index": self.token_index,
46
- "values": values,
47
- "token_str": self.token_str,
48
- }
49
-
50
-
51
- class ActivationBatch(BaseModel):
52
- """
53
- Represents a batch of activations from a model.
54
-
55
- Attributes:
56
- model_name: Name of the model
57
- prompt: Input prompt that generated the activations
58
- activations: List of activations
59
- metadata: Additional metadata (optional)
60
- """
61
-
62
- model_name: str
63
- prompt: str
64
- activations: List[Activation]
65
- metadata: Optional[Dict] = Field(default_factory=dict)
66
-
67
- class Config:
68
- arbitrary_types_allowed = True
69
-
70
- def to_dict(self) -> Dict:
71
- """Convert to dictionary for API requests."""
72
- return {
73
- "model_name": self.model_name,
74
- "prompt": self.prompt,
75
- "activations": [a.to_dict() for a in self.activations],
76
- "metadata": self.metadata or {},
77
- }
78
-
79
-
80
- @dataclass
81
- class ActivationExtractorConfig:
82
- """
83
- Configuration for activation extraction.
84
-
85
- Attributes:
86
- layers: List of layers to extract activations from
87
- tokens_to_extract: List of token indices to extract (negative indices count from the end)
88
- batch_size: Batch size for processing
89
- device: Device to use for extraction
90
- """
91
-
92
- layers: List[int] = field(default_factory=lambda: [-1])
93
- tokens_to_extract: List[int] = field(default_factory=lambda: [-1])
94
- batch_size: int = 1
95
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
wisent/client.py DELETED
@@ -1,45 +0,0 @@
1
- """
2
- Main client class for interacting with the Wisent backend services.
3
- """
4
-
5
- from typing import Dict, Optional
6
-
7
- from wisent.activations import ActivationsClient
8
- from wisent.control_vector import ControlVectorClient
9
- from wisent.inference import InferenceClient
10
- from wisent.utils.auth import AuthManager
11
-
12
-
13
- class WisentClient:
14
- """
15
- Main client for interacting with the Wisent backend services.
16
-
17
- This client provides access to all Wisent API functionality through
18
- specialized sub-clients for different features.
19
-
20
- Args:
21
- api_key: Your Wisent API key
22
- base_url: The base URL for the Wisent API (default: https://api.wisent.ai)
23
- timeout: Request timeout in seconds (default: 60)
24
- """
25
-
26
- def __init__(
27
- self,
28
- api_key: str,
29
- base_url: str = "https://api.wisent.ai",
30
- timeout: int = 60,
31
- ):
32
- self.api_key = api_key
33
- self.base_url = base_url
34
- self.timeout = timeout
35
-
36
- # Initialize auth manager
37
- self.auth = AuthManager(api_key)
38
-
39
- # Initialize sub-clients
40
- self.activations = ActivationsClient(self.auth, base_url, timeout)
41
- self.control_vector = ControlVectorClient(self.auth, base_url, timeout)
42
- self.inference = InferenceClient(self.auth, base_url, timeout)
43
-
44
- def __repr__(self) -> str:
45
- return f"WisentClient(base_url='{self.base_url}')"
@@ -1,9 +0,0 @@
1
- """
2
- Functionality for working with control vectors.
3
- """
4
-
5
- from wisent.control_vector.client import ControlVectorClient
6
- from wisent.control_vector.manager import ControlVectorManager
7
- from wisent.control_vector.models import ControlVector, ControlVectorConfig
8
-
9
- __all__ = ["ControlVectorClient", "ControlVectorManager", "ControlVector", "ControlVectorConfig"]
@@ -1,85 +0,0 @@
1
- """
2
- Client for interacting with the control vector API.
3
- """
4
-
5
- from typing import Dict, List, Optional, Union
6
-
7
- from wisent.control_vector.models import ControlVector
8
- from wisent.utils.auth import AuthManager
9
- from wisent.utils.http import HTTPClient
10
-
11
-
12
- class ControlVectorClient:
13
- """
14
- Client for interacting with the control vector API.
15
-
16
- Args:
17
- auth_manager: Authentication manager
18
- base_url: Base URL for the API
19
- timeout: Request timeout in seconds
20
- """
21
-
22
- def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
23
- self.auth_manager = auth_manager
24
- self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
25
-
26
- def get(self, name: str, model: str) -> ControlVector:
27
- """
28
- Get a control vector from the Wisent backend.
29
-
30
- Args:
31
- name: Name of the control vector
32
- model: Model name
33
-
34
- Returns:
35
- Control vector
36
- """
37
- data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
38
- return ControlVector(**data)
39
-
40
- def list(
41
- self,
42
- model: Optional[str] = None,
43
- limit: int = 100,
44
- offset: int = 0,
45
- ) -> List[Dict]:
46
- """
47
- List available control vectors from the Wisent backend.
48
-
49
- Args:
50
- model: Filter by model name
51
- limit: Maximum number of results
52
- offset: Offset for pagination
53
-
54
- Returns:
55
- List of control vector metadata
56
- """
57
- params = {"limit": limit, "offset": offset}
58
- if model:
59
- params["model"] = model
60
-
61
- return self.http_client.get("/control_vectors", params=params)
62
-
63
- def combine(
64
- self,
65
- vectors: Dict[str, float],
66
- model: str,
67
- ) -> ControlVector:
68
- """
69
- Combine multiple control vectors with weights.
70
-
71
- Args:
72
- vectors: Dictionary mapping vector names to weights
73
- model: Model name
74
-
75
- Returns:
76
- Combined control vector
77
- """
78
- data = self.http_client.post(
79
- "/control_vectors/combine",
80
- json_data={
81
- "vectors": vectors,
82
- "model": model,
83
- }
84
- )
85
- return ControlVector(**data)
@@ -1,168 +0,0 @@
1
- """
2
- Manager for working with control vectors.
3
- """
4
-
5
- import logging
6
- from typing import Dict, List, Optional, Union
7
-
8
- import torch
9
-
10
- from wisent.control_vector.models import ControlVector, ControlVectorConfig
11
- from wisent.utils.auth import AuthManager
12
- from wisent.utils.http import HTTPClient
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class ControlVectorManager:
18
- """
19
- Manager for working with control vectors.
20
-
21
- Args:
22
- api_key: Wisent API key
23
- base_url: Base URL for the API
24
- timeout: Request timeout in seconds
25
- """
26
-
27
- def __init__(
28
- self,
29
- api_key: str,
30
- base_url: str = "https://api.wisent.ai",
31
- timeout: int = 60,
32
- ):
33
- self.auth = AuthManager(api_key)
34
- self.http_client = HTTPClient(base_url, self.auth.get_headers(), timeout)
35
- self.cache = {} # Simple in-memory cache
36
-
37
- def get(self, name: str, model: str) -> ControlVector:
38
- """
39
- Get a control vector from the Wisent backend.
40
-
41
- Args:
42
- name: Name of the control vector
43
- model: Model name
44
-
45
- Returns:
46
- Control vector
47
- """
48
- cache_key = f"{name}:{model}"
49
- if cache_key in self.cache:
50
- logger.info(f"Using cached control vector: {name} for model {model}")
51
- return self.cache[cache_key]
52
-
53
- logger.info(f"Fetching control vector: {name} for model {model}")
54
- data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
55
- vector = ControlVector(**data)
56
-
57
- # Cache the result
58
- self.cache[cache_key] = vector
59
-
60
- return vector
61
-
62
- def list(
63
- self,
64
- model: Optional[str] = None,
65
- limit: int = 100,
66
- offset: int = 0,
67
- ) -> List[Dict]:
68
- """
69
- List available control vectors from the Wisent backend.
70
-
71
- Args:
72
- model: Filter by model name
73
- limit: Maximum number of results
74
- offset: Offset for pagination
75
-
76
- Returns:
77
- List of control vector metadata
78
- """
79
- params = {"limit": limit, "offset": offset}
80
- if model:
81
- params["model"] = model
82
-
83
- return self.http_client.get("/control_vectors", params=params)
84
-
85
- def combine(
86
- self,
87
- vectors: Dict[str, float],
88
- model: str,
89
- ) -> ControlVector:
90
- """
91
- Combine multiple control vectors with weights.
92
-
93
- Args:
94
- vectors: Dictionary mapping vector names to weights
95
- model: Model name
96
-
97
- Returns:
98
- Combined control vector
99
- """
100
- # Check if we can combine locally
101
- can_combine_locally = True
102
- local_vectors = {}
103
-
104
- for name in vectors.keys():
105
- cache_key = f"{name}:{model}"
106
- if cache_key not in self.cache:
107
- can_combine_locally = False
108
- break
109
- local_vectors[name] = self.cache[cache_key]
110
-
111
- if can_combine_locally:
112
- logger.info(f"Combining vectors locally for model {model}")
113
- return self._combine_locally(local_vectors, vectors, model)
114
-
115
- # Otherwise, use the API
116
- logger.info(f"Combining vectors via API for model {model}")
117
- data = self.http_client.post(
118
- "/control_vectors/combine",
119
- json_data={
120
- "vectors": vectors,
121
- "model": model,
122
- }
123
- )
124
- return ControlVector(**data)
125
-
126
- def _combine_locally(
127
- self,
128
- vectors: Dict[str, ControlVector],
129
- weights: Dict[str, float],
130
- model: str,
131
- ) -> ControlVector:
132
- """
133
- Combine vectors locally.
134
-
135
- Args:
136
- vectors: Dictionary mapping vector names to ControlVector objects
137
- weights: Dictionary mapping vector names to weights
138
- model: Model name
139
-
140
- Returns:
141
- Combined control vector
142
- """
143
- # Convert all vectors to tensors
144
- tensor_vectors = {}
145
- for name, vector in vectors.items():
146
- tensor_vectors[name] = vector.to_tensor()
147
-
148
- # Get the shape from the first vector
149
- first_vector = next(iter(tensor_vectors.values()))
150
- combined = torch.zeros_like(first_vector)
151
-
152
- # Combine vectors with weights
153
- for name, weight in weights.items():
154
- if name in tensor_vectors:
155
- combined += tensor_vectors[name] * weight
156
-
157
- # Create a new control vector
158
- vector_names = list(weights.keys())
159
- combined_name = f"combined_{'_'.join(vector_names)}"
160
-
161
- return ControlVector(
162
- name=combined_name,
163
- model_name=model,
164
- values=combined,
165
- metadata={
166
- "combined_from": {name: weight for name, weight in weights.items()},
167
- }
168
- )
@@ -1,70 +0,0 @@
1
- """
2
- Data models for control vectors.
3
- """
4
-
5
- from dataclasses import dataclass
6
- from typing import Dict, List, Optional, Union
7
-
8
- import numpy as np
9
- import torch
10
- from pydantic import BaseModel, Field
11
-
12
-
13
- class ControlVector(BaseModel):
14
- """
15
- Represents a control vector for steering model outputs.
16
-
17
- Attributes:
18
- name: Name of the control vector
19
- model_name: Name of the model the vector is for
20
- values: Vector values
21
- metadata: Additional metadata
22
- """
23
-
24
- name: str
25
- model_name: str
26
- values: Union[List[float], np.ndarray, torch.Tensor]
27
- metadata: Optional[Dict] = Field(default_factory=dict)
28
-
29
- class Config:
30
- arbitrary_types_allowed = True
31
-
32
- def to_dict(self) -> Dict:
33
- """Convert to dictionary for API requests."""
34
- values = self.values
35
- if isinstance(values, torch.Tensor):
36
- values = values.detach().cpu().numpy()
37
- if isinstance(values, np.ndarray):
38
- values = values.tolist()
39
-
40
- return {
41
- "name": self.name,
42
- "model_name": self.model_name,
43
- "values": values,
44
- "metadata": self.metadata or {},
45
- }
46
-
47
- def to_tensor(self, device: str = "cpu") -> torch.Tensor:
48
- """Convert values to a PyTorch tensor."""
49
- if isinstance(self.values, torch.Tensor):
50
- return self.values.to(device)
51
- elif isinstance(self.values, np.ndarray):
52
- return torch.tensor(self.values, device=device)
53
- else:
54
- return torch.tensor(self.values, device=device)
55
-
56
-
57
- @dataclass
58
- class ControlVectorConfig:
59
- """
60
- Configuration for control vector application.
61
-
62
- Attributes:
63
- scale: Scaling factor for the control vector
64
- method: Method for applying the control vector
65
- layers: Layers to apply the control vector to
66
- """
67
-
68
- scale: float = 1.0
69
- method: str = "caa" # Context-Aware Addition
70
- layers: Optional[List[int]] = None
@@ -1,9 +0,0 @@
1
- """
2
- Functionality for model inference with control vectors.
3
- """
4
-
5
- from wisent.inference.client import InferenceClient
6
- from wisent.inference.inferencer import Inferencer
7
- from wisent.inference.models import InferenceConfig, InferenceResponse
8
-
9
- __all__ = ["InferenceClient", "Inferencer", "InferenceConfig", "InferenceResponse"]
@@ -1,103 +0,0 @@
1
- """
2
- Client for interacting with the inference API.
3
- """
4
-
5
- from typing import Dict, List, Optional, Union
6
-
7
- from wisent.inference.models import InferenceConfig, InferenceResponse
8
- from wisent.utils.auth import AuthManager
9
- from wisent.utils.http import HTTPClient
10
-
11
-
12
- class InferenceClient:
13
- """
14
- Client for interacting with the inference API.
15
-
16
- Args:
17
- auth_manager: Authentication manager
18
- base_url: Base URL for the API
19
- timeout: Request timeout in seconds
20
- """
21
-
22
- def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
23
- self.auth_manager = auth_manager
24
- self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
25
-
26
- def generate(
27
- self,
28
- model_name: str,
29
- prompt: str,
30
- config: Optional[InferenceConfig] = None,
31
- ) -> InferenceResponse:
32
- """
33
- Generate text using a model.
34
-
35
- Args:
36
- model_name: Name of the model
37
- prompt: Input prompt
38
- config: Inference configuration
39
-
40
- Returns:
41
- Inference response
42
- """
43
- config = config or InferenceConfig()
44
-
45
- data = self.http_client.post(
46
- "/inference/generate",
47
- json_data={
48
- "model": model_name,
49
- "prompt": prompt,
50
- "max_tokens": config.max_tokens,
51
- "temperature": config.temperature,
52
- "top_p": config.top_p,
53
- "top_k": config.top_k,
54
- "repetition_penalty": config.repetition_penalty,
55
- "stop_sequences": config.stop_sequences,
56
- }
57
- )
58
-
59
- return InferenceResponse(**data)
60
-
61
- def generate_with_control(
62
- self,
63
- model_name: str,
64
- prompt: str,
65
- control_vectors: Dict[str, float],
66
- method: str = "caa",
67
- scale: float = 1.0,
68
- config: Optional[InferenceConfig] = None,
69
- ) -> InferenceResponse:
70
- """
71
- Generate text using a model with control vectors.
72
-
73
- Args:
74
- model_name: Name of the model
75
- prompt: Input prompt
76
- control_vectors: Dictionary mapping vector names to weights
77
- method: Method for applying control vectors
78
- scale: Scaling factor for control vectors
79
- config: Inference configuration
80
-
81
- Returns:
82
- Inference response
83
- """
84
- config = config or InferenceConfig()
85
-
86
- data = self.http_client.post(
87
- "/inference/generate_with_control",
88
- json_data={
89
- "model": model_name,
90
- "prompt": prompt,
91
- "control_vectors": control_vectors,
92
- "method": method,
93
- "scale": scale,
94
- "max_tokens": config.max_tokens,
95
- "temperature": config.temperature,
96
- "top_p": config.top_p,
97
- "top_k": config.top_k,
98
- "repetition_penalty": config.repetition_penalty,
99
- "stop_sequences": config.stop_sequences,
100
- }
101
- )
102
-
103
- return InferenceResponse(**data)