evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,926 @@
1
+ """
2
+ Callbacks for monitoring and controlling the optimisation process.
3
+
4
+ This module provides a callback system inspired by Keras/PyTorch Lightning,
5
+ allowing users to hook into the optimisation loop at various points.
6
+
7
+ Available callbacks:
8
+ - HistoryCallback: Track fitness and hyperparameter history
9
+ - EarlyStoppingCallback: Stop when no improvement is detected
10
+ - ConvergenceCallback: Stop when fitness change falls below threshold
11
+ - PrintCallback: Print progress during optimisation
12
+ - CheckpointCallback: Save algorithm state periodically
13
+ - CompositeCallback: Combine multiple callbacks
14
+
15
+ Example:
16
+ >>> from evograd.utils.callbacks import HistoryCallback, EarlyStoppingCallback
17
+ >>>
18
+ >>> history = HistoryCallback(track_population=False)
19
+ >>> early_stop = EarlyStoppingCallback(patience=50, min_delta=1e-6)
20
+ >>>
21
+ >>> # Pass to minimize function
22
+ >>> result = minimize(algorithm, callbacks=[history, early_stop])
23
+ >>>
24
+ >>> # Access history after optimisation
25
+ >>> print(history.best_fitness) # List of best fitness per generation
26
+ >>> print(history.to_dataframe()) # Pandas DataFrame if available
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import time
32
+ import warnings
33
+ from abc import ABC, abstractmethod
34
+ from dataclasses import dataclass, field
35
+ from enum import Enum, auto
36
+ from pathlib import Path
37
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
38
+
39
+ import torch
40
+
41
+ if TYPE_CHECKING:
42
+ from torch import Tensor
43
+
44
+ __all__ = [
45
+ "CallbackEvent",
46
+ "Callback",
47
+ "HistoryCallback",
48
+ "EarlyStoppingCallback",
49
+ "ConvergenceCallback",
50
+ "PrintCallback",
51
+ "CheckpointCallback",
52
+ "CompositeCallback",
53
+ "CallbackList",
54
+ ]
55
+
56
+
57
+ # =============================================================================
58
+ # Callback Events
59
+ # =============================================================================
60
+
61
+ class CallbackEvent(Enum):
62
+ """Events that trigger callback methods."""
63
+
64
+ OPTIMISATION_START = auto()
65
+ OPTIMISATION_END = auto()
66
+ GENERATION_START = auto()
67
+ GENERATION_END = auto()
68
+ EVALUATION_START = auto()
69
+ EVALUATION_END = auto()
70
+
71
+
72
+ # =============================================================================
73
+ # Callback State Container
74
+ # =============================================================================
75
+
76
+ @dataclass
77
+ class CallbackState:
78
+ """
79
+ State object passed to callbacks containing current optimisation info.
80
+
81
+ Attributes:
82
+ generation: Current generation number (0-indexed).
83
+ n_evals: Total number of fitness evaluations so far.
84
+ max_evals: Maximum allowed fitness evaluations.
85
+ max_generations: Maximum allowed generations (if set).
86
+ best_fitness: Best fitness value found so far.
87
+ best_solution: Best solution found so far.
88
+ current_fitness: Fitness values of current population.
89
+ current_population: Current population tensor.
90
+ algorithm: Reference to the algorithm instance.
91
+ hyperparams: Dictionary of current hyperparameter values.
92
+ elapsed_time: Time elapsed since optimisation start.
93
+ stop_optimisation: Flag to signal early stopping.
94
+ extra: Dictionary for custom data from algorithms.
95
+ """
96
+
97
+ generation: int = 0
98
+ n_evals: int = 0
99
+ max_evals: Optional[int] = None
100
+ max_generations: Optional[int] = None
101
+ best_fitness: float = float('inf')
102
+ best_solution: Optional[Tensor] = None
103
+ current_fitness: Optional[Tensor] = None
104
+ current_population: Optional[Tensor] = None
105
+ algorithm: Optional[Any] = None
106
+ hyperparams: Dict[str, Any] = field(default_factory=dict)
107
+ elapsed_time: float = 0.0
108
+ stop_optimisation: bool = False
109
+ extra: Dict[str, Any] = field(default_factory=dict)
110
+
111
+ def request_stop(self, reason: str = "") -> None:
112
+ """Request the optimisation loop to stop."""
113
+ self.stop_optimisation = True
114
+ self.extra["stop_reason"] = reason
115
+
116
+
117
+ # =============================================================================
118
+ # Base Callback
119
+ # =============================================================================
120
+
121
+ class Callback(ABC):
122
+ """
123
+ Abstract base class for all callbacks.
124
+
125
+ Callbacks can hook into various points of the optimisation loop:
126
+ - on_optimisation_start: Called once before optimisation begins
127
+ - on_optimisation_end: Called once after optimisation completes
128
+ - on_generation_start: Called at the start of each generation
129
+ - on_generation_end: Called at the end of each generation
130
+ - on_evaluation_start: Called before fitness evaluation
131
+ - on_evaluation_end: Called after fitness evaluation
132
+
133
+ To create a custom callback, subclass this and override the desired methods.
134
+
135
+ Example:
136
+ >>> class MyCallback(Callback):
137
+ ... def on_generation_end(self, state: CallbackState) -> None:
138
+ ... if state.generation % 100 == 0:
139
+ ... print(f"Gen {state.generation}: {state.best_fitness:.6f}")
140
+ """
141
+
142
+ def on_optimisation_start(self, state: CallbackState) -> None:
143
+ """Called when optimisation begins."""
144
+ pass
145
+
146
+ def on_optimisation_end(self, state: CallbackState) -> None:
147
+ """Called when optimisation ends."""
148
+ pass
149
+
150
+ def on_generation_start(self, state: CallbackState) -> None:
151
+ """Called at the start of each generation."""
152
+ pass
153
+
154
+ def on_generation_end(self, state: CallbackState) -> None:
155
+ """Called at the end of each generation."""
156
+ pass
157
+
158
+ def on_evaluation_start(self, state: CallbackState) -> None:
159
+ """Called before fitness evaluation."""
160
+ pass
161
+
162
+ def on_evaluation_end(self, state: CallbackState) -> None:
163
+ """Called after fitness evaluation."""
164
+ pass
165
+
166
+
167
+ # =============================================================================
168
+ # History Callback
169
+ # =============================================================================
170
+
171
+ class HistoryCallback(Callback):
172
+ """
173
+ Track optimisation history including fitness values and hyperparameters.
174
+
175
+ This callback records various metrics at each generation, providing
176
+ a complete picture of the optimisation trajectory.
177
+
178
+ Args:
179
+ track_population: Whether to store full population at each generation.
180
+ Warning: This can consume significant memory for large populations.
181
+ track_hyperparams: Whether to track hyperparameter changes.
182
+ track_diversity: Whether to compute population diversity metrics.
183
+ track_fitness_stats: Whether to track min/max/mean/std of fitness.
184
+
185
+ Attributes:
186
+ best_fitness: List of best fitness at each generation.
187
+ best_solution: List of best solutions at each generation.
188
+ mean_fitness: List of mean fitness at each generation.
189
+ std_fitness: List of fitness std at each generation.
190
+ min_fitness: List of min fitness at each generation.
191
+ max_fitness: List of max fitness at each generation.
192
+ n_evals: List of cumulative evaluations at each generation.
193
+ elapsed_time: List of elapsed time at each generation.
194
+ hyperparams: Dict mapping hyperparam names to lists of values.
195
+ populations: List of population tensors (if track_population=True).
196
+ diversity: List of diversity metrics (if track_diversity=True).
197
+
198
+ Example:
199
+ >>> history = HistoryCallback(track_hyperparams=True)
200
+ >>> result = minimize(algorithm, callbacks=[history])
201
+ >>>
202
+ >>> # Plot convergence
203
+ >>> import matplotlib.pyplot as plt
204
+ >>> plt.plot(history.best_fitness)
205
+ >>> plt.xlabel('Generation')
206
+ >>> plt.ylabel('Best Fitness')
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ track_population: bool = False,
212
+ track_hyperparams: bool = True,
213
+ track_diversity: bool = False,
214
+ track_fitness_stats: bool = True,
215
+ ) -> None:
216
+ self.track_population = track_population
217
+ self.track_hyperparams = track_hyperparams
218
+ self.track_diversity = track_diversity
219
+ self.track_fitness_stats = track_fitness_stats
220
+
221
+ # Core tracking
222
+ self.best_fitness: List[float] = []
223
+ self.best_solution: List[Tensor] = []
224
+ self.n_evals: List[int] = []
225
+ self.elapsed_time: List[float] = []
226
+ self.generations: List[int] = []
227
+
228
+ # Fitness statistics
229
+ self.mean_fitness: List[float] = []
230
+ self.std_fitness: List[float] = []
231
+ self.min_fitness: List[float] = []
232
+ self.max_fitness: List[float] = []
233
+
234
+ # Hyperparameters
235
+ self.hyperparams: Dict[str, List[Any]] = {}
236
+
237
+ # Population (optional)
238
+ self.populations: List[Tensor] = []
239
+
240
+ # Diversity (optional)
241
+ self.diversity: List[float] = []
242
+
243
+ # Timing
244
+ self._start_time: Optional[float] = None
245
+
246
+ def _record_metrics(self, state: CallbackState) -> None:
247
+ # Core metrics
248
+ self.generations.append(state.generation)
249
+ self.best_fitness.append(float(state.best_fitness))
250
+ self.n_evals.append(state.n_evals)
251
+
252
+ if self._start_time is not None:
253
+ self.elapsed_time.append(time.perf_counter() - self._start_time)
254
+
255
+ # Best solution (detached clone)
256
+ if state.best_solution is not None:
257
+ self.best_solution.append(state.best_solution.detach().clone())
258
+
259
+ # Fitness statistics
260
+ if self.track_fitness_stats and state.current_fitness is not None:
261
+ fitness = state.current_fitness
262
+ self.mean_fitness.append(float(fitness.mean().item()))
263
+ self.std_fitness.append(float(fitness.std().item()))
264
+ self.min_fitness.append(float(fitness.min().item()))
265
+ self.max_fitness.append(float(fitness.max().item()))
266
+
267
+ # Hyperparameters
268
+ if self.track_hyperparams and state.hyperparams:
269
+ for name, value in state.hyperparams.items():
270
+ if name not in self.hyperparams:
271
+ self.hyperparams[name] = []
272
+ # Convert tensor to scalar if needed
273
+ if isinstance(value, torch.Tensor):
274
+ value = float(value.mean()) if value.numel() > 1 else float(value)
275
+ self.hyperparams[name].append(value)
276
+
277
+ # Population snapshot
278
+ if self.track_population and state.current_population is not None:
279
+ self.populations.append(state.current_population.detach().clone())
280
+
281
+ # Diversity
282
+ if self.track_diversity and state.current_population is not None:
283
+ div = self._compute_diversity(state.current_population)
284
+ self.diversity.append(div)
285
+
286
+ def on_optimisation_start(self, state: CallbackState) -> None:
287
+ """Reset history and start timer."""
288
+ self._start_time = time.perf_counter()
289
+
290
+ # Clear all lists
291
+ self.best_fitness.clear()
292
+ self.best_solution.clear()
293
+ self.n_evals.clear()
294
+ self.elapsed_time.clear()
295
+ self.generations.clear()
296
+ self.mean_fitness.clear()
297
+ self.std_fitness.clear()
298
+ self.min_fitness.clear()
299
+ self.max_fitness.clear()
300
+ self.hyperparams.clear()
301
+ self.populations.clear()
302
+ self.diversity.clear()
303
+
304
+ self._record_metrics(state)
305
+
306
+ def on_generation_end(self, state: CallbackState) -> None:
307
+ """Record metrics at end of generation or initialization."""
308
+ self._record_metrics(state)
309
+
310
+ def _compute_diversity(self, population: Tensor) -> float:
311
+ """
312
+ Compute population diversity as mean pairwise distance.
313
+
314
+ Uses L2 norm between individuals.
315
+ """
316
+ if population.shape[0] < 2:
317
+ return 0.0
318
+
319
+ # Compute pairwise distances
320
+ dists = torch.cdist(population, population, p=2)
321
+
322
+ # Mean of upper triangle (excluding diagonal)
323
+ n = population.shape[0]
324
+ mask = torch.triu(torch.ones(n, n, device=population.device), diagonal=1).bool()
325
+ mean_dist = dists[mask].mean()
326
+
327
+ return float(mean_dist)
328
+
329
+ def to_dict(self) -> Dict[str, Any]:
330
+ """Convert history to dictionary."""
331
+ result = {
332
+ "generation": self.generations,
333
+ "best_fitness": self.best_fitness,
334
+ "n_evals": self.n_evals,
335
+ "elapsed_time": self.elapsed_time,
336
+ }
337
+
338
+ if self.track_fitness_stats:
339
+ result["mean_fitness"] = self.mean_fitness
340
+ result["std_fitness"] = self.std_fitness
341
+ result["min_fitness"] = self.min_fitness
342
+ result["max_fitness"] = self.max_fitness
343
+
344
+ if self.track_hyperparams:
345
+ for name, values in self.hyperparams.items():
346
+ result[f"hp_{name}"] = values
347
+
348
+ if self.track_diversity:
349
+ result["diversity"] = self.diversity
350
+
351
+ return result
352
+
353
+ def to_dataframe(self):
354
+ """
355
+ Convert history to pandas DataFrame.
356
+
357
+ Returns:
358
+ pandas.DataFrame if pandas is available, else raises ImportError.
359
+ """
360
+ try:
361
+ import pandas as pd
362
+ return pd.DataFrame(self.to_dict())
363
+ except ImportError:
364
+ raise ImportError(
365
+ "pandas is required for to_dataframe(). "
366
+ "Install with: pip install pandas"
367
+ )
368
+
369
+ def __len__(self) -> int:
370
+ """Number of recorded generations."""
371
+ return len(self.generations)
372
+
373
+ def __repr__(self) -> str:
374
+ n_gens = len(self.generations)
375
+ best = self.best_fitness[-1] if self.best_fitness else float('inf')
376
+ return f"HistoryCallback(generations={n_gens}, best_fitness={best:.6g})"
377
+
378
+
379
+ # =============================================================================
380
+ # Early Stopping Callback
381
+ # =============================================================================
382
+
383
+ class EarlyStoppingCallback(Callback):
384
+ """
385
+ Stop optimisation when fitness stops improving.
386
+
387
+ Monitors the best fitness and stops if no improvement is seen
388
+ for a specified number of generations (patience).
389
+
390
+ Args:
391
+ patience: Number of generations to wait for improvement.
392
+ min_delta: Minimum change to qualify as an improvement.
393
+ Improvement is defined as: new_best < best - min_delta
394
+ baseline: Initial baseline value. If None, uses first generation's best.
395
+ restore_best: Whether to restore best solution when stopping.
396
+ verbose: Whether to print when stopping.
397
+
398
+ Attributes:
399
+ best_fitness: Best fitness seen so far.
400
+ best_generation: Generation where best fitness was found.
401
+ wait: Current number of generations without improvement.
402
+ stopped_generation: Generation where stopping was triggered.
403
+
404
+ Example:
405
+ >>> early_stop = EarlyStoppingCallback(patience=100, min_delta=1e-8)
406
+ >>> result = minimize(algorithm, callbacks=[early_stop])
407
+ >>>
408
+ >>> if early_stop.stopped_generation is not None:
409
+ ... print(f"Stopped at generation {early_stop.stopped_generation}")
410
+ """
411
+
412
+ def __init__(
413
+ self,
414
+ patience: int = 50,
415
+ min_delta: float = 0.0,
416
+ baseline: Optional[float] = None,
417
+ restore_best: bool = True,
418
+ verbose: bool = False,
419
+ ) -> None:
420
+ if patience < 1:
421
+ raise ValueError(f"patience must be >= 1, got {patience}")
422
+ if min_delta < 0:
423
+ raise ValueError(f"min_delta must be >= 0, got {min_delta}")
424
+
425
+ self.patience = patience
426
+ self.min_delta = min_delta
427
+ self.baseline = baseline
428
+ self.restore_best = restore_best
429
+ self.verbose = verbose
430
+
431
+ # State
432
+ self.best_fitness: float = float('inf')
433
+ self.best_generation: int = 0
434
+ self.best_solution: Optional[Tensor] = None
435
+ self.wait: int = 0
436
+ self.stopped_generation: Optional[int] = None
437
+
438
+ def on_optimisation_start(self, state: CallbackState) -> None:
439
+ """Reset state at start of optimisation."""
440
+ self.best_fitness = self.baseline if self.baseline is not None else float('inf')
441
+ self.best_generation = 0
442
+ self.best_solution = None
443
+ self.wait = 0
444
+ self.stopped_generation = None
445
+
446
+ def on_generation_end(self, state: CallbackState) -> None:
447
+ """Check for improvement and possibly stop."""
448
+ current = state.best_fitness
449
+
450
+ # Check for improvement
451
+ if current < self.best_fitness - self.min_delta:
452
+ self.best_fitness = current
453
+ self.best_generation = state.generation
454
+ if state.best_solution is not None:
455
+ self.best_solution = state.best_solution.detach().clone()
456
+ self.wait = 0
457
+ else:
458
+ self.wait += 1
459
+
460
+ # Check patience
461
+ if self.wait >= self.patience:
462
+ self.stopped_generation = state.generation
463
+ state.request_stop(
464
+ f"EarlyStopping: No improvement for {self.patience} generations. "
465
+ f"Best: {self.best_fitness:.6g} at generation {self.best_generation}"
466
+ )
467
+
468
+ if self.verbose:
469
+ print(
470
+ f"Early stopping at generation {state.generation}. "
471
+ f"Best fitness: {self.best_fitness:.6g} "
472
+ f"(generation {self.best_generation})"
473
+ )
474
+
475
+ def on_optimisation_end(self, state: CallbackState) -> None:
476
+ """Optionally restore best solution."""
477
+ if self.restore_best and self.best_solution is not None:
478
+ # Store in extra for minimize to pick up
479
+ state.extra["restored_solution"] = self.best_solution
480
+ state.extra["restored_fitness"] = self.best_fitness
481
+
482
+ def __repr__(self) -> str:
483
+ status = "active" if self.stopped_generation is None else f"stopped@{self.stopped_generation}"
484
+ return (
485
+ f"EarlyStoppingCallback(patience={self.patience}, "
486
+ f"min_delta={self.min_delta}, status={status})"
487
+ )
488
+
489
+
490
+ # =============================================================================
491
+ # Convergence Callback
492
+ # =============================================================================
493
+
494
+ class ConvergenceCallback(Callback):
495
+ """
496
+ Stop optimisation when fitness change falls below threshold.
497
+
498
+ Monitors the relative or absolute change in best fitness over a window
499
+ of generations and stops when change is consistently small.
500
+
501
+ Args:
502
+ threshold: Convergence threshold for fitness change.
503
+ window: Number of generations to average change over.
504
+ mode: 'absolute' or 'relative' change measurement.
505
+ min_generations: Minimum generations before convergence check starts.
506
+ verbose: Whether to print when stopping.
507
+
508
+ Example:
509
+ >>> conv = ConvergenceCallback(threshold=1e-6, window=20, mode='relative')
510
+ >>> result = minimize(algorithm, callbacks=[conv])
511
+ """
512
+
513
+ def __init__(
514
+ self,
515
+ threshold: float = 1e-6,
516
+ window: int = 10,
517
+ mode: str = "absolute",
518
+ min_generations: int = 100,
519
+ verbose: bool = False,
520
+ ) -> None:
521
+ if threshold <= 0:
522
+ raise ValueError(f"threshold must be > 0, got {threshold}")
523
+ if window < 2:
524
+ raise ValueError(f"window must be >= 2, got {window}")
525
+ if mode not in ("absolute", "relative"):
526
+ raise ValueError(f"mode must be 'absolute' or 'relative', got '{mode}'")
527
+
528
+ self.threshold = threshold
529
+ self.window = window
530
+ self.mode = mode
531
+ self.min_generations = min_generations
532
+ self.verbose = verbose
533
+
534
+ # State
535
+ self.fitness_history: List[float] = []
536
+ self.stopped_generation: Optional[int] = None
537
+ self.final_change: Optional[float] = None
538
+
539
+ def on_optimisation_start(self, state: CallbackState) -> None:
540
+ """Reset state."""
541
+ self.fitness_history.clear()
542
+ self.stopped_generation = None
543
+ self.final_change = None
544
+
545
+ def on_generation_end(self, state: CallbackState) -> None:
546
+ """Check convergence."""
547
+ self.fitness_history.append(float(state.best_fitness))
548
+
549
+ # Skip if not enough history
550
+ if len(self.fitness_history) < self.window:
551
+ return
552
+
553
+ # Skip if minimum generations not reached
554
+ if state.generation < self.min_generations:
555
+ return
556
+
557
+ # Compute change over window
558
+ recent = self.fitness_history[-self.window:]
559
+ old_val = recent[0]
560
+ new_val = recent[-1]
561
+
562
+ if self.mode == "absolute":
563
+ change = abs(new_val - old_val)
564
+ else: # relative
565
+ if abs(old_val) < 1e-10:
566
+ change = abs(new_val - old_val)
567
+ else:
568
+ change = abs((new_val - old_val) / old_val)
569
+
570
+ self.final_change = change
571
+
572
+ # Check threshold
573
+ if change < self.threshold:
574
+ self.stopped_generation = state.generation
575
+ state.request_stop(
576
+ f"Convergence: {self.mode} change {change:.2e} < {self.threshold:.2e} "
577
+ f"over {self.window} generations"
578
+ )
579
+
580
+ if self.verbose:
581
+ print(
582
+ f"Converged at generation {state.generation}. "
583
+ f"{self.mode.capitalize()} change: {change:.2e}"
584
+ )
585
+
586
+ def __repr__(self) -> str:
587
+ status = "active" if self.stopped_generation is None else f"stopped@{self.stopped_generation}"
588
+ return (
589
+ f"ConvergenceCallback(threshold={self.threshold}, "
590
+ f"window={self.window}, mode='{self.mode}', status={status})"
591
+ )
592
+
593
+
594
+ # =============================================================================
595
+ # Print Callback
596
+ # =============================================================================
597
+
598
+ class PrintCallback(Callback):
599
+ """
600
+ Print progress during optimisation.
601
+
602
+ Args:
603
+ every: Print every N generations.
604
+ show_hyperparams: Whether to show hyperparameter values.
605
+ show_evals: Whether to show evaluation count.
606
+ show_time: Whether to show elapsed time.
607
+ format_spec: Format specification for fitness (e.g., '.6f', '.4e').
608
+
609
+ Example:
610
+ >>> printer = PrintCallback(every=50, show_time=True)
611
+ >>> result = minimize(algorithm, callbacks=[printer])
612
+
613
+ # Output:
614
+ # Gen 50 | Best: 1.234567e+02 | Evals: 5000 | Time: 1.23s
615
+ # Gen 100 | Best: 4.567890e+01 | Evals: 10000 | Time: 2.45s
616
+ """
617
+
618
+ def __init__(
619
+ self,
620
+ every: int = 1,
621
+ show_hyperparams: bool = False,
622
+ show_evals: bool = True,
623
+ show_time: bool = True,
624
+ format_spec: str = ".6e",
625
+ ) -> None:
626
+ if every < 1:
627
+ raise ValueError(f"every must be >= 1, got {every}")
628
+
629
+ self.every = every
630
+ self.show_hyperparams = show_hyperparams
631
+ self.show_evals = show_evals
632
+ self.show_time = show_time
633
+ self.format_spec = format_spec
634
+
635
+ self._start_time: Optional[float] = None
636
+
637
+ def _report_generation(self, state: CallbackState) -> None:
638
+ """Print progress if at print interval."""
639
+ if state.generation % self.every != 0:
640
+ return
641
+
642
+ # Build output string
643
+ parts = [f"Gen {state.generation:5d}"]
644
+ parts.append(f"Best: {state.best_fitness:{self.format_spec}}")
645
+
646
+ if self.show_evals:
647
+ parts.append(f"Evals: {state.n_evals:7d}")
648
+
649
+ if self.show_time and self._start_time is not None:
650
+ elapsed = time.perf_counter() - self._start_time
651
+ parts.append(f"Time: {elapsed:.2f}s")
652
+
653
+ if self.show_hyperparams and state.hyperparams:
654
+ hp_str = ", ".join(
655
+ f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}"
656
+ for k, v in list(state.hyperparams.items())[:3] # Limit to 3
657
+ )
658
+ parts.append(f"HP: {hp_str}")
659
+
660
+ print(" | ".join(parts))
661
+
662
+
663
+ def on_optimisation_start(self, state: CallbackState) -> None:
664
+ """Record start time, print header, and print initilization"""
665
+ self._start_time = time.perf_counter()
666
+
667
+ # Print header
668
+ header = "Gen"
669
+ header += " | Best Fitness"
670
+ if self.show_evals:
671
+ header += " | Evals"
672
+ if self.show_time:
673
+ header += " | Time"
674
+
675
+ print("-" * len(header))
676
+ print(header)
677
+ print("-" * len(header))
678
+
679
+ self._report_generation(state)
680
+
681
+ def on_generation_end(self, state: CallbackState) -> None:
682
+ """Print progress after each generation"""
683
+ self._report_generation(state)
684
+
685
+ def on_optimisation_end(self, state: CallbackState) -> None:
686
+ """Print final summary."""
687
+ elapsed = time.perf_counter() - self._start_time if self._start_time else 0
688
+ print("-" * 60)
689
+ print(f"Optimisation complete!")
690
+ print(f" Final best: {state.best_fitness:{self.format_spec}}")
691
+ print(f" Generations: {state.generation}")
692
+ print(f" Evaluations: {state.n_evals}")
693
+ print(f" Time: {elapsed:.2f}s")
694
+
695
+ if state.extra.get("stop_reason"):
696
+ print(f" Stop reason: {state.extra['stop_reason']}")
697
+
698
+
699
+ # =============================================================================
700
+ # Checkpoint Callback
701
+ # =============================================================================
702
+
703
+ class CheckpointCallback(Callback):
704
+ """
705
+ Save algorithm state periodically.
706
+
707
+ Args:
708
+ directory: Directory to save checkpoints.
709
+ every: Save every N generations.
710
+ save_best_only: Only save when best fitness improves.
711
+ max_to_keep: Maximum number of checkpoints to keep (None for all).
712
+ prefix: Filename prefix for checkpoints.
713
+
714
+ Example:
715
+ >>> ckpt = CheckpointCallback(
716
+ ... directory="checkpoints",
717
+ ... every=100,
718
+ ... save_best_only=True
719
+ ... )
720
+ >>> result = minimize(algorithm, callbacks=[ckpt])
721
+ """
722
+
723
+ def __init__(
724
+ self,
725
+ directory: Union[str, Path],
726
+ every: int = 100,
727
+ save_best_only: bool = False,
728
+ max_to_keep: Optional[int] = 5,
729
+ prefix: str = "checkpoint",
730
+ ) -> None:
731
+ self.directory = Path(directory)
732
+ self.every = every
733
+ self.save_best_only = save_best_only
734
+ self.max_to_keep = max_to_keep
735
+ self.prefix = prefix
736
+
737
+ # State
738
+ self.best_fitness: float = float('inf')
739
+ self.saved_paths: List[Path] = []
740
+
741
+ def on_optimisation_start(self, state: CallbackState) -> None:
742
+ """Create directory and reset state."""
743
+ self.directory.mkdir(parents=True, exist_ok=True)
744
+ self.best_fitness = float('inf')
745
+ self.saved_paths.clear()
746
+
747
+ def on_generation_end(self, state: CallbackState) -> None:
748
+ """Save checkpoint if conditions are met."""
749
+ should_save = False
750
+
751
+ if self.save_best_only:
752
+ if state.best_fitness < self.best_fitness:
753
+ self.best_fitness = state.best_fitness
754
+ should_save = True
755
+ else:
756
+ if state.generation % self.every == 0:
757
+ should_save = True
758
+
759
+ if not should_save:
760
+ return
761
+
762
+ # Build checkpoint
763
+ checkpoint = {
764
+ "generation": state.generation,
765
+ "n_evals": state.n_evals,
766
+ "best_fitness": float(state.best_fitness),
767
+ "best_solution": state.best_solution.detach().cpu() if state.best_solution is not None else None,
768
+ "hyperparams": {
769
+ k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
770
+ for k, v in state.hyperparams.items()
771
+ },
772
+ }
773
+
774
+ # Save algorithm state dict if available
775
+ if state.algorithm is not None and hasattr(state.algorithm, "state_dict"):
776
+ checkpoint["algorithm_state"] = state.algorithm.state_dict()
777
+
778
+ # Save
779
+ filename = f"{self.prefix}_gen{state.generation:06d}.pt"
780
+ filepath = self.directory / filename
781
+ torch.save(checkpoint, filepath)
782
+ self.saved_paths.append(filepath)
783
+
784
+ # Clean up old checkpoints
785
+ if self.max_to_keep is not None:
786
+ while len(self.saved_paths) > self.max_to_keep:
787
+ old_path = self.saved_paths.pop(0)
788
+ if old_path.exists():
789
+ old_path.unlink()
790
+
791
+ @staticmethod
792
+ def load_checkpoint(path: Union[str, Path]) -> Dict[str, Any]:
793
+ """Load a checkpoint file."""
794
+ return torch.load(path, weights_only=False)
795
+
796
+ def __repr__(self) -> str:
797
+ return (
798
+ f"CheckpointCallback(directory='{self.directory}', "
799
+ f"every={self.every}, saved={len(self.saved_paths)})"
800
+ )
801
+
802
+
803
+ # =============================================================================
804
+ # Composite Callback
805
+ # =============================================================================
806
+
807
+ class CompositeCallback(Callback):
808
+ """
809
+ Combine multiple callbacks into one.
810
+
811
+ This is useful for passing a single callback object that internally
812
+ manages multiple callbacks.
813
+
814
+ Args:
815
+ callbacks: List of callbacks to combine.
816
+
817
+ Example:
818
+ >>> composite = CompositeCallback([
819
+ ... HistoryCallback(),
820
+ ... EarlyStoppingCallback(patience=50),
821
+ ... PrintCallback(every=10)
822
+ ... ])
823
+ >>> result = minimize(algorithm, callbacks=[composite])
824
+ """
825
+
826
+ def __init__(self, callbacks: List[Callback]) -> None:
827
+ self.callbacks = list(callbacks)
828
+
829
+ def add(self, callback: Callback) -> "CompositeCallback":
830
+ """Add a callback."""
831
+ self.callbacks.append(callback)
832
+ return self
833
+
834
+ def on_optimisation_start(self, state: CallbackState) -> None:
835
+ for cb in self.callbacks:
836
+ cb.on_optimisation_start(state)
837
+
838
+ def on_optimisation_end(self, state: CallbackState) -> None:
839
+ for cb in self.callbacks:
840
+ cb.on_optimisation_end(state)
841
+
842
+ def on_generation_start(self, state: CallbackState) -> None:
843
+ for cb in self.callbacks:
844
+ cb.on_generation_start(state)
845
+
846
+ def on_generation_end(self, state: CallbackState) -> None:
847
+ for cb in self.callbacks:
848
+ cb.on_generation_end(state)
849
+
850
+ def on_evaluation_start(self, state: CallbackState) -> None:
851
+ for cb in self.callbacks:
852
+ cb.on_evaluation_start(state)
853
+
854
+ def on_evaluation_end(self, state: CallbackState) -> None:
855
+ for cb in self.callbacks:
856
+ cb.on_evaluation_end(state)
857
+
858
+ def __len__(self) -> int:
859
+ return len(self.callbacks)
860
+
861
+ def __iter__(self):
862
+ return iter(self.callbacks)
863
+
864
+ def __repr__(self) -> str:
865
+ return f"CompositeCallback({len(self.callbacks)} callbacks)"
866
+
867
+
868
+ # =============================================================================
869
+ # Callback List (Convenience Alias)
870
+ # =============================================================================
871
+
872
+ class CallbackList(CompositeCallback):
873
+ """
874
+ Alias for CompositeCallback with list-like interface.
875
+
876
+ This mimics Keras's CallbackList for familiarity.
877
+ """
878
+
879
+ def append(self, callback: Callback) -> None:
880
+ """Append a callback."""
881
+ self.callbacks.append(callback)
882
+
883
+ def extend(self, callbacks: List[Callback]) -> None:
884
+ """Extend with multiple callbacks."""
885
+ self.callbacks.extend(callbacks)
886
+
887
+ def __getitem__(self, idx: int) -> Callback:
888
+ return self.callbacks[idx]
889
+
890
+
891
+ # =============================================================================
892
+ # Utility Functions
893
+ # =============================================================================
894
+
895
+ def create_default_callbacks(
896
+ verbose: bool = True,
897
+ history: bool = True,
898
+ early_stopping: bool = False,
899
+ patience: int = 50,
900
+ print_every: int = 100,
901
+ ) -> CallbackList:
902
+ """
903
+ Create a default set of callbacks.
904
+
905
+ Args:
906
+ verbose: Whether to include PrintCallback.
907
+ history: Whether to include HistoryCallback.
908
+ early_stopping: Whether to include EarlyStoppingCallback.
909
+ patience: Patience for early stopping.
910
+ print_every: Print interval.
911
+
912
+ Returns:
913
+ CallbackList with requested callbacks.
914
+ """
915
+ callbacks = CallbackList([])
916
+
917
+ if history:
918
+ callbacks.append(HistoryCallback())
919
+
920
+ if early_stopping:
921
+ callbacks.append(EarlyStoppingCallback(patience=patience))
922
+
923
+ if verbose:
924
+ callbacks.append(PrintCallback(every=print_every))
925
+
926
+ return callbacks