reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,758 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Jupyter Widget Components for Reflectorch
|
|
3
|
+
|
|
4
|
+
This module contains reusable widget components that can be composed
|
|
5
|
+
to create different interfaces for reflectometry analysis.
|
|
6
|
+
|
|
7
|
+
Components:
|
|
8
|
+
- ParameterTable: Interactive parameter table with sliders and results
|
|
9
|
+
- PreprocessingControls: Data preprocessing options
|
|
10
|
+
- PredictionControls: Prediction and computation settings
|
|
11
|
+
- PlottingControls: Plotting and visualization options
|
|
12
|
+
- AdditionalParametersControls: Controls for additional parameters like Q resolution
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from typing import Optional, Dict, Any, List, Tuple
|
|
17
|
+
import ipywidgets as widgets
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ParameterTable:
|
|
21
|
+
"""
|
|
22
|
+
Interactive parameter table with sliders and result displays
|
|
23
|
+
|
|
24
|
+
Features:
|
|
25
|
+
- Structured table layout with aligned columns
|
|
26
|
+
- Real-time result updates after predictions
|
|
27
|
+
- Automatic slider validation
|
|
28
|
+
- Professional styling
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, param_labels: List[str], min_bounds: np.ndarray,
|
|
32
|
+
max_bounds: np.ndarray, max_deltas: np.ndarray,
|
|
33
|
+
initial_bounds: Optional[np.ndarray] = None,
|
|
34
|
+
additional_params_controls: Optional['AdditionalParametersControls'] = None,
|
|
35
|
+
predict_button: Optional[widgets.Button] = None):
|
|
36
|
+
"""
|
|
37
|
+
Initialize parameter table
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
param_labels: List of parameter names
|
|
41
|
+
min_bounds: Minimum values for each parameter
|
|
42
|
+
max_bounds: Maximum values for each parameter
|
|
43
|
+
max_deltas: Maximum allowed range for each parameter
|
|
44
|
+
initial_bounds: Initial bounds, shape (n_params, 2)
|
|
45
|
+
additional_params_controls: Optional additional parameters component
|
|
46
|
+
predict_button: Optional predict button to include in the table header
|
|
47
|
+
"""
|
|
48
|
+
self.param_labels = param_labels
|
|
49
|
+
self.min_bounds = min_bounds
|
|
50
|
+
self.max_bounds = max_bounds
|
|
51
|
+
self.max_deltas = max_deltas
|
|
52
|
+
self.sliders = [] # Store range sliders
|
|
53
|
+
self.min_inputs = [] # Store min bound inputs
|
|
54
|
+
self.max_inputs = [] # Store max bound inputs
|
|
55
|
+
self.result_displays = {}
|
|
56
|
+
self.additional_params_controls = additional_params_controls
|
|
57
|
+
self.predict_button = predict_button
|
|
58
|
+
|
|
59
|
+
self.widget = self._create_table(initial_bounds)
|
|
60
|
+
|
|
61
|
+
def _create_table(self, initial_bounds: Optional[np.ndarray] = None) -> widgets.VBox:
|
|
62
|
+
"""Create the parameter table widget"""
|
|
63
|
+
init_pb = np.array(initial_bounds) if initial_bounds is not None else None
|
|
64
|
+
|
|
65
|
+
# Add custom CSS for styling
|
|
66
|
+
custom_style = widgets.HTML("""
|
|
67
|
+
<style>
|
|
68
|
+
.widget-float-text input {
|
|
69
|
+
text-align: center;
|
|
70
|
+
}
|
|
71
|
+
.widget-inline-hbox .widget-readout {
|
|
72
|
+
margin-left: 20px !important;
|
|
73
|
+
}
|
|
74
|
+
.widget-slider .ui-slider {
|
|
75
|
+
width: 280px !important;
|
|
76
|
+
}
|
|
77
|
+
</style>
|
|
78
|
+
""")
|
|
79
|
+
|
|
80
|
+
# Create header row
|
|
81
|
+
header = widgets.HBox([
|
|
82
|
+
widgets.HTML("<b>Parameter</b>", layout=widgets.Layout(width='150px')),
|
|
83
|
+
widgets.HTML("<b>Prior Bounds</b>", layout=widgets.Layout(width='300px')),
|
|
84
|
+
widgets.HTML("<b></b>", layout=widgets.Layout(width='50px')),
|
|
85
|
+
widgets.HTML("<b></b>", layout=widgets.Layout(width='50px')),
|
|
86
|
+
widgets.HTML("<b>Predicted</b>", layout=widgets.Layout(width='100px')),
|
|
87
|
+
widgets.HTML("<b>Polished</b>", layout=widgets.Layout(width='100px')),
|
|
88
|
+
widgets.HTML("<b>Uncertainty</b>", layout=widgets.Layout(width='100px'))
|
|
89
|
+
], layout=widgets.Layout(margin='5px 0px', align_items='center'))
|
|
90
|
+
|
|
91
|
+
# Create parameter rows
|
|
92
|
+
parameter_rows = []
|
|
93
|
+
|
|
94
|
+
if not self.param_labels:
|
|
95
|
+
# No parameters available - show message
|
|
96
|
+
no_params_message = widgets.HTML(
|
|
97
|
+
value="<i>No model loaded. Please load a model from the Models tab to see parameters.</i>",
|
|
98
|
+
layout=widgets.Layout(margin='20px 5px')
|
|
99
|
+
)
|
|
100
|
+
parameter_rows.append(no_params_message)
|
|
101
|
+
|
|
102
|
+
for i, label in enumerate(self.param_labels):
|
|
103
|
+
init_min = float(init_pb[i, 0]) if init_pb is not None else float(self.min_bounds[i])
|
|
104
|
+
init_max = float(init_pb[i, 1]) if init_pb is not None else float(min(self.min_bounds[i] + self.max_deltas[i], self.max_bounds[i]))
|
|
105
|
+
|
|
106
|
+
# Parameter label
|
|
107
|
+
param_label = widgets.HTML(
|
|
108
|
+
value=f"<b>{label}</b>",
|
|
109
|
+
layout=widgets.Layout(width='150px', display='flex', align_items='center', justify_content='flex-start')
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Range slider - updates continuously
|
|
113
|
+
slider = widgets.FloatRangeSlider(
|
|
114
|
+
value=[init_min, init_max],
|
|
115
|
+
min=float(self.min_bounds[i]),
|
|
116
|
+
max=float(self.max_bounds[i]),
|
|
117
|
+
step=0.01,
|
|
118
|
+
layout=widgets.Layout(width='180px'),
|
|
119
|
+
readout=False, # Hide readout since we have input boxes
|
|
120
|
+
continuous_update=True, # Update instantly while dragging
|
|
121
|
+
style={'description_width': '0px', 'handle_color': '#2196F3'}
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Min bound input - updates continuously
|
|
125
|
+
min_input = widgets.FloatText(
|
|
126
|
+
value=round(init_min, 2),
|
|
127
|
+
step=0.01,
|
|
128
|
+
layout=widgets.Layout(width='65px'),
|
|
129
|
+
continuous_update=True # Update on every keystroke
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Max bound input - updates continuously
|
|
133
|
+
max_input = widgets.FloatText(
|
|
134
|
+
value=round(init_max, 2),
|
|
135
|
+
step=0.01,
|
|
136
|
+
layout=widgets.Layout(width='65px'),
|
|
137
|
+
continuous_update=True # Update on every keystroke
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Result displays
|
|
141
|
+
predicted_display = widgets.HTML(
|
|
142
|
+
value="<i>-</i>",
|
|
143
|
+
layout=widgets.Layout(width='100px', display='flex', align_items='center', justify_content='flex-start')
|
|
144
|
+
)
|
|
145
|
+
polished_display = widgets.HTML(
|
|
146
|
+
value="<i>-</i>",
|
|
147
|
+
layout=widgets.Layout(width='100px', display='flex', align_items='center', justify_content='flex-start')
|
|
148
|
+
)
|
|
149
|
+
uncertainty_display = widgets.HTML(
|
|
150
|
+
value="<i>-</i>",
|
|
151
|
+
layout=widgets.Layout(width='100px', display='flex', align_items='center', justify_content='flex-start')
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Store references for updating results
|
|
155
|
+
self.result_displays[i] = {
|
|
156
|
+
'predicted': predicted_display,
|
|
157
|
+
'polished': polished_display,
|
|
158
|
+
'uncertainty': uncertainty_display
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
# Add synchronization between slider and inputs
|
|
162
|
+
self._add_widget_synchronization(slider, min_input, max_input, i)
|
|
163
|
+
self.sliders.append(slider)
|
|
164
|
+
self.min_inputs.append(min_input)
|
|
165
|
+
self.max_inputs.append(max_input)
|
|
166
|
+
|
|
167
|
+
# Create row layout with vertical alignment
|
|
168
|
+
row = widgets.HBox([
|
|
169
|
+
param_label,
|
|
170
|
+
slider,
|
|
171
|
+
min_input,
|
|
172
|
+
max_input,
|
|
173
|
+
predicted_display,
|
|
174
|
+
polished_display,
|
|
175
|
+
uncertainty_display
|
|
176
|
+
], layout=widgets.Layout(margin='5px 0px', align_items='center'))
|
|
177
|
+
|
|
178
|
+
parameter_rows.append(row)
|
|
179
|
+
|
|
180
|
+
self.param_title = widgets.HTML("<h4>Parameter Configuration</h4>")
|
|
181
|
+
|
|
182
|
+
# Create title row with optional predict button
|
|
183
|
+
if self.predict_button is not None:
|
|
184
|
+
title_row = widgets.HBox([
|
|
185
|
+
self.param_title,
|
|
186
|
+
self.predict_button
|
|
187
|
+
], layout=widgets.Layout(justify_content='space-between', align_items='center'))
|
|
188
|
+
else:
|
|
189
|
+
title_row = self.param_title
|
|
190
|
+
|
|
191
|
+
# Create the main parameter table
|
|
192
|
+
main_table = widgets.VBox([
|
|
193
|
+
custom_style,
|
|
194
|
+
title_row,
|
|
195
|
+
header,
|
|
196
|
+
widgets.HTML("<hr style='margin: 5px 0px;'>"),
|
|
197
|
+
*parameter_rows
|
|
198
|
+
])
|
|
199
|
+
|
|
200
|
+
# Create the complete widget with optional additional parameters
|
|
201
|
+
table_components = [main_table]
|
|
202
|
+
|
|
203
|
+
# Add additional parameters section if available
|
|
204
|
+
if (self.additional_params_controls is not None and
|
|
205
|
+
self.additional_params_controls.additional_sliders):
|
|
206
|
+
table_components.extend([
|
|
207
|
+
widgets.HTML("<br>"),
|
|
208
|
+
self.additional_params_controls.widget
|
|
209
|
+
])
|
|
210
|
+
|
|
211
|
+
return widgets.VBox(table_components)
|
|
212
|
+
|
|
213
|
+
def _add_widget_synchronization(self, slider: widgets.FloatRangeSlider,
|
|
214
|
+
min_input: widgets.FloatText,
|
|
215
|
+
max_input: widgets.FloatText,
|
|
216
|
+
param_idx: int):
|
|
217
|
+
"""Add synchronization between slider and input boxes with validation"""
|
|
218
|
+
max_width = float(self.max_deltas[param_idx])
|
|
219
|
+
global_min = float(self.min_bounds[param_idx])
|
|
220
|
+
global_max = float(self.max_bounds[param_idx])
|
|
221
|
+
|
|
222
|
+
# Flag to prevent infinite update loops
|
|
223
|
+
updating = {'active': False}
|
|
224
|
+
|
|
225
|
+
def validate_and_clamp(min_val, max_val, source='slider'):
|
|
226
|
+
"""Apply all constraints and return valid (min, max) pair
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
min_val: Minimum value
|
|
230
|
+
max_val: Maximum value
|
|
231
|
+
source: Which widget triggered the change ('slider', 'min', or 'max')
|
|
232
|
+
"""
|
|
233
|
+
# Clamp to global bounds first
|
|
234
|
+
min_val = max(global_min, min(min_val, global_max))
|
|
235
|
+
max_val = max(global_min, min(max_val, global_max))
|
|
236
|
+
|
|
237
|
+
# Ensure min < max with proper handling based on which widget changed
|
|
238
|
+
min_step = 0.01 # Minimum separation between min and max
|
|
239
|
+
|
|
240
|
+
if min_val >= max_val:
|
|
241
|
+
if source == 'min':
|
|
242
|
+
# User changed min, adjust min to be less than max
|
|
243
|
+
min_val = max_val - min_step
|
|
244
|
+
# If this pushes min below global_min, adjust max instead
|
|
245
|
+
if min_val < global_min:
|
|
246
|
+
min_val = global_min
|
|
247
|
+
max_val = min_val + min_step
|
|
248
|
+
# If max exceeds global_max, we have a problem
|
|
249
|
+
if max_val > global_max:
|
|
250
|
+
max_val = global_max
|
|
251
|
+
min_val = max_val - min_step
|
|
252
|
+
elif source == 'max':
|
|
253
|
+
# User changed max, adjust max to be greater than min
|
|
254
|
+
max_val = min_val + min_step
|
|
255
|
+
# If this pushes max above global_max, adjust min instead
|
|
256
|
+
if max_val > global_max:
|
|
257
|
+
max_val = global_max
|
|
258
|
+
min_val = max_val - min_step
|
|
259
|
+
# If min goes below global_min, we have a problem
|
|
260
|
+
if min_val < global_min:
|
|
261
|
+
min_val = global_min
|
|
262
|
+
max_val = min_val + min_step
|
|
263
|
+
else: # source == 'slider'
|
|
264
|
+
# Slider changed, just ensure there's a minimum gap
|
|
265
|
+
avg = (min_val + max_val) / 2
|
|
266
|
+
min_val = avg - min_step / 2
|
|
267
|
+
max_val = avg + min_step / 2
|
|
268
|
+
|
|
269
|
+
# Ensure range doesn't exceed max_width
|
|
270
|
+
if max_val - min_val > max_width:
|
|
271
|
+
if source == 'min':
|
|
272
|
+
# User changed min, keep max fixed and adjust min
|
|
273
|
+
min_val = max_val - max_width
|
|
274
|
+
if min_val < global_min:
|
|
275
|
+
min_val = global_min
|
|
276
|
+
max_val = min_val + max_width
|
|
277
|
+
elif source == 'max':
|
|
278
|
+
# User changed max, keep min fixed and adjust max
|
|
279
|
+
max_val = min_val + max_width
|
|
280
|
+
if max_val > global_max:
|
|
281
|
+
max_val = global_max
|
|
282
|
+
min_val = max_val - max_width
|
|
283
|
+
else: # source == 'slider'
|
|
284
|
+
# Slider changed, adjust based on center
|
|
285
|
+
center = (min_val + max_val) / 2
|
|
286
|
+
min_val = max(global_min, center - max_width / 2)
|
|
287
|
+
max_val = min(global_max, min_val + max_width)
|
|
288
|
+
if max_val - min_val > max_width:
|
|
289
|
+
max_val = min(global_max, center + max_width / 2)
|
|
290
|
+
min_val = max(global_min, max_val - max_width)
|
|
291
|
+
|
|
292
|
+
return min_val, max_val
|
|
293
|
+
|
|
294
|
+
def sync_all_widgets(min_val, max_val, source='slider'):
|
|
295
|
+
"""Update all three widgets if values changed"""
|
|
296
|
+
if updating['active']:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
updating['active'] = True
|
|
300
|
+
|
|
301
|
+
# Validate values with source information
|
|
302
|
+
min_val, max_val = validate_and_clamp(min_val, max_val, source)
|
|
303
|
+
|
|
304
|
+
# Round to 2 decimal places for display
|
|
305
|
+
min_val = round(min_val, 2)
|
|
306
|
+
max_val = round(max_val, 2)
|
|
307
|
+
|
|
308
|
+
# Update all widgets if needed
|
|
309
|
+
if slider.value != (min_val, max_val):
|
|
310
|
+
slider.value = (min_val, max_val)
|
|
311
|
+
if min_input.value != min_val:
|
|
312
|
+
min_input.value = min_val
|
|
313
|
+
if max_input.value != max_val:
|
|
314
|
+
max_input.value = max_val
|
|
315
|
+
|
|
316
|
+
updating['active'] = False
|
|
317
|
+
|
|
318
|
+
def on_slider_change(change):
|
|
319
|
+
min_val, max_val = change['new']
|
|
320
|
+
sync_all_widgets(min_val, max_val, source='slider')
|
|
321
|
+
|
|
322
|
+
def on_min_input_change(change):
|
|
323
|
+
min_val = change['new']
|
|
324
|
+
max_val = max_input.value
|
|
325
|
+
sync_all_widgets(min_val, max_val, source='min')
|
|
326
|
+
|
|
327
|
+
def on_max_input_change(change):
|
|
328
|
+
min_val = min_input.value
|
|
329
|
+
max_val = change['new']
|
|
330
|
+
sync_all_widgets(min_val, max_val, source='max')
|
|
331
|
+
|
|
332
|
+
# Attach observers
|
|
333
|
+
slider.observe(on_slider_change, names='value')
|
|
334
|
+
min_input.observe(on_min_input_change, names='value')
|
|
335
|
+
max_input.observe(on_max_input_change, names='value')
|
|
336
|
+
|
|
337
|
+
def get_prior_bounds(self) -> np.ndarray:
|
|
338
|
+
"""Get current prior bounds from input boxes"""
|
|
339
|
+
if not self.min_inputs or not self.max_inputs:
|
|
340
|
+
return np.array([], dtype=np.float32).reshape(0, 2)
|
|
341
|
+
return np.array([[min_inp.value, max_inp.value]
|
|
342
|
+
for min_inp, max_inp in zip(self.min_inputs, self.max_inputs)],
|
|
343
|
+
dtype=np.float32)
|
|
344
|
+
|
|
345
|
+
def update_results(self, prediction_result: Dict[str, Any]):
|
|
346
|
+
"""Update parameter result displays"""
|
|
347
|
+
if not prediction_result:
|
|
348
|
+
return
|
|
349
|
+
|
|
350
|
+
predicted_params = prediction_result.get('predicted_params_array', [])
|
|
351
|
+
polished_params = prediction_result.get('polished_params_array', None)
|
|
352
|
+
error_bars = prediction_result.get('polished_params_error_array', None)
|
|
353
|
+
|
|
354
|
+
for i, displays in self.result_displays.items():
|
|
355
|
+
# Update predicted value
|
|
356
|
+
if i < len(predicted_params):
|
|
357
|
+
pred_val = predicted_params[i]
|
|
358
|
+
displays['predicted'].value = f"{pred_val:.2f}"
|
|
359
|
+
else:
|
|
360
|
+
displays['predicted'].value = "<i>-</i>"
|
|
361
|
+
|
|
362
|
+
# Update polished value
|
|
363
|
+
if polished_params is not None and i < len(polished_params):
|
|
364
|
+
pol_val = polished_params[i]
|
|
365
|
+
displays['polished'].value = f"{pol_val:.2f}"
|
|
366
|
+
else:
|
|
367
|
+
displays['polished'].value = "<i>-</i>"
|
|
368
|
+
|
|
369
|
+
# Update uncertainty/error bars value
|
|
370
|
+
if error_bars is not None and i < len(error_bars):
|
|
371
|
+
err_val = error_bars[i]
|
|
372
|
+
displays['uncertainty'].value = f"±{err_val:.2f}"
|
|
373
|
+
else:
|
|
374
|
+
displays['uncertainty'].value = "<i>-</i>"
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class PreprocessingControls:
|
|
378
|
+
"""Data preprocessing controls for the widget"""
|
|
379
|
+
|
|
380
|
+
def __init__(self, n_datapoints: int):
|
|
381
|
+
"""
|
|
382
|
+
Initialize preprocessing controls
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
n_datapoints: Number of data points in the dataset
|
|
386
|
+
"""
|
|
387
|
+
self.n_datapoints = n_datapoints
|
|
388
|
+
self.widget = self._create_controls()
|
|
389
|
+
|
|
390
|
+
def _create_controls(self) -> widgets.VBox:
|
|
391
|
+
"""Create preprocessing controls widget"""
|
|
392
|
+
return widgets.VBox([
|
|
393
|
+
widgets.HTML("<h4>Data Preprocessing</h4>"),
|
|
394
|
+
|
|
395
|
+
# Truncation section
|
|
396
|
+
widgets.HTML("<h5>Data Truncation</h5>"),
|
|
397
|
+
widgets.HTML("<i>Specify which data points to include in the analysis</i>"),
|
|
398
|
+
widgets.VBox([
|
|
399
|
+
widgets.IntSlider(
|
|
400
|
+
description='Left index:', min=0, max=max(0, self.n_datapoints-1),
|
|
401
|
+
step=1, value=0,
|
|
402
|
+
),
|
|
403
|
+
widgets.IntSlider(
|
|
404
|
+
description='Right index:', min=1, max=self.n_datapoints,
|
|
405
|
+
step=1, value=self.n_datapoints,
|
|
406
|
+
)
|
|
407
|
+
]),
|
|
408
|
+
|
|
409
|
+
widgets.HTML("<br>"),
|
|
410
|
+
|
|
411
|
+
# Error bar filtering section
|
|
412
|
+
widgets.HTML("<h5>Error Bar Filtering</h5>"),
|
|
413
|
+
widgets.HTML("<i>Filter out unreliable data points based on error bars</i>"),
|
|
414
|
+
widgets.VBox([
|
|
415
|
+
widgets.Checkbox(description='Enable filtering', value=True),
|
|
416
|
+
widgets.Checkbox(description='Remove singles', value=True),
|
|
417
|
+
widgets.Checkbox(description='Remove consecutives', value=True)
|
|
418
|
+
]),
|
|
419
|
+
widgets.VBox([
|
|
420
|
+
widgets.FloatSlider(
|
|
421
|
+
description='Threshold:', min=0.0, max=1.0, step=0.01, value=0.3,
|
|
422
|
+
),
|
|
423
|
+
widgets.IntSlider(
|
|
424
|
+
description='Consecutive:', min=1, max=10, step=1, value=3,
|
|
425
|
+
),
|
|
426
|
+
widgets.FloatSlider(
|
|
427
|
+
description='Q start trunc:', min=0.0, max=1.0, step=0.01, value=0.1,
|
|
428
|
+
)
|
|
429
|
+
])
|
|
430
|
+
])
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class PredictionControls:
|
|
434
|
+
"""Prediction and computation settings controls"""
|
|
435
|
+
|
|
436
|
+
def __init__(self):
|
|
437
|
+
self.widget = self._create_controls()
|
|
438
|
+
|
|
439
|
+
def _create_controls(self) -> widgets.VBox:
|
|
440
|
+
"""Create prediction controls widget"""
|
|
441
|
+
return widgets.VBox([
|
|
442
|
+
widgets.HTML("<h4>Prediction & Computation Settings</h4>"),
|
|
443
|
+
|
|
444
|
+
# Prediction settings
|
|
445
|
+
widgets.HTML("<h5>Prediction Options</h5>"),
|
|
446
|
+
widgets.VBox([
|
|
447
|
+
widgets.Checkbox(description='Polish prediction', value=True),
|
|
448
|
+
widgets.Checkbox(description='Use sigmas for polishing', value=True)
|
|
449
|
+
]),
|
|
450
|
+
|
|
451
|
+
widgets.HTML("<br>"),
|
|
452
|
+
|
|
453
|
+
# Computation settings
|
|
454
|
+
widgets.HTML("<h5>Computation Settings</h5>"),
|
|
455
|
+
widgets.HTML("<i>Choose what to calculate during prediction</i>"),
|
|
456
|
+
widgets.VBox([
|
|
457
|
+
widgets.Checkbox(description='Calculate curve', value=True),
|
|
458
|
+
widgets.Checkbox(description='Calculate pred SLD', value=True),
|
|
459
|
+
widgets.Checkbox(description='Calculate polished SLD', value=True)
|
|
460
|
+
])
|
|
461
|
+
])
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
class AdditionalParametersControls:
|
|
465
|
+
"""
|
|
466
|
+
Controls for additional parameters that are not part of prior bounds
|
|
467
|
+
|
|
468
|
+
These are fixed input parameters like Q resolution that are passed
|
|
469
|
+
separately to the inference model.
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
def __init__(self, inference_model=None):
|
|
473
|
+
"""
|
|
474
|
+
Initialize additional parameters controls
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
inference_model: The inference model to check for additional parameters (optional)
|
|
478
|
+
"""
|
|
479
|
+
self.inference_model = inference_model
|
|
480
|
+
self.additional_sliders = {}
|
|
481
|
+
self.widget = self._create_controls()
|
|
482
|
+
|
|
483
|
+
def _create_controls(self) -> widgets.VBox:
|
|
484
|
+
"""Create additional parameters controls widget"""
|
|
485
|
+
controls = []
|
|
486
|
+
|
|
487
|
+
# Check if model has smearing (Q resolution)
|
|
488
|
+
if (self.inference_model is not None and
|
|
489
|
+
hasattr(self.inference_model, 'trainer') and
|
|
490
|
+
hasattr(self.inference_model.trainer, 'loader') and
|
|
491
|
+
hasattr(self.inference_model.trainer.loader, 'smearing') and
|
|
492
|
+
self.inference_model.trainer.loader.smearing is not None):
|
|
493
|
+
|
|
494
|
+
q_res_min = float(self.inference_model.trainer.loader.smearing.sigma_min)
|
|
495
|
+
q_res_max = float(self.inference_model.trainer.loader.smearing.sigma_max)
|
|
496
|
+
|
|
497
|
+
# Q resolution slider
|
|
498
|
+
q_res_slider = widgets.FloatSlider(
|
|
499
|
+
description='Q resolution (dq/q):',
|
|
500
|
+
min=q_res_min,
|
|
501
|
+
max=q_res_max,
|
|
502
|
+
step=0.001,
|
|
503
|
+
value=(q_res_min + q_res_max) / 2, # Default to middle value
|
|
504
|
+
readout_format='.4f',
|
|
505
|
+
style={'description_width': '120px'},
|
|
506
|
+
layout=widgets.Layout(width='400px')
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
self.additional_sliders['q_resolution'] = q_res_slider
|
|
510
|
+
controls.append(q_res_slider)
|
|
511
|
+
controls.append(widgets.HTML("<br>"))
|
|
512
|
+
|
|
513
|
+
return widgets.VBox(controls)
|
|
514
|
+
|
|
515
|
+
def get_additional_params(self) -> Dict[str, float]:
|
|
516
|
+
"""Get current values of additional parameters"""
|
|
517
|
+
return {name: slider.value for name, slider in self.additional_sliders.items()}
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
class PlottingControls:
|
|
521
|
+
"""Plotting and visualization controls"""
|
|
522
|
+
|
|
523
|
+
def __init__(self):
|
|
524
|
+
self.widget = self._create_controls()
|
|
525
|
+
|
|
526
|
+
def _create_controls(self) -> widgets.VBox:
|
|
527
|
+
"""Create plotting controls widget"""
|
|
528
|
+
return widgets.VBox([
|
|
529
|
+
widgets.HTML("<h4>Plotting Settings</h4>"),
|
|
530
|
+
|
|
531
|
+
# Display options
|
|
532
|
+
widgets.HTML("<h5>Display Options</h5>"),
|
|
533
|
+
widgets.VBox([
|
|
534
|
+
widgets.HBox([
|
|
535
|
+
widgets.Checkbox(description='Show error bars', value=True),
|
|
536
|
+
widgets.Checkbox(description='Show q-resolution', value=False),
|
|
537
|
+
]),
|
|
538
|
+
widgets.HBox([
|
|
539
|
+
widgets.Checkbox(description='Log x-axis', value=False),
|
|
540
|
+
widgets.Checkbox(description='Plot SLD profile', value=True)
|
|
541
|
+
])
|
|
542
|
+
]),
|
|
543
|
+
|
|
544
|
+
# SLD padding
|
|
545
|
+
widgets.HTML("<h5>SLD Profile Settings</h5>"),
|
|
546
|
+
widgets.HBox([
|
|
547
|
+
widgets.FloatText(
|
|
548
|
+
description='Left padding:', value=0.2, step=0.1,
|
|
549
|
+
style={'description_width': '100px'}, layout=widgets.Layout(width='200px')
|
|
550
|
+
),
|
|
551
|
+
widgets.FloatText(
|
|
552
|
+
description='Right padding:', value=1.1, step=0.1,
|
|
553
|
+
style={'description_width': '100px'}, layout=widgets.Layout(width='200px')
|
|
554
|
+
)
|
|
555
|
+
]),
|
|
556
|
+
|
|
557
|
+
widgets.HTML("<br>"),
|
|
558
|
+
|
|
559
|
+
# Color customization
|
|
560
|
+
widgets.HTML("<h5>Color Customization</h5>"),
|
|
561
|
+
widgets.VBox([
|
|
562
|
+
widgets.ColorPicker(description='Data color:', value='#0000FF'),
|
|
563
|
+
widgets.ColorPicker(description='Error bars:', value='#800080')
|
|
564
|
+
]),
|
|
565
|
+
|
|
566
|
+
widgets.VBox([
|
|
567
|
+
widgets.ColorPicker(description='Prediction:', value='#FF0000'),
|
|
568
|
+
widgets.ColorPicker(description='Polished:', value='#FFA500')
|
|
569
|
+
]),
|
|
570
|
+
|
|
571
|
+
widgets.VBox([
|
|
572
|
+
widgets.ColorPicker(description='SLD pred:', value='#FF0000'),
|
|
573
|
+
widgets.ColorPicker(description='SLD polish:', value='#FFA500')
|
|
574
|
+
])
|
|
575
|
+
])
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class WidgetSettingsExtractor:
|
|
579
|
+
"""Utility class to extract settings from widget components"""
|
|
580
|
+
|
|
581
|
+
@staticmethod
|
|
582
|
+
def extract_settings(parameter_table: ParameterTable,
|
|
583
|
+
preprocessing: PreprocessingControls,
|
|
584
|
+
prediction: PredictionControls,
|
|
585
|
+
plotting: PlottingControls,
|
|
586
|
+
additional_params: Optional['AdditionalParametersControls'] = None,
|
|
587
|
+
data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
588
|
+
"""
|
|
589
|
+
Extract all current widget settings into a dictionary
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
parameter_table: Parameter table component
|
|
593
|
+
preprocessing: Preprocessing controls component
|
|
594
|
+
prediction: Prediction controls component
|
|
595
|
+
plotting: Plotting controls component
|
|
596
|
+
additional_params: Additional parameters controls component (optional)
|
|
597
|
+
data: Data dictionary with fallback values (optional)
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
Dictionary containing all current settings ready for preprocess_and_predict
|
|
601
|
+
"""
|
|
602
|
+
settings = {}
|
|
603
|
+
data = data or {}
|
|
604
|
+
|
|
605
|
+
# Get prior bounds from parameter table
|
|
606
|
+
settings['prior_bounds'] = parameter_table.get_prior_bounds()
|
|
607
|
+
|
|
608
|
+
# Handle data parameters with widget override or fallback to data
|
|
609
|
+
additional_param_values = {}
|
|
610
|
+
if additional_params is not None:
|
|
611
|
+
additional_param_values = additional_params.get_additional_params()
|
|
612
|
+
|
|
613
|
+
# Set q_resolution: use widget value if available, otherwise data, otherwise None
|
|
614
|
+
settings['q_resolution'] = additional_param_values.get('q_resolution', data.get('q_resolution'))
|
|
615
|
+
|
|
616
|
+
# Set ambient_sld: use widget value if available, otherwise data, otherwise None
|
|
617
|
+
settings['ambient_sld'] = additional_param_values.get('ambient_sld', data.get('ambient_sld'))
|
|
618
|
+
|
|
619
|
+
# Set other data parameters that are always from data
|
|
620
|
+
settings['reflectivity_curve'] = data.get('reflectivity_curve')
|
|
621
|
+
settings['q_values'] = data.get('q_values')
|
|
622
|
+
settings['sigmas'] = data.get('sigmas')
|
|
623
|
+
|
|
624
|
+
# Set fixed prediction parameters
|
|
625
|
+
settings['clip_prediction'] = True # Always clip predictions
|
|
626
|
+
|
|
627
|
+
# Find and extract settings from all components
|
|
628
|
+
all_widgets = [preprocessing.widget, prediction.widget, plotting.widget]
|
|
629
|
+
|
|
630
|
+
# Map widget descriptions to preprocess_and_predict parameter names
|
|
631
|
+
widget_map = {
|
|
632
|
+
# Preprocessing parameters (match preprocess_and_predict signature)
|
|
633
|
+
'truncate_index_left': ('Left index:', 'value'),
|
|
634
|
+
'truncate_index_right': ('Right index:', 'value'),
|
|
635
|
+
'enable_error_bars_filtering': ('Enable filtering', 'value'),
|
|
636
|
+
'filter_remove_singles': ('Remove singles', 'value'),
|
|
637
|
+
'filter_remove_consecutives': ('Remove consecutives', 'value'),
|
|
638
|
+
'filter_threshold': ('Threshold:', 'value'),
|
|
639
|
+
'filter_consecutive': ('Consecutive:', 'value'),
|
|
640
|
+
'filter_q_start_trunc': ('Q start trunc:', 'value'),
|
|
641
|
+
|
|
642
|
+
# Prediction parameters (match preprocess_and_predict signature)
|
|
643
|
+
'polish_prediction': ('Polish prediction', 'value'),
|
|
644
|
+
'use_sigmas_for_polishing': ('Use sigmas for polishing', 'value'),
|
|
645
|
+
'calc_pred_curve': ('Calculate curve', 'value'),
|
|
646
|
+
'calc_pred_sld_profile': ('Calculate pred SLD', 'value'),
|
|
647
|
+
'calc_polished_sld_profile': ('Calculate polished SLD', 'value'),
|
|
648
|
+
'sld_profile_padding_left': ('Left padding:', 'value'),
|
|
649
|
+
'sld_profile_padding_right': ('Right padding:', 'value'),
|
|
650
|
+
|
|
651
|
+
# Plotting parameters (not passed to preprocess_and_predict, kept for plotting)
|
|
652
|
+
'show_error_bars': ('Show error bars', 'value'),
|
|
653
|
+
'show_q_resolution': ('Show q-resolution', 'value'),
|
|
654
|
+
'log_x_axis': ('Log x-axis', 'value'),
|
|
655
|
+
'plot_sld_profile': ('Plot SLD profile', 'value'),
|
|
656
|
+
'exp_color': ('Data color:', 'value'),
|
|
657
|
+
'exp_errcolor': ('Error bars:', 'value'),
|
|
658
|
+
'pred_color': ('Prediction:', 'value'),
|
|
659
|
+
'pol_color': ('Polished:', 'value'),
|
|
660
|
+
'sld_pred_color': ('SLD pred:', 'value'),
|
|
661
|
+
'sld_pol_color': ('SLD polish:', 'value'),
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
for setting_name, (description, attr) in widget_map.items():
|
|
665
|
+
found_widgets = []
|
|
666
|
+
for widget in all_widgets:
|
|
667
|
+
found_widgets.extend(WidgetSettingsExtractor._find_widgets_by_description(widget, [description]))
|
|
668
|
+
|
|
669
|
+
if found_widgets:
|
|
670
|
+
settings[setting_name] = getattr(found_widgets[0], attr)
|
|
671
|
+
else:
|
|
672
|
+
# Set reasonable defaults
|
|
673
|
+
defaults = {
|
|
674
|
+
# Preprocessing parameters
|
|
675
|
+
'truncate_index_left': 0,
|
|
676
|
+
'truncate_index_right': 100,
|
|
677
|
+
'enable_error_bars_filtering': True,
|
|
678
|
+
'filter_remove_singles': True,
|
|
679
|
+
'filter_remove_consecutives': True,
|
|
680
|
+
'filter_threshold': 0.3,
|
|
681
|
+
'filter_consecutive': 3,
|
|
682
|
+
'filter_q_start_trunc': 0.1,
|
|
683
|
+
|
|
684
|
+
# Prediction parameters
|
|
685
|
+
'polish_prediction': True,
|
|
686
|
+
'use_sigmas_for_polishing': True,
|
|
687
|
+
'calc_pred_curve': True,
|
|
688
|
+
'calc_pred_sld_profile': True,
|
|
689
|
+
'calc_polished_sld_profile': True,
|
|
690
|
+
'sld_profile_padding_left': 0.2,
|
|
691
|
+
'sld_profile_padding_right': 1.1,
|
|
692
|
+
|
|
693
|
+
# Plotting parameters
|
|
694
|
+
'show_error_bars': True,
|
|
695
|
+
'show_q_resolution': False,
|
|
696
|
+
'log_x_axis': False,
|
|
697
|
+
'plot_sld_profile': True,
|
|
698
|
+
'exp_color': '#0000FF',
|
|
699
|
+
'exp_errcolor': '#800080',
|
|
700
|
+
'pred_color': '#FF0000',
|
|
701
|
+
'pol_color': '#FFA500',
|
|
702
|
+
'sld_pred_color': '#FF0000',
|
|
703
|
+
'sld_pol_color': '#FFA500',
|
|
704
|
+
}
|
|
705
|
+
settings[setting_name] = defaults.get(setting_name)
|
|
706
|
+
|
|
707
|
+
return settings
|
|
708
|
+
|
|
709
|
+
@staticmethod
|
|
710
|
+
def separate_settings(settings: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
711
|
+
"""
|
|
712
|
+
Separate settings into prediction parameters and plotting parameters
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
settings: Complete settings dictionary from extract_settings
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
Tuple of (prediction_params, plotting_params)
|
|
719
|
+
"""
|
|
720
|
+
# Parameters that go to preprocess_and_predict
|
|
721
|
+
prediction_param_names = {
|
|
722
|
+
'reflectivity_curve', 'q_values', 'prior_bounds', 'sigmas', 'q_resolution', 'ambient_sld',
|
|
723
|
+
'clip_prediction', 'polish_prediction', 'use_sigmas_for_polishing',
|
|
724
|
+
'calc_pred_curve', 'calc_pred_sld_profile', 'calc_polished_sld_profile',
|
|
725
|
+
'sld_profile_padding_left', 'sld_profile_padding_right',
|
|
726
|
+
'truncate_index_left', 'truncate_index_right', 'enable_error_bars_filtering',
|
|
727
|
+
'filter_threshold', 'filter_remove_singles', 'filter_remove_consecutives',
|
|
728
|
+
'filter_consecutive', 'filter_q_start_trunc'
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
# Parameters used for plotting
|
|
732
|
+
plotting_param_names = {
|
|
733
|
+
'show_error_bars', 'show_q_resolution', 'log_x_axis', 'plot_sld_profile',
|
|
734
|
+
'exp_color', 'exp_errcolor', 'pred_color', 'pol_color',
|
|
735
|
+
'sld_pred_color', 'sld_pol_color'
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
prediction_params = {k: v for k, v in settings.items() if k in prediction_param_names}
|
|
739
|
+
plotting_params = {k: v for k, v in settings.items() if k in plotting_param_names}
|
|
740
|
+
|
|
741
|
+
return prediction_params, plotting_params
|
|
742
|
+
|
|
743
|
+
@staticmethod
|
|
744
|
+
def _find_widgets_by_description(container, descriptions):
|
|
745
|
+
"""Helper to find widgets by their description"""
|
|
746
|
+
found_widgets = []
|
|
747
|
+
|
|
748
|
+
def search_widget(widget):
|
|
749
|
+
if hasattr(widget, 'description') and widget.description in descriptions:
|
|
750
|
+
found_widgets.append(widget)
|
|
751
|
+
if hasattr(widget, 'children'):
|
|
752
|
+
for child in widget.children:
|
|
753
|
+
search_widget(child)
|
|
754
|
+
|
|
755
|
+
search_widget(container)
|
|
756
|
+
return found_widgets
|
|
757
|
+
|
|
758
|
+
|