moospread 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. moospread/__init__.py +3 -0
  2. moospread/core.py +1881 -0
  3. moospread/problem.py +193 -0
  4. moospread/tasks/__init__.py +4 -0
  5. moospread/tasks/dtlz_torch.py +139 -0
  6. moospread/tasks/mw_torch.py +274 -0
  7. moospread/tasks/re_torch.py +394 -0
  8. moospread/tasks/zdt_torch.py +112 -0
  9. moospread/utils/__init__.py +8 -0
  10. moospread/utils/constraint_utils/__init__.py +2 -0
  11. moospread/utils/constraint_utils/gradient.py +72 -0
  12. moospread/utils/constraint_utils/mgda_core.py +69 -0
  13. moospread/utils/constraint_utils/pmgda_solver.py +308 -0
  14. moospread/utils/constraint_utils/prefs.py +64 -0
  15. moospread/utils/ditmoo.py +127 -0
  16. moospread/utils/lhs.py +74 -0
  17. moospread/utils/misc.py +28 -0
  18. moospread/utils/mobo_utils/__init__.py +11 -0
  19. moospread/utils/mobo_utils/evolution/__init__.py +0 -0
  20. moospread/utils/mobo_utils/evolution/dom.py +60 -0
  21. moospread/utils/mobo_utils/evolution/norm.py +40 -0
  22. moospread/utils/mobo_utils/evolution/utils.py +97 -0
  23. moospread/utils/mobo_utils/learning/__init__.py +0 -0
  24. moospread/utils/mobo_utils/learning/model.py +40 -0
  25. moospread/utils/mobo_utils/learning/model_init.py +33 -0
  26. moospread/utils/mobo_utils/learning/model_update.py +51 -0
  27. moospread/utils/mobo_utils/learning/prediction.py +116 -0
  28. moospread/utils/mobo_utils/learning/utils.py +143 -0
  29. moospread/utils/mobo_utils/lhs_for_mobo.py +243 -0
  30. moospread/utils/mobo_utils/mobo/__init__.py +0 -0
  31. moospread/utils/mobo_utils/mobo/acquisition.py +209 -0
  32. moospread/utils/mobo_utils/mobo/algorithms.py +91 -0
  33. moospread/utils/mobo_utils/mobo/factory.py +86 -0
  34. moospread/utils/mobo_utils/mobo/mobo.py +132 -0
  35. moospread/utils/mobo_utils/mobo/selection.py +182 -0
  36. moospread/utils/mobo_utils/mobo/solver/__init__.py +5 -0
  37. moospread/utils/mobo_utils/mobo/solver/moead.py +17 -0
  38. moospread/utils/mobo_utils/mobo/solver/nsga2.py +10 -0
  39. moospread/utils/mobo_utils/mobo/solver/parego/__init__.py +1 -0
  40. moospread/utils/mobo_utils/mobo/solver/parego/parego.py +62 -0
  41. moospread/utils/mobo_utils/mobo/solver/parego/utils.py +34 -0
  42. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/__init__.py +1 -0
  43. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/buffer.py +364 -0
  44. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/pareto_discovery.py +571 -0
  45. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/utils.py +168 -0
  46. moospread/utils/mobo_utils/mobo/solver/solver.py +74 -0
  47. moospread/utils/mobo_utils/mobo/surrogate_model/__init__.py +2 -0
  48. moospread/utils/mobo_utils/mobo/surrogate_model/base.py +36 -0
  49. moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py +177 -0
  50. moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py +79 -0
  51. moospread/utils/mobo_utils/mobo/surrogate_problem.py +44 -0
  52. moospread/utils/mobo_utils/mobo/transformation.py +106 -0
  53. moospread/utils/mobo_utils/mobo/utils.py +65 -0
  54. moospread/utils/mobo_utils/spread_mobo_utils.py +854 -0
  55. moospread/utils/offline_utils/__init__.py +10 -0
  56. moospread/utils/offline_utils/handle_task.py +203 -0
  57. moospread/utils/offline_utils/proxies.py +338 -0
  58. moospread/utils/spread_utils.py +91 -0
  59. moospread-0.1.0.dist-info/METADATA +75 -0
  60. moospread-0.1.0.dist-info/RECORD +63 -0
  61. moospread-0.1.0.dist-info/WHEEL +5 -0
  62. moospread-0.1.0.dist-info/licenses/LICENSE +10 -0
  63. moospread-0.1.0.dist-info/top_level.txt +1 -0
moospread/core.py ADDED
@@ -0,0 +1,1881 @@
1
+ """Main module."""
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+
9
+ import copy
10
+ import math
11
+ import json
12
+ import os
13
+ import pickle
14
+ from tqdm import tqdm
15
+ import datetime
16
+ from time import time
17
+ import dis
18
+
19
+ from pymoo.config import Config
20
+ Config.warnings['not_compiled'] = False
21
+ from pymoo.indicators.hv import HV
22
+ from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
23
+
24
+ import matplotlib.pyplot as plt
25
+ from mpl_toolkits.mplot3d import Axes3D
26
+
27
+ from moospread.utils import *
28
+
29
+ class SPREAD:
30
+ def __init__(self,
31
+ problem,
32
+ mode: str = "online",
33
+ model = None,
34
+ surrogate_model = None,
35
+ dataset = None,
36
+ xi_shift = None,
37
+ data_size: int = 10000,
38
+ validation_split=0.1,
39
+ hidden_dim: int = 128,
40
+ num_heads: int = 4,
41
+ num_blocks: int = 2,
42
+ timesteps: int = 1000,
43
+ batch_size: int = 256,
44
+ train_lr: float = 1e-4,
45
+ train_lr_surrogate: float = 1e-4,
46
+ num_epochs: int = 1000,
47
+ num_epochs_surrogate: int = 1000,
48
+ train_tol: int = 100,
49
+ train_tol_surrogate: int = 100,
50
+ mobo_coef_lcb=0.1,
51
+ model_dir: str = "./model_dir",
52
+ proxies_store_path: str = "./proxies_dir",
53
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
54
+ seed: int = 0,
55
+ offline_global_clamping: bool = False,
56
+ offline_normalization_method: str = "z_score",
57
+ dominance_classifier = None,
58
+ train_func_surrogate = None,
59
+ plot_func = None,
60
+ verbose: bool = True):
61
+
62
+ self.mode = mode.lower()
63
+ if self.mode not in ["offline", "online", "bayesian"]:
64
+ raise ValueError(f"Invalid mode: {mode}. Must be one of ['offline', 'online', 'bayesian']")
65
+
66
+ assert problem is not None, "Problem must be provided"
67
+ self.problem = problem
68
+ if self.mode in ["online", "bayesian"]:
69
+ assert not is_pass_function(self.problem._evaluate), "Problem must have the '_evaluate' method implemented."
70
+ self.device = device
71
+
72
+ assert 0.0 <= validation_split < 1.0, "validation_split must be in [0.0, 1.0)"
73
+ self.validation_split = validation_split
74
+ self.train_lr = train_lr
75
+ self.batch_size = batch_size
76
+ self.num_epochs = num_epochs
77
+ self.timesteps = timesteps
78
+ self.train_tol = train_tol
79
+ if self.mode in ["offline", "bayesian"]:
80
+ self.train_lr_surrogate = train_lr_surrogate
81
+ self.num_epochs_surrogate = num_epochs_surrogate
82
+ self.train_tol_surrogate = train_tol_surrogate
83
+ if self.mode == "bayesian":
84
+ self.mobo_coef_lcb = mobo_coef_lcb
85
+
86
+
87
+ self.xi_shift = xi_shift
88
+ self.model_dir = model_dir
89
+ os.makedirs(self.model_dir, exist_ok=True)
90
+
91
+ self.train_func_surrogate = train_func_surrogate
92
+ self.plot_func = plot_func
93
+ self.dominance_classifier = dominance_classifier
94
+
95
+ self.seed = seed
96
+ # Set the seed for reproducibility
97
+ set_seed(self.seed)
98
+
99
+ self.model = model
100
+ if self.model is None:
101
+ self.model = DiTMOO(
102
+ input_dim=problem.n_var,
103
+ num_obj=problem.n_obj,
104
+ hidden_dim=hidden_dim,
105
+ num_heads=num_heads,
106
+ num_blocks=num_blocks
107
+ )
108
+ self.surrogate_model = surrogate_model
109
+ if self.surrogate_model is not None:
110
+ self.surrogate_given = True
111
+ else:
112
+ self.surrogate_given = False
113
+ self.proxies_store_path = proxies_store_path
114
+
115
+ self.verbose = verbose
116
+ if self.verbose:
117
+ total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
118
+ print(f"Total number of learnable parameters: {total_params}")
119
+
120
+ self.dataset = dataset
121
+ if self.mode in ["offline", "online"]:
122
+ if self.dataset is None:
123
+ if self.mode == "offline":
124
+ print("Training dataset not provided for offline mode.")
125
+ assert data_size > 0, "Training data size must be positive."
126
+ # self.problem._evaluate should exist in this case
127
+ assert (not is_pass_function(self.problem._evaluate)), "Problem must have the '_evaluate' method implemented when dataset is not provided in offline mode."
128
+ print("Generating training dataset ...")
129
+ self.dataset = self.get_training_data(self.problem,
130
+ num_samples=data_size)
131
+ if self.verbose:
132
+ print("Training dataset generated.")
133
+
134
+ if self.mode == "offline":
135
+ assert offline_normalization_method in ["z_score", "min_max", "none"], "Invalid normalization method"
136
+ if offline_normalization_method == "z_score":
137
+ self.offline_normalization = offdata_z_score_normalize
138
+ self.offline_denormalization = offdata_z_score_denormalize
139
+ elif offline_normalization_method == "min_max":
140
+ self.offline_normalization = offdata_min_max_normalize
141
+ self.offline_denormalization = offdata_min_max_denormalize
142
+ else:
143
+ self.offline_normalization = lambda x: x
144
+ self.offline_denormalization = lambda x: x
145
+
146
+ self.X_meanormin, self.y_meanormin = 0, 0
147
+ self.X_stdormax, self.y_stdormax = 1, 1
148
+ if self.problem.has_bounds():
149
+ xl = self.problem.xl
150
+ xu = self.problem.xu
151
+ ## Normalize the bounds
152
+ # xl
153
+ if self.problem.is_discrete:
154
+ xl = offdata_to_logits(xl)
155
+ _, n_dim, n_classes = tuple(xl.shape)
156
+ xl = xl.reshape(-1, n_dim * n_classes)
157
+ if self.problem.is_sequence:
158
+ xl = offdata_to_logits(xl)
159
+ xl, _, _ = self.offline_normalization(xl)
160
+ # xu
161
+ if self.problem.is_discrete:
162
+ xu = offdata_to_logits(xu)
163
+ _, n_dim, n_classes = tuple(xu.shape)
164
+ xu = xu.reshape(-1, n_dim * n_classes)
165
+ if self.problem.is_sequence:
166
+ xu = offdata_to_logits(xu)
167
+ xu, _, _ = self.offline_normalization(xu)
168
+ ## Set the normalized bounds
169
+ self.problem.xl = xl
170
+ self.problem.xu = xu
171
+ if offline_global_clamping:
172
+ self.problem.global_clamping = True
173
+
174
+ def objective_functions(self,
175
+ points,
176
+ return_as_dict: bool = False,
177
+ return_values_of=None,
178
+ get_constraint=False,
179
+ get_grad_mobo=False,
180
+ evaluate_true=False):
181
+ if evaluate_true:
182
+ if self.problem.need_repair:
183
+ points = self.repair_bounds(points)
184
+ if get_constraint:
185
+ return self.problem.evaluate(points, return_as_dict=True,
186
+ return_values_of=["F", "G", "H"])
187
+ return self.problem.evaluate(points, return_as_dict=return_as_dict,
188
+ return_values_of=return_values_of)
189
+ # Define the objective functions for the optimization problem
190
+ if self.mode == "online":
191
+ if get_constraint:
192
+ return self.problem.evaluate(points, return_as_dict=True,
193
+ return_values_of=["F", "G", "H"])
194
+ return self.problem.evaluate(points, return_as_dict=return_as_dict,
195
+ return_values_of=return_values_of)
196
+ elif self.mode == "offline":
197
+ scores = []
198
+ for proxy in self.surrogate_model:
199
+ scores.append(proxy(points).squeeze())
200
+ return torch.stack(scores, dim=1)
201
+ elif self.mode == "bayesian":
202
+ x = points.detach().cpu().numpy()
203
+ eval_result = self.surrogate_model.evaluate(x, std=True,
204
+ calc_gradient=get_grad_mobo)
205
+ mean = torch.from_numpy(eval_result["F"]).float().to(self.device)
206
+ std = torch.from_numpy(eval_result["S"]).float().to(self.device)
207
+ Y_val = mean - self.mobo_coef_lcb * std
208
+ if get_grad_mobo:
209
+ out = {}
210
+ mean_grad = torch.from_numpy(eval_result["dF"]).float().to(self.device)
211
+ std_grad = torch.from_numpy(eval_result["dS"]).float().to(self.device)
212
+ Grad_val = mean_grad - self.mobo_coef_lcb * std_grad
213
+ out["dF"] = [Grad_val[:, i, :] for i in range(Grad_val.shape[1])]
214
+ out["F"] = Y_val
215
+ else:
216
+ out = Y_val
217
+ return out
218
+ else:
219
+ raise ValueError(f"Invalid mode: {self.mode}")
220
+
221
+ def solve(self, num_points_sample=500,
222
+ strict_guidance=False,
223
+ rho_scale_gamma=0.9,
224
+ nu_t=10.0, eta_init=0.9,
225
+ num_inner_steps=10, lr_inner=0.9,
226
+ free_initial_h=True,
227
+ use_sigma_rep=False, kernel_sigma_rep=0.01,
228
+ iterative_plot=True, plot_period=100,
229
+ max_backtracks=100, label=None, save_results=True,
230
+ load_models=False,
231
+ samples_store_path="./samples_dir/",
232
+ images_store_path="./images_dir/",
233
+ n_init_mobo=100, use_escape_local_mobo=True,
234
+ n_steps_mobo=20, spread_num_samp_mobo=25,
235
+ batch_select_mobo=5):
236
+ set_seed(self.seed)
237
+ if self.mode in ["offline", "online"]:
238
+ X, y = self.dataset
239
+
240
+ if self.mode == "offline":
241
+ X = X.clone().detach()
242
+ y = y.clone().detach()
243
+ if self.problem.is_discrete:
244
+ X = offdata_to_logits(X)
245
+ _, n_dim, n_classes = tuple(X.shape)
246
+ X = X.reshape(-1, n_dim * n_classes)
247
+ if self.problem.is_sequence:
248
+ X = offdata_to_logits(X)
249
+ # For usual cases, we normalize the inputs
250
+ # and outputs with z-score normalization
251
+ X, self.X_meanormin, self.X_stdormax = self.offline_normalization(X)
252
+ y, self.y_meanormin, self.y_stdormax = self.offline_normalization(y)
253
+
254
+ #### SURROGATE MODEL TRAINING ####
255
+ if not load_models or self.surrogate_given:
256
+ self.train_surrogate(X, y)
257
+ else:
258
+ # Load the proxies
259
+ self.surrogate_model = []
260
+ for i in range(self.problem.n_obj):
261
+ classifier = SingleModel(
262
+ input_size=self.problem.n_var,
263
+ which_obj=i,
264
+ device=self.device,
265
+ hidden_size=[2048, 2048],
266
+ save_dir=self.proxies_store_path,
267
+ save_prefix=f"MultipleModels-Vallina-{self.problem.__class__.__name__}-{self.seed}",
268
+ )
269
+ classifier.load()
270
+ classifier = classifier.to(self.device)
271
+ classifier.eval()
272
+ self.surrogate_model.append(classifier)
273
+
274
+ # Create DataLoader
275
+ train_dataloader, val_dataloader = get_ddpm_dataloader(X,
276
+ y,
277
+ validation_split=self.validation_split,
278
+ batch_size=self.batch_size)
279
+
280
+ #### DIFFUSION MODEL TRAINING ####
281
+ if not load_models:
282
+ self.train(train_dataloader, val_dataloader=val_dataloader)
283
+ #### SPREAD SAMPLING ####
284
+ res_x, res_y = self.sampling(num_points_sample,
285
+ strict_guidance=strict_guidance,
286
+ rho_scale_gamma=rho_scale_gamma,
287
+ nu_t=nu_t, eta_init=eta_init,
288
+ num_inner_steps=num_inner_steps, lr_inner=lr_inner,
289
+ free_initial_h=free_initial_h,
290
+ use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
291
+ iterative_plot=iterative_plot, plot_period=plot_period,
292
+ max_backtracks=max_backtracks, label=label,
293
+ save_results=save_results,
294
+ samples_store_path=samples_store_path,
295
+ images_store_path=images_store_path)
296
+
297
+ return res_x, res_y
298
+
299
+ elif self.mode == "bayesian":
300
+ self.verbose = False
301
+ hv_all_value = []
302
+ # initialize n_init solutions
303
+ x_init = lhs_no_evaluation(self.problem.n_var,
304
+ n_init_mobo)
305
+ x_init = torch.from_numpy(x_init).float().to(self.device)
306
+ y_init = self.problem.evaluate(x_init).detach().cpu().numpy()
307
+
308
+ # initialize dominance-classifier for non-dominance relation
309
+ p_rel_map, s_rel_map = init_dom_rel_map(300)
310
+ p_model = init_dom_nn_classifier(
311
+ x_init, y_init, p_rel_map, pareto_dominance, self.problem.n_var,
312
+ )
313
+ self.dominance_classifier = p_model
314
+
315
+ evaluated = len(y_init)
316
+ X = x_init.detach().cpu().numpy()
317
+ Y = y_init
318
+ hv = HV(ref_point=np.array(self.problem.ref_point))
319
+ hv_value = hv(Y)
320
+ hv_all_value.append(hv_value)
321
+ z = torch.zeros(self.problem.n_obj).to(self.device)
322
+
323
+ escape_flag = False
324
+ if use_escape_local_mobo:
325
+ # Counter for tracking iterations since last switch
326
+ iteration_since_switch = 0
327
+ # Parameters for switching methods
328
+ hv_change_threshold = 0.05 # Threshold for HV value change
329
+ hv_history_length = 3 # Number of recent iterations to consider
330
+ hv_history = [] # Store recent HV values
331
+ # Initialize list to store historical data
332
+ history_Y = []
333
+
334
+ # Start the main loop for Bayesian-SPREAD
335
+ with tqdm(
336
+ total=n_steps_mobo,
337
+ desc=f"SPREAD (MOBO)",
338
+ unit="k",
339
+ ) as pbar:
340
+
341
+ for k_iter in range(n_steps_mobo):
342
+ # Solution normalization
343
+ transformation = StandardTransform([0, 1])
344
+ transformation.fit(X, Y)
345
+ X_norm, Y_norm = transformation.do(X, Y)
346
+
347
+ #### SURROGATE MODEL TRAINING ####
348
+ self.train_surrogate(X_norm, Y_norm)
349
+
350
+ if use_escape_local_mobo:
351
+ _, index = environment_selection(Y, len(X) // 3)
352
+ PopDec = X[index, :]
353
+ else:
354
+ PopDec = X
355
+
356
+ PopDec_dom_labels, PopDec_cfs = nn_predict_dom_intra(
357
+ PopDec, p_model, self.device
358
+ )
359
+ sorted_pop = sort_population(PopDec, PopDec_dom_labels, PopDec_cfs)
360
+
361
+ if not escape_flag:
362
+ # **** Generate new offspring using SPREAD**** #
363
+ self.model.to(self.device)
364
+ #### DIFFUSION MODEL TRAINING ####
365
+ train_dataloader, val_dataloader, dataset_size = mobo_get_ddpm_dataloader(sorted_pop,
366
+ self.objective_functions,
367
+ self.device,
368
+ self.batch_size,
369
+ self.validation_split)
370
+ self.train(train_dataloader,
371
+ val_dataloader=val_dataloader,
372
+ disable_progress_bar=True)
373
+ #### SPREAD SAMPLING ####
374
+ new_offsprings = []
375
+ for i in range(spread_num_samp_mobo):
376
+ num_points_sample = dataset_size
377
+ pf_points, _ = self.sampling(num_points_sample,
378
+ strict_guidance=strict_guidance,
379
+ rho_scale_gamma=rho_scale_gamma,
380
+ nu_t=nu_t, eta_init=eta_init,
381
+ num_inner_steps=num_inner_steps, lr_inner=lr_inner,
382
+ free_initial_h=free_initial_h,
383
+ use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
384
+ iterative_plot=iterative_plot, plot_period=plot_period,
385
+ max_backtracks=max_backtracks, label=label,
386
+ samples_store_path=samples_store_path,
387
+ images_store_path=images_store_path,
388
+ disable_progress_bar=True,
389
+ save_results=False, evaluate_final=False)
390
+ new_offsprings.append(pf_points)
391
+ X_psl = np.vstack(new_offsprings)
392
+ else:
393
+ #### SBX OFFSPRING GENERATION ####
394
+ rows_to_take = int(1 / 3 * sorted_pop.shape[0])
395
+ offspringA = sorted_pop[:rows_to_take, :]
396
+ if len(offspringA) % 2 == 1:
397
+ offspringA = offspringA[:-1]
398
+ new_pop = np.empty((0, self.problem.n_var))
399
+ for _ in range(1000):
400
+ result = sbx(offspringA, eta=15)
401
+ new_pop = np.vstack((new_pop, result))
402
+ X_psl = new_pop
403
+
404
+ pop_size_used = X_psl.shape[0]
405
+
406
+ # Mutate the new offspring
407
+ X_psl = pm_mutation(X_psl, [self.problem.xl.detach().cpu().numpy(),
408
+ self.problem.xu.detach().cpu().numpy()])
409
+
410
+ Y_candidate_mean = self.surrogate_model.evaluate(X_psl)["F"]
411
+ Y_candidate_std = self.surrogate_model.evaluate(X_psl, std=True)["S"]
412
+
413
+ rows_with_nan = np.any(np.isnan(Y_candidate_mean), axis=1)
414
+ Y_candidate_mean = Y_candidate_mean[~rows_with_nan]
415
+ Y_candidate_std = Y_candidate_std[~rows_with_nan]
416
+ X_psl = X_psl[~rows_with_nan]
417
+
418
+ Y_candidate = Y_candidate_mean - self.mobo_coef_lcb * Y_candidate_std
419
+ Y_candidate_mean = Y_candidate
420
+
421
+ #### BATCH SELECTION ####
422
+ nds = NonDominatedSorting()
423
+ idx_nds = nds.do(Y_norm)
424
+ Y_nds = Y_norm[idx_nds[0]]
425
+ best_subset_list = []
426
+ Y_p = Y_nds
427
+ for b in range(batch_select_mobo):
428
+ hv = HV(
429
+ ref_point=np.max(np.vstack([Y_p, Y_candidate_mean]), axis=0)
430
+ )
431
+ best_hv_value = 0
432
+ best_subset = None
433
+
434
+ for k in range(len(Y_candidate_mean)):
435
+ Y_subset = Y_candidate_mean[k]
436
+ Y_comb = np.vstack([Y_p, Y_subset])
437
+ hv_value_subset = hv(Y_comb)
438
+ if hv_value_subset > best_hv_value:
439
+ best_hv_value = hv_value_subset
440
+ best_subset = [k]
441
+
442
+ Y_p = np.vstack([Y_p, Y_candidate_mean[best_subset]])
443
+ best_subset_list.append(best_subset)
444
+
445
+ best_subset_list = np.array(best_subset_list).T[0]
446
+
447
+ X_candidate = X_psl
448
+ X_new = X_candidate[best_subset_list]
449
+ Y_new = self.problem.evaluate(torch.from_numpy(X_new).float().to(self.device)).detach().cpu().numpy()
450
+
451
+ Y_new = torch.tensor(Y_new).to(self.device)
452
+ X_new = torch.tensor(X_new).to(self.device)
453
+
454
+ X = np.vstack([X, X_new.detach().cpu().numpy()])
455
+ Y = np.vstack([Y, Y_new.detach().cpu().numpy()])
456
+ hv = HV(ref_point=np.array(self.problem.ref_point))
457
+ hv_value = hv(Y)
458
+ hv_all_value.append(hv_value)
459
+
460
+ rows_with_nan = np.any(np.isnan(Y), axis=1)
461
+ X = X[~rows_with_nan, :]
462
+ Y = Y[~rows_with_nan, :]
463
+
464
+ update_dom_nn_classifier(
465
+ p_model, X, Y, p_rel_map, pareto_dominance, self.problem
466
+ )
467
+
468
+ hv_text = f"{hv_value:.4e}"
469
+ evaluated = evaluated + batch_select_mobo
470
+
471
+ #### DECISION TO SWITCH OPERATOR ####
472
+ if use_escape_local_mobo:
473
+ # Current operator
474
+ if not escape_flag:
475
+ operator_text = "Diffusion"
476
+ else:
477
+ operator_text = "SBX"
478
+ # Update historical data and calculate reference point
479
+ history_Y.append(Y)
480
+ if len(history_Y) > k:
481
+ history_Y.pop(0)
482
+ all_Y = np.vstack(history_Y) # Combine historical data
483
+ nds_hist = NonDominatedSorting()
484
+ idx_nds_hist = nds_hist.do(all_Y)
485
+ Y_nds_hist = all_Y[idx_nds_hist[0]] # Get non-dominated individuals
486
+ quantile_values = np.quantile(Y_nds_hist, 0.95, axis=0)
487
+ ref_point_method2 = 1.1 * quantile_values
488
+ # Calculate approximate HV
489
+ hv_method2 = HV(ref_point=ref_point_method2)
490
+ hv_value_method2 = hv_method2(Y)
491
+ # Update HV value history
492
+ hv_history.append(hv_value_method2)
493
+ if len(hv_history) > hv_history_length:
494
+ hv_history.pop(0)
495
+
496
+ if len(hv_history) == hv_history_length:
497
+ avg_hv = sum(hv_history[:-1]) / (hv_history_length - 1)
498
+ if avg_hv == 0:
499
+ hv_change = 0
500
+ else:
501
+ hv_change = abs((hv_history[-1] - avg_hv) / avg_hv)
502
+ # Determine if method needs to be switched
503
+ if iteration_since_switch >= 2:
504
+ if hv_change < hv_change_threshold:
505
+ escape_flag = not escape_flag
506
+ iteration_since_switch = 0 # Reset counter
507
+ else:
508
+ iteration_since_switch += 1 # If already switched, increment counter
509
+
510
+ pbar.set_postfix({"HV": hv_text, "Operator": operator_text, "Population": pop_size_used, "Num Points": evaluated})
511
+ else:
512
+ pbar.set_postfix({"HV": hv_text, "Population": pop_size_used, "Num Points": evaluated})
513
+
514
+ pbar.update(1)
515
+
516
+ name_t = (
517
+ "spread"
518
+ + "_"
519
+ + self.problem.__class__.__name__
520
+ + "_T"
521
+ + str(self.timesteps)
522
+ + "_K"
523
+ + str(n_steps_mobo)
524
+ + "_FE"
525
+ + str(n_steps_mobo*batch_select_mobo)
526
+ + "_"
527
+ + f"seed={self.seed}"
528
+ + "_"
529
+ + self.mode
530
+ )
531
+
532
+ if not (os.path.exists(samples_store_path)):
533
+ os.makedirs(samples_store_path)
534
+
535
+ if save_results:
536
+ # Save the samples and HV values
537
+ np.save(samples_store_path + name_t + "_x.npy", X)
538
+ np.save(samples_store_path + name_t + "_y.npy", Y)
539
+ print("\n================ Final Results ================\n")
540
+ print(f"Total function evaluations: {evaluated}")
541
+ print(f"Final hypervolume: {hv_value:.4e}")
542
+ print(f"Samples and HV values are saved to {samples_store_path}\n")
543
+
544
+ outfile = samples_store_path + name_t + "_hv_results.pkl"
545
+ with open(outfile, "wb") as f:
546
+ pickle.dump(hv_all_value, f)
547
+
548
+ return X, Y, hv_all_value
549
+
550
+
551
+ def train(self,
552
+ train_dataloader,
553
+ val_dataloader=None,
554
+ disable_progress_bar=False):
555
+ set_seed(self.seed)
556
+ if self.verbose:
557
+ print(datetime.datetime.now())
558
+
559
+ self.model = self.model.to(self.device)
560
+ optimizer = optim.Adam(self.model.parameters(), lr=self.train_lr)
561
+
562
+ betas = self.cosine_beta_schedule(self.timesteps)
563
+ alphas = 1 - betas
564
+ alpha_cumprod = torch.cumprod(alphas, dim=0).to(self.device)
565
+
566
+ DDPM_FILE_LAST = str("%s/checkpoint_ddpm_last.pth" % (self.model_dir))
567
+ if val_dataloader:
568
+ DDPM_FILE_BEST = str("%s/checkpoint_ddpm_best.pth" % (self.model_dir))
569
+ best_val_loss = np.inf
570
+ tol_violation_epoch = self.train_tol
571
+ cur_violation_epoch = 0
572
+
573
+ time_start = time()
574
+
575
+ with tqdm(
576
+ total=self.num_epochs,
577
+ desc=f"DDPM Training",
578
+ unit="epoch",
579
+ disable=disable_progress_bar,
580
+ ) as pbar:
581
+
582
+ for epoch in range(self.num_epochs):
583
+ self.model.train()
584
+ train_loss = 0.0
585
+ for indx_batch, (batch, obj_values) in enumerate(train_dataloader):
586
+ optimizer.zero_grad()
587
+
588
+ # Extract batch points
589
+ points = batch.to(self.device)
590
+ points.requires_grad = True # Enable gradients
591
+ obj_values = obj_values.to(self.device)
592
+
593
+ # Sample random timesteps for each data point
594
+ t = torch.randint(0, self.timesteps, (points.shape[0],)).to(self.device)
595
+ alpha_bar_t = alpha_cumprod[t].unsqueeze(1) # shape: [batch_size, 1]
596
+
597
+ # Forward process: Add noise to points
598
+ noise = torch.randn_like(points).to(
599
+ self.device
600
+ ) # shape: [batch_size, n_var]
601
+ x_t = (
602
+ torch.sqrt(alpha_bar_t) * points
603
+ + torch.sqrt(1 - alpha_bar_t) * noise
604
+ )
605
+
606
+ # Conditioning information
607
+ c = obj_values
608
+ if self.xi_shift is not None:
609
+ c = c + self.xi_shift
610
+ else:
611
+ xi_shift = c[c > 0].min() if (c > 0).any() else 1e-5
612
+ c = c + xi_shift
613
+ # Model predicts noise
614
+ predicted_noise = self.model(
615
+ x_t, t.float() / self.timesteps, c
616
+ )
617
+
618
+ # Compute loss
619
+ loss_simple = self.l_simple_loss(predicted_noise,
620
+ noise.detach())
621
+ loss_simple.backward()
622
+ optimizer.step()
623
+ train_loss += loss_simple.item()
624
+
625
+ train_loss = train_loss / len(train_dataloader)
626
+
627
+ if val_dataloader:
628
+ # Validation
629
+ self.model.eval()
630
+ val_loss = 0.0
631
+ for indx_batch, (val_batch, val_obj_values) in enumerate(val_dataloader):
632
+ # Extract batch points
633
+ val_points = val_batch.to(
634
+ self.device
635
+ )
636
+ val_points.requires_grad = True # Enable gradients
637
+ val_obj_values = val_obj_values.to(self.device)
638
+
639
+ # Sample random timesteps for each data point
640
+ t = torch.randint(0, self.timesteps, (val_points.shape[0],)).to(self.device)
641
+ alpha_bar_t = alpha_cumprod[t].unsqueeze(1) # shape: [batch_size, 1]
642
+
643
+ # Forward process: Add noise to points
644
+ val_noise = torch.randn_like(val_points).to(
645
+ self.device
646
+ ) # shape: [batch_size, n_var]
647
+ val_x_t = (
648
+ torch.sqrt(alpha_bar_t) * val_points
649
+ + torch.sqrt(1 - alpha_bar_t) * val_noise
650
+ )
651
+
652
+ # Conditioning information
653
+ c = val_obj_values
654
+ if self.xi_shift is not None:
655
+ c = c + self.xi_shift
656
+ else:
657
+ xi_shift = c[c > 0].min() if (c > 0).any() else 1e-6
658
+ c = c + xi_shift
659
+ # Model predicts noise
660
+ val_predicted_noise = self.model(
661
+ val_x_t, t.float() / self.timesteps, c
662
+ )
663
+ loss_simple = self.l_simple_loss(val_predicted_noise,
664
+ val_noise.detach())
665
+ val_loss += loss_simple.item()
666
+
667
+ val_loss = val_loss / len(val_dataloader)
668
+
669
+ if val_loss <= best_val_loss:
670
+ best_val_loss = val_loss
671
+ cur_violation_epoch = 0
672
+ # Save the model
673
+ checkpoint = {
674
+ "model_state_dict": self.model.state_dict(),
675
+ "epoch": epoch + 1,
676
+ }
677
+ torch.save(checkpoint, DDPM_FILE_BEST)
678
+ else:
679
+ cur_violation_epoch += 1
680
+ if cur_violation_epoch >= tol_violation_epoch:
681
+ if self.verbose:
682
+ print(f"Early Stopping at epoch {epoch + 1}.")
683
+ break
684
+
685
+ pbar.set_postfix({"val_loss": val_loss})
686
+ else:
687
+ pbar.set_postfix({"train_loss": train_loss})
688
+ pbar.update(1)
689
+
690
+ comp_time = time() - time_start
691
+ if self.verbose:
692
+ convert_seconds(comp_time)
693
+ if self.verbose:
694
+ print(datetime.datetime.now())
695
+
696
+ if val_dataloader and self.verbose:
697
+ print(f"Best model saved at: {DDPM_FILE_BEST}")
698
+ # Save the model
699
+ checkpoint = {
700
+ "model_state_dict": self.model.state_dict(),
701
+ "train_time": comp_time,
702
+ "num_epochs": self.num_epochs,
703
+ "train_tol": self.train_tol,
704
+ "train_lr": self.train_lr,
705
+ "batch_size": self.batch_size,
706
+ "timesteps": self.timesteps,
707
+ }
708
+ torch.save(checkpoint, DDPM_FILE_LAST)
709
+ if self.verbose:
710
+ print(f"Final model saved at: {DDPM_FILE_LAST}")
711
+
712
+ def sampling(self, num_points_sample,
713
+ strict_guidance=False,
714
+ rho_scale_gamma=0.9,
715
+ nu_t=10.0, eta_init=0.9,
716
+ num_inner_steps=10, lr_inner=1e-4,
717
+ free_initial_h=True,
718
+ use_sigma_rep=False, kernel_sigma_rep=0.01,
719
+ iterative_plot=True, plot_period=100,
720
+ max_backtracks=25, label=None,
721
+ samples_store_path="./samples_dir/",
722
+ images_store_path="./images_dir/",
723
+ disable_progress_bar=False,
724
+ save_results=True, evaluate_final=True):
725
+ # Set the seed
726
+ set_seed(self.seed)
727
+ if save_results:
728
+ # Store the results
729
+ if not (os.path.exists(samples_store_path)):
730
+ os.makedirs(samples_store_path)
731
+
732
+ name = (
733
+ "spread"
734
+ + "_"
735
+ + self.problem.__class__.__name__
736
+ + "_"
737
+ + f"T={self.timesteps}"
738
+ + "_"
739
+ + f"N={num_points_sample}"
740
+ + "_"
741
+ + f"seed={self.seed}"
742
+ + "_"
743
+ + self.mode
744
+ )
745
+
746
+ if label is not None:
747
+ name += f"_{label}"
748
+
749
+ self.model.to(self.device)
750
+ DDPM_FILE = str("%s/checkpoint_ddpm_best.pth" % (self.model_dir))
751
+ # Load the best model if exists, else load the last model
752
+ if not os.path.exists(DDPM_FILE):
753
+ DDPM_FILE = str("%s/checkpoint_ddpm_last.pth" % (self.model_dir))
754
+ if not os.path.exists(DDPM_FILE):
755
+ raise ValueError(f"No trained model found in {self.model_dir}")
756
+ checkpoint = torch.load(DDPM_FILE, map_location=self.device, weights_only=False)
757
+ self.model.load_state_dict(checkpoint["model_state_dict"])
758
+
759
+ betas = self.cosine_beta_schedule(self.timesteps)
760
+ alphas = 1 - betas
761
+ alpha_cumprod = torch.cumprod(alphas, dim=0).to(self.device)
762
+
763
+ # Start from random points in the decision space
764
+ x_t = torch.rand((num_points_sample, self.problem.n_var)) # in [0, 1]
765
+ x_t = self.problem.bounds()[0] + (self.problem.bounds()[1] - self.problem.bounds()[0]) * x_t # scale to bounds
766
+ if self.mode == "offline":
767
+ x_t, _, _ = self.offline_normalization(x_t)
768
+ x_t = x_t.to(self.device)
769
+ x_t.requires_grad = True
770
+ if self.problem.need_repair:
771
+ x_t.data = self.repair_bounds(x_t.data.clone())
772
+
773
+ if self.mode == "online":
774
+ if iterative_plot and (not is_pass_function(self.problem._evaluate)):
775
+ if self.problem.n_obj <= 3:
776
+ pf_population = x_t.detach()
777
+ pf_points, _, _ = self.get_non_dominated_points(
778
+ pf_population,
779
+ keep_shape=False
780
+ )
781
+ list_fi = self.objective_functions(pf_points).split(1, dim=1)
782
+ list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
783
+ pareto_front = None
784
+ if self.problem.pareto_front() is not None:
785
+ pareto_front = self.problem.pareto_front()
786
+ pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
787
+ if self.plot_func is not None:
788
+ self.plot_func(list_fi, self.timesteps,
789
+ num_points_sample,
790
+ extra=pareto_front,
791
+ label=label, images_store_path=images_store_path)
792
+ else:
793
+ plot_dataset = True if self.mode == "offline" else False
794
+ list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
795
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
796
+ self.plot_pareto_front(list_fi, self.timesteps,
797
+ num_points_sample,
798
+ extra=pareto_front,
799
+ plot_dataset=plot_dataset,
800
+ pop=list_fi_pop,
801
+ label=label, images_store_path=images_store_path)
802
+
803
+ prev_pf_points = None
804
+ num_optimal_points = 0
805
+
806
+ point_n0 = None
807
+ optimizer_n0 = None
808
+ if strict_guidance:
809
+ #### Initialize 1 target point for direction perturbation.
810
+ # (If strict_guidance, get new perturbation based on the MGD direction of
811
+ # a single initialized point. Otherwise, a random perturbation is used (as in the paper).)
812
+ point_n0 = torch.rand((1, self.problem.n_var)) # in [0, 1]
813
+ point_n0 = self.problem.bounds()[0] + (self.problem.bounds()[1] - self.problem.bounds()[0]) * point_n0 # scale to bounds
814
+ if self.mode == "offline":
815
+ point_n0, _, _ = self.offline_normalization(point_n0)
816
+ point_n0 = point_n0.to(self.device)
817
+ point_n0.requires_grad = True
818
+ if self.problem.need_repair:
819
+ point_n0.data = self.repair_bounds(
820
+ point_n0.data.clone()
821
+ )
822
+ optimizer_n0 = optim.Adam([point_n0], lr=1e-2)
823
+
824
+ if self.verbose:
825
+ print(datetime.datetime.now())
826
+ print(f"START sampling {num_points_sample} points ...")
827
+
828
+ time_start = time()
829
+ self.model.eval()
830
+
831
+ with tqdm(
832
+ total=self.timesteps,
833
+ desc=f"SPREAD Sampling",
834
+ unit="t",
835
+ disable=disable_progress_bar,
836
+ ) as pbar:
837
+
838
+ for t in reversed(range(self.timesteps)):
839
+ x_t.requires_grad_(True)
840
+ if self.problem.need_repair:
841
+ x_t.data = self.repair_bounds(
842
+ x_t.data.clone()
843
+ )
844
+ # Compute beta_t and alpha_t
845
+ beta_t = 1 - alphas[t]
846
+ alpha_bar_t = alpha_cumprod[t]
847
+
848
+ x_t = self.one_spread_sampling_step(
849
+ x_t,
850
+ num_points_sample,
851
+ t, beta_t, alpha_bar_t, rho_scale_gamma,
852
+ nu_t, eta_init, num_inner_steps, lr_inner,
853
+ free_initial_h=free_initial_h,
854
+ use_sigma=use_sigma_rep, kernel_sigma=kernel_sigma_rep,
855
+ strict_guidance=strict_guidance, max_backtracks=max_backtracks,
856
+ point_n0=point_n0, optimizer_n0=optimizer_n0,
857
+ )
858
+
859
+ if strict_guidance and self.problem.need_repair:
860
+ point_n0.data = self.repair_bounds(
861
+ point_n0.data.clone()
862
+ )
863
+
864
+ if self.problem.need_repair:
865
+ pf_population = self.repair_bounds(
866
+ copy.deepcopy(x_t.detach())
867
+ )
868
+ else:
869
+ pf_population = copy.deepcopy(x_t.detach())
870
+
871
+ # print("Number of points before selection:", len(pf_population))
872
+ pf_points, _, _ = self.get_non_dominated_points(
873
+ pf_population,
874
+ keep_shape=False
875
+ )
876
+
877
+ if prev_pf_points is not None:
878
+ pf_points = torch.cat((prev_pf_points, pf_points), dim=0)
879
+ if self.mode != "bayesian":
880
+ pf_points, _, _ = self.get_non_dominated_points(
881
+ pf_points,
882
+ keep_shape=False,
883
+ )
884
+ # print("Number of non-dominated points before selection:", len(non_dom_points))
885
+ if len(pf_points) > num_points_sample:
886
+ pf_points = self.select_top_n_candidates(
887
+ pf_points,
888
+ num_points_sample,
889
+ )
890
+
891
+ if self.mode == "bayesian" and (pf_points is None or len(pf_points) == 0):
892
+ pf_points = x_t.detach()
893
+ prev_pf_points = pf_points
894
+ num_optimal_points = len(pf_points)
895
+
896
+ if iterative_plot and (not is_pass_function(self.problem._evaluate)):
897
+ if self.problem.n_obj <= 3:
898
+ if (t % plot_period == 0) or (t == self.timesteps - 1):
899
+ if self.mode == "offline":
900
+ # Denormalize the points before plotting
901
+ res_x_t = pf_points.clone().detach()
902
+ res_x_t = self.offline_denormalization(res_x_t,
903
+ self.X_meanormin,
904
+ self.X_stdormax)
905
+ norm_xl, norm_xu = self.problem.bounds()
906
+ xl, xu = self.problem.original_bounds
907
+ self.problem.xl = xl
908
+ self.problem.xu = xu
909
+ if self.problem.is_discrete:
910
+ _, dim, n_classes = tuple(res_x_t.shape)
911
+ res_x_t = res_x_t.reshape(-1, dim, n_classes)
912
+ res_x_t = offdata_to_integers(res_x_t)
913
+ if self.problem.is_sequence:
914
+ res_x_t = offdata_to_integers(res_x_t)
915
+ # we need to evaluate the true objective functions for plotting
916
+ list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
917
+ # restore the normalized bounds
918
+ self.problem.xl = norm_xl
919
+ self.problem.xu = norm_xu
920
+ elif self.mode == "bayesian":
921
+ # we need to evaluate the true objective functions for plotting
922
+ list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
923
+ else:
924
+ list_fi = self.objective_functions(pf_points).split(1, dim=1)
925
+
926
+ list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
927
+ pareto_front = None
928
+ if self.problem.pareto_front() is not None:
929
+ pareto_front = self.problem.pareto_front()
930
+ pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
931
+ if self.plot_func is not None:
932
+ self.plot_func(list_fi, t,
933
+ num_points_sample,
934
+ extra= pareto_front,
935
+ label=label, images_store_path=images_store_path)
936
+ else:
937
+ plot_dataset = True if self.mode == "offline" else False
938
+ list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
939
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
940
+ self.plot_pareto_front(list_fi, t,
941
+ num_points_sample,
942
+ extra= pareto_front,
943
+ pop=list_fi_pop,
944
+ plot_dataset=plot_dataset,
945
+ label=label, images_store_path=images_store_path)
946
+
947
+
948
+ x_t = x_t.detach()
949
+ pbar.set_postfix({
950
+ "Points": num_optimal_points,
951
+ })
952
+ pbar.update(1)
953
+ if self.verbose:
954
+ print(f"END sampling !")
955
+
956
+ comp_time = time() - time_start
957
+
958
+ if self.mode == "offline":
959
+ pf_points = pf_points.detach()
960
+ pf_points = self.offline_denormalization(pf_points,
961
+ self.X_meanormin,
962
+ self.X_stdormax)
963
+ if self.problem.is_discrete:
964
+ _, dim, n_classes = tuple(pf_points.shape)
965
+ pf_points = pf_points.reshape(-1, dim, n_classes)
966
+ pf_points = offdata_to_integers(pf_points)
967
+ if self.problem.is_sequence:
968
+ pf_points = offdata_to_integers(pf_points)
969
+ if self.problem.has_bounds():
970
+ self.problem.xl, self.problem.xu = self.problem.original_bounds
971
+
972
+ res_x = pf_points.detach().cpu().numpy()
973
+ res_y = None
974
+ if evaluate_final and (not is_pass_function(self.problem._evaluate)):
975
+ res_y = self.problem.evaluate(pf_points).detach().cpu().numpy()
976
+ visible_masks = np.ones(len(res_y))
977
+ visible_masks[np.where(np.logical_or(np.isinf(res_y), np.isnan(res_y)))[0]] = 0
978
+ visible_masks[np.where(np.logical_or(np.isinf(res_x), np.isnan(res_x)))[0]] = 0
979
+ res_x = res_x[np.where(visible_masks == 1)[0]]
980
+ res_y = res_y[np.where(visible_masks == 1)[0]]
981
+ if save_results:
982
+ np.save(samples_store_path + name + "_y.npy", res_y)
983
+ hv = HV(ref_point=self.problem.ref_point)
984
+ hv_value = hv(res_y)
985
+ hv_results = {
986
+ "ref_point": self.problem.ref_point,
987
+ "hypervolume": hv_value,
988
+ "computation_time": comp_time
989
+ }
990
+ with open(samples_store_path + name + "_hv_results.pkl", "wb") as f:
991
+ pickle.dump(hv_results, f)
992
+ if self.verbose:
993
+ print(f"Hypervolume: {hv_value} for seed {self.seed}")
994
+ print("---------------------------------------")
995
+ # Print computation time
996
+ convert_seconds(comp_time)
997
+ print(datetime.datetime.now())
998
+
999
+ if save_results:
1000
+ np.save(samples_store_path + name + "_x.npy", res_x)
1001
+
1002
+ return res_x, res_y
1003
+
1004
+
1005
+ def train_surrogate(self,
1006
+ X, y,
1007
+ val_ratio=0.1,
1008
+ batch_size=32,
1009
+ lr=1e-3,
1010
+ lr_decay=0.95,
1011
+ n_epochs=200):
1012
+
1013
+ set_seed(self.seed)
1014
+ self.surrogate_model = self.get_surrogate()
1015
+ if self.surrogate_given:
1016
+ return self.train_surrogate_user_defined(X, y)
1017
+
1018
+ # Train the surrogate model
1019
+ if self.mode == "bayesian":
1020
+ self.surrogate_model.fit(X, y)
1021
+ elif self.mode == "offline":
1022
+ n_obj = y.shape[1]
1023
+ tkwargs = {"device": self.device, "dtype": torch.float32}
1024
+ self.surrogate_model.set_kwargs(**tkwargs)
1025
+
1026
+ trainer_func = SingleModelBaseTrainer
1027
+
1028
+ for which_obj in range(n_obj):
1029
+
1030
+ y0 = y[:, which_obj].clone().reshape(-1, 1)
1031
+
1032
+ trainer = trainer_func(
1033
+ model=list(self.surrogate_model.obj2model.values())[which_obj],
1034
+ which_obj=which_obj,
1035
+ args={
1036
+ "proxies_lr": lr,
1037
+ "proxies_lr_decay": lr_decay,
1038
+ "proxies_epochs": n_epochs,
1039
+ "device": self.device,
1040
+ "verbose": self.verbose,
1041
+ },
1042
+ )
1043
+
1044
+ (train_loader, val_loader) = offdata_get_dataloader(
1045
+ X,
1046
+ y0,
1047
+ train_ratio=(
1048
+ 1 - val_ratio
1049
+ ),
1050
+ batch_size=batch_size,
1051
+ )
1052
+
1053
+ trainer.launch(train_loader, val_loader)
1054
+ # Load the proxies
1055
+ self.surrogate_model = []
1056
+ for i in range(n_obj):
1057
+ classifier = SingleModel(
1058
+ input_size=self.problem.n_var,
1059
+ which_obj=i,
1060
+ device=self.device,
1061
+ hidden_size=[2048, 2048],
1062
+ save_dir=self.proxies_store_path,
1063
+ save_prefix=f"MultipleModels-Vallina-{self.problem.__class__.__name__}-{self.seed}",
1064
+ )
1065
+ classifier.load()
1066
+ classifier = classifier.to(self.device)
1067
+ classifier.eval()
1068
+ self.surrogate_model.append(classifier)
1069
+ else:
1070
+ raise ValueError(f"Mode {self.mode} does not support surrogate model!")
1071
+
1072
+ def train_surrogate_user_defined(self, X, y):
1073
+ """
1074
+ Train the user-defined surrogate model.
1075
+ If self.mode == "offline", the train_func should return a list of trained surrogate models,
1076
+ one for each objective.
1077
+ If self.mode == "bayesian", the train_func should return a single trained surrogate model for all objectives.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ train_func : function
1082
+ A function that takes X, y as input and returns a trained surrogate model.
1083
+ **kwargs : dict
1084
+ Additional keyword arguments to pass to the train_func.
1085
+ -----------
1086
+ """
1087
+ self.surrogate_model = self.train_func_surrogate(X, y)
1088
+
1089
+ def get_surrogate(self):
1090
+ if self.surrogate_given:
1091
+ return self.surrogate_model
1092
+ else:
1093
+ if self.mode == "bayesian":
1094
+ return GaussianProcess(self.problem.n_var,
1095
+ self.problem.n_obj,
1096
+ nu=5)
1097
+ elif self.mode == "offline":
1098
+ os.makedirs(self.proxies_store_path, exist_ok=True)
1099
+ return MultipleModels(
1100
+ n_dim=self.problem.n_var,
1101
+ n_obj=self.problem.n_obj,
1102
+ train_mode="Vallina",
1103
+ device=self.device,
1104
+ hidden_size=[2048, 2048],
1105
+ save_dir=self.proxies_store_path,
1106
+ save_prefix=f"MultipleModels-Vallina-{self.problem.__class__.__name__}-{self.seed}",
1107
+ )
1108
+ else:
1109
+ raise ValueError(f"Mode {self.mode} does not support surrogate model!")
1110
+
1111
+ def one_spread_sampling_step(
1112
+ self,
1113
+ x_t,
1114
+ num_points_sample,
1115
+ t, beta_t, alpha_bar_t, rho_scale_gamma,
1116
+ nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h,
1117
+ use_sigma=False, kernel_sigma=1.0, strict_guidance = False,
1118
+ max_backtracks=100, point_n0=None, optimizer_n0=None,
1119
+ ):
1120
+
1121
+ # Create a tensor of timesteps with shape (num_points_sample, 1)
1122
+ t_tensor = torch.full(
1123
+ (num_points_sample,),
1124
+ t,
1125
+ device=self.device,
1126
+ dtype=torch.float32,
1127
+ )
1128
+ # Compute objective values
1129
+ obj_values = self.objective_functions(x_t)
1130
+
1131
+ # Conditioning information
1132
+ c = obj_values
1133
+ with torch.no_grad():
1134
+ predicted_noise = self.model(x_t,
1135
+ t_tensor / self.timesteps,
1136
+ c)
1137
+
1138
+ g_w = None
1139
+ if strict_guidance:
1140
+ #### If strict_guidance, get new perturbation based on the MGD direction of
1141
+ # a single initialized point. Otherwise, a random perturbation is used.
1142
+ if self.mode in ["online", "offline"]:
1143
+ list_fi_n0 = self.objective_functions(point_n0).split(1, dim=1)
1144
+ list_grad_i_n0 = []
1145
+ for fi_n0 in list_fi_n0:
1146
+ fi_n0.sum().backward(retain_graph=True)
1147
+ grad_i_n0 = point_n0.grad.detach().clone()
1148
+ point_n0.grad.zero_()
1149
+ grad_i_n0 = torch.nn.functional.normalize(grad_i_n0, dim=0)
1150
+ list_grad_i_n0.append(grad_i_n0)
1151
+ else:
1152
+ list_grad_i_n0 = self.objective_functions(point_n0, get_grad_mobo=True)["dF"]
1153
+ for i in range(len(list_grad_i_n0)):
1154
+ # X.grad.zero_()
1155
+ grad_i_n0 = list_grad_i_n0[i]
1156
+ # Normalize gradients
1157
+ grad_i_n0 = torch.nn.functional.normalize(grad_i_n0, dim=0)
1158
+ list_grad_i_n0[i] = grad_i_n0
1159
+
1160
+ optimizer_n0.zero_grad()
1161
+ mth = "mgda"
1162
+ if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
1163
+ mth = "pmgda"
1164
+ g_w = self.get_target_dir(list_grad_i_n0, mth=mth, x=point_n0)
1165
+ point_n0.grad = g_w
1166
+ optimizer_n0.step()
1167
+
1168
+ #### Reverse diffusion step
1169
+ sqrt_1_minus_alpha_t = torch.sqrt(torch.clamp(1 - alpha_bar_t, min=1e-6))
1170
+ sqrt_1_minus_beta_t = torch.sqrt(torch.clamp(1 - beta_t, min=1e-6))
1171
+ mean = (1 / sqrt_1_minus_beta_t) * (
1172
+ x_t - (beta_t / sqrt_1_minus_alpha_t) * (predicted_noise)
1173
+ )
1174
+ std_dev = torch.sqrt(beta_t)
1175
+ z = torch.randn_like(x_t) if t > 0 else 0.0 # No noise for the final step
1176
+ x_t = mean + std_dev * z
1177
+
1178
+ #### Pareto Guidance step
1179
+ if self.problem.need_repair:
1180
+ x_t.data = self.repair_bounds(x_t.data.clone())
1181
+ X = x_t.clone().detach().requires_grad_()
1182
+ if self.mode in ["online", "offline"]:
1183
+ list_fi = self.objective_functions(X).split(1, dim=1)
1184
+ list_grad_i = []
1185
+ for fi in list_fi:
1186
+ fi.sum().backward(retain_graph=True)
1187
+ grad_i = X.grad.detach().clone()
1188
+ grad_i = torch.nn.functional.normalize(grad_i, dim=0)
1189
+ list_grad_i.append(grad_i)
1190
+ X.grad.zero_()
1191
+ elif self.mode == "bayesian":
1192
+ list_grad_i = self.objective_functions(X, get_grad_mobo=True)["dF"]
1193
+ for i in range(len(list_grad_i)):
1194
+ # X.grad.zero_()
1195
+ grad_i = list_grad_i[i]
1196
+ # Normalize gradients
1197
+ grad_i = torch.nn.functional.normalize(grad_i, dim=0)
1198
+ list_grad_i[i] = grad_i
1199
+ else:
1200
+ raise ValueError(f"Mode {self.mode} not recognized!")
1201
+
1202
+ grads = torch.stack(list_grad_i, dim=0) # (m, N, d)
1203
+ grads_copy = torch.stack(list_grad_i, dim=1).detach() # (N, m, d)
1204
+ mth = "mgda"
1205
+ if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
1206
+ mth = "pmgda"
1207
+ g_x_t_prime = self.get_target_dir(list_grad_i, mth=mth, x=X)
1208
+
1209
+ eta = self.mgd_armijo_step(
1210
+ x_t, g_x_t_prime,
1211
+ obj_values, grads_copy, # (N, m, d)
1212
+ eta_init=eta_init,
1213
+ max_backtracks=max_backtracks
1214
+ )
1215
+
1216
+ h_tilde = self.solve_for_h(
1217
+ x_t,
1218
+ g_x_t_prime,
1219
+ grads,
1220
+ g_w,
1221
+ eta=eta,
1222
+ nu_t=nu_t,
1223
+ sigma=kernel_sigma,
1224
+ free_initial_h=free_initial_h,
1225
+ use_sigma=False,
1226
+ num_inner_steps=num_inner_steps,
1227
+ lr_inner=lr_inner,
1228
+ rho_scale_gamma=rho_scale_gamma
1229
+ )
1230
+
1231
+ h_tilde = torch.nan_to_num(h_tilde,
1232
+ nan=torch.nanmean(h_tilde),
1233
+ posinf=0.0,
1234
+ neginf=0.0)
1235
+
1236
+ x_t = x_t - eta * h_tilde
1237
+
1238
+ return x_t
1239
+
1240
+ def solve_for_h(
1241
+ self,
1242
+ x_t_prime,
1243
+ g_x_t_prime,
1244
+ grads,
1245
+ g_w,
1246
+ eta,
1247
+ nu_t,
1248
+ sigma=1.0,
1249
+ use_sigma=False,
1250
+ num_inner_steps=10,
1251
+ lr_inner=1e-2,
1252
+ strict_guidance=False,
1253
+ free_initial_h=False,
1254
+ rho_scale_gamma=0.9
1255
+ ):
1256
+ """
1257
+ Returns:
1258
+ h_tilde: Optimized h (Tensor of shape (batch_size, n_var)).
1259
+ """
1260
+
1261
+ x_t_h = x_t_prime.clone().detach()
1262
+ g = g_x_t_prime.clone().detach()
1263
+
1264
+ if strict_guidance:
1265
+ g_targ = g_w.clone().detach()
1266
+ else:
1267
+ g_targ = torch.randn((1, g.shape[1]), device=g.device)
1268
+
1269
+ # Initialize h
1270
+ if not free_initial_h:
1271
+ h = g_x_t_prime.clone().detach().requires_grad_() # initialize at g
1272
+ else:
1273
+ h = torch.zeros_like(g, requires_grad=False) + 1e-6 # or as a free parameter
1274
+ h = h.requires_grad_()
1275
+
1276
+ optimizer_inner = optim.Adam([h], lr=lr_inner)
1277
+
1278
+ for step in range(num_inner_steps):
1279
+
1280
+ gtarg_scaled = self.adaptive_scale_delta_vect(
1281
+ h, g_targ, grads, gamma=rho_scale_gamma
1282
+ )
1283
+
1284
+ # Alignment term: maximize <g, h>
1285
+ # To maximize L, we minimize -L:
1286
+ alignment = -torch.mean(torch.sum(g * h, dim=-1))
1287
+
1288
+ # Update points:
1289
+ x_t_h = x_t_h - eta * (h + gtarg_scaled)
1290
+ if self.problem.need_repair:
1291
+ x_t_h.data = self.repair_bounds(
1292
+ x_t_h.data
1293
+ )
1294
+ # Map the updated points to the objective space
1295
+ F_ = self.objective_functions(x_t_h)
1296
+ # Compute repulsion loss to encourage diversity
1297
+ if use_sigma:
1298
+ rep_loss = self.repulsion_loss(F_, sigma)
1299
+ else:
1300
+ rep_loss = self.repulsion_loss(F_, use_sigma=False)
1301
+
1302
+ # Our composite objective L is:
1303
+ loss = alignment + nu_t * rep_loss
1304
+
1305
+ optimizer_inner.zero_grad()
1306
+ loss.backward(retain_graph=True)
1307
+ optimizer_inner.step()
1308
+
1309
+ h_tilde = h + gtarg_scaled # This is h_tilde in the paper
1310
+
1311
+ return h_tilde.detach()
1312
+
1313
+ def get_training_data(self, problem, num_samples=10000):
1314
+ """
1315
+ Sample points, using LHS, based on lowest constraint violation
1316
+ """
1317
+ sampler = LHS()
1318
+ # Problem bounds
1319
+ xl, xu = problem.xl, problem.xu
1320
+ # Draw n_sample candidates in [0,1]^n_var
1321
+ Xcand = sampler.do(problem, num_samples).get("X")
1322
+ # Scale to actual bounds
1323
+ Xcand = xl + (xu - xl) * Xcand
1324
+ F = problem.evaluate(Xcand)
1325
+ return Xcand, F
1326
+
1327
+ def betas_for_alpha_bar(self, T, alpha_bar, max_beta=0.999):
1328
+ """
1329
+ Create a beta schedule that discretizes the given alpha_t_bar function,
1330
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
1331
+
1332
+ :param T: the number of betas to produce.
1333
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1334
+ produces the cumulative product of (1-beta) up to that
1335
+ part of the diffusion process.
1336
+ :param max_beta: the maximum beta to use; use values lower than 1 to
1337
+ prevent singularities.
1338
+ """
1339
+ betas = []
1340
+ for i in range(T):
1341
+ t1 = i / T
1342
+ t2 = (i + 1) / T
1343
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
1344
+ return torch.from_numpy(np.array(betas)).float()
1345
+
1346
+ def cosine_beta_schedule(self, s=0.008):
1347
+ """
1348
+ Cosine schedule for beta values over timesteps.
1349
+ """
1350
+ return self.betas_for_alpha_bar(
1351
+ self.timesteps,
1352
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
1353
+ )
1354
+
1355
+ def l_simple_loss(self, predicted_noise, actual_noise):
1356
+ return nn.MSELoss()(predicted_noise, actual_noise)
1357
+
1358
+ def get_target_dir(self, grads, mth="mgda", x=None):
1359
+ m = len(grads)
1360
+ if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
1361
+ assert mth != "mgda", "MGDA not supported with constraints. Use mth ='pmgda'."
1362
+
1363
+ if mth == "mgda":
1364
+ """
1365
+ Compute the MGDA combined descent direction given a list of gradient tensors.
1366
+ All grads are assumed to have the same shape (parameters' shape).
1367
+ Returns a tensor of the same shape as each gradient, representing the direction g.
1368
+ """
1369
+ # Flatten gradients and stack into matrix of shape (m, p), where p is number of params
1370
+ flat_grads = [g.reshape(-1) for g in grads]
1371
+ G = torch.stack(flat_grads, dim=0) # shape: (m, p)
1372
+ # Compute Gram matrix of size (m, m): entry (i,j) = g_i \cdot g_j
1373
+ gram_matrix = G @ G.t() # shape: (m, m)
1374
+
1375
+ # Solve quadratic problem: minimize 0.5 * alpha^T Gram * alpha s.t. sum(alpha)=1, alpha>=0
1376
+ # We use the closed-form solution via KKT for equality constraint, then adjust for alpha>=0.
1377
+ ones = torch.ones(m, device=gram_matrix.device, dtype=gram_matrix.dtype)
1378
+ # Solve Gram * alpha = mu * 1 (plus sum(alpha)=1). This is a linear system with Lagrange multiplier mu.
1379
+ # Use pseudo-inverse in case Gram is singular.
1380
+ inv_gram = torch.linalg.pinv(gram_matrix)
1381
+ alpha = inv_gram @ ones # solution of Gram * alpha = 1 (unnormalized)
1382
+ alpha = alpha / alpha.sum() # enforce sum(alpha) = 1
1383
+
1384
+ # Clamp negative weights to 0 and renormalize if needed (active-set correction for constraints)
1385
+ if (alpha < 0).any():
1386
+ alpha = torch.clamp(alpha, min=0.0)
1387
+ if alpha.sum() == 0:
1388
+ # If all alpha became 0 (numerical issues), fall back to equal weights
1389
+ alpha = torch.ones(m, device=alpha.device) / m
1390
+ else:
1391
+ alpha = alpha / alpha.sum()
1392
+
1393
+ # Compute the combined gradient direction g
1394
+ # Reshape each gradient to original shape and sum with weights
1395
+ g = torch.zeros_like(grads[0])
1396
+ for weight, grad in zip(alpha, grads):
1397
+ g += weight * grad
1398
+ elif mth == "pmgda":
1399
+ pre_h_vals=None
1400
+ constraint_mtd='pbi'
1401
+ SOLVER = PMGDASolver(self.problem, prefs=None,
1402
+ n_prob=grads[0].shape[0], n_obj=self.problem.n_obj,
1403
+ verbose=False)
1404
+ y = self.objective_functions(x, get_constraint=True)
1405
+ # print("y.keys():", y.keys())
1406
+ if "H" in y:
1407
+ pre_h_vals = y["H"].sum(dim=1)
1408
+ constraint_mtd='eq'
1409
+ # print("pre_h_vals.shape:", pre_h_vals.shape)
1410
+ elif "G" in y:
1411
+ # print("pre_h_vals.shape before:", y["G"].shape)
1412
+ pre_h_vals = y["G"].sum(dim=1)
1413
+ print("pre_h_vals.shape:", pre_h_vals.shape)
1414
+ constraint_mtd='ineq'
1415
+ # print("pre_h_vals.shape:", pre_h_vals.shape)
1416
+ y = y["F"]
1417
+ # print("pre_h_vals.shape:", pre_h_vals.shape)
1418
+ alphas = SOLVER.compute_weights(x, y, pre_h_vals=pre_h_vals,
1419
+ constraint_mtd=constraint_mtd)
1420
+ alphas = torch.nan_to_num(alphas, nan=torch.nanmean(alphas),
1421
+ posinf=0.0, neginf=0.0).split(1, dim=1)
1422
+ g = torch.zeros_like(grads[0])
1423
+ for weight, grad in zip(alphas, grads):
1424
+ g += weight * grad
1425
+ else:
1426
+ raise ValueError(f"Method {mth} not recognized.")
1427
+ return g
1428
+
1429
+ def mgd_armijo_step(
1430
+ self,
1431
+ x_t: torch.Tensor,
1432
+ d: torch.Tensor,
1433
+ f_old,
1434
+ grads, # (N, m, d)
1435
+ eta_init=0.9,
1436
+ rho=0.9,
1437
+ c1=1e-4,
1438
+ max_backtracks=100,
1439
+ ):
1440
+ """
1441
+ Batched Armijo back-tracking line search for Multiple-Gradient-Descent (MGD).
1442
+
1443
+ Returns
1444
+ -------
1445
+ eta : torch.Tensor, shape (N,)
1446
+ Final step sizes.
1447
+ """
1448
+
1449
+ x = x_t.clone().detach()
1450
+
1451
+ if not torch.is_tensor(eta_init):
1452
+ eta = torch.full((x.shape[0],), float(eta_init),
1453
+ dtype=x.dtype, device=x.device)
1454
+ else:
1455
+ eta = eta_init.clone().to(x)
1456
+
1457
+ grad_dot = torch.einsum('nkd,nd->nk', grads, d)
1458
+
1459
+ improve = torch.ones_like(eta, dtype=torch.bool)
1460
+
1461
+ for _ in range(max_backtracks):
1462
+ if not improve.any():
1463
+ break
1464
+
1465
+ # Evaluate objectives at trial points
1466
+ trial_x = x[improve] + eta[improve, None] * d[improve]
1467
+ f_new = self.objective_functions(trial_x)
1468
+
1469
+ # Armijo test (vectorised over objectives)
1470
+ # f_new <= f_old + c1 * eta * grad_dot (element-wise)
1471
+ ok = (f_new <= f_old[improve] + c1 * eta[improve, None] * grad_dot[improve]).all(dim=1)
1472
+
1473
+ # Update masks and step sizes
1474
+ eta[improve] = torch.where(ok, eta[improve], rho * eta[improve])
1475
+ improve_mask = improve.clone()
1476
+ improve[improve_mask] = ~ok
1477
+
1478
+ return eta[:, None]
1479
+
1480
+ def adaptive_scale_delta_vect(
1481
+ self,
1482
+ g: torch.Tensor,
1483
+ delta_raw: torch.Tensor,
1484
+ grads: torch.Tensor,
1485
+ gamma: float = 0.9
1486
+ ) -> torch.Tensor:
1487
+ """
1488
+ Adaptive scaling to preserve *positivity*:
1489
+
1490
+ ∇f_j(x_i)^T [ g_i + rho_i * delta_raw_i ] > 0 for all j.
1491
+
1492
+ Args:
1493
+ g (torch.Tensor): [n_points, d], the multi-objective "gradient"
1494
+ (which we *subtract* in the update).
1495
+ delta_raw (torch.Tensor): [n_points, d] or [1, d], the unscaled diversity/repulsion direction.
1496
+ grads (torch.Tensor): [m, n_points, d], storing ∇f_j(x_i).
1497
+ gamma (float): Safety factor in (0,1).
1498
+
1499
+ Returns:
1500
+ delta_scaled (torch.Tensor): [n_points, d], scaled directions s.t.
1501
+ for all j: grads[j,i]ᵀ [g[i] + delta_scaled[i]] > 0.
1502
+ """
1503
+
1504
+ # Compute alpha_{i,j} = ∇f_j(x_i)^T g_i
1505
+ # shape of alpha: [n_points, m]
1506
+ alpha = torch.einsum("j i d, i d -> i j", grads, g)
1507
+
1508
+ # Compute beta_{i,j} = ∇f_j(x_i)^T delta_raw_i
1509
+ # shape of beta: [n_points, m]
1510
+ beta = torch.einsum("j i d, i d -> i j", grads, delta_raw)
1511
+
1512
+ # We only need to restrict rho_i if alpha_{i,j} > 0 and beta_{i,j} < 0.
1513
+ # Because for alpha + rho*beta to stay > 0, we need
1514
+ # rho < alpha / -beta
1515
+ # when beta<0 and alpha>0.
1516
+ mask = (alpha > 0.0) & (beta < 0.0)
1517
+
1518
+ # Prepare an array of ratios = alpha / -beta, default +∞
1519
+ ratio = torch.full_like(alpha, float("inf"))
1520
+
1521
+ # Where mask is True, compute ratio_{i,j}
1522
+ ratio[mask] = alpha[mask] / (-beta[mask]) # must remain below this
1523
+
1524
+ # For each point i, we pick rho_i = gamma * min_j ratio[i,j].
1525
+ # If the min is +∞ => no constraints => set rho_i=1.0
1526
+ ratio_min, _ = ratio.min(dim=1) # [n_points]
1527
+ rho = gamma * ratio_min
1528
+ # If ratio_min == +∞ => no constraint => set rho_i=1.
1529
+ inf_mask = torch.isinf(ratio_min)
1530
+ rho[inf_mask] = 1.0
1531
+
1532
+ # Scale delta_raw by rho_i
1533
+ delta_scaled = delta_raw * rho.unsqueeze(1) # broadcast along dim
1534
+
1535
+ return delta_scaled
1536
+
1537
+ def repair_bounds(self, x):
1538
+ """
1539
+ Clips a tensor x of shape [N, d] such that for each column j:
1540
+ x[:, j] is clipped to be between xl[j] and xu[j].
1541
+
1542
+ Parameters:
1543
+ x (torch.Tensor): Input tensor of shape [N, d].
1544
+
1545
+ Returns:
1546
+ torch.Tensor: The clipped tensor with the same shape as x.
1547
+ """
1548
+
1549
+ xl, xu = self.problem.bounds()[0], self.problem.bounds()[1]
1550
+ lower = xl.detach().clone().to(x.device)
1551
+ upper = xu.detach().clone().to(x.device)
1552
+
1553
+ if self.problem.global_clamping:
1554
+ return torch.clamp(x.data.clone(), min=lower.min(), max=upper.max())
1555
+ else:
1556
+ return torch.clamp(x.data.clone(), min=lower, max=upper)
1557
+
1558
+ def repulsion_loss(self,
1559
+ F_,
1560
+ sigma=1.0,
1561
+ use_sigma=False):
1562
+ """
1563
+ Computes the repulsion loss over a batch of points in the objective space.
1564
+ F_: Tensors of shape (n, m), where n is the batch size.
1565
+ Only unique pairs (i < j) are considered.
1566
+ """
1567
+ n = F_.shape[0]
1568
+ # Compute pairwise differences: shape [n, n, m]
1569
+ dist_sq = torch.norm(F_[:, None] - F_, dim=2).pow(2)
1570
+ # Compute RBF values for the distances
1571
+ if use_sigma:
1572
+ repulsion = torch.exp(-dist_sq / (2 * sigma**2))
1573
+ else:
1574
+ tensor = dist_sq.detach().flatten()
1575
+ tensor_max = tensor.max()[None]
1576
+ median_dist = (torch.cat((tensor, tensor_max)).median() + tensor.median()) / 2.0
1577
+ s = median_dist / math.log(n)
1578
+ repulsion = torch.exp(-dist_sq / 5e-6 * s)
1579
+
1580
+ # Normalize by the number of pairs
1581
+ loss = (2/(n*(n-1))) * repulsion.sum()
1582
+ return loss
1583
+
1584
+ def eps_dominance(self, Obj_space, alpha=0.0):
1585
+ epsilon = alpha * np.min(Obj_space, axis=0)
1586
+ N = len(Obj_space)
1587
+ Pareto_set_idx = list(range(N))
1588
+ Dominated = []
1589
+ for i in range(N):
1590
+ candt = Obj_space[i] - epsilon
1591
+ for j in range(N):
1592
+ if np.all(candt >= Obj_space[j]) and np.any(candt > Obj_space[j]):
1593
+ Dominated.append(i)
1594
+ break
1595
+ PS_idx = list(set(Pareto_set_idx) - set(Dominated))
1596
+ return PS_idx
1597
+
1598
+
1599
+ def get_non_dominated_points(
1600
+ self,
1601
+ points_pred=None,
1602
+ keep_shape=True,
1603
+ indx_only=False,
1604
+ p_front=None
1605
+ ):
1606
+ if not indx_only and points_pred is None:
1607
+ raise ValueError("points_pred cannot be None when indx_only is False.")
1608
+ if points_pred is not None:
1609
+ pf_points = copy.deepcopy(points_pred.detach())
1610
+ p_front = self.objective_functions(pf_points).detach().cpu().numpy()
1611
+ else:
1612
+ assert p_front is not None, "p_front must be provided if points_pred is None."
1613
+
1614
+ if self.mode in ["online", "offline"]:
1615
+ PS_idx = self.eps_dominance(p_front)
1616
+ elif self.mode == "bayesian":
1617
+ N = points_pred.shape[0]
1618
+ # 1) Predict dominance
1619
+ label_matrix, _ = nn_predict_dom_intra(points_pred.detach().cpu().numpy(),
1620
+ self.dominance_classifier,
1621
+ self.device)
1622
+ # 2) Find non‑dominated indices
1623
+ PS_idx = [
1624
+ i for i in range(N)
1625
+ if not any(label_matrix[j, i] == 2 for j in range(N))
1626
+ ]
1627
+ # print(f"Number of non-dominated points: {len(PS_idx)} out of {N}")
1628
+ else:
1629
+ raise ValueError(f"Mode {self.mode} not recognized!")
1630
+
1631
+ if indx_only:
1632
+ return PS_idx
1633
+
1634
+ elif keep_shape:
1635
+ PS_idx = np.sort(PS_idx)
1636
+ # Create an array of all indices
1637
+ all_indices = np.arange(p_front.shape[0])
1638
+ # Identify the indices not in PS_idx
1639
+ not_in_PS_idx = np.setdiff1d(all_indices, PS_idx)
1640
+ # For each index not in PS_idx, find the nearest index in PS_idx
1641
+ for idx in not_in_PS_idx:
1642
+ # Compute the distance to all indices in PS_idx
1643
+ distances = np.abs(PS_idx - idx)
1644
+ nearest_idx = PS_idx[np.argmin(distances)] # Find the closest index
1645
+ pf_points[idx] = pf_points[
1646
+ nearest_idx
1647
+ ] # Replace with the value at the closest index
1648
+
1649
+ return pf_points, points_pred, PS_idx
1650
+
1651
+ else:
1652
+ return pf_points[PS_idx], points_pred, PS_idx
1653
+
1654
+
1655
+ def crowding_distance(self, points):
1656
+ """
1657
+ Compute crowding distances for points.
1658
+ points: Tensor of shape (N, D) in the objective space.
1659
+ Returns: Tensor of shape (N,) containing crowding distances.
1660
+ """
1661
+ N, D = points.shape
1662
+ distances = torch.zeros(N, device=points.device)
1663
+
1664
+ for d in range(D):
1665
+ sorted_points, indices = torch.sort(points[:, d])
1666
+ distances[indices[0]] = distances[indices[-1]] = float("inf")
1667
+
1668
+ min_d, max_d = sorted_points[0], sorted_points[-1]
1669
+ norm_range = max_d - min_d if max_d > min_d else 1.0
1670
+
1671
+ # Compute normalized crowding distance
1672
+ distances[indices[1:-1]] += (
1673
+ sorted_points[2:] - sorted_points[:-2]
1674
+ ) / norm_range
1675
+
1676
+ return distances
1677
+
1678
+
1679
+ def select_top_n_candidates(
1680
+ self,
1681
+ points: torch.Tensor,
1682
+ n: int,
1683
+ top_frac: float = 0.9
1684
+ ) -> torch.Tensor:
1685
+ """
1686
+ Selects the top `n` points from `points` based on crowding distance.
1687
+
1688
+ Returns:
1689
+ torch.Tensor: The best subset of points (shape [n, D]).
1690
+ """
1691
+
1692
+ if self.mode in ["online", "offline"]:
1693
+ if len(points) <= n:
1694
+ final_idx = torch.randperm(points.size(0))
1695
+ else:
1696
+ full_p_front = self.objective_functions(points)
1697
+ distances = self.crowding_distance(full_p_front)
1698
+ top_indices = torch.topk(distances, n).indices
1699
+ final_idx = top_indices[torch.randperm(top_indices.size(0))]
1700
+ else:
1701
+ N = points.shape[0]
1702
+ # print(f"In selection, N={N}, n={n}")
1703
+ # 1) Predict dominance
1704
+ label_matrix, conf_matrix = nn_predict_dom_intra(points.detach().cpu().numpy(),
1705
+ self.dominance_classifier,
1706
+ self.device)
1707
+ # 2) Find non‑dominated indices
1708
+ nondom_inds = [
1709
+ i for i in range(N)
1710
+ if not any(label_matrix[j, i] == 2 for j in range(N))
1711
+ ]
1712
+ # print(f"(selection) Number of non-dominated points: {len(nondom_inds)} out of {N}")
1713
+
1714
+ # --- CASE A: too many non‑dominated → pick top-n by crowding ---
1715
+ if len(nondom_inds) > n:
1716
+ # Evaluate objectives on just the non-dominated set
1717
+ pts_nd = points[nondom_inds].to(self.device)
1718
+ Y_t = self.objective_functions(pts_nd)
1719
+
1720
+ # Compute crowding distances and select top-n
1721
+ distances = self.crowding_distance(Y_t)
1722
+ topk = torch.topk(distances, n).indices.tolist()
1723
+
1724
+ selected_nd = [nondom_inds[i] for i in topk]
1725
+
1726
+ # Shuffle before returning
1727
+ perm = torch.randperm(n, device=points.device)
1728
+ final_idx = torch.tensor(selected_nd, device=points.device)[perm]
1729
+ return points[final_idx] #.detach().cpu().numpy()
1730
+
1731
+ # --- CASE B: nondom ≤ n → fill up via rank + top_frac% + crowding ---
1732
+ # 3) Compute dom counts & avg confidence for all
1733
+ dom_counts = []
1734
+ avg_conf = []
1735
+ for i in range(N):
1736
+ dom_by = (label_matrix[:, i] == 2)
1737
+ cnt = int(dom_by.sum())
1738
+ dom_counts.append(cnt)
1739
+ avg_conf.append(
1740
+ float(conf_matrix[dom_by, i].sum()) / cnt
1741
+ if cnt > 0 else 0.0
1742
+ )
1743
+
1744
+ # 4) Sort full points by (dom_count asc, avg_conf desc)
1745
+ idxs = list(range(N))
1746
+ idxs.sort(key=lambda i: (dom_counts[i], -avg_conf[i]))
1747
+
1748
+ # 5) Keep only top top_frac% of that ranking
1749
+ k90 = int(np.floor(top_frac * N))
1750
+ top90 = idxs[:k90]
1751
+
1752
+ # 6) Evaluate
1753
+ pts90 = points[top90]
1754
+ Y_t = self.objective_functions(pts90)
1755
+
1756
+ # 7) Crowding distance & pick as many as needed to reach n
1757
+ distances = self.crowding_distance(Y_t)
1758
+ need = n - len(nondom_inds)
1759
+ need = max(need, 0)
1760
+ k_sel = min(need, len(top90))
1761
+ sel90 = torch.topk(distances, k_sel).indices.tolist()
1762
+ selected_from_top_frac = [ top90[i] for i in sel90 ]
1763
+
1764
+ # 8) Build final list: all nondom + selected_from_top_frac
1765
+ final_inds = nondom_inds + selected_from_top_frac
1766
+
1767
+ # 9) If still short (e.g. N<n), pad with best remaining in idxs
1768
+ if len(final_inds) < n:
1769
+ remaining = [i for i in idxs if i not in final_inds]
1770
+ to_add = n - len(final_inds)
1771
+ final_inds += remaining[:to_add]
1772
+
1773
+ # 10) Shuffle final indices
1774
+ perm = torch.randperm(len(final_inds), device=points.device)
1775
+ final_idx = torch.tensor(final_inds, device=points.device)[perm]
1776
+
1777
+ return points[final_idx]
1778
+
1779
+ def plot_pareto_front(self,
1780
+ list_fi,
1781
+ t,
1782
+ num_points_sample,
1783
+ extra=None,
1784
+ label=None,
1785
+ plot_dataset=False,
1786
+ pop=None,
1787
+ images_store_path="./images_dir/"):
1788
+ name = (
1789
+ "spread"
1790
+ + "_"
1791
+ + self.problem.__class__.__name__
1792
+ + "_"
1793
+ + f"T={self.timesteps}"
1794
+ + "_"
1795
+ + f"N={num_points_sample}"
1796
+ + "_"
1797
+ + f"t={t}"
1798
+ + "_"
1799
+ + f"seed={self.seed}"
1800
+ + "_"
1801
+ + self.mode
1802
+ )
1803
+ if label is not None:
1804
+ name += f"_{label}"
1805
+
1806
+ if len(list_fi) > 3:
1807
+ return None
1808
+
1809
+ elif len(list_fi) == 2:
1810
+ if extra is not None:
1811
+ f1, f2 = extra
1812
+ plt.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,)
1813
+ # label="Pareto optimal points")
1814
+
1815
+ if pop is not None:
1816
+ f_pop1, f_pop2 = pop
1817
+ plt.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,)
1818
+ # label="Population points")
1819
+
1820
+ f1, f2 = list_fi
1821
+ plt.scatter(f1, f2, c="red", s=10, alpha=1.0,)
1822
+ # label="Generated optimal points")
1823
+ if plot_dataset and (self.dataset) is not None:
1824
+ _, Y = self.dataset
1825
+ Y = self.offline_denormalization(Y,
1826
+ self.y_meanormin,
1827
+ self.y_stdormax)
1828
+ plt.scatter(Y[:, 0], Y[:, 1],
1829
+ c="blue", s=5, alpha=1.0,)
1830
+ # label="Training data points")
1831
+
1832
+ plt.xlabel("$f_1$", fontsize=14)
1833
+ plt.ylabel("$f_2$", fontsize=14)
1834
+ plt.title(f"Reverse Time Step: {t}", fontsize=14)
1835
+
1836
+ elif len(list_fi) == 3:
1837
+ fig = plt.figure()
1838
+ ax = fig.add_subplot(111, projection="3d")
1839
+
1840
+ if extra is not None:
1841
+ f1, f2, f3 = extra
1842
+ ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=0.05,)
1843
+ # label="Pareto optimal points")
1844
+
1845
+ if pop is not None:
1846
+ f_pop1, f_pop2, f_pop3 = pop
1847
+ ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,)
1848
+ # label="Population points")
1849
+
1850
+ f1, f2, f3 = list_fi
1851
+ ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,)
1852
+ # label="Generated optimal points")
1853
+
1854
+ if plot_dataset and (self.dataset is not None):
1855
+ _, Y = self.dataset
1856
+ Y = self.offline_denormalization(Y,
1857
+ self.y_meanormin,
1858
+ self.y_stdormax)
1859
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
1860
+ c="blue", s=5, alpha=1.0,)
1861
+ # label="Training data points")
1862
+ ax.set_xlabel("$f_1$", fontsize=14)
1863
+ ax.set_ylabel("$f_2$", fontsize=14)
1864
+ ax.set_zlabel("$f_3$", fontsize=14)
1865
+ ax.view_init(elev=30, azim=45)
1866
+ ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
1867
+
1868
+ img_dir = f"{images_store_path}/{self.problem.__class__.__name__}_{self.mode}"
1869
+ if label is not None:
1870
+ img_dir += f"_{label}"
1871
+ if not os.path.exists(img_dir):
1872
+ os.makedirs(img_dir)
1873
+
1874
+ # plt.legend(fontsize=12)
1875
+
1876
+ plt.savefig(
1877
+ f"{img_dir}/{name}.jpg",
1878
+ dpi=300,
1879
+ bbox_inches="tight",
1880
+ )
1881
+ plt.close()