morphml 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of morphml might be problematic. Click here for more details.
- morphml/__init__.py +14 -0
- morphml/api/__init__.py +26 -0
- morphml/api/app.py +326 -0
- morphml/api/auth.py +193 -0
- morphml/api/client.py +338 -0
- morphml/api/models.py +132 -0
- morphml/api/rate_limit.py +192 -0
- morphml/benchmarking/__init__.py +36 -0
- morphml/benchmarking/comparison.py +430 -0
- morphml/benchmarks/__init__.py +56 -0
- morphml/benchmarks/comparator.py +409 -0
- morphml/benchmarks/datasets.py +280 -0
- morphml/benchmarks/metrics.py +199 -0
- morphml/benchmarks/openml_suite.py +201 -0
- morphml/benchmarks/problems.py +289 -0
- morphml/benchmarks/suite.py +318 -0
- morphml/cli/__init__.py +5 -0
- morphml/cli/commands/experiment.py +329 -0
- morphml/cli/main.py +457 -0
- morphml/cli/quickstart.py +312 -0
- morphml/config.py +278 -0
- morphml/constraints/__init__.py +19 -0
- morphml/constraints/handler.py +205 -0
- morphml/constraints/predicates.py +285 -0
- morphml/core/__init__.py +3 -0
- morphml/core/crossover.py +449 -0
- morphml/core/dsl/README.md +359 -0
- morphml/core/dsl/__init__.py +72 -0
- morphml/core/dsl/ast_nodes.py +364 -0
- morphml/core/dsl/compiler.py +318 -0
- morphml/core/dsl/layers.py +368 -0
- morphml/core/dsl/lexer.py +336 -0
- morphml/core/dsl/parser.py +455 -0
- morphml/core/dsl/search_space.py +386 -0
- morphml/core/dsl/syntax.py +199 -0
- morphml/core/dsl/type_system.py +361 -0
- morphml/core/dsl/validator.py +386 -0
- morphml/core/graph/__init__.py +40 -0
- morphml/core/graph/edge.py +124 -0
- morphml/core/graph/graph.py +507 -0
- morphml/core/graph/mutations.py +409 -0
- morphml/core/graph/node.py +196 -0
- morphml/core/graph/serialization.py +361 -0
- morphml/core/graph/visualization.py +431 -0
- morphml/core/objectives/__init__.py +20 -0
- morphml/core/search/__init__.py +33 -0
- morphml/core/search/individual.py +252 -0
- morphml/core/search/parameters.py +453 -0
- morphml/core/search/population.py +375 -0
- morphml/core/search/search_engine.py +340 -0
- morphml/distributed/__init__.py +76 -0
- morphml/distributed/fault_tolerance.py +497 -0
- morphml/distributed/health_monitor.py +348 -0
- morphml/distributed/master.py +709 -0
- morphml/distributed/proto/README.md +224 -0
- morphml/distributed/proto/__init__.py +74 -0
- morphml/distributed/proto/worker.proto +170 -0
- morphml/distributed/proto/worker_pb2.py +79 -0
- morphml/distributed/proto/worker_pb2_grpc.py +423 -0
- morphml/distributed/resource_manager.py +416 -0
- morphml/distributed/scheduler.py +567 -0
- morphml/distributed/storage/__init__.py +33 -0
- morphml/distributed/storage/artifacts.py +381 -0
- morphml/distributed/storage/cache.py +366 -0
- morphml/distributed/storage/checkpointing.py +329 -0
- morphml/distributed/storage/database.py +459 -0
- morphml/distributed/worker.py +549 -0
- morphml/evaluation/__init__.py +5 -0
- morphml/evaluation/heuristic.py +237 -0
- morphml/exceptions.py +55 -0
- morphml/execution/__init__.py +5 -0
- morphml/execution/local_executor.py +350 -0
- morphml/integrations/__init__.py +28 -0
- morphml/integrations/jax_adapter.py +206 -0
- morphml/integrations/pytorch_adapter.py +530 -0
- morphml/integrations/sklearn_adapter.py +206 -0
- morphml/integrations/tensorflow_adapter.py +230 -0
- morphml/logging_config.py +93 -0
- morphml/meta_learning/__init__.py +66 -0
- morphml/meta_learning/architecture_similarity.py +277 -0
- morphml/meta_learning/experiment_database.py +240 -0
- morphml/meta_learning/knowledge_base/__init__.py +19 -0
- morphml/meta_learning/knowledge_base/embedder.py +179 -0
- morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
- morphml/meta_learning/knowledge_base/meta_features.py +265 -0
- morphml/meta_learning/knowledge_base/vector_store.py +271 -0
- morphml/meta_learning/predictors/__init__.py +27 -0
- morphml/meta_learning/predictors/ensemble.py +221 -0
- morphml/meta_learning/predictors/gnn_predictor.py +552 -0
- morphml/meta_learning/predictors/learning_curve.py +231 -0
- morphml/meta_learning/predictors/proxy_metrics.py +261 -0
- morphml/meta_learning/strategy_evolution/__init__.py +27 -0
- morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
- morphml/meta_learning/strategy_evolution/bandit.py +276 -0
- morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
- morphml/meta_learning/transfer.py +581 -0
- morphml/meta_learning/warm_start.py +286 -0
- morphml/optimizers/__init__.py +74 -0
- morphml/optimizers/adaptive_operators.py +399 -0
- morphml/optimizers/bayesian/__init__.py +52 -0
- morphml/optimizers/bayesian/acquisition.py +387 -0
- morphml/optimizers/bayesian/base.py +319 -0
- morphml/optimizers/bayesian/gaussian_process.py +635 -0
- morphml/optimizers/bayesian/smac.py +534 -0
- morphml/optimizers/bayesian/tpe.py +411 -0
- morphml/optimizers/differential_evolution.py +220 -0
- morphml/optimizers/evolutionary/__init__.py +61 -0
- morphml/optimizers/evolutionary/cma_es.py +416 -0
- morphml/optimizers/evolutionary/differential_evolution.py +556 -0
- morphml/optimizers/evolutionary/encoding.py +426 -0
- morphml/optimizers/evolutionary/particle_swarm.py +449 -0
- morphml/optimizers/genetic_algorithm.py +486 -0
- morphml/optimizers/gradient_based/__init__.py +22 -0
- morphml/optimizers/gradient_based/darts.py +550 -0
- morphml/optimizers/gradient_based/enas.py +585 -0
- morphml/optimizers/gradient_based/operations.py +474 -0
- morphml/optimizers/gradient_based/utils.py +601 -0
- morphml/optimizers/hill_climbing.py +169 -0
- morphml/optimizers/multi_objective/__init__.py +56 -0
- morphml/optimizers/multi_objective/indicators.py +504 -0
- morphml/optimizers/multi_objective/nsga2.py +647 -0
- morphml/optimizers/multi_objective/visualization.py +427 -0
- morphml/optimizers/nsga2.py +308 -0
- morphml/optimizers/random_search.py +172 -0
- morphml/optimizers/simulated_annealing.py +181 -0
- morphml/plugins/__init__.py +35 -0
- morphml/plugins/custom_evaluator_example.py +81 -0
- morphml/plugins/custom_optimizer_example.py +63 -0
- morphml/plugins/plugin_system.py +454 -0
- morphml/reports/__init__.py +30 -0
- morphml/reports/generator.py +362 -0
- morphml/tracking/__init__.py +7 -0
- morphml/tracking/experiment.py +309 -0
- morphml/tracking/logger.py +301 -0
- morphml/tracking/reporter.py +357 -0
- morphml/utils/__init__.py +6 -0
- morphml/utils/checkpoint.py +189 -0
- morphml/utils/comparison.py +390 -0
- morphml/utils/export.py +407 -0
- morphml/utils/progress.py +392 -0
- morphml/utils/validation.py +392 -0
- morphml/version.py +7 -0
- morphml/visualization/__init__.py +50 -0
- morphml/visualization/analytics.py +423 -0
- morphml/visualization/architecture_diagrams.py +353 -0
- morphml/visualization/architecture_plot.py +223 -0
- morphml/visualization/convergence_plot.py +174 -0
- morphml/visualization/crossover_viz.py +386 -0
- morphml/visualization/graph_viz.py +338 -0
- morphml/visualization/pareto_plot.py +149 -0
- morphml/visualization/plotly_dashboards.py +422 -0
- morphml/visualization/population.py +309 -0
- morphml/visualization/progress.py +260 -0
- morphml-1.0.0.dist-info/METADATA +434 -0
- morphml-1.0.0.dist-info/RECORD +158 -0
- morphml-1.0.0.dist-info/WHEEL +4 -0
- morphml-1.0.0.dist-info/entry_points.txt +3 -0
- morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""Redis-based distributed cache for fast intermediate storage.
|
|
2
|
+
|
|
3
|
+
Caches architecture evaluations, optimizer state, and temporary results.
|
|
4
|
+
|
|
5
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
6
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import pickle
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import redis
|
|
14
|
+
|
|
15
|
+
REDIS_AVAILABLE = True
|
|
16
|
+
except ImportError:
|
|
17
|
+
REDIS_AVAILABLE = False
|
|
18
|
+
|
|
19
|
+
from morphml.exceptions import DistributedError
|
|
20
|
+
from morphml.logging_config import get_logger
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DistributedCache:
|
|
26
|
+
"""
|
|
27
|
+
Redis-based distributed cache.
|
|
28
|
+
|
|
29
|
+
Provides fast caching for:
|
|
30
|
+
- Architecture evaluation results
|
|
31
|
+
- Optimizer state
|
|
32
|
+
- Temporary computation results
|
|
33
|
+
- Worker metadata
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
redis_url: Redis connection URL (default: redis://localhost:6379)
|
|
37
|
+
prefix: Key prefix for namespacing (default: 'morphml')
|
|
38
|
+
default_ttl: Default time-to-live in seconds (default: None = no expiry)
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> cache = DistributedCache('redis://localhost:6379')
|
|
42
|
+
>>> cache.set('key', {'value': 42}, ttl=3600)
|
|
43
|
+
>>> result = cache.get('key')
|
|
44
|
+
>>> cache.cache_architecture_result('abc123', {'fitness': 0.95})
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
redis_url: str = "redis://localhost:6379",
|
|
50
|
+
prefix: str = "morphml",
|
|
51
|
+
default_ttl: Optional[int] = None,
|
|
52
|
+
):
|
|
53
|
+
"""Initialize distributed cache."""
|
|
54
|
+
if not REDIS_AVAILABLE:
|
|
55
|
+
raise DistributedError("Redis not available. Install with: pip install redis")
|
|
56
|
+
|
|
57
|
+
self.redis_url = redis_url
|
|
58
|
+
self.prefix = prefix
|
|
59
|
+
self.default_ttl = default_ttl
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
self.client = redis.from_url(redis_url, decode_responses=False)
|
|
63
|
+
# Test connection
|
|
64
|
+
self.client.ping()
|
|
65
|
+
logger.info(f"Connected to Redis: {redis_url}")
|
|
66
|
+
except redis.ConnectionError as e:
|
|
67
|
+
raise DistributedError(f"Failed to connect to Redis: {e}")
|
|
68
|
+
|
|
69
|
+
def _make_key(self, key: str) -> str:
|
|
70
|
+
"""Add prefix to key."""
|
|
71
|
+
return f"{self.prefix}:{key}"
|
|
72
|
+
|
|
73
|
+
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Set value in cache.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
key: Cache key
|
|
79
|
+
value: Value to cache (will be pickled)
|
|
80
|
+
ttl: Time-to-live in seconds (None = use default)
|
|
81
|
+
"""
|
|
82
|
+
full_key = self._make_key(key)
|
|
83
|
+
serialized = pickle.dumps(value)
|
|
84
|
+
|
|
85
|
+
ttl = ttl or self.default_ttl
|
|
86
|
+
|
|
87
|
+
if ttl:
|
|
88
|
+
self.client.setex(full_key, ttl, serialized)
|
|
89
|
+
else:
|
|
90
|
+
self.client.set(full_key, serialized)
|
|
91
|
+
|
|
92
|
+
logger.debug(f"Cached {key} (ttl={ttl})")
|
|
93
|
+
|
|
94
|
+
def get(self, key: str) -> Optional[Any]:
|
|
95
|
+
"""
|
|
96
|
+
Get value from cache.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
key: Cache key
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Cached value or None if not found
|
|
103
|
+
"""
|
|
104
|
+
full_key = self._make_key(key)
|
|
105
|
+
value = self.client.get(full_key)
|
|
106
|
+
|
|
107
|
+
if value is None:
|
|
108
|
+
logger.debug(f"Cache miss: {key}")
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
logger.debug(f"Cache hit: {key}")
|
|
112
|
+
return pickle.loads(value)
|
|
113
|
+
|
|
114
|
+
def delete(self, key: str) -> bool:
|
|
115
|
+
"""
|
|
116
|
+
Delete key from cache.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
key: Cache key
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
True if key was deleted
|
|
123
|
+
"""
|
|
124
|
+
full_key = self._make_key(key)
|
|
125
|
+
result = self.client.delete(full_key)
|
|
126
|
+
return result > 0
|
|
127
|
+
|
|
128
|
+
def exists(self, key: str) -> bool:
|
|
129
|
+
"""
|
|
130
|
+
Check if key exists.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
key: Cache key
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
True if key exists
|
|
137
|
+
"""
|
|
138
|
+
full_key = self._make_key(key)
|
|
139
|
+
return self.client.exists(full_key) > 0
|
|
140
|
+
|
|
141
|
+
def cache_architecture_result(
|
|
142
|
+
self,
|
|
143
|
+
arch_hash: str,
|
|
144
|
+
result: Dict[str, Any],
|
|
145
|
+
ttl: int = 86400, # 24 hours
|
|
146
|
+
) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Cache architecture evaluation result.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
arch_hash: Architecture hash
|
|
152
|
+
result: Evaluation result
|
|
153
|
+
ttl: Time-to-live (seconds)
|
|
154
|
+
"""
|
|
155
|
+
key = f"arch:{arch_hash}"
|
|
156
|
+
self.set(key, result, ttl=ttl)
|
|
157
|
+
|
|
158
|
+
def get_architecture_result(self, arch_hash: str) -> Optional[Dict[str, Any]]:
|
|
159
|
+
"""
|
|
160
|
+
Get cached architecture result.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
arch_hash: Architecture hash
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Cached result or None
|
|
167
|
+
"""
|
|
168
|
+
key = f"arch:{arch_hash}"
|
|
169
|
+
return self.get(key)
|
|
170
|
+
|
|
171
|
+
def cache_optimizer_state(
|
|
172
|
+
self,
|
|
173
|
+
experiment_id: str,
|
|
174
|
+
generation: int,
|
|
175
|
+
state: Dict[str, Any],
|
|
176
|
+
ttl: int = 3600, # 1 hour
|
|
177
|
+
) -> None:
|
|
178
|
+
"""
|
|
179
|
+
Cache optimizer state for quick recovery.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
experiment_id: Experiment ID
|
|
183
|
+
generation: Generation number
|
|
184
|
+
state: Optimizer state
|
|
185
|
+
ttl: Time-to-live (seconds)
|
|
186
|
+
"""
|
|
187
|
+
key = f"optimizer:{experiment_id}:gen{generation}"
|
|
188
|
+
self.set(key, state, ttl=ttl)
|
|
189
|
+
|
|
190
|
+
def get_optimizer_state(self, experiment_id: str, generation: int) -> Optional[Dict[str, Any]]:
|
|
191
|
+
"""
|
|
192
|
+
Get cached optimizer state.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
experiment_id: Experiment ID
|
|
196
|
+
generation: Generation number
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Cached state or None
|
|
200
|
+
"""
|
|
201
|
+
key = f"optimizer:{experiment_id}:gen{generation}"
|
|
202
|
+
return self.get(key)
|
|
203
|
+
|
|
204
|
+
def invalidate_pattern(self, pattern: str) -> int:
|
|
205
|
+
"""
|
|
206
|
+
Invalidate all keys matching pattern.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
pattern: Key pattern (supports * wildcard)
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Number of keys deleted
|
|
213
|
+
"""
|
|
214
|
+
full_pattern = self._make_key(pattern)
|
|
215
|
+
keys = list(self.client.scan_iter(match=full_pattern))
|
|
216
|
+
|
|
217
|
+
if keys:
|
|
218
|
+
deleted = self.client.delete(*keys)
|
|
219
|
+
logger.info(f"Invalidated {deleted} keys matching {pattern}")
|
|
220
|
+
return deleted
|
|
221
|
+
|
|
222
|
+
return 0
|
|
223
|
+
|
|
224
|
+
def invalidate_experiment(self, experiment_id: str) -> int:
|
|
225
|
+
"""
|
|
226
|
+
Invalidate all cache entries for experiment.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
experiment_id: Experiment ID
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Number of keys deleted
|
|
233
|
+
"""
|
|
234
|
+
return self.invalidate_pattern(f"*:{experiment_id}:*")
|
|
235
|
+
|
|
236
|
+
def get_statistics(self) -> Dict[str, Any]:
|
|
237
|
+
"""
|
|
238
|
+
Get cache statistics.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Statistics dictionary
|
|
242
|
+
"""
|
|
243
|
+
info = self.client.info("stats")
|
|
244
|
+
|
|
245
|
+
return {
|
|
246
|
+
"total_connections": info.get("total_connections_received", 0),
|
|
247
|
+
"total_commands": info.get("total_commands_processed", 0),
|
|
248
|
+
"keyspace_hits": info.get("keyspace_hits", 0),
|
|
249
|
+
"keyspace_misses": info.get("keyspace_misses", 0),
|
|
250
|
+
"hit_rate": (
|
|
251
|
+
info.get("keyspace_hits", 0)
|
|
252
|
+
/ max(
|
|
253
|
+
info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0),
|
|
254
|
+
1,
|
|
255
|
+
)
|
|
256
|
+
* 100
|
|
257
|
+
),
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
def clear_all(self) -> None:
|
|
261
|
+
"""
|
|
262
|
+
Clear all cache entries with prefix.
|
|
263
|
+
|
|
264
|
+
Warning: This deletes all keys with the configured prefix!
|
|
265
|
+
"""
|
|
266
|
+
pattern = f"{self.prefix}:*"
|
|
267
|
+
keys = list(self.client.scan_iter(match=pattern))
|
|
268
|
+
|
|
269
|
+
if keys:
|
|
270
|
+
self.client.delete(*keys)
|
|
271
|
+
logger.warning(f"Cleared {len(keys)} cache entries")
|
|
272
|
+
|
|
273
|
+
def set_multiple(self, mapping: Dict[str, Any], ttl: Optional[int] = None) -> None:
|
|
274
|
+
"""
|
|
275
|
+
Set multiple key-value pairs.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
mapping: Dictionary of key-value pairs
|
|
279
|
+
ttl: Time-to-live for all keys
|
|
280
|
+
"""
|
|
281
|
+
pipe = self.client.pipeline()
|
|
282
|
+
|
|
283
|
+
for key, value in mapping.items():
|
|
284
|
+
full_key = self._make_key(key)
|
|
285
|
+
serialized = pickle.dumps(value)
|
|
286
|
+
|
|
287
|
+
if ttl:
|
|
288
|
+
pipe.setex(full_key, ttl, serialized)
|
|
289
|
+
else:
|
|
290
|
+
pipe.set(full_key, serialized)
|
|
291
|
+
|
|
292
|
+
pipe.execute()
|
|
293
|
+
logger.debug(f"Cached {len(mapping)} keys")
|
|
294
|
+
|
|
295
|
+
def get_multiple(self, keys: List[str]) -> Dict[str, Any]:
|
|
296
|
+
"""
|
|
297
|
+
Get multiple values.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
keys: List of cache keys
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Dictionary of found key-value pairs
|
|
304
|
+
"""
|
|
305
|
+
full_keys = [self._make_key(k) for k in keys]
|
|
306
|
+
values = self.client.mget(full_keys)
|
|
307
|
+
|
|
308
|
+
result = {}
|
|
309
|
+
for key, value in zip(keys, values):
|
|
310
|
+
if value is not None:
|
|
311
|
+
result[key] = pickle.loads(value)
|
|
312
|
+
|
|
313
|
+
return result
|
|
314
|
+
|
|
315
|
+
def increment(self, key: str, amount: int = 1) -> int:
|
|
316
|
+
"""
|
|
317
|
+
Increment counter.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
key: Counter key
|
|
321
|
+
amount: Increment amount
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
New counter value
|
|
325
|
+
"""
|
|
326
|
+
full_key = self._make_key(key)
|
|
327
|
+
return self.client.incr(full_key, amount)
|
|
328
|
+
|
|
329
|
+
def decrement(self, key: str, amount: int = 1) -> int:
|
|
330
|
+
"""
|
|
331
|
+
Decrement counter.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
key: Counter key
|
|
335
|
+
amount: Decrement amount
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
New counter value
|
|
339
|
+
"""
|
|
340
|
+
full_key = self._make_key(key)
|
|
341
|
+
return self.client.decr(full_key, amount)
|
|
342
|
+
|
|
343
|
+
def get_ttl(self, key: str) -> Optional[int]:
|
|
344
|
+
"""
|
|
345
|
+
Get remaining TTL for key.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
key: Cache key
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
TTL in seconds, or None if no TTL set
|
|
352
|
+
"""
|
|
353
|
+
full_key = self._make_key(key)
|
|
354
|
+
ttl = self.client.ttl(full_key)
|
|
355
|
+
|
|
356
|
+
if ttl == -1: # No TTL
|
|
357
|
+
return None
|
|
358
|
+
elif ttl == -2: # Key doesn't exist
|
|
359
|
+
return None
|
|
360
|
+
else:
|
|
361
|
+
return ttl
|
|
362
|
+
|
|
363
|
+
def close(self) -> None:
|
|
364
|
+
"""Close Redis connection."""
|
|
365
|
+
self.client.close()
|
|
366
|
+
logger.info("Closed Redis connection")
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""Checkpoint management for experiment recovery.
|
|
2
|
+
|
|
3
|
+
Enables saving and restoring experiment state for fault tolerance.
|
|
4
|
+
|
|
5
|
+
Author: Eshan Roy <eshanized@proton.me>
|
|
6
|
+
Organization: TONMOY INFRASTRUCTURE & VISION
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import pickle
|
|
10
|
+
import time
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
from morphml.core.search import Individual
|
|
14
|
+
from morphml.distributed.storage.artifacts import ArtifactStore
|
|
15
|
+
from morphml.distributed.storage.cache import DistributedCache
|
|
16
|
+
from morphml.logging_config import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CheckpointManager:
|
|
22
|
+
"""
|
|
23
|
+
Manage experiment checkpoints for recovery.
|
|
24
|
+
|
|
25
|
+
Uses both cache (fast) and artifact storage (persistent):
|
|
26
|
+
- Cache: Recent checkpoints for quick recovery
|
|
27
|
+
- Storage: All checkpoints for long-term persistence
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
artifact_store: ArtifactStore for persistent storage
|
|
31
|
+
cache: DistributedCache for fast access (optional)
|
|
32
|
+
checkpoint_interval: Save checkpoint every N generations
|
|
33
|
+
max_checkpoints: Maximum checkpoints to keep (0 = unlimited)
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> manager = CheckpointManager(artifact_store, cache)
|
|
37
|
+
>>>
|
|
38
|
+
>>> # Save checkpoint
|
|
39
|
+
>>> manager.save_checkpoint(
|
|
40
|
+
... experiment_id='exp1',
|
|
41
|
+
... generation=10,
|
|
42
|
+
... optimizer_state={'population_size': 50},
|
|
43
|
+
... population=population
|
|
44
|
+
... )
|
|
45
|
+
>>>
|
|
46
|
+
>>> # Load latest checkpoint
|
|
47
|
+
>>> checkpoint = manager.load_checkpoint('exp1')
|
|
48
|
+
>>> if checkpoint:
|
|
49
|
+
... generation = checkpoint['generation']
|
|
50
|
+
... optimizer_state = checkpoint['optimizer_state']
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
artifact_store: ArtifactStore,
|
|
56
|
+
cache: Optional[DistributedCache] = None,
|
|
57
|
+
checkpoint_interval: int = 10,
|
|
58
|
+
max_checkpoints: int = 5,
|
|
59
|
+
):
|
|
60
|
+
"""Initialize checkpoint manager."""
|
|
61
|
+
self.store = artifact_store
|
|
62
|
+
self.cache = cache
|
|
63
|
+
self.checkpoint_interval = checkpoint_interval
|
|
64
|
+
self.max_checkpoints = max_checkpoints
|
|
65
|
+
|
|
66
|
+
logger.info(
|
|
67
|
+
f"Initialized CheckpointManager "
|
|
68
|
+
f"(interval={checkpoint_interval}, max={max_checkpoints})"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def save_checkpoint(
|
|
72
|
+
self,
|
|
73
|
+
experiment_id: str,
|
|
74
|
+
generation: int,
|
|
75
|
+
optimizer_state: Dict[str, Any],
|
|
76
|
+
population: Optional[List[Individual]] = None,
|
|
77
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Save experiment checkpoint.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
experiment_id: Experiment identifier
|
|
84
|
+
generation: Current generation number
|
|
85
|
+
optimizer_state: Optimizer state dictionary
|
|
86
|
+
population: Current population (optional)
|
|
87
|
+
metadata: Additional metadata (optional)
|
|
88
|
+
"""
|
|
89
|
+
checkpoint = {
|
|
90
|
+
"experiment_id": experiment_id,
|
|
91
|
+
"generation": generation,
|
|
92
|
+
"optimizer_state": optimizer_state,
|
|
93
|
+
"population": [ind.to_dict() for ind in population] if population else [],
|
|
94
|
+
"metadata": metadata or {},
|
|
95
|
+
"timestamp": time.time(),
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Save to cache for fast recovery
|
|
99
|
+
if self.cache:
|
|
100
|
+
try:
|
|
101
|
+
self.cache.set(f"checkpoint:{experiment_id}:latest", checkpoint, ttl=3600)
|
|
102
|
+
self.cache.set(
|
|
103
|
+
f"checkpoint:{experiment_id}:gen{generation}",
|
|
104
|
+
checkpoint,
|
|
105
|
+
ttl=7200,
|
|
106
|
+
)
|
|
107
|
+
logger.debug(f"Cached checkpoint for generation {generation}")
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.warning(f"Failed to cache checkpoint: {e}")
|
|
110
|
+
|
|
111
|
+
# Save to persistent storage
|
|
112
|
+
try:
|
|
113
|
+
# Serialize checkpoint
|
|
114
|
+
checkpoint_bytes = pickle.dumps(checkpoint)
|
|
115
|
+
|
|
116
|
+
# Upload to S3
|
|
117
|
+
s3_key = f"checkpoints/{experiment_id}/gen_{generation:06d}.pkl"
|
|
118
|
+
self.store.upload_bytes(
|
|
119
|
+
checkpoint_bytes,
|
|
120
|
+
s3_key,
|
|
121
|
+
metadata={
|
|
122
|
+
"experiment_id": experiment_id,
|
|
123
|
+
"generation": str(generation),
|
|
124
|
+
"timestamp": str(int(time.time())),
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
logger.info(f"Saved checkpoint for {experiment_id} generation {generation}")
|
|
129
|
+
|
|
130
|
+
# Clean up old checkpoints
|
|
131
|
+
if self.max_checkpoints > 0:
|
|
132
|
+
self._cleanup_old_checkpoints(experiment_id)
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.error(f"Failed to save checkpoint: {e}")
|
|
136
|
+
raise
|
|
137
|
+
|
|
138
|
+
def load_checkpoint(
|
|
139
|
+
self, experiment_id: str, generation: Optional[int] = None
|
|
140
|
+
) -> Optional[Dict[str, Any]]:
|
|
141
|
+
"""
|
|
142
|
+
Load checkpoint.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
experiment_id: Experiment identifier
|
|
146
|
+
generation: Specific generation (None = latest)
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Checkpoint dictionary or None if not found
|
|
150
|
+
"""
|
|
151
|
+
# Try cache first (latest only)
|
|
152
|
+
if generation is None and self.cache:
|
|
153
|
+
try:
|
|
154
|
+
checkpoint = self.cache.get(f"checkpoint:{experiment_id}:latest")
|
|
155
|
+
|
|
156
|
+
if checkpoint:
|
|
157
|
+
logger.info("Loaded checkpoint from cache (latest)")
|
|
158
|
+
return checkpoint
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.warning(f"Failed to load from cache: {e}")
|
|
161
|
+
|
|
162
|
+
# Load from persistent storage
|
|
163
|
+
try:
|
|
164
|
+
if generation is not None:
|
|
165
|
+
# Load specific generation
|
|
166
|
+
s3_key = f"checkpoints/{experiment_id}/gen_{generation:06d}.pkl"
|
|
167
|
+
|
|
168
|
+
if not self.store.exists(s3_key):
|
|
169
|
+
logger.warning(f"Checkpoint not found for generation {generation}")
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
checkpoint_bytes = self.store.download_bytes(s3_key)
|
|
173
|
+
checkpoint = pickle.loads(checkpoint_bytes)
|
|
174
|
+
|
|
175
|
+
logger.info(f"Loaded checkpoint for {experiment_id} generation {generation}")
|
|
176
|
+
|
|
177
|
+
return checkpoint
|
|
178
|
+
|
|
179
|
+
else:
|
|
180
|
+
# Load latest
|
|
181
|
+
checkpoints = self.list_checkpoints(experiment_id)
|
|
182
|
+
|
|
183
|
+
if not checkpoints:
|
|
184
|
+
logger.info(f"No checkpoints found for {experiment_id}")
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
# Get latest
|
|
188
|
+
latest = checkpoints[-1]
|
|
189
|
+
s3_key = latest["key"]
|
|
190
|
+
|
|
191
|
+
checkpoint_bytes = self.store.download_bytes(s3_key)
|
|
192
|
+
checkpoint = pickle.loads(checkpoint_bytes)
|
|
193
|
+
|
|
194
|
+
logger.info(
|
|
195
|
+
f"Loaded latest checkpoint for {experiment_id} "
|
|
196
|
+
f"(generation {checkpoint['generation']})"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return checkpoint
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error(f"Failed to load checkpoint: {e}")
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
def list_checkpoints(self, experiment_id: str) -> List[Dict[str, Any]]:
|
|
206
|
+
"""
|
|
207
|
+
List all checkpoints for experiment.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
experiment_id: Experiment identifier
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
List of checkpoint metadata dictionaries
|
|
214
|
+
"""
|
|
215
|
+
prefix = f"checkpoints/{experiment_id}/"
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
checkpoints = self.store.list_objects(prefix)
|
|
219
|
+
|
|
220
|
+
# Sort by generation (extracted from key)
|
|
221
|
+
checkpoints.sort(key=lambda x: x["key"])
|
|
222
|
+
|
|
223
|
+
return checkpoints
|
|
224
|
+
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.error(f"Failed to list checkpoints: {e}")
|
|
227
|
+
return []
|
|
228
|
+
|
|
229
|
+
def delete_checkpoint(self, experiment_id: str, generation: int) -> None:
|
|
230
|
+
"""
|
|
231
|
+
Delete specific checkpoint.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
experiment_id: Experiment identifier
|
|
235
|
+
generation: Generation number
|
|
236
|
+
"""
|
|
237
|
+
s3_key = f"checkpoints/{experiment_id}/gen_{generation:06d}.pkl"
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
self.store.delete(s3_key)
|
|
241
|
+
logger.info(f"Deleted checkpoint for generation {generation}")
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.error(f"Failed to delete checkpoint: {e}")
|
|
245
|
+
|
|
246
|
+
def delete_all_checkpoints(self, experiment_id: str) -> int:
|
|
247
|
+
"""
|
|
248
|
+
Delete all checkpoints for experiment.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
experiment_id: Experiment identifier
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Number of checkpoints deleted
|
|
255
|
+
"""
|
|
256
|
+
prefix = f"checkpoints/{experiment_id}/"
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
deleted = self.store.delete_prefix(prefix)
|
|
260
|
+
logger.info(f"Deleted {deleted} checkpoints for {experiment_id}")
|
|
261
|
+
|
|
262
|
+
# Also clear cache
|
|
263
|
+
if self.cache:
|
|
264
|
+
self.cache.invalidate_pattern(f"checkpoint:{experiment_id}:*")
|
|
265
|
+
|
|
266
|
+
return deleted
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.error(f"Failed to delete checkpoints: {e}")
|
|
270
|
+
return 0
|
|
271
|
+
|
|
272
|
+
def _cleanup_old_checkpoints(self, experiment_id: str) -> None:
|
|
273
|
+
"""Remove old checkpoints keeping only max_checkpoints."""
|
|
274
|
+
checkpoints = self.list_checkpoints(experiment_id)
|
|
275
|
+
|
|
276
|
+
if len(checkpoints) <= self.max_checkpoints:
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
# Delete oldest checkpoints
|
|
280
|
+
num_to_delete = len(checkpoints) - self.max_checkpoints
|
|
281
|
+
|
|
282
|
+
for checkpoint in checkpoints[:num_to_delete]:
|
|
283
|
+
try:
|
|
284
|
+
self.store.delete(checkpoint["key"])
|
|
285
|
+
logger.debug(f"Cleaned up old checkpoint: {checkpoint['key']}")
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.warning(f"Failed to cleanup checkpoint: {e}")
|
|
288
|
+
|
|
289
|
+
def should_checkpoint(self, generation: int) -> bool:
|
|
290
|
+
"""
|
|
291
|
+
Check if should save checkpoint at this generation.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
generation: Current generation
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
True if should checkpoint
|
|
298
|
+
"""
|
|
299
|
+
return generation % self.checkpoint_interval == 0
|
|
300
|
+
|
|
301
|
+
def get_statistics(self, experiment_id: str) -> Dict[str, Any]:
|
|
302
|
+
"""
|
|
303
|
+
Get checkpoint statistics.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
experiment_id: Experiment identifier
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Statistics dictionary
|
|
310
|
+
"""
|
|
311
|
+
checkpoints = self.list_checkpoints(experiment_id)
|
|
312
|
+
|
|
313
|
+
if not checkpoints:
|
|
314
|
+
return {
|
|
315
|
+
"experiment_id": experiment_id,
|
|
316
|
+
"num_checkpoints": 0,
|
|
317
|
+
"total_size_bytes": 0,
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
total_size = sum(cp["size"] for cp in checkpoints)
|
|
321
|
+
|
|
322
|
+
return {
|
|
323
|
+
"experiment_id": experiment_id,
|
|
324
|
+
"num_checkpoints": len(checkpoints),
|
|
325
|
+
"total_size_bytes": total_size,
|
|
326
|
+
"total_size_mb": total_size / (1024 * 1024),
|
|
327
|
+
"oldest": checkpoints[0]["last_modified"],
|
|
328
|
+
"newest": checkpoints[-1]["last_modified"],
|
|
329
|
+
}
|