evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,11 @@
1
+ """
2
+ EvoGrad test suite.
3
+
4
+ Run individual test modules:
5
+ python -m tests.test_utils
6
+ python -m tests.test_core
7
+ python -m tests.test_operators
8
+
9
+ Run all tests:
10
+ python -m tests.run_all
11
+ """
@@ -0,0 +1,78 @@
1
+ """
2
+ Run all EvoGrad tests.
3
+
4
+ Usage:
5
+ python -m tests.run_all
6
+ # or
7
+ python tests/run_all.py
8
+ """
9
+
10
+ import sys
11
+ import os
12
+
13
+ # Add parent to path
14
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from tests.test_utils import run_all_tests as test_utils
17
+ from tests.test_core import run_all_tests as test_core
18
+ from tests.test_operators import run_all_tests as test_operators
19
+
20
+
21
+ def run_all():
22
+ """Run all test suites."""
23
+ print("\n" + "█"*60)
24
+ print("█" + " "*58 + "█")
25
+ print("█" + " EVOGRAD COMPLETE TEST SUITE".center(58) + "█")
26
+ print("█" + " "*58 + "█")
27
+ print("█"*60)
28
+
29
+ results = {}
30
+
31
+ # Run utils tests
32
+ print("\n\n" + "▶"*60)
33
+ print("▶ RUNNING UTILS TESTS")
34
+ print("▶"*60)
35
+ results['utils'] = test_utils()
36
+
37
+ # Run core tests
38
+ print("\n\n" + "▶"*60)
39
+ print("▶ RUNNING CORE TESTS")
40
+ print("▶"*60)
41
+ results['core'] = test_core()
42
+
43
+ # Run operators tests
44
+ print("\n\n" + "▶"*60)
45
+ print("▶ RUNNING OPERATORS TESTS")
46
+ print("▶"*60)
47
+ results['operators'] = test_operators()
48
+
49
+ # Summary
50
+ print("\n\n" + "█"*60)
51
+ print("█" + " "*58 + "█")
52
+ print("█" + " TEST SUMMARY".center(58) + "█")
53
+ print("█" + " "*58 + "█")
54
+ print("█"*60)
55
+
56
+ all_passed = True
57
+ for module, passed in results.items():
58
+ status = "✓ PASSED" if passed else "✗ FAILED"
59
+ print(f" {module:20s} {status}")
60
+ if not passed:
61
+ all_passed = False
62
+
63
+ print()
64
+ if all_passed:
65
+ print("█"*60)
66
+ print("█" + " ✓ ALL TESTS PASSED!".center(58) + "█")
67
+ print("█"*60)
68
+ else:
69
+ print("█"*60)
70
+ print("█" + " ✗ SOME TESTS FAILED!".center(58) + "█")
71
+ print("█"*60)
72
+
73
+ return all_passed
74
+
75
+
76
+ if __name__ == "__main__":
77
+ success = run_all()
78
+ sys.exit(0 if success else 1)
@@ -0,0 +1,528 @@
1
+ """
2
+ Test script for EvoGrad core module.
3
+
4
+ Tests:
5
+ - problem.py: Problem definition and evaluation
6
+ - termination.py: Termination criteria
7
+ - result.py: Result container
8
+ - algorithm.py: Algorithm base class
9
+
10
+ Usage:
11
+ cd evograd && python tests/test_core.py
12
+ """
13
+
14
+ import sys
15
+ import torch
16
+ import torch.nn as nn
17
+ import tempfile
18
+ import os
19
+
20
+ # Add parent directory to path for imports (works when running from evograd/)
21
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
+
23
+ from core.problem import Problem
24
+ from core.termination import (
25
+ Termination,
26
+ MaxEvaluations,
27
+ MaxGenerations,
28
+ TargetReached,
29
+ ToleranceReached,
30
+ TimeLimit,
31
+ TerminationCollection,
32
+ )
33
+ from core.result import Result
34
+ from core.algorithm import Algorithm, AlgorithmState
35
+
36
+
37
+ def test_problem():
38
+ """Test Problem class."""
39
+ print("\n" + "="*60)
40
+ print("Testing problem.py")
41
+ print("="*60)
42
+
43
+ # Test basic Problem creation
44
+ print("\n1. Testing Problem creation...")
45
+
46
+ def sphere_func(x):
47
+ return (x ** 2).sum(dim=-1)
48
+
49
+ problem = Problem(
50
+ n_var=10,
51
+ n_obj=1,
52
+ xl=-5.0,
53
+ xu=5.0,
54
+ objective=sphere_func,
55
+ )
56
+
57
+ print(f" n_var: {problem.n_var}")
58
+ print(f" n_obj: {problem.n_obj}")
59
+ print(f" xl shape: {problem.xl.shape}")
60
+ print(f" xu shape: {problem.xu.shape}")
61
+ assert problem.n_var == 10
62
+ assert problem.xl.shape == (10,)
63
+
64
+ # Test evaluation
65
+ print("\n2. Testing Problem evaluation...")
66
+ x = torch.randn(20, 10) # 20 solutions, 10 variables
67
+ f = problem.evaluate(x)
68
+ print(f" Input shape: {x.shape}")
69
+ print(f" Output shape: {f.shape}")
70
+ print(f" Sample fitness values: {f[:3].tolist()}")
71
+ assert f.shape == (20,)
72
+
73
+ # Test bounds as lists
74
+ print("\n3. Testing Problem with list bounds...")
75
+ problem2 = Problem(
76
+ n_var=3,
77
+ xl=[-1.0, -2.0, -3.0],
78
+ xu=[1.0, 2.0, 3.0],
79
+ objective=sphere_func,
80
+ )
81
+ print(f" xl: {problem2.xl.tolist()}")
82
+ print(f" xu: {problem2.xu.tolist()}")
83
+ assert problem2.xl[1] == -2.0
84
+ assert problem2.xu[2] == 3.0
85
+
86
+ # Test Problem with different objective function
87
+ print("\n4. Testing Problem with Rastrigin function...")
88
+
89
+ def rastrigin(x):
90
+ A = 10.0
91
+ return A * x.shape[-1] + (x**2 - A * torch.cos(2 * 3.14159 * x)).sum(dim=-1)
92
+
93
+ rastrigin_problem = Problem(
94
+ objective=rastrigin,
95
+ n_var=5,
96
+ xl=-5.12,
97
+ xu=5.12,
98
+ )
99
+
100
+ x = torch.zeros(1, 5) # Global optimum
101
+ f = rastrigin_problem(x)
102
+ print(f" Rastrigin at origin: {f.item():.6f}")
103
+ assert abs(f.item()) < 1e-5, "Rastrigin should be 0 at origin"
104
+
105
+ # Test constraints
106
+ print("\n5. Testing Problem with constraints...")
107
+
108
+ def constraint_func(x):
109
+ # g(x) <= 0 format: x[0] + x[1] <= 1
110
+ return x[:, 0] + x[:, 1] - 1.0
111
+
112
+ constrained_problem = Problem(
113
+ n_var=2,
114
+ xl=0.0,
115
+ xu=2.0,
116
+ objective=sphere_func,
117
+ constraints=[(constraint_func, 'ineq')],
118
+ )
119
+
120
+ # Feasible point
121
+ x_feasible = torch.tensor([[0.3, 0.3]])
122
+ g = constrained_problem.evaluate_constraints(x_feasible)
123
+ ineq_val = g['ineq'][0, 0].item()
124
+ print(f" Constraint at [0.3, 0.3]: {ineq_val:.2f} (should be < 0)")
125
+ assert ineq_val < 0, "Point should be feasible"
126
+
127
+ # Infeasible point
128
+ x_infeasible = torch.tensor([[0.8, 0.8]])
129
+ g = constrained_problem.evaluate_constraints(x_infeasible)
130
+ ineq_val = g['ineq'][0, 0].item()
131
+ print(f" Constraint at [0.8, 0.8]: {ineq_val:.2f} (should be > 0)")
132
+ assert ineq_val > 0, "Point should be infeasible"
133
+
134
+ # Test is_feasible
135
+ feasibility = constrained_problem.is_feasible(
136
+ torch.cat([x_feasible, x_infeasible])
137
+ )
138
+ print(f" Feasibility: {feasibility.tolist()}")
139
+ assert feasibility[0] == True
140
+ assert feasibility[1] == False
141
+
142
+ print("\n✓ problem.py tests passed!")
143
+
144
+
145
+ def test_termination():
146
+ """Test termination criteria."""
147
+ print("\n" + "="*60)
148
+ print("Testing termination.py")
149
+ print("="*60)
150
+
151
+ # Create mock algorithm state
152
+ class MockAlgorithm:
153
+ def __init__(self):
154
+ self.n_evals = 0
155
+ self.generation = 0
156
+ self.best_fitness = float('inf')
157
+ self.fitness = torch.tensor([100.0])
158
+ self._prev_best = None
159
+
160
+ # Test MaxEvaluations
161
+ print("\n1. Testing MaxEvaluations...")
162
+ max_evals = MaxEvaluations(max_evals=100) # keyword arg as in actual API
163
+ alg = MockAlgorithm()
164
+
165
+ alg.n_evals = 50
166
+ assert not max_evals.should_terminate(alg), "Should not terminate at 50 evals"
167
+ print(f" Progress at 50 evals: {max_evals.progress(alg):.1%}")
168
+
169
+ alg.n_evals = 100
170
+ assert max_evals.should_terminate(alg), "Should terminate at 100 evals"
171
+ print(f" Terminated at 100 evals: ✓")
172
+
173
+ # Test MaxGenerations
174
+ print("\n2. Testing MaxGenerations...")
175
+ max_gen = MaxGenerations(max_gens=50) # keyword arg as in actual API
176
+ alg = MockAlgorithm()
177
+
178
+ alg.generation = 25
179
+ assert not max_gen.should_terminate(alg)
180
+ print(f" Progress at gen 25: {max_gen.progress(alg):.1%}")
181
+
182
+ alg.generation = 50
183
+ assert max_gen.should_terminate(alg)
184
+ print(f" Terminated at gen 50: ✓")
185
+
186
+ # Test TargetReached (no tolerance param - just target_fitness and minimize)
187
+ print("\n3. Testing TargetReached...")
188
+ target = TargetReached(target_fitness=1.0, minimize=True)
189
+ alg = MockAlgorithm()
190
+
191
+ alg.best_fitness = 10.0
192
+ assert not target.should_terminate(alg)
193
+ print(f" At fitness 10.0: not terminated")
194
+
195
+ alg.best_fitness = 0.5 # Below target (for minimization)
196
+ assert target.should_terminate(alg)
197
+ print(f" At fitness 0.5 (below target): terminated ✓")
198
+
199
+ # Test ToleranceReached
200
+ print("\n4. Testing ToleranceReached...")
201
+ tol = ToleranceReached(tol=0.001, n_last=3, mode='absolute')
202
+ alg = MockAlgorithm()
203
+
204
+ # Simulate improving then stagnating
205
+ fitness_history = [100.0, 50.0, 25.0, 24.9999, 24.9998, 24.9997]
206
+ terminated_at = None
207
+
208
+ for gen, fit in enumerate(fitness_history):
209
+ alg.generation = gen
210
+ alg.best_fitness = fit
211
+ if tol.should_terminate(alg):
212
+ terminated_at = gen
213
+ break
214
+
215
+ print(f" Fitness history: {fitness_history}")
216
+ print(f" Terminated at generation: {terminated_at}")
217
+ assert terminated_at is not None, "Should have terminated due to stagnation"
218
+
219
+ # Test TimeLimit
220
+ print("\n5. Testing TimeLimit...")
221
+ time_limit = TimeLimit(max_seconds=0.1) # 100ms, keyword arg as in actual API
222
+ alg = MockAlgorithm()
223
+
224
+ assert not time_limit.should_terminate(alg)
225
+ print(f" Initial progress: {time_limit.progress(alg):.1%}")
226
+
227
+ import time
228
+ time.sleep(0.15) # Wait 150ms
229
+ assert time_limit.should_terminate(alg)
230
+ print(f" After 150ms: terminated ✓")
231
+
232
+ # Test TerminationCollection (any)
233
+ print("\n6. Testing TerminationCollection (mode='or')...")
234
+ combined = TerminationCollection(
235
+ criteria=[
236
+ MaxEvaluations(1000),
237
+ MaxGenerations(100),
238
+ TargetReached(0.0),
239
+ ],
240
+ mode='or',
241
+ )
242
+
243
+ alg = MockAlgorithm()
244
+ alg.n_evals = 50
245
+ alg.generation = 10
246
+ alg.best_fitness = 0.0 # Target reached!
247
+
248
+ assert combined.should_terminate(alg)
249
+ print(f" Terminated because target reached (any mode): ✓")
250
+
251
+ # Test TerminationCollection (all)
252
+ print("\n7. Testing TerminationCollection (mode='and')...")
253
+ combined_all = TerminationCollection(
254
+ criteria=[
255
+ MaxEvaluations(100),
256
+ MaxGenerations(10),
257
+ ],
258
+ mode='and',
259
+ )
260
+
261
+ alg = MockAlgorithm()
262
+ alg.n_evals = 100 # Met
263
+ alg.generation = 5 # Not met
264
+ assert not combined_all.should_terminate(alg)
265
+ print(f" 100 evals, 5 gens: not terminated (all mode)")
266
+
267
+ alg.generation = 10 # Both met
268
+ assert combined_all.should_terminate(alg)
269
+ print(f" 100 evals, 10 gens: terminated (all mode) ✓")
270
+
271
+ print("\n✓ termination.py tests passed!")
272
+
273
+
274
+ def test_result():
275
+ """Test Result class."""
276
+ print("\n" + "="*60)
277
+ print("Testing result.py")
278
+ print("="*60)
279
+
280
+ # Create a result
281
+ print("\n1. Testing Result creation...")
282
+ result = Result(
283
+ best_solution=torch.tensor([1.0, 2.0, 3.0]),
284
+ best_fitness=0.5,
285
+ n_evals=1000,
286
+ n_gen=50,
287
+ elapsed_time=10.5,
288
+ success=True,
289
+ termination_reason="Optimization completed",
290
+ )
291
+
292
+ print(f" Best solution: {result.best_solution.tolist()}")
293
+ print(f" Best fitness: {result.best_fitness}")
294
+ print(f" Evaluations: {result.n_evals}")
295
+ print(f" Generations: {result.n_gen}")
296
+ print(f" Time: {result.elapsed_time:.1f}s")
297
+
298
+ # Test with history
299
+ print("\n2. Testing Result with history...")
300
+ result.history = {
301
+ 'best_fitness': [100, 50, 25, 10, 5, 1, 0.5],
302
+ 'generation': list(range(7)),
303
+ }
304
+
305
+ print(f" History keys: {list(result.history.keys())}")
306
+ print(f" Best fitness history: {result.history['best_fitness']}")
307
+
308
+ # Test population storage
309
+ print("\n3. Testing Result with population...")
310
+ result.population = torch.randn(20, 3)
311
+ result.fitness = torch.randn(20)
312
+ print(f" Population shape: {result.population.shape}")
313
+ print(f" Fitness shape: {result.fitness.shape}")
314
+
315
+ # Test save and load
316
+ print("\n4. Testing save and load...")
317
+ with tempfile.TemporaryDirectory() as tmpdir:
318
+ filepath = os.path.join(tmpdir, "result.pt")
319
+ result.save(filepath)
320
+ print(f" Saved to: {filepath}")
321
+
322
+ loaded = Result.load(filepath)
323
+ print(f" Loaded best_solution: {loaded.best_solution.tolist()}")
324
+ print(f" Loaded best_fitness: {loaded.best_fitness}")
325
+ print(f" Loaded population shape: {loaded.population.shape}")
326
+ print(f" Loaded fitness shape: {loaded.fitness.shape}")
327
+ assert torch.allclose(result.best_solution, loaded.best_solution)
328
+ assert result.best_fitness == loaded.best_fitness
329
+ print(" Save/load verified ✓")
330
+
331
+ # Test string representation
332
+ print("\n5. Testing string representation...")
333
+ print(result)
334
+
335
+ # Test to_dict
336
+ print("\n6. Testing to_dict...")
337
+ result_dict = result.to_dict()
338
+ print(f" Dict keys: {list(result_dict.keys())}")
339
+ assert 'best_solution' in result_dict
340
+ assert 'best_fitness' in result_dict
341
+ assert 'population' in result_dict
342
+ assert 'fitness' in result_dict
343
+
344
+ print("\n✓ result.py tests passed!")
345
+
346
+
347
+ def test_algorithm():
348
+ """Test Algorithm base class."""
349
+ print("\n" + "="*60)
350
+ print("Testing algorithm.py")
351
+ print("="*60)
352
+
353
+ # Import operators for a concrete implementation
354
+ from operators.sampling import UniformSampling
355
+ from operators.selection import TournamentSelection
356
+ from operators.crossover import SBXCrossover
357
+ from operators.mutation import PolynomialMutation
358
+ from operators.repair import ReflectRepair
359
+
360
+ # Create a simple concrete algorithm
361
+ print("\n1. Creating concrete Algorithm subclass...")
362
+
363
+ class SimpleGA(Algorithm):
364
+ """Minimal GA for testing."""
365
+
366
+ def __init__(self, pop_size=20, **kwargs):
367
+ super().__init__(pop_size=pop_size, **kwargs)
368
+
369
+ def _setup(self):
370
+ """Initialise population (called by initialize)."""
371
+ # Population is already created by parent class
372
+ pass
373
+
374
+ def _infill(self):
375
+ """Generate offspring."""
376
+ # Select parents
377
+ parents = self.selection(
378
+ self.population, self.fitness, self.pop_size
379
+ )
380
+
381
+ # Crossover (pair consecutive)
382
+ n_pairs = self.pop_size // 2
383
+ p1 = parents[:n_pairs]
384
+ p2 = parents[n_pairs:2*n_pairs]
385
+ offspring = self.crossover(p1, p2)
386
+
387
+ # Mutation
388
+ offspring = self.mutation(offspring, self.problem.xl, self.problem.xu)
389
+
390
+ # Repair bounds
391
+ offspring = self.repair(offspring, self.problem.xl, self.problem.xu)
392
+
393
+ return offspring
394
+
395
+ def _advance(self, offspring, offspring_fitness):
396
+ """Update population."""
397
+ # Combine and select best
398
+ combined_pop = torch.cat([self.population, offspring], dim=0)
399
+ combined_fit = torch.cat([self.fitness, offspring_fitness], dim=0)
400
+
401
+ # Select top pop_size
402
+ indices = torch.argsort(combined_fit)[:self.pop_size]
403
+
404
+ # Update population (stored as nn.Parameter)
405
+ with torch.no_grad():
406
+ self._population.copy_(combined_pop[indices])
407
+ self.state.fitness = combined_fit[indices]
408
+
409
+ # Update best
410
+ self.state.update_best(self.population, self.state.fitness)
411
+
412
+ # Create problem
413
+ def sphere(x):
414
+ return (x ** 2).sum(dim=-1)
415
+
416
+ problem = Problem(
417
+ n_var=5,
418
+ xl=-5.0,
419
+ xu=5.0,
420
+ objective=sphere,
421
+ )
422
+
423
+ # Create algorithm with operators
424
+ print("\n2. Testing Algorithm with dependency injection...")
425
+ ga = SimpleGA(
426
+ pop_size=20,
427
+ sampling=UniformSampling(seed=42),
428
+ selection=TournamentSelection(tournament_size=3),
429
+ crossover=SBXCrossover(eta=15, prob=0.9),
430
+ mutation=PolynomialMutation(eta=20),
431
+ repair=ReflectRepair(),
432
+ )
433
+
434
+ print(f" Algorithm: {ga}")
435
+ print(f" Pop size: {ga.pop_size}")
436
+
437
+ # Test initialize (not setup)
438
+ print("\n3. Testing initialize...")
439
+ ga.initialize(problem)
440
+ print(f" Population shape: {ga.population.shape}")
441
+ print(f" Initial best fitness: {ga.best_fitness:.4f}")
442
+
443
+ # Test step
444
+ print("\n4. Testing step() method...")
445
+ initial_best = ga.best_fitness
446
+ for i in range(10):
447
+ ga.step()
448
+ print(f" Generation: {ga.generation}")
449
+ print(f" Best fitness after 10 steps: {ga.best_fitness:.4f}")
450
+ assert ga.generation == 10
451
+
452
+ # Test state dict
453
+ print("\n5. Testing state_dict and load_state_dict...")
454
+ state = ga.state_dict()
455
+ print(f" State keys: {list(state.keys())}")
456
+
457
+ # Create new algorithm with SAME operators and load state
458
+ ga2 = SimpleGA(
459
+ pop_size=20,
460
+ sampling=UniformSampling(seed=42),
461
+ selection=TournamentSelection(tournament_size=3),
462
+ crossover=SBXCrossover(eta=15, prob=0.9),
463
+ mutation=PolynomialMutation(eta=20),
464
+ repair=ReflectRepair(),
465
+ )
466
+ ga2.initialize(problem)
467
+ ga2.load_state_dict(state)
468
+ print(f" Loaded generation: {ga2.generation}")
469
+ print(f" Loaded best_fitness: {ga2.best_fitness:.4f}")
470
+ assert ga2.generation == ga.generation
471
+ assert abs(ga2.best_fitness - ga.best_fitness) < 1e-6
472
+
473
+ # Test save and load using torch
474
+ print("\n6. Testing save and load with torch...")
475
+ with tempfile.TemporaryDirectory() as tmpdir:
476
+ filepath = os.path.join(tmpdir, "algorithm.pt")
477
+ torch.save(ga.state_dict(), filepath)
478
+ print(f" Saved to: {filepath}")
479
+
480
+ ga3 = SimpleGA(
481
+ pop_size=20,
482
+ sampling=UniformSampling(seed=42),
483
+ selection=TournamentSelection(tournament_size=3),
484
+ crossover=SBXCrossover(eta=15, prob=0.9),
485
+ mutation=PolynomialMutation(eta=20),
486
+ repair=ReflectRepair(),
487
+ )
488
+ ga3.initialize(problem)
489
+ ga3.load_state_dict(torch.load(filepath))
490
+ print(f" Loaded generation: {ga3.generation}")
491
+ assert ga3.generation == ga.generation
492
+
493
+ # Test AlgorithmState container
494
+ print("\n7. Testing AlgorithmState container...")
495
+ print(f" state.generation: {ga.state.generation}")
496
+ print(f" state.n_evals: {ga.state.n_evals}")
497
+ print(f" state.best_fitness: {ga.state.best_fitness:.4f}")
498
+ assert ga.state.generation == ga.generation
499
+
500
+ print("\n✓ algorithm.py tests passed!")
501
+
502
+
503
+ def run_all_tests():
504
+ """Run all core tests."""
505
+ print("\n" + "#"*60)
506
+ print("# EvoGrad Core Module Tests")
507
+ print("#"*60)
508
+
509
+ try:
510
+ test_problem()
511
+ test_termination()
512
+ test_result()
513
+ test_algorithm()
514
+
515
+ print("\n" + "="*60)
516
+ print("✓ ALL CORE TESTS PASSED!")
517
+ print("="*60)
518
+ return True
519
+ except Exception as e:
520
+ print(f"\n✗ TEST FAILED: {e}")
521
+ import traceback
522
+ traceback.print_exc()
523
+ return False
524
+
525
+
526
+ if __name__ == "__main__":
527
+ success = run_all_tests()
528
+ sys.exit(0 if success else 1)