wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
wisent/activations/models.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Data models for model activations.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
|
-
from typing import Dict, List, Optional, Union
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
from pydantic import BaseModel, Field
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class Activation(BaseModel):
|
|
14
|
-
"""
|
|
15
|
-
Represents a single activation from a model.
|
|
16
|
-
|
|
17
|
-
Attributes:
|
|
18
|
-
model_name: Name of the model
|
|
19
|
-
layer: Layer index
|
|
20
|
-
token_index: Token index
|
|
21
|
-
values: Activation values
|
|
22
|
-
token_str: String representation of the token (optional)
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
model_name: str
|
|
26
|
-
layer: int
|
|
27
|
-
token_index: int
|
|
28
|
-
values: Union[List[float], np.ndarray, torch.Tensor]
|
|
29
|
-
token_str: Optional[str] = None
|
|
30
|
-
|
|
31
|
-
class Config:
|
|
32
|
-
arbitrary_types_allowed = True
|
|
33
|
-
|
|
34
|
-
def to_dict(self) -> Dict:
|
|
35
|
-
"""Convert to dictionary for API requests."""
|
|
36
|
-
values = self.values
|
|
37
|
-
if isinstance(values, torch.Tensor):
|
|
38
|
-
values = values.detach().cpu().numpy()
|
|
39
|
-
if isinstance(values, np.ndarray):
|
|
40
|
-
values = values.tolist()
|
|
41
|
-
|
|
42
|
-
return {
|
|
43
|
-
"model_name": self.model_name,
|
|
44
|
-
"layer": self.layer,
|
|
45
|
-
"token_index": self.token_index,
|
|
46
|
-
"values": values,
|
|
47
|
-
"token_str": self.token_str,
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class ActivationBatch(BaseModel):
|
|
52
|
-
"""
|
|
53
|
-
Represents a batch of activations from a model.
|
|
54
|
-
|
|
55
|
-
Attributes:
|
|
56
|
-
model_name: Name of the model
|
|
57
|
-
prompt: Input prompt that generated the activations
|
|
58
|
-
activations: List of activations
|
|
59
|
-
metadata: Additional metadata (optional)
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
model_name: str
|
|
63
|
-
prompt: str
|
|
64
|
-
activations: List[Activation]
|
|
65
|
-
metadata: Optional[Dict] = Field(default_factory=dict)
|
|
66
|
-
|
|
67
|
-
class Config:
|
|
68
|
-
arbitrary_types_allowed = True
|
|
69
|
-
|
|
70
|
-
def to_dict(self) -> Dict:
|
|
71
|
-
"""Convert to dictionary for API requests."""
|
|
72
|
-
return {
|
|
73
|
-
"model_name": self.model_name,
|
|
74
|
-
"prompt": self.prompt,
|
|
75
|
-
"activations": [a.to_dict() for a in self.activations],
|
|
76
|
-
"metadata": self.metadata or {},
|
|
77
|
-
}
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@dataclass
|
|
81
|
-
class ActivationExtractorConfig:
|
|
82
|
-
"""
|
|
83
|
-
Configuration for activation extraction.
|
|
84
|
-
|
|
85
|
-
Attributes:
|
|
86
|
-
layers: List of layers to extract activations from
|
|
87
|
-
tokens_to_extract: List of token indices to extract (negative indices count from the end)
|
|
88
|
-
batch_size: Batch size for processing
|
|
89
|
-
device: Device to use for extraction
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
layers: List[int] = field(default_factory=lambda: [-1])
|
|
93
|
-
tokens_to_extract: List[int] = field(default_factory=lambda: [-1])
|
|
94
|
-
batch_size: int = 1
|
|
95
|
-
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
wisent/client.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Main client class for interacting with the Wisent backend services.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Dict, Optional
|
|
6
|
-
|
|
7
|
-
from wisent.activations import ActivationsClient
|
|
8
|
-
from wisent.control_vector import ControlVectorClient
|
|
9
|
-
from wisent.inference import InferenceClient
|
|
10
|
-
from wisent.utils.auth import AuthManager
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class WisentClient:
|
|
14
|
-
"""
|
|
15
|
-
Main client for interacting with the Wisent backend services.
|
|
16
|
-
|
|
17
|
-
This client provides access to all Wisent API functionality through
|
|
18
|
-
specialized sub-clients for different features.
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
api_key: Your Wisent API key
|
|
22
|
-
base_url: The base URL for the Wisent API (default: https://api.wisent.ai)
|
|
23
|
-
timeout: Request timeout in seconds (default: 60)
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
api_key: str,
|
|
29
|
-
base_url: str = "https://api.wisent.ai",
|
|
30
|
-
timeout: int = 60,
|
|
31
|
-
):
|
|
32
|
-
self.api_key = api_key
|
|
33
|
-
self.base_url = base_url
|
|
34
|
-
self.timeout = timeout
|
|
35
|
-
|
|
36
|
-
# Initialize auth manager
|
|
37
|
-
self.auth = AuthManager(api_key)
|
|
38
|
-
|
|
39
|
-
# Initialize sub-clients
|
|
40
|
-
self.activations = ActivationsClient(self.auth, base_url, timeout)
|
|
41
|
-
self.control_vector = ControlVectorClient(self.auth, base_url, timeout)
|
|
42
|
-
self.inference = InferenceClient(self.auth, base_url, timeout)
|
|
43
|
-
|
|
44
|
-
def __repr__(self) -> str:
|
|
45
|
-
return f"WisentClient(base_url='{self.base_url}')"
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Functionality for working with control vectors.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from wisent.control_vector.client import ControlVectorClient
|
|
6
|
-
from wisent.control_vector.manager import ControlVectorManager
|
|
7
|
-
from wisent.control_vector.models import ControlVector, ControlVectorConfig
|
|
8
|
-
|
|
9
|
-
__all__ = ["ControlVectorClient", "ControlVectorManager", "ControlVector", "ControlVectorConfig"]
|
wisent/control_vector/client.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Client for interacting with the control vector API.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Dict, List, Optional, Union
|
|
6
|
-
|
|
7
|
-
from wisent.control_vector.models import ControlVector
|
|
8
|
-
from wisent.utils.auth import AuthManager
|
|
9
|
-
from wisent.utils.http import HTTPClient
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class ControlVectorClient:
|
|
13
|
-
"""
|
|
14
|
-
Client for interacting with the control vector API.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
auth_manager: Authentication manager
|
|
18
|
-
base_url: Base URL for the API
|
|
19
|
-
timeout: Request timeout in seconds
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
|
|
23
|
-
self.auth_manager = auth_manager
|
|
24
|
-
self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
|
|
25
|
-
|
|
26
|
-
def get(self, name: str, model: str) -> ControlVector:
|
|
27
|
-
"""
|
|
28
|
-
Get a control vector from the Wisent backend.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
name: Name of the control vector
|
|
32
|
-
model: Model name
|
|
33
|
-
|
|
34
|
-
Returns:
|
|
35
|
-
Control vector
|
|
36
|
-
"""
|
|
37
|
-
data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
|
|
38
|
-
return ControlVector(**data)
|
|
39
|
-
|
|
40
|
-
def list(
|
|
41
|
-
self,
|
|
42
|
-
model: Optional[str] = None,
|
|
43
|
-
limit: int = 100,
|
|
44
|
-
offset: int = 0,
|
|
45
|
-
) -> List[Dict]:
|
|
46
|
-
"""
|
|
47
|
-
List available control vectors from the Wisent backend.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
model: Filter by model name
|
|
51
|
-
limit: Maximum number of results
|
|
52
|
-
offset: Offset for pagination
|
|
53
|
-
|
|
54
|
-
Returns:
|
|
55
|
-
List of control vector metadata
|
|
56
|
-
"""
|
|
57
|
-
params = {"limit": limit, "offset": offset}
|
|
58
|
-
if model:
|
|
59
|
-
params["model"] = model
|
|
60
|
-
|
|
61
|
-
return self.http_client.get("/control_vectors", params=params)
|
|
62
|
-
|
|
63
|
-
def combine(
|
|
64
|
-
self,
|
|
65
|
-
vectors: Dict[str, float],
|
|
66
|
-
model: str,
|
|
67
|
-
) -> ControlVector:
|
|
68
|
-
"""
|
|
69
|
-
Combine multiple control vectors with weights.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
vectors: Dictionary mapping vector names to weights
|
|
73
|
-
model: Model name
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
Combined control vector
|
|
77
|
-
"""
|
|
78
|
-
data = self.http_client.post(
|
|
79
|
-
"/control_vectors/combine",
|
|
80
|
-
json_data={
|
|
81
|
-
"vectors": vectors,
|
|
82
|
-
"model": model,
|
|
83
|
-
}
|
|
84
|
-
)
|
|
85
|
-
return ControlVector(**data)
|
wisent/control_vector/manager.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Manager for working with control vectors.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import logging
|
|
6
|
-
from typing import Dict, List, Optional, Union
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from wisent.control_vector.models import ControlVector, ControlVectorConfig
|
|
11
|
-
from wisent.utils.auth import AuthManager
|
|
12
|
-
from wisent.utils.http import HTTPClient
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ControlVectorManager:
|
|
18
|
-
"""
|
|
19
|
-
Manager for working with control vectors.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
api_key: Wisent API key
|
|
23
|
-
base_url: Base URL for the API
|
|
24
|
-
timeout: Request timeout in seconds
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
api_key: str,
|
|
30
|
-
base_url: str = "https://api.wisent.ai",
|
|
31
|
-
timeout: int = 60,
|
|
32
|
-
):
|
|
33
|
-
self.auth = AuthManager(api_key)
|
|
34
|
-
self.http_client = HTTPClient(base_url, self.auth.get_headers(), timeout)
|
|
35
|
-
self.cache = {} # Simple in-memory cache
|
|
36
|
-
|
|
37
|
-
def get(self, name: str, model: str) -> ControlVector:
|
|
38
|
-
"""
|
|
39
|
-
Get a control vector from the Wisent backend.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
name: Name of the control vector
|
|
43
|
-
model: Model name
|
|
44
|
-
|
|
45
|
-
Returns:
|
|
46
|
-
Control vector
|
|
47
|
-
"""
|
|
48
|
-
cache_key = f"{name}:{model}"
|
|
49
|
-
if cache_key in self.cache:
|
|
50
|
-
logger.info(f"Using cached control vector: {name} for model {model}")
|
|
51
|
-
return self.cache[cache_key]
|
|
52
|
-
|
|
53
|
-
logger.info(f"Fetching control vector: {name} for model {model}")
|
|
54
|
-
data = self.http_client.get(f"/control_vectors/{name}", params={"model": model})
|
|
55
|
-
vector = ControlVector(**data)
|
|
56
|
-
|
|
57
|
-
# Cache the result
|
|
58
|
-
self.cache[cache_key] = vector
|
|
59
|
-
|
|
60
|
-
return vector
|
|
61
|
-
|
|
62
|
-
def list(
|
|
63
|
-
self,
|
|
64
|
-
model: Optional[str] = None,
|
|
65
|
-
limit: int = 100,
|
|
66
|
-
offset: int = 0,
|
|
67
|
-
) -> List[Dict]:
|
|
68
|
-
"""
|
|
69
|
-
List available control vectors from the Wisent backend.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
model: Filter by model name
|
|
73
|
-
limit: Maximum number of results
|
|
74
|
-
offset: Offset for pagination
|
|
75
|
-
|
|
76
|
-
Returns:
|
|
77
|
-
List of control vector metadata
|
|
78
|
-
"""
|
|
79
|
-
params = {"limit": limit, "offset": offset}
|
|
80
|
-
if model:
|
|
81
|
-
params["model"] = model
|
|
82
|
-
|
|
83
|
-
return self.http_client.get("/control_vectors", params=params)
|
|
84
|
-
|
|
85
|
-
def combine(
|
|
86
|
-
self,
|
|
87
|
-
vectors: Dict[str, float],
|
|
88
|
-
model: str,
|
|
89
|
-
) -> ControlVector:
|
|
90
|
-
"""
|
|
91
|
-
Combine multiple control vectors with weights.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
vectors: Dictionary mapping vector names to weights
|
|
95
|
-
model: Model name
|
|
96
|
-
|
|
97
|
-
Returns:
|
|
98
|
-
Combined control vector
|
|
99
|
-
"""
|
|
100
|
-
# Check if we can combine locally
|
|
101
|
-
can_combine_locally = True
|
|
102
|
-
local_vectors = {}
|
|
103
|
-
|
|
104
|
-
for name in vectors.keys():
|
|
105
|
-
cache_key = f"{name}:{model}"
|
|
106
|
-
if cache_key not in self.cache:
|
|
107
|
-
can_combine_locally = False
|
|
108
|
-
break
|
|
109
|
-
local_vectors[name] = self.cache[cache_key]
|
|
110
|
-
|
|
111
|
-
if can_combine_locally:
|
|
112
|
-
logger.info(f"Combining vectors locally for model {model}")
|
|
113
|
-
return self._combine_locally(local_vectors, vectors, model)
|
|
114
|
-
|
|
115
|
-
# Otherwise, use the API
|
|
116
|
-
logger.info(f"Combining vectors via API for model {model}")
|
|
117
|
-
data = self.http_client.post(
|
|
118
|
-
"/control_vectors/combine",
|
|
119
|
-
json_data={
|
|
120
|
-
"vectors": vectors,
|
|
121
|
-
"model": model,
|
|
122
|
-
}
|
|
123
|
-
)
|
|
124
|
-
return ControlVector(**data)
|
|
125
|
-
|
|
126
|
-
def _combine_locally(
|
|
127
|
-
self,
|
|
128
|
-
vectors: Dict[str, ControlVector],
|
|
129
|
-
weights: Dict[str, float],
|
|
130
|
-
model: str,
|
|
131
|
-
) -> ControlVector:
|
|
132
|
-
"""
|
|
133
|
-
Combine vectors locally.
|
|
134
|
-
|
|
135
|
-
Args:
|
|
136
|
-
vectors: Dictionary mapping vector names to ControlVector objects
|
|
137
|
-
weights: Dictionary mapping vector names to weights
|
|
138
|
-
model: Model name
|
|
139
|
-
|
|
140
|
-
Returns:
|
|
141
|
-
Combined control vector
|
|
142
|
-
"""
|
|
143
|
-
# Convert all vectors to tensors
|
|
144
|
-
tensor_vectors = {}
|
|
145
|
-
for name, vector in vectors.items():
|
|
146
|
-
tensor_vectors[name] = vector.to_tensor()
|
|
147
|
-
|
|
148
|
-
# Get the shape from the first vector
|
|
149
|
-
first_vector = next(iter(tensor_vectors.values()))
|
|
150
|
-
combined = torch.zeros_like(first_vector)
|
|
151
|
-
|
|
152
|
-
# Combine vectors with weights
|
|
153
|
-
for name, weight in weights.items():
|
|
154
|
-
if name in tensor_vectors:
|
|
155
|
-
combined += tensor_vectors[name] * weight
|
|
156
|
-
|
|
157
|
-
# Create a new control vector
|
|
158
|
-
vector_names = list(weights.keys())
|
|
159
|
-
combined_name = f"combined_{'_'.join(vector_names)}"
|
|
160
|
-
|
|
161
|
-
return ControlVector(
|
|
162
|
-
name=combined_name,
|
|
163
|
-
model_name=model,
|
|
164
|
-
values=combined,
|
|
165
|
-
metadata={
|
|
166
|
-
"combined_from": {name: weight for name, weight in weights.items()},
|
|
167
|
-
}
|
|
168
|
-
)
|
wisent/control_vector/models.py
DELETED
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Data models for control vectors.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from typing import Dict, List, Optional, Union
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
from pydantic import BaseModel, Field
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class ControlVector(BaseModel):
|
|
14
|
-
"""
|
|
15
|
-
Represents a control vector for steering model outputs.
|
|
16
|
-
|
|
17
|
-
Attributes:
|
|
18
|
-
name: Name of the control vector
|
|
19
|
-
model_name: Name of the model the vector is for
|
|
20
|
-
values: Vector values
|
|
21
|
-
metadata: Additional metadata
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
name: str
|
|
25
|
-
model_name: str
|
|
26
|
-
values: Union[List[float], np.ndarray, torch.Tensor]
|
|
27
|
-
metadata: Optional[Dict] = Field(default_factory=dict)
|
|
28
|
-
|
|
29
|
-
class Config:
|
|
30
|
-
arbitrary_types_allowed = True
|
|
31
|
-
|
|
32
|
-
def to_dict(self) -> Dict:
|
|
33
|
-
"""Convert to dictionary for API requests."""
|
|
34
|
-
values = self.values
|
|
35
|
-
if isinstance(values, torch.Tensor):
|
|
36
|
-
values = values.detach().cpu().numpy()
|
|
37
|
-
if isinstance(values, np.ndarray):
|
|
38
|
-
values = values.tolist()
|
|
39
|
-
|
|
40
|
-
return {
|
|
41
|
-
"name": self.name,
|
|
42
|
-
"model_name": self.model_name,
|
|
43
|
-
"values": values,
|
|
44
|
-
"metadata": self.metadata or {},
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
def to_tensor(self, device: str = "cpu") -> torch.Tensor:
|
|
48
|
-
"""Convert values to a PyTorch tensor."""
|
|
49
|
-
if isinstance(self.values, torch.Tensor):
|
|
50
|
-
return self.values.to(device)
|
|
51
|
-
elif isinstance(self.values, np.ndarray):
|
|
52
|
-
return torch.tensor(self.values, device=device)
|
|
53
|
-
else:
|
|
54
|
-
return torch.tensor(self.values, device=device)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@dataclass
|
|
58
|
-
class ControlVectorConfig:
|
|
59
|
-
"""
|
|
60
|
-
Configuration for control vector application.
|
|
61
|
-
|
|
62
|
-
Attributes:
|
|
63
|
-
scale: Scaling factor for the control vector
|
|
64
|
-
method: Method for applying the control vector
|
|
65
|
-
layers: Layers to apply the control vector to
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
scale: float = 1.0
|
|
69
|
-
method: str = "caa" # Context-Aware Addition
|
|
70
|
-
layers: Optional[List[int]] = None
|
wisent/inference/__init__.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Functionality for model inference with control vectors.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from wisent.inference.client import InferenceClient
|
|
6
|
-
from wisent.inference.inferencer import Inferencer
|
|
7
|
-
from wisent.inference.models import InferenceConfig, InferenceResponse
|
|
8
|
-
|
|
9
|
-
__all__ = ["InferenceClient", "Inferencer", "InferenceConfig", "InferenceResponse"]
|
wisent/inference/client.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Client for interacting with the inference API.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Dict, List, Optional, Union
|
|
6
|
-
|
|
7
|
-
from wisent.inference.models import InferenceConfig, InferenceResponse
|
|
8
|
-
from wisent.utils.auth import AuthManager
|
|
9
|
-
from wisent.utils.http import HTTPClient
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class InferenceClient:
|
|
13
|
-
"""
|
|
14
|
-
Client for interacting with the inference API.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
auth_manager: Authentication manager
|
|
18
|
-
base_url: Base URL for the API
|
|
19
|
-
timeout: Request timeout in seconds
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def __init__(self, auth_manager: AuthManager, base_url: str, timeout: int = 60):
|
|
23
|
-
self.auth_manager = auth_manager
|
|
24
|
-
self.http_client = HTTPClient(base_url, auth_manager.get_headers(), timeout)
|
|
25
|
-
|
|
26
|
-
def generate(
|
|
27
|
-
self,
|
|
28
|
-
model_name: str,
|
|
29
|
-
prompt: str,
|
|
30
|
-
config: Optional[InferenceConfig] = None,
|
|
31
|
-
) -> InferenceResponse:
|
|
32
|
-
"""
|
|
33
|
-
Generate text using a model.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
model_name: Name of the model
|
|
37
|
-
prompt: Input prompt
|
|
38
|
-
config: Inference configuration
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
Inference response
|
|
42
|
-
"""
|
|
43
|
-
config = config or InferenceConfig()
|
|
44
|
-
|
|
45
|
-
data = self.http_client.post(
|
|
46
|
-
"/inference/generate",
|
|
47
|
-
json_data={
|
|
48
|
-
"model": model_name,
|
|
49
|
-
"prompt": prompt,
|
|
50
|
-
"max_tokens": config.max_tokens,
|
|
51
|
-
"temperature": config.temperature,
|
|
52
|
-
"top_p": config.top_p,
|
|
53
|
-
"top_k": config.top_k,
|
|
54
|
-
"repetition_penalty": config.repetition_penalty,
|
|
55
|
-
"stop_sequences": config.stop_sequences,
|
|
56
|
-
}
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
return InferenceResponse(**data)
|
|
60
|
-
|
|
61
|
-
def generate_with_control(
|
|
62
|
-
self,
|
|
63
|
-
model_name: str,
|
|
64
|
-
prompt: str,
|
|
65
|
-
control_vectors: Dict[str, float],
|
|
66
|
-
method: str = "caa",
|
|
67
|
-
scale: float = 1.0,
|
|
68
|
-
config: Optional[InferenceConfig] = None,
|
|
69
|
-
) -> InferenceResponse:
|
|
70
|
-
"""
|
|
71
|
-
Generate text using a model with control vectors.
|
|
72
|
-
|
|
73
|
-
Args:
|
|
74
|
-
model_name: Name of the model
|
|
75
|
-
prompt: Input prompt
|
|
76
|
-
control_vectors: Dictionary mapping vector names to weights
|
|
77
|
-
method: Method for applying control vectors
|
|
78
|
-
scale: Scaling factor for control vectors
|
|
79
|
-
config: Inference configuration
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
Inference response
|
|
83
|
-
"""
|
|
84
|
-
config = config or InferenceConfig()
|
|
85
|
-
|
|
86
|
-
data = self.http_client.post(
|
|
87
|
-
"/inference/generate_with_control",
|
|
88
|
-
json_data={
|
|
89
|
-
"model": model_name,
|
|
90
|
-
"prompt": prompt,
|
|
91
|
-
"control_vectors": control_vectors,
|
|
92
|
-
"method": method,
|
|
93
|
-
"scale": scale,
|
|
94
|
-
"max_tokens": config.max_tokens,
|
|
95
|
-
"temperature": config.temperature,
|
|
96
|
-
"top_p": config.top_p,
|
|
97
|
-
"top_k": config.top_k,
|
|
98
|
-
"repetition_penalty": config.repetition_penalty,
|
|
99
|
-
"stop_sequences": config.stop_sequences,
|
|
100
|
-
}
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
return InferenceResponse(**data)
|