reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,23 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.data_generation import *
8
+ from reflectorch.ml import *
9
+ from reflectorch.models import *
10
+ from reflectorch.utils import *
11
+ from reflectorch.paths import *
12
+ from reflectorch.runs import *
13
+ from reflectorch.inference import *
14
+
15
+ from reflectorch.data_generation import __all__ as all_data_generation
16
+ from reflectorch.ml import __all__ as all_ml
17
+ from reflectorch.models import __all__ as all_models
18
+ from reflectorch.utils import __all__ as all_utils
19
+ from reflectorch.paths import __all__ as all_paths
20
+ from reflectorch.runs import __all__ as all_runs
21
+ from reflectorch.inference import __all__ as all_inference
22
+
23
+ __all__ = all_data_generation + all_ml + all_models + all_utils + all_paths + all_runs + all_inference
@@ -0,0 +1,130 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.data_generation.dataset import BasicDataset, BATCH_DATA_TYPE
8
+ from reflectorch.data_generation.priors import (
9
+ Params,
10
+ PriorSampler,
11
+ BasicPriorSampler,
12
+ SingleParamPrior,
13
+ SimplePriorSampler,
14
+ UniformParamPrior,
15
+ GaussianParamPrior,
16
+ TruncatedGaussianParamPrior,
17
+ UniformSubPriorParams,
18
+ UniformSubPriorSampler,
19
+ NarrowSldUniformSubPriorSampler,
20
+ ExpUniformSubPriorSampler,
21
+ SimpleMultilayerSampler,
22
+ SubpriorParametricSampler,
23
+ BasicParams,
24
+ ParametricModel,
25
+ MULTILAYER_MODELS,
26
+ SamplerStrategy,
27
+ BasicSamplerStrategy,
28
+ ConstrainedRoughnessSamplerStrategy,
29
+ ConstrainedRoughnessAndImgSldSamplerStrategy,
30
+ )
31
+ from reflectorch.data_generation.process_data import ProcessData, ProcessPipeline
32
+ from reflectorch.data_generation.q_generator import (
33
+ QGenerator,
34
+ ConstantAngle,
35
+ ConstantQ,
36
+ VariableQ,
37
+ EquidistantQ,
38
+ )
39
+ from reflectorch.data_generation.noise import (
40
+ QNoiseGenerator,
41
+ IntensityNoiseGenerator,
42
+ QNormalNoiseGenerator,
43
+ QSystematicShiftGenerator,
44
+ PoissonNoiseGenerator,
45
+ MultiplicativeLogNormalNoiseGenerator,
46
+ ShiftNoise,
47
+ ScalingNoise,
48
+ BackgroundNoise,
49
+ BasicExpIntensityNoise,
50
+ BasicQNoiseGenerator,
51
+ )
52
+ from reflectorch.data_generation.scale_curves import (
53
+ CurvesScaler,
54
+ LogAffineCurvesScaler,
55
+ MeanNormalizationCurvesScaler,
56
+ )
57
+ from reflectorch.data_generation.utils import (
58
+ get_reversed_params,
59
+ get_density_profiles,
60
+ uniform_sampler,
61
+ logdist_sampler,
62
+ triangular_sampler,
63
+ get_param_labels,
64
+ )
65
+
66
+ from reflectorch.data_generation.smearing import Smearing
67
+
68
+ from reflectorch.data_generation.reflectivity import reflectivity
69
+
70
+ from reflectorch.data_generation.likelihoods import (
71
+ LogLikelihood,
72
+ PoissonLogLikelihood,
73
+ )
74
+
75
+ __all__ = [
76
+ "Params",
77
+ "PriorSampler",
78
+ "BasicPriorSampler",
79
+ "BasicDataset",
80
+ "ProcessData",
81
+ "ProcessPipeline",
82
+ "QGenerator",
83
+ "ConstantQ",
84
+ "VariableQ",
85
+ "EquidistantQ",
86
+ "QNoiseGenerator",
87
+ "IntensityNoiseGenerator",
88
+ "MultiplicativeLogNormalNoiseGenerator",
89
+ "PoissonNoiseGenerator",
90
+ "CurvesScaler",
91
+ "ShiftNoise",
92
+ "ScalingNoise",
93
+ "BackgroundNoise",
94
+ "QNormalNoiseGenerator",
95
+ "QSystematicShiftGenerator",
96
+ "LogAffineCurvesScaler",
97
+ "MeanNormalizationCurvesScaler",
98
+ "get_reversed_params",
99
+ "get_density_profiles",
100
+ "logdist_sampler",
101
+ "uniform_sampler",
102
+ "triangular_sampler",
103
+ "get_param_labels",
104
+ "reflectivity",
105
+ "Smearing",
106
+ "SingleParamPrior",
107
+ "SimplePriorSampler",
108
+ "UniformParamPrior",
109
+ "GaussianParamPrior",
110
+ "TruncatedGaussianParamPrior",
111
+ "UniformSubPriorParams",
112
+ "UniformSubPriorSampler",
113
+ "NarrowSldUniformSubPriorSampler",
114
+ "ExpUniformSubPriorSampler",
115
+ "SimpleMultilayerSampler",
116
+ "BATCH_DATA_TYPE",
117
+ "LogLikelihood",
118
+ "PoissonLogLikelihood",
119
+ "BasicExpIntensityNoise",
120
+ "BasicQNoiseGenerator",
121
+ "ConstantAngle",
122
+ "SubpriorParametricSampler",
123
+ "BasicParams",
124
+ "ParametricModel",
125
+ "MULTILAYER_MODELS",
126
+ "SamplerStrategy",
127
+ "BasicSamplerStrategy",
128
+ "ConstrainedRoughnessSamplerStrategy",
129
+ "ConstrainedRoughnessAndImgSldSamplerStrategy",
130
+ ]
@@ -0,0 +1,196 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict, Union
8
+ import warnings
9
+
10
+ from torch import Tensor
11
+ import torch
12
+
13
+ from reflectorch.data_generation.priors import PriorSampler, BasicParams
14
+ from reflectorch.data_generation.noise import QNoiseGenerator, IntensityNoiseGenerator
15
+ from reflectorch.data_generation.q_generator import QGenerator
16
+ from reflectorch.data_generation.scale_curves import CurvesScaler, LogAffineCurvesScaler
17
+ from reflectorch.data_generation.smearing import Smearing
18
+
19
+ BATCH_DATA_TYPE = Dict[str, Union[Tensor, BasicParams]]
20
+
21
+
22
+ class BasicDataset(object):
23
+ """Reflectometry dataset. It generates the q positions, samples the thin film parameters from the prior,
24
+ simulates the reflectivity curves and applies noise to the curves.
25
+
26
+ Args:
27
+ q_generator (QGenerator): the momentum transfer (q) generator
28
+ prior_sampler (PriorSampler): the prior sampler
29
+ intensity_noise (IntensityNoiseGenerator, optional): the intensity noise generator. Defaults to None.
30
+ q_noise (QNoiseGenerator, optional): the q noise generator. Defaults to None.
31
+ curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
32
+ which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
33
+ calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
34
+ smearing (Smearing, optional): curve smearing generator. Defaults to None.
35
+ """
36
+ def __init__(self,
37
+ q_generator: QGenerator,
38
+ prior_sampler: PriorSampler,
39
+ intensity_noise: IntensityNoiseGenerator = None,
40
+ q_noise: QNoiseGenerator = None,
41
+ curves_scaler: CurvesScaler = None,
42
+ calc_denoised_curves: bool = False,
43
+ smearing: Smearing = None,
44
+ ):
45
+ self.q_generator = q_generator
46
+ self.intensity_noise = intensity_noise
47
+ self.q_noise = q_noise
48
+ self.curves_scaler = curves_scaler or LogAffineCurvesScaler()
49
+ self.prior_sampler = prior_sampler
50
+ self.smearing = smearing
51
+ self.calc_denoised_curves = calc_denoised_curves
52
+
53
+ def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
54
+ """implement in a subclass to edit batch_data dict inplace"""
55
+ pass
56
+
57
+ def _sample_from_prior(self, batch_size: int):
58
+ params: BasicParams = self.prior_sampler.sample(batch_size)
59
+ scaled_params: Tensor = self.prior_sampler.scale_params(params)
60
+ return params, scaled_params
61
+
62
+ def get_batch(self, batch_size: int) -> BATCH_DATA_TYPE:
63
+ """get a batch of data as a dictionary with keys ``params``, ``scaled_params``, ``q_values``, ``curves``, ``scaled_noisy_curves``
64
+
65
+ Args:
66
+ batch_size (int): the batch size
67
+ """
68
+ batch_data = {}
69
+
70
+ params, scaled_params = self._sample_from_prior(batch_size)
71
+
72
+ batch_data['params'] = params
73
+ batch_data['scaled_params'] = scaled_params
74
+
75
+ q_values: Tensor = self.q_generator.get_batch(batch_size, batch_data)
76
+
77
+ if self.q_noise:
78
+ batch_data['original_q_values'] = q_values
79
+ q_values = self.q_noise.apply(q_values, batch_data)
80
+
81
+ batch_data['q_values'] = q_values
82
+
83
+ curves = self._calc_curves(q_values, params)
84
+
85
+ if self.calc_denoised_curves:
86
+ batch_data['curves'] = curves
87
+
88
+ noisy_curves = curves
89
+
90
+ if self.intensity_noise:
91
+ noisy_curves = self.intensity_noise(noisy_curves, batch_data)
92
+
93
+ scaled_noisy_curves = self.curves_scaler.scale(noisy_curves)
94
+ batch_data['scaled_noisy_curves'] = scaled_noisy_curves
95
+
96
+ is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
97
+ if not torch.all(is_finite).item():
98
+ infinite_indices = ~is_finite
99
+ warnings.warn(f'Batch with {infinite_indices.sum().item()} curves with infinities skipped.')
100
+ return self.get_batch(batch_size = batch_size)
101
+
102
+ is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
103
+ assert torch.all(is_finite).item()
104
+
105
+ self.update_batch_data(batch_data)
106
+
107
+ return batch_data
108
+
109
+ def _calc_curves(self, q_values: Tensor, params: BasicParams):
110
+ if self.smearing:
111
+ curves = self.smearing.get_curves(q_values, params)
112
+ else:
113
+ curves = params.reflectivity(q_values)
114
+ curves = curves.to(q_values)
115
+ return curves
116
+
117
+
118
+ def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
119
+ for key in tuple(tgt_batch_data.keys()):
120
+ value = tgt_batch_data[key]
121
+ if isinstance(value, BasicParams) or len(value.shape) == 2:
122
+ value[indices] = add_batch_data[key]
123
+ else:
124
+ warnings.warn(f'Ignore {key} while merging batch_data.')
125
+
126
+
127
+ if __name__ == '__main__':
128
+ from reflectorch.data_generation.q_generator import ConstantQ
129
+ from reflectorch.data_generation.priors import BasicPriorSampler, UniformSubPriorSampler
130
+ from reflectorch.data_generation.noise import BasicExpIntensityNoise
131
+ from reflectorch.data_generation.noise import BasicQNoiseGenerator
132
+ from reflectorch.utils import to_np
133
+ from time import perf_counter
134
+
135
+ q_generator = ConstantQ((0, 0.2, 65), device='cpu')
136
+ noise_gen = BasicExpIntensityNoise(
137
+ relative_errors=(0.05, 0.2),
138
+ # scale_range=(-1e-2, 1e-2),
139
+ logdist=True,
140
+ apply_shift=True,
141
+ )
142
+ q_noise_gen = BasicQNoiseGenerator(
143
+ shift_std=5e-4,
144
+ noise_std=(0, 1e-3),
145
+ )
146
+ prior_sampler = UniformSubPriorSampler(
147
+ thickness_range=(0, 250),
148
+ roughness_range=(0, 40),
149
+ sld_range=(0, 60),
150
+ num_layers=2,
151
+ device=torch.device('cpu'),
152
+ dtype=torch.float64,
153
+ smaller_roughnesses=True,
154
+ logdist=True,
155
+ relative_min_bound_width=5e-4,
156
+ )
157
+ smearing = Smearing(
158
+ sigma_range=(0.8e-3, 5e-3),
159
+ gauss_num=31,
160
+ share_smeared=0.5,
161
+ )
162
+
163
+ dataset = BasicDataset(
164
+ q_generator,
165
+ prior_sampler,
166
+ noise_gen,
167
+ q_noise=q_noise_gen,
168
+ smearing=smearing
169
+ )
170
+ start = perf_counter()
171
+ batch_data = dataset.get_batch(32)
172
+ print(f'Total time = {(perf_counter() - start):.3f} sec ')
173
+ print(batch_data['params'].roughnesses[:10])
174
+ print(batch_data['scaled_noisy_curves'].min().item())
175
+
176
+ scaled_noisy_curves = batch_data['scaled_noisy_curves']
177
+ scaled_curves = dataset.curves_scaler.scale(
178
+ batch_data['params'].reflectivity(q_generator.q)
179
+ )
180
+
181
+ try:
182
+ import matplotlib.pyplot as plt
183
+
184
+ for i in range(16):
185
+ plt.plot(
186
+ to_np(q_generator.q.squeeze().cpu().numpy()),
187
+ to_np(scaled_curves[i])
188
+ )
189
+ plt.plot(
190
+ to_np(q_generator.q.squeeze().cpu().numpy()),
191
+ to_np(scaled_noisy_curves[i])
192
+ )
193
+
194
+ plt.show()
195
+ except ImportError:
196
+ pass
@@ -0,0 +1,86 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union, Tuple
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from reflectorch.data_generation import (
13
+ PriorSampler,
14
+ Params,
15
+ )
16
+
17
+
18
+ class LogLikelihood(object):
19
+ """Computes the gaussian log likelihood of the thin film parameters
20
+
21
+ Args:
22
+ q (Tensor): the q values
23
+ exp_curve (Tensor): the experimental reflectivity curve
24
+ priors (PriorSampler): the prior sampler
25
+ sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
26
+ """
27
+ def __init__(self, q: Tensor, exp_curve: Tensor, priors: PriorSampler, sigmas: Union[float, Tensor]):
28
+ self.exp_curve = torch.atleast_2d(exp_curve)
29
+ self.priors: PriorSampler = priors
30
+ self.q = q
31
+ self.sigmas = sigmas
32
+ self.sigmas2 = self.sigmas ** 2
33
+
34
+ def calc_log_likelihood(self, curves: Tensor):
35
+ "computes the gaussian log likelihood"
36
+ log_probs = - (self.exp_curve - curves) ** 2 / self.sigmas2 / 2
37
+ return log_probs.sum(-1)
38
+
39
+ def __call__(self, params: Union[Params, Tensor], curves: Tensor = None):
40
+ if not isinstance(params, Params):
41
+ params: Params = self.priors.PARAM_CLS.from_tensor(params)
42
+ log_priors: Tensor = self.priors.log_prob(params)
43
+ indices: Tensor = torch.isfinite(log_priors)
44
+
45
+ if not indices.sum().item():
46
+ return log_priors
47
+
48
+ finite_params: Params = params[indices]
49
+
50
+ if curves is None:
51
+ curves: Tensor = finite_params.reflectivity(self.q)
52
+ else:
53
+ curves = curves[indices]
54
+
55
+ log_priors[indices] += self.calc_log_likelihood(curves)
56
+
57
+ return log_priors
58
+
59
+ calc_log_posterior = __call__
60
+
61
+ def get_importance_sampling_weights(
62
+ self, sampled_params: Params, nf_log_probs: Tensor, curves: Tensor = None
63
+ ) -> Tuple[Tensor, Tensor, Tensor]:
64
+ log_probs = self.calc_log_posterior(sampled_params, curves=curves)
65
+ log_weights = log_probs - nf_log_probs
66
+ log_weights = log_weights - log_weights.max()
67
+
68
+ weights = torch.exp(log_weights.to(torch.float64)).to(log_weights)
69
+ weights = weights / weights.sum()
70
+
71
+ return weights, log_weights, log_probs
72
+
73
+
74
+ class PoissonLogLikelihood(LogLikelihood):
75
+ """Computes the Poisson log likelihood of the thin film parameters
76
+
77
+ Args:
78
+ q (Tensor): the q values
79
+ exp_curve (Tensor): the experimental reflectivity curve
80
+ priors (PriorSampler): the prior sampler
81
+ sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
82
+ """
83
+ def calc_log_likelihood(self, curves: Tensor):
84
+ """computes the Poisson log likelihood"""
85
+ log_probs = self.exp_curve / self.sigmas2 * (self.exp_curve * torch.log(curves) - curves)
86
+ return log_probs.sum(-1)