oscura 0.7.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. oscura/__init__.py +19 -19
  2. oscura/analyzers/__init__.py +2 -0
  3. oscura/analyzers/digital/extraction.py +2 -3
  4. oscura/analyzers/digital/quality.py +1 -1
  5. oscura/analyzers/digital/timing.py +1 -1
  6. oscura/analyzers/eye/__init__.py +5 -1
  7. oscura/analyzers/eye/generation.py +501 -0
  8. oscura/analyzers/jitter/__init__.py +6 -6
  9. oscura/analyzers/jitter/timing.py +419 -0
  10. oscura/analyzers/patterns/__init__.py +94 -0
  11. oscura/analyzers/patterns/reverse_engineering.py +991 -0
  12. oscura/analyzers/power/__init__.py +35 -12
  13. oscura/analyzers/power/basic.py +3 -3
  14. oscura/analyzers/power/soa.py +1 -1
  15. oscura/analyzers/power/switching.py +3 -3
  16. oscura/analyzers/signal_classification.py +529 -0
  17. oscura/analyzers/signal_integrity/sparams.py +3 -3
  18. oscura/analyzers/statistics/__init__.py +4 -0
  19. oscura/analyzers/statistics/basic.py +152 -0
  20. oscura/analyzers/statistics/correlation.py +47 -6
  21. oscura/analyzers/validation.py +1 -1
  22. oscura/analyzers/waveform/__init__.py +2 -0
  23. oscura/analyzers/waveform/measurements.py +329 -163
  24. oscura/analyzers/waveform/measurements_with_uncertainty.py +91 -35
  25. oscura/analyzers/waveform/spectral.py +498 -54
  26. oscura/api/dsl/commands.py +15 -6
  27. oscura/api/server/templates/base.html +137 -146
  28. oscura/api/server/templates/export.html +84 -110
  29. oscura/api/server/templates/home.html +248 -267
  30. oscura/api/server/templates/protocols.html +44 -48
  31. oscura/api/server/templates/reports.html +27 -35
  32. oscura/api/server/templates/session_detail.html +68 -78
  33. oscura/api/server/templates/sessions.html +62 -72
  34. oscura/api/server/templates/waveforms.html +54 -64
  35. oscura/automotive/__init__.py +1 -1
  36. oscura/automotive/can/session.py +1 -1
  37. oscura/automotive/dbc/generator.py +638 -23
  38. oscura/automotive/dtc/data.json +102 -17
  39. oscura/automotive/uds/decoder.py +99 -6
  40. oscura/cli/analyze.py +8 -2
  41. oscura/cli/batch.py +36 -5
  42. oscura/cli/characterize.py +18 -4
  43. oscura/cli/export.py +47 -5
  44. oscura/cli/main.py +2 -0
  45. oscura/cli/onboarding/wizard.py +10 -6
  46. oscura/cli/pipeline.py +585 -0
  47. oscura/cli/visualize.py +6 -4
  48. oscura/convenience.py +400 -32
  49. oscura/core/config/loader.py +0 -1
  50. oscura/core/measurement_result.py +286 -0
  51. oscura/core/progress.py +1 -1
  52. oscura/core/schemas/device_mapping.json +8 -2
  53. oscura/core/schemas/packet_format.json +24 -4
  54. oscura/core/schemas/protocol_definition.json +12 -2
  55. oscura/core/types.py +300 -199
  56. oscura/correlation/multi_protocol.py +1 -1
  57. oscura/export/legacy/__init__.py +11 -0
  58. oscura/export/legacy/wav.py +75 -0
  59. oscura/exporters/__init__.py +19 -0
  60. oscura/exporters/wireshark.py +809 -0
  61. oscura/hardware/acquisition/file.py +5 -19
  62. oscura/hardware/acquisition/saleae.py +10 -10
  63. oscura/hardware/acquisition/socketcan.py +4 -6
  64. oscura/hardware/acquisition/synthetic.py +1 -5
  65. oscura/hardware/acquisition/visa.py +6 -6
  66. oscura/hardware/security/side_channel_detector.py +5 -508
  67. oscura/inference/message_format.py +686 -1
  68. oscura/jupyter/display.py +2 -2
  69. oscura/jupyter/magic.py +3 -3
  70. oscura/loaders/__init__.py +17 -12
  71. oscura/loaders/binary.py +1 -1
  72. oscura/loaders/chipwhisperer.py +1 -2
  73. oscura/loaders/configurable.py +1 -1
  74. oscura/loaders/csv_loader.py +2 -2
  75. oscura/loaders/hdf5_loader.py +1 -1
  76. oscura/loaders/lazy.py +6 -1
  77. oscura/loaders/mmap_loader.py +0 -1
  78. oscura/loaders/numpy_loader.py +8 -7
  79. oscura/loaders/preprocessing.py +3 -5
  80. oscura/loaders/rigol.py +21 -7
  81. oscura/loaders/sigrok.py +2 -5
  82. oscura/loaders/tdms.py +3 -2
  83. oscura/loaders/tektronix.py +38 -32
  84. oscura/loaders/tss.py +20 -27
  85. oscura/loaders/vcd.py +13 -8
  86. oscura/loaders/wav.py +1 -6
  87. oscura/pipeline/__init__.py +76 -0
  88. oscura/pipeline/handlers/__init__.py +165 -0
  89. oscura/pipeline/handlers/analyzers.py +1045 -0
  90. oscura/pipeline/handlers/decoders.py +899 -0
  91. oscura/pipeline/handlers/exporters.py +1103 -0
  92. oscura/pipeline/handlers/filters.py +891 -0
  93. oscura/pipeline/handlers/loaders.py +640 -0
  94. oscura/pipeline/handlers/transforms.py +768 -0
  95. oscura/reporting/__init__.py +88 -1
  96. oscura/reporting/automation.py +348 -0
  97. oscura/reporting/citations.py +374 -0
  98. oscura/reporting/core.py +54 -0
  99. oscura/reporting/formatting/__init__.py +11 -0
  100. oscura/reporting/formatting/measurements.py +320 -0
  101. oscura/reporting/html.py +57 -0
  102. oscura/reporting/interpretation.py +431 -0
  103. oscura/reporting/summary.py +329 -0
  104. oscura/reporting/templates/enhanced/protocol_re.html +504 -503
  105. oscura/reporting/visualization.py +542 -0
  106. oscura/side_channel/__init__.py +38 -57
  107. oscura/utils/builders/signal_builder.py +5 -5
  108. oscura/utils/comparison/compare.py +7 -9
  109. oscura/utils/comparison/golden.py +1 -1
  110. oscura/utils/filtering/convenience.py +2 -2
  111. oscura/utils/math/arithmetic.py +38 -62
  112. oscura/utils/math/interpolation.py +20 -20
  113. oscura/utils/pipeline/__init__.py +4 -17
  114. oscura/utils/progressive.py +1 -4
  115. oscura/utils/triggering/edge.py +1 -1
  116. oscura/utils/triggering/pattern.py +2 -2
  117. oscura/utils/triggering/pulse.py +2 -2
  118. oscura/utils/triggering/window.py +3 -3
  119. oscura/validation/hil_testing.py +11 -11
  120. oscura/visualization/__init__.py +47 -284
  121. oscura/visualization/batch.py +160 -0
  122. oscura/visualization/plot.py +542 -53
  123. oscura/visualization/styles.py +184 -318
  124. oscura/workflows/__init__.py +2 -0
  125. oscura/workflows/batch/advanced.py +1 -1
  126. oscura/workflows/batch/aggregate.py +7 -8
  127. oscura/workflows/complete_re.py +251 -23
  128. oscura/workflows/digital.py +27 -4
  129. oscura/workflows/multi_trace.py +136 -17
  130. oscura/workflows/waveform.py +788 -0
  131. {oscura-0.7.0.dist-info → oscura-0.10.0.dist-info}/METADATA +59 -79
  132. {oscura-0.7.0.dist-info → oscura-0.10.0.dist-info}/RECORD +135 -149
  133. oscura/side_channel/dpa.py +0 -1025
  134. oscura/utils/optimization/__init__.py +0 -19
  135. oscura/utils/optimization/parallel.py +0 -443
  136. oscura/utils/optimization/search.py +0 -532
  137. oscura/utils/pipeline/base.py +0 -338
  138. oscura/utils/pipeline/composition.py +0 -248
  139. oscura/utils/pipeline/parallel.py +0 -449
  140. oscura/utils/pipeline/pipeline.py +0 -375
  141. oscura/utils/search/__init__.py +0 -16
  142. oscura/utils/search/anomaly.py +0 -424
  143. oscura/utils/search/context.py +0 -294
  144. oscura/utils/search/pattern.py +0 -288
  145. oscura/utils/storage/__init__.py +0 -61
  146. oscura/utils/storage/database.py +0 -1166
  147. oscura/visualization/accessibility.py +0 -526
  148. oscura/visualization/annotations.py +0 -371
  149. oscura/visualization/axis_scaling.py +0 -305
  150. oscura/visualization/colors.py +0 -451
  151. oscura/visualization/digital.py +0 -436
  152. oscura/visualization/eye.py +0 -571
  153. oscura/visualization/histogram.py +0 -281
  154. oscura/visualization/interactive.py +0 -1035
  155. oscura/visualization/jitter.py +0 -1042
  156. oscura/visualization/keyboard.py +0 -394
  157. oscura/visualization/layout.py +0 -400
  158. oscura/visualization/optimization.py +0 -1079
  159. oscura/visualization/palettes.py +0 -446
  160. oscura/visualization/power.py +0 -508
  161. oscura/visualization/power_extended.py +0 -955
  162. oscura/visualization/presets.py +0 -469
  163. oscura/visualization/protocols.py +0 -1246
  164. oscura/visualization/render.py +0 -223
  165. oscura/visualization/rendering.py +0 -444
  166. oscura/visualization/reverse_engineering.py +0 -838
  167. oscura/visualization/signal_integrity.py +0 -989
  168. oscura/visualization/specialized.py +0 -643
  169. oscura/visualization/spectral.py +0 -1226
  170. oscura/visualization/thumbnails.py +0 -340
  171. oscura/visualization/time_axis.py +0 -351
  172. oscura/visualization/waveform.py +0 -454
  173. {oscura-0.7.0.dist-info → oscura-0.10.0.dist-info}/WHEEL +0 -0
  174. {oscura-0.7.0.dist-info → oscura-0.10.0.dist-info}/entry_points.txt +0 -0
  175. {oscura-0.7.0.dist-info → oscura-0.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,532 +0,0 @@
1
- """Parameter optimization via grid search and random search.
2
-
3
- This module provides tools for finding optimal analysis parameters through
4
- systematic or random search of the parameter space.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import itertools
10
- from collections.abc import Callable
11
- from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
12
- from dataclasses import dataclass
13
- from typing import TYPE_CHECKING, Any, Literal
14
-
15
- import numpy as np
16
- import pandas as pd
17
-
18
- from oscura.analyzers.waveform.spectral import thd as compute_thd
19
- from oscura.core.exceptions import AnalysisError
20
-
21
- if TYPE_CHECKING:
22
- from numpy.typing import NDArray
23
-
24
- from oscura.core.types import WaveformTrace
25
-
26
- ScoringFunction = Callable[[WaveformTrace, dict[str, Any]], float]
27
- else:
28
- ScoringFunction = Callable
29
-
30
-
31
- @dataclass
32
- class SearchResult:
33
- """Result from parameter search.
34
-
35
- Attributes:
36
- best_params: Dictionary of best parameters found.
37
- best_score: Best score achieved.
38
- all_results: DataFrame with all parameter combinations and scores.
39
- cv_scores: Cross-validation scores if CV was used.
40
-
41
- Example:
42
- >>> result = search.fit(traces)
43
- >>> print(f"Best params: {result.best_params}")
44
- >>> print(f"Best score: {result.best_score}")
45
-
46
- References:
47
- API-014: Parameter Grid Search
48
- """
49
-
50
- best_params: dict[str, Any]
51
- best_score: float
52
- all_results: pd.DataFrame
53
- cv_scores: NDArray[np.float64] | None = None
54
-
55
-
56
- def _default_snr_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float:
57
- """Default SNR scoring function.
58
-
59
- Args:
60
- trace: Waveform trace to analyze.
61
- params: Parameters to apply (not used in basic SNR).
62
-
63
- Returns:
64
- Signal-to-noise ratio in dB.
65
- """
66
- # Simple SNR: signal power / noise power
67
- # Assume first half is signal, second half is noise (oversimplified)
68
- data = trace.data
69
- mid = len(data) // 2
70
- signal_power = np.mean(data[:mid] ** 2)
71
- noise_power = np.mean((data[mid:] - np.mean(data[mid:])) ** 2)
72
-
73
- if noise_power == 0:
74
- return float("inf")
75
-
76
- snr = signal_power / noise_power
77
- return float(10 * np.log10(snr))
78
-
79
-
80
- def _default_thd_scorer(trace: WaveformTrace, params: dict[str, Any]) -> float:
81
- """Default THD scoring function.
82
-
83
- Args:
84
- trace: Waveform trace to analyze.
85
- params: Parameters to apply (not used in basic THD).
86
-
87
- Returns:
88
- Negative THD percentage (negative because lower THD is better, but we maximize scores).
89
- """
90
- # Compute THD and return negative value (lower THD = better = higher score)
91
- thd_value = compute_thd(trace)
92
- return float(-thd_value)
93
-
94
-
95
- class GridSearchCV:
96
- """Grid search over parameter space with optional cross-validation.
97
-
98
- Systematically evaluates all combinations of parameters to find the
99
- optimal configuration.
100
-
101
- Example:
102
- >>> from oscura.utils.optimization.search import GridSearchCV
103
- >>> param_grid = {
104
- ... 'cutoff': [1e5, 5e5, 1e6],
105
- ... 'order': [2, 4, 6]
106
- ... }
107
- >>> search = GridSearchCV(
108
- ... param_grid=param_grid,
109
- ... scoring='snr',
110
- ... cv=3
111
- ... )
112
- >>> result = search.fit(traces, apply_filter)
113
- >>> print(result.best_params)
114
-
115
- References:
116
- API-014: Parameter Grid Search
117
- """
118
-
119
- def __init__(
120
- self,
121
- param_grid: dict[str, list[Any]],
122
- scoring: Literal["snr", "thd"] | ScoringFunction = "snr",
123
- cv: int | None = None,
124
- *,
125
- parallel: bool = True,
126
- max_workers: int | None = None,
127
- use_threads: bool = True,
128
- ) -> None:
129
- """Initialize grid search.
130
-
131
- Args:
132
- param_grid: Dictionary mapping parameter names to lists of values.
133
- scoring: Scoring function. Built-in: 'snr', 'thd', or custom callable.
134
- cv: Number of cross-validation folds. None for no CV.
135
- parallel: Enable parallel evaluation.
136
- max_workers: Maximum parallel workers.
137
- use_threads: Use threads instead of processes.
138
-
139
- Raises:
140
- AnalysisError: If scoring function is invalid.
141
-
142
- Example:
143
- >>> param_grid = {'cutoff': [1e6, 2e6], 'order': [4, 6]}
144
- >>> search = GridSearchCV(param_grid, scoring='snr', cv=3)
145
- """
146
- self.param_grid = param_grid
147
- self.cv = cv
148
- self.parallel = parallel
149
- self.max_workers = max_workers
150
- self.use_threads = use_threads
151
-
152
- # Set scoring function
153
- if scoring == "snr":
154
- self.scoring_fn = _default_snr_scorer
155
- elif scoring == "thd":
156
- self.scoring_fn = _default_thd_scorer
157
- elif callable(scoring):
158
- self.scoring_fn = scoring # type: ignore[assignment]
159
- else:
160
- raise AnalysisError(f"Unknown scoring function: {scoring}")
161
-
162
- self.best_params_: dict[str, Any] | None = None
163
- self.best_score_: float | None = None
164
- self.results_df_: pd.DataFrame | None = None
165
-
166
- def fit(
167
- self,
168
- traces: list[WaveformTrace] | WaveformTrace,
169
- transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
170
- ) -> SearchResult:
171
- """Fit grid search on traces.
172
-
173
- Evaluates all parameter combinations and finds the best.
174
-
175
- Args:
176
- traces: Trace or list of traces to evaluate on.
177
- transform_fn: Function that applies parameters to trace.
178
- Should accept (trace, **params) and return transformed trace.
179
-
180
- Returns:
181
- SearchResult with best parameters and all results.
182
-
183
- Example:
184
- >>> def apply_filter(trace, cutoff, order):
185
- ... return lowpass_filter(trace, cutoff=cutoff, order=order)
186
- >>> result = search.fit(traces, apply_filter)
187
-
188
- References:
189
- API-014: Parameter Grid Search
190
- """
191
- # Convert single trace to list
192
- if not isinstance(traces, list):
193
- traces = [traces]
194
-
195
- # Generate all parameter combinations
196
- param_combinations = self._generate_combinations()
197
-
198
- # Evaluate each combination
199
- results = self._evaluate_combinations(param_combinations, traces, transform_fn)
200
-
201
- # Convert to DataFrame
202
- self.results_df_ = pd.DataFrame(results)
203
-
204
- # Find best
205
- best_idx = self.results_df_["mean_score"].idxmax()
206
- best_row = self.results_df_.iloc[best_idx]
207
-
208
- self.best_params_ = {k: best_row[k] for k in self.param_grid}
209
- self.best_score_ = float(best_row["mean_score"])
210
-
211
- # Collect CV scores if available
212
- cv_scores = None
213
- if self.cv:
214
- cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")]
215
- if cv_cols:
216
- cv_scores = self.results_df_.loc[best_idx, cv_cols].values
217
-
218
- return SearchResult(
219
- best_params=self.best_params_,
220
- best_score=self.best_score_,
221
- all_results=self.results_df_,
222
- cv_scores=cv_scores,
223
- )
224
-
225
- def _generate_combinations(self) -> list[dict[str, Any]]:
226
- """Generate all parameter combinations from grid.
227
-
228
- Returns:
229
- List of parameter dictionaries.
230
- """
231
- keys = list(self.param_grid.keys())
232
- values = [self.param_grid[k] for k in keys]
233
-
234
- combinations = []
235
- for combo in itertools.product(*values):
236
- combinations.append(dict(zip(keys, combo, strict=False)))
237
-
238
- return combinations
239
-
240
- def _evaluate_combinations(
241
- self,
242
- param_combinations: list[dict[str, Any]],
243
- traces: list[WaveformTrace],
244
- transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
245
- ) -> list[dict[str, Any]]:
246
- """Evaluate all parameter combinations.
247
-
248
- Args:
249
- param_combinations: List of parameter dicts to evaluate.
250
- traces: Traces to evaluate on.
251
- transform_fn: Transformation function.
252
-
253
- Returns:
254
- List of result dictionaries.
255
- """
256
- if self.parallel:
257
- return self._evaluate_parallel(param_combinations, traces, transform_fn)
258
- else:
259
- return self._evaluate_sequential(param_combinations, traces, transform_fn)
260
-
261
- def _evaluate_one(
262
- self,
263
- params: dict[str, Any],
264
- traces: list[WaveformTrace],
265
- transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
266
- ) -> dict[str, Any]:
267
- """Evaluate one parameter combination.
268
-
269
- Args:
270
- params: Parameters to evaluate.
271
- traces: Traces to evaluate on.
272
- transform_fn: Transformation function.
273
-
274
- Returns:
275
- Result dictionary with scores.
276
- """
277
- scores: list[float] = []
278
-
279
- if self.cv:
280
- # Cross-validation - split traces into folds
281
- fold_size = len(traces) // self.cv
282
- for i in range(self.cv):
283
- # Select fold
284
- start = i * fold_size
285
- end = start + fold_size if i < self.cv - 1 else len(traces)
286
- fold_traces = traces[start:end]
287
-
288
- # Evaluate on fold
289
- fold_scores = []
290
- for trace in fold_traces:
291
- transformed = transform_fn(trace, **params) # type: ignore[call-arg]
292
- score = self.scoring_fn(transformed, params)
293
- fold_scores.append(score)
294
-
295
- scores.append(float(np.mean(fold_scores)))
296
-
297
- else:
298
- # No CV - evaluate on all traces
299
- for trace in traces:
300
- transformed = transform_fn(trace, **params) # type: ignore[call-arg]
301
- score = self.scoring_fn(transformed, params)
302
- scores.append(score)
303
-
304
- # Build result
305
- result = params.copy()
306
- result["mean_score"] = float(np.mean(scores))
307
- result["std_score"] = float(np.std(scores))
308
-
309
- if self.cv:
310
- for i, score in enumerate(scores):
311
- result[f"cv_{i}"] = float(score)
312
-
313
- return result
314
-
315
- def _evaluate_sequential(
316
- self,
317
- param_combinations: list[dict[str, Any]],
318
- traces: list[WaveformTrace],
319
- transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
320
- ) -> list[dict[str, Any]]:
321
- """Evaluate combinations sequentially.
322
-
323
- Args:
324
- param_combinations: Parameter combinations.
325
- traces: Traces to evaluate on.
326
- transform_fn: Transformation function.
327
-
328
- Returns:
329
- List of results.
330
- """
331
- results = []
332
- for params in param_combinations:
333
- result = self._evaluate_one(params, traces, transform_fn)
334
- results.append(result)
335
- return results
336
-
337
- def _evaluate_parallel(
338
- self,
339
- param_combinations: list[dict[str, Any]],
340
- traces: list[WaveformTrace],
341
- transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
342
- ) -> list[dict[str, Any]]:
343
- """Evaluate combinations in parallel.
344
-
345
- Args:
346
- param_combinations: Parameter combinations.
347
- traces: Traces to evaluate on.
348
- transform_fn: Transformation function.
349
-
350
- Returns:
351
- List of results.
352
- """
353
- executor_class = ThreadPoolExecutor if self.use_threads else ProcessPoolExecutor
354
-
355
- with executor_class(max_workers=self.max_workers) as executor:
356
- futures = {
357
- executor.submit(self._evaluate_one, params, traces, transform_fn): params
358
- for params in param_combinations
359
- }
360
-
361
- results = []
362
- for future in as_completed(futures):
363
- result = future.result()
364
- results.append(result)
365
-
366
- return results
367
-
368
-
369
- class RandomizedSearchCV:
370
- """Random search over parameter distributions.
371
-
372
- Samples random combinations from parameter distributions rather than
373
- exhaustively evaluating all combinations.
374
-
375
- Example:
376
- >>> from oscura.utils.optimization.search import RandomizedSearchCV
377
- >>> import numpy as np
378
- >>> param_distributions = {
379
- ... 'cutoff': lambda: np.random.uniform(1e5, 1e7),
380
- ... 'order': lambda: np.random.choice([2, 4, 6, 8])
381
- ... }
382
- >>> search = RandomizedSearchCV(
383
- ... param_distributions=param_distributions,
384
- ... n_iter=20,
385
- ... scoring='snr'
386
- ... )
387
- >>> result = search.fit(traces, apply_filter)
388
-
389
- References:
390
- API-014: Parameter Grid Search
391
- """
392
-
393
- def __init__(
394
- self,
395
- param_distributions: dict[str, Callable[[], Any]],
396
- n_iter: int = 10,
397
- scoring: Literal["snr", "thd"] | ScoringFunction = "snr",
398
- cv: int | None = None,
399
- *,
400
- parallel: bool = True,
401
- max_workers: int | None = None,
402
- use_threads: bool = True,
403
- random_state: int | None = None,
404
- ) -> None:
405
- """Initialize randomized search.
406
-
407
- Args:
408
- param_distributions: Dict mapping parameter names to sampling functions.
409
- n_iter: Number of parameter combinations to sample.
410
- scoring: Scoring function.
411
- cv: Number of cross-validation folds.
412
- parallel: Enable parallel evaluation.
413
- max_workers: Maximum parallel workers.
414
- use_threads: Use threads instead of processes.
415
- random_state: Random seed for reproducibility.
416
-
417
- Raises:
418
- AnalysisError: If scoring function is invalid.
419
-
420
- Example:
421
- >>> param_dist = {'cutoff': lambda: np.random.uniform(1e5, 1e7)}
422
- >>> search = RandomizedSearchCV(param_dist, n_iter=50)
423
- """
424
- self.param_distributions = param_distributions
425
- self.n_iter = n_iter
426
- self.cv = cv
427
- self.parallel = parallel
428
- self.max_workers = max_workers
429
- self.use_threads = use_threads
430
-
431
- if random_state is not None:
432
- np.random.seed(random_state)
433
-
434
- # Set scoring function
435
- if scoring == "snr":
436
- self.scoring_fn = _default_snr_scorer
437
- elif scoring == "thd":
438
- self.scoring_fn = _default_thd_scorer
439
- elif callable(scoring):
440
- self.scoring_fn = scoring # type: ignore[assignment]
441
- else:
442
- raise AnalysisError(f"Unknown scoring function: {scoring}")
443
-
444
- self.best_params_: dict[str, Any] | None = None
445
- self.best_score_: float | None = None
446
- self.results_df_: pd.DataFrame | None = None
447
-
448
- def fit(
449
- self,
450
- traces: list[WaveformTrace] | WaveformTrace,
451
- transform_fn: Callable[[WaveformTrace, dict[str, Any]], WaveformTrace],
452
- ) -> SearchResult:
453
- """Fit randomized search on traces.
454
-
455
- Args:
456
- traces: Trace or list of traces to evaluate on.
457
- transform_fn: Function that applies parameters to trace.
458
-
459
- Returns:
460
- SearchResult with best parameters.
461
-
462
- Example:
463
- >>> result = search.fit(traces, apply_filter)
464
- >>> print(f"Best cutoff: {result.best_params['cutoff']:.2e}")
465
-
466
- References:
467
- API-014: Parameter Grid Search
468
- """
469
- # Convert single trace to list
470
- if not isinstance(traces, list):
471
- traces = [traces]
472
-
473
- # Sample parameter combinations
474
- param_combinations = self._sample_combinations()
475
-
476
- # Reuse grid search evaluation logic
477
- grid_search = GridSearchCV(
478
- param_grid={}, # Not used
479
- scoring=self.scoring_fn,
480
- cv=self.cv,
481
- parallel=self.parallel,
482
- max_workers=self.max_workers,
483
- use_threads=self.use_threads,
484
- )
485
-
486
- results = grid_search._evaluate_combinations(param_combinations, traces, transform_fn)
487
-
488
- # Convert to DataFrame
489
- self.results_df_ = pd.DataFrame(results)
490
-
491
- # Find best
492
- best_idx = self.results_df_["mean_score"].idxmax()
493
- best_row = self.results_df_.iloc[best_idx]
494
-
495
- self.best_params_ = {k: best_row[k] for k in self.param_distributions}
496
- self.best_score_ = float(best_row["mean_score"])
497
-
498
- # Collect CV scores if available
499
- cv_scores = None
500
- if self.cv:
501
- cv_cols = [c for c in self.results_df_.columns if c.startswith("cv_")]
502
- if cv_cols:
503
- cv_scores = self.results_df_.loc[best_idx, cv_cols].values
504
-
505
- return SearchResult(
506
- best_params=self.best_params_,
507
- best_score=self.best_score_,
508
- all_results=self.results_df_,
509
- cv_scores=cv_scores,
510
- )
511
-
512
- def _sample_combinations(self) -> list[dict[str, Any]]:
513
- """Sample random parameter combinations.
514
-
515
- Returns:
516
- List of sampled parameter dictionaries.
517
- """
518
- combinations = []
519
-
520
- for _ in range(self.n_iter):
521
- params = {key: sampler() for key, sampler in self.param_distributions.items()}
522
- combinations.append(params)
523
-
524
- return combinations
525
-
526
-
527
- __all__ = [
528
- "GridSearchCV",
529
- "RandomizedSearchCV",
530
- "ScoringFunction",
531
- "SearchResult",
532
- ]