wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+ from collections import defaultdict
6
+ import inspect
7
+
8
+ import torch
9
+
10
+ from wisent.core.activations.core.atoms import LayerActivations, RawActivationMap, LayerName
11
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
12
+
13
+ __all__ = [
14
+ "SteeringError",
15
+ "BaseSteeringMethod",
16
+ "PerLayerBaseSteeringMethod",
17
+ ]
18
+
19
+ class BaseSteeringError(RuntimeError):
20
+ """Raised when a steering method fails or is misconfigured."""
21
+
22
+ class BaseSteeringMethod(ABC):
23
+ name: str = "base"
24
+ description: str = "Abstract steering method"
25
+ _REGISTRY: dict[str, type[BaseSteeringMethod]] = {}
26
+
27
+ def __init_subclass__(cls, **kwargs):
28
+ super().__init_subclass__(**kwargs)
29
+ if cls is BaseSteeringMethod:
30
+ return
31
+ if inspect.isabstract(cls):
32
+ return
33
+ if not getattr(cls, "name", None):
34
+ raise TypeError("BaseSteeringMethod subclasses must define `name`.")
35
+ if cls.name in BaseSteeringMethod._REGISTRY:
36
+ raise ValueError(f"Duplicate steering method: {cls.name!r}")
37
+ BaseSteeringMethod._REGISTRY[cls.name] = cls
38
+
39
+ def __init__(self, **kwargs: Any) -> None:
40
+ self.kwargs: dict[str, Any] = dict(kwargs)
41
+
42
+ @abstractmethod
43
+ def train(self, pair_set: ContrastivePairSet) -> LayerActivations:
44
+ """
45
+ Produce per-layer vectors from the given contrastive set.
46
+
47
+ arguments:
48
+ pair_set: ContrastivePairSet with collected activations.
49
+
50
+ returns:
51
+ LayerActivations with one steering vector per layer.
52
+ """
53
+ ...
54
+
55
+ @classmethod
56
+ def list_registered(cls) -> dict[str, type[BaseSteeringMethod]]:
57
+ """
58
+ list all registered steering methods.
59
+
60
+ returns:
61
+ dict mapping method name to class.
62
+ """
63
+ return dict(cls._REGISTRY)
64
+
65
+ @classmethod
66
+ def get(cls, name: str) -> type[BaseSteeringMethod]:
67
+ """
68
+ Get a registered steering method class by name.
69
+
70
+ arguments:
71
+ name: str name of the steering method.
72
+
73
+ returns:
74
+ BaseSteeringMethod subclass.
75
+
76
+ raises:
77
+ SteeringError if name is unknown.
78
+ """
79
+ try:
80
+ return cls._REGISTRY[name]
81
+ except KeyError as exc:
82
+ raise BaseSteeringError(f"Unknown steering method: {name!r}") from exc
83
+
84
+
85
+ class PerLayerBaseSteeringMethod(BaseSteeringMethod):
86
+ """
87
+ Base for steering methods that compute one vector per layer independently.
88
+ Subclasses must implement 'train_for_layer'.
89
+ """
90
+
91
+ @abstractmethod
92
+ def train_for_layer(self, pos_list: list[torch.Tensor], neg_list: list[torch.Tensor]) -> torch.Tensor:
93
+ """
94
+ Compute a vector for ONE layer from lists of positives/negatives.
95
+
96
+ arguments:
97
+ pos_list: list of tensors from positive examples.
98
+ neg_list: list of tensors from negative examples.
99
+
100
+ returns:
101
+ torch.Tensor steering vector for the layer.
102
+ """
103
+ ...
104
+
105
+ def _collect_from_set(self, pair_set: ContrastivePairSet) -> dict[LayerName, tuple[list[torch.Tensor], list[torch.Tensor]]]:
106
+ """
107
+ Build {layer_name: ([pos tensors...], [neg tensors...])} by iterating pairs.
108
+ Skips entries where activations are missing/None.
109
+
110
+ arguments:
111
+ pair_set: ContrastivePairSet with collected activations.
112
+
113
+ returns:
114
+ dict mapping layer names to tuples of (list of pos tensors, list of neg tensors).
115
+ """
116
+ buckets: dict[LayerName, tuple[list[torch.Tensor], list[torch.Tensor]]] = defaultdict(lambda: ([], []))
117
+ for pair in pair_set.pairs: # ContrastivePair
118
+ pos_la = getattr(pair.positive_response, "layers_activations", None)
119
+ neg_la = getattr(pair.negative_response, "layers_activations", None)
120
+
121
+ if pos_la is None or neg_la is None:
122
+ continue
123
+
124
+ layer_names = set(pos_la.to_dict().keys()) | set(neg_la.to_dict().keys())
125
+ for layer in layer_names:
126
+ p = pos_la.to_dict().get(layer, None) if pos_la is not None else None
127
+ n = neg_la.to_dict().get(layer, None) if neg_la is not None else None
128
+ if isinstance(p, torch.Tensor) and isinstance(n, torch.Tensor):
129
+ buckets[layer][0].append(p)
130
+ buckets[layer][1].append(n)
131
+ return buckets
132
+
133
+ def train(self, pair_set: ContrastivePairSet) -> LayerActivations:
134
+ """
135
+ Produce per-layer steering vectors from the given contrastive set.
136
+
137
+ arguments:
138
+ pair_set: ContrastivePairSet with collected activations.
139
+
140
+ returns:
141
+ LayerActivations with one steering vector per layer.
142
+ """
143
+ buckets = self._collect_from_set(pair_set)
144
+
145
+ raw: RawActivationMap = {}
146
+ for layer, (pos_list, neg_list) in sorted(buckets.items(), key=lambda kv: (len(kv[0]), kv[0])):
147
+ if not pos_list or not neg_list:
148
+ continue
149
+ raw[layer] = self.train_for_layer(pos_list, neg_list)
150
+
151
+ dtype = self.kwargs.get("dtype", None)
152
+ agg = self.kwargs.get("activation_aggregation_strategy", None)
153
+ return LayerActivations(raw, activation_aggregation_strategy=agg, dtype=dtype)
File without changes
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List
4
+ import torch
5
+
6
+ from wisent.core.steering_methods.core.atoms import PerLayerBaseSteeringMethod
7
+
8
+ __all__ = [
9
+ "CAAMethod",
10
+ ]
11
+
12
+ class CAAMethod(PerLayerBaseSteeringMethod):
13
+ """
14
+ Contrastive Activation Additions (CAA).
15
+ For each layer: v = mean(positives) - mean(negatives),
16
+ optionally L2-normalized (kwargs: normalize=True, dtype=..., activation_aggregation_strategy=...).
17
+ """
18
+ name = "caa"
19
+ description = "Per-layer mean(pos)-mean(neg) over ContrastivePairSet."
20
+
21
+ def train_for_layer(self, pos_list: List[torch.Tensor], neg_list: List[torch.Tensor]) -> torch.Tensor:
22
+ """
23
+ Train CAA vector for a single layer.
24
+
25
+ arguments:
26
+ pos_list: List of positive activations (torch.Tensor) for this layer.
27
+ neg_list: List of negative activations (torch.Tensor) for this layer.
28
+
29
+ returns:
30
+ torch.Tensor steering vector for the layer.
31
+ """
32
+ if not pos_list or not neg_list:
33
+ raise ValueError("Both positive and negative lists must be non-empty.")
34
+ pos = torch.stack([t.detach().to("cpu").float().reshape(-1) for t in pos_list], dim=0) # [N_pos, H]
35
+ neg = torch.stack([t.detach().to("cpu").float().reshape(-1) for t in neg_list], dim=0) # [N_neg, H]
36
+ v = pos.mean(dim=0) - neg.mean(dim=0)
37
+ if bool(self.kwargs.get("normalize", True)):
38
+ v = self._safe_l2_normalize(v)
39
+ return v
40
+
41
+ def _safe_l2_normalize(self, v: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
42
+ if v.ndim != 1:
43
+ v = v.reshape(-1)
44
+ return v / (torch.linalg.norm(v) + eps)