moospread 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- moospread/__init__.py +3 -0
- moospread/core.py +1881 -0
- moospread/problem.py +193 -0
- moospread/tasks/__init__.py +4 -0
- moospread/tasks/dtlz_torch.py +139 -0
- moospread/tasks/mw_torch.py +274 -0
- moospread/tasks/re_torch.py +394 -0
- moospread/tasks/zdt_torch.py +112 -0
- moospread/utils/__init__.py +8 -0
- moospread/utils/constraint_utils/__init__.py +2 -0
- moospread/utils/constraint_utils/gradient.py +72 -0
- moospread/utils/constraint_utils/mgda_core.py +69 -0
- moospread/utils/constraint_utils/pmgda_solver.py +308 -0
- moospread/utils/constraint_utils/prefs.py +64 -0
- moospread/utils/ditmoo.py +127 -0
- moospread/utils/lhs.py +74 -0
- moospread/utils/misc.py +28 -0
- moospread/utils/mobo_utils/__init__.py +11 -0
- moospread/utils/mobo_utils/evolution/__init__.py +0 -0
- moospread/utils/mobo_utils/evolution/dom.py +60 -0
- moospread/utils/mobo_utils/evolution/norm.py +40 -0
- moospread/utils/mobo_utils/evolution/utils.py +97 -0
- moospread/utils/mobo_utils/learning/__init__.py +0 -0
- moospread/utils/mobo_utils/learning/model.py +40 -0
- moospread/utils/mobo_utils/learning/model_init.py +33 -0
- moospread/utils/mobo_utils/learning/model_update.py +51 -0
- moospread/utils/mobo_utils/learning/prediction.py +116 -0
- moospread/utils/mobo_utils/learning/utils.py +143 -0
- moospread/utils/mobo_utils/lhs_for_mobo.py +243 -0
- moospread/utils/mobo_utils/mobo/__init__.py +0 -0
- moospread/utils/mobo_utils/mobo/acquisition.py +209 -0
- moospread/utils/mobo_utils/mobo/algorithms.py +91 -0
- moospread/utils/mobo_utils/mobo/factory.py +86 -0
- moospread/utils/mobo_utils/mobo/mobo.py +132 -0
- moospread/utils/mobo_utils/mobo/selection.py +182 -0
- moospread/utils/mobo_utils/mobo/solver/__init__.py +5 -0
- moospread/utils/mobo_utils/mobo/solver/moead.py +17 -0
- moospread/utils/mobo_utils/mobo/solver/nsga2.py +10 -0
- moospread/utils/mobo_utils/mobo/solver/parego/__init__.py +1 -0
- moospread/utils/mobo_utils/mobo/solver/parego/parego.py +62 -0
- moospread/utils/mobo_utils/mobo/solver/parego/utils.py +34 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/__init__.py +1 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/buffer.py +364 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/pareto_discovery.py +571 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/utils.py +168 -0
- moospread/utils/mobo_utils/mobo/solver/solver.py +74 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/__init__.py +2 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/base.py +36 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py +177 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py +79 -0
- moospread/utils/mobo_utils/mobo/surrogate_problem.py +44 -0
- moospread/utils/mobo_utils/mobo/transformation.py +106 -0
- moospread/utils/mobo_utils/mobo/utils.py +65 -0
- moospread/utils/mobo_utils/spread_mobo_utils.py +854 -0
- moospread/utils/offline_utils/__init__.py +10 -0
- moospread/utils/offline_utils/handle_task.py +203 -0
- moospread/utils/offline_utils/proxies.py +338 -0
- moospread/utils/spread_utils.py +91 -0
- moospread-0.1.0.dist-info/METADATA +75 -0
- moospread-0.1.0.dist-info/RECORD +63 -0
- moospread-0.1.0.dist-info/WHEEL +5 -0
- moospread-0.1.0.dist-info/licenses/LICENSE +10 -0
- moospread-0.1.0.dist-info/top_level.txt +1 -0
moospread/core.py
ADDED
|
@@ -0,0 +1,1881 @@
|
|
|
1
|
+
"""Main module."""
|
|
2
|
+
import numpy as np
|
|
3
|
+
import random
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.optim as optim
|
|
7
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
8
|
+
|
|
9
|
+
import copy
|
|
10
|
+
import math
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import pickle
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
import datetime
|
|
16
|
+
from time import time
|
|
17
|
+
import dis
|
|
18
|
+
|
|
19
|
+
from pymoo.config import Config
|
|
20
|
+
Config.warnings['not_compiled'] = False
|
|
21
|
+
from pymoo.indicators.hv import HV
|
|
22
|
+
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
|
|
23
|
+
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
from mpl_toolkits.mplot3d import Axes3D
|
|
26
|
+
|
|
27
|
+
from moospread.utils import *
|
|
28
|
+
|
|
29
|
+
class SPREAD:
|
|
30
|
+
def __init__(self,
|
|
31
|
+
problem,
|
|
32
|
+
mode: str = "online",
|
|
33
|
+
model = None,
|
|
34
|
+
surrogate_model = None,
|
|
35
|
+
dataset = None,
|
|
36
|
+
xi_shift = None,
|
|
37
|
+
data_size: int = 10000,
|
|
38
|
+
validation_split=0.1,
|
|
39
|
+
hidden_dim: int = 128,
|
|
40
|
+
num_heads: int = 4,
|
|
41
|
+
num_blocks: int = 2,
|
|
42
|
+
timesteps: int = 1000,
|
|
43
|
+
batch_size: int = 256,
|
|
44
|
+
train_lr: float = 1e-4,
|
|
45
|
+
train_lr_surrogate: float = 1e-4,
|
|
46
|
+
num_epochs: int = 1000,
|
|
47
|
+
num_epochs_surrogate: int = 1000,
|
|
48
|
+
train_tol: int = 100,
|
|
49
|
+
train_tol_surrogate: int = 100,
|
|
50
|
+
mobo_coef_lcb=0.1,
|
|
51
|
+
model_dir: str = "./model_dir",
|
|
52
|
+
proxies_store_path: str = "./proxies_dir",
|
|
53
|
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
54
|
+
seed: int = 0,
|
|
55
|
+
offline_global_clamping: bool = False,
|
|
56
|
+
offline_normalization_method: str = "z_score",
|
|
57
|
+
dominance_classifier = None,
|
|
58
|
+
train_func_surrogate = None,
|
|
59
|
+
plot_func = None,
|
|
60
|
+
verbose: bool = True):
|
|
61
|
+
|
|
62
|
+
self.mode = mode.lower()
|
|
63
|
+
if self.mode not in ["offline", "online", "bayesian"]:
|
|
64
|
+
raise ValueError(f"Invalid mode: {mode}. Must be one of ['offline', 'online', 'bayesian']")
|
|
65
|
+
|
|
66
|
+
assert problem is not None, "Problem must be provided"
|
|
67
|
+
self.problem = problem
|
|
68
|
+
if self.mode in ["online", "bayesian"]:
|
|
69
|
+
assert not is_pass_function(self.problem._evaluate), "Problem must have the '_evaluate' method implemented."
|
|
70
|
+
self.device = device
|
|
71
|
+
|
|
72
|
+
assert 0.0 <= validation_split < 1.0, "validation_split must be in [0.0, 1.0)"
|
|
73
|
+
self.validation_split = validation_split
|
|
74
|
+
self.train_lr = train_lr
|
|
75
|
+
self.batch_size = batch_size
|
|
76
|
+
self.num_epochs = num_epochs
|
|
77
|
+
self.timesteps = timesteps
|
|
78
|
+
self.train_tol = train_tol
|
|
79
|
+
if self.mode in ["offline", "bayesian"]:
|
|
80
|
+
self.train_lr_surrogate = train_lr_surrogate
|
|
81
|
+
self.num_epochs_surrogate = num_epochs_surrogate
|
|
82
|
+
self.train_tol_surrogate = train_tol_surrogate
|
|
83
|
+
if self.mode == "bayesian":
|
|
84
|
+
self.mobo_coef_lcb = mobo_coef_lcb
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
self.xi_shift = xi_shift
|
|
88
|
+
self.model_dir = model_dir
|
|
89
|
+
os.makedirs(self.model_dir, exist_ok=True)
|
|
90
|
+
|
|
91
|
+
self.train_func_surrogate = train_func_surrogate
|
|
92
|
+
self.plot_func = plot_func
|
|
93
|
+
self.dominance_classifier = dominance_classifier
|
|
94
|
+
|
|
95
|
+
self.seed = seed
|
|
96
|
+
# Set the seed for reproducibility
|
|
97
|
+
set_seed(self.seed)
|
|
98
|
+
|
|
99
|
+
self.model = model
|
|
100
|
+
if self.model is None:
|
|
101
|
+
self.model = DiTMOO(
|
|
102
|
+
input_dim=problem.n_var,
|
|
103
|
+
num_obj=problem.n_obj,
|
|
104
|
+
hidden_dim=hidden_dim,
|
|
105
|
+
num_heads=num_heads,
|
|
106
|
+
num_blocks=num_blocks
|
|
107
|
+
)
|
|
108
|
+
self.surrogate_model = surrogate_model
|
|
109
|
+
if self.surrogate_model is not None:
|
|
110
|
+
self.surrogate_given = True
|
|
111
|
+
else:
|
|
112
|
+
self.surrogate_given = False
|
|
113
|
+
self.proxies_store_path = proxies_store_path
|
|
114
|
+
|
|
115
|
+
self.verbose = verbose
|
|
116
|
+
if self.verbose:
|
|
117
|
+
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
|
118
|
+
print(f"Total number of learnable parameters: {total_params}")
|
|
119
|
+
|
|
120
|
+
self.dataset = dataset
|
|
121
|
+
if self.mode in ["offline", "online"]:
|
|
122
|
+
if self.dataset is None:
|
|
123
|
+
if self.mode == "offline":
|
|
124
|
+
print("Training dataset not provided for offline mode.")
|
|
125
|
+
assert data_size > 0, "Training data size must be positive."
|
|
126
|
+
# self.problem._evaluate should exist in this case
|
|
127
|
+
assert (not is_pass_function(self.problem._evaluate)), "Problem must have the '_evaluate' method implemented when dataset is not provided in offline mode."
|
|
128
|
+
print("Generating training dataset ...")
|
|
129
|
+
self.dataset = self.get_training_data(self.problem,
|
|
130
|
+
num_samples=data_size)
|
|
131
|
+
if self.verbose:
|
|
132
|
+
print("Training dataset generated.")
|
|
133
|
+
|
|
134
|
+
if self.mode == "offline":
|
|
135
|
+
assert offline_normalization_method in ["z_score", "min_max", "none"], "Invalid normalization method"
|
|
136
|
+
if offline_normalization_method == "z_score":
|
|
137
|
+
self.offline_normalization = offdata_z_score_normalize
|
|
138
|
+
self.offline_denormalization = offdata_z_score_denormalize
|
|
139
|
+
elif offline_normalization_method == "min_max":
|
|
140
|
+
self.offline_normalization = offdata_min_max_normalize
|
|
141
|
+
self.offline_denormalization = offdata_min_max_denormalize
|
|
142
|
+
else:
|
|
143
|
+
self.offline_normalization = lambda x: x
|
|
144
|
+
self.offline_denormalization = lambda x: x
|
|
145
|
+
|
|
146
|
+
self.X_meanormin, self.y_meanormin = 0, 0
|
|
147
|
+
self.X_stdormax, self.y_stdormax = 1, 1
|
|
148
|
+
if self.problem.has_bounds():
|
|
149
|
+
xl = self.problem.xl
|
|
150
|
+
xu = self.problem.xu
|
|
151
|
+
## Normalize the bounds
|
|
152
|
+
# xl
|
|
153
|
+
if self.problem.is_discrete:
|
|
154
|
+
xl = offdata_to_logits(xl)
|
|
155
|
+
_, n_dim, n_classes = tuple(xl.shape)
|
|
156
|
+
xl = xl.reshape(-1, n_dim * n_classes)
|
|
157
|
+
if self.problem.is_sequence:
|
|
158
|
+
xl = offdata_to_logits(xl)
|
|
159
|
+
xl, _, _ = self.offline_normalization(xl)
|
|
160
|
+
# xu
|
|
161
|
+
if self.problem.is_discrete:
|
|
162
|
+
xu = offdata_to_logits(xu)
|
|
163
|
+
_, n_dim, n_classes = tuple(xu.shape)
|
|
164
|
+
xu = xu.reshape(-1, n_dim * n_classes)
|
|
165
|
+
if self.problem.is_sequence:
|
|
166
|
+
xu = offdata_to_logits(xu)
|
|
167
|
+
xu, _, _ = self.offline_normalization(xu)
|
|
168
|
+
## Set the normalized bounds
|
|
169
|
+
self.problem.xl = xl
|
|
170
|
+
self.problem.xu = xu
|
|
171
|
+
if offline_global_clamping:
|
|
172
|
+
self.problem.global_clamping = True
|
|
173
|
+
|
|
174
|
+
def objective_functions(self,
|
|
175
|
+
points,
|
|
176
|
+
return_as_dict: bool = False,
|
|
177
|
+
return_values_of=None,
|
|
178
|
+
get_constraint=False,
|
|
179
|
+
get_grad_mobo=False,
|
|
180
|
+
evaluate_true=False):
|
|
181
|
+
if evaluate_true:
|
|
182
|
+
if self.problem.need_repair:
|
|
183
|
+
points = self.repair_bounds(points)
|
|
184
|
+
if get_constraint:
|
|
185
|
+
return self.problem.evaluate(points, return_as_dict=True,
|
|
186
|
+
return_values_of=["F", "G", "H"])
|
|
187
|
+
return self.problem.evaluate(points, return_as_dict=return_as_dict,
|
|
188
|
+
return_values_of=return_values_of)
|
|
189
|
+
# Define the objective functions for the optimization problem
|
|
190
|
+
if self.mode == "online":
|
|
191
|
+
if get_constraint:
|
|
192
|
+
return self.problem.evaluate(points, return_as_dict=True,
|
|
193
|
+
return_values_of=["F", "G", "H"])
|
|
194
|
+
return self.problem.evaluate(points, return_as_dict=return_as_dict,
|
|
195
|
+
return_values_of=return_values_of)
|
|
196
|
+
elif self.mode == "offline":
|
|
197
|
+
scores = []
|
|
198
|
+
for proxy in self.surrogate_model:
|
|
199
|
+
scores.append(proxy(points).squeeze())
|
|
200
|
+
return torch.stack(scores, dim=1)
|
|
201
|
+
elif self.mode == "bayesian":
|
|
202
|
+
x = points.detach().cpu().numpy()
|
|
203
|
+
eval_result = self.surrogate_model.evaluate(x, std=True,
|
|
204
|
+
calc_gradient=get_grad_mobo)
|
|
205
|
+
mean = torch.from_numpy(eval_result["F"]).float().to(self.device)
|
|
206
|
+
std = torch.from_numpy(eval_result["S"]).float().to(self.device)
|
|
207
|
+
Y_val = mean - self.mobo_coef_lcb * std
|
|
208
|
+
if get_grad_mobo:
|
|
209
|
+
out = {}
|
|
210
|
+
mean_grad = torch.from_numpy(eval_result["dF"]).float().to(self.device)
|
|
211
|
+
std_grad = torch.from_numpy(eval_result["dS"]).float().to(self.device)
|
|
212
|
+
Grad_val = mean_grad - self.mobo_coef_lcb * std_grad
|
|
213
|
+
out["dF"] = [Grad_val[:, i, :] for i in range(Grad_val.shape[1])]
|
|
214
|
+
out["F"] = Y_val
|
|
215
|
+
else:
|
|
216
|
+
out = Y_val
|
|
217
|
+
return out
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Invalid mode: {self.mode}")
|
|
220
|
+
|
|
221
|
+
def solve(self, num_points_sample=500,
|
|
222
|
+
strict_guidance=False,
|
|
223
|
+
rho_scale_gamma=0.9,
|
|
224
|
+
nu_t=10.0, eta_init=0.9,
|
|
225
|
+
num_inner_steps=10, lr_inner=0.9,
|
|
226
|
+
free_initial_h=True,
|
|
227
|
+
use_sigma_rep=False, kernel_sigma_rep=0.01,
|
|
228
|
+
iterative_plot=True, plot_period=100,
|
|
229
|
+
max_backtracks=100, label=None, save_results=True,
|
|
230
|
+
load_models=False,
|
|
231
|
+
samples_store_path="./samples_dir/",
|
|
232
|
+
images_store_path="./images_dir/",
|
|
233
|
+
n_init_mobo=100, use_escape_local_mobo=True,
|
|
234
|
+
n_steps_mobo=20, spread_num_samp_mobo=25,
|
|
235
|
+
batch_select_mobo=5):
|
|
236
|
+
set_seed(self.seed)
|
|
237
|
+
if self.mode in ["offline", "online"]:
|
|
238
|
+
X, y = self.dataset
|
|
239
|
+
|
|
240
|
+
if self.mode == "offline":
|
|
241
|
+
X = X.clone().detach()
|
|
242
|
+
y = y.clone().detach()
|
|
243
|
+
if self.problem.is_discrete:
|
|
244
|
+
X = offdata_to_logits(X)
|
|
245
|
+
_, n_dim, n_classes = tuple(X.shape)
|
|
246
|
+
X = X.reshape(-1, n_dim * n_classes)
|
|
247
|
+
if self.problem.is_sequence:
|
|
248
|
+
X = offdata_to_logits(X)
|
|
249
|
+
# For usual cases, we normalize the inputs
|
|
250
|
+
# and outputs with z-score normalization
|
|
251
|
+
X, self.X_meanormin, self.X_stdormax = self.offline_normalization(X)
|
|
252
|
+
y, self.y_meanormin, self.y_stdormax = self.offline_normalization(y)
|
|
253
|
+
|
|
254
|
+
#### SURROGATE MODEL TRAINING ####
|
|
255
|
+
if not load_models or self.surrogate_given:
|
|
256
|
+
self.train_surrogate(X, y)
|
|
257
|
+
else:
|
|
258
|
+
# Load the proxies
|
|
259
|
+
self.surrogate_model = []
|
|
260
|
+
for i in range(self.problem.n_obj):
|
|
261
|
+
classifier = SingleModel(
|
|
262
|
+
input_size=self.problem.n_var,
|
|
263
|
+
which_obj=i,
|
|
264
|
+
device=self.device,
|
|
265
|
+
hidden_size=[2048, 2048],
|
|
266
|
+
save_dir=self.proxies_store_path,
|
|
267
|
+
save_prefix=f"MultipleModels-Vallina-{self.problem.__class__.__name__}-{self.seed}",
|
|
268
|
+
)
|
|
269
|
+
classifier.load()
|
|
270
|
+
classifier = classifier.to(self.device)
|
|
271
|
+
classifier.eval()
|
|
272
|
+
self.surrogate_model.append(classifier)
|
|
273
|
+
|
|
274
|
+
# Create DataLoader
|
|
275
|
+
train_dataloader, val_dataloader = get_ddpm_dataloader(X,
|
|
276
|
+
y,
|
|
277
|
+
validation_split=self.validation_split,
|
|
278
|
+
batch_size=self.batch_size)
|
|
279
|
+
|
|
280
|
+
#### DIFFUSION MODEL TRAINING ####
|
|
281
|
+
if not load_models:
|
|
282
|
+
self.train(train_dataloader, val_dataloader=val_dataloader)
|
|
283
|
+
#### SPREAD SAMPLING ####
|
|
284
|
+
res_x, res_y = self.sampling(num_points_sample,
|
|
285
|
+
strict_guidance=strict_guidance,
|
|
286
|
+
rho_scale_gamma=rho_scale_gamma,
|
|
287
|
+
nu_t=nu_t, eta_init=eta_init,
|
|
288
|
+
num_inner_steps=num_inner_steps, lr_inner=lr_inner,
|
|
289
|
+
free_initial_h=free_initial_h,
|
|
290
|
+
use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
|
|
291
|
+
iterative_plot=iterative_plot, plot_period=plot_period,
|
|
292
|
+
max_backtracks=max_backtracks, label=label,
|
|
293
|
+
save_results=save_results,
|
|
294
|
+
samples_store_path=samples_store_path,
|
|
295
|
+
images_store_path=images_store_path)
|
|
296
|
+
|
|
297
|
+
return res_x, res_y
|
|
298
|
+
|
|
299
|
+
elif self.mode == "bayesian":
|
|
300
|
+
self.verbose = False
|
|
301
|
+
hv_all_value = []
|
|
302
|
+
# initialize n_init solutions
|
|
303
|
+
x_init = lhs_no_evaluation(self.problem.n_var,
|
|
304
|
+
n_init_mobo)
|
|
305
|
+
x_init = torch.from_numpy(x_init).float().to(self.device)
|
|
306
|
+
y_init = self.problem.evaluate(x_init).detach().cpu().numpy()
|
|
307
|
+
|
|
308
|
+
# initialize dominance-classifier for non-dominance relation
|
|
309
|
+
p_rel_map, s_rel_map = init_dom_rel_map(300)
|
|
310
|
+
p_model = init_dom_nn_classifier(
|
|
311
|
+
x_init, y_init, p_rel_map, pareto_dominance, self.problem.n_var,
|
|
312
|
+
)
|
|
313
|
+
self.dominance_classifier = p_model
|
|
314
|
+
|
|
315
|
+
evaluated = len(y_init)
|
|
316
|
+
X = x_init.detach().cpu().numpy()
|
|
317
|
+
Y = y_init
|
|
318
|
+
hv = HV(ref_point=np.array(self.problem.ref_point))
|
|
319
|
+
hv_value = hv(Y)
|
|
320
|
+
hv_all_value.append(hv_value)
|
|
321
|
+
z = torch.zeros(self.problem.n_obj).to(self.device)
|
|
322
|
+
|
|
323
|
+
escape_flag = False
|
|
324
|
+
if use_escape_local_mobo:
|
|
325
|
+
# Counter for tracking iterations since last switch
|
|
326
|
+
iteration_since_switch = 0
|
|
327
|
+
# Parameters for switching methods
|
|
328
|
+
hv_change_threshold = 0.05 # Threshold for HV value change
|
|
329
|
+
hv_history_length = 3 # Number of recent iterations to consider
|
|
330
|
+
hv_history = [] # Store recent HV values
|
|
331
|
+
# Initialize list to store historical data
|
|
332
|
+
history_Y = []
|
|
333
|
+
|
|
334
|
+
# Start the main loop for Bayesian-SPREAD
|
|
335
|
+
with tqdm(
|
|
336
|
+
total=n_steps_mobo,
|
|
337
|
+
desc=f"SPREAD (MOBO)",
|
|
338
|
+
unit="k",
|
|
339
|
+
) as pbar:
|
|
340
|
+
|
|
341
|
+
for k_iter in range(n_steps_mobo):
|
|
342
|
+
# Solution normalization
|
|
343
|
+
transformation = StandardTransform([0, 1])
|
|
344
|
+
transformation.fit(X, Y)
|
|
345
|
+
X_norm, Y_norm = transformation.do(X, Y)
|
|
346
|
+
|
|
347
|
+
#### SURROGATE MODEL TRAINING ####
|
|
348
|
+
self.train_surrogate(X_norm, Y_norm)
|
|
349
|
+
|
|
350
|
+
if use_escape_local_mobo:
|
|
351
|
+
_, index = environment_selection(Y, len(X) // 3)
|
|
352
|
+
PopDec = X[index, :]
|
|
353
|
+
else:
|
|
354
|
+
PopDec = X
|
|
355
|
+
|
|
356
|
+
PopDec_dom_labels, PopDec_cfs = nn_predict_dom_intra(
|
|
357
|
+
PopDec, p_model, self.device
|
|
358
|
+
)
|
|
359
|
+
sorted_pop = sort_population(PopDec, PopDec_dom_labels, PopDec_cfs)
|
|
360
|
+
|
|
361
|
+
if not escape_flag:
|
|
362
|
+
# **** Generate new offspring using SPREAD**** #
|
|
363
|
+
self.model.to(self.device)
|
|
364
|
+
#### DIFFUSION MODEL TRAINING ####
|
|
365
|
+
train_dataloader, val_dataloader, dataset_size = mobo_get_ddpm_dataloader(sorted_pop,
|
|
366
|
+
self.objective_functions,
|
|
367
|
+
self.device,
|
|
368
|
+
self.batch_size,
|
|
369
|
+
self.validation_split)
|
|
370
|
+
self.train(train_dataloader,
|
|
371
|
+
val_dataloader=val_dataloader,
|
|
372
|
+
disable_progress_bar=True)
|
|
373
|
+
#### SPREAD SAMPLING ####
|
|
374
|
+
new_offsprings = []
|
|
375
|
+
for i in range(spread_num_samp_mobo):
|
|
376
|
+
num_points_sample = dataset_size
|
|
377
|
+
pf_points, _ = self.sampling(num_points_sample,
|
|
378
|
+
strict_guidance=strict_guidance,
|
|
379
|
+
rho_scale_gamma=rho_scale_gamma,
|
|
380
|
+
nu_t=nu_t, eta_init=eta_init,
|
|
381
|
+
num_inner_steps=num_inner_steps, lr_inner=lr_inner,
|
|
382
|
+
free_initial_h=free_initial_h,
|
|
383
|
+
use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
|
|
384
|
+
iterative_plot=iterative_plot, plot_period=plot_period,
|
|
385
|
+
max_backtracks=max_backtracks, label=label,
|
|
386
|
+
samples_store_path=samples_store_path,
|
|
387
|
+
images_store_path=images_store_path,
|
|
388
|
+
disable_progress_bar=True,
|
|
389
|
+
save_results=False, evaluate_final=False)
|
|
390
|
+
new_offsprings.append(pf_points)
|
|
391
|
+
X_psl = np.vstack(new_offsprings)
|
|
392
|
+
else:
|
|
393
|
+
#### SBX OFFSPRING GENERATION ####
|
|
394
|
+
rows_to_take = int(1 / 3 * sorted_pop.shape[0])
|
|
395
|
+
offspringA = sorted_pop[:rows_to_take, :]
|
|
396
|
+
if len(offspringA) % 2 == 1:
|
|
397
|
+
offspringA = offspringA[:-1]
|
|
398
|
+
new_pop = np.empty((0, self.problem.n_var))
|
|
399
|
+
for _ in range(1000):
|
|
400
|
+
result = sbx(offspringA, eta=15)
|
|
401
|
+
new_pop = np.vstack((new_pop, result))
|
|
402
|
+
X_psl = new_pop
|
|
403
|
+
|
|
404
|
+
pop_size_used = X_psl.shape[0]
|
|
405
|
+
|
|
406
|
+
# Mutate the new offspring
|
|
407
|
+
X_psl = pm_mutation(X_psl, [self.problem.xl.detach().cpu().numpy(),
|
|
408
|
+
self.problem.xu.detach().cpu().numpy()])
|
|
409
|
+
|
|
410
|
+
Y_candidate_mean = self.surrogate_model.evaluate(X_psl)["F"]
|
|
411
|
+
Y_candidate_std = self.surrogate_model.evaluate(X_psl, std=True)["S"]
|
|
412
|
+
|
|
413
|
+
rows_with_nan = np.any(np.isnan(Y_candidate_mean), axis=1)
|
|
414
|
+
Y_candidate_mean = Y_candidate_mean[~rows_with_nan]
|
|
415
|
+
Y_candidate_std = Y_candidate_std[~rows_with_nan]
|
|
416
|
+
X_psl = X_psl[~rows_with_nan]
|
|
417
|
+
|
|
418
|
+
Y_candidate = Y_candidate_mean - self.mobo_coef_lcb * Y_candidate_std
|
|
419
|
+
Y_candidate_mean = Y_candidate
|
|
420
|
+
|
|
421
|
+
#### BATCH SELECTION ####
|
|
422
|
+
nds = NonDominatedSorting()
|
|
423
|
+
idx_nds = nds.do(Y_norm)
|
|
424
|
+
Y_nds = Y_norm[idx_nds[0]]
|
|
425
|
+
best_subset_list = []
|
|
426
|
+
Y_p = Y_nds
|
|
427
|
+
for b in range(batch_select_mobo):
|
|
428
|
+
hv = HV(
|
|
429
|
+
ref_point=np.max(np.vstack([Y_p, Y_candidate_mean]), axis=0)
|
|
430
|
+
)
|
|
431
|
+
best_hv_value = 0
|
|
432
|
+
best_subset = None
|
|
433
|
+
|
|
434
|
+
for k in range(len(Y_candidate_mean)):
|
|
435
|
+
Y_subset = Y_candidate_mean[k]
|
|
436
|
+
Y_comb = np.vstack([Y_p, Y_subset])
|
|
437
|
+
hv_value_subset = hv(Y_comb)
|
|
438
|
+
if hv_value_subset > best_hv_value:
|
|
439
|
+
best_hv_value = hv_value_subset
|
|
440
|
+
best_subset = [k]
|
|
441
|
+
|
|
442
|
+
Y_p = np.vstack([Y_p, Y_candidate_mean[best_subset]])
|
|
443
|
+
best_subset_list.append(best_subset)
|
|
444
|
+
|
|
445
|
+
best_subset_list = np.array(best_subset_list).T[0]
|
|
446
|
+
|
|
447
|
+
X_candidate = X_psl
|
|
448
|
+
X_new = X_candidate[best_subset_list]
|
|
449
|
+
Y_new = self.problem.evaluate(torch.from_numpy(X_new).float().to(self.device)).detach().cpu().numpy()
|
|
450
|
+
|
|
451
|
+
Y_new = torch.tensor(Y_new).to(self.device)
|
|
452
|
+
X_new = torch.tensor(X_new).to(self.device)
|
|
453
|
+
|
|
454
|
+
X = np.vstack([X, X_new.detach().cpu().numpy()])
|
|
455
|
+
Y = np.vstack([Y, Y_new.detach().cpu().numpy()])
|
|
456
|
+
hv = HV(ref_point=np.array(self.problem.ref_point))
|
|
457
|
+
hv_value = hv(Y)
|
|
458
|
+
hv_all_value.append(hv_value)
|
|
459
|
+
|
|
460
|
+
rows_with_nan = np.any(np.isnan(Y), axis=1)
|
|
461
|
+
X = X[~rows_with_nan, :]
|
|
462
|
+
Y = Y[~rows_with_nan, :]
|
|
463
|
+
|
|
464
|
+
update_dom_nn_classifier(
|
|
465
|
+
p_model, X, Y, p_rel_map, pareto_dominance, self.problem
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
hv_text = f"{hv_value:.4e}"
|
|
469
|
+
evaluated = evaluated + batch_select_mobo
|
|
470
|
+
|
|
471
|
+
#### DECISION TO SWITCH OPERATOR ####
|
|
472
|
+
if use_escape_local_mobo:
|
|
473
|
+
# Current operator
|
|
474
|
+
if not escape_flag:
|
|
475
|
+
operator_text = "Diffusion"
|
|
476
|
+
else:
|
|
477
|
+
operator_text = "SBX"
|
|
478
|
+
# Update historical data and calculate reference point
|
|
479
|
+
history_Y.append(Y)
|
|
480
|
+
if len(history_Y) > k:
|
|
481
|
+
history_Y.pop(0)
|
|
482
|
+
all_Y = np.vstack(history_Y) # Combine historical data
|
|
483
|
+
nds_hist = NonDominatedSorting()
|
|
484
|
+
idx_nds_hist = nds_hist.do(all_Y)
|
|
485
|
+
Y_nds_hist = all_Y[idx_nds_hist[0]] # Get non-dominated individuals
|
|
486
|
+
quantile_values = np.quantile(Y_nds_hist, 0.95, axis=0)
|
|
487
|
+
ref_point_method2 = 1.1 * quantile_values
|
|
488
|
+
# Calculate approximate HV
|
|
489
|
+
hv_method2 = HV(ref_point=ref_point_method2)
|
|
490
|
+
hv_value_method2 = hv_method2(Y)
|
|
491
|
+
# Update HV value history
|
|
492
|
+
hv_history.append(hv_value_method2)
|
|
493
|
+
if len(hv_history) > hv_history_length:
|
|
494
|
+
hv_history.pop(0)
|
|
495
|
+
|
|
496
|
+
if len(hv_history) == hv_history_length:
|
|
497
|
+
avg_hv = sum(hv_history[:-1]) / (hv_history_length - 1)
|
|
498
|
+
if avg_hv == 0:
|
|
499
|
+
hv_change = 0
|
|
500
|
+
else:
|
|
501
|
+
hv_change = abs((hv_history[-1] - avg_hv) / avg_hv)
|
|
502
|
+
# Determine if method needs to be switched
|
|
503
|
+
if iteration_since_switch >= 2:
|
|
504
|
+
if hv_change < hv_change_threshold:
|
|
505
|
+
escape_flag = not escape_flag
|
|
506
|
+
iteration_since_switch = 0 # Reset counter
|
|
507
|
+
else:
|
|
508
|
+
iteration_since_switch += 1 # If already switched, increment counter
|
|
509
|
+
|
|
510
|
+
pbar.set_postfix({"HV": hv_text, "Operator": operator_text, "Population": pop_size_used, "Num Points": evaluated})
|
|
511
|
+
else:
|
|
512
|
+
pbar.set_postfix({"HV": hv_text, "Population": pop_size_used, "Num Points": evaluated})
|
|
513
|
+
|
|
514
|
+
pbar.update(1)
|
|
515
|
+
|
|
516
|
+
name_t = (
|
|
517
|
+
"spread"
|
|
518
|
+
+ "_"
|
|
519
|
+
+ self.problem.__class__.__name__
|
|
520
|
+
+ "_T"
|
|
521
|
+
+ str(self.timesteps)
|
|
522
|
+
+ "_K"
|
|
523
|
+
+ str(n_steps_mobo)
|
|
524
|
+
+ "_FE"
|
|
525
|
+
+ str(n_steps_mobo*batch_select_mobo)
|
|
526
|
+
+ "_"
|
|
527
|
+
+ f"seed={self.seed}"
|
|
528
|
+
+ "_"
|
|
529
|
+
+ self.mode
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
if not (os.path.exists(samples_store_path)):
|
|
533
|
+
os.makedirs(samples_store_path)
|
|
534
|
+
|
|
535
|
+
if save_results:
|
|
536
|
+
# Save the samples and HV values
|
|
537
|
+
np.save(samples_store_path + name_t + "_x.npy", X)
|
|
538
|
+
np.save(samples_store_path + name_t + "_y.npy", Y)
|
|
539
|
+
print("\n================ Final Results ================\n")
|
|
540
|
+
print(f"Total function evaluations: {evaluated}")
|
|
541
|
+
print(f"Final hypervolume: {hv_value:.4e}")
|
|
542
|
+
print(f"Samples and HV values are saved to {samples_store_path}\n")
|
|
543
|
+
|
|
544
|
+
outfile = samples_store_path + name_t + "_hv_results.pkl"
|
|
545
|
+
with open(outfile, "wb") as f:
|
|
546
|
+
pickle.dump(hv_all_value, f)
|
|
547
|
+
|
|
548
|
+
return X, Y, hv_all_value
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def train(self,
|
|
552
|
+
train_dataloader,
|
|
553
|
+
val_dataloader=None,
|
|
554
|
+
disable_progress_bar=False):
|
|
555
|
+
set_seed(self.seed)
|
|
556
|
+
if self.verbose:
|
|
557
|
+
print(datetime.datetime.now())
|
|
558
|
+
|
|
559
|
+
self.model = self.model.to(self.device)
|
|
560
|
+
optimizer = optim.Adam(self.model.parameters(), lr=self.train_lr)
|
|
561
|
+
|
|
562
|
+
betas = self.cosine_beta_schedule(self.timesteps)
|
|
563
|
+
alphas = 1 - betas
|
|
564
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0).to(self.device)
|
|
565
|
+
|
|
566
|
+
DDPM_FILE_LAST = str("%s/checkpoint_ddpm_last.pth" % (self.model_dir))
|
|
567
|
+
if val_dataloader:
|
|
568
|
+
DDPM_FILE_BEST = str("%s/checkpoint_ddpm_best.pth" % (self.model_dir))
|
|
569
|
+
best_val_loss = np.inf
|
|
570
|
+
tol_violation_epoch = self.train_tol
|
|
571
|
+
cur_violation_epoch = 0
|
|
572
|
+
|
|
573
|
+
time_start = time()
|
|
574
|
+
|
|
575
|
+
with tqdm(
|
|
576
|
+
total=self.num_epochs,
|
|
577
|
+
desc=f"DDPM Training",
|
|
578
|
+
unit="epoch",
|
|
579
|
+
disable=disable_progress_bar,
|
|
580
|
+
) as pbar:
|
|
581
|
+
|
|
582
|
+
for epoch in range(self.num_epochs):
|
|
583
|
+
self.model.train()
|
|
584
|
+
train_loss = 0.0
|
|
585
|
+
for indx_batch, (batch, obj_values) in enumerate(train_dataloader):
|
|
586
|
+
optimizer.zero_grad()
|
|
587
|
+
|
|
588
|
+
# Extract batch points
|
|
589
|
+
points = batch.to(self.device)
|
|
590
|
+
points.requires_grad = True # Enable gradients
|
|
591
|
+
obj_values = obj_values.to(self.device)
|
|
592
|
+
|
|
593
|
+
# Sample random timesteps for each data point
|
|
594
|
+
t = torch.randint(0, self.timesteps, (points.shape[0],)).to(self.device)
|
|
595
|
+
alpha_bar_t = alpha_cumprod[t].unsqueeze(1) # shape: [batch_size, 1]
|
|
596
|
+
|
|
597
|
+
# Forward process: Add noise to points
|
|
598
|
+
noise = torch.randn_like(points).to(
|
|
599
|
+
self.device
|
|
600
|
+
) # shape: [batch_size, n_var]
|
|
601
|
+
x_t = (
|
|
602
|
+
torch.sqrt(alpha_bar_t) * points
|
|
603
|
+
+ torch.sqrt(1 - alpha_bar_t) * noise
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Conditioning information
|
|
607
|
+
c = obj_values
|
|
608
|
+
if self.xi_shift is not None:
|
|
609
|
+
c = c + self.xi_shift
|
|
610
|
+
else:
|
|
611
|
+
xi_shift = c[c > 0].min() if (c > 0).any() else 1e-5
|
|
612
|
+
c = c + xi_shift
|
|
613
|
+
# Model predicts noise
|
|
614
|
+
predicted_noise = self.model(
|
|
615
|
+
x_t, t.float() / self.timesteps, c
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Compute loss
|
|
619
|
+
loss_simple = self.l_simple_loss(predicted_noise,
|
|
620
|
+
noise.detach())
|
|
621
|
+
loss_simple.backward()
|
|
622
|
+
optimizer.step()
|
|
623
|
+
train_loss += loss_simple.item()
|
|
624
|
+
|
|
625
|
+
train_loss = train_loss / len(train_dataloader)
|
|
626
|
+
|
|
627
|
+
if val_dataloader:
|
|
628
|
+
# Validation
|
|
629
|
+
self.model.eval()
|
|
630
|
+
val_loss = 0.0
|
|
631
|
+
for indx_batch, (val_batch, val_obj_values) in enumerate(val_dataloader):
|
|
632
|
+
# Extract batch points
|
|
633
|
+
val_points = val_batch.to(
|
|
634
|
+
self.device
|
|
635
|
+
)
|
|
636
|
+
val_points.requires_grad = True # Enable gradients
|
|
637
|
+
val_obj_values = val_obj_values.to(self.device)
|
|
638
|
+
|
|
639
|
+
# Sample random timesteps for each data point
|
|
640
|
+
t = torch.randint(0, self.timesteps, (val_points.shape[0],)).to(self.device)
|
|
641
|
+
alpha_bar_t = alpha_cumprod[t].unsqueeze(1) # shape: [batch_size, 1]
|
|
642
|
+
|
|
643
|
+
# Forward process: Add noise to points
|
|
644
|
+
val_noise = torch.randn_like(val_points).to(
|
|
645
|
+
self.device
|
|
646
|
+
) # shape: [batch_size, n_var]
|
|
647
|
+
val_x_t = (
|
|
648
|
+
torch.sqrt(alpha_bar_t) * val_points
|
|
649
|
+
+ torch.sqrt(1 - alpha_bar_t) * val_noise
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Conditioning information
|
|
653
|
+
c = val_obj_values
|
|
654
|
+
if self.xi_shift is not None:
|
|
655
|
+
c = c + self.xi_shift
|
|
656
|
+
else:
|
|
657
|
+
xi_shift = c[c > 0].min() if (c > 0).any() else 1e-6
|
|
658
|
+
c = c + xi_shift
|
|
659
|
+
# Model predicts noise
|
|
660
|
+
val_predicted_noise = self.model(
|
|
661
|
+
val_x_t, t.float() / self.timesteps, c
|
|
662
|
+
)
|
|
663
|
+
loss_simple = self.l_simple_loss(val_predicted_noise,
|
|
664
|
+
val_noise.detach())
|
|
665
|
+
val_loss += loss_simple.item()
|
|
666
|
+
|
|
667
|
+
val_loss = val_loss / len(val_dataloader)
|
|
668
|
+
|
|
669
|
+
if val_loss <= best_val_loss:
|
|
670
|
+
best_val_loss = val_loss
|
|
671
|
+
cur_violation_epoch = 0
|
|
672
|
+
# Save the model
|
|
673
|
+
checkpoint = {
|
|
674
|
+
"model_state_dict": self.model.state_dict(),
|
|
675
|
+
"epoch": epoch + 1,
|
|
676
|
+
}
|
|
677
|
+
torch.save(checkpoint, DDPM_FILE_BEST)
|
|
678
|
+
else:
|
|
679
|
+
cur_violation_epoch += 1
|
|
680
|
+
if cur_violation_epoch >= tol_violation_epoch:
|
|
681
|
+
if self.verbose:
|
|
682
|
+
print(f"Early Stopping at epoch {epoch + 1}.")
|
|
683
|
+
break
|
|
684
|
+
|
|
685
|
+
pbar.set_postfix({"val_loss": val_loss})
|
|
686
|
+
else:
|
|
687
|
+
pbar.set_postfix({"train_loss": train_loss})
|
|
688
|
+
pbar.update(1)
|
|
689
|
+
|
|
690
|
+
comp_time = time() - time_start
|
|
691
|
+
if self.verbose:
|
|
692
|
+
convert_seconds(comp_time)
|
|
693
|
+
if self.verbose:
|
|
694
|
+
print(datetime.datetime.now())
|
|
695
|
+
|
|
696
|
+
if val_dataloader and self.verbose:
|
|
697
|
+
print(f"Best model saved at: {DDPM_FILE_BEST}")
|
|
698
|
+
# Save the model
|
|
699
|
+
checkpoint = {
|
|
700
|
+
"model_state_dict": self.model.state_dict(),
|
|
701
|
+
"train_time": comp_time,
|
|
702
|
+
"num_epochs": self.num_epochs,
|
|
703
|
+
"train_tol": self.train_tol,
|
|
704
|
+
"train_lr": self.train_lr,
|
|
705
|
+
"batch_size": self.batch_size,
|
|
706
|
+
"timesteps": self.timesteps,
|
|
707
|
+
}
|
|
708
|
+
torch.save(checkpoint, DDPM_FILE_LAST)
|
|
709
|
+
if self.verbose:
|
|
710
|
+
print(f"Final model saved at: {DDPM_FILE_LAST}")
|
|
711
|
+
|
|
712
|
+
def sampling(self, num_points_sample,
|
|
713
|
+
strict_guidance=False,
|
|
714
|
+
rho_scale_gamma=0.9,
|
|
715
|
+
nu_t=10.0, eta_init=0.9,
|
|
716
|
+
num_inner_steps=10, lr_inner=1e-4,
|
|
717
|
+
free_initial_h=True,
|
|
718
|
+
use_sigma_rep=False, kernel_sigma_rep=0.01,
|
|
719
|
+
iterative_plot=True, plot_period=100,
|
|
720
|
+
max_backtracks=25, label=None,
|
|
721
|
+
samples_store_path="./samples_dir/",
|
|
722
|
+
images_store_path="./images_dir/",
|
|
723
|
+
disable_progress_bar=False,
|
|
724
|
+
save_results=True, evaluate_final=True):
|
|
725
|
+
# Set the seed
|
|
726
|
+
set_seed(self.seed)
|
|
727
|
+
if save_results:
|
|
728
|
+
# Store the results
|
|
729
|
+
if not (os.path.exists(samples_store_path)):
|
|
730
|
+
os.makedirs(samples_store_path)
|
|
731
|
+
|
|
732
|
+
name = (
|
|
733
|
+
"spread"
|
|
734
|
+
+ "_"
|
|
735
|
+
+ self.problem.__class__.__name__
|
|
736
|
+
+ "_"
|
|
737
|
+
+ f"T={self.timesteps}"
|
|
738
|
+
+ "_"
|
|
739
|
+
+ f"N={num_points_sample}"
|
|
740
|
+
+ "_"
|
|
741
|
+
+ f"seed={self.seed}"
|
|
742
|
+
+ "_"
|
|
743
|
+
+ self.mode
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
if label is not None:
|
|
747
|
+
name += f"_{label}"
|
|
748
|
+
|
|
749
|
+
self.model.to(self.device)
|
|
750
|
+
DDPM_FILE = str("%s/checkpoint_ddpm_best.pth" % (self.model_dir))
|
|
751
|
+
# Load the best model if exists, else load the last model
|
|
752
|
+
if not os.path.exists(DDPM_FILE):
|
|
753
|
+
DDPM_FILE = str("%s/checkpoint_ddpm_last.pth" % (self.model_dir))
|
|
754
|
+
if not os.path.exists(DDPM_FILE):
|
|
755
|
+
raise ValueError(f"No trained model found in {self.model_dir}")
|
|
756
|
+
checkpoint = torch.load(DDPM_FILE, map_location=self.device, weights_only=False)
|
|
757
|
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
758
|
+
|
|
759
|
+
betas = self.cosine_beta_schedule(self.timesteps)
|
|
760
|
+
alphas = 1 - betas
|
|
761
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0).to(self.device)
|
|
762
|
+
|
|
763
|
+
# Start from random points in the decision space
|
|
764
|
+
x_t = torch.rand((num_points_sample, self.problem.n_var)) # in [0, 1]
|
|
765
|
+
x_t = self.problem.bounds()[0] + (self.problem.bounds()[1] - self.problem.bounds()[0]) * x_t # scale to bounds
|
|
766
|
+
if self.mode == "offline":
|
|
767
|
+
x_t, _, _ = self.offline_normalization(x_t)
|
|
768
|
+
x_t = x_t.to(self.device)
|
|
769
|
+
x_t.requires_grad = True
|
|
770
|
+
if self.problem.need_repair:
|
|
771
|
+
x_t.data = self.repair_bounds(x_t.data.clone())
|
|
772
|
+
|
|
773
|
+
if self.mode == "online":
|
|
774
|
+
if iterative_plot and (not is_pass_function(self.problem._evaluate)):
|
|
775
|
+
if self.problem.n_obj <= 3:
|
|
776
|
+
pf_population = x_t.detach()
|
|
777
|
+
pf_points, _, _ = self.get_non_dominated_points(
|
|
778
|
+
pf_population,
|
|
779
|
+
keep_shape=False
|
|
780
|
+
)
|
|
781
|
+
list_fi = self.objective_functions(pf_points).split(1, dim=1)
|
|
782
|
+
list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
|
|
783
|
+
pareto_front = None
|
|
784
|
+
if self.problem.pareto_front() is not None:
|
|
785
|
+
pareto_front = self.problem.pareto_front()
|
|
786
|
+
pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
|
|
787
|
+
if self.plot_func is not None:
|
|
788
|
+
self.plot_func(list_fi, self.timesteps,
|
|
789
|
+
num_points_sample,
|
|
790
|
+
extra=pareto_front,
|
|
791
|
+
label=label, images_store_path=images_store_path)
|
|
792
|
+
else:
|
|
793
|
+
plot_dataset = True if self.mode == "offline" else False
|
|
794
|
+
list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
|
|
795
|
+
list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
|
|
796
|
+
self.plot_pareto_front(list_fi, self.timesteps,
|
|
797
|
+
num_points_sample,
|
|
798
|
+
extra=pareto_front,
|
|
799
|
+
plot_dataset=plot_dataset,
|
|
800
|
+
pop=list_fi_pop,
|
|
801
|
+
label=label, images_store_path=images_store_path)
|
|
802
|
+
|
|
803
|
+
prev_pf_points = None
|
|
804
|
+
num_optimal_points = 0
|
|
805
|
+
|
|
806
|
+
point_n0 = None
|
|
807
|
+
optimizer_n0 = None
|
|
808
|
+
if strict_guidance:
|
|
809
|
+
#### Initialize 1 target point for direction perturbation.
|
|
810
|
+
# (If strict_guidance, get new perturbation based on the MGD direction of
|
|
811
|
+
# a single initialized point. Otherwise, a random perturbation is used (as in the paper).)
|
|
812
|
+
point_n0 = torch.rand((1, self.problem.n_var)) # in [0, 1]
|
|
813
|
+
point_n0 = self.problem.bounds()[0] + (self.problem.bounds()[1] - self.problem.bounds()[0]) * point_n0 # scale to bounds
|
|
814
|
+
if self.mode == "offline":
|
|
815
|
+
point_n0, _, _ = self.offline_normalization(point_n0)
|
|
816
|
+
point_n0 = point_n0.to(self.device)
|
|
817
|
+
point_n0.requires_grad = True
|
|
818
|
+
if self.problem.need_repair:
|
|
819
|
+
point_n0.data = self.repair_bounds(
|
|
820
|
+
point_n0.data.clone()
|
|
821
|
+
)
|
|
822
|
+
optimizer_n0 = optim.Adam([point_n0], lr=1e-2)
|
|
823
|
+
|
|
824
|
+
if self.verbose:
|
|
825
|
+
print(datetime.datetime.now())
|
|
826
|
+
print(f"START sampling {num_points_sample} points ...")
|
|
827
|
+
|
|
828
|
+
time_start = time()
|
|
829
|
+
self.model.eval()
|
|
830
|
+
|
|
831
|
+
with tqdm(
|
|
832
|
+
total=self.timesteps,
|
|
833
|
+
desc=f"SPREAD Sampling",
|
|
834
|
+
unit="t",
|
|
835
|
+
disable=disable_progress_bar,
|
|
836
|
+
) as pbar:
|
|
837
|
+
|
|
838
|
+
for t in reversed(range(self.timesteps)):
|
|
839
|
+
x_t.requires_grad_(True)
|
|
840
|
+
if self.problem.need_repair:
|
|
841
|
+
x_t.data = self.repair_bounds(
|
|
842
|
+
x_t.data.clone()
|
|
843
|
+
)
|
|
844
|
+
# Compute beta_t and alpha_t
|
|
845
|
+
beta_t = 1 - alphas[t]
|
|
846
|
+
alpha_bar_t = alpha_cumprod[t]
|
|
847
|
+
|
|
848
|
+
x_t = self.one_spread_sampling_step(
|
|
849
|
+
x_t,
|
|
850
|
+
num_points_sample,
|
|
851
|
+
t, beta_t, alpha_bar_t, rho_scale_gamma,
|
|
852
|
+
nu_t, eta_init, num_inner_steps, lr_inner,
|
|
853
|
+
free_initial_h=free_initial_h,
|
|
854
|
+
use_sigma=use_sigma_rep, kernel_sigma=kernel_sigma_rep,
|
|
855
|
+
strict_guidance=strict_guidance, max_backtracks=max_backtracks,
|
|
856
|
+
point_n0=point_n0, optimizer_n0=optimizer_n0,
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
if strict_guidance and self.problem.need_repair:
|
|
860
|
+
point_n0.data = self.repair_bounds(
|
|
861
|
+
point_n0.data.clone()
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
if self.problem.need_repair:
|
|
865
|
+
pf_population = self.repair_bounds(
|
|
866
|
+
copy.deepcopy(x_t.detach())
|
|
867
|
+
)
|
|
868
|
+
else:
|
|
869
|
+
pf_population = copy.deepcopy(x_t.detach())
|
|
870
|
+
|
|
871
|
+
# print("Number of points before selection:", len(pf_population))
|
|
872
|
+
pf_points, _, _ = self.get_non_dominated_points(
|
|
873
|
+
pf_population,
|
|
874
|
+
keep_shape=False
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
if prev_pf_points is not None:
|
|
878
|
+
pf_points = torch.cat((prev_pf_points, pf_points), dim=0)
|
|
879
|
+
if self.mode != "bayesian":
|
|
880
|
+
pf_points, _, _ = self.get_non_dominated_points(
|
|
881
|
+
pf_points,
|
|
882
|
+
keep_shape=False,
|
|
883
|
+
)
|
|
884
|
+
# print("Number of non-dominated points before selection:", len(non_dom_points))
|
|
885
|
+
if len(pf_points) > num_points_sample:
|
|
886
|
+
pf_points = self.select_top_n_candidates(
|
|
887
|
+
pf_points,
|
|
888
|
+
num_points_sample,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
if self.mode == "bayesian" and (pf_points is None or len(pf_points) == 0):
|
|
892
|
+
pf_points = x_t.detach()
|
|
893
|
+
prev_pf_points = pf_points
|
|
894
|
+
num_optimal_points = len(pf_points)
|
|
895
|
+
|
|
896
|
+
if iterative_plot and (not is_pass_function(self.problem._evaluate)):
|
|
897
|
+
if self.problem.n_obj <= 3:
|
|
898
|
+
if (t % plot_period == 0) or (t == self.timesteps - 1):
|
|
899
|
+
if self.mode == "offline":
|
|
900
|
+
# Denormalize the points before plotting
|
|
901
|
+
res_x_t = pf_points.clone().detach()
|
|
902
|
+
res_x_t = self.offline_denormalization(res_x_t,
|
|
903
|
+
self.X_meanormin,
|
|
904
|
+
self.X_stdormax)
|
|
905
|
+
norm_xl, norm_xu = self.problem.bounds()
|
|
906
|
+
xl, xu = self.problem.original_bounds
|
|
907
|
+
self.problem.xl = xl
|
|
908
|
+
self.problem.xu = xu
|
|
909
|
+
if self.problem.is_discrete:
|
|
910
|
+
_, dim, n_classes = tuple(res_x_t.shape)
|
|
911
|
+
res_x_t = res_x_t.reshape(-1, dim, n_classes)
|
|
912
|
+
res_x_t = offdata_to_integers(res_x_t)
|
|
913
|
+
if self.problem.is_sequence:
|
|
914
|
+
res_x_t = offdata_to_integers(res_x_t)
|
|
915
|
+
# we need to evaluate the true objective functions for plotting
|
|
916
|
+
list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
|
|
917
|
+
# restore the normalized bounds
|
|
918
|
+
self.problem.xl = norm_xl
|
|
919
|
+
self.problem.xu = norm_xu
|
|
920
|
+
elif self.mode == "bayesian":
|
|
921
|
+
# we need to evaluate the true objective functions for plotting
|
|
922
|
+
list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
|
|
923
|
+
else:
|
|
924
|
+
list_fi = self.objective_functions(pf_points).split(1, dim=1)
|
|
925
|
+
|
|
926
|
+
list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
|
|
927
|
+
pareto_front = None
|
|
928
|
+
if self.problem.pareto_front() is not None:
|
|
929
|
+
pareto_front = self.problem.pareto_front()
|
|
930
|
+
pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
|
|
931
|
+
if self.plot_func is not None:
|
|
932
|
+
self.plot_func(list_fi, t,
|
|
933
|
+
num_points_sample,
|
|
934
|
+
extra= pareto_front,
|
|
935
|
+
label=label, images_store_path=images_store_path)
|
|
936
|
+
else:
|
|
937
|
+
plot_dataset = True if self.mode == "offline" else False
|
|
938
|
+
list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
|
|
939
|
+
list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
|
|
940
|
+
self.plot_pareto_front(list_fi, t,
|
|
941
|
+
num_points_sample,
|
|
942
|
+
extra= pareto_front,
|
|
943
|
+
pop=list_fi_pop,
|
|
944
|
+
plot_dataset=plot_dataset,
|
|
945
|
+
label=label, images_store_path=images_store_path)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
x_t = x_t.detach()
|
|
949
|
+
pbar.set_postfix({
|
|
950
|
+
"Points": num_optimal_points,
|
|
951
|
+
})
|
|
952
|
+
pbar.update(1)
|
|
953
|
+
if self.verbose:
|
|
954
|
+
print(f"END sampling !")
|
|
955
|
+
|
|
956
|
+
comp_time = time() - time_start
|
|
957
|
+
|
|
958
|
+
if self.mode == "offline":
|
|
959
|
+
pf_points = pf_points.detach()
|
|
960
|
+
pf_points = self.offline_denormalization(pf_points,
|
|
961
|
+
self.X_meanormin,
|
|
962
|
+
self.X_stdormax)
|
|
963
|
+
if self.problem.is_discrete:
|
|
964
|
+
_, dim, n_classes = tuple(pf_points.shape)
|
|
965
|
+
pf_points = pf_points.reshape(-1, dim, n_classes)
|
|
966
|
+
pf_points = offdata_to_integers(pf_points)
|
|
967
|
+
if self.problem.is_sequence:
|
|
968
|
+
pf_points = offdata_to_integers(pf_points)
|
|
969
|
+
if self.problem.has_bounds():
|
|
970
|
+
self.problem.xl, self.problem.xu = self.problem.original_bounds
|
|
971
|
+
|
|
972
|
+
res_x = pf_points.detach().cpu().numpy()
|
|
973
|
+
res_y = None
|
|
974
|
+
if evaluate_final and (not is_pass_function(self.problem._evaluate)):
|
|
975
|
+
res_y = self.problem.evaluate(pf_points).detach().cpu().numpy()
|
|
976
|
+
visible_masks = np.ones(len(res_y))
|
|
977
|
+
visible_masks[np.where(np.logical_or(np.isinf(res_y), np.isnan(res_y)))[0]] = 0
|
|
978
|
+
visible_masks[np.where(np.logical_or(np.isinf(res_x), np.isnan(res_x)))[0]] = 0
|
|
979
|
+
res_x = res_x[np.where(visible_masks == 1)[0]]
|
|
980
|
+
res_y = res_y[np.where(visible_masks == 1)[0]]
|
|
981
|
+
if save_results:
|
|
982
|
+
np.save(samples_store_path + name + "_y.npy", res_y)
|
|
983
|
+
hv = HV(ref_point=self.problem.ref_point)
|
|
984
|
+
hv_value = hv(res_y)
|
|
985
|
+
hv_results = {
|
|
986
|
+
"ref_point": self.problem.ref_point,
|
|
987
|
+
"hypervolume": hv_value,
|
|
988
|
+
"computation_time": comp_time
|
|
989
|
+
}
|
|
990
|
+
with open(samples_store_path + name + "_hv_results.pkl", "wb") as f:
|
|
991
|
+
pickle.dump(hv_results, f)
|
|
992
|
+
if self.verbose:
|
|
993
|
+
print(f"Hypervolume: {hv_value} for seed {self.seed}")
|
|
994
|
+
print("---------------------------------------")
|
|
995
|
+
# Print computation time
|
|
996
|
+
convert_seconds(comp_time)
|
|
997
|
+
print(datetime.datetime.now())
|
|
998
|
+
|
|
999
|
+
if save_results:
|
|
1000
|
+
np.save(samples_store_path + name + "_x.npy", res_x)
|
|
1001
|
+
|
|
1002
|
+
return res_x, res_y
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def train_surrogate(self,
|
|
1006
|
+
X, y,
|
|
1007
|
+
val_ratio=0.1,
|
|
1008
|
+
batch_size=32,
|
|
1009
|
+
lr=1e-3,
|
|
1010
|
+
lr_decay=0.95,
|
|
1011
|
+
n_epochs=200):
|
|
1012
|
+
|
|
1013
|
+
set_seed(self.seed)
|
|
1014
|
+
self.surrogate_model = self.get_surrogate()
|
|
1015
|
+
if self.surrogate_given:
|
|
1016
|
+
return self.train_surrogate_user_defined(X, y)
|
|
1017
|
+
|
|
1018
|
+
# Train the surrogate model
|
|
1019
|
+
if self.mode == "bayesian":
|
|
1020
|
+
self.surrogate_model.fit(X, y)
|
|
1021
|
+
elif self.mode == "offline":
|
|
1022
|
+
n_obj = y.shape[1]
|
|
1023
|
+
tkwargs = {"device": self.device, "dtype": torch.float32}
|
|
1024
|
+
self.surrogate_model.set_kwargs(**tkwargs)
|
|
1025
|
+
|
|
1026
|
+
trainer_func = SingleModelBaseTrainer
|
|
1027
|
+
|
|
1028
|
+
for which_obj in range(n_obj):
|
|
1029
|
+
|
|
1030
|
+
y0 = y[:, which_obj].clone().reshape(-1, 1)
|
|
1031
|
+
|
|
1032
|
+
trainer = trainer_func(
|
|
1033
|
+
model=list(self.surrogate_model.obj2model.values())[which_obj],
|
|
1034
|
+
which_obj=which_obj,
|
|
1035
|
+
args={
|
|
1036
|
+
"proxies_lr": lr,
|
|
1037
|
+
"proxies_lr_decay": lr_decay,
|
|
1038
|
+
"proxies_epochs": n_epochs,
|
|
1039
|
+
"device": self.device,
|
|
1040
|
+
"verbose": self.verbose,
|
|
1041
|
+
},
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
(train_loader, val_loader) = offdata_get_dataloader(
|
|
1045
|
+
X,
|
|
1046
|
+
y0,
|
|
1047
|
+
train_ratio=(
|
|
1048
|
+
1 - val_ratio
|
|
1049
|
+
),
|
|
1050
|
+
batch_size=batch_size,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
trainer.launch(train_loader, val_loader)
|
|
1054
|
+
# Load the proxies
|
|
1055
|
+
self.surrogate_model = []
|
|
1056
|
+
for i in range(n_obj):
|
|
1057
|
+
classifier = SingleModel(
|
|
1058
|
+
input_size=self.problem.n_var,
|
|
1059
|
+
which_obj=i,
|
|
1060
|
+
device=self.device,
|
|
1061
|
+
hidden_size=[2048, 2048],
|
|
1062
|
+
save_dir=self.proxies_store_path,
|
|
1063
|
+
save_prefix=f"MultipleModels-Vallina-{self.problem.__class__.__name__}-{self.seed}",
|
|
1064
|
+
)
|
|
1065
|
+
classifier.load()
|
|
1066
|
+
classifier = classifier.to(self.device)
|
|
1067
|
+
classifier.eval()
|
|
1068
|
+
self.surrogate_model.append(classifier)
|
|
1069
|
+
else:
|
|
1070
|
+
raise ValueError(f"Mode {self.mode} does not support surrogate model!")
|
|
1071
|
+
|
|
1072
|
+
def train_surrogate_user_defined(self, X, y):
|
|
1073
|
+
"""
|
|
1074
|
+
Train the user-defined surrogate model.
|
|
1075
|
+
If self.mode == "offline", the train_func should return a list of trained surrogate models,
|
|
1076
|
+
one for each objective.
|
|
1077
|
+
If self.mode == "bayesian", the train_func should return a single trained surrogate model for all objectives.
|
|
1078
|
+
|
|
1079
|
+
Parameters
|
|
1080
|
+
----------
|
|
1081
|
+
train_func : function
|
|
1082
|
+
A function that takes X, y as input and returns a trained surrogate model.
|
|
1083
|
+
**kwargs : dict
|
|
1084
|
+
Additional keyword arguments to pass to the train_func.
|
|
1085
|
+
-----------
|
|
1086
|
+
"""
|
|
1087
|
+
self.surrogate_model = self.train_func_surrogate(X, y)
|
|
1088
|
+
|
|
1089
|
+
def get_surrogate(self):
|
|
1090
|
+
if self.surrogate_given:
|
|
1091
|
+
return self.surrogate_model
|
|
1092
|
+
else:
|
|
1093
|
+
if self.mode == "bayesian":
|
|
1094
|
+
return GaussianProcess(self.problem.n_var,
|
|
1095
|
+
self.problem.n_obj,
|
|
1096
|
+
nu=5)
|
|
1097
|
+
elif self.mode == "offline":
|
|
1098
|
+
os.makedirs(self.proxies_store_path, exist_ok=True)
|
|
1099
|
+
return MultipleModels(
|
|
1100
|
+
n_dim=self.problem.n_var,
|
|
1101
|
+
n_obj=self.problem.n_obj,
|
|
1102
|
+
train_mode="Vallina",
|
|
1103
|
+
device=self.device,
|
|
1104
|
+
hidden_size=[2048, 2048],
|
|
1105
|
+
save_dir=self.proxies_store_path,
|
|
1106
|
+
save_prefix=f"MultipleModels-Vallina-{self.problem.__class__.__name__}-{self.seed}",
|
|
1107
|
+
)
|
|
1108
|
+
else:
|
|
1109
|
+
raise ValueError(f"Mode {self.mode} does not support surrogate model!")
|
|
1110
|
+
|
|
1111
|
+
def one_spread_sampling_step(
|
|
1112
|
+
self,
|
|
1113
|
+
x_t,
|
|
1114
|
+
num_points_sample,
|
|
1115
|
+
t, beta_t, alpha_bar_t, rho_scale_gamma,
|
|
1116
|
+
nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h,
|
|
1117
|
+
use_sigma=False, kernel_sigma=1.0, strict_guidance = False,
|
|
1118
|
+
max_backtracks=100, point_n0=None, optimizer_n0=None,
|
|
1119
|
+
):
|
|
1120
|
+
|
|
1121
|
+
# Create a tensor of timesteps with shape (num_points_sample, 1)
|
|
1122
|
+
t_tensor = torch.full(
|
|
1123
|
+
(num_points_sample,),
|
|
1124
|
+
t,
|
|
1125
|
+
device=self.device,
|
|
1126
|
+
dtype=torch.float32,
|
|
1127
|
+
)
|
|
1128
|
+
# Compute objective values
|
|
1129
|
+
obj_values = self.objective_functions(x_t)
|
|
1130
|
+
|
|
1131
|
+
# Conditioning information
|
|
1132
|
+
c = obj_values
|
|
1133
|
+
with torch.no_grad():
|
|
1134
|
+
predicted_noise = self.model(x_t,
|
|
1135
|
+
t_tensor / self.timesteps,
|
|
1136
|
+
c)
|
|
1137
|
+
|
|
1138
|
+
g_w = None
|
|
1139
|
+
if strict_guidance:
|
|
1140
|
+
#### If strict_guidance, get new perturbation based on the MGD direction of
|
|
1141
|
+
# a single initialized point. Otherwise, a random perturbation is used.
|
|
1142
|
+
if self.mode in ["online", "offline"]:
|
|
1143
|
+
list_fi_n0 = self.objective_functions(point_n0).split(1, dim=1)
|
|
1144
|
+
list_grad_i_n0 = []
|
|
1145
|
+
for fi_n0 in list_fi_n0:
|
|
1146
|
+
fi_n0.sum().backward(retain_graph=True)
|
|
1147
|
+
grad_i_n0 = point_n0.grad.detach().clone()
|
|
1148
|
+
point_n0.grad.zero_()
|
|
1149
|
+
grad_i_n0 = torch.nn.functional.normalize(grad_i_n0, dim=0)
|
|
1150
|
+
list_grad_i_n0.append(grad_i_n0)
|
|
1151
|
+
else:
|
|
1152
|
+
list_grad_i_n0 = self.objective_functions(point_n0, get_grad_mobo=True)["dF"]
|
|
1153
|
+
for i in range(len(list_grad_i_n0)):
|
|
1154
|
+
# X.grad.zero_()
|
|
1155
|
+
grad_i_n0 = list_grad_i_n0[i]
|
|
1156
|
+
# Normalize gradients
|
|
1157
|
+
grad_i_n0 = torch.nn.functional.normalize(grad_i_n0, dim=0)
|
|
1158
|
+
list_grad_i_n0[i] = grad_i_n0
|
|
1159
|
+
|
|
1160
|
+
optimizer_n0.zero_grad()
|
|
1161
|
+
mth = "mgda"
|
|
1162
|
+
if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
|
|
1163
|
+
mth = "pmgda"
|
|
1164
|
+
g_w = self.get_target_dir(list_grad_i_n0, mth=mth, x=point_n0)
|
|
1165
|
+
point_n0.grad = g_w
|
|
1166
|
+
optimizer_n0.step()
|
|
1167
|
+
|
|
1168
|
+
#### Reverse diffusion step
|
|
1169
|
+
sqrt_1_minus_alpha_t = torch.sqrt(torch.clamp(1 - alpha_bar_t, min=1e-6))
|
|
1170
|
+
sqrt_1_minus_beta_t = torch.sqrt(torch.clamp(1 - beta_t, min=1e-6))
|
|
1171
|
+
mean = (1 / sqrt_1_minus_beta_t) * (
|
|
1172
|
+
x_t - (beta_t / sqrt_1_minus_alpha_t) * (predicted_noise)
|
|
1173
|
+
)
|
|
1174
|
+
std_dev = torch.sqrt(beta_t)
|
|
1175
|
+
z = torch.randn_like(x_t) if t > 0 else 0.0 # No noise for the final step
|
|
1176
|
+
x_t = mean + std_dev * z
|
|
1177
|
+
|
|
1178
|
+
#### Pareto Guidance step
|
|
1179
|
+
if self.problem.need_repair:
|
|
1180
|
+
x_t.data = self.repair_bounds(x_t.data.clone())
|
|
1181
|
+
X = x_t.clone().detach().requires_grad_()
|
|
1182
|
+
if self.mode in ["online", "offline"]:
|
|
1183
|
+
list_fi = self.objective_functions(X).split(1, dim=1)
|
|
1184
|
+
list_grad_i = []
|
|
1185
|
+
for fi in list_fi:
|
|
1186
|
+
fi.sum().backward(retain_graph=True)
|
|
1187
|
+
grad_i = X.grad.detach().clone()
|
|
1188
|
+
grad_i = torch.nn.functional.normalize(grad_i, dim=0)
|
|
1189
|
+
list_grad_i.append(grad_i)
|
|
1190
|
+
X.grad.zero_()
|
|
1191
|
+
elif self.mode == "bayesian":
|
|
1192
|
+
list_grad_i = self.objective_functions(X, get_grad_mobo=True)["dF"]
|
|
1193
|
+
for i in range(len(list_grad_i)):
|
|
1194
|
+
# X.grad.zero_()
|
|
1195
|
+
grad_i = list_grad_i[i]
|
|
1196
|
+
# Normalize gradients
|
|
1197
|
+
grad_i = torch.nn.functional.normalize(grad_i, dim=0)
|
|
1198
|
+
list_grad_i[i] = grad_i
|
|
1199
|
+
else:
|
|
1200
|
+
raise ValueError(f"Mode {self.mode} not recognized!")
|
|
1201
|
+
|
|
1202
|
+
grads = torch.stack(list_grad_i, dim=0) # (m, N, d)
|
|
1203
|
+
grads_copy = torch.stack(list_grad_i, dim=1).detach() # (N, m, d)
|
|
1204
|
+
mth = "mgda"
|
|
1205
|
+
if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
|
|
1206
|
+
mth = "pmgda"
|
|
1207
|
+
g_x_t_prime = self.get_target_dir(list_grad_i, mth=mth, x=X)
|
|
1208
|
+
|
|
1209
|
+
eta = self.mgd_armijo_step(
|
|
1210
|
+
x_t, g_x_t_prime,
|
|
1211
|
+
obj_values, grads_copy, # (N, m, d)
|
|
1212
|
+
eta_init=eta_init,
|
|
1213
|
+
max_backtracks=max_backtracks
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
h_tilde = self.solve_for_h(
|
|
1217
|
+
x_t,
|
|
1218
|
+
g_x_t_prime,
|
|
1219
|
+
grads,
|
|
1220
|
+
g_w,
|
|
1221
|
+
eta=eta,
|
|
1222
|
+
nu_t=nu_t,
|
|
1223
|
+
sigma=kernel_sigma,
|
|
1224
|
+
free_initial_h=free_initial_h,
|
|
1225
|
+
use_sigma=False,
|
|
1226
|
+
num_inner_steps=num_inner_steps,
|
|
1227
|
+
lr_inner=lr_inner,
|
|
1228
|
+
rho_scale_gamma=rho_scale_gamma
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
h_tilde = torch.nan_to_num(h_tilde,
|
|
1232
|
+
nan=torch.nanmean(h_tilde),
|
|
1233
|
+
posinf=0.0,
|
|
1234
|
+
neginf=0.0)
|
|
1235
|
+
|
|
1236
|
+
x_t = x_t - eta * h_tilde
|
|
1237
|
+
|
|
1238
|
+
return x_t
|
|
1239
|
+
|
|
1240
|
+
def solve_for_h(
|
|
1241
|
+
self,
|
|
1242
|
+
x_t_prime,
|
|
1243
|
+
g_x_t_prime,
|
|
1244
|
+
grads,
|
|
1245
|
+
g_w,
|
|
1246
|
+
eta,
|
|
1247
|
+
nu_t,
|
|
1248
|
+
sigma=1.0,
|
|
1249
|
+
use_sigma=False,
|
|
1250
|
+
num_inner_steps=10,
|
|
1251
|
+
lr_inner=1e-2,
|
|
1252
|
+
strict_guidance=False,
|
|
1253
|
+
free_initial_h=False,
|
|
1254
|
+
rho_scale_gamma=0.9
|
|
1255
|
+
):
|
|
1256
|
+
"""
|
|
1257
|
+
Returns:
|
|
1258
|
+
h_tilde: Optimized h (Tensor of shape (batch_size, n_var)).
|
|
1259
|
+
"""
|
|
1260
|
+
|
|
1261
|
+
x_t_h = x_t_prime.clone().detach()
|
|
1262
|
+
g = g_x_t_prime.clone().detach()
|
|
1263
|
+
|
|
1264
|
+
if strict_guidance:
|
|
1265
|
+
g_targ = g_w.clone().detach()
|
|
1266
|
+
else:
|
|
1267
|
+
g_targ = torch.randn((1, g.shape[1]), device=g.device)
|
|
1268
|
+
|
|
1269
|
+
# Initialize h
|
|
1270
|
+
if not free_initial_h:
|
|
1271
|
+
h = g_x_t_prime.clone().detach().requires_grad_() # initialize at g
|
|
1272
|
+
else:
|
|
1273
|
+
h = torch.zeros_like(g, requires_grad=False) + 1e-6 # or as a free parameter
|
|
1274
|
+
h = h.requires_grad_()
|
|
1275
|
+
|
|
1276
|
+
optimizer_inner = optim.Adam([h], lr=lr_inner)
|
|
1277
|
+
|
|
1278
|
+
for step in range(num_inner_steps):
|
|
1279
|
+
|
|
1280
|
+
gtarg_scaled = self.adaptive_scale_delta_vect(
|
|
1281
|
+
h, g_targ, grads, gamma=rho_scale_gamma
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
# Alignment term: maximize <g, h>
|
|
1285
|
+
# To maximize L, we minimize -L:
|
|
1286
|
+
alignment = -torch.mean(torch.sum(g * h, dim=-1))
|
|
1287
|
+
|
|
1288
|
+
# Update points:
|
|
1289
|
+
x_t_h = x_t_h - eta * (h + gtarg_scaled)
|
|
1290
|
+
if self.problem.need_repair:
|
|
1291
|
+
x_t_h.data = self.repair_bounds(
|
|
1292
|
+
x_t_h.data
|
|
1293
|
+
)
|
|
1294
|
+
# Map the updated points to the objective space
|
|
1295
|
+
F_ = self.objective_functions(x_t_h)
|
|
1296
|
+
# Compute repulsion loss to encourage diversity
|
|
1297
|
+
if use_sigma:
|
|
1298
|
+
rep_loss = self.repulsion_loss(F_, sigma)
|
|
1299
|
+
else:
|
|
1300
|
+
rep_loss = self.repulsion_loss(F_, use_sigma=False)
|
|
1301
|
+
|
|
1302
|
+
# Our composite objective L is:
|
|
1303
|
+
loss = alignment + nu_t * rep_loss
|
|
1304
|
+
|
|
1305
|
+
optimizer_inner.zero_grad()
|
|
1306
|
+
loss.backward(retain_graph=True)
|
|
1307
|
+
optimizer_inner.step()
|
|
1308
|
+
|
|
1309
|
+
h_tilde = h + gtarg_scaled # This is h_tilde in the paper
|
|
1310
|
+
|
|
1311
|
+
return h_tilde.detach()
|
|
1312
|
+
|
|
1313
|
+
def get_training_data(self, problem, num_samples=10000):
|
|
1314
|
+
"""
|
|
1315
|
+
Sample points, using LHS, based on lowest constraint violation
|
|
1316
|
+
"""
|
|
1317
|
+
sampler = LHS()
|
|
1318
|
+
# Problem bounds
|
|
1319
|
+
xl, xu = problem.xl, problem.xu
|
|
1320
|
+
# Draw n_sample candidates in [0,1]^n_var
|
|
1321
|
+
Xcand = sampler.do(problem, num_samples).get("X")
|
|
1322
|
+
# Scale to actual bounds
|
|
1323
|
+
Xcand = xl + (xu - xl) * Xcand
|
|
1324
|
+
F = problem.evaluate(Xcand)
|
|
1325
|
+
return Xcand, F
|
|
1326
|
+
|
|
1327
|
+
def betas_for_alpha_bar(self, T, alpha_bar, max_beta=0.999):
|
|
1328
|
+
"""
|
|
1329
|
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
|
1330
|
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
|
1331
|
+
|
|
1332
|
+
:param T: the number of betas to produce.
|
|
1333
|
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
|
1334
|
+
produces the cumulative product of (1-beta) up to that
|
|
1335
|
+
part of the diffusion process.
|
|
1336
|
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
|
1337
|
+
prevent singularities.
|
|
1338
|
+
"""
|
|
1339
|
+
betas = []
|
|
1340
|
+
for i in range(T):
|
|
1341
|
+
t1 = i / T
|
|
1342
|
+
t2 = (i + 1) / T
|
|
1343
|
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
|
1344
|
+
return torch.from_numpy(np.array(betas)).float()
|
|
1345
|
+
|
|
1346
|
+
def cosine_beta_schedule(self, s=0.008):
|
|
1347
|
+
"""
|
|
1348
|
+
Cosine schedule for beta values over timesteps.
|
|
1349
|
+
"""
|
|
1350
|
+
return self.betas_for_alpha_bar(
|
|
1351
|
+
self.timesteps,
|
|
1352
|
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
def l_simple_loss(self, predicted_noise, actual_noise):
|
|
1356
|
+
return nn.MSELoss()(predicted_noise, actual_noise)
|
|
1357
|
+
|
|
1358
|
+
def get_target_dir(self, grads, mth="mgda", x=None):
|
|
1359
|
+
m = len(grads)
|
|
1360
|
+
if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
|
|
1361
|
+
assert mth != "mgda", "MGDA not supported with constraints. Use mth ='pmgda'."
|
|
1362
|
+
|
|
1363
|
+
if mth == "mgda":
|
|
1364
|
+
"""
|
|
1365
|
+
Compute the MGDA combined descent direction given a list of gradient tensors.
|
|
1366
|
+
All grads are assumed to have the same shape (parameters' shape).
|
|
1367
|
+
Returns a tensor of the same shape as each gradient, representing the direction g.
|
|
1368
|
+
"""
|
|
1369
|
+
# Flatten gradients and stack into matrix of shape (m, p), where p is number of params
|
|
1370
|
+
flat_grads = [g.reshape(-1) for g in grads]
|
|
1371
|
+
G = torch.stack(flat_grads, dim=0) # shape: (m, p)
|
|
1372
|
+
# Compute Gram matrix of size (m, m): entry (i,j) = g_i \cdot g_j
|
|
1373
|
+
gram_matrix = G @ G.t() # shape: (m, m)
|
|
1374
|
+
|
|
1375
|
+
# Solve quadratic problem: minimize 0.5 * alpha^T Gram * alpha s.t. sum(alpha)=1, alpha>=0
|
|
1376
|
+
# We use the closed-form solution via KKT for equality constraint, then adjust for alpha>=0.
|
|
1377
|
+
ones = torch.ones(m, device=gram_matrix.device, dtype=gram_matrix.dtype)
|
|
1378
|
+
# Solve Gram * alpha = mu * 1 (plus sum(alpha)=1). This is a linear system with Lagrange multiplier mu.
|
|
1379
|
+
# Use pseudo-inverse in case Gram is singular.
|
|
1380
|
+
inv_gram = torch.linalg.pinv(gram_matrix)
|
|
1381
|
+
alpha = inv_gram @ ones # solution of Gram * alpha = 1 (unnormalized)
|
|
1382
|
+
alpha = alpha / alpha.sum() # enforce sum(alpha) = 1
|
|
1383
|
+
|
|
1384
|
+
# Clamp negative weights to 0 and renormalize if needed (active-set correction for constraints)
|
|
1385
|
+
if (alpha < 0).any():
|
|
1386
|
+
alpha = torch.clamp(alpha, min=0.0)
|
|
1387
|
+
if alpha.sum() == 0:
|
|
1388
|
+
# If all alpha became 0 (numerical issues), fall back to equal weights
|
|
1389
|
+
alpha = torch.ones(m, device=alpha.device) / m
|
|
1390
|
+
else:
|
|
1391
|
+
alpha = alpha / alpha.sum()
|
|
1392
|
+
|
|
1393
|
+
# Compute the combined gradient direction g
|
|
1394
|
+
# Reshape each gradient to original shape and sum with weights
|
|
1395
|
+
g = torch.zeros_like(grads[0])
|
|
1396
|
+
for weight, grad in zip(alpha, grads):
|
|
1397
|
+
g += weight * grad
|
|
1398
|
+
elif mth == "pmgda":
|
|
1399
|
+
pre_h_vals=None
|
|
1400
|
+
constraint_mtd='pbi'
|
|
1401
|
+
SOLVER = PMGDASolver(self.problem, prefs=None,
|
|
1402
|
+
n_prob=grads[0].shape[0], n_obj=self.problem.n_obj,
|
|
1403
|
+
verbose=False)
|
|
1404
|
+
y = self.objective_functions(x, get_constraint=True)
|
|
1405
|
+
# print("y.keys():", y.keys())
|
|
1406
|
+
if "H" in y:
|
|
1407
|
+
pre_h_vals = y["H"].sum(dim=1)
|
|
1408
|
+
constraint_mtd='eq'
|
|
1409
|
+
# print("pre_h_vals.shape:", pre_h_vals.shape)
|
|
1410
|
+
elif "G" in y:
|
|
1411
|
+
# print("pre_h_vals.shape before:", y["G"].shape)
|
|
1412
|
+
pre_h_vals = y["G"].sum(dim=1)
|
|
1413
|
+
print("pre_h_vals.shape:", pre_h_vals.shape)
|
|
1414
|
+
constraint_mtd='ineq'
|
|
1415
|
+
# print("pre_h_vals.shape:", pre_h_vals.shape)
|
|
1416
|
+
y = y["F"]
|
|
1417
|
+
# print("pre_h_vals.shape:", pre_h_vals.shape)
|
|
1418
|
+
alphas = SOLVER.compute_weights(x, y, pre_h_vals=pre_h_vals,
|
|
1419
|
+
constraint_mtd=constraint_mtd)
|
|
1420
|
+
alphas = torch.nan_to_num(alphas, nan=torch.nanmean(alphas),
|
|
1421
|
+
posinf=0.0, neginf=0.0).split(1, dim=1)
|
|
1422
|
+
g = torch.zeros_like(grads[0])
|
|
1423
|
+
for weight, grad in zip(alphas, grads):
|
|
1424
|
+
g += weight * grad
|
|
1425
|
+
else:
|
|
1426
|
+
raise ValueError(f"Method {mth} not recognized.")
|
|
1427
|
+
return g
|
|
1428
|
+
|
|
1429
|
+
def mgd_armijo_step(
|
|
1430
|
+
self,
|
|
1431
|
+
x_t: torch.Tensor,
|
|
1432
|
+
d: torch.Tensor,
|
|
1433
|
+
f_old,
|
|
1434
|
+
grads, # (N, m, d)
|
|
1435
|
+
eta_init=0.9,
|
|
1436
|
+
rho=0.9,
|
|
1437
|
+
c1=1e-4,
|
|
1438
|
+
max_backtracks=100,
|
|
1439
|
+
):
|
|
1440
|
+
"""
|
|
1441
|
+
Batched Armijo back-tracking line search for Multiple-Gradient-Descent (MGD).
|
|
1442
|
+
|
|
1443
|
+
Returns
|
|
1444
|
+
-------
|
|
1445
|
+
eta : torch.Tensor, shape (N,)
|
|
1446
|
+
Final step sizes.
|
|
1447
|
+
"""
|
|
1448
|
+
|
|
1449
|
+
x = x_t.clone().detach()
|
|
1450
|
+
|
|
1451
|
+
if not torch.is_tensor(eta_init):
|
|
1452
|
+
eta = torch.full((x.shape[0],), float(eta_init),
|
|
1453
|
+
dtype=x.dtype, device=x.device)
|
|
1454
|
+
else:
|
|
1455
|
+
eta = eta_init.clone().to(x)
|
|
1456
|
+
|
|
1457
|
+
grad_dot = torch.einsum('nkd,nd->nk', grads, d)
|
|
1458
|
+
|
|
1459
|
+
improve = torch.ones_like(eta, dtype=torch.bool)
|
|
1460
|
+
|
|
1461
|
+
for _ in range(max_backtracks):
|
|
1462
|
+
if not improve.any():
|
|
1463
|
+
break
|
|
1464
|
+
|
|
1465
|
+
# Evaluate objectives at trial points
|
|
1466
|
+
trial_x = x[improve] + eta[improve, None] * d[improve]
|
|
1467
|
+
f_new = self.objective_functions(trial_x)
|
|
1468
|
+
|
|
1469
|
+
# Armijo test (vectorised over objectives)
|
|
1470
|
+
# f_new <= f_old + c1 * eta * grad_dot (element-wise)
|
|
1471
|
+
ok = (f_new <= f_old[improve] + c1 * eta[improve, None] * grad_dot[improve]).all(dim=1)
|
|
1472
|
+
|
|
1473
|
+
# Update masks and step sizes
|
|
1474
|
+
eta[improve] = torch.where(ok, eta[improve], rho * eta[improve])
|
|
1475
|
+
improve_mask = improve.clone()
|
|
1476
|
+
improve[improve_mask] = ~ok
|
|
1477
|
+
|
|
1478
|
+
return eta[:, None]
|
|
1479
|
+
|
|
1480
|
+
def adaptive_scale_delta_vect(
|
|
1481
|
+
self,
|
|
1482
|
+
g: torch.Tensor,
|
|
1483
|
+
delta_raw: torch.Tensor,
|
|
1484
|
+
grads: torch.Tensor,
|
|
1485
|
+
gamma: float = 0.9
|
|
1486
|
+
) -> torch.Tensor:
|
|
1487
|
+
"""
|
|
1488
|
+
Adaptive scaling to preserve *positivity*:
|
|
1489
|
+
|
|
1490
|
+
∇f_j(x_i)^T [ g_i + rho_i * delta_raw_i ] > 0 for all j.
|
|
1491
|
+
|
|
1492
|
+
Args:
|
|
1493
|
+
g (torch.Tensor): [n_points, d], the multi-objective "gradient"
|
|
1494
|
+
(which we *subtract* in the update).
|
|
1495
|
+
delta_raw (torch.Tensor): [n_points, d] or [1, d], the unscaled diversity/repulsion direction.
|
|
1496
|
+
grads (torch.Tensor): [m, n_points, d], storing ∇f_j(x_i).
|
|
1497
|
+
gamma (float): Safety factor in (0,1).
|
|
1498
|
+
|
|
1499
|
+
Returns:
|
|
1500
|
+
delta_scaled (torch.Tensor): [n_points, d], scaled directions s.t.
|
|
1501
|
+
for all j: grads[j,i]ᵀ [g[i] + delta_scaled[i]] > 0.
|
|
1502
|
+
"""
|
|
1503
|
+
|
|
1504
|
+
# Compute alpha_{i,j} = ∇f_j(x_i)^T g_i
|
|
1505
|
+
# shape of alpha: [n_points, m]
|
|
1506
|
+
alpha = torch.einsum("j i d, i d -> i j", grads, g)
|
|
1507
|
+
|
|
1508
|
+
# Compute beta_{i,j} = ∇f_j(x_i)^T delta_raw_i
|
|
1509
|
+
# shape of beta: [n_points, m]
|
|
1510
|
+
beta = torch.einsum("j i d, i d -> i j", grads, delta_raw)
|
|
1511
|
+
|
|
1512
|
+
# We only need to restrict rho_i if alpha_{i,j} > 0 and beta_{i,j} < 0.
|
|
1513
|
+
# Because for alpha + rho*beta to stay > 0, we need
|
|
1514
|
+
# rho < alpha / -beta
|
|
1515
|
+
# when beta<0 and alpha>0.
|
|
1516
|
+
mask = (alpha > 0.0) & (beta < 0.0)
|
|
1517
|
+
|
|
1518
|
+
# Prepare an array of ratios = alpha / -beta, default +∞
|
|
1519
|
+
ratio = torch.full_like(alpha, float("inf"))
|
|
1520
|
+
|
|
1521
|
+
# Where mask is True, compute ratio_{i,j}
|
|
1522
|
+
ratio[mask] = alpha[mask] / (-beta[mask]) # must remain below this
|
|
1523
|
+
|
|
1524
|
+
# For each point i, we pick rho_i = gamma * min_j ratio[i,j].
|
|
1525
|
+
# If the min is +∞ => no constraints => set rho_i=1.0
|
|
1526
|
+
ratio_min, _ = ratio.min(dim=1) # [n_points]
|
|
1527
|
+
rho = gamma * ratio_min
|
|
1528
|
+
# If ratio_min == +∞ => no constraint => set rho_i=1.
|
|
1529
|
+
inf_mask = torch.isinf(ratio_min)
|
|
1530
|
+
rho[inf_mask] = 1.0
|
|
1531
|
+
|
|
1532
|
+
# Scale delta_raw by rho_i
|
|
1533
|
+
delta_scaled = delta_raw * rho.unsqueeze(1) # broadcast along dim
|
|
1534
|
+
|
|
1535
|
+
return delta_scaled
|
|
1536
|
+
|
|
1537
|
+
def repair_bounds(self, x):
|
|
1538
|
+
"""
|
|
1539
|
+
Clips a tensor x of shape [N, d] such that for each column j:
|
|
1540
|
+
x[:, j] is clipped to be between xl[j] and xu[j].
|
|
1541
|
+
|
|
1542
|
+
Parameters:
|
|
1543
|
+
x (torch.Tensor): Input tensor of shape [N, d].
|
|
1544
|
+
|
|
1545
|
+
Returns:
|
|
1546
|
+
torch.Tensor: The clipped tensor with the same shape as x.
|
|
1547
|
+
"""
|
|
1548
|
+
|
|
1549
|
+
xl, xu = self.problem.bounds()[0], self.problem.bounds()[1]
|
|
1550
|
+
lower = xl.detach().clone().to(x.device)
|
|
1551
|
+
upper = xu.detach().clone().to(x.device)
|
|
1552
|
+
|
|
1553
|
+
if self.problem.global_clamping:
|
|
1554
|
+
return torch.clamp(x.data.clone(), min=lower.min(), max=upper.max())
|
|
1555
|
+
else:
|
|
1556
|
+
return torch.clamp(x.data.clone(), min=lower, max=upper)
|
|
1557
|
+
|
|
1558
|
+
def repulsion_loss(self,
|
|
1559
|
+
F_,
|
|
1560
|
+
sigma=1.0,
|
|
1561
|
+
use_sigma=False):
|
|
1562
|
+
"""
|
|
1563
|
+
Computes the repulsion loss over a batch of points in the objective space.
|
|
1564
|
+
F_: Tensors of shape (n, m), where n is the batch size.
|
|
1565
|
+
Only unique pairs (i < j) are considered.
|
|
1566
|
+
"""
|
|
1567
|
+
n = F_.shape[0]
|
|
1568
|
+
# Compute pairwise differences: shape [n, n, m]
|
|
1569
|
+
dist_sq = torch.norm(F_[:, None] - F_, dim=2).pow(2)
|
|
1570
|
+
# Compute RBF values for the distances
|
|
1571
|
+
if use_sigma:
|
|
1572
|
+
repulsion = torch.exp(-dist_sq / (2 * sigma**2))
|
|
1573
|
+
else:
|
|
1574
|
+
tensor = dist_sq.detach().flatten()
|
|
1575
|
+
tensor_max = tensor.max()[None]
|
|
1576
|
+
median_dist = (torch.cat((tensor, tensor_max)).median() + tensor.median()) / 2.0
|
|
1577
|
+
s = median_dist / math.log(n)
|
|
1578
|
+
repulsion = torch.exp(-dist_sq / 5e-6 * s)
|
|
1579
|
+
|
|
1580
|
+
# Normalize by the number of pairs
|
|
1581
|
+
loss = (2/(n*(n-1))) * repulsion.sum()
|
|
1582
|
+
return loss
|
|
1583
|
+
|
|
1584
|
+
def eps_dominance(self, Obj_space, alpha=0.0):
|
|
1585
|
+
epsilon = alpha * np.min(Obj_space, axis=0)
|
|
1586
|
+
N = len(Obj_space)
|
|
1587
|
+
Pareto_set_idx = list(range(N))
|
|
1588
|
+
Dominated = []
|
|
1589
|
+
for i in range(N):
|
|
1590
|
+
candt = Obj_space[i] - epsilon
|
|
1591
|
+
for j in range(N):
|
|
1592
|
+
if np.all(candt >= Obj_space[j]) and np.any(candt > Obj_space[j]):
|
|
1593
|
+
Dominated.append(i)
|
|
1594
|
+
break
|
|
1595
|
+
PS_idx = list(set(Pareto_set_idx) - set(Dominated))
|
|
1596
|
+
return PS_idx
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def get_non_dominated_points(
|
|
1600
|
+
self,
|
|
1601
|
+
points_pred=None,
|
|
1602
|
+
keep_shape=True,
|
|
1603
|
+
indx_only=False,
|
|
1604
|
+
p_front=None
|
|
1605
|
+
):
|
|
1606
|
+
if not indx_only and points_pred is None:
|
|
1607
|
+
raise ValueError("points_pred cannot be None when indx_only is False.")
|
|
1608
|
+
if points_pred is not None:
|
|
1609
|
+
pf_points = copy.deepcopy(points_pred.detach())
|
|
1610
|
+
p_front = self.objective_functions(pf_points).detach().cpu().numpy()
|
|
1611
|
+
else:
|
|
1612
|
+
assert p_front is not None, "p_front must be provided if points_pred is None."
|
|
1613
|
+
|
|
1614
|
+
if self.mode in ["online", "offline"]:
|
|
1615
|
+
PS_idx = self.eps_dominance(p_front)
|
|
1616
|
+
elif self.mode == "bayesian":
|
|
1617
|
+
N = points_pred.shape[0]
|
|
1618
|
+
# 1) Predict dominance
|
|
1619
|
+
label_matrix, _ = nn_predict_dom_intra(points_pred.detach().cpu().numpy(),
|
|
1620
|
+
self.dominance_classifier,
|
|
1621
|
+
self.device)
|
|
1622
|
+
# 2) Find non‑dominated indices
|
|
1623
|
+
PS_idx = [
|
|
1624
|
+
i for i in range(N)
|
|
1625
|
+
if not any(label_matrix[j, i] == 2 for j in range(N))
|
|
1626
|
+
]
|
|
1627
|
+
# print(f"Number of non-dominated points: {len(PS_idx)} out of {N}")
|
|
1628
|
+
else:
|
|
1629
|
+
raise ValueError(f"Mode {self.mode} not recognized!")
|
|
1630
|
+
|
|
1631
|
+
if indx_only:
|
|
1632
|
+
return PS_idx
|
|
1633
|
+
|
|
1634
|
+
elif keep_shape:
|
|
1635
|
+
PS_idx = np.sort(PS_idx)
|
|
1636
|
+
# Create an array of all indices
|
|
1637
|
+
all_indices = np.arange(p_front.shape[0])
|
|
1638
|
+
# Identify the indices not in PS_idx
|
|
1639
|
+
not_in_PS_idx = np.setdiff1d(all_indices, PS_idx)
|
|
1640
|
+
# For each index not in PS_idx, find the nearest index in PS_idx
|
|
1641
|
+
for idx in not_in_PS_idx:
|
|
1642
|
+
# Compute the distance to all indices in PS_idx
|
|
1643
|
+
distances = np.abs(PS_idx - idx)
|
|
1644
|
+
nearest_idx = PS_idx[np.argmin(distances)] # Find the closest index
|
|
1645
|
+
pf_points[idx] = pf_points[
|
|
1646
|
+
nearest_idx
|
|
1647
|
+
] # Replace with the value at the closest index
|
|
1648
|
+
|
|
1649
|
+
return pf_points, points_pred, PS_idx
|
|
1650
|
+
|
|
1651
|
+
else:
|
|
1652
|
+
return pf_points[PS_idx], points_pred, PS_idx
|
|
1653
|
+
|
|
1654
|
+
|
|
1655
|
+
def crowding_distance(self, points):
|
|
1656
|
+
"""
|
|
1657
|
+
Compute crowding distances for points.
|
|
1658
|
+
points: Tensor of shape (N, D) in the objective space.
|
|
1659
|
+
Returns: Tensor of shape (N,) containing crowding distances.
|
|
1660
|
+
"""
|
|
1661
|
+
N, D = points.shape
|
|
1662
|
+
distances = torch.zeros(N, device=points.device)
|
|
1663
|
+
|
|
1664
|
+
for d in range(D):
|
|
1665
|
+
sorted_points, indices = torch.sort(points[:, d])
|
|
1666
|
+
distances[indices[0]] = distances[indices[-1]] = float("inf")
|
|
1667
|
+
|
|
1668
|
+
min_d, max_d = sorted_points[0], sorted_points[-1]
|
|
1669
|
+
norm_range = max_d - min_d if max_d > min_d else 1.0
|
|
1670
|
+
|
|
1671
|
+
# Compute normalized crowding distance
|
|
1672
|
+
distances[indices[1:-1]] += (
|
|
1673
|
+
sorted_points[2:] - sorted_points[:-2]
|
|
1674
|
+
) / norm_range
|
|
1675
|
+
|
|
1676
|
+
return distances
|
|
1677
|
+
|
|
1678
|
+
|
|
1679
|
+
def select_top_n_candidates(
|
|
1680
|
+
self,
|
|
1681
|
+
points: torch.Tensor,
|
|
1682
|
+
n: int,
|
|
1683
|
+
top_frac: float = 0.9
|
|
1684
|
+
) -> torch.Tensor:
|
|
1685
|
+
"""
|
|
1686
|
+
Selects the top `n` points from `points` based on crowding distance.
|
|
1687
|
+
|
|
1688
|
+
Returns:
|
|
1689
|
+
torch.Tensor: The best subset of points (shape [n, D]).
|
|
1690
|
+
"""
|
|
1691
|
+
|
|
1692
|
+
if self.mode in ["online", "offline"]:
|
|
1693
|
+
if len(points) <= n:
|
|
1694
|
+
final_idx = torch.randperm(points.size(0))
|
|
1695
|
+
else:
|
|
1696
|
+
full_p_front = self.objective_functions(points)
|
|
1697
|
+
distances = self.crowding_distance(full_p_front)
|
|
1698
|
+
top_indices = torch.topk(distances, n).indices
|
|
1699
|
+
final_idx = top_indices[torch.randperm(top_indices.size(0))]
|
|
1700
|
+
else:
|
|
1701
|
+
N = points.shape[0]
|
|
1702
|
+
# print(f"In selection, N={N}, n={n}")
|
|
1703
|
+
# 1) Predict dominance
|
|
1704
|
+
label_matrix, conf_matrix = nn_predict_dom_intra(points.detach().cpu().numpy(),
|
|
1705
|
+
self.dominance_classifier,
|
|
1706
|
+
self.device)
|
|
1707
|
+
# 2) Find non‑dominated indices
|
|
1708
|
+
nondom_inds = [
|
|
1709
|
+
i for i in range(N)
|
|
1710
|
+
if not any(label_matrix[j, i] == 2 for j in range(N))
|
|
1711
|
+
]
|
|
1712
|
+
# print(f"(selection) Number of non-dominated points: {len(nondom_inds)} out of {N}")
|
|
1713
|
+
|
|
1714
|
+
# --- CASE A: too many non‑dominated → pick top-n by crowding ---
|
|
1715
|
+
if len(nondom_inds) > n:
|
|
1716
|
+
# Evaluate objectives on just the non-dominated set
|
|
1717
|
+
pts_nd = points[nondom_inds].to(self.device)
|
|
1718
|
+
Y_t = self.objective_functions(pts_nd)
|
|
1719
|
+
|
|
1720
|
+
# Compute crowding distances and select top-n
|
|
1721
|
+
distances = self.crowding_distance(Y_t)
|
|
1722
|
+
topk = torch.topk(distances, n).indices.tolist()
|
|
1723
|
+
|
|
1724
|
+
selected_nd = [nondom_inds[i] for i in topk]
|
|
1725
|
+
|
|
1726
|
+
# Shuffle before returning
|
|
1727
|
+
perm = torch.randperm(n, device=points.device)
|
|
1728
|
+
final_idx = torch.tensor(selected_nd, device=points.device)[perm]
|
|
1729
|
+
return points[final_idx] #.detach().cpu().numpy()
|
|
1730
|
+
|
|
1731
|
+
# --- CASE B: nondom ≤ n → fill up via rank + top_frac% + crowding ---
|
|
1732
|
+
# 3) Compute dom counts & avg confidence for all
|
|
1733
|
+
dom_counts = []
|
|
1734
|
+
avg_conf = []
|
|
1735
|
+
for i in range(N):
|
|
1736
|
+
dom_by = (label_matrix[:, i] == 2)
|
|
1737
|
+
cnt = int(dom_by.sum())
|
|
1738
|
+
dom_counts.append(cnt)
|
|
1739
|
+
avg_conf.append(
|
|
1740
|
+
float(conf_matrix[dom_by, i].sum()) / cnt
|
|
1741
|
+
if cnt > 0 else 0.0
|
|
1742
|
+
)
|
|
1743
|
+
|
|
1744
|
+
# 4) Sort full points by (dom_count asc, avg_conf desc)
|
|
1745
|
+
idxs = list(range(N))
|
|
1746
|
+
idxs.sort(key=lambda i: (dom_counts[i], -avg_conf[i]))
|
|
1747
|
+
|
|
1748
|
+
# 5) Keep only top top_frac% of that ranking
|
|
1749
|
+
k90 = int(np.floor(top_frac * N))
|
|
1750
|
+
top90 = idxs[:k90]
|
|
1751
|
+
|
|
1752
|
+
# 6) Evaluate
|
|
1753
|
+
pts90 = points[top90]
|
|
1754
|
+
Y_t = self.objective_functions(pts90)
|
|
1755
|
+
|
|
1756
|
+
# 7) Crowding distance & pick as many as needed to reach n
|
|
1757
|
+
distances = self.crowding_distance(Y_t)
|
|
1758
|
+
need = n - len(nondom_inds)
|
|
1759
|
+
need = max(need, 0)
|
|
1760
|
+
k_sel = min(need, len(top90))
|
|
1761
|
+
sel90 = torch.topk(distances, k_sel).indices.tolist()
|
|
1762
|
+
selected_from_top_frac = [ top90[i] for i in sel90 ]
|
|
1763
|
+
|
|
1764
|
+
# 8) Build final list: all nondom + selected_from_top_frac
|
|
1765
|
+
final_inds = nondom_inds + selected_from_top_frac
|
|
1766
|
+
|
|
1767
|
+
# 9) If still short (e.g. N<n), pad with best remaining in idxs
|
|
1768
|
+
if len(final_inds) < n:
|
|
1769
|
+
remaining = [i for i in idxs if i not in final_inds]
|
|
1770
|
+
to_add = n - len(final_inds)
|
|
1771
|
+
final_inds += remaining[:to_add]
|
|
1772
|
+
|
|
1773
|
+
# 10) Shuffle final indices
|
|
1774
|
+
perm = torch.randperm(len(final_inds), device=points.device)
|
|
1775
|
+
final_idx = torch.tensor(final_inds, device=points.device)[perm]
|
|
1776
|
+
|
|
1777
|
+
return points[final_idx]
|
|
1778
|
+
|
|
1779
|
+
def plot_pareto_front(self,
|
|
1780
|
+
list_fi,
|
|
1781
|
+
t,
|
|
1782
|
+
num_points_sample,
|
|
1783
|
+
extra=None,
|
|
1784
|
+
label=None,
|
|
1785
|
+
plot_dataset=False,
|
|
1786
|
+
pop=None,
|
|
1787
|
+
images_store_path="./images_dir/"):
|
|
1788
|
+
name = (
|
|
1789
|
+
"spread"
|
|
1790
|
+
+ "_"
|
|
1791
|
+
+ self.problem.__class__.__name__
|
|
1792
|
+
+ "_"
|
|
1793
|
+
+ f"T={self.timesteps}"
|
|
1794
|
+
+ "_"
|
|
1795
|
+
+ f"N={num_points_sample}"
|
|
1796
|
+
+ "_"
|
|
1797
|
+
+ f"t={t}"
|
|
1798
|
+
+ "_"
|
|
1799
|
+
+ f"seed={self.seed}"
|
|
1800
|
+
+ "_"
|
|
1801
|
+
+ self.mode
|
|
1802
|
+
)
|
|
1803
|
+
if label is not None:
|
|
1804
|
+
name += f"_{label}"
|
|
1805
|
+
|
|
1806
|
+
if len(list_fi) > 3:
|
|
1807
|
+
return None
|
|
1808
|
+
|
|
1809
|
+
elif len(list_fi) == 2:
|
|
1810
|
+
if extra is not None:
|
|
1811
|
+
f1, f2 = extra
|
|
1812
|
+
plt.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,)
|
|
1813
|
+
# label="Pareto optimal points")
|
|
1814
|
+
|
|
1815
|
+
if pop is not None:
|
|
1816
|
+
f_pop1, f_pop2 = pop
|
|
1817
|
+
plt.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,)
|
|
1818
|
+
# label="Population points")
|
|
1819
|
+
|
|
1820
|
+
f1, f2 = list_fi
|
|
1821
|
+
plt.scatter(f1, f2, c="red", s=10, alpha=1.0,)
|
|
1822
|
+
# label="Generated optimal points")
|
|
1823
|
+
if plot_dataset and (self.dataset) is not None:
|
|
1824
|
+
_, Y = self.dataset
|
|
1825
|
+
Y = self.offline_denormalization(Y,
|
|
1826
|
+
self.y_meanormin,
|
|
1827
|
+
self.y_stdormax)
|
|
1828
|
+
plt.scatter(Y[:, 0], Y[:, 1],
|
|
1829
|
+
c="blue", s=5, alpha=1.0,)
|
|
1830
|
+
# label="Training data points")
|
|
1831
|
+
|
|
1832
|
+
plt.xlabel("$f_1$", fontsize=14)
|
|
1833
|
+
plt.ylabel("$f_2$", fontsize=14)
|
|
1834
|
+
plt.title(f"Reverse Time Step: {t}", fontsize=14)
|
|
1835
|
+
|
|
1836
|
+
elif len(list_fi) == 3:
|
|
1837
|
+
fig = plt.figure()
|
|
1838
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
1839
|
+
|
|
1840
|
+
if extra is not None:
|
|
1841
|
+
f1, f2, f3 = extra
|
|
1842
|
+
ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=0.05,)
|
|
1843
|
+
# label="Pareto optimal points")
|
|
1844
|
+
|
|
1845
|
+
if pop is not None:
|
|
1846
|
+
f_pop1, f_pop2, f_pop3 = pop
|
|
1847
|
+
ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,)
|
|
1848
|
+
# label="Population points")
|
|
1849
|
+
|
|
1850
|
+
f1, f2, f3 = list_fi
|
|
1851
|
+
ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,)
|
|
1852
|
+
# label="Generated optimal points")
|
|
1853
|
+
|
|
1854
|
+
if plot_dataset and (self.dataset is not None):
|
|
1855
|
+
_, Y = self.dataset
|
|
1856
|
+
Y = self.offline_denormalization(Y,
|
|
1857
|
+
self.y_meanormin,
|
|
1858
|
+
self.y_stdormax)
|
|
1859
|
+
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
|
|
1860
|
+
c="blue", s=5, alpha=1.0,)
|
|
1861
|
+
# label="Training data points")
|
|
1862
|
+
ax.set_xlabel("$f_1$", fontsize=14)
|
|
1863
|
+
ax.set_ylabel("$f_2$", fontsize=14)
|
|
1864
|
+
ax.set_zlabel("$f_3$", fontsize=14)
|
|
1865
|
+
ax.view_init(elev=30, azim=45)
|
|
1866
|
+
ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
|
|
1867
|
+
|
|
1868
|
+
img_dir = f"{images_store_path}/{self.problem.__class__.__name__}_{self.mode}"
|
|
1869
|
+
if label is not None:
|
|
1870
|
+
img_dir += f"_{label}"
|
|
1871
|
+
if not os.path.exists(img_dir):
|
|
1872
|
+
os.makedirs(img_dir)
|
|
1873
|
+
|
|
1874
|
+
# plt.legend(fontsize=12)
|
|
1875
|
+
|
|
1876
|
+
plt.savefig(
|
|
1877
|
+
f"{img_dir}/{name}.jpg",
|
|
1878
|
+
dpi=300,
|
|
1879
|
+
bbox_inches="tight",
|
|
1880
|
+
)
|
|
1881
|
+
plt.close()
|