ber-equalization-studio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ber_equalization_studio/__init__.py +58 -0
- ber_equalization_studio/_legacy_backend/__init__.py +1 -0
- ber_equalization_studio/_legacy_backend/ber_equalization.py +3700 -0
- ber_equalization_studio/_legacy_backend/efficient_kan/__init__.py +3 -0
- ber_equalization_studio/_legacy_backend/efficient_kan/kan.py +218 -0
- ber_equalization_studio/api.py +348 -0
- ber_equalization_studio/cli.py +92 -0
- ber_equalization_studio/config.py +168 -0
- ber_equalization_studio/data.py +31 -0
- ber_equalization_studio/experiment.py +92 -0
- ber_equalization_studio/legacy.py +149 -0
- ber_equalization_studio/models.py +86 -0
- ber_equalization_studio/results.py +74 -0
- ber_equalization_studio/visualization.py +186 -0
- ber_equalization_studio-0.1.0.dist-info/METADATA +266 -0
- ber_equalization_studio-0.1.0.dist-info/RECORD +19 -0
- ber_equalization_studio-0.1.0.dist-info/WHEEL +5 -0
- ber_equalization_studio-0.1.0.dist-info/entry_points.txt +2 -0
- ber_equalization_studio-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _default_device() -> str:
|
|
9
|
+
try:
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
if torch.cuda.is_available():
|
|
13
|
+
return "cuda"
|
|
14
|
+
if torch.backends.mps.is_available():
|
|
15
|
+
return "mps"
|
|
16
|
+
except Exception:
|
|
17
|
+
pass
|
|
18
|
+
return "cpu"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(slots=True)
|
|
22
|
+
class DataConfig:
|
|
23
|
+
data_dirs: list[Path] = field(
|
|
24
|
+
default_factory=lambda: [Path("symbols_new"), Path("Symbols_1m_1ch_PR"), Path(".")]
|
|
25
|
+
)
|
|
26
|
+
max_files: int = 64
|
|
27
|
+
train_portion: float = 0.97
|
|
28
|
+
val_portion_within_train: float = 0.10
|
|
29
|
+
min_val_files: int = 1
|
|
30
|
+
randomize_file_split: bool = False
|
|
31
|
+
split_seed: int = 42
|
|
32
|
+
context_k: int = 32
|
|
33
|
+
input_dim: int = 2
|
|
34
|
+
power_normalize: bool = True
|
|
35
|
+
max_test_files: int | None = None
|
|
36
|
+
|
|
37
|
+
def __post_init__(self) -> None:
|
|
38
|
+
self.data_dirs = [Path(path) for path in self.data_dirs]
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def seq_len(self) -> int:
|
|
42
|
+
return 2 * self.context_k + 1
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(slots=True)
|
|
46
|
+
class ModelConfig:
|
|
47
|
+
name: str = "complex_fastkan"
|
|
48
|
+
hidden_dim: int = 96
|
|
49
|
+
dropout: float = 0.2
|
|
50
|
+
mlp_layers: int = 3
|
|
51
|
+
tcn_hidden_dim: int = 96
|
|
52
|
+
tcn_layers: int = 5
|
|
53
|
+
tcn_kernel_size: int = 5
|
|
54
|
+
fastkan_hidden_dim: int = 96
|
|
55
|
+
fastkan_layers: int = 2
|
|
56
|
+
fastkan_num_grids: int = 8
|
|
57
|
+
fastkan_grid_min: float = -2.5
|
|
58
|
+
fastkan_grid_max: float = 2.5
|
|
59
|
+
efficient_kan_hidden_dim: int = 128
|
|
60
|
+
efficient_kan_layers: int = 2
|
|
61
|
+
efficient_kan_grid_size: int = 8
|
|
62
|
+
efficient_kan_spline_order: int = 3
|
|
63
|
+
efficient_kan_grid_range: tuple[float, float] = (-3.0, 3.0)
|
|
64
|
+
complex_light_channels: int = 48
|
|
65
|
+
complex_light_dilations: tuple[int, ...] = (1, 2, 4)
|
|
66
|
+
complex_light_kernel_size: int = 3
|
|
67
|
+
kan_prune_l1: float = 1e-5
|
|
68
|
+
kan_prune_threshold: float = 0.02
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(slots=True)
|
|
72
|
+
class TrainingConfig:
|
|
73
|
+
epochs: int = 250
|
|
74
|
+
learning_rate: float = 1e-3
|
|
75
|
+
weight_decay: float = 0.0
|
|
76
|
+
optimizer: str = "adam"
|
|
77
|
+
loss: str = "mse"
|
|
78
|
+
train_block_size: int = 8192
|
|
79
|
+
min_block_size: int = 1024
|
|
80
|
+
use_amp: bool = True
|
|
81
|
+
use_torch_compile: bool = False
|
|
82
|
+
grad_clip_norm: float = 1.0
|
|
83
|
+
lr_scheduler: str = "notebook_decay"
|
|
84
|
+
scheduler_factor: float = 0.5
|
|
85
|
+
scheduler_threshold: float = 1e-6
|
|
86
|
+
decay_steps: int = 24
|
|
87
|
+
min_lr: float = 1e-5
|
|
88
|
+
early_stopping: bool = True
|
|
89
|
+
early_stopping_patience: int = 72
|
|
90
|
+
early_stopping_min_epochs: int = 40
|
|
91
|
+
early_stopping_threshold: float = 0.0
|
|
92
|
+
save_best_by: str = "val_ber"
|
|
93
|
+
save_best: bool = True
|
|
94
|
+
log_every: int = 1
|
|
95
|
+
seed: int = 42
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(slots=True)
|
|
99
|
+
class EvaluationConfig:
|
|
100
|
+
eval_batch_size: int = 65536
|
|
101
|
+
eval_test_during_training: bool = False
|
|
102
|
+
test_ber_every: int = 10
|
|
103
|
+
compute_per_file_metrics: bool = True
|
|
104
|
+
ber_scale_search: bool = True
|
|
105
|
+
ber_scale_min: float = 0.5
|
|
106
|
+
ber_scale_max: float = 1.5
|
|
107
|
+
ber_scale_steps: int = 10
|
|
108
|
+
ber_scale_offset: int = 10000
|
|
109
|
+
ber_scale_samples: int = 1 << 20
|
|
110
|
+
efficiency_batch_size: int = 16000
|
|
111
|
+
efficiency_timing_warmup: int = 5
|
|
112
|
+
efficiency_timing_repeats: int = 20
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(slots=True)
|
|
116
|
+
class OutputConfig:
|
|
117
|
+
out_dir: Path = Path("studio_outputs")
|
|
118
|
+
experiment_name: str = "experiment"
|
|
119
|
+
save_checkpoint: bool = True
|
|
120
|
+
save_history: bool = True
|
|
121
|
+
save_plots: bool = True
|
|
122
|
+
plot_backend: str = "plotly"
|
|
123
|
+
|
|
124
|
+
def __post_init__(self) -> None:
|
|
125
|
+
self.out_dir = Path(self.out_dir)
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def run_dir(self) -> Path:
|
|
129
|
+
return self.out_dir / self.experiment_name
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass(slots=True)
|
|
133
|
+
class ExperimentConfig:
|
|
134
|
+
models: list[str] = field(default_factory=lambda: ["complex_fastkan"])
|
|
135
|
+
tags: dict[str, str] = field(default_factory=dict)
|
|
136
|
+
notes: str = ""
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass(slots=True)
|
|
140
|
+
class StudioConfig:
|
|
141
|
+
data: DataConfig = field(default_factory=DataConfig)
|
|
142
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
143
|
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
|
144
|
+
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
|
145
|
+
output: OutputConfig = field(default_factory=OutputConfig)
|
|
146
|
+
experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
|
|
147
|
+
device: str = field(default_factory=_default_device)
|
|
148
|
+
|
|
149
|
+
def to_dict(self) -> dict[str, Any]:
|
|
150
|
+
raw = asdict(self)
|
|
151
|
+
raw["data"]["data_dirs"] = [str(path) for path in self.data.data_dirs]
|
|
152
|
+
raw["output"]["out_dir"] = str(self.output.out_dir)
|
|
153
|
+
return raw
|
|
154
|
+
|
|
155
|
+
def with_updates(self, **updates: Any) -> "StudioConfig":
|
|
156
|
+
values = self.to_dict()
|
|
157
|
+
for dotted_key, value in updates.items():
|
|
158
|
+
section, key = dotted_key.split(".", 1)
|
|
159
|
+
values[section][key] = value
|
|
160
|
+
return StudioConfig(
|
|
161
|
+
data=DataConfig(**values["data"]),
|
|
162
|
+
model=ModelConfig(**values["model"]),
|
|
163
|
+
training=TrainingConfig(**values["training"]),
|
|
164
|
+
evaluation=EvaluationConfig(**values["evaluation"]),
|
|
165
|
+
output=OutputConfig(**values["output"]),
|
|
166
|
+
experiment=ExperimentConfig(**values["experiment"]),
|
|
167
|
+
device=values["device"],
|
|
168
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from ber_equalization_studio.config import StudioConfig
|
|
7
|
+
from ber_equalization_studio.legacy import configured_legacy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def prepare_dataset(config: StudioConfig, legacy_path: Path | None = None) -> dict[str, Any]:
|
|
11
|
+
with configured_legacy(config, legacy_path) as module:
|
|
12
|
+
return module.prepare_data(max_test_files=config.data.max_test_files)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dataset_summary(data: dict[str, Any]) -> dict[str, Any]:
|
|
16
|
+
summary: dict[str, Any] = {}
|
|
17
|
+
for split in ("train", "val", "test"):
|
|
18
|
+
x_key = f"{split}_x"
|
|
19
|
+
y_key = f"{split}_y"
|
|
20
|
+
if x_key in data:
|
|
21
|
+
summary[f"{split}_samples"] = int(data[x_key].size(0))
|
|
22
|
+
if y_key in data:
|
|
23
|
+
summary[f"{split}_targets"] = int(data[y_key].size(0))
|
|
24
|
+
spans = data.get(f"{split}_file_spans", [])
|
|
25
|
+
if spans:
|
|
26
|
+
summary[f"{split}_files"] = [int(item["file_idx"]) for item in spans]
|
|
27
|
+
for key in ("tx_scale", "rx_scale"):
|
|
28
|
+
if key in data:
|
|
29
|
+
value = data[key]
|
|
30
|
+
summary[key] = float(value.item() if hasattr(value, "item") else value)
|
|
31
|
+
return summary
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import random
|
|
6
|
+
from dataclasses import replace
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ber_equalization_studio.config import StudioConfig
|
|
14
|
+
from ber_equalization_studio.data import dataset_summary, prepare_dataset
|
|
15
|
+
from ber_equalization_studio.legacy import configured_legacy
|
|
16
|
+
from ber_equalization_studio.models import canonical_model_name
|
|
17
|
+
from ber_equalization_studio.results import save_json, write_result_table
|
|
18
|
+
from ber_equalization_studio.visualization import plot_comparison, plot_history
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def set_reproducible_seed(seed: int) -> None:
|
|
22
|
+
random.seed(seed)
|
|
23
|
+
np.random.seed(seed)
|
|
24
|
+
torch.manual_seed(seed)
|
|
25
|
+
if torch.cuda.is_available():
|
|
26
|
+
torch.cuda.manual_seed_all(seed)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ExperimentRunner:
|
|
30
|
+
def __init__(self, config: StudioConfig, legacy_path: Path | None = None):
|
|
31
|
+
self.config = config
|
|
32
|
+
self.legacy_path = legacy_path
|
|
33
|
+
self.run_dir = config.output.run_dir
|
|
34
|
+
|
|
35
|
+
def _model_config(self, model_name: str) -> StudioConfig:
|
|
36
|
+
model = replace(self.config.model, name=canonical_model_name(model_name))
|
|
37
|
+
return replace(self.config, model=model)
|
|
38
|
+
|
|
39
|
+
def run(self) -> dict[str, Any]:
|
|
40
|
+
set_reproducible_seed(self.config.training.seed)
|
|
41
|
+
self.run_dir.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
save_json(self.run_dir / "config.json", self.config.to_dict())
|
|
43
|
+
|
|
44
|
+
data = prepare_dataset(self.config, self.legacy_path)
|
|
45
|
+
save_json(self.run_dir / "dataset_summary.json", dataset_summary(data))
|
|
46
|
+
|
|
47
|
+
rows: list[dict[str, Any]] = []
|
|
48
|
+
histories: dict[str, dict[str, Any]] = {}
|
|
49
|
+
models = self.config.experiment.models or [self.config.model.name]
|
|
50
|
+
|
|
51
|
+
for requested_model_name in models:
|
|
52
|
+
model_name = canonical_model_name(requested_model_name)
|
|
53
|
+
model_config = self._model_config(model_name)
|
|
54
|
+
model_run_dir = self.run_dir / model_name
|
|
55
|
+
model_run_dir.mkdir(parents=True, exist_ok=True)
|
|
56
|
+
|
|
57
|
+
with configured_legacy(model_config, self.legacy_path) as legacy:
|
|
58
|
+
legacy.Config.OUT_DIR = model_run_dir
|
|
59
|
+
legacy.Config.OUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
60
|
+
model, history, metrics = legacy.train_one_model(model_name, data)
|
|
61
|
+
|
|
62
|
+
if model_config.output.save_checkpoint:
|
|
63
|
+
torch.save(model.state_dict(), model_run_dir / "final_state_dict.pt")
|
|
64
|
+
|
|
65
|
+
metrics = copy.deepcopy(metrics)
|
|
66
|
+
metrics["model_type"] = model_name
|
|
67
|
+
metrics["experiment_name"] = self.config.output.experiment_name
|
|
68
|
+
rows.append(metrics)
|
|
69
|
+
histories[model_name] = history
|
|
70
|
+
|
|
71
|
+
if self.config.output.save_history:
|
|
72
|
+
save_json(model_run_dir / "history.json", history)
|
|
73
|
+
save_json(model_run_dir / "metrics.json", metrics)
|
|
74
|
+
if self.config.output.save_plots:
|
|
75
|
+
plot_history(history, model_name, model_run_dir)
|
|
76
|
+
|
|
77
|
+
if torch.cuda.is_available():
|
|
78
|
+
torch.cuda.empty_cache()
|
|
79
|
+
|
|
80
|
+
df = write_result_table(self.run_dir / "results.csv", rows)
|
|
81
|
+
if self.config.output.save_plots:
|
|
82
|
+
plot_comparison(df, self.run_dir)
|
|
83
|
+
return {
|
|
84
|
+
"run_dir": self.run_dir,
|
|
85
|
+
"results": rows,
|
|
86
|
+
"histories": histories,
|
|
87
|
+
"results_csv": self.run_dir / "results.csv",
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def run_experiment(config: StudioConfig, legacy_path: Path | None = None) -> dict[str, Any]:
|
|
92
|
+
return ExperimentRunner(config=config, legacy_path=legacy_path).run()
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib.util
|
|
4
|
+
import sys
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from types import ModuleType
|
|
8
|
+
from typing import Any, Iterator
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ber_equalization_studio.config import StudioConfig
|
|
13
|
+
|
|
14
|
+
_LEGACY_MODULE: ModuleType | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def default_legacy_path() -> Path:
|
|
18
|
+
package_root = Path(__file__).resolve().parents[3]
|
|
19
|
+
package_dir = Path(__file__).resolve().parent
|
|
20
|
+
candidates = [
|
|
21
|
+
package_dir / "_legacy_backend" / "ber_equalization.py",
|
|
22
|
+
package_root.parent / "BER_minimization_survey" / "ber_equalization.py",
|
|
23
|
+
package_root / "BER_minimization_survey" / "ber_equalization.py",
|
|
24
|
+
Path.cwd() / "BER_minimization_survey" / "ber_equalization.py",
|
|
25
|
+
Path.cwd() / "ber_equalization.py",
|
|
26
|
+
]
|
|
27
|
+
for path in candidates:
|
|
28
|
+
if path.exists():
|
|
29
|
+
return path
|
|
30
|
+
raise FileNotFoundError(
|
|
31
|
+
"Could not find legacy ber_equalization.py. Run from the repository root or pass legacy_path explicitly."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_legacy_module(legacy_path: Path | None = None) -> ModuleType:
|
|
36
|
+
global _LEGACY_MODULE
|
|
37
|
+
if _LEGACY_MODULE is not None:
|
|
38
|
+
return _LEGACY_MODULE
|
|
39
|
+
|
|
40
|
+
path = (legacy_path or default_legacy_path()).resolve()
|
|
41
|
+
sys.path.insert(0, str(path.parent))
|
|
42
|
+
spec = importlib.util.spec_from_file_location("_ber_equalization_legacy", path)
|
|
43
|
+
if spec is None or spec.loader is None:
|
|
44
|
+
raise ImportError(f"Cannot import legacy module from {path}")
|
|
45
|
+
module = importlib.util.module_from_spec(spec)
|
|
46
|
+
sys.modules[spec.name] = module
|
|
47
|
+
spec.loader.exec_module(module)
|
|
48
|
+
_LEGACY_MODULE = module
|
|
49
|
+
return module
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _legacy_overrides(config: StudioConfig) -> dict[str, Any]:
|
|
53
|
+
data = config.data
|
|
54
|
+
model = config.model
|
|
55
|
+
training = config.training
|
|
56
|
+
evaluation = config.evaluation
|
|
57
|
+
output = config.output
|
|
58
|
+
return {
|
|
59
|
+
"DEVICE": torch.device(config.device),
|
|
60
|
+
"DATA_DIR_CANDIDATES": [Path(path) for path in data.data_dirs],
|
|
61
|
+
"MAX_FILES": data.max_files,
|
|
62
|
+
"TRAIN_PORTION": data.train_portion,
|
|
63
|
+
"VAL_PORTION_WITHIN_TRAIN": data.val_portion_within_train,
|
|
64
|
+
"MIN_VAL_FILES": data.min_val_files,
|
|
65
|
+
"RANDOMIZE_FILE_SPLIT": data.randomize_file_split,
|
|
66
|
+
"SPLIT_SEED": data.split_seed,
|
|
67
|
+
"CONTEXT_K": data.context_k,
|
|
68
|
+
"SEQ_LEN": data.seq_len,
|
|
69
|
+
"INPUT_DIM": data.input_dim,
|
|
70
|
+
"POWER_NORMALIZE": data.power_normalize,
|
|
71
|
+
"HIDDEN_DIM": model.hidden_dim,
|
|
72
|
+
"DROPOUT": model.dropout,
|
|
73
|
+
"MLP_LAYERS": model.mlp_layers,
|
|
74
|
+
"TCN_HIDDEN_DIM": model.tcn_hidden_dim,
|
|
75
|
+
"TCN_LAYERS": model.tcn_layers,
|
|
76
|
+
"TCN_KERNEL_SIZE": model.tcn_kernel_size,
|
|
77
|
+
"FASTKAN_HIDDEN_DIM": model.fastkan_hidden_dim,
|
|
78
|
+
"FASTKAN_LAYERS": model.fastkan_layers,
|
|
79
|
+
"FASTKAN_NUM_GRIDS": model.fastkan_num_grids,
|
|
80
|
+
"FASTKAN_GRID_MIN": model.fastkan_grid_min,
|
|
81
|
+
"FASTKAN_GRID_MAX": model.fastkan_grid_max,
|
|
82
|
+
"EFFICIENT_KAN_HIDDEN_DIM": model.efficient_kan_hidden_dim,
|
|
83
|
+
"EFFICIENT_KAN_LAYERS": model.efficient_kan_layers,
|
|
84
|
+
"EFFICIENT_KAN_GRID_SIZE": model.efficient_kan_grid_size,
|
|
85
|
+
"EFFICIENT_KAN_SPLINE_ORDER": model.efficient_kan_spline_order,
|
|
86
|
+
"EFFICIENT_KAN_GRID_RANGE": list(model.efficient_kan_grid_range),
|
|
87
|
+
"COMPLEX_LIGHT_CHANNELS": model.complex_light_channels,
|
|
88
|
+
"COMPLEX_LIGHT_DILATIONS": list(model.complex_light_dilations),
|
|
89
|
+
"COMPLEX_LIGHT_KERNEL_SIZE": model.complex_light_kernel_size,
|
|
90
|
+
"KAN_PRUNE_L1": model.kan_prune_l1,
|
|
91
|
+
"KAN_PRUNE_THRESHOLD": model.kan_prune_threshold,
|
|
92
|
+
"EPOCHS": training.epochs,
|
|
93
|
+
"LEARNING_RATE": training.learning_rate,
|
|
94
|
+
"WEIGHT_DECAY": training.weight_decay,
|
|
95
|
+
"OPTIMIZER": training.optimizer,
|
|
96
|
+
"LOSS": training.loss,
|
|
97
|
+
"TRAIN_BLOCK_SIZE": training.train_block_size,
|
|
98
|
+
"MIN_BLOCK_SIZE": training.min_block_size,
|
|
99
|
+
"USE_AMP": training.use_amp,
|
|
100
|
+
"USE_TORCH_COMPILE": training.use_torch_compile,
|
|
101
|
+
"GRAD_CLIP_NORM": training.grad_clip_norm,
|
|
102
|
+
"LR_SCHEDULER": training.lr_scheduler,
|
|
103
|
+
"SCHEDULER_FACTOR": training.scheduler_factor,
|
|
104
|
+
"SCHEDULER_THRESHOLD": training.scheduler_threshold,
|
|
105
|
+
"DECAY_STEPS": training.decay_steps,
|
|
106
|
+
"MIN_LR": training.min_lr,
|
|
107
|
+
"EARLY_STOPPING": training.early_stopping,
|
|
108
|
+
"EARLY_STOPPING_PATIENCE": training.early_stopping_patience,
|
|
109
|
+
"EARLY_STOPPING_MIN_EPOCHS": training.early_stopping_min_epochs,
|
|
110
|
+
"EARLY_STOPPING_THRESHOLD": training.early_stopping_threshold,
|
|
111
|
+
"SAVE_BEST_BY": training.save_best_by,
|
|
112
|
+
"SAVE_BEST": training.save_best,
|
|
113
|
+
"LOG_EVERY": training.log_every,
|
|
114
|
+
"EVAL_BATCH_SIZE": evaluation.eval_batch_size,
|
|
115
|
+
"EVAL_TEST_DURING_TRAINING": evaluation.eval_test_during_training,
|
|
116
|
+
"TEST_BER_EVERY": evaluation.test_ber_every,
|
|
117
|
+
"COMPUTE_PER_FILE_METRICS": evaluation.compute_per_file_metrics,
|
|
118
|
+
"BER_SCALE_SEARCH": evaluation.ber_scale_search,
|
|
119
|
+
"BER_SCALE_MIN": evaluation.ber_scale_min,
|
|
120
|
+
"BER_SCALE_MAX": evaluation.ber_scale_max,
|
|
121
|
+
"BER_SCALE_STEPS": evaluation.ber_scale_steps,
|
|
122
|
+
"BER_SCALE_OFFSET": evaluation.ber_scale_offset,
|
|
123
|
+
"BER_SCALE_SAMPLES": evaluation.ber_scale_samples,
|
|
124
|
+
"EFFICIENCY_BATCH_SIZE": evaluation.efficiency_batch_size,
|
|
125
|
+
"EFFICIENCY_TIMING_WARMUP": evaluation.efficiency_timing_warmup,
|
|
126
|
+
"EFFICIENCY_TIMING_REPEATS": evaluation.efficiency_timing_repeats,
|
|
127
|
+
"OUT_DIR": output.run_dir,
|
|
128
|
+
"MODEL_TYPES": list(config.experiment.models),
|
|
129
|
+
"RUN_MAIN_EXPERIMENTS": True,
|
|
130
|
+
"RUN_SWEEP_EXPERIMENTS": False,
|
|
131
|
+
"RUN_EFFICIENT_KAN_SWEEP": False,
|
|
132
|
+
"RUN_KAN_EXPERIMENT_SUITE": False,
|
|
133
|
+
"RUN_FASTKAN_CLASSIFIER_SWEEP": False,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@contextmanager
|
|
138
|
+
def configured_legacy(config: StudioConfig, legacy_path: Path | None = None) -> Iterator[ModuleType]:
|
|
139
|
+
module = load_legacy_module(legacy_path)
|
|
140
|
+
legacy_config = module.Config
|
|
141
|
+
overrides = _legacy_overrides(config)
|
|
142
|
+
previous = {key: getattr(legacy_config, key, None) for key in overrides}
|
|
143
|
+
for key, value in overrides.items():
|
|
144
|
+
setattr(legacy_config, key, value)
|
|
145
|
+
try:
|
|
146
|
+
yield module
|
|
147
|
+
finally:
|
|
148
|
+
for key, value in previous.items():
|
|
149
|
+
setattr(legacy_config, key, value)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from ber_equalization_studio.config import StudioConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
MODEL_ALIASES: dict[str, str] = {
|
|
13
|
+
"ekan": "efficient_kan_baseline",
|
|
14
|
+
"efficient_kan": "efficient_kan_baseline",
|
|
15
|
+
"kan": "complex_fastkan",
|
|
16
|
+
"fastkan": "complex_fastkan",
|
|
17
|
+
"rbf_kan": "complex_fastkan",
|
|
18
|
+
"efficient_kan_classifier": "kan_classifier",
|
|
19
|
+
"rbf_kan_classifier": "fastkan_classifier",
|
|
20
|
+
"complex_rbf_kan_classifier": "complex_fastkan_classifier",
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
MODEL_NOTES: dict[str, str] = {
|
|
25
|
+
"efficient_kan_baseline": "flat IQ window -> B-spline EfficientKAN -> corrected I/Q",
|
|
26
|
+
"efficient_kan_residual": "flat IQ window -> KAN correction -> rx center + correction",
|
|
27
|
+
"efficient_kan_features": "handcrafted IQ statistics -> EfficientKAN -> corrected I/Q",
|
|
28
|
+
"cnn_kan": "temporal CNN encoder -> EfficientKAN head -> corrected I/Q",
|
|
29
|
+
"kan_classifier": "flat IQ window -> EfficientKAN -> 16 constellation logits",
|
|
30
|
+
"fastkan_classifier": "flat IQ window -> Gaussian RBF/FastKAN -> 16 logits",
|
|
31
|
+
"complex_fastkan": "lightweight complex temporal encoder -> RBF/FastKAN -> corrected I/Q",
|
|
32
|
+
"complex_fastkan_classifier": "lightweight complex temporal encoder -> RBF/FastKAN -> 16 logits",
|
|
33
|
+
"mlp": "flat IQ window -> residual MLP -> corrected I/Q",
|
|
34
|
+
"cnn": "temporal CNN encoder -> MLP head -> corrected I/Q",
|
|
35
|
+
"lstm": "LSTM over IQ window -> corrected I/Q",
|
|
36
|
+
"hybrid": "CNN front-end + LSTM -> corrected I/Q",
|
|
37
|
+
"complex_cnn": "complex feature encoder + temporal CNN -> corrected I/Q",
|
|
38
|
+
"complex_lstm": "complex feature encoder + LSTM -> corrected I/Q",
|
|
39
|
+
"complex_cnn_lstm": "complex temporal CNN + LSTM -> corrected I/Q",
|
|
40
|
+
"complex_dbp_seqstat": "learnable DBP-inspired front-end + sequence statistics",
|
|
41
|
+
"transformer": "local transformer encoder over IQ window -> corrected I/Q",
|
|
42
|
+
"tcn": "dilated temporal convolutional network -> corrected I/Q",
|
|
43
|
+
"mamba": "Mamba sequence blocks over IQ window -> corrected I/Q",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def canonical_model_name(name: str) -> str:
|
|
48
|
+
normalized = name.strip().lower()
|
|
49
|
+
return MODEL_ALIASES.get(normalized, normalized)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def available_models() -> dict[str, str]:
|
|
53
|
+
return dict(sorted(MODEL_NOTES.items()))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def build_model(
|
|
57
|
+
name: str,
|
|
58
|
+
config: StudioConfig | None = None,
|
|
59
|
+
legacy_path: Path | None = None,
|
|
60
|
+
) -> "nn.Module":
|
|
61
|
+
"""Build a model from the studio registry.
|
|
62
|
+
|
|
63
|
+
The current backend reuses the verified model implementations from
|
|
64
|
+
`ber_equalization.py`, while the public API is driven by StudioConfig.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
from ber_equalization_studio.legacy import configured_legacy, load_legacy_module
|
|
68
|
+
|
|
69
|
+
model_name = canonical_model_name(name)
|
|
70
|
+
if config is None:
|
|
71
|
+
module = load_legacy_module(legacy_path)
|
|
72
|
+
return module.make_model(model_name)
|
|
73
|
+
with configured_legacy(config, legacy_path) as module:
|
|
74
|
+
return module.make_model(model_name)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def count_parameters(model: "nn.Module") -> int:
|
|
78
|
+
return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def describe_model(name: str) -> dict[str, Any]:
|
|
82
|
+
model_name = canonical_model_name(name)
|
|
83
|
+
return {
|
|
84
|
+
"name": model_name,
|
|
85
|
+
"description": MODEL_NOTES.get(model_name, "custom or legacy model"),
|
|
86
|
+
}
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
PREFERRED_COLUMNS = [
|
|
11
|
+
"model_type",
|
|
12
|
+
"equalized_ber",
|
|
13
|
+
"baseline_ber",
|
|
14
|
+
"improvement_rel",
|
|
15
|
+
"improvement_db",
|
|
16
|
+
"accuracy",
|
|
17
|
+
"ser",
|
|
18
|
+
"best_val_ber",
|
|
19
|
+
"trainable_params",
|
|
20
|
+
"batch_time_sec",
|
|
21
|
+
"samples_per_sec",
|
|
22
|
+
"train_samples_per_sec",
|
|
23
|
+
"epochs_ran",
|
|
24
|
+
"stop_reason",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def flatten_metrics(metrics: dict[str, Any]) -> dict[str, Any]:
|
|
29
|
+
flat: dict[str, Any] = {}
|
|
30
|
+
for key, value in metrics.items():
|
|
31
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
32
|
+
flat[key] = value
|
|
33
|
+
elif hasattr(value, "item"):
|
|
34
|
+
flat[key] = value.item()
|
|
35
|
+
else:
|
|
36
|
+
flat[key] = json.dumps(value, ensure_ascii=True)
|
|
37
|
+
return flat
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def save_json(path: Path, payload: dict[str, Any]) -> None:
|
|
41
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
path.write_text(json.dumps(payload, indent=2, ensure_ascii=True), encoding="utf-8")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_result_table(path: Path) -> pd.DataFrame:
|
|
46
|
+
if path.is_dir():
|
|
47
|
+
candidates = sorted(path.glob("**/results.csv"))
|
|
48
|
+
if not candidates:
|
|
49
|
+
candidates = sorted(path.glob("**/*summary*.csv"))
|
|
50
|
+
if not candidates:
|
|
51
|
+
raise FileNotFoundError(f"No result CSV files found below {path}")
|
|
52
|
+
frames = [pd.read_csv(candidate) for candidate in candidates]
|
|
53
|
+
return pd.concat(frames, ignore_index=True)
|
|
54
|
+
return pd.read_csv(path)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def compare_results(path: Path, sort_by: str = "equalized_ber") -> pd.DataFrame:
|
|
58
|
+
df = load_result_table(path)
|
|
59
|
+
if sort_by in df.columns:
|
|
60
|
+
ascending = sort_by not in {"improvement_rel", "improvement_db", "accuracy", "samples_per_sec"}
|
|
61
|
+
df = df.sort_values(sort_by, ascending=ascending)
|
|
62
|
+
ordered = [column for column in PREFERRED_COLUMNS if column in df.columns]
|
|
63
|
+
rest = [column for column in df.columns if column not in ordered]
|
|
64
|
+
return df[ordered + rest]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def write_result_table(path: Path, rows: list[dict[str, Any]]) -> pd.DataFrame:
|
|
68
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
df = pd.DataFrame([flatten_metrics(row) for row in rows])
|
|
70
|
+
ordered = [column for column in PREFERRED_COLUMNS if column in df.columns]
|
|
71
|
+
rest = [column for column in df.columns if column not in ordered]
|
|
72
|
+
df = df[ordered + rest]
|
|
73
|
+
df.to_csv(path, index=False)
|
|
74
|
+
return df
|