wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Device-specific performance benchmarking for wisent-guard.
|
|
3
|
+
|
|
4
|
+
This module runs quick performance tests on the current device to measure
|
|
5
|
+
actual execution times for different operations, then saves those estimates
|
|
6
|
+
for future budget calculations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
import os
|
|
12
|
+
import tempfile
|
|
13
|
+
import subprocess
|
|
14
|
+
import sys
|
|
15
|
+
from typing import Dict, Any, Optional, List
|
|
16
|
+
from dataclasses import dataclass, asdict
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
import hashlib
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from wisent_guard.core.utils.device import resolve_default_device
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class DeviceBenchmark:
|
|
27
|
+
"""Performance benchmark results for a specific device."""
|
|
28
|
+
device_id: str
|
|
29
|
+
device_type: str # "cpu", "cuda", "mps", etc.
|
|
30
|
+
model_loading_seconds: float
|
|
31
|
+
benchmark_eval_seconds_per_100_examples: float
|
|
32
|
+
classifier_training_seconds_per_100_samples: float # Actually measures full classifier creation time (per 100 classifiers)
|
|
33
|
+
data_generation_seconds_per_example: float
|
|
34
|
+
steering_seconds_per_example: float
|
|
35
|
+
benchmark_timestamp: float
|
|
36
|
+
python_version: str
|
|
37
|
+
platform_info: str
|
|
38
|
+
|
|
39
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
40
|
+
"""Convert to dictionary for JSON serialization."""
|
|
41
|
+
return asdict(self)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_dict(cls, data: Dict[str, Any]) -> 'DeviceBenchmark':
|
|
45
|
+
"""Create from dictionary loaded from JSON."""
|
|
46
|
+
return cls(**data)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class DeviceBenchmarker:
|
|
50
|
+
"""Runs performance benchmarks and manages device-specific estimates."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, benchmarks_file: str = "device_benchmarks.json"):
|
|
53
|
+
self.benchmarks_file = benchmarks_file
|
|
54
|
+
self.cached_benchmark: Optional[DeviceBenchmark] = None
|
|
55
|
+
|
|
56
|
+
def get_device_id(self) -> str:
|
|
57
|
+
"""Generate a unique ID for the current device configuration."""
|
|
58
|
+
import platform
|
|
59
|
+
|
|
60
|
+
# Create device fingerprint from hardware/software info
|
|
61
|
+
info_parts = [
|
|
62
|
+
platform.machine(),
|
|
63
|
+
platform.processor(),
|
|
64
|
+
platform.system(),
|
|
65
|
+
platform.release(),
|
|
66
|
+
sys.version,
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
# Add GPU info if available
|
|
70
|
+
device_kind = resolve_default_device()
|
|
71
|
+
if device_kind == "cuda" and torch.cuda.is_available():
|
|
72
|
+
info_parts.append(f"cuda_{torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
73
|
+
elif device_kind == "mps":
|
|
74
|
+
info_parts.append("mps")
|
|
75
|
+
|
|
76
|
+
# Create hash of the combined info
|
|
77
|
+
combined = "|".join(str(part) for part in info_parts)
|
|
78
|
+
device_hash = hashlib.md5(combined.encode()).hexdigest()[:12]
|
|
79
|
+
return device_hash
|
|
80
|
+
|
|
81
|
+
def get_device_type(self) -> str:
|
|
82
|
+
"""Detect the device type (cpu, cuda, mps, etc.)."""
|
|
83
|
+
return resolve_default_device()
|
|
84
|
+
|
|
85
|
+
def load_cached_benchmark(self) -> Optional[DeviceBenchmark]:
|
|
86
|
+
"""Load cached benchmark results if they exist and are recent."""
|
|
87
|
+
if not os.path.exists(self.benchmarks_file):
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
with open(self.benchmarks_file, 'r') as f:
|
|
92
|
+
data = json.load(f)
|
|
93
|
+
|
|
94
|
+
device_id = self.get_device_id()
|
|
95
|
+
if device_id not in data:
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
benchmark_data = data[device_id]
|
|
99
|
+
benchmark = DeviceBenchmark.from_dict(benchmark_data)
|
|
100
|
+
|
|
101
|
+
# Check if benchmark is recent (within 7 days)
|
|
102
|
+
current_time = time.time()
|
|
103
|
+
age_days = (current_time - benchmark.benchmark_timestamp) / (24 * 3600)
|
|
104
|
+
|
|
105
|
+
if age_days > 7:
|
|
106
|
+
print(f" ⚠️ Cached benchmark is {age_days:.1f} days old, will re-run")
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return benchmark
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
print(f" ⚠️ Error loading cached benchmark: {e}")
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def save_benchmark(self, benchmark: DeviceBenchmark) -> None:
|
|
116
|
+
"""Save benchmark results to JSON file."""
|
|
117
|
+
try:
|
|
118
|
+
# Load existing data
|
|
119
|
+
existing_data = {}
|
|
120
|
+
if os.path.exists(self.benchmarks_file):
|
|
121
|
+
with open(self.benchmarks_file, 'r') as f:
|
|
122
|
+
existing_data = json.load(f)
|
|
123
|
+
|
|
124
|
+
# Update with new benchmark
|
|
125
|
+
existing_data[benchmark.device_id] = benchmark.to_dict()
|
|
126
|
+
|
|
127
|
+
# Save back to file
|
|
128
|
+
with open(self.benchmarks_file, 'w') as f:
|
|
129
|
+
json.dump(existing_data, f, indent=2)
|
|
130
|
+
|
|
131
|
+
print(f" 💾 Saved benchmark results to {self.benchmarks_file}")
|
|
132
|
+
|
|
133
|
+
except Exception as e:
|
|
134
|
+
print(f" ❌ Error saving benchmark: {e}")
|
|
135
|
+
|
|
136
|
+
def run_model_loading_benchmark(self) -> float:
|
|
137
|
+
"""Benchmark actual model loading time using the real model."""
|
|
138
|
+
print(" 📊 Benchmarking model loading...")
|
|
139
|
+
|
|
140
|
+
# Create actual model loading test script
|
|
141
|
+
test_script = '''
|
|
142
|
+
import time
|
|
143
|
+
import sys
|
|
144
|
+
sys.path.append('.')
|
|
145
|
+
|
|
146
|
+
start_time = time.time()
|
|
147
|
+
try:
|
|
148
|
+
from wisent_guard.core.model import Model
|
|
149
|
+
# Use the actual model that will be used in production
|
|
150
|
+
model = Model("meta-llama/Llama-3.1-8B-Instruct")
|
|
151
|
+
end_time = time.time()
|
|
152
|
+
print(f"BENCHMARK_RESULT:{end_time - start_time}")
|
|
153
|
+
except Exception as e:
|
|
154
|
+
print(f"BENCHMARK_ERROR:{e}")
|
|
155
|
+
raise
|
|
156
|
+
'''
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
160
|
+
f.write(test_script)
|
|
161
|
+
temp_script = f.name
|
|
162
|
+
|
|
163
|
+
# Run with 2-minute timeout
|
|
164
|
+
result = subprocess.run([
|
|
165
|
+
sys.executable, temp_script
|
|
166
|
+
], capture_output=True, text=True, timeout=120)
|
|
167
|
+
|
|
168
|
+
# Clean up
|
|
169
|
+
os.unlink(temp_script)
|
|
170
|
+
|
|
171
|
+
# Parse result
|
|
172
|
+
for line in result.stdout.split('\n'):
|
|
173
|
+
if line.startswith('BENCHMARK_RESULT:'):
|
|
174
|
+
loading_time = float(line.split(':')[1])
|
|
175
|
+
print(f" Model loading: {loading_time:.1f}s")
|
|
176
|
+
return loading_time
|
|
177
|
+
|
|
178
|
+
except Exception as e:
|
|
179
|
+
print(f" Error in model loading benchmark: {e}")
|
|
180
|
+
raise RuntimeError(f"Model loading benchmark failed: {e}")
|
|
181
|
+
|
|
182
|
+
def run_benchmark_eval_test(self) -> float:
|
|
183
|
+
"""Benchmark evaluation performance using real CLI functionality."""
|
|
184
|
+
print(" 📊 Benchmarking evaluation performance...")
|
|
185
|
+
print(" 🔧 DEBUG: Creating evaluation test script...")
|
|
186
|
+
|
|
187
|
+
# Create evaluation test script using actual CLI
|
|
188
|
+
test_script = '''
|
|
189
|
+
import time
|
|
190
|
+
import sys
|
|
191
|
+
sys.path.append('.')
|
|
192
|
+
|
|
193
|
+
print("BENCHMARK_DEBUG: Starting evaluation benchmark")
|
|
194
|
+
start_time = time.time()
|
|
195
|
+
try:
|
|
196
|
+
print("BENCHMARK_DEBUG: Importing CLI...")
|
|
197
|
+
from wisent_guard.cli import run_task_pipeline
|
|
198
|
+
print("BENCHMARK_DEBUG: CLI imported successfully")
|
|
199
|
+
|
|
200
|
+
print("BENCHMARK_DEBUG: Running task pipeline...")
|
|
201
|
+
# Run actual evaluation with real model and minimal examples
|
|
202
|
+
run_task_pipeline(
|
|
203
|
+
task_name="truthfulqa_mc",
|
|
204
|
+
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
|
205
|
+
layer="15", # Required parameter
|
|
206
|
+
limit=3, # Minimum examples for timing
|
|
207
|
+
steering_mode=False, # No steering for baseline timing
|
|
208
|
+
verbose=False,
|
|
209
|
+
allow_small_dataset=True,
|
|
210
|
+
output_mode="likelihoods"
|
|
211
|
+
)
|
|
212
|
+
print("BENCHMARK_DEBUG: Task pipeline completed")
|
|
213
|
+
|
|
214
|
+
end_time = time.time()
|
|
215
|
+
total_time = end_time - start_time
|
|
216
|
+
print(f"BENCHMARK_DEBUG: Total time: {total_time}s for 3 examples")
|
|
217
|
+
# Scale to per-100-examples
|
|
218
|
+
time_per_100 = (total_time / 3) * 100
|
|
219
|
+
print(f"BENCHMARK_DEBUG: Scaled time per 100: {time_per_100}s")
|
|
220
|
+
print(f"BENCHMARK_RESULT:{time_per_100}")
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
print(f"BENCHMARK_ERROR:{e}")
|
|
224
|
+
import traceback
|
|
225
|
+
traceback.print_exc()
|
|
226
|
+
raise
|
|
227
|
+
'''
|
|
228
|
+
|
|
229
|
+
print(" 🔧 DEBUG: Writing test script to temporary file...")
|
|
230
|
+
try:
|
|
231
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
232
|
+
f.write(test_script)
|
|
233
|
+
temp_script = f.name
|
|
234
|
+
print(f" 🔧 DEBUG: Test script written to {temp_script}")
|
|
235
|
+
|
|
236
|
+
print(" 🔧 DEBUG: Running evaluation subprocess...")
|
|
237
|
+
result = subprocess.run([
|
|
238
|
+
sys.executable, temp_script
|
|
239
|
+
], capture_output=True, text=True, timeout=120) # 2-minute timeout
|
|
240
|
+
|
|
241
|
+
print(f" 🔧 DEBUG: Subprocess completed with return code: {result.returncode}")
|
|
242
|
+
print(f" 🔧 DEBUG: Stdout length: {len(result.stdout)} chars")
|
|
243
|
+
print(f" 🔧 DEBUG: Stderr length: {len(result.stderr)} chars")
|
|
244
|
+
|
|
245
|
+
if result.stderr:
|
|
246
|
+
print(f" ⚠️ DEBUG: Stderr content:\n{result.stderr}")
|
|
247
|
+
|
|
248
|
+
os.unlink(temp_script)
|
|
249
|
+
print(" 🔧 DEBUG: Temporary script cleaned up")
|
|
250
|
+
|
|
251
|
+
# Parse result
|
|
252
|
+
print(" 🔧 DEBUG: Parsing output for BENCHMARK_RESULT...")
|
|
253
|
+
found_result = False
|
|
254
|
+
for line in result.stdout.split('\n'):
|
|
255
|
+
print(f" 🔍 DEBUG: Output line: {repr(line)}")
|
|
256
|
+
if line.startswith('BENCHMARK_RESULT:'):
|
|
257
|
+
eval_time = float(line.split(':')[1])
|
|
258
|
+
print(f" ✅ Evaluation: {eval_time:.1f}s per 100 examples")
|
|
259
|
+
found_result = True
|
|
260
|
+
return eval_time
|
|
261
|
+
|
|
262
|
+
if not found_result:
|
|
263
|
+
print(" ❌ DEBUG: No BENCHMARK_RESULT found in output!")
|
|
264
|
+
print(" 📜 DEBUG: Full stdout:")
|
|
265
|
+
print(result.stdout)
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
print(f" ❌ Error in evaluation benchmark: {e}")
|
|
270
|
+
import traceback
|
|
271
|
+
traceback.print_exc()
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
def run_classifier_training_test(self) -> float:
|
|
275
|
+
"""Benchmark ACTUAL classifier training using real synthetic classifier creation."""
|
|
276
|
+
print(" 📊 Benchmarking classifier training...")
|
|
277
|
+
print(" 🔧 DEBUG: Creating classifier training test script...")
|
|
278
|
+
|
|
279
|
+
# Create test script that uses real synthetic classifier creation
|
|
280
|
+
test_script = '''
|
|
281
|
+
import time
|
|
282
|
+
import platform
|
|
283
|
+
import sys
|
|
284
|
+
import time
|
|
285
|
+
from pathlib import Path
|
|
286
|
+
from typing import Dict, Optional
|
|
287
|
+
try:
|
|
288
|
+
print("BENCHMARK_DEBUG: Importing required modules...")
|
|
289
|
+
from wisent_guard.core.model import Model
|
|
290
|
+
from wisent_guard.core.agent.diagnose.synthetic_classifier_option import create_classifier_from_trait_description
|
|
291
|
+
from wisent_guard.core.agent.budget import set_time_budget
|
|
292
|
+
import time
|
|
293
|
+
print("BENCHMARK_DEBUG: All modules imported successfully")
|
|
294
|
+
|
|
295
|
+
print("BENCHMARK_DEBUG: Starting classifier benchmark")
|
|
296
|
+
|
|
297
|
+
# Set a budget for the classifier creation
|
|
298
|
+
print("BENCHMARK_DEBUG: Setting time budget...")
|
|
299
|
+
set_time_budget(5.0) # 5 minutes
|
|
300
|
+
print("BENCHMARK_DEBUG: Set time budget to 5.0 minutes")
|
|
301
|
+
|
|
302
|
+
# Load the actual model
|
|
303
|
+
print("BENCHMARK_DEBUG: Loading model...")
|
|
304
|
+
model_start = time.time()
|
|
305
|
+
model = Model("meta-llama/Llama-3.1-8B-Instruct")
|
|
306
|
+
model_time = time.time() - model_start
|
|
307
|
+
print(f"BENCHMARK_DEBUG: Model loaded in {model_time}s")
|
|
308
|
+
|
|
309
|
+
# Create ONE actual classifier using the real synthetic process
|
|
310
|
+
print("BENCHMARK_DEBUG: Creating classifier...")
|
|
311
|
+
classifier_start = time.time()
|
|
312
|
+
classifier = create_classifier_from_trait_description(
|
|
313
|
+
model=model,
|
|
314
|
+
trait_description="accuracy and truthfulness",
|
|
315
|
+
num_pairs=3 # Minimum needed for training
|
|
316
|
+
)
|
|
317
|
+
classifier_time = time.time() - classifier_start
|
|
318
|
+
print(f"BENCHMARK_DEBUG: Classifier created in {classifier_time}s")
|
|
319
|
+
|
|
320
|
+
end_time = time.time()
|
|
321
|
+
total_time = end_time - start_time
|
|
322
|
+
print(f"BENCHMARK_DEBUG: Total benchmark time: {total_time}s")
|
|
323
|
+
|
|
324
|
+
# This is time for ONE complete classifier creation
|
|
325
|
+
# Scale to "per 100 classifiers" for compatibility with existing code
|
|
326
|
+
time_per_100 = total_time * 100
|
|
327
|
+
print(f"BENCHMARK_DEBUG: Scaled time per 100 classifiers: {time_per_100}s")
|
|
328
|
+
print(f"BENCHMARK_RESULT:{time_per_100}")
|
|
329
|
+
|
|
330
|
+
except Exception as e:
|
|
331
|
+
print(f"BENCHMARK_ERROR:{e}")
|
|
332
|
+
import traceback
|
|
333
|
+
traceback.print_exc()
|
|
334
|
+
raise
|
|
335
|
+
'''
|
|
336
|
+
|
|
337
|
+
print(" 🔧 DEBUG: Writing classifier test script to temporary file...")
|
|
338
|
+
try:
|
|
339
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
340
|
+
f.write(test_script)
|
|
341
|
+
temp_script = f.name
|
|
342
|
+
print(f" 🔧 DEBUG: Classifier test script written to {temp_script}")
|
|
343
|
+
|
|
344
|
+
print(" 🔧 DEBUG: Running classifier training subprocess (20 min timeout)...")
|
|
345
|
+
result = subprocess.run([
|
|
346
|
+
sys.executable,
|
|
347
|
+
temp_script,
|
|
348
|
+
], capture_output=True, text=True, timeout=1200)
|
|
349
|
+
|
|
350
|
+
print(f" 🔧 DEBUG: Classifier subprocess completed with return code: {result.returncode}")
|
|
351
|
+
print(f" 🔧 DEBUG: Stdout length: {len(result.stdout)} chars")
|
|
352
|
+
print(f" 🔧 DEBUG: Stderr length: {len(result.stderr)} chars")
|
|
353
|
+
|
|
354
|
+
if result.stderr:
|
|
355
|
+
print(f" ⚠️ DEBUG: Classifier stderr content:\n{result.stderr}")
|
|
356
|
+
|
|
357
|
+
os.unlink(temp_script)
|
|
358
|
+
print(" 🔧 DEBUG: Classifier temporary script cleaned up")
|
|
359
|
+
|
|
360
|
+
# Parse result
|
|
361
|
+
print(" 🔧 DEBUG: Parsing classifier output for BENCHMARK_RESULT...")
|
|
362
|
+
for line in result.stdout.split('\n'):
|
|
363
|
+
print(f" 🔍 DEBUG: Classifier output line: {repr(line)}")
|
|
364
|
+
if line.startswith('BENCHMARK_RESULT:'):
|
|
365
|
+
training_time = float(line.split(':')[1])
|
|
366
|
+
print(f" ✅ Classifier training: {training_time:.1f}s per 100 classifiers")
|
|
367
|
+
return training_time
|
|
368
|
+
|
|
369
|
+
print(" ❌ DEBUG: No BENCHMARK_RESULT found in classifier output!")
|
|
370
|
+
print(" 📜 DEBUG: Full classifier stdout:")
|
|
371
|
+
print(result.stdout)
|
|
372
|
+
return None
|
|
373
|
+
|
|
374
|
+
except Exception as e:
|
|
375
|
+
print(f" ❌ Error in classifier training benchmark: {e}")
|
|
376
|
+
import traceback
|
|
377
|
+
traceback.print_exc()
|
|
378
|
+
return None
|
|
379
|
+
|
|
380
|
+
def run_steering_test(self) -> float:
|
|
381
|
+
"""Benchmark steering performance using real CLI functionality."""
|
|
382
|
+
print(" 📊 Benchmarking steering performance...")
|
|
383
|
+
|
|
384
|
+
# Create steering test script using actual CLI
|
|
385
|
+
test_script = '''
|
|
386
|
+
import time
|
|
387
|
+
import sys
|
|
388
|
+
sys.path.append('.')
|
|
389
|
+
|
|
390
|
+
start_time = time.time()
|
|
391
|
+
try:
|
|
392
|
+
from wisent_guard.cli import run_task_pipeline
|
|
393
|
+
|
|
394
|
+
# Run actual steering with real model and minimal examples
|
|
395
|
+
run_task_pipeline(
|
|
396
|
+
task_name="truthfulqa_mc",
|
|
397
|
+
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
|
398
|
+
limit=2, # Minimum examples for timing
|
|
399
|
+
steering_mode=True,
|
|
400
|
+
steering_method="CAA",
|
|
401
|
+
steering_strength=1.0,
|
|
402
|
+
layer="15",
|
|
403
|
+
verbose=False,
|
|
404
|
+
allow_small_dataset=True,
|
|
405
|
+
output_mode="likelihoods"
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
end_time = time.time()
|
|
409
|
+
total_time = end_time - start_time
|
|
410
|
+
# Time per example
|
|
411
|
+
time_per_example = total_time / 2
|
|
412
|
+
print(f"BENCHMARK_RESULT:{time_per_example}")
|
|
413
|
+
|
|
414
|
+
except Exception as e:
|
|
415
|
+
print(f"BENCHMARK_ERROR:{e}")
|
|
416
|
+
raise
|
|
417
|
+
'''
|
|
418
|
+
|
|
419
|
+
try:
|
|
420
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
421
|
+
f.write(test_script)
|
|
422
|
+
temp_script = f.name
|
|
423
|
+
|
|
424
|
+
result = subprocess.run([
|
|
425
|
+
sys.executable,
|
|
426
|
+
temp_script,
|
|
427
|
+
], capture_output=True, text=True, timeout=300)
|
|
428
|
+
|
|
429
|
+
os.unlink(temp_script)
|
|
430
|
+
|
|
431
|
+
for line in result.stdout.split('\n'):
|
|
432
|
+
if line.startswith('BENCHMARK_RESULT:'):
|
|
433
|
+
steering_time = float(line.split(':')[1])
|
|
434
|
+
print(f" Steering: {steering_time:.1f}s per example")
|
|
435
|
+
return steering_time
|
|
436
|
+
|
|
437
|
+
print(" ❌ No BENCHMARK_RESULT found in steering output!")
|
|
438
|
+
print(result.stdout)
|
|
439
|
+
return None
|
|
440
|
+
|
|
441
|
+
except Exception as e:
|
|
442
|
+
print(f" Error in steering benchmark: {e}")
|
|
443
|
+
raise RuntimeError(f"Steering benchmark failed: {e}")
|
|
444
|
+
|
|
445
|
+
def run_data_generation_test(self) -> float:
|
|
446
|
+
"""Benchmark data generation performance using real synthetic generation."""
|
|
447
|
+
print(" 📊 Benchmarking data generation...")
|
|
448
|
+
|
|
449
|
+
# Create data generation test script using actual synthetic pair generation
|
|
450
|
+
test_script = '''
|
|
451
|
+
import time
|
|
452
|
+
import sys
|
|
453
|
+
sys.path.append('.')
|
|
454
|
+
|
|
455
|
+
start_time = time.time()
|
|
456
|
+
try:
|
|
457
|
+
from wisent_guard.core.model import Model
|
|
458
|
+
from wisent_guard.core.contrastive_pairs.generate_synthetically import SyntheticContrastivePairGenerator
|
|
459
|
+
|
|
460
|
+
# Load the actual model
|
|
461
|
+
model = Model("meta-llama/Llama-3.1-8B-Instruct")
|
|
462
|
+
|
|
463
|
+
# Create generator and generate actual synthetic pairs
|
|
464
|
+
generator = SyntheticContrastivePairGenerator(model)
|
|
465
|
+
|
|
466
|
+
# Generate a small set of pairs for timing
|
|
467
|
+
pair_set = generator.generate_contrastive_pair_set(
|
|
468
|
+
trait_description="accuracy and truthfulness",
|
|
469
|
+
num_pairs=1, # Minimum needed for estimation
|
|
470
|
+
name="benchmark_test"
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
end_time = time.time()
|
|
474
|
+
total_time = end_time - start_time
|
|
475
|
+
|
|
476
|
+
# Calculate time per generated pair (each pair has 2 responses)
|
|
477
|
+
num_generated_responses = len(pair_set.pairs) * 2
|
|
478
|
+
if num_generated_responses == 0:
|
|
479
|
+
raise RuntimeError("No pairs were generated during data generation benchmark")
|
|
480
|
+
|
|
481
|
+
time_per_example = total_time / num_generated_responses
|
|
482
|
+
print(f"BENCHMARK_RESULT:{time_per_example}")
|
|
483
|
+
|
|
484
|
+
except Exception as e:
|
|
485
|
+
print(f"BENCHMARK_ERROR:{e}")
|
|
486
|
+
raise
|
|
487
|
+
'''
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
|
491
|
+
f.write(test_script)
|
|
492
|
+
temp_script = f.name
|
|
493
|
+
|
|
494
|
+
result = subprocess.run([
|
|
495
|
+
sys.executable, temp_script
|
|
496
|
+
], capture_output=True, text=True, timeout=300) # 5-minute timeout
|
|
497
|
+
|
|
498
|
+
os.unlink(temp_script)
|
|
499
|
+
|
|
500
|
+
# Parse result
|
|
501
|
+
for line in result.stdout.split('\n'):
|
|
502
|
+
if line.startswith('BENCHMARK_RESULT:'):
|
|
503
|
+
generation_time = float(line.split(':')[1])
|
|
504
|
+
print(f" Data generation: {generation_time:.1f}s per example")
|
|
505
|
+
return generation_time
|
|
506
|
+
|
|
507
|
+
except Exception as e:
|
|
508
|
+
print(f" Error in data generation benchmark: {e}")
|
|
509
|
+
raise RuntimeError(f"Data generation benchmark failed: {e}")
|
|
510
|
+
|
|
511
|
+
def run_full_benchmark(self, force_rerun: bool = False) -> DeviceBenchmark:
|
|
512
|
+
"""Run complete device benchmark suite."""
|
|
513
|
+
# Check for cached results first
|
|
514
|
+
if not force_rerun:
|
|
515
|
+
cached = self.load_cached_benchmark()
|
|
516
|
+
if cached:
|
|
517
|
+
print(f" ✅ Using cached benchmark results (device: {cached.device_id[:8]}...)")
|
|
518
|
+
self.cached_benchmark = cached
|
|
519
|
+
return cached
|
|
520
|
+
|
|
521
|
+
print("🚀 Running device performance benchmark...")
|
|
522
|
+
print(" This will take 1-2 minutes to measure your hardware performance")
|
|
523
|
+
|
|
524
|
+
import platform
|
|
525
|
+
|
|
526
|
+
device_id = self.get_device_id()
|
|
527
|
+
device_type = self.get_device_type()
|
|
528
|
+
|
|
529
|
+
print(f" 🖥️ Device ID: {device_id[:8]}... ({device_type})")
|
|
530
|
+
|
|
531
|
+
# Run all benchmarks with error handling
|
|
532
|
+
try:
|
|
533
|
+
model_loading = self.run_model_loading_benchmark()
|
|
534
|
+
if model_loading is None:
|
|
535
|
+
print(f" ❌ Model loading benchmark returned None")
|
|
536
|
+
raise RuntimeError("Model loading benchmark failed")
|
|
537
|
+
except Exception as e:
|
|
538
|
+
print(f" ❌ Model loading benchmark failed: {e}")
|
|
539
|
+
raise
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
benchmark_eval = self.run_benchmark_eval_test()
|
|
543
|
+
if benchmark_eval is None:
|
|
544
|
+
print(f" ⚠️ Evaluation benchmark returned None, using default value")
|
|
545
|
+
benchmark_eval = 60.0 # Default 60 seconds per 100 examples
|
|
546
|
+
except Exception as e:
|
|
547
|
+
print(f" ❌ Evaluation benchmark failed: {e}")
|
|
548
|
+
benchmark_eval = 60.0 # Default fallback
|
|
549
|
+
|
|
550
|
+
try:
|
|
551
|
+
classifier_training = self.run_classifier_training_test()
|
|
552
|
+
if classifier_training is None:
|
|
553
|
+
print(f" ⚠️ Classifier training benchmark returned None, using default value")
|
|
554
|
+
classifier_training = 600.0 # Default 600 seconds per 100 classifiers
|
|
555
|
+
except Exception as e:
|
|
556
|
+
print(f" ❌ Classifier training benchmark failed: {e}")
|
|
557
|
+
classifier_training = 600.0 # Default fallback
|
|
558
|
+
|
|
559
|
+
try:
|
|
560
|
+
steering = self.run_steering_test()
|
|
561
|
+
if steering is None:
|
|
562
|
+
print(f" ❌ Steering benchmark returned None")
|
|
563
|
+
raise RuntimeError("Steering benchmark failed")
|
|
564
|
+
except Exception as e:
|
|
565
|
+
print(f" ❌ Steering benchmark failed: {e}")
|
|
566
|
+
raise
|
|
567
|
+
|
|
568
|
+
try:
|
|
569
|
+
data_generation = self.run_data_generation_test()
|
|
570
|
+
if data_generation is None:
|
|
571
|
+
print(f" ❌ Data generation benchmark returned None")
|
|
572
|
+
raise RuntimeError("Data generation benchmark failed")
|
|
573
|
+
except Exception as e:
|
|
574
|
+
print(f" ❌ Data generation benchmark failed: {e}")
|
|
575
|
+
raise
|
|
576
|
+
|
|
577
|
+
# Create benchmark result
|
|
578
|
+
benchmark = DeviceBenchmark(
|
|
579
|
+
device_id=device_id,
|
|
580
|
+
device_type=device_type,
|
|
581
|
+
model_loading_seconds=model_loading,
|
|
582
|
+
benchmark_eval_seconds_per_100_examples=benchmark_eval,
|
|
583
|
+
classifier_training_seconds_per_100_samples=classifier_training,
|
|
584
|
+
data_generation_seconds_per_example=data_generation,
|
|
585
|
+
steering_seconds_per_example=steering,
|
|
586
|
+
benchmark_timestamp=time.time(),
|
|
587
|
+
python_version=sys.version,
|
|
588
|
+
platform_info=platform.platform()
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# Save results
|
|
592
|
+
self.save_benchmark(benchmark)
|
|
593
|
+
self.cached_benchmark = benchmark
|
|
594
|
+
|
|
595
|
+
print(" ✅ Benchmark complete!")
|
|
596
|
+
print(f" Model loading: {model_loading:.1f}s")
|
|
597
|
+
print(f" Evaluation: {benchmark_eval:.1f}s per 100 examples")
|
|
598
|
+
print(f" Classifier creation: {classifier_training:.1f}s per 100 classifiers")
|
|
599
|
+
print(f" Steering: {steering:.1f}s per example")
|
|
600
|
+
print(f" Generation: {data_generation:.1f}s per example")
|
|
601
|
+
|
|
602
|
+
return benchmark
|
|
603
|
+
|
|
604
|
+
def get_current_benchmark(self, auto_run: bool = True) -> Optional[DeviceBenchmark]:
|
|
605
|
+
"""Get current device benchmark, optionally auto-running if needed."""
|
|
606
|
+
if self.cached_benchmark:
|
|
607
|
+
return self.cached_benchmark
|
|
608
|
+
|
|
609
|
+
cached = self.load_cached_benchmark()
|
|
610
|
+
if cached:
|
|
611
|
+
self.cached_benchmark = cached
|
|
612
|
+
return cached
|
|
613
|
+
|
|
614
|
+
if auto_run:
|
|
615
|
+
return self.run_full_benchmark()
|
|
616
|
+
|
|
617
|
+
return None
|
|
618
|
+
|
|
619
|
+
def estimate_task_time(self, task_type: str, quantity: int = 1) -> float:
|
|
620
|
+
"""
|
|
621
|
+
Estimate time for a specific task type and quantity.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
task_type: Type of task ("model_loading", "benchmark_eval", etc.)
|
|
625
|
+
quantity: Number of items (examples, samples, etc.)
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
Estimated time in seconds
|
|
629
|
+
"""
|
|
630
|
+
benchmark = self.get_current_benchmark()
|
|
631
|
+
if not benchmark:
|
|
632
|
+
raise RuntimeError(f"No benchmark available for device. Run benchmark first with: python -m wisent_guard.core.agent.budget benchmark")
|
|
633
|
+
else:
|
|
634
|
+
# Use actual benchmark results
|
|
635
|
+
if task_type == "model_loading":
|
|
636
|
+
return benchmark.model_loading_seconds
|
|
637
|
+
elif task_type == "benchmark_eval":
|
|
638
|
+
base_time = benchmark.benchmark_eval_seconds_per_100_examples
|
|
639
|
+
return (base_time / 100.0) * quantity
|
|
640
|
+
elif task_type == "classifier_training":
|
|
641
|
+
base_time = benchmark.classifier_training_seconds_per_100_samples # Actually per 100 classifiers now
|
|
642
|
+
return (base_time / 100.0) * quantity
|
|
643
|
+
elif task_type == "steering":
|
|
644
|
+
return benchmark.steering_seconds_per_example * quantity
|
|
645
|
+
elif task_type == "data_generation":
|
|
646
|
+
return benchmark.data_generation_seconds_per_example * quantity
|
|
647
|
+
else:
|
|
648
|
+
raise ValueError(f"Unknown task type: {task_type}")
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
# Global benchmarker instance
|
|
652
|
+
_device_benchmarker = DeviceBenchmarker()
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def get_device_benchmarker() -> DeviceBenchmarker:
|
|
656
|
+
"""Get the global device benchmarker instance."""
|
|
657
|
+
return _device_benchmarker
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
def ensure_benchmark_exists(force_rerun: bool = False) -> DeviceBenchmark:
|
|
661
|
+
"""Ensure device benchmark exists, running it if necessary."""
|
|
662
|
+
return _device_benchmarker.run_full_benchmark(force_rerun=force_rerun)
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def estimate_task_time(task_type: str, quantity: int = 1) -> float:
|
|
666
|
+
"""
|
|
667
|
+
Convenience function to estimate task time.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
task_type: Type of task ("model_loading", "benchmark_eval", etc.)
|
|
671
|
+
quantity: Number of items
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
Estimated time in seconds
|
|
675
|
+
"""
|
|
676
|
+
return _device_benchmarker.estimate_task_time(task_type, quantity)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def get_current_device_info() -> Dict[str, str]:
|
|
680
|
+
"""Get current device information."""
|
|
681
|
+
benchmarker = get_device_benchmarker()
|
|
682
|
+
return {
|
|
683
|
+
"device_id": benchmarker.get_device_id(),
|
|
684
|
+
"device_type": benchmarker.get_device_type()
|
|
685
|
+
}
|