reflectorch 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +23 -0
- reflectorch/data_generation/__init__.py +130 -0
- reflectorch/data_generation/dataset.py +196 -0
- reflectorch/data_generation/likelihoods.py +86 -0
- reflectorch/data_generation/noise.py +371 -0
- reflectorch/data_generation/priors/__init__.py +66 -0
- reflectorch/data_generation/priors/base.py +61 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
- reflectorch/data_generation/priors/independent_priors.py +201 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +110 -0
- reflectorch/data_generation/priors/no_constraints.py +212 -0
- reflectorch/data_generation/priors/parametric_models.py +767 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
- reflectorch/data_generation/priors/params.py +258 -0
- reflectorch/data_generation/priors/sampler_strategies.py +306 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +377 -0
- reflectorch/data_generation/priors/utils.py +124 -0
- reflectorch/data_generation/process_data.py +47 -0
- reflectorch/data_generation/q_generator.py +232 -0
- reflectorch/data_generation/reflectivity/__init__.py +56 -0
- reflectorch/data_generation/reflectivity/abeles.py +81 -0
- reflectorch/data_generation/reflectivity/kinematical.py +58 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +123 -0
- reflectorch/data_generation/scale_curves.py +118 -0
- reflectorch/data_generation/smearing.py +67 -0
- reflectorch/data_generation/utils.py +154 -0
- reflectorch/extensions/__init__.py +6 -0
- reflectorch/extensions/jupyter/__init__.py +12 -0
- reflectorch/extensions/jupyter/callbacks.py +40 -0
- reflectorch/extensions/matplotlib/__init__.py +11 -0
- reflectorch/extensions/matplotlib/losses.py +38 -0
- reflectorch/inference/__init__.py +22 -0
- reflectorch/inference/inference_model.py +734 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +16 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +171 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +37 -0
- reflectorch/ml/basic_trainer.py +286 -0
- reflectorch/ml/callbacks.py +86 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +38 -0
- reflectorch/ml/schedulers.py +246 -0
- reflectorch/ml/trainers.py +126 -0
- reflectorch/ml/utils.py +9 -0
- reflectorch/models/__init__.py +22 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +27 -0
- reflectorch/models/encoders/conv_encoder.py +211 -0
- reflectorch/models/encoders/conv_res_net.py +119 -0
- reflectorch/models/encoders/fno.py +127 -0
- reflectorch/models/encoders/transformers.py +56 -0
- reflectorch/models/networks/__init__.py +18 -0
- reflectorch/models/networks/mlp_networks.py +256 -0
- reflectorch/models/networks/residual_net.py +131 -0
- reflectorch/paths.py +33 -0
- reflectorch/runs/__init__.py +35 -0
- reflectorch/runs/config.py +31 -0
- reflectorch/runs/slurm_utils.py +99 -0
- reflectorch/runs/train.py +85 -0
- reflectorch/runs/utils.py +300 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +74 -0
- reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
- reflectorch-1.0.0.dist-info/METADATA +115 -0
- reflectorch-1.0.0.dist-info/RECORD +83 -0
- reflectorch-1.0.0.dist-info/WHEEL +5 -0
- reflectorch-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from math import pi, sqrt, log
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
from torch.nn.functional import conv1d, pad
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def abeles_constant_smearing(
|
|
12
|
+
q: Tensor,
|
|
13
|
+
thickness: Tensor,
|
|
14
|
+
roughness: Tensor,
|
|
15
|
+
sld: Tensor,
|
|
16
|
+
dq: Tensor = None,
|
|
17
|
+
gauss_num: int = 51,
|
|
18
|
+
constant_dq: bool = True,
|
|
19
|
+
abeles_func=None,
|
|
20
|
+
):
|
|
21
|
+
abeles_func = abeles_func or abeles
|
|
22
|
+
q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
|
|
23
|
+
kernels = _get_t_gauss_kernels(dq, gauss_num)
|
|
24
|
+
|
|
25
|
+
curves = abeles_func(q_lin, thickness, roughness, sld)
|
|
26
|
+
|
|
27
|
+
padding = (kernels.shape[-1] - 1) // 2
|
|
28
|
+
smeared_curves = conv1d(
|
|
29
|
+
pad(curves[None], (padding, padding), 'reflect'), kernels[:, None], groups=kernels.shape[0],
|
|
30
|
+
)[0]
|
|
31
|
+
|
|
32
|
+
if q.shape[0] != smeared_curves.shape[0]:
|
|
33
|
+
q = q.expand(smeared_curves.shape[0], *q.shape[1:])
|
|
34
|
+
|
|
35
|
+
smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
|
|
36
|
+
|
|
37
|
+
return smeared_curves
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
_FWHM = 2 * sqrt(2 * log(2.0))
|
|
41
|
+
_2PI_SQRT = 1. / sqrt(2 * pi)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _batch_linspace(start: Tensor, end: Tensor, num: int):
|
|
45
|
+
return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _torch_gauss(x, s):
|
|
49
|
+
return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
|
|
53
|
+
gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
|
|
54
|
+
gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
|
|
55
|
+
return gauss_y
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = True):
|
|
59
|
+
if constant_dq:
|
|
60
|
+
return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
|
|
61
|
+
else:
|
|
62
|
+
return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
|
|
66
|
+
gaussgpoint = (gaussnum - 1) / 2
|
|
67
|
+
|
|
68
|
+
lowq = torch.clamp_min_(q.min(1).values, 1e-6)
|
|
69
|
+
highq = q.max(1).values
|
|
70
|
+
|
|
71
|
+
start = torch.log10(lowq) - 6 * resolutions / _FWHM
|
|
72
|
+
end = torch.log10(highq * (1 + 6 * resolutions / _FWHM))
|
|
73
|
+
|
|
74
|
+
interpnums = torch.abs(
|
|
75
|
+
(torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
|
|
76
|
+
).round().to(int)
|
|
77
|
+
|
|
78
|
+
q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
|
|
79
|
+
|
|
80
|
+
return q_lin
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
|
|
84
|
+
gaussgpoint = (gaussnum - 1) / 2
|
|
85
|
+
|
|
86
|
+
start = q.min(1).values[:, None] - resolutions * 1.7
|
|
87
|
+
end = q.max(1).values[:, None] + resolutions * 1.7
|
|
88
|
+
|
|
89
|
+
interpnums = torch.abs(
|
|
90
|
+
(torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
|
|
91
|
+
).round().to(int)
|
|
92
|
+
|
|
93
|
+
q_lin = _batch_linspace_with_padding(start, end, interpnums)
|
|
94
|
+
q_lin = torch.clamp_min_(q_lin, 1e-6)
|
|
95
|
+
|
|
96
|
+
return q_lin
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
|
|
100
|
+
max_num = nums.max().int().item()
|
|
101
|
+
|
|
102
|
+
deltas = 1 / (nums - 1)
|
|
103
|
+
|
|
104
|
+
x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
|
|
105
|
+
|
|
106
|
+
x = x * (end - start) + start
|
|
107
|
+
|
|
108
|
+
return x
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
|
|
112
|
+
eps = torch.finfo(y.dtype).eps
|
|
113
|
+
|
|
114
|
+
ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
|
|
115
|
+
|
|
116
|
+
ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
|
|
117
|
+
slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
|
|
118
|
+
ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
|
|
119
|
+
ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
|
|
120
|
+
|
|
121
|
+
y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
|
|
122
|
+
|
|
123
|
+
return y_new
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
from reflectorch.data_generation.priors import PriorSampler
|
|
13
|
+
from reflectorch.paths import SAVED_MODELS_DIR
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CurvesScaler(object):
|
|
17
|
+
"""Base class for curve scalers"""
|
|
18
|
+
def scale(self, curves: Tensor):
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
def restore(self, curves: Tensor):
|
|
22
|
+
raise NotImplementedError
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LogAffineCurvesScaler(CurvesScaler):
|
|
26
|
+
""" Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
|
|
27
|
+
:math:`\log_{10}(R + eps) \cdot weight + bias`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
weight (float): multiplication factor in the transformation
|
|
31
|
+
bias (float): addition term in the transformation
|
|
32
|
+
eps (float): sets the minimum intensity value of the reflectivity curves which is considered
|
|
33
|
+
"""
|
|
34
|
+
def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
|
|
35
|
+
self.weight = weight
|
|
36
|
+
self.bias = bias
|
|
37
|
+
self.eps = eps
|
|
38
|
+
|
|
39
|
+
def scale(self, curves: Tensor):
|
|
40
|
+
"""scales the reflectivity curves to a ML-friendly range
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
curves (Tensor): original reflectivity curves
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
47
|
+
"""
|
|
48
|
+
return torch.log10(curves + self.eps) * self.weight + self.bias
|
|
49
|
+
|
|
50
|
+
def restore(self, curves: Tensor):
|
|
51
|
+
"""restores the physical reflectivity curves
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
curves (Tensor): scaled reflectivity curves
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tensor: reflectivity curves restored to the physical range
|
|
58
|
+
"""
|
|
59
|
+
return 10 ** ((curves - self.bias) / self.weight) - self.eps
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MeanNormalizationCurvesScaler(CurvesScaler):
|
|
63
|
+
"""Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
|
|
67
|
+
curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
|
|
68
|
+
device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
|
|
72
|
+
if curves_mean is None:
|
|
73
|
+
curves_mean = torch.load(self.get_path(path))
|
|
74
|
+
self.curves_mean = curves_mean.to(device)
|
|
75
|
+
|
|
76
|
+
def scale(self, curves: Tensor):
|
|
77
|
+
"""scales the reflectivity curves to a ML-friendly range
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
curves (Tensor): original reflectivity curves
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
84
|
+
"""
|
|
85
|
+
self.curves_mean = self.curves_mean.to(curves)
|
|
86
|
+
return curves / self.curves_mean - 1
|
|
87
|
+
|
|
88
|
+
def restore(self, curves: Tensor):
|
|
89
|
+
"""restores the physical reflectivity curves
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
curves (Tensor): scaled reflectivity curves
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Tensor: reflectivity curves restored to the physical range
|
|
96
|
+
"""
|
|
97
|
+
self.curves_mean = self.curves_mean.to(curves)
|
|
98
|
+
return (curves + 1) * self.curves_mean
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
|
|
102
|
+
"""computes the mean of a batch of reflectivity curves and saves it
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
prior_sampler (PriorSampler): the prior sampler
|
|
106
|
+
q (Tensor): the q values
|
|
107
|
+
path (str): the path for saving the mean of the curves
|
|
108
|
+
num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
|
|
109
|
+
"""
|
|
110
|
+
params = prior_sampler.sample(num)
|
|
111
|
+
curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
|
|
112
|
+
torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def get_path(path: str) -> Path:
|
|
116
|
+
if not path.endswith('.pt'):
|
|
117
|
+
path = path + '.pt'
|
|
118
|
+
return SAVED_MODELS_DIR / path
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Smearing(object):
|
|
8
|
+
"""Class which applies resolution smearing to the reflectivity curves.
|
|
9
|
+
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
sigma_range (tuple, optional): the range for sampling the standard deviation of the gaussians. Defaults to (1e-4, 5e-3).
|
|
13
|
+
constant_dq (bool, optional): whether the smearing is constant for each q point. Defaults to True.
|
|
14
|
+
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
15
|
+
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
16
|
+
"""
|
|
17
|
+
def __init__(self,
|
|
18
|
+
sigma_range: tuple = (1e-4, 5e-3),
|
|
19
|
+
constant_dq: bool = True,
|
|
20
|
+
gauss_num: int = 31,
|
|
21
|
+
share_smeared: float = 0.2,
|
|
22
|
+
):
|
|
23
|
+
self.sigma_min, self.sigma_max = sigma_range
|
|
24
|
+
self.sigma_delta = self.sigma_max - self.sigma_min
|
|
25
|
+
self.constant_dq = constant_dq
|
|
26
|
+
self.gauss_num = gauss_num
|
|
27
|
+
self.share_smeared = share_smeared
|
|
28
|
+
|
|
29
|
+
def __repr__(self):
|
|
30
|
+
return f'Smearing(({self.sigma_min}, {self.sigma_max})'
|
|
31
|
+
|
|
32
|
+
def generate_resolutions(self, batch_size: int, device=None, dtype=None):
|
|
33
|
+
num_smeared = int(batch_size * self.share_smeared)
|
|
34
|
+
if not num_smeared:
|
|
35
|
+
return None, None
|
|
36
|
+
dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
|
|
37
|
+
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
|
|
38
|
+
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
|
|
39
|
+
return dq, indices
|
|
40
|
+
|
|
41
|
+
def get_curves(self, q_values: Tensor, params: BasicParams):
|
|
42
|
+
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
|
|
43
|
+
|
|
44
|
+
if dq is None:
|
|
45
|
+
return params.reflectivity(q_values, log=False)
|
|
46
|
+
|
|
47
|
+
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
|
|
48
|
+
|
|
49
|
+
if (~indices).sum().item():
|
|
50
|
+
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
51
|
+
q = q_values[~indices]
|
|
52
|
+
else:
|
|
53
|
+
q = q_values
|
|
54
|
+
|
|
55
|
+
curves[~indices] = params[~indices].reflectivity(q, log=False)
|
|
56
|
+
|
|
57
|
+
if indices.sum().item():
|
|
58
|
+
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
59
|
+
q = q_values[indices]
|
|
60
|
+
else:
|
|
61
|
+
q = q_values
|
|
62
|
+
|
|
63
|
+
curves[indices] = params[indices].reflectivity(
|
|
64
|
+
q, dq=dq, constant_dq=self.constant_dq, log=False, gauss_num=self.gauss_num
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return curves
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import List, Union
|
|
8
|
+
from math import sqrt, pi, log10
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"get_reversed_params",
|
|
15
|
+
"get_density_profiles",
|
|
16
|
+
"uniform_sampler",
|
|
17
|
+
"logdist_sampler",
|
|
18
|
+
"triangular_sampler",
|
|
19
|
+
"get_param_labels",
|
|
20
|
+
"get_d_rhos",
|
|
21
|
+
"get_slds_from_d_rhos",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def uniform_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
|
|
26
|
+
if isinstance(low, Tensor):
|
|
27
|
+
device, dtype = low.device, low.dtype
|
|
28
|
+
return torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def logdist_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
|
|
32
|
+
if isinstance(low, Tensor):
|
|
33
|
+
device, dtype = low.device, low.dtype
|
|
34
|
+
low, high = map(torch.log10, (low, high))
|
|
35
|
+
else:
|
|
36
|
+
low, high = map(log10, (low, high))
|
|
37
|
+
return 10 ** (torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def triangular_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
|
|
41
|
+
if isinstance(low, Tensor):
|
|
42
|
+
device, dtype = low.device, low.dtype
|
|
43
|
+
|
|
44
|
+
x = torch.rand(*shape, device=device, dtype=dtype)
|
|
45
|
+
|
|
46
|
+
return (high - low) * (1 - torch.sqrt(x)) + low
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
50
|
+
reversed_slds = torch.cumsum(
|
|
51
|
+
torch.flip(
|
|
52
|
+
torch.diff(
|
|
53
|
+
torch.cat([torch.zeros(slds.shape[0], 1).to(slds), slds], dim=-1),
|
|
54
|
+
dim=-1
|
|
55
|
+
), (-1,)
|
|
56
|
+
),
|
|
57
|
+
dim=-1
|
|
58
|
+
)
|
|
59
|
+
reversed_thicknesses = torch.flip(thicknesses, [-1])
|
|
60
|
+
reversed_roughnesses = torch.flip(roughnesses, [-1])
|
|
61
|
+
reversed_params = torch.cat([reversed_thicknesses, reversed_roughnesses, reversed_slds], -1)
|
|
62
|
+
|
|
63
|
+
return reversed_params
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_density_profiles(
|
|
67
|
+
thicknesses: Tensor,
|
|
68
|
+
roughnesses: Tensor,
|
|
69
|
+
slds: Tensor,
|
|
70
|
+
z_axis: Tensor = None,
|
|
71
|
+
num: int = 1000
|
|
72
|
+
):
|
|
73
|
+
"""Generates SLD profiles (and their derivative) based on batches of thicknesses, roughnesses and layer SLDs.
|
|
74
|
+
|
|
75
|
+
The axis has its zero at the top (ambient medium) interface and is positive inside the film.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
thicknesses (Tensor): the layer thicknesses (top to bottom)
|
|
79
|
+
roughnesses (Tensor): the interlayer roughnesses (top to bottom)
|
|
80
|
+
slds (Tensor): the layer SLDs (top to bottom)
|
|
81
|
+
z_axis (Tensor, optional): a custom depth (z) axis. Defaults to None.
|
|
82
|
+
num (int, optional): number of discretization points for the profile. Defaults to 1000.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
tuple: the z axis, the computed density profile rho(z) and the derivative of the density profile drho/dz(z)
|
|
86
|
+
"""
|
|
87
|
+
assert torch.all(roughnesses >= 0), 'Negative roughness happened'
|
|
88
|
+
assert torch.all(thicknesses >= 0), 'Negative thickness happened'
|
|
89
|
+
|
|
90
|
+
sample_num = thicknesses.shape[0]
|
|
91
|
+
|
|
92
|
+
d_rhos = get_d_rhos(slds)
|
|
93
|
+
|
|
94
|
+
zs = torch.cumsum(torch.cat([torch.zeros(sample_num, 1).to(thicknesses), thicknesses], dim=-1), dim=-1)
|
|
95
|
+
|
|
96
|
+
if z_axis is None:
|
|
97
|
+
z_axis = torch.linspace(- zs.max() * 0.1, zs.max() * 1.1, num, device=thicknesses.device)[None]
|
|
98
|
+
elif len(z_axis.shape) == 1:
|
|
99
|
+
z_axis = z_axis[None]
|
|
100
|
+
|
|
101
|
+
sigmas = roughnesses * sqrt(2)
|
|
102
|
+
|
|
103
|
+
profile = get_erf(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
|
|
104
|
+
|
|
105
|
+
d_profile = get_gauss(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
|
|
106
|
+
|
|
107
|
+
z_axis = z_axis[0]
|
|
108
|
+
|
|
109
|
+
return z_axis, profile, d_profile
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_d_rhos(slds: Tensor) -> Tensor:
|
|
113
|
+
d_rhos = torch.cat([slds[:, 0][:, None], torch.diff(slds, dim=-1)], -1)
|
|
114
|
+
return d_rhos
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_slds_from_d_rhos(d_rhos: Tensor) -> Tensor:
|
|
118
|
+
slds = torch.cumsum(d_rhos, dim=-1)
|
|
119
|
+
return slds
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_erf(z, z0, sigma, amp):
|
|
123
|
+
return (torch.erf((z - z0) / sigma) + 1) * amp / 2
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_gauss(z, z0, sigma, amp):
|
|
127
|
+
return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_param_labels(
|
|
131
|
+
num_layers: int, *,
|
|
132
|
+
thickness_name: str = 'Thickness',
|
|
133
|
+
roughness_name: str = 'Roughness',
|
|
134
|
+
sld_name: str = 'SLD',
|
|
135
|
+
substrate_name: str = 'sub',
|
|
136
|
+
) -> List[str]:
|
|
137
|
+
thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
|
|
138
|
+
roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
139
|
+
sld_labels = [f'{sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
|
|
140
|
+
return thickness_labels + roughness_labels + sld_labels
|
|
141
|
+
|
|
142
|
+
def get_param_labels_absorption_model(
|
|
143
|
+
num_layers: int, *,
|
|
144
|
+
thickness_name: str = 'Thickness',
|
|
145
|
+
roughness_name: str = 'Roughness',
|
|
146
|
+
real_sld_name: str = 'SLD real',
|
|
147
|
+
imag_sld_name: str = 'SLD imag',
|
|
148
|
+
substrate_name: str = 'sub',
|
|
149
|
+
) -> List[str]:
|
|
150
|
+
thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
|
|
151
|
+
roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
152
|
+
real_sld_labels = [f'{real_sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{real_sld_name} {substrate_name}']
|
|
153
|
+
imag_sld_labels = [f'{imag_sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
|
|
154
|
+
return thickness_labels + roughness_labels + real_sld_labels + imag_sld_labels
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from IPython.display import clear_output
|
|
8
|
+
|
|
9
|
+
from ...ml import TrainerCallback, Trainer
|
|
10
|
+
|
|
11
|
+
from ..matplotlib import plot_losses
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class JPlotLoss(TrainerCallback):
|
|
15
|
+
"""Callback for plotting the loss in a Jupyter notebook
|
|
16
|
+
"""
|
|
17
|
+
def __init__(self, frequency: int, log: bool = True, clear: bool = True, **kwargs):
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
frequency (int): plotting frequency
|
|
22
|
+
log (bool, optional): if True, the plot is on a logarithmic scale. Defaults to True.
|
|
23
|
+
clear (bool, optional):
|
|
24
|
+
"""
|
|
25
|
+
self.frequency = frequency
|
|
26
|
+
self.log = log
|
|
27
|
+
self.kwargs = kwargs
|
|
28
|
+
self.clear = clear
|
|
29
|
+
|
|
30
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
31
|
+
if not batch_num % self.frequency:
|
|
32
|
+
if self.clear:
|
|
33
|
+
clear_output(wait=True)
|
|
34
|
+
|
|
35
|
+
plot_losses(
|
|
36
|
+
trainer.losses,
|
|
37
|
+
log=self.log,
|
|
38
|
+
best_epoch=trainer.callback_params.get('saved_iteration', None),
|
|
39
|
+
**self.kwargs
|
|
40
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from reflectorch.extensions.matplotlib.losses import plot_losses
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"plot_losses",
|
|
11
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def plot_losses(
|
|
11
|
+
losses: dict,
|
|
12
|
+
log: bool = False,
|
|
13
|
+
show: bool = True,
|
|
14
|
+
title: str = 'Losses',
|
|
15
|
+
x_label: str = 'Iterations',
|
|
16
|
+
best_epoch: float = None,
|
|
17
|
+
**kwargs
|
|
18
|
+
):
|
|
19
|
+
func = plt.semilogy if log else plt.plot
|
|
20
|
+
|
|
21
|
+
if len(losses) <= 2:
|
|
22
|
+
losses = {'loss': losses['total_loss']}
|
|
23
|
+
|
|
24
|
+
for k, data in losses.items():
|
|
25
|
+
func(data, label=k, **kwargs)
|
|
26
|
+
|
|
27
|
+
if best_epoch is not None:
|
|
28
|
+
plt.axvline(best_epoch, ls='--', color='red')
|
|
29
|
+
|
|
30
|
+
plt.xlabel(x_label)
|
|
31
|
+
|
|
32
|
+
if len(losses) > 2:
|
|
33
|
+
plt.legend()
|
|
34
|
+
|
|
35
|
+
plt.title(title)
|
|
36
|
+
|
|
37
|
+
if show:
|
|
38
|
+
plt.show()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
|
|
2
|
+
from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
|
|
3
|
+
from reflectorch.inference.preprocess_exp import (
|
|
4
|
+
StandardPreprocessing,
|
|
5
|
+
standard_preprocessing,
|
|
6
|
+
interp_reflectivity,
|
|
7
|
+
apply_attenuation_correction,
|
|
8
|
+
apply_footprint_correction,
|
|
9
|
+
)
|
|
10
|
+
from reflectorch.inference.torch_fitter import ReflGradientFit
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"InferenceModel",
|
|
14
|
+
"EasyInferenceModel",
|
|
15
|
+
"MultilayerInferenceModel",
|
|
16
|
+
"StandardPreprocessing",
|
|
17
|
+
"standard_preprocessing",
|
|
18
|
+
"ReflGradientFit",
|
|
19
|
+
"interp_reflectivity",
|
|
20
|
+
"apply_attenuation_correction",
|
|
21
|
+
"apply_footprint_correction",
|
|
22
|
+
]
|