wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Optimizers for Deep Learning Training
|
|
3
|
+
=====================================
|
|
4
|
+
|
|
5
|
+
Provides a comprehensive set of optimizers with a factory function
|
|
6
|
+
for easy selection via CLI arguments.
|
|
7
|
+
|
|
8
|
+
Supported Optimizers:
|
|
9
|
+
- adamw: AdamW (default, best for most cases)
|
|
10
|
+
- adam: Adam (legacy)
|
|
11
|
+
- sgd: SGD with momentum
|
|
12
|
+
- nadam: NAdam (Adam + Nesterov momentum)
|
|
13
|
+
- radam: RAdam (variance-adaptive Adam)
|
|
14
|
+
- rmsprop: RMSprop (good for RNNs)
|
|
15
|
+
|
|
16
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
17
|
+
Version: 1.0.0
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from collections.abc import Iterator
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch.optim as optim
|
|
24
|
+
from torch.nn import Parameter
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ==============================================================================
|
|
28
|
+
# OPTIMIZER REGISTRY
|
|
29
|
+
# ==============================================================================
|
|
30
|
+
def list_optimizers() -> list[str]:
|
|
31
|
+
"""
|
|
32
|
+
Return list of available optimizer names.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
List of registered optimizer names
|
|
36
|
+
"""
|
|
37
|
+
return ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_optimizer(
|
|
41
|
+
name: str,
|
|
42
|
+
params: Iterator[Parameter] | list[dict],
|
|
43
|
+
lr: float = 1e-3,
|
|
44
|
+
weight_decay: float = 1e-4,
|
|
45
|
+
# SGD-specific
|
|
46
|
+
momentum: float = 0.9,
|
|
47
|
+
nesterov: bool = False,
|
|
48
|
+
dampening: float = 0.0,
|
|
49
|
+
# Adam/AdamW-specific
|
|
50
|
+
betas: tuple = (0.9, 0.999),
|
|
51
|
+
eps: float = 1e-8,
|
|
52
|
+
amsgrad: bool = False,
|
|
53
|
+
# RMSprop-specific
|
|
54
|
+
alpha: float = 0.99,
|
|
55
|
+
centered: bool = False,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> optim.Optimizer:
|
|
58
|
+
"""
|
|
59
|
+
Factory function to create optimizer by name.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
name: Optimizer name (see list_optimizers())
|
|
63
|
+
params: Model parameters or parameter groups
|
|
64
|
+
lr: Learning rate
|
|
65
|
+
weight_decay: Weight decay (L2 penalty)
|
|
66
|
+
momentum: Momentum factor (SGD, RMSprop)
|
|
67
|
+
nesterov: Enable Nesterov momentum (SGD)
|
|
68
|
+
dampening: Dampening for momentum (SGD)
|
|
69
|
+
betas: Coefficients for computing running averages (Adam variants)
|
|
70
|
+
eps: Term for numerical stability (Adam variants, RMSprop)
|
|
71
|
+
amsgrad: Use AMSGrad variant (Adam variants)
|
|
72
|
+
alpha: Smoothing constant (RMSprop)
|
|
73
|
+
centered: Compute centered RMSprop
|
|
74
|
+
**kwargs: Additional optimizer-specific arguments
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Instantiated optimizer
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If optimizer name is not recognized
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> optimizer = get_optimizer("adamw", model.parameters(), lr=1e-3)
|
|
84
|
+
>>> optimizer = get_optimizer("sgd", model.parameters(), lr=1e-2, momentum=0.9)
|
|
85
|
+
"""
|
|
86
|
+
name_lower = name.lower()
|
|
87
|
+
|
|
88
|
+
if name_lower == "adamw":
|
|
89
|
+
return optim.AdamW(
|
|
90
|
+
params,
|
|
91
|
+
lr=lr,
|
|
92
|
+
betas=betas,
|
|
93
|
+
eps=eps,
|
|
94
|
+
weight_decay=weight_decay,
|
|
95
|
+
amsgrad=amsgrad,
|
|
96
|
+
**kwargs,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
elif name_lower == "adam":
|
|
100
|
+
return optim.Adam(
|
|
101
|
+
params,
|
|
102
|
+
lr=lr,
|
|
103
|
+
betas=betas,
|
|
104
|
+
eps=eps,
|
|
105
|
+
weight_decay=weight_decay,
|
|
106
|
+
amsgrad=amsgrad,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
elif name_lower == "sgd":
|
|
111
|
+
return optim.SGD(
|
|
112
|
+
params,
|
|
113
|
+
lr=lr,
|
|
114
|
+
momentum=momentum,
|
|
115
|
+
dampening=dampening,
|
|
116
|
+
weight_decay=weight_decay,
|
|
117
|
+
nesterov=nesterov,
|
|
118
|
+
**kwargs,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
elif name_lower == "nadam":
|
|
122
|
+
return optim.NAdam(
|
|
123
|
+
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
elif name_lower == "radam":
|
|
127
|
+
return optim.RAdam(
|
|
128
|
+
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
elif name_lower == "rmsprop":
|
|
132
|
+
return optim.RMSprop(
|
|
133
|
+
params,
|
|
134
|
+
lr=lr,
|
|
135
|
+
alpha=alpha,
|
|
136
|
+
eps=eps,
|
|
137
|
+
weight_decay=weight_decay,
|
|
138
|
+
momentum=momentum,
|
|
139
|
+
centered=centered,
|
|
140
|
+
**kwargs,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
else:
|
|
144
|
+
available = ", ".join(list_optimizers())
|
|
145
|
+
raise ValueError(f"Unknown optimizer: '{name}'. Available options: {available}")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_optimizer_with_param_groups(
|
|
149
|
+
name: str,
|
|
150
|
+
model: torch.nn.Module,
|
|
151
|
+
lr: float = 1e-3,
|
|
152
|
+
weight_decay: float = 1e-4,
|
|
153
|
+
no_decay_keywords: list[str] = None,
|
|
154
|
+
**kwargs,
|
|
155
|
+
) -> optim.Optimizer:
|
|
156
|
+
"""
|
|
157
|
+
Create optimizer with automatic parameter grouping.
|
|
158
|
+
|
|
159
|
+
Separates parameters into decay and no-decay groups based on
|
|
160
|
+
parameter names. By default, bias and normalization layer parameters
|
|
161
|
+
are excluded from weight decay.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
name: Optimizer name
|
|
165
|
+
model: PyTorch model
|
|
166
|
+
lr: Learning rate
|
|
167
|
+
weight_decay: Weight decay for applicable parameters
|
|
168
|
+
no_decay_keywords: Keywords to identify no-decay parameters
|
|
169
|
+
Default: ['bias', 'norm', 'bn', 'ln']
|
|
170
|
+
**kwargs: Additional optimizer arguments
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Optimizer with parameter groups configured
|
|
174
|
+
|
|
175
|
+
Example:
|
|
176
|
+
>>> optimizer = get_optimizer_with_param_groups(
|
|
177
|
+
... "adamw", model, lr=1e-3, weight_decay=1e-4
|
|
178
|
+
... )
|
|
179
|
+
"""
|
|
180
|
+
if no_decay_keywords is None:
|
|
181
|
+
no_decay_keywords = ["bias", "norm", "bn", "ln"]
|
|
182
|
+
|
|
183
|
+
decay_params = []
|
|
184
|
+
no_decay_params = []
|
|
185
|
+
|
|
186
|
+
for name_param, param in model.named_parameters():
|
|
187
|
+
if not param.requires_grad:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
# Check if any no-decay keyword is in parameter name
|
|
191
|
+
if any(kw in name_param.lower() for kw in no_decay_keywords):
|
|
192
|
+
no_decay_params.append(param)
|
|
193
|
+
else:
|
|
194
|
+
decay_params.append(param)
|
|
195
|
+
|
|
196
|
+
param_groups = []
|
|
197
|
+
if decay_params:
|
|
198
|
+
param_groups.append(
|
|
199
|
+
{
|
|
200
|
+
"params": decay_params,
|
|
201
|
+
"weight_decay": weight_decay,
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
if no_decay_params:
|
|
205
|
+
param_groups.append(
|
|
206
|
+
{
|
|
207
|
+
"params": no_decay_params,
|
|
208
|
+
"weight_decay": 0.0,
|
|
209
|
+
}
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Fall back to all parameters if grouping fails
|
|
213
|
+
if not param_groups:
|
|
214
|
+
param_groups = [{"params": model.parameters()}]
|
|
215
|
+
|
|
216
|
+
return get_optimizer(name, param_groups, lr=lr, weight_decay=0.0, **kwargs)
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Learning Rate Schedulers
|
|
3
|
+
========================
|
|
4
|
+
|
|
5
|
+
Provides a comprehensive set of learning rate schedulers with a factory
|
|
6
|
+
function for easy selection via CLI arguments.
|
|
7
|
+
|
|
8
|
+
Supported Schedulers:
|
|
9
|
+
- plateau: ReduceLROnPlateau (default, adaptive)
|
|
10
|
+
- cosine: CosineAnnealingLR
|
|
11
|
+
- cosine_restarts: CosineAnnealingWarmRestarts
|
|
12
|
+
- onecycle: OneCycleLR
|
|
13
|
+
- step: StepLR
|
|
14
|
+
- multistep: MultiStepLR
|
|
15
|
+
- exponential: ExponentialLR
|
|
16
|
+
- linear_warmup: LinearLR (warmup phase)
|
|
17
|
+
|
|
18
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
19
|
+
Version: 1.0.0
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import torch.optim as optim
|
|
23
|
+
from torch.optim.lr_scheduler import (
|
|
24
|
+
CosineAnnealingLR,
|
|
25
|
+
CosineAnnealingWarmRestarts,
|
|
26
|
+
ExponentialLR,
|
|
27
|
+
LinearLR,
|
|
28
|
+
LRScheduler,
|
|
29
|
+
MultiStepLR,
|
|
30
|
+
OneCycleLR,
|
|
31
|
+
ReduceLROnPlateau,
|
|
32
|
+
SequentialLR,
|
|
33
|
+
StepLR,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ==============================================================================
|
|
38
|
+
# SCHEDULER REGISTRY
|
|
39
|
+
# ==============================================================================
|
|
40
|
+
def list_schedulers() -> list[str]:
|
|
41
|
+
"""
|
|
42
|
+
Return list of available scheduler names.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
List of registered scheduler names
|
|
46
|
+
"""
|
|
47
|
+
return [
|
|
48
|
+
"plateau",
|
|
49
|
+
"cosine",
|
|
50
|
+
"cosine_restarts",
|
|
51
|
+
"onecycle",
|
|
52
|
+
"step",
|
|
53
|
+
"multistep",
|
|
54
|
+
"exponential",
|
|
55
|
+
"linear_warmup",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_scheduler(
|
|
60
|
+
name: str,
|
|
61
|
+
optimizer: optim.Optimizer,
|
|
62
|
+
# Common parameters
|
|
63
|
+
epochs: int = 100,
|
|
64
|
+
steps_per_epoch: int | None = None,
|
|
65
|
+
min_lr: float = 1e-6,
|
|
66
|
+
# ReduceLROnPlateau parameters
|
|
67
|
+
patience: int = 10,
|
|
68
|
+
factor: float = 0.5,
|
|
69
|
+
# Cosine parameters
|
|
70
|
+
T_max: int | None = None,
|
|
71
|
+
T_0: int = 10,
|
|
72
|
+
T_mult: int = 2,
|
|
73
|
+
# OneCycleLR parameters
|
|
74
|
+
max_lr: float | None = None,
|
|
75
|
+
pct_start: float = 0.3,
|
|
76
|
+
# Step/MultiStep parameters
|
|
77
|
+
step_size: int = 30,
|
|
78
|
+
milestones: list[int] | None = None,
|
|
79
|
+
gamma: float = 0.1,
|
|
80
|
+
# Linear warmup parameters
|
|
81
|
+
warmup_epochs: int = 5,
|
|
82
|
+
start_factor: float = 0.1,
|
|
83
|
+
**kwargs,
|
|
84
|
+
) -> LRScheduler:
|
|
85
|
+
"""
|
|
86
|
+
Factory function to create learning rate scheduler by name.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
name: Scheduler name (see list_schedulers())
|
|
90
|
+
optimizer: Optimizer instance to schedule
|
|
91
|
+
epochs: Total training epochs (for cosine, onecycle)
|
|
92
|
+
steps_per_epoch: Steps per epoch (required for onecycle)
|
|
93
|
+
min_lr: Minimum learning rate (eta_min for cosine)
|
|
94
|
+
patience: Patience for ReduceLROnPlateau
|
|
95
|
+
factor: Reduction factor for plateau/step
|
|
96
|
+
T_max: Period for CosineAnnealingLR (default: epochs)
|
|
97
|
+
T_0: Initial period for CosineAnnealingWarmRestarts
|
|
98
|
+
T_mult: Period multiplier for warm restarts
|
|
99
|
+
max_lr: Maximum LR for OneCycleLR (default: optimizer's initial LR)
|
|
100
|
+
pct_start: Percentage of cycle spent increasing LR (OneCycleLR)
|
|
101
|
+
step_size: Period for StepLR
|
|
102
|
+
milestones: Epochs to decay LR for MultiStepLR
|
|
103
|
+
gamma: Decay factor for step/multistep/exponential
|
|
104
|
+
warmup_epochs: Number of warmup epochs for linear_warmup
|
|
105
|
+
start_factor: Starting LR factor for warmup (LR * start_factor)
|
|
106
|
+
**kwargs: Additional arguments passed to scheduler
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Instantiated learning rate scheduler
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If scheduler name is not recognized
|
|
113
|
+
|
|
114
|
+
Example:
|
|
115
|
+
>>> scheduler = get_scheduler("plateau", optimizer, patience=15)
|
|
116
|
+
>>> scheduler = get_scheduler("cosine", optimizer, epochs=100)
|
|
117
|
+
>>> scheduler = get_scheduler(
|
|
118
|
+
... "onecycle", optimizer, epochs=100, steps_per_epoch=1000, max_lr=1e-3
|
|
119
|
+
... )
|
|
120
|
+
"""
|
|
121
|
+
name_lower = name.lower().replace("-", "_")
|
|
122
|
+
|
|
123
|
+
# Get initial LR from optimizer
|
|
124
|
+
base_lr = optimizer.param_groups[0]["lr"]
|
|
125
|
+
|
|
126
|
+
if name_lower == "plateau":
|
|
127
|
+
return ReduceLROnPlateau(
|
|
128
|
+
optimizer,
|
|
129
|
+
mode="min",
|
|
130
|
+
factor=factor,
|
|
131
|
+
patience=patience,
|
|
132
|
+
min_lr=min_lr,
|
|
133
|
+
**kwargs,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
elif name_lower == "cosine":
|
|
137
|
+
return CosineAnnealingLR(
|
|
138
|
+
optimizer,
|
|
139
|
+
T_max=T_max if T_max is not None else epochs,
|
|
140
|
+
eta_min=min_lr,
|
|
141
|
+
**kwargs,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
elif name_lower == "cosine_restarts":
|
|
145
|
+
return CosineAnnealingWarmRestarts(
|
|
146
|
+
optimizer, T_0=T_0, T_mult=T_mult, eta_min=min_lr, **kwargs
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
elif name_lower == "onecycle":
|
|
150
|
+
if steps_per_epoch is None:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
"OneCycleLR requires 'steps_per_epoch'. "
|
|
153
|
+
"Pass len(train_dataloader) as steps_per_epoch."
|
|
154
|
+
)
|
|
155
|
+
return OneCycleLR(
|
|
156
|
+
optimizer,
|
|
157
|
+
max_lr=max_lr if max_lr is not None else base_lr,
|
|
158
|
+
epochs=epochs,
|
|
159
|
+
steps_per_epoch=steps_per_epoch,
|
|
160
|
+
pct_start=pct_start,
|
|
161
|
+
**kwargs,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
elif name_lower == "step":
|
|
165
|
+
return StepLR(optimizer, step_size=step_size, gamma=gamma, **kwargs)
|
|
166
|
+
|
|
167
|
+
elif name_lower == "multistep":
|
|
168
|
+
if milestones is None:
|
|
169
|
+
# Default milestones at 30%, 60%, 90% of epochs
|
|
170
|
+
milestones = [int(epochs * 0.3), int(epochs * 0.6), int(epochs * 0.9)]
|
|
171
|
+
return MultiStepLR(optimizer, milestones=milestones, gamma=gamma, **kwargs)
|
|
172
|
+
|
|
173
|
+
elif name_lower == "exponential":
|
|
174
|
+
return ExponentialLR(optimizer, gamma=gamma, **kwargs)
|
|
175
|
+
|
|
176
|
+
elif name_lower == "linear_warmup":
|
|
177
|
+
return LinearLR(
|
|
178
|
+
optimizer,
|
|
179
|
+
start_factor=start_factor,
|
|
180
|
+
end_factor=1.0,
|
|
181
|
+
total_iters=warmup_epochs,
|
|
182
|
+
**kwargs,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
available = ", ".join(list_schedulers())
|
|
187
|
+
raise ValueError(f"Unknown scheduler: '{name}'. Available options: {available}")
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def get_scheduler_with_warmup(
|
|
191
|
+
name: str,
|
|
192
|
+
optimizer: optim.Optimizer,
|
|
193
|
+
warmup_epochs: int = 5,
|
|
194
|
+
start_factor: float = 0.1,
|
|
195
|
+
**kwargs,
|
|
196
|
+
) -> LRScheduler:
|
|
197
|
+
"""
|
|
198
|
+
Create a scheduler with linear warmup phase.
|
|
199
|
+
|
|
200
|
+
Combines LinearLR warmup with any other scheduler using SequentialLR.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
name: Main scheduler name (after warmup)
|
|
204
|
+
optimizer: Optimizer instance
|
|
205
|
+
warmup_epochs: Number of warmup epochs
|
|
206
|
+
start_factor: Starting LR factor for warmup
|
|
207
|
+
**kwargs: Arguments for main scheduler (see get_scheduler)
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
SequentialLR combining warmup and main scheduler
|
|
211
|
+
|
|
212
|
+
Example:
|
|
213
|
+
>>> scheduler = get_scheduler_with_warmup(
|
|
214
|
+
... "cosine", optimizer, warmup_epochs=5, epochs=100
|
|
215
|
+
... )
|
|
216
|
+
"""
|
|
217
|
+
# Create warmup scheduler
|
|
218
|
+
warmup_scheduler = LinearLR(
|
|
219
|
+
optimizer,
|
|
220
|
+
start_factor=start_factor,
|
|
221
|
+
end_factor=1.0,
|
|
222
|
+
total_iters=warmup_epochs,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Create main scheduler
|
|
226
|
+
main_scheduler = get_scheduler(name, optimizer, **kwargs)
|
|
227
|
+
|
|
228
|
+
# Combine with SequentialLR
|
|
229
|
+
return SequentialLR(
|
|
230
|
+
optimizer,
|
|
231
|
+
schedulers=[warmup_scheduler, main_scheduler],
|
|
232
|
+
milestones=[warmup_epochs],
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def is_epoch_based(name: str) -> bool:
|
|
237
|
+
"""
|
|
238
|
+
Check if scheduler should be stepped per epoch (True) or per batch (False).
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
name: Scheduler name
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
True if scheduler should step per epoch, False for per batch
|
|
245
|
+
"""
|
|
246
|
+
name_lower = name.lower().replace("-", "_")
|
|
247
|
+
|
|
248
|
+
# OneCycleLR steps per batch, all others step per epoch
|
|
249
|
+
per_batch_schedulers = {"onecycle"}
|
|
250
|
+
|
|
251
|
+
return name_lower not in per_batch_schedulers
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Ductho Le
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|