boltzmann9 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {boltzmann9-0.1.1.dist-info → boltzmann9-0.1.2.dist-info}/METADATA +1 -1
- boltzmann9-0.1.2.dist-info/RECORD +5 -0
- boltzmann9-0.1.2.dist-info/top_level.txt +1 -0
- boltzmann/__init__.py +0 -38
- boltzmann/cli.py +0 -389
- boltzmann/config.py +0 -58
- boltzmann/data.py +0 -145
- boltzmann/data_generator.py +0 -234
- boltzmann/model.py +0 -867
- boltzmann/pipeline.py +0 -216
- boltzmann/preprocessor.py +0 -627
- boltzmann/project.py +0 -195
- boltzmann/run_utils.py +0 -262
- boltzmann/tester.py +0 -167
- boltzmann/utils.py +0 -42
- boltzmann/visualization.py +0 -115
- boltzmann9-0.1.1.dist-info/RECORD +0 -18
- boltzmann9-0.1.1.dist-info/top_level.txt +0 -1
- {boltzmann9-0.1.1.dist-info → boltzmann9-0.1.2.dist-info}/WHEEL +0 -0
- {boltzmann9-0.1.1.dist-info → boltzmann9-0.1.2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
boltzmann9-0.1.2.dist-info/METADATA,sha256=3ocsO1ccm-6SxfSo5CsBjzQRwLN6LtBXLIr3bS8kgmA,3097
|
|
2
|
+
boltzmann9-0.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
3
|
+
boltzmann9-0.1.2.dist-info/entry_points.txt,sha256=AZgq8QRNPYvXn6WQspFyXz4U2pcYcwg3RntGcdRXTHY,50
|
|
4
|
+
boltzmann9-0.1.2.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
5
|
+
boltzmann9-0.1.2.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
boltzmann/__init__.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
"""Boltzmann Machine library for PyTorch."""
|
|
2
|
-
|
|
3
|
-
from .model import RBM
|
|
4
|
-
from .data import BMDataset, GBMDataloader, split_rbm_loaders
|
|
5
|
-
from .utils import resolve_device, resolve_pin_memory
|
|
6
|
-
from .tester import RBMTester
|
|
7
|
-
from .config import load_config
|
|
8
|
-
from .data_generator import SyntheticDataGenerator, GeneratorConfig
|
|
9
|
-
from .pipeline import train_rbm, evaluate_rbm, save_model, load_model
|
|
10
|
-
from .project import create_project, create_run, list_runs, find_latest_run
|
|
11
|
-
|
|
12
|
-
__all__ = [
|
|
13
|
-
# Model
|
|
14
|
-
"RBM",
|
|
15
|
-
# Data
|
|
16
|
-
"BMDataset",
|
|
17
|
-
"GBMDataloader",
|
|
18
|
-
"split_rbm_loaders",
|
|
19
|
-
# Data generation
|
|
20
|
-
"SyntheticDataGenerator",
|
|
21
|
-
"GeneratorConfig",
|
|
22
|
-
# Utils
|
|
23
|
-
"resolve_device",
|
|
24
|
-
"resolve_pin_memory",
|
|
25
|
-
# Testing
|
|
26
|
-
"RBMTester",
|
|
27
|
-
# Config & Pipeline
|
|
28
|
-
"load_config",
|
|
29
|
-
"train_rbm",
|
|
30
|
-
"evaluate_rbm",
|
|
31
|
-
"save_model",
|
|
32
|
-
"load_model",
|
|
33
|
-
# Project management
|
|
34
|
-
"create_project",
|
|
35
|
-
"create_run",
|
|
36
|
-
"list_runs",
|
|
37
|
-
"find_latest_run",
|
|
38
|
-
]
|
boltzmann/cli.py
DELETED
|
@@ -1,389 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python
|
|
2
|
-
"""BoltzmaNN9 command-line interface."""
|
|
3
|
-
|
|
4
|
-
import argparse
|
|
5
|
-
import sys
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def cmd_new_project(args):
|
|
10
|
-
"""Create a new project."""
|
|
11
|
-
from boltzmann.project import create_project
|
|
12
|
-
|
|
13
|
-
try:
|
|
14
|
-
create_project(args.project_name)
|
|
15
|
-
except FileExistsError as e:
|
|
16
|
-
print(f"Error: {e}")
|
|
17
|
-
return 1
|
|
18
|
-
return 0
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def cmd_train(args):
|
|
22
|
-
"""Train a model."""
|
|
23
|
-
from boltzmann.config import load_config
|
|
24
|
-
from boltzmann.pipeline import train_rbm, save_model
|
|
25
|
-
from boltzmann.project import create_run, save_run_config, get_run_paths
|
|
26
|
-
from boltzmann.run_utils import RunLogger, save_history, save_plots
|
|
27
|
-
|
|
28
|
-
project_path = Path(args.project)
|
|
29
|
-
|
|
30
|
-
if not project_path.exists():
|
|
31
|
-
print(f"Error: Project not found: {project_path}")
|
|
32
|
-
return 1
|
|
33
|
-
|
|
34
|
-
config_path = project_path / "config.py"
|
|
35
|
-
if not config_path.exists():
|
|
36
|
-
print(f"Error: Config file not found: {config_path}")
|
|
37
|
-
return 1
|
|
38
|
-
|
|
39
|
-
# Create run directory
|
|
40
|
-
run_dir = create_run(project_path, args.name)
|
|
41
|
-
paths = get_run_paths(run_dir)
|
|
42
|
-
|
|
43
|
-
# Copy config to run
|
|
44
|
-
save_run_config(run_dir, config_path)
|
|
45
|
-
|
|
46
|
-
print(f"Starting training run: {run_dir.name}")
|
|
47
|
-
print("=" * 60)
|
|
48
|
-
|
|
49
|
-
# Train with logging
|
|
50
|
-
with RunLogger(paths["log"]):
|
|
51
|
-
# Load config and adjust data path to be relative to project
|
|
52
|
-
cfg = load_config(config_path)
|
|
53
|
-
|
|
54
|
-
# Make data path relative to project
|
|
55
|
-
data_path = cfg.get("data", {}).get("csv_path", "")
|
|
56
|
-
if data_path and not Path(data_path).is_absolute():
|
|
57
|
-
cfg["data"]["csv_path"] = str(project_path / data_path)
|
|
58
|
-
|
|
59
|
-
# Import here to avoid circular imports
|
|
60
|
-
from boltzmann.data import BMDataset, split_rbm_loaders
|
|
61
|
-
from boltzmann.model import RBM
|
|
62
|
-
from boltzmann.utils import resolve_device, resolve_pin_memory
|
|
63
|
-
|
|
64
|
-
# Setup
|
|
65
|
-
device = resolve_device(cfg.get("device", "auto"))
|
|
66
|
-
print(f"Using device: {device}")
|
|
67
|
-
|
|
68
|
-
data_cfg = cfg.get("data", {})
|
|
69
|
-
drop_cols = data_cfg.get("drop_cols", [])
|
|
70
|
-
dataset = BMDataset(cfg["data"]["csv_path"], drop_cols=drop_cols)
|
|
71
|
-
|
|
72
|
-
print(f"Loaded dataset: {len(dataset)} samples")
|
|
73
|
-
print(f" Columns: {dataset.columns}")
|
|
74
|
-
|
|
75
|
-
dl = cfg.get("dataloader", {})
|
|
76
|
-
pin_memory = resolve_pin_memory(dl.get("pin_memory", "auto"), device)
|
|
77
|
-
|
|
78
|
-
loaders = split_rbm_loaders(
|
|
79
|
-
dataset,
|
|
80
|
-
batch_size=dl.get("batch_size", 256),
|
|
81
|
-
split=tuple(dl.get("split", (0.8, 0.1, 0.1))),
|
|
82
|
-
seed=dl.get("seed", 42),
|
|
83
|
-
shuffle_train=dl.get("shuffle_train", True),
|
|
84
|
-
num_workers=dl.get("num_workers", 0),
|
|
85
|
-
pin_memory=pin_memory,
|
|
86
|
-
drop_last_train=dl.get("drop_last_train", True),
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
# Model
|
|
90
|
-
model_cfg = dict(cfg.get("model", {}))
|
|
91
|
-
model = RBM(model_cfg).to(device)
|
|
92
|
-
|
|
93
|
-
# Train
|
|
94
|
-
train_cfg = dict(cfg.get("train", {}))
|
|
95
|
-
history = model.fit(
|
|
96
|
-
loaders["train"],
|
|
97
|
-
val_loader=loaders["val"],
|
|
98
|
-
**train_cfg,
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
# Save outputs
|
|
102
|
-
import torch
|
|
103
|
-
checkpoint = {
|
|
104
|
-
"model_state_dict": model.state_dict(),
|
|
105
|
-
"nv": model.nv,
|
|
106
|
-
"nh": model.nh,
|
|
107
|
-
"config": cfg,
|
|
108
|
-
}
|
|
109
|
-
torch.save(checkpoint, paths["model"])
|
|
110
|
-
print(f"Model saved to: {paths['model']}")
|
|
111
|
-
|
|
112
|
-
save_history(history, paths["history"])
|
|
113
|
-
save_plots(history, paths["plots"], model=model)
|
|
114
|
-
|
|
115
|
-
print("\n" + "=" * 60)
|
|
116
|
-
print("TRAINING COMPLETE")
|
|
117
|
-
print("=" * 60)
|
|
118
|
-
print(f"Run directory: {run_dir}")
|
|
119
|
-
print(f" - model.pt")
|
|
120
|
-
print(f" - config.py")
|
|
121
|
-
print(f" - training.log")
|
|
122
|
-
print(f" - history.json")
|
|
123
|
-
print(f" - plots/")
|
|
124
|
-
|
|
125
|
-
return 0
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def cmd_evaluate(args):
|
|
129
|
-
"""Evaluate a trained model."""
|
|
130
|
-
from boltzmann.config import load_config
|
|
131
|
-
from boltzmann.pipeline import load_model
|
|
132
|
-
from boltzmann.project import get_run_paths, find_latest_run
|
|
133
|
-
from boltzmann.run_utils import save_metrics
|
|
134
|
-
|
|
135
|
-
# Determine run directory
|
|
136
|
-
run_path = Path(args.run)
|
|
137
|
-
|
|
138
|
-
if not run_path.exists():
|
|
139
|
-
# Maybe it's a project path - find latest run
|
|
140
|
-
project_path = run_path
|
|
141
|
-
run_path = find_latest_run(project_path)
|
|
142
|
-
if run_path is None:
|
|
143
|
-
print(f"Error: No runs found in project: {project_path}")
|
|
144
|
-
return 1
|
|
145
|
-
print(f"Using latest run: {run_path.name}")
|
|
146
|
-
|
|
147
|
-
paths = get_run_paths(run_path)
|
|
148
|
-
|
|
149
|
-
# Load model and config from run
|
|
150
|
-
if not paths["model"].exists():
|
|
151
|
-
print(f"Error: Model not found: {paths['model']}")
|
|
152
|
-
return 1
|
|
153
|
-
|
|
154
|
-
if not paths["config"].exists():
|
|
155
|
-
print(f"Error: Config not found: {paths['config']}")
|
|
156
|
-
return 1
|
|
157
|
-
|
|
158
|
-
model, saved_cfg = load_model(paths["model"])
|
|
159
|
-
cfg = load_config(paths["config"])
|
|
160
|
-
|
|
161
|
-
# Find project path from run path (output/run_xxx -> project)
|
|
162
|
-
project_path = run_path.parent.parent
|
|
163
|
-
|
|
164
|
-
# Adjust data path
|
|
165
|
-
data_path = cfg.get("data", {}).get("csv_path", "")
|
|
166
|
-
if data_path and not Path(data_path).is_absolute():
|
|
167
|
-
cfg["data"]["csv_path"] = str(project_path / data_path)
|
|
168
|
-
|
|
169
|
-
print(f"Evaluating run: {run_path.name}")
|
|
170
|
-
print("=" * 60)
|
|
171
|
-
|
|
172
|
-
# Import and setup
|
|
173
|
-
from boltzmann.data import BMDataset, split_rbm_loaders
|
|
174
|
-
from boltzmann.tester import RBMTester
|
|
175
|
-
from boltzmann.utils import resolve_device, resolve_pin_memory
|
|
176
|
-
|
|
177
|
-
device = resolve_device(cfg.get("device", "auto"))
|
|
178
|
-
model = model.to(device)
|
|
179
|
-
print(f"Using device: {device}")
|
|
180
|
-
|
|
181
|
-
data_cfg = cfg.get("data", {})
|
|
182
|
-
drop_cols = data_cfg.get("drop_cols", [])
|
|
183
|
-
dataset = BMDataset(cfg["data"]["csv_path"], drop_cols=drop_cols)
|
|
184
|
-
|
|
185
|
-
dl = cfg.get("dataloader", {})
|
|
186
|
-
pin_memory = resolve_pin_memory(dl.get("pin_memory", "auto"), device)
|
|
187
|
-
|
|
188
|
-
loaders = split_rbm_loaders(
|
|
189
|
-
dataset,
|
|
190
|
-
batch_size=dl.get("batch_size", 256),
|
|
191
|
-
split=tuple(dl.get("split", (0.8, 0.1, 0.1))),
|
|
192
|
-
seed=dl.get("seed", 42),
|
|
193
|
-
shuffle_train=False,
|
|
194
|
-
num_workers=dl.get("num_workers", 0),
|
|
195
|
-
pin_memory=pin_memory,
|
|
196
|
-
drop_last_train=False,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
print(f"Evaluating on {len(loaders['test'].dataset)} test samples")
|
|
200
|
-
|
|
201
|
-
# Basic metrics
|
|
202
|
-
eval_cfg = cfg.get("eval", {})
|
|
203
|
-
test_metrics = model.evaluate(loaders["test"], recon_k=eval_cfg.get("recon_k", 1))
|
|
204
|
-
print("Test metrics:", test_metrics)
|
|
205
|
-
|
|
206
|
-
# Conditional NLL
|
|
207
|
-
cond_cfg = cfg.get("conditional", {})
|
|
208
|
-
tester = RBMTester(
|
|
209
|
-
model=model,
|
|
210
|
-
test_dataloader=loaders["test"],
|
|
211
|
-
clamp_idx=cond_cfg["clamp_idx"],
|
|
212
|
-
target_idx=cond_cfg["target_idx"],
|
|
213
|
-
)
|
|
214
|
-
conditional_results = tester.conditional_nll(
|
|
215
|
-
n_samples=cond_cfg.get("n_samples", 100),
|
|
216
|
-
burn_in=cond_cfg.get("burn_in", 500),
|
|
217
|
-
thin=cond_cfg.get("thin", 10),
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
# Save metrics
|
|
221
|
-
all_metrics = {
|
|
222
|
-
"test_metrics": test_metrics,
|
|
223
|
-
"conditional_nll_nats": conditional_results["mean_nll_nats"],
|
|
224
|
-
"conditional_nll_bits": conditional_results["mean_nll_bits"],
|
|
225
|
-
"conditional_nll_per_bit": conditional_results["mean_nll_per_bit"],
|
|
226
|
-
"n_target_bits": conditional_results["n_target_bits"],
|
|
227
|
-
}
|
|
228
|
-
save_metrics(all_metrics, paths["metrics"])
|
|
229
|
-
|
|
230
|
-
print("\n" + "=" * 60)
|
|
231
|
-
print("EVALUATION RESULTS")
|
|
232
|
-
print("=" * 60)
|
|
233
|
-
print(f"Test metrics: {test_metrics}")
|
|
234
|
-
print(f"Conditional NLL (nats): {conditional_results['mean_nll_nats']:.4f}")
|
|
235
|
-
print(f"Conditional NLL (bits): {conditional_results['mean_nll_bits']:.4f}")
|
|
236
|
-
print(f"Conditional NLL per bit: {conditional_results['mean_nll_per_bit']:.4f}")
|
|
237
|
-
print(f"Target bits: {conditional_results['n_target_bits']}")
|
|
238
|
-
print(f"\nMetrics saved to: {paths['metrics']}")
|
|
239
|
-
|
|
240
|
-
return 0
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
def cmd_preprocess_raw(args):
|
|
244
|
-
"""Preprocess raw data based on config."""
|
|
245
|
-
from boltzmann.config import load_config
|
|
246
|
-
from boltzmann.preprocessor import DataPreprocessor
|
|
247
|
-
|
|
248
|
-
config_path = Path(args.config)
|
|
249
|
-
|
|
250
|
-
if not config_path.exists():
|
|
251
|
-
print(f"Error: Config file not found: {config_path}")
|
|
252
|
-
return 1
|
|
253
|
-
|
|
254
|
-
cfg = load_config(config_path)
|
|
255
|
-
|
|
256
|
-
print(f"Running preprocessor with config: {config_path}")
|
|
257
|
-
print("=" * 60)
|
|
258
|
-
|
|
259
|
-
preprocessor = DataPreprocessor(cfg, config_dir=config_path.parent)
|
|
260
|
-
df = preprocessor.fit_transform()
|
|
261
|
-
|
|
262
|
-
print(f"\nPreprocessing complete!")
|
|
263
|
-
print(f" Output CSV: {preprocessor.output_csv_path}")
|
|
264
|
-
print(f" Samples: {len(df)}")
|
|
265
|
-
print(f" Columns: {len(df.columns)}")
|
|
266
|
-
print(f" Visible blocks: {preprocessor.get_visible_blocks_sizes()}")
|
|
267
|
-
|
|
268
|
-
return 0
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def cmd_list_runs(args):
|
|
272
|
-
"""List all runs in a project."""
|
|
273
|
-
from boltzmann.project import list_runs
|
|
274
|
-
|
|
275
|
-
project_path = Path(args.project)
|
|
276
|
-
|
|
277
|
-
if not project_path.exists():
|
|
278
|
-
print(f"Error: Project not found: {project_path}")
|
|
279
|
-
return 1
|
|
280
|
-
|
|
281
|
-
runs = list_runs(project_path)
|
|
282
|
-
|
|
283
|
-
if not runs:
|
|
284
|
-
print(f"No runs found in: {project_path}")
|
|
285
|
-
return 0
|
|
286
|
-
|
|
287
|
-
print(f"Runs in {project_path}:")
|
|
288
|
-
print("-" * 40)
|
|
289
|
-
for run in runs:
|
|
290
|
-
print(f" {run.name}")
|
|
291
|
-
|
|
292
|
-
return 0
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
def main():
|
|
296
|
-
parser = argparse.ArgumentParser(
|
|
297
|
-
description="BoltzmaNN9 - Restricted Boltzmann Machine toolkit",
|
|
298
|
-
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
302
|
-
|
|
303
|
-
# new_project command
|
|
304
|
-
new_proj_parser = subparsers.add_parser(
|
|
305
|
-
"new_project",
|
|
306
|
-
help="Create a new project",
|
|
307
|
-
)
|
|
308
|
-
new_proj_parser.add_argument(
|
|
309
|
-
"project_name",
|
|
310
|
-
type=str,
|
|
311
|
-
help="Name/path for the new project",
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
# train command
|
|
315
|
-
train_parser = subparsers.add_parser(
|
|
316
|
-
"train",
|
|
317
|
-
help="Train a model",
|
|
318
|
-
)
|
|
319
|
-
train_parser.add_argument(
|
|
320
|
-
"--project", "-p",
|
|
321
|
-
type=str,
|
|
322
|
-
required=True,
|
|
323
|
-
help="Path to project directory",
|
|
324
|
-
)
|
|
325
|
-
train_parser.add_argument(
|
|
326
|
-
"--name", "-n",
|
|
327
|
-
type=str,
|
|
328
|
-
default=None,
|
|
329
|
-
help="Optional name suffix for the run",
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
# evaluate command
|
|
333
|
-
eval_parser = subparsers.add_parser(
|
|
334
|
-
"evaluate",
|
|
335
|
-
help="Evaluate a trained model",
|
|
336
|
-
)
|
|
337
|
-
eval_parser.add_argument(
|
|
338
|
-
"--run", "-r",
|
|
339
|
-
type=str,
|
|
340
|
-
required=True,
|
|
341
|
-
help="Path to run directory or project (uses latest run)",
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
# list command
|
|
345
|
-
list_parser = subparsers.add_parser(
|
|
346
|
-
"list",
|
|
347
|
-
help="List runs in a project",
|
|
348
|
-
)
|
|
349
|
-
list_parser.add_argument(
|
|
350
|
-
"--project", "-p",
|
|
351
|
-
type=str,
|
|
352
|
-
required=True,
|
|
353
|
-
help="Path to project directory",
|
|
354
|
-
)
|
|
355
|
-
|
|
356
|
-
# preprocess_raw command
|
|
357
|
-
preprocess_parser = subparsers.add_parser(
|
|
358
|
-
"preprocess_raw",
|
|
359
|
-
help="Preprocess raw data using DataPreprocessor",
|
|
360
|
-
)
|
|
361
|
-
preprocess_parser.add_argument(
|
|
362
|
-
"--config", "-c",
|
|
363
|
-
type=str,
|
|
364
|
-
required=True,
|
|
365
|
-
help="Path to config file",
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
args = parser.parse_args()
|
|
369
|
-
|
|
370
|
-
if args.command is None:
|
|
371
|
-
parser.print_help()
|
|
372
|
-
return 0
|
|
373
|
-
|
|
374
|
-
if args.command == "new_project":
|
|
375
|
-
return cmd_new_project(args)
|
|
376
|
-
elif args.command == "train":
|
|
377
|
-
return cmd_train(args)
|
|
378
|
-
elif args.command == "evaluate":
|
|
379
|
-
return cmd_evaluate(args)
|
|
380
|
-
elif args.command == "list":
|
|
381
|
-
return cmd_list_runs(args)
|
|
382
|
-
elif args.command == "preprocess_raw":
|
|
383
|
-
return cmd_preprocess_raw(args)
|
|
384
|
-
|
|
385
|
-
return 0
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
if __name__ == "__main__":
|
|
389
|
-
sys.exit(main())
|
boltzmann/config.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
"""Configuration loading utilities."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import importlib.util
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import Any, Dict
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def load_config(config_path: str | Path) -> Dict[str, Any]:
|
|
11
|
-
"""Load configuration from a Python file.
|
|
12
|
-
|
|
13
|
-
The config file should contain a dictionary named `config`.
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
config_path: Path to the configuration .py file.
|
|
17
|
-
|
|
18
|
-
Returns:
|
|
19
|
-
Configuration dictionary.
|
|
20
|
-
|
|
21
|
-
Raises:
|
|
22
|
-
FileNotFoundError: If config file doesn't exist.
|
|
23
|
-
ValueError: If config file doesn't contain a 'config' dictionary.
|
|
24
|
-
|
|
25
|
-
Example config.py:
|
|
26
|
-
config = {
|
|
27
|
-
"device": "auto",
|
|
28
|
-
"data": {"csv_path": "data.csv", ...},
|
|
29
|
-
...
|
|
30
|
-
}
|
|
31
|
-
"""
|
|
32
|
-
config_path = Path(config_path)
|
|
33
|
-
|
|
34
|
-
if not config_path.exists():
|
|
35
|
-
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
36
|
-
|
|
37
|
-
# Load the Python file as a module
|
|
38
|
-
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
|
39
|
-
if spec is None or spec.loader is None:
|
|
40
|
-
raise ValueError(f"Could not load config from: {config_path}")
|
|
41
|
-
|
|
42
|
-
module = importlib.util.module_from_spec(spec)
|
|
43
|
-
spec.loader.exec_module(module)
|
|
44
|
-
|
|
45
|
-
# Extract the config dictionary
|
|
46
|
-
if not hasattr(module, "config"):
|
|
47
|
-
raise ValueError(
|
|
48
|
-
f"Config file must contain a 'config' dictionary: {config_path}"
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
config = getattr(module, "config")
|
|
52
|
-
|
|
53
|
-
if not isinstance(config, dict):
|
|
54
|
-
raise ValueError(
|
|
55
|
-
f"'config' must be a dictionary, got {type(config).__name__}"
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
return config
|
boltzmann/data.py
DELETED
|
@@ -1,145 +0,0 @@
|
|
|
1
|
-
"""Dataset and DataLoader utilities for Boltzmann Machines."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Dict, Sequence
|
|
7
|
-
|
|
8
|
-
import pandas as pd
|
|
9
|
-
import torch
|
|
10
|
-
from torch.utils.data import Dataset, DataLoader, random_split
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class BMDataset(Dataset):
|
|
14
|
-
"""PyTorch Dataset for Boltzmann Machine training data.
|
|
15
|
-
|
|
16
|
-
Loads data from a CSV file.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
csv_path: Path to CSV file containing the training data.
|
|
20
|
-
Each row is a sample, columns are features.
|
|
21
|
-
drop_cols: Optional list of column names to drop from the data.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
csv_path: str | Path,
|
|
27
|
-
drop_cols: Sequence[str] | None = None,
|
|
28
|
-
):
|
|
29
|
-
self.csv_path = Path(csv_path)
|
|
30
|
-
self.df = pd.read_csv(self.csv_path)
|
|
31
|
-
|
|
32
|
-
if drop_cols:
|
|
33
|
-
self.df = self.df.drop(columns=list(drop_cols))
|
|
34
|
-
|
|
35
|
-
self.columns = list(self.df.columns)
|
|
36
|
-
|
|
37
|
-
def __len__(self):
|
|
38
|
-
return len(self.df)
|
|
39
|
-
|
|
40
|
-
def __getitem__(self, idx):
|
|
41
|
-
row = self.df.iloc[idx]
|
|
42
|
-
x = torch.tensor(row.values, dtype=torch.float32)
|
|
43
|
-
return x
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class GBMDataloader:
|
|
47
|
-
"""Simple wrapper around PyTorch DataLoader for Boltzmann Machine data.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
dataset: BMDataset instance.
|
|
51
|
-
batch_size: Number of samples per batch.
|
|
52
|
-
shuffle: Whether to shuffle data each epoch.
|
|
53
|
-
num_workers: Number of worker processes for data loading.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
def __init__(
|
|
57
|
-
self,
|
|
58
|
-
dataset: BMDataset,
|
|
59
|
-
batch_size: int = 1,
|
|
60
|
-
shuffle: bool = True,
|
|
61
|
-
num_workers: int = 0,
|
|
62
|
-
):
|
|
63
|
-
self.loader = DataLoader(
|
|
64
|
-
dataset,
|
|
65
|
-
batch_size=batch_size,
|
|
66
|
-
shuffle=shuffle,
|
|
67
|
-
num_workers=num_workers,
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
def __iter__(self):
|
|
71
|
-
return iter(self.loader)
|
|
72
|
-
|
|
73
|
-
def __len__(self):
|
|
74
|
-
return len(self.loader)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def split_rbm_loaders(
|
|
78
|
-
dataset: BMDataset,
|
|
79
|
-
*,
|
|
80
|
-
batch_size: int,
|
|
81
|
-
split: tuple[float, float, float] = (0.8, 0.1, 0.1),
|
|
82
|
-
seed: int = 42,
|
|
83
|
-
shuffle_train: bool = True,
|
|
84
|
-
num_workers: int = 0,
|
|
85
|
-
pin_memory: bool = True,
|
|
86
|
-
drop_last_train: bool = True,
|
|
87
|
-
) -> Dict[str, DataLoader]:
|
|
88
|
-
"""Split dataset and create train/val/test DataLoaders.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
dataset: BMDataset to split.
|
|
92
|
-
batch_size: Number of samples per batch.
|
|
93
|
-
split: Tuple of (train_frac, val_frac, test_frac). Must sum to 1.0.
|
|
94
|
-
seed: Random seed for reproducible splits.
|
|
95
|
-
shuffle_train: Whether to shuffle training data.
|
|
96
|
-
num_workers: Number of data loading workers.
|
|
97
|
-
pin_memory: Whether to pin memory (good for CUDA).
|
|
98
|
-
drop_last_train: Whether to drop last incomplete batch during training.
|
|
99
|
-
|
|
100
|
-
Returns:
|
|
101
|
-
Dictionary with 'train', 'val', 'test' DataLoader instances.
|
|
102
|
-
|
|
103
|
-
Raises:
|
|
104
|
-
ValueError: If split fractions don't sum to 1.0.
|
|
105
|
-
"""
|
|
106
|
-
train_frac, val_frac, test_frac = split
|
|
107
|
-
if abs(train_frac + val_frac + test_frac - 1.0) > 1e-6:
|
|
108
|
-
raise ValueError("split fractions must sum to 1.0")
|
|
109
|
-
|
|
110
|
-
n = len(dataset)
|
|
111
|
-
n_train = int(round(train_frac * n))
|
|
112
|
-
n_val = int(round(val_frac * n))
|
|
113
|
-
n_test = n - n_train - n_val
|
|
114
|
-
|
|
115
|
-
gen = torch.Generator().manual_seed(seed)
|
|
116
|
-
train_set, val_set, test_set = random_split(
|
|
117
|
-
dataset, [n_train, n_val, n_test], generator=gen
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
train_loader = DataLoader(
|
|
121
|
-
train_set,
|
|
122
|
-
batch_size=batch_size,
|
|
123
|
-
shuffle=shuffle_train,
|
|
124
|
-
drop_last=drop_last_train,
|
|
125
|
-
num_workers=num_workers,
|
|
126
|
-
pin_memory=pin_memory,
|
|
127
|
-
)
|
|
128
|
-
val_loader = DataLoader(
|
|
129
|
-
val_set,
|
|
130
|
-
batch_size=batch_size,
|
|
131
|
-
shuffle=False,
|
|
132
|
-
drop_last=False,
|
|
133
|
-
num_workers=num_workers,
|
|
134
|
-
pin_memory=pin_memory,
|
|
135
|
-
)
|
|
136
|
-
test_loader = DataLoader(
|
|
137
|
-
test_set,
|
|
138
|
-
batch_size=batch_size,
|
|
139
|
-
shuffle=False,
|
|
140
|
-
drop_last=False,
|
|
141
|
-
num_workers=num_workers,
|
|
142
|
-
pin_memory=pin_memory,
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
return {"train": train_loader, "val": val_loader, "test": test_loader}
|