wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,216 +1,216 @@
1
- """
2
- Optimizers for Deep Learning Training
3
- =====================================
4
-
5
- Provides a comprehensive set of optimizers with a factory function
6
- for easy selection via CLI arguments.
7
-
8
- Supported Optimizers:
9
- - adamw: AdamW (default, best for most cases)
10
- - adam: Adam (legacy)
11
- - sgd: SGD with momentum
12
- - nadam: NAdam (Adam + Nesterov momentum)
13
- - radam: RAdam (variance-adaptive Adam)
14
- - rmsprop: RMSprop (good for RNNs)
15
-
16
- Author: Ductho Le (ductho.le@outlook.com)
17
- Version: 1.0.0
18
- """
19
-
20
- from collections.abc import Iterator
21
-
22
- import torch
23
- import torch.optim as optim
24
- from torch.nn import Parameter
25
-
26
-
27
- # ==============================================================================
28
- # OPTIMIZER REGISTRY
29
- # ==============================================================================
30
- def list_optimizers() -> list[str]:
31
- """
32
- Return list of available optimizer names.
33
-
34
- Returns:
35
- List of registered optimizer names
36
- """
37
- return ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
38
-
39
-
40
- def get_optimizer(
41
- name: str,
42
- params: Iterator[Parameter] | list[dict],
43
- lr: float = 1e-3,
44
- weight_decay: float = 1e-4,
45
- # SGD-specific
46
- momentum: float = 0.9,
47
- nesterov: bool = False,
48
- dampening: float = 0.0,
49
- # Adam/AdamW-specific
50
- betas: tuple = (0.9, 0.999),
51
- eps: float = 1e-8,
52
- amsgrad: bool = False,
53
- # RMSprop-specific
54
- alpha: float = 0.99,
55
- centered: bool = False,
56
- **kwargs,
57
- ) -> optim.Optimizer:
58
- """
59
- Factory function to create optimizer by name.
60
-
61
- Args:
62
- name: Optimizer name (see list_optimizers())
63
- params: Model parameters or parameter groups
64
- lr: Learning rate
65
- weight_decay: Weight decay (L2 penalty)
66
- momentum: Momentum factor (SGD, RMSprop)
67
- nesterov: Enable Nesterov momentum (SGD)
68
- dampening: Dampening for momentum (SGD)
69
- betas: Coefficients for computing running averages (Adam variants)
70
- eps: Term for numerical stability (Adam variants, RMSprop)
71
- amsgrad: Use AMSGrad variant (Adam variants)
72
- alpha: Smoothing constant (RMSprop)
73
- centered: Compute centered RMSprop
74
- **kwargs: Additional optimizer-specific arguments
75
-
76
- Returns:
77
- Instantiated optimizer
78
-
79
- Raises:
80
- ValueError: If optimizer name is not recognized
81
-
82
- Example:
83
- >>> optimizer = get_optimizer("adamw", model.parameters(), lr=1e-3)
84
- >>> optimizer = get_optimizer("sgd", model.parameters(), lr=1e-2, momentum=0.9)
85
- """
86
- name_lower = name.lower()
87
-
88
- if name_lower == "adamw":
89
- return optim.AdamW(
90
- params,
91
- lr=lr,
92
- betas=betas,
93
- eps=eps,
94
- weight_decay=weight_decay,
95
- amsgrad=amsgrad,
96
- **kwargs,
97
- )
98
-
99
- elif name_lower == "adam":
100
- return optim.Adam(
101
- params,
102
- lr=lr,
103
- betas=betas,
104
- eps=eps,
105
- weight_decay=weight_decay,
106
- amsgrad=amsgrad,
107
- **kwargs,
108
- )
109
-
110
- elif name_lower == "sgd":
111
- return optim.SGD(
112
- params,
113
- lr=lr,
114
- momentum=momentum,
115
- dampening=dampening,
116
- weight_decay=weight_decay,
117
- nesterov=nesterov,
118
- **kwargs,
119
- )
120
-
121
- elif name_lower == "nadam":
122
- return optim.NAdam(
123
- params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
124
- )
125
-
126
- elif name_lower == "radam":
127
- return optim.RAdam(
128
- params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
129
- )
130
-
131
- elif name_lower == "rmsprop":
132
- return optim.RMSprop(
133
- params,
134
- lr=lr,
135
- alpha=alpha,
136
- eps=eps,
137
- weight_decay=weight_decay,
138
- momentum=momentum,
139
- centered=centered,
140
- **kwargs,
141
- )
142
-
143
- else:
144
- available = ", ".join(list_optimizers())
145
- raise ValueError(f"Unknown optimizer: '{name}'. Available options: {available}")
146
-
147
-
148
- def get_optimizer_with_param_groups(
149
- name: str,
150
- model: torch.nn.Module,
151
- lr: float = 1e-3,
152
- weight_decay: float = 1e-4,
153
- no_decay_keywords: list[str] = None,
154
- **kwargs,
155
- ) -> optim.Optimizer:
156
- """
157
- Create optimizer with automatic parameter grouping.
158
-
159
- Separates parameters into decay and no-decay groups based on
160
- parameter names. By default, bias and normalization layer parameters
161
- are excluded from weight decay.
162
-
163
- Args:
164
- name: Optimizer name
165
- model: PyTorch model
166
- lr: Learning rate
167
- weight_decay: Weight decay for applicable parameters
168
- no_decay_keywords: Keywords to identify no-decay parameters
169
- Default: ['bias', 'norm', 'bn', 'ln']
170
- **kwargs: Additional optimizer arguments
171
-
172
- Returns:
173
- Optimizer with parameter groups configured
174
-
175
- Example:
176
- >>> optimizer = get_optimizer_with_param_groups(
177
- ... "adamw", model, lr=1e-3, weight_decay=1e-4
178
- ... )
179
- """
180
- if no_decay_keywords is None:
181
- no_decay_keywords = ["bias", "norm", "bn", "ln"]
182
-
183
- decay_params = []
184
- no_decay_params = []
185
-
186
- for name_param, param in model.named_parameters():
187
- if not param.requires_grad:
188
- continue
189
-
190
- # Check if any no-decay keyword is in parameter name
191
- if any(kw in name_param.lower() for kw in no_decay_keywords):
192
- no_decay_params.append(param)
193
- else:
194
- decay_params.append(param)
195
-
196
- param_groups = []
197
- if decay_params:
198
- param_groups.append(
199
- {
200
- "params": decay_params,
201
- "weight_decay": weight_decay,
202
- }
203
- )
204
- if no_decay_params:
205
- param_groups.append(
206
- {
207
- "params": no_decay_params,
208
- "weight_decay": 0.0,
209
- }
210
- )
211
-
212
- # Fall back to all parameters if grouping fails
213
- if not param_groups:
214
- param_groups = [{"params": model.parameters()}]
215
-
216
- return get_optimizer(name, param_groups, lr=lr, weight_decay=0.0, **kwargs)
1
+ """
2
+ Optimizers for Deep Learning Training
3
+ =====================================
4
+
5
+ Provides a comprehensive set of optimizers with a factory function
6
+ for easy selection via CLI arguments.
7
+
8
+ Supported Optimizers:
9
+ - adamw: AdamW (default, best for most cases)
10
+ - adam: Adam (legacy)
11
+ - sgd: SGD with momentum
12
+ - nadam: NAdam (Adam + Nesterov momentum)
13
+ - radam: RAdam (variance-adaptive Adam)
14
+ - rmsprop: RMSprop (good for RNNs)
15
+
16
+ Author: Ductho Le (ductho.le@outlook.com)
17
+ Version: 1.0.0
18
+ """
19
+
20
+ from collections.abc import Iterator
21
+
22
+ import torch
23
+ import torch.optim as optim
24
+ from torch.nn import Parameter
25
+
26
+
27
+ # ==============================================================================
28
+ # OPTIMIZER REGISTRY
29
+ # ==============================================================================
30
+ def list_optimizers() -> list[str]:
31
+ """
32
+ Return list of available optimizer names.
33
+
34
+ Returns:
35
+ List of registered optimizer names
36
+ """
37
+ return ["adamw", "adam", "sgd", "nadam", "radam", "rmsprop"]
38
+
39
+
40
+ def get_optimizer(
41
+ name: str,
42
+ params: Iterator[Parameter] | list[dict],
43
+ lr: float = 1e-3,
44
+ weight_decay: float = 1e-4,
45
+ # SGD-specific
46
+ momentum: float = 0.9,
47
+ nesterov: bool = False,
48
+ dampening: float = 0.0,
49
+ # Adam/AdamW-specific
50
+ betas: tuple = (0.9, 0.999),
51
+ eps: float = 1e-8,
52
+ amsgrad: bool = False,
53
+ # RMSprop-specific
54
+ alpha: float = 0.99,
55
+ centered: bool = False,
56
+ **kwargs,
57
+ ) -> optim.Optimizer:
58
+ """
59
+ Factory function to create optimizer by name.
60
+
61
+ Args:
62
+ name: Optimizer name (see list_optimizers())
63
+ params: Model parameters or parameter groups
64
+ lr: Learning rate
65
+ weight_decay: Weight decay (L2 penalty)
66
+ momentum: Momentum factor (SGD, RMSprop)
67
+ nesterov: Enable Nesterov momentum (SGD)
68
+ dampening: Dampening for momentum (SGD)
69
+ betas: Coefficients for computing running averages (Adam variants)
70
+ eps: Term for numerical stability (Adam variants, RMSprop)
71
+ amsgrad: Use AMSGrad variant (Adam variants)
72
+ alpha: Smoothing constant (RMSprop)
73
+ centered: Compute centered RMSprop
74
+ **kwargs: Additional optimizer-specific arguments
75
+
76
+ Returns:
77
+ Instantiated optimizer
78
+
79
+ Raises:
80
+ ValueError: If optimizer name is not recognized
81
+
82
+ Example:
83
+ >>> optimizer = get_optimizer("adamw", model.parameters(), lr=1e-3)
84
+ >>> optimizer = get_optimizer("sgd", model.parameters(), lr=1e-2, momentum=0.9)
85
+ """
86
+ name_lower = name.lower()
87
+
88
+ if name_lower == "adamw":
89
+ return optim.AdamW(
90
+ params,
91
+ lr=lr,
92
+ betas=betas,
93
+ eps=eps,
94
+ weight_decay=weight_decay,
95
+ amsgrad=amsgrad,
96
+ **kwargs,
97
+ )
98
+
99
+ elif name_lower == "adam":
100
+ return optim.Adam(
101
+ params,
102
+ lr=lr,
103
+ betas=betas,
104
+ eps=eps,
105
+ weight_decay=weight_decay,
106
+ amsgrad=amsgrad,
107
+ **kwargs,
108
+ )
109
+
110
+ elif name_lower == "sgd":
111
+ return optim.SGD(
112
+ params,
113
+ lr=lr,
114
+ momentum=momentum,
115
+ dampening=dampening,
116
+ weight_decay=weight_decay,
117
+ nesterov=nesterov,
118
+ **kwargs,
119
+ )
120
+
121
+ elif name_lower == "nadam":
122
+ return optim.NAdam(
123
+ params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
124
+ )
125
+
126
+ elif name_lower == "radam":
127
+ return optim.RAdam(
128
+ params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, **kwargs
129
+ )
130
+
131
+ elif name_lower == "rmsprop":
132
+ return optim.RMSprop(
133
+ params,
134
+ lr=lr,
135
+ alpha=alpha,
136
+ eps=eps,
137
+ weight_decay=weight_decay,
138
+ momentum=momentum,
139
+ centered=centered,
140
+ **kwargs,
141
+ )
142
+
143
+ else:
144
+ available = ", ".join(list_optimizers())
145
+ raise ValueError(f"Unknown optimizer: '{name}'. Available options: {available}")
146
+
147
+
148
+ def get_optimizer_with_param_groups(
149
+ name: str,
150
+ model: torch.nn.Module,
151
+ lr: float = 1e-3,
152
+ weight_decay: float = 1e-4,
153
+ no_decay_keywords: list[str] = None,
154
+ **kwargs,
155
+ ) -> optim.Optimizer:
156
+ """
157
+ Create optimizer with automatic parameter grouping.
158
+
159
+ Separates parameters into decay and no-decay groups based on
160
+ parameter names. By default, bias and normalization layer parameters
161
+ are excluded from weight decay.
162
+
163
+ Args:
164
+ name: Optimizer name
165
+ model: PyTorch model
166
+ lr: Learning rate
167
+ weight_decay: Weight decay for applicable parameters
168
+ no_decay_keywords: Keywords to identify no-decay parameters
169
+ Default: ['bias', 'norm', 'bn', 'ln']
170
+ **kwargs: Additional optimizer arguments
171
+
172
+ Returns:
173
+ Optimizer with parameter groups configured
174
+
175
+ Example:
176
+ >>> optimizer = get_optimizer_with_param_groups(
177
+ ... "adamw", model, lr=1e-3, weight_decay=1e-4
178
+ ... )
179
+ """
180
+ if no_decay_keywords is None:
181
+ no_decay_keywords = ["bias", "norm", "bn", "ln"]
182
+
183
+ decay_params = []
184
+ no_decay_params = []
185
+
186
+ for name_param, param in model.named_parameters():
187
+ if not param.requires_grad:
188
+ continue
189
+
190
+ # Check if any no-decay keyword is in parameter name
191
+ if any(kw in name_param.lower() for kw in no_decay_keywords):
192
+ no_decay_params.append(param)
193
+ else:
194
+ decay_params.append(param)
195
+
196
+ param_groups = []
197
+ if decay_params:
198
+ param_groups.append(
199
+ {
200
+ "params": decay_params,
201
+ "weight_decay": weight_decay,
202
+ }
203
+ )
204
+ if no_decay_params:
205
+ param_groups.append(
206
+ {
207
+ "params": no_decay_params,
208
+ "weight_decay": 0.0,
209
+ }
210
+ )
211
+
212
+ # Fall back to all parameters if grouping fails
213
+ if not param_groups:
214
+ param_groups = [{"params": model.parameters()}]
215
+
216
+ return get_optimizer(name, param_groups, lr=lr, weight_decay=0.0, **kwargs)