reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,329 @@
1
+ import numpy as np
2
+
3
+ import plotly.graph_objects as go
4
+
5
+
6
+ class PlotlyPlotManager:
7
+ """
8
+ Manager for Plotly figures in Jupyter widgets
9
+ """
10
+
11
+ def __init__(self, verbose: bool = False):
12
+ self.figures = {} # Store persistent figures
13
+ self.widgets = {} # Store plotly widgets
14
+ self.verbose = verbose
15
+
16
+ def create_reflectivity_figure(self,
17
+ figure_id: str,
18
+ width: int = 600,
19
+ height: int = 300):
20
+ """Create a reflectivity-only figure widget"""
21
+ fig = go.Figure()
22
+
23
+ fig.update_layout(
24
+ width=width,
25
+ height=height,
26
+ showlegend=True,
27
+ hovermode='closest',
28
+ template='plotly_white',
29
+ margin=dict(l=60, r=20, t=60, b=60),
30
+ legend=dict(
31
+ orientation="h",
32
+ yanchor="bottom",
33
+ y=1.02,
34
+ xanchor="left",
35
+ x=0
36
+ )
37
+ )
38
+
39
+ # Create Plotly widget
40
+ plotly_widget = go.FigureWidget(fig)
41
+
42
+ # Store references
43
+ self.figures[figure_id] = fig
44
+ self.widgets[figure_id] = plotly_widget
45
+
46
+ return plotly_widget
47
+
48
+ def create_sld_figure(self,
49
+ figure_id: str,
50
+ width: int = 600,
51
+ height: int = 250):
52
+ """Create an SLD-only figure widget"""
53
+ fig = go.Figure()
54
+
55
+ fig.update_layout(
56
+ width=width,
57
+ height=height,
58
+ showlegend=True,
59
+ hovermode='closest',
60
+ template='plotly_white',
61
+ margin=dict(l=60, r=20, t=60, b=60),
62
+ legend=dict(
63
+ orientation="h",
64
+ yanchor="bottom",
65
+ y=1.02,
66
+ xanchor="left",
67
+ x=0
68
+ )
69
+ )
70
+
71
+ # Create Plotly widget
72
+ plotly_widget = go.FigureWidget(fig)
73
+
74
+ # Store references
75
+ self.figures[figure_id] = fig
76
+ self.widgets[figure_id] = plotly_widget
77
+
78
+ return plotly_widget
79
+
80
+ def _setup_figure_hover(self, figure_id: str, plotly_widget):
81
+ """Setup hover functionality for the figure (no coordinate display)"""
82
+ # Just ensure data is mutable for future trace additions
83
+ plotly_widget.data = list(plotly_widget.data)
84
+
85
+ def get_figure(self, figure_id: str):
86
+ """Get existing figure widget"""
87
+ if figure_id not in self.widgets:
88
+ raise ValueError(f"Figure '{figure_id}' not found. Create it first with create_figure().")
89
+ return self.widgets[figure_id]
90
+
91
+ def get_widget(self, figure_id: str):
92
+ """Get the plotly widget for display"""
93
+ if figure_id not in self.widgets:
94
+ raise ValueError(f"Widget for figure '{figure_id}' not found.")
95
+ return self.widgets[figure_id]
96
+
97
+
98
+ def clear_figure(self, figure_id: str):
99
+ """Clear all traces from the figure"""
100
+ if figure_id in self.widgets:
101
+ widget = self.widgets[figure_id]
102
+ # Use Plotly's proper method to clear traces
103
+ with widget.batch_update():
104
+ widget.data = []
105
+
106
+ def close_figure(self, figure_id: str):
107
+ """Close and cleanup a figure"""
108
+ if figure_id in self.figures:
109
+ del self.figures[figure_id]
110
+
111
+ if figure_id in self.widgets:
112
+ del self.widgets[figure_id]
113
+
114
+
115
+
116
+ def plot_reflectivity_only(
117
+ plot_manager: PlotlyPlotManager,
118
+ figure_id: str,
119
+ *,
120
+ q_exp=None,
121
+ r_exp=None,
122
+ yerr=None,
123
+ xerr=None,
124
+ q_pred=None,
125
+ r_pred=None,
126
+ q_pol=None,
127
+ r_pol=None,
128
+ logx=False,
129
+ logy=True,
130
+ exp_color='blue',
131
+ exp_errcolor='purple',
132
+ pred_color='red',
133
+ pol_color='orange',
134
+ exp_label='experimental data',
135
+ pred_label='prediction',
136
+ pol_label='polished prediction',
137
+ width=600,
138
+ height=300
139
+ ):
140
+ """
141
+ Plot reflectivity data only using Plotly
142
+
143
+ This function creates or updates a Plotly figure widget with reflectivity data only.
144
+ """
145
+
146
+ def _np(a):
147
+ return None if a is None else np.asarray(a)
148
+
149
+ def _mask_data(x, y):
150
+ """Create mask for finite values and positive values if log scale"""
151
+ if x is None or y is None:
152
+ return None, None
153
+
154
+ x, y = np.asarray(x), np.asarray(y)
155
+ mask = np.isfinite(x) & np.isfinite(y)
156
+
157
+ if logx:
158
+ mask &= (x > 0.0)
159
+ if logy:
160
+ mask &= (y > 0.0)
161
+
162
+ return x[mask], y[mask]
163
+
164
+ # Convert inputs to numpy arrays
165
+ q_exp, r_exp, yerr, xerr = _np(q_exp), _np(r_exp), _np(yerr), _np(xerr)
166
+ q_pred, r_pred = _np(q_pred), _np(r_pred)
167
+ q_pol, r_pol = _np(q_pol), _np(r_pol)
168
+
169
+ # Create or get existing figure widget
170
+ try:
171
+ fig = plot_manager.get_figure(figure_id)
172
+ # Clear existing traces
173
+ plot_manager.clear_figure(figure_id)
174
+ except ValueError:
175
+ # Figure doesn't exist, create new one
176
+ fig = plot_manager.create_reflectivity_figure(figure_id, width, height)
177
+
178
+ # Plot experimental data
179
+ if q_exp is not None and r_exp is not None:
180
+ q_exp_clean, r_exp_clean = _mask_data(q_exp, r_exp)
181
+
182
+ if q_exp_clean is not None and len(q_exp_clean) > 0:
183
+ # Handle error bars
184
+ error_y = None
185
+ error_x = None
186
+
187
+ if yerr is not None:
188
+ yerr_clean = yerr[np.isfinite(q_exp) & np.isfinite(r_exp)]
189
+ if logx:
190
+ yerr_clean = yerr_clean[q_exp > 0.0]
191
+ if logy:
192
+ yerr_clean = yerr_clean[r_exp > 0.0]
193
+ error_y = dict(type='data', array=yerr_clean, visible=True, color=exp_errcolor)
194
+
195
+ if xerr is not None:
196
+ xerr_clean = xerr[np.isfinite(q_exp) & np.isfinite(r_exp)]
197
+ if logx:
198
+ xerr_clean = xerr_clean[q_exp > 0.0]
199
+ if logy:
200
+ xerr_clean = xerr_clean[r_exp > 0.0]
201
+ error_x = dict(type='data', array=xerr_clean, visible=True, color=exp_errcolor)
202
+
203
+ # Add experimental data trace
204
+ fig.add_trace(
205
+ go.Scatter(
206
+ x=q_exp_clean,
207
+ y=r_exp_clean,
208
+ mode='markers',
209
+ marker=dict(color=exp_color, size=6),
210
+ error_y=error_y,
211
+ error_x=error_x,
212
+ name=exp_label,
213
+ hovertemplate='<b>%{fullData.name}</b><br>q: %{x}<br>R: %{y}<extra></extra>'
214
+ )
215
+ )
216
+
217
+ # Plot predicted curve
218
+ if q_pred is not None and r_pred is not None:
219
+ q_pred_clean, r_pred_clean = _mask_data(q_pred, r_pred)
220
+
221
+ if q_pred_clean is not None and len(q_pred_clean) > 0:
222
+ fig.add_trace(
223
+ go.Scatter(
224
+ x=q_pred_clean,
225
+ y=r_pred_clean,
226
+ mode='lines',
227
+ line=dict(color=pred_color, width=2),
228
+ name=pred_label,
229
+ hovertemplate='<b>%{fullData.name}</b><br>q: %{x}<br>R: %{y}<extra></extra>'
230
+ )
231
+ )
232
+
233
+ # Plot polished curve
234
+ if q_pol is not None and r_pol is not None:
235
+ q_pol_clean, r_pol_clean = _mask_data(q_pol, r_pol)
236
+
237
+ if q_pol_clean is not None and len(q_pol_clean) > 0:
238
+ fig.add_trace(
239
+ go.Scatter(
240
+ x=q_pol_clean,
241
+ y=r_pol_clean,
242
+ mode='lines',
243
+ line=dict(color=pol_color, width=2, dash='dash'),
244
+ name=pol_label,
245
+ hovertemplate='<b>%{fullData.name}</b><br>q: %{x}<br>R: %{y}<extra></extra>'
246
+ )
247
+ )
248
+
249
+ # Update axis settings for reflectivity plot
250
+ fig.update_xaxes(
251
+ title_text="q [Å⁻¹]",
252
+ type='log' if logx else 'linear'
253
+ )
254
+ fig.update_yaxes(
255
+ title_text="R(q)",
256
+ type='log' if logy else 'linear'
257
+ )
258
+
259
+ # The fig is already a FigureWidget, so changes are automatically reflected
260
+ return fig
261
+
262
+
263
+ def plot_sld_only(
264
+ plot_manager: PlotlyPlotManager,
265
+ figure_id: str,
266
+ *,
267
+ z_sld=None,
268
+ sld_pred=None,
269
+ sld_pol=None,
270
+ sld_pred_color='red',
271
+ sld_pol_color='orange',
272
+ sld_pred_label='pred. SLD',
273
+ sld_pol_label='polished SLD',
274
+ width=600,
275
+ height=250
276
+ ):
277
+ """
278
+ Plot SLD profile data only using Plotly
279
+
280
+ This function creates or updates a Plotly figure widget with SLD profile data only.
281
+ """
282
+
283
+ def _np(a):
284
+ return None if a is None else np.asarray(a)
285
+
286
+ # Convert inputs to numpy arrays
287
+ z_sld, sld_pred, sld_pol = _np(z_sld), _np(sld_pred), _np(sld_pol)
288
+
289
+ # Create or get existing figure widget
290
+ try:
291
+ fig = plot_manager.get_figure(figure_id)
292
+ # Clear existing traces
293
+ plot_manager.clear_figure(figure_id)
294
+ except ValueError:
295
+ # Figure doesn't exist, create new one
296
+ fig = plot_manager.create_sld_figure(figure_id, width, height)
297
+
298
+ # Plot SLD profiles
299
+ if z_sld is not None:
300
+ if sld_pred is not None:
301
+ fig.add_trace(
302
+ go.Scatter(
303
+ x=z_sld,
304
+ y=sld_pred,
305
+ mode='lines',
306
+ line=dict(color=sld_pred_color, width=2),
307
+ name=sld_pred_label,
308
+ hovertemplate='<b>%{fullData.name}</b><br>z: %{x}<br>SLD: %{y}<extra></extra>'
309
+ )
310
+ )
311
+
312
+ if sld_pol is not None:
313
+ fig.add_trace(
314
+ go.Scatter(
315
+ x=z_sld,
316
+ y=sld_pol,
317
+ mode='lines',
318
+ line=dict(color=sld_pol_color, width=2, dash='dash'),
319
+ name=sld_pol_label,
320
+ hovertemplate='<b>%{fullData.name}</b><br>z: %{x}<br>SLD: %{y}<extra></extra>'
321
+ )
322
+ )
323
+
324
+ # Update axis settings for SLD plot
325
+ fig.update_xaxes(title_text="z [Å]")
326
+ fig.update_yaxes(title_text="SLD [10⁻⁶ Å⁻²]")
327
+
328
+ # The fig is already a FigureWidget, so changes are automatically reflected
329
+ return fig