wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,648 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sample Size Optimizer for finding the optimal training sample size for classifiers.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import numpy as np
|
|
14
|
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
15
|
+
|
|
16
|
+
from wisent_guard.core.classifier.classifier import Classifier
|
|
17
|
+
|
|
18
|
+
from .activations import ActivationAggregationStrategy
|
|
19
|
+
from .contrastive_pairs import ContrastivePairSet
|
|
20
|
+
from .model import Model
|
|
21
|
+
from .model_config_manager import ModelConfigManager
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SampleSizeOptimizer:
|
|
27
|
+
"""Optimizes training sample size for classifiers."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
model_name: str,
|
|
32
|
+
task_name: str = "truthfulqa_mc1",
|
|
33
|
+
layer: int = 0,
|
|
34
|
+
token_aggregation: str = "average",
|
|
35
|
+
threshold: float = 0.5,
|
|
36
|
+
test_split: float = 0.2,
|
|
37
|
+
sample_sizes: Optional[List[int]] = None,
|
|
38
|
+
device: Optional[str] = None,
|
|
39
|
+
verbose: bool = False,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the sample size optimizer.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model_name: Name of the model to optimize
|
|
46
|
+
task_name: Task to optimize for
|
|
47
|
+
layer: Layer index to optimize
|
|
48
|
+
token_aggregation: Token aggregation method (average, final, first, max, min)
|
|
49
|
+
threshold: Detection threshold for classification
|
|
50
|
+
test_split: Fraction of data to use for testing
|
|
51
|
+
sample_sizes: List of sample sizes to test
|
|
52
|
+
device: Device to use for computation
|
|
53
|
+
verbose: Enable verbose output
|
|
54
|
+
"""
|
|
55
|
+
self.model_name = model_name
|
|
56
|
+
self.task_name = task_name
|
|
57
|
+
self.layer = layer
|
|
58
|
+
self.token_aggregation = token_aggregation
|
|
59
|
+
self.threshold = threshold
|
|
60
|
+
self.test_split = test_split
|
|
61
|
+
self.verbose = verbose
|
|
62
|
+
|
|
63
|
+
# Default sample sizes if not provided
|
|
64
|
+
if sample_sizes is None:
|
|
65
|
+
self.sample_sizes = [1, 2, 5, 10, 20, 50, 100, 200, 500]
|
|
66
|
+
else:
|
|
67
|
+
self.sample_sizes = sorted(sample_sizes)
|
|
68
|
+
|
|
69
|
+
# Initialize model
|
|
70
|
+
self.model = Model(name=model_name, device=device)
|
|
71
|
+
self.device = self.model.device
|
|
72
|
+
|
|
73
|
+
# Storage for results
|
|
74
|
+
self.results = []
|
|
75
|
+
self.optimal_sample_size = None
|
|
76
|
+
|
|
77
|
+
logger.info(f"Initialized SampleSizeOptimizer for {model_name}")
|
|
78
|
+
logger.info(f"Task: {task_name}, Layer: {layer}")
|
|
79
|
+
logger.info(f"Sample sizes to test: {self.sample_sizes}")
|
|
80
|
+
|
|
81
|
+
def load_and_split_data(self, limit: Optional[int] = None) -> Tuple[ContrastivePairSet, ContrastivePairSet]:
|
|
82
|
+
"""
|
|
83
|
+
Load task data and split into train/test sets.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
limit: Maximum number of samples to load (None for all)
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Tuple of (train_pairs, test_pairs)
|
|
90
|
+
"""
|
|
91
|
+
logger.info(f"Loading data for task: {self.task_name}")
|
|
92
|
+
|
|
93
|
+
# Load task data using the model
|
|
94
|
+
max_samples = limit or 1000 # Default to 1000 if not specified
|
|
95
|
+
|
|
96
|
+
# Try to use cached benchmark data first
|
|
97
|
+
qa_pairs = None
|
|
98
|
+
try:
|
|
99
|
+
from .managed_cached_benchmarks import get_managed_cache
|
|
100
|
+
|
|
101
|
+
cache = get_managed_cache()
|
|
102
|
+
logger.info(f"Attempting to load from cache with limit={max_samples}")
|
|
103
|
+
|
|
104
|
+
# Load samples from cache (it will download if needed)
|
|
105
|
+
samples = cache.get_task_samples(self.task_name, limit=max_samples)
|
|
106
|
+
|
|
107
|
+
if samples:
|
|
108
|
+
logger.info(f"Loaded {len(samples)} samples from cache")
|
|
109
|
+
# Convert cached samples to QA pairs format
|
|
110
|
+
qa_pairs = []
|
|
111
|
+
for sample in samples:
|
|
112
|
+
# The cached sample has 'normalized' field with the QA pair
|
|
113
|
+
normalized = sample.get("normalized", {})
|
|
114
|
+
# Handle both formats: good_response/bad_response and correct_answer
|
|
115
|
+
if "good_response" in normalized and "bad_response" in normalized:
|
|
116
|
+
qa_pair = {
|
|
117
|
+
"question": normalized.get("context", normalized.get("question", "")),
|
|
118
|
+
"correct_answer": normalized.get("good_response", ""),
|
|
119
|
+
"incorrect_answer": normalized.get("bad_response", ""),
|
|
120
|
+
"metadata": normalized.get("metadata", {}),
|
|
121
|
+
}
|
|
122
|
+
else:
|
|
123
|
+
# For truthfulqa_mc1, we need to get incorrect answers from mc1_targets
|
|
124
|
+
raw_data = sample.get("raw_data", {})
|
|
125
|
+
mc1_targets = raw_data.get("mc1_targets", {})
|
|
126
|
+
choices = mc1_targets.get("choices", [])
|
|
127
|
+
labels = mc1_targets.get("labels", [])
|
|
128
|
+
|
|
129
|
+
# Find first incorrect answer
|
|
130
|
+
incorrect_answer = None
|
|
131
|
+
for i, label in enumerate(labels):
|
|
132
|
+
if label == 0 and i < len(choices):
|
|
133
|
+
incorrect_answer = choices[i]
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
if not incorrect_answer:
|
|
137
|
+
incorrect_answer = "This is incorrect"
|
|
138
|
+
|
|
139
|
+
qa_pair = {
|
|
140
|
+
"question": normalized.get("question", ""),
|
|
141
|
+
"correct_answer": normalized.get("correct_answer", ""),
|
|
142
|
+
"incorrect_answer": incorrect_answer,
|
|
143
|
+
"metadata": normalized.get("metadata", {}),
|
|
144
|
+
}
|
|
145
|
+
qa_pairs.append(qa_pair)
|
|
146
|
+
logger.info(f"Converted {len(qa_pairs)} cached samples to QA pairs")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"Failed to load from cache: {e}")
|
|
149
|
+
qa_pairs = None
|
|
150
|
+
|
|
151
|
+
# Fallback to loading from lm-eval if cache failed
|
|
152
|
+
if not qa_pairs:
|
|
153
|
+
logger.info("Loading from lm-eval harness...")
|
|
154
|
+
# Load lm-eval task
|
|
155
|
+
task_data = self.model.load_lm_eval_task(self.task_name, shots=0, limit=max_samples)
|
|
156
|
+
|
|
157
|
+
# Split into train/test docs
|
|
158
|
+
docs, _ = self.model.split_task_data(task_data, split_ratio=1.0) # Use all for now
|
|
159
|
+
|
|
160
|
+
if not docs:
|
|
161
|
+
raise ValueError(f"No documents loaded for task {self.task_name}")
|
|
162
|
+
|
|
163
|
+
logger.info(f"Loaded {len(docs)} documents from {self.task_name}")
|
|
164
|
+
|
|
165
|
+
# Extract QA pairs from task docs
|
|
166
|
+
qa_pairs = ContrastivePairSet.extract_qa_pairs_from_task_docs(self.task_name, task_data, docs)
|
|
167
|
+
|
|
168
|
+
if not qa_pairs:
|
|
169
|
+
raise ValueError(f"No QA pairs could be extracted from task {self.task_name}")
|
|
170
|
+
|
|
171
|
+
logger.info(f"Extracted {len(qa_pairs)} QA pairs")
|
|
172
|
+
|
|
173
|
+
# Create contrastive pairs from QA pairs
|
|
174
|
+
from wisent_guard.core.activations.activation_collection_method import ActivationCollectionLogic
|
|
175
|
+
|
|
176
|
+
collector = ActivationCollectionLogic(model=self.model)
|
|
177
|
+
|
|
178
|
+
# Import token aggregation function
|
|
179
|
+
|
|
180
|
+
# Create contrastive pairs
|
|
181
|
+
all_pairs = []
|
|
182
|
+
for qa_pair in qa_pairs:
|
|
183
|
+
# Create prompts for positive and negative cases
|
|
184
|
+
question = qa_pair["question"]
|
|
185
|
+
correct_answer = qa_pair["correct_answer"]
|
|
186
|
+
incorrect_answer = qa_pair["incorrect_answer"]
|
|
187
|
+
|
|
188
|
+
# Generate with model to get activations
|
|
189
|
+
# Positive case (correct answer)
|
|
190
|
+
pos_prompt = self.model.format_prompt(question)
|
|
191
|
+
pos_response = correct_answer
|
|
192
|
+
|
|
193
|
+
# Negative case (incorrect answer)
|
|
194
|
+
neg_prompt = self.model.format_prompt(question)
|
|
195
|
+
neg_response = incorrect_answer
|
|
196
|
+
|
|
197
|
+
# Create contrastive pair
|
|
198
|
+
from .contrastive_pairs import ContrastivePair
|
|
199
|
+
from .response import NegativeResponse, PositiveResponse
|
|
200
|
+
|
|
201
|
+
pair = ContrastivePair(
|
|
202
|
+
prompt=question,
|
|
203
|
+
positive_response=PositiveResponse(text=pos_response),
|
|
204
|
+
negative_response=NegativeResponse(text=neg_response),
|
|
205
|
+
)
|
|
206
|
+
all_pairs.append(pair)
|
|
207
|
+
|
|
208
|
+
if not all_pairs:
|
|
209
|
+
raise ValueError(f"No contrastive pairs created for task {self.task_name}")
|
|
210
|
+
|
|
211
|
+
# Extract activations for all pairs at the specified layer
|
|
212
|
+
logger.info(f"Extracting activations at layer {self.layer}")
|
|
213
|
+
|
|
214
|
+
# Use the collector to extract activations
|
|
215
|
+
# For MULTIPLE_CHOICE, we use CHOICE_TOKEN targeting
|
|
216
|
+
all_pairs = collector.collect_activations_batch(
|
|
217
|
+
all_pairs,
|
|
218
|
+
layer_index=self.layer,
|
|
219
|
+
device=self.device,
|
|
220
|
+
token_targeting_strategy=ActivationAggregationStrategy.CHOICE_TOKEN,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Filter out any pairs without activations
|
|
224
|
+
all_pairs = [p for p in all_pairs if p.positive_activations is not None and p.negative_activations is not None]
|
|
225
|
+
|
|
226
|
+
logger.info(f"Loaded {len(all_pairs)} contrastive pairs")
|
|
227
|
+
|
|
228
|
+
# Calculate split index
|
|
229
|
+
n_test = int(len(all_pairs) * self.test_split)
|
|
230
|
+
n_train = len(all_pairs) - n_test
|
|
231
|
+
|
|
232
|
+
# Create train and test sets
|
|
233
|
+
# Use a fixed seed for reproducibility
|
|
234
|
+
np.random.seed(42)
|
|
235
|
+
indices = np.random.permutation(len(all_pairs))
|
|
236
|
+
|
|
237
|
+
train_indices = indices[:n_train]
|
|
238
|
+
test_indices = indices[n_train:]
|
|
239
|
+
|
|
240
|
+
train_pairs = [all_pairs[i] for i in train_indices]
|
|
241
|
+
test_pairs = [all_pairs[i] for i in test_indices]
|
|
242
|
+
|
|
243
|
+
# Create ContrastivePairSet objects
|
|
244
|
+
train_set = ContrastivePairSet(name=f"{self.task_name}_train", pairs=train_pairs)
|
|
245
|
+
test_set = ContrastivePairSet(name=f"{self.task_name}_test", pairs=test_pairs)
|
|
246
|
+
|
|
247
|
+
logger.info(f"Split data: {len(train_pairs)} train, {len(test_pairs)} test")
|
|
248
|
+
|
|
249
|
+
return train_set, test_set
|
|
250
|
+
|
|
251
|
+
def _aggregate_activations(self, activations):
|
|
252
|
+
"""
|
|
253
|
+
Apply token aggregation to activations based on configured method.
|
|
254
|
+
|
|
255
|
+
Since we're using CHOICE_TOKEN strategy, activations should be a single vector.
|
|
256
|
+
This method is here for consistency with the main CLI approach.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
activations: Activation vector or tensor
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Aggregated activation vector
|
|
263
|
+
"""
|
|
264
|
+
# For CHOICE_TOKEN strategy, activations are already a single vector
|
|
265
|
+
# No aggregation needed
|
|
266
|
+
return activations
|
|
267
|
+
|
|
268
|
+
def train_classifier_with_sample_size(
|
|
269
|
+
self, train_set: ContrastivePairSet, sample_size: int
|
|
270
|
+
) -> Tuple[Classifier, float]:
|
|
271
|
+
"""
|
|
272
|
+
Train a classifier with a specific sample size.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
train_set: Full training set
|
|
276
|
+
sample_size: Number of samples to use for training
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Tuple of (trained_classifier, training_time)
|
|
280
|
+
"""
|
|
281
|
+
# Limit training set to sample_size
|
|
282
|
+
if sample_size >= len(train_set.pairs):
|
|
283
|
+
train_pairs = train_set.pairs
|
|
284
|
+
else:
|
|
285
|
+
# Use first sample_size pairs (already shuffled)
|
|
286
|
+
train_pairs = train_set.pairs[:sample_size]
|
|
287
|
+
|
|
288
|
+
logger.info(f"Training classifier with {len(train_pairs)} samples")
|
|
289
|
+
|
|
290
|
+
# Ensure we have enough samples for training
|
|
291
|
+
if len(train_pairs) < 2:
|
|
292
|
+
logger.warning(f"Not enough training samples ({len(train_pairs)}). Skipping.")
|
|
293
|
+
return None, 0.0
|
|
294
|
+
|
|
295
|
+
# Extract activations
|
|
296
|
+
X_train = []
|
|
297
|
+
y_train = []
|
|
298
|
+
|
|
299
|
+
for pair in train_pairs:
|
|
300
|
+
# Positive example (correct answer)
|
|
301
|
+
X_train.append(pair.positive_activations)
|
|
302
|
+
y_train.append(0) # 0 for correct/truthful
|
|
303
|
+
|
|
304
|
+
# Negative example (incorrect answer)
|
|
305
|
+
X_train.append(pair.negative_activations)
|
|
306
|
+
y_train.append(1) # 1 for incorrect/untruthful
|
|
307
|
+
|
|
308
|
+
# Create and train classifier
|
|
309
|
+
classifier = Classifier(model_type="logistic", device=self.device)
|
|
310
|
+
|
|
311
|
+
start_time = time.time()
|
|
312
|
+
classifier.fit(X_train, y_train)
|
|
313
|
+
training_time = time.time() - start_time
|
|
314
|
+
|
|
315
|
+
return classifier, training_time
|
|
316
|
+
|
|
317
|
+
def evaluate_classifier(self, classifier: Classifier, test_set: ContrastivePairSet) -> Dict[str, float]:
|
|
318
|
+
"""
|
|
319
|
+
Evaluate a classifier on the test set.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
classifier: Trained classifier
|
|
323
|
+
test_set: Test set to evaluate on
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Dictionary of metrics
|
|
327
|
+
"""
|
|
328
|
+
X_test = []
|
|
329
|
+
y_test = []
|
|
330
|
+
|
|
331
|
+
for pair in test_set.pairs:
|
|
332
|
+
# Positive example
|
|
333
|
+
X_test.append(pair.positive_activations)
|
|
334
|
+
y_test.append(0)
|
|
335
|
+
|
|
336
|
+
# Negative example
|
|
337
|
+
X_test.append(pair.negative_activations)
|
|
338
|
+
y_test.append(1)
|
|
339
|
+
|
|
340
|
+
# Get predictions
|
|
341
|
+
y_pred = []
|
|
342
|
+
for x in X_test:
|
|
343
|
+
pred = classifier.predict(x)
|
|
344
|
+
y_pred.append(1 if pred > 0.5 else 0)
|
|
345
|
+
|
|
346
|
+
# Calculate metrics
|
|
347
|
+
metrics = {
|
|
348
|
+
"accuracy": accuracy_score(y_test, y_pred),
|
|
349
|
+
"precision": precision_score(y_test, y_pred, zero_division=0),
|
|
350
|
+
"recall": recall_score(y_test, y_pred, zero_division=0),
|
|
351
|
+
"f1": f1_score(y_test, y_pred, zero_division=0),
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
return metrics
|
|
355
|
+
|
|
356
|
+
def find_optimal_sample_size(self) -> int:
|
|
357
|
+
"""
|
|
358
|
+
Determine the optimal sample size based on diminishing returns.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Optimal sample size
|
|
362
|
+
"""
|
|
363
|
+
if len(self.results) < 2:
|
|
364
|
+
return self.sample_sizes[-1]
|
|
365
|
+
|
|
366
|
+
# Extract accuracies and times
|
|
367
|
+
accuracies = [r["metrics"]["accuracy"] for r in self.results]
|
|
368
|
+
times = [r["training_time"] for r in self.results]
|
|
369
|
+
sizes = [r["sample_size"] for r in self.results]
|
|
370
|
+
|
|
371
|
+
# Calculate accuracy gains
|
|
372
|
+
gains = []
|
|
373
|
+
for i in range(1, len(accuracies)):
|
|
374
|
+
gain = accuracies[i] - accuracies[i - 1]
|
|
375
|
+
gains.append(gain)
|
|
376
|
+
|
|
377
|
+
# Find where gain drops below threshold (2% improvement)
|
|
378
|
+
threshold = 0.02
|
|
379
|
+
optimal_idx = len(sizes) - 1 # Default to largest
|
|
380
|
+
|
|
381
|
+
for i, gain in enumerate(gains):
|
|
382
|
+
if gain < threshold and accuracies[i + 1] > 0.7: # Ensure reasonable accuracy
|
|
383
|
+
optimal_idx = i + 1
|
|
384
|
+
break
|
|
385
|
+
|
|
386
|
+
# Also consider training time - if time increases dramatically, prefer smaller
|
|
387
|
+
if optimal_idx < len(sizes) - 1 and times[optimal_idx] > 0:
|
|
388
|
+
time_ratio = times[optimal_idx + 1] / times[optimal_idx]
|
|
389
|
+
if time_ratio > 2.0 and gains[optimal_idx] < 0.01:
|
|
390
|
+
# Training time doubled for < 1% gain, stick with current
|
|
391
|
+
pass
|
|
392
|
+
elif accuracies[optimal_idx + 1] - accuracies[optimal_idx] > 0.05:
|
|
393
|
+
# Significant accuracy improvement, use larger size
|
|
394
|
+
optimal_idx += 1
|
|
395
|
+
|
|
396
|
+
return sizes[optimal_idx]
|
|
397
|
+
|
|
398
|
+
def run_optimization(self) -> Dict[str, Any]:
|
|
399
|
+
"""
|
|
400
|
+
Run the complete sample size optimization process.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
Dictionary containing results and optimal sample size
|
|
404
|
+
"""
|
|
405
|
+
logger.info("Starting sample size optimization...")
|
|
406
|
+
|
|
407
|
+
# Load and split data
|
|
408
|
+
dataset_limit = getattr(self, "dataset_limit", None)
|
|
409
|
+
train_set, test_set = self.load_and_split_data(limit=dataset_limit)
|
|
410
|
+
|
|
411
|
+
# Ensure we don't test sample sizes larger than training set
|
|
412
|
+
max_train_size = len(train_set.pairs)
|
|
413
|
+
valid_sample_sizes = [s for s in self.sample_sizes if s <= max_train_size]
|
|
414
|
+
|
|
415
|
+
if not valid_sample_sizes:
|
|
416
|
+
raise ValueError(f"No valid sample sizes. Training set has only {max_train_size} samples.")
|
|
417
|
+
|
|
418
|
+
logger.info(f"Testing sample sizes: {valid_sample_sizes}")
|
|
419
|
+
|
|
420
|
+
# Test each sample size
|
|
421
|
+
for sample_size in valid_sample_sizes:
|
|
422
|
+
logger.info(f"\n{'=' * 50}")
|
|
423
|
+
logger.info(f"Testing sample size: {sample_size}")
|
|
424
|
+
|
|
425
|
+
# Train classifier
|
|
426
|
+
classifier, training_time = self.train_classifier_with_sample_size(train_set, sample_size)
|
|
427
|
+
|
|
428
|
+
# Skip if classifier training failed
|
|
429
|
+
if classifier is None:
|
|
430
|
+
logger.warning(f"Skipping sample size {sample_size} - not enough samples for training")
|
|
431
|
+
continue
|
|
432
|
+
|
|
433
|
+
# Evaluate on test set
|
|
434
|
+
metrics = self.evaluate_classifier(classifier, test_set)
|
|
435
|
+
|
|
436
|
+
# Store results
|
|
437
|
+
result = {"sample_size": sample_size, "training_time": training_time, "metrics": metrics}
|
|
438
|
+
self.results.append(result)
|
|
439
|
+
|
|
440
|
+
logger.info(f"Accuracy: {metrics['accuracy']:.3f}")
|
|
441
|
+
logger.info(f"F1 Score: {metrics['f1']:.3f}")
|
|
442
|
+
logger.info(f"Training time: {training_time:.3f}s")
|
|
443
|
+
|
|
444
|
+
# Find optimal sample size
|
|
445
|
+
self.optimal_sample_size = self.find_optimal_sample_size()
|
|
446
|
+
|
|
447
|
+
logger.info(f"\n{'=' * 50}")
|
|
448
|
+
logger.info(f"Optimal sample size: {self.optimal_sample_size}")
|
|
449
|
+
|
|
450
|
+
# Create summary
|
|
451
|
+
summary = {
|
|
452
|
+
"model": self.model_name,
|
|
453
|
+
"task": self.task_name,
|
|
454
|
+
"layer": self.layer,
|
|
455
|
+
"test_split": self.test_split,
|
|
456
|
+
"results": self.results,
|
|
457
|
+
"optimal_sample_size": self.optimal_sample_size,
|
|
458
|
+
"timestamp": datetime.now().isoformat(),
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
return summary
|
|
462
|
+
|
|
463
|
+
def save_results(self, output_dir: Optional[str] = None) -> str:
|
|
464
|
+
"""
|
|
465
|
+
Save optimization results to file.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
output_dir: Directory to save results (uses default if None)
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
Path to saved results file
|
|
472
|
+
"""
|
|
473
|
+
if output_dir is None:
|
|
474
|
+
output_dir = "./sample_size_optimization_results"
|
|
475
|
+
|
|
476
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
477
|
+
|
|
478
|
+
# Create filename
|
|
479
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
480
|
+
model_safe = self.model_name.replace("/", "_")
|
|
481
|
+
filename = f"sample_size_{model_safe}_{self.task_name}_layer{self.layer}_{timestamp}.json"
|
|
482
|
+
filepath = os.path.join(output_dir, filename)
|
|
483
|
+
|
|
484
|
+
# Prepare data for saving
|
|
485
|
+
save_data = {
|
|
486
|
+
"model": self.model_name,
|
|
487
|
+
"task": self.task_name,
|
|
488
|
+
"layer": self.layer,
|
|
489
|
+
"test_split": self.test_split,
|
|
490
|
+
"results": self.results,
|
|
491
|
+
"optimal_sample_size": self.optimal_sample_size,
|
|
492
|
+
"timestamp": datetime.now().isoformat(),
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
# Save to file
|
|
496
|
+
with open(filepath, "w") as f:
|
|
497
|
+
json.dump(save_data, f, indent=2)
|
|
498
|
+
|
|
499
|
+
logger.info(f"Results saved to: {filepath}")
|
|
500
|
+
return filepath
|
|
501
|
+
|
|
502
|
+
def plot_results(self, save_path: Optional[str] = None, show: bool = True) -> None:
|
|
503
|
+
"""
|
|
504
|
+
Plot accuracy vs sample size curve.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
save_path: Path to save plot (optional)
|
|
508
|
+
show: Whether to display the plot
|
|
509
|
+
"""
|
|
510
|
+
if not self.results:
|
|
511
|
+
logger.warning("No results to plot")
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
# Extract data
|
|
515
|
+
sizes = [r["sample_size"] for r in self.results]
|
|
516
|
+
accuracies = [r["metrics"]["accuracy"] for r in self.results]
|
|
517
|
+
f1_scores = [r["metrics"]["f1"] for r in self.results]
|
|
518
|
+
times = [r["training_time"] for r in self.results]
|
|
519
|
+
|
|
520
|
+
# Create figure with subplots
|
|
521
|
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
|
|
522
|
+
|
|
523
|
+
# Plot 1: Accuracy and F1 vs Sample Size
|
|
524
|
+
ax1.plot(sizes, accuracies, "b-o", label="Accuracy", linewidth=2, markersize=8)
|
|
525
|
+
ax1.plot(sizes, f1_scores, "g--s", label="F1 Score", linewidth=2, markersize=8)
|
|
526
|
+
|
|
527
|
+
# Mark optimal sample size
|
|
528
|
+
if self.optimal_sample_size:
|
|
529
|
+
ax1.axvline(
|
|
530
|
+
self.optimal_sample_size, color="r", linestyle=":", label=f"Optimal: {self.optimal_sample_size}"
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
ax1.set_xlabel("Sample Size")
|
|
534
|
+
ax1.set_ylabel("Score")
|
|
535
|
+
ax1.set_title(
|
|
536
|
+
f"Classifier Performance vs Sample Size\n{self.model_name} - {self.task_name} - Layer {self.layer}"
|
|
537
|
+
)
|
|
538
|
+
ax1.legend()
|
|
539
|
+
ax1.grid(True, alpha=0.3)
|
|
540
|
+
# Use linear scale for x-axis
|
|
541
|
+
ax1.set_xticks(sizes)
|
|
542
|
+
ax1.set_xticklabels([str(s) for s in sizes])
|
|
543
|
+
|
|
544
|
+
# Plot 2: Training Time vs Sample Size
|
|
545
|
+
ax2.plot(sizes, times, "r-^", linewidth=2, markersize=8)
|
|
546
|
+
ax2.set_xlabel("Sample Size")
|
|
547
|
+
ax2.set_ylabel("Training Time (seconds)")
|
|
548
|
+
ax2.set_title("Training Time vs Sample Size")
|
|
549
|
+
ax2.grid(True, alpha=0.3)
|
|
550
|
+
# Use linear scale for x-axis
|
|
551
|
+
ax2.set_xticks(sizes)
|
|
552
|
+
ax2.set_xticklabels([str(s) for s in sizes])
|
|
553
|
+
|
|
554
|
+
plt.tight_layout()
|
|
555
|
+
|
|
556
|
+
if save_path:
|
|
557
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
558
|
+
logger.info(f"Plot saved to: {save_path}")
|
|
559
|
+
|
|
560
|
+
if show:
|
|
561
|
+
plt.show()
|
|
562
|
+
|
|
563
|
+
plt.close()
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def run_sample_size_optimization(
|
|
567
|
+
model_name: str,
|
|
568
|
+
task_name: str = "truthfulqa_mc1",
|
|
569
|
+
layer: int = 0,
|
|
570
|
+
token_aggregation: str = "average",
|
|
571
|
+
threshold: float = 0.5,
|
|
572
|
+
test_split: float = 0.2,
|
|
573
|
+
sample_sizes: Optional[List[int]] = None,
|
|
574
|
+
dataset_limit: Optional[int] = None,
|
|
575
|
+
device: Optional[str] = None,
|
|
576
|
+
verbose: bool = False,
|
|
577
|
+
save_plot: bool = True,
|
|
578
|
+
save_to_config: bool = True,
|
|
579
|
+
) -> Dict[str, Any]:
|
|
580
|
+
"""
|
|
581
|
+
Run sample size optimization and optionally save to model config.
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
model_name: Name of the model
|
|
585
|
+
task_name: Task to optimize for
|
|
586
|
+
layer: Layer index
|
|
587
|
+
token_aggregation: Token aggregation method
|
|
588
|
+
threshold: Detection threshold
|
|
589
|
+
test_split: Test split ratio
|
|
590
|
+
sample_sizes: Sample sizes to test
|
|
591
|
+
dataset_limit: Maximum number of samples to load from dataset
|
|
592
|
+
device: Computation device
|
|
593
|
+
verbose: Verbose output
|
|
594
|
+
save_plot: Whether to save the plot
|
|
595
|
+
save_to_config: Whether to save to model config
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
Optimization results dictionary
|
|
599
|
+
"""
|
|
600
|
+
# Create optimizer
|
|
601
|
+
optimizer = SampleSizeOptimizer(
|
|
602
|
+
model_name=model_name,
|
|
603
|
+
task_name=task_name,
|
|
604
|
+
layer=layer,
|
|
605
|
+
token_aggregation=token_aggregation,
|
|
606
|
+
threshold=threshold,
|
|
607
|
+
test_split=test_split,
|
|
608
|
+
sample_sizes=sample_sizes,
|
|
609
|
+
device=device,
|
|
610
|
+
verbose=verbose,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
# Run optimization with dataset limit
|
|
614
|
+
optimizer.dataset_limit = dataset_limit
|
|
615
|
+
results = optimizer.run_optimization()
|
|
616
|
+
|
|
617
|
+
# Save results
|
|
618
|
+
results_path = optimizer.save_results()
|
|
619
|
+
|
|
620
|
+
# Create plot
|
|
621
|
+
if save_plot:
|
|
622
|
+
plot_path = results_path.replace(".json", ".png")
|
|
623
|
+
optimizer.plot_results(save_path=plot_path, show=False)
|
|
624
|
+
|
|
625
|
+
# Save to model config if requested
|
|
626
|
+
if save_to_config and optimizer.optimal_sample_size:
|
|
627
|
+
config_manager = ModelConfigManager()
|
|
628
|
+
|
|
629
|
+
# Load existing config or create new
|
|
630
|
+
existing_config = config_manager.load_model_config(model_name)
|
|
631
|
+
|
|
632
|
+
if existing_config:
|
|
633
|
+
# Update existing config
|
|
634
|
+
if "optimal_sample_sizes" not in existing_config:
|
|
635
|
+
existing_config["optimal_sample_sizes"] = {}
|
|
636
|
+
|
|
637
|
+
if task_name not in existing_config["optimal_sample_sizes"]:
|
|
638
|
+
existing_config["optimal_sample_sizes"][task_name] = {}
|
|
639
|
+
|
|
640
|
+
existing_config["optimal_sample_sizes"][task_name][str(layer)] = optimizer.optimal_sample_size
|
|
641
|
+
|
|
642
|
+
# Save updated config
|
|
643
|
+
config_manager.update_model_config(model_name, existing_config)
|
|
644
|
+
logger.info(f"Updated model config with optimal sample size: {optimizer.optimal_sample_size}")
|
|
645
|
+
else:
|
|
646
|
+
logger.warning("No existing model config found. Run optimize-classification first.")
|
|
647
|
+
|
|
648
|
+
return results
|