reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,248 +1,272 @@
1
- import warnings
2
-
3
- import numpy as np
4
- from scipy.optimize import minimize, curve_fit
5
- import torch
6
-
7
- from reflectorch.data_generation.priors.base import PriorSampler
8
- from reflectorch.data_generation.reflectivity import abeles_np
9
-
10
- __all__ = [
11
- "standard_refl_fit",
12
- "refl_fit",
13
- "fit_refl_curve",
14
- "restore_masked_params",
15
- "get_fit_with_growth",
16
- ]
17
-
18
-
19
- def standard_restore_params(fitted_params) -> dict:
20
- num_layers = (fitted_params.size - 2) // 3
21
-
22
- return dict(
23
- thickness=fitted_params[:num_layers],
24
- roughness=fitted_params[num_layers:2 * num_layers + 1],
25
- sld=fitted_params[2 * num_layers + 1:],
26
- )
27
-
28
-
29
- def mse_loss(curve1, curve2):
30
- return np.sum((curve1 - curve2) ** 2)
31
-
32
- def standard_refl_fit(
33
- q: np.ndarray, curve: np.ndarray,
34
- init_params: np.ndarray,
35
- bounds: np.ndarray = None,
36
- refl_generator=abeles_np,
37
- restore_params_func=standard_restore_params,
38
- scale_curve_func=np.log10,
39
- **kwargs
40
- ):
41
- if bounds is not None:
42
- kwargs['bounds'] = bounds
43
- init_params = np.clip(init_params, *bounds)
44
-
45
- res = curve_fit(
46
- standard_get_scaled_curve_func(
47
- refl_generator=refl_generator,
48
- restore_params_func=restore_params_func,
49
- scale_curve_func=scale_curve_func,
50
- ),
51
- q, scale_curve_func(curve),
52
- p0=init_params, **kwargs
53
- )
54
-
55
- curve = refl_generator(q, **restore_params_func(res[0]))
56
- return res[0], curve
57
-
58
- def refl_fit(
59
- q: np.ndarray,
60
- curve: np.ndarray,
61
- init_params: np.ndarray,
62
- prior_sampler: PriorSampler,
63
- bounds: np.ndarray = None,
64
- error_bars: np.ndarray = None,
65
- scale_curve_func=np.log10,
66
- reflectivity_kwargs: dict = None,
67
- **kwargs
68
- ):
69
- if bounds is not None:
70
- # introduce a small perturbation for fixed bounds
71
- epsilon = 1e-6
72
- adjusted_bounds = bounds.copy()
73
-
74
- for i in range(bounds.shape[1]):
75
- if bounds[0, i] == bounds[1, i]:
76
- adjusted_bounds[0, i] -= epsilon
77
- adjusted_bounds[1, i] += epsilon
78
-
79
- init_params = np.clip(init_params, *adjusted_bounds)
80
- kwargs['bounds'] = adjusted_bounds
81
-
82
- reflectivity_kwargs = reflectivity_kwargs or {}
83
- for key, value in reflectivity_kwargs.items():
84
- if isinstance(value, float):
85
- reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
86
- elif isinstance(value, np.ndarray):
87
- reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
88
-
89
- res = curve_fit(
90
- f=get_scaled_curve_func(
91
- scale_curve_func=scale_curve_func,
92
- prior_sampler=prior_sampler,
93
- reflectivity_kwargs=reflectivity_kwargs,
94
- ),
95
- xdata=q,
96
- ydata=scale_curve_func(curve),
97
- p0=init_params,
98
- sigma=error_bars if error_bars is not None else None,
99
- absolute_sigma=True,
100
- **kwargs
101
- )
102
-
103
- curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
104
- torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
105
- **reflectivity_kwargs).squeeze().numpy()
106
- return res[0], curve
107
-
108
-
109
- def get_fit_with_growth(
110
- q: np.ndarray,
111
- curve: np.ndarray,
112
- init_params: np.ndarray,
113
- bounds: np.ndarray = None,
114
- init_d_change: float = 0.,
115
- max_d_change: float = 30.,
116
- scale_curve_func=np.log10,
117
- **kwargs
118
- ):
119
- init_params = np.array(list(init_params) + [init_d_change])
120
- if bounds is not None:
121
- bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
122
-
123
- params, curve = standard_refl_fit(
124
- q,
125
- curve,
126
- init_params,
127
- bounds,
128
- refl_generator=growth_reflectivity,
129
- restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
130
- scale_curve_func=scale_curve_func,
131
- **kwargs
132
- )
133
-
134
- params[0] += params[-1] / 2
135
- return params, curve
136
-
137
-
138
- def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
139
- init_params: np.ndarray,
140
- bounds: np.ndarray = None,
141
- refl_generator=abeles_np,
142
- restore_params_func=standard_restore_params,
143
- scale_curve_func=np.log10,
144
- **kwargs
145
- ) -> np.ndarray:
146
- fitting_func = get_fitting_func(
147
- q=q, curve=curve,
148
- refl_generator=refl_generator,
149
- restore_params_func=restore_params_func,
150
- scale_curve_func=scale_curve_func,
151
- )
152
-
153
- res = minimize(fitting_func, init_params, bounds=bounds, **kwargs)
154
-
155
- if not res.success:
156
- warnings.warn(f"Minimization did not converge.")
157
- return res.x
158
-
159
- def standard_get_scaled_curve_func(
160
- refl_generator=abeles_np,
161
- restore_params_func=standard_restore_params,
162
- scale_curve_func=np.log10,
163
- ):
164
- def scaled_curve_func(q, *fitted_params):
165
- fitted_params = restore_params_func(np.asarray(fitted_params))
166
- fitted_curve = refl_generator(q, **fitted_params)
167
- scaled_curve = scale_curve_func(fitted_curve)
168
- return scaled_curve
169
-
170
- return scaled_curve_func
171
-
172
- def get_scaled_curve_func(
173
- scale_curve_func=np.log10,
174
- prior_sampler: PriorSampler = None,
175
- reflectivity_kwargs: dict = None,
176
- ):
177
- reflectivity_kwargs = reflectivity_kwargs or {}
178
-
179
- def scaled_curve_func(q, *fitted_params):
180
- q_tensor = torch.from_numpy(q).to(torch.float64)
181
- fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
182
-
183
- fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
184
- fitted_curve = fitted_curve_tensor.squeeze().numpy()
185
-
186
- scaled_curve = scale_curve_func(fitted_curve)
187
- return scaled_curve
188
-
189
- return scaled_curve_func
190
-
191
-
192
- def get_fitting_func(
193
- q: np.ndarray,
194
- curve: np.ndarray,
195
- refl_generator=abeles_np,
196
- restore_params_func=standard_restore_params,
197
- scale_curve_func=np.log10,
198
- loss_func=mse_loss,
199
- ):
200
- scaled_curve = scale_curve_func(curve)
201
-
202
- def fitting_func(fitted_params):
203
- fitted_params = restore_params_func(fitted_params)
204
- fitted_curve = refl_generator(q, **fitted_params)
205
- loss = loss_func(scale_curve_func(fitted_curve), scaled_curve)
206
- return loss
207
-
208
- return fitting_func
209
-
210
-
211
- def restore_masked_params(fixed_params, fixed_mask):
212
- def restore_params(fitted_params) -> dict:
213
- params = np.empty_like(fixed_mask).astype(fitted_params.dtype)
214
- params[fixed_mask] = fixed_params
215
- params[~fixed_mask] = fitted_params
216
- return standard_restore_params(params)
217
-
218
- return restore_params
219
-
220
-
221
- def base_params2growth(base_params: dict, d_shift: np.ndarray, d_idx: int = 0) -> dict:
222
- d_init = base_params['thickness'][None]
223
- q_size = d_shift.size
224
- d = d_init.repeat(q_size, 0)
225
- d[:, d_idx] = d[:, d_idx] + d_shift
226
-
227
- roughness = np.broadcast_to(base_params['roughness'][None], (q_size, base_params['roughness'].size))
228
- sld = np.broadcast_to(base_params['sld'][None], (q_size, base_params['sld'].size))
229
-
230
- return {
231
- 'thickness': d,
232
- 'roughness': roughness,
233
- 'sld': sld,
234
- }
235
-
236
-
237
- def get_restore_params_with_growth_func(q_size: int, d_idx: int = 0):
238
- def restore_params_with_growth(fitted_params) -> dict:
239
- fitted_params, delta_d = fitted_params[:-1], fitted_params[-1]
240
- base_params = standard_restore_params(fitted_params)
241
- d_shift = np.linspace(0, delta_d, q_size)
242
- return base_params2growth(base_params, d_shift, d_idx)
243
-
244
- return restore_params_with_growth
245
-
246
-
247
- def growth_reflectivity(q: np.ndarray, **kwargs):
248
- return abeles_np(q[..., None], **kwargs).flatten()
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from scipy.optimize import minimize, curve_fit
5
+ import torch
6
+
7
+ from reflectorch.data_generation.priors.base import PriorSampler
8
+ from reflectorch.data_generation.reflectivity import abeles_np
9
+
10
+ __all__ = [
11
+ "standard_refl_fit",
12
+ "refl_fit",
13
+ "fit_refl_curve",
14
+ "restore_masked_params",
15
+ "get_fit_with_growth",
16
+ ]
17
+
18
+
19
+ def standard_restore_params(fitted_params) -> dict:
20
+ num_layers = (fitted_params.size - 2) // 3
21
+
22
+ return dict(
23
+ thickness=fitted_params[:num_layers],
24
+ roughness=fitted_params[num_layers:2 * num_layers + 1],
25
+ sld=fitted_params[2 * num_layers + 1:],
26
+ )
27
+
28
+
29
+ def mse_loss(curve1, curve2):
30
+ return np.sum((curve1 - curve2) ** 2)
31
+
32
+ def standard_refl_fit(
33
+ q: np.ndarray, curve: np.ndarray,
34
+ init_params: np.ndarray,
35
+ bounds: np.ndarray = None,
36
+ refl_generator=abeles_np,
37
+ restore_params_func=standard_restore_params,
38
+ scale_curve_func=np.log10,
39
+ **kwargs
40
+ ):
41
+ if bounds is not None:
42
+ kwargs['bounds'] = bounds
43
+ init_params = np.clip(init_params, *bounds)
44
+
45
+ res = curve_fit(
46
+ standard_get_scaled_curve_func(
47
+ refl_generator=refl_generator,
48
+ restore_params_func=restore_params_func,
49
+ scale_curve_func=scale_curve_func,
50
+ ),
51
+ q, scale_curve_func(curve),
52
+ p0=init_params, **kwargs
53
+ )
54
+
55
+ curve = refl_generator(q, **restore_params_func(res[0]))
56
+ return res[0], curve
57
+
58
+ def refl_fit(
59
+ q: np.ndarray,
60
+ curve: np.ndarray,
61
+ init_params: np.ndarray,
62
+ prior_sampler: PriorSampler,
63
+ bounds: np.ndarray = None,
64
+ error_bars: np.ndarray = None,
65
+ scale_curve_func=np.log10,
66
+ method: str = 'trf', #'lm', 'trf'
67
+ polishing_max_steps: int = None,
68
+ reflectivity_kwargs: dict = None,
69
+ **kwargs
70
+ ):
71
+ if bounds is not None:
72
+ # introduce a small perturbation for fixed bounds
73
+ epsilon = 1e-6
74
+ adjusted_bounds = bounds.copy()
75
+
76
+ for i in range(bounds.shape[1]):
77
+ if bounds[0, i] == bounds[1, i]:
78
+ adjusted_bounds[0, i] -= epsilon
79
+ adjusted_bounds[1, i] += epsilon
80
+
81
+ init_params = np.clip(init_params, *adjusted_bounds)
82
+ if method != 'lm':
83
+ kwargs['bounds'] = adjusted_bounds
84
+
85
+ reflectivity_kwargs = reflectivity_kwargs or {}
86
+ for key, value in reflectivity_kwargs.items():
87
+ if isinstance(value, float):
88
+ reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
89
+ elif isinstance(value, np.ndarray):
90
+ reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
91
+
92
+ curve = np.clip(curve, a_min=1e-12, a_max=None)
93
+
94
+ if error_bars is not None and scale_curve_func == np.log10:
95
+ error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
96
+ scaled_error_bars = error_bars / (curve * np.log(10))
97
+ else:
98
+ scaled_error_bars = None
99
+
100
+ if polishing_max_steps is not None:
101
+ if method == 'lm':
102
+ kwargs['maxfev'] = polishing_max_steps
103
+ else:
104
+ kwargs['max_nfev'] = polishing_max_steps
105
+
106
+ res = curve_fit(
107
+ f=get_scaled_curve_func(
108
+ scale_curve_func=scale_curve_func,
109
+ prior_sampler=prior_sampler,
110
+ reflectivity_kwargs=reflectivity_kwargs,
111
+ ),
112
+ xdata=q,
113
+ ydata=scale_curve_func(curve).reshape(-1),
114
+ p0=init_params,
115
+ sigma=scaled_error_bars,
116
+ absolute_sigma=True,
117
+ method=method,
118
+ **kwargs
119
+ )
120
+
121
+ curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
122
+ torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
123
+ **reflectivity_kwargs).squeeze().numpy()
124
+ # cov matrix --> variance of the parameter estimate
125
+ if res[1] is not None and np.ndim(res[1]) == 2 and np.all(np.isfinite(res[1])):
126
+ pol_param_errs = np.sqrt(np.diag(res[1]))
127
+ else:
128
+ pol_param_errs = np.full_like(res[1], np.nan)
129
+ return res[0], pol_param_errs, curve
130
+
131
+
132
+ def get_fit_with_growth(
133
+ q: np.ndarray,
134
+ curve: np.ndarray,
135
+ init_params: np.ndarray,
136
+ bounds: np.ndarray = None,
137
+ init_d_change: float = 0.,
138
+ max_d_change: float = 30.,
139
+ scale_curve_func=np.log10,
140
+ **kwargs
141
+ ):
142
+ init_params = np.array(list(init_params) + [init_d_change])
143
+ if bounds is not None:
144
+ bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
145
+
146
+ params, curve = standard_refl_fit(
147
+ q,
148
+ curve,
149
+ init_params,
150
+ bounds,
151
+ refl_generator=growth_reflectivity,
152
+ restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
153
+ scale_curve_func=scale_curve_func,
154
+ **kwargs
155
+ )
156
+
157
+ params[0] += params[-1] / 2
158
+ return params, curve
159
+
160
+
161
+ def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
162
+ init_params: np.ndarray,
163
+ bounds: np.ndarray = None,
164
+ refl_generator=abeles_np,
165
+ restore_params_func=standard_restore_params,
166
+ scale_curve_func=np.log10,
167
+ **kwargs
168
+ ) -> np.ndarray:
169
+ fitting_func = get_fitting_func(
170
+ q=q, curve=curve,
171
+ refl_generator=refl_generator,
172
+ restore_params_func=restore_params_func,
173
+ scale_curve_func=scale_curve_func,
174
+ )
175
+
176
+ res = minimize(fitting_func, init_params, bounds=bounds, **kwargs)
177
+
178
+ if not res.success:
179
+ warnings.warn(f"Minimization did not converge.")
180
+ return res.x
181
+
182
+ def standard_get_scaled_curve_func(
183
+ refl_generator=abeles_np,
184
+ restore_params_func=standard_restore_params,
185
+ scale_curve_func=np.log10,
186
+ ):
187
+ def scaled_curve_func(q, *fitted_params):
188
+ fitted_params = restore_params_func(np.asarray(fitted_params))
189
+ fitted_curve = refl_generator(q, **fitted_params)
190
+ scaled_curve = scale_curve_func(fitted_curve)
191
+ return scaled_curve
192
+
193
+ return scaled_curve_func
194
+
195
+ def get_scaled_curve_func(
196
+ scale_curve_func=np.log10,
197
+ prior_sampler: PriorSampler = None,
198
+ reflectivity_kwargs: dict = None,
199
+ ):
200
+ reflectivity_kwargs = reflectivity_kwargs or {}
201
+
202
+ def scaled_curve_func(q, *fitted_params):
203
+ q_tensor = torch.from_numpy(q).to(torch.float64)
204
+ fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
205
+
206
+ fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
207
+ fitted_curve = fitted_curve_tensor.squeeze().numpy()
208
+
209
+ scaled_curve = scale_curve_func(fitted_curve)
210
+
211
+ return scaled_curve.reshape(-1)
212
+
213
+ return scaled_curve_func
214
+
215
+
216
+ def get_fitting_func(
217
+ q: np.ndarray,
218
+ curve: np.ndarray,
219
+ refl_generator=abeles_np,
220
+ restore_params_func=standard_restore_params,
221
+ scale_curve_func=np.log10,
222
+ loss_func=mse_loss,
223
+ ):
224
+ scaled_curve = scale_curve_func(curve)
225
+
226
+ def fitting_func(fitted_params):
227
+ fitted_params = restore_params_func(fitted_params)
228
+ fitted_curve = refl_generator(q, **fitted_params)
229
+ loss = loss_func(scale_curve_func(fitted_curve), scaled_curve)
230
+ return loss
231
+
232
+ return fitting_func
233
+
234
+
235
+ def restore_masked_params(fixed_params, fixed_mask):
236
+ def restore_params(fitted_params) -> dict:
237
+ params = np.empty_like(fixed_mask).astype(fitted_params.dtype)
238
+ params[fixed_mask] = fixed_params
239
+ params[~fixed_mask] = fitted_params
240
+ return standard_restore_params(params)
241
+
242
+ return restore_params
243
+
244
+
245
+ def base_params2growth(base_params: dict, d_shift: np.ndarray, d_idx: int = 0) -> dict:
246
+ d_init = base_params['thickness'][None]
247
+ q_size = d_shift.size
248
+ d = d_init.repeat(q_size, 0)
249
+ d[:, d_idx] = d[:, d_idx] + d_shift
250
+
251
+ roughness = np.broadcast_to(base_params['roughness'][None], (q_size, base_params['roughness'].size))
252
+ sld = np.broadcast_to(base_params['sld'][None], (q_size, base_params['sld'].size))
253
+
254
+ return {
255
+ 'thickness': d,
256
+ 'roughness': roughness,
257
+ 'sld': sld,
258
+ }
259
+
260
+
261
+ def get_restore_params_with_growth_func(q_size: int, d_idx: int = 0):
262
+ def restore_params_with_growth(fitted_params) -> dict:
263
+ fitted_params, delta_d = fitted_params[:-1], fitted_params[-1]
264
+ base_params = standard_restore_params(fitted_params)
265
+ d_shift = np.linspace(0, delta_d, q_size)
266
+ return base_params2growth(base_params, d_shift, d_idx)
267
+
268
+ return restore_params_with_growth
269
+
270
+
271
+ def growth_reflectivity(q: np.ndarray, **kwargs):
272
+ return abeles_np(q[..., None], **kwargs).flatten()