ma-agents 3.3.0 → 3.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.opencode/skills/.ma-agents.json +99 -99
- package/.roo/skills/.ma-agents.json +99 -99
- package/README.md +19 -1
- package/bin/cli.js +55 -0
- package/lib/agents.js +23 -0
- package/lib/bmad-cache/cache-manifest.json +1 -1
- package/lib/bmad-customizations/bmm-demerzel.customize.yaml +36 -0
- package/lib/bmad-customizations/demerzel.md +32 -0
- package/lib/bmad-extension/module-help.csv +13 -0
- package/lib/bmad-extension/skills/bmad-ma-agent-ml/.gitkeep +0 -0
- package/lib/bmad-extension/skills/bmad-ma-agent-ml/SKILL.md +59 -0
- package/lib/bmad-extension/skills/bmad-ma-agent-ml/bmad-skill-manifest.yaml +11 -0
- package/lib/bmad-extension/skills/generate-backlog/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-advise/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-advise/SKILL.md +76 -0
- package/lib/bmad-extension/skills/ml-advise/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-advise/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-analysis/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-analysis/SKILL.md +60 -0
- package/lib/bmad-extension/skills/ml-analysis/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-analysis/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-architecture/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-architecture/SKILL.md +55 -0
- package/lib/bmad-extension/skills/ml-architecture/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-architecture/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-detailed-design/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-detailed-design/SKILL.md +67 -0
- package/lib/bmad-extension/skills/ml-detailed-design/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-detailed-design/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-eda/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-eda/SKILL.md +56 -0
- package/lib/bmad-extension/skills/ml-eda/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-eda/scripts/baseline_classifier.py +522 -0
- package/lib/bmad-extension/skills/ml-eda/scripts/class_weights_calculator.py +295 -0
- package/lib/bmad-extension/skills/ml-eda/scripts/clustering_explorer.py +383 -0
- package/lib/bmad-extension/skills/ml-eda/scripts/eda_analyzer.py +654 -0
- package/lib/bmad-extension/skills/ml-eda/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-experiment/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-experiment/SKILL.md +74 -0
- package/lib/bmad-extension/skills/ml-experiment/assets/advanced_trainer_configs.py +430 -0
- package/lib/bmad-extension/skills/ml-experiment/assets/quick_trainer_setup.py +233 -0
- package/lib/bmad-extension/skills/ml-experiment/assets/template_datamodule.py +219 -0
- package/lib/bmad-extension/skills/ml-experiment/assets/template_gnn_module.py +341 -0
- package/lib/bmad-extension/skills/ml-experiment/assets/template_lightning_module.py +158 -0
- package/lib/bmad-extension/skills/ml-experiment/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-experiment/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-hparam/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-hparam/SKILL.md +81 -0
- package/lib/bmad-extension/skills/ml-hparam/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-hparam/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-ideation/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-ideation/SKILL.md +50 -0
- package/lib/bmad-extension/skills/ml-ideation/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-ideation/scripts/validate_ml_prd.py +287 -0
- package/lib/bmad-extension/skills/ml-ideation/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-infra/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-infra/SKILL.md +58 -0
- package/lib/bmad-extension/skills/ml-infra/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-infra/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-retrospective/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-retrospective/SKILL.md +63 -0
- package/lib/bmad-extension/skills/ml-retrospective/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-retrospective/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-revision/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-revision/SKILL.md +82 -0
- package/lib/bmad-extension/skills/ml-revision/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-revision/skill.json +7 -0
- package/lib/bmad-extension/skills/ml-techspec/.gitkeep +0 -0
- package/lib/bmad-extension/skills/ml-techspec/SKILL.md +80 -0
- package/lib/bmad-extension/skills/ml-techspec/bmad-skill-manifest.yaml +3 -0
- package/lib/bmad-extension/skills/ml-techspec/skill.json +7 -0
- package/lib/bmad.js +85 -8
- package/lib/skill-authoring.js +1 -1
- package/package.json +2 -2
- package/test/agent-injection-strategy.test.js +4 -4
- package/test/bmad-version-bump.test.js +34 -34
- package/test/build-bmad-args.test.js +13 -6
- package/test/convert-agents-to-skills.test.js +11 -1
- package/test/extension-module-restructure.test.js +31 -7
- package/test/migration-validation.test.js +14 -11
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""
|
|
2
|
+
quick_trainer_setup.py — BMAD DL Lifecycle
|
|
3
|
+
Ready-to-run Lightning Trainer configuration for standard DL training runs.
|
|
4
|
+
|
|
5
|
+
Covers: callbacks (early stopping, checkpointing, LR monitor),
|
|
6
|
+
loggers (CSV + optional TensorBoard/W&B), and hardware-aware device selection.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python3 assets/quick_trainer_setup.py # prints recommended config
|
|
10
|
+
python3 assets/quick_trainer_setup.py --run # launches a training run (demo)
|
|
11
|
+
|
|
12
|
+
Or import and call build_trainer() in your training script:
|
|
13
|
+
|
|
14
|
+
from assets.quick_trainer_setup import build_trainer
|
|
15
|
+
from src.models.your_model import YourModel
|
|
16
|
+
from src.data.your_datamodule import YourDataModule
|
|
17
|
+
|
|
18
|
+
trainer = build_trainer(max_epochs=50, experiment_name="run_001")
|
|
19
|
+
model = YourModel(num_classes=2)
|
|
20
|
+
dm = YourDataModule(data_dir="data/")
|
|
21
|
+
trainer.fit(model, dm)
|
|
22
|
+
trainer.test(model, dm)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import argparse
|
|
28
|
+
import sys
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# ── Lightning import ──────────────────────────────────────────────────────────
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import lightning as L
|
|
36
|
+
from lightning.pytorch.callbacks import (
|
|
37
|
+
EarlyStopping, ModelCheckpoint, LearningRateMonitor, RichProgressBar,
|
|
38
|
+
)
|
|
39
|
+
from lightning.pytorch.loggers import CSVLogger
|
|
40
|
+
LIGHTNING_PKG = "lightning"
|
|
41
|
+
except ImportError:
|
|
42
|
+
try:
|
|
43
|
+
import pytorch_lightning as L
|
|
44
|
+
from pytorch_lightning.callbacks import (
|
|
45
|
+
EarlyStopping, ModelCheckpoint, LearningRateMonitor,
|
|
46
|
+
)
|
|
47
|
+
from pytorch_lightning.loggers import CSVLogger
|
|
48
|
+
RichProgressBar = None
|
|
49
|
+
LIGHTNING_PKG = "pytorch_lightning"
|
|
50
|
+
except ImportError:
|
|
51
|
+
L = None # type: ignore
|
|
52
|
+
LIGHTNING_PKG = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _detect_accelerator() -> tuple[str, int]:
|
|
56
|
+
"""Return (accelerator, devices) based on available hardware, with explicit status output."""
|
|
57
|
+
try:
|
|
58
|
+
import torch
|
|
59
|
+
if torch.cuda.is_available():
|
|
60
|
+
device_name = torch.cuda.get_device_name(0)
|
|
61
|
+
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
62
|
+
print(f"GPU: {device_name} ({vram_gb:.1f} GB VRAM) — using CUDA")
|
|
63
|
+
return "gpu", torch.cuda.device_count()
|
|
64
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
65
|
+
print("GPU: Apple MPS — using Metal Performance Shaders")
|
|
66
|
+
return "mps", 1
|
|
67
|
+
except ImportError:
|
|
68
|
+
pass
|
|
69
|
+
print("WARNING: No GPU detected — training will run on CPU and be significantly slower.")
|
|
70
|
+
print(" If you expected a GPU, check your CUDA installation and driver.")
|
|
71
|
+
return "cpu", 1
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def build_trainer(
|
|
75
|
+
max_epochs: int = 50,
|
|
76
|
+
experiment_name: str = "experiment",
|
|
77
|
+
version: str | None = None,
|
|
78
|
+
log_dir: str | Path = "logs/",
|
|
79
|
+
monitor_metric: str = "val/loss",
|
|
80
|
+
monitor_mode: str = "min",
|
|
81
|
+
early_stopping_patience: int = 10,
|
|
82
|
+
gradient_clip_val: float = 1.0,
|
|
83
|
+
accumulate_grad_batches: int = 1,
|
|
84
|
+
precision: str = "16-mixed",
|
|
85
|
+
) -> "L.Trainer":
|
|
86
|
+
"""
|
|
87
|
+
Build a Lightning Trainer with standard callbacks and logging.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
max_epochs: Maximum training epochs.
|
|
91
|
+
experiment_name: Name used for checkpoint dir and log subdir.
|
|
92
|
+
version: Run identifier appended to log path (e.g. "fold_0", "run_001").
|
|
93
|
+
Prevents different runs from overwriting each other's TensorBoard logs.
|
|
94
|
+
log_dir: Root directory for logs and checkpoints.
|
|
95
|
+
monitor_metric: Metric to monitor for early stopping and checkpointing.
|
|
96
|
+
monitor_mode: "min" (for loss) or "max" (for accuracy/F1).
|
|
97
|
+
early_stopping_patience: Stop after N epochs without improvement.
|
|
98
|
+
gradient_clip_val: Max gradient norm (0.0 to disable clipping).
|
|
99
|
+
accumulate_grad_batches: Simulate larger batch size via gradient accumulation.
|
|
100
|
+
precision: Training precision ("32", "16-mixed", "bf16-mixed").
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Configured Lightning Trainer.
|
|
104
|
+
"""
|
|
105
|
+
if L is None:
|
|
106
|
+
raise ImportError(
|
|
107
|
+
"PyTorch Lightning not installed.\n"
|
|
108
|
+
" pip install lightning (recommended)\n"
|
|
109
|
+
" or: pip install pytorch-lightning"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
log_dir = Path(log_dir)
|
|
113
|
+
ckpt_dir = log_dir / "checkpoints" / experiment_name
|
|
114
|
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
accelerator, devices = _detect_accelerator()
|
|
117
|
+
|
|
118
|
+
# ── Callbacks ─────────────────────────────────────────────────────────────
|
|
119
|
+
callbacks = [
|
|
120
|
+
ModelCheckpoint(
|
|
121
|
+
dirpath=ckpt_dir,
|
|
122
|
+
filename=f"{experiment_name}-{{epoch:02d}}-{{{monitor_metric}:.4f}}",
|
|
123
|
+
monitor=monitor_metric,
|
|
124
|
+
mode=monitor_mode,
|
|
125
|
+
save_top_k=3,
|
|
126
|
+
save_last=True,
|
|
127
|
+
verbose=True,
|
|
128
|
+
),
|
|
129
|
+
EarlyStopping(
|
|
130
|
+
monitor=monitor_metric,
|
|
131
|
+
mode=monitor_mode,
|
|
132
|
+
patience=early_stopping_patience,
|
|
133
|
+
verbose=True,
|
|
134
|
+
),
|
|
135
|
+
LearningRateMonitor(logging_interval="epoch"),
|
|
136
|
+
]
|
|
137
|
+
if RichProgressBar is not None:
|
|
138
|
+
callbacks.append(RichProgressBar())
|
|
139
|
+
|
|
140
|
+
# ── Loggers ───────────────────────────────────────────────────────────────
|
|
141
|
+
# TensorBoard is required — install with: uv add tensorboard
|
|
142
|
+
# version= keeps each fold/run in its own subdir so they never overwrite each other
|
|
143
|
+
if LIGHTNING_PKG == "lightning":
|
|
144
|
+
from lightning.pytorch.loggers import TensorBoardLogger
|
|
145
|
+
else:
|
|
146
|
+
from pytorch_lightning.loggers import TensorBoardLogger
|
|
147
|
+
|
|
148
|
+
tb_logger = TensorBoardLogger(save_dir=str(log_dir), name=experiment_name, version=version)
|
|
149
|
+
csv_logger = CSVLogger(save_dir=str(log_dir), name=experiment_name, version=version)
|
|
150
|
+
loggers = [tb_logger, csv_logger]
|
|
151
|
+
|
|
152
|
+
log_path = Path(log_dir) / experiment_name / (version or f"version_{tb_logger.version}")
|
|
153
|
+
print(f"Logs → {log_path}/ run: tensorboard --logdir={log_dir}")
|
|
154
|
+
|
|
155
|
+
# ── Precision ─────────────────────────────────────────────────────────────
|
|
156
|
+
# Fall back to 32-bit on CPU (mixed precision not supported)
|
|
157
|
+
if accelerator == "cpu" and precision != "32":
|
|
158
|
+
precision = "32"
|
|
159
|
+
|
|
160
|
+
trainer = L.Trainer(
|
|
161
|
+
max_epochs=max_epochs,
|
|
162
|
+
accelerator=accelerator,
|
|
163
|
+
devices=devices,
|
|
164
|
+
precision=precision,
|
|
165
|
+
gradient_clip_val=gradient_clip_val if gradient_clip_val > 0 else None,
|
|
166
|
+
accumulate_grad_batches=accumulate_grad_batches,
|
|
167
|
+
callbacks=callbacks,
|
|
168
|
+
logger=loggers,
|
|
169
|
+
log_every_n_steps=10,
|
|
170
|
+
deterministic=False, # set True for full reproducibility (slower)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return trainer
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def print_config(max_epochs: int, experiment_name: str, log_dir: str) -> None:
|
|
177
|
+
accelerator, devices = _detect_accelerator()
|
|
178
|
+
print(f"""
|
|
179
|
+
┌─────────────────────────────────────────────────────┐
|
|
180
|
+
│ BMAD DL — Quick Trainer Configuration │
|
|
181
|
+
├─────────────────────────────────────────────────────┤
|
|
182
|
+
│ Lightning package : {LIGHTNING_PKG or 'NOT INSTALLED':<30} │
|
|
183
|
+
│ Hardware : {accelerator.upper()} ({devices} device(s)){'':<19} │
|
|
184
|
+
│ Max epochs : {max_epochs:<30} │
|
|
185
|
+
│ Experiment name : {experiment_name:<30} │
|
|
186
|
+
│ Log directory : {log_dir:<30} │
|
|
187
|
+
├─────────────────────────────────────────────────────┤
|
|
188
|
+
│ Callbacks active: │
|
|
189
|
+
│ ✓ ModelCheckpoint (top-3 + last) │
|
|
190
|
+
│ ✓ EarlyStopping (patience=10) │
|
|
191
|
+
│ ✓ LearningRateMonitor │
|
|
192
|
+
│ ✓ RichProgressBar (if available) │
|
|
193
|
+
│ Loggers active: │
|
|
194
|
+
│ ✓ CSVLogger │
|
|
195
|
+
│ ✓ TensorBoardLogger (if tensorboard installed) │
|
|
196
|
+
└─────────────────────────────────────────────────────┘
|
|
197
|
+
|
|
198
|
+
Quick start in your training script:
|
|
199
|
+
|
|
200
|
+
from assets.quick_trainer_setup import build_trainer
|
|
201
|
+
trainer = build_trainer(max_epochs=50, experiment_name="run_001")
|
|
202
|
+
trainer.fit(model, datamodule)
|
|
203
|
+
trainer.test(model, datamodule)
|
|
204
|
+
|
|
205
|
+
After training, parse results with:
|
|
206
|
+
|
|
207
|
+
python3 scripts/parse_training_logs.py \\
|
|
208
|
+
logs/{experiment_name}/version_0/metrics.csv \\
|
|
209
|
+
docs/prd/01_PRD.md
|
|
210
|
+
""")
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def main() -> int:
|
|
214
|
+
parser = argparse.ArgumentParser(description="Quick Trainer Setup — BMAD DL Lifecycle")
|
|
215
|
+
parser.add_argument("--run", action="store_true", help="Launch a demo training run")
|
|
216
|
+
parser.add_argument("--max-epochs", type=int, default=50)
|
|
217
|
+
parser.add_argument("--experiment-name", type=str, default="run_001")
|
|
218
|
+
parser.add_argument("--log-dir", type=str, default="logs/")
|
|
219
|
+
args = parser.parse_args()
|
|
220
|
+
|
|
221
|
+
if args.run:
|
|
222
|
+
if L is None:
|
|
223
|
+
print("Error: PyTorch Lightning not installed.", file=sys.stderr)
|
|
224
|
+
return 2
|
|
225
|
+
print("Demo run requires a model and datamodule. See module docstring.")
|
|
226
|
+
return 1
|
|
227
|
+
|
|
228
|
+
print_config(args.max_epochs, args.experiment_name, args.log_dir)
|
|
229
|
+
return 0
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
if __name__ == "__main__":
|
|
233
|
+
sys.exit(main())
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
template_datamodule.py — BMAD DL Lifecycle
|
|
3
|
+
PyTorch Lightning LightningDataModule template.
|
|
4
|
+
|
|
5
|
+
Handles train/val/test dataset loading, transforms, and DataLoader creation
|
|
6
|
+
in a clean, reproducible, and Lightning-compatible way.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
Copy to src/data/your_datamodule.py and implement the TODO sections.
|
|
10
|
+
Then pass it directly to the Trainer — no manual DataLoaders needed.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils.data import DataLoader, Dataset, random_split
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import lightning as L
|
|
23
|
+
LightningDataModule = L.LightningDataModule
|
|
24
|
+
except ImportError:
|
|
25
|
+
try:
|
|
26
|
+
import pytorch_lightning as pl
|
|
27
|
+
LightningDataModule = pl.LightningDataModule
|
|
28
|
+
except ImportError:
|
|
29
|
+
raise ImportError("Install PyTorch Lightning: pip install lightning")
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from torchvision import transforms
|
|
33
|
+
HAS_TORCHVISION = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
HAS_TORCHVISION = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ── TODO: Define or import your Dataset ───────────────────────────────────────
|
|
39
|
+
# Replace this stub with your actual Dataset class.
|
|
40
|
+
|
|
41
|
+
class YourDataset(Dataset):
|
|
42
|
+
"""
|
|
43
|
+
Stub dataset — replace with your implementation.
|
|
44
|
+
|
|
45
|
+
Expected output per __getitem__: (input_tensor, label_tensor)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, data_dir: Path, split: str = "train", transform=None):
|
|
49
|
+
self.data_dir = data_dir
|
|
50
|
+
self.split = split
|
|
51
|
+
self.transform = transform
|
|
52
|
+
|
|
53
|
+
# TODO: load file list, annotations, CSV rows, etc.
|
|
54
|
+
self.samples: list = [] # list of (path_or_data, label)
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
return len(self.samples)
|
|
58
|
+
|
|
59
|
+
def __getitem__(self, idx: int):
|
|
60
|
+
sample, label = self.samples[idx]
|
|
61
|
+
# TODO: load image/array/features from `sample`
|
|
62
|
+
# x = Image.open(sample).convert("RGB")
|
|
63
|
+
# if self.transform:
|
|
64
|
+
# x = self.transform(x)
|
|
65
|
+
# return x, label
|
|
66
|
+
raise NotImplementedError("Implement __getitem__ in your Dataset")
|
|
67
|
+
|
|
68
|
+
# ── END TODO ──────────────────────────────────────────────────────────────────
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class YourDataModule(LightningDataModule):
|
|
72
|
+
"""
|
|
73
|
+
Template LightningDataModule.
|
|
74
|
+
|
|
75
|
+
Replace 'YourDataModule' with a descriptive name (e.g. DefectDataModule).
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
data_dir: Root directory of your dataset.
|
|
79
|
+
batch_size: Batch size for all DataLoaders.
|
|
80
|
+
num_workers: Number of worker processes for data loading.
|
|
81
|
+
val_split: Fraction of training data to use for validation
|
|
82
|
+
(only used when no explicit val/ directory exists).
|
|
83
|
+
seed: Random seed for reproducibility.
|
|
84
|
+
image_size: (H, W) for image resizing — set None to skip.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
data_dir: str | Path = "data/",
|
|
90
|
+
batch_size: int = 32,
|
|
91
|
+
num_workers: int = 4,
|
|
92
|
+
val_split: float = 0.15,
|
|
93
|
+
seed: int = 42,
|
|
94
|
+
image_size: tuple[int, int] | None = (224, 224),
|
|
95
|
+
):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.save_hyperparameters()
|
|
98
|
+
self.data_dir = Path(data_dir)
|
|
99
|
+
|
|
100
|
+
# Built in setup()
|
|
101
|
+
self.train_dataset: Optional[Dataset] = None
|
|
102
|
+
self.val_dataset: Optional[Dataset] = None
|
|
103
|
+
self.test_dataset: Optional[Dataset] = None
|
|
104
|
+
|
|
105
|
+
# ── Transforms ────────────────────────────────────────────────────────────
|
|
106
|
+
|
|
107
|
+
def _train_transform(self):
|
|
108
|
+
"""
|
|
109
|
+
TODO: Define augmentation pipeline for training.
|
|
110
|
+
"""
|
|
111
|
+
if not HAS_TORCHVISION:
|
|
112
|
+
return None
|
|
113
|
+
steps = []
|
|
114
|
+
if self.hparams.image_size:
|
|
115
|
+
steps.append(transforms.Resize(self.hparams.image_size))
|
|
116
|
+
steps += [
|
|
117
|
+
transforms.RandomHorizontalFlip(),
|
|
118
|
+
transforms.RandomRotation(10),
|
|
119
|
+
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|
120
|
+
transforms.ToTensor(),
|
|
121
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
122
|
+
std=[0.229, 0.224, 0.225]),
|
|
123
|
+
]
|
|
124
|
+
return transforms.Compose(steps)
|
|
125
|
+
|
|
126
|
+
def _eval_transform(self):
|
|
127
|
+
"""
|
|
128
|
+
TODO: Define deterministic transform for val/test (no augmentation).
|
|
129
|
+
"""
|
|
130
|
+
if not HAS_TORCHVISION:
|
|
131
|
+
return None
|
|
132
|
+
steps = []
|
|
133
|
+
if self.hparams.image_size:
|
|
134
|
+
steps.append(transforms.Resize(self.hparams.image_size))
|
|
135
|
+
steps += [
|
|
136
|
+
transforms.ToTensor(),
|
|
137
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
138
|
+
std=[0.229, 0.224, 0.225]),
|
|
139
|
+
]
|
|
140
|
+
return transforms.Compose(steps)
|
|
141
|
+
|
|
142
|
+
# ── Setup ─────────────────────────────────────────────────────────────────
|
|
143
|
+
|
|
144
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Called by Lightning before fit/test. Initializes dataset splits.
|
|
147
|
+
|
|
148
|
+
stage: "fit" (train+val), "test", "predict", or None (all).
|
|
149
|
+
"""
|
|
150
|
+
# TODO: Adjust split detection logic for your directory layout.
|
|
151
|
+
# Option A — explicit split directories: data/train/, data/val/, data/test/
|
|
152
|
+
has_split_dirs = (
|
|
153
|
+
(self.data_dir / "train").exists() and
|
|
154
|
+
(self.data_dir / "val").exists()
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if stage in (None, "fit"):
|
|
158
|
+
if has_split_dirs:
|
|
159
|
+
self.train_dataset = YourDataset(
|
|
160
|
+
self.data_dir / "train", split="train",
|
|
161
|
+
transform=self._train_transform(),
|
|
162
|
+
)
|
|
163
|
+
self.val_dataset = YourDataset(
|
|
164
|
+
self.data_dir / "val", split="val",
|
|
165
|
+
transform=self._eval_transform(),
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
# Option B — random split from single dataset directory
|
|
169
|
+
full_dataset = YourDataset(
|
|
170
|
+
self.data_dir, split="train",
|
|
171
|
+
transform=self._train_transform(),
|
|
172
|
+
)
|
|
173
|
+
val_size = int(len(full_dataset) * self.hparams.val_split)
|
|
174
|
+
train_size = len(full_dataset) - val_size
|
|
175
|
+
self.train_dataset, self.val_dataset = random_split(
|
|
176
|
+
full_dataset,
|
|
177
|
+
[train_size, val_size],
|
|
178
|
+
generator=torch.Generator().manual_seed(self.hparams.seed),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if stage in (None, "test"):
|
|
182
|
+
test_dir = self.data_dir / "test"
|
|
183
|
+
if test_dir.exists():
|
|
184
|
+
self.test_dataset = YourDataset(
|
|
185
|
+
test_dir, split="test",
|
|
186
|
+
transform=self._eval_transform(),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# ── DataLoaders ───────────────────────────────────────────────────────────
|
|
190
|
+
|
|
191
|
+
def train_dataloader(self) -> DataLoader:
|
|
192
|
+
return DataLoader(
|
|
193
|
+
self.train_dataset,
|
|
194
|
+
batch_size=self.hparams.batch_size,
|
|
195
|
+
shuffle=True,
|
|
196
|
+
num_workers=self.hparams.num_workers,
|
|
197
|
+
pin_memory=True,
|
|
198
|
+
drop_last=True,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def val_dataloader(self) -> DataLoader:
|
|
202
|
+
return DataLoader(
|
|
203
|
+
self.val_dataset,
|
|
204
|
+
batch_size=self.hparams.batch_size,
|
|
205
|
+
shuffle=False,
|
|
206
|
+
num_workers=self.hparams.num_workers,
|
|
207
|
+
pin_memory=True,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def test_dataloader(self) -> DataLoader:
|
|
211
|
+
if self.test_dataset is None:
|
|
212
|
+
raise RuntimeError("No test dataset found. Check data_dir/test/ exists.")
|
|
213
|
+
return DataLoader(
|
|
214
|
+
self.test_dataset,
|
|
215
|
+
batch_size=self.hparams.batch_size,
|
|
216
|
+
shuffle=False,
|
|
217
|
+
num_workers=self.hparams.num_workers,
|
|
218
|
+
pin_memory=True,
|
|
219
|
+
)
|