evograd-diff 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. evograd/__init__.py +67 -0
  2. evograd/algorithms/__init__.py +138 -0
  3. evograd/algorithms/cmaes.py +1365 -0
  4. evograd/algorithms/de.py +895 -0
  5. evograd/algorithms/ga.py +532 -0
  6. evograd/algorithms/pso.py +648 -0
  7. evograd/algorithms/shade.py +1165 -0
  8. evograd/benchmarks/functions/__init__.py +229 -0
  9. evograd/benchmarks/functions/base.py +217 -0
  10. evograd/benchmarks/functions/cec2017/__init__.py +250 -0
  11. evograd/benchmarks/functions/cec2017/basic.py +413 -0
  12. evograd/benchmarks/functions/cec2017/composition.py +580 -0
  13. evograd/benchmarks/functions/cec2017/data.pkl +0 -0
  14. evograd/benchmarks/functions/cec2017/data.py +350 -0
  15. evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
  16. evograd/benchmarks/functions/cec2017/simple.py +326 -0
  17. evograd/benchmarks/functions/classical.py +649 -0
  18. evograd/benchmarks/functions/smoothed_funnel.py +476 -0
  19. evograd/benchmarks/functions/transforms.py +463 -0
  20. evograd/benchmarks/run_benchmark_functions.py +1208 -0
  21. evograd/core/__init__.py +73 -0
  22. evograd/core/algorithm.py +778 -0
  23. evograd/core/maximize.py +269 -0
  24. evograd/core/minimize.py +740 -0
  25. evograd/core/problem.py +444 -0
  26. evograd/core/result.py +571 -0
  27. evograd/core/termination.py +602 -0
  28. evograd/operators/__init__.py +178 -0
  29. evograd/operators/crossover.py +1117 -0
  30. evograd/operators/mutation.py +1098 -0
  31. evograd/operators/relaxations.py +175 -0
  32. evograd/operators/repair.py +601 -0
  33. evograd/operators/sampling.py +577 -0
  34. evograd/operators/selection.py +981 -0
  35. evograd/operators/survival.py +1000 -0
  36. evograd/tests/__init__.py +11 -0
  37. evograd/tests/run_all.py +78 -0
  38. evograd/tests/test_core.py +528 -0
  39. evograd/tests/test_ga.py +572 -0
  40. evograd/tests/test_operators.py +662 -0
  41. evograd/tests/test_per_individual.py +326 -0
  42. evograd/tests/test_utils.py +328 -0
  43. evograd/utils/__init__.py +97 -0
  44. evograd/utils/callbacks.py +926 -0
  45. evograd/utils/device.py +502 -0
  46. evograd/utils/duplicates.py +421 -0
  47. evograd_diff-0.1.0.dist-info/METADATA +439 -0
  48. evograd_diff-0.1.0.dist-info/RECORD +50 -0
  49. evograd_diff-0.1.0.dist-info/WHEEL +4 -0
  50. evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,444 @@
1
+ """
2
+ Problem definition for EvoGrad optimisation.
3
+
4
+ This module provides the Problem class that encapsulates:
5
+ - Objective function(s)
6
+ - Variable bounds (xl, xu)
7
+ - Constraints (equality and inequality)
8
+ - Problem metadata
9
+
10
+ The Problem class is intentionally minimal. Bounds handling (repair),
11
+ population initialisation (sampling), and other operations belong
12
+ to their respective operator classes in the operators subpackage.
13
+
14
+ Problems can be defined in two ways:
15
+ 1. Functional: Pass objective as a callable
16
+ 2. Subclassing: Override _evaluate() method
17
+
18
+ Example (functional):
19
+ >>> from evograd.core import Problem
20
+ >>>
21
+ >>> problem = Problem(
22
+ ... objective=ackley,
23
+ ... n_var=30,
24
+ ... xl=-32.768,
25
+ ... xu=32.768,
26
+ ... )
27
+
28
+ Example (subclassing):
29
+ >>> class Rastrigin(Problem):
30
+ ... def __init__(self, n_var=30):
31
+ ... super().__init__(n_var=n_var, xl=-5.12, xu=5.12)
32
+ ...
33
+ ... def _evaluate(self, x):
34
+ ... A = 10.0
35
+ ... return A * x.shape[-1] + (x**2 - A * torch.cos(2 * torch.pi * x)).sum(dim=-1)
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ from typing import (
41
+ TYPE_CHECKING,
42
+ Any,
43
+ Callable,
44
+ Dict,
45
+ List,
46
+ Optional,
47
+ Tuple,
48
+ Union,
49
+ )
50
+
51
+ import torch
52
+ import torch.nn as nn
53
+
54
+ from evograd.utils.device import get_device, ensure_tensor
55
+
56
+ if TYPE_CHECKING:
57
+ from torch import Tensor
58
+
59
+ __all__ = [
60
+ "Problem",
61
+ ]
62
+
63
+
64
+ # =============================================================================
65
+ # Problem Class
66
+ # =============================================================================
67
+
68
+ class Problem(nn.Module):
69
+ """
70
+ Optimisation problem definition.
71
+
72
+ Encapsulates the objective function, variable bounds, and optional
73
+ constraints. The Problem class is intentionally minimal - it only
74
+ stores problem definition, not operations on solutions.
75
+
76
+ For bounds handling, use operators/repair.py.
77
+ For population initialisation, use operators/sampling.py.
78
+
79
+ The problem can be defined either by passing an objective callable
80
+ to the constructor, or by subclassing and overriding _evaluate().
81
+
82
+ Args:
83
+ objective: Callable that takes (N, n_var) tensor and returns (N,) fitness.
84
+ If None, subclass must override _evaluate().
85
+ n_var: Number of decision variables.
86
+ xl: Lower bounds. Can be:
87
+ - Scalar (applied to all variables)
88
+ - List of length n_var
89
+ - Tensor of shape (n_var,)
90
+ xu: Upper bounds (same format as xl).
91
+ constraints: List of constraint tuples: (func, type).
92
+ - func: Callable (N, n_var) -> (N,) or (N, n_constraints)
93
+ - type: 'ineq' for g(x) <= 0, 'eq' for h(x) = 0
94
+ n_obj: Number of objectives (default: 1, multi-objective planned).
95
+ name: Optional problem name for identification.
96
+ device: Computation device (default: auto-detect).
97
+ dtype: Tensor dtype (default: float32). Use ``torch.float64`` for
98
+ problems that require higher numerical precision, such as parameter
99
+ estimation with stiff ODE solvers. All operators respect the dtype
100
+ propagated from the Problem; ensure the Algorithm is created with a
101
+ matching dtype to avoid silent precision loss.
102
+
103
+ Attributes:
104
+ n_var: Number of decision variables.
105
+ n_obj: Number of objectives.
106
+ n_ieq_constr: Number of inequality constraints.
107
+ n_eq_constr: Number of equality constraints.
108
+ n_constr: Total number of constraints.
109
+ xl: Lower bounds tensor of shape (n_var,).
110
+ xu: Upper bounds tensor of shape (n_var,).
111
+ name: Problem name.
112
+
113
+ Example:
114
+ >>> # Simple unconstrained problem
115
+ >>> problem = Problem(
116
+ ... objective=lambda x: (x ** 2).sum(dim=-1),
117
+ ... n_var=10,
118
+ ... xl=-5.0,
119
+ ... xu=5.0,
120
+ ... )
121
+ >>>
122
+ >>> # Evaluate a batch of solutions
123
+ >>> x = torch.rand(100, 10) * 10 - 5
124
+ >>> fitness = problem.evaluate(x)
125
+ >>> print(fitness.shape) # torch.Size([100])
126
+
127
+ >>> # Problem with constraints
128
+ >>> problem = Problem(
129
+ ... objective=lambda x: x[:, 0] + x[:, 1],
130
+ ... n_var=2,
131
+ ... xl=0.0,
132
+ ... xu=10.0,
133
+ ... constraints=[
134
+ ... (lambda x: x[:, 0] + x[:, 1] - 5, 'ineq'), # x0 + x1 <= 5
135
+ ... (lambda x: x[:, 0] - 2 * x[:, 1], 'eq'), # x0 = 2 * x1
136
+ ... ],
137
+ ... )
138
+ """
139
+
140
+ # Constraint type constants
141
+ INEQ = "ineq" # Inequality: g(x) <= 0
142
+ EQ = "eq" # Equality: h(x) = 0
143
+
144
+ def __init__(
145
+ self,
146
+ objective: Optional[Callable[[Tensor], Tensor]] = None,
147
+ n_var: Optional[int] = None,
148
+ xl: Union[float, List[float], Tensor] = -100.0,
149
+ xu: Union[float, List[float], Tensor] = 100.0,
150
+ constraints: Optional[List[Tuple[Callable[[Tensor], Tensor], str]]] = None,
151
+ n_obj: int = 1,
152
+ name: Optional[str] = None,
153
+ device: Optional[Union[str, torch.device]] = None,
154
+ dtype: torch.dtype = torch.float32,
155
+ ) -> None:
156
+ super().__init__()
157
+
158
+ # Validate inputs
159
+ if objective is None and type(self)._evaluate is Problem._evaluate:
160
+ raise ValueError(
161
+ "Either provide 'objective' callable or subclass Problem "
162
+ "and override _evaluate()"
163
+ )
164
+
165
+ if n_var is None:
166
+ raise ValueError("n_var (number of variables) must be specified")
167
+
168
+ if n_var < 1:
169
+ raise ValueError(f"n_var must be >= 1, got {n_var}")
170
+
171
+ if n_obj < 1:
172
+ raise ValueError(f"n_obj must be >= 1, got {n_obj}")
173
+
174
+ # Store configuration
175
+ self.n_var = n_var
176
+ self.n_obj = n_obj
177
+ self._objective = objective
178
+ self.name = name or self.__class__.__name__
179
+ self.device = get_device(device)
180
+ self.dtype = dtype
181
+
182
+ # Process and register bounds
183
+ xl_tensor, xu_tensor = self._process_bounds(xl, xu, n_var)
184
+ self.register_buffer("xl", xl_tensor)
185
+ self.register_buffer("xu", xu_tensor)
186
+
187
+ # Process constraints
188
+ self._constraints: List[Tuple[Callable, str]] = []
189
+ self.n_ieq_constr = 0
190
+ self.n_eq_constr = 0
191
+
192
+ if constraints is not None:
193
+ for func, ctype in constraints:
194
+ ctype = ctype.lower()
195
+ if ctype not in (self.INEQ, self.EQ):
196
+ raise ValueError(
197
+ f"Constraint type must be 'ineq' or 'eq', got '{ctype}'"
198
+ )
199
+ self._constraints.append((func, ctype))
200
+ if ctype == self.INEQ:
201
+ self.n_ieq_constr += 1
202
+ else:
203
+ self.n_eq_constr += 1
204
+
205
+ self.n_constr = self.n_ieq_constr + self.n_eq_constr
206
+
207
+ def _process_bounds(
208
+ self,
209
+ xl: Union[float, List[float], Tensor],
210
+ xu: Union[float, List[float], Tensor],
211
+ n_var: int,
212
+ ) -> Tuple[Tensor, Tensor]:
213
+ """Process and validate bounds."""
214
+ # Convert to tensors
215
+ xl_tensor = ensure_tensor(xl, device=self.device, dtype=self.dtype)
216
+ xu_tensor = ensure_tensor(xu, device=self.device, dtype=self.dtype)
217
+
218
+ # Expand scalars to full dimension
219
+ if xl_tensor.dim() == 0 or xl_tensor.numel() == 1:
220
+ xl_tensor = xl_tensor.expand(n_var).clone()
221
+ if xu_tensor.dim() == 0 or xu_tensor.numel() == 1:
222
+ xu_tensor = xu_tensor.expand(n_var).clone()
223
+
224
+ # Validate shapes
225
+ if xl_tensor.shape[0] != n_var:
226
+ raise ValueError(
227
+ f"xl has {xl_tensor.shape[0]} elements but n_var={n_var}"
228
+ )
229
+ if xu_tensor.shape[0] != n_var:
230
+ raise ValueError(
231
+ f"xu has {xu_tensor.shape[0]} elements but n_var={n_var}"
232
+ )
233
+
234
+ # Validate bounds ordering
235
+ if (xl_tensor > xu_tensor).any():
236
+ raise ValueError("Lower bounds must be <= upper bounds")
237
+
238
+ return xl_tensor, xu_tensor
239
+
240
+ # =========================================================================
241
+ # Evaluation
242
+ # =========================================================================
243
+
244
+ def _evaluate(self, x: Tensor) -> Tensor:
245
+ """
246
+ Evaluate objective function.
247
+
248
+ Override this method in subclasses to define custom objectives.
249
+
250
+ Args:
251
+ x: Decision variables of shape (N, n_var).
252
+
253
+ Returns:
254
+ Fitness values of shape (N,) for single-objective,
255
+ or (N, n_obj) for multi-objective.
256
+ """
257
+ if self._objective is not None:
258
+ return self._objective(x)
259
+ raise NotImplementedError(
260
+ "Subclass must implement _evaluate() if objective not provided"
261
+ )
262
+
263
+ def evaluate(self, x: Tensor) -> Tensor:
264
+ """
265
+ Evaluate fitness of solutions.
266
+
267
+ Args:
268
+ x: Decision variables of shape (N, n_var) or (n_var,).
269
+
270
+ Returns:
271
+ Fitness values of shape (N,) or scalar.
272
+ """
273
+ # Handle single solution
274
+ squeeze_output = False
275
+ if x.dim() == 1:
276
+ x = x.unsqueeze(0)
277
+ squeeze_output = True
278
+
279
+ # Ensure correct device and dtype
280
+ x = x.to(device=self.device, dtype=self.dtype)
281
+
282
+ # Validate shape
283
+ if x.shape[-1] != self.n_var:
284
+ raise ValueError(
285
+ f"Expected {self.n_var} variables, got {x.shape[-1]}"
286
+ )
287
+
288
+ # Evaluate objective
289
+ fitness = self._evaluate(x)
290
+
291
+ # Ensure correct output shape
292
+ if fitness.dim() == 0:
293
+ fitness = fitness.unsqueeze(0)
294
+
295
+ if squeeze_output and fitness.shape[0] == 1:
296
+ fitness = fitness.squeeze(0)
297
+
298
+ return fitness
299
+
300
+ def evaluate_constraints(self, x: Tensor) -> Dict[str, Tensor]:
301
+ """
302
+ Evaluate all constraints.
303
+
304
+ Args:
305
+ x: Decision variables of shape (N, n_var).
306
+
307
+ Returns:
308
+ Dictionary with keys:
309
+ - 'ineq': Inequality constraint values g(x), shape (N, n_ieq_constr)
310
+ - 'eq': Equality constraint values h(x), shape (N, n_eq_constr)
311
+ - 'cv': Total constraint violation per solution, shape (N,)
312
+ """
313
+ if x.dim() == 1:
314
+ x = x.unsqueeze(0)
315
+
316
+ x = x.to(device=self.device, dtype=self.dtype)
317
+ n_solutions = x.shape[0]
318
+
319
+ ineq_values = []
320
+ eq_values = []
321
+
322
+ for func, ctype in self._constraints:
323
+ val = func(x)
324
+ if val.dim() == 1:
325
+ val = val.unsqueeze(-1)
326
+
327
+ if ctype == self.INEQ:
328
+ ineq_values.append(val)
329
+ else:
330
+ eq_values.append(val)
331
+
332
+ # Stack constraint values
333
+ if ineq_values:
334
+ ineq = torch.cat(ineq_values, dim=-1)
335
+ else:
336
+ ineq = torch.zeros(n_solutions, 0, device=self.device, dtype=self.dtype)
337
+
338
+ if eq_values:
339
+ eq = torch.cat(eq_values, dim=-1)
340
+ else:
341
+ eq = torch.zeros(n_solutions, 0, device=self.device, dtype=self.dtype)
342
+
343
+ # Compute constraint violation
344
+ # For ineq: max(0, g(x)) (violation if positive)
345
+ # For eq: |h(x)| (violation if non-zero)
346
+ cv = torch.zeros(n_solutions, device=self.device, dtype=self.dtype)
347
+ if ineq.shape[-1] > 0:
348
+ cv = cv + torch.clamp(ineq, min=0).sum(dim=-1)
349
+ if eq.shape[-1] > 0:
350
+ cv = cv + torch.abs(eq).sum(dim=-1)
351
+
352
+ return {
353
+ "ineq": ineq,
354
+ "eq": eq,
355
+ "cv": cv,
356
+ }
357
+
358
+ def is_feasible(self, x: Tensor, tol: float = 1e-6) -> Tensor:
359
+ """
360
+ Check if solutions satisfy all constraints.
361
+
362
+ Args:
363
+ x: Decision variables of shape (N, n_var) or (n_var,).
364
+ tol: Tolerance for constraint satisfaction.
365
+
366
+ Returns:
367
+ Boolean tensor of shape (N,) or scalar.
368
+ """
369
+ squeeze = x.dim() == 1
370
+
371
+ if self.n_constr == 0:
372
+ if squeeze:
373
+ return torch.tensor(True, device=self.device)
374
+ return torch.ones(x.shape[0], dtype=torch.bool, device=self.device)
375
+
376
+ cv = self.evaluate_constraints(x)["cv"]
377
+ result = cv <= tol
378
+
379
+ if squeeze:
380
+ result = result.squeeze(0)
381
+
382
+ return result
383
+
384
+ def has_constraints(self) -> bool:
385
+ """Check if problem has any constraints."""
386
+ return self.n_constr > 0
387
+
388
+ # =========================================================================
389
+ # PyTorch Forward
390
+ # =========================================================================
391
+
392
+ def forward(self, x: Tensor) -> Tensor:
393
+ """
394
+ PyTorch forward pass (alias for evaluate).
395
+
396
+ Enables using Problem as an nn.Module in computation graphs.
397
+ """
398
+ return self.evaluate(x)
399
+
400
+ # =========================================================================
401
+ # String Representation
402
+ # =========================================================================
403
+
404
+ def __repr__(self) -> str:
405
+ parts = [
406
+ f"name='{self.name}'",
407
+ f"n_var={self.n_var}",
408
+ f"n_obj={self.n_obj}",
409
+ ]
410
+
411
+ if self.n_constr > 0:
412
+ parts.append(f"n_constr={self.n_ieq_constr}ineq+{self.n_eq_constr}eq")
413
+
414
+ return f"{self.__class__.__name__}({', '.join(parts)})"
415
+
416
+ def summary(self) -> str:
417
+ """Return detailed problem summary."""
418
+ lines = [
419
+ f"{'=' * 50}",
420
+ f"Problem: {self.name}",
421
+ f"{'=' * 50}",
422
+ f" Variables: {self.n_var}",
423
+ f" Objectives: {self.n_obj}",
424
+ ]
425
+
426
+ if self.n_constr > 0:
427
+ lines.append(
428
+ f" Constraints: {self.n_constr} "
429
+ f"({self.n_ieq_constr} inequality, {self.n_eq_constr} equality)"
430
+ )
431
+ else:
432
+ lines.append(" Constraints: None")
433
+
434
+ lines.extend([
435
+ f"",
436
+ f"Bounds:",
437
+ f" xl: [{float(self.xl.min()):.4g}, {float(self.xl.max()):.4g}]",
438
+ f" xu: [{float(self.xu.min()):.4g}, {float(self.xu.max()):.4g}]",
439
+ f"",
440
+ f"Device: {self.device}",
441
+ f"{'=' * 50}",
442
+ ])
443
+
444
+ return "\n".join(lines)