ber-equalization-studio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+
8
+ def _default_device() -> str:
9
+ try:
10
+ import torch
11
+
12
+ if torch.cuda.is_available():
13
+ return "cuda"
14
+ if torch.backends.mps.is_available():
15
+ return "mps"
16
+ except Exception:
17
+ pass
18
+ return "cpu"
19
+
20
+
21
+ @dataclass(slots=True)
22
+ class DataConfig:
23
+ data_dirs: list[Path] = field(
24
+ default_factory=lambda: [Path("symbols_new"), Path("Symbols_1m_1ch_PR"), Path(".")]
25
+ )
26
+ max_files: int = 64
27
+ train_portion: float = 0.97
28
+ val_portion_within_train: float = 0.10
29
+ min_val_files: int = 1
30
+ randomize_file_split: bool = False
31
+ split_seed: int = 42
32
+ context_k: int = 32
33
+ input_dim: int = 2
34
+ power_normalize: bool = True
35
+ max_test_files: int | None = None
36
+
37
+ def __post_init__(self) -> None:
38
+ self.data_dirs = [Path(path) for path in self.data_dirs]
39
+
40
+ @property
41
+ def seq_len(self) -> int:
42
+ return 2 * self.context_k + 1
43
+
44
+
45
+ @dataclass(slots=True)
46
+ class ModelConfig:
47
+ name: str = "complex_fastkan"
48
+ hidden_dim: int = 96
49
+ dropout: float = 0.2
50
+ mlp_layers: int = 3
51
+ tcn_hidden_dim: int = 96
52
+ tcn_layers: int = 5
53
+ tcn_kernel_size: int = 5
54
+ fastkan_hidden_dim: int = 96
55
+ fastkan_layers: int = 2
56
+ fastkan_num_grids: int = 8
57
+ fastkan_grid_min: float = -2.5
58
+ fastkan_grid_max: float = 2.5
59
+ efficient_kan_hidden_dim: int = 128
60
+ efficient_kan_layers: int = 2
61
+ efficient_kan_grid_size: int = 8
62
+ efficient_kan_spline_order: int = 3
63
+ efficient_kan_grid_range: tuple[float, float] = (-3.0, 3.0)
64
+ complex_light_channels: int = 48
65
+ complex_light_dilations: tuple[int, ...] = (1, 2, 4)
66
+ complex_light_kernel_size: int = 3
67
+ kan_prune_l1: float = 1e-5
68
+ kan_prune_threshold: float = 0.02
69
+
70
+
71
+ @dataclass(slots=True)
72
+ class TrainingConfig:
73
+ epochs: int = 250
74
+ learning_rate: float = 1e-3
75
+ weight_decay: float = 0.0
76
+ optimizer: str = "adam"
77
+ loss: str = "mse"
78
+ train_block_size: int = 8192
79
+ min_block_size: int = 1024
80
+ use_amp: bool = True
81
+ use_torch_compile: bool = False
82
+ grad_clip_norm: float = 1.0
83
+ lr_scheduler: str = "notebook_decay"
84
+ scheduler_factor: float = 0.5
85
+ scheduler_threshold: float = 1e-6
86
+ decay_steps: int = 24
87
+ min_lr: float = 1e-5
88
+ early_stopping: bool = True
89
+ early_stopping_patience: int = 72
90
+ early_stopping_min_epochs: int = 40
91
+ early_stopping_threshold: float = 0.0
92
+ save_best_by: str = "val_ber"
93
+ save_best: bool = True
94
+ log_every: int = 1
95
+ seed: int = 42
96
+
97
+
98
+ @dataclass(slots=True)
99
+ class EvaluationConfig:
100
+ eval_batch_size: int = 65536
101
+ eval_test_during_training: bool = False
102
+ test_ber_every: int = 10
103
+ compute_per_file_metrics: bool = True
104
+ ber_scale_search: bool = True
105
+ ber_scale_min: float = 0.5
106
+ ber_scale_max: float = 1.5
107
+ ber_scale_steps: int = 10
108
+ ber_scale_offset: int = 10000
109
+ ber_scale_samples: int = 1 << 20
110
+ efficiency_batch_size: int = 16000
111
+ efficiency_timing_warmup: int = 5
112
+ efficiency_timing_repeats: int = 20
113
+
114
+
115
+ @dataclass(slots=True)
116
+ class OutputConfig:
117
+ out_dir: Path = Path("studio_outputs")
118
+ experiment_name: str = "experiment"
119
+ save_checkpoint: bool = True
120
+ save_history: bool = True
121
+ save_plots: bool = True
122
+ plot_backend: str = "plotly"
123
+
124
+ def __post_init__(self) -> None:
125
+ self.out_dir = Path(self.out_dir)
126
+
127
+ @property
128
+ def run_dir(self) -> Path:
129
+ return self.out_dir / self.experiment_name
130
+
131
+
132
+ @dataclass(slots=True)
133
+ class ExperimentConfig:
134
+ models: list[str] = field(default_factory=lambda: ["complex_fastkan"])
135
+ tags: dict[str, str] = field(default_factory=dict)
136
+ notes: str = ""
137
+
138
+
139
+ @dataclass(slots=True)
140
+ class StudioConfig:
141
+ data: DataConfig = field(default_factory=DataConfig)
142
+ model: ModelConfig = field(default_factory=ModelConfig)
143
+ training: TrainingConfig = field(default_factory=TrainingConfig)
144
+ evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
145
+ output: OutputConfig = field(default_factory=OutputConfig)
146
+ experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
147
+ device: str = field(default_factory=_default_device)
148
+
149
+ def to_dict(self) -> dict[str, Any]:
150
+ raw = asdict(self)
151
+ raw["data"]["data_dirs"] = [str(path) for path in self.data.data_dirs]
152
+ raw["output"]["out_dir"] = str(self.output.out_dir)
153
+ return raw
154
+
155
+ def with_updates(self, **updates: Any) -> "StudioConfig":
156
+ values = self.to_dict()
157
+ for dotted_key, value in updates.items():
158
+ section, key = dotted_key.split(".", 1)
159
+ values[section][key] = value
160
+ return StudioConfig(
161
+ data=DataConfig(**values["data"]),
162
+ model=ModelConfig(**values["model"]),
163
+ training=TrainingConfig(**values["training"]),
164
+ evaluation=EvaluationConfig(**values["evaluation"]),
165
+ output=OutputConfig(**values["output"]),
166
+ experiment=ExperimentConfig(**values["experiment"]),
167
+ device=values["device"],
168
+ )
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from ber_equalization_studio.config import StudioConfig
7
+ from ber_equalization_studio.legacy import configured_legacy
8
+
9
+
10
+ def prepare_dataset(config: StudioConfig, legacy_path: Path | None = None) -> dict[str, Any]:
11
+ with configured_legacy(config, legacy_path) as module:
12
+ return module.prepare_data(max_test_files=config.data.max_test_files)
13
+
14
+
15
+ def dataset_summary(data: dict[str, Any]) -> dict[str, Any]:
16
+ summary: dict[str, Any] = {}
17
+ for split in ("train", "val", "test"):
18
+ x_key = f"{split}_x"
19
+ y_key = f"{split}_y"
20
+ if x_key in data:
21
+ summary[f"{split}_samples"] = int(data[x_key].size(0))
22
+ if y_key in data:
23
+ summary[f"{split}_targets"] = int(data[y_key].size(0))
24
+ spans = data.get(f"{split}_file_spans", [])
25
+ if spans:
26
+ summary[f"{split}_files"] = [int(item["file_idx"]) for item in spans]
27
+ for key in ("tx_scale", "rx_scale"):
28
+ if key in data:
29
+ value = data[key]
30
+ summary[key] = float(value.item() if hasattr(value, "item") else value)
31
+ return summary
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+ import random
6
+ from dataclasses import replace
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from ber_equalization_studio.config import StudioConfig
14
+ from ber_equalization_studio.data import dataset_summary, prepare_dataset
15
+ from ber_equalization_studio.legacy import configured_legacy
16
+ from ber_equalization_studio.models import canonical_model_name
17
+ from ber_equalization_studio.results import save_json, write_result_table
18
+ from ber_equalization_studio.visualization import plot_comparison, plot_history
19
+
20
+
21
+ def set_reproducible_seed(seed: int) -> None:
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.manual_seed_all(seed)
27
+
28
+
29
+ class ExperimentRunner:
30
+ def __init__(self, config: StudioConfig, legacy_path: Path | None = None):
31
+ self.config = config
32
+ self.legacy_path = legacy_path
33
+ self.run_dir = config.output.run_dir
34
+
35
+ def _model_config(self, model_name: str) -> StudioConfig:
36
+ model = replace(self.config.model, name=canonical_model_name(model_name))
37
+ return replace(self.config, model=model)
38
+
39
+ def run(self) -> dict[str, Any]:
40
+ set_reproducible_seed(self.config.training.seed)
41
+ self.run_dir.mkdir(parents=True, exist_ok=True)
42
+ save_json(self.run_dir / "config.json", self.config.to_dict())
43
+
44
+ data = prepare_dataset(self.config, self.legacy_path)
45
+ save_json(self.run_dir / "dataset_summary.json", dataset_summary(data))
46
+
47
+ rows: list[dict[str, Any]] = []
48
+ histories: dict[str, dict[str, Any]] = {}
49
+ models = self.config.experiment.models or [self.config.model.name]
50
+
51
+ for requested_model_name in models:
52
+ model_name = canonical_model_name(requested_model_name)
53
+ model_config = self._model_config(model_name)
54
+ model_run_dir = self.run_dir / model_name
55
+ model_run_dir.mkdir(parents=True, exist_ok=True)
56
+
57
+ with configured_legacy(model_config, self.legacy_path) as legacy:
58
+ legacy.Config.OUT_DIR = model_run_dir
59
+ legacy.Config.OUT_DIR.mkdir(parents=True, exist_ok=True)
60
+ model, history, metrics = legacy.train_one_model(model_name, data)
61
+
62
+ if model_config.output.save_checkpoint:
63
+ torch.save(model.state_dict(), model_run_dir / "final_state_dict.pt")
64
+
65
+ metrics = copy.deepcopy(metrics)
66
+ metrics["model_type"] = model_name
67
+ metrics["experiment_name"] = self.config.output.experiment_name
68
+ rows.append(metrics)
69
+ histories[model_name] = history
70
+
71
+ if self.config.output.save_history:
72
+ save_json(model_run_dir / "history.json", history)
73
+ save_json(model_run_dir / "metrics.json", metrics)
74
+ if self.config.output.save_plots:
75
+ plot_history(history, model_name, model_run_dir)
76
+
77
+ if torch.cuda.is_available():
78
+ torch.cuda.empty_cache()
79
+
80
+ df = write_result_table(self.run_dir / "results.csv", rows)
81
+ if self.config.output.save_plots:
82
+ plot_comparison(df, self.run_dir)
83
+ return {
84
+ "run_dir": self.run_dir,
85
+ "results": rows,
86
+ "histories": histories,
87
+ "results_csv": self.run_dir / "results.csv",
88
+ }
89
+
90
+
91
+ def run_experiment(config: StudioConfig, legacy_path: Path | None = None) -> dict[str, Any]:
92
+ return ExperimentRunner(config=config, legacy_path=legacy_path).run()
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ import sys
5
+ from contextlib import contextmanager
6
+ from pathlib import Path
7
+ from types import ModuleType
8
+ from typing import Any, Iterator
9
+
10
+ import torch
11
+
12
+ from ber_equalization_studio.config import StudioConfig
13
+
14
+ _LEGACY_MODULE: ModuleType | None = None
15
+
16
+
17
+ def default_legacy_path() -> Path:
18
+ package_root = Path(__file__).resolve().parents[3]
19
+ package_dir = Path(__file__).resolve().parent
20
+ candidates = [
21
+ package_dir / "_legacy_backend" / "ber_equalization.py",
22
+ package_root.parent / "BER_minimization_survey" / "ber_equalization.py",
23
+ package_root / "BER_minimization_survey" / "ber_equalization.py",
24
+ Path.cwd() / "BER_minimization_survey" / "ber_equalization.py",
25
+ Path.cwd() / "ber_equalization.py",
26
+ ]
27
+ for path in candidates:
28
+ if path.exists():
29
+ return path
30
+ raise FileNotFoundError(
31
+ "Could not find legacy ber_equalization.py. Run from the repository root or pass legacy_path explicitly."
32
+ )
33
+
34
+
35
+ def load_legacy_module(legacy_path: Path | None = None) -> ModuleType:
36
+ global _LEGACY_MODULE
37
+ if _LEGACY_MODULE is not None:
38
+ return _LEGACY_MODULE
39
+
40
+ path = (legacy_path or default_legacy_path()).resolve()
41
+ sys.path.insert(0, str(path.parent))
42
+ spec = importlib.util.spec_from_file_location("_ber_equalization_legacy", path)
43
+ if spec is None or spec.loader is None:
44
+ raise ImportError(f"Cannot import legacy module from {path}")
45
+ module = importlib.util.module_from_spec(spec)
46
+ sys.modules[spec.name] = module
47
+ spec.loader.exec_module(module)
48
+ _LEGACY_MODULE = module
49
+ return module
50
+
51
+
52
+ def _legacy_overrides(config: StudioConfig) -> dict[str, Any]:
53
+ data = config.data
54
+ model = config.model
55
+ training = config.training
56
+ evaluation = config.evaluation
57
+ output = config.output
58
+ return {
59
+ "DEVICE": torch.device(config.device),
60
+ "DATA_DIR_CANDIDATES": [Path(path) for path in data.data_dirs],
61
+ "MAX_FILES": data.max_files,
62
+ "TRAIN_PORTION": data.train_portion,
63
+ "VAL_PORTION_WITHIN_TRAIN": data.val_portion_within_train,
64
+ "MIN_VAL_FILES": data.min_val_files,
65
+ "RANDOMIZE_FILE_SPLIT": data.randomize_file_split,
66
+ "SPLIT_SEED": data.split_seed,
67
+ "CONTEXT_K": data.context_k,
68
+ "SEQ_LEN": data.seq_len,
69
+ "INPUT_DIM": data.input_dim,
70
+ "POWER_NORMALIZE": data.power_normalize,
71
+ "HIDDEN_DIM": model.hidden_dim,
72
+ "DROPOUT": model.dropout,
73
+ "MLP_LAYERS": model.mlp_layers,
74
+ "TCN_HIDDEN_DIM": model.tcn_hidden_dim,
75
+ "TCN_LAYERS": model.tcn_layers,
76
+ "TCN_KERNEL_SIZE": model.tcn_kernel_size,
77
+ "FASTKAN_HIDDEN_DIM": model.fastkan_hidden_dim,
78
+ "FASTKAN_LAYERS": model.fastkan_layers,
79
+ "FASTKAN_NUM_GRIDS": model.fastkan_num_grids,
80
+ "FASTKAN_GRID_MIN": model.fastkan_grid_min,
81
+ "FASTKAN_GRID_MAX": model.fastkan_grid_max,
82
+ "EFFICIENT_KAN_HIDDEN_DIM": model.efficient_kan_hidden_dim,
83
+ "EFFICIENT_KAN_LAYERS": model.efficient_kan_layers,
84
+ "EFFICIENT_KAN_GRID_SIZE": model.efficient_kan_grid_size,
85
+ "EFFICIENT_KAN_SPLINE_ORDER": model.efficient_kan_spline_order,
86
+ "EFFICIENT_KAN_GRID_RANGE": list(model.efficient_kan_grid_range),
87
+ "COMPLEX_LIGHT_CHANNELS": model.complex_light_channels,
88
+ "COMPLEX_LIGHT_DILATIONS": list(model.complex_light_dilations),
89
+ "COMPLEX_LIGHT_KERNEL_SIZE": model.complex_light_kernel_size,
90
+ "KAN_PRUNE_L1": model.kan_prune_l1,
91
+ "KAN_PRUNE_THRESHOLD": model.kan_prune_threshold,
92
+ "EPOCHS": training.epochs,
93
+ "LEARNING_RATE": training.learning_rate,
94
+ "WEIGHT_DECAY": training.weight_decay,
95
+ "OPTIMIZER": training.optimizer,
96
+ "LOSS": training.loss,
97
+ "TRAIN_BLOCK_SIZE": training.train_block_size,
98
+ "MIN_BLOCK_SIZE": training.min_block_size,
99
+ "USE_AMP": training.use_amp,
100
+ "USE_TORCH_COMPILE": training.use_torch_compile,
101
+ "GRAD_CLIP_NORM": training.grad_clip_norm,
102
+ "LR_SCHEDULER": training.lr_scheduler,
103
+ "SCHEDULER_FACTOR": training.scheduler_factor,
104
+ "SCHEDULER_THRESHOLD": training.scheduler_threshold,
105
+ "DECAY_STEPS": training.decay_steps,
106
+ "MIN_LR": training.min_lr,
107
+ "EARLY_STOPPING": training.early_stopping,
108
+ "EARLY_STOPPING_PATIENCE": training.early_stopping_patience,
109
+ "EARLY_STOPPING_MIN_EPOCHS": training.early_stopping_min_epochs,
110
+ "EARLY_STOPPING_THRESHOLD": training.early_stopping_threshold,
111
+ "SAVE_BEST_BY": training.save_best_by,
112
+ "SAVE_BEST": training.save_best,
113
+ "LOG_EVERY": training.log_every,
114
+ "EVAL_BATCH_SIZE": evaluation.eval_batch_size,
115
+ "EVAL_TEST_DURING_TRAINING": evaluation.eval_test_during_training,
116
+ "TEST_BER_EVERY": evaluation.test_ber_every,
117
+ "COMPUTE_PER_FILE_METRICS": evaluation.compute_per_file_metrics,
118
+ "BER_SCALE_SEARCH": evaluation.ber_scale_search,
119
+ "BER_SCALE_MIN": evaluation.ber_scale_min,
120
+ "BER_SCALE_MAX": evaluation.ber_scale_max,
121
+ "BER_SCALE_STEPS": evaluation.ber_scale_steps,
122
+ "BER_SCALE_OFFSET": evaluation.ber_scale_offset,
123
+ "BER_SCALE_SAMPLES": evaluation.ber_scale_samples,
124
+ "EFFICIENCY_BATCH_SIZE": evaluation.efficiency_batch_size,
125
+ "EFFICIENCY_TIMING_WARMUP": evaluation.efficiency_timing_warmup,
126
+ "EFFICIENCY_TIMING_REPEATS": evaluation.efficiency_timing_repeats,
127
+ "OUT_DIR": output.run_dir,
128
+ "MODEL_TYPES": list(config.experiment.models),
129
+ "RUN_MAIN_EXPERIMENTS": True,
130
+ "RUN_SWEEP_EXPERIMENTS": False,
131
+ "RUN_EFFICIENT_KAN_SWEEP": False,
132
+ "RUN_KAN_EXPERIMENT_SUITE": False,
133
+ "RUN_FASTKAN_CLASSIFIER_SWEEP": False,
134
+ }
135
+
136
+
137
+ @contextmanager
138
+ def configured_legacy(config: StudioConfig, legacy_path: Path | None = None) -> Iterator[ModuleType]:
139
+ module = load_legacy_module(legacy_path)
140
+ legacy_config = module.Config
141
+ overrides = _legacy_overrides(config)
142
+ previous = {key: getattr(legacy_config, key, None) for key in overrides}
143
+ for key, value in overrides.items():
144
+ setattr(legacy_config, key, value)
145
+ try:
146
+ yield module
147
+ finally:
148
+ for key, value in previous.items():
149
+ setattr(legacy_config, key, value)
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ if TYPE_CHECKING:
7
+ import torch.nn as nn
8
+
9
+ from ber_equalization_studio.config import StudioConfig
10
+
11
+
12
+ MODEL_ALIASES: dict[str, str] = {
13
+ "ekan": "efficient_kan_baseline",
14
+ "efficient_kan": "efficient_kan_baseline",
15
+ "kan": "complex_fastkan",
16
+ "fastkan": "complex_fastkan",
17
+ "rbf_kan": "complex_fastkan",
18
+ "efficient_kan_classifier": "kan_classifier",
19
+ "rbf_kan_classifier": "fastkan_classifier",
20
+ "complex_rbf_kan_classifier": "complex_fastkan_classifier",
21
+ }
22
+
23
+
24
+ MODEL_NOTES: dict[str, str] = {
25
+ "efficient_kan_baseline": "flat IQ window -> B-spline EfficientKAN -> corrected I/Q",
26
+ "efficient_kan_residual": "flat IQ window -> KAN correction -> rx center + correction",
27
+ "efficient_kan_features": "handcrafted IQ statistics -> EfficientKAN -> corrected I/Q",
28
+ "cnn_kan": "temporal CNN encoder -> EfficientKAN head -> corrected I/Q",
29
+ "kan_classifier": "flat IQ window -> EfficientKAN -> 16 constellation logits",
30
+ "fastkan_classifier": "flat IQ window -> Gaussian RBF/FastKAN -> 16 logits",
31
+ "complex_fastkan": "lightweight complex temporal encoder -> RBF/FastKAN -> corrected I/Q",
32
+ "complex_fastkan_classifier": "lightweight complex temporal encoder -> RBF/FastKAN -> 16 logits",
33
+ "mlp": "flat IQ window -> residual MLP -> corrected I/Q",
34
+ "cnn": "temporal CNN encoder -> MLP head -> corrected I/Q",
35
+ "lstm": "LSTM over IQ window -> corrected I/Q",
36
+ "hybrid": "CNN front-end + LSTM -> corrected I/Q",
37
+ "complex_cnn": "complex feature encoder + temporal CNN -> corrected I/Q",
38
+ "complex_lstm": "complex feature encoder + LSTM -> corrected I/Q",
39
+ "complex_cnn_lstm": "complex temporal CNN + LSTM -> corrected I/Q",
40
+ "complex_dbp_seqstat": "learnable DBP-inspired front-end + sequence statistics",
41
+ "transformer": "local transformer encoder over IQ window -> corrected I/Q",
42
+ "tcn": "dilated temporal convolutional network -> corrected I/Q",
43
+ "mamba": "Mamba sequence blocks over IQ window -> corrected I/Q",
44
+ }
45
+
46
+
47
+ def canonical_model_name(name: str) -> str:
48
+ normalized = name.strip().lower()
49
+ return MODEL_ALIASES.get(normalized, normalized)
50
+
51
+
52
+ def available_models() -> dict[str, str]:
53
+ return dict(sorted(MODEL_NOTES.items()))
54
+
55
+
56
+ def build_model(
57
+ name: str,
58
+ config: StudioConfig | None = None,
59
+ legacy_path: Path | None = None,
60
+ ) -> "nn.Module":
61
+ """Build a model from the studio registry.
62
+
63
+ The current backend reuses the verified model implementations from
64
+ `ber_equalization.py`, while the public API is driven by StudioConfig.
65
+ """
66
+
67
+ from ber_equalization_studio.legacy import configured_legacy, load_legacy_module
68
+
69
+ model_name = canonical_model_name(name)
70
+ if config is None:
71
+ module = load_legacy_module(legacy_path)
72
+ return module.make_model(model_name)
73
+ with configured_legacy(config, legacy_path) as module:
74
+ return module.make_model(model_name)
75
+
76
+
77
+ def count_parameters(model: "nn.Module") -> int:
78
+ return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
79
+
80
+
81
+ def describe_model(name: str) -> dict[str, Any]:
82
+ model_name = canonical_model_name(name)
83
+ return {
84
+ "name": model_name,
85
+ "description": MODEL_NOTES.get(model_name, "custom or legacy model"),
86
+ }
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import pandas as pd
8
+
9
+
10
+ PREFERRED_COLUMNS = [
11
+ "model_type",
12
+ "equalized_ber",
13
+ "baseline_ber",
14
+ "improvement_rel",
15
+ "improvement_db",
16
+ "accuracy",
17
+ "ser",
18
+ "best_val_ber",
19
+ "trainable_params",
20
+ "batch_time_sec",
21
+ "samples_per_sec",
22
+ "train_samples_per_sec",
23
+ "epochs_ran",
24
+ "stop_reason",
25
+ ]
26
+
27
+
28
+ def flatten_metrics(metrics: dict[str, Any]) -> dict[str, Any]:
29
+ flat: dict[str, Any] = {}
30
+ for key, value in metrics.items():
31
+ if isinstance(value, (str, int, float, bool)) or value is None:
32
+ flat[key] = value
33
+ elif hasattr(value, "item"):
34
+ flat[key] = value.item()
35
+ else:
36
+ flat[key] = json.dumps(value, ensure_ascii=True)
37
+ return flat
38
+
39
+
40
+ def save_json(path: Path, payload: dict[str, Any]) -> None:
41
+ path.parent.mkdir(parents=True, exist_ok=True)
42
+ path.write_text(json.dumps(payload, indent=2, ensure_ascii=True), encoding="utf-8")
43
+
44
+
45
+ def load_result_table(path: Path) -> pd.DataFrame:
46
+ if path.is_dir():
47
+ candidates = sorted(path.glob("**/results.csv"))
48
+ if not candidates:
49
+ candidates = sorted(path.glob("**/*summary*.csv"))
50
+ if not candidates:
51
+ raise FileNotFoundError(f"No result CSV files found below {path}")
52
+ frames = [pd.read_csv(candidate) for candidate in candidates]
53
+ return pd.concat(frames, ignore_index=True)
54
+ return pd.read_csv(path)
55
+
56
+
57
+ def compare_results(path: Path, sort_by: str = "equalized_ber") -> pd.DataFrame:
58
+ df = load_result_table(path)
59
+ if sort_by in df.columns:
60
+ ascending = sort_by not in {"improvement_rel", "improvement_db", "accuracy", "samples_per_sec"}
61
+ df = df.sort_values(sort_by, ascending=ascending)
62
+ ordered = [column for column in PREFERRED_COLUMNS if column in df.columns]
63
+ rest = [column for column in df.columns if column not in ordered]
64
+ return df[ordered + rest]
65
+
66
+
67
+ def write_result_table(path: Path, rows: list[dict[str, Any]]) -> pd.DataFrame:
68
+ path.parent.mkdir(parents=True, exist_ok=True)
69
+ df = pd.DataFrame([flatten_metrics(row) for row in rows])
70
+ ordered = [column for column in PREFERRED_COLUMNS if column in df.columns]
71
+ rest = [column for column in df.columns if column not in ordered]
72
+ df = df[ordered + rest]
73
+ df.to_csv(path, index=False)
74
+ return df