evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
evograd/tests/run_all.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Run all EvoGrad tests.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python -m tests.run_all
|
|
6
|
+
# or
|
|
7
|
+
python tests/run_all.py
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
# Add parent to path
|
|
14
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
15
|
+
|
|
16
|
+
from tests.test_utils import run_all_tests as test_utils
|
|
17
|
+
from tests.test_core import run_all_tests as test_core
|
|
18
|
+
from tests.test_operators import run_all_tests as test_operators
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def run_all():
|
|
22
|
+
"""Run all test suites."""
|
|
23
|
+
print("\n" + "█"*60)
|
|
24
|
+
print("█" + " "*58 + "█")
|
|
25
|
+
print("█" + " EVOGRAD COMPLETE TEST SUITE".center(58) + "█")
|
|
26
|
+
print("█" + " "*58 + "█")
|
|
27
|
+
print("█"*60)
|
|
28
|
+
|
|
29
|
+
results = {}
|
|
30
|
+
|
|
31
|
+
# Run utils tests
|
|
32
|
+
print("\n\n" + "▶"*60)
|
|
33
|
+
print("▶ RUNNING UTILS TESTS")
|
|
34
|
+
print("▶"*60)
|
|
35
|
+
results['utils'] = test_utils()
|
|
36
|
+
|
|
37
|
+
# Run core tests
|
|
38
|
+
print("\n\n" + "▶"*60)
|
|
39
|
+
print("▶ RUNNING CORE TESTS")
|
|
40
|
+
print("▶"*60)
|
|
41
|
+
results['core'] = test_core()
|
|
42
|
+
|
|
43
|
+
# Run operators tests
|
|
44
|
+
print("\n\n" + "▶"*60)
|
|
45
|
+
print("▶ RUNNING OPERATORS TESTS")
|
|
46
|
+
print("▶"*60)
|
|
47
|
+
results['operators'] = test_operators()
|
|
48
|
+
|
|
49
|
+
# Summary
|
|
50
|
+
print("\n\n" + "█"*60)
|
|
51
|
+
print("█" + " "*58 + "█")
|
|
52
|
+
print("█" + " TEST SUMMARY".center(58) + "█")
|
|
53
|
+
print("█" + " "*58 + "█")
|
|
54
|
+
print("█"*60)
|
|
55
|
+
|
|
56
|
+
all_passed = True
|
|
57
|
+
for module, passed in results.items():
|
|
58
|
+
status = "✓ PASSED" if passed else "✗ FAILED"
|
|
59
|
+
print(f" {module:20s} {status}")
|
|
60
|
+
if not passed:
|
|
61
|
+
all_passed = False
|
|
62
|
+
|
|
63
|
+
print()
|
|
64
|
+
if all_passed:
|
|
65
|
+
print("█"*60)
|
|
66
|
+
print("█" + " ✓ ALL TESTS PASSED!".center(58) + "█")
|
|
67
|
+
print("█"*60)
|
|
68
|
+
else:
|
|
69
|
+
print("█"*60)
|
|
70
|
+
print("█" + " ✗ SOME TESTS FAILED!".center(58) + "█")
|
|
71
|
+
print("█"*60)
|
|
72
|
+
|
|
73
|
+
return all_passed
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == "__main__":
|
|
77
|
+
success = run_all()
|
|
78
|
+
sys.exit(0 if success else 1)
|
|
@@ -0,0 +1,528 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test script for EvoGrad core module.
|
|
3
|
+
|
|
4
|
+
Tests:
|
|
5
|
+
- problem.py: Problem definition and evaluation
|
|
6
|
+
- termination.py: Termination criteria
|
|
7
|
+
- result.py: Result container
|
|
8
|
+
- algorithm.py: Algorithm base class
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
cd evograd && python tests/test_core.py
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import sys
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
import tempfile
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
# Add parent directory to path for imports (works when running from evograd/)
|
|
21
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
22
|
+
|
|
23
|
+
from core.problem import Problem
|
|
24
|
+
from core.termination import (
|
|
25
|
+
Termination,
|
|
26
|
+
MaxEvaluations,
|
|
27
|
+
MaxGenerations,
|
|
28
|
+
TargetReached,
|
|
29
|
+
ToleranceReached,
|
|
30
|
+
TimeLimit,
|
|
31
|
+
TerminationCollection,
|
|
32
|
+
)
|
|
33
|
+
from core.result import Result
|
|
34
|
+
from core.algorithm import Algorithm, AlgorithmState
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_problem():
|
|
38
|
+
"""Test Problem class."""
|
|
39
|
+
print("\n" + "="*60)
|
|
40
|
+
print("Testing problem.py")
|
|
41
|
+
print("="*60)
|
|
42
|
+
|
|
43
|
+
# Test basic Problem creation
|
|
44
|
+
print("\n1. Testing Problem creation...")
|
|
45
|
+
|
|
46
|
+
def sphere_func(x):
|
|
47
|
+
return (x ** 2).sum(dim=-1)
|
|
48
|
+
|
|
49
|
+
problem = Problem(
|
|
50
|
+
n_var=10,
|
|
51
|
+
n_obj=1,
|
|
52
|
+
xl=-5.0,
|
|
53
|
+
xu=5.0,
|
|
54
|
+
objective=sphere_func,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
print(f" n_var: {problem.n_var}")
|
|
58
|
+
print(f" n_obj: {problem.n_obj}")
|
|
59
|
+
print(f" xl shape: {problem.xl.shape}")
|
|
60
|
+
print(f" xu shape: {problem.xu.shape}")
|
|
61
|
+
assert problem.n_var == 10
|
|
62
|
+
assert problem.xl.shape == (10,)
|
|
63
|
+
|
|
64
|
+
# Test evaluation
|
|
65
|
+
print("\n2. Testing Problem evaluation...")
|
|
66
|
+
x = torch.randn(20, 10) # 20 solutions, 10 variables
|
|
67
|
+
f = problem.evaluate(x)
|
|
68
|
+
print(f" Input shape: {x.shape}")
|
|
69
|
+
print(f" Output shape: {f.shape}")
|
|
70
|
+
print(f" Sample fitness values: {f[:3].tolist()}")
|
|
71
|
+
assert f.shape == (20,)
|
|
72
|
+
|
|
73
|
+
# Test bounds as lists
|
|
74
|
+
print("\n3. Testing Problem with list bounds...")
|
|
75
|
+
problem2 = Problem(
|
|
76
|
+
n_var=3,
|
|
77
|
+
xl=[-1.0, -2.0, -3.0],
|
|
78
|
+
xu=[1.0, 2.0, 3.0],
|
|
79
|
+
objective=sphere_func,
|
|
80
|
+
)
|
|
81
|
+
print(f" xl: {problem2.xl.tolist()}")
|
|
82
|
+
print(f" xu: {problem2.xu.tolist()}")
|
|
83
|
+
assert problem2.xl[1] == -2.0
|
|
84
|
+
assert problem2.xu[2] == 3.0
|
|
85
|
+
|
|
86
|
+
# Test Problem with different objective function
|
|
87
|
+
print("\n4. Testing Problem with Rastrigin function...")
|
|
88
|
+
|
|
89
|
+
def rastrigin(x):
|
|
90
|
+
A = 10.0
|
|
91
|
+
return A * x.shape[-1] + (x**2 - A * torch.cos(2 * 3.14159 * x)).sum(dim=-1)
|
|
92
|
+
|
|
93
|
+
rastrigin_problem = Problem(
|
|
94
|
+
objective=rastrigin,
|
|
95
|
+
n_var=5,
|
|
96
|
+
xl=-5.12,
|
|
97
|
+
xu=5.12,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
x = torch.zeros(1, 5) # Global optimum
|
|
101
|
+
f = rastrigin_problem(x)
|
|
102
|
+
print(f" Rastrigin at origin: {f.item():.6f}")
|
|
103
|
+
assert abs(f.item()) < 1e-5, "Rastrigin should be 0 at origin"
|
|
104
|
+
|
|
105
|
+
# Test constraints
|
|
106
|
+
print("\n5. Testing Problem with constraints...")
|
|
107
|
+
|
|
108
|
+
def constraint_func(x):
|
|
109
|
+
# g(x) <= 0 format: x[0] + x[1] <= 1
|
|
110
|
+
return x[:, 0] + x[:, 1] - 1.0
|
|
111
|
+
|
|
112
|
+
constrained_problem = Problem(
|
|
113
|
+
n_var=2,
|
|
114
|
+
xl=0.0,
|
|
115
|
+
xu=2.0,
|
|
116
|
+
objective=sphere_func,
|
|
117
|
+
constraints=[(constraint_func, 'ineq')],
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Feasible point
|
|
121
|
+
x_feasible = torch.tensor([[0.3, 0.3]])
|
|
122
|
+
g = constrained_problem.evaluate_constraints(x_feasible)
|
|
123
|
+
ineq_val = g['ineq'][0, 0].item()
|
|
124
|
+
print(f" Constraint at [0.3, 0.3]: {ineq_val:.2f} (should be < 0)")
|
|
125
|
+
assert ineq_val < 0, "Point should be feasible"
|
|
126
|
+
|
|
127
|
+
# Infeasible point
|
|
128
|
+
x_infeasible = torch.tensor([[0.8, 0.8]])
|
|
129
|
+
g = constrained_problem.evaluate_constraints(x_infeasible)
|
|
130
|
+
ineq_val = g['ineq'][0, 0].item()
|
|
131
|
+
print(f" Constraint at [0.8, 0.8]: {ineq_val:.2f} (should be > 0)")
|
|
132
|
+
assert ineq_val > 0, "Point should be infeasible"
|
|
133
|
+
|
|
134
|
+
# Test is_feasible
|
|
135
|
+
feasibility = constrained_problem.is_feasible(
|
|
136
|
+
torch.cat([x_feasible, x_infeasible])
|
|
137
|
+
)
|
|
138
|
+
print(f" Feasibility: {feasibility.tolist()}")
|
|
139
|
+
assert feasibility[0] == True
|
|
140
|
+
assert feasibility[1] == False
|
|
141
|
+
|
|
142
|
+
print("\n✓ problem.py tests passed!")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_termination():
|
|
146
|
+
"""Test termination criteria."""
|
|
147
|
+
print("\n" + "="*60)
|
|
148
|
+
print("Testing termination.py")
|
|
149
|
+
print("="*60)
|
|
150
|
+
|
|
151
|
+
# Create mock algorithm state
|
|
152
|
+
class MockAlgorithm:
|
|
153
|
+
def __init__(self):
|
|
154
|
+
self.n_evals = 0
|
|
155
|
+
self.generation = 0
|
|
156
|
+
self.best_fitness = float('inf')
|
|
157
|
+
self.fitness = torch.tensor([100.0])
|
|
158
|
+
self._prev_best = None
|
|
159
|
+
|
|
160
|
+
# Test MaxEvaluations
|
|
161
|
+
print("\n1. Testing MaxEvaluations...")
|
|
162
|
+
max_evals = MaxEvaluations(max_evals=100) # keyword arg as in actual API
|
|
163
|
+
alg = MockAlgorithm()
|
|
164
|
+
|
|
165
|
+
alg.n_evals = 50
|
|
166
|
+
assert not max_evals.should_terminate(alg), "Should not terminate at 50 evals"
|
|
167
|
+
print(f" Progress at 50 evals: {max_evals.progress(alg):.1%}")
|
|
168
|
+
|
|
169
|
+
alg.n_evals = 100
|
|
170
|
+
assert max_evals.should_terminate(alg), "Should terminate at 100 evals"
|
|
171
|
+
print(f" Terminated at 100 evals: ✓")
|
|
172
|
+
|
|
173
|
+
# Test MaxGenerations
|
|
174
|
+
print("\n2. Testing MaxGenerations...")
|
|
175
|
+
max_gen = MaxGenerations(max_gens=50) # keyword arg as in actual API
|
|
176
|
+
alg = MockAlgorithm()
|
|
177
|
+
|
|
178
|
+
alg.generation = 25
|
|
179
|
+
assert not max_gen.should_terminate(alg)
|
|
180
|
+
print(f" Progress at gen 25: {max_gen.progress(alg):.1%}")
|
|
181
|
+
|
|
182
|
+
alg.generation = 50
|
|
183
|
+
assert max_gen.should_terminate(alg)
|
|
184
|
+
print(f" Terminated at gen 50: ✓")
|
|
185
|
+
|
|
186
|
+
# Test TargetReached (no tolerance param - just target_fitness and minimize)
|
|
187
|
+
print("\n3. Testing TargetReached...")
|
|
188
|
+
target = TargetReached(target_fitness=1.0, minimize=True)
|
|
189
|
+
alg = MockAlgorithm()
|
|
190
|
+
|
|
191
|
+
alg.best_fitness = 10.0
|
|
192
|
+
assert not target.should_terminate(alg)
|
|
193
|
+
print(f" At fitness 10.0: not terminated")
|
|
194
|
+
|
|
195
|
+
alg.best_fitness = 0.5 # Below target (for minimization)
|
|
196
|
+
assert target.should_terminate(alg)
|
|
197
|
+
print(f" At fitness 0.5 (below target): terminated ✓")
|
|
198
|
+
|
|
199
|
+
# Test ToleranceReached
|
|
200
|
+
print("\n4. Testing ToleranceReached...")
|
|
201
|
+
tol = ToleranceReached(tol=0.001, n_last=3, mode='absolute')
|
|
202
|
+
alg = MockAlgorithm()
|
|
203
|
+
|
|
204
|
+
# Simulate improving then stagnating
|
|
205
|
+
fitness_history = [100.0, 50.0, 25.0, 24.9999, 24.9998, 24.9997]
|
|
206
|
+
terminated_at = None
|
|
207
|
+
|
|
208
|
+
for gen, fit in enumerate(fitness_history):
|
|
209
|
+
alg.generation = gen
|
|
210
|
+
alg.best_fitness = fit
|
|
211
|
+
if tol.should_terminate(alg):
|
|
212
|
+
terminated_at = gen
|
|
213
|
+
break
|
|
214
|
+
|
|
215
|
+
print(f" Fitness history: {fitness_history}")
|
|
216
|
+
print(f" Terminated at generation: {terminated_at}")
|
|
217
|
+
assert terminated_at is not None, "Should have terminated due to stagnation"
|
|
218
|
+
|
|
219
|
+
# Test TimeLimit
|
|
220
|
+
print("\n5. Testing TimeLimit...")
|
|
221
|
+
time_limit = TimeLimit(max_seconds=0.1) # 100ms, keyword arg as in actual API
|
|
222
|
+
alg = MockAlgorithm()
|
|
223
|
+
|
|
224
|
+
assert not time_limit.should_terminate(alg)
|
|
225
|
+
print(f" Initial progress: {time_limit.progress(alg):.1%}")
|
|
226
|
+
|
|
227
|
+
import time
|
|
228
|
+
time.sleep(0.15) # Wait 150ms
|
|
229
|
+
assert time_limit.should_terminate(alg)
|
|
230
|
+
print(f" After 150ms: terminated ✓")
|
|
231
|
+
|
|
232
|
+
# Test TerminationCollection (any)
|
|
233
|
+
print("\n6. Testing TerminationCollection (mode='or')...")
|
|
234
|
+
combined = TerminationCollection(
|
|
235
|
+
criteria=[
|
|
236
|
+
MaxEvaluations(1000),
|
|
237
|
+
MaxGenerations(100),
|
|
238
|
+
TargetReached(0.0),
|
|
239
|
+
],
|
|
240
|
+
mode='or',
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
alg = MockAlgorithm()
|
|
244
|
+
alg.n_evals = 50
|
|
245
|
+
alg.generation = 10
|
|
246
|
+
alg.best_fitness = 0.0 # Target reached!
|
|
247
|
+
|
|
248
|
+
assert combined.should_terminate(alg)
|
|
249
|
+
print(f" Terminated because target reached (any mode): ✓")
|
|
250
|
+
|
|
251
|
+
# Test TerminationCollection (all)
|
|
252
|
+
print("\n7. Testing TerminationCollection (mode='and')...")
|
|
253
|
+
combined_all = TerminationCollection(
|
|
254
|
+
criteria=[
|
|
255
|
+
MaxEvaluations(100),
|
|
256
|
+
MaxGenerations(10),
|
|
257
|
+
],
|
|
258
|
+
mode='and',
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
alg = MockAlgorithm()
|
|
262
|
+
alg.n_evals = 100 # Met
|
|
263
|
+
alg.generation = 5 # Not met
|
|
264
|
+
assert not combined_all.should_terminate(alg)
|
|
265
|
+
print(f" 100 evals, 5 gens: not terminated (all mode)")
|
|
266
|
+
|
|
267
|
+
alg.generation = 10 # Both met
|
|
268
|
+
assert combined_all.should_terminate(alg)
|
|
269
|
+
print(f" 100 evals, 10 gens: terminated (all mode) ✓")
|
|
270
|
+
|
|
271
|
+
print("\n✓ termination.py tests passed!")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def test_result():
|
|
275
|
+
"""Test Result class."""
|
|
276
|
+
print("\n" + "="*60)
|
|
277
|
+
print("Testing result.py")
|
|
278
|
+
print("="*60)
|
|
279
|
+
|
|
280
|
+
# Create a result
|
|
281
|
+
print("\n1. Testing Result creation...")
|
|
282
|
+
result = Result(
|
|
283
|
+
best_solution=torch.tensor([1.0, 2.0, 3.0]),
|
|
284
|
+
best_fitness=0.5,
|
|
285
|
+
n_evals=1000,
|
|
286
|
+
n_gen=50,
|
|
287
|
+
elapsed_time=10.5,
|
|
288
|
+
success=True,
|
|
289
|
+
termination_reason="Optimization completed",
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
print(f" Best solution: {result.best_solution.tolist()}")
|
|
293
|
+
print(f" Best fitness: {result.best_fitness}")
|
|
294
|
+
print(f" Evaluations: {result.n_evals}")
|
|
295
|
+
print(f" Generations: {result.n_gen}")
|
|
296
|
+
print(f" Time: {result.elapsed_time:.1f}s")
|
|
297
|
+
|
|
298
|
+
# Test with history
|
|
299
|
+
print("\n2. Testing Result with history...")
|
|
300
|
+
result.history = {
|
|
301
|
+
'best_fitness': [100, 50, 25, 10, 5, 1, 0.5],
|
|
302
|
+
'generation': list(range(7)),
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
print(f" History keys: {list(result.history.keys())}")
|
|
306
|
+
print(f" Best fitness history: {result.history['best_fitness']}")
|
|
307
|
+
|
|
308
|
+
# Test population storage
|
|
309
|
+
print("\n3. Testing Result with population...")
|
|
310
|
+
result.population = torch.randn(20, 3)
|
|
311
|
+
result.fitness = torch.randn(20)
|
|
312
|
+
print(f" Population shape: {result.population.shape}")
|
|
313
|
+
print(f" Fitness shape: {result.fitness.shape}")
|
|
314
|
+
|
|
315
|
+
# Test save and load
|
|
316
|
+
print("\n4. Testing save and load...")
|
|
317
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
318
|
+
filepath = os.path.join(tmpdir, "result.pt")
|
|
319
|
+
result.save(filepath)
|
|
320
|
+
print(f" Saved to: {filepath}")
|
|
321
|
+
|
|
322
|
+
loaded = Result.load(filepath)
|
|
323
|
+
print(f" Loaded best_solution: {loaded.best_solution.tolist()}")
|
|
324
|
+
print(f" Loaded best_fitness: {loaded.best_fitness}")
|
|
325
|
+
print(f" Loaded population shape: {loaded.population.shape}")
|
|
326
|
+
print(f" Loaded fitness shape: {loaded.fitness.shape}")
|
|
327
|
+
assert torch.allclose(result.best_solution, loaded.best_solution)
|
|
328
|
+
assert result.best_fitness == loaded.best_fitness
|
|
329
|
+
print(" Save/load verified ✓")
|
|
330
|
+
|
|
331
|
+
# Test string representation
|
|
332
|
+
print("\n5. Testing string representation...")
|
|
333
|
+
print(result)
|
|
334
|
+
|
|
335
|
+
# Test to_dict
|
|
336
|
+
print("\n6. Testing to_dict...")
|
|
337
|
+
result_dict = result.to_dict()
|
|
338
|
+
print(f" Dict keys: {list(result_dict.keys())}")
|
|
339
|
+
assert 'best_solution' in result_dict
|
|
340
|
+
assert 'best_fitness' in result_dict
|
|
341
|
+
assert 'population' in result_dict
|
|
342
|
+
assert 'fitness' in result_dict
|
|
343
|
+
|
|
344
|
+
print("\n✓ result.py tests passed!")
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def test_algorithm():
|
|
348
|
+
"""Test Algorithm base class."""
|
|
349
|
+
print("\n" + "="*60)
|
|
350
|
+
print("Testing algorithm.py")
|
|
351
|
+
print("="*60)
|
|
352
|
+
|
|
353
|
+
# Import operators for a concrete implementation
|
|
354
|
+
from operators.sampling import UniformSampling
|
|
355
|
+
from operators.selection import TournamentSelection
|
|
356
|
+
from operators.crossover import SBXCrossover
|
|
357
|
+
from operators.mutation import PolynomialMutation
|
|
358
|
+
from operators.repair import ReflectRepair
|
|
359
|
+
|
|
360
|
+
# Create a simple concrete algorithm
|
|
361
|
+
print("\n1. Creating concrete Algorithm subclass...")
|
|
362
|
+
|
|
363
|
+
class SimpleGA(Algorithm):
|
|
364
|
+
"""Minimal GA for testing."""
|
|
365
|
+
|
|
366
|
+
def __init__(self, pop_size=20, **kwargs):
|
|
367
|
+
super().__init__(pop_size=pop_size, **kwargs)
|
|
368
|
+
|
|
369
|
+
def _setup(self):
|
|
370
|
+
"""Initialise population (called by initialize)."""
|
|
371
|
+
# Population is already created by parent class
|
|
372
|
+
pass
|
|
373
|
+
|
|
374
|
+
def _infill(self):
|
|
375
|
+
"""Generate offspring."""
|
|
376
|
+
# Select parents
|
|
377
|
+
parents = self.selection(
|
|
378
|
+
self.population, self.fitness, self.pop_size
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Crossover (pair consecutive)
|
|
382
|
+
n_pairs = self.pop_size // 2
|
|
383
|
+
p1 = parents[:n_pairs]
|
|
384
|
+
p2 = parents[n_pairs:2*n_pairs]
|
|
385
|
+
offspring = self.crossover(p1, p2)
|
|
386
|
+
|
|
387
|
+
# Mutation
|
|
388
|
+
offspring = self.mutation(offspring, self.problem.xl, self.problem.xu)
|
|
389
|
+
|
|
390
|
+
# Repair bounds
|
|
391
|
+
offspring = self.repair(offspring, self.problem.xl, self.problem.xu)
|
|
392
|
+
|
|
393
|
+
return offspring
|
|
394
|
+
|
|
395
|
+
def _advance(self, offspring, offspring_fitness):
|
|
396
|
+
"""Update population."""
|
|
397
|
+
# Combine and select best
|
|
398
|
+
combined_pop = torch.cat([self.population, offspring], dim=0)
|
|
399
|
+
combined_fit = torch.cat([self.fitness, offspring_fitness], dim=0)
|
|
400
|
+
|
|
401
|
+
# Select top pop_size
|
|
402
|
+
indices = torch.argsort(combined_fit)[:self.pop_size]
|
|
403
|
+
|
|
404
|
+
# Update population (stored as nn.Parameter)
|
|
405
|
+
with torch.no_grad():
|
|
406
|
+
self._population.copy_(combined_pop[indices])
|
|
407
|
+
self.state.fitness = combined_fit[indices]
|
|
408
|
+
|
|
409
|
+
# Update best
|
|
410
|
+
self.state.update_best(self.population, self.state.fitness)
|
|
411
|
+
|
|
412
|
+
# Create problem
|
|
413
|
+
def sphere(x):
|
|
414
|
+
return (x ** 2).sum(dim=-1)
|
|
415
|
+
|
|
416
|
+
problem = Problem(
|
|
417
|
+
n_var=5,
|
|
418
|
+
xl=-5.0,
|
|
419
|
+
xu=5.0,
|
|
420
|
+
objective=sphere,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Create algorithm with operators
|
|
424
|
+
print("\n2. Testing Algorithm with dependency injection...")
|
|
425
|
+
ga = SimpleGA(
|
|
426
|
+
pop_size=20,
|
|
427
|
+
sampling=UniformSampling(seed=42),
|
|
428
|
+
selection=TournamentSelection(tournament_size=3),
|
|
429
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
430
|
+
mutation=PolynomialMutation(eta=20),
|
|
431
|
+
repair=ReflectRepair(),
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
print(f" Algorithm: {ga}")
|
|
435
|
+
print(f" Pop size: {ga.pop_size}")
|
|
436
|
+
|
|
437
|
+
# Test initialize (not setup)
|
|
438
|
+
print("\n3. Testing initialize...")
|
|
439
|
+
ga.initialize(problem)
|
|
440
|
+
print(f" Population shape: {ga.population.shape}")
|
|
441
|
+
print(f" Initial best fitness: {ga.best_fitness:.4f}")
|
|
442
|
+
|
|
443
|
+
# Test step
|
|
444
|
+
print("\n4. Testing step() method...")
|
|
445
|
+
initial_best = ga.best_fitness
|
|
446
|
+
for i in range(10):
|
|
447
|
+
ga.step()
|
|
448
|
+
print(f" Generation: {ga.generation}")
|
|
449
|
+
print(f" Best fitness after 10 steps: {ga.best_fitness:.4f}")
|
|
450
|
+
assert ga.generation == 10
|
|
451
|
+
|
|
452
|
+
# Test state dict
|
|
453
|
+
print("\n5. Testing state_dict and load_state_dict...")
|
|
454
|
+
state = ga.state_dict()
|
|
455
|
+
print(f" State keys: {list(state.keys())}")
|
|
456
|
+
|
|
457
|
+
# Create new algorithm with SAME operators and load state
|
|
458
|
+
ga2 = SimpleGA(
|
|
459
|
+
pop_size=20,
|
|
460
|
+
sampling=UniformSampling(seed=42),
|
|
461
|
+
selection=TournamentSelection(tournament_size=3),
|
|
462
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
463
|
+
mutation=PolynomialMutation(eta=20),
|
|
464
|
+
repair=ReflectRepair(),
|
|
465
|
+
)
|
|
466
|
+
ga2.initialize(problem)
|
|
467
|
+
ga2.load_state_dict(state)
|
|
468
|
+
print(f" Loaded generation: {ga2.generation}")
|
|
469
|
+
print(f" Loaded best_fitness: {ga2.best_fitness:.4f}")
|
|
470
|
+
assert ga2.generation == ga.generation
|
|
471
|
+
assert abs(ga2.best_fitness - ga.best_fitness) < 1e-6
|
|
472
|
+
|
|
473
|
+
# Test save and load using torch
|
|
474
|
+
print("\n6. Testing save and load with torch...")
|
|
475
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
476
|
+
filepath = os.path.join(tmpdir, "algorithm.pt")
|
|
477
|
+
torch.save(ga.state_dict(), filepath)
|
|
478
|
+
print(f" Saved to: {filepath}")
|
|
479
|
+
|
|
480
|
+
ga3 = SimpleGA(
|
|
481
|
+
pop_size=20,
|
|
482
|
+
sampling=UniformSampling(seed=42),
|
|
483
|
+
selection=TournamentSelection(tournament_size=3),
|
|
484
|
+
crossover=SBXCrossover(eta=15, prob=0.9),
|
|
485
|
+
mutation=PolynomialMutation(eta=20),
|
|
486
|
+
repair=ReflectRepair(),
|
|
487
|
+
)
|
|
488
|
+
ga3.initialize(problem)
|
|
489
|
+
ga3.load_state_dict(torch.load(filepath))
|
|
490
|
+
print(f" Loaded generation: {ga3.generation}")
|
|
491
|
+
assert ga3.generation == ga.generation
|
|
492
|
+
|
|
493
|
+
# Test AlgorithmState container
|
|
494
|
+
print("\n7. Testing AlgorithmState container...")
|
|
495
|
+
print(f" state.generation: {ga.state.generation}")
|
|
496
|
+
print(f" state.n_evals: {ga.state.n_evals}")
|
|
497
|
+
print(f" state.best_fitness: {ga.state.best_fitness:.4f}")
|
|
498
|
+
assert ga.state.generation == ga.generation
|
|
499
|
+
|
|
500
|
+
print("\n✓ algorithm.py tests passed!")
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def run_all_tests():
|
|
504
|
+
"""Run all core tests."""
|
|
505
|
+
print("\n" + "#"*60)
|
|
506
|
+
print("# EvoGrad Core Module Tests")
|
|
507
|
+
print("#"*60)
|
|
508
|
+
|
|
509
|
+
try:
|
|
510
|
+
test_problem()
|
|
511
|
+
test_termination()
|
|
512
|
+
test_result()
|
|
513
|
+
test_algorithm()
|
|
514
|
+
|
|
515
|
+
print("\n" + "="*60)
|
|
516
|
+
print("✓ ALL CORE TESTS PASSED!")
|
|
517
|
+
print("="*60)
|
|
518
|
+
return True
|
|
519
|
+
except Exception as e:
|
|
520
|
+
print(f"\n✗ TEST FAILED: {e}")
|
|
521
|
+
import traceback
|
|
522
|
+
traceback.print_exc()
|
|
523
|
+
return False
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
if __name__ == "__main__":
|
|
527
|
+
success = run_all_tests()
|
|
528
|
+
sys.exit(0 if success else 1)
|