ber-equalization-studio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3700 @@
1
+ import copy
2
+ import time
3
+ from contextlib import nullcontext
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+
15
+ try:
16
+ from efficient_kan import KAN as EfficientKAN
17
+ except ImportError:
18
+ EfficientKAN = None
19
+
20
+ try:
21
+ from mamba_ssm import Mamba
22
+ except ImportError:
23
+ Mamba = None
24
+
25
+
26
+ class Config:
27
+ DEVICE = torch.device(
28
+ "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
29
+ )
30
+ DATA_DIR_CANDIDATES = [Path("symbols_new"), Path("Symbols_1m_1ch_PR"), Path(".")]
31
+ MAX_FILES = 64
32
+ # File-level split: first TRAIN_PORTION files form train+val pool, the rest is a hold-out test set.
33
+ TRAIN_PORTION = 0.97
34
+ VAL_PORTION_WITHIN_TRAIN = 0.10
35
+ MIN_VAL_FILES = 1
36
+ RANDOMIZE_FILE_SPLIT = False
37
+ SPLIT_SEED = 42
38
+
39
+ CONTEXT_K = 32
40
+ SEQ_LEN = 2 * CONTEXT_K + 1
41
+ INPUT_DIM = 2
42
+
43
+ HIDDEN_DIM = 96 # надо 64
44
+ DROPOUT = 0.2
45
+ LSTM_HIDDEN = 64 # надо 64
46
+ LSTM_LAYERS = 2
47
+ BIDIRECTIONAL = True
48
+ USE_ATTENTION = True
49
+ TRANSFORMER_DIM = 128
50
+ TRANSFORMER_LAYERS = 4
51
+ TRANSFORMER_HEADS = 4
52
+ TRANSFORMER_FF_DIM = 256
53
+ TRANSFORMER_CONV_KERNEL = 3
54
+ TCN_HIDDEN_DIM = 96
55
+ TCN_LAYERS = 5
56
+ TCN_KERNEL_SIZE = 5
57
+ TCN_DILATIONS = [1, 2, 4, 8, 16]
58
+ MAMBA_DIM = 96
59
+ MAMBA_LAYERS = 4
60
+ MAMBA_D_STATE = 16
61
+ MAMBA_D_CONV = 4
62
+ MAMBA_EXPAND = 2
63
+ COMPLEX_CHANNELS = 24
64
+ COMPLEX_BLOCK_CHANNELS = [24, 32]
65
+ COMPLEX_KERNEL_SIZES = [5, 7]
66
+ COMPLEX_HEAD_DIM = 128
67
+ COMPLEX_TEMPORAL_DIM = 96
68
+ COMPLEX_TEMPORAL_DILATIONS = [1, 2, 4]
69
+ COMPLEX_LIGHT_CHANNELS = 48
70
+ COMPLEX_LIGHT_DILATIONS = [1, 2, 4]
71
+ COMPLEX_LIGHT_KERNEL_SIZE = 3
72
+ COMPLEX_SEQ_DIM = 96
73
+ COMPLEX_LSTM_HIDDEN = 64
74
+ COMPLEX_LSTM_LAYERS = 2
75
+ COMPLEX_USE_KERR = False
76
+ COMPLEX_USE_DBP_FRONTEND = False
77
+ COMPLEX_KERR_KERNEL = 5
78
+ COMPLEX_KERR_INIT_GAMMA = 0.02
79
+ DBP_NUM_STEPS = 20
80
+ DBP_KERNEL_SIZE = 7
81
+ DBP_FINAL_KERNEL_SIZE = 21
82
+ DBP_USE_FINAL_FILTER = True
83
+ DBP_USE_SYMMETRIC_FILTER = True
84
+ DBP_USE_SYMMETRIC_NONLINEAR = True
85
+ DBP_NL_MEMORY = 2
86
+ DBP_INIT_FROM_LS = False
87
+ DBP_INIT_SAMPLES = 65536
88
+ DBP_INIT_FFT_SIZE = 4096
89
+ DBP_JOINT_INIT = True
90
+ DBP_JOINT_INIT_ITERS = 200
91
+ DBP_JOINT_INIT_BATCH_SIZE = 1024
92
+ DBP_JOINT_INIT_LR = 2e-3
93
+ DBP_SEQSTAT_DIM = 128
94
+
95
+ FASTKAN_HIDDEN_DIM = 96
96
+ FASTKAN_LAYERS = 2
97
+ FASTKAN_NUM_GRIDS = 8
98
+ FASTKAN_GRID_MIN = -2.5
99
+ FASTKAN_GRID_MAX = 2.5
100
+ FASTKAN_BASE_ACT = "silu"
101
+ FASTKAN_USE_BASE_PATH = True
102
+ KAN_INPUT_DROPOUT = 0.05
103
+ KAN_HIDDEN_DROPOUT = 0.1
104
+ KAN_PRUNE_L1 = 1e-5
105
+ KAN_PRUNE_THRESHOLD = 0.02
106
+ KAN_STRUCTURAL_PRUNE_AFTER_TRAINING = False
107
+ KAN_STRUCTURAL_PRUNE_KEEP_RATIOS = [0.75, 0.5, 0.35, 0.25]
108
+ KAN_STRUCTURAL_PRUNE_MIN_HIDDEN = 16
109
+ KAN_STRUCTURAL_PRUNE_FINE_TUNE_EPOCHS = 20
110
+ KAN_STRUCTURAL_PRUNE_FINE_TUNE_LR = 2e-4
111
+ KAN_STRUCTURAL_PRUNE_SELECT_BY = "efficiency_score"
112
+ EFFICIENCY_BATCH_SIZE = 16000
113
+ EFFICIENCY_SCORE_POWER = 3.0
114
+ EFFICIENCY_TIMING_WARMUP = 5
115
+ EFFICIENCY_TIMING_REPEATS = 20
116
+ EFFICIENT_KAN_HIDDEN_DIM = 128
117
+ EFFICIENT_KAN_LAYERS = 2
118
+ EFFICIENT_KAN_GRID_SIZE = 20
119
+ EFFICIENT_KAN_SPLINE_ORDER = 3
120
+ EFFICIENT_KAN_GRID_EPS = 0.02
121
+ EFFICIENT_KAN_GRID_RANGE = [-3.0, 3.0]
122
+ EFFICIENT_KAN_SCALE_NOISE = 0.1
123
+ EFFICIENT_KAN_SCALE_BASE = 1.0
124
+ EFFICIENT_KAN_SCALE_SPLINE = 1.0
125
+ KAN_FEATURE_RADIUS = 2
126
+
127
+ EPOCHS = 250
128
+ LEARNING_RATE = 1e-3
129
+ WEIGHT_DECAY = 0.0
130
+ TRAIN_BLOCK_SIZE = 8192
131
+ EVAL_BATCH_SIZE = 65536
132
+ MIN_BLOCK_SIZE = 1024
133
+ USE_AMP = True
134
+ USE_TORCH_COMPILE = False
135
+ TORCH_COMPILE_MODE = "max-autotune-no-cudagraphs"
136
+ OPTIMIZER = "adam"
137
+ LOSS = "mse"
138
+ GRAD_CLIP_NORM = 1.0
139
+ LR_SCHEDULER = "notebook_decay"
140
+ SCHEDULER_FACTOR = 0.5
141
+ SCHEDULER_PATIENCE = 100
142
+ SCHEDULER_THRESHOLD = 1e-6
143
+
144
+ DECAY_STEPS = 24
145
+ MIN_LR = 1e-5
146
+ EARLY_STOPPING = True
147
+ EARLY_STOPPING_PATIENCE = 72
148
+ EARLY_STOPPING_MIN_EPOCHS = 40
149
+ EARLY_STOPPING_THRESHOLD = 0.0
150
+ LOG_EVERY = 1
151
+ TEST_BER_EVERY = 10
152
+ SAVE_BEST_BY = "val_ber"
153
+ EVAL_TEST_DURING_TRAINING = False
154
+ COMPUTE_PER_FILE_METRICS = True
155
+ POWER_NORMALIZE = True
156
+ BER_SCALE_SEARCH = True
157
+ BER_SCALE_MIN = 0.5
158
+ BER_SCALE_MAX = 1.5
159
+ BER_SCALE_STEPS = 10
160
+ BER_SCALE_OFFSET = 10000
161
+ BER_SCALE_SAMPLES = 1 << 20
162
+
163
+ OUT_DIR = Path("clean_compare_outputs")
164
+ RUN_MAIN_EXPERIMENTS = True
165
+ MODEL_TYPES = [
166
+ "efficient_kan_baseline",
167
+ "kan_classifier",
168
+
169
+ ]
170
+ SAVE_BEST = True
171
+ RUN_SWEEP_EXPERIMENTS = False
172
+ SWEEP_TEST_FILES = 1
173
+ WINDOW_SWEEP_VALUES = [2, 4, 8, 12, 16, 20]
174
+ HIDDEN_SWEEP_VALUES = [8, 16, 32, 64]
175
+ RUN_EFFICIENT_KAN_SWEEP = False
176
+ EFFICIENT_KAN_SWEEP_MODELS = ["mlp", "efficient_kan_baseline", "kan_classifier"]
177
+ EFFICIENT_KAN_SWEEP_EPOCHS = 60
178
+ EFFICIENT_KAN_SWEEP_TEST_FILES = 1
179
+ EFFICIENT_KAN_HIDDEN_SWEEP_VALUES = [64, 96, 128, 192]
180
+ EFFICIENT_KAN_LR_SWEEP_VALUES = [1e-3]
181
+ EFFICIENT_KAN_GRID_SWEEP_VALUES = [4, 8, 12, 16]
182
+ EFFICIENT_KAN_ORDER_SWEEP_VALUES = [1, 2, 3, 4]
183
+ EFFICIENT_KAN_LAYER_SWEEP_VALUES = [1, 2, 3]
184
+
185
+ RUN_KAN_EXPERIMENT_SUITE = False
186
+ EXPERIMENT_EPOCHS = 60
187
+ EXPERIMENT_TEST_FILES = 1
188
+ EXPERIMENT_COMPUTE_PER_FILE_METRICS = False
189
+ EXPERIMENT_KAN_MODELS = ["efficient_kan_baseline", "kan_classifier"]
190
+ EXPERIMENT_COMPARE_MODELS = ["efficient_kan_baseline", "mlp"]
191
+ EXPERIMENT_COMPLEXITY_MODELS = ["complex_fastkan", "efficient_kan_baseline"]
192
+ EXPERIMENT_FIXED_GRID = 16
193
+ EXPERIMENT_FIXED_SPLINE_ORDER = 3
194
+ EXPERIMENT_HIDDEN_VALUES = [64, 96, 128, 192]
195
+ EXPERIMENT_WINDOW_VALUES = [8, 16, 24, 32, 48]
196
+ EXPERIMENT_GRID_VALUES = [4, 8, 12, 16, 20]
197
+ EXPERIMENT_SPLINE_ORDER_VALUES = [1, 2, 3, 4]
198
+ EXPERIMENT_LAYER_VALUES = [1, 2, 3]
199
+ MLP_LAYERS = 3
200
+
201
+ RUN_FASTKAN_CLASSIFIER_SWEEP = True
202
+ FASTKAN_CLASSIFIER_SWEEP_MODELS = ["fastkan_classifier", "complex_fastkan_classifier"]
203
+ FASTKAN_CLASSIFIER_SWEEP_EPOCHS = 60
204
+ FASTKAN_CLASSIFIER_SWEEP_TEST_FILES = 1
205
+ FASTKAN_CLASSIFIER_HIDDEN_VALUES = [16, 32, 48, 64, 96]
206
+ FASTKAN_CLASSIFIER_GRID_VALUES = [4, 8, 12, 16]
207
+ FASTKAN_CLASSIFIER_LAYER_VALUES = [1, 2]
208
+
209
+
210
+ if Config.DEVICE.type == "cuda":
211
+ torch.backends.cudnn.benchmark = True
212
+ torch.backends.cuda.matmul.allow_tf32 = True
213
+ torch.backends.cudnn.allow_tf32 = True
214
+ torch.set_float32_matmul_precision("high")
215
+
216
+
217
+ CONSTELLATION = torch.tensor(
218
+ [
219
+ [-0.948683, -0.948683],
220
+ [-0.948683, -0.316228],
221
+ [-0.948683, 0.316228],
222
+ [-0.948683, 0.948683],
223
+ [-0.316228, -0.948683],
224
+ [-0.316228, -0.316228],
225
+ [-0.316228, 0.316228],
226
+ [-0.316228, 0.948683],
227
+ [0.316228, -0.948683],
228
+ [0.316228, -0.316228],
229
+ [0.316228, 0.316228],
230
+ [0.316228, 0.948683],
231
+ [0.948683, -0.948683],
232
+ [0.948683, -0.316228],
233
+ [0.948683, 0.316228],
234
+ [0.948683, 0.948683],
235
+ ],
236
+ dtype=torch.float32,
237
+ )
238
+
239
+ BIT_LABELS = torch.tensor(
240
+ [
241
+ [0, 0, 0, 0],
242
+ [0, 0, 0, 1],
243
+ [0, 0, 1, 1],
244
+ [0, 0, 1, 0],
245
+ [0, 1, 0, 0],
246
+ [0, 1, 0, 1],
247
+ [0, 1, 1, 1],
248
+ [0, 1, 1, 0],
249
+ [1, 1, 0, 0],
250
+ [1, 1, 0, 1],
251
+ [1, 1, 1, 1],
252
+ [1, 1, 1, 0],
253
+ [1, 0, 0, 0],
254
+ [1, 0, 0, 1],
255
+ [1, 0, 1, 1],
256
+ [1, 0, 1, 0],
257
+ ],
258
+ dtype=torch.uint8,
259
+ )
260
+
261
+
262
+ class AttentionLayer(nn.Module):
263
+ def __init__(self, hidden_dim: int):
264
+ super().__init__()
265
+ self.attention = nn.Sequential(
266
+ nn.Linear(hidden_dim, hidden_dim // 2),
267
+ nn.Tanh(),
268
+ nn.Linear(hidden_dim // 2, 1),
269
+ )
270
+ self.softmax = nn.Softmax(dim=1)
271
+
272
+ def forward(self, lstm_output: torch.Tensor) -> torch.Tensor:
273
+ weights = self.softmax(self.attention(lstm_output))
274
+ return torch.sum(weights * lstm_output, dim=1)
275
+
276
+
277
+ class LSTMRxEqualizer(nn.Module):
278
+ def __init__(self):
279
+ super().__init__()
280
+ lstm_out_dim = Config.LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
281
+ self.embedding = nn.Sequential(
282
+ nn.Linear(Config.INPUT_DIM, 32),
283
+ nn.LayerNorm(32),
284
+ nn.GELU(),
285
+ nn.Dropout(Config.DROPOUT * 0.5),
286
+ )
287
+ self.lstm = nn.LSTM(
288
+ input_size=32,
289
+ hidden_size=Config.LSTM_HIDDEN,
290
+ num_layers=Config.LSTM_LAYERS,
291
+ batch_first=True,
292
+ bidirectional=Config.BIDIRECTIONAL,
293
+ dropout=Config.DROPOUT if Config.LSTM_LAYERS > 1 else 0.0,
294
+ )
295
+ self.use_attention = Config.USE_ATTENTION
296
+ self.attention = AttentionLayer(lstm_out_dim)
297
+ self.center_fusion = nn.Sequential(
298
+ nn.Linear(lstm_out_dim * 2, lstm_out_dim),
299
+ nn.LayerNorm(lstm_out_dim),
300
+ nn.GELU(),
301
+ )
302
+ self.lstm_norm = nn.LayerNorm(lstm_out_dim)
303
+ self.classifier = nn.Sequential(
304
+ nn.Linear(lstm_out_dim, Config.HIDDEN_DIM),
305
+ nn.LayerNorm(Config.HIDDEN_DIM),
306
+ nn.GELU(),
307
+ nn.Dropout(Config.DROPOUT),
308
+ nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
309
+ nn.LayerNorm(Config.HIDDEN_DIM // 2),
310
+ nn.GELU(),
311
+ nn.Dropout(Config.DROPOUT * 0.5),
312
+ nn.Linear(Config.HIDDEN_DIM // 2, 2),
313
+ )
314
+ self._init_weights()
315
+
316
+ def _init_weights(self):
317
+ for name, param in self.lstm.named_parameters():
318
+ if "weight_ih" in name:
319
+ nn.init.xavier_uniform_(param.data)
320
+ elif "weight_hh" in name:
321
+ nn.init.orthogonal_(param.data)
322
+ elif "bias" in name:
323
+ nn.init.constant_(param.data, 0)
324
+ gate = Config.LSTM_HIDDEN
325
+ param.data[gate : 2 * gate] = 1.0
326
+ for module in self.modules():
327
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
328
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
329
+ if module.bias is not None:
330
+ nn.init.constant_(module.bias, 0.0)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
334
+ x = self.embedding(x)
335
+ lstm_out, (hidden, _) = self.lstm(x)
336
+ center_feature = lstm_out[:, Config.CONTEXT_K, :]
337
+ if self.use_attention:
338
+ context = self.attention(lstm_out)
339
+ elif Config.BIDIRECTIONAL:
340
+ context = torch.cat([hidden[-2], hidden[-1]], dim=1)
341
+ else:
342
+ context = hidden[-1]
343
+ context = self.center_fusion(torch.cat([context, center_feature], dim=1))
344
+ context = self.lstm_norm(context)
345
+ return self.classifier(context)
346
+
347
+
348
+ class HybridCNNLSTMEqualizer(nn.Module):
349
+ def __init__(self):
350
+ super().__init__()
351
+ self.cnn = nn.Sequential(
352
+ nn.Conv1d(2, 64, kernel_size=3, padding=1),
353
+ nn.BatchNorm1d(64),
354
+ nn.GELU(),
355
+ nn.Dropout(Config.DROPOUT * 0.3),
356
+ nn.Conv1d(64, 128, kernel_size=3, padding=1),
357
+ nn.BatchNorm1d(128),
358
+ nn.GELU(),
359
+ nn.Dropout(Config.DROPOUT * 0.3),
360
+ nn.Conv1d(128, 256, kernel_size=3, padding=1),
361
+ nn.BatchNorm1d(256),
362
+ nn.GELU(),
363
+ nn.Dropout(Config.DROPOUT * 0.3),
364
+ )
365
+ self.lstm = nn.LSTM(
366
+ input_size=256,
367
+ hidden_size=Config.LSTM_HIDDEN,
368
+ num_layers=2,
369
+ batch_first=True,
370
+ bidirectional=Config.BIDIRECTIONAL,
371
+ dropout=Config.DROPOUT,
372
+ )
373
+ out_dim = Config.LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
374
+ self.attention = AttentionLayer(out_dim)
375
+ self.classifier = nn.Sequential(
376
+ nn.Linear(out_dim, Config.HIDDEN_DIM),
377
+ nn.LayerNorm(Config.HIDDEN_DIM),
378
+ nn.GELU(),
379
+ nn.Dropout(Config.DROPOUT),
380
+ nn.Linear(Config.HIDDEN_DIM, 2),
381
+ )
382
+ self._init_weights()
383
+
384
+ def _init_weights(self):
385
+ for module in self.modules():
386
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
387
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
388
+ if module.bias is not None:
389
+ nn.init.constant_(module.bias, 0.0)
390
+
391
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
392
+ x = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM).transpose(1, 2)
393
+ x = self.cnn(x).transpose(1, 2)
394
+ x, _ = self.lstm(x)
395
+ x = self.attention(x)
396
+ return self.classifier(x)
397
+
398
+
399
+ class CNNRxEqualizer(nn.Module):
400
+ def __init__(self):
401
+ super().__init__()
402
+ hidden_dim = Config.HIDDEN_DIM
403
+ self.cnn = nn.Sequential(
404
+ nn.Conv1d(Config.INPUT_DIM, hidden_dim, kernel_size=5, padding=2),
405
+ nn.BatchNorm1d(hidden_dim),
406
+ nn.GELU(),
407
+ nn.Dropout(Config.DROPOUT * 0.25),
408
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
409
+ nn.BatchNorm1d(hidden_dim),
410
+ nn.GELU(),
411
+ nn.Dropout(Config.DROPOUT * 0.25),
412
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
413
+ nn.BatchNorm1d(hidden_dim),
414
+ nn.GELU(),
415
+ nn.Dropout(Config.DROPOUT * 0.25),
416
+ )
417
+ self.pool_score = nn.Conv1d(hidden_dim, 1, kernel_size=1)
418
+ fused_dim = hidden_dim * 2 + 2 * Config.INPUT_DIM
419
+ self.head = nn.Sequential(
420
+ nn.Linear(fused_dim, hidden_dim),
421
+ nn.LayerNorm(hidden_dim),
422
+ nn.GELU(),
423
+ nn.Dropout(Config.DROPOUT),
424
+ nn.Linear(hidden_dim, hidden_dim // 2),
425
+ nn.LayerNorm(hidden_dim // 2),
426
+ nn.GELU(),
427
+ nn.Dropout(Config.DROPOUT * 0.5),
428
+ nn.Linear(hidden_dim // 2, 2),
429
+ )
430
+ self._init_weights()
431
+
432
+ def _init_weights(self):
433
+ for module in self.modules():
434
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
435
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
436
+ if module.bias is not None:
437
+ nn.init.constant_(module.bias, 0.0)
438
+ elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
439
+ if hasattr(module, "weight") and module.weight is not None:
440
+ nn.init.constant_(module.weight, 1.0)
441
+ if hasattr(module, "bias") and module.bias is not None:
442
+ nn.init.constant_(module.bias, 0.0)
443
+
444
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
445
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
446
+ seq = raw.transpose(1, 2)
447
+ features = self.cnn(seq)
448
+ center = features[:, :, Config.CONTEXT_K]
449
+ weights = torch.softmax(self.pool_score(features), dim=2)
450
+ global_context = torch.sum(weights * features, dim=2)
451
+ raw_center = raw[:, Config.CONTEXT_K, :]
452
+ raw_mean = raw.mean(dim=1)
453
+ fused = torch.cat([center, global_context, raw_center, raw_mean], dim=1)
454
+ return self.head(fused)
455
+
456
+
457
+ class ComplexConv1d(nn.Module):
458
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, groups: int = 1, symmetric: bool = False):
459
+ super().__init__()
460
+ self.symmetric = symmetric
461
+ effective_kernel = 2 * kernel_size - 1 if symmetric else kernel_size
462
+ padding = effective_kernel // 2
463
+ self.real_conv = nn.Conv1d(
464
+ in_channels,
465
+ out_channels,
466
+ kernel_size=kernel_size,
467
+ padding=padding,
468
+ groups=groups,
469
+ bias=False,
470
+ )
471
+ self.imag_conv = nn.Conv1d(
472
+ in_channels,
473
+ out_channels,
474
+ kernel_size=kernel_size,
475
+ padding=padding,
476
+ groups=groups,
477
+ bias=False,
478
+ )
479
+ self._init_weights()
480
+
481
+ def _init_weights(self):
482
+ nn.init.kaiming_normal_(self.real_conv.weight, nonlinearity="linear")
483
+ nn.init.kaiming_normal_(self.imag_conv.weight, nonlinearity="linear")
484
+
485
+ def _build_weight(self, weight: torch.Tensor) -> torch.Tensor:
486
+ if not self.symmetric:
487
+ return weight
488
+ return torch.cat([weight, weight.flip(dims=(2,))[:, :, 1:]], dim=2)
489
+
490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
491
+ real = x[:, 0::2, :]
492
+ imag = x[:, 1::2, :]
493
+ real_weight = self._build_weight(self.real_conv.weight)
494
+ imag_weight = self._build_weight(self.imag_conv.weight)
495
+ out_real = F.conv1d(real, real_weight, padding=self.real_conv.padding[0], groups=self.real_conv.groups) - F.conv1d(
496
+ imag, imag_weight, padding=self.imag_conv.padding[0], groups=self.imag_conv.groups
497
+ )
498
+ out_imag = F.conv1d(real, imag_weight, padding=self.imag_conv.padding[0], groups=self.imag_conv.groups) + F.conv1d(
499
+ imag, real_weight, padding=self.real_conv.padding[0], groups=self.real_conv.groups
500
+ )
501
+ return torch.stack((out_real, out_imag), dim=2).flatten(1, 2)
502
+
503
+
504
+ class ComplexResidualBlock(nn.Module):
505
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
506
+ super().__init__()
507
+ self.in_channels = in_channels
508
+ self.out_channels = out_channels
509
+ self.pre_norm = nn.BatchNorm1d(2 * in_channels)
510
+ self.expand = ComplexConv1d(in_channels, out_channels, kernel_size=1)
511
+ self.expand_norm = nn.BatchNorm1d(2 * out_channels)
512
+ self.depthwise = ComplexConv1d(out_channels, out_channels, kernel_size=kernel_size, groups=out_channels)
513
+ self.depthwise_norm = nn.BatchNorm1d(2 * out_channels)
514
+ self.kerr = KerrLikeActivation(out_channels) if Config.COMPLEX_USE_KERR else nn.Identity()
515
+ self.project = ComplexConv1d(out_channels, out_channels, kernel_size=1)
516
+ self.project_norm = nn.BatchNorm1d(2 * out_channels)
517
+ self.activation = nn.GELU()
518
+ self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
519
+ self.skip = ComplexConv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
520
+
521
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
522
+ residual = self.skip(x)
523
+ x = self.pre_norm(x)
524
+ x = self.activation(self.expand_norm(self.expand(x)))
525
+ x = self.activation(self.depthwise_norm(self.depthwise(x)))
526
+ x = self.kerr(x)
527
+ x = self.dropout(self.project_norm(self.project(x)))
528
+ return residual + x
529
+
530
+
531
+ class KerrLikeActivation(nn.Module):
532
+ def __init__(self, channels: int, kernel_size: Optional[int] = None, init_gamma: Optional[float] = None, symmetric: bool = False):
533
+ super().__init__()
534
+ kernel_size = Config.COMPLEX_KERR_KERNEL if kernel_size is None else kernel_size
535
+ self.symmetric = symmetric
536
+ self.power_filter = nn.Conv1d(
537
+ channels,
538
+ channels,
539
+ kernel_size=kernel_size,
540
+ padding=kernel_size // 2,
541
+ groups=channels,
542
+ bias=False,
543
+ )
544
+ gamma = Config.COMPLEX_KERR_INIT_GAMMA if init_gamma is None else init_gamma
545
+ self.gamma = nn.Parameter(torch.full((1, channels, 1), gamma))
546
+ self._init_weights()
547
+
548
+ def _init_weights(self):
549
+ nn.init.zeros_(self.power_filter.weight)
550
+ center = self.power_filter.weight.size(-1) // 2
551
+ self.power_filter.weight.data[:, :, center] = 1.0
552
+
553
+ def _build_weight(self) -> torch.Tensor:
554
+ weight = self.power_filter.weight
555
+ if not self.symmetric:
556
+ return weight
557
+ return torch.cat([weight, weight.flip(dims=(2,))[:, :, 1:]], dim=2)
558
+
559
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
560
+ real = x[:, 0::2, :]
561
+ imag = x[:, 1::2, :]
562
+ power_weight = self._build_weight()
563
+ power = F.conv1d(
564
+ real.square() + imag.square(),
565
+ power_weight,
566
+ padding=power_weight.size(-1) // 2,
567
+ groups=self.power_filter.groups,
568
+ )
569
+ phase = self.gamma * power
570
+ cos_phase = torch.cos(phase)
571
+ sin_phase = torch.sin(phase)
572
+ out_real = cos_phase * real + sin_phase * imag
573
+ out_imag = cos_phase * imag - sin_phase * real
574
+ return torch.stack((out_real, out_imag), dim=2).flatten(1, 2)
575
+
576
+
577
+ class TemporalConvBlock(nn.Module):
578
+ def __init__(self, channels: int, dilation: int):
579
+ super().__init__()
580
+ kernel_size = 3
581
+ padding = dilation * (kernel_size - 1) // 2
582
+ self.norm = nn.BatchNorm1d(channels)
583
+ self.depthwise = nn.Conv1d(
584
+ channels,
585
+ channels,
586
+ kernel_size=kernel_size,
587
+ padding=padding,
588
+ dilation=dilation,
589
+ groups=channels,
590
+ bias=False,
591
+ )
592
+ self.pointwise = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
593
+ self.activation = nn.GELU()
594
+ self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
595
+ self._init_weights()
596
+
597
+ def _init_weights(self):
598
+ nn.init.kaiming_normal_(self.depthwise.weight, nonlinearity="relu")
599
+ nn.init.kaiming_normal_(self.pointwise.weight, nonlinearity="relu")
600
+
601
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
602
+ residual = x
603
+ x = self.activation(self.norm(x))
604
+ x = self.depthwise(x)
605
+ x = self.activation(self.pointwise(x))
606
+ x = self.dropout(x)
607
+ return x + residual
608
+
609
+
610
+ class LightweightComplexTemporalBlock(nn.Module):
611
+ def __init__(self, channels: int, kernel_size: int, dilation: int):
612
+ super().__init__()
613
+ padding = dilation * (kernel_size - 1) // 2
614
+ self.norm = nn.BatchNorm1d(channels)
615
+ self.depthwise = nn.Conv1d(
616
+ channels,
617
+ channels,
618
+ kernel_size=kernel_size,
619
+ padding=padding,
620
+ dilation=dilation,
621
+ groups=channels,
622
+ bias=False,
623
+ )
624
+ self.mix = nn.Conv1d(channels, channels * 2, kernel_size=1, bias=False)
625
+ self.gate_proj = nn.Conv1d(1, channels * 2, kernel_size=1, bias=True)
626
+ self.out_proj = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
627
+ self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
628
+ self._init_weights()
629
+
630
+ def _init_weights(self):
631
+ for module in (self.depthwise, self.mix, self.gate_proj, self.out_proj):
632
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
633
+ if module.bias is not None:
634
+ nn.init.constant_(module.bias, 0.0)
635
+
636
+ def forward(self, x: torch.Tensor, power: torch.Tensor) -> torch.Tensor:
637
+ residual = x
638
+ x = self.norm(x)
639
+ x = self.depthwise(x)
640
+ value, gate = self.mix(x).chunk(2, dim=1)
641
+ power_gate = self.gate_proj(power)
642
+ value = F.gelu(value + power_gate[:, : value.size(1), :])
643
+ gate = torch.sigmoid(gate + power_gate[:, value.size(1) :, :])
644
+ x = self.out_proj(value * gate)
645
+ x = self.dropout(x)
646
+ return x + residual
647
+
648
+
649
+ class LightweightComplexEncoder(nn.Module):
650
+ def __init__(self):
651
+ super().__init__()
652
+ channels = Config.COMPLEX_LIGHT_CHANNELS
653
+ self.stem = nn.Sequential(
654
+ nn.Conv1d(5, channels, kernel_size=1, bias=False),
655
+ nn.BatchNorm1d(channels),
656
+ nn.GELU(),
657
+ )
658
+ self.blocks = nn.ModuleList(
659
+ [
660
+ LightweightComplexTemporalBlock(
661
+ channels=channels,
662
+ kernel_size=Config.COMPLEX_LIGHT_KERNEL_SIZE,
663
+ dilation=dilation,
664
+ )
665
+ for dilation in Config.COMPLEX_LIGHT_DILATIONS
666
+ ]
667
+ )
668
+ self.out_norm = nn.BatchNorm1d(channels)
669
+ self.out_channels = channels
670
+ self._init_weights()
671
+
672
+ def _init_weights(self):
673
+ for module in self.modules():
674
+ if isinstance(module, nn.Conv1d):
675
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
676
+ if module.bias is not None:
677
+ nn.init.constant_(module.bias, 0.0)
678
+ elif isinstance(module, nn.BatchNorm1d):
679
+ nn.init.constant_(module.weight, 1.0)
680
+ nn.init.constant_(module.bias, 0.0)
681
+
682
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
683
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
684
+ seq = raw.transpose(1, 2)
685
+ real = seq[:, 0:1, :]
686
+ imag = seq[:, 1:2, :]
687
+ power = real.square() + imag.square()
688
+ magnitude = torch.sqrt(power + 1e-6)
689
+ cross = real * imag
690
+ features = torch.cat([real, imag, magnitude, power, cross], dim=1)
691
+ hidden = self.stem(features)
692
+ for block in self.blocks:
693
+ hidden = block(hidden, power)
694
+ hidden = self.out_norm(hidden)
695
+ seq_features = torch.cat([hidden, real, imag, magnitude], dim=1).transpose(1, 2)
696
+ return raw, hidden, seq_features
697
+
698
+
699
+ class GaussianRBFExpansion(nn.Module):
700
+ def __init__(self, num_grids: int, grid_min: float, grid_max: float):
701
+ super().__init__()
702
+ centers = torch.linspace(grid_min, grid_max, num_grids)
703
+ spacing = float(centers[1] - centers[0]) if num_grids > 1 else max(abs(grid_max - grid_min), 1.0)
704
+ self.register_buffer("centers", centers)
705
+ self.log_inv_scale = nn.Parameter(torch.log(torch.tensor(1.0 / max(spacing, 1e-3))))
706
+
707
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
708
+ inv_scale = self.log_inv_scale.exp().clamp(1e-3, 1e3)
709
+ diff = (x.unsqueeze(-1) - self.centers) * inv_scale
710
+ return torch.exp(-(diff * diff))
711
+
712
+
713
+ class FastKANLayer(nn.Module):
714
+ def __init__(self, in_features: int, out_features: int):
715
+ super().__init__()
716
+ self.in_features = in_features
717
+ self.out_features = out_features
718
+ self.rbf = GaussianRBFExpansion(
719
+ num_grids=Config.FASTKAN_NUM_GRIDS,
720
+ grid_min=Config.FASTKAN_GRID_MIN,
721
+ grid_max=Config.FASTKAN_GRID_MAX,
722
+ )
723
+ self.base_linear = nn.Linear(in_features, out_features)
724
+ self.spline_linear = nn.Linear(in_features * Config.FASTKAN_NUM_GRIDS, out_features)
725
+ self.norm = nn.LayerNorm(out_features)
726
+ self.dropout = nn.Dropout(Config.KAN_HIDDEN_DROPOUT)
727
+ self._init_weights()
728
+
729
+ def _init_weights(self):
730
+ nn.init.xavier_uniform_(self.base_linear.weight)
731
+ nn.init.zeros_(self.base_linear.bias)
732
+ nn.init.xavier_uniform_(self.spline_linear.weight)
733
+ nn.init.zeros_(self.spline_linear.bias)
734
+ nn.init.constant_(self.norm.weight, 1.0)
735
+ nn.init.constant_(self.norm.bias, 0.0)
736
+
737
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
738
+ if Config.FASTKAN_BASE_ACT == "gelu":
739
+ base = F.gelu(x)
740
+ else:
741
+ base = F.silu(x)
742
+ base_out = self.base_linear(base) if Config.FASTKAN_USE_BASE_PATH else 0.0
743
+ spline_basis = self.rbf(x).flatten(start_dim=1)
744
+ spline_out = self.spline_linear(spline_basis)
745
+ out = self.norm(base_out + spline_out)
746
+ return self.dropout(F.gelu(out))
747
+
748
+ def regularization_loss(self) -> torch.Tensor:
749
+ return self.spline_linear.weight.abs().mean()
750
+
751
+
752
+ class FastKANHead(nn.Module):
753
+ def __init__(self, input_dim: int, hidden_dim: int, out_dim: int):
754
+ super().__init__()
755
+ self.input_norm = nn.LayerNorm(input_dim)
756
+ self.input_dropout = nn.Dropout(Config.KAN_INPUT_DROPOUT)
757
+ self.feature_gate = nn.Parameter(torch.ones(input_dim))
758
+ self.layers = nn.ModuleList()
759
+ in_dim = input_dim
760
+ for _ in range(Config.FASTKAN_LAYERS):
761
+ self.layers.append(FastKANLayer(in_dim, hidden_dim))
762
+ in_dim = hidden_dim
763
+ self.output = nn.Linear(in_dim, out_dim)
764
+ self._init_weights()
765
+
766
+ def _init_weights(self):
767
+ nn.init.xavier_uniform_(self.output.weight)
768
+ nn.init.zeros_(self.output.bias)
769
+ nn.init.constant_(self.input_norm.weight, 1.0)
770
+ nn.init.constant_(self.input_norm.bias, 0.0)
771
+
772
+ def _gated_features(self, x: torch.Tensor) -> torch.Tensor:
773
+ gate = self.feature_gate
774
+ if not self.training and Config.KAN_PRUNE_THRESHOLD > 0:
775
+ gate = gate * (gate.abs() >= Config.KAN_PRUNE_THRESHOLD)
776
+ return x * gate.unsqueeze(0)
777
+
778
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
779
+ x = self.input_dropout(self.input_norm(x))
780
+ x = self._gated_features(x)
781
+ for layer in self.layers:
782
+ x = layer(x)
783
+ return self.output(x)
784
+
785
+ def regularization_loss(self) -> torch.Tensor:
786
+ reg = self.feature_gate.abs().mean()
787
+ for layer in self.layers:
788
+ reg = reg + layer.regularization_loss()
789
+ return reg
790
+
791
+
792
+ class EfficientKANBaselineEqualizer(nn.Module):
793
+ def __init__(self):
794
+ super().__init__()
795
+ if EfficientKAN is None:
796
+ raise ImportError(
797
+ "efficient_kan is unavailable. Keep the local `efficient_kan/` package next to this script or "
798
+ "install the library before running."
799
+ )
800
+
801
+ input_dim = Config.SEQ_LEN * Config.INPUT_DIM
802
+ hidden_layers = [Config.EFFICIENT_KAN_HIDDEN_DIM] * max(Config.EFFICIENT_KAN_LAYERS, 1)
803
+ self.kan = EfficientKAN(
804
+ layers_hidden=[input_dim, *hidden_layers, 2],
805
+ grid_size=Config.EFFICIENT_KAN_GRID_SIZE,
806
+ spline_order=Config.EFFICIENT_KAN_SPLINE_ORDER,
807
+ scale_noise=Config.EFFICIENT_KAN_SCALE_NOISE,
808
+ scale_base=Config.EFFICIENT_KAN_SCALE_BASE,
809
+ scale_spline=Config.EFFICIENT_KAN_SCALE_SPLINE,
810
+ base_activation=nn.SiLU,
811
+ grid_eps=Config.EFFICIENT_KAN_GRID_EPS,
812
+ grid_range=Config.EFFICIENT_KAN_GRID_RANGE,
813
+ )
814
+
815
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
816
+ return self.kan(x)
817
+
818
+ def regularization_loss(self) -> torch.Tensor:
819
+ if hasattr(self.kan, "regularization_loss"):
820
+ reg = self.kan.regularization_loss()
821
+ if torch.is_tensor(reg):
822
+ return reg
823
+ return torch.tensor(float(reg), device=Config.DEVICE)
824
+ return torch.zeros((), device=Config.DEVICE)
825
+
826
+
827
+ class FastKANClassifierEqualizer(nn.Module):
828
+ def __init__(self):
829
+ super().__init__()
830
+ input_dim = Config.SEQ_LEN * Config.INPUT_DIM
831
+ self.head = FastKANHead(
832
+ input_dim=input_dim,
833
+ hidden_dim=Config.FASTKAN_HIDDEN_DIM,
834
+ out_dim=CONSTELLATION.size(0),
835
+ )
836
+
837
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
838
+ return self.head(x)
839
+
840
+ def regularization_loss(self) -> torch.Tensor:
841
+ return self.head.regularization_loss()
842
+
843
+
844
+ def make_efficient_kan(input_dim: int, output_dim: int) -> nn.Module:
845
+ if EfficientKAN is None:
846
+ raise ImportError(
847
+ "efficient_kan is unavailable. Keep the local `efficient_kan/` package next to this script or "
848
+ "install the library before running."
849
+ )
850
+ hidden_layers = [Config.EFFICIENT_KAN_HIDDEN_DIM] * max(Config.EFFICIENT_KAN_LAYERS, 1)
851
+ return EfficientKAN(
852
+ layers_hidden=[input_dim, *hidden_layers, output_dim],
853
+ grid_size=Config.EFFICIENT_KAN_GRID_SIZE,
854
+ spline_order=Config.EFFICIENT_KAN_SPLINE_ORDER,
855
+ scale_noise=Config.EFFICIENT_KAN_SCALE_NOISE,
856
+ scale_base=Config.EFFICIENT_KAN_SCALE_BASE,
857
+ scale_spline=Config.EFFICIENT_KAN_SCALE_SPLINE,
858
+ base_activation=nn.SiLU,
859
+ grid_eps=Config.EFFICIENT_KAN_GRID_EPS,
860
+ grid_range=Config.EFFICIENT_KAN_GRID_RANGE,
861
+ )
862
+
863
+
864
+ def efficient_kan_regularization(kan: nn.Module) -> torch.Tensor:
865
+ if hasattr(kan, "regularization_loss"):
866
+ reg = kan.regularization_loss()
867
+ if torch.is_tensor(reg):
868
+ return reg
869
+ return torch.tensor(float(reg), device=Config.DEVICE)
870
+ return torch.zeros((), device=Config.DEVICE)
871
+
872
+
873
+ class EfficientKANResidualEqualizer(nn.Module):
874
+ def __init__(self):
875
+ super().__init__()
876
+ input_dim = Config.SEQ_LEN * Config.INPUT_DIM
877
+ self.kan = make_efficient_kan(input_dim=input_dim, output_dim=2)
878
+ self.residual_scale = nn.Parameter(torch.tensor(1.0))
879
+
880
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
881
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
882
+ center_rx = raw[:, Config.CONTEXT_K, :]
883
+ correction = self.kan(x)
884
+ return center_rx + self.residual_scale * correction
885
+
886
+ def regularization_loss(self) -> torch.Tensor:
887
+ return efficient_kan_regularization(self.kan)
888
+
889
+
890
+ class EfficientKANFeatureEqualizer(nn.Module):
891
+ def __init__(self):
892
+ super().__init__()
893
+ self.feature_dim = 17
894
+ self.kan = make_efficient_kan(input_dim=self.feature_dim, output_dim=2)
895
+
896
+ def _features(self, x: torch.Tensor) -> torch.Tensor:
897
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
898
+ center = raw[:, Config.CONTEXT_K, :]
899
+ radius = min(Config.KAN_FEATURE_RADIUS, Config.CONTEXT_K)
900
+ local = raw[:, Config.CONTEXT_K - radius : Config.CONTEXT_K + radius + 1, :]
901
+
902
+ global_mean = raw.mean(dim=1)
903
+ global_std = raw.std(dim=1, unbiased=False)
904
+ local_mean = local.mean(dim=1)
905
+ local_std = local.std(dim=1, unbiased=False)
906
+
907
+ power = raw.square().sum(dim=2)
908
+ power_center = power[:, Config.CONTEXT_K : Config.CONTEXT_K + 1]
909
+ power_mean = power.mean(dim=1, keepdim=True)
910
+ power_std = power.std(dim=1, unbiased=False, keepdim=True)
911
+
912
+ cross = raw[:, :, 0] * raw[:, :, 1]
913
+ cross_mean = cross.mean(dim=1, keepdim=True)
914
+ cross_std = cross.std(dim=1, unbiased=False, keepdim=True)
915
+ edge_delta = raw[:, -1, :] - raw[:, 0, :]
916
+
917
+ return torch.cat(
918
+ [
919
+ center,
920
+ local_mean,
921
+ local_std,
922
+ global_mean,
923
+ global_std,
924
+ power_center,
925
+ power_mean,
926
+ power_std,
927
+ cross_mean,
928
+ cross_std,
929
+ edge_delta,
930
+ ],
931
+ dim=1,
932
+ )
933
+
934
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
935
+ return self.kan(self._features(x))
936
+
937
+ def regularization_loss(self) -> torch.Tensor:
938
+ return efficient_kan_regularization(self.kan)
939
+
940
+
941
+ class CNNKANEqualizer(nn.Module):
942
+ def __init__(self):
943
+ super().__init__()
944
+ hidden_dim = Config.HIDDEN_DIM
945
+ self.cnn = nn.Sequential(
946
+ nn.Conv1d(Config.INPUT_DIM, hidden_dim, kernel_size=5, padding=2),
947
+ nn.BatchNorm1d(hidden_dim),
948
+ nn.GELU(),
949
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
950
+ nn.BatchNorm1d(hidden_dim),
951
+ nn.GELU(),
952
+ nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
953
+ nn.BatchNorm1d(hidden_dim),
954
+ nn.GELU(),
955
+ )
956
+ self.pool_score = nn.Conv1d(hidden_dim, 1, kernel_size=1)
957
+ fused_dim = hidden_dim * 2 + 2 * Config.INPUT_DIM
958
+ self.kan = make_efficient_kan(input_dim=fused_dim, output_dim=2)
959
+ self._init_weights()
960
+
961
+ def _init_weights(self):
962
+ for module in self.cnn.modules():
963
+ if isinstance(module, nn.Conv1d):
964
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
965
+ if module.bias is not None:
966
+ nn.init.constant_(module.bias, 0.0)
967
+ elif isinstance(module, nn.BatchNorm1d):
968
+ nn.init.constant_(module.weight, 1.0)
969
+ nn.init.constant_(module.bias, 0.0)
970
+ nn.init.kaiming_normal_(self.pool_score.weight, nonlinearity="linear")
971
+ nn.init.constant_(self.pool_score.bias, 0.0)
972
+
973
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
974
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
975
+ features = self.cnn(raw.transpose(1, 2))
976
+ center = features[:, :, Config.CONTEXT_K]
977
+ weights = torch.softmax(self.pool_score(features), dim=2)
978
+ global_context = torch.sum(weights * features, dim=2)
979
+ raw_center = raw[:, Config.CONTEXT_K, :]
980
+ raw_mean = raw.mean(dim=1)
981
+ fused = torch.cat([center, global_context, raw_center, raw_mean], dim=1)
982
+ return self.kan(fused)
983
+
984
+ def regularization_loss(self) -> torch.Tensor:
985
+ return efficient_kan_regularization(self.kan)
986
+
987
+
988
+ class EfficientKANClassifierEqualizer(nn.Module):
989
+ def __init__(self):
990
+ super().__init__()
991
+ input_dim = Config.SEQ_LEN * Config.INPUT_DIM
992
+ self.kan = make_efficient_kan(input_dim=input_dim, output_dim=CONSTELLATION.size(0))
993
+
994
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
995
+ return self.kan(x)
996
+
997
+ def regularization_loss(self) -> torch.Tensor:
998
+ return efficient_kan_regularization(self.kan)
999
+
1000
+
1001
+ class ComplexFeatureEncoder(nn.Module):
1002
+ def __init__(self):
1003
+ super().__init__()
1004
+ self.dbp_frontend = ComplexDBPFrontEnd() if Config.COMPLEX_USE_DBP_FRONTEND else None
1005
+ channels = Config.COMPLEX_BLOCK_CHANNELS
1006
+ kernels = Config.COMPLEX_KERNEL_SIZES
1007
+ if len(channels) != len(kernels):
1008
+ raise ValueError("COMPLEX_BLOCK_CHANNELS and COMPLEX_KERNEL_SIZES must have the same length.")
1009
+
1010
+ self.stem = ComplexConv1d(1, channels[0], kernel_size=1)
1011
+ self.stem_norm = nn.BatchNorm1d(2 * channels[0])
1012
+ self.blocks = nn.ModuleList()
1013
+ in_channels = channels[0]
1014
+ for out_channels, kernel_size in zip(channels, kernels):
1015
+ self.blocks.append(ComplexResidualBlock(in_channels, out_channels, kernel_size))
1016
+ in_channels = out_channels
1017
+
1018
+ self.final_norm = nn.BatchNorm1d(2 * in_channels)
1019
+ self.out_channels = in_channels
1020
+
1021
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1022
+ if self.dbp_frontend is not None:
1023
+ self.dbp_frontend.initialize_from_data(train_x, train_y)
1024
+
1025
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1026
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
1027
+ if self.dbp_frontend is not None:
1028
+ x, _ = self.dbp_frontend(x, collect_states=False)
1029
+ else:
1030
+ x = raw.transpose(1, 2)
1031
+ x = self.stem_norm(self.stem(x))
1032
+ for block in self.blocks:
1033
+ x = block(x)
1034
+ x = self.final_norm(x)
1035
+ real = x[:, 0::2, :]
1036
+ imag = x[:, 1::2, :]
1037
+ magnitude = torch.sqrt(real.square() + imag.square() + 1e-6)
1038
+ seq_features = torch.cat([real, imag, magnitude], dim=1).transpose(1, 2)
1039
+ return raw, x, seq_features
1040
+
1041
+
1042
+ class ComplexDBPStep1Ch(nn.Module):
1043
+ def __init__(self):
1044
+ super().__init__()
1045
+ self.linear = ComplexConv1d(
1046
+ in_channels=1,
1047
+ out_channels=1,
1048
+ kernel_size=Config.DBP_KERNEL_SIZE,
1049
+ symmetric=Config.DBP_USE_SYMMETRIC_FILTER,
1050
+ )
1051
+ nl_kernel = Config.DBP_NL_MEMORY + 1 if Config.DBP_USE_SYMMETRIC_NONLINEAR else 2 * Config.DBP_NL_MEMORY + 1
1052
+ self.nonlinear = KerrLikeActivation(
1053
+ 1,
1054
+ kernel_size=nl_kernel,
1055
+ init_gamma=Config.COMPLEX_KERR_INIT_GAMMA,
1056
+ symmetric=Config.DBP_USE_SYMMETRIC_NONLINEAR,
1057
+ )
1058
+
1059
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1060
+ return self.nonlinear(self.linear(x))
1061
+
1062
+
1063
+ class ComplexDBPFrontEnd(nn.Module):
1064
+ def __init__(self):
1065
+ super().__init__()
1066
+ self.steps = nn.ModuleList([ComplexDBPStep1Ch() for _ in range(Config.DBP_NUM_STEPS)])
1067
+ self.final_linear = (
1068
+ ComplexConv1d(
1069
+ in_channels=1,
1070
+ out_channels=1,
1071
+ kernel_size=Config.DBP_FINAL_KERNEL_SIZE,
1072
+ symmetric=Config.DBP_USE_SYMMETRIC_FILTER,
1073
+ )
1074
+ if Config.DBP_USE_FINAL_FILTER
1075
+ else None
1076
+ )
1077
+ linear_delay = Config.DBP_KERNEL_SIZE - 1 if Config.DBP_USE_SYMMETRIC_FILTER else Config.DBP_KERNEL_SIZE // 2
1078
+ final_delay = 0
1079
+ if self.final_linear is not None:
1080
+ final_delay = (
1081
+ Config.DBP_FINAL_KERNEL_SIZE - 1
1082
+ if Config.DBP_USE_SYMMETRIC_FILTER
1083
+ else Config.DBP_FINAL_KERNEL_SIZE // 2
1084
+ )
1085
+ nl_delay = Config.DBP_NL_MEMORY if Config.DBP_USE_SYMMETRIC_NONLINEAR else 2 * Config.DBP_NL_MEMORY
1086
+ self.valid_margin = Config.DBP_NUM_STEPS * (linear_delay + nl_delay) + final_delay
1087
+ if self.valid_margin > Config.CONTEXT_K:
1088
+ raise ValueError(
1089
+ f"DBP receptive radius {self.valid_margin} exceeds CONTEXT_K={Config.CONTEXT_K}. "
1090
+ "Increase CONTEXT_K or reduce DBP kernels/steps."
1091
+ )
1092
+
1093
+ @staticmethod
1094
+ def _seq_features(x: torch.Tensor) -> torch.Tensor:
1095
+ real = x[:, 0:1, :]
1096
+ imag = x[:, 1:2, :]
1097
+ magnitude = torch.sqrt(real.square() + imag.square() + 1e-6)
1098
+ return torch.cat([real, imag, magnitude], dim=1).transpose(1, 2)
1099
+
1100
+ def forward(
1101
+ self,
1102
+ x: torch.Tensor,
1103
+ collect_states: bool = False,
1104
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
1105
+ state = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM).transpose(1, 2)
1106
+ states = [state] if collect_states else []
1107
+ for step in self.steps:
1108
+ state = step(state)
1109
+ if collect_states:
1110
+ states.append(state)
1111
+ if self.final_linear is not None:
1112
+ state = self.final_linear(state)
1113
+ if collect_states:
1114
+ states.append(state)
1115
+ return state, states
1116
+
1117
+ @torch.no_grad()
1118
+ def _initialize_linear_kernels_from_ls(self, train_x: torch.Tensor, train_y: torch.Tensor):
1119
+ samples = min(train_x.size(0), Config.DBP_INIT_SAMPLES)
1120
+ if samples < Config.SEQ_LEN:
1121
+ return
1122
+ windows = train_x[:samples].view(samples, Config.SEQ_LEN, Config.INPUT_DIM)
1123
+ target = train_y[:samples]
1124
+ rx_complex = torch.complex(windows[:, :, 0], windows[:, :, 1]).to(torch.complex64)
1125
+ tx_complex = torch.complex(target[:, 0], target[:, 1]).to(torch.complex64)
1126
+
1127
+ global_kernel = torch.linalg.lstsq(rx_complex, tx_complex.unsqueeze(1)).solution.squeeze(1)
1128
+ global_kernel = global_kernel / global_kernel.norm().clamp_min(1e-6)
1129
+
1130
+ fft_size = max(Config.DBP_INIT_FFT_SIZE, 1 << int(np.ceil(np.log2(Config.SEQ_LEN))))
1131
+ global_response = centered_complex_kernel_to_frequency(global_kernel, fft_size)
1132
+ linear_stage_count = len(self.steps) + (1 if self.final_linear is not None else 0)
1133
+ step_response = complex_unit_response(global_response, linear_stage_count)
1134
+ final_response = global_response / step_response.pow(len(self.steps))
1135
+
1136
+ step_kernel_centered = frequency_to_centered_complex_kernel(step_response, Config.SEQ_LEN)
1137
+ step_kernel = extract_kernel_from_centered_response(
1138
+ step_kernel_centered,
1139
+ Config.DBP_KERNEL_SIZE,
1140
+ Config.DBP_USE_SYMMETRIC_FILTER,
1141
+ )
1142
+ for step in self.steps:
1143
+ assign_complex_kernel(step.linear, step_kernel)
1144
+ center = step.nonlinear.power_filter.weight.size(-1) // 2
1145
+ step.nonlinear.power_filter.weight.zero_()
1146
+ step.nonlinear.power_filter.weight[:, :, center] = 1.0
1147
+ step.nonlinear.gamma.fill_(Config.COMPLEX_KERR_INIT_GAMMA / max(Config.DBP_NUM_STEPS, 1))
1148
+
1149
+ if self.final_linear is not None:
1150
+ final_kernel_centered = frequency_to_centered_complex_kernel(final_response, Config.SEQ_LEN)
1151
+ final_kernel = extract_kernel_from_centered_response(
1152
+ final_kernel_centered,
1153
+ Config.DBP_FINAL_KERNEL_SIZE,
1154
+ Config.DBP_USE_SYMMETRIC_FILTER,
1155
+ )
1156
+ assign_complex_kernel(self.final_linear, final_kernel)
1157
+
1158
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1159
+ if not Config.DBP_INIT_FROM_LS:
1160
+ return
1161
+ self._initialize_linear_kernels_from_ls(train_x, train_y)
1162
+ if not Config.DBP_JOINT_INIT or Config.DBP_JOINT_INIT_ITERS <= 0:
1163
+ return
1164
+
1165
+ was_training = self.training
1166
+ self.train()
1167
+ subset = min(train_x.size(0), Config.DBP_INIT_SAMPLES)
1168
+ if subset < Config.DBP_JOINT_INIT_BATCH_SIZE:
1169
+ if not was_training:
1170
+ self.eval()
1171
+ return
1172
+
1173
+ optimizer = optim.Adam(self.parameters(), lr=Config.DBP_JOINT_INIT_LR)
1174
+ criterion = nn.MSELoss()
1175
+ for _ in range(Config.DBP_JOINT_INIT_ITERS):
1176
+ index = torch.randint(0, subset, (Config.DBP_JOINT_INIT_BATCH_SIZE,))
1177
+ xb = train_x[index].to(Config.DEVICE)
1178
+ yb = train_y[index].to(Config.DEVICE)
1179
+ optimizer.zero_grad(set_to_none=True)
1180
+ state, _ = self.forward(xb, collect_states=False)
1181
+ preds = state[:, :, Config.CONTEXT_K].transpose(0, 1).transpose(0, 1)
1182
+ loss = criterion(preds, yb)
1183
+ loss.backward()
1184
+ optimizer.step()
1185
+ if not was_training:
1186
+ self.eval()
1187
+
1188
+
1189
+ class ComplexDBPSeqStatRxEqualizer(nn.Module):
1190
+ def __init__(self):
1191
+ super().__init__()
1192
+ self.frontend = ComplexDBPFrontEnd()
1193
+ self.valid_margin = self.frontend.valid_margin
1194
+ feature_dim = 3 * (
1195
+ 1 + Config.DBP_NUM_STEPS + (1 if Config.DBP_USE_FINAL_FILTER else 0)
1196
+ )
1197
+ fused_dim = feature_dim * 3 + Config.INPUT_DIM
1198
+ self.head = nn.Sequential(
1199
+ nn.Linear(fused_dim, Config.DBP_SEQSTAT_DIM),
1200
+ nn.LayerNorm(Config.DBP_SEQSTAT_DIM),
1201
+ nn.GELU(),
1202
+ nn.Dropout(Config.DROPOUT),
1203
+ nn.Linear(Config.DBP_SEQSTAT_DIM, Config.HIDDEN_DIM),
1204
+ nn.LayerNorm(Config.HIDDEN_DIM),
1205
+ nn.GELU(),
1206
+ nn.Dropout(Config.DROPOUT * 0.5),
1207
+ nn.Linear(Config.HIDDEN_DIM, 2),
1208
+ )
1209
+ self._init_weights()
1210
+
1211
+ def _init_weights(self):
1212
+ for module in self.modules():
1213
+ if isinstance(module, nn.Linear):
1214
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1215
+ if module.bias is not None:
1216
+ nn.init.constant_(module.bias, 0.0)
1217
+ elif isinstance(module, nn.LayerNorm):
1218
+ nn.init.constant_(module.weight, 1.0)
1219
+ nn.init.constant_(module.bias, 0.0)
1220
+
1221
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1222
+ self.frontend.initialize_from_data(train_x, train_y)
1223
+
1224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1225
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
1226
+ state, states = self.frontend(x, collect_states=True)
1227
+ features = [self.frontend._seq_features(step_state) for step_state in states]
1228
+ seq = torch.cat(features, dim=2)
1229
+ center = seq[:, Config.CONTEXT_K, :]
1230
+ if self.valid_margin > 0:
1231
+ valid_seq = seq[:, self.valid_margin : Config.SEQ_LEN - self.valid_margin, :]
1232
+ else:
1233
+ valid_seq = seq
1234
+ mean = valid_seq.mean(dim=1)
1235
+ std = valid_seq.std(dim=1, unbiased=False)
1236
+ fused = torch.cat([center, mean, std, raw[:, Config.CONTEXT_K, :]], dim=1)
1237
+ return self.head(fused)
1238
+
1239
+
1240
+ class ComplexCNNRxEqualizer(nn.Module):
1241
+ def __init__(self):
1242
+ super().__init__()
1243
+ self.encoder = ComplexFeatureEncoder()
1244
+ temporal_in_dim = 3 * self.encoder.out_channels
1245
+ self.temporal_proj = nn.Conv1d(temporal_in_dim, Config.COMPLEX_TEMPORAL_DIM, kernel_size=1, bias=False)
1246
+ self.temporal_blocks = nn.ModuleList(
1247
+ [TemporalConvBlock(Config.COMPLEX_TEMPORAL_DIM, dilation) for dilation in Config.COMPLEX_TEMPORAL_DILATIONS]
1248
+ )
1249
+ self.temporal_norm = nn.BatchNorm1d(Config.COMPLEX_TEMPORAL_DIM)
1250
+ self.pool_score = nn.Conv1d(Config.COMPLEX_TEMPORAL_DIM, 1, kernel_size=1)
1251
+ fused_dim = Config.COMPLEX_TEMPORAL_DIM * 2 + 2 * Config.INPUT_DIM
1252
+ self.head = nn.Sequential(
1253
+ nn.Linear(fused_dim, Config.COMPLEX_HEAD_DIM),
1254
+ nn.LayerNorm(Config.COMPLEX_HEAD_DIM),
1255
+ nn.GELU(),
1256
+ nn.Dropout(Config.DROPOUT),
1257
+ nn.Linear(Config.COMPLEX_HEAD_DIM, Config.HIDDEN_DIM),
1258
+ nn.LayerNorm(Config.HIDDEN_DIM),
1259
+ nn.GELU(),
1260
+ nn.Dropout(Config.DROPOUT * 0.5),
1261
+ nn.Linear(Config.HIDDEN_DIM, 2),
1262
+ )
1263
+ self._init_weights()
1264
+
1265
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1266
+ self.encoder.initialize_from_data(train_x, train_y)
1267
+
1268
+ def _init_weights(self):
1269
+ for module in self.modules():
1270
+ if isinstance(module, nn.Linear):
1271
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1272
+ if module.bias is not None:
1273
+ nn.init.constant_(module.bias, 0.0)
1274
+ elif isinstance(module, nn.BatchNorm1d):
1275
+ nn.init.constant_(module.weight, 1.0)
1276
+ nn.init.constant_(module.bias, 0.0)
1277
+ elif isinstance(module, nn.LayerNorm):
1278
+ nn.init.constant_(module.weight, 1.0)
1279
+ nn.init.constant_(module.bias, 0.0)
1280
+
1281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1282
+ raw, _, seq_features = self.encoder(x)
1283
+ raw_center = raw[:, Config.CONTEXT_K, :]
1284
+ raw_global = raw.mean(dim=1)
1285
+ temporal = seq_features.transpose(1, 2)
1286
+ temporal = self.temporal_proj(temporal)
1287
+ for block in self.temporal_blocks:
1288
+ temporal = block(temporal)
1289
+ temporal = self.temporal_norm(temporal)
1290
+
1291
+ center = temporal[:, :, Config.CONTEXT_K]
1292
+ scores = self.pool_score(temporal)
1293
+ weights = torch.softmax(scores, dim=2)
1294
+ global_context = torch.sum(weights * temporal, dim=2)
1295
+ fused = torch.cat([center, global_context, raw_center, raw_global], dim=1)
1296
+ return self.head(fused)
1297
+
1298
+
1299
+ class ComplexLSTMRxEqualizer(nn.Module):
1300
+ def __init__(self):
1301
+ super().__init__()
1302
+ self.encoder = ComplexFeatureEncoder()
1303
+ seq_in_dim = 3 * self.encoder.out_channels
1304
+ self.seq_proj = nn.Sequential(
1305
+ nn.Linear(seq_in_dim, Config.COMPLEX_SEQ_DIM),
1306
+ nn.LayerNorm(Config.COMPLEX_SEQ_DIM),
1307
+ nn.GELU(),
1308
+ nn.Dropout(Config.DROPOUT * 0.5),
1309
+ )
1310
+ self.lstm = nn.LSTM(
1311
+ input_size=Config.COMPLEX_SEQ_DIM,
1312
+ hidden_size=Config.COMPLEX_LSTM_HIDDEN,
1313
+ num_layers=Config.COMPLEX_LSTM_LAYERS,
1314
+ batch_first=True,
1315
+ bidirectional=Config.BIDIRECTIONAL,
1316
+ dropout=Config.DROPOUT if Config.COMPLEX_LSTM_LAYERS > 1 else 0.0,
1317
+ )
1318
+ lstm_out_dim = Config.COMPLEX_LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
1319
+ self.attention = AttentionLayer(lstm_out_dim)
1320
+ self.center_fusion = nn.Sequential(
1321
+ nn.Linear(lstm_out_dim * 2 + Config.INPUT_DIM, Config.COMPLEX_HEAD_DIM),
1322
+ nn.LayerNorm(Config.COMPLEX_HEAD_DIM),
1323
+ nn.GELU(),
1324
+ )
1325
+ self.head = nn.Sequential(
1326
+ nn.Linear(Config.COMPLEX_HEAD_DIM, Config.HIDDEN_DIM),
1327
+ nn.LayerNorm(Config.HIDDEN_DIM),
1328
+ nn.GELU(),
1329
+ nn.Dropout(Config.DROPOUT),
1330
+ nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
1331
+ nn.LayerNorm(Config.HIDDEN_DIM // 2),
1332
+ nn.GELU(),
1333
+ nn.Dropout(Config.DROPOUT * 0.5),
1334
+ nn.Linear(Config.HIDDEN_DIM // 2, 2),
1335
+ )
1336
+ self._init_weights()
1337
+
1338
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1339
+ self.encoder.initialize_from_data(train_x, train_y)
1340
+
1341
+ def _init_weights(self):
1342
+ for name, param in self.lstm.named_parameters():
1343
+ if "weight_ih" in name:
1344
+ nn.init.xavier_uniform_(param.data)
1345
+ elif "weight_hh" in name:
1346
+ nn.init.orthogonal_(param.data)
1347
+ elif "bias" in name:
1348
+ nn.init.constant_(param.data, 0)
1349
+ gate = Config.COMPLEX_LSTM_HIDDEN
1350
+ param.data[gate : 2 * gate] = 1.0
1351
+ for module in self.modules():
1352
+ if isinstance(module, nn.Linear):
1353
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1354
+ if module.bias is not None:
1355
+ nn.init.constant_(module.bias, 0.0)
1356
+ elif isinstance(module, nn.LayerNorm):
1357
+ nn.init.constant_(module.weight, 1.0)
1358
+ nn.init.constant_(module.bias, 0.0)
1359
+
1360
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1361
+ raw, _, seq_features = self.encoder(x)
1362
+ seq = self.seq_proj(seq_features)
1363
+ lstm_out, _ = self.lstm(seq)
1364
+ center = lstm_out[:, Config.CONTEXT_K, :]
1365
+ context = self.attention(lstm_out)
1366
+ fused = self.center_fusion(torch.cat([center, context, raw[:, Config.CONTEXT_K, :]], dim=1))
1367
+ return self.head(fused)
1368
+
1369
+
1370
+ class ComplexCNNLSTMRxEqualizer(nn.Module):
1371
+ def __init__(self):
1372
+ super().__init__()
1373
+ self.encoder = ComplexFeatureEncoder()
1374
+ seq_in_dim = 3 * self.encoder.out_channels
1375
+ self.temporal_proj = nn.Conv1d(seq_in_dim, Config.COMPLEX_TEMPORAL_DIM, kernel_size=1, bias=False)
1376
+ self.temporal_blocks = nn.ModuleList(
1377
+ [TemporalConvBlock(Config.COMPLEX_TEMPORAL_DIM, dilation) for dilation in Config.COMPLEX_TEMPORAL_DILATIONS[:2]]
1378
+ )
1379
+ self.temporal_norm = nn.BatchNorm1d(Config.COMPLEX_TEMPORAL_DIM)
1380
+ self.lstm = nn.LSTM(
1381
+ input_size=Config.COMPLEX_TEMPORAL_DIM,
1382
+ hidden_size=Config.COMPLEX_LSTM_HIDDEN,
1383
+ num_layers=Config.COMPLEX_LSTM_LAYERS,
1384
+ batch_first=True,
1385
+ bidirectional=Config.BIDIRECTIONAL,
1386
+ dropout=Config.DROPOUT if Config.COMPLEX_LSTM_LAYERS > 1 else 0.0,
1387
+ )
1388
+ lstm_out_dim = Config.COMPLEX_LSTM_HIDDEN * (2 if Config.BIDIRECTIONAL else 1)
1389
+ self.attention = AttentionLayer(lstm_out_dim)
1390
+ self.center_fusion = nn.Sequential(
1391
+ nn.Linear(lstm_out_dim * 2 + Config.INPUT_DIM, Config.COMPLEX_HEAD_DIM),
1392
+ nn.LayerNorm(Config.COMPLEX_HEAD_DIM),
1393
+ nn.GELU(),
1394
+ )
1395
+ self.head = nn.Sequential(
1396
+ nn.Linear(Config.COMPLEX_HEAD_DIM, Config.HIDDEN_DIM),
1397
+ nn.LayerNorm(Config.HIDDEN_DIM),
1398
+ nn.GELU(),
1399
+ nn.Dropout(Config.DROPOUT),
1400
+ nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
1401
+ nn.LayerNorm(Config.HIDDEN_DIM // 2),
1402
+ nn.GELU(),
1403
+ nn.Dropout(Config.DROPOUT * 0.5),
1404
+ nn.Linear(Config.HIDDEN_DIM // 2, 2),
1405
+ )
1406
+ self._init_weights()
1407
+
1408
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1409
+ self.encoder.initialize_from_data(train_x, train_y)
1410
+
1411
+ def _init_weights(self):
1412
+ for name, param in self.lstm.named_parameters():
1413
+ if "weight_ih" in name:
1414
+ nn.init.xavier_uniform_(param.data)
1415
+ elif "weight_hh" in name:
1416
+ nn.init.orthogonal_(param.data)
1417
+ elif "bias" in name:
1418
+ nn.init.constant_(param.data, 0)
1419
+ gate = Config.COMPLEX_LSTM_HIDDEN
1420
+ param.data[gate : 2 * gate] = 1.0
1421
+ for module in self.modules():
1422
+ if isinstance(module, nn.Linear):
1423
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1424
+ if module.bias is not None:
1425
+ nn.init.constant_(module.bias, 0.0)
1426
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
1427
+ if hasattr(module, "weight") and module.weight is not None:
1428
+ nn.init.constant_(module.weight, 1.0)
1429
+ if hasattr(module, "bias") and module.bias is not None:
1430
+ nn.init.constant_(module.bias, 0.0)
1431
+ elif isinstance(module, nn.Conv1d):
1432
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1433
+ if module.bias is not None:
1434
+ nn.init.constant_(module.bias, 0.0)
1435
+
1436
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1437
+ raw, _, seq_features = self.encoder(x)
1438
+ temporal = self.temporal_proj(seq_features.transpose(1, 2))
1439
+ for block in self.temporal_blocks:
1440
+ temporal = block(temporal)
1441
+ temporal = self.temporal_norm(temporal).transpose(1, 2)
1442
+ lstm_out, _ = self.lstm(temporal)
1443
+ center = lstm_out[:, Config.CONTEXT_K, :]
1444
+ context = self.attention(lstm_out)
1445
+ fused = self.center_fusion(torch.cat([center, context, raw[:, Config.CONTEXT_K, :]], dim=1))
1446
+ return self.head(fused)
1447
+
1448
+
1449
+ class ComplexFastKANEqualizer(nn.Module):
1450
+ def __init__(self):
1451
+ super().__init__()
1452
+ self.encoder = LightweightComplexEncoder()
1453
+ fused_dim = 3 * (self.encoder.out_channels + 3) + 2 * Config.INPUT_DIM + 1
1454
+ self.head = FastKANHead(
1455
+ input_dim=fused_dim,
1456
+ hidden_dim=Config.FASTKAN_HIDDEN_DIM,
1457
+ out_dim=2,
1458
+ )
1459
+
1460
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1461
+ raw, _, seq_features = self.encoder(x)
1462
+ center = seq_features[:, Config.CONTEXT_K, :]
1463
+ mean = seq_features.mean(dim=1)
1464
+ std = seq_features.std(dim=1, unbiased=False)
1465
+ raw_center = raw[:, Config.CONTEXT_K, :]
1466
+ raw_mean = raw.mean(dim=1)
1467
+ power_center = raw_center.square().sum(dim=1, keepdim=True)
1468
+ fused = torch.cat([center, mean, std, raw_center, raw_mean, power_center], dim=1)
1469
+ return self.head(fused)
1470
+
1471
+ def regularization_loss(self) -> torch.Tensor:
1472
+ return self.head.regularization_loss()
1473
+
1474
+
1475
+ class ComplexFastKANClassifierEqualizer(nn.Module):
1476
+ def __init__(self):
1477
+ super().__init__()
1478
+ self.encoder = LightweightComplexEncoder()
1479
+ fused_dim = 3 * (self.encoder.out_channels + 3) + 2 * Config.INPUT_DIM + 1
1480
+ self.head = FastKANHead(
1481
+ input_dim=fused_dim,
1482
+ hidden_dim=Config.FASTKAN_HIDDEN_DIM,
1483
+ out_dim=CONSTELLATION.size(0),
1484
+ )
1485
+
1486
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1487
+ raw, _, seq_features = self.encoder(x)
1488
+ center = seq_features[:, Config.CONTEXT_K, :]
1489
+ mean = seq_features.mean(dim=1)
1490
+ std = seq_features.std(dim=1, unbiased=False)
1491
+ raw_center = raw[:, Config.CONTEXT_K, :]
1492
+ raw_mean = raw.mean(dim=1)
1493
+ power_center = raw_center.square().sum(dim=1, keepdim=True)
1494
+ fused = torch.cat([center, mean, std, raw_center, raw_mean, power_center], dim=1)
1495
+ return self.head(fused)
1496
+
1497
+ def regularization_loss(self) -> torch.Tensor:
1498
+ return self.head.regularization_loss()
1499
+
1500
+
1501
+ class TransformerEncoderBlock(nn.Module):
1502
+ def __init__(self, dim: int, heads: int, ff_dim: int):
1503
+ super().__init__()
1504
+ self.norm1 = nn.LayerNorm(dim)
1505
+ self.attn = nn.MultiheadAttention(
1506
+ embed_dim=dim,
1507
+ num_heads=heads,
1508
+ dropout=Config.DROPOUT,
1509
+ batch_first=True,
1510
+ )
1511
+ self.norm2 = nn.LayerNorm(dim)
1512
+ self.local_mixer = nn.Sequential(
1513
+ nn.Conv1d(
1514
+ dim,
1515
+ dim,
1516
+ kernel_size=Config.TRANSFORMER_CONV_KERNEL,
1517
+ padding=Config.TRANSFORMER_CONV_KERNEL // 2,
1518
+ groups=dim,
1519
+ ),
1520
+ nn.GELU(),
1521
+ nn.Conv1d(dim, dim, kernel_size=1),
1522
+ )
1523
+ self.norm3 = nn.LayerNorm(dim)
1524
+ self.ff_in = nn.Linear(dim, ff_dim * 2)
1525
+ self.ff_out = nn.Linear(ff_dim, dim)
1526
+ self.dropout = nn.Dropout(Config.DROPOUT)
1527
+
1528
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1529
+ attn_in = self.norm1(x)
1530
+ with sdpa_context():
1531
+ attn_out, _ = self.attn(attn_in, attn_in, attn_in, need_weights=False)
1532
+ x = x + self.dropout(attn_out)
1533
+
1534
+ local_in = self.norm2(x).transpose(1, 2)
1535
+ local_out = self.local_mixer(local_in).transpose(1, 2)
1536
+ x = x + self.dropout(local_out)
1537
+
1538
+ ff_in = self.norm3(x)
1539
+ value, gate = self.ff_in(ff_in).chunk(2, dim=-1)
1540
+ ff_out = self.ff_out(value * F.gelu(gate))
1541
+ return x + self.dropout(ff_out)
1542
+
1543
+
1544
+ class TransformerRxEqualizer(nn.Module):
1545
+ def __init__(self):
1546
+ super().__init__()
1547
+ dim = Config.TRANSFORMER_DIM
1548
+ self.input_proj = nn.Linear(Config.INPUT_DIM, dim)
1549
+ self.pos_embedding = nn.Parameter(torch.zeros(1, Config.SEQ_LEN, dim))
1550
+ self.input_dropout = nn.Dropout(Config.DROPOUT * 0.5)
1551
+ self.blocks = nn.ModuleList(
1552
+ [
1553
+ TransformerEncoderBlock(
1554
+ dim=dim,
1555
+ heads=Config.TRANSFORMER_HEADS,
1556
+ ff_dim=Config.TRANSFORMER_FF_DIM,
1557
+ )
1558
+ for _ in range(Config.TRANSFORMER_LAYERS)
1559
+ ]
1560
+ )
1561
+ self.final_norm = nn.LayerNorm(dim)
1562
+ self.center_fusion = nn.Sequential(
1563
+ nn.Linear(dim * 2, dim),
1564
+ nn.LayerNorm(dim),
1565
+ nn.GELU(),
1566
+ )
1567
+ self.regressor = nn.Sequential(
1568
+ nn.Linear(dim, Config.HIDDEN_DIM),
1569
+ nn.LayerNorm(Config.HIDDEN_DIM),
1570
+ nn.GELU(),
1571
+ nn.Dropout(Config.DROPOUT),
1572
+ nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
1573
+ nn.LayerNorm(Config.HIDDEN_DIM // 2),
1574
+ nn.GELU(),
1575
+ nn.Dropout(Config.DROPOUT * 0.5),
1576
+ nn.Linear(Config.HIDDEN_DIM // 2, 2),
1577
+ )
1578
+ self._init_weights()
1579
+
1580
+ def _init_weights(self):
1581
+ nn.init.normal_(self.pos_embedding, mean=0.0, std=0.02)
1582
+ for module in self.modules():
1583
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
1584
+ nn.init.xavier_uniform_(module.weight)
1585
+ if module.bias is not None:
1586
+ nn.init.constant_(module.bias, 0.0)
1587
+ elif isinstance(module, nn.LayerNorm):
1588
+ nn.init.constant_(module.bias, 0.0)
1589
+ nn.init.constant_(module.weight, 1.0)
1590
+
1591
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1592
+ x = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
1593
+ x = self.input_proj(x)
1594
+ x = self.input_dropout(x + self.pos_embedding)
1595
+ for block in self.blocks:
1596
+ x = block(x)
1597
+ x = self.final_norm(x)
1598
+ center = x[:, Config.CONTEXT_K, :]
1599
+ global_context = x.mean(dim=1)
1600
+ fused = self.center_fusion(torch.cat([center, global_context], dim=1))
1601
+ return self.regressor(fused)
1602
+
1603
+
1604
+ class TCNResidualBlock(nn.Module):
1605
+ def __init__(self, channels: int, kernel_size: int, dilation: int):
1606
+ super().__init__()
1607
+ padding = dilation * (kernel_size - 1) // 2
1608
+ self.norm = nn.BatchNorm1d(channels)
1609
+ self.depthwise = nn.Conv1d(
1610
+ channels,
1611
+ channels,
1612
+ kernel_size=kernel_size,
1613
+ padding=padding,
1614
+ dilation=dilation,
1615
+ groups=channels,
1616
+ bias=False,
1617
+ )
1618
+ self.pointwise = nn.Conv1d(channels, channels * 2, kernel_size=1)
1619
+ self.out_proj = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
1620
+ self.dropout = nn.Dropout(Config.DROPOUT * 0.5)
1621
+ self._init_weights()
1622
+
1623
+ def _init_weights(self):
1624
+ for module in (self.depthwise, self.pointwise, self.out_proj):
1625
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1626
+ if getattr(module, "bias", None) is not None:
1627
+ nn.init.constant_(module.bias, 0.0)
1628
+ nn.init.constant_(self.norm.weight, 1.0)
1629
+ nn.init.constant_(self.norm.bias, 0.0)
1630
+
1631
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1632
+ residual = x
1633
+ x = self.norm(x)
1634
+ x = self.depthwise(x)
1635
+ value, gate = self.pointwise(x).chunk(2, dim=1)
1636
+ x = self.out_proj(F.gelu(value) * torch.sigmoid(gate))
1637
+ return residual + self.dropout(x)
1638
+
1639
+
1640
+ class TCNRxEqualizer(nn.Module):
1641
+ def __init__(self):
1642
+ super().__init__()
1643
+ hidden_dim = Config.TCN_HIDDEN_DIM
1644
+ dilations = Config.TCN_DILATIONS[: Config.TCN_LAYERS]
1645
+ if len(dilations) < Config.TCN_LAYERS:
1646
+ dilations = dilations + [dilations[-1] if dilations else 1] * (Config.TCN_LAYERS - len(dilations))
1647
+ self.stem = nn.Sequential(
1648
+ nn.Conv1d(Config.INPUT_DIM, hidden_dim, kernel_size=1, bias=False),
1649
+ nn.BatchNorm1d(hidden_dim),
1650
+ nn.GELU(),
1651
+ )
1652
+ self.blocks = nn.ModuleList(
1653
+ [TCNResidualBlock(hidden_dim, Config.TCN_KERNEL_SIZE, dilation) for dilation in dilations]
1654
+ )
1655
+ self.final_norm = nn.BatchNorm1d(hidden_dim)
1656
+ fused_dim = hidden_dim * 2 + 2 * Config.INPUT_DIM
1657
+ self.head = nn.Sequential(
1658
+ nn.Linear(fused_dim, Config.HIDDEN_DIM),
1659
+ nn.LayerNorm(Config.HIDDEN_DIM),
1660
+ nn.GELU(),
1661
+ nn.Dropout(Config.DROPOUT),
1662
+ nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
1663
+ nn.LayerNorm(Config.HIDDEN_DIM // 2),
1664
+ nn.GELU(),
1665
+ nn.Dropout(Config.DROPOUT * 0.5),
1666
+ nn.Linear(Config.HIDDEN_DIM // 2, 2),
1667
+ )
1668
+ self._init_weights()
1669
+
1670
+ def _init_weights(self):
1671
+ for module in self.modules():
1672
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
1673
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
1674
+ if module.bias is not None:
1675
+ nn.init.constant_(module.bias, 0.0)
1676
+ elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
1677
+ if hasattr(module, "weight") and module.weight is not None:
1678
+ nn.init.constant_(module.weight, 1.0)
1679
+ if hasattr(module, "bias") and module.bias is not None:
1680
+ nn.init.constant_(module.bias, 0.0)
1681
+
1682
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1683
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
1684
+ seq = raw.transpose(1, 2)
1685
+ hidden = self.stem(seq)
1686
+ for block in self.blocks:
1687
+ hidden = block(hidden)
1688
+ hidden = self.final_norm(hidden)
1689
+ center = hidden[:, :, Config.CONTEXT_K]
1690
+ global_context = hidden.mean(dim=2)
1691
+ raw_center = raw[:, Config.CONTEXT_K, :]
1692
+ raw_mean = raw.mean(dim=1)
1693
+ return self.head(torch.cat([center, global_context, raw_center, raw_mean], dim=1))
1694
+
1695
+
1696
+ class MambaRxEqualizer(nn.Module):
1697
+ def __init__(self):
1698
+ super().__init__()
1699
+ if Mamba is None:
1700
+ raise ImportError(
1701
+ "mamba_ssm is unavailable. Install `mamba-ssm` to run the `mamba` model, "
1702
+ "or keep using `tcn`, `lstm`, and KAN baselines."
1703
+ )
1704
+ dim = Config.MAMBA_DIM
1705
+ self.input_proj = nn.Sequential(
1706
+ nn.Linear(Config.INPUT_DIM, dim),
1707
+ nn.LayerNorm(dim),
1708
+ nn.GELU(),
1709
+ )
1710
+ self.blocks = nn.ModuleList(
1711
+ [
1712
+ nn.ModuleDict(
1713
+ {
1714
+ "norm": nn.LayerNorm(dim),
1715
+ "mamba": Mamba(
1716
+ d_model=dim,
1717
+ d_state=Config.MAMBA_D_STATE,
1718
+ d_conv=Config.MAMBA_D_CONV,
1719
+ expand=Config.MAMBA_EXPAND,
1720
+ ),
1721
+ }
1722
+ )
1723
+ for _ in range(Config.MAMBA_LAYERS)
1724
+ ]
1725
+ )
1726
+ self.final_norm = nn.LayerNorm(dim)
1727
+ self.head = nn.Sequential(
1728
+ nn.Linear(dim * 2 + Config.INPUT_DIM, Config.HIDDEN_DIM),
1729
+ nn.LayerNorm(Config.HIDDEN_DIM),
1730
+ nn.GELU(),
1731
+ nn.Dropout(Config.DROPOUT),
1732
+ nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM // 2),
1733
+ nn.LayerNorm(Config.HIDDEN_DIM // 2),
1734
+ nn.GELU(),
1735
+ nn.Dropout(Config.DROPOUT * 0.5),
1736
+ nn.Linear(Config.HIDDEN_DIM // 2, 2),
1737
+ )
1738
+ self._init_weights()
1739
+
1740
+ def _init_weights(self):
1741
+ for module in self.modules():
1742
+ if isinstance(module, nn.Linear):
1743
+ nn.init.xavier_uniform_(module.weight)
1744
+ if module.bias is not None:
1745
+ nn.init.constant_(module.bias, 0.0)
1746
+ elif isinstance(module, nn.LayerNorm):
1747
+ nn.init.constant_(module.weight, 1.0)
1748
+ nn.init.constant_(module.bias, 0.0)
1749
+
1750
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1751
+ raw = x.view(x.size(0), Config.SEQ_LEN, Config.INPUT_DIM)
1752
+ hidden = self.input_proj(raw)
1753
+ for block in self.blocks:
1754
+ hidden = hidden + block["mamba"](block["norm"](hidden))
1755
+ hidden = self.final_norm(hidden)
1756
+ center = hidden[:, Config.CONTEXT_K, :]
1757
+ global_context = hidden.mean(dim=1)
1758
+ raw_center = raw[:, Config.CONTEXT_K, :]
1759
+ return self.head(torch.cat([center, global_context, raw_center], dim=1))
1760
+
1761
+
1762
+ class MLPRxEqualizer(nn.Module):
1763
+ def __init__(self):
1764
+ super().__init__()
1765
+ input_dim = Config.SEQ_LEN * Config.INPUT_DIM
1766
+ hidden_dim = Config.HIDDEN_DIM
1767
+ self.skip = nn.Linear(input_dim, 2)
1768
+ layers: List[nn.Module] = []
1769
+ in_dim = input_dim
1770
+ for _ in range(max(Config.MLP_LAYERS, 1)):
1771
+ layers.extend(
1772
+ [
1773
+ nn.Linear(in_dim, hidden_dim),
1774
+ nn.LayerNorm(hidden_dim),
1775
+ nn.GELU(),
1776
+ ]
1777
+ )
1778
+ in_dim = hidden_dim
1779
+ layers.append(nn.Linear(in_dim, 2))
1780
+ self.net = nn.Sequential(*layers)
1781
+ self._init_weights()
1782
+
1783
+ def _init_weights(self):
1784
+ with torch.no_grad():
1785
+ self.skip.weight.zero_()
1786
+ self.skip.bias.zero_()
1787
+ center_start = Config.CONTEXT_K * Config.INPUT_DIM
1788
+ self.skip.weight[:, center_start : center_start + Config.INPUT_DIM] = torch.eye(2)
1789
+
1790
+ linear_layers = [module for module in self.net.modules() if isinstance(module, nn.Linear)]
1791
+ for idx, module in enumerate(linear_layers):
1792
+ if idx == len(linear_layers) - 1:
1793
+ nn.init.zeros_(module.weight)
1794
+ nn.init.constant_(module.bias, 0.0)
1795
+ else:
1796
+ nn.init.xavier_uniform_(module.weight)
1797
+ if module.bias is not None:
1798
+ nn.init.constant_(module.bias, 0.0)
1799
+ for module in self.net.modules():
1800
+ if isinstance(module, nn.LayerNorm):
1801
+ nn.init.constant_(module.bias, 0)
1802
+ nn.init.constant_(module.weight, 1.0)
1803
+
1804
+ @torch.no_grad()
1805
+ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor):
1806
+ samples = min(train_x.size(0), 262_144)
1807
+ if samples < 3:
1808
+ return
1809
+ center = train_x[:samples].view(samples, Config.SEQ_LEN, Config.INPUT_DIM)[:, Config.CONTEXT_K, :].float()
1810
+ target = train_y[:samples].float()
1811
+ ones = torch.ones(samples, 1, dtype=center.dtype, device=center.device)
1812
+ design = torch.cat([center, ones], dim=1)
1813
+ solution = torch.linalg.lstsq(design, target).solution
1814
+
1815
+ self.skip.weight.zero_()
1816
+ self.skip.bias.copy_(solution[-1].to(self.skip.bias.device, dtype=self.skip.bias.dtype))
1817
+ center_start = Config.CONTEXT_K * Config.INPUT_DIM
1818
+ self.skip.weight[:, center_start : center_start + Config.INPUT_DIM].copy_(
1819
+ solution[: Config.INPUT_DIM].T.to(self.skip.weight.device, dtype=self.skip.weight.dtype)
1820
+ )
1821
+
1822
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1823
+ return self.skip(x) + self.net(x)
1824
+
1825
+
1826
+ def log(msg: str):
1827
+ print(msg, flush=True)
1828
+
1829
+
1830
+ MODEL_NOTES = {
1831
+ "efficient_kan_baseline": "flat IQ window -> KAN -> corrected I/Q",
1832
+ "efficient_kan_residual": "flat IQ window -> KAN correction -> rx center + correction",
1833
+ "efficient_kan_features": "handcrafted local/global IQ features -> KAN -> corrected I/Q",
1834
+ "cnn_kan": "CNN extracts temporal features -> KAN head -> corrected I/Q",
1835
+ "kan_classifier": "flat IQ window -> KAN -> 16 constellation logits",
1836
+ "fastkan_classifier": "flat IQ window -> compact RBF/FastKAN -> 16 constellation logits",
1837
+ "complex_fastkan_classifier": "light complex encoder -> RBF/FastKAN -> 16 constellation logits",
1838
+ "mlp": "flat IQ window -> MLP -> corrected I/Q",
1839
+ "cnn": "CNN extracts temporal features -> MLP head -> corrected I/Q",
1840
+ "tcn": "dilated temporal CNN over IQ window -> corrected I/Q",
1841
+ "mamba": "Mamba sequence blocks over IQ window -> corrected I/Q",
1842
+ }
1843
+
1844
+
1845
+ def count_trainable_parameters(model: nn.Module) -> int:
1846
+ return sum(param.numel() for param in model.parameters() if param.requires_grad)
1847
+
1848
+
1849
+ def build_criterion() -> nn.Module:
1850
+ if Config.LOSS == "smooth_l1":
1851
+ return nn.SmoothL1Loss(beta=0.05)
1852
+ return nn.MSELoss()
1853
+
1854
+
1855
+ def is_classifier_output(preds: torch.Tensor) -> bool:
1856
+ return preds.dim() == 2 and preds.size(1) == CONSTELLATION.size(0)
1857
+
1858
+
1859
+ def prediction_loss(preds: torch.Tensor, targets: torch.Tensor, criterion: nn.Module) -> torch.Tensor:
1860
+ if is_classifier_output(preds):
1861
+ target_classes = symbols_to_classes(targets.float())
1862
+ return F.cross_entropy(preds.float(), target_classes)
1863
+ return criterion(preds, targets)
1864
+
1865
+
1866
+ def build_optimizer(model: nn.Module) -> optim.Optimizer:
1867
+ optimizer_kwargs = {"lr": Config.LEARNING_RATE, "weight_decay": Config.WEIGHT_DECAY}
1868
+ if Config.DEVICE.type == "cuda":
1869
+ optimizer_kwargs["fused"] = True
1870
+ optimizer_name = Config.OPTIMIZER.lower()
1871
+ optimizer_cls = optim.Adam if optimizer_name == "adam" else optim.RMSprop
1872
+ try:
1873
+ return optimizer_cls(model.parameters(), **optimizer_kwargs)
1874
+ except TypeError:
1875
+ optimizer_kwargs.pop("fused", None)
1876
+ return optimizer_cls(model.parameters(), **optimizer_kwargs)
1877
+
1878
+
1879
+ def compute_model_regularization(model: nn.Module) -> torch.Tensor:
1880
+ if Config.KAN_PRUNE_L1 <= 0 or not hasattr(model, "regularization_loss"):
1881
+ return torch.zeros((), device=Config.DEVICE)
1882
+ reg = model.regularization_loss()
1883
+ if not torch.is_tensor(reg):
1884
+ reg = torch.tensor(float(reg), device=Config.DEVICE)
1885
+ return reg * Config.KAN_PRUNE_L1
1886
+
1887
+
1888
+ def complex_unit_response(response: torch.Tensor, order: int) -> torch.Tensor:
1889
+ magnitude = response.abs().clamp_min(1e-6).pow(1.0 / order)
1890
+ phase = torch.angle(response) / order
1891
+ return torch.polar(magnitude, phase)
1892
+
1893
+
1894
+ def centered_complex_kernel_to_frequency(kernel: torch.Tensor, fft_size: int) -> torch.Tensor:
1895
+ center = kernel.numel() // 2
1896
+ ordered = torch.roll(kernel, shifts=-center, dims=0)
1897
+ return torch.fft.fft(ordered, n=fft_size)
1898
+
1899
+
1900
+ def frequency_to_centered_complex_kernel(response: torch.Tensor, seq_len: int) -> torch.Tensor:
1901
+ time = torch.fft.ifft(response, n=response.numel())
1902
+ centered = torch.roll(time[:seq_len], shifts=seq_len // 2, dims=0)
1903
+ return centered
1904
+
1905
+
1906
+ def extract_kernel_from_centered_response(
1907
+ kernel: torch.Tensor,
1908
+ kernel_size: int,
1909
+ symmetric: bool,
1910
+ ) -> torch.Tensor:
1911
+ center = kernel.numel() // 2
1912
+ if symmetric:
1913
+ return kernel[center : center + kernel_size].contiguous()
1914
+ start = center - kernel_size // 2
1915
+ end = start + kernel_size
1916
+ return kernel[start:end].contiguous()
1917
+
1918
+
1919
+ def assign_complex_kernel(module, kernel: torch.Tensor):
1920
+ module.real_conv.weight.copy_(kernel.real.view_as(module.real_conv.weight).to(module.real_conv.weight.dtype))
1921
+ module.imag_conv.weight.copy_(kernel.imag.view_as(module.imag_conv.weight).to(module.imag_conv.weight.dtype))
1922
+
1923
+
1924
+ def is_cuda_oom(error: RuntimeError) -> bool:
1925
+ return "out of memory" in str(error).lower()
1926
+
1927
+
1928
+ def mark_cudagraph_step_begin():
1929
+ if Config.DEVICE.type != "cuda":
1930
+ return
1931
+ compiler_mod = getattr(torch, "compiler", None)
1932
+ if compiler_mod is not None and hasattr(compiler_mod, "cudagraph_mark_step_begin"):
1933
+ compiler_mod.cudagraph_mark_step_begin()
1934
+
1935
+
1936
+ def autocast_context():
1937
+ return torch.autocast(
1938
+ device_type="cuda",
1939
+ dtype=torch.float16,
1940
+ enabled=Config.DEVICE.type == "cuda" and Config.USE_AMP,
1941
+ )
1942
+
1943
+
1944
+ def sdpa_context():
1945
+ if Config.DEVICE.type != "cuda":
1946
+ return nullcontext()
1947
+ attention_mod = getattr(torch.nn, "attention", None)
1948
+ if attention_mod is not None and hasattr(attention_mod, "sdpa_kernel") and hasattr(attention_mod, "SDPBackend"):
1949
+ return attention_mod.sdpa_kernel([attention_mod.SDPBackend.MATH])
1950
+ cuda_backends = getattr(torch.backends, "cuda", None)
1951
+ if cuda_backends is not None and hasattr(cuda_backends, "sdp_kernel"):
1952
+ return cuda_backends.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
1953
+ return nullcontext()
1954
+
1955
+
1956
+ def symbols_to_classes(symbols: torch.Tensor) -> torch.Tensor:
1957
+ constellation = CONSTELLATION.to(symbols.device)
1958
+ diff = symbols.unsqueeze(1) - constellation.unsqueeze(0)
1959
+ dist = torch.sum(diff.square(), dim=2)
1960
+ return torch.argmin(dist, dim=1)
1961
+
1962
+
1963
+ def calculate_ber_from_classes(tx_classes: torch.Tensor, rx_classes: torch.Tensor) -> float:
1964
+ bit_labels = BIT_LABELS.to(tx_classes.device)
1965
+ tx_bits = bit_labels[tx_classes]
1966
+ rx_bits = bit_labels[rx_classes]
1967
+ return (tx_bits != rx_bits).float().mean().item()
1968
+
1969
+
1970
+ def compute_rms_scale(symbols: torch.Tensor) -> torch.Tensor:
1971
+ power = torch.mean(symbols.square().sum(dim=1)).sqrt()
1972
+ return power.clamp_min(1e-6)
1973
+
1974
+
1975
+ def power_normalize_pair(
1976
+ tx_symbols: torch.Tensor,
1977
+ rx_symbols: torch.Tensor,
1978
+ tx_scale: Optional[torch.Tensor] = None,
1979
+ rx_scale: Optional[torch.Tensor] = None,
1980
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1981
+ if tx_scale is None:
1982
+ tx_scale = compute_rms_scale(tx_symbols)
1983
+ if rx_scale is None:
1984
+ rx_scale = compute_rms_scale(rx_symbols)
1985
+ return tx_symbols / tx_scale, rx_symbols / rx_scale, tx_scale, rx_scale
1986
+
1987
+
1988
+ def find_best_symbol_scale(tx_symbols: torch.Tensor, rx_symbols: torch.Tensor) -> float:
1989
+ if not Config.BER_SCALE_SEARCH or tx_symbols.numel() == 0 or rx_symbols.numel() == 0:
1990
+ return 1.0
1991
+
1992
+ total = min(tx_symbols.size(0), rx_symbols.size(0))
1993
+ offset = min(Config.BER_SCALE_OFFSET, max(total - 1, 0))
1994
+ available = max(total - offset, 1)
1995
+ sample_count = min(available, Config.BER_SCALE_SAMPLES)
1996
+ tx_eval = tx_symbols[offset : offset + sample_count]
1997
+ rx_eval = rx_symbols[offset : offset + sample_count]
1998
+
1999
+ tx_norm = tx_eval.square().sum(dim=1).mean().sqrt().clamp_min(1e-6)
2000
+ rx_norm = rx_eval.square().sum(dim=1).mean().sqrt().clamp_min(1e-6)
2001
+ center = (tx_norm / rx_norm).item()
2002
+ left = max(Config.BER_SCALE_MIN, 0.5 * center)
2003
+ right = min(Config.BER_SCALE_MAX, 1.5 * center)
2004
+ phi = (1 + 5**0.5) / 2
2005
+
2006
+ def objective(scale: float) -> float:
2007
+ pred_classes = symbols_to_classes((rx_eval * scale).float())
2008
+ target_classes = symbols_to_classes(tx_eval.float())
2009
+ return calculate_ber_from_classes(target_classes, pred_classes)
2010
+
2011
+ x1 = right - (right - left) / phi
2012
+ x2 = left + (right - left) / phi
2013
+ y1 = objective(x1)
2014
+ y2 = objective(x2)
2015
+ for _ in range(Config.BER_SCALE_STEPS):
2016
+ if y1 >= y2:
2017
+ left = x1
2018
+ x1 = x2
2019
+ y1 = y2
2020
+ x2 = left + (right - left) / phi
2021
+ y2 = objective(x2)
2022
+ else:
2023
+ right = x2
2024
+ x2 = x1
2025
+ y2 = y1
2026
+ x1 = right - (right - left) / phi
2027
+ y1 = objective(x1)
2028
+ return float((x1 + x2) * 0.5)
2029
+
2030
+
2031
+ def discover_symbol_files() -> Tuple[Path, List[int]]:
2032
+ for base_dir in Config.DATA_DIR_CANDIDATES:
2033
+ if not base_dir.exists():
2034
+ continue
2035
+ files = sorted(base_dir.glob("Symbols_1m_1ch_PR_*.csv"))
2036
+ if files:
2037
+ indices = sorted(int(path.stem.split("_")[-1]) for path in files)
2038
+ return base_dir, indices[: Config.MAX_FILES]
2039
+ raise FileNotFoundError("No Symbols_1m_1ch_PR_*.csv files found.")
2040
+
2041
+
2042
+ def resolve_splits(file_indices: List[int]) -> Tuple[List[int], List[int], List[int]]:
2043
+ file_indices = list(file_indices)
2044
+ if Config.RANDOMIZE_FILE_SPLIT:
2045
+ rng = np.random.default_rng(Config.SPLIT_SEED)
2046
+ file_indices = rng.permutation(file_indices).tolist()
2047
+
2048
+ train_val_files = max(2, int(len(file_indices) * Config.TRAIN_PORTION))
2049
+ if train_val_files >= len(file_indices):
2050
+ train_val_files = len(file_indices) - 1
2051
+ train_val_pool = file_indices[:train_val_files]
2052
+ if len(train_val_pool) < 2:
2053
+ raise ValueError("Need at least 2 train+val files so validation can be held out.")
2054
+
2055
+ requested_val_files = int(round(len(train_val_pool) * Config.VAL_PORTION_WITHIN_TRAIN))
2056
+ val_files = max(Config.MIN_VAL_FILES, requested_val_files)
2057
+ val_files = min(max(val_files, 1), len(train_val_pool) - 1)
2058
+
2059
+ train_idx = train_val_pool[:-val_files]
2060
+ val_idx = train_val_pool[-val_files:]
2061
+ test_idx = file_indices[train_val_files:]
2062
+ if not test_idx:
2063
+ raise ValueError("File-level split produced no test files. Increase MAX_FILES or lower TRAIN_PORTION.")
2064
+ return train_idx, val_idx, test_idx
2065
+
2066
+
2067
+ def load_symbol_file(base_dir: Path, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
2068
+ path = base_dir / f"Symbols_1m_1ch_PR_{idx}.csv"
2069
+ arr = pd.read_csv(path, header=None, dtype=np.float32).to_numpy(copy=False)
2070
+ return torch.from_numpy(arr[:, 0:2].copy()), torch.from_numpy(arr[:, 2:4].copy())
2071
+
2072
+
2073
+ def load_files(base_dir: Path, file_indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
2074
+ tx_list: List[torch.Tensor] = []
2075
+ rx_list: List[torch.Tensor] = []
2076
+ for idx in file_indices:
2077
+ tx_symbols, rx_symbols = load_symbol_file(base_dir, idx)
2078
+ tx_list.append(tx_symbols)
2079
+ rx_list.append(rx_symbols)
2080
+ return torch.cat(tx_list, dim=0), torch.cat(rx_list, dim=0)
2081
+
2082
+
2083
+ def load_file_splits(base_dir: Path, file_indices: List[int]) -> List[Dict[str, Any]]:
2084
+ files: List[Dict[str, Any]] = []
2085
+ for idx in file_indices:
2086
+ tx_symbols, rx_symbols = load_symbol_file(base_dir, idx)
2087
+ files.append({"file_idx": idx, "tx": tx_symbols, "rx": rx_symbols})
2088
+ return files
2089
+
2090
+
2091
+ def normalize_file_splits(
2092
+ files: List[Dict[str, Any]],
2093
+ tx_scale: torch.Tensor,
2094
+ rx_scale: torch.Tensor,
2095
+ ) -> List[Dict[str, Any]]:
2096
+ return [
2097
+ {
2098
+ "file_idx": item["file_idx"],
2099
+ "tx": item["tx"] / tx_scale,
2100
+ "rx": item["rx"] / rx_scale,
2101
+ }
2102
+ for item in files
2103
+ ]
2104
+
2105
+
2106
+ def make_windows(rx_symbols: torch.Tensor, tx_symbols: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
2107
+ norm_rx = (rx_symbols - mean) / std
2108
+ norm_rx = torch.nan_to_num(norm_rx, nan=0.0, posinf=0.0, neginf=0.0)
2109
+ window_view = norm_rx.unfold(0, Config.SEQ_LEN, 1).permute(0, 2, 1).contiguous()
2110
+ x = window_view.view(window_view.size(0), -1).contiguous()
2111
+ y = tx_symbols[Config.CONTEXT_K : tx_symbols.size(0) - Config.CONTEXT_K].contiguous()
2112
+ return x, y
2113
+
2114
+
2115
+ def make_windows_for_files(
2116
+ files: List[Dict[str, Any]],
2117
+ mean: torch.Tensor,
2118
+ std: torch.Tensor,
2119
+ collect_spans: bool = True,
2120
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
2121
+ x_parts: List[torch.Tensor] = []
2122
+ y_parts: List[torch.Tensor] = []
2123
+ file_spans: List[Dict[str, Any]] = []
2124
+ offset = 0
2125
+
2126
+ for item in files:
2127
+ tx_symbols = item["tx"]
2128
+ rx_symbols = item["rx"]
2129
+ if rx_symbols.size(0) < Config.SEQ_LEN:
2130
+ continue
2131
+ x_file, y_file = make_windows(rx_symbols, tx_symbols, mean, std)
2132
+ count = x_file.size(0)
2133
+ if collect_spans:
2134
+ file_spans.append(
2135
+ {
2136
+ "file_idx": item["file_idx"],
2137
+ "start": offset,
2138
+ "end": offset + count,
2139
+ "tx_center": tx_symbols[Config.CONTEXT_K : tx_symbols.size(0) - Config.CONTEXT_K].contiguous(),
2140
+ "rx_center": rx_symbols[Config.CONTEXT_K : rx_symbols.size(0) - Config.CONTEXT_K].contiguous(),
2141
+ }
2142
+ )
2143
+ x_parts.append(x_file)
2144
+ y_parts.append(y_file)
2145
+ offset += count
2146
+
2147
+ if not x_parts:
2148
+ raise ValueError("No file is long enough to build context windows.")
2149
+ return torch.cat(x_parts, dim=0), torch.cat(y_parts, dim=0), file_spans
2150
+
2151
+
2152
+ def prepare_data(max_test_files: Optional[int] = None) -> Dict[str, Any]:
2153
+ base_dir, all_indices = discover_symbol_files()
2154
+ train_idx, val_idx, test_idx = resolve_splits(all_indices)
2155
+ if max_test_files is not None:
2156
+ test_idx = test_idx[:max_test_files]
2157
+ if not test_idx:
2158
+ raise ValueError("max_test_files truncated test split to zero files.")
2159
+ log(f"Data dir: {base_dir}")
2160
+ log(
2161
+ "Split protocol: train/val/test are separated by file; "
2162
+ "test files are never used for fitting, checkpoint selection, or normalization statistics."
2163
+ )
2164
+ if Config.RANDOMIZE_FILE_SPLIT:
2165
+ log(f"File split randomized with SPLIT_SEED={Config.SPLIT_SEED}")
2166
+ log(f"Train files: {train_idx}")
2167
+ log(f"Val files: {val_idx}")
2168
+ log(f"Test files: {test_idx}")
2169
+
2170
+ train_files = load_file_splits(base_dir, train_idx)
2171
+ val_files = load_file_splits(base_dir, val_idx)
2172
+ test_files = load_file_splits(base_dir, test_idx)
2173
+ tx_train_raw = torch.cat([item["tx"] for item in train_files], dim=0)
2174
+ rx_train_raw = torch.cat([item["rx"] for item in train_files], dim=0)
2175
+
2176
+ if Config.POWER_NORMALIZE:
2177
+ _, _, tx_scale, rx_scale = power_normalize_pair(tx_train_raw, rx_train_raw)
2178
+ else:
2179
+ tx_scale = torch.tensor(1.0, dtype=tx_train_raw.dtype)
2180
+ rx_scale = torch.tensor(1.0, dtype=rx_train_raw.dtype)
2181
+
2182
+ train_files = normalize_file_splits(train_files, tx_scale, rx_scale)
2183
+ val_files = normalize_file_splits(val_files, tx_scale, rx_scale)
2184
+ test_files = normalize_file_splits(test_files, tx_scale, rx_scale)
2185
+ tx_train = torch.cat([item["tx"] for item in train_files], dim=0)
2186
+ rx_train = torch.cat([item["rx"] for item in train_files], dim=0)
2187
+ tx_test = torch.cat([item["tx"] for item in test_files], dim=0)
2188
+ rx_test = torch.cat([item["rx"] for item in test_files], dim=0)
2189
+
2190
+ mean = rx_train.mean(dim=0, keepdim=True)
2191
+ std = rx_train.std(dim=0, keepdim=True)
2192
+ std[std == 0] = 1.0
2193
+
2194
+ train_x, train_y, _ = make_windows_for_files(train_files, mean, std, collect_spans=False)
2195
+ val_x, val_y, val_file_spans = make_windows_for_files(val_files, mean, std)
2196
+ test_x, test_y, test_file_spans = make_windows_for_files(test_files, mean, std)
2197
+ log(
2198
+ "Window samples (built per file, no cross-file context): "
2199
+ f"train {train_x.size(0):,} | val {val_x.size(0):,} | test {test_x.size(0):,}"
2200
+ )
2201
+ rx_test_center = torch.cat([item["rx_center"] for item in test_file_spans], dim=0)
2202
+ tx_test_center = torch.cat([item["tx_center"] for item in test_file_spans], dim=0)
2203
+
2204
+ return {
2205
+ "train_x": train_x,
2206
+ "train_y": train_y,
2207
+ "val_x": val_x,
2208
+ "val_y": val_y,
2209
+ "test_x": test_x,
2210
+ "test_y": test_y,
2211
+ "val_file_spans": val_file_spans,
2212
+ "test_file_spans": test_file_spans,
2213
+ "tx_test": tx_test,
2214
+ "rx_test": rx_test,
2215
+ "tx_test_center": tx_test_center,
2216
+ "rx_test_center": rx_test_center,
2217
+ "mean": mean,
2218
+ "std": std,
2219
+ "tx_scale": tx_scale,
2220
+ "rx_scale": rx_scale,
2221
+ }
2222
+
2223
+
2224
+ def make_model(name: str) -> nn.Module:
2225
+ if name in {"efficient_kan_baseline", "efficient_kan", "ekan"}:
2226
+ return EfficientKANBaselineEqualizer().to(Config.DEVICE)
2227
+ if name in {"efficient_kan_residual", "kan_residual", "residual_kan"}:
2228
+ return EfficientKANResidualEqualizer().to(Config.DEVICE)
2229
+ if name in {"efficient_kan_features", "kan_features", "feature_kan"}:
2230
+ return EfficientKANFeatureEqualizer().to(Config.DEVICE)
2231
+ if name in {"cnn_kan", "kan_cnn"}:
2232
+ return CNNKANEqualizer().to(Config.DEVICE)
2233
+ if name in {"kan_classifier", "efficient_kan_classifier"}:
2234
+ return EfficientKANClassifierEqualizer().to(Config.DEVICE)
2235
+ if name in {"fastkan_classifier", "rbf_kan_classifier"}:
2236
+ return FastKANClassifierEqualizer().to(Config.DEVICE)
2237
+ if name == "cnn":
2238
+ return CNNRxEqualizer().to(Config.DEVICE)
2239
+ if name == "lstm":
2240
+ return LSTMRxEqualizer().to(Config.DEVICE)
2241
+ if name in {"hybrid", "cnn_lstm"}:
2242
+ return HybridCNNLSTMEqualizer().to(Config.DEVICE)
2243
+ if name in {"complex_fastkan", "fastkan", "kan"}:
2244
+ return ComplexFastKANEqualizer().to(Config.DEVICE)
2245
+ if name in {"complex_fastkan_classifier", "complex_rbf_kan_classifier"}:
2246
+ return ComplexFastKANClassifierEqualizer().to(Config.DEVICE)
2247
+ if name == "complex_lstm":
2248
+ return ComplexLSTMRxEqualizer().to(Config.DEVICE)
2249
+ if name == "complex_dbp_seqstat":
2250
+ return ComplexDBPSeqStatRxEqualizer().to(Config.DEVICE)
2251
+ if name == "complex_cnn_lstm":
2252
+ return ComplexCNNLSTMRxEqualizer().to(Config.DEVICE)
2253
+ if name == "complex_cnn":
2254
+ return ComplexCNNRxEqualizer().to(Config.DEVICE)
2255
+ if name == "transformer":
2256
+ return TransformerRxEqualizer().to(Config.DEVICE)
2257
+ if name == "tcn":
2258
+ return TCNRxEqualizer().to(Config.DEVICE)
2259
+ if name == "mamba":
2260
+ return MambaRxEqualizer().to(Config.DEVICE)
2261
+ if name == "mlp":
2262
+ return MLPRxEqualizer().to(Config.DEVICE)
2263
+ raise ValueError(f"Unknown model: {name}")
2264
+
2265
+
2266
+ def iter_tensor_batches(x: torch.Tensor, y: torch.Tensor, batch_size: int, shuffle: bool):
2267
+ total = x.size(0)
2268
+ if shuffle:
2269
+ order = torch.randperm(total)
2270
+ x = x.index_select(0, order)
2271
+ y = y.index_select(0, order)
2272
+ for start in range(0, total, batch_size):
2273
+ end = min(start + batch_size, total)
2274
+ xb = x[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
2275
+ yb = y[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
2276
+ yield xb, yb
2277
+
2278
+
2279
+ @torch.inference_mode()
2280
+ def evaluate_split(
2281
+ model: nn.Module,
2282
+ x: torch.Tensor,
2283
+ y: torch.Tensor,
2284
+ batch_size: int,
2285
+ scale_search: bool = False,
2286
+ ) -> Tuple[float, float, float, int, float]:
2287
+ model.eval()
2288
+ criterion = build_criterion()
2289
+ bit_labels = BIT_LABELS.to(Config.DEVICE)
2290
+ current_batch_size = batch_size
2291
+
2292
+ while True:
2293
+ total_loss = 0.0
2294
+ total_samples = 0
2295
+ correct = 0
2296
+ bit_errors = 0
2297
+ total_bits = 0
2298
+ preds_accum: List[torch.Tensor] = []
2299
+ targets_accum: List[torch.Tensor] = []
2300
+ try:
2301
+ for xb, yb in iter_tensor_batches(x, y, batch_size=current_batch_size, shuffle=False):
2302
+ with autocast_context():
2303
+ mark_cudagraph_step_begin()
2304
+ preds = model(xb)
2305
+ loss = prediction_loss(preds, yb, criterion)
2306
+ preds_float = preds.float()
2307
+ target_float = yb.float()
2308
+ preds_accum.append(preds_float.detach().cpu())
2309
+ targets_accum.append(target_float.detach().cpu())
2310
+ batch_size_now = yb.size(0)
2311
+ total_loss += loss.item() * batch_size_now
2312
+ total_samples += batch_size_now
2313
+ preds_all = torch.cat(preds_accum, dim=0)
2314
+ targets_all = torch.cat(targets_accum, dim=0)
2315
+ target_classes = symbols_to_classes(targets_all.to(Config.DEVICE))
2316
+ if is_classifier_output(preds_all):
2317
+ scale = 1.0
2318
+ pred_classes = torch.argmax(preds_all.to(Config.DEVICE), dim=1)
2319
+ else:
2320
+ scale = find_best_symbol_scale(targets_all, preds_all) if scale_search else 1.0
2321
+ pred_classes = symbols_to_classes((preds_all * scale).to(Config.DEVICE))
2322
+ correct = (pred_classes == target_classes).sum().item()
2323
+ tx_bits = bit_labels[target_classes]
2324
+ rx_bits = bit_labels[pred_classes]
2325
+ bit_errors = (tx_bits != rx_bits).sum().item()
2326
+ total_bits = tx_bits.numel()
2327
+ return (
2328
+ total_loss / max(total_samples, 1),
2329
+ correct / max(total_samples, 1),
2330
+ bit_errors / max(total_bits, 1),
2331
+ current_batch_size,
2332
+ scale,
2333
+ )
2334
+ except RuntimeError as error:
2335
+ if Config.DEVICE.type != "cuda" or not is_cuda_oom(error):
2336
+ raise
2337
+ if current_batch_size <= Config.MIN_BLOCK_SIZE:
2338
+ raise
2339
+ next_batch_size = max(current_batch_size // 2, Config.MIN_BLOCK_SIZE)
2340
+ log(f"eval | CUDA OOM at batch_size={current_batch_size}, retrying with {next_batch_size}")
2341
+ current_batch_size = next_batch_size
2342
+ torch.cuda.empty_cache()
2343
+
2344
+
2345
+ @torch.inference_mode()
2346
+ def compute_split_file_metrics(
2347
+ model: nn.Module,
2348
+ data: Dict[str, Any],
2349
+ split: str,
2350
+ batch_size: int,
2351
+ ) -> Tuple[List[Dict[str, float]], int]:
2352
+ file_spans = data.get(f"{split}_file_spans", [])
2353
+ if not file_spans:
2354
+ return [], batch_size
2355
+
2356
+ results: List[Dict[str, float]] = []
2357
+ current_batch_size = batch_size
2358
+ x = data[f"{split}_x"]
2359
+ y = data[f"{split}_y"]
2360
+ for item in file_spans:
2361
+ start = int(item["start"])
2362
+ end = int(item["end"])
2363
+ baseline_scale = find_best_symbol_scale(item["tx_center"], item["rx_center"])
2364
+ tx_cls = symbols_to_classes(item["tx_center"].to(Config.DEVICE))
2365
+ rx_cls = symbols_to_classes((item["rx_center"] * baseline_scale).to(Config.DEVICE))
2366
+ baseline_ber = calculate_ber_from_classes(tx_cls, rx_cls)
2367
+ loss, acc, ber, current_batch_size, equalizer_scale = evaluate_split(
2368
+ model,
2369
+ x[start:end],
2370
+ y[start:end],
2371
+ current_batch_size,
2372
+ scale_search=True,
2373
+ )
2374
+ results.append(
2375
+ {
2376
+ "file_idx": int(item["file_idx"]),
2377
+ "samples": int(end - start),
2378
+ "baseline_ber": float(baseline_ber),
2379
+ "baseline_scale": float(baseline_scale),
2380
+ "equalized_ber": float(ber),
2381
+ "equalizer_scale": float(equalizer_scale),
2382
+ "accuracy": float(acc),
2383
+ "loss": float(loss),
2384
+ "improvement_rel": float((1 - ber / baseline_ber) * 100 if baseline_ber > 0 else 0.0),
2385
+ }
2386
+ )
2387
+ return results, current_batch_size
2388
+
2389
+
2390
+ def add_file_metric_summary(metrics: Dict[str, Any], split: str, file_metrics: List[Dict[str, float]]):
2391
+ if not file_metrics:
2392
+ return
2393
+ equalized = np.array([row["equalized_ber"] for row in file_metrics], dtype=np.float64)
2394
+ baseline = np.array([row["baseline_ber"] for row in file_metrics], dtype=np.float64)
2395
+ metrics[f"{split}_file_equalized_ber_mean"] = float(equalized.mean())
2396
+ metrics[f"{split}_file_equalized_ber_std"] = float(equalized.std())
2397
+ metrics[f"{split}_file_equalized_ber_worst"] = float(equalized.max())
2398
+ metrics[f"{split}_file_baseline_ber_mean"] = float(baseline.mean())
2399
+ metrics[f"{split}_file_baseline_ber_worst"] = float(baseline.max())
2400
+ metrics[f"{split}_file_equalized_ber_by_file"] = ";".join(
2401
+ f"{int(row['file_idx'])}:{row['equalized_ber']:.6e}" for row in file_metrics
2402
+ )
2403
+ metrics[f"{split}_file_baseline_ber_by_file"] = ";".join(
2404
+ f"{int(row['file_idx'])}:{row['baseline_ber']:.6e}" for row in file_metrics
2405
+ )
2406
+
2407
+
2408
+ @torch.inference_mode()
2409
+ def compute_test_metrics(model: nn.Module, data: Dict[str, Any]) -> Dict[str, Any]:
2410
+ eval_prefix = "test"
2411
+ baseline_scale = find_best_symbol_scale(data[f"tx_{eval_prefix}_center"], data[f"rx_{eval_prefix}_center"])
2412
+ tx_cls = symbols_to_classes(data[f"tx_{eval_prefix}_center"].to(Config.DEVICE))
2413
+ rx_cls = symbols_to_classes((data[f"rx_{eval_prefix}_center"] * baseline_scale).to(Config.DEVICE))
2414
+ baseline_ber = calculate_ber_from_classes(tx_cls, rx_cls)
2415
+ eval_batch_size = data.get("eval_batch_size", Config.EVAL_BATCH_SIZE)
2416
+ test_loss, test_acc, test_ber, safe_eval_batch_size, equalizer_scale = evaluate_split(
2417
+ model, data[f"{eval_prefix}_x"], data[f"{eval_prefix}_y"], eval_batch_size, scale_search=True
2418
+ )
2419
+ return {
2420
+ "eval_split": eval_prefix,
2421
+ "baseline_ber": baseline_ber,
2422
+ "baseline_scale": baseline_scale,
2423
+ "equalized_ber": test_ber,
2424
+ "equalizer_scale": equalizer_scale,
2425
+ "accuracy": test_acc,
2426
+ "ser": 1.0 - test_acc,
2427
+ "test_loss": test_loss,
2428
+ "safe_eval_batch_size": safe_eval_batch_size,
2429
+ "improvement_abs": baseline_ber - test_ber,
2430
+ "improvement_rel": (1 - test_ber / baseline_ber) * 100 if baseline_ber > 0 else 0.0,
2431
+ "improvement_db": 10 * np.log10(baseline_ber / test_ber) if test_ber > 0 else float("inf"),
2432
+ }
2433
+
2434
+
2435
+ def build_optimizer_with_lr(model: nn.Module, lr: float) -> optim.Optimizer:
2436
+ optimizer_kwargs = {"lr": lr, "weight_decay": Config.WEIGHT_DECAY}
2437
+ if Config.DEVICE.type == "cuda":
2438
+ optimizer_kwargs["fused"] = True
2439
+ try:
2440
+ return optim.Adam(model.parameters(), **optimizer_kwargs)
2441
+ except TypeError:
2442
+ optimizer_kwargs.pop("fused", None)
2443
+ return optim.Adam(model.parameters(), **optimizer_kwargs)
2444
+
2445
+
2446
+ def get_efficient_kan_module(model: nn.Module) -> Optional[nn.Module]:
2447
+ kan = getattr(model, "kan", None)
2448
+ if kan is None or not hasattr(kan, "layers"):
2449
+ return None
2450
+ layers = list(getattr(kan, "layers"))
2451
+ if not layers or not all(hasattr(layer, "base_weight") and hasattr(layer, "spline_weight") for layer in layers):
2452
+ return None
2453
+ return kan
2454
+
2455
+
2456
+ def efficient_kan_layer_sizes(kan: nn.Module) -> List[int]:
2457
+ layers = list(kan.layers)
2458
+ if not layers:
2459
+ return []
2460
+ return [int(layers[0].in_features)] + [int(layer.out_features) for layer in layers]
2461
+
2462
+
2463
+ @torch.no_grad()
2464
+ def efficient_kan_edge_norm(layer: nn.Module) -> torch.Tensor:
2465
+ base = layer.base_weight.detach().float().abs()
2466
+ spline = layer.scaled_spline_weight.detach().float().abs().mean(dim=2)
2467
+ return base + spline
2468
+
2469
+
2470
+ @torch.no_grad()
2471
+ def select_efficient_kan_hidden_units(kan: nn.Module, keep_ratio: float) -> List[torch.Tensor]:
2472
+ layers = list(kan.layers)
2473
+ selections: List[torch.Tensor] = []
2474
+ for layer_idx in range(len(layers) - 1):
2475
+ current_layer = layers[layer_idx]
2476
+ next_layer = layers[layer_idx + 1]
2477
+ hidden_size = int(current_layer.out_features)
2478
+ keep_count = max(Config.KAN_STRUCTURAL_PRUNE_MIN_HIDDEN, int(round(hidden_size * keep_ratio)))
2479
+ keep_count = min(hidden_size, max(1, keep_count))
2480
+ incoming = efficient_kan_edge_norm(current_layer).mean(dim=1)
2481
+ outgoing = efficient_kan_edge_norm(next_layer).mean(dim=0)
2482
+ importance = torch.sqrt(incoming.clamp_min(1e-12) * outgoing.clamp_min(1e-12))
2483
+ keep = torch.topk(importance, k=keep_count, largest=True, sorted=False).indices
2484
+ keep = torch.sort(keep.cpu()).values
2485
+ selections.append(keep)
2486
+ return selections
2487
+
2488
+
2489
+ @torch.no_grad()
2490
+ def copy_pruned_kan_layer(old_layer: nn.Module, new_layer: nn.Module, input_idx: torch.Tensor, output_idx: torch.Tensor):
2491
+ input_idx = input_idx.to(old_layer.base_weight.device)
2492
+ output_idx = output_idx.to(old_layer.base_weight.device)
2493
+ new_layer.grid.copy_(old_layer.grid.index_select(0, input_idx).to(new_layer.grid.device))
2494
+ new_layer.base_weight.copy_(
2495
+ old_layer.base_weight.index_select(0, output_idx).index_select(1, input_idx).to(new_layer.base_weight.device)
2496
+ )
2497
+ new_layer.spline_weight.copy_(
2498
+ old_layer.spline_weight.index_select(0, output_idx)
2499
+ .index_select(1, input_idx)
2500
+ .to(new_layer.spline_weight.device)
2501
+ )
2502
+ if hasattr(old_layer, "spline_scaler") and hasattr(new_layer, "spline_scaler"):
2503
+ new_layer.spline_scaler.copy_(
2504
+ old_layer.spline_scaler.index_select(0, output_idx)
2505
+ .index_select(1, input_idx)
2506
+ .to(new_layer.spline_scaler.device)
2507
+ )
2508
+
2509
+
2510
+ def structurally_prune_efficient_kan_model(model: nn.Module, keep_ratio: float) -> Tuple[nn.Module, List[int]]:
2511
+ if EfficientKAN is None:
2512
+ raise ImportError("EfficientKAN is unavailable.")
2513
+ old_kan = get_efficient_kan_module(model)
2514
+ if old_kan is None:
2515
+ raise ValueError("Model does not expose a prunable EfficientKAN module via `.kan`.")
2516
+
2517
+ old_layers = list(old_kan.layers)
2518
+ old_sizes = efficient_kan_layer_sizes(old_kan)
2519
+ hidden_selections = select_efficient_kan_hidden_units(old_kan, keep_ratio)
2520
+ all_indices: List[torch.Tensor] = [
2521
+ torch.arange(old_sizes[0], dtype=torch.long),
2522
+ *hidden_selections,
2523
+ torch.arange(old_sizes[-1], dtype=torch.long),
2524
+ ]
2525
+ new_sizes = [int(idx.numel()) for idx in all_indices]
2526
+ new_kan = EfficientKAN(
2527
+ layers_hidden=new_sizes,
2528
+ grid_size=int(old_kan.grid_size),
2529
+ spline_order=int(old_kan.spline_order),
2530
+ scale_noise=Config.EFFICIENT_KAN_SCALE_NOISE,
2531
+ scale_base=Config.EFFICIENT_KAN_SCALE_BASE,
2532
+ scale_spline=Config.EFFICIENT_KAN_SCALE_SPLINE,
2533
+ base_activation=nn.SiLU,
2534
+ grid_eps=Config.EFFICIENT_KAN_GRID_EPS,
2535
+ grid_range=Config.EFFICIENT_KAN_GRID_RANGE,
2536
+ ).to(Config.DEVICE)
2537
+ for layer_idx, (old_layer, new_layer) in enumerate(zip(old_layers, new_kan.layers)):
2538
+ copy_pruned_kan_layer(old_layer, new_layer, all_indices[layer_idx], all_indices[layer_idx + 1])
2539
+ model.kan = new_kan
2540
+ return model, new_sizes
2541
+
2542
+
2543
+ def efficiency_score_from_metrics(metrics: Dict[str, Any]) -> float:
2544
+ baseline = float(metrics.get("baseline_ber", 0.0))
2545
+ equalized = float(metrics.get("equalized_ber", 0.0))
2546
+ batch_time = float(metrics.get("efficiency_batch_time_sec", 0.0))
2547
+ if baseline <= 0 or batch_time <= 0:
2548
+ return 0.0
2549
+ improvement = max((baseline - equalized) / baseline, 0.0)
2550
+ return float(improvement**Config.EFFICIENCY_SCORE_POWER / batch_time)
2551
+
2552
+
2553
+ @torch.inference_mode()
2554
+ def measure_batch_inference_time(model: nn.Module, x: torch.Tensor) -> float:
2555
+ model.eval()
2556
+ batch_size = min(Config.EFFICIENCY_BATCH_SIZE, x.size(0))
2557
+ if batch_size <= 0:
2558
+ return 0.0
2559
+ xb = x[:batch_size].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
2560
+ for _ in range(Config.EFFICIENCY_TIMING_WARMUP):
2561
+ with autocast_context():
2562
+ mark_cudagraph_step_begin()
2563
+ _ = model(xb)
2564
+ if Config.DEVICE.type == "cuda":
2565
+ torch.cuda.synchronize()
2566
+ start = time.perf_counter()
2567
+ for _ in range(Config.EFFICIENCY_TIMING_REPEATS):
2568
+ with autocast_context():
2569
+ mark_cudagraph_step_begin()
2570
+ _ = model(xb)
2571
+ if Config.DEVICE.type == "cuda":
2572
+ torch.cuda.synchronize()
2573
+ elapsed = time.perf_counter() - start
2574
+ return elapsed / max(Config.EFFICIENCY_TIMING_REPEATS, 1)
2575
+
2576
+
2577
+ def add_efficiency_metrics(model: nn.Module, data: Dict[str, Any], metrics: Dict[str, Any]) -> Dict[str, Any]:
2578
+ batch_time = measure_batch_inference_time(model, data["test_x"])
2579
+ metrics["efficiency_batch_size"] = int(min(Config.EFFICIENCY_BATCH_SIZE, data["test_x"].size(0)))
2580
+ metrics["efficiency_batch_time_sec"] = float(batch_time)
2581
+ metrics["efficiency_score"] = efficiency_score_from_metrics(metrics)
2582
+ return metrics
2583
+
2584
+
2585
+ def fine_tune_model_for_pruning(
2586
+ model: nn.Module,
2587
+ data: Dict[str, Any],
2588
+ epochs: int,
2589
+ lr: float,
2590
+ eval_batch_size: int,
2591
+ ) -> Tuple[nn.Module, Dict[str, float], int]:
2592
+ if epochs <= 0:
2593
+ val_loss, val_acc, val_ber, eval_batch_size, val_scale = evaluate_split(
2594
+ model, data["val_x"], data["val_y"], eval_batch_size, scale_search=True
2595
+ )
2596
+ return model, {"val_loss": val_loss, "val_acc": val_acc, "val_ber": val_ber, "val_scale": val_scale}, eval_batch_size
2597
+
2598
+ optimizer = build_optimizer_with_lr(model, lr)
2599
+ criterion = build_criterion()
2600
+ scaler = torch.amp.GradScaler("cuda", enabled=Config.DEVICE.type == "cuda" and Config.USE_AMP)
2601
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
2602
+ best_metrics = {"val_loss": float("inf"), "val_acc": 0.0, "val_ber": float("inf"), "val_scale": 1.0}
2603
+ train_block_size = Config.TRAIN_BLOCK_SIZE
2604
+
2605
+ for epoch in range(epochs):
2606
+ model.train()
2607
+ running_loss = 0.0
2608
+ seen = 0
2609
+ total_train = data["train_x"].size(0)
2610
+ num_blocks = (total_train + train_block_size - 1) // train_block_size
2611
+ for block_idx in torch.randperm(num_blocks).tolist():
2612
+ block_start = block_idx * train_block_size
2613
+ block_end = min(block_start + train_block_size, total_train)
2614
+ xb = data["train_x"][block_start:block_end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
2615
+ yb = data["train_y"][block_start:block_end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
2616
+ optimizer.zero_grad(set_to_none=True)
2617
+ with autocast_context():
2618
+ mark_cudagraph_step_begin()
2619
+ preds = model(xb)
2620
+ loss = prediction_loss(preds, yb, criterion) + compute_model_regularization(model)
2621
+ scaler.scale(loss).backward()
2622
+ if Config.GRAD_CLIP_NORM > 0:
2623
+ scaler.unscale_(optimizer)
2624
+ torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRAD_CLIP_NORM)
2625
+ scaler.step(optimizer)
2626
+ scaler.update()
2627
+ running_loss += loss.item() * yb.size(0)
2628
+ seen += yb.size(0)
2629
+
2630
+ val_loss, val_acc, val_ber, eval_batch_size, val_scale = evaluate_split(
2631
+ model, data["val_x"], data["val_y"], eval_batch_size, scale_search=True
2632
+ )
2633
+ if val_ber < best_metrics["val_ber"]:
2634
+ best_metrics = {"val_loss": val_loss, "val_acc": val_acc, "val_ber": val_ber, "val_scale": val_scale}
2635
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
2636
+ log(
2637
+ f"prune fine-tune | epoch {epoch+1:3d}/{epochs} | "
2638
+ f"train {running_loss / max(seen, 1):.6f} | val_ber {val_ber:.6e}"
2639
+ )
2640
+
2641
+ model.load_state_dict(best_state)
2642
+ return model, best_metrics, eval_batch_size
2643
+
2644
+
2645
+ def should_replace_by_pruned(candidate_metrics: Dict[str, Any], best_metrics: Dict[str, Any]) -> bool:
2646
+ mode = Config.KAN_STRUCTURAL_PRUNE_SELECT_BY
2647
+ if mode == "val_ber":
2648
+ return float(candidate_metrics.get("prune_val_ber", float("inf"))) < float(
2649
+ best_metrics.get("prune_val_ber", best_metrics.get("best_val_ber", float("inf")))
2650
+ )
2651
+ if mode == "test_ber":
2652
+ return float(candidate_metrics["equalized_ber"]) < float(best_metrics["equalized_ber"])
2653
+ return float(candidate_metrics.get("efficiency_score", 0.0)) > float(best_metrics.get("efficiency_score", 0.0))
2654
+
2655
+
2656
+ def maybe_prune_efficient_kan_model(
2657
+ model: nn.Module,
2658
+ model_name: str,
2659
+ data: Dict[str, Any],
2660
+ base_metrics: Dict[str, Any],
2661
+ eval_batch_size: int,
2662
+ ) -> Tuple[nn.Module, Dict[str, Any], int]:
2663
+ if not Config.KAN_STRUCTURAL_PRUNE_AFTER_TRAINING:
2664
+ return model, base_metrics, eval_batch_size
2665
+ if get_efficient_kan_module(model) is None:
2666
+ log(f"{model_name} | structural KAN pruning skipped: model has no prunable EfficientKAN `.kan`")
2667
+ return model, base_metrics, eval_batch_size
2668
+
2669
+ base_metrics = add_efficiency_metrics(model, data, base_metrics)
2670
+ base_metrics["pruned"] = False
2671
+ base_metrics["prune_keep_ratio"] = 1.0
2672
+ base_metrics["prune_layer_sizes"] = str(efficient_kan_layer_sizes(get_efficient_kan_module(model)))
2673
+ base_metrics["prune_val_ber"] = base_metrics.get("best_val_ber", float("inf"))
2674
+ best_model = model
2675
+ best_metrics = dict(base_metrics)
2676
+ rows: List[Dict[str, Any]] = [dict(base_metrics, model_type=model_name, prune_candidate="unpruned")]
2677
+
2678
+ for keep_ratio in Config.KAN_STRUCTURAL_PRUNE_KEEP_RATIOS:
2679
+ candidate = copy.deepcopy(model).to(Config.DEVICE)
2680
+ try:
2681
+ candidate, layer_sizes = structurally_prune_efficient_kan_model(candidate, keep_ratio)
2682
+ except Exception as exc:
2683
+ log(f"{model_name} | prune keep_ratio={keep_ratio:.2f} skipped: {exc}")
2684
+ continue
2685
+ candidate_params = count_trainable_parameters(candidate)
2686
+ log(
2687
+ f"{model_name} | prune candidate keep_ratio={keep_ratio:.2f} | "
2688
+ f"layers {layer_sizes} | params {candidate_params:,}"
2689
+ )
2690
+ candidate, val_metrics, eval_batch_size = fine_tune_model_for_pruning(
2691
+ candidate,
2692
+ data,
2693
+ Config.KAN_STRUCTURAL_PRUNE_FINE_TUNE_EPOCHS,
2694
+ Config.KAN_STRUCTURAL_PRUNE_FINE_TUNE_LR,
2695
+ eval_batch_size,
2696
+ )
2697
+ candidate_metrics = compute_test_metrics(candidate, {**data, "eval_batch_size": eval_batch_size})
2698
+ candidate_metrics["trainable_params"] = candidate_params
2699
+ candidate_metrics["pruned"] = True
2700
+ candidate_metrics["prune_keep_ratio"] = float(keep_ratio)
2701
+ candidate_metrics["prune_layer_sizes"] = str(layer_sizes)
2702
+ candidate_metrics["prune_val_loss"] = float(val_metrics["val_loss"])
2703
+ candidate_metrics["prune_val_acc"] = float(val_metrics["val_acc"])
2704
+ candidate_metrics["prune_val_ber"] = float(val_metrics["val_ber"])
2705
+ candidate_metrics = add_efficiency_metrics(candidate, data, candidate_metrics)
2706
+ rows.append(dict(candidate_metrics, model_type=model_name, prune_candidate=f"keep_{keep_ratio:.2f}"))
2707
+ log(
2708
+ f"{model_name} | prune keep_ratio={keep_ratio:.2f} | "
2709
+ f"test_ber {candidate_metrics['equalized_ber']:.6e} | "
2710
+ f"batch16k {candidate_metrics['efficiency_batch_time_sec']:.6f}s | "
2711
+ f"score {candidate_metrics['efficiency_score']:.3f}"
2712
+ )
2713
+ if should_replace_by_pruned(candidate_metrics, best_metrics):
2714
+ best_model = candidate
2715
+ best_metrics = dict(candidate_metrics)
2716
+ elif Config.DEVICE.type == "cuda":
2717
+ candidate.to("cpu")
2718
+ del candidate
2719
+ torch.cuda.empty_cache()
2720
+
2721
+ prune_df = pd.DataFrame(rows)
2722
+ prune_path = Config.OUT_DIR / f"{model_name}_pruning_candidates.csv"
2723
+ prune_df.to_csv(prune_path, index=False)
2724
+ log(
2725
+ f"{model_name} | pruning selected keep_ratio={best_metrics.get('prune_keep_ratio', 1.0)} | "
2726
+ f"test_ber {best_metrics['equalized_ber']:.6e} | "
2727
+ f"score {best_metrics.get('efficiency_score', 0.0):.3f} | saved {prune_path}"
2728
+ )
2729
+ return best_model, best_metrics, eval_batch_size
2730
+
2731
+
2732
+ def plot_results(history: Dict, eval_results: Dict, model_name: str):
2733
+ fig, axes = plt.subplots(2, 3, figsize=(18, 10))
2734
+ train_epochs = list(range(1, len(history["train_loss"]) + 1))
2735
+ val_epochs = history["val_epochs"]
2736
+ test_epochs = history["test_epochs"]
2737
+
2738
+ axes[0, 0].plot(train_epochs, history["train_loss"], label="Train", linewidth=2, alpha=0.8)
2739
+ axes[0, 0].plot(val_epochs, history["val_loss"], label="Val", linewidth=2)
2740
+ if history["test_loss"]:
2741
+ axes[0, 0].plot(test_epochs, history["test_loss"], label="Test", linewidth=2, linestyle="--")
2742
+ axes[0, 0].set_title(f"{model_name} - Loss Curves", fontweight="bold")
2743
+ axes[0, 0].set_xlabel("Epoch")
2744
+ axes[0, 0].set_ylabel("Loss")
2745
+ axes[0, 0].set_yscale("log")
2746
+ axes[0, 0].grid(alpha=0.3)
2747
+ axes[0, 0].legend()
2748
+
2749
+ axes[0, 1].plot(val_epochs, history["val_acc"], color="green", linewidth=2, label="Val Acc")
2750
+ if history["test_acc"]:
2751
+ axes[0, 1].plot(test_epochs, history["test_acc"], color="#0984e3", linewidth=2, linestyle="--", label="Test Acc")
2752
+ axes[0, 1].set_title(f"{model_name} - Accuracy", fontweight="bold")
2753
+ axes[0, 1].set_xlabel("Epoch")
2754
+ axes[0, 1].set_ylabel("Accuracy")
2755
+ axes[0, 1].set_ylim([0, 1])
2756
+ axes[0, 1].grid(alpha=0.3)
2757
+ axes[0, 1].legend()
2758
+
2759
+ axes[0, 2].plot(train_epochs, history["lr"], color="red", linewidth=2)
2760
+ axes[0, 2].set_title("Learning Rate Schedule", fontweight="bold")
2761
+ axes[0, 2].set_xlabel("Epoch")
2762
+ axes[0, 2].set_ylabel("LR")
2763
+ axes[0, 2].set_yscale("log")
2764
+ axes[0, 2].grid(alpha=0.3)
2765
+
2766
+ axes[1, 0].bar(
2767
+ ["Baseline", f"{model_name} EQ"],
2768
+ [eval_results["baseline_ber"], eval_results["equalized_ber"]],
2769
+ color=["#ff7675", "#55efc4"],
2770
+ edgecolor="black",
2771
+ linewidth=1.5,
2772
+ )
2773
+ axes[1, 0].set_title("BER Comparison (log scale)", fontweight="bold")
2774
+ axes[1, 0].set_ylabel("BER")
2775
+ axes[1, 0].set_yscale("log")
2776
+ axes[1, 0].grid(axis="y", alpha=0.3)
2777
+
2778
+ metrics = ["Abs Reduction\n(pp)", "Rel Improvement\n(%)", "SNR Gain\n(dB)"]
2779
+ values = [
2780
+ eval_results["improvement_abs"] * 100,
2781
+ eval_results["improvement_rel"],
2782
+ min(eval_results["improvement_db"], 20),
2783
+ ]
2784
+ axes[1, 1].bar(metrics, values, color=["#74b9ff", "#a29bfe", "#fd79a8"], edgecolor="black", linewidth=1.5)
2785
+ axes[1, 1].set_title("Improvement Metrics", fontweight="bold")
2786
+ axes[1, 1].set_ylabel("Value")
2787
+ axes[1, 1].grid(axis="y", alpha=0.3)
2788
+
2789
+ axes[1, 2].plot(train_epochs, [loss * 100 for loss in history["train_loss"]], label="Train Loss x100", alpha=0.7, linewidth=1)
2790
+ axes[1, 2].plot(val_epochs, [acc * 100 for acc in history["val_acc"]], label="Val Acc (%)", alpha=0.7, linewidth=2)
2791
+ if history["test_ber"]:
2792
+ axes[1, 2].plot(test_epochs, [ber * 100 for ber in history["test_ber"]], label="Test BER (%)", alpha=0.9, linewidth=2)
2793
+ axes[1, 2].set_title("Training Dynamics", fontweight="bold")
2794
+ axes[1, 2].set_xlabel("Epoch")
2795
+ axes[1, 2].set_ylabel("Value")
2796
+ axes[1, 2].grid(alpha=0.3)
2797
+ axes[1, 2].legend()
2798
+
2799
+ plt.suptitle(f"{model_name} Equalizer Performance", fontsize=16, fontweight="bold")
2800
+ plt.tight_layout()
2801
+ out_path = Config.OUT_DIR / f"ber_results_{model_name.lower()}.png"
2802
+ plt.savefig(out_path, dpi=150, bbox_inches="tight")
2803
+ plt.close(fig)
2804
+
2805
+
2806
+ def plot_architecture_summary(results: List[Dict[str, float]]):
2807
+ summary = pd.DataFrame(results).sort_values("equalized_ber").reset_index(drop=True)
2808
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
2809
+
2810
+ axes[0].bar(summary["model_type"], summary["equalized_ber"], color="#00b894", edgecolor="black")
2811
+ axes[0].set_title("Equalized BER by Architecture", fontweight="bold")
2812
+ axes[0].set_ylabel("BER")
2813
+ axes[0].set_yscale("log")
2814
+ axes[0].grid(axis="y", alpha=0.3)
2815
+
2816
+ axes[1].bar(summary["model_type"], summary["improvement_rel"], color="#0984e3", edgecolor="black")
2817
+ axes[1].set_title("Relative BER Improvement", fontweight="bold")
2818
+ axes[1].set_ylabel("Improvement (%)")
2819
+ axes[1].grid(axis="y", alpha=0.3)
2820
+
2821
+ plt.tight_layout()
2822
+ out_path = Config.OUT_DIR / "architecture_summary.png"
2823
+ plt.savefig(out_path, dpi=150, bbox_inches="tight")
2824
+ plt.close(fig)
2825
+
2826
+ summary.to_csv(Config.OUT_DIR / "architecture_comparison.csv", index=False)
2827
+
2828
+
2829
+ def plot_sweep_per_model(df: pd.DataFrame, x_col: str, x_label: str, filename_prefix: str):
2830
+ for model_name in Config.MODEL_TYPES:
2831
+ model_df = df[df["model_type"] == model_name].sort_values(x_col)
2832
+ fig, ax = plt.subplots(figsize=(8, 5))
2833
+ ax.plot(model_df[x_col], model_df["equalized_ber"], marker="o", linewidth=2)
2834
+ ax.set_title(f"{model_name.upper()} - BER vs {x_label}", fontweight="bold")
2835
+ ax.set_xlabel(x_label)
2836
+ ax.set_ylabel("BER")
2837
+ ax.set_yscale("log")
2838
+ ax.grid(alpha=0.3)
2839
+ plt.tight_layout()
2840
+ plt.savefig(Config.OUT_DIR / f"{filename_prefix}_{model_name}.png", dpi=150, bbox_inches="tight")
2841
+ plt.close(fig)
2842
+
2843
+
2844
+ def plot_sweep_overlay(df: pd.DataFrame, x_col: str, x_label: str, filename: str):
2845
+ fig, ax = plt.subplots(figsize=(9, 6))
2846
+ for model_name in Config.MODEL_TYPES:
2847
+ model_df = df[df["model_type"] == model_name].sort_values(x_col)
2848
+ ax.plot(model_df[x_col], model_df["equalized_ber"], marker="o", linewidth=2, label=model_name.upper())
2849
+ ax.set_title(f"BER vs {x_label} - All Models", fontweight="bold")
2850
+ ax.set_xlabel(x_label)
2851
+ ax.set_ylabel("BER")
2852
+ ax.set_yscale("log")
2853
+ ax.grid(alpha=0.3)
2854
+ ax.legend()
2855
+ plt.tight_layout()
2856
+ plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
2857
+ plt.close(fig)
2858
+
2859
+
2860
+ def plot_efficient_kan_sweep(df: pd.DataFrame, x_col: str, x_label: str, filename: str, title: str):
2861
+ fig, ax = plt.subplots(figsize=(9, 6))
2862
+ for model_name in Config.EFFICIENT_KAN_SWEEP_MODELS:
2863
+ model_df = df[df["model_type"] == model_name].sort_values(x_col)
2864
+ if model_df.empty:
2865
+ continue
2866
+ ax.plot(model_df[x_col], model_df["equalized_ber"], marker="o", linewidth=2, label=model_name)
2867
+ ax.set_title(title, fontweight="bold")
2868
+ ax.set_xlabel(x_label)
2869
+ ax.set_ylabel("BER")
2870
+ ax.set_yscale("log")
2871
+ ax.grid(alpha=0.3)
2872
+ ax.legend()
2873
+ plt.tight_layout()
2874
+ plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
2875
+ plt.close(fig)
2876
+
2877
+
2878
+ def plot_efficient_kan_tradeoff(df: pd.DataFrame, x_col: str, x_label: str, filename: str, title: str):
2879
+ fig, ax = plt.subplots(figsize=(9, 6))
2880
+ for model_name in Config.EFFICIENT_KAN_SWEEP_MODELS:
2881
+ model_df = df[df["model_type"] == model_name]
2882
+ if model_df.empty:
2883
+ continue
2884
+ ax.scatter(model_df[x_col], model_df["equalized_ber"], s=60, label=model_name)
2885
+ ax.set_title(title, fontweight="bold")
2886
+ ax.set_xlabel(x_label)
2887
+ ax.set_ylabel("BER")
2888
+ ax.set_yscale("log")
2889
+ ax.grid(alpha=0.3)
2890
+ ax.legend()
2891
+ plt.tight_layout()
2892
+ plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
2893
+ plt.close(fig)
2894
+
2895
+
2896
+ def plot_experiment_lines(
2897
+ df: pd.DataFrame,
2898
+ x_col: str,
2899
+ x_label: str,
2900
+ filename: str,
2901
+ title: str,
2902
+ y_col: str = "equalized_ber",
2903
+ ):
2904
+ fig, ax = plt.subplots(figsize=(9, 6))
2905
+ for model_name, model_df in df.groupby("model_type"):
2906
+ model_df = model_df.sort_values(x_col)
2907
+ ax.plot(model_df[x_col], model_df[y_col], marker="o", linewidth=2, label=model_name)
2908
+ ax.set_title(title, fontweight="bold")
2909
+ ax.set_xlabel(x_label)
2910
+ ax.set_ylabel("BER")
2911
+ ax.set_yscale("log")
2912
+ ax.grid(alpha=0.3)
2913
+ ax.legend()
2914
+ plt.tight_layout()
2915
+ plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
2916
+ plt.close(fig)
2917
+
2918
+
2919
+ def plot_experiment_complexity(df: pd.DataFrame, filename: str, title: str):
2920
+ fig, ax = plt.subplots(figsize=(9, 6))
2921
+ for model_name, model_df in df.groupby("model_type"):
2922
+ ax.plot(
2923
+ model_df["trainable_params"],
2924
+ model_df["equalized_ber"],
2925
+ marker="o",
2926
+ linewidth=1.5,
2927
+ linestyle="-",
2928
+ label=model_name,
2929
+ )
2930
+ ax.set_title(title, fontweight="bold")
2931
+ ax.set_xlabel("Trainable Parameters")
2932
+ ax.set_ylabel("BER")
2933
+ ax.set_xscale("log")
2934
+ ax.set_yscale("log")
2935
+ ax.grid(alpha=0.3)
2936
+ ax.legend()
2937
+ plt.tight_layout()
2938
+ plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
2939
+ plt.close(fig)
2940
+
2941
+
2942
+ def hidden_overrides(hidden_dim: int) -> Dict[str, int]:
2943
+ return {
2944
+ "HIDDEN_DIM": hidden_dim,
2945
+ "EFFICIENT_KAN_HIDDEN_DIM": hidden_dim,
2946
+ "FASTKAN_HIDDEN_DIM": hidden_dim,
2947
+ }
2948
+
2949
+
2950
+ def layer_overrides(layer_count: int) -> Dict[str, int]:
2951
+ return {
2952
+ "EFFICIENT_KAN_LAYERS": layer_count,
2953
+ "FASTKAN_LAYERS": layer_count,
2954
+ "MLP_LAYERS": layer_count,
2955
+ }
2956
+
2957
+
2958
+ def run_model_with_overrides(model_name: str, max_test_files: int, **overrides) -> Dict[str, Any]:
2959
+ tracked_keys = [
2960
+ "CONTEXT_K",
2961
+ "SEQ_LEN",
2962
+ "HIDDEN_DIM",
2963
+ "LSTM_HIDDEN",
2964
+ "SAVE_BEST",
2965
+ "EPOCHS",
2966
+ "LEARNING_RATE",
2967
+ "EFFICIENT_KAN_HIDDEN_DIM",
2968
+ "EFFICIENT_KAN_LAYERS",
2969
+ "EFFICIENT_KAN_GRID_SIZE",
2970
+ "EFFICIENT_KAN_SPLINE_ORDER",
2971
+ "FASTKAN_HIDDEN_DIM",
2972
+ "FASTKAN_LAYERS",
2973
+ "FASTKAN_NUM_GRIDS",
2974
+ "MLP_LAYERS",
2975
+ "COMPUTE_PER_FILE_METRICS",
2976
+ ]
2977
+ previous = {key: getattr(Config, key) for key in tracked_keys}
2978
+ try:
2979
+ for key, value in overrides.items():
2980
+ setattr(Config, key, value)
2981
+ Config.SEQ_LEN = 2 * Config.CONTEXT_K + 1
2982
+ Config.SAVE_BEST = False
2983
+ data = prepare_data(max_test_files=max_test_files)
2984
+ model, _, results = train_one_model(model_name, data)
2985
+ if Config.DEVICE.type == "cuda":
2986
+ model.to("cpu")
2987
+ del model
2988
+ torch.cuda.empty_cache()
2989
+ return results
2990
+ finally:
2991
+ for key, value in previous.items():
2992
+ setattr(Config, key, value)
2993
+
2994
+
2995
+ def plot_fastkan_classifier_score(df: pd.DataFrame, filename: str):
2996
+ fig, ax = plt.subplots(figsize=(9, 6))
2997
+ for model_name, model_df in df.groupby("model_type"):
2998
+ ax.scatter(
2999
+ model_df["trainable_params"],
3000
+ model_df["efficiency_score"],
3001
+ s=70,
3002
+ label=model_name,
3003
+ )
3004
+ ax.set_title("FastKAN Classifiers: BER-Speed Efficiency", fontweight="bold")
3005
+ ax.set_xlabel("Trainable Parameters")
3006
+ ax.set_ylabel("Efficiency Score")
3007
+ ax.set_xscale("log")
3008
+ ax.grid(alpha=0.3)
3009
+ ax.legend()
3010
+ plt.tight_layout()
3011
+ plt.savefig(Config.OUT_DIR / filename, dpi=150, bbox_inches="tight")
3012
+ plt.close(fig)
3013
+
3014
+
3015
+ def run_fastkan_classifier_sweep():
3016
+ log("\nRunning compact FastKAN/RBF-KAN classifier sweep...")
3017
+ rows: List[Dict[str, Any]] = []
3018
+ output_path = Config.OUT_DIR / "fastkan_classifier_sweep_all.csv"
3019
+ base = {
3020
+ "EPOCHS": Config.FASTKAN_CLASSIFIER_SWEEP_EPOCHS,
3021
+ "COMPUTE_PER_FILE_METRICS": False,
3022
+ "FASTKAN_HIDDEN_DIM": Config.FASTKAN_HIDDEN_DIM,
3023
+ "FASTKAN_LAYERS": Config.FASTKAN_LAYERS,
3024
+ "FASTKAN_NUM_GRIDS": Config.FASTKAN_NUM_GRIDS,
3025
+ }
3026
+
3027
+ def run_case(sweep_name: str, model_name: str, **overrides):
3028
+ effective = {**base, **overrides}
3029
+ log(
3030
+ f"fastkan sweep | {sweep_name} | {model_name} | "
3031
+ f"hidden={effective['FASTKAN_HIDDEN_DIM']} | grids={effective['FASTKAN_NUM_GRIDS']} | "
3032
+ f"layers={effective['FASTKAN_LAYERS']}"
3033
+ )
3034
+ results = run_model_with_overrides(
3035
+ model_name,
3036
+ max_test_files=Config.FASTKAN_CLASSIFIER_SWEEP_TEST_FILES,
3037
+ **effective,
3038
+ )
3039
+ rows.append(
3040
+ {
3041
+ "sweep": sweep_name,
3042
+ "model_type": model_name,
3043
+ "hidden_dim": effective["FASTKAN_HIDDEN_DIM"],
3044
+ "num_grids": effective["FASTKAN_NUM_GRIDS"],
3045
+ "layers": effective["FASTKAN_LAYERS"],
3046
+ **results,
3047
+ }
3048
+ )
3049
+ pd.DataFrame(rows).to_csv(output_path, index=False)
3050
+
3051
+ for model_name in Config.FASTKAN_CLASSIFIER_SWEEP_MODELS:
3052
+ for hidden_dim in Config.FASTKAN_CLASSIFIER_HIDDEN_VALUES:
3053
+ run_case("hidden_dim", model_name, FASTKAN_HIDDEN_DIM=hidden_dim)
3054
+ for num_grids in Config.FASTKAN_CLASSIFIER_GRID_VALUES:
3055
+ run_case("num_grids", model_name, FASTKAN_NUM_GRIDS=num_grids)
3056
+ for layers in Config.FASTKAN_CLASSIFIER_LAYER_VALUES:
3057
+ run_case("layers", model_name, FASTKAN_LAYERS=layers)
3058
+
3059
+ df = pd.DataFrame(rows)
3060
+ for sweep_name, x_col, x_label in [
3061
+ ("hidden_dim", "hidden_dim", "Hidden Dimension"),
3062
+ ("num_grids", "num_grids", "RBF Grid Count"),
3063
+ ("layers", "layers", "FastKAN Layers"),
3064
+ ]:
3065
+ sweep_df = df[df["sweep"] == sweep_name].copy()
3066
+ sweep_df.to_csv(Config.OUT_DIR / f"fastkan_classifier_ber_vs_{sweep_name}.csv", index=False)
3067
+ plot_experiment_lines(
3068
+ sweep_df,
3069
+ x_col=x_col,
3070
+ x_label=x_label,
3071
+ filename=f"fastkan_classifier_ber_vs_{sweep_name}.png",
3072
+ title=f"FastKAN Classifier BER vs {x_label}",
3073
+ )
3074
+ plot_experiment_complexity(
3075
+ df,
3076
+ filename="fastkan_classifier_ber_vs_complexity.png",
3077
+ title="FastKAN Classifiers: BER vs Complexity",
3078
+ )
3079
+ plot_fastkan_classifier_score(df, "fastkan_classifier_efficiency_score.png")
3080
+ log(f"Saved FastKAN classifier sweep: {output_path}")
3081
+
3082
+
3083
+ def run_sweep_experiments():
3084
+ if Config.RUN_FASTKAN_CLASSIFIER_SWEEP:
3085
+ run_fastkan_classifier_sweep()
3086
+ return
3087
+ if Config.RUN_KAN_EXPERIMENT_SUITE:
3088
+ run_kan_experiment_suite()
3089
+ return
3090
+ if Config.RUN_EFFICIENT_KAN_SWEEP:
3091
+ run_efficient_kan_sweep_experiments()
3092
+ return
3093
+
3094
+ log("\nRunning BER sweep experiments on one test file...")
3095
+ window_rows: List[Dict[str, float]] = []
3096
+ hidden_rows: List[Dict[str, float]] = []
3097
+
3098
+ for model_name in Config.MODEL_TYPES:
3099
+ for context_k in Config.WINDOW_SWEEP_VALUES:
3100
+ log(f"sweep | {model_name} | window={context_k}")
3101
+ results = run_model_with_overrides(
3102
+ model_name,
3103
+ max_test_files=Config.SWEEP_TEST_FILES,
3104
+ CONTEXT_K=context_k,
3105
+ )
3106
+ window_rows.append(
3107
+ {
3108
+ "model_type": model_name,
3109
+ "context_k": context_k,
3110
+ "seq_len": 2 * context_k + 1,
3111
+ **results,
3112
+ }
3113
+ )
3114
+
3115
+ for hidden_size in Config.HIDDEN_SWEEP_VALUES:
3116
+ log(f"sweep | {model_name} | hidden={hidden_size}")
3117
+ results = run_model_with_overrides(
3118
+ model_name,
3119
+ max_test_files=Config.SWEEP_TEST_FILES,
3120
+ HIDDEN_DIM=hidden_size,
3121
+ LSTM_HIDDEN=hidden_size,
3122
+ )
3123
+ hidden_rows.append(
3124
+ {
3125
+ "model_type": model_name,
3126
+ "hidden_size": hidden_size,
3127
+ **results,
3128
+ }
3129
+ )
3130
+
3131
+ window_df = pd.DataFrame(window_rows)
3132
+ hidden_df = pd.DataFrame(hidden_rows)
3133
+ window_df.to_csv(Config.OUT_DIR / "ber_vs_window.csv", index=False)
3134
+ hidden_df.to_csv(Config.OUT_DIR / "ber_vs_hidden.csv", index=False)
3135
+
3136
+ plot_sweep_per_model(window_df, x_col="seq_len", x_label="Window Size", filename_prefix="ber_vs_window")
3137
+ plot_sweep_per_model(hidden_df, x_col="hidden_size", x_label="Hidden Size", filename_prefix="ber_vs_hidden")
3138
+ plot_sweep_overlay(window_df, x_col="seq_len", x_label="Window Size", filename="ber_vs_window_overlay.png")
3139
+ plot_sweep_overlay(hidden_df, x_col="hidden_size", x_label="Hidden Size", filename="ber_vs_hidden_overlay.png")
3140
+
3141
+
3142
+ def run_kan_experiment_suite():
3143
+ log("\nRunning KAN/MLP experiment suite...")
3144
+ rows: List[Dict[str, Any]] = []
3145
+
3146
+ base = {
3147
+ "EPOCHS": Config.EXPERIMENT_EPOCHS,
3148
+ "EFFICIENT_KAN_GRID_SIZE": Config.EXPERIMENT_FIXED_GRID,
3149
+ "EFFICIENT_KAN_SPLINE_ORDER": Config.EXPERIMENT_FIXED_SPLINE_ORDER,
3150
+ "COMPUTE_PER_FILE_METRICS": Config.EXPERIMENT_COMPUTE_PER_FILE_METRICS,
3151
+ }
3152
+
3153
+ def run_case(sweep_name: str, model_name: str, **overrides):
3154
+ effective = {**base, **overrides}
3155
+ effective["SEQ_LEN"] = 2 * effective.get("CONTEXT_K", Config.CONTEXT_K) + 1
3156
+ case_id = (
3157
+ f"{sweep_name}_{model_name}_"
3158
+ f"k{effective.get('CONTEXT_K', Config.CONTEXT_K)}_"
3159
+ f"h{effective.get('EFFICIENT_KAN_HIDDEN_DIM', Config.EFFICIENT_KAN_HIDDEN_DIM)}_"
3160
+ f"mlph{effective.get('HIDDEN_DIM', Config.HIDDEN_DIM)}_"
3161
+ f"l{effective.get('EFFICIENT_KAN_LAYERS', Config.EFFICIENT_KAN_LAYERS)}_"
3162
+ f"mlpl{effective.get('MLP_LAYERS', Config.MLP_LAYERS)}_"
3163
+ f"g{effective.get('EFFICIENT_KAN_GRID_SIZE', Config.EFFICIENT_KAN_GRID_SIZE)}_"
3164
+ f"o{effective.get('EFFICIENT_KAN_SPLINE_ORDER', Config.EFFICIENT_KAN_SPLINE_ORDER)}"
3165
+ )
3166
+ log(f"experiment | {case_id}")
3167
+ results = run_model_with_overrides(
3168
+ model_name,
3169
+ max_test_files=Config.EXPERIMENT_TEST_FILES,
3170
+ **effective,
3171
+ )
3172
+ row = {
3173
+ "case_id": case_id,
3174
+ "sweep": sweep_name,
3175
+ "model_type": model_name,
3176
+ "context_k": effective.get("CONTEXT_K", Config.CONTEXT_K),
3177
+ "seq_len": effective.get("SEQ_LEN", Config.SEQ_LEN),
3178
+ "hidden_dim": effective.get("EFFICIENT_KAN_HIDDEN_DIM", effective.get("HIDDEN_DIM", Config.HIDDEN_DIM)),
3179
+ "mlp_hidden_dim": effective.get("HIDDEN_DIM", Config.HIDDEN_DIM),
3180
+ "kan_hidden_dim": effective.get("EFFICIENT_KAN_HIDDEN_DIM", Config.EFFICIENT_KAN_HIDDEN_DIM),
3181
+ "layers": effective.get("EFFICIENT_KAN_LAYERS", effective.get("MLP_LAYERS", Config.MLP_LAYERS)),
3182
+ "mlp_layers": effective.get("MLP_LAYERS", Config.MLP_LAYERS),
3183
+ "kan_layers": effective.get("EFFICIENT_KAN_LAYERS", Config.EFFICIENT_KAN_LAYERS),
3184
+ "grid_size": effective.get("EFFICIENT_KAN_GRID_SIZE", Config.EFFICIENT_KAN_GRID_SIZE),
3185
+ "spline_order": effective.get("EFFICIENT_KAN_SPLINE_ORDER", Config.EFFICIENT_KAN_SPLINE_ORDER),
3186
+ "epochs": effective["EPOCHS"],
3187
+ **results,
3188
+ }
3189
+ rows.append(row)
3190
+ pd.DataFrame(rows).to_csv(Config.OUT_DIR / "kan_experiment_suite_all.csv", index=False)
3191
+
3192
+ for model_name in Config.EXPERIMENT_KAN_MODELS:
3193
+ for grid_size in Config.EXPERIMENT_GRID_VALUES:
3194
+ run_case("ber_vs_grid", model_name, EFFICIENT_KAN_GRID_SIZE=grid_size)
3195
+
3196
+ for spline_order in Config.EXPERIMENT_SPLINE_ORDER_VALUES:
3197
+ run_case("ber_vs_spline_order", model_name, EFFICIENT_KAN_SPLINE_ORDER=spline_order)
3198
+
3199
+ for model_name in Config.EXPERIMENT_COMPARE_MODELS:
3200
+ for hidden_dim in Config.EXPERIMENT_HIDDEN_VALUES:
3201
+ run_case(
3202
+ "kan_mlp_vs_hidden_grid16",
3203
+ model_name,
3204
+ **hidden_overrides(hidden_dim),
3205
+ EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
3206
+ )
3207
+
3208
+ for context_k in Config.EXPERIMENT_WINDOW_VALUES:
3209
+ run_case(
3210
+ "kan_mlp_vs_window",
3211
+ model_name,
3212
+ CONTEXT_K=context_k,
3213
+ EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
3214
+ )
3215
+
3216
+ for layer_count in Config.EXPERIMENT_LAYER_VALUES:
3217
+ run_case(
3218
+ "kan_mlp_vs_layers",
3219
+ model_name,
3220
+ **layer_overrides(layer_count),
3221
+ EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
3222
+ )
3223
+
3224
+ for model_name in Config.EXPERIMENT_COMPLEXITY_MODELS:
3225
+ for hidden_dim in Config.EXPERIMENT_HIDDEN_VALUES:
3226
+ run_case(
3227
+ "ber_vs_complexity",
3228
+ model_name,
3229
+ **hidden_overrides(hidden_dim),
3230
+ EFFICIENT_KAN_GRID_SIZE=Config.EXPERIMENT_FIXED_GRID,
3231
+ )
3232
+
3233
+ df = pd.DataFrame(rows)
3234
+ all_path = Config.OUT_DIR / "kan_experiment_suite_all.csv"
3235
+ df.to_csv(all_path, index=False)
3236
+
3237
+ plot_specs = [
3238
+ ("ber_vs_grid", "grid_size", "Grid Size", "ber_vs_grid.png", "BER vs Grid Size"),
3239
+ (
3240
+ "ber_vs_spline_order",
3241
+ "spline_order",
3242
+ "Spline Order",
3243
+ "ber_vs_spline_order.png",
3244
+ "BER vs Spline Order",
3245
+ ),
3246
+ (
3247
+ "kan_mlp_vs_hidden_grid16",
3248
+ "hidden_dim",
3249
+ "Hidden Dimension",
3250
+ "kan_mlp_ber_vs_hidden_grid16.png",
3251
+ "KAN vs MLP: BER vs Hidden Dim (grid=16)",
3252
+ ),
3253
+ (
3254
+ "kan_mlp_vs_window",
3255
+ "seq_len",
3256
+ "Window Size (symbols)",
3257
+ "kan_mlp_ber_vs_window.png",
3258
+ "KAN vs MLP: BER vs Window Size",
3259
+ ),
3260
+ (
3261
+ "kan_mlp_vs_layers",
3262
+ "layers",
3263
+ "Number of Layers",
3264
+ "kan_mlp_ber_vs_layers.png",
3265
+ "KAN vs MLP: BER vs Number of Layers",
3266
+ ),
3267
+ ]
3268
+ for sweep_name, x_col, x_label, filename, title in plot_specs:
3269
+ sweep_df = df[df["sweep"] == sweep_name].copy()
3270
+ sweep_df.to_csv(Config.OUT_DIR / f"{sweep_name}.csv", index=False)
3271
+ if not sweep_df.empty:
3272
+ plot_experiment_lines(sweep_df, x_col=x_col, x_label=x_label, filename=filename, title=title)
3273
+
3274
+ complexity_df = df[df["sweep"] == "ber_vs_complexity"].copy()
3275
+ complexity_df.to_csv(Config.OUT_DIR / "ber_vs_complexity.csv", index=False)
3276
+ if not complexity_df.empty:
3277
+ plot_experiment_complexity(
3278
+ complexity_df,
3279
+ filename="ber_vs_complexity.png",
3280
+ title="BER vs Complexity",
3281
+ )
3282
+
3283
+ log(f"Saved KAN experiment suite: {all_path}")
3284
+
3285
+
3286
+ def run_efficient_kan_sweep_experiments():
3287
+ log("\nRunning EfficientKAN regression/classifier sweep...")
3288
+ rows: List[Dict[str, float]] = []
3289
+ base = {
3290
+ "EPOCHS": Config.EFFICIENT_KAN_SWEEP_EPOCHS,
3291
+ "EFFICIENT_KAN_HIDDEN_DIM": Config.EFFICIENT_KAN_HIDDEN_DIM,
3292
+ "EFFICIENT_KAN_LAYERS": Config.EFFICIENT_KAN_LAYERS,
3293
+ "EFFICIENT_KAN_GRID_SIZE": Config.EFFICIENT_KAN_GRID_SIZE,
3294
+ "EFFICIENT_KAN_SPLINE_ORDER": Config.EFFICIENT_KAN_SPLINE_ORDER,
3295
+ "LEARNING_RATE": Config.LEARNING_RATE,
3296
+ }
3297
+
3298
+ def run_case(sweep_name: str, model_name: str, **overrides):
3299
+ effective = {**base, **overrides}
3300
+ case_id = (
3301
+ f"{sweep_name}_{model_name}_"
3302
+ f"h{effective['EFFICIENT_KAN_HIDDEN_DIM']}_"
3303
+ f"l{effective['EFFICIENT_KAN_LAYERS']}_"
3304
+ f"g{effective['EFFICIENT_KAN_GRID_SIZE']}_"
3305
+ f"o{effective['EFFICIENT_KAN_SPLINE_ORDER']}_"
3306
+ f"lr{effective['LEARNING_RATE']:.0e}"
3307
+ )
3308
+ log(f"sweep | {case_id}")
3309
+ results = run_model_with_overrides(
3310
+ model_name,
3311
+ max_test_files=Config.EFFICIENT_KAN_SWEEP_TEST_FILES,
3312
+ **effective,
3313
+ )
3314
+ row = {
3315
+ "case_id": case_id,
3316
+ "sweep": sweep_name,
3317
+ "model_type": model_name,
3318
+ "hidden_dim": effective["EFFICIENT_KAN_HIDDEN_DIM"],
3319
+ "layers": effective["EFFICIENT_KAN_LAYERS"],
3320
+ "grid_size": effective["EFFICIENT_KAN_GRID_SIZE"],
3321
+ "spline_order": effective["EFFICIENT_KAN_SPLINE_ORDER"],
3322
+ "learning_rate": effective["LEARNING_RATE"],
3323
+ "epochs": effective["EPOCHS"],
3324
+ **results,
3325
+ }
3326
+ rows.append(row)
3327
+ pd.DataFrame(rows).to_csv(Config.OUT_DIR / "efficient_kan_sweep_all.csv", index=False)
3328
+
3329
+ for model_name in Config.EFFICIENT_KAN_SWEEP_MODELS:
3330
+ for hidden_dim in Config.EFFICIENT_KAN_HIDDEN_SWEEP_VALUES:
3331
+ for learning_rate in Config.EFFICIENT_KAN_LR_SWEEP_VALUES:
3332
+ run_case(
3333
+ "hidden_lr",
3334
+ model_name,
3335
+ EFFICIENT_KAN_HIDDEN_DIM=hidden_dim,
3336
+ LEARNING_RATE=learning_rate,
3337
+ )
3338
+
3339
+ for grid_size in Config.EFFICIENT_KAN_GRID_SWEEP_VALUES:
3340
+ run_case("grid_size", model_name, EFFICIENT_KAN_GRID_SIZE=grid_size)
3341
+
3342
+ for spline_order in Config.EFFICIENT_KAN_ORDER_SWEEP_VALUES:
3343
+ run_case("spline_order", model_name, EFFICIENT_KAN_SPLINE_ORDER=spline_order)
3344
+
3345
+ for layers in Config.EFFICIENT_KAN_LAYER_SWEEP_VALUES:
3346
+ run_case("layers", model_name, EFFICIENT_KAN_LAYERS=layers)
3347
+
3348
+ df = pd.DataFrame(rows)
3349
+ df.to_csv(Config.OUT_DIR / "efficient_kan_sweep_all.csv", index=False)
3350
+
3351
+ hidden_df = df[df["sweep"] == "hidden_lr"].copy()
3352
+ grid_df = df[df["sweep"] == "grid_size"].copy()
3353
+ order_df = df[df["sweep"] == "spline_order"].copy()
3354
+ layer_df = df[df["sweep"] == "layers"].copy()
3355
+
3356
+ hidden_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_hidden_lr.csv", index=False)
3357
+ grid_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_grid.csv", index=False)
3358
+ order_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_order.csv", index=False)
3359
+ layer_df.to_csv(Config.OUT_DIR / "efficient_kan_ber_vs_layers.csv", index=False)
3360
+
3361
+ for learning_rate in Config.EFFICIENT_KAN_LR_SWEEP_VALUES:
3362
+ lr_df = hidden_df[hidden_df["learning_rate"] == learning_rate]
3363
+ if lr_df.empty:
3364
+ continue
3365
+ plot_efficient_kan_sweep(
3366
+ lr_df,
3367
+ x_col="hidden_dim",
3368
+ x_label="EfficientKAN Hidden Dimension",
3369
+ filename=f"efficient_kan_ber_vs_hidden_lr_{learning_rate:.0e}.png",
3370
+ title=f"EfficientKAN BER vs Hidden Dim (lr={learning_rate:.0e})",
3371
+ )
3372
+
3373
+ plot_efficient_kan_sweep(
3374
+ grid_df,
3375
+ x_col="grid_size",
3376
+ x_label="Grid Size",
3377
+ filename="efficient_kan_ber_vs_grid.png",
3378
+ title="EfficientKAN BER vs Grid Size",
3379
+ )
3380
+ plot_efficient_kan_sweep(
3381
+ order_df,
3382
+ x_col="spline_order",
3383
+ x_label="Spline Order",
3384
+ filename="efficient_kan_ber_vs_spline_order.png",
3385
+ title="EfficientKAN BER vs Spline Order",
3386
+ )
3387
+ plot_efficient_kan_sweep(
3388
+ layer_df,
3389
+ x_col="layers",
3390
+ x_label="KAN Hidden Layers",
3391
+ filename="efficient_kan_ber_vs_layers.png",
3392
+ title="EfficientKAN BER vs Number of Layers",
3393
+ )
3394
+ plot_efficient_kan_tradeoff(
3395
+ df,
3396
+ x_col="trainable_params",
3397
+ x_label="Trainable Parameters",
3398
+ filename="efficient_kan_params_vs_ber.png",
3399
+ title="EfficientKAN Parameter Count vs BER",
3400
+ )
3401
+ plot_efficient_kan_tradeoff(
3402
+ df,
3403
+ x_col="train_samples_per_sec",
3404
+ x_label="Training Samples/sec",
3405
+ filename="efficient_kan_speed_vs_ber.png",
3406
+ title="EfficientKAN Training Speed vs BER",
3407
+ )
3408
+ log(f"Saved EfficientKAN sweep: {Config.OUT_DIR / 'efficient_kan_sweep_all.csv'}")
3409
+
3410
+
3411
+ def train_one_model(model_name: str, data: Dict[str, Any]) -> Tuple[nn.Module, Dict, Dict[str, Any]]:
3412
+ model = make_model(model_name)
3413
+ trainable_params = count_trainable_parameters(model)
3414
+ log(
3415
+ f"{model_name} | params {trainable_params:,} | "
3416
+ f"{MODEL_NOTES.get(model_name, 'custom equalizer')}"
3417
+ )
3418
+ if hasattr(model, "initialize_from_data"):
3419
+ model.initialize_from_data(data["train_x"], data["train_y"])
3420
+ log(f"{model_name} | initialized from training windows")
3421
+ if Config.USE_TORCH_COMPILE and hasattr(torch, "compile"):
3422
+ try:
3423
+ model = torch.compile(model, mode=Config.TORCH_COMPILE_MODE)
3424
+ log(f"{model_name} | torch.compile enabled ({Config.TORCH_COMPILE_MODE})")
3425
+ except Exception as exc:
3426
+ log(f"{model_name} | torch.compile disabled: {exc}")
3427
+
3428
+ optimizer = build_optimizer(model)
3429
+ criterion = build_criterion()
3430
+ scaler = torch.amp.GradScaler("cuda", enabled=Config.DEVICE.type == "cuda" and Config.USE_AMP)
3431
+ best_train_loss = float("inf")
3432
+ best_val_ber = float("inf")
3433
+ best_score = float("inf")
3434
+ best_state = None
3435
+ steps_without_improvement = 0
3436
+ early_stop_without_improvement = 0
3437
+ early_stopped = False
3438
+ stop_reason = "completed"
3439
+
3440
+ history = {
3441
+ "train_loss": [],
3442
+ "val_loss": [],
3443
+ "val_acc": [],
3444
+ "val_ber": [],
3445
+ "val_scale": [],
3446
+ "val_epochs": [],
3447
+ "test_loss": [],
3448
+ "test_acc": [],
3449
+ "test_ber": [],
3450
+ "test_scale": [],
3451
+ "test_epochs": [],
3452
+ "lr": [],
3453
+ "epoch_time_sec": [],
3454
+ "train_samples_per_sec": [],
3455
+ }
3456
+
3457
+ train_x = data["train_x"]
3458
+ train_y = data["train_y"]
3459
+ train_block_size = Config.TRAIN_BLOCK_SIZE
3460
+ eval_batch_size = Config.EVAL_BATCH_SIZE
3461
+
3462
+ if Config.SAVE_BEST_BY in {"val_ber", "val_loss"}:
3463
+ init_val_loss, _, init_val_ber, eval_batch_size, _ = evaluate_split(
3464
+ model,
3465
+ data["val_x"],
3466
+ data["val_y"],
3467
+ eval_batch_size,
3468
+ scale_search=True,
3469
+ )
3470
+ best_val_ber = init_val_ber
3471
+ best_score = init_val_ber if Config.SAVE_BEST_BY == "val_ber" else init_val_loss
3472
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
3473
+ if Config.SAVE_BEST:
3474
+ torch.save(best_state, Config.OUT_DIR / f"{model_name}_best.pth")
3475
+ log(
3476
+ f"{model_name} | init checkpoint | val {init_val_loss:.6f} | "
3477
+ f"val_ber {init_val_ber:.6e}"
3478
+ )
3479
+
3480
+ for epoch in range(Config.EPOCHS):
3481
+ model.train()
3482
+ epoch_start = time.time()
3483
+ running_loss = 0.0
3484
+ seen = 0
3485
+ total_train = train_x.size(0)
3486
+ num_blocks = (total_train + train_block_size - 1) // train_block_size
3487
+ block_order = torch.randperm(num_blocks).tolist()
3488
+ for block_idx in block_order:
3489
+ block_start = block_idx * train_block_size
3490
+ block_end = min(block_start + train_block_size, total_train)
3491
+ start = block_start
3492
+ while start < block_end:
3493
+ end = min(start + train_block_size, block_end)
3494
+ xb = train_x[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
3495
+ yb = train_y[start:end].to(Config.DEVICE, non_blocking=Config.DEVICE.type == "cuda")
3496
+ try:
3497
+ optimizer.zero_grad(set_to_none=True)
3498
+ with autocast_context():
3499
+ mark_cudagraph_step_begin()
3500
+ preds = model(xb)
3501
+ loss = prediction_loss(preds, yb, criterion)
3502
+ loss = loss + compute_model_regularization(model)
3503
+ scaler.scale(loss).backward()
3504
+ if Config.GRAD_CLIP_NORM > 0:
3505
+ scaler.unscale_(optimizer)
3506
+ torch.nn.utils.clip_grad_norm_(model.parameters(), Config.GRAD_CLIP_NORM)
3507
+ scaler.step(optimizer)
3508
+ scaler.update()
3509
+ batch_size_now = yb.size(0)
3510
+ running_loss += loss.item() * batch_size_now
3511
+ seen += batch_size_now
3512
+ start = end
3513
+ except RuntimeError as error:
3514
+ if Config.DEVICE.type != "cuda" or not is_cuda_oom(error):
3515
+ raise
3516
+ optimizer.zero_grad(set_to_none=True)
3517
+ if train_block_size <= Config.MIN_BLOCK_SIZE:
3518
+ raise
3519
+ next_block_size = max(train_block_size // 2, Config.MIN_BLOCK_SIZE)
3520
+ log(f"{model_name} | CUDA OOM at train_block_size={train_block_size}, retrying with {next_block_size}")
3521
+ train_block_size = next_block_size
3522
+ torch.cuda.empty_cache()
3523
+
3524
+ train_loss = running_loss / max(seen, 1)
3525
+ epoch_time = time.time() - epoch_start
3526
+ speed = seen / max(epoch_time, 1e-9)
3527
+
3528
+ val_loss, val_acc, val_ber, eval_batch_size, val_scale = evaluate_split(
3529
+ model,
3530
+ data["val_x"],
3531
+ data["val_y"],
3532
+ eval_batch_size,
3533
+ scale_search=True,
3534
+ )
3535
+ history["train_loss"].append(train_loss)
3536
+ history["val_loss"].append(val_loss)
3537
+ history["val_acc"].append(val_acc)
3538
+ history["val_ber"].append(val_ber)
3539
+ history["val_scale"].append(val_scale)
3540
+ history["val_epochs"].append(epoch + 1)
3541
+ history["lr"].append(optimizer.param_groups[0]["lr"])
3542
+ history["epoch_time_sec"].append(epoch_time)
3543
+ history["train_samples_per_sec"].append(speed)
3544
+
3545
+ should_eval_test = Config.EVAL_TEST_DURING_TRAINING and (
3546
+ (epoch + 1) % Config.TEST_BER_EVERY == 0 or epoch == 0 or epoch == Config.EPOCHS - 1
3547
+ )
3548
+ if should_eval_test:
3549
+ test_metrics = compute_test_metrics(model, {**data, "eval_batch_size": eval_batch_size})
3550
+ eval_batch_size = int(test_metrics["safe_eval_batch_size"])
3551
+ history["test_loss"].append(test_metrics["test_loss"])
3552
+ history["test_acc"].append(test_metrics["accuracy"])
3553
+ history["test_ber"].append(test_metrics["equalized_ber"])
3554
+ history["test_scale"].append(test_metrics["equalizer_scale"])
3555
+ history["test_epochs"].append(epoch + 1)
3556
+ log(
3557
+ f"{model_name} | epoch {epoch+1:4d}/{Config.EPOCHS} | "
3558
+ f"train {train_loss:.6f} | val {val_loss:.6f} | val_ber {val_ber:.6e} | "
3559
+ f"test_ber {test_metrics['equalized_ber']:.6e} | lr {optimizer.param_groups[0]['lr']:.2e} | "
3560
+ f"speed {speed:,.0f} samp/s | time {epoch_time:.1f}s"
3561
+ )
3562
+ elif epoch % Config.LOG_EVERY == 0 or epoch == Config.EPOCHS - 1:
3563
+ log(
3564
+ f"{model_name} | epoch {epoch+1:4d}/{Config.EPOCHS} | "
3565
+ f"train {train_loss:.6f} | val {val_loss:.6f} | val_ber {val_ber:.6e} | "
3566
+ f"lr {optimizer.param_groups[0]['lr']:.2e} | speed {speed:,.0f} samp/s | time {epoch_time:.1f}s"
3567
+ )
3568
+
3569
+ if train_loss < best_train_loss:
3570
+ best_train_loss = train_loss
3571
+
3572
+ if val_ber + Config.EARLY_STOPPING_THRESHOLD < best_val_ber:
3573
+ best_val_ber = val_ber
3574
+ early_stop_without_improvement = 0
3575
+ else:
3576
+ early_stop_without_improvement += 1
3577
+ if Config.SAVE_BEST_BY == "train_loss":
3578
+ monitor_value = train_loss
3579
+ elif Config.SAVE_BEST_BY == "val_ber":
3580
+ monitor_value = val_ber
3581
+ else:
3582
+ monitor_value = val_loss
3583
+ if monitor_value + Config.SCHEDULER_THRESHOLD < best_score:
3584
+ best_score = monitor_value
3585
+ steps_without_improvement = 0
3586
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
3587
+ if Config.SAVE_BEST:
3588
+ torch.save(best_state, Config.OUT_DIR / f"{model_name}_best.pth")
3589
+ else:
3590
+ steps_without_improvement += 1
3591
+
3592
+ if Config.LR_SCHEDULER == "notebook_decay" and steps_without_improvement >= Config.DECAY_STEPS:
3593
+ current_lr = optimizer.param_groups[0]["lr"]
3594
+ new_lr = current_lr * Config.SCHEDULER_FACTOR
3595
+ steps_without_improvement = 0
3596
+ if new_lr < Config.MIN_LR:
3597
+ stop_reason = "lr_floor"
3598
+ log(f"{model_name} | epoch {epoch+1} -- stopping at lr floor ({current_lr:.6g})")
3599
+ break
3600
+ for param_group in optimizer.param_groups:
3601
+ param_group["lr"] = new_lr
3602
+ log(f"{model_name} | epoch {epoch+1} -- scheduler reduced lr to {new_lr:.6g}")
3603
+
3604
+ if (
3605
+ Config.EARLY_STOPPING
3606
+ and epoch + 1 >= Config.EARLY_STOPPING_MIN_EPOCHS
3607
+ and early_stop_without_improvement >= Config.EARLY_STOPPING_PATIENCE
3608
+ ):
3609
+ early_stopped = True
3610
+ stop_reason = f"early_stop_val_ber_patience_{Config.EARLY_STOPPING_PATIENCE}"
3611
+ log(
3612
+ f"{model_name} | epoch {epoch+1} -- early stopping: "
3613
+ f"val_ber did not improve for {early_stop_without_improvement} epochs "
3614
+ f"(best_val_ber={best_val_ber:.6e})"
3615
+ )
3616
+ break
3617
+
3618
+ if best_state is None:
3619
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
3620
+ model.load_state_dict(best_state)
3621
+ final_metrics = compute_test_metrics(model, {**data, "eval_batch_size": eval_batch_size})
3622
+ final_metrics["trainable_params"] = count_trainable_parameters(model)
3623
+ final_metrics["best_val_ber"] = best_val_ber
3624
+ final_metrics["best_train_loss"] = best_train_loss
3625
+ final_metrics["train_samples_per_sec"] = float(np.mean(history["train_samples_per_sec"])) if history["train_samples_per_sec"] else 0.0
3626
+ final_metrics["mean_epoch_time_sec"] = float(np.mean(history["epoch_time_sec"])) if history["epoch_time_sec"] else 0.0
3627
+ final_metrics["epochs_ran"] = len(history["train_loss"])
3628
+ final_metrics["early_stopped"] = early_stopped
3629
+ final_metrics["stop_reason"] = stop_reason
3630
+ final_metrics = add_efficiency_metrics(model, data, final_metrics)
3631
+ model, final_metrics, eval_batch_size = maybe_prune_efficient_kan_model(
3632
+ model,
3633
+ model_name,
3634
+ data,
3635
+ final_metrics,
3636
+ eval_batch_size,
3637
+ )
3638
+ final_metrics["trainable_params"] = count_trainable_parameters(model)
3639
+ final_metrics["best_val_ber"] = best_val_ber
3640
+ final_metrics["best_train_loss"] = best_train_loss
3641
+ final_metrics["train_samples_per_sec"] = float(np.mean(history["train_samples_per_sec"])) if history["train_samples_per_sec"] else 0.0
3642
+ final_metrics["mean_epoch_time_sec"] = float(np.mean(history["epoch_time_sec"])) if history["epoch_time_sec"] else 0.0
3643
+ final_metrics["epochs_ran"] = len(history["train_loss"])
3644
+ final_metrics["early_stopped"] = early_stopped
3645
+ final_metrics["stop_reason"] = stop_reason
3646
+ if Config.COMPUTE_PER_FILE_METRICS:
3647
+ val_file_metrics, eval_batch_size = compute_split_file_metrics(model, data, "val", eval_batch_size)
3648
+ test_file_metrics, eval_batch_size = compute_split_file_metrics(model, data, "test", eval_batch_size)
3649
+ add_file_metric_summary(final_metrics, "val", val_file_metrics)
3650
+ add_file_metric_summary(final_metrics, "test", test_file_metrics)
3651
+ return model, history, final_metrics
3652
+
3653
+
3654
+ def main():
3655
+ Config.OUT_DIR.mkdir(parents=True, exist_ok=True)
3656
+ log(f"Device: {Config.DEVICE}")
3657
+
3658
+ if Config.RUN_FASTKAN_CLASSIFIER_SWEEP:
3659
+ run_fastkan_classifier_sweep()
3660
+ return
3661
+
3662
+ if Config.RUN_KAN_EXPERIMENT_SUITE:
3663
+ run_kan_experiment_suite()
3664
+ return
3665
+
3666
+ if Config.RUN_SWEEP_EXPERIMENTS and not Config.RUN_MAIN_EXPERIMENTS:
3667
+ run_sweep_experiments()
3668
+ return
3669
+
3670
+ data = prepare_data()
3671
+
3672
+ all_results = []
3673
+ for model_name in Config.MODEL_TYPES:
3674
+ log(f"\nTraining {model_name.upper()}...")
3675
+ model, history, results = train_one_model(model_name, data)
3676
+ results["model_type"] = model_name
3677
+ all_results.append(results)
3678
+
3679
+ plot_results(history, results, model_name)
3680
+ torch.save(model.state_dict(), Config.OUT_DIR / f"{model_name}_final.pth")
3681
+ log(
3682
+ f"{model_name.upper()} | baseline BER {results['baseline_ber']:.6e} | "
3683
+ f"equalized BER {results['equalized_ber']:.6e} | acc {results['accuracy']:.4%} | "
3684
+ f"rel improvement {results['improvement_rel']:.2f}%"
3685
+ )
3686
+
3687
+ if Config.DEVICE.type == "cuda":
3688
+ model.to("cpu")
3689
+ del model
3690
+ torch.cuda.empty_cache()
3691
+
3692
+ plot_architecture_summary(all_results)
3693
+ log(f"Saved summary: {Config.OUT_DIR / 'architecture_comparison.csv'}")
3694
+
3695
+ if Config.RUN_SWEEP_EXPERIMENTS:
3696
+ run_sweep_experiments()
3697
+
3698
+
3699
+ if __name__ == "__main__":
3700
+ main()