reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,758 @@
1
+ """
2
+ Jupyter Widget Components for Reflectorch
3
+
4
+ This module contains reusable widget components that can be composed
5
+ to create different interfaces for reflectometry analysis.
6
+
7
+ Components:
8
+ - ParameterTable: Interactive parameter table with sliders and results
9
+ - PreprocessingControls: Data preprocessing options
10
+ - PredictionControls: Prediction and computation settings
11
+ - PlottingControls: Plotting and visualization options
12
+ - AdditionalParametersControls: Controls for additional parameters like Q resolution
13
+ """
14
+
15
+ import numpy as np
16
+ from typing import Optional, Dict, Any, List, Tuple
17
+ import ipywidgets as widgets
18
+
19
+
20
+ class ParameterTable:
21
+ """
22
+ Interactive parameter table with sliders and result displays
23
+
24
+ Features:
25
+ - Structured table layout with aligned columns
26
+ - Real-time result updates after predictions
27
+ - Automatic slider validation
28
+ - Professional styling
29
+ """
30
+
31
+ def __init__(self, param_labels: List[str], min_bounds: np.ndarray,
32
+ max_bounds: np.ndarray, max_deltas: np.ndarray,
33
+ initial_bounds: Optional[np.ndarray] = None,
34
+ additional_params_controls: Optional['AdditionalParametersControls'] = None,
35
+ predict_button: Optional[widgets.Button] = None):
36
+ """
37
+ Initialize parameter table
38
+
39
+ Args:
40
+ param_labels: List of parameter names
41
+ min_bounds: Minimum values for each parameter
42
+ max_bounds: Maximum values for each parameter
43
+ max_deltas: Maximum allowed range for each parameter
44
+ initial_bounds: Initial bounds, shape (n_params, 2)
45
+ additional_params_controls: Optional additional parameters component
46
+ predict_button: Optional predict button to include in the table header
47
+ """
48
+ self.param_labels = param_labels
49
+ self.min_bounds = min_bounds
50
+ self.max_bounds = max_bounds
51
+ self.max_deltas = max_deltas
52
+ self.sliders = [] # Store range sliders
53
+ self.min_inputs = [] # Store min bound inputs
54
+ self.max_inputs = [] # Store max bound inputs
55
+ self.result_displays = {}
56
+ self.additional_params_controls = additional_params_controls
57
+ self.predict_button = predict_button
58
+
59
+ self.widget = self._create_table(initial_bounds)
60
+
61
+ def _create_table(self, initial_bounds: Optional[np.ndarray] = None) -> widgets.VBox:
62
+ """Create the parameter table widget"""
63
+ init_pb = np.array(initial_bounds) if initial_bounds is not None else None
64
+
65
+ # Add custom CSS for styling
66
+ custom_style = widgets.HTML("""
67
+ <style>
68
+ .widget-float-text input {
69
+ text-align: center;
70
+ }
71
+ .widget-inline-hbox .widget-readout {
72
+ margin-left: 20px !important;
73
+ }
74
+ .widget-slider .ui-slider {
75
+ width: 280px !important;
76
+ }
77
+ </style>
78
+ """)
79
+
80
+ # Create header row
81
+ header = widgets.HBox([
82
+ widgets.HTML("<b>Parameter</b>", layout=widgets.Layout(width='150px')),
83
+ widgets.HTML("<b>Prior Bounds</b>", layout=widgets.Layout(width='300px')),
84
+ widgets.HTML("<b></b>", layout=widgets.Layout(width='50px')),
85
+ widgets.HTML("<b></b>", layout=widgets.Layout(width='50px')),
86
+ widgets.HTML("<b>Predicted</b>", layout=widgets.Layout(width='100px')),
87
+ widgets.HTML("<b>Polished</b>", layout=widgets.Layout(width='100px')),
88
+ widgets.HTML("<b>Uncertainty</b>", layout=widgets.Layout(width='100px'))
89
+ ], layout=widgets.Layout(margin='5px 0px', align_items='center'))
90
+
91
+ # Create parameter rows
92
+ parameter_rows = []
93
+
94
+ if not self.param_labels:
95
+ # No parameters available - show message
96
+ no_params_message = widgets.HTML(
97
+ value="<i>No model loaded. Please load a model from the Models tab to see parameters.</i>",
98
+ layout=widgets.Layout(margin='20px 5px')
99
+ )
100
+ parameter_rows.append(no_params_message)
101
+
102
+ for i, label in enumerate(self.param_labels):
103
+ init_min = float(init_pb[i, 0]) if init_pb is not None else float(self.min_bounds[i])
104
+ init_max = float(init_pb[i, 1]) if init_pb is not None else float(min(self.min_bounds[i] + self.max_deltas[i], self.max_bounds[i]))
105
+
106
+ # Parameter label
107
+ param_label = widgets.HTML(
108
+ value=f"<b>{label}</b>",
109
+ layout=widgets.Layout(width='150px', display='flex', align_items='center', justify_content='flex-start')
110
+ )
111
+
112
+ # Range slider - updates continuously
113
+ slider = widgets.FloatRangeSlider(
114
+ value=[init_min, init_max],
115
+ min=float(self.min_bounds[i]),
116
+ max=float(self.max_bounds[i]),
117
+ step=0.01,
118
+ layout=widgets.Layout(width='180px'),
119
+ readout=False, # Hide readout since we have input boxes
120
+ continuous_update=True, # Update instantly while dragging
121
+ style={'description_width': '0px', 'handle_color': '#2196F3'}
122
+ )
123
+
124
+ # Min bound input - updates continuously
125
+ min_input = widgets.FloatText(
126
+ value=round(init_min, 2),
127
+ step=0.01,
128
+ layout=widgets.Layout(width='65px'),
129
+ continuous_update=True # Update on every keystroke
130
+ )
131
+
132
+ # Max bound input - updates continuously
133
+ max_input = widgets.FloatText(
134
+ value=round(init_max, 2),
135
+ step=0.01,
136
+ layout=widgets.Layout(width='65px'),
137
+ continuous_update=True # Update on every keystroke
138
+ )
139
+
140
+ # Result displays
141
+ predicted_display = widgets.HTML(
142
+ value="<i>-</i>",
143
+ layout=widgets.Layout(width='100px', display='flex', align_items='center', justify_content='flex-start')
144
+ )
145
+ polished_display = widgets.HTML(
146
+ value="<i>-</i>",
147
+ layout=widgets.Layout(width='100px', display='flex', align_items='center', justify_content='flex-start')
148
+ )
149
+ uncertainty_display = widgets.HTML(
150
+ value="<i>-</i>",
151
+ layout=widgets.Layout(width='100px', display='flex', align_items='center', justify_content='flex-start')
152
+ )
153
+
154
+ # Store references for updating results
155
+ self.result_displays[i] = {
156
+ 'predicted': predicted_display,
157
+ 'polished': polished_display,
158
+ 'uncertainty': uncertainty_display
159
+ }
160
+
161
+ # Add synchronization between slider and inputs
162
+ self._add_widget_synchronization(slider, min_input, max_input, i)
163
+ self.sliders.append(slider)
164
+ self.min_inputs.append(min_input)
165
+ self.max_inputs.append(max_input)
166
+
167
+ # Create row layout with vertical alignment
168
+ row = widgets.HBox([
169
+ param_label,
170
+ slider,
171
+ min_input,
172
+ max_input,
173
+ predicted_display,
174
+ polished_display,
175
+ uncertainty_display
176
+ ], layout=widgets.Layout(margin='5px 0px', align_items='center'))
177
+
178
+ parameter_rows.append(row)
179
+
180
+ self.param_title = widgets.HTML("<h4>Parameter Configuration</h4>")
181
+
182
+ # Create title row with optional predict button
183
+ if self.predict_button is not None:
184
+ title_row = widgets.HBox([
185
+ self.param_title,
186
+ self.predict_button
187
+ ], layout=widgets.Layout(justify_content='space-between', align_items='center'))
188
+ else:
189
+ title_row = self.param_title
190
+
191
+ # Create the main parameter table
192
+ main_table = widgets.VBox([
193
+ custom_style,
194
+ title_row,
195
+ header,
196
+ widgets.HTML("<hr style='margin: 5px 0px;'>"),
197
+ *parameter_rows
198
+ ])
199
+
200
+ # Create the complete widget with optional additional parameters
201
+ table_components = [main_table]
202
+
203
+ # Add additional parameters section if available
204
+ if (self.additional_params_controls is not None and
205
+ self.additional_params_controls.additional_sliders):
206
+ table_components.extend([
207
+ widgets.HTML("<br>"),
208
+ self.additional_params_controls.widget
209
+ ])
210
+
211
+ return widgets.VBox(table_components)
212
+
213
+ def _add_widget_synchronization(self, slider: widgets.FloatRangeSlider,
214
+ min_input: widgets.FloatText,
215
+ max_input: widgets.FloatText,
216
+ param_idx: int):
217
+ """Add synchronization between slider and input boxes with validation"""
218
+ max_width = float(self.max_deltas[param_idx])
219
+ global_min = float(self.min_bounds[param_idx])
220
+ global_max = float(self.max_bounds[param_idx])
221
+
222
+ # Flag to prevent infinite update loops
223
+ updating = {'active': False}
224
+
225
+ def validate_and_clamp(min_val, max_val, source='slider'):
226
+ """Apply all constraints and return valid (min, max) pair
227
+
228
+ Args:
229
+ min_val: Minimum value
230
+ max_val: Maximum value
231
+ source: Which widget triggered the change ('slider', 'min', or 'max')
232
+ """
233
+ # Clamp to global bounds first
234
+ min_val = max(global_min, min(min_val, global_max))
235
+ max_val = max(global_min, min(max_val, global_max))
236
+
237
+ # Ensure min < max with proper handling based on which widget changed
238
+ min_step = 0.01 # Minimum separation between min and max
239
+
240
+ if min_val >= max_val:
241
+ if source == 'min':
242
+ # User changed min, adjust min to be less than max
243
+ min_val = max_val - min_step
244
+ # If this pushes min below global_min, adjust max instead
245
+ if min_val < global_min:
246
+ min_val = global_min
247
+ max_val = min_val + min_step
248
+ # If max exceeds global_max, we have a problem
249
+ if max_val > global_max:
250
+ max_val = global_max
251
+ min_val = max_val - min_step
252
+ elif source == 'max':
253
+ # User changed max, adjust max to be greater than min
254
+ max_val = min_val + min_step
255
+ # If this pushes max above global_max, adjust min instead
256
+ if max_val > global_max:
257
+ max_val = global_max
258
+ min_val = max_val - min_step
259
+ # If min goes below global_min, we have a problem
260
+ if min_val < global_min:
261
+ min_val = global_min
262
+ max_val = min_val + min_step
263
+ else: # source == 'slider'
264
+ # Slider changed, just ensure there's a minimum gap
265
+ avg = (min_val + max_val) / 2
266
+ min_val = avg - min_step / 2
267
+ max_val = avg + min_step / 2
268
+
269
+ # Ensure range doesn't exceed max_width
270
+ if max_val - min_val > max_width:
271
+ if source == 'min':
272
+ # User changed min, keep max fixed and adjust min
273
+ min_val = max_val - max_width
274
+ if min_val < global_min:
275
+ min_val = global_min
276
+ max_val = min_val + max_width
277
+ elif source == 'max':
278
+ # User changed max, keep min fixed and adjust max
279
+ max_val = min_val + max_width
280
+ if max_val > global_max:
281
+ max_val = global_max
282
+ min_val = max_val - max_width
283
+ else: # source == 'slider'
284
+ # Slider changed, adjust based on center
285
+ center = (min_val + max_val) / 2
286
+ min_val = max(global_min, center - max_width / 2)
287
+ max_val = min(global_max, min_val + max_width)
288
+ if max_val - min_val > max_width:
289
+ max_val = min(global_max, center + max_width / 2)
290
+ min_val = max(global_min, max_val - max_width)
291
+
292
+ return min_val, max_val
293
+
294
+ def sync_all_widgets(min_val, max_val, source='slider'):
295
+ """Update all three widgets if values changed"""
296
+ if updating['active']:
297
+ return
298
+
299
+ updating['active'] = True
300
+
301
+ # Validate values with source information
302
+ min_val, max_val = validate_and_clamp(min_val, max_val, source)
303
+
304
+ # Round to 2 decimal places for display
305
+ min_val = round(min_val, 2)
306
+ max_val = round(max_val, 2)
307
+
308
+ # Update all widgets if needed
309
+ if slider.value != (min_val, max_val):
310
+ slider.value = (min_val, max_val)
311
+ if min_input.value != min_val:
312
+ min_input.value = min_val
313
+ if max_input.value != max_val:
314
+ max_input.value = max_val
315
+
316
+ updating['active'] = False
317
+
318
+ def on_slider_change(change):
319
+ min_val, max_val = change['new']
320
+ sync_all_widgets(min_val, max_val, source='slider')
321
+
322
+ def on_min_input_change(change):
323
+ min_val = change['new']
324
+ max_val = max_input.value
325
+ sync_all_widgets(min_val, max_val, source='min')
326
+
327
+ def on_max_input_change(change):
328
+ min_val = min_input.value
329
+ max_val = change['new']
330
+ sync_all_widgets(min_val, max_val, source='max')
331
+
332
+ # Attach observers
333
+ slider.observe(on_slider_change, names='value')
334
+ min_input.observe(on_min_input_change, names='value')
335
+ max_input.observe(on_max_input_change, names='value')
336
+
337
+ def get_prior_bounds(self) -> np.ndarray:
338
+ """Get current prior bounds from input boxes"""
339
+ if not self.min_inputs or not self.max_inputs:
340
+ return np.array([], dtype=np.float32).reshape(0, 2)
341
+ return np.array([[min_inp.value, max_inp.value]
342
+ for min_inp, max_inp in zip(self.min_inputs, self.max_inputs)],
343
+ dtype=np.float32)
344
+
345
+ def update_results(self, prediction_result: Dict[str, Any]):
346
+ """Update parameter result displays"""
347
+ if not prediction_result:
348
+ return
349
+
350
+ predicted_params = prediction_result.get('predicted_params_array', [])
351
+ polished_params = prediction_result.get('polished_params_array', None)
352
+ error_bars = prediction_result.get('polished_params_error_array', None)
353
+
354
+ for i, displays in self.result_displays.items():
355
+ # Update predicted value
356
+ if i < len(predicted_params):
357
+ pred_val = predicted_params[i]
358
+ displays['predicted'].value = f"{pred_val:.2f}"
359
+ else:
360
+ displays['predicted'].value = "<i>-</i>"
361
+
362
+ # Update polished value
363
+ if polished_params is not None and i < len(polished_params):
364
+ pol_val = polished_params[i]
365
+ displays['polished'].value = f"{pol_val:.2f}"
366
+ else:
367
+ displays['polished'].value = "<i>-</i>"
368
+
369
+ # Update uncertainty/error bars value
370
+ if error_bars is not None and i < len(error_bars):
371
+ err_val = error_bars[i]
372
+ displays['uncertainty'].value = f"±{err_val:.2f}"
373
+ else:
374
+ displays['uncertainty'].value = "<i>-</i>"
375
+
376
+
377
+ class PreprocessingControls:
378
+ """Data preprocessing controls for the widget"""
379
+
380
+ def __init__(self, n_datapoints: int):
381
+ """
382
+ Initialize preprocessing controls
383
+
384
+ Args:
385
+ n_datapoints: Number of data points in the dataset
386
+ """
387
+ self.n_datapoints = n_datapoints
388
+ self.widget = self._create_controls()
389
+
390
+ def _create_controls(self) -> widgets.VBox:
391
+ """Create preprocessing controls widget"""
392
+ return widgets.VBox([
393
+ widgets.HTML("<h4>Data Preprocessing</h4>"),
394
+
395
+ # Truncation section
396
+ widgets.HTML("<h5>Data Truncation</h5>"),
397
+ widgets.HTML("<i>Specify which data points to include in the analysis</i>"),
398
+ widgets.VBox([
399
+ widgets.IntSlider(
400
+ description='Left index:', min=0, max=max(0, self.n_datapoints-1),
401
+ step=1, value=0,
402
+ ),
403
+ widgets.IntSlider(
404
+ description='Right index:', min=1, max=self.n_datapoints,
405
+ step=1, value=self.n_datapoints,
406
+ )
407
+ ]),
408
+
409
+ widgets.HTML("<br>"),
410
+
411
+ # Error bar filtering section
412
+ widgets.HTML("<h5>Error Bar Filtering</h5>"),
413
+ widgets.HTML("<i>Filter out unreliable data points based on error bars</i>"),
414
+ widgets.VBox([
415
+ widgets.Checkbox(description='Enable filtering', value=True),
416
+ widgets.Checkbox(description='Remove singles', value=True),
417
+ widgets.Checkbox(description='Remove consecutives', value=True)
418
+ ]),
419
+ widgets.VBox([
420
+ widgets.FloatSlider(
421
+ description='Threshold:', min=0.0, max=1.0, step=0.01, value=0.3,
422
+ ),
423
+ widgets.IntSlider(
424
+ description='Consecutive:', min=1, max=10, step=1, value=3,
425
+ ),
426
+ widgets.FloatSlider(
427
+ description='Q start trunc:', min=0.0, max=1.0, step=0.01, value=0.1,
428
+ )
429
+ ])
430
+ ])
431
+
432
+
433
+ class PredictionControls:
434
+ """Prediction and computation settings controls"""
435
+
436
+ def __init__(self):
437
+ self.widget = self._create_controls()
438
+
439
+ def _create_controls(self) -> widgets.VBox:
440
+ """Create prediction controls widget"""
441
+ return widgets.VBox([
442
+ widgets.HTML("<h4>Prediction & Computation Settings</h4>"),
443
+
444
+ # Prediction settings
445
+ widgets.HTML("<h5>Prediction Options</h5>"),
446
+ widgets.VBox([
447
+ widgets.Checkbox(description='Polish prediction', value=True),
448
+ widgets.Checkbox(description='Use sigmas for polishing', value=True)
449
+ ]),
450
+
451
+ widgets.HTML("<br>"),
452
+
453
+ # Computation settings
454
+ widgets.HTML("<h5>Computation Settings</h5>"),
455
+ widgets.HTML("<i>Choose what to calculate during prediction</i>"),
456
+ widgets.VBox([
457
+ widgets.Checkbox(description='Calculate curve', value=True),
458
+ widgets.Checkbox(description='Calculate pred SLD', value=True),
459
+ widgets.Checkbox(description='Calculate polished SLD', value=True)
460
+ ])
461
+ ])
462
+
463
+
464
+ class AdditionalParametersControls:
465
+ """
466
+ Controls for additional parameters that are not part of prior bounds
467
+
468
+ These are fixed input parameters like Q resolution that are passed
469
+ separately to the inference model.
470
+ """
471
+
472
+ def __init__(self, inference_model=None):
473
+ """
474
+ Initialize additional parameters controls
475
+
476
+ Args:
477
+ inference_model: The inference model to check for additional parameters (optional)
478
+ """
479
+ self.inference_model = inference_model
480
+ self.additional_sliders = {}
481
+ self.widget = self._create_controls()
482
+
483
+ def _create_controls(self) -> widgets.VBox:
484
+ """Create additional parameters controls widget"""
485
+ controls = []
486
+
487
+ # Check if model has smearing (Q resolution)
488
+ if (self.inference_model is not None and
489
+ hasattr(self.inference_model, 'trainer') and
490
+ hasattr(self.inference_model.trainer, 'loader') and
491
+ hasattr(self.inference_model.trainer.loader, 'smearing') and
492
+ self.inference_model.trainer.loader.smearing is not None):
493
+
494
+ q_res_min = float(self.inference_model.trainer.loader.smearing.sigma_min)
495
+ q_res_max = float(self.inference_model.trainer.loader.smearing.sigma_max)
496
+
497
+ # Q resolution slider
498
+ q_res_slider = widgets.FloatSlider(
499
+ description='Q resolution (dq/q):',
500
+ min=q_res_min,
501
+ max=q_res_max,
502
+ step=0.001,
503
+ value=(q_res_min + q_res_max) / 2, # Default to middle value
504
+ readout_format='.4f',
505
+ style={'description_width': '120px'},
506
+ layout=widgets.Layout(width='400px')
507
+ )
508
+
509
+ self.additional_sliders['q_resolution'] = q_res_slider
510
+ controls.append(q_res_slider)
511
+ controls.append(widgets.HTML("<br>"))
512
+
513
+ return widgets.VBox(controls)
514
+
515
+ def get_additional_params(self) -> Dict[str, float]:
516
+ """Get current values of additional parameters"""
517
+ return {name: slider.value for name, slider in self.additional_sliders.items()}
518
+
519
+
520
+ class PlottingControls:
521
+ """Plotting and visualization controls"""
522
+
523
+ def __init__(self):
524
+ self.widget = self._create_controls()
525
+
526
+ def _create_controls(self) -> widgets.VBox:
527
+ """Create plotting controls widget"""
528
+ return widgets.VBox([
529
+ widgets.HTML("<h4>Plotting Settings</h4>"),
530
+
531
+ # Display options
532
+ widgets.HTML("<h5>Display Options</h5>"),
533
+ widgets.VBox([
534
+ widgets.HBox([
535
+ widgets.Checkbox(description='Show error bars', value=True),
536
+ widgets.Checkbox(description='Show q-resolution', value=False),
537
+ ]),
538
+ widgets.HBox([
539
+ widgets.Checkbox(description='Log x-axis', value=False),
540
+ widgets.Checkbox(description='Plot SLD profile', value=True)
541
+ ])
542
+ ]),
543
+
544
+ # SLD padding
545
+ widgets.HTML("<h5>SLD Profile Settings</h5>"),
546
+ widgets.HBox([
547
+ widgets.FloatText(
548
+ description='Left padding:', value=0.2, step=0.1,
549
+ style={'description_width': '100px'}, layout=widgets.Layout(width='200px')
550
+ ),
551
+ widgets.FloatText(
552
+ description='Right padding:', value=1.1, step=0.1,
553
+ style={'description_width': '100px'}, layout=widgets.Layout(width='200px')
554
+ )
555
+ ]),
556
+
557
+ widgets.HTML("<br>"),
558
+
559
+ # Color customization
560
+ widgets.HTML("<h5>Color Customization</h5>"),
561
+ widgets.VBox([
562
+ widgets.ColorPicker(description='Data color:', value='#0000FF'),
563
+ widgets.ColorPicker(description='Error bars:', value='#800080')
564
+ ]),
565
+
566
+ widgets.VBox([
567
+ widgets.ColorPicker(description='Prediction:', value='#FF0000'),
568
+ widgets.ColorPicker(description='Polished:', value='#FFA500')
569
+ ]),
570
+
571
+ widgets.VBox([
572
+ widgets.ColorPicker(description='SLD pred:', value='#FF0000'),
573
+ widgets.ColorPicker(description='SLD polish:', value='#FFA500')
574
+ ])
575
+ ])
576
+
577
+
578
+ class WidgetSettingsExtractor:
579
+ """Utility class to extract settings from widget components"""
580
+
581
+ @staticmethod
582
+ def extract_settings(parameter_table: ParameterTable,
583
+ preprocessing: PreprocessingControls,
584
+ prediction: PredictionControls,
585
+ plotting: PlottingControls,
586
+ additional_params: Optional['AdditionalParametersControls'] = None,
587
+ data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
588
+ """
589
+ Extract all current widget settings into a dictionary
590
+
591
+ Args:
592
+ parameter_table: Parameter table component
593
+ preprocessing: Preprocessing controls component
594
+ prediction: Prediction controls component
595
+ plotting: Plotting controls component
596
+ additional_params: Additional parameters controls component (optional)
597
+ data: Data dictionary with fallback values (optional)
598
+
599
+ Returns:
600
+ Dictionary containing all current settings ready for preprocess_and_predict
601
+ """
602
+ settings = {}
603
+ data = data or {}
604
+
605
+ # Get prior bounds from parameter table
606
+ settings['prior_bounds'] = parameter_table.get_prior_bounds()
607
+
608
+ # Handle data parameters with widget override or fallback to data
609
+ additional_param_values = {}
610
+ if additional_params is not None:
611
+ additional_param_values = additional_params.get_additional_params()
612
+
613
+ # Set q_resolution: use widget value if available, otherwise data, otherwise None
614
+ settings['q_resolution'] = additional_param_values.get('q_resolution', data.get('q_resolution'))
615
+
616
+ # Set ambient_sld: use widget value if available, otherwise data, otherwise None
617
+ settings['ambient_sld'] = additional_param_values.get('ambient_sld', data.get('ambient_sld'))
618
+
619
+ # Set other data parameters that are always from data
620
+ settings['reflectivity_curve'] = data.get('reflectivity_curve')
621
+ settings['q_values'] = data.get('q_values')
622
+ settings['sigmas'] = data.get('sigmas')
623
+
624
+ # Set fixed prediction parameters
625
+ settings['clip_prediction'] = True # Always clip predictions
626
+
627
+ # Find and extract settings from all components
628
+ all_widgets = [preprocessing.widget, prediction.widget, plotting.widget]
629
+
630
+ # Map widget descriptions to preprocess_and_predict parameter names
631
+ widget_map = {
632
+ # Preprocessing parameters (match preprocess_and_predict signature)
633
+ 'truncate_index_left': ('Left index:', 'value'),
634
+ 'truncate_index_right': ('Right index:', 'value'),
635
+ 'enable_error_bars_filtering': ('Enable filtering', 'value'),
636
+ 'filter_remove_singles': ('Remove singles', 'value'),
637
+ 'filter_remove_consecutives': ('Remove consecutives', 'value'),
638
+ 'filter_threshold': ('Threshold:', 'value'),
639
+ 'filter_consecutive': ('Consecutive:', 'value'),
640
+ 'filter_q_start_trunc': ('Q start trunc:', 'value'),
641
+
642
+ # Prediction parameters (match preprocess_and_predict signature)
643
+ 'polish_prediction': ('Polish prediction', 'value'),
644
+ 'use_sigmas_for_polishing': ('Use sigmas for polishing', 'value'),
645
+ 'calc_pred_curve': ('Calculate curve', 'value'),
646
+ 'calc_pred_sld_profile': ('Calculate pred SLD', 'value'),
647
+ 'calc_polished_sld_profile': ('Calculate polished SLD', 'value'),
648
+ 'sld_profile_padding_left': ('Left padding:', 'value'),
649
+ 'sld_profile_padding_right': ('Right padding:', 'value'),
650
+
651
+ # Plotting parameters (not passed to preprocess_and_predict, kept for plotting)
652
+ 'show_error_bars': ('Show error bars', 'value'),
653
+ 'show_q_resolution': ('Show q-resolution', 'value'),
654
+ 'log_x_axis': ('Log x-axis', 'value'),
655
+ 'plot_sld_profile': ('Plot SLD profile', 'value'),
656
+ 'exp_color': ('Data color:', 'value'),
657
+ 'exp_errcolor': ('Error bars:', 'value'),
658
+ 'pred_color': ('Prediction:', 'value'),
659
+ 'pol_color': ('Polished:', 'value'),
660
+ 'sld_pred_color': ('SLD pred:', 'value'),
661
+ 'sld_pol_color': ('SLD polish:', 'value'),
662
+ }
663
+
664
+ for setting_name, (description, attr) in widget_map.items():
665
+ found_widgets = []
666
+ for widget in all_widgets:
667
+ found_widgets.extend(WidgetSettingsExtractor._find_widgets_by_description(widget, [description]))
668
+
669
+ if found_widgets:
670
+ settings[setting_name] = getattr(found_widgets[0], attr)
671
+ else:
672
+ # Set reasonable defaults
673
+ defaults = {
674
+ # Preprocessing parameters
675
+ 'truncate_index_left': 0,
676
+ 'truncate_index_right': 100,
677
+ 'enable_error_bars_filtering': True,
678
+ 'filter_remove_singles': True,
679
+ 'filter_remove_consecutives': True,
680
+ 'filter_threshold': 0.3,
681
+ 'filter_consecutive': 3,
682
+ 'filter_q_start_trunc': 0.1,
683
+
684
+ # Prediction parameters
685
+ 'polish_prediction': True,
686
+ 'use_sigmas_for_polishing': True,
687
+ 'calc_pred_curve': True,
688
+ 'calc_pred_sld_profile': True,
689
+ 'calc_polished_sld_profile': True,
690
+ 'sld_profile_padding_left': 0.2,
691
+ 'sld_profile_padding_right': 1.1,
692
+
693
+ # Plotting parameters
694
+ 'show_error_bars': True,
695
+ 'show_q_resolution': False,
696
+ 'log_x_axis': False,
697
+ 'plot_sld_profile': True,
698
+ 'exp_color': '#0000FF',
699
+ 'exp_errcolor': '#800080',
700
+ 'pred_color': '#FF0000',
701
+ 'pol_color': '#FFA500',
702
+ 'sld_pred_color': '#FF0000',
703
+ 'sld_pol_color': '#FFA500',
704
+ }
705
+ settings[setting_name] = defaults.get(setting_name)
706
+
707
+ return settings
708
+
709
+ @staticmethod
710
+ def separate_settings(settings: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
711
+ """
712
+ Separate settings into prediction parameters and plotting parameters
713
+
714
+ Args:
715
+ settings: Complete settings dictionary from extract_settings
716
+
717
+ Returns:
718
+ Tuple of (prediction_params, plotting_params)
719
+ """
720
+ # Parameters that go to preprocess_and_predict
721
+ prediction_param_names = {
722
+ 'reflectivity_curve', 'q_values', 'prior_bounds', 'sigmas', 'q_resolution', 'ambient_sld',
723
+ 'clip_prediction', 'polish_prediction', 'use_sigmas_for_polishing',
724
+ 'calc_pred_curve', 'calc_pred_sld_profile', 'calc_polished_sld_profile',
725
+ 'sld_profile_padding_left', 'sld_profile_padding_right',
726
+ 'truncate_index_left', 'truncate_index_right', 'enable_error_bars_filtering',
727
+ 'filter_threshold', 'filter_remove_singles', 'filter_remove_consecutives',
728
+ 'filter_consecutive', 'filter_q_start_trunc'
729
+ }
730
+
731
+ # Parameters used for plotting
732
+ plotting_param_names = {
733
+ 'show_error_bars', 'show_q_resolution', 'log_x_axis', 'plot_sld_profile',
734
+ 'exp_color', 'exp_errcolor', 'pred_color', 'pol_color',
735
+ 'sld_pred_color', 'sld_pol_color'
736
+ }
737
+
738
+ prediction_params = {k: v for k, v in settings.items() if k in prediction_param_names}
739
+ plotting_params = {k: v for k, v in settings.items() if k in plotting_param_names}
740
+
741
+ return prediction_params, plotting_params
742
+
743
+ @staticmethod
744
+ def _find_widgets_by_description(container, descriptions):
745
+ """Helper to find widgets by their description"""
746
+ found_widgets = []
747
+
748
+ def search_widget(widget):
749
+ if hasattr(widget, 'description') and widget.description in descriptions:
750
+ found_widgets.append(widget)
751
+ if hasattr(widget, 'children'):
752
+ for child in widget.children:
753
+ search_widget(child)
754
+
755
+ search_widget(container)
756
+ return found_widgets
757
+
758
+