reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
1
+ # -*- coding: utf-8 -*-
2
+ from math import pi, sqrt, log
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+ from torch.nn.functional import conv1d, pad
9
+
10
+
11
+ def abeles_constant_smearing(
12
+ q: Tensor,
13
+ thickness: Tensor,
14
+ roughness: Tensor,
15
+ sld: Tensor,
16
+ dq: Tensor = None,
17
+ gauss_num: int = 51,
18
+ constant_dq: bool = True,
19
+ abeles_func=None,
20
+ ):
21
+ abeles_func = abeles_func or abeles
22
+ q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
23
+ kernels = _get_t_gauss_kernels(dq, gauss_num)
24
+
25
+ curves = abeles_func(q_lin, thickness, roughness, sld)
26
+
27
+ padding = (kernels.shape[-1] - 1) // 2
28
+ smeared_curves = conv1d(
29
+ pad(curves[None], (padding, padding), 'reflect'), kernels[:, None], groups=kernels.shape[0],
30
+ )[0]
31
+
32
+ if q.shape[0] != smeared_curves.shape[0]:
33
+ q = q.expand(smeared_curves.shape[0], *q.shape[1:])
34
+
35
+ smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
36
+
37
+ return smeared_curves
38
+
39
+
40
+ _FWHM = 2 * sqrt(2 * log(2.0))
41
+ _2PI_SQRT = 1. / sqrt(2 * pi)
42
+
43
+
44
+ def _batch_linspace(start: Tensor, end: Tensor, num: int):
45
+ return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
46
+
47
+
48
+ def _torch_gauss(x, s):
49
+ return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
50
+
51
+
52
+ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
53
+ gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
54
+ gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
55
+ return gauss_y
56
+
57
+
58
+ def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = True):
59
+ if constant_dq:
60
+ return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
61
+ else:
62
+ return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
63
+
64
+
65
+ def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
66
+ gaussgpoint = (gaussnum - 1) / 2
67
+
68
+ lowq = torch.clamp_min_(q.min(1).values, 1e-6)
69
+ highq = q.max(1).values
70
+
71
+ start = torch.log10(lowq) - 6 * resolutions / _FWHM
72
+ end = torch.log10(highq * (1 + 6 * resolutions / _FWHM))
73
+
74
+ interpnums = torch.abs(
75
+ (torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
76
+ ).round().to(int)
77
+
78
+ q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
79
+
80
+ return q_lin
81
+
82
+
83
+ def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
84
+ gaussgpoint = (gaussnum - 1) / 2
85
+
86
+ start = q.min(1).values[:, None] - resolutions * 1.7
87
+ end = q.max(1).values[:, None] + resolutions * 1.7
88
+
89
+ interpnums = torch.abs(
90
+ (torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
91
+ ).round().to(int)
92
+
93
+ q_lin = _batch_linspace_with_padding(start, end, interpnums)
94
+ q_lin = torch.clamp_min_(q_lin, 1e-6)
95
+
96
+ return q_lin
97
+
98
+
99
+ def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
100
+ max_num = nums.max().int().item()
101
+
102
+ deltas = 1 / (nums - 1)
103
+
104
+ x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
105
+
106
+ x = x * (end - start) + start
107
+
108
+ return x
109
+
110
+
111
+ def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
112
+ eps = torch.finfo(y.dtype).eps
113
+
114
+ ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
115
+
116
+ ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
117
+ slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
118
+ ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
119
+ ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
120
+
121
+ y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
122
+
123
+ return y_new
@@ -0,0 +1,118 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from reflectorch.data_generation.priors import PriorSampler
13
+ from reflectorch.paths import SAVED_MODELS_DIR
14
+
15
+
16
+ class CurvesScaler(object):
17
+ """Base class for curve scalers"""
18
+ def scale(self, curves: Tensor):
19
+ raise NotImplementedError
20
+
21
+ def restore(self, curves: Tensor):
22
+ raise NotImplementedError
23
+
24
+
25
+ class LogAffineCurvesScaler(CurvesScaler):
26
+ """ Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
27
+ :math:`\log_{10}(R + eps) \cdot weight + bias`.
28
+
29
+ Args:
30
+ weight (float): multiplication factor in the transformation
31
+ bias (float): addition term in the transformation
32
+ eps (float): sets the minimum intensity value of the reflectivity curves which is considered
33
+ """
34
+ def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
35
+ self.weight = weight
36
+ self.bias = bias
37
+ self.eps = eps
38
+
39
+ def scale(self, curves: Tensor):
40
+ """scales the reflectivity curves to a ML-friendly range
41
+
42
+ Args:
43
+ curves (Tensor): original reflectivity curves
44
+
45
+ Returns:
46
+ Tensor: reflectivity curves scaled to a ML-friendly range
47
+ """
48
+ return torch.log10(curves + self.eps) * self.weight + self.bias
49
+
50
+ def restore(self, curves: Tensor):
51
+ """restores the physical reflectivity curves
52
+
53
+ Args:
54
+ curves (Tensor): scaled reflectivity curves
55
+
56
+ Returns:
57
+ Tensor: reflectivity curves restored to the physical range
58
+ """
59
+ return 10 ** ((curves - self.bias) / self.weight) - self.eps
60
+
61
+
62
+ class MeanNormalizationCurvesScaler(CurvesScaler):
63
+ """Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
64
+
65
+ Args:
66
+ path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
67
+ curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
68
+ device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
69
+ """
70
+
71
+ def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
72
+ if curves_mean is None:
73
+ curves_mean = torch.load(self.get_path(path))
74
+ self.curves_mean = curves_mean.to(device)
75
+
76
+ def scale(self, curves: Tensor):
77
+ """scales the reflectivity curves to a ML-friendly range
78
+
79
+ Args:
80
+ curves (Tensor): original reflectivity curves
81
+
82
+ Returns:
83
+ Tensor: reflectivity curves scaled to a ML-friendly range
84
+ """
85
+ self.curves_mean = self.curves_mean.to(curves)
86
+ return curves / self.curves_mean - 1
87
+
88
+ def restore(self, curves: Tensor):
89
+ """restores the physical reflectivity curves
90
+
91
+ Args:
92
+ curves (Tensor): scaled reflectivity curves
93
+
94
+ Returns:
95
+ Tensor: reflectivity curves restored to the physical range
96
+ """
97
+ self.curves_mean = self.curves_mean.to(curves)
98
+ return (curves + 1) * self.curves_mean
99
+
100
+ @staticmethod
101
+ def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
102
+ """computes the mean of a batch of reflectivity curves and saves it
103
+
104
+ Args:
105
+ prior_sampler (PriorSampler): the prior sampler
106
+ q (Tensor): the q values
107
+ path (str): the path for saving the mean of the curves
108
+ num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
109
+ """
110
+ params = prior_sampler.sample(num)
111
+ curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
112
+ torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
113
+
114
+ @staticmethod
115
+ def get_path(path: str) -> Path:
116
+ if not path.endswith('.pt'):
117
+ path = path + '.pt'
118
+ return SAVED_MODELS_DIR / path
@@ -0,0 +1,67 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
5
+
6
+
7
+ class Smearing(object):
8
+ """Class which applies resolution smearing to the reflectivity curves.
9
+ The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
+
11
+ Args:
12
+ sigma_range (tuple, optional): the range for sampling the standard deviation of the gaussians. Defaults to (1e-4, 5e-3).
13
+ constant_dq (bool, optional): whether the smearing is constant for each q point. Defaults to True.
14
+ gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
15
+ share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
16
+ """
17
+ def __init__(self,
18
+ sigma_range: tuple = (1e-4, 5e-3),
19
+ constant_dq: bool = True,
20
+ gauss_num: int = 31,
21
+ share_smeared: float = 0.2,
22
+ ):
23
+ self.sigma_min, self.sigma_max = sigma_range
24
+ self.sigma_delta = self.sigma_max - self.sigma_min
25
+ self.constant_dq = constant_dq
26
+ self.gauss_num = gauss_num
27
+ self.share_smeared = share_smeared
28
+
29
+ def __repr__(self):
30
+ return f'Smearing(({self.sigma_min}, {self.sigma_max})'
31
+
32
+ def generate_resolutions(self, batch_size: int, device=None, dtype=None):
33
+ num_smeared = int(batch_size * self.share_smeared)
34
+ if not num_smeared:
35
+ return None, None
36
+ dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
37
+ indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
38
+ indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
39
+ return dq, indices
40
+
41
+ def get_curves(self, q_values: Tensor, params: BasicParams):
42
+ dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
43
+
44
+ if dq is None:
45
+ return params.reflectivity(q_values, log=False)
46
+
47
+ curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
48
+
49
+ if (~indices).sum().item():
50
+ if q_values.dim() == 2 and q_values.shape[0] > 1:
51
+ q = q_values[~indices]
52
+ else:
53
+ q = q_values
54
+
55
+ curves[~indices] = params[~indices].reflectivity(q, log=False)
56
+
57
+ if indices.sum().item():
58
+ if q_values.dim() == 2 and q_values.shape[0] > 1:
59
+ q = q_values[indices]
60
+ else:
61
+ q = q_values
62
+
63
+ curves[indices] = params[indices].reflectivity(
64
+ q, dq=dq, constant_dq=self.constant_dq, log=False, gauss_num=self.gauss_num
65
+ )
66
+
67
+ return curves
@@ -0,0 +1,154 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Union
8
+ from math import sqrt, pi, log10
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ __all__ = [
14
+ "get_reversed_params",
15
+ "get_density_profiles",
16
+ "uniform_sampler",
17
+ "logdist_sampler",
18
+ "triangular_sampler",
19
+ "get_param_labels",
20
+ "get_d_rhos",
21
+ "get_slds_from_d_rhos",
22
+ ]
23
+
24
+
25
+ def uniform_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
26
+ if isinstance(low, Tensor):
27
+ device, dtype = low.device, low.dtype
28
+ return torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low
29
+
30
+
31
+ def logdist_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
32
+ if isinstance(low, Tensor):
33
+ device, dtype = low.device, low.dtype
34
+ low, high = map(torch.log10, (low, high))
35
+ else:
36
+ low, high = map(log10, (low, high))
37
+ return 10 ** (torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low)
38
+
39
+
40
+ def triangular_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
41
+ if isinstance(low, Tensor):
42
+ device, dtype = low.device, low.dtype
43
+
44
+ x = torch.rand(*shape, device=device, dtype=dtype)
45
+
46
+ return (high - low) * (1 - torch.sqrt(x)) + low
47
+
48
+
49
+ def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
50
+ reversed_slds = torch.cumsum(
51
+ torch.flip(
52
+ torch.diff(
53
+ torch.cat([torch.zeros(slds.shape[0], 1).to(slds), slds], dim=-1),
54
+ dim=-1
55
+ ), (-1,)
56
+ ),
57
+ dim=-1
58
+ )
59
+ reversed_thicknesses = torch.flip(thicknesses, [-1])
60
+ reversed_roughnesses = torch.flip(roughnesses, [-1])
61
+ reversed_params = torch.cat([reversed_thicknesses, reversed_roughnesses, reversed_slds], -1)
62
+
63
+ return reversed_params
64
+
65
+
66
+ def get_density_profiles(
67
+ thicknesses: Tensor,
68
+ roughnesses: Tensor,
69
+ slds: Tensor,
70
+ z_axis: Tensor = None,
71
+ num: int = 1000
72
+ ):
73
+ """Generates SLD profiles (and their derivative) based on batches of thicknesses, roughnesses and layer SLDs.
74
+
75
+ The axis has its zero at the top (ambient medium) interface and is positive inside the film.
76
+
77
+ Args:
78
+ thicknesses (Tensor): the layer thicknesses (top to bottom)
79
+ roughnesses (Tensor): the interlayer roughnesses (top to bottom)
80
+ slds (Tensor): the layer SLDs (top to bottom)
81
+ z_axis (Tensor, optional): a custom depth (z) axis. Defaults to None.
82
+ num (int, optional): number of discretization points for the profile. Defaults to 1000.
83
+
84
+ Returns:
85
+ tuple: the z axis, the computed density profile rho(z) and the derivative of the density profile drho/dz(z)
86
+ """
87
+ assert torch.all(roughnesses >= 0), 'Negative roughness happened'
88
+ assert torch.all(thicknesses >= 0), 'Negative thickness happened'
89
+
90
+ sample_num = thicknesses.shape[0]
91
+
92
+ d_rhos = get_d_rhos(slds)
93
+
94
+ zs = torch.cumsum(torch.cat([torch.zeros(sample_num, 1).to(thicknesses), thicknesses], dim=-1), dim=-1)
95
+
96
+ if z_axis is None:
97
+ z_axis = torch.linspace(- zs.max() * 0.1, zs.max() * 1.1, num, device=thicknesses.device)[None]
98
+ elif len(z_axis.shape) == 1:
99
+ z_axis = z_axis[None]
100
+
101
+ sigmas = roughnesses * sqrt(2)
102
+
103
+ profile = get_erf(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
104
+
105
+ d_profile = get_gauss(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
106
+
107
+ z_axis = z_axis[0]
108
+
109
+ return z_axis, profile, d_profile
110
+
111
+
112
+ def get_d_rhos(slds: Tensor) -> Tensor:
113
+ d_rhos = torch.cat([slds[:, 0][:, None], torch.diff(slds, dim=-1)], -1)
114
+ return d_rhos
115
+
116
+
117
+ def get_slds_from_d_rhos(d_rhos: Tensor) -> Tensor:
118
+ slds = torch.cumsum(d_rhos, dim=-1)
119
+ return slds
120
+
121
+
122
+ def get_erf(z, z0, sigma, amp):
123
+ return (torch.erf((z - z0) / sigma) + 1) * amp / 2
124
+
125
+
126
+ def get_gauss(z, z0, sigma, amp):
127
+ return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
128
+
129
+
130
+ def get_param_labels(
131
+ num_layers: int, *,
132
+ thickness_name: str = 'Thickness',
133
+ roughness_name: str = 'Roughness',
134
+ sld_name: str = 'SLD',
135
+ substrate_name: str = 'sub',
136
+ ) -> List[str]:
137
+ thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
138
+ roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
139
+ sld_labels = [f'{sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
140
+ return thickness_labels + roughness_labels + sld_labels
141
+
142
+ def get_param_labels_absorption_model(
143
+ num_layers: int, *,
144
+ thickness_name: str = 'Thickness',
145
+ roughness_name: str = 'Roughness',
146
+ real_sld_name: str = 'SLD real',
147
+ imag_sld_name: str = 'SLD imag',
148
+ substrate_name: str = 'sub',
149
+ ) -> List[str]:
150
+ thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
151
+ roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
152
+ real_sld_labels = [f'{real_sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{real_sld_name} {substrate_name}']
153
+ imag_sld_labels = [f'{imag_sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
154
+ return thickness_labels + roughness_labels + real_sld_labels + imag_sld_labels
@@ -0,0 +1,6 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
@@ -0,0 +1,12 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .callbacks import JPlotLoss
8
+
9
+
10
+ __all__ = [
11
+ 'JPlotLoss',
12
+ ]
@@ -0,0 +1,40 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from IPython.display import clear_output
8
+
9
+ from ...ml import TrainerCallback, Trainer
10
+
11
+ from ..matplotlib import plot_losses
12
+
13
+
14
+ class JPlotLoss(TrainerCallback):
15
+ """Callback for plotting the loss in a Jupyter notebook
16
+ """
17
+ def __init__(self, frequency: int, log: bool = True, clear: bool = True, **kwargs):
18
+ """
19
+
20
+ Args:
21
+ frequency (int): plotting frequency
22
+ log (bool, optional): if True, the plot is on a logarithmic scale. Defaults to True.
23
+ clear (bool, optional):
24
+ """
25
+ self.frequency = frequency
26
+ self.log = log
27
+ self.kwargs = kwargs
28
+ self.clear = clear
29
+
30
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
31
+ if not batch_num % self.frequency:
32
+ if self.clear:
33
+ clear_output(wait=True)
34
+
35
+ plot_losses(
36
+ trainer.losses,
37
+ log=self.log,
38
+ best_epoch=trainer.callback_params.get('saved_iteration', None),
39
+ **self.kwargs
40
+ )
@@ -0,0 +1,11 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.extensions.matplotlib.losses import plot_losses
8
+
9
+ __all__ = [
10
+ "plot_losses",
11
+ ]
@@ -0,0 +1,38 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ def plot_losses(
11
+ losses: dict,
12
+ log: bool = False,
13
+ show: bool = True,
14
+ title: str = 'Losses',
15
+ x_label: str = 'Iterations',
16
+ best_epoch: float = None,
17
+ **kwargs
18
+ ):
19
+ func = plt.semilogy if log else plt.plot
20
+
21
+ if len(losses) <= 2:
22
+ losses = {'loss': losses['total_loss']}
23
+
24
+ for k, data in losses.items():
25
+ func(data, label=k, **kwargs)
26
+
27
+ if best_epoch is not None:
28
+ plt.axvline(best_epoch, ls='--', color='red')
29
+
30
+ plt.xlabel(x_label)
31
+
32
+ if len(losses) > 2:
33
+ plt.legend()
34
+
35
+ plt.title(title)
36
+
37
+ if show:
38
+ plt.show()
@@ -0,0 +1,22 @@
1
+ from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
2
+ from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
3
+ from reflectorch.inference.preprocess_exp import (
4
+ StandardPreprocessing,
5
+ standard_preprocessing,
6
+ interp_reflectivity,
7
+ apply_attenuation_correction,
8
+ apply_footprint_correction,
9
+ )
10
+ from reflectorch.inference.torch_fitter import ReflGradientFit
11
+
12
+ __all__ = [
13
+ "InferenceModel",
14
+ "EasyInferenceModel",
15
+ "MultilayerInferenceModel",
16
+ "StandardPreprocessing",
17
+ "standard_preprocessing",
18
+ "ReflGradientFit",
19
+ "interp_reflectivity",
20
+ "apply_attenuation_correction",
21
+ "apply_footprint_correction",
22
+ ]