reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,210 +1,210 @@
|
|
|
1
|
-
from typing import Dict, Union
|
|
2
|
-
import warnings
|
|
3
|
-
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.priors import PriorSampler, BasicParams
|
|
8
|
-
from reflectorch.data_generation.noise import QNoiseGenerator, IntensityNoiseGenerator
|
|
9
|
-
from reflectorch.data_generation.q_generator import QGenerator
|
|
10
|
-
from reflectorch.data_generation.scale_curves import CurvesScaler, LogAffineCurvesScaler
|
|
11
|
-
from reflectorch.data_generation.smearing import Smearing
|
|
12
|
-
|
|
13
|
-
BATCH_DATA_TYPE = Dict[str, Union[Tensor, BasicParams]]
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class BasicDataset(object):
|
|
17
|
-
"""Reflectometry dataset. It generates the q positions, samples the thin film parameters from the prior,
|
|
18
|
-
simulates the reflectivity curves and applies noise to the curves.
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
q_generator (QGenerator): the momentum transfer (q) generator
|
|
22
|
-
prior_sampler (PriorSampler): the prior sampler
|
|
23
|
-
intensity_noise (IntensityNoiseGenerator, optional): the intensity noise generator. Defaults to None.
|
|
24
|
-
q_noise (QNoiseGenerator, optional): the q noise generator. Defaults to None.
|
|
25
|
-
curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
|
|
26
|
-
which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
|
|
27
|
-
calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
|
|
28
|
-
calc_nonsmeared_curves (bool, optional): whether to add the curves without smearing to the dictionary (only relevant when smearing is applied). Defaults to False.
|
|
29
|
-
smearing (Smearing, optional): curve smearing generator. Defaults to None.
|
|
30
|
-
"""
|
|
31
|
-
def __init__(self,
|
|
32
|
-
q_generator: QGenerator,
|
|
33
|
-
prior_sampler: PriorSampler,
|
|
34
|
-
intensity_noise: IntensityNoiseGenerator = None,
|
|
35
|
-
q_noise: QNoiseGenerator = None,
|
|
36
|
-
curves_scaler: CurvesScaler = None,
|
|
37
|
-
calc_denoised_curves: bool = False,
|
|
38
|
-
calc_nonsmeared_curves: bool = False,
|
|
39
|
-
smearing: Smearing = None,
|
|
40
|
-
):
|
|
41
|
-
self.q_generator = q_generator
|
|
42
|
-
self.intensity_noise = intensity_noise
|
|
43
|
-
self.q_noise = q_noise
|
|
44
|
-
self.curves_scaler = curves_scaler or LogAffineCurvesScaler()
|
|
45
|
-
self.prior_sampler = prior_sampler
|
|
46
|
-
self.smearing = smearing
|
|
47
|
-
self.calc_denoised_curves = calc_denoised_curves
|
|
48
|
-
self.calc_nonsmeared_curves = calc_nonsmeared_curves
|
|
49
|
-
|
|
50
|
-
def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
|
|
51
|
-
"""implement in a subclass to edit batch_data dict inplace"""
|
|
52
|
-
pass
|
|
53
|
-
|
|
54
|
-
def _sample_from_prior(self, batch_size: int):
|
|
55
|
-
params: BasicParams = self.prior_sampler.sample(batch_size)
|
|
56
|
-
scaled_params: Tensor = self.prior_sampler.scale_params(params)
|
|
57
|
-
return params, scaled_params
|
|
58
|
-
|
|
59
|
-
def get_batch(self, batch_size: int) -> BATCH_DATA_TYPE:
|
|
60
|
-
"""get a batch of data as a dictionary with keys ``params``, ``scaled_params``, ``q_values``, ``curves``, ``scaled_noisy_curves``
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
batch_size (int): the batch size
|
|
64
|
-
"""
|
|
65
|
-
batch_data = {}
|
|
66
|
-
|
|
67
|
-
params, scaled_params = self._sample_from_prior(batch_size)
|
|
68
|
-
|
|
69
|
-
batch_data['params'] = params
|
|
70
|
-
batch_data['scaled_params'] = scaled_params
|
|
71
|
-
|
|
72
|
-
q_values: Tensor = self.q_generator.get_batch(batch_size, batch_data)
|
|
73
|
-
|
|
74
|
-
if self.q_noise:
|
|
75
|
-
batch_data['original_q_values'] = q_values
|
|
76
|
-
q_values = self.q_noise.apply(q_values, batch_data)
|
|
77
|
-
|
|
78
|
-
batch_data['q_values'] = q_values
|
|
79
|
-
|
|
80
|
-
refl_kwargs = {}
|
|
81
|
-
|
|
82
|
-
curves, q_resolutions, nonsmeared_curves = self._calc_curves(q_values, params, refl_kwargs)
|
|
83
|
-
|
|
84
|
-
if torch.is_tensor(q_resolutions):
|
|
85
|
-
batch_data['q_resolutions'] = q_resolutions
|
|
86
|
-
|
|
87
|
-
if torch.is_tensor(nonsmeared_curves):
|
|
88
|
-
batch_data['nonsmeared_curves'] = nonsmeared_curves
|
|
89
|
-
|
|
90
|
-
if self.calc_denoised_curves:
|
|
91
|
-
batch_data['curves'] = curves
|
|
92
|
-
|
|
93
|
-
noisy_curves = curves
|
|
94
|
-
|
|
95
|
-
if self.intensity_noise:
|
|
96
|
-
noisy_curves = self.intensity_noise(noisy_curves, batch_data)
|
|
97
|
-
|
|
98
|
-
scaled_noisy_curves = self.curves_scaler.scale(noisy_curves)
|
|
99
|
-
batch_data['scaled_noisy_curves'] = scaled_noisy_curves
|
|
100
|
-
|
|
101
|
-
is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
|
|
102
|
-
|
|
103
|
-
if not torch.all(is_finite).item():
|
|
104
|
-
infinite_indices = ~is_finite
|
|
105
|
-
to_recalculate = infinite_indices.sum().item()
|
|
106
|
-
warnings.warn(f'Infinite number appeared in the curve simulation! Recalculate {to_recalculate} curves.')
|
|
107
|
-
recalculated_batch_data = self.get_batch(to_recalculate)
|
|
108
|
-
_insert_batch_data(batch_data, recalculated_batch_data, infinite_indices)
|
|
109
|
-
|
|
110
|
-
is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
|
|
111
|
-
assert torch.all(is_finite).item()
|
|
112
|
-
|
|
113
|
-
self.update_batch_data(batch_data)
|
|
114
|
-
|
|
115
|
-
return batch_data
|
|
116
|
-
|
|
117
|
-
def _calc_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs):
|
|
118
|
-
nonsmeared_curves = None
|
|
119
|
-
|
|
120
|
-
if self.smearing:
|
|
121
|
-
if self.calc_nonsmeared_curves:
|
|
122
|
-
nonsmeared_curves = params.reflectivity(q_values, **refl_kwargs)
|
|
123
|
-
curves, q_resolutions = self.smearing.get_curves(q_values, params, refl_kwargs)
|
|
124
|
-
else:
|
|
125
|
-
curves = params.reflectivity(q_values, **refl_kwargs)
|
|
126
|
-
q_resolutions = None
|
|
127
|
-
|
|
128
|
-
curves = curves.to(q_values)
|
|
129
|
-
return curves, q_resolutions, nonsmeared_curves
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
|
|
133
|
-
for key in tuple(tgt_batch_data.keys()):
|
|
134
|
-
value = tgt_batch_data[key]
|
|
135
|
-
if isinstance(value, BasicParams) or len(value.shape) == 2:
|
|
136
|
-
value[indices] = add_batch_data[key]
|
|
137
|
-
else:
|
|
138
|
-
warnings.warn(f'Ignore {key} while merging batch_data.')
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
if __name__ == '__main__':
|
|
142
|
-
from reflectorch.data_generation.q_generator import ConstantQ
|
|
143
|
-
from reflectorch.data_generation.priors import BasicPriorSampler, UniformSubPriorSampler
|
|
144
|
-
from reflectorch.data_generation.noise import BasicExpIntensityNoise
|
|
145
|
-
from reflectorch.data_generation.noise import BasicQNoiseGenerator
|
|
146
|
-
from reflectorch.utils import to_np
|
|
147
|
-
from time import perf_counter
|
|
148
|
-
|
|
149
|
-
q_generator = ConstantQ((0, 0.2, 65), device='cpu')
|
|
150
|
-
noise_gen = BasicExpIntensityNoise(
|
|
151
|
-
relative_errors=(0.05, 0.2),
|
|
152
|
-
# scale_range=(-1e-2, 1e-2),
|
|
153
|
-
logdist=True,
|
|
154
|
-
apply_shift=True,
|
|
155
|
-
)
|
|
156
|
-
q_noise_gen = BasicQNoiseGenerator(
|
|
157
|
-
shift_std=5e-4,
|
|
158
|
-
noise_std=(0, 1e-3),
|
|
159
|
-
)
|
|
160
|
-
prior_sampler = UniformSubPriorSampler(
|
|
161
|
-
thickness_range=(0, 250),
|
|
162
|
-
roughness_range=(0, 40),
|
|
163
|
-
sld_range=(0, 60),
|
|
164
|
-
num_layers=2,
|
|
165
|
-
device=torch.device('cpu'),
|
|
166
|
-
dtype=torch.float64,
|
|
167
|
-
smaller_roughnesses=True,
|
|
168
|
-
logdist=True,
|
|
169
|
-
relative_min_bound_width=5e-4,
|
|
170
|
-
)
|
|
171
|
-
smearing = Smearing(
|
|
172
|
-
sigma_range=(0.8e-3, 5e-3),
|
|
173
|
-
gauss_num=31,
|
|
174
|
-
share_smeared=0.5,
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
dataset = BasicDataset(
|
|
178
|
-
q_generator,
|
|
179
|
-
prior_sampler,
|
|
180
|
-
noise_gen,
|
|
181
|
-
q_noise=q_noise_gen,
|
|
182
|
-
smearing=smearing
|
|
183
|
-
)
|
|
184
|
-
start = perf_counter()
|
|
185
|
-
batch_data = dataset.get_batch(32)
|
|
186
|
-
print(f'Total time = {(perf_counter() - start):.3f} sec ')
|
|
187
|
-
print(batch_data['params'].roughnesses[:10])
|
|
188
|
-
print(batch_data['scaled_noisy_curves'].min().item())
|
|
189
|
-
|
|
190
|
-
scaled_noisy_curves = batch_data['scaled_noisy_curves']
|
|
191
|
-
scaled_curves = dataset.curves_scaler.scale(
|
|
192
|
-
batch_data['params'].reflectivity(q_generator.q)
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
try:
|
|
196
|
-
import matplotlib.pyplot as plt
|
|
197
|
-
|
|
198
|
-
for i in range(16):
|
|
199
|
-
plt.plot(
|
|
200
|
-
to_np(q_generator.q.squeeze().cpu().numpy()),
|
|
201
|
-
to_np(scaled_curves[i])
|
|
202
|
-
)
|
|
203
|
-
plt.plot(
|
|
204
|
-
to_np(q_generator.q.squeeze().cpu().numpy()),
|
|
205
|
-
to_np(scaled_noisy_curves[i])
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
plt.show()
|
|
209
|
-
except ImportError:
|
|
210
|
-
pass
|
|
1
|
+
from typing import Dict, Union
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.priors import PriorSampler, BasicParams
|
|
8
|
+
from reflectorch.data_generation.noise import QNoiseGenerator, IntensityNoiseGenerator
|
|
9
|
+
from reflectorch.data_generation.q_generator import QGenerator
|
|
10
|
+
from reflectorch.data_generation.scale_curves import CurvesScaler, LogAffineCurvesScaler
|
|
11
|
+
from reflectorch.data_generation.smearing import Smearing
|
|
12
|
+
|
|
13
|
+
BATCH_DATA_TYPE = Dict[str, Union[Tensor, BasicParams]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BasicDataset(object):
|
|
17
|
+
"""Reflectometry dataset. It generates the q positions, samples the thin film parameters from the prior,
|
|
18
|
+
simulates the reflectivity curves and applies noise to the curves.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
q_generator (QGenerator): the momentum transfer (q) generator
|
|
22
|
+
prior_sampler (PriorSampler): the prior sampler
|
|
23
|
+
intensity_noise (IntensityNoiseGenerator, optional): the intensity noise generator. Defaults to None.
|
|
24
|
+
q_noise (QNoiseGenerator, optional): the q noise generator. Defaults to None.
|
|
25
|
+
curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
|
|
26
|
+
which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
|
|
27
|
+
calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
|
|
28
|
+
calc_nonsmeared_curves (bool, optional): whether to add the curves without smearing to the dictionary (only relevant when smearing is applied). Defaults to False.
|
|
29
|
+
smearing (Smearing, optional): curve smearing generator. Defaults to None.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self,
|
|
32
|
+
q_generator: QGenerator,
|
|
33
|
+
prior_sampler: PriorSampler,
|
|
34
|
+
intensity_noise: IntensityNoiseGenerator = None,
|
|
35
|
+
q_noise: QNoiseGenerator = None,
|
|
36
|
+
curves_scaler: CurvesScaler = None,
|
|
37
|
+
calc_denoised_curves: bool = False,
|
|
38
|
+
calc_nonsmeared_curves: bool = False,
|
|
39
|
+
smearing: Smearing = None,
|
|
40
|
+
):
|
|
41
|
+
self.q_generator = q_generator
|
|
42
|
+
self.intensity_noise = intensity_noise
|
|
43
|
+
self.q_noise = q_noise
|
|
44
|
+
self.curves_scaler = curves_scaler or LogAffineCurvesScaler()
|
|
45
|
+
self.prior_sampler = prior_sampler
|
|
46
|
+
self.smearing = smearing
|
|
47
|
+
self.calc_denoised_curves = calc_denoised_curves
|
|
48
|
+
self.calc_nonsmeared_curves = calc_nonsmeared_curves
|
|
49
|
+
|
|
50
|
+
def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
|
|
51
|
+
"""implement in a subclass to edit batch_data dict inplace"""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def _sample_from_prior(self, batch_size: int):
|
|
55
|
+
params: BasicParams = self.prior_sampler.sample(batch_size)
|
|
56
|
+
scaled_params: Tensor = self.prior_sampler.scale_params(params)
|
|
57
|
+
return params, scaled_params
|
|
58
|
+
|
|
59
|
+
def get_batch(self, batch_size: int) -> BATCH_DATA_TYPE:
|
|
60
|
+
"""get a batch of data as a dictionary with keys ``params``, ``scaled_params``, ``q_values``, ``curves``, ``scaled_noisy_curves``
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
batch_size (int): the batch size
|
|
64
|
+
"""
|
|
65
|
+
batch_data = {}
|
|
66
|
+
|
|
67
|
+
params, scaled_params = self._sample_from_prior(batch_size)
|
|
68
|
+
|
|
69
|
+
batch_data['params'] = params
|
|
70
|
+
batch_data['scaled_params'] = scaled_params
|
|
71
|
+
|
|
72
|
+
q_values: Tensor = self.q_generator.get_batch(batch_size, batch_data)
|
|
73
|
+
|
|
74
|
+
if self.q_noise:
|
|
75
|
+
batch_data['original_q_values'] = q_values
|
|
76
|
+
q_values = self.q_noise.apply(q_values, batch_data)
|
|
77
|
+
|
|
78
|
+
batch_data['q_values'] = q_values
|
|
79
|
+
|
|
80
|
+
refl_kwargs = {}
|
|
81
|
+
|
|
82
|
+
curves, q_resolutions, nonsmeared_curves = self._calc_curves(q_values, params, refl_kwargs)
|
|
83
|
+
|
|
84
|
+
if torch.is_tensor(q_resolutions):
|
|
85
|
+
batch_data['q_resolutions'] = q_resolutions
|
|
86
|
+
|
|
87
|
+
if torch.is_tensor(nonsmeared_curves):
|
|
88
|
+
batch_data['nonsmeared_curves'] = nonsmeared_curves
|
|
89
|
+
|
|
90
|
+
if self.calc_denoised_curves:
|
|
91
|
+
batch_data['curves'] = curves
|
|
92
|
+
|
|
93
|
+
noisy_curves = curves
|
|
94
|
+
|
|
95
|
+
if self.intensity_noise:
|
|
96
|
+
noisy_curves = self.intensity_noise(noisy_curves, batch_data)
|
|
97
|
+
|
|
98
|
+
scaled_noisy_curves = self.curves_scaler.scale(noisy_curves)
|
|
99
|
+
batch_data['scaled_noisy_curves'] = scaled_noisy_curves
|
|
100
|
+
|
|
101
|
+
is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
|
|
102
|
+
|
|
103
|
+
if not torch.all(is_finite).item():
|
|
104
|
+
infinite_indices = ~is_finite
|
|
105
|
+
to_recalculate = infinite_indices.sum().item()
|
|
106
|
+
warnings.warn(f'Infinite number appeared in the curve simulation! Recalculate {to_recalculate} curves.')
|
|
107
|
+
recalculated_batch_data = self.get_batch(to_recalculate)
|
|
108
|
+
_insert_batch_data(batch_data, recalculated_batch_data, infinite_indices)
|
|
109
|
+
|
|
110
|
+
is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
|
|
111
|
+
assert torch.all(is_finite).item()
|
|
112
|
+
|
|
113
|
+
self.update_batch_data(batch_data)
|
|
114
|
+
|
|
115
|
+
return batch_data
|
|
116
|
+
|
|
117
|
+
def _calc_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs):
|
|
118
|
+
nonsmeared_curves = None
|
|
119
|
+
|
|
120
|
+
if self.smearing:
|
|
121
|
+
if self.calc_nonsmeared_curves:
|
|
122
|
+
nonsmeared_curves = params.reflectivity(q_values, **refl_kwargs)
|
|
123
|
+
curves, q_resolutions = self.smearing.get_curves(q_values, params, refl_kwargs)
|
|
124
|
+
else:
|
|
125
|
+
curves = params.reflectivity(q_values, **refl_kwargs)
|
|
126
|
+
q_resolutions = None
|
|
127
|
+
|
|
128
|
+
curves = curves.to(q_values)
|
|
129
|
+
return curves, q_resolutions, nonsmeared_curves
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
|
|
133
|
+
for key in tuple(tgt_batch_data.keys()):
|
|
134
|
+
value = tgt_batch_data[key]
|
|
135
|
+
if isinstance(value, BasicParams) or len(value.shape) == 2:
|
|
136
|
+
value[indices] = add_batch_data[key]
|
|
137
|
+
else:
|
|
138
|
+
warnings.warn(f'Ignore {key} while merging batch_data.')
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
if __name__ == '__main__':
|
|
142
|
+
from reflectorch.data_generation.q_generator import ConstantQ
|
|
143
|
+
from reflectorch.data_generation.priors import BasicPriorSampler, UniformSubPriorSampler
|
|
144
|
+
from reflectorch.data_generation.noise import BasicExpIntensityNoise
|
|
145
|
+
from reflectorch.data_generation.noise import BasicQNoiseGenerator
|
|
146
|
+
from reflectorch.utils import to_np
|
|
147
|
+
from time import perf_counter
|
|
148
|
+
|
|
149
|
+
q_generator = ConstantQ((0, 0.2, 65), device='cpu')
|
|
150
|
+
noise_gen = BasicExpIntensityNoise(
|
|
151
|
+
relative_errors=(0.05, 0.2),
|
|
152
|
+
# scale_range=(-1e-2, 1e-2),
|
|
153
|
+
logdist=True,
|
|
154
|
+
apply_shift=True,
|
|
155
|
+
)
|
|
156
|
+
q_noise_gen = BasicQNoiseGenerator(
|
|
157
|
+
shift_std=5e-4,
|
|
158
|
+
noise_std=(0, 1e-3),
|
|
159
|
+
)
|
|
160
|
+
prior_sampler = UniformSubPriorSampler(
|
|
161
|
+
thickness_range=(0, 250),
|
|
162
|
+
roughness_range=(0, 40),
|
|
163
|
+
sld_range=(0, 60),
|
|
164
|
+
num_layers=2,
|
|
165
|
+
device=torch.device('cpu'),
|
|
166
|
+
dtype=torch.float64,
|
|
167
|
+
smaller_roughnesses=True,
|
|
168
|
+
logdist=True,
|
|
169
|
+
relative_min_bound_width=5e-4,
|
|
170
|
+
)
|
|
171
|
+
smearing = Smearing(
|
|
172
|
+
sigma_range=(0.8e-3, 5e-3),
|
|
173
|
+
gauss_num=31,
|
|
174
|
+
share_smeared=0.5,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
dataset = BasicDataset(
|
|
178
|
+
q_generator,
|
|
179
|
+
prior_sampler,
|
|
180
|
+
noise_gen,
|
|
181
|
+
q_noise=q_noise_gen,
|
|
182
|
+
smearing=smearing
|
|
183
|
+
)
|
|
184
|
+
start = perf_counter()
|
|
185
|
+
batch_data = dataset.get_batch(32)
|
|
186
|
+
print(f'Total time = {(perf_counter() - start):.3f} sec ')
|
|
187
|
+
print(batch_data['params'].roughnesses[:10])
|
|
188
|
+
print(batch_data['scaled_noisy_curves'].min().item())
|
|
189
|
+
|
|
190
|
+
scaled_noisy_curves = batch_data['scaled_noisy_curves']
|
|
191
|
+
scaled_curves = dataset.curves_scaler.scale(
|
|
192
|
+
batch_data['params'].reflectivity(q_generator.q)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
import matplotlib.pyplot as plt
|
|
197
|
+
|
|
198
|
+
for i in range(16):
|
|
199
|
+
plt.plot(
|
|
200
|
+
to_np(q_generator.q.squeeze().cpu().numpy()),
|
|
201
|
+
to_np(scaled_curves[i])
|
|
202
|
+
)
|
|
203
|
+
plt.plot(
|
|
204
|
+
to_np(q_generator.q.squeeze().cpu().numpy()),
|
|
205
|
+
to_np(scaled_noisy_curves[i])
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
plt.show()
|
|
209
|
+
except ImportError:
|
|
210
|
+
pass
|
|
@@ -1,80 +1,80 @@
|
|
|
1
|
-
from typing import Union, Tuple
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
|
|
6
|
-
from reflectorch.data_generation import (
|
|
7
|
-
PriorSampler,
|
|
8
|
-
Params,
|
|
9
|
-
)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class LogLikelihood(object):
|
|
13
|
-
"""Computes the gaussian log likelihood of the thin film parameters
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
q (Tensor): the q values
|
|
17
|
-
exp_curve (Tensor): the experimental reflectivity curve
|
|
18
|
-
priors (PriorSampler): the prior sampler
|
|
19
|
-
sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
|
|
20
|
-
"""
|
|
21
|
-
def __init__(self, q: Tensor, exp_curve: Tensor, priors: PriorSampler, sigmas: Union[float, Tensor]):
|
|
22
|
-
self.exp_curve = torch.atleast_2d(exp_curve)
|
|
23
|
-
self.priors: PriorSampler = priors
|
|
24
|
-
self.q = q
|
|
25
|
-
self.sigmas = sigmas
|
|
26
|
-
self.sigmas2 = self.sigmas ** 2
|
|
27
|
-
|
|
28
|
-
def calc_log_likelihood(self, curves: Tensor):
|
|
29
|
-
"computes the gaussian log likelihood"
|
|
30
|
-
log_probs = - (self.exp_curve - curves) ** 2 / self.sigmas2 / 2
|
|
31
|
-
return log_probs.sum(-1)
|
|
32
|
-
|
|
33
|
-
def __call__(self, params: Union[Params, Tensor], curves: Tensor = None):
|
|
34
|
-
if not isinstance(params, Params):
|
|
35
|
-
params: Params = self.priors.PARAM_CLS.from_tensor(params)
|
|
36
|
-
log_priors: Tensor = self.priors.log_prob(params)
|
|
37
|
-
indices: Tensor = torch.isfinite(log_priors)
|
|
38
|
-
|
|
39
|
-
if not indices.sum().item():
|
|
40
|
-
return log_priors
|
|
41
|
-
|
|
42
|
-
finite_params: Params = params[indices]
|
|
43
|
-
|
|
44
|
-
if curves is None:
|
|
45
|
-
curves: Tensor = finite_params.reflectivity(self.q)
|
|
46
|
-
else:
|
|
47
|
-
curves = curves[indices]
|
|
48
|
-
|
|
49
|
-
log_priors[indices] += self.calc_log_likelihood(curves)
|
|
50
|
-
|
|
51
|
-
return log_priors
|
|
52
|
-
|
|
53
|
-
calc_log_posterior = __call__
|
|
54
|
-
|
|
55
|
-
def get_importance_sampling_weights(
|
|
56
|
-
self, sampled_params: Params, nf_log_probs: Tensor, curves: Tensor = None
|
|
57
|
-
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
58
|
-
log_probs = self.calc_log_posterior(sampled_params, curves=curves)
|
|
59
|
-
log_weights = log_probs - nf_log_probs
|
|
60
|
-
log_weights = log_weights - log_weights.max()
|
|
61
|
-
|
|
62
|
-
weights = torch.exp(log_weights.to(torch.float64)).to(log_weights)
|
|
63
|
-
weights = weights / weights.sum()
|
|
64
|
-
|
|
65
|
-
return weights, log_weights, log_probs
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class PoissonLogLikelihood(LogLikelihood):
|
|
69
|
-
"""Computes the Poisson log likelihood of the thin film parameters
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
q (Tensor): the q values
|
|
73
|
-
exp_curve (Tensor): the experimental reflectivity curve
|
|
74
|
-
priors (PriorSampler): the prior sampler
|
|
75
|
-
sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
|
|
76
|
-
"""
|
|
77
|
-
def calc_log_likelihood(self, curves: Tensor):
|
|
78
|
-
"""computes the Poisson log likelihood"""
|
|
79
|
-
log_probs = self.exp_curve / self.sigmas2 * (self.exp_curve * torch.log(curves) - curves)
|
|
80
|
-
return log_probs.sum(-1)
|
|
1
|
+
from typing import Union, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation import (
|
|
7
|
+
PriorSampler,
|
|
8
|
+
Params,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LogLikelihood(object):
|
|
13
|
+
"""Computes the gaussian log likelihood of the thin film parameters
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
q (Tensor): the q values
|
|
17
|
+
exp_curve (Tensor): the experimental reflectivity curve
|
|
18
|
+
priors (PriorSampler): the prior sampler
|
|
19
|
+
sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, q: Tensor, exp_curve: Tensor, priors: PriorSampler, sigmas: Union[float, Tensor]):
|
|
22
|
+
self.exp_curve = torch.atleast_2d(exp_curve)
|
|
23
|
+
self.priors: PriorSampler = priors
|
|
24
|
+
self.q = q
|
|
25
|
+
self.sigmas = sigmas
|
|
26
|
+
self.sigmas2 = self.sigmas ** 2
|
|
27
|
+
|
|
28
|
+
def calc_log_likelihood(self, curves: Tensor):
|
|
29
|
+
"computes the gaussian log likelihood"
|
|
30
|
+
log_probs = - (self.exp_curve - curves) ** 2 / self.sigmas2 / 2
|
|
31
|
+
return log_probs.sum(-1)
|
|
32
|
+
|
|
33
|
+
def __call__(self, params: Union[Params, Tensor], curves: Tensor = None):
|
|
34
|
+
if not isinstance(params, Params):
|
|
35
|
+
params: Params = self.priors.PARAM_CLS.from_tensor(params)
|
|
36
|
+
log_priors: Tensor = self.priors.log_prob(params)
|
|
37
|
+
indices: Tensor = torch.isfinite(log_priors)
|
|
38
|
+
|
|
39
|
+
if not indices.sum().item():
|
|
40
|
+
return log_priors
|
|
41
|
+
|
|
42
|
+
finite_params: Params = params[indices]
|
|
43
|
+
|
|
44
|
+
if curves is None:
|
|
45
|
+
curves: Tensor = finite_params.reflectivity(self.q)
|
|
46
|
+
else:
|
|
47
|
+
curves = curves[indices]
|
|
48
|
+
|
|
49
|
+
log_priors[indices] += self.calc_log_likelihood(curves)
|
|
50
|
+
|
|
51
|
+
return log_priors
|
|
52
|
+
|
|
53
|
+
calc_log_posterior = __call__
|
|
54
|
+
|
|
55
|
+
def get_importance_sampling_weights(
|
|
56
|
+
self, sampled_params: Params, nf_log_probs: Tensor, curves: Tensor = None
|
|
57
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
58
|
+
log_probs = self.calc_log_posterior(sampled_params, curves=curves)
|
|
59
|
+
log_weights = log_probs - nf_log_probs
|
|
60
|
+
log_weights = log_weights - log_weights.max()
|
|
61
|
+
|
|
62
|
+
weights = torch.exp(log_weights.to(torch.float64)).to(log_weights)
|
|
63
|
+
weights = weights / weights.sum()
|
|
64
|
+
|
|
65
|
+
return weights, log_weights, log_probs
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class PoissonLogLikelihood(LogLikelihood):
|
|
69
|
+
"""Computes the Poisson log likelihood of the thin film parameters
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
q (Tensor): the q values
|
|
73
|
+
exp_curve (Tensor): the experimental reflectivity curve
|
|
74
|
+
priors (PriorSampler): the prior sampler
|
|
75
|
+
sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
|
|
76
|
+
"""
|
|
77
|
+
def calc_log_likelihood(self, curves: Tensor):
|
|
78
|
+
"""computes the Poisson log likelihood"""
|
|
79
|
+
log_probs = self.exp_curve / self.sigmas2 * (self.exp_curve * torch.log(curves) - curves)
|
|
80
|
+
return log_probs.sum(-1)
|