evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
evograd/tests/test_ga.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test script for EvoGrad GA (Genetic Algorithm) implementation.
|
|
3
|
+
|
|
4
|
+
Tests:
|
|
5
|
+
- GA creation with default and custom operators
|
|
6
|
+
- Different survival strategies (plus, comma, replace_worst)
|
|
7
|
+
- Elitism behavior
|
|
8
|
+
- Classical and differentiable modes
|
|
9
|
+
- State persistence (save/load)
|
|
10
|
+
- Convergence on test functions
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
cd evograd && python tests/test_ga.py
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import sys
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
# Add parent directory to path for imports
|
|
22
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
23
|
+
|
|
24
|
+
from evograd.core.problem import Problem
|
|
25
|
+
from evograd.algorithms.ga import GA, ga_default, ga_steady_state, ga_comma
|
|
26
|
+
from evograd.operators.sampling import UniformSampling
|
|
27
|
+
from evograd.operators.selection import TournamentSelection, RouletteSelection
|
|
28
|
+
from evograd.operators.crossover import SBXCrossover, BlendCrossover
|
|
29
|
+
from evograd.operators.mutation import PolynomialMutation, GaussianMutation
|
|
30
|
+
from evograd.operators.repair import ReflectRepair
|
|
31
|
+
from evograd.operators.survival import (
|
|
32
|
+
MergeSurvival,
|
|
33
|
+
CommaSurvival,
|
|
34
|
+
ReplaceWorstSurvival,
|
|
35
|
+
FitnessSurvival,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# =============================================================================
|
|
40
|
+
# Test Functions
|
|
41
|
+
# =============================================================================
|
|
42
|
+
|
|
43
|
+
def sphere(x):
|
|
44
|
+
"""Sphere function: sum of squares. Global optimum at origin."""
|
|
45
|
+
return (x ** 2).sum(dim=-1)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def rastrigin(x):
|
|
49
|
+
"""Rastrigin function: highly multimodal."""
|
|
50
|
+
A = 10.0
|
|
51
|
+
n = x.shape[-1]
|
|
52
|
+
return A * n + (x**2 - A * torch.cos(2 * torch.pi * x)).sum(dim=-1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def rosenbrock(x):
|
|
56
|
+
"""Rosenbrock function: narrow valley."""
|
|
57
|
+
return (100 * (x[..., 1:] - x[..., :-1]**2)**2 + (1 - x[..., :-1])**2).sum(dim=-1)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# =============================================================================
|
|
61
|
+
# Tests
|
|
62
|
+
# =============================================================================
|
|
63
|
+
|
|
64
|
+
def test_ga_creation():
|
|
65
|
+
"""Test GA creation with default operators."""
|
|
66
|
+
print("\n" + "="*60)
|
|
67
|
+
print("Testing GA Creation")
|
|
68
|
+
print("="*60)
|
|
69
|
+
|
|
70
|
+
# Test with all defaults
|
|
71
|
+
print("\n1. Testing GA with default operators...")
|
|
72
|
+
ga = GA(pop_size=20)
|
|
73
|
+
print(f" Created: {ga}")
|
|
74
|
+
print(f" Selection: {ga.selection}")
|
|
75
|
+
print(f" Crossover: {ga.crossover}")
|
|
76
|
+
print(f" Mutation: {ga.mutation}")
|
|
77
|
+
print(f" Survival: {ga.survival}")
|
|
78
|
+
assert ga.pop_size == 20
|
|
79
|
+
assert ga.selection is not None
|
|
80
|
+
assert ga.crossover is not None
|
|
81
|
+
assert ga.mutation is not None
|
|
82
|
+
assert ga.survival is not None
|
|
83
|
+
|
|
84
|
+
# Test with custom operators
|
|
85
|
+
print("\n2. Testing GA with custom operators...")
|
|
86
|
+
ga = GA(
|
|
87
|
+
pop_size=30,
|
|
88
|
+
selection=RouletteSelection(differentiable=True),
|
|
89
|
+
crossover=BlendCrossover(alpha=0.5, differentiable=True),
|
|
90
|
+
mutation=GaussianMutation(sigma=0.1, differentiable=True),
|
|
91
|
+
survival=MergeSurvival(elitism=True, n_elite=2, differentiable=True),
|
|
92
|
+
)
|
|
93
|
+
print(f" Created: {ga}")
|
|
94
|
+
print(f" Survival: {ga.survival}")
|
|
95
|
+
assert ga.survival.elitism == True
|
|
96
|
+
assert ga.survival.n_elite == 2
|
|
97
|
+
|
|
98
|
+
# Test factory functions
|
|
99
|
+
print("\n3. Testing factory functions...")
|
|
100
|
+
ga1 = ga_default(pop_size=50)
|
|
101
|
+
print(f" ga_default: {ga1}")
|
|
102
|
+
|
|
103
|
+
ga2 = ga_steady_state(pop_size=50, n_offsprings=2)
|
|
104
|
+
print(f" ga_steady_state: {ga2}")
|
|
105
|
+
print(f" Survival type: {type(ga2.survival).__name__}")
|
|
106
|
+
assert type(ga2.survival).__name__ == 'ReplaceWorstSurvival', \
|
|
107
|
+
f"Expected ReplaceWorstSurvival, got {type(ga2.survival).__name__}"
|
|
108
|
+
|
|
109
|
+
ga3 = ga_comma(pop_size=30, n_offsprings=60)
|
|
110
|
+
print(f" ga_comma: {ga3}")
|
|
111
|
+
print(f" Survival type: {type(ga3.survival).__name__}")
|
|
112
|
+
assert type(ga3.survival).__name__ == 'CommaSurvival', \
|
|
113
|
+
f"Expected CommaSurvival, got {type(ga3.survival).__name__}"
|
|
114
|
+
|
|
115
|
+
print("\n✓ GA creation tests passed!")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_ga_initialization():
|
|
119
|
+
"""Test GA initialization with problem."""
|
|
120
|
+
print("\n" + "="*60)
|
|
121
|
+
print("Testing GA Initialization")
|
|
122
|
+
print("="*60)
|
|
123
|
+
|
|
124
|
+
problem = Problem(
|
|
125
|
+
objective=sphere,
|
|
126
|
+
n_var=10,
|
|
127
|
+
xl=-5.0,
|
|
128
|
+
xu=5.0,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
ga = GA(
|
|
132
|
+
pop_size=20,
|
|
133
|
+
sampling=UniformSampling(seed=42),
|
|
134
|
+
differentiable=True,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
print("\n1. Testing initialization...")
|
|
138
|
+
ga.initialize(problem)
|
|
139
|
+
print(f" Population shape: {ga.population.shape}")
|
|
140
|
+
print(f" Fitness shape: {ga.fitness.shape}")
|
|
141
|
+
print(f" Best fitness: {ga.best_fitness:.4f}")
|
|
142
|
+
print(f" Best solution shape: {ga.best_solution.shape}")
|
|
143
|
+
|
|
144
|
+
assert ga.population.shape == (20, 10)
|
|
145
|
+
assert ga.fitness.shape == (20,)
|
|
146
|
+
assert ga.generation == 0
|
|
147
|
+
assert ga.n_evals == 20 # Initial population evaluated
|
|
148
|
+
|
|
149
|
+
print("\n✓ GA initialization tests passed!")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def test_ga_step():
|
|
153
|
+
"""Test GA evolution step."""
|
|
154
|
+
print("\n" + "="*60)
|
|
155
|
+
print("Testing GA Step")
|
|
156
|
+
print("="*60)
|
|
157
|
+
|
|
158
|
+
problem = Problem(
|
|
159
|
+
objective=sphere,
|
|
160
|
+
n_var=10,
|
|
161
|
+
xl=-5.0,
|
|
162
|
+
xu=5.0,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
ga = GA(
|
|
166
|
+
pop_size=30,
|
|
167
|
+
sampling=UniformSampling(seed=42),
|
|
168
|
+
selection=TournamentSelection(tournament_size=3, differentiable=True),
|
|
169
|
+
crossover=SBXCrossover(eta=15, prob=0.9, differentiable=True),
|
|
170
|
+
mutation=PolynomialMutation(eta=20, differentiable=True),
|
|
171
|
+
survival=MergeSurvival(n_survive=30, elitism=True, n_elite=1, differentiable=True),
|
|
172
|
+
differentiable=True,
|
|
173
|
+
)
|
|
174
|
+
ga.initialize(problem)
|
|
175
|
+
|
|
176
|
+
print("\n1. Testing single step...")
|
|
177
|
+
initial_best = ga.best_fitness
|
|
178
|
+
ga.step()
|
|
179
|
+
print(f" Generation: {ga.generation}")
|
|
180
|
+
print(f" Initial best: {initial_best:.4f}")
|
|
181
|
+
print(f" After 1 step: {ga.best_fitness:.4f}")
|
|
182
|
+
assert ga.generation == 1
|
|
183
|
+
assert ga.n_evals == 30 + 30 # Initial + one generation
|
|
184
|
+
|
|
185
|
+
print("\n2. Testing multiple steps...")
|
|
186
|
+
for _ in range(9):
|
|
187
|
+
ga.step()
|
|
188
|
+
print(f" Generation: {ga.generation}")
|
|
189
|
+
print(f" Best fitness after 10 steps: {ga.best_fitness:.4f}")
|
|
190
|
+
assert ga.generation == 10
|
|
191
|
+
|
|
192
|
+
# Check improvement (sphere is easy, should improve)
|
|
193
|
+
print(f" Improvement: {initial_best - ga.best_fitness:.4f}")
|
|
194
|
+
|
|
195
|
+
print("\n✓ GA step tests passed!")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def test_survival_strategies():
|
|
199
|
+
"""Test different survival selection strategies."""
|
|
200
|
+
print("\n" + "="*60)
|
|
201
|
+
print("Testing Survival Strategies")
|
|
202
|
+
print("="*60)
|
|
203
|
+
|
|
204
|
+
problem = Problem(
|
|
205
|
+
objective=sphere,
|
|
206
|
+
n_var=5,
|
|
207
|
+
xl=-5.0,
|
|
208
|
+
xu=5.0,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Test (mu + lambda) - MergeSurvival
|
|
212
|
+
print("\n1. Testing MergeSurvival (mu+lambda)...")
|
|
213
|
+
ga_plus = GA(
|
|
214
|
+
pop_size=20,
|
|
215
|
+
survival=MergeSurvival(n_survive=20, elitism=True, n_elite=1),
|
|
216
|
+
seed=42,
|
|
217
|
+
)
|
|
218
|
+
ga_plus.initialize(problem)
|
|
219
|
+
for _ in range(5):
|
|
220
|
+
ga_plus.step()
|
|
221
|
+
print(f" Best fitness after 5 gens: {ga_plus.best_fitness:.4f}")
|
|
222
|
+
|
|
223
|
+
# Test (mu, lambda) - CommaSurvival
|
|
224
|
+
print("\n2. Testing CommaSurvival (mu,lambda)...")
|
|
225
|
+
ga_comma_inst = GA(
|
|
226
|
+
pop_size=20,
|
|
227
|
+
n_offsprings=40, # Must be >= pop_size
|
|
228
|
+
survival=CommaSurvival(n_survive=20, elitism=True, n_elite=1),
|
|
229
|
+
seed=42,
|
|
230
|
+
)
|
|
231
|
+
ga_comma_inst.initialize(problem)
|
|
232
|
+
for _ in range(5):
|
|
233
|
+
ga_comma_inst.step()
|
|
234
|
+
print(f" Best fitness after 5 gens: {ga_comma_inst.best_fitness:.4f}")
|
|
235
|
+
|
|
236
|
+
# Test ReplaceWorstSurvival (steady-state)
|
|
237
|
+
print("\n3. Testing ReplaceWorstSurvival (steady-state)...")
|
|
238
|
+
ga_replace = GA(
|
|
239
|
+
pop_size=20,
|
|
240
|
+
n_offsprings=5,
|
|
241
|
+
survival=ReplaceWorstSurvival(n_survive=20, elitism=True, n_elite=1),
|
|
242
|
+
seed=42,
|
|
243
|
+
)
|
|
244
|
+
ga_replace.initialize(problem)
|
|
245
|
+
for _ in range(20): # More generations since fewer offspring per gen
|
|
246
|
+
ga_replace.step()
|
|
247
|
+
print(f" Best fitness after 20 gens: {ga_replace.best_fitness:.4f}")
|
|
248
|
+
|
|
249
|
+
# Test FitnessSurvival (no elitism)
|
|
250
|
+
print("\n4. Testing FitnessSurvival (pure truncation)...")
|
|
251
|
+
ga_fitness = GA(
|
|
252
|
+
pop_size=20,
|
|
253
|
+
survival=FitnessSurvival(n_survive=20),
|
|
254
|
+
seed=42,
|
|
255
|
+
)
|
|
256
|
+
ga_fitness.initialize(problem)
|
|
257
|
+
for _ in range(5):
|
|
258
|
+
ga_fitness.step()
|
|
259
|
+
print(f" Best fitness after 5 gens: {ga_fitness.best_fitness:.4f}")
|
|
260
|
+
|
|
261
|
+
print("\n✓ Survival strategy tests passed!")
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_elitism():
|
|
265
|
+
"""Test elitism behavior."""
|
|
266
|
+
print("\n" + "="*60)
|
|
267
|
+
print("Testing Elitism")
|
|
268
|
+
print("="*60)
|
|
269
|
+
|
|
270
|
+
problem = Problem(
|
|
271
|
+
objective=sphere,
|
|
272
|
+
n_var=5,
|
|
273
|
+
xl=-5.0,
|
|
274
|
+
xu=5.0,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# GA with elitism
|
|
278
|
+
print("\n1. Testing GA with elitism...")
|
|
279
|
+
ga_elite = GA(
|
|
280
|
+
pop_size=20,
|
|
281
|
+
survival=MergeSurvival(n_survive=20, elitism=True, n_elite=1),
|
|
282
|
+
seed=42,
|
|
283
|
+
)
|
|
284
|
+
ga_elite.initialize(problem)
|
|
285
|
+
|
|
286
|
+
best_values = [ga_elite.best_fitness]
|
|
287
|
+
for _ in range(10):
|
|
288
|
+
ga_elite.step()
|
|
289
|
+
best_values.append(ga_elite.best_fitness)
|
|
290
|
+
|
|
291
|
+
# With elitism, best should never increase (for minimization)
|
|
292
|
+
for i in range(1, len(best_values)):
|
|
293
|
+
assert best_values[i] <= best_values[i-1] + 1e-8, \
|
|
294
|
+
f"Elitism violated: {best_values[i]} > {best_values[i-1]}"
|
|
295
|
+
print(f" Best never increased over {len(best_values)} generations ✓")
|
|
296
|
+
|
|
297
|
+
# GA without elitism (best can worsen)
|
|
298
|
+
print("\n2. Testing GA without elitism...")
|
|
299
|
+
ga_no_elite = GA(
|
|
300
|
+
pop_size=20,
|
|
301
|
+
survival=FitnessSurvival(n_survive=20), # No elitism
|
|
302
|
+
seed=42,
|
|
303
|
+
)
|
|
304
|
+
ga_no_elite.initialize(problem)
|
|
305
|
+
|
|
306
|
+
for _ in range(10):
|
|
307
|
+
ga_no_elite.step()
|
|
308
|
+
print(f" Final best: {ga_no_elite.best_fitness:.4f}")
|
|
309
|
+
|
|
310
|
+
print("\n✓ Elitism tests passed!")
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def test_differentiable_mode():
|
|
314
|
+
"""Test differentiable mode with gradient computation."""
|
|
315
|
+
print("\n" + "="*60)
|
|
316
|
+
print("Testing Differentiable Mode")
|
|
317
|
+
print("="*60)
|
|
318
|
+
|
|
319
|
+
problem = Problem(
|
|
320
|
+
objective=sphere,
|
|
321
|
+
n_var=5,
|
|
322
|
+
xl=-5.0,
|
|
323
|
+
xu=5.0,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
ga = GA(
|
|
327
|
+
pop_size=20,
|
|
328
|
+
differentiable=True,
|
|
329
|
+
seed=42,
|
|
330
|
+
)
|
|
331
|
+
ga.initialize(problem)
|
|
332
|
+
|
|
333
|
+
# Create optimizer for learnable parameters
|
|
334
|
+
optimizer = torch.optim.Adam(ga.parameters(), lr=0.01)
|
|
335
|
+
|
|
336
|
+
print("\n1. Testing forward pass...")
|
|
337
|
+
loss = ga.forward()
|
|
338
|
+
print(f" Loss (best fitness): {loss.item():.4f}")
|
|
339
|
+
assert loss.requires_grad, "Loss should require gradients"
|
|
340
|
+
|
|
341
|
+
print("\n2. Testing backward pass...")
|
|
342
|
+
optimizer.zero_grad()
|
|
343
|
+
loss.backward()
|
|
344
|
+
|
|
345
|
+
# Check that some parameters received gradients
|
|
346
|
+
n_grads = 0
|
|
347
|
+
for name, param in ga.named_parameters():
|
|
348
|
+
if param.grad is not None and param.grad.abs().sum() > 0:
|
|
349
|
+
n_grads += 1
|
|
350
|
+
print(f" Parameters with gradients: {n_grads}")
|
|
351
|
+
|
|
352
|
+
print("\n3. Testing optimizer step...")
|
|
353
|
+
optimizer.step()
|
|
354
|
+
ga.update_state()
|
|
355
|
+
|
|
356
|
+
print(f" Generation after update: {ga.generation}")
|
|
357
|
+
|
|
358
|
+
print("\n4. Testing multiple differentiable iterations...")
|
|
359
|
+
for i in range(5):
|
|
360
|
+
optimizer.zero_grad()
|
|
361
|
+
loss = ga.forward()
|
|
362
|
+
loss.backward()
|
|
363
|
+
optimizer.step()
|
|
364
|
+
ga.update_state()
|
|
365
|
+
print(f" Final best fitness: {ga.best_fitness:.4f}")
|
|
366
|
+
|
|
367
|
+
print("\n✓ Differentiable mode tests passed!")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def test_state_persistence():
|
|
371
|
+
"""Test state save and load."""
|
|
372
|
+
print("\n" + "="*60)
|
|
373
|
+
print("Testing State Persistence")
|
|
374
|
+
print("="*60)
|
|
375
|
+
|
|
376
|
+
problem = Problem(
|
|
377
|
+
objective=sphere,
|
|
378
|
+
n_var=5,
|
|
379
|
+
xl=-5.0,
|
|
380
|
+
xu=5.0,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Create and run GA
|
|
384
|
+
ga1 = GA(
|
|
385
|
+
pop_size=20,
|
|
386
|
+
sampling=UniformSampling(seed=42),
|
|
387
|
+
selection=TournamentSelection(tournament_size=3),
|
|
388
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
389
|
+
mutation=PolynomialMutation(eta=20),
|
|
390
|
+
seed=42,
|
|
391
|
+
)
|
|
392
|
+
ga1.initialize(problem)
|
|
393
|
+
|
|
394
|
+
for _ in range(10):
|
|
395
|
+
ga1.step()
|
|
396
|
+
|
|
397
|
+
print(f"\n1. Original GA state:")
|
|
398
|
+
print(f" Generation: {ga1.generation}")
|
|
399
|
+
print(f" Best fitness: {ga1.best_fitness:.4f}")
|
|
400
|
+
print(f" N evals: {ga1.n_evals}")
|
|
401
|
+
|
|
402
|
+
# Save state
|
|
403
|
+
state = ga1.state_dict()
|
|
404
|
+
|
|
405
|
+
# Create new GA with same structure
|
|
406
|
+
ga2 = GA(
|
|
407
|
+
pop_size=20,
|
|
408
|
+
sampling=UniformSampling(seed=42),
|
|
409
|
+
selection=TournamentSelection(tournament_size=3),
|
|
410
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
411
|
+
mutation=PolynomialMutation(eta=20),
|
|
412
|
+
seed=0, # Different seed
|
|
413
|
+
)
|
|
414
|
+
ga2.initialize(problem)
|
|
415
|
+
|
|
416
|
+
print(f"\n2. New GA before load:")
|
|
417
|
+
print(f" Generation: {ga2.generation}")
|
|
418
|
+
print(f" Best fitness: {ga2.best_fitness:.4f}")
|
|
419
|
+
|
|
420
|
+
# Load state
|
|
421
|
+
ga2.load_state_dict(state)
|
|
422
|
+
|
|
423
|
+
print(f"\n3. New GA after load:")
|
|
424
|
+
print(f" Generation: {ga2.generation}")
|
|
425
|
+
print(f" Best fitness: {ga2.best_fitness:.4f}")
|
|
426
|
+
|
|
427
|
+
assert ga2.generation == ga1.generation
|
|
428
|
+
assert abs(ga2.best_fitness - ga1.best_fitness) < 1e-6
|
|
429
|
+
|
|
430
|
+
# Continue evolution
|
|
431
|
+
for _ in range(5):
|
|
432
|
+
ga2.step()
|
|
433
|
+
print(f"\n4. After 5 more steps:")
|
|
434
|
+
print(f" Generation: {ga2.generation}")
|
|
435
|
+
print(f" Best fitness: {ga2.best_fitness:.4f}")
|
|
436
|
+
|
|
437
|
+
# Test save to file
|
|
438
|
+
print("\n5. Testing save/load to file...")
|
|
439
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
440
|
+
filepath = os.path.join(tmpdir, "ga_state.pt")
|
|
441
|
+
torch.save(ga1.state_dict(), filepath)
|
|
442
|
+
print(f" Saved to: {filepath}")
|
|
443
|
+
|
|
444
|
+
ga3 = GA(
|
|
445
|
+
pop_size=20,
|
|
446
|
+
sampling=UniformSampling(seed=42),
|
|
447
|
+
selection=TournamentSelection(tournament_size=3),
|
|
448
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
449
|
+
mutation=PolynomialMutation(eta=20),
|
|
450
|
+
)
|
|
451
|
+
ga3.initialize(problem)
|
|
452
|
+
ga3.load_state_dict(torch.load(filepath))
|
|
453
|
+
print(f" Loaded generation: {ga3.generation}")
|
|
454
|
+
assert ga3.generation == ga1.generation
|
|
455
|
+
|
|
456
|
+
print("\n✓ State persistence tests passed!")
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def test_convergence():
|
|
460
|
+
"""Test GA convergence on simple function."""
|
|
461
|
+
print("\n" + "="*60)
|
|
462
|
+
print("Testing Convergence")
|
|
463
|
+
print("="*60)
|
|
464
|
+
|
|
465
|
+
# Easy problem: 5D sphere
|
|
466
|
+
problem = Problem(
|
|
467
|
+
objective=sphere,
|
|
468
|
+
n_var=5,
|
|
469
|
+
xl=-5.0,
|
|
470
|
+
xu=5.0,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
ga = GA(
|
|
474
|
+
pop_size=50,
|
|
475
|
+
selection=TournamentSelection(tournament_size=3),
|
|
476
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
477
|
+
mutation=PolynomialMutation(eta=20),
|
|
478
|
+
survival=MergeSurvival(n_survive=50, elitism=True, n_elite=1),
|
|
479
|
+
seed=42,
|
|
480
|
+
)
|
|
481
|
+
ga.initialize(problem)
|
|
482
|
+
|
|
483
|
+
print(f"\n1. Running GA for 100 generations...")
|
|
484
|
+
print(f" Initial best: {ga.best_fitness:.4f}")
|
|
485
|
+
|
|
486
|
+
for gen in range(100):
|
|
487
|
+
ga.step()
|
|
488
|
+
if (gen + 1) % 25 == 0:
|
|
489
|
+
print(f" Gen {gen+1:3d}: best = {ga.best_fitness:.6f}")
|
|
490
|
+
|
|
491
|
+
print(f"\n2. Final results:")
|
|
492
|
+
print(f" Best fitness: {ga.best_fitness:.6f}")
|
|
493
|
+
print(f" Best solution: {ga.best_solution.tolist()}")
|
|
494
|
+
print(f" Distance to origin: {ga.best_solution.norm().item():.6f}")
|
|
495
|
+
|
|
496
|
+
# Should get reasonably close to optimum (0)
|
|
497
|
+
assert ga.best_fitness < 1.0, f"GA should converge better, got {ga.best_fitness}"
|
|
498
|
+
print("\n✓ Convergence tests passed!")
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def test_hyperparams():
|
|
502
|
+
"""Test hyperparameter tracking."""
|
|
503
|
+
print("\n" + "="*60)
|
|
504
|
+
print("Testing Hyperparameter Tracking")
|
|
505
|
+
print("="*60)
|
|
506
|
+
|
|
507
|
+
problem = Problem(
|
|
508
|
+
objective=sphere,
|
|
509
|
+
n_var=5,
|
|
510
|
+
xl=-5.0,
|
|
511
|
+
xu=5.0,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
ga = GA(
|
|
515
|
+
pop_size=20,
|
|
516
|
+
selection=TournamentSelection(tournament_size=3, differentiable=True),
|
|
517
|
+
crossover=SBXCrossover(eta=15, prob=0.9, differentiable=True),
|
|
518
|
+
mutation=PolynomialMutation(eta=20, differentiable=True),
|
|
519
|
+
survival=MergeSurvival(n_survive=20, elitism=True, n_elite=2, differentiable=True),
|
|
520
|
+
)
|
|
521
|
+
ga.initialize(problem)
|
|
522
|
+
ga.step()
|
|
523
|
+
|
|
524
|
+
params = ga._get_hyperparams()
|
|
525
|
+
print(f"\n1. Hyperparameters:")
|
|
526
|
+
for key, value in params.items():
|
|
527
|
+
print(f" {key}: {value}")
|
|
528
|
+
|
|
529
|
+
assert 'pop_size' in params
|
|
530
|
+
assert params['pop_size'] == 20
|
|
531
|
+
# Elitism is now tracked via the survival operator
|
|
532
|
+
assert 'elitism' in params
|
|
533
|
+
assert params['elitism'] == True
|
|
534
|
+
|
|
535
|
+
print("\n✓ Hyperparameter tracking tests passed!")
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
# =============================================================================
|
|
539
|
+
# Main
|
|
540
|
+
# =============================================================================
|
|
541
|
+
|
|
542
|
+
def run_all_tests():
|
|
543
|
+
"""Run all GA tests."""
|
|
544
|
+
print("\n" + "#"*60)
|
|
545
|
+
print("# EvoGrad GA Algorithm Tests")
|
|
546
|
+
print("#"*60)
|
|
547
|
+
|
|
548
|
+
try:
|
|
549
|
+
test_ga_creation()
|
|
550
|
+
test_ga_initialization()
|
|
551
|
+
test_ga_step()
|
|
552
|
+
test_survival_strategies()
|
|
553
|
+
test_elitism()
|
|
554
|
+
test_differentiable_mode()
|
|
555
|
+
test_state_persistence()
|
|
556
|
+
test_convergence()
|
|
557
|
+
test_hyperparams()
|
|
558
|
+
|
|
559
|
+
print("\n" + "="*60)
|
|
560
|
+
print("✓ ALL GA TESTS PASSED!")
|
|
561
|
+
print("="*60)
|
|
562
|
+
return True
|
|
563
|
+
except Exception as e:
|
|
564
|
+
print(f"\n✗ TEST FAILED: {e}")
|
|
565
|
+
import traceback
|
|
566
|
+
traceback.print_exc()
|
|
567
|
+
return False
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
if __name__ == "__main__":
|
|
571
|
+
success = run_all_tests()
|
|
572
|
+
sys.exit(0 if success else 1)
|