boltzmann9 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
boltzmann9/model.py ADDED
@@ -0,0 +1,867 @@
1
+ """
2
+ Restricted Boltzmann Machine (RBM) implementation in PyTorch
3
+ with cross-block restrictions (weight masking).
4
+
5
+ - Binary visible/hidden units
6
+ - PCD (persistent contrastive divergence)
7
+ - Momentum updates, weight decay, gradient clipping
8
+ - LR schedules (constant/exponential/step/cosine/plateau)
9
+ - Optional hidden sparsity regularization
10
+ - Optional early stopping with validation monitoring
11
+
12
+ Cross-block restrictions:
13
+ config["model"]["cross_block_restrictions"] = [("v_block", "h_block"), ...]
14
+
15
+ These pairs indicate which V-block × H-block submatrices in W must be forced to 0.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from typing import Any, Dict, Optional, Sequence, Tuple
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+
27
+ class RBM(nn.Module):
28
+ def __init__(self, config: Dict[str, Any]) -> None:
29
+ super().__init__()
30
+
31
+ # Accept either a full app config with "model" key, or a model-only dict.
32
+ model_cfg = config.get("model", config)
33
+
34
+ visible_blocks: Dict[str, Any] = model_cfg["visible_blocks"]
35
+ hidden_blocks: Dict[str, Any] = model_cfg["hidden_blocks"]
36
+ restrictions = model_cfg.get("cross_block_restrictions", []) or []
37
+
38
+ self.visible_blocks = {k: int(v) for k, v in visible_blocks.items()}
39
+ self.hidden_blocks = {k: int(v) for k, v in hidden_blocks.items()}
40
+
41
+ self.nv = sum(self.visible_blocks.values())
42
+ self.nh = sum(self.hidden_blocks.values())
43
+
44
+ # --------- build block ranges (name -> (start, end)) ----------
45
+ self._v_block_ranges: Dict[str, Tuple[int, int]] = {}
46
+ off = 0
47
+ for name, size in self.visible_blocks.items():
48
+ if size <= 0:
49
+ raise ValueError(f"Visible block {name!r} must have positive size, got {size}.")
50
+ self._v_block_ranges[name] = (off, off + size)
51
+ off += size
52
+
53
+ self._h_block_ranges: Dict[str, Tuple[int, int]] = {}
54
+ off = 0
55
+ for name, size in self.hidden_blocks.items():
56
+ if size <= 0:
57
+ raise ValueError(f"Hidden block {name!r} must have positive size, got {size}.")
58
+ self._h_block_ranges[name] = (off, off + size)
59
+ off += size
60
+
61
+ # --------- parameters ----------
62
+ self.W = nn.Parameter(torch.empty(self.nv, self.nh))
63
+ self.bv = nn.Parameter(torch.zeros(self.nv))
64
+ self.bh = nn.Parameter(torch.zeros(self.nh))
65
+
66
+ nn.init.xavier_uniform_(self.W)
67
+
68
+ # --------- mask construction ----------
69
+ # Mask is float tensor with 1.0 for allowed edges and 0.0 for forbidden edges.
70
+ mask = torch.ones(self.nv, self.nh, dtype=self.W.dtype)
71
+
72
+ for pair in restrictions:
73
+ if not (isinstance(pair, (tuple, list)) and len(pair) == 2):
74
+ raise ValueError(
75
+ "Each cross_block_restrictions entry must be a pair (v_block, h_block). "
76
+ f"Got: {pair!r}"
77
+ )
78
+ v_block, h_block = pair
79
+ if v_block not in self._v_block_ranges:
80
+ raise KeyError(
81
+ f"Unknown visible block {v_block!r} in cross_block_restrictions. "
82
+ f"Known: {list(self._v_block_ranges.keys())}"
83
+ )
84
+ if h_block not in self._h_block_ranges:
85
+ raise KeyError(
86
+ f"Unknown hidden block {h_block!r} in cross_block_restrictions. "
87
+ f"Known: {list(self._h_block_ranges.keys())}"
88
+ )
89
+
90
+ vs, ve = self._v_block_ranges[v_block]
91
+ hs, he = self._h_block_ranges[h_block]
92
+ mask[vs:ve, hs:he] = 0.0
93
+
94
+ # register_buffer so it moves with .to(device) and is saved in state_dict
95
+ self.register_buffer("mask", mask)
96
+
97
+ # enforce mask at init
98
+ with torch.no_grad():
99
+ self.W.mul_(self.mask)
100
+
101
+ # Persistent chain for PCD
102
+ self.v_chain: Optional[torch.Tensor] = None
103
+
104
+ # Momentum buffers (registered so they move with .to(device))
105
+ self.register_buffer("_vW", torch.zeros_like(self.W))
106
+ self.register_buffer("_vbv", torch.zeros_like(self.bv))
107
+ self.register_buffer("_vbh", torch.zeros_like(self.bh))
108
+
109
+ # Plateau scheduler state
110
+ self._plateau_best: Optional[float] = None
111
+ self._plateau_bad_count: int = 0
112
+
113
+ @classmethod
114
+ def from_run_folder(
115
+ cls,
116
+ run_folder: str,
117
+ device: Optional[str] = None,
118
+ ) -> Tuple["RBM", Dict[str, Any]]:
119
+ """Load an RBM model from a run folder.
120
+
121
+ Args:
122
+ run_folder: Path to run folder containing model.pt and config.py.
123
+ device: Device to load model to. If None, uses CPU.
124
+
125
+ Returns:
126
+ Tuple of (model, config) where config is the full configuration dict.
127
+ """
128
+ from pathlib import Path
129
+
130
+ run_path = Path(run_folder)
131
+ model_path = run_path / "model.pt"
132
+ config_path = run_path / "config.py"
133
+
134
+ if not model_path.exists():
135
+ raise FileNotFoundError(f"Model file not found: {model_path}")
136
+ if not config_path.exists():
137
+ raise FileNotFoundError(f"Config file not found: {config_path}")
138
+
139
+ # Load config
140
+ from boltzmann9.config import load_config
141
+ config = load_config(config_path)
142
+
143
+ # Load checkpoint
144
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
145
+
146
+ # Create model from config
147
+ model = cls(config)
148
+ model.load_state_dict(checkpoint["model_state_dict"])
149
+
150
+ if device:
151
+ model = model.to(device)
152
+
153
+ print(f"Model loaded from: {run_path}")
154
+ return model, config
155
+
156
+ # --------------------------------------------------
157
+ # Core distributions
158
+ # --------------------------------------------------
159
+
160
+ def hidden_prob(self, v: torch.Tensor) -> torch.Tensor:
161
+ """Compute P(h=1 | v)."""
162
+ return torch.sigmoid(v @ self.W + self.bh)
163
+
164
+ def visible_prob(self, h: torch.Tensor) -> torch.Tensor:
165
+ """Compute P(v=1 | h)."""
166
+ return torch.sigmoid(h @ self.W.T + self.bv)
167
+
168
+ def _bernoulli(self, p: torch.Tensor) -> torch.Tensor:
169
+ """Sample from Bernoulli distribution."""
170
+ return torch.bernoulli(p)
171
+
172
+ # --------------------------------------------------
173
+ # Forward (semantic: inference, NOT training)
174
+ # --------------------------------------------------
175
+
176
+ def forward(self, v: torch.Tensor) -> torch.Tensor:
177
+ """Return P(h=1 | v)."""
178
+ return self.hidden_prob(v.to(self.W.dtype))
179
+
180
+ # --------------------------------------------------
181
+ # Phases
182
+ # --------------------------------------------------
183
+
184
+ def positive_phase(self, v: torch.Tensor, kind: str = "mean-field"):
185
+ """Compute positive phase statistics."""
186
+ v = v.to(self.W.dtype)
187
+ ph = self.hidden_prob(v)
188
+ h_used = ph if kind == "mean-field" else self._bernoulli(ph)
189
+
190
+ pos_W = v.T @ h_used
191
+ pos_bv = v.sum(dim=0)
192
+ pos_bh = h_used.sum(dim=0)
193
+ return pos_W, pos_bv, pos_bh, ph
194
+
195
+ @torch.no_grad()
196
+ def negative_phase(
197
+ self,
198
+ batch_size: int,
199
+ k: int = 1,
200
+ kind: str = "mean-field",
201
+ device: Optional[torch.device] = None,
202
+ ):
203
+ """Compute negative phase statistics using PCD."""
204
+ device = device or self.W.device
205
+
206
+ if self.v_chain is None or self.v_chain.shape[0] != batch_size:
207
+ self.v_chain = self._bernoulli(
208
+ torch.full((batch_size, self.nv), 0.5, device=device, dtype=self.W.dtype)
209
+ )
210
+
211
+ v = self.v_chain
212
+
213
+ for _ in range(k):
214
+ h = self._bernoulli(self.hidden_prob(v))
215
+ v = self._bernoulli(self.visible_prob(h))
216
+
217
+ self.v_chain = v.detach()
218
+
219
+ phk = self.hidden_prob(v)
220
+ h_used = phk if kind == "mean-field" else self._bernoulli(phk)
221
+
222
+ neg_W = v.T @ h_used
223
+ neg_bv = v.sum(dim=0)
224
+ neg_bh = h_used.sum(dim=0)
225
+ return neg_W, neg_bv, neg_bh
226
+
227
+ # --------------------------------------------------
228
+ # Update helpers (momentum, clipping, regularization)
229
+ # --------------------------------------------------
230
+
231
+ @staticmethod
232
+ def _clip_by_value(x: torch.Tensor, clip_value: Optional[float]) -> torch.Tensor:
233
+ if clip_value is None:
234
+ return x
235
+ return x.clamp(min=-clip_value, max=clip_value)
236
+
237
+ @staticmethod
238
+ def _clip_by_global_norm(
239
+ dW: torch.Tensor,
240
+ dbv: torch.Tensor,
241
+ dbh: torch.Tensor,
242
+ max_norm: Optional[float],
243
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if max_norm is None:
245
+ return dW, dbv, dbh
246
+ norm = torch.sqrt((dW * dW).sum() + (dbv * dbv).sum() + (dbh * dbh).sum())
247
+ if norm > max_norm:
248
+ scale = max_norm / (norm + 1e-12)
249
+ dW = dW * scale
250
+ dbv = dbv * scale
251
+ dbh = dbh * scale
252
+ return dW, dbv, dbh
253
+
254
+ def _apply_update(
255
+ self,
256
+ *,
257
+ lr: float,
258
+ dW: torch.Tensor,
259
+ dbv: torch.Tensor,
260
+ dbh: torch.Tensor,
261
+ momentum: float = 0.0,
262
+ weight_decay: float = 0.0,
263
+ clip_value: Optional[float] = None,
264
+ clip_norm: Optional[float] = None,
265
+ ) -> None:
266
+ """Apply parameter updates with momentum, weight decay, clipping, and weight masking."""
267
+ # Mask gradients early (avoids momentum accumulating on forbidden edges)
268
+ if hasattr(self, "mask") and self.mask is not None:
269
+ dW = dW * self.mask
270
+
271
+ # L2 weight decay
272
+ if weight_decay and weight_decay > 0.0:
273
+ dW = dW - weight_decay * self.W
274
+
275
+ # Clip by value
276
+ dW = self._clip_by_value(dW, clip_value)
277
+ dbv = self._clip_by_value(dbv, clip_value)
278
+ dbh = self._clip_by_value(dbh, clip_value)
279
+
280
+ # Clip by global norm
281
+ dW, dbv, dbh = self._clip_by_global_norm(dW, dbv, dbh, clip_norm)
282
+
283
+ # Momentum update
284
+ with torch.no_grad():
285
+ if momentum and momentum > 0.0:
286
+ self._vW.mul_(momentum).add_(dW, alpha=lr)
287
+ self._vbv.mul_(momentum).add_(dbv, alpha=lr)
288
+ self._vbh.mul_(momentum).add_(dbh, alpha=lr)
289
+
290
+ # keep momentum buffer masked too (optional but good)
291
+ if hasattr(self, "mask") and self.mask is not None:
292
+ self._vW.mul_(self.mask)
293
+
294
+ self.W.add_(self._vW)
295
+ self.bv.add_(self._vbv)
296
+ self.bh.add_(self._vbh)
297
+ else:
298
+ self.W.add_(dW, alpha=lr)
299
+ self.bv.add_(dbv, alpha=lr)
300
+ self.bh.add_(dbh, alpha=lr)
301
+
302
+ # Re-apply mask to ensure restricted weights stay zero
303
+ if hasattr(self, "mask") and self.mask is not None:
304
+ self.W.mul_(self.mask)
305
+
306
+ # --------------------------------------------------
307
+ # LR scheduling
308
+ # --------------------------------------------------
309
+
310
+ def _lr_at_epoch(
311
+ self,
312
+ *,
313
+ base_lr: float,
314
+ epoch: int,
315
+ epochs: int,
316
+ schedule: Optional[Dict[str, Any]] = None,
317
+ current_val_metric: Optional[float] = None,
318
+ ) -> float:
319
+ if not schedule:
320
+ return float(base_lr)
321
+
322
+ mode = schedule.get("mode", "constant")
323
+ lr0 = float(base_lr)
324
+
325
+ if mode == "constant":
326
+ return lr0
327
+
328
+ if mode == "exponential":
329
+ gamma = float(schedule.get("gamma", 0.99))
330
+ return lr0 * (gamma ** (epoch - 1))
331
+
332
+ if mode == "step":
333
+ step_size = int(schedule.get("step_size", 10))
334
+ gamma = float(schedule.get("gamma", 0.5))
335
+ n_steps = (epoch - 1) // step_size
336
+ return lr0 * (gamma ** n_steps)
337
+
338
+ if mode == "cosine":
339
+ min_lr = float(schedule.get("min_lr", 0.0))
340
+ t = (epoch - 1) / max(1, (epochs - 1))
341
+ return min_lr + 0.5 * (lr0 - min_lr) * (1.0 + math.cos(math.pi * t))
342
+
343
+ if mode == "plateau":
344
+ factor = float(schedule.get("factor", 0.5))
345
+ patience = int(schedule.get("patience", 3))
346
+ min_lr = float(schedule.get("min_lr", 1e-6))
347
+ threshold = float(schedule.get("threshold", 1e-4))
348
+
349
+ if current_val_metric is None:
350
+ return float(schedule.get("__current_lr", lr0))
351
+
352
+ current_lr = float(schedule.get("__current_lr", lr0))
353
+
354
+ if self._plateau_best is None or (self._plateau_best - current_val_metric) > threshold:
355
+ self._plateau_best = float(current_val_metric)
356
+ self._plateau_bad_count = 0
357
+ return current_lr
358
+
359
+ self._plateau_bad_count += 1
360
+ if self._plateau_bad_count >= patience:
361
+ new_lr = max(min_lr, current_lr * factor)
362
+ schedule["__current_lr"] = new_lr
363
+ self._plateau_bad_count = 0
364
+ return new_lr
365
+
366
+ return current_lr
367
+
368
+ raise ValueError(f"Unknown lr schedule mode={mode!r}")
369
+
370
+ # --------------------------------------------------
371
+ # Training
372
+ # --------------------------------------------------
373
+
374
+ @torch.no_grad()
375
+ def cd_step(
376
+ self,
377
+ v: torch.Tensor,
378
+ *,
379
+ lr: float,
380
+ k: int = 1,
381
+ kind: str = "mean-field",
382
+ momentum: float = 0.0,
383
+ weight_decay: float = 0.0,
384
+ clip_value: Optional[float] = None,
385
+ clip_norm: Optional[float] = None,
386
+ sparse_hidden: bool = False,
387
+ rho: float = 0.1,
388
+ lambda_sparse: float = 0.0,
389
+ ) -> None:
390
+ B = v.size(0)
391
+
392
+ pos_W, pos_bv, pos_bh, ph = self.positive_phase(v, kind)
393
+ neg_W, neg_bv, neg_bh = self.negative_phase(batch_size=B, k=k, kind=kind, device=v.device)
394
+
395
+ dW = (pos_W - neg_W) / B
396
+ dbv = (pos_bv - neg_bv) / B
397
+ dbh = (pos_bh - neg_bh) / B
398
+
399
+ if sparse_hidden and lambda_sparse > 0.0:
400
+ err = ph.mean(dim=0) - rho
401
+ dbh = dbh - lambda_sparse * err
402
+ v_ = v.to(self.W.dtype)
403
+ dW = dW - lambda_sparse * (v_.mean(dim=0).unsqueeze(1) * err.unsqueeze(0))
404
+
405
+ self._apply_update(
406
+ lr=lr,
407
+ dW=dW,
408
+ dbv=dbv,
409
+ dbh=dbh,
410
+ momentum=momentum,
411
+ weight_decay=weight_decay,
412
+ clip_value=clip_value,
413
+ clip_norm=clip_norm,
414
+ )
415
+
416
+ # --------------------------------------------------
417
+ # Utilities
418
+ # --------------------------------------------------
419
+
420
+ def reconstruct(self, v: torch.Tensor, k: int = 1) -> torch.Tensor:
421
+ v = v.to(self.W.dtype)
422
+ for _ in range(k):
423
+ h = self._bernoulli(self.hidden_prob(v))
424
+ v = self._bernoulli(self.visible_prob(h))
425
+ return v
426
+
427
+ def free_energy(self, v: torch.Tensor) -> torch.Tensor:
428
+ v = v.to(self.W.dtype)
429
+ wx_b = v @ self.W + self.bh
430
+ return -v @ self.bv - torch.nn.functional.softplus(wx_b).sum(dim=1)
431
+
432
+ @torch.no_grad()
433
+ def evaluate(self, dataloader, *, recon_k: int = 1) -> Dict[str, float]:
434
+ device = self.W.device
435
+
436
+ fe_sum = 0.0
437
+ mse_sum = 0.0
438
+ ber_sum = 0.0
439
+ n_samples = 0
440
+
441
+ for v in dataloader:
442
+ v = v.to(device, non_blocking=True)
443
+ B = v.size(0)
444
+ n_samples += B
445
+
446
+ fe = self.free_energy(v).mean().item() / (self.nv + self.nh) # per-node free energy
447
+ v_rec = self.reconstruct(v, k=recon_k)
448
+
449
+ mse = torch.mean((v - v_rec) ** 2).item()
450
+ ber = torch.mean((v != v_rec).to(torch.float32)).item()
451
+
452
+ fe_sum += fe * B
453
+ mse_sum += mse * B
454
+ ber_sum += ber * B
455
+
456
+ return {
457
+ "free_energy_mean": fe_sum / n_samples,
458
+ "recon_mse_mean": mse_sum / n_samples,
459
+ "recon_bit_error": ber_sum / n_samples,
460
+ }
461
+
462
+ # --------------------------------------------------
463
+ # Training loop
464
+ # --------------------------------------------------
465
+
466
+ def fit(
467
+ self,
468
+ train_loader,
469
+ *,
470
+ val_loader: Optional[object] = None,
471
+ epochs: int = 10,
472
+ lr: float = 1e-3,
473
+ k: int = 1,
474
+ kind: str = "mean-field",
475
+ eval_every: int = 1,
476
+ recon_k: int = 1,
477
+ lr_schedule: Optional[Dict[str, Any]] = None,
478
+ momentum: float = 0.0,
479
+ weight_decay: float = 0.0,
480
+ clip_value: Optional[float] = None,
481
+ clip_norm: Optional[float] = None,
482
+ sparse_hidden: bool = False,
483
+ rho: float = 0.1,
484
+ lambda_sparse: float = 0.0,
485
+ early_stopping: bool = False,
486
+ es_patience: int = 10,
487
+ es_min_delta: float = 1e-4,
488
+ ) -> Dict[str, list]:
489
+ device = self.W.device
490
+ history: Dict[str, list] = {
491
+ "epoch": [],
492
+ "train_free_energy": [],
493
+ "train_recon_mse": [],
494
+ "train_recon_bit_error": [],
495
+ "val_free_energy": [],
496
+ "val_recon_mse": [],
497
+ "val_recon_bit_error": [],
498
+ "lr": [],
499
+ }
500
+
501
+ best_val: Optional[float] = None
502
+ best_state: Optional[Dict[str, torch.Tensor]] = None
503
+ bad_epochs = 0
504
+
505
+ if lr_schedule and lr_schedule.get("mode") == "plateau":
506
+ lr_schedule = dict(lr_schedule)
507
+ lr_schedule["__current_lr"] = float(lr)
508
+
509
+ for epoch in range(1, epochs + 1):
510
+ self.train()
511
+
512
+ current_lr = self._lr_at_epoch(
513
+ base_lr=float(lr),
514
+ epoch=epoch,
515
+ epochs=epochs,
516
+ schedule=lr_schedule,
517
+ current_val_metric=None,
518
+ )
519
+
520
+ for v in train_loader:
521
+ v = v.to(device, non_blocking=True)
522
+ self.cd_step(
523
+ v,
524
+ lr=current_lr,
525
+ k=k,
526
+ kind=kind,
527
+ momentum=momentum,
528
+ weight_decay=weight_decay,
529
+ clip_value=clip_value,
530
+ clip_norm=clip_norm,
531
+ sparse_hidden=sparse_hidden,
532
+ rho=rho,
533
+ lambda_sparse=lambda_sparse,
534
+ )
535
+
536
+ if epoch % eval_every == 0:
537
+ self.eval()
538
+
539
+ train_metrics = self.evaluate(train_loader, recon_k=recon_k)
540
+ history["epoch"].append(epoch)
541
+ history["train_free_energy"].append(train_metrics["free_energy_mean"])
542
+ history["train_recon_mse"].append(train_metrics["recon_mse_mean"])
543
+ history["train_recon_bit_error"].append(train_metrics["recon_bit_error"])
544
+
545
+ if val_loader is not None:
546
+ val_metrics = self.evaluate(val_loader, recon_k=recon_k)
547
+ val_fe = float(val_metrics["free_energy_mean"])
548
+
549
+ history["val_free_energy"].append(val_fe)
550
+ history["val_recon_mse"].append(val_metrics["recon_mse_mean"])
551
+ history["val_recon_bit_error"].append(val_metrics["recon_bit_error"])
552
+
553
+ if lr_schedule and lr_schedule.get("mode") == "plateau":
554
+ # plateau uses its internal "__current_lr"
555
+ _ = self._lr_at_epoch(
556
+ base_lr=float(lr),
557
+ epoch=epoch,
558
+ epochs=epochs,
559
+ schedule=lr_schedule,
560
+ current_val_metric=val_fe,
561
+ )
562
+ current_lr = float(lr_schedule.get("__current_lr", current_lr))
563
+
564
+ if early_stopping:
565
+ if best_val is None or (best_val - val_fe) > es_min_delta:
566
+ best_val = val_fe
567
+ best_state = {k: t.detach().clone() for k, t in self.state_dict().items()}
568
+ bad_epochs = 0
569
+ else:
570
+ bad_epochs += 1
571
+ if bad_epochs >= es_patience:
572
+ if best_state is not None:
573
+ self.load_state_dict(best_state)
574
+ history["lr"].append(current_lr)
575
+ print(
576
+ f"Early stopping at epoch {epoch} "
577
+ f"(best val FE={best_val:.6f})"
578
+ )
579
+ self.visualize_history(history)
580
+ return history
581
+
582
+ print(
583
+ f"Epoch {epoch:04d} | lr={current_lr:.3e} | "
584
+ f"train FE={train_metrics['free_energy_mean']:.4f} "
585
+ f"val FE={val_fe:.4f} | "
586
+ f"train recon_mse={train_metrics['recon_mse_mean']:.4f} "
587
+ f"val recon_mse={val_metrics['recon_mse_mean']:.4f}"
588
+ )
589
+ else:
590
+ history["val_free_energy"].append(float("nan"))
591
+ history["val_recon_mse"].append(float("nan"))
592
+ history["val_recon_bit_error"].append(float("nan"))
593
+ print(
594
+ f"Epoch {epoch:04d} | lr={current_lr:.3e} | "
595
+ f"train FE={train_metrics['free_energy_mean']:.4f} | "
596
+ f"train recon_mse={train_metrics['recon_mse_mean']:.4f}"
597
+ )
598
+
599
+ history["lr"].append(current_lr)
600
+
601
+ self.visualize_history(history)
602
+ return history
603
+
604
+ # --------------------------------------------------
605
+ # Visualization
606
+ # --------------------------------------------------
607
+
608
+ def visualize_history(self, history: dict) -> None:
609
+ try:
610
+ import matplotlib.pyplot as plt
611
+ except Exception as e:
612
+ print(f"[visualize_history] matplotlib import failed: {e}")
613
+ return
614
+
615
+ epochs = history.get("epoch", list(range(1, len(history.get("train_free_energy", [])) + 1)))
616
+
617
+ def _is_all_nan(xs):
618
+ if not xs:
619
+ return True
620
+ return all((x is None) or (isinstance(x, float) and math.isnan(x)) for x in xs)
621
+
622
+ def _plot(ax, y_train_key, y_val_key, title, ylabel):
623
+ y_tr = history.get(y_train_key, [])
624
+ y_va = history.get(y_val_key, [])
625
+
626
+ ax.plot(epochs[: len(y_tr)], y_tr, label="train")
627
+ if not _is_all_nan(y_va):
628
+ ax.plot(epochs[: len(y_va)], y_va, label="val")
629
+
630
+ ax.set_title(title)
631
+ ax.set_xlabel("epoch")
632
+ ax.set_ylabel(ylabel)
633
+ ax.legend()
634
+
635
+ try:
636
+ fig, axes = plt.subplots(1, 3, figsize=(16, 4))
637
+
638
+ _plot(axes[0], "train_free_energy", "val_free_energy", "Free Energy", "mean FE (lower is better)")
639
+ _plot(axes[1], "train_recon_mse", "val_recon_mse", "Reconstruction MSE", "MSE")
640
+ _plot(axes[2], "train_recon_bit_error", "val_recon_bit_error", "Reconstruction Bit Error", "fraction mismatched")
641
+
642
+ fig.tight_layout()
643
+ plt.show()
644
+ except Exception as e:
645
+ print(f"[visualize_history] plotting failed: {e}")
646
+
647
+ # --------------------------------------------------
648
+ # Sampling
649
+ # --------------------------------------------------
650
+
651
+ @torch.no_grad()
652
+ def sample(
653
+ self,
654
+ n_samples: int,
655
+ *,
656
+ burn_in: int = 200,
657
+ thin: int = 10,
658
+ init: str = "random",
659
+ device: Optional[torch.device] = None,
660
+ ) -> torch.Tensor:
661
+ device = device or self.W.device
662
+ dtype = self.W.dtype
663
+
664
+ if init == "chain" and self.v_chain is not None:
665
+ v = self.v_chain[:1].to(device=device, dtype=dtype)
666
+ else:
667
+ v = torch.bernoulli(torch.full((1, self.nv), 0.5, device=device, dtype=dtype))
668
+
669
+ for _ in range(burn_in):
670
+ h = torch.bernoulli(self.hidden_prob(v))
671
+ v = torch.bernoulli(self.visible_prob(h))
672
+
673
+ samples = []
674
+ steps_needed = n_samples * thin
675
+ for t in range(steps_needed):
676
+ h = torch.bernoulli(self.hidden_prob(v))
677
+ v = torch.bernoulli(self.visible_prob(h))
678
+ if (t + 1) % thin == 0:
679
+ samples.append(v.squeeze(0).clone())
680
+
681
+ self.v_chain = v.detach()
682
+ return torch.stack(samples, dim=0)
683
+
684
+ def draw_blocks(self, save_path: Optional[str] = None, show: bool = True) -> None:
685
+ """Visualize the RBM block structure at the block level.
686
+
687
+ Draws a bipartite graph showing visible blocks (bottom) and hidden blocks (top),
688
+ with block-to-block connections colored by whether they are allowed or restricted.
689
+
690
+ Args:
691
+ save_path: If provided, save the figure to this path.
692
+ show: If True, display the plot interactively.
693
+ """
694
+ try:
695
+ import matplotlib.pyplot as plt
696
+ import matplotlib.patches as mpatches
697
+ from matplotlib.patches import FancyBboxPatch
698
+ except ImportError as e:
699
+ print(f"[draw_blocks] matplotlib import failed: {e}")
700
+ return
701
+
702
+ fig, ax = plt.subplots(figsize=(14, 8))
703
+
704
+ # Layout parameters
705
+ v_y = 0.15 # visible layer y-coordinate
706
+ h_y = 0.85 # hidden layer y-coordinate
707
+
708
+ # Colors
709
+ v_colors = list(plt.cm.Set2.colors)
710
+ h_colors = list(plt.cm.Set3.colors)
711
+ allowed_color = "#4CAF50" # green for allowed connections
712
+ restricted_color = "#E57373" # red for restricted
713
+
714
+ # Get block info
715
+ v_blocks = list(self._v_block_ranges.items())
716
+ h_blocks = list(self._h_block_ranges.items())
717
+
718
+ n_v_blocks = len(v_blocks)
719
+ n_h_blocks = len(h_blocks)
720
+
721
+ # Calculate block positions (evenly spaced)
722
+ def get_block_positions(n_blocks, y):
723
+ if n_blocks == 1:
724
+ return [0.5]
725
+ return [0.1 + i * 0.8 / (n_blocks - 1) for i in range(n_blocks)]
726
+
727
+ v_x_positions = get_block_positions(n_v_blocks, v_y)
728
+ h_x_positions = get_block_positions(n_h_blocks, h_y)
729
+
730
+ # Determine block-to-block connectivity from mask
731
+ mask_np = self.mask.detach().cpu().numpy()
732
+
733
+ def blocks_connected(v_block_range, h_block_range):
734
+ """Check if any connection exists between two blocks."""
735
+ v_start, v_end = v_block_range
736
+ h_start, h_end = h_block_range
737
+ return mask_np[v_start:v_end, h_start:h_end].any()
738
+
739
+ # Draw connections between blocks
740
+ for vi, (v_name, v_range) in enumerate(v_blocks):
741
+ for hi, (h_name, h_range) in enumerate(h_blocks):
742
+ vx, vy = v_x_positions[vi], v_y
743
+ hx, hy = h_x_positions[hi], h_y
744
+
745
+ connected = blocks_connected(v_range, h_range)
746
+ color = allowed_color if connected else restricted_color
747
+ alpha = 0.7 if connected else 0.2
748
+ linewidth = 2.5 if connected else 1.0
749
+ zorder = 2 if connected else 1
750
+
751
+ ax.plot([vx, hx], [vy + 0.05, hy - 0.05],
752
+ color=color, alpha=alpha, linewidth=linewidth, zorder=zorder)
753
+
754
+ # Draw visible blocks as rounded rectangles
755
+ block_height = 0.08
756
+ for i, (name, (start, end)) in enumerate(v_blocks):
757
+ size = end - start
758
+ x = v_x_positions[i]
759
+ color = v_colors[i % len(v_colors)]
760
+
761
+ # Block width proportional to log of size (for visual balance)
762
+ width = 0.06 + 0.02 * min(3, max(0, (size / 100)))
763
+
764
+ rect = FancyBboxPatch(
765
+ (x - width/2, v_y - block_height/2), width, block_height,
766
+ boxstyle="round,pad=0.01,rounding_size=0.02",
767
+ facecolor=color, edgecolor="black", linewidth=1.5, zorder=3
768
+ )
769
+ ax.add_patch(rect)
770
+
771
+ # Block label with size
772
+ ax.text(x, v_y, f"{name}\n({size})", ha="center", va="center",
773
+ fontsize=9, fontweight="bold", zorder=4)
774
+
775
+ # Draw hidden blocks as rounded rectangles
776
+ for i, (name, (start, end)) in enumerate(h_blocks):
777
+ size = end - start
778
+ x = h_x_positions[i]
779
+ color = h_colors[i % len(h_colors)]
780
+
781
+ width = 0.06 + 0.02 * min(3, max(0, (size / 100)))
782
+
783
+ rect = FancyBboxPatch(
784
+ (x - width/2, h_y - block_height/2), width, block_height,
785
+ boxstyle="round,pad=0.01,rounding_size=0.02",
786
+ facecolor=color, edgecolor="black", linewidth=1.5, zorder=3
787
+ )
788
+ ax.add_patch(rect)
789
+
790
+ ax.text(x, h_y, f"{name}\n({size})", ha="center", va="center",
791
+ fontsize=9, fontweight="bold", zorder=4)
792
+
793
+ # Layer labels
794
+ ax.text(0.02, v_y, "Visible\nLayer", ha="left", va="center", fontsize=11, fontweight="bold")
795
+ ax.text(0.02, h_y, "Hidden\nLayer", ha="left", va="center", fontsize=11, fontweight="bold")
796
+
797
+ # Legend
798
+ allowed_patch = mpatches.Patch(color=allowed_color, label="Connected (allowed)")
799
+ restricted_patch = mpatches.Patch(color=restricted_color, alpha=0.3, label="Restricted (masked)")
800
+ ax.legend(handles=[allowed_patch, restricted_patch], loc="upper right", fontsize=10)
801
+
802
+ # Title with summary
803
+ n_allowed = int(mask_np.sum())
804
+ n_total = self.nv * self.nh
805
+ n_restricted = n_total - n_allowed
806
+ ax.set_title(
807
+ f"RBM Block Structure\n"
808
+ f"Visible: {self.nv:,} units ({n_v_blocks} blocks) | "
809
+ f"Hidden: {self.nh:,} units ({n_h_blocks} blocks)\n"
810
+ f"Connections: {n_allowed:,}/{n_total:,} allowed, {n_restricted:,} restricted",
811
+ fontsize=12, fontweight="bold"
812
+ )
813
+
814
+ ax.set_xlim(-0.05, 1.05)
815
+ ax.set_ylim(0, 1)
816
+ ax.set_aspect("equal")
817
+ ax.axis("off")
818
+
819
+ fig.tight_layout()
820
+
821
+ if save_path:
822
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
823
+ print(f"Block diagram saved to: {save_path}")
824
+
825
+ if show:
826
+ plt.show()
827
+ else:
828
+ plt.close(fig)
829
+
830
+ @torch.no_grad()
831
+ def sample_clamped(
832
+ self,
833
+ v_clamp: torch.Tensor,
834
+ clamp_idx: Sequence[int],
835
+ *,
836
+ n_samples: int = 1000,
837
+ burn_in: int = 200,
838
+ thin: int = 10,
839
+ init: str = "random",
840
+ device: Optional[torch.device] = None,
841
+ ) -> torch.Tensor:
842
+ device = device or self.W.device
843
+ dtype = self.W.dtype
844
+
845
+ if v_clamp.dim() == 1:
846
+ v_clamp = v_clamp.unsqueeze(0)
847
+ v_clamp = v_clamp.to(device=device, dtype=dtype)
848
+
849
+ clamp_idx_t = torch.as_tensor(clamp_idx, device=device, dtype=torch.long)
850
+
851
+ v = torch.bernoulli(torch.full((1, self.nv), 0.5, device=device, dtype=dtype))
852
+ v[:, clamp_idx_t] = v_clamp[:, clamp_idx_t]
853
+
854
+ for _ in range(burn_in):
855
+ h = torch.bernoulli(self.hidden_prob(v))
856
+ v = torch.bernoulli(self.visible_prob(h))
857
+ v[:, clamp_idx_t] = v_clamp[:, clamp_idx_t]
858
+
859
+ samples = []
860
+ for t in range(n_samples * thin):
861
+ h = torch.bernoulli(self.hidden_prob(v))
862
+ v = torch.bernoulli(self.visible_prob(h))
863
+ v[:, clamp_idx_t] = v_clamp[:, clamp_idx_t]
864
+ if (t + 1) % thin == 0:
865
+ samples.append(v.squeeze(0).clone())
866
+
867
+ return torch.stack(samples, dim=0)