boltzmann9 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
boltzmann9/project.py ADDED
@@ -0,0 +1,195 @@
1
+ """Project and run management utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ import shutil
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any, Dict
10
+
11
+
12
+ def _get_templates_dir() -> Path:
13
+ """Get the path to the templates directory."""
14
+ return Path(__file__).parent.parent / "templates"
15
+
16
+
17
+ def _load_config_template(data_path: str) -> str:
18
+ """Load config.py template and substitute the data path.
19
+
20
+ Args:
21
+ data_path: Path to set for csv_path in the config.
22
+
23
+ Returns:
24
+ Config file content with substituted data path.
25
+ """
26
+ template_path = _get_templates_dir() / "config.py"
27
+ content = template_path.read_text()
28
+
29
+ # Replace the csv_path value (handles various quoting styles)
30
+ content = re.sub(
31
+ r'("csv_path":\s*)"[^"]*"',
32
+ f'\\1"{data_path}"',
33
+ content
34
+ )
35
+
36
+ return content
37
+
38
+
39
+ def _load_generator_template() -> str:
40
+ """Load synthetic_generator.py template from templates directory.
41
+
42
+ Returns:
43
+ Generator script content.
44
+ """
45
+ template_path = _get_templates_dir() / "synthetic_generator.py"
46
+ return template_path.read_text()
47
+
48
+
49
+ def create_project(project_path: str | Path) -> Path:
50
+ """Create a new BoltzmaNN9 project.
51
+
52
+ Args:
53
+ project_path: Path where the project will be created.
54
+
55
+ Returns:
56
+ Path to the created project.
57
+
58
+ Raises:
59
+ FileExistsError: If project directory already exists.
60
+ """
61
+ project_path = Path(project_path)
62
+
63
+ if project_path.exists():
64
+ raise FileExistsError(f"Project directory already exists: {project_path}")
65
+
66
+ # Create project structure
67
+ project_path.mkdir(parents=True)
68
+ data_dir = project_path / "data"
69
+ data_dir.mkdir()
70
+ output_dir = project_path / "output"
71
+ output_dir.mkdir()
72
+
73
+ # Create config.py from template with correct data path
74
+ data_csv_path = "data/data.csv"
75
+ config_content = _load_config_template(data_csv_path)
76
+ config_file = project_path / "config.py"
77
+ config_file.write_text(config_content)
78
+
79
+ # Create synthetic_generator.py in data folder from template
80
+ generator_content = _load_generator_template()
81
+ generator_file = data_dir / "synthetic_generator.py"
82
+ generator_file.write_text(generator_content)
83
+
84
+ print(f"Created project: {project_path}")
85
+ print(f" - config.py")
86
+ print(f" - data/")
87
+ print(f" - synthetic_generator.py")
88
+ print(f" - output/")
89
+ print(f"\nNext steps:")
90
+ print(f" 1. cd {project_path}/data && python synthetic_generator.py")
91
+ print(f" 2. Edit {project_path}/config.py as needed")
92
+ print(f" 3. python boltzmann.py train --project {project_path}")
93
+
94
+ return project_path
95
+
96
+
97
+ def create_run(project_path: str | Path, run_name: str | None = None) -> Path:
98
+ """Create a new run directory within a project.
99
+
100
+ Args:
101
+ project_path: Path to the project.
102
+ run_name: Optional custom run name. If None, uses timestamp.
103
+
104
+ Returns:
105
+ Path to the created run directory.
106
+ """
107
+ project_path = Path(project_path)
108
+ output_dir = project_path / "output"
109
+
110
+ if not output_dir.exists():
111
+ output_dir.mkdir(parents=True)
112
+
113
+ # Generate run name with timestamp
114
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
115
+ if run_name:
116
+ run_dir_name = f"run_{timestamp}_{run_name}"
117
+ else:
118
+ run_dir_name = f"run_{timestamp}"
119
+
120
+ run_dir = output_dir / run_dir_name
121
+ run_dir.mkdir()
122
+
123
+ # Create subdirectories
124
+ (run_dir / "plots").mkdir()
125
+ (run_dir / "checkpoints").mkdir()
126
+
127
+ return run_dir
128
+
129
+
130
+ def save_run_config(run_dir: Path, config_path: Path) -> None:
131
+ """Copy config file to run directory.
132
+
133
+ Args:
134
+ run_dir: Path to the run directory.
135
+ config_path: Path to the original config file.
136
+ """
137
+ dest = run_dir / "config.py"
138
+ shutil.copy2(config_path, dest)
139
+
140
+
141
+ def get_run_paths(run_dir: Path) -> Dict[str, Path]:
142
+ """Get standard paths within a run directory.
143
+
144
+ Args:
145
+ run_dir: Path to the run directory.
146
+
147
+ Returns:
148
+ Dictionary with paths for model, log, plots, checkpoints, config.
149
+ """
150
+ return {
151
+ "model": run_dir / "model.pt",
152
+ "log": run_dir / "training.log",
153
+ "plots": run_dir / "plots",
154
+ "checkpoints": run_dir / "checkpoints",
155
+ "config": run_dir / "config.py",
156
+ "history": run_dir / "history.json",
157
+ "metrics": run_dir / "metrics.json",
158
+ }
159
+
160
+
161
+ def find_latest_run(project_path: str | Path) -> Path | None:
162
+ """Find the most recent run in a project.
163
+
164
+ Args:
165
+ project_path: Path to the project.
166
+
167
+ Returns:
168
+ Path to the latest run directory, or None if no runs exist.
169
+ """
170
+ project_path = Path(project_path)
171
+ output_dir = project_path / "output"
172
+
173
+ if not output_dir.exists():
174
+ return None
175
+
176
+ runs = sorted(output_dir.glob("run_*"), reverse=True)
177
+ return runs[0] if runs else None
178
+
179
+
180
+ def list_runs(project_path: str | Path) -> list[Path]:
181
+ """List all runs in a project.
182
+
183
+ Args:
184
+ project_path: Path to the project.
185
+
186
+ Returns:
187
+ List of run directory paths, sorted by date (newest first).
188
+ """
189
+ project_path = Path(project_path)
190
+ output_dir = project_path / "output"
191
+
192
+ if not output_dir.exists():
193
+ return []
194
+
195
+ return sorted(output_dir.glob("run_*"), reverse=True)
@@ -0,0 +1,262 @@
1
+ """Utilities for run logging and visualization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any, Dict, TextIO
10
+
11
+
12
+ class RunLogger:
13
+ """Logger that writes to both console and file."""
14
+
15
+ def __init__(self, log_path: Path):
16
+ """Initialize logger.
17
+
18
+ Args:
19
+ log_path: Path to log file.
20
+ """
21
+ self.log_path = log_path
22
+ self.log_file: TextIO | None = None
23
+ self._original_stdout = sys.stdout
24
+
25
+ def __enter__(self):
26
+ self.log_file = open(self.log_path, "w")
27
+ sys.stdout = self
28
+ return self
29
+
30
+ def __exit__(self, exc_type, exc_val, exc_tb):
31
+ sys.stdout = self._original_stdout
32
+ if self.log_file:
33
+ self.log_file.close()
34
+ return False
35
+
36
+ def write(self, message: str):
37
+ self._original_stdout.write(message)
38
+ if self.log_file:
39
+ self.log_file.write(message)
40
+ self.log_file.flush()
41
+
42
+ def flush(self):
43
+ self._original_stdout.flush()
44
+ if self.log_file:
45
+ self.log_file.flush()
46
+
47
+
48
+ def save_history(history: Dict[str, list], path: Path) -> None:
49
+ """Save training history to JSON.
50
+
51
+ Args:
52
+ history: Training history dictionary.
53
+ path: Path to save JSON file.
54
+ """
55
+ # Convert any non-serializable types
56
+ serializable = {}
57
+ for key, values in history.items():
58
+ serializable[key] = [
59
+ float(v) if isinstance(v, (int, float)) else v
60
+ for v in values
61
+ ]
62
+
63
+ with open(path, "w") as f:
64
+ json.dump(serializable, f, indent=2)
65
+
66
+
67
+ def load_history(path: Path) -> Dict[str, list]:
68
+ """Load training history from JSON.
69
+
70
+ Args:
71
+ path: Path to history JSON file.
72
+
73
+ Returns:
74
+ Training history dictionary.
75
+ """
76
+ with open(path) as f:
77
+ return json.load(f)
78
+
79
+
80
+ def save_metrics(metrics: Dict[str, Any], path: Path) -> None:
81
+ """Save evaluation metrics to JSON.
82
+
83
+ Args:
84
+ metrics: Metrics dictionary.
85
+ path: Path to save JSON file.
86
+ """
87
+ # Convert nested structures
88
+ def convert(obj):
89
+ if isinstance(obj, dict):
90
+ return {k: convert(v) for k, v in obj.items()}
91
+ elif isinstance(obj, list):
92
+ return [convert(v) for v in obj]
93
+ elif isinstance(obj, (int, float)):
94
+ return float(obj)
95
+ else:
96
+ return obj
97
+
98
+ with open(path, "w") as f:
99
+ json.dump(convert(metrics), f, indent=2)
100
+
101
+
102
+ def load_metrics(path: Path) -> Dict[str, Any]:
103
+ """Load evaluation metrics from JSON.
104
+
105
+ Args:
106
+ path: Path to metrics JSON file.
107
+
108
+ Returns:
109
+ Metrics dictionary.
110
+ """
111
+ with open(path) as f:
112
+ return json.load(f)
113
+
114
+
115
+ def save_plots(history: Dict[str, list], plots_dir: Path, model: Any = None) -> None:
116
+ """Save training plots to directory.
117
+
118
+ Args:
119
+ history: Training history dictionary.
120
+ plots_dir: Directory to save plots.
121
+ model: Optional RBM model instance. If provided, saves block diagram.
122
+ """
123
+ try:
124
+ import matplotlib.pyplot as plt
125
+ except ImportError:
126
+ print("[save_plots] matplotlib not available, skipping plots")
127
+ return
128
+
129
+ import math
130
+
131
+ plots_dir.mkdir(parents=True, exist_ok=True)
132
+
133
+ # Save RBM block diagram if model is provided
134
+ if model is not None and hasattr(model, "draw_blocks"):
135
+ try:
136
+ model.draw_blocks(save_path=str(plots_dir / "RBM_blocks.png"), show=False)
137
+ except Exception as e:
138
+ print(f"[save_plots] Failed to save block diagram: {e}")
139
+
140
+ epochs = history.get(
141
+ "epoch", list(range(1, len(history.get("train_free_energy", [])) + 1))
142
+ )
143
+
144
+ def _is_all_nan(xs):
145
+ if not xs:
146
+ return True
147
+ return all(
148
+ (x is None) or (isinstance(x, float) and math.isnan(x)) for x in xs
149
+ )
150
+
151
+ # Plot 1: Free Energy
152
+ fig, ax = plt.subplots(figsize=(10, 6))
153
+ y_tr = history.get("train_free_energy", [])
154
+ y_va = history.get("val_free_energy", [])
155
+
156
+ ax.plot(epochs[:len(y_tr)], y_tr, label="train", linewidth=2)
157
+ if not _is_all_nan(y_va):
158
+ ax.plot(epochs[:len(y_va)], y_va, label="val", linewidth=2)
159
+
160
+ ax.set_title("Free Energy", fontsize=14)
161
+ ax.set_xlabel("Epoch", fontsize=12)
162
+ ax.set_ylabel("Mean FE (lower is better)", fontsize=12)
163
+ ax.legend()
164
+ ax.grid(True, alpha=0.3)
165
+ fig.tight_layout()
166
+ fig.savefig(plots_dir / "free_energy.png", dpi=150)
167
+ plt.close(fig)
168
+
169
+ # Plot 2: Reconstruction MSE
170
+ fig, ax = plt.subplots(figsize=(10, 6))
171
+ y_tr = history.get("train_recon_mse", [])
172
+ y_va = history.get("val_recon_mse", [])
173
+
174
+ ax.plot(epochs[:len(y_tr)], y_tr, label="train", linewidth=2)
175
+ if not _is_all_nan(y_va):
176
+ ax.plot(epochs[:len(y_va)], y_va, label="val", linewidth=2)
177
+
178
+ ax.set_title("Reconstruction MSE", fontsize=14)
179
+ ax.set_xlabel("Epoch", fontsize=12)
180
+ ax.set_ylabel("MSE", fontsize=12)
181
+ ax.legend()
182
+ ax.grid(True, alpha=0.3)
183
+ fig.tight_layout()
184
+ fig.savefig(plots_dir / "reconstruction_mse.png", dpi=150)
185
+ plt.close(fig)
186
+
187
+ # Plot 3: Reconstruction Bit Error
188
+ fig, ax = plt.subplots(figsize=(10, 6))
189
+ y_tr = history.get("train_recon_bit_error", [])
190
+ y_va = history.get("val_recon_bit_error", [])
191
+
192
+ ax.plot(epochs[:len(y_tr)], y_tr, label="train", linewidth=2)
193
+ if not _is_all_nan(y_va):
194
+ ax.plot(epochs[:len(y_va)], y_va, label="val", linewidth=2)
195
+
196
+ ax.set_title("Reconstruction Bit Error", fontsize=14)
197
+ ax.set_xlabel("Epoch", fontsize=12)
198
+ ax.set_ylabel("Fraction Mismatched", fontsize=12)
199
+ ax.legend()
200
+ ax.grid(True, alpha=0.3)
201
+ fig.tight_layout()
202
+ fig.savefig(plots_dir / "bit_error.png", dpi=150)
203
+ plt.close(fig)
204
+
205
+ # Plot 4: Learning Rate
206
+ if "lr" in history and history["lr"]:
207
+ fig, ax = plt.subplots(figsize=(10, 6))
208
+ lr_vals = history["lr"]
209
+ ax.plot(epochs[:len(lr_vals)], lr_vals, linewidth=2, color="green")
210
+ ax.set_title("Learning Rate Schedule", fontsize=14)
211
+ ax.set_xlabel("Epoch", fontsize=12)
212
+ ax.set_ylabel("Learning Rate", fontsize=12)
213
+ ax.set_yscale("log")
214
+ ax.grid(True, alpha=0.3)
215
+ fig.tight_layout()
216
+ fig.savefig(plots_dir / "learning_rate.png", dpi=150)
217
+ plt.close(fig)
218
+
219
+ # Combined plot
220
+ fig, axes = plt.subplots(1, 3, figsize=(16, 4))
221
+
222
+ # Free Energy
223
+ y_tr = history.get("train_free_energy", [])
224
+ y_va = history.get("val_free_energy", [])
225
+ axes[0].plot(epochs[:len(y_tr)], y_tr, label="train")
226
+ if not _is_all_nan(y_va):
227
+ axes[0].plot(epochs[:len(y_va)], y_va, label="val")
228
+ axes[0].set_title("Free Energy")
229
+ axes[0].set_xlabel("Epoch")
230
+ axes[0].set_ylabel("Mean FE")
231
+ axes[0].legend()
232
+ axes[0].grid(True, alpha=0.3)
233
+
234
+ # MSE
235
+ y_tr = history.get("train_recon_mse", [])
236
+ y_va = history.get("val_recon_mse", [])
237
+ axes[1].plot(epochs[:len(y_tr)], y_tr, label="train")
238
+ if not _is_all_nan(y_va):
239
+ axes[1].plot(epochs[:len(y_va)], y_va, label="val")
240
+ axes[1].set_title("Reconstruction MSE")
241
+ axes[1].set_xlabel("Epoch")
242
+ axes[1].set_ylabel("MSE")
243
+ axes[1].legend()
244
+ axes[1].grid(True, alpha=0.3)
245
+
246
+ # Bit Error
247
+ y_tr = history.get("train_recon_bit_error", [])
248
+ y_va = history.get("val_recon_bit_error", [])
249
+ axes[2].plot(epochs[:len(y_tr)], y_tr, label="train")
250
+ if not _is_all_nan(y_va):
251
+ axes[2].plot(epochs[:len(y_va)], y_va, label="val")
252
+ axes[2].set_title("Bit Error")
253
+ axes[2].set_xlabel("Epoch")
254
+ axes[2].set_ylabel("Fraction")
255
+ axes[2].legend()
256
+ axes[2].grid(True, alpha=0.3)
257
+
258
+ fig.tight_layout()
259
+ fig.savefig(plots_dir / "training_summary.png", dpi=150)
260
+ plt.close(fig)
261
+
262
+ print(f"Plots saved to: {plots_dir}")
boltzmann9/tester.py ADDED
@@ -0,0 +1,167 @@
1
+ """Testing utilities for RBM conditional probability evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from collections import Counter
7
+ from typing import Any, Dict, Sequence
8
+
9
+ import torch
10
+ from tqdm.auto import tqdm
11
+
12
+
13
+ class RBMTester:
14
+ """Test RBM conditional probability estimation.
15
+
16
+ Evaluates how well the RBM can predict target visible units
17
+ given clamped (observed) visible units.
18
+
19
+ Args:
20
+ model: Trained RBM model.
21
+ test_dataloader: DataLoader for test data.
22
+ clamp_idx: Indices of visible units to clamp (condition on).
23
+ target_idx: Indices of visible units to predict.
24
+ device: Target device for computations.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ model,
30
+ test_dataloader,
31
+ clamp_idx: Sequence[int],
32
+ target_idx: Sequence[int],
33
+ *,
34
+ device=None,
35
+ ):
36
+ self.model = model
37
+ self.test_dataloader = test_dataloader
38
+ self.clamp_idx = list(clamp_idx)
39
+ self.target_idx = list(target_idx)
40
+ self.device = device or model.W.device
41
+
42
+ @staticmethod
43
+ def _bits_to_int(bits: torch.Tensor) -> int:
44
+ """Convert binary bits to integer (LSB-first)."""
45
+ weights = 2 ** torch.arange(bits.numel(), device=bits.device)
46
+ return int((bits * weights).sum().item())
47
+
48
+ @torch.no_grad()
49
+ def conditional_nll(
50
+ self,
51
+ *,
52
+ n_samples: int = 500,
53
+ burn_in: int = 300,
54
+ thin: int = 10,
55
+ laplace_alpha: float = 1.0,
56
+ log_every: int = 50,
57
+ ) -> Dict[str, Any]:
58
+ """Compute conditional negative log-likelihood over test set.
59
+
60
+ For each test sample, clamps the specified visible units and
61
+ samples from the conditional distribution p(target | clamp).
62
+
63
+ Computes per-bit NLL using empirical bit frequencies, which scales
64
+ to large target dimensions.
65
+
66
+ Args:
67
+ n_samples: Number of MCMC samples per test point.
68
+ burn_in: Burn-in steps for MCMC.
69
+ thin: Thinning interval between samples.
70
+ laplace_alpha: Laplace smoothing parameter for per-bit probabilities.
71
+ log_every: Log progress every N samples.
72
+
73
+ Returns:
74
+ Dictionary with:
75
+ - mean_nll_nats: Mean NLL in nats (sum over target bits).
76
+ - mean_nll_bits: Mean NLL in bits.
77
+ - mean_nll_per_bit: Mean NLL per target bit.
78
+ - nll_nats_per_sample: List of per-sample NLL in nats.
79
+ - nll_bits_per_sample: List of per-sample NLL in bits.
80
+ """
81
+ self.model.eval()
82
+
83
+ ln2 = math.log(2.0)
84
+
85
+ nlls_nats = []
86
+ nlls_bits = []
87
+
88
+ total_points = len(self.test_dataloader.dataset)
89
+ n_target_bits = len(self.target_idx)
90
+
91
+ outer_pbar = tqdm(
92
+ total=total_points,
93
+ desc="RBM conditional NLL",
94
+ leave=True,
95
+ )
96
+
97
+ for batch_idx, v in enumerate(self.test_dataloader):
98
+ v = v.to(self.device)
99
+
100
+ for i in range(v.size(0)):
101
+ # Clamp visible units
102
+ v_clamp = torch.zeros(
103
+ self.model.nv,
104
+ device=self.device,
105
+ dtype=self.model.W.dtype,
106
+ )
107
+ v_clamp[self.clamp_idx] = v[i, self.clamp_idx]
108
+
109
+ # True target bits
110
+ true_bits = v[i, self.target_idx]
111
+
112
+ # Sample conditional
113
+ samples = self.model.sample_clamped(
114
+ v_clamp=v_clamp,
115
+ clamp_idx=self.clamp_idx,
116
+ n_samples=n_samples,
117
+ burn_in=burn_in,
118
+ thin=thin,
119
+ )
120
+
121
+ # Get target bits from samples: (n_samples, n_target_bits)
122
+ sampled_bits = samples[:, self.target_idx]
123
+
124
+ # Compute per-bit empirical probability with Laplace smoothing
125
+ # For each target bit, estimate P(bit=true_value)
126
+ bit_matches = (sampled_bits == true_bits.unsqueeze(0)).float()
127
+ bit_probs = (bit_matches.sum(dim=0) + laplace_alpha) / (n_samples + 2 * laplace_alpha)
128
+
129
+ # NLL = -sum(log(p_i)) for each bit
130
+ # Clamp to avoid log(0)
131
+ bit_probs = bit_probs.clamp(min=1e-10, max=1 - 1e-10)
132
+ nll_nats = -torch.log(bit_probs).sum().item()
133
+ nll_bits = nll_nats / ln2
134
+
135
+ nlls_nats.append(nll_nats)
136
+ nlls_bits.append(nll_bits)
137
+
138
+ # Logging
139
+ if len(nlls_nats) % log_every == 0:
140
+ mean_nats = sum(nlls_nats) / len(nlls_nats)
141
+ mean_bits = sum(nlls_bits) / len(nlls_bits)
142
+ outer_pbar.set_postfix(
143
+ nll_nats=f"{mean_nats:.3f}",
144
+ nll_bits=f"{mean_bits:.3f}",
145
+ per_bit=f"{mean_nats / n_target_bits:.4f}",
146
+ )
147
+
148
+ outer_pbar.update(1)
149
+
150
+ outer_pbar.close()
151
+
152
+ mean_nats = (
153
+ float(sum(nlls_nats) / len(nlls_nats)) if nlls_nats else float("nan")
154
+ )
155
+ mean_bits = (
156
+ float(sum(nlls_bits) / len(nlls_bits)) if nlls_bits else float("nan")
157
+ )
158
+ mean_per_bit = mean_nats / n_target_bits if n_target_bits > 0 else float("nan")
159
+
160
+ return {
161
+ "mean_nll_nats": mean_nats,
162
+ "mean_nll_bits": mean_bits,
163
+ "mean_nll_per_bit": mean_per_bit,
164
+ "n_target_bits": n_target_bits,
165
+ "nll_nats_per_sample": nlls_nats,
166
+ "nll_bits_per_sample": nlls_bits,
167
+ }
boltzmann9/utils.py ADDED
@@ -0,0 +1,42 @@
1
+ """Utility functions for device resolution and memory management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+
8
+ def resolve_device(device_cfg: str | None) -> torch.device:
9
+ """Resolve device configuration to a torch.device.
10
+
11
+ Args:
12
+ device_cfg: Device configuration string. Can be:
13
+ - None or "auto": Auto-detect best available device
14
+ - "cpu": Use CPU
15
+ - "cuda:0", "cuda:1", etc.: Use specific CUDA device
16
+ - "mps": Use Apple Metal Performance Shaders
17
+
18
+ Returns:
19
+ torch.device instance for the resolved device.
20
+ """
21
+ if device_cfg is None or device_cfg == "auto":
22
+ if torch.cuda.is_available():
23
+ return torch.device("cuda:0")
24
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
25
+ return torch.device("mps")
26
+ return torch.device("cpu")
27
+ return torch.device(device_cfg)
28
+
29
+
30
+ def resolve_pin_memory(pin_cfg, device: torch.device) -> bool:
31
+ """Resolve pin_memory configuration based on device type.
32
+
33
+ Args:
34
+ pin_cfg: Pin memory configuration. Can be "auto", True, or False.
35
+ device: The target torch device.
36
+
37
+ Returns:
38
+ Boolean indicating whether to pin memory.
39
+ """
40
+ if pin_cfg == "auto":
41
+ return device.type == "cuda"
42
+ return bool(pin_cfg)