moospread 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. moospread/__init__.py +3 -0
  2. moospread/core.py +1881 -0
  3. moospread/problem.py +193 -0
  4. moospread/tasks/__init__.py +4 -0
  5. moospread/tasks/dtlz_torch.py +139 -0
  6. moospread/tasks/mw_torch.py +274 -0
  7. moospread/tasks/re_torch.py +394 -0
  8. moospread/tasks/zdt_torch.py +112 -0
  9. moospread/utils/__init__.py +8 -0
  10. moospread/utils/constraint_utils/__init__.py +2 -0
  11. moospread/utils/constraint_utils/gradient.py +72 -0
  12. moospread/utils/constraint_utils/mgda_core.py +69 -0
  13. moospread/utils/constraint_utils/pmgda_solver.py +308 -0
  14. moospread/utils/constraint_utils/prefs.py +64 -0
  15. moospread/utils/ditmoo.py +127 -0
  16. moospread/utils/lhs.py +74 -0
  17. moospread/utils/misc.py +28 -0
  18. moospread/utils/mobo_utils/__init__.py +11 -0
  19. moospread/utils/mobo_utils/evolution/__init__.py +0 -0
  20. moospread/utils/mobo_utils/evolution/dom.py +60 -0
  21. moospread/utils/mobo_utils/evolution/norm.py +40 -0
  22. moospread/utils/mobo_utils/evolution/utils.py +97 -0
  23. moospread/utils/mobo_utils/learning/__init__.py +0 -0
  24. moospread/utils/mobo_utils/learning/model.py +40 -0
  25. moospread/utils/mobo_utils/learning/model_init.py +33 -0
  26. moospread/utils/mobo_utils/learning/model_update.py +51 -0
  27. moospread/utils/mobo_utils/learning/prediction.py +116 -0
  28. moospread/utils/mobo_utils/learning/utils.py +143 -0
  29. moospread/utils/mobo_utils/lhs_for_mobo.py +243 -0
  30. moospread/utils/mobo_utils/mobo/__init__.py +0 -0
  31. moospread/utils/mobo_utils/mobo/acquisition.py +209 -0
  32. moospread/utils/mobo_utils/mobo/algorithms.py +91 -0
  33. moospread/utils/mobo_utils/mobo/factory.py +86 -0
  34. moospread/utils/mobo_utils/mobo/mobo.py +132 -0
  35. moospread/utils/mobo_utils/mobo/selection.py +182 -0
  36. moospread/utils/mobo_utils/mobo/solver/__init__.py +5 -0
  37. moospread/utils/mobo_utils/mobo/solver/moead.py +17 -0
  38. moospread/utils/mobo_utils/mobo/solver/nsga2.py +10 -0
  39. moospread/utils/mobo_utils/mobo/solver/parego/__init__.py +1 -0
  40. moospread/utils/mobo_utils/mobo/solver/parego/parego.py +62 -0
  41. moospread/utils/mobo_utils/mobo/solver/parego/utils.py +34 -0
  42. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/__init__.py +1 -0
  43. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/buffer.py +364 -0
  44. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/pareto_discovery.py +571 -0
  45. moospread/utils/mobo_utils/mobo/solver/pareto_discovery/utils.py +168 -0
  46. moospread/utils/mobo_utils/mobo/solver/solver.py +74 -0
  47. moospread/utils/mobo_utils/mobo/surrogate_model/__init__.py +2 -0
  48. moospread/utils/mobo_utils/mobo/surrogate_model/base.py +36 -0
  49. moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py +177 -0
  50. moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py +79 -0
  51. moospread/utils/mobo_utils/mobo/surrogate_problem.py +44 -0
  52. moospread/utils/mobo_utils/mobo/transformation.py +106 -0
  53. moospread/utils/mobo_utils/mobo/utils.py +65 -0
  54. moospread/utils/mobo_utils/spread_mobo_utils.py +854 -0
  55. moospread/utils/offline_utils/__init__.py +10 -0
  56. moospread/utils/offline_utils/handle_task.py +203 -0
  57. moospread/utils/offline_utils/proxies.py +338 -0
  58. moospread/utils/spread_utils.py +91 -0
  59. moospread-0.1.0.dist-info/METADATA +75 -0
  60. moospread-0.1.0.dist-info/RECORD +63 -0
  61. moospread-0.1.0.dist-info/WHEEL +5 -0
  62. moospread-0.1.0.dist-info/licenses/LICENSE +10 -0
  63. moospread-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,10 @@
1
+ from moospread.utils.offline_utils.proxies import (SingleModelBaseTrainer,
2
+ MultipleModels,
3
+ SingleModel,
4
+ offdata_get_dataloader)
5
+ from moospread.utils.offline_utils.handle_task import (offdata_min_max_normalize,
6
+ offdata_min_max_denormalize,
7
+ offdata_z_score_normalize,
8
+ offdata_z_score_denormalize,
9
+ offdata_to_integers,
10
+ offdata_to_logits)
@@ -0,0 +1,203 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import List, Tuple
4
+
5
+
6
+ def one_hot(a: torch.Tensor, num_classes: int) -> torch.Tensor:
7
+ """Convert an integer tensor to a float32 one-hot tensor using PyTorch.
8
+
9
+ Args:
10
+ a: Long/Int tensor of arbitrary shape containing class indices.
11
+ num_classes: Total number of classes.
12
+ Returns:
13
+ Tensor with shape ``a.shape + (num_classes,)`` and dtype float32.
14
+ """
15
+ if torch.is_floating_point(a):
16
+ raise ValueError("cannot one-hot encode non-integers (got floating dtype)")
17
+ return F.one_hot(a.to(torch.long), num_classes=num_classes).to(torch.float32)
18
+
19
+
20
+ def help_to_logits(
21
+ x: torch.Tensor, num_classes: int, soft_interpolation: float = 0.6
22
+ ) -> torch.Tensor:
23
+ """Convert integer labels to *logit* representation with a soft uniform prior.
24
+
25
+ Mirrors the original NumPy implementation but in PyTorch.
26
+
27
+ Args:
28
+ x: Int tensor of shape ``(n_samples, 1)``.
29
+ num_classes: Number of classes for this position.
30
+ soft_interpolation: Interpolate between one-hot (Dirac) and uniform prior.
31
+ Returns:
32
+ Float tensor of shape ``(n_samples, 1, num_classes-1)``.
33
+ """
34
+ if torch.is_floating_point(x):
35
+ raise ValueError("cannot convert non-integers to logits")
36
+
37
+ device = x.device
38
+
39
+ # one-hot (n, 1, C)
40
+ one_hot_x = one_hot(x, num_classes=num_classes) # float32
41
+
42
+ # uniform prior to interpolate with
43
+ uniform_prior = torch.full_like(one_hot_x, 1.0 / float(num_classes), device=device)
44
+
45
+ # convex combination
46
+ soft_x = soft_interpolation * one_hot_x + (1.0 - soft_interpolation) * uniform_prior
47
+
48
+ # convert to log probabilities
49
+ log_p = torch.log(soft_x)
50
+
51
+ # remove one degree of freedom: subtract the first component and drop it
52
+ return (log_p[:, :, 1:] - log_p[:, :, :1]).to(torch.float32)
53
+
54
+
55
+ def offdata_to_logits(
56
+ x: torch.Tensor,
57
+ num_classes_on_each_position: List[int],
58
+ soft_interpolation: float = 0.6,
59
+ ) -> torch.Tensor:
60
+ """Convert a sequence of categorical integers to concatenated logits.
61
+
62
+ For each sequence position ``i`` with ``k_i`` classes, we form a length ``k_i-1``
63
+ logit vector (after removing the redundant degree of freedom) and then concatenate
64
+ all positions along the last axis.
65
+
66
+ Args:
67
+ x: Int tensor of shape ``(n_samples, seq_len)``.
68
+ num_classes_on_each_position: list of class counts per position.
69
+ soft_interpolation: Interpolation factor in ``[0, 1]``.
70
+ Returns:
71
+ Float tensor of shape ``(n_samples, sum_i (k_i - 1))`` (matches the original).
72
+ """
73
+ # Original code adds +1, then for positions with exactly 1 class, adds another +1
74
+ # to introduce a dummy class. We reproduce that behavior.
75
+ num_classes = [c + 1 for c in num_classes_on_each_position]
76
+ num_classes = [c + 1 if c == 1 else c for c in num_classes]
77
+
78
+ logits = []
79
+ seq_len = len(num_classes)
80
+ for i in range(seq_len):
81
+ temp_x = x[:, i].reshape(-1, 1)
82
+ logits.append(help_to_logits(temp_x, num_classes[i], soft_interpolation))
83
+
84
+ # concatenate along the last dim, then squeeze (to mimic NumPy's .squeeze())
85
+ out = torch.cat(logits, dim=2).squeeze()
86
+ return out
87
+
88
+
89
+ def help_to_integers(x: torch.Tensor, true_num_of_classes: int) -> torch.Tensor:
90
+ """Convert per-position logits back to class integers.
91
+
92
+ Args:
93
+ x: Float tensor of shape ``(n_samples, 1, k-1)`` for a position with ``k`` classes.
94
+ true_num_of_classes: The (possibly dummy-adjusted) number of classes ``k``.
95
+ Returns:
96
+ Int tensor of shape ``(n_samples, 1)`` with the selected classes.
97
+ """
98
+ if not torch.is_floating_point(x):
99
+ raise ValueError("cannot convert non-floats to integers")
100
+
101
+ # Special-case: if k == 1 (RFP-Exact-v0 path), always return zeros
102
+ if true_num_of_classes == 1:
103
+ return torch.zeros(x.shape[:-1], dtype=torch.int64, device=x.device)
104
+
105
+ # Pad a leading zero component and take argmax along class dim
106
+ # x shape: (n, 1, k-1) -> pad to (n, 1, k)
107
+ x_padded = F.pad(x, pad=(1, 0)) # pad last dim: (left=1, right=0)
108
+ return torch.argmax(x_padded, dim=-1).to(torch.int64)
109
+
110
+
111
+ def offdata_to_integers(x: torch.Tensor, num_classes_on_each_position: List[int]) -> torch.Tensor:
112
+ """Invert ``to_logits``: recover integer classes for each sequence position.
113
+
114
+ Note: This follows the original slicing behavior where the concatenated
115
+ per-position logits are packed along the second dimension (after a squeeze).
116
+
117
+ Args:
118
+ x: Float tensor with shape ``(n_samples, total_concatenated)`` where
119
+ ``total_concatenated = sum_i (k_i - 1)`` after the same class-count
120
+ adjustments used in ``to_logits``.
121
+ num_classes_on_each_position: List of true class counts per position.
122
+ Returns:
123
+ Int tensor of shape ``(n_samples, seq_len)``.
124
+ """
125
+ # Reproduce the same class-count adjustment as in to_logits
126
+ true_num_classes = [c + 1 for c in num_classes_on_each_position]
127
+ num_classes = [c + 1 if c == 1 else c for c in true_num_classes]
128
+
129
+ integers = []
130
+ start = 0
131
+ for k in num_classes:
132
+ width = k - 1
133
+ # Slice along dim=1 to match the original implementation
134
+ temp_x = x[:, start : start + width].reshape(-1, 1, width)
135
+ integers.append(help_to_integers(temp_x, k))
136
+ start += width
137
+
138
+ # Concatenate along seq dimension (dim=1)
139
+ return torch.cat(integers, dim=1)
140
+
141
+
142
+ def offdata_z_score_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
143
+ """Z-score normalize features columnwise (match NumPy semantics).
144
+
145
+ Args:
146
+ x: Float tensor of shape ``(n_samples, n_dim)``.
147
+ Returns:
148
+ (x_norm, mean, std) where each has shape compatible with broadcasting.
149
+ """
150
+ if not torch.is_floating_point(x):
151
+ raise ValueError("cannot normalize discrete design values")
152
+
153
+ mean = torch.mean(x, dim=0)
154
+ # NumPy's np.std uses population std (ddof=0) by default -> unbiased=False
155
+ std = torch.std(x, dim=0, unbiased=False)
156
+ x_norm = (x - mean) / std
157
+ return x_norm, mean, std
158
+
159
+
160
+ def offdata_z_score_denormalize(x: torch.Tensor,
161
+ x_mean: torch.Tensor,
162
+ x_std: torch.Tensor) -> torch.Tensor:
163
+ """Invert z-score normalization.
164
+
165
+ Args:
166
+ x: Float tensor ``(n_samples, n_dim)``.
167
+ x_mean: Mean used during normalization.
168
+ x_std: Std used during normalization.
169
+ Returns:
170
+ Denormalized tensor.
171
+ """
172
+ if not torch.is_floating_point(x):
173
+ raise ValueError("cannot denormalize discrete design values")
174
+ return x * x_std.to(x.device) + x_mean.to(x.device)
175
+
176
+
177
+ def offdata_min_max_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ """Min-max normalize features columnwise.
179
+
180
+ Args:
181
+ x: Tensor of shape ``(n_samples, n_dim)``.
182
+ Returns:
183
+ (x_norm, x_min, x_max)
184
+ """
185
+ x_min = torch.min(x, dim=0).values
186
+ x_max = torch.max(x, dim=0).values
187
+ x_norm = (x - x_min) / (x_max - x_min)
188
+ return x_norm, x_min, x_max
189
+
190
+
191
+ def offdata_min_max_denormalize(x: torch.Tensor,
192
+ x_min: torch.Tensor,
193
+ x_max: torch.Tensor) -> torch.Tensor:
194
+ """Invert min-max normalization.
195
+
196
+ Args:
197
+ x: Tensor of shape ``(n_samples, n_dim)``.
198
+ x_min: Per-dimension min.
199
+ x_max: Per-dimension max.
200
+ Returns:
201
+ Denormalized tensor.
202
+ """
203
+ return x * (x_max.to(x.device) - x_min.to(x.device)) + x_min.to(x.device)
@@ -0,0 +1,338 @@
1
+ """
2
+ Adapted from: https://github.com/lamda-bbo/offline-moo/blob/main/off_moo_baselines/multiple_models/nets.py
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.optim import Adam
10
+ from torch.utils.data import DataLoader, TensorDataset, random_split
11
+ from tqdm import tqdm
12
+
13
+ def offdata_get_dataloader(
14
+ X,
15
+ y,
16
+ train_ratio: float = 0.9,
17
+ batch_size: int = 32,
18
+ ):
19
+
20
+ tensor_dataset = TensorDataset(X, y)
21
+ lengths = [
22
+ int(train_ratio * len(tensor_dataset)),
23
+ len(tensor_dataset) - int(train_ratio * len(tensor_dataset)),
24
+ ]
25
+ train_dataset, val_dataset = random_split(tensor_dataset, lengths)
26
+
27
+ train_loader = DataLoader(
28
+ train_dataset, batch_size=batch_size, shuffle=True, drop_last=False
29
+ )
30
+
31
+ val_loader = DataLoader(
32
+ val_dataset, batch_size=batch_size * 4, shuffle=False, drop_last=False
33
+ )
34
+
35
+ return train_loader, val_loader
36
+
37
+ def spearman_correlation(x, y):
38
+ n = x.size(0)
39
+ _, rank_x = x.sort(0)
40
+ _, rank_y = y.sort(0)
41
+
42
+ d = rank_x - rank_y
43
+ d_squared_sum = (d**2).sum(0).float()
44
+
45
+ rho = 1 - (6 * d_squared_sum) / (n * (n**2 - 1))
46
+ return rho
47
+
48
+ class SingleModelBaseTrainer(nn.Module):
49
+
50
+ def __init__(self, model, which_obj, args):
51
+ super(SingleModelBaseTrainer, self).__init__()
52
+ self.args = args
53
+
54
+ self.forward_lr = args["proxies_lr"]
55
+ self.forward_lr_decay = args["proxies_lr_decay"]
56
+ self.n_epochs = args["proxies_epochs"]
57
+ self.device = args["device"]
58
+ self.verbose = args["verbose"]
59
+
60
+ self.model = model
61
+
62
+ self.which_obj = which_obj
63
+
64
+ self.forward_opt = Adam(model.parameters(), lr=args["proxies_lr"])
65
+ self.train_criterion = lambda yhat, y: (
66
+ torch.sum(torch.mean((yhat - y) ** 2, dim=1))
67
+ )
68
+ self.mse_criterion = nn.MSELoss()
69
+
70
+ def _evaluate_performance(self, statistics, epoch, train_loader, val_loader):
71
+ self.model.eval()
72
+ with torch.no_grad():
73
+ y_all = torch.zeros((0, self.n_obj)).to(self.device)
74
+ outputs_all = torch.zeros((0, self.n_obj)).to(self.device)
75
+ for (
76
+ batch_x,
77
+ batch_y,
78
+ ) in train_loader:
79
+ batch_x = batch_x.to(self.device)
80
+ batch_y = batch_y.to(self.device)
81
+
82
+ y_all = torch.cat((y_all, batch_y), dim=0)
83
+ outputs = self.model(batch_x)
84
+ outputs_all = torch.cat((outputs_all, outputs), dim=0)
85
+
86
+ train_mse = self.mse_criterion(outputs_all, y_all)
87
+ train_corr = spearman_correlation(outputs_all, y_all)
88
+ train_pcc = self.compute_pcc(outputs_all, y_all)
89
+
90
+ statistics[f"model_{self.which_obj}/train/mse"] = train_mse.item()
91
+ for i in range(self.n_obj):
92
+ statistics[f"model_{self.which_obj}/train/rank_corr_{i + 1}"] = (
93
+ train_corr[i].item()
94
+ )
95
+ # if self.verbose:
96
+ # print(
97
+ # "Epoch [{}/{}], MSE: {:}, PCC: {:}".format(
98
+ # epoch + 1, self.n_epochs, train_mse.item(), train_pcc.item()
99
+ # )
100
+ # )
101
+
102
+ with torch.no_grad():
103
+ y_all = torch.zeros((0, self.n_obj)).to(self.device)
104
+ outputs_all = torch.zeros((0, self.n_obj)).to(self.device)
105
+
106
+ for batch_x, batch_y in val_loader:
107
+ batch_x = batch_x.to(self.device)
108
+ batch_y = batch_y.to(self.device)
109
+
110
+ y_all = torch.cat((y_all, batch_y), dim=0)
111
+ outputs = self.model(batch_x)
112
+ outputs_all = torch.cat((outputs_all, outputs))
113
+
114
+ val_mse = self.mse_criterion(outputs_all, y_all)
115
+ val_corr = spearman_correlation(outputs_all, y_all)
116
+ val_pcc = self.compute_pcc(outputs_all, y_all)
117
+
118
+ statistics[f"model_{self.which_obj}/valid/mse"] = val_mse.item()
119
+ for i in range(self.n_obj):
120
+ statistics[f"model_{self.which_obj}/valid/rank_corr_{i + 1}"] = (
121
+ val_corr[i].item()
122
+ )
123
+
124
+ # if self.verbose:
125
+ # print(
126
+ # "Valid MSE: {:}, Valid PCC: {:}".format(val_mse.item(), val_pcc.item())
127
+ # )
128
+
129
+ if val_pcc.item() > self.min_pcc:
130
+ # if self.verbose:
131
+ # print("🌸 New best epoch! 🌸")
132
+ self.min_pcc = val_pcc.item()
133
+ self.model.save(val_pcc=self.min_pcc)
134
+ return statistics
135
+
136
+ def launch(
137
+ self,
138
+ train_loader=None,
139
+ val_loader=None,
140
+ retrain_model: bool = True,
141
+ ):
142
+
143
+ def update_lr(optimizer, lr):
144
+ for param_group in optimizer.param_groups:
145
+ param_group["lr"] = lr
146
+
147
+ if not retrain_model and os.path.exists(self.model.save_path):
148
+ self.model.load()
149
+ return
150
+
151
+ assert train_loader is not None
152
+ assert val_loader is not None
153
+
154
+ self.n_obj = None
155
+ self.min_pcc = -1.0
156
+ statistics = {}
157
+
158
+ with tqdm(
159
+ total=self.n_epochs,
160
+ desc=f"Proxy Training",
161
+ unit="epoch",
162
+ ) as pbar:
163
+
164
+ for epoch in range(self.n_epochs):
165
+ self.model.train()
166
+
167
+ losses = []
168
+ epoch_loss = 0.0
169
+ for batch_x, batch_y in train_loader:
170
+ batch_x = batch_x.to(self.device)
171
+ batch_y = batch_y.to(self.device)
172
+ if self.n_obj is None:
173
+ self.n_obj = batch_y.shape[1]
174
+
175
+ self.forward_opt.zero_grad()
176
+ outputs = self.model(batch_x)
177
+ loss = self.train_criterion(outputs, batch_y)
178
+ losses.append(loss.item() / batch_x.size(0))
179
+ loss.backward()
180
+ self.forward_opt.step()
181
+ epoch_loss += loss.item()
182
+
183
+ statistics[f"model_{self.which_obj}/train/loss/mean"] = np.array(
184
+ losses
185
+ ).mean()
186
+ statistics[f"model_{self.which_obj}/train/loss/std"] = np.array(
187
+ losses
188
+ ).std()
189
+ statistics[f"model_{self.which_obj}/train/loss/max"] = np.array(
190
+ losses
191
+ ).max()
192
+
193
+ self._evaluate_performance(statistics, epoch, train_loader, val_loader)
194
+
195
+ statistics[f"model_{self.which_obj}/train/lr"] = self.forward_lr
196
+ self.forward_lr *= self.forward_lr_decay
197
+ update_lr(self.forward_opt, self.forward_lr)
198
+
199
+ epoch_loss /= len(train_loader)
200
+ pbar.set_postfix({
201
+ "Loss": epoch_loss,
202
+ })
203
+ pbar.update(1)
204
+
205
+ def compute_pcc(self, valid_preds, valid_labels):
206
+ vx = valid_preds - torch.mean(valid_preds)
207
+ vy = valid_labels - torch.mean(valid_labels)
208
+ pcc = torch.sum(vx * vy) / (
209
+ torch.sqrt(torch.sum(vx**2) + 1e-12) * torch.sqrt(torch.sum(vy**2) + 1e-12)
210
+ )
211
+ return pcc
212
+
213
+
214
+ class MultipleModels(nn.Module):
215
+ def __init__(
216
+ self, n_dim, n_obj, hidden_size, train_mode, device,
217
+ save_dir=None, save_prefix=None
218
+ ):
219
+ super(MultipleModels, self).__init__()
220
+ self.n_dim = n_dim
221
+ self.n_obj = n_obj
222
+ self.device = device
223
+
224
+ self.obj2model = {}
225
+ self.hidden_size = hidden_size
226
+ self.train_mode = train_mode
227
+
228
+ self.save_dir = save_dir
229
+ self.save_prefix = save_prefix
230
+ if self.save_dir is not None:
231
+ os.makedirs(self.save_dir, exist_ok=True)
232
+
233
+ for obj in range(self.n_obj):
234
+ self.create_models(obj)
235
+
236
+ def create_models(self, learning_obj):
237
+ model = SingleModel
238
+ new_model = model(
239
+ self.n_dim,
240
+ self.hidden_size,
241
+ which_obj=learning_obj,
242
+ device=self.device,
243
+ save_dir=self.save_dir,
244
+ save_prefix=self.save_prefix,
245
+ )
246
+ self.obj2model[learning_obj] = new_model
247
+
248
+ def set_kwargs(self, device=None, dtype=torch.float32):
249
+ for model in self.obj2model.values():
250
+ model.set_kwargs(device=device, dtype=dtype)
251
+ model.to(device=device, dtype=dtype)
252
+
253
+ def forward(self, x, forward_objs=None):
254
+ if forward_objs is None:
255
+ forward_objs = list(self.obj2model.keys())
256
+ x = [self.obj2model[obj](x) for obj in forward_objs]
257
+ x = torch.cat(x, dim=1)
258
+ return x
259
+
260
+
261
+ activate_functions = [nn.LeakyReLU(), nn.LeakyReLU()]
262
+
263
+
264
+ class SingleModel(nn.Module):
265
+ def __init__(
266
+ self, input_size, hidden_size, which_obj, device,
267
+ save_dir=None, save_prefix=None
268
+ ):
269
+ super(SingleModel, self).__init__()
270
+ self.n_dim = input_size
271
+ self.n_obj = 1
272
+ self.which_obj = which_obj
273
+ self.activate_functions = activate_functions
274
+ self.device = device
275
+
276
+ layers = []
277
+ layers.append(nn.Linear(input_size, hidden_size[0]))
278
+ for i in range(len(hidden_size) - 1):
279
+ layers.append(nn.Linear(hidden_size[i], hidden_size[i + 1]))
280
+ layers.append(nn.Linear(hidden_size[len(hidden_size) - 1], 1))
281
+
282
+ self.layers = nn.Sequential(*layers)
283
+ self.hidden_size = hidden_size
284
+
285
+ self.save_path = os.path.join(save_dir, f"{save_prefix}-{which_obj}.pt")
286
+
287
+ def forward(self, x):
288
+ for i in range(len(self.hidden_size)):
289
+ x = self.layers[i](x)
290
+ x = self.activate_functions[i](x)
291
+
292
+ x = self.layers[len(self.hidden_size)](x)
293
+ out = x
294
+
295
+ return out
296
+
297
+ def set_kwargs(self, device=None, dtype=torch.float32):
298
+ self.to(device=device, dtype=dtype)
299
+
300
+ def check_model_path_exist(self, save_path=None):
301
+ assert (
302
+ self.save_path is not None or save_path is not None
303
+ ), "save path should be specified"
304
+ if save_path is None:
305
+ save_path = self.save_path
306
+ return os.path.exists(save_path)
307
+
308
+ def save(self, val_pcc=None, save_path=None):
309
+ assert (
310
+ self.save_path is not None or save_path is not None
311
+ ), "save path should be specified"
312
+ if save_path is None:
313
+ save_path = self.save_path
314
+
315
+ self = self.to("cpu")
316
+ checkpoint = {
317
+ "model_state_dict": self.state_dict(),
318
+ }
319
+ if val_pcc is not None:
320
+ checkpoint["valid_pcc"] = val_pcc
321
+
322
+ torch.save(checkpoint, save_path)
323
+ self = self.to(self.device)
324
+
325
+ def load(self, save_path=None):
326
+ assert (
327
+ self.save_path is not None or save_path is not None
328
+ ), "save path should be specified"
329
+ if save_path is None:
330
+ save_path = self.save_path
331
+
332
+ checkpoint = torch.load(save_path, weights_only=False)
333
+ self.load_state_dict(checkpoint["model_state_dict"])
334
+ valid_pcc = checkpoint["valid_pcc"]
335
+ print(
336
+ f"Successfully load trained proxy model from {save_path} "
337
+ f"with valid PCC = {valid_pcc}"
338
+ )
@@ -0,0 +1,91 @@
1
+ from torch.utils.data import DataLoader, TensorDataset
2
+ import dis
3
+ import inspect
4
+ import ast
5
+ import textwrap
6
+
7
+ def get_ddpm_dataloader(X,
8
+ y,
9
+ validation_split=0.1,
10
+ batch_size=32):
11
+ val_dataloader = None
12
+ if validation_split > 0.0:
13
+ train_size = int(X.shape[0] - int(X.shape[0] * validation_split))
14
+ X_val = X[train_size:]
15
+ y_val = y[train_size:]
16
+ X = X[:train_size]
17
+ y = y[:train_size]
18
+
19
+ tensor_x_val = X_val.float()
20
+ tensor_y_val = y_val.float()
21
+ dataset_val = TensorDataset(tensor_x_val, tensor_y_val)
22
+ val_dataloader = DataLoader(
23
+ dataset_val,
24
+ batch_size=batch_size,
25
+ shuffle=False,
26
+ drop_last=True,
27
+ )
28
+ tensor_x = X.float()
29
+ tensor_y = y.float()
30
+ dataset_train = TensorDataset(tensor_x, tensor_y)
31
+ train_dataloader = DataLoader(
32
+ dataset_train,
33
+ batch_size=batch_size,
34
+ shuffle=True,
35
+ drop_last=True,
36
+ )
37
+
38
+ return train_dataloader, val_dataloader
39
+
40
+ def is_pass_function(func) -> bool:
41
+ """
42
+ Return True iff the function body is effectively just `pass`
43
+ (optionally with a docstring). Works on Python 3.8–3.12+.
44
+ """
45
+ # -------- Try AST first (most reliable) --------
46
+ try:
47
+ src = inspect.getsource(func)
48
+ except (OSError, TypeError):
49
+ src = None
50
+
51
+ if src:
52
+ mod = ast.parse(textwrap.dedent(src))
53
+ fn = next((n for n in ast.walk(mod)
54
+ if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))), None)
55
+ if fn:
56
+ body = list(fn.body)
57
+ # drop docstring if present
58
+ if body and isinstance(body[0], ast.Expr) and \
59
+ isinstance(getattr(body[0], "value", None), ast.Constant) and \
60
+ isinstance(body[0].value.value, str):
61
+ body = body[1:]
62
+ # True if remaining stmts are all Pass (or empty)
63
+ return all(isinstance(n, ast.Pass) for n in body)
64
+
65
+ # -------- Fallback: bytecode pattern (version-tolerant) --------
66
+ instrs = list(dis.get_instructions(func))
67
+
68
+ # remove version-specific noise
69
+ noise = {"RESUME", "CACHE", "EXTENDED_ARG", "NOP"}
70
+ core = [i for i in instrs if i.opname not in noise]
71
+
72
+ # strip docstring store: LOAD_CONST <str>; STORE_* __doc__
73
+ i = 0
74
+ while i + 1 < len(core):
75
+ a, b = core[i], core[i + 1]
76
+ if (a.opname == "LOAD_CONST" and isinstance(a.argval, str)
77
+ and b.opname in {"STORE_NAME", "STORE_FAST"} and b.argval == "__doc__"):
78
+ del core[i:i+2]
79
+ continue
80
+ i += 1
81
+
82
+ # Accept either:
83
+ # - LOAD_CONST None; RETURN_VALUE
84
+ # - RETURN_CONST None (3.12+)
85
+ if len(core) == 2 and core[0].opname == "LOAD_CONST" and core[0].argval is None \
86
+ and core[1].opname == "RETURN_VALUE":
87
+ return True
88
+ if len(core) == 1 and core[0].opname == "RETURN_CONST" and core[0].argval is None:
89
+ return True
90
+
91
+ return False