evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1000 @@
1
+ """
2
+ Survival selection operators for generational replacement.
3
+
4
+ This module provides strategies for selecting which individuals
5
+ survive to the next generation. All operators support both classical
6
+ and differentiable (i.e., adaptive) modes.
7
+
8
+ Available survival strategies:
9
+ - MergeSurvival: (μ+λ) - Select best from parents + offspring
10
+ - CommaSurvival: (μ,λ) - Select only from offspring
11
+ - ReplaceWorstSurvival: Steady-state replacement of worst
12
+ - FitnessSurvival: Simple fitness-based truncation
13
+ - AgeSurvival: Age-based replacement with fitness tie-breaking
14
+
15
+ Elitism:
16
+ Most survival operators support elitism, which ensures that the
17
+ best n_elite individuals from the parent population are always
18
+ preserved in the next generation.
19
+
20
+ Differentiable Mode:
21
+ When `adaptive=True`, survival selection uses soft ranking
22
+ based on fitness scores with temperature-scaled softmax,
23
+ if and only if another operators has `adaptive=True`,
24
+ allowing gradients to flow through the selection process
25
+ via the temperature parameter using straight-through estimator.
26
+
27
+ Example:
28
+ >>> from evograd.operators import MergeSurvival
29
+ >>>
30
+ >>> # Classical (μ+λ) survival
31
+ >>> survival = MergeSurvival(n_survive=100, elitism=True, n_elite=1)
32
+ >>> new_pop, new_fit = survival(
33
+ ... parents, parent_fitness,
34
+ ... offspring, offspring_fitness,
35
+ ... )
36
+ >>>
37
+ >>> # Differentiable mode
38
+ >>> survival = MergeSurvival(
39
+ ... n_survive=100,
40
+ ... adaptive=True,
41
+ ... temperature=1.0,
42
+ ... )
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ from abc import ABC, abstractmethod
48
+ from typing import Optional, Tuple
49
+
50
+ import torch
51
+ import torch.nn as nn
52
+ from torch import Tensor
53
+
54
+ from evograd.operators.selection import TopKSelection
55
+
56
+ __all__ = [
57
+ "Survival",
58
+ "MergeSurvival",
59
+ "CommaSurvival",
60
+ "ReplaceWorstSurvival",
61
+ "FitnessSurvival",
62
+ "AgeSurvival",
63
+ ]
64
+
65
+
66
+ # =============================================================================
67
+ # Base Survival Class
68
+ # =============================================================================
69
+
70
+ class Survival(nn.Module, ABC):
71
+ """
72
+ Abstract base class for survival selection operators.
73
+
74
+ Survival selection determines which individuals from the combined
75
+ parent and offspring populations survive to the next generation.
76
+
77
+ Subclasses must implement:
78
+ - _survive(): Perform survival selection
79
+
80
+ Args:
81
+ n_survive: Number of individuals to survive (population size).
82
+ elitism: If True, preserve best individuals from parents.
83
+ n_elite: Number of elite individuals to preserve.
84
+ adaptive: If True, use soft selection for gradients.
85
+ temperature: Temperature for soft selection.
86
+ learn_temperature: If True, temperature is learnable.
87
+ minimize: If True, lower fitness is better (default).
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ n_survive: Optional[int] = None,
93
+ elitism: bool = True,
94
+ n_elite: int = 1,
95
+ adaptive: bool = False,
96
+ temperature: float = 1.0,
97
+ learn_temperature: bool = True,
98
+ minimize: bool = True,
99
+ ) -> None:
100
+ super().__init__()
101
+
102
+ self.n_survive = n_survive
103
+ self.elitism = elitism
104
+ self.n_elite = n_elite if elitism else 0
105
+ self.adaptive = adaptive
106
+ self.minimize = minimize
107
+
108
+ # Temperature parameter for soft selection
109
+ if learn_temperature and adaptive:
110
+ self._log_temperature = nn.Parameter(
111
+ torch.tensor(temperature).log()
112
+ )
113
+ else:
114
+ self.register_buffer(
115
+ "_log_temperature",
116
+ torch.tensor(temperature).log()
117
+ )
118
+
119
+ @property
120
+ def temperature(self) -> Tensor:
121
+ """Current temperature value."""
122
+ return self._log_temperature.exp()
123
+
124
+ @temperature.setter
125
+ def temperature(self, value: float) -> None:
126
+ """Set temperature value."""
127
+ with torch.no_grad():
128
+ self._log_temperature.fill_(torch.tensor(value).log())
129
+
130
+ def _fitness_to_scores(self, fitness: Tensor) -> Tensor:
131
+ """
132
+ Convert fitness to selection scores (higher = better).
133
+
134
+ Args:
135
+ fitness: Raw fitness values [n].
136
+
137
+ Returns:
138
+ Selection scores [n] where higher is better.
139
+ """
140
+ if self.minimize:
141
+ return -fitness
142
+ else:
143
+ return fitness
144
+
145
+ def _compute_soft_weights(self, fitness: Tensor) -> Tensor:
146
+ """
147
+ Compute soft selection weights using temperature-scaled softmax.
148
+
149
+ Args:
150
+ fitness: Fitness values [n].
151
+
152
+ Returns:
153
+ Soft selection weights [n] summing to 1.
154
+ """
155
+ scores = self._fitness_to_scores(fitness)
156
+ return torch.softmax(scores / self.temperature, dim=0)
157
+
158
+
159
+
160
+ # -------------------------------------------------------------------------
161
+ # Shared Utilities
162
+ # -------------------------------------------------------------------------
163
+
164
+ def _best_indices(self, fitness: Tensor, k: int) -> Tensor:
165
+ """
166
+ Indices of the best k individuals according to `minimize`.
167
+
168
+ Uses score-space (higher = better) so it works for both minimization
169
+ and maximization.
170
+ """
171
+ return TopKSelection.best_indices(fitness, k, minimize=self.minimize)
172
+
173
+ def _worst_indices(self, fitness: Tensor, k: int) -> Tensor:
174
+ """
175
+ Indices of the worst k individuals according to `minimize`.
176
+
177
+ Uses score-space (higher = better) so it works for both minimization
178
+ and maximization.
179
+ """
180
+ return TopKSelection.worst_indices(fitness, k, minimize=self.minimize)
181
+
182
+ def _sort_indices_best_first(self, fitness: Tensor) -> Tensor:
183
+ """Return indices that sort individuals from best to worst."""
184
+ return TopKSelection.sort_indices_best_first(fitness, minimize=self.minimize)
185
+
186
+ def _gumbel_topk(
187
+ self,
188
+ logits: Tensor,
189
+ k: int,
190
+ dim: int = 0,
191
+ eps:float = 1e-8,
192
+ ) -> Tensor:
193
+ """
194
+ Differentiable top-k without replacement using sequential Gumbel-Softmax.
195
+
196
+ Returns a stack of one-hot vectors (hard in forward, soft in backward)
197
+ with shape [k, n] (assuming `dim=0` / logits is 1D).
198
+ """
199
+ weights, _ = TopKSelection.gumbel_topk(
200
+ logits,
201
+ k,
202
+ temperature=self.temperature,
203
+ dim=dim,
204
+ eps=eps,
205
+ )
206
+ return weights
207
+
208
+ def _adaptive_topk_select(
209
+ self,
210
+ population: Tensor,
211
+ fitness: Tensor,
212
+ k: int,
213
+ mask_out: Optional[Tensor] = None,
214
+ ) -> Tuple[Tensor, Tensor]:
215
+ """
216
+ Adaptive (differentiable) top-k selection with straight-through gradients.
217
+
218
+ Args:
219
+ population: Candidate individuals [n, n_var].
220
+ fitness: Candidate fitness [n].
221
+ k: Number to select.
222
+ mask_out: Optional boolean mask [n] of candidates to exclude.
223
+
224
+ Returns:
225
+ survivors: Selected individuals [k, n_var].
226
+ survivor_fit: Selected fitness [k].
227
+ """
228
+ n = population.shape[0]
229
+ if k <= 0:
230
+ raise ValueError(f"k must be > 0, got {k}")
231
+ if k > n:
232
+ raise ValueError(f"Cannot select k={k} from n={n}")
233
+
234
+ # Convert fitness to score-space (higher is better)
235
+ scores = self._fitness_to_scores(fitness)
236
+
237
+ # Optional masking (e.g., to exclude elites when selecting the remainder)
238
+ logits = scores.unsqueeze(0) # [1, n]
239
+ if mask_out is not None:
240
+ if mask_out.dtype != torch.bool:
241
+ raise TypeError("mask_out must be a boolean tensor")
242
+ if mask_out.shape[0] != n:
243
+ raise ValueError(f"mask_out must have shape [{n}], got {tuple(mask_out.shape)}")
244
+ logits = logits.clone()
245
+ logits[:, mask_out] = -1e9
246
+
247
+ # Sequential (without-replacement) ST Gumbel top-k
248
+ weights, idx = TopKSelection.gumbel_topk(
249
+ logits,
250
+ k=k,
251
+ temperature=self.temperature,
252
+ dim=-1,
253
+ )
254
+
255
+ # weights: [1, k, n] -> [k, n]
256
+ weights = weights.squeeze(0)
257
+
258
+ selected = weights @ population # [k, n_var]
259
+
260
+ # Compute (soft) fitness values for the selected survivors.
261
+ # Forward pass is hard (one-hot) due to straight-through Gumbel-Softmax,
262
+ # backward pass carries gradients through the soft weights.
263
+ selected_fit = weights @ fitness # [k]
264
+
265
+ return selected, selected_fit
266
+
267
+ def _survive(
268
+ self,
269
+ parents: Tensor,
270
+ parent_fitness: Tensor,
271
+ offspring: Tensor,
272
+ offspring_fitness: Tensor,
273
+ n_survive: int,
274
+ ) -> Tuple[Tensor, Tensor]:
275
+ """
276
+ Perform survival selection.
277
+
278
+ Args:
279
+ parents: Parent population [n_parents, n_var].
280
+ parent_fitness: Parent fitness [n_parents].
281
+ offspring: Offspring population [n_offspring, n_var].
282
+ offspring_fitness: Offspring fitness [n_offspring].
283
+ n_survive: Number of individuals to survive.
284
+
285
+ Returns:
286
+ Tuple of (survivors, survivor_fitness).
287
+ """
288
+ pass
289
+
290
+ def forward(
291
+ self,
292
+ parents: Tensor,
293
+ parent_fitness: Tensor,
294
+ offspring: Tensor,
295
+ offspring_fitness: Tensor,
296
+ n_survive: Optional[int] = None,
297
+ ) -> Tuple[Tensor, Tensor]:
298
+ """
299
+ Apply survival selection.
300
+
301
+ Args:
302
+ parents: Parent population [n_parents, n_var].
303
+ parent_fitness: Parent fitness [n_parents].
304
+ offspring: Offspring population [n_offspring, n_var].
305
+ offspring_fitness: Offspring fitness [n_offspring].
306
+ n_survive: Number to survive (default: n_parents or self.n_survive).
307
+
308
+ Returns:
309
+ Tuple of (survivors, survivor_fitness).
310
+ """
311
+ if n_survive is None:
312
+ n_survive = self.n_survive if self.n_survive is not None else parents.shape[0]
313
+
314
+ return self._survive(
315
+ parents, parent_fitness,
316
+ offspring, offspring_fitness,
317
+ n_survive,
318
+ )
319
+
320
+ # Note: Do NOT override __call__. nn.Module.__call__ dispatches to
321
+ # forward() and fires registered hooks (forward_pre_hooks, forward_hooks,
322
+ # and the autograd profiler). Overriding __call__ would bypass all of these.
323
+
324
+
325
+ # =============================================================================
326
+ # Merge Survival (μ+λ)
327
+ # =============================================================================
328
+
329
+ class MergeSurvival(Survival):
330
+ """
331
+ (μ+λ) Merge survival: Select best from parents + offspring.
332
+
333
+ Combines the parent and offspring populations, then selects
334
+ the best n_survive individuals based on fitness. This is the
335
+ most common survival strategy for evolutionary algorithms.
336
+
337
+ With elitism enabled, the best n_elite parents are guaranteed
338
+ to survive regardless of offspring fitness.
339
+
340
+ Args:
341
+ n_survive: Number of individuals to survive.
342
+ elitism: If True, preserve best parents (default: True).
343
+ n_elite: Number of elite parents to preserve.
344
+ adaptive: If True, use soft weighted selection.
345
+ temperature: Temperature for soft selection.
346
+ minimize: If True, lower fitness is better.
347
+
348
+ Example:
349
+ >>> survival = MergeSurvival(n_survive=100, elitism=True, n_elite=2)
350
+ >>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
351
+ """
352
+
353
+ def _survive(
354
+ self,
355
+ parents: Tensor,
356
+ parent_fitness: Tensor,
357
+ offspring: Tensor,
358
+ offspring_fitness: Tensor,
359
+ n_survive: int,
360
+ ) -> Tuple[Tensor, Tensor]:
361
+ """Select best from combined population."""
362
+
363
+ if self.adaptive:
364
+ return self._survive_adaptive(
365
+ parents, parent_fitness,
366
+ offspring, offspring_fitness,
367
+ n_survive,
368
+ )
369
+ else:
370
+ return self._survive_hard(
371
+ parents, parent_fitness,
372
+ offspring, offspring_fitness,
373
+ n_survive,
374
+ )
375
+
376
+ def _survive_hard(
377
+ self,
378
+ parents: Tensor,
379
+ parent_fitness: Tensor,
380
+ offspring: Tensor,
381
+ offspring_fitness: Tensor,
382
+ n_survive: int,
383
+ ) -> Tuple[Tensor, Tensor]:
384
+ """Hard (classical) survival selection."""
385
+
386
+ # Handle elitism first
387
+ if self.elitism and self.n_elite > 0:
388
+ # Get elite from parents
389
+ elite_idx = self._best_indices(parent_fitness, self.n_elite)
390
+ elite_pop = parents[elite_idx].clone()
391
+ elite_fit = parent_fitness[elite_idx].clone()
392
+
393
+ # Combine remaining
394
+ combined_pop = torch.cat([parents, offspring], dim=0)
395
+ combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
396
+
397
+ # Select remaining slots
398
+ n_remaining = n_survive - self.n_elite
399
+ sorted_idx = self._best_indices(combined_fit, n_remaining)
400
+
401
+ # Combine elite + selected
402
+ survivors = torch.cat([elite_pop, combined_pop[sorted_idx]], dim=0)
403
+ survivor_fit = torch.cat([elite_fit, combined_fit[sorted_idx]], dim=0)
404
+
405
+ # Re-sort by fitness
406
+ final_idx = self._sort_indices_best_first(survivor_fit)
407
+ return survivors[final_idx], survivor_fit[final_idx]
408
+ else:
409
+ # Simple selection from combined
410
+ combined_pop = torch.cat([parents, offspring], dim=0)
411
+ combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
412
+
413
+ sorted_idx = self._best_indices(combined_fit, n_survive)
414
+ return combined_pop[sorted_idx], combined_fit[sorted_idx]
415
+
416
+ def _survive_adaptive(
417
+ self,
418
+ parents: Tensor,
419
+ parent_fitness: Tensor,
420
+ offspring: Tensor,
421
+ offspring_fitness: Tensor,
422
+ n_survive: int,
423
+ ) -> Tuple[Tensor, Tensor]:
424
+ """
425
+ Soft (differentiable/adaptive) survival selection.
426
+
427
+ Uses straight-through estimator via Gumbel-Softmax:
428
+ - Forward pass: discrete top-k without replacement
429
+ - Backward pass: soft gradients flow through temperature
430
+
431
+ Note: Elitism (if enabled) is applied as a hard constraint by first
432
+ preserving elite parents, then selecting the remaining survivors from
433
+ the remaining candidate pool.
434
+ """
435
+ # Combine populations
436
+ combined_pop = torch.cat([parents, offspring], dim=0)
437
+ combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
438
+
439
+ # Handle elitism first (hard constraint)
440
+ if self.elitism and self.n_elite > 0:
441
+ n_elite = min(self.n_elite, parents.shape[0], n_survive)
442
+ elite_idx = self._best_indices(parent_fitness, n_elite)
443
+ elite_pop = parents[elite_idx].clone()
444
+ elite_fit = parent_fitness[elite_idx].clone()
445
+
446
+ n_remaining = n_survive - n_elite
447
+ if n_remaining <= 0:
448
+ # Sort best-first for consistency
449
+ final_idx = self._sort_indices_best_first(elite_fit)
450
+ return elite_pop[final_idx], elite_fit[final_idx]
451
+
452
+ # Mask out elites in the combined candidate pool to avoid duplicates
453
+ mask_out = torch.zeros(combined_fit.shape[0], dtype=torch.bool, device=combined_fit.device)
454
+ mask_out[elite_idx] = True # elite indices refer to parents, which occupy the first block in combined
455
+
456
+ survivors_rest, fit_rest = self._adaptive_topk_select(
457
+ combined_pop,
458
+ combined_fit,
459
+ k=n_remaining,
460
+ mask_out=mask_out,
461
+ )
462
+
463
+ survivors = torch.cat([elite_pop, survivors_rest], dim=0)
464
+ survivor_fit = torch.cat([elite_fit, fit_rest], dim=0)
465
+
466
+ # Re-sort by fitness (best first)
467
+ final_idx = self._sort_indices_best_first(survivor_fit)
468
+ return survivors[final_idx], survivor_fit[final_idx]
469
+
470
+ # No elitism: select directly from combined
471
+ survivors, survivor_fit = self._adaptive_topk_select(
472
+ combined_pop,
473
+ combined_fit,
474
+ k=n_survive,
475
+ )
476
+
477
+ # Re-sort by fitness (best first)
478
+ final_idx = self._sort_indices_best_first(survivor_fit)
479
+ return survivors[final_idx], survivor_fit[final_idx]
480
+
481
+ # =============================================================================
482
+ # Comma Survival (μ,λ)
483
+ # =============================================================================
484
+
485
+ class CommaSurvival(Survival):
486
+ """
487
+ (μ,λ) Comma survival: Select only from offspring.
488
+
489
+ Discards all parents and selects the best n_survive individuals
490
+ from the offspring population only. This can help escape local
491
+ optima but requires n_offspring >= n_survive.
492
+
493
+ Elitism can still be enabled to preserve the best parents, but
494
+ they are added separately rather than competing with offspring.
495
+
496
+ Args:
497
+ n_survive: Number of individuals to survive.
498
+ elitism: If True, preserve best parents (recommended).
499
+ n_elite: Number of elite parents to preserve.
500
+ adaptive: If True, use soft selection.
501
+ temperature: Temperature for soft selection.
502
+ minimize: If True, lower fitness is better.
503
+
504
+ Example:
505
+ >>> survival = CommaSurvival(n_survive=50, elitism=True)
506
+ >>> # Requires offspring.shape[0] >= 50 (or 49 if n_elite=1)
507
+ >>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
508
+
509
+ Note:
510
+ With elitism, offspring count must be >= n_survive - n_elite.
511
+ Without elitism, offspring count must be >= n_survive.
512
+ """
513
+
514
+ def _survive(
515
+ self,
516
+ parents: Tensor,
517
+ parent_fitness: Tensor,
518
+ offspring: Tensor,
519
+ offspring_fitness: Tensor,
520
+ n_survive: int,
521
+ ) -> Tuple[Tensor, Tensor]:
522
+ """Select from offspring only (with optional elite from parents)."""
523
+
524
+ n_offspring = offspring.shape[0]
525
+ n_from_offspring = n_survive - self.n_elite if self.elitism else n_survive
526
+
527
+ if n_offspring < n_from_offspring:
528
+ raise ValueError(
529
+ f"CommaSurvival requires at least {n_from_offspring} offspring, "
530
+ f"got {n_offspring}. Increase n_offsprings or reduce n_survive."
531
+ )
532
+
533
+ if self.adaptive:
534
+ # Soft selection with gradient through temperature
535
+ #
536
+ # Use straight-through Gumbel-Softmax top-k without replacement so:
537
+ # - forward is discrete (hard survivors)
538
+ # - backward carries gradients through temperature
539
+ hard_survivors, hard_fit = self._adaptive_topk_select(
540
+ offspring,
541
+ offspring_fitness,
542
+ k=n_from_offspring,
543
+ )
544
+
545
+ # Re-sort by fitness (best first)
546
+ final_idx = self._sort_indices_best_first(hard_fit)
547
+ survivors = hard_survivors[final_idx]
548
+ survivor_fit = hard_fit[final_idx]
549
+ else:
550
+ # Hard selection from offspring
551
+ sorted_idx = self._best_indices(offspring_fitness, n_from_offspring)
552
+ survivors = offspring[sorted_idx]
553
+ survivor_fit = offspring_fitness[sorted_idx]
554
+
555
+ # Add elites from parents
556
+ if self.elitism and self.n_elite > 0:
557
+ elite_idx = self._best_indices(parent_fitness, self.n_elite)
558
+ elite_pop = parents[elite_idx].clone()
559
+ elite_fit = parent_fitness[elite_idx].clone()
560
+
561
+ survivors = torch.cat([elite_pop, survivors], dim=0)
562
+ survivor_fit = torch.cat([elite_fit, survivor_fit], dim=0)
563
+
564
+ # Re-sort
565
+ final_idx = self._sort_indices_best_first(survivor_fit)
566
+ survivors = survivors[final_idx]
567
+ survivor_fit = survivor_fit[final_idx]
568
+
569
+ return survivors, survivor_fit
570
+
571
+
572
+ # =============================================================================
573
+ # Replace Worst Survival (Steady-State)
574
+ # =============================================================================
575
+
576
+ class ReplaceWorstSurvival(Survival):
577
+ """
578
+ Steady-state survival: Replace worst parents with best offspring.
579
+
580
+ Instead of replacing the entire population, only the worst
581
+ individuals are replaced by better offspring. This creates
582
+ higher selection pressure and can lead to faster convergence,
583
+ but may also cause premature convergence.
584
+
585
+ Each offspring competes with a parent. If the offspring is
586
+ better, it replaces that parent. The pairing can be:
587
+ - 'worst': Each offspring replaces worst remaining parent
588
+ - 'random': Random parent-offspring pairing
589
+
590
+ Args:
591
+ n_survive: Number of individuals in population (parents).
592
+ elitism: If True, best parent is never replaced.
593
+ n_elite: Number of protected parents.
594
+ replacement: Pairing strategy ('worst' or 'random').
595
+ adaptive: If True, use soft replacement.
596
+ temperature: Temperature for soft selection.
597
+ minimize: If True, lower fitness is better.
598
+
599
+ Example:
600
+ >>> survival = ReplaceWorstSurvival(n_survive=100, elitism=True)
601
+ >>> # With 10 offspring, up to 10 worst parents may be replaced
602
+ >>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
603
+ """
604
+
605
+ def __init__(
606
+ self,
607
+ n_survive: Optional[int] = None,
608
+ elitism: bool = True,
609
+ n_elite: int = 1,
610
+ replacement: str = 'worst',
611
+ adaptive: bool = False,
612
+ temperature: float = 1.0,
613
+ learn_temperature: bool = True,
614
+ minimize: bool = True,
615
+ ) -> None:
616
+ super().__init__(
617
+ n_survive=n_survive,
618
+ elitism=elitism,
619
+ n_elite=n_elite,
620
+ adaptive=adaptive,
621
+ temperature=temperature,
622
+ learn_temperature=learn_temperature,
623
+ minimize=minimize,
624
+ )
625
+
626
+ if replacement not in ['worst', 'random']:
627
+ raise ValueError(f"replacement must be 'worst' or 'random', got '{replacement}'")
628
+ self.replacement = replacement
629
+
630
+ def _survive(
631
+ self,
632
+ parents: Tensor,
633
+ parent_fitness: Tensor,
634
+ offspring: Tensor,
635
+ offspring_fitness: Tensor,
636
+ n_survive: int,
637
+ ) -> Tuple[Tensor, Tensor]:
638
+ """Replace worst parents with better offspring."""
639
+
640
+ n_parents = parents.shape[0]
641
+ n_offspring = offspring.shape[0]
642
+
643
+ # Start with parent population
644
+ survivors = parents.clone()
645
+ survivor_fit = parent_fitness.clone()
646
+
647
+ # Get indices of worst parents (candidates for replacement)
648
+ if self.replacement == 'worst':
649
+ # Sort by fitness descending (worst first)
650
+ # Sort by score ascending (worst first)
651
+ worst_idx = self._sort_indices_best_first(parent_fitness).flip(0)
652
+ else: # random
653
+ worst_idx = torch.randperm(n_parents, device=parents.device)
654
+
655
+ # Protect elites
656
+ if self.elitism and self.n_elite > 0:
657
+ elite_idx = set(self._best_indices(parent_fitness, self.n_elite).tolist())
658
+ # Filter out elite indices
659
+ worst_idx = torch.tensor(
660
+ [i for i in worst_idx.tolist() if i not in elite_idx],
661
+ device=parents.device,
662
+ )
663
+
664
+ # Sort offspring by fitness (best first)
665
+ best_offspring_idx = self._sort_indices_best_first(offspring_fitness)
666
+
667
+ # Replace worst with better offspring
668
+ n_replace = min(len(worst_idx), n_offspring)
669
+
670
+ if self.adaptive and n_replace > 0:
671
+ # Create gradient carrier using all fitness values
672
+ all_fit = torch.cat([parent_fitness, offspring_fitness])
673
+ scores = self._fitness_to_scores(all_fit)
674
+ soft_probs = torch.softmax(scores / self.temperature, dim=0)
675
+ soft_weighted_fit = torch.sum(soft_probs * all_fit)
676
+ fit_gradient_carrier = soft_weighted_fit - soft_weighted_fit.detach()
677
+
678
+ # Soft replacement with gradient flow
679
+ for i in range(n_replace):
680
+ parent_idx = worst_idx[i]
681
+ off_idx = best_offspring_idx[i]
682
+
683
+ # Hard decision
684
+ if self.minimize:
685
+ should_replace = offspring_fitness[off_idx] < survivor_fit[parent_idx]
686
+ else:
687
+ should_replace = offspring_fitness[off_idx] > survivor_fit[parent_idx]
688
+
689
+ if should_replace:
690
+ survivors[parent_idx] = offspring[off_idx]
691
+ survivor_fit[parent_idx] = offspring_fitness[off_idx]
692
+
693
+ # Add gradient carrier to all fitness values
694
+ survivor_fit = survivor_fit + fit_gradient_carrier
695
+ else:
696
+ # Hard replacement
697
+ for i in range(n_replace):
698
+ parent_idx = worst_idx[i]
699
+ off_idx = best_offspring_idx[i]
700
+
701
+ # Only replace if offspring is better
702
+ if self.minimize:
703
+ should_replace = offspring_fitness[off_idx] < survivor_fit[parent_idx]
704
+ else:
705
+ should_replace = offspring_fitness[off_idx] > survivor_fit[parent_idx]
706
+
707
+ if should_replace:
708
+ survivors[parent_idx] = offspring[off_idx]
709
+ survivor_fit[parent_idx] = offspring_fitness[off_idx]
710
+
711
+ return survivors, survivor_fit
712
+
713
+
714
+ # =============================================================================
715
+ # Fitness Survival (Simple Truncation)
716
+ # =============================================================================
717
+
718
+ class FitnessSurvival(Survival):
719
+ """
720
+ Simple fitness-based truncation survival.
721
+
722
+ A minimal survival operator that simply selects the n_survive
723
+ best individuals from the combined population. This is equivalent
724
+ to MergeSurvival without elitism, but provided as a simpler
725
+ alternative.
726
+
727
+ Args:
728
+ n_survive: Number of individuals to survive.
729
+ adaptive: If True, use soft selection.
730
+ temperature: Temperature for soft selection.
731
+ minimize: If True, lower fitness is better.
732
+
733
+ Example:
734
+ >>> survival = FitnessSurvival(n_survive=100)
735
+ >>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
736
+ """
737
+
738
+ def __init__(
739
+ self,
740
+ n_survive: Optional[int] = None,
741
+ adaptive: bool = False,
742
+ temperature: float = 1.0,
743
+ learn_temperature: bool = True,
744
+ minimize: bool = True,
745
+ ) -> None:
746
+ super().__init__(
747
+ n_survive=n_survive,
748
+ elitism=False,
749
+ n_elite=0,
750
+ adaptive=adaptive,
751
+ temperature=temperature,
752
+ learn_temperature=learn_temperature,
753
+ minimize=minimize,
754
+ )
755
+
756
+ def _survive(
757
+ self,
758
+ parents: Tensor,
759
+ parent_fitness: Tensor,
760
+ offspring: Tensor,
761
+ offspring_fitness: Tensor,
762
+ n_survive: int,
763
+ ) -> Tuple[Tensor, Tensor]:
764
+ """Select best n_survive from combined population."""
765
+
766
+ combined_pop = torch.cat([parents, offspring], dim=0)
767
+ combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
768
+
769
+ if self.adaptive:
770
+ # Soft selection with gradient through temperature
771
+ #
772
+ # Use straight-through Gumbel-Softmax top-k without replacement so:
773
+ # - forward is discrete (hard survivors)
774
+ # - backward carries gradients through temperature
775
+ survivors, survivor_fit = self._adaptive_topk_select(
776
+ combined_pop,
777
+ combined_fit,
778
+ k=n_survive,
779
+ )
780
+
781
+ # Re-sort by fitness (best first)
782
+ final_idx = self._sort_indices_best_first(survivor_fit)
783
+ return survivors[final_idx], survivor_fit[final_idx]
784
+ else:
785
+ sorted_idx = self._best_indices(combined_fit, n_survive)
786
+ return combined_pop[sorted_idx], combined_fit[sorted_idx]
787
+
788
+
789
+ # =============================================================================
790
+ # Age-Based Survival
791
+ # =============================================================================
792
+
793
+ class AgeSurvival(Survival):
794
+ """
795
+ Age-based survival with fitness tie-breaking.
796
+
797
+ Tracks the age (number of generations) of each individual.
798
+ Older individuals are replaced first, with fitness used as
799
+ a tie-breaker. This can help maintain diversity by preventing
800
+ any individual from dominating the population indefinitely.
801
+
802
+ Args:
803
+ n_survive: Number of individuals to survive.
804
+ max_age: Maximum age before forced replacement.
805
+ elitism: If True, preserve best regardless of age.
806
+ n_elite: Number of age-exempt elite individuals.
807
+ adaptive: If True, use soft selection.
808
+ temperature: Temperature for soft selection.
809
+ minimize: If True, lower fitness is better.
810
+
811
+ Example:
812
+ >>> survival = AgeSurvival(n_survive=100, max_age=10)
813
+ >>> # Individuals older than 10 generations are replaced
814
+ >>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
815
+
816
+ Note:
817
+ Age tracking is maintained externally. This operator expects
818
+ ages to be passed or uses fitness-only selection if not provided.
819
+ """
820
+
821
+ def __init__(
822
+ self,
823
+ n_survive: Optional[int] = None,
824
+ max_age: int = 10,
825
+ elitism: bool = True,
826
+ n_elite: int = 1,
827
+ adaptive: bool = False,
828
+ temperature: float = 1.0,
829
+ learn_temperature: bool = True,
830
+ minimize: bool = True,
831
+ ) -> None:
832
+ super().__init__(
833
+ n_survive=n_survive,
834
+ elitism=elitism,
835
+ n_elite=n_elite,
836
+ adaptive=adaptive,
837
+ temperature=temperature,
838
+ learn_temperature=learn_temperature,
839
+ minimize=minimize,
840
+ )
841
+ self.max_age = max_age
842
+
843
+ def _survive(
844
+ self,
845
+ parents: Tensor,
846
+ parent_fitness: Tensor,
847
+ offspring: Tensor,
848
+ offspring_fitness: Tensor,
849
+ n_survive: int,
850
+ ) -> Tuple[Tensor, Tensor]:
851
+ """Age-based selection with fitness tie-breaking."""
852
+
853
+ n_parents = parents.shape[0]
854
+ n_offspring = offspring.shape[0]
855
+
856
+ # Combine populations
857
+ combined_pop = torch.cat([parents, offspring], dim=0)
858
+ combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
859
+
860
+ # Create age tensor (parents have age >= 1, offspring have age 0)
861
+ # In practice, ages would be tracked externally
862
+ ages = torch.cat([
863
+ torch.ones(n_parents, device=parents.device), # Parents
864
+ torch.zeros(n_offspring, device=offspring.device), # Offspring
865
+ ])
866
+
867
+ # Create composite score: prioritize younger and fitter
868
+ # Lower score = more likely to survive
869
+ if self.minimize:
870
+ fit_score = combined_fit
871
+ else:
872
+ fit_score = -combined_fit
873
+
874
+ # Normalize fitness to [0, 1] for fair combination with age
875
+ fit_min, fit_max = fit_score.min(), fit_score.max()
876
+ if fit_max > fit_min:
877
+ fit_norm = (fit_score - fit_min) / (fit_max - fit_min)
878
+ else:
879
+ fit_norm = torch.zeros_like(fit_score)
880
+
881
+ # Composite: 70% fitness, 30% age (normalized)
882
+ age_norm = ages / (self.max_age + 1)
883
+ composite = 0.7 * fit_norm + 0.3 * age_norm
884
+
885
+ if self.adaptive:
886
+ # Soft selection based on composite score (lower = better, so negate)
887
+ scores = -composite # higher = better
888
+
889
+ # Straight-through Gumbel-Softmax top-k without replacement
890
+ W = self._gumbel_topk(scores, k=n_survive, dim=0) # [n_survive, n_total]
891
+ survivors = W @ combined_pop
892
+ survivor_fit = W @ combined_fit
893
+
894
+ # Re-sort by actual fitness (best first) for consistent output ordering
895
+ final_idx = self._sort_indices_best_first(survivor_fit)
896
+ return survivors[final_idx], survivor_fit[final_idx]
897
+ else:
898
+ # Hard selection
899
+ sorted_idx = TopKSelection.best_indices(composite, k=n_survive, minimize=True)
900
+ survivors = combined_pop[sorted_idx]
901
+ survivor_fit = combined_fit[sorted_idx]
902
+
903
+ # Handle elitism
904
+ if self.elitism and self.n_elite > 0:
905
+ elite_idx = self._best_indices(parent_fitness, self.n_elite)
906
+ elite_pop = parents[elite_idx].clone()
907
+ elite_fit = parent_fitness[elite_idx].clone()
908
+
909
+ # Replace last n_elite survivors with elites
910
+ n_remaining = n_survive - self.n_elite
911
+ survivors = torch.cat([elite_pop, survivors[:n_remaining]], dim=0)
912
+ survivor_fit = torch.cat([elite_fit, survivor_fit[:n_remaining]], dim=0)
913
+
914
+ final_idx = self._sort_indices_best_first(survivor_fit)
915
+ survivors = survivors[final_idx]
916
+ survivor_fit = survivor_fit[final_idx]
917
+
918
+ return survivors, survivor_fit
919
+
920
+
921
+ # =============================================================================
922
+ # Utility Functions
923
+ # =============================================================================
924
+
925
+ def get_survival(
926
+ strategy: str,
927
+ n_survive: Optional[int] = None,
928
+ elitism: bool = True,
929
+ n_elite: int = 1,
930
+ adaptive: bool = False,
931
+ **kwargs,
932
+ ) -> Survival:
933
+ """
934
+ Factory function to create survival operators by name.
935
+
936
+ Args:
937
+ strategy: Survival strategy name. Options:
938
+ - 'merge', 'plus', '(mu+lambda)': MergeSurvival
939
+ - 'comma', '(mu,lambda)': CommaSurvival
940
+ - 'replace_worst', 'steady_state': ReplaceWorstSurvival
941
+ - 'fitness', 'truncation': FitnessSurvival
942
+ - 'age': AgeSurvival
943
+ n_survive: Number of individuals to survive.
944
+ elitism: Whether to preserve best individuals.
945
+ n_elite: Number of elite individuals.
946
+ adaptive: Enable gradient flow.
947
+ **kwargs: Additional arguments for specific strategies.
948
+
949
+ Returns:
950
+ Configured Survival operator.
951
+
952
+ Example:
953
+ >>> survival = get_survival('plus', n_survive=100, elitism=True)
954
+ >>> survival = get_survival('comma', n_survive=50, adaptive=True)
955
+ """
956
+ strategy = strategy.lower().strip()
957
+
958
+ if strategy in ['merge', 'plus', '(mu+lambda)', 'mu+lambda']:
959
+ return MergeSurvival(
960
+ n_survive=n_survive,
961
+ elitism=elitism,
962
+ n_elite=n_elite,
963
+ adaptive=adaptive,
964
+ **kwargs,
965
+ )
966
+ elif strategy in ['comma', '(mu,lambda)', 'mu,lambda']:
967
+ return CommaSurvival(
968
+ n_survive=n_survive,
969
+ elitism=elitism,
970
+ n_elite=n_elite,
971
+ adaptive=adaptive,
972
+ **kwargs,
973
+ )
974
+ elif strategy in ['replace_worst', 'steady_state', 'steady-state']:
975
+ return ReplaceWorstSurvival(
976
+ n_survive=n_survive,
977
+ elitism=elitism,
978
+ n_elite=n_elite,
979
+ adaptive=adaptive,
980
+ **kwargs,
981
+ )
982
+ elif strategy in ['fitness', 'truncation']:
983
+ return FitnessSurvival(
984
+ n_survive=n_survive,
985
+ adaptive=adaptive,
986
+ **kwargs,
987
+ )
988
+ elif strategy in ['age', 'age_based']:
989
+ return AgeSurvival(
990
+ n_survive=n_survive,
991
+ elitism=elitism,
992
+ n_elite=n_elite,
993
+ adaptive=adaptive,
994
+ **kwargs,
995
+ )
996
+ else:
997
+ raise ValueError(
998
+ f"Unknown survival strategy: '{strategy}'. "
999
+ f"Options: merge/plus, comma, replace_worst, fitness, age"
1000
+ )