ber-equalization-studio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ber_equalization_studio/__init__.py +58 -0
- ber_equalization_studio/_legacy_backend/__init__.py +1 -0
- ber_equalization_studio/_legacy_backend/ber_equalization.py +3700 -0
- ber_equalization_studio/_legacy_backend/efficient_kan/__init__.py +3 -0
- ber_equalization_studio/_legacy_backend/efficient_kan/kan.py +218 -0
- ber_equalization_studio/api.py +348 -0
- ber_equalization_studio/cli.py +92 -0
- ber_equalization_studio/config.py +168 -0
- ber_equalization_studio/data.py +31 -0
- ber_equalization_studio/experiment.py +92 -0
- ber_equalization_studio/legacy.py +149 -0
- ber_equalization_studio/models.py +86 -0
- ber_equalization_studio/results.py +74 -0
- ber_equalization_studio/visualization.py +186 -0
- ber_equalization_studio-0.1.0.dist-info/METADATA +266 -0
- ber_equalization_studio-0.1.0.dist-info/RECORD +19 -0
- ber_equalization_studio-0.1.0.dist-info/WHEEL +5 -0
- ber_equalization_studio-0.1.0.dist-info/entry_points.txt +2 -0
- ber_equalization_studio-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,3700 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import time
|
|
3
|
+
from contextlib import nullcontext
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
import torch.optim as optim
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from efficient_kan import KAN as EfficientKAN
|
|
17
|
+
except ImportError:
|
|
18
|
+
EfficientKAN = None
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from mamba_ssm import Mamba
|
|
22
|
+
except ImportError:
|
|
23
|
+
Mamba = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Config:
|
|
27
|
+
DEVICE = torch.device(
|
|
28
|
+
"cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
|
29
|
+
)
|
|
30
|
+
DATA_DIR_CANDIDATES = [Path("symbols_new"), Path("Symbols_1m_1ch_PR"), Path(".")]
|
|
31
|
+
MAX_FILES = 64
|
|
32
|
+
# File-level split: first TRAIN_PORTION files form train+val pool, the rest is a hold-out test set.
|
|
33
|
+
TRAIN_PORTION = 0.97
|
|
34
|
+
VAL_PORTION_WITHIN_TRAIN = 0.10
|
|
35
|
+
MIN_VAL_FILES = 1
|
|
36
|
+
RANDOMIZE_FILE_SPLIT = False
|
|
37
|
+
SPLIT_SEED = 42
|
|
38
|
+
|
|
39
|
+
CONTEXT_K = 32
|
|
40
|
+
SEQ_LEN = 2 * CONTEXT_K + 1
|
|
41
|
+
INPUT_DIM = 2
|
|
42
|
+
|
|
43
|
+
HIDDEN_DIM = 96 # надо 64
|
|
44
|
+
DROPOUT = 0.2
|
|
45
|
+
LSTM_HIDDEN = 64 # надо 64
|
|
46
|
+
LSTM_LAYERS = 2
|
|
47
|
+
BIDIRECTIONAL = True
|
|
48
|
+
USE_ATTENTION = True
|
|
49
|
+
TRANSFORMER_DIM = 128
|
|
50
|
+
TRANSFORMER_LAYERS = 4
|
|
51
|
+
TRANSFORMER_HEADS = 4
|
|
52
|
+
TRANSFORMER_FF_DIM = 256
|
|
53
|
+
TRANSFORMER_CONV_KERNEL = 3
|
|
54
|
+
TCN_HIDDEN_DIM = 96
|
|
55
|
+
TCN_LAYERS = 5
|
|
56
|
+
TCN_KERNEL_SIZE = 5
|
|
57
|
+
TCN_DILATIONS = [1, 2, 4, 8, 16]
|
|
58
|
+
MAMBA_DIM = 96
|
|
59
|
+
MAMBA_LAYERS = 4
|
|
60
|
+
MAMBA_D_STATE = 16
|
|
61
|
+
MAMBA_D_CONV = 4
|
|
62
|
+
MAMBA_EXPAND = 2
|
|
63
|
+
COMPLEX_CHANNELS = 24
|
|
64
|
+
COMPLEX_BLOCK_CHANNELS = [24, 32]
|
|
65
|
+
COMPLEX_KERNEL_SIZES = [5, 7]
|
|
66
|
+
COMPLEX_HEAD_DIM = 128
|
|
67
|
+
COMPLEX_TEMPORAL_DIM = 96
|
|
68
|
+
COMPLEX_TEMPORAL_DILATIONS = [1, 2, 4]
|
|
69
|
+
COMPLEX_LIGHT_CHANNELS = 48
|
|
70
|
+
COMPLEX_LIGHT_DILATIONS = [1, 2, 4]
|
|
71
|
+
COMPLEX_LIGHT_KERNEL_SIZE = 3
|
|
72
|
+
COMPLEX_SEQ_DIM = 96
|
|
73
|
+
COMPLEX_LSTM_HIDDEN = 64
|
|
74
|
+
COMPLEX_LSTM_LAYERS = 2
|
|
75
|
+
COMPLEX_USE_KERR = False
|
|
76
|
+
COMPLEX_USE_DBP_FRONTEND = False
|
|
77
|
+
COMPLEX_KERR_KERNEL = 5
|
|
78
|
+
COMPLEX_KERR_INIT_GAMMA = 0.02
|
|
79
|
+
DBP_NUM_STEPS = 20
|
|
80
|
+
DBP_KERNEL_SIZE = 7
|
|
81
|
+
DBP_FINAL_KERNEL_SIZE = 21
|
|
82
|
+
DBP_USE_FINAL_FILTER = True
|
|
83
|
+
DBP_USE_SYMMETRIC_FILTER = True
|
|
84
|
+
DBP_USE_SYMMETRIC_NONLINEAR = True
|
|
85
|
+
DBP_NL_MEMORY = 2
|
|
86
|
+
DBP_INIT_FROM_LS = False
|
|
87
|
+
DBP_INIT_SAMPLES = 65536
|
|
88
|
+
DBP_INIT_FFT_SIZE = 4096
|
|
89
|
+
DBP_JOINT_INIT = True
|
|
90
|
+
DBP_JOINT_INIT_ITERS = 200
|
|
91
|
+
DBP_JOINT_INIT_BATCH_SIZE = 1024
|
|
92
|
+
DBP_JOINT_INIT_LR = 2e-3
|
|
93
|
+
DBP_SEQSTAT_DIM = 128
|
|
94
|
+
|
|
95
|
+
FASTKAN_HIDDEN_DIM = 96
|
|
96
|
+
FASTKAN_LAYERS = 2
|
|
97
|
+
FASTKAN_NUM_GRIDS = 8
|
|
98
|
+
FASTKAN_GRID_MIN = -2.5
|
|
99
|
+
FASTKAN_GRID_MAX = 2.5
|
|
100
|
+
FASTKAN_BASE_ACT = "silu"
|
|
101
|
+
FASTKAN_USE_BASE_PATH = True
|
|
102
|
+
KAN_INPUT_DROPOUT = 0.05
|
|
103
|
+
KAN_HIDDEN_DROPOUT = 0.1
|
|
104
|
+
KAN_PRUNE_L1 = 1e-5
|
|
105
|
+
KAN_PRUNE_THRESHOLD = 0.02
|
|
106
|
+
KAN_STRUCTURAL_PRUNE_AFTER_TRAINING = False
|
|
107
|
+
KAN_STRUCTURAL_PRUNE_KEEP_RATIOS = [0.75, 0.5, 0.35, 0.25]
|
|
108
|
+
KAN_STRUCTURAL_PRUNE_MIN_HIDDEN = 16
|
|
109
|
+
KAN_STRUCTURAL_PRUNE_FINE_TUNE_EPOCHS = 20
|
|
110
|
+
KAN_STRUCTURAL_PRUNE_FINE_TUNE_LR = 2e-4
|
|
111
|
+
KAN_STRUCTURAL_PRUNE_SELECT_BY = "efficiency_score"
|
|
112
|
+
EFFICIENCY_BATCH_SIZE = 16000
|
|
113
|
+
EFFICIENCY_SCORE_POWER = 3.0
|
|
114
|
+
EFFICIENCY_TIMING_WARMUP = 5
|
|
115
|
+
EFFICIENCY_TIMING_REPEATS = 20
|
|
116
|
+
EFFICIENT_KAN_HIDDEN_DIM = 128
|
|
117
|
+
EFFICIENT_KAN_LAYERS = 2
|
|
118
|
+
EFFICIENT_KAN_GRID_SIZE = 20
|
|
119
|
+
EFFICIENT_KAN_SPLINE_ORDER = 3
|
|
120
|
+
EFFICIENT_KAN_GRID_EPS = 0.02
|
|
121
|
+
EFFICIENT_KAN_GRID_RANGE = [-3.0, 3.0]
|
|
122
|
+
EFFICIENT_KAN_SCALE_NOISE = 0.1
|
|
123
|
+
EFFICIENT_KAN_SCALE_BASE = 1.0
|
|
124
|
+
EFFICIENT_KAN_SCALE_SPLINE = 1.0
|
|
125
|
+
KAN_FEATURE_RADIUS = 2
|
|
126
|
+
|
|
127
|
+
EPOCHS = 250
|
|
128
|
+
LEARNING_RATE = 1e-3
|
|
129
|
+
WEIGHT_DECAY = 0.0
|
|
130
|
+
TRAIN_BLOCK_SIZE = 8192
|
|
131
|
+
EVAL_BATCH_SIZE = 65536
|
|
132
|
+
MIN_BLOCK_SIZE = 1024
|
|
133
|
+
USE_AMP = True
|
|
134
|
+
USE_TORCH_COMPILE = False
|
|
135
|
+
TORCH_COMPILE_MODE = "max-autotune-no-cudagraphs"
|
|
136
|
+
OPTIMIZER = "adam"
|
|
137
|
+
LOSS = "mse"
|
|
138
|
+
GRAD_CLIP_NORM = 1.0
|
|
139
|
+
LR_SCHEDULER = "notebook_decay"
|
|
140
|
+
SCHEDULER_FACTOR = 0.5
|
|
141
|
+
SCHEDULER_PATIENCE = 100
|
|
142
|
+
SCHEDULER_THRESHOLD = 1e-6
|
|
143
|
+
|
|
144
|
+
DECAY_STEPS = 24
|
|
145
|
+
MIN_LR = 1e-5
|
|
146
|
+
EARLY_STOPPING = True
|
|
147
|
+
EARLY_STOPPING_PATIENCE = 72
|
|
148
|
+
EARLY_STOPPING_MIN_EPOCHS = 40
|
|
149
|
+
EARLY_STOPPING_THRESHOLD = 0.0
|
|
150
|
+
LOG_EVERY = 1
|
|
151
|
+
TEST_BER_EVERY = 10
|
|
152
|
+
SAVE_BEST_BY = "val_ber"
|
|
153
|
+
EVAL_TEST_DURING_TRAINING = False
|
|
154
|
+
COMPUTE_PER_FILE_METRICS = True
|
|
155
|
+
POWER_NORMALIZE = True
|
|
156
|
+
BER_SCALE_SEARCH = True
|
|
157
|
+
BER_SCALE_MIN = 0.5
|
|
158
|
+
BER_SCALE_MAX = 1.5
|
|
159
|
+
BER_SCALE_STEPS = 10
|
|
160
|
+
BER_SCALE_OFFSET = 10000
|
|
161
|
+
BER_SCALE_SAMPLES = 1 << 20
|
|
162
|
+
|
|
163
|
+
OUT_DIR = Path("clean_compare_outputs")
|
|
164
|
+
RUN_MAIN_EXPERIMENTS = True
|
|
165
|
+
MODEL_TYPES = [
|
|
166
|
+
"efficient_kan_baseline",
|
|
167
|
+
"kan_classifier",
|
|
168
|
+
|
|
169
|
+
]
|
|
170
|
+
SAVE_BEST = True
|
|
171
|
+
RUN_SWEEP_EXPERIMENTS = False
|
|
172
|
+
SWEEP_TEST_FILES = 1
|
|
173
|
+
WINDOW_SWEEP_VALUES = [2, 4, 8, 12, 16, 20]
|
|
174
|
+
HIDDEN_SWEEP_VALUES = [8, 16, 32, 64]
|
|
175
|
+
RUN_EFFICIENT_KAN_SWEEP = False
|
|
176
|
+
EFFICIENT_KAN_SWEEP_MODELS = ["mlp", "efficient_kan_baseline", "kan_classifier"]
|
|
177
|
+
EFFICIENT_KAN_SWEEP_EPOCHS = 60
|
|
178
|
+
EFFICIENT_KAN_SWEEP_TEST_FILES = 1
|
|
179
|
+
EFFICIENT_KAN_HIDDEN_SWEEP_VALUES = [64, 96, 128, 192]
|
|
180
|
+
EFFICIENT_KAN_LR_SWEEP_VALUES = [1e-3]
|
|
181
|
+
EFFICIENT_KAN_GRID_SWEEP_VALUES = [4, 8, 12, 16]
|
|
182
|
+
EFFICIENT_KAN_ORDER_SWEEP_VALUES = [1, 2, 3, 4]
|
|
183
|
+
EFFICIENT_KAN_LAYER_SWEEP_VALUES = [1, 2, 3]
|
|
184
|
+
|
|
185
|
+
RUN_KAN_EXPERIMENT_SUITE = False
|
|
186
|
+
EXPERIMENT_EPOCHS = 60
|
|
187
|
+
EXPERIMENT_TEST_FILES = 1
|
|
188
|
+
EXPERIMENT_COMPUTE_PER_FILE_METRICS = False
|
|
189
|
+
EXPERIMENT_KAN_MODELS = ["efficient_kan_baseline", "kan_classifier"]
|
|
190
|
+
EXPERIMENT_COMPARE_MODELS = ["efficient_kan_baseline", "mlp"]
|
|
191
|
+
EXPERIMENT_COMPLEXITY_MODELS = ["complex_fastkan", "efficient_kan_baseline"]
|
|
192
|
+
EXPERIMENT_FIXED_GRID = 16
|
|
193
|
+
EXPERIMENT_FIXED_SPLINE_ORDER = 3
|
|
194
|
+
EXPERIMENT_HIDDEN_VALUES = [64, 96, 128, 192]
|
|
195
|
+
EXPERIMENT_WINDOW_VALUES = [8, 16, 24, 32, 48]
|
|
196
|
+
EXPERIMENT_GRID_VALUES = [4, 8, 12, 16, 20]
|
|
197
|
+
EXPERIMENT_SPLINE_ORDER_VALUES = [1, 2, 3, 4]
|
|
198
|
+
EXPERIMENT_LAYER_VALUES = [1, 2, 3]
|
|
199
|
+
MLP_LAYERS = 3
|
|
200
|
+
|
|
201
|
+
RUN_FASTKAN_CLASSIFIER_SWEEP = True
|
|
202
|
+
FASTKAN_CLASSIFIER_SWEEP_MODELS = ["fastkan_classifier", "complex_fastkan_classifier"]
|
|
203
|
+
FASTKAN_CLASSIFIER_SWEEP_EPOCHS = 60
|
|
204
|
+
FASTKAN_CLASSIFIER_SWEEP_TEST_FILES = 1
|
|
205
|
+
FASTKAN_CLASSIFIER_HIDDEN_VALUES = [16, 32, 48, 64, 96]
|
|
206
|
+
FASTKAN_CLASSIFIER_GRID_VALUES = [4, 8, 12, 16]
|
|
207
|
+
FASTKAN_CLASSIFIER_LAYER_VALUES = [1, 2]
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
if Config.DEVICE.type == "cuda":
|
|
211
|
+
torch.backends.cudnn.benchmark = True
|
|
212
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
213
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
214
|
+
torch.set_float32_matmul_precision("high")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
CONSTELLATION = torch.tensor(
|
|
218
|
+
[
|
|
219
|
+
[-0.948683, -0.948683],
|
|
220
|
+
[-0.948683, -0.316228],
|
|
221
|
+
[-0.948683, 0.316228],
|
|
222
|
+
[-0.948683, 0.948683],
|
|
223
|
+
[-0.316228, -0.948683],
|
|
224
|
+
[-0.316228, -0.316228],
|
|
225
|
+
[-0.316228, 0.316228],
|
|
226
|
+
[-0.316228, 0.948683],
|
|
227
|
+
[0.316228, -0.948683],
|
|
228
|
+
[0.316228, -0.316228],
|
|
229
|
+
[0.316228, 0.316228],
|
|
230
|
+
[0.316228, 0.948683],
|
|
231
|
+
[0.948683, -0.948683],
|
|
232
|
+
[0.948683, -0.316228],
|
|
233
|
+
[0.948683, 0.316228],
|
|
234
|
+
[0.948683, 0.948683],
|
|
235
|
+
],
|
|
236
|
+
dtype=torch.float32,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
BIT_LABELS = torch.tensor(
|
|
240
|
+
[
|
|
241
|
+
[0, 0, 0, 0],
|
|
242
|
+
[0, 0, 0, 1],
|
|
243
|
+
[0, 0, 1, 1],
|
|
244
|
+
[0, 0, 1, 0],
|
|
245
|
+
[0, 1, 0, 0],
|
|
246
|
+
[0, 1, 0, 1],
|
|
247
|
+
[0, 1, 1, 1],
|
|
248
|
+
[0, 1, 1, 0],
|
|
249
|
+
[1, 1, 0, 0],
|
|
250
|
+
[1, 1, 0, 1],
|
|
251
|
+
[1, 1, 1, 1],
|
|
252
|
+
[1, 1, 1, 0],
|
|
253
|
+
[1, 0, 0, 0],
|
|
254
|
+
[1, 0, 0, 1],
|
|
255
|
+
[1, 0, 1, 1],
|
|
256
|
+
[1, 0, 1, 0],
|
|
257
|
+
],
|
|
258
|
+
dtype=torch.uint8,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class AttentionLayer(nn.Module):
|
|
263
|
+
def __init__(self, hidden_dim: int):
|
|
264
|
+
super().__init__()
|
|
265
|
+
self.attention = nn.Sequential(
|
|
266
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
267
|
+
nn.Tanh(),
|
|
268
|
+
nn.Linear(hidden_dim // 2, 1),
|
|
269
|
+
)
|
|
270
|
+
self.softmax = nn.Softmax(dim=1)
|
|
271
|
+
|
|
272
|
+
def forward(self, lstm_output: torch.Tensor) -> torch.Tensor:
|
|
273
|
+
weights = self.softmax(self.attention(lstm_output))
|
|
274
|
+
return torch.sum(weights * lstm_output, dim=1)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class LSTMRxEqualizer(nn.Module):
|
|
278
|
+
def __init__(self):
|
|
279
|
+
super().__init__()
|
|
280
|
+
lstm_out_dim = Config.LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
|
|
281
|
+
self.embedding = nn.Sequential(
|
|
282
|
+
nn.Linear(Config.INPUT_DIM, 32),
|
|
283
|
+
nn.LayerNorm(32),
|
|
284
|
+
nn.GELU(),
|
|
285
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
286
|
+
)
|
|
287
|
+
self.lstm = nn.LSTM(
|
|
288
|
+
input_size=32,
|
|
289
|
+
hidden_size=Config.LSTM_HIDDEN,
|
|
290
|
+
num_layers=Config.LSTM_LAYERS,
|
|
291
|
+
batch_first=True,
|
|
292
|
+
bidirectional=Config.BIDIRECTIONAL,
|
|
293
|
+
dropout=Config.DROPOUT if Config.LSTM_LAYERS > 1 else 0.0,
|
|
294
|
+
)
|
|
295
|
+
self.use_attention = Config.USE_ATTENTION
|
|
296
|
+
self.attention = AttentionLayer(lstm_out_dim)
|
|
297
|
+
self.center_fusion = nn.Sequential(
|
|
298
|
+
nn.Linear(lstm_out_dim * 2, lstm_out_dim),
|
|
299
|
+
nn.LayerNorm(lstm_out_dim),
|
|
300
|
+
nn.GELU(),
|
|
301
|
+
)
|
|
302
|
+
self.lstm_norm = nn.LayerNorm(lstm_out_dim)
|
|
303
|
+
self.classifier = nn.Sequential(
|
|
304
|
+
nn.Linear(lstm_out_dim, Config.HIDDEN_DIM),
|
|
305
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
306
|
+
nn.GELU(),
|
|
307
|
+
nn.Dropout(Config.DROPOUT),
|
|
308
|
+
nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
|
|
309
|
+
nn.LayerNorm(Config.HIDDEN_DIM // 2),
|
|
310
|
+
nn.GELU(),
|
|
311
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
312
|
+
nn.Linear(Config.HIDDEN_DIM // 2, 2),
|
|
313
|
+
)
|
|
314
|
+
self._init_weights()
|
|
315
|
+
|
|
316
|
+
def _init_weights(self):
|
|
317
|
+
for name, param in self.lstm.named_parameters():
|
|
318
|
+
if "weight_ih" in name:
|
|
319
|
+
nn.init.xavier_uniform_(param.data)
|
|
320
|
+
elif "weight_hh" in name:
|
|
321
|
+
nn.init.orthogonal_(param.data)
|
|
322
|
+
elif "bias" in name:
|
|
323
|
+
nn.init.constant_(param.data, 0)
|
|
324
|
+
gate = Config.LSTM_HIDDEN
|
|
325
|
+
param.data[gate : 2 * gate] = 1.0
|
|
326
|
+
for module in self.modules():
|
|
327
|
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
|
328
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
329
|
+
if module.bias is not None:
|
|
330
|
+
nn.init.constant_(module.bias, 0.0)
|
|
331
|
+
|
|
332
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
333
|
+
x = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
334
|
+
x = self.embedding(x)
|
|
335
|
+
lstm_out, (hidden, _) = self.lstm(x)
|
|
336
|
+
center_feature = lstm_out[:, Config.CONTEXT_K, :]
|
|
337
|
+
if self.use_attention:
|
|
338
|
+
context = self.attention(lstm_out)
|
|
339
|
+
elif Config.BIDIRECTIONAL:
|
|
340
|
+
context = torch.cat([hidden[-2], hidden[-1]], dim=1)
|
|
341
|
+
else:
|
|
342
|
+
context = hidden[-1]
|
|
343
|
+
context = self.center_fusion(torch.cat([context, center_feature], dim=1))
|
|
344
|
+
context = self.lstm_norm(context)
|
|
345
|
+
return self.classifier(context)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class HybridCNNLSTMEqualizer(nn.Module):
|
|
349
|
+
def __init__(self):
|
|
350
|
+
super().__init__()
|
|
351
|
+
self.cnn = nn.Sequential(
|
|
352
|
+
nn.Conv1d(2, 64, kernel_size=3, padding=1),
|
|
353
|
+
nn.BatchNorm1d(64),
|
|
354
|
+
nn.GELU(),
|
|
355
|
+
nn.Dropout(Config.DROPOUT * 0.3),
|
|
356
|
+
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
|
357
|
+
nn.BatchNorm1d(128),
|
|
358
|
+
nn.GELU(),
|
|
359
|
+
nn.Dropout(Config.DROPOUT * 0.3),
|
|
360
|
+
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
|
361
|
+
nn.BatchNorm1d(256),
|
|
362
|
+
nn.GELU(),
|
|
363
|
+
nn.Dropout(Config.DROPOUT * 0.3),
|
|
364
|
+
)
|
|
365
|
+
self.lstm = nn.LSTM(
|
|
366
|
+
input_size=256,
|
|
367
|
+
hidden_size=Config.LSTM_HIDDEN,
|
|
368
|
+
num_layers=2,
|
|
369
|
+
batch_first=True,
|
|
370
|
+
bidirectional=Config.BIDIRECTIONAL,
|
|
371
|
+
dropout=Config.DROPOUT,
|
|
372
|
+
)
|
|
373
|
+
out_dim = Config.LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
|
|
374
|
+
self.attention = AttentionLayer(out_dim)
|
|
375
|
+
self.classifier = nn.Sequential(
|
|
376
|
+
nn.Linear(out_dim, Config.HIDDEN_DIM),
|
|
377
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
378
|
+
nn.GELU(),
|
|
379
|
+
nn.Dropout(Config.DROPOUT),
|
|
380
|
+
nn.Linear(Config.HIDDEN_DIM, 2),
|
|
381
|
+
)
|
|
382
|
+
self._init_weights()
|
|
383
|
+
|
|
384
|
+
def _init_weights(self):
|
|
385
|
+
for module in self.modules():
|
|
386
|
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
|
387
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
388
|
+
if module.bias is not None:
|
|
389
|
+
nn.init.constant_(module.bias, 0.0)
|
|
390
|
+
|
|
391
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
392
|
+
x = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM).transpose(1, 2)
|
|
393
|
+
x = self.cnn(x).transpose(1, 2)
|
|
394
|
+
x, _ = self.lstm(x)
|
|
395
|
+
x = self.attention(x)
|
|
396
|
+
return self.classifier(x)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class CNNRxEqualizer(nn.Module):
|
|
400
|
+
def __init__(self):
|
|
401
|
+
super().__init__()
|
|
402
|
+
hidden_dim = Config.HIDDEN_DIM
|
|
403
|
+
self.cnn = nn.Sequential(
|
|
404
|
+
nn.Conv1d(Config.INPUT_DIM, hidden_dim, kernel_size=5, padding=2),
|
|
405
|
+
nn.BatchNorm1d(hidden_dim),
|
|
406
|
+
nn.GELU(),
|
|
407
|
+
nn.Dropout(Config.DROPOUT * 0.25),
|
|
408
|
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
|
409
|
+
nn.BatchNorm1d(hidden_dim),
|
|
410
|
+
nn.GELU(),
|
|
411
|
+
nn.Dropout(Config.DROPOUT * 0.25),
|
|
412
|
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
|
413
|
+
nn.BatchNorm1d(hidden_dim),
|
|
414
|
+
nn.GELU(),
|
|
415
|
+
nn.Dropout(Config.DROPOUT * 0.25),
|
|
416
|
+
)
|
|
417
|
+
self.pool_score = nn.Conv1d(hidden_dim, 1, kernel_size=1)
|
|
418
|
+
fused_dim = hidden_dim * 2 + 2 * Config.INPUT_DIM
|
|
419
|
+
self.head = nn.Sequential(
|
|
420
|
+
nn.Linear(fused_dim, hidden_dim),
|
|
421
|
+
nn.LayerNorm(hidden_dim),
|
|
422
|
+
nn.GELU(),
|
|
423
|
+
nn.Dropout(Config.DROPOUT),
|
|
424
|
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
425
|
+
nn.LayerNorm(hidden_dim // 2),
|
|
426
|
+
nn.GELU(),
|
|
427
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
428
|
+
nn.Linear(hidden_dim // 2, 2),
|
|
429
|
+
)
|
|
430
|
+
self._init_weights()
|
|
431
|
+
|
|
432
|
+
def _init_weights(self):
|
|
433
|
+
for module in self.modules():
|
|
434
|
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
|
435
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
436
|
+
if module.bias is not None:
|
|
437
|
+
nn.init.constant_(module.bias, 0.0)
|
|
438
|
+
elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
|
|
439
|
+
if hasattr(module, "weight") and module.weight is not None:
|
|
440
|
+
nn.init.constant_(module.weight, 1.0)
|
|
441
|
+
if hasattr(module, "bias") and module.bias is not None:
|
|
442
|
+
nn.init.constant_(module.bias, 0.0)
|
|
443
|
+
|
|
444
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
445
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
446
|
+
seq = raw.transpose(1, 2)
|
|
447
|
+
features = self.cnn(seq)
|
|
448
|
+
center = features[:, :, Config.CONTEXT_K]
|
|
449
|
+
weights = torch.softmax(self.pool_score(features), dim=2)
|
|
450
|
+
global_context = torch.sum(weights * features, dim=2)
|
|
451
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
452
|
+
raw_mean = raw.mean(dim=1)
|
|
453
|
+
fused = torch.cat([center, global_context, raw_center, raw_mean], dim=1)
|
|
454
|
+
return self.head(fused)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class ComplexConv1d(nn.Module):
|
|
458
|
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, groups: int = 1, symmetric: bool = False):
|
|
459
|
+
super().__init__()
|
|
460
|
+
self.symmetric = symmetric
|
|
461
|
+
effective_kernel = 2 * kernel_size - 1 if symmetric else kernel_size
|
|
462
|
+
padding = effective_kernel // 2
|
|
463
|
+
self.real_conv = nn.Conv1d(
|
|
464
|
+
in_channels,
|
|
465
|
+
out_channels,
|
|
466
|
+
kernel_size=kernel_size,
|
|
467
|
+
padding=padding,
|
|
468
|
+
groups=groups,
|
|
469
|
+
bias=False,
|
|
470
|
+
)
|
|
471
|
+
self.imag_conv = nn.Conv1d(
|
|
472
|
+
in_channels,
|
|
473
|
+
out_channels,
|
|
474
|
+
kernel_size=kernel_size,
|
|
475
|
+
padding=padding,
|
|
476
|
+
groups=groups,
|
|
477
|
+
bias=False,
|
|
478
|
+
)
|
|
479
|
+
self._init_weights()
|
|
480
|
+
|
|
481
|
+
def _init_weights(self):
|
|
482
|
+
nn.init.kaiming_normal_(self.real_conv.weight, nonlinearity="linear")
|
|
483
|
+
nn.init.kaiming_normal_(self.imag_conv.weight, nonlinearity="linear")
|
|
484
|
+
|
|
485
|
+
def _build_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
|
486
|
+
if not self.symmetric:
|
|
487
|
+
return weight
|
|
488
|
+
return torch.cat([weight, weight.flip(dims=(2,))[:, :, 1:]], dim=2)
|
|
489
|
+
|
|
490
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
491
|
+
real = x[:, 0::2, :]
|
|
492
|
+
imag = x[:, 1::2, :]
|
|
493
|
+
real_weight = self._build_weight(self.real_conv.weight)
|
|
494
|
+
imag_weight = self._build_weight(self.imag_conv.weight)
|
|
495
|
+
out_real = F.conv1d(real, real_weight, padding=self.real_conv.padding[0], groups=self.real_conv.groups) - F.conv1d(
|
|
496
|
+
imag, imag_weight, padding=self.imag_conv.padding[0], groups=self.imag_conv.groups
|
|
497
|
+
)
|
|
498
|
+
out_imag = F.conv1d(real, imag_weight, padding=self.imag_conv.padding[0], groups=self.imag_conv.groups) + F.conv1d(
|
|
499
|
+
imag, real_weight, padding=self.real_conv.padding[0], groups=self.real_conv.groups
|
|
500
|
+
)
|
|
501
|
+
return torch.stack((out_real, out_imag), dim=2).flatten(1, 2)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class ComplexResidualBlock(nn.Module):
|
|
505
|
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
|
|
506
|
+
super().__init__()
|
|
507
|
+
self.in_channels = in_channels
|
|
508
|
+
self.out_channels = out_channels
|
|
509
|
+
self.pre_norm = nn.BatchNorm1d(2 * in_channels)
|
|
510
|
+
self.expand = ComplexConv1d(in_channels, out_channels, kernel_size=1)
|
|
511
|
+
self.expand_norm = nn.BatchNorm1d(2 * out_channels)
|
|
512
|
+
self.depthwise = ComplexConv1d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels)
|
|
513
|
+
self.depthwise_norm = nn.BatchNorm1d(2 * out_channels)
|
|
514
|
+
self.kerr = KerrLikeActivation(out_channels) if Config.COMPLEX_USE_KERR else nn.Identity()
|
|
515
|
+
self.project = ComplexConv1d(out_channels, out_channels, kernel_size=1)
|
|
516
|
+
self.project_norm = nn.BatchNorm1d(2 * out_channels)
|
|
517
|
+
self.activation = nn.GELU()
|
|
518
|
+
self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
|
|
519
|
+
self.skip = ComplexConv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
|
|
520
|
+
|
|
521
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
522
|
+
residual = self.skip(x)
|
|
523
|
+
x = self.pre_norm(x)
|
|
524
|
+
x = self.activation(self.expand_norm(self.expand(x)))
|
|
525
|
+
x = self.activation(self.depthwise_norm(self.depthwise(x)))
|
|
526
|
+
x = self.kerr(x)
|
|
527
|
+
x = self.dropout(self.project_norm(self.project(x)))
|
|
528
|
+
return residual + x
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class KerrLikeActivation(nn.Module):
|
|
532
|
+
def __init__(self, channels: int, kernel_size: Optional[int] = None, init_gamma: Optional[float] = None, symmetric: bool = False):
|
|
533
|
+
super().__init__()
|
|
534
|
+
kernel_size = Config.COMPLEX_KERR_KERNEL if kernel_size is None else kernel_size
|
|
535
|
+
self.symmetric = symmetric
|
|
536
|
+
self.power_filter = nn.Conv1d(
|
|
537
|
+
channels,
|
|
538
|
+
channels,
|
|
539
|
+
kernel_size=kernel_size,
|
|
540
|
+
padding=kernel_size // 2,
|
|
541
|
+
groups=channels,
|
|
542
|
+
bias=False,
|
|
543
|
+
)
|
|
544
|
+
gamma = Config.COMPLEX_KERR_INIT_GAMMA if init_gamma is None else init_gamma
|
|
545
|
+
self.gamma = nn.Parameter(torch.full((1, channels, 1), gamma))
|
|
546
|
+
self._init_weights()
|
|
547
|
+
|
|
548
|
+
def _init_weights(self):
|
|
549
|
+
nn.init.zeros_(self.power_filter.weight)
|
|
550
|
+
center = self.power_filter.weight.size(-1) // 2
|
|
551
|
+
self.power_filter.weight.data[:, :, center] = 1.0
|
|
552
|
+
|
|
553
|
+
def _build_weight(self) -> torch.Tensor:
|
|
554
|
+
weight = self.power_filter.weight
|
|
555
|
+
if not self.symmetric:
|
|
556
|
+
return weight
|
|
557
|
+
return torch.cat([weight, weight.flip(dims=(2,))[:, :, 1:]], dim=2)
|
|
558
|
+
|
|
559
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
560
|
+
real = x[:, 0::2, :]
|
|
561
|
+
imag = x[:, 1::2, :]
|
|
562
|
+
power_weight = self._build_weight()
|
|
563
|
+
power = F.conv1d(
|
|
564
|
+
real.square() + imag.square(),
|
|
565
|
+
power_weight,
|
|
566
|
+
padding=power_weight.size(-1) // 2,
|
|
567
|
+
groups=self.power_filter.groups,
|
|
568
|
+
)
|
|
569
|
+
phase = self.gamma * power
|
|
570
|
+
cos_phase = torch.cos(phase)
|
|
571
|
+
sin_phase = torch.sin(phase)
|
|
572
|
+
out_real = cos_phase * real + sin_phase * imag
|
|
573
|
+
out_imag = cos_phase * imag - sin_phase * real
|
|
574
|
+
return torch.stack((out_real, out_imag), dim=2).flatten(1, 2)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class TemporalConvBlock(nn.Module):
|
|
578
|
+
def __init__(self, channels: int, dilation: int):
|
|
579
|
+
super().__init__()
|
|
580
|
+
kernel_size = 3
|
|
581
|
+
padding = dilation * (kernel_size - 1) // 2
|
|
582
|
+
self.norm = nn.BatchNorm1d(channels)
|
|
583
|
+
self.depthwise = nn.Conv1d(
|
|
584
|
+
channels,
|
|
585
|
+
channels,
|
|
586
|
+
kernel_size=kernel_size,
|
|
587
|
+
padding=padding,
|
|
588
|
+
dilation=dilation,
|
|
589
|
+
groups=channels,
|
|
590
|
+
bias=False,
|
|
591
|
+
)
|
|
592
|
+
self.pointwise = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
|
|
593
|
+
self.activation = nn.GELU()
|
|
594
|
+
self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
|
|
595
|
+
self._init_weights()
|
|
596
|
+
|
|
597
|
+
def _init_weights(self):
|
|
598
|
+
nn.init.kaiming_normal_(self.depthwise.weight, nonlinearity="relu")
|
|
599
|
+
nn.init.kaiming_normal_(self.pointwise.weight, nonlinearity="relu")
|
|
600
|
+
|
|
601
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
602
|
+
residual = x
|
|
603
|
+
x = self.activation(self.norm(x))
|
|
604
|
+
x = self.depthwise(x)
|
|
605
|
+
x = self.activation(self.pointwise(x))
|
|
606
|
+
x = self.dropout(x)
|
|
607
|
+
return x + residual
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class LightweightComplexTemporalBlock(nn.Module):
|
|
611
|
+
def __init__(self, channels: int, kernel_size: int, dilation: int):
|
|
612
|
+
super().__init__()
|
|
613
|
+
padding = dilation * (kernel_size - 1) // 2
|
|
614
|
+
self.norm = nn.BatchNorm1d(channels)
|
|
615
|
+
self.depthwise = nn.Conv1d(
|
|
616
|
+
channels,
|
|
617
|
+
channels,
|
|
618
|
+
kernel_size=kernel_size,
|
|
619
|
+
padding=padding,
|
|
620
|
+
dilation=dilation,
|
|
621
|
+
groups=channels,
|
|
622
|
+
bias=False,
|
|
623
|
+
)
|
|
624
|
+
self.mix = nn.Conv1d(channels, channels * 2, kernel_size=1, bias=False)
|
|
625
|
+
self.gate_proj = nn.Conv1d(1, channels * 2, kernel_size=1, bias=True)
|
|
626
|
+
self.out_proj = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
|
|
627
|
+
self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
|
|
628
|
+
self._init_weights()
|
|
629
|
+
|
|
630
|
+
def _init_weights(self):
|
|
631
|
+
for module in (self.depthwise, self.mix, self.gate_proj, self.out_proj):
|
|
632
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
633
|
+
if module.bias is not None:
|
|
634
|
+
nn.init.constant_(module.bias, 0.0)
|
|
635
|
+
|
|
636
|
+
def forward(self, x: torch.Tensor, power: torch.Tensor) -> torch.Tensor:
|
|
637
|
+
residual = x
|
|
638
|
+
x = self.norm(x)
|
|
639
|
+
x = self.depthwise(x)
|
|
640
|
+
value, gate = self.mix(x).chunk(2, dim=1)
|
|
641
|
+
power_gate = self.gate_proj(power)
|
|
642
|
+
value = F.gelu(value + power_gate[:, : value.size(1), :])
|
|
643
|
+
gate = torch.sigmoid(gate + power_gate[:, value.size(1) :, :])
|
|
644
|
+
x = self.out_proj(value * gate)
|
|
645
|
+
x = self.dropout(x)
|
|
646
|
+
return x + residual
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
class LightweightComplexEncoder(nn.Module):
|
|
650
|
+
def __init__(self):
|
|
651
|
+
super().__init__()
|
|
652
|
+
channels = Config.COMPLEX_LIGHT_CHANNELS
|
|
653
|
+
self.stem = nn.Sequential(
|
|
654
|
+
nn.Conv1d(5, channels, kernel_size=1, bias=False),
|
|
655
|
+
nn.BatchNorm1d(channels),
|
|
656
|
+
nn.GELU(),
|
|
657
|
+
)
|
|
658
|
+
self.blocks = nn.ModuleList(
|
|
659
|
+
[
|
|
660
|
+
LightweightComplexTemporalBlock(
|
|
661
|
+
channels=channels,
|
|
662
|
+
kernel_size=Config.COMPLEX_LIGHT_KERNEL_SIZE,
|
|
663
|
+
dilation=dilation,
|
|
664
|
+
)
|
|
665
|
+
for dilation in Config.COMPLEX_LIGHT_DILATIONS
|
|
666
|
+
]
|
|
667
|
+
)
|
|
668
|
+
self.out_norm = nn.BatchNorm1d(channels)
|
|
669
|
+
self.out_channels = channels
|
|
670
|
+
self._init_weights()
|
|
671
|
+
|
|
672
|
+
def _init_weights(self):
|
|
673
|
+
for module in self.modules():
|
|
674
|
+
if isinstance(module, nn.Conv1d):
|
|
675
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
676
|
+
if module.bias is not None:
|
|
677
|
+
nn.init.constant_(module.bias, 0.0)
|
|
678
|
+
elif isinstance(module, nn.BatchNorm1d):
|
|
679
|
+
nn.init.constant_(module.weight, 1.0)
|
|
680
|
+
nn.init.constant_(module.bias, 0.0)
|
|
681
|
+
|
|
682
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
683
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
684
|
+
seq = raw.transpose(1, 2)
|
|
685
|
+
real = seq[:, 0:1, :]
|
|
686
|
+
imag = seq[:, 1:2, :]
|
|
687
|
+
power = real.square() + imag.square()
|
|
688
|
+
magnitude = torch.sqrt(power + 1e-6)
|
|
689
|
+
cross = real * imag
|
|
690
|
+
features = torch.cat([real, imag, magnitude, power, cross], dim=1)
|
|
691
|
+
hidden = self.stem(features)
|
|
692
|
+
for block in self.blocks:
|
|
693
|
+
hidden = block(hidden, power)
|
|
694
|
+
hidden = self.out_norm(hidden)
|
|
695
|
+
seq_features = torch.cat([hidden, real, imag, magnitude], dim=1).transpose(1, 2)
|
|
696
|
+
return raw, hidden, seq_features
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
class GaussianRBFExpansion(nn.Module):
|
|
700
|
+
def __init__(self, num_grids: int, grid_min: float, grid_max: float):
|
|
701
|
+
super().__init__()
|
|
702
|
+
centers = torch.linspace(grid_min, grid_max, num_grids)
|
|
703
|
+
spacing = float(centers[1] - centers[0]) if num_grids > 1 else max(abs(grid_max - grid_min), 1.0)
|
|
704
|
+
self.register_buffer("centers", centers)
|
|
705
|
+
self.log_inv_scale = nn.Parameter(torch.log(torch.tensor(1.0 / max(spacing, 1e-3))))
|
|
706
|
+
|
|
707
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
708
|
+
inv_scale = self.log_inv_scale.exp().clamp(1e-3, 1e3)
|
|
709
|
+
diff = (x.unsqueeze(-1) - self.centers) * inv_scale
|
|
710
|
+
return torch.exp(-(diff * diff))
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
class FastKANLayer(nn.Module):
|
|
714
|
+
def __init__(self, in_features: int, out_features: int):
|
|
715
|
+
super().__init__()
|
|
716
|
+
self.in_features = in_features
|
|
717
|
+
self.out_features = out_features
|
|
718
|
+
self.rbf = GaussianRBFExpansion(
|
|
719
|
+
num_grids=Config.FASTKAN_NUM_GRIDS,
|
|
720
|
+
grid_min=Config.FASTKAN_GRID_MIN,
|
|
721
|
+
grid_max=Config.FASTKAN_GRID_MAX,
|
|
722
|
+
)
|
|
723
|
+
self.base_linear = nn.Linear(in_features, out_features)
|
|
724
|
+
self.spline_linear = nn.Linear(in_features * Config.FASTKAN_NUM_GRIDS, out_features)
|
|
725
|
+
self.norm = nn.LayerNorm(out_features)
|
|
726
|
+
self.dropout = nn.Dropout(Config.KAN_HIDDEN_DROPOUT)
|
|
727
|
+
self._init_weights()
|
|
728
|
+
|
|
729
|
+
def _init_weights(self):
|
|
730
|
+
nn.init.xavier_uniform_(self.base_linear.weight)
|
|
731
|
+
nn.init.zeros_(self.base_linear.bias)
|
|
732
|
+
nn.init.xavier_uniform_(self.spline_linear.weight)
|
|
733
|
+
nn.init.zeros_(self.spline_linear.bias)
|
|
734
|
+
nn.init.constant_(self.norm.weight, 1.0)
|
|
735
|
+
nn.init.constant_(self.norm.bias, 0.0)
|
|
736
|
+
|
|
737
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
738
|
+
if Config.FASTKAN_BASE_ACT == "gelu":
|
|
739
|
+
base = F.gelu(x)
|
|
740
|
+
else:
|
|
741
|
+
base = F.silu(x)
|
|
742
|
+
base_out = self.base_linear(base) if Config.FASTKAN_USE_BASE_PATH else 0.0
|
|
743
|
+
spline_basis = self.rbf(x).flatten(start_dim=1)
|
|
744
|
+
spline_out = self.spline_linear(spline_basis)
|
|
745
|
+
out = self.norm(base_out + spline_out)
|
|
746
|
+
return self.dropout(F.gelu(out))
|
|
747
|
+
|
|
748
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
749
|
+
return self.spline_linear.weight.abs().mean()
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
class FastKANHead(nn.Module):
|
|
753
|
+
def __init__(self, input_dim: int, hidden_dim: int, out_dim: int):
|
|
754
|
+
super().__init__()
|
|
755
|
+
self.input_norm = nn.LayerNorm(input_dim)
|
|
756
|
+
self.input_dropout = nn.Dropout(Config.KAN_INPUT_DROPOUT)
|
|
757
|
+
self.feature_gate = nn.Parameter(torch.ones(input_dim))
|
|
758
|
+
self.layers = nn.ModuleList()
|
|
759
|
+
in_dim = input_dim
|
|
760
|
+
for _ in range(Config.FASTKAN_LAYERS):
|
|
761
|
+
self.layers.append(FastKANLayer(in_dim, hidden_dim))
|
|
762
|
+
in_dim = hidden_dim
|
|
763
|
+
self.output = nn.Linear(in_dim, out_dim)
|
|
764
|
+
self._init_weights()
|
|
765
|
+
|
|
766
|
+
def _init_weights(self):
|
|
767
|
+
nn.init.xavier_uniform_(self.output.weight)
|
|
768
|
+
nn.init.zeros_(self.output.bias)
|
|
769
|
+
nn.init.constant_(self.input_norm.weight, 1.0)
|
|
770
|
+
nn.init.constant_(self.input_norm.bias, 0.0)
|
|
771
|
+
|
|
772
|
+
def _gated_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
773
|
+
gate = self.feature_gate
|
|
774
|
+
if not self.training and Config.KAN_PRUNE_THRESHOLD > 0:
|
|
775
|
+
gate = gate * (gate.abs() >= Config.KAN_PRUNE_THRESHOLD)
|
|
776
|
+
return x * gate.unsqueeze(0)
|
|
777
|
+
|
|
778
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
779
|
+
x = self.input_dropout(self.input_norm(x))
|
|
780
|
+
x = self._gated_features(x)
|
|
781
|
+
for layer in self.layers:
|
|
782
|
+
x = layer(x)
|
|
783
|
+
return self.output(x)
|
|
784
|
+
|
|
785
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
786
|
+
reg = self.feature_gate.abs().mean()
|
|
787
|
+
for layer in self.layers:
|
|
788
|
+
reg = reg + layer.regularization_loss()
|
|
789
|
+
return reg
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
class EfficientKANBaselineEqualizer(nn.Module):
|
|
793
|
+
def __init__(self):
|
|
794
|
+
super().__init__()
|
|
795
|
+
if EfficientKAN is None:
|
|
796
|
+
raise ImportError(
|
|
797
|
+
"efficient_kan is unavailable. Keep the local `efficient_kan/` package next to this script or "
|
|
798
|
+
"install the library before running."
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
input_dim = Config.SEQ_LEN * Config.INPUT_DIM
|
|
802
|
+
hidden_layers = [Config.EFFICIENT_KAN_HIDDEN_DIM] * max(Config.EFFICIENT_KAN_LAYERS, 1)
|
|
803
|
+
self.kan = EfficientKAN(
|
|
804
|
+
layers_hidden=[input_dim, *hidden_layers, 2],
|
|
805
|
+
grid_size=Config.EFFICIENT_KAN_GRID_SIZE,
|
|
806
|
+
spline_order=Config.EFFICIENT_KAN_SPLINE_ORDER,
|
|
807
|
+
scale_noise=Config.EFFICIENT_KAN_SCALE_NOISE,
|
|
808
|
+
scale_base=Config.EFFICIENT_KAN_SCALE_BASE,
|
|
809
|
+
scale_spline=Config.EFFICIENT_KAN_SCALE_SPLINE,
|
|
810
|
+
base_activation=nn.SiLU,
|
|
811
|
+
grid_eps=Config.EFFICIENT_KAN_GRID_EPS,
|
|
812
|
+
grid_range=Config.EFFICIENT_KAN_GRID_RANGE,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
816
|
+
return self.kan(x)
|
|
817
|
+
|
|
818
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
819
|
+
if hasattr(self.kan, "regularization_loss"):
|
|
820
|
+
reg = self.kan.regularization_loss()
|
|
821
|
+
if torch.is_tensor(reg):
|
|
822
|
+
return reg
|
|
823
|
+
return torch.tensor(float(reg), device=Config.DEVICE)
|
|
824
|
+
return torch.zeros((), device=Config.DEVICE)
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
class FastKANClassifierEqualizer(nn.Module):
|
|
828
|
+
def __init__(self):
|
|
829
|
+
super().__init__()
|
|
830
|
+
input_dim = Config.SEQ_LEN * Config.INPUT_DIM
|
|
831
|
+
self.head = FastKANHead(
|
|
832
|
+
input_dim=input_dim,
|
|
833
|
+
hidden_dim=Config.FASTKAN_HIDDEN_DIM,
|
|
834
|
+
out_dim=CONSTELLATION.size(0),
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
838
|
+
return self.head(x)
|
|
839
|
+
|
|
840
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
841
|
+
return self.head.regularization_loss()
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
def make_efficient_kan(input_dim: int, output_dim: int) -> nn.Module:
|
|
845
|
+
if EfficientKAN is None:
|
|
846
|
+
raise ImportError(
|
|
847
|
+
"efficient_kan is unavailable. Keep the local `efficient_kan/` package next to this script or "
|
|
848
|
+
"install the library before running."
|
|
849
|
+
)
|
|
850
|
+
hidden_layers = [Config.EFFICIENT_KAN_HIDDEN_DIM] * max(Config.EFFICIENT_KAN_LAYERS, 1)
|
|
851
|
+
return EfficientKAN(
|
|
852
|
+
layers_hidden=[input_dim, *hidden_layers, output_dim],
|
|
853
|
+
grid_size=Config.EFFICIENT_KAN_GRID_SIZE,
|
|
854
|
+
spline_order=Config.EFFICIENT_KAN_SPLINE_ORDER,
|
|
855
|
+
scale_noise=Config.EFFICIENT_KAN_SCALE_NOISE,
|
|
856
|
+
scale_base=Config.EFFICIENT_KAN_SCALE_BASE,
|
|
857
|
+
scale_spline=Config.EFFICIENT_KAN_SCALE_SPLINE,
|
|
858
|
+
base_activation=nn.SiLU,
|
|
859
|
+
grid_eps=Config.EFFICIENT_KAN_GRID_EPS,
|
|
860
|
+
grid_range=Config.EFFICIENT_KAN_GRID_RANGE,
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def efficient_kan_regularization(kan: nn.Module) -> torch.Tensor:
|
|
865
|
+
if hasattr(kan, "regularization_loss"):
|
|
866
|
+
reg = kan.regularization_loss()
|
|
867
|
+
if torch.is_tensor(reg):
|
|
868
|
+
return reg
|
|
869
|
+
return torch.tensor(float(reg), device=Config.DEVICE)
|
|
870
|
+
return torch.zeros((), device=Config.DEVICE)
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
class EfficientKANResidualEqualizer(nn.Module):
|
|
874
|
+
def __init__(self):
|
|
875
|
+
super().__init__()
|
|
876
|
+
input_dim = Config.SEQ_LEN * Config.INPUT_DIM
|
|
877
|
+
self.kan = make_efficient_kan(input_dim=input_dim, output_dim=2)
|
|
878
|
+
self.residual_scale = nn.Parameter(torch.tensor(1.0))
|
|
879
|
+
|
|
880
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
881
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
882
|
+
center_rx = raw[:, Config.CONTEXT_K, :]
|
|
883
|
+
correction = self.kan(x)
|
|
884
|
+
return center_rx + self.residual_scale * correction
|
|
885
|
+
|
|
886
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
887
|
+
return efficient_kan_regularization(self.kan)
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
class EfficientKANFeatureEqualizer(nn.Module):
|
|
891
|
+
def __init__(self):
|
|
892
|
+
super().__init__()
|
|
893
|
+
self.feature_dim = 17
|
|
894
|
+
self.kan = make_efficient_kan(input_dim=self.feature_dim, output_dim=2)
|
|
895
|
+
|
|
896
|
+
def _features(self, x: torch.Tensor) -> torch.Tensor:
|
|
897
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
898
|
+
center = raw[:, Config.CONTEXT_K, :]
|
|
899
|
+
radius = min(Config.KAN_FEATURE_RADIUS, Config.CONTEXT_K)
|
|
900
|
+
local = raw[:, Config.CONTEXT_K - radius : Config.CONTEXT_K + radius + 1, :]
|
|
901
|
+
|
|
902
|
+
global_mean = raw.mean(dim=1)
|
|
903
|
+
global_std = raw.std(dim=1, unbiased=False)
|
|
904
|
+
local_mean = local.mean(dim=1)
|
|
905
|
+
local_std = local.std(dim=1, unbiased=False)
|
|
906
|
+
|
|
907
|
+
power = raw.square().sum(dim=2)
|
|
908
|
+
power_center = power[:, Config.CONTEXT_K : Config.CONTEXT_K + 1]
|
|
909
|
+
power_mean = power.mean(dim=1, keepdim=True)
|
|
910
|
+
power_std = power.std(dim=1, unbiased=False, keepdim=True)
|
|
911
|
+
|
|
912
|
+
cross = raw[:, :, 0] * raw[:, :, 1]
|
|
913
|
+
cross_mean = cross.mean(dim=1, keepdim=True)
|
|
914
|
+
cross_std = cross.std(dim=1, unbiased=False, keepdim=True)
|
|
915
|
+
edge_delta = raw[:, -1, :] - raw[:, 0, :]
|
|
916
|
+
|
|
917
|
+
return torch.cat(
|
|
918
|
+
[
|
|
919
|
+
center,
|
|
920
|
+
local_mean,
|
|
921
|
+
local_std,
|
|
922
|
+
global_mean,
|
|
923
|
+
global_std,
|
|
924
|
+
power_center,
|
|
925
|
+
power_mean,
|
|
926
|
+
power_std,
|
|
927
|
+
cross_mean,
|
|
928
|
+
cross_std,
|
|
929
|
+
edge_delta,
|
|
930
|
+
],
|
|
931
|
+
dim=1,
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
935
|
+
return self.kan(self._features(x))
|
|
936
|
+
|
|
937
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
938
|
+
return efficient_kan_regularization(self.kan)
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
class CNNKANEqualizer(nn.Module):
|
|
942
|
+
def __init__(self):
|
|
943
|
+
super().__init__()
|
|
944
|
+
hidden_dim = Config.HIDDEN_DIM
|
|
945
|
+
self.cnn = nn.Sequential(
|
|
946
|
+
nn.Conv1d(Config.INPUT_DIM, hidden_dim, kernel_size=5, padding=2),
|
|
947
|
+
nn.BatchNorm1d(hidden_dim),
|
|
948
|
+
nn.GELU(),
|
|
949
|
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
|
950
|
+
nn.BatchNorm1d(hidden_dim),
|
|
951
|
+
nn.GELU(),
|
|
952
|
+
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
|
953
|
+
nn.BatchNorm1d(hidden_dim),
|
|
954
|
+
nn.GELU(),
|
|
955
|
+
)
|
|
956
|
+
self.pool_score = nn.Conv1d(hidden_dim, 1, kernel_size=1)
|
|
957
|
+
fused_dim = hidden_dim * 2 + 2 * Config.INPUT_DIM
|
|
958
|
+
self.kan = make_efficient_kan(input_dim=fused_dim, output_dim=2)
|
|
959
|
+
self._init_weights()
|
|
960
|
+
|
|
961
|
+
def _init_weights(self):
|
|
962
|
+
for module in self.cnn.modules():
|
|
963
|
+
if isinstance(module, nn.Conv1d):
|
|
964
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
965
|
+
if module.bias is not None:
|
|
966
|
+
nn.init.constant_(module.bias, 0.0)
|
|
967
|
+
elif isinstance(module, nn.BatchNorm1d):
|
|
968
|
+
nn.init.constant_(module.weight, 1.0)
|
|
969
|
+
nn.init.constant_(module.bias, 0.0)
|
|
970
|
+
nn.init.kaiming_normal_(self.pool_score.weight, nonlinearity="linear")
|
|
971
|
+
nn.init.constant_(self.pool_score.bias, 0.0)
|
|
972
|
+
|
|
973
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
974
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
975
|
+
features = self.cnn(raw.transpose(1, 2))
|
|
976
|
+
center = features[:, :, Config.CONTEXT_K]
|
|
977
|
+
weights = torch.softmax(self.pool_score(features), dim=2)
|
|
978
|
+
global_context = torch.sum(weights * features, dim=2)
|
|
979
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
980
|
+
raw_mean = raw.mean(dim=1)
|
|
981
|
+
fused = torch.cat([center, global_context, raw_center, raw_mean], dim=1)
|
|
982
|
+
return self.kan(fused)
|
|
983
|
+
|
|
984
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
985
|
+
return efficient_kan_regularization(self.kan)
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
class EfficientKANClassifierEqualizer(nn.Module):
|
|
989
|
+
def __init__(self):
|
|
990
|
+
super().__init__()
|
|
991
|
+
input_dim = Config.SEQ_LEN * Config.INPUT_DIM
|
|
992
|
+
self.kan = make_efficient_kan(input_dim=input_dim, output_dim=CONSTELLATION.size(0))
|
|
993
|
+
|
|
994
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
995
|
+
return self.kan(x)
|
|
996
|
+
|
|
997
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
998
|
+
return efficient_kan_regularization(self.kan)
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
class ComplexFeatureEncoder(nn.Module):
|
|
1002
|
+
def __init__(self):
|
|
1003
|
+
super().__init__()
|
|
1004
|
+
self.dbp_frontend = ComplexDBPFrontEnd() if Config.COMPLEX_USE_DBP_FRONTEND else None
|
|
1005
|
+
channels = Config.COMPLEX_BLOCK_CHANNELS
|
|
1006
|
+
kernels = Config.COMPLEX_KERNEL_SIZES
|
|
1007
|
+
if len(channels) != len(kernels):
|
|
1008
|
+
raise ValueError("COMPLEX_BLOCK_CHANNELS and COMPLEX_KERNEL_SIZES must have the same length.")
|
|
1009
|
+
|
|
1010
|
+
self.stem = ComplexConv1d(1, channels[0], kernel_size=1)
|
|
1011
|
+
self.stem_norm = nn.BatchNorm1d(2 * channels[0])
|
|
1012
|
+
self.blocks = nn.ModuleList()
|
|
1013
|
+
in_channels = channels[0]
|
|
1014
|
+
for out_channels, kernel_size in zip(channels, kernels):
|
|
1015
|
+
self.blocks.append(ComplexResidualBlock(in_channels, out_channels, kernel_size))
|
|
1016
|
+
in_channels = out_channels
|
|
1017
|
+
|
|
1018
|
+
self.final_norm = nn.BatchNorm1d(2 * in_channels)
|
|
1019
|
+
self.out_channels = in_channels
|
|
1020
|
+
|
|
1021
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1022
|
+
if self.dbp_frontend is not None:
|
|
1023
|
+
self.dbp_frontend.initialize_from_data(train_x, train_y)
|
|
1024
|
+
|
|
1025
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1026
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
1027
|
+
if self.dbp_frontend is not None:
|
|
1028
|
+
x, _ = self.dbp_frontend(x, collect_states=False)
|
|
1029
|
+
else:
|
|
1030
|
+
x = raw.transpose(1, 2)
|
|
1031
|
+
x = self.stem_norm(self.stem(x))
|
|
1032
|
+
for block in self.blocks:
|
|
1033
|
+
x = block(x)
|
|
1034
|
+
x = self.final_norm(x)
|
|
1035
|
+
real = x[:, 0::2, :]
|
|
1036
|
+
imag = x[:, 1::2, :]
|
|
1037
|
+
magnitude = torch.sqrt(real.square() + imag.square() + 1e-6)
|
|
1038
|
+
seq_features = torch.cat([real, imag, magnitude], dim=1).transpose(1, 2)
|
|
1039
|
+
return raw, x, seq_features
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
class ComplexDBPStep1Ch(nn.Module):
|
|
1043
|
+
def __init__(self):
|
|
1044
|
+
super().__init__()
|
|
1045
|
+
self.linear = ComplexConv1d(
|
|
1046
|
+
in_channels=1,
|
|
1047
|
+
out_channels=1,
|
|
1048
|
+
kernel_size=Config.DBP_KERNEL_SIZE,
|
|
1049
|
+
symmetric=Config.DBP_USE_SYMMETRIC_FILTER,
|
|
1050
|
+
)
|
|
1051
|
+
nl_kernel = Config.DBP_NL_MEMORY + 1 if Config.DBP_USE_SYMMETRIC_NONLINEAR else 2 * Config.DBP_NL_MEMORY + 1
|
|
1052
|
+
self.nonlinear = KerrLikeActivation(
|
|
1053
|
+
1,
|
|
1054
|
+
kernel_size=nl_kernel,
|
|
1055
|
+
init_gamma=Config.COMPLEX_KERR_INIT_GAMMA,
|
|
1056
|
+
symmetric=Config.DBP_USE_SYMMETRIC_NONLINEAR,
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1060
|
+
return self.nonlinear(self.linear(x))
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
class ComplexDBPFrontEnd(nn.Module):
|
|
1064
|
+
def __init__(self):
|
|
1065
|
+
super().__init__()
|
|
1066
|
+
self.steps = nn.ModuleList([ComplexDBPStep1Ch() for _ in range(Config.DBP_NUM_STEPS)])
|
|
1067
|
+
self.final_linear = (
|
|
1068
|
+
ComplexConv1d(
|
|
1069
|
+
in_channels=1,
|
|
1070
|
+
out_channels=1,
|
|
1071
|
+
kernel_size=Config.DBP_FINAL_KERNEL_SIZE,
|
|
1072
|
+
symmetric=Config.DBP_USE_SYMMETRIC_FILTER,
|
|
1073
|
+
)
|
|
1074
|
+
if Config.DBP_USE_FINAL_FILTER
|
|
1075
|
+
else None
|
|
1076
|
+
)
|
|
1077
|
+
linear_delay = Config.DBP_KERNEL_SIZE - 1 if Config.DBP_USE_SYMMETRIC_FILTER else Config.DBP_KERNEL_SIZE // 2
|
|
1078
|
+
final_delay = 0
|
|
1079
|
+
if self.final_linear is not None:
|
|
1080
|
+
final_delay = (
|
|
1081
|
+
Config.DBP_FINAL_KERNEL_SIZE - 1
|
|
1082
|
+
if Config.DBP_USE_SYMMETRIC_FILTER
|
|
1083
|
+
else Config.DBP_FINAL_KERNEL_SIZE // 2
|
|
1084
|
+
)
|
|
1085
|
+
nl_delay = Config.DBP_NL_MEMORY if Config.DBP_USE_SYMMETRIC_NONLINEAR else 2 * Config.DBP_NL_MEMORY
|
|
1086
|
+
self.valid_margin = Config.DBP_NUM_STEPS * (linear_delay + nl_delay) + final_delay
|
|
1087
|
+
if self.valid_margin > Config.CONTEXT_K:
|
|
1088
|
+
raise ValueError(
|
|
1089
|
+
f"DBP receptive radius {self.valid_margin} exceeds CONTEXT_K={Config.CONTEXT_K}. "
|
|
1090
|
+
"Increase CONTEXT_K or reduce DBP kernels/steps."
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
@staticmethod
|
|
1094
|
+
def _seq_features(x: torch.Tensor) -> torch.Tensor:
|
|
1095
|
+
real = x[:, 0:1, :]
|
|
1096
|
+
imag = x[:, 1:2, :]
|
|
1097
|
+
magnitude = torch.sqrt(real.square() + imag.square() + 1e-6)
|
|
1098
|
+
return torch.cat([real, imag, magnitude], dim=1).transpose(1, 2)
|
|
1099
|
+
|
|
1100
|
+
def forward(
|
|
1101
|
+
self,
|
|
1102
|
+
x: torch.Tensor,
|
|
1103
|
+
collect_states: bool = False,
|
|
1104
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
1105
|
+
state = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM).transpose(1, 2)
|
|
1106
|
+
states = [state] if collect_states else []
|
|
1107
|
+
for step in self.steps:
|
|
1108
|
+
state = step(state)
|
|
1109
|
+
if collect_states:
|
|
1110
|
+
states.append(state)
|
|
1111
|
+
if self.final_linear is not None:
|
|
1112
|
+
state = self.final_linear(state)
|
|
1113
|
+
if collect_states:
|
|
1114
|
+
states.append(state)
|
|
1115
|
+
return state, states
|
|
1116
|
+
|
|
1117
|
+
@torch.no_grad()
|
|
1118
|
+
def _initialize_linear_kernels_from_ls(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1119
|
+
samples = min(train_x.size(0), Config.DBP_INIT_SAMPLES)
|
|
1120
|
+
if samples < Config.SEQ_LEN:
|
|
1121
|
+
return
|
|
1122
|
+
windows = train_x[:samples].view(samples, Config.SEQ_LEN, Config.INPUT_DIM)
|
|
1123
|
+
target = train_y[:samples]
|
|
1124
|
+
rx_complex = torch.complex(windows[:, :, 0], windows[:, :, 1]).to(torch.complex64)
|
|
1125
|
+
tx_complex = torch.complex(target[:, 0], target[:, 1]).to(torch.complex64)
|
|
1126
|
+
|
|
1127
|
+
global_kernel = torch.linalg.lstsq(rx_complex, tx_complex.unsqueeze(1)).solution.squeeze(1)
|
|
1128
|
+
global_kernel = global_kernel / global_kernel.norm().clamp_min(1e-6)
|
|
1129
|
+
|
|
1130
|
+
fft_size = max(Config.DBP_INIT_FFT_SIZE, 1 << int(np.ceil(np.log2(Config.SEQ_LEN))))
|
|
1131
|
+
global_response = centered_complex_kernel_to_frequency(global_kernel, fft_size)
|
|
1132
|
+
linear_stage_count = len(self.steps) + (1 if self.final_linear is not None else 0)
|
|
1133
|
+
step_response = complex_unit_response(global_response, linear_stage_count)
|
|
1134
|
+
final_response = global_response / step_response.pow(len(self.steps))
|
|
1135
|
+
|
|
1136
|
+
step_kernel_centered = frequency_to_centered_complex_kernel(step_response, Config.SEQ_LEN)
|
|
1137
|
+
step_kernel = extract_kernel_from_centered_response(
|
|
1138
|
+
step_kernel_centered,
|
|
1139
|
+
Config.DBP_KERNEL_SIZE,
|
|
1140
|
+
Config.DBP_USE_SYMMETRIC_FILTER,
|
|
1141
|
+
)
|
|
1142
|
+
for step in self.steps:
|
|
1143
|
+
assign_complex_kernel(step.linear, step_kernel)
|
|
1144
|
+
center = step.nonlinear.power_filter.weight.size(-1) // 2
|
|
1145
|
+
step.nonlinear.power_filter.weight.zero_()
|
|
1146
|
+
step.nonlinear.power_filter.weight[:, :, center] = 1.0
|
|
1147
|
+
step.nonlinear.gamma.fill_(Config.COMPLEX_KERR_INIT_GAMMA / max(Config.DBP_NUM_STEPS, 1))
|
|
1148
|
+
|
|
1149
|
+
if self.final_linear is not None:
|
|
1150
|
+
final_kernel_centered = frequency_to_centered_complex_kernel(final_response, Config.SEQ_LEN)
|
|
1151
|
+
final_kernel = extract_kernel_from_centered_response(
|
|
1152
|
+
final_kernel_centered,
|
|
1153
|
+
Config.DBP_FINAL_KERNEL_SIZE,
|
|
1154
|
+
Config.DBP_USE_SYMMETRIC_FILTER,
|
|
1155
|
+
)
|
|
1156
|
+
assign_complex_kernel(self.final_linear, final_kernel)
|
|
1157
|
+
|
|
1158
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1159
|
+
if not Config.DBP_INIT_FROM_LS:
|
|
1160
|
+
return
|
|
1161
|
+
self._initialize_linear_kernels_from_ls(train_x, train_y)
|
|
1162
|
+
if not Config.DBP_JOINT_INIT or Config.DBP_JOINT_INIT_ITERS <= 0:
|
|
1163
|
+
return
|
|
1164
|
+
|
|
1165
|
+
was_training = self.training
|
|
1166
|
+
self.train()
|
|
1167
|
+
subset = min(train_x.size(0), Config.DBP_INIT_SAMPLES)
|
|
1168
|
+
if subset < Config.DBP_JOINT_INIT_BATCH_SIZE:
|
|
1169
|
+
if not was_training:
|
|
1170
|
+
self.eval()
|
|
1171
|
+
return
|
|
1172
|
+
|
|
1173
|
+
optimizer = optim.Adam(self.parameters(), lr=Config.DBP_JOINT_INIT_LR)
|
|
1174
|
+
criterion = nn.MSELoss()
|
|
1175
|
+
for _ in range(Config.DBP_JOINT_INIT_ITERS):
|
|
1176
|
+
index = torch.randint(0, subset, (Config.DBP_JOINT_INIT_BATCH_SIZE,))
|
|
1177
|
+
xb = train_x[index].to(Config.DEVICE)
|
|
1178
|
+
yb = train_y[index].to(Config.DEVICE)
|
|
1179
|
+
optimizer.zero_grad(set_to_none=True)
|
|
1180
|
+
state, _ = self.forward(xb, collect_states=False)
|
|
1181
|
+
preds = state[:, :, Config.CONTEXT_K].transpose(0, 1).transpose(0, 1)
|
|
1182
|
+
loss = criterion(preds, yb)
|
|
1183
|
+
loss.backward()
|
|
1184
|
+
optimizer.step()
|
|
1185
|
+
if not was_training:
|
|
1186
|
+
self.eval()
|
|
1187
|
+
|
|
1188
|
+
|
|
1189
|
+
class ComplexDBPSeqStatRxEqualizer(nn.Module):
|
|
1190
|
+
def __init__(self):
|
|
1191
|
+
super().__init__()
|
|
1192
|
+
self.frontend = ComplexDBPFrontEnd()
|
|
1193
|
+
self.valid_margin = self.frontend.valid_margin
|
|
1194
|
+
feature_dim = 3 * (
|
|
1195
|
+
1 + Config.DBP_NUM_STEPS + (1 if Config.DBP_USE_FINAL_FILTER else 0)
|
|
1196
|
+
)
|
|
1197
|
+
fused_dim = feature_dim * 3 + Config.INPUT_DIM
|
|
1198
|
+
self.head = nn.Sequential(
|
|
1199
|
+
nn.Linear(fused_dim, Config.DBP_SEQSTAT_DIM),
|
|
1200
|
+
nn.LayerNorm(Config.DBP_SEQSTAT_DIM),
|
|
1201
|
+
nn.GELU(),
|
|
1202
|
+
nn.Dropout(Config.DROPOUT),
|
|
1203
|
+
nn.Linear(Config.DBP_SEQSTAT_DIM, Config.HIDDEN_DIM),
|
|
1204
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1205
|
+
nn.GELU(),
|
|
1206
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1207
|
+
nn.Linear(Config.HIDDEN_DIM, 2),
|
|
1208
|
+
)
|
|
1209
|
+
self._init_weights()
|
|
1210
|
+
|
|
1211
|
+
def _init_weights(self):
|
|
1212
|
+
for module in self.modules():
|
|
1213
|
+
if isinstance(module, nn.Linear):
|
|
1214
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1215
|
+
if module.bias is not None:
|
|
1216
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1217
|
+
elif isinstance(module, nn.LayerNorm):
|
|
1218
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1219
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1220
|
+
|
|
1221
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1222
|
+
self.frontend.initialize_from_data(train_x, train_y)
|
|
1223
|
+
|
|
1224
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1225
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
1226
|
+
state, states = self.frontend(x, collect_states=True)
|
|
1227
|
+
features = [self.frontend._seq_features(step_state) for step_state in states]
|
|
1228
|
+
seq = torch.cat(features, dim=2)
|
|
1229
|
+
center = seq[:, Config.CONTEXT_K, :]
|
|
1230
|
+
if self.valid_margin > 0:
|
|
1231
|
+
valid_seq = seq[:, self.valid_margin : Config.SEQ_LEN - self.valid_margin, :]
|
|
1232
|
+
else:
|
|
1233
|
+
valid_seq = seq
|
|
1234
|
+
mean = valid_seq.mean(dim=1)
|
|
1235
|
+
std = valid_seq.std(dim=1, unbiased=False)
|
|
1236
|
+
fused = torch.cat([center, mean, std, raw[:, Config.CONTEXT_K, :]], dim=1)
|
|
1237
|
+
return self.head(fused)
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
class ComplexCNNRxEqualizer(nn.Module):
|
|
1241
|
+
def __init__(self):
|
|
1242
|
+
super().__init__()
|
|
1243
|
+
self.encoder = ComplexFeatureEncoder()
|
|
1244
|
+
temporal_in_dim = 3 * self.encoder.out_channels
|
|
1245
|
+
self.temporal_proj = nn.Conv1d(temporal_in_dim, Config.COMPLEX_TEMPORAL_DIM, kernel_size=1, bias=False)
|
|
1246
|
+
self.temporal_blocks = nn.ModuleList(
|
|
1247
|
+
[TemporalConvBlock(Config.COMPLEX_TEMPORAL_DIM, dilation) for dilation in Config.COMPLEX_TEMPORAL_DILATIONS]
|
|
1248
|
+
)
|
|
1249
|
+
self.temporal_norm = nn.BatchNorm1d(Config.COMPLEX_TEMPORAL_DIM)
|
|
1250
|
+
self.pool_score = nn.Conv1d(Config.COMPLEX_TEMPORAL_DIM, 1, kernel_size=1)
|
|
1251
|
+
fused_dim = Config.COMPLEX_TEMPORAL_DIM * 2 + 2 * Config.INPUT_DIM
|
|
1252
|
+
self.head = nn.Sequential(
|
|
1253
|
+
nn.Linear(fused_dim, Config.COMPLEX_HEAD_DIM),
|
|
1254
|
+
nn.LayerNorm(Config.COMPLEX_HEAD_DIM),
|
|
1255
|
+
nn.GELU(),
|
|
1256
|
+
nn.Dropout(Config.DROPOUT),
|
|
1257
|
+
nn.Linear(Config.COMPLEX_HEAD_DIM, Config.HIDDEN_DIM),
|
|
1258
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1259
|
+
nn.GELU(),
|
|
1260
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1261
|
+
nn.Linear(Config.HIDDEN_DIM, 2),
|
|
1262
|
+
)
|
|
1263
|
+
self._init_weights()
|
|
1264
|
+
|
|
1265
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1266
|
+
self.encoder.initialize_from_data(train_x, train_y)
|
|
1267
|
+
|
|
1268
|
+
def _init_weights(self):
|
|
1269
|
+
for module in self.modules():
|
|
1270
|
+
if isinstance(module, nn.Linear):
|
|
1271
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1272
|
+
if module.bias is not None:
|
|
1273
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1274
|
+
elif isinstance(module, nn.BatchNorm1d):
|
|
1275
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1276
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1277
|
+
elif isinstance(module, nn.LayerNorm):
|
|
1278
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1279
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1280
|
+
|
|
1281
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1282
|
+
raw, _, seq_features = self.encoder(x)
|
|
1283
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
1284
|
+
raw_global = raw.mean(dim=1)
|
|
1285
|
+
temporal = seq_features.transpose(1, 2)
|
|
1286
|
+
temporal = self.temporal_proj(temporal)
|
|
1287
|
+
for block in self.temporal_blocks:
|
|
1288
|
+
temporal = block(temporal)
|
|
1289
|
+
temporal = self.temporal_norm(temporal)
|
|
1290
|
+
|
|
1291
|
+
center = temporal[:, :, Config.CONTEXT_K]
|
|
1292
|
+
scores = self.pool_score(temporal)
|
|
1293
|
+
weights = torch.softmax(scores, dim=2)
|
|
1294
|
+
global_context = torch.sum(weights * temporal, dim=2)
|
|
1295
|
+
fused = torch.cat([center, global_context, raw_center, raw_global], dim=1)
|
|
1296
|
+
return self.head(fused)
|
|
1297
|
+
|
|
1298
|
+
|
|
1299
|
+
class ComplexLSTMRxEqualizer(nn.Module):
|
|
1300
|
+
def __init__(self):
|
|
1301
|
+
super().__init__()
|
|
1302
|
+
self.encoder = ComplexFeatureEncoder()
|
|
1303
|
+
seq_in_dim = 3 * self.encoder.out_channels
|
|
1304
|
+
self.seq_proj = nn.Sequential(
|
|
1305
|
+
nn.Linear(seq_in_dim, Config.COMPLEX_SEQ_DIM),
|
|
1306
|
+
nn.LayerNorm(Config.COMPLEX_SEQ_DIM),
|
|
1307
|
+
nn.GELU(),
|
|
1308
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1309
|
+
)
|
|
1310
|
+
self.lstm = nn.LSTM(
|
|
1311
|
+
input_size=Config.COMPLEX_SEQ_DIM,
|
|
1312
|
+
hidden_size=Config.COMPLEX_LSTM_HIDDEN,
|
|
1313
|
+
num_layers=Config.COMPLEX_LSTM_LAYERS,
|
|
1314
|
+
batch_first=True,
|
|
1315
|
+
bidirectional=Config.BIDIRECTIONAL,
|
|
1316
|
+
dropout=Config.DROPOUT if Config.COMPLEX_LSTM_LAYERS > 1 else 0.0,
|
|
1317
|
+
)
|
|
1318
|
+
lstm_out_dim = Config.COMPLEX_LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
|
|
1319
|
+
self.attention = AttentionLayer(lstm_out_dim)
|
|
1320
|
+
self.center_fusion = nn.Sequential(
|
|
1321
|
+
nn.Linear(lstm_out_dim * 2 + Config.INPUT_DIM, Config.COMPLEX_HEAD_DIM),
|
|
1322
|
+
nn.LayerNorm(Config.COMPLEX_HEAD_DIM),
|
|
1323
|
+
nn.GELU(),
|
|
1324
|
+
)
|
|
1325
|
+
self.head = nn.Sequential(
|
|
1326
|
+
nn.Linear(Config.COMPLEX_HEAD_DIM, Config.HIDDEN_DIM),
|
|
1327
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1328
|
+
nn.GELU(),
|
|
1329
|
+
nn.Dropout(Config.DROPOUT),
|
|
1330
|
+
nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
|
|
1331
|
+
nn.LayerNorm(Config.HIDDEN_DIM // 2),
|
|
1332
|
+
nn.GELU(),
|
|
1333
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1334
|
+
nn.Linear(Config.HIDDEN_DIM // 2, 2),
|
|
1335
|
+
)
|
|
1336
|
+
self._init_weights()
|
|
1337
|
+
|
|
1338
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1339
|
+
self.encoder.initialize_from_data(train_x, train_y)
|
|
1340
|
+
|
|
1341
|
+
def _init_weights(self):
|
|
1342
|
+
for name, param in self.lstm.named_parameters():
|
|
1343
|
+
if "weight_ih" in name:
|
|
1344
|
+
nn.init.xavier_uniform_(param.data)
|
|
1345
|
+
elif "weight_hh" in name:
|
|
1346
|
+
nn.init.orthogonal_(param.data)
|
|
1347
|
+
elif "bias" in name:
|
|
1348
|
+
nn.init.constant_(param.data, 0)
|
|
1349
|
+
gate = Config.COMPLEX_LSTM_HIDDEN
|
|
1350
|
+
param.data[gate : 2 * gate] = 1.0
|
|
1351
|
+
for module in self.modules():
|
|
1352
|
+
if isinstance(module, nn.Linear):
|
|
1353
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1354
|
+
if module.bias is not None:
|
|
1355
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1356
|
+
elif isinstance(module, nn.LayerNorm):
|
|
1357
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1358
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1359
|
+
|
|
1360
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1361
|
+
raw, _, seq_features = self.encoder(x)
|
|
1362
|
+
seq = self.seq_proj(seq_features)
|
|
1363
|
+
lstm_out, _ = self.lstm(seq)
|
|
1364
|
+
center = lstm_out[:, Config.CONTEXT_K, :]
|
|
1365
|
+
context = self.attention(lstm_out)
|
|
1366
|
+
fused = self.center_fusion(torch.cat([center, context, raw[:, Config.CONTEXT_K, :]], dim=1))
|
|
1367
|
+
return self.head(fused)
|
|
1368
|
+
|
|
1369
|
+
|
|
1370
|
+
class ComplexCNNLSTMRxEqualizer(nn.Module):
|
|
1371
|
+
def __init__(self):
|
|
1372
|
+
super().__init__()
|
|
1373
|
+
self.encoder = ComplexFeatureEncoder()
|
|
1374
|
+
seq_in_dim = 3 * self.encoder.out_channels
|
|
1375
|
+
self.temporal_proj = nn.Conv1d(seq_in_dim, Config.COMPLEX_TEMPORAL_DIM, kernel_size=1, bias=False)
|
|
1376
|
+
self.temporal_blocks = nn.ModuleList(
|
|
1377
|
+
[TemporalConvBlock(Config.COMPLEX_TEMPORAL_DIM, dilation) for dilation in Config.COMPLEX_TEMPORAL_DILATIONS[:2]]
|
|
1378
|
+
)
|
|
1379
|
+
self.temporal_norm = nn.BatchNorm1d(Config.COMPLEX_TEMPORAL_DIM)
|
|
1380
|
+
self.lstm = nn.LSTM(
|
|
1381
|
+
input_size=Config.COMPLEX_TEMPORAL_DIM,
|
|
1382
|
+
hidden_size=Config.COMPLEX_LSTM_HIDDEN,
|
|
1383
|
+
num_layers=Config.COMPLEX_LSTM_LAYERS,
|
|
1384
|
+
batch_first=True,
|
|
1385
|
+
bidirectional=Config.BIDIRECTIONAL,
|
|
1386
|
+
dropout=Config.DROPOUT if Config.COMPLEX_LSTM_LAYERS > 1 else 0.0,
|
|
1387
|
+
)
|
|
1388
|
+
lstm_out_dim = Config.COMPLEX_LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
|
|
1389
|
+
self.attention = AttentionLayer(lstm_out_dim)
|
|
1390
|
+
self.center_fusion = nn.Sequential(
|
|
1391
|
+
nn.Linear(lstm_out_dim * 2 + Config.INPUT_DIM, Config.COMPLEX_HEAD_DIM),
|
|
1392
|
+
nn.LayerNorm(Config.COMPLEX_HEAD_DIM),
|
|
1393
|
+
nn.GELU(),
|
|
1394
|
+
)
|
|
1395
|
+
self.head = nn.Sequential(
|
|
1396
|
+
nn.Linear(Config.COMPLEX_HEAD_DIM, Config.HIDDEN_DIM),
|
|
1397
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1398
|
+
nn.GELU(),
|
|
1399
|
+
nn.Dropout(Config.DROPOUT),
|
|
1400
|
+
nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
|
|
1401
|
+
nn.LayerNorm(Config.HIDDEN_DIM // 2),
|
|
1402
|
+
nn.GELU(),
|
|
1403
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1404
|
+
nn.Linear(Config.HIDDEN_DIM // 2, 2),
|
|
1405
|
+
)
|
|
1406
|
+
self._init_weights()
|
|
1407
|
+
|
|
1408
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1409
|
+
self.encoder.initialize_from_data(train_x, train_y)
|
|
1410
|
+
|
|
1411
|
+
def _init_weights(self):
|
|
1412
|
+
for name, param in self.lstm.named_parameters():
|
|
1413
|
+
if "weight_ih" in name:
|
|
1414
|
+
nn.init.xavier_uniform_(param.data)
|
|
1415
|
+
elif "weight_hh" in name:
|
|
1416
|
+
nn.init.orthogonal_(param.data)
|
|
1417
|
+
elif "bias" in name:
|
|
1418
|
+
nn.init.constant_(param.data, 0)
|
|
1419
|
+
gate = Config.COMPLEX_LSTM_HIDDEN
|
|
1420
|
+
param.data[gate : 2 * gate] = 1.0
|
|
1421
|
+
for module in self.modules():
|
|
1422
|
+
if isinstance(module, nn.Linear):
|
|
1423
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1424
|
+
if module.bias is not None:
|
|
1425
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1426
|
+
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
|
1427
|
+
if hasattr(module, "weight") and module.weight is not None:
|
|
1428
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1429
|
+
if hasattr(module, "bias") and module.bias is not None:
|
|
1430
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1431
|
+
elif isinstance(module, nn.Conv1d):
|
|
1432
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1433
|
+
if module.bias is not None:
|
|
1434
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1435
|
+
|
|
1436
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1437
|
+
raw, _, seq_features = self.encoder(x)
|
|
1438
|
+
temporal = self.temporal_proj(seq_features.transpose(1, 2))
|
|
1439
|
+
for block in self.temporal_blocks:
|
|
1440
|
+
temporal = block(temporal)
|
|
1441
|
+
temporal = self.temporal_norm(temporal).transpose(1, 2)
|
|
1442
|
+
lstm_out, _ = self.lstm(temporal)
|
|
1443
|
+
center = lstm_out[:, Config.CONTEXT_K, :]
|
|
1444
|
+
context = self.attention(lstm_out)
|
|
1445
|
+
fused = self.center_fusion(torch.cat([center, context, raw[:, Config.CONTEXT_K, :]], dim=1))
|
|
1446
|
+
return self.head(fused)
|
|
1447
|
+
|
|
1448
|
+
|
|
1449
|
+
class ComplexFastKANEqualizer(nn.Module):
|
|
1450
|
+
def __init__(self):
|
|
1451
|
+
super().__init__()
|
|
1452
|
+
self.encoder = LightweightComplexEncoder()
|
|
1453
|
+
fused_dim = 3 * (self.encoder.out_channels + 3) + 2 * Config.INPUT_DIM + 1
|
|
1454
|
+
self.head = FastKANHead(
|
|
1455
|
+
input_dim=fused_dim,
|
|
1456
|
+
hidden_dim=Config.FASTKAN_HIDDEN_DIM,
|
|
1457
|
+
out_dim=2,
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1461
|
+
raw, _, seq_features = self.encoder(x)
|
|
1462
|
+
center = seq_features[:, Config.CONTEXT_K, :]
|
|
1463
|
+
mean = seq_features.mean(dim=1)
|
|
1464
|
+
std = seq_features.std(dim=1, unbiased=False)
|
|
1465
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
1466
|
+
raw_mean = raw.mean(dim=1)
|
|
1467
|
+
power_center = raw_center.square().sum(dim=1, keepdim=True)
|
|
1468
|
+
fused = torch.cat([center, mean, std, raw_center, raw_mean, power_center], dim=1)
|
|
1469
|
+
return self.head(fused)
|
|
1470
|
+
|
|
1471
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
1472
|
+
return self.head.regularization_loss()
|
|
1473
|
+
|
|
1474
|
+
|
|
1475
|
+
class ComplexFastKANClassifierEqualizer(nn.Module):
|
|
1476
|
+
def __init__(self):
|
|
1477
|
+
super().__init__()
|
|
1478
|
+
self.encoder = LightweightComplexEncoder()
|
|
1479
|
+
fused_dim = 3 * (self.encoder.out_channels + 3) + 2 * Config.INPUT_DIM + 1
|
|
1480
|
+
self.head = FastKANHead(
|
|
1481
|
+
input_dim=fused_dim,
|
|
1482
|
+
hidden_dim=Config.FASTKAN_HIDDEN_DIM,
|
|
1483
|
+
out_dim=CONSTELLATION.size(0),
|
|
1484
|
+
)
|
|
1485
|
+
|
|
1486
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1487
|
+
raw, _, seq_features = self.encoder(x)
|
|
1488
|
+
center = seq_features[:, Config.CONTEXT_K, :]
|
|
1489
|
+
mean = seq_features.mean(dim=1)
|
|
1490
|
+
std = seq_features.std(dim=1, unbiased=False)
|
|
1491
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
1492
|
+
raw_mean = raw.mean(dim=1)
|
|
1493
|
+
power_center = raw_center.square().sum(dim=1, keepdim=True)
|
|
1494
|
+
fused = torch.cat([center, mean, std, raw_center, raw_mean, power_center], dim=1)
|
|
1495
|
+
return self.head(fused)
|
|
1496
|
+
|
|
1497
|
+
def regularization_loss(self) -> torch.Tensor:
|
|
1498
|
+
return self.head.regularization_loss()
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
class TransformerEncoderBlock(nn.Module):
|
|
1502
|
+
def __init__(self, dim: int, heads: int, ff_dim: int):
|
|
1503
|
+
super().__init__()
|
|
1504
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
1505
|
+
self.attn = nn.MultiheadAttention(
|
|
1506
|
+
embed_dim=dim,
|
|
1507
|
+
num_heads=heads,
|
|
1508
|
+
dropout=Config.DROPOUT,
|
|
1509
|
+
batch_first=True,
|
|
1510
|
+
)
|
|
1511
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
1512
|
+
self.local_mixer = nn.Sequential(
|
|
1513
|
+
nn.Conv1d(
|
|
1514
|
+
dim,
|
|
1515
|
+
dim,
|
|
1516
|
+
kernel_size=Config.TRANSFORMER_CONV_KERNEL,
|
|
1517
|
+
padding=Config.TRANSFORMER_CONV_KERNEL // 2,
|
|
1518
|
+
groups=dim,
|
|
1519
|
+
),
|
|
1520
|
+
nn.GELU(),
|
|
1521
|
+
nn.Conv1d(dim, dim, kernel_size=1),
|
|
1522
|
+
)
|
|
1523
|
+
self.norm3 = nn.LayerNorm(dim)
|
|
1524
|
+
self.ff_in = nn.Linear(dim, ff_dim * 2)
|
|
1525
|
+
self.ff_out = nn.Linear(ff_dim, dim)
|
|
1526
|
+
self.dropout = nn.Dropout(Config.DROPOUT)
|
|
1527
|
+
|
|
1528
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1529
|
+
attn_in = self.norm1(x)
|
|
1530
|
+
with sdpa_context():
|
|
1531
|
+
attn_out, _ = self.attn(attn_in, attn_in, attn_in, need_weights=False)
|
|
1532
|
+
x = x + self.dropout(attn_out)
|
|
1533
|
+
|
|
1534
|
+
local_in = self.norm2(x).transpose(1, 2)
|
|
1535
|
+
local_out = self.local_mixer(local_in).transpose(1, 2)
|
|
1536
|
+
x = x + self.dropout(local_out)
|
|
1537
|
+
|
|
1538
|
+
ff_in = self.norm3(x)
|
|
1539
|
+
value, gate = self.ff_in(ff_in).chunk(2, dim=-1)
|
|
1540
|
+
ff_out = self.ff_out(value * F.gelu(gate))
|
|
1541
|
+
return x + self.dropout(ff_out)
|
|
1542
|
+
|
|
1543
|
+
|
|
1544
|
+
class TransformerRxEqualizer(nn.Module):
|
|
1545
|
+
def __init__(self):
|
|
1546
|
+
super().__init__()
|
|
1547
|
+
dim = Config.TRANSFORMER_DIM
|
|
1548
|
+
self.input_proj = nn.Linear(Config.INPUT_DIM, dim)
|
|
1549
|
+
self.pos_embedding = nn.Parameter(torch.zeros(1, Config.SEQ_LEN, dim))
|
|
1550
|
+
self.input_dropout = nn.Dropout(Config.DROPOUT * 0.5)
|
|
1551
|
+
self.blocks = nn.ModuleList(
|
|
1552
|
+
[
|
|
1553
|
+
TransformerEncoderBlock(
|
|
1554
|
+
dim=dim,
|
|
1555
|
+
heads=Config.TRANSFORMER_HEADS,
|
|
1556
|
+
ff_dim=Config.TRANSFORMER_FF_DIM,
|
|
1557
|
+
)
|
|
1558
|
+
for _ in range(Config.TRANSFORMER_LAYERS)
|
|
1559
|
+
]
|
|
1560
|
+
)
|
|
1561
|
+
self.final_norm = nn.LayerNorm(dim)
|
|
1562
|
+
self.center_fusion = nn.Sequential(
|
|
1563
|
+
nn.Linear(dim * 2, dim),
|
|
1564
|
+
nn.LayerNorm(dim),
|
|
1565
|
+
nn.GELU(),
|
|
1566
|
+
)
|
|
1567
|
+
self.regressor = nn.Sequential(
|
|
1568
|
+
nn.Linear(dim, Config.HIDDEN_DIM),
|
|
1569
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1570
|
+
nn.GELU(),
|
|
1571
|
+
nn.Dropout(Config.DROPOUT),
|
|
1572
|
+
nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
|
|
1573
|
+
nn.LayerNorm(Config.HIDDEN_DIM // 2),
|
|
1574
|
+
nn.GELU(),
|
|
1575
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1576
|
+
nn.Linear(Config.HIDDEN_DIM // 2, 2),
|
|
1577
|
+
)
|
|
1578
|
+
self._init_weights()
|
|
1579
|
+
|
|
1580
|
+
def _init_weights(self):
|
|
1581
|
+
nn.init.normal_(self.pos_embedding, mean=0.0, std=0.02)
|
|
1582
|
+
for module in self.modules():
|
|
1583
|
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
|
1584
|
+
nn.init.xavier_uniform_(module.weight)
|
|
1585
|
+
if module.bias is not None:
|
|
1586
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1587
|
+
elif isinstance(module, nn.LayerNorm):
|
|
1588
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1589
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1590
|
+
|
|
1591
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1592
|
+
x = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
1593
|
+
x = self.input_proj(x)
|
|
1594
|
+
x = self.input_dropout(x + self.pos_embedding)
|
|
1595
|
+
for block in self.blocks:
|
|
1596
|
+
x = block(x)
|
|
1597
|
+
x = self.final_norm(x)
|
|
1598
|
+
center = x[:, Config.CONTEXT_K, :]
|
|
1599
|
+
global_context = x.mean(dim=1)
|
|
1600
|
+
fused = self.center_fusion(torch.cat([center, global_context], dim=1))
|
|
1601
|
+
return self.regressor(fused)
|
|
1602
|
+
|
|
1603
|
+
|
|
1604
|
+
class TCNResidualBlock(nn.Module):
|
|
1605
|
+
def __init__(self, channels: int, kernel_size: int, dilation: int):
|
|
1606
|
+
super().__init__()
|
|
1607
|
+
padding = dilation * (kernel_size - 1) // 2
|
|
1608
|
+
self.norm = nn.BatchNorm1d(channels)
|
|
1609
|
+
self.depthwise = nn.Conv1d(
|
|
1610
|
+
channels,
|
|
1611
|
+
channels,
|
|
1612
|
+
kernel_size=kernel_size,
|
|
1613
|
+
padding=padding,
|
|
1614
|
+
dilation=dilation,
|
|
1615
|
+
groups=channels,
|
|
1616
|
+
bias=False,
|
|
1617
|
+
)
|
|
1618
|
+
self.pointwise = nn.Conv1d(channels, channels * 2, kernel_size=1)
|
|
1619
|
+
self.out_proj = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
|
|
1620
|
+
self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
|
|
1621
|
+
self._init_weights()
|
|
1622
|
+
|
|
1623
|
+
def _init_weights(self):
|
|
1624
|
+
for module in (self.depthwise, self.pointwise, self.out_proj):
|
|
1625
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1626
|
+
if getattr(module, "bias", None) is not None:
|
|
1627
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1628
|
+
nn.init.constant_(self.norm.weight, 1.0)
|
|
1629
|
+
nn.init.constant_(self.norm.bias, 0.0)
|
|
1630
|
+
|
|
1631
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1632
|
+
residual = x
|
|
1633
|
+
x = self.norm(x)
|
|
1634
|
+
x = self.depthwise(x)
|
|
1635
|
+
value, gate = self.pointwise(x).chunk(2, dim=1)
|
|
1636
|
+
x = self.out_proj(F.gelu(value) * torch.sigmoid(gate))
|
|
1637
|
+
return residual + self.dropout(x)
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
class TCNRxEqualizer(nn.Module):
|
|
1641
|
+
def __init__(self):
|
|
1642
|
+
super().__init__()
|
|
1643
|
+
hidden_dim = Config.TCN_HIDDEN_DIM
|
|
1644
|
+
dilations = Config.TCN_DILATIONS[: Config.TCN_LAYERS]
|
|
1645
|
+
if len(dilations) < Config.TCN_LAYERS:
|
|
1646
|
+
dilations = dilations + [dilations[-1] if dilations else 1] * (Config.TCN_LAYERS - len(dilations))
|
|
1647
|
+
self.stem = nn.Sequential(
|
|
1648
|
+
nn.Conv1d(Config.INPUT_DIM, hidden_dim, kernel_size=1, bias=False),
|
|
1649
|
+
nn.BatchNorm1d(hidden_dim),
|
|
1650
|
+
nn.GELU(),
|
|
1651
|
+
)
|
|
1652
|
+
self.blocks = nn.ModuleList(
|
|
1653
|
+
[TCNResidualBlock(hidden_dim, Config.TCN_KERNEL_SIZE, dilation) for dilation in dilations]
|
|
1654
|
+
)
|
|
1655
|
+
self.final_norm = nn.BatchNorm1d(hidden_dim)
|
|
1656
|
+
fused_dim = hidden_dim * 2 + 2 * Config.INPUT_DIM
|
|
1657
|
+
self.head = nn.Sequential(
|
|
1658
|
+
nn.Linear(fused_dim, Config.HIDDEN_DIM),
|
|
1659
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1660
|
+
nn.GELU(),
|
|
1661
|
+
nn.Dropout(Config.DROPOUT),
|
|
1662
|
+
nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
|
|
1663
|
+
nn.LayerNorm(Config.HIDDEN_DIM // 2),
|
|
1664
|
+
nn.GELU(),
|
|
1665
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1666
|
+
nn.Linear(Config.HIDDEN_DIM // 2, 2),
|
|
1667
|
+
)
|
|
1668
|
+
self._init_weights()
|
|
1669
|
+
|
|
1670
|
+
def _init_weights(self):
|
|
1671
|
+
for module in self.modules():
|
|
1672
|
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
|
1673
|
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
|
1674
|
+
if module.bias is not None:
|
|
1675
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1676
|
+
elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
|
|
1677
|
+
if hasattr(module, "weight") and module.weight is not None:
|
|
1678
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1679
|
+
if hasattr(module, "bias") and module.bias is not None:
|
|
1680
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1681
|
+
|
|
1682
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1683
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
1684
|
+
seq = raw.transpose(1, 2)
|
|
1685
|
+
hidden = self.stem(seq)
|
|
1686
|
+
for block in self.blocks:
|
|
1687
|
+
hidden = block(hidden)
|
|
1688
|
+
hidden = self.final_norm(hidden)
|
|
1689
|
+
center = hidden[:, :, Config.CONTEXT_K]
|
|
1690
|
+
global_context = hidden.mean(dim=2)
|
|
1691
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
1692
|
+
raw_mean = raw.mean(dim=1)
|
|
1693
|
+
return self.head(torch.cat([center, global_context, raw_center, raw_mean], dim=1))
|
|
1694
|
+
|
|
1695
|
+
|
|
1696
|
+
class MambaRxEqualizer(nn.Module):
|
|
1697
|
+
def __init__(self):
|
|
1698
|
+
super().__init__()
|
|
1699
|
+
if Mamba is None:
|
|
1700
|
+
raise ImportError(
|
|
1701
|
+
"mamba_ssm is unavailable. Install `mamba-ssm` to run the `mamba` model, "
|
|
1702
|
+
"or keep using `tcn`, `lstm`, and KAN baselines."
|
|
1703
|
+
)
|
|
1704
|
+
dim = Config.MAMBA_DIM
|
|
1705
|
+
self.input_proj = nn.Sequential(
|
|
1706
|
+
nn.Linear(Config.INPUT_DIM, dim),
|
|
1707
|
+
nn.LayerNorm(dim),
|
|
1708
|
+
nn.GELU(),
|
|
1709
|
+
)
|
|
1710
|
+
self.blocks = nn.ModuleList(
|
|
1711
|
+
[
|
|
1712
|
+
nn.ModuleDict(
|
|
1713
|
+
{
|
|
1714
|
+
"norm": nn.LayerNorm(dim),
|
|
1715
|
+
"mamba": Mamba(
|
|
1716
|
+
d_model=dim,
|
|
1717
|
+
d_state=Config.MAMBA_D_STATE,
|
|
1718
|
+
d_conv=Config.MAMBA_D_CONV,
|
|
1719
|
+
expand=Config.MAMBA_EXPAND,
|
|
1720
|
+
),
|
|
1721
|
+
}
|
|
1722
|
+
)
|
|
1723
|
+
for _ in range(Config.MAMBA_LAYERS)
|
|
1724
|
+
]
|
|
1725
|
+
)
|
|
1726
|
+
self.final_norm = nn.LayerNorm(dim)
|
|
1727
|
+
self.head = nn.Sequential(
|
|
1728
|
+
nn.Linear(dim * 2 + Config.INPUT_DIM, Config.HIDDEN_DIM),
|
|
1729
|
+
nn.LayerNorm(Config.HIDDEN_DIM),
|
|
1730
|
+
nn.GELU(),
|
|
1731
|
+
nn.Dropout(Config.DROPOUT),
|
|
1732
|
+
nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
|
|
1733
|
+
nn.LayerNorm(Config.HIDDEN_DIM // 2),
|
|
1734
|
+
nn.GELU(),
|
|
1735
|
+
nn.Dropout(Config.DROPOUT * 0.5),
|
|
1736
|
+
nn.Linear(Config.HIDDEN_DIM // 2, 2),
|
|
1737
|
+
)
|
|
1738
|
+
self._init_weights()
|
|
1739
|
+
|
|
1740
|
+
def _init_weights(self):
|
|
1741
|
+
for module in self.modules():
|
|
1742
|
+
if isinstance(module, nn.Linear):
|
|
1743
|
+
nn.init.xavier_uniform_(module.weight)
|
|
1744
|
+
if module.bias is not None:
|
|
1745
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1746
|
+
elif isinstance(module, nn.LayerNorm):
|
|
1747
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1748
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1749
|
+
|
|
1750
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1751
|
+
raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
|
|
1752
|
+
hidden = self.input_proj(raw)
|
|
1753
|
+
for block in self.blocks:
|
|
1754
|
+
hidden = hidden + block["mamba"](block["norm"](hidden))
|
|
1755
|
+
hidden = self.final_norm(hidden)
|
|
1756
|
+
center = hidden[:, Config.CONTEXT_K, :]
|
|
1757
|
+
global_context = hidden.mean(dim=1)
|
|
1758
|
+
raw_center = raw[:, Config.CONTEXT_K, :]
|
|
1759
|
+
return self.head(torch.cat([center, global_context, raw_center], dim=1))
|
|
1760
|
+
|
|
1761
|
+
|
|
1762
|
+
class MLPRxEqualizer(nn.Module):
|
|
1763
|
+
def __init__(self):
|
|
1764
|
+
super().__init__()
|
|
1765
|
+
input_dim = Config.SEQ_LEN * Config.INPUT_DIM
|
|
1766
|
+
hidden_dim = Config.HIDDEN_DIM
|
|
1767
|
+
self.skip = nn.Linear(input_dim, 2)
|
|
1768
|
+
layers: List[nn.Module] = []
|
|
1769
|
+
in_dim = input_dim
|
|
1770
|
+
for _ in range(max(Config.MLP_LAYERS, 1)):
|
|
1771
|
+
layers.extend(
|
|
1772
|
+
[
|
|
1773
|
+
nn.Linear(in_dim, hidden_dim),
|
|
1774
|
+
nn.LayerNorm(hidden_dim),
|
|
1775
|
+
nn.GELU(),
|
|
1776
|
+
]
|
|
1777
|
+
)
|
|
1778
|
+
in_dim = hidden_dim
|
|
1779
|
+
layers.append(nn.Linear(in_dim, 2))
|
|
1780
|
+
self.net = nn.Sequential(*layers)
|
|
1781
|
+
self._init_weights()
|
|
1782
|
+
|
|
1783
|
+
def _init_weights(self):
|
|
1784
|
+
with torch.no_grad():
|
|
1785
|
+
self.skip.weight.zero_()
|
|
1786
|
+
self.skip.bias.zero_()
|
|
1787
|
+
center_start = Config.CONTEXT_K * Config.INPUT_DIM
|
|
1788
|
+
self.skip.weight[:, center_start : center_start + Config.INPUT_DIM] = torch.eye(2)
|
|
1789
|
+
|
|
1790
|
+
linear_layers = [module for module in self.net.modules() if isinstance(module, nn.Linear)]
|
|
1791
|
+
for idx, module in enumerate(linear_layers):
|
|
1792
|
+
if idx == len(linear_layers) - 1:
|
|
1793
|
+
nn.init.zeros_(module.weight)
|
|
1794
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1795
|
+
else:
|
|
1796
|
+
nn.init.xavier_uniform_(module.weight)
|
|
1797
|
+
if module.bias is not None:
|
|
1798
|
+
nn.init.constant_(module.bias, 0.0)
|
|
1799
|
+
for module in self.net.modules():
|
|
1800
|
+
if isinstance(module, nn.LayerNorm):
|
|
1801
|
+
nn.init.constant_(module.bias, 0)
|
|
1802
|
+
nn.init.constant_(module.weight, 1.0)
|
|
1803
|
+
|
|
1804
|
+
@torch.no_grad()
|
|
1805
|
+
def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
|
|
1806
|
+
samples = min(train_x.size(0), 262_144)
|
|
1807
|
+
if samples < 3:
|
|
1808
|
+
return
|
|
1809
|
+
center = train_x[:samples].view(samples, Config.SEQ_LEN, Config.INPUT_DIM)[:, Config.CONTEXT_K, :].float()
|
|
1810
|
+
target = train_y[:samples].float()
|
|
1811
|
+
ones = torch.ones(samples, 1, dtype=center.dtype, device=center.device)
|
|
1812
|
+
design = torch.cat([center, ones], dim=1)
|
|
1813
|
+
solution = torch.linalg.lstsq(design, target).solution
|
|
1814
|
+
|
|
1815
|
+
self.skip.weight.zero_()
|
|
1816
|
+
self.skip.bias.copy_(solution[-1].to(self.skip.bias.device, dtype=self.skip.bias.dtype))
|
|
1817
|
+
center_start = Config.CONTEXT_K * Config.INPUT_DIM
|
|
1818
|
+
self.skip.weight[:, center_start : center_start + Config.INPUT_DIM].copy_(
|
|
1819
|
+
solution[: Config.INPUT_DIM].T.to(self.skip.weight.device, dtype=self.skip.weight.dtype)
|
|
1820
|
+
)
|
|
1821
|
+
|
|
1822
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1823
|
+
return self.skip(x) + self.net(x)
|
|
1824
|
+
|
|
1825
|
+
|
|
1826
|
+
def log(msg: str):
|
|
1827
|
+
print(msg, flush=True)
|
|
1828
|
+
|
|
1829
|
+
|
|
1830
|
+
MODEL_NOTES = {
|
|
1831
|
+
"efficient_kan_baseline": "flat IQ window -> KAN -> corrected I/Q",
|
|
1832
|
+
"efficient_kan_residual": "flat IQ window -> KAN correction -> rx center + correction",
|
|
1833
|
+
"efficient_kan_features": "handcrafted local/global IQ features -> KAN -> corrected I/Q",
|
|
1834
|
+
"cnn_kan": "CNN extracts temporal features -> KAN head -> corrected I/Q",
|
|
1835
|
+
"kan_classifier": "flat IQ window -> KAN -> 16 constellation logits",
|
|
1836
|
+
"fastkan_classifier": "flat IQ window -> compact RBF/FastKAN -> 16 constellation logits",
|
|
1837
|
+
"complex_fastkan_classifier": "light complex encoder -> RBF/FastKAN -> 16 constellation logits",
|
|
1838
|
+
"mlp": "flat IQ window -> MLP -> corrected I/Q",
|
|
1839
|
+
"cnn": "CNN extracts temporal features -> MLP head -> corrected I/Q",
|
|
1840
|
+
"tcn": "dilated temporal CNN over IQ window -> corrected I/Q",
|
|
1841
|
+
"mamba": "Mamba sequence blocks over IQ window -> corrected I/Q",
|
|
1842
|
+
}
|
|
1843
|
+
|
|
1844
|
+
|
|
1845
|
+
def count_trainable_parameters(model: nn.Module) -> int:
|
|
1846
|
+
return sum(param.numel() for param in model.parameters() if param.requires_grad)
|
|
1847
|
+
|
|
1848
|
+
|
|
1849
|
+
def build_criterion() -> nn.Module:
|
|
1850
|
+
if Config.LOSS == "smooth_l1":
|
|
1851
|
+
return nn.SmoothL1Loss(beta=0.05)
|
|
1852
|
+
return nn.MSELoss()
|
|
1853
|
+
|
|
1854
|
+
|
|
1855
|
+
def is_classifier_output(preds: torch.Tensor) -> bool:
|
|
1856
|
+
return preds.dim() == 2 and preds.size(1) == CONSTELLATION.size(0)
|
|
1857
|
+
|
|
1858
|
+
|
|
1859
|
+
def prediction_loss(preds: torch.Tensor, targets: torch.Tensor, criterion: nn.Module) -> torch.Tensor:
|
|
1860
|
+
if is_classifier_output(preds):
|
|
1861
|
+
target_classes = symbols_to_classes(targets.float())
|
|
1862
|
+
return F.cross_entropy(preds.float(), target_classes)
|
|
1863
|
+
return criterion(preds, targets)
|
|
1864
|
+
|
|
1865
|
+
|
|
1866
|
+
def build_optimizer(model: nn.Module) -> optim.Optimizer:
|
|
1867
|
+
optimizer_kwargs = {"lr": Config.LEARNING_RATE, "weight_decay": Config.WEIGHT_DECAY}
|
|
1868
|
+
if Config.DEVICE.type == "cuda":
|
|
1869
|
+
optimizer_kwargs["fused"] = True
|
|
1870
|
+
optimizer_name = Config.OPTIMIZER.lower()
|
|
1871
|
+
optimizer_cls = optim.Adam if optimizer_name == "adam" else optim.RMSprop
|
|
1872
|
+
try:
|
|
1873
|
+
return optimizer_cls(model.parameters(), **optimizer_kwargs)
|
|
1874
|
+
except TypeError:
|
|
1875
|
+
optimizer_kwargs.pop("fused", None)
|
|
1876
|
+
return optimizer_cls(model.parameters(), **optimizer_kwargs)
|
|
1877
|
+
|
|
1878
|
+
|
|
1879
|
+
def compute_model_regularization(model: nn.Module) -> torch.Tensor:
|
|
1880
|
+
if Config.KAN_PRUNE_L1 <= 0 or not hasattr(model, "regularization_loss"):
|
|
1881
|
+
return torch.zeros((), device=Config.DEVICE)
|
|
1882
|
+
reg = model.regularization_loss()
|
|
1883
|
+
if not torch.is_tensor(reg):
|
|
1884
|
+
reg = torch.tensor(float(reg), device=Config.DEVICE)
|
|
1885
|
+
return reg * Config.KAN_PRUNE_L1
|
|
1886
|
+
|
|
1887
|
+
|
|
1888
|
+
def complex_unit_response(response: torch.Tensor, order: int) -> torch.Tensor:
|
|
1889
|
+
magnitude = response.abs().clamp_min(1e-6).pow(1.0 / order)
|
|
1890
|
+
phase = torch.angle(response) / order
|
|
1891
|
+
return torch.polar(magnitude, phase)
|
|
1892
|
+
|
|
1893
|
+
|
|
1894
|
+
def centered_complex_kernel_to_frequency(kernel: torch.Tensor, fft_size: int) -> torch.Tensor:
|
|
1895
|
+
center = kernel.numel() // 2
|
|
1896
|
+
ordered = torch.roll(kernel, shifts=-center, dims=0)
|
|
1897
|
+
return torch.fft.fft(ordered, n=fft_size)
|
|
1898
|
+
|
|
1899
|
+
|
|
1900
|
+
def frequency_to_centered_complex_kernel(response: torch.Tensor, seq_len: int) -> torch.Tensor:
|
|
1901
|
+
time = torch.fft.ifft(response, n=response.numel())
|
|
1902
|
+
centered = torch.roll(time[:seq_len], shifts=seq_len // 2, dims=0)
|
|
1903
|
+
return centered
|
|
1904
|
+
|
|
1905
|
+
|
|
1906
|
+
def extract_kernel_from_centered_response(
|
|
1907
|
+
kernel: torch.Tensor,
|
|
1908
|
+
kernel_size: int,
|
|
1909
|
+
symmetric: bool,
|
|
1910
|
+
) -> torch.Tensor:
|
|
1911
|
+
center = kernel.numel() // 2
|
|
1912
|
+
if symmetric:
|
|
1913
|
+
return kernel[center : center + kernel_size].contiguous()
|
|
1914
|
+
start = center - kernel_size // 2
|
|
1915
|
+
end = start + kernel_size
|
|
1916
|
+
return kernel[start:end].contiguous()
|
|
1917
|
+
|
|
1918
|
+
|
|
1919
|
+
def assign_complex_kernel(module, kernel: torch.Tensor):
|
|
1920
|
+
module.real_conv.weight.copy_(kernel.real.view_as(module.real_conv.weight).to(module.real_conv.weight.dtype))
|
|
1921
|
+
module.imag_conv.weight.copy_(kernel.imag.view_as(module.imag_conv.weight).to(module.imag_conv.weight.dtype))
|
|
1922
|
+
|
|
1923
|
+
|
|
1924
|
+
def is_cuda_oom(error: RuntimeError) -> bool:
|
|
1925
|
+
return "out of memory" in str(error).lower()
|
|
1926
|
+
|
|
1927
|
+
|
|
1928
|
+
def mark_cudagraph_step_begin():
|
|
1929
|
+
if Config.DEVICE.type != "cuda":
|
|
1930
|
+
return
|
|
1931
|
+
compiler_mod = getattr(torch, "compiler", None)
|
|
1932
|
+
if compiler_mod is not None and hasattr(compiler_mod, "cudagraph_mark_step_begin"):
|
|
1933
|
+
compiler_mod.cudagraph_mark_step_begin()
|
|
1934
|
+
|
|
1935
|
+
|
|
1936
|
+
def autocast_context():
|
|
1937
|
+
return torch.autocast(
|
|
1938
|
+
device_type="cuda",
|
|
1939
|
+
dtype=torch.float16,
|
|
1940
|
+
enabled=Config.DEVICE.type == "cuda" and Config.USE_AMP,
|
|
1941
|
+
)
|
|
1942
|
+
|
|
1943
|
+
|
|
1944
|
+
def sdpa_context():
|
|
1945
|
+
if Config.DEVICE.type != "cuda":
|
|
1946
|
+
return nullcontext()
|
|
1947
|
+
attention_mod = getattr(torch.nn, "attention", None)
|
|
1948
|
+
if attention_mod is not None and hasattr(attention_mod, "sdpa_kernel") and hasattr(attention_mod, "SDPBackend"):
|
|
1949
|
+
return attention_mod.sdpa_kernel([attention_mod.SDPBackend.MATH])
|
|
1950
|
+
cuda_backends = getattr(torch.backends, "cuda", None)
|
|
1951
|
+
if cuda_backends is not None and hasattr(cuda_backends, "sdp_kernel"):
|
|
1952
|
+
return cuda_backends.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
|
|
1953
|
+
return nullcontext()
|
|
1954
|
+
|
|
1955
|
+
|
|
1956
|
+
def symbols_to_classes(symbols: torch.Tensor) -> torch.Tensor:
|
|
1957
|
+
constellation = CONSTELLATION.to(symbols.device)
|
|
1958
|
+
diff = symbols.unsqueeze(1) - constellation.unsqueeze(0)
|
|
1959
|
+
dist = torch.sum(diff.square(), dim=2)
|
|
1960
|
+
return torch.argmin(dist, dim=1)
|
|
1961
|
+
|
|
1962
|
+
|
|
1963
|
+
def calculate_ber_from_classes(tx_classes: torch.Tensor, rx_classes: torch.Tensor) -> float:
|
|
1964
|
+
bit_labels = BIT_LABELS.to(tx_classes.device)
|
|
1965
|
+
tx_bits = bit_labels[tx_classes]
|
|
1966
|
+
rx_bits = bit_labels[rx_classes]
|
|
1967
|
+
return (tx_bits != rx_bits).float().mean().item()
|
|
1968
|
+
|
|
1969
|
+
|
|
1970
|
+
def compute_rms_scale(symbols: torch.Tensor) -> torch.Tensor:
|
|
1971
|
+
power = torch.mean(symbols.square().sum(dim=1)).sqrt()
|
|
1972
|
+
return power.clamp_min(1e-6)
|
|
1973
|
+
|
|
1974
|
+
|
|
1975
|
+
def power_normalize_pair(
|
|
1976
|
+
tx_symbols: torch.Tensor,
|
|
1977
|
+
rx_symbols: torch.Tensor,
|
|
1978
|
+
tx_scale: Optional[torch.Tensor] = None,
|
|
1979
|
+
rx_scale: Optional[torch.Tensor] = None,
|
|
1980
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1981
|
+
if tx_scale is None:
|
|
1982
|
+
tx_scale = compute_rms_scale(tx_symbols)
|
|
1983
|
+
if rx_scale is None:
|
|
1984
|
+
rx_scale = compute_rms_scale(rx_symbols)
|
|
1985
|
+
return tx_symbols / tx_scale, rx_symbols / rx_scale, tx_scale, rx_scale
|
|
1986
|
+
|
|
1987
|
+
|
|
1988
|
+
def find_best_symbol_scale(tx_symbols: torch.Tensor, rx_symbols: torch.Tensor) -> float:
|
|
1989
|
+
if not Config.BER_SCALE_SEARCH or tx_symbols.numel() == 0 or rx_symbols.numel() == 0:
|
|
1990
|
+
return 1.0
|
|
1991
|
+
|
|
1992
|
+
total = min(tx_symbols.size(0), rx_symbols.size(0))
|
|
1993
|
+
offset = min(Config.BER_SCALE_OFFSET, max(total - 1, 0))
|
|
1994
|
+
available = max(total - offset, 1)
|
|
1995
|
+
sample_count = min(available, Config.BER_SCALE_SAMPLES)
|
|
1996
|
+
tx_eval = tx_symbols[offset : offset + sample_count]
|
|
1997
|
+
rx_eval = rx_symbols[offset : offset + sample_count]
|
|
1998
|
+
|
|
1999
|
+
tx_norm = tx_eval.square().sum(dim=1).mean().sqrt().clamp_min(1e-6)
|
|
2000
|
+
rx_norm = rx_eval.square().sum(dim=1).mean().sqrt().clamp_min(1e-6)
|
|
2001
|
+
center = (tx_norm / rx_norm).item()
|
|
2002
|
+
left = max(Config.BER_SCALE_MIN, 0.5 * center)
|
|
2003
|
+
right = min(Config.BER_SCALE_MAX, 1.5 * center)
|
|
2004
|
+
phi = (1 + 5**0.5) / 2
|
|
2005
|
+
|
|
2006
|
+
def objective(scale: float) -> float:
|
|
2007
|
+
pred_classes = symbols_to_classes((rx_eval * scale).float())
|
|
2008
|
+
target_classes = symbols_to_classes(tx_eval.float())
|
|
2009
|
+
return calculate_ber_from_classes(target_classes, pred_classes)
|
|
2010
|
+
|
|
2011
|
+
x1 = right - (right - left) / phi
|
|
2012
|
+
x2 = left + (right - left) / phi
|
|
2013
|
+
y1 = objective(x1)
|
|
2014
|
+
y2 = objective(x2)
|
|
2015
|
+
for _ in range(Config.BER_SCALE_STEPS):
|
|
2016
|
+
if y1 >= y2:
|
|
2017
|
+
left = x1
|
|
2018
|
+
x1 = x2
|
|
2019
|
+
y1 = y2
|
|
2020
|
+
x2 = left + (right - left) / phi
|
|
2021
|
+
y2 = objective(x2)
|
|
2022
|
+
else:
|
|
2023
|
+
right = x2
|
|
2024
|
+
x2 = x1
|
|
2025
|
+
y2 = y1
|
|
2026
|
+
x1 = right - (right - left) / phi
|
|
2027
|
+
y1 = objective(x1)
|
|
2028
|
+
return float((x1 + x2) * 0.5)
|
|
2029
|
+
|
|
2030
|
+
|
|
2031
|
+
def discover_symbol_files() -> Tuple[Path, List[int]]:
|
|
2032
|
+
for base_dir in Config.DATA_DIR_CANDIDATES:
|
|
2033
|
+
if not base_dir.exists():
|
|
2034
|
+
continue
|
|
2035
|
+
files = sorted(base_dir.glob("Symbols_1m_1ch_PR_*.csv"))
|
|
2036
|
+
if files:
|
|
2037
|
+
indices = sorted(int(path.stem.split("_")[-1]) for path in files)
|
|
2038
|
+
return base_dir, indices[: Config.MAX_FILES]
|
|
2039
|
+
raise FileNotFoundError("No Symbols_1m_1ch_PR_*.csv files found.")
|
|
2040
|
+
|
|
2041
|
+
|
|
2042
|
+
def resolve_splits(file_indices: List[int]) -> Tuple[List[int], List[int], List[int]]:
|
|
2043
|
+
file_indices = list(file_indices)
|
|
2044
|
+
if Config.RANDOMIZE_FILE_SPLIT:
|
|
2045
|
+
rng = np.random.default_rng(Config.SPLIT_SEED)
|
|
2046
|
+
file_indices = rng.permutation(file_indices).tolist()
|
|
2047
|
+
|
|
2048
|
+
train_val_files = max(2, int(len(file_indices) * Config.TRAIN_PORTION))
|
|
2049
|
+
if train_val_files >= len(file_indices):
|
|
2050
|
+
train_val_files = len(file_indices) - 1
|
|
2051
|
+
train_val_pool = file_indices[:train_val_files]
|
|
2052
|
+
if len(train_val_pool) < 2:
|
|
2053
|
+
raise ValueError("Need at least 2 train+val files so validation can be held out.")
|
|
2054
|
+
|
|
2055
|
+
requested_val_files = int(round(len(train_val_pool) * Config.VAL_PORTION_WITHIN_TRAIN))
|
|
2056
|
+
val_files = max(Config.MIN_VAL_FILES, requested_val_files)
|
|
2057
|
+
val_files = min(max(val_files, 1), len(train_val_pool) - 1)
|
|
2058
|
+
|
|
2059
|
+
train_idx = train_val_pool[:-val_files]
|
|
2060
|
+
val_idx = train_val_pool[-val_files:]
|
|
2061
|
+
test_idx = file_indices[train_val_files:]
|
|
2062
|
+
if not test_idx:
|
|
2063
|
+
raise ValueError("File-level split produced no test files. Increase MAX_FILES or lower TRAIN_PORTION.")
|
|
2064
|
+
return train_idx, val_idx, test_idx
|
|
2065
|
+
|
|
2066
|
+
|
|
2067
|
+
def load_symbol_file(base_dir: Path, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
2068
|
+
path = base_dir / f"Symbols_1m_1ch_PR_{idx}.csv"
|
|
2069
|
+
arr = pd.read_csv(path, header=None, dtype=np.float32).to_numpy(copy=False)
|
|
2070
|
+
return torch.from_numpy(arr[:, 0:2].copy()), torch.from_numpy(arr[:, 2:4].copy())
|
|
2071
|
+
|
|
2072
|
+
|
|
2073
|
+
def load_files(base_dir: Path, file_indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
2074
|
+
tx_list: List[torch.Tensor] = []
|
|
2075
|
+
rx_list: List[torch.Tensor] = []
|
|
2076
|
+
for idx in file_indices:
|
|
2077
|
+
tx_symbols, rx_symbols = load_symbol_file(base_dir, idx)
|
|
2078
|
+
tx_list.append(tx_symbols)
|
|
2079
|
+
rx_list.append(rx_symbols)
|
|
2080
|
+
return torch.cat(tx_list, dim=0), torch.cat(rx_list, dim=0)
|
|
2081
|
+
|
|
2082
|
+
|
|
2083
|
+
def load_file_splits(base_dir: Path, file_indices: List[int]) -> List[Dict[str, Any]]:
|
|
2084
|
+
files: List[Dict[str, Any]] = []
|
|
2085
|
+
for idx in file_indices:
|
|
2086
|
+
tx_symbols, rx_symbols = load_symbol_file(base_dir, idx)
|
|
2087
|
+
files.append({"file_idx": idx, "tx": tx_symbols, "rx": rx_symbols})
|
|
2088
|
+
return files
|
|
2089
|
+
|
|
2090
|
+
|
|
2091
|
+
def normalize_file_splits(
|
|
2092
|
+
files: List[Dict[str, Any]],
|
|
2093
|
+
tx_scale: torch.Tensor,
|
|
2094
|
+
rx_scale: torch.Tensor,
|
|
2095
|
+
) -> List[Dict[str, Any]]:
|
|
2096
|
+
return [
|
|
2097
|
+
{
|
|
2098
|
+
"file_idx": item["file_idx"],
|
|
2099
|
+
"tx": item["tx"] / tx_scale,
|
|
2100
|
+
"rx": item["rx"] / rx_scale,
|
|
2101
|
+
}
|
|
2102
|
+
for item in files
|
|
2103
|
+
]
|
|
2104
|
+
|
|
2105
|
+
|
|
2106
|
+
def make_windows(rx_symbols: torch.Tensor, tx_symbols: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
2107
|
+
norm_rx = (rx_symbols - mean) / std
|
|
2108
|
+
norm_rx = torch.nan_to_num(norm_rx, nan=0.0, posinf=0.0, neginf=0.0)
|
|
2109
|
+
window_view = norm_rx.unfold(0, Config.SEQ_LEN, 1).permute(0, 2, 1).contiguous()
|
|
2110
|
+
x = window_view.view(window_view.size(0), -1).contiguous()
|
|
2111
|
+
y = tx_symbols[Config.CONTEXT_K : tx_symbols.size(0) - Config.CONTEXT_K].contiguous()
|
|
2112
|
+
return x, y
|
|
2113
|
+
|
|
2114
|
+
|
|
2115
|
+
def make_windows_for_files(
|
|
2116
|
+
files: List[Dict[str, Any]],
|
|
2117
|
+
mean: torch.Tensor,
|
|
2118
|
+
std: torch.Tensor,
|
|
2119
|
+
collect_spans: bool = True,
|
|
2120
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
|
|
2121
|
+
x_parts: List[torch.Tensor] = []
|
|
2122
|
+
y_parts: List[torch.Tensor] = []
|
|
2123
|
+
file_spans: List[Dict[str, Any]] = []
|
|
2124
|
+
offset = 0
|
|
2125
|
+
|
|
2126
|
+
for item in files:
|
|
2127
|
+
tx_symbols = item["tx"]
|
|
2128
|
+
rx_symbols = item["rx"]
|
|
2129
|
+
if rx_symbols.size(0) < Config.SEQ_LEN:
|
|
2130
|
+
continue
|
|
2131
|
+
x_file, y_file = make_windows(rx_symbols, tx_symbols, mean, std)
|
|
2132
|
+
count = x_file.size(0)
|
|
2133
|
+
if collect_spans:
|
|
2134
|
+
file_spans.append(
|
|
2135
|
+
{
|
|
2136
|
+
"file_idx": item["file_idx"],
|
|
2137
|
+
"start": offset,
|
|
2138
|
+
"end": offset + count,
|
|
2139
|
+
"tx_center": tx_symbols[Config.CONTEXT_K : tx_symbols.size(0) - Config.CONTEXT_K].contiguous(),
|
|
2140
|
+
"rx_center": rx_symbols[Config.CONTEXT_K : rx_symbols.size(0) - Config.CONTEXT_K].contiguous(),
|
|
2141
|
+
}
|
|
2142
|
+
)
|
|
2143
|
+
x_parts.append(x_file)
|
|
2144
|
+
y_parts.append(y_file)
|
|
2145
|
+
offset += count
|
|
2146
|
+
|
|
2147
|
+
if not x_parts:
|
|
2148
|
+
raise ValueError("No file is long enough to build context windows.")
|
|
2149
|
+
return torch.cat(x_parts, dim=0), torch.cat(y_parts, dim=0), file_spans
|
|
2150
|
+
|
|
2151
|
+
|
|
2152
|
+
def prepare_data(max_test_files: Optional[int] = None) -> Dict[str, Any]:
|
|
2153
|
+
base_dir, all_indices = discover_symbol_files()
|
|
2154
|
+
train_idx, val_idx, test_idx = resolve_splits(all_indices)
|
|
2155
|
+
if max_test_files is not None:
|
|
2156
|
+
test_idx = test_idx[:max_test_files]
|
|
2157
|
+
if not test_idx:
|
|
2158
|
+
raise ValueError("max_test_files truncated test split to zero files.")
|
|
2159
|
+
log(f"Data dir: {base_dir}")
|
|
2160
|
+
log(
|
|
2161
|
+
"Split protocol: train/val/test are separated by file; "
|
|
2162
|
+
"test files are never used for fitting, checkpoint selection, or normalization statistics."
|
|
2163
|
+
)
|
|
2164
|
+
if Config.RANDOMIZE_FILE_SPLIT:
|
|
2165
|
+
log(f"File split randomized with SPLIT_SEED={Config.SPLIT_SEED}")
|
|
2166
|
+
log(f"Train files: {train_idx}")
|
|
2167
|
+
log(f"Val files: {val_idx}")
|
|
2168
|
+
log(f"Test files: {test_idx}")
|
|
2169
|
+
|
|
2170
|
+
train_files = load_file_splits(base_dir, train_idx)
|
|
2171
|
+
val_files = load_file_splits(base_dir, val_idx)
|
|
2172
|
+
test_files = load_file_splits(base_dir, test_idx)
|
|
2173
|
+
tx_train_raw = torch.cat([item["tx"] for item in train_files], dim=0)
|
|
2174
|
+
rx_train_raw = torch.cat([item["rx"] for item in train_files], dim=0)
|
|
2175
|
+
|
|
2176
|
+
if Config.POWER_NORMALIZE:
|
|
2177
|
+
_, _, tx_scale, rx_scale = power_normalize_pair(tx_train_raw, rx_train_raw)
|
|
2178
|
+
else:
|
|
2179
|
+
tx_scale = torch.tensor(1.0, dtype=tx_train_raw.dtype)
|
|
2180
|
+
rx_scale = torch.tensor(1.0, dtype=rx_train_raw.dtype)
|
|
2181
|
+
|
|
2182
|
+
train_files = normalize_file_splits(train_files, tx_scale, rx_scale)
|
|
2183
|
+
val_files = normalize_file_splits(val_files, tx_scale, rx_scale)
|
|
2184
|
+
test_files = normalize_file_splits(test_files, tx_scale, rx_scale)
|
|
2185
|
+
tx_train = torch.cat([item["tx"] for item in train_files], dim=0)
|
|
2186
|
+
rx_train = torch.cat([item["rx"] for item in train_files], dim=0)
|
|
2187
|
+
tx_test = torch.cat([item["tx"] for item in test_files], dim=0)
|
|
2188
|
+
rx_test = torch.cat([item["rx"] for item in test_files], dim=0)
|
|
2189
|
+
|
|
2190
|
+
mean = rx_train.mean(dim=0, keepdim=True)
|
|
2191
|
+
std = rx_train.std(dim=0, keepdim=True)
|
|
2192
|
+
std[std == 0] = 1.0
|
|
2193
|
+
|
|
2194
|
+
train_x, train_y, _ = make_windows_for_files(train_files, mean, std, collect_spans=False)
|
|
2195
|
+
val_x, val_y, val_file_spans = make_windows_for_files(val_files, mean, std)
|
|
2196
|
+
test_x, test_y, test_file_spans = make_windows_for_files(test_files, mean, std)
|
|
2197
|
+
log(
|
|
2198
|
+
"Window samples (built per file, no cross-file context): "
|
|
2199
|
+
f"train {train_x.size(0):,} | val {val_x.size(0):,} | test {test_x.size(0):,}"
|
|
2200
|
+
)
|
|
2201
|
+
rx_test_center = torch.cat([item["rx_center"] for item in test_file_spans], dim=0)
|
|
2202
|
+
tx_test_center = torch.cat([item["tx_center"] for item in test_file_spans], dim=0)
|
|
2203
|
+
|
|
2204
|
+
return {
|
|
2205
|
+
"train_x": train_x,
|
|
2206
|
+
"train_y": train_y,
|
|
2207
|
+
"val_x": val_x,
|
|
2208
|
+
"val_y": val_y,
|
|
2209
|
+
"test_x": test_x,
|
|
2210
|
+
"test_y": test_y,
|
|
2211
|
+
"val_file_spans": val_file_spans,
|
|
2212
|
+
"test_file_spans": test_file_spans,
|
|
2213
|
+
"tx_test": tx_test,
|
|
2214
|
+
"rx_test": rx_test,
|
|
2215
|
+
"tx_test_center": tx_test_center,
|
|
2216
|
+
"rx_test_center": rx_test_center,
|
|
2217
|
+
"mean": mean,
|
|
2218
|
+
"std": std,
|
|
2219
|
+
"tx_scale": tx_scale,
|
|
2220
|
+
"rx_scale": rx_scale,
|
|
2221
|
+
}
|
|
2222
|
+
|
|
2223
|
+
|
|
2224
|
+
def make_model(name: str) -> nn.Module:
|
|
2225
|
+
if name in {"efficient_kan_baseline", "efficient_kan", "ekan"}:
|
|
2226
|
+
return EfficientKANBaselineEqualizer().to(Config.DEVICE)
|
|
2227
|
+
if name in {"efficient_kan_residual", "kan_residual", "residual_kan"}:
|
|
2228
|
+
return EfficientKANResidualEqualizer().to(Config.DEVICE)
|
|
2229
|
+
if name in {"efficient_kan_features", "kan_features", "feature_kan"}:
|
|
2230
|
+
return EfficientKANFeatureEqualizer().to(Config.DEVICE)
|
|
2231
|
+
if name in {"cnn_kan", "kan_cnn"}:
|
|
2232
|
+
return CNNKANEqualizer().to(Config.DEVICE)
|
|
2233
|
+
if name in {"kan_classifier", "efficient_kan_classifier"}:
|
|
2234
|
+
return EfficientKANClassifierEqualizer().to(Config.DEVICE)
|
|
2235
|
+
if name in {"fastkan_classifier", "rbf_kan_classifier"}:
|
|
2236
|
+
return FastKANClassifierEqualizer().to(Config.DEVICE)
|
|
2237
|
+
if name == "cnn":
|
|
2238
|
+
return CNNRxEqualizer().to(Config.DEVICE)
|
|
2239
|
+
if name == "lstm":
|
|
2240
|
+
return LSTMRxEqualizer().to(Config.DEVICE)
|
|
2241
|
+
if name in {"hybrid", "cnn_lstm"}:
|
|
2242
|
+
return HybridCNNLSTMEqualizer().to(Config.DEVICE)
|
|
2243
|
+
if name in {"complex_fastkan", "fastkan", "kan"}:
|
|
2244
|
+
return ComplexFastKANEqualizer().to(Config.DEVICE)
|
|
2245
|
+
if name in {"complex_fastkan_classifier", "complex_rbf_kan_classifier"}:
|
|
2246
|
+
return ComplexFastKANClassifierEqualizer().to(Config.DEVICE)
|
|
2247
|
+
if name == "complex_lstm":
|
|
2248
|
+
return ComplexLSTMRxEqualizer().to(Config.DEVICE)
|
|
2249
|
+
if name == "complex_dbp_seqstat":
|
|
2250
|
+
return ComplexDBPSeqStatRxEqualizer().to(Config.DEVICE)
|
|
2251
|
+
if name == "complex_cnn_lstm":
|
|
2252
|
+
return ComplexCNNLSTMRxEqualizer().to(Config.DEVICE)
|
|
2253
|
+
if name == "complex_cnn":
|
|
2254
|
+
return ComplexCNNRxEqualizer().to(Config.DEVICE)
|
|
2255
|
+
if name == "transformer":
|
|
2256
|
+
return TransformerRxEqualizer().to(Config.DEVICE)
|
|
2257
|
+
if name == "tcn":
|
|
2258
|
+
return TCNRxEqualizer().to(Config.DEVICE)
|
|
2259
|
+
if name == "mamba":
|
|
2260
|
+
return MambaRxEqualizer().to(Config.DEVICE)
|
|
2261
|
+
if name == "mlp":
|
|
2262
|
+
return MLPRxEqualizer().to(Config.DEVICE)
|
|
2263
|
+
raise ValueError(f"Unknown model: {name}")
|
|
2264
|
+
|
|
2265
|
+
|
|
2266
|
+
def iter_tensor_batches(x: torch.Tensor, y: torch.Tensor, batch_size: int, shuffle: bool):
|
|
2267
|
+
total = x.size(0)
|
|
2268
|
+
if shuffle:
|
|
2269
|
+
order = torch.randperm(total)
|
|
2270
|
+
x = x.index_select(0, order)
|
|
2271
|
+
y = y.index_select(0, order)
|
|
2272
|
+
for start in range(0, total, batch_size):
|
|
2273
|
+
end = min(start + batch_size, total)
|
|
2274
|
+
xb = x[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
2275
|
+
yb = y[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
2276
|
+
yield xb, yb
|
|
2277
|
+
|
|
2278
|
+
|
|
2279
|
+
@torch.inference_mode()
|
|
2280
|
+
def evaluate_split(
|
|
2281
|
+
model: nn.Module,
|
|
2282
|
+
x: torch.Tensor,
|
|
2283
|
+
y: torch.Tensor,
|
|
2284
|
+
batch_size: int,
|
|
2285
|
+
scale_search: bool = False,
|
|
2286
|
+
) -> Tuple[float, float, float, int, float]:
|
|
2287
|
+
model.eval()
|
|
2288
|
+
criterion = build_criterion()
|
|
2289
|
+
bit_labels = BIT_LABELS.to(Config.DEVICE)
|
|
2290
|
+
current_batch_size = batch_size
|
|
2291
|
+
|
|
2292
|
+
while True:
|
|
2293
|
+
total_loss = 0.0
|
|
2294
|
+
total_samples = 0
|
|
2295
|
+
correct = 0
|
|
2296
|
+
bit_errors = 0
|
|
2297
|
+
total_bits = 0
|
|
2298
|
+
preds_accum: List[torch.Tensor] = []
|
|
2299
|
+
targets_accum: List[torch.Tensor] = []
|
|
2300
|
+
try:
|
|
2301
|
+
for xb, yb in iter_tensor_batches(x, y, batch_size=current_batch_size, shuffle=False):
|
|
2302
|
+
with autocast_context():
|
|
2303
|
+
mark_cudagraph_step_begin()
|
|
2304
|
+
preds = model(xb)
|
|
2305
|
+
loss = prediction_loss(preds, yb, criterion)
|
|
2306
|
+
preds_float = preds.float()
|
|
2307
|
+
target_float = yb.float()
|
|
2308
|
+
preds_accum.append(preds_float.detach().cpu())
|
|
2309
|
+
targets_accum.append(target_float.detach().cpu())
|
|
2310
|
+
batch_size_now = yb.size(0)
|
|
2311
|
+
total_loss += loss.item() * batch_size_now
|
|
2312
|
+
total_samples += batch_size_now
|
|
2313
|
+
preds_all = torch.cat(preds_accum, dim=0)
|
|
2314
|
+
targets_all = torch.cat(targets_accum, dim=0)
|
|
2315
|
+
target_classes = symbols_to_classes(targets_all.to(Config.DEVICE))
|
|
2316
|
+
if is_classifier_output(preds_all):
|
|
2317
|
+
scale = 1.0
|
|
2318
|
+
pred_classes = torch.argmax(preds_all.to(Config.DEVICE), dim=1)
|
|
2319
|
+
else:
|
|
2320
|
+
scale = find_best_symbol_scale(targets_all, preds_all) if scale_search else 1.0
|
|
2321
|
+
pred_classes = symbols_to_classes((preds_all * scale).to(Config.DEVICE))
|
|
2322
|
+
correct = (pred_classes == target_classes).sum().item()
|
|
2323
|
+
tx_bits = bit_labels[target_classes]
|
|
2324
|
+
rx_bits = bit_labels[pred_classes]
|
|
2325
|
+
bit_errors = (tx_bits != rx_bits).sum().item()
|
|
2326
|
+
total_bits = tx_bits.numel()
|
|
2327
|
+
return (
|
|
2328
|
+
total_loss / max(total_samples, 1),
|
|
2329
|
+
correct / max(total_samples, 1),
|
|
2330
|
+
bit_errors / max(total_bits, 1),
|
|
2331
|
+
current_batch_size,
|
|
2332
|
+
scale,
|
|
2333
|
+
)
|
|
2334
|
+
except RuntimeError as error:
|
|
2335
|
+
if Config.DEVICE.type != "cuda" or not is_cuda_oom(error):
|
|
2336
|
+
raise
|
|
2337
|
+
if current_batch_size <= Config.MIN_BLOCK_SIZE:
|
|
2338
|
+
raise
|
|
2339
|
+
next_batch_size = max(current_batch_size // 2, Config.MIN_BLOCK_SIZE)
|
|
2340
|
+
log(f"eval | CUDA OOM at batch_size={current_batch_size}, retrying with {next_batch_size}")
|
|
2341
|
+
current_batch_size = next_batch_size
|
|
2342
|
+
torch.cuda.empty_cache()
|
|
2343
|
+
|
|
2344
|
+
|
|
2345
|
+
@torch.inference_mode()
|
|
2346
|
+
def compute_split_file_metrics(
|
|
2347
|
+
model: nn.Module,
|
|
2348
|
+
data: Dict[str, Any],
|
|
2349
|
+
split: str,
|
|
2350
|
+
batch_size: int,
|
|
2351
|
+
) -> Tuple[List[Dict[str, float]], int]:
|
|
2352
|
+
file_spans = data.get(f"{split}_file_spans", [])
|
|
2353
|
+
if not file_spans:
|
|
2354
|
+
return [], batch_size
|
|
2355
|
+
|
|
2356
|
+
results: List[Dict[str, float]] = []
|
|
2357
|
+
current_batch_size = batch_size
|
|
2358
|
+
x = data[f"{split}_x"]
|
|
2359
|
+
y = data[f"{split}_y"]
|
|
2360
|
+
for item in file_spans:
|
|
2361
|
+
start = int(item["start"])
|
|
2362
|
+
end = int(item["end"])
|
|
2363
|
+
baseline_scale = find_best_symbol_scale(item["tx_center"], item["rx_center"])
|
|
2364
|
+
tx_cls = symbols_to_classes(item["tx_center"].to(Config.DEVICE))
|
|
2365
|
+
rx_cls = symbols_to_classes((item["rx_center"] * baseline_scale).to(Config.DEVICE))
|
|
2366
|
+
baseline_ber = calculate_ber_from_classes(tx_cls, rx_cls)
|
|
2367
|
+
loss, acc, ber, current_batch_size, equalizer_scale = evaluate_split(
|
|
2368
|
+
model,
|
|
2369
|
+
x[start:end],
|
|
2370
|
+
y[start:end],
|
|
2371
|
+
current_batch_size,
|
|
2372
|
+
scale_search=True,
|
|
2373
|
+
)
|
|
2374
|
+
results.append(
|
|
2375
|
+
{
|
|
2376
|
+
"file_idx": int(item["file_idx"]),
|
|
2377
|
+
"samples": int(end - start),
|
|
2378
|
+
"baseline_ber": float(baseline_ber),
|
|
2379
|
+
"baseline_scale": float(baseline_scale),
|
|
2380
|
+
"equalized_ber": float(ber),
|
|
2381
|
+
"equalizer_scale": float(equalizer_scale),
|
|
2382
|
+
"accuracy": float(acc),
|
|
2383
|
+
"loss": float(loss),
|
|
2384
|
+
"improvement_rel": float((1 - ber / baseline_ber) * 100 if baseline_ber > 0 else 0.0),
|
|
2385
|
+
}
|
|
2386
|
+
)
|
|
2387
|
+
return results, current_batch_size
|
|
2388
|
+
|
|
2389
|
+
|
|
2390
|
+
def add_file_metric_summary(metrics: Dict[str, Any], split: str, file_metrics: List[Dict[str, float]]):
|
|
2391
|
+
if not file_metrics:
|
|
2392
|
+
return
|
|
2393
|
+
equalized = np.array([row["equalized_ber"] for row in file_metrics], dtype=np.float64)
|
|
2394
|
+
baseline = np.array([row["baseline_ber"] for row in file_metrics], dtype=np.float64)
|
|
2395
|
+
metrics[f"{split}_file_equalized_ber_mean"] = float(equalized.mean())
|
|
2396
|
+
metrics[f"{split}_file_equalized_ber_std"] = float(equalized.std())
|
|
2397
|
+
metrics[f"{split}_file_equalized_ber_worst"] = float(equalized.max())
|
|
2398
|
+
metrics[f"{split}_file_baseline_ber_mean"] = float(baseline.mean())
|
|
2399
|
+
metrics[f"{split}_file_baseline_ber_worst"] = float(baseline.max())
|
|
2400
|
+
metrics[f"{split}_file_equalized_ber_by_file"] = ";".join(
|
|
2401
|
+
f"{int(row['file_idx'])}:{row['equalized_ber']:.6e}" for row in file_metrics
|
|
2402
|
+
)
|
|
2403
|
+
metrics[f"{split}_file_baseline_ber_by_file"] = ";".join(
|
|
2404
|
+
f"{int(row['file_idx'])}:{row['baseline_ber']:.6e}" for row in file_metrics
|
|
2405
|
+
)
|
|
2406
|
+
|
|
2407
|
+
|
|
2408
|
+
@torch.inference_mode()
|
|
2409
|
+
def compute_test_metrics(model: nn.Module, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
2410
|
+
eval_prefix = "test"
|
|
2411
|
+
baseline_scale = find_best_symbol_scale(data[f"tx_{eval_prefix}_center"], data[f"rx_{eval_prefix}_center"])
|
|
2412
|
+
tx_cls = symbols_to_classes(data[f"tx_{eval_prefix}_center"].to(Config.DEVICE))
|
|
2413
|
+
rx_cls = symbols_to_classes((data[f"rx_{eval_prefix}_center"] * baseline_scale).to(Config.DEVICE))
|
|
2414
|
+
baseline_ber = calculate_ber_from_classes(tx_cls, rx_cls)
|
|
2415
|
+
eval_batch_size = data.get("eval_batch_size", Config.EVAL_BATCH_SIZE)
|
|
2416
|
+
test_loss, test_acc, test_ber, safe_eval_batch_size, equalizer_scale = evaluate_split(
|
|
2417
|
+
model, data[f"{eval_prefix}_x"], data[f"{eval_prefix}_y"], eval_batch_size, scale_search=True
|
|
2418
|
+
)
|
|
2419
|
+
return {
|
|
2420
|
+
"eval_split": eval_prefix,
|
|
2421
|
+
"baseline_ber": baseline_ber,
|
|
2422
|
+
"baseline_scale": baseline_scale,
|
|
2423
|
+
"equalized_ber": test_ber,
|
|
2424
|
+
"equalizer_scale": equalizer_scale,
|
|
2425
|
+
"accuracy": test_acc,
|
|
2426
|
+
"ser": 1.0 - test_acc,
|
|
2427
|
+
"test_loss": test_loss,
|
|
2428
|
+
"safe_eval_batch_size": safe_eval_batch_size,
|
|
2429
|
+
"improvement_abs": baseline_ber - test_ber,
|
|
2430
|
+
"improvement_rel": (1 - test_ber / baseline_ber) * 100 if baseline_ber > 0 else 0.0,
|
|
2431
|
+
"improvement_db": 10 * np.log10(baseline_ber / test_ber) if test_ber > 0 else float("inf"),
|
|
2432
|
+
}
|
|
2433
|
+
|
|
2434
|
+
|
|
2435
|
+
def build_optimizer_with_lr(model: nn.Module, lr: float) -> optim.Optimizer:
|
|
2436
|
+
optimizer_kwargs = {"lr": lr, "weight_decay": Config.WEIGHT_DECAY}
|
|
2437
|
+
if Config.DEVICE.type == "cuda":
|
|
2438
|
+
optimizer_kwargs["fused"] = True
|
|
2439
|
+
try:
|
|
2440
|
+
return optim.Adam(model.parameters(), **optimizer_kwargs)
|
|
2441
|
+
except TypeError:
|
|
2442
|
+
optimizer_kwargs.pop("fused", None)
|
|
2443
|
+
return optim.Adam(model.parameters(), **optimizer_kwargs)
|
|
2444
|
+
|
|
2445
|
+
|
|
2446
|
+
def get_efficient_kan_module(model: nn.Module) -> Optional[nn.Module]:
|
|
2447
|
+
kan = getattr(model, "kan", None)
|
|
2448
|
+
if kan is None or not hasattr(kan, "layers"):
|
|
2449
|
+
return None
|
|
2450
|
+
layers = list(getattr(kan, "layers"))
|
|
2451
|
+
if not layers or not all(hasattr(layer, "base_weight") and hasattr(layer, "spline_weight") for layer in layers):
|
|
2452
|
+
return None
|
|
2453
|
+
return kan
|
|
2454
|
+
|
|
2455
|
+
|
|
2456
|
+
def efficient_kan_layer_sizes(kan: nn.Module) -> List[int]:
|
|
2457
|
+
layers = list(kan.layers)
|
|
2458
|
+
if not layers:
|
|
2459
|
+
return []
|
|
2460
|
+
return [int(layers[0].in_features)] + [int(layer.out_features) for layer in layers]
|
|
2461
|
+
|
|
2462
|
+
|
|
2463
|
+
@torch.no_grad()
|
|
2464
|
+
def efficient_kan_edge_norm(layer: nn.Module) -> torch.Tensor:
|
|
2465
|
+
base = layer.base_weight.detach().float().abs()
|
|
2466
|
+
spline = layer.scaled_spline_weight.detach().float().abs().mean(dim=2)
|
|
2467
|
+
return base + spline
|
|
2468
|
+
|
|
2469
|
+
|
|
2470
|
+
@torch.no_grad()
|
|
2471
|
+
def select_efficient_kan_hidden_units(kan: nn.Module, keep_ratio: float) -> List[torch.Tensor]:
|
|
2472
|
+
layers = list(kan.layers)
|
|
2473
|
+
selections: List[torch.Tensor] = []
|
|
2474
|
+
for layer_idx in range(len(layers) - 1):
|
|
2475
|
+
current_layer = layers[layer_idx]
|
|
2476
|
+
next_layer = layers[layer_idx + 1]
|
|
2477
|
+
hidden_size = int(current_layer.out_features)
|
|
2478
|
+
keep_count = max(Config.KAN_STRUCTURAL_PRUNE_MIN_HIDDEN, int(round(hidden_size * keep_ratio)))
|
|
2479
|
+
keep_count = min(hidden_size, max(1, keep_count))
|
|
2480
|
+
incoming = efficient_kan_edge_norm(current_layer).mean(dim=1)
|
|
2481
|
+
outgoing = efficient_kan_edge_norm(next_layer).mean(dim=0)
|
|
2482
|
+
importance = torch.sqrt(incoming.clamp_min(1e-12) * outgoing.clamp_min(1e-12))
|
|
2483
|
+
keep = torch.topk(importance, k=keep_count, largest=True, sorted=False).indices
|
|
2484
|
+
keep = torch.sort(keep.cpu()).values
|
|
2485
|
+
selections.append(keep)
|
|
2486
|
+
return selections
|
|
2487
|
+
|
|
2488
|
+
|
|
2489
|
+
@torch.no_grad()
|
|
2490
|
+
def copy_pruned_kan_layer(old_layer: nn.Module, new_layer: nn.Module, input_idx: torch.Tensor, output_idx: torch.Tensor):
|
|
2491
|
+
input_idx = input_idx.to(old_layer.base_weight.device)
|
|
2492
|
+
output_idx = output_idx.to(old_layer.base_weight.device)
|
|
2493
|
+
new_layer.grid.copy_(old_layer.grid.index_select(0, input_idx).to(new_layer.grid.device))
|
|
2494
|
+
new_layer.base_weight.copy_(
|
|
2495
|
+
old_layer.base_weight.index_select(0, output_idx).index_select(1, input_idx).to(new_layer.base_weight.device)
|
|
2496
|
+
)
|
|
2497
|
+
new_layer.spline_weight.copy_(
|
|
2498
|
+
old_layer.spline_weight.index_select(0, output_idx)
|
|
2499
|
+
.index_select(1, input_idx)
|
|
2500
|
+
.to(new_layer.spline_weight.device)
|
|
2501
|
+
)
|
|
2502
|
+
if hasattr(old_layer, "spline_scaler") and hasattr(new_layer, "spline_scaler"):
|
|
2503
|
+
new_layer.spline_scaler.copy_(
|
|
2504
|
+
old_layer.spline_scaler.index_select(0, output_idx)
|
|
2505
|
+
.index_select(1, input_idx)
|
|
2506
|
+
.to(new_layer.spline_scaler.device)
|
|
2507
|
+
)
|
|
2508
|
+
|
|
2509
|
+
|
|
2510
|
+
def structurally_prune_efficient_kan_model(model: nn.Module, keep_ratio: float) -> Tuple[nn.Module, List[int]]:
|
|
2511
|
+
if EfficientKAN is None:
|
|
2512
|
+
raise ImportError("EfficientKAN is unavailable.")
|
|
2513
|
+
old_kan = get_efficient_kan_module(model)
|
|
2514
|
+
if old_kan is None:
|
|
2515
|
+
raise ValueError("Model does not expose a prunable EfficientKAN module via `.kan`.")
|
|
2516
|
+
|
|
2517
|
+
old_layers = list(old_kan.layers)
|
|
2518
|
+
old_sizes = efficient_kan_layer_sizes(old_kan)
|
|
2519
|
+
hidden_selections = select_efficient_kan_hidden_units(old_kan, keep_ratio)
|
|
2520
|
+
all_indices: List[torch.Tensor] = [
|
|
2521
|
+
torch.arange(old_sizes[0], dtype=torch.long),
|
|
2522
|
+
*hidden_selections,
|
|
2523
|
+
torch.arange(old_sizes[-1], dtype=torch.long),
|
|
2524
|
+
]
|
|
2525
|
+
new_sizes = [int(idx.numel()) for idx in all_indices]
|
|
2526
|
+
new_kan = EfficientKAN(
|
|
2527
|
+
layers_hidden=new_sizes,
|
|
2528
|
+
grid_size=int(old_kan.grid_size),
|
|
2529
|
+
spline_order=int(old_kan.spline_order),
|
|
2530
|
+
scale_noise=Config.EFFICIENT_KAN_SCALE_NOISE,
|
|
2531
|
+
scale_base=Config.EFFICIENT_KAN_SCALE_BASE,
|
|
2532
|
+
scale_spline=Config.EFFICIENT_KAN_SCALE_SPLINE,
|
|
2533
|
+
base_activation=nn.SiLU,
|
|
2534
|
+
grid_eps=Config.EFFICIENT_KAN_GRID_EPS,
|
|
2535
|
+
grid_range=Config.EFFICIENT_KAN_GRID_RANGE,
|
|
2536
|
+
).to(Config.DEVICE)
|
|
2537
|
+
for layer_idx, (old_layer, new_layer) in enumerate(zip(old_layers, new_kan.layers)):
|
|
2538
|
+
copy_pruned_kan_layer(old_layer, new_layer, all_indices[layer_idx], all_indices[layer_idx + 1])
|
|
2539
|
+
model.kan = new_kan
|
|
2540
|
+
return model, new_sizes
|
|
2541
|
+
|
|
2542
|
+
|
|
2543
|
+
def efficiency_score_from_metrics(metrics: Dict[str, Any]) -> float:
|
|
2544
|
+
baseline = float(metrics.get("baseline_ber", 0.0))
|
|
2545
|
+
equalized = float(metrics.get("equalized_ber", 0.0))
|
|
2546
|
+
batch_time = float(metrics.get("efficiency_batch_time_sec", 0.0))
|
|
2547
|
+
if baseline <= 0 or batch_time <= 0:
|
|
2548
|
+
return 0.0
|
|
2549
|
+
improvement = max((baseline - equalized) / baseline, 0.0)
|
|
2550
|
+
return float(improvement**Config.EFFICIENCY_SCORE_POWER / batch_time)
|
|
2551
|
+
|
|
2552
|
+
|
|
2553
|
+
@torch.inference_mode()
|
|
2554
|
+
def measure_batch_inference_time(model: nn.Module, x: torch.Tensor) -> float:
|
|
2555
|
+
model.eval()
|
|
2556
|
+
batch_size = min(Config.EFFICIENCY_BATCH_SIZE, x.size(0))
|
|
2557
|
+
if batch_size <= 0:
|
|
2558
|
+
return 0.0
|
|
2559
|
+
xb = x[:batch_size].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
2560
|
+
for _ in range(Config.EFFICIENCY_TIMING_WARMUP):
|
|
2561
|
+
with autocast_context():
|
|
2562
|
+
mark_cudagraph_step_begin()
|
|
2563
|
+
_ = model(xb)
|
|
2564
|
+
if Config.DEVICE.type == "cuda":
|
|
2565
|
+
torch.cuda.synchronize()
|
|
2566
|
+
start = time.perf_counter()
|
|
2567
|
+
for _ in range(Config.EFFICIENCY_TIMING_REPEATS):
|
|
2568
|
+
with autocast_context():
|
|
2569
|
+
mark_cudagraph_step_begin()
|
|
2570
|
+
_ = model(xb)
|
|
2571
|
+
if Config.DEVICE.type == "cuda":
|
|
2572
|
+
torch.cuda.synchronize()
|
|
2573
|
+
elapsed = time.perf_counter() - start
|
|
2574
|
+
return elapsed / max(Config.EFFICIENCY_TIMING_REPEATS, 1)
|
|
2575
|
+
|
|
2576
|
+
|
|
2577
|
+
def add_efficiency_metrics(model: nn.Module, data: Dict[str, Any], metrics: Dict[str, Any]) -> Dict[str, Any]:
|
|
2578
|
+
batch_time = measure_batch_inference_time(model, data["test_x"])
|
|
2579
|
+
metrics["efficiency_batch_size"] = int(min(Config.EFFICIENCY_BATCH_SIZE, data["test_x"].size(0)))
|
|
2580
|
+
metrics["efficiency_batch_time_sec"] = float(batch_time)
|
|
2581
|
+
metrics["efficiency_score"] = efficiency_score_from_metrics(metrics)
|
|
2582
|
+
return metrics
|
|
2583
|
+
|
|
2584
|
+
|
|
2585
|
+
def fine_tune_model_for_pruning(
|
|
2586
|
+
model: nn.Module,
|
|
2587
|
+
data: Dict[str, Any],
|
|
2588
|
+
epochs: int,
|
|
2589
|
+
lr: float,
|
|
2590
|
+
eval_batch_size: int,
|
|
2591
|
+
) -> Tuple[nn.Module, Dict[str, float], int]:
|
|
2592
|
+
if epochs <= 0:
|
|
2593
|
+
val_loss, val_acc, val_ber, eval_batch_size, val_scale = evaluate_split(
|
|
2594
|
+
model, data["val_x"], data["val_y"], eval_batch_size, scale_search=True
|
|
2595
|
+
)
|
|
2596
|
+
return model, {"val_loss": val_loss, "val_acc": val_acc, "val_ber": val_ber, "val_scale": val_scale}, eval_batch_size
|
|
2597
|
+
|
|
2598
|
+
optimizer = build_optimizer_with_lr(model, lr)
|
|
2599
|
+
criterion = build_criterion()
|
|
2600
|
+
scaler = torch.amp.GradScaler("cuda", enabled=Config.DEVICE.type == "cuda" and Config.USE_AMP)
|
|
2601
|
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
2602
|
+
best_metrics = {"val_loss": float("inf"), "val_acc": 0.0, "val_ber": float("inf"), "val_scale": 1.0}
|
|
2603
|
+
train_block_size = Config.TRAIN_BLOCK_SIZE
|
|
2604
|
+
|
|
2605
|
+
for epoch in range(epochs):
|
|
2606
|
+
model.train()
|
|
2607
|
+
running_loss = 0.0
|
|
2608
|
+
seen = 0
|
|
2609
|
+
total_train = data["train_x"].size(0)
|
|
2610
|
+
num_blocks = (total_train + train_block_size - 1) // train_block_size
|
|
2611
|
+
for block_idx in torch.randperm(num_blocks).tolist():
|
|
2612
|
+
block_start = block_idx * train_block_size
|
|
2613
|
+
block_end = min(block_start + train_block_size, total_train)
|
|
2614
|
+
xb = data["train_x"][block_start:block_end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
2615
|
+
yb = data["train_y"][block_start:block_end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
2616
|
+
optimizer.zero_grad(set_to_none=True)
|
|
2617
|
+
with autocast_context():
|
|
2618
|
+
mark_cudagraph_step_begin()
|
|
2619
|
+
preds = model(xb)
|
|
2620
|
+
loss = prediction_loss(preds, yb, criterion) + compute_model_regularization(model)
|
|
2621
|
+
scaler.scale(loss).backward()
|
|
2622
|
+
if Config.GRAD_CLIP_NORM > 0:
|
|
2623
|
+
scaler.unscale_(optimizer)
|
|
2624
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRAD_CLIP_NORM)
|
|
2625
|
+
scaler.step(optimizer)
|
|
2626
|
+
scaler.update()
|
|
2627
|
+
running_loss += loss.item() * yb.size(0)
|
|
2628
|
+
seen += yb.size(0)
|
|
2629
|
+
|
|
2630
|
+
val_loss, val_acc, val_ber, eval_batch_size, val_scale = evaluate_split(
|
|
2631
|
+
model, data["val_x"], data["val_y"], eval_batch_size, scale_search=True
|
|
2632
|
+
)
|
|
2633
|
+
if val_ber < best_metrics["val_ber"]:
|
|
2634
|
+
best_metrics = {"val_loss": val_loss, "val_acc": val_acc, "val_ber": val_ber, "val_scale": val_scale}
|
|
2635
|
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
2636
|
+
log(
|
|
2637
|
+
f"prune fine-tune | epoch {epoch+1:3d}/{epochs} | "
|
|
2638
|
+
f"train {running_loss / max(seen, 1):.6f} | val_ber {val_ber:.6e}"
|
|
2639
|
+
)
|
|
2640
|
+
|
|
2641
|
+
model.load_state_dict(best_state)
|
|
2642
|
+
return model, best_metrics, eval_batch_size
|
|
2643
|
+
|
|
2644
|
+
|
|
2645
|
+
def should_replace_by_pruned(candidate_metrics: Dict[str, Any], best_metrics: Dict[str, Any]) -> bool:
|
|
2646
|
+
mode = Config.KAN_STRUCTURAL_PRUNE_SELECT_BY
|
|
2647
|
+
if mode == "val_ber":
|
|
2648
|
+
return float(candidate_metrics.get("prune_val_ber", float("inf"))) < float(
|
|
2649
|
+
best_metrics.get("prune_val_ber", best_metrics.get("best_val_ber", float("inf")))
|
|
2650
|
+
)
|
|
2651
|
+
if mode == "test_ber":
|
|
2652
|
+
return float(candidate_metrics["equalized_ber"]) < float(best_metrics["equalized_ber"])
|
|
2653
|
+
return float(candidate_metrics.get("efficiency_score", 0.0)) > float(best_metrics.get("efficiency_score", 0.0))
|
|
2654
|
+
|
|
2655
|
+
|
|
2656
|
+
def maybe_prune_efficient_kan_model(
|
|
2657
|
+
model: nn.Module,
|
|
2658
|
+
model_name: str,
|
|
2659
|
+
data: Dict[str, Any],
|
|
2660
|
+
base_metrics: Dict[str, Any],
|
|
2661
|
+
eval_batch_size: int,
|
|
2662
|
+
) -> Tuple[nn.Module, Dict[str, Any], int]:
|
|
2663
|
+
if not Config.KAN_STRUCTURAL_PRUNE_AFTER_TRAINING:
|
|
2664
|
+
return model, base_metrics, eval_batch_size
|
|
2665
|
+
if get_efficient_kan_module(model) is None:
|
|
2666
|
+
log(f"{model_name} | structural KAN pruning skipped: model has no prunable EfficientKAN `.kan`")
|
|
2667
|
+
return model, base_metrics, eval_batch_size
|
|
2668
|
+
|
|
2669
|
+
base_metrics = add_efficiency_metrics(model, data, base_metrics)
|
|
2670
|
+
base_metrics["pruned"] = False
|
|
2671
|
+
base_metrics["prune_keep_ratio"] = 1.0
|
|
2672
|
+
base_metrics["prune_layer_sizes"] = str(efficient_kan_layer_sizes(get_efficient_kan_module(model)))
|
|
2673
|
+
base_metrics["prune_val_ber"] = base_metrics.get("best_val_ber", float("inf"))
|
|
2674
|
+
best_model = model
|
|
2675
|
+
best_metrics = dict(base_metrics)
|
|
2676
|
+
rows: List[Dict[str, Any]] = [dict(base_metrics, model_type=model_name, prune_candidate="unpruned")]
|
|
2677
|
+
|
|
2678
|
+
for keep_ratio in Config.KAN_STRUCTURAL_PRUNE_KEEP_RATIOS:
|
|
2679
|
+
candidate = copy.deepcopy(model).to(Config.DEVICE)
|
|
2680
|
+
try:
|
|
2681
|
+
candidate, layer_sizes = structurally_prune_efficient_kan_model(candidate, keep_ratio)
|
|
2682
|
+
except Exception as exc:
|
|
2683
|
+
log(f"{model_name} | prune keep_ratio={keep_ratio:.2f} skipped: {exc}")
|
|
2684
|
+
continue
|
|
2685
|
+
candidate_params = count_trainable_parameters(candidate)
|
|
2686
|
+
log(
|
|
2687
|
+
f"{model_name} | prune candidate keep_ratio={keep_ratio:.2f} | "
|
|
2688
|
+
f"layers {layer_sizes} | params {candidate_params:,}"
|
|
2689
|
+
)
|
|
2690
|
+
candidate, val_metrics, eval_batch_size = fine_tune_model_for_pruning(
|
|
2691
|
+
candidate,
|
|
2692
|
+
data,
|
|
2693
|
+
Config.KAN_STRUCTURAL_PRUNE_FINE_TUNE_EPOCHS,
|
|
2694
|
+
Config.KAN_STRUCTURAL_PRUNE_FINE_TUNE_LR,
|
|
2695
|
+
eval_batch_size,
|
|
2696
|
+
)
|
|
2697
|
+
candidate_metrics = compute_test_metrics(candidate, {**data, "eval_batch_size": eval_batch_size})
|
|
2698
|
+
candidate_metrics["trainable_params"] = candidate_params
|
|
2699
|
+
candidate_metrics["pruned"] = True
|
|
2700
|
+
candidate_metrics["prune_keep_ratio"] = float(keep_ratio)
|
|
2701
|
+
candidate_metrics["prune_layer_sizes"] = str(layer_sizes)
|
|
2702
|
+
candidate_metrics["prune_val_loss"] = float(val_metrics["val_loss"])
|
|
2703
|
+
candidate_metrics["prune_val_acc"] = float(val_metrics["val_acc"])
|
|
2704
|
+
candidate_metrics["prune_val_ber"] = float(val_metrics["val_ber"])
|
|
2705
|
+
candidate_metrics = add_efficiency_metrics(candidate, data, candidate_metrics)
|
|
2706
|
+
rows.append(dict(candidate_metrics, model_type=model_name, prune_candidate=f"keep_{keep_ratio:.2f}"))
|
|
2707
|
+
log(
|
|
2708
|
+
f"{model_name} | prune keep_ratio={keep_ratio:.2f} | "
|
|
2709
|
+
f"test_ber {candidate_metrics['equalized_ber']:.6e} | "
|
|
2710
|
+
f"batch16k {candidate_metrics['efficiency_batch_time_sec']:.6f}s | "
|
|
2711
|
+
f"score {candidate_metrics['efficiency_score']:.3f}"
|
|
2712
|
+
)
|
|
2713
|
+
if should_replace_by_pruned(candidate_metrics, best_metrics):
|
|
2714
|
+
best_model = candidate
|
|
2715
|
+
best_metrics = dict(candidate_metrics)
|
|
2716
|
+
elif Config.DEVICE.type == "cuda":
|
|
2717
|
+
candidate.to("cpu")
|
|
2718
|
+
del candidate
|
|
2719
|
+
torch.cuda.empty_cache()
|
|
2720
|
+
|
|
2721
|
+
prune_df = pd.DataFrame(rows)
|
|
2722
|
+
prune_path = Config.OUT_DIR / f"{model_name}_pruning_candidates.csv"
|
|
2723
|
+
prune_df.to_csv(prune_path, index=False)
|
|
2724
|
+
log(
|
|
2725
|
+
f"{model_name} | pruning selected keep_ratio={best_metrics.get('prune_keep_ratio', 1.0)} | "
|
|
2726
|
+
f"test_ber {best_metrics['equalized_ber']:.6e} | "
|
|
2727
|
+
f"score {best_metrics.get('efficiency_score', 0.0):.3f} | saved {prune_path}"
|
|
2728
|
+
)
|
|
2729
|
+
return best_model, best_metrics, eval_batch_size
|
|
2730
|
+
|
|
2731
|
+
|
|
2732
|
+
def plot_results(history: Dict, eval_results: Dict, model_name: str):
|
|
2733
|
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
|
2734
|
+
train_epochs = list(range(1, len(history["train_loss"]) + 1))
|
|
2735
|
+
val_epochs = history["val_epochs"]
|
|
2736
|
+
test_epochs = history["test_epochs"]
|
|
2737
|
+
|
|
2738
|
+
axes[0, 0].plot(train_epochs, history["train_loss"], label="Train", linewidth=2, alpha=0.8)
|
|
2739
|
+
axes[0, 0].plot(val_epochs, history["val_loss"], label="Val", linewidth=2)
|
|
2740
|
+
if history["test_loss"]:
|
|
2741
|
+
axes[0, 0].plot(test_epochs, history["test_loss"], label="Test", linewidth=2, linestyle="--")
|
|
2742
|
+
axes[0, 0].set_title(f"{model_name} - Loss Curves", fontweight="bold")
|
|
2743
|
+
axes[0, 0].set_xlabel("Epoch")
|
|
2744
|
+
axes[0, 0].set_ylabel("Loss")
|
|
2745
|
+
axes[0, 0].set_yscale("log")
|
|
2746
|
+
axes[0, 0].grid(alpha=0.3)
|
|
2747
|
+
axes[0, 0].legend()
|
|
2748
|
+
|
|
2749
|
+
axes[0, 1].plot(val_epochs, history["val_acc"], color="green", linewidth=2, label="Val Acc")
|
|
2750
|
+
if history["test_acc"]:
|
|
2751
|
+
axes[0, 1].plot(test_epochs, history["test_acc"], color="#0984e3", linewidth=2, linestyle="--", label="Test Acc")
|
|
2752
|
+
axes[0, 1].set_title(f"{model_name} - Accuracy", fontweight="bold")
|
|
2753
|
+
axes[0, 1].set_xlabel("Epoch")
|
|
2754
|
+
axes[0, 1].set_ylabel("Accuracy")
|
|
2755
|
+
axes[0, 1].set_ylim([0, 1])
|
|
2756
|
+
axes[0, 1].grid(alpha=0.3)
|
|
2757
|
+
axes[0, 1].legend()
|
|
2758
|
+
|
|
2759
|
+
axes[0, 2].plot(train_epochs, history["lr"], color="red", linewidth=2)
|
|
2760
|
+
axes[0, 2].set_title("Learning Rate Schedule", fontweight="bold")
|
|
2761
|
+
axes[0, 2].set_xlabel("Epoch")
|
|
2762
|
+
axes[0, 2].set_ylabel("LR")
|
|
2763
|
+
axes[0, 2].set_yscale("log")
|
|
2764
|
+
axes[0, 2].grid(alpha=0.3)
|
|
2765
|
+
|
|
2766
|
+
axes[1, 0].bar(
|
|
2767
|
+
["Baseline", f"{model_name} EQ"],
|
|
2768
|
+
[eval_results["baseline_ber"], eval_results["equalized_ber"]],
|
|
2769
|
+
color=["#ff7675", "#55efc4"],
|
|
2770
|
+
edgecolor="black",
|
|
2771
|
+
linewidth=1.5,
|
|
2772
|
+
)
|
|
2773
|
+
axes[1, 0].set_title("BER Comparison (log scale)", fontweight="bold")
|
|
2774
|
+
axes[1, 0].set_ylabel("BER")
|
|
2775
|
+
axes[1, 0].set_yscale("log")
|
|
2776
|
+
axes[1, 0].grid(axis="y", alpha=0.3)
|
|
2777
|
+
|
|
2778
|
+
metrics = ["Abs Reduction\n(pp)", "Rel Improvement\n(%)", "SNR Gain\n(dB)"]
|
|
2779
|
+
values = [
|
|
2780
|
+
eval_results["improvement_abs"] * 100,
|
|
2781
|
+
eval_results["improvement_rel"],
|
|
2782
|
+
min(eval_results["improvement_db"], 20),
|
|
2783
|
+
]
|
|
2784
|
+
axes[1, 1].bar(metrics, values, color=["#74b9ff", "#a29bfe", "#fd79a8"], edgecolor="black", linewidth=1.5)
|
|
2785
|
+
axes[1, 1].set_title("Improvement Metrics", fontweight="bold")
|
|
2786
|
+
axes[1, 1].set_ylabel("Value")
|
|
2787
|
+
axes[1, 1].grid(axis="y", alpha=0.3)
|
|
2788
|
+
|
|
2789
|
+
axes[1, 2].plot(train_epochs, [loss * 100 for loss in history["train_loss"]], label="Train Loss x100", alpha=0.7, linewidth=1)
|
|
2790
|
+
axes[1, 2].plot(val_epochs, [acc * 100 for acc in history["val_acc"]], label="Val Acc (%)", alpha=0.7, linewidth=2)
|
|
2791
|
+
if history["test_ber"]:
|
|
2792
|
+
axes[1, 2].plot(test_epochs, [ber * 100 for ber in history["test_ber"]], label="Test BER (%)", alpha=0.9, linewidth=2)
|
|
2793
|
+
axes[1, 2].set_title("Training Dynamics", fontweight="bold")
|
|
2794
|
+
axes[1, 2].set_xlabel("Epoch")
|
|
2795
|
+
axes[1, 2].set_ylabel("Value")
|
|
2796
|
+
axes[1, 2].grid(alpha=0.3)
|
|
2797
|
+
axes[1, 2].legend()
|
|
2798
|
+
|
|
2799
|
+
plt.suptitle(f"{model_name} Equalizer Performance", fontsize=16, fontweight="bold")
|
|
2800
|
+
plt.tight_layout()
|
|
2801
|
+
out_path = Config.OUT_DIR / f"ber_results_{model_name.lower()}.png"
|
|
2802
|
+
plt.savefig(out_path, dpi=150, bbox_inches="tight")
|
|
2803
|
+
plt.close(fig)
|
|
2804
|
+
|
|
2805
|
+
|
|
2806
|
+
def plot_architecture_summary(results: List[Dict[str, float]]):
|
|
2807
|
+
summary = pd.DataFrame(results).sort_values("equalized_ber").reset_index(drop=True)
|
|
2808
|
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
|
2809
|
+
|
|
2810
|
+
axes[0].bar(summary["model_type"], summary["equalized_ber"], color="#00b894", edgecolor="black")
|
|
2811
|
+
axes[0].set_title("Equalized BER by Architecture", fontweight="bold")
|
|
2812
|
+
axes[0].set_ylabel("BER")
|
|
2813
|
+
axes[0].set_yscale("log")
|
|
2814
|
+
axes[0].grid(axis="y", alpha=0.3)
|
|
2815
|
+
|
|
2816
|
+
axes[1].bar(summary["model_type"], summary["improvement_rel"], color="#0984e3", edgecolor="black")
|
|
2817
|
+
axes[1].set_title("Relative BER Improvement", fontweight="bold")
|
|
2818
|
+
axes[1].set_ylabel("Improvement (%)")
|
|
2819
|
+
axes[1].grid(axis="y", alpha=0.3)
|
|
2820
|
+
|
|
2821
|
+
plt.tight_layout()
|
|
2822
|
+
out_path = Config.OUT_DIR / "architecture_summary.png"
|
|
2823
|
+
plt.savefig(out_path, dpi=150, bbox_inches="tight")
|
|
2824
|
+
plt.close(fig)
|
|
2825
|
+
|
|
2826
|
+
summary.to_csv(Config.OUT_DIR / "architecture_comparison.csv", index=False)
|
|
2827
|
+
|
|
2828
|
+
|
|
2829
|
+
def plot_sweep_per_model(df: pd.DataFrame, x_col: str, x_label: str, filename_prefix: str):
|
|
2830
|
+
for model_name in Config.MODEL_TYPES:
|
|
2831
|
+
model_df = df[df["model_type"] == model_name].sort_values(x_col)
|
|
2832
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
2833
|
+
ax.plot(model_df[x_col], model_df["equalized_ber"], marker="o", linewidth=2)
|
|
2834
|
+
ax.set_title(f"{model_name.upper()} - BER vs {x_label}", fontweight="bold")
|
|
2835
|
+
ax.set_xlabel(x_label)
|
|
2836
|
+
ax.set_ylabel("BER")
|
|
2837
|
+
ax.set_yscale("log")
|
|
2838
|
+
ax.grid(alpha=0.3)
|
|
2839
|
+
plt.tight_layout()
|
|
2840
|
+
plt.savefig(Config.OUT_DIR / f"{filename_prefix}_{model_name}.png", dpi=150, bbox_inches="tight")
|
|
2841
|
+
plt.close(fig)
|
|
2842
|
+
|
|
2843
|
+
|
|
2844
|
+
def plot_sweep_overlay(df: pd.DataFrame, x_col: str, x_label: str, filename: str):
|
|
2845
|
+
fig, ax = plt.subplots(figsize=(9, 6))
|
|
2846
|
+
for model_name in Config.MODEL_TYPES:
|
|
2847
|
+
model_df = df[df["model_type"] == model_name].sort_values(x_col)
|
|
2848
|
+
ax.plot(model_df[x_col], model_df["equalized_ber"], marker="o", linewidth=2, label=model_name.upper())
|
|
2849
|
+
ax.set_title(f"BER vs {x_label} - All Models", fontweight="bold")
|
|
2850
|
+
ax.set_xlabel(x_label)
|
|
2851
|
+
ax.set_ylabel("BER")
|
|
2852
|
+
ax.set_yscale("log")
|
|
2853
|
+
ax.grid(alpha=0.3)
|
|
2854
|
+
ax.legend()
|
|
2855
|
+
plt.tight_layout()
|
|
2856
|
+
plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
|
|
2857
|
+
plt.close(fig)
|
|
2858
|
+
|
|
2859
|
+
|
|
2860
|
+
def plot_efficient_kan_sweep(df: pd.DataFrame, x_col: str, x_label: str, filename: str, title: str):
|
|
2861
|
+
fig, ax = plt.subplots(figsize=(9, 6))
|
|
2862
|
+
for model_name in Config.EFFICIENT_KAN_SWEEP_MODELS:
|
|
2863
|
+
model_df = df[df["model_type"] == model_name].sort_values(x_col)
|
|
2864
|
+
if model_df.empty:
|
|
2865
|
+
continue
|
|
2866
|
+
ax.plot(model_df[x_col], model_df["equalized_ber"], marker="o", linewidth=2, label=model_name)
|
|
2867
|
+
ax.set_title(title, fontweight="bold")
|
|
2868
|
+
ax.set_xlabel(x_label)
|
|
2869
|
+
ax.set_ylabel("BER")
|
|
2870
|
+
ax.set_yscale("log")
|
|
2871
|
+
ax.grid(alpha=0.3)
|
|
2872
|
+
ax.legend()
|
|
2873
|
+
plt.tight_layout()
|
|
2874
|
+
plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
|
|
2875
|
+
plt.close(fig)
|
|
2876
|
+
|
|
2877
|
+
|
|
2878
|
+
def plot_efficient_kan_tradeoff(df: pd.DataFrame, x_col: str, x_label: str, filename: str, title: str):
|
|
2879
|
+
fig, ax = plt.subplots(figsize=(9, 6))
|
|
2880
|
+
for model_name in Config.EFFICIENT_KAN_SWEEP_MODELS:
|
|
2881
|
+
model_df = df[df["model_type"] == model_name]
|
|
2882
|
+
if model_df.empty:
|
|
2883
|
+
continue
|
|
2884
|
+
ax.scatter(model_df[x_col], model_df["equalized_ber"], s=60, label=model_name)
|
|
2885
|
+
ax.set_title(title, fontweight="bold")
|
|
2886
|
+
ax.set_xlabel(x_label)
|
|
2887
|
+
ax.set_ylabel("BER")
|
|
2888
|
+
ax.set_yscale("log")
|
|
2889
|
+
ax.grid(alpha=0.3)
|
|
2890
|
+
ax.legend()
|
|
2891
|
+
plt.tight_layout()
|
|
2892
|
+
plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
|
|
2893
|
+
plt.close(fig)
|
|
2894
|
+
|
|
2895
|
+
|
|
2896
|
+
def plot_experiment_lines(
|
|
2897
|
+
df: pd.DataFrame,
|
|
2898
|
+
x_col: str,
|
|
2899
|
+
x_label: str,
|
|
2900
|
+
filename: str,
|
|
2901
|
+
title: str,
|
|
2902
|
+
y_col: str = "equalized_ber",
|
|
2903
|
+
):
|
|
2904
|
+
fig, ax = plt.subplots(figsize=(9, 6))
|
|
2905
|
+
for model_name, model_df in df.groupby("model_type"):
|
|
2906
|
+
model_df = model_df.sort_values(x_col)
|
|
2907
|
+
ax.plot(model_df[x_col], model_df[y_col], marker="o", linewidth=2, label=model_name)
|
|
2908
|
+
ax.set_title(title, fontweight="bold")
|
|
2909
|
+
ax.set_xlabel(x_label)
|
|
2910
|
+
ax.set_ylabel("BER")
|
|
2911
|
+
ax.set_yscale("log")
|
|
2912
|
+
ax.grid(alpha=0.3)
|
|
2913
|
+
ax.legend()
|
|
2914
|
+
plt.tight_layout()
|
|
2915
|
+
plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
|
|
2916
|
+
plt.close(fig)
|
|
2917
|
+
|
|
2918
|
+
|
|
2919
|
+
def plot_experiment_complexity(df: pd.DataFrame, filename: str, title: str):
|
|
2920
|
+
fig, ax = plt.subplots(figsize=(9, 6))
|
|
2921
|
+
for model_name, model_df in df.groupby("model_type"):
|
|
2922
|
+
ax.plot(
|
|
2923
|
+
model_df["trainable_params"],
|
|
2924
|
+
model_df["equalized_ber"],
|
|
2925
|
+
marker="o",
|
|
2926
|
+
linewidth=1.5,
|
|
2927
|
+
linestyle="-",
|
|
2928
|
+
label=model_name,
|
|
2929
|
+
)
|
|
2930
|
+
ax.set_title(title, fontweight="bold")
|
|
2931
|
+
ax.set_xlabel("Trainable Parameters")
|
|
2932
|
+
ax.set_ylabel("BER")
|
|
2933
|
+
ax.set_xscale("log")
|
|
2934
|
+
ax.set_yscale("log")
|
|
2935
|
+
ax.grid(alpha=0.3)
|
|
2936
|
+
ax.legend()
|
|
2937
|
+
plt.tight_layout()
|
|
2938
|
+
plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
|
|
2939
|
+
plt.close(fig)
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
def hidden_overrides(hidden_dim: int) -> Dict[str, int]:
|
|
2943
|
+
return {
|
|
2944
|
+
"HIDDEN_DIM": hidden_dim,
|
|
2945
|
+
"EFFICIENT_KAN_HIDDEN_DIM": hidden_dim,
|
|
2946
|
+
"FASTKAN_HIDDEN_DIM": hidden_dim,
|
|
2947
|
+
}
|
|
2948
|
+
|
|
2949
|
+
|
|
2950
|
+
def layer_overrides(layer_count: int) -> Dict[str, int]:
|
|
2951
|
+
return {
|
|
2952
|
+
"EFFICIENT_KAN_LAYERS": layer_count,
|
|
2953
|
+
"FASTKAN_LAYERS": layer_count,
|
|
2954
|
+
"MLP_LAYERS": layer_count,
|
|
2955
|
+
}
|
|
2956
|
+
|
|
2957
|
+
|
|
2958
|
+
def run_model_with_overrides(model_name: str, max_test_files: int, **overrides) -> Dict[str, Any]:
|
|
2959
|
+
tracked_keys = [
|
|
2960
|
+
"CONTEXT_K",
|
|
2961
|
+
"SEQ_LEN",
|
|
2962
|
+
"HIDDEN_DIM",
|
|
2963
|
+
"LSTM_HIDDEN",
|
|
2964
|
+
"SAVE_BEST",
|
|
2965
|
+
"EPOCHS",
|
|
2966
|
+
"LEARNING_RATE",
|
|
2967
|
+
"EFFICIENT_KAN_HIDDEN_DIM",
|
|
2968
|
+
"EFFICIENT_KAN_LAYERS",
|
|
2969
|
+
"EFFICIENT_KAN_GRID_SIZE",
|
|
2970
|
+
"EFFICIENT_KAN_SPLINE_ORDER",
|
|
2971
|
+
"FASTKAN_HIDDEN_DIM",
|
|
2972
|
+
"FASTKAN_LAYERS",
|
|
2973
|
+
"FASTKAN_NUM_GRIDS",
|
|
2974
|
+
"MLP_LAYERS",
|
|
2975
|
+
"COMPUTE_PER_FILE_METRICS",
|
|
2976
|
+
]
|
|
2977
|
+
previous = {key: getattr(Config, key) for key in tracked_keys}
|
|
2978
|
+
try:
|
|
2979
|
+
for key, value in overrides.items():
|
|
2980
|
+
setattr(Config, key, value)
|
|
2981
|
+
Config.SEQ_LEN = 2 * Config.CONTEXT_K + 1
|
|
2982
|
+
Config.SAVE_BEST = False
|
|
2983
|
+
data = prepare_data(max_test_files=max_test_files)
|
|
2984
|
+
model, _, results = train_one_model(model_name, data)
|
|
2985
|
+
if Config.DEVICE.type == "cuda":
|
|
2986
|
+
model.to("cpu")
|
|
2987
|
+
del model
|
|
2988
|
+
torch.cuda.empty_cache()
|
|
2989
|
+
return results
|
|
2990
|
+
finally:
|
|
2991
|
+
for key, value in previous.items():
|
|
2992
|
+
setattr(Config, key, value)
|
|
2993
|
+
|
|
2994
|
+
|
|
2995
|
+
def plot_fastkan_classifier_score(df: pd.DataFrame, filename: str):
|
|
2996
|
+
fig, ax = plt.subplots(figsize=(9, 6))
|
|
2997
|
+
for model_name, model_df in df.groupby("model_type"):
|
|
2998
|
+
ax.scatter(
|
|
2999
|
+
model_df["trainable_params"],
|
|
3000
|
+
model_df["efficiency_score"],
|
|
3001
|
+
s=70,
|
|
3002
|
+
label=model_name,
|
|
3003
|
+
)
|
|
3004
|
+
ax.set_title("FastKAN Classifiers: BER-Speed Efficiency", fontweight="bold")
|
|
3005
|
+
ax.set_xlabel("Trainable Parameters")
|
|
3006
|
+
ax.set_ylabel("Efficiency Score")
|
|
3007
|
+
ax.set_xscale("log")
|
|
3008
|
+
ax.grid(alpha=0.3)
|
|
3009
|
+
ax.legend()
|
|
3010
|
+
plt.tight_layout()
|
|
3011
|
+
plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
|
|
3012
|
+
plt.close(fig)
|
|
3013
|
+
|
|
3014
|
+
|
|
3015
|
+
def run_fastkan_classifier_sweep():
|
|
3016
|
+
log("\nRunning compact FastKAN/RBF-KAN classifier sweep...")
|
|
3017
|
+
rows: List[Dict[str, Any]] = []
|
|
3018
|
+
output_path = Config.OUT_DIR / "fastkan_classifier_sweep_all.csv"
|
|
3019
|
+
base = {
|
|
3020
|
+
"EPOCHS": Config.FASTKAN_CLASSIFIER_SWEEP_EPOCHS,
|
|
3021
|
+
"COMPUTE_PER_FILE_METRICS": False,
|
|
3022
|
+
"FASTKAN_HIDDEN_DIM": Config.FASTKAN_HIDDEN_DIM,
|
|
3023
|
+
"FASTKAN_LAYERS": Config.FASTKAN_LAYERS,
|
|
3024
|
+
"FASTKAN_NUM_GRIDS": Config.FASTKAN_NUM_GRIDS,
|
|
3025
|
+
}
|
|
3026
|
+
|
|
3027
|
+
def run_case(sweep_name: str, model_name: str, **overrides):
|
|
3028
|
+
effective = {**base, **overrides}
|
|
3029
|
+
log(
|
|
3030
|
+
f"fastkan sweep | {sweep_name} | {model_name} | "
|
|
3031
|
+
f"hidden={effective['FASTKAN_HIDDEN_DIM']} | grids={effective['FASTKAN_NUM_GRIDS']} | "
|
|
3032
|
+
f"layers={effective['FASTKAN_LAYERS']}"
|
|
3033
|
+
)
|
|
3034
|
+
results = run_model_with_overrides(
|
|
3035
|
+
model_name,
|
|
3036
|
+
max_test_files=Config.FASTKAN_CLASSIFIER_SWEEP_TEST_FILES,
|
|
3037
|
+
**effective,
|
|
3038
|
+
)
|
|
3039
|
+
rows.append(
|
|
3040
|
+
{
|
|
3041
|
+
"sweep": sweep_name,
|
|
3042
|
+
"model_type": model_name,
|
|
3043
|
+
"hidden_dim": effective["FASTKAN_HIDDEN_DIM"],
|
|
3044
|
+
"num_grids": effective["FASTKAN_NUM_GRIDS"],
|
|
3045
|
+
"layers": effective["FASTKAN_LAYERS"],
|
|
3046
|
+
**results,
|
|
3047
|
+
}
|
|
3048
|
+
)
|
|
3049
|
+
pd.DataFrame(rows).to_csv(output_path, index=False)
|
|
3050
|
+
|
|
3051
|
+
for model_name in Config.FASTKAN_CLASSIFIER_SWEEP_MODELS:
|
|
3052
|
+
for hidden_dim in Config.FASTKAN_CLASSIFIER_HIDDEN_VALUES:
|
|
3053
|
+
run_case("hidden_dim", model_name, FASTKAN_HIDDEN_DIM=hidden_dim)
|
|
3054
|
+
for num_grids in Config.FASTKAN_CLASSIFIER_GRID_VALUES:
|
|
3055
|
+
run_case("num_grids", model_name, FASTKAN_NUM_GRIDS=num_grids)
|
|
3056
|
+
for layers in Config.FASTKAN_CLASSIFIER_LAYER_VALUES:
|
|
3057
|
+
run_case("layers", model_name, FASTKAN_LAYERS=layers)
|
|
3058
|
+
|
|
3059
|
+
df = pd.DataFrame(rows)
|
|
3060
|
+
for sweep_name, x_col, x_label in [
|
|
3061
|
+
("hidden_dim", "hidden_dim", "Hidden Dimension"),
|
|
3062
|
+
("num_grids", "num_grids", "RBF Grid Count"),
|
|
3063
|
+
("layers", "layers", "FastKAN Layers"),
|
|
3064
|
+
]:
|
|
3065
|
+
sweep_df = df[df["sweep"] == sweep_name].copy()
|
|
3066
|
+
sweep_df.to_csv(Config.OUT_DIR / f"fastkan_classifier_ber_vs_{sweep_name}.csv", index=False)
|
|
3067
|
+
plot_experiment_lines(
|
|
3068
|
+
sweep_df,
|
|
3069
|
+
x_col=x_col,
|
|
3070
|
+
x_label=x_label,
|
|
3071
|
+
filename=f"fastkan_classifier_ber_vs_{sweep_name}.png",
|
|
3072
|
+
title=f"FastKAN Classifier BER vs {x_label}",
|
|
3073
|
+
)
|
|
3074
|
+
plot_experiment_complexity(
|
|
3075
|
+
df,
|
|
3076
|
+
filename="fastkan_classifier_ber_vs_complexity.png",
|
|
3077
|
+
title="FastKAN Classifiers: BER vs Complexity",
|
|
3078
|
+
)
|
|
3079
|
+
plot_fastkan_classifier_score(df, "fastkan_classifier_efficiency_score.png")
|
|
3080
|
+
log(f"Saved FastKAN classifier sweep: {output_path}")
|
|
3081
|
+
|
|
3082
|
+
|
|
3083
|
+
def run_sweep_experiments():
|
|
3084
|
+
if Config.RUN_FASTKAN_CLASSIFIER_SWEEP:
|
|
3085
|
+
run_fastkan_classifier_sweep()
|
|
3086
|
+
return
|
|
3087
|
+
if Config.RUN_KAN_EXPERIMENT_SUITE:
|
|
3088
|
+
run_kan_experiment_suite()
|
|
3089
|
+
return
|
|
3090
|
+
if Config.RUN_EFFICIENT_KAN_SWEEP:
|
|
3091
|
+
run_efficient_kan_sweep_experiments()
|
|
3092
|
+
return
|
|
3093
|
+
|
|
3094
|
+
log("\nRunning BER sweep experiments on one test file...")
|
|
3095
|
+
window_rows: List[Dict[str, float]] = []
|
|
3096
|
+
hidden_rows: List[Dict[str, float]] = []
|
|
3097
|
+
|
|
3098
|
+
for model_name in Config.MODEL_TYPES:
|
|
3099
|
+
for context_k in Config.WINDOW_SWEEP_VALUES:
|
|
3100
|
+
log(f"sweep | {model_name} | window={context_k}")
|
|
3101
|
+
results = run_model_with_overrides(
|
|
3102
|
+
model_name,
|
|
3103
|
+
max_test_files=Config.SWEEP_TEST_FILES,
|
|
3104
|
+
CONTEXT_K=context_k,
|
|
3105
|
+
)
|
|
3106
|
+
window_rows.append(
|
|
3107
|
+
{
|
|
3108
|
+
"model_type": model_name,
|
|
3109
|
+
"context_k": context_k,
|
|
3110
|
+
"seq_len": 2 * context_k + 1,
|
|
3111
|
+
**results,
|
|
3112
|
+
}
|
|
3113
|
+
)
|
|
3114
|
+
|
|
3115
|
+
for hidden_size in Config.HIDDEN_SWEEP_VALUES:
|
|
3116
|
+
log(f"sweep | {model_name} | hidden={hidden_size}")
|
|
3117
|
+
results = run_model_with_overrides(
|
|
3118
|
+
model_name,
|
|
3119
|
+
max_test_files=Config.SWEEP_TEST_FILES,
|
|
3120
|
+
HIDDEN_DIM=hidden_size,
|
|
3121
|
+
LSTM_HIDDEN=hidden_size,
|
|
3122
|
+
)
|
|
3123
|
+
hidden_rows.append(
|
|
3124
|
+
{
|
|
3125
|
+
"model_type": model_name,
|
|
3126
|
+
"hidden_size": hidden_size,
|
|
3127
|
+
**results,
|
|
3128
|
+
}
|
|
3129
|
+
)
|
|
3130
|
+
|
|
3131
|
+
window_df = pd.DataFrame(window_rows)
|
|
3132
|
+
hidden_df = pd.DataFrame(hidden_rows)
|
|
3133
|
+
window_df.to_csv(Config.OUT_DIR / "ber_vs_window.csv", index=False)
|
|
3134
|
+
hidden_df.to_csv(Config.OUT_DIR / "ber_vs_hidden.csv", index=False)
|
|
3135
|
+
|
|
3136
|
+
plot_sweep_per_model(window_df, x_col="seq_len", x_label="Window Size", filename_prefix="ber_vs_window")
|
|
3137
|
+
plot_sweep_per_model(hidden_df, x_col="hidden_size", x_label="Hidden Size", filename_prefix="ber_vs_hidden")
|
|
3138
|
+
plot_sweep_overlay(window_df, x_col="seq_len", x_label="Window Size", filename="ber_vs_window_overlay.png")
|
|
3139
|
+
plot_sweep_overlay(hidden_df, x_col="hidden_size", x_label="Hidden Size", filename="ber_vs_hidden_overlay.png")
|
|
3140
|
+
|
|
3141
|
+
|
|
3142
|
+
def run_kan_experiment_suite():
|
|
3143
|
+
log("\nRunning KAN/MLP experiment suite...")
|
|
3144
|
+
rows: List[Dict[str, Any]] = []
|
|
3145
|
+
|
|
3146
|
+
base = {
|
|
3147
|
+
"EPOCHS": Config.EXPERIMENT_EPOCHS,
|
|
3148
|
+
"EFFICIENT_KAN_GRID_SIZE": Config.EXPERIMENT_FIXED_GRID,
|
|
3149
|
+
"EFFICIENT_KAN_SPLINE_ORDER": Config.EXPERIMENT_FIXED_SPLINE_ORDER,
|
|
3150
|
+
"COMPUTE_PER_FILE_METRICS": Config.EXPERIMENT_COMPUTE_PER_FILE_METRICS,
|
|
3151
|
+
}
|
|
3152
|
+
|
|
3153
|
+
def run_case(sweep_name: str, model_name: str, **overrides):
|
|
3154
|
+
effective = {**base, **overrides}
|
|
3155
|
+
effective["SEQ_LEN"] = 2 * effective.get("CONTEXT_K", Config.CONTEXT_K) + 1
|
|
3156
|
+
case_id = (
|
|
3157
|
+
f"{sweep_name}_{model_name}_"
|
|
3158
|
+
f"k{effective.get('CONTEXT_K', Config.CONTEXT_K)}_"
|
|
3159
|
+
f"h{effective.get('EFFICIENT_KAN_HIDDEN_DIM', Config.EFFICIENT_KAN_HIDDEN_DIM)}_"
|
|
3160
|
+
f"mlph{effective.get('HIDDEN_DIM', Config.HIDDEN_DIM)}_"
|
|
3161
|
+
f"l{effective.get('EFFICIENT_KAN_LAYERS', Config.EFFICIENT_KAN_LAYERS)}_"
|
|
3162
|
+
f"mlpl{effective.get('MLP_LAYERS', Config.MLP_LAYERS)}_"
|
|
3163
|
+
f"g{effective.get('EFFICIENT_KAN_GRID_SIZE', Config.EFFICIENT_KAN_GRID_SIZE)}_"
|
|
3164
|
+
f"o{effective.get('EFFICIENT_KAN_SPLINE_ORDER', Config.EFFICIENT_KAN_SPLINE_ORDER)}"
|
|
3165
|
+
)
|
|
3166
|
+
log(f"experiment | {case_id}")
|
|
3167
|
+
results = run_model_with_overrides(
|
|
3168
|
+
model_name,
|
|
3169
|
+
max_test_files=Config.EXPERIMENT_TEST_FILES,
|
|
3170
|
+
**effective,
|
|
3171
|
+
)
|
|
3172
|
+
row = {
|
|
3173
|
+
"case_id": case_id,
|
|
3174
|
+
"sweep": sweep_name,
|
|
3175
|
+
"model_type": model_name,
|
|
3176
|
+
"context_k": effective.get("CONTEXT_K", Config.CONTEXT_K),
|
|
3177
|
+
"seq_len": effective.get("SEQ_LEN", Config.SEQ_LEN),
|
|
3178
|
+
"hidden_dim": effective.get("EFFICIENT_KAN_HIDDEN_DIM", effective.get("HIDDEN_DIM", Config.HIDDEN_DIM)),
|
|
3179
|
+
"mlp_hidden_dim": effective.get("HIDDEN_DIM", Config.HIDDEN_DIM),
|
|
3180
|
+
"kan_hidden_dim": effective.get("EFFICIENT_KAN_HIDDEN_DIM", Config.EFFICIENT_KAN_HIDDEN_DIM),
|
|
3181
|
+
"layers": effective.get("EFFICIENT_KAN_LAYERS", effective.get("MLP_LAYERS", Config.MLP_LAYERS)),
|
|
3182
|
+
"mlp_layers": effective.get("MLP_LAYERS", Config.MLP_LAYERS),
|
|
3183
|
+
"kan_layers": effective.get("EFFICIENT_KAN_LAYERS", Config.EFFICIENT_KAN_LAYERS),
|
|
3184
|
+
"grid_size": effective.get("EFFICIENT_KAN_GRID_SIZE", Config.EFFICIENT_KAN_GRID_SIZE),
|
|
3185
|
+
"spline_order": effective.get("EFFICIENT_KAN_SPLINE_ORDER", Config.EFFICIENT_KAN_SPLINE_ORDER),
|
|
3186
|
+
"epochs": effective["EPOCHS"],
|
|
3187
|
+
**results,
|
|
3188
|
+
}
|
|
3189
|
+
rows.append(row)
|
|
3190
|
+
pd.DataFrame(rows).to_csv(Config.OUT_DIR / "kan_experiment_suite_all.csv", index=False)
|
|
3191
|
+
|
|
3192
|
+
for model_name in Config.EXPERIMENT_KAN_MODELS:
|
|
3193
|
+
for grid_size in Config.EXPERIMENT_GRID_VALUES:
|
|
3194
|
+
run_case("ber_vs_grid", model_name, EFFICIENT_KAN_GRID_SIZE=grid_size)
|
|
3195
|
+
|
|
3196
|
+
for spline_order in Config.EXPERIMENT_SPLINE_ORDER_VALUES:
|
|
3197
|
+
run_case("ber_vs_spline_order", model_name, EFFICIENT_KAN_SPLINE_ORDER=spline_order)
|
|
3198
|
+
|
|
3199
|
+
for model_name in Config.EXPERIMENT_COMPARE_MODELS:
|
|
3200
|
+
for hidden_dim in Config.EXPERIMENT_HIDDEN_VALUES:
|
|
3201
|
+
run_case(
|
|
3202
|
+
"kan_mlp_vs_hidden_grid16",
|
|
3203
|
+
model_name,
|
|
3204
|
+
**hidden_overrides(hidden_dim),
|
|
3205
|
+
EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
|
|
3206
|
+
)
|
|
3207
|
+
|
|
3208
|
+
for context_k in Config.EXPERIMENT_WINDOW_VALUES:
|
|
3209
|
+
run_case(
|
|
3210
|
+
"kan_mlp_vs_window",
|
|
3211
|
+
model_name,
|
|
3212
|
+
CONTEXT_K=context_k,
|
|
3213
|
+
EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
|
|
3214
|
+
)
|
|
3215
|
+
|
|
3216
|
+
for layer_count in Config.EXPERIMENT_LAYER_VALUES:
|
|
3217
|
+
run_case(
|
|
3218
|
+
"kan_mlp_vs_layers",
|
|
3219
|
+
model_name,
|
|
3220
|
+
**layer_overrides(layer_count),
|
|
3221
|
+
EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
|
|
3222
|
+
)
|
|
3223
|
+
|
|
3224
|
+
for model_name in Config.EXPERIMENT_COMPLEXITY_MODELS:
|
|
3225
|
+
for hidden_dim in Config.EXPERIMENT_HIDDEN_VALUES:
|
|
3226
|
+
run_case(
|
|
3227
|
+
"ber_vs_complexity",
|
|
3228
|
+
model_name,
|
|
3229
|
+
**hidden_overrides(hidden_dim),
|
|
3230
|
+
EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
|
|
3231
|
+
)
|
|
3232
|
+
|
|
3233
|
+
df = pd.DataFrame(rows)
|
|
3234
|
+
all_path = Config.OUT_DIR / "kan_experiment_suite_all.csv"
|
|
3235
|
+
df.to_csv(all_path, index=False)
|
|
3236
|
+
|
|
3237
|
+
plot_specs = [
|
|
3238
|
+
("ber_vs_grid", "grid_size", "Grid Size", "ber_vs_grid.png", "BER vs Grid Size"),
|
|
3239
|
+
(
|
|
3240
|
+
"ber_vs_spline_order",
|
|
3241
|
+
"spline_order",
|
|
3242
|
+
"Spline Order",
|
|
3243
|
+
"ber_vs_spline_order.png",
|
|
3244
|
+
"BER vs Spline Order",
|
|
3245
|
+
),
|
|
3246
|
+
(
|
|
3247
|
+
"kan_mlp_vs_hidden_grid16",
|
|
3248
|
+
"hidden_dim",
|
|
3249
|
+
"Hidden Dimension",
|
|
3250
|
+
"kan_mlp_ber_vs_hidden_grid16.png",
|
|
3251
|
+
"KAN vs MLP: BER vs Hidden Dim (grid=16)",
|
|
3252
|
+
),
|
|
3253
|
+
(
|
|
3254
|
+
"kan_mlp_vs_window",
|
|
3255
|
+
"seq_len",
|
|
3256
|
+
"Window Size (symbols)",
|
|
3257
|
+
"kan_mlp_ber_vs_window.png",
|
|
3258
|
+
"KAN vs MLP: BER vs Window Size",
|
|
3259
|
+
),
|
|
3260
|
+
(
|
|
3261
|
+
"kan_mlp_vs_layers",
|
|
3262
|
+
"layers",
|
|
3263
|
+
"Number of Layers",
|
|
3264
|
+
"kan_mlp_ber_vs_layers.png",
|
|
3265
|
+
"KAN vs MLP: BER vs Number of Layers",
|
|
3266
|
+
),
|
|
3267
|
+
]
|
|
3268
|
+
for sweep_name, x_col, x_label, filename, title in plot_specs:
|
|
3269
|
+
sweep_df = df[df["sweep"] == sweep_name].copy()
|
|
3270
|
+
sweep_df.to_csv(Config.OUT_DIR / f"{sweep_name}.csv", index=False)
|
|
3271
|
+
if not sweep_df.empty:
|
|
3272
|
+
plot_experiment_lines(sweep_df, x_col=x_col, x_label=x_label, filename=filename, title=title)
|
|
3273
|
+
|
|
3274
|
+
complexity_df = df[df["sweep"] == "ber_vs_complexity"].copy()
|
|
3275
|
+
complexity_df.to_csv(Config.OUT_DIR / "ber_vs_complexity.csv", index=False)
|
|
3276
|
+
if not complexity_df.empty:
|
|
3277
|
+
plot_experiment_complexity(
|
|
3278
|
+
complexity_df,
|
|
3279
|
+
filename="ber_vs_complexity.png",
|
|
3280
|
+
title="BER vs Complexity",
|
|
3281
|
+
)
|
|
3282
|
+
|
|
3283
|
+
log(f"Saved KAN experiment suite: {all_path}")
|
|
3284
|
+
|
|
3285
|
+
|
|
3286
|
+
def run_efficient_kan_sweep_experiments():
|
|
3287
|
+
log("\nRunning EfficientKAN regression/classifier sweep...")
|
|
3288
|
+
rows: List[Dict[str, float]] = []
|
|
3289
|
+
base = {
|
|
3290
|
+
"EPOCHS": Config.EFFICIENT_KAN_SWEEP_EPOCHS,
|
|
3291
|
+
"EFFICIENT_KAN_HIDDEN_DIM": Config.EFFICIENT_KAN_HIDDEN_DIM,
|
|
3292
|
+
"EFFICIENT_KAN_LAYERS": Config.EFFICIENT_KAN_LAYERS,
|
|
3293
|
+
"EFFICIENT_KAN_GRID_SIZE": Config.EFFICIENT_KAN_GRID_SIZE,
|
|
3294
|
+
"EFFICIENT_KAN_SPLINE_ORDER": Config.EFFICIENT_KAN_SPLINE_ORDER,
|
|
3295
|
+
"LEARNING_RATE": Config.LEARNING_RATE,
|
|
3296
|
+
}
|
|
3297
|
+
|
|
3298
|
+
def run_case(sweep_name: str, model_name: str, **overrides):
|
|
3299
|
+
effective = {**base, **overrides}
|
|
3300
|
+
case_id = (
|
|
3301
|
+
f"{sweep_name}_{model_name}_"
|
|
3302
|
+
f"h{effective['EFFICIENT_KAN_HIDDEN_DIM']}_"
|
|
3303
|
+
f"l{effective['EFFICIENT_KAN_LAYERS']}_"
|
|
3304
|
+
f"g{effective['EFFICIENT_KAN_GRID_SIZE']}_"
|
|
3305
|
+
f"o{effective['EFFICIENT_KAN_SPLINE_ORDER']}_"
|
|
3306
|
+
f"lr{effective['LEARNING_RATE']:.0e}"
|
|
3307
|
+
)
|
|
3308
|
+
log(f"sweep | {case_id}")
|
|
3309
|
+
results = run_model_with_overrides(
|
|
3310
|
+
model_name,
|
|
3311
|
+
max_test_files=Config.EFFICIENT_KAN_SWEEP_TEST_FILES,
|
|
3312
|
+
**effective,
|
|
3313
|
+
)
|
|
3314
|
+
row = {
|
|
3315
|
+
"case_id": case_id,
|
|
3316
|
+
"sweep": sweep_name,
|
|
3317
|
+
"model_type": model_name,
|
|
3318
|
+
"hidden_dim": effective["EFFICIENT_KAN_HIDDEN_DIM"],
|
|
3319
|
+
"layers": effective["EFFICIENT_KAN_LAYERS"],
|
|
3320
|
+
"grid_size": effective["EFFICIENT_KAN_GRID_SIZE"],
|
|
3321
|
+
"spline_order": effective["EFFICIENT_KAN_SPLINE_ORDER"],
|
|
3322
|
+
"learning_rate": effective["LEARNING_RATE"],
|
|
3323
|
+
"epochs": effective["EPOCHS"],
|
|
3324
|
+
**results,
|
|
3325
|
+
}
|
|
3326
|
+
rows.append(row)
|
|
3327
|
+
pd.DataFrame(rows).to_csv(Config.OUT_DIR / "efficient_kan_sweep_all.csv", index=False)
|
|
3328
|
+
|
|
3329
|
+
for model_name in Config.EFFICIENT_KAN_SWEEP_MODELS:
|
|
3330
|
+
for hidden_dim in Config.EFFICIENT_KAN_HIDDEN_SWEEP_VALUES:
|
|
3331
|
+
for learning_rate in Config.EFFICIENT_KAN_LR_SWEEP_VALUES:
|
|
3332
|
+
run_case(
|
|
3333
|
+
"hidden_lr",
|
|
3334
|
+
model_name,
|
|
3335
|
+
EFFICIENT_KAN_HIDDEN_DIM=hidden_dim,
|
|
3336
|
+
LEARNING_RATE=learning_rate,
|
|
3337
|
+
)
|
|
3338
|
+
|
|
3339
|
+
for grid_size in Config.EFFICIENT_KAN_GRID_SWEEP_VALUES:
|
|
3340
|
+
run_case("grid_size", model_name, EFFICIENT_KAN_GRID_SIZE=grid_size)
|
|
3341
|
+
|
|
3342
|
+
for spline_order in Config.EFFICIENT_KAN_ORDER_SWEEP_VALUES:
|
|
3343
|
+
run_case("spline_order", model_name, EFFICIENT_KAN_SPLINE_ORDER=spline_order)
|
|
3344
|
+
|
|
3345
|
+
for layers in Config.EFFICIENT_KAN_LAYER_SWEEP_VALUES:
|
|
3346
|
+
run_case("layers", model_name, EFFICIENT_KAN_LAYERS=layers)
|
|
3347
|
+
|
|
3348
|
+
df = pd.DataFrame(rows)
|
|
3349
|
+
df.to_csv(Config.OUT_DIR / "efficient_kan_sweep_all.csv", index=False)
|
|
3350
|
+
|
|
3351
|
+
hidden_df = df[df["sweep"] == "hidden_lr"].copy()
|
|
3352
|
+
grid_df = df[df["sweep"] == "grid_size"].copy()
|
|
3353
|
+
order_df = df[df["sweep"] == "spline_order"].copy()
|
|
3354
|
+
layer_df = df[df["sweep"] == "layers"].copy()
|
|
3355
|
+
|
|
3356
|
+
hidden_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_hidden_lr.csv", index=False)
|
|
3357
|
+
grid_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_grid.csv", index=False)
|
|
3358
|
+
order_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_order.csv", index=False)
|
|
3359
|
+
layer_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_layers.csv", index=False)
|
|
3360
|
+
|
|
3361
|
+
for learning_rate in Config.EFFICIENT_KAN_LR_SWEEP_VALUES:
|
|
3362
|
+
lr_df = hidden_df[hidden_df["learning_rate"] == learning_rate]
|
|
3363
|
+
if lr_df.empty:
|
|
3364
|
+
continue
|
|
3365
|
+
plot_efficient_kan_sweep(
|
|
3366
|
+
lr_df,
|
|
3367
|
+
x_col="hidden_dim",
|
|
3368
|
+
x_label="EfficientKAN Hidden Dimension",
|
|
3369
|
+
filename=f"efficient_kan_ber_vs_hidden_lr_{learning_rate:.0e}.png",
|
|
3370
|
+
title=f"EfficientKAN BER vs Hidden Dim (lr={learning_rate:.0e})",
|
|
3371
|
+
)
|
|
3372
|
+
|
|
3373
|
+
plot_efficient_kan_sweep(
|
|
3374
|
+
grid_df,
|
|
3375
|
+
x_col="grid_size",
|
|
3376
|
+
x_label="Grid Size",
|
|
3377
|
+
filename="efficient_kan_ber_vs_grid.png",
|
|
3378
|
+
title="EfficientKAN BER vs Grid Size",
|
|
3379
|
+
)
|
|
3380
|
+
plot_efficient_kan_sweep(
|
|
3381
|
+
order_df,
|
|
3382
|
+
x_col="spline_order",
|
|
3383
|
+
x_label="Spline Order",
|
|
3384
|
+
filename="efficient_kan_ber_vs_spline_order.png",
|
|
3385
|
+
title="EfficientKAN BER vs Spline Order",
|
|
3386
|
+
)
|
|
3387
|
+
plot_efficient_kan_sweep(
|
|
3388
|
+
layer_df,
|
|
3389
|
+
x_col="layers",
|
|
3390
|
+
x_label="KAN Hidden Layers",
|
|
3391
|
+
filename="efficient_kan_ber_vs_layers.png",
|
|
3392
|
+
title="EfficientKAN BER vs Number of Layers",
|
|
3393
|
+
)
|
|
3394
|
+
plot_efficient_kan_tradeoff(
|
|
3395
|
+
df,
|
|
3396
|
+
x_col="trainable_params",
|
|
3397
|
+
x_label="Trainable Parameters",
|
|
3398
|
+
filename="efficient_kan_params_vs_ber.png",
|
|
3399
|
+
title="EfficientKAN Parameter Count vs BER",
|
|
3400
|
+
)
|
|
3401
|
+
plot_efficient_kan_tradeoff(
|
|
3402
|
+
df,
|
|
3403
|
+
x_col="train_samples_per_sec",
|
|
3404
|
+
x_label="Training Samples/sec",
|
|
3405
|
+
filename="efficient_kan_speed_vs_ber.png",
|
|
3406
|
+
title="EfficientKAN Training Speed vs BER",
|
|
3407
|
+
)
|
|
3408
|
+
log(f"Saved EfficientKAN sweep: {Config.OUT_DIR / 'efficient_kan_sweep_all.csv'}")
|
|
3409
|
+
|
|
3410
|
+
|
|
3411
|
+
def train_one_model(model_name: str, data: Dict[str, Any]) -> Tuple[nn.Module, Dict, Dict[str, Any]]:
|
|
3412
|
+
model = make_model(model_name)
|
|
3413
|
+
trainable_params = count_trainable_parameters(model)
|
|
3414
|
+
log(
|
|
3415
|
+
f"{model_name} | params {trainable_params:,} | "
|
|
3416
|
+
f"{MODEL_NOTES.get(model_name, 'custom equalizer')}"
|
|
3417
|
+
)
|
|
3418
|
+
if hasattr(model, "initialize_from_data"):
|
|
3419
|
+
model.initialize_from_data(data["train_x"], data["train_y"])
|
|
3420
|
+
log(f"{model_name} | initialized from training windows")
|
|
3421
|
+
if Config.USE_TORCH_COMPILE and hasattr(torch, "compile"):
|
|
3422
|
+
try:
|
|
3423
|
+
model = torch.compile(model, mode=Config.TORCH_COMPILE_MODE)
|
|
3424
|
+
log(f"{model_name} | torch.compile enabled ({Config.TORCH_COMPILE_MODE})")
|
|
3425
|
+
except Exception as exc:
|
|
3426
|
+
log(f"{model_name} | torch.compile disabled: {exc}")
|
|
3427
|
+
|
|
3428
|
+
optimizer = build_optimizer(model)
|
|
3429
|
+
criterion = build_criterion()
|
|
3430
|
+
scaler = torch.amp.GradScaler("cuda", enabled=Config.DEVICE.type == "cuda" and Config.USE_AMP)
|
|
3431
|
+
best_train_loss = float("inf")
|
|
3432
|
+
best_val_ber = float("inf")
|
|
3433
|
+
best_score = float("inf")
|
|
3434
|
+
best_state = None
|
|
3435
|
+
steps_without_improvement = 0
|
|
3436
|
+
early_stop_without_improvement = 0
|
|
3437
|
+
early_stopped = False
|
|
3438
|
+
stop_reason = "completed"
|
|
3439
|
+
|
|
3440
|
+
history = {
|
|
3441
|
+
"train_loss": [],
|
|
3442
|
+
"val_loss": [],
|
|
3443
|
+
"val_acc": [],
|
|
3444
|
+
"val_ber": [],
|
|
3445
|
+
"val_scale": [],
|
|
3446
|
+
"val_epochs": [],
|
|
3447
|
+
"test_loss": [],
|
|
3448
|
+
"test_acc": [],
|
|
3449
|
+
"test_ber": [],
|
|
3450
|
+
"test_scale": [],
|
|
3451
|
+
"test_epochs": [],
|
|
3452
|
+
"lr": [],
|
|
3453
|
+
"epoch_time_sec": [],
|
|
3454
|
+
"train_samples_per_sec": [],
|
|
3455
|
+
}
|
|
3456
|
+
|
|
3457
|
+
train_x = data["train_x"]
|
|
3458
|
+
train_y = data["train_y"]
|
|
3459
|
+
train_block_size = Config.TRAIN_BLOCK_SIZE
|
|
3460
|
+
eval_batch_size = Config.EVAL_BATCH_SIZE
|
|
3461
|
+
|
|
3462
|
+
if Config.SAVE_BEST_BY in {"val_ber", "val_loss"}:
|
|
3463
|
+
init_val_loss, _, init_val_ber, eval_batch_size, _ = evaluate_split(
|
|
3464
|
+
model,
|
|
3465
|
+
data["val_x"],
|
|
3466
|
+
data["val_y"],
|
|
3467
|
+
eval_batch_size,
|
|
3468
|
+
scale_search=True,
|
|
3469
|
+
)
|
|
3470
|
+
best_val_ber = init_val_ber
|
|
3471
|
+
best_score = init_val_ber if Config.SAVE_BEST_BY == "val_ber" else init_val_loss
|
|
3472
|
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
3473
|
+
if Config.SAVE_BEST:
|
|
3474
|
+
torch.save(best_state, Config.OUT_DIR / f"{model_name}_best.pth")
|
|
3475
|
+
log(
|
|
3476
|
+
f"{model_name} | init checkpoint | val {init_val_loss:.6f} | "
|
|
3477
|
+
f"val_ber {init_val_ber:.6e}"
|
|
3478
|
+
)
|
|
3479
|
+
|
|
3480
|
+
for epoch in range(Config.EPOCHS):
|
|
3481
|
+
model.train()
|
|
3482
|
+
epoch_start = time.time()
|
|
3483
|
+
running_loss = 0.0
|
|
3484
|
+
seen = 0
|
|
3485
|
+
total_train = train_x.size(0)
|
|
3486
|
+
num_blocks = (total_train + train_block_size - 1) // train_block_size
|
|
3487
|
+
block_order = torch.randperm(num_blocks).tolist()
|
|
3488
|
+
for block_idx in block_order:
|
|
3489
|
+
block_start = block_idx * train_block_size
|
|
3490
|
+
block_end = min(block_start + train_block_size, total_train)
|
|
3491
|
+
start = block_start
|
|
3492
|
+
while start < block_end:
|
|
3493
|
+
end = min(start + train_block_size, block_end)
|
|
3494
|
+
xb = train_x[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
3495
|
+
yb = train_y[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
|
|
3496
|
+
try:
|
|
3497
|
+
optimizer.zero_grad(set_to_none=True)
|
|
3498
|
+
with autocast_context():
|
|
3499
|
+
mark_cudagraph_step_begin()
|
|
3500
|
+
preds = model(xb)
|
|
3501
|
+
loss = prediction_loss(preds, yb, criterion)
|
|
3502
|
+
loss = loss + compute_model_regularization(model)
|
|
3503
|
+
scaler.scale(loss).backward()
|
|
3504
|
+
if Config.GRAD_CLIP_NORM > 0:
|
|
3505
|
+
scaler.unscale_(optimizer)
|
|
3506
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRAD_CLIP_NORM)
|
|
3507
|
+
scaler.step(optimizer)
|
|
3508
|
+
scaler.update()
|
|
3509
|
+
batch_size_now = yb.size(0)
|
|
3510
|
+
running_loss += loss.item() * batch_size_now
|
|
3511
|
+
seen += batch_size_now
|
|
3512
|
+
start = end
|
|
3513
|
+
except RuntimeError as error:
|
|
3514
|
+
if Config.DEVICE.type != "cuda" or not is_cuda_oom(error):
|
|
3515
|
+
raise
|
|
3516
|
+
optimizer.zero_grad(set_to_none=True)
|
|
3517
|
+
if train_block_size <= Config.MIN_BLOCK_SIZE:
|
|
3518
|
+
raise
|
|
3519
|
+
next_block_size = max(train_block_size // 2, Config.MIN_BLOCK_SIZE)
|
|
3520
|
+
log(f"{model_name} | CUDA OOM at train_block_size={train_block_size}, retrying with {next_block_size}")
|
|
3521
|
+
train_block_size = next_block_size
|
|
3522
|
+
torch.cuda.empty_cache()
|
|
3523
|
+
|
|
3524
|
+
train_loss = running_loss / max(seen, 1)
|
|
3525
|
+
epoch_time = time.time() - epoch_start
|
|
3526
|
+
speed = seen / max(epoch_time, 1e-9)
|
|
3527
|
+
|
|
3528
|
+
val_loss, val_acc, val_ber, eval_batch_size, val_scale = evaluate_split(
|
|
3529
|
+
model,
|
|
3530
|
+
data["val_x"],
|
|
3531
|
+
data["val_y"],
|
|
3532
|
+
eval_batch_size,
|
|
3533
|
+
scale_search=True,
|
|
3534
|
+
)
|
|
3535
|
+
history["train_loss"].append(train_loss)
|
|
3536
|
+
history["val_loss"].append(val_loss)
|
|
3537
|
+
history["val_acc"].append(val_acc)
|
|
3538
|
+
history["val_ber"].append(val_ber)
|
|
3539
|
+
history["val_scale"].append(val_scale)
|
|
3540
|
+
history["val_epochs"].append(epoch + 1)
|
|
3541
|
+
history["lr"].append(optimizer.param_groups[0]["lr"])
|
|
3542
|
+
history["epoch_time_sec"].append(epoch_time)
|
|
3543
|
+
history["train_samples_per_sec"].append(speed)
|
|
3544
|
+
|
|
3545
|
+
should_eval_test = Config.EVAL_TEST_DURING_TRAINING and (
|
|
3546
|
+
(epoch + 1) % Config.TEST_BER_EVERY == 0 or epoch == 0 or epoch == Config.EPOCHS - 1
|
|
3547
|
+
)
|
|
3548
|
+
if should_eval_test:
|
|
3549
|
+
test_metrics = compute_test_metrics(model, {**data, "eval_batch_size": eval_batch_size})
|
|
3550
|
+
eval_batch_size = int(test_metrics["safe_eval_batch_size"])
|
|
3551
|
+
history["test_loss"].append(test_metrics["test_loss"])
|
|
3552
|
+
history["test_acc"].append(test_metrics["accuracy"])
|
|
3553
|
+
history["test_ber"].append(test_metrics["equalized_ber"])
|
|
3554
|
+
history["test_scale"].append(test_metrics["equalizer_scale"])
|
|
3555
|
+
history["test_epochs"].append(epoch + 1)
|
|
3556
|
+
log(
|
|
3557
|
+
f"{model_name} | epoch {epoch+1:4d}/{Config.EPOCHS} | "
|
|
3558
|
+
f"train {train_loss:.6f} | val {val_loss:.6f} | val_ber {val_ber:.6e} | "
|
|
3559
|
+
f"test_ber {test_metrics['equalized_ber']:.6e} | lr {optimizer.param_groups[0]['lr']:.2e} | "
|
|
3560
|
+
f"speed {speed:,.0f} samp/s | time {epoch_time:.1f}s"
|
|
3561
|
+
)
|
|
3562
|
+
elif epoch % Config.LOG_EVERY == 0 or epoch == Config.EPOCHS - 1:
|
|
3563
|
+
log(
|
|
3564
|
+
f"{model_name} | epoch {epoch+1:4d}/{Config.EPOCHS} | "
|
|
3565
|
+
f"train {train_loss:.6f} | val {val_loss:.6f} | val_ber {val_ber:.6e} | "
|
|
3566
|
+
f"lr {optimizer.param_groups[0]['lr']:.2e} | speed {speed:,.0f} samp/s | time {epoch_time:.1f}s"
|
|
3567
|
+
)
|
|
3568
|
+
|
|
3569
|
+
if train_loss < best_train_loss:
|
|
3570
|
+
best_train_loss = train_loss
|
|
3571
|
+
|
|
3572
|
+
if val_ber + Config.EARLY_STOPPING_THRESHOLD < best_val_ber:
|
|
3573
|
+
best_val_ber = val_ber
|
|
3574
|
+
early_stop_without_improvement = 0
|
|
3575
|
+
else:
|
|
3576
|
+
early_stop_without_improvement += 1
|
|
3577
|
+
if Config.SAVE_BEST_BY == "train_loss":
|
|
3578
|
+
monitor_value = train_loss
|
|
3579
|
+
elif Config.SAVE_BEST_BY == "val_ber":
|
|
3580
|
+
monitor_value = val_ber
|
|
3581
|
+
else:
|
|
3582
|
+
monitor_value = val_loss
|
|
3583
|
+
if monitor_value + Config.SCHEDULER_THRESHOLD < best_score:
|
|
3584
|
+
best_score = monitor_value
|
|
3585
|
+
steps_without_improvement = 0
|
|
3586
|
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
3587
|
+
if Config.SAVE_BEST:
|
|
3588
|
+
torch.save(best_state, Config.OUT_DIR / f"{model_name}_best.pth")
|
|
3589
|
+
else:
|
|
3590
|
+
steps_without_improvement += 1
|
|
3591
|
+
|
|
3592
|
+
if Config.LR_SCHEDULER == "notebook_decay" and steps_without_improvement >= Config.DECAY_STEPS:
|
|
3593
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
3594
|
+
new_lr = current_lr * Config.SCHEDULER_FACTOR
|
|
3595
|
+
steps_without_improvement = 0
|
|
3596
|
+
if new_lr < Config.MIN_LR:
|
|
3597
|
+
stop_reason = "lr_floor"
|
|
3598
|
+
log(f"{model_name} | epoch {epoch+1} -- stopping at lr floor ({current_lr:.6g})")
|
|
3599
|
+
break
|
|
3600
|
+
for param_group in optimizer.param_groups:
|
|
3601
|
+
param_group["lr"] = new_lr
|
|
3602
|
+
log(f"{model_name} | epoch {epoch+1} -- scheduler reduced lr to {new_lr:.6g}")
|
|
3603
|
+
|
|
3604
|
+
if (
|
|
3605
|
+
Config.EARLY_STOPPING
|
|
3606
|
+
and epoch + 1 >= Config.EARLY_STOPPING_MIN_EPOCHS
|
|
3607
|
+
and early_stop_without_improvement >= Config.EARLY_STOPPING_PATIENCE
|
|
3608
|
+
):
|
|
3609
|
+
early_stopped = True
|
|
3610
|
+
stop_reason = f"early_stop_val_ber_patience_{Config.EARLY_STOPPING_PATIENCE}"
|
|
3611
|
+
log(
|
|
3612
|
+
f"{model_name} | epoch {epoch+1} -- early stopping: "
|
|
3613
|
+
f"val_ber did not improve for {early_stop_without_improvement} epochs "
|
|
3614
|
+
f"(best_val_ber={best_val_ber:.6e})"
|
|
3615
|
+
)
|
|
3616
|
+
break
|
|
3617
|
+
|
|
3618
|
+
if best_state is None:
|
|
3619
|
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
3620
|
+
model.load_state_dict(best_state)
|
|
3621
|
+
final_metrics = compute_test_metrics(model, {**data, "eval_batch_size": eval_batch_size})
|
|
3622
|
+
final_metrics["trainable_params"] = count_trainable_parameters(model)
|
|
3623
|
+
final_metrics["best_val_ber"] = best_val_ber
|
|
3624
|
+
final_metrics["best_train_loss"] = best_train_loss
|
|
3625
|
+
final_metrics["train_samples_per_sec"] = float(np.mean(history["train_samples_per_sec"])) if history["train_samples_per_sec"] else 0.0
|
|
3626
|
+
final_metrics["mean_epoch_time_sec"] = float(np.mean(history["epoch_time_sec"])) if history["epoch_time_sec"] else 0.0
|
|
3627
|
+
final_metrics["epochs_ran"] = len(history["train_loss"])
|
|
3628
|
+
final_metrics["early_stopped"] = early_stopped
|
|
3629
|
+
final_metrics["stop_reason"] = stop_reason
|
|
3630
|
+
final_metrics = add_efficiency_metrics(model, data, final_metrics)
|
|
3631
|
+
model, final_metrics, eval_batch_size = maybe_prune_efficient_kan_model(
|
|
3632
|
+
model,
|
|
3633
|
+
model_name,
|
|
3634
|
+
data,
|
|
3635
|
+
final_metrics,
|
|
3636
|
+
eval_batch_size,
|
|
3637
|
+
)
|
|
3638
|
+
final_metrics["trainable_params"] = count_trainable_parameters(model)
|
|
3639
|
+
final_metrics["best_val_ber"] = best_val_ber
|
|
3640
|
+
final_metrics["best_train_loss"] = best_train_loss
|
|
3641
|
+
final_metrics["train_samples_per_sec"] = float(np.mean(history["train_samples_per_sec"])) if history["train_samples_per_sec"] else 0.0
|
|
3642
|
+
final_metrics["mean_epoch_time_sec"] = float(np.mean(history["epoch_time_sec"])) if history["epoch_time_sec"] else 0.0
|
|
3643
|
+
final_metrics["epochs_ran"] = len(history["train_loss"])
|
|
3644
|
+
final_metrics["early_stopped"] = early_stopped
|
|
3645
|
+
final_metrics["stop_reason"] = stop_reason
|
|
3646
|
+
if Config.COMPUTE_PER_FILE_METRICS:
|
|
3647
|
+
val_file_metrics, eval_batch_size = compute_split_file_metrics(model, data, "val", eval_batch_size)
|
|
3648
|
+
test_file_metrics, eval_batch_size = compute_split_file_metrics(model, data, "test", eval_batch_size)
|
|
3649
|
+
add_file_metric_summary(final_metrics, "val", val_file_metrics)
|
|
3650
|
+
add_file_metric_summary(final_metrics, "test", test_file_metrics)
|
|
3651
|
+
return model, history, final_metrics
|
|
3652
|
+
|
|
3653
|
+
|
|
3654
|
+
def main():
|
|
3655
|
+
Config.OUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
3656
|
+
log(f"Device: {Config.DEVICE}")
|
|
3657
|
+
|
|
3658
|
+
if Config.RUN_FASTKAN_CLASSIFIER_SWEEP:
|
|
3659
|
+
run_fastkan_classifier_sweep()
|
|
3660
|
+
return
|
|
3661
|
+
|
|
3662
|
+
if Config.RUN_KAN_EXPERIMENT_SUITE:
|
|
3663
|
+
run_kan_experiment_suite()
|
|
3664
|
+
return
|
|
3665
|
+
|
|
3666
|
+
if Config.RUN_SWEEP_EXPERIMENTS and not Config.RUN_MAIN_EXPERIMENTS:
|
|
3667
|
+
run_sweep_experiments()
|
|
3668
|
+
return
|
|
3669
|
+
|
|
3670
|
+
data = prepare_data()
|
|
3671
|
+
|
|
3672
|
+
all_results = []
|
|
3673
|
+
for model_name in Config.MODEL_TYPES:
|
|
3674
|
+
log(f"\nTraining {model_name.upper()}...")
|
|
3675
|
+
model, history, results = train_one_model(model_name, data)
|
|
3676
|
+
results["model_type"] = model_name
|
|
3677
|
+
all_results.append(results)
|
|
3678
|
+
|
|
3679
|
+
plot_results(history, results, model_name)
|
|
3680
|
+
torch.save(model.state_dict(), Config.OUT_DIR / f"{model_name}_final.pth")
|
|
3681
|
+
log(
|
|
3682
|
+
f"{model_name.upper()} | baseline BER {results['baseline_ber']:.6e} | "
|
|
3683
|
+
f"equalized BER {results['equalized_ber']:.6e} | acc {results['accuracy']:.4%} | "
|
|
3684
|
+
f"rel improvement {results['improvement_rel']:.2f}%"
|
|
3685
|
+
)
|
|
3686
|
+
|
|
3687
|
+
if Config.DEVICE.type == "cuda":
|
|
3688
|
+
model.to("cpu")
|
|
3689
|
+
del model
|
|
3690
|
+
torch.cuda.empty_cache()
|
|
3691
|
+
|
|
3692
|
+
plot_architecture_summary(all_results)
|
|
3693
|
+
log(f"Saved summary: {Config.OUT_DIR / 'architecture_comparison.csv'}")
|
|
3694
|
+
|
|
3695
|
+
if Config.RUN_SWEEP_EXPERIMENTS:
|
|
3696
|
+
run_sweep_experiments()
|
|
3697
|
+
|
|
3698
|
+
|
|
3699
|
+
if __name__ == "__main__":
|
|
3700
|
+
main()
|