ma-agents 3.3.0 → 3.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/.opencode/skills/.ma-agents.json +99 -99
  2. package/.roo/skills/.ma-agents.json +99 -99
  3. package/README.md +19 -1
  4. package/bin/cli.js +55 -0
  5. package/lib/agents.js +23 -0
  6. package/lib/bmad-cache/cache-manifest.json +1 -1
  7. package/lib/bmad-customizations/bmm-demerzel.customize.yaml +36 -0
  8. package/lib/bmad-customizations/demerzel.md +32 -0
  9. package/lib/bmad-extension/module-help.csv +13 -0
  10. package/lib/bmad-extension/skills/bmad-ma-agent-ml/.gitkeep +0 -0
  11. package/lib/bmad-extension/skills/bmad-ma-agent-ml/SKILL.md +59 -0
  12. package/lib/bmad-extension/skills/bmad-ma-agent-ml/bmad-skill-manifest.yaml +11 -0
  13. package/lib/bmad-extension/skills/generate-backlog/.gitkeep +0 -0
  14. package/lib/bmad-extension/skills/ml-advise/.gitkeep +0 -0
  15. package/lib/bmad-extension/skills/ml-advise/SKILL.md +76 -0
  16. package/lib/bmad-extension/skills/ml-advise/bmad-skill-manifest.yaml +3 -0
  17. package/lib/bmad-extension/skills/ml-advise/skill.json +7 -0
  18. package/lib/bmad-extension/skills/ml-analysis/.gitkeep +0 -0
  19. package/lib/bmad-extension/skills/ml-analysis/SKILL.md +60 -0
  20. package/lib/bmad-extension/skills/ml-analysis/bmad-skill-manifest.yaml +3 -0
  21. package/lib/bmad-extension/skills/ml-analysis/skill.json +7 -0
  22. package/lib/bmad-extension/skills/ml-architecture/.gitkeep +0 -0
  23. package/lib/bmad-extension/skills/ml-architecture/SKILL.md +55 -0
  24. package/lib/bmad-extension/skills/ml-architecture/bmad-skill-manifest.yaml +3 -0
  25. package/lib/bmad-extension/skills/ml-architecture/skill.json +7 -0
  26. package/lib/bmad-extension/skills/ml-detailed-design/.gitkeep +0 -0
  27. package/lib/bmad-extension/skills/ml-detailed-design/SKILL.md +67 -0
  28. package/lib/bmad-extension/skills/ml-detailed-design/bmad-skill-manifest.yaml +3 -0
  29. package/lib/bmad-extension/skills/ml-detailed-design/skill.json +7 -0
  30. package/lib/bmad-extension/skills/ml-eda/.gitkeep +0 -0
  31. package/lib/bmad-extension/skills/ml-eda/SKILL.md +56 -0
  32. package/lib/bmad-extension/skills/ml-eda/bmad-skill-manifest.yaml +3 -0
  33. package/lib/bmad-extension/skills/ml-eda/scripts/baseline_classifier.py +522 -0
  34. package/lib/bmad-extension/skills/ml-eda/scripts/class_weights_calculator.py +295 -0
  35. package/lib/bmad-extension/skills/ml-eda/scripts/clustering_explorer.py +383 -0
  36. package/lib/bmad-extension/skills/ml-eda/scripts/eda_analyzer.py +654 -0
  37. package/lib/bmad-extension/skills/ml-eda/skill.json +7 -0
  38. package/lib/bmad-extension/skills/ml-experiment/.gitkeep +0 -0
  39. package/lib/bmad-extension/skills/ml-experiment/SKILL.md +74 -0
  40. package/lib/bmad-extension/skills/ml-experiment/assets/advanced_trainer_configs.py +430 -0
  41. package/lib/bmad-extension/skills/ml-experiment/assets/quick_trainer_setup.py +233 -0
  42. package/lib/bmad-extension/skills/ml-experiment/assets/template_datamodule.py +219 -0
  43. package/lib/bmad-extension/skills/ml-experiment/assets/template_gnn_module.py +341 -0
  44. package/lib/bmad-extension/skills/ml-experiment/assets/template_lightning_module.py +158 -0
  45. package/lib/bmad-extension/skills/ml-experiment/bmad-skill-manifest.yaml +3 -0
  46. package/lib/bmad-extension/skills/ml-experiment/skill.json +7 -0
  47. package/lib/bmad-extension/skills/ml-hparam/.gitkeep +0 -0
  48. package/lib/bmad-extension/skills/ml-hparam/SKILL.md +81 -0
  49. package/lib/bmad-extension/skills/ml-hparam/bmad-skill-manifest.yaml +3 -0
  50. package/lib/bmad-extension/skills/ml-hparam/skill.json +7 -0
  51. package/lib/bmad-extension/skills/ml-ideation/.gitkeep +0 -0
  52. package/lib/bmad-extension/skills/ml-ideation/SKILL.md +50 -0
  53. package/lib/bmad-extension/skills/ml-ideation/bmad-skill-manifest.yaml +3 -0
  54. package/lib/bmad-extension/skills/ml-ideation/scripts/validate_ml_prd.py +287 -0
  55. package/lib/bmad-extension/skills/ml-ideation/skill.json +7 -0
  56. package/lib/bmad-extension/skills/ml-infra/.gitkeep +0 -0
  57. package/lib/bmad-extension/skills/ml-infra/SKILL.md +58 -0
  58. package/lib/bmad-extension/skills/ml-infra/bmad-skill-manifest.yaml +3 -0
  59. package/lib/bmad-extension/skills/ml-infra/skill.json +7 -0
  60. package/lib/bmad-extension/skills/ml-retrospective/.gitkeep +0 -0
  61. package/lib/bmad-extension/skills/ml-retrospective/SKILL.md +63 -0
  62. package/lib/bmad-extension/skills/ml-retrospective/bmad-skill-manifest.yaml +3 -0
  63. package/lib/bmad-extension/skills/ml-retrospective/skill.json +7 -0
  64. package/lib/bmad-extension/skills/ml-revision/.gitkeep +0 -0
  65. package/lib/bmad-extension/skills/ml-revision/SKILL.md +82 -0
  66. package/lib/bmad-extension/skills/ml-revision/bmad-skill-manifest.yaml +3 -0
  67. package/lib/bmad-extension/skills/ml-revision/skill.json +7 -0
  68. package/lib/bmad-extension/skills/ml-techspec/.gitkeep +0 -0
  69. package/lib/bmad-extension/skills/ml-techspec/SKILL.md +80 -0
  70. package/lib/bmad-extension/skills/ml-techspec/bmad-skill-manifest.yaml +3 -0
  71. package/lib/bmad-extension/skills/ml-techspec/skill.json +7 -0
  72. package/lib/bmad.js +85 -8
  73. package/lib/skill-authoring.js +1 -1
  74. package/package.json +2 -2
  75. package/test/agent-injection-strategy.test.js +4 -4
  76. package/test/bmad-version-bump.test.js +34 -34
  77. package/test/build-bmad-args.test.js +13 -6
  78. package/test/convert-agents-to-skills.test.js +11 -1
  79. package/test/extension-module-restructure.test.js +31 -7
  80. package/test/migration-validation.test.js +14 -11
@@ -0,0 +1,233 @@
1
+ """
2
+ quick_trainer_setup.py — BMAD DL Lifecycle
3
+ Ready-to-run Lightning Trainer configuration for standard DL training runs.
4
+
5
+ Covers: callbacks (early stopping, checkpointing, LR monitor),
6
+ loggers (CSV + optional TensorBoard/W&B), and hardware-aware device selection.
7
+
8
+ Usage:
9
+ python3 assets/quick_trainer_setup.py # prints recommended config
10
+ python3 assets/quick_trainer_setup.py --run # launches a training run (demo)
11
+
12
+ Or import and call build_trainer() in your training script:
13
+
14
+ from assets.quick_trainer_setup import build_trainer
15
+ from src.models.your_model import YourModel
16
+ from src.data.your_datamodule import YourDataModule
17
+
18
+ trainer = build_trainer(max_epochs=50, experiment_name="run_001")
19
+ model = YourModel(num_classes=2)
20
+ dm = YourDataModule(data_dir="data/")
21
+ trainer.fit(model, dm)
22
+ trainer.test(model, dm)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import sys
29
+ from pathlib import Path
30
+
31
+
32
+ # ── Lightning import ──────────────────────────────────────────────────────────
33
+
34
+ try:
35
+ import lightning as L
36
+ from lightning.pytorch.callbacks import (
37
+ EarlyStopping, ModelCheckpoint, LearningRateMonitor, RichProgressBar,
38
+ )
39
+ from lightning.pytorch.loggers import CSVLogger
40
+ LIGHTNING_PKG = "lightning"
41
+ except ImportError:
42
+ try:
43
+ import pytorch_lightning as L
44
+ from pytorch_lightning.callbacks import (
45
+ EarlyStopping, ModelCheckpoint, LearningRateMonitor,
46
+ )
47
+ from pytorch_lightning.loggers import CSVLogger
48
+ RichProgressBar = None
49
+ LIGHTNING_PKG = "pytorch_lightning"
50
+ except ImportError:
51
+ L = None # type: ignore
52
+ LIGHTNING_PKG = None
53
+
54
+
55
+ def _detect_accelerator() -> tuple[str, int]:
56
+ """Return (accelerator, devices) based on available hardware, with explicit status output."""
57
+ try:
58
+ import torch
59
+ if torch.cuda.is_available():
60
+ device_name = torch.cuda.get_device_name(0)
61
+ vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
62
+ print(f"GPU: {device_name} ({vram_gb:.1f} GB VRAM) — using CUDA")
63
+ return "gpu", torch.cuda.device_count()
64
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
65
+ print("GPU: Apple MPS — using Metal Performance Shaders")
66
+ return "mps", 1
67
+ except ImportError:
68
+ pass
69
+ print("WARNING: No GPU detected — training will run on CPU and be significantly slower.")
70
+ print(" If you expected a GPU, check your CUDA installation and driver.")
71
+ return "cpu", 1
72
+
73
+
74
+ def build_trainer(
75
+ max_epochs: int = 50,
76
+ experiment_name: str = "experiment",
77
+ version: str | None = None,
78
+ log_dir: str | Path = "logs/",
79
+ monitor_metric: str = "val/loss",
80
+ monitor_mode: str = "min",
81
+ early_stopping_patience: int = 10,
82
+ gradient_clip_val: float = 1.0,
83
+ accumulate_grad_batches: int = 1,
84
+ precision: str = "16-mixed",
85
+ ) -> "L.Trainer":
86
+ """
87
+ Build a Lightning Trainer with standard callbacks and logging.
88
+
89
+ Args:
90
+ max_epochs: Maximum training epochs.
91
+ experiment_name: Name used for checkpoint dir and log subdir.
92
+ version: Run identifier appended to log path (e.g. "fold_0", "run_001").
93
+ Prevents different runs from overwriting each other's TensorBoard logs.
94
+ log_dir: Root directory for logs and checkpoints.
95
+ monitor_metric: Metric to monitor for early stopping and checkpointing.
96
+ monitor_mode: "min" (for loss) or "max" (for accuracy/F1).
97
+ early_stopping_patience: Stop after N epochs without improvement.
98
+ gradient_clip_val: Max gradient norm (0.0 to disable clipping).
99
+ accumulate_grad_batches: Simulate larger batch size via gradient accumulation.
100
+ precision: Training precision ("32", "16-mixed", "bf16-mixed").
101
+
102
+ Returns:
103
+ Configured Lightning Trainer.
104
+ """
105
+ if L is None:
106
+ raise ImportError(
107
+ "PyTorch Lightning not installed.\n"
108
+ " pip install lightning (recommended)\n"
109
+ " or: pip install pytorch-lightning"
110
+ )
111
+
112
+ log_dir = Path(log_dir)
113
+ ckpt_dir = log_dir / "checkpoints" / experiment_name
114
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ accelerator, devices = _detect_accelerator()
117
+
118
+ # ── Callbacks ─────────────────────────────────────────────────────────────
119
+ callbacks = [
120
+ ModelCheckpoint(
121
+ dirpath=ckpt_dir,
122
+ filename=f"{experiment_name}-{{epoch:02d}}-{{{monitor_metric}:.4f}}",
123
+ monitor=monitor_metric,
124
+ mode=monitor_mode,
125
+ save_top_k=3,
126
+ save_last=True,
127
+ verbose=True,
128
+ ),
129
+ EarlyStopping(
130
+ monitor=monitor_metric,
131
+ mode=monitor_mode,
132
+ patience=early_stopping_patience,
133
+ verbose=True,
134
+ ),
135
+ LearningRateMonitor(logging_interval="epoch"),
136
+ ]
137
+ if RichProgressBar is not None:
138
+ callbacks.append(RichProgressBar())
139
+
140
+ # ── Loggers ───────────────────────────────────────────────────────────────
141
+ # TensorBoard is required — install with: uv add tensorboard
142
+ # version= keeps each fold/run in its own subdir so they never overwrite each other
143
+ if LIGHTNING_PKG == "lightning":
144
+ from lightning.pytorch.loggers import TensorBoardLogger
145
+ else:
146
+ from pytorch_lightning.loggers import TensorBoardLogger
147
+
148
+ tb_logger = TensorBoardLogger(save_dir=str(log_dir), name=experiment_name, version=version)
149
+ csv_logger = CSVLogger(save_dir=str(log_dir), name=experiment_name, version=version)
150
+ loggers = [tb_logger, csv_logger]
151
+
152
+ log_path = Path(log_dir) / experiment_name / (version or f"version_{tb_logger.version}")
153
+ print(f"Logs → {log_path}/ run: tensorboard --logdir={log_dir}")
154
+
155
+ # ── Precision ─────────────────────────────────────────────────────────────
156
+ # Fall back to 32-bit on CPU (mixed precision not supported)
157
+ if accelerator == "cpu" and precision != "32":
158
+ precision = "32"
159
+
160
+ trainer = L.Trainer(
161
+ max_epochs=max_epochs,
162
+ accelerator=accelerator,
163
+ devices=devices,
164
+ precision=precision,
165
+ gradient_clip_val=gradient_clip_val if gradient_clip_val > 0 else None,
166
+ accumulate_grad_batches=accumulate_grad_batches,
167
+ callbacks=callbacks,
168
+ logger=loggers,
169
+ log_every_n_steps=10,
170
+ deterministic=False, # set True for full reproducibility (slower)
171
+ )
172
+
173
+ return trainer
174
+
175
+
176
+ def print_config(max_epochs: int, experiment_name: str, log_dir: str) -> None:
177
+ accelerator, devices = _detect_accelerator()
178
+ print(f"""
179
+ ┌─────────────────────────────────────────────────────┐
180
+ │ BMAD DL — Quick Trainer Configuration │
181
+ ├─────────────────────────────────────────────────────┤
182
+ │ Lightning package : {LIGHTNING_PKG or 'NOT INSTALLED':<30} │
183
+ │ Hardware : {accelerator.upper()} ({devices} device(s)){'':<19} │
184
+ │ Max epochs : {max_epochs:<30} │
185
+ │ Experiment name : {experiment_name:<30} │
186
+ │ Log directory : {log_dir:<30} │
187
+ ├─────────────────────────────────────────────────────┤
188
+ │ Callbacks active: │
189
+ │ ✓ ModelCheckpoint (top-3 + last) │
190
+ │ ✓ EarlyStopping (patience=10) │
191
+ │ ✓ LearningRateMonitor │
192
+ │ ✓ RichProgressBar (if available) │
193
+ │ Loggers active: │
194
+ │ ✓ CSVLogger │
195
+ │ ✓ TensorBoardLogger (if tensorboard installed) │
196
+ └─────────────────────────────────────────────────────┘
197
+
198
+ Quick start in your training script:
199
+
200
+ from assets.quick_trainer_setup import build_trainer
201
+ trainer = build_trainer(max_epochs=50, experiment_name="run_001")
202
+ trainer.fit(model, datamodule)
203
+ trainer.test(model, datamodule)
204
+
205
+ After training, parse results with:
206
+
207
+ python3 scripts/parse_training_logs.py \\
208
+ logs/{experiment_name}/version_0/metrics.csv \\
209
+ docs/prd/01_PRD.md
210
+ """)
211
+
212
+
213
+ def main() -> int:
214
+ parser = argparse.ArgumentParser(description="Quick Trainer Setup — BMAD DL Lifecycle")
215
+ parser.add_argument("--run", action="store_true", help="Launch a demo training run")
216
+ parser.add_argument("--max-epochs", type=int, default=50)
217
+ parser.add_argument("--experiment-name", type=str, default="run_001")
218
+ parser.add_argument("--log-dir", type=str, default="logs/")
219
+ args = parser.parse_args()
220
+
221
+ if args.run:
222
+ if L is None:
223
+ print("Error: PyTorch Lightning not installed.", file=sys.stderr)
224
+ return 2
225
+ print("Demo run requires a model and datamodule. See module docstring.")
226
+ return 1
227
+
228
+ print_config(args.max_epochs, args.experiment_name, args.log_dir)
229
+ return 0
230
+
231
+
232
+ if __name__ == "__main__":
233
+ sys.exit(main())
@@ -0,0 +1,219 @@
1
+ """
2
+ template_datamodule.py — BMAD DL Lifecycle
3
+ PyTorch Lightning LightningDataModule template.
4
+
5
+ Handles train/val/test dataset loading, transforms, and DataLoader creation
6
+ in a clean, reproducible, and Lightning-compatible way.
7
+
8
+ Usage:
9
+ Copy to src/data/your_datamodule.py and implement the TODO sections.
10
+ Then pass it directly to the Trainer — no manual DataLoaders needed.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from torch.utils.data import DataLoader, Dataset, random_split
20
+
21
+ try:
22
+ import lightning as L
23
+ LightningDataModule = L.LightningDataModule
24
+ except ImportError:
25
+ try:
26
+ import pytorch_lightning as pl
27
+ LightningDataModule = pl.LightningDataModule
28
+ except ImportError:
29
+ raise ImportError("Install PyTorch Lightning: pip install lightning")
30
+
31
+ try:
32
+ from torchvision import transforms
33
+ HAS_TORCHVISION = True
34
+ except ImportError:
35
+ HAS_TORCHVISION = False
36
+
37
+
38
+ # ── TODO: Define or import your Dataset ───────────────────────────────────────
39
+ # Replace this stub with your actual Dataset class.
40
+
41
+ class YourDataset(Dataset):
42
+ """
43
+ Stub dataset — replace with your implementation.
44
+
45
+ Expected output per __getitem__: (input_tensor, label_tensor)
46
+ """
47
+
48
+ def __init__(self, data_dir: Path, split: str = "train", transform=None):
49
+ self.data_dir = data_dir
50
+ self.split = split
51
+ self.transform = transform
52
+
53
+ # TODO: load file list, annotations, CSV rows, etc.
54
+ self.samples: list = [] # list of (path_or_data, label)
55
+
56
+ def __len__(self) -> int:
57
+ return len(self.samples)
58
+
59
+ def __getitem__(self, idx: int):
60
+ sample, label = self.samples[idx]
61
+ # TODO: load image/array/features from `sample`
62
+ # x = Image.open(sample).convert("RGB")
63
+ # if self.transform:
64
+ # x = self.transform(x)
65
+ # return x, label
66
+ raise NotImplementedError("Implement __getitem__ in your Dataset")
67
+
68
+ # ── END TODO ──────────────────────────────────────────────────────────────────
69
+
70
+
71
+ class YourDataModule(LightningDataModule):
72
+ """
73
+ Template LightningDataModule.
74
+
75
+ Replace 'YourDataModule' with a descriptive name (e.g. DefectDataModule).
76
+
77
+ Args:
78
+ data_dir: Root directory of your dataset.
79
+ batch_size: Batch size for all DataLoaders.
80
+ num_workers: Number of worker processes for data loading.
81
+ val_split: Fraction of training data to use for validation
82
+ (only used when no explicit val/ directory exists).
83
+ seed: Random seed for reproducibility.
84
+ image_size: (H, W) for image resizing — set None to skip.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ data_dir: str | Path = "data/",
90
+ batch_size: int = 32,
91
+ num_workers: int = 4,
92
+ val_split: float = 0.15,
93
+ seed: int = 42,
94
+ image_size: tuple[int, int] | None = (224, 224),
95
+ ):
96
+ super().__init__()
97
+ self.save_hyperparameters()
98
+ self.data_dir = Path(data_dir)
99
+
100
+ # Built in setup()
101
+ self.train_dataset: Optional[Dataset] = None
102
+ self.val_dataset: Optional[Dataset] = None
103
+ self.test_dataset: Optional[Dataset] = None
104
+
105
+ # ── Transforms ────────────────────────────────────────────────────────────
106
+
107
+ def _train_transform(self):
108
+ """
109
+ TODO: Define augmentation pipeline for training.
110
+ """
111
+ if not HAS_TORCHVISION:
112
+ return None
113
+ steps = []
114
+ if self.hparams.image_size:
115
+ steps.append(transforms.Resize(self.hparams.image_size))
116
+ steps += [
117
+ transforms.RandomHorizontalFlip(),
118
+ transforms.RandomRotation(10),
119
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
120
+ transforms.ToTensor(),
121
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
122
+ std=[0.229, 0.224, 0.225]),
123
+ ]
124
+ return transforms.Compose(steps)
125
+
126
+ def _eval_transform(self):
127
+ """
128
+ TODO: Define deterministic transform for val/test (no augmentation).
129
+ """
130
+ if not HAS_TORCHVISION:
131
+ return None
132
+ steps = []
133
+ if self.hparams.image_size:
134
+ steps.append(transforms.Resize(self.hparams.image_size))
135
+ steps += [
136
+ transforms.ToTensor(),
137
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
138
+ std=[0.229, 0.224, 0.225]),
139
+ ]
140
+ return transforms.Compose(steps)
141
+
142
+ # ── Setup ─────────────────────────────────────────────────────────────────
143
+
144
+ def setup(self, stage: Optional[str] = None) -> None:
145
+ """
146
+ Called by Lightning before fit/test. Initializes dataset splits.
147
+
148
+ stage: "fit" (train+val), "test", "predict", or None (all).
149
+ """
150
+ # TODO: Adjust split detection logic for your directory layout.
151
+ # Option A — explicit split directories: data/train/, data/val/, data/test/
152
+ has_split_dirs = (
153
+ (self.data_dir / "train").exists() and
154
+ (self.data_dir / "val").exists()
155
+ )
156
+
157
+ if stage in (None, "fit"):
158
+ if has_split_dirs:
159
+ self.train_dataset = YourDataset(
160
+ self.data_dir / "train", split="train",
161
+ transform=self._train_transform(),
162
+ )
163
+ self.val_dataset = YourDataset(
164
+ self.data_dir / "val", split="val",
165
+ transform=self._eval_transform(),
166
+ )
167
+ else:
168
+ # Option B — random split from single dataset directory
169
+ full_dataset = YourDataset(
170
+ self.data_dir, split="train",
171
+ transform=self._train_transform(),
172
+ )
173
+ val_size = int(len(full_dataset) * self.hparams.val_split)
174
+ train_size = len(full_dataset) - val_size
175
+ self.train_dataset, self.val_dataset = random_split(
176
+ full_dataset,
177
+ [train_size, val_size],
178
+ generator=torch.Generator().manual_seed(self.hparams.seed),
179
+ )
180
+
181
+ if stage in (None, "test"):
182
+ test_dir = self.data_dir / "test"
183
+ if test_dir.exists():
184
+ self.test_dataset = YourDataset(
185
+ test_dir, split="test",
186
+ transform=self._eval_transform(),
187
+ )
188
+
189
+ # ── DataLoaders ───────────────────────────────────────────────────────────
190
+
191
+ def train_dataloader(self) -> DataLoader:
192
+ return DataLoader(
193
+ self.train_dataset,
194
+ batch_size=self.hparams.batch_size,
195
+ shuffle=True,
196
+ num_workers=self.hparams.num_workers,
197
+ pin_memory=True,
198
+ drop_last=True,
199
+ )
200
+
201
+ def val_dataloader(self) -> DataLoader:
202
+ return DataLoader(
203
+ self.val_dataset,
204
+ batch_size=self.hparams.batch_size,
205
+ shuffle=False,
206
+ num_workers=self.hparams.num_workers,
207
+ pin_memory=True,
208
+ )
209
+
210
+ def test_dataloader(self) -> DataLoader:
211
+ if self.test_dataset is None:
212
+ raise RuntimeError("No test dataset found. Check data_dir/test/ exists.")
213
+ return DataLoader(
214
+ self.test_dataset,
215
+ batch_size=self.hparams.batch_size,
216
+ shuffle=False,
217
+ num_workers=self.hparams.num_workers,
218
+ pin_memory=True,
219
+ )