boltzmann9 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/PKG-INFO +1 -1
  2. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/pyproject.toml +1 -1
  3. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/src/boltzmann9.egg-info/PKG-INFO +1 -1
  4. boltzmann9-0.1.4/src/boltzmann9.egg-info/SOURCES.txt +10 -0
  5. boltzmann9-0.1.4/src/boltzmann9.egg-info/top_level.txt +1 -0
  6. boltzmann9-0.1.3/src/boltzmann9/__init__.py +0 -38
  7. boltzmann9-0.1.3/src/boltzmann9/__main__.py +0 -4
  8. boltzmann9-0.1.3/src/boltzmann9/cli.py +0 -389
  9. boltzmann9-0.1.3/src/boltzmann9/config.py +0 -58
  10. boltzmann9-0.1.3/src/boltzmann9/data.py +0 -145
  11. boltzmann9-0.1.3/src/boltzmann9/data_generator.py +0 -234
  12. boltzmann9-0.1.3/src/boltzmann9/model.py +0 -867
  13. boltzmann9-0.1.3/src/boltzmann9/pipeline.py +0 -216
  14. boltzmann9-0.1.3/src/boltzmann9/preprocessor.py +0 -627
  15. boltzmann9-0.1.3/src/boltzmann9/project.py +0 -195
  16. boltzmann9-0.1.3/src/boltzmann9/run_utils.py +0 -262
  17. boltzmann9-0.1.3/src/boltzmann9/tester.py +0 -167
  18. boltzmann9-0.1.3/src/boltzmann9/utils.py +0 -42
  19. boltzmann9-0.1.3/src/boltzmann9/visualization.py +0 -115
  20. boltzmann9-0.1.3/src/boltzmann9.egg-info/SOURCES.txt +0 -24
  21. boltzmann9-0.1.3/src/boltzmann9.egg-info/top_level.txt +0 -1
  22. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/README.md +0 -0
  23. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/setup.cfg +0 -0
  24. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/src/boltzmann9.egg-info/dependency_links.txt +0 -0
  25. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/src/boltzmann9.egg-info/entry_points.txt +0 -0
  26. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/src/boltzmann9.egg-info/requires.txt +0 -0
  27. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/src/templates/config.py +0 -0
  28. {boltzmann9-0.1.3 → boltzmann9-0.1.4}/src/templates/synthetic_generator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: boltzmann9
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Restricted Boltzmann Machine implementation in PyTorch
5
5
  License: MIT
6
6
  Requires-Python: >=3.10
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "boltzmann9"
7
- version = "0.1.3"
7
+ version = "0.1.4"
8
8
  description = "Restricted Boltzmann Machine implementation in PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: boltzmann9
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Restricted Boltzmann Machine implementation in PyTorch
5
5
  License: MIT
6
6
  Requires-Python: >=3.10
@@ -0,0 +1,10 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/boltzmann9.egg-info/PKG-INFO
4
+ src/boltzmann9.egg-info/SOURCES.txt
5
+ src/boltzmann9.egg-info/dependency_links.txt
6
+ src/boltzmann9.egg-info/entry_points.txt
7
+ src/boltzmann9.egg-info/requires.txt
8
+ src/boltzmann9.egg-info/top_level.txt
9
+ src/templates/config.py
10
+ src/templates/synthetic_generator.py
@@ -1,38 +0,0 @@
1
- """Boltzmann Machine library for PyTorch."""
2
-
3
- from .model import RBM
4
- from .data import BMDataset, GBMDataloader, split_rbm_loaders
5
- from .utils import resolve_device, resolve_pin_memory
6
- from .tester import RBMTester
7
- from .config import load_config
8
- from .data_generator import SyntheticDataGenerator, GeneratorConfig
9
- from .pipeline import train_rbm, evaluate_rbm, save_model, load_model
10
- from .project import create_project, create_run, list_runs, find_latest_run
11
-
12
- __all__ = [
13
- # Model
14
- "RBM",
15
- # Data
16
- "BMDataset",
17
- "GBMDataloader",
18
- "split_rbm_loaders",
19
- # Data generation
20
- "SyntheticDataGenerator",
21
- "GeneratorConfig",
22
- # Utils
23
- "resolve_device",
24
- "resolve_pin_memory",
25
- # Testing
26
- "RBMTester",
27
- # Config & Pipeline
28
- "load_config",
29
- "train_rbm",
30
- "evaluate_rbm",
31
- "save_model",
32
- "load_model",
33
- # Project management
34
- "create_project",
35
- "create_run",
36
- "list_runs",
37
- "find_latest_run",
38
- ]
@@ -1,4 +0,0 @@
1
- from boltzmann9.cli import main
2
-
3
- if __name__ == "__main__":
4
- main()
@@ -1,389 +0,0 @@
1
- #!/usr/bin/env python
2
- """BoltzmaNN9 command-line interface."""
3
-
4
- import argparse
5
- import sys
6
- from pathlib import Path
7
-
8
-
9
- def cmd_new_project(args):
10
- """Create a new project."""
11
- from boltzmann.project import create_project
12
-
13
- try:
14
- create_project(args.project_name)
15
- except FileExistsError as e:
16
- print(f"Error: {e}")
17
- return 1
18
- return 0
19
-
20
-
21
- def cmd_train(args):
22
- """Train a model."""
23
- from boltzmann.config import load_config
24
- from boltzmann.pipeline import train_rbm, save_model
25
- from boltzmann.project import create_run, save_run_config, get_run_paths
26
- from boltzmann.run_utils import RunLogger, save_history, save_plots
27
-
28
- project_path = Path(args.project)
29
-
30
- if not project_path.exists():
31
- print(f"Error: Project not found: {project_path}")
32
- return 1
33
-
34
- config_path = project_path / "config.py"
35
- if not config_path.exists():
36
- print(f"Error: Config file not found: {config_path}")
37
- return 1
38
-
39
- # Create run directory
40
- run_dir = create_run(project_path, args.name)
41
- paths = get_run_paths(run_dir)
42
-
43
- # Copy config to run
44
- save_run_config(run_dir, config_path)
45
-
46
- print(f"Starting training run: {run_dir.name}")
47
- print("=" * 60)
48
-
49
- # Train with logging
50
- with RunLogger(paths["log"]):
51
- # Load config and adjust data path to be relative to project
52
- cfg = load_config(config_path)
53
-
54
- # Make data path relative to project
55
- data_path = cfg.get("data", {}).get("csv_path", "")
56
- if data_path and not Path(data_path).is_absolute():
57
- cfg["data"]["csv_path"] = str(project_path / data_path)
58
-
59
- # Import here to avoid circular imports
60
- from boltzmann.data import BMDataset, split_rbm_loaders
61
- from boltzmann.model import RBM
62
- from boltzmann.utils import resolve_device, resolve_pin_memory
63
-
64
- # Setup
65
- device = resolve_device(cfg.get("device", "auto"))
66
- print(f"Using device: {device}")
67
-
68
- data_cfg = cfg.get("data", {})
69
- drop_cols = data_cfg.get("drop_cols", [])
70
- dataset = BMDataset(cfg["data"]["csv_path"], drop_cols=drop_cols)
71
-
72
- print(f"Loaded dataset: {len(dataset)} samples")
73
- print(f" Columns: {dataset.columns}")
74
-
75
- dl = cfg.get("dataloader", {})
76
- pin_memory = resolve_pin_memory(dl.get("pin_memory", "auto"), device)
77
-
78
- loaders = split_rbm_loaders(
79
- dataset,
80
- batch_size=dl.get("batch_size", 256),
81
- split=tuple(dl.get("split", (0.8, 0.1, 0.1))),
82
- seed=dl.get("seed", 42),
83
- shuffle_train=dl.get("shuffle_train", True),
84
- num_workers=dl.get("num_workers", 0),
85
- pin_memory=pin_memory,
86
- drop_last_train=dl.get("drop_last_train", True),
87
- )
88
-
89
- # Model
90
- model_cfg = dict(cfg.get("model", {}))
91
- model = RBM(model_cfg).to(device)
92
-
93
- # Train
94
- train_cfg = dict(cfg.get("train", {}))
95
- history = model.fit(
96
- loaders["train"],
97
- val_loader=loaders["val"],
98
- **train_cfg,
99
- )
100
-
101
- # Save outputs
102
- import torch
103
- checkpoint = {
104
- "model_state_dict": model.state_dict(),
105
- "nv": model.nv,
106
- "nh": model.nh,
107
- "config": cfg,
108
- }
109
- torch.save(checkpoint, paths["model"])
110
- print(f"Model saved to: {paths['model']}")
111
-
112
- save_history(history, paths["history"])
113
- save_plots(history, paths["plots"], model=model)
114
-
115
- print("\n" + "=" * 60)
116
- print("TRAINING COMPLETE")
117
- print("=" * 60)
118
- print(f"Run directory: {run_dir}")
119
- print(f" - model.pt")
120
- print(f" - config.py")
121
- print(f" - training.log")
122
- print(f" - history.json")
123
- print(f" - plots/")
124
-
125
- return 0
126
-
127
-
128
- def cmd_evaluate(args):
129
- """Evaluate a trained model."""
130
- from boltzmann.config import load_config
131
- from boltzmann.pipeline import load_model
132
- from boltzmann.project import get_run_paths, find_latest_run
133
- from boltzmann.run_utils import save_metrics
134
-
135
- # Determine run directory
136
- run_path = Path(args.run)
137
-
138
- if not run_path.exists():
139
- # Maybe it's a project path - find latest run
140
- project_path = run_path
141
- run_path = find_latest_run(project_path)
142
- if run_path is None:
143
- print(f"Error: No runs found in project: {project_path}")
144
- return 1
145
- print(f"Using latest run: {run_path.name}")
146
-
147
- paths = get_run_paths(run_path)
148
-
149
- # Load model and config from run
150
- if not paths["model"].exists():
151
- print(f"Error: Model not found: {paths['model']}")
152
- return 1
153
-
154
- if not paths["config"].exists():
155
- print(f"Error: Config not found: {paths['config']}")
156
- return 1
157
-
158
- model, saved_cfg = load_model(paths["model"])
159
- cfg = load_config(paths["config"])
160
-
161
- # Find project path from run path (output/run_xxx -> project)
162
- project_path = run_path.parent.parent
163
-
164
- # Adjust data path
165
- data_path = cfg.get("data", {}).get("csv_path", "")
166
- if data_path and not Path(data_path).is_absolute():
167
- cfg["data"]["csv_path"] = str(project_path / data_path)
168
-
169
- print(f"Evaluating run: {run_path.name}")
170
- print("=" * 60)
171
-
172
- # Import and setup
173
- from boltzmann.data import BMDataset, split_rbm_loaders
174
- from boltzmann.tester import RBMTester
175
- from boltzmann.utils import resolve_device, resolve_pin_memory
176
-
177
- device = resolve_device(cfg.get("device", "auto"))
178
- model = model.to(device)
179
- print(f"Using device: {device}")
180
-
181
- data_cfg = cfg.get("data", {})
182
- drop_cols = data_cfg.get("drop_cols", [])
183
- dataset = BMDataset(cfg["data"]["csv_path"], drop_cols=drop_cols)
184
-
185
- dl = cfg.get("dataloader", {})
186
- pin_memory = resolve_pin_memory(dl.get("pin_memory", "auto"), device)
187
-
188
- loaders = split_rbm_loaders(
189
- dataset,
190
- batch_size=dl.get("batch_size", 256),
191
- split=tuple(dl.get("split", (0.8, 0.1, 0.1))),
192
- seed=dl.get("seed", 42),
193
- shuffle_train=False,
194
- num_workers=dl.get("num_workers", 0),
195
- pin_memory=pin_memory,
196
- drop_last_train=False,
197
- )
198
-
199
- print(f"Evaluating on {len(loaders['test'].dataset)} test samples")
200
-
201
- # Basic metrics
202
- eval_cfg = cfg.get("eval", {})
203
- test_metrics = model.evaluate(loaders["test"], recon_k=eval_cfg.get("recon_k", 1))
204
- print("Test metrics:", test_metrics)
205
-
206
- # Conditional NLL
207
- cond_cfg = cfg.get("conditional", {})
208
- tester = RBMTester(
209
- model=model,
210
- test_dataloader=loaders["test"],
211
- clamp_idx=cond_cfg["clamp_idx"],
212
- target_idx=cond_cfg["target_idx"],
213
- )
214
- conditional_results = tester.conditional_nll(
215
- n_samples=cond_cfg.get("n_samples", 100),
216
- burn_in=cond_cfg.get("burn_in", 500),
217
- thin=cond_cfg.get("thin", 10),
218
- )
219
-
220
- # Save metrics
221
- all_metrics = {
222
- "test_metrics": test_metrics,
223
- "conditional_nll_nats": conditional_results["mean_nll_nats"],
224
- "conditional_nll_bits": conditional_results["mean_nll_bits"],
225
- "conditional_nll_per_bit": conditional_results["mean_nll_per_bit"],
226
- "n_target_bits": conditional_results["n_target_bits"],
227
- }
228
- save_metrics(all_metrics, paths["metrics"])
229
-
230
- print("\n" + "=" * 60)
231
- print("EVALUATION RESULTS")
232
- print("=" * 60)
233
- print(f"Test metrics: {test_metrics}")
234
- print(f"Conditional NLL (nats): {conditional_results['mean_nll_nats']:.4f}")
235
- print(f"Conditional NLL (bits): {conditional_results['mean_nll_bits']:.4f}")
236
- print(f"Conditional NLL per bit: {conditional_results['mean_nll_per_bit']:.4f}")
237
- print(f"Target bits: {conditional_results['n_target_bits']}")
238
- print(f"\nMetrics saved to: {paths['metrics']}")
239
-
240
- return 0
241
-
242
-
243
- def cmd_preprocess_raw(args):
244
- """Preprocess raw data based on config."""
245
- from boltzmann.config import load_config
246
- from boltzmann.preprocessor import DataPreprocessor
247
-
248
- config_path = Path(args.config)
249
-
250
- if not config_path.exists():
251
- print(f"Error: Config file not found: {config_path}")
252
- return 1
253
-
254
- cfg = load_config(config_path)
255
-
256
- print(f"Running preprocessor with config: {config_path}")
257
- print("=" * 60)
258
-
259
- preprocessor = DataPreprocessor(cfg, config_dir=config_path.parent)
260
- df = preprocessor.fit_transform()
261
-
262
- print(f"\nPreprocessing complete!")
263
- print(f" Output CSV: {preprocessor.output_csv_path}")
264
- print(f" Samples: {len(df)}")
265
- print(f" Columns: {len(df.columns)}")
266
- print(f" Visible blocks: {preprocessor.get_visible_blocks_sizes()}")
267
-
268
- return 0
269
-
270
-
271
- def cmd_list_runs(args):
272
- """List all runs in a project."""
273
- from boltzmann.project import list_runs
274
-
275
- project_path = Path(args.project)
276
-
277
- if not project_path.exists():
278
- print(f"Error: Project not found: {project_path}")
279
- return 1
280
-
281
- runs = list_runs(project_path)
282
-
283
- if not runs:
284
- print(f"No runs found in: {project_path}")
285
- return 0
286
-
287
- print(f"Runs in {project_path}:")
288
- print("-" * 40)
289
- for run in runs:
290
- print(f" {run.name}")
291
-
292
- return 0
293
-
294
-
295
- def main():
296
- parser = argparse.ArgumentParser(
297
- description="BoltzmaNN9 - Restricted Boltzmann Machine toolkit",
298
- formatter_class=argparse.RawDescriptionHelpFormatter,
299
- )
300
-
301
- subparsers = parser.add_subparsers(dest="command", help="Available commands")
302
-
303
- # new_project command
304
- new_proj_parser = subparsers.add_parser(
305
- "new_project",
306
- help="Create a new project",
307
- )
308
- new_proj_parser.add_argument(
309
- "project_name",
310
- type=str,
311
- help="Name/path for the new project",
312
- )
313
-
314
- # train command
315
- train_parser = subparsers.add_parser(
316
- "train",
317
- help="Train a model",
318
- )
319
- train_parser.add_argument(
320
- "--project", "-p",
321
- type=str,
322
- required=True,
323
- help="Path to project directory",
324
- )
325
- train_parser.add_argument(
326
- "--name", "-n",
327
- type=str,
328
- default=None,
329
- help="Optional name suffix for the run",
330
- )
331
-
332
- # evaluate command
333
- eval_parser = subparsers.add_parser(
334
- "evaluate",
335
- help="Evaluate a trained model",
336
- )
337
- eval_parser.add_argument(
338
- "--run", "-r",
339
- type=str,
340
- required=True,
341
- help="Path to run directory or project (uses latest run)",
342
- )
343
-
344
- # list command
345
- list_parser = subparsers.add_parser(
346
- "list",
347
- help="List runs in a project",
348
- )
349
- list_parser.add_argument(
350
- "--project", "-p",
351
- type=str,
352
- required=True,
353
- help="Path to project directory",
354
- )
355
-
356
- # preprocess_raw command
357
- preprocess_parser = subparsers.add_parser(
358
- "preprocess_raw",
359
- help="Preprocess raw data using DataPreprocessor",
360
- )
361
- preprocess_parser.add_argument(
362
- "--config", "-c",
363
- type=str,
364
- required=True,
365
- help="Path to config file",
366
- )
367
-
368
- args = parser.parse_args()
369
-
370
- if args.command is None:
371
- parser.print_help()
372
- return 0
373
-
374
- if args.command == "new_project":
375
- return cmd_new_project(args)
376
- elif args.command == "train":
377
- return cmd_train(args)
378
- elif args.command == "evaluate":
379
- return cmd_evaluate(args)
380
- elif args.command == "list":
381
- return cmd_list_runs(args)
382
- elif args.command == "preprocess_raw":
383
- return cmd_preprocess_raw(args)
384
-
385
- return 0
386
-
387
-
388
- if __name__ == "__main__":
389
- sys.exit(main())
@@ -1,58 +0,0 @@
1
- """Configuration loading utilities."""
2
-
3
- from __future__ import annotations
4
-
5
- import importlib.util
6
- from pathlib import Path
7
- from typing import Any, Dict
8
-
9
-
10
- def load_config(config_path: str | Path) -> Dict[str, Any]:
11
- """Load configuration from a Python file.
12
-
13
- The config file should contain a dictionary named `config`.
14
-
15
- Args:
16
- config_path: Path to the configuration .py file.
17
-
18
- Returns:
19
- Configuration dictionary.
20
-
21
- Raises:
22
- FileNotFoundError: If config file doesn't exist.
23
- ValueError: If config file doesn't contain a 'config' dictionary.
24
-
25
- Example config.py:
26
- config = {
27
- "device": "auto",
28
- "data": {"csv_path": "data.csv", ...},
29
- ...
30
- }
31
- """
32
- config_path = Path(config_path)
33
-
34
- if not config_path.exists():
35
- raise FileNotFoundError(f"Config file not found: {config_path}")
36
-
37
- # Load the Python file as a module
38
- spec = importlib.util.spec_from_file_location("config_module", config_path)
39
- if spec is None or spec.loader is None:
40
- raise ValueError(f"Could not load config from: {config_path}")
41
-
42
- module = importlib.util.module_from_spec(spec)
43
- spec.loader.exec_module(module)
44
-
45
- # Extract the config dictionary
46
- if not hasattr(module, "config"):
47
- raise ValueError(
48
- f"Config file must contain a 'config' dictionary: {config_path}"
49
- )
50
-
51
- config = getattr(module, "config")
52
-
53
- if not isinstance(config, dict):
54
- raise ValueError(
55
- f"'config' must be a dictionary, got {type(config).__name__}"
56
- )
57
-
58
- return config
@@ -1,145 +0,0 @@
1
- """Dataset and DataLoader utilities for Boltzmann Machines."""
2
-
3
- from __future__ import annotations
4
-
5
- from pathlib import Path
6
- from typing import Dict, Sequence
7
-
8
- import pandas as pd
9
- import torch
10
- from torch.utils.data import Dataset, DataLoader, random_split
11
-
12
-
13
- class BMDataset(Dataset):
14
- """PyTorch Dataset for Boltzmann Machine training data.
15
-
16
- Loads data from a CSV file.
17
-
18
- Args:
19
- csv_path: Path to CSV file containing the training data.
20
- Each row is a sample, columns are features.
21
- drop_cols: Optional list of column names to drop from the data.
22
- """
23
-
24
- def __init__(
25
- self,
26
- csv_path: str | Path,
27
- drop_cols: Sequence[str] | None = None,
28
- ):
29
- self.csv_path = Path(csv_path)
30
- self.df = pd.read_csv(self.csv_path)
31
-
32
- if drop_cols:
33
- self.df = self.df.drop(columns=list(drop_cols))
34
-
35
- self.columns = list(self.df.columns)
36
-
37
- def __len__(self):
38
- return len(self.df)
39
-
40
- def __getitem__(self, idx):
41
- row = self.df.iloc[idx]
42
- x = torch.tensor(row.values, dtype=torch.float32)
43
- return x
44
-
45
-
46
- class GBMDataloader:
47
- """Simple wrapper around PyTorch DataLoader for Boltzmann Machine data.
48
-
49
- Args:
50
- dataset: BMDataset instance.
51
- batch_size: Number of samples per batch.
52
- shuffle: Whether to shuffle data each epoch.
53
- num_workers: Number of worker processes for data loading.
54
- """
55
-
56
- def __init__(
57
- self,
58
- dataset: BMDataset,
59
- batch_size: int = 1,
60
- shuffle: bool = True,
61
- num_workers: int = 0,
62
- ):
63
- self.loader = DataLoader(
64
- dataset,
65
- batch_size=batch_size,
66
- shuffle=shuffle,
67
- num_workers=num_workers,
68
- )
69
-
70
- def __iter__(self):
71
- return iter(self.loader)
72
-
73
- def __len__(self):
74
- return len(self.loader)
75
-
76
-
77
- def split_rbm_loaders(
78
- dataset: BMDataset,
79
- *,
80
- batch_size: int,
81
- split: tuple[float, float, float] = (0.8, 0.1, 0.1),
82
- seed: int = 42,
83
- shuffle_train: bool = True,
84
- num_workers: int = 0,
85
- pin_memory: bool = True,
86
- drop_last_train: bool = True,
87
- ) -> Dict[str, DataLoader]:
88
- """Split dataset and create train/val/test DataLoaders.
89
-
90
- Args:
91
- dataset: BMDataset to split.
92
- batch_size: Number of samples per batch.
93
- split: Tuple of (train_frac, val_frac, test_frac). Must sum to 1.0.
94
- seed: Random seed for reproducible splits.
95
- shuffle_train: Whether to shuffle training data.
96
- num_workers: Number of data loading workers.
97
- pin_memory: Whether to pin memory (good for CUDA).
98
- drop_last_train: Whether to drop last incomplete batch during training.
99
-
100
- Returns:
101
- Dictionary with 'train', 'val', 'test' DataLoader instances.
102
-
103
- Raises:
104
- ValueError: If split fractions don't sum to 1.0.
105
- """
106
- train_frac, val_frac, test_frac = split
107
- if abs(train_frac + val_frac + test_frac - 1.0) > 1e-6:
108
- raise ValueError("split fractions must sum to 1.0")
109
-
110
- n = len(dataset)
111
- n_train = int(round(train_frac * n))
112
- n_val = int(round(val_frac * n))
113
- n_test = n - n_train - n_val
114
-
115
- gen = torch.Generator().manual_seed(seed)
116
- train_set, val_set, test_set = random_split(
117
- dataset, [n_train, n_val, n_test], generator=gen
118
- )
119
-
120
- train_loader = DataLoader(
121
- train_set,
122
- batch_size=batch_size,
123
- shuffle=shuffle_train,
124
- drop_last=drop_last_train,
125
- num_workers=num_workers,
126
- pin_memory=pin_memory,
127
- )
128
- val_loader = DataLoader(
129
- val_set,
130
- batch_size=batch_size,
131
- shuffle=False,
132
- drop_last=False,
133
- num_workers=num_workers,
134
- pin_memory=pin_memory,
135
- )
136
- test_loader = DataLoader(
137
- test_set,
138
- batch_size=batch_size,
139
- shuffle=False,
140
- drop_last=False,
141
- num_workers=num_workers,
142
- pin_memory=pin_memory,
143
- )
144
-
145
- return {"train": train_loader, "val": val_loader, "test": test_loader}