moospread 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
moospread/core.py CHANGED
@@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
30
30
  from mpl_toolkits.mplot3d import Axes3D
31
31
 
32
32
  from moospread.utils import *
33
- print
33
+
34
34
  class SPREAD:
35
35
  def __init__(self,
36
36
  problem,
@@ -53,8 +53,8 @@ class SPREAD:
53
53
  train_tol: int = 100,
54
54
  train_tol_surrogate: int = 100,
55
55
  mobo_coef_lcb=0.1,
56
- model_dir: str = "./model_dir",
57
- proxies_store_path: str = "./proxies_dir",
56
+ model_dir: str = "./model_dir/",
57
+ proxies_store_path: str = "./proxies_dir/",
58
58
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
59
59
  seed: int = 0,
60
60
  offline_global_clamping: bool = False,
@@ -62,6 +62,50 @@ class SPREAD:
62
62
  train_func_surrogate = None,
63
63
  plot_func = None,
64
64
  verbose: bool = True):
65
+ """
66
+ Initialize a SPREAD solver instance, configure mode (online, offline, bayesian),
67
+ prepare dataset (and normalization in offline mode),
68
+ initialize diffusion model and (optionally) surrogate model setup.
69
+
70
+ Arguments.
71
+ problem: Optimization problem instance. Must define n_var, n_obj, bounds (xl, xu / bounds()),
72
+ and evaluate(...). In online/bayesian must implement _evaluate.
73
+ mode (str): One of {"online","offline","bayesian"}.
74
+ model: Diffusion model. If None, defaults to DiTMOO(...).
75
+ surrogate_model: Optional surrogate model (used in offline and bayesian).
76
+ dataset: Optional (X, y) training set. If None, SPREAD may generate data (depending on mode).
77
+ xi_shift: Optional constant shift added to conditioning objective values to ensure positivity.
78
+ data_size (int): Number of samples to generate if dataset is not provided.
79
+ validation_split (float): Fraction of dataset used for validation in diffusion training.
80
+ hidden_dim, num_heads, num_blocks: Diffusion model hyperparameters (for default DiTMOO).
81
+ timesteps (int): Number of diffusion steps.
82
+ batch_size (int): Batch size for diffusion training.
83
+ train_lr (float): Learning rate for diffusion training.
84
+ train_lr_surrogate (float): Learning rate for surrogate training (offline/bayesian).
85
+ num_epochs (int): Max epochs for diffusion training.
86
+ num_epochs_surrogate (int): Max epochs for surrogate training (offline/bayesian).
87
+ train_tol (int): Early-stopping patience for diffusion validation loss
88
+ train_tol_surrogate (int): Early-stopping patience for surrogate training (if implemented).
89
+ mobo_coef_lcb (float): LCB coefficient in Bayesian mode: mean - coef * std.
90
+ model_dir (str): Directory to store diffusion checkpoints.
91
+ proxies_store_path (str): Directory to store offline proxy models.
92
+ device: Torch device used for training/sampling.
93
+ seed (int): Random seed for reproducibility.
94
+ offline_global_clamping (bool): If True, clamp all dimensions with global min/max rather than per-dimension.
95
+ offline_normalization_method (str|None): One of {"z_score","min_max",None} used in offline mode for normalizing X and y.
96
+ train_func_surrogate: Optional user-defined surrogate training function.
97
+ plot_func: Optional custom plotting function.
98
+ verbose (bool): If True, prints parameter counts and progress messages.
99
+
100
+ Outputs
101
+ None (constructor). Initializes internal fields like:
102
+ self.model, self.surrogate_model, self.dataset
103
+ offline normalization parameters: X_meanormin, X_stdormax, etc.
104
+ normalized bounds in offline mode (and stores original bounds).
105
+
106
+ Raises
107
+ ValueError: If mode invalid, or if constraints on method availability are violated.
108
+ """
65
109
 
66
110
  self.mode = mode.lower()
67
111
  if self.mode not in ["offline", "online", "bayesian"]:
@@ -88,7 +132,7 @@ class SPREAD:
88
132
  self.mobo_coef_lcb = mobo_coef_lcb
89
133
 
90
134
  self.xi_shift = xi_shift
91
- self.model_dir = model_dir+f"/{self.problem.__class__.__name__}_{self.mode}"
135
+ self.model_dir = model_dir+f"{self.problem.__class__.__name__}_{self.mode}"
92
136
  os.makedirs(self.model_dir, exist_ok=True)
93
137
 
94
138
  self.train_func_surrogate = train_func_surrogate
@@ -198,6 +242,34 @@ class SPREAD:
198
242
  get_constraint=False,
199
243
  get_grad_mobo=False,
200
244
  evaluate_true=False):
245
+ """
246
+ Evaluate the objective functions at given points based on the current mode.
247
+ online: call true problem evaluation.
248
+ offline: evaluate proxy models (one per objective).
249
+ bayesian: evaluate surrogate mean/std and optionally gradients (LCB objective).
250
+
251
+ Arguments
252
+ points (torch.Tensor): Decision points, shape (N, n_var)
253
+ (or encoded forms for discrete/sequence problems depending on your pipeline).
254
+ return_as_dict (bool): Forwarded to problem.evaluate(...) in online/true eval.
255
+ return_values_of (list|None): Forwarded to problem.evaluate(...) to request "F", "G", "H", etc.
256
+ get_constraint (bool): If True, returns constraint info ("G", "H") along with objectives when supported.
257
+ get_grad_mobo (bool): Bayesian only. If True, requests gradients from surrogate (dF, dS) and returns gradient of LCB.
258
+ evaluate_true (bool): If True, forces true problem evaluation even if mode is not online (used for plotting in offline).
259
+
260
+ Returns
261
+ If get_constraint=True (true/online): a dict with keys typically including "F" and optionally "G", "H".
262
+ If mode is online or evaluate_true=True: objective tensor (N, n_obj) (or dict if return_as_dict=True).
263
+ If mode is offline: tensor (N, n_obj) built by stacking each proxy output.
264
+ If mode is bayesian:
265
+ if get_grad_mobo=False: tensor (N, n_obj) containing mean - coef * std.
266
+ if get_grad_mobo=True: dict with:
267
+ "F": LCB tensor (N, n_obj)
268
+ "dF": list of length n_obj, each tensor (N, n_var).
269
+
270
+ Raises
271
+ ValueError: If mode is invalid.
272
+ """
201
273
  if evaluate_true:
202
274
  if self.problem.need_repair:
203
275
  points = self.repair_bounds(points)
@@ -247,7 +319,7 @@ class SPREAD:
247
319
  use_sigma_rep=False, kernel_sigma_rep=0.01,
248
320
  iterative_plot=True, plot_period=100,
249
321
  plot_dataset=False, plot_population=False,
250
- elev=30, azim=45, legend=False,
322
+ elev=30, azim=45, legend=False, alpha_pf_3d=0.05,
251
323
  max_backtracks=100, label=None, save_results=True,
252
324
  load_models=False,
253
325
  samples_store_path="./samples_dir/",
@@ -255,6 +327,43 @@ class SPREAD:
255
327
  n_init_mobo=100, use_escape_local_mobo=True,
256
328
  n_steps_mobo=20, spread_num_samp_mobo=25,
257
329
  batch_select_mobo=5):
330
+ """
331
+ offline/online: (optional) train surrogate → train diffusion → sample Pareto solutions.
332
+ bayesian: runs a full MOBO loop where SPREAD generates candidate designs and selects batches
333
+ via hypervolume improvement.
334
+
335
+ Arguments
336
+ Key Arguments (offline/online path)
337
+ num_points_sample (int): Number of solutions to sample.
338
+ strict_guidance (bool): If True, direction perturbation uses MGDA direction from a
339
+ single evolving target point.
340
+ rho_scale_gamma, nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h: Guidance
341
+ inner-loop controls.
342
+ use_sigma_rep, kernel_sigma_rep: Repulsion loss configuration.
343
+ iterative_plot, legend, plot_period, plot_dataset, plot_population: Plot controls during sampling.
344
+ elev, azim, alpha_pf_3d: 3D plot controls.
345
+ max_backtracks: Max Armijo backtracking steps.
346
+ label: Extra tag appended to saved files/plots.
347
+ save_results (bool): Save sampled results to disk.
348
+ load_models (bool): If True, loads saved diffusion / proxy models instead of training.
349
+ samples_store_path, images_store_path: Output directories.
350
+ Key Arguments (bayesian path)
351
+ n_init_mobo (int): Initial random evaluations.
352
+ use_escape_local_mobo (bool): Enables switching between diffusion operator and SBX when HV stagnates.
353
+ n_steps_mobo (int): Number of MOBO iterations.
354
+ spread_num_samp_mobo (int): How many SPREAD sampling runs to aggregate into candidate pool.
355
+ batch_select_mobo (int): Batch size per iteration (selected by HV improvement).
356
+
357
+ Returns
358
+ If mode in {offline, online}: (res_x, res_y)
359
+ res_x: np.ndarray of sampled Pareto(-like) decision vectors, shape (K, n_var) (K ≤ requested due to filtering).
360
+ res_y: np.ndarray of evaluated objectives, shape (K, n_obj) (or None if final evaluation disabled).
361
+ If mode is bayesian: ([X, Y], hv_all_value)
362
+ X: np.ndarray of all evaluated designs
363
+ Y: np.ndarray of all evaluated objectives
364
+ hv_all_value: list of hypervolume values per iteration.
365
+ """
366
+
258
367
  set_seed(self.seed)
259
368
  if self.mode in ["offline", "online"]:
260
369
  X, y = self.dataset
@@ -299,7 +408,7 @@ class SPREAD:
299
408
  use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
300
409
  iterative_plot=iterative_plot, plot_period=plot_period,
301
410
  plot_dataset=plot_dataset, plot_population=plot_population,
302
- elev=elev, azim=azim, legend=legend,
411
+ elev=elev, azim=azim, legend=legend, alpha_pf_3d=alpha_pf_3d,
303
412
  max_backtracks=max_backtracks, label=label,
304
413
  save_results=save_results,
305
414
  samples_store_path=samples_store_path,
@@ -309,12 +418,35 @@ class SPREAD:
309
418
 
310
419
  elif self.mode == "bayesian":
311
420
  self.verbose = False
421
+ all_selected_batch_y = []
312
422
  hv_all_value = []
313
423
  # initialize n_init solutions
314
424
  x_init = lhs_no_evaluation(self.problem.n_var,
315
425
  n_init_mobo)
316
426
  x_init = torch.from_numpy(x_init).float().to(self.device)
317
427
  y_init = self.problem.evaluate(x_init).detach().cpu().numpy()
428
+
429
+ if iterative_plot:
430
+ list_fi_init = [y_init[:, i] for i in range(y_init.shape[1])]
431
+ pareto_front = None
432
+ if self.problem.pareto_front() is not None:
433
+ pareto_front = self.problem.pareto_front()
434
+ pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
435
+ if self.plot_func is not None:
436
+ self.plot_func(list_fi=None, t=0,
437
+ num_points_sample=batch_select_mobo,
438
+ extra=pareto_front,
439
+ dataset = self.dataset,
440
+ pop=list_fi_init,
441
+ elev=elev, azim=azim, legend=legend, mode=self.mode, alpha_pf_3d=alpha_pf_3d,
442
+ label=label, images_store_path=images_store_path)
443
+ else:
444
+ self.plot_pareto_front(list_fi=None, t=0,
445
+ num_points_sample=batch_select_mobo,
446
+ extra=pareto_front,
447
+ pop=list_fi_init,
448
+ elev=elev, azim=azim, legend=legend, alpha_pf_3d=alpha_pf_3d,
449
+ label=label, images_store_path=images_store_path)
318
450
 
319
451
  # initialize dominance-classifier for non-dominance relation
320
452
  p_rel_map, s_rel_map = init_dom_rel_map(300)
@@ -394,7 +526,7 @@ class SPREAD:
394
526
  use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
395
527
  iterative_plot=iterative_plot, plot_period=plot_period,
396
528
  plot_dataset=plot_dataset, plot_population=plot_population,
397
- elev=elev, azim=azim, legend=legend,
529
+ elev=elev, azim=azim, legend=legend, alpha_pf_3d=alpha_pf_3d,
398
530
  max_backtracks=max_backtracks, label=label,
399
531
  samples_store_path=samples_store_path,
400
532
  images_store_path=images_store_path,
@@ -416,6 +548,7 @@ class SPREAD:
416
548
 
417
549
  pop_size_used = X_psl.shape[0]
418
550
 
551
+
419
552
  # Mutate the new offspring
420
553
  X_psl = pm_mutation(X_psl, [self.problem.xl.detach().cpu().numpy(),
421
554
  self.problem.xu.detach().cpu().numpy()])
@@ -451,7 +584,7 @@ class SPREAD:
451
584
  if hv_value_subset > best_hv_value:
452
585
  best_hv_value = hv_value_subset
453
586
  best_subset = [k]
454
-
587
+
455
588
  Y_p = np.vstack([Y_p, Y_candidate_mean[best_subset]])
456
589
  best_subset_list.append(best_subset)
457
590
 
@@ -460,7 +593,10 @@ class SPREAD:
460
593
  X_candidate = X_psl
461
594
  X_new = X_candidate[best_subset_list]
462
595
  Y_new = self.problem.evaluate(torch.from_numpy(X_new).float().to(self.device)).detach().cpu().numpy()
463
-
596
+
597
+ list_fi_new = [Y_new[:, i] for i in range(Y_new.shape[1])]
598
+ all_selected_batch_y.append(list_fi_new)
599
+
464
600
  Y_new = torch.tensor(Y_new).to(self.device)
465
601
  X_new = torch.tensor(X_new).to(self.device)
466
602
 
@@ -480,6 +616,23 @@ class SPREAD:
480
616
 
481
617
  hv_text = f"{hv_value:.4e}"
482
618
  evaluated = evaluated + batch_select_mobo
619
+
620
+ if iterative_plot:
621
+ if self.plot_func is not None:
622
+ self.plot_func(list_fi=all_selected_batch_y, t=k_iter+1,
623
+ num_points_sample=batch_select_mobo,
624
+ extra=pareto_front,
625
+ dataset = self.dataset,
626
+ pop=list_fi_init, alpha_pf_3d=alpha_pf_3d,
627
+ elev=elev, azim=azim, legend=legend, mode=self.mode,
628
+ label=label, images_store_path=images_store_path)
629
+ else:
630
+ self.plot_pareto_front(list_fi=all_selected_batch_y, t=k_iter+1,
631
+ num_points_sample=batch_select_mobo,
632
+ extra=pareto_front,
633
+ pop=list_fi_init, alpha_pf_3d=alpha_pf_3d,
634
+ elev=elev, azim=azim, legend=legend,
635
+ label=label, images_store_path=images_store_path)
483
636
 
484
637
  #### DECISION TO SWITCH OPERATOR ####
485
638
  if use_escape_local_mobo:
@@ -558,13 +711,27 @@ class SPREAD:
558
711
  with open(outfile, "wb") as f:
559
712
  pickle.dump(hv_all_value, f)
560
713
 
561
- return X, Y, hv_all_value
714
+ return [X, Y], hv_all_value
562
715
 
563
716
 
564
717
  def train(self,
565
718
  train_dataloader,
566
719
  val_dataloader=None,
567
720
  disable_progress_bar=False):
721
+ """
722
+ Train the diffusion model (DDPM) using MSE loss on predicted noise.
723
+
724
+ Arguments
725
+ train_dataloader (DataLoader): Yields (x, obj_values) batches.
726
+ val_dataloader (DataLoader|None): Optional validation loader for early stopping
727
+ and best checkpoint saving.
728
+ disable_progress_bar (bool): Disables tqdm progress bar.
729
+
730
+ Returns
731
+ None. Saves checkpoints:
732
+ checkpoint_ddpm_best.pth (if validation enabled)
733
+ checkpoint_ddpm_last.pth (always)
734
+ """
568
735
  set_seed(self.seed)
569
736
  if self.verbose:
570
737
  print(datetime.datetime.now())
@@ -731,12 +898,37 @@ class SPREAD:
731
898
  use_sigma_rep=False, kernel_sigma_rep=0.01,
732
899
  iterative_plot=True, plot_period=100,
733
900
  plot_dataset=False, plot_population=False,
734
- elev=30, azim=45, legend=False,
901
+ elev=30, azim=45, legend=False, alpha_pf_3d=0.05,
735
902
  max_backtracks=25, label=None,
736
903
  samples_store_path="./samples_dir/",
737
904
  images_store_path="./images_dir/",
738
905
  disable_progress_bar=False,
739
906
  save_results=True, evaluate_final=True):
907
+ """
908
+ Generate num_points_sample decision vectors by reverse diffusion, with Pareto guidance
909
+ and optional plotting/saving.
910
+
911
+ Arguments
912
+ num_points_sample (int): Number of points to generate.
913
+ strict_guidance, rho_scale_gamma, nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h: Guidance controls.
914
+ use_sigma_rep, kernel_sigma_rep: Repulsion controls.
915
+ iterative_plot, legend, plot_period, plot_dataset, plot_population: Plot controls.
916
+ elev, azim, alpha_pf_3d: Plot styling.
917
+ max_backtracks (int): Armijo backtracking limit.
918
+ label (str|None): Extra name tag.
919
+ samples_store_path, images_store_path: Output directories.
920
+ disable_progress_bar (bool): Disables tqdm.
921
+ save_results (bool): Saves *_x.npy, optionally *_y.npy and HV pickle.
922
+ evaluate_final (bool): If True, evaluates objectives on final points and filters NaN/Inf.
923
+
924
+ Returns
925
+ (res_x, res_y)
926
+ res_x: np.ndarray, shape (K, n_var) sampled decision vectors.
927
+ res_y: np.ndarray, shape (K, n_obj) final objective values if evaluate_final=True, else None.
928
+
929
+ Raises
930
+ ValueError: If no trained diffusion checkpoint found.
931
+ """
740
932
  # Set the seed
741
933
  set_seed(self.seed)
742
934
  if save_results:
@@ -798,12 +990,12 @@ class SPREAD:
798
990
  # Denormalize the points before plotting
799
991
  res_x_t = pf_points.clone().detach()
800
992
  res_x_t = self.offline_denormalization(res_x_t,
801
- self.X_meanormin,
802
- self.X_stdormax)
993
+ self.X_meanormin,
994
+ self.X_stdormax)
803
995
  res_pop = pf_population.clone().detach()
804
996
  res_pop = self.offline_denormalization(res_pop,
805
- self.X_meanormin,
806
- self.X_stdormax)
997
+ self.X_meanormin,
998
+ self.X_stdormax)
807
999
  norm_xl, norm_xu = self.problem.bounds()
808
1000
  xl, xu = self.problem.original_bounds
809
1001
  self.problem.xl = xl
@@ -842,14 +1034,15 @@ class SPREAD:
842
1034
  extra=pareto_front,
843
1035
  plot_dataset=plot_dataset,
844
1036
  dataset = self.dataset,
845
- elev=elev, azim=azim, legend=legend,
1037
+ pop=list_fi_pop, alpha_pf_3d=alpha_pf_3d,
1038
+ elev=elev, azim=azim, legend=legend, mode=self.mode,
846
1039
  label=label, images_store_path=images_store_path)
847
1040
  else:
848
1041
  self.plot_pareto_front(list_fi, self.timesteps,
849
1042
  num_points_sample,
850
1043
  extra=pareto_front,
851
1044
  plot_dataset=plot_dataset,
852
- pop=list_fi_pop,
1045
+ pop=list_fi_pop, alpha_pf_3d=alpha_pf_3d,
853
1046
  elev=elev, azim=azim, legend=legend,
854
1047
  label=label, images_store_path=images_store_path)
855
1048
 
@@ -927,7 +1120,7 @@ class SPREAD:
927
1120
  pf_population,
928
1121
  keep_shape=False
929
1122
  )
930
-
1123
+
931
1124
  if prev_pf_points is not None:
932
1125
  pf_points = torch.cat((prev_pf_points, pf_points), dim=0)
933
1126
  if self.mode != "bayesian":
@@ -935,6 +1128,7 @@ class SPREAD:
935
1128
  pf_points,
936
1129
  keep_shape=False,
937
1130
  )
1131
+
938
1132
  if len(pf_points) > num_points_sample:
939
1133
  pf_points = self.select_top_n_candidates(
940
1134
  pf_points,
@@ -946,76 +1140,73 @@ class SPREAD:
946
1140
  prev_pf_points = pf_points
947
1141
  num_optimal_points = len(pf_points)
948
1142
 
949
- if iterative_plot and (not is_pass_function(self.problem._evaluate)):
950
- if self.problem.n_obj <= 3:
951
- if (t % plot_period == 0) or (t == self.timesteps - 1):
952
- if self.mode == "offline":
953
- # Denormalize the points before plotting
954
- res_x_t = pf_points.clone().detach()
955
- res_x_t = self.offline_denormalization(res_x_t,
956
- self.X_meanormin,
957
- self.X_stdormax)
958
- res_pop = pf_population.clone().detach()
959
- res_pop = self.offline_denormalization(res_pop,
960
- self.X_meanormin,
961
- self.X_stdormax)
962
- norm_xl, norm_xu = self.problem.bounds()
963
- xl, xu = self.problem.original_bounds
964
- self.problem.xl = xl
965
- self.problem.xu = xu
966
- if self.problem.is_discrete:
967
- _, dim, n_classes = tuple(res_x_t.shape)
968
- res_x_t = res_x_t.reshape(-1, dim, n_classes)
969
- res_x_t = offdata_to_integers(res_x_t)
970
-
971
- _, dim_pop, n_classes_pop = tuple(res_pop.shape)
972
- res_pop = res_pop.reshape(-1, dim_pop, n_classes_pop)
973
- res_pop = offdata_to_integers(res_pop)
974
- if self.problem.is_sequence:
975
- res_x_t = offdata_to_integers(res_x_t)
976
- res_pop = offdata_to_integers(res_pop)
977
- # we need to evaluate the true objective functions for plotting
978
- list_fi = self.objective_functions(res_x_t,
979
- evaluate_true=True).split(1, dim=1)
980
- list_fi_pop = self.objective_functions(res_pop,
1143
+ if self.mode in ["online", "offline"]:
1144
+ if iterative_plot and (not is_pass_function(self.problem._evaluate)):
1145
+ if self.problem.n_obj <= 3:
1146
+ if (t % plot_period == 0) or (t == self.timesteps - 1):
1147
+ if self.mode == "offline":
1148
+ # Denormalize the points before plotting
1149
+ res_x_t = pf_points.clone().detach()
1150
+ res_x_t = self.offline_denormalization(res_x_t,
1151
+ self.X_meanormin,
1152
+ self.X_stdormax)
1153
+ res_pop = pf_population.clone().detach()
1154
+ res_pop = self.offline_denormalization(res_pop,
1155
+ self.X_meanormin,
1156
+ self.X_stdormax)
1157
+ norm_xl, norm_xu = self.problem.bounds()
1158
+ xl, xu = self.problem.original_bounds
1159
+ self.problem.xl = xl
1160
+ self.problem.xu = xu
1161
+ if self.problem.is_discrete:
1162
+ _, dim, n_classes = tuple(res_x_t.shape)
1163
+ res_x_t = res_x_t.reshape(-1, dim, n_classes)
1164
+ res_x_t = offdata_to_integers(res_x_t)
1165
+
1166
+ _, dim_pop, n_classes_pop = tuple(res_pop.shape)
1167
+ res_pop = res_pop.reshape(-1, dim_pop, n_classes_pop)
1168
+ res_pop = offdata_to_integers(res_pop)
1169
+ if self.problem.is_sequence:
1170
+ res_x_t = offdata_to_integers(res_x_t)
1171
+ res_pop = offdata_to_integers(res_pop)
1172
+ # we need to evaluate the true objective functions for plotting
1173
+ list_fi = self.objective_functions(res_x_t,
981
1174
  evaluate_true=True).split(1, dim=1)
982
- list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
983
- # restore the normalized bounds
984
- self.problem.xl = norm_xl
985
- self.problem.xu = norm_xu
986
- elif self.mode == "bayesian":
987
- # we need to evaluate the true objective functions for plotting
988
- list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
989
- list_fi_pop = self.objective_functions(pf_population.detach(), evaluate_true=True).split(1, dim=1)
990
- list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
991
- else:
992
- list_fi = self.objective_functions(pf_points).split(1, dim=1)
993
- list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
994
- list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
995
-
996
- list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
997
- pareto_front = None
998
- if self.problem.pareto_front() is not None:
999
- pareto_front = self.problem.pareto_front()
1000
- pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
1001
-
1002
- if self.plot_func is not None:
1003
- self.plot_func(list_fi, t,
1004
- num_points_sample,
1005
- extra= pareto_front,
1006
- plot_dataset=plot_dataset,
1007
- dataset = self.dataset,
1008
- elev=elev, azim=azim, legend=legend,
1009
- label=label, images_store_path=images_store_path)
1010
- else:
1011
- self.plot_pareto_front(list_fi, t,
1175
+ list_fi_pop = self.objective_functions(res_pop,
1176
+ evaluate_true=True).split(1, dim=1)
1177
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
1178
+ # restore the normalized bounds
1179
+ self.problem.xl = norm_xl
1180
+ self.problem.xu = norm_xu
1181
+ else:
1182
+ list_fi = self.objective_functions(pf_points).split(1, dim=1)
1183
+ list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
1184
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
1185
+
1186
+ list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
1187
+ pareto_front = None
1188
+ if self.problem.pareto_front() is not None:
1189
+ pareto_front = self.problem.pareto_front()
1190
+ pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
1191
+
1192
+ if self.plot_func is not None:
1193
+ self.plot_func(list_fi, t,
1012
1194
  num_points_sample,
1013
1195
  extra= pareto_front,
1014
1196
  pop=list_fi_pop if plot_population else None,
1015
1197
  plot_dataset=plot_dataset,
1016
- elev=elev, azim=azim, legend=legend,
1198
+ dataset = self.dataset, alpha_pf_3d=alpha_pf_3d,
1199
+ elev=elev, azim=azim, legend=legend, mode=self.mode,
1017
1200
  label=label, images_store_path=images_store_path)
1018
-
1201
+ else:
1202
+ self.plot_pareto_front(list_fi, t,
1203
+ num_points_sample,
1204
+ extra= pareto_front,
1205
+ pop=list_fi_pop if plot_population else None,
1206
+ plot_dataset=plot_dataset, alpha_pf_3d=alpha_pf_3d,
1207
+ elev=elev, azim=azim, legend=legend,
1208
+ label=label, images_store_path=images_store_path)
1209
+
1019
1210
 
1020
1211
  x_t = x_t.detach()
1021
1212
  pbar.set_postfix({
@@ -1081,6 +1272,25 @@ class SPREAD:
1081
1272
  lr=1e-3,
1082
1273
  lr_decay=0.95,
1083
1274
  n_epochs=200):
1275
+ """
1276
+ Train or fit the surrogate model depending on mode.
1277
+ bayesian: fits a single multi-output surrogate.
1278
+ offline: trains one proxy per objective (then loads them into a list).
1279
+ If a user surrogate was provided, delegates to train_surrogate_user_defined.
1280
+
1281
+ Arguments
1282
+ X (torch.Tensor or np.ndarray): Training inputs.
1283
+ y (torch.Tensor or np.ndarray): Training objectives.
1284
+ val_ratio (float): Validation split for offline proxy training.
1285
+ batch_size (int): Batch size for proxy training.
1286
+ lr, lr_decay, n_epochs: Proxy training hyperparameters.
1287
+
1288
+ Returns
1289
+ None. Updates self.surrogate_model (list of models in offline, single model in bayesian).
1290
+
1291
+ Raises
1292
+ ValueError: If called in a mode that does not support surrogates.
1293
+ """
1084
1294
 
1085
1295
  set_seed(self.seed)
1086
1296
  self.surrogate_model = self.get_surrogate()
@@ -1143,22 +1353,35 @@ class SPREAD:
1143
1353
 
1144
1354
  def train_surrogate_user_defined(self, X, y):
1145
1355
  """
1146
- Train the user-defined surrogate model.
1147
- If self.mode == "offline", the train_func should return a list of trained surrogate models,
1148
- one for each objective.
1149
- If self.mode == "bayesian", the train_func should return a single trained surrogate model for all objectives.
1356
+ Train a user-provided surrogate via self.train_func_surrogate.
1150
1357
 
1151
- Parameters
1152
- ----------
1153
- train_func : function
1154
- A function that takes X, y as input and returns a trained surrogate model.
1155
- **kwargs : dict
1156
- Additional keyword arguments to pass to the train_func.
1157
- -----------
1358
+ Arguments
1359
+ X: Training inputs.
1360
+ y: Training objectives.
1361
+
1362
+ Returns
1363
+ None. Sets self.surrogate_model to whatever train_func_surrogate(X, y) returns:
1364
+ offline: typically a list of per-objective models
1365
+ bayesian: typically a single multi-objective surrogate.
1158
1366
  """
1159
1367
  self.surrogate_model = self.train_func_surrogate(X, y)
1160
1368
 
1161
1369
  def get_surrogate(self):
1370
+ """
1371
+ Construct a default surrogate model if the user did not provide one.
1372
+
1373
+ Arguments
1374
+ None.
1375
+
1376
+ Returns
1377
+ If user provided a surrogate: returns it as-is.
1378
+ Else:
1379
+ bayesian: returns GaussianProcess(...)
1380
+ offline: returns MultipleModels(...) configured for per-objective proxies.
1381
+
1382
+ Raises
1383
+ ValueError: If called in a mode without surrogate support.
1384
+ """
1162
1385
  if self.surrogate_given:
1163
1386
  return self.surrogate_model
1164
1387
  else:
@@ -1189,6 +1412,30 @@ class SPREAD:
1189
1412
  use_sigma=False, kernel_sigma=1.0, strict_guidance = False,
1190
1413
  max_backtracks=100, point_n0=None, optimizer_n0=None,
1191
1414
  ):
1415
+ """
1416
+ Perform one reverse-diffusion step plus Pareto guidance:
1417
+ 1. DDPM reverse step using predicted noise
1418
+ 2. Compute gradients of objectives
1419
+ 3. Compute MGDA direction
1420
+ 4. Armijo step-size selection
1421
+ 5. Solve inner problem for h_tilde (alignment + repulsion)
1422
+ 6. Update x_t
1423
+
1424
+ Arguments (main ones)
1425
+ x_t (torch.Tensor): Current samples, shape (N, n_var).
1426
+ num_points_sample (int): N.
1427
+ t (int): Current timestep index.
1428
+ beta_t, alpha_bar_t: Diffusion schedule values at t.
1429
+ rho_scale_gamma, nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h: Guidance parameters.
1430
+ use_sigma, kernel_sigma: Repulsion settings.
1431
+ strict_guidance (bool): Uses target-direction perturbation from point_n0.
1432
+ max_backtracks (int): Armijo backtracking cap.
1433
+ point_n0 (torch.Tensor|None): Target point used in strict guidance.
1434
+ optimizer_n0 (Optimizer|None): Optimizer updating point_n0.
1435
+
1436
+ Returns
1437
+ x_t (torch.Tensor): Updated samples after one SPREAD step, shape (N, n_var).
1438
+ """
1192
1439
 
1193
1440
  # Create a tensor of timesteps with shape (num_points_sample, 1)
1194
1441
  t_tensor = torch.full(
@@ -1246,7 +1493,7 @@ class SPREAD:
1246
1493
  std_dev = torch.sqrt(beta_t)
1247
1494
  z = torch.randn_like(x_t) if t > 0 else 0.0 # No noise for the final step
1248
1495
  x_t = mean + std_dev * z
1249
-
1496
+
1250
1497
  #### Pareto Guidance step
1251
1498
  if self.problem.need_repair:
1252
1499
  x_t.data = self.repair_bounds(x_t.data.clone())
@@ -1300,10 +1547,13 @@ class SPREAD:
1300
1547
  rho_scale_gamma=rho_scale_gamma
1301
1548
  )
1302
1549
 
1303
- h_tilde = torch.nan_to_num(h_tilde,
1304
- nan=torch.nanmean(h_tilde),
1305
- posinf=0.0,
1306
- neginf=0.0)
1550
+ # h_tilde = torch.nan_to_num(h_tilde,
1551
+ # nan=torch.nanmean(h_tilde),
1552
+ # posinf=0.0,
1553
+ # neginf=0.0)
1554
+ finite = torch.isfinite(h_tilde)
1555
+ fill = h_tilde[finite].mean() if finite.any() else h_tilde.new_tensor(0.0)
1556
+ h_tilde = torch.where(finite, h_tilde, fill)
1307
1557
 
1308
1558
  x_t = x_t - eta * h_tilde
1309
1559
 
@@ -1326,8 +1576,28 @@ class SPREAD:
1326
1576
  rho_scale_gamma=0.9
1327
1577
  ):
1328
1578
  """
1329
- Returns:
1330
- h_tilde: Optimized h (Tensor of shape (batch_size, n_var)).
1579
+ Inner-loop optimizer that learns/updates h to trade off:
1580
+ alignment with MGDA direction g (descent)
1581
+ repulsion/diversity in objective space (kernel-based)
1582
+
1583
+ Arguments
1584
+ x_t_prime (torch.Tensor): Starting points, shape (N, n_var).
1585
+ g_x_t_prime (torch.Tensor): MGDA/PMGDA direction, shape (N, n_var).
1586
+ grads (torch.Tensor): Objective gradients, shape (m, N, n_var).
1587
+ g_w (torch.Tensor|None): Optional strict-guidance target direction, shape compatible
1588
+ with (N, n_var) or (1, n_var).
1589
+ eta (torch.Tensor): Step size, shape (N, 1) (broadcasted).
1590
+ nu_t (float): Weight of repulsion term.
1591
+ sigma (float): Kernel bandwidth when use_sigma=True.
1592
+ use_sigma (bool): Whether to use fixed sigma.
1593
+ num_inner_steps (int): Number of optimization steps on h.
1594
+ lr_inner (float): Learning rate for inner optimizer.
1595
+ strict_guidance (bool): If True, uses g_w as target direction.
1596
+ free_initial_h (bool): If False, initialize h at g; else initialize small constant.
1597
+ rho_scale_gamma (float): Safety factor used in adaptive scaling.
1598
+
1599
+ Returns
1600
+ h_tilde (torch.Tensor): Final guidance direction, shape (N, n_var) (detached).
1331
1601
  """
1332
1602
 
1333
1603
  x_t_h = x_t_prime.clone().detach()
@@ -1384,7 +1654,16 @@ class SPREAD:
1384
1654
 
1385
1655
  def get_training_data(self, problem, num_samples=10000):
1386
1656
  """
1387
- Sample points, using LHS, based on lowest constraint violation
1657
+ Generate a training dataset using LHS sampling within bounds and evaluate objectives.
1658
+
1659
+ Arguments
1660
+ problem: Problem instance with bounds and evaluate.
1661
+ num_samples (int): Number of sampled candidates.
1662
+
1663
+ Returns
1664
+ (Xcand, F)
1665
+ Xcand: sampled decision vectors
1666
+ F: evaluated objectives.
1388
1667
  """
1389
1668
  sampler = LHS()
1390
1669
  # Problem bounds
@@ -1398,15 +1677,15 @@ class SPREAD:
1398
1677
 
1399
1678
  def betas_for_alpha_bar(self, T, alpha_bar, max_beta=0.999):
1400
1679
  """
1401
- Create a beta schedule that discretizes the given alpha_t_bar function,
1402
- which defines the cumulative product of (1-beta) over time from t = [0,1].
1403
-
1404
- :param T: the number of betas to produce.
1405
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1406
- produces the cumulative product of (1-beta) up to that
1407
- part of the diffusion process.
1408
- :param max_beta: the maximum beta to use; use values lower than 1 to
1409
- prevent singularities.
1680
+ Discretize a continuous cumulative alpha-bar function into a beta schedule.
1681
+
1682
+ Arguments
1683
+ T (int): Number of steps.
1684
+ alpha_bar (callable): Function t [0,1] -> ᾱ(t).
1685
+ max_beta (float): Clamp for numerical stability.
1686
+
1687
+ Returns
1688
+ torch.Tensor: Betas of shape (T,).
1410
1689
  """
1411
1690
  betas = []
1412
1691
  for i in range(T):
@@ -1417,7 +1696,13 @@ class SPREAD:
1417
1696
 
1418
1697
  def cosine_beta_schedule(self, s=0.008):
1419
1698
  """
1420
- Cosine schedule for beta values over timesteps.
1699
+ Compute cosine-based beta schedule for self.timesteps.
1700
+
1701
+ Arguments
1702
+ s (float): Offset used in cosine schedule.
1703
+
1704
+ Returns
1705
+ torch.Tensor: Betas of shape (self.timesteps,).
1421
1706
  """
1422
1707
  return self.betas_for_alpha_bar(
1423
1708
  self.timesteps,
@@ -1425,9 +1710,36 @@ class SPREAD:
1425
1710
  )
1426
1711
 
1427
1712
  def l_simple_loss(self, predicted_noise, actual_noise):
1713
+ """
1714
+ Compute DDPM training loss (MSE between predicted and true noise).
1715
+
1716
+ Arguments
1717
+ predicted_noise (torch.Tensor): Predicted noise.
1718
+ actual_noise (torch.Tensor): True sampled noise.
1719
+
1720
+ Returns
1721
+ torch.Tensor: Scalar loss.
1722
+ """
1428
1723
  return nn.MSELoss()(predicted_noise, actual_noise)
1429
1724
 
1430
1725
  def get_target_dir(self, grads, mth="mgda", x=None):
1726
+ """
1727
+ Compute a single descent direction from multiple objective gradients using:
1728
+ mgda: convex combination minimizing norm of weighted gradient
1729
+ pmgda: constrained variant using PMGDASolver and constraints G/H
1730
+
1731
+ Arguments
1732
+ grads (list[torch.Tensor]): List of gradients, each shape (N, n_var) (or consistent tensor shape).
1733
+ mth (str): "mgda" or "pmgda".
1734
+ x (torch.Tensor|None): Required for pmgda (used for constraint evaluation and weights).
1735
+
1736
+ Returns
1737
+ torch.Tensor: Combined direction g, same shape as one gradient (N, n_var).
1738
+
1739
+ Raises
1740
+ AssertionError: If constraints exist but mth="mgda".
1741
+ ValueError: Unknown method.
1742
+ """
1431
1743
  m = len(grads)
1432
1744
  if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
1433
1745
  assert mth != "mgda", "MGDA not supported with constraints. Use mth ='pmgda'."
@@ -1505,12 +1817,20 @@ class SPREAD:
1505
1817
  max_backtracks=100,
1506
1818
  ):
1507
1819
  """
1508
- Batched Armijo back-tracking line search for Multiple-Gradient-Descent (MGD).
1509
-
1820
+ Batched Armijo backtracking line search for multi-objective descent.
1821
+
1822
+ Arguments
1823
+ x_t (torch.Tensor): Current points (N, n_var).
1824
+ d (torch.Tensor): Descent direction (N, n_var).
1825
+ f_old (torch.Tensor): Current objective values (N, m).
1826
+ grads (torch.Tensor): Gradients (N, m, n_var) (your code uses einsum accordingly).
1827
+ eta_init (float|torch.Tensor): Initial step size (scalar or per-point).
1828
+ rho (float): Backtracking shrink factor.
1829
+ c1 (float): Armijo constant.
1830
+ max_backtracks (int): Maximum backtracking iterations.
1831
+
1510
1832
  Returns
1511
- -------
1512
- eta : torch.Tensor, shape (N,)
1513
- Final step sizes.
1833
+ eta (torch.Tensor): Step sizes shaped (N, 1).
1514
1834
  """
1515
1835
 
1516
1836
  x = x_t.clone().detach()
@@ -1556,14 +1876,14 @@ class SPREAD:
1556
1876
 
1557
1877
  ∇f_j(x_i)^T [ g_i + rho_i * delta_raw_i ] > 0 for all j.
1558
1878
 
1559
- Args:
1879
+ Arguments
1560
1880
  g (torch.Tensor): [n_points, d], the multi-objective "gradient"
1561
1881
  (which we *subtract* in the update).
1562
1882
  delta_raw (torch.Tensor): [n_points, d] or [1, d], the unscaled diversity/repulsion direction.
1563
1883
  grads (torch.Tensor): [m, n_points, d], storing ∇f_j(x_i).
1564
1884
  gamma (float): Safety factor in (0,1).
1565
1885
 
1566
- Returns:
1886
+ Returns
1567
1887
  delta_scaled (torch.Tensor): [n_points, d], scaled directions s.t.
1568
1888
  for all j: grads[j,i]ᵀ [g[i] + delta_scaled[i]] > 0.
1569
1889
  """
@@ -1603,14 +1923,13 @@ class SPREAD:
1603
1923
 
1604
1924
  def repair_bounds(self, x):
1605
1925
  """
1606
- Clips a tensor x of shape [N, d] such that for each column j:
1607
- x[:, j] is clipped to be between xl[j] and xu[j].
1608
-
1609
- Parameters:
1610
- x (torch.Tensor): Input tensor of shape [N, d].
1926
+ Clamp candidate decision vectors into problem bounds (either per-dimension or globally).
1927
+
1928
+ Arguments
1929
+ x (torch.Tensor): Shape (N, n_var).
1611
1930
 
1612
- Returns:
1613
- torch.Tensor: The clipped tensor with the same shape as x.
1931
+ Returns
1932
+ torch.Tensor: Clipped tensor of same shape.
1614
1933
  """
1615
1934
 
1616
1935
  xl, xu = self.problem.bounds()[0], self.problem.bounds()[1]
@@ -1627,9 +1946,15 @@ class SPREAD:
1627
1946
  sigma=1.0,
1628
1947
  use_sigma=False):
1629
1948
  """
1630
- Computes the repulsion loss over a batch of points in the objective space.
1631
- F_: Tensors of shape (n, m), where n is the batch size.
1632
- Only unique pairs (i < j) are considered.
1949
+ Compute RBF-style repulsion loss in objective space to encourage diversity.
1950
+
1951
+ Arguments
1952
+ F_ (torch.Tensor): Objective values, shape (N, m).
1953
+ sigma (float): Kernel bandwidth when use_sigma=True.
1954
+ use_sigma (bool): Whether to use fixed sigma or median heuristic.
1955
+
1956
+ Returns
1957
+ torch.Tensor: Scalar repulsion loss.
1633
1958
  """
1634
1959
  n = F_.shape[0]
1635
1960
  # Compute pairwise differences: shape [n, n, m]
@@ -1649,6 +1974,16 @@ class SPREAD:
1649
1974
  return loss
1650
1975
 
1651
1976
  def eps_dominance(self, Obj_space, alpha=0.0):
1977
+ """
1978
+ Compute indices of (epsilon-)non-dominated points using an epsilon shift epsilon = alpha * min(Obj_space).
1979
+
1980
+ Arguments
1981
+ Obj_space (np.ndarray): Objective values (N, m).
1982
+ alpha (float): Epsilon scaling factor.
1983
+
1984
+ Returns
1985
+ list[int]: Indices of epsilon-nondominated points.
1986
+ """
1652
1987
  epsilon = alpha * np.min(Obj_space, axis=0)
1653
1988
  N = len(Obj_space)
1654
1989
  Pareto_set_idx = list(range(N))
@@ -1670,6 +2005,30 @@ class SPREAD:
1670
2005
  indx_only=False,
1671
2006
  p_front=None
1672
2007
  ):
2008
+ """
2009
+ Extract nondominated points (or indices) depending on mode:
2010
+ offline/online: uses eps_dominance
2011
+ bayesian: uses dominance classifier pairwise predictions
2012
+
2013
+ Arguments
2014
+ points_pred (torch.Tensor|None): Decision points (N, n_var).
2015
+ keep_shape (bool): If True, returns a tensor with same length N by replacing dominated
2016
+ entries with nearest nondominated neighbor.
2017
+ indx_only (bool): If True, only returns indices.
2018
+ p_front (np.ndarray|None): If provided and points_pred is None, uses this objective array directly.
2019
+
2020
+ Returns
2021
+ If indx_only=True: PS_idx (list[int])
2022
+ Else if keep_shape=True: (pf_points, points_pred, PS_idx)
2023
+ pf_points: tensor shaped like points_pred (length N)
2024
+ points_pred: original input
2025
+ PS_idx: nondominated indices
2026
+ Else: (pf_points[PS_idx], points_pred, PS_idx)
2027
+
2028
+ Raises
2029
+ ValueError: If insufficient inputs are provided.
2030
+ """
2031
+
1673
2032
  if not indx_only and points_pred is None:
1674
2033
  raise ValueError("points_pred cannot be None when indx_only is False.")
1675
2034
  if points_pred is not None:
@@ -1720,9 +2079,13 @@ class SPREAD:
1720
2079
 
1721
2080
  def crowding_distance(self, points):
1722
2081
  """
1723
- Compute crowding distances for points.
1724
- points: Tensor of shape (N, D) in the objective space.
1725
- Returns: Tensor of shape (N,) containing crowding distances.
2082
+ Compute crowding distance (NSGA-II style) for a set of objective points.
2083
+
2084
+ Arguments
2085
+ points (torch.Tensor): Objective values, shape (N, m).
2086
+
2087
+ Returns
2088
+ torch.Tensor: Crowding distances, shape (N,) (with boundary points set to inf per objective).
1726
2089
  """
1727
2090
  N, D = points.shape
1728
2091
  distances = torch.zeros(N, device=points.device)
@@ -1749,10 +2112,17 @@ class SPREAD:
1749
2112
  top_frac: float = 0.9
1750
2113
  ) -> torch.Tensor:
1751
2114
  """
1752
- Selects the top `n` points from `points` based on crowding distance.
2115
+ Select a subset of candidate decision vectors based on diversity and (in bayesian)
2116
+ predicted nondominance.
2117
+
2118
+ Arguments
2119
+ points (torch.Tensor): Candidate decision vectors, shape (N, n_var).
2120
+ n (int): Number of points to return.
2121
+ top_frac (float): Bayesian-only fallback: fraction of best-ranked points considered
2122
+ before crowding selection.
1753
2123
 
1754
- Returns:
1755
- torch.Tensor: The best subset of points (shape [n, D]).
2124
+ Returns
2125
+ torch.Tensor: Selected decision vectors, shape (min(n,N), n_var).
1756
2126
  """
1757
2127
 
1758
2128
  if self.mode in ["online", "offline"]:
@@ -1848,8 +2218,31 @@ class SPREAD:
1848
2218
  label=None,
1849
2219
  plot_dataset=False,
1850
2220
  pop=None,
1851
- elev=30, azim=45, legend=False,
2221
+ elev=30, azim=45, legend=False, alpha_pf_3d=0.05,
1852
2222
  images_store_path="./images_dir/"):
2223
+ """
2224
+ Save a 2D/3D scatter plot of generated Pareto(-like) points (and optionally dataset,
2225
+ population, and true Pareto front).
2226
+
2227
+ Arguments
2228
+ list_fi: Objectives to plot.
2229
+ offline/online: list of arrays [f1, f2] or [f1, f2, f3]
2230
+ bayesian: can be a list over iterations, each element containing [f1, f2] (or 3D)
2231
+ t (int): Reverse timestep (or MOBO iteration).
2232
+ num_points_sample (int): Used in filename.
2233
+ extra: Optional true Pareto front arrays.
2234
+ label: Optional name suffix.
2235
+ plot_dataset (bool): If True, plots training data (offline).
2236
+ pop: Optional population objectives to plot.
2237
+ elev, azim: 3D view parameters.
2238
+ legend (bool): Add legend.
2239
+ alpha_pf_3d (float): Alpha for 3D Pareto front points.
2240
+ images_store_path (str): Folder to save images.
2241
+
2242
+ Returns
2243
+ None. Writes an image file to disk.
2244
+ """
2245
+
1853
2246
  name = (
1854
2247
  "spread"
1855
2248
  + "_"
@@ -1866,94 +2259,177 @@ class SPREAD:
1866
2259
  + self.mode
1867
2260
  )
1868
2261
  if label is not None:
1869
- name += f"_{label}"
2262
+ name += f"_{label}"
1870
2263
 
1871
- if len(list_fi) > 3:
2264
+ if self.problem.n_obj > 3:
1872
2265
  return None
2266
+
2267
+ if self.mode != "bayesian":
2268
+ if len(list_fi) == 2:
2269
+ fig, ax = plt.subplots()
2270
+ if plot_dataset and (self.dataset) is not None:
2271
+ _, Y = self.dataset
2272
+ # Denormalize the data
2273
+ Y = self.offline_denormalization(Y,
2274
+ self.y_meanormin,
2275
+ self.y_stdormax)
2276
+ ax.scatter(Y[:, 0], Y[:, 1],
2277
+ c="violet", s=5, alpha=1.0,
2278
+ label="Training Data")
2279
+
2280
+ if extra is not None:
2281
+ f1, f2 = extra
2282
+ ax.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,
2283
+ label="Pareto Optimal")
2284
+
2285
+ if pop is not None:
2286
+ f_pop1, f_pop2 = pop
2287
+ ax.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,
2288
+ label="Gen Population")
2289
+
2290
+ f1, f2 = list_fi
2291
+ ax.scatter(f1, f2, c="red", s=10, alpha=1.0,
2292
+ label="Gen Optimal")
2293
+
2294
+ ax.set_xlabel("$f_1$", fontsize=14)
2295
+ ax.set_ylabel("$f_2$", fontsize=14)
2296
+ ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
2297
+ ax.text(
2298
+ -0.17, 0.5,
2299
+ self.problem.__class__.__name__.upper() + f"({self.mode})",
2300
+ transform=ax.transAxes,
2301
+ va='center',
2302
+ ha='center',
2303
+ rotation='vertical',
2304
+ fontsize=20,
2305
+ fontweight='bold'
2306
+ )
1873
2307
 
1874
- elif len(list_fi) == 2:
1875
- fig, ax = plt.subplots()
1876
- if plot_dataset and (self.dataset) is not None:
1877
- _, Y = self.dataset
1878
- # Denormalize the data
1879
- Y = self.offline_denormalization(Y,
1880
- self.y_meanormin,
1881
- self.y_stdormax)
1882
- ax.scatter(Y[:, 0], Y[:, 1],
2308
+ elif len(list_fi) == 3:
2309
+ fig = plt.figure()
2310
+ ax = fig.add_subplot(111, projection="3d")
2311
+
2312
+ if plot_dataset and (self.dataset is not None):
2313
+ _, Y = self.dataset
2314
+ # Denormalize the data
2315
+ Y = self.offline_denormalization(Y,
2316
+ self.y_meanormin,
2317
+ self.y_stdormax)
2318
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
1883
2319
  c="violet", s=5, alpha=1.0,
1884
2320
  label="Training Data")
1885
-
1886
- if extra is not None:
1887
- f1, f2 = extra
1888
- ax.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,
2321
+
2322
+ if extra is not None:
2323
+ f1, f2, f3 = extra
2324
+ ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=alpha_pf_3d,
1889
2325
  label="Pareto Optimal")
1890
2326
 
1891
- if pop is not None:
1892
- f_pop1, f_pop2 = pop
1893
- ax.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,
2327
+ if pop is not None:
2328
+ f_pop1, f_pop2, f_pop3 = pop
2329
+ ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,
1894
2330
  label="Gen Population")
1895
-
1896
- f1, f2 = list_fi
1897
- ax.scatter(f1, f2, c="red", s=10, alpha=1.0,
2331
+
2332
+ f1, f2, f3 = list_fi
2333
+ ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,
1898
2334
  label="Gen Optimal")
1899
-
1900
- ax.set_xlabel("$f_1$", fontsize=14)
1901
- ax.set_ylabel("$f_2$", fontsize=14)
1902
- ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
1903
- ax.text(
1904
- -0.15, 0.5,
1905
- self.problem.__class__.__name__.upper() + f"({self.mode})",
1906
- transform=ax.transAxes,
1907
- va='center',
1908
- ha='center',
1909
- rotation='vertical',
1910
- fontsize=20,
1911
- fontweight='bold'
1912
- )
2335
+
2336
+ ax.set_xlabel("$f_1$", fontsize=14)
2337
+ ax.set_ylabel("$f_2$", fontsize=14)
2338
+ ax.set_zlabel("$f_3$", fontsize=14)
2339
+ ax.view_init(elev=elev, azim=azim)
2340
+ ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
2341
+ ax.text2D(
2342
+ -0.17, 0.5,
2343
+ self.problem.__class__.__name__.upper() + f"({self.mode})",
2344
+ transform=ax.transAxes,
2345
+ va='center',
2346
+ ha='center',
2347
+ rotation='vertical',
2348
+ fontsize=20,
2349
+ fontweight='bold'
2350
+ )
2351
+ else:
2352
+ # Bayesian mode
2353
+ if self.problem.n_obj == 2:
2354
+ fig, ax = plt.subplots()
2355
+ if extra is not None:
2356
+ f1, f2 = extra
2357
+ ax.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,
2358
+ label="Pareto Optimal")
2359
+ if pop is not None:
2360
+ f_pop1, f_pop2 = pop
2361
+ ax.scatter(f_pop1, f_pop2, c="green", s=10, alpha=1.0,
2362
+ label="Init Points")
2363
+ if list_fi is not None:
2364
+ n = len(list_fi)
2365
+ for i in range(len(list_fi)):
2366
+ f1, f2 = list_fi[i]
2367
+ # alpha for the inner color only
2368
+ a = 1.0 / (n - i)
2369
+ face_color = (1.0, 0.0, 0.0, a) # red with fading alpha
2370
+ edge_color = (1.0, 0.0, 0.0, 1.0) # solid red border
2371
+ ax.scatter(f1, f2, c="red", s=10,
2372
+ facecolors=face_color,
2373
+ edgecolors=edge_color,
2374
+ linewidths=0.5,
2375
+ marker='o',
2376
+ label="Gen Optimal" if i==len(list_fi)-1 else None)
2377
+
2378
+ ax.set_xlabel("$f_1$", fontsize=14)
2379
+ ax.set_ylabel("$f_2$", fontsize=14)
2380
+ ax.set_title(f"Step: {t}", fontsize=14)
2381
+ ax.text(
2382
+ -0.17, 0.5,
2383
+ self.problem.__class__.__name__.upper() + f"(mobo)",
2384
+ transform=ax.transAxes,
2385
+ va='center',
2386
+ ha='center',
2387
+ rotation='vertical',
2388
+ fontsize=20,
2389
+ fontweight='bold'
2390
+ )
1913
2391
 
1914
- elif len(list_fi) == 3:
1915
- fig = plt.figure()
1916
- ax = fig.add_subplot(111, projection="3d")
1917
-
1918
- if plot_dataset and (self.dataset is not None):
1919
- _, Y = self.dataset
1920
- # Denormalize the data
1921
- Y = self.offline_denormalization(Y,
1922
- self.y_meanormin,
1923
- self.y_stdormax)
1924
- ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
1925
- c="violet", s=5, alpha=1.0,
1926
- label="Training Data")
1927
-
1928
- if extra is not None:
1929
- f1, f2, f3 = extra
1930
- ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=0.05,
1931
- label="Pareto Optimal")
1932
-
1933
- if pop is not None:
1934
- f_pop1, f_pop2, f_pop3 = pop
1935
- ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,
1936
- label="Gen Population")
2392
+ elif self.problem.n_obj == 3:
2393
+ fig = plt.figure()
2394
+ ax = fig.add_subplot(111, projection="3d")
1937
2395
 
1938
- f1, f2, f3 = list_fi
1939
- ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,
1940
- label="Gen Optimal")
1941
-
1942
- ax.set_xlabel("$f_1$", fontsize=14)
1943
- ax.set_ylabel("$f_2$", fontsize=14)
1944
- ax.set_zlabel("$f_3$", fontsize=14)
1945
- ax.view_init(elev=elev, azim=azim)
1946
- ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
1947
- ax.text(
1948
- -0.15, 0.5,
1949
- self.problem.__class__.__name__.upper() + f"({self.mode})",
1950
- transform=ax.transAxes,
1951
- va='center',
1952
- ha='center',
1953
- rotation='vertical',
1954
- fontsize=20,
1955
- fontweight='bold'
1956
- )
2396
+ if extra is not None:
2397
+ f1, f2, f3 = extra
2398
+ ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=alpha_pf_3d,
2399
+ label="Pareto Optimal")
2400
+ if pop is not None:
2401
+ f_pop1, f_pop2, f_pop3 = pop
2402
+ ax.scatter(f_pop1, f_pop2, f_pop3, c="green", s=10, alpha=1.0,
2403
+ label="Init Points")
2404
+ if list_fi is not None:
2405
+ n = len(list_fi)
2406
+ for i in range(len(list_fi)):
2407
+ f1, f2, f3 = list_fi[i]
2408
+ a = 1.0 / (n - i)
2409
+ face_color = (1.0, 0.0, 0.0, a) # red with fading alpha
2410
+ edge_color = (1.0, 0.0, 0.0, 1.0) # solid red border
2411
+ ax.scatter(f1, f2, f3, c="red", s = 10,
2412
+ facecolors=face_color,
2413
+ edgecolors=edge_color,
2414
+ linewidths=0.5,
2415
+ marker='o',
2416
+ label="Gen Optimal" if i==len(list_fi)-1 else None)
2417
+
2418
+ ax.set_xlabel("$f_1$", fontsize=14)
2419
+ ax.set_ylabel("$f_2$", fontsize=14)
2420
+ ax.set_zlabel("$f_3$", fontsize=14)
2421
+ ax.view_init(elev=elev, azim=azim)
2422
+ ax.set_title(f"Step: {t}", fontsize=14)
2423
+ ax.text2D(
2424
+ -0.17, 0.5,
2425
+ self.problem.__class__.__name__.upper() + f"(mobo)",
2426
+ transform=ax.transAxes,
2427
+ va='center',
2428
+ ha='center',
2429
+ rotation='vertical',
2430
+ fontsize=20,
2431
+ fontweight='bold'
2432
+ )
1957
2433
 
1958
2434
  img_dir = f"{images_store_path}/{self.problem.__class__.__name__}_{self.mode}"
1959
2435
  if label is not None:
@@ -1975,14 +2451,29 @@ class SPREAD:
1975
2451
  total_duration_s=20.0,
1976
2452
  first_transition_s=2.0,
1977
2453
  fps=30,
2454
+ reverse=True,
1978
2455
  extensions=("*.jpg", "*.png", "*.jpeg", "*.bmp")):
1979
- """Create a video from images in `image_folder`, sorted by t=... in filename.
1980
- The first transition (first->second image) lasts `first_transition_s` seconds.
1981
- The remaining transitions share the remaining time equally.
1982
- The output video has total duration `total_duration_s` seconds at `fps` frames per second.
2456
+ """
2457
+ Create an MP4 video by blending a sequence of images (sorted by t=... in filename). First transition gets
2458
+ a fixed duration; remaining transitions share the rest.
2459
+
2460
+ Arguments
2461
+ image_folder (str): Directory containing images.
2462
+ output_video (str): Output video path (MP4).
2463
+ total_duration_s (float): Total video duration in seconds.
2464
+ first_transition_s (float): Duration of first transition in seconds.
2465
+ fps (int): Frames per second.
2466
+ reverse (bool): If True, sort images by decreasing t=....
2467
+ extensions (tuple[str,...]): Glob patterns for image files.
2468
+
2469
+ Returns
2470
+ None. Writes video to output_video and prints a success message.
2471
+
2472
+ Raises
2473
+ RuntimeError: If no images are found, fewer than two images exist, or first image cannot be read.
1983
2474
  """
1984
2475
 
1985
- # Collect and sort by t=... (descending)
2476
+ # Collect and sort by t=... (descending/ascending)
1986
2477
  paths = []
1987
2478
  for ext in extensions:
1988
2479
  paths.extend(glob.glob(os.path.join(image_folder, ext)))
@@ -1994,7 +2485,7 @@ class SPREAD:
1994
2485
  m = t_pat.search(p)
1995
2486
  return int(m.group(1)) if m else -1
1996
2487
 
1997
- paths.sort(key=lambda p: t_val(p), reverse=True)
2488
+ paths.sort(key=lambda p: t_val(p), reverse=reverse)
1998
2489
  N = len(paths)
1999
2490
  if N < 2:
2000
2491
  raise RuntimeError("Need at least two images for a transition.")