evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,662 @@
1
+ """
2
+ Test script for EvoGrad operators module.
3
+
4
+ Tests:
5
+ - sampling.py: Population initialisation
6
+ - selection.py: Parent selection
7
+ - crossover.py: Recombination operators
8
+ - mutation.py: Mutation operators
9
+ - repair.py: Bounds handling
10
+ """
11
+
12
+ import sys
13
+ import torch
14
+ import os
15
+
16
+ # Add parent to path for imports
17
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ from core.problem import Problem
20
+
21
+ from operators.sampling import (
22
+ UniformSampling,
23
+ LatinHypercubeSampling,
24
+ NormalSampling,
25
+ LogUniformSampling,
26
+ HaltonSampling,
27
+ )
28
+
29
+ from operators.selection import (
30
+ TournamentSelection,
31
+ RouletteSelection,
32
+ RankSelection,
33
+ RandomSelection,
34
+ TruncationSelection,
35
+ StochasticUniversalSampling,
36
+ )
37
+
38
+ from operators.crossover import (
39
+ SBXCrossover,
40
+ BlendCrossover,
41
+ BinomialCrossover,
42
+ ExponentialCrossover,
43
+ UniformCrossover,
44
+ ArithmeticCrossover,
45
+ NPointCrossover,
46
+ )
47
+
48
+ from operators.mutation import (
49
+ PolynomialMutation,
50
+ GaussianMutation,
51
+ UniformMutation,
52
+ NonUniformMutation,
53
+ BoundaryMutation,
54
+ NoMutation,
55
+ CombinedMutation,
56
+ )
57
+
58
+ from operators.repair import (
59
+ ClipRepair,
60
+ ReflectRepair,
61
+ WrapRepair,
62
+ RandomRepair,
63
+ BoundsRepair,
64
+ SoftClipRepair,
65
+ PenaltyRepair,
66
+ NoRepair,
67
+ )
68
+
69
+
70
+ # Helper to create a test problem
71
+ def create_test_problem(n_var=10):
72
+ def sphere(x):
73
+ return (x ** 2).sum(dim=-1)
74
+ return Problem(n_var=n_var, xl=-5.0, xu=5.0, objective=sphere)
75
+
76
+
77
+ def test_sampling():
78
+ """Test sampling operators."""
79
+ print("\n" + "="*60)
80
+ print("Testing sampling.py")
81
+ print("="*60)
82
+
83
+ problem = create_test_problem(n_var=5)
84
+ n_samples = 50
85
+
86
+ # Test UniformSampling
87
+ print("\n1. Testing UniformSampling...")
88
+ sampler = UniformSampling(seed=42)
89
+ pop = sampler(n_samples, problem)
90
+ print(f" Shape: {pop.shape}")
91
+ assert pop.shape == (n_samples, 5)
92
+ assert (pop >= problem.xl).all() and (pop <= problem.xu).all()
93
+ print(f" All within bounds: ✓")
94
+
95
+ # Test reproducibility
96
+ sampler2 = UniformSampling(seed=42)
97
+ pop2 = sampler2(n_samples, problem)
98
+ assert torch.allclose(pop, pop2)
99
+ print(f" Reproducibility with seed: ✓")
100
+
101
+ # Test LatinHypercubeSampling
102
+ print("\n2. Testing LatinHypercubeSampling...")
103
+ lhs = LatinHypercubeSampling(smooth=True, seed=42)
104
+ pop_lhs = lhs(n_samples, problem)
105
+ print(f" Shape: {pop_lhs.shape}")
106
+ assert pop_lhs.shape == (n_samples, 5)
107
+ assert (pop_lhs >= problem.xl).all() and (pop_lhs <= problem.xu).all()
108
+ print(f" All within bounds: ✓")
109
+
110
+ # Test NormalSampling
111
+ print("\n3. Testing NormalSampling...")
112
+ normal = NormalSampling(sigma_factor=0.2, seed=42)
113
+ pop_normal = normal(n_samples, problem)
114
+ print(f" Shape: {pop_normal.shape}")
115
+ # Most should be within bounds (3-sigma)
116
+ within = ((pop_normal >= problem.xl) & (pop_normal <= problem.xu)).float().mean()
117
+ print(f" Fraction within bounds: {within:.1%}")
118
+
119
+ # Test LogUniformSampling
120
+ print("\n4. Testing LogUniformSampling...")
121
+ # Use positive bounds for log sampling
122
+ log_problem = Problem(n_var=5, xl=0.001, xu=1000.0, objective=lambda x: x.sum(-1))
123
+ log_sampler = LogUniformSampling(base=10, seed=42)
124
+ pop_log = log_sampler(n_samples, log_problem)
125
+ print(f" Shape: {pop_log.shape}")
126
+ print(f" Min: {pop_log.min().item():.4f}, Max: {pop_log.max().item():.4f}")
127
+ assert (pop_log > 0).all()
128
+
129
+ # Test HaltonSampling
130
+ print("\n5. Testing HaltonSampling...")
131
+ halton = HaltonSampling(scramble=True, seed=42)
132
+ pop_halton = halton(n_samples, problem)
133
+ print(f" Shape: {pop_halton.shape}")
134
+ assert pop_halton.shape == (n_samples, 5)
135
+ assert (pop_halton >= problem.xl).all() and (pop_halton <= problem.xu).all()
136
+ print(f" All within bounds: ✓")
137
+
138
+ print("\n✓ sampling.py tests passed!")
139
+
140
+
141
+ def test_selection():
142
+ """Test selection operators."""
143
+ print("\n" + "="*60)
144
+ print("Testing selection.py")
145
+ print("="*60)
146
+
147
+ # Create test population and fitness
148
+ n_pop = 50
149
+ n_var = 5
150
+ population = torch.randn(n_pop, n_var)
151
+ fitness = torch.randn(n_pop) # Lower is better
152
+
153
+ n_select = 30
154
+
155
+ # Test TournamentSelection
156
+ print("\n1. Testing TournamentSelection...")
157
+ tournament = TournamentSelection(tournament_size=3, replacement=True)
158
+ selected = tournament(population, fitness, n_select)
159
+ print(f" Selected shape: {selected.shape}")
160
+ assert selected.shape == (n_select, n_var)
161
+
162
+ # Test with indices
163
+ selected, indices = tournament(population, fitness, n_select, return_indices=True)
164
+ print(f" Indices shape: {indices.shape}")
165
+ assert indices.shape == (n_select,)
166
+ assert (indices >= 0).all() and (indices < n_pop).all()
167
+ print(f" Index range valid: ✓")
168
+
169
+ # Test differentiable mode
170
+ print("\n2. Testing TournamentSelection (differentiable)...")
171
+ tournament_diff = TournamentSelection(
172
+ tournament_size=3,
173
+ differentiable=True,
174
+ temperature=1.0,
175
+ learn_temperature=True,
176
+ )
177
+ pop_param = torch.nn.Parameter(population.clone())
178
+ selected_diff = tournament_diff(pop_param, fitness, n_select)
179
+ print(f" Differentiable selection shape: {selected_diff.shape}")
180
+
181
+ # Check gradients flow
182
+ loss = selected_diff.sum()
183
+ loss.backward()
184
+ assert pop_param.grad is not None
185
+ print(f" Gradients flow: ✓")
186
+
187
+ # Test RouletteSelection
188
+ print("\n3. Testing RouletteSelection...")
189
+ roulette = RouletteSelection()
190
+ selected_roulette = roulette(population, fitness, n_select)
191
+ print(f" Selected shape: {selected_roulette.shape}")
192
+ assert selected_roulette.shape == (n_select, n_var)
193
+
194
+ # Test RankSelection
195
+ print("\n4. Testing RankSelection...")
196
+ rank_sel = RankSelection(scheme='linear', selection_pressure=1.5)
197
+ selected_rank = rank_sel(population, fitness, n_select)
198
+ print(f" Selected shape: {selected_rank.shape}")
199
+
200
+ rank_exp = RankSelection(scheme='exponential')
201
+ selected_rank_exp = rank_exp(population, fitness, n_select)
202
+ print(f" Exponential scheme works: ✓")
203
+
204
+ # Test RandomSelection
205
+ print("\n5. Testing RandomSelection...")
206
+ random_sel = RandomSelection(replacement=True)
207
+ selected_random = random_sel(population, fitness, n_select)
208
+ print(f" Selected shape: {selected_random.shape}")
209
+
210
+ # Test TruncationSelection
211
+ print("\n6. Testing TruncationSelection...")
212
+ truncation = TruncationSelection(truncation_ratio=0.5)
213
+ selected_trunc = truncation(population, fitness, n_select)
214
+ print(f" Selected shape: {selected_trunc.shape}")
215
+
216
+ # Test StochasticUniversalSampling
217
+ print("\n7. Testing StochasticUniversalSampling...")
218
+ sus = StochasticUniversalSampling()
219
+ selected_sus = sus(population, fitness, n_select)
220
+ print(f" Selected shape: {selected_sus.shape}")
221
+
222
+ print("\n✓ selection.py tests passed!")
223
+
224
+
225
+ def test_crossover():
226
+ """Test crossover operators."""
227
+ print("\n" + "="*60)
228
+ print("Testing crossover.py")
229
+ print("="*60)
230
+
231
+ n_pairs = 25
232
+ n_var = 10
233
+
234
+ # Create parent pairs
235
+ parent1 = torch.randn(n_pairs, n_var)
236
+ parent2 = torch.randn(n_pairs, n_var)
237
+
238
+ # Test SBXCrossover
239
+ print("\n1. Testing SBXCrossover...")
240
+ sbx = SBXCrossover(eta=15, prob=0.9)
241
+ offspring = sbx(parent1, parent2)
242
+ print(f" Offspring shape: {offspring.shape}")
243
+ assert offspring.shape == (n_pairs, n_var)
244
+
245
+ # Test differentiable mode
246
+ print("\n2. Testing SBXCrossover (differentiable)...")
247
+ sbx_diff = SBXCrossover(
248
+ eta=15,
249
+ prob=0.9,
250
+ differentiable=True,
251
+ learn_eta=True,
252
+ learn_prob=True,
253
+ )
254
+ p1 = torch.nn.Parameter(parent1.clone())
255
+ p2 = torch.nn.Parameter(parent2.clone())
256
+ offspring_diff = sbx_diff(p1, p2)
257
+
258
+ loss = offspring_diff.sum()
259
+ loss.backward()
260
+ assert p1.grad is not None
261
+ print(f" Gradients flow through SBX: ✓")
262
+ print(f" Learnable eta: {sbx_diff._log_eta.item():.4f}")
263
+
264
+ # Test BlendCrossover
265
+ print("\n3. Testing BlendCrossover...")
266
+ blend = BlendCrossover(alpha=0.5)
267
+ offspring_blend = blend(parent1, parent2)
268
+ print(f" Offspring shape: {offspring_blend.shape}")
269
+
270
+ # Test BinomialCrossover (DE-style)
271
+ print("\n4. Testing BinomialCrossover...")
272
+ binomial = BinomialCrossover(cr=0.9)
273
+ offspring_bin = binomial(parent1, parent2) # parent1=target, parent2=donor
274
+ print(f" Offspring shape: {offspring_bin.shape}")
275
+
276
+ # Test differentiable binomial
277
+ binomial_diff = BinomialCrossover(cr=0.9, differentiable=True, learn_cr=True)
278
+ p1 = torch.nn.Parameter(parent1.clone())
279
+ offspring_bin_diff = binomial_diff(p1, parent2)
280
+ loss = offspring_bin_diff.sum()
281
+ loss.backward()
282
+ assert p1.grad is not None
283
+ print(f" Gradients flow through binomial: ✓")
284
+
285
+ # Test ExponentialCrossover
286
+ print("\n5. Testing ExponentialCrossover...")
287
+ exponential = ExponentialCrossover(cr=0.9)
288
+ offspring_exp = exponential(parent1, parent2)
289
+ print(f" Offspring shape: {offspring_exp.shape}")
290
+
291
+ # Test UniformCrossover
292
+ print("\n6. Testing UniformCrossover...")
293
+ uniform = UniformCrossover(prob=0.9)
294
+ offspring_unif = uniform(parent1, parent2)
295
+ print(f" Offspring shape: {offspring_unif.shape}")
296
+
297
+ # Test ArithmeticCrossover
298
+ print("\n7. Testing ArithmeticCrossover...")
299
+ arithmetic = ArithmeticCrossover(alpha=0.5, whole=True)
300
+ offspring_arith = arithmetic(parent1, parent2)
301
+ print(f" Offspring shape: {offspring_arith.shape}")
302
+
303
+ # Verify whole arithmetic is weighted average
304
+ expected = 0.5 * parent1 + 0.5 * parent2
305
+ assert torch.allclose(offspring_arith, expected)
306
+ print(f" Weighted average verified: ✓")
307
+
308
+ # Test NPointCrossover
309
+ print("\n8. Testing NPointCrossover...")
310
+ npoint = NPointCrossover(n_points=2)
311
+ offspring_npoint = npoint(parent1, parent2)
312
+ print(f" Offspring shape: {offspring_npoint.shape}")
313
+
314
+ print("\n✓ crossover.py tests passed!")
315
+
316
+
317
+ def test_mutation():
318
+ """Test mutation operators."""
319
+ print("\n" + "="*60)
320
+ print("Testing mutation.py")
321
+ print("="*60)
322
+
323
+ n_pop = 50
324
+ n_var = 10
325
+
326
+ population = torch.randn(n_pop, n_var)
327
+ xl = torch.full((n_var,), -5.0)
328
+ xu = torch.full((n_var,), 5.0)
329
+
330
+ # Test PolynomialMutation
331
+ print("\n1. Testing PolynomialMutation...")
332
+ poly_mut = PolynomialMutation(eta=20, prob=0.1)
333
+ mutated = poly_mut(population, xl, xu)
334
+ print(f" Mutated shape: {mutated.shape}")
335
+ assert mutated.shape == population.shape
336
+
337
+ # Check some genes changed
338
+ changed = (mutated != population).any(dim=0).sum()
339
+ print(f" Genes with changes: {changed.item()}/{n_var}")
340
+
341
+ # Test differentiable mode
342
+ print("\n2. Testing PolynomialMutation (differentiable)...")
343
+ poly_diff = PolynomialMutation(
344
+ eta=20,
345
+ prob=0.1,
346
+ differentiable=True,
347
+ learn_eta=True,
348
+ learn_prob=True,
349
+ )
350
+ pop_param = torch.nn.Parameter(population.clone())
351
+ mutated_diff = poly_diff(pop_param, xl, xu)
352
+
353
+ loss = mutated_diff.sum()
354
+ loss.backward()
355
+ assert pop_param.grad is not None
356
+ print(f" Gradients flow through polynomial mutation: ✓")
357
+
358
+ # Test GaussianMutation
359
+ print("\n3. Testing GaussianMutation...")
360
+ gauss_mut = GaussianMutation(sigma=0.1, prob=0.2)
361
+ mutated_gauss = gauss_mut(population, xl, xu)
362
+ print(f" Mutated shape: {mutated_gauss.shape}")
363
+
364
+ # Test with sigma_frac
365
+ gauss_frac = GaussianMutation(sigma_frac=0.05)
366
+ mutated_frac = gauss_frac(population, xl, xu)
367
+ print(f" With sigma_frac: ✓")
368
+
369
+ # Test differentiable Gaussian
370
+ gauss_diff = GaussianMutation(sigma=0.1, differentiable=True, learn_sigma=True)
371
+ pop_param = torch.nn.Parameter(population.clone())
372
+ mutated_gauss_diff = gauss_diff(pop_param, xl, xu)
373
+ loss = mutated_gauss_diff.sum()
374
+ loss.backward()
375
+ assert pop_param.grad is not None
376
+ print(f" Gradients flow through Gaussian mutation: ✓")
377
+
378
+ # Test UniformMutation
379
+ print("\n4. Testing UniformMutation...")
380
+ unif_mut = UniformMutation(prob=0.1)
381
+ mutated_unif = unif_mut(population, xl, xu)
382
+ print(f" Mutated shape: {mutated_unif.shape}")
383
+
384
+ # Test NonUniformMutation
385
+ print("\n5. Testing NonUniformMutation...")
386
+ nonunif_mut = NonUniformMutation(max_generations=100, b=5.0)
387
+ nonunif_mut.set_generation(50)
388
+ mutated_nonunif = nonunif_mut(population, xl, xu)
389
+ print(f" Mutated shape: {mutated_nonunif.shape}")
390
+ print(f" Current generation: {nonunif_mut.generation}")
391
+
392
+ # Test BoundaryMutation
393
+ print("\n6. Testing BoundaryMutation...")
394
+ boundary_mut = BoundaryMutation(prob=0.1)
395
+ mutated_boundary = boundary_mut(population, xl, xu)
396
+ print(f" Mutated shape: {mutated_boundary.shape}")
397
+
398
+ # Check boundary values
399
+ at_lower = (mutated_boundary == xl).any()
400
+ at_upper = (mutated_boundary == xu).any()
401
+ print(f" Has values at lower bound: {at_lower}")
402
+ print(f" Has values at upper bound: {at_upper}")
403
+
404
+ # Test NoMutation
405
+ print("\n7. Testing NoMutation...")
406
+ no_mut = NoMutation()
407
+ mutated_none = no_mut(population, xl, xu)
408
+ assert torch.allclose(mutated_none, population)
409
+ print(f" NoMutation returns unchanged: ✓")
410
+
411
+ # Test CombinedMutation
412
+ print("\n8. Testing CombinedMutation...")
413
+ combined = CombinedMutation([
414
+ GaussianMutation(sigma=0.05, prob=0.5),
415
+ PolynomialMutation(eta=20, prob=0.1),
416
+ ])
417
+ mutated_combined = combined(population, xl, xu)
418
+ print(f" Combined mutation shape: {mutated_combined.shape}")
419
+
420
+ print("\n✓ mutation.py tests passed!")
421
+
422
+
423
+ def test_repair():
424
+ """Test repair operators."""
425
+ print("\n" + "="*60)
426
+ print("Testing repair.py")
427
+ print("="*60)
428
+
429
+ n_pop = 50
430
+ n_var = 5
431
+
432
+ xl = torch.zeros(n_var)
433
+ xu = torch.ones(n_var)
434
+
435
+ # Create population with violations
436
+ population = torch.randn(n_pop, n_var) * 2 # Some will be outside [0, 1]
437
+
438
+ n_violations_before = ((population < xl) | (population > xu)).sum().item()
439
+ print(f"\n Violations before repair: {n_violations_before}")
440
+
441
+ # Test ClipRepair
442
+ print("\n1. Testing ClipRepair...")
443
+ clip = ClipRepair()
444
+ repaired_clip = clip(population, xl, xu)
445
+ violations_clip = ((repaired_clip < xl) | (repaired_clip > xu)).sum().item()
446
+ print(f" Violations after clip: {violations_clip}")
447
+ assert violations_clip == 0
448
+ assert (repaired_clip >= xl).all() and (repaired_clip <= xu).all()
449
+ print(f" All within bounds: ✓")
450
+
451
+ # Test ReflectRepair
452
+ print("\n2. Testing ReflectRepair...")
453
+ reflect = ReflectRepair()
454
+ repaired_reflect = reflect(population, xl, xu)
455
+ violations_reflect = ((repaired_reflect < xl) | (repaired_reflect > xu)).sum().item()
456
+ print(f" Violations after reflect: {violations_reflect}")
457
+ assert violations_reflect == 0
458
+ print(f" All within bounds: ✓")
459
+
460
+ # Test that reflection preserves "momentum"
461
+ x_test = torch.tensor([[1.3]]) # 0.3 above upper bound of 1
462
+ xl_test = torch.tensor([0.0])
463
+ xu_test = torch.tensor([1.0])
464
+ reflected = reflect(x_test, xl_test, xu_test)
465
+ print(f" 1.3 reflects to: {reflected.item():.2f} (expected ~0.7)")
466
+ assert abs(reflected.item() - 0.7) < 0.01
467
+
468
+ # Test WrapRepair
469
+ print("\n3. Testing WrapRepair...")
470
+ wrap = WrapRepair()
471
+ repaired_wrap = wrap(population, xl, xu)
472
+ violations_wrap = ((repaired_wrap < xl) | (repaired_wrap > xu)).sum().item()
473
+ print(f" Violations after wrap: {violations_wrap}")
474
+ assert violations_wrap == 0
475
+
476
+ # Test wrapping behaviour
477
+ x_test = torch.tensor([[1.3]])
478
+ wrapped = wrap(x_test, xl_test, xu_test)
479
+ print(f" 1.3 wraps to: {wrapped.item():.2f} (expected ~0.3)")
480
+ assert abs(wrapped.item() - 0.3) < 0.01
481
+
482
+ # Test RandomRepair
483
+ print("\n4. Testing RandomRepair...")
484
+ random_repair = RandomRepair()
485
+ repaired_random = random_repair(population, xl, xu)
486
+ violations_random = ((repaired_random < xl) | (repaired_random > xu)).sum().item()
487
+ print(f" Violations after random: {violations_random}")
488
+ assert violations_random == 0
489
+
490
+ # Test BoundsRepair (configurable)
491
+ print("\n5. Testing BoundsRepair...")
492
+ for method in ['clip', 'reflect', 'wrap', 'random']:
493
+ bounds_repair = BoundsRepair(method=method)
494
+ repaired = bounds_repair(population, xl, xu)
495
+ violations = ((repaired < xl) | (repaired > xu)).sum().item()
496
+ assert violations == 0
497
+ print(f" Method '{method}': ✓")
498
+
499
+ # Test SoftClipRepair
500
+ print("\n6. Testing SoftClipRepair...")
501
+ soft_clip = SoftClipRepair(beta=10.0)
502
+ repaired_soft = soft_clip(population, xl, xu)
503
+ print(f" Soft clip shape: {repaired_soft.shape}")
504
+
505
+ # Check gradients flow
506
+ pop_param = torch.nn.Parameter(population.clone())
507
+ soft_repaired = soft_clip(pop_param, xl, xu)
508
+ loss = soft_repaired.sum()
509
+ loss.backward()
510
+ assert pop_param.grad is not None
511
+ print(f" Gradients flow through soft clip: ✓")
512
+
513
+ # Test PenaltyRepair
514
+ print("\n7. Testing PenaltyRepair...")
515
+ penalty_repair = PenaltyRepair(penalty_weight=100.0, power=2.0)
516
+
517
+ # Compute penalty
518
+ penalty = penalty_repair.compute_penalty(population, xl, xu)
519
+ print(f" Penalty shape: {penalty.shape}")
520
+ print(f" Mean penalty: {penalty.mean().item():.2f}")
521
+
522
+ # Verify no repair happens
523
+ repaired_penalty = penalty_repair(population, xl, xu)
524
+ assert torch.allclose(repaired_penalty, population)
525
+ print(f" PenaltyRepair returns unchanged: ✓")
526
+
527
+ # Test NoRepair
528
+ print("\n8. Testing NoRepair...")
529
+ no_repair = NoRepair()
530
+ repaired_none = no_repair(population, xl, xu)
531
+ assert torch.allclose(repaired_none, population)
532
+ print(f" NoRepair returns unchanged: ✓")
533
+
534
+ # Test is_within_bounds helper
535
+ print("\n9. Testing is_within_bounds...")
536
+ feasible = clip.is_within_bounds(repaired_clip, xl, xu)
537
+ assert feasible.all()
538
+
539
+ infeasible = clip.is_within_bounds(population, xl, xu)
540
+ print(f" Feasible before repair: {infeasible.sum().item()}/{n_pop}")
541
+
542
+ print("\n✓ repair.py tests passed!")
543
+
544
+
545
+ def test_operator_integration():
546
+ """Test operators working together."""
547
+ print("\n" + "="*60)
548
+ print("Testing Operator Integration")
549
+ print("="*60)
550
+
551
+ # Create problem
552
+ problem = create_test_problem(n_var=10)
553
+
554
+ print("\n1. Testing full GA-style pipeline...")
555
+
556
+ # Sampling
557
+ sampling = UniformSampling(seed=42)
558
+ population = sampling(50, problem)
559
+ print(f" Initial population: {population.shape}")
560
+
561
+ # Evaluate
562
+ fitness = problem.evaluate(population)
563
+ print(f" Fitness: {fitness.shape}")
564
+
565
+ # Selection
566
+ selection = TournamentSelection(tournament_size=3)
567
+ parents = selection(population, fitness, 50)
568
+ print(f" Selected parents: {parents.shape}")
569
+
570
+ # Crossover
571
+ crossover = SBXCrossover(eta=15, prob=0.9)
572
+ p1, p2 = parents[:25], parents[25:]
573
+ offspring = crossover(p1, p2)
574
+ print(f" Offspring after crossover: {offspring.shape}")
575
+
576
+ # Mutation
577
+ mutation = PolynomialMutation(eta=20, prob=0.1)
578
+ offspring = mutation(offspring, problem.xl, problem.xu)
579
+ print(f" Offspring after mutation: {offspring.shape}")
580
+
581
+ # Repair
582
+ repair = ReflectRepair()
583
+ offspring = repair(offspring, problem.xl, problem.xu)
584
+ print(f" Offspring after repair: {offspring.shape}")
585
+
586
+ # Verify bounds
587
+ assert (offspring >= problem.xl).all() and (offspring <= problem.xu).all()
588
+ print(f" All offspring within bounds: ✓")
589
+
590
+ print("\n2. Testing differentiable pipeline...")
591
+
592
+ # Create differentiable operators
593
+ selection_diff = TournamentSelection(
594
+ tournament_size=3,
595
+ differentiable=True,
596
+ temperature=1.0,
597
+ )
598
+ crossover_diff = SBXCrossover(
599
+ eta=15,
600
+ prob=0.9,
601
+ differentiable=True,
602
+ learn_eta=True,
603
+ )
604
+ mutation_diff = GaussianMutation(
605
+ sigma=0.1,
606
+ differentiable=True,
607
+ learn_sigma=True,
608
+ )
609
+
610
+ # Run through pipeline with gradient tracking
611
+ pop_param = torch.nn.Parameter(population.clone())
612
+
613
+ parents_diff = selection_diff(pop_param, fitness, 50)
614
+ p1_diff, p2_diff = parents_diff[:25], parents_diff[25:]
615
+ offspring_diff = crossover_diff(p1_diff, p2_diff)
616
+ offspring_diff = mutation_diff(offspring_diff, problem.xl, problem.xu)
617
+
618
+ # Compute loss and backprop
619
+ loss = offspring_diff.sum()
620
+ loss.backward()
621
+
622
+ assert pop_param.grad is not None
623
+ print(f" Gradients flow through entire pipeline: ✓")
624
+
625
+ # Check learnable parameters have gradients
626
+ assert crossover_diff._log_eta.grad is not None
627
+ print(f" Crossover eta gradient: {crossover_diff._log_eta.grad.item():.6f}")
628
+
629
+ assert mutation_diff._log_sigma.grad is not None
630
+ print(f" Mutation sigma gradient: {mutation_diff._log_sigma.grad.item():.6f}")
631
+
632
+ print("\n✓ Operator integration tests passed!")
633
+
634
+
635
+ def run_all_tests():
636
+ """Run all operator tests."""
637
+ print("\n" + "#"*60)
638
+ print("# EvoGrad Operators Module Tests")
639
+ print("#"*60)
640
+
641
+ try:
642
+ test_sampling()
643
+ test_selection()
644
+ test_crossover()
645
+ test_mutation()
646
+ test_repair()
647
+ test_operator_integration()
648
+
649
+ print("\n" + "="*60)
650
+ print("✓ ALL OPERATORS TESTS PASSED!")
651
+ print("="*60)
652
+ return True
653
+ except Exception as e:
654
+ print(f"\n✗ TEST FAILED: {e}")
655
+ import traceback
656
+ traceback.print_exc()
657
+ return False
658
+
659
+
660
+ if __name__ == "__main__":
661
+ success = run_all_tests()
662
+ sys.exit(0 if success else 1)