reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,625 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reflectorch Jupyter Widget
|
|
3
|
+
"""
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing import Optional, Union, TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from torch.types import Device
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from reflectorch.inference.inference_model import InferenceModel
|
|
12
|
+
import ipywidgets as widgets
|
|
13
|
+
from IPython.display import display
|
|
14
|
+
|
|
15
|
+
from reflectorch.extensions.jupyter.plotly_plot_manager import (
|
|
16
|
+
PlotlyPlotManager,
|
|
17
|
+
plot_reflectivity_only,
|
|
18
|
+
plot_sld_only,
|
|
19
|
+
)
|
|
20
|
+
from reflectorch.extensions.jupyter.components import (
|
|
21
|
+
ParameterTable,
|
|
22
|
+
PreprocessingControls,
|
|
23
|
+
PredictionControls,
|
|
24
|
+
PlottingControls,
|
|
25
|
+
AdditionalParametersControls,
|
|
26
|
+
WidgetSettingsExtractor,
|
|
27
|
+
)
|
|
28
|
+
from reflectorch.extensions.jupyter.model_selection import ModelSelection
|
|
29
|
+
from reflectorch.extensions.jupyter.log_widget import LogWidget
|
|
30
|
+
|
|
31
|
+
from huggingface_hub.utils import disable_progress_bars
|
|
32
|
+
|
|
33
|
+
# that causes some Rust related errors when downloading models from Huggingface
|
|
34
|
+
disable_progress_bars()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ReflectorchPlotlyWidget:
|
|
38
|
+
"""
|
|
39
|
+
Interactive Jupyter Widget for Reflectometry Analysis using Plotly
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
model: The InferenceModel instance
|
|
43
|
+
prediction_result: Latest prediction results
|
|
44
|
+
plot_manager: PlotlyPlotManager for handling interactive plots
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
```python
|
|
48
|
+
from reflectorch.inference import InferenceModel
|
|
49
|
+
from reflectorch.extensions.jupyter import ReflectorchPlotlyWidget
|
|
50
|
+
|
|
51
|
+
model = InferenceModel('config.yaml')
|
|
52
|
+
widget = ReflectorchPlotlyWidget(model)
|
|
53
|
+
|
|
54
|
+
widget.display(
|
|
55
|
+
reflectivity_curve=data,
|
|
56
|
+
q_values=q_values,
|
|
57
|
+
sigmas=sigmas
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Access results
|
|
61
|
+
results = widget.prediction_result
|
|
62
|
+
```
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self,
|
|
66
|
+
reflectivity_curve: np.ndarray,
|
|
67
|
+
q_values: np.ndarray,
|
|
68
|
+
sigmas: Optional[np.ndarray] = None,
|
|
69
|
+
q_resolution: Optional[Union[float, np.ndarray]] = None,
|
|
70
|
+
initial_prior_bounds: Optional[np.ndarray] = None,
|
|
71
|
+
ambient_sld: Optional[float] = None,
|
|
72
|
+
model: Optional["InferenceModel"] = None,
|
|
73
|
+
root_dir: Optional[str] = None,
|
|
74
|
+
):
|
|
75
|
+
"""
|
|
76
|
+
Initialize the Reflectorch Plotly widget
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
reflectivity_curve: Experimental reflectivity data
|
|
80
|
+
q_values: Momentum transfer values
|
|
81
|
+
sigmas: Experimental uncertainties (optional)
|
|
82
|
+
q_resolution: Q-resolution, float or array (optional)
|
|
83
|
+
initial_prior_bounds: Initial bounds for priors, shape (n_params, 2)
|
|
84
|
+
ambient_sld: Ambient SLD value (optional)
|
|
85
|
+
model: InferenceModel instance for making predictions (optional)
|
|
86
|
+
root_dir: Root directory for the model (optional)
|
|
87
|
+
"""
|
|
88
|
+
self.model = model
|
|
89
|
+
self.prediction_result = None
|
|
90
|
+
self.plot_manager = PlotlyPlotManager()
|
|
91
|
+
self.root_dir = root_dir
|
|
92
|
+
# Store data for prediction
|
|
93
|
+
self._data = {
|
|
94
|
+
'reflectivity_curve': reflectivity_curve,
|
|
95
|
+
'q_values': q_values,
|
|
96
|
+
'sigmas': sigmas,
|
|
97
|
+
'q_resolution': q_resolution,
|
|
98
|
+
'ambient_sld': ambient_sld
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
self.initial_prior_bounds = initial_prior_bounds
|
|
102
|
+
|
|
103
|
+
# Widget components (initialized when display is called)
|
|
104
|
+
self.parameter_table = None
|
|
105
|
+
self.preprocessing_controls = None
|
|
106
|
+
self.prediction_controls = None
|
|
107
|
+
self.plotting_controls = None
|
|
108
|
+
self.additional_params_controls = None
|
|
109
|
+
self.model_selection = None
|
|
110
|
+
self.tabs_widget = None # Store reference to tabs for updating
|
|
111
|
+
self.predict_button = None # Store reference to predict button
|
|
112
|
+
self.log_widget = None # Store reference to log widget
|
|
113
|
+
|
|
114
|
+
if self.model is not None:
|
|
115
|
+
self._validate_model()
|
|
116
|
+
|
|
117
|
+
def _create_parameter_components(self, initial_prior_bounds=None, predict_button=None):
|
|
118
|
+
"""
|
|
119
|
+
Create parameter table and additional parameters controls based on current model
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
initial_prior_bounds: Initial bounds for priors, shape (n_params, 2)
|
|
123
|
+
predict_button: Optional predict button to include in parameter table
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Tuple of (parameter_table, additional_params_controls)
|
|
127
|
+
"""
|
|
128
|
+
# Create additional params controls first so we can pass it to parameter table
|
|
129
|
+
additional_params_controls = AdditionalParametersControls(self.model)
|
|
130
|
+
|
|
131
|
+
# Get model parameters info if model is available
|
|
132
|
+
if self.model is not None:
|
|
133
|
+
param_labels = self.model.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
134
|
+
min_bounds = self.model.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
135
|
+
max_bounds = self.model.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
136
|
+
max_deltas = self.model.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
137
|
+
else:
|
|
138
|
+
# Default empty parameters when no model is loaded
|
|
139
|
+
param_labels = []
|
|
140
|
+
min_bounds = np.array([])
|
|
141
|
+
max_bounds = np.array([])
|
|
142
|
+
max_deltas = np.array([])
|
|
143
|
+
|
|
144
|
+
parameter_table = ParameterTable(
|
|
145
|
+
param_labels, min_bounds, max_bounds, max_deltas, initial_prior_bounds,
|
|
146
|
+
additional_params_controls=additional_params_controls,
|
|
147
|
+
predict_button=predict_button
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return parameter_table, additional_params_controls
|
|
151
|
+
|
|
152
|
+
def _validate_model(self):
|
|
153
|
+
"""Validate that the model has required attributes"""
|
|
154
|
+
required_attrs = ['trainer', 'preprocess_and_predict']
|
|
155
|
+
for attr in required_attrs:
|
|
156
|
+
if not hasattr(self.model, attr):
|
|
157
|
+
raise ValueError(f"Model must have '{attr}' attribute")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def display(self,
|
|
161
|
+
controls_width: int = 700,
|
|
162
|
+
plot_width: int = 400,
|
|
163
|
+
plot_height: int = 300
|
|
164
|
+
):
|
|
165
|
+
"""
|
|
166
|
+
Display the widget interface
|
|
167
|
+
|
|
168
|
+
Parameters:
|
|
169
|
+
----------
|
|
170
|
+
reflectivity_curve: Experimental reflectivity data
|
|
171
|
+
q_values: Momentum transfer values (required)
|
|
172
|
+
sigmas: Experimental uncertainties (optional)
|
|
173
|
+
q_resolution: Q-resolution, float or array (optional)
|
|
174
|
+
initial_prior_bounds: Initial bounds for priors, shape (n_params, 2)
|
|
175
|
+
ambient_sld: Ambient SLD value (optional)
|
|
176
|
+
controls_width: Width of the controls area in pixels. Default is 700px.
|
|
177
|
+
plot_width: Width of the plots in pixels. Default is 400px.
|
|
178
|
+
plot_height: Height of the plots in pixels. Default is 300px.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
# Create predict button (disabled by default if no model)
|
|
182
|
+
self.predict_button = widgets.Button(
|
|
183
|
+
description="Predict",
|
|
184
|
+
button_style='primary',
|
|
185
|
+
tooltip='Run prediction with current settings' if self.model else 'Load a model first',
|
|
186
|
+
layout=widgets.Layout(width='120px'),
|
|
187
|
+
disabled=(self.model is None)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Create widget components
|
|
191
|
+
self.parameter_table, self.additional_params_controls = self._create_parameter_components(
|
|
192
|
+
self.initial_prior_bounds, self.predict_button
|
|
193
|
+
)
|
|
194
|
+
self.preprocessing_controls = PreprocessingControls(len(self._data['reflectivity_curve']))
|
|
195
|
+
self.prediction_controls = PredictionControls()
|
|
196
|
+
self.plotting_controls = PlottingControls()
|
|
197
|
+
self.model_selection = ModelSelection()
|
|
198
|
+
|
|
199
|
+
# Create log widget
|
|
200
|
+
self.log_widget = LogWidget()
|
|
201
|
+
|
|
202
|
+
# Create tabbed interface
|
|
203
|
+
# if model is not provided, then Models tabs goes first.
|
|
204
|
+
tabs = widgets.Tab()
|
|
205
|
+
if self.model is None:
|
|
206
|
+
tabs.children = [
|
|
207
|
+
self.model_selection.widget,
|
|
208
|
+
self.parameter_table.widget,
|
|
209
|
+
self.preprocessing_controls.widget,
|
|
210
|
+
self.prediction_controls.widget,
|
|
211
|
+
self.plotting_controls.widget
|
|
212
|
+
]
|
|
213
|
+
tabs.titles = ['Models', 'Parameters', 'Preprocessing', 'Prediction', 'Plotting']
|
|
214
|
+
else:
|
|
215
|
+
tabs.children = [
|
|
216
|
+
self.parameter_table.widget,
|
|
217
|
+
self.preprocessing_controls.widget,
|
|
218
|
+
self.prediction_controls.widget,
|
|
219
|
+
self.plotting_controls.widget,
|
|
220
|
+
self.model_selection.widget
|
|
221
|
+
]
|
|
222
|
+
tabs.titles = ['Parameters', 'Preprocessing', 'Prediction', 'Plotting', 'Models']
|
|
223
|
+
|
|
224
|
+
self.tab_indices = dict(zip(tabs.titles, range(len(tabs.titles))))
|
|
225
|
+
# Store reference to tabs for later updates
|
|
226
|
+
self.tabs_widget = tabs
|
|
227
|
+
|
|
228
|
+
# Create plot containers (initially empty)
|
|
229
|
+
reflectivity_plot_container = widgets.VBox([])
|
|
230
|
+
sld_plot_container = widgets.VBox([])
|
|
231
|
+
|
|
232
|
+
# Combine plots vertically on the right
|
|
233
|
+
plot_area = widgets.VBox([
|
|
234
|
+
reflectivity_plot_container,
|
|
235
|
+
sld_plot_container
|
|
236
|
+
], layout=widgets.Layout(margin='50px 0px 0px 0px'))
|
|
237
|
+
|
|
238
|
+
# Main layout with controls on left, plots on right
|
|
239
|
+
header = widgets.HTML("<h2>Reflectorch Widget</h2>")
|
|
240
|
+
|
|
241
|
+
controls_area = widgets.VBox([
|
|
242
|
+
header,
|
|
243
|
+
tabs
|
|
244
|
+
], layout=widgets.Layout(width=f'{controls_width}px'))
|
|
245
|
+
|
|
246
|
+
# Horizontal layout: controls on left, plots on right
|
|
247
|
+
main_content = widgets.HBox([
|
|
248
|
+
controls_area,
|
|
249
|
+
plot_area
|
|
250
|
+
])
|
|
251
|
+
|
|
252
|
+
# Complete layout with log at the bottom
|
|
253
|
+
main_layout = widgets.VBox([
|
|
254
|
+
main_content,
|
|
255
|
+
self.log_widget.widget
|
|
256
|
+
])
|
|
257
|
+
|
|
258
|
+
# Add border around the entire widget
|
|
259
|
+
container = widgets.VBox([main_layout], layout=widgets.Layout(
|
|
260
|
+
border='2px solid #d0d0d0',
|
|
261
|
+
border_radius='8px',
|
|
262
|
+
padding='15px',
|
|
263
|
+
margin='10px',
|
|
264
|
+
background_color='#fafafa'
|
|
265
|
+
))
|
|
266
|
+
display(container)
|
|
267
|
+
|
|
268
|
+
# Setup event handlers
|
|
269
|
+
self._setup_event_handlers(self.predict_button, reflectivity_plot_container, sld_plot_container, container)
|
|
270
|
+
|
|
271
|
+
# Setup model selection integration
|
|
272
|
+
self._setup_model_selection_integration()
|
|
273
|
+
|
|
274
|
+
# Setup truncation synchronization
|
|
275
|
+
self._setup_truncation_sync()
|
|
276
|
+
|
|
277
|
+
# Create initial plots with experimental data
|
|
278
|
+
self._create_initial_plots(reflectivity_plot_container, sld_plot_container, plot_width, plot_height)
|
|
279
|
+
|
|
280
|
+
# Setup reactive plot updates for plotting controls
|
|
281
|
+
self._setup_reactive_plot_updates(reflectivity_plot_container, sld_plot_container)
|
|
282
|
+
|
|
283
|
+
def _create_initial_plots(self, reflectivity_container, sld_container, plot_width, plot_height):
|
|
284
|
+
"""Create initial plots showing experimental data"""
|
|
285
|
+
try:
|
|
286
|
+
# Use default settings for initial plots
|
|
287
|
+
settings = {
|
|
288
|
+
'show_error_bars': True,
|
|
289
|
+
'show_q_resolution': True,
|
|
290
|
+
'exp_color': 'blue',
|
|
291
|
+
'exp_errcolor': 'purple',
|
|
292
|
+
'log_x_axis': False,
|
|
293
|
+
'plot_sld_profile': True
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
# Plot initial experimental data
|
|
297
|
+
self._plot_initial_data(settings, reflectivity_container, sld_container, plot_width, plot_height)
|
|
298
|
+
|
|
299
|
+
except Exception as e:
|
|
300
|
+
if self.log_widget:
|
|
301
|
+
self.log_widget.log(f"⚠️ Could not create initial plots: {str(e)}")
|
|
302
|
+
else:
|
|
303
|
+
print(f"⚠️ Could not create initial plots: {str(e)}")
|
|
304
|
+
|
|
305
|
+
def _plot_initial_data(self, settings, reflectivity_container, sld_container, plot_width, plot_height):
|
|
306
|
+
"""Plot only experimental data before any prediction"""
|
|
307
|
+
# Prepare experimental data for plotting
|
|
308
|
+
q_exp_plot = self._data['q_values']
|
|
309
|
+
r_exp_plot = self._data['reflectivity_curve']
|
|
310
|
+
yerr_plot = self._data['sigmas'] if settings['show_error_bars'] and self._data['sigmas'] is not None else None
|
|
311
|
+
xerr_plot = self._data['q_resolution'] if settings['show_q_resolution'] and self._data['q_resolution'] is not None else None
|
|
312
|
+
|
|
313
|
+
# Create reflectivity plot
|
|
314
|
+
reflectivity_fig = plot_reflectivity_only(
|
|
315
|
+
plot_manager=self.plot_manager,
|
|
316
|
+
figure_id="reflectivity_plot",
|
|
317
|
+
q_exp=q_exp_plot, r_exp=r_exp_plot, yerr=yerr_plot, xerr=xerr_plot,
|
|
318
|
+
exp_color=settings['exp_color'],
|
|
319
|
+
exp_errcolor=settings['exp_errcolor'],
|
|
320
|
+
exp_label='experimental data',
|
|
321
|
+
logx=settings['log_x_axis'], logy=True,
|
|
322
|
+
width=plot_width, height=plot_height
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Get the reflectivity plotly widget and add it to container
|
|
326
|
+
reflectivity_widget = self.plot_manager.get_widget("reflectivity_plot")
|
|
327
|
+
reflectivity_container.children = [reflectivity_widget]
|
|
328
|
+
|
|
329
|
+
# Create empty SLD plot (will be populated after prediction)
|
|
330
|
+
if settings['plot_sld_profile']:
|
|
331
|
+
sld_fig = plot_sld_only(
|
|
332
|
+
plot_manager=self.plot_manager,
|
|
333
|
+
figure_id="sld_plot",
|
|
334
|
+
z_sld=None, sld_pred=None, sld_pol=None,
|
|
335
|
+
width=plot_width, height=plot_height
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Get the SLD plotly widget and add it to container
|
|
339
|
+
sld_widget = self.plot_manager.get_widget("sld_plot")
|
|
340
|
+
sld_container.children = [sld_widget]
|
|
341
|
+
|
|
342
|
+
def _setup_event_handlers(self, predict_button, reflectivity_container, sld_container, container):
|
|
343
|
+
"""Setup button event handlers"""
|
|
344
|
+
|
|
345
|
+
def on_predict(_):
|
|
346
|
+
"""Handle predict button click"""
|
|
347
|
+
with self.log_widget.capture_prints():
|
|
348
|
+
# Store original button state
|
|
349
|
+
original_description = predict_button.description
|
|
350
|
+
original_disabled = predict_button.disabled
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
# Check if model is loaded
|
|
354
|
+
if self.model is None:
|
|
355
|
+
print("❌ No model loaded. Please load a model from the Models tab first.")
|
|
356
|
+
return
|
|
357
|
+
|
|
358
|
+
# Disable button and show "Predicting..."
|
|
359
|
+
predict_button.disabled = True
|
|
360
|
+
predict_button.description = "Predicting..."
|
|
361
|
+
|
|
362
|
+
# Extract settings from all components with data fallback
|
|
363
|
+
settings = WidgetSettingsExtractor.extract_settings(
|
|
364
|
+
self.parameter_table,
|
|
365
|
+
self.preprocessing_controls,
|
|
366
|
+
self.prediction_controls,
|
|
367
|
+
self.plotting_controls,
|
|
368
|
+
self.additional_params_controls,
|
|
369
|
+
data=self._data
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Separate prediction and plotting parameters
|
|
373
|
+
prediction_params, plotting_params = WidgetSettingsExtractor.separate_settings(settings)
|
|
374
|
+
|
|
375
|
+
# Run prediction with all parameters
|
|
376
|
+
prediction_result = self.model.preprocess_and_predict(**prediction_params)
|
|
377
|
+
|
|
378
|
+
# Update parameter table with results
|
|
379
|
+
self.parameter_table.update_results(prediction_result)
|
|
380
|
+
|
|
381
|
+
# Plot results
|
|
382
|
+
self._plot_results(prediction_result, plotting_params, reflectivity_container, sld_container)
|
|
383
|
+
|
|
384
|
+
# Store results
|
|
385
|
+
self.prediction_result = prediction_result
|
|
386
|
+
|
|
387
|
+
except Exception as e:
|
|
388
|
+
print(f"❌ Prediction error: {str(e)}")
|
|
389
|
+
import traceback
|
|
390
|
+
traceback.print_exc()
|
|
391
|
+
finally:
|
|
392
|
+
# Always restore button state, even if there was an error
|
|
393
|
+
predict_button.description = original_description
|
|
394
|
+
predict_button.disabled = original_disabled
|
|
395
|
+
|
|
396
|
+
# Connect event handlers
|
|
397
|
+
predict_button.on_click(on_predict)
|
|
398
|
+
|
|
399
|
+
def _setup_model_selection_integration(self):
|
|
400
|
+
"""Setup integration with model selection tab"""
|
|
401
|
+
if not self.model_selection:
|
|
402
|
+
return
|
|
403
|
+
|
|
404
|
+
load_button = self.model_selection._widgets['download_button']
|
|
405
|
+
|
|
406
|
+
def on_load_model(_):
|
|
407
|
+
"""Handle load model button click"""
|
|
408
|
+
with self.log_widget.capture_prints():
|
|
409
|
+
# Store original button state
|
|
410
|
+
original_description = load_button.description
|
|
411
|
+
original_disabled = load_button.disabled
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
selected_model_info = self.model_selection.get_selected_model_info()
|
|
415
|
+
if selected_model_info is None:
|
|
416
|
+
print("❌ No model selected")
|
|
417
|
+
return
|
|
418
|
+
|
|
419
|
+
# Disable button and show "Downloading..."
|
|
420
|
+
load_button.disabled = True
|
|
421
|
+
load_button.description = "Downloading..."
|
|
422
|
+
|
|
423
|
+
print(f"🔄 Loading model: {selected_model_info['model_name']} ...")
|
|
424
|
+
|
|
425
|
+
from reflectorch.inference.inference_model import InferenceModel
|
|
426
|
+
|
|
427
|
+
device: Device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
428
|
+
|
|
429
|
+
# Create new model instance
|
|
430
|
+
new_model = InferenceModel(
|
|
431
|
+
config_name=selected_model_info['config_name'],
|
|
432
|
+
repo_id=selected_model_info['repo_id'],
|
|
433
|
+
device=device,
|
|
434
|
+
root_dir=self.root_dir
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
print(f"📥 Model downloaded and initialized successfully")
|
|
438
|
+
|
|
439
|
+
# Replace current model
|
|
440
|
+
self.model = new_model
|
|
441
|
+
|
|
442
|
+
# Enable the predict button since we now have a model
|
|
443
|
+
if self.predict_button:
|
|
444
|
+
self.predict_button.disabled = False
|
|
445
|
+
self.predict_button.tooltip = 'Run prediction with current settings'
|
|
446
|
+
|
|
447
|
+
# Create new parameter components for the new model (reuse the same predict button)
|
|
448
|
+
new_parameter_table, new_additional_params_controls = self._create_parameter_components(predict_button=self.predict_button)
|
|
449
|
+
|
|
450
|
+
# Update the parameter table title with the new model name
|
|
451
|
+
new_parameter_table.param_title.value = f"<h4>Parameter Configuration for {selected_model_info['model_name']}</h4>"
|
|
452
|
+
|
|
453
|
+
# Update the parameter table in the tabs
|
|
454
|
+
if self.tabs_widget and hasattr(self.tabs_widget, 'children'):
|
|
455
|
+
children_list = list(self.tabs_widget.children)
|
|
456
|
+
children_list[self.tab_indices['Parameters']] = new_parameter_table.widget
|
|
457
|
+
self.tabs_widget.children = children_list
|
|
458
|
+
|
|
459
|
+
# Update our references to the components
|
|
460
|
+
self.parameter_table = new_parameter_table
|
|
461
|
+
self.additional_params_controls = new_additional_params_controls
|
|
462
|
+
|
|
463
|
+
# Switch to Parameters tab automatically
|
|
464
|
+
self.tabs_widget.selected_index = self.tab_indices['Parameters']
|
|
465
|
+
|
|
466
|
+
# Clear previous prediction results
|
|
467
|
+
self.prediction_result = None
|
|
468
|
+
|
|
469
|
+
# Get parameter info for success message
|
|
470
|
+
param_labels = self.model.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
471
|
+
max_layers = self.model.trainer.loader.prior_sampler.max_num_layers
|
|
472
|
+
|
|
473
|
+
print(f"✅ Model loaded successfully: {selected_model_info['config_name']}")
|
|
474
|
+
print(f"Parameters: {len(param_labels)} parameters, {max_layers} max layers")
|
|
475
|
+
print("💡 Tip: Go to the Parameters tab to see the updated parameter ranges for the new model")
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
print(f"❌ Error loading model: {str(e)}")
|
|
479
|
+
import traceback
|
|
480
|
+
traceback.print_exc()
|
|
481
|
+
finally:
|
|
482
|
+
# Always restore button state, even if there was an error
|
|
483
|
+
load_button.description = original_description
|
|
484
|
+
load_button.disabled = original_disabled
|
|
485
|
+
|
|
486
|
+
# Connect the load button
|
|
487
|
+
load_button.on_click(on_load_model)
|
|
488
|
+
|
|
489
|
+
def _plot_results(self, prediction_result, settings, reflectivity_container, sld_container):
|
|
490
|
+
"""Plot prediction results with current settings"""
|
|
491
|
+
# Prepare plotting data
|
|
492
|
+
q_exp_plot = self._data['q_values']
|
|
493
|
+
r_exp_plot = self._data['reflectivity_curve']
|
|
494
|
+
yerr_plot = self._data['sigmas'] if settings['show_error_bars'] else None
|
|
495
|
+
xerr_plot = self._data['q_resolution'] if settings['show_q_resolution'] else None
|
|
496
|
+
|
|
497
|
+
q_pred = prediction_result.get('q_plot_pred', None)
|
|
498
|
+
r_pred = prediction_result.get('predicted_curve', None)
|
|
499
|
+
q_pol = self._data['q_values'] if 'polished_curve' in prediction_result else None
|
|
500
|
+
r_pol = prediction_result.get('polished_curve', None)
|
|
501
|
+
|
|
502
|
+
z_sld = prediction_result.get('predicted_sld_xaxis', None)
|
|
503
|
+
sld_pred = prediction_result.get('predicted_sld_profile', None)
|
|
504
|
+
sld_pol = prediction_result.get('sld_profile_polished', None)
|
|
505
|
+
|
|
506
|
+
# Handle complex SLD
|
|
507
|
+
if sld_pred is not None and np.iscomplexobj(sld_pred):
|
|
508
|
+
sld_pred = sld_pred.real
|
|
509
|
+
if sld_pol is not None and np.iscomplexobj(sld_pol):
|
|
510
|
+
sld_pol = sld_pol.real
|
|
511
|
+
|
|
512
|
+
# Update reflectivity plot
|
|
513
|
+
reflectivity_fig = plot_reflectivity_only(
|
|
514
|
+
plot_manager=self.plot_manager,
|
|
515
|
+
figure_id="reflectivity_plot",
|
|
516
|
+
q_exp=q_exp_plot, r_exp=r_exp_plot, yerr=yerr_plot, xerr=xerr_plot,
|
|
517
|
+
exp_color=settings['exp_color'],
|
|
518
|
+
exp_errcolor=settings['exp_errcolor'],
|
|
519
|
+
q_pred=q_pred, r_pred=r_pred, pred_color=settings['pred_color'],
|
|
520
|
+
q_pol=q_pol, r_pol=r_pol, pol_color=settings['pol_color'],
|
|
521
|
+
logx=settings['log_x_axis'], logy=True,
|
|
522
|
+
width=600, height=300
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# Update SLD plot if requested
|
|
526
|
+
if settings['plot_sld_profile']:
|
|
527
|
+
sld_fig = plot_sld_only(
|
|
528
|
+
plot_manager=self.plot_manager,
|
|
529
|
+
figure_id="sld_plot",
|
|
530
|
+
z_sld=z_sld, sld_pred=sld_pred, sld_pol=sld_pol,
|
|
531
|
+
sld_pred_color=settings['sld_pred_color'],
|
|
532
|
+
sld_pol_color=settings['sld_pol_color'],
|
|
533
|
+
width=600, height=250
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
def _setup_truncation_sync(self):
|
|
537
|
+
"""Setup synchronization between truncation sliders"""
|
|
538
|
+
# Find truncation widgets
|
|
539
|
+
trunc_widgets = WidgetSettingsExtractor._find_widgets_by_description(
|
|
540
|
+
self.preprocessing_controls.widget, ['Left index:', 'Right index:']
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
if len(trunc_widgets) == 2:
|
|
544
|
+
trunc_left, trunc_right = trunc_widgets
|
|
545
|
+
|
|
546
|
+
def sync_truncation(_):
|
|
547
|
+
if trunc_left.value >= trunc_right.value:
|
|
548
|
+
trunc_left.value = max(0, trunc_right.value - 1)
|
|
549
|
+
|
|
550
|
+
trunc_left.observe(sync_truncation, names='value')
|
|
551
|
+
trunc_right.observe(sync_truncation, names='value')
|
|
552
|
+
|
|
553
|
+
def _setup_reactive_plot_updates(self, reflectivity_container, sld_container):
|
|
554
|
+
"""Setup observers for plotting controls that should trigger immediate plot updates"""
|
|
555
|
+
if not self.plotting_controls:
|
|
556
|
+
return
|
|
557
|
+
|
|
558
|
+
# Find plotting controls that should trigger plot updates
|
|
559
|
+
reactive_controls = WidgetSettingsExtractor._find_widgets_by_description(
|
|
560
|
+
self.plotting_controls.widget,
|
|
561
|
+
[
|
|
562
|
+
'Show error bars', 'Show q-resolution', 'Log x-axis', 'Plot SLD profile',
|
|
563
|
+
'Data color:', 'Error bars:', 'Prediction:', 'Polished:',
|
|
564
|
+
'SLD pred:', 'SLD polish:'
|
|
565
|
+
]
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
def update_plot_on_change(change):
|
|
569
|
+
"""Update plot when plotting controls change"""
|
|
570
|
+
# Only update if we have prediction results to show
|
|
571
|
+
if self.prediction_result is not None:
|
|
572
|
+
try:
|
|
573
|
+
# Extract current settings
|
|
574
|
+
settings = WidgetSettingsExtractor.extract_settings(
|
|
575
|
+
self.parameter_table,
|
|
576
|
+
self.preprocessing_controls,
|
|
577
|
+
self.prediction_controls,
|
|
578
|
+
self.plotting_controls
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Update plot with new settings
|
|
582
|
+
self._plot_results(self.prediction_result, settings, reflectivity_container, sld_container)
|
|
583
|
+
|
|
584
|
+
except Exception as e:
|
|
585
|
+
if self.log_widget:
|
|
586
|
+
self.log_widget.log(f"⚠️ Error updating plot: {str(e)}")
|
|
587
|
+
else:
|
|
588
|
+
print(f"⚠️ Error updating plot: {str(e)}")
|
|
589
|
+
else:
|
|
590
|
+
# If no prediction results yet, just update the initial plot
|
|
591
|
+
try:
|
|
592
|
+
self._update_initial_plot_style(reflectivity_container, sld_container)
|
|
593
|
+
except Exception as e:
|
|
594
|
+
if self.log_widget:
|
|
595
|
+
self.log_widget.log(f"⚠️ Error updating initial plot: {str(e)}")
|
|
596
|
+
else:
|
|
597
|
+
print(f"⚠️ Error updating initial plot: {str(e)}")
|
|
598
|
+
|
|
599
|
+
# Setup observers for all reactive controls
|
|
600
|
+
for control in reactive_controls:
|
|
601
|
+
if hasattr(control, 'observe'):
|
|
602
|
+
control.observe(update_plot_on_change, names='value')
|
|
603
|
+
|
|
604
|
+
def _update_initial_plot_style(self, reflectivity_container, sld_container):
|
|
605
|
+
"""Update initial plot styling based on current control settings"""
|
|
606
|
+
if not self.plotting_controls:
|
|
607
|
+
return
|
|
608
|
+
|
|
609
|
+
try:
|
|
610
|
+
# Extract current plotting settings
|
|
611
|
+
settings = WidgetSettingsExtractor.extract_settings(
|
|
612
|
+
self.parameter_table,
|
|
613
|
+
self.preprocessing_controls,
|
|
614
|
+
self.prediction_controls,
|
|
615
|
+
self.plotting_controls
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Update initial plot with new styling
|
|
619
|
+
self._plot_initial_data(settings, reflectivity_container, sld_container)
|
|
620
|
+
|
|
621
|
+
except Exception as e:
|
|
622
|
+
if self.log_widget:
|
|
623
|
+
self.log_widget.log(f"⚠️ Error updating initial plot style: {str(e)}")
|
|
624
|
+
else:
|
|
625
|
+
print(f"⚠️ Error updating initial plot style: {str(e)}")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from reflectorch.extensions.matplotlib.losses import plot_losses
|
|
2
|
-
|
|
3
|
-
__all__ = [
|
|
4
|
-
"plot_losses",
|
|
5
|
-
]
|
|
1
|
+
from reflectorch.extensions.matplotlib.losses import plot_losses
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"plot_losses",
|
|
5
|
+
]
|