boltzmann9 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltzmann9/__init__.py +38 -0
- boltzmann9/__main__.py +4 -0
- boltzmann9/cli.py +389 -0
- boltzmann9/config.py +58 -0
- boltzmann9/data.py +145 -0
- boltzmann9/data_generator.py +234 -0
- boltzmann9/model.py +867 -0
- boltzmann9/pipeline.py +216 -0
- boltzmann9/preprocessor.py +627 -0
- boltzmann9/project.py +195 -0
- boltzmann9/run_utils.py +262 -0
- boltzmann9/tester.py +167 -0
- boltzmann9/utils.py +42 -0
- boltzmann9/visualization.py +115 -0
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.6.dist-info}/METADATA +1 -1
- boltzmann9-0.1.6.dist-info/RECORD +19 -0
- boltzmann9-0.1.6.dist-info/top_level.txt +1 -0
- boltzmann9-0.1.4.dist-info/RECORD +0 -5
- boltzmann9-0.1.4.dist-info/top_level.txt +0 -1
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.6.dist-info}/WHEEL +0 -0
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.6.dist-info}/entry_points.txt +0 -0
boltzmann9/model.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Restricted Boltzmann Machine (RBM) implementation in PyTorch
|
|
3
|
+
with cross-block restrictions (weight masking).
|
|
4
|
+
|
|
5
|
+
- Binary visible/hidden units
|
|
6
|
+
- PCD (persistent contrastive divergence)
|
|
7
|
+
- Momentum updates, weight decay, gradient clipping
|
|
8
|
+
- LR schedules (constant/exponential/step/cosine/plateau)
|
|
9
|
+
- Optional hidden sparsity regularization
|
|
10
|
+
- Optional early stopping with validation monitoring
|
|
11
|
+
|
|
12
|
+
Cross-block restrictions:
|
|
13
|
+
config["model"]["cross_block_restrictions"] = [("v_block", "h_block"), ...]
|
|
14
|
+
|
|
15
|
+
These pairs indicate which V-block × H-block submatrices in W must be forced to 0.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import math
|
|
21
|
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBM(nn.Module):
|
|
28
|
+
def __init__(self, config: Dict[str, Any]) -> None:
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
# Accept either a full app config with "model" key, or a model-only dict.
|
|
32
|
+
model_cfg = config.get("model", config)
|
|
33
|
+
|
|
34
|
+
visible_blocks: Dict[str, Any] = model_cfg["visible_blocks"]
|
|
35
|
+
hidden_blocks: Dict[str, Any] = model_cfg["hidden_blocks"]
|
|
36
|
+
restrictions = model_cfg.get("cross_block_restrictions", []) or []
|
|
37
|
+
|
|
38
|
+
self.visible_blocks = {k: int(v) for k, v in visible_blocks.items()}
|
|
39
|
+
self.hidden_blocks = {k: int(v) for k, v in hidden_blocks.items()}
|
|
40
|
+
|
|
41
|
+
self.nv = sum(self.visible_blocks.values())
|
|
42
|
+
self.nh = sum(self.hidden_blocks.values())
|
|
43
|
+
|
|
44
|
+
# --------- build block ranges (name -> (start, end)) ----------
|
|
45
|
+
self._v_block_ranges: Dict[str, Tuple[int, int]] = {}
|
|
46
|
+
off = 0
|
|
47
|
+
for name, size in self.visible_blocks.items():
|
|
48
|
+
if size <= 0:
|
|
49
|
+
raise ValueError(f"Visible block {name!r} must have positive size, got {size}.")
|
|
50
|
+
self._v_block_ranges[name] = (off, off + size)
|
|
51
|
+
off += size
|
|
52
|
+
|
|
53
|
+
self._h_block_ranges: Dict[str, Tuple[int, int]] = {}
|
|
54
|
+
off = 0
|
|
55
|
+
for name, size in self.hidden_blocks.items():
|
|
56
|
+
if size <= 0:
|
|
57
|
+
raise ValueError(f"Hidden block {name!r} must have positive size, got {size}.")
|
|
58
|
+
self._h_block_ranges[name] = (off, off + size)
|
|
59
|
+
off += size
|
|
60
|
+
|
|
61
|
+
# --------- parameters ----------
|
|
62
|
+
self.W = nn.Parameter(torch.empty(self.nv, self.nh))
|
|
63
|
+
self.bv = nn.Parameter(torch.zeros(self.nv))
|
|
64
|
+
self.bh = nn.Parameter(torch.zeros(self.nh))
|
|
65
|
+
|
|
66
|
+
nn.init.xavier_uniform_(self.W)
|
|
67
|
+
|
|
68
|
+
# --------- mask construction ----------
|
|
69
|
+
# Mask is float tensor with 1.0 for allowed edges and 0.0 for forbidden edges.
|
|
70
|
+
mask = torch.ones(self.nv, self.nh, dtype=self.W.dtype)
|
|
71
|
+
|
|
72
|
+
for pair in restrictions:
|
|
73
|
+
if not (isinstance(pair, (tuple, list)) and len(pair) == 2):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"Each cross_block_restrictions entry must be a pair (v_block, h_block). "
|
|
76
|
+
f"Got: {pair!r}"
|
|
77
|
+
)
|
|
78
|
+
v_block, h_block = pair
|
|
79
|
+
if v_block not in self._v_block_ranges:
|
|
80
|
+
raise KeyError(
|
|
81
|
+
f"Unknown visible block {v_block!r} in cross_block_restrictions. "
|
|
82
|
+
f"Known: {list(self._v_block_ranges.keys())}"
|
|
83
|
+
)
|
|
84
|
+
if h_block not in self._h_block_ranges:
|
|
85
|
+
raise KeyError(
|
|
86
|
+
f"Unknown hidden block {h_block!r} in cross_block_restrictions. "
|
|
87
|
+
f"Known: {list(self._h_block_ranges.keys())}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
vs, ve = self._v_block_ranges[v_block]
|
|
91
|
+
hs, he = self._h_block_ranges[h_block]
|
|
92
|
+
mask[vs:ve, hs:he] = 0.0
|
|
93
|
+
|
|
94
|
+
# register_buffer so it moves with .to(device) and is saved in state_dict
|
|
95
|
+
self.register_buffer("mask", mask)
|
|
96
|
+
|
|
97
|
+
# enforce mask at init
|
|
98
|
+
with torch.no_grad():
|
|
99
|
+
self.W.mul_(self.mask)
|
|
100
|
+
|
|
101
|
+
# Persistent chain for PCD
|
|
102
|
+
self.v_chain: Optional[torch.Tensor] = None
|
|
103
|
+
|
|
104
|
+
# Momentum buffers (registered so they move with .to(device))
|
|
105
|
+
self.register_buffer("_vW", torch.zeros_like(self.W))
|
|
106
|
+
self.register_buffer("_vbv", torch.zeros_like(self.bv))
|
|
107
|
+
self.register_buffer("_vbh", torch.zeros_like(self.bh))
|
|
108
|
+
|
|
109
|
+
# Plateau scheduler state
|
|
110
|
+
self._plateau_best: Optional[float] = None
|
|
111
|
+
self._plateau_bad_count: int = 0
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_run_folder(
|
|
115
|
+
cls,
|
|
116
|
+
run_folder: str,
|
|
117
|
+
device: Optional[str] = None,
|
|
118
|
+
) -> Tuple["RBM", Dict[str, Any]]:
|
|
119
|
+
"""Load an RBM model from a run folder.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
run_folder: Path to run folder containing model.pt and config.py.
|
|
123
|
+
device: Device to load model to. If None, uses CPU.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Tuple of (model, config) where config is the full configuration dict.
|
|
127
|
+
"""
|
|
128
|
+
from pathlib import Path
|
|
129
|
+
|
|
130
|
+
run_path = Path(run_folder)
|
|
131
|
+
model_path = run_path / "model.pt"
|
|
132
|
+
config_path = run_path / "config.py"
|
|
133
|
+
|
|
134
|
+
if not model_path.exists():
|
|
135
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
136
|
+
if not config_path.exists():
|
|
137
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
138
|
+
|
|
139
|
+
# Load config
|
|
140
|
+
from boltzmann9.config import load_config
|
|
141
|
+
config = load_config(config_path)
|
|
142
|
+
|
|
143
|
+
# Load checkpoint
|
|
144
|
+
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
|
|
145
|
+
|
|
146
|
+
# Create model from config
|
|
147
|
+
model = cls(config)
|
|
148
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
149
|
+
|
|
150
|
+
if device:
|
|
151
|
+
model = model.to(device)
|
|
152
|
+
|
|
153
|
+
print(f"Model loaded from: {run_path}")
|
|
154
|
+
return model, config
|
|
155
|
+
|
|
156
|
+
# --------------------------------------------------
|
|
157
|
+
# Core distributions
|
|
158
|
+
# --------------------------------------------------
|
|
159
|
+
|
|
160
|
+
def hidden_prob(self, v: torch.Tensor) -> torch.Tensor:
|
|
161
|
+
"""Compute P(h=1 | v)."""
|
|
162
|
+
return torch.sigmoid(v @ self.W + self.bh)
|
|
163
|
+
|
|
164
|
+
def visible_prob(self, h: torch.Tensor) -> torch.Tensor:
|
|
165
|
+
"""Compute P(v=1 | h)."""
|
|
166
|
+
return torch.sigmoid(h @ self.W.T + self.bv)
|
|
167
|
+
|
|
168
|
+
def _bernoulli(self, p: torch.Tensor) -> torch.Tensor:
|
|
169
|
+
"""Sample from Bernoulli distribution."""
|
|
170
|
+
return torch.bernoulli(p)
|
|
171
|
+
|
|
172
|
+
# --------------------------------------------------
|
|
173
|
+
# Forward (semantic: inference, NOT training)
|
|
174
|
+
# --------------------------------------------------
|
|
175
|
+
|
|
176
|
+
def forward(self, v: torch.Tensor) -> torch.Tensor:
|
|
177
|
+
"""Return P(h=1 | v)."""
|
|
178
|
+
return self.hidden_prob(v.to(self.W.dtype))
|
|
179
|
+
|
|
180
|
+
# --------------------------------------------------
|
|
181
|
+
# Phases
|
|
182
|
+
# --------------------------------------------------
|
|
183
|
+
|
|
184
|
+
def positive_phase(self, v: torch.Tensor, kind: str = "mean-field"):
|
|
185
|
+
"""Compute positive phase statistics."""
|
|
186
|
+
v = v.to(self.W.dtype)
|
|
187
|
+
ph = self.hidden_prob(v)
|
|
188
|
+
h_used = ph if kind == "mean-field" else self._bernoulli(ph)
|
|
189
|
+
|
|
190
|
+
pos_W = v.T @ h_used
|
|
191
|
+
pos_bv = v.sum(dim=0)
|
|
192
|
+
pos_bh = h_used.sum(dim=0)
|
|
193
|
+
return pos_W, pos_bv, pos_bh, ph
|
|
194
|
+
|
|
195
|
+
@torch.no_grad()
|
|
196
|
+
def negative_phase(
|
|
197
|
+
self,
|
|
198
|
+
batch_size: int,
|
|
199
|
+
k: int = 1,
|
|
200
|
+
kind: str = "mean-field",
|
|
201
|
+
device: Optional[torch.device] = None,
|
|
202
|
+
):
|
|
203
|
+
"""Compute negative phase statistics using PCD."""
|
|
204
|
+
device = device or self.W.device
|
|
205
|
+
|
|
206
|
+
if self.v_chain is None or self.v_chain.shape[0] != batch_size:
|
|
207
|
+
self.v_chain = self._bernoulli(
|
|
208
|
+
torch.full((batch_size, self.nv), 0.5, device=device, dtype=self.W.dtype)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
v = self.v_chain
|
|
212
|
+
|
|
213
|
+
for _ in range(k):
|
|
214
|
+
h = self._bernoulli(self.hidden_prob(v))
|
|
215
|
+
v = self._bernoulli(self.visible_prob(h))
|
|
216
|
+
|
|
217
|
+
self.v_chain = v.detach()
|
|
218
|
+
|
|
219
|
+
phk = self.hidden_prob(v)
|
|
220
|
+
h_used = phk if kind == "mean-field" else self._bernoulli(phk)
|
|
221
|
+
|
|
222
|
+
neg_W = v.T @ h_used
|
|
223
|
+
neg_bv = v.sum(dim=0)
|
|
224
|
+
neg_bh = h_used.sum(dim=0)
|
|
225
|
+
return neg_W, neg_bv, neg_bh
|
|
226
|
+
|
|
227
|
+
# --------------------------------------------------
|
|
228
|
+
# Update helpers (momentum, clipping, regularization)
|
|
229
|
+
# --------------------------------------------------
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
def _clip_by_value(x: torch.Tensor, clip_value: Optional[float]) -> torch.Tensor:
|
|
233
|
+
if clip_value is None:
|
|
234
|
+
return x
|
|
235
|
+
return x.clamp(min=-clip_value, max=clip_value)
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
def _clip_by_global_norm(
|
|
239
|
+
dW: torch.Tensor,
|
|
240
|
+
dbv: torch.Tensor,
|
|
241
|
+
dbh: torch.Tensor,
|
|
242
|
+
max_norm: Optional[float],
|
|
243
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
244
|
+
if max_norm is None:
|
|
245
|
+
return dW, dbv, dbh
|
|
246
|
+
norm = torch.sqrt((dW * dW).sum() + (dbv * dbv).sum() + (dbh * dbh).sum())
|
|
247
|
+
if norm > max_norm:
|
|
248
|
+
scale = max_norm / (norm + 1e-12)
|
|
249
|
+
dW = dW * scale
|
|
250
|
+
dbv = dbv * scale
|
|
251
|
+
dbh = dbh * scale
|
|
252
|
+
return dW, dbv, dbh
|
|
253
|
+
|
|
254
|
+
def _apply_update(
|
|
255
|
+
self,
|
|
256
|
+
*,
|
|
257
|
+
lr: float,
|
|
258
|
+
dW: torch.Tensor,
|
|
259
|
+
dbv: torch.Tensor,
|
|
260
|
+
dbh: torch.Tensor,
|
|
261
|
+
momentum: float = 0.0,
|
|
262
|
+
weight_decay: float = 0.0,
|
|
263
|
+
clip_value: Optional[float] = None,
|
|
264
|
+
clip_norm: Optional[float] = None,
|
|
265
|
+
) -> None:
|
|
266
|
+
"""Apply parameter updates with momentum, weight decay, clipping, and weight masking."""
|
|
267
|
+
# Mask gradients early (avoids momentum accumulating on forbidden edges)
|
|
268
|
+
if hasattr(self, "mask") and self.mask is not None:
|
|
269
|
+
dW = dW * self.mask
|
|
270
|
+
|
|
271
|
+
# L2 weight decay
|
|
272
|
+
if weight_decay and weight_decay > 0.0:
|
|
273
|
+
dW = dW - weight_decay * self.W
|
|
274
|
+
|
|
275
|
+
# Clip by value
|
|
276
|
+
dW = self._clip_by_value(dW, clip_value)
|
|
277
|
+
dbv = self._clip_by_value(dbv, clip_value)
|
|
278
|
+
dbh = self._clip_by_value(dbh, clip_value)
|
|
279
|
+
|
|
280
|
+
# Clip by global norm
|
|
281
|
+
dW, dbv, dbh = self._clip_by_global_norm(dW, dbv, dbh, clip_norm)
|
|
282
|
+
|
|
283
|
+
# Momentum update
|
|
284
|
+
with torch.no_grad():
|
|
285
|
+
if momentum and momentum > 0.0:
|
|
286
|
+
self._vW.mul_(momentum).add_(dW, alpha=lr)
|
|
287
|
+
self._vbv.mul_(momentum).add_(dbv, alpha=lr)
|
|
288
|
+
self._vbh.mul_(momentum).add_(dbh, alpha=lr)
|
|
289
|
+
|
|
290
|
+
# keep momentum buffer masked too (optional but good)
|
|
291
|
+
if hasattr(self, "mask") and self.mask is not None:
|
|
292
|
+
self._vW.mul_(self.mask)
|
|
293
|
+
|
|
294
|
+
self.W.add_(self._vW)
|
|
295
|
+
self.bv.add_(self._vbv)
|
|
296
|
+
self.bh.add_(self._vbh)
|
|
297
|
+
else:
|
|
298
|
+
self.W.add_(dW, alpha=lr)
|
|
299
|
+
self.bv.add_(dbv, alpha=lr)
|
|
300
|
+
self.bh.add_(dbh, alpha=lr)
|
|
301
|
+
|
|
302
|
+
# Re-apply mask to ensure restricted weights stay zero
|
|
303
|
+
if hasattr(self, "mask") and self.mask is not None:
|
|
304
|
+
self.W.mul_(self.mask)
|
|
305
|
+
|
|
306
|
+
# --------------------------------------------------
|
|
307
|
+
# LR scheduling
|
|
308
|
+
# --------------------------------------------------
|
|
309
|
+
|
|
310
|
+
def _lr_at_epoch(
|
|
311
|
+
self,
|
|
312
|
+
*,
|
|
313
|
+
base_lr: float,
|
|
314
|
+
epoch: int,
|
|
315
|
+
epochs: int,
|
|
316
|
+
schedule: Optional[Dict[str, Any]] = None,
|
|
317
|
+
current_val_metric: Optional[float] = None,
|
|
318
|
+
) -> float:
|
|
319
|
+
if not schedule:
|
|
320
|
+
return float(base_lr)
|
|
321
|
+
|
|
322
|
+
mode = schedule.get("mode", "constant")
|
|
323
|
+
lr0 = float(base_lr)
|
|
324
|
+
|
|
325
|
+
if mode == "constant":
|
|
326
|
+
return lr0
|
|
327
|
+
|
|
328
|
+
if mode == "exponential":
|
|
329
|
+
gamma = float(schedule.get("gamma", 0.99))
|
|
330
|
+
return lr0 * (gamma ** (epoch - 1))
|
|
331
|
+
|
|
332
|
+
if mode == "step":
|
|
333
|
+
step_size = int(schedule.get("step_size", 10))
|
|
334
|
+
gamma = float(schedule.get("gamma", 0.5))
|
|
335
|
+
n_steps = (epoch - 1) // step_size
|
|
336
|
+
return lr0 * (gamma ** n_steps)
|
|
337
|
+
|
|
338
|
+
if mode == "cosine":
|
|
339
|
+
min_lr = float(schedule.get("min_lr", 0.0))
|
|
340
|
+
t = (epoch - 1) / max(1, (epochs - 1))
|
|
341
|
+
return min_lr + 0.5 * (lr0 - min_lr) * (1.0 + math.cos(math.pi * t))
|
|
342
|
+
|
|
343
|
+
if mode == "plateau":
|
|
344
|
+
factor = float(schedule.get("factor", 0.5))
|
|
345
|
+
patience = int(schedule.get("patience", 3))
|
|
346
|
+
min_lr = float(schedule.get("min_lr", 1e-6))
|
|
347
|
+
threshold = float(schedule.get("threshold", 1e-4))
|
|
348
|
+
|
|
349
|
+
if current_val_metric is None:
|
|
350
|
+
return float(schedule.get("__current_lr", lr0))
|
|
351
|
+
|
|
352
|
+
current_lr = float(schedule.get("__current_lr", lr0))
|
|
353
|
+
|
|
354
|
+
if self._plateau_best is None or (self._plateau_best - current_val_metric) > threshold:
|
|
355
|
+
self._plateau_best = float(current_val_metric)
|
|
356
|
+
self._plateau_bad_count = 0
|
|
357
|
+
return current_lr
|
|
358
|
+
|
|
359
|
+
self._plateau_bad_count += 1
|
|
360
|
+
if self._plateau_bad_count >= patience:
|
|
361
|
+
new_lr = max(min_lr, current_lr * factor)
|
|
362
|
+
schedule["__current_lr"] = new_lr
|
|
363
|
+
self._plateau_bad_count = 0
|
|
364
|
+
return new_lr
|
|
365
|
+
|
|
366
|
+
return current_lr
|
|
367
|
+
|
|
368
|
+
raise ValueError(f"Unknown lr schedule mode={mode!r}")
|
|
369
|
+
|
|
370
|
+
# --------------------------------------------------
|
|
371
|
+
# Training
|
|
372
|
+
# --------------------------------------------------
|
|
373
|
+
|
|
374
|
+
@torch.no_grad()
|
|
375
|
+
def cd_step(
|
|
376
|
+
self,
|
|
377
|
+
v: torch.Tensor,
|
|
378
|
+
*,
|
|
379
|
+
lr: float,
|
|
380
|
+
k: int = 1,
|
|
381
|
+
kind: str = "mean-field",
|
|
382
|
+
momentum: float = 0.0,
|
|
383
|
+
weight_decay: float = 0.0,
|
|
384
|
+
clip_value: Optional[float] = None,
|
|
385
|
+
clip_norm: Optional[float] = None,
|
|
386
|
+
sparse_hidden: bool = False,
|
|
387
|
+
rho: float = 0.1,
|
|
388
|
+
lambda_sparse: float = 0.0,
|
|
389
|
+
) -> None:
|
|
390
|
+
B = v.size(0)
|
|
391
|
+
|
|
392
|
+
pos_W, pos_bv, pos_bh, ph = self.positive_phase(v, kind)
|
|
393
|
+
neg_W, neg_bv, neg_bh = self.negative_phase(batch_size=B, k=k, kind=kind, device=v.device)
|
|
394
|
+
|
|
395
|
+
dW = (pos_W - neg_W) / B
|
|
396
|
+
dbv = (pos_bv - neg_bv) / B
|
|
397
|
+
dbh = (pos_bh - neg_bh) / B
|
|
398
|
+
|
|
399
|
+
if sparse_hidden and lambda_sparse > 0.0:
|
|
400
|
+
err = ph.mean(dim=0) - rho
|
|
401
|
+
dbh = dbh - lambda_sparse * err
|
|
402
|
+
v_ = v.to(self.W.dtype)
|
|
403
|
+
dW = dW - lambda_sparse * (v_.mean(dim=0).unsqueeze(1) * err.unsqueeze(0))
|
|
404
|
+
|
|
405
|
+
self._apply_update(
|
|
406
|
+
lr=lr,
|
|
407
|
+
dW=dW,
|
|
408
|
+
dbv=dbv,
|
|
409
|
+
dbh=dbh,
|
|
410
|
+
momentum=momentum,
|
|
411
|
+
weight_decay=weight_decay,
|
|
412
|
+
clip_value=clip_value,
|
|
413
|
+
clip_norm=clip_norm,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# --------------------------------------------------
|
|
417
|
+
# Utilities
|
|
418
|
+
# --------------------------------------------------
|
|
419
|
+
|
|
420
|
+
def reconstruct(self, v: torch.Tensor, k: int = 1) -> torch.Tensor:
|
|
421
|
+
v = v.to(self.W.dtype)
|
|
422
|
+
for _ in range(k):
|
|
423
|
+
h = self._bernoulli(self.hidden_prob(v))
|
|
424
|
+
v = self._bernoulli(self.visible_prob(h))
|
|
425
|
+
return v
|
|
426
|
+
|
|
427
|
+
def free_energy(self, v: torch.Tensor) -> torch.Tensor:
|
|
428
|
+
v = v.to(self.W.dtype)
|
|
429
|
+
wx_b = v @ self.W + self.bh
|
|
430
|
+
return -v @ self.bv - torch.nn.functional.softplus(wx_b).sum(dim=1)
|
|
431
|
+
|
|
432
|
+
@torch.no_grad()
|
|
433
|
+
def evaluate(self, dataloader, *, recon_k: int = 1) -> Dict[str, float]:
|
|
434
|
+
device = self.W.device
|
|
435
|
+
|
|
436
|
+
fe_sum = 0.0
|
|
437
|
+
mse_sum = 0.0
|
|
438
|
+
ber_sum = 0.0
|
|
439
|
+
n_samples = 0
|
|
440
|
+
|
|
441
|
+
for v in dataloader:
|
|
442
|
+
v = v.to(device, non_blocking=True)
|
|
443
|
+
B = v.size(0)
|
|
444
|
+
n_samples += B
|
|
445
|
+
|
|
446
|
+
fe = self.free_energy(v).mean().item() / (self.nv + self.nh) # per-node free energy
|
|
447
|
+
v_rec = self.reconstruct(v, k=recon_k)
|
|
448
|
+
|
|
449
|
+
mse = torch.mean((v - v_rec) ** 2).item()
|
|
450
|
+
ber = torch.mean((v != v_rec).to(torch.float32)).item()
|
|
451
|
+
|
|
452
|
+
fe_sum += fe * B
|
|
453
|
+
mse_sum += mse * B
|
|
454
|
+
ber_sum += ber * B
|
|
455
|
+
|
|
456
|
+
return {
|
|
457
|
+
"free_energy_mean": fe_sum / n_samples,
|
|
458
|
+
"recon_mse_mean": mse_sum / n_samples,
|
|
459
|
+
"recon_bit_error": ber_sum / n_samples,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
# --------------------------------------------------
|
|
463
|
+
# Training loop
|
|
464
|
+
# --------------------------------------------------
|
|
465
|
+
|
|
466
|
+
def fit(
|
|
467
|
+
self,
|
|
468
|
+
train_loader,
|
|
469
|
+
*,
|
|
470
|
+
val_loader: Optional[object] = None,
|
|
471
|
+
epochs: int = 10,
|
|
472
|
+
lr: float = 1e-3,
|
|
473
|
+
k: int = 1,
|
|
474
|
+
kind: str = "mean-field",
|
|
475
|
+
eval_every: int = 1,
|
|
476
|
+
recon_k: int = 1,
|
|
477
|
+
lr_schedule: Optional[Dict[str, Any]] = None,
|
|
478
|
+
momentum: float = 0.0,
|
|
479
|
+
weight_decay: float = 0.0,
|
|
480
|
+
clip_value: Optional[float] = None,
|
|
481
|
+
clip_norm: Optional[float] = None,
|
|
482
|
+
sparse_hidden: bool = False,
|
|
483
|
+
rho: float = 0.1,
|
|
484
|
+
lambda_sparse: float = 0.0,
|
|
485
|
+
early_stopping: bool = False,
|
|
486
|
+
es_patience: int = 10,
|
|
487
|
+
es_min_delta: float = 1e-4,
|
|
488
|
+
) -> Dict[str, list]:
|
|
489
|
+
device = self.W.device
|
|
490
|
+
history: Dict[str, list] = {
|
|
491
|
+
"epoch": [],
|
|
492
|
+
"train_free_energy": [],
|
|
493
|
+
"train_recon_mse": [],
|
|
494
|
+
"train_recon_bit_error": [],
|
|
495
|
+
"val_free_energy": [],
|
|
496
|
+
"val_recon_mse": [],
|
|
497
|
+
"val_recon_bit_error": [],
|
|
498
|
+
"lr": [],
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
best_val: Optional[float] = None
|
|
502
|
+
best_state: Optional[Dict[str, torch.Tensor]] = None
|
|
503
|
+
bad_epochs = 0
|
|
504
|
+
|
|
505
|
+
if lr_schedule and lr_schedule.get("mode") == "plateau":
|
|
506
|
+
lr_schedule = dict(lr_schedule)
|
|
507
|
+
lr_schedule["__current_lr"] = float(lr)
|
|
508
|
+
|
|
509
|
+
for epoch in range(1, epochs + 1):
|
|
510
|
+
self.train()
|
|
511
|
+
|
|
512
|
+
current_lr = self._lr_at_epoch(
|
|
513
|
+
base_lr=float(lr),
|
|
514
|
+
epoch=epoch,
|
|
515
|
+
epochs=epochs,
|
|
516
|
+
schedule=lr_schedule,
|
|
517
|
+
current_val_metric=None,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
for v in train_loader:
|
|
521
|
+
v = v.to(device, non_blocking=True)
|
|
522
|
+
self.cd_step(
|
|
523
|
+
v,
|
|
524
|
+
lr=current_lr,
|
|
525
|
+
k=k,
|
|
526
|
+
kind=kind,
|
|
527
|
+
momentum=momentum,
|
|
528
|
+
weight_decay=weight_decay,
|
|
529
|
+
clip_value=clip_value,
|
|
530
|
+
clip_norm=clip_norm,
|
|
531
|
+
sparse_hidden=sparse_hidden,
|
|
532
|
+
rho=rho,
|
|
533
|
+
lambda_sparse=lambda_sparse,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
if epoch % eval_every == 0:
|
|
537
|
+
self.eval()
|
|
538
|
+
|
|
539
|
+
train_metrics = self.evaluate(train_loader, recon_k=recon_k)
|
|
540
|
+
history["epoch"].append(epoch)
|
|
541
|
+
history["train_free_energy"].append(train_metrics["free_energy_mean"])
|
|
542
|
+
history["train_recon_mse"].append(train_metrics["recon_mse_mean"])
|
|
543
|
+
history["train_recon_bit_error"].append(train_metrics["recon_bit_error"])
|
|
544
|
+
|
|
545
|
+
if val_loader is not None:
|
|
546
|
+
val_metrics = self.evaluate(val_loader, recon_k=recon_k)
|
|
547
|
+
val_fe = float(val_metrics["free_energy_mean"])
|
|
548
|
+
|
|
549
|
+
history["val_free_energy"].append(val_fe)
|
|
550
|
+
history["val_recon_mse"].append(val_metrics["recon_mse_mean"])
|
|
551
|
+
history["val_recon_bit_error"].append(val_metrics["recon_bit_error"])
|
|
552
|
+
|
|
553
|
+
if lr_schedule and lr_schedule.get("mode") == "plateau":
|
|
554
|
+
# plateau uses its internal "__current_lr"
|
|
555
|
+
_ = self._lr_at_epoch(
|
|
556
|
+
base_lr=float(lr),
|
|
557
|
+
epoch=epoch,
|
|
558
|
+
epochs=epochs,
|
|
559
|
+
schedule=lr_schedule,
|
|
560
|
+
current_val_metric=val_fe,
|
|
561
|
+
)
|
|
562
|
+
current_lr = float(lr_schedule.get("__current_lr", current_lr))
|
|
563
|
+
|
|
564
|
+
if early_stopping:
|
|
565
|
+
if best_val is None or (best_val - val_fe) > es_min_delta:
|
|
566
|
+
best_val = val_fe
|
|
567
|
+
best_state = {k: t.detach().clone() for k, t in self.state_dict().items()}
|
|
568
|
+
bad_epochs = 0
|
|
569
|
+
else:
|
|
570
|
+
bad_epochs += 1
|
|
571
|
+
if bad_epochs >= es_patience:
|
|
572
|
+
if best_state is not None:
|
|
573
|
+
self.load_state_dict(best_state)
|
|
574
|
+
history["lr"].append(current_lr)
|
|
575
|
+
print(
|
|
576
|
+
f"Early stopping at epoch {epoch} "
|
|
577
|
+
f"(best val FE={best_val:.6f})"
|
|
578
|
+
)
|
|
579
|
+
self.visualize_history(history)
|
|
580
|
+
return history
|
|
581
|
+
|
|
582
|
+
print(
|
|
583
|
+
f"Epoch {epoch:04d} | lr={current_lr:.3e} | "
|
|
584
|
+
f"train FE={train_metrics['free_energy_mean']:.4f} "
|
|
585
|
+
f"val FE={val_fe:.4f} | "
|
|
586
|
+
f"train recon_mse={train_metrics['recon_mse_mean']:.4f} "
|
|
587
|
+
f"val recon_mse={val_metrics['recon_mse_mean']:.4f}"
|
|
588
|
+
)
|
|
589
|
+
else:
|
|
590
|
+
history["val_free_energy"].append(float("nan"))
|
|
591
|
+
history["val_recon_mse"].append(float("nan"))
|
|
592
|
+
history["val_recon_bit_error"].append(float("nan"))
|
|
593
|
+
print(
|
|
594
|
+
f"Epoch {epoch:04d} | lr={current_lr:.3e} | "
|
|
595
|
+
f"train FE={train_metrics['free_energy_mean']:.4f} | "
|
|
596
|
+
f"train recon_mse={train_metrics['recon_mse_mean']:.4f}"
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
history["lr"].append(current_lr)
|
|
600
|
+
|
|
601
|
+
self.visualize_history(history)
|
|
602
|
+
return history
|
|
603
|
+
|
|
604
|
+
# --------------------------------------------------
|
|
605
|
+
# Visualization
|
|
606
|
+
# --------------------------------------------------
|
|
607
|
+
|
|
608
|
+
def visualize_history(self, history: dict) -> None:
|
|
609
|
+
try:
|
|
610
|
+
import matplotlib.pyplot as plt
|
|
611
|
+
except Exception as e:
|
|
612
|
+
print(f"[visualize_history] matplotlib import failed: {e}")
|
|
613
|
+
return
|
|
614
|
+
|
|
615
|
+
epochs = history.get("epoch", list(range(1, len(history.get("train_free_energy", [])) + 1)))
|
|
616
|
+
|
|
617
|
+
def _is_all_nan(xs):
|
|
618
|
+
if not xs:
|
|
619
|
+
return True
|
|
620
|
+
return all((x is None) or (isinstance(x, float) and math.isnan(x)) for x in xs)
|
|
621
|
+
|
|
622
|
+
def _plot(ax, y_train_key, y_val_key, title, ylabel):
|
|
623
|
+
y_tr = history.get(y_train_key, [])
|
|
624
|
+
y_va = history.get(y_val_key, [])
|
|
625
|
+
|
|
626
|
+
ax.plot(epochs[: len(y_tr)], y_tr, label="train")
|
|
627
|
+
if not _is_all_nan(y_va):
|
|
628
|
+
ax.plot(epochs[: len(y_va)], y_va, label="val")
|
|
629
|
+
|
|
630
|
+
ax.set_title(title)
|
|
631
|
+
ax.set_xlabel("epoch")
|
|
632
|
+
ax.set_ylabel(ylabel)
|
|
633
|
+
ax.legend()
|
|
634
|
+
|
|
635
|
+
try:
|
|
636
|
+
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
|
|
637
|
+
|
|
638
|
+
_plot(axes[0], "train_free_energy", "val_free_energy", "Free Energy", "mean FE (lower is better)")
|
|
639
|
+
_plot(axes[1], "train_recon_mse", "val_recon_mse", "Reconstruction MSE", "MSE")
|
|
640
|
+
_plot(axes[2], "train_recon_bit_error", "val_recon_bit_error", "Reconstruction Bit Error", "fraction mismatched")
|
|
641
|
+
|
|
642
|
+
fig.tight_layout()
|
|
643
|
+
plt.show()
|
|
644
|
+
except Exception as e:
|
|
645
|
+
print(f"[visualize_history] plotting failed: {e}")
|
|
646
|
+
|
|
647
|
+
# --------------------------------------------------
|
|
648
|
+
# Sampling
|
|
649
|
+
# --------------------------------------------------
|
|
650
|
+
|
|
651
|
+
@torch.no_grad()
|
|
652
|
+
def sample(
|
|
653
|
+
self,
|
|
654
|
+
n_samples: int,
|
|
655
|
+
*,
|
|
656
|
+
burn_in: int = 200,
|
|
657
|
+
thin: int = 10,
|
|
658
|
+
init: str = "random",
|
|
659
|
+
device: Optional[torch.device] = None,
|
|
660
|
+
) -> torch.Tensor:
|
|
661
|
+
device = device or self.W.device
|
|
662
|
+
dtype = self.W.dtype
|
|
663
|
+
|
|
664
|
+
if init == "chain" and self.v_chain is not None:
|
|
665
|
+
v = self.v_chain[:1].to(device=device, dtype=dtype)
|
|
666
|
+
else:
|
|
667
|
+
v = torch.bernoulli(torch.full((1, self.nv), 0.5, device=device, dtype=dtype))
|
|
668
|
+
|
|
669
|
+
for _ in range(burn_in):
|
|
670
|
+
h = torch.bernoulli(self.hidden_prob(v))
|
|
671
|
+
v = torch.bernoulli(self.visible_prob(h))
|
|
672
|
+
|
|
673
|
+
samples = []
|
|
674
|
+
steps_needed = n_samples * thin
|
|
675
|
+
for t in range(steps_needed):
|
|
676
|
+
h = torch.bernoulli(self.hidden_prob(v))
|
|
677
|
+
v = torch.bernoulli(self.visible_prob(h))
|
|
678
|
+
if (t + 1) % thin == 0:
|
|
679
|
+
samples.append(v.squeeze(0).clone())
|
|
680
|
+
|
|
681
|
+
self.v_chain = v.detach()
|
|
682
|
+
return torch.stack(samples, dim=0)
|
|
683
|
+
|
|
684
|
+
def draw_blocks(self, save_path: Optional[str] = None, show: bool = True) -> None:
|
|
685
|
+
"""Visualize the RBM block structure at the block level.
|
|
686
|
+
|
|
687
|
+
Draws a bipartite graph showing visible blocks (bottom) and hidden blocks (top),
|
|
688
|
+
with block-to-block connections colored by whether they are allowed or restricted.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
save_path: If provided, save the figure to this path.
|
|
692
|
+
show: If True, display the plot interactively.
|
|
693
|
+
"""
|
|
694
|
+
try:
|
|
695
|
+
import matplotlib.pyplot as plt
|
|
696
|
+
import matplotlib.patches as mpatches
|
|
697
|
+
from matplotlib.patches import FancyBboxPatch
|
|
698
|
+
except ImportError as e:
|
|
699
|
+
print(f"[draw_blocks] matplotlib import failed: {e}")
|
|
700
|
+
return
|
|
701
|
+
|
|
702
|
+
fig, ax = plt.subplots(figsize=(14, 8))
|
|
703
|
+
|
|
704
|
+
# Layout parameters
|
|
705
|
+
v_y = 0.15 # visible layer y-coordinate
|
|
706
|
+
h_y = 0.85 # hidden layer y-coordinate
|
|
707
|
+
|
|
708
|
+
# Colors
|
|
709
|
+
v_colors = list(plt.cm.Set2.colors)
|
|
710
|
+
h_colors = list(plt.cm.Set3.colors)
|
|
711
|
+
allowed_color = "#4CAF50" # green for allowed connections
|
|
712
|
+
restricted_color = "#E57373" # red for restricted
|
|
713
|
+
|
|
714
|
+
# Get block info
|
|
715
|
+
v_blocks = list(self._v_block_ranges.items())
|
|
716
|
+
h_blocks = list(self._h_block_ranges.items())
|
|
717
|
+
|
|
718
|
+
n_v_blocks = len(v_blocks)
|
|
719
|
+
n_h_blocks = len(h_blocks)
|
|
720
|
+
|
|
721
|
+
# Calculate block positions (evenly spaced)
|
|
722
|
+
def get_block_positions(n_blocks, y):
|
|
723
|
+
if n_blocks == 1:
|
|
724
|
+
return [0.5]
|
|
725
|
+
return [0.1 + i * 0.8 / (n_blocks - 1) for i in range(n_blocks)]
|
|
726
|
+
|
|
727
|
+
v_x_positions = get_block_positions(n_v_blocks, v_y)
|
|
728
|
+
h_x_positions = get_block_positions(n_h_blocks, h_y)
|
|
729
|
+
|
|
730
|
+
# Determine block-to-block connectivity from mask
|
|
731
|
+
mask_np = self.mask.detach().cpu().numpy()
|
|
732
|
+
|
|
733
|
+
def blocks_connected(v_block_range, h_block_range):
|
|
734
|
+
"""Check if any connection exists between two blocks."""
|
|
735
|
+
v_start, v_end = v_block_range
|
|
736
|
+
h_start, h_end = h_block_range
|
|
737
|
+
return mask_np[v_start:v_end, h_start:h_end].any()
|
|
738
|
+
|
|
739
|
+
# Draw connections between blocks
|
|
740
|
+
for vi, (v_name, v_range) in enumerate(v_blocks):
|
|
741
|
+
for hi, (h_name, h_range) in enumerate(h_blocks):
|
|
742
|
+
vx, vy = v_x_positions[vi], v_y
|
|
743
|
+
hx, hy = h_x_positions[hi], h_y
|
|
744
|
+
|
|
745
|
+
connected = blocks_connected(v_range, h_range)
|
|
746
|
+
color = allowed_color if connected else restricted_color
|
|
747
|
+
alpha = 0.7 if connected else 0.2
|
|
748
|
+
linewidth = 2.5 if connected else 1.0
|
|
749
|
+
zorder = 2 if connected else 1
|
|
750
|
+
|
|
751
|
+
ax.plot([vx, hx], [vy + 0.05, hy - 0.05],
|
|
752
|
+
color=color, alpha=alpha, linewidth=linewidth, zorder=zorder)
|
|
753
|
+
|
|
754
|
+
# Draw visible blocks as rounded rectangles
|
|
755
|
+
block_height = 0.08
|
|
756
|
+
for i, (name, (start, end)) in enumerate(v_blocks):
|
|
757
|
+
size = end - start
|
|
758
|
+
x = v_x_positions[i]
|
|
759
|
+
color = v_colors[i % len(v_colors)]
|
|
760
|
+
|
|
761
|
+
# Block width proportional to log of size (for visual balance)
|
|
762
|
+
width = 0.06 + 0.02 * min(3, max(0, (size / 100)))
|
|
763
|
+
|
|
764
|
+
rect = FancyBboxPatch(
|
|
765
|
+
(x - width/2, v_y - block_height/2), width, block_height,
|
|
766
|
+
boxstyle="round,pad=0.01,rounding_size=0.02",
|
|
767
|
+
facecolor=color, edgecolor="black", linewidth=1.5, zorder=3
|
|
768
|
+
)
|
|
769
|
+
ax.add_patch(rect)
|
|
770
|
+
|
|
771
|
+
# Block label with size
|
|
772
|
+
ax.text(x, v_y, f"{name}\n({size})", ha="center", va="center",
|
|
773
|
+
fontsize=9, fontweight="bold", zorder=4)
|
|
774
|
+
|
|
775
|
+
# Draw hidden blocks as rounded rectangles
|
|
776
|
+
for i, (name, (start, end)) in enumerate(h_blocks):
|
|
777
|
+
size = end - start
|
|
778
|
+
x = h_x_positions[i]
|
|
779
|
+
color = h_colors[i % len(h_colors)]
|
|
780
|
+
|
|
781
|
+
width = 0.06 + 0.02 * min(3, max(0, (size / 100)))
|
|
782
|
+
|
|
783
|
+
rect = FancyBboxPatch(
|
|
784
|
+
(x - width/2, h_y - block_height/2), width, block_height,
|
|
785
|
+
boxstyle="round,pad=0.01,rounding_size=0.02",
|
|
786
|
+
facecolor=color, edgecolor="black", linewidth=1.5, zorder=3
|
|
787
|
+
)
|
|
788
|
+
ax.add_patch(rect)
|
|
789
|
+
|
|
790
|
+
ax.text(x, h_y, f"{name}\n({size})", ha="center", va="center",
|
|
791
|
+
fontsize=9, fontweight="bold", zorder=4)
|
|
792
|
+
|
|
793
|
+
# Layer labels
|
|
794
|
+
ax.text(0.02, v_y, "Visible\nLayer", ha="left", va="center", fontsize=11, fontweight="bold")
|
|
795
|
+
ax.text(0.02, h_y, "Hidden\nLayer", ha="left", va="center", fontsize=11, fontweight="bold")
|
|
796
|
+
|
|
797
|
+
# Legend
|
|
798
|
+
allowed_patch = mpatches.Patch(color=allowed_color, label="Connected (allowed)")
|
|
799
|
+
restricted_patch = mpatches.Patch(color=restricted_color, alpha=0.3, label="Restricted (masked)")
|
|
800
|
+
ax.legend(handles=[allowed_patch, restricted_patch], loc="upper right", fontsize=10)
|
|
801
|
+
|
|
802
|
+
# Title with summary
|
|
803
|
+
n_allowed = int(mask_np.sum())
|
|
804
|
+
n_total = self.nv * self.nh
|
|
805
|
+
n_restricted = n_total - n_allowed
|
|
806
|
+
ax.set_title(
|
|
807
|
+
f"RBM Block Structure\n"
|
|
808
|
+
f"Visible: {self.nv:,} units ({n_v_blocks} blocks) | "
|
|
809
|
+
f"Hidden: {self.nh:,} units ({n_h_blocks} blocks)\n"
|
|
810
|
+
f"Connections: {n_allowed:,}/{n_total:,} allowed, {n_restricted:,} restricted",
|
|
811
|
+
fontsize=12, fontweight="bold"
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
ax.set_xlim(-0.05, 1.05)
|
|
815
|
+
ax.set_ylim(0, 1)
|
|
816
|
+
ax.set_aspect("equal")
|
|
817
|
+
ax.axis("off")
|
|
818
|
+
|
|
819
|
+
fig.tight_layout()
|
|
820
|
+
|
|
821
|
+
if save_path:
|
|
822
|
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
823
|
+
print(f"Block diagram saved to: {save_path}")
|
|
824
|
+
|
|
825
|
+
if show:
|
|
826
|
+
plt.show()
|
|
827
|
+
else:
|
|
828
|
+
plt.close(fig)
|
|
829
|
+
|
|
830
|
+
@torch.no_grad()
|
|
831
|
+
def sample_clamped(
|
|
832
|
+
self,
|
|
833
|
+
v_clamp: torch.Tensor,
|
|
834
|
+
clamp_idx: Sequence[int],
|
|
835
|
+
*,
|
|
836
|
+
n_samples: int = 1000,
|
|
837
|
+
burn_in: int = 200,
|
|
838
|
+
thin: int = 10,
|
|
839
|
+
init: str = "random",
|
|
840
|
+
device: Optional[torch.device] = None,
|
|
841
|
+
) -> torch.Tensor:
|
|
842
|
+
device = device or self.W.device
|
|
843
|
+
dtype = self.W.dtype
|
|
844
|
+
|
|
845
|
+
if v_clamp.dim() == 1:
|
|
846
|
+
v_clamp = v_clamp.unsqueeze(0)
|
|
847
|
+
v_clamp = v_clamp.to(device=device, dtype=dtype)
|
|
848
|
+
|
|
849
|
+
clamp_idx_t = torch.as_tensor(clamp_idx, device=device, dtype=torch.long)
|
|
850
|
+
|
|
851
|
+
v = torch.bernoulli(torch.full((1, self.nv), 0.5, device=device, dtype=dtype))
|
|
852
|
+
v[:, clamp_idx_t] = v_clamp[:, clamp_idx_t]
|
|
853
|
+
|
|
854
|
+
for _ in range(burn_in):
|
|
855
|
+
h = torch.bernoulli(self.hidden_prob(v))
|
|
856
|
+
v = torch.bernoulli(self.visible_prob(h))
|
|
857
|
+
v[:, clamp_idx_t] = v_clamp[:, clamp_idx_t]
|
|
858
|
+
|
|
859
|
+
samples = []
|
|
860
|
+
for t in range(n_samples * thin):
|
|
861
|
+
h = torch.bernoulli(self.hidden_prob(v))
|
|
862
|
+
v = torch.bernoulli(self.visible_prob(h))
|
|
863
|
+
v[:, clamp_idx_t] = v_clamp[:, clamp_idx_t]
|
|
864
|
+
if (t + 1) % thin == 0:
|
|
865
|
+
samples.append(v.squeeze(0).clone())
|
|
866
|
+
|
|
867
|
+
return torch.stack(samples, dim=0)
|