moospread 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
moospread/core.py CHANGED
@@ -9,6 +9,11 @@ from torch.utils.data import DataLoader, TensorDataset
9
9
  import copy
10
10
  import math
11
11
  import json
12
+
13
+ import cv2
14
+ import re
15
+ import glob
16
+
12
17
  import os
13
18
  import pickle
14
19
  from tqdm import tqdm
@@ -25,7 +30,7 @@ import matplotlib.pyplot as plt
25
30
  from mpl_toolkits.mplot3d import Axes3D
26
31
 
27
32
  from moospread.utils import *
28
-
33
+ print
29
34
  class SPREAD:
30
35
  def __init__(self,
31
36
  problem,
@@ -54,7 +59,6 @@ class SPREAD:
54
59
  seed: int = 0,
55
60
  offline_global_clamping: bool = False,
56
61
  offline_normalization_method: str = "z_score",
57
- dominance_classifier = None,
58
62
  train_func_surrogate = None,
59
63
  plot_func = None,
60
64
  verbose: bool = True):
@@ -83,14 +87,12 @@ class SPREAD:
83
87
  if self.mode == "bayesian":
84
88
  self.mobo_coef_lcb = mobo_coef_lcb
85
89
 
86
-
87
90
  self.xi_shift = xi_shift
88
- self.model_dir = model_dir
91
+ self.model_dir = model_dir+f"/{self.problem.__class__.__name__}_{self.mode}"
89
92
  os.makedirs(self.model_dir, exist_ok=True)
90
93
 
91
94
  self.train_func_surrogate = train_func_surrogate
92
95
  self.plot_func = plot_func
93
- self.dominance_classifier = dominance_classifier
94
96
 
95
97
  self.seed = seed
96
98
  # Set the seed for reproducibility
@@ -132,7 +134,7 @@ class SPREAD:
132
134
  print("Training dataset generated.")
133
135
 
134
136
  if self.mode == "offline":
135
- assert offline_normalization_method in ["z_score", "min_max", "none"], "Invalid normalization method"
137
+ assert offline_normalization_method in ["z_score", "min_max", None], "Invalid normalization method"
136
138
  if offline_normalization_method == "z_score":
137
139
  self.offline_normalization = offdata_z_score_normalize
138
140
  self.offline_denormalization = offdata_z_score_denormalize
@@ -142,9 +144,23 @@ class SPREAD:
142
144
  else:
143
145
  self.offline_normalization = lambda x: x
144
146
  self.offline_denormalization = lambda x: x
147
+
148
+ if self.mode == "offline":
149
+ X, y = self.dataset
150
+ X = X.clone().detach()
151
+ y = y.clone().detach()
152
+ if self.problem.is_discrete:
153
+ X = offdata_to_logits(X)
154
+ _, n_dim, n_classes = tuple(X.shape)
155
+ X = X.reshape(-1, n_dim * n_classes)
156
+ if self.problem.is_sequence:
157
+ X = offdata_to_logits(X)
158
+ # For usual cases, we normalize the inputs
159
+ # and outputs with z-score normalization
160
+ X, self.X_meanormin, self.X_stdormax = self.offline_normalization(X)
161
+ y, self.y_meanormin, self.y_stdormax = self.offline_normalization(y)
162
+ self.dataset = (X, y)
145
163
 
146
- self.X_meanormin, self.y_meanormin = 0, 0
147
- self.X_stdormax, self.y_stdormax = 1, 1
148
164
  if self.problem.has_bounds():
149
165
  xl = self.problem.xl
150
166
  xu = self.problem.xu
@@ -156,7 +172,9 @@ class SPREAD:
156
172
  xl = xl.reshape(-1, n_dim * n_classes)
157
173
  if self.problem.is_sequence:
158
174
  xl = offdata_to_logits(xl)
159
- xl, _, _ = self.offline_normalization(xl)
175
+ xl, _, _ = self.offline_normalization(xl,
176
+ self.X_meanormin,
177
+ self.X_stdormax)
160
178
  # xu
161
179
  if self.problem.is_discrete:
162
180
  xu = offdata_to_logits(xu)
@@ -164,7 +182,9 @@ class SPREAD:
164
182
  xu = xu.reshape(-1, n_dim * n_classes)
165
183
  if self.problem.is_sequence:
166
184
  xu = offdata_to_logits(xu)
167
- xu, _, _ = self.offline_normalization(xu)
185
+ xu, _, _ = self.offline_normalization(xu,
186
+ self.X_meanormin,
187
+ self.X_stdormax)
168
188
  ## Set the normalized bounds
169
189
  self.problem.xl = xl
170
190
  self.problem.xu = xu
@@ -225,7 +245,9 @@ class SPREAD:
225
245
  num_inner_steps=10, lr_inner=0.9,
226
246
  free_initial_h=True,
227
247
  use_sigma_rep=False, kernel_sigma_rep=0.01,
228
- iterative_plot=True, plot_period=100,
248
+ iterative_plot=True, plot_period=100,
249
+ plot_dataset=False, plot_population=False,
250
+ elev=30, azim=45, legend=False,
229
251
  max_backtracks=100, label=None, save_results=True,
230
252
  load_models=False,
231
253
  samples_store_path="./samples_dir/",
@@ -238,19 +260,6 @@ class SPREAD:
238
260
  X, y = self.dataset
239
261
 
240
262
  if self.mode == "offline":
241
- X = X.clone().detach()
242
- y = y.clone().detach()
243
- if self.problem.is_discrete:
244
- X = offdata_to_logits(X)
245
- _, n_dim, n_classes = tuple(X.shape)
246
- X = X.reshape(-1, n_dim * n_classes)
247
- if self.problem.is_sequence:
248
- X = offdata_to_logits(X)
249
- # For usual cases, we normalize the inputs
250
- # and outputs with z-score normalization
251
- X, self.X_meanormin, self.X_stdormax = self.offline_normalization(X)
252
- y, self.y_meanormin, self.y_stdormax = self.offline_normalization(y)
253
-
254
263
  #### SURROGATE MODEL TRAINING ####
255
264
  if not load_models or self.surrogate_given:
256
265
  self.train_surrogate(X, y)
@@ -288,7 +297,9 @@ class SPREAD:
288
297
  num_inner_steps=num_inner_steps, lr_inner=lr_inner,
289
298
  free_initial_h=free_initial_h,
290
299
  use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
291
- iterative_plot=iterative_plot, plot_period=plot_period,
300
+ iterative_plot=iterative_plot, plot_period=plot_period,
301
+ plot_dataset=plot_dataset, plot_population=plot_population,
302
+ elev=elev, azim=azim, legend=legend,
292
303
  max_backtracks=max_backtracks, label=label,
293
304
  save_results=save_results,
294
305
  samples_store_path=samples_store_path,
@@ -382,6 +393,8 @@ class SPREAD:
382
393
  free_initial_h=free_initial_h,
383
394
  use_sigma_rep=use_sigma_rep, kernel_sigma_rep=kernel_sigma_rep,
384
395
  iterative_plot=iterative_plot, plot_period=plot_period,
396
+ plot_dataset=plot_dataset, plot_population=plot_population,
397
+ elev=elev, azim=azim, legend=legend,
385
398
  max_backtracks=max_backtracks, label=label,
386
399
  samples_store_path=samples_store_path,
387
400
  images_store_path=images_store_path,
@@ -716,7 +729,9 @@ class SPREAD:
716
729
  num_inner_steps=10, lr_inner=1e-4,
717
730
  free_initial_h=True,
718
731
  use_sigma_rep=False, kernel_sigma_rep=0.01,
719
- iterative_plot=True, plot_period=100,
732
+ iterative_plot=True, plot_period=100,
733
+ plot_dataset=False, plot_population=False,
734
+ elev=30, azim=45, legend=False,
720
735
  max_backtracks=25, label=None,
721
736
  samples_store_path="./samples_dir/",
722
737
  images_store_path="./images_dir/",
@@ -764,13 +779,14 @@ class SPREAD:
764
779
  x_t = torch.rand((num_points_sample, self.problem.n_var)) # in [0, 1]
765
780
  x_t = self.problem.bounds()[0] + (self.problem.bounds()[1] - self.problem.bounds()[0]) * x_t # scale to bounds
766
781
  if self.mode == "offline":
767
- x_t, _, _ = self.offline_normalization(x_t)
782
+ x_t, _, _ = self.offline_normalization(x_t,
783
+ self.X_meanormin,
784
+ self.X_stdormax)
768
785
  x_t = x_t.to(self.device)
769
786
  x_t.requires_grad = True
770
787
  if self.problem.need_repair:
771
788
  x_t.data = self.repair_bounds(x_t.data.clone())
772
-
773
- if self.mode == "online":
789
+ if self.mode in ["online", "offline"]:
774
790
  if iterative_plot and (not is_pass_function(self.problem._evaluate)):
775
791
  if self.problem.n_obj <= 3:
776
792
  pf_population = x_t.detach()
@@ -778,9 +794,45 @@ class SPREAD:
778
794
  pf_population,
779
795
  keep_shape=False
780
796
  )
781
- list_fi = self.objective_functions(pf_points).split(1, dim=1)
782
- list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
797
+ if self.mode == "offline":
798
+ # Denormalize the points before plotting
799
+ res_x_t = pf_points.clone().detach()
800
+ res_x_t = self.offline_denormalization(res_x_t,
801
+ self.X_meanormin,
802
+ self.X_stdormax)
803
+ res_pop = pf_population.clone().detach()
804
+ res_pop = self.offline_denormalization(res_pop,
805
+ self.X_meanormin,
806
+ self.X_stdormax)
807
+ norm_xl, norm_xu = self.problem.bounds()
808
+ xl, xu = self.problem.original_bounds
809
+ self.problem.xl = xl
810
+ self.problem.xu = xu
811
+ if self.problem.is_discrete:
812
+ _, dim, n_classes = tuple(res_x_t.shape)
813
+ res_x_t = res_x_t.reshape(-1, dim, n_classes)
814
+ res_x_t = offdata_to_integers(res_x_t)
815
+
816
+ _, dim_pop, n_classes_pop = tuple(res_pop.shape)
817
+ res_pop = res_pop.reshape(-1, dim_pop, n_classes_pop)
818
+ res_pop = offdata_to_integers(res_pop)
819
+ if self.problem.is_sequence:
820
+ res_x_t = offdata_to_integers(res_x_t)
821
+ res_pop = offdata_to_integers(res_pop)
822
+ # we need to evaluate the true objective functions for plotting
823
+ list_fi = self.objective_functions(res_x_t,
824
+ evaluate_true=True).split(1, dim=1)
825
+ list_fi_pop = self.objective_functions(res_pop,
826
+ evaluate_true=True).split(1, dim=1)
827
+ # restore the normalized bounds
828
+ self.problem.xl = norm_xl
829
+ self.problem.xu = norm_xu
830
+ else:
831
+ list_fi = self.objective_functions(pf_points).split(1, dim=1)
832
+ list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
783
833
  pareto_front = None
834
+ list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
835
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
784
836
  if self.problem.pareto_front() is not None:
785
837
  pareto_front = self.problem.pareto_front()
786
838
  pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
@@ -788,17 +840,18 @@ class SPREAD:
788
840
  self.plot_func(list_fi, self.timesteps,
789
841
  num_points_sample,
790
842
  extra=pareto_front,
843
+ plot_dataset=plot_dataset,
844
+ dataset = self.dataset,
845
+ elev=elev, azim=azim, legend=legend,
791
846
  label=label, images_store_path=images_store_path)
792
847
  else:
793
- plot_dataset = True if self.mode == "offline" else False
794
- list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
795
- list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
796
848
  self.plot_pareto_front(list_fi, self.timesteps,
797
849
  num_points_sample,
798
850
  extra=pareto_front,
799
851
  plot_dataset=plot_dataset,
800
- pop=list_fi_pop,
801
- label=label, images_store_path=images_store_path)
852
+ pop=list_fi_pop,
853
+ elev=elev, azim=azim, legend=legend,
854
+ label=label, images_store_path=images_store_path)
802
855
 
803
856
  prev_pf_points = None
804
857
  num_optimal_points = 0
@@ -812,7 +865,9 @@ class SPREAD:
812
865
  point_n0 = torch.rand((1, self.problem.n_var)) # in [0, 1]
813
866
  point_n0 = self.problem.bounds()[0] + (self.problem.bounds()[1] - self.problem.bounds()[0]) * point_n0 # scale to bounds
814
867
  if self.mode == "offline":
815
- point_n0, _, _ = self.offline_normalization(point_n0)
868
+ point_n0, _, _ = self.offline_normalization(point_n0,
869
+ self.X_meanormin,
870
+ self.X_stdormax)
816
871
  point_n0 = point_n0.to(self.device)
817
872
  point_n0.requires_grad = True
818
873
  if self.problem.need_repair:
@@ -868,7 +923,6 @@ class SPREAD:
868
923
  else:
869
924
  pf_population = copy.deepcopy(x_t.detach())
870
925
 
871
- # print("Number of points before selection:", len(pf_population))
872
926
  pf_points, _, _ = self.get_non_dominated_points(
873
927
  pf_population,
874
928
  keep_shape=False
@@ -881,7 +935,6 @@ class SPREAD:
881
935
  pf_points,
882
936
  keep_shape=False,
883
937
  )
884
- # print("Number of non-dominated points before selection:", len(non_dom_points))
885
938
  if len(pf_points) > num_points_sample:
886
939
  pf_points = self.select_top_n_candidates(
887
940
  pf_points,
@@ -902,6 +955,10 @@ class SPREAD:
902
955
  res_x_t = self.offline_denormalization(res_x_t,
903
956
  self.X_meanormin,
904
957
  self.X_stdormax)
958
+ res_pop = pf_population.clone().detach()
959
+ res_pop = self.offline_denormalization(res_pop,
960
+ self.X_meanormin,
961
+ self.X_stdormax)
905
962
  norm_xl, norm_xu = self.problem.bounds()
906
963
  xl, xu = self.problem.original_bounds
907
964
  self.problem.xl = xl
@@ -910,40 +967,55 @@ class SPREAD:
910
967
  _, dim, n_classes = tuple(res_x_t.shape)
911
968
  res_x_t = res_x_t.reshape(-1, dim, n_classes)
912
969
  res_x_t = offdata_to_integers(res_x_t)
970
+
971
+ _, dim_pop, n_classes_pop = tuple(res_pop.shape)
972
+ res_pop = res_pop.reshape(-1, dim_pop, n_classes_pop)
973
+ res_pop = offdata_to_integers(res_pop)
913
974
  if self.problem.is_sequence:
914
975
  res_x_t = offdata_to_integers(res_x_t)
976
+ res_pop = offdata_to_integers(res_pop)
915
977
  # we need to evaluate the true objective functions for plotting
916
- list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
978
+ list_fi = self.objective_functions(res_x_t,
979
+ evaluate_true=True).split(1, dim=1)
980
+ list_fi_pop = self.objective_functions(res_pop,
981
+ evaluate_true=True).split(1, dim=1)
982
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
917
983
  # restore the normalized bounds
918
984
  self.problem.xl = norm_xl
919
985
  self.problem.xu = norm_xu
920
986
  elif self.mode == "bayesian":
921
987
  # we need to evaluate the true objective functions for plotting
922
988
  list_fi = self.objective_functions(pf_points, evaluate_true=True).split(1, dim=1)
989
+ list_fi_pop = self.objective_functions(pf_population.detach(), evaluate_true=True).split(1, dim=1)
990
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
923
991
  else:
924
992
  list_fi = self.objective_functions(pf_points).split(1, dim=1)
993
+ list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
994
+ list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
925
995
 
926
996
  list_fi = [fi.detach().cpu().numpy() for fi in list_fi]
927
997
  pareto_front = None
928
998
  if self.problem.pareto_front() is not None:
929
999
  pareto_front = self.problem.pareto_front()
930
1000
  pareto_front = [pareto_front[:, i] for i in range(self.problem.n_obj)]
1001
+
931
1002
  if self.plot_func is not None:
932
1003
  self.plot_func(list_fi, t,
933
1004
  num_points_sample,
934
1005
  extra= pareto_front,
1006
+ plot_dataset=plot_dataset,
1007
+ dataset = self.dataset,
1008
+ elev=elev, azim=azim, legend=legend,
935
1009
  label=label, images_store_path=images_store_path)
936
1010
  else:
937
- plot_dataset = True if self.mode == "offline" else False
938
- list_fi_pop = self.objective_functions(pf_population.detach()).split(1, dim=1)
939
- list_fi_pop = [fi.detach().cpu().numpy() for fi in list_fi_pop]
940
1011
  self.plot_pareto_front(list_fi, t,
941
1012
  num_points_sample,
942
1013
  extra= pareto_front,
943
- pop=list_fi_pop,
1014
+ pop=list_fi_pop if plot_population else None,
944
1015
  plot_dataset=plot_dataset,
1016
+ elev=elev, azim=azim, legend=legend,
945
1017
  label=label, images_store_path=images_store_path)
946
-
1018
+
947
1019
 
948
1020
  x_t = x_t.detach()
949
1021
  pbar.set_postfix({
@@ -1402,19 +1474,14 @@ class SPREAD:
1402
1474
  n_prob=grads[0].shape[0], n_obj=self.problem.n_obj,
1403
1475
  verbose=False)
1404
1476
  y = self.objective_functions(x, get_constraint=True)
1405
- # print("y.keys():", y.keys())
1406
1477
  if "H" in y:
1407
1478
  pre_h_vals = y["H"].sum(dim=1)
1408
1479
  constraint_mtd='eq'
1409
- # print("pre_h_vals.shape:", pre_h_vals.shape)
1410
1480
  elif "G" in y:
1411
- # print("pre_h_vals.shape before:", y["G"].shape)
1412
1481
  pre_h_vals = y["G"].sum(dim=1)
1413
1482
  print("pre_h_vals.shape:", pre_h_vals.shape)
1414
1483
  constraint_mtd='ineq'
1415
- # print("pre_h_vals.shape:", pre_h_vals.shape)
1416
1484
  y = y["F"]
1417
- # print("pre_h_vals.shape:", pre_h_vals.shape)
1418
1485
  alphas = SOLVER.compute_weights(x, y, pre_h_vals=pre_h_vals,
1419
1486
  constraint_mtd=constraint_mtd)
1420
1487
  alphas = torch.nan_to_num(alphas, nan=torch.nanmean(alphas),
@@ -1624,7 +1691,6 @@ class SPREAD:
1624
1691
  i for i in range(N)
1625
1692
  if not any(label_matrix[j, i] == 2 for j in range(N))
1626
1693
  ]
1627
- # print(f"Number of non-dominated points: {len(PS_idx)} out of {N}")
1628
1694
  else:
1629
1695
  raise ValueError(f"Mode {self.mode} not recognized!")
1630
1696
 
@@ -1699,7 +1765,6 @@ class SPREAD:
1699
1765
  final_idx = top_indices[torch.randperm(top_indices.size(0))]
1700
1766
  else:
1701
1767
  N = points.shape[0]
1702
- # print(f"In selection, N={N}, n={n}")
1703
1768
  # 1) Predict dominance
1704
1769
  label_matrix, conf_matrix = nn_predict_dom_intra(points.detach().cpu().numpy(),
1705
1770
  self.dominance_classifier,
@@ -1709,7 +1774,6 @@ class SPREAD:
1709
1774
  i for i in range(N)
1710
1775
  if not any(label_matrix[j, i] == 2 for j in range(N))
1711
1776
  ]
1712
- # print(f"(selection) Number of non-dominated points: {len(nondom_inds)} out of {N}")
1713
1777
 
1714
1778
  # --- CASE A: too many non‑dominated → pick top-n by crowding ---
1715
1779
  if len(nondom_inds) > n:
@@ -1784,6 +1848,7 @@ class SPREAD:
1784
1848
  label=None,
1785
1849
  plot_dataset=False,
1786
1850
  pop=None,
1851
+ elev=30, azim=45, legend=False,
1787
1852
  images_store_path="./images_dir/"):
1788
1853
  name = (
1789
1854
  "spread"
@@ -1807,63 +1872,88 @@ class SPREAD:
1807
1872
  return None
1808
1873
 
1809
1874
  elif len(list_fi) == 2:
1875
+ fig, ax = plt.subplots()
1876
+ if plot_dataset and (self.dataset) is not None:
1877
+ _, Y = self.dataset
1878
+ # Denormalize the data
1879
+ Y = self.offline_denormalization(Y,
1880
+ self.y_meanormin,
1881
+ self.y_stdormax)
1882
+ ax.scatter(Y[:, 0], Y[:, 1],
1883
+ c="violet", s=5, alpha=1.0,
1884
+ label="Training Data")
1885
+
1810
1886
  if extra is not None:
1811
1887
  f1, f2 = extra
1812
- plt.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,)
1813
- # label="Pareto optimal points")
1888
+ ax.scatter(f1, f2, c="yellow", s = 5, alpha=1.0,
1889
+ label="Pareto Optimal")
1814
1890
 
1815
1891
  if pop is not None:
1816
1892
  f_pop1, f_pop2 = pop
1817
- plt.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,)
1818
- # label="Population points")
1893
+ ax.scatter(f_pop1, f_pop2, c="blue", s=10, alpha=1.0,
1894
+ label="Gen Population")
1819
1895
 
1820
1896
  f1, f2 = list_fi
1821
- plt.scatter(f1, f2, c="red", s=10, alpha=1.0,)
1822
- # label="Generated optimal points")
1823
- if plot_dataset and (self.dataset) is not None:
1824
- _, Y = self.dataset
1825
- Y = self.offline_denormalization(Y,
1826
- self.y_meanormin,
1827
- self.y_stdormax)
1828
- plt.scatter(Y[:, 0], Y[:, 1],
1829
- c="blue", s=5, alpha=1.0,)
1830
- # label="Training data points")
1831
-
1832
- plt.xlabel("$f_1$", fontsize=14)
1833
- plt.ylabel("$f_2$", fontsize=14)
1834
- plt.title(f"Reverse Time Step: {t}", fontsize=14)
1897
+ ax.scatter(f1, f2, c="red", s=10, alpha=1.0,
1898
+ label="Gen Optimal")
1899
+
1900
+ ax.set_xlabel("$f_1$", fontsize=14)
1901
+ ax.set_ylabel("$f_2$", fontsize=14)
1902
+ ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
1903
+ ax.text(
1904
+ -0.17, 0.5,
1905
+ self.problem.__class__.__name__.upper() + f"({self.mode})",
1906
+ transform=ax.transAxes,
1907
+ va='center',
1908
+ ha='center',
1909
+ rotation='vertical',
1910
+ fontsize=20,
1911
+ fontweight='bold'
1912
+ )
1835
1913
 
1836
1914
  elif len(list_fi) == 3:
1837
1915
  fig = plt.figure()
1838
1916
  ax = fig.add_subplot(111, projection="3d")
1839
1917
 
1918
+ if plot_dataset and (self.dataset is not None):
1919
+ _, Y = self.dataset
1920
+ # Denormalize the data
1921
+ Y = self.offline_denormalization(Y,
1922
+ self.y_meanormin,
1923
+ self.y_stdormax)
1924
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
1925
+ c="violet", s=5, alpha=1.0,
1926
+ label="Training Data")
1927
+
1840
1928
  if extra is not None:
1841
1929
  f1, f2, f3 = extra
1842
- ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=0.05,)
1843
- # label="Pareto optimal points")
1930
+ ax.scatter(f1, f2, f3, c="yellow", s = 5, alpha=0.05,
1931
+ label="Pareto Optimal")
1844
1932
 
1845
1933
  if pop is not None:
1846
1934
  f_pop1, f_pop2, f_pop3 = pop
1847
- ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,)
1848
- # label="Population points")
1935
+ ax.scatter(f_pop1, f_pop2, f_pop3, c="blue", s=10, alpha=1.0,
1936
+ label="Gen Population")
1849
1937
 
1850
1938
  f1, f2, f3 = list_fi
1851
- ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,)
1852
- # label="Generated optimal points")
1853
-
1854
- if plot_dataset and (self.dataset is not None):
1855
- _, Y = self.dataset
1856
- Y = self.offline_denormalization(Y,
1857
- self.y_meanormin,
1858
- self.y_stdormax)
1859
- ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2],
1860
- c="blue", s=5, alpha=1.0,)
1861
- # label="Training data points")
1939
+ ax.scatter(f1, f2, f3, c="red", s = 10, alpha=1.0,
1940
+ label="Gen Optimal")
1941
+
1862
1942
  ax.set_xlabel("$f_1$", fontsize=14)
1863
1943
  ax.set_ylabel("$f_2$", fontsize=14)
1864
1944
  ax.set_zlabel("$f_3$", fontsize=14)
1865
- ax.view_init(elev=30, azim=45)
1945
+ ax.view_init(elev=elev, azim=azim)
1866
1946
  ax.set_title(f"Reverse Time Step: {t}", fontsize=14)
1947
+ ax.text2D(
1948
+ -0.17, 0.5,
1949
+ self.problem.__class__.__name__.upper() + f"({self.mode})",
1950
+ transform=ax.transAxes,
1951
+ va='center',
1952
+ ha='center',
1953
+ rotation='vertical',
1954
+ fontsize=20,
1955
+ fontweight='bold'
1956
+ )
1867
1957
 
1868
1958
  img_dir = f"{images_store_path}/{self.problem.__class__.__name__}_{self.mode}"
1869
1959
  if label is not None:
@@ -1871,7 +1961,8 @@ class SPREAD:
1871
1961
  if not os.path.exists(img_dir):
1872
1962
  os.makedirs(img_dir)
1873
1963
 
1874
- # plt.legend(fontsize=12)
1964
+ if legend:
1965
+ plt.legend(fontsize=12)
1875
1966
 
1876
1967
  plt.savefig(
1877
1968
  f"{img_dir}/{name}.jpg",
@@ -1879,3 +1970,96 @@ class SPREAD:
1879
1970
  bbox_inches="tight",
1880
1971
  )
1881
1972
  plt.close()
1973
+
1974
+ def create_video(self, image_folder, output_video,
1975
+ total_duration_s=20.0,
1976
+ first_transition_s=2.0,
1977
+ fps=30,
1978
+ extensions=("*.jpg", "*.png", "*.jpeg", "*.bmp")):
1979
+ """Create a video from images in `image_folder`, sorted by t=... in filename.
1980
+ The first transition (first->second image) lasts `first_transition_s` seconds.
1981
+ The remaining transitions share the remaining time equally.
1982
+ The output video has total duration `total_duration_s` seconds at `fps` frames per second.
1983
+ """
1984
+
1985
+ # Collect and sort by t=... (descending)
1986
+ paths = []
1987
+ for ext in extensions:
1988
+ paths.extend(glob.glob(os.path.join(image_folder, ext)))
1989
+ if not paths:
1990
+ raise RuntimeError(f"No images found in {image_folder}")
1991
+
1992
+ t_pat = re.compile(r"t=(\d+)")
1993
+ def t_val(p):
1994
+ m = t_pat.search(p)
1995
+ return int(m.group(1)) if m else -1
1996
+
1997
+ paths.sort(key=lambda p: t_val(p), reverse=True)
1998
+ N = len(paths)
1999
+ if N < 2:
2000
+ raise RuntimeError("Need at least two images for a transition.")
2001
+
2002
+ # Read first to get size
2003
+ first_img = cv2.imread(paths[0])
2004
+ if first_img is None:
2005
+ raise RuntimeError(f"Cannot read: {paths[0]}")
2006
+ h, w = first_img.shape[:2]
2007
+ size = (w, h)
2008
+
2009
+ # Prepare writer
2010
+ os.makedirs(os.path.dirname(output_video), exist_ok=True)
2011
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
2012
+ writer = cv2.VideoWriter(output_video, fourcc, float(fps), size)
2013
+
2014
+ # Frame budget
2015
+ F_total = int(round(total_duration_s * fps))
2016
+ F_first = int(round(first_transition_s * fps))
2017
+ F_first = max(0, min(F_first, F_total)) # clamp just in case
2018
+ self.frames_written = 0
2019
+
2020
+ def write_transition(img1, img2, d):
2021
+ """Blend img1->img2 over d frames; d==0 means hard cut; d==1 means single frame of img2."""
2022
+ if d <= 0:
2023
+ return
2024
+ img1 = cv2.resize(img1, size)
2025
+ img2 = cv2.resize(img2, size)
2026
+ if d == 1:
2027
+ writer.write(img2); self.frames_written += 1
2028
+ return
2029
+ for j in range(d):
2030
+ alpha = j / (d - 1) # includes endpoints: j=0 -> img1, j=d-1 -> img2
2031
+ frame = cv2.addWeighted(img1, 1 - alpha, img2, alpha, 0)
2032
+ writer.write(frame); self.frames_written += 1
2033
+
2034
+ # Load all images (resized) once to avoid repeated disk I/O
2035
+ imgs = []
2036
+ for p in paths:
2037
+ im = cv2.imread(p)
2038
+ if im is None:
2039
+ im = first_img.copy()
2040
+ imgs.append(cv2.resize(im, size))
2041
+
2042
+ # 1) First transition: fixed 2 seconds (or less if total is tiny)
2043
+ write_transition(imgs[0], imgs[1], F_first)
2044
+
2045
+ # 2) Remaining transitions share remaining frames
2046
+ remaining_frames = F_total - self.frames_written
2047
+ remaining_transitions = max(0, N - 2)
2048
+
2049
+ if remaining_transitions > 0:
2050
+ # Distribute remaining frames across the remaining transitions.
2051
+ # Some transitions may get 0 or 1 frame (hard/near-hard cut) if time is tight.
2052
+ base = 0 if remaining_transitions == 0 else remaining_frames // remaining_transitions
2053
+ extra = 0 if remaining_transitions == 0 else remaining_frames % remaining_transitions
2054
+ # Ensure we don't exceed the total budget:
2055
+ for i in range(remaining_transitions):
2056
+ d = base + (1 if i < extra else 0)
2057
+ write_transition(imgs[i + 1], imgs[i + 2], d)
2058
+
2059
+ # 3) If we still have spare frames (due to rounding), hold last frame
2060
+ last = imgs[-1]
2061
+ while self.frames_written < F_total:
2062
+ writer.write(last); self.frames_written += 1
2063
+
2064
+ writer.release()
2065
+ print(f"✅ Saved: {output_video} | duration={total_duration_s}s, fps={fps}, frames={F_total}")
@@ -1,4 +1,5 @@
1
1
  from moospread.tasks.dtlz_torch import DTLZ, DTLZ2, DTLZ4, DTLZ7
2
2
  from moospread.tasks.zdt_torch import ZDT, ZDT1, ZDT2, ZDT3
3
3
  from moospread.tasks.re_torch import RE21, RE33, RE34, RE37, RE41
4
- from moospread.tasks.mw_torch import MW7
4
+ from moospread.tasks.bo_torch import BraninCurrin, Penicillin, VehicleSafety
5
+ from moospread.tasks.mw_torch import MW7
@@ -0,0 +1,300 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from pymoo.problems import get_problem
5
+ from moospread.problem import PymooProblemTorch
6
+ from importlib import resources
7
+ import os
8
+ import numpy as np
9
+ import math
10
+
11
+
12
+ ######## RE Problems ########
13
+ # Note: For the sake of differentiability, we will use the strict bounds:
14
+ # - Lower bound: xl + 1e-6 (instead of xl)
15
+ # - Upper bound: xu - 1e-6 (instead of xu)
16
+ # ref_point: The default reference points are suitable for the "bayesian" mode.
17
+
18
+ "Adapted from: https://botorch.readthedocs.io/en/stable/_modules/botorch/test_functions/multi_objective.html"
19
+
20
+ class BraninCurrin(PymooProblemTorch):
21
+ r"""Two objective problem composed of the Branin and Currin functions.
22
+
23
+ Branin (rescaled):
24
+
25
+ f(x) = (
26
+ 15*x_1 - 5.1 * (15 * x_0 - 5) ** 2 / (4 * pi ** 2) + 5 * (15 * x_0 - 5)
27
+ / pi - 5
28
+ ) ** 2 + (10 - 10 / (8 * pi)) * cos(15 * x_0 - 5))
29
+
30
+ Currin:
31
+
32
+ f(x) = (1 - exp(-1 / (2 * x_1))) * (
33
+ 2300 * x_0 ** 3 + 1900 * x_0 ** 2 + 2092 * x_0 + 60
34
+ ) / 100 * x_0 ** 3 + 500 * x_0 ** 2 + 4 * x_0 + 20
35
+
36
+ """
37
+
38
+ def __init__(self, path = None, ref_point=None, negate=True, **kwargs):
39
+ super().__init__(n_var=2, n_obj=2,
40
+ xl=torch.zeros(2, dtype=torch.float) + 1e-6,
41
+ xu=torch.ones(2, dtype=torch.float) - 1e-6,
42
+ vtype=float, **kwargs)
43
+ if ref_point is None:
44
+ self.ref_point = [18.0, 6.0]
45
+ else:
46
+ self.ref_point = ref_point
47
+
48
+ self.path = path
49
+ self.max_hv = 59.36011874867746 # this is approximated using NSGA-II
50
+
51
+ def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
52
+ if self.path is not None:
53
+ front = np.loadtxt(self.path)
54
+ return torch.from_numpy(front).to(self.device)
55
+ else:
56
+ return None
57
+
58
+ def _branin(self, X: Tensor) -> Tensor:
59
+ t1 = (
60
+ X[..., 1]
61
+ - 5.1 / (4 * math.pi**2) * X[..., 0].pow(2)
62
+ + 5 / math.pi * X[..., 0]
63
+ - 6
64
+ )
65
+ t2 = 10 * (1 - 1 / (8 * math.pi)) * torch.cos(X[..., 0])
66
+ return t1.pow(2) + t2 + 10
67
+
68
+ def _rescaled_branin(self, X: Tensor) -> Tensor:
69
+ # return to Branin bounds
70
+ x_0 = 15 * X[..., 0] - 5
71
+ x_1 = 15 * X[..., 1]
72
+ return self._branin(torch.stack([x_0, x_1], dim=-1))
73
+
74
+ @staticmethod
75
+ def _currin(X: Tensor) -> Tensor:
76
+ x_0 = X[..., 0]
77
+ x_1 = X[..., 1]
78
+ factor1 = 1 - torch.exp(-1 / (2 * x_1))
79
+ numer = 2300 * x_0.pow(3) + 1900 * x_0.pow(2) + 2092 * x_0 + 60
80
+ denom = 100 * x_0.pow(3) + 500 * x_0.pow(2) + 4 * x_0 + 20
81
+ return factor1 * numer / denom
82
+
83
+ def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
84
+ # branin rescaled with inputsto [0,1]^2
85
+ branin = self._rescaled_branin(X=X)
86
+ currin = self._currin(X=X)
87
+ out["F"] = torch.stack([branin, currin], dim=-1)
88
+
89
+ class Penicillin(PymooProblemTorch):
90
+ r"""A penicillin production simulator from [Liang2021]_.
91
+
92
+ This implementation is adapted from
93
+ https://github.com/HarryQL/TuRBO-Penicillin.
94
+
95
+ The goal is to maximize the penicillin yield while minimizing
96
+ time to ferment and the CO2 byproduct.
97
+
98
+ The function is defined for minimization of all objectives.
99
+
100
+ The reference point was set using the `infer_reference_point` heuristic
101
+ on the Pareto frontier obtained via NSGA-II.
102
+ """
103
+
104
+ Y_xs = 0.45
105
+ Y_ps = 0.90
106
+ K_1 = 10 ** (-10)
107
+ K_2 = 7 * 10 ** (-5)
108
+ m_X = 0.014
109
+ alpha_1 = 0.143
110
+ alpha_2 = 4 * 10 ** (-7)
111
+ alpha_3 = 10 ** (-4)
112
+ mu_X = 0.092
113
+ K_X = 0.15
114
+ mu_p = 0.005
115
+ K_p = 0.0002
116
+ K_I = 0.10
117
+ K = 0.04
118
+ k_g = 7.0 * 10**3
119
+ E_g = 5100.0
120
+ k_d = 10.0**33
121
+ E_d = 50000.0
122
+ lambd = 2.5 * 10 ** (-4)
123
+ T_v = 273.0 # Kelvin
124
+ T_o = 373.0
125
+ R = 1.9872 # CAL/(MOL K)
126
+ V_max = 180.0
127
+
128
+ def __init__(self, path = None, ref_point=None, negate=True, **kwargs):
129
+ super().__init__(n_var=7, n_obj=3,
130
+ xl=torch.tensor([60.0, 0.05, 293.0, 0.05, 0.01, 500.0, 5.0], dtype=torch.float) + 1e-6,
131
+ xu=torch.tensor([120.0, 18.0, 303.0, 18.0, 0.5, 700.0, 6.5], dtype=torch.float) - 1e-6,
132
+ vtype=float, **kwargs)
133
+ if ref_point is None:
134
+ self.ref_point = [25.935, 57.612, 935.5]
135
+ else:
136
+ self.ref_point = ref_point
137
+
138
+ self.path = path
139
+ self.max_hv = 2183455.909507436
140
+
141
+ def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
142
+ if self.path is not None:
143
+ front = np.loadtxt(self.path)
144
+ return torch.from_numpy(front).to(self.device)
145
+ else:
146
+ return None
147
+
148
+ @classmethod
149
+ def penicillin_vectorized(cls, X_input: Tensor) -> Tensor:
150
+ r"""Penicillin simulator, simplified and vectorized.
151
+
152
+ The 7 input parameters are (in order): culture volume, biomass
153
+ concentration, temperature, glucose concentration, substrate feed
154
+ rate, substrate feed concentration, and H+ concentration.
155
+
156
+ Args:
157
+ X_input: A `n x 7`-dim tensor of inputs.
158
+
159
+ Returns:
160
+ An `n x 3`-dim tensor of (negative) penicillin yield, CO2 and time.
161
+ """
162
+ V, X, T, S, F, s_f, H_ = torch.split(X_input, 1, -1)
163
+ P, CO2 = torch.zeros_like(V), torch.zeros_like(V)
164
+ H = torch.full_like(H_, 10.0).pow(-H_)
165
+
166
+ active = torch.ones_like(V).bool()
167
+ t_tensor = torch.full_like(V, 2500)
168
+
169
+ for t in range(1, 2501):
170
+ if active.sum() == 0:
171
+ break
172
+ F_loss = (
173
+ V[active]
174
+ * cls.lambd
175
+ * torch.special.expm1(5 * ((T[active] - cls.T_o) / (cls.T_v - cls.T_o)))
176
+ )
177
+ dV_dt = F[active] - F_loss
178
+ mu = (
179
+ (cls.mu_X / (1 + cls.K_1 / H[active] + H[active] / cls.K_2))
180
+ * (S[active] / (cls.K_X * X[active] + S[active]))
181
+ * (
182
+ (cls.k_g * torch.exp(-cls.E_g / (cls.R * T[active])))
183
+ - (cls.k_d * torch.exp(-cls.E_d / (cls.R * T[active])))
184
+ )
185
+ )
186
+ dX_dt = mu * X[active] - (X[active] / V[active]) * dV_dt
187
+ mu_pp = cls.mu_p * (
188
+ S[active] / (cls.K_p + S[active] + S[active].pow(2) / cls.K_I)
189
+ )
190
+ dS_dt = (
191
+ -(mu / cls.Y_xs) * X[active]
192
+ - (mu_pp / cls.Y_ps) * X[active]
193
+ - cls.m_X * X[active]
194
+ + F[active] * s_f[active] / V[active]
195
+ - (S[active] / V[active]) * dV_dt
196
+ )
197
+ dP_dt = (
198
+ (mu_pp * X[active])
199
+ - cls.K * P[active]
200
+ - (P[active] / V[active]) * dV_dt
201
+ )
202
+ dCO2_dt = cls.alpha_1 * dX_dt + cls.alpha_2 * X[active] + cls.alpha_3
203
+
204
+ # UPDATE
205
+ P[active] = P[active] + dP_dt # Penicillin concentration
206
+ V[active] = V[active] + dV_dt # Culture medium volume
207
+ X[active] = X[active] + dX_dt # Biomass concentration
208
+ S[active] = S[active] + dS_dt # Glucose concentration
209
+ CO2[active] = CO2[active] + dCO2_dt # CO2 concentration
210
+
211
+ # Update active indices
212
+ full_dpdt = torch.ones_like(P)
213
+ full_dpdt[active] = dP_dt
214
+ inactive = (V > cls.V_max) + (S < 0) + (full_dpdt < 10e-12)
215
+ t_tensor[inactive] = torch.minimum(
216
+ t_tensor[inactive], torch.full_like(t_tensor[inactive], t)
217
+ )
218
+ active[inactive] = 0
219
+
220
+ return torch.stack([-P, CO2, t_tensor], dim=-1)
221
+
222
+ def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
223
+ # This uses in-place operations. Hence, the clone is to avoid modifying
224
+ # the original X in-place.
225
+ out["F"] = self.penicillin_vectorized(X.view(-1, self.dim).clone()).view(
226
+ *X.shape[:-1], self.num_objectives
227
+ )
228
+
229
+
230
+ class VehicleSafety(PymooProblemTorch):
231
+ r"""Optimize Vehicle crash-worthiness.
232
+
233
+ See [Tanabe2020]_ for details.
234
+
235
+ The reference point is 1.1 * the nadir point from
236
+ approximate front provided by [Tanabe2020]_.
237
+
238
+ The maximum hypervolume is computed using the approximate
239
+ pareto front from [Tanabe2020]_.
240
+ """
241
+
242
+ def __init__(self, path = None, ref_point=None, negate=True, **kwargs):
243
+ super().__init__(n_var=5, n_obj=3,
244
+ xl=torch.ones(5, dtype=torch.float) + 1e-6,
245
+ xu=3*torch.ones(5, dtype=torch.float) - 1e-6,
246
+ vtype=float, **kwargs)
247
+ if ref_point is None:
248
+ self.ref_point = [1864.72022, 11.81993945, 0.2903999384]
249
+ else:
250
+ self.ref_point = ref_point
251
+
252
+ self.path = path
253
+ self.max_hv = 246.81607081187002
254
+
255
+ def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
256
+ if self.path is not None:
257
+ front = np.loadtxt(self.path)
258
+ return torch.from_numpy(front).to(self.device)
259
+ else:
260
+ return None
261
+
262
+
263
+ def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
264
+ X1, X2, X3, X4, X5 = torch.split(X, 1, -1)
265
+ f1 = (
266
+ 1640.2823
267
+ + 2.3573285 * X1
268
+ + 2.3220035 * X2
269
+ + 4.5688768 * X3
270
+ + 7.7213633 * X4
271
+ + 4.4559504 * X5
272
+ )
273
+ f2 = (
274
+ 6.5856
275
+ + 1.15 * X1
276
+ - 1.0427 * X2
277
+ + 0.9738 * X3
278
+ + 0.8364 * X4
279
+ - 0.3695 * X1 * X4
280
+ + 0.0861 * X1 * X5
281
+ + 0.3628 * X2 * X4
282
+ - 0.1106 * X1.pow(2)
283
+ - 0.3437 * X3.pow(2)
284
+ + 0.1764 * X4.pow(2)
285
+ )
286
+ f3 = (
287
+ -0.0551
288
+ + 0.0181 * X1
289
+ + 0.1024 * X2
290
+ + 0.0421 * X3
291
+ - 0.0073 * X1 * X2
292
+ + 0.024 * X2 * X3
293
+ - 0.0118 * X2 * X4
294
+ - 0.0204 * X3 * X4
295
+ - 0.008 * X3 * X5
296
+ - 0.0241 * X2.pow(2)
297
+ + 0.0109 * X4.pow(2)
298
+ )
299
+ f_X = torch.cat([f1, f2, f3], dim=-1)
300
+ out["F"] = f_X
@@ -2,6 +2,7 @@ import torch
2
2
  import torch.nn as nn
3
3
  from pymoo.problems import get_problem
4
4
  from moospread.problem import PymooProblemTorch
5
+ from importlib import resources
5
6
  import os
6
7
  import numpy as np
7
8
 
@@ -9,7 +10,7 @@ import numpy as np
9
10
  # Note: For the sake of differentiability, we will use the strict bounds:
10
11
  # - Lower bound: xl + 1e-6 (instead of xl)
11
12
  # - Upper bound: xu - 1e-6 (instead of xu)
12
- # ref_point: The default reference points are suitable for the "online" mode.
13
+ # ref_point: The default reference points are suitable for the "offline" mode.
13
14
 
14
15
  class RE21(PymooProblemTorch):
15
16
  def __init__(self, path = None, ref_point=None, **kwargs):
@@ -26,14 +27,19 @@ class RE21(PymooProblemTorch):
26
27
  vtype=float, **kwargs)
27
28
  if ref_point is None:
28
29
  self.ref_point = [3144.44, 0.05]
30
+ # self.ref_point = [3144.44, 0.05] # suitable for "online" mode
29
31
  else:
30
32
  self.ref_point = ref_point
31
33
 
32
34
  self.path = path
33
35
 
34
36
  def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
35
- assert self.path is not None, "Path to Pareto front file not specified."
36
- front = np.loadtxt(self.path)
37
+ try:
38
+ front = np.loadtxt(self.path)
39
+ except:
40
+ front = np.loadtxt(resources.files("moospread.tasks.pf_re_tasks").joinpath(f"reference_points_{self.__class__.__name__}.dat"))
41
+ else:
42
+ assert self.path is not None, "Path to Pareto front file not specified."
37
43
  return torch.from_numpy(front).to(self.device)
38
44
 
39
45
  def _evaluate(self, X: torch.Tensor, out: dict, *args, **kwargs) -> None:
@@ -81,15 +87,20 @@ class RE33(PymooProblemTorch):
81
87
  xu=torch.tensor([80.0, 110.0, 3000.0, 20.0])- 1e-6,
82
88
  vtype=float, **kwargs)
83
89
  if ref_point is None:
84
- self.ref_point = [5.01, 9.84, 4.30]
90
+ self.ref_point = [8.01, 8.84, 2343.30]
91
+ # self.ref_point = [5.01, 9.84, 4.30] # suitable for "online" mode
85
92
  else:
86
93
  self.ref_point = ref_point
87
94
 
88
95
  self.path = path
89
96
 
90
97
  def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
91
- assert self.path is not None, "Path to Pareto front file not specified."
92
- front = np.loadtxt(self.path)
98
+ try:
99
+ front = np.loadtxt(self.path)
100
+ except:
101
+ front = np.loadtxt(resources.files("moospread.tasks.pf_re_tasks").joinpath(f"reference_points_{self.__class__.__name__}.dat"))
102
+ else:
103
+ assert self.path is not None, "Path to Pareto front file not specified."
93
104
  return torch.from_numpy(front).to(self.device)
94
105
 
95
106
 
@@ -153,15 +164,20 @@ class RE34(PymooProblemTorch):
153
164
  xu=torch.full((5,), 3.0) - 1e-6,
154
165
  vtype=float, **kwargs)
155
166
  if ref_point is None:
156
- self.ref_point = [1.86472022e+03, 1.18199394e+01, 2.90399938e-01]
167
+ self.ref_point = [1702.52, 11.68, 0.26]
168
+ # self.ref_point = [1.86472022e+03, 1.18199394e+01, 2.90399938e-01] # suitable for "online" mode
157
169
  else:
158
170
  self.ref_point = ref_point
159
171
 
160
172
  self.path = path
161
173
 
162
174
  def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
163
- assert self.path is not None, "Path to Pareto front file not specified."
164
- front = np.loadtxt(self.path)
175
+ try:
176
+ front = np.loadtxt(self.path)
177
+ except:
178
+ front = np.loadtxt(resources.files("moospread.tasks.pf_re_tasks").joinpath(f"reference_points_{self.__class__.__name__}.dat"))
179
+ else:
180
+ assert self.path is not None, "Path to Pareto front file not specified."
165
181
  return torch.from_numpy(front).to(self.device)
166
182
 
167
183
  def _evaluate(self, x: torch.Tensor, out: dict, *args, **kwargs) -> None:
@@ -233,15 +249,20 @@ class RE37(PymooProblemTorch):
233
249
  xu=torch.ones(4) - 1e-6,
234
250
  vtype=float, **kwargs)
235
251
  if ref_point is None:
236
- self.ref_point = [1.1022, 1.20726899, 1.20318656]
252
+ self.ref_point = [0.99, 0.96, 0.99]
253
+ # self.ref_point = [1.1022, 1.20726899, 1.20318656] # suitable for "online" mode
237
254
  else:
238
255
  self.ref_point = ref_point
239
256
 
240
257
  self.path = path
241
258
 
242
259
  def _calc_pareto_front(self, n_pareto_points: int = 100) -> torch.Tensor:
243
- assert self.path is not None, "Path to Pareto front file not specified."
244
- front = np.loadtxt(self.path)
260
+ try:
261
+ front = np.loadtxt(self.path)
262
+ except:
263
+ front = np.loadtxt(resources.files("moospread.tasks.pf_re_tasks").joinpath(f"reference_points_{self.__class__.__name__}.dat"))
264
+ else:
265
+ assert self.path is not None, "Path to Pareto front file not specified."
245
266
  return torch.from_numpy(front).to(self.device)
246
267
 
247
268
  def _evaluate(self, x: torch.Tensor, out: dict, *args, **kwargs) -> None:
@@ -325,7 +346,8 @@ class RE41(PymooProblemTorch):
325
346
  xu=torch.tensor([1.5, 1.35, 1.5, 1.5, 2.625, 1.2, 1.2], dtype=torch.float) - 1e-6,
326
347
  vtype=float, **kwargs)
327
348
  if ref_point is None:
328
- self.ref_point = [47.04480682, 4.86997366, 14.40049127, 10.3941957 ]
349
+ self.ref_point = [42.65, 4.43, 13.08, 13.45]
350
+ # self.ref_point = [47.04480682, 4.86997366, 14.40049127, 10.3941957 ] # suitable for "online" mode
329
351
  else:
330
352
  self.ref_point = ref_point
331
353
 
@@ -391,4 +413,4 @@ class RE41(PymooProblemTorch):
391
413
  zero = torch.tensor(0.0, dtype=X.dtype, device=X.device)
392
414
  g = torch.where(g < 0, -g, zero)
393
415
  f4 = g.sum(dim=-1, keepdim=True)
394
- out["F"] = torch.cat([f1, f2, f3, f4], dim=-1)
416
+ out["F"] = torch.cat([f1, f2, f3, f4], dim=-1)
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
- from typing import List, Tuple
3
+ from typing import List, Optional, Tuple
4
4
 
5
5
 
6
6
  def one_hot(a: torch.Tensor, num_classes: int) -> torch.Tensor:
@@ -139,7 +139,9 @@ def offdata_to_integers(x: torch.Tensor, num_classes_on_each_position: List[int]
139
139
  return torch.cat(integers, dim=1)
140
140
 
141
141
 
142
- def offdata_z_score_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
142
+ def offdata_z_score_normalize(x: torch.Tensor,
143
+ mean: Optional[torch.Tensor] = None,
144
+ std: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
143
145
  """Z-score normalize features columnwise (match NumPy semantics).
144
146
 
145
147
  Args:
@@ -150,11 +152,17 @@ def offdata_z_score_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
150
152
  if not torch.is_floating_point(x):
151
153
  raise ValueError("cannot normalize discrete design values")
152
154
 
155
+ if mean is not None and std is not None:
156
+ x_norm = (x - mean.to(x.device)) / std.to(x.device)
157
+ return x_norm, mean, std
158
+
153
159
  mean = torch.mean(x, dim=0)
154
160
  # NumPy's np.std uses population std (ddof=0) by default -> unbiased=False
155
161
  std = torch.std(x, dim=0, unbiased=False)
156
- x_norm = (x - mean) / std
157
- return x_norm, mean, std
162
+ eps = 1e-6
163
+ std_safe = torch.clamp(std, min=eps)
164
+ x_norm = (x - mean) / std_safe
165
+ return x_norm, mean, std_safe
158
166
 
159
167
 
160
168
  def offdata_z_score_denormalize(x: torch.Tensor,
@@ -174,7 +182,9 @@ def offdata_z_score_denormalize(x: torch.Tensor,
174
182
  return x * x_std.to(x.device) + x_mean.to(x.device)
175
183
 
176
184
 
177
- def offdata_min_max_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
185
+ def offdata_min_max_normalize(x: torch.Tensor,
186
+ min_val: Optional[torch.Tensor] = None,
187
+ max_val: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
188
  """Min-max normalize features columnwise.
179
189
 
180
190
  Args:
@@ -182,10 +192,17 @@ def offdata_min_max_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
182
192
  Returns:
183
193
  (x_norm, x_min, x_max)
184
194
  """
195
+
196
+ if min_val is not None and max_val is not None:
197
+ x_norm = (x - min_val.to(x.device)) / (max_val.to(x.device) - min_val.to(x.device))
198
+ return x_norm, min_val, max_val
199
+
185
200
  x_min = torch.min(x, dim=0).values
186
201
  x_max = torch.max(x, dim=0).values
187
- x_norm = (x - x_min) / (x_max - x_min)
188
- return x_norm, x_min, x_max
202
+ eps = 1e-6
203
+ x_max_x_min_safe = torch.clamp(x_max - x_min, min=eps)
204
+ x_norm = (x - x_min) / x_max_x_min_safe
205
+ return x_norm, x_max-x_max_x_min_safe, x_max
189
206
 
190
207
 
191
208
  def offdata_min_max_denormalize(x: torch.Tensor,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: moospread
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Sampling-based Pareto front Refinement via Efficient Adaptive Diffusion
5
5
  Author-email: Sedjro Salomon Hotegni <salomon.hotegni@aims.ac.rw>
6
6
  Maintainer-email: Sedjro Salomon Hotegni <salomon.hotegni@tu-dortmund.de>
@@ -26,6 +26,8 @@ Requires-Dist: scipy>=1.10
26
26
  Requires-Dist: scikit-learn>=1.3
27
27
  Requires-Dist: matplotlib>=3.7
28
28
  Requires-Dist: pandas>=2.0
29
+ Requires-Dist: importlib
30
+ Requires-Dist: opencv-python-headless
29
31
  Requires-Dist: pytz
30
32
  Requires-Dist: PyYAML>=6.0
31
33
  Requires-Dist: tqdm>=4.66
@@ -61,6 +63,20 @@ Dynamic: license-file
61
63
  </p>
62
64
  -->
63
65
 
66
+ <p align="center">
67
+ <a href="https://pypi.org/project/moospread/"><img src="https://img.shields.io/pypi/v/moospread.svg" alt="PyPI version"></a>
68
+ <a href="https://moospread.readthedocs.io">
69
+ <img src="https://img.shields.io/badge/docs-online-brightgreen.svg" alt="Documentation">
70
+ </a>
71
+ </p>
72
+ <div align="center">
73
+ <h3>
74
+ <a href="https://pypi.org/project/moospread/">Installation</a> |
75
+ <a href="https://moospread.readthedocs.io/en/latest/">Documentation</a> |
76
+ <a href="https://arxiv.org/pdf/2509.21058">Paper</a>
77
+ </h3>
78
+ </div>
79
+
64
80
  # SPREAD: Sampling-based Pareto front Refinement via Efficient Adaptive Diffusion
65
81
 
66
82
  > SPREAD is a novel sampling-based approach for multi-objective optimization that leverages diffusion models to efficiently refine and generate well-spread Pareto front approximations. It combines the expressiveness of diffusion models with multi-objective optimization principles to achieve both high convergence to the Pareto front and excellent diversity across the objective space. SPREAD demonstrates competitive performance against state-of-the-art methods while providing a flexible framework for different optimization contexts.
@@ -96,35 +112,23 @@ from moospread import SPREAD
96
112
  from moospread.tasks import ZDT2
97
113
 
98
114
  # Define the problem
99
- n_var = 30
100
- problem = ZDT2(n_var=n_var)
115
+ problem = ZDT2(n_var=30)
101
116
 
102
117
  # Initialize the SPREAD solver
103
118
  solver = SPREAD(
104
119
  problem,
105
120
  data_size=10000,
106
- timesteps=5000,
121
+ timesteps=1000,
107
122
  num_epochs=1000,
108
123
  train_tol=100,
109
- num_blocks=3,
110
- validation_split=0.1,
111
124
  mode="online",
112
125
  seed=2026,
113
126
  verbose=True
114
127
  )
115
128
 
116
129
  # Solve the problem
117
- results = solver.solve(
130
+ res_x, res_y = solver.solve(
118
131
  num_points_sample=200,
119
- strict_guidance=False,
120
- rho_scale_gamma=0.9,
121
- nu_t=10.0,
122
- eta_init=0.9,
123
- num_inner_steps=10,
124
- lr_inner=0.9,
125
- free_initial_h=True,
126
- use_sigma_rep=False,
127
- kernel_sigma_rep=0.01,
128
132
  iterative_plot=True,
129
133
  plot_period=10,
130
134
  max_backtracks=25,
@@ -138,11 +142,10 @@ This will train a diffusion-based multi-objective solver, approximate the Pareto
138
142
 
139
143
  ---
140
144
 
141
- <!--
145
+
142
146
  ### 📚 Next steps
143
147
 
144
- For more advanced examples (offline mode, Bayesian mode, custom problems), see the full [documentation](https://moospread.readthedocs.io/en/latest/).
145
- -->
148
+ For more advanced examples (offline mode, mobo mode, tutorials), see the full [documentation](https://moospread.readthedocs.io/en/latest/).
146
149
 
147
150
  ## Citation
148
151
  If you find `moospread` useful in your research, please consider citing:
@@ -1,10 +1,11 @@
1
1
  moospread/__init__.py,sha256=v9TLUZq0-q0j_23NB7S4ugJqogOMutUuL9MMCi4zu4I,124
2
- moospread/core.py,sha256=nuSEnWBCxxEHXN0AOe1KGp93l7UjkBJjEMUZD3xS64c,82331
2
+ moospread/core.py,sha256=KL7gOta4OUAzb3TYWjgCXrDTot2-Pxt-6vI4fZzzv4s,91644
3
3
  moospread/problem.py,sha256=YjT4k_K7qTZDhWzIYkaBQlsZRVnpP6iV3F8ShFhGAck,6042
4
- moospread/tasks/__init__.py,sha256=S_zM0GjBNbSPH9uyn1RAdQtA6-BDkcbmF4uWUi-8NVo,231
4
+ moospread/tasks/__init__.py,sha256=RrtuZs1TyL35fDzGjlm1PnFqxXd75Qdl-q3Sdld7dsw,309
5
+ moospread/tasks/bo_torch.py,sha256=0I77aWBM4TzfH7l2iQLTRm_5Orc25rmXabrcpPQpVb8,10371
5
6
  moospread/tasks/dtlz_torch.py,sha256=dhmzUj-dbhF3zXIOgx2Z-PpuIqvxAttnEoBM0C9nUt0,5134
6
7
  moospread/tasks/mw_torch.py,sha256=hvYGxcaCr-AFgZd_-rLcIul_a5cZrOV3dRVB5sl9Wuo,9585
7
- moospread/tasks/re_torch.py,sha256=Q5aANkHatO_FgPoPQsLVrwOESNo1LW8sDhXH351GdDM,12856
8
+ moospread/tasks/re_torch.py,sha256=TK7i-9i7qlDZC2H4vdhcVgHybarKwJ-M27yAibAg3LI,14062
8
9
  moospread/tasks/zdt_torch.py,sha256=xdtfF6K3LmgH0vINQa6ogeMrJOdoyg3xaQ1EmldcEvg,4161
9
10
  moospread/utils/__init__.py,sha256=l7iUhBmGcOaZVyJI4SI4XlecDNLhRy0L5CTNmdXccIc,336
10
11
  moospread/utils/ditmoo.py,sha256=fP3NZ-CmnCSoK7m8b7RSoDZiVh053rcOUbu8Llwpz48,3680
@@ -54,10 +55,10 @@ moospread/utils/mobo_utils/mobo/surrogate_model/base.py,sha256=YvTx-ORldAnnQ4bT0
54
55
  moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py,sha256=-OkFEte66FAXSPk32Dy2vhJ9RGoTsJWagbHosCE9qus,7722
55
56
  moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py,sha256=Nmp63vuAgNv-iyVxxRn_yTaAlRLVXW28wsnIu5JpZ1o,3293
56
57
  moospread/utils/offline_utils/__init__.py,sha256=MJC-fqvQnbQ0T_wjCw_QK8nKo_xpQxh0buq91fxYjFY,742
57
- moospread/utils/offline_utils/handle_task.py,sha256=VJjcWZC5AoPm42YN_SKgSpcyHtKBAgXgWwSFU0-Ehis,7586
58
+ moospread/utils/offline_utils/handle_task.py,sha256=kgHpWotCIrmDBjs08KXyp1BuwXe5APkw3q7cu-xWlz4,8365
58
59
  moospread/utils/offline_utils/proxies.py,sha256=DPBykB8l1XJmT5QQCAQrgMZz-8FiGEiNwN0bBdYJIaY,11218
59
- moospread-0.1.2.dist-info/licenses/LICENSE,sha256=YwtV5PRo6WMw5CWQMD728fSF8cWEKKfwOhek37Yi1so,1079
60
- moospread-0.1.2.dist-info/METADATA,sha256=wzYcnaPd_T-vgO7vsUS1vgIRMr4uNyX4bcbrfdDoSGM,5839
61
- moospread-0.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
62
- moospread-0.1.2.dist-info/top_level.txt,sha256=LWi5kIahDQRNXNkx55T-gefn09Bgcq8SoCxp72S-7x0,10
63
- moospread-0.1.2.dist-info/RECORD,,
60
+ moospread-0.1.4.dist-info/licenses/LICENSE,sha256=YwtV5PRo6WMw5CWQMD728fSF8cWEKKfwOhek37Yi1so,1079
61
+ moospread-0.1.4.dist-info/METADATA,sha256=Q1Sw5QOvVA4j973tkEjk4Wse2BPXRNSh9Pzf5gFy4Ts,6179
62
+ moospread-0.1.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
63
+ moospread-0.1.4.dist-info/top_level.txt,sha256=LWi5kIahDQRNXNkx55T-gefn09Bgcq8SoCxp72S-7x0,10
64
+ moospread-0.1.4.dist-info/RECORD,,