blnetwork 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
blnetwork/export.py ADDED
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Optional, Tuple, Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ import sys
8
+ import io
9
+ import torch.nn.functional as F
10
+
11
+ def _fmt_num(x: float, ndigits: int = 4) -> str:
12
+ s = f"{x:.{ndigits}f}"
13
+ s = s.rstrip("0").rstrip(".")
14
+ return s if s != "" else "0"
15
+
16
+
17
+ def _safe_numpy(t):
18
+ if t is None:
19
+ return None
20
+ if torch.is_tensor(t):
21
+ return t.detach().cpu().numpy()
22
+ return np.asarray(t)
23
+
24
+
25
+ def _get_bl_unit(block):
26
+ unit = getattr(block, "unit", None)
27
+ if unit is None:
28
+ raise AttributeError("Block has no attribute 'unit' (expected BLBlock.unit).")
29
+ return unit
30
+
31
+
32
+ def _get_lambdas(unit):
33
+ lam = unit.lam
34
+ if getattr(unit, "constrain_lambda", False):
35
+ eps = float(getattr(unit, "eps", 1e-8))
36
+ lam = F.softplus(lam) + eps
37
+
38
+ lam_u = _safe_numpy(lam[0])
39
+ lam_c = _safe_numpy(lam[1])
40
+ lam_t = _safe_numpy(lam[2])
41
+
42
+ return lam_u, lam_c, lam_t
43
+
44
+
45
+ def _get_backbone(model):
46
+ backbone = getattr(model, "backbone", None)
47
+ if backbone is not None:
48
+ return backbone
49
+ if hasattr(model, "blocks"):
50
+ return model
51
+ raise AttributeError("Cannot find backbone. Expected model.backbone or model.blocks.")
52
+
53
+
54
+ def _get_blocks(model) -> list:
55
+ backbone = _get_backbone(model)
56
+ blocks = getattr(backbone, "blocks", None)
57
+ if blocks is None:
58
+ raise AttributeError("Backbone has no attribute 'blocks'.")
59
+ return list(blocks)
60
+
61
+
62
+ def _get_hidden_dims(model) -> Optional[Tuple[int, ...]]:
63
+ backbone = _get_backbone(model)
64
+ v = getattr(backbone, "hidden_dims", None)
65
+ if v is None:
66
+ return None
67
+ return tuple(map(int, v))
68
+
69
+
70
+ def _get_output_linears(model) -> Dict[str, torch.nn.Linear]:
71
+ outs: Dict[str, torch.nn.Linear] = {}
72
+
73
+ lin = getattr(model, "linear_out", None)
74
+ if lin is not None and hasattr(lin, "weight"):
75
+ outs["Output Layer (Discrete)"] = lin
76
+ return outs
77
+
78
+ head = getattr(model, "head", None)
79
+ if head is None:
80
+ return outs
81
+
82
+ if isinstance(head, torch.nn.Linear):
83
+ outs["Output Layer"] = head
84
+ return outs
85
+
86
+ lin = getattr(head, "linear", None)
87
+ if lin is not None and hasattr(lin, "weight"):
88
+ outs["Output Layer"] = lin
89
+
90
+ return outs
91
+
92
+
93
+ def _emit_part_lines(
94
+ part_name: str,
95
+ lam_val: float,
96
+ w_row: np.ndarray,
97
+ b_val: float,
98
+ feature_names: List[str],
99
+ ndigits: int = 4,
100
+ tol: float = 0.0,
101
+ ) -> List[str]:
102
+ lines = []
103
+ lines.append(f"{part_name}")
104
+ lines.append(f"lambda {_fmt_num(float(lam_val), ndigits)}")
105
+
106
+ for name, w in zip(feature_names, w_row):
107
+ w = float(w)
108
+ if abs(w) > tol:
109
+ lines.append(f"----{name} {_fmt_num(w, ndigits)}")
110
+
111
+ if abs(float(b_val)) > tol:
112
+ lines.append(f"----C {_fmt_num(float(b_val), ndigits)}")
113
+
114
+ return lines
115
+
116
+
117
+ def _print_blocks(
118
+ block,
119
+ feature_names: List[str],
120
+ layer_idx: int,
121
+ ndigits: int = 4,
122
+ tol: float = 0.0,
123
+ ) -> int:
124
+
125
+ unit = _get_bl_unit(block)
126
+
127
+ num_basis = int(unit.lin_u.out_features)
128
+
129
+ lam_u, lam_c, lam_t = _get_lambdas(unit)
130
+
131
+ w_u = _safe_numpy(unit.lin_u.weight)
132
+ b_u = _safe_numpy(unit.lin_u.bias)
133
+ w_c = _safe_numpy(unit.lin_c.weight)
134
+ b_c = _safe_numpy(unit.lin_c.bias)
135
+ w_t = _safe_numpy(unit.lin_t.weight)
136
+ b_t = _safe_numpy(unit.lin_t.bias)
137
+
138
+ for j in range(num_basis):
139
+ block_id = j + 1
140
+ print(f"--B{layer_idx}{block_id}")
141
+
142
+ # U part
143
+ lines = _emit_part_lines(
144
+ "U",
145
+ lam_val=float(lam_u[j]),
146
+ w_row=w_u[j],
147
+ b_val=float(b_u[j]),
148
+ feature_names=feature_names,
149
+ ndigits=ndigits,
150
+ tol=tol,
151
+ )
152
+ for ln in lines:
153
+ print(ln)
154
+
155
+ # C part
156
+ lines = _emit_part_lines(
157
+ "C",
158
+ lam_val=float(lam_c[j]),
159
+ w_row=w_c[j],
160
+ b_val=float(b_c[j]),
161
+ feature_names=feature_names,
162
+ ndigits=ndigits,
163
+ tol=tol,
164
+ )
165
+ for ln in lines:
166
+ print(ln)
167
+
168
+ # T part
169
+ lines = _emit_part_lines(
170
+ "T",
171
+ lam_val=float(lam_t[j]),
172
+ w_row=w_t[j],
173
+ b_val=float(b_t[j]),
174
+ feature_names=feature_names,
175
+ ndigits=ndigits,
176
+ tol=tol,
177
+ )
178
+ for ln in lines:
179
+ print(ln)
180
+
181
+ print("")
182
+
183
+ return num_basis
184
+
185
+
186
+ def _print_core(model, blocks, hidden_dims, feat_names, ndigits, tol, title="BL Model Structure"):
187
+ print("=" * 72)
188
+ print(title)
189
+ print("=" * 72)
190
+
191
+ print(f"hidden_dims = {hidden_dims}")
192
+ print(f"feature_dim = {len(feat_names)}")
193
+ print("")
194
+
195
+ current_feat_names = feat_names
196
+
197
+ for layer_idx, block in enumerate(blocks, start=1):
198
+ num_basis = _print_blocks(
199
+ block=block,
200
+ feature_names=current_feat_names,
201
+ layer_idx=layer_idx,
202
+ ndigits=ndigits,
203
+ tol=tol,
204
+ )
205
+
206
+ if layer_idx < len(blocks):
207
+ current_feat_names = [f"B{layer_idx}{i+1}" for i in range(num_basis)]
208
+
209
+ outs = _get_output_linears(model)
210
+ if len(outs) > 0:
211
+ print("=" * 72)
212
+ print("OUTPUT LINEAR(S)")
213
+ print("=" * 72)
214
+ for name, lin in outs.items():
215
+ w = _safe_numpy(lin.weight)
216
+ b = _safe_numpy(lin.bias) if lin.bias is not None else None
217
+ print(name)
218
+ if getattr(w, "size", 0) and w.size <= 16:
219
+ w_flat = [float(x) for x in w.reshape(-1)]
220
+ w_fmt = ", ".join(_fmt_num(x, ndigits) for x in w_flat)
221
+ print(f"weight values = [{w_fmt}]")
222
+
223
+ if b is None:
224
+ print("bias = None")
225
+ else:
226
+ if getattr(b, "size", 0) and b.size <= 16:
227
+ b_flat = [float(x) for x in b.reshape(-1)]
228
+ b_fmt = ", ".join(_fmt_num(x, ndigits) for x in b_flat)
229
+ print(f"bias values = [{b_fmt}]")
230
+ print("")
231
+
232
+ def export_structure(
233
+ model,
234
+ df=None,
235
+ feature_names=None,
236
+ txt_path: Optional[str] = None,
237
+ ndigits: int = 4,
238
+ tol: float = 0.0,
239
+ title: str = "BL Model Structure",
240
+ ):
241
+
242
+ blocks = _get_blocks(model)
243
+ hidden_dims = _get_hidden_dims(model)
244
+
245
+ if feature_names is not None:
246
+ feat_names = list(feature_names)
247
+ elif df is not None:
248
+ feat_names = list(df.columns)
249
+ else:
250
+ unit0 = _get_bl_unit(blocks[0])
251
+ if not hasattr(unit0, "lin_u"):
252
+ raise AttributeError("Cannot infer input dim from BLUnit. Expected 'lin_u'.")
253
+ in_dim = int(unit0.lin_u.in_features)
254
+ feat_names = [f"x{i+1}" for i in range(in_dim)]
255
+
256
+ if txt_path is None:
257
+ _print_core(model, blocks, hidden_dims, feat_names, ndigits, tol, title)
258
+ return
259
+
260
+ old = sys.stdout
261
+ buf = io.StringIO()
262
+ sys.stdout = buf
263
+ try:
264
+ _print_core(model, blocks, hidden_dims, feat_names, ndigits, tol, title)
265
+ finally:
266
+ sys.stdout = old
267
+
268
+ with open(txt_path, "w", encoding="utf-8") as f:
269
+ f.write(buf.getvalue())
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from .continuous import predict_continuous
4
+ from .discrete import predict_class_discrete, predict_proba_discrete
5
+
6
+ __all__ = [
7
+ "predict_continuous",
8
+ "predict_class_discrete",
9
+ "predict_proba_discrete",
10
+ ]
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from ..training import utils as U
7
+
8
+
9
+ @torch.no_grad()
10
+ def predict_continuous(
11
+ predictor: nn.Module,
12
+ x: torch.Tensor,
13
+ *,
14
+ device: str | torch.device | None = None,
15
+ return_cpu: bool = True,
16
+ ) -> torch.Tensor:
17
+
18
+ predictor.eval()
19
+ dev = U.resolve_device(model=predictor, tensor=x, device=device)
20
+ y_hat = predictor(x.to(dev))
21
+ return y_hat.detach().cpu() if return_cpu else y_hat.detach()
@@ -0,0 +1,40 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ..training import utils as U
5
+
6
+
7
+ @torch.no_grad()
8
+ def predict_proba_discrete(
9
+ model,
10
+ x: torch.Tensor,
11
+ *,
12
+ temperature: float = 1.0,
13
+ device: str | torch.device | None = None,
14
+ return_cpu: bool = False,
15
+ ) -> torch.Tensor:
16
+
17
+ dev = U.resolve_device(model, x, device)
18
+ x = x.to(dev)
19
+
20
+ scores = model.logits(x) if hasattr(model, "logits") else model(x)
21
+ logits = scores / float(temperature)
22
+ probs = F.softmax(logits, dim=1)
23
+
24
+ return probs.cpu() if return_cpu else probs
25
+
26
+
27
+ @torch.no_grad()
28
+ def predict_class_discrete(
29
+ model,
30
+ x: torch.Tensor,
31
+ *,
32
+ temperature: float = 1.0,
33
+ device: str | torch.device | None = None,
34
+ return_cpu: bool = True,
35
+ ) -> torch.Tensor:
36
+ probs = predict_proba_discrete(
37
+ model, x, temperature=temperature, device=device, return_cpu=False
38
+ )
39
+ pred = probs.argmax(dim=1)
40
+ return pred.cpu() if return_cpu else pred
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+ from .bldeep import BLDeep
4
+
5
+ __all__ = [
6
+ "BLDeep"
7
+ ]
@@ -0,0 +1,229 @@
1
+ from typing import List, Sequence, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from . import utils as U
6
+
7
+
8
+ class BLUnit(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_dim: int,
12
+ num_basis: int,
13
+ second_act_func: str = "relu",
14
+ third_act_func: str = "abs",
15
+ eps: float = 1e-8,
16
+ constrain_lambda: bool = True,
17
+ init_lambda: float = 1.0,
18
+ beta: float = 1.0,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.in_dim = int(in_dim)
22
+ self.num_basis = int(num_basis)
23
+ self.second_act_func = str(second_act_func)
24
+ self.third_act_func = str(third_act_func)
25
+ self.eps = float(eps)
26
+ self.constrain_lambda = bool(constrain_lambda)
27
+ self.beta = float(beta)
28
+
29
+ init_lambda = float(init_lambda)
30
+
31
+ self.lin_u = nn.Linear(self.in_dim, self.num_basis, bias=True)
32
+ self.lin_c = nn.Linear(self.in_dim, self.num_basis, bias=True)
33
+ self.lin_t = nn.Linear(self.in_dim, self.num_basis, bias=True)
34
+
35
+ self.lam = nn.Parameter(torch.full((3, self.num_basis), init_lambda))
36
+
37
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
38
+ u = torch.tanh(self.lin_u(z))
39
+ c = U.second_activation(self.lin_c(z), self.second_act_func, beta=self.beta)
40
+ t = U.third_activation(self.lin_t(z), self.third_act_func)
41
+
42
+ if self.constrain_lambda:
43
+ lam = F.softplus(self.lam) + self.eps
44
+ else:
45
+ lam = self.lam
46
+
47
+ lam_u = lam[0]
48
+ lam_c = lam[1]
49
+ lam_t = lam[2]
50
+
51
+ return lam_u * u - lam_c * c - lam_t * t
52
+
53
+
54
+ class BLBlock(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_dim: int,
58
+ num_basis: int,
59
+ second_act_func: str = "relu",
60
+ third_act_func: str = "abs",
61
+ constrain_lambda: bool = True,
62
+ init_lambda: float = 1.0,
63
+ beta: float = 1.0,
64
+ ) -> None:
65
+ super().__init__()
66
+ self.unit = BLUnit(
67
+ in_dim=in_dim,
68
+ num_basis=num_basis,
69
+ second_act_func=second_act_func,
70
+ third_act_func=third_act_func,
71
+ constrain_lambda=constrain_lambda,
72
+ init_lambda=init_lambda,
73
+ beta=beta,
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ return self.unit(x)
78
+
79
+
80
+ class BLDeepBackbone(nn.Module):
81
+ def __init__(
82
+ self,
83
+ in_dim: int,
84
+ hidden_dims: Sequence[int],
85
+ second_act_func: str = "relu",
86
+ third_act_func: str = "abs",
87
+ constrain_lambda: bool = True,
88
+ init_lambda: float = 1.0,
89
+ beta: float = 1.0,
90
+ ) -> None:
91
+ super().__init__()
92
+
93
+ self.in_dim = int(in_dim)
94
+ self.hidden_dims = list(hidden_dims)
95
+ dims: List[int] = [int(in_dim)] + list(hidden_dims[:-1])
96
+
97
+ self.blocks = nn.ModuleList(
98
+ BLBlock(
99
+ dims[i],
100
+ num_basis=hidden_dims[i],
101
+ second_act_func=second_act_func,
102
+ third_act_func=third_act_func,
103
+ constrain_lambda=constrain_lambda,
104
+ init_lambda=init_lambda,
105
+ beta=beta,
106
+ )
107
+ for i in range(len(hidden_dims))
108
+ )
109
+ self.out_dim = int(hidden_dims[-1])
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ for blk in self.blocks:
113
+ x = blk(x)
114
+ return x
115
+
116
+
117
+ class BLDeep(nn.Module):
118
+ def __init__(
119
+ self,
120
+ hidden_dims: Sequence[int],
121
+ second_act_func: str = "relu",
122
+ third_act_func: str = "abs",
123
+ head_bias: bool = True,
124
+ num_classes: Optional[int] = None,
125
+ task: str = "continuous",
126
+ constrain_lambda: bool = True,
127
+ init_lambda: float = 1.0,
128
+ beta: float = 1.0,
129
+ ) -> None:
130
+ super().__init__()
131
+ self.hidden_dims = list(hidden_dims)
132
+ self.second_act_func = str(second_act_func)
133
+ self.third_act_func = third_act_func
134
+ self.head_bias = head_bias
135
+ self.num_classes = num_classes
136
+
137
+ if task not in {"continuous", "discrete"}:
138
+ raise ValueError(f"task must be 'continuous' or 'discrete', got '{task}'")
139
+ self.task = task
140
+ self.constrain_lambda = bool(constrain_lambda)
141
+ self.init_lambda = float(init_lambda)
142
+ self.beta = float(beta)
143
+
144
+ self.x_dim: Optional[int] = None
145
+ self.y_dim: Optional[int] = None
146
+ self.backbone: Optional[nn.Module] = None
147
+ self.head: Optional[nn.Module] = None
148
+
149
+ def _build_architecture(self, x: torch.Tensor) -> None:
150
+ self.backbone = BLDeepBackbone(
151
+ in_dim=self.x_dim + self.y_dim,
152
+ hidden_dims=self.hidden_dims,
153
+ second_act_func=self.second_act_func,
154
+ third_act_func=self.third_act_func,
155
+ constrain_lambda=self.constrain_lambda,
156
+ init_lambda=self.init_lambda,
157
+ beta=self.beta,
158
+ )
159
+ self.head = nn.Linear(self.backbone.out_dim, 1, bias=self.head_bias)
160
+
161
+ device, dtype = x.device, x.dtype
162
+ self.to(device=device, dtype=dtype)
163
+
164
+ def build(self, X: torch.Tensor, y: torch.Tensor) -> None:
165
+ self.x_dim = X.shape[1]
166
+
167
+ if self.task == "discrete":
168
+ m, y_idx = U.infer_num_classes(y)
169
+ self.num_classes = int(m)
170
+ self.y_dim = int(m)
171
+ if not torch.equal(y_idx, y.long()):
172
+ raise ValueError(
173
+ f"Discrete labels must be in range [0..K-1]. "
174
+ f"Got non-continuous labels. Please remap your labels to [0, 1, 2, ...] first."
175
+ )
176
+ else:
177
+ self.y_dim = 1 if y.ndim == 1 else y.shape[1]
178
+
179
+ self._build_architecture(X)
180
+
181
+ def build_for_discrete_inference(self, x: torch.Tensor, num_classes: int) -> None:
182
+ if x.ndim != 2:
183
+ raise ValueError(f"x must be 2D (B, x_dim), got shape {tuple(x.shape)}")
184
+ if self.task != "discrete":
185
+ raise ValueError(f"This function only works with task='discrete', got task='{self.task}'")
186
+
187
+ self.x_dim = int(x.shape[1])
188
+ self.num_classes = int(num_classes)
189
+ self.y_dim = int(self.num_classes)
190
+
191
+ self._build_architecture(x)
192
+
193
+ def score(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
194
+ if self.backbone is None:
195
+ self.build(x, y)
196
+
197
+ if y.ndim == 1:
198
+ if self.task == "discrete":
199
+ y = F.one_hot(y.long(), num_classes=self.num_classes).to(device=x.device, dtype=x.dtype)
200
+ else:
201
+ y = y.unsqueeze(1)
202
+ elif y.ndim == 2:
203
+ if self.task == "discrete" and y.shape[1] == 1:
204
+ raise ValueError(
205
+ f"For discrete task, y should be 1D class indices (shape (B,)), "
206
+ f"but got shape {tuple(y.shape)}. Use y.squeeze(1) to convert (B,1) -> (B)."
207
+ )
208
+
209
+ z = torch.cat([x, y.to(device=x.device, dtype=x.dtype)], dim=1)
210
+ feats = self.backbone(z)
211
+ return self.head(feats)
212
+
213
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
214
+ return self.score(x, y)
215
+
216
+ def logits(self, x: torch.Tensor, num_classes: int | None = None) -> torch.Tensor:
217
+ if self.task != "discrete":
218
+ raise RuntimeError("logits() is only available when task='discrete'.")
219
+
220
+ m = int(num_classes or (self.num_classes or 0))
221
+ if m <= 0:
222
+ raise RuntimeError(
223
+ "Unknown num_classes for discrete inference. "
224
+ "Pass num_classes=K, or initialize BLDeep(..., num_classes=K), "
225
+ )
226
+ if self.backbone is None:
227
+ self.build_for_discrete_inference(x, num_classes=m)
228
+
229
+ return U.enumerate_onehot_logits(self.score, x, m=m)
@@ -0,0 +1,68 @@
1
+ import torch
2
+ from typing import Callable, Tuple
3
+
4
+ __all__ = [
5
+ "second_activation",
6
+ "third_activation",
7
+ "infer_num_classes",
8
+ "onehot_candidates",
9
+ "enumerate_onehot_logits",
10
+ ]
11
+
12
+ def second_activation(z: torch.Tensor, second_act_func: str, beta: float = 1.0) -> torch.Tensor:
13
+ if second_act_func == "relu":
14
+ return torch.relu(z)
15
+ if second_act_func == "softplus":
16
+ return torch.nn.functional.softplus(z, beta=beta)
17
+ raise ValueError(f"Unknown second_act_func='{second_act_func}'. Use 'relu' or 'softplus'.")
18
+
19
+ def third_activation(z: torch.Tensor, third_act_func: str) -> torch.Tensor:
20
+ if third_act_func == "abs":
21
+ return torch.abs(z)
22
+ if third_act_func == "square":
23
+ return z ** 2
24
+ raise ValueError(f"Unknown third_act_func='{third_act_func}'. Use 'abs' or 'square'.")
25
+
26
+ def infer_num_classes(y: torch.Tensor) -> Tuple[int, torch.Tensor]:
27
+
28
+ if y.ndim == 2:
29
+ K = int(y.shape[1])
30
+ y_idx = torch.argmax(y, dim=1).long()
31
+ return K, y_idx
32
+
33
+ if y.ndim != 1:
34
+ raise ValueError(f"y must be 1D class-index or 2D one-hot, got shape={tuple(y.shape)}")
35
+
36
+ if y.dtype.is_floating_point:
37
+ if not torch.allclose(y, y.round()):
38
+ raise ValueError("Discrete y looks continuous (non-integer floats).")
39
+ y = y.round()
40
+
41
+ y = y.long()
42
+ classes, y_idx = torch.unique(y, sorted=True, return_inverse=True)
43
+ K = int(classes.numel())
44
+ return K, y_idx
45
+
46
+
47
+ @torch.no_grad()
48
+ def onehot_candidates(m: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
49
+ return torch.eye(int(m), device=device, dtype=dtype)
50
+
51
+
52
+ def enumerate_onehot_logits(
53
+ score_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
54
+ x: torch.Tensor,
55
+ m: int,
56
+ ) -> torch.Tensor:
57
+
58
+ B = x.shape[0]
59
+ m = int(m)
60
+ device, dtype = x.device, x.dtype
61
+
62
+ X_rep = x.repeat_interleave(m, dim=0)
63
+ Y = onehot_candidates(m, device, dtype)
64
+ Y_rep = Y.repeat(B, 1)
65
+
66
+ e = score_fn(X_rep, Y_rep)
67
+ e = e.view(B, m)
68
+ return e
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from .base import TrainConfig
4
+ from .utils import OptimConfig, ExportConfig
5
+ from .continuous import ContinuousTrainer
6
+ from .discrete import DiscreteTrainer
7
+ from .amortized import AmortizedConfig, fit_amortized_predictor
8
+
9
+ __all__ = [
10
+ "OptimConfig",
11
+ "TrainConfig",
12
+ "ExportConfig",
13
+ "ContinuousTrainer",
14
+ "DiscreteTrainer",
15
+ "AmortizedConfig",
16
+ "fit_amortized_predictor"
17
+ ]