morphml 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of morphml might be problematic. Click here for more details.

Files changed (158) hide show
  1. morphml/__init__.py +14 -0
  2. morphml/api/__init__.py +26 -0
  3. morphml/api/app.py +326 -0
  4. morphml/api/auth.py +193 -0
  5. morphml/api/client.py +338 -0
  6. morphml/api/models.py +132 -0
  7. morphml/api/rate_limit.py +192 -0
  8. morphml/benchmarking/__init__.py +36 -0
  9. morphml/benchmarking/comparison.py +430 -0
  10. morphml/benchmarks/__init__.py +56 -0
  11. morphml/benchmarks/comparator.py +409 -0
  12. morphml/benchmarks/datasets.py +280 -0
  13. morphml/benchmarks/metrics.py +199 -0
  14. morphml/benchmarks/openml_suite.py +201 -0
  15. morphml/benchmarks/problems.py +289 -0
  16. morphml/benchmarks/suite.py +318 -0
  17. morphml/cli/__init__.py +5 -0
  18. morphml/cli/commands/experiment.py +329 -0
  19. morphml/cli/main.py +457 -0
  20. morphml/cli/quickstart.py +312 -0
  21. morphml/config.py +278 -0
  22. morphml/constraints/__init__.py +19 -0
  23. morphml/constraints/handler.py +205 -0
  24. morphml/constraints/predicates.py +285 -0
  25. morphml/core/__init__.py +3 -0
  26. morphml/core/crossover.py +449 -0
  27. morphml/core/dsl/README.md +359 -0
  28. morphml/core/dsl/__init__.py +72 -0
  29. morphml/core/dsl/ast_nodes.py +364 -0
  30. morphml/core/dsl/compiler.py +318 -0
  31. morphml/core/dsl/layers.py +368 -0
  32. morphml/core/dsl/lexer.py +336 -0
  33. morphml/core/dsl/parser.py +455 -0
  34. morphml/core/dsl/search_space.py +386 -0
  35. morphml/core/dsl/syntax.py +199 -0
  36. morphml/core/dsl/type_system.py +361 -0
  37. morphml/core/dsl/validator.py +386 -0
  38. morphml/core/graph/__init__.py +40 -0
  39. morphml/core/graph/edge.py +124 -0
  40. morphml/core/graph/graph.py +507 -0
  41. morphml/core/graph/mutations.py +409 -0
  42. morphml/core/graph/node.py +196 -0
  43. morphml/core/graph/serialization.py +361 -0
  44. morphml/core/graph/visualization.py +431 -0
  45. morphml/core/objectives/__init__.py +20 -0
  46. morphml/core/search/__init__.py +33 -0
  47. morphml/core/search/individual.py +252 -0
  48. morphml/core/search/parameters.py +453 -0
  49. morphml/core/search/population.py +375 -0
  50. morphml/core/search/search_engine.py +340 -0
  51. morphml/distributed/__init__.py +76 -0
  52. morphml/distributed/fault_tolerance.py +497 -0
  53. morphml/distributed/health_monitor.py +348 -0
  54. morphml/distributed/master.py +709 -0
  55. morphml/distributed/proto/README.md +224 -0
  56. morphml/distributed/proto/__init__.py +74 -0
  57. morphml/distributed/proto/worker.proto +170 -0
  58. morphml/distributed/proto/worker_pb2.py +79 -0
  59. morphml/distributed/proto/worker_pb2_grpc.py +423 -0
  60. morphml/distributed/resource_manager.py +416 -0
  61. morphml/distributed/scheduler.py +567 -0
  62. morphml/distributed/storage/__init__.py +33 -0
  63. morphml/distributed/storage/artifacts.py +381 -0
  64. morphml/distributed/storage/cache.py +366 -0
  65. morphml/distributed/storage/checkpointing.py +329 -0
  66. morphml/distributed/storage/database.py +459 -0
  67. morphml/distributed/worker.py +549 -0
  68. morphml/evaluation/__init__.py +5 -0
  69. morphml/evaluation/heuristic.py +237 -0
  70. morphml/exceptions.py +55 -0
  71. morphml/execution/__init__.py +5 -0
  72. morphml/execution/local_executor.py +350 -0
  73. morphml/integrations/__init__.py +28 -0
  74. morphml/integrations/jax_adapter.py +206 -0
  75. morphml/integrations/pytorch_adapter.py +530 -0
  76. morphml/integrations/sklearn_adapter.py +206 -0
  77. morphml/integrations/tensorflow_adapter.py +230 -0
  78. morphml/logging_config.py +93 -0
  79. morphml/meta_learning/__init__.py +66 -0
  80. morphml/meta_learning/architecture_similarity.py +277 -0
  81. morphml/meta_learning/experiment_database.py +240 -0
  82. morphml/meta_learning/knowledge_base/__init__.py +19 -0
  83. morphml/meta_learning/knowledge_base/embedder.py +179 -0
  84. morphml/meta_learning/knowledge_base/knowledge_base.py +313 -0
  85. morphml/meta_learning/knowledge_base/meta_features.py +265 -0
  86. morphml/meta_learning/knowledge_base/vector_store.py +271 -0
  87. morphml/meta_learning/predictors/__init__.py +27 -0
  88. morphml/meta_learning/predictors/ensemble.py +221 -0
  89. morphml/meta_learning/predictors/gnn_predictor.py +552 -0
  90. morphml/meta_learning/predictors/learning_curve.py +231 -0
  91. morphml/meta_learning/predictors/proxy_metrics.py +261 -0
  92. morphml/meta_learning/strategy_evolution/__init__.py +27 -0
  93. morphml/meta_learning/strategy_evolution/adaptive_optimizer.py +226 -0
  94. morphml/meta_learning/strategy_evolution/bandit.py +276 -0
  95. morphml/meta_learning/strategy_evolution/portfolio.py +230 -0
  96. morphml/meta_learning/transfer.py +581 -0
  97. morphml/meta_learning/warm_start.py +286 -0
  98. morphml/optimizers/__init__.py +74 -0
  99. morphml/optimizers/adaptive_operators.py +399 -0
  100. morphml/optimizers/bayesian/__init__.py +52 -0
  101. morphml/optimizers/bayesian/acquisition.py +387 -0
  102. morphml/optimizers/bayesian/base.py +319 -0
  103. morphml/optimizers/bayesian/gaussian_process.py +635 -0
  104. morphml/optimizers/bayesian/smac.py +534 -0
  105. morphml/optimizers/bayesian/tpe.py +411 -0
  106. morphml/optimizers/differential_evolution.py +220 -0
  107. morphml/optimizers/evolutionary/__init__.py +61 -0
  108. morphml/optimizers/evolutionary/cma_es.py +416 -0
  109. morphml/optimizers/evolutionary/differential_evolution.py +556 -0
  110. morphml/optimizers/evolutionary/encoding.py +426 -0
  111. morphml/optimizers/evolutionary/particle_swarm.py +449 -0
  112. morphml/optimizers/genetic_algorithm.py +486 -0
  113. morphml/optimizers/gradient_based/__init__.py +22 -0
  114. morphml/optimizers/gradient_based/darts.py +550 -0
  115. morphml/optimizers/gradient_based/enas.py +585 -0
  116. morphml/optimizers/gradient_based/operations.py +474 -0
  117. morphml/optimizers/gradient_based/utils.py +601 -0
  118. morphml/optimizers/hill_climbing.py +169 -0
  119. morphml/optimizers/multi_objective/__init__.py +56 -0
  120. morphml/optimizers/multi_objective/indicators.py +504 -0
  121. morphml/optimizers/multi_objective/nsga2.py +647 -0
  122. morphml/optimizers/multi_objective/visualization.py +427 -0
  123. morphml/optimizers/nsga2.py +308 -0
  124. morphml/optimizers/random_search.py +172 -0
  125. morphml/optimizers/simulated_annealing.py +181 -0
  126. morphml/plugins/__init__.py +35 -0
  127. morphml/plugins/custom_evaluator_example.py +81 -0
  128. morphml/plugins/custom_optimizer_example.py +63 -0
  129. morphml/plugins/plugin_system.py +454 -0
  130. morphml/reports/__init__.py +30 -0
  131. morphml/reports/generator.py +362 -0
  132. morphml/tracking/__init__.py +7 -0
  133. morphml/tracking/experiment.py +309 -0
  134. morphml/tracking/logger.py +301 -0
  135. morphml/tracking/reporter.py +357 -0
  136. morphml/utils/__init__.py +6 -0
  137. morphml/utils/checkpoint.py +189 -0
  138. morphml/utils/comparison.py +390 -0
  139. morphml/utils/export.py +407 -0
  140. morphml/utils/progress.py +392 -0
  141. morphml/utils/validation.py +392 -0
  142. morphml/version.py +7 -0
  143. morphml/visualization/__init__.py +50 -0
  144. morphml/visualization/analytics.py +423 -0
  145. morphml/visualization/architecture_diagrams.py +353 -0
  146. morphml/visualization/architecture_plot.py +223 -0
  147. morphml/visualization/convergence_plot.py +174 -0
  148. morphml/visualization/crossover_viz.py +386 -0
  149. morphml/visualization/graph_viz.py +338 -0
  150. morphml/visualization/pareto_plot.py +149 -0
  151. morphml/visualization/plotly_dashboards.py +422 -0
  152. morphml/visualization/population.py +309 -0
  153. morphml/visualization/progress.py +260 -0
  154. morphml-1.0.0.dist-info/METADATA +434 -0
  155. morphml-1.0.0.dist-info/RECORD +158 -0
  156. morphml-1.0.0.dist-info/WHEEL +4 -0
  157. morphml-1.0.0.dist-info/entry_points.txt +3 -0
  158. morphml-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,366 @@
1
+ """Redis-based distributed cache for fast intermediate storage.
2
+
3
+ Caches architecture evaluations, optimizer state, and temporary results.
4
+
5
+ Author: Eshan Roy <eshanized@proton.me>
6
+ Organization: TONMOY INFRASTRUCTURE & VISION
7
+ """
8
+
9
+ import pickle
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ try:
13
+ import redis
14
+
15
+ REDIS_AVAILABLE = True
16
+ except ImportError:
17
+ REDIS_AVAILABLE = False
18
+
19
+ from morphml.exceptions import DistributedError
20
+ from morphml.logging_config import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ class DistributedCache:
26
+ """
27
+ Redis-based distributed cache.
28
+
29
+ Provides fast caching for:
30
+ - Architecture evaluation results
31
+ - Optimizer state
32
+ - Temporary computation results
33
+ - Worker metadata
34
+
35
+ Args:
36
+ redis_url: Redis connection URL (default: redis://localhost:6379)
37
+ prefix: Key prefix for namespacing (default: 'morphml')
38
+ default_ttl: Default time-to-live in seconds (default: None = no expiry)
39
+
40
+ Example:
41
+ >>> cache = DistributedCache('redis://localhost:6379')
42
+ >>> cache.set('key', {'value': 42}, ttl=3600)
43
+ >>> result = cache.get('key')
44
+ >>> cache.cache_architecture_result('abc123', {'fitness': 0.95})
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ redis_url: str = "redis://localhost:6379",
50
+ prefix: str = "morphml",
51
+ default_ttl: Optional[int] = None,
52
+ ):
53
+ """Initialize distributed cache."""
54
+ if not REDIS_AVAILABLE:
55
+ raise DistributedError("Redis not available. Install with: pip install redis")
56
+
57
+ self.redis_url = redis_url
58
+ self.prefix = prefix
59
+ self.default_ttl = default_ttl
60
+
61
+ try:
62
+ self.client = redis.from_url(redis_url, decode_responses=False)
63
+ # Test connection
64
+ self.client.ping()
65
+ logger.info(f"Connected to Redis: {redis_url}")
66
+ except redis.ConnectionError as e:
67
+ raise DistributedError(f"Failed to connect to Redis: {e}")
68
+
69
+ def _make_key(self, key: str) -> str:
70
+ """Add prefix to key."""
71
+ return f"{self.prefix}:{key}"
72
+
73
+ def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
74
+ """
75
+ Set value in cache.
76
+
77
+ Args:
78
+ key: Cache key
79
+ value: Value to cache (will be pickled)
80
+ ttl: Time-to-live in seconds (None = use default)
81
+ """
82
+ full_key = self._make_key(key)
83
+ serialized = pickle.dumps(value)
84
+
85
+ ttl = ttl or self.default_ttl
86
+
87
+ if ttl:
88
+ self.client.setex(full_key, ttl, serialized)
89
+ else:
90
+ self.client.set(full_key, serialized)
91
+
92
+ logger.debug(f"Cached {key} (ttl={ttl})")
93
+
94
+ def get(self, key: str) -> Optional[Any]:
95
+ """
96
+ Get value from cache.
97
+
98
+ Args:
99
+ key: Cache key
100
+
101
+ Returns:
102
+ Cached value or None if not found
103
+ """
104
+ full_key = self._make_key(key)
105
+ value = self.client.get(full_key)
106
+
107
+ if value is None:
108
+ logger.debug(f"Cache miss: {key}")
109
+ return None
110
+
111
+ logger.debug(f"Cache hit: {key}")
112
+ return pickle.loads(value)
113
+
114
+ def delete(self, key: str) -> bool:
115
+ """
116
+ Delete key from cache.
117
+
118
+ Args:
119
+ key: Cache key
120
+
121
+ Returns:
122
+ True if key was deleted
123
+ """
124
+ full_key = self._make_key(key)
125
+ result = self.client.delete(full_key)
126
+ return result > 0
127
+
128
+ def exists(self, key: str) -> bool:
129
+ """
130
+ Check if key exists.
131
+
132
+ Args:
133
+ key: Cache key
134
+
135
+ Returns:
136
+ True if key exists
137
+ """
138
+ full_key = self._make_key(key)
139
+ return self.client.exists(full_key) > 0
140
+
141
+ def cache_architecture_result(
142
+ self,
143
+ arch_hash: str,
144
+ result: Dict[str, Any],
145
+ ttl: int = 86400, # 24 hours
146
+ ) -> None:
147
+ """
148
+ Cache architecture evaluation result.
149
+
150
+ Args:
151
+ arch_hash: Architecture hash
152
+ result: Evaluation result
153
+ ttl: Time-to-live (seconds)
154
+ """
155
+ key = f"arch:{arch_hash}"
156
+ self.set(key, result, ttl=ttl)
157
+
158
+ def get_architecture_result(self, arch_hash: str) -> Optional[Dict[str, Any]]:
159
+ """
160
+ Get cached architecture result.
161
+
162
+ Args:
163
+ arch_hash: Architecture hash
164
+
165
+ Returns:
166
+ Cached result or None
167
+ """
168
+ key = f"arch:{arch_hash}"
169
+ return self.get(key)
170
+
171
+ def cache_optimizer_state(
172
+ self,
173
+ experiment_id: str,
174
+ generation: int,
175
+ state: Dict[str, Any],
176
+ ttl: int = 3600, # 1 hour
177
+ ) -> None:
178
+ """
179
+ Cache optimizer state for quick recovery.
180
+
181
+ Args:
182
+ experiment_id: Experiment ID
183
+ generation: Generation number
184
+ state: Optimizer state
185
+ ttl: Time-to-live (seconds)
186
+ """
187
+ key = f"optimizer:{experiment_id}:gen{generation}"
188
+ self.set(key, state, ttl=ttl)
189
+
190
+ def get_optimizer_state(self, experiment_id: str, generation: int) -> Optional[Dict[str, Any]]:
191
+ """
192
+ Get cached optimizer state.
193
+
194
+ Args:
195
+ experiment_id: Experiment ID
196
+ generation: Generation number
197
+
198
+ Returns:
199
+ Cached state or None
200
+ """
201
+ key = f"optimizer:{experiment_id}:gen{generation}"
202
+ return self.get(key)
203
+
204
+ def invalidate_pattern(self, pattern: str) -> int:
205
+ """
206
+ Invalidate all keys matching pattern.
207
+
208
+ Args:
209
+ pattern: Key pattern (supports * wildcard)
210
+
211
+ Returns:
212
+ Number of keys deleted
213
+ """
214
+ full_pattern = self._make_key(pattern)
215
+ keys = list(self.client.scan_iter(match=full_pattern))
216
+
217
+ if keys:
218
+ deleted = self.client.delete(*keys)
219
+ logger.info(f"Invalidated {deleted} keys matching {pattern}")
220
+ return deleted
221
+
222
+ return 0
223
+
224
+ def invalidate_experiment(self, experiment_id: str) -> int:
225
+ """
226
+ Invalidate all cache entries for experiment.
227
+
228
+ Args:
229
+ experiment_id: Experiment ID
230
+
231
+ Returns:
232
+ Number of keys deleted
233
+ """
234
+ return self.invalidate_pattern(f"*:{experiment_id}:*")
235
+
236
+ def get_statistics(self) -> Dict[str, Any]:
237
+ """
238
+ Get cache statistics.
239
+
240
+ Returns:
241
+ Statistics dictionary
242
+ """
243
+ info = self.client.info("stats")
244
+
245
+ return {
246
+ "total_connections": info.get("total_connections_received", 0),
247
+ "total_commands": info.get("total_commands_processed", 0),
248
+ "keyspace_hits": info.get("keyspace_hits", 0),
249
+ "keyspace_misses": info.get("keyspace_misses", 0),
250
+ "hit_rate": (
251
+ info.get("keyspace_hits", 0)
252
+ / max(
253
+ info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0),
254
+ 1,
255
+ )
256
+ * 100
257
+ ),
258
+ }
259
+
260
+ def clear_all(self) -> None:
261
+ """
262
+ Clear all cache entries with prefix.
263
+
264
+ Warning: This deletes all keys with the configured prefix!
265
+ """
266
+ pattern = f"{self.prefix}:*"
267
+ keys = list(self.client.scan_iter(match=pattern))
268
+
269
+ if keys:
270
+ self.client.delete(*keys)
271
+ logger.warning(f"Cleared {len(keys)} cache entries")
272
+
273
+ def set_multiple(self, mapping: Dict[str, Any], ttl: Optional[int] = None) -> None:
274
+ """
275
+ Set multiple key-value pairs.
276
+
277
+ Args:
278
+ mapping: Dictionary of key-value pairs
279
+ ttl: Time-to-live for all keys
280
+ """
281
+ pipe = self.client.pipeline()
282
+
283
+ for key, value in mapping.items():
284
+ full_key = self._make_key(key)
285
+ serialized = pickle.dumps(value)
286
+
287
+ if ttl:
288
+ pipe.setex(full_key, ttl, serialized)
289
+ else:
290
+ pipe.set(full_key, serialized)
291
+
292
+ pipe.execute()
293
+ logger.debug(f"Cached {len(mapping)} keys")
294
+
295
+ def get_multiple(self, keys: List[str]) -> Dict[str, Any]:
296
+ """
297
+ Get multiple values.
298
+
299
+ Args:
300
+ keys: List of cache keys
301
+
302
+ Returns:
303
+ Dictionary of found key-value pairs
304
+ """
305
+ full_keys = [self._make_key(k) for k in keys]
306
+ values = self.client.mget(full_keys)
307
+
308
+ result = {}
309
+ for key, value in zip(keys, values):
310
+ if value is not None:
311
+ result[key] = pickle.loads(value)
312
+
313
+ return result
314
+
315
+ def increment(self, key: str, amount: int = 1) -> int:
316
+ """
317
+ Increment counter.
318
+
319
+ Args:
320
+ key: Counter key
321
+ amount: Increment amount
322
+
323
+ Returns:
324
+ New counter value
325
+ """
326
+ full_key = self._make_key(key)
327
+ return self.client.incr(full_key, amount)
328
+
329
+ def decrement(self, key: str, amount: int = 1) -> int:
330
+ """
331
+ Decrement counter.
332
+
333
+ Args:
334
+ key: Counter key
335
+ amount: Decrement amount
336
+
337
+ Returns:
338
+ New counter value
339
+ """
340
+ full_key = self._make_key(key)
341
+ return self.client.decr(full_key, amount)
342
+
343
+ def get_ttl(self, key: str) -> Optional[int]:
344
+ """
345
+ Get remaining TTL for key.
346
+
347
+ Args:
348
+ key: Cache key
349
+
350
+ Returns:
351
+ TTL in seconds, or None if no TTL set
352
+ """
353
+ full_key = self._make_key(key)
354
+ ttl = self.client.ttl(full_key)
355
+
356
+ if ttl == -1: # No TTL
357
+ return None
358
+ elif ttl == -2: # Key doesn't exist
359
+ return None
360
+ else:
361
+ return ttl
362
+
363
+ def close(self) -> None:
364
+ """Close Redis connection."""
365
+ self.client.close()
366
+ logger.info("Closed Redis connection")
@@ -0,0 +1,329 @@
1
+ """Checkpoint management for experiment recovery.
2
+
3
+ Enables saving and restoring experiment state for fault tolerance.
4
+
5
+ Author: Eshan Roy <eshanized@proton.me>
6
+ Organization: TONMOY INFRASTRUCTURE & VISION
7
+ """
8
+
9
+ import pickle
10
+ import time
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from morphml.core.search import Individual
14
+ from morphml.distributed.storage.artifacts import ArtifactStore
15
+ from morphml.distributed.storage.cache import DistributedCache
16
+ from morphml.logging_config import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class CheckpointManager:
22
+ """
23
+ Manage experiment checkpoints for recovery.
24
+
25
+ Uses both cache (fast) and artifact storage (persistent):
26
+ - Cache: Recent checkpoints for quick recovery
27
+ - Storage: All checkpoints for long-term persistence
28
+
29
+ Args:
30
+ artifact_store: ArtifactStore for persistent storage
31
+ cache: DistributedCache for fast access (optional)
32
+ checkpoint_interval: Save checkpoint every N generations
33
+ max_checkpoints: Maximum checkpoints to keep (0 = unlimited)
34
+
35
+ Example:
36
+ >>> manager = CheckpointManager(artifact_store, cache)
37
+ >>>
38
+ >>> # Save checkpoint
39
+ >>> manager.save_checkpoint(
40
+ ... experiment_id='exp1',
41
+ ... generation=10,
42
+ ... optimizer_state={'population_size': 50},
43
+ ... population=population
44
+ ... )
45
+ >>>
46
+ >>> # Load latest checkpoint
47
+ >>> checkpoint = manager.load_checkpoint('exp1')
48
+ >>> if checkpoint:
49
+ ... generation = checkpoint['generation']
50
+ ... optimizer_state = checkpoint['optimizer_state']
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ artifact_store: ArtifactStore,
56
+ cache: Optional[DistributedCache] = None,
57
+ checkpoint_interval: int = 10,
58
+ max_checkpoints: int = 5,
59
+ ):
60
+ """Initialize checkpoint manager."""
61
+ self.store = artifact_store
62
+ self.cache = cache
63
+ self.checkpoint_interval = checkpoint_interval
64
+ self.max_checkpoints = max_checkpoints
65
+
66
+ logger.info(
67
+ f"Initialized CheckpointManager "
68
+ f"(interval={checkpoint_interval}, max={max_checkpoints})"
69
+ )
70
+
71
+ def save_checkpoint(
72
+ self,
73
+ experiment_id: str,
74
+ generation: int,
75
+ optimizer_state: Dict[str, Any],
76
+ population: Optional[List[Individual]] = None,
77
+ metadata: Optional[Dict[str, Any]] = None,
78
+ ) -> None:
79
+ """
80
+ Save experiment checkpoint.
81
+
82
+ Args:
83
+ experiment_id: Experiment identifier
84
+ generation: Current generation number
85
+ optimizer_state: Optimizer state dictionary
86
+ population: Current population (optional)
87
+ metadata: Additional metadata (optional)
88
+ """
89
+ checkpoint = {
90
+ "experiment_id": experiment_id,
91
+ "generation": generation,
92
+ "optimizer_state": optimizer_state,
93
+ "population": [ind.to_dict() for ind in population] if population else [],
94
+ "metadata": metadata or {},
95
+ "timestamp": time.time(),
96
+ }
97
+
98
+ # Save to cache for fast recovery
99
+ if self.cache:
100
+ try:
101
+ self.cache.set(f"checkpoint:{experiment_id}:latest", checkpoint, ttl=3600)
102
+ self.cache.set(
103
+ f"checkpoint:{experiment_id}:gen{generation}",
104
+ checkpoint,
105
+ ttl=7200,
106
+ )
107
+ logger.debug(f"Cached checkpoint for generation {generation}")
108
+ except Exception as e:
109
+ logger.warning(f"Failed to cache checkpoint: {e}")
110
+
111
+ # Save to persistent storage
112
+ try:
113
+ # Serialize checkpoint
114
+ checkpoint_bytes = pickle.dumps(checkpoint)
115
+
116
+ # Upload to S3
117
+ s3_key = f"checkpoints/{experiment_id}/gen_{generation:06d}.pkl"
118
+ self.store.upload_bytes(
119
+ checkpoint_bytes,
120
+ s3_key,
121
+ metadata={
122
+ "experiment_id": experiment_id,
123
+ "generation": str(generation),
124
+ "timestamp": str(int(time.time())),
125
+ },
126
+ )
127
+
128
+ logger.info(f"Saved checkpoint for {experiment_id} generation {generation}")
129
+
130
+ # Clean up old checkpoints
131
+ if self.max_checkpoints > 0:
132
+ self._cleanup_old_checkpoints(experiment_id)
133
+
134
+ except Exception as e:
135
+ logger.error(f"Failed to save checkpoint: {e}")
136
+ raise
137
+
138
+ def load_checkpoint(
139
+ self, experiment_id: str, generation: Optional[int] = None
140
+ ) -> Optional[Dict[str, Any]]:
141
+ """
142
+ Load checkpoint.
143
+
144
+ Args:
145
+ experiment_id: Experiment identifier
146
+ generation: Specific generation (None = latest)
147
+
148
+ Returns:
149
+ Checkpoint dictionary or None if not found
150
+ """
151
+ # Try cache first (latest only)
152
+ if generation is None and self.cache:
153
+ try:
154
+ checkpoint = self.cache.get(f"checkpoint:{experiment_id}:latest")
155
+
156
+ if checkpoint:
157
+ logger.info("Loaded checkpoint from cache (latest)")
158
+ return checkpoint
159
+ except Exception as e:
160
+ logger.warning(f"Failed to load from cache: {e}")
161
+
162
+ # Load from persistent storage
163
+ try:
164
+ if generation is not None:
165
+ # Load specific generation
166
+ s3_key = f"checkpoints/{experiment_id}/gen_{generation:06d}.pkl"
167
+
168
+ if not self.store.exists(s3_key):
169
+ logger.warning(f"Checkpoint not found for generation {generation}")
170
+ return None
171
+
172
+ checkpoint_bytes = self.store.download_bytes(s3_key)
173
+ checkpoint = pickle.loads(checkpoint_bytes)
174
+
175
+ logger.info(f"Loaded checkpoint for {experiment_id} generation {generation}")
176
+
177
+ return checkpoint
178
+
179
+ else:
180
+ # Load latest
181
+ checkpoints = self.list_checkpoints(experiment_id)
182
+
183
+ if not checkpoints:
184
+ logger.info(f"No checkpoints found for {experiment_id}")
185
+ return None
186
+
187
+ # Get latest
188
+ latest = checkpoints[-1]
189
+ s3_key = latest["key"]
190
+
191
+ checkpoint_bytes = self.store.download_bytes(s3_key)
192
+ checkpoint = pickle.loads(checkpoint_bytes)
193
+
194
+ logger.info(
195
+ f"Loaded latest checkpoint for {experiment_id} "
196
+ f"(generation {checkpoint['generation']})"
197
+ )
198
+
199
+ return checkpoint
200
+
201
+ except Exception as e:
202
+ logger.error(f"Failed to load checkpoint: {e}")
203
+ return None
204
+
205
+ def list_checkpoints(self, experiment_id: str) -> List[Dict[str, Any]]:
206
+ """
207
+ List all checkpoints for experiment.
208
+
209
+ Args:
210
+ experiment_id: Experiment identifier
211
+
212
+ Returns:
213
+ List of checkpoint metadata dictionaries
214
+ """
215
+ prefix = f"checkpoints/{experiment_id}/"
216
+
217
+ try:
218
+ checkpoints = self.store.list_objects(prefix)
219
+
220
+ # Sort by generation (extracted from key)
221
+ checkpoints.sort(key=lambda x: x["key"])
222
+
223
+ return checkpoints
224
+
225
+ except Exception as e:
226
+ logger.error(f"Failed to list checkpoints: {e}")
227
+ return []
228
+
229
+ def delete_checkpoint(self, experiment_id: str, generation: int) -> None:
230
+ """
231
+ Delete specific checkpoint.
232
+
233
+ Args:
234
+ experiment_id: Experiment identifier
235
+ generation: Generation number
236
+ """
237
+ s3_key = f"checkpoints/{experiment_id}/gen_{generation:06d}.pkl"
238
+
239
+ try:
240
+ self.store.delete(s3_key)
241
+ logger.info(f"Deleted checkpoint for generation {generation}")
242
+
243
+ except Exception as e:
244
+ logger.error(f"Failed to delete checkpoint: {e}")
245
+
246
+ def delete_all_checkpoints(self, experiment_id: str) -> int:
247
+ """
248
+ Delete all checkpoints for experiment.
249
+
250
+ Args:
251
+ experiment_id: Experiment identifier
252
+
253
+ Returns:
254
+ Number of checkpoints deleted
255
+ """
256
+ prefix = f"checkpoints/{experiment_id}/"
257
+
258
+ try:
259
+ deleted = self.store.delete_prefix(prefix)
260
+ logger.info(f"Deleted {deleted} checkpoints for {experiment_id}")
261
+
262
+ # Also clear cache
263
+ if self.cache:
264
+ self.cache.invalidate_pattern(f"checkpoint:{experiment_id}:*")
265
+
266
+ return deleted
267
+
268
+ except Exception as e:
269
+ logger.error(f"Failed to delete checkpoints: {e}")
270
+ return 0
271
+
272
+ def _cleanup_old_checkpoints(self, experiment_id: str) -> None:
273
+ """Remove old checkpoints keeping only max_checkpoints."""
274
+ checkpoints = self.list_checkpoints(experiment_id)
275
+
276
+ if len(checkpoints) <= self.max_checkpoints:
277
+ return
278
+
279
+ # Delete oldest checkpoints
280
+ num_to_delete = len(checkpoints) - self.max_checkpoints
281
+
282
+ for checkpoint in checkpoints[:num_to_delete]:
283
+ try:
284
+ self.store.delete(checkpoint["key"])
285
+ logger.debug(f"Cleaned up old checkpoint: {checkpoint['key']}")
286
+ except Exception as e:
287
+ logger.warning(f"Failed to cleanup checkpoint: {e}")
288
+
289
+ def should_checkpoint(self, generation: int) -> bool:
290
+ """
291
+ Check if should save checkpoint at this generation.
292
+
293
+ Args:
294
+ generation: Current generation
295
+
296
+ Returns:
297
+ True if should checkpoint
298
+ """
299
+ return generation % self.checkpoint_interval == 0
300
+
301
+ def get_statistics(self, experiment_id: str) -> Dict[str, Any]:
302
+ """
303
+ Get checkpoint statistics.
304
+
305
+ Args:
306
+ experiment_id: Experiment identifier
307
+
308
+ Returns:
309
+ Statistics dictionary
310
+ """
311
+ checkpoints = self.list_checkpoints(experiment_id)
312
+
313
+ if not checkpoints:
314
+ return {
315
+ "experiment_id": experiment_id,
316
+ "num_checkpoints": 0,
317
+ "total_size_bytes": 0,
318
+ }
319
+
320
+ total_size = sum(cp["size"] for cp in checkpoints)
321
+
322
+ return {
323
+ "experiment_id": experiment_id,
324
+ "num_checkpoints": len(checkpoints),
325
+ "total_size_bytes": total_size,
326
+ "total_size_mb": total_size / (1024 * 1024),
327
+ "oldest": checkpoints[0]["last_modified"],
328
+ "newest": checkpoints[-1]["last_modified"],
329
+ }