evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,981 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Selection operators for parent selection.
|
|
3
|
+
|
|
4
|
+
This module provides strategies for selecting parents from the
|
|
5
|
+
population for recombination. All selectors support both classical
|
|
6
|
+
(hard) and differentiable (i.e., adaptive) (Gumbel-Softmax) modes.
|
|
7
|
+
|
|
8
|
+
Available selectors:
|
|
9
|
+
- TournamentSelection: Tournament-based selection
|
|
10
|
+
- RouletteSelection: Fitness-proportionate selection
|
|
11
|
+
- RankSelection: Rank-based selection
|
|
12
|
+
- RandomSelection: Uniform random selection
|
|
13
|
+
- TruncationSelection: Select *from* the top fraction (samples within truncated set)
|
|
14
|
+
- TopKSelection: Deterministic top-k WITHOUT replacement (hard + differentiable)
|
|
15
|
+
|
|
16
|
+
Differentiable Mode:
|
|
17
|
+
When `adaptive=True`, selection uses Gumbel-Softmax
|
|
18
|
+
relaxation with straight-through estimator, allowing gradients
|
|
19
|
+
to flow through the selection process.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
>>> from evograd.operators import TournamentSelection
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Classical mode
|
|
25
|
+
>>> selector = TournamentSelection(tournament_size=3)
|
|
26
|
+
>>> parents = selector(population, fitness, n_parents=50)
|
|
27
|
+
>>>
|
|
28
|
+
>>> # Differentiable mode
|
|
29
|
+
>>> selector = TournamentSelection(
|
|
30
|
+
... tournament_size=3,
|
|
31
|
+
... adaptive=True,
|
|
32
|
+
... temperature=1.0,
|
|
33
|
+
... )
|
|
34
|
+
>>> parents = selector(population, fitness, n_parents=50)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
from abc import ABC, abstractmethod
|
|
40
|
+
from typing import Optional, Tuple, Union
|
|
41
|
+
|
|
42
|
+
import math
|
|
43
|
+
import torch
|
|
44
|
+
import torch.nn as nn
|
|
45
|
+
from torch import Tensor
|
|
46
|
+
|
|
47
|
+
from evograd.operators.relaxations import gumbel_softmax
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
"Selection",
|
|
51
|
+
"TournamentSelection",
|
|
52
|
+
"RouletteSelection",
|
|
53
|
+
"RankSelection",
|
|
54
|
+
"RandomSelection",
|
|
55
|
+
"TopKSelection",
|
|
56
|
+
"TruncationSelection",
|
|
57
|
+
"StochasticUniversalSampling",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# =============================================================================
|
|
62
|
+
# Base Selection Class
|
|
63
|
+
# =============================================================================
|
|
64
|
+
|
|
65
|
+
class Selection(nn.Module, ABC):
|
|
66
|
+
"""
|
|
67
|
+
Abstract base class for selection operators.
|
|
68
|
+
|
|
69
|
+
Subclasses must implement:
|
|
70
|
+
- _select(): Perform selection and return indices
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
adaptive: If True, use Gumbel-Softmax for soft selection.
|
|
74
|
+
temperature: Temperature for Gumbel-Softmax (lower = harder).
|
|
75
|
+
learn_temperature: If True, temperature is a learnable parameter.
|
|
76
|
+
minimize: If True, lower fitness is better (default).
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
adaptive: bool = False,
|
|
82
|
+
temperature: float = 1.0,
|
|
83
|
+
learn_temperature: bool = True,
|
|
84
|
+
minimize: bool = True,
|
|
85
|
+
) -> None:
|
|
86
|
+
super().__init__()
|
|
87
|
+
|
|
88
|
+
self._MIN_TEMPERATURE = 0.05
|
|
89
|
+
self._MAX_TEMPERATURE = 10.0
|
|
90
|
+
|
|
91
|
+
self.adaptive = adaptive
|
|
92
|
+
self.minimize = minimize
|
|
93
|
+
|
|
94
|
+
# Temperature parameter
|
|
95
|
+
if learn_temperature and adaptive:
|
|
96
|
+
# Store as log for positivity
|
|
97
|
+
self._log_temperature = nn.Parameter(
|
|
98
|
+
torch.tensor(temperature).log()
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
self.register_buffer(
|
|
102
|
+
"_log_temperature",
|
|
103
|
+
torch.tensor(temperature).log()
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def temperature(self) -> Tensor:
|
|
108
|
+
"""Current temperature value."""
|
|
109
|
+
return self._log_temperature.exp()
|
|
110
|
+
|
|
111
|
+
@temperature.setter
|
|
112
|
+
def temperature(self, value: float) -> None:
|
|
113
|
+
"""Set temperature value."""
|
|
114
|
+
with torch.no_grad():
|
|
115
|
+
self._log_temperature.fill_(torch.tensor(value).log())
|
|
116
|
+
|
|
117
|
+
def _fitness_to_scores(self, fitness: Tensor) -> Tensor:
|
|
118
|
+
"""
|
|
119
|
+
Convert fitness to selection scores (higher = better).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
fitness: Raw fitness values [n_pop].
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Selection scores [n_pop] where higher is better.
|
|
126
|
+
"""
|
|
127
|
+
if self.minimize:
|
|
128
|
+
# Negate so lower fitness = higher score
|
|
129
|
+
return -fitness
|
|
130
|
+
else:
|
|
131
|
+
return fitness
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _clamp_temperature(self):
|
|
135
|
+
# Clamp temperature for numerical stability (operator responsibility)
|
|
136
|
+
if hasattr(self, "_log_temperature") and self._log_temperature is not None:
|
|
137
|
+
with torch.no_grad():
|
|
138
|
+
self._log_temperature.clamp_(
|
|
139
|
+
math.log(self._MIN_TEMPERATURE),
|
|
140
|
+
math.log(self._MAX_TEMPERATURE),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@abstractmethod
|
|
144
|
+
def _select(
|
|
145
|
+
self,
|
|
146
|
+
population: Tensor,
|
|
147
|
+
fitness: Tensor,
|
|
148
|
+
n_select: int,
|
|
149
|
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
|
150
|
+
"""
|
|
151
|
+
Perform selection.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
population: Population tensor [n_pop, n_var].
|
|
155
|
+
fitness: Fitness values [n_pop].
|
|
156
|
+
n_select: Number of individuals to select.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Selected individuals [n_select, n_var] or
|
|
160
|
+
tuple of (selected, indices) if return_indices=True.
|
|
161
|
+
"""
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
def forward(
|
|
165
|
+
self,
|
|
166
|
+
population: Tensor,
|
|
167
|
+
fitness: Tensor,
|
|
168
|
+
n_select: Optional[int] = None,
|
|
169
|
+
return_indices: bool = False,
|
|
170
|
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
|
171
|
+
"""
|
|
172
|
+
Select parents from population.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
population: Population tensor [n_pop, n_var].
|
|
176
|
+
fitness: Fitness values [n_pop].
|
|
177
|
+
n_select: Number to select (default: population size).
|
|
178
|
+
return_indices: If True, also return selection indices.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Selected individuals [n_select, n_var], or
|
|
182
|
+
tuple (selected, indices) if return_indices=True.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
if n_select is None:
|
|
186
|
+
n_select = population.shape[0]
|
|
187
|
+
|
|
188
|
+
self._clamp_temperature()
|
|
189
|
+
result = self._select(population, fitness, n_select)
|
|
190
|
+
|
|
191
|
+
if return_indices:
|
|
192
|
+
return result
|
|
193
|
+
elif isinstance(result, tuple):
|
|
194
|
+
return result[0]
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
# Note: Do NOT override __call__. nn.Module.__call__ dispatches to
|
|
198
|
+
# forward() and fires registered hooks (forward_pre_hooks, forward_hooks,
|
|
199
|
+
# and the autograd profiler). Overriding __call__ would bypass all of these.
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# =============================================================================
|
|
203
|
+
# Tournament Selection
|
|
204
|
+
# =============================================================================
|
|
205
|
+
|
|
206
|
+
class TournamentSelection(Selection):
|
|
207
|
+
"""
|
|
208
|
+
Tournament selection.
|
|
209
|
+
|
|
210
|
+
For each selection, randomly pick `tournament_size` individuals
|
|
211
|
+
and select the best one. This is the most common selection
|
|
212
|
+
operator for genetic algorithms.
|
|
213
|
+
|
|
214
|
+
In differentiable mode, tournament winners are selected using
|
|
215
|
+
Gumbel-Softmax over the tournament participants.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
tournament_size: Number of individuals per tournament.
|
|
219
|
+
adaptive: If True, use Gumbel-Softmax selection.
|
|
220
|
+
temperature: Temperature for Gumbel-Softmax.
|
|
221
|
+
learn_temperature: If True, temperature is learnable.
|
|
222
|
+
minimize: If True, lower fitness is better.
|
|
223
|
+
replacement: If True, allow same individual in tournament.
|
|
224
|
+
|
|
225
|
+
Example:
|
|
226
|
+
>>> selector = TournamentSelection(tournament_size=3)
|
|
227
|
+
>>> parents = selector(population, fitness, n_select=50)
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
tournament_size: int = 3,
|
|
233
|
+
adaptive: bool = False,
|
|
234
|
+
temperature: float = 1.0,
|
|
235
|
+
learn_temperature: bool = True,
|
|
236
|
+
minimize: bool = True,
|
|
237
|
+
replacement: bool = True,
|
|
238
|
+
) -> None:
|
|
239
|
+
super().__init__(
|
|
240
|
+
adaptive=adaptive,
|
|
241
|
+
temperature=temperature,
|
|
242
|
+
learn_temperature=learn_temperature,
|
|
243
|
+
minimize=minimize,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if tournament_size < 2:
|
|
247
|
+
raise ValueError(f"tournament_size must be >= 2, got {tournament_size}")
|
|
248
|
+
|
|
249
|
+
self.tournament_size = tournament_size
|
|
250
|
+
self.replacement = replacement
|
|
251
|
+
|
|
252
|
+
def _select(
|
|
253
|
+
self,
|
|
254
|
+
population: Tensor,
|
|
255
|
+
fitness: Tensor,
|
|
256
|
+
n_select: int,
|
|
257
|
+
) -> Tuple[Tensor, Tensor]:
|
|
258
|
+
n_pop, n_var = population.shape
|
|
259
|
+
device = population.device
|
|
260
|
+
k = self.tournament_size
|
|
261
|
+
|
|
262
|
+
scores = self._fitness_to_scores(fitness)
|
|
263
|
+
|
|
264
|
+
# Sample tournament participants
|
|
265
|
+
# Shape: [n_select, tournament_size]
|
|
266
|
+
if self.replacement:
|
|
267
|
+
tournament_idx = torch.randint(
|
|
268
|
+
0, n_pop, (n_select, k), device=device
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
# Without replacement (slower)
|
|
272
|
+
tournament_idx = torch.stack([
|
|
273
|
+
torch.randperm(n_pop, device=device)[:k]
|
|
274
|
+
for _ in range(n_select)
|
|
275
|
+
])
|
|
276
|
+
|
|
277
|
+
# Get scores for tournament participants
|
|
278
|
+
tournament_scores = scores[tournament_idx] # [n_select, k]
|
|
279
|
+
|
|
280
|
+
if self.adaptive:
|
|
281
|
+
# Gumbel-Softmax selection within each tournament
|
|
282
|
+
# Use scores as logits, PASS TEMPERATURE
|
|
283
|
+
weights = gumbel_softmax(
|
|
284
|
+
tournament_scores,
|
|
285
|
+
temperature=self.temperature, # Pass temperature
|
|
286
|
+
dim=-1
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Get tournament participants
|
|
290
|
+
tournament_pop = population[tournament_idx] # [n_select, k, n_var]
|
|
291
|
+
|
|
292
|
+
# Weighted combination (hard weights = one-hot)
|
|
293
|
+
selected = torch.einsum('nk,nkd->nd', weights, tournament_pop)
|
|
294
|
+
|
|
295
|
+
# For indices, use argmax of weights
|
|
296
|
+
relative_idx = weights.argmax(dim=-1) # [n_select]
|
|
297
|
+
indices = tournament_idx.gather(1, relative_idx.unsqueeze(-1)).squeeze(-1)
|
|
298
|
+
else:
|
|
299
|
+
# Hard selection: pick best in each tournament
|
|
300
|
+
relative_idx = tournament_scores.argmax(dim=-1) # [n_select]
|
|
301
|
+
indices = tournament_idx.gather(1, relative_idx.unsqueeze(-1)).squeeze(-1)
|
|
302
|
+
selected = population[indices]
|
|
303
|
+
|
|
304
|
+
return selected, indices
|
|
305
|
+
|
|
306
|
+
def __repr__(self) -> str:
|
|
307
|
+
return (
|
|
308
|
+
f"TournamentSelection("
|
|
309
|
+
f"tournament_size={self.tournament_size}, "
|
|
310
|
+
f"adaptive={self.adaptive}, "
|
|
311
|
+
f"temperature={self.temperature.item():.3f})"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# =============================================================================
|
|
316
|
+
# Roulette (Fitness Proportionate) Selection
|
|
317
|
+
# =============================================================================
|
|
318
|
+
|
|
319
|
+
class RouletteSelection(Selection):
|
|
320
|
+
"""
|
|
321
|
+
Roulette wheel (fitness proportionate) selection.
|
|
322
|
+
|
|
323
|
+
Selection probability is proportional to fitness. Better
|
|
324
|
+
individuals have higher probability of being selected.
|
|
325
|
+
|
|
326
|
+
In differentiable mode, uses Gumbel-Softmax over the
|
|
327
|
+
entire population with fitness-based logits.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
adaptive: If True, use Gumbel-Softmax selection.
|
|
331
|
+
temperature: Temperature for Gumbel-Softmax.
|
|
332
|
+
learn_temperature: If True, temperature is learnable.
|
|
333
|
+
minimize: If True, lower fitness is better.
|
|
334
|
+
eps: Small constant to avoid division by zero.
|
|
335
|
+
|
|
336
|
+
Example:
|
|
337
|
+
>>> selector = RouletteSelection()
|
|
338
|
+
>>> parents = selector(population, fitness, n_select=50)
|
|
339
|
+
|
|
340
|
+
Note:
|
|
341
|
+
For minimisation problems, fitness is transformed to
|
|
342
|
+
ensure positive selection probabilities.
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
def __init__(
|
|
346
|
+
self,
|
|
347
|
+
adaptive: bool = False,
|
|
348
|
+
temperature: float = 1.0,
|
|
349
|
+
learn_temperature: bool = True,
|
|
350
|
+
minimize: bool = True,
|
|
351
|
+
eps: float = 1e-10,
|
|
352
|
+
) -> None:
|
|
353
|
+
super().__init__(
|
|
354
|
+
adaptive=adaptive,
|
|
355
|
+
temperature=temperature,
|
|
356
|
+
learn_temperature=learn_temperature,
|
|
357
|
+
minimize=minimize,
|
|
358
|
+
)
|
|
359
|
+
self.eps = eps
|
|
360
|
+
|
|
361
|
+
def _select(
|
|
362
|
+
self,
|
|
363
|
+
population: Tensor,
|
|
364
|
+
fitness: Tensor,
|
|
365
|
+
n_select: int,
|
|
366
|
+
) -> Tuple[Tensor, Tensor]:
|
|
367
|
+
n_pop = population.shape[0]
|
|
368
|
+
device = population.device
|
|
369
|
+
|
|
370
|
+
scores = self._fitness_to_scores(fitness)
|
|
371
|
+
|
|
372
|
+
# Shift scores to be positive (for probability calculation)
|
|
373
|
+
shifted_scores = scores - scores.min() + self.eps
|
|
374
|
+
|
|
375
|
+
if self.adaptive:
|
|
376
|
+
# Use log-probabilities as logits
|
|
377
|
+
logits = torch.log(shifted_scores + self.eps)
|
|
378
|
+
|
|
379
|
+
# Expand logits for n_select samples
|
|
380
|
+
logits_expanded = logits.unsqueeze(0).expand(n_select, -1)
|
|
381
|
+
|
|
382
|
+
# Gumbel-Softmax selection, PASS TEMPERATURE
|
|
383
|
+
weights = gumbel_softmax(
|
|
384
|
+
logits_expanded,
|
|
385
|
+
temperature=self.temperature, # Pass temperature
|
|
386
|
+
dim=-1
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Weighted combination
|
|
390
|
+
selected = torch.matmul(weights, population) # [n_select, n_var]
|
|
391
|
+
indices = weights.argmax(dim=-1)
|
|
392
|
+
else:
|
|
393
|
+
# Classical roulette selection
|
|
394
|
+
probs = shifted_scores / shifted_scores.sum()
|
|
395
|
+
indices = torch.multinomial(probs, n_select, replacement=True)
|
|
396
|
+
selected = population[indices]
|
|
397
|
+
|
|
398
|
+
return selected, indices
|
|
399
|
+
|
|
400
|
+
def __repr__(self) -> str:
|
|
401
|
+
return (
|
|
402
|
+
f"RouletteSelection("
|
|
403
|
+
f"adaptive={self.adaptive}, "
|
|
404
|
+
f"temperature={self.temperature.item():.3f})"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
# =============================================================================
|
|
409
|
+
# Rank Selection
|
|
410
|
+
# =============================================================================
|
|
411
|
+
|
|
412
|
+
class RankSelection(Selection):
|
|
413
|
+
"""
|
|
414
|
+
Rank-based selection.
|
|
415
|
+
|
|
416
|
+
Selection probability is based on rank rather than raw fitness.
|
|
417
|
+
This reduces selection pressure compared to fitness-proportionate
|
|
418
|
+
selection and is more robust to fitness scaling.
|
|
419
|
+
|
|
420
|
+
Two ranking schemes are available:
|
|
421
|
+
- 'linear': Probability proportional to rank
|
|
422
|
+
- 'exponential': Probability decays exponentially with rank
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
scheme: Ranking scheme ('linear' or 'exponential').
|
|
426
|
+
selection_pressure: Controls selection intensity.
|
|
427
|
+
For 'linear': in [1.0, 2.0], higher = more pressure.
|
|
428
|
+
For 'exponential': decay factor, higher = more pressure.
|
|
429
|
+
adaptive: If True, use Gumbel-Softmax selection.
|
|
430
|
+
temperature: Temperature for Gumbel-Softmax.
|
|
431
|
+
learn_temperature: If True, temperature is learnable.
|
|
432
|
+
minimize: If True, lower fitness is better.
|
|
433
|
+
|
|
434
|
+
Example:
|
|
435
|
+
>>> selector = RankSelection(scheme='linear', selection_pressure=1.5)
|
|
436
|
+
>>> parents = selector(population, fitness, n_select=50)
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def __init__(
|
|
440
|
+
self,
|
|
441
|
+
scheme: str = "linear",
|
|
442
|
+
selection_pressure: float = 1.5,
|
|
443
|
+
adaptive: bool = False,
|
|
444
|
+
temperature: float = 1.0,
|
|
445
|
+
learn_temperature: bool = True,
|
|
446
|
+
minimize: bool = True,
|
|
447
|
+
) -> None:
|
|
448
|
+
super().__init__(
|
|
449
|
+
adaptive=adaptive,
|
|
450
|
+
temperature=temperature,
|
|
451
|
+
learn_temperature=learn_temperature,
|
|
452
|
+
minimize=minimize,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
if scheme not in ("linear", "exponential"):
|
|
456
|
+
raise ValueError(f"scheme must be 'linear' or 'exponential', got '{scheme}'")
|
|
457
|
+
|
|
458
|
+
self.scheme = scheme
|
|
459
|
+
self.selection_pressure = selection_pressure
|
|
460
|
+
|
|
461
|
+
def _compute_rank_probabilities(
|
|
462
|
+
self,
|
|
463
|
+
n_pop: int,
|
|
464
|
+
device: torch.device,
|
|
465
|
+
dtype: torch.dtype,
|
|
466
|
+
) -> Tensor:
|
|
467
|
+
"""Compute selection probabilities based on rank."""
|
|
468
|
+
# Ranks from 1 (best) to n_pop (worst)
|
|
469
|
+
ranks = torch.arange(1, n_pop + 1, device=device, dtype=dtype)
|
|
470
|
+
|
|
471
|
+
if self.scheme == "linear":
|
|
472
|
+
# Linear ranking: P(rank=i) = (2-s)/n + 2*(s-1)*(n-i)/(n*(n-1))
|
|
473
|
+
# where s is selection pressure in [1, 2]
|
|
474
|
+
s = self.selection_pressure
|
|
475
|
+
n = float(n_pop)
|
|
476
|
+
probs = (2 - s) / n + 2 * (s - 1) * (n - ranks) / (n * (n - 1))
|
|
477
|
+
else: # exponential
|
|
478
|
+
# Exponential ranking: P(rank=i) = exp(-c * (i-1))
|
|
479
|
+
c = self.selection_pressure
|
|
480
|
+
probs = torch.exp(-c * (ranks - 1))
|
|
481
|
+
|
|
482
|
+
# Normalise
|
|
483
|
+
return probs / probs.sum()
|
|
484
|
+
|
|
485
|
+
def _select(
|
|
486
|
+
self,
|
|
487
|
+
population: Tensor,
|
|
488
|
+
fitness: Tensor,
|
|
489
|
+
n_select: int,
|
|
490
|
+
) -> Tuple[Tensor, Tensor]:
|
|
491
|
+
n_pop = population.shape[0]
|
|
492
|
+
device = population.device
|
|
493
|
+
dtype = population.dtype
|
|
494
|
+
|
|
495
|
+
scores = self._fitness_to_scores(fitness)
|
|
496
|
+
|
|
497
|
+
# Sort by score (descending, so best first)
|
|
498
|
+
sorted_indices = torch.argsort(scores, descending=True)
|
|
499
|
+
|
|
500
|
+
# Compute rank-based probabilities
|
|
501
|
+
rank_probs = self._compute_rank_probabilities(n_pop, device, dtype)
|
|
502
|
+
|
|
503
|
+
if self.adaptive:
|
|
504
|
+
# Use log-probabilities as logits
|
|
505
|
+
logits = torch.log(rank_probs + 1e-10)
|
|
506
|
+
logits_expanded = logits.unsqueeze(0).expand(n_select, -1)
|
|
507
|
+
|
|
508
|
+
# Gumbel-Softmax selection (in rank space), PASS TEMPERATURE
|
|
509
|
+
weights = gumbel_softmax(
|
|
510
|
+
logits_expanded,
|
|
511
|
+
temperature=self.temperature, # Pass temperature
|
|
512
|
+
dim=-1
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# Map back to original indices
|
|
516
|
+
sorted_pop = population[sorted_indices]
|
|
517
|
+
selected = torch.matmul(weights, sorted_pop)
|
|
518
|
+
|
|
519
|
+
# Get indices in original population
|
|
520
|
+
rank_indices = weights.argmax(dim=-1)
|
|
521
|
+
indices = sorted_indices[rank_indices]
|
|
522
|
+
else:
|
|
523
|
+
# Classical sampling by rank
|
|
524
|
+
rank_indices = torch.multinomial(rank_probs, n_select, replacement=True)
|
|
525
|
+
indices = sorted_indices[rank_indices]
|
|
526
|
+
selected = population[indices]
|
|
527
|
+
|
|
528
|
+
return selected, indices
|
|
529
|
+
|
|
530
|
+
def __repr__(self) -> str:
|
|
531
|
+
return (
|
|
532
|
+
f"RankSelection("
|
|
533
|
+
f"scheme='{self.scheme}', "
|
|
534
|
+
f"selection_pressure={self.selection_pressure}, "
|
|
535
|
+
f"adaptive={self.adaptive})"
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
# =============================================================================
|
|
540
|
+
# Random Selection
|
|
541
|
+
# =============================================================================
|
|
542
|
+
|
|
543
|
+
class RandomSelection(Selection):
|
|
544
|
+
"""
|
|
545
|
+
Uniform random selection (baseline).
|
|
546
|
+
|
|
547
|
+
Selects individuals uniformly at random, ignoring fitness.
|
|
548
|
+
Useful as a baseline or for algorithms that don't use
|
|
549
|
+
fitness-based selection.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
replacement: If True, allow selecting same individual multiple times.
|
|
553
|
+
adaptive: If True, use Gumbel-Softmax (uniform logits).
|
|
554
|
+
temperature: Temperature for Gumbel-Softmax.
|
|
555
|
+
|
|
556
|
+
Example:
|
|
557
|
+
>>> selector = RandomSelection()
|
|
558
|
+
>>> parents = selector(population, fitness, n_select=50)
|
|
559
|
+
"""
|
|
560
|
+
|
|
561
|
+
def __init__(
|
|
562
|
+
self,
|
|
563
|
+
replacement: bool = True,
|
|
564
|
+
adaptive: bool = False,
|
|
565
|
+
temperature: float = 1.0,
|
|
566
|
+
) -> None:
|
|
567
|
+
super().__init__(
|
|
568
|
+
adaptive=adaptive,
|
|
569
|
+
temperature=temperature,
|
|
570
|
+
learn_temperature=False,
|
|
571
|
+
minimize=True,
|
|
572
|
+
)
|
|
573
|
+
self.replacement = replacement
|
|
574
|
+
|
|
575
|
+
def _select(
|
|
576
|
+
self,
|
|
577
|
+
population: Tensor,
|
|
578
|
+
fitness: Tensor,
|
|
579
|
+
n_select: int,
|
|
580
|
+
) -> Tuple[Tensor, Tensor]:
|
|
581
|
+
n_pop = population.shape[0]
|
|
582
|
+
device = population.device
|
|
583
|
+
|
|
584
|
+
if self.adaptive:
|
|
585
|
+
# Uniform logits
|
|
586
|
+
logits = torch.zeros(n_select, n_pop, device=device)
|
|
587
|
+
weights = gumbel_softmax(
|
|
588
|
+
logits,
|
|
589
|
+
temperature=self.temperature, # Pass temperature
|
|
590
|
+
dim=-1
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
selected = torch.matmul(weights, population)
|
|
594
|
+
indices = weights.argmax(dim=-1)
|
|
595
|
+
else:
|
|
596
|
+
if self.replacement:
|
|
597
|
+
indices = torch.randint(0, n_pop, (n_select,), device=device)
|
|
598
|
+
else:
|
|
599
|
+
if n_select > n_pop:
|
|
600
|
+
raise ValueError(
|
|
601
|
+
f"Cannot select {n_select} from {n_pop} without replacement"
|
|
602
|
+
)
|
|
603
|
+
indices = torch.randperm(n_pop, device=device)[:n_select]
|
|
604
|
+
|
|
605
|
+
selected = population[indices]
|
|
606
|
+
|
|
607
|
+
return selected, indices
|
|
608
|
+
|
|
609
|
+
def __repr__(self) -> str:
|
|
610
|
+
return f"RandomSelection(replacement={self.replacement})"
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
# =============================================================================
|
|
615
|
+
# Top-K Selection (Deterministic / Without Replacement)
|
|
616
|
+
# =============================================================================
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class TopKSelection(Selection):
|
|
620
|
+
"""
|
|
621
|
+
Top-k (elitist) selection without replacement.
|
|
622
|
+
|
|
623
|
+
Selects the best `k` individuals according to fitness. This is the
|
|
624
|
+
canonical "take the best" operator. Unlike most parent selectors,
|
|
625
|
+
this returns a *subset* of the population (no duplicates).
|
|
626
|
+
|
|
627
|
+
In differentiable mode, uses a sequential (without replacement)
|
|
628
|
+
straight-through Gumbel-Softmax relaxation.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
adaptive: If True, use differentiable top-k selection.
|
|
632
|
+
temperature: Temperature for Gumbel-Softmax.
|
|
633
|
+
learn_temperature: If True, temperature is learnable.
|
|
634
|
+
minimize: If True, lower fitness is better.
|
|
635
|
+
|
|
636
|
+
Example:
|
|
637
|
+
>>> # Keep best 10 individuals
|
|
638
|
+
>>> selector = TopKSelection()
|
|
639
|
+
>>> best10, idx = selector(population, fitness, n_select=10)
|
|
640
|
+
"""
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
@staticmethod
|
|
644
|
+
def best_indices(
|
|
645
|
+
fitness: Tensor,
|
|
646
|
+
k: int,
|
|
647
|
+
minimize: bool = True,
|
|
648
|
+
) -> Tensor:
|
|
649
|
+
"""Return indices of the best-k individuals (no replacement)."""
|
|
650
|
+
scores = -fitness if minimize else fitness
|
|
651
|
+
return torch.topk(scores, k=k, largest=True).indices
|
|
652
|
+
|
|
653
|
+
@staticmethod
|
|
654
|
+
def worst_indices(
|
|
655
|
+
fitness: Tensor,
|
|
656
|
+
k: int,
|
|
657
|
+
minimize: bool = True,
|
|
658
|
+
) -> Tensor:
|
|
659
|
+
"""Return indices of the worst-k individuals (no replacement)."""
|
|
660
|
+
scores = -fitness if minimize else fitness
|
|
661
|
+
return torch.topk(scores, k=k, largest=False).indices
|
|
662
|
+
|
|
663
|
+
@staticmethod
|
|
664
|
+
def sort_indices_best_first(
|
|
665
|
+
fitness: Tensor,
|
|
666
|
+
minimize: bool = True,
|
|
667
|
+
) -> Tensor:
|
|
668
|
+
"""Return indices sorting the population best->worst."""
|
|
669
|
+
scores = -fitness if minimize else fitness
|
|
670
|
+
return torch.argsort(scores, descending=True)
|
|
671
|
+
|
|
672
|
+
@staticmethod
|
|
673
|
+
def gumbel_topk(
|
|
674
|
+
logits: Tensor,
|
|
675
|
+
k: int,
|
|
676
|
+
temperature: Union[float, Tensor] = 1.0,
|
|
677
|
+
dim: int = -1,
|
|
678
|
+
eps: float = 1e-10,
|
|
679
|
+
) -> Tuple[Tensor, Tensor]:
|
|
680
|
+
"""
|
|
681
|
+
Differentiable top-k without replacement using sequential Gumbel-Softmax.
|
|
682
|
+
|
|
683
|
+
Forward pass returns hard one-hot selections (straight-through),
|
|
684
|
+
and we mask out previously selected items to avoid duplicates.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
logits: Logits over items [..., n_items].
|
|
688
|
+
k: Number of items to select.
|
|
689
|
+
temperature: Gumbel-Softmax temperature.
|
|
690
|
+
dim: Dimension along which to select.
|
|
691
|
+
eps: Numerical stability constant.
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
weights: Selection weights [..., k, n_items] (one-hot ST).
|
|
695
|
+
indices: Selected indices [..., k].
|
|
696
|
+
"""
|
|
697
|
+
if k <= 0:
|
|
698
|
+
raise ValueError(f"k must be > 0, got {k}")
|
|
699
|
+
|
|
700
|
+
n_items = logits.shape[dim]
|
|
701
|
+
if k > n_items:
|
|
702
|
+
raise ValueError(f"Cannot select k={k} from n_items={n_items}")
|
|
703
|
+
|
|
704
|
+
# Work on a copy we can mask in-place
|
|
705
|
+
masked_logits = logits.clone()
|
|
706
|
+
|
|
707
|
+
weights_list = []
|
|
708
|
+
indices_list = []
|
|
709
|
+
|
|
710
|
+
for _ in range(k):
|
|
711
|
+
w = gumbel_softmax(
|
|
712
|
+
masked_logits,
|
|
713
|
+
temperature=temperature,
|
|
714
|
+
dim=dim,
|
|
715
|
+
eps=eps,
|
|
716
|
+
)
|
|
717
|
+
idx = w.argmax(dim=dim)
|
|
718
|
+
|
|
719
|
+
weights_list.append(w)
|
|
720
|
+
indices_list.append(idx)
|
|
721
|
+
|
|
722
|
+
# Mask selected items so they cannot be chosen again
|
|
723
|
+
# (use a large negative number instead of -inf for safety)
|
|
724
|
+
if dim != -1:
|
|
725
|
+
# Move selection dim to end for consistent masking
|
|
726
|
+
perm = list(range(masked_logits.dim()))
|
|
727
|
+
perm[dim], perm[-1] = perm[-1], perm[dim]
|
|
728
|
+
masked_logits = masked_logits.permute(perm)
|
|
729
|
+
|
|
730
|
+
# idx now refers to last dim
|
|
731
|
+
scatter_idx = idx.unsqueeze(-1)
|
|
732
|
+
masked_logits = masked_logits.scatter(-1, scatter_idx, -1e9)
|
|
733
|
+
|
|
734
|
+
# Restore original dim order
|
|
735
|
+
inv = [0] * len(perm)
|
|
736
|
+
for i, p in enumerate(perm):
|
|
737
|
+
inv[p] = i
|
|
738
|
+
masked_logits = masked_logits.permute(inv)
|
|
739
|
+
else:
|
|
740
|
+
masked_logits = masked_logits.scatter(
|
|
741
|
+
-1, idx.unsqueeze(-1), -1e9
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
weights = torch.stack(weights_list, dim=-2) # [..., k, n_items]
|
|
745
|
+
indices = torch.stack(indices_list, dim=-1) # [..., k]
|
|
746
|
+
return weights, indices
|
|
747
|
+
|
|
748
|
+
def _select(
|
|
749
|
+
self,
|
|
750
|
+
population: Tensor,
|
|
751
|
+
fitness: Tensor,
|
|
752
|
+
n_select: int,
|
|
753
|
+
) -> Tuple[Tensor, Tensor]:
|
|
754
|
+
n_pop, n_var = population.shape
|
|
755
|
+
device = population.device
|
|
756
|
+
|
|
757
|
+
if n_select > n_pop:
|
|
758
|
+
raise ValueError(f"Cannot select {n_select} from {n_pop} without replacement")
|
|
759
|
+
|
|
760
|
+
scores = self._fitness_to_scores(fitness)
|
|
761
|
+
|
|
762
|
+
if self.adaptive:
|
|
763
|
+
# Sequential Gumbel top-k in score space
|
|
764
|
+
logits = scores.unsqueeze(0) # [1, n_pop]
|
|
765
|
+
weights, idx = self.gumbel_topk(
|
|
766
|
+
logits,
|
|
767
|
+
k=n_select,
|
|
768
|
+
temperature=self.temperature,
|
|
769
|
+
dim=-1,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# weights: [1, k, n_pop]
|
|
773
|
+
# Build k selected individuals as soft/hard mixtures
|
|
774
|
+
selected = torch.matmul(weights.squeeze(0), population) # [k, n_var]
|
|
775
|
+
indices = idx.squeeze(0) # [k]
|
|
776
|
+
else:
|
|
777
|
+
# Hard deterministic top-k
|
|
778
|
+
indices = torch.topk(scores, n_select, largest=True).indices
|
|
779
|
+
selected = population[indices]
|
|
780
|
+
|
|
781
|
+
return selected, indices
|
|
782
|
+
|
|
783
|
+
def __repr__(self) -> str:
|
|
784
|
+
return (
|
|
785
|
+
f"TopKSelection("
|
|
786
|
+
f"adaptive={self.adaptive}, "
|
|
787
|
+
f"temperature={self.temperature.item():.3f})"
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# =============================================================================
|
|
791
|
+
# Truncation Selection
|
|
792
|
+
# =============================================================================
|
|
793
|
+
|
|
794
|
+
class TruncationSelection(Selection):
|
|
795
|
+
"""
|
|
796
|
+
Truncation (elitist) selection.
|
|
797
|
+
|
|
798
|
+
Selects only from the top fraction of the population.
|
|
799
|
+
This provides strong selection pressure.
|
|
800
|
+
|
|
801
|
+
Args:
|
|
802
|
+
truncation_ratio: Fraction of population to select from (0, 1].
|
|
803
|
+
adaptive: If True, use softmax over truncated population.
|
|
804
|
+
temperature: Temperature for soft selection.
|
|
805
|
+
learn_temperature: If True, temperature is learnable.
|
|
806
|
+
minimize: If True, lower fitness is better.
|
|
807
|
+
|
|
808
|
+
Example:
|
|
809
|
+
>>> # Select from top 20% of population
|
|
810
|
+
>>> selector = TruncationSelection(truncation_ratio=0.2)
|
|
811
|
+
>>> parents = selector(population, fitness, n_select=50)
|
|
812
|
+
"""
|
|
813
|
+
|
|
814
|
+
def __init__(
|
|
815
|
+
self,
|
|
816
|
+
truncation_ratio: float = 0.5,
|
|
817
|
+
adaptive: bool = False,
|
|
818
|
+
temperature: float = 1.0,
|
|
819
|
+
learn_temperature: bool = True,
|
|
820
|
+
minimize: bool = True,
|
|
821
|
+
) -> None:
|
|
822
|
+
super().__init__(
|
|
823
|
+
adaptive=adaptive,
|
|
824
|
+
temperature=temperature,
|
|
825
|
+
learn_temperature=learn_temperature,
|
|
826
|
+
minimize=minimize,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
if not 0 < truncation_ratio <= 1:
|
|
830
|
+
raise ValueError(
|
|
831
|
+
f"truncation_ratio must be in (0, 1], got {truncation_ratio}"
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
self.truncation_ratio = truncation_ratio
|
|
835
|
+
|
|
836
|
+
def _select(
|
|
837
|
+
self,
|
|
838
|
+
population: Tensor,
|
|
839
|
+
fitness: Tensor,
|
|
840
|
+
n_select: int,
|
|
841
|
+
) -> Tuple[Tensor, Tensor]:
|
|
842
|
+
n_pop = population.shape[0]
|
|
843
|
+
device = population.device
|
|
844
|
+
|
|
845
|
+
scores = self._fitness_to_scores(fitness)
|
|
846
|
+
|
|
847
|
+
# Determine truncation size
|
|
848
|
+
n_truncated = max(1, int(n_pop * self.truncation_ratio))
|
|
849
|
+
|
|
850
|
+
# Get indices of top individuals
|
|
851
|
+
_, top_indices = torch.topk(scores, n_truncated)
|
|
852
|
+
|
|
853
|
+
if self.adaptive:
|
|
854
|
+
# Soft selection within truncated set
|
|
855
|
+
top_scores = scores[top_indices]
|
|
856
|
+
logits = top_scores.unsqueeze(0).expand(n_select, -1)
|
|
857
|
+
|
|
858
|
+
weights = gumbel_softmax(
|
|
859
|
+
logits,
|
|
860
|
+
temperature=self.temperature, # Pass temperature
|
|
861
|
+
dim=-1
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
top_pop = population[top_indices]
|
|
865
|
+
selected = torch.matmul(weights, top_pop)
|
|
866
|
+
|
|
867
|
+
relative_idx = weights.argmax(dim=-1)
|
|
868
|
+
indices = top_indices[relative_idx]
|
|
869
|
+
else:
|
|
870
|
+
# Random selection from truncated set
|
|
871
|
+
relative_idx = torch.randint(0, n_truncated, (n_select,), device=device)
|
|
872
|
+
indices = top_indices[relative_idx]
|
|
873
|
+
selected = population[indices]
|
|
874
|
+
|
|
875
|
+
return selected, indices
|
|
876
|
+
|
|
877
|
+
def __repr__(self) -> str:
|
|
878
|
+
return (
|
|
879
|
+
f"TruncationSelection("
|
|
880
|
+
f"truncation_ratio={self.truncation_ratio}, "
|
|
881
|
+
f"adaptive={self.adaptive})"
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
# =============================================================================
|
|
886
|
+
# Stochastic Universal Sampling (SUS)
|
|
887
|
+
# =============================================================================
|
|
888
|
+
|
|
889
|
+
class StochasticUniversalSampling(Selection):
|
|
890
|
+
"""
|
|
891
|
+
Stochastic Universal Sampling (SUS).
|
|
892
|
+
|
|
893
|
+
Similar to roulette selection but uses evenly spaced pointers
|
|
894
|
+
on the wheel, reducing variance and ensuring a more uniform
|
|
895
|
+
sampling of fit individuals.
|
|
896
|
+
|
|
897
|
+
Args:
|
|
898
|
+
adaptive: If True, use Gumbel-Softmax approximation.
|
|
899
|
+
temperature: Temperature for Gumbel-Softmax.
|
|
900
|
+
learn_temperature: If True, temperature is learnable.
|
|
901
|
+
minimize: If True, lower fitness is better.
|
|
902
|
+
eps: Small constant for numerical stability.
|
|
903
|
+
|
|
904
|
+
Example:
|
|
905
|
+
>>> selector = StochasticUniversalSampling()
|
|
906
|
+
>>> parents = selector(population, fitness, n_select=50)
|
|
907
|
+
"""
|
|
908
|
+
|
|
909
|
+
def __init__(
|
|
910
|
+
self,
|
|
911
|
+
adaptive: bool = False,
|
|
912
|
+
temperature: float = 1.0,
|
|
913
|
+
learn_temperature: bool = True,
|
|
914
|
+
minimize: bool = True,
|
|
915
|
+
eps: float = 1e-10,
|
|
916
|
+
) -> None:
|
|
917
|
+
super().__init__(
|
|
918
|
+
adaptive=adaptive,
|
|
919
|
+
temperature=temperature,
|
|
920
|
+
learn_temperature=learn_temperature,
|
|
921
|
+
minimize=minimize,
|
|
922
|
+
)
|
|
923
|
+
self.eps = eps
|
|
924
|
+
|
|
925
|
+
def _select(
|
|
926
|
+
self,
|
|
927
|
+
population: Tensor,
|
|
928
|
+
fitness: Tensor,
|
|
929
|
+
n_select: int,
|
|
930
|
+
) -> Tuple[Tensor, Tensor]:
|
|
931
|
+
n_pop = population.shape[0]
|
|
932
|
+
device = population.device
|
|
933
|
+
dtype = population.dtype
|
|
934
|
+
|
|
935
|
+
scores = self._fitness_to_scores(fitness)
|
|
936
|
+
|
|
937
|
+
# Shift to positive
|
|
938
|
+
shifted_scores = scores - scores.min() + self.eps
|
|
939
|
+
total = shifted_scores.sum()
|
|
940
|
+
|
|
941
|
+
if self.adaptive:
|
|
942
|
+
# Fall back to Gumbel-Softmax (SUS is inherently discrete)
|
|
943
|
+
logits = torch.log(shifted_scores + self.eps)
|
|
944
|
+
logits_expanded = logits.unsqueeze(0).expand(n_select, -1)
|
|
945
|
+
|
|
946
|
+
weights = gumbel_softmax(
|
|
947
|
+
logits_expanded,
|
|
948
|
+
temperature=self.temperature, # Pass temperature
|
|
949
|
+
dim=-1
|
|
950
|
+
)
|
|
951
|
+
selected = torch.matmul(weights, population)
|
|
952
|
+
indices = weights.argmax(dim=-1)
|
|
953
|
+
else:
|
|
954
|
+
# Classical SUS
|
|
955
|
+
# Compute cumulative sum
|
|
956
|
+
cumsum = torch.cumsum(shifted_scores, dim=0)
|
|
957
|
+
|
|
958
|
+
# Distance between pointers
|
|
959
|
+
pointer_distance = total / n_select
|
|
960
|
+
|
|
961
|
+
# Random starting point
|
|
962
|
+
start = torch.rand(1, device=device, dtype=dtype) * pointer_distance
|
|
963
|
+
|
|
964
|
+
# Pointers
|
|
965
|
+
pointers = start + pointer_distance * torch.arange(
|
|
966
|
+
n_select, device=device, dtype=dtype
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
# Find indices where pointers land
|
|
970
|
+
indices = torch.searchsorted(cumsum, pointers)
|
|
971
|
+
indices = torch.clamp(indices, 0, n_pop - 1)
|
|
972
|
+
|
|
973
|
+
selected = population[indices]
|
|
974
|
+
|
|
975
|
+
return selected, indices
|
|
976
|
+
|
|
977
|
+
def __repr__(self) -> str:
|
|
978
|
+
return (
|
|
979
|
+
f"StochasticUniversalSampling("
|
|
980
|
+
f"adaptive={self.adaptive})"
|
|
981
|
+
)
|