wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +451 -451
- wavedl/models/__init__.py +80 -4
- wavedl/models/_pretrained_utils.py +366 -0
- wavedl/models/base.py +48 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/cnn.py +2 -27
- wavedl/models/convnext.py +113 -51
- wavedl/models/convnext_v2.py +488 -0
- wavedl/models/densenet.py +10 -23
- wavedl/models/efficientnet.py +6 -6
- wavedl/models/efficientnetv2.py +315 -315
- wavedl/models/efficientvit.py +398 -0
- wavedl/models/fastvit.py +252 -0
- wavedl/models/mamba.py +555 -0
- wavedl/models/maxvit.py +254 -0
- wavedl/models/mobilenetv3.py +295 -295
- wavedl/models/regnet.py +406 -406
- wavedl/models/resnet.py +19 -61
- wavedl/models/resnet3d.py +258 -258
- wavedl/models/swin.py +443 -443
- wavedl/models/tcn.py +393 -409
- wavedl/models/unet.py +2 -6
- wavedl/models/unireplknet.py +491 -0
- wavedl/models/vit.py +9 -9
- wavedl/train.py +1430 -1425
- wavedl/utils/config.py +367 -367
- wavedl/utils/cross_validation.py +530 -530
- wavedl/utils/data.py +39 -6
- wavedl/utils/losses.py +216 -216
- wavedl/utils/optimizers.py +216 -216
- wavedl/utils/schedulers.py +251 -251
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
- wavedl-1.6.1.dist-info/RECORD +46 -0
- wavedl-1.5.7.dist-info/RECORD +0 -38
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
wavedl/utils/optimizers.py
CHANGED
|
@@ -1,216 +1,216 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Optimizers for Deep Learning Training
|
|
3
|
-
=====================================
|
|
4
|
-
|
|
5
|
-
Provides a comprehensive set of optimizers with a factory function
|
|
6
|
-
for easy selection via CLI arguments.
|
|
7
|
-
|
|
8
|
-
Supported Optimizers:
|
|
9
|
-
- adamw: AdamW (default, best for most cases)
|
|
10
|
-
- adam: Adam (legacy)
|
|
11
|
-
- sgd: SGD with momentum
|
|
12
|
-
- nadam: NAdam (Adam + Nesterov momentum)
|
|
13
|
-
- radam: RAdam (variance-adaptive Adam)
|
|
14
|
-
- rmsprop: RMSprop (good for RNNs)
|
|
15
|
-
|
|
16
|
-
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
-
Version: 1.0.0
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
from collections.abc import Iterator
|
|
21
|
-
|
|
22
|
-
import torch
|
|
23
|
-
import torch.optim as optim
|
|
24
|
-
from torch.nn import Parameter
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
# ==============================================================================
|
|
28
|
-
# OPTIMIZER REGISTRY
|
|
29
|
-
# ==============================================================================
|
|
30
|
-
def list_optimizers() -> list[str]:
|
|
31
|
-
"""
|
|
32
|
-
Return list of available optimizer names.
|
|
33
|
-
|
|
34
|
-
Returns:
|
|
35
|
-
List of registered optimizer names
|
|
36
|
-
"""
|
|
37
|
-
return ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def get_optimizer(
|
|
41
|
-
name: str,
|
|
42
|
-
params: Iterator[Parameter] | list[dict],
|
|
43
|
-
lr: float = 1e-3,
|
|
44
|
-
weight_decay: float = 1e-4,
|
|
45
|
-
# SGD-specific
|
|
46
|
-
momentum: float = 0.9,
|
|
47
|
-
nesterov: bool = False,
|
|
48
|
-
dampening: float = 0.0,
|
|
49
|
-
# Adam/AdamW-specific
|
|
50
|
-
betas: tuple = (0.9, 0.999),
|
|
51
|
-
eps: float = 1e-8,
|
|
52
|
-
amsgrad: bool = False,
|
|
53
|
-
# RMSprop-specific
|
|
54
|
-
alpha: float = 0.99,
|
|
55
|
-
centered: bool = False,
|
|
56
|
-
**kwargs,
|
|
57
|
-
) -> optim.Optimizer:
|
|
58
|
-
"""
|
|
59
|
-
Factory function to create optimizer by name.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
name: Optimizer name (see list_optimizers())
|
|
63
|
-
params: Model parameters or parameter groups
|
|
64
|
-
lr: Learning rate
|
|
65
|
-
weight_decay: Weight decay (L2 penalty)
|
|
66
|
-
momentum: Momentum factor (SGD, RMSprop)
|
|
67
|
-
nesterov: Enable Nesterov momentum (SGD)
|
|
68
|
-
dampening: Dampening for momentum (SGD)
|
|
69
|
-
betas: Coefficients for computing running averages (Adam variants)
|
|
70
|
-
eps: Term for numerical stability (Adam variants, RMSprop)
|
|
71
|
-
amsgrad: Use AMSGrad variant (Adam variants)
|
|
72
|
-
alpha: Smoothing constant (RMSprop)
|
|
73
|
-
centered: Compute centered RMSprop
|
|
74
|
-
**kwargs: Additional optimizer-specific arguments
|
|
75
|
-
|
|
76
|
-
Returns:
|
|
77
|
-
Instantiated optimizer
|
|
78
|
-
|
|
79
|
-
Raises:
|
|
80
|
-
ValueError: If optimizer name is not recognized
|
|
81
|
-
|
|
82
|
-
Example:
|
|
83
|
-
>>> optimizer = get_optimizer("adamw", model.parameters(), lr=1e-3)
|
|
84
|
-
>>> optimizer = get_optimizer("sgd", model.parameters(), lr=1e-2, momentum=0.9)
|
|
85
|
-
"""
|
|
86
|
-
name_lower = name.lower()
|
|
87
|
-
|
|
88
|
-
if name_lower == "adamw":
|
|
89
|
-
return optim.AdamW(
|
|
90
|
-
params,
|
|
91
|
-
lr=lr,
|
|
92
|
-
betas=betas,
|
|
93
|
-
eps=eps,
|
|
94
|
-
weight_decay=weight_decay,
|
|
95
|
-
amsgrad=amsgrad,
|
|
96
|
-
**kwargs,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
elif name_lower == "adam":
|
|
100
|
-
return optim.Adam(
|
|
101
|
-
params,
|
|
102
|
-
lr=lr,
|
|
103
|
-
betas=betas,
|
|
104
|
-
eps=eps,
|
|
105
|
-
weight_decay=weight_decay,
|
|
106
|
-
amsgrad=amsgrad,
|
|
107
|
-
**kwargs,
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
elif name_lower == "sgd":
|
|
111
|
-
return optim.SGD(
|
|
112
|
-
params,
|
|
113
|
-
lr=lr,
|
|
114
|
-
momentum=momentum,
|
|
115
|
-
dampening=dampening,
|
|
116
|
-
weight_decay=weight_decay,
|
|
117
|
-
nesterov=nesterov,
|
|
118
|
-
**kwargs,
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
elif name_lower == "nadam":
|
|
122
|
-
return optim.NAdam(
|
|
123
|
-
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
elif name_lower == "radam":
|
|
127
|
-
return optim.RAdam(
|
|
128
|
-
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
elif name_lower == "rmsprop":
|
|
132
|
-
return optim.RMSprop(
|
|
133
|
-
params,
|
|
134
|
-
lr=lr,
|
|
135
|
-
alpha=alpha,
|
|
136
|
-
eps=eps,
|
|
137
|
-
weight_decay=weight_decay,
|
|
138
|
-
momentum=momentum,
|
|
139
|
-
centered=centered,
|
|
140
|
-
**kwargs,
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
else:
|
|
144
|
-
available = ", ".join(list_optimizers())
|
|
145
|
-
raise ValueError(f"Unknown optimizer: '{name}'. Available options: {available}")
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
def get_optimizer_with_param_groups(
|
|
149
|
-
name: str,
|
|
150
|
-
model: torch.nn.Module,
|
|
151
|
-
lr: float = 1e-3,
|
|
152
|
-
weight_decay: float = 1e-4,
|
|
153
|
-
no_decay_keywords: list[str] = None,
|
|
154
|
-
**kwargs,
|
|
155
|
-
) -> optim.Optimizer:
|
|
156
|
-
"""
|
|
157
|
-
Create optimizer with automatic parameter grouping.
|
|
158
|
-
|
|
159
|
-
Separates parameters into decay and no-decay groups based on
|
|
160
|
-
parameter names. By default, bias and normalization layer parameters
|
|
161
|
-
are excluded from weight decay.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
name: Optimizer name
|
|
165
|
-
model: PyTorch model
|
|
166
|
-
lr: Learning rate
|
|
167
|
-
weight_decay: Weight decay for applicable parameters
|
|
168
|
-
no_decay_keywords: Keywords to identify no-decay parameters
|
|
169
|
-
Default: ['bias', 'norm', 'bn', 'ln']
|
|
170
|
-
**kwargs: Additional optimizer arguments
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
Optimizer with parameter groups configured
|
|
174
|
-
|
|
175
|
-
Example:
|
|
176
|
-
>>> optimizer = get_optimizer_with_param_groups(
|
|
177
|
-
... "adamw", model, lr=1e-3, weight_decay=1e-4
|
|
178
|
-
... )
|
|
179
|
-
"""
|
|
180
|
-
if no_decay_keywords is None:
|
|
181
|
-
no_decay_keywords = ["bias", "norm", "bn", "ln"]
|
|
182
|
-
|
|
183
|
-
decay_params = []
|
|
184
|
-
no_decay_params = []
|
|
185
|
-
|
|
186
|
-
for name_param, param in model.named_parameters():
|
|
187
|
-
if not param.requires_grad:
|
|
188
|
-
continue
|
|
189
|
-
|
|
190
|
-
# Check if any no-decay keyword is in parameter name
|
|
191
|
-
if any(kw in name_param.lower() for kw in no_decay_keywords):
|
|
192
|
-
no_decay_params.append(param)
|
|
193
|
-
else:
|
|
194
|
-
decay_params.append(param)
|
|
195
|
-
|
|
196
|
-
param_groups = []
|
|
197
|
-
if decay_params:
|
|
198
|
-
param_groups.append(
|
|
199
|
-
{
|
|
200
|
-
"params": decay_params,
|
|
201
|
-
"weight_decay": weight_decay,
|
|
202
|
-
}
|
|
203
|
-
)
|
|
204
|
-
if no_decay_params:
|
|
205
|
-
param_groups.append(
|
|
206
|
-
{
|
|
207
|
-
"params": no_decay_params,
|
|
208
|
-
"weight_decay": 0.0,
|
|
209
|
-
}
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
# Fall back to all parameters if grouping fails
|
|
213
|
-
if not param_groups:
|
|
214
|
-
param_groups = [{"params": model.parameters()}]
|
|
215
|
-
|
|
216
|
-
return get_optimizer(name, param_groups, lr=lr, weight_decay=0.0, **kwargs)
|
|
1
|
+
"""
|
|
2
|
+
Optimizers for Deep Learning Training
|
|
3
|
+
=====================================
|
|
4
|
+
|
|
5
|
+
Provides a comprehensive set of optimizers with a factory function
|
|
6
|
+
for easy selection via CLI arguments.
|
|
7
|
+
|
|
8
|
+
Supported Optimizers:
|
|
9
|
+
- adamw: AdamW (default, best for most cases)
|
|
10
|
+
- adam: Adam (legacy)
|
|
11
|
+
- sgd: SGD with momentum
|
|
12
|
+
- nadam: NAdam (Adam + Nesterov momentum)
|
|
13
|
+
- radam: RAdam (variance-adaptive Adam)
|
|
14
|
+
- rmsprop: RMSprop (good for RNNs)
|
|
15
|
+
|
|
16
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
+
Version: 1.0.0
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from collections.abc import Iterator
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch.optim as optim
|
|
24
|
+
from torch.nn import Parameter
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ==============================================================================
|
|
28
|
+
# OPTIMIZER REGISTRY
|
|
29
|
+
# ==============================================================================
|
|
30
|
+
def list_optimizers() -> list[str]:
|
|
31
|
+
"""
|
|
32
|
+
Return list of available optimizer names.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
List of registered optimizer names
|
|
36
|
+
"""
|
|
37
|
+
return ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_optimizer(
|
|
41
|
+
name: str,
|
|
42
|
+
params: Iterator[Parameter] | list[dict],
|
|
43
|
+
lr: float = 1e-3,
|
|
44
|
+
weight_decay: float = 1e-4,
|
|
45
|
+
# SGD-specific
|
|
46
|
+
momentum: float = 0.9,
|
|
47
|
+
nesterov: bool = False,
|
|
48
|
+
dampening: float = 0.0,
|
|
49
|
+
# Adam/AdamW-specific
|
|
50
|
+
betas: tuple = (0.9, 0.999),
|
|
51
|
+
eps: float = 1e-8,
|
|
52
|
+
amsgrad: bool = False,
|
|
53
|
+
# RMSprop-specific
|
|
54
|
+
alpha: float = 0.99,
|
|
55
|
+
centered: bool = False,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> optim.Optimizer:
|
|
58
|
+
"""
|
|
59
|
+
Factory function to create optimizer by name.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
name: Optimizer name (see list_optimizers())
|
|
63
|
+
params: Model parameters or parameter groups
|
|
64
|
+
lr: Learning rate
|
|
65
|
+
weight_decay: Weight decay (L2 penalty)
|
|
66
|
+
momentum: Momentum factor (SGD, RMSprop)
|
|
67
|
+
nesterov: Enable Nesterov momentum (SGD)
|
|
68
|
+
dampening: Dampening for momentum (SGD)
|
|
69
|
+
betas: Coefficients for computing running averages (Adam variants)
|
|
70
|
+
eps: Term for numerical stability (Adam variants, RMSprop)
|
|
71
|
+
amsgrad: Use AMSGrad variant (Adam variants)
|
|
72
|
+
alpha: Smoothing constant (RMSprop)
|
|
73
|
+
centered: Compute centered RMSprop
|
|
74
|
+
**kwargs: Additional optimizer-specific arguments
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Instantiated optimizer
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If optimizer name is not recognized
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> optimizer = get_optimizer("adamw", model.parameters(), lr=1e-3)
|
|
84
|
+
>>> optimizer = get_optimizer("sgd", model.parameters(), lr=1e-2, momentum=0.9)
|
|
85
|
+
"""
|
|
86
|
+
name_lower = name.lower()
|
|
87
|
+
|
|
88
|
+
if name_lower == "adamw":
|
|
89
|
+
return optim.AdamW(
|
|
90
|
+
params,
|
|
91
|
+
lr=lr,
|
|
92
|
+
betas=betas,
|
|
93
|
+
eps=eps,
|
|
94
|
+
weight_decay=weight_decay,
|
|
95
|
+
amsgrad=amsgrad,
|
|
96
|
+
**kwargs,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
elif name_lower == "adam":
|
|
100
|
+
return optim.Adam(
|
|
101
|
+
params,
|
|
102
|
+
lr=lr,
|
|
103
|
+
betas=betas,
|
|
104
|
+
eps=eps,
|
|
105
|
+
weight_decay=weight_decay,
|
|
106
|
+
amsgrad=amsgrad,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
elif name_lower == "sgd":
|
|
111
|
+
return optim.SGD(
|
|
112
|
+
params,
|
|
113
|
+
lr=lr,
|
|
114
|
+
momentum=momentum,
|
|
115
|
+
dampening=dampening,
|
|
116
|
+
weight_decay=weight_decay,
|
|
117
|
+
nesterov=nesterov,
|
|
118
|
+
**kwargs,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
elif name_lower == "nadam":
|
|
122
|
+
return optim.NAdam(
|
|
123
|
+
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
elif name_lower == "radam":
|
|
127
|
+
return optim.RAdam(
|
|
128
|
+
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
elif name_lower == "rmsprop":
|
|
132
|
+
return optim.RMSprop(
|
|
133
|
+
params,
|
|
134
|
+
lr=lr,
|
|
135
|
+
alpha=alpha,
|
|
136
|
+
eps=eps,
|
|
137
|
+
weight_decay=weight_decay,
|
|
138
|
+
momentum=momentum,
|
|
139
|
+
centered=centered,
|
|
140
|
+
**kwargs,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
else:
|
|
144
|
+
available = ", ".join(list_optimizers())
|
|
145
|
+
raise ValueError(f"Unknown optimizer: '{name}'. Available options: {available}")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_optimizer_with_param_groups(
|
|
149
|
+
name: str,
|
|
150
|
+
model: torch.nn.Module,
|
|
151
|
+
lr: float = 1e-3,
|
|
152
|
+
weight_decay: float = 1e-4,
|
|
153
|
+
no_decay_keywords: list[str] = None,
|
|
154
|
+
**kwargs,
|
|
155
|
+
) -> optim.Optimizer:
|
|
156
|
+
"""
|
|
157
|
+
Create optimizer with automatic parameter grouping.
|
|
158
|
+
|
|
159
|
+
Separates parameters into decay and no-decay groups based on
|
|
160
|
+
parameter names. By default, bias and normalization layer parameters
|
|
161
|
+
are excluded from weight decay.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
name: Optimizer name
|
|
165
|
+
model: PyTorch model
|
|
166
|
+
lr: Learning rate
|
|
167
|
+
weight_decay: Weight decay for applicable parameters
|
|
168
|
+
no_decay_keywords: Keywords to identify no-decay parameters
|
|
169
|
+
Default: ['bias', 'norm', 'bn', 'ln']
|
|
170
|
+
**kwargs: Additional optimizer arguments
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Optimizer with parameter groups configured
|
|
174
|
+
|
|
175
|
+
Example:
|
|
176
|
+
>>> optimizer = get_optimizer_with_param_groups(
|
|
177
|
+
... "adamw", model, lr=1e-3, weight_decay=1e-4
|
|
178
|
+
... )
|
|
179
|
+
"""
|
|
180
|
+
if no_decay_keywords is None:
|
|
181
|
+
no_decay_keywords = ["bias", "norm", "bn", "ln"]
|
|
182
|
+
|
|
183
|
+
decay_params = []
|
|
184
|
+
no_decay_params = []
|
|
185
|
+
|
|
186
|
+
for name_param, param in model.named_parameters():
|
|
187
|
+
if not param.requires_grad:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
# Check if any no-decay keyword is in parameter name
|
|
191
|
+
if any(kw in name_param.lower() for kw in no_decay_keywords):
|
|
192
|
+
no_decay_params.append(param)
|
|
193
|
+
else:
|
|
194
|
+
decay_params.append(param)
|
|
195
|
+
|
|
196
|
+
param_groups = []
|
|
197
|
+
if decay_params:
|
|
198
|
+
param_groups.append(
|
|
199
|
+
{
|
|
200
|
+
"params": decay_params,
|
|
201
|
+
"weight_decay": weight_decay,
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
if no_decay_params:
|
|
205
|
+
param_groups.append(
|
|
206
|
+
{
|
|
207
|
+
"params": no_decay_params,
|
|
208
|
+
"weight_decay": 0.0,
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Fall back to all parameters if grouping fails
|
|
213
|
+
if not param_groups:
|
|
214
|
+
param_groups = [{"params": model.parameters()}]
|
|
215
|
+
|
|
216
|
+
return get_optimizer(name, param_groups, lr=lr, weight_decay=0.0, **kwargs)
|