wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,250 +0,0 @@
1
- """
2
- Functionality for local inference with control vectors.
3
- """
4
-
5
- import logging
6
- from typing import Dict, List, Optional, Union
7
-
8
- import torch
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
10
-
11
- from wisent.control_vector.models import ControlVector
12
- from wisent.inference.models import ControlVectorInferenceConfig, InferenceConfig, InferenceResponse
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class ControlVectorHook:
18
- """
19
- Hook for applying control vectors during inference.
20
-
21
- Args:
22
- control_vector: Control vector to apply
23
- config: Configuration for applying the control vector
24
- """
25
-
26
- def __init__(
27
- self,
28
- control_vector: ControlVector,
29
- config: ControlVectorInferenceConfig,
30
- ):
31
- self.control_vector = control_vector
32
- self.config = config
33
- self.device = None
34
- self.vector_tensor = None
35
- self.hooks = []
36
-
37
- def register(self, model):
38
- """
39
- Register hooks on the model.
40
-
41
- Args:
42
- model: The model to register hooks on
43
- """
44
- self.device = next(model.parameters()).device
45
- self.vector_tensor = self.control_vector.to_tensor(self.device)
46
-
47
- # Get transformer layers
48
- if hasattr(model, "transformer"):
49
- transformer_layers = model.transformer.h
50
- elif hasattr(model, "model") and hasattr(model.model, "layers"):
51
- transformer_layers = model.model.layers
52
- else:
53
- raise ValueError(f"Unsupported model architecture: {model.__class__.__name__}")
54
-
55
- # Determine which layers to apply the control vector to
56
- num_layers = len(transformer_layers)
57
- layers = self.config.layers or [num_layers - 1] # Default to last layer
58
-
59
- # Resolve negative indices
60
- resolved_layers = []
61
- for layer in layers:
62
- if layer < 0:
63
- resolved_layer = num_layers + layer
64
- else:
65
- resolved_layer = layer
66
-
67
- if 0 <= resolved_layer < num_layers:
68
- resolved_layers.append(resolved_layer)
69
-
70
- # Register hooks
71
- for layer_idx in resolved_layers:
72
- layer = transformer_layers[layer_idx]
73
-
74
- # Define hook function
75
- def hook_fn(module, input, output, layer_idx=layer_idx):
76
- if isinstance(output, tuple):
77
- hidden_states = output[0]
78
- else:
79
- hidden_states = output
80
-
81
- # Apply the control vector
82
- if self.config.method == "caa": # Context-Aware Addition
83
- # Add the control vector to the hidden states
84
- modified = hidden_states + self.vector_tensor * self.config.scale
85
-
86
- if isinstance(output, tuple):
87
- return (modified,) + output[1:]
88
- else:
89
- return modified
90
- else:
91
- logger.warning(f"Unsupported method: {self.config.method}, using original output")
92
- return output
93
-
94
- # Register hook
95
- if hasattr(layer, "output"):
96
- handle = layer.output.register_forward_hook(
97
- lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
98
- )
99
- else:
100
- handle = layer.register_forward_hook(
101
- lambda module, input, output, layer_idx=layer_idx: hook_fn(module, input, output, layer_idx)
102
- )
103
-
104
- self.hooks.append(handle)
105
-
106
- def remove(self):
107
- """Remove all registered hooks."""
108
- for hook in self.hooks:
109
- hook.remove()
110
- self.hooks = []
111
-
112
-
113
- class Inferencer:
114
- """
115
- Performs local inference with control vectors.
116
-
117
- Args:
118
- model_name: Name of the model
119
- device: Device to use for inference
120
- """
121
-
122
- def __init__(
123
- self,
124
- model_name: str,
125
- device: Optional[str] = None,
126
- ):
127
- self.model_name = model_name
128
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
129
- self.model = None
130
- self.tokenizer = None
131
-
132
- logger.info(f"Initializing Inferencer for model {model_name} on {self.device}")
133
-
134
- def _load_model(self):
135
- """Load the model and tokenizer."""
136
- if self.model is None:
137
- logger.info(f"Loading model {self.model_name}")
138
- self.model = AutoModelForCausalLM.from_pretrained(
139
- self.model_name,
140
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
141
- device_map=self.device
142
- )
143
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
144
- logger.info(f"Model loaded successfully")
145
-
146
- def generate(
147
- self,
148
- prompt: str,
149
- control_vector: Optional[ControlVector] = None,
150
- method: str = "caa",
151
- scale: float = 1.0,
152
- layers: Optional[List[int]] = None,
153
- config: Optional[InferenceConfig] = None,
154
- ) -> InferenceResponse:
155
- """
156
- Generate text using the model, optionally with a control vector.
157
-
158
- Args:
159
- prompt: Input prompt
160
- control_vector: Control vector to apply (optional)
161
- method: Method for applying the control vector
162
- scale: Scaling factor for the control vector
163
- layers: Layers to apply the control vector to
164
- config: Inference configuration
165
-
166
- Returns:
167
- Inference response
168
- """
169
- try:
170
- self._load_model()
171
-
172
- config = config or InferenceConfig()
173
- hook = None
174
-
175
- # Register control vector hook if provided
176
- if control_vector is not None:
177
- cv_config = ControlVectorInferenceConfig(
178
- method=method,
179
- scale=scale,
180
- layers=layers,
181
- )
182
- hook = ControlVectorHook(control_vector, cv_config)
183
- hook.register(self.model)
184
-
185
- # Tokenize input
186
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
187
- prompt_length = inputs.input_ids.shape[1]
188
-
189
- # Configure generation
190
- generation_config = GenerationConfig(
191
- max_new_tokens=config.max_tokens,
192
- temperature=config.temperature,
193
- top_p=config.top_p,
194
- top_k=config.top_k,
195
- repetition_penalty=config.repetition_penalty,
196
- do_sample=config.temperature > 0,
197
- pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
198
- )
199
-
200
- # Generate
201
- with torch.no_grad():
202
- output_ids = self.model.generate(
203
- inputs.input_ids,
204
- attention_mask=inputs.attention_mask,
205
- generation_config=generation_config,
206
- )
207
-
208
- # Remove control vector hook if registered
209
- if hook is not None:
210
- hook.remove()
211
-
212
- # Decode output
213
- generated_text = self.tokenizer.decode(
214
- output_ids[0][prompt_length:],
215
- skip_special_tokens=True
216
- )
217
-
218
- # Create response
219
- return InferenceResponse(
220
- text=generated_text,
221
- model=self.model_name,
222
- prompt=prompt,
223
- finish_reason="length", # Simplified
224
- usage={
225
- "prompt_tokens": prompt_length,
226
- "completion_tokens": output_ids.shape[1] - prompt_length,
227
- "total_tokens": output_ids.shape[1],
228
- },
229
- metadata={
230
- "control_vector": control_vector.name if control_vector else None,
231
- "method": method if control_vector else None,
232
- "scale": scale if control_vector else None,
233
- }
234
- )
235
-
236
- except Exception as e:
237
- logger.error(f"Error during inference: {str(e)}")
238
- if hook is not None:
239
- hook.remove()
240
- raise
241
-
242
- def __del__(self):
243
- """Clean up resources."""
244
- # Free GPU memory
245
- if self.model is not None and hasattr(self.model, "to"):
246
- self.model = self.model.to("cpu")
247
-
248
- # Clear CUDA cache
249
- if torch.cuda.is_available():
250
- torch.cuda.empty_cache()
@@ -1,66 +0,0 @@
1
- """
2
- Data models for inference.
3
- """
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Dict, List, Optional, Union
7
-
8
- from pydantic import BaseModel, Field
9
-
10
-
11
- class InferenceConfig(BaseModel):
12
- """
13
- Configuration for model inference.
14
-
15
- Attributes:
16
- max_tokens: Maximum number of tokens to generate
17
- temperature: Sampling temperature
18
- top_p: Top-p sampling parameter
19
- top_k: Top-k sampling parameter
20
- repetition_penalty: Repetition penalty
21
- stop_sequences: Sequences that stop generation
22
- """
23
-
24
- max_tokens: int = 256
25
- temperature: float = 0.7
26
- top_p: float = 0.9
27
- top_k: int = 50
28
- repetition_penalty: float = 1.0
29
- stop_sequences: Optional[List[str]] = None
30
-
31
-
32
- class InferenceResponse(BaseModel):
33
- """
34
- Response from model inference.
35
-
36
- Attributes:
37
- text: Generated text
38
- model: Model used for generation
39
- prompt: Input prompt
40
- finish_reason: Reason generation stopped
41
- usage: Token usage information
42
- metadata: Additional metadata
43
- """
44
-
45
- text: str
46
- model: str
47
- prompt: str
48
- finish_reason: str = "length"
49
- usage: Dict[str, int] = Field(default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0})
50
- metadata: Dict = Field(default_factory=dict)
51
-
52
-
53
- @dataclass
54
- class ControlVectorInferenceConfig:
55
- """
56
- Configuration for inference with control vectors.
57
-
58
- Attributes:
59
- method: Method for applying control vectors
60
- scale: Scaling factor for control vectors
61
- layers: Layers to apply control vectors to
62
- """
63
-
64
- method: str = "caa" # Context-Aware Addition
65
- scale: float = 1.0
66
- layers: Optional[List[int]] = None
wisent/utils/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- """
2
- Utility functions and classes for the Wisent package.
3
- """
wisent/utils/auth.py DELETED
@@ -1,30 +0,0 @@
1
- """
2
- Authentication utilities for the Wisent API.
3
- """
4
-
5
- from typing import Dict
6
-
7
-
8
- class AuthManager:
9
- """
10
- Manages authentication for Wisent API requests.
11
-
12
- Args:
13
- api_key: The Wisent API key
14
- """
15
-
16
- def __init__(self, api_key: str):
17
- self.api_key = api_key
18
-
19
- def get_headers(self) -> Dict[str, str]:
20
- """
21
- Get the authentication headers for API requests.
22
-
23
- Returns:
24
- Dict containing the authentication headers
25
- """
26
- return {
27
- "Authorization": f"Bearer {self.api_key}",
28
- "Content-Type": "application/json",
29
- "Accept": "application/json",
30
- }
wisent/utils/http.py DELETED
@@ -1,228 +0,0 @@
1
- """
2
- HTTP request utilities for the Wisent API.
3
- """
4
-
5
- import json
6
- from typing import Any, Dict, Optional, Union
7
-
8
- import aiohttp
9
- import requests
10
- from requests.exceptions import RequestException
11
-
12
-
13
- class APIError(Exception):
14
- """Exception raised for API errors."""
15
-
16
- def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict[str, Any]] = None):
17
- self.message = message
18
- self.status_code = status_code
19
- self.response = response
20
- super().__init__(self.message)
21
-
22
-
23
- class HTTPClient:
24
- """
25
- HTTP client for making requests to the Wisent API.
26
-
27
- Args:
28
- base_url: The base URL for the API
29
- headers: Headers to include in all requests
30
- timeout: Request timeout in seconds
31
- """
32
-
33
- def __init__(self, base_url: str, headers: Dict[str, str], timeout: int = 60):
34
- self.base_url = base_url.rstrip("/")
35
- self.headers = headers
36
- self.timeout = timeout
37
-
38
- def _build_url(self, endpoint: str) -> str:
39
- """Build the full URL for an API endpoint."""
40
- endpoint = endpoint.lstrip("/")
41
- return f"{self.base_url}/{endpoint}"
42
-
43
- def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
44
- """
45
- Make a GET request to the API.
46
-
47
- Args:
48
- endpoint: API endpoint
49
- params: Query parameters
50
-
51
- Returns:
52
- Response data as a dictionary
53
-
54
- Raises:
55
- APIError: If the request fails
56
- """
57
- url = self._build_url(endpoint)
58
- try:
59
- response = requests.get(
60
- url,
61
- headers=self.headers,
62
- params=params,
63
- timeout=self.timeout
64
- )
65
- response.raise_for_status()
66
- return response.json()
67
- except RequestException as e:
68
- status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
69
- response_data = None
70
-
71
- if hasattr(e, "response") and e.response is not None:
72
- try:
73
- response_data = e.response.json()
74
- except (ValueError, AttributeError):
75
- response_data = {"error": str(e)}
76
-
77
- raise APIError(
78
- f"GET request to {url} failed: {str(e)}",
79
- status_code=status_code,
80
- response=response_data
81
- ) from e
82
-
83
- def post(self, endpoint: str, data: Optional[Dict[str, Any]] = None, json_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
84
- """
85
- Make a POST request to the API.
86
-
87
- Args:
88
- endpoint: API endpoint
89
- data: Form data
90
- json_data: JSON data
91
-
92
- Returns:
93
- Response data as a dictionary
94
-
95
- Raises:
96
- APIError: If the request fails
97
- """
98
- url = self._build_url(endpoint)
99
- try:
100
- response = requests.post(
101
- url,
102
- headers=self.headers,
103
- data=data,
104
- json=json_data,
105
- timeout=self.timeout
106
- )
107
- response.raise_for_status()
108
- return response.json()
109
- except RequestException as e:
110
- status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
111
- response_data = None
112
-
113
- if hasattr(e, "response") and e.response is not None:
114
- try:
115
- response_data = e.response.json()
116
- except (ValueError, AttributeError):
117
- response_data = {"error": str(e)}
118
-
119
- raise APIError(
120
- f"POST request to {url} failed: {str(e)}",
121
- status_code=status_code,
122
- response=response_data
123
- ) from e
124
-
125
-
126
- class AsyncHTTPClient:
127
- """
128
- Asynchronous HTTP client for making requests to the Wisent API.
129
-
130
- Args:
131
- base_url: The base URL for the API
132
- headers: Headers to include in all requests
133
- timeout: Request timeout in seconds
134
- """
135
-
136
- def __init__(self, base_url: str, headers: Dict[str, str], timeout: int = 60):
137
- self.base_url = base_url.rstrip("/")
138
- self.headers = headers
139
- self.timeout = timeout
140
-
141
- def _build_url(self, endpoint: str) -> str:
142
- """Build the full URL for an API endpoint."""
143
- endpoint = endpoint.lstrip("/")
144
- return f"{self.base_url}/{endpoint}"
145
-
146
- async def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
147
- """
148
- Make an asynchronous GET request to the API.
149
-
150
- Args:
151
- endpoint: API endpoint
152
- params: Query parameters
153
-
154
- Returns:
155
- Response data as a dictionary
156
-
157
- Raises:
158
- APIError: If the request fails
159
- """
160
- url = self._build_url(endpoint)
161
- try:
162
- async with aiohttp.ClientSession() as session:
163
- async with session.get(
164
- url,
165
- headers=self.headers,
166
- params=params,
167
- timeout=self.timeout
168
- ) as response:
169
- response.raise_for_status()
170
- return await response.json()
171
- except aiohttp.ClientError as e:
172
- status_code = getattr(response, "status", None) if 'response' in locals() else None
173
- response_data = None
174
-
175
- if 'response' in locals():
176
- try:
177
- response_data = await response.json()
178
- except (ValueError, AttributeError):
179
- response_data = {"error": str(e)}
180
-
181
- raise APIError(
182
- f"Async GET request to {url} failed: {str(e)}",
183
- status_code=status_code,
184
- response=response_data
185
- ) from e
186
-
187
- async def post(self, endpoint: str, data: Optional[Dict[str, Any]] = None, json_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
188
- """
189
- Make an asynchronous POST request to the API.
190
-
191
- Args:
192
- endpoint: API endpoint
193
- data: Form data
194
- json_data: JSON data
195
-
196
- Returns:
197
- Response data as a dictionary
198
-
199
- Raises:
200
- APIError: If the request fails
201
- """
202
- url = self._build_url(endpoint)
203
- try:
204
- async with aiohttp.ClientSession() as session:
205
- async with session.post(
206
- url,
207
- headers=self.headers,
208
- data=data,
209
- json=json_data,
210
- timeout=self.timeout
211
- ) as response:
212
- response.raise_for_status()
213
- return await response.json()
214
- except aiohttp.ClientError as e:
215
- status_code = getattr(response, "status", None) if 'response' in locals() else None
216
- response_data = None
217
-
218
- if 'response' in locals():
219
- try:
220
- response_data = await response.json()
221
- except (ValueError, AttributeError):
222
- response_data = {"error": str(e)}
223
-
224
- raise APIError(
225
- f"Async POST request to {url} failed: {str(e)}",
226
- status_code=status_code,
227
- response=response_data
228
- ) from e
wisent/version.py DELETED
@@ -1,3 +0,0 @@
1
- """Version information."""
2
-
3
- __version__ = "0.1.0"