moospread 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- moospread/__init__.py +3 -0
- moospread/core.py +1881 -0
- moospread/problem.py +193 -0
- moospread/tasks/__init__.py +4 -0
- moospread/tasks/dtlz_torch.py +139 -0
- moospread/tasks/mw_torch.py +274 -0
- moospread/tasks/re_torch.py +394 -0
- moospread/tasks/zdt_torch.py +112 -0
- moospread/utils/__init__.py +8 -0
- moospread/utils/constraint_utils/__init__.py +2 -0
- moospread/utils/constraint_utils/gradient.py +72 -0
- moospread/utils/constraint_utils/mgda_core.py +69 -0
- moospread/utils/constraint_utils/pmgda_solver.py +308 -0
- moospread/utils/constraint_utils/prefs.py +64 -0
- moospread/utils/ditmoo.py +127 -0
- moospread/utils/lhs.py +74 -0
- moospread/utils/misc.py +28 -0
- moospread/utils/mobo_utils/__init__.py +11 -0
- moospread/utils/mobo_utils/evolution/__init__.py +0 -0
- moospread/utils/mobo_utils/evolution/dom.py +60 -0
- moospread/utils/mobo_utils/evolution/norm.py +40 -0
- moospread/utils/mobo_utils/evolution/utils.py +97 -0
- moospread/utils/mobo_utils/learning/__init__.py +0 -0
- moospread/utils/mobo_utils/learning/model.py +40 -0
- moospread/utils/mobo_utils/learning/model_init.py +33 -0
- moospread/utils/mobo_utils/learning/model_update.py +51 -0
- moospread/utils/mobo_utils/learning/prediction.py +116 -0
- moospread/utils/mobo_utils/learning/utils.py +143 -0
- moospread/utils/mobo_utils/lhs_for_mobo.py +243 -0
- moospread/utils/mobo_utils/mobo/__init__.py +0 -0
- moospread/utils/mobo_utils/mobo/acquisition.py +209 -0
- moospread/utils/mobo_utils/mobo/algorithms.py +91 -0
- moospread/utils/mobo_utils/mobo/factory.py +86 -0
- moospread/utils/mobo_utils/mobo/mobo.py +132 -0
- moospread/utils/mobo_utils/mobo/selection.py +182 -0
- moospread/utils/mobo_utils/mobo/solver/__init__.py +5 -0
- moospread/utils/mobo_utils/mobo/solver/moead.py +17 -0
- moospread/utils/mobo_utils/mobo/solver/nsga2.py +10 -0
- moospread/utils/mobo_utils/mobo/solver/parego/__init__.py +1 -0
- moospread/utils/mobo_utils/mobo/solver/parego/parego.py +62 -0
- moospread/utils/mobo_utils/mobo/solver/parego/utils.py +34 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/__init__.py +1 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/buffer.py +364 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/pareto_discovery.py +571 -0
- moospread/utils/mobo_utils/mobo/solver/pareto_discovery/utils.py +168 -0
- moospread/utils/mobo_utils/mobo/solver/solver.py +74 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/__init__.py +2 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/base.py +36 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/gaussian_process.py +177 -0
- moospread/utils/mobo_utils/mobo/surrogate_model/thompson_sampling.py +79 -0
- moospread/utils/mobo_utils/mobo/surrogate_problem.py +44 -0
- moospread/utils/mobo_utils/mobo/transformation.py +106 -0
- moospread/utils/mobo_utils/mobo/utils.py +65 -0
- moospread/utils/mobo_utils/spread_mobo_utils.py +854 -0
- moospread/utils/offline_utils/__init__.py +10 -0
- moospread/utils/offline_utils/handle_task.py +203 -0
- moospread/utils/offline_utils/proxies.py +338 -0
- moospread/utils/spread_utils.py +91 -0
- moospread-0.1.0.dist-info/METADATA +75 -0
- moospread-0.1.0.dist-info/RECORD +63 -0
- moospread-0.1.0.dist-info/WHEEL +5 -0
- moospread-0.1.0.dist-info/licenses/LICENSE +10 -0
- moospread-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from moospread.utils.offline_utils.proxies import (SingleModelBaseTrainer,
|
|
2
|
+
MultipleModels,
|
|
3
|
+
SingleModel,
|
|
4
|
+
offdata_get_dataloader)
|
|
5
|
+
from moospread.utils.offline_utils.handle_task import (offdata_min_max_normalize,
|
|
6
|
+
offdata_min_max_denormalize,
|
|
7
|
+
offdata_z_score_normalize,
|
|
8
|
+
offdata_z_score_denormalize,
|
|
9
|
+
offdata_to_integers,
|
|
10
|
+
offdata_to_logits)
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def one_hot(a: torch.Tensor, num_classes: int) -> torch.Tensor:
|
|
7
|
+
"""Convert an integer tensor to a float32 one-hot tensor using PyTorch.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
a: Long/Int tensor of arbitrary shape containing class indices.
|
|
11
|
+
num_classes: Total number of classes.
|
|
12
|
+
Returns:
|
|
13
|
+
Tensor with shape ``a.shape + (num_classes,)`` and dtype float32.
|
|
14
|
+
"""
|
|
15
|
+
if torch.is_floating_point(a):
|
|
16
|
+
raise ValueError("cannot one-hot encode non-integers (got floating dtype)")
|
|
17
|
+
return F.one_hot(a.to(torch.long), num_classes=num_classes).to(torch.float32)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def help_to_logits(
|
|
21
|
+
x: torch.Tensor, num_classes: int, soft_interpolation: float = 0.6
|
|
22
|
+
) -> torch.Tensor:
|
|
23
|
+
"""Convert integer labels to *logit* representation with a soft uniform prior.
|
|
24
|
+
|
|
25
|
+
Mirrors the original NumPy implementation but in PyTorch.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
x: Int tensor of shape ``(n_samples, 1)``.
|
|
29
|
+
num_classes: Number of classes for this position.
|
|
30
|
+
soft_interpolation: Interpolate between one-hot (Dirac) and uniform prior.
|
|
31
|
+
Returns:
|
|
32
|
+
Float tensor of shape ``(n_samples, 1, num_classes-1)``.
|
|
33
|
+
"""
|
|
34
|
+
if torch.is_floating_point(x):
|
|
35
|
+
raise ValueError("cannot convert non-integers to logits")
|
|
36
|
+
|
|
37
|
+
device = x.device
|
|
38
|
+
|
|
39
|
+
# one-hot (n, 1, C)
|
|
40
|
+
one_hot_x = one_hot(x, num_classes=num_classes) # float32
|
|
41
|
+
|
|
42
|
+
# uniform prior to interpolate with
|
|
43
|
+
uniform_prior = torch.full_like(one_hot_x, 1.0 / float(num_classes), device=device)
|
|
44
|
+
|
|
45
|
+
# convex combination
|
|
46
|
+
soft_x = soft_interpolation * one_hot_x + (1.0 - soft_interpolation) * uniform_prior
|
|
47
|
+
|
|
48
|
+
# convert to log probabilities
|
|
49
|
+
log_p = torch.log(soft_x)
|
|
50
|
+
|
|
51
|
+
# remove one degree of freedom: subtract the first component and drop it
|
|
52
|
+
return (log_p[:, :, 1:] - log_p[:, :, :1]).to(torch.float32)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def offdata_to_logits(
|
|
56
|
+
x: torch.Tensor,
|
|
57
|
+
num_classes_on_each_position: List[int],
|
|
58
|
+
soft_interpolation: float = 0.6,
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
"""Convert a sequence of categorical integers to concatenated logits.
|
|
61
|
+
|
|
62
|
+
For each sequence position ``i`` with ``k_i`` classes, we form a length ``k_i-1``
|
|
63
|
+
logit vector (after removing the redundant degree of freedom) and then concatenate
|
|
64
|
+
all positions along the last axis.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
x: Int tensor of shape ``(n_samples, seq_len)``.
|
|
68
|
+
num_classes_on_each_position: list of class counts per position.
|
|
69
|
+
soft_interpolation: Interpolation factor in ``[0, 1]``.
|
|
70
|
+
Returns:
|
|
71
|
+
Float tensor of shape ``(n_samples, sum_i (k_i - 1))`` (matches the original).
|
|
72
|
+
"""
|
|
73
|
+
# Original code adds +1, then for positions with exactly 1 class, adds another +1
|
|
74
|
+
# to introduce a dummy class. We reproduce that behavior.
|
|
75
|
+
num_classes = [c + 1 for c in num_classes_on_each_position]
|
|
76
|
+
num_classes = [c + 1 if c == 1 else c for c in num_classes]
|
|
77
|
+
|
|
78
|
+
logits = []
|
|
79
|
+
seq_len = len(num_classes)
|
|
80
|
+
for i in range(seq_len):
|
|
81
|
+
temp_x = x[:, i].reshape(-1, 1)
|
|
82
|
+
logits.append(help_to_logits(temp_x, num_classes[i], soft_interpolation))
|
|
83
|
+
|
|
84
|
+
# concatenate along the last dim, then squeeze (to mimic NumPy's .squeeze())
|
|
85
|
+
out = torch.cat(logits, dim=2).squeeze()
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def help_to_integers(x: torch.Tensor, true_num_of_classes: int) -> torch.Tensor:
|
|
90
|
+
"""Convert per-position logits back to class integers.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
x: Float tensor of shape ``(n_samples, 1, k-1)`` for a position with ``k`` classes.
|
|
94
|
+
true_num_of_classes: The (possibly dummy-adjusted) number of classes ``k``.
|
|
95
|
+
Returns:
|
|
96
|
+
Int tensor of shape ``(n_samples, 1)`` with the selected classes.
|
|
97
|
+
"""
|
|
98
|
+
if not torch.is_floating_point(x):
|
|
99
|
+
raise ValueError("cannot convert non-floats to integers")
|
|
100
|
+
|
|
101
|
+
# Special-case: if k == 1 (RFP-Exact-v0 path), always return zeros
|
|
102
|
+
if true_num_of_classes == 1:
|
|
103
|
+
return torch.zeros(x.shape[:-1], dtype=torch.int64, device=x.device)
|
|
104
|
+
|
|
105
|
+
# Pad a leading zero component and take argmax along class dim
|
|
106
|
+
# x shape: (n, 1, k-1) -> pad to (n, 1, k)
|
|
107
|
+
x_padded = F.pad(x, pad=(1, 0)) # pad last dim: (left=1, right=0)
|
|
108
|
+
return torch.argmax(x_padded, dim=-1).to(torch.int64)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def offdata_to_integers(x: torch.Tensor, num_classes_on_each_position: List[int]) -> torch.Tensor:
|
|
112
|
+
"""Invert ``to_logits``: recover integer classes for each sequence position.
|
|
113
|
+
|
|
114
|
+
Note: This follows the original slicing behavior where the concatenated
|
|
115
|
+
per-position logits are packed along the second dimension (after a squeeze).
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
x: Float tensor with shape ``(n_samples, total_concatenated)`` where
|
|
119
|
+
``total_concatenated = sum_i (k_i - 1)`` after the same class-count
|
|
120
|
+
adjustments used in ``to_logits``.
|
|
121
|
+
num_classes_on_each_position: List of true class counts per position.
|
|
122
|
+
Returns:
|
|
123
|
+
Int tensor of shape ``(n_samples, seq_len)``.
|
|
124
|
+
"""
|
|
125
|
+
# Reproduce the same class-count adjustment as in to_logits
|
|
126
|
+
true_num_classes = [c + 1 for c in num_classes_on_each_position]
|
|
127
|
+
num_classes = [c + 1 if c == 1 else c for c in true_num_classes]
|
|
128
|
+
|
|
129
|
+
integers = []
|
|
130
|
+
start = 0
|
|
131
|
+
for k in num_classes:
|
|
132
|
+
width = k - 1
|
|
133
|
+
# Slice along dim=1 to match the original implementation
|
|
134
|
+
temp_x = x[:, start : start + width].reshape(-1, 1, width)
|
|
135
|
+
integers.append(help_to_integers(temp_x, k))
|
|
136
|
+
start += width
|
|
137
|
+
|
|
138
|
+
# Concatenate along seq dimension (dim=1)
|
|
139
|
+
return torch.cat(integers, dim=1)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def offdata_z_score_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
143
|
+
"""Z-score normalize features columnwise (match NumPy semantics).
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
x: Float tensor of shape ``(n_samples, n_dim)``.
|
|
147
|
+
Returns:
|
|
148
|
+
(x_norm, mean, std) where each has shape compatible with broadcasting.
|
|
149
|
+
"""
|
|
150
|
+
if not torch.is_floating_point(x):
|
|
151
|
+
raise ValueError("cannot normalize discrete design values")
|
|
152
|
+
|
|
153
|
+
mean = torch.mean(x, dim=0)
|
|
154
|
+
# NumPy's np.std uses population std (ddof=0) by default -> unbiased=False
|
|
155
|
+
std = torch.std(x, dim=0, unbiased=False)
|
|
156
|
+
x_norm = (x - mean) / std
|
|
157
|
+
return x_norm, mean, std
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def offdata_z_score_denormalize(x: torch.Tensor,
|
|
161
|
+
x_mean: torch.Tensor,
|
|
162
|
+
x_std: torch.Tensor) -> torch.Tensor:
|
|
163
|
+
"""Invert z-score normalization.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
x: Float tensor ``(n_samples, n_dim)``.
|
|
167
|
+
x_mean: Mean used during normalization.
|
|
168
|
+
x_std: Std used during normalization.
|
|
169
|
+
Returns:
|
|
170
|
+
Denormalized tensor.
|
|
171
|
+
"""
|
|
172
|
+
if not torch.is_floating_point(x):
|
|
173
|
+
raise ValueError("cannot denormalize discrete design values")
|
|
174
|
+
return x * x_std.to(x.device) + x_mean.to(x.device)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def offdata_min_max_normalize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
178
|
+
"""Min-max normalize features columnwise.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
x: Tensor of shape ``(n_samples, n_dim)``.
|
|
182
|
+
Returns:
|
|
183
|
+
(x_norm, x_min, x_max)
|
|
184
|
+
"""
|
|
185
|
+
x_min = torch.min(x, dim=0).values
|
|
186
|
+
x_max = torch.max(x, dim=0).values
|
|
187
|
+
x_norm = (x - x_min) / (x_max - x_min)
|
|
188
|
+
return x_norm, x_min, x_max
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def offdata_min_max_denormalize(x: torch.Tensor,
|
|
192
|
+
x_min: torch.Tensor,
|
|
193
|
+
x_max: torch.Tensor) -> torch.Tensor:
|
|
194
|
+
"""Invert min-max normalization.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
x: Tensor of shape ``(n_samples, n_dim)``.
|
|
198
|
+
x_min: Per-dimension min.
|
|
199
|
+
x_max: Per-dimension max.
|
|
200
|
+
Returns:
|
|
201
|
+
Denormalized tensor.
|
|
202
|
+
"""
|
|
203
|
+
return x * (x_max.to(x.device) - x_min.to(x.device)) + x_min.to(x.device)
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Adapted from: https://github.com/lamda-bbo/offline-moo/blob/main/off_moo_baselines/multiple_models/nets.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from torch.optim import Adam
|
|
10
|
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
def offdata_get_dataloader(
|
|
14
|
+
X,
|
|
15
|
+
y,
|
|
16
|
+
train_ratio: float = 0.9,
|
|
17
|
+
batch_size: int = 32,
|
|
18
|
+
):
|
|
19
|
+
|
|
20
|
+
tensor_dataset = TensorDataset(X, y)
|
|
21
|
+
lengths = [
|
|
22
|
+
int(train_ratio * len(tensor_dataset)),
|
|
23
|
+
len(tensor_dataset) - int(train_ratio * len(tensor_dataset)),
|
|
24
|
+
]
|
|
25
|
+
train_dataset, val_dataset = random_split(tensor_dataset, lengths)
|
|
26
|
+
|
|
27
|
+
train_loader = DataLoader(
|
|
28
|
+
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
val_loader = DataLoader(
|
|
32
|
+
val_dataset, batch_size=batch_size * 4, shuffle=False, drop_last=False
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
return train_loader, val_loader
|
|
36
|
+
|
|
37
|
+
def spearman_correlation(x, y):
|
|
38
|
+
n = x.size(0)
|
|
39
|
+
_, rank_x = x.sort(0)
|
|
40
|
+
_, rank_y = y.sort(0)
|
|
41
|
+
|
|
42
|
+
d = rank_x - rank_y
|
|
43
|
+
d_squared_sum = (d**2).sum(0).float()
|
|
44
|
+
|
|
45
|
+
rho = 1 - (6 * d_squared_sum) / (n * (n**2 - 1))
|
|
46
|
+
return rho
|
|
47
|
+
|
|
48
|
+
class SingleModelBaseTrainer(nn.Module):
|
|
49
|
+
|
|
50
|
+
def __init__(self, model, which_obj, args):
|
|
51
|
+
super(SingleModelBaseTrainer, self).__init__()
|
|
52
|
+
self.args = args
|
|
53
|
+
|
|
54
|
+
self.forward_lr = args["proxies_lr"]
|
|
55
|
+
self.forward_lr_decay = args["proxies_lr_decay"]
|
|
56
|
+
self.n_epochs = args["proxies_epochs"]
|
|
57
|
+
self.device = args["device"]
|
|
58
|
+
self.verbose = args["verbose"]
|
|
59
|
+
|
|
60
|
+
self.model = model
|
|
61
|
+
|
|
62
|
+
self.which_obj = which_obj
|
|
63
|
+
|
|
64
|
+
self.forward_opt = Adam(model.parameters(), lr=args["proxies_lr"])
|
|
65
|
+
self.train_criterion = lambda yhat, y: (
|
|
66
|
+
torch.sum(torch.mean((yhat - y) ** 2, dim=1))
|
|
67
|
+
)
|
|
68
|
+
self.mse_criterion = nn.MSELoss()
|
|
69
|
+
|
|
70
|
+
def _evaluate_performance(self, statistics, epoch, train_loader, val_loader):
|
|
71
|
+
self.model.eval()
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
y_all = torch.zeros((0, self.n_obj)).to(self.device)
|
|
74
|
+
outputs_all = torch.zeros((0, self.n_obj)).to(self.device)
|
|
75
|
+
for (
|
|
76
|
+
batch_x,
|
|
77
|
+
batch_y,
|
|
78
|
+
) in train_loader:
|
|
79
|
+
batch_x = batch_x.to(self.device)
|
|
80
|
+
batch_y = batch_y.to(self.device)
|
|
81
|
+
|
|
82
|
+
y_all = torch.cat((y_all, batch_y), dim=0)
|
|
83
|
+
outputs = self.model(batch_x)
|
|
84
|
+
outputs_all = torch.cat((outputs_all, outputs), dim=0)
|
|
85
|
+
|
|
86
|
+
train_mse = self.mse_criterion(outputs_all, y_all)
|
|
87
|
+
train_corr = spearman_correlation(outputs_all, y_all)
|
|
88
|
+
train_pcc = self.compute_pcc(outputs_all, y_all)
|
|
89
|
+
|
|
90
|
+
statistics[f"model_{self.which_obj}/train/mse"] = train_mse.item()
|
|
91
|
+
for i in range(self.n_obj):
|
|
92
|
+
statistics[f"model_{self.which_obj}/train/rank_corr_{i + 1}"] = (
|
|
93
|
+
train_corr[i].item()
|
|
94
|
+
)
|
|
95
|
+
# if self.verbose:
|
|
96
|
+
# print(
|
|
97
|
+
# "Epoch [{}/{}], MSE: {:}, PCC: {:}".format(
|
|
98
|
+
# epoch + 1, self.n_epochs, train_mse.item(), train_pcc.item()
|
|
99
|
+
# )
|
|
100
|
+
# )
|
|
101
|
+
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
y_all = torch.zeros((0, self.n_obj)).to(self.device)
|
|
104
|
+
outputs_all = torch.zeros((0, self.n_obj)).to(self.device)
|
|
105
|
+
|
|
106
|
+
for batch_x, batch_y in val_loader:
|
|
107
|
+
batch_x = batch_x.to(self.device)
|
|
108
|
+
batch_y = batch_y.to(self.device)
|
|
109
|
+
|
|
110
|
+
y_all = torch.cat((y_all, batch_y), dim=0)
|
|
111
|
+
outputs = self.model(batch_x)
|
|
112
|
+
outputs_all = torch.cat((outputs_all, outputs))
|
|
113
|
+
|
|
114
|
+
val_mse = self.mse_criterion(outputs_all, y_all)
|
|
115
|
+
val_corr = spearman_correlation(outputs_all, y_all)
|
|
116
|
+
val_pcc = self.compute_pcc(outputs_all, y_all)
|
|
117
|
+
|
|
118
|
+
statistics[f"model_{self.which_obj}/valid/mse"] = val_mse.item()
|
|
119
|
+
for i in range(self.n_obj):
|
|
120
|
+
statistics[f"model_{self.which_obj}/valid/rank_corr_{i + 1}"] = (
|
|
121
|
+
val_corr[i].item()
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# if self.verbose:
|
|
125
|
+
# print(
|
|
126
|
+
# "Valid MSE: {:}, Valid PCC: {:}".format(val_mse.item(), val_pcc.item())
|
|
127
|
+
# )
|
|
128
|
+
|
|
129
|
+
if val_pcc.item() > self.min_pcc:
|
|
130
|
+
# if self.verbose:
|
|
131
|
+
# print("🌸 New best epoch! 🌸")
|
|
132
|
+
self.min_pcc = val_pcc.item()
|
|
133
|
+
self.model.save(val_pcc=self.min_pcc)
|
|
134
|
+
return statistics
|
|
135
|
+
|
|
136
|
+
def launch(
|
|
137
|
+
self,
|
|
138
|
+
train_loader=None,
|
|
139
|
+
val_loader=None,
|
|
140
|
+
retrain_model: bool = True,
|
|
141
|
+
):
|
|
142
|
+
|
|
143
|
+
def update_lr(optimizer, lr):
|
|
144
|
+
for param_group in optimizer.param_groups:
|
|
145
|
+
param_group["lr"] = lr
|
|
146
|
+
|
|
147
|
+
if not retrain_model and os.path.exists(self.model.save_path):
|
|
148
|
+
self.model.load()
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
assert train_loader is not None
|
|
152
|
+
assert val_loader is not None
|
|
153
|
+
|
|
154
|
+
self.n_obj = None
|
|
155
|
+
self.min_pcc = -1.0
|
|
156
|
+
statistics = {}
|
|
157
|
+
|
|
158
|
+
with tqdm(
|
|
159
|
+
total=self.n_epochs,
|
|
160
|
+
desc=f"Proxy Training",
|
|
161
|
+
unit="epoch",
|
|
162
|
+
) as pbar:
|
|
163
|
+
|
|
164
|
+
for epoch in range(self.n_epochs):
|
|
165
|
+
self.model.train()
|
|
166
|
+
|
|
167
|
+
losses = []
|
|
168
|
+
epoch_loss = 0.0
|
|
169
|
+
for batch_x, batch_y in train_loader:
|
|
170
|
+
batch_x = batch_x.to(self.device)
|
|
171
|
+
batch_y = batch_y.to(self.device)
|
|
172
|
+
if self.n_obj is None:
|
|
173
|
+
self.n_obj = batch_y.shape[1]
|
|
174
|
+
|
|
175
|
+
self.forward_opt.zero_grad()
|
|
176
|
+
outputs = self.model(batch_x)
|
|
177
|
+
loss = self.train_criterion(outputs, batch_y)
|
|
178
|
+
losses.append(loss.item() / batch_x.size(0))
|
|
179
|
+
loss.backward()
|
|
180
|
+
self.forward_opt.step()
|
|
181
|
+
epoch_loss += loss.item()
|
|
182
|
+
|
|
183
|
+
statistics[f"model_{self.which_obj}/train/loss/mean"] = np.array(
|
|
184
|
+
losses
|
|
185
|
+
).mean()
|
|
186
|
+
statistics[f"model_{self.which_obj}/train/loss/std"] = np.array(
|
|
187
|
+
losses
|
|
188
|
+
).std()
|
|
189
|
+
statistics[f"model_{self.which_obj}/train/loss/max"] = np.array(
|
|
190
|
+
losses
|
|
191
|
+
).max()
|
|
192
|
+
|
|
193
|
+
self._evaluate_performance(statistics, epoch, train_loader, val_loader)
|
|
194
|
+
|
|
195
|
+
statistics[f"model_{self.which_obj}/train/lr"] = self.forward_lr
|
|
196
|
+
self.forward_lr *= self.forward_lr_decay
|
|
197
|
+
update_lr(self.forward_opt, self.forward_lr)
|
|
198
|
+
|
|
199
|
+
epoch_loss /= len(train_loader)
|
|
200
|
+
pbar.set_postfix({
|
|
201
|
+
"Loss": epoch_loss,
|
|
202
|
+
})
|
|
203
|
+
pbar.update(1)
|
|
204
|
+
|
|
205
|
+
def compute_pcc(self, valid_preds, valid_labels):
|
|
206
|
+
vx = valid_preds - torch.mean(valid_preds)
|
|
207
|
+
vy = valid_labels - torch.mean(valid_labels)
|
|
208
|
+
pcc = torch.sum(vx * vy) / (
|
|
209
|
+
torch.sqrt(torch.sum(vx**2) + 1e-12) * torch.sqrt(torch.sum(vy**2) + 1e-12)
|
|
210
|
+
)
|
|
211
|
+
return pcc
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class MultipleModels(nn.Module):
|
|
215
|
+
def __init__(
|
|
216
|
+
self, n_dim, n_obj, hidden_size, train_mode, device,
|
|
217
|
+
save_dir=None, save_prefix=None
|
|
218
|
+
):
|
|
219
|
+
super(MultipleModels, self).__init__()
|
|
220
|
+
self.n_dim = n_dim
|
|
221
|
+
self.n_obj = n_obj
|
|
222
|
+
self.device = device
|
|
223
|
+
|
|
224
|
+
self.obj2model = {}
|
|
225
|
+
self.hidden_size = hidden_size
|
|
226
|
+
self.train_mode = train_mode
|
|
227
|
+
|
|
228
|
+
self.save_dir = save_dir
|
|
229
|
+
self.save_prefix = save_prefix
|
|
230
|
+
if self.save_dir is not None:
|
|
231
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
232
|
+
|
|
233
|
+
for obj in range(self.n_obj):
|
|
234
|
+
self.create_models(obj)
|
|
235
|
+
|
|
236
|
+
def create_models(self, learning_obj):
|
|
237
|
+
model = SingleModel
|
|
238
|
+
new_model = model(
|
|
239
|
+
self.n_dim,
|
|
240
|
+
self.hidden_size,
|
|
241
|
+
which_obj=learning_obj,
|
|
242
|
+
device=self.device,
|
|
243
|
+
save_dir=self.save_dir,
|
|
244
|
+
save_prefix=self.save_prefix,
|
|
245
|
+
)
|
|
246
|
+
self.obj2model[learning_obj] = new_model
|
|
247
|
+
|
|
248
|
+
def set_kwargs(self, device=None, dtype=torch.float32):
|
|
249
|
+
for model in self.obj2model.values():
|
|
250
|
+
model.set_kwargs(device=device, dtype=dtype)
|
|
251
|
+
model.to(device=device, dtype=dtype)
|
|
252
|
+
|
|
253
|
+
def forward(self, x, forward_objs=None):
|
|
254
|
+
if forward_objs is None:
|
|
255
|
+
forward_objs = list(self.obj2model.keys())
|
|
256
|
+
x = [self.obj2model[obj](x) for obj in forward_objs]
|
|
257
|
+
x = torch.cat(x, dim=1)
|
|
258
|
+
return x
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
activate_functions = [nn.LeakyReLU(), nn.LeakyReLU()]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class SingleModel(nn.Module):
|
|
265
|
+
def __init__(
|
|
266
|
+
self, input_size, hidden_size, which_obj, device,
|
|
267
|
+
save_dir=None, save_prefix=None
|
|
268
|
+
):
|
|
269
|
+
super(SingleModel, self).__init__()
|
|
270
|
+
self.n_dim = input_size
|
|
271
|
+
self.n_obj = 1
|
|
272
|
+
self.which_obj = which_obj
|
|
273
|
+
self.activate_functions = activate_functions
|
|
274
|
+
self.device = device
|
|
275
|
+
|
|
276
|
+
layers = []
|
|
277
|
+
layers.append(nn.Linear(input_size, hidden_size[0]))
|
|
278
|
+
for i in range(len(hidden_size) - 1):
|
|
279
|
+
layers.append(nn.Linear(hidden_size[i], hidden_size[i + 1]))
|
|
280
|
+
layers.append(nn.Linear(hidden_size[len(hidden_size) - 1], 1))
|
|
281
|
+
|
|
282
|
+
self.layers = nn.Sequential(*layers)
|
|
283
|
+
self.hidden_size = hidden_size
|
|
284
|
+
|
|
285
|
+
self.save_path = os.path.join(save_dir, f"{save_prefix}-{which_obj}.pt")
|
|
286
|
+
|
|
287
|
+
def forward(self, x):
|
|
288
|
+
for i in range(len(self.hidden_size)):
|
|
289
|
+
x = self.layers[i](x)
|
|
290
|
+
x = self.activate_functions[i](x)
|
|
291
|
+
|
|
292
|
+
x = self.layers[len(self.hidden_size)](x)
|
|
293
|
+
out = x
|
|
294
|
+
|
|
295
|
+
return out
|
|
296
|
+
|
|
297
|
+
def set_kwargs(self, device=None, dtype=torch.float32):
|
|
298
|
+
self.to(device=device, dtype=dtype)
|
|
299
|
+
|
|
300
|
+
def check_model_path_exist(self, save_path=None):
|
|
301
|
+
assert (
|
|
302
|
+
self.save_path is not None or save_path is not None
|
|
303
|
+
), "save path should be specified"
|
|
304
|
+
if save_path is None:
|
|
305
|
+
save_path = self.save_path
|
|
306
|
+
return os.path.exists(save_path)
|
|
307
|
+
|
|
308
|
+
def save(self, val_pcc=None, save_path=None):
|
|
309
|
+
assert (
|
|
310
|
+
self.save_path is not None or save_path is not None
|
|
311
|
+
), "save path should be specified"
|
|
312
|
+
if save_path is None:
|
|
313
|
+
save_path = self.save_path
|
|
314
|
+
|
|
315
|
+
self = self.to("cpu")
|
|
316
|
+
checkpoint = {
|
|
317
|
+
"model_state_dict": self.state_dict(),
|
|
318
|
+
}
|
|
319
|
+
if val_pcc is not None:
|
|
320
|
+
checkpoint["valid_pcc"] = val_pcc
|
|
321
|
+
|
|
322
|
+
torch.save(checkpoint, save_path)
|
|
323
|
+
self = self.to(self.device)
|
|
324
|
+
|
|
325
|
+
def load(self, save_path=None):
|
|
326
|
+
assert (
|
|
327
|
+
self.save_path is not None or save_path is not None
|
|
328
|
+
), "save path should be specified"
|
|
329
|
+
if save_path is None:
|
|
330
|
+
save_path = self.save_path
|
|
331
|
+
|
|
332
|
+
checkpoint = torch.load(save_path, weights_only=False)
|
|
333
|
+
self.load_state_dict(checkpoint["model_state_dict"])
|
|
334
|
+
valid_pcc = checkpoint["valid_pcc"]
|
|
335
|
+
print(
|
|
336
|
+
f"Successfully load trained proxy model from {save_path} "
|
|
337
|
+
f"with valid PCC = {valid_pcc}"
|
|
338
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
2
|
+
import dis
|
|
3
|
+
import inspect
|
|
4
|
+
import ast
|
|
5
|
+
import textwrap
|
|
6
|
+
|
|
7
|
+
def get_ddpm_dataloader(X,
|
|
8
|
+
y,
|
|
9
|
+
validation_split=0.1,
|
|
10
|
+
batch_size=32):
|
|
11
|
+
val_dataloader = None
|
|
12
|
+
if validation_split > 0.0:
|
|
13
|
+
train_size = int(X.shape[0] - int(X.shape[0] * validation_split))
|
|
14
|
+
X_val = X[train_size:]
|
|
15
|
+
y_val = y[train_size:]
|
|
16
|
+
X = X[:train_size]
|
|
17
|
+
y = y[:train_size]
|
|
18
|
+
|
|
19
|
+
tensor_x_val = X_val.float()
|
|
20
|
+
tensor_y_val = y_val.float()
|
|
21
|
+
dataset_val = TensorDataset(tensor_x_val, tensor_y_val)
|
|
22
|
+
val_dataloader = DataLoader(
|
|
23
|
+
dataset_val,
|
|
24
|
+
batch_size=batch_size,
|
|
25
|
+
shuffle=False,
|
|
26
|
+
drop_last=True,
|
|
27
|
+
)
|
|
28
|
+
tensor_x = X.float()
|
|
29
|
+
tensor_y = y.float()
|
|
30
|
+
dataset_train = TensorDataset(tensor_x, tensor_y)
|
|
31
|
+
train_dataloader = DataLoader(
|
|
32
|
+
dataset_train,
|
|
33
|
+
batch_size=batch_size,
|
|
34
|
+
shuffle=True,
|
|
35
|
+
drop_last=True,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return train_dataloader, val_dataloader
|
|
39
|
+
|
|
40
|
+
def is_pass_function(func) -> bool:
|
|
41
|
+
"""
|
|
42
|
+
Return True iff the function body is effectively just `pass`
|
|
43
|
+
(optionally with a docstring). Works on Python 3.8–3.12+.
|
|
44
|
+
"""
|
|
45
|
+
# -------- Try AST first (most reliable) --------
|
|
46
|
+
try:
|
|
47
|
+
src = inspect.getsource(func)
|
|
48
|
+
except (OSError, TypeError):
|
|
49
|
+
src = None
|
|
50
|
+
|
|
51
|
+
if src:
|
|
52
|
+
mod = ast.parse(textwrap.dedent(src))
|
|
53
|
+
fn = next((n for n in ast.walk(mod)
|
|
54
|
+
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))), None)
|
|
55
|
+
if fn:
|
|
56
|
+
body = list(fn.body)
|
|
57
|
+
# drop docstring if present
|
|
58
|
+
if body and isinstance(body[0], ast.Expr) and \
|
|
59
|
+
isinstance(getattr(body[0], "value", None), ast.Constant) and \
|
|
60
|
+
isinstance(body[0].value.value, str):
|
|
61
|
+
body = body[1:]
|
|
62
|
+
# True if remaining stmts are all Pass (or empty)
|
|
63
|
+
return all(isinstance(n, ast.Pass) for n in body)
|
|
64
|
+
|
|
65
|
+
# -------- Fallback: bytecode pattern (version-tolerant) --------
|
|
66
|
+
instrs = list(dis.get_instructions(func))
|
|
67
|
+
|
|
68
|
+
# remove version-specific noise
|
|
69
|
+
noise = {"RESUME", "CACHE", "EXTENDED_ARG", "NOP"}
|
|
70
|
+
core = [i for i in instrs if i.opname not in noise]
|
|
71
|
+
|
|
72
|
+
# strip docstring store: LOAD_CONST <str>; STORE_* __doc__
|
|
73
|
+
i = 0
|
|
74
|
+
while i + 1 < len(core):
|
|
75
|
+
a, b = core[i], core[i + 1]
|
|
76
|
+
if (a.opname == "LOAD_CONST" and isinstance(a.argval, str)
|
|
77
|
+
and b.opname in {"STORE_NAME", "STORE_FAST"} and b.argval == "__doc__"):
|
|
78
|
+
del core[i:i+2]
|
|
79
|
+
continue
|
|
80
|
+
i += 1
|
|
81
|
+
|
|
82
|
+
# Accept either:
|
|
83
|
+
# - LOAD_CONST None; RETURN_VALUE
|
|
84
|
+
# - RETURN_CONST None (3.12+)
|
|
85
|
+
if len(core) == 2 and core[0].opname == "LOAD_CONST" and core[0].argval is None \
|
|
86
|
+
and core[1].opname == "RETURN_VALUE":
|
|
87
|
+
return True
|
|
88
|
+
if len(core) == 1 and core[0].opname == "RETURN_CONST" and core[0].argval is None:
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
return False
|