reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,625 @@
1
+ """
2
+ Reflectorch Jupyter Widget
3
+ """
4
+ import torch
5
+ import numpy as np
6
+ from typing import Optional, Union, TYPE_CHECKING
7
+
8
+ from torch.types import Device
9
+
10
+ if TYPE_CHECKING:
11
+ from reflectorch.inference.inference_model import InferenceModel
12
+ import ipywidgets as widgets
13
+ from IPython.display import display
14
+
15
+ from reflectorch.extensions.jupyter.plotly_plot_manager import (
16
+ PlotlyPlotManager,
17
+ plot_reflectivity_only,
18
+ plot_sld_only,
19
+ )
20
+ from reflectorch.extensions.jupyter.components import (
21
+ ParameterTable,
22
+ PreprocessingControls,
23
+ PredictionControls,
24
+ PlottingControls,
25
+ AdditionalParametersControls,
26
+ WidgetSettingsExtractor,
27
+ )
28
+ from reflectorch.extensions.jupyter.model_selection import ModelSelection
29
+ from reflectorch.extensions.jupyter.log_widget import LogWidget
30
+
31
+ from huggingface_hub.utils import disable_progress_bars
32
+
33
+ # that causes some Rust related errors when downloading models from Huggingface
34
+ disable_progress_bars()
35
+
36
+
37
+ class ReflectorchPlotlyWidget:
38
+ """
39
+ Interactive Jupyter Widget for Reflectometry Analysis using Plotly
40
+
41
+ Attributes:
42
+ model: The InferenceModel instance
43
+ prediction_result: Latest prediction results
44
+ plot_manager: PlotlyPlotManager for handling interactive plots
45
+
46
+ Example:
47
+ ```python
48
+ from reflectorch.inference import InferenceModel
49
+ from reflectorch.extensions.jupyter import ReflectorchPlotlyWidget
50
+
51
+ model = InferenceModel('config.yaml')
52
+ widget = ReflectorchPlotlyWidget(model)
53
+
54
+ widget.display(
55
+ reflectivity_curve=data,
56
+ q_values=q_values,
57
+ sigmas=sigmas
58
+ )
59
+
60
+ # Access results
61
+ results = widget.prediction_result
62
+ ```
63
+ """
64
+
65
+ def __init__(self,
66
+ reflectivity_curve: np.ndarray,
67
+ q_values: np.ndarray,
68
+ sigmas: Optional[np.ndarray] = None,
69
+ q_resolution: Optional[Union[float, np.ndarray]] = None,
70
+ initial_prior_bounds: Optional[np.ndarray] = None,
71
+ ambient_sld: Optional[float] = None,
72
+ model: Optional["InferenceModel"] = None,
73
+ root_dir: Optional[str] = None,
74
+ ):
75
+ """
76
+ Initialize the Reflectorch Plotly widget
77
+
78
+ Args:
79
+ reflectivity_curve: Experimental reflectivity data
80
+ q_values: Momentum transfer values
81
+ sigmas: Experimental uncertainties (optional)
82
+ q_resolution: Q-resolution, float or array (optional)
83
+ initial_prior_bounds: Initial bounds for priors, shape (n_params, 2)
84
+ ambient_sld: Ambient SLD value (optional)
85
+ model: InferenceModel instance for making predictions (optional)
86
+ root_dir: Root directory for the model (optional)
87
+ """
88
+ self.model = model
89
+ self.prediction_result = None
90
+ self.plot_manager = PlotlyPlotManager()
91
+ self.root_dir = root_dir
92
+ # Store data for prediction
93
+ self._data = {
94
+ 'reflectivity_curve': reflectivity_curve,
95
+ 'q_values': q_values,
96
+ 'sigmas': sigmas,
97
+ 'q_resolution': q_resolution,
98
+ 'ambient_sld': ambient_sld
99
+ }
100
+
101
+ self.initial_prior_bounds = initial_prior_bounds
102
+
103
+ # Widget components (initialized when display is called)
104
+ self.parameter_table = None
105
+ self.preprocessing_controls = None
106
+ self.prediction_controls = None
107
+ self.plotting_controls = None
108
+ self.additional_params_controls = None
109
+ self.model_selection = None
110
+ self.tabs_widget = None # Store reference to tabs for updating
111
+ self.predict_button = None # Store reference to predict button
112
+ self.log_widget = None # Store reference to log widget
113
+
114
+ if self.model is not None:
115
+ self._validate_model()
116
+
117
+ def _create_parameter_components(self, initial_prior_bounds=None, predict_button=None):
118
+ """
119
+ Create parameter table and additional parameters controls based on current model
120
+
121
+ Args:
122
+ initial_prior_bounds: Initial bounds for priors, shape (n_params, 2)
123
+ predict_button: Optional predict button to include in parameter table
124
+
125
+ Returns:
126
+ Tuple of (parameter_table, additional_params_controls)
127
+ """
128
+ # Create additional params controls first so we can pass it to parameter table
129
+ additional_params_controls = AdditionalParametersControls(self.model)
130
+
131
+ # Get model parameters info if model is available
132
+ if self.model is not None:
133
+ param_labels = self.model.trainer.loader.prior_sampler.param_model.get_param_labels()
134
+ min_bounds = self.model.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
135
+ max_bounds = self.model.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
136
+ max_deltas = self.model.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
137
+ else:
138
+ # Default empty parameters when no model is loaded
139
+ param_labels = []
140
+ min_bounds = np.array([])
141
+ max_bounds = np.array([])
142
+ max_deltas = np.array([])
143
+
144
+ parameter_table = ParameterTable(
145
+ param_labels, min_bounds, max_bounds, max_deltas, initial_prior_bounds,
146
+ additional_params_controls=additional_params_controls,
147
+ predict_button=predict_button
148
+ )
149
+
150
+ return parameter_table, additional_params_controls
151
+
152
+ def _validate_model(self):
153
+ """Validate that the model has required attributes"""
154
+ required_attrs = ['trainer', 'preprocess_and_predict']
155
+ for attr in required_attrs:
156
+ if not hasattr(self.model, attr):
157
+ raise ValueError(f"Model must have '{attr}' attribute")
158
+
159
+
160
+ def display(self,
161
+ controls_width: int = 700,
162
+ plot_width: int = 400,
163
+ plot_height: int = 300
164
+ ):
165
+ """
166
+ Display the widget interface
167
+
168
+ Parameters:
169
+ ----------
170
+ reflectivity_curve: Experimental reflectivity data
171
+ q_values: Momentum transfer values (required)
172
+ sigmas: Experimental uncertainties (optional)
173
+ q_resolution: Q-resolution, float or array (optional)
174
+ initial_prior_bounds: Initial bounds for priors, shape (n_params, 2)
175
+ ambient_sld: Ambient SLD value (optional)
176
+ controls_width: Width of the controls area in pixels. Default is 700px.
177
+ plot_width: Width of the plots in pixels. Default is 400px.
178
+ plot_height: Height of the plots in pixels. Default is 300px.
179
+ """
180
+
181
+ # Create predict button (disabled by default if no model)
182
+ self.predict_button = widgets.Button(
183
+ description="Predict",
184
+ button_style='primary',
185
+ tooltip='Run prediction with current settings' if self.model else 'Load a model first',
186
+ layout=widgets.Layout(width='120px'),
187
+ disabled=(self.model is None)
188
+ )
189
+
190
+ # Create widget components
191
+ self.parameter_table, self.additional_params_controls = self._create_parameter_components(
192
+ self.initial_prior_bounds, self.predict_button
193
+ )
194
+ self.preprocessing_controls = PreprocessingControls(len(self._data['reflectivity_curve']))
195
+ self.prediction_controls = PredictionControls()
196
+ self.plotting_controls = PlottingControls()
197
+ self.model_selection = ModelSelection()
198
+
199
+ # Create log widget
200
+ self.log_widget = LogWidget()
201
+
202
+ # Create tabbed interface
203
+ # if model is not provided, then Models tabs goes first.
204
+ tabs = widgets.Tab()
205
+ if self.model is None:
206
+ tabs.children = [
207
+ self.model_selection.widget,
208
+ self.parameter_table.widget,
209
+ self.preprocessing_controls.widget,
210
+ self.prediction_controls.widget,
211
+ self.plotting_controls.widget
212
+ ]
213
+ tabs.titles = ['Models', 'Parameters', 'Preprocessing', 'Prediction', 'Plotting']
214
+ else:
215
+ tabs.children = [
216
+ self.parameter_table.widget,
217
+ self.preprocessing_controls.widget,
218
+ self.prediction_controls.widget,
219
+ self.plotting_controls.widget,
220
+ self.model_selection.widget
221
+ ]
222
+ tabs.titles = ['Parameters', 'Preprocessing', 'Prediction', 'Plotting', 'Models']
223
+
224
+ self.tab_indices = dict(zip(tabs.titles, range(len(tabs.titles))))
225
+ # Store reference to tabs for later updates
226
+ self.tabs_widget = tabs
227
+
228
+ # Create plot containers (initially empty)
229
+ reflectivity_plot_container = widgets.VBox([])
230
+ sld_plot_container = widgets.VBox([])
231
+
232
+ # Combine plots vertically on the right
233
+ plot_area = widgets.VBox([
234
+ reflectivity_plot_container,
235
+ sld_plot_container
236
+ ], layout=widgets.Layout(margin='50px 0px 0px 0px'))
237
+
238
+ # Main layout with controls on left, plots on right
239
+ header = widgets.HTML("<h2>Reflectorch Widget</h2>")
240
+
241
+ controls_area = widgets.VBox([
242
+ header,
243
+ tabs
244
+ ], layout=widgets.Layout(width=f'{controls_width}px'))
245
+
246
+ # Horizontal layout: controls on left, plots on right
247
+ main_content = widgets.HBox([
248
+ controls_area,
249
+ plot_area
250
+ ])
251
+
252
+ # Complete layout with log at the bottom
253
+ main_layout = widgets.VBox([
254
+ main_content,
255
+ self.log_widget.widget
256
+ ])
257
+
258
+ # Add border around the entire widget
259
+ container = widgets.VBox([main_layout], layout=widgets.Layout(
260
+ border='2px solid #d0d0d0',
261
+ border_radius='8px',
262
+ padding='15px',
263
+ margin='10px',
264
+ background_color='#fafafa'
265
+ ))
266
+ display(container)
267
+
268
+ # Setup event handlers
269
+ self._setup_event_handlers(self.predict_button, reflectivity_plot_container, sld_plot_container, container)
270
+
271
+ # Setup model selection integration
272
+ self._setup_model_selection_integration()
273
+
274
+ # Setup truncation synchronization
275
+ self._setup_truncation_sync()
276
+
277
+ # Create initial plots with experimental data
278
+ self._create_initial_plots(reflectivity_plot_container, sld_plot_container, plot_width, plot_height)
279
+
280
+ # Setup reactive plot updates for plotting controls
281
+ self._setup_reactive_plot_updates(reflectivity_plot_container, sld_plot_container)
282
+
283
+ def _create_initial_plots(self, reflectivity_container, sld_container, plot_width, plot_height):
284
+ """Create initial plots showing experimental data"""
285
+ try:
286
+ # Use default settings for initial plots
287
+ settings = {
288
+ 'show_error_bars': True,
289
+ 'show_q_resolution': True,
290
+ 'exp_color': 'blue',
291
+ 'exp_errcolor': 'purple',
292
+ 'log_x_axis': False,
293
+ 'plot_sld_profile': True
294
+ }
295
+
296
+ # Plot initial experimental data
297
+ self._plot_initial_data(settings, reflectivity_container, sld_container, plot_width, plot_height)
298
+
299
+ except Exception as e:
300
+ if self.log_widget:
301
+ self.log_widget.log(f"⚠️ Could not create initial plots: {str(e)}")
302
+ else:
303
+ print(f"⚠️ Could not create initial plots: {str(e)}")
304
+
305
+ def _plot_initial_data(self, settings, reflectivity_container, sld_container, plot_width, plot_height):
306
+ """Plot only experimental data before any prediction"""
307
+ # Prepare experimental data for plotting
308
+ q_exp_plot = self._data['q_values']
309
+ r_exp_plot = self._data['reflectivity_curve']
310
+ yerr_plot = self._data['sigmas'] if settings['show_error_bars'] and self._data['sigmas'] is not None else None
311
+ xerr_plot = self._data['q_resolution'] if settings['show_q_resolution'] and self._data['q_resolution'] is not None else None
312
+
313
+ # Create reflectivity plot
314
+ reflectivity_fig = plot_reflectivity_only(
315
+ plot_manager=self.plot_manager,
316
+ figure_id="reflectivity_plot",
317
+ q_exp=q_exp_plot, r_exp=r_exp_plot, yerr=yerr_plot, xerr=xerr_plot,
318
+ exp_color=settings['exp_color'],
319
+ exp_errcolor=settings['exp_errcolor'],
320
+ exp_label='experimental data',
321
+ logx=settings['log_x_axis'], logy=True,
322
+ width=plot_width, height=plot_height
323
+ )
324
+
325
+ # Get the reflectivity plotly widget and add it to container
326
+ reflectivity_widget = self.plot_manager.get_widget("reflectivity_plot")
327
+ reflectivity_container.children = [reflectivity_widget]
328
+
329
+ # Create empty SLD plot (will be populated after prediction)
330
+ if settings['plot_sld_profile']:
331
+ sld_fig = plot_sld_only(
332
+ plot_manager=self.plot_manager,
333
+ figure_id="sld_plot",
334
+ z_sld=None, sld_pred=None, sld_pol=None,
335
+ width=plot_width, height=plot_height
336
+ )
337
+
338
+ # Get the SLD plotly widget and add it to container
339
+ sld_widget = self.plot_manager.get_widget("sld_plot")
340
+ sld_container.children = [sld_widget]
341
+
342
+ def _setup_event_handlers(self, predict_button, reflectivity_container, sld_container, container):
343
+ """Setup button event handlers"""
344
+
345
+ def on_predict(_):
346
+ """Handle predict button click"""
347
+ with self.log_widget.capture_prints():
348
+ # Store original button state
349
+ original_description = predict_button.description
350
+ original_disabled = predict_button.disabled
351
+
352
+ try:
353
+ # Check if model is loaded
354
+ if self.model is None:
355
+ print("❌ No model loaded. Please load a model from the Models tab first.")
356
+ return
357
+
358
+ # Disable button and show "Predicting..."
359
+ predict_button.disabled = True
360
+ predict_button.description = "Predicting..."
361
+
362
+ # Extract settings from all components with data fallback
363
+ settings = WidgetSettingsExtractor.extract_settings(
364
+ self.parameter_table,
365
+ self.preprocessing_controls,
366
+ self.prediction_controls,
367
+ self.plotting_controls,
368
+ self.additional_params_controls,
369
+ data=self._data
370
+ )
371
+
372
+ # Separate prediction and plotting parameters
373
+ prediction_params, plotting_params = WidgetSettingsExtractor.separate_settings(settings)
374
+
375
+ # Run prediction with all parameters
376
+ prediction_result = self.model.preprocess_and_predict(**prediction_params)
377
+
378
+ # Update parameter table with results
379
+ self.parameter_table.update_results(prediction_result)
380
+
381
+ # Plot results
382
+ self._plot_results(prediction_result, plotting_params, reflectivity_container, sld_container)
383
+
384
+ # Store results
385
+ self.prediction_result = prediction_result
386
+
387
+ except Exception as e:
388
+ print(f"❌ Prediction error: {str(e)}")
389
+ import traceback
390
+ traceback.print_exc()
391
+ finally:
392
+ # Always restore button state, even if there was an error
393
+ predict_button.description = original_description
394
+ predict_button.disabled = original_disabled
395
+
396
+ # Connect event handlers
397
+ predict_button.on_click(on_predict)
398
+
399
+ def _setup_model_selection_integration(self):
400
+ """Setup integration with model selection tab"""
401
+ if not self.model_selection:
402
+ return
403
+
404
+ load_button = self.model_selection._widgets['download_button']
405
+
406
+ def on_load_model(_):
407
+ """Handle load model button click"""
408
+ with self.log_widget.capture_prints():
409
+ # Store original button state
410
+ original_description = load_button.description
411
+ original_disabled = load_button.disabled
412
+
413
+ try:
414
+ selected_model_info = self.model_selection.get_selected_model_info()
415
+ if selected_model_info is None:
416
+ print("❌ No model selected")
417
+ return
418
+
419
+ # Disable button and show "Downloading..."
420
+ load_button.disabled = True
421
+ load_button.description = "Downloading..."
422
+
423
+ print(f"🔄 Loading model: {selected_model_info['model_name']} ...")
424
+
425
+ from reflectorch.inference.inference_model import InferenceModel
426
+
427
+ device: Device = 'cuda' if torch.cuda.is_available() else 'cpu'
428
+
429
+ # Create new model instance
430
+ new_model = InferenceModel(
431
+ config_name=selected_model_info['config_name'],
432
+ repo_id=selected_model_info['repo_id'],
433
+ device=device,
434
+ root_dir=self.root_dir
435
+ )
436
+
437
+ print(f"📥 Model downloaded and initialized successfully")
438
+
439
+ # Replace current model
440
+ self.model = new_model
441
+
442
+ # Enable the predict button since we now have a model
443
+ if self.predict_button:
444
+ self.predict_button.disabled = False
445
+ self.predict_button.tooltip = 'Run prediction with current settings'
446
+
447
+ # Create new parameter components for the new model (reuse the same predict button)
448
+ new_parameter_table, new_additional_params_controls = self._create_parameter_components(predict_button=self.predict_button)
449
+
450
+ # Update the parameter table title with the new model name
451
+ new_parameter_table.param_title.value = f"<h4>Parameter Configuration for {selected_model_info['model_name']}</h4>"
452
+
453
+ # Update the parameter table in the tabs
454
+ if self.tabs_widget and hasattr(self.tabs_widget, 'children'):
455
+ children_list = list(self.tabs_widget.children)
456
+ children_list[self.tab_indices['Parameters']] = new_parameter_table.widget
457
+ self.tabs_widget.children = children_list
458
+
459
+ # Update our references to the components
460
+ self.parameter_table = new_parameter_table
461
+ self.additional_params_controls = new_additional_params_controls
462
+
463
+ # Switch to Parameters tab automatically
464
+ self.tabs_widget.selected_index = self.tab_indices['Parameters']
465
+
466
+ # Clear previous prediction results
467
+ self.prediction_result = None
468
+
469
+ # Get parameter info for success message
470
+ param_labels = self.model.trainer.loader.prior_sampler.param_model.get_param_labels()
471
+ max_layers = self.model.trainer.loader.prior_sampler.max_num_layers
472
+
473
+ print(f"✅ Model loaded successfully: {selected_model_info['config_name']}")
474
+ print(f"Parameters: {len(param_labels)} parameters, {max_layers} max layers")
475
+ print("💡 Tip: Go to the Parameters tab to see the updated parameter ranges for the new model")
476
+
477
+ except Exception as e:
478
+ print(f"❌ Error loading model: {str(e)}")
479
+ import traceback
480
+ traceback.print_exc()
481
+ finally:
482
+ # Always restore button state, even if there was an error
483
+ load_button.description = original_description
484
+ load_button.disabled = original_disabled
485
+
486
+ # Connect the load button
487
+ load_button.on_click(on_load_model)
488
+
489
+ def _plot_results(self, prediction_result, settings, reflectivity_container, sld_container):
490
+ """Plot prediction results with current settings"""
491
+ # Prepare plotting data
492
+ q_exp_plot = self._data['q_values']
493
+ r_exp_plot = self._data['reflectivity_curve']
494
+ yerr_plot = self._data['sigmas'] if settings['show_error_bars'] else None
495
+ xerr_plot = self._data['q_resolution'] if settings['show_q_resolution'] else None
496
+
497
+ q_pred = prediction_result.get('q_plot_pred', None)
498
+ r_pred = prediction_result.get('predicted_curve', None)
499
+ q_pol = self._data['q_values'] if 'polished_curve' in prediction_result else None
500
+ r_pol = prediction_result.get('polished_curve', None)
501
+
502
+ z_sld = prediction_result.get('predicted_sld_xaxis', None)
503
+ sld_pred = prediction_result.get('predicted_sld_profile', None)
504
+ sld_pol = prediction_result.get('sld_profile_polished', None)
505
+
506
+ # Handle complex SLD
507
+ if sld_pred is not None and np.iscomplexobj(sld_pred):
508
+ sld_pred = sld_pred.real
509
+ if sld_pol is not None and np.iscomplexobj(sld_pol):
510
+ sld_pol = sld_pol.real
511
+
512
+ # Update reflectivity plot
513
+ reflectivity_fig = plot_reflectivity_only(
514
+ plot_manager=self.plot_manager,
515
+ figure_id="reflectivity_plot",
516
+ q_exp=q_exp_plot, r_exp=r_exp_plot, yerr=yerr_plot, xerr=xerr_plot,
517
+ exp_color=settings['exp_color'],
518
+ exp_errcolor=settings['exp_errcolor'],
519
+ q_pred=q_pred, r_pred=r_pred, pred_color=settings['pred_color'],
520
+ q_pol=q_pol, r_pol=r_pol, pol_color=settings['pol_color'],
521
+ logx=settings['log_x_axis'], logy=True,
522
+ width=600, height=300
523
+ )
524
+
525
+ # Update SLD plot if requested
526
+ if settings['plot_sld_profile']:
527
+ sld_fig = plot_sld_only(
528
+ plot_manager=self.plot_manager,
529
+ figure_id="sld_plot",
530
+ z_sld=z_sld, sld_pred=sld_pred, sld_pol=sld_pol,
531
+ sld_pred_color=settings['sld_pred_color'],
532
+ sld_pol_color=settings['sld_pol_color'],
533
+ width=600, height=250
534
+ )
535
+
536
+ def _setup_truncation_sync(self):
537
+ """Setup synchronization between truncation sliders"""
538
+ # Find truncation widgets
539
+ trunc_widgets = WidgetSettingsExtractor._find_widgets_by_description(
540
+ self.preprocessing_controls.widget, ['Left index:', 'Right index:']
541
+ )
542
+
543
+ if len(trunc_widgets) == 2:
544
+ trunc_left, trunc_right = trunc_widgets
545
+
546
+ def sync_truncation(_):
547
+ if trunc_left.value >= trunc_right.value:
548
+ trunc_left.value = max(0, trunc_right.value - 1)
549
+
550
+ trunc_left.observe(sync_truncation, names='value')
551
+ trunc_right.observe(sync_truncation, names='value')
552
+
553
+ def _setup_reactive_plot_updates(self, reflectivity_container, sld_container):
554
+ """Setup observers for plotting controls that should trigger immediate plot updates"""
555
+ if not self.plotting_controls:
556
+ return
557
+
558
+ # Find plotting controls that should trigger plot updates
559
+ reactive_controls = WidgetSettingsExtractor._find_widgets_by_description(
560
+ self.plotting_controls.widget,
561
+ [
562
+ 'Show error bars', 'Show q-resolution', 'Log x-axis', 'Plot SLD profile',
563
+ 'Data color:', 'Error bars:', 'Prediction:', 'Polished:',
564
+ 'SLD pred:', 'SLD polish:'
565
+ ]
566
+ )
567
+
568
+ def update_plot_on_change(change):
569
+ """Update plot when plotting controls change"""
570
+ # Only update if we have prediction results to show
571
+ if self.prediction_result is not None:
572
+ try:
573
+ # Extract current settings
574
+ settings = WidgetSettingsExtractor.extract_settings(
575
+ self.parameter_table,
576
+ self.preprocessing_controls,
577
+ self.prediction_controls,
578
+ self.plotting_controls
579
+ )
580
+
581
+ # Update plot with new settings
582
+ self._plot_results(self.prediction_result, settings, reflectivity_container, sld_container)
583
+
584
+ except Exception as e:
585
+ if self.log_widget:
586
+ self.log_widget.log(f"⚠️ Error updating plot: {str(e)}")
587
+ else:
588
+ print(f"⚠️ Error updating plot: {str(e)}")
589
+ else:
590
+ # If no prediction results yet, just update the initial plot
591
+ try:
592
+ self._update_initial_plot_style(reflectivity_container, sld_container)
593
+ except Exception as e:
594
+ if self.log_widget:
595
+ self.log_widget.log(f"⚠️ Error updating initial plot: {str(e)}")
596
+ else:
597
+ print(f"⚠️ Error updating initial plot: {str(e)}")
598
+
599
+ # Setup observers for all reactive controls
600
+ for control in reactive_controls:
601
+ if hasattr(control, 'observe'):
602
+ control.observe(update_plot_on_change, names='value')
603
+
604
+ def _update_initial_plot_style(self, reflectivity_container, sld_container):
605
+ """Update initial plot styling based on current control settings"""
606
+ if not self.plotting_controls:
607
+ return
608
+
609
+ try:
610
+ # Extract current plotting settings
611
+ settings = WidgetSettingsExtractor.extract_settings(
612
+ self.parameter_table,
613
+ self.preprocessing_controls,
614
+ self.prediction_controls,
615
+ self.plotting_controls
616
+ )
617
+
618
+ # Update initial plot with new styling
619
+ self._plot_initial_data(settings, reflectivity_container, sld_container)
620
+
621
+ except Exception as e:
622
+ if self.log_widget:
623
+ self.log_widget.log(f"⚠️ Error updating initial plot style: {str(e)}")
624
+ else:
625
+ print(f"⚠️ Error updating initial plot style: {str(e)}")
@@ -0,0 +1,5 @@
1
+ from reflectorch.extensions.matplotlib.losses import plot_losses
2
+
3
+ __all__ = [
4
+ "plot_losses",
5
+ ]
@@ -0,0 +1,32 @@
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ def plot_losses(
5
+ losses: dict,
6
+ log: bool = False,
7
+ show: bool = True,
8
+ title: str = 'Losses',
9
+ x_label: str = 'Iterations',
10
+ best_epoch: float = None,
11
+ **kwargs
12
+ ):
13
+ func = plt.semilogy if log else plt.plot
14
+
15
+ if len(losses) <= 2:
16
+ losses = {'loss': losses['total_loss']}
17
+
18
+ for k, data in losses.items():
19
+ func(data, label=k, **kwargs)
20
+
21
+ if best_epoch is not None:
22
+ plt.axvline(best_epoch, ls='--', color='red')
23
+
24
+ plt.xlabel(x_label)
25
+
26
+ if len(losses) > 2:
27
+ plt.legend()
28
+
29
+ plt.title(title)
30
+
31
+ if show:
32
+ plt.show()