evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1000 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Survival selection operators for generational replacement.
|
|
3
|
+
|
|
4
|
+
This module provides strategies for selecting which individuals
|
|
5
|
+
survive to the next generation. All operators support both classical
|
|
6
|
+
and differentiable (i.e., adaptive) modes.
|
|
7
|
+
|
|
8
|
+
Available survival strategies:
|
|
9
|
+
- MergeSurvival: (μ+λ) - Select best from parents + offspring
|
|
10
|
+
- CommaSurvival: (μ,λ) - Select only from offspring
|
|
11
|
+
- ReplaceWorstSurvival: Steady-state replacement of worst
|
|
12
|
+
- FitnessSurvival: Simple fitness-based truncation
|
|
13
|
+
- AgeSurvival: Age-based replacement with fitness tie-breaking
|
|
14
|
+
|
|
15
|
+
Elitism:
|
|
16
|
+
Most survival operators support elitism, which ensures that the
|
|
17
|
+
best n_elite individuals from the parent population are always
|
|
18
|
+
preserved in the next generation.
|
|
19
|
+
|
|
20
|
+
Differentiable Mode:
|
|
21
|
+
When `adaptive=True`, survival selection uses soft ranking
|
|
22
|
+
based on fitness scores with temperature-scaled softmax,
|
|
23
|
+
if and only if another operators has `adaptive=True`,
|
|
24
|
+
allowing gradients to flow through the selection process
|
|
25
|
+
via the temperature parameter using straight-through estimator.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> from evograd.operators import MergeSurvival
|
|
29
|
+
>>>
|
|
30
|
+
>>> # Classical (μ+λ) survival
|
|
31
|
+
>>> survival = MergeSurvival(n_survive=100, elitism=True, n_elite=1)
|
|
32
|
+
>>> new_pop, new_fit = survival(
|
|
33
|
+
... parents, parent_fitness,
|
|
34
|
+
... offspring, offspring_fitness,
|
|
35
|
+
... )
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Differentiable mode
|
|
38
|
+
>>> survival = MergeSurvival(
|
|
39
|
+
... n_survive=100,
|
|
40
|
+
... adaptive=True,
|
|
41
|
+
... temperature=1.0,
|
|
42
|
+
... )
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
from __future__ import annotations
|
|
46
|
+
|
|
47
|
+
from abc import ABC, abstractmethod
|
|
48
|
+
from typing import Optional, Tuple
|
|
49
|
+
|
|
50
|
+
import torch
|
|
51
|
+
import torch.nn as nn
|
|
52
|
+
from torch import Tensor
|
|
53
|
+
|
|
54
|
+
from evograd.operators.selection import TopKSelection
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
"Survival",
|
|
58
|
+
"MergeSurvival",
|
|
59
|
+
"CommaSurvival",
|
|
60
|
+
"ReplaceWorstSurvival",
|
|
61
|
+
"FitnessSurvival",
|
|
62
|
+
"AgeSurvival",
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# =============================================================================
|
|
67
|
+
# Base Survival Class
|
|
68
|
+
# =============================================================================
|
|
69
|
+
|
|
70
|
+
class Survival(nn.Module, ABC):
|
|
71
|
+
"""
|
|
72
|
+
Abstract base class for survival selection operators.
|
|
73
|
+
|
|
74
|
+
Survival selection determines which individuals from the combined
|
|
75
|
+
parent and offspring populations survive to the next generation.
|
|
76
|
+
|
|
77
|
+
Subclasses must implement:
|
|
78
|
+
- _survive(): Perform survival selection
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
n_survive: Number of individuals to survive (population size).
|
|
82
|
+
elitism: If True, preserve best individuals from parents.
|
|
83
|
+
n_elite: Number of elite individuals to preserve.
|
|
84
|
+
adaptive: If True, use soft selection for gradients.
|
|
85
|
+
temperature: Temperature for soft selection.
|
|
86
|
+
learn_temperature: If True, temperature is learnable.
|
|
87
|
+
minimize: If True, lower fitness is better (default).
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
n_survive: Optional[int] = None,
|
|
93
|
+
elitism: bool = True,
|
|
94
|
+
n_elite: int = 1,
|
|
95
|
+
adaptive: bool = False,
|
|
96
|
+
temperature: float = 1.0,
|
|
97
|
+
learn_temperature: bool = True,
|
|
98
|
+
minimize: bool = True,
|
|
99
|
+
) -> None:
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
self.n_survive = n_survive
|
|
103
|
+
self.elitism = elitism
|
|
104
|
+
self.n_elite = n_elite if elitism else 0
|
|
105
|
+
self.adaptive = adaptive
|
|
106
|
+
self.minimize = minimize
|
|
107
|
+
|
|
108
|
+
# Temperature parameter for soft selection
|
|
109
|
+
if learn_temperature and adaptive:
|
|
110
|
+
self._log_temperature = nn.Parameter(
|
|
111
|
+
torch.tensor(temperature).log()
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
self.register_buffer(
|
|
115
|
+
"_log_temperature",
|
|
116
|
+
torch.tensor(temperature).log()
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def temperature(self) -> Tensor:
|
|
121
|
+
"""Current temperature value."""
|
|
122
|
+
return self._log_temperature.exp()
|
|
123
|
+
|
|
124
|
+
@temperature.setter
|
|
125
|
+
def temperature(self, value: float) -> None:
|
|
126
|
+
"""Set temperature value."""
|
|
127
|
+
with torch.no_grad():
|
|
128
|
+
self._log_temperature.fill_(torch.tensor(value).log())
|
|
129
|
+
|
|
130
|
+
def _fitness_to_scores(self, fitness: Tensor) -> Tensor:
|
|
131
|
+
"""
|
|
132
|
+
Convert fitness to selection scores (higher = better).
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
fitness: Raw fitness values [n].
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Selection scores [n] where higher is better.
|
|
139
|
+
"""
|
|
140
|
+
if self.minimize:
|
|
141
|
+
return -fitness
|
|
142
|
+
else:
|
|
143
|
+
return fitness
|
|
144
|
+
|
|
145
|
+
def _compute_soft_weights(self, fitness: Tensor) -> Tensor:
|
|
146
|
+
"""
|
|
147
|
+
Compute soft selection weights using temperature-scaled softmax.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
fitness: Fitness values [n].
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Soft selection weights [n] summing to 1.
|
|
154
|
+
"""
|
|
155
|
+
scores = self._fitness_to_scores(fitness)
|
|
156
|
+
return torch.softmax(scores / self.temperature, dim=0)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# -------------------------------------------------------------------------
|
|
161
|
+
# Shared Utilities
|
|
162
|
+
# -------------------------------------------------------------------------
|
|
163
|
+
|
|
164
|
+
def _best_indices(self, fitness: Tensor, k: int) -> Tensor:
|
|
165
|
+
"""
|
|
166
|
+
Indices of the best k individuals according to `minimize`.
|
|
167
|
+
|
|
168
|
+
Uses score-space (higher = better) so it works for both minimization
|
|
169
|
+
and maximization.
|
|
170
|
+
"""
|
|
171
|
+
return TopKSelection.best_indices(fitness, k, minimize=self.minimize)
|
|
172
|
+
|
|
173
|
+
def _worst_indices(self, fitness: Tensor, k: int) -> Tensor:
|
|
174
|
+
"""
|
|
175
|
+
Indices of the worst k individuals according to `minimize`.
|
|
176
|
+
|
|
177
|
+
Uses score-space (higher = better) so it works for both minimization
|
|
178
|
+
and maximization.
|
|
179
|
+
"""
|
|
180
|
+
return TopKSelection.worst_indices(fitness, k, minimize=self.minimize)
|
|
181
|
+
|
|
182
|
+
def _sort_indices_best_first(self, fitness: Tensor) -> Tensor:
|
|
183
|
+
"""Return indices that sort individuals from best to worst."""
|
|
184
|
+
return TopKSelection.sort_indices_best_first(fitness, minimize=self.minimize)
|
|
185
|
+
|
|
186
|
+
def _gumbel_topk(
|
|
187
|
+
self,
|
|
188
|
+
logits: Tensor,
|
|
189
|
+
k: int,
|
|
190
|
+
dim: int = 0,
|
|
191
|
+
eps:float = 1e-8,
|
|
192
|
+
) -> Tensor:
|
|
193
|
+
"""
|
|
194
|
+
Differentiable top-k without replacement using sequential Gumbel-Softmax.
|
|
195
|
+
|
|
196
|
+
Returns a stack of one-hot vectors (hard in forward, soft in backward)
|
|
197
|
+
with shape [k, n] (assuming `dim=0` / logits is 1D).
|
|
198
|
+
"""
|
|
199
|
+
weights, _ = TopKSelection.gumbel_topk(
|
|
200
|
+
logits,
|
|
201
|
+
k,
|
|
202
|
+
temperature=self.temperature,
|
|
203
|
+
dim=dim,
|
|
204
|
+
eps=eps,
|
|
205
|
+
)
|
|
206
|
+
return weights
|
|
207
|
+
|
|
208
|
+
def _adaptive_topk_select(
|
|
209
|
+
self,
|
|
210
|
+
population: Tensor,
|
|
211
|
+
fitness: Tensor,
|
|
212
|
+
k: int,
|
|
213
|
+
mask_out: Optional[Tensor] = None,
|
|
214
|
+
) -> Tuple[Tensor, Tensor]:
|
|
215
|
+
"""
|
|
216
|
+
Adaptive (differentiable) top-k selection with straight-through gradients.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
population: Candidate individuals [n, n_var].
|
|
220
|
+
fitness: Candidate fitness [n].
|
|
221
|
+
k: Number to select.
|
|
222
|
+
mask_out: Optional boolean mask [n] of candidates to exclude.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
survivors: Selected individuals [k, n_var].
|
|
226
|
+
survivor_fit: Selected fitness [k].
|
|
227
|
+
"""
|
|
228
|
+
n = population.shape[0]
|
|
229
|
+
if k <= 0:
|
|
230
|
+
raise ValueError(f"k must be > 0, got {k}")
|
|
231
|
+
if k > n:
|
|
232
|
+
raise ValueError(f"Cannot select k={k} from n={n}")
|
|
233
|
+
|
|
234
|
+
# Convert fitness to score-space (higher is better)
|
|
235
|
+
scores = self._fitness_to_scores(fitness)
|
|
236
|
+
|
|
237
|
+
# Optional masking (e.g., to exclude elites when selecting the remainder)
|
|
238
|
+
logits = scores.unsqueeze(0) # [1, n]
|
|
239
|
+
if mask_out is not None:
|
|
240
|
+
if mask_out.dtype != torch.bool:
|
|
241
|
+
raise TypeError("mask_out must be a boolean tensor")
|
|
242
|
+
if mask_out.shape[0] != n:
|
|
243
|
+
raise ValueError(f"mask_out must have shape [{n}], got {tuple(mask_out.shape)}")
|
|
244
|
+
logits = logits.clone()
|
|
245
|
+
logits[:, mask_out] = -1e9
|
|
246
|
+
|
|
247
|
+
# Sequential (without-replacement) ST Gumbel top-k
|
|
248
|
+
weights, idx = TopKSelection.gumbel_topk(
|
|
249
|
+
logits,
|
|
250
|
+
k=k,
|
|
251
|
+
temperature=self.temperature,
|
|
252
|
+
dim=-1,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# weights: [1, k, n] -> [k, n]
|
|
256
|
+
weights = weights.squeeze(0)
|
|
257
|
+
|
|
258
|
+
selected = weights @ population # [k, n_var]
|
|
259
|
+
|
|
260
|
+
# Compute (soft) fitness values for the selected survivors.
|
|
261
|
+
# Forward pass is hard (one-hot) due to straight-through Gumbel-Softmax,
|
|
262
|
+
# backward pass carries gradients through the soft weights.
|
|
263
|
+
selected_fit = weights @ fitness # [k]
|
|
264
|
+
|
|
265
|
+
return selected, selected_fit
|
|
266
|
+
|
|
267
|
+
def _survive(
|
|
268
|
+
self,
|
|
269
|
+
parents: Tensor,
|
|
270
|
+
parent_fitness: Tensor,
|
|
271
|
+
offspring: Tensor,
|
|
272
|
+
offspring_fitness: Tensor,
|
|
273
|
+
n_survive: int,
|
|
274
|
+
) -> Tuple[Tensor, Tensor]:
|
|
275
|
+
"""
|
|
276
|
+
Perform survival selection.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
parents: Parent population [n_parents, n_var].
|
|
280
|
+
parent_fitness: Parent fitness [n_parents].
|
|
281
|
+
offspring: Offspring population [n_offspring, n_var].
|
|
282
|
+
offspring_fitness: Offspring fitness [n_offspring].
|
|
283
|
+
n_survive: Number of individuals to survive.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Tuple of (survivors, survivor_fitness).
|
|
287
|
+
"""
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
def forward(
|
|
291
|
+
self,
|
|
292
|
+
parents: Tensor,
|
|
293
|
+
parent_fitness: Tensor,
|
|
294
|
+
offspring: Tensor,
|
|
295
|
+
offspring_fitness: Tensor,
|
|
296
|
+
n_survive: Optional[int] = None,
|
|
297
|
+
) -> Tuple[Tensor, Tensor]:
|
|
298
|
+
"""
|
|
299
|
+
Apply survival selection.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
parents: Parent population [n_parents, n_var].
|
|
303
|
+
parent_fitness: Parent fitness [n_parents].
|
|
304
|
+
offspring: Offspring population [n_offspring, n_var].
|
|
305
|
+
offspring_fitness: Offspring fitness [n_offspring].
|
|
306
|
+
n_survive: Number to survive (default: n_parents or self.n_survive).
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Tuple of (survivors, survivor_fitness).
|
|
310
|
+
"""
|
|
311
|
+
if n_survive is None:
|
|
312
|
+
n_survive = self.n_survive if self.n_survive is not None else parents.shape[0]
|
|
313
|
+
|
|
314
|
+
return self._survive(
|
|
315
|
+
parents, parent_fitness,
|
|
316
|
+
offspring, offspring_fitness,
|
|
317
|
+
n_survive,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Note: Do NOT override __call__. nn.Module.__call__ dispatches to
|
|
321
|
+
# forward() and fires registered hooks (forward_pre_hooks, forward_hooks,
|
|
322
|
+
# and the autograd profiler). Overriding __call__ would bypass all of these.
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# =============================================================================
|
|
326
|
+
# Merge Survival (μ+λ)
|
|
327
|
+
# =============================================================================
|
|
328
|
+
|
|
329
|
+
class MergeSurvival(Survival):
|
|
330
|
+
"""
|
|
331
|
+
(μ+λ) Merge survival: Select best from parents + offspring.
|
|
332
|
+
|
|
333
|
+
Combines the parent and offspring populations, then selects
|
|
334
|
+
the best n_survive individuals based on fitness. This is the
|
|
335
|
+
most common survival strategy for evolutionary algorithms.
|
|
336
|
+
|
|
337
|
+
With elitism enabled, the best n_elite parents are guaranteed
|
|
338
|
+
to survive regardless of offspring fitness.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
n_survive: Number of individuals to survive.
|
|
342
|
+
elitism: If True, preserve best parents (default: True).
|
|
343
|
+
n_elite: Number of elite parents to preserve.
|
|
344
|
+
adaptive: If True, use soft weighted selection.
|
|
345
|
+
temperature: Temperature for soft selection.
|
|
346
|
+
minimize: If True, lower fitness is better.
|
|
347
|
+
|
|
348
|
+
Example:
|
|
349
|
+
>>> survival = MergeSurvival(n_survive=100, elitism=True, n_elite=2)
|
|
350
|
+
>>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
def _survive(
|
|
354
|
+
self,
|
|
355
|
+
parents: Tensor,
|
|
356
|
+
parent_fitness: Tensor,
|
|
357
|
+
offspring: Tensor,
|
|
358
|
+
offspring_fitness: Tensor,
|
|
359
|
+
n_survive: int,
|
|
360
|
+
) -> Tuple[Tensor, Tensor]:
|
|
361
|
+
"""Select best from combined population."""
|
|
362
|
+
|
|
363
|
+
if self.adaptive:
|
|
364
|
+
return self._survive_adaptive(
|
|
365
|
+
parents, parent_fitness,
|
|
366
|
+
offspring, offspring_fitness,
|
|
367
|
+
n_survive,
|
|
368
|
+
)
|
|
369
|
+
else:
|
|
370
|
+
return self._survive_hard(
|
|
371
|
+
parents, parent_fitness,
|
|
372
|
+
offspring, offspring_fitness,
|
|
373
|
+
n_survive,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def _survive_hard(
|
|
377
|
+
self,
|
|
378
|
+
parents: Tensor,
|
|
379
|
+
parent_fitness: Tensor,
|
|
380
|
+
offspring: Tensor,
|
|
381
|
+
offspring_fitness: Tensor,
|
|
382
|
+
n_survive: int,
|
|
383
|
+
) -> Tuple[Tensor, Tensor]:
|
|
384
|
+
"""Hard (classical) survival selection."""
|
|
385
|
+
|
|
386
|
+
# Handle elitism first
|
|
387
|
+
if self.elitism and self.n_elite > 0:
|
|
388
|
+
# Get elite from parents
|
|
389
|
+
elite_idx = self._best_indices(parent_fitness, self.n_elite)
|
|
390
|
+
elite_pop = parents[elite_idx].clone()
|
|
391
|
+
elite_fit = parent_fitness[elite_idx].clone()
|
|
392
|
+
|
|
393
|
+
# Combine remaining
|
|
394
|
+
combined_pop = torch.cat([parents, offspring], dim=0)
|
|
395
|
+
combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
|
|
396
|
+
|
|
397
|
+
# Select remaining slots
|
|
398
|
+
n_remaining = n_survive - self.n_elite
|
|
399
|
+
sorted_idx = self._best_indices(combined_fit, n_remaining)
|
|
400
|
+
|
|
401
|
+
# Combine elite + selected
|
|
402
|
+
survivors = torch.cat([elite_pop, combined_pop[sorted_idx]], dim=0)
|
|
403
|
+
survivor_fit = torch.cat([elite_fit, combined_fit[sorted_idx]], dim=0)
|
|
404
|
+
|
|
405
|
+
# Re-sort by fitness
|
|
406
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
407
|
+
return survivors[final_idx], survivor_fit[final_idx]
|
|
408
|
+
else:
|
|
409
|
+
# Simple selection from combined
|
|
410
|
+
combined_pop = torch.cat([parents, offspring], dim=0)
|
|
411
|
+
combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
|
|
412
|
+
|
|
413
|
+
sorted_idx = self._best_indices(combined_fit, n_survive)
|
|
414
|
+
return combined_pop[sorted_idx], combined_fit[sorted_idx]
|
|
415
|
+
|
|
416
|
+
def _survive_adaptive(
|
|
417
|
+
self,
|
|
418
|
+
parents: Tensor,
|
|
419
|
+
parent_fitness: Tensor,
|
|
420
|
+
offspring: Tensor,
|
|
421
|
+
offspring_fitness: Tensor,
|
|
422
|
+
n_survive: int,
|
|
423
|
+
) -> Tuple[Tensor, Tensor]:
|
|
424
|
+
"""
|
|
425
|
+
Soft (differentiable/adaptive) survival selection.
|
|
426
|
+
|
|
427
|
+
Uses straight-through estimator via Gumbel-Softmax:
|
|
428
|
+
- Forward pass: discrete top-k without replacement
|
|
429
|
+
- Backward pass: soft gradients flow through temperature
|
|
430
|
+
|
|
431
|
+
Note: Elitism (if enabled) is applied as a hard constraint by first
|
|
432
|
+
preserving elite parents, then selecting the remaining survivors from
|
|
433
|
+
the remaining candidate pool.
|
|
434
|
+
"""
|
|
435
|
+
# Combine populations
|
|
436
|
+
combined_pop = torch.cat([parents, offspring], dim=0)
|
|
437
|
+
combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
|
|
438
|
+
|
|
439
|
+
# Handle elitism first (hard constraint)
|
|
440
|
+
if self.elitism and self.n_elite > 0:
|
|
441
|
+
n_elite = min(self.n_elite, parents.shape[0], n_survive)
|
|
442
|
+
elite_idx = self._best_indices(parent_fitness, n_elite)
|
|
443
|
+
elite_pop = parents[elite_idx].clone()
|
|
444
|
+
elite_fit = parent_fitness[elite_idx].clone()
|
|
445
|
+
|
|
446
|
+
n_remaining = n_survive - n_elite
|
|
447
|
+
if n_remaining <= 0:
|
|
448
|
+
# Sort best-first for consistency
|
|
449
|
+
final_idx = self._sort_indices_best_first(elite_fit)
|
|
450
|
+
return elite_pop[final_idx], elite_fit[final_idx]
|
|
451
|
+
|
|
452
|
+
# Mask out elites in the combined candidate pool to avoid duplicates
|
|
453
|
+
mask_out = torch.zeros(combined_fit.shape[0], dtype=torch.bool, device=combined_fit.device)
|
|
454
|
+
mask_out[elite_idx] = True # elite indices refer to parents, which occupy the first block in combined
|
|
455
|
+
|
|
456
|
+
survivors_rest, fit_rest = self._adaptive_topk_select(
|
|
457
|
+
combined_pop,
|
|
458
|
+
combined_fit,
|
|
459
|
+
k=n_remaining,
|
|
460
|
+
mask_out=mask_out,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
survivors = torch.cat([elite_pop, survivors_rest], dim=0)
|
|
464
|
+
survivor_fit = torch.cat([elite_fit, fit_rest], dim=0)
|
|
465
|
+
|
|
466
|
+
# Re-sort by fitness (best first)
|
|
467
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
468
|
+
return survivors[final_idx], survivor_fit[final_idx]
|
|
469
|
+
|
|
470
|
+
# No elitism: select directly from combined
|
|
471
|
+
survivors, survivor_fit = self._adaptive_topk_select(
|
|
472
|
+
combined_pop,
|
|
473
|
+
combined_fit,
|
|
474
|
+
k=n_survive,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Re-sort by fitness (best first)
|
|
478
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
479
|
+
return survivors[final_idx], survivor_fit[final_idx]
|
|
480
|
+
|
|
481
|
+
# =============================================================================
|
|
482
|
+
# Comma Survival (μ,λ)
|
|
483
|
+
# =============================================================================
|
|
484
|
+
|
|
485
|
+
class CommaSurvival(Survival):
|
|
486
|
+
"""
|
|
487
|
+
(μ,λ) Comma survival: Select only from offspring.
|
|
488
|
+
|
|
489
|
+
Discards all parents and selects the best n_survive individuals
|
|
490
|
+
from the offspring population only. This can help escape local
|
|
491
|
+
optima but requires n_offspring >= n_survive.
|
|
492
|
+
|
|
493
|
+
Elitism can still be enabled to preserve the best parents, but
|
|
494
|
+
they are added separately rather than competing with offspring.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
n_survive: Number of individuals to survive.
|
|
498
|
+
elitism: If True, preserve best parents (recommended).
|
|
499
|
+
n_elite: Number of elite parents to preserve.
|
|
500
|
+
adaptive: If True, use soft selection.
|
|
501
|
+
temperature: Temperature for soft selection.
|
|
502
|
+
minimize: If True, lower fitness is better.
|
|
503
|
+
|
|
504
|
+
Example:
|
|
505
|
+
>>> survival = CommaSurvival(n_survive=50, elitism=True)
|
|
506
|
+
>>> # Requires offspring.shape[0] >= 50 (or 49 if n_elite=1)
|
|
507
|
+
>>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
|
|
508
|
+
|
|
509
|
+
Note:
|
|
510
|
+
With elitism, offspring count must be >= n_survive - n_elite.
|
|
511
|
+
Without elitism, offspring count must be >= n_survive.
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
def _survive(
|
|
515
|
+
self,
|
|
516
|
+
parents: Tensor,
|
|
517
|
+
parent_fitness: Tensor,
|
|
518
|
+
offspring: Tensor,
|
|
519
|
+
offspring_fitness: Tensor,
|
|
520
|
+
n_survive: int,
|
|
521
|
+
) -> Tuple[Tensor, Tensor]:
|
|
522
|
+
"""Select from offspring only (with optional elite from parents)."""
|
|
523
|
+
|
|
524
|
+
n_offspring = offspring.shape[0]
|
|
525
|
+
n_from_offspring = n_survive - self.n_elite if self.elitism else n_survive
|
|
526
|
+
|
|
527
|
+
if n_offspring < n_from_offspring:
|
|
528
|
+
raise ValueError(
|
|
529
|
+
f"CommaSurvival requires at least {n_from_offspring} offspring, "
|
|
530
|
+
f"got {n_offspring}. Increase n_offsprings or reduce n_survive."
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
if self.adaptive:
|
|
534
|
+
# Soft selection with gradient through temperature
|
|
535
|
+
#
|
|
536
|
+
# Use straight-through Gumbel-Softmax top-k without replacement so:
|
|
537
|
+
# - forward is discrete (hard survivors)
|
|
538
|
+
# - backward carries gradients through temperature
|
|
539
|
+
hard_survivors, hard_fit = self._adaptive_topk_select(
|
|
540
|
+
offspring,
|
|
541
|
+
offspring_fitness,
|
|
542
|
+
k=n_from_offspring,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
# Re-sort by fitness (best first)
|
|
546
|
+
final_idx = self._sort_indices_best_first(hard_fit)
|
|
547
|
+
survivors = hard_survivors[final_idx]
|
|
548
|
+
survivor_fit = hard_fit[final_idx]
|
|
549
|
+
else:
|
|
550
|
+
# Hard selection from offspring
|
|
551
|
+
sorted_idx = self._best_indices(offspring_fitness, n_from_offspring)
|
|
552
|
+
survivors = offspring[sorted_idx]
|
|
553
|
+
survivor_fit = offspring_fitness[sorted_idx]
|
|
554
|
+
|
|
555
|
+
# Add elites from parents
|
|
556
|
+
if self.elitism and self.n_elite > 0:
|
|
557
|
+
elite_idx = self._best_indices(parent_fitness, self.n_elite)
|
|
558
|
+
elite_pop = parents[elite_idx].clone()
|
|
559
|
+
elite_fit = parent_fitness[elite_idx].clone()
|
|
560
|
+
|
|
561
|
+
survivors = torch.cat([elite_pop, survivors], dim=0)
|
|
562
|
+
survivor_fit = torch.cat([elite_fit, survivor_fit], dim=0)
|
|
563
|
+
|
|
564
|
+
# Re-sort
|
|
565
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
566
|
+
survivors = survivors[final_idx]
|
|
567
|
+
survivor_fit = survivor_fit[final_idx]
|
|
568
|
+
|
|
569
|
+
return survivors, survivor_fit
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
# =============================================================================
|
|
573
|
+
# Replace Worst Survival (Steady-State)
|
|
574
|
+
# =============================================================================
|
|
575
|
+
|
|
576
|
+
class ReplaceWorstSurvival(Survival):
|
|
577
|
+
"""
|
|
578
|
+
Steady-state survival: Replace worst parents with best offspring.
|
|
579
|
+
|
|
580
|
+
Instead of replacing the entire population, only the worst
|
|
581
|
+
individuals are replaced by better offspring. This creates
|
|
582
|
+
higher selection pressure and can lead to faster convergence,
|
|
583
|
+
but may also cause premature convergence.
|
|
584
|
+
|
|
585
|
+
Each offspring competes with a parent. If the offspring is
|
|
586
|
+
better, it replaces that parent. The pairing can be:
|
|
587
|
+
- 'worst': Each offspring replaces worst remaining parent
|
|
588
|
+
- 'random': Random parent-offspring pairing
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
n_survive: Number of individuals in population (parents).
|
|
592
|
+
elitism: If True, best parent is never replaced.
|
|
593
|
+
n_elite: Number of protected parents.
|
|
594
|
+
replacement: Pairing strategy ('worst' or 'random').
|
|
595
|
+
adaptive: If True, use soft replacement.
|
|
596
|
+
temperature: Temperature for soft selection.
|
|
597
|
+
minimize: If True, lower fitness is better.
|
|
598
|
+
|
|
599
|
+
Example:
|
|
600
|
+
>>> survival = ReplaceWorstSurvival(n_survive=100, elitism=True)
|
|
601
|
+
>>> # With 10 offspring, up to 10 worst parents may be replaced
|
|
602
|
+
>>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
|
|
603
|
+
"""
|
|
604
|
+
|
|
605
|
+
def __init__(
|
|
606
|
+
self,
|
|
607
|
+
n_survive: Optional[int] = None,
|
|
608
|
+
elitism: bool = True,
|
|
609
|
+
n_elite: int = 1,
|
|
610
|
+
replacement: str = 'worst',
|
|
611
|
+
adaptive: bool = False,
|
|
612
|
+
temperature: float = 1.0,
|
|
613
|
+
learn_temperature: bool = True,
|
|
614
|
+
minimize: bool = True,
|
|
615
|
+
) -> None:
|
|
616
|
+
super().__init__(
|
|
617
|
+
n_survive=n_survive,
|
|
618
|
+
elitism=elitism,
|
|
619
|
+
n_elite=n_elite,
|
|
620
|
+
adaptive=adaptive,
|
|
621
|
+
temperature=temperature,
|
|
622
|
+
learn_temperature=learn_temperature,
|
|
623
|
+
minimize=minimize,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
if replacement not in ['worst', 'random']:
|
|
627
|
+
raise ValueError(f"replacement must be 'worst' or 'random', got '{replacement}'")
|
|
628
|
+
self.replacement = replacement
|
|
629
|
+
|
|
630
|
+
def _survive(
|
|
631
|
+
self,
|
|
632
|
+
parents: Tensor,
|
|
633
|
+
parent_fitness: Tensor,
|
|
634
|
+
offspring: Tensor,
|
|
635
|
+
offspring_fitness: Tensor,
|
|
636
|
+
n_survive: int,
|
|
637
|
+
) -> Tuple[Tensor, Tensor]:
|
|
638
|
+
"""Replace worst parents with better offspring."""
|
|
639
|
+
|
|
640
|
+
n_parents = parents.shape[0]
|
|
641
|
+
n_offspring = offspring.shape[0]
|
|
642
|
+
|
|
643
|
+
# Start with parent population
|
|
644
|
+
survivors = parents.clone()
|
|
645
|
+
survivor_fit = parent_fitness.clone()
|
|
646
|
+
|
|
647
|
+
# Get indices of worst parents (candidates for replacement)
|
|
648
|
+
if self.replacement == 'worst':
|
|
649
|
+
# Sort by fitness descending (worst first)
|
|
650
|
+
# Sort by score ascending (worst first)
|
|
651
|
+
worst_idx = self._sort_indices_best_first(parent_fitness).flip(0)
|
|
652
|
+
else: # random
|
|
653
|
+
worst_idx = torch.randperm(n_parents, device=parents.device)
|
|
654
|
+
|
|
655
|
+
# Protect elites
|
|
656
|
+
if self.elitism and self.n_elite > 0:
|
|
657
|
+
elite_idx = set(self._best_indices(parent_fitness, self.n_elite).tolist())
|
|
658
|
+
# Filter out elite indices
|
|
659
|
+
worst_idx = torch.tensor(
|
|
660
|
+
[i for i in worst_idx.tolist() if i not in elite_idx],
|
|
661
|
+
device=parents.device,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Sort offspring by fitness (best first)
|
|
665
|
+
best_offspring_idx = self._sort_indices_best_first(offspring_fitness)
|
|
666
|
+
|
|
667
|
+
# Replace worst with better offspring
|
|
668
|
+
n_replace = min(len(worst_idx), n_offspring)
|
|
669
|
+
|
|
670
|
+
if self.adaptive and n_replace > 0:
|
|
671
|
+
# Create gradient carrier using all fitness values
|
|
672
|
+
all_fit = torch.cat([parent_fitness, offspring_fitness])
|
|
673
|
+
scores = self._fitness_to_scores(all_fit)
|
|
674
|
+
soft_probs = torch.softmax(scores / self.temperature, dim=0)
|
|
675
|
+
soft_weighted_fit = torch.sum(soft_probs * all_fit)
|
|
676
|
+
fit_gradient_carrier = soft_weighted_fit - soft_weighted_fit.detach()
|
|
677
|
+
|
|
678
|
+
# Soft replacement with gradient flow
|
|
679
|
+
for i in range(n_replace):
|
|
680
|
+
parent_idx = worst_idx[i]
|
|
681
|
+
off_idx = best_offspring_idx[i]
|
|
682
|
+
|
|
683
|
+
# Hard decision
|
|
684
|
+
if self.minimize:
|
|
685
|
+
should_replace = offspring_fitness[off_idx] < survivor_fit[parent_idx]
|
|
686
|
+
else:
|
|
687
|
+
should_replace = offspring_fitness[off_idx] > survivor_fit[parent_idx]
|
|
688
|
+
|
|
689
|
+
if should_replace:
|
|
690
|
+
survivors[parent_idx] = offspring[off_idx]
|
|
691
|
+
survivor_fit[parent_idx] = offspring_fitness[off_idx]
|
|
692
|
+
|
|
693
|
+
# Add gradient carrier to all fitness values
|
|
694
|
+
survivor_fit = survivor_fit + fit_gradient_carrier
|
|
695
|
+
else:
|
|
696
|
+
# Hard replacement
|
|
697
|
+
for i in range(n_replace):
|
|
698
|
+
parent_idx = worst_idx[i]
|
|
699
|
+
off_idx = best_offspring_idx[i]
|
|
700
|
+
|
|
701
|
+
# Only replace if offspring is better
|
|
702
|
+
if self.minimize:
|
|
703
|
+
should_replace = offspring_fitness[off_idx] < survivor_fit[parent_idx]
|
|
704
|
+
else:
|
|
705
|
+
should_replace = offspring_fitness[off_idx] > survivor_fit[parent_idx]
|
|
706
|
+
|
|
707
|
+
if should_replace:
|
|
708
|
+
survivors[parent_idx] = offspring[off_idx]
|
|
709
|
+
survivor_fit[parent_idx] = offspring_fitness[off_idx]
|
|
710
|
+
|
|
711
|
+
return survivors, survivor_fit
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
# =============================================================================
|
|
715
|
+
# Fitness Survival (Simple Truncation)
|
|
716
|
+
# =============================================================================
|
|
717
|
+
|
|
718
|
+
class FitnessSurvival(Survival):
|
|
719
|
+
"""
|
|
720
|
+
Simple fitness-based truncation survival.
|
|
721
|
+
|
|
722
|
+
A minimal survival operator that simply selects the n_survive
|
|
723
|
+
best individuals from the combined population. This is equivalent
|
|
724
|
+
to MergeSurvival without elitism, but provided as a simpler
|
|
725
|
+
alternative.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
n_survive: Number of individuals to survive.
|
|
729
|
+
adaptive: If True, use soft selection.
|
|
730
|
+
temperature: Temperature for soft selection.
|
|
731
|
+
minimize: If True, lower fitness is better.
|
|
732
|
+
|
|
733
|
+
Example:
|
|
734
|
+
>>> survival = FitnessSurvival(n_survive=100)
|
|
735
|
+
>>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
|
|
736
|
+
"""
|
|
737
|
+
|
|
738
|
+
def __init__(
|
|
739
|
+
self,
|
|
740
|
+
n_survive: Optional[int] = None,
|
|
741
|
+
adaptive: bool = False,
|
|
742
|
+
temperature: float = 1.0,
|
|
743
|
+
learn_temperature: bool = True,
|
|
744
|
+
minimize: bool = True,
|
|
745
|
+
) -> None:
|
|
746
|
+
super().__init__(
|
|
747
|
+
n_survive=n_survive,
|
|
748
|
+
elitism=False,
|
|
749
|
+
n_elite=0,
|
|
750
|
+
adaptive=adaptive,
|
|
751
|
+
temperature=temperature,
|
|
752
|
+
learn_temperature=learn_temperature,
|
|
753
|
+
minimize=minimize,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
def _survive(
|
|
757
|
+
self,
|
|
758
|
+
parents: Tensor,
|
|
759
|
+
parent_fitness: Tensor,
|
|
760
|
+
offspring: Tensor,
|
|
761
|
+
offspring_fitness: Tensor,
|
|
762
|
+
n_survive: int,
|
|
763
|
+
) -> Tuple[Tensor, Tensor]:
|
|
764
|
+
"""Select best n_survive from combined population."""
|
|
765
|
+
|
|
766
|
+
combined_pop = torch.cat([parents, offspring], dim=0)
|
|
767
|
+
combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
|
|
768
|
+
|
|
769
|
+
if self.adaptive:
|
|
770
|
+
# Soft selection with gradient through temperature
|
|
771
|
+
#
|
|
772
|
+
# Use straight-through Gumbel-Softmax top-k without replacement so:
|
|
773
|
+
# - forward is discrete (hard survivors)
|
|
774
|
+
# - backward carries gradients through temperature
|
|
775
|
+
survivors, survivor_fit = self._adaptive_topk_select(
|
|
776
|
+
combined_pop,
|
|
777
|
+
combined_fit,
|
|
778
|
+
k=n_survive,
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
# Re-sort by fitness (best first)
|
|
782
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
783
|
+
return survivors[final_idx], survivor_fit[final_idx]
|
|
784
|
+
else:
|
|
785
|
+
sorted_idx = self._best_indices(combined_fit, n_survive)
|
|
786
|
+
return combined_pop[sorted_idx], combined_fit[sorted_idx]
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
# =============================================================================
|
|
790
|
+
# Age-Based Survival
|
|
791
|
+
# =============================================================================
|
|
792
|
+
|
|
793
|
+
class AgeSurvival(Survival):
|
|
794
|
+
"""
|
|
795
|
+
Age-based survival with fitness tie-breaking.
|
|
796
|
+
|
|
797
|
+
Tracks the age (number of generations) of each individual.
|
|
798
|
+
Older individuals are replaced first, with fitness used as
|
|
799
|
+
a tie-breaker. This can help maintain diversity by preventing
|
|
800
|
+
any individual from dominating the population indefinitely.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
n_survive: Number of individuals to survive.
|
|
804
|
+
max_age: Maximum age before forced replacement.
|
|
805
|
+
elitism: If True, preserve best regardless of age.
|
|
806
|
+
n_elite: Number of age-exempt elite individuals.
|
|
807
|
+
adaptive: If True, use soft selection.
|
|
808
|
+
temperature: Temperature for soft selection.
|
|
809
|
+
minimize: If True, lower fitness is better.
|
|
810
|
+
|
|
811
|
+
Example:
|
|
812
|
+
>>> survival = AgeSurvival(n_survive=100, max_age=10)
|
|
813
|
+
>>> # Individuals older than 10 generations are replaced
|
|
814
|
+
>>> new_pop, new_fit = survival(parents, p_fit, offspring, o_fit)
|
|
815
|
+
|
|
816
|
+
Note:
|
|
817
|
+
Age tracking is maintained externally. This operator expects
|
|
818
|
+
ages to be passed or uses fitness-only selection if not provided.
|
|
819
|
+
"""
|
|
820
|
+
|
|
821
|
+
def __init__(
|
|
822
|
+
self,
|
|
823
|
+
n_survive: Optional[int] = None,
|
|
824
|
+
max_age: int = 10,
|
|
825
|
+
elitism: bool = True,
|
|
826
|
+
n_elite: int = 1,
|
|
827
|
+
adaptive: bool = False,
|
|
828
|
+
temperature: float = 1.0,
|
|
829
|
+
learn_temperature: bool = True,
|
|
830
|
+
minimize: bool = True,
|
|
831
|
+
) -> None:
|
|
832
|
+
super().__init__(
|
|
833
|
+
n_survive=n_survive,
|
|
834
|
+
elitism=elitism,
|
|
835
|
+
n_elite=n_elite,
|
|
836
|
+
adaptive=adaptive,
|
|
837
|
+
temperature=temperature,
|
|
838
|
+
learn_temperature=learn_temperature,
|
|
839
|
+
minimize=minimize,
|
|
840
|
+
)
|
|
841
|
+
self.max_age = max_age
|
|
842
|
+
|
|
843
|
+
def _survive(
|
|
844
|
+
self,
|
|
845
|
+
parents: Tensor,
|
|
846
|
+
parent_fitness: Tensor,
|
|
847
|
+
offspring: Tensor,
|
|
848
|
+
offspring_fitness: Tensor,
|
|
849
|
+
n_survive: int,
|
|
850
|
+
) -> Tuple[Tensor, Tensor]:
|
|
851
|
+
"""Age-based selection with fitness tie-breaking."""
|
|
852
|
+
|
|
853
|
+
n_parents = parents.shape[0]
|
|
854
|
+
n_offspring = offspring.shape[0]
|
|
855
|
+
|
|
856
|
+
# Combine populations
|
|
857
|
+
combined_pop = torch.cat([parents, offspring], dim=0)
|
|
858
|
+
combined_fit = torch.cat([parent_fitness, offspring_fitness], dim=0)
|
|
859
|
+
|
|
860
|
+
# Create age tensor (parents have age >= 1, offspring have age 0)
|
|
861
|
+
# In practice, ages would be tracked externally
|
|
862
|
+
ages = torch.cat([
|
|
863
|
+
torch.ones(n_parents, device=parents.device), # Parents
|
|
864
|
+
torch.zeros(n_offspring, device=offspring.device), # Offspring
|
|
865
|
+
])
|
|
866
|
+
|
|
867
|
+
# Create composite score: prioritize younger and fitter
|
|
868
|
+
# Lower score = more likely to survive
|
|
869
|
+
if self.minimize:
|
|
870
|
+
fit_score = combined_fit
|
|
871
|
+
else:
|
|
872
|
+
fit_score = -combined_fit
|
|
873
|
+
|
|
874
|
+
# Normalize fitness to [0, 1] for fair combination with age
|
|
875
|
+
fit_min, fit_max = fit_score.min(), fit_score.max()
|
|
876
|
+
if fit_max > fit_min:
|
|
877
|
+
fit_norm = (fit_score - fit_min) / (fit_max - fit_min)
|
|
878
|
+
else:
|
|
879
|
+
fit_norm = torch.zeros_like(fit_score)
|
|
880
|
+
|
|
881
|
+
# Composite: 70% fitness, 30% age (normalized)
|
|
882
|
+
age_norm = ages / (self.max_age + 1)
|
|
883
|
+
composite = 0.7 * fit_norm + 0.3 * age_norm
|
|
884
|
+
|
|
885
|
+
if self.adaptive:
|
|
886
|
+
# Soft selection based on composite score (lower = better, so negate)
|
|
887
|
+
scores = -composite # higher = better
|
|
888
|
+
|
|
889
|
+
# Straight-through Gumbel-Softmax top-k without replacement
|
|
890
|
+
W = self._gumbel_topk(scores, k=n_survive, dim=0) # [n_survive, n_total]
|
|
891
|
+
survivors = W @ combined_pop
|
|
892
|
+
survivor_fit = W @ combined_fit
|
|
893
|
+
|
|
894
|
+
# Re-sort by actual fitness (best first) for consistent output ordering
|
|
895
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
896
|
+
return survivors[final_idx], survivor_fit[final_idx]
|
|
897
|
+
else:
|
|
898
|
+
# Hard selection
|
|
899
|
+
sorted_idx = TopKSelection.best_indices(composite, k=n_survive, minimize=True)
|
|
900
|
+
survivors = combined_pop[sorted_idx]
|
|
901
|
+
survivor_fit = combined_fit[sorted_idx]
|
|
902
|
+
|
|
903
|
+
# Handle elitism
|
|
904
|
+
if self.elitism and self.n_elite > 0:
|
|
905
|
+
elite_idx = self._best_indices(parent_fitness, self.n_elite)
|
|
906
|
+
elite_pop = parents[elite_idx].clone()
|
|
907
|
+
elite_fit = parent_fitness[elite_idx].clone()
|
|
908
|
+
|
|
909
|
+
# Replace last n_elite survivors with elites
|
|
910
|
+
n_remaining = n_survive - self.n_elite
|
|
911
|
+
survivors = torch.cat([elite_pop, survivors[:n_remaining]], dim=0)
|
|
912
|
+
survivor_fit = torch.cat([elite_fit, survivor_fit[:n_remaining]], dim=0)
|
|
913
|
+
|
|
914
|
+
final_idx = self._sort_indices_best_first(survivor_fit)
|
|
915
|
+
survivors = survivors[final_idx]
|
|
916
|
+
survivor_fit = survivor_fit[final_idx]
|
|
917
|
+
|
|
918
|
+
return survivors, survivor_fit
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
# =============================================================================
|
|
922
|
+
# Utility Functions
|
|
923
|
+
# =============================================================================
|
|
924
|
+
|
|
925
|
+
def get_survival(
|
|
926
|
+
strategy: str,
|
|
927
|
+
n_survive: Optional[int] = None,
|
|
928
|
+
elitism: bool = True,
|
|
929
|
+
n_elite: int = 1,
|
|
930
|
+
adaptive: bool = False,
|
|
931
|
+
**kwargs,
|
|
932
|
+
) -> Survival:
|
|
933
|
+
"""
|
|
934
|
+
Factory function to create survival operators by name.
|
|
935
|
+
|
|
936
|
+
Args:
|
|
937
|
+
strategy: Survival strategy name. Options:
|
|
938
|
+
- 'merge', 'plus', '(mu+lambda)': MergeSurvival
|
|
939
|
+
- 'comma', '(mu,lambda)': CommaSurvival
|
|
940
|
+
- 'replace_worst', 'steady_state': ReplaceWorstSurvival
|
|
941
|
+
- 'fitness', 'truncation': FitnessSurvival
|
|
942
|
+
- 'age': AgeSurvival
|
|
943
|
+
n_survive: Number of individuals to survive.
|
|
944
|
+
elitism: Whether to preserve best individuals.
|
|
945
|
+
n_elite: Number of elite individuals.
|
|
946
|
+
adaptive: Enable gradient flow.
|
|
947
|
+
**kwargs: Additional arguments for specific strategies.
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
Configured Survival operator.
|
|
951
|
+
|
|
952
|
+
Example:
|
|
953
|
+
>>> survival = get_survival('plus', n_survive=100, elitism=True)
|
|
954
|
+
>>> survival = get_survival('comma', n_survive=50, adaptive=True)
|
|
955
|
+
"""
|
|
956
|
+
strategy = strategy.lower().strip()
|
|
957
|
+
|
|
958
|
+
if strategy in ['merge', 'plus', '(mu+lambda)', 'mu+lambda']:
|
|
959
|
+
return MergeSurvival(
|
|
960
|
+
n_survive=n_survive,
|
|
961
|
+
elitism=elitism,
|
|
962
|
+
n_elite=n_elite,
|
|
963
|
+
adaptive=adaptive,
|
|
964
|
+
**kwargs,
|
|
965
|
+
)
|
|
966
|
+
elif strategy in ['comma', '(mu,lambda)', 'mu,lambda']:
|
|
967
|
+
return CommaSurvival(
|
|
968
|
+
n_survive=n_survive,
|
|
969
|
+
elitism=elitism,
|
|
970
|
+
n_elite=n_elite,
|
|
971
|
+
adaptive=adaptive,
|
|
972
|
+
**kwargs,
|
|
973
|
+
)
|
|
974
|
+
elif strategy in ['replace_worst', 'steady_state', 'steady-state']:
|
|
975
|
+
return ReplaceWorstSurvival(
|
|
976
|
+
n_survive=n_survive,
|
|
977
|
+
elitism=elitism,
|
|
978
|
+
n_elite=n_elite,
|
|
979
|
+
adaptive=adaptive,
|
|
980
|
+
**kwargs,
|
|
981
|
+
)
|
|
982
|
+
elif strategy in ['fitness', 'truncation']:
|
|
983
|
+
return FitnessSurvival(
|
|
984
|
+
n_survive=n_survive,
|
|
985
|
+
adaptive=adaptive,
|
|
986
|
+
**kwargs,
|
|
987
|
+
)
|
|
988
|
+
elif strategy in ['age', 'age_based']:
|
|
989
|
+
return AgeSurvival(
|
|
990
|
+
n_survive=n_survive,
|
|
991
|
+
elitism=elitism,
|
|
992
|
+
n_elite=n_elite,
|
|
993
|
+
adaptive=adaptive,
|
|
994
|
+
**kwargs,
|
|
995
|
+
)
|
|
996
|
+
else:
|
|
997
|
+
raise ValueError(
|
|
998
|
+
f"Unknown survival strategy: '{strategy}'. "
|
|
999
|
+
f"Options: merge/plus, comma, replace_worst, fitness, age"
|
|
1000
|
+
)
|