reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import plotly.graph_objects as go
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PlotlyPlotManager:
|
|
7
|
+
"""
|
|
8
|
+
Manager for Plotly figures in Jupyter widgets
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, verbose: bool = False):
|
|
12
|
+
self.figures = {} # Store persistent figures
|
|
13
|
+
self.widgets = {} # Store plotly widgets
|
|
14
|
+
self.verbose = verbose
|
|
15
|
+
|
|
16
|
+
def create_reflectivity_figure(self,
|
|
17
|
+
figure_id: str,
|
|
18
|
+
width: int = 600,
|
|
19
|
+
height: int = 300):
|
|
20
|
+
"""Create a reflectivity-only figure widget"""
|
|
21
|
+
fig = go.Figure()
|
|
22
|
+
|
|
23
|
+
fig.update_layout(
|
|
24
|
+
width=width,
|
|
25
|
+
height=height,
|
|
26
|
+
showlegend=True,
|
|
27
|
+
hovermode='closest',
|
|
28
|
+
template='plotly_white',
|
|
29
|
+
margin=dict(l=60, r=20, t=60, b=60),
|
|
30
|
+
legend=dict(
|
|
31
|
+
orientation="h",
|
|
32
|
+
yanchor="bottom",
|
|
33
|
+
y=1.02,
|
|
34
|
+
xanchor="left",
|
|
35
|
+
x=0
|
|
36
|
+
)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Create Plotly widget
|
|
40
|
+
plotly_widget = go.FigureWidget(fig)
|
|
41
|
+
|
|
42
|
+
# Store references
|
|
43
|
+
self.figures[figure_id] = fig
|
|
44
|
+
self.widgets[figure_id] = plotly_widget
|
|
45
|
+
|
|
46
|
+
return plotly_widget
|
|
47
|
+
|
|
48
|
+
def create_sld_figure(self,
|
|
49
|
+
figure_id: str,
|
|
50
|
+
width: int = 600,
|
|
51
|
+
height: int = 250):
|
|
52
|
+
"""Create an SLD-only figure widget"""
|
|
53
|
+
fig = go.Figure()
|
|
54
|
+
|
|
55
|
+
fig.update_layout(
|
|
56
|
+
width=width,
|
|
57
|
+
height=height,
|
|
58
|
+
showlegend=True,
|
|
59
|
+
hovermode='closest',
|
|
60
|
+
template='plotly_white',
|
|
61
|
+
margin=dict(l=60, r=20, t=60, b=60),
|
|
62
|
+
legend=dict(
|
|
63
|
+
orientation="h",
|
|
64
|
+
yanchor="bottom",
|
|
65
|
+
y=1.02,
|
|
66
|
+
xanchor="left",
|
|
67
|
+
x=0
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Create Plotly widget
|
|
72
|
+
plotly_widget = go.FigureWidget(fig)
|
|
73
|
+
|
|
74
|
+
# Store references
|
|
75
|
+
self.figures[figure_id] = fig
|
|
76
|
+
self.widgets[figure_id] = plotly_widget
|
|
77
|
+
|
|
78
|
+
return plotly_widget
|
|
79
|
+
|
|
80
|
+
def _setup_figure_hover(self, figure_id: str, plotly_widget):
|
|
81
|
+
"""Setup hover functionality for the figure (no coordinate display)"""
|
|
82
|
+
# Just ensure data is mutable for future trace additions
|
|
83
|
+
plotly_widget.data = list(plotly_widget.data)
|
|
84
|
+
|
|
85
|
+
def get_figure(self, figure_id: str):
|
|
86
|
+
"""Get existing figure widget"""
|
|
87
|
+
if figure_id not in self.widgets:
|
|
88
|
+
raise ValueError(f"Figure '{figure_id}' not found. Create it first with create_figure().")
|
|
89
|
+
return self.widgets[figure_id]
|
|
90
|
+
|
|
91
|
+
def get_widget(self, figure_id: str):
|
|
92
|
+
"""Get the plotly widget for display"""
|
|
93
|
+
if figure_id not in self.widgets:
|
|
94
|
+
raise ValueError(f"Widget for figure '{figure_id}' not found.")
|
|
95
|
+
return self.widgets[figure_id]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def clear_figure(self, figure_id: str):
|
|
99
|
+
"""Clear all traces from the figure"""
|
|
100
|
+
if figure_id in self.widgets:
|
|
101
|
+
widget = self.widgets[figure_id]
|
|
102
|
+
# Use Plotly's proper method to clear traces
|
|
103
|
+
with widget.batch_update():
|
|
104
|
+
widget.data = []
|
|
105
|
+
|
|
106
|
+
def close_figure(self, figure_id: str):
|
|
107
|
+
"""Close and cleanup a figure"""
|
|
108
|
+
if figure_id in self.figures:
|
|
109
|
+
del self.figures[figure_id]
|
|
110
|
+
|
|
111
|
+
if figure_id in self.widgets:
|
|
112
|
+
del self.widgets[figure_id]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def plot_reflectivity_only(
|
|
117
|
+
plot_manager: PlotlyPlotManager,
|
|
118
|
+
figure_id: str,
|
|
119
|
+
*,
|
|
120
|
+
q_exp=None,
|
|
121
|
+
r_exp=None,
|
|
122
|
+
yerr=None,
|
|
123
|
+
xerr=None,
|
|
124
|
+
q_pred=None,
|
|
125
|
+
r_pred=None,
|
|
126
|
+
q_pol=None,
|
|
127
|
+
r_pol=None,
|
|
128
|
+
logx=False,
|
|
129
|
+
logy=True,
|
|
130
|
+
exp_color='blue',
|
|
131
|
+
exp_errcolor='purple',
|
|
132
|
+
pred_color='red',
|
|
133
|
+
pol_color='orange',
|
|
134
|
+
exp_label='experimental data',
|
|
135
|
+
pred_label='prediction',
|
|
136
|
+
pol_label='polished prediction',
|
|
137
|
+
width=600,
|
|
138
|
+
height=300
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Plot reflectivity data only using Plotly
|
|
142
|
+
|
|
143
|
+
This function creates or updates a Plotly figure widget with reflectivity data only.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def _np(a):
|
|
147
|
+
return None if a is None else np.asarray(a)
|
|
148
|
+
|
|
149
|
+
def _mask_data(x, y):
|
|
150
|
+
"""Create mask for finite values and positive values if log scale"""
|
|
151
|
+
if x is None or y is None:
|
|
152
|
+
return None, None
|
|
153
|
+
|
|
154
|
+
x, y = np.asarray(x), np.asarray(y)
|
|
155
|
+
mask = np.isfinite(x) & np.isfinite(y)
|
|
156
|
+
|
|
157
|
+
if logx:
|
|
158
|
+
mask &= (x > 0.0)
|
|
159
|
+
if logy:
|
|
160
|
+
mask &= (y > 0.0)
|
|
161
|
+
|
|
162
|
+
return x[mask], y[mask]
|
|
163
|
+
|
|
164
|
+
# Convert inputs to numpy arrays
|
|
165
|
+
q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
|
|
166
|
+
q_pred, r_pred = _np(q_pred), _np(r_pred)
|
|
167
|
+
q_pol, r_pol = _np(q_pol), _np(r_pol)
|
|
168
|
+
|
|
169
|
+
# Create or get existing figure widget
|
|
170
|
+
try:
|
|
171
|
+
fig = plot_manager.get_figure(figure_id)
|
|
172
|
+
# Clear existing traces
|
|
173
|
+
plot_manager.clear_figure(figure_id)
|
|
174
|
+
except ValueError:
|
|
175
|
+
# Figure doesn't exist, create new one
|
|
176
|
+
fig = plot_manager.create_reflectivity_figure(figure_id, width, height)
|
|
177
|
+
|
|
178
|
+
# Plot experimental data
|
|
179
|
+
if q_exp is not None and r_exp is not None:
|
|
180
|
+
q_exp_clean, r_exp_clean = _mask_data(q_exp, r_exp)
|
|
181
|
+
|
|
182
|
+
if q_exp_clean is not None and len(q_exp_clean) > 0:
|
|
183
|
+
# Handle error bars
|
|
184
|
+
error_y = None
|
|
185
|
+
error_x = None
|
|
186
|
+
|
|
187
|
+
if yerr is not None:
|
|
188
|
+
yerr_clean = yerr[np.isfinite(q_exp) & np.isfinite(r_exp)]
|
|
189
|
+
if logx:
|
|
190
|
+
yerr_clean = yerr_clean[q_exp > 0.0]
|
|
191
|
+
if logy:
|
|
192
|
+
yerr_clean = yerr_clean[r_exp > 0.0]
|
|
193
|
+
error_y = dict(type='data', array=yerr_clean, visible=True, color=exp_errcolor)
|
|
194
|
+
|
|
195
|
+
if xerr is not None:
|
|
196
|
+
xerr_clean = xerr[np.isfinite(q_exp) & np.isfinite(r_exp)]
|
|
197
|
+
if logx:
|
|
198
|
+
xerr_clean = xerr_clean[q_exp > 0.0]
|
|
199
|
+
if logy:
|
|
200
|
+
xerr_clean = xerr_clean[r_exp > 0.0]
|
|
201
|
+
error_x = dict(type='data', array=xerr_clean, visible=True, color=exp_errcolor)
|
|
202
|
+
|
|
203
|
+
# Add experimental data trace
|
|
204
|
+
fig.add_trace(
|
|
205
|
+
go.Scatter(
|
|
206
|
+
x=q_exp_clean,
|
|
207
|
+
y=r_exp_clean,
|
|
208
|
+
mode='markers',
|
|
209
|
+
marker=dict(color=exp_color, size=6),
|
|
210
|
+
error_y=error_y,
|
|
211
|
+
error_x=error_x,
|
|
212
|
+
name=exp_label,
|
|
213
|
+
hovertemplate='<b>%{fullData.name}</b><br>q: %{x}<br>R: %{y}<extra></extra>'
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Plot predicted curve
|
|
218
|
+
if q_pred is not None and r_pred is not None:
|
|
219
|
+
q_pred_clean, r_pred_clean = _mask_data(q_pred, r_pred)
|
|
220
|
+
|
|
221
|
+
if q_pred_clean is not None and len(q_pred_clean) > 0:
|
|
222
|
+
fig.add_trace(
|
|
223
|
+
go.Scatter(
|
|
224
|
+
x=q_pred_clean,
|
|
225
|
+
y=r_pred_clean,
|
|
226
|
+
mode='lines',
|
|
227
|
+
line=dict(color=pred_color, width=2),
|
|
228
|
+
name=pred_label,
|
|
229
|
+
hovertemplate='<b>%{fullData.name}</b><br>q: %{x}<br>R: %{y}<extra></extra>'
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Plot polished curve
|
|
234
|
+
if q_pol is not None and r_pol is not None:
|
|
235
|
+
q_pol_clean, r_pol_clean = _mask_data(q_pol, r_pol)
|
|
236
|
+
|
|
237
|
+
if q_pol_clean is not None and len(q_pol_clean) > 0:
|
|
238
|
+
fig.add_trace(
|
|
239
|
+
go.Scatter(
|
|
240
|
+
x=q_pol_clean,
|
|
241
|
+
y=r_pol_clean,
|
|
242
|
+
mode='lines',
|
|
243
|
+
line=dict(color=pol_color, width=2, dash='dash'),
|
|
244
|
+
name=pol_label,
|
|
245
|
+
hovertemplate='<b>%{fullData.name}</b><br>q: %{x}<br>R: %{y}<extra></extra>'
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Update axis settings for reflectivity plot
|
|
250
|
+
fig.update_xaxes(
|
|
251
|
+
title_text="q [Å⁻¹]",
|
|
252
|
+
type='log' if logx else 'linear'
|
|
253
|
+
)
|
|
254
|
+
fig.update_yaxes(
|
|
255
|
+
title_text="R(q)",
|
|
256
|
+
type='log' if logy else 'linear'
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# The fig is already a FigureWidget, so changes are automatically reflected
|
|
260
|
+
return fig
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def plot_sld_only(
|
|
264
|
+
plot_manager: PlotlyPlotManager,
|
|
265
|
+
figure_id: str,
|
|
266
|
+
*,
|
|
267
|
+
z_sld=None,
|
|
268
|
+
sld_pred=None,
|
|
269
|
+
sld_pol=None,
|
|
270
|
+
sld_pred_color='red',
|
|
271
|
+
sld_pol_color='orange',
|
|
272
|
+
sld_pred_label='pred. SLD',
|
|
273
|
+
sld_pol_label='polished SLD',
|
|
274
|
+
width=600,
|
|
275
|
+
height=250
|
|
276
|
+
):
|
|
277
|
+
"""
|
|
278
|
+
Plot SLD profile data only using Plotly
|
|
279
|
+
|
|
280
|
+
This function creates or updates a Plotly figure widget with SLD profile data only.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
def _np(a):
|
|
284
|
+
return None if a is None else np.asarray(a)
|
|
285
|
+
|
|
286
|
+
# Convert inputs to numpy arrays
|
|
287
|
+
z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
|
|
288
|
+
|
|
289
|
+
# Create or get existing figure widget
|
|
290
|
+
try:
|
|
291
|
+
fig = plot_manager.get_figure(figure_id)
|
|
292
|
+
# Clear existing traces
|
|
293
|
+
plot_manager.clear_figure(figure_id)
|
|
294
|
+
except ValueError:
|
|
295
|
+
# Figure doesn't exist, create new one
|
|
296
|
+
fig = plot_manager.create_sld_figure(figure_id, width, height)
|
|
297
|
+
|
|
298
|
+
# Plot SLD profiles
|
|
299
|
+
if z_sld is not None:
|
|
300
|
+
if sld_pred is not None:
|
|
301
|
+
fig.add_trace(
|
|
302
|
+
go.Scatter(
|
|
303
|
+
x=z_sld,
|
|
304
|
+
y=sld_pred,
|
|
305
|
+
mode='lines',
|
|
306
|
+
line=dict(color=sld_pred_color, width=2),
|
|
307
|
+
name=sld_pred_label,
|
|
308
|
+
hovertemplate='<b>%{fullData.name}</b><br>z: %{x}<br>SLD: %{y}<extra></extra>'
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
if sld_pol is not None:
|
|
313
|
+
fig.add_trace(
|
|
314
|
+
go.Scatter(
|
|
315
|
+
x=z_sld,
|
|
316
|
+
y=sld_pol,
|
|
317
|
+
mode='lines',
|
|
318
|
+
line=dict(color=sld_pol_color, width=2, dash='dash'),
|
|
319
|
+
name=sld_pol_label,
|
|
320
|
+
hovertemplate='<b>%{fullData.name}</b><br>z: %{x}<br>SLD: %{y}<extra></extra>'
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Update axis settings for SLD plot
|
|
325
|
+
fig.update_xaxes(title_text="z [Å]")
|
|
326
|
+
fig.update_yaxes(title_text="SLD [10⁻⁶ Å⁻²]")
|
|
327
|
+
|
|
328
|
+
# The fig is already a FigureWidget, so changes are automatically reflected
|
|
329
|
+
return fig
|