moospread 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- moospread/core.py +709 -218
- moospread/tasks/bo_torch.py +24 -9
- moospread/utils/mobo_utils/learning/model_init.py +0 -4
- moospread/utils/mobo_utils/learning/utils.py +41 -17
- {moospread-0.1.3.dist-info → moospread-0.1.5.dist-info}/METADATA +17 -4
- {moospread-0.1.3.dist-info → moospread-0.1.5.dist-info}/RECORD +9 -9
- {moospread-0.1.3.dist-info → moospread-0.1.5.dist-info}/WHEEL +0 -0
- {moospread-0.1.3.dist-info → moospread-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {moospread-0.1.3.dist-info → moospread-0.1.5.dist-info}/top_level.txt +0 -0
moospread/core.py
CHANGED
|
@@ -30,7 +30,7 @@ import matplotlib.pyplot as plt
|
|
|
30
30
|
from mpl_toolkits.mplot3d import Axes3D
|
|
31
31
|
|
|
32
32
|
from moospread.utils import *
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
class SPREAD:
|
|
35
35
|
def __init__(self,
|
|
36
36
|
problem,
|
|
@@ -53,8 +53,8 @@ class SPREAD:
|
|
|
53
53
|
train_tol: int = 100,
|
|
54
54
|
train_tol_surrogate: int = 100,
|
|
55
55
|
mobo_coef_lcb=0.1,
|
|
56
|
-
model_dir: str = "./model_dir",
|
|
57
|
-
proxies_store_path: str = "./proxies_dir",
|
|
56
|
+
model_dir: str = "./model_dir/",
|
|
57
|
+
proxies_store_path: str = "./proxies_dir/",
|
|
58
58
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
59
59
|
seed: int = 0,
|
|
60
60
|
offline_global_clamping: bool = False,
|
|
@@ -62,6 +62,50 @@ class SPREAD:
|
|
|
62
62
|
train_func_surrogate = None,
|
|
63
63
|
plot_func = None,
|
|
64
64
|
verbose: bool = True):
|
|
65
|
+
"""
|
|
66
|
+
Initialize a SPREAD solver instance, configure mode (online, offline, bayesian),
|
|
67
|
+
prepare dataset (and normalization in offline mode),
|
|
68
|
+
initialize diffusion model and (optionally) surrogate model setup.
|
|
69
|
+
|
|
70
|
+
Arguments.
|
|
71
|
+
problem: Optimization problem instance. Must define n_var, n_obj, bounds (xl, xu / bounds()),
|
|
72
|
+
and evaluate(...). In online/bayesian must implement _evaluate.
|
|
73
|
+
mode (str): One of {"online","offline","bayesian"}.
|
|
74
|
+
model: Diffusion model. If None, defaults to DiTMOO(...).
|
|
75
|
+
surrogate_model: Optional surrogate model (used in offline and bayesian).
|
|
76
|
+
dataset: Optional (X, y) training set. If None, SPREAD may generate data (depending on mode).
|
|
77
|
+
xi_shift: Optional constant shift added to conditioning objective values to ensure positivity.
|
|
78
|
+
data_size (int): Number of samples to generate if dataset is not provided.
|
|
79
|
+
validation_split (float): Fraction of dataset used for validation in diffusion training.
|
|
80
|
+
hidden_dim, num_heads, num_blocks: Diffusion model hyperparameters (for default DiTMOO).
|
|
81
|
+
timesteps (int): Number of diffusion steps.
|
|
82
|
+
batch_size (int): Batch size for diffusion training.
|
|
83
|
+
train_lr (float): Learning rate for diffusion training.
|
|
84
|
+
train_lr_surrogate (float): Learning rate for surrogate training (offline/bayesian).
|
|
85
|
+
num_epochs (int): Max epochs for diffusion training.
|
|
86
|
+
num_epochs_surrogate (int): Max epochs for surrogate training (offline/bayesian).
|
|
87
|
+
train_tol (int): Early-stopping patience for diffusion validation loss
|
|
88
|
+
train_tol_surrogate (int): Early-stopping patience for surrogate training (if implemented).
|
|
89
|
+
mobo_coef_lcb (float): LCB coefficient in Bayesian mode: mean - coef * std.
|
|
90
|
+
model_dir (str): Directory to store diffusion checkpoints.
|
|
91
|
+
proxies_store_path (str): Directory to store offline proxy models.
|
|
92
|
+
device: Torch device used for training/sampling.
|
|
93
|
+
seed (int): Random seed for reproducibility.
|
|
94
|
+
offline_global_clamping (bool): If True, clamp all dimensions with global min/max rather than per-dimension.
|
|
95
|
+
offline_normalization_method (str|None): One of {"z_score","min_max",None} used in offline mode for normalizing X and y.
|
|
96
|
+
train_func_surrogate: Optional user-defined surrogate training function.
|
|
97
|
+
plot_func: Optional custom plotting function.
|
|
98
|
+
verbose (bool): If True, prints parameter counts and progress messages.
|
|
99
|
+
|
|
100
|
+
Outputs
|
|
101
|
+
None (constructor). Initializes internal fields like:
|
|
102
|
+
self.model, self.surrogate_model, self.dataset
|
|
103
|
+
offline normalization parameters: X_meanormin, X_stdormax, etc.
|
|
104
|
+
normalized bounds in offline mode (and stores original bounds).
|
|
105
|
+
|
|
106
|
+
Raises
|
|
107
|
+
ValueError: If mode invalid, or if constraints on method availability are violated.
|
|
108
|
+
"""
|
|
65
109
|
|
|
66
110
|
self.mode = mode.lower()
|
|
67
111
|
if self.mode not in ["offline", "online", "bayesian"]:
|
|
@@ -88,7 +132,7 @@ class SPREAD:
|
|
|
88
132
|
self.mobo_coef_lcb = mobo_coef_lcb
|
|
89
133
|
|
|
90
134
|
self.xi_shift = xi_shift
|
|
91
|
-
self.model_dir = model_dir+f"
|
|
135
|
+
self.model_dir = model_dir+f"{self.problem.__class__.__name__}_{self.mode}"
|
|
92
136
|
os.makedirs(self.model_dir, exist_ok=True)
|
|
93
137
|
|
|
94
138
|
self.train_func_surrogate = train_func_surrogate
|
|
@@ -198,6 +242,34 @@ class SPREAD:
|
|
|
198
242
|
get_constraint=False,
|
|
199
243
|
get_grad_mobo=False,
|
|
200
244
|
evaluate_true=False):
|
|
245
|
+
"""
|
|
246
|
+
Evaluate the objective functions at given points based on the current mode.
|
|
247
|
+
online: call true problem evaluation.
|
|
248
|
+
offline: evaluate proxy models (one per objective).
|
|
249
|
+
bayesian: evaluate surrogate mean/std and optionally gradients (LCB objective).
|
|
250
|
+
|
|
251
|
+
Arguments
|
|
252
|
+
points (torch.Tensor): Decision points, shape (N, n_var)
|
|
253
|
+
(or encoded forms for discrete/sequence problems depending on your pipeline).
|
|
254
|
+
return_as_dict (bool): Forwarded to problem.evaluate(...) in online/true eval.
|
|
255
|
+
return_values_of (list|None): Forwarded to problem.evaluate(...) to request "F", "G", "H", etc.
|
|
256
|
+
get_constraint (bool): If True, returns constraint info ("G", "H") along with objectives when supported.
|
|
257
|
+
get_grad_mobo (bool): Bayesian only. If True, requests gradients from surrogate (dF, dS) and returns gradient of LCB.
|
|
258
|
+
evaluate_true (bool): If True, forces true problem evaluation even if mode is not online (used for plotting in offline).
|
|
259
|
+
|
|
260
|
+
Returns
|
|
261
|
+
If get_constraint=True (true/online): a dict with keys typically including "F" and optionally "G", "H".
|
|
262
|
+
If mode is online or evaluate_true=True: objective tensor (N, n_obj) (or dict if return_as_dict=True).
|
|
263
|
+
If mode is offline: tensor (N, n_obj) built by stacking each proxy output.
|
|
264
|
+
If mode is bayesian:
|
|
265
|
+
if get_grad_mobo=False: tensor (N, n_obj) containing mean - coef * std.
|
|
266
|
+
if get_grad_mobo=True: dict with:
|
|
267
|
+
"F": LCB tensor (N, n_obj)
|
|
268
|
+
"dF": list of length n_obj, each tensor (N, n_var).
|
|
269
|
+
|
|
270
|
+
Raises
|
|
271
|
+
ValueError: If mode is invalid.
|
|
272
|
+
"""
|
|
201
273
|
if evaluate_true:
|
|
202
274
|
if self.problem.need_repair:
|
|
203
275
|
points = self.repair_bounds(points)
|
|
@@ -247,7 +319,7 @@ class SPREAD:
|
|
|
247
319
|
use_sigma_rep=False, kernel_sigma_rep=0.01,
|
|
248
320
|
iterative_plot=True, plot_period=100,
|
|
249
321
|
plot_dataset=False, plot_population=False,
|
|
250
|
-
elev=30, azim=45, legend=False,
|
|
322
|
+
elev=30, azim=45, legend=False, alpha_pf_3d=0.05,
|
|
251
323
|
max_backtracks=100, label=None, save_results=True,
|
|
252
324
|
load_models=False,
|
|
253
325
|
samples_store_path="./samples_dir/",
|
|
@@ -255,6 +327,43 @@ class SPREAD:
|
|
|
255
327
|
n_init_mobo=100, use_escape_local_mobo=True,
|
|
256
328
|
n_steps_mobo=20, spread_num_samp_mobo=25,
|
|
257
329
|
batch_select_mobo=5):
|
|
330
|
+
"""
|
|
331
|
+
offline/online: (optional) train surrogate → train diffusion → sample Pareto solutions.
|
|
332
|
+
bayesian: runs a full MOBO loop where SPREAD generates candidate designs and selects batches
|
|
333
|
+
via hypervolume improvement.
|
|
334
|
+
|
|
335
|
+
Arguments
|
|
336
|
+
Key Arguments (offline/online path)
|
|
337
|
+
num_points_sample (int): Number of solutions to sample.
|
|
338
|
+
strict_guidance (bool): If True, direction perturbation uses MGDA direction from a
|
|
339
|
+
single evolving target point.
|
|
340
|
+
rho_scale_gamma, nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h: Guidance
|
|
341
|
+
inner-loop controls.
|
|
342
|
+
use_sigma_rep, kernel_sigma_rep: Repulsion loss configuration.
|
|
343
|
+
iterative_plot, legend, plot_period, plot_dataset, plot_population: Plot controls during sampling.
|
|
344
|
+
elev, azim, alpha_pf_3d: 3D plot controls.
|
|
345
|
+
max_backtracks: Max Armijo backtracking steps.
|
|
346
|
+
label: Extra tag appended to saved files/plots.
|
|
347
|
+
save_results (bool): Save sampled results to disk.
|
|
348
|
+
load_models (bool): If True, loads saved diffusion / proxy models instead of training.
|
|
349
|
+
samples_store_path, images_store_path: Output directories.
|
|
350
|
+
Key Arguments (bayesian path)
|
|
351
|
+
n_init_mobo (int): Initial random evaluations.
|
|
352
|
+
use_escape_local_mobo (bool): Enables switching between diffusion operator and SBX when HV stagnates.
|
|
353
|
+
n_steps_mobo (int): Number of MOBO iterations.
|
|
354
|
+
spread_num_samp_mobo (int): How many SPREAD sampling runs to aggregate into candidate pool.
|
|
355
|
+
batch_select_mobo (int): Batch size per iteration (selected by HV improvement).
|
|
356
|
+
|
|
357
|
+
Returns
|
|
358
|
+
If mode in {offline, online}: (res_x, res_y)
|
|
359
|
+
res_x: np.ndarray of sampled Pareto(-like) decision vectors, shape (K, n_var) (K ≤ requested due to filtering).
|
|
360
|
+
res_y: np.ndarray of evaluated objectives, shape (K, n_obj) (or None if final evaluation disabled).
|
|
361
|
+
If mode is bayesian: ([X, Y], hv_all_value)
|
|
362
|
+
X: np.ndarray of all evaluated designs
|
|
363
|
+
Y: np.ndarray of all evaluated objectives
|
|
364
|
+
hv_all_value: list of hypervolume values per iteration.
|
|
365
|
+
"""
|
|
366
|
+
|
|
258
367
|
set_seed(self.seed)
|
|
259
368
|
if self.mode in ["offline", "online"]:
|
|
260
369
|
X, y = self.dataset
|
|
@@ -299,7 +408,7 @@ class SPREAD:
|
|
|
299
408
|
use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
|
|
300
409
|
iterative_plot=iterative_plot, plot_period=plot_period,
|
|
301
410
|
plot_dataset=plot_dataset, plot_population=plot_population,
|
|
302
|
-
elev=elev, azim=azim, legend=legend,
|
|
411
|
+
elev=elev, azim=azim, legend=legend, alpha_pf_3d=alpha_pf_3d,
|
|
303
412
|
max_backtracks=max_backtracks, label=label,
|
|
304
413
|
save_results=save_results,
|
|
305
414
|
samples_store_path=samples_store_path,
|
|
@@ -309,12 +418,35 @@ class SPREAD:
|
|
|
309
418
|
|
|
310
419
|
elif self.mode == "bayesian":
|
|
311
420
|
self.verbose = False
|
|
421
|
+
all_selected_batch_y = []
|
|
312
422
|
hv_all_value = []
|
|
313
423
|
# initialize n_init solutions
|
|
314
424
|
x_init = lhs_no_evaluation(self.problem.n_var,
|
|
315
425
|
n_init_mobo)
|
|
316
426
|
x_init = torch.from_numpy(x_init).float().to(self.device)
|
|
317
427
|
y_init = self.problem.evaluate(x_init).detach().cpu().numpy()
|
|
428
|
+
|
|
429
|
+
if iterative_plot:
|
|
430
|
+
list_fi_init = [y_init[:, i] for i in range(y_init.shape[1])]
|
|
431
|
+
pareto_front = None
|
|
432
|
+
if self.problem.pareto_front() is not None:
|
|
433
|
+
pareto_front = self.problem.pareto_front()
|
|
434
|
+
pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
|
|
435
|
+
if self.plot_func is not None:
|
|
436
|
+
self.plot_func(list_fi=None, t=0,
|
|
437
|
+
num_points_sample=batch_select_mobo,
|
|
438
|
+
extra=pareto_front,
|
|
439
|
+
dataset = self.dataset,
|
|
440
|
+
pop=list_fi_init,
|
|
441
|
+
elev=elev, azim=azim, legend=legend, mode=self.mode, alpha_pf_3d=alpha_pf_3d,
|
|
442
|
+
label=label, images_store_path=images_store_path)
|
|
443
|
+
else:
|
|
444
|
+
self.plot_pareto_front(list_fi=None, t=0,
|
|
445
|
+
num_points_sample=batch_select_mobo,
|
|
446
|
+
extra=pareto_front,
|
|
447
|
+
pop=list_fi_init,
|
|
448
|
+
elev=elev, azim=azim, legend=legend, alpha_pf_3d=alpha_pf_3d,
|
|
449
|
+
label=label, images_store_path=images_store_path)
|
|
318
450
|
|
|
319
451
|
# initialize dominance-classifier for non-dominance relation
|
|
320
452
|
p_rel_map, s_rel_map = init_dom_rel_map(300)
|
|
@@ -394,7 +526,7 @@ class SPREAD:
|
|
|
394
526
|
use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
|
|
395
527
|
iterative_plot=iterative_plot, plot_period=plot_period,
|
|
396
528
|
plot_dataset=plot_dataset, plot_population=plot_population,
|
|
397
|
-
elev=elev, azim=azim, legend=legend,
|
|
529
|
+
elev=elev, azim=azim, legend=legend, alpha_pf_3d=alpha_pf_3d,
|
|
398
530
|
max_backtracks=max_backtracks, label=label,
|
|
399
531
|
samples_store_path=samples_store_path,
|
|
400
532
|
images_store_path=images_store_path,
|
|
@@ -416,6 +548,7 @@ class SPREAD:
|
|
|
416
548
|
|
|
417
549
|
pop_size_used = X_psl.shape[0]
|
|
418
550
|
|
|
551
|
+
|
|
419
552
|
# Mutate the new offspring
|
|
420
553
|
X_psl = pm_mutation(X_psl, [self.problem.xl.detach().cpu().numpy(),
|
|
421
554
|
self.problem.xu.detach().cpu().numpy()])
|
|
@@ -451,7 +584,7 @@ class SPREAD:
|
|
|
451
584
|
if hv_value_subset > best_hv_value:
|
|
452
585
|
best_hv_value = hv_value_subset
|
|
453
586
|
best_subset = [k]
|
|
454
|
-
|
|
587
|
+
|
|
455
588
|
Y_p = np.vstack([Y_p, Y_candidate_mean[best_subset]])
|
|
456
589
|
best_subset_list.append(best_subset)
|
|
457
590
|
|
|
@@ -460,7 +593,10 @@ class SPREAD:
|
|
|
460
593
|
X_candidate = X_psl
|
|
461
594
|
X_new = X_candidate[best_subset_list]
|
|
462
595
|
Y_new = self.problem.evaluate(torch.from_numpy(X_new).float().to(self.device)).detach().cpu().numpy()
|
|
463
|
-
|
|
596
|
+
|
|
597
|
+
list_fi_new = [Y_new[:, i] for i in range(Y_new.shape[1])]
|
|
598
|
+
all_selected_batch_y.append(list_fi_new)
|
|
599
|
+
|
|
464
600
|
Y_new = torch.tensor(Y_new).to(self.device)
|
|
465
601
|
X_new = torch.tensor(X_new).to(self.device)
|
|
466
602
|
|
|
@@ -480,6 +616,23 @@ class SPREAD:
|
|
|
480
616
|
|
|
481
617
|
hv_text = f"{hv_value:.4e}"
|
|
482
618
|
evaluated = evaluated + batch_select_mobo
|
|
619
|
+
|
|
620
|
+
if iterative_plot:
|
|
621
|
+
if self.plot_func is not None:
|
|
622
|
+
self.plot_func(list_fi=all_selected_batch_y, t=k_iter+1,
|
|
623
|
+
num_points_sample=batch_select_mobo,
|
|
624
|
+
extra=pareto_front,
|
|
625
|
+
dataset = self.dataset,
|
|
626
|
+
pop=list_fi_init, alpha_pf_3d=alpha_pf_3d,
|
|
627
|
+
elev=elev, azim=azim, legend=legend, mode=self.mode,
|
|
628
|
+
label=label, images_store_path=images_store_path)
|
|
629
|
+
else:
|
|
630
|
+
self.plot_pareto_front(list_fi=all_selected_batch_y, t=k_iter+1,
|
|
631
|
+
num_points_sample=batch_select_mobo,
|
|
632
|
+
extra=pareto_front,
|
|
633
|
+
pop=list_fi_init, alpha_pf_3d=alpha_pf_3d,
|
|
634
|
+
elev=elev, azim=azim, legend=legend,
|
|
635
|
+
label=label, images_store_path=images_store_path)
|
|
483
636
|
|
|
484
637
|
#### DECISION TO SWITCH OPERATOR ####
|
|
485
638
|
if use_escape_local_mobo:
|
|
@@ -558,13 +711,27 @@ class SPREAD:
|
|
|
558
711
|
with open(outfile, "wb") as f:
|
|
559
712
|
pickle.dump(hv_all_value, f)
|
|
560
713
|
|
|
561
|
-
return X, Y, hv_all_value
|
|
714
|
+
return [X, Y], hv_all_value
|
|
562
715
|
|
|
563
716
|
|
|
564
717
|
def train(self,
|
|
565
718
|
train_dataloader,
|
|
566
719
|
val_dataloader=None,
|
|
567
720
|
disable_progress_bar=False):
|
|
721
|
+
"""
|
|
722
|
+
Train the diffusion model (DDPM) using MSE loss on predicted noise.
|
|
723
|
+
|
|
724
|
+
Arguments
|
|
725
|
+
train_dataloader (DataLoader): Yields (x, obj_values) batches.
|
|
726
|
+
val_dataloader (DataLoader|None): Optional validation loader for early stopping
|
|
727
|
+
and best checkpoint saving.
|
|
728
|
+
disable_progress_bar (bool): Disables tqdm progress bar.
|
|
729
|
+
|
|
730
|
+
Returns
|
|
731
|
+
None. Saves checkpoints:
|
|
732
|
+
checkpoint_ddpm_best.pth (if validation enabled)
|
|
733
|
+
checkpoint_ddpm_last.pth (always)
|
|
734
|
+
"""
|
|
568
735
|
set_seed(self.seed)
|
|
569
736
|
if self.verbose:
|
|
570
737
|
print(datetime.datetime.now())
|
|
@@ -731,12 +898,37 @@ class SPREAD:
|
|
|
731
898
|
use_sigma_rep=False, kernel_sigma_rep=0.01,
|
|
732
899
|
iterative_plot=True, plot_period=100,
|
|
733
900
|
plot_dataset=False, plot_population=False,
|
|
734
|
-
elev=30, azim=45, legend=False,
|
|
901
|
+
elev=30, azim=45, legend=False, alpha_pf_3d=0.05,
|
|
735
902
|
max_backtracks=25, label=None,
|
|
736
903
|
samples_store_path="./samples_dir/",
|
|
737
904
|
images_store_path="./images_dir/",
|
|
738
905
|
disable_progress_bar=False,
|
|
739
906
|
save_results=True, evaluate_final=True):
|
|
907
|
+
"""
|
|
908
|
+
Generate num_points_sample decision vectors by reverse diffusion, with Pareto guidance
|
|
909
|
+
and optional plotting/saving.
|
|
910
|
+
|
|
911
|
+
Arguments
|
|
912
|
+
num_points_sample (int): Number of points to generate.
|
|
913
|
+
strict_guidance, rho_scale_gamma, nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h: Guidance controls.
|
|
914
|
+
use_sigma_rep, kernel_sigma_rep: Repulsion controls.
|
|
915
|
+
iterative_plot, legend, plot_period, plot_dataset, plot_population: Plot controls.
|
|
916
|
+
elev, azim, alpha_pf_3d: Plot styling.
|
|
917
|
+
max_backtracks (int): Armijo backtracking limit.
|
|
918
|
+
label (str|None): Extra name tag.
|
|
919
|
+
samples_store_path, images_store_path: Output directories.
|
|
920
|
+
disable_progress_bar (bool): Disables tqdm.
|
|
921
|
+
save_results (bool): Saves *_x.npy, optionally *_y.npy and HV pickle.
|
|
922
|
+
evaluate_final (bool): If True, evaluates objectives on final points and filters NaN/Inf.
|
|
923
|
+
|
|
924
|
+
Returns
|
|
925
|
+
(res_x, res_y)
|
|
926
|
+
res_x: np.ndarray, shape (K, n_var) sampled decision vectors.
|
|
927
|
+
res_y: np.ndarray, shape (K, n_obj) final objective values if evaluate_final=True, else None.
|
|
928
|
+
|
|
929
|
+
Raises
|
|
930
|
+
ValueError: If no trained diffusion checkpoint found.
|
|
931
|
+
"""
|
|
740
932
|
# Set the seed
|
|
741
933
|
set_seed(self.seed)
|
|
742
934
|
if save_results:
|
|
@@ -798,12 +990,12 @@ class SPREAD:
|
|
|
798
990
|
# Denormalize the points before plotting
|
|
799
991
|
res_x_t = pf_points.clone().detach()
|
|
800
992
|
res_x_t = self.offline_denormalization(res_x_t,
|
|
801
|
-
|
|
802
|
-
|
|
993
|
+
self.X_meanormin,
|
|
994
|
+
self.X_stdormax)
|
|
803
995
|
res_pop = pf_population.clone().detach()
|
|
804
996
|
res_pop = self.offline_denormalization(res_pop,
|
|
805
|
-
|
|
806
|
-
|
|
997
|
+
self.X_meanormin,
|
|
998
|
+
self.X_stdormax)
|
|
807
999
|
norm_xl, norm_xu = self.problem.bounds()
|
|
808
1000
|
xl, xu = self.problem.original_bounds
|
|
809
1001
|
self.problem.xl = xl
|
|
@@ -842,14 +1034,15 @@ class SPREAD:
|
|
|
842
1034
|
extra=pareto_front,
|
|
843
1035
|
plot_dataset=plot_dataset,
|
|
844
1036
|
dataset = self.dataset,
|
|
845
|
-
|
|
1037
|
+
pop=list_fi_pop, alpha_pf_3d=alpha_pf_3d,
|
|
1038
|
+
elev=elev, azim=azim, legend=legend, mode=self.mode,
|
|
846
1039
|
label=label, images_store_path=images_store_path)
|
|
847
1040
|
else:
|
|
848
1041
|
self.plot_pareto_front(list_fi, self.timesteps,
|
|
849
1042
|
num_points_sample,
|
|
850
1043
|
extra=pareto_front,
|
|
851
1044
|
plot_dataset=plot_dataset,
|
|
852
|
-
pop=list_fi_pop,
|
|
1045
|
+
pop=list_fi_pop, alpha_pf_3d=alpha_pf_3d,
|
|
853
1046
|
elev=elev, azim=azim, legend=legend,
|
|
854
1047
|
label=label, images_store_path=images_store_path)
|
|
855
1048
|
|
|
@@ -927,7 +1120,7 @@ class SPREAD:
|
|
|
927
1120
|
pf_population,
|
|
928
1121
|
keep_shape=False
|
|
929
1122
|
)
|
|
930
|
-
|
|
1123
|
+
|
|
931
1124
|
if prev_pf_points is not None:
|
|
932
1125
|
pf_points = torch.cat((prev_pf_points, pf_points), dim=0)
|
|
933
1126
|
if self.mode != "bayesian":
|
|
@@ -935,6 +1128,7 @@ class SPREAD:
|
|
|
935
1128
|
pf_points,
|
|
936
1129
|
keep_shape=False,
|
|
937
1130
|
)
|
|
1131
|
+
|
|
938
1132
|
if len(pf_points) > num_points_sample:
|
|
939
1133
|
pf_points = self.select_top_n_candidates(
|
|
940
1134
|
pf_points,
|
|
@@ -946,76 +1140,73 @@ class SPREAD:
|
|
|
946
1140
|
prev_pf_points = pf_points
|
|
947
1141
|
num_optimal_points = len(pf_points)
|
|
948
1142
|
|
|
949
|
-
if
|
|
950
|
-
if self.problem.
|
|
951
|
-
if
|
|
952
|
-
if self.
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
list_fi_pop = self.objective_functions(res_pop,
|
|
1143
|
+
if self.mode in ["online", "offline"]:
|
|
1144
|
+
if iterative_plot and (not is_pass_function(self.problem._evaluate)):
|
|
1145
|
+
if self.problem.n_obj <= 3:
|
|
1146
|
+
if (t % plot_period == 0) or (t == self.timesteps - 1):
|
|
1147
|
+
if self.mode == "offline":
|
|
1148
|
+
# Denormalize the points before plotting
|
|
1149
|
+
res_x_t = pf_points.clone().detach()
|
|
1150
|
+
res_x_t = self.offline_denormalization(res_x_t,
|
|
1151
|
+
self.X_meanormin,
|
|
1152
|
+
self.X_stdormax)
|
|
1153
|
+
res_pop = pf_population.clone().detach()
|
|
1154
|
+
res_pop = self.offline_denormalization(res_pop,
|
|
1155
|
+
self.X_meanormin,
|
|
1156
|
+
self.X_stdormax)
|
|
1157
|
+
norm_xl, norm_xu = self.problem.bounds()
|
|
1158
|
+
xl, xu = self.problem.original_bounds
|
|
1159
|
+
self.problem.xl = xl
|
|
1160
|
+
self.problem.xu = xu
|
|
1161
|
+
if self.problem.is_discrete:
|
|
1162
|
+
_, dim, n_classes = tuple(res_x_t.shape)
|
|
1163
|
+
res_x_t = res_x_t.reshape(-1, dim, n_classes)
|
|
1164
|
+
res_x_t = offdata_to_integers(res_x_t)
|
|
1165
|
+
|
|
1166
|
+
_, dim_pop, n_classes_pop = tuple(res_pop.shape)
|
|
1167
|
+
res_pop = res_pop.reshape(-1, dim_pop, n_classes_pop)
|
|
1168
|
+
res_pop = offdata_to_integers(res_pop)
|
|
1169
|
+
if self.problem.is_sequence:
|
|
1170
|
+
res_x_t = offdata_to_integers(res_x_t)
|
|
1171
|
+
res_pop = offdata_to_integers(res_pop)
|
|
1172
|
+
# we need to evaluate the true objective functions for plotting
|
|
1173
|
+
list_fi = self.objective_functions(res_x_t,
|
|
981
1174
|
evaluate_true=True).split(1, dim=1)
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
if self.plot_func is not None:
|
|
1003
|
-
self.plot_func(list_fi, t,
|
|
1004
|
-
num_points_sample,
|
|
1005
|
-
extra= pareto_front,
|
|
1006
|
-
plot_dataset=plot_dataset,
|
|
1007
|
-
dataset = self.dataset,
|
|
1008
|
-
elev=elev, azim=azim, legend=legend,
|
|
1009
|
-
label=label, images_store_path=images_store_path)
|
|
1010
|
-
else:
|
|
1011
|
-
self.plot_pareto_front(list_fi, t,
|
|
1175
|
+
list_fi_pop = self.objective_functions(res_pop,
|
|
1176
|
+
evaluate_true=True).split(1, dim=1)
|
|
1177
|
+
list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
|
|
1178
|
+
# restore the normalized bounds
|
|
1179
|
+
self.problem.xl = norm_xl
|
|
1180
|
+
self.problem.xu = norm_xu
|
|
1181
|
+
else:
|
|
1182
|
+
list_fi = self.objective_functions(pf_points).split(1, dim=1)
|
|
1183
|
+
list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
|
|
1184
|
+
list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
|
|
1185
|
+
|
|
1186
|
+
list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
|
|
1187
|
+
pareto_front = None
|
|
1188
|
+
if self.problem.pareto_front() is not None:
|
|
1189
|
+
pareto_front = self.problem.pareto_front()
|
|
1190
|
+
pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
|
|
1191
|
+
|
|
1192
|
+
if self.plot_func is not None:
|
|
1193
|
+
self.plot_func(list_fi, t,
|
|
1012
1194
|
num_points_sample,
|
|
1013
1195
|
extra= pareto_front,
|
|
1014
1196
|
pop=list_fi_pop if plot_population else None,
|
|
1015
1197
|
plot_dataset=plot_dataset,
|
|
1016
|
-
|
|
1198
|
+
dataset = self.dataset, alpha_pf_3d=alpha_pf_3d,
|
|
1199
|
+
elev=elev, azim=azim, legend=legend, mode=self.mode,
|
|
1017
1200
|
label=label, images_store_path=images_store_path)
|
|
1018
|
-
|
|
1201
|
+
else:
|
|
1202
|
+
self.plot_pareto_front(list_fi, t,
|
|
1203
|
+
num_points_sample,
|
|
1204
|
+
extra= pareto_front,
|
|
1205
|
+
pop=list_fi_pop if plot_population else None,
|
|
1206
|
+
plot_dataset=plot_dataset, alpha_pf_3d=alpha_pf_3d,
|
|
1207
|
+
elev=elev, azim=azim, legend=legend,
|
|
1208
|
+
label=label, images_store_path=images_store_path)
|
|
1209
|
+
|
|
1019
1210
|
|
|
1020
1211
|
x_t = x_t.detach()
|
|
1021
1212
|
pbar.set_postfix({
|
|
@@ -1081,6 +1272,25 @@ class SPREAD:
|
|
|
1081
1272
|
lr=1e-3,
|
|
1082
1273
|
lr_decay=0.95,
|
|
1083
1274
|
n_epochs=200):
|
|
1275
|
+
"""
|
|
1276
|
+
Train or fit the surrogate model depending on mode.
|
|
1277
|
+
bayesian: fits a single multi-output surrogate.
|
|
1278
|
+
offline: trains one proxy per objective (then loads them into a list).
|
|
1279
|
+
If a user surrogate was provided, delegates to train_surrogate_user_defined.
|
|
1280
|
+
|
|
1281
|
+
Arguments
|
|
1282
|
+
X (torch.Tensor or np.ndarray): Training inputs.
|
|
1283
|
+
y (torch.Tensor or np.ndarray): Training objectives.
|
|
1284
|
+
val_ratio (float): Validation split for offline proxy training.
|
|
1285
|
+
batch_size (int): Batch size for proxy training.
|
|
1286
|
+
lr, lr_decay, n_epochs: Proxy training hyperparameters.
|
|
1287
|
+
|
|
1288
|
+
Returns
|
|
1289
|
+
None. Updates self.surrogate_model (list of models in offline, single model in bayesian).
|
|
1290
|
+
|
|
1291
|
+
Raises
|
|
1292
|
+
ValueError: If called in a mode that does not support surrogates.
|
|
1293
|
+
"""
|
|
1084
1294
|
|
|
1085
1295
|
set_seed(self.seed)
|
|
1086
1296
|
self.surrogate_model = self.get_surrogate()
|
|
@@ -1143,22 +1353,35 @@ class SPREAD:
|
|
|
1143
1353
|
|
|
1144
1354
|
def train_surrogate_user_defined(self, X, y):
|
|
1145
1355
|
"""
|
|
1146
|
-
Train
|
|
1147
|
-
If self.mode == "offline", the train_func should return a list of trained surrogate models,
|
|
1148
|
-
one for each objective.
|
|
1149
|
-
If self.mode == "bayesian", the train_func should return a single trained surrogate model for all objectives.
|
|
1356
|
+
Train a user-provided surrogate via self.train_func_surrogate.
|
|
1150
1357
|
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1358
|
+
Arguments
|
|
1359
|
+
X: Training inputs.
|
|
1360
|
+
y: Training objectives.
|
|
1361
|
+
|
|
1362
|
+
Returns
|
|
1363
|
+
None. Sets self.surrogate_model to whatever train_func_surrogate(X, y) returns:
|
|
1364
|
+
offline: typically a list of per-objective models
|
|
1365
|
+
bayesian: typically a single multi-objective surrogate.
|
|
1158
1366
|
"""
|
|
1159
1367
|
self.surrogate_model = self.train_func_surrogate(X, y)
|
|
1160
1368
|
|
|
1161
1369
|
def get_surrogate(self):
|
|
1370
|
+
"""
|
|
1371
|
+
Construct a default surrogate model if the user did not provide one.
|
|
1372
|
+
|
|
1373
|
+
Arguments
|
|
1374
|
+
None.
|
|
1375
|
+
|
|
1376
|
+
Returns
|
|
1377
|
+
If user provided a surrogate: returns it as-is.
|
|
1378
|
+
Else:
|
|
1379
|
+
bayesian: returns GaussianProcess(...)
|
|
1380
|
+
offline: returns MultipleModels(...) configured for per-objective proxies.
|
|
1381
|
+
|
|
1382
|
+
Raises
|
|
1383
|
+
ValueError: If called in a mode without surrogate support.
|
|
1384
|
+
"""
|
|
1162
1385
|
if self.surrogate_given:
|
|
1163
1386
|
return self.surrogate_model
|
|
1164
1387
|
else:
|
|
@@ -1189,6 +1412,30 @@ class SPREAD:
|
|
|
1189
1412
|
use_sigma=False, kernel_sigma=1.0, strict_guidance = False,
|
|
1190
1413
|
max_backtracks=100, point_n0=None, optimizer_n0=None,
|
|
1191
1414
|
):
|
|
1415
|
+
"""
|
|
1416
|
+
Perform one reverse-diffusion step plus Pareto guidance:
|
|
1417
|
+
1. DDPM reverse step using predicted noise
|
|
1418
|
+
2. Compute gradients of objectives
|
|
1419
|
+
3. Compute MGDA direction
|
|
1420
|
+
4. Armijo step-size selection
|
|
1421
|
+
5. Solve inner problem for h_tilde (alignment + repulsion)
|
|
1422
|
+
6. Update x_t
|
|
1423
|
+
|
|
1424
|
+
Arguments (main ones)
|
|
1425
|
+
x_t (torch.Tensor): Current samples, shape (N, n_var).
|
|
1426
|
+
num_points_sample (int): N.
|
|
1427
|
+
t (int): Current timestep index.
|
|
1428
|
+
beta_t, alpha_bar_t: Diffusion schedule values at t.
|
|
1429
|
+
rho_scale_gamma, nu_t, eta_init, num_inner_steps, lr_inner, free_initial_h: Guidance parameters.
|
|
1430
|
+
use_sigma, kernel_sigma: Repulsion settings.
|
|
1431
|
+
strict_guidance (bool): Uses target-direction perturbation from point_n0.
|
|
1432
|
+
max_backtracks (int): Armijo backtracking cap.
|
|
1433
|
+
point_n0 (torch.Tensor|None): Target point used in strict guidance.
|
|
1434
|
+
optimizer_n0 (Optimizer|None): Optimizer updating point_n0.
|
|
1435
|
+
|
|
1436
|
+
Returns
|
|
1437
|
+
x_t (torch.Tensor): Updated samples after one SPREAD step, shape (N, n_var).
|
|
1438
|
+
"""
|
|
1192
1439
|
|
|
1193
1440
|
# Create a tensor of timesteps with shape (num_points_sample, 1)
|
|
1194
1441
|
t_tensor = torch.full(
|
|
@@ -1246,7 +1493,7 @@ class SPREAD:
|
|
|
1246
1493
|
std_dev = torch.sqrt(beta_t)
|
|
1247
1494
|
z = torch.randn_like(x_t) if t > 0 else 0.0 # No noise for the final step
|
|
1248
1495
|
x_t = mean + std_dev * z
|
|
1249
|
-
|
|
1496
|
+
|
|
1250
1497
|
#### Pareto Guidance step
|
|
1251
1498
|
if self.problem.need_repair:
|
|
1252
1499
|
x_t.data = self.repair_bounds(x_t.data.clone())
|
|
@@ -1300,10 +1547,13 @@ class SPREAD:
|
|
|
1300
1547
|
rho_scale_gamma=rho_scale_gamma
|
|
1301
1548
|
)
|
|
1302
1549
|
|
|
1303
|
-
h_tilde = torch.nan_to_num(h_tilde,
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1550
|
+
# h_tilde = torch.nan_to_num(h_tilde,
|
|
1551
|
+
# nan=torch.nanmean(h_tilde),
|
|
1552
|
+
# posinf=0.0,
|
|
1553
|
+
# neginf=0.0)
|
|
1554
|
+
finite = torch.isfinite(h_tilde)
|
|
1555
|
+
fill = h_tilde[finite].mean() if finite.any() else h_tilde.new_tensor(0.0)
|
|
1556
|
+
h_tilde = torch.where(finite, h_tilde, fill)
|
|
1307
1557
|
|
|
1308
1558
|
x_t = x_t - eta * h_tilde
|
|
1309
1559
|
|
|
@@ -1326,8 +1576,28 @@ class SPREAD:
|
|
|
1326
1576
|
rho_scale_gamma=0.9
|
|
1327
1577
|
):
|
|
1328
1578
|
"""
|
|
1329
|
-
|
|
1330
|
-
|
|
1579
|
+
Inner-loop optimizer that learns/updates h to trade off:
|
|
1580
|
+
alignment with MGDA direction g (descent)
|
|
1581
|
+
repulsion/diversity in objective space (kernel-based)
|
|
1582
|
+
|
|
1583
|
+
Arguments
|
|
1584
|
+
x_t_prime (torch.Tensor): Starting points, shape (N, n_var).
|
|
1585
|
+
g_x_t_prime (torch.Tensor): MGDA/PMGDA direction, shape (N, n_var).
|
|
1586
|
+
grads (torch.Tensor): Objective gradients, shape (m, N, n_var).
|
|
1587
|
+
g_w (torch.Tensor|None): Optional strict-guidance target direction, shape compatible
|
|
1588
|
+
with (N, n_var) or (1, n_var).
|
|
1589
|
+
eta (torch.Tensor): Step size, shape (N, 1) (broadcasted).
|
|
1590
|
+
nu_t (float): Weight of repulsion term.
|
|
1591
|
+
sigma (float): Kernel bandwidth when use_sigma=True.
|
|
1592
|
+
use_sigma (bool): Whether to use fixed sigma.
|
|
1593
|
+
num_inner_steps (int): Number of optimization steps on h.
|
|
1594
|
+
lr_inner (float): Learning rate for inner optimizer.
|
|
1595
|
+
strict_guidance (bool): If True, uses g_w as target direction.
|
|
1596
|
+
free_initial_h (bool): If False, initialize h at g; else initialize small constant.
|
|
1597
|
+
rho_scale_gamma (float): Safety factor used in adaptive scaling.
|
|
1598
|
+
|
|
1599
|
+
Returns
|
|
1600
|
+
h_tilde (torch.Tensor): Final guidance direction, shape (N, n_var) (detached).
|
|
1331
1601
|
"""
|
|
1332
1602
|
|
|
1333
1603
|
x_t_h = x_t_prime.clone().detach()
|
|
@@ -1384,7 +1654,16 @@ class SPREAD:
|
|
|
1384
1654
|
|
|
1385
1655
|
def get_training_data(self, problem, num_samples=10000):
|
|
1386
1656
|
"""
|
|
1387
|
-
|
|
1657
|
+
Generate a training dataset using LHS sampling within bounds and evaluate objectives.
|
|
1658
|
+
|
|
1659
|
+
Arguments
|
|
1660
|
+
problem: Problem instance with bounds and evaluate.
|
|
1661
|
+
num_samples (int): Number of sampled candidates.
|
|
1662
|
+
|
|
1663
|
+
Returns
|
|
1664
|
+
(Xcand, F)
|
|
1665
|
+
Xcand: sampled decision vectors
|
|
1666
|
+
F: evaluated objectives.
|
|
1388
1667
|
"""
|
|
1389
1668
|
sampler = LHS()
|
|
1390
1669
|
# Problem bounds
|
|
@@ -1398,15 +1677,15 @@ class SPREAD:
|
|
|
1398
1677
|
|
|
1399
1678
|
def betas_for_alpha_bar(self, T, alpha_bar, max_beta=0.999):
|
|
1400
1679
|
"""
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1680
|
+
Discretize a continuous cumulative alpha-bar function into a beta schedule.
|
|
1681
|
+
|
|
1682
|
+
Arguments
|
|
1683
|
+
T (int): Number of steps.
|
|
1684
|
+
alpha_bar (callable): Function t ∈ [0,1] -> ᾱ(t).
|
|
1685
|
+
max_beta (float): Clamp for numerical stability.
|
|
1686
|
+
|
|
1687
|
+
Returns
|
|
1688
|
+
torch.Tensor: Betas of shape (T,).
|
|
1410
1689
|
"""
|
|
1411
1690
|
betas = []
|
|
1412
1691
|
for i in range(T):
|
|
@@ -1417,7 +1696,13 @@ class SPREAD:
|
|
|
1417
1696
|
|
|
1418
1697
|
def cosine_beta_schedule(self, s=0.008):
|
|
1419
1698
|
"""
|
|
1420
|
-
|
|
1699
|
+
Compute cosine-based beta schedule for self.timesteps.
|
|
1700
|
+
|
|
1701
|
+
Arguments
|
|
1702
|
+
s (float): Offset used in cosine schedule.
|
|
1703
|
+
|
|
1704
|
+
Returns
|
|
1705
|
+
torch.Tensor: Betas of shape (self.timesteps,).
|
|
1421
1706
|
"""
|
|
1422
1707
|
return self.betas_for_alpha_bar(
|
|
1423
1708
|
self.timesteps,
|
|
@@ -1425,9 +1710,36 @@ class SPREAD:
|
|
|
1425
1710
|
)
|
|
1426
1711
|
|
|
1427
1712
|
def l_simple_loss(self, predicted_noise, actual_noise):
|
|
1713
|
+
"""
|
|
1714
|
+
Compute DDPM training loss (MSE between predicted and true noise).
|
|
1715
|
+
|
|
1716
|
+
Arguments
|
|
1717
|
+
predicted_noise (torch.Tensor): Predicted noise.
|
|
1718
|
+
actual_noise (torch.Tensor): True sampled noise.
|
|
1719
|
+
|
|
1720
|
+
Returns
|
|
1721
|
+
torch.Tensor: Scalar loss.
|
|
1722
|
+
"""
|
|
1428
1723
|
return nn.MSELoss()(predicted_noise, actual_noise)
|
|
1429
1724
|
|
|
1430
1725
|
def get_target_dir(self, grads, mth="mgda", x=None):
|
|
1726
|
+
"""
|
|
1727
|
+
Compute a single descent direction from multiple objective gradients using:
|
|
1728
|
+
mgda: convex combination minimizing norm of weighted gradient
|
|
1729
|
+
pmgda: constrained variant using PMGDASolver and constraints G/H
|
|
1730
|
+
|
|
1731
|
+
Arguments
|
|
1732
|
+
grads (list[torch.Tensor]): List of gradients, each shape (N, n_var) (or consistent tensor shape).
|
|
1733
|
+
mth (str): "mgda" or "pmgda".
|
|
1734
|
+
x (torch.Tensor|None): Required for pmgda (used for constraint evaluation and weights).
|
|
1735
|
+
|
|
1736
|
+
Returns
|
|
1737
|
+
torch.Tensor: Combined direction g, same shape as one gradient (N, n_var).
|
|
1738
|
+
|
|
1739
|
+
Raises
|
|
1740
|
+
AssertionError: If constraints exist but mth="mgda".
|
|
1741
|
+
ValueError: Unknown method.
|
|
1742
|
+
"""
|
|
1431
1743
|
m = len(grads)
|
|
1432
1744
|
if self.problem.n_ieq_constr + self.problem.n_eq_constr > 0:
|
|
1433
1745
|
assert mth != "mgda", "MGDA not supported with constraints. Use mth ='pmgda'."
|
|
@@ -1505,12 +1817,20 @@ class SPREAD:
|
|
|
1505
1817
|
max_backtracks=100,
|
|
1506
1818
|
):
|
|
1507
1819
|
"""
|
|
1508
|
-
Batched Armijo
|
|
1509
|
-
|
|
1820
|
+
Batched Armijo backtracking line search for multi-objective descent.
|
|
1821
|
+
|
|
1822
|
+
Arguments
|
|
1823
|
+
x_t (torch.Tensor): Current points (N, n_var).
|
|
1824
|
+
d (torch.Tensor): Descent direction (N, n_var).
|
|
1825
|
+
f_old (torch.Tensor): Current objective values (N, m).
|
|
1826
|
+
grads (torch.Tensor): Gradients (N, m, n_var) (your code uses einsum accordingly).
|
|
1827
|
+
eta_init (float|torch.Tensor): Initial step size (scalar or per-point).
|
|
1828
|
+
rho (float): Backtracking shrink factor.
|
|
1829
|
+
c1 (float): Armijo constant.
|
|
1830
|
+
max_backtracks (int): Maximum backtracking iterations.
|
|
1831
|
+
|
|
1510
1832
|
Returns
|
|
1511
|
-
|
|
1512
|
-
eta : torch.Tensor, shape (N,)
|
|
1513
|
-
Final step sizes.
|
|
1833
|
+
eta (torch.Tensor): Step sizes shaped (N, 1).
|
|
1514
1834
|
"""
|
|
1515
1835
|
|
|
1516
1836
|
x = x_t.clone().detach()
|
|
@@ -1556,14 +1876,14 @@ class SPREAD:
|
|
|
1556
1876
|
|
|
1557
1877
|
∇f_j(x_i)^T [ g_i + rho_i * delta_raw_i ] > 0 for all j.
|
|
1558
1878
|
|
|
1559
|
-
|
|
1879
|
+
Arguments
|
|
1560
1880
|
g (torch.Tensor): [n_points, d], the multi-objective "gradient"
|
|
1561
1881
|
(which we *subtract* in the update).
|
|
1562
1882
|
delta_raw (torch.Tensor): [n_points, d] or [1, d], the unscaled diversity/repulsion direction.
|
|
1563
1883
|
grads (torch.Tensor): [m, n_points, d], storing ∇f_j(x_i).
|
|
1564
1884
|
gamma (float): Safety factor in (0,1).
|
|
1565
1885
|
|
|
1566
|
-
Returns
|
|
1886
|
+
Returns
|
|
1567
1887
|
delta_scaled (torch.Tensor): [n_points, d], scaled directions s.t.
|
|
1568
1888
|
for all j: grads[j,i]ᵀ [g[i] + delta_scaled[i]] > 0.
|
|
1569
1889
|
"""
|
|
@@ -1603,14 +1923,13 @@ class SPREAD:
|
|
|
1603
1923
|
|
|
1604
1924
|
def repair_bounds(self, x):
|
|
1605
1925
|
"""
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
x (torch.Tensor): Input tensor of shape [N, d].
|
|
1926
|
+
Clamp candidate decision vectors into problem bounds (either per-dimension or globally).
|
|
1927
|
+
|
|
1928
|
+
Arguments
|
|
1929
|
+
x (torch.Tensor): Shape (N, n_var).
|
|
1611
1930
|
|
|
1612
|
-
Returns
|
|
1613
|
-
torch.Tensor:
|
|
1931
|
+
Returns
|
|
1932
|
+
torch.Tensor: Clipped tensor of same shape.
|
|
1614
1933
|
"""
|
|
1615
1934
|
|
|
1616
1935
|
xl, xu = self.problem.bounds()[0], self.problem.bounds()[1]
|
|
@@ -1627,9 +1946,15 @@ class SPREAD:
|
|
|
1627
1946
|
sigma=1.0,
|
|
1628
1947
|
use_sigma=False):
|
|
1629
1948
|
"""
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1949
|
+
Compute RBF-style repulsion loss in objective space to encourage diversity.
|
|
1950
|
+
|
|
1951
|
+
Arguments
|
|
1952
|
+
F_ (torch.Tensor): Objective values, shape (N, m).
|
|
1953
|
+
sigma (float): Kernel bandwidth when use_sigma=True.
|
|
1954
|
+
use_sigma (bool): Whether to use fixed sigma or median heuristic.
|
|
1955
|
+
|
|
1956
|
+
Returns
|
|
1957
|
+
torch.Tensor: Scalar repulsion loss.
|
|
1633
1958
|
"""
|
|
1634
1959
|
n = F_.shape[0]
|
|
1635
1960
|
# Compute pairwise differences: shape [n, n, m]
|
|
@@ -1649,6 +1974,16 @@ class SPREAD:
|
|
|
1649
1974
|
return loss
|
|
1650
1975
|
|
|
1651
1976
|
def eps_dominance(self, Obj_space, alpha=0.0):
|
|
1977
|
+
"""
|
|
1978
|
+
Compute indices of (epsilon-)non-dominated points using an epsilon shift epsilon = alpha * min(Obj_space).
|
|
1979
|
+
|
|
1980
|
+
Arguments
|
|
1981
|
+
Obj_space (np.ndarray): Objective values (N, m).
|
|
1982
|
+
alpha (float): Epsilon scaling factor.
|
|
1983
|
+
|
|
1984
|
+
Returns
|
|
1985
|
+
list[int]: Indices of epsilon-nondominated points.
|
|
1986
|
+
"""
|
|
1652
1987
|
epsilon = alpha * np.min(Obj_space, axis=0)
|
|
1653
1988
|
N = len(Obj_space)
|
|
1654
1989
|
Pareto_set_idx = list(range(N))
|
|
@@ -1670,6 +2005,30 @@ class SPREAD:
|
|
|
1670
2005
|
indx_only=False,
|
|
1671
2006
|
p_front=None
|
|
1672
2007
|
):
|
|
2008
|
+
"""
|
|
2009
|
+
Extract nondominated points (or indices) depending on mode:
|
|
2010
|
+
offline/online: uses eps_dominance
|
|
2011
|
+
bayesian: uses dominance classifier pairwise predictions
|
|
2012
|
+
|
|
2013
|
+
Arguments
|
|
2014
|
+
points_pred (torch.Tensor|None): Decision points (N, n_var).
|
|
2015
|
+
keep_shape (bool): If True, returns a tensor with same length N by replacing dominated
|
|
2016
|
+
entries with nearest nondominated neighbor.
|
|
2017
|
+
indx_only (bool): If True, only returns indices.
|
|
2018
|
+
p_front (np.ndarray|None): If provided and points_pred is None, uses this objective array directly.
|
|
2019
|
+
|
|
2020
|
+
Returns
|
|
2021
|
+
If indx_only=True: PS_idx (list[int])
|
|
2022
|
+
Else if keep_shape=True: (pf_points, points_pred, PS_idx)
|
|
2023
|
+
pf_points: tensor shaped like points_pred (length N)
|
|
2024
|
+
points_pred: original input
|
|
2025
|
+
PS_idx: nondominated indices
|
|
2026
|
+
Else: (pf_points[PS_idx], points_pred, PS_idx)
|
|
2027
|
+
|
|
2028
|
+
Raises
|
|
2029
|
+
ValueError: If insufficient inputs are provided.
|
|
2030
|
+
"""
|
|
2031
|
+
|
|
1673
2032
|
if not indx_only and points_pred is None:
|
|
1674
2033
|
raise ValueError("points_pred cannot be None when indx_only is False.")
|
|
1675
2034
|
if points_pred is not None:
|
|
@@ -1720,9 +2079,13 @@ class SPREAD:
|
|
|
1720
2079
|
|
|
1721
2080
|
def crowding_distance(self, points):
|
|
1722
2081
|
"""
|
|
1723
|
-
Compute crowding
|
|
1724
|
-
|
|
1725
|
-
|
|
2082
|
+
Compute crowding distance (NSGA-II style) for a set of objective points.
|
|
2083
|
+
|
|
2084
|
+
Arguments
|
|
2085
|
+
points (torch.Tensor): Objective values, shape (N, m).
|
|
2086
|
+
|
|
2087
|
+
Returns
|
|
2088
|
+
torch.Tensor: Crowding distances, shape (N,) (with boundary points set to inf per objective).
|
|
1726
2089
|
"""
|
|
1727
2090
|
N, D = points.shape
|
|
1728
2091
|
distances = torch.zeros(N, device=points.device)
|
|
@@ -1749,10 +2112,17 @@ class SPREAD:
|
|
|
1749
2112
|
top_frac: float = 0.9
|
|
1750
2113
|
) -> torch.Tensor:
|
|
1751
2114
|
"""
|
|
1752
|
-
|
|
2115
|
+
Select a subset of candidate decision vectors based on diversity and (in bayesian)
|
|
2116
|
+
predicted nondominance.
|
|
2117
|
+
|
|
2118
|
+
Arguments
|
|
2119
|
+
points (torch.Tensor): Candidate decision vectors, shape (N, n_var).
|
|
2120
|
+
n (int): Number of points to return.
|
|
2121
|
+
top_frac (float): Bayesian-only fallback: fraction of best-ranked points considered
|
|
2122
|
+
before crowding selection.
|
|
1753
2123
|
|
|
1754
|
-
Returns
|
|
1755
|
-
torch.Tensor:
|
|
2124
|
+
Returns
|
|
2125
|
+
torch.Tensor: Selected decision vectors, shape (min(n,N), n_var).
|
|
1756
2126
|
"""
|
|
1757
2127
|
|
|
1758
2128
|
if self.mode in ["online", "offline"]:
|
|
@@ -1848,8 +2218,31 @@ class SPREAD:
|
|
|
1848
2218
|
label=None,
|
|
1849
2219
|
plot_dataset=False,
|
|
1850
2220
|
pop=None,
|
|
1851
|
-
elev=30, azim=45, legend=False,
|
|
2221
|
+
elev=30, azim=45, legend=False, alpha_pf_3d=0.05,
|
|
1852
2222
|
images_store_path="./images_dir/"):
|
|
2223
|
+
"""
|
|
2224
|
+
Save a 2D/3D scatter plot of generated Pareto(-like) points (and optionally dataset,
|
|
2225
|
+
population, and true Pareto front).
|
|
2226
|
+
|
|
2227
|
+
Arguments
|
|
2228
|
+
list_fi: Objectives to plot.
|
|
2229
|
+
offline/online: list of arrays [f1, f2] or [f1, f2, f3]
|
|
2230
|
+
bayesian: can be a list over iterations, each element containing [f1, f2] (or 3D)
|
|
2231
|
+
t (int): Reverse timestep (or MOBO iteration).
|
|
2232
|
+
num_points_sample (int): Used in filename.
|
|
2233
|
+
extra: Optional true Pareto front arrays.
|
|
2234
|
+
label: Optional name suffix.
|
|
2235
|
+
plot_dataset (bool): If True, plots training data (offline).
|
|
2236
|
+
pop: Optional population objectives to plot.
|
|
2237
|
+
elev, azim: 3D view parameters.
|
|
2238
|
+
legend (bool): Add legend.
|
|
2239
|
+
alpha_pf_3d (float): Alpha for 3D Pareto front points.
|
|
2240
|
+
images_store_path (str): Folder to save images.
|
|
2241
|
+
|
|
2242
|
+
Returns
|
|
2243
|
+
None. Writes an image file to disk.
|
|
2244
|
+
"""
|
|
2245
|
+
|
|
1853
2246
|
name = (
|
|
1854
2247
|
"spread"
|
|
1855
2248
|
+ "_"
|
|
@@ -1866,94 +2259,177 @@ class SPREAD:
|
|
|
1866
2259
|
+ self.mode
|
|
1867
2260
|
)
|
|
1868
2261
|
if label is not None:
|
|
1869
|
-
name += f"_{label}"
|
|
2262
|
+
name += f"_{label}"
|
|
1870
2263
|
|
|
1871
|
-
if
|
|
2264
|
+
if self.problem.n_obj > 3:
|
|
1872
2265
|
return None
|
|
2266
|
+
|
|
2267
|
+
if self.mode != "bayesian":
|
|
2268
|
+
if len(list_fi) == 2:
|
|
2269
|
+
fig, ax = plt.subplots()
|
|
2270
|
+
if plot_dataset and (self.dataset) is not None:
|
|
2271
|
+
_, Y = self.dataset
|
|
2272
|
+
# Denormalize the data
|
|
2273
|
+
Y = self.offline_denormalization(Y,
|
|
2274
|
+
self.y_meanormin,
|
|
2275
|
+
self.y_stdormax)
|
|
2276
|
+
ax.scatter(Y[:, 0], Y[:, 1],
|
|
2277
|
+
c="violet", s=5, alpha=1.0,
|
|
2278
|
+
label="Training Data")
|
|
2279
|
+
|
|
2280
|
+
if extra is not None:
|
|
2281
|
+
f1, f2 = extra
|
|
2282
|
+
ax.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,
|
|
2283
|
+
label="Pareto Optimal")
|
|
2284
|
+
|
|
2285
|
+
if pop is not None:
|
|
2286
|
+
f_pop1, f_pop2 = pop
|
|
2287
|
+
ax.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,
|
|
2288
|
+
label="Gen Population")
|
|
2289
|
+
|
|
2290
|
+
f1, f2 = list_fi
|
|
2291
|
+
ax.scatter(f1, f2, c="red", s=10, alpha=1.0,
|
|
2292
|
+
label="Gen Optimal")
|
|
2293
|
+
|
|
2294
|
+
ax.set_xlabel("$f_1$", fontsize=14)
|
|
2295
|
+
ax.set_ylabel("$f_2$", fontsize=14)
|
|
2296
|
+
ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
|
|
2297
|
+
ax.text(
|
|
2298
|
+
-0.17, 0.5,
|
|
2299
|
+
self.problem.__class__.__name__.upper() + f"({self.mode})",
|
|
2300
|
+
transform=ax.transAxes,
|
|
2301
|
+
va='center',
|
|
2302
|
+
ha='center',
|
|
2303
|
+
rotation='vertical',
|
|
2304
|
+
fontsize=20,
|
|
2305
|
+
fontweight='bold'
|
|
2306
|
+
)
|
|
1873
2307
|
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
2308
|
+
elif len(list_fi) == 3:
|
|
2309
|
+
fig = plt.figure()
|
|
2310
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
2311
|
+
|
|
2312
|
+
if plot_dataset and (self.dataset is not None):
|
|
2313
|
+
_, Y = self.dataset
|
|
2314
|
+
# Denormalize the data
|
|
2315
|
+
Y = self.offline_denormalization(Y,
|
|
2316
|
+
self.y_meanormin,
|
|
2317
|
+
self.y_stdormax)
|
|
2318
|
+
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
|
|
1883
2319
|
c="violet", s=5, alpha=1.0,
|
|
1884
2320
|
label="Training Data")
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
2321
|
+
|
|
2322
|
+
if extra is not None:
|
|
2323
|
+
f1, f2, f3 = extra
|
|
2324
|
+
ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=alpha_pf_3d,
|
|
1889
2325
|
label="Pareto Optimal")
|
|
1890
2326
|
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
2327
|
+
if pop is not None:
|
|
2328
|
+
f_pop1, f_pop2, f_pop3 = pop
|
|
2329
|
+
ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,
|
|
1894
2330
|
label="Gen Population")
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
|
|
2331
|
+
|
|
2332
|
+
f1, f2, f3 = list_fi
|
|
2333
|
+
ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,
|
|
1898
2334
|
label="Gen Optimal")
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
2335
|
+
|
|
2336
|
+
ax.set_xlabel("$f_1$", fontsize=14)
|
|
2337
|
+
ax.set_ylabel("$f_2$", fontsize=14)
|
|
2338
|
+
ax.set_zlabel("$f_3$", fontsize=14)
|
|
2339
|
+
ax.view_init(elev=elev, azim=azim)
|
|
2340
|
+
ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
|
|
2341
|
+
ax.text2D(
|
|
2342
|
+
-0.17, 0.5,
|
|
2343
|
+
self.problem.__class__.__name__.upper() + f"({self.mode})",
|
|
2344
|
+
transform=ax.transAxes,
|
|
2345
|
+
va='center',
|
|
2346
|
+
ha='center',
|
|
2347
|
+
rotation='vertical',
|
|
2348
|
+
fontsize=20,
|
|
2349
|
+
fontweight='bold'
|
|
2350
|
+
)
|
|
2351
|
+
else:
|
|
2352
|
+
# Bayesian mode
|
|
2353
|
+
if self.problem.n_obj == 2:
|
|
2354
|
+
fig, ax = plt.subplots()
|
|
2355
|
+
if extra is not None:
|
|
2356
|
+
f1, f2 = extra
|
|
2357
|
+
ax.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,
|
|
2358
|
+
label="Pareto Optimal")
|
|
2359
|
+
if pop is not None:
|
|
2360
|
+
f_pop1, f_pop2 = pop
|
|
2361
|
+
ax.scatter(f_pop1, f_pop2, c="green", s=10, alpha=1.0,
|
|
2362
|
+
label="Init Points")
|
|
2363
|
+
if list_fi is not None:
|
|
2364
|
+
n = len(list_fi)
|
|
2365
|
+
for i in range(len(list_fi)):
|
|
2366
|
+
f1, f2 = list_fi[i]
|
|
2367
|
+
# alpha for the inner color only
|
|
2368
|
+
a = 1.0 / (n - i)
|
|
2369
|
+
face_color = (1.0, 0.0, 0.0, a) # red with fading alpha
|
|
2370
|
+
edge_color = (1.0, 0.0, 0.0, 1.0) # solid red border
|
|
2371
|
+
ax.scatter(f1, f2, c="red", s=10,
|
|
2372
|
+
facecolors=face_color,
|
|
2373
|
+
edgecolors=edge_color,
|
|
2374
|
+
linewidths=0.5,
|
|
2375
|
+
marker='o',
|
|
2376
|
+
label="Gen Optimal" if i==len(list_fi)-1 else None)
|
|
2377
|
+
|
|
2378
|
+
ax.set_xlabel("$f_1$", fontsize=14)
|
|
2379
|
+
ax.set_ylabel("$f_2$", fontsize=14)
|
|
2380
|
+
ax.set_title(f"Step: {t}", fontsize=14)
|
|
2381
|
+
ax.text(
|
|
2382
|
+
-0.17, 0.5,
|
|
2383
|
+
self.problem.__class__.__name__.upper() + f"(mobo)",
|
|
2384
|
+
transform=ax.transAxes,
|
|
2385
|
+
va='center',
|
|
2386
|
+
ha='center',
|
|
2387
|
+
rotation='vertical',
|
|
2388
|
+
fontsize=20,
|
|
2389
|
+
fontweight='bold'
|
|
2390
|
+
)
|
|
1913
2391
|
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
if plot_dataset and (self.dataset is not None):
|
|
1919
|
-
_, Y = self.dataset
|
|
1920
|
-
# Denormalize the data
|
|
1921
|
-
Y = self.offline_denormalization(Y,
|
|
1922
|
-
self.y_meanormin,
|
|
1923
|
-
self.y_stdormax)
|
|
1924
|
-
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
|
|
1925
|
-
c="violet", s=5, alpha=1.0,
|
|
1926
|
-
label="Training Data")
|
|
1927
|
-
|
|
1928
|
-
if extra is not None:
|
|
1929
|
-
f1, f2, f3 = extra
|
|
1930
|
-
ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=0.05,
|
|
1931
|
-
label="Pareto Optimal")
|
|
1932
|
-
|
|
1933
|
-
if pop is not None:
|
|
1934
|
-
f_pop1, f_pop2, f_pop3 = pop
|
|
1935
|
-
ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,
|
|
1936
|
-
label="Gen Population")
|
|
2392
|
+
elif self.problem.n_obj == 3:
|
|
2393
|
+
fig = plt.figure()
|
|
2394
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
1937
2395
|
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
2396
|
+
if extra is not None:
|
|
2397
|
+
f1, f2, f3 = extra
|
|
2398
|
+
ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=alpha_pf_3d,
|
|
2399
|
+
label="Pareto Optimal")
|
|
2400
|
+
if pop is not None:
|
|
2401
|
+
f_pop1, f_pop2, f_pop3 = pop
|
|
2402
|
+
ax.scatter(f_pop1, f_pop2, f_pop3, c="green", s=10, alpha=1.0,
|
|
2403
|
+
label="Init Points")
|
|
2404
|
+
if list_fi is not None:
|
|
2405
|
+
n = len(list_fi)
|
|
2406
|
+
for i in range(len(list_fi)):
|
|
2407
|
+
f1, f2, f3 = list_fi[i]
|
|
2408
|
+
a = 1.0 / (n - i)
|
|
2409
|
+
face_color = (1.0, 0.0, 0.0, a) # red with fading alpha
|
|
2410
|
+
edge_color = (1.0, 0.0, 0.0, 1.0) # solid red border
|
|
2411
|
+
ax.scatter(f1, f2, f3, c="red", s = 10,
|
|
2412
|
+
facecolors=face_color,
|
|
2413
|
+
edgecolors=edge_color,
|
|
2414
|
+
linewidths=0.5,
|
|
2415
|
+
marker='o',
|
|
2416
|
+
label="Gen Optimal" if i==len(list_fi)-1 else None)
|
|
2417
|
+
|
|
2418
|
+
ax.set_xlabel("$f_1$", fontsize=14)
|
|
2419
|
+
ax.set_ylabel("$f_2$", fontsize=14)
|
|
2420
|
+
ax.set_zlabel("$f_3$", fontsize=14)
|
|
2421
|
+
ax.view_init(elev=elev, azim=azim)
|
|
2422
|
+
ax.set_title(f"Step: {t}", fontsize=14)
|
|
2423
|
+
ax.text2D(
|
|
2424
|
+
-0.17, 0.5,
|
|
2425
|
+
self.problem.__class__.__name__.upper() + f"(mobo)",
|
|
2426
|
+
transform=ax.transAxes,
|
|
2427
|
+
va='center',
|
|
2428
|
+
ha='center',
|
|
2429
|
+
rotation='vertical',
|
|
2430
|
+
fontsize=20,
|
|
2431
|
+
fontweight='bold'
|
|
2432
|
+
)
|
|
1957
2433
|
|
|
1958
2434
|
img_dir = f"{images_store_path}/{self.problem.__class__.__name__}_{self.mode}"
|
|
1959
2435
|
if label is not None:
|
|
@@ -1975,14 +2451,29 @@ class SPREAD:
|
|
|
1975
2451
|
total_duration_s=20.0,
|
|
1976
2452
|
first_transition_s=2.0,
|
|
1977
2453
|
fps=30,
|
|
2454
|
+
reverse=True,
|
|
1978
2455
|
extensions=("*.jpg", "*.png", "*.jpeg", "*.bmp")):
|
|
1979
|
-
"""
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
2456
|
+
"""
|
|
2457
|
+
Create an MP4 video by blending a sequence of images (sorted by t=... in filename). First transition gets
|
|
2458
|
+
a fixed duration; remaining transitions share the rest.
|
|
2459
|
+
|
|
2460
|
+
Arguments
|
|
2461
|
+
image_folder (str): Directory containing images.
|
|
2462
|
+
output_video (str): Output video path (MP4).
|
|
2463
|
+
total_duration_s (float): Total video duration in seconds.
|
|
2464
|
+
first_transition_s (float): Duration of first transition in seconds.
|
|
2465
|
+
fps (int): Frames per second.
|
|
2466
|
+
reverse (bool): If True, sort images by decreasing t=....
|
|
2467
|
+
extensions (tuple[str,...]): Glob patterns for image files.
|
|
2468
|
+
|
|
2469
|
+
Returns
|
|
2470
|
+
None. Writes video to output_video and prints a success message.
|
|
2471
|
+
|
|
2472
|
+
Raises
|
|
2473
|
+
RuntimeError: If no images are found, fewer than two images exist, or first image cannot be read.
|
|
1983
2474
|
"""
|
|
1984
2475
|
|
|
1985
|
-
# Collect and sort by t=... (descending)
|
|
2476
|
+
# Collect and sort by t=... (descending/ascending)
|
|
1986
2477
|
paths = []
|
|
1987
2478
|
for ext in extensions:
|
|
1988
2479
|
paths.extend(glob.glob(os.path.join(image_folder, ext)))
|
|
@@ -1994,7 +2485,7 @@ class SPREAD:
|
|
|
1994
2485
|
m = t_pat.search(p)
|
|
1995
2486
|
return int(m.group(1)) if m else -1
|
|
1996
2487
|
|
|
1997
|
-
paths.sort(key=lambda p: t_val(p), reverse=
|
|
2488
|
+
paths.sort(key=lambda p: t_val(p), reverse=reverse)
|
|
1998
2489
|
N = len(paths)
|
|
1999
2490
|
if N < 2:
|
|
2000
2491
|
raise RuntimeError("Need at least two images for a transition.")
|