evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,778 @@
1
+ """
2
+ Abstract base class for all EvoGrad algorithms.
3
+
4
+ This module provides the foundation for population-based optimisers,
5
+ supporting both differentiable (gradient-enabled) and classical modes.
6
+
7
+ The design follows an infill/advance pattern inspired by pymoo:
8
+ - _infill(): Generate offspring using evolutionary operators
9
+ - _advance(): Update population state based on offspring fitness
10
+
11
+ Algorithms receive operators via dependency injection (pymoo-style):
12
+ - sampling: Initial population generation
13
+ - selection: Parent selection (optional)
14
+ - crossover: Recombination operator (optional)
15
+ - mutation: Perturbation operator (optional)
16
+ - survival: Survivor selection (optional)
17
+ - repair: Constraint handling (optional)
18
+
19
+ For differentiable mode, the entire generation is a differentiable
20
+ computation graph, enabling gradient-based hyperparameter learning.
21
+
22
+ Example:
23
+ >>> from evograd.core import Problem
24
+ >>> from evograd.operators import FloatRandomSampling, TournamentSelection
25
+ >>>
26
+ >>> problem = Problem(objective=ackley, n_var=30, xl=-100.0, xu=100.0)
27
+ >>>
28
+ >>> ga = GA(
29
+ ... pop_size=100,
30
+ ... sampling=FloatRandomSampling(),
31
+ ... selection=TournamentSelection(tournament_size=3),
32
+ ... crossover=SBXCrossover(eta=15, prob=0.9),
33
+ ... mutation=PolynomialMutation(eta=20),
34
+ ... differentiable=True,
35
+ ... )
36
+ >>>
37
+ >>> ga.initialize(problem)
38
+ >>> result = minimize(ga, problem, max_evals=10000)
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ from abc import ABC, abstractmethod
44
+ from typing import (
45
+ TYPE_CHECKING,
46
+ Any,
47
+ Dict,
48
+ List,
49
+ Optional,
50
+ Union,
51
+ )
52
+
53
+ import torch
54
+ import torch.nn as nn
55
+
56
+ from evograd.utils.device import set_seed
57
+ from evograd.utils.duplicates import DuplicateEliminator, DuplicateMethod
58
+ from evograd.operators.sampling import UniformSampling
59
+
60
+ if TYPE_CHECKING:
61
+ from torch import Tensor
62
+ from evograd.core.problem import Problem
63
+
64
+ __all__ = [
65
+ "Algorithm",
66
+ "AlgorithmState",
67
+ ]
68
+
69
+
70
+ # =============================================================================
71
+ # Algorithm State Container
72
+ # =============================================================================
73
+
74
+ class AlgorithmState:
75
+ """
76
+ Container for algorithm state that can be saved/loaded.
77
+
78
+ This separates the persistent state from the algorithm logic,
79
+ making it easier to checkpoint and resume optimisation.
80
+
81
+ Attributes:
82
+ generation: Current generation number.
83
+ n_evals: Total fitness evaluations performed.
84
+ population: Current population tensor.
85
+ fitness: Current fitness values.
86
+ best_fitness: Best fitness found so far.
87
+ best_solution: Best solution found so far.
88
+ hyperparams: Dictionary of current hyperparameter values.
89
+ extra: Dictionary for algorithm-specific state.
90
+ """
91
+
92
+ def __init__(self) -> None:
93
+ self.generation: int = 0
94
+ self.n_evals: int = 0
95
+ self.population: Optional[Tensor] = None
96
+ self.fitness: Optional[Tensor] = None
97
+ self.best_fitness: float = float('inf')
98
+ self.best_solution: Optional[Tensor] = None
99
+ self.hyperparams: Dict[str, Any] = {}
100
+ self.extra: Dict[str, Any] = {}
101
+
102
+ def update_best(self, population: Tensor, fitness: Tensor) -> None:
103
+ """Update best solution if improved."""
104
+ best_idx = torch.argmin(fitness)
105
+ best_val = float(fitness[best_idx].detach())
106
+
107
+ if best_val < self.best_fitness:
108
+ self.best_fitness = best_val
109
+ self.best_solution = population[best_idx].detach().clone()
110
+
111
+ def to_dict(self) -> Dict[str, Any]:
112
+ """Convert state to dictionary for serialisation."""
113
+ return {
114
+ "generation": self.generation,
115
+ "n_evals": self.n_evals,
116
+ "population": self.population.detach().cpu() if self.population is not None else None,
117
+ "fitness": self.fitness.detach().cpu() if self.fitness is not None else None,
118
+ "best_fitness": self.best_fitness,
119
+ "best_solution": self.best_solution.detach().cpu() if self.best_solution is not None else None,
120
+ "hyperparams": {
121
+ k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
122
+ for k, v in self.hyperparams.items()
123
+ },
124
+ "extra": self.extra,
125
+ }
126
+
127
+ @classmethod
128
+ def from_dict(cls, data: Dict[str, Any], device: torch.device) -> "AlgorithmState":
129
+ """Restore state from dictionary."""
130
+ state = cls()
131
+ state.generation = data["generation"]
132
+ state.n_evals = data["n_evals"]
133
+ state.best_fitness = data["best_fitness"]
134
+
135
+ if data["population"] is not None:
136
+ state.population = data["population"].to(device)
137
+ if data["fitness"] is not None:
138
+ state.fitness = data["fitness"].to(device)
139
+ if data["best_solution"] is not None:
140
+ state.best_solution = data["best_solution"].to(device)
141
+
142
+ state.hyperparams = data.get("hyperparams", {})
143
+ state.extra = data.get("extra", {})
144
+
145
+ return state
146
+
147
+
148
+ # =============================================================================
149
+ # Abstract Algorithm Base Class
150
+ # =============================================================================
151
+
152
+ class Algorithm(nn.Module, ABC):
153
+ """
154
+ Abstract base class for all EvoGrad optimisation algorithms.
155
+
156
+ This class provides the common infrastructure for population-based
157
+ optimisers following the pymoo-style dependency injection pattern.
158
+ Operators (selection, crossover, mutation, etc.) are passed to the
159
+ constructor and used during the optimisation loop.
160
+
161
+ Subclasses must implement:
162
+ - _infill(): Generate offspring population
163
+ - _advance(): Update state based on offspring evaluation
164
+
165
+ Optionally override:
166
+ - _setup(): One-time setup after initialisation
167
+ - _get_hyperparams(): Return current hyperparameter values
168
+
169
+ Args:
170
+ pop_size: Population size.
171
+ sampling: Operator for initial population generation.
172
+ selection: Parent selection operator (optional, algorithm-specific).
173
+ crossover: Crossover/recombination operator (optional).
174
+ mutation: Mutation operator (optional).
175
+ survival: Survivor selection operator (optional).
176
+ repair: Repair operator for constraint handling (optional).
177
+ eliminate_duplicates: Duplicate handling strategy:
178
+ - True: Use default epsilon-based elimination
179
+ - False: No duplicate elimination
180
+ - DuplicateEliminator instance: Custom eliminator
181
+ n_offsprings: Number of offspring per generation (default: pop_size).
182
+ differentiable: Enable gradient flow through operations.
183
+ dtype: Tensor dtype (default: torch.float32). Use ``torch.float64``
184
+ when the objective requires higher numerical precision (e.g.,
185
+ parameter estimation with stiff ODE solvers). The dtype should
186
+ match the Problem's dtype to avoid silent precision loss in
187
+ operator computations and log/exp hyperparameter transforms.
188
+
189
+ Attributes:
190
+ pop_size: Population size.
191
+ n_offsprings: Number of offspring per generation.
192
+ differentiable: Whether gradients are enabled.
193
+ dtype: Tensor data type.
194
+ problem: The Problem instance (set after initialize()).
195
+ state: AlgorithmState containing current optimisation state.
196
+
197
+ Example:
198
+ >>> ga = GA(
199
+ ... pop_size=100,
200
+ ... sampling=FloatRandomSampling(),
201
+ ... selection=TournamentSelection(tournament_size=3),
202
+ ... crossover=SBXCrossover(eta=15, prob=0.9),
203
+ ... mutation=PolynomialMutation(eta=20),
204
+ ... repair=BoundsRepair(method='reflect'),
205
+ ... eliminate_duplicates=True,
206
+ ... differentiable=True,
207
+ ... )
208
+ >>>
209
+ >>> problem = Problem(objective=ackley, n_var=30, xl=-100, xu=100)
210
+ >>> ga.initialize(problem)
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ pop_size: int = 100,
216
+ sampling: Optional[nn.Module] = None,
217
+ selection: Optional[nn.Module] = None,
218
+ crossover: Optional[nn.Module] = None,
219
+ mutation: Optional[nn.Module] = None,
220
+ survival: Optional[nn.Module] = None,
221
+ repair: Optional[nn.Module] = None,
222
+ eliminate_duplicates: Union[bool, DuplicateEliminator] = True,
223
+ n_offsprings: Optional[int] = None,
224
+ differentiable: bool = True,
225
+ adaptive: bool = True,
226
+ dtype: torch.dtype = torch.float32,
227
+ ) -> None:
228
+ super().__init__()
229
+
230
+ # Validate inputs
231
+ if pop_size < 1:
232
+ raise ValueError(f"pop_size must be >= 1, got {pop_size}")
233
+
234
+ # Device and dtype
235
+ self.dtype = dtype
236
+
237
+ # Population parameters
238
+ self.pop_size = pop_size
239
+ self.n_offsprings = n_offsprings if n_offsprings is not None else pop_size
240
+ self.differentiable = differentiable
241
+ self.adaptive = adaptive
242
+
243
+ # Store operators (can be None for algorithms that don't use them)
244
+ if sampling is None:
245
+ sampling = UniformSampling()
246
+ self.sampling = sampling
247
+
248
+ self.selection = selection
249
+ self.crossover = crossover
250
+ self.mutation = mutation
251
+ self.survival = survival
252
+ self.repair = repair
253
+
254
+ # Register operators as submodules if they are nn.Module
255
+ # This ensures their parameters are included in algorithm.parameters()
256
+ self._register_operator("sampling", sampling)
257
+ self._register_operator("selection", selection)
258
+ self._register_operator("crossover", crossover)
259
+ self._register_operator("mutation", mutation)
260
+ self._register_operator("survival", survival)
261
+ self._register_operator("repair", repair)
262
+
263
+ # Set up duplicate elimination
264
+ if eliminate_duplicates is True:
265
+ self.duplicate_eliminator = DuplicateEliminator(
266
+ method=DuplicateMethod.EPSILON_L2,
267
+ epsilon=1e-8,
268
+ )
269
+ elif eliminate_duplicates is False:
270
+ self.duplicate_eliminator = None
271
+ else:
272
+ self.duplicate_eliminator = eliminate_duplicates
273
+
274
+ # Problem reference (set in initialize())
275
+ self.problem: Optional[Problem] = None
276
+
277
+ # Algorithm state
278
+ self.state = AlgorithmState()
279
+
280
+ # Internal flags
281
+ self._is_initialized = False
282
+
283
+ def _register_operator(self, name: str, operator: Optional[nn.Module]) -> None:
284
+ """Register an operator as a submodule if it's an nn.Module."""
285
+ if operator is not None and isinstance(operator, nn.Module):
286
+ self.add_module(f"_op_{name}", operator)
287
+
288
+ # =========================================================================
289
+ # Core Interface (to be implemented by subclasses)
290
+ # =========================================================================
291
+
292
+ @abstractmethod
293
+ def _infill(self) -> Tensor:
294
+ """
295
+ Generate offspring population.
296
+
297
+ This method implements the core evolutionary operators:
298
+ selection, crossover, mutation, etc. Subclasses use the
299
+ operators passed to __init__ as needed.
300
+
301
+ Returns:
302
+ Offspring population tensor of shape (n_offsprings, n_var).
303
+ """
304
+ pass
305
+
306
+ @abstractmethod
307
+ def _advance(self, offspring: Tensor, offspring_fitness: Tensor) -> None:
308
+ """
309
+ Update algorithm state based on offspring evaluation.
310
+
311
+ This method implements survivor selection and state updates.
312
+ Should update self.state.population, self.state.fitness, and
313
+ call self.state.update_best().
314
+
315
+ Args:
316
+ offspring: Offspring population tensor.
317
+ offspring_fitness: Fitness values of offspring.
318
+ """
319
+ pass
320
+
321
+ # =========================================================================
322
+ # Optional Hooks (can be overridden)
323
+ # =========================================================================
324
+
325
+ def _setup(self) -> None:
326
+ """
327
+ One-time setup after initialisation.
328
+
329
+ Override to perform algorithm-specific setup that requires
330
+ the problem and population to be initialized.
331
+ Called at the end of initialize().
332
+ """
333
+ pass
334
+
335
+ def _get_hyperparams(self) -> Dict[str, Any]:
336
+ """
337
+ Return current hyperparameter values.
338
+
339
+ Override to include algorithm-specific hyperparameters.
340
+ These are passed to callbacks and stored in history.
341
+
342
+ Returns:
343
+ Dictionary of hyperparameter names to values.
344
+ """
345
+ return {}
346
+
347
+ # =========================================================================
348
+ # Initialisation
349
+ # =========================================================================
350
+
351
+ def initialize(self, problem: Problem) -> "Algorithm":
352
+ """
353
+ Initialize the algorithm with a problem.
354
+
355
+ Creates initial population, evaluates fitness, and sets up
356
+ internal state. Must be called before step() or forward().
357
+
358
+ Args:
359
+ problem: Problem instance defining objective, bounds, etc.
360
+
361
+ Returns:
362
+ Self for method chaining.
363
+ """
364
+ if self._is_initialized:
365
+ return self
366
+
367
+ # Store problem reference
368
+ self.problem = problem
369
+ self.device = problem.device
370
+
371
+ # Move problem bounds to device
372
+ self.register_buffer(
373
+ "xl",
374
+ problem.xl.to(device=self.device, dtype=self.dtype)
375
+ )
376
+ self.register_buffer(
377
+ "xu",
378
+ problem.xu.to(device=self.device, dtype=self.dtype)
379
+ )
380
+
381
+ # Create initial population using sampling operator
382
+ population = self.sampling(self.pop_size, problem)
383
+
384
+ # Apply repair if provided
385
+ if self.repair is not None:
386
+ population = self.repair(population, self.xl, self.xu)
387
+
388
+ # Eliminate duplicates
389
+ if self.duplicate_eliminator is not None:
390
+ population = self.duplicate_eliminator(population, self.xl, self.xu)
391
+
392
+ # Store population
393
+ if self.differentiable:
394
+ self._population = nn.Parameter(population)
395
+ else:
396
+ self.register_buffer("_population", population)
397
+
398
+ # Evaluate initial population
399
+ fitness = self._evaluate(population)
400
+
401
+ # Initialize state
402
+ # Always reference the registered _population (nn.Parameter or buffer)
403
+ # to ensure gradient flow in differentiable mode and a single source of truth.
404
+ self.state.population = self._population
405
+ self.state.fitness = fitness
406
+ self.state.generation = 0
407
+ self.state.n_evals = self.pop_size
408
+ self.state.update_best(self._population, fitness)
409
+
410
+ # Algorithm-specific setup
411
+ self._setup()
412
+
413
+ self._is_initialized = True
414
+ return self
415
+
416
+ def reset(self, seed: Optional[int] = None) -> "Algorithm":
417
+ """
418
+ Reset algorithm to initial state.
419
+
420
+ Args:
421
+ seed: Optional new random seed.
422
+
423
+ Returns:
424
+ Self for method chaining.
425
+
426
+ Raises:
427
+ RuntimeError: If problem not set (never initialized).
428
+ """
429
+ if self.problem is None:
430
+ raise RuntimeError(
431
+ "Cannot reset: algorithm was never initialized with a problem."
432
+ )
433
+
434
+ if seed is not None:
435
+ set_seed(seed)
436
+ self._seed = seed
437
+
438
+ self._is_initialized = False
439
+ self.state = AlgorithmState()
440
+
441
+ return self.initialize(self.problem)
442
+
443
+ # =========================================================================
444
+ # Evaluation
445
+ # =========================================================================
446
+
447
+ def _evaluate(self, x: Tensor) -> Tensor:
448
+ """
449
+ Evaluate fitness of population.
450
+
451
+ Args:
452
+ x: Population tensor of shape (N, n_var).
453
+
454
+ Returns:
455
+ Fitness tensor of shape (N,).
456
+ """
457
+ fitness = self.problem.evaluate(x)
458
+
459
+ # Ensure correct shape and type
460
+ if fitness.dim() == 0:
461
+ fitness = fitness.unsqueeze(0)
462
+
463
+ return fitness.to(device=self.device, dtype=self.dtype)
464
+
465
+ # =========================================================================
466
+ # Main Evolution Methods
467
+ # =========================================================================
468
+
469
+ def step(self) -> float:
470
+ """
471
+ Perform one generation of evolution (classical mode).
472
+
473
+ This is the main entry point for advancing the algorithm
474
+ without gradient tracking. Handles the complete cycle:
475
+ infill -> repair -> eliminate_duplicates -> evaluate -> advance.
476
+
477
+ Returns:
478
+ Best fitness after this generation.
479
+
480
+ Raises:
481
+ RuntimeError: If algorithm not initialized.
482
+ """
483
+ if not self._is_initialized:
484
+ raise RuntimeError(
485
+ "Algorithm not initialized. Call initialize(problem) first."
486
+ )
487
+
488
+ # Generate offspring
489
+ offspring = self._infill()
490
+
491
+ # Apply repair if provided
492
+ if self.repair is not None:
493
+ offspring = self.repair(offspring, self.xl, self.xu)
494
+
495
+ # Eliminate duplicates
496
+ if self.duplicate_eliminator is not None:
497
+ offspring = self.duplicate_eliminator(offspring, self.xl, self.xu)
498
+
499
+ # Evaluate offspring
500
+ offspring_fitness = self._evaluate(offspring)
501
+ self.state.n_evals += offspring.shape[0]
502
+
503
+ # Update state (implemented by subclass)
504
+ self._advance(offspring, offspring_fitness)
505
+ self.state.generation += 1
506
+
507
+ # Update hyperparams in state
508
+ self.state.hyperparams = self._get_hyperparams()
509
+
510
+ return self.state.best_fitness
511
+
512
+ def forward(self) -> Tensor:
513
+ """
514
+ PyTorch forward pass for differentiable optimisation.
515
+
516
+ In differentiable mode, this builds a computation graph
517
+ through the entire generation, returning the best fitness
518
+ as a differentiable scalar loss. Call update_state() after
519
+ loss.backward() and optimizer.step() to commit changes.
520
+
521
+ Returns:
522
+ Best fitness as a scalar tensor (for backprop).
523
+
524
+ Raises:
525
+ RuntimeError: If algorithm not initialized.
526
+ """
527
+ if not self._is_initialized:
528
+ raise RuntimeError(
529
+ "Algorithm not initialized. Call initialize(problem) first."
530
+ )
531
+
532
+ # Generate offspring (differentiable)
533
+ offspring = self._infill()
534
+
535
+ # Apply repair if provided (should be differentiable)
536
+ if self.repair is not None:
537
+ offspring = self.repair(offspring, self.xl, self.xu)
538
+
539
+ # Note: duplicate elimination is typically not differentiable
540
+ # Skip in forward pass, apply in update_state if needed
541
+
542
+ # Evaluate offspring (differentiable if objective supports it)
543
+ offspring_fitness = self._evaluate(offspring)
544
+ self.state.n_evals += offspring.shape[0]
545
+
546
+ # Store for update_state() to commit later
547
+ self._pending_offspring = offspring
548
+ self._pending_fitness = offspring_fitness
549
+
550
+ # Return best fitness as loss
551
+ return offspring_fitness.min()
552
+
553
+ @torch.no_grad()
554
+ def update_state(self) -> None:
555
+ """
556
+ Commit pending changes after backward pass.
557
+
558
+ In differentiable mode, call this after loss.backward()
559
+ and optimizer.step() to update the algorithm state.
560
+ """
561
+ if not hasattr(self, "_pending_offspring"):
562
+ return
563
+
564
+ offspring = self._pending_offspring
565
+ offspring_fitness = self._pending_fitness
566
+
567
+ # Now apply duplicate elimination (non-differentiable)
568
+ if self.duplicate_eliminator is not None:
569
+ # Duplicate elimination is non-differentiable. We can avoid
570
+ # a full re-evaluation by only re-evaluating individuals that
571
+ # were actually resampled.
572
+ offspring, changed_indices = self.duplicate_eliminator(
573
+ offspring, self.xl, self.xu, return_indices=True
574
+ )
575
+
576
+ if changed_indices.numel() > 0:
577
+ # Ensure we can assign into the fitness tensor safely.
578
+ offspring_fitness = offspring_fitness.clone()
579
+
580
+ changed_fitness = self._evaluate(offspring[changed_indices])
581
+ offspring_fitness[changed_indices] = changed_fitness
582
+ self.state.n_evals += int(changed_indices.numel())
583
+
584
+ # Advance with pending offspring
585
+ self._advance(offspring, offspring_fitness)
586
+ self.state.generation += 1
587
+ self.state.hyperparams = self._get_hyperparams()
588
+
589
+ # Clean up
590
+ del self._pending_offspring
591
+ del self._pending_fitness
592
+
593
+ # =========================================================================
594
+ # Properties
595
+ # =========================================================================
596
+
597
+ @property
598
+ def population(self) -> Tensor:
599
+ """Current population tensor."""
600
+ return self._population
601
+
602
+ @property
603
+ def fitness(self) -> Optional[Tensor]:
604
+ """Current fitness values."""
605
+ return self.state.fitness
606
+
607
+ @property
608
+ def best_fitness(self) -> float:
609
+ """Best fitness found so far."""
610
+ return self.state.best_fitness
611
+
612
+ @property
613
+ def best_solution(self) -> Optional[Tensor]:
614
+ """Best solution found so far."""
615
+ return self.state.best_solution
616
+
617
+ @property
618
+ def n_evals(self) -> int:
619
+ """Total number of fitness evaluations."""
620
+ return self.state.n_evals
621
+
622
+ @property
623
+ def generation(self) -> int:
624
+ """Current generation number."""
625
+ return self.state.generation
626
+
627
+ @property
628
+ def n_var(self) -> Optional[int]:
629
+ """Number of variables (from problem)."""
630
+ return self.problem.n_var if self.problem is not None else None
631
+
632
+ # =========================================================================
633
+ # Serialisation
634
+ # =========================================================================
635
+
636
+ def state_dict(self) -> Dict[str, Any]:
637
+ """
638
+ Get complete state dictionary for checkpointing.
639
+
640
+ Returns:
641
+ Dictionary containing all state for serialisation.
642
+ """
643
+ state = {
644
+ "algorithm_state": self.state.to_dict(),
645
+ "model_state": super().state_dict(),
646
+ "config": {
647
+ "pop_size": self.pop_size,
648
+ "n_offsprings": self.n_offsprings,
649
+ "differentiable": self.differentiable,
650
+ "adaptive": self.adaptive,
651
+ },
652
+ "is_initialized": self._is_initialized,
653
+ }
654
+ return state
655
+
656
+ def load_state_dict(
657
+ self,
658
+ state_dict: Dict[str, Any],
659
+ strict: bool = True,
660
+ ) -> None:
661
+ """
662
+ Load state from dictionary.
663
+
664
+ Note: The problem must be set separately via initialize()
665
+ or by setting self.problem before calling this method.
666
+
667
+ Args:
668
+ state_dict: State dictionary from state_dict().
669
+ strict: Whether to require exact key matching.
670
+ """
671
+ # Load model parameters (population, operator params, etc.)
672
+ super().load_state_dict(state_dict["model_state"], strict=strict)
673
+
674
+ # Load algorithm state
675
+ self.state = AlgorithmState.from_dict(
676
+ state_dict["algorithm_state"], self.device
677
+ )
678
+
679
+ self._is_initialized = state_dict.get("is_initialized", True)
680
+
681
+ # =========================================================================
682
+ # String Representation
683
+ # =========================================================================
684
+
685
+ def __repr__(self) -> str:
686
+ mode = "differentiable" if self.differentiable else "classical"
687
+ status = "initialized" if self._is_initialized else "not initialized"
688
+ n_var = self.n_var if self.problem else "?"
689
+ return (
690
+ f"{self.__class__.__name__}("
691
+ f"pop_size={self.pop_size}, "
692
+ f"n_var={n_var}, "
693
+ f"mode={mode}, "
694
+ f"status={status}, "
695
+ f"device={self.device})"
696
+ )
697
+
698
+ def summary(self) -> str:
699
+ """Return a detailed summary of the algorithm configuration."""
700
+ lines = [
701
+ f"{'=' * 60}",
702
+ f"Algorithm: {self.__class__.__name__}",
703
+ f"{'=' * 60}",
704
+ f" Population size: {self.pop_size}",
705
+ f" Offspring size: {self.n_offsprings}",
706
+ f" Mode: {'Differentiable' if self.differentiable else 'Classical'}",
707
+ f" Adaptive: {self.adaptive}",
708
+ f" Device: {self.device}",
709
+ f" Initialized: {self._is_initialized}",
710
+ ]
711
+
712
+ # Problem info
713
+ if self.problem is not None:
714
+ lines.extend([
715
+ f"",
716
+ f"Problem:",
717
+ f" Variables: {self.problem.n_var}",
718
+ f" Bounds: [{float(self.xl.min()):.2g}, {float(self.xu.max()):.2g}]",
719
+ ])
720
+
721
+ # Operators
722
+ lines.append(f"")
723
+ lines.append("Operators:")
724
+ operators = [
725
+ ("Sampling", self.sampling),
726
+ ("Selection", self.selection),
727
+ ("Crossover", self.crossover),
728
+ ("Mutation", self.mutation),
729
+ ("Survival", self.survival),
730
+ ("Repair", self.repair),
731
+ ]
732
+ for name, op in operators:
733
+ if op is not None:
734
+ lines.append(f" {name}: {op.__class__.__name__}")
735
+ else:
736
+ lines.append(f" {name}: None")
737
+
738
+ # Duplicate elimination
739
+ if self.duplicate_eliminator is not None:
740
+ lines.append(f" Duplicates: {self.duplicate_eliminator.method.name}")
741
+ else:
742
+ lines.append(f" Duplicates: Disabled")
743
+
744
+ # State info
745
+ if self._is_initialized:
746
+ lines.extend([
747
+ f"",
748
+ f"State:",
749
+ f" Generation: {self.state.generation}",
750
+ f" Evaluations: {self.state.n_evals}",
751
+ f" Best fitness: {self.state.best_fitness:.6g}",
752
+ ])
753
+
754
+ # Hyperparameters
755
+ hp = self._get_hyperparams()
756
+ if hp:
757
+ lines.append(f"")
758
+ lines.append("Hyperparameters:")
759
+ for name, value in hp.items():
760
+ if isinstance(value, torch.Tensor):
761
+ value = float(value.mean()) if value.numel() > 1 else float(value)
762
+ if isinstance(value, float):
763
+ lines.append(f" {name}: {value:.4g}")
764
+ else:
765
+ lines.append(f" {name}: {value}")
766
+
767
+ # Parameter count
768
+ n_params = sum(p.numel() for p in self.parameters())
769
+ n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
770
+ lines.extend([
771
+ f"",
772
+ f"Parameters:",
773
+ f" Total: {n_params:,}",
774
+ f" Trainable: {n_trainable:,}",
775
+ f"{'=' * 60}",
776
+ ])
777
+
778
+ return "\n".join(lines)