ber-equalization-studio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ber_equalization_studio/__init__.py +58 -0
- ber_equalization_studio/_legacy_backend/__init__.py +1 -0
- ber_equalization_studio/_legacy_backend/ber_equalization.py +3700 -0
- ber_equalization_studio/_legacy_backend/efficient_kan/__init__.py +3 -0
- ber_equalization_studio/_legacy_backend/efficient_kan/kan.py +218 -0
- ber_equalization_studio/api.py +348 -0
- ber_equalization_studio/cli.py +92 -0
- ber_equalization_studio/config.py +168 -0
- ber_equalization_studio/data.py +31 -0
- ber_equalization_studio/experiment.py +92 -0
- ber_equalization_studio/legacy.py +149 -0
- ber_equalization_studio/models.py +86 -0
- ber_equalization_studio/results.py +74 -0
- ber_equalization_studio/visualization.py +186 -0
- ber_equalization_studio-0.1.0.dist-info/METADATA +266 -0
- ber_equalization_studio-0.1.0.dist-info/RECORD +19 -0
- ber_equalization_studio-0.1.0.dist-info/WHEEL +5 -0
- ber_equalization_studio-0.1.0.dist-info/entry_points.txt +2 -0
- ber_equalization_studio-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class KANLinear(torch.nn.Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
in_features,
|
|
11
|
+
out_features,
|
|
12
|
+
grid_size=5,
|
|
13
|
+
spline_order=3,
|
|
14
|
+
scale_noise=0.1,
|
|
15
|
+
scale_base=1.0,
|
|
16
|
+
scale_spline=1.0,
|
|
17
|
+
enable_standalone_scale_spline=True,
|
|
18
|
+
base_activation=torch.nn.SiLU,
|
|
19
|
+
grid_eps=0.02,
|
|
20
|
+
grid_range=[-1, 1],
|
|
21
|
+
):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.in_features = in_features
|
|
24
|
+
self.out_features = out_features
|
|
25
|
+
self.grid_size = grid_size
|
|
26
|
+
self.spline_order = spline_order
|
|
27
|
+
|
|
28
|
+
h = (grid_range[1] - grid_range[0]) / grid_size
|
|
29
|
+
grid = (
|
|
30
|
+
(torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0])
|
|
31
|
+
.expand(in_features, -1)
|
|
32
|
+
.contiguous()
|
|
33
|
+
)
|
|
34
|
+
self.register_buffer("grid", grid)
|
|
35
|
+
|
|
36
|
+
self.base_weight = torch.nn.Parameter(torch.empty(out_features, in_features))
|
|
37
|
+
self.spline_weight = torch.nn.Parameter(
|
|
38
|
+
torch.empty(out_features, in_features, grid_size + spline_order)
|
|
39
|
+
)
|
|
40
|
+
if enable_standalone_scale_spline:
|
|
41
|
+
self.spline_scaler = torch.nn.Parameter(torch.empty(out_features, in_features))
|
|
42
|
+
|
|
43
|
+
self.scale_noise = scale_noise
|
|
44
|
+
self.scale_base = scale_base
|
|
45
|
+
self.scale_spline = scale_spline
|
|
46
|
+
self.enable_standalone_scale_spline = enable_standalone_scale_spline
|
|
47
|
+
self.base_activation = base_activation()
|
|
48
|
+
self.grid_eps = grid_eps
|
|
49
|
+
|
|
50
|
+
self.reset_parameters()
|
|
51
|
+
|
|
52
|
+
def reset_parameters(self):
|
|
53
|
+
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
noise = (
|
|
56
|
+
(torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 0.5)
|
|
57
|
+
* self.scale_noise
|
|
58
|
+
/ self.grid_size
|
|
59
|
+
)
|
|
60
|
+
self.spline_weight.data.copy_(
|
|
61
|
+
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
|
|
62
|
+
* self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order], noise)
|
|
63
|
+
)
|
|
64
|
+
if self.enable_standalone_scale_spline:
|
|
65
|
+
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
|
|
66
|
+
|
|
67
|
+
def b_splines(self, x: torch.Tensor):
|
|
68
|
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
69
|
+
|
|
70
|
+
grid: torch.Tensor = self.grid
|
|
71
|
+
x = x.unsqueeze(-1)
|
|
72
|
+
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
|
|
73
|
+
for k in range(1, self.spline_order + 1):
|
|
74
|
+
bases = (
|
|
75
|
+
(x - grid[:, : -(k + 1)])
|
|
76
|
+
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
|
|
77
|
+
* bases[:, :, :-1]
|
|
78
|
+
) + (
|
|
79
|
+
(grid[:, k + 1 :] - x)
|
|
80
|
+
/ (grid[:, k + 1 :] - grid[:, 1:-k])
|
|
81
|
+
* bases[:, :, 1:]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert bases.size() == (
|
|
85
|
+
x.size(0),
|
|
86
|
+
self.in_features,
|
|
87
|
+
self.grid_size + self.spline_order,
|
|
88
|
+
)
|
|
89
|
+
return bases.contiguous()
|
|
90
|
+
|
|
91
|
+
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
|
|
92
|
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
93
|
+
assert y.size() == (x.size(0), self.in_features, self.out_features)
|
|
94
|
+
|
|
95
|
+
a = self.b_splines(x).transpose(0, 1)
|
|
96
|
+
b = y.transpose(0, 1)
|
|
97
|
+
solution = torch.linalg.lstsq(a, b).solution
|
|
98
|
+
result = solution.permute(2, 0, 1)
|
|
99
|
+
|
|
100
|
+
assert result.size() == (
|
|
101
|
+
self.out_features,
|
|
102
|
+
self.in_features,
|
|
103
|
+
self.grid_size + self.spline_order,
|
|
104
|
+
)
|
|
105
|
+
return result.contiguous()
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def scaled_spline_weight(self):
|
|
109
|
+
if not self.enable_standalone_scale_spline:
|
|
110
|
+
return self.spline_weight
|
|
111
|
+
return self.spline_weight * self.spline_scaler.unsqueeze(-1)
|
|
112
|
+
|
|
113
|
+
def forward(self, x: torch.Tensor):
|
|
114
|
+
assert x.size(-1) == self.in_features
|
|
115
|
+
|
|
116
|
+
original_shape = x.shape
|
|
117
|
+
x = x.reshape(-1, self.in_features)
|
|
118
|
+
|
|
119
|
+
base_output = F.linear(self.base_activation(x), self.base_weight)
|
|
120
|
+
spline_output = F.linear(
|
|
121
|
+
self.b_splines(x).view(x.size(0), -1),
|
|
122
|
+
self.scaled_spline_weight.view(self.out_features, -1),
|
|
123
|
+
)
|
|
124
|
+
output = base_output + spline_output
|
|
125
|
+
return output.reshape(*original_shape[:-1], self.out_features)
|
|
126
|
+
|
|
127
|
+
@torch.no_grad()
|
|
128
|
+
def update_grid(self, x: torch.Tensor, margin=0.01):
|
|
129
|
+
assert x.dim() == 2 and x.size(1) == self.in_features
|
|
130
|
+
|
|
131
|
+
batch = x.size(0)
|
|
132
|
+
splines = self.b_splines(x).permute(1, 0, 2)
|
|
133
|
+
orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
|
|
134
|
+
unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)
|
|
135
|
+
|
|
136
|
+
x_sorted = torch.sort(x, dim=0)[0]
|
|
137
|
+
grid_adaptive = x_sorted[
|
|
138
|
+
torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
|
|
142
|
+
grid_uniform = (
|
|
143
|
+
torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)
|
|
144
|
+
* uniform_step
|
|
145
|
+
+ x_sorted[0]
|
|
146
|
+
- margin
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
|
150
|
+
grid = torch.concatenate(
|
|
151
|
+
[
|
|
152
|
+
grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
|
|
153
|
+
grid,
|
|
154
|
+
grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
|
|
155
|
+
],
|
|
156
|
+
dim=0,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
self.grid.copy_(grid.T)
|
|
160
|
+
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
|
|
161
|
+
|
|
162
|
+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
|
163
|
+
l1_fake = self.spline_weight.abs().mean(-1)
|
|
164
|
+
regularization_loss_activation = l1_fake.sum()
|
|
165
|
+
p = l1_fake / regularization_loss_activation
|
|
166
|
+
regularization_loss_entropy = -torch.sum(p * p.log())
|
|
167
|
+
return (
|
|
168
|
+
regularize_activation * regularization_loss_activation
|
|
169
|
+
+ regularize_entropy * regularization_loss_entropy
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class KAN(torch.nn.Module):
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
layers_hidden,
|
|
177
|
+
grid_size=5,
|
|
178
|
+
spline_order=3,
|
|
179
|
+
scale_noise=0.1,
|
|
180
|
+
scale_base=1.0,
|
|
181
|
+
scale_spline=1.0,
|
|
182
|
+
base_activation=torch.nn.SiLU,
|
|
183
|
+
grid_eps=0.02,
|
|
184
|
+
grid_range=[-1, 1],
|
|
185
|
+
):
|
|
186
|
+
super().__init__()
|
|
187
|
+
self.grid_size = grid_size
|
|
188
|
+
self.spline_order = spline_order
|
|
189
|
+
self.layers = torch.nn.ModuleList()
|
|
190
|
+
|
|
191
|
+
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
|
|
192
|
+
self.layers.append(
|
|
193
|
+
KANLinear(
|
|
194
|
+
in_features,
|
|
195
|
+
out_features,
|
|
196
|
+
grid_size=grid_size,
|
|
197
|
+
spline_order=spline_order,
|
|
198
|
+
scale_noise=scale_noise,
|
|
199
|
+
scale_base=scale_base,
|
|
200
|
+
scale_spline=scale_spline,
|
|
201
|
+
base_activation=base_activation,
|
|
202
|
+
grid_eps=grid_eps,
|
|
203
|
+
grid_range=grid_range,
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def forward(self, x: torch.Tensor, update_grid=False):
|
|
208
|
+
for layer in self.layers:
|
|
209
|
+
if update_grid:
|
|
210
|
+
layer.update_grid(x)
|
|
211
|
+
x = layer(x)
|
|
212
|
+
return x
|
|
213
|
+
|
|
214
|
+
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
|
215
|
+
return sum(
|
|
216
|
+
layer.regularization_loss(regularize_activation, regularize_entropy)
|
|
217
|
+
for layer in self.layers
|
|
218
|
+
)
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import fields, replace
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable
|
|
8
|
+
|
|
9
|
+
from ber_equalization_studio.config import (
|
|
10
|
+
DataConfig,
|
|
11
|
+
EvaluationConfig,
|
|
12
|
+
ExperimentConfig,
|
|
13
|
+
ModelConfig,
|
|
14
|
+
OutputConfig,
|
|
15
|
+
StudioConfig,
|
|
16
|
+
TrainingConfig,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_CONFIG_SECTIONS = {
|
|
21
|
+
"data": DataConfig,
|
|
22
|
+
"model": ModelConfig,
|
|
23
|
+
"training": TrainingConfig,
|
|
24
|
+
"evaluation": EvaluationConfig,
|
|
25
|
+
"output": OutputConfig,
|
|
26
|
+
"experiment": ExperimentConfig,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
_SHORTCUTS: dict[str, tuple[str, str]] = {
|
|
30
|
+
"data_dirs": ("data", "data_dirs"),
|
|
31
|
+
"context_k": ("data", "context_k"),
|
|
32
|
+
"max_files": ("data", "max_files"),
|
|
33
|
+
"max_test_files": ("data", "max_test_files"),
|
|
34
|
+
"train_portion": ("data", "train_portion"),
|
|
35
|
+
"val_portion_within_train": ("data", "val_portion_within_train"),
|
|
36
|
+
"randomize_file_split": ("data", "randomize_file_split"),
|
|
37
|
+
"split_seed": ("data", "split_seed"),
|
|
38
|
+
"input_dim": ("data", "input_dim"),
|
|
39
|
+
"power_normalize": ("data", "power_normalize"),
|
|
40
|
+
"model_name": ("model", "name"),
|
|
41
|
+
"model": ("model", "name"),
|
|
42
|
+
"hidden_dim": ("model", "hidden_dim"),
|
|
43
|
+
"dropout": ("model", "dropout"),
|
|
44
|
+
"mlp_layers": ("model", "mlp_layers"),
|
|
45
|
+
"fastkan_hidden_dim": ("model", "fastkan_hidden_dim"),
|
|
46
|
+
"fastkan_layers": ("model", "fastkan_layers"),
|
|
47
|
+
"fastkan_num_grids": ("model", "fastkan_num_grids"),
|
|
48
|
+
"fastkan_grids": ("model", "fastkan_num_grids"),
|
|
49
|
+
"efficient_kan_hidden_dim": ("model", "efficient_kan_hidden_dim"),
|
|
50
|
+
"efficient_kan_layers": ("model", "efficient_kan_layers"),
|
|
51
|
+
"efficient_kan_grid_size": ("model", "efficient_kan_grid_size"),
|
|
52
|
+
"efficient_kan_grid": ("model", "efficient_kan_grid_size"),
|
|
53
|
+
"epochs": ("training", "epochs"),
|
|
54
|
+
"lr": ("training", "learning_rate"),
|
|
55
|
+
"learning_rate": ("training", "learning_rate"),
|
|
56
|
+
"weight_decay": ("training", "weight_decay"),
|
|
57
|
+
"optimizer": ("training", "optimizer"),
|
|
58
|
+
"loss": ("training", "loss"),
|
|
59
|
+
"train_block_size": ("training", "train_block_size"),
|
|
60
|
+
"use_amp": ("training", "use_amp"),
|
|
61
|
+
"use_torch_compile": ("training", "use_torch_compile"),
|
|
62
|
+
"grad_clip_norm": ("training", "grad_clip_norm"),
|
|
63
|
+
"early_stopping": ("training", "early_stopping"),
|
|
64
|
+
"early_stopping_patience": ("training", "early_stopping_patience"),
|
|
65
|
+
"seed": ("training", "seed"),
|
|
66
|
+
"eval_batch_size": ("evaluation", "eval_batch_size"),
|
|
67
|
+
"eval_test_during_training": ("evaluation", "eval_test_during_training"),
|
|
68
|
+
"test_ber_every": ("evaluation", "test_ber_every"),
|
|
69
|
+
"compute_per_file_metrics": ("evaluation", "compute_per_file_metrics"),
|
|
70
|
+
"ber_scale_search": ("evaluation", "ber_scale_search"),
|
|
71
|
+
"out_dir": ("output", "out_dir"),
|
|
72
|
+
"experiment_name": ("output", "experiment_name"),
|
|
73
|
+
"save_checkpoint": ("output", "save_checkpoint"),
|
|
74
|
+
"save_history": ("output", "save_history"),
|
|
75
|
+
"save_plots": ("output", "save_plots"),
|
|
76
|
+
"plot_backend": ("output", "plot_backend"),
|
|
77
|
+
"models": ("experiment", "models"),
|
|
78
|
+
"tags": ("experiment", "tags"),
|
|
79
|
+
"notes": ("experiment", "notes"),
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _field_index() -> dict[str, tuple[str, str] | None]:
|
|
84
|
+
index: dict[str, tuple[str, str] | None] = {}
|
|
85
|
+
for section, section_type in _CONFIG_SECTIONS.items():
|
|
86
|
+
for field in fields(section_type):
|
|
87
|
+
if field.name in index:
|
|
88
|
+
index[field.name] = None
|
|
89
|
+
else:
|
|
90
|
+
index[field.name] = (section, field.name)
|
|
91
|
+
return index
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
_FIELD_INDEX = _field_index()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _normalize_value(section: str, key: str, value: Any) -> Any:
|
|
98
|
+
if section == "data" and key == "data_dirs":
|
|
99
|
+
if isinstance(value, (str, Path)):
|
|
100
|
+
return [Path(value)]
|
|
101
|
+
return [Path(item) for item in value]
|
|
102
|
+
if section == "output" and key == "out_dir":
|
|
103
|
+
return Path(value)
|
|
104
|
+
if section == "experiment" and key == "models":
|
|
105
|
+
if isinstance(value, str):
|
|
106
|
+
return [item.strip() for item in value.split(",") if item.strip()]
|
|
107
|
+
return list(value)
|
|
108
|
+
return value
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _locate_update(key: str) -> tuple[str, str]:
|
|
112
|
+
if "." in key:
|
|
113
|
+
section, field = key.split(".", 1)
|
|
114
|
+
if section not in _CONFIG_SECTIONS and section != "root":
|
|
115
|
+
raise KeyError(f"Unknown config section: {section}")
|
|
116
|
+
return section, field
|
|
117
|
+
if key == "device":
|
|
118
|
+
return "root", "device"
|
|
119
|
+
if key in _SHORTCUTS:
|
|
120
|
+
return _SHORTCUTS[key]
|
|
121
|
+
inferred = _FIELD_INDEX.get(key)
|
|
122
|
+
if inferred is None:
|
|
123
|
+
raise KeyError(f"Ambiguous or unknown config field: {key}. Use dotted form, for example 'training.{key}'.")
|
|
124
|
+
return inferred
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _apply_updates(config: StudioConfig, updates: dict[str, Any]) -> StudioConfig:
|
|
128
|
+
section_updates: dict[str, dict[str, Any]] = {section: {} for section in _CONFIG_SECTIONS}
|
|
129
|
+
root_updates: dict[str, Any] = {}
|
|
130
|
+
for key, value in updates.items():
|
|
131
|
+
if value is None:
|
|
132
|
+
continue
|
|
133
|
+
section, field = _locate_update(key)
|
|
134
|
+
if section == "root":
|
|
135
|
+
root_updates[field] = value
|
|
136
|
+
continue
|
|
137
|
+
section_updates[section][field] = _normalize_value(section, field, value)
|
|
138
|
+
|
|
139
|
+
next_config = config
|
|
140
|
+
for section, values in section_updates.items():
|
|
141
|
+
if not values:
|
|
142
|
+
continue
|
|
143
|
+
current_section = getattr(next_config, section)
|
|
144
|
+
next_config = replace(next_config, **{section: replace(current_section, **values)})
|
|
145
|
+
if root_updates:
|
|
146
|
+
next_config = replace(next_config, **root_updates)
|
|
147
|
+
return next_config
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _slug(value: str) -> str:
|
|
151
|
+
slug = re.sub(r"[^a-zA-Z0-9_.-]+", "_", value.strip())
|
|
152
|
+
return slug.strip("_") or "run"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _grid_items(grid: dict[str, Iterable[Any]]) -> list[dict[str, Any]]:
|
|
156
|
+
keys = list(grid)
|
|
157
|
+
values = [list(grid[key]) for key in keys]
|
|
158
|
+
return [dict(zip(keys, combination, strict=True)) for combination in itertools.product(*values)]
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _pd():
|
|
162
|
+
import pandas as pd
|
|
163
|
+
|
|
164
|
+
return pd
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class RunResult:
|
|
168
|
+
"""Notebook-friendly wrapper around a single experiment run."""
|
|
169
|
+
|
|
170
|
+
def __init__(self, payload: dict[str, Any], config: StudioConfig):
|
|
171
|
+
self.payload = payload
|
|
172
|
+
self.config = config
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def run_dir(self) -> Path:
|
|
176
|
+
return Path(self.payload["run_dir"])
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def results_csv(self) -> Path:
|
|
180
|
+
return Path(self.payload["results_csv"])
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def results(self):
|
|
184
|
+
from ber_equalization_studio.results import flatten_metrics, load_result_table
|
|
185
|
+
|
|
186
|
+
if self.results_csv.exists():
|
|
187
|
+
return load_result_table(self.results_csv)
|
|
188
|
+
return _pd().DataFrame([flatten_metrics(row) for row in self.payload.get("results", [])])
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def histories(self) -> dict[str, dict[str, Any]]:
|
|
192
|
+
return self.payload.get("histories", {})
|
|
193
|
+
|
|
194
|
+
def best(self, by: str = "equalized_ber", ascending: bool | None = None) -> pd.Series:
|
|
195
|
+
df = self.results
|
|
196
|
+
if by not in df.columns:
|
|
197
|
+
raise KeyError(f"Column not found in results: {by}")
|
|
198
|
+
if ascending is None:
|
|
199
|
+
ascending = by not in {"improvement_rel", "improvement_db", "accuracy", "samples_per_sec"}
|
|
200
|
+
return df.sort_values(by, ascending=ascending).iloc[0]
|
|
201
|
+
|
|
202
|
+
def compare(self, sort_by: str = "equalized_ber") -> pd.DataFrame:
|
|
203
|
+
from ber_equalization_studio.results import compare_results
|
|
204
|
+
|
|
205
|
+
return compare_results(self.results_csv, sort_by=sort_by)
|
|
206
|
+
|
|
207
|
+
def history(self, model: str | None = None) -> dict[str, Any]:
|
|
208
|
+
if model is None:
|
|
209
|
+
if len(self.histories) != 1:
|
|
210
|
+
raise ValueError("Pass model=... when a run contains more than one model.")
|
|
211
|
+
return next(iter(self.histories.values()))
|
|
212
|
+
return self.histories[model]
|
|
213
|
+
|
|
214
|
+
def plot_history(self, model: str | None = None, metric: str = "ber"):
|
|
215
|
+
from ber_equalization_studio.visualization import history_figure
|
|
216
|
+
|
|
217
|
+
history = self.history(model)
|
|
218
|
+
model_name = model or next(iter(self.histories))
|
|
219
|
+
return history_figure(history, model_name, metric=metric)
|
|
220
|
+
|
|
221
|
+
def plot_comparison(
|
|
222
|
+
self,
|
|
223
|
+
x: str = "model_type",
|
|
224
|
+
y: str = "equalized_ber",
|
|
225
|
+
color: str | None = None,
|
|
226
|
+
kind: str = "bar",
|
|
227
|
+
):
|
|
228
|
+
from ber_equalization_studio.visualization import comparison_figure
|
|
229
|
+
|
|
230
|
+
return comparison_figure(self.results, x=x, y=y, color=color, kind=kind)
|
|
231
|
+
|
|
232
|
+
def _repr_html_(self) -> str:
|
|
233
|
+
return self.results._repr_html_()
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class StudyResult:
|
|
237
|
+
"""Result object for a sequence of runs, usually produced by Studio.sweep()."""
|
|
238
|
+
|
|
239
|
+
def __init__(self, runs: list[RunResult], parameters: list[dict[str, Any]]):
|
|
240
|
+
self.runs = runs
|
|
241
|
+
self.parameters = parameters
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def results(self):
|
|
245
|
+
pd = _pd()
|
|
246
|
+
frames: list[pd.DataFrame] = []
|
|
247
|
+
for index, run in enumerate(self.runs):
|
|
248
|
+
df = run.results.copy()
|
|
249
|
+
df["run_dir"] = str(run.run_dir)
|
|
250
|
+
for key, value in self.parameters[index].items():
|
|
251
|
+
df[key] = value
|
|
252
|
+
frames.append(df)
|
|
253
|
+
if not frames:
|
|
254
|
+
return pd.DataFrame()
|
|
255
|
+
return pd.concat(frames, ignore_index=True)
|
|
256
|
+
|
|
257
|
+
def best(self, by: str = "equalized_ber", ascending: bool | None = None) -> pd.Series:
|
|
258
|
+
df = self.results
|
|
259
|
+
if by not in df.columns:
|
|
260
|
+
raise KeyError(f"Column not found in results: {by}")
|
|
261
|
+
if ascending is None:
|
|
262
|
+
ascending = by not in {"improvement_rel", "improvement_db", "accuracy", "samples_per_sec"}
|
|
263
|
+
return df.sort_values(by, ascending=ascending).iloc[0]
|
|
264
|
+
|
|
265
|
+
def plot_tradeoff(
|
|
266
|
+
self,
|
|
267
|
+
x: str = "trainable_params",
|
|
268
|
+
y: str = "equalized_ber",
|
|
269
|
+
color: str | None = "model_type",
|
|
270
|
+
):
|
|
271
|
+
from ber_equalization_studio.visualization import comparison_figure
|
|
272
|
+
|
|
273
|
+
return comparison_figure(self.results, x=x, y=y, color=color, kind="scatter")
|
|
274
|
+
|
|
275
|
+
def _repr_html_(self) -> str:
|
|
276
|
+
return self.results._repr_html_()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class Studio:
|
|
280
|
+
"""High-level research interface for notebooks and scripts."""
|
|
281
|
+
|
|
282
|
+
def __init__(
|
|
283
|
+
self,
|
|
284
|
+
base_config: StudioConfig | None = None,
|
|
285
|
+
legacy_path: str | Path | None = None,
|
|
286
|
+
**defaults: Any,
|
|
287
|
+
):
|
|
288
|
+
self.base_config = _apply_updates(base_config or StudioConfig(), defaults)
|
|
289
|
+
self.legacy_path = Path(legacy_path) if legacy_path is not None else None
|
|
290
|
+
|
|
291
|
+
def config(self, **updates: Any) -> StudioConfig:
|
|
292
|
+
return _apply_updates(self.base_config, updates)
|
|
293
|
+
|
|
294
|
+
def models(self) -> pd.DataFrame:
|
|
295
|
+
from ber_equalization_studio.models import available_models
|
|
296
|
+
|
|
297
|
+
rows = [{"model": name, "description": note} for name, note in available_models().items()]
|
|
298
|
+
return _pd().DataFrame(rows)
|
|
299
|
+
|
|
300
|
+
def run(
|
|
301
|
+
self,
|
|
302
|
+
name: str = "experiment",
|
|
303
|
+
models: str | Iterable[str] | None = None,
|
|
304
|
+
**updates: Any,
|
|
305
|
+
) -> RunResult:
|
|
306
|
+
if models is not None:
|
|
307
|
+
updates["models"] = models
|
|
308
|
+
updates["experiment_name"] = name
|
|
309
|
+
config = self.config(**updates)
|
|
310
|
+
from ber_equalization_studio.experiment import ExperimentRunner
|
|
311
|
+
|
|
312
|
+
payload = ExperimentRunner(config, legacy_path=self.legacy_path).run()
|
|
313
|
+
return RunResult(payload, config)
|
|
314
|
+
|
|
315
|
+
def sweep(
|
|
316
|
+
self,
|
|
317
|
+
name: str,
|
|
318
|
+
grid: dict[str, Iterable[Any]],
|
|
319
|
+
models: str | Iterable[str] | None = None,
|
|
320
|
+
**updates: Any,
|
|
321
|
+
) -> StudyResult:
|
|
322
|
+
runs: list[RunResult] = []
|
|
323
|
+
parameters = _grid_items(grid)
|
|
324
|
+
for index, params in enumerate(parameters, start=1):
|
|
325
|
+
parts = [f"{_slug(key)}-{_slug(str(value))}" for key, value in params.items()]
|
|
326
|
+
run_name = _slug("_".join([name, f"{index:03d}", *parts]))
|
|
327
|
+
run_updates = {**updates, **params}
|
|
328
|
+
runs.append(self.run(name=run_name, models=models, **run_updates))
|
|
329
|
+
return StudyResult(runs, parameters)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def run(
|
|
333
|
+
name: str = "experiment",
|
|
334
|
+
models: str | Iterable[str] | None = None,
|
|
335
|
+
legacy_path: str | Path | None = None,
|
|
336
|
+
**updates: Any,
|
|
337
|
+
) -> RunResult:
|
|
338
|
+
return Studio(legacy_path=legacy_path).run(name=name, models=models, **updates)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def sweep(
|
|
342
|
+
name: str,
|
|
343
|
+
grid: dict[str, Iterable[Any]],
|
|
344
|
+
models: str | Iterable[str] | None = None,
|
|
345
|
+
legacy_path: str | Path | None = None,
|
|
346
|
+
**updates: Any,
|
|
347
|
+
) -> StudyResult:
|
|
348
|
+
return Studio(legacy_path=legacy_path).sweep(name=name, grid=grid, models=models, **updates)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from ber_equalization_studio.config import DataConfig, ExperimentConfig, ModelConfig, OutputConfig, StudioConfig, TrainingConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _split_models(value: str) -> list[str]:
|
|
10
|
+
return [item.strip() for item in value.split(",") if item.strip()]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
14
|
+
parser = argparse.ArgumentParser(prog="ber-studio", description="BER equalization experiment studio")
|
|
15
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
16
|
+
|
|
17
|
+
run = sub.add_parser("run", help="Train one or more equalizer models")
|
|
18
|
+
run.add_argument("--models", default="complex_fastkan", help="Comma-separated model names")
|
|
19
|
+
run.add_argument("--data-dir", action="append", type=Path, default=None, help="Directory with Symbols_1m_1ch_PR_*.csv")
|
|
20
|
+
run.add_argument("--out-dir", type=Path, default=Path("studio_outputs"))
|
|
21
|
+
run.add_argument("--name", default="experiment")
|
|
22
|
+
run.add_argument("--legacy-path", type=Path, default=None)
|
|
23
|
+
run.add_argument("--epochs", type=int, default=250)
|
|
24
|
+
run.add_argument("--lr", type=float, default=1e-3)
|
|
25
|
+
run.add_argument("--context-k", type=int, default=32)
|
|
26
|
+
run.add_argument("--max-files", type=int, default=64)
|
|
27
|
+
run.add_argument("--max-test-files", type=int, default=None)
|
|
28
|
+
run.add_argument("--device", default=None)
|
|
29
|
+
run.add_argument("--fastkan-hidden", type=int, default=96)
|
|
30
|
+
run.add_argument("--fastkan-grids", type=int, default=8)
|
|
31
|
+
run.add_argument("--efficient-kan-hidden", type=int, default=128)
|
|
32
|
+
run.add_argument("--efficient-kan-grid", type=int, default=8)
|
|
33
|
+
run.add_argument("--no-per-file", action="store_true")
|
|
34
|
+
run.add_argument("--no-plots", action="store_true")
|
|
35
|
+
|
|
36
|
+
sub.add_parser("models", help="List available model names")
|
|
37
|
+
|
|
38
|
+
compare = sub.add_parser("compare", help="Compare result CSV files below a run directory")
|
|
39
|
+
compare.add_argument("path", type=Path)
|
|
40
|
+
compare.add_argument("--sort-by", default="equalized_ber")
|
|
41
|
+
|
|
42
|
+
return parser
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def main(argv: list[str] | None = None) -> None:
|
|
46
|
+
parser = build_parser()
|
|
47
|
+
args = parser.parse_args(argv)
|
|
48
|
+
|
|
49
|
+
if args.command == "models":
|
|
50
|
+
from ber_equalization_studio.models import available_models
|
|
51
|
+
|
|
52
|
+
for name, note in available_models().items():
|
|
53
|
+
print(f"{name:32s} {note}")
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
if args.command == "compare":
|
|
57
|
+
from ber_equalization_studio.results import compare_results
|
|
58
|
+
|
|
59
|
+
df = compare_results(args.path, sort_by=args.sort_by)
|
|
60
|
+
print(df.to_string(index=False))
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
data_dirs = args.data_dir if args.data_dir else DataConfig().data_dirs
|
|
64
|
+
config = StudioConfig(
|
|
65
|
+
data=DataConfig(
|
|
66
|
+
data_dirs=data_dirs,
|
|
67
|
+
max_files=args.max_files,
|
|
68
|
+
context_k=args.context_k,
|
|
69
|
+
max_test_files=args.max_test_files,
|
|
70
|
+
),
|
|
71
|
+
model=ModelConfig(
|
|
72
|
+
fastkan_hidden_dim=args.fastkan_hidden,
|
|
73
|
+
fastkan_num_grids=args.fastkan_grids,
|
|
74
|
+
efficient_kan_hidden_dim=args.efficient_kan_hidden,
|
|
75
|
+
efficient_kan_grid_size=args.efficient_kan_grid,
|
|
76
|
+
),
|
|
77
|
+
training=TrainingConfig(epochs=args.epochs, learning_rate=args.lr),
|
|
78
|
+
output=OutputConfig(out_dir=args.out_dir, experiment_name=args.name, save_plots=not args.no_plots),
|
|
79
|
+
experiment=ExperimentConfig(models=_split_models(args.models)),
|
|
80
|
+
device=args.device or StudioConfig().device,
|
|
81
|
+
)
|
|
82
|
+
if args.no_per_file:
|
|
83
|
+
config.evaluation.compute_per_file_metrics = False
|
|
84
|
+
|
|
85
|
+
from ber_equalization_studio.experiment import ExperimentRunner
|
|
86
|
+
|
|
87
|
+
result = ExperimentRunner(config, legacy_path=args.legacy_path).run()
|
|
88
|
+
print(f"Saved results: {result['results_csv']}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
if __name__ == "__main__":
|
|
92
|
+
main()
|