wisent 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/core/activations/__init__.py +26 -0
- wisent/core/activations/activations.py +96 -0
- wisent/core/activations/activations_collector.py +71 -20
- wisent/core/activations/prompt_construction_strategy.py +47 -0
- wisent/core/agent/__init__.py +1 -18
- wisent/core/agent/budget.py +2 -2
- wisent/core/agent/device_benchmarks.py +1 -1
- wisent/core/agent/diagnose/__init__.py +1 -55
- wisent/core/agent/diagnose/classifier_marketplace.py +8 -8
- wisent/core/agent/diagnose/response_diagnostics.py +4 -4
- wisent/core/agent/diagnose/synthetic_classifier_option.py +1 -1
- wisent/core/agent/diagnose/tasks/task_manager.py +3 -3
- wisent/core/agent/diagnose.py +2 -1
- wisent/core/autonomous_agent.py +10 -2
- wisent/core/benchmark_extractors.py +293 -0
- wisent/core/bigcode_integration.py +20 -7
- wisent/core/branding.py +108 -0
- wisent/core/cli/__init__.py +15 -0
- wisent/core/cli/create_steering_vector.py +138 -0
- wisent/core/cli/evaluate_responses.py +715 -0
- wisent/core/cli/generate_pairs.py +128 -0
- wisent/core/cli/generate_pairs_from_task.py +119 -0
- wisent/core/cli/generate_responses.py +129 -0
- wisent/core/cli/generate_vector_from_synthetic.py +149 -0
- wisent/core/cli/generate_vector_from_task.py +147 -0
- wisent/core/cli/get_activations.py +191 -0
- wisent/core/cli/optimize_classification.py +339 -0
- wisent/core/cli/optimize_steering.py +364 -0
- wisent/core/cli/tasks.py +182 -0
- wisent/core/cli_logger.py +22 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +27 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +49 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_challenge.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arc_easy.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/arithmetic.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/asdiv.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/boolq.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/cb.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/copa.py +118 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/coqa.py +146 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/drop.py +129 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/gsm8k.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/headqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/hellaswag.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/livecodebench.py +367 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/logiqa2.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mc-taco.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/medqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mrpc.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/multirc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/mutual.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/openbookqa.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pawsx.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/piqa.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/prost.py +113 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/pubmedqa.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qa4mre.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qasper.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/qqp.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/race.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/record.py +121 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/rte.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sciq.py +110 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/social_iqa.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/squad2.py +124 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/sst2.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/swag.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/triviaqa.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_gen.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc1.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/truthfulqa_mc2.py +117 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/webqs.py +127 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wic.py +119 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +1 -1
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wnli.py +111 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/wsc.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xnli.py +112 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xstorycloze.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/xwinograd.py +114 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +1 -1
- wisent/core/data_loaders/__init__.py +235 -0
- wisent/core/data_loaders/loaders/lm_loader.py +2 -2
- wisent/core/data_loaders/loaders/task_interface_loader.py +300 -0
- wisent/{cli/data_loaders/data_loader_rotator.py → core/data_loaders/rotator.py} +1 -1
- wisent/core/download_full_benchmarks.py +79 -2
- wisent/core/evaluators/benchmark_specific/__init__.py +26 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/evaluator.py +17 -17
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/cpp_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/java_sanitizer.py +2 -2
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/python_sanitizer.py +2 -2
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/__init__.py +3 -0
- wisent/core/evaluators/benchmark_specific/coding/providers/livecodebench/provider.py +305 -0
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/runtime.py +36 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/entrypoint.py +2 -4
- wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/recipes.py +1 -1
- wisent/core/evaluators/benchmark_specific/coding/solution_generator.py +258 -0
- wisent/core/evaluators/benchmark_specific/exact_match_evaluator.py +79 -0
- wisent/core/evaluators/benchmark_specific/f1_evaluator.py +101 -0
- wisent/core/evaluators/benchmark_specific/generation_evaluator.py +197 -0
- wisent/core/{log_likelihoods_evaluator.py → evaluators/benchmark_specific/log_likelihoods_evaluator.py} +10 -2
- wisent/core/evaluators/benchmark_specific/perplexity_evaluator.py +140 -0
- wisent/core/evaluators/benchmark_specific/personalization_evaluator.py +250 -0
- wisent/{cli/evaluators/evaluator_rotator.py → core/evaluators/rotator.py} +4 -4
- wisent/core/lm_eval_harness_ground_truth.py +3 -2
- wisent/core/main.py +57 -0
- wisent/core/model_persistence.py +2 -2
- wisent/core/models/wisent_model.py +6 -6
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +2 -2
- wisent/core/optuna/steering/steering_optimization.py +1 -1
- wisent/core/parser_arguments/__init__.py +10 -0
- wisent/core/parser_arguments/agent_parser.py +110 -0
- wisent/core/parser_arguments/configure_model_parser.py +7 -0
- wisent/core/parser_arguments/create_steering_vector_parser.py +59 -0
- wisent/core/parser_arguments/evaluate_parser.py +40 -0
- wisent/core/parser_arguments/evaluate_responses_parser.py +10 -0
- wisent/core/parser_arguments/full_optimize_parser.py +115 -0
- wisent/core/parser_arguments/generate_pairs_from_task_parser.py +33 -0
- wisent/core/parser_arguments/generate_pairs_parser.py +29 -0
- wisent/core/parser_arguments/generate_responses_parser.py +15 -0
- wisent/core/parser_arguments/generate_vector_from_synthetic_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_from_task_parser.py +127 -0
- wisent/core/parser_arguments/generate_vector_parser.py +90 -0
- wisent/core/parser_arguments/get_activations_parser.py +90 -0
- wisent/core/parser_arguments/main_parser.py +152 -0
- wisent/core/parser_arguments/model_config_parser.py +59 -0
- wisent/core/parser_arguments/monitor_parser.py +17 -0
- wisent/core/parser_arguments/multi_steer_parser.py +47 -0
- wisent/core/parser_arguments/optimize_classification_parser.py +67 -0
- wisent/core/parser_arguments/optimize_sample_size_parser.py +58 -0
- wisent/core/parser_arguments/optimize_steering_parser.py +147 -0
- wisent/core/parser_arguments/synthetic_parser.py +93 -0
- wisent/core/parser_arguments/tasks_parser.py +584 -0
- wisent/core/parser_arguments/test_nonsense_parser.py +26 -0
- wisent/core/parser_arguments/utils.py +111 -0
- wisent/core/prompts/core/prompt_formater.py +3 -3
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +2 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +2 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +2 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +2 -0
- wisent/{cli/steering_methods/steering_rotator.py → core/steering_methods/rotator.py} +4 -4
- wisent/core/steering_optimizer.py +45 -21
- wisent/{synthetic → core/synthetic}/cleaners/deduper_cleaner.py +3 -3
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_dedupers.py +2 -2
- wisent/{synthetic → core/synthetic}/cleaners/methods/base_refusalers.py +1 -1
- wisent/{synthetic → core/synthetic}/cleaners/pairs_cleaner.py +5 -5
- wisent/{synthetic → core/synthetic}/cleaners/refusaler_cleaner.py +4 -4
- wisent/{synthetic → core/synthetic}/db_instructions/mini_dp.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/diversities/methods/fast_diversity.py +1 -1
- wisent/{synthetic → core/synthetic}/generators/pairs_generator.py +38 -12
- wisent/core/tasks/livecodebench_task.py +4 -103
- wisent/core/timing_calibration.py +1 -1
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/METADATA +3 -3
- wisent-0.5.14.dist-info/RECORD +294 -0
- wisent-0.5.14.dist-info/entry_points.txt +2 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +0 -53
- wisent/classifiers/core/atoms.py +0 -747
- wisent/classifiers/models/logistic.py +0 -29
- wisent/classifiers/models/mlp.py +0 -47
- wisent/cli/classifiers/classifier_rotator.py +0 -137
- wisent/cli/cli_logger.py +0 -142
- wisent/cli/wisent_cli/commands/help_cmd.py +0 -52
- wisent/cli/wisent_cli/commands/listing.py +0 -154
- wisent/cli/wisent_cli/commands/train_cmd.py +0 -322
- wisent/cli/wisent_cli/main.py +0 -93
- wisent/cli/wisent_cli/shell.py +0 -80
- wisent/cli/wisent_cli/ui.py +0 -69
- wisent/cli/wisent_cli/util/aggregations.py +0 -43
- wisent/cli/wisent_cli/util/parsing.py +0 -126
- wisent/cli/wisent_cli/version.py +0 -4
- wisent/opti/methods/__init__.py +0 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent-0.5.12.dist-info/RECORD +0 -220
- /wisent/{benchmarks → core/evaluators/benchmark_specific/coding}/__init__.py +0 -0
- /wisent/{benchmarks/coding → core/evaluators/benchmark_specific/coding/metrics}/__init__.py +0 -0
- /wisent/{benchmarks/coding/metrics → core/evaluators/benchmark_specific/coding/metrics/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/metrics/passk.py +0 -0
- /wisent/{benchmarks/coding/metrics/core → core/evaluators/benchmark_specific/coding/output_sanitizer}/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/output_sanitizer/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/core/atoms.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/output_sanitizer/utils.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/__init__.py +0 -0
- /wisent/{benchmarks/coding/output_sanitizer → core/evaluators/benchmark_specific/coding/providers}/core/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/providers/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/providers/core → core/evaluators/benchmark_specific/coding/safe_docker}/__init__.py +0 -0
- /wisent/{benchmarks/coding/providers/livecodebench → core/evaluators/benchmark_specific/coding/safe_docker/core}/__init__.py +0 -0
- /wisent/{benchmarks → core/evaluators/benchmark_specific}/coding/safe_docker/core/atoms.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/__init__.py +0 -0
- /wisent/{benchmarks/coding/safe_docker → core/opti}/core/__init__.py +0 -0
- /wisent/{opti → core/opti}/core/atoms.py +0 -0
- /wisent/{classifiers → core/opti/methods}/__init__.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_classificator.py +0 -0
- /wisent/{opti → core/opti}/methods/opti_steering.py +0 -0
- /wisent/{classifiers/core → core/synthetic}/__init__.py +0 -0
- /wisent/{classifiers/models → core/synthetic/cleaners}/__init__.py +0 -0
- /wisent/{cli → core/synthetic/cleaners/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/core/atoms.py +0 -0
- /wisent/{cli/classifiers → core/synthetic/cleaners/methods}/__init__.py +0 -0
- /wisent/{cli/data_loaders → core/synthetic/cleaners/methods/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/cleaners/methods/core/atoms.py +0 -0
- /wisent/{cli/evaluators → core/synthetic/db_instructions}/__init__.py +0 -0
- /wisent/{cli/steering_methods → core/synthetic/db_instructions/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/db_instructions/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli → core/synthetic/generators}/__init__.py +0 -0
- /wisent/{cli/wisent_cli/commands → core/synthetic/generators/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/core/atoms.py +0 -0
- /wisent/{cli/wisent_cli/util → core/synthetic/generators/diversities}/__init__.py +0 -0
- /wisent/{opti → core/synthetic/generators/diversities/core}/__init__.py +0 -0
- /wisent/{synthetic → core/synthetic}/generators/diversities/core/core.py +0 -0
- /wisent/{opti/core → core/synthetic/generators/diversities/methods}/__init__.py +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/WHEEL +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.12.dist-info → wisent-0.5.14.dist-info}/top_level.txt +0 -0
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
|
|
6
|
-
from wisent.classifiers.core.atoms import BaseClassifier
|
|
7
|
-
|
|
8
|
-
__all__ = ["LogisticClassifier"]
|
|
9
|
-
|
|
10
|
-
class LogisticModel(nn.Module):
|
|
11
|
-
"""Simple logistic regression model for activation classification."""
|
|
12
|
-
def __init__(self, input_dim: int):
|
|
13
|
-
super().__init__()
|
|
14
|
-
self.linear = nn.Linear(input_dim, 1)
|
|
15
|
-
self.sigmoid = nn.Sigmoid()
|
|
16
|
-
|
|
17
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
18
|
-
logits = self.linear(x)
|
|
19
|
-
if logits.ndim == 1:
|
|
20
|
-
logits = logits.unsqueeze(1)
|
|
21
|
-
return self.sigmoid(logits)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class LogisticClassifier(BaseClassifier):
|
|
25
|
-
name = "logistic"
|
|
26
|
-
description = "One-layer logistic regression over dense features"
|
|
27
|
-
|
|
28
|
-
def build_model(self, input_dim: int, **_: object) -> nn.Module:
|
|
29
|
-
return LogisticModel(input_dim)
|
wisent/classifiers/models/mlp.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
|
|
6
|
-
from wisent.classifiers.core.atoms import BaseClassifier
|
|
7
|
-
|
|
8
|
-
__all__ = ["MLPClassifier"]
|
|
9
|
-
|
|
10
|
-
class MLPModel(nn.Module):
|
|
11
|
-
"""Multi-layer perceptron for activation classification."""
|
|
12
|
-
def __init__(self, input_dim: int, hidden_dim: int = 128, dropout: float = 0.2):
|
|
13
|
-
super().__init__()
|
|
14
|
-
self.net = nn.Sequential(
|
|
15
|
-
nn.Linear(input_dim, hidden_dim),
|
|
16
|
-
nn.ReLU(),
|
|
17
|
-
nn.Dropout(dropout),
|
|
18
|
-
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
19
|
-
nn.ReLU(),
|
|
20
|
-
nn.Dropout(dropout),
|
|
21
|
-
nn.Linear(hidden_dim // 2, 1),
|
|
22
|
-
nn.Sigmoid(),
|
|
23
|
-
)
|
|
24
|
-
|
|
25
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
|
-
out = self.net(x)
|
|
27
|
-
if out.ndim == 1:
|
|
28
|
-
out = out.unsqueeze(1)
|
|
29
|
-
return out
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class MLPClassifier(BaseClassifier):
|
|
33
|
-
name = "mlp"
|
|
34
|
-
description = "Two-layer MLP with dropout and ReLU"
|
|
35
|
-
|
|
36
|
-
def __init__(self, *, hidden_dim: int = 128, **base_kwargs):
|
|
37
|
-
super().__init__(**base_kwargs)
|
|
38
|
-
self._hidden_dim = int(hidden_dim)
|
|
39
|
-
|
|
40
|
-
def build_model(self, input_dim: int, **model_params: object) -> nn.Module:
|
|
41
|
-
hd = int(model_params.get("hidden_dim", self._hidden_dim))
|
|
42
|
-
dp = float(model_params.get("dropout", 0.2))
|
|
43
|
-
self._hidden_dim = hd
|
|
44
|
-
return MLPModel(input_dim, hidden_dim=hd, dropout=dp)
|
|
45
|
-
|
|
46
|
-
def model_hyperparams(self) -> dict[str, int]:
|
|
47
|
-
return {"hidden_dim": self._hidden_dim, "dropout": 0.2}
|
|
@@ -1,137 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import importlib
|
|
4
|
-
import importlib.util
|
|
5
|
-
import inspect
|
|
6
|
-
import pkgutil
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Any
|
|
9
|
-
|
|
10
|
-
from wisent.core.classifiers.core.atoms import BaseClassifier, ClassifierError, ClassifierTrainReport
|
|
11
|
-
|
|
12
|
-
__all__ = ["ClassifierRotator"]
|
|
13
|
-
|
|
14
|
-
class ClassifierRotator:
|
|
15
|
-
"""
|
|
16
|
-
Discover, list, and delegate to registered classifiers.
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
classifier: str | BaseClassifier | type[BaseClassifier] | None = None,
|
|
22
|
-
classifiers_location: str | Path = "wisent_guard.core.classifiers.models",
|
|
23
|
-
autoload: bool = True,
|
|
24
|
-
**classifier_kwargs: Any,
|
|
25
|
-
) -> None:
|
|
26
|
-
if autoload:
|
|
27
|
-
self.discover_classifiers(classifiers_location)
|
|
28
|
-
self._classifier = self._resolve_classifier(classifier, **classifier_kwargs)
|
|
29
|
-
|
|
30
|
-
@staticmethod
|
|
31
|
-
def discover_classifiers(location: str | Path = "wisent_guard.core.classifiers.models") -> None:
|
|
32
|
-
"""
|
|
33
|
-
Import all classifier modules so BaseClassifier subclasses self-register.
|
|
34
|
-
|
|
35
|
-
- If `location` is a dotted module path (str without existing FS path),
|
|
36
|
-
import that package and iterate its __path__ (works with namespace packages).
|
|
37
|
-
- If `location` is an existing directory (Path/str), import all .py files inside.
|
|
38
|
-
"""
|
|
39
|
-
loc_path = Path(str(location))
|
|
40
|
-
if loc_path.exists() and loc_path.is_dir():
|
|
41
|
-
ClassifierRotator._import_all_py_in_dir(loc_path)
|
|
42
|
-
return
|
|
43
|
-
|
|
44
|
-
if not isinstance(location, str):
|
|
45
|
-
raise ClassifierError(
|
|
46
|
-
f"Invalid classifiers location: {location!r}. Provide a dotted module path or a directory."
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
try:
|
|
50
|
-
pkg = importlib.import_module(location)
|
|
51
|
-
except ModuleNotFoundError as exc:
|
|
52
|
-
raise ClassifierError(
|
|
53
|
-
f"Cannot import classifier package {location!r}. "
|
|
54
|
-
f"Use a dotted path (no leading slash) and ensure your project root is on PYTHONPATH."
|
|
55
|
-
) from exc
|
|
56
|
-
|
|
57
|
-
search_paths = list(getattr(pkg, "__path__", []))
|
|
58
|
-
if not search_paths:
|
|
59
|
-
pkg_file = getattr(pkg, "__file__", None)
|
|
60
|
-
if pkg_file:
|
|
61
|
-
search_paths = [str(Path(pkg_file).parent)]
|
|
62
|
-
|
|
63
|
-
for _finder, name, _ispkg in pkgutil.iter_modules(search_paths):
|
|
64
|
-
if name.startswith("_"):
|
|
65
|
-
continue
|
|
66
|
-
importlib.import_module(f"{location}.{name}")
|
|
67
|
-
|
|
68
|
-
@staticmethod
|
|
69
|
-
def _import_all_py_in_dir(directory: Path) -> None:
|
|
70
|
-
for py in directory.glob("*.py"):
|
|
71
|
-
if py.name.startswith("_"):
|
|
72
|
-
continue
|
|
73
|
-
mod_name = f"_dyn_classifiers_{py.stem}"
|
|
74
|
-
spec = importlib.util.spec_from_file_location(mod_name, py)
|
|
75
|
-
if spec and spec.loader:
|
|
76
|
-
module = importlib.util.module_from_spec(spec)
|
|
77
|
-
spec.loader.exec_module(module) # type: ignore[attr-defined]
|
|
78
|
-
|
|
79
|
-
@staticmethod
|
|
80
|
-
def list_classifiers() -> list[dict[str, Any]]:
|
|
81
|
-
out: list[dict[str, Any]] = []
|
|
82
|
-
for name, cls in BaseClassifier.list_registered().items():
|
|
83
|
-
out.append(
|
|
84
|
-
{
|
|
85
|
-
"name": name,
|
|
86
|
-
"description": getattr(cls, "description", ""),
|
|
87
|
-
"class": f"{cls.__module__}.{cls.__name__}",
|
|
88
|
-
}
|
|
89
|
-
)
|
|
90
|
-
return sorted(out, key=lambda x: x["name"])
|
|
91
|
-
|
|
92
|
-
@staticmethod
|
|
93
|
-
def _resolve_classifier(
|
|
94
|
-
classifier: str | BaseClassifier | type[BaseClassifier] | None,
|
|
95
|
-
**kwargs: Any,
|
|
96
|
-
) -> BaseClassifier:
|
|
97
|
-
if classifier is None:
|
|
98
|
-
registry = BaseClassifier.list_registered()
|
|
99
|
-
if not registry:
|
|
100
|
-
raise ClassifierError("No classifiers registered.")
|
|
101
|
-
# Deterministic pick: first by name
|
|
102
|
-
return next(iter(sorted(registry.items())))[1](**kwargs)
|
|
103
|
-
if isinstance(classifier, BaseClassifier):
|
|
104
|
-
return classifier
|
|
105
|
-
if inspect.isclass(classifier) and issubclass(classifier, BaseClassifier):
|
|
106
|
-
return classifier(**kwargs)
|
|
107
|
-
if isinstance(classifier, str):
|
|
108
|
-
cls = BaseClassifier.get(classifier)
|
|
109
|
-
return cls(**kwargs)
|
|
110
|
-
raise TypeError(
|
|
111
|
-
"classifier must be None, a name (str), BaseClassifier instance, or BaseClassifier subclass."
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def use(self, classifier: str | BaseClassifier | type[BaseClassifier], **kwargs: Any) -> None:
|
|
116
|
-
self._classifier = self._resolve_classifier(classifier, **kwargs)
|
|
117
|
-
|
|
118
|
-
def fit(self, X, y, **kwargs) -> ClassifierTrainReport:
|
|
119
|
-
return self._classifier.fit(X, y, **kwargs)
|
|
120
|
-
|
|
121
|
-
def predict(self, X):
|
|
122
|
-
return self._classifier.predict(X)
|
|
123
|
-
|
|
124
|
-
def predict_proba(self, X):
|
|
125
|
-
return self._classifier.predict_proba(X)
|
|
126
|
-
|
|
127
|
-
def evaluate(self, X, y) -> dict[str, float]:
|
|
128
|
-
return self._classifier.evaluate(X, y)
|
|
129
|
-
|
|
130
|
-
def save_model(self, path: str) -> None:
|
|
131
|
-
self._classifier.save_model(path)
|
|
132
|
-
|
|
133
|
-
def load_model(self, path: str) -> None:
|
|
134
|
-
self._classifier.load_model(path)
|
|
135
|
-
|
|
136
|
-
def set_threshold(self, threshold: float) -> None:
|
|
137
|
-
self._classifier.set_threshold(threshold)
|
wisent/cli/cli_logger.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import logging
|
|
5
|
-
import sys
|
|
6
|
-
from datetime import datetime, timezone
|
|
7
|
-
from typing import Any, Mapping
|
|
8
|
-
|
|
9
|
-
__all__ = [
|
|
10
|
-
"setup_logger",
|
|
11
|
-
"bind",
|
|
12
|
-
"JsonFormatter",
|
|
13
|
-
"ContextAdapter",
|
|
14
|
-
"add_file_handler",
|
|
15
|
-
]
|
|
16
|
-
|
|
17
|
-
class JsonFormatter(logging.Formatter):
|
|
18
|
-
"""
|
|
19
|
-
Minimal JSON formatter with structured fields + extras.
|
|
20
|
-
"""
|
|
21
|
-
_STD = {
|
|
22
|
-
"name", "msg", "args", "levelname", "levelno", "pathname",
|
|
23
|
-
"filename", "module", "exc_info", "exc_text", "stack_info",
|
|
24
|
-
"lineno", "funcName", "created", "msecs", "relativeCreated",
|
|
25
|
-
"thread", "threadName", "processName", "process"
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
def format(self, record: logging.LogRecord) -> str:
|
|
29
|
-
payload: dict[str, Any] = {
|
|
30
|
-
"ts": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
|
31
|
-
"level": record.levelname,
|
|
32
|
-
"logger": record.name,
|
|
33
|
-
"message": record.getMessage(),
|
|
34
|
-
"file": record.filename,
|
|
35
|
-
"func": record.funcName,
|
|
36
|
-
"line": record.lineno,
|
|
37
|
-
}
|
|
38
|
-
extras = {
|
|
39
|
-
k: v for k, v in record.__dict__.items()
|
|
40
|
-
if k not in self._STD and not k.startswith("_")
|
|
41
|
-
}
|
|
42
|
-
if extras:
|
|
43
|
-
payload["extra"] = extras
|
|
44
|
-
if record.exc_info:
|
|
45
|
-
payload["exc"] = self.formatException(record.exc_info)
|
|
46
|
-
return json.dumps(payload, ensure_ascii=False)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class ContextAdapter(logging.LoggerAdapter):
|
|
50
|
-
"""
|
|
51
|
-
LoggerAdapter that ensures persistent context fields appear in every log entry.
|
|
52
|
-
"""
|
|
53
|
-
def process(self, msg, kwargs):
|
|
54
|
-
extra = kwargs.get("extra", {})
|
|
55
|
-
extra.update(self.extra or {})
|
|
56
|
-
kwargs["extra"] = extra
|
|
57
|
-
return msg, kwargs
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class _EnsureContextFilter(logging.Filter):
|
|
61
|
-
"""
|
|
62
|
-
Adds default values for context keys so format strings never KeyError.
|
|
63
|
-
"""
|
|
64
|
-
def __init__(self, defaults: Mapping[str, Any] | None = None):
|
|
65
|
-
super().__init__()
|
|
66
|
-
self.defaults = dict(defaults or {})
|
|
67
|
-
|
|
68
|
-
def filter(self, record: logging.LogRecord) -> bool:
|
|
69
|
-
for k, v in self.defaults.items():
|
|
70
|
-
if not hasattr(record, k):
|
|
71
|
-
setattr(record, k, v)
|
|
72
|
-
return True
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def setup_logger(
|
|
76
|
-
name: str = "wisent",
|
|
77
|
-
level: int = logging.INFO,
|
|
78
|
-
*,
|
|
79
|
-
json_logs: bool = False,
|
|
80
|
-
stream = sys.stderr,
|
|
81
|
-
) -> logging.Logger:
|
|
82
|
-
"""
|
|
83
|
-
Create or return a named logger with a single stream handler.
|
|
84
|
-
Safe to call multiple times; won’t duplicate handlers.
|
|
85
|
-
"""
|
|
86
|
-
logger = logging.getLogger(name)
|
|
87
|
-
logger.setLevel(level)
|
|
88
|
-
if not logger.handlers:
|
|
89
|
-
handler = logging.StreamHandler(stream)
|
|
90
|
-
if json_logs:
|
|
91
|
-
handler.setFormatter(JsonFormatter())
|
|
92
|
-
else:
|
|
93
|
-
handler.setFormatter(logging.Formatter(
|
|
94
|
-
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s "
|
|
95
|
-
"[file=%(filename)s func=%(funcName)s line=%(lineno)d] "
|
|
96
|
-
"%(task_name)s%(subtask)s",
|
|
97
|
-
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
|
98
|
-
))
|
|
99
|
-
# ensure context placeholders always exist
|
|
100
|
-
handler.addFilter(_EnsureContextFilter({"task_name": "", "subtask": ""}))
|
|
101
|
-
logger.addHandler(handler)
|
|
102
|
-
logger.propagate = False
|
|
103
|
-
return logger
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def add_file_handler(
|
|
107
|
-
logger: logging.Logger,
|
|
108
|
-
filepath: str,
|
|
109
|
-
*,
|
|
110
|
-
level: int | None = None,
|
|
111
|
-
json_logs: bool = False,
|
|
112
|
-
) -> None:
|
|
113
|
-
"""
|
|
114
|
-
Optionally add a file handler (e.g., for long-running CLI jobs).
|
|
115
|
-
"""
|
|
116
|
-
fh = logging.FileHandler(filepath, encoding="utf-8")
|
|
117
|
-
fh.setLevel(level or logger.level)
|
|
118
|
-
if json_logs:
|
|
119
|
-
fh.setFormatter(JsonFormatter())
|
|
120
|
-
else:
|
|
121
|
-
fh.setFormatter(logging.Formatter(
|
|
122
|
-
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s "
|
|
123
|
-
"[file=%(filename)s func=%(funcName)s line=%(lineno)d] "
|
|
124
|
-
"%(task_name)s%(subtask)s",
|
|
125
|
-
datefmt="%Y-%m-%dT%H:%M:%S%z",
|
|
126
|
-
))
|
|
127
|
-
fh.addFilter(_EnsureContextFilter({"task_name": "", "subtask": ""}))
|
|
128
|
-
logger.addHandler(fh)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def bind(
|
|
132
|
-
logger: logging.Logger | ContextAdapter,
|
|
133
|
-
**extra: Any
|
|
134
|
-
) -> ContextAdapter:
|
|
135
|
-
"""
|
|
136
|
-
Return a ContextAdapter with merged extras.
|
|
137
|
-
Works whether you pass a raw Logger or an existing ContextAdapter.
|
|
138
|
-
"""
|
|
139
|
-
if isinstance(logger, ContextAdapter):
|
|
140
|
-
merged = {**logger.extra, **extra}
|
|
141
|
-
return ContextAdapter(logger.logger, merged)
|
|
142
|
-
return ContextAdapter(logger, extra)
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
import shlex
|
|
3
|
-
from typing import Optional
|
|
4
|
-
import typer
|
|
5
|
-
|
|
6
|
-
__all__ = ["app", "help_command"]
|
|
7
|
-
|
|
8
|
-
app = typer.Typer(help="Human-friendly help router")
|
|
9
|
-
|
|
10
|
-
def _run_cli_line(line: str) -> None:
|
|
11
|
-
from typer.main import get_command
|
|
12
|
-
from wisent.cli.wisent_cli.main import app as root_app
|
|
13
|
-
click_cmd = get_command(root_app)
|
|
14
|
-
args = shlex.split(line)
|
|
15
|
-
try:
|
|
16
|
-
click_cmd.main(args=args, standalone_mode=False, prog_name="wisent")
|
|
17
|
-
except SystemExit as e:
|
|
18
|
-
if e.code not in (0, None):
|
|
19
|
-
raise
|
|
20
|
-
|
|
21
|
-
@app.command("help")
|
|
22
|
-
def help_command(topic: Optional[str] = typer.Argument(None), name: Optional[str] = typer.Argument(None)):
|
|
23
|
-
"""
|
|
24
|
-
Examples:
|
|
25
|
-
wisent help train
|
|
26
|
-
wisent help method caa
|
|
27
|
-
wisent help loader custom
|
|
28
|
-
wisent help list-methods
|
|
29
|
-
"""
|
|
30
|
-
t = (topic or "").strip().lower()
|
|
31
|
-
|
|
32
|
-
if t in {"", "app", "main"}:
|
|
33
|
-
_run_cli_line("--help")
|
|
34
|
-
return
|
|
35
|
-
|
|
36
|
-
passthrough = {"train", "list-methods", "list-loaders", "list-aggregations", "explain", "instructions", "start"}
|
|
37
|
-
if t in passthrough:
|
|
38
|
-
_run_cli_line(f"{t} --help")
|
|
39
|
-
return
|
|
40
|
-
|
|
41
|
-
if t in {"method", "methods"} and name:
|
|
42
|
-
_run_cli_line(f"explain --method {shlex.quote(name)}")
|
|
43
|
-
return
|
|
44
|
-
if t in {"loader", "loaders"} and name:
|
|
45
|
-
_run_cli_line(f"explain --loader {shlex.quote(name)}")
|
|
46
|
-
_run_cli_line(f"loader-args {shlex.quote(name)}")
|
|
47
|
-
return
|
|
48
|
-
if t in {"aggregation", "agg", "aggregations"} and name:
|
|
49
|
-
_run_cli_line(f"explain --aggregation {shlex.quote(name)}")
|
|
50
|
-
return
|
|
51
|
-
|
|
52
|
-
_run_cli_line("--help")
|
|
@@ -1,154 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
import inspect
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import typer
|
|
6
|
-
|
|
7
|
-
from wisent.cli.wisent_cli.ui import echo
|
|
8
|
-
from wisent.cli.wisent_cli.util import aggregations as aggs
|
|
9
|
-
|
|
10
|
-
try:
|
|
11
|
-
from rich.table import Table
|
|
12
|
-
from rich.panel import Panel
|
|
13
|
-
HAS_RICH = True
|
|
14
|
-
except Exception:
|
|
15
|
-
HAS_RICH = False
|
|
16
|
-
|
|
17
|
-
app = typer.Typer(help="Listing, discovery, and explanation commands")
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@app.command("list-aggregations")
|
|
21
|
-
def list_aggregations():
|
|
22
|
-
if HAS_RICH:
|
|
23
|
-
t = Table(title="Aggregation Strategies")
|
|
24
|
-
t.add_column("Name", style="bold")
|
|
25
|
-
t.add_column("Description")
|
|
26
|
-
for k, desc in aggs.descriptions().items():
|
|
27
|
-
t.add_row(k.name.lower(), desc)
|
|
28
|
-
echo(t)
|
|
29
|
-
else:
|
|
30
|
-
for k, desc in aggs.descriptions().items():
|
|
31
|
-
print(f"- {k.name.lower():22s} : {desc}")
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@app.command("list-methods")
|
|
35
|
-
def list_methods():
|
|
36
|
-
from wisent.cli.steering_methods.steering_rotator import SteeringMethodRotator # type: ignore
|
|
37
|
-
rot = SteeringMethodRotator()
|
|
38
|
-
methods = rot.list_methods()
|
|
39
|
-
if not methods:
|
|
40
|
-
typer.echo("No steering methods registered.")
|
|
41
|
-
raise typer.Exit(code=1)
|
|
42
|
-
if HAS_RICH:
|
|
43
|
-
t = Table(title="Registered Steering Methods")
|
|
44
|
-
t.add_column("Name", style="bold")
|
|
45
|
-
t.add_column("Description")
|
|
46
|
-
t.add_column("Class")
|
|
47
|
-
for m in methods:
|
|
48
|
-
t.add_row(m["name"], m["description"], m["class"])
|
|
49
|
-
echo(t)
|
|
50
|
-
else:
|
|
51
|
-
for m in methods:
|
|
52
|
-
print(f"- {m['name']}: {m['description']} ({m['class']})")
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@app.command("list-loaders")
|
|
56
|
-
def list_loaders(
|
|
57
|
-
loaders_location: Optional[str] = typer.Option(None, help="Package path or directory containing data loader modules"),
|
|
58
|
-
scope_prefix: Optional[str] = typer.Option(None, help="Limit list to module path prefix"),
|
|
59
|
-
):
|
|
60
|
-
from wisent.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
61
|
-
if loaders_location:
|
|
62
|
-
DataLoaderRotator.discover_loaders(loaders_location)
|
|
63
|
-
loaders = DataLoaderRotator.list_loaders(scope_prefix=scope_prefix)
|
|
64
|
-
if not loaders:
|
|
65
|
-
typer.echo("No data loaders found.")
|
|
66
|
-
raise typer.Exit(code=1)
|
|
67
|
-
if HAS_RICH:
|
|
68
|
-
t = Table(title="Registered Data Loaders")
|
|
69
|
-
t.add_column("Name", style="bold")
|
|
70
|
-
t.add_column("Description")
|
|
71
|
-
t.add_column("Class")
|
|
72
|
-
for l in loaders:
|
|
73
|
-
t.add_row(l["name"], l["description"], l["class"])
|
|
74
|
-
echo(t)
|
|
75
|
-
else:
|
|
76
|
-
for l in loaders:
|
|
77
|
-
print(f"- {l['name']}: {l['description']} ({l['class']})")
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@app.command("explain")
|
|
81
|
-
def explain(
|
|
82
|
-
method: Optional[str] = typer.Option(None, help="Steering method to describe"),
|
|
83
|
-
loader: Optional[str] = typer.Option(None, help="Data loader to describe"),
|
|
84
|
-
loaders_location: Optional[str] = typer.Option(None, help="Where to discover data loaders"),
|
|
85
|
-
aggregation: Optional[str] = typer.Option(None, help="Aggregation to describe"),
|
|
86
|
-
):
|
|
87
|
-
from wisent.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
88
|
-
from wisent.cli.steering_methods.steering_rotator import SteeringMethodRotator # type: ignore
|
|
89
|
-
|
|
90
|
-
if loaders_location:
|
|
91
|
-
DataLoaderRotator.discover_loaders(loaders_location)
|
|
92
|
-
|
|
93
|
-
if method:
|
|
94
|
-
m = SteeringMethodRotator._resolve_method(method)
|
|
95
|
-
doc = (getattr(type(m), "__doc__", None) or "No docstring.").strip()
|
|
96
|
-
if HAS_RICH:
|
|
97
|
-
echo(Panel(doc, title=f"Method: {getattr(m, 'name', type(m).__name__)}"))
|
|
98
|
-
else:
|
|
99
|
-
print(doc)
|
|
100
|
-
|
|
101
|
-
if loader:
|
|
102
|
-
reg = DataLoaderRotator.list_loaders()
|
|
103
|
-
match = next((x for x in reg if x["name"].lower() == loader.lower()), None)
|
|
104
|
-
if match:
|
|
105
|
-
desc = (match.get("description") or "No description.").strip()
|
|
106
|
-
if HAS_RICH:
|
|
107
|
-
echo(Panel(desc, title=f"Loader: {loader}"))
|
|
108
|
-
else:
|
|
109
|
-
print(desc)
|
|
110
|
-
else:
|
|
111
|
-
typer.echo(f"Unknown loader: {loader}")
|
|
112
|
-
|
|
113
|
-
if aggregation:
|
|
114
|
-
agg = aggs.pick(aggregation)
|
|
115
|
-
desc = aggs.descriptions().get(agg, "No description.")
|
|
116
|
-
if HAS_RICH:
|
|
117
|
-
from rich.panel import Panel
|
|
118
|
-
echo(Panel(desc, title=f"Aggregation: {aggregation}"))
|
|
119
|
-
else:
|
|
120
|
-
print(desc)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
@app.command("loader-args")
|
|
124
|
-
def loader_args(
|
|
125
|
-
name: str = typer.Argument(..., help="Loader name, e.g. 'custom'"),
|
|
126
|
-
loaders_location: Optional[str] = typer.Option(None, help="Where to discover data loaders"),
|
|
127
|
-
):
|
|
128
|
-
"""
|
|
129
|
-
Show the exact arguments accepted by the loader's `load(...)` method.
|
|
130
|
-
Useful so users know precisely what to pass (e.g., for `custom`:
|
|
131
|
-
`path, split_ratio, seed, training_limit, testing_limit`).
|
|
132
|
-
"""
|
|
133
|
-
from wisent.cli.data_loaders.data_loader_rotator import DataLoaderRotator # type: ignore
|
|
134
|
-
if loaders_location:
|
|
135
|
-
DataLoaderRotator.discover_loaders(loaders_location)
|
|
136
|
-
rot = DataLoaderRotator(loader=name, loaders_location=loaders_location or "wisent_guard.core.data_loaders.loaders")
|
|
137
|
-
# Best-effort introspection
|
|
138
|
-
target = None
|
|
139
|
-
for cand in (getattr(rot, "_loader", None), rot):
|
|
140
|
-
if cand is None:
|
|
141
|
-
continue
|
|
142
|
-
if hasattr(cand, "load"):
|
|
143
|
-
target = cand.load
|
|
144
|
-
break
|
|
145
|
-
if target is None:
|
|
146
|
-
typer.echo("Could not introspect loader signature.")
|
|
147
|
-
raise typer.Exit(code=1)
|
|
148
|
-
|
|
149
|
-
sig = inspect.signature(target)
|
|
150
|
-
if HAS_RICH:
|
|
151
|
-
from rich.panel import Panel
|
|
152
|
-
echo(Panel(f"{name}.load{sig}", title="Loader load(...) signature"))
|
|
153
|
-
else:
|
|
154
|
-
print(f"{name}.load{sig}")
|