reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,210 +1,210 @@
1
- from typing import Dict, Union
2
- import warnings
3
-
4
- from torch import Tensor
5
- import torch
6
-
7
- from reflectorch.data_generation.priors import PriorSampler, BasicParams
8
- from reflectorch.data_generation.noise import QNoiseGenerator, IntensityNoiseGenerator
9
- from reflectorch.data_generation.q_generator import QGenerator
10
- from reflectorch.data_generation.scale_curves import CurvesScaler, LogAffineCurvesScaler
11
- from reflectorch.data_generation.smearing import Smearing
12
-
13
- BATCH_DATA_TYPE = Dict[str, Union[Tensor, BasicParams]]
14
-
15
-
16
- class BasicDataset(object):
17
- """Reflectometry dataset. It generates the q positions, samples the thin film parameters from the prior,
18
- simulates the reflectivity curves and applies noise to the curves.
19
-
20
- Args:
21
- q_generator (QGenerator): the momentum transfer (q) generator
22
- prior_sampler (PriorSampler): the prior sampler
23
- intensity_noise (IntensityNoiseGenerator, optional): the intensity noise generator. Defaults to None.
24
- q_noise (QNoiseGenerator, optional): the q noise generator. Defaults to None.
25
- curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
26
- which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
27
- calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
28
- calc_nonsmeared_curves (bool, optional): whether to add the curves without smearing to the dictionary (only relevant when smearing is applied). Defaults to False.
29
- smearing (Smearing, optional): curve smearing generator. Defaults to None.
30
- """
31
- def __init__(self,
32
- q_generator: QGenerator,
33
- prior_sampler: PriorSampler,
34
- intensity_noise: IntensityNoiseGenerator = None,
35
- q_noise: QNoiseGenerator = None,
36
- curves_scaler: CurvesScaler = None,
37
- calc_denoised_curves: bool = False,
38
- calc_nonsmeared_curves: bool = False,
39
- smearing: Smearing = None,
40
- ):
41
- self.q_generator = q_generator
42
- self.intensity_noise = intensity_noise
43
- self.q_noise = q_noise
44
- self.curves_scaler = curves_scaler or LogAffineCurvesScaler()
45
- self.prior_sampler = prior_sampler
46
- self.smearing = smearing
47
- self.calc_denoised_curves = calc_denoised_curves
48
- self.calc_nonsmeared_curves = calc_nonsmeared_curves
49
-
50
- def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
51
- """implement in a subclass to edit batch_data dict inplace"""
52
- pass
53
-
54
- def _sample_from_prior(self, batch_size: int):
55
- params: BasicParams = self.prior_sampler.sample(batch_size)
56
- scaled_params: Tensor = self.prior_sampler.scale_params(params)
57
- return params, scaled_params
58
-
59
- def get_batch(self, batch_size: int) -> BATCH_DATA_TYPE:
60
- """get a batch of data as a dictionary with keys ``params``, ``scaled_params``, ``q_values``, ``curves``, ``scaled_noisy_curves``
61
-
62
- Args:
63
- batch_size (int): the batch size
64
- """
65
- batch_data = {}
66
-
67
- params, scaled_params = self._sample_from_prior(batch_size)
68
-
69
- batch_data['params'] = params
70
- batch_data['scaled_params'] = scaled_params
71
-
72
- q_values: Tensor = self.q_generator.get_batch(batch_size, batch_data)
73
-
74
- if self.q_noise:
75
- batch_data['original_q_values'] = q_values
76
- q_values = self.q_noise.apply(q_values, batch_data)
77
-
78
- batch_data['q_values'] = q_values
79
-
80
- refl_kwargs = {}
81
-
82
- curves, q_resolutions, nonsmeared_curves = self._calc_curves(q_values, params, refl_kwargs)
83
-
84
- if torch.is_tensor(q_resolutions):
85
- batch_data['q_resolutions'] = q_resolutions
86
-
87
- if torch.is_tensor(nonsmeared_curves):
88
- batch_data['nonsmeared_curves'] = nonsmeared_curves
89
-
90
- if self.calc_denoised_curves:
91
- batch_data['curves'] = curves
92
-
93
- noisy_curves = curves
94
-
95
- if self.intensity_noise:
96
- noisy_curves = self.intensity_noise(noisy_curves, batch_data)
97
-
98
- scaled_noisy_curves = self.curves_scaler.scale(noisy_curves)
99
- batch_data['scaled_noisy_curves'] = scaled_noisy_curves
100
-
101
- is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
102
-
103
- if not torch.all(is_finite).item():
104
- infinite_indices = ~is_finite
105
- to_recalculate = infinite_indices.sum().item()
106
- warnings.warn(f'Infinite number appeared in the curve simulation! Recalculate {to_recalculate} curves.')
107
- recalculated_batch_data = self.get_batch(to_recalculate)
108
- _insert_batch_data(batch_data, recalculated_batch_data, infinite_indices)
109
-
110
- is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
111
- assert torch.all(is_finite).item()
112
-
113
- self.update_batch_data(batch_data)
114
-
115
- return batch_data
116
-
117
- def _calc_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs):
118
- nonsmeared_curves = None
119
-
120
- if self.smearing:
121
- if self.calc_nonsmeared_curves:
122
- nonsmeared_curves = params.reflectivity(q_values, **refl_kwargs)
123
- curves, q_resolutions = self.smearing.get_curves(q_values, params, refl_kwargs)
124
- else:
125
- curves = params.reflectivity(q_values, **refl_kwargs)
126
- q_resolutions = None
127
-
128
- curves = curves.to(q_values)
129
- return curves, q_resolutions, nonsmeared_curves
130
-
131
-
132
- def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
133
- for key in tuple(tgt_batch_data.keys()):
134
- value = tgt_batch_data[key]
135
- if isinstance(value, BasicParams) or len(value.shape) == 2:
136
- value[indices] = add_batch_data[key]
137
- else:
138
- warnings.warn(f'Ignore {key} while merging batch_data.')
139
-
140
-
141
- if __name__ == '__main__':
142
- from reflectorch.data_generation.q_generator import ConstantQ
143
- from reflectorch.data_generation.priors import BasicPriorSampler, UniformSubPriorSampler
144
- from reflectorch.data_generation.noise import BasicExpIntensityNoise
145
- from reflectorch.data_generation.noise import BasicQNoiseGenerator
146
- from reflectorch.utils import to_np
147
- from time import perf_counter
148
-
149
- q_generator = ConstantQ((0, 0.2, 65), device='cpu')
150
- noise_gen = BasicExpIntensityNoise(
151
- relative_errors=(0.05, 0.2),
152
- # scale_range=(-1e-2, 1e-2),
153
- logdist=True,
154
- apply_shift=True,
155
- )
156
- q_noise_gen = BasicQNoiseGenerator(
157
- shift_std=5e-4,
158
- noise_std=(0, 1e-3),
159
- )
160
- prior_sampler = UniformSubPriorSampler(
161
- thickness_range=(0, 250),
162
- roughness_range=(0, 40),
163
- sld_range=(0, 60),
164
- num_layers=2,
165
- device=torch.device('cpu'),
166
- dtype=torch.float64,
167
- smaller_roughnesses=True,
168
- logdist=True,
169
- relative_min_bound_width=5e-4,
170
- )
171
- smearing = Smearing(
172
- sigma_range=(0.8e-3, 5e-3),
173
- gauss_num=31,
174
- share_smeared=0.5,
175
- )
176
-
177
- dataset = BasicDataset(
178
- q_generator,
179
- prior_sampler,
180
- noise_gen,
181
- q_noise=q_noise_gen,
182
- smearing=smearing
183
- )
184
- start = perf_counter()
185
- batch_data = dataset.get_batch(32)
186
- print(f'Total time = {(perf_counter() - start):.3f} sec ')
187
- print(batch_data['params'].roughnesses[:10])
188
- print(batch_data['scaled_noisy_curves'].min().item())
189
-
190
- scaled_noisy_curves = batch_data['scaled_noisy_curves']
191
- scaled_curves = dataset.curves_scaler.scale(
192
- batch_data['params'].reflectivity(q_generator.q)
193
- )
194
-
195
- try:
196
- import matplotlib.pyplot as plt
197
-
198
- for i in range(16):
199
- plt.plot(
200
- to_np(q_generator.q.squeeze().cpu().numpy()),
201
- to_np(scaled_curves[i])
202
- )
203
- plt.plot(
204
- to_np(q_generator.q.squeeze().cpu().numpy()),
205
- to_np(scaled_noisy_curves[i])
206
- )
207
-
208
- plt.show()
209
- except ImportError:
210
- pass
1
+ from typing import Dict, Union
2
+ import warnings
3
+
4
+ from torch import Tensor
5
+ import torch
6
+
7
+ from reflectorch.data_generation.priors import PriorSampler, BasicParams
8
+ from reflectorch.data_generation.noise import QNoiseGenerator, IntensityNoiseGenerator
9
+ from reflectorch.data_generation.q_generator import QGenerator
10
+ from reflectorch.data_generation.scale_curves import CurvesScaler, LogAffineCurvesScaler
11
+ from reflectorch.data_generation.smearing import Smearing
12
+
13
+ BATCH_DATA_TYPE = Dict[str, Union[Tensor, BasicParams]]
14
+
15
+
16
+ class BasicDataset(object):
17
+ """Reflectometry dataset. It generates the q positions, samples the thin film parameters from the prior,
18
+ simulates the reflectivity curves and applies noise to the curves.
19
+
20
+ Args:
21
+ q_generator (QGenerator): the momentum transfer (q) generator
22
+ prior_sampler (PriorSampler): the prior sampler
23
+ intensity_noise (IntensityNoiseGenerator, optional): the intensity noise generator. Defaults to None.
24
+ q_noise (QNoiseGenerator, optional): the q noise generator. Defaults to None.
25
+ curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
26
+ which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
27
+ calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
28
+ calc_nonsmeared_curves (bool, optional): whether to add the curves without smearing to the dictionary (only relevant when smearing is applied). Defaults to False.
29
+ smearing (Smearing, optional): curve smearing generator. Defaults to None.
30
+ """
31
+ def __init__(self,
32
+ q_generator: QGenerator,
33
+ prior_sampler: PriorSampler,
34
+ intensity_noise: IntensityNoiseGenerator = None,
35
+ q_noise: QNoiseGenerator = None,
36
+ curves_scaler: CurvesScaler = None,
37
+ calc_denoised_curves: bool = False,
38
+ calc_nonsmeared_curves: bool = False,
39
+ smearing: Smearing = None,
40
+ ):
41
+ self.q_generator = q_generator
42
+ self.intensity_noise = intensity_noise
43
+ self.q_noise = q_noise
44
+ self.curves_scaler = curves_scaler or LogAffineCurvesScaler()
45
+ self.prior_sampler = prior_sampler
46
+ self.smearing = smearing
47
+ self.calc_denoised_curves = calc_denoised_curves
48
+ self.calc_nonsmeared_curves = calc_nonsmeared_curves
49
+
50
+ def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
51
+ """implement in a subclass to edit batch_data dict inplace"""
52
+ pass
53
+
54
+ def _sample_from_prior(self, batch_size: int):
55
+ params: BasicParams = self.prior_sampler.sample(batch_size)
56
+ scaled_params: Tensor = self.prior_sampler.scale_params(params)
57
+ return params, scaled_params
58
+
59
+ def get_batch(self, batch_size: int) -> BATCH_DATA_TYPE:
60
+ """get a batch of data as a dictionary with keys ``params``, ``scaled_params``, ``q_values``, ``curves``, ``scaled_noisy_curves``
61
+
62
+ Args:
63
+ batch_size (int): the batch size
64
+ """
65
+ batch_data = {}
66
+
67
+ params, scaled_params = self._sample_from_prior(batch_size)
68
+
69
+ batch_data['params'] = params
70
+ batch_data['scaled_params'] = scaled_params
71
+
72
+ q_values: Tensor = self.q_generator.get_batch(batch_size, batch_data)
73
+
74
+ if self.q_noise:
75
+ batch_data['original_q_values'] = q_values
76
+ q_values = self.q_noise.apply(q_values, batch_data)
77
+
78
+ batch_data['q_values'] = q_values
79
+
80
+ refl_kwargs = {}
81
+
82
+ curves, q_resolutions, nonsmeared_curves = self._calc_curves(q_values, params, refl_kwargs)
83
+
84
+ if torch.is_tensor(q_resolutions):
85
+ batch_data['q_resolutions'] = q_resolutions
86
+
87
+ if torch.is_tensor(nonsmeared_curves):
88
+ batch_data['nonsmeared_curves'] = nonsmeared_curves
89
+
90
+ if self.calc_denoised_curves:
91
+ batch_data['curves'] = curves
92
+
93
+ noisy_curves = curves
94
+
95
+ if self.intensity_noise:
96
+ noisy_curves = self.intensity_noise(noisy_curves, batch_data)
97
+
98
+ scaled_noisy_curves = self.curves_scaler.scale(noisy_curves)
99
+ batch_data['scaled_noisy_curves'] = scaled_noisy_curves
100
+
101
+ is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
102
+
103
+ if not torch.all(is_finite).item():
104
+ infinite_indices = ~is_finite
105
+ to_recalculate = infinite_indices.sum().item()
106
+ warnings.warn(f'Infinite number appeared in the curve simulation! Recalculate {to_recalculate} curves.')
107
+ recalculated_batch_data = self.get_batch(to_recalculate)
108
+ _insert_batch_data(batch_data, recalculated_batch_data, infinite_indices)
109
+
110
+ is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
111
+ assert torch.all(is_finite).item()
112
+
113
+ self.update_batch_data(batch_data)
114
+
115
+ return batch_data
116
+
117
+ def _calc_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs):
118
+ nonsmeared_curves = None
119
+
120
+ if self.smearing:
121
+ if self.calc_nonsmeared_curves:
122
+ nonsmeared_curves = params.reflectivity(q_values, **refl_kwargs)
123
+ curves, q_resolutions = self.smearing.get_curves(q_values, params, refl_kwargs)
124
+ else:
125
+ curves = params.reflectivity(q_values, **refl_kwargs)
126
+ q_resolutions = None
127
+
128
+ curves = curves.to(q_values)
129
+ return curves, q_resolutions, nonsmeared_curves
130
+
131
+
132
+ def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
133
+ for key in tuple(tgt_batch_data.keys()):
134
+ value = tgt_batch_data[key]
135
+ if isinstance(value, BasicParams) or len(value.shape) == 2:
136
+ value[indices] = add_batch_data[key]
137
+ else:
138
+ warnings.warn(f'Ignore {key} while merging batch_data.')
139
+
140
+
141
+ if __name__ == '__main__':
142
+ from reflectorch.data_generation.q_generator import ConstantQ
143
+ from reflectorch.data_generation.priors import BasicPriorSampler, UniformSubPriorSampler
144
+ from reflectorch.data_generation.noise import BasicExpIntensityNoise
145
+ from reflectorch.data_generation.noise import BasicQNoiseGenerator
146
+ from reflectorch.utils import to_np
147
+ from time import perf_counter
148
+
149
+ q_generator = ConstantQ((0, 0.2, 65), device='cpu')
150
+ noise_gen = BasicExpIntensityNoise(
151
+ relative_errors=(0.05, 0.2),
152
+ # scale_range=(-1e-2, 1e-2),
153
+ logdist=True,
154
+ apply_shift=True,
155
+ )
156
+ q_noise_gen = BasicQNoiseGenerator(
157
+ shift_std=5e-4,
158
+ noise_std=(0, 1e-3),
159
+ )
160
+ prior_sampler = UniformSubPriorSampler(
161
+ thickness_range=(0, 250),
162
+ roughness_range=(0, 40),
163
+ sld_range=(0, 60),
164
+ num_layers=2,
165
+ device=torch.device('cpu'),
166
+ dtype=torch.float64,
167
+ smaller_roughnesses=True,
168
+ logdist=True,
169
+ relative_min_bound_width=5e-4,
170
+ )
171
+ smearing = Smearing(
172
+ sigma_range=(0.8e-3, 5e-3),
173
+ gauss_num=31,
174
+ share_smeared=0.5,
175
+ )
176
+
177
+ dataset = BasicDataset(
178
+ q_generator,
179
+ prior_sampler,
180
+ noise_gen,
181
+ q_noise=q_noise_gen,
182
+ smearing=smearing
183
+ )
184
+ start = perf_counter()
185
+ batch_data = dataset.get_batch(32)
186
+ print(f'Total time = {(perf_counter() - start):.3f} sec ')
187
+ print(batch_data['params'].roughnesses[:10])
188
+ print(batch_data['scaled_noisy_curves'].min().item())
189
+
190
+ scaled_noisy_curves = batch_data['scaled_noisy_curves']
191
+ scaled_curves = dataset.curves_scaler.scale(
192
+ batch_data['params'].reflectivity(q_generator.q)
193
+ )
194
+
195
+ try:
196
+ import matplotlib.pyplot as plt
197
+
198
+ for i in range(16):
199
+ plt.plot(
200
+ to_np(q_generator.q.squeeze().cpu().numpy()),
201
+ to_np(scaled_curves[i])
202
+ )
203
+ plt.plot(
204
+ to_np(q_generator.q.squeeze().cpu().numpy()),
205
+ to_np(scaled_noisy_curves[i])
206
+ )
207
+
208
+ plt.show()
209
+ except ImportError:
210
+ pass
@@ -1,80 +1,80 @@
1
- from typing import Union, Tuple
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- from reflectorch.data_generation import (
7
- PriorSampler,
8
- Params,
9
- )
10
-
11
-
12
- class LogLikelihood(object):
13
- """Computes the gaussian log likelihood of the thin film parameters
14
-
15
- Args:
16
- q (Tensor): the q values
17
- exp_curve (Tensor): the experimental reflectivity curve
18
- priors (PriorSampler): the prior sampler
19
- sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
20
- """
21
- def __init__(self, q: Tensor, exp_curve: Tensor, priors: PriorSampler, sigmas: Union[float, Tensor]):
22
- self.exp_curve = torch.atleast_2d(exp_curve)
23
- self.priors: PriorSampler = priors
24
- self.q = q
25
- self.sigmas = sigmas
26
- self.sigmas2 = self.sigmas ** 2
27
-
28
- def calc_log_likelihood(self, curves: Tensor):
29
- "computes the gaussian log likelihood"
30
- log_probs = - (self.exp_curve - curves) ** 2 / self.sigmas2 / 2
31
- return log_probs.sum(-1)
32
-
33
- def __call__(self, params: Union[Params, Tensor], curves: Tensor = None):
34
- if not isinstance(params, Params):
35
- params: Params = self.priors.PARAM_CLS.from_tensor(params)
36
- log_priors: Tensor = self.priors.log_prob(params)
37
- indices: Tensor = torch.isfinite(log_priors)
38
-
39
- if not indices.sum().item():
40
- return log_priors
41
-
42
- finite_params: Params = params[indices]
43
-
44
- if curves is None:
45
- curves: Tensor = finite_params.reflectivity(self.q)
46
- else:
47
- curves = curves[indices]
48
-
49
- log_priors[indices] += self.calc_log_likelihood(curves)
50
-
51
- return log_priors
52
-
53
- calc_log_posterior = __call__
54
-
55
- def get_importance_sampling_weights(
56
- self, sampled_params: Params, nf_log_probs: Tensor, curves: Tensor = None
57
- ) -> Tuple[Tensor, Tensor, Tensor]:
58
- log_probs = self.calc_log_posterior(sampled_params, curves=curves)
59
- log_weights = log_probs - nf_log_probs
60
- log_weights = log_weights - log_weights.max()
61
-
62
- weights = torch.exp(log_weights.to(torch.float64)).to(log_weights)
63
- weights = weights / weights.sum()
64
-
65
- return weights, log_weights, log_probs
66
-
67
-
68
- class PoissonLogLikelihood(LogLikelihood):
69
- """Computes the Poisson log likelihood of the thin film parameters
70
-
71
- Args:
72
- q (Tensor): the q values
73
- exp_curve (Tensor): the experimental reflectivity curve
74
- priors (PriorSampler): the prior sampler
75
- sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
76
- """
77
- def calc_log_likelihood(self, curves: Tensor):
78
- """computes the Poisson log likelihood"""
79
- log_probs = self.exp_curve / self.sigmas2 * (self.exp_curve * torch.log(curves) - curves)
80
- return log_probs.sum(-1)
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation import (
7
+ PriorSampler,
8
+ Params,
9
+ )
10
+
11
+
12
+ class LogLikelihood(object):
13
+ """Computes the gaussian log likelihood of the thin film parameters
14
+
15
+ Args:
16
+ q (Tensor): the q values
17
+ exp_curve (Tensor): the experimental reflectivity curve
18
+ priors (PriorSampler): the prior sampler
19
+ sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
20
+ """
21
+ def __init__(self, q: Tensor, exp_curve: Tensor, priors: PriorSampler, sigmas: Union[float, Tensor]):
22
+ self.exp_curve = torch.atleast_2d(exp_curve)
23
+ self.priors: PriorSampler = priors
24
+ self.q = q
25
+ self.sigmas = sigmas
26
+ self.sigmas2 = self.sigmas ** 2
27
+
28
+ def calc_log_likelihood(self, curves: Tensor):
29
+ "computes the gaussian log likelihood"
30
+ log_probs = - (self.exp_curve - curves) ** 2 / self.sigmas2 / 2
31
+ return log_probs.sum(-1)
32
+
33
+ def __call__(self, params: Union[Params, Tensor], curves: Tensor = None):
34
+ if not isinstance(params, Params):
35
+ params: Params = self.priors.PARAM_CLS.from_tensor(params)
36
+ log_priors: Tensor = self.priors.log_prob(params)
37
+ indices: Tensor = torch.isfinite(log_priors)
38
+
39
+ if not indices.sum().item():
40
+ return log_priors
41
+
42
+ finite_params: Params = params[indices]
43
+
44
+ if curves is None:
45
+ curves: Tensor = finite_params.reflectivity(self.q)
46
+ else:
47
+ curves = curves[indices]
48
+
49
+ log_priors[indices] += self.calc_log_likelihood(curves)
50
+
51
+ return log_priors
52
+
53
+ calc_log_posterior = __call__
54
+
55
+ def get_importance_sampling_weights(
56
+ self, sampled_params: Params, nf_log_probs: Tensor, curves: Tensor = None
57
+ ) -> Tuple[Tensor, Tensor, Tensor]:
58
+ log_probs = self.calc_log_posterior(sampled_params, curves=curves)
59
+ log_weights = log_probs - nf_log_probs
60
+ log_weights = log_weights - log_weights.max()
61
+
62
+ weights = torch.exp(log_weights.to(torch.float64)).to(log_weights)
63
+ weights = weights / weights.sum()
64
+
65
+ return weights, log_weights, log_probs
66
+
67
+
68
+ class PoissonLogLikelihood(LogLikelihood):
69
+ """Computes the Poisson log likelihood of the thin film parameters
70
+
71
+ Args:
72
+ q (Tensor): the q values
73
+ exp_curve (Tensor): the experimental reflectivity curve
74
+ priors (PriorSampler): the prior sampler
75
+ sigmas (Union[float, Tensor]): the sigmas (i.e. intensity error bars)
76
+ """
77
+ def calc_log_likelihood(self, curves: Tensor):
78
+ """computes the Poisson log likelihood"""
79
+ log_probs = self.exp_curve / self.sigmas2 * (self.exp_curve * torch.log(curves) - curves)
80
+ return log_probs.sum(-1)