wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import importlib.util
|
|
5
|
+
import inspect
|
|
6
|
+
import logging
|
|
7
|
+
import pkgutil
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Type
|
|
10
|
+
|
|
11
|
+
from wisent_guard.core.steering_methods.core.atoms import BaseSteeringError, BaseSteeringMethod
|
|
12
|
+
from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
13
|
+
|
|
14
|
+
from wisent_guard.core.activations.core.atoms import LayerActivations
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"SteeringMethodRotator",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
class SteeringMethodRotator:
|
|
23
|
+
"""Discover/select a steering method and train it on a ContrastivePairSet."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
method: str | BaseSteeringMethod | Type[BaseSteeringMethod] | None = None,
|
|
28
|
+
methods_location: str | Path = "wisent_guard.core.steering_methods.methods",
|
|
29
|
+
autoload: bool = True,
|
|
30
|
+
**default_method_kwargs: Any,
|
|
31
|
+
) -> None:
|
|
32
|
+
if autoload:
|
|
33
|
+
self.discover_methods(methods_location)
|
|
34
|
+
self._method = self._resolve_method(method, **default_method_kwargs)
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def discover_methods(location: str | Path) -> None:
|
|
38
|
+
loc_path = Path(str(location))
|
|
39
|
+
if loc_path.exists() and loc_path.is_dir():
|
|
40
|
+
for py in loc_path.glob("*.py"):
|
|
41
|
+
if py.name.startswith("_"):
|
|
42
|
+
continue
|
|
43
|
+
mod_name = f"_dyn_steering_{py.stem}"
|
|
44
|
+
spec = importlib.util.spec_from_file_location(mod_name, py)
|
|
45
|
+
if spec and spec.loader:
|
|
46
|
+
module = importlib.util.module_from_spec(spec)
|
|
47
|
+
spec.loader.exec_module(module)
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
if not isinstance(location, str):
|
|
51
|
+
raise BaseSteeringError(f"Invalid methods location: {location!r}")
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
pkg = importlib.import_module(location)
|
|
55
|
+
except ModuleNotFoundError as exc:
|
|
56
|
+
raise BaseSteeringError(f"Cannot import steering package {location!r}.") from exc
|
|
57
|
+
|
|
58
|
+
search_paths = list(getattr(pkg, "__path__", [])) or [Path(getattr(pkg, "__file__", "")).parent.as_posix()]
|
|
59
|
+
for _, name, _ in pkgutil.iter_modules(search_paths):
|
|
60
|
+
if name.startswith("_"):
|
|
61
|
+
continue
|
|
62
|
+
importlib.import_module(f"{location}.{name}")
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def list_methods() -> list[dict[str, Any]]:
|
|
66
|
+
return [
|
|
67
|
+
{
|
|
68
|
+
"name": name,
|
|
69
|
+
"description": getattr(cls, "description", ""),
|
|
70
|
+
"class": f"{cls.__module__}.{cls.__name__}",
|
|
71
|
+
}
|
|
72
|
+
for name, cls in sorted(BaseSteeringMethod.list_registered().items(), key=lambda kv: kv[0])
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _resolve_method(
|
|
77
|
+
method: str | BaseSteeringMethod | Type[BaseSteeringMethod] | None,
|
|
78
|
+
**kwargs: Any,
|
|
79
|
+
) -> BaseSteeringMethod:
|
|
80
|
+
if method is None:
|
|
81
|
+
reg = BaseSteeringMethod.list_registered()
|
|
82
|
+
if not reg:
|
|
83
|
+
raise BaseSteeringError("No steering methods registered.")
|
|
84
|
+
first = next(iter(sorted(reg.items(), key=lambda kv: kv[0])))[1]
|
|
85
|
+
return first(**kwargs)
|
|
86
|
+
if isinstance(method, BaseSteeringMethod):
|
|
87
|
+
method.kwargs = {**kwargs, **method.kwargs}
|
|
88
|
+
return method
|
|
89
|
+
if inspect.isclass(method) and issubclass(method, BaseSteeringMethod):
|
|
90
|
+
return method(**kwargs)
|
|
91
|
+
if isinstance(method, str):
|
|
92
|
+
return BaseSteeringMethod.get(method)(**kwargs)
|
|
93
|
+
raise TypeError("method must be None, str name, BaseSteeringMethod instance, or subclass.")
|
|
94
|
+
|
|
95
|
+
def use(self, method: str | BaseSteeringMethod | Type[BaseSteeringMethod], **kwargs: Any) -> None:
|
|
96
|
+
self._method = self._resolve_method(method, **kwargs)
|
|
97
|
+
|
|
98
|
+
def train(self, pair_set: ContrastivePairSet, **overrides: Any) -> LayerActivations:
|
|
99
|
+
old = dict(self._method.kwargs)
|
|
100
|
+
try:
|
|
101
|
+
self._method.kwargs = {**old, **overrides}
|
|
102
|
+
return self._method.train(pair_set)
|
|
103
|
+
finally:
|
|
104
|
+
self._method.kwargs = old
|
|
105
|
+
|
|
106
|
+
if __name__ == "__main__":
|
|
107
|
+
rot = SteeringMethodRotator()
|
|
108
|
+
print("Available steering methods:")
|
|
109
|
+
for m in rot.list_methods():
|
|
110
|
+
print(f" - {m['name']}: {m['description']} ({m['class']})")
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import shlex
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import typer
|
|
5
|
+
|
|
6
|
+
__all__ = ["app", "help_command"]
|
|
7
|
+
|
|
8
|
+
app = typer.Typer(help="Human-friendly help router")
|
|
9
|
+
|
|
10
|
+
def _run_cli_line(line: str) -> None:
|
|
11
|
+
from typer.main import get_command
|
|
12
|
+
from wisent_guard.cli.wisent_cli.main import app as root_app
|
|
13
|
+
click_cmd = get_command(root_app)
|
|
14
|
+
args = shlex.split(line)
|
|
15
|
+
try:
|
|
16
|
+
click_cmd.main(args=args, standalone_mode=False, prog_name="wisent")
|
|
17
|
+
except SystemExit as e:
|
|
18
|
+
if e.code not in (0, None):
|
|
19
|
+
raise
|
|
20
|
+
|
|
21
|
+
@app.command("help")
|
|
22
|
+
def help_command(topic: Optional[str] = typer.Argument(None), name: Optional[str] = typer.Argument(None)):
|
|
23
|
+
"""
|
|
24
|
+
Examples:
|
|
25
|
+
wisent help train
|
|
26
|
+
wisent help method caa
|
|
27
|
+
wisent help loader custom
|
|
28
|
+
wisent help list-methods
|
|
29
|
+
"""
|
|
30
|
+
t = (topic or "").strip().lower()
|
|
31
|
+
|
|
32
|
+
if t in {"", "app", "main"}:
|
|
33
|
+
_run_cli_line("--help")
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
passthrough = {"train", "list-methods", "list-loaders", "list-aggregations", "explain", "instructions", "start"}
|
|
37
|
+
if t in passthrough:
|
|
38
|
+
_run_cli_line(f"{t} --help")
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
if t in {"method", "methods"} and name:
|
|
42
|
+
_run_cli_line(f"explain --method {shlex.quote(name)}")
|
|
43
|
+
return
|
|
44
|
+
if t in {"loader", "loaders"} and name:
|
|
45
|
+
_run_cli_line(f"explain --loader {shlex.quote(name)}")
|
|
46
|
+
_run_cli_line(f"loader-args {shlex.quote(name)}")
|
|
47
|
+
return
|
|
48
|
+
if t in {"aggregation", "agg", "aggregations"} and name:
|
|
49
|
+
_run_cli_line(f"explain --aggregation {shlex.quote(name)}")
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
_run_cli_line("--help")
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
|
|
7
|
+
from wisent_guard.cli.wisent_cli.ui import echo
|
|
8
|
+
from wisent_guard.cli.wisent_cli.util import aggregations as aggs
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from rich.table import Table
|
|
12
|
+
from rich.panel import Panel
|
|
13
|
+
HAS_RICH = True
|
|
14
|
+
except Exception:
|
|
15
|
+
HAS_RICH = False
|
|
16
|
+
|
|
17
|
+
app = typer.Typer(help="Listing, discovery, and explanation commands")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@app.command("list-aggregations")
|
|
21
|
+
def list_aggregations():
|
|
22
|
+
if HAS_RICH:
|
|
23
|
+
t = Table(title="Aggregation Strategies")
|
|
24
|
+
t.add_column("Name", style="bold")
|
|
25
|
+
t.add_column("Description")
|
|
26
|
+
for k, desc in aggs.descriptions().items():
|
|
27
|
+
t.add_row(k.name.lower(), desc)
|
|
28
|
+
echo(t)
|
|
29
|
+
else:
|
|
30
|
+
for k, desc in aggs.descriptions().items():
|
|
31
|
+
print(f"- {k.name.lower():22s} : {desc}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@app.command("list-methods")
|
|
35
|
+
def list_methods():
|
|
36
|
+
from wisent_guard.cli.steering_methods.steering_rotator import SteeringMethodRotator # type: ignore
|
|
37
|
+
rot = SteeringMethodRotator()
|
|
38
|
+
methods = rot.list_methods()
|
|
39
|
+
if not methods:
|
|
40
|
+
typer.echo("No steering methods registered.")
|
|
41
|
+
raise typer.Exit(code=1)
|
|
42
|
+
if HAS_RICH:
|
|
43
|
+
t = Table(title="Registered Steering Methods")
|
|
44
|
+
t.add_column("Name", style="bold")
|
|
45
|
+
t.add_column("Description")
|
|
46
|
+
t.add_column("Class")
|
|
47
|
+
for m in methods:
|
|
48
|
+
t.add_row(m["name"], m["description"], m["class"])
|
|
49
|
+
echo(t)
|
|
50
|
+
else:
|
|
51
|
+
for m in methods:
|
|
52
|
+
print(f"- {m['name']}: {m['description']} ({m['class']})")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@app.command("list-loaders")
|
|
56
|
+
def list_loaders(
|
|
57
|
+
loaders_location: Optional[str] = typer.Option(None, help="Package path or directory containing data loader modules"),
|
|
58
|
+
scope_prefix: Optional[str] = typer.Option(None, help="Limit list to module path prefix"),
|
|
59
|
+
):
|
|
60
|
+
from wisent_guard.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
61
|
+
if loaders_location:
|
|
62
|
+
DataLoaderRotator.discover_loaders(loaders_location)
|
|
63
|
+
loaders = DataLoaderRotator.list_loaders(scope_prefix=scope_prefix)
|
|
64
|
+
if not loaders:
|
|
65
|
+
typer.echo("No data loaders found.")
|
|
66
|
+
raise typer.Exit(code=1)
|
|
67
|
+
if HAS_RICH:
|
|
68
|
+
t = Table(title="Registered Data Loaders")
|
|
69
|
+
t.add_column("Name", style="bold")
|
|
70
|
+
t.add_column("Description")
|
|
71
|
+
t.add_column("Class")
|
|
72
|
+
for l in loaders:
|
|
73
|
+
t.add_row(l["name"], l["description"], l["class"])
|
|
74
|
+
echo(t)
|
|
75
|
+
else:
|
|
76
|
+
for l in loaders:
|
|
77
|
+
print(f"- {l['name']}: {l['description']} ({l['class']})")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@app.command("explain")
|
|
81
|
+
def explain(
|
|
82
|
+
method: Optional[str] = typer.Option(None, help="Steering method to describe"),
|
|
83
|
+
loader: Optional[str] = typer.Option(None, help="Data loader to describe"),
|
|
84
|
+
loaders_location: Optional[str] = typer.Option(None, help="Where to discover data loaders"),
|
|
85
|
+
aggregation: Optional[str] = typer.Option(None, help="Aggregation to describe"),
|
|
86
|
+
):
|
|
87
|
+
from wisent_guard.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
88
|
+
from wisent_guard.cli.steering_methods.steering_rotator import SteeringMethodRotator # type: ignore
|
|
89
|
+
|
|
90
|
+
if loaders_location:
|
|
91
|
+
DataLoaderRotator.discover_loaders(loaders_location)
|
|
92
|
+
|
|
93
|
+
if method:
|
|
94
|
+
m = SteeringMethodRotator._resolve_method(method)
|
|
95
|
+
doc = (getattr(type(m), "__doc__", None) or "No docstring.").strip()
|
|
96
|
+
if HAS_RICH:
|
|
97
|
+
echo(Panel(doc, title=f"Method: {getattr(m, 'name', type(m).__name__)}"))
|
|
98
|
+
else:
|
|
99
|
+
print(doc)
|
|
100
|
+
|
|
101
|
+
if loader:
|
|
102
|
+
reg = DataLoaderRotator.list_loaders()
|
|
103
|
+
match = next((x for x in reg if x["name"].lower() == loader.lower()), None)
|
|
104
|
+
if match:
|
|
105
|
+
desc = (match.get("description") or "No description.").strip()
|
|
106
|
+
if HAS_RICH:
|
|
107
|
+
echo(Panel(desc, title=f"Loader: {loader}"))
|
|
108
|
+
else:
|
|
109
|
+
print(desc)
|
|
110
|
+
else:
|
|
111
|
+
typer.echo(f"Unknown loader: {loader}")
|
|
112
|
+
|
|
113
|
+
if aggregation:
|
|
114
|
+
agg = aggs.pick(aggregation)
|
|
115
|
+
desc = aggs.descriptions().get(agg, "No description.")
|
|
116
|
+
if HAS_RICH:
|
|
117
|
+
from rich.panel import Panel
|
|
118
|
+
echo(Panel(desc, title=f"Aggregation: {aggregation}"))
|
|
119
|
+
else:
|
|
120
|
+
print(desc)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@app.command("loader-args")
|
|
124
|
+
def loader_args(
|
|
125
|
+
name: str = typer.Argument(..., help="Loader name, e.g. 'custom'"),
|
|
126
|
+
loaders_location: Optional[str] = typer.Option(None, help="Where to discover data loaders"),
|
|
127
|
+
):
|
|
128
|
+
"""
|
|
129
|
+
Show the exact arguments accepted by the loader's `load(...)` method.
|
|
130
|
+
Useful so users know precisely what to pass (e.g., for `custom`:
|
|
131
|
+
`path, split_ratio, seed, training_limit, testing_limit`).
|
|
132
|
+
"""
|
|
133
|
+
from wisent_guard.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
134
|
+
if loaders_location:
|
|
135
|
+
DataLoaderRotator.discover_loaders(loaders_location)
|
|
136
|
+
rot = DataLoaderRotator(loader=name, loaders_location=loaders_location or "wisent_guard.core.data_loaders.loaders")
|
|
137
|
+
# Best-effort introspection
|
|
138
|
+
target = None
|
|
139
|
+
for cand in (getattr(rot, "_loader", None), rot):
|
|
140
|
+
if cand is None:
|
|
141
|
+
continue
|
|
142
|
+
if hasattr(cand, "load"):
|
|
143
|
+
target = cand.load
|
|
144
|
+
break
|
|
145
|
+
if target is None:
|
|
146
|
+
typer.echo("Could not introspect loader signature.")
|
|
147
|
+
raise typer.Exit(code=1)
|
|
148
|
+
|
|
149
|
+
sig = inspect.signature(target)
|
|
150
|
+
if HAS_RICH:
|
|
151
|
+
from rich.panel import Panel
|
|
152
|
+
echo(Panel(f"{name}.load{sig}", title="Loader load(...) signature"))
|
|
153
|
+
else:
|
|
154
|
+
print(f"{name}.load{sig}")
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import typer
|
|
8
|
+
|
|
9
|
+
from wisent_guard.cli.wisent_cli.ui import echo
|
|
10
|
+
from wisent_guard.cli.wisent_cli.util import aggregations as aggs
|
|
11
|
+
from wisent_guard.cli.wisent_cli.util.parsing import (
|
|
12
|
+
parse_natural_tokens, parse_kv, parse_layers, to_bool, DTYPE_MAP,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from rich.table import Table
|
|
17
|
+
from rich.panel import Panel
|
|
18
|
+
from rich.syntax import Syntax
|
|
19
|
+
HAS_RICH = True
|
|
20
|
+
except Exception:
|
|
21
|
+
HAS_RICH = False
|
|
22
|
+
|
|
23
|
+
__all__ = ["app", "train"]
|
|
24
|
+
|
|
25
|
+
app = typer.Typer(help="Training workflow")
|
|
26
|
+
|
|
27
|
+
def _resolve_method(method_name: Optional[str], methods_location: Optional[str]):
|
|
28
|
+
from wisent_guard.cli.steering_methods.steering_rotator import SteeringMethodRotator # type: ignore
|
|
29
|
+
# Best effort discovery if available
|
|
30
|
+
try:
|
|
31
|
+
if methods_location and hasattr(SteeringMethodRotator, "discover_methods"):
|
|
32
|
+
SteeringMethodRotator.discover_methods(methods_location) # type: ignore[attr-defined]
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
rot = SteeringMethodRotator()
|
|
37
|
+
if method_name:
|
|
38
|
+
# Case-insensitive match from registry
|
|
39
|
+
registry = {m["name"].lower(): m["name"] for m in rot.list_methods()}
|
|
40
|
+
real = registry.get(method_name.lower(), method_name)
|
|
41
|
+
try:
|
|
42
|
+
rot.use(real)
|
|
43
|
+
inst = getattr(rot, "_method", None)
|
|
44
|
+
if inst is not None:
|
|
45
|
+
return inst
|
|
46
|
+
except Exception:
|
|
47
|
+
pass
|
|
48
|
+
# Fallback to private resolver
|
|
49
|
+
try:
|
|
50
|
+
return SteeringMethodRotator._resolve_method(real)
|
|
51
|
+
except Exception as ex:
|
|
52
|
+
raise typer.BadParameter(f"Unknown steering method: {method_name!r}") from ex
|
|
53
|
+
|
|
54
|
+
# No name provided -> default to first or 'caa' if present
|
|
55
|
+
names = [m["name"] for m in rot.list_methods()]
|
|
56
|
+
if "caa" in [n.lower() for n in names]:
|
|
57
|
+
rot.use("caa")
|
|
58
|
+
return getattr(rot, "_method", SteeringMethodRotator._resolve_method("caa"))
|
|
59
|
+
if not names:
|
|
60
|
+
raise typer.BadParameter("No steering methods registered.")
|
|
61
|
+
rot.use(names[0])
|
|
62
|
+
return getattr(rot, "_method", SteeringMethodRotator._resolve_method(names[0]))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _show_plan(
|
|
66
|
+
*,
|
|
67
|
+
model: str,
|
|
68
|
+
loader: Optional[str],
|
|
69
|
+
loaders_location: Optional[str],
|
|
70
|
+
loader_kwargs: Dict[str, object],
|
|
71
|
+
method_name: Optional[str],
|
|
72
|
+
method_kwargs: Dict[str, object],
|
|
73
|
+
layers: Optional[str],
|
|
74
|
+
aggregation_name: str,
|
|
75
|
+
store_device: str,
|
|
76
|
+
dtype: Optional[str],
|
|
77
|
+
return_full_sequence: bool,
|
|
78
|
+
normalize_layers: bool,
|
|
79
|
+
save_dir: Optional[Path],
|
|
80
|
+
) -> None:
|
|
81
|
+
plan = {
|
|
82
|
+
"Model": model,
|
|
83
|
+
"Data loader": loader or "(default)",
|
|
84
|
+
"Loaders location": loaders_location or "(auto)",
|
|
85
|
+
"Loader kwargs": loader_kwargs or {},
|
|
86
|
+
"Method": method_name or "(resolved automatically)",
|
|
87
|
+
"Method kwargs": method_kwargs or {},
|
|
88
|
+
"Layers": layers or "(all)",
|
|
89
|
+
"Aggregation": aggregation_name,
|
|
90
|
+
"Return full sequence": return_full_sequence,
|
|
91
|
+
"Normalize layers": normalize_layers,
|
|
92
|
+
"Store device": store_device,
|
|
93
|
+
"Dtype": dtype or "(unchanged)",
|
|
94
|
+
"Save dir": str(save_dir) if save_dir else "(none)",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
code = f"""
|
|
98
|
+
# Example: Training steering vectors (auto-generated plan)
|
|
99
|
+
from wisent_guard.core.trainers.steering_trainer import WisentSteeringTrainer
|
|
100
|
+
from wisent_guard.core.models.wisent_model import WisentModel
|
|
101
|
+
from wisent_guard.cli.data_loaders.data_loader_rotator import DataLoaderRotator
|
|
102
|
+
from wisent_guard.cli.steering_methods.steering_rotator import SteeringMethodRotator
|
|
103
|
+
from wisent_guard.core.activations.core.atoms import ActivationAggregationStrategy
|
|
104
|
+
|
|
105
|
+
# 1) Model
|
|
106
|
+
model = WisentModel(model_name={model!r}, layers={{}}, device={store_device!r})
|
|
107
|
+
|
|
108
|
+
# 2) Data loader
|
|
109
|
+
rot = DataLoaderRotator(loader={loader!r}, loaders_location={loaders_location!r})
|
|
110
|
+
load = rot.load(**{json.dumps(loader_kwargs)})
|
|
111
|
+
|
|
112
|
+
# 3) Method
|
|
113
|
+
method = SteeringMethodRotator._resolve_method({(method_name or 'caa')!r})
|
|
114
|
+
|
|
115
|
+
# 4) Trainer
|
|
116
|
+
trainer = WisentSteeringTrainer(model=model, pair_set=load["train_qa_pairs"], steering_method=method,
|
|
117
|
+
store_device={store_device!r}, dtype={dtype!r})
|
|
118
|
+
|
|
119
|
+
# 5) Train
|
|
120
|
+
result = trainer.run(
|
|
121
|
+
layers_spec={layers!r},
|
|
122
|
+
method_kwargs={json.dumps(method_kwargs)},
|
|
123
|
+
aggregation=ActivationAggregationStrategy.{aggs.pick(aggregation_name).name},
|
|
124
|
+
return_full_sequence={return_full_sequence!r},
|
|
125
|
+
normalize_layers={normalize_layers!r},
|
|
126
|
+
save_dir={str(save_dir) if save_dir else None!r},
|
|
127
|
+
)
|
|
128
|
+
""".strip()
|
|
129
|
+
|
|
130
|
+
if HAS_RICH:
|
|
131
|
+
t = Table(title="Execution Plan")
|
|
132
|
+
t.add_column("Key", style="bold", no_wrap=True)
|
|
133
|
+
t.add_column("Value")
|
|
134
|
+
for k, v in plan.items():
|
|
135
|
+
t.add_row(k, json.dumps(v) if isinstance(v, (dict, list)) else str(v))
|
|
136
|
+
echo(Panel(t, expand=False))
|
|
137
|
+
echo(Panel(Syntax(code, "python", word_wrap=False), title="Code Preview", expand=False))
|
|
138
|
+
else:
|
|
139
|
+
print(json.dumps(plan, indent=2))
|
|
140
|
+
print("\n" + code)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@app.command("train", context_settings={"ignore_unknown_options": True, "allow_extra_args": True})
|
|
144
|
+
def train(ctx: typer.Context, params: List[str] = typer.Argument(None)):
|
|
145
|
+
"""
|
|
146
|
+
Natural (no-dash) usage examples:
|
|
147
|
+
|
|
148
|
+
wisent train model meta-llama/Llama-3.2-1B-Instruct loader custom path ./custom.json training_limit 5 method caa
|
|
149
|
+
|
|
150
|
+
wisent train interactive true
|
|
151
|
+
|
|
152
|
+
See `wisent loader-args custom` to view the exact loader arguments.
|
|
153
|
+
"""
|
|
154
|
+
# Lazy imports
|
|
155
|
+
from wisent_guard.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
156
|
+
from wisent_guard.core.models.wisent_model import WisentModel # type: ignore
|
|
157
|
+
from wisent_guard.core.trainers.steering_trainer import WisentSteeringTrainer # type: ignore
|
|
158
|
+
|
|
159
|
+
tokens = list(params or []) + list(ctx.args or [])
|
|
160
|
+
top, loader_kv_raw, method_kv_raw = parse_natural_tokens(tokens)
|
|
161
|
+
|
|
162
|
+
# Core args
|
|
163
|
+
model = top.get("model")
|
|
164
|
+
if not model:
|
|
165
|
+
raise typer.BadParameter("Please specify a model (e.g. `train model meta-llama/Llama-3.2-1B-Instruct`) or use `interactive true`.")
|
|
166
|
+
|
|
167
|
+
loader = top.get("loader")
|
|
168
|
+
loaders_location = top.get("loaders_location")
|
|
169
|
+
methods_location = top.get("methods_location")
|
|
170
|
+
method_name = top.get("method")
|
|
171
|
+
|
|
172
|
+
layers = parse_layers(top.get("layers")) if top.get("layers") else None
|
|
173
|
+
aggregation_name = (top.get("aggregation") or "continuation_token").lower()
|
|
174
|
+
store_device = top.get("device") or top.get("store_device") or "cpu"
|
|
175
|
+
dtype = top.get("dtype")
|
|
176
|
+
save_dir = Path(top["save_dir"]) if top.get("save_dir") else None
|
|
177
|
+
return_full_sequence = to_bool(top.get("return_full_sequence", "false")) if "return_full_sequence" in top else False
|
|
178
|
+
normalize_layers = to_bool(top.get("normalize_layers", "false")) if "normalize_layers" in top else False
|
|
179
|
+
interactive = to_bool(top.get("interactive", "false")) if "interactive" in top else False
|
|
180
|
+
plan_only = to_bool(top.get("plan-only", top.get("plan_only", "false"))) if ( "plan-only" in top or "plan_only" in top ) else False
|
|
181
|
+
confirm = to_bool(top.get("confirm", "true")) if "confirm" in top else True
|
|
182
|
+
|
|
183
|
+
# Convert kwargs
|
|
184
|
+
loader_kwargs = parse_kv([f"{k}={v}" for k, v in loader_kv_raw.items()])
|
|
185
|
+
method_kwargs = parse_kv([f"{k}={v}" for k, v in method_kv_raw.items()])
|
|
186
|
+
|
|
187
|
+
# Interactive wizard
|
|
188
|
+
if interactive:
|
|
189
|
+
if loaders_location:
|
|
190
|
+
DataLoaderRotator.discover_loaders(loaders_location)
|
|
191
|
+
if not loader:
|
|
192
|
+
options = [d["name"] for d in DataLoaderRotator.list_loaders()]
|
|
193
|
+
loader = typer.prompt("Choose data loader", default=(options[0] if options else "custom"))
|
|
194
|
+
if loader and loader.lower() == "custom":
|
|
195
|
+
echo(Panel(
|
|
196
|
+
"[b]Custom loader arguments[/]\n\n"
|
|
197
|
+
"• path (str) [required]\n"
|
|
198
|
+
"• split_ratio (float | None)\n"
|
|
199
|
+
"• seed (int | None)\n"
|
|
200
|
+
"• training_limit (int | None)\n"
|
|
201
|
+
"• testing_limit (int | None)",
|
|
202
|
+
title="custom.load(...)",
|
|
203
|
+
) if HAS_RICH else
|
|
204
|
+
None
|
|
205
|
+
)
|
|
206
|
+
if "path" not in loader_kwargs:
|
|
207
|
+
loader_kwargs["path"] = typer.prompt("Path to dataset JSON (required)")
|
|
208
|
+
for name, cast, default in [
|
|
209
|
+
("split_ratio", float, ""),
|
|
210
|
+
("seed", int, ""),
|
|
211
|
+
("training_limit", int, ""),
|
|
212
|
+
("testing_limit", int, ""),
|
|
213
|
+
]:
|
|
214
|
+
if name not in loader_kwargs:
|
|
215
|
+
val = typer.prompt(f"{name} (optional)", default=default)
|
|
216
|
+
if str(val).strip() != "":
|
|
217
|
+
try:
|
|
218
|
+
loader_kwargs[name] = cast(val)
|
|
219
|
+
except Exception:
|
|
220
|
+
loader_kwargs[name] = val
|
|
221
|
+
if not method_name:
|
|
222
|
+
method_name = typer.prompt("Choose steering method (see list-methods)", default="caa")
|
|
223
|
+
if layers is None:
|
|
224
|
+
layers = parse_layers(typer.prompt("Layers (e.g., '10..12', '5,7,9' or leave empty for all)", default="") or None)
|
|
225
|
+
if "aggregation" not in top:
|
|
226
|
+
aggregation_name = typer.prompt("Aggregation (see list-aggregations)", default="continuation_token")
|
|
227
|
+
if "dtype" not in top:
|
|
228
|
+
dtype = typer.prompt("Activation dtype (float32/float16/bfloat16 or blank)", default="") or None
|
|
229
|
+
if "device" not in top and "store_device" not in top:
|
|
230
|
+
store_device = typer.prompt("Device to store activations on (cpu / cuda / cuda:0 / ...)", default="cpu")
|
|
231
|
+
if "normalize_layers" not in top:
|
|
232
|
+
normalize_layers = typer.confirm("Normalize activations per layer?", default=True)
|
|
233
|
+
if "return_full_sequence" not in top:
|
|
234
|
+
return_full_sequence = typer.confirm("Return full [T,H] sequence per layer?", default=False)
|
|
235
|
+
if "save_dir" not in top:
|
|
236
|
+
default_out = os.path.abspath("./steering_output")
|
|
237
|
+
p = typer.prompt("Save directory for artifacts (blank to skip saving)", default=default_out)
|
|
238
|
+
if p.strip():
|
|
239
|
+
save_dir = Path(p)
|
|
240
|
+
if "plan-only" not in top and "plan_only" not in top:
|
|
241
|
+
plan_only = typer.confirm("Only show the plan and code preview?", default=False)
|
|
242
|
+
if "confirm" not in top:
|
|
243
|
+
confirm = typer.confirm("Confirm before running?", default=True)
|
|
244
|
+
|
|
245
|
+
# Validate dtype
|
|
246
|
+
if dtype not in DTYPE_MAP:
|
|
247
|
+
raise typer.BadParameter("dtype must be one of: float32, float16, bfloat16")
|
|
248
|
+
|
|
249
|
+
# Validate aggregation
|
|
250
|
+
try:
|
|
251
|
+
agg = aggs.pick(aggregation_name)
|
|
252
|
+
except ValueError as ex:
|
|
253
|
+
raise typer.BadParameter(str(ex)) from ex
|
|
254
|
+
|
|
255
|
+
# Plan
|
|
256
|
+
_show_plan(
|
|
257
|
+
model=model,
|
|
258
|
+
loader=loader,
|
|
259
|
+
loaders_location=loaders_location,
|
|
260
|
+
loader_kwargs=loader_kwargs,
|
|
261
|
+
method_name=method_name,
|
|
262
|
+
method_kwargs=method_kwargs,
|
|
263
|
+
layers=layers,
|
|
264
|
+
aggregation_name=aggregation_name,
|
|
265
|
+
store_device=store_device,
|
|
266
|
+
dtype=dtype,
|
|
267
|
+
return_full_sequence=return_full_sequence,
|
|
268
|
+
normalize_layers=normalize_layers,
|
|
269
|
+
save_dir=save_dir,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if plan_only:
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
if confirm and not typer.confirm("Proceed with training?", default=True):
|
|
276
|
+
typer.echo("Aborted.")
|
|
277
|
+
raise typer.Exit(code=1)
|
|
278
|
+
|
|
279
|
+
# -- Model -----------------------------------------------------------------
|
|
280
|
+
typer.echo(f"[+] Loading model: {model}")
|
|
281
|
+
from wisent_guard.core.models.wisent_model import WisentModel # type: ignore
|
|
282
|
+
wmodel = WisentModel(model_name=model, layers={}, device=store_device)
|
|
283
|
+
|
|
284
|
+
# -- Data loader -----------------------------------------------------------
|
|
285
|
+
from wisent_guard.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
286
|
+
if loaders_location:
|
|
287
|
+
DataLoaderRotator.discover_loaders(loaders_location)
|
|
288
|
+
dl_rot = DataLoaderRotator(loader=loader, loaders_location=loaders_location or "wisent_guard.core.data_loaders.loaders")
|
|
289
|
+
typer.echo(f"[+] Using data loader: {loader or '(default)'}")
|
|
290
|
+
load_result = dl_rot.load(**loader_kwargs)
|
|
291
|
+
pair_set = load_result["train_qa_pairs"]
|
|
292
|
+
typer.echo(f"[+] Loaded training pairs: {len(pair_set)} (task_type={load_result['task_type']})")
|
|
293
|
+
|
|
294
|
+
# -- Steering method -------------------------------------------------------
|
|
295
|
+
method_inst = _resolve_method(method_name, methods_location)
|
|
296
|
+
name_shown = getattr(method_inst, "name", type(method_inst).__name__)
|
|
297
|
+
typer.echo(f"[+] Steering method: {name_shown}")
|
|
298
|
+
|
|
299
|
+
# -- Trainer ---------------------------------------------------------------
|
|
300
|
+
from wisent_guard.core.trainers.steering_trainer import WisentSteeringTrainer # type: ignore
|
|
301
|
+
torch_dtype = None if dtype is None else __import__("torch").__dict__[DTYPE_MAP[dtype]]
|
|
302
|
+
trainer = WisentSteeringTrainer(
|
|
303
|
+
model=wmodel,
|
|
304
|
+
pair_set=pair_set,
|
|
305
|
+
steering_method=method_inst,
|
|
306
|
+
store_device=store_device,
|
|
307
|
+
dtype=torch_dtype,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
result = trainer.run(
|
|
311
|
+
layers_spec=layers,
|
|
312
|
+
method_kwargs=method_kwargs,
|
|
313
|
+
aggregation=agg,
|
|
314
|
+
return_full_sequence=return_full_sequence,
|
|
315
|
+
normalize_layers=normalize_layers,
|
|
316
|
+
save_dir=save_dir,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
typer.echo("\n=== Training Summary ===")
|
|
320
|
+
typer.echo(json.dumps(result.metadata, indent=2))
|
|
321
|
+
if save_dir is not None:
|
|
322
|
+
typer.echo(f"\nArtifacts saved in: {Path(save_dir).resolve()}\n")
|