moospread 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. moospread/__init__.py +3 -0
  2. moospread/core.py +1881 -0
  3. moospread/problem.py +193 -0
  4. moospread/tasks/__init__.py +4 -0
  5. moospread/tasks/dtlz_torch.py +139 -0
  6. moospread/tasks/mw_torch.py +274 -0
  7. moospread/tasks/re_torch.py +394 -0
  8. moospread/tasks/zdt_torch.py +112 -0
  9. moospread/utils/__init__.py +8 -0
  10. moospread/utils/constraint_utils/__init__.py +2 -0
  11. moospread/utils/constraint_utils/gradient.py +72 -0
  12. moospread/utils/constraint_utils/mgda_core.py +69 -0
  13. moospread/utils/constraint_utils/pmgda_solver.py +308 -0
  14. moospread/utils/constraint_utils/prefs.py +64 -0
  15. moospread/utils/ditmoo.py +127 -0
  16. moospread/utils/lhs.py +74 -0
  17. moospread/utils/misc.py +28 -0
  18. moospread/utils/mobo_utils/__init__.py +11 -0
  19. moospread/utils/mobo_utils/evolution/__init__.py +0 -0
  20. moospread/utils/mobo_utils/evolution/dom.py +60 -0
  21. moospread/utils/mobo_utils/evolution/norm.py +40 -0
  22. moospread/utils/mobo_utils/evolution/utils.py +97 -0
  23. moospread/utils/mobo_utils/learning/__init__.py +0 -0
  24. moospread/utils/mobo_utils/learning/model.py +40 -0
  25. moospread/utils/mobo_utils/learning/model_init.py +33 -0
  26. moospread/utils/mobo_utils/learning/model_update.py +51 -0
  27. moospread/utils/mobo_utils/learning/prediction.py +116 -0
  28. moospread/utils/mobo_utils/learning/utils.py +143 -0
  29. moospread/utils/mobo_utils/lhs_for_mobo.py +243 -0
  30. moospread/utils/mobo_utils/mobo/__init__.py +0 -0
  31. moospread/utils/mobo_utils/mobo/acquisition.py +209 -0
  32. moospread/utils/mobo_utils/mobo/algorithms.py +91 -0
  33. moospread/utils/mobo_utils/mobo/factory.py +86 -0
  34. moospread/utils/mobo_utils/mobo/mobo.py +132 -0
  35. moospread/utils/mobo_utils/mobo/selection.py +182 -0
  36. moospread/utils/mobo_utils/mobo/solver/__init__.py +5 -0
  37. moospread/utils/mobo_utils/mobo/solver/moead.py +17 -0
  38. moospread/utils/mobo_utils/mobo/solver/nsga2.py +10 -0
  39. moospread/utils/mobo_utils/mobo/solver/parego/__init__.py +1 -0
  40. moospread/utils/mobo_utils/mobo/solver/parego/parego.py +62 -0
  41. moospread/utils/mobo_utils/mobo/solver/parego/utils.py +34 -0
  42. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/__init__.py +1 -0
  43. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/buffer.py +364 -0
  44. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/pareto_discovery.py +571 -0
  45. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/utils.py +168 -0
  46. moospread/utils/mobo_utils/mobo/solver/solver.py +74 -0
  47. moospread/utils/mobo_utils/mobo/surrogate_model/__init__.py +2 -0
  48. moospread/utils/mobo_utils/mobo/surrogate_model/base.py +36 -0
  49. moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py +177 -0
  50. moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py +79 -0
  51. moospread/utils/mobo_utils/mobo/surrogate_problem.py +44 -0
  52. moospread/utils/mobo_utils/mobo/transformation.py +106 -0
  53. moospread/utils/mobo_utils/mobo/utils.py +65 -0
  54. moospread/utils/mobo_utils/spread_mobo_utils.py +854 -0
  55. moospread/utils/offline_utils/__init__.py +10 -0
  56. moospread/utils/offline_utils/handle_task.py +203 -0
  57. moospread/utils/offline_utils/proxies.py +338 -0
  58. moospread/utils/spread_utils.py +91 -0
  59. moospread-0.1.0.dist-info/METADATA +75 -0
  60. moospread-0.1.0.dist-info/RECORD +63 -0
  61. moospread-0.1.0.dist-info/WHEEL +5 -0
  62. moospread-0.1.0.dist-info/licenses/LICENSE +10 -0
  63. moospread-0.1.0.dist-info/top_level.txt +1 -0
moospread/problem.py ADDED
@@ -0,0 +1,193 @@
1
+ """
2
+ This module defines an abstract base class for
3
+ multi-objective optimization problems using PyTorch.
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ from abc import ABC, abstractmethod
9
+ from pymoo.util.cache import Cache
10
+
11
+
12
+ def default_shape(problem, n):
13
+ n_var = problem.n_var
14
+ return {
15
+ 'F': (n, problem.n_obj),
16
+ 'G': (n, problem.n_ieq_constr),
17
+ 'H': (n, problem.n_eq_constr),
18
+ 'dF': (n, problem.n_obj, n_var),
19
+ 'dG': (n, problem.n_ieq_constr, n_var),
20
+ 'dH': (n, problem.n_eq_constr, n_var),
21
+ 'CV': (n, 1),
22
+ }
23
+
24
+ class PymooProblemTorch(ABC):
25
+ def __init__(
26
+ self,
27
+ n_var: int = -1,
28
+ n_obj: int = 1,
29
+ n_ieq_constr: int = 0,
30
+ n_eq_constr: int = 0,
31
+ xl=None,
32
+ xu=None,
33
+ device: torch.device = None,
34
+ elementwise: bool = False,
35
+ replace_nan_by: float = None,
36
+ strict: bool = True,
37
+ **kwargs
38
+ ):
39
+ self.n_var = n_var
40
+ self.n_obj = n_obj
41
+ self.n_ieq_constr = n_ieq_constr
42
+ self.n_eq_constr = n_eq_constr
43
+ self.device = device or torch.device('cpu')
44
+ self.elementwise = elementwise
45
+ self.replace_nan_by = replace_nan_by
46
+ self.strict = strict
47
+ self.ref_point = None
48
+
49
+ # bounds
50
+ if n_var > 0 and xl is not None:
51
+ self.xl = (
52
+ xl
53
+ if isinstance(xl, torch.Tensor)
54
+ else torch.full((n_var,), xl, dtype=torch.float, device=self.device)
55
+ )
56
+ else:
57
+ self.xl = None
58
+ if n_var > 0 and xu is not None:
59
+ self.xu = (
60
+ xu
61
+ if isinstance(xu, torch.Tensor)
62
+ else torch.full((n_var,), xu, dtype=torch.float, device=self.device)
63
+ )
64
+ else:
65
+ self.xu = None
66
+
67
+ self.is_discrete = False
68
+ self.is_sequence = False
69
+ self.original_bounds = None
70
+ self.need_repair = False
71
+ if self.has_bounds():
72
+ self.original_bounds = [self.xl.clone(), self.xu.clone()]
73
+ self.global_clamping = False
74
+ self.need_repair = True
75
+
76
+ def has_bounds(self) -> bool:
77
+ return self.xl is not None and self.xu is not None
78
+
79
+ def has_constraints(self) -> bool:
80
+ return (self.n_ieq_constr + self.n_eq_constr) > 0
81
+
82
+ def name(self) -> str:
83
+ return self.__class__.__name__
84
+
85
+ def bounds(self):
86
+ return [self.xl, self.xu]
87
+
88
+ def evaluate(
89
+ self,
90
+ x: torch.Tensor,
91
+ return_as_dict: bool = False,
92
+ return_values_of=None
93
+ ):
94
+ """
95
+ Evaluate objectives on input tensor x.
96
+
97
+ Args:
98
+ x: Tensor of shape (batch_size, n_var) with requires_grad=True.
99
+ return_as_dict: If True, return the full dict of computed values.
100
+ return_values_of: Optional list of keys from the internal out dict to return alongside F.
101
+
102
+ Returns:
103
+ If return_as_dict=True: the out dict (including all computed values).
104
+ Else if return_values_of provided: a tuple (F, dict_of_requested_values).
105
+ Else: Tensor F of shape (batch_size, n_obj).
106
+ """
107
+ if not isinstance(x, torch.Tensor):
108
+ raise TypeError("Input must be a torch.Tensor")
109
+ x.data = torch.nan_to_num(x.data, nan=0.0, posinf=1.0, neginf=0.0)
110
+
111
+ # Ensure is is within bounds
112
+ if self.has_bounds():
113
+ if self.xl is not None:
114
+ x.data = torch.max(x.data, self.xl.to(x.device))
115
+ if self.xu is not None:
116
+ x.data = torch.min(x.data, self.xu.to(x.device))
117
+
118
+ out: dict = {}
119
+ self._evaluate(x, out)
120
+
121
+ # Prepare base return
122
+ F = out.get("F")
123
+ # If specific extra values requested
124
+ extra = None
125
+ if return_values_of is not None:
126
+ if isinstance(return_values_of, (list, tuple)):
127
+ extra = {k: out[k] for k in return_values_of if k in out}
128
+ else:
129
+ key = return_values_of
130
+ extra = {key: out[key]} if key in out else {}
131
+
132
+ if return_as_dict:
133
+ if extra is not None:
134
+ out.update(extra)
135
+ return out
136
+
137
+ if extra is not None:
138
+ if len(extra) == 1:
139
+ return list(extra.values())[0]
140
+ return tuple(extra[k] for k in extra)
141
+
142
+ return F
143
+
144
+ @abstractmethod
145
+ def _evaluate(self, x, out: dict):
146
+ """
147
+ User-implemented evaluation.
148
+ If elementwise: x is 1d Tensor of shape (n_var,).
149
+ Otherwise x is 2d Tensor of shape (n_samples, n_var).
150
+ out should be populated with keys 'F', 'G', 'H' as torch.Tensors.
151
+ """
152
+ pass
153
+
154
+ @Cache
155
+ def pareto_front(self, *args, **kwargs):
156
+ pf = self._calc_pareto_front(*args, **kwargs)
157
+ if pf is None:
158
+ return None
159
+ if not isinstance(pf, torch.Tensor):
160
+ pf = torch.as_tensor(pf, dtype=torch.float32)
161
+ if pf.dim() == 1:
162
+ pf = pf.unsqueeze(0)
163
+ if pf.size(1) == 2:
164
+ pf = pf[pf[:, 0].argsort()]
165
+ return pf
166
+
167
+ def __str__(self):
168
+ return (
169
+ f"Problem(name={self.name()}, n_var={self.n_var}, n_obj={self.n_obj}, "
170
+ f"n_ieq_constr={self.n_ieq_constr}, n_eq_constr={self.n_eq_constr})"
171
+ )
172
+
173
+
174
+ class BaseProblem(PymooProblemTorch):
175
+ """
176
+ Base Problem class to ensure PyTorch differentiability.
177
+ """
178
+
179
+ def __init__(self, n_var=30, n_obj=2, **kwargs):
180
+ super().__init__(n_var=n_var, n_obj=n_obj,
181
+ xl=torch.zeros(n_var, dtype=torch.float) + 1e-6,
182
+ xu=torch.ones(n_var, dtype=torch.float) - 1e-6,
183
+ vtype=float, **kwargs)
184
+
185
+ def _calc_pareto_front(self,
186
+ n_pareto_points: int = 100) -> torch.Tensor:
187
+ pass
188
+
189
+ def _evaluate(self, x: torch.Tensor,
190
+ out: dict,
191
+ *args,
192
+ **kwargs) -> None:
193
+ pass
@@ -0,0 +1,4 @@
1
+ from moospread.tasks.dtlz_torch import DTLZ, DTLZ2, DTLZ4, DTLZ7
2
+ from moospread.tasks.zdt_torch import ZDT, ZDT1, ZDT2, ZDT3
3
+ from moospread.tasks.re_torch import RE21, RE33, RE34, RE37, RE41
4
+ from moospread.tasks.mw_torch import MW7
@@ -0,0 +1,139 @@
1
+ import torch
2
+ from pymoo.problems import get_problem
3
+ from moospread.problem import PymooProblemTorch
4
+ import math
5
+ from torch import Tensor
6
+
7
+ ######## DTLZ Problems ########
8
+ # Note: For the sake of differentiability, we will use the strict bounds:
9
+ # - Lower bound: 0 + 1e-6 (instead of 0)
10
+ # - Upper bound: 1 - 1e-6 (instead of 1)
11
+ # ref_point: The default reference points are suitable for the "online" mode.
12
+
13
+ class DTLZ(PymooProblemTorch):
14
+ r"""Base class for DTLZ problems.
15
+
16
+ See [Deb2005dtlz]_ for more details on DTLZ.
17
+ """
18
+
19
+ def __init__(self, n_var=30, n_obj=3, **kwargs):
20
+ super().__init__(n_var=n_var, n_obj=n_obj,
21
+ xl=torch.zeros(n_var, dtype=torch.float) + 1e-6,
22
+ xu=torch.ones(n_var, dtype=torch.float) - 1e-6,
23
+ vtype=float, **kwargs)
24
+
25
+ if n_var <= n_obj:
26
+ raise ValueError(
27
+ f"n_var must be > n_obj, but got {n_var} and {n_obj}."
28
+ )
29
+ self.continuous_inds = list(range(n_var))
30
+ self.k = n_var - n_obj + 1
31
+
32
+
33
+ class DTLZ2(DTLZ):
34
+ r"""DLTZ2 test problem.
35
+
36
+ d-dimensional problem evaluated on `[0, 1]^d`:
37
+
38
+ f_0(x) = (1 + g(x)) * cos(x_0 * pi / 2)
39
+ f_1(x) = (1 + g(x)) * sin(x_0 * pi / 2)
40
+ g(x) = \sum_{i=m}^{d-1} (x_i - 0.5)^2
41
+
42
+ The pareto front is given by the unit hypersphere \sum{i} f_i^2 = 1.
43
+ Note: the pareto front is completely concave. The goal is to minimize
44
+ both objectives.
45
+ """
46
+ def __init__(self, n_var=30, n_obj=3, ref_point=None, **kwargs):
47
+ super().__init__(n_var, n_obj, **kwargs)
48
+ if ref_point is None and n_obj == 3:
49
+ self.ref_point = [2.8390, 2.9011, 2.8575]
50
+ else:
51
+ assert ref_point is not None, "Please provide a reference point for n_obj != 3"
52
+ self.ref_point = ref_point
53
+
54
+ def pareto_front(self):
55
+ return get_problem("dtlz2", n_var=self.n_var, n_obj=self.n_obj).pareto_front()
56
+
57
+ def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
58
+ X_m = X[..., -self.k :]
59
+ g_X = (X_m - 0.5).pow(2).sum(dim=-1)
60
+ g_X_plus1 = 1 + g_X
61
+ fs = []
62
+ pi_over_2 = math.pi / 2
63
+ for i in range(self.n_obj):
64
+ idx = self.n_obj - 1 - i
65
+ f_i = g_X_plus1.clone()
66
+ f_i *= torch.cos(X[..., :idx] * pi_over_2).prod(dim=-1)
67
+ if i > 0:
68
+ f_i *= torch.sin(X[..., idx] * pi_over_2)
69
+ fs.append(f_i)
70
+ out["F"] = torch.stack(fs, dim=-1)
71
+
72
+
73
+ class DTLZ4(DTLZ):
74
+
75
+ def __init__(self, n_var=30, n_obj=3, ref_point=None, **kwargs):
76
+ super().__init__(n_var, n_obj, **kwargs)
77
+ if ref_point is None and n_obj == 3:
78
+ self.ref_point = [3.2675, 2.6443, 2.4263]
79
+ else:
80
+ assert ref_point is not None, "Please provide a reference point for n_obj != 3"
81
+ self.ref_point = ref_point
82
+
83
+ def pareto_front(self):
84
+ return get_problem("dtlz4", n_var=self.n_var, n_obj=self.n_obj).pareto_front()
85
+
86
+ def _evaluate(self, x: torch.Tensor, out: dict, *args, **kwargs) -> None:
87
+ X_, X_M = x[:, :self.n_obj - 1], x[:, self.n_obj - 1:]
88
+ alpha = 100
89
+ g_X = (X_M - 0.5).pow(2).sum(dim=-1)
90
+ g_X_plus1 = 1 + g_X
91
+ fs = []
92
+ pi_over_2 = math.pi / 2
93
+ for i in range(self.n_obj):
94
+ idx = self.n_obj - 1 - i
95
+ f_i = g_X_plus1.clone()
96
+ f_i *= torch.cos((X_[..., :idx]**alpha) * pi_over_2).prod(dim=-1)
97
+ if i > 0:
98
+ f_i *= torch.sin((X_[..., idx]**alpha) * pi_over_2)
99
+ fs.append(f_i)
100
+ out["F"] = torch.stack(fs, dim=-1)
101
+
102
+
103
+ class DTLZ7(DTLZ):
104
+ r"""DTLZ7 test problem.
105
+
106
+ d-dimensional problem evaluated on `[0, 1]^d`:
107
+ f_0(x) = x_0
108
+ f_1(x) = x_1
109
+ ...
110
+ f_{M-1}(x) = (1 + g(X_m)) * h(f_0, f_1, ..., f_{M-2}, g, x)
111
+ h(f_0, f_1, ..., f_{M-2}, g, x) =
112
+ M - sum_{i=0}^{M-2} f_i(x)/(1+g(x)) * (1 + sin(3 * pi * f_i(x)))
113
+
114
+ This test problem has 2M-1 disconnected Pareto-optimal regions in the search space.
115
+
116
+ The pareto frontier corresponds to X_m = 0.
117
+ """
118
+ def __init__(self, n_var=30, n_obj=3, ref_point=None, **kwargs):
119
+ super().__init__(n_var, n_obj, **kwargs)
120
+ if ref_point is None and n_obj == 3:
121
+ self.ref_point = [0.9984, 0.9961, 22.8114]
122
+ else:
123
+ assert ref_point is not None, "Please provide a reference point for n_obj != 3"
124
+ self.ref_point = ref_point
125
+
126
+ def pareto_front(self):
127
+ return get_problem("dtlz7", n_var=self.n_var, n_obj=self.n_obj).pareto_front()
128
+
129
+ def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
130
+ f = []
131
+ for i in range(0, self.n_obj - 1):
132
+ f.append(X[..., i])
133
+ f = torch.stack(f, dim=-1)
134
+
135
+ g_X = 1 + 9 / self.k * torch.sum(X[..., -self.k :], dim=-1)
136
+ h = self.n_obj - torch.sum(
137
+ f / (1 + g_X.unsqueeze(-1)) * (1 + torch.sin(3 * math.pi * f)), dim=-1
138
+ )
139
+ out["F"] = torch.cat([f, ((1 + g_X) * h).unsqueeze(-1)], dim=-1)
@@ -0,0 +1,274 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from pymoo.problems import get_problem
4
+ from pymoo.util.remote import Remote
5
+ from moospread.problem import PymooProblemTorch
6
+ import numpy as np
7
+
8
+ ######## MW Problems ########
9
+ # Note: For the sake of differentiability, we will use the strict bounds:
10
+ # - Lower bound: 0 + 1e-6 (instead of 0)
11
+ # - Upper bound: 1 - 1e-6 (instead of 1)
12
+ # ref_point: The default reference points are suitable for the "online" mode.
13
+
14
+ import torch
15
+ from torch import Tensor
16
+ from typing import Optional
17
+
18
+ class MW(PymooProblemTorch):
19
+ def __init__(self, n_var, n_obj,
20
+ n_ieq_constr: int = 0,
21
+ n_eq_constr: int = 0,
22
+ **kwargs):
23
+ super().__init__(n_var=n_var, n_obj=n_obj,
24
+ xl=torch.zeros(n_var, dtype=torch.float) + 1e-6,
25
+ xu=torch.ones(n_var, dtype=torch.float) - 1e-6,
26
+ n_ieq_constr=n_ieq_constr, n_eq_constr=n_eq_constr,
27
+ vtype=float, **kwargs)
28
+
29
+ # ---- smooth building blocks (torch versions of LA1/2/3) ----
30
+ @staticmethod
31
+ def LA1(A, B, C, D, theta: Tensor) -> Tensor:
32
+ # A * (sin(B * pi * theta^C))^D
33
+ t = theta.pow(C)
34
+ return torch.as_tensor(A, dtype=theta.dtype, device=theta.device) * \
35
+ torch.sin(torch.as_tensor(B, dtype=theta.dtype, device=theta.device) * torch.pi * t).pow(D)
36
+
37
+ @staticmethod
38
+ def LA2(A, B, C, D, theta: Tensor) -> Tensor:
39
+ # A * (sin(B * theta^C))^D
40
+ t = theta.pow(C)
41
+ return torch.as_tensor(A, dtype=theta.dtype, device=theta.device) * \
42
+ torch.sin(torch.as_tensor(B, dtype=theta.dtype, device=theta.device) * t).pow(D)
43
+
44
+ @staticmethod
45
+ def LA3(A, B, C, D, theta: Tensor) -> Tensor:
46
+ # A * (cos(B * theta^C))^D
47
+ t = theta.pow(C)
48
+ return torch.as_tensor(A, dtype=theta.dtype, device=theta.device) * \
49
+ torch.cos(torch.as_tensor(B, dtype=theta.dtype, device=theta.device) * t).pow(D)
50
+
51
+ # ---- distance/landscape functions ----
52
+ def g1(self, X: Tensor) -> Tensor:
53
+ """
54
+ X: (batch, n_var) tensor
55
+ returns: (batch,) tensor
56
+ """
57
+ if not isinstance(X, Tensor):
58
+ X = torch.as_tensor(X, dtype=torch.get_default_dtype())
59
+ d = self.n_var
60
+ n = d - self.n_obj
61
+
62
+ # z shape: (batch, d - (n_obj-1))
63
+ z = X[:, self.n_obj - 1:].pow(n)
64
+
65
+ # i shape: (d - (n_obj-1),)
66
+ i = torch.arange(self.n_obj - 1, d, device=X.device, dtype=X.dtype)
67
+ offset = 0.5 + i / (2.0 * d) # broadcast over batch
68
+ delta = z - offset
69
+ exp_term = 1.0 - torch.exp(-10.0 * (delta * delta))
70
+ distance = 1.0 + exp_term.sum(dim=1)
71
+ return distance
72
+
73
+ def g2(self, X: Tensor) -> Tensor:
74
+ """
75
+ X: (batch, n_var) tensor
76
+ returns: (batch,) tensor
77
+ """
78
+ if not isinstance(X, Tensor):
79
+ X = torch.as_tensor(X, dtype=torch.get_default_dtype())
80
+ d = self.n_var
81
+ n = float(d)
82
+
83
+ i = torch.arange(self.n_obj - 1, d, device=X.device, dtype=X.dtype) # (len,)
84
+ z = 1.0 - torch.exp(-10.0 * (X[:, self.n_obj - 1:] - i / n).pow(2))
85
+ contrib = (0.1 / n) * (z * z) + 1.5 - 1.5 * torch.cos(2.0 * torch.pi * z)
86
+ distance = 1.0 + contrib.sum(dim=1)
87
+ return distance
88
+
89
+ def g3(self, X: Tensor) -> Tensor:
90
+ """
91
+ X: (batch, n_var) tensor
92
+ returns: (batch,) tensor
93
+ """
94
+ if not isinstance(X, Tensor):
95
+ X = torch.as_tensor(X, dtype=torch.get_default_dtype())
96
+
97
+ a = X[:, self.n_obj - 1:] # last block
98
+ b = X[:, self.n_obj - 2:-1] - 0.5 # previous block (shifted)
99
+ inner = a + (b * b) - 1.0
100
+ contrib = 2.0 * inner.pow(2)
101
+ distance = 1.0 + contrib.sum(dim=1)
102
+ return distance
103
+
104
+
105
+ class MW7(MW):
106
+ def __init__(self, n_var: int = 15, ref_point=None, **kwargs):
107
+ super().__init__(n_var=n_var,
108
+ n_obj=2,
109
+ n_ieq_constr=2, **kwargs)
110
+ if ref_point is None:
111
+ self.ref_point = [2.0, 2.0]
112
+ else:
113
+ self.ref_point = ref_point
114
+
115
+ def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
116
+ return Remote.get_instance().load("pymoo", "pf", "MW", "MW7.pf")
117
+
118
+ @torch.no_grad()
119
+ def _check_shapes(self, X: Tensor):
120
+ if X.ndim != 2 or X.shape[1] != self.n_var:
121
+ raise ValueError(f"X must have shape (batch, {self.n_var})")
122
+
123
+ def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
124
+ """
125
+ X: (batch, n_var) tensor
126
+ returns: dict with
127
+ F: (batch, 2)
128
+ G: (batch, 2)
129
+ """
130
+ if not isinstance(X, Tensor):
131
+ X = torch.as_tensor(X, dtype=torch.get_default_dtype())
132
+ self._check_shapes(X)
133
+
134
+ # g from MW.g3 (>= 1, so safe to divide by)
135
+ g = self.g3(X) # (batch,)
136
+
137
+ f0 = g * X[:, 0] # (batch,)
138
+ # 1 - (f0/g)^2, clamped for numerical safety
139
+ one_minus_sq = 1.0 - (f0 / g).pow(2)
140
+ f1 = g * torch.sqrt(torch.clamp(one_minus_sq, min=0.0))
141
+
142
+ # atan2 handles all quadrants and f0=0 safely (better than arctan(f1/f0))
143
+ atan = torch.atan2(f1, f0) # (batch,)
144
+
145
+ # Constraints via LA2 (uses sin(B * theta^C)^D)
146
+ term0 = 1.2 + torch.abs(self.LA2(0.4, 4.0, 1.0, 16.0, atan))
147
+ g0 = f0.pow(2) + f1.pow(2) - term0.pow(2)
148
+
149
+ term1 = 1.15 - self.LA2(0.2, 4.0, 1.0, 8.0, atan)
150
+ g1 = term1.pow(2) - f0.pow(2) - f1.pow(2)
151
+
152
+ F = torch.stack([f0, f1], dim=1) # (batch, 2)
153
+ G = torch.stack([g0, g1], dim=1) # (batch, 2)
154
+ out["F"] = F
155
+ out["G"] = G
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+ class ZDT(PymooProblemTorch):
176
+ """
177
+ Base class to ensure PyTorch differentiability.
178
+ Provides a default `evaluate` that preserves gradients and guards against NaNs.
179
+ """
180
+
181
+ def __init__(self, n_var=30, **kwargs):
182
+ super().__init__(n_var=n_var, n_obj=2,
183
+ xl=torch.zeros(n_var, dtype=torch.float) + 1e-6,
184
+ xu=torch.ones(n_var, dtype=torch.float) - 1e-6,
185
+ vtype=float, **kwargs)
186
+
187
+ class ZDT1(ZDT):
188
+ """
189
+ ZDT1 test problem in PyTorch, fully differentiable.
190
+ """
191
+ def __init__(self, n_var=30, ref_point=None, **kwargs):
192
+ super().__init__(n_var, **kwargs)
193
+ if ref_point is None:
194
+ self.ref_point = [0.9994, 6.0576]
195
+ else:
196
+ self.ref_point = ref_point
197
+
198
+ def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
199
+ x = torch.linspace(0.0, 1.0, n_pareto_points, device=self.device)
200
+ return torch.stack([x, 1.0 - torch.sqrt(x)], dim=1)
201
+
202
+ def _evaluate(self, x: torch.Tensor, out: dict, *args, **kwargs) -> None:
203
+ # x: (batch_size, n_var)
204
+ # Objective f1
205
+ f1 = x[:, 0]
206
+ # Auxiliary g
207
+ g = 1.0 + 9.0 / (self.n_var - 1) * torch.sum(x[:, 1:], dim=1)
208
+ # Avoid negative division
209
+ term = torch.clamp(f1 / g, min=0.0)
210
+ # Objective f2
211
+ f2 = g * (1.0 - torch.sqrt(term))
212
+ out["F"] = torch.stack([f1, f2], dim=1)
213
+
214
+ class ZDT2(ZDT):
215
+ """
216
+ ZDT2 test problem in PyTorch, fully differentiable.
217
+ """
218
+ def __init__(self, n_var=30, ref_point=None, **kwargs):
219
+ super().__init__(n_var, **kwargs)
220
+ if ref_point is None:
221
+ self.ref_point = [0.9994, 6.8960]
222
+ else:
223
+ self.ref_point = ref_point
224
+
225
+ def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
226
+ x = torch.linspace(0.0, 1.0, n_pareto_points, device=self.device)
227
+ return torch.stack([x, 1.0 - x.pow(2)], dim=1)
228
+
229
+ def _evaluate(self, x: torch.Tensor, out: dict, *args, **kwargs) -> None:
230
+ f1 = x[:, 0]
231
+ c = torch.sum(x[:, 1:], dim=1)
232
+ g = 1.0 + 9.0 * c / (self.n_var - 1)
233
+ term = torch.clamp(f1 / g, min=0.0)
234
+ f2 = g * (1.0 - term.pow(2))
235
+ out["F"] = torch.stack([f1, f2], dim=1)
236
+
237
+ class ZDT3(ZDT):
238
+ """
239
+ ZDT3 test problem in PyTorch, fully differentiable.
240
+ """
241
+ def __init__(self, n_var=30, ref_point=None, **kwargs):
242
+ super().__init__(n_var, **kwargs)
243
+ if ref_point is None:
244
+ self.ref_point = [0.9994, 6.0571]
245
+ else:
246
+ self.ref_point = ref_point
247
+
248
+ def _calc_pareto_front(
249
+ self,
250
+ n_points: int = 100,
251
+ flatten: bool = True
252
+ ) -> torch.Tensor:
253
+ regions = [
254
+ [0.0, 0.0830015349],
255
+ [0.182228780, 0.2577623634],
256
+ [0.4093136748, 0.4538821041],
257
+ [0.6183967944, 0.6525117038],
258
+ [0.8233317983, 0.8518328654]
259
+ ]
260
+ pf_list = []
261
+ points_per_region = int(n_points / len(regions))
262
+ for r in regions:
263
+ x1 = torch.linspace(r[0], r[1], points_per_region, device=self.device)
264
+ x2 = 1.0 - torch.sqrt(x1) - x1 * torch.sin(10.0 * torch.pi * x1)
265
+ pf_list.append(torch.stack([x1, x2], dim=1))
266
+ return torch.cat(pf_list, dim=0) if flatten else torch.stack(pf_list, dim=0)
267
+
268
+ def _evaluate(self, x: torch.Tensor, out: dict, *args, **kwargs) -> None:
269
+ f1 = x[:, 0]
270
+ c = torch.sum(x[:, 1:], dim=1)
271
+ g = 1.0 + 9.0 * c / (self.n_var - 1)
272
+ term = torch.clamp(f1 / g, min=0.0)
273
+ f2 = g * (1.0 - torch.sqrt(term) - term * torch.sin(10.0 * torch.pi * f1))
274
+ out["F"] = torch.stack([f1, f2], dim=1)