wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,509 @@
1
+ """
2
+ Classifier model caching system for efficient Optuna optimization.
3
+
4
+ This module provides intelligent caching of trained classifier models to avoid
5
+ retraining identical configurations across optimization runs and trials.
6
+ """
7
+
8
+ import hashlib
9
+ import json
10
+ import logging
11
+ import pickle
12
+ import time
13
+ from dataclasses import asdict, dataclass
14
+ from pathlib import Path
15
+ from typing import Any, Optional
16
+
17
+ import torch
18
+
19
+ from wisent_guard.core.classifier.classifier import Classifier
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class CacheMetadata:
26
+ """Metadata for cached classifier models."""
27
+
28
+ cache_key: str
29
+ model_name: str
30
+ task_name: str
31
+ model_type: str
32
+ layer: int
33
+ aggregation: str
34
+ threshold: float
35
+ hyperparameters: dict[str, Any]
36
+ performance_metrics: dict[str, float]
37
+ training_samples: int
38
+ data_hash: str
39
+ timestamp: float
40
+ file_size_mb: float
41
+
42
+ def to_dict(self) -> dict[str, Any]:
43
+ """Convert to dictionary for JSON serialization."""
44
+ return asdict(self)
45
+
46
+ @classmethod
47
+ def from_dict(cls, data: dict[str, Any]) -> "CacheMetadata":
48
+ """Create from dictionary."""
49
+ return cls(**data)
50
+
51
+
52
+ @dataclass
53
+ class CacheConfig:
54
+ """Configuration for classifier cache."""
55
+
56
+ cache_dir: str = "./classifier_cache"
57
+ max_cache_size_gb: float = 5.0
58
+ max_age_days: float = 30.0
59
+ memory_cache_size: int = 10 # Number of models to keep in memory
60
+
61
+ def __post_init__(self):
62
+ Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
63
+
64
+
65
+ class ClassifierCache:
66
+ """
67
+ Intelligent caching system for trained classifier models.
68
+
69
+ Features:
70
+ - Hash-based cache keys for deterministic caching
71
+ - Persistent disk storage with metadata
72
+ - In-memory hot cache for frequently used models
73
+ - Automatic cleanup based on size and age limits
74
+ - Performance metrics tracking
75
+ """
76
+
77
+ def __init__(self, config: CacheConfig):
78
+ self.config = config
79
+ self.cache_dir = Path(config.cache_dir)
80
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
81
+
82
+ self.metadata_file = self.cache_dir / "cache_metadata.json"
83
+ self.memory_cache: dict[str, Classifier] = {}
84
+ self.access_times: dict[str, float] = {}
85
+
86
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
87
+
88
+ # Load existing metadata
89
+ self.metadata = self._load_metadata()
90
+
91
+ # Cleanup old/large cache if needed
92
+ self._cleanup_cache()
93
+
94
+ def get_cache_key(
95
+ self,
96
+ model_name: str,
97
+ task_name: str,
98
+ model_type: str,
99
+ layer: int,
100
+ aggregation: str,
101
+ threshold: float,
102
+ hyperparameters: dict[str, Any],
103
+ data_hash: str,
104
+ ) -> str:
105
+ """
106
+ Generate deterministic cache key for classifier configuration.
107
+
108
+ Args:
109
+ model_name: Name of the base model
110
+ task_name: Task being optimized
111
+ model_type: Type of classifier ("logistic", "mlp")
112
+ layer: Layer index used
113
+ aggregation: Token aggregation method
114
+ threshold: Classification threshold
115
+ hyperparameters: Model-specific hyperparameters
116
+ data_hash: Hash of the training data
117
+
118
+ Returns:
119
+ Unique cache key string
120
+ """
121
+ # Normalize model name
122
+ clean_model_name = model_name.replace("/", "_").replace(":", "_")
123
+
124
+ # Sort hyperparameters for consistent hashing
125
+ sorted_hyperparams = json.dumps(hyperparameters, sort_keys=True)
126
+
127
+ # Create cache key components
128
+ key_components = [
129
+ clean_model_name,
130
+ task_name,
131
+ model_type,
132
+ str(layer),
133
+ aggregation,
134
+ f"{threshold:.3f}",
135
+ sorted_hyperparams,
136
+ data_hash,
137
+ ]
138
+
139
+ # Generate hash
140
+ key_string = "_".join(key_components)
141
+ cache_key = hashlib.sha256(key_string.encode()).hexdigest()[:16] # First 16 chars
142
+
143
+ return cache_key
144
+
145
+ def has_cached_model(self, cache_key: str) -> bool:
146
+ """Check if a model with the given cache key exists."""
147
+ return cache_key in self.metadata or cache_key in self.memory_cache
148
+
149
+ def save_classifier(
150
+ self,
151
+ cache_key: str,
152
+ classifier: Classifier,
153
+ model_name: str,
154
+ task_name: str,
155
+ layer: int,
156
+ aggregation: str,
157
+ threshold: float,
158
+ hyperparameters: dict[str, Any],
159
+ performance_metrics: dict[str, float],
160
+ training_samples: int,
161
+ data_hash: str,
162
+ ) -> None:
163
+ """
164
+ Save a trained classifier to cache.
165
+
166
+ Args:
167
+ cache_key: Unique cache key
168
+ classifier: Trained classifier model
169
+ model_name: Name of base model
170
+ task_name: Task name
171
+ layer: Layer index
172
+ aggregation: Aggregation method
173
+ threshold: Classification threshold
174
+ hyperparameters: Model hyperparameters
175
+ performance_metrics: Training/validation metrics
176
+ training_samples: Number of training samples
177
+ data_hash: Hash of training data
178
+ """
179
+ try:
180
+ # Save model to disk
181
+ model_file = self.cache_dir / f"{cache_key}.pkl"
182
+ with open(model_file, "wb") as f:
183
+ pickle.dump(classifier, f)
184
+
185
+ # Calculate file size
186
+ file_size_mb = model_file.stat().st_size / (1024 * 1024)
187
+
188
+ # Create metadata
189
+ metadata = CacheMetadata(
190
+ cache_key=cache_key,
191
+ model_name=model_name,
192
+ task_name=task_name,
193
+ model_type=classifier.model_type,
194
+ layer=layer,
195
+ aggregation=aggregation,
196
+ threshold=threshold,
197
+ hyperparameters=hyperparameters,
198
+ performance_metrics=performance_metrics,
199
+ training_samples=training_samples,
200
+ data_hash=data_hash,
201
+ timestamp=time.time(),
202
+ file_size_mb=file_size_mb,
203
+ )
204
+
205
+ # Update metadata
206
+ self.metadata[cache_key] = metadata
207
+ self._save_metadata()
208
+
209
+ # Add to memory cache if space available
210
+ if len(self.memory_cache) < self.config.memory_cache_size:
211
+ self.memory_cache[cache_key] = classifier
212
+ self.access_times[cache_key] = time.time()
213
+
214
+ self.logger.info(
215
+ f"Cached classifier {cache_key}: {model_name}/{task_name} "
216
+ f"layer_{layer} {classifier.model_type} ({file_size_mb:.2f}MB)"
217
+ )
218
+
219
+ except Exception as e:
220
+ self.logger.error(f"Failed to save classifier {cache_key}: {e}")
221
+ raise
222
+
223
+ def load_classifier(self, cache_key: str) -> Optional[Classifier]:
224
+ """
225
+ Load a cached classifier model.
226
+
227
+ Args:
228
+ cache_key: Cache key to load
229
+
230
+ Returns:
231
+ Loaded classifier or None if not found
232
+ """
233
+ # Try memory cache first
234
+ if cache_key in self.memory_cache:
235
+ self.access_times[cache_key] = time.time()
236
+ self.logger.debug(f"Loaded classifier {cache_key} from memory cache")
237
+ return self.memory_cache[cache_key]
238
+
239
+ # Try disk cache
240
+ if cache_key not in self.metadata:
241
+ return None
242
+
243
+ model_file = self.cache_dir / f"{cache_key}.pkl"
244
+ if not model_file.exists():
245
+ self.logger.warning(f"Cache file missing for {cache_key}")
246
+ # Remove from metadata
247
+ del self.metadata[cache_key]
248
+ self._save_metadata()
249
+ return None
250
+
251
+ try:
252
+ with open(model_file, "rb") as f:
253
+ classifier = pickle.load(f)
254
+
255
+ # Add to memory cache (evict oldest if needed)
256
+ if len(self.memory_cache) >= self.config.memory_cache_size:
257
+ # Evict oldest accessed model
258
+ oldest_key = min(self.access_times.keys(), key=self.access_times.get)
259
+ del self.memory_cache[oldest_key]
260
+ del self.access_times[oldest_key]
261
+
262
+ self.memory_cache[cache_key] = classifier
263
+ self.access_times[cache_key] = time.time()
264
+
265
+ self.logger.debug(f"Loaded classifier {cache_key} from disk cache")
266
+ return classifier
267
+
268
+ except Exception as e:
269
+ self.logger.error(f"Failed to load classifier {cache_key}: {e}")
270
+ return None
271
+
272
+ def get_cache_info(self) -> dict[str, Any]:
273
+ """Get comprehensive cache information."""
274
+ total_size_mb = sum(metadata.file_size_mb for metadata in self.metadata.values())
275
+
276
+ # Group by task and model type
277
+ task_counts = {}
278
+ model_type_counts = {}
279
+
280
+ for metadata in self.metadata.values():
281
+ task_counts[metadata.task_name] = task_counts.get(metadata.task_name, 0) + 1
282
+ model_type_counts[metadata.model_type] = model_type_counts.get(metadata.model_type, 0) + 1
283
+
284
+ return {
285
+ "total_models": len(self.metadata),
286
+ "total_size_mb": total_size_mb,
287
+ "memory_cache_size": len(self.memory_cache),
288
+ "cache_dir": str(self.cache_dir),
289
+ "task_distribution": task_counts,
290
+ "model_type_distribution": model_type_counts,
291
+ "oldest_cache_age_hours": (
292
+ time.time() - min((m.timestamp for m in self.metadata.values()), default=time.time())
293
+ )
294
+ / 3600,
295
+ "config": asdict(self.config),
296
+ }
297
+
298
+ def find_similar_models(
299
+ self,
300
+ model_name: str,
301
+ task_name: str,
302
+ model_type: Optional[str] = None,
303
+ layer: Optional[int] = None,
304
+ top_k: int = 5,
305
+ ) -> list[tuple[str, CacheMetadata, float]]:
306
+ """
307
+ Find similar cached models based on configuration.
308
+
309
+ Args:
310
+ model_name: Base model name
311
+ task_name: Task name
312
+ model_type: Optional model type filter
313
+ layer: Optional layer filter
314
+ top_k: Maximum number of results
315
+
316
+ Returns:
317
+ List of (cache_key, metadata, similarity_score) tuples
318
+ """
319
+ candidates = []
320
+
321
+ for cache_key, metadata in self.metadata.items():
322
+ # Calculate similarity score
323
+ score = 0.0
324
+
325
+ # Model name match (highest weight)
326
+ if metadata.model_name == model_name:
327
+ score += 0.4
328
+
329
+ # Task name match
330
+ if metadata.task_name == task_name:
331
+ score += 0.3
332
+
333
+ # Model type match
334
+ if model_type and metadata.model_type == model_type:
335
+ score += 0.2
336
+
337
+ # Layer proximity
338
+ if layer is not None:
339
+ layer_diff = abs(metadata.layer - layer)
340
+ layer_score = max(0, 1.0 - layer_diff / 10.0) # Decay with distance
341
+ score += 0.1 * layer_score
342
+
343
+ # Only include models with some similarity
344
+ if score > 0.1:
345
+ candidates.append((cache_key, metadata, score))
346
+
347
+ # Sort by similarity score and return top_k
348
+ candidates.sort(key=lambda x: x[2], reverse=True)
349
+ return candidates[:top_k]
350
+
351
+ def clear_cache(self, keep_recent_hours: float = 0) -> int:
352
+ """
353
+ Clear cached models.
354
+
355
+ Args:
356
+ keep_recent_hours: Keep models newer than this many hours
357
+
358
+ Returns:
359
+ Number of models removed
360
+ """
361
+ cutoff_time = time.time() - (keep_recent_hours * 3600)
362
+ removed_count = 0
363
+
364
+ keys_to_remove = []
365
+ for cache_key, metadata in self.metadata.items():
366
+ if metadata.timestamp < cutoff_time:
367
+ keys_to_remove.append(cache_key)
368
+
369
+ for cache_key in keys_to_remove:
370
+ try:
371
+ # Remove from disk
372
+ model_file = self.cache_dir / f"{cache_key}.pkl"
373
+ if model_file.exists():
374
+ model_file.unlink()
375
+
376
+ # Remove from memory cache
377
+ if cache_key in self.memory_cache:
378
+ del self.memory_cache[cache_key]
379
+ if cache_key in self.access_times:
380
+ del self.access_times[cache_key]
381
+
382
+ # Remove from metadata
383
+ del self.metadata[cache_key]
384
+ removed_count += 1
385
+
386
+ except Exception as e:
387
+ self.logger.warning(f"Failed to remove cached model {cache_key}: {e}")
388
+
389
+ self._save_metadata()
390
+ self.logger.info(f"Cleared {removed_count} cached models")
391
+ return removed_count
392
+
393
+ def _load_metadata(self) -> dict[str, CacheMetadata]:
394
+ """Load cache metadata from disk."""
395
+ if not self.metadata_file.exists():
396
+ return {}
397
+
398
+ try:
399
+ with open(self.metadata_file) as f:
400
+ data = json.load(f)
401
+
402
+ metadata = {}
403
+ for cache_key, metadata_dict in data.items():
404
+ metadata[cache_key] = CacheMetadata.from_dict(metadata_dict)
405
+
406
+ self.logger.debug(f"Loaded metadata for {len(metadata)} cached models")
407
+ return metadata
408
+
409
+ except Exception as e:
410
+ self.logger.warning(f"Failed to load cache metadata: {e}")
411
+ return {}
412
+
413
+ def _save_metadata(self) -> None:
414
+ """Save cache metadata to disk."""
415
+ try:
416
+ data = {}
417
+ for cache_key, metadata in self.metadata.items():
418
+ data[cache_key] = metadata.to_dict()
419
+
420
+ with open(self.metadata_file, "w") as f:
421
+ json.dump(data, f, indent=2)
422
+
423
+ except Exception as e:
424
+ self.logger.error(f"Failed to save cache metadata: {e}")
425
+
426
+ def _cleanup_cache(self) -> None:
427
+ """Clean up cache based on size and age limits."""
428
+ current_time = time.time()
429
+ total_size_mb = sum(metadata.file_size_mb for metadata in self.metadata.values())
430
+
431
+ # Remove old models
432
+ old_threshold = current_time - (self.config.max_age_days * 24 * 3600)
433
+ old_models = [cache_key for cache_key, metadata in self.metadata.items() if metadata.timestamp < old_threshold]
434
+
435
+ if old_models:
436
+ for cache_key in old_models:
437
+ try:
438
+ model_file = self.cache_dir / f"{cache_key}.pkl"
439
+ if model_file.exists():
440
+ model_file.unlink()
441
+ del self.metadata[cache_key]
442
+ except Exception as e:
443
+ self.logger.warning(f"Failed to remove old model {cache_key}: {e}")
444
+
445
+ self.logger.info(f"Removed {len(old_models)} old cached models")
446
+ total_size_mb = sum(metadata.file_size_mb for metadata in self.metadata.values())
447
+
448
+ # Remove largest models if over size limit
449
+ if total_size_mb > self.config.max_cache_size_gb * 1024:
450
+ # Sort by size (largest first)
451
+ models_by_size = sorted(self.metadata.items(), key=lambda x: x[1].file_size_mb, reverse=True)
452
+
453
+ removed_count = 0
454
+ for cache_key, metadata in models_by_size:
455
+ if total_size_mb <= self.config.max_cache_size_gb * 1024:
456
+ break
457
+
458
+ try:
459
+ model_file = self.cache_dir / f"{cache_key}.pkl"
460
+ if model_file.exists():
461
+ model_file.unlink()
462
+
463
+ total_size_mb -= metadata.file_size_mb
464
+ del self.metadata[cache_key]
465
+ removed_count += 1
466
+
467
+ except Exception as e:
468
+ self.logger.warning(f"Failed to remove large model {cache_key}: {e}")
469
+
470
+ if removed_count > 0:
471
+ self.logger.info(f"Removed {removed_count} large cached models to free space")
472
+
473
+ # Save updated metadata
474
+ self._save_metadata()
475
+
476
+ def compute_data_hash(self, X: torch.Tensor, y: torch.Tensor) -> str:
477
+ """
478
+ Compute hash of training data for cache key generation.
479
+
480
+ Args:
481
+ X: Training features (torch tensor)
482
+ y: Training labels (torch tensor)
483
+
484
+ Returns:
485
+ Hash string representing the data
486
+ """
487
+ # Work directly with tensors - no numpy conversion needed
488
+ # Use shape and sample of data for hashing (efficient for large datasets)
489
+ x_hash = hashlib.md5(str(tuple(X.shape)).encode()).hexdigest()[:8]
490
+ y_hash = hashlib.md5(str(tuple(y.shape)).encode()).hexdigest()[:8]
491
+
492
+ # Sample some data points for more unique hash (tensor operations)
493
+ if X.size(0) > 10:
494
+ # Use tensor indexing instead of numpy.linspace
495
+ sample_indices = torch.linspace(0, X.size(0) - 1, 10, dtype=torch.long)
496
+ x_sample = X[sample_indices].flatten()[:100] # First 100 values
497
+ y_sample = y[sample_indices]
498
+ else:
499
+ x_sample = X.flatten()[:100]
500
+ y_sample = y
501
+
502
+ # Convert tensor data to bytes for hashing
503
+ x_sample_bytes = x_sample.detach().cpu().numpy().tobytes()
504
+ y_sample_bytes = y_sample.detach().cpu().numpy().tobytes()
505
+
506
+ x_sample_hash = hashlib.md5(x_sample_bytes).hexdigest()[:8]
507
+ y_sample_hash = hashlib.md5(y_sample_bytes).hexdigest()[:8]
508
+
509
+ return f"{x_hash}_{y_hash}_{x_sample_hash}_{y_sample_hash}"