ber-equalization-studio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from .kan import KAN, KANLinear
2
+
3
+ __all__ = ["KANLinear", "KAN"]
@@ -0,0 +1,218 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class KANLinear(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_features,
11
+ out_features,
12
+ grid_size=5,
13
+ spline_order=3,
14
+ scale_noise=0.1,
15
+ scale_base=1.0,
16
+ scale_spline=1.0,
17
+ enable_standalone_scale_spline=True,
18
+ base_activation=torch.nn.SiLU,
19
+ grid_eps=0.02,
20
+ grid_range=[-1, 1],
21
+ ):
22
+ super().__init__()
23
+ self.in_features = in_features
24
+ self.out_features = out_features
25
+ self.grid_size = grid_size
26
+ self.spline_order = spline_order
27
+
28
+ h = (grid_range[1] - grid_range[0]) / grid_size
29
+ grid = (
30
+ (torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0])
31
+ .expand(in_features, -1)
32
+ .contiguous()
33
+ )
34
+ self.register_buffer("grid", grid)
35
+
36
+ self.base_weight = torch.nn.Parameter(torch.empty(out_features, in_features))
37
+ self.spline_weight = torch.nn.Parameter(
38
+ torch.empty(out_features, in_features, grid_size + spline_order)
39
+ )
40
+ if enable_standalone_scale_spline:
41
+ self.spline_scaler = torch.nn.Parameter(torch.empty(out_features, in_features))
42
+
43
+ self.scale_noise = scale_noise
44
+ self.scale_base = scale_base
45
+ self.scale_spline = scale_spline
46
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
47
+ self.base_activation = base_activation()
48
+ self.grid_eps = grid_eps
49
+
50
+ self.reset_parameters()
51
+
52
+ def reset_parameters(self):
53
+ torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
54
+ with torch.no_grad():
55
+ noise = (
56
+ (torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 0.5)
57
+ * self.scale_noise
58
+ / self.grid_size
59
+ )
60
+ self.spline_weight.data.copy_(
61
+ (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
62
+ * self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order], noise)
63
+ )
64
+ if self.enable_standalone_scale_spline:
65
+ torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
66
+
67
+ def b_splines(self, x: torch.Tensor):
68
+ assert x.dim() == 2 and x.size(1) == self.in_features
69
+
70
+ grid: torch.Tensor = self.grid
71
+ x = x.unsqueeze(-1)
72
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
73
+ for k in range(1, self.spline_order + 1):
74
+ bases = (
75
+ (x - grid[:, : -(k + 1)])
76
+ / (grid[:, k:-1] - grid[:, : -(k + 1)])
77
+ * bases[:, :, :-1]
78
+ ) + (
79
+ (grid[:, k + 1 :] - x)
80
+ / (grid[:, k + 1 :] - grid[:, 1:-k])
81
+ * bases[:, :, 1:]
82
+ )
83
+
84
+ assert bases.size() == (
85
+ x.size(0),
86
+ self.in_features,
87
+ self.grid_size + self.spline_order,
88
+ )
89
+ return bases.contiguous()
90
+
91
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
92
+ assert x.dim() == 2 and x.size(1) == self.in_features
93
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
94
+
95
+ a = self.b_splines(x).transpose(0, 1)
96
+ b = y.transpose(0, 1)
97
+ solution = torch.linalg.lstsq(a, b).solution
98
+ result = solution.permute(2, 0, 1)
99
+
100
+ assert result.size() == (
101
+ self.out_features,
102
+ self.in_features,
103
+ self.grid_size + self.spline_order,
104
+ )
105
+ return result.contiguous()
106
+
107
+ @property
108
+ def scaled_spline_weight(self):
109
+ if not self.enable_standalone_scale_spline:
110
+ return self.spline_weight
111
+ return self.spline_weight * self.spline_scaler.unsqueeze(-1)
112
+
113
+ def forward(self, x: torch.Tensor):
114
+ assert x.size(-1) == self.in_features
115
+
116
+ original_shape = x.shape
117
+ x = x.reshape(-1, self.in_features)
118
+
119
+ base_output = F.linear(self.base_activation(x), self.base_weight)
120
+ spline_output = F.linear(
121
+ self.b_splines(x).view(x.size(0), -1),
122
+ self.scaled_spline_weight.view(self.out_features, -1),
123
+ )
124
+ output = base_output + spline_output
125
+ return output.reshape(*original_shape[:-1], self.out_features)
126
+
127
+ @torch.no_grad()
128
+ def update_grid(self, x: torch.Tensor, margin=0.01):
129
+ assert x.dim() == 2 and x.size(1) == self.in_features
130
+
131
+ batch = x.size(0)
132
+ splines = self.b_splines(x).permute(1, 0, 2)
133
+ orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
134
+ unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)
135
+
136
+ x_sorted = torch.sort(x, dim=0)[0]
137
+ grid_adaptive = x_sorted[
138
+ torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)
139
+ ]
140
+
141
+ uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
142
+ grid_uniform = (
143
+ torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)
144
+ * uniform_step
145
+ + x_sorted[0]
146
+ - margin
147
+ )
148
+
149
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
150
+ grid = torch.concatenate(
151
+ [
152
+ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
153
+ grid,
154
+ grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
155
+ ],
156
+ dim=0,
157
+ )
158
+
159
+ self.grid.copy_(grid.T)
160
+ self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
161
+
162
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
163
+ l1_fake = self.spline_weight.abs().mean(-1)
164
+ regularization_loss_activation = l1_fake.sum()
165
+ p = l1_fake / regularization_loss_activation
166
+ regularization_loss_entropy = -torch.sum(p * p.log())
167
+ return (
168
+ regularize_activation * regularization_loss_activation
169
+ + regularize_entropy * regularization_loss_entropy
170
+ )
171
+
172
+
173
+ class KAN(torch.nn.Module):
174
+ def __init__(
175
+ self,
176
+ layers_hidden,
177
+ grid_size=5,
178
+ spline_order=3,
179
+ scale_noise=0.1,
180
+ scale_base=1.0,
181
+ scale_spline=1.0,
182
+ base_activation=torch.nn.SiLU,
183
+ grid_eps=0.02,
184
+ grid_range=[-1, 1],
185
+ ):
186
+ super().__init__()
187
+ self.grid_size = grid_size
188
+ self.spline_order = spline_order
189
+ self.layers = torch.nn.ModuleList()
190
+
191
+ for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
192
+ self.layers.append(
193
+ KANLinear(
194
+ in_features,
195
+ out_features,
196
+ grid_size=grid_size,
197
+ spline_order=spline_order,
198
+ scale_noise=scale_noise,
199
+ scale_base=scale_base,
200
+ scale_spline=scale_spline,
201
+ base_activation=base_activation,
202
+ grid_eps=grid_eps,
203
+ grid_range=grid_range,
204
+ )
205
+ )
206
+
207
+ def forward(self, x: torch.Tensor, update_grid=False):
208
+ for layer in self.layers:
209
+ if update_grid:
210
+ layer.update_grid(x)
211
+ x = layer(x)
212
+ return x
213
+
214
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
215
+ return sum(
216
+ layer.regularization_loss(regularize_activation, regularize_entropy)
217
+ for layer in self.layers
218
+ )
@@ -0,0 +1,348 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ import re
5
+ from dataclasses import fields, replace
6
+ from pathlib import Path
7
+ from typing import Any, Iterable
8
+
9
+ from ber_equalization_studio.config import (
10
+ DataConfig,
11
+ EvaluationConfig,
12
+ ExperimentConfig,
13
+ ModelConfig,
14
+ OutputConfig,
15
+ StudioConfig,
16
+ TrainingConfig,
17
+ )
18
+
19
+
20
+ _CONFIG_SECTIONS = {
21
+ "data": DataConfig,
22
+ "model": ModelConfig,
23
+ "training": TrainingConfig,
24
+ "evaluation": EvaluationConfig,
25
+ "output": OutputConfig,
26
+ "experiment": ExperimentConfig,
27
+ }
28
+
29
+ _SHORTCUTS: dict[str, tuple[str, str]] = {
30
+ "data_dirs": ("data", "data_dirs"),
31
+ "context_k": ("data", "context_k"),
32
+ "max_files": ("data", "max_files"),
33
+ "max_test_files": ("data", "max_test_files"),
34
+ "train_portion": ("data", "train_portion"),
35
+ "val_portion_within_train": ("data", "val_portion_within_train"),
36
+ "randomize_file_split": ("data", "randomize_file_split"),
37
+ "split_seed": ("data", "split_seed"),
38
+ "input_dim": ("data", "input_dim"),
39
+ "power_normalize": ("data", "power_normalize"),
40
+ "model_name": ("model", "name"),
41
+ "model": ("model", "name"),
42
+ "hidden_dim": ("model", "hidden_dim"),
43
+ "dropout": ("model", "dropout"),
44
+ "mlp_layers": ("model", "mlp_layers"),
45
+ "fastkan_hidden_dim": ("model", "fastkan_hidden_dim"),
46
+ "fastkan_layers": ("model", "fastkan_layers"),
47
+ "fastkan_num_grids": ("model", "fastkan_num_grids"),
48
+ "fastkan_grids": ("model", "fastkan_num_grids"),
49
+ "efficient_kan_hidden_dim": ("model", "efficient_kan_hidden_dim"),
50
+ "efficient_kan_layers": ("model", "efficient_kan_layers"),
51
+ "efficient_kan_grid_size": ("model", "efficient_kan_grid_size"),
52
+ "efficient_kan_grid": ("model", "efficient_kan_grid_size"),
53
+ "epochs": ("training", "epochs"),
54
+ "lr": ("training", "learning_rate"),
55
+ "learning_rate": ("training", "learning_rate"),
56
+ "weight_decay": ("training", "weight_decay"),
57
+ "optimizer": ("training", "optimizer"),
58
+ "loss": ("training", "loss"),
59
+ "train_block_size": ("training", "train_block_size"),
60
+ "use_amp": ("training", "use_amp"),
61
+ "use_torch_compile": ("training", "use_torch_compile"),
62
+ "grad_clip_norm": ("training", "grad_clip_norm"),
63
+ "early_stopping": ("training", "early_stopping"),
64
+ "early_stopping_patience": ("training", "early_stopping_patience"),
65
+ "seed": ("training", "seed"),
66
+ "eval_batch_size": ("evaluation", "eval_batch_size"),
67
+ "eval_test_during_training": ("evaluation", "eval_test_during_training"),
68
+ "test_ber_every": ("evaluation", "test_ber_every"),
69
+ "compute_per_file_metrics": ("evaluation", "compute_per_file_metrics"),
70
+ "ber_scale_search": ("evaluation", "ber_scale_search"),
71
+ "out_dir": ("output", "out_dir"),
72
+ "experiment_name": ("output", "experiment_name"),
73
+ "save_checkpoint": ("output", "save_checkpoint"),
74
+ "save_history": ("output", "save_history"),
75
+ "save_plots": ("output", "save_plots"),
76
+ "plot_backend": ("output", "plot_backend"),
77
+ "models": ("experiment", "models"),
78
+ "tags": ("experiment", "tags"),
79
+ "notes": ("experiment", "notes"),
80
+ }
81
+
82
+
83
+ def _field_index() -> dict[str, tuple[str, str] | None]:
84
+ index: dict[str, tuple[str, str] | None] = {}
85
+ for section, section_type in _CONFIG_SECTIONS.items():
86
+ for field in fields(section_type):
87
+ if field.name in index:
88
+ index[field.name] = None
89
+ else:
90
+ index[field.name] = (section, field.name)
91
+ return index
92
+
93
+
94
+ _FIELD_INDEX = _field_index()
95
+
96
+
97
+ def _normalize_value(section: str, key: str, value: Any) -> Any:
98
+ if section == "data" and key == "data_dirs":
99
+ if isinstance(value, (str, Path)):
100
+ return [Path(value)]
101
+ return [Path(item) for item in value]
102
+ if section == "output" and key == "out_dir":
103
+ return Path(value)
104
+ if section == "experiment" and key == "models":
105
+ if isinstance(value, str):
106
+ return [item.strip() for item in value.split(",") if item.strip()]
107
+ return list(value)
108
+ return value
109
+
110
+
111
+ def _locate_update(key: str) -> tuple[str, str]:
112
+ if "." in key:
113
+ section, field = key.split(".", 1)
114
+ if section not in _CONFIG_SECTIONS and section != "root":
115
+ raise KeyError(f"Unknown config section: {section}")
116
+ return section, field
117
+ if key == "device":
118
+ return "root", "device"
119
+ if key in _SHORTCUTS:
120
+ return _SHORTCUTS[key]
121
+ inferred = _FIELD_INDEX.get(key)
122
+ if inferred is None:
123
+ raise KeyError(f"Ambiguous or unknown config field: {key}. Use dotted form, for example 'training.{key}'.")
124
+ return inferred
125
+
126
+
127
+ def _apply_updates(config: StudioConfig, updates: dict[str, Any]) -> StudioConfig:
128
+ section_updates: dict[str, dict[str, Any]] = {section: {} for section in _CONFIG_SECTIONS}
129
+ root_updates: dict[str, Any] = {}
130
+ for key, value in updates.items():
131
+ if value is None:
132
+ continue
133
+ section, field = _locate_update(key)
134
+ if section == "root":
135
+ root_updates[field] = value
136
+ continue
137
+ section_updates[section][field] = _normalize_value(section, field, value)
138
+
139
+ next_config = config
140
+ for section, values in section_updates.items():
141
+ if not values:
142
+ continue
143
+ current_section = getattr(next_config, section)
144
+ next_config = replace(next_config, **{section: replace(current_section, **values)})
145
+ if root_updates:
146
+ next_config = replace(next_config, **root_updates)
147
+ return next_config
148
+
149
+
150
+ def _slug(value: str) -> str:
151
+ slug = re.sub(r"[^a-zA-Z0-9_.-]+", "_", value.strip())
152
+ return slug.strip("_") or "run"
153
+
154
+
155
+ def _grid_items(grid: dict[str, Iterable[Any]]) -> list[dict[str, Any]]:
156
+ keys = list(grid)
157
+ values = [list(grid[key]) for key in keys]
158
+ return [dict(zip(keys, combination, strict=True)) for combination in itertools.product(*values)]
159
+
160
+
161
+ def _pd():
162
+ import pandas as pd
163
+
164
+ return pd
165
+
166
+
167
+ class RunResult:
168
+ """Notebook-friendly wrapper around a single experiment run."""
169
+
170
+ def __init__(self, payload: dict[str, Any], config: StudioConfig):
171
+ self.payload = payload
172
+ self.config = config
173
+
174
+ @property
175
+ def run_dir(self) -> Path:
176
+ return Path(self.payload["run_dir"])
177
+
178
+ @property
179
+ def results_csv(self) -> Path:
180
+ return Path(self.payload["results_csv"])
181
+
182
+ @property
183
+ def results(self):
184
+ from ber_equalization_studio.results import flatten_metrics, load_result_table
185
+
186
+ if self.results_csv.exists():
187
+ return load_result_table(self.results_csv)
188
+ return _pd().DataFrame([flatten_metrics(row) for row in self.payload.get("results", [])])
189
+
190
+ @property
191
+ def histories(self) -> dict[str, dict[str, Any]]:
192
+ return self.payload.get("histories", {})
193
+
194
+ def best(self, by: str = "equalized_ber", ascending: bool | None = None) -> pd.Series:
195
+ df = self.results
196
+ if by not in df.columns:
197
+ raise KeyError(f"Column not found in results: {by}")
198
+ if ascending is None:
199
+ ascending = by not in {"improvement_rel", "improvement_db", "accuracy", "samples_per_sec"}
200
+ return df.sort_values(by, ascending=ascending).iloc[0]
201
+
202
+ def compare(self, sort_by: str = "equalized_ber") -> pd.DataFrame:
203
+ from ber_equalization_studio.results import compare_results
204
+
205
+ return compare_results(self.results_csv, sort_by=sort_by)
206
+
207
+ def history(self, model: str | None = None) -> dict[str, Any]:
208
+ if model is None:
209
+ if len(self.histories) != 1:
210
+ raise ValueError("Pass model=... when a run contains more than one model.")
211
+ return next(iter(self.histories.values()))
212
+ return self.histories[model]
213
+
214
+ def plot_history(self, model: str | None = None, metric: str = "ber"):
215
+ from ber_equalization_studio.visualization import history_figure
216
+
217
+ history = self.history(model)
218
+ model_name = model or next(iter(self.histories))
219
+ return history_figure(history, model_name, metric=metric)
220
+
221
+ def plot_comparison(
222
+ self,
223
+ x: str = "model_type",
224
+ y: str = "equalized_ber",
225
+ color: str | None = None,
226
+ kind: str = "bar",
227
+ ):
228
+ from ber_equalization_studio.visualization import comparison_figure
229
+
230
+ return comparison_figure(self.results, x=x, y=y, color=color, kind=kind)
231
+
232
+ def _repr_html_(self) -> str:
233
+ return self.results._repr_html_()
234
+
235
+
236
+ class StudyResult:
237
+ """Result object for a sequence of runs, usually produced by Studio.sweep()."""
238
+
239
+ def __init__(self, runs: list[RunResult], parameters: list[dict[str, Any]]):
240
+ self.runs = runs
241
+ self.parameters = parameters
242
+
243
+ @property
244
+ def results(self):
245
+ pd = _pd()
246
+ frames: list[pd.DataFrame] = []
247
+ for index, run in enumerate(self.runs):
248
+ df = run.results.copy()
249
+ df["run_dir"] = str(run.run_dir)
250
+ for key, value in self.parameters[index].items():
251
+ df[key] = value
252
+ frames.append(df)
253
+ if not frames:
254
+ return pd.DataFrame()
255
+ return pd.concat(frames, ignore_index=True)
256
+
257
+ def best(self, by: str = "equalized_ber", ascending: bool | None = None) -> pd.Series:
258
+ df = self.results
259
+ if by not in df.columns:
260
+ raise KeyError(f"Column not found in results: {by}")
261
+ if ascending is None:
262
+ ascending = by not in {"improvement_rel", "improvement_db", "accuracy", "samples_per_sec"}
263
+ return df.sort_values(by, ascending=ascending).iloc[0]
264
+
265
+ def plot_tradeoff(
266
+ self,
267
+ x: str = "trainable_params",
268
+ y: str = "equalized_ber",
269
+ color: str | None = "model_type",
270
+ ):
271
+ from ber_equalization_studio.visualization import comparison_figure
272
+
273
+ return comparison_figure(self.results, x=x, y=y, color=color, kind="scatter")
274
+
275
+ def _repr_html_(self) -> str:
276
+ return self.results._repr_html_()
277
+
278
+
279
+ class Studio:
280
+ """High-level research interface for notebooks and scripts."""
281
+
282
+ def __init__(
283
+ self,
284
+ base_config: StudioConfig | None = None,
285
+ legacy_path: str | Path | None = None,
286
+ **defaults: Any,
287
+ ):
288
+ self.base_config = _apply_updates(base_config or StudioConfig(), defaults)
289
+ self.legacy_path = Path(legacy_path) if legacy_path is not None else None
290
+
291
+ def config(self, **updates: Any) -> StudioConfig:
292
+ return _apply_updates(self.base_config, updates)
293
+
294
+ def models(self) -> pd.DataFrame:
295
+ from ber_equalization_studio.models import available_models
296
+
297
+ rows = [{"model": name, "description": note} for name, note in available_models().items()]
298
+ return _pd().DataFrame(rows)
299
+
300
+ def run(
301
+ self,
302
+ name: str = "experiment",
303
+ models: str | Iterable[str] | None = None,
304
+ **updates: Any,
305
+ ) -> RunResult:
306
+ if models is not None:
307
+ updates["models"] = models
308
+ updates["experiment_name"] = name
309
+ config = self.config(**updates)
310
+ from ber_equalization_studio.experiment import ExperimentRunner
311
+
312
+ payload = ExperimentRunner(config, legacy_path=self.legacy_path).run()
313
+ return RunResult(payload, config)
314
+
315
+ def sweep(
316
+ self,
317
+ name: str,
318
+ grid: dict[str, Iterable[Any]],
319
+ models: str | Iterable[str] | None = None,
320
+ **updates: Any,
321
+ ) -> StudyResult:
322
+ runs: list[RunResult] = []
323
+ parameters = _grid_items(grid)
324
+ for index, params in enumerate(parameters, start=1):
325
+ parts = [f"{_slug(key)}-{_slug(str(value))}" for key, value in params.items()]
326
+ run_name = _slug("_".join([name, f"{index:03d}", *parts]))
327
+ run_updates = {**updates, **params}
328
+ runs.append(self.run(name=run_name, models=models, **run_updates))
329
+ return StudyResult(runs, parameters)
330
+
331
+
332
+ def run(
333
+ name: str = "experiment",
334
+ models: str | Iterable[str] | None = None,
335
+ legacy_path: str | Path | None = None,
336
+ **updates: Any,
337
+ ) -> RunResult:
338
+ return Studio(legacy_path=legacy_path).run(name=name, models=models, **updates)
339
+
340
+
341
+ def sweep(
342
+ name: str,
343
+ grid: dict[str, Iterable[Any]],
344
+ models: str | Iterable[str] | None = None,
345
+ legacy_path: str | Path | None = None,
346
+ **updates: Any,
347
+ ) -> StudyResult:
348
+ return Studio(legacy_path=legacy_path).sweep(name=name, grid=grid, models=models, **updates)
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from ber_equalization_studio.config import DataConfig, ExperimentConfig, ModelConfig, OutputConfig, StudioConfig, TrainingConfig
7
+
8
+
9
+ def _split_models(value: str) -> list[str]:
10
+ return [item.strip() for item in value.split(",") if item.strip()]
11
+
12
+
13
+ def build_parser() -> argparse.ArgumentParser:
14
+ parser = argparse.ArgumentParser(prog="ber-studio", description="BER equalization experiment studio")
15
+ sub = parser.add_subparsers(dest="command", required=True)
16
+
17
+ run = sub.add_parser("run", help="Train one or more equalizer models")
18
+ run.add_argument("--models", default="complex_fastkan", help="Comma-separated model names")
19
+ run.add_argument("--data-dir", action="append", type=Path, default=None, help="Directory with Symbols_1m_1ch_PR_*.csv")
20
+ run.add_argument("--out-dir", type=Path, default=Path("studio_outputs"))
21
+ run.add_argument("--name", default="experiment")
22
+ run.add_argument("--legacy-path", type=Path, default=None)
23
+ run.add_argument("--epochs", type=int, default=250)
24
+ run.add_argument("--lr", type=float, default=1e-3)
25
+ run.add_argument("--context-k", type=int, default=32)
26
+ run.add_argument("--max-files", type=int, default=64)
27
+ run.add_argument("--max-test-files", type=int, default=None)
28
+ run.add_argument("--device", default=None)
29
+ run.add_argument("--fastkan-hidden", type=int, default=96)
30
+ run.add_argument("--fastkan-grids", type=int, default=8)
31
+ run.add_argument("--efficient-kan-hidden", type=int, default=128)
32
+ run.add_argument("--efficient-kan-grid", type=int, default=8)
33
+ run.add_argument("--no-per-file", action="store_true")
34
+ run.add_argument("--no-plots", action="store_true")
35
+
36
+ sub.add_parser("models", help="List available model names")
37
+
38
+ compare = sub.add_parser("compare", help="Compare result CSV files below a run directory")
39
+ compare.add_argument("path", type=Path)
40
+ compare.add_argument("--sort-by", default="equalized_ber")
41
+
42
+ return parser
43
+
44
+
45
+ def main(argv: list[str] | None = None) -> None:
46
+ parser = build_parser()
47
+ args = parser.parse_args(argv)
48
+
49
+ if args.command == "models":
50
+ from ber_equalization_studio.models import available_models
51
+
52
+ for name, note in available_models().items():
53
+ print(f"{name:32s} {note}")
54
+ return
55
+
56
+ if args.command == "compare":
57
+ from ber_equalization_studio.results import compare_results
58
+
59
+ df = compare_results(args.path, sort_by=args.sort_by)
60
+ print(df.to_string(index=False))
61
+ return
62
+
63
+ data_dirs = args.data_dir if args.data_dir else DataConfig().data_dirs
64
+ config = StudioConfig(
65
+ data=DataConfig(
66
+ data_dirs=data_dirs,
67
+ max_files=args.max_files,
68
+ context_k=args.context_k,
69
+ max_test_files=args.max_test_files,
70
+ ),
71
+ model=ModelConfig(
72
+ fastkan_hidden_dim=args.fastkan_hidden,
73
+ fastkan_num_grids=args.fastkan_grids,
74
+ efficient_kan_hidden_dim=args.efficient_kan_hidden,
75
+ efficient_kan_grid_size=args.efficient_kan_grid,
76
+ ),
77
+ training=TrainingConfig(epochs=args.epochs, learning_rate=args.lr),
78
+ output=OutputConfig(out_dir=args.out_dir, experiment_name=args.name, save_plots=not args.no_plots),
79
+ experiment=ExperimentConfig(models=_split_models(args.models)),
80
+ device=args.device or StudioConfig().device,
81
+ )
82
+ if args.no_per_file:
83
+ config.evaluation.compute_per_file_metrics = False
84
+
85
+ from ber_equalization_studio.experiment import ExperimentRunner
86
+
87
+ result = ExperimentRunner(config, legacy_path=args.legacy_path).run()
88
+ print(f"Saved results: {result['results_csv']}")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()