blnetwork 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blnetwork/export.py +269 -0
- blnetwork/inference/__init__.py +10 -0
- blnetwork/inference/continuous.py +21 -0
- blnetwork/inference/discrete.py +40 -0
- blnetwork/model/__init__.py +7 -0
- blnetwork/model/bldeep.py +229 -0
- blnetwork/model/utils.py +68 -0
- blnetwork/training/__init__.py +17 -0
- blnetwork/training/amortized.py +234 -0
- blnetwork/training/base.py +289 -0
- blnetwork/training/continuous.py +40 -0
- blnetwork/training/discrete.py +34 -0
- blnetwork/training/losses.py +80 -0
- blnetwork/training/utils.py +207 -0
- blnetwork-0.1.0.dist-info/METADATA +121 -0
- blnetwork-0.1.0.dist-info/RECORD +19 -0
- blnetwork-0.1.0.dist-info/WHEEL +5 -0
- blnetwork-0.1.0.dist-info/licenses/LICENSE +21 -0
- blnetwork-0.1.0.dist-info/top_level.txt +1 -0
blnetwork/export.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple, Dict
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import sys
|
|
8
|
+
import io
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
def _fmt_num(x: float, ndigits: int = 4) -> str:
|
|
12
|
+
s = f"{x:.{ndigits}f}"
|
|
13
|
+
s = s.rstrip("0").rstrip(".")
|
|
14
|
+
return s if s != "" else "0"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _safe_numpy(t):
|
|
18
|
+
if t is None:
|
|
19
|
+
return None
|
|
20
|
+
if torch.is_tensor(t):
|
|
21
|
+
return t.detach().cpu().numpy()
|
|
22
|
+
return np.asarray(t)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _get_bl_unit(block):
|
|
26
|
+
unit = getattr(block, "unit", None)
|
|
27
|
+
if unit is None:
|
|
28
|
+
raise AttributeError("Block has no attribute 'unit' (expected BLBlock.unit).")
|
|
29
|
+
return unit
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_lambdas(unit):
|
|
33
|
+
lam = unit.lam
|
|
34
|
+
if getattr(unit, "constrain_lambda", False):
|
|
35
|
+
eps = float(getattr(unit, "eps", 1e-8))
|
|
36
|
+
lam = F.softplus(lam) + eps
|
|
37
|
+
|
|
38
|
+
lam_u = _safe_numpy(lam[0])
|
|
39
|
+
lam_c = _safe_numpy(lam[1])
|
|
40
|
+
lam_t = _safe_numpy(lam[2])
|
|
41
|
+
|
|
42
|
+
return lam_u, lam_c, lam_t
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_backbone(model):
|
|
46
|
+
backbone = getattr(model, "backbone", None)
|
|
47
|
+
if backbone is not None:
|
|
48
|
+
return backbone
|
|
49
|
+
if hasattr(model, "blocks"):
|
|
50
|
+
return model
|
|
51
|
+
raise AttributeError("Cannot find backbone. Expected model.backbone or model.blocks.")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _get_blocks(model) -> list:
|
|
55
|
+
backbone = _get_backbone(model)
|
|
56
|
+
blocks = getattr(backbone, "blocks", None)
|
|
57
|
+
if blocks is None:
|
|
58
|
+
raise AttributeError("Backbone has no attribute 'blocks'.")
|
|
59
|
+
return list(blocks)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _get_hidden_dims(model) -> Optional[Tuple[int, ...]]:
|
|
63
|
+
backbone = _get_backbone(model)
|
|
64
|
+
v = getattr(backbone, "hidden_dims", None)
|
|
65
|
+
if v is None:
|
|
66
|
+
return None
|
|
67
|
+
return tuple(map(int, v))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_output_linears(model) -> Dict[str, torch.nn.Linear]:
|
|
71
|
+
outs: Dict[str, torch.nn.Linear] = {}
|
|
72
|
+
|
|
73
|
+
lin = getattr(model, "linear_out", None)
|
|
74
|
+
if lin is not None and hasattr(lin, "weight"):
|
|
75
|
+
outs["Output Layer (Discrete)"] = lin
|
|
76
|
+
return outs
|
|
77
|
+
|
|
78
|
+
head = getattr(model, "head", None)
|
|
79
|
+
if head is None:
|
|
80
|
+
return outs
|
|
81
|
+
|
|
82
|
+
if isinstance(head, torch.nn.Linear):
|
|
83
|
+
outs["Output Layer"] = head
|
|
84
|
+
return outs
|
|
85
|
+
|
|
86
|
+
lin = getattr(head, "linear", None)
|
|
87
|
+
if lin is not None and hasattr(lin, "weight"):
|
|
88
|
+
outs["Output Layer"] = lin
|
|
89
|
+
|
|
90
|
+
return outs
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _emit_part_lines(
|
|
94
|
+
part_name: str,
|
|
95
|
+
lam_val: float,
|
|
96
|
+
w_row: np.ndarray,
|
|
97
|
+
b_val: float,
|
|
98
|
+
feature_names: List[str],
|
|
99
|
+
ndigits: int = 4,
|
|
100
|
+
tol: float = 0.0,
|
|
101
|
+
) -> List[str]:
|
|
102
|
+
lines = []
|
|
103
|
+
lines.append(f"{part_name}")
|
|
104
|
+
lines.append(f"lambda {_fmt_num(float(lam_val), ndigits)}")
|
|
105
|
+
|
|
106
|
+
for name, w in zip(feature_names, w_row):
|
|
107
|
+
w = float(w)
|
|
108
|
+
if abs(w) > tol:
|
|
109
|
+
lines.append(f"----{name} {_fmt_num(w, ndigits)}")
|
|
110
|
+
|
|
111
|
+
if abs(float(b_val)) > tol:
|
|
112
|
+
lines.append(f"----C {_fmt_num(float(b_val), ndigits)}")
|
|
113
|
+
|
|
114
|
+
return lines
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _print_blocks(
|
|
118
|
+
block,
|
|
119
|
+
feature_names: List[str],
|
|
120
|
+
layer_idx: int,
|
|
121
|
+
ndigits: int = 4,
|
|
122
|
+
tol: float = 0.0,
|
|
123
|
+
) -> int:
|
|
124
|
+
|
|
125
|
+
unit = _get_bl_unit(block)
|
|
126
|
+
|
|
127
|
+
num_basis = int(unit.lin_u.out_features)
|
|
128
|
+
|
|
129
|
+
lam_u, lam_c, lam_t = _get_lambdas(unit)
|
|
130
|
+
|
|
131
|
+
w_u = _safe_numpy(unit.lin_u.weight)
|
|
132
|
+
b_u = _safe_numpy(unit.lin_u.bias)
|
|
133
|
+
w_c = _safe_numpy(unit.lin_c.weight)
|
|
134
|
+
b_c = _safe_numpy(unit.lin_c.bias)
|
|
135
|
+
w_t = _safe_numpy(unit.lin_t.weight)
|
|
136
|
+
b_t = _safe_numpy(unit.lin_t.bias)
|
|
137
|
+
|
|
138
|
+
for j in range(num_basis):
|
|
139
|
+
block_id = j + 1
|
|
140
|
+
print(f"--B{layer_idx}{block_id}")
|
|
141
|
+
|
|
142
|
+
# U part
|
|
143
|
+
lines = _emit_part_lines(
|
|
144
|
+
"U",
|
|
145
|
+
lam_val=float(lam_u[j]),
|
|
146
|
+
w_row=w_u[j],
|
|
147
|
+
b_val=float(b_u[j]),
|
|
148
|
+
feature_names=feature_names,
|
|
149
|
+
ndigits=ndigits,
|
|
150
|
+
tol=tol,
|
|
151
|
+
)
|
|
152
|
+
for ln in lines:
|
|
153
|
+
print(ln)
|
|
154
|
+
|
|
155
|
+
# C part
|
|
156
|
+
lines = _emit_part_lines(
|
|
157
|
+
"C",
|
|
158
|
+
lam_val=float(lam_c[j]),
|
|
159
|
+
w_row=w_c[j],
|
|
160
|
+
b_val=float(b_c[j]),
|
|
161
|
+
feature_names=feature_names,
|
|
162
|
+
ndigits=ndigits,
|
|
163
|
+
tol=tol,
|
|
164
|
+
)
|
|
165
|
+
for ln in lines:
|
|
166
|
+
print(ln)
|
|
167
|
+
|
|
168
|
+
# T part
|
|
169
|
+
lines = _emit_part_lines(
|
|
170
|
+
"T",
|
|
171
|
+
lam_val=float(lam_t[j]),
|
|
172
|
+
w_row=w_t[j],
|
|
173
|
+
b_val=float(b_t[j]),
|
|
174
|
+
feature_names=feature_names,
|
|
175
|
+
ndigits=ndigits,
|
|
176
|
+
tol=tol,
|
|
177
|
+
)
|
|
178
|
+
for ln in lines:
|
|
179
|
+
print(ln)
|
|
180
|
+
|
|
181
|
+
print("")
|
|
182
|
+
|
|
183
|
+
return num_basis
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _print_core(model, blocks, hidden_dims, feat_names, ndigits, tol, title="BL Model Structure"):
|
|
187
|
+
print("=" * 72)
|
|
188
|
+
print(title)
|
|
189
|
+
print("=" * 72)
|
|
190
|
+
|
|
191
|
+
print(f"hidden_dims = {hidden_dims}")
|
|
192
|
+
print(f"feature_dim = {len(feat_names)}")
|
|
193
|
+
print("")
|
|
194
|
+
|
|
195
|
+
current_feat_names = feat_names
|
|
196
|
+
|
|
197
|
+
for layer_idx, block in enumerate(blocks, start=1):
|
|
198
|
+
num_basis = _print_blocks(
|
|
199
|
+
block=block,
|
|
200
|
+
feature_names=current_feat_names,
|
|
201
|
+
layer_idx=layer_idx,
|
|
202
|
+
ndigits=ndigits,
|
|
203
|
+
tol=tol,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if layer_idx < len(blocks):
|
|
207
|
+
current_feat_names = [f"B{layer_idx}{i+1}" for i in range(num_basis)]
|
|
208
|
+
|
|
209
|
+
outs = _get_output_linears(model)
|
|
210
|
+
if len(outs) > 0:
|
|
211
|
+
print("=" * 72)
|
|
212
|
+
print("OUTPUT LINEAR(S)")
|
|
213
|
+
print("=" * 72)
|
|
214
|
+
for name, lin in outs.items():
|
|
215
|
+
w = _safe_numpy(lin.weight)
|
|
216
|
+
b = _safe_numpy(lin.bias) if lin.bias is not None else None
|
|
217
|
+
print(name)
|
|
218
|
+
if getattr(w, "size", 0) and w.size <= 16:
|
|
219
|
+
w_flat = [float(x) for x in w.reshape(-1)]
|
|
220
|
+
w_fmt = ", ".join(_fmt_num(x, ndigits) for x in w_flat)
|
|
221
|
+
print(f"weight values = [{w_fmt}]")
|
|
222
|
+
|
|
223
|
+
if b is None:
|
|
224
|
+
print("bias = None")
|
|
225
|
+
else:
|
|
226
|
+
if getattr(b, "size", 0) and b.size <= 16:
|
|
227
|
+
b_flat = [float(x) for x in b.reshape(-1)]
|
|
228
|
+
b_fmt = ", ".join(_fmt_num(x, ndigits) for x in b_flat)
|
|
229
|
+
print(f"bias values = [{b_fmt}]")
|
|
230
|
+
print("")
|
|
231
|
+
|
|
232
|
+
def export_structure(
|
|
233
|
+
model,
|
|
234
|
+
df=None,
|
|
235
|
+
feature_names=None,
|
|
236
|
+
txt_path: Optional[str] = None,
|
|
237
|
+
ndigits: int = 4,
|
|
238
|
+
tol: float = 0.0,
|
|
239
|
+
title: str = "BL Model Structure",
|
|
240
|
+
):
|
|
241
|
+
|
|
242
|
+
blocks = _get_blocks(model)
|
|
243
|
+
hidden_dims = _get_hidden_dims(model)
|
|
244
|
+
|
|
245
|
+
if feature_names is not None:
|
|
246
|
+
feat_names = list(feature_names)
|
|
247
|
+
elif df is not None:
|
|
248
|
+
feat_names = list(df.columns)
|
|
249
|
+
else:
|
|
250
|
+
unit0 = _get_bl_unit(blocks[0])
|
|
251
|
+
if not hasattr(unit0, "lin_u"):
|
|
252
|
+
raise AttributeError("Cannot infer input dim from BLUnit. Expected 'lin_u'.")
|
|
253
|
+
in_dim = int(unit0.lin_u.in_features)
|
|
254
|
+
feat_names = [f"x{i+1}" for i in range(in_dim)]
|
|
255
|
+
|
|
256
|
+
if txt_path is None:
|
|
257
|
+
_print_core(model, blocks, hidden_dims, feat_names, ndigits, tol, title)
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
old = sys.stdout
|
|
261
|
+
buf = io.StringIO()
|
|
262
|
+
sys.stdout = buf
|
|
263
|
+
try:
|
|
264
|
+
_print_core(model, blocks, hidden_dims, feat_names, ndigits, tol, title)
|
|
265
|
+
finally:
|
|
266
|
+
sys.stdout = old
|
|
267
|
+
|
|
268
|
+
with open(txt_path, "w", encoding="utf-8") as f:
|
|
269
|
+
f.write(buf.getvalue())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..training import utils as U
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@torch.no_grad()
|
|
10
|
+
def predict_continuous(
|
|
11
|
+
predictor: nn.Module,
|
|
12
|
+
x: torch.Tensor,
|
|
13
|
+
*,
|
|
14
|
+
device: str | torch.device | None = None,
|
|
15
|
+
return_cpu: bool = True,
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
|
|
18
|
+
predictor.eval()
|
|
19
|
+
dev = U.resolve_device(model=predictor, tensor=x, device=device)
|
|
20
|
+
y_hat = predictor(x.to(dev))
|
|
21
|
+
return y_hat.detach().cpu() if return_cpu else y_hat.detach()
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from ..training import utils as U
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@torch.no_grad()
|
|
8
|
+
def predict_proba_discrete(
|
|
9
|
+
model,
|
|
10
|
+
x: torch.Tensor,
|
|
11
|
+
*,
|
|
12
|
+
temperature: float = 1.0,
|
|
13
|
+
device: str | torch.device | None = None,
|
|
14
|
+
return_cpu: bool = False,
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
|
|
17
|
+
dev = U.resolve_device(model, x, device)
|
|
18
|
+
x = x.to(dev)
|
|
19
|
+
|
|
20
|
+
scores = model.logits(x) if hasattr(model, "logits") else model(x)
|
|
21
|
+
logits = scores / float(temperature)
|
|
22
|
+
probs = F.softmax(logits, dim=1)
|
|
23
|
+
|
|
24
|
+
return probs.cpu() if return_cpu else probs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@torch.no_grad()
|
|
28
|
+
def predict_class_discrete(
|
|
29
|
+
model,
|
|
30
|
+
x: torch.Tensor,
|
|
31
|
+
*,
|
|
32
|
+
temperature: float = 1.0,
|
|
33
|
+
device: str | torch.device | None = None,
|
|
34
|
+
return_cpu: bool = True,
|
|
35
|
+
) -> torch.Tensor:
|
|
36
|
+
probs = predict_proba_discrete(
|
|
37
|
+
model, x, temperature=temperature, device=device, return_cpu=False
|
|
38
|
+
)
|
|
39
|
+
pred = probs.argmax(dim=1)
|
|
40
|
+
return pred.cpu() if return_cpu else pred
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from typing import List, Sequence, Optional
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from . import utils as U
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BLUnit(nn.Module):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
in_dim: int,
|
|
12
|
+
num_basis: int,
|
|
13
|
+
second_act_func: str = "relu",
|
|
14
|
+
third_act_func: str = "abs",
|
|
15
|
+
eps: float = 1e-8,
|
|
16
|
+
constrain_lambda: bool = True,
|
|
17
|
+
init_lambda: float = 1.0,
|
|
18
|
+
beta: float = 1.0,
|
|
19
|
+
) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.in_dim = int(in_dim)
|
|
22
|
+
self.num_basis = int(num_basis)
|
|
23
|
+
self.second_act_func = str(second_act_func)
|
|
24
|
+
self.third_act_func = str(third_act_func)
|
|
25
|
+
self.eps = float(eps)
|
|
26
|
+
self.constrain_lambda = bool(constrain_lambda)
|
|
27
|
+
self.beta = float(beta)
|
|
28
|
+
|
|
29
|
+
init_lambda = float(init_lambda)
|
|
30
|
+
|
|
31
|
+
self.lin_u = nn.Linear(self.in_dim, self.num_basis, bias=True)
|
|
32
|
+
self.lin_c = nn.Linear(self.in_dim, self.num_basis, bias=True)
|
|
33
|
+
self.lin_t = nn.Linear(self.in_dim, self.num_basis, bias=True)
|
|
34
|
+
|
|
35
|
+
self.lam = nn.Parameter(torch.full((3, self.num_basis), init_lambda))
|
|
36
|
+
|
|
37
|
+
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
|
38
|
+
u = torch.tanh(self.lin_u(z))
|
|
39
|
+
c = U.second_activation(self.lin_c(z), self.second_act_func, beta=self.beta)
|
|
40
|
+
t = U.third_activation(self.lin_t(z), self.third_act_func)
|
|
41
|
+
|
|
42
|
+
if self.constrain_lambda:
|
|
43
|
+
lam = F.softplus(self.lam) + self.eps
|
|
44
|
+
else:
|
|
45
|
+
lam = self.lam
|
|
46
|
+
|
|
47
|
+
lam_u = lam[0]
|
|
48
|
+
lam_c = lam[1]
|
|
49
|
+
lam_t = lam[2]
|
|
50
|
+
|
|
51
|
+
return lam_u * u - lam_c * c - lam_t * t
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class BLBlock(nn.Module):
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
in_dim: int,
|
|
58
|
+
num_basis: int,
|
|
59
|
+
second_act_func: str = "relu",
|
|
60
|
+
third_act_func: str = "abs",
|
|
61
|
+
constrain_lambda: bool = True,
|
|
62
|
+
init_lambda: float = 1.0,
|
|
63
|
+
beta: float = 1.0,
|
|
64
|
+
) -> None:
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.unit = BLUnit(
|
|
67
|
+
in_dim=in_dim,
|
|
68
|
+
num_basis=num_basis,
|
|
69
|
+
second_act_func=second_act_func,
|
|
70
|
+
third_act_func=third_act_func,
|
|
71
|
+
constrain_lambda=constrain_lambda,
|
|
72
|
+
init_lambda=init_lambda,
|
|
73
|
+
beta=beta,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
77
|
+
return self.unit(x)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class BLDeepBackbone(nn.Module):
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
in_dim: int,
|
|
84
|
+
hidden_dims: Sequence[int],
|
|
85
|
+
second_act_func: str = "relu",
|
|
86
|
+
third_act_func: str = "abs",
|
|
87
|
+
constrain_lambda: bool = True,
|
|
88
|
+
init_lambda: float = 1.0,
|
|
89
|
+
beta: float = 1.0,
|
|
90
|
+
) -> None:
|
|
91
|
+
super().__init__()
|
|
92
|
+
|
|
93
|
+
self.in_dim = int(in_dim)
|
|
94
|
+
self.hidden_dims = list(hidden_dims)
|
|
95
|
+
dims: List[int] = [int(in_dim)] + list(hidden_dims[:-1])
|
|
96
|
+
|
|
97
|
+
self.blocks = nn.ModuleList(
|
|
98
|
+
BLBlock(
|
|
99
|
+
dims[i],
|
|
100
|
+
num_basis=hidden_dims[i],
|
|
101
|
+
second_act_func=second_act_func,
|
|
102
|
+
third_act_func=third_act_func,
|
|
103
|
+
constrain_lambda=constrain_lambda,
|
|
104
|
+
init_lambda=init_lambda,
|
|
105
|
+
beta=beta,
|
|
106
|
+
)
|
|
107
|
+
for i in range(len(hidden_dims))
|
|
108
|
+
)
|
|
109
|
+
self.out_dim = int(hidden_dims[-1])
|
|
110
|
+
|
|
111
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
112
|
+
for blk in self.blocks:
|
|
113
|
+
x = blk(x)
|
|
114
|
+
return x
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class BLDeep(nn.Module):
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
hidden_dims: Sequence[int],
|
|
121
|
+
second_act_func: str = "relu",
|
|
122
|
+
third_act_func: str = "abs",
|
|
123
|
+
head_bias: bool = True,
|
|
124
|
+
num_classes: Optional[int] = None,
|
|
125
|
+
task: str = "continuous",
|
|
126
|
+
constrain_lambda: bool = True,
|
|
127
|
+
init_lambda: float = 1.0,
|
|
128
|
+
beta: float = 1.0,
|
|
129
|
+
) -> None:
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.hidden_dims = list(hidden_dims)
|
|
132
|
+
self.second_act_func = str(second_act_func)
|
|
133
|
+
self.third_act_func = third_act_func
|
|
134
|
+
self.head_bias = head_bias
|
|
135
|
+
self.num_classes = num_classes
|
|
136
|
+
|
|
137
|
+
if task not in {"continuous", "discrete"}:
|
|
138
|
+
raise ValueError(f"task must be 'continuous' or 'discrete', got '{task}'")
|
|
139
|
+
self.task = task
|
|
140
|
+
self.constrain_lambda = bool(constrain_lambda)
|
|
141
|
+
self.init_lambda = float(init_lambda)
|
|
142
|
+
self.beta = float(beta)
|
|
143
|
+
|
|
144
|
+
self.x_dim: Optional[int] = None
|
|
145
|
+
self.y_dim: Optional[int] = None
|
|
146
|
+
self.backbone: Optional[nn.Module] = None
|
|
147
|
+
self.head: Optional[nn.Module] = None
|
|
148
|
+
|
|
149
|
+
def _build_architecture(self, x: torch.Tensor) -> None:
|
|
150
|
+
self.backbone = BLDeepBackbone(
|
|
151
|
+
in_dim=self.x_dim + self.y_dim,
|
|
152
|
+
hidden_dims=self.hidden_dims,
|
|
153
|
+
second_act_func=self.second_act_func,
|
|
154
|
+
third_act_func=self.third_act_func,
|
|
155
|
+
constrain_lambda=self.constrain_lambda,
|
|
156
|
+
init_lambda=self.init_lambda,
|
|
157
|
+
beta=self.beta,
|
|
158
|
+
)
|
|
159
|
+
self.head = nn.Linear(self.backbone.out_dim, 1, bias=self.head_bias)
|
|
160
|
+
|
|
161
|
+
device, dtype = x.device, x.dtype
|
|
162
|
+
self.to(device=device, dtype=dtype)
|
|
163
|
+
|
|
164
|
+
def build(self, X: torch.Tensor, y: torch.Tensor) -> None:
|
|
165
|
+
self.x_dim = X.shape[1]
|
|
166
|
+
|
|
167
|
+
if self.task == "discrete":
|
|
168
|
+
m, y_idx = U.infer_num_classes(y)
|
|
169
|
+
self.num_classes = int(m)
|
|
170
|
+
self.y_dim = int(m)
|
|
171
|
+
if not torch.equal(y_idx, y.long()):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Discrete labels must be in range [0..K-1]. "
|
|
174
|
+
f"Got non-continuous labels. Please remap your labels to [0, 1, 2, ...] first."
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
self.y_dim = 1 if y.ndim == 1 else y.shape[1]
|
|
178
|
+
|
|
179
|
+
self._build_architecture(X)
|
|
180
|
+
|
|
181
|
+
def build_for_discrete_inference(self, x: torch.Tensor, num_classes: int) -> None:
|
|
182
|
+
if x.ndim != 2:
|
|
183
|
+
raise ValueError(f"x must be 2D (B, x_dim), got shape {tuple(x.shape)}")
|
|
184
|
+
if self.task != "discrete":
|
|
185
|
+
raise ValueError(f"This function only works with task='discrete', got task='{self.task}'")
|
|
186
|
+
|
|
187
|
+
self.x_dim = int(x.shape[1])
|
|
188
|
+
self.num_classes = int(num_classes)
|
|
189
|
+
self.y_dim = int(self.num_classes)
|
|
190
|
+
|
|
191
|
+
self._build_architecture(x)
|
|
192
|
+
|
|
193
|
+
def score(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
194
|
+
if self.backbone is None:
|
|
195
|
+
self.build(x, y)
|
|
196
|
+
|
|
197
|
+
if y.ndim == 1:
|
|
198
|
+
if self.task == "discrete":
|
|
199
|
+
y = F.one_hot(y.long(), num_classes=self.num_classes).to(device=x.device, dtype=x.dtype)
|
|
200
|
+
else:
|
|
201
|
+
y = y.unsqueeze(1)
|
|
202
|
+
elif y.ndim == 2:
|
|
203
|
+
if self.task == "discrete" and y.shape[1] == 1:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"For discrete task, y should be 1D class indices (shape (B,)), "
|
|
206
|
+
f"but got shape {tuple(y.shape)}. Use y.squeeze(1) to convert (B,1) -> (B)."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
z = torch.cat([x, y.to(device=x.device, dtype=x.dtype)], dim=1)
|
|
210
|
+
feats = self.backbone(z)
|
|
211
|
+
return self.head(feats)
|
|
212
|
+
|
|
213
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
214
|
+
return self.score(x, y)
|
|
215
|
+
|
|
216
|
+
def logits(self, x: torch.Tensor, num_classes: int | None = None) -> torch.Tensor:
|
|
217
|
+
if self.task != "discrete":
|
|
218
|
+
raise RuntimeError("logits() is only available when task='discrete'.")
|
|
219
|
+
|
|
220
|
+
m = int(num_classes or (self.num_classes or 0))
|
|
221
|
+
if m <= 0:
|
|
222
|
+
raise RuntimeError(
|
|
223
|
+
"Unknown num_classes for discrete inference. "
|
|
224
|
+
"Pass num_classes=K, or initialize BLDeep(..., num_classes=K), "
|
|
225
|
+
)
|
|
226
|
+
if self.backbone is None:
|
|
227
|
+
self.build_for_discrete_inference(x, num_classes=m)
|
|
228
|
+
|
|
229
|
+
return U.enumerate_onehot_logits(self.score, x, m=m)
|
blnetwork/model/utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Callable, Tuple
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"second_activation",
|
|
6
|
+
"third_activation",
|
|
7
|
+
"infer_num_classes",
|
|
8
|
+
"onehot_candidates",
|
|
9
|
+
"enumerate_onehot_logits",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
def second_activation(z: torch.Tensor, second_act_func: str, beta: float = 1.0) -> torch.Tensor:
|
|
13
|
+
if second_act_func == "relu":
|
|
14
|
+
return torch.relu(z)
|
|
15
|
+
if second_act_func == "softplus":
|
|
16
|
+
return torch.nn.functional.softplus(z, beta=beta)
|
|
17
|
+
raise ValueError(f"Unknown second_act_func='{second_act_func}'. Use 'relu' or 'softplus'.")
|
|
18
|
+
|
|
19
|
+
def third_activation(z: torch.Tensor, third_act_func: str) -> torch.Tensor:
|
|
20
|
+
if third_act_func == "abs":
|
|
21
|
+
return torch.abs(z)
|
|
22
|
+
if third_act_func == "square":
|
|
23
|
+
return z ** 2
|
|
24
|
+
raise ValueError(f"Unknown third_act_func='{third_act_func}'. Use 'abs' or 'square'.")
|
|
25
|
+
|
|
26
|
+
def infer_num_classes(y: torch.Tensor) -> Tuple[int, torch.Tensor]:
|
|
27
|
+
|
|
28
|
+
if y.ndim == 2:
|
|
29
|
+
K = int(y.shape[1])
|
|
30
|
+
y_idx = torch.argmax(y, dim=1).long()
|
|
31
|
+
return K, y_idx
|
|
32
|
+
|
|
33
|
+
if y.ndim != 1:
|
|
34
|
+
raise ValueError(f"y must be 1D class-index or 2D one-hot, got shape={tuple(y.shape)}")
|
|
35
|
+
|
|
36
|
+
if y.dtype.is_floating_point:
|
|
37
|
+
if not torch.allclose(y, y.round()):
|
|
38
|
+
raise ValueError("Discrete y looks continuous (non-integer floats).")
|
|
39
|
+
y = y.round()
|
|
40
|
+
|
|
41
|
+
y = y.long()
|
|
42
|
+
classes, y_idx = torch.unique(y, sorted=True, return_inverse=True)
|
|
43
|
+
K = int(classes.numel())
|
|
44
|
+
return K, y_idx
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@torch.no_grad()
|
|
48
|
+
def onehot_candidates(m: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
|
49
|
+
return torch.eye(int(m), device=device, dtype=dtype)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def enumerate_onehot_logits(
|
|
53
|
+
score_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
54
|
+
x: torch.Tensor,
|
|
55
|
+
m: int,
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
|
|
58
|
+
B = x.shape[0]
|
|
59
|
+
m = int(m)
|
|
60
|
+
device, dtype = x.device, x.dtype
|
|
61
|
+
|
|
62
|
+
X_rep = x.repeat_interleave(m, dim=0)
|
|
63
|
+
Y = onehot_candidates(m, device, dtype)
|
|
64
|
+
Y_rep = Y.repeat(B, 1)
|
|
65
|
+
|
|
66
|
+
e = score_fn(X_rep, Y_rep)
|
|
67
|
+
e = e.view(B, m)
|
|
68
|
+
return e
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .base import TrainConfig
|
|
4
|
+
from .utils import OptimConfig, ExportConfig
|
|
5
|
+
from .continuous import ContinuousTrainer
|
|
6
|
+
from .discrete import DiscreteTrainer
|
|
7
|
+
from .amortized import AmortizedConfig, fit_amortized_predictor
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"OptimConfig",
|
|
11
|
+
"TrainConfig",
|
|
12
|
+
"ExportConfig",
|
|
13
|
+
"ContinuousTrainer",
|
|
14
|
+
"DiscreteTrainer",
|
|
15
|
+
"AmortizedConfig",
|
|
16
|
+
"fit_amortized_predictor"
|
|
17
|
+
]
|