wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Classifier model caching system for efficient Optuna optimization.
|
|
3
|
+
|
|
4
|
+
This module provides intelligent caching of trained classifier models to avoid
|
|
5
|
+
retraining identical configurations across optimization runs and trials.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import hashlib
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import pickle
|
|
12
|
+
import time
|
|
13
|
+
from dataclasses import asdict, dataclass
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from wisent_guard.core.classifier.classifier import Classifier
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CacheMetadata:
|
|
26
|
+
"""Metadata for cached classifier models."""
|
|
27
|
+
|
|
28
|
+
cache_key: str
|
|
29
|
+
model_name: str
|
|
30
|
+
task_name: str
|
|
31
|
+
model_type: str
|
|
32
|
+
layer: int
|
|
33
|
+
aggregation: str
|
|
34
|
+
threshold: float
|
|
35
|
+
hyperparameters: dict[str, Any]
|
|
36
|
+
performance_metrics: dict[str, float]
|
|
37
|
+
training_samples: int
|
|
38
|
+
data_hash: str
|
|
39
|
+
timestamp: float
|
|
40
|
+
file_size_mb: float
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict[str, Any]:
|
|
43
|
+
"""Convert to dictionary for JSON serialization."""
|
|
44
|
+
return asdict(self)
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_dict(cls, data: dict[str, Any]) -> "CacheMetadata":
|
|
48
|
+
"""Create from dictionary."""
|
|
49
|
+
return cls(**data)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class CacheConfig:
|
|
54
|
+
"""Configuration for classifier cache."""
|
|
55
|
+
|
|
56
|
+
cache_dir: str = "./classifier_cache"
|
|
57
|
+
max_cache_size_gb: float = 5.0
|
|
58
|
+
max_age_days: float = 30.0
|
|
59
|
+
memory_cache_size: int = 10 # Number of models to keep in memory
|
|
60
|
+
|
|
61
|
+
def __post_init__(self):
|
|
62
|
+
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ClassifierCache:
|
|
66
|
+
"""
|
|
67
|
+
Intelligent caching system for trained classifier models.
|
|
68
|
+
|
|
69
|
+
Features:
|
|
70
|
+
- Hash-based cache keys for deterministic caching
|
|
71
|
+
- Persistent disk storage with metadata
|
|
72
|
+
- In-memory hot cache for frequently used models
|
|
73
|
+
- Automatic cleanup based on size and age limits
|
|
74
|
+
- Performance metrics tracking
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, config: CacheConfig):
|
|
78
|
+
self.config = config
|
|
79
|
+
self.cache_dir = Path(config.cache_dir)
|
|
80
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
|
|
82
|
+
self.metadata_file = self.cache_dir / "cache_metadata.json"
|
|
83
|
+
self.memory_cache: dict[str, Classifier] = {}
|
|
84
|
+
self.access_times: dict[str, float] = {}
|
|
85
|
+
|
|
86
|
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
87
|
+
|
|
88
|
+
# Load existing metadata
|
|
89
|
+
self.metadata = self._load_metadata()
|
|
90
|
+
|
|
91
|
+
# Cleanup old/large cache if needed
|
|
92
|
+
self._cleanup_cache()
|
|
93
|
+
|
|
94
|
+
def get_cache_key(
|
|
95
|
+
self,
|
|
96
|
+
model_name: str,
|
|
97
|
+
task_name: str,
|
|
98
|
+
model_type: str,
|
|
99
|
+
layer: int,
|
|
100
|
+
aggregation: str,
|
|
101
|
+
threshold: float,
|
|
102
|
+
hyperparameters: dict[str, Any],
|
|
103
|
+
data_hash: str,
|
|
104
|
+
) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Generate deterministic cache key for classifier configuration.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
model_name: Name of the base model
|
|
110
|
+
task_name: Task being optimized
|
|
111
|
+
model_type: Type of classifier ("logistic", "mlp")
|
|
112
|
+
layer: Layer index used
|
|
113
|
+
aggregation: Token aggregation method
|
|
114
|
+
threshold: Classification threshold
|
|
115
|
+
hyperparameters: Model-specific hyperparameters
|
|
116
|
+
data_hash: Hash of the training data
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Unique cache key string
|
|
120
|
+
"""
|
|
121
|
+
# Normalize model name
|
|
122
|
+
clean_model_name = model_name.replace("/", "_").replace(":", "_")
|
|
123
|
+
|
|
124
|
+
# Sort hyperparameters for consistent hashing
|
|
125
|
+
sorted_hyperparams = json.dumps(hyperparameters, sort_keys=True)
|
|
126
|
+
|
|
127
|
+
# Create cache key components
|
|
128
|
+
key_components = [
|
|
129
|
+
clean_model_name,
|
|
130
|
+
task_name,
|
|
131
|
+
model_type,
|
|
132
|
+
str(layer),
|
|
133
|
+
aggregation,
|
|
134
|
+
f"{threshold:.3f}",
|
|
135
|
+
sorted_hyperparams,
|
|
136
|
+
data_hash,
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
# Generate hash
|
|
140
|
+
key_string = "_".join(key_components)
|
|
141
|
+
cache_key = hashlib.sha256(key_string.encode()).hexdigest()[:16] # First 16 chars
|
|
142
|
+
|
|
143
|
+
return cache_key
|
|
144
|
+
|
|
145
|
+
def has_cached_model(self, cache_key: str) -> bool:
|
|
146
|
+
"""Check if a model with the given cache key exists."""
|
|
147
|
+
return cache_key in self.metadata or cache_key in self.memory_cache
|
|
148
|
+
|
|
149
|
+
def save_classifier(
|
|
150
|
+
self,
|
|
151
|
+
cache_key: str,
|
|
152
|
+
classifier: Classifier,
|
|
153
|
+
model_name: str,
|
|
154
|
+
task_name: str,
|
|
155
|
+
layer: int,
|
|
156
|
+
aggregation: str,
|
|
157
|
+
threshold: float,
|
|
158
|
+
hyperparameters: dict[str, Any],
|
|
159
|
+
performance_metrics: dict[str, float],
|
|
160
|
+
training_samples: int,
|
|
161
|
+
data_hash: str,
|
|
162
|
+
) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Save a trained classifier to cache.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
cache_key: Unique cache key
|
|
168
|
+
classifier: Trained classifier model
|
|
169
|
+
model_name: Name of base model
|
|
170
|
+
task_name: Task name
|
|
171
|
+
layer: Layer index
|
|
172
|
+
aggregation: Aggregation method
|
|
173
|
+
threshold: Classification threshold
|
|
174
|
+
hyperparameters: Model hyperparameters
|
|
175
|
+
performance_metrics: Training/validation metrics
|
|
176
|
+
training_samples: Number of training samples
|
|
177
|
+
data_hash: Hash of training data
|
|
178
|
+
"""
|
|
179
|
+
try:
|
|
180
|
+
# Save model to disk
|
|
181
|
+
model_file = self.cache_dir / f"{cache_key}.pkl"
|
|
182
|
+
with open(model_file, "wb") as f:
|
|
183
|
+
pickle.dump(classifier, f)
|
|
184
|
+
|
|
185
|
+
# Calculate file size
|
|
186
|
+
file_size_mb = model_file.stat().st_size / (1024 * 1024)
|
|
187
|
+
|
|
188
|
+
# Create metadata
|
|
189
|
+
metadata = CacheMetadata(
|
|
190
|
+
cache_key=cache_key,
|
|
191
|
+
model_name=model_name,
|
|
192
|
+
task_name=task_name,
|
|
193
|
+
model_type=classifier.model_type,
|
|
194
|
+
layer=layer,
|
|
195
|
+
aggregation=aggregation,
|
|
196
|
+
threshold=threshold,
|
|
197
|
+
hyperparameters=hyperparameters,
|
|
198
|
+
performance_metrics=performance_metrics,
|
|
199
|
+
training_samples=training_samples,
|
|
200
|
+
data_hash=data_hash,
|
|
201
|
+
timestamp=time.time(),
|
|
202
|
+
file_size_mb=file_size_mb,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Update metadata
|
|
206
|
+
self.metadata[cache_key] = metadata
|
|
207
|
+
self._save_metadata()
|
|
208
|
+
|
|
209
|
+
# Add to memory cache if space available
|
|
210
|
+
if len(self.memory_cache) < self.config.memory_cache_size:
|
|
211
|
+
self.memory_cache[cache_key] = classifier
|
|
212
|
+
self.access_times[cache_key] = time.time()
|
|
213
|
+
|
|
214
|
+
self.logger.info(
|
|
215
|
+
f"Cached classifier {cache_key}: {model_name}/{task_name} "
|
|
216
|
+
f"layer_{layer} {classifier.model_type} ({file_size_mb:.2f}MB)"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
self.logger.error(f"Failed to save classifier {cache_key}: {e}")
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
def load_classifier(self, cache_key: str) -> Optional[Classifier]:
|
|
224
|
+
"""
|
|
225
|
+
Load a cached classifier model.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
cache_key: Cache key to load
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Loaded classifier or None if not found
|
|
232
|
+
"""
|
|
233
|
+
# Try memory cache first
|
|
234
|
+
if cache_key in self.memory_cache:
|
|
235
|
+
self.access_times[cache_key] = time.time()
|
|
236
|
+
self.logger.debug(f"Loaded classifier {cache_key} from memory cache")
|
|
237
|
+
return self.memory_cache[cache_key]
|
|
238
|
+
|
|
239
|
+
# Try disk cache
|
|
240
|
+
if cache_key not in self.metadata:
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
model_file = self.cache_dir / f"{cache_key}.pkl"
|
|
244
|
+
if not model_file.exists():
|
|
245
|
+
self.logger.warning(f"Cache file missing for {cache_key}")
|
|
246
|
+
# Remove from metadata
|
|
247
|
+
del self.metadata[cache_key]
|
|
248
|
+
self._save_metadata()
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
with open(model_file, "rb") as f:
|
|
253
|
+
classifier = pickle.load(f)
|
|
254
|
+
|
|
255
|
+
# Add to memory cache (evict oldest if needed)
|
|
256
|
+
if len(self.memory_cache) >= self.config.memory_cache_size:
|
|
257
|
+
# Evict oldest accessed model
|
|
258
|
+
oldest_key = min(self.access_times.keys(), key=self.access_times.get)
|
|
259
|
+
del self.memory_cache[oldest_key]
|
|
260
|
+
del self.access_times[oldest_key]
|
|
261
|
+
|
|
262
|
+
self.memory_cache[cache_key] = classifier
|
|
263
|
+
self.access_times[cache_key] = time.time()
|
|
264
|
+
|
|
265
|
+
self.logger.debug(f"Loaded classifier {cache_key} from disk cache")
|
|
266
|
+
return classifier
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
self.logger.error(f"Failed to load classifier {cache_key}: {e}")
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
def get_cache_info(self) -> dict[str, Any]:
|
|
273
|
+
"""Get comprehensive cache information."""
|
|
274
|
+
total_size_mb = sum(metadata.file_size_mb for metadata in self.metadata.values())
|
|
275
|
+
|
|
276
|
+
# Group by task and model type
|
|
277
|
+
task_counts = {}
|
|
278
|
+
model_type_counts = {}
|
|
279
|
+
|
|
280
|
+
for metadata in self.metadata.values():
|
|
281
|
+
task_counts[metadata.task_name] = task_counts.get(metadata.task_name, 0) + 1
|
|
282
|
+
model_type_counts[metadata.model_type] = model_type_counts.get(metadata.model_type, 0) + 1
|
|
283
|
+
|
|
284
|
+
return {
|
|
285
|
+
"total_models": len(self.metadata),
|
|
286
|
+
"total_size_mb": total_size_mb,
|
|
287
|
+
"memory_cache_size": len(self.memory_cache),
|
|
288
|
+
"cache_dir": str(self.cache_dir),
|
|
289
|
+
"task_distribution": task_counts,
|
|
290
|
+
"model_type_distribution": model_type_counts,
|
|
291
|
+
"oldest_cache_age_hours": (
|
|
292
|
+
time.time() - min((m.timestamp for m in self.metadata.values()), default=time.time())
|
|
293
|
+
)
|
|
294
|
+
/ 3600,
|
|
295
|
+
"config": asdict(self.config),
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
def find_similar_models(
|
|
299
|
+
self,
|
|
300
|
+
model_name: str,
|
|
301
|
+
task_name: str,
|
|
302
|
+
model_type: Optional[str] = None,
|
|
303
|
+
layer: Optional[int] = None,
|
|
304
|
+
top_k: int = 5,
|
|
305
|
+
) -> list[tuple[str, CacheMetadata, float]]:
|
|
306
|
+
"""
|
|
307
|
+
Find similar cached models based on configuration.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
model_name: Base model name
|
|
311
|
+
task_name: Task name
|
|
312
|
+
model_type: Optional model type filter
|
|
313
|
+
layer: Optional layer filter
|
|
314
|
+
top_k: Maximum number of results
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
List of (cache_key, metadata, similarity_score) tuples
|
|
318
|
+
"""
|
|
319
|
+
candidates = []
|
|
320
|
+
|
|
321
|
+
for cache_key, metadata in self.metadata.items():
|
|
322
|
+
# Calculate similarity score
|
|
323
|
+
score = 0.0
|
|
324
|
+
|
|
325
|
+
# Model name match (highest weight)
|
|
326
|
+
if metadata.model_name == model_name:
|
|
327
|
+
score += 0.4
|
|
328
|
+
|
|
329
|
+
# Task name match
|
|
330
|
+
if metadata.task_name == task_name:
|
|
331
|
+
score += 0.3
|
|
332
|
+
|
|
333
|
+
# Model type match
|
|
334
|
+
if model_type and metadata.model_type == model_type:
|
|
335
|
+
score += 0.2
|
|
336
|
+
|
|
337
|
+
# Layer proximity
|
|
338
|
+
if layer is not None:
|
|
339
|
+
layer_diff = abs(metadata.layer - layer)
|
|
340
|
+
layer_score = max(0, 1.0 - layer_diff / 10.0) # Decay with distance
|
|
341
|
+
score += 0.1 * layer_score
|
|
342
|
+
|
|
343
|
+
# Only include models with some similarity
|
|
344
|
+
if score > 0.1:
|
|
345
|
+
candidates.append((cache_key, metadata, score))
|
|
346
|
+
|
|
347
|
+
# Sort by similarity score and return top_k
|
|
348
|
+
candidates.sort(key=lambda x: x[2], reverse=True)
|
|
349
|
+
return candidates[:top_k]
|
|
350
|
+
|
|
351
|
+
def clear_cache(self, keep_recent_hours: float = 0) -> int:
|
|
352
|
+
"""
|
|
353
|
+
Clear cached models.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
keep_recent_hours: Keep models newer than this many hours
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Number of models removed
|
|
360
|
+
"""
|
|
361
|
+
cutoff_time = time.time() - (keep_recent_hours * 3600)
|
|
362
|
+
removed_count = 0
|
|
363
|
+
|
|
364
|
+
keys_to_remove = []
|
|
365
|
+
for cache_key, metadata in self.metadata.items():
|
|
366
|
+
if metadata.timestamp < cutoff_time:
|
|
367
|
+
keys_to_remove.append(cache_key)
|
|
368
|
+
|
|
369
|
+
for cache_key in keys_to_remove:
|
|
370
|
+
try:
|
|
371
|
+
# Remove from disk
|
|
372
|
+
model_file = self.cache_dir / f"{cache_key}.pkl"
|
|
373
|
+
if model_file.exists():
|
|
374
|
+
model_file.unlink()
|
|
375
|
+
|
|
376
|
+
# Remove from memory cache
|
|
377
|
+
if cache_key in self.memory_cache:
|
|
378
|
+
del self.memory_cache[cache_key]
|
|
379
|
+
if cache_key in self.access_times:
|
|
380
|
+
del self.access_times[cache_key]
|
|
381
|
+
|
|
382
|
+
# Remove from metadata
|
|
383
|
+
del self.metadata[cache_key]
|
|
384
|
+
removed_count += 1
|
|
385
|
+
|
|
386
|
+
except Exception as e:
|
|
387
|
+
self.logger.warning(f"Failed to remove cached model {cache_key}: {e}")
|
|
388
|
+
|
|
389
|
+
self._save_metadata()
|
|
390
|
+
self.logger.info(f"Cleared {removed_count} cached models")
|
|
391
|
+
return removed_count
|
|
392
|
+
|
|
393
|
+
def _load_metadata(self) -> dict[str, CacheMetadata]:
|
|
394
|
+
"""Load cache metadata from disk."""
|
|
395
|
+
if not self.metadata_file.exists():
|
|
396
|
+
return {}
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
with open(self.metadata_file) as f:
|
|
400
|
+
data = json.load(f)
|
|
401
|
+
|
|
402
|
+
metadata = {}
|
|
403
|
+
for cache_key, metadata_dict in data.items():
|
|
404
|
+
metadata[cache_key] = CacheMetadata.from_dict(metadata_dict)
|
|
405
|
+
|
|
406
|
+
self.logger.debug(f"Loaded metadata for {len(metadata)} cached models")
|
|
407
|
+
return metadata
|
|
408
|
+
|
|
409
|
+
except Exception as e:
|
|
410
|
+
self.logger.warning(f"Failed to load cache metadata: {e}")
|
|
411
|
+
return {}
|
|
412
|
+
|
|
413
|
+
def _save_metadata(self) -> None:
|
|
414
|
+
"""Save cache metadata to disk."""
|
|
415
|
+
try:
|
|
416
|
+
data = {}
|
|
417
|
+
for cache_key, metadata in self.metadata.items():
|
|
418
|
+
data[cache_key] = metadata.to_dict()
|
|
419
|
+
|
|
420
|
+
with open(self.metadata_file, "w") as f:
|
|
421
|
+
json.dump(data, f, indent=2)
|
|
422
|
+
|
|
423
|
+
except Exception as e:
|
|
424
|
+
self.logger.error(f"Failed to save cache metadata: {e}")
|
|
425
|
+
|
|
426
|
+
def _cleanup_cache(self) -> None:
|
|
427
|
+
"""Clean up cache based on size and age limits."""
|
|
428
|
+
current_time = time.time()
|
|
429
|
+
total_size_mb = sum(metadata.file_size_mb for metadata in self.metadata.values())
|
|
430
|
+
|
|
431
|
+
# Remove old models
|
|
432
|
+
old_threshold = current_time - (self.config.max_age_days * 24 * 3600)
|
|
433
|
+
old_models = [cache_key for cache_key, metadata in self.metadata.items() if metadata.timestamp < old_threshold]
|
|
434
|
+
|
|
435
|
+
if old_models:
|
|
436
|
+
for cache_key in old_models:
|
|
437
|
+
try:
|
|
438
|
+
model_file = self.cache_dir / f"{cache_key}.pkl"
|
|
439
|
+
if model_file.exists():
|
|
440
|
+
model_file.unlink()
|
|
441
|
+
del self.metadata[cache_key]
|
|
442
|
+
except Exception as e:
|
|
443
|
+
self.logger.warning(f"Failed to remove old model {cache_key}: {e}")
|
|
444
|
+
|
|
445
|
+
self.logger.info(f"Removed {len(old_models)} old cached models")
|
|
446
|
+
total_size_mb = sum(metadata.file_size_mb for metadata in self.metadata.values())
|
|
447
|
+
|
|
448
|
+
# Remove largest models if over size limit
|
|
449
|
+
if total_size_mb > self.config.max_cache_size_gb * 1024:
|
|
450
|
+
# Sort by size (largest first)
|
|
451
|
+
models_by_size = sorted(self.metadata.items(), key=lambda x: x[1].file_size_mb, reverse=True)
|
|
452
|
+
|
|
453
|
+
removed_count = 0
|
|
454
|
+
for cache_key, metadata in models_by_size:
|
|
455
|
+
if total_size_mb <= self.config.max_cache_size_gb * 1024:
|
|
456
|
+
break
|
|
457
|
+
|
|
458
|
+
try:
|
|
459
|
+
model_file = self.cache_dir / f"{cache_key}.pkl"
|
|
460
|
+
if model_file.exists():
|
|
461
|
+
model_file.unlink()
|
|
462
|
+
|
|
463
|
+
total_size_mb -= metadata.file_size_mb
|
|
464
|
+
del self.metadata[cache_key]
|
|
465
|
+
removed_count += 1
|
|
466
|
+
|
|
467
|
+
except Exception as e:
|
|
468
|
+
self.logger.warning(f"Failed to remove large model {cache_key}: {e}")
|
|
469
|
+
|
|
470
|
+
if removed_count > 0:
|
|
471
|
+
self.logger.info(f"Removed {removed_count} large cached models to free space")
|
|
472
|
+
|
|
473
|
+
# Save updated metadata
|
|
474
|
+
self._save_metadata()
|
|
475
|
+
|
|
476
|
+
def compute_data_hash(self, X: torch.Tensor, y: torch.Tensor) -> str:
|
|
477
|
+
"""
|
|
478
|
+
Compute hash of training data for cache key generation.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
X: Training features (torch tensor)
|
|
482
|
+
y: Training labels (torch tensor)
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
Hash string representing the data
|
|
486
|
+
"""
|
|
487
|
+
# Work directly with tensors - no numpy conversion needed
|
|
488
|
+
# Use shape and sample of data for hashing (efficient for large datasets)
|
|
489
|
+
x_hash = hashlib.md5(str(tuple(X.shape)).encode()).hexdigest()[:8]
|
|
490
|
+
y_hash = hashlib.md5(str(tuple(y.shape)).encode()).hexdigest()[:8]
|
|
491
|
+
|
|
492
|
+
# Sample some data points for more unique hash (tensor operations)
|
|
493
|
+
if X.size(0) > 10:
|
|
494
|
+
# Use tensor indexing instead of numpy.linspace
|
|
495
|
+
sample_indices = torch.linspace(0, X.size(0) - 1, 10, dtype=torch.long)
|
|
496
|
+
x_sample = X[sample_indices].flatten()[:100] # First 100 values
|
|
497
|
+
y_sample = y[sample_indices]
|
|
498
|
+
else:
|
|
499
|
+
x_sample = X.flatten()[:100]
|
|
500
|
+
y_sample = y
|
|
501
|
+
|
|
502
|
+
# Convert tensor data to bytes for hashing
|
|
503
|
+
x_sample_bytes = x_sample.detach().cpu().numpy().tobytes()
|
|
504
|
+
y_sample_bytes = y_sample.detach().cpu().numpy().tobytes()
|
|
505
|
+
|
|
506
|
+
x_sample_hash = hashlib.md5(x_sample_bytes).hexdigest()[:8]
|
|
507
|
+
y_sample_hash = hashlib.md5(y_sample_bytes).hexdigest()[:8]
|
|
508
|
+
|
|
509
|
+
return f"{x_hash}_{y_hash}_{x_sample_hash}_{y_sample_hash}"
|