wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
wavedl/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Deep Learning Framework for Wave-Based Inverse Problems
|
|
3
|
+
=================================================================
|
|
4
|
+
|
|
5
|
+
A scalable deep learning framework for wave-based inverse problems,
|
|
6
|
+
from ultrasonic NDE and geophysics to biomedical tissue characterization.
|
|
7
|
+
|
|
8
|
+
Quick Start:
|
|
9
|
+
from wavedl.models import build_model, list_models
|
|
10
|
+
from wavedl.utils import prepare_data, load_test_data
|
|
11
|
+
|
|
12
|
+
For training:
|
|
13
|
+
wavedl-train --model cnn --data_path train.npz
|
|
14
|
+
# or: python -m wavedl.train --model cnn --data_path train.npz
|
|
15
|
+
|
|
16
|
+
For inference:
|
|
17
|
+
wavedl-test --checkpoint best_checkpoint --data_path test.npz
|
|
18
|
+
# or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
__version__ = "1.2.0"
|
|
22
|
+
__author__ = "Ductho Le"
|
|
23
|
+
__email__ = "ductho.le@outlook.com"
|
|
24
|
+
|
|
25
|
+
# Re-export key APIs for convenience
|
|
26
|
+
from wavedl.models import build_model, get_model, list_models, register_model
|
|
27
|
+
from wavedl.utils import (
|
|
28
|
+
load_test_data,
|
|
29
|
+
load_training_data,
|
|
30
|
+
prepare_data,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"__version__",
|
|
36
|
+
"build_model",
|
|
37
|
+
"get_model",
|
|
38
|
+
"list_models",
|
|
39
|
+
"load_test_data",
|
|
40
|
+
"load_training_data",
|
|
41
|
+
"prepare_data",
|
|
42
|
+
"register_model",
|
|
43
|
+
]
|
wavedl/hpo.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WaveDL - Hyperparameter Optimization with Optuna
|
|
3
|
+
=================================================
|
|
4
|
+
Automated hyperparameter search for finding optimal training configurations.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
# Basic HPO (50 trials)
|
|
8
|
+
python hpo.py --data_path train.npz --n_trials 50
|
|
9
|
+
|
|
10
|
+
# Quick search (fewer parameters)
|
|
11
|
+
python hpo.py --data_path train.npz --n_trials 30 --quick
|
|
12
|
+
|
|
13
|
+
# Full search with specific models
|
|
14
|
+
python hpo.py --data_path train.npz --n_trials 100 --models cnn resnet18 efficientnet_b0
|
|
15
|
+
|
|
16
|
+
# Parallel trials on multiple GPUs
|
|
17
|
+
python hpo.py --data_path train.npz --n_trials 100 --n_jobs 4
|
|
18
|
+
|
|
19
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
import json
|
|
24
|
+
import subprocess
|
|
25
|
+
import sys
|
|
26
|
+
import tempfile
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import optuna
|
|
32
|
+
from optuna.trial import TrialState
|
|
33
|
+
except ImportError:
|
|
34
|
+
print("Error: Optuna not installed. Run: pip install -e '.[hpo]'")
|
|
35
|
+
sys.exit(1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# DEFAULT SEARCH SPACES
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
DEFAULT_MODELS = ["cnn", "resnet18", "resnet34"]
|
|
43
|
+
QUICK_MODELS = ["cnn"]
|
|
44
|
+
|
|
45
|
+
# All 6 optimizers
|
|
46
|
+
DEFAULT_OPTIMIZERS = ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
47
|
+
QUICK_OPTIMIZERS = ["adamw"]
|
|
48
|
+
|
|
49
|
+
# All 8 schedulers
|
|
50
|
+
DEFAULT_SCHEDULERS = [
|
|
51
|
+
"plateau",
|
|
52
|
+
"cosine",
|
|
53
|
+
"cosine_restarts",
|
|
54
|
+
"onecycle",
|
|
55
|
+
"step",
|
|
56
|
+
"multistep",
|
|
57
|
+
"exponential",
|
|
58
|
+
"linear_warmup",
|
|
59
|
+
]
|
|
60
|
+
QUICK_SCHEDULERS = ["plateau"]
|
|
61
|
+
|
|
62
|
+
# All 6 losses
|
|
63
|
+
DEFAULT_LOSSES = ["mse", "mae", "huber", "smooth_l1", "log_cosh", "weighted_mse"]
|
|
64
|
+
QUICK_LOSSES = ["mse"]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# =============================================================================
|
|
68
|
+
# OBJECTIVE FUNCTION
|
|
69
|
+
# =============================================================================
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def create_objective(args):
|
|
73
|
+
"""Create Optuna objective function with configurable search space."""
|
|
74
|
+
|
|
75
|
+
def objective(trial):
|
|
76
|
+
# Select search space based on mode
|
|
77
|
+
# CLI arguments always take precedence over defaults
|
|
78
|
+
if args.quick:
|
|
79
|
+
models = args.models or QUICK_MODELS
|
|
80
|
+
optimizers = args.optimizers or QUICK_OPTIMIZERS
|
|
81
|
+
schedulers = args.schedulers or QUICK_SCHEDULERS
|
|
82
|
+
losses = args.losses or QUICK_LOSSES
|
|
83
|
+
else:
|
|
84
|
+
models = args.models or DEFAULT_MODELS
|
|
85
|
+
optimizers = args.optimizers or DEFAULT_OPTIMIZERS
|
|
86
|
+
schedulers = args.schedulers or DEFAULT_SCHEDULERS
|
|
87
|
+
losses = args.losses or DEFAULT_LOSSES
|
|
88
|
+
|
|
89
|
+
# Suggest hyperparameters
|
|
90
|
+
model = trial.suggest_categorical("model", models)
|
|
91
|
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
92
|
+
batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
|
|
93
|
+
optimizer = trial.suggest_categorical("optimizer", optimizers)
|
|
94
|
+
scheduler = trial.suggest_categorical("scheduler", schedulers)
|
|
95
|
+
loss = trial.suggest_categorical("loss", losses)
|
|
96
|
+
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
|
|
97
|
+
patience = trial.suggest_int("patience", 10, 30, step=5)
|
|
98
|
+
|
|
99
|
+
# Conditional hyperparameters
|
|
100
|
+
if loss == "huber":
|
|
101
|
+
huber_delta = trial.suggest_float("huber_delta", 0.1, 2.0)
|
|
102
|
+
else:
|
|
103
|
+
huber_delta = None
|
|
104
|
+
|
|
105
|
+
if optimizer == "sgd":
|
|
106
|
+
momentum = trial.suggest_float("momentum", 0.8, 0.99)
|
|
107
|
+
else:
|
|
108
|
+
momentum = None
|
|
109
|
+
|
|
110
|
+
# Build command
|
|
111
|
+
cmd = [
|
|
112
|
+
sys.executable,
|
|
113
|
+
"-m",
|
|
114
|
+
"wavedl.train",
|
|
115
|
+
"--data_path",
|
|
116
|
+
str(args.data_path),
|
|
117
|
+
"--model",
|
|
118
|
+
model,
|
|
119
|
+
"--lr",
|
|
120
|
+
str(lr),
|
|
121
|
+
"--batch_size",
|
|
122
|
+
str(batch_size),
|
|
123
|
+
"--optimizer",
|
|
124
|
+
optimizer,
|
|
125
|
+
"--scheduler",
|
|
126
|
+
scheduler,
|
|
127
|
+
"--loss",
|
|
128
|
+
loss,
|
|
129
|
+
"--weight_decay",
|
|
130
|
+
str(weight_decay),
|
|
131
|
+
"--patience",
|
|
132
|
+
str(patience),
|
|
133
|
+
"--epochs",
|
|
134
|
+
str(args.max_epochs),
|
|
135
|
+
"--seed",
|
|
136
|
+
str(args.seed),
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
# Add conditional args
|
|
140
|
+
if huber_delta:
|
|
141
|
+
cmd.extend(["--huber_delta", str(huber_delta)])
|
|
142
|
+
if momentum:
|
|
143
|
+
cmd.extend(["--momentum", str(momentum)])
|
|
144
|
+
|
|
145
|
+
# Use temporary directory for trial output
|
|
146
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
147
|
+
cmd.extend(["--output_dir", tmpdir])
|
|
148
|
+
|
|
149
|
+
# Run training
|
|
150
|
+
try:
|
|
151
|
+
result = subprocess.run(
|
|
152
|
+
cmd,
|
|
153
|
+
capture_output=True,
|
|
154
|
+
text=True,
|
|
155
|
+
timeout=args.timeout,
|
|
156
|
+
cwd=Path(__file__).parent,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Parse validation loss from output
|
|
160
|
+
# Look for "Best val_loss: X.XXXX" in stdout
|
|
161
|
+
val_loss = None
|
|
162
|
+
for line in result.stdout.split("\n"):
|
|
163
|
+
if "Best val_loss:" in line:
|
|
164
|
+
try:
|
|
165
|
+
val_loss = float(line.split(":")[-1].strip())
|
|
166
|
+
except ValueError:
|
|
167
|
+
pass
|
|
168
|
+
# Also check for final validation loss
|
|
169
|
+
if "val_loss=" in line.lower():
|
|
170
|
+
try:
|
|
171
|
+
# Extract number after val_loss=
|
|
172
|
+
parts = line.lower().split("val_loss=")
|
|
173
|
+
if len(parts) > 1:
|
|
174
|
+
val_str = parts[1].split()[0].strip(",")
|
|
175
|
+
val_loss = float(val_str)
|
|
176
|
+
except (ValueError, IndexError):
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
if val_loss is None:
|
|
180
|
+
# Training failed or no loss found
|
|
181
|
+
print(f"Trial {trial.number}: Training failed")
|
|
182
|
+
return float("inf")
|
|
183
|
+
|
|
184
|
+
print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
|
|
185
|
+
return val_loss
|
|
186
|
+
|
|
187
|
+
except subprocess.TimeoutExpired:
|
|
188
|
+
print(f"Trial {trial.number}: Timeout after {args.timeout}s")
|
|
189
|
+
return float("inf")
|
|
190
|
+
except Exception as e:
|
|
191
|
+
print(f"Trial {trial.number}: Error - {e}")
|
|
192
|
+
return float("inf")
|
|
193
|
+
|
|
194
|
+
return objective
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# =============================================================================
|
|
198
|
+
# MAIN FUNCTION
|
|
199
|
+
# =============================================================================
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def main():
|
|
203
|
+
parser = argparse.ArgumentParser(
|
|
204
|
+
description="WaveDL Hyperparameter Optimization with Optuna",
|
|
205
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
206
|
+
epilog="""
|
|
207
|
+
Examples:
|
|
208
|
+
python hpo.py --data_path train.npz --n_trials 50
|
|
209
|
+
python hpo.py --data_path train.npz --n_trials 30 --quick
|
|
210
|
+
python hpo.py --data_path train.npz --n_trials 100 --models cnn resnet18
|
|
211
|
+
""",
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Required
|
|
215
|
+
parser.add_argument(
|
|
216
|
+
"--data_path", type=str, required=True, help="Path to training data"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# HPO settings
|
|
220
|
+
parser.add_argument(
|
|
221
|
+
"--n_trials", type=int, default=50, help="Number of HPO trials (default: 50)"
|
|
222
|
+
)
|
|
223
|
+
parser.add_argument(
|
|
224
|
+
"--n_jobs", type=int, default=1, help="Parallel trials (default: 1)"
|
|
225
|
+
)
|
|
226
|
+
parser.add_argument(
|
|
227
|
+
"--quick",
|
|
228
|
+
action="store_true",
|
|
229
|
+
help="Quick mode: search fewer parameters",
|
|
230
|
+
)
|
|
231
|
+
parser.add_argument(
|
|
232
|
+
"--timeout",
|
|
233
|
+
type=int,
|
|
234
|
+
default=3600,
|
|
235
|
+
help="Timeout per trial in seconds (default: 3600)",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Search space customization
|
|
239
|
+
parser.add_argument(
|
|
240
|
+
"--models",
|
|
241
|
+
nargs="+",
|
|
242
|
+
default=None,
|
|
243
|
+
help=f"Models to search (default: {DEFAULT_MODELS})",
|
|
244
|
+
)
|
|
245
|
+
parser.add_argument(
|
|
246
|
+
"--optimizers",
|
|
247
|
+
nargs="+",
|
|
248
|
+
default=None,
|
|
249
|
+
help=f"Optimizers to search (default: {DEFAULT_OPTIMIZERS})",
|
|
250
|
+
)
|
|
251
|
+
parser.add_argument(
|
|
252
|
+
"--schedulers",
|
|
253
|
+
nargs="+",
|
|
254
|
+
default=None,
|
|
255
|
+
help=f"Schedulers to search (default: {DEFAULT_SCHEDULERS})",
|
|
256
|
+
)
|
|
257
|
+
parser.add_argument(
|
|
258
|
+
"--losses",
|
|
259
|
+
nargs="+",
|
|
260
|
+
default=None,
|
|
261
|
+
help=f"Losses to search (default: {DEFAULT_LOSSES})",
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Training settings for each trial
|
|
265
|
+
parser.add_argument(
|
|
266
|
+
"--max_epochs",
|
|
267
|
+
type=int,
|
|
268
|
+
default=50,
|
|
269
|
+
help="Max epochs per trial (default: 50, use early stopping)",
|
|
270
|
+
)
|
|
271
|
+
parser.add_argument(
|
|
272
|
+
"--seed", type=int, default=2025, help="Random seed (default: 2025)"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Output
|
|
276
|
+
parser.add_argument(
|
|
277
|
+
"--output",
|
|
278
|
+
type=str,
|
|
279
|
+
default="hpo_results.json",
|
|
280
|
+
help="Output file for best params (default: hpo_results.json)",
|
|
281
|
+
)
|
|
282
|
+
parser.add_argument(
|
|
283
|
+
"--study_name",
|
|
284
|
+
type=str,
|
|
285
|
+
default="wavedl_hpo",
|
|
286
|
+
help="Optuna study name (default: wavedl_hpo)",
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
args = parser.parse_args()
|
|
290
|
+
|
|
291
|
+
# Validate data path
|
|
292
|
+
if not Path(args.data_path).exists():
|
|
293
|
+
print(f"Error: Data file not found: {args.data_path}")
|
|
294
|
+
sys.exit(1)
|
|
295
|
+
|
|
296
|
+
# Create study
|
|
297
|
+
print("=" * 60)
|
|
298
|
+
print("WaveDL Hyperparameter Optimization")
|
|
299
|
+
print("=" * 60)
|
|
300
|
+
print(f"Data: {args.data_path}")
|
|
301
|
+
print(f"Trials: {args.n_trials}")
|
|
302
|
+
print(f"Mode: {'Quick' if args.quick else 'Full'}")
|
|
303
|
+
print(f"Parallel jobs: {args.n_jobs}")
|
|
304
|
+
print("=" * 60)
|
|
305
|
+
|
|
306
|
+
study = optuna.create_study(
|
|
307
|
+
study_name=args.study_name,
|
|
308
|
+
direction="minimize",
|
|
309
|
+
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10),
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Run optimization
|
|
313
|
+
objective = create_objective(args)
|
|
314
|
+
study.optimize(
|
|
315
|
+
objective,
|
|
316
|
+
n_trials=args.n_trials,
|
|
317
|
+
n_jobs=args.n_jobs,
|
|
318
|
+
show_progress_bar=True,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Results
|
|
322
|
+
print("\n" + "=" * 60)
|
|
323
|
+
print("OPTIMIZATION COMPLETE")
|
|
324
|
+
print("=" * 60)
|
|
325
|
+
|
|
326
|
+
# Filter completed trials
|
|
327
|
+
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
|
328
|
+
|
|
329
|
+
if not completed_trials:
|
|
330
|
+
print("No trials completed successfully.")
|
|
331
|
+
sys.exit(1)
|
|
332
|
+
|
|
333
|
+
print(f"\nCompleted trials: {len(completed_trials)}/{args.n_trials}")
|
|
334
|
+
print(f"Best trial: #{study.best_trial.number}")
|
|
335
|
+
print(f"Best val_loss: {study.best_value:.6f}")
|
|
336
|
+
|
|
337
|
+
print("\nBest hyperparameters:")
|
|
338
|
+
for key, value in study.best_params.items():
|
|
339
|
+
print(f" {key}: {value}")
|
|
340
|
+
|
|
341
|
+
# Save results
|
|
342
|
+
results = {
|
|
343
|
+
"best_value": study.best_value,
|
|
344
|
+
"best_params": study.best_params,
|
|
345
|
+
"n_trials": len(completed_trials),
|
|
346
|
+
"study_name": args.study_name,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
with open(args.output, "w") as f:
|
|
350
|
+
json.dump(results, f, indent=2)
|
|
351
|
+
|
|
352
|
+
print(f"\nResults saved to: {args.output}")
|
|
353
|
+
|
|
354
|
+
# Print command to train with best params
|
|
355
|
+
print("\n" + "=" * 60)
|
|
356
|
+
print("TO TRAIN WITH BEST PARAMETERS:")
|
|
357
|
+
print("=" * 60)
|
|
358
|
+
cmd_parts = ["accelerate launch train.py"]
|
|
359
|
+
cmd_parts.append(f"--data_path {args.data_path}")
|
|
360
|
+
for key, value in study.best_params.items():
|
|
361
|
+
cmd_parts.append(f"--{key} {value}")
|
|
362
|
+
print(" \\\n ".join(cmd_parts))
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
if __name__ == "__main__":
|
|
366
|
+
main()
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Registry and Factory Pattern for Deep Learning Architectures
|
|
3
|
+
===================================================================
|
|
4
|
+
|
|
5
|
+
This module provides a centralized registry for neural network architectures,
|
|
6
|
+
enabling dynamic model selection via command-line arguments.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from wavedl.models import get_model, list_models, MODEL_REGISTRY
|
|
10
|
+
|
|
11
|
+
# List available models
|
|
12
|
+
print(list_models())
|
|
13
|
+
|
|
14
|
+
# Get a model class by name
|
|
15
|
+
ModelClass = get_model("cnn")
|
|
16
|
+
model = ModelClass(in_shape=(500, 500), out_size=5)
|
|
17
|
+
|
|
18
|
+
Adding New Models:
|
|
19
|
+
1. Create a new file in models/ (e.g., models/my_model.py)
|
|
20
|
+
2. Inherit from BaseModel
|
|
21
|
+
3. Use the @register_model decorator
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
from wavedl.models.base import BaseModel
|
|
25
|
+
from wavedl.models.registry import register_model
|
|
26
|
+
|
|
27
|
+
@register_model("my_model")
|
|
28
|
+
class MyModel(BaseModel):
|
|
29
|
+
def __init__(self, in_shape, out_size, **kwargs):
|
|
30
|
+
super().__init__(in_shape, out_size)
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
34
|
+
Version: 1.0.0
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# Import registry first (no dependencies)
|
|
38
|
+
# Import base class (depends only on torch)
|
|
39
|
+
from .base import BaseModel
|
|
40
|
+
|
|
41
|
+
# Import model implementations (triggers registration via decorators)
|
|
42
|
+
from .cnn import CNN
|
|
43
|
+
from .convnext import ConvNeXtBase_, ConvNeXtSmall, ConvNeXtTiny
|
|
44
|
+
from .densenet import DenseNet121, DenseNet169
|
|
45
|
+
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2
|
|
46
|
+
from .registry import (
|
|
47
|
+
MODEL_REGISTRY,
|
|
48
|
+
build_model,
|
|
49
|
+
get_model,
|
|
50
|
+
list_models,
|
|
51
|
+
register_model,
|
|
52
|
+
)
|
|
53
|
+
from .resnet import ResNet18, ResNet34, ResNet50
|
|
54
|
+
from .unet import UNet, UNetRegression
|
|
55
|
+
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Export public API
|
|
59
|
+
__all__ = [
|
|
60
|
+
# Models
|
|
61
|
+
"CNN",
|
|
62
|
+
# Registry
|
|
63
|
+
"MODEL_REGISTRY",
|
|
64
|
+
# Base class
|
|
65
|
+
"BaseModel",
|
|
66
|
+
"ConvNeXtBase_",
|
|
67
|
+
"ConvNeXtSmall",
|
|
68
|
+
"ConvNeXtTiny",
|
|
69
|
+
"DenseNet121",
|
|
70
|
+
"DenseNet169",
|
|
71
|
+
"EfficientNetB0",
|
|
72
|
+
"EfficientNetB1",
|
|
73
|
+
"EfficientNetB2",
|
|
74
|
+
"ResNet18",
|
|
75
|
+
"ResNet34",
|
|
76
|
+
"ResNet50",
|
|
77
|
+
"UNet",
|
|
78
|
+
"UNetRegression",
|
|
79
|
+
"ViTBase_",
|
|
80
|
+
"ViTSmall",
|
|
81
|
+
"ViTTiny",
|
|
82
|
+
"build_model",
|
|
83
|
+
"get_model",
|
|
84
|
+
"list_models",
|
|
85
|
+
"register_model",
|
|
86
|
+
]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Template for New Architectures
|
|
3
|
+
=====================================
|
|
4
|
+
|
|
5
|
+
Copy this file and modify to add new model architectures to the framework.
|
|
6
|
+
The model will be automatically registered and available via --model flag.
|
|
7
|
+
|
|
8
|
+
Steps to Add a New Model:
|
|
9
|
+
1. Copy this file to models/your_model.py
|
|
10
|
+
2. Rename the class and update @register_model("your_model")
|
|
11
|
+
3. Implement the __init__ and forward methods
|
|
12
|
+
4. Import your model in models/__init__.py:
|
|
13
|
+
from wavedl.models.your_model import YourModel
|
|
14
|
+
5. Run: accelerate launch train.py --model your_model --wandb
|
|
15
|
+
|
|
16
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
+
Version: 1.0.0
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
|
|
25
|
+
from wavedl.models.base import BaseModel
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Uncomment the decorator to register this model
|
|
29
|
+
# @register_model("template")
|
|
30
|
+
class TemplateModel(BaseModel):
|
|
31
|
+
"""
|
|
32
|
+
Template Model Architecture.
|
|
33
|
+
|
|
34
|
+
Replace this docstring with your model description.
|
|
35
|
+
The first line will appear in --list_models output.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
in_shape: Input spatial dimensions (H, W)
|
|
39
|
+
out_size: Number of regression output targets
|
|
40
|
+
hidden_dim: Size of hidden layers (default: 256)
|
|
41
|
+
num_layers: Number of convolutional layers (default: 4)
|
|
42
|
+
dropout: Dropout rate (default: 0.1)
|
|
43
|
+
|
|
44
|
+
Input Shape:
|
|
45
|
+
(B, 1, H, W) - Single-channel images
|
|
46
|
+
|
|
47
|
+
Output Shape:
|
|
48
|
+
(B, out_size) - Regression predictions
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
in_shape: tuple[int, int],
|
|
54
|
+
out_size: int,
|
|
55
|
+
hidden_dim: int = 256,
|
|
56
|
+
num_layers: int = 4,
|
|
57
|
+
dropout: float = 0.1,
|
|
58
|
+
**kwargs, # Accept extra kwargs for flexibility
|
|
59
|
+
):
|
|
60
|
+
# REQUIRED: Call parent __init__ with in_shape and out_size
|
|
61
|
+
super().__init__(in_shape, out_size)
|
|
62
|
+
|
|
63
|
+
# Store hyperparameters as attributes (optional but recommended)
|
|
64
|
+
self.hidden_dim = hidden_dim
|
|
65
|
+
self.num_layers = num_layers
|
|
66
|
+
self.dropout_rate = dropout
|
|
67
|
+
|
|
68
|
+
# =================================================================
|
|
69
|
+
# BUILD YOUR ARCHITECTURE HERE
|
|
70
|
+
# =================================================================
|
|
71
|
+
|
|
72
|
+
# Example: Simple CNN encoder
|
|
73
|
+
self.encoder = nn.Sequential(
|
|
74
|
+
# Layer 1
|
|
75
|
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
|
76
|
+
nn.BatchNorm2d(32),
|
|
77
|
+
nn.ReLU(inplace=True),
|
|
78
|
+
nn.MaxPool2d(2),
|
|
79
|
+
# Layer 2
|
|
80
|
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
|
81
|
+
nn.BatchNorm2d(64),
|
|
82
|
+
nn.ReLU(inplace=True),
|
|
83
|
+
nn.MaxPool2d(2),
|
|
84
|
+
# Layer 3
|
|
85
|
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
86
|
+
nn.BatchNorm2d(128),
|
|
87
|
+
nn.ReLU(inplace=True),
|
|
88
|
+
nn.MaxPool2d(2),
|
|
89
|
+
# Layer 4
|
|
90
|
+
nn.Conv2d(128, hidden_dim, kernel_size=3, padding=1),
|
|
91
|
+
nn.BatchNorm2d(hidden_dim),
|
|
92
|
+
nn.ReLU(inplace=True),
|
|
93
|
+
nn.AdaptiveAvgPool2d(1), # Global average pooling
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Example: Regression head
|
|
97
|
+
self.head = nn.Sequential(
|
|
98
|
+
nn.Flatten(),
|
|
99
|
+
nn.Dropout(dropout),
|
|
100
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
101
|
+
nn.ReLU(inplace=True),
|
|
102
|
+
nn.Dropout(dropout),
|
|
103
|
+
nn.Linear(hidden_dim // 2, out_size),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
"""
|
|
108
|
+
Forward pass of the model.
|
|
109
|
+
|
|
110
|
+
REQUIRED: Must accept (B, C, H, W) and return (B, out_size)
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
x: Input tensor of shape (B, 1, H, W)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Output tensor of shape (B, out_size)
|
|
117
|
+
"""
|
|
118
|
+
# Encode
|
|
119
|
+
features = self.encoder(x)
|
|
120
|
+
|
|
121
|
+
# Predict
|
|
122
|
+
output = self.head(features)
|
|
123
|
+
|
|
124
|
+
return output
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def get_default_config(cls) -> dict[str, Any]:
|
|
128
|
+
"""
|
|
129
|
+
Return default hyperparameters for this model.
|
|
130
|
+
|
|
131
|
+
OPTIONAL: Override to provide model-specific defaults.
|
|
132
|
+
These can be used for documentation or config files.
|
|
133
|
+
"""
|
|
134
|
+
return {
|
|
135
|
+
"hidden_dim": 256,
|
|
136
|
+
"num_layers": 4,
|
|
137
|
+
"dropout": 0.1,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# =============================================================================
|
|
142
|
+
# USAGE EXAMPLE
|
|
143
|
+
# =============================================================================
|
|
144
|
+
if __name__ == "__main__":
|
|
145
|
+
# Quick test of the model
|
|
146
|
+
model = TemplateModel(in_shape=(500, 500), out_size=5)
|
|
147
|
+
|
|
148
|
+
# Print model summary
|
|
149
|
+
print(f"Model: {model.__class__.__name__}")
|
|
150
|
+
print(f"Parameters: {model.count_parameters():,}")
|
|
151
|
+
print(f"Default config: {model.get_default_config()}")
|
|
152
|
+
|
|
153
|
+
# Test forward pass
|
|
154
|
+
dummy_input = torch.randn(2, 1, 500, 500)
|
|
155
|
+
output = model(dummy_input)
|
|
156
|
+
print(f"Input shape: {dummy_input.shape}")
|
|
157
|
+
print(f"Output shape: {output.shape}")
|