wisent 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (225) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/core/activations/__init__.py +26 -0
  3. wisent/core/activations/activations.py +96 -0
  4. wisent/core/activations/activations_collector.py +71 -20
  5. wisent/core/activations/prompt_construction_strategy.py +47 -0
  6. wisent/core/agent/budget.py +2 -2
  7. wisent/core/agent/device_benchmarks.py +1 -1
  8. wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
  9. wisent/core/agent/diagnose/response_diagnostics.py +4 -4
  10. wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
  11. wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
  12. wisent/core/agent/diagnose.py +2 -1
  13. wisent/core/autonomous_agent.py +10 -2
  14. wisent/core/benchmark_extractors.py +293 -0
  15. wisent/core/bigcode_integration.py +20 -7
  16. wisent/core/branding.py +108 -0
  17. wisent/core/cli/__init__.py +15 -0
  18. wisent/core/cli/create_steering_vector.py +138 -0
  19. wisent/core/cli/evaluate_responses.py +715 -0
  20. wisent/core/cli/generate_pairs.py +128 -0
  21. wisent/core/cli/generate_pairs_from_task.py +119 -0
  22. wisent/core/cli/generate_responses.py +129 -0
  23. wisent/core/cli/generate_vector_from_synthetic.py +149 -0
  24. wisent/core/cli/generate_vector_from_task.py +147 -0
  25. wisent/core/cli/get_activations.py +191 -0
  26. wisent/core/cli/optimize_classification.py +339 -0
  27. wisent/core/cli/optimize_steering.py +364 -0
  28. wisent/core/cli/tasks.py +182 -0
  29. wisent/core/cli_logger.py +22 -0
  30. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
  31. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
  32. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
  33. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
  34. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
  35. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
  36. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
  37. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
  38. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
  39. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
  40. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
  41. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
  42. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
  43. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
  44. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
  45. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
  46. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
  47. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
  48. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
  49. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
  50. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
  51. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
  52. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
  53. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
  54. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
  55. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
  56. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
  57. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
  58. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
  59. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
  60. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
  61. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
  62. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
  63. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
  64. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
  65. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
  66. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
  67. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
  68. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
  69. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
  70. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
  71. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
  72. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
  73. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
  74. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
  75. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
  76. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
  77. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
  78. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
  79. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
  80. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
  81. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
  82. wisent/core/data_loaders/__init__.py +235 -0
  83. wisent/core/data_loaders/loaders/lm_loader.py +2 -2
  84. wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
  85. wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
  86. wisent/core/download_full_benchmarks.py +79 -2
  87. wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
  88. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
  89. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
  90. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
  91. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
  92. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
  93. wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
  94. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
  95. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
  96. wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
  97. wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
  98. wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
  99. wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
  100. wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
  101. wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
  102. wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
  103. wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
  104. wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
  105. wisent/core/lm_eval_harness_ground_truth.py +3 -2
  106. wisent/core/main.py +57 -0
  107. wisent/core/model_persistence.py +2 -2
  108. wisent/core/models/wisent_model.py +8 -6
  109. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
  110. wisent/core/optuna/steering/steering_optimization.py +1 -1
  111. wisent/core/parser_arguments/__init__.py +10 -0
  112. wisent/core/parser_arguments/agent_parser.py +110 -0
  113. wisent/core/parser_arguments/configure_model_parser.py +7 -0
  114. wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
  115. wisent/core/parser_arguments/evaluate_parser.py +40 -0
  116. wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
  117. wisent/core/parser_arguments/full_optimize_parser.py +115 -0
  118. wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
  119. wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
  120. wisent/core/parser_arguments/generate_responses_parser.py +15 -0
  121. wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
  122. wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
  123. wisent/core/parser_arguments/generate_vector_parser.py +90 -0
  124. wisent/core/parser_arguments/get_activations_parser.py +90 -0
  125. wisent/core/parser_arguments/main_parser.py +152 -0
  126. wisent/core/parser_arguments/model_config_parser.py +59 -0
  127. wisent/core/parser_arguments/monitor_parser.py +17 -0
  128. wisent/core/parser_arguments/multi_steer_parser.py +47 -0
  129. wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
  130. wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
  131. wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
  132. wisent/core/parser_arguments/synthetic_parser.py +93 -0
  133. wisent/core/parser_arguments/tasks_parser.py +584 -0
  134. wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
  135. wisent/core/parser_arguments/utils.py +111 -0
  136. wisent/core/prompts/core/prompt_formater.py +3 -3
  137. wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
  138. wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
  139. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
  140. wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
  141. wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
  142. wisent/core/steering_optimizer.py +45 -21
  143. wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
  144. wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
  145. wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
  146. wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
  147. wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
  148. wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
  149. wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
  150. wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
  151. wisent/core/tasks/livecodebench_task.py +4 -103
  152. wisent/core/timing_calibration.py +1 -1
  153. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/METADATA +3 -3
  154. wisent-0.5.13.dist-info/RECORD +294 -0
  155. wisent-0.5.13.dist-info/entry_points.txt +2 -0
  156. wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
  157. wisent/classifiers/core/atoms.py +0 -747
  158. wisent/classifiers/models/logistic.py +0 -29
  159. wisent/classifiers/models/mlp.py +0 -47
  160. wisent/cli/classifiers/classifier_rotator.py +0 -137
  161. wisent/cli/cli_logger.py +0 -142
  162. wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
  163. wisent/cli/wisent_cli/commands/listing.py +0 -154
  164. wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
  165. wisent/cli/wisent_cli/main.py +0 -93
  166. wisent/cli/wisent_cli/shell.py +0 -80
  167. wisent/cli/wisent_cli/ui.py +0 -69
  168. wisent/cli/wisent_cli/util/aggregations.py +0 -43
  169. wisent/cli/wisent_cli/util/parsing.py +0 -126
  170. wisent/cli/wisent_cli/version.py +0 -4
  171. wisent/opti/methods/__init__.py +0 -0
  172. wisent/synthetic/__init__.py +0 -0
  173. wisent/synthetic/cleaners/__init__.py +0 -0
  174. wisent/synthetic/cleaners/core/__init__.py +0 -0
  175. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  176. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  177. wisent/synthetic/db_instructions/__init__.py +0 -0
  178. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  179. wisent/synthetic/generators/__init__.py +0 -0
  180. wisent/synthetic/generators/core/__init__.py +0 -0
  181. wisent/synthetic/generators/diversities/__init__.py +0 -0
  182. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  183. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  184. wisent-0.5.11.dist-info/RECORD +0 -220
  185. /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
  186. /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
  187. /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
  188. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
  189. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
  190. /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
  191. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
  192. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
  193. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
  194. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
  195. /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
  196. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
  197. /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
  198. /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
  199. /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
  200. /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
  201. /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
  202. /wisent/{opti → core/opti}/core/atoms.py +0 -0
  203. /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
  204. /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
  205. /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
  206. /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
  207. /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
  208. /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
  209. /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
  210. /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
  211. /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
  212. /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
  213. /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
  214. /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
  215. /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
  216. /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
  217. /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
  218. /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
  219. /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
  220. /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
  221. /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
  222. /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
  223. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/WHEEL +0 -0
  224. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/licenses/LICENSE +0 -0
  225. {wisent-0.5.11.dist-info → wisent-0.5.13.dist-info}/top_level.txt +0 -0
@@ -1,322 +0,0 @@
1
- from __future__ import annotations
2
- import json
3
- import os
4
- from pathlib import Path
5
- from typing import Dict, List, Optional
6
-
7
- import typer
8
-
9
- from wisent.cli.wisent_cli.ui import echo
10
- from wisent.cli.wisent_cli.util import aggregations as aggs
11
- from wisent.cli.wisent_cli.util.parsing import (
12
- parse_natural_tokens, parse_kv, parse_layers, to_bool, DTYPE_MAP,
13
- )
14
-
15
- try:
16
- from rich.table import Table
17
- from rich.panel import Panel
18
- from rich.syntax import Syntax
19
- HAS_RICH = True
20
- except Exception:
21
- HAS_RICH = False
22
-
23
- __all__ = ["app", "train"]
24
-
25
- app = typer.Typer(help="Training workflow")
26
-
27
- def _resolve_method(method_name: Optional[str], methods_location: Optional[str]):
28
- from wisent.cli.steering_methods.steering_rotator import SteeringMethodRotator # type: ignore
29
- # Best effort discovery if available
30
- try:
31
- if methods_location and hasattr(SteeringMethodRotator, "discover_methods"):
32
- SteeringMethodRotator.discover_methods(methods_location) # type: ignore[attr-defined]
33
- except Exception:
34
- pass
35
-
36
- rot = SteeringMethodRotator()
37
- if method_name:
38
- # Case-insensitive match from registry
39
- registry = {m["name"].lower(): m["name"] for m in rot.list_methods()}
40
- real = registry.get(method_name.lower(), method_name)
41
- try:
42
- rot.use(real)
43
- inst = getattr(rot, "_method", None)
44
- if inst is not None:
45
- return inst
46
- except Exception:
47
- pass
48
- # Fallback to private resolver
49
- try:
50
- return SteeringMethodRotator._resolve_method(real)
51
- except Exception as ex:
52
- raise typer.BadParameter(f"Unknown steering method: {method_name!r}") from ex
53
-
54
- # No name provided -> default to first or 'caa' if present
55
- names = [m["name"] for m in rot.list_methods()]
56
- if "caa" in [n.lower() for n in names]:
57
- rot.use("caa")
58
- return getattr(rot, "_method", SteeringMethodRotator._resolve_method("caa"))
59
- if not names:
60
- raise typer.BadParameter("No steering methods registered.")
61
- rot.use(names[0])
62
- return getattr(rot, "_method", SteeringMethodRotator._resolve_method(names[0]))
63
-
64
-
65
- def _show_plan(
66
- *,
67
- model: str,
68
- loader: Optional[str],
69
- loaders_location: Optional[str],
70
- loader_kwargs: Dict[str, object],
71
- method_name: Optional[str],
72
- method_kwargs: Dict[str, object],
73
- layers: Optional[str],
74
- aggregation_name: str,
75
- store_device: str,
76
- dtype: Optional[str],
77
- return_full_sequence: bool,
78
- normalize_layers: bool,
79
- save_dir: Optional[Path],
80
- ) -> None:
81
- plan = {
82
- "Model": model,
83
- "Data loader": loader or "(default)",
84
- "Loaders location": loaders_location or "(auto)",
85
- "Loader kwargs": loader_kwargs or {},
86
- "Method": method_name or "(resolved automatically)",
87
- "Method kwargs": method_kwargs or {},
88
- "Layers": layers or "(all)",
89
- "Aggregation": aggregation_name,
90
- "Return full sequence": return_full_sequence,
91
- "Normalize layers": normalize_layers,
92
- "Store device": store_device,
93
- "Dtype": dtype or "(unchanged)",
94
- "Save dir": str(save_dir) if save_dir else "(none)",
95
- }
96
-
97
- code = f"""
98
- # Example: Training steering vectors (auto-generated plan)
99
- from wisent.core.trainers.steering_trainer import WisentSteeringTrainer
100
- from wisent.core.models.wisent_model import WisentModel
101
- from wisent.cli.data_loaders.data_loader_rotator import DataLoaderRotator
102
- from wisent.cli.steering_methods.steering_rotator import SteeringMethodRotator
103
- from wisent.core.activations.core.atoms import ActivationAggregationStrategy
104
-
105
- # 1) Model
106
- model = WisentModel(model_name={model!r}, layers={{}}, device={store_device!r})
107
-
108
- # 2) Data loader
109
- rot = DataLoaderRotator(loader={loader!r}, loaders_location={loaders_location!r})
110
- load = rot.load(**{json.dumps(loader_kwargs)})
111
-
112
- # 3) Method
113
- method = SteeringMethodRotator._resolve_method({(method_name or 'caa')!r})
114
-
115
- # 4) Trainer
116
- trainer = WisentSteeringTrainer(model=model, pair_set=load["train_qa_pairs"], steering_method=method,
117
- store_device={store_device!r}, dtype={dtype!r})
118
-
119
- # 5) Train
120
- result = trainer.run(
121
- layers_spec={layers!r},
122
- method_kwargs={json.dumps(method_kwargs)},
123
- aggregation=ActivationAggregationStrategy.{aggs.pick(aggregation_name).name},
124
- return_full_sequence={return_full_sequence!r},
125
- normalize_layers={normalize_layers!r},
126
- save_dir={str(save_dir) if save_dir else None!r},
127
- )
128
- """.strip()
129
-
130
- if HAS_RICH:
131
- t = Table(title="Execution Plan")
132
- t.add_column("Key", style="bold", no_wrap=True)
133
- t.add_column("Value")
134
- for k, v in plan.items():
135
- t.add_row(k, json.dumps(v) if isinstance(v, (dict, list)) else str(v))
136
- echo(Panel(t, expand=False))
137
- echo(Panel(Syntax(code, "python", word_wrap=False), title="Code Preview", expand=False))
138
- else:
139
- print(json.dumps(plan, indent=2))
140
- print("\n" + code)
141
-
142
-
143
- @app.command("train", context_settings={"ignore_unknown_options": True, "allow_extra_args": True})
144
- def train(ctx: typer.Context, params: List[str] = typer.Argument(None)):
145
- """
146
- Natural (no-dash) usage examples:
147
-
148
- wisent train model meta-llama/Llama-3.2-1B-Instruct loader custom path ./custom.json training_limit 5 method caa
149
-
150
- wisent train interactive true
151
-
152
- See `wisent loader-args custom` to view the exact loader arguments.
153
- """
154
- # Lazy imports
155
- from wisent.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
156
- from wisent.core.models.wisent_model import WisentModel # type: ignore
157
- from wisent.core.trainers.steering_trainer import WisentSteeringTrainer # type: ignore
158
-
159
- tokens = list(params or []) + list(ctx.args or [])
160
- top, loader_kv_raw, method_kv_raw = parse_natural_tokens(tokens)
161
-
162
- # Core args
163
- model = top.get("model")
164
- if not model:
165
- raise typer.BadParameter("Please specify a model (e.g. `train model meta-llama/Llama-3.2-1B-Instruct`) or use `interactive true`.")
166
-
167
- loader = top.get("loader")
168
- loaders_location = top.get("loaders_location")
169
- methods_location = top.get("methods_location")
170
- method_name = top.get("method")
171
-
172
- layers = parse_layers(top.get("layers")) if top.get("layers") else None
173
- aggregation_name = (top.get("aggregation") or "continuation_token").lower()
174
- store_device = top.get("device") or top.get("store_device") or "cpu"
175
- dtype = top.get("dtype")
176
- save_dir = Path(top["save_dir"]) if top.get("save_dir") else None
177
- return_full_sequence = to_bool(top.get("return_full_sequence", "false")) if "return_full_sequence" in top else False
178
- normalize_layers = to_bool(top.get("normalize_layers", "false")) if "normalize_layers" in top else False
179
- interactive = to_bool(top.get("interactive", "false")) if "interactive" in top else False
180
- plan_only = to_bool(top.get("plan-only", top.get("plan_only", "false"))) if ( "plan-only" in top or "plan_only" in top ) else False
181
- confirm = to_bool(top.get("confirm", "true")) if "confirm" in top else True
182
-
183
- # Convert kwargs
184
- loader_kwargs = parse_kv([f"{k}={v}" for k, v in loader_kv_raw.items()])
185
- method_kwargs = parse_kv([f"{k}={v}" for k, v in method_kv_raw.items()])
186
-
187
- # Interactive wizard
188
- if interactive:
189
- if loaders_location:
190
- DataLoaderRotator.discover_loaders(loaders_location)
191
- if not loader:
192
- options = [d["name"] for d in DataLoaderRotator.list_loaders()]
193
- loader = typer.prompt("Choose data loader", default=(options[0] if options else "custom"))
194
- if loader and loader.lower() == "custom":
195
- echo(Panel(
196
- "[b]Custom loader arguments[/]\n\n"
197
- "• path (str) [required]\n"
198
- "• split_ratio (float | None)\n"
199
- "• seed (int | None)\n"
200
- "• training_limit (int | None)\n"
201
- "• testing_limit (int | None)",
202
- title="custom.load(...)",
203
- ) if HAS_RICH else
204
- None
205
- )
206
- if "path" not in loader_kwargs:
207
- loader_kwargs["path"] = typer.prompt("Path to dataset JSON (required)")
208
- for name, cast, default in [
209
- ("split_ratio", float, ""),
210
- ("seed", int, ""),
211
- ("training_limit", int, ""),
212
- ("testing_limit", int, ""),
213
- ]:
214
- if name not in loader_kwargs:
215
- val = typer.prompt(f"{name} (optional)", default=default)
216
- if str(val).strip() != "":
217
- try:
218
- loader_kwargs[name] = cast(val)
219
- except Exception:
220
- loader_kwargs[name] = val
221
- if not method_name:
222
- method_name = typer.prompt("Choose steering method (see list-methods)", default="caa")
223
- if layers is None:
224
- layers = parse_layers(typer.prompt("Layers (e.g., '10..12', '5,7,9' or leave empty for all)", default="") or None)
225
- if "aggregation" not in top:
226
- aggregation_name = typer.prompt("Aggregation (see list-aggregations)", default="continuation_token")
227
- if "dtype" not in top:
228
- dtype = typer.prompt("Activation dtype (float32/float16/bfloat16 or blank)", default="") or None
229
- if "device" not in top and "store_device" not in top:
230
- store_device = typer.prompt("Device to store activations on (cpu / cuda / cuda:0 / ...)", default="cpu")
231
- if "normalize_layers" not in top:
232
- normalize_layers = typer.confirm("Normalize activations per layer?", default=True)
233
- if "return_full_sequence" not in top:
234
- return_full_sequence = typer.confirm("Return full [T,H] sequence per layer?", default=False)
235
- if "save_dir" not in top:
236
- default_out = os.path.abspath("./steering_output")
237
- p = typer.prompt("Save directory for artifacts (blank to skip saving)", default=default_out)
238
- if p.strip():
239
- save_dir = Path(p)
240
- if "plan-only" not in top and "plan_only" not in top:
241
- plan_only = typer.confirm("Only show the plan and code preview?", default=False)
242
- if "confirm" not in top:
243
- confirm = typer.confirm("Confirm before running?", default=True)
244
-
245
- # Validate dtype
246
- if dtype not in DTYPE_MAP:
247
- raise typer.BadParameter("dtype must be one of: float32, float16, bfloat16")
248
-
249
- # Validate aggregation
250
- try:
251
- agg = aggs.pick(aggregation_name)
252
- except ValueError as ex:
253
- raise typer.BadParameter(str(ex)) from ex
254
-
255
- # Plan
256
- _show_plan(
257
- model=model,
258
- loader=loader,
259
- loaders_location=loaders_location,
260
- loader_kwargs=loader_kwargs,
261
- method_name=method_name,
262
- method_kwargs=method_kwargs,
263
- layers=layers,
264
- aggregation_name=aggregation_name,
265
- store_device=store_device,
266
- dtype=dtype,
267
- return_full_sequence=return_full_sequence,
268
- normalize_layers=normalize_layers,
269
- save_dir=save_dir,
270
- )
271
-
272
- if plan_only:
273
- return
274
-
275
- if confirm and not typer.confirm("Proceed with training?", default=True):
276
- typer.echo("Aborted.")
277
- raise typer.Exit(code=1)
278
-
279
- # -- Model -----------------------------------------------------------------
280
- typer.echo(f"[+] Loading model: {model}")
281
- from wisent.core.models.wisent_model import WisentModel # type: ignore
282
- wmodel = WisentModel(model_name=model, layers={}, device=store_device)
283
-
284
- # -- Data loader -----------------------------------------------------------
285
- from wisent.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
286
- if loaders_location:
287
- DataLoaderRotator.discover_loaders(loaders_location)
288
- dl_rot = DataLoaderRotator(loader=loader, loaders_location=loaders_location or "wisent_guard.core.data_loaders.loaders")
289
- typer.echo(f"[+] Using data loader: {loader or '(default)'}")
290
- load_result = dl_rot.load(**loader_kwargs)
291
- pair_set = load_result["train_qa_pairs"]
292
- typer.echo(f"[+] Loaded training pairs: {len(pair_set)} (task_type={load_result['task_type']})")
293
-
294
- # -- Steering method -------------------------------------------------------
295
- method_inst = _resolve_method(method_name, methods_location)
296
- name_shown = getattr(method_inst, "name", type(method_inst).__name__)
297
- typer.echo(f"[+] Steering method: {name_shown}")
298
-
299
- # -- Trainer ---------------------------------------------------------------
300
- from wisent.core.trainers.steering_trainer import WisentSteeringTrainer # type: ignore
301
- torch_dtype = None if dtype is None else __import__("torch").__dict__[DTYPE_MAP[dtype]]
302
- trainer = WisentSteeringTrainer(
303
- model=wmodel,
304
- pair_set=pair_set,
305
- steering_method=method_inst,
306
- store_device=store_device,
307
- dtype=torch_dtype,
308
- )
309
-
310
- result = trainer.run(
311
- layers_spec=layers,
312
- method_kwargs=method_kwargs,
313
- aggregation=agg,
314
- return_full_sequence=return_full_sequence,
315
- normalize_layers=normalize_layers,
316
- save_dir=save_dir,
317
- )
318
-
319
- typer.echo("\n=== Training Summary ===")
320
- typer.echo(json.dumps(result.metadata, indent=2))
321
- if save_dir is not None:
322
- typer.echo(f"\nArtifacts saved in: {Path(save_dir).resolve()}\n")
@@ -1,93 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Optional
3
-
4
- import typer
5
-
6
- from wisent.cli.wisent_cli.version import APP_NAME, APP_VERSION
7
- from wisent.cli.wisent_cli.ui import print_banner
8
- from wisent.cli.wisent_cli.commands.listing import app as listing_app
9
- from wisent.cli.wisent_cli.commands.train_cmd import app as train_app
10
- from wisent.cli.wisent_cli.commands.help_cmd import app as help_router_app
11
- from wisent.cli.wisent_cli.shell import app as shell_app
12
-
13
- app = typer.Typer(
14
- no_args_is_help=True,
15
- add_completion=False,
16
- rich_markup_mode="markdown",
17
- help=(
18
- "[bold]Wisent Guard[/] – steerable activations / steering vectors.\n"
19
- "Collect activations, train steering vectors, and inspect loaders & methods.\n\n"
20
- "Natural commands (no dashes) + `help <topic>` and a `wisent` shell."
21
- ),
22
- )
23
-
24
- app.add_typer(listing_app, name="list")
25
- app.add_typer(train_app, name="") # attach commands directly (e.g., train)
26
- app.add_typer(help_router_app, name="") # help router lives at root (help ...)
27
- app.add_typer(shell_app, name="shell")
28
-
29
- STATE = {"verbose": False}
30
-
31
- @app.callback(invoke_without_command=True)
32
- def _main_callback(
33
- ctx: typer.Context,
34
- version: Optional[bool] = typer.Option(None, "--version", "-V", help="Show version and exit."),
35
- no_banner: bool = typer.Option(False, "--no-banner", help="Disable the startup banner."),
36
- verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose logging."),
37
- logo_width: Optional[int] = typer.Option(None, "--logo-width", help="Width of the wisent badge (28–96)."),
38
- ):
39
- """
40
- Welcome to **Wisent Guard**.
41
-
42
- Examples:
43
- • `wisent help train`
44
- • `wisent train model meta-llama/Llama-3.2-1B-Instruct loader custom path ./custom.json training_limit 5 method caa`
45
- • `wisent list list-methods` (methods)
46
- • `wisent list list-loaders` (loaders)
47
- • `wisent list list-aggregations` (aggregations)
48
- • `wisent shell start` (interactive 'wisent' prompt)
49
- """
50
- if version:
51
- typer.echo(f"{APP_NAME} {APP_VERSION}")
52
- raise typer.Exit()
53
-
54
- STATE["verbose"] = verbose
55
- if not no_banner and (ctx.invoked_subcommand is None or ctx.info_name in {"--help", None}):
56
- print_banner(APP_NAME, width=logo_width or 48)
57
-
58
- if ctx.invoked_subcommand is None:
59
- typer.echo(ctx.get_help())
60
- raise typer.Exit()
61
-
62
- @app.command("instructions")
63
- def instructions():
64
- msg = """
65
- [bold]Quickstart (no-dash style)[/]
66
-
67
- • Preview a run without executing:
68
- [b]wisent train model gpt2 plan-only true[/]
69
-
70
- • Discover components:
71
- [b]wisent list list-methods[/] | [b]wisent list list-loaders[/] | [b]wisent list list-aggregations[/]
72
-
73
- • Get help:
74
- [b]wisent help train[/], [b]wisent help method caa[/], [b]wisent help loader custom[/]
75
-
76
- • Full training example (natural):
77
- [b]wisent train model meta-llama/Llama-3.2-1B-Instruct loader custom path ./wisent_guard/cli/cli_examples/custom_dataset.json training_limit 5 \\
78
- method caa layers 10..12 aggregation continuation_token device cuda dtype float16 save_dir ./steering_output normalize_layers true[/]
79
-
80
- Tip: add [b]interactive true[/] for a guided wizard.
81
- """
82
- try:
83
- from rich.panel import Panel
84
- from wisent.cli.wisent_cli.ui import echo
85
- echo(Panel.fit(msg, title="Instructions", border_style="green"))
86
- except Exception:
87
- print(msg)
88
-
89
- def run():
90
- app()
91
-
92
- if __name__ == "__main__":
93
- run()
@@ -1,80 +0,0 @@
1
- from __future__ import annotations
2
- import shlex
3
- import typer
4
-
5
- from wisent.cli.wisent_cli.ui import print_banner, echo
6
-
7
- try:
8
- from rich.panel import Panel
9
- HAS_RICH = True
10
- except Exception:
11
- HAS_RICH = False
12
-
13
- __all__ = ["app", "start"]
14
-
15
- app = typer.Typer(help="Interactive shell")
16
-
17
- def _run_cli_line(line: str) -> None:
18
- from typer.main import get_command
19
- from wisent.cli.wisent_cli.main import app as root_app
20
- click_cmd = get_command(root_app)
21
- args = shlex.split(line)
22
- try:
23
- click_cmd.main(args=args, standalone_mode=False, prog_name="wisent")
24
- except SystemExit as e:
25
- if e.code not in (0, None):
26
- raise
27
-
28
- @app.command("start")
29
- def start(
30
- logo_width: int = typer.Option(48, "--logo-width", "-w", help="Logo width for the banner (28–96)."),
31
- show_banner: bool = typer.Option(True, "--banner/--no-banner", help="Show banner when the shell starts."),
32
- ):
33
- """
34
- Launch the interactive **wisent** shell.
35
-
36
- Inside the shell, run:
37
- • `help` / `help train` / `help method caa`
38
- • `instructions`
39
- • `train model ...` (no dashes)
40
-
41
- Type `exit`, `quit`, or press `Ctrl-D` to leave.
42
- """
43
- if show_banner:
44
- print_banner("Wisent Guard", logo_width)
45
-
46
- hint = "Type 'help', 'help train', 'instructions', or any command like 'train model ...'. Type 'exit' to quit."
47
- if HAS_RICH:
48
- echo(Panel(hint, title="Welcome to the wisent shell", border_style="green"))
49
- else:
50
- print(hint)
51
-
52
- GREEN, OFF = ("\x1b[32m", "\x1b[0m")
53
- while True:
54
- try:
55
- if HAS_RICH:
56
- from rich.console import Console
57
- line = Console().input("[bold green]wisent[/] ")
58
- else:
59
- line = input(f"{GREEN}wisent{OFF} ")
60
- except (EOFError, KeyboardInterrupt):
61
- typer.echo("\nBye.")
62
- break
63
-
64
- line = line.strip()
65
- if not line:
66
- continue
67
- if line in {"exit", "quit", "q"}:
68
- typer.echo("Bye.")
69
- break
70
- if line in {"help", "--help", "-h"}:
71
- _run_cli_line("--help")
72
- continue
73
- if line.startswith("help "):
74
- _run_cli_line(line)
75
- continue
76
- if line in {"instructions", "--instructions"}:
77
- _run_cli_line("instructions")
78
- continue
79
-
80
- _run_cli_line(line)
@@ -1,69 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any
4
-
5
- try:
6
- from rich.console import Console
7
- from rich.panel import Panel
8
- from rich.text import Text
9
- HAS_RICH = True
10
- _console = Console()
11
- except Exception: # pragma: no cover
12
- HAS_RICH = False
13
- _console = None
14
-
15
- __all__ = ["print_banner", "echo"]
16
-
17
- def _render_wisent_mark(width: int = 48) -> str:
18
- logo = """
19
- ................. .:--++*##%%%%##**+=-:. .................
20
- .. .:=*%@@@@@@@%%%%%%%@@@@@@%*=:. ..
21
- . .-*%@@@%#+=-::.........:-=+#%@@@%*=. .
22
- . -*%@@@#=:. .:=*%@@@*-. .
23
- . .-#@@@*=. .-*@@@#-. .
24
- . :#@@@*: :+%@@#- .
25
- . .+@@@*: :+@@@+. .
26
- . .*@@@@%*=:. -%@@#: .
27
- . .#@@#=*%@@@%*-:. .#@@%: .
28
- ..*@@%. .-+#@@@@#+-:. .*@@%..
29
- .=@@@- :-+#@@@@%*=:. .%@@*.
30
- :#@@+ .:-+#@@@@%#+=:. -@@@-
31
- =@@@: .-=*%@@@@%#+=:.. .#@@+
32
- +@@@*=:. .:-+*%@@@@%#*=-:.. *@@+
33
- +@@@@@@#+-.. .:-=*#@@@@@%#*+--.. +@@+
34
- +@@#-+%@@@%: .:-=*#%@@@@@%#*+=-:.*@@+
35
- =@@%. .=@@@: ..:-=+#%%@@@@@%@@@+
36
- :%@@= :@@@- ..::-=+#@@@=
37
- .+@@%. .#@@* +@@#:
38
- ..%@@*. =@@@: =@@@-.
39
- . :%@@*..#@@#. .:.. =@@@= .
40
- . :%@@*.:%@@*. :#@@%#*+=-::..+@@@= .
41
- . :#@@%-:%@@#: .+@@@#%%@@@@@@%%@@%- .
42
- . .+@@@*=#@@%- .=%@@%=...::-=#@@@@*. .
43
- . :*@@@#%@@@*: .=%@@@+. .:*%@@#- .
44
- . :+%@@@@@@@*-. :=*@@@%+. .-+%@@@*-. .
45
- . .=*%@@@@@@#+:.:-+#@@@%*-. .:-+#%@@@#+: .
46
- . .-+#%@@@@@@@@@@@@#*+**#@@@@@%*=:. .
47
- .............. ..-=+*#%%%@@@@@@@@%%#*=-:. ..............
48
- ................... ....:::::::::.... ...................
49
- """.strip("\n")
50
- return "\n".join(line.center(width) for line in logo.splitlines())
51
-
52
-
53
- def print_banner(title: str, width: int | None = None) -> None:
54
- if HAS_RICH:
55
- usable = 64 if width is None else width
56
- badge = _render_wisent_mark(usable)
57
- _console.print(Panel.fit(Text(badge, style="green"), title=f"[b green]{title}[/]", border_style="green"))
58
- _console.print("[dim]Steering vectors & activation tooling[/]\n")
59
- else:
60
- GREEN, OFF = "\x1b[32m", "\x1b[0m"
61
- print(GREEN + _render_wisent_mark(width or 48) + OFF)
62
- print(f"{title} – Steering vectors & activation tooling\n")
63
-
64
-
65
- def echo(obj: Any) -> None:
66
- if HAS_RICH:
67
- _console.print(obj)
68
- else:
69
- print(obj)
@@ -1,43 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Dict, Optional
3
-
4
- __all__ = ["agg_cls", "descriptions", "name_map", "pick"]
5
-
6
- def agg_cls():
7
- from wisent.core.activations.core.atoms import ActivationAggregationStrategy
8
- return ActivationAggregationStrategy
9
-
10
- def descriptions() -> Dict[object, str]:
11
- A = agg_cls()
12
- return {
13
- A.CHOICE_TOKEN: "Target A/B choice tokens (multiple choice).",
14
- A.CONTINUATION_TOKEN: "Use the first token of the continuation.",
15
- A.LAST_TOKEN: "Always select the last token.",
16
- A.FIRST_TOKEN: "Always select the first token.",
17
- A.MEAN_POOLING: "Aggregate by mean over all tokens.",
18
- A.MAX_POOLING: "Aggregate by max over all tokens.",
19
- }
20
-
21
- def name_map() -> Dict[str, object]:
22
- A = agg_cls()
23
- d = descriptions()
24
- mapping = {s.name.lower(): s for s in d}
25
- mapping.update({
26
- "cont": A.CONTINUATION_TOKEN,
27
- "choice": A.CHOICE_TOKEN,
28
- "mean": A.MEAN_POOLING,
29
- "max": A.MAX_POOLING,
30
- "first": A.FIRST_TOKEN,
31
- "last": A.LAST_TOKEN,
32
- })
33
- return mapping
34
-
35
- def pick(name: Optional[str]):
36
- if not name:
37
- return agg_cls().CONTINUATION_TOKEN
38
- key = name.strip().lower()
39
- mapping = name_map()
40
- if key not in mapping:
41
- valid = ", ".join(sorted(mapping.keys()))
42
- raise ValueError(f"Unknown aggregation {name!r}. Valid: {valid}")
43
- return mapping[key]