reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,364 @@
1
+
2
+ import warnings
3
+ import joblib
4
+ from joblib import Parallel, delayed
5
+ import numpy as np
6
+ from scipy.optimize import minimize, curve_fit
7
+ import torch
8
+
9
+ from reflectorch.data_generation.priors.base import PriorSampler
10
+ from reflectorch.data_generation.reflectivity import abeles_np
11
+
12
+ __all__ = [
13
+ "standard_refl_fit",
14
+ "refl_fit",
15
+ "fit_refl_curve",
16
+ "restore_masked_params",
17
+ "get_fit_with_growth",
18
+ "batch_standard_refl_fit",
19
+ ]
20
+
21
+
22
+ def standard_restore_params(fitted_params) -> dict:
23
+ num_layers = (fitted_params.size - 2) // 3
24
+
25
+ return dict(
26
+ thickness=fitted_params[:num_layers],
27
+ roughness=fitted_params[num_layers:2 * num_layers + 1],
28
+ sld=fitted_params[2 * num_layers + 1:],
29
+ )
30
+
31
+
32
+ def mse_loss(curve1, curve2):
33
+ return np.sum((curve1 - curve2) ** 2)
34
+
35
+ def standard_refl_fit(
36
+ q: np.ndarray, curve: np.ndarray,
37
+ init_params: np.ndarray,
38
+ bounds: np.ndarray = None,
39
+ refl_generator=abeles_np,
40
+ restore_params_func=standard_restore_params,
41
+ scale_curve_func=np.log10,
42
+ **kwargs
43
+ ):
44
+ if bounds is not None:
45
+ kwargs['bounds'] = bounds
46
+ init_params = np.clip(init_params, *bounds)
47
+
48
+ res = curve_fit(
49
+ standard_get_scaled_curve_func(
50
+ refl_generator=refl_generator,
51
+ restore_params_func=restore_params_func,
52
+ scale_curve_func=scale_curve_func,
53
+ ),
54
+ q, scale_curve_func(curve),
55
+ p0=init_params, **kwargs
56
+ )
57
+
58
+ curve = refl_generator(q, **restore_params_func(res[0]))
59
+ return res[0], curve
60
+
61
+
62
+
63
+ def batch_refl_fit(
64
+ q: np.ndarray,
65
+ curves: np.ndarray,
66
+ init_params: np.ndarray, # (n_curves, n_params)
67
+ prior_sampler: PriorSampler,
68
+ bounds: np.ndarray = None,
69
+ error_bars: np.ndarray = None,
70
+ scale_curve_func=np.log10,
71
+ method: str = 'trf', #'lm', 'trf'
72
+ polishing_max_steps: int = None,
73
+ reflectivity_kwargs: dict = None,
74
+ n_jobs: int = -1,
75
+ verbose: int = 5,
76
+ **kwargs
77
+ ):
78
+ """
79
+ Fit (polished fit) multiple reflectivity curves in parallel using joblib.
80
+
81
+ Parameters
82
+ ----------
83
+ q : np.ndarray
84
+ 1D array of momentum transfer values (same for all curves).
85
+ curves : np.ndarray
86
+ 2D array of reflectivity curves with shape (n_curves, n_q).
87
+ init_params : np.ndarray
88
+ 2D array of initial parameter guesses (n_curves, n_params).
89
+ prior_sampler : PriorSampler
90
+ The prior sampler.
91
+ bounds : np.ndarray, optional
92
+ Bounds for the parameters, shape (2, n_params). Shared by all the curves. Default: None.
93
+ error_bars : np.ndarray, optional
94
+ Error bars for the curves, shape (n_curves, n_q). Default: None.
95
+ scale_curve_func : callable, optional
96
+ Function to scale the curves. Default: `np.log10`.
97
+ method : str, optional
98
+ The method to use for the fitting. Default: 'trf'.
99
+ polishing_max_steps : int, optional
100
+ The maximum number of function evaluations for the polishing step. Default: None.
101
+ reflectivity_kwargs : dict, optional
102
+ Keyword arguments for the reflectivity function. Default: None.
103
+ n_jobs : int, optional
104
+ The number of jobs to run in parallel. Default: -1 (all CPUs).
105
+ verbose : int, optional
106
+ The verbosity level for joblib. Default: 5.
107
+ **kwargs : dict
108
+ Extra keyword arguments passed to `scipy.optimize.curve_fit`.
109
+
110
+ Returns
111
+ -------
112
+ params_array : np.ndarray
113
+ Array of fitted parameter values for each curve, shape (n_curves, n_params).
114
+ error_bars_array : np.ndarray
115
+ Array of error bars for the fitted parameter values, shape (n_curves, n_params).
116
+ curves_array : np.ndarray
117
+ Array of fitted reflectivity curves, shape (n_curves, n_q).
118
+ """
119
+ if bounds is not None:
120
+ if bounds.ndim == 2:
121
+ bounds = bounds[None].repeat(curves.shape[0], 0)
122
+ elif bounds.ndim == 3:
123
+ assert bounds.shape[0] == curves.shape[0], f"Bounds must have the same number of curves as the number of curves, got {bounds.shape[0]} and {curves.shape[0]}"
124
+ else:
125
+ raise ValueError(f"Bounds must be a 2D or 3D array, got {bounds.ndim}D array")
126
+ else:
127
+ bounds = [None] * curves.shape[0]
128
+
129
+ results = Parallel(n_jobs=n_jobs, verbose=verbose)(
130
+ delayed(refl_fit)(
131
+ q=q, curve=curve, init_params=init_params,
132
+ bounds=bound,
133
+ prior_sampler=prior_sampler,
134
+ error_bars=error_bars,
135
+ method=method,
136
+ scale_curve_func=scale_curve_func,
137
+ polishing_max_steps=polishing_max_steps,
138
+ reflectivity_kwargs=reflectivity_kwargs,
139
+ **kwargs
140
+ )
141
+ for curve, init_params, bound in zip(curves, init_params, bounds)
142
+ )
143
+
144
+ params_array, error_bars, curves_array = zip(*results)
145
+ return np.array(params_array), np.array(error_bars), np.array(curves_array)
146
+
147
+
148
+ def refl_fit(
149
+ q: np.ndarray,
150
+ curve: np.ndarray,
151
+ init_params: np.ndarray,
152
+ prior_sampler: PriorSampler,
153
+ bounds: np.ndarray = None,
154
+ error_bars: np.ndarray = None,
155
+ scale_curve_func=np.log10,
156
+ method: str = 'trf', #'lm', 'trf'
157
+ polishing_max_steps: int = None,
158
+ reflectivity_kwargs: dict = None,
159
+ **kwargs
160
+ ):
161
+ if bounds is not None:
162
+ if bounds.ndim != 2:
163
+ raise ValueError(f"Bounds must be a 2D array, got {bounds.ndim}D array")
164
+ # introduce a small perturbation for fixed bounds
165
+ epsilon = 1e-6
166
+ adjusted_bounds = bounds.copy()
167
+
168
+ for i in range(bounds.shape[1]):
169
+ if bounds[0, i] == bounds[1, i]:
170
+ adjusted_bounds[0, i] -= epsilon
171
+ adjusted_bounds[1, i] += epsilon
172
+
173
+ init_params = np.clip(init_params, *adjusted_bounds)
174
+ if method != 'lm':
175
+ kwargs['bounds'] = adjusted_bounds
176
+
177
+ reflectivity_kwargs = reflectivity_kwargs or {}
178
+ for key, value in reflectivity_kwargs.items():
179
+ if isinstance(value, float):
180
+ reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
181
+ elif isinstance(value, np.ndarray):
182
+ reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
183
+
184
+ curve = np.clip(curve, a_min=1e-12, a_max=None)
185
+
186
+ if error_bars is not None and scale_curve_func == np.log10:
187
+ error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
188
+ scaled_error_bars = error_bars / (curve * np.log(10))
189
+ else:
190
+ scaled_error_bars = None
191
+
192
+ if polishing_max_steps is not None:
193
+ if method == 'lm':
194
+ kwargs['maxfev'] = polishing_max_steps
195
+ else:
196
+ kwargs['max_nfev'] = polishing_max_steps
197
+
198
+ res = curve_fit(
199
+ f=get_scaled_curve_func(
200
+ scale_curve_func=scale_curve_func,
201
+ prior_sampler=prior_sampler,
202
+ reflectivity_kwargs=reflectivity_kwargs,
203
+ ),
204
+ xdata=q,
205
+ ydata=scale_curve_func(curve).reshape(-1),
206
+ p0=init_params,
207
+ sigma=scaled_error_bars,
208
+ absolute_sigma=True,
209
+ method=method,
210
+ **kwargs
211
+ )
212
+
213
+ curve = prior_sampler.param_model.reflectivity(torch.tensor(q, dtype=torch.float64),
214
+ torch.tensor(res[0], dtype=torch.float64).unsqueeze(0),
215
+ **reflectivity_kwargs).squeeze().numpy()
216
+ # cov matrix --> variance of the parameter estimate
217
+ if res[1] is not None and np.ndim(res[1]) == 2 and np.all(np.isfinite(res[1])):
218
+ pol_param_errs = np.sqrt(np.diag(res[1]))
219
+ else:
220
+ pol_param_errs = np.full_like(res[1], np.nan)
221
+ return res[0], pol_param_errs, curve
222
+
223
+
224
+ def get_fit_with_growth(
225
+ q: np.ndarray,
226
+ curve: np.ndarray,
227
+ init_params: np.ndarray,
228
+ bounds: np.ndarray = None,
229
+ init_d_change: float = 0.,
230
+ max_d_change: float = 30.,
231
+ scale_curve_func=np.log10,
232
+ **kwargs
233
+ ):
234
+ init_params = np.array(list(init_params) + [init_d_change])
235
+ if bounds is not None:
236
+ bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
237
+
238
+ params, curve = standard_refl_fit(
239
+ q,
240
+ curve,
241
+ init_params,
242
+ bounds,
243
+ refl_generator=growth_reflectivity,
244
+ restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
245
+ scale_curve_func=scale_curve_func,
246
+ **kwargs
247
+ )
248
+
249
+ params[0] += params[-1] / 2
250
+ return params, curve
251
+
252
+
253
+ def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
254
+ init_params: np.ndarray,
255
+ bounds: np.ndarray = None,
256
+ refl_generator=abeles_np,
257
+ restore_params_func=standard_restore_params,
258
+ scale_curve_func=np.log10,
259
+ **kwargs
260
+ ) -> np.ndarray:
261
+ fitting_func = get_fitting_func(
262
+ q=q, curve=curve,
263
+ refl_generator=refl_generator,
264
+ restore_params_func=restore_params_func,
265
+ scale_curve_func=scale_curve_func,
266
+ )
267
+
268
+ res = minimize(fitting_func, init_params, bounds=bounds, **kwargs)
269
+
270
+ if not res.success:
271
+ warnings.warn(f"Minimization did not converge.")
272
+ return res.x
273
+
274
+ def standard_get_scaled_curve_func(
275
+ refl_generator=abeles_np,
276
+ restore_params_func=standard_restore_params,
277
+ scale_curve_func=np.log10,
278
+ ):
279
+ def scaled_curve_func(q, *fitted_params):
280
+ fitted_params = restore_params_func(np.asarray(fitted_params))
281
+ fitted_curve = refl_generator(q, **fitted_params)
282
+ scaled_curve = scale_curve_func(fitted_curve)
283
+ return scaled_curve
284
+
285
+ return scaled_curve_func
286
+
287
+ def get_scaled_curve_func(
288
+ scale_curve_func=np.log10,
289
+ prior_sampler: PriorSampler = None,
290
+ reflectivity_kwargs: dict = None,
291
+ ):
292
+ reflectivity_kwargs = reflectivity_kwargs or {}
293
+
294
+ def scaled_curve_func(q, *fitted_params):
295
+ q_tensor = torch.from_numpy(q).to(torch.float64)
296
+ fitted_params_tensor = torch.tensor(fitted_params, dtype=torch.float64).unsqueeze(0)
297
+
298
+ fitted_curve_tensor = prior_sampler.param_model.reflectivity(q_tensor, fitted_params_tensor, **reflectivity_kwargs)
299
+ fitted_curve = fitted_curve_tensor.squeeze().numpy()
300
+
301
+ scaled_curve = scale_curve_func(fitted_curve)
302
+
303
+ return scaled_curve.reshape(-1)
304
+
305
+ return scaled_curve_func
306
+
307
+
308
+ def get_fitting_func(
309
+ q: np.ndarray,
310
+ curve: np.ndarray,
311
+ refl_generator=abeles_np,
312
+ restore_params_func=standard_restore_params,
313
+ scale_curve_func=np.log10,
314
+ loss_func=mse_loss,
315
+ ):
316
+ scaled_curve = scale_curve_func(curve)
317
+
318
+ def fitting_func(fitted_params):
319
+ fitted_params = restore_params_func(fitted_params)
320
+ fitted_curve = refl_generator(q, **fitted_params)
321
+ loss = loss_func(scale_curve_func(fitted_curve), scaled_curve)
322
+ return loss
323
+
324
+ return fitting_func
325
+
326
+
327
+ def restore_masked_params(fixed_params, fixed_mask):
328
+ def restore_params(fitted_params) -> dict:
329
+ params = np.empty_like(fixed_mask).astype(fitted_params.dtype)
330
+ params[fixed_mask] = fixed_params
331
+ params[~fixed_mask] = fitted_params
332
+ return standard_restore_params(params)
333
+
334
+ return restore_params
335
+
336
+
337
+ def base_params2growth(base_params: dict, d_shift: np.ndarray, d_idx: int = 0) -> dict:
338
+ d_init = base_params['thickness'][None]
339
+ q_size = d_shift.size
340
+ d = d_init.repeat(q_size, 0)
341
+ d[:, d_idx] = d[:, d_idx] + d_shift
342
+
343
+ roughness = np.broadcast_to(base_params['roughness'][None], (q_size, base_params['roughness'].size))
344
+ sld = np.broadcast_to(base_params['sld'][None], (q_size, base_params['sld'].size))
345
+
346
+ return {
347
+ 'thickness': d,
348
+ 'roughness': roughness,
349
+ 'sld': sld,
350
+ }
351
+
352
+
353
+ def get_restore_params_with_growth_func(q_size: int, d_idx: int = 0):
354
+ def restore_params_with_growth(fitted_params) -> dict:
355
+ fitted_params, delta_d = fitted_params[:-1], fitted_params[-1]
356
+ base_params = standard_restore_params(fitted_params)
357
+ d_shift = np.linspace(0, delta_d, q_size)
358
+ return base_params2growth(base_params, d_shift, d_idx)
359
+
360
+ return restore_params_with_growth
361
+
362
+
363
+ def growth_reflectivity(q: np.ndarray, **kwargs):
364
+ return abeles_np(q[..., None], **kwargs).flatten()
@@ -0,0 +1,87 @@
1
+ from tqdm import trange
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+
6
+ from reflectorch.data_generation import LogLikelihood, reflectivity, PriorSampler
7
+
8
+
9
+ class ReflGradientFit(object):
10
+ """Directly optimizes the thin film parameters using a Pytorch optimizer
11
+
12
+ Args:
13
+ q (Tensor): the q positions
14
+ exp_curve (Tensor): the experimental reflectivity curve
15
+ prior_sampler (PriorSampler): the prior sampler
16
+ params (Tensor): the initial thin film parameters
17
+ fit_indices (Tensor): the indices of the thin film parameters which are to be fitted
18
+ sigmas (Tensor, optional): error bars of the reflectivity curve, if not provided they are derived from ``rel_err`` and ``abs_err``. Defaults to None.
19
+ optim_cls (Type[torch.optim.Optimizer], optional): the Pytorch optimizer class. Defaults to None.
20
+ lr (float, optional): the learning rate. Defaults to 1e-2.
21
+ rel_err (float, optional): the relative error in the reflectivity curve. Defaults to 0.1.
22
+ abs_err (float, optional): the absolute error in the reflectivity curve. Defaults to 1e-7.
23
+ """
24
+ def __init__(self,
25
+ q: Tensor,
26
+ exp_curve: Tensor,
27
+ prior_sampler: PriorSampler,
28
+ params: Tensor,
29
+ fit_indices: Tensor,
30
+ sigmas: Tensor = None,
31
+ optim_cls=None,
32
+ lr: float = 1e-2,
33
+ rel_err: float = 0.1,
34
+ abs_err: float = 1e-7,
35
+ ):
36
+ self.q = q
37
+
38
+ if sigmas is None:
39
+ sigmas = exp_curve * rel_err + abs_err
40
+
41
+ self.likelihood = LogLikelihood(q, exp_curve, prior_sampler, sigmas)
42
+
43
+ self.num_layers = params.shape[-1] // 3
44
+ self.fit_indices = fit_indices
45
+ self.init_params = params.clone()
46
+ self.params_to_fit = nn.Parameter(self.init_params[fit_indices].clone())
47
+
48
+ optim_cls = optim_cls or torch.optim.Adam
49
+ self.optim = optim_cls([self.params_to_fit], lr)
50
+
51
+ self.losses = []
52
+
53
+ @property
54
+ def params(self):
55
+ params = self.init_params.clone()
56
+ params[self.fit_indices] = self.params_to_fit
57
+ return params
58
+
59
+ def calc_log_likelihood(self):
60
+ return self.likelihood.calc_log_likelihood(self.refl())
61
+
62
+ def calc_log_prob_loss(self):
63
+ return - self.calc_log_likelihood().mean()
64
+
65
+ def refl(self):
66
+ d, sigma, rho = torch.split(self.params, [self.num_layers, self.num_layers + 1, self.num_layers + 1], -1)
67
+ return reflectivity(self.q, d, sigma, rho)
68
+
69
+ def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
70
+ """Runs the optimization process
71
+
72
+ Args:
73
+ num_iterations (int, optional): number of iterations the optimization is run for. Defaults to 500.
74
+ disable_tqdm (bool, optional): whether to disable the prograss bar. Defaults to False.
75
+ """
76
+ pbar = trange(num_iterations, disable=disable_tqdm)
77
+
78
+ for _ in pbar:
79
+ self.optim.zero_grad()
80
+ loss = self.calc_log_prob_loss()
81
+ loss.backward()
82
+ self.optim.step()
83
+ self.losses.append(loss.item())
84
+ pbar.set_description(f'Loss = {loss.item():.2e}')
85
+
86
+ def clear(self):
87
+ self.losses.clear()
@@ -0,0 +1,32 @@
1
+ from reflectorch.ml.basic_trainer import *
2
+ from reflectorch.ml.callbacks import *
3
+ from reflectorch.ml.trainers import *
4
+ from reflectorch.ml.loggers import *
5
+ from reflectorch.ml.schedulers import *
6
+ from reflectorch.ml.dataloaders import *
7
+
8
+ __all__ = [
9
+ 'Trainer',
10
+ 'TrainerCallback',
11
+ 'DataLoader',
12
+ 'PeriodicTrainerCallback',
13
+ 'SaveBestModel',
14
+ 'LogLosses',
15
+ 'Logger',
16
+ 'Loggers',
17
+ 'PrintLogger',
18
+ 'TensorBoardLogger',
19
+ 'ScheduleBatchSize',
20
+ 'ScheduleLR',
21
+ 'StepLR',
22
+ 'CyclicLR',
23
+ 'LogCyclicLR',
24
+ 'ReduceLROnPlateau',
25
+ 'OneCycleLR',
26
+ 'CosineAnnealingWithWarmup',
27
+ 'ReflectivityDataLoader',
28
+ 'MultilayerDataLoader',
29
+ 'RealTimeSimTrainer',
30
+ 'DenoisingAETrainer',
31
+ 'PointEstimatorTrainer',
32
+ ]