sqil-core 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqil_core/__init__.py +1 -0
- sqil_core/config_log.py +42 -0
- sqil_core/experiment/__init__.py +11 -0
- sqil_core/experiment/_analysis.py +125 -0
- sqil_core/experiment/_events.py +25 -0
- sqil_core/experiment/_experiment.py +553 -0
- sqil_core/experiment/data/plottr.py +778 -0
- sqil_core/experiment/helpers/_function_override_handler.py +111 -0
- sqil_core/experiment/helpers/_labone_wrappers.py +12 -0
- sqil_core/experiment/instruments/__init__.py +2 -0
- sqil_core/experiment/instruments/_instrument.py +190 -0
- sqil_core/experiment/instruments/drivers/SignalCore_SC5511A.py +515 -0
- sqil_core/experiment/instruments/local_oscillator.py +205 -0
- sqil_core/experiment/instruments/server.py +175 -0
- sqil_core/experiment/instruments/setup.yaml +21 -0
- sqil_core/experiment/instruments/zurich_instruments.py +55 -0
- sqil_core/fit/__init__.py +23 -0
- sqil_core/fit/_core.py +179 -31
- sqil_core/fit/_fit.py +544 -94
- sqil_core/fit/_guess.py +304 -0
- sqil_core/fit/_models.py +50 -1
- sqil_core/fit/_quality.py +266 -0
- sqil_core/resonator/__init__.py +2 -0
- sqil_core/resonator/_resonator.py +256 -74
- sqil_core/utils/__init__.py +40 -13
- sqil_core/utils/_analysis.py +226 -0
- sqil_core/utils/_const.py +83 -18
- sqil_core/utils/_formatter.py +127 -55
- sqil_core/utils/_plot.py +272 -6
- sqil_core/utils/_read.py +178 -95
- sqil_core/utils/_utils.py +147 -0
- {sqil_core-0.1.0.dist-info → sqil_core-1.1.0.dist-info}/METADATA +9 -1
- sqil_core-1.1.0.dist-info/RECORD +36 -0
- {sqil_core-0.1.0.dist-info → sqil_core-1.1.0.dist-info}/WHEEL +1 -1
- sqil_core-0.1.0.dist-info/RECORD +0 -19
- {sqil_core-0.1.0.dist-info → sqil_core-1.1.0.dist-info}/entry_points.txt +0 -0
sqil_core/fit/_guess.py
ADDED
@@ -0,0 +1,304 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from numpy.fft import rfft, rfftfreq
|
3
|
+
from scipy.signal import hilbert
|
4
|
+
|
5
|
+
from sqil_core.utils import compute_fft, get_peaks
|
6
|
+
|
7
|
+
|
8
|
+
def estimate_peak(
|
9
|
+
x_data: np.ndarray, y_data: np.ndarray
|
10
|
+
) -> tuple[float, float, float, float, bool]:
|
11
|
+
"""
|
12
|
+
Estimates the key properties of a peak or dip in 1D data.
|
13
|
+
|
14
|
+
This function analyzes a one-dimensional dataset to identify whether the dominant
|
15
|
+
feature is a peak or dip and then estimates the following parameters:
|
16
|
+
- The position of the peak/dip (x0)
|
17
|
+
- The full width at half maximum (FWHM)
|
18
|
+
- The peak/dip height
|
19
|
+
- The baseline value (y0)
|
20
|
+
- A flag indicating if it is a peak (True) or a dip (False)
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
x_data : np.ndarray
|
25
|
+
Array of x-values.
|
26
|
+
y_data : np.ndarray
|
27
|
+
Array of y-values corresponding to `x_data`.
|
28
|
+
|
29
|
+
Returns
|
30
|
+
-------
|
31
|
+
x0 : float
|
32
|
+
The x-position of the peak or dip.
|
33
|
+
fwhm : float
|
34
|
+
Estimated full width at half maximum.
|
35
|
+
peak_height : float
|
36
|
+
Height (or depth) of the peak or dip relative to the baseline.
|
37
|
+
y0 : float
|
38
|
+
Baseline level from which the peak/dip is measured.
|
39
|
+
is_peak : bool
|
40
|
+
True if the feature is a peak; False if it is a dip.
|
41
|
+
|
42
|
+
Notes
|
43
|
+
-----
|
44
|
+
- The function uses the median of `y_data` to determine whether the dominant
|
45
|
+
feature is a peak or a dip.
|
46
|
+
- FWHM is estimated using the positions where the signal crosses the half-max level.
|
47
|
+
- If fewer than two crossings are found, a fallback FWHM is estimated as 1/10th
|
48
|
+
of the x-range.
|
49
|
+
"""
|
50
|
+
|
51
|
+
x, y = x_data, y_data
|
52
|
+
y_median = np.median(y)
|
53
|
+
y_max, y_min = np.max(y), np.min(y)
|
54
|
+
|
55
|
+
# Determine if it's a peak or dip
|
56
|
+
if y_max - y_median >= y_median - y_min:
|
57
|
+
idx = np.argmax(y)
|
58
|
+
is_peak = True
|
59
|
+
y0 = y_min
|
60
|
+
peak_height = y_max - y0
|
61
|
+
else:
|
62
|
+
idx = np.argmin(y)
|
63
|
+
is_peak = False
|
64
|
+
y0 = y_max
|
65
|
+
peak_height = y0 - y_min
|
66
|
+
|
67
|
+
x0 = x[idx]
|
68
|
+
|
69
|
+
# Estimate FWHM using half-max crossings
|
70
|
+
half_max = y0 + (peak_height / 2.0 if is_peak else -peak_height / 2.0)
|
71
|
+
crossings = np.where(np.diff(np.sign(y - half_max)))[0]
|
72
|
+
if len(crossings) >= 2:
|
73
|
+
fwhm = np.abs(x[crossings[-1]] - x[crossings[0]])
|
74
|
+
else:
|
75
|
+
fwhm = (x[-1] - x[0]) / 10.0
|
76
|
+
|
77
|
+
return x0, fwhm, peak_height, y0, is_peak
|
78
|
+
|
79
|
+
|
80
|
+
def lorentzian_guess(x_data, y_data):
|
81
|
+
"""Guess lorentzian fit parameters."""
|
82
|
+
x0, fwhm, peak_height, y0, is_peak = estimate_peak(x_data, y_data)
|
83
|
+
|
84
|
+
# Compute A from peak height = 2A / FWHM
|
85
|
+
A = (peak_height * fwhm) / 2.0
|
86
|
+
if not is_peak:
|
87
|
+
A = -A
|
88
|
+
|
89
|
+
guess = [A, x0, fwhm, y0]
|
90
|
+
return guess
|
91
|
+
|
92
|
+
|
93
|
+
def lorentzian_bounds(x_data, y_data, guess):
|
94
|
+
"""Guess lorentzian fit bounds."""
|
95
|
+
x, y = x_data, y_data
|
96
|
+
A, *_ = guess
|
97
|
+
|
98
|
+
x_span = np.max(x) - np.min(x)
|
99
|
+
A_abs = np.abs(A) if A != 0 else 1.0
|
100
|
+
fwhm_min = (x[1] - x[0]) if len(x) > 1 else x_span / 10
|
101
|
+
|
102
|
+
bounds = (
|
103
|
+
[-10 * A_abs, np.min(x) - 0.1 * x_span, fwhm_min, np.min(y) - 0.5 * A_abs],
|
104
|
+
[+10 * A_abs, np.max(x) + 0.1 * x_span, x_span, np.max(y) + 0.5 * A_abs],
|
105
|
+
)
|
106
|
+
return bounds
|
107
|
+
|
108
|
+
|
109
|
+
def gaussian_guess(x_data, y_data):
|
110
|
+
"""Guess gaussian fit parameters."""
|
111
|
+
x0, fwhm, peak_height, y0, is_peak = estimate_peak(x_data, y_data)
|
112
|
+
|
113
|
+
sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) # Convert FWHM to σ
|
114
|
+
|
115
|
+
A = peak_height * sigma * np.sqrt(2 * np.pi)
|
116
|
+
if not is_peak:
|
117
|
+
A = -A
|
118
|
+
|
119
|
+
guess = [A, x0, sigma, y0]
|
120
|
+
return guess
|
121
|
+
|
122
|
+
|
123
|
+
def gaussian_bounds(x_data, y_data, guess):
|
124
|
+
"""Guess gaussian fit bounds."""
|
125
|
+
x, y = x_data, y_data
|
126
|
+
A, *_ = guess
|
127
|
+
|
128
|
+
x_span = np.max(x) - np.min(x)
|
129
|
+
sigma_min = (x[1] - x[0]) / 10 if len(x) > 1 else x_span / 100
|
130
|
+
sigma_max = x_span
|
131
|
+
A_abs = np.abs(A)
|
132
|
+
|
133
|
+
bounds = (
|
134
|
+
[-10 * A_abs, np.min(x) - 0.1 * x_span, sigma_min, np.min(y) - 0.5 * A_abs],
|
135
|
+
[10 * A_abs, np.max(x) + 0.1 * x_span, sigma_max, np.max(y) + 0.5 * A_abs],
|
136
|
+
)
|
137
|
+
return bounds
|
138
|
+
|
139
|
+
|
140
|
+
def oscillations_guess(x_data, y_data, num_init=10):
|
141
|
+
"""Generate robust initial guesses for oscillation parameters."""
|
142
|
+
x_data = np.asarray(x_data)
|
143
|
+
y_data = np.asarray(y_data)
|
144
|
+
dx = np.mean(np.diff(x_data))
|
145
|
+
|
146
|
+
# Amplitude guess (robust against outliers)
|
147
|
+
A = (np.percentile(y_data, 95) - np.percentile(y_data, 5)) / 2
|
148
|
+
|
149
|
+
# Offset guess (tail median + mean)
|
150
|
+
y0_tail = np.median(y_data[-max(5, len(y_data) // 10) :])
|
151
|
+
y0_mean = np.mean(y_data)
|
152
|
+
y0_candidates = [y0_tail, y0_mean]
|
153
|
+
|
154
|
+
# FFT-based T (period)
|
155
|
+
y_demeaned = y_data - np.mean(y_data)
|
156
|
+
freqs = rfftfreq(len(x_data), d=dx)
|
157
|
+
spectrum = np.abs(rfft(y_demeaned))
|
158
|
+
peak_idx = np.argmax(spectrum[1:]) + 1 # Ignore DC
|
159
|
+
freq_peak = freqs[peak_idx]
|
160
|
+
T = 1 / freq_peak if freq_peak > 0 else np.ptp(x_data) # fallback to range
|
161
|
+
|
162
|
+
# Phase estimate from cross-correlation
|
163
|
+
cos_wave = np.cos(2 * np.pi * x_data / T)
|
164
|
+
lag = np.argmax(np.correlate(y_demeaned, cos_wave, mode="full")) - len(x_data) + 1
|
165
|
+
phi_base = x_data[0] + lag * dx
|
166
|
+
phi_candidates = np.linspace(phi_base - T, phi_base + T, num_init)
|
167
|
+
phi_candidates = np.mod(phi_candidates, T)
|
168
|
+
|
169
|
+
return [A, y0_candidates, phi_candidates, T]
|
170
|
+
|
171
|
+
|
172
|
+
def oscillations_bounds(x_data, y_data, guess):
|
173
|
+
"""Generate realistic bounds for oscillation parameters."""
|
174
|
+
x_data = np.asarray(x_data)
|
175
|
+
y_data = np.asarray(y_data)
|
176
|
+
|
177
|
+
A, y0, phi, T = guess
|
178
|
+
|
179
|
+
# Add small offset to ensure bounds don't collaps
|
180
|
+
eps = 1e-12
|
181
|
+
|
182
|
+
A_min = 0.1 * A - eps
|
183
|
+
A_max = 10 * A
|
184
|
+
|
185
|
+
y0_min = np.min(y_data) - eps
|
186
|
+
y0_max = np.max(y_data)
|
187
|
+
|
188
|
+
phi_min = 0.0 - eps
|
189
|
+
phi_max = T # reasonable 1-period wrap
|
190
|
+
|
191
|
+
T_min = 0.1 * T - eps
|
192
|
+
T_max = 10 * T
|
193
|
+
|
194
|
+
lower = [A_min, y0_min, phi_min, T_min]
|
195
|
+
upper = [A_max, y0_max, phi_max, T_max]
|
196
|
+
return (lower, upper)
|
197
|
+
|
198
|
+
|
199
|
+
def decaying_oscillations_guess(x_data, y_data, num_init=10):
|
200
|
+
"""Generate robust initial guesses for decaying oscillation parameters."""
|
201
|
+
x_data = np.asarray(x_data)
|
202
|
+
y_data = np.asarray(y_data)
|
203
|
+
dx = np.mean(np.diff(x_data))
|
204
|
+
|
205
|
+
# Oscillations params
|
206
|
+
A, y0_candidates, phi_candidates, T = oscillations_guess(x_data, y_data, num_init)
|
207
|
+
|
208
|
+
# Decay time (tau) from log-envelope
|
209
|
+
try:
|
210
|
+
y_demeaned = y_data - np.mean(y_data)
|
211
|
+
envelope = np.abs(hilbert(y_demeaned))
|
212
|
+
log_env = np.log(np.clip(envelope, 1e-10, None))
|
213
|
+
slope, _ = np.polyfit(x_data, log_env, 1)
|
214
|
+
tau = -1 / slope if slope < 0 else np.ptp(x_data)
|
215
|
+
except Exception:
|
216
|
+
tau = np.ptp(x_data)
|
217
|
+
|
218
|
+
# Rough estimate of y0 with a local min or mean of last N points
|
219
|
+
N_tail = max(3, int(0.1 * len(y_data)))
|
220
|
+
tail_mean = np.mean(y_data[-N_tail:])
|
221
|
+
y0_decay = min(np.min(y_data), tail_mean)
|
222
|
+
y0_candidates.append(y0_decay)
|
223
|
+
|
224
|
+
return [A, tau, y0_candidates, phi_candidates, T]
|
225
|
+
|
226
|
+
|
227
|
+
def decaying_oscillations_bounds(x_data, y_data, guess):
|
228
|
+
"""Generate realistic bounds for decaying oscillation parameters."""
|
229
|
+
x_data = np.asarray(x_data)
|
230
|
+
y_data = np.asarray(y_data)
|
231
|
+
|
232
|
+
A, tau, y0, phi, T = guess
|
233
|
+
lower, upper = oscillations_bounds(x_data, y_data, [A, y0, phi, T])
|
234
|
+
|
235
|
+
tau_min = 0.01 * tau
|
236
|
+
tau_max = 10 * tau
|
237
|
+
|
238
|
+
lower.insert(1, tau_min)
|
239
|
+
upper.insert(1, tau_max)
|
240
|
+
return (lower, upper)
|
241
|
+
|
242
|
+
|
243
|
+
def many_decaying_oscillations_guess(x_data, y_data, n):
|
244
|
+
offset = np.mean(y_data)
|
245
|
+
y_centered = y_data - offset
|
246
|
+
|
247
|
+
freqs, fft_mag = compute_fft(x_data, y_centered)
|
248
|
+
peak_freqs, peak_mags = get_peaks(freqs, fft_mag)
|
249
|
+
|
250
|
+
if len(peak_freqs) < n:
|
251
|
+
raise ValueError(
|
252
|
+
f"Not enough frequency peaks found to initialize {n} oscillations."
|
253
|
+
)
|
254
|
+
|
255
|
+
guess = []
|
256
|
+
signal_duration = x_data[-1] - x_data[0]
|
257
|
+
|
258
|
+
for i in range(n):
|
259
|
+
A = peak_mags[i]
|
260
|
+
tau = signal_duration / (2 + i) # Increasing τ for later oscillations
|
261
|
+
phi = 0.0 # Can be refined
|
262
|
+
T = peak_freqs[i]
|
263
|
+
guess.extend([A, tau, phi, T])
|
264
|
+
|
265
|
+
guess.append(offset)
|
266
|
+
return guess
|
267
|
+
|
268
|
+
|
269
|
+
def decaying_exp_guess(x_data: np.ndarray, y_data: np.ndarray) -> list[float]:
|
270
|
+
"""
|
271
|
+
Robust initial guess for decaying exponential even if the full decay isn't captured.
|
272
|
+
"""
|
273
|
+
x = np.asarray(x_data)
|
274
|
+
y = np.asarray(y_data)
|
275
|
+
|
276
|
+
# Rough estimate of y0 with a local min or mean of last N points
|
277
|
+
N_tail = max(3, int(0.1 * len(y)))
|
278
|
+
tail_mean = np.mean(y[-N_tail:])
|
279
|
+
y0 = min(np.min(y), tail_mean)
|
280
|
+
|
281
|
+
# Amplitude
|
282
|
+
A = y[0] - y0
|
283
|
+
A = np.clip(A, 1e-12, None)
|
284
|
+
|
285
|
+
# Ensure sign consistency
|
286
|
+
if np.abs(np.max(y) - y0) > np.abs(A):
|
287
|
+
A = np.max(y) - y0
|
288
|
+
|
289
|
+
# Estimate tau using log-linear fit of the first ~30% of data
|
290
|
+
N_fit = max(5, int(0.3 * len(x)))
|
291
|
+
y_fit = y[:N_fit] - y0
|
292
|
+
mask = y_fit > 0 # log() only valid on positive values
|
293
|
+
|
294
|
+
if np.count_nonzero(mask) > 1:
|
295
|
+
x_fit = x[:N_fit][mask]
|
296
|
+
log_y = np.log(y_fit[mask])
|
297
|
+
slope, intercept = np.polyfit(x_fit, log_y, 1)
|
298
|
+
tau = -1 / slope if slope < 0 else (x[-1] - x[0]) / 3
|
299
|
+
else:
|
300
|
+
tau = (x[-1] - x[0]) / 3
|
301
|
+
|
302
|
+
tau = max(tau, np.finfo(float).eps)
|
303
|
+
|
304
|
+
return [A, tau, y0]
|
sqil_core/fit/_models.py
CHANGED
@@ -10,6 +10,17 @@ def lorentzian(x, A, x0, fwhm, y0):
|
|
10
10
|
return A * (np.abs(fwhm) / 2.0) / ((x - x0) ** 2.0 + fwhm**2.0 / 4.0) + y0
|
11
11
|
|
12
12
|
|
13
|
+
def two_lorentzians_shared_x0(x_data_1, x_data_2, A1, fwhm1, y01, A2, fwhm2, y02, x0):
|
14
|
+
r"""
|
15
|
+
Concatenates two lorentzians with same x0.
|
16
|
+
L_1(x) = A_1 * (|FWHM_1| / 2) / ((x - x0)^2 + (FWHM_1^2 / 4)) + y0_1
|
17
|
+
L_2(x) = A_2 * (|FWHM_2| / 2) / ((x - x0)^2 + (FWHM_2^2 / 4)) + y0_2
|
18
|
+
"""
|
19
|
+
y1 = lorentzian(x_data_1, A1, x0, fwhm1, y01)
|
20
|
+
y2 = lorentzian(x_data_2, A2, x0, fwhm2, y02)
|
21
|
+
return np.concatenate([y1, y2])
|
22
|
+
|
23
|
+
|
13
24
|
def gaussian(x, A, x0, sigma, y0):
|
14
25
|
r"""
|
15
26
|
G(x) = A / (|σ| * sqrt(2π)) * exp(- (x - x0)^2 / (2σ^2)) + y0
|
@@ -24,6 +35,17 @@ def gaussian(x, A, x0, sigma, y0):
|
|
24
35
|
)
|
25
36
|
|
26
37
|
|
38
|
+
def two_gaussians_shared_x0(x_data_1, x_data_2, A1, fwhm1, y01, A2, fwhm2, y02, x0):
|
39
|
+
r"""
|
40
|
+
Concatenates two gaussians with same x0.
|
41
|
+
G_1(x) = A_1 / (|σ_1| * sqrt(2π)) * exp(- (x - x0)^2 / (2σ_1^2)) + y0_1
|
42
|
+
G_1(x) = A_2 / (|σ_2| * sqrt(2π)) * exp(- (x - x0)^2 / (2σ_2^2)) + y0_2
|
43
|
+
"""
|
44
|
+
y1 = gaussian(x_data_1, A1, x0, fwhm1, y01)
|
45
|
+
y2 = gaussian(x_data_2, A2, x0, fwhm2, y02)
|
46
|
+
return np.concatenate([y1, y2])
|
47
|
+
|
48
|
+
|
27
49
|
def decaying_exp(x, A, tau, y0):
|
28
50
|
r"""
|
29
51
|
f(x) = A * exp(-x / τ) + y0
|
@@ -52,10 +74,37 @@ def decaying_oscillations(x, A, tau, y0, phi, T):
|
|
52
74
|
return A * np.exp(-x / tau) * np.cos(2.0 * np.pi * (x - phi) / T) + y0
|
53
75
|
|
54
76
|
|
77
|
+
def many_decaying_oscillations(t, *params):
|
78
|
+
r"""
|
79
|
+
f(x) = SUM_i A_i * exp(-x / τ_i) * cos(2π * (x - φ_i) / T_i) + y0
|
80
|
+
|
81
|
+
$$f(x) = \sum_i A_i \cdot e^{-x/\tau_i} \cdot \cos\left(\frac{2\pi (x - \phi_i)}{T_i}\right) + y_0$$
|
82
|
+
"""
|
83
|
+
n = (len(params) - 1) // 4 # Each oscillation has 4 params: A, tau, phi, T
|
84
|
+
offset = params[-1]
|
85
|
+
result = np.zeros_like(t)
|
86
|
+
for i in range(n):
|
87
|
+
A = params[4 * i]
|
88
|
+
tau = params[4 * i + 1]
|
89
|
+
phi = params[4 * i + 2]
|
90
|
+
T = params[4 * i + 3]
|
91
|
+
result += A * np.exp(-t / tau) * np.cos(2 * np.pi * T * t + phi)
|
92
|
+
return result + offset
|
93
|
+
|
94
|
+
|
95
|
+
def oscillations(x, A, y0, phi, T):
|
96
|
+
r"""
|
97
|
+
f(x) = A * cos(2π * (x - φ) / T) + y0
|
98
|
+
|
99
|
+
$$f(x) = A \cos\left( 2\pi \frac{x - \phi}{T} \right) + y_0$$
|
100
|
+
"""
|
101
|
+
return A * np.cos(2.0 * np.pi * (x - phi) / T) + y0
|
102
|
+
|
103
|
+
|
55
104
|
def skewed_lorentzian(
|
56
105
|
f: np.ndarray, A1: float, A2: float, A3: float, A4: float, fr: float, Q_tot: float
|
57
106
|
) -> np.ndarray:
|
58
|
-
"""
|
107
|
+
r"""
|
59
108
|
Computes the skewed Lorentzian function.
|
60
109
|
|
61
110
|
This function models asymmetric resonance peaks using a skewed Lorentzian
|
@@ -0,0 +1,266 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from enum import IntEnum
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from tabulate import tabulate
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from sqil_core.fit._core import FitResult
|
11
|
+
|
12
|
+
|
13
|
+
class FitQuality(IntEnum):
|
14
|
+
BAD = 0
|
15
|
+
ACCEPTABLE = 1
|
16
|
+
GOOD = 2
|
17
|
+
GREAT = 3
|
18
|
+
|
19
|
+
def __str__(self):
|
20
|
+
return self.name
|
21
|
+
|
22
|
+
|
23
|
+
FIT_QUALITY_THRESHOLDS = {
|
24
|
+
"nrmse": [
|
25
|
+
(0.01, FitQuality.GREAT),
|
26
|
+
(0.03, FitQuality.GOOD),
|
27
|
+
(0.08, FitQuality.ACCEPTABLE),
|
28
|
+
(np.inf, FitQuality.BAD),
|
29
|
+
],
|
30
|
+
"nmae": [
|
31
|
+
(0.01, FitQuality.GREAT),
|
32
|
+
(0.03, FitQuality.GOOD),
|
33
|
+
(0.08, FitQuality.ACCEPTABLE),
|
34
|
+
(np.inf, FitQuality.BAD),
|
35
|
+
],
|
36
|
+
"red_chi2": [
|
37
|
+
(0.5, FitQuality.ACCEPTABLE),
|
38
|
+
(0.9, FitQuality.GOOD),
|
39
|
+
(1.1, FitQuality.GREAT),
|
40
|
+
(2.0, FitQuality.GOOD),
|
41
|
+
(5.0, FitQuality.ACCEPTABLE),
|
42
|
+
(np.inf, FitQuality.BAD),
|
43
|
+
],
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
def evaluate_fit_quality(fit_metrics: dict, recipe: str = "nrmse") -> FitQuality:
|
48
|
+
"""
|
49
|
+
Evaluates the quality category of a fit based on a specified metric recipe.
|
50
|
+
|
51
|
+
This function maps a numeric fit metric (e.g., NRMSE or AIC) to a qualitative
|
52
|
+
fit quality category (GREAT, GOOD, ACCEPTABLE, BAD) using predefined thresholds. These
|
53
|
+
thresholds are stored in the `FIT_QUALITY_THRESHOLDS` dictionary and must be
|
54
|
+
provided for each supported recipe.
|
55
|
+
|
56
|
+
Parameters
|
57
|
+
----------
|
58
|
+
fit_metrics : dict
|
59
|
+
Dictionary containing computed metrics from a fit. Must include the key
|
60
|
+
specified by `recipe`.
|
61
|
+
recipe : str, optional
|
62
|
+
The name of the metric to evaluate quality against. Default is "nrmse".
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
FitQuality
|
67
|
+
A qualitative classification of the fit (GREAT, GOOD, ACCEPTABLE, BAD), represented
|
68
|
+
by an enum or constant defined in `FitQuality`.
|
69
|
+
"""
|
70
|
+
|
71
|
+
value = fit_metrics.get(recipe)
|
72
|
+
if value is None:
|
73
|
+
raise KeyError(
|
74
|
+
f"The metrics provided aren't sufficient to use recipe '{recipe}'"
|
75
|
+
)
|
76
|
+
|
77
|
+
thresholds = FIT_QUALITY_THRESHOLDS.get(recipe)
|
78
|
+
if thresholds is None:
|
79
|
+
raise NotImplementedError(
|
80
|
+
f"No fit quality threshold available for '{recipe}'."
|
81
|
+
+ " You can add them to 'FIT_QUALITY_THRESHOLDS'"
|
82
|
+
)
|
83
|
+
|
84
|
+
for threshold, quality in thresholds:
|
85
|
+
if value <= threshold:
|
86
|
+
return quality
|
87
|
+
|
88
|
+
return FitQuality.BAD
|
89
|
+
|
90
|
+
|
91
|
+
def get_best_fit(
|
92
|
+
fit_res_a: FitResult,
|
93
|
+
fit_res_b: FitResult,
|
94
|
+
recipe: Literal["nrmse_aic"] = "nrmse_aic",
|
95
|
+
):
|
96
|
+
"""
|
97
|
+
Selects the better fit result according to a specified selection recipe.
|
98
|
+
|
99
|
+
This function acts as a dispatcher to choose between two fit results using a
|
100
|
+
predefined comparison strategy.
|
101
|
+
|
102
|
+
Supported recipies:
|
103
|
+
- "nrmse_aic": uses NRMSE as primary metric and adjusts it with AIC if the
|
104
|
+
NRMSE are in the same quality category.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
fit_res_a : FitResult
|
109
|
+
The first fit result object containing metrics and parameters.
|
110
|
+
fit_res_b : FitResult
|
111
|
+
The second fit result object containing metrics and parameters.
|
112
|
+
recipe : Literal["nrmse_aic"], optional
|
113
|
+
The name of the comparison strategy to use.
|
114
|
+
|
115
|
+
Returns
|
116
|
+
-------
|
117
|
+
FitResult
|
118
|
+
The selected fit result, based on the comparison strategy.
|
119
|
+
|
120
|
+
Examples
|
121
|
+
--------
|
122
|
+
>>> best_fit = get_best_fit(fit1, fit2)
|
123
|
+
>>> print("Best-fit parameters:", best_fit.params)
|
124
|
+
"""
|
125
|
+
|
126
|
+
if recipe == "nrmse_aic":
|
127
|
+
return get_best_fit_nrmse_aic(fit_res_a, fit_res_b)
|
128
|
+
raise NotImplementedError(f"Recipe {recipe} does not exist")
|
129
|
+
|
130
|
+
|
131
|
+
def get_best_fit_nrmse_aic(
|
132
|
+
fit_res_a: FitResult, fit_res_b: FitResult, aic_rel_tol: float = 0.01
|
133
|
+
):
|
134
|
+
"""
|
135
|
+
Selects the better fit result based on NRMSE quality and AIC with complexity penalty.
|
136
|
+
|
137
|
+
This function compares two fit results by first evaluating the normalized root
|
138
|
+
mean squared error (NRMSE) using a quality categorization scheme. If the fits
|
139
|
+
differ in NRMSE quality, the one with better quality is selected. If the
|
140
|
+
qualities are equal, the function compares the Akaike Information Criterion (AIC),
|
141
|
+
using a relative tolerance to determine statistical equivalence. When AIC values
|
142
|
+
are within tolerance, the simpler model (with fewer parameters) is preferred.
|
143
|
+
|
144
|
+
Parameters
|
145
|
+
----------
|
146
|
+
fit_res_a : FitResult
|
147
|
+
The first FitResult object.
|
148
|
+
fit_res_b : FitResult
|
149
|
+
The second FitResult object.
|
150
|
+
aic_rel_tol : float, optional
|
151
|
+
The relative tolerance for AIC comparison. If the relative difference in AIC
|
152
|
+
is below this threshold, models are considered equally good, and complexity
|
153
|
+
(number of parameters) is used as a tiebreaker. Default is 0.01.
|
154
|
+
|
155
|
+
Returns
|
156
|
+
-------
|
157
|
+
FitResult
|
158
|
+
The preferred fit result based on NRMSE category, AIC, and model simplicity.
|
159
|
+
|
160
|
+
Notes
|
161
|
+
-----
|
162
|
+
- If models are statistically equivalent in AIC and have the same complexity,
|
163
|
+
the first result is returned for consistency.
|
164
|
+
- If the minimum AIC is zero, relative delta AIC is replaced by its absolute counter
|
165
|
+
part, but still using the aic_rel_tol as tolerance.
|
166
|
+
|
167
|
+
Examples
|
168
|
+
--------
|
169
|
+
>>> best_fit = get_best_fit_nrmse_aic(fit1, fit2)
|
170
|
+
>>> print("Selected model parameters:", best_fit.params)
|
171
|
+
"""
|
172
|
+
|
173
|
+
quality_a = evaluate_fit_quality(fit_res_a.metrics)
|
174
|
+
quality_b = evaluate_fit_quality(fit_res_b.metrics)
|
175
|
+
|
176
|
+
# If NMRSE qualities are not in the same category, return the best one
|
177
|
+
if quality_a != quality_b:
|
178
|
+
return fit_res_a if quality_a > quality_b else fit_res_b
|
179
|
+
aic_a = fit_res_a.metrics.get("aic")
|
180
|
+
aic_b = fit_res_b.metrics.get("aic")
|
181
|
+
|
182
|
+
# Use AIC to penalize fit complexity
|
183
|
+
if aic_a is None or aic_b is None:
|
184
|
+
raise ValueError("Missing AIC value in one of the fits")
|
185
|
+
delta = abs(aic_a - aic_b)
|
186
|
+
min_aic = abs(min(aic_a, aic_b))
|
187
|
+
rel_delta = delta / min_aic if min_aic != 0 else delta
|
188
|
+
if rel_delta < aic_rel_tol:
|
189
|
+
# Within tolerance: consider them equivalent, return simpler (fewer params)
|
190
|
+
len_a, len_b = len(fit_res_a.params), len(fit_res_b.params)
|
191
|
+
if len_a != len_b:
|
192
|
+
return fit_res_a if len_a < len_b else fit_res_b
|
193
|
+
# Otherwise: arbitrary but consistent
|
194
|
+
return fit_res_a
|
195
|
+
# Outside tolerance: pick the one with lower AIC
|
196
|
+
return fit_res_a if aic_a < aic_b else fit_res_b
|
197
|
+
|
198
|
+
|
199
|
+
def format_fit_metrics(fit_metrics: dict, keys: list[str] | None = None) -> str:
|
200
|
+
"""
|
201
|
+
Formats and summarizes selected fit metrics with qualitative evaluations.
|
202
|
+
|
203
|
+
This function generates a human-readable table that reports selected fit metrics
|
204
|
+
(e.g., reduced χ², R², NRMSE) alongside their numerical values and qualitative
|
205
|
+
quality assessments. Quality categories are determined using `evaluate_fit_quality`.
|
206
|
+
|
207
|
+
Parameters
|
208
|
+
----------
|
209
|
+
fit_metrics : dict
|
210
|
+
Dictionary of fit metrics to display. Should contain values for keys like
|
211
|
+
"red_chi2", "r2", "nrmse", etc.
|
212
|
+
keys : list of str, optional
|
213
|
+
Subset of metric keys to include in the output. If None, all available keys
|
214
|
+
in `fit_metrics` are considered.
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
str
|
219
|
+
A plain-text table summarizing the selected metrics with their values and
|
220
|
+
associated quality labels.
|
221
|
+
|
222
|
+
Notes
|
223
|
+
-----
|
224
|
+
- Complex-valued R² metrics are skipped.
|
225
|
+
- Keys are optionally renamed for output formatting (e.g., "red_chi2" → "reduced χ²").
|
226
|
+
|
227
|
+
Examples
|
228
|
+
--------
|
229
|
+
>>> metrics = {"red_chi2": 1.2, "r2": 0.97, "nrmse": 0.05}
|
230
|
+
>>> print(format_fit_metrics(metrics))
|
231
|
+
reduced χ² 1.200e+00 GOOD
|
232
|
+
R² 9.700e-01 GOOD
|
233
|
+
nrmse 5.000e-02 GOOD
|
234
|
+
"""
|
235
|
+
|
236
|
+
table_data = []
|
237
|
+
|
238
|
+
if keys is None:
|
239
|
+
keys = fit_metrics.keys() if fit_metrics else []
|
240
|
+
|
241
|
+
# Print fit quality parameters
|
242
|
+
for key in keys:
|
243
|
+
value = fit_metrics[key]
|
244
|
+
quality = ""
|
245
|
+
# Evaluate reduced Chi-squared
|
246
|
+
if key == "red_chi2":
|
247
|
+
key = "reduced χ²"
|
248
|
+
quality = evaluate_fit_quality(fit_metrics, "red_chi2")
|
249
|
+
# Evaluate R-squared
|
250
|
+
elif key == "r2":
|
251
|
+
# Skip if complex
|
252
|
+
if isinstance(value, complex):
|
253
|
+
continue
|
254
|
+
key = "R²"
|
255
|
+
quality = evaluate_fit_quality(fit_metrics, "r2")
|
256
|
+
# Normalized root mean square error NRMSE
|
257
|
+
# Normalized mean absolute error NMAE and
|
258
|
+
elif (key == "nrmse") or (key == "nmae"):
|
259
|
+
quality = evaluate_fit_quality(fit_metrics, key)
|
260
|
+
else:
|
261
|
+
continue
|
262
|
+
|
263
|
+
quality_label = str(quality)
|
264
|
+
|
265
|
+
table_data.append([key, f"{value:.3e}", quality_label])
|
266
|
+
return tabulate(table_data, tablefmt="plain")
|