evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,662 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test script for EvoGrad operators module.
|
|
3
|
+
|
|
4
|
+
Tests:
|
|
5
|
+
- sampling.py: Population initialisation
|
|
6
|
+
- selection.py: Parent selection
|
|
7
|
+
- crossover.py: Recombination operators
|
|
8
|
+
- mutation.py: Mutation operators
|
|
9
|
+
- repair.py: Bounds handling
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import sys
|
|
13
|
+
import torch
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
# Add parent to path for imports
|
|
17
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
18
|
+
|
|
19
|
+
from core.problem import Problem
|
|
20
|
+
|
|
21
|
+
from operators.sampling import (
|
|
22
|
+
UniformSampling,
|
|
23
|
+
LatinHypercubeSampling,
|
|
24
|
+
NormalSampling,
|
|
25
|
+
LogUniformSampling,
|
|
26
|
+
HaltonSampling,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from operators.selection import (
|
|
30
|
+
TournamentSelection,
|
|
31
|
+
RouletteSelection,
|
|
32
|
+
RankSelection,
|
|
33
|
+
RandomSelection,
|
|
34
|
+
TruncationSelection,
|
|
35
|
+
StochasticUniversalSampling,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from operators.crossover import (
|
|
39
|
+
SBXCrossover,
|
|
40
|
+
BlendCrossover,
|
|
41
|
+
BinomialCrossover,
|
|
42
|
+
ExponentialCrossover,
|
|
43
|
+
UniformCrossover,
|
|
44
|
+
ArithmeticCrossover,
|
|
45
|
+
NPointCrossover,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
from operators.mutation import (
|
|
49
|
+
PolynomialMutation,
|
|
50
|
+
GaussianMutation,
|
|
51
|
+
UniformMutation,
|
|
52
|
+
NonUniformMutation,
|
|
53
|
+
BoundaryMutation,
|
|
54
|
+
NoMutation,
|
|
55
|
+
CombinedMutation,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
from operators.repair import (
|
|
59
|
+
ClipRepair,
|
|
60
|
+
ReflectRepair,
|
|
61
|
+
WrapRepair,
|
|
62
|
+
RandomRepair,
|
|
63
|
+
BoundsRepair,
|
|
64
|
+
SoftClipRepair,
|
|
65
|
+
PenaltyRepair,
|
|
66
|
+
NoRepair,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Helper to create a test problem
|
|
71
|
+
def create_test_problem(n_var=10):
|
|
72
|
+
def sphere(x):
|
|
73
|
+
return (x ** 2).sum(dim=-1)
|
|
74
|
+
return Problem(n_var=n_var, xl=-5.0, xu=5.0, objective=sphere)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_sampling():
|
|
78
|
+
"""Test sampling operators."""
|
|
79
|
+
print("\n" + "="*60)
|
|
80
|
+
print("Testing sampling.py")
|
|
81
|
+
print("="*60)
|
|
82
|
+
|
|
83
|
+
problem = create_test_problem(n_var=5)
|
|
84
|
+
n_samples = 50
|
|
85
|
+
|
|
86
|
+
# Test UniformSampling
|
|
87
|
+
print("\n1. Testing UniformSampling...")
|
|
88
|
+
sampler = UniformSampling(seed=42)
|
|
89
|
+
pop = sampler(n_samples, problem)
|
|
90
|
+
print(f" Shape: {pop.shape}")
|
|
91
|
+
assert pop.shape == (n_samples, 5)
|
|
92
|
+
assert (pop >= problem.xl).all() and (pop <= problem.xu).all()
|
|
93
|
+
print(f" All within bounds: ✓")
|
|
94
|
+
|
|
95
|
+
# Test reproducibility
|
|
96
|
+
sampler2 = UniformSampling(seed=42)
|
|
97
|
+
pop2 = sampler2(n_samples, problem)
|
|
98
|
+
assert torch.allclose(pop, pop2)
|
|
99
|
+
print(f" Reproducibility with seed: ✓")
|
|
100
|
+
|
|
101
|
+
# Test LatinHypercubeSampling
|
|
102
|
+
print("\n2. Testing LatinHypercubeSampling...")
|
|
103
|
+
lhs = LatinHypercubeSampling(smooth=True, seed=42)
|
|
104
|
+
pop_lhs = lhs(n_samples, problem)
|
|
105
|
+
print(f" Shape: {pop_lhs.shape}")
|
|
106
|
+
assert pop_lhs.shape == (n_samples, 5)
|
|
107
|
+
assert (pop_lhs >= problem.xl).all() and (pop_lhs <= problem.xu).all()
|
|
108
|
+
print(f" All within bounds: ✓")
|
|
109
|
+
|
|
110
|
+
# Test NormalSampling
|
|
111
|
+
print("\n3. Testing NormalSampling...")
|
|
112
|
+
normal = NormalSampling(sigma_factor=0.2, seed=42)
|
|
113
|
+
pop_normal = normal(n_samples, problem)
|
|
114
|
+
print(f" Shape: {pop_normal.shape}")
|
|
115
|
+
# Most should be within bounds (3-sigma)
|
|
116
|
+
within = ((pop_normal >= problem.xl) & (pop_normal <= problem.xu)).float().mean()
|
|
117
|
+
print(f" Fraction within bounds: {within:.1%}")
|
|
118
|
+
|
|
119
|
+
# Test LogUniformSampling
|
|
120
|
+
print("\n4. Testing LogUniformSampling...")
|
|
121
|
+
# Use positive bounds for log sampling
|
|
122
|
+
log_problem = Problem(n_var=5, xl=0.001, xu=1000.0, objective=lambda x: x.sum(-1))
|
|
123
|
+
log_sampler = LogUniformSampling(base=10, seed=42)
|
|
124
|
+
pop_log = log_sampler(n_samples, log_problem)
|
|
125
|
+
print(f" Shape: {pop_log.shape}")
|
|
126
|
+
print(f" Min: {pop_log.min().item():.4f}, Max: {pop_log.max().item():.4f}")
|
|
127
|
+
assert (pop_log > 0).all()
|
|
128
|
+
|
|
129
|
+
# Test HaltonSampling
|
|
130
|
+
print("\n5. Testing HaltonSampling...")
|
|
131
|
+
halton = HaltonSampling(scramble=True, seed=42)
|
|
132
|
+
pop_halton = halton(n_samples, problem)
|
|
133
|
+
print(f" Shape: {pop_halton.shape}")
|
|
134
|
+
assert pop_halton.shape == (n_samples, 5)
|
|
135
|
+
assert (pop_halton >= problem.xl).all() and (pop_halton <= problem.xu).all()
|
|
136
|
+
print(f" All within bounds: ✓")
|
|
137
|
+
|
|
138
|
+
print("\n✓ sampling.py tests passed!")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_selection():
|
|
142
|
+
"""Test selection operators."""
|
|
143
|
+
print("\n" + "="*60)
|
|
144
|
+
print("Testing selection.py")
|
|
145
|
+
print("="*60)
|
|
146
|
+
|
|
147
|
+
# Create test population and fitness
|
|
148
|
+
n_pop = 50
|
|
149
|
+
n_var = 5
|
|
150
|
+
population = torch.randn(n_pop, n_var)
|
|
151
|
+
fitness = torch.randn(n_pop) # Lower is better
|
|
152
|
+
|
|
153
|
+
n_select = 30
|
|
154
|
+
|
|
155
|
+
# Test TournamentSelection
|
|
156
|
+
print("\n1. Testing TournamentSelection...")
|
|
157
|
+
tournament = TournamentSelection(tournament_size=3, replacement=True)
|
|
158
|
+
selected = tournament(population, fitness, n_select)
|
|
159
|
+
print(f" Selected shape: {selected.shape}")
|
|
160
|
+
assert selected.shape == (n_select, n_var)
|
|
161
|
+
|
|
162
|
+
# Test with indices
|
|
163
|
+
selected, indices = tournament(population, fitness, n_select, return_indices=True)
|
|
164
|
+
print(f" Indices shape: {indices.shape}")
|
|
165
|
+
assert indices.shape == (n_select,)
|
|
166
|
+
assert (indices >= 0).all() and (indices < n_pop).all()
|
|
167
|
+
print(f" Index range valid: ✓")
|
|
168
|
+
|
|
169
|
+
# Test differentiable mode
|
|
170
|
+
print("\n2. Testing TournamentSelection (differentiable)...")
|
|
171
|
+
tournament_diff = TournamentSelection(
|
|
172
|
+
tournament_size=3,
|
|
173
|
+
differentiable=True,
|
|
174
|
+
temperature=1.0,
|
|
175
|
+
learn_temperature=True,
|
|
176
|
+
)
|
|
177
|
+
pop_param = torch.nn.Parameter(population.clone())
|
|
178
|
+
selected_diff = tournament_diff(pop_param, fitness, n_select)
|
|
179
|
+
print(f" Differentiable selection shape: {selected_diff.shape}")
|
|
180
|
+
|
|
181
|
+
# Check gradients flow
|
|
182
|
+
loss = selected_diff.sum()
|
|
183
|
+
loss.backward()
|
|
184
|
+
assert pop_param.grad is not None
|
|
185
|
+
print(f" Gradients flow: ✓")
|
|
186
|
+
|
|
187
|
+
# Test RouletteSelection
|
|
188
|
+
print("\n3. Testing RouletteSelection...")
|
|
189
|
+
roulette = RouletteSelection()
|
|
190
|
+
selected_roulette = roulette(population, fitness, n_select)
|
|
191
|
+
print(f" Selected shape: {selected_roulette.shape}")
|
|
192
|
+
assert selected_roulette.shape == (n_select, n_var)
|
|
193
|
+
|
|
194
|
+
# Test RankSelection
|
|
195
|
+
print("\n4. Testing RankSelection...")
|
|
196
|
+
rank_sel = RankSelection(scheme='linear', selection_pressure=1.5)
|
|
197
|
+
selected_rank = rank_sel(population, fitness, n_select)
|
|
198
|
+
print(f" Selected shape: {selected_rank.shape}")
|
|
199
|
+
|
|
200
|
+
rank_exp = RankSelection(scheme='exponential')
|
|
201
|
+
selected_rank_exp = rank_exp(population, fitness, n_select)
|
|
202
|
+
print(f" Exponential scheme works: ✓")
|
|
203
|
+
|
|
204
|
+
# Test RandomSelection
|
|
205
|
+
print("\n5. Testing RandomSelection...")
|
|
206
|
+
random_sel = RandomSelection(replacement=True)
|
|
207
|
+
selected_random = random_sel(population, fitness, n_select)
|
|
208
|
+
print(f" Selected shape: {selected_random.shape}")
|
|
209
|
+
|
|
210
|
+
# Test TruncationSelection
|
|
211
|
+
print("\n6. Testing TruncationSelection...")
|
|
212
|
+
truncation = TruncationSelection(truncation_ratio=0.5)
|
|
213
|
+
selected_trunc = truncation(population, fitness, n_select)
|
|
214
|
+
print(f" Selected shape: {selected_trunc.shape}")
|
|
215
|
+
|
|
216
|
+
# Test StochasticUniversalSampling
|
|
217
|
+
print("\n7. Testing StochasticUniversalSampling...")
|
|
218
|
+
sus = StochasticUniversalSampling()
|
|
219
|
+
selected_sus = sus(population, fitness, n_select)
|
|
220
|
+
print(f" Selected shape: {selected_sus.shape}")
|
|
221
|
+
|
|
222
|
+
print("\n✓ selection.py tests passed!")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def test_crossover():
|
|
226
|
+
"""Test crossover operators."""
|
|
227
|
+
print("\n" + "="*60)
|
|
228
|
+
print("Testing crossover.py")
|
|
229
|
+
print("="*60)
|
|
230
|
+
|
|
231
|
+
n_pairs = 25
|
|
232
|
+
n_var = 10
|
|
233
|
+
|
|
234
|
+
# Create parent pairs
|
|
235
|
+
parent1 = torch.randn(n_pairs, n_var)
|
|
236
|
+
parent2 = torch.randn(n_pairs, n_var)
|
|
237
|
+
|
|
238
|
+
# Test SBXCrossover
|
|
239
|
+
print("\n1. Testing SBXCrossover...")
|
|
240
|
+
sbx = SBXCrossover(eta=15, prob=0.9)
|
|
241
|
+
offspring = sbx(parent1, parent2)
|
|
242
|
+
print(f" Offspring shape: {offspring.shape}")
|
|
243
|
+
assert offspring.shape == (n_pairs, n_var)
|
|
244
|
+
|
|
245
|
+
# Test differentiable mode
|
|
246
|
+
print("\n2. Testing SBXCrossover (differentiable)...")
|
|
247
|
+
sbx_diff = SBXCrossover(
|
|
248
|
+
eta=15,
|
|
249
|
+
prob=0.9,
|
|
250
|
+
differentiable=True,
|
|
251
|
+
learn_eta=True,
|
|
252
|
+
learn_prob=True,
|
|
253
|
+
)
|
|
254
|
+
p1 = torch.nn.Parameter(parent1.clone())
|
|
255
|
+
p2 = torch.nn.Parameter(parent2.clone())
|
|
256
|
+
offspring_diff = sbx_diff(p1, p2)
|
|
257
|
+
|
|
258
|
+
loss = offspring_diff.sum()
|
|
259
|
+
loss.backward()
|
|
260
|
+
assert p1.grad is not None
|
|
261
|
+
print(f" Gradients flow through SBX: ✓")
|
|
262
|
+
print(f" Learnable eta: {sbx_diff._log_eta.item():.4f}")
|
|
263
|
+
|
|
264
|
+
# Test BlendCrossover
|
|
265
|
+
print("\n3. Testing BlendCrossover...")
|
|
266
|
+
blend = BlendCrossover(alpha=0.5)
|
|
267
|
+
offspring_blend = blend(parent1, parent2)
|
|
268
|
+
print(f" Offspring shape: {offspring_blend.shape}")
|
|
269
|
+
|
|
270
|
+
# Test BinomialCrossover (DE-style)
|
|
271
|
+
print("\n4. Testing BinomialCrossover...")
|
|
272
|
+
binomial = BinomialCrossover(cr=0.9)
|
|
273
|
+
offspring_bin = binomial(parent1, parent2) # parent1=target, parent2=donor
|
|
274
|
+
print(f" Offspring shape: {offspring_bin.shape}")
|
|
275
|
+
|
|
276
|
+
# Test differentiable binomial
|
|
277
|
+
binomial_diff = BinomialCrossover(cr=0.9, differentiable=True, learn_cr=True)
|
|
278
|
+
p1 = torch.nn.Parameter(parent1.clone())
|
|
279
|
+
offspring_bin_diff = binomial_diff(p1, parent2)
|
|
280
|
+
loss = offspring_bin_diff.sum()
|
|
281
|
+
loss.backward()
|
|
282
|
+
assert p1.grad is not None
|
|
283
|
+
print(f" Gradients flow through binomial: ✓")
|
|
284
|
+
|
|
285
|
+
# Test ExponentialCrossover
|
|
286
|
+
print("\n5. Testing ExponentialCrossover...")
|
|
287
|
+
exponential = ExponentialCrossover(cr=0.9)
|
|
288
|
+
offspring_exp = exponential(parent1, parent2)
|
|
289
|
+
print(f" Offspring shape: {offspring_exp.shape}")
|
|
290
|
+
|
|
291
|
+
# Test UniformCrossover
|
|
292
|
+
print("\n6. Testing UniformCrossover...")
|
|
293
|
+
uniform = UniformCrossover(prob=0.9)
|
|
294
|
+
offspring_unif = uniform(parent1, parent2)
|
|
295
|
+
print(f" Offspring shape: {offspring_unif.shape}")
|
|
296
|
+
|
|
297
|
+
# Test ArithmeticCrossover
|
|
298
|
+
print("\n7. Testing ArithmeticCrossover...")
|
|
299
|
+
arithmetic = ArithmeticCrossover(alpha=0.5, whole=True)
|
|
300
|
+
offspring_arith = arithmetic(parent1, parent2)
|
|
301
|
+
print(f" Offspring shape: {offspring_arith.shape}")
|
|
302
|
+
|
|
303
|
+
# Verify whole arithmetic is weighted average
|
|
304
|
+
expected = 0.5 * parent1 + 0.5 * parent2
|
|
305
|
+
assert torch.allclose(offspring_arith, expected)
|
|
306
|
+
print(f" Weighted average verified: ✓")
|
|
307
|
+
|
|
308
|
+
# Test NPointCrossover
|
|
309
|
+
print("\n8. Testing NPointCrossover...")
|
|
310
|
+
npoint = NPointCrossover(n_points=2)
|
|
311
|
+
offspring_npoint = npoint(parent1, parent2)
|
|
312
|
+
print(f" Offspring shape: {offspring_npoint.shape}")
|
|
313
|
+
|
|
314
|
+
print("\n✓ crossover.py tests passed!")
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def test_mutation():
|
|
318
|
+
"""Test mutation operators."""
|
|
319
|
+
print("\n" + "="*60)
|
|
320
|
+
print("Testing mutation.py")
|
|
321
|
+
print("="*60)
|
|
322
|
+
|
|
323
|
+
n_pop = 50
|
|
324
|
+
n_var = 10
|
|
325
|
+
|
|
326
|
+
population = torch.randn(n_pop, n_var)
|
|
327
|
+
xl = torch.full((n_var,), -5.0)
|
|
328
|
+
xu = torch.full((n_var,), 5.0)
|
|
329
|
+
|
|
330
|
+
# Test PolynomialMutation
|
|
331
|
+
print("\n1. Testing PolynomialMutation...")
|
|
332
|
+
poly_mut = PolynomialMutation(eta=20, prob=0.1)
|
|
333
|
+
mutated = poly_mut(population, xl, xu)
|
|
334
|
+
print(f" Mutated shape: {mutated.shape}")
|
|
335
|
+
assert mutated.shape == population.shape
|
|
336
|
+
|
|
337
|
+
# Check some genes changed
|
|
338
|
+
changed = (mutated != population).any(dim=0).sum()
|
|
339
|
+
print(f" Genes with changes: {changed.item()}/{n_var}")
|
|
340
|
+
|
|
341
|
+
# Test differentiable mode
|
|
342
|
+
print("\n2. Testing PolynomialMutation (differentiable)...")
|
|
343
|
+
poly_diff = PolynomialMutation(
|
|
344
|
+
eta=20,
|
|
345
|
+
prob=0.1,
|
|
346
|
+
differentiable=True,
|
|
347
|
+
learn_eta=True,
|
|
348
|
+
learn_prob=True,
|
|
349
|
+
)
|
|
350
|
+
pop_param = torch.nn.Parameter(population.clone())
|
|
351
|
+
mutated_diff = poly_diff(pop_param, xl, xu)
|
|
352
|
+
|
|
353
|
+
loss = mutated_diff.sum()
|
|
354
|
+
loss.backward()
|
|
355
|
+
assert pop_param.grad is not None
|
|
356
|
+
print(f" Gradients flow through polynomial mutation: ✓")
|
|
357
|
+
|
|
358
|
+
# Test GaussianMutation
|
|
359
|
+
print("\n3. Testing GaussianMutation...")
|
|
360
|
+
gauss_mut = GaussianMutation(sigma=0.1, prob=0.2)
|
|
361
|
+
mutated_gauss = gauss_mut(population, xl, xu)
|
|
362
|
+
print(f" Mutated shape: {mutated_gauss.shape}")
|
|
363
|
+
|
|
364
|
+
# Test with sigma_frac
|
|
365
|
+
gauss_frac = GaussianMutation(sigma_frac=0.05)
|
|
366
|
+
mutated_frac = gauss_frac(population, xl, xu)
|
|
367
|
+
print(f" With sigma_frac: ✓")
|
|
368
|
+
|
|
369
|
+
# Test differentiable Gaussian
|
|
370
|
+
gauss_diff = GaussianMutation(sigma=0.1, differentiable=True, learn_sigma=True)
|
|
371
|
+
pop_param = torch.nn.Parameter(population.clone())
|
|
372
|
+
mutated_gauss_diff = gauss_diff(pop_param, xl, xu)
|
|
373
|
+
loss = mutated_gauss_diff.sum()
|
|
374
|
+
loss.backward()
|
|
375
|
+
assert pop_param.grad is not None
|
|
376
|
+
print(f" Gradients flow through Gaussian mutation: ✓")
|
|
377
|
+
|
|
378
|
+
# Test UniformMutation
|
|
379
|
+
print("\n4. Testing UniformMutation...")
|
|
380
|
+
unif_mut = UniformMutation(prob=0.1)
|
|
381
|
+
mutated_unif = unif_mut(population, xl, xu)
|
|
382
|
+
print(f" Mutated shape: {mutated_unif.shape}")
|
|
383
|
+
|
|
384
|
+
# Test NonUniformMutation
|
|
385
|
+
print("\n5. Testing NonUniformMutation...")
|
|
386
|
+
nonunif_mut = NonUniformMutation(max_generations=100, b=5.0)
|
|
387
|
+
nonunif_mut.set_generation(50)
|
|
388
|
+
mutated_nonunif = nonunif_mut(population, xl, xu)
|
|
389
|
+
print(f" Mutated shape: {mutated_nonunif.shape}")
|
|
390
|
+
print(f" Current generation: {nonunif_mut.generation}")
|
|
391
|
+
|
|
392
|
+
# Test BoundaryMutation
|
|
393
|
+
print("\n6. Testing BoundaryMutation...")
|
|
394
|
+
boundary_mut = BoundaryMutation(prob=0.1)
|
|
395
|
+
mutated_boundary = boundary_mut(population, xl, xu)
|
|
396
|
+
print(f" Mutated shape: {mutated_boundary.shape}")
|
|
397
|
+
|
|
398
|
+
# Check boundary values
|
|
399
|
+
at_lower = (mutated_boundary == xl).any()
|
|
400
|
+
at_upper = (mutated_boundary == xu).any()
|
|
401
|
+
print(f" Has values at lower bound: {at_lower}")
|
|
402
|
+
print(f" Has values at upper bound: {at_upper}")
|
|
403
|
+
|
|
404
|
+
# Test NoMutation
|
|
405
|
+
print("\n7. Testing NoMutation...")
|
|
406
|
+
no_mut = NoMutation()
|
|
407
|
+
mutated_none = no_mut(population, xl, xu)
|
|
408
|
+
assert torch.allclose(mutated_none, population)
|
|
409
|
+
print(f" NoMutation returns unchanged: ✓")
|
|
410
|
+
|
|
411
|
+
# Test CombinedMutation
|
|
412
|
+
print("\n8. Testing CombinedMutation...")
|
|
413
|
+
combined = CombinedMutation([
|
|
414
|
+
GaussianMutation(sigma=0.05, prob=0.5),
|
|
415
|
+
PolynomialMutation(eta=20, prob=0.1),
|
|
416
|
+
])
|
|
417
|
+
mutated_combined = combined(population, xl, xu)
|
|
418
|
+
print(f" Combined mutation shape: {mutated_combined.shape}")
|
|
419
|
+
|
|
420
|
+
print("\n✓ mutation.py tests passed!")
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def test_repair():
|
|
424
|
+
"""Test repair operators."""
|
|
425
|
+
print("\n" + "="*60)
|
|
426
|
+
print("Testing repair.py")
|
|
427
|
+
print("="*60)
|
|
428
|
+
|
|
429
|
+
n_pop = 50
|
|
430
|
+
n_var = 5
|
|
431
|
+
|
|
432
|
+
xl = torch.zeros(n_var)
|
|
433
|
+
xu = torch.ones(n_var)
|
|
434
|
+
|
|
435
|
+
# Create population with violations
|
|
436
|
+
population = torch.randn(n_pop, n_var) * 2 # Some will be outside [0, 1]
|
|
437
|
+
|
|
438
|
+
n_violations_before = ((population < xl) | (population > xu)).sum().item()
|
|
439
|
+
print(f"\n Violations before repair: {n_violations_before}")
|
|
440
|
+
|
|
441
|
+
# Test ClipRepair
|
|
442
|
+
print("\n1. Testing ClipRepair...")
|
|
443
|
+
clip = ClipRepair()
|
|
444
|
+
repaired_clip = clip(population, xl, xu)
|
|
445
|
+
violations_clip = ((repaired_clip < xl) | (repaired_clip > xu)).sum().item()
|
|
446
|
+
print(f" Violations after clip: {violations_clip}")
|
|
447
|
+
assert violations_clip == 0
|
|
448
|
+
assert (repaired_clip >= xl).all() and (repaired_clip <= xu).all()
|
|
449
|
+
print(f" All within bounds: ✓")
|
|
450
|
+
|
|
451
|
+
# Test ReflectRepair
|
|
452
|
+
print("\n2. Testing ReflectRepair...")
|
|
453
|
+
reflect = ReflectRepair()
|
|
454
|
+
repaired_reflect = reflect(population, xl, xu)
|
|
455
|
+
violations_reflect = ((repaired_reflect < xl) | (repaired_reflect > xu)).sum().item()
|
|
456
|
+
print(f" Violations after reflect: {violations_reflect}")
|
|
457
|
+
assert violations_reflect == 0
|
|
458
|
+
print(f" All within bounds: ✓")
|
|
459
|
+
|
|
460
|
+
# Test that reflection preserves "momentum"
|
|
461
|
+
x_test = torch.tensor([[1.3]]) # 0.3 above upper bound of 1
|
|
462
|
+
xl_test = torch.tensor([0.0])
|
|
463
|
+
xu_test = torch.tensor([1.0])
|
|
464
|
+
reflected = reflect(x_test, xl_test, xu_test)
|
|
465
|
+
print(f" 1.3 reflects to: {reflected.item():.2f} (expected ~0.7)")
|
|
466
|
+
assert abs(reflected.item() - 0.7) < 0.01
|
|
467
|
+
|
|
468
|
+
# Test WrapRepair
|
|
469
|
+
print("\n3. Testing WrapRepair...")
|
|
470
|
+
wrap = WrapRepair()
|
|
471
|
+
repaired_wrap = wrap(population, xl, xu)
|
|
472
|
+
violations_wrap = ((repaired_wrap < xl) | (repaired_wrap > xu)).sum().item()
|
|
473
|
+
print(f" Violations after wrap: {violations_wrap}")
|
|
474
|
+
assert violations_wrap == 0
|
|
475
|
+
|
|
476
|
+
# Test wrapping behaviour
|
|
477
|
+
x_test = torch.tensor([[1.3]])
|
|
478
|
+
wrapped = wrap(x_test, xl_test, xu_test)
|
|
479
|
+
print(f" 1.3 wraps to: {wrapped.item():.2f} (expected ~0.3)")
|
|
480
|
+
assert abs(wrapped.item() - 0.3) < 0.01
|
|
481
|
+
|
|
482
|
+
# Test RandomRepair
|
|
483
|
+
print("\n4. Testing RandomRepair...")
|
|
484
|
+
random_repair = RandomRepair()
|
|
485
|
+
repaired_random = random_repair(population, xl, xu)
|
|
486
|
+
violations_random = ((repaired_random < xl) | (repaired_random > xu)).sum().item()
|
|
487
|
+
print(f" Violations after random: {violations_random}")
|
|
488
|
+
assert violations_random == 0
|
|
489
|
+
|
|
490
|
+
# Test BoundsRepair (configurable)
|
|
491
|
+
print("\n5. Testing BoundsRepair...")
|
|
492
|
+
for method in ['clip', 'reflect', 'wrap', 'random']:
|
|
493
|
+
bounds_repair = BoundsRepair(method=method)
|
|
494
|
+
repaired = bounds_repair(population, xl, xu)
|
|
495
|
+
violations = ((repaired < xl) | (repaired > xu)).sum().item()
|
|
496
|
+
assert violations == 0
|
|
497
|
+
print(f" Method '{method}': ✓")
|
|
498
|
+
|
|
499
|
+
# Test SoftClipRepair
|
|
500
|
+
print("\n6. Testing SoftClipRepair...")
|
|
501
|
+
soft_clip = SoftClipRepair(beta=10.0)
|
|
502
|
+
repaired_soft = soft_clip(population, xl, xu)
|
|
503
|
+
print(f" Soft clip shape: {repaired_soft.shape}")
|
|
504
|
+
|
|
505
|
+
# Check gradients flow
|
|
506
|
+
pop_param = torch.nn.Parameter(population.clone())
|
|
507
|
+
soft_repaired = soft_clip(pop_param, xl, xu)
|
|
508
|
+
loss = soft_repaired.sum()
|
|
509
|
+
loss.backward()
|
|
510
|
+
assert pop_param.grad is not None
|
|
511
|
+
print(f" Gradients flow through soft clip: ✓")
|
|
512
|
+
|
|
513
|
+
# Test PenaltyRepair
|
|
514
|
+
print("\n7. Testing PenaltyRepair...")
|
|
515
|
+
penalty_repair = PenaltyRepair(penalty_weight=100.0, power=2.0)
|
|
516
|
+
|
|
517
|
+
# Compute penalty
|
|
518
|
+
penalty = penalty_repair.compute_penalty(population, xl, xu)
|
|
519
|
+
print(f" Penalty shape: {penalty.shape}")
|
|
520
|
+
print(f" Mean penalty: {penalty.mean().item():.2f}")
|
|
521
|
+
|
|
522
|
+
# Verify no repair happens
|
|
523
|
+
repaired_penalty = penalty_repair(population, xl, xu)
|
|
524
|
+
assert torch.allclose(repaired_penalty, population)
|
|
525
|
+
print(f" PenaltyRepair returns unchanged: ✓")
|
|
526
|
+
|
|
527
|
+
# Test NoRepair
|
|
528
|
+
print("\n8. Testing NoRepair...")
|
|
529
|
+
no_repair = NoRepair()
|
|
530
|
+
repaired_none = no_repair(population, xl, xu)
|
|
531
|
+
assert torch.allclose(repaired_none, population)
|
|
532
|
+
print(f" NoRepair returns unchanged: ✓")
|
|
533
|
+
|
|
534
|
+
# Test is_within_bounds helper
|
|
535
|
+
print("\n9. Testing is_within_bounds...")
|
|
536
|
+
feasible = clip.is_within_bounds(repaired_clip, xl, xu)
|
|
537
|
+
assert feasible.all()
|
|
538
|
+
|
|
539
|
+
infeasible = clip.is_within_bounds(population, xl, xu)
|
|
540
|
+
print(f" Feasible before repair: {infeasible.sum().item()}/{n_pop}")
|
|
541
|
+
|
|
542
|
+
print("\n✓ repair.py tests passed!")
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def test_operator_integration():
|
|
546
|
+
"""Test operators working together."""
|
|
547
|
+
print("\n" + "="*60)
|
|
548
|
+
print("Testing Operator Integration")
|
|
549
|
+
print("="*60)
|
|
550
|
+
|
|
551
|
+
# Create problem
|
|
552
|
+
problem = create_test_problem(n_var=10)
|
|
553
|
+
|
|
554
|
+
print("\n1. Testing full GA-style pipeline...")
|
|
555
|
+
|
|
556
|
+
# Sampling
|
|
557
|
+
sampling = UniformSampling(seed=42)
|
|
558
|
+
population = sampling(50, problem)
|
|
559
|
+
print(f" Initial population: {population.shape}")
|
|
560
|
+
|
|
561
|
+
# Evaluate
|
|
562
|
+
fitness = problem.evaluate(population)
|
|
563
|
+
print(f" Fitness: {fitness.shape}")
|
|
564
|
+
|
|
565
|
+
# Selection
|
|
566
|
+
selection = TournamentSelection(tournament_size=3)
|
|
567
|
+
parents = selection(population, fitness, 50)
|
|
568
|
+
print(f" Selected parents: {parents.shape}")
|
|
569
|
+
|
|
570
|
+
# Crossover
|
|
571
|
+
crossover = SBXCrossover(eta=15, prob=0.9)
|
|
572
|
+
p1, p2 = parents[:25], parents[25:]
|
|
573
|
+
offspring = crossover(p1, p2)
|
|
574
|
+
print(f" Offspring after crossover: {offspring.shape}")
|
|
575
|
+
|
|
576
|
+
# Mutation
|
|
577
|
+
mutation = PolynomialMutation(eta=20, prob=0.1)
|
|
578
|
+
offspring = mutation(offspring, problem.xl, problem.xu)
|
|
579
|
+
print(f" Offspring after mutation: {offspring.shape}")
|
|
580
|
+
|
|
581
|
+
# Repair
|
|
582
|
+
repair = ReflectRepair()
|
|
583
|
+
offspring = repair(offspring, problem.xl, problem.xu)
|
|
584
|
+
print(f" Offspring after repair: {offspring.shape}")
|
|
585
|
+
|
|
586
|
+
# Verify bounds
|
|
587
|
+
assert (offspring >= problem.xl).all() and (offspring <= problem.xu).all()
|
|
588
|
+
print(f" All offspring within bounds: ✓")
|
|
589
|
+
|
|
590
|
+
print("\n2. Testing differentiable pipeline...")
|
|
591
|
+
|
|
592
|
+
# Create differentiable operators
|
|
593
|
+
selection_diff = TournamentSelection(
|
|
594
|
+
tournament_size=3,
|
|
595
|
+
differentiable=True,
|
|
596
|
+
temperature=1.0,
|
|
597
|
+
)
|
|
598
|
+
crossover_diff = SBXCrossover(
|
|
599
|
+
eta=15,
|
|
600
|
+
prob=0.9,
|
|
601
|
+
differentiable=True,
|
|
602
|
+
learn_eta=True,
|
|
603
|
+
)
|
|
604
|
+
mutation_diff = GaussianMutation(
|
|
605
|
+
sigma=0.1,
|
|
606
|
+
differentiable=True,
|
|
607
|
+
learn_sigma=True,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Run through pipeline with gradient tracking
|
|
611
|
+
pop_param = torch.nn.Parameter(population.clone())
|
|
612
|
+
|
|
613
|
+
parents_diff = selection_diff(pop_param, fitness, 50)
|
|
614
|
+
p1_diff, p2_diff = parents_diff[:25], parents_diff[25:]
|
|
615
|
+
offspring_diff = crossover_diff(p1_diff, p2_diff)
|
|
616
|
+
offspring_diff = mutation_diff(offspring_diff, problem.xl, problem.xu)
|
|
617
|
+
|
|
618
|
+
# Compute loss and backprop
|
|
619
|
+
loss = offspring_diff.sum()
|
|
620
|
+
loss.backward()
|
|
621
|
+
|
|
622
|
+
assert pop_param.grad is not None
|
|
623
|
+
print(f" Gradients flow through entire pipeline: ✓")
|
|
624
|
+
|
|
625
|
+
# Check learnable parameters have gradients
|
|
626
|
+
assert crossover_diff._log_eta.grad is not None
|
|
627
|
+
print(f" Crossover eta gradient: {crossover_diff._log_eta.grad.item():.6f}")
|
|
628
|
+
|
|
629
|
+
assert mutation_diff._log_sigma.grad is not None
|
|
630
|
+
print(f" Mutation sigma gradient: {mutation_diff._log_sigma.grad.item():.6f}")
|
|
631
|
+
|
|
632
|
+
print("\n✓ Operator integration tests passed!")
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def run_all_tests():
|
|
636
|
+
"""Run all operator tests."""
|
|
637
|
+
print("\n" + "#"*60)
|
|
638
|
+
print("# EvoGrad Operators Module Tests")
|
|
639
|
+
print("#"*60)
|
|
640
|
+
|
|
641
|
+
try:
|
|
642
|
+
test_sampling()
|
|
643
|
+
test_selection()
|
|
644
|
+
test_crossover()
|
|
645
|
+
test_mutation()
|
|
646
|
+
test_repair()
|
|
647
|
+
test_operator_integration()
|
|
648
|
+
|
|
649
|
+
print("\n" + "="*60)
|
|
650
|
+
print("✓ ALL OPERATORS TESTS PASSED!")
|
|
651
|
+
print("="*60)
|
|
652
|
+
return True
|
|
653
|
+
except Exception as e:
|
|
654
|
+
print(f"\n✗ TEST FAILED: {e}")
|
|
655
|
+
import traceback
|
|
656
|
+
traceback.print_exc()
|
|
657
|
+
return False
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
if __name__ == "__main__":
|
|
661
|
+
success = run_all_tests()
|
|
662
|
+
sys.exit(0 if success else 1)
|