boltzmann9 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltzmann9/__init__.py +38 -0
- boltzmann9/__main__.py +4 -0
- boltzmann9/cli.py +389 -0
- boltzmann9/config.py +58 -0
- boltzmann9/data.py +145 -0
- boltzmann9/data_generator.py +234 -0
- boltzmann9/model.py +867 -0
- boltzmann9/pipeline.py +216 -0
- boltzmann9/preprocessor.py +627 -0
- boltzmann9/project.py +195 -0
- boltzmann9/run_utils.py +262 -0
- boltzmann9/tester.py +167 -0
- boltzmann9/utils.py +42 -0
- boltzmann9/visualization.py +115 -0
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.6.dist-info}/METADATA +1 -1
- boltzmann9-0.1.6.dist-info/RECORD +19 -0
- boltzmann9-0.1.6.dist-info/top_level.txt +1 -0
- boltzmann9-0.1.4.dist-info/RECORD +0 -5
- boltzmann9-0.1.4.dist-info/top_level.txt +0 -1
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.6.dist-info}/WHEEL +0 -0
- {boltzmann9-0.1.4.dist-info → boltzmann9-0.1.6.dist-info}/entry_points.txt +0 -0
boltzmann9/project.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Project and run management utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
import shutil
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _get_templates_dir() -> Path:
|
|
13
|
+
"""Get the path to the templates directory."""
|
|
14
|
+
return Path(__file__).parent.parent / "templates"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _load_config_template(data_path: str) -> str:
|
|
18
|
+
"""Load config.py template and substitute the data path.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
data_path: Path to set for csv_path in the config.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Config file content with substituted data path.
|
|
25
|
+
"""
|
|
26
|
+
template_path = _get_templates_dir() / "config.py"
|
|
27
|
+
content = template_path.read_text()
|
|
28
|
+
|
|
29
|
+
# Replace the csv_path value (handles various quoting styles)
|
|
30
|
+
content = re.sub(
|
|
31
|
+
r'("csv_path":\s*)"[^"]*"',
|
|
32
|
+
f'\\1"{data_path}"',
|
|
33
|
+
content
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return content
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _load_generator_template() -> str:
|
|
40
|
+
"""Load synthetic_generator.py template from templates directory.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Generator script content.
|
|
44
|
+
"""
|
|
45
|
+
template_path = _get_templates_dir() / "synthetic_generator.py"
|
|
46
|
+
return template_path.read_text()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def create_project(project_path: str | Path) -> Path:
|
|
50
|
+
"""Create a new BoltzmaNN9 project.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
project_path: Path where the project will be created.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Path to the created project.
|
|
57
|
+
|
|
58
|
+
Raises:
|
|
59
|
+
FileExistsError: If project directory already exists.
|
|
60
|
+
"""
|
|
61
|
+
project_path = Path(project_path)
|
|
62
|
+
|
|
63
|
+
if project_path.exists():
|
|
64
|
+
raise FileExistsError(f"Project directory already exists: {project_path}")
|
|
65
|
+
|
|
66
|
+
# Create project structure
|
|
67
|
+
project_path.mkdir(parents=True)
|
|
68
|
+
data_dir = project_path / "data"
|
|
69
|
+
data_dir.mkdir()
|
|
70
|
+
output_dir = project_path / "output"
|
|
71
|
+
output_dir.mkdir()
|
|
72
|
+
|
|
73
|
+
# Create config.py from template with correct data path
|
|
74
|
+
data_csv_path = "data/data.csv"
|
|
75
|
+
config_content = _load_config_template(data_csv_path)
|
|
76
|
+
config_file = project_path / "config.py"
|
|
77
|
+
config_file.write_text(config_content)
|
|
78
|
+
|
|
79
|
+
# Create synthetic_generator.py in data folder from template
|
|
80
|
+
generator_content = _load_generator_template()
|
|
81
|
+
generator_file = data_dir / "synthetic_generator.py"
|
|
82
|
+
generator_file.write_text(generator_content)
|
|
83
|
+
|
|
84
|
+
print(f"Created project: {project_path}")
|
|
85
|
+
print(f" - config.py")
|
|
86
|
+
print(f" - data/")
|
|
87
|
+
print(f" - synthetic_generator.py")
|
|
88
|
+
print(f" - output/")
|
|
89
|
+
print(f"\nNext steps:")
|
|
90
|
+
print(f" 1. cd {project_path}/data && python synthetic_generator.py")
|
|
91
|
+
print(f" 2. Edit {project_path}/config.py as needed")
|
|
92
|
+
print(f" 3. python boltzmann.py train --project {project_path}")
|
|
93
|
+
|
|
94
|
+
return project_path
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def create_run(project_path: str | Path, run_name: str | None = None) -> Path:
|
|
98
|
+
"""Create a new run directory within a project.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
project_path: Path to the project.
|
|
102
|
+
run_name: Optional custom run name. If None, uses timestamp.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Path to the created run directory.
|
|
106
|
+
"""
|
|
107
|
+
project_path = Path(project_path)
|
|
108
|
+
output_dir = project_path / "output"
|
|
109
|
+
|
|
110
|
+
if not output_dir.exists():
|
|
111
|
+
output_dir.mkdir(parents=True)
|
|
112
|
+
|
|
113
|
+
# Generate run name with timestamp
|
|
114
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
115
|
+
if run_name:
|
|
116
|
+
run_dir_name = f"run_{timestamp}_{run_name}"
|
|
117
|
+
else:
|
|
118
|
+
run_dir_name = f"run_{timestamp}"
|
|
119
|
+
|
|
120
|
+
run_dir = output_dir / run_dir_name
|
|
121
|
+
run_dir.mkdir()
|
|
122
|
+
|
|
123
|
+
# Create subdirectories
|
|
124
|
+
(run_dir / "plots").mkdir()
|
|
125
|
+
(run_dir / "checkpoints").mkdir()
|
|
126
|
+
|
|
127
|
+
return run_dir
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def save_run_config(run_dir: Path, config_path: Path) -> None:
|
|
131
|
+
"""Copy config file to run directory.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
run_dir: Path to the run directory.
|
|
135
|
+
config_path: Path to the original config file.
|
|
136
|
+
"""
|
|
137
|
+
dest = run_dir / "config.py"
|
|
138
|
+
shutil.copy2(config_path, dest)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_run_paths(run_dir: Path) -> Dict[str, Path]:
|
|
142
|
+
"""Get standard paths within a run directory.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
run_dir: Path to the run directory.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Dictionary with paths for model, log, plots, checkpoints, config.
|
|
149
|
+
"""
|
|
150
|
+
return {
|
|
151
|
+
"model": run_dir / "model.pt",
|
|
152
|
+
"log": run_dir / "training.log",
|
|
153
|
+
"plots": run_dir / "plots",
|
|
154
|
+
"checkpoints": run_dir / "checkpoints",
|
|
155
|
+
"config": run_dir / "config.py",
|
|
156
|
+
"history": run_dir / "history.json",
|
|
157
|
+
"metrics": run_dir / "metrics.json",
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def find_latest_run(project_path: str | Path) -> Path | None:
|
|
162
|
+
"""Find the most recent run in a project.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
project_path: Path to the project.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Path to the latest run directory, or None if no runs exist.
|
|
169
|
+
"""
|
|
170
|
+
project_path = Path(project_path)
|
|
171
|
+
output_dir = project_path / "output"
|
|
172
|
+
|
|
173
|
+
if not output_dir.exists():
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
runs = sorted(output_dir.glob("run_*"), reverse=True)
|
|
177
|
+
return runs[0] if runs else None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def list_runs(project_path: str | Path) -> list[Path]:
|
|
181
|
+
"""List all runs in a project.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
project_path: Path to the project.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List of run directory paths, sorted by date (newest first).
|
|
188
|
+
"""
|
|
189
|
+
project_path = Path(project_path)
|
|
190
|
+
output_dir = project_path / "output"
|
|
191
|
+
|
|
192
|
+
if not output_dir.exists():
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
return sorted(output_dir.glob("run_*"), reverse=True)
|
boltzmann9/run_utils.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""Utilities for run logging and visualization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, TextIO
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RunLogger:
|
|
13
|
+
"""Logger that writes to both console and file."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, log_path: Path):
|
|
16
|
+
"""Initialize logger.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
log_path: Path to log file.
|
|
20
|
+
"""
|
|
21
|
+
self.log_path = log_path
|
|
22
|
+
self.log_file: TextIO | None = None
|
|
23
|
+
self._original_stdout = sys.stdout
|
|
24
|
+
|
|
25
|
+
def __enter__(self):
|
|
26
|
+
self.log_file = open(self.log_path, "w")
|
|
27
|
+
sys.stdout = self
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
31
|
+
sys.stdout = self._original_stdout
|
|
32
|
+
if self.log_file:
|
|
33
|
+
self.log_file.close()
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
def write(self, message: str):
|
|
37
|
+
self._original_stdout.write(message)
|
|
38
|
+
if self.log_file:
|
|
39
|
+
self.log_file.write(message)
|
|
40
|
+
self.log_file.flush()
|
|
41
|
+
|
|
42
|
+
def flush(self):
|
|
43
|
+
self._original_stdout.flush()
|
|
44
|
+
if self.log_file:
|
|
45
|
+
self.log_file.flush()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def save_history(history: Dict[str, list], path: Path) -> None:
|
|
49
|
+
"""Save training history to JSON.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
history: Training history dictionary.
|
|
53
|
+
path: Path to save JSON file.
|
|
54
|
+
"""
|
|
55
|
+
# Convert any non-serializable types
|
|
56
|
+
serializable = {}
|
|
57
|
+
for key, values in history.items():
|
|
58
|
+
serializable[key] = [
|
|
59
|
+
float(v) if isinstance(v, (int, float)) else v
|
|
60
|
+
for v in values
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
with open(path, "w") as f:
|
|
64
|
+
json.dump(serializable, f, indent=2)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def load_history(path: Path) -> Dict[str, list]:
|
|
68
|
+
"""Load training history from JSON.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
path: Path to history JSON file.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Training history dictionary.
|
|
75
|
+
"""
|
|
76
|
+
with open(path) as f:
|
|
77
|
+
return json.load(f)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def save_metrics(metrics: Dict[str, Any], path: Path) -> None:
|
|
81
|
+
"""Save evaluation metrics to JSON.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
metrics: Metrics dictionary.
|
|
85
|
+
path: Path to save JSON file.
|
|
86
|
+
"""
|
|
87
|
+
# Convert nested structures
|
|
88
|
+
def convert(obj):
|
|
89
|
+
if isinstance(obj, dict):
|
|
90
|
+
return {k: convert(v) for k, v in obj.items()}
|
|
91
|
+
elif isinstance(obj, list):
|
|
92
|
+
return [convert(v) for v in obj]
|
|
93
|
+
elif isinstance(obj, (int, float)):
|
|
94
|
+
return float(obj)
|
|
95
|
+
else:
|
|
96
|
+
return obj
|
|
97
|
+
|
|
98
|
+
with open(path, "w") as f:
|
|
99
|
+
json.dump(convert(metrics), f, indent=2)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def load_metrics(path: Path) -> Dict[str, Any]:
|
|
103
|
+
"""Load evaluation metrics from JSON.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
path: Path to metrics JSON file.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Metrics dictionary.
|
|
110
|
+
"""
|
|
111
|
+
with open(path) as f:
|
|
112
|
+
return json.load(f)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def save_plots(history: Dict[str, list], plots_dir: Path, model: Any = None) -> None:
|
|
116
|
+
"""Save training plots to directory.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
history: Training history dictionary.
|
|
120
|
+
plots_dir: Directory to save plots.
|
|
121
|
+
model: Optional RBM model instance. If provided, saves block diagram.
|
|
122
|
+
"""
|
|
123
|
+
try:
|
|
124
|
+
import matplotlib.pyplot as plt
|
|
125
|
+
except ImportError:
|
|
126
|
+
print("[save_plots] matplotlib not available, skipping plots")
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
import math
|
|
130
|
+
|
|
131
|
+
plots_dir.mkdir(parents=True, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
# Save RBM block diagram if model is provided
|
|
134
|
+
if model is not None and hasattr(model, "draw_blocks"):
|
|
135
|
+
try:
|
|
136
|
+
model.draw_blocks(save_path=str(plots_dir / "RBM_blocks.png"), show=False)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
print(f"[save_plots] Failed to save block diagram: {e}")
|
|
139
|
+
|
|
140
|
+
epochs = history.get(
|
|
141
|
+
"epoch", list(range(1, len(history.get("train_free_energy", [])) + 1))
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _is_all_nan(xs):
|
|
145
|
+
if not xs:
|
|
146
|
+
return True
|
|
147
|
+
return all(
|
|
148
|
+
(x is None) or (isinstance(x, float) and math.isnan(x)) for x in xs
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Plot 1: Free Energy
|
|
152
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
153
|
+
y_tr = history.get("train_free_energy", [])
|
|
154
|
+
y_va = history.get("val_free_energy", [])
|
|
155
|
+
|
|
156
|
+
ax.plot(epochs[:len(y_tr)], y_tr, label="train", linewidth=2)
|
|
157
|
+
if not _is_all_nan(y_va):
|
|
158
|
+
ax.plot(epochs[:len(y_va)], y_va, label="val", linewidth=2)
|
|
159
|
+
|
|
160
|
+
ax.set_title("Free Energy", fontsize=14)
|
|
161
|
+
ax.set_xlabel("Epoch", fontsize=12)
|
|
162
|
+
ax.set_ylabel("Mean FE (lower is better)", fontsize=12)
|
|
163
|
+
ax.legend()
|
|
164
|
+
ax.grid(True, alpha=0.3)
|
|
165
|
+
fig.tight_layout()
|
|
166
|
+
fig.savefig(plots_dir / "free_energy.png", dpi=150)
|
|
167
|
+
plt.close(fig)
|
|
168
|
+
|
|
169
|
+
# Plot 2: Reconstruction MSE
|
|
170
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
171
|
+
y_tr = history.get("train_recon_mse", [])
|
|
172
|
+
y_va = history.get("val_recon_mse", [])
|
|
173
|
+
|
|
174
|
+
ax.plot(epochs[:len(y_tr)], y_tr, label="train", linewidth=2)
|
|
175
|
+
if not _is_all_nan(y_va):
|
|
176
|
+
ax.plot(epochs[:len(y_va)], y_va, label="val", linewidth=2)
|
|
177
|
+
|
|
178
|
+
ax.set_title("Reconstruction MSE", fontsize=14)
|
|
179
|
+
ax.set_xlabel("Epoch", fontsize=12)
|
|
180
|
+
ax.set_ylabel("MSE", fontsize=12)
|
|
181
|
+
ax.legend()
|
|
182
|
+
ax.grid(True, alpha=0.3)
|
|
183
|
+
fig.tight_layout()
|
|
184
|
+
fig.savefig(plots_dir / "reconstruction_mse.png", dpi=150)
|
|
185
|
+
plt.close(fig)
|
|
186
|
+
|
|
187
|
+
# Plot 3: Reconstruction Bit Error
|
|
188
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
189
|
+
y_tr = history.get("train_recon_bit_error", [])
|
|
190
|
+
y_va = history.get("val_recon_bit_error", [])
|
|
191
|
+
|
|
192
|
+
ax.plot(epochs[:len(y_tr)], y_tr, label="train", linewidth=2)
|
|
193
|
+
if not _is_all_nan(y_va):
|
|
194
|
+
ax.plot(epochs[:len(y_va)], y_va, label="val", linewidth=2)
|
|
195
|
+
|
|
196
|
+
ax.set_title("Reconstruction Bit Error", fontsize=14)
|
|
197
|
+
ax.set_xlabel("Epoch", fontsize=12)
|
|
198
|
+
ax.set_ylabel("Fraction Mismatched", fontsize=12)
|
|
199
|
+
ax.legend()
|
|
200
|
+
ax.grid(True, alpha=0.3)
|
|
201
|
+
fig.tight_layout()
|
|
202
|
+
fig.savefig(plots_dir / "bit_error.png", dpi=150)
|
|
203
|
+
plt.close(fig)
|
|
204
|
+
|
|
205
|
+
# Plot 4: Learning Rate
|
|
206
|
+
if "lr" in history and history["lr"]:
|
|
207
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
208
|
+
lr_vals = history["lr"]
|
|
209
|
+
ax.plot(epochs[:len(lr_vals)], lr_vals, linewidth=2, color="green")
|
|
210
|
+
ax.set_title("Learning Rate Schedule", fontsize=14)
|
|
211
|
+
ax.set_xlabel("Epoch", fontsize=12)
|
|
212
|
+
ax.set_ylabel("Learning Rate", fontsize=12)
|
|
213
|
+
ax.set_yscale("log")
|
|
214
|
+
ax.grid(True, alpha=0.3)
|
|
215
|
+
fig.tight_layout()
|
|
216
|
+
fig.savefig(plots_dir / "learning_rate.png", dpi=150)
|
|
217
|
+
plt.close(fig)
|
|
218
|
+
|
|
219
|
+
# Combined plot
|
|
220
|
+
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
|
|
221
|
+
|
|
222
|
+
# Free Energy
|
|
223
|
+
y_tr = history.get("train_free_energy", [])
|
|
224
|
+
y_va = history.get("val_free_energy", [])
|
|
225
|
+
axes[0].plot(epochs[:len(y_tr)], y_tr, label="train")
|
|
226
|
+
if not _is_all_nan(y_va):
|
|
227
|
+
axes[0].plot(epochs[:len(y_va)], y_va, label="val")
|
|
228
|
+
axes[0].set_title("Free Energy")
|
|
229
|
+
axes[0].set_xlabel("Epoch")
|
|
230
|
+
axes[0].set_ylabel("Mean FE")
|
|
231
|
+
axes[0].legend()
|
|
232
|
+
axes[0].grid(True, alpha=0.3)
|
|
233
|
+
|
|
234
|
+
# MSE
|
|
235
|
+
y_tr = history.get("train_recon_mse", [])
|
|
236
|
+
y_va = history.get("val_recon_mse", [])
|
|
237
|
+
axes[1].plot(epochs[:len(y_tr)], y_tr, label="train")
|
|
238
|
+
if not _is_all_nan(y_va):
|
|
239
|
+
axes[1].plot(epochs[:len(y_va)], y_va, label="val")
|
|
240
|
+
axes[1].set_title("Reconstruction MSE")
|
|
241
|
+
axes[1].set_xlabel("Epoch")
|
|
242
|
+
axes[1].set_ylabel("MSE")
|
|
243
|
+
axes[1].legend()
|
|
244
|
+
axes[1].grid(True, alpha=0.3)
|
|
245
|
+
|
|
246
|
+
# Bit Error
|
|
247
|
+
y_tr = history.get("train_recon_bit_error", [])
|
|
248
|
+
y_va = history.get("val_recon_bit_error", [])
|
|
249
|
+
axes[2].plot(epochs[:len(y_tr)], y_tr, label="train")
|
|
250
|
+
if not _is_all_nan(y_va):
|
|
251
|
+
axes[2].plot(epochs[:len(y_va)], y_va, label="val")
|
|
252
|
+
axes[2].set_title("Bit Error")
|
|
253
|
+
axes[2].set_xlabel("Epoch")
|
|
254
|
+
axes[2].set_ylabel("Fraction")
|
|
255
|
+
axes[2].legend()
|
|
256
|
+
axes[2].grid(True, alpha=0.3)
|
|
257
|
+
|
|
258
|
+
fig.tight_layout()
|
|
259
|
+
fig.savefig(plots_dir / "training_summary.png", dpi=150)
|
|
260
|
+
plt.close(fig)
|
|
261
|
+
|
|
262
|
+
print(f"Plots saved to: {plots_dir}")
|
boltzmann9/tester.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Testing utilities for RBM conditional probability evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from collections import Counter
|
|
7
|
+
from typing import Any, Dict, Sequence
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RBMTester:
|
|
14
|
+
"""Test RBM conditional probability estimation.
|
|
15
|
+
|
|
16
|
+
Evaluates how well the RBM can predict target visible units
|
|
17
|
+
given clamped (observed) visible units.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model: Trained RBM model.
|
|
21
|
+
test_dataloader: DataLoader for test data.
|
|
22
|
+
clamp_idx: Indices of visible units to clamp (condition on).
|
|
23
|
+
target_idx: Indices of visible units to predict.
|
|
24
|
+
device: Target device for computations.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model,
|
|
30
|
+
test_dataloader,
|
|
31
|
+
clamp_idx: Sequence[int],
|
|
32
|
+
target_idx: Sequence[int],
|
|
33
|
+
*,
|
|
34
|
+
device=None,
|
|
35
|
+
):
|
|
36
|
+
self.model = model
|
|
37
|
+
self.test_dataloader = test_dataloader
|
|
38
|
+
self.clamp_idx = list(clamp_idx)
|
|
39
|
+
self.target_idx = list(target_idx)
|
|
40
|
+
self.device = device or model.W.device
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def _bits_to_int(bits: torch.Tensor) -> int:
|
|
44
|
+
"""Convert binary bits to integer (LSB-first)."""
|
|
45
|
+
weights = 2 ** torch.arange(bits.numel(), device=bits.device)
|
|
46
|
+
return int((bits * weights).sum().item())
|
|
47
|
+
|
|
48
|
+
@torch.no_grad()
|
|
49
|
+
def conditional_nll(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
n_samples: int = 500,
|
|
53
|
+
burn_in: int = 300,
|
|
54
|
+
thin: int = 10,
|
|
55
|
+
laplace_alpha: float = 1.0,
|
|
56
|
+
log_every: int = 50,
|
|
57
|
+
) -> Dict[str, Any]:
|
|
58
|
+
"""Compute conditional negative log-likelihood over test set.
|
|
59
|
+
|
|
60
|
+
For each test sample, clamps the specified visible units and
|
|
61
|
+
samples from the conditional distribution p(target | clamp).
|
|
62
|
+
|
|
63
|
+
Computes per-bit NLL using empirical bit frequencies, which scales
|
|
64
|
+
to large target dimensions.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
n_samples: Number of MCMC samples per test point.
|
|
68
|
+
burn_in: Burn-in steps for MCMC.
|
|
69
|
+
thin: Thinning interval between samples.
|
|
70
|
+
laplace_alpha: Laplace smoothing parameter for per-bit probabilities.
|
|
71
|
+
log_every: Log progress every N samples.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dictionary with:
|
|
75
|
+
- mean_nll_nats: Mean NLL in nats (sum over target bits).
|
|
76
|
+
- mean_nll_bits: Mean NLL in bits.
|
|
77
|
+
- mean_nll_per_bit: Mean NLL per target bit.
|
|
78
|
+
- nll_nats_per_sample: List of per-sample NLL in nats.
|
|
79
|
+
- nll_bits_per_sample: List of per-sample NLL in bits.
|
|
80
|
+
"""
|
|
81
|
+
self.model.eval()
|
|
82
|
+
|
|
83
|
+
ln2 = math.log(2.0)
|
|
84
|
+
|
|
85
|
+
nlls_nats = []
|
|
86
|
+
nlls_bits = []
|
|
87
|
+
|
|
88
|
+
total_points = len(self.test_dataloader.dataset)
|
|
89
|
+
n_target_bits = len(self.target_idx)
|
|
90
|
+
|
|
91
|
+
outer_pbar = tqdm(
|
|
92
|
+
total=total_points,
|
|
93
|
+
desc="RBM conditional NLL",
|
|
94
|
+
leave=True,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
for batch_idx, v in enumerate(self.test_dataloader):
|
|
98
|
+
v = v.to(self.device)
|
|
99
|
+
|
|
100
|
+
for i in range(v.size(0)):
|
|
101
|
+
# Clamp visible units
|
|
102
|
+
v_clamp = torch.zeros(
|
|
103
|
+
self.model.nv,
|
|
104
|
+
device=self.device,
|
|
105
|
+
dtype=self.model.W.dtype,
|
|
106
|
+
)
|
|
107
|
+
v_clamp[self.clamp_idx] = v[i, self.clamp_idx]
|
|
108
|
+
|
|
109
|
+
# True target bits
|
|
110
|
+
true_bits = v[i, self.target_idx]
|
|
111
|
+
|
|
112
|
+
# Sample conditional
|
|
113
|
+
samples = self.model.sample_clamped(
|
|
114
|
+
v_clamp=v_clamp,
|
|
115
|
+
clamp_idx=self.clamp_idx,
|
|
116
|
+
n_samples=n_samples,
|
|
117
|
+
burn_in=burn_in,
|
|
118
|
+
thin=thin,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Get target bits from samples: (n_samples, n_target_bits)
|
|
122
|
+
sampled_bits = samples[:, self.target_idx]
|
|
123
|
+
|
|
124
|
+
# Compute per-bit empirical probability with Laplace smoothing
|
|
125
|
+
# For each target bit, estimate P(bit=true_value)
|
|
126
|
+
bit_matches = (sampled_bits == true_bits.unsqueeze(0)).float()
|
|
127
|
+
bit_probs = (bit_matches.sum(dim=0) + laplace_alpha) / (n_samples + 2 * laplace_alpha)
|
|
128
|
+
|
|
129
|
+
# NLL = -sum(log(p_i)) for each bit
|
|
130
|
+
# Clamp to avoid log(0)
|
|
131
|
+
bit_probs = bit_probs.clamp(min=1e-10, max=1 - 1e-10)
|
|
132
|
+
nll_nats = -torch.log(bit_probs).sum().item()
|
|
133
|
+
nll_bits = nll_nats / ln2
|
|
134
|
+
|
|
135
|
+
nlls_nats.append(nll_nats)
|
|
136
|
+
nlls_bits.append(nll_bits)
|
|
137
|
+
|
|
138
|
+
# Logging
|
|
139
|
+
if len(nlls_nats) % log_every == 0:
|
|
140
|
+
mean_nats = sum(nlls_nats) / len(nlls_nats)
|
|
141
|
+
mean_bits = sum(nlls_bits) / len(nlls_bits)
|
|
142
|
+
outer_pbar.set_postfix(
|
|
143
|
+
nll_nats=f"{mean_nats:.3f}",
|
|
144
|
+
nll_bits=f"{mean_bits:.3f}",
|
|
145
|
+
per_bit=f"{mean_nats / n_target_bits:.4f}",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
outer_pbar.update(1)
|
|
149
|
+
|
|
150
|
+
outer_pbar.close()
|
|
151
|
+
|
|
152
|
+
mean_nats = (
|
|
153
|
+
float(sum(nlls_nats) / len(nlls_nats)) if nlls_nats else float("nan")
|
|
154
|
+
)
|
|
155
|
+
mean_bits = (
|
|
156
|
+
float(sum(nlls_bits) / len(nlls_bits)) if nlls_bits else float("nan")
|
|
157
|
+
)
|
|
158
|
+
mean_per_bit = mean_nats / n_target_bits if n_target_bits > 0 else float("nan")
|
|
159
|
+
|
|
160
|
+
return {
|
|
161
|
+
"mean_nll_nats": mean_nats,
|
|
162
|
+
"mean_nll_bits": mean_bits,
|
|
163
|
+
"mean_nll_per_bit": mean_per_bit,
|
|
164
|
+
"n_target_bits": n_target_bits,
|
|
165
|
+
"nll_nats_per_sample": nlls_nats,
|
|
166
|
+
"nll_bits_per_sample": nlls_bits,
|
|
167
|
+
}
|
boltzmann9/utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Utility functions for device resolution and memory management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def resolve_device(device_cfg: str | None) -> torch.device:
|
|
9
|
+
"""Resolve device configuration to a torch.device.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
device_cfg: Device configuration string. Can be:
|
|
13
|
+
- None or "auto": Auto-detect best available device
|
|
14
|
+
- "cpu": Use CPU
|
|
15
|
+
- "cuda:0", "cuda:1", etc.: Use specific CUDA device
|
|
16
|
+
- "mps": Use Apple Metal Performance Shaders
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
torch.device instance for the resolved device.
|
|
20
|
+
"""
|
|
21
|
+
if device_cfg is None or device_cfg == "auto":
|
|
22
|
+
if torch.cuda.is_available():
|
|
23
|
+
return torch.device("cuda:0")
|
|
24
|
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
|
25
|
+
return torch.device("mps")
|
|
26
|
+
return torch.device("cpu")
|
|
27
|
+
return torch.device(device_cfg)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def resolve_pin_memory(pin_cfg, device: torch.device) -> bool:
|
|
31
|
+
"""Resolve pin_memory configuration based on device type.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
pin_cfg: Pin memory configuration. Can be "auto", True, or False.
|
|
35
|
+
device: The target torch device.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Boolean indicating whether to pin memory.
|
|
39
|
+
"""
|
|
40
|
+
if pin_cfg == "auto":
|
|
41
|
+
return device.type == "cuda"
|
|
42
|
+
return bool(pin_cfg)
|