evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,981 @@
1
+ """
2
+ Selection operators for parent selection.
3
+
4
+ This module provides strategies for selecting parents from the
5
+ population for recombination. All selectors support both classical
6
+ (hard) and differentiable (i.e., adaptive) (Gumbel-Softmax) modes.
7
+
8
+ Available selectors:
9
+ - TournamentSelection: Tournament-based selection
10
+ - RouletteSelection: Fitness-proportionate selection
11
+ - RankSelection: Rank-based selection
12
+ - RandomSelection: Uniform random selection
13
+ - TruncationSelection: Select *from* the top fraction (samples within truncated set)
14
+ - TopKSelection: Deterministic top-k WITHOUT replacement (hard + differentiable)
15
+
16
+ Differentiable Mode:
17
+ When `adaptive=True`, selection uses Gumbel-Softmax
18
+ relaxation with straight-through estimator, allowing gradients
19
+ to flow through the selection process.
20
+
21
+ Example:
22
+ >>> from evograd.operators import TournamentSelection
23
+ >>>
24
+ >>> # Classical mode
25
+ >>> selector = TournamentSelection(tournament_size=3)
26
+ >>> parents = selector(population, fitness, n_parents=50)
27
+ >>>
28
+ >>> # Differentiable mode
29
+ >>> selector = TournamentSelection(
30
+ ... tournament_size=3,
31
+ ... adaptive=True,
32
+ ... temperature=1.0,
33
+ ... )
34
+ >>> parents = selector(population, fitness, n_parents=50)
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ from abc import ABC, abstractmethod
40
+ from typing import Optional, Tuple, Union
41
+
42
+ import math
43
+ import torch
44
+ import torch.nn as nn
45
+ from torch import Tensor
46
+
47
+ from evograd.operators.relaxations import gumbel_softmax
48
+
49
+ __all__ = [
50
+ "Selection",
51
+ "TournamentSelection",
52
+ "RouletteSelection",
53
+ "RankSelection",
54
+ "RandomSelection",
55
+ "TopKSelection",
56
+ "TruncationSelection",
57
+ "StochasticUniversalSampling",
58
+ ]
59
+
60
+
61
+ # =============================================================================
62
+ # Base Selection Class
63
+ # =============================================================================
64
+
65
+ class Selection(nn.Module, ABC):
66
+ """
67
+ Abstract base class for selection operators.
68
+
69
+ Subclasses must implement:
70
+ - _select(): Perform selection and return indices
71
+
72
+ Args:
73
+ adaptive: If True, use Gumbel-Softmax for soft selection.
74
+ temperature: Temperature for Gumbel-Softmax (lower = harder).
75
+ learn_temperature: If True, temperature is a learnable parameter.
76
+ minimize: If True, lower fitness is better (default).
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ adaptive: bool = False,
82
+ temperature: float = 1.0,
83
+ learn_temperature: bool = True,
84
+ minimize: bool = True,
85
+ ) -> None:
86
+ super().__init__()
87
+
88
+ self._MIN_TEMPERATURE = 0.05
89
+ self._MAX_TEMPERATURE = 10.0
90
+
91
+ self.adaptive = adaptive
92
+ self.minimize = minimize
93
+
94
+ # Temperature parameter
95
+ if learn_temperature and adaptive:
96
+ # Store as log for positivity
97
+ self._log_temperature = nn.Parameter(
98
+ torch.tensor(temperature).log()
99
+ )
100
+ else:
101
+ self.register_buffer(
102
+ "_log_temperature",
103
+ torch.tensor(temperature).log()
104
+ )
105
+
106
+ @property
107
+ def temperature(self) -> Tensor:
108
+ """Current temperature value."""
109
+ return self._log_temperature.exp()
110
+
111
+ @temperature.setter
112
+ def temperature(self, value: float) -> None:
113
+ """Set temperature value."""
114
+ with torch.no_grad():
115
+ self._log_temperature.fill_(torch.tensor(value).log())
116
+
117
+ def _fitness_to_scores(self, fitness: Tensor) -> Tensor:
118
+ """
119
+ Convert fitness to selection scores (higher = better).
120
+
121
+ Args:
122
+ fitness: Raw fitness values [n_pop].
123
+
124
+ Returns:
125
+ Selection scores [n_pop] where higher is better.
126
+ """
127
+ if self.minimize:
128
+ # Negate so lower fitness = higher score
129
+ return -fitness
130
+ else:
131
+ return fitness
132
+
133
+
134
+ def _clamp_temperature(self):
135
+ # Clamp temperature for numerical stability (operator responsibility)
136
+ if hasattr(self, "_log_temperature") and self._log_temperature is not None:
137
+ with torch.no_grad():
138
+ self._log_temperature.clamp_(
139
+ math.log(self._MIN_TEMPERATURE),
140
+ math.log(self._MAX_TEMPERATURE),
141
+ )
142
+
143
+ @abstractmethod
144
+ def _select(
145
+ self,
146
+ population: Tensor,
147
+ fitness: Tensor,
148
+ n_select: int,
149
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
150
+ """
151
+ Perform selection.
152
+
153
+ Args:
154
+ population: Population tensor [n_pop, n_var].
155
+ fitness: Fitness values [n_pop].
156
+ n_select: Number of individuals to select.
157
+
158
+ Returns:
159
+ Selected individuals [n_select, n_var] or
160
+ tuple of (selected, indices) if return_indices=True.
161
+ """
162
+ pass
163
+
164
+ def forward(
165
+ self,
166
+ population: Tensor,
167
+ fitness: Tensor,
168
+ n_select: Optional[int] = None,
169
+ return_indices: bool = False,
170
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
171
+ """
172
+ Select parents from population.
173
+
174
+ Args:
175
+ population: Population tensor [n_pop, n_var].
176
+ fitness: Fitness values [n_pop].
177
+ n_select: Number to select (default: population size).
178
+ return_indices: If True, also return selection indices.
179
+
180
+ Returns:
181
+ Selected individuals [n_select, n_var], or
182
+ tuple (selected, indices) if return_indices=True.
183
+ """
184
+
185
+ if n_select is None:
186
+ n_select = population.shape[0]
187
+
188
+ self._clamp_temperature()
189
+ result = self._select(population, fitness, n_select)
190
+
191
+ if return_indices:
192
+ return result
193
+ elif isinstance(result, tuple):
194
+ return result[0]
195
+ return result
196
+
197
+ # Note: Do NOT override __call__. nn.Module.__call__ dispatches to
198
+ # forward() and fires registered hooks (forward_pre_hooks, forward_hooks,
199
+ # and the autograd profiler). Overriding __call__ would bypass all of these.
200
+
201
+
202
+ # =============================================================================
203
+ # Tournament Selection
204
+ # =============================================================================
205
+
206
+ class TournamentSelection(Selection):
207
+ """
208
+ Tournament selection.
209
+
210
+ For each selection, randomly pick `tournament_size` individuals
211
+ and select the best one. This is the most common selection
212
+ operator for genetic algorithms.
213
+
214
+ In differentiable mode, tournament winners are selected using
215
+ Gumbel-Softmax over the tournament participants.
216
+
217
+ Args:
218
+ tournament_size: Number of individuals per tournament.
219
+ adaptive: If True, use Gumbel-Softmax selection.
220
+ temperature: Temperature for Gumbel-Softmax.
221
+ learn_temperature: If True, temperature is learnable.
222
+ minimize: If True, lower fitness is better.
223
+ replacement: If True, allow same individual in tournament.
224
+
225
+ Example:
226
+ >>> selector = TournamentSelection(tournament_size=3)
227
+ >>> parents = selector(population, fitness, n_select=50)
228
+ """
229
+
230
+ def __init__(
231
+ self,
232
+ tournament_size: int = 3,
233
+ adaptive: bool = False,
234
+ temperature: float = 1.0,
235
+ learn_temperature: bool = True,
236
+ minimize: bool = True,
237
+ replacement: bool = True,
238
+ ) -> None:
239
+ super().__init__(
240
+ adaptive=adaptive,
241
+ temperature=temperature,
242
+ learn_temperature=learn_temperature,
243
+ minimize=minimize,
244
+ )
245
+
246
+ if tournament_size < 2:
247
+ raise ValueError(f"tournament_size must be >= 2, got {tournament_size}")
248
+
249
+ self.tournament_size = tournament_size
250
+ self.replacement = replacement
251
+
252
+ def _select(
253
+ self,
254
+ population: Tensor,
255
+ fitness: Tensor,
256
+ n_select: int,
257
+ ) -> Tuple[Tensor, Tensor]:
258
+ n_pop, n_var = population.shape
259
+ device = population.device
260
+ k = self.tournament_size
261
+
262
+ scores = self._fitness_to_scores(fitness)
263
+
264
+ # Sample tournament participants
265
+ # Shape: [n_select, tournament_size]
266
+ if self.replacement:
267
+ tournament_idx = torch.randint(
268
+ 0, n_pop, (n_select, k), device=device
269
+ )
270
+ else:
271
+ # Without replacement (slower)
272
+ tournament_idx = torch.stack([
273
+ torch.randperm(n_pop, device=device)[:k]
274
+ for _ in range(n_select)
275
+ ])
276
+
277
+ # Get scores for tournament participants
278
+ tournament_scores = scores[tournament_idx] # [n_select, k]
279
+
280
+ if self.adaptive:
281
+ # Gumbel-Softmax selection within each tournament
282
+ # Use scores as logits, PASS TEMPERATURE
283
+ weights = gumbel_softmax(
284
+ tournament_scores,
285
+ temperature=self.temperature, # Pass temperature
286
+ dim=-1
287
+ )
288
+
289
+ # Get tournament participants
290
+ tournament_pop = population[tournament_idx] # [n_select, k, n_var]
291
+
292
+ # Weighted combination (hard weights = one-hot)
293
+ selected = torch.einsum('nk,nkd->nd', weights, tournament_pop)
294
+
295
+ # For indices, use argmax of weights
296
+ relative_idx = weights.argmax(dim=-1) # [n_select]
297
+ indices = tournament_idx.gather(1, relative_idx.unsqueeze(-1)).squeeze(-1)
298
+ else:
299
+ # Hard selection: pick best in each tournament
300
+ relative_idx = tournament_scores.argmax(dim=-1) # [n_select]
301
+ indices = tournament_idx.gather(1, relative_idx.unsqueeze(-1)).squeeze(-1)
302
+ selected = population[indices]
303
+
304
+ return selected, indices
305
+
306
+ def __repr__(self) -> str:
307
+ return (
308
+ f"TournamentSelection("
309
+ f"tournament_size={self.tournament_size}, "
310
+ f"adaptive={self.adaptive}, "
311
+ f"temperature={self.temperature.item():.3f})"
312
+ )
313
+
314
+
315
+ # =============================================================================
316
+ # Roulette (Fitness Proportionate) Selection
317
+ # =============================================================================
318
+
319
+ class RouletteSelection(Selection):
320
+ """
321
+ Roulette wheel (fitness proportionate) selection.
322
+
323
+ Selection probability is proportional to fitness. Better
324
+ individuals have higher probability of being selected.
325
+
326
+ In differentiable mode, uses Gumbel-Softmax over the
327
+ entire population with fitness-based logits.
328
+
329
+ Args:
330
+ adaptive: If True, use Gumbel-Softmax selection.
331
+ temperature: Temperature for Gumbel-Softmax.
332
+ learn_temperature: If True, temperature is learnable.
333
+ minimize: If True, lower fitness is better.
334
+ eps: Small constant to avoid division by zero.
335
+
336
+ Example:
337
+ >>> selector = RouletteSelection()
338
+ >>> parents = selector(population, fitness, n_select=50)
339
+
340
+ Note:
341
+ For minimisation problems, fitness is transformed to
342
+ ensure positive selection probabilities.
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ adaptive: bool = False,
348
+ temperature: float = 1.0,
349
+ learn_temperature: bool = True,
350
+ minimize: bool = True,
351
+ eps: float = 1e-10,
352
+ ) -> None:
353
+ super().__init__(
354
+ adaptive=adaptive,
355
+ temperature=temperature,
356
+ learn_temperature=learn_temperature,
357
+ minimize=minimize,
358
+ )
359
+ self.eps = eps
360
+
361
+ def _select(
362
+ self,
363
+ population: Tensor,
364
+ fitness: Tensor,
365
+ n_select: int,
366
+ ) -> Tuple[Tensor, Tensor]:
367
+ n_pop = population.shape[0]
368
+ device = population.device
369
+
370
+ scores = self._fitness_to_scores(fitness)
371
+
372
+ # Shift scores to be positive (for probability calculation)
373
+ shifted_scores = scores - scores.min() + self.eps
374
+
375
+ if self.adaptive:
376
+ # Use log-probabilities as logits
377
+ logits = torch.log(shifted_scores + self.eps)
378
+
379
+ # Expand logits for n_select samples
380
+ logits_expanded = logits.unsqueeze(0).expand(n_select, -1)
381
+
382
+ # Gumbel-Softmax selection, PASS TEMPERATURE
383
+ weights = gumbel_softmax(
384
+ logits_expanded,
385
+ temperature=self.temperature, # Pass temperature
386
+ dim=-1
387
+ )
388
+
389
+ # Weighted combination
390
+ selected = torch.matmul(weights, population) # [n_select, n_var]
391
+ indices = weights.argmax(dim=-1)
392
+ else:
393
+ # Classical roulette selection
394
+ probs = shifted_scores / shifted_scores.sum()
395
+ indices = torch.multinomial(probs, n_select, replacement=True)
396
+ selected = population[indices]
397
+
398
+ return selected, indices
399
+
400
+ def __repr__(self) -> str:
401
+ return (
402
+ f"RouletteSelection("
403
+ f"adaptive={self.adaptive}, "
404
+ f"temperature={self.temperature.item():.3f})"
405
+ )
406
+
407
+
408
+ # =============================================================================
409
+ # Rank Selection
410
+ # =============================================================================
411
+
412
+ class RankSelection(Selection):
413
+ """
414
+ Rank-based selection.
415
+
416
+ Selection probability is based on rank rather than raw fitness.
417
+ This reduces selection pressure compared to fitness-proportionate
418
+ selection and is more robust to fitness scaling.
419
+
420
+ Two ranking schemes are available:
421
+ - 'linear': Probability proportional to rank
422
+ - 'exponential': Probability decays exponentially with rank
423
+
424
+ Args:
425
+ scheme: Ranking scheme ('linear' or 'exponential').
426
+ selection_pressure: Controls selection intensity.
427
+ For 'linear': in [1.0, 2.0], higher = more pressure.
428
+ For 'exponential': decay factor, higher = more pressure.
429
+ adaptive: If True, use Gumbel-Softmax selection.
430
+ temperature: Temperature for Gumbel-Softmax.
431
+ learn_temperature: If True, temperature is learnable.
432
+ minimize: If True, lower fitness is better.
433
+
434
+ Example:
435
+ >>> selector = RankSelection(scheme='linear', selection_pressure=1.5)
436
+ >>> parents = selector(population, fitness, n_select=50)
437
+ """
438
+
439
+ def __init__(
440
+ self,
441
+ scheme: str = "linear",
442
+ selection_pressure: float = 1.5,
443
+ adaptive: bool = False,
444
+ temperature: float = 1.0,
445
+ learn_temperature: bool = True,
446
+ minimize: bool = True,
447
+ ) -> None:
448
+ super().__init__(
449
+ adaptive=adaptive,
450
+ temperature=temperature,
451
+ learn_temperature=learn_temperature,
452
+ minimize=minimize,
453
+ )
454
+
455
+ if scheme not in ("linear", "exponential"):
456
+ raise ValueError(f"scheme must be 'linear' or 'exponential', got '{scheme}'")
457
+
458
+ self.scheme = scheme
459
+ self.selection_pressure = selection_pressure
460
+
461
+ def _compute_rank_probabilities(
462
+ self,
463
+ n_pop: int,
464
+ device: torch.device,
465
+ dtype: torch.dtype,
466
+ ) -> Tensor:
467
+ """Compute selection probabilities based on rank."""
468
+ # Ranks from 1 (best) to n_pop (worst)
469
+ ranks = torch.arange(1, n_pop + 1, device=device, dtype=dtype)
470
+
471
+ if self.scheme == "linear":
472
+ # Linear ranking: P(rank=i) = (2-s)/n + 2*(s-1)*(n-i)/(n*(n-1))
473
+ # where s is selection pressure in [1, 2]
474
+ s = self.selection_pressure
475
+ n = float(n_pop)
476
+ probs = (2 - s) / n + 2 * (s - 1) * (n - ranks) / (n * (n - 1))
477
+ else: # exponential
478
+ # Exponential ranking: P(rank=i) = exp(-c * (i-1))
479
+ c = self.selection_pressure
480
+ probs = torch.exp(-c * (ranks - 1))
481
+
482
+ # Normalise
483
+ return probs / probs.sum()
484
+
485
+ def _select(
486
+ self,
487
+ population: Tensor,
488
+ fitness: Tensor,
489
+ n_select: int,
490
+ ) -> Tuple[Tensor, Tensor]:
491
+ n_pop = population.shape[0]
492
+ device = population.device
493
+ dtype = population.dtype
494
+
495
+ scores = self._fitness_to_scores(fitness)
496
+
497
+ # Sort by score (descending, so best first)
498
+ sorted_indices = torch.argsort(scores, descending=True)
499
+
500
+ # Compute rank-based probabilities
501
+ rank_probs = self._compute_rank_probabilities(n_pop, device, dtype)
502
+
503
+ if self.adaptive:
504
+ # Use log-probabilities as logits
505
+ logits = torch.log(rank_probs + 1e-10)
506
+ logits_expanded = logits.unsqueeze(0).expand(n_select, -1)
507
+
508
+ # Gumbel-Softmax selection (in rank space), PASS TEMPERATURE
509
+ weights = gumbel_softmax(
510
+ logits_expanded,
511
+ temperature=self.temperature, # Pass temperature
512
+ dim=-1
513
+ )
514
+
515
+ # Map back to original indices
516
+ sorted_pop = population[sorted_indices]
517
+ selected = torch.matmul(weights, sorted_pop)
518
+
519
+ # Get indices in original population
520
+ rank_indices = weights.argmax(dim=-1)
521
+ indices = sorted_indices[rank_indices]
522
+ else:
523
+ # Classical sampling by rank
524
+ rank_indices = torch.multinomial(rank_probs, n_select, replacement=True)
525
+ indices = sorted_indices[rank_indices]
526
+ selected = population[indices]
527
+
528
+ return selected, indices
529
+
530
+ def __repr__(self) -> str:
531
+ return (
532
+ f"RankSelection("
533
+ f"scheme='{self.scheme}', "
534
+ f"selection_pressure={self.selection_pressure}, "
535
+ f"adaptive={self.adaptive})"
536
+ )
537
+
538
+
539
+ # =============================================================================
540
+ # Random Selection
541
+ # =============================================================================
542
+
543
+ class RandomSelection(Selection):
544
+ """
545
+ Uniform random selection (baseline).
546
+
547
+ Selects individuals uniformly at random, ignoring fitness.
548
+ Useful as a baseline or for algorithms that don't use
549
+ fitness-based selection.
550
+
551
+ Args:
552
+ replacement: If True, allow selecting same individual multiple times.
553
+ adaptive: If True, use Gumbel-Softmax (uniform logits).
554
+ temperature: Temperature for Gumbel-Softmax.
555
+
556
+ Example:
557
+ >>> selector = RandomSelection()
558
+ >>> parents = selector(population, fitness, n_select=50)
559
+ """
560
+
561
+ def __init__(
562
+ self,
563
+ replacement: bool = True,
564
+ adaptive: bool = False,
565
+ temperature: float = 1.0,
566
+ ) -> None:
567
+ super().__init__(
568
+ adaptive=adaptive,
569
+ temperature=temperature,
570
+ learn_temperature=False,
571
+ minimize=True,
572
+ )
573
+ self.replacement = replacement
574
+
575
+ def _select(
576
+ self,
577
+ population: Tensor,
578
+ fitness: Tensor,
579
+ n_select: int,
580
+ ) -> Tuple[Tensor, Tensor]:
581
+ n_pop = population.shape[0]
582
+ device = population.device
583
+
584
+ if self.adaptive:
585
+ # Uniform logits
586
+ logits = torch.zeros(n_select, n_pop, device=device)
587
+ weights = gumbel_softmax(
588
+ logits,
589
+ temperature=self.temperature, # Pass temperature
590
+ dim=-1
591
+ )
592
+
593
+ selected = torch.matmul(weights, population)
594
+ indices = weights.argmax(dim=-1)
595
+ else:
596
+ if self.replacement:
597
+ indices = torch.randint(0, n_pop, (n_select,), device=device)
598
+ else:
599
+ if n_select > n_pop:
600
+ raise ValueError(
601
+ f"Cannot select {n_select} from {n_pop} without replacement"
602
+ )
603
+ indices = torch.randperm(n_pop, device=device)[:n_select]
604
+
605
+ selected = population[indices]
606
+
607
+ return selected, indices
608
+
609
+ def __repr__(self) -> str:
610
+ return f"RandomSelection(replacement={self.replacement})"
611
+
612
+
613
+
614
+ # =============================================================================
615
+ # Top-K Selection (Deterministic / Without Replacement)
616
+ # =============================================================================
617
+
618
+
619
+ class TopKSelection(Selection):
620
+ """
621
+ Top-k (elitist) selection without replacement.
622
+
623
+ Selects the best `k` individuals according to fitness. This is the
624
+ canonical "take the best" operator. Unlike most parent selectors,
625
+ this returns a *subset* of the population (no duplicates).
626
+
627
+ In differentiable mode, uses a sequential (without replacement)
628
+ straight-through Gumbel-Softmax relaxation.
629
+
630
+ Args:
631
+ adaptive: If True, use differentiable top-k selection.
632
+ temperature: Temperature for Gumbel-Softmax.
633
+ learn_temperature: If True, temperature is learnable.
634
+ minimize: If True, lower fitness is better.
635
+
636
+ Example:
637
+ >>> # Keep best 10 individuals
638
+ >>> selector = TopKSelection()
639
+ >>> best10, idx = selector(population, fitness, n_select=10)
640
+ """
641
+
642
+
643
+ @staticmethod
644
+ def best_indices(
645
+ fitness: Tensor,
646
+ k: int,
647
+ minimize: bool = True,
648
+ ) -> Tensor:
649
+ """Return indices of the best-k individuals (no replacement)."""
650
+ scores = -fitness if minimize else fitness
651
+ return torch.topk(scores, k=k, largest=True).indices
652
+
653
+ @staticmethod
654
+ def worst_indices(
655
+ fitness: Tensor,
656
+ k: int,
657
+ minimize: bool = True,
658
+ ) -> Tensor:
659
+ """Return indices of the worst-k individuals (no replacement)."""
660
+ scores = -fitness if minimize else fitness
661
+ return torch.topk(scores, k=k, largest=False).indices
662
+
663
+ @staticmethod
664
+ def sort_indices_best_first(
665
+ fitness: Tensor,
666
+ minimize: bool = True,
667
+ ) -> Tensor:
668
+ """Return indices sorting the population best->worst."""
669
+ scores = -fitness if minimize else fitness
670
+ return torch.argsort(scores, descending=True)
671
+
672
+ @staticmethod
673
+ def gumbel_topk(
674
+ logits: Tensor,
675
+ k: int,
676
+ temperature: Union[float, Tensor] = 1.0,
677
+ dim: int = -1,
678
+ eps: float = 1e-10,
679
+ ) -> Tuple[Tensor, Tensor]:
680
+ """
681
+ Differentiable top-k without replacement using sequential Gumbel-Softmax.
682
+
683
+ Forward pass returns hard one-hot selections (straight-through),
684
+ and we mask out previously selected items to avoid duplicates.
685
+
686
+ Args:
687
+ logits: Logits over items [..., n_items].
688
+ k: Number of items to select.
689
+ temperature: Gumbel-Softmax temperature.
690
+ dim: Dimension along which to select.
691
+ eps: Numerical stability constant.
692
+
693
+ Returns:
694
+ weights: Selection weights [..., k, n_items] (one-hot ST).
695
+ indices: Selected indices [..., k].
696
+ """
697
+ if k <= 0:
698
+ raise ValueError(f"k must be > 0, got {k}")
699
+
700
+ n_items = logits.shape[dim]
701
+ if k > n_items:
702
+ raise ValueError(f"Cannot select k={k} from n_items={n_items}")
703
+
704
+ # Work on a copy we can mask in-place
705
+ masked_logits = logits.clone()
706
+
707
+ weights_list = []
708
+ indices_list = []
709
+
710
+ for _ in range(k):
711
+ w = gumbel_softmax(
712
+ masked_logits,
713
+ temperature=temperature,
714
+ dim=dim,
715
+ eps=eps,
716
+ )
717
+ idx = w.argmax(dim=dim)
718
+
719
+ weights_list.append(w)
720
+ indices_list.append(idx)
721
+
722
+ # Mask selected items so they cannot be chosen again
723
+ # (use a large negative number instead of -inf for safety)
724
+ if dim != -1:
725
+ # Move selection dim to end for consistent masking
726
+ perm = list(range(masked_logits.dim()))
727
+ perm[dim], perm[-1] = perm[-1], perm[dim]
728
+ masked_logits = masked_logits.permute(perm)
729
+
730
+ # idx now refers to last dim
731
+ scatter_idx = idx.unsqueeze(-1)
732
+ masked_logits = masked_logits.scatter(-1, scatter_idx, -1e9)
733
+
734
+ # Restore original dim order
735
+ inv = [0] * len(perm)
736
+ for i, p in enumerate(perm):
737
+ inv[p] = i
738
+ masked_logits = masked_logits.permute(inv)
739
+ else:
740
+ masked_logits = masked_logits.scatter(
741
+ -1, idx.unsqueeze(-1), -1e9
742
+ )
743
+
744
+ weights = torch.stack(weights_list, dim=-2) # [..., k, n_items]
745
+ indices = torch.stack(indices_list, dim=-1) # [..., k]
746
+ return weights, indices
747
+
748
+ def _select(
749
+ self,
750
+ population: Tensor,
751
+ fitness: Tensor,
752
+ n_select: int,
753
+ ) -> Tuple[Tensor, Tensor]:
754
+ n_pop, n_var = population.shape
755
+ device = population.device
756
+
757
+ if n_select > n_pop:
758
+ raise ValueError(f"Cannot select {n_select} from {n_pop} without replacement")
759
+
760
+ scores = self._fitness_to_scores(fitness)
761
+
762
+ if self.adaptive:
763
+ # Sequential Gumbel top-k in score space
764
+ logits = scores.unsqueeze(0) # [1, n_pop]
765
+ weights, idx = self.gumbel_topk(
766
+ logits,
767
+ k=n_select,
768
+ temperature=self.temperature,
769
+ dim=-1,
770
+ )
771
+
772
+ # weights: [1, k, n_pop]
773
+ # Build k selected individuals as soft/hard mixtures
774
+ selected = torch.matmul(weights.squeeze(0), population) # [k, n_var]
775
+ indices = idx.squeeze(0) # [k]
776
+ else:
777
+ # Hard deterministic top-k
778
+ indices = torch.topk(scores, n_select, largest=True).indices
779
+ selected = population[indices]
780
+
781
+ return selected, indices
782
+
783
+ def __repr__(self) -> str:
784
+ return (
785
+ f"TopKSelection("
786
+ f"adaptive={self.adaptive}, "
787
+ f"temperature={self.temperature.item():.3f})"
788
+ )
789
+
790
+ # =============================================================================
791
+ # Truncation Selection
792
+ # =============================================================================
793
+
794
+ class TruncationSelection(Selection):
795
+ """
796
+ Truncation (elitist) selection.
797
+
798
+ Selects only from the top fraction of the population.
799
+ This provides strong selection pressure.
800
+
801
+ Args:
802
+ truncation_ratio: Fraction of population to select from (0, 1].
803
+ adaptive: If True, use softmax over truncated population.
804
+ temperature: Temperature for soft selection.
805
+ learn_temperature: If True, temperature is learnable.
806
+ minimize: If True, lower fitness is better.
807
+
808
+ Example:
809
+ >>> # Select from top 20% of population
810
+ >>> selector = TruncationSelection(truncation_ratio=0.2)
811
+ >>> parents = selector(population, fitness, n_select=50)
812
+ """
813
+
814
+ def __init__(
815
+ self,
816
+ truncation_ratio: float = 0.5,
817
+ adaptive: bool = False,
818
+ temperature: float = 1.0,
819
+ learn_temperature: bool = True,
820
+ minimize: bool = True,
821
+ ) -> None:
822
+ super().__init__(
823
+ adaptive=adaptive,
824
+ temperature=temperature,
825
+ learn_temperature=learn_temperature,
826
+ minimize=minimize,
827
+ )
828
+
829
+ if not 0 < truncation_ratio <= 1:
830
+ raise ValueError(
831
+ f"truncation_ratio must be in (0, 1], got {truncation_ratio}"
832
+ )
833
+
834
+ self.truncation_ratio = truncation_ratio
835
+
836
+ def _select(
837
+ self,
838
+ population: Tensor,
839
+ fitness: Tensor,
840
+ n_select: int,
841
+ ) -> Tuple[Tensor, Tensor]:
842
+ n_pop = population.shape[0]
843
+ device = population.device
844
+
845
+ scores = self._fitness_to_scores(fitness)
846
+
847
+ # Determine truncation size
848
+ n_truncated = max(1, int(n_pop * self.truncation_ratio))
849
+
850
+ # Get indices of top individuals
851
+ _, top_indices = torch.topk(scores, n_truncated)
852
+
853
+ if self.adaptive:
854
+ # Soft selection within truncated set
855
+ top_scores = scores[top_indices]
856
+ logits = top_scores.unsqueeze(0).expand(n_select, -1)
857
+
858
+ weights = gumbel_softmax(
859
+ logits,
860
+ temperature=self.temperature, # Pass temperature
861
+ dim=-1
862
+ )
863
+
864
+ top_pop = population[top_indices]
865
+ selected = torch.matmul(weights, top_pop)
866
+
867
+ relative_idx = weights.argmax(dim=-1)
868
+ indices = top_indices[relative_idx]
869
+ else:
870
+ # Random selection from truncated set
871
+ relative_idx = torch.randint(0, n_truncated, (n_select,), device=device)
872
+ indices = top_indices[relative_idx]
873
+ selected = population[indices]
874
+
875
+ return selected, indices
876
+
877
+ def __repr__(self) -> str:
878
+ return (
879
+ f"TruncationSelection("
880
+ f"truncation_ratio={self.truncation_ratio}, "
881
+ f"adaptive={self.adaptive})"
882
+ )
883
+
884
+
885
+ # =============================================================================
886
+ # Stochastic Universal Sampling (SUS)
887
+ # =============================================================================
888
+
889
+ class StochasticUniversalSampling(Selection):
890
+ """
891
+ Stochastic Universal Sampling (SUS).
892
+
893
+ Similar to roulette selection but uses evenly spaced pointers
894
+ on the wheel, reducing variance and ensuring a more uniform
895
+ sampling of fit individuals.
896
+
897
+ Args:
898
+ adaptive: If True, use Gumbel-Softmax approximation.
899
+ temperature: Temperature for Gumbel-Softmax.
900
+ learn_temperature: If True, temperature is learnable.
901
+ minimize: If True, lower fitness is better.
902
+ eps: Small constant for numerical stability.
903
+
904
+ Example:
905
+ >>> selector = StochasticUniversalSampling()
906
+ >>> parents = selector(population, fitness, n_select=50)
907
+ """
908
+
909
+ def __init__(
910
+ self,
911
+ adaptive: bool = False,
912
+ temperature: float = 1.0,
913
+ learn_temperature: bool = True,
914
+ minimize: bool = True,
915
+ eps: float = 1e-10,
916
+ ) -> None:
917
+ super().__init__(
918
+ adaptive=adaptive,
919
+ temperature=temperature,
920
+ learn_temperature=learn_temperature,
921
+ minimize=minimize,
922
+ )
923
+ self.eps = eps
924
+
925
+ def _select(
926
+ self,
927
+ population: Tensor,
928
+ fitness: Tensor,
929
+ n_select: int,
930
+ ) -> Tuple[Tensor, Tensor]:
931
+ n_pop = population.shape[0]
932
+ device = population.device
933
+ dtype = population.dtype
934
+
935
+ scores = self._fitness_to_scores(fitness)
936
+
937
+ # Shift to positive
938
+ shifted_scores = scores - scores.min() + self.eps
939
+ total = shifted_scores.sum()
940
+
941
+ if self.adaptive:
942
+ # Fall back to Gumbel-Softmax (SUS is inherently discrete)
943
+ logits = torch.log(shifted_scores + self.eps)
944
+ logits_expanded = logits.unsqueeze(0).expand(n_select, -1)
945
+
946
+ weights = gumbel_softmax(
947
+ logits_expanded,
948
+ temperature=self.temperature, # Pass temperature
949
+ dim=-1
950
+ )
951
+ selected = torch.matmul(weights, population)
952
+ indices = weights.argmax(dim=-1)
953
+ else:
954
+ # Classical SUS
955
+ # Compute cumulative sum
956
+ cumsum = torch.cumsum(shifted_scores, dim=0)
957
+
958
+ # Distance between pointers
959
+ pointer_distance = total / n_select
960
+
961
+ # Random starting point
962
+ start = torch.rand(1, device=device, dtype=dtype) * pointer_distance
963
+
964
+ # Pointers
965
+ pointers = start + pointer_distance * torch.arange(
966
+ n_select, device=device, dtype=dtype
967
+ )
968
+
969
+ # Find indices where pointers land
970
+ indices = torch.searchsorted(cumsum, pointers)
971
+ indices = torch.clamp(indices, 0, n_pop - 1)
972
+
973
+ selected = population[indices]
974
+
975
+ return selected, indices
976
+
977
+ def __repr__(self) -> str:
978
+ return (
979
+ f"StochasticUniversalSampling("
980
+ f"adaptive={self.adaptive})"
981
+ )