oscura 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oscura/__init__.py +19 -19
- oscura/analyzers/__init__.py +2 -0
- oscura/analyzers/digital/extraction.py +2 -3
- oscura/analyzers/digital/quality.py +1 -1
- oscura/analyzers/digital/timing.py +1 -1
- oscura/analyzers/patterns/__init__.py +66 -0
- oscura/analyzers/power/basic.py +3 -3
- oscura/analyzers/power/soa.py +1 -1
- oscura/analyzers/power/switching.py +3 -3
- oscura/analyzers/signal_classification.py +529 -0
- oscura/analyzers/signal_integrity/sparams.py +3 -3
- oscura/analyzers/statistics/basic.py +10 -7
- oscura/analyzers/validation.py +1 -1
- oscura/analyzers/waveform/measurements.py +200 -156
- oscura/analyzers/waveform/measurements_with_uncertainty.py +91 -35
- oscura/analyzers/waveform/spectral.py +164 -73
- oscura/api/dsl/commands.py +15 -6
- oscura/api/server/templates/base.html +137 -146
- oscura/api/server/templates/export.html +84 -110
- oscura/api/server/templates/home.html +248 -267
- oscura/api/server/templates/protocols.html +44 -48
- oscura/api/server/templates/reports.html +27 -35
- oscura/api/server/templates/session_detail.html +68 -78
- oscura/api/server/templates/sessions.html +62 -72
- oscura/api/server/templates/waveforms.html +54 -64
- oscura/automotive/__init__.py +1 -1
- oscura/automotive/can/session.py +1 -1
- oscura/automotive/dbc/generator.py +638 -23
- oscura/automotive/uds/decoder.py +99 -6
- oscura/cli/analyze.py +8 -2
- oscura/cli/batch.py +36 -5
- oscura/cli/characterize.py +18 -4
- oscura/cli/export.py +47 -5
- oscura/cli/main.py +2 -0
- oscura/cli/onboarding/wizard.py +10 -6
- oscura/cli/pipeline.py +585 -0
- oscura/cli/visualize.py +6 -4
- oscura/convenience.py +400 -32
- oscura/core/measurement_result.py +286 -0
- oscura/core/progress.py +1 -1
- oscura/core/types.py +232 -239
- oscura/correlation/multi_protocol.py +1 -1
- oscura/export/legacy/__init__.py +11 -0
- oscura/export/legacy/wav.py +75 -0
- oscura/exporters/__init__.py +19 -0
- oscura/exporters/wireshark.py +809 -0
- oscura/hardware/acquisition/file.py +5 -19
- oscura/hardware/acquisition/saleae.py +10 -10
- oscura/hardware/acquisition/socketcan.py +4 -6
- oscura/hardware/acquisition/synthetic.py +1 -5
- oscura/hardware/acquisition/visa.py +6 -6
- oscura/hardware/security/side_channel_detector.py +5 -508
- oscura/inference/message_format.py +686 -1
- oscura/jupyter/display.py +2 -2
- oscura/jupyter/magic.py +3 -3
- oscura/loaders/__init__.py +17 -12
- oscura/loaders/binary.py +1 -1
- oscura/loaders/chipwhisperer.py +1 -2
- oscura/loaders/configurable.py +1 -1
- oscura/loaders/csv_loader.py +2 -2
- oscura/loaders/hdf5_loader.py +1 -1
- oscura/loaders/lazy.py +6 -1
- oscura/loaders/mmap_loader.py +0 -1
- oscura/loaders/numpy_loader.py +8 -7
- oscura/loaders/preprocessing.py +3 -5
- oscura/loaders/rigol.py +21 -7
- oscura/loaders/sigrok.py +2 -5
- oscura/loaders/tdms.py +3 -2
- oscura/loaders/tektronix.py +38 -32
- oscura/loaders/tss.py +20 -27
- oscura/loaders/vcd.py +13 -8
- oscura/loaders/wav.py +1 -6
- oscura/pipeline/__init__.py +76 -0
- oscura/pipeline/handlers/__init__.py +165 -0
- oscura/pipeline/handlers/analyzers.py +1045 -0
- oscura/pipeline/handlers/decoders.py +899 -0
- oscura/pipeline/handlers/exporters.py +1103 -0
- oscura/pipeline/handlers/filters.py +891 -0
- oscura/pipeline/handlers/loaders.py +640 -0
- oscura/pipeline/handlers/transforms.py +768 -0
- oscura/reporting/formatting/measurements.py +55 -14
- oscura/reporting/templates/enhanced/protocol_re.html +504 -503
- oscura/side_channel/__init__.py +38 -57
- oscura/utils/builders/signal_builder.py +5 -5
- oscura/utils/comparison/compare.py +7 -9
- oscura/utils/comparison/golden.py +1 -1
- oscura/utils/filtering/convenience.py +2 -2
- oscura/utils/math/arithmetic.py +38 -62
- oscura/utils/math/interpolation.py +20 -20
- oscura/utils/pipeline/__init__.py +4 -17
- oscura/utils/progressive.py +1 -4
- oscura/utils/triggering/edge.py +1 -1
- oscura/utils/triggering/pattern.py +2 -2
- oscura/utils/triggering/pulse.py +2 -2
- oscura/utils/triggering/window.py +3 -3
- oscura/validation/hil_testing.py +11 -11
- oscura/visualization/__init__.py +46 -284
- oscura/visualization/batch.py +72 -433
- oscura/visualization/plot.py +542 -53
- oscura/visualization/styles.py +184 -318
- oscura/workflows/batch/advanced.py +1 -1
- oscura/workflows/batch/aggregate.py +7 -8
- oscura/workflows/complete_re.py +251 -23
- oscura/workflows/digital.py +27 -4
- oscura/workflows/multi_trace.py +136 -17
- oscura/workflows/waveform.py +11 -6
- {oscura-0.8.0.dist-info → oscura-0.10.0.dist-info}/METADATA +59 -79
- {oscura-0.8.0.dist-info → oscura-0.10.0.dist-info}/RECORD +111 -136
- oscura/side_channel/dpa.py +0 -1025
- oscura/utils/optimization/__init__.py +0 -19
- oscura/utils/optimization/parallel.py +0 -443
- oscura/utils/optimization/search.py +0 -532
- oscura/utils/pipeline/base.py +0 -338
- oscura/utils/pipeline/composition.py +0 -248
- oscura/utils/pipeline/parallel.py +0 -449
- oscura/utils/pipeline/pipeline.py +0 -375
- oscura/utils/search/__init__.py +0 -16
- oscura/utils/search/anomaly.py +0 -424
- oscura/utils/search/context.py +0 -294
- oscura/utils/search/pattern.py +0 -288
- oscura/utils/storage/__init__.py +0 -61
- oscura/utils/storage/database.py +0 -1166
- oscura/visualization/accessibility.py +0 -526
- oscura/visualization/annotations.py +0 -371
- oscura/visualization/axis_scaling.py +0 -305
- oscura/visualization/colors.py +0 -451
- oscura/visualization/digital.py +0 -436
- oscura/visualization/eye.py +0 -571
- oscura/visualization/histogram.py +0 -281
- oscura/visualization/interactive.py +0 -1035
- oscura/visualization/jitter.py +0 -1042
- oscura/visualization/keyboard.py +0 -394
- oscura/visualization/layout.py +0 -400
- oscura/visualization/optimization.py +0 -1079
- oscura/visualization/palettes.py +0 -446
- oscura/visualization/power.py +0 -508
- oscura/visualization/power_extended.py +0 -955
- oscura/visualization/presets.py +0 -469
- oscura/visualization/protocols.py +0 -1246
- oscura/visualization/render.py +0 -223
- oscura/visualization/rendering.py +0 -444
- oscura/visualization/reverse_engineering.py +0 -838
- oscura/visualization/signal_integrity.py +0 -989
- oscura/visualization/specialized.py +0 -643
- oscura/visualization/spectral.py +0 -1226
- oscura/visualization/thumbnails.py +0 -340
- oscura/visualization/time_axis.py +0 -351
- oscura/visualization/waveform.py +0 -454
- {oscura-0.8.0.dist-info → oscura-0.10.0.dist-info}/WHEEL +0 -0
- {oscura-0.8.0.dist-info → oscura-0.10.0.dist-info}/entry_points.txt +0 -0
- {oscura-0.8.0.dist-info → oscura-0.10.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,532 +0,0 @@
|
|
|
1
|
-
"""Parameter optimization via grid search and random search.
|
|
2
|
-
|
|
3
|
-
This module provides tools for finding optimal analysis parameters through
|
|
4
|
-
systematic or random search of the parameter space.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from __future__ import annotations
|
|
8
|
-
|
|
9
|
-
import itertools
|
|
10
|
-
from collections.abc import Callable
|
|
11
|
-
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
12
|
-
from dataclasses import dataclass
|
|
13
|
-
from typing import TYPE_CHECKING, Any, Literal
|
|
14
|
-
|
|
15
|
-
import numpy as np
|
|
16
|
-
import pandas as pd
|
|
17
|
-
|
|
18
|
-
from oscura.analyzers.waveform.spectral import thd as compute_thd
|
|
19
|
-
from oscura.core.exceptions import AnalysisError
|
|
20
|
-
|
|
21
|
-
if TYPE_CHECKING:
|
|
22
|
-
from numpy.typing import NDArray
|
|
23
|
-
|
|
24
|
-
from oscura.core.types import WaveformTrace
|
|
25
|
-
|
|
26
|
-
ScoringFunction = Callable[[WaveformTrace, dict[str, Any]], float]
|
|
27
|
-
else:
|
|
28
|
-
ScoringFunction = Callable
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@dataclass
|
|
32
|
-
class SearchResult:
|
|
33
|
-
"""Result from parameter search.
|
|
34
|
-
|
|
35
|
-
Attributes:
|
|
36
|
-
best_params: Dictionary of best parameters found.
|
|
37
|
-
best_score: Best score achieved.
|
|
38
|
-
all_results: DataFrame with all parameter combinations and scores.
|
|
39
|
-
cv_scores: Cross-validation scores if CV was used.
|
|
40
|
-
|
|
41
|
-
Example:
|
|
42
|
-
>>> result = search.fit(traces)
|
|
43
|
-
>>> print(f"Best params: {result.best_params}")
|
|
44
|
-
>>> print(f"Best score: {result.best_score}")
|
|
45
|
-
|
|
46
|
-
References:
|
|
47
|
-
API-014: Parameter Grid Search
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
best_params: dict[str, Any]
|
|
51
|
-
best_score: float
|
|
52
|
-
all_results: pd.DataFrame
|
|
53
|
-
cv_scores: NDArray[np.float64] | None = None
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _default_snr_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float:
|
|
57
|
-
"""Default SNR scoring function.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
trace: Waveform trace to analyze.
|
|
61
|
-
params: Parameters to apply (not used in basic SNR).
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
Signal-to-noise ratio in dB.
|
|
65
|
-
"""
|
|
66
|
-
# Simple SNR: signal power / noise power
|
|
67
|
-
# Assume first half is signal, second half is noise (oversimplified)
|
|
68
|
-
data = trace.data
|
|
69
|
-
mid = len(data) // 2
|
|
70
|
-
signal_power = np.mean(data[:mid] ** 2)
|
|
71
|
-
noise_power = np.mean((data[mid:] - np.mean(data[mid:])) ** 2)
|
|
72
|
-
|
|
73
|
-
if noise_power == 0:
|
|
74
|
-
return float("inf")
|
|
75
|
-
|
|
76
|
-
snr = signal_power / noise_power
|
|
77
|
-
return float(10 * np.log10(snr))
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _default_thd_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float:
|
|
81
|
-
"""Default THD scoring function.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
trace: Waveform trace to analyze.
|
|
85
|
-
params: Parameters to apply (not used in basic THD).
|
|
86
|
-
|
|
87
|
-
Returns:
|
|
88
|
-
Negative THD percentage (negative because lower THD is better, but we maximize scores).
|
|
89
|
-
"""
|
|
90
|
-
# Compute THD and return negative value (lower THD = better = higher score)
|
|
91
|
-
thd_value = compute_thd(trace)
|
|
92
|
-
return float(-thd_value)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class GridSearchCV:
|
|
96
|
-
"""Grid search over parameter space with optional cross-validation.
|
|
97
|
-
|
|
98
|
-
Systematically evaluates all combinations of parameters to find the
|
|
99
|
-
optimal configuration.
|
|
100
|
-
|
|
101
|
-
Example:
|
|
102
|
-
>>> from oscura.utils.optimization.search import GridSearchCV
|
|
103
|
-
>>> param_grid = {
|
|
104
|
-
... 'cutoff': [1e5, 5e5, 1e6],
|
|
105
|
-
... 'order': [2, 4, 6]
|
|
106
|
-
... }
|
|
107
|
-
>>> search = GridSearchCV(
|
|
108
|
-
... param_grid=param_grid,
|
|
109
|
-
... scoring='snr',
|
|
110
|
-
... cv=3
|
|
111
|
-
... )
|
|
112
|
-
>>> result = search.fit(traces, apply_filter)
|
|
113
|
-
>>> print(result.best_params)
|
|
114
|
-
|
|
115
|
-
References:
|
|
116
|
-
API-014: Parameter Grid Search
|
|
117
|
-
"""
|
|
118
|
-
|
|
119
|
-
def __init__(
|
|
120
|
-
self,
|
|
121
|
-
param_grid: dict[str, list[Any]],
|
|
122
|
-
scoring: Literal["snr", "thd"] | ScoringFunction = "snr",
|
|
123
|
-
cv: int | None = None,
|
|
124
|
-
*,
|
|
125
|
-
parallel: bool = True,
|
|
126
|
-
max_workers: int | None = None,
|
|
127
|
-
use_threads: bool = True,
|
|
128
|
-
) -> None:
|
|
129
|
-
"""Initialize grid search.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
param_grid: Dictionary mapping parameter names to lists of values.
|
|
133
|
-
scoring: Scoring function. Built-in: 'snr', 'thd', or custom callable.
|
|
134
|
-
cv: Number of cross-validation folds. None for no CV.
|
|
135
|
-
parallel: Enable parallel evaluation.
|
|
136
|
-
max_workers: Maximum parallel workers.
|
|
137
|
-
use_threads: Use threads instead of processes.
|
|
138
|
-
|
|
139
|
-
Raises:
|
|
140
|
-
AnalysisError: If scoring function is invalid.
|
|
141
|
-
|
|
142
|
-
Example:
|
|
143
|
-
>>> param_grid = {'cutoff': [1e6, 2e6], 'order': [4, 6]}
|
|
144
|
-
>>> search = GridSearchCV(param_grid, scoring='snr', cv=3)
|
|
145
|
-
"""
|
|
146
|
-
self.param_grid = param_grid
|
|
147
|
-
self.cv = cv
|
|
148
|
-
self.parallel = parallel
|
|
149
|
-
self.max_workers = max_workers
|
|
150
|
-
self.use_threads = use_threads
|
|
151
|
-
|
|
152
|
-
# Set scoring function
|
|
153
|
-
if scoring == "snr":
|
|
154
|
-
self.scoring_fn = _default_snr_scorer
|
|
155
|
-
elif scoring == "thd":
|
|
156
|
-
self.scoring_fn = _default_thd_scorer
|
|
157
|
-
elif callable(scoring):
|
|
158
|
-
self.scoring_fn = scoring # type: ignore[assignment]
|
|
159
|
-
else:
|
|
160
|
-
raise AnalysisError(f"Unknown scoring function: {scoring}")
|
|
161
|
-
|
|
162
|
-
self.best_params_: dict[str, Any] | None = None
|
|
163
|
-
self.best_score_: float | None = None
|
|
164
|
-
self.results_df_: pd.DataFrame | None = None
|
|
165
|
-
|
|
166
|
-
def fit(
|
|
167
|
-
self,
|
|
168
|
-
traces: list[WaveformTrace] | WaveformTrace,
|
|
169
|
-
transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
|
|
170
|
-
) -> SearchResult:
|
|
171
|
-
"""Fit grid search on traces.
|
|
172
|
-
|
|
173
|
-
Evaluates all parameter combinations and finds the best.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
traces: Trace or list of traces to evaluate on.
|
|
177
|
-
transform_fn: Function that applies parameters to trace.
|
|
178
|
-
Should accept (trace, **params) and return transformed trace.
|
|
179
|
-
|
|
180
|
-
Returns:
|
|
181
|
-
SearchResult with best parameters and all results.
|
|
182
|
-
|
|
183
|
-
Example:
|
|
184
|
-
>>> def apply_filter(trace, cutoff, order):
|
|
185
|
-
... return lowpass_filter(trace, cutoff=cutoff, order=order)
|
|
186
|
-
>>> result = search.fit(traces, apply_filter)
|
|
187
|
-
|
|
188
|
-
References:
|
|
189
|
-
API-014: Parameter Grid Search
|
|
190
|
-
"""
|
|
191
|
-
# Convert single trace to list
|
|
192
|
-
if not isinstance(traces, list):
|
|
193
|
-
traces = [traces]
|
|
194
|
-
|
|
195
|
-
# Generate all parameter combinations
|
|
196
|
-
param_combinations = self._generate_combinations()
|
|
197
|
-
|
|
198
|
-
# Evaluate each combination
|
|
199
|
-
results = self._evaluate_combinations(param_combinations, traces, transform_fn)
|
|
200
|
-
|
|
201
|
-
# Convert to DataFrame
|
|
202
|
-
self.results_df_ = pd.DataFrame(results)
|
|
203
|
-
|
|
204
|
-
# Find best
|
|
205
|
-
best_idx = self.results_df_["mean_score"].idxmax()
|
|
206
|
-
best_row = self.results_df_.iloc[best_idx]
|
|
207
|
-
|
|
208
|
-
self.best_params_ = {k: best_row[k] for k in self.param_grid}
|
|
209
|
-
self.best_score_ = float(best_row["mean_score"])
|
|
210
|
-
|
|
211
|
-
# Collect CV scores if available
|
|
212
|
-
cv_scores = None
|
|
213
|
-
if self.cv:
|
|
214
|
-
cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")]
|
|
215
|
-
if cv_cols:
|
|
216
|
-
cv_scores = self.results_df_.loc[best_idx, cv_cols].values
|
|
217
|
-
|
|
218
|
-
return SearchResult(
|
|
219
|
-
best_params=self.best_params_,
|
|
220
|
-
best_score=self.best_score_,
|
|
221
|
-
all_results=self.results_df_,
|
|
222
|
-
cv_scores=cv_scores,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
def _generate_combinations(self) -> list[dict[str, Any]]:
|
|
226
|
-
"""Generate all parameter combinations from grid.
|
|
227
|
-
|
|
228
|
-
Returns:
|
|
229
|
-
List of parameter dictionaries.
|
|
230
|
-
"""
|
|
231
|
-
keys = list(self.param_grid.keys())
|
|
232
|
-
values = [self.param_grid[k] for k in keys]
|
|
233
|
-
|
|
234
|
-
combinations = []
|
|
235
|
-
for combo in itertools.product(*values):
|
|
236
|
-
combinations.append(dict(zip(keys, combo, strict=False)))
|
|
237
|
-
|
|
238
|
-
return combinations
|
|
239
|
-
|
|
240
|
-
def _evaluate_combinations(
|
|
241
|
-
self,
|
|
242
|
-
param_combinations: list[dict[str, Any]],
|
|
243
|
-
traces: list[WaveformTrace],
|
|
244
|
-
transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
|
|
245
|
-
) -> list[dict[str, Any]]:
|
|
246
|
-
"""Evaluate all parameter combinations.
|
|
247
|
-
|
|
248
|
-
Args:
|
|
249
|
-
param_combinations: List of parameter dicts to evaluate.
|
|
250
|
-
traces: Traces to evaluate on.
|
|
251
|
-
transform_fn: Transformation function.
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
List of result dictionaries.
|
|
255
|
-
"""
|
|
256
|
-
if self.parallel:
|
|
257
|
-
return self._evaluate_parallel(param_combinations, traces, transform_fn)
|
|
258
|
-
else:
|
|
259
|
-
return self._evaluate_sequential(param_combinations, traces, transform_fn)
|
|
260
|
-
|
|
261
|
-
def _evaluate_one(
|
|
262
|
-
self,
|
|
263
|
-
params: dict[str, Any],
|
|
264
|
-
traces: list[WaveformTrace],
|
|
265
|
-
transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
|
|
266
|
-
) -> dict[str, Any]:
|
|
267
|
-
"""Evaluate one parameter combination.
|
|
268
|
-
|
|
269
|
-
Args:
|
|
270
|
-
params: Parameters to evaluate.
|
|
271
|
-
traces: Traces to evaluate on.
|
|
272
|
-
transform_fn: Transformation function.
|
|
273
|
-
|
|
274
|
-
Returns:
|
|
275
|
-
Result dictionary with scores.
|
|
276
|
-
"""
|
|
277
|
-
scores: list[float] = []
|
|
278
|
-
|
|
279
|
-
if self.cv:
|
|
280
|
-
# Cross-validation - split traces into folds
|
|
281
|
-
fold_size = len(traces) // self.cv
|
|
282
|
-
for i in range(self.cv):
|
|
283
|
-
# Select fold
|
|
284
|
-
start = i * fold_size
|
|
285
|
-
end = start + fold_size if i < self.cv - 1 else len(traces)
|
|
286
|
-
fold_traces = traces[start:end]
|
|
287
|
-
|
|
288
|
-
# Evaluate on fold
|
|
289
|
-
fold_scores = []
|
|
290
|
-
for trace in fold_traces:
|
|
291
|
-
transformed = transform_fn(trace, **params) # type: ignore[call-arg]
|
|
292
|
-
score = self.scoring_fn(transformed, params)
|
|
293
|
-
fold_scores.append(score)
|
|
294
|
-
|
|
295
|
-
scores.append(float(np.mean(fold_scores)))
|
|
296
|
-
|
|
297
|
-
else:
|
|
298
|
-
# No CV - evaluate on all traces
|
|
299
|
-
for trace in traces:
|
|
300
|
-
transformed = transform_fn(trace, **params) # type: ignore[call-arg]
|
|
301
|
-
score = self.scoring_fn(transformed, params)
|
|
302
|
-
scores.append(score)
|
|
303
|
-
|
|
304
|
-
# Build result
|
|
305
|
-
result = params.copy()
|
|
306
|
-
result["mean_score"] = float(np.mean(scores))
|
|
307
|
-
result["std_score"] = float(np.std(scores))
|
|
308
|
-
|
|
309
|
-
if self.cv:
|
|
310
|
-
for i, score in enumerate(scores):
|
|
311
|
-
result[f"cv_{i}"] = float(score)
|
|
312
|
-
|
|
313
|
-
return result
|
|
314
|
-
|
|
315
|
-
def _evaluate_sequential(
|
|
316
|
-
self,
|
|
317
|
-
param_combinations: list[dict[str, Any]],
|
|
318
|
-
traces: list[WaveformTrace],
|
|
319
|
-
transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
|
|
320
|
-
) -> list[dict[str, Any]]:
|
|
321
|
-
"""Evaluate combinations sequentially.
|
|
322
|
-
|
|
323
|
-
Args:
|
|
324
|
-
param_combinations: Parameter combinations.
|
|
325
|
-
traces: Traces to evaluate on.
|
|
326
|
-
transform_fn: Transformation function.
|
|
327
|
-
|
|
328
|
-
Returns:
|
|
329
|
-
List of results.
|
|
330
|
-
"""
|
|
331
|
-
results = []
|
|
332
|
-
for params in param_combinations:
|
|
333
|
-
result = self._evaluate_one(params, traces, transform_fn)
|
|
334
|
-
results.append(result)
|
|
335
|
-
return results
|
|
336
|
-
|
|
337
|
-
def _evaluate_parallel(
|
|
338
|
-
self,
|
|
339
|
-
param_combinations: list[dict[str, Any]],
|
|
340
|
-
traces: list[WaveformTrace],
|
|
341
|
-
transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
|
|
342
|
-
) -> list[dict[str, Any]]:
|
|
343
|
-
"""Evaluate combinations in parallel.
|
|
344
|
-
|
|
345
|
-
Args:
|
|
346
|
-
param_combinations: Parameter combinations.
|
|
347
|
-
traces: Traces to evaluate on.
|
|
348
|
-
transform_fn: Transformation function.
|
|
349
|
-
|
|
350
|
-
Returns:
|
|
351
|
-
List of results.
|
|
352
|
-
"""
|
|
353
|
-
executor_class = ThreadPoolExecutor if self.use_threads else ProcessPoolExecutor
|
|
354
|
-
|
|
355
|
-
with executor_class(max_workers=self.max_workers) as executor:
|
|
356
|
-
futures = {
|
|
357
|
-
executor.submit(self._evaluate_one, params, traces, transform_fn): params
|
|
358
|
-
for params in param_combinations
|
|
359
|
-
}
|
|
360
|
-
|
|
361
|
-
results = []
|
|
362
|
-
for future in as_completed(futures):
|
|
363
|
-
result = future.result()
|
|
364
|
-
results.append(result)
|
|
365
|
-
|
|
366
|
-
return results
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
class RandomizedSearchCV:
|
|
370
|
-
"""Random search over parameter distributions.
|
|
371
|
-
|
|
372
|
-
Samples random combinations from parameter distributions rather than
|
|
373
|
-
exhaustively evaluating all combinations.
|
|
374
|
-
|
|
375
|
-
Example:
|
|
376
|
-
>>> from oscura.utils.optimization.search import RandomizedSearchCV
|
|
377
|
-
>>> import numpy as np
|
|
378
|
-
>>> param_distributions = {
|
|
379
|
-
... 'cutoff': lambda: np.random.uniform(1e5, 1e7),
|
|
380
|
-
... 'order': lambda: np.random.choice([2, 4, 6, 8])
|
|
381
|
-
... }
|
|
382
|
-
>>> search = RandomizedSearchCV(
|
|
383
|
-
... param_distributions=param_distributions,
|
|
384
|
-
... n_iter=20,
|
|
385
|
-
... scoring='snr'
|
|
386
|
-
... )
|
|
387
|
-
>>> result = search.fit(traces, apply_filter)
|
|
388
|
-
|
|
389
|
-
References:
|
|
390
|
-
API-014: Parameter Grid Search
|
|
391
|
-
"""
|
|
392
|
-
|
|
393
|
-
def __init__(
|
|
394
|
-
self,
|
|
395
|
-
param_distributions: dict[str, Callable[[], Any]],
|
|
396
|
-
n_iter: int = 10,
|
|
397
|
-
scoring: Literal["snr", "thd"] | ScoringFunction = "snr",
|
|
398
|
-
cv: int | None = None,
|
|
399
|
-
*,
|
|
400
|
-
parallel: bool = True,
|
|
401
|
-
max_workers: int | None = None,
|
|
402
|
-
use_threads: bool = True,
|
|
403
|
-
random_state: int | None = None,
|
|
404
|
-
) -> None:
|
|
405
|
-
"""Initialize randomized search.
|
|
406
|
-
|
|
407
|
-
Args:
|
|
408
|
-
param_distributions: Dict mapping parameter names to sampling functions.
|
|
409
|
-
n_iter: Number of parameter combinations to sample.
|
|
410
|
-
scoring: Scoring function.
|
|
411
|
-
cv: Number of cross-validation folds.
|
|
412
|
-
parallel: Enable parallel evaluation.
|
|
413
|
-
max_workers: Maximum parallel workers.
|
|
414
|
-
use_threads: Use threads instead of processes.
|
|
415
|
-
random_state: Random seed for reproducibility.
|
|
416
|
-
|
|
417
|
-
Raises:
|
|
418
|
-
AnalysisError: If scoring function is invalid.
|
|
419
|
-
|
|
420
|
-
Example:
|
|
421
|
-
>>> param_dist = {'cutoff': lambda: np.random.uniform(1e5, 1e7)}
|
|
422
|
-
>>> search = RandomizedSearchCV(param_dist, n_iter=50)
|
|
423
|
-
"""
|
|
424
|
-
self.param_distributions = param_distributions
|
|
425
|
-
self.n_iter = n_iter
|
|
426
|
-
self.cv = cv
|
|
427
|
-
self.parallel = parallel
|
|
428
|
-
self.max_workers = max_workers
|
|
429
|
-
self.use_threads = use_threads
|
|
430
|
-
|
|
431
|
-
if random_state is not None:
|
|
432
|
-
np.random.seed(random_state)
|
|
433
|
-
|
|
434
|
-
# Set scoring function
|
|
435
|
-
if scoring == "snr":
|
|
436
|
-
self.scoring_fn = _default_snr_scorer
|
|
437
|
-
elif scoring == "thd":
|
|
438
|
-
self.scoring_fn = _default_thd_scorer
|
|
439
|
-
elif callable(scoring):
|
|
440
|
-
self.scoring_fn = scoring # type: ignore[assignment]
|
|
441
|
-
else:
|
|
442
|
-
raise AnalysisError(f"Unknown scoring function: {scoring}")
|
|
443
|
-
|
|
444
|
-
self.best_params_: dict[str, Any] | None = None
|
|
445
|
-
self.best_score_: float | None = None
|
|
446
|
-
self.results_df_: pd.DataFrame | None = None
|
|
447
|
-
|
|
448
|
-
def fit(
|
|
449
|
-
self,
|
|
450
|
-
traces: list[WaveformTrace] | WaveformTrace,
|
|
451
|
-
transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
|
|
452
|
-
) -> SearchResult:
|
|
453
|
-
"""Fit randomized search on traces.
|
|
454
|
-
|
|
455
|
-
Args:
|
|
456
|
-
traces: Trace or list of traces to evaluate on.
|
|
457
|
-
transform_fn: Function that applies parameters to trace.
|
|
458
|
-
|
|
459
|
-
Returns:
|
|
460
|
-
SearchResult with best parameters.
|
|
461
|
-
|
|
462
|
-
Example:
|
|
463
|
-
>>> result = search.fit(traces, apply_filter)
|
|
464
|
-
>>> print(f"Best cutoff: {result.best_params['cutoff']:.2e}")
|
|
465
|
-
|
|
466
|
-
References:
|
|
467
|
-
API-014: Parameter Grid Search
|
|
468
|
-
"""
|
|
469
|
-
# Convert single trace to list
|
|
470
|
-
if not isinstance(traces, list):
|
|
471
|
-
traces = [traces]
|
|
472
|
-
|
|
473
|
-
# Sample parameter combinations
|
|
474
|
-
param_combinations = self._sample_combinations()
|
|
475
|
-
|
|
476
|
-
# Reuse grid search evaluation logic
|
|
477
|
-
grid_search = GridSearchCV(
|
|
478
|
-
param_grid={}, # Not used
|
|
479
|
-
scoring=self.scoring_fn,
|
|
480
|
-
cv=self.cv,
|
|
481
|
-
parallel=self.parallel,
|
|
482
|
-
max_workers=self.max_workers,
|
|
483
|
-
use_threads=self.use_threads,
|
|
484
|
-
)
|
|
485
|
-
|
|
486
|
-
results = grid_search._evaluate_combinations(param_combinations, traces, transform_fn)
|
|
487
|
-
|
|
488
|
-
# Convert to DataFrame
|
|
489
|
-
self.results_df_ = pd.DataFrame(results)
|
|
490
|
-
|
|
491
|
-
# Find best
|
|
492
|
-
best_idx = self.results_df_["mean_score"].idxmax()
|
|
493
|
-
best_row = self.results_df_.iloc[best_idx]
|
|
494
|
-
|
|
495
|
-
self.best_params_ = {k: best_row[k] for k in self.param_distributions}
|
|
496
|
-
self.best_score_ = float(best_row["mean_score"])
|
|
497
|
-
|
|
498
|
-
# Collect CV scores if available
|
|
499
|
-
cv_scores = None
|
|
500
|
-
if self.cv:
|
|
501
|
-
cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")]
|
|
502
|
-
if cv_cols:
|
|
503
|
-
cv_scores = self.results_df_.loc[best_idx, cv_cols].values
|
|
504
|
-
|
|
505
|
-
return SearchResult(
|
|
506
|
-
best_params=self.best_params_,
|
|
507
|
-
best_score=self.best_score_,
|
|
508
|
-
all_results=self.results_df_,
|
|
509
|
-
cv_scores=cv_scores,
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
def _sample_combinations(self) -> list[dict[str, Any]]:
|
|
513
|
-
"""Sample random parameter combinations.
|
|
514
|
-
|
|
515
|
-
Returns:
|
|
516
|
-
List of sampled parameter dictionaries.
|
|
517
|
-
"""
|
|
518
|
-
combinations = []
|
|
519
|
-
|
|
520
|
-
for _ in range(self.n_iter):
|
|
521
|
-
params = {key: sampler() for key, sampler in self.param_distributions.items()}
|
|
522
|
-
combinations.append(params)
|
|
523
|
-
|
|
524
|
-
return combinations
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
__all__ = [
|
|
528
|
-
"GridSearchCV",
|
|
529
|
-
"RandomizedSearchCV",
|
|
530
|
-
"ScoringFunction",
|
|
531
|
-
"SearchResult",
|
|
532
|
-
]
|