reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,239 @@
1
+ from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
2
+
3
+ class Layer():
4
+ """Defines a single slab layer with prior bounds for thickness, roughness and SLD.
5
+
6
+ The bounds can be given for both real and imaginary parts of the SLD (the latter only if the model supports absorption).
7
+
8
+ Args:
9
+ thickness_bounds (Tuple[float, float]): Minimum and maximum thickness of the layer (in Å).
10
+ roughness_bounds (Tuple[float, float]): Minimum and maximum interfacial roughness at the top of this layer (in Å).
11
+ sld_bounds (Tuple[float, float]): Minimum and maximum real SLD of this layer (in 10⁻⁶ Å⁻²).
12
+ imag_sld_bounds (Tuple[float, float], optional): Minimum and maximum imaginary SLD (in 10⁻⁶ Å⁻²) of this layer. Defaults to None.
13
+ """
14
+ def __init__(self, thickness_bounds, roughness_bounds, sld_bounds, imag_sld_bounds=None):
15
+ self.thickness_bounds = thickness_bounds
16
+ self.roughness_bounds = roughness_bounds
17
+ self.sld_bounds = sld_bounds
18
+ self.imag_sld_bounds = imag_sld_bounds
19
+
20
+ class Backing():
21
+ """Defines the backing medium (substrate) for the multilayer structure.
22
+
23
+ The backing is assumed to be semi-infinite and has no thickness parameter.
24
+ This class ensures compatibility with the layer-based structure definition.
25
+
26
+ Args:
27
+ roughness_bounds (Tuple[float, float]): Minimum and maximum interfacial roughness at the top of the backing medium (in Å).
28
+ sld_bounds (Tuple[float, float]): Minimum and maximum real SLD of the backing medium (in 10⁻⁶ Å⁻²).
29
+ imag_sld_bounds (Tuple[float, float], optional): Minimum and maximum imaginary SLD (in 10⁻⁶ Å⁻²) of the backing. Defaults to None.
30
+ """
31
+ def __init__(self, roughness_bounds, sld_bounds, imag_sld_bounds=None):
32
+ self.thickness_bounds = None
33
+ self.roughness_bounds = roughness_bounds
34
+ self.sld_bounds = sld_bounds
35
+ self.imag_sld_bounds = imag_sld_bounds
36
+
37
+ class Structure():
38
+ """Defines a multilayer structure and its parameter bounds in a layer-wise manner.
39
+
40
+ This class allows the user to define the prior bounds for the full structure (film + backing) in a layer-wise format. It automatically constructs the
41
+ flattened list of parameter bounds compatible with the inference model’s expected input format.
42
+
43
+ Args:
44
+ layers (List[Union[Layer, Backing]]): Ordered list of layers defining the structure, from the ambient side to the backing. The last element
45
+ must be a :class:`Backing` instance. Note that the fronting medium (ambient) is not part of this list (since it is not a predicted parameter),
46
+ and is treated by default as being 0 (air). For different fronting media one can use the ``ambient_sld`` argument of the prediction method.
47
+ q_shift_bounds (Tuple[float, float], optional): Bounds for the global ``q_shift`` nuisance parameter. Defaults to None.
48
+ r_scale_bounds (Tuple[float, float], optional): Bounds for the global reflectivity scale factor ``r_scale``. Defaults to None.
49
+ log10_background_bounds (Tuple[float, float], optional): Bounds for the background term expressed as log10(background). Defaults to None.
50
+
51
+ Attributes:
52
+ thicknesses_bounds (List[Tuple[float, float]]): Bounds for all thicknesses (excluding backing).
53
+ roughnesses_bounds (List[Tuple[float, float]]): Bounds for all roughnesses (including backing).
54
+ slds_bounds (List[Tuple[float, float]]): Bounds for all real SLDs (including backing).
55
+ imag_slds_bounds (List[Tuple[float, float]]): Bounds for all imaginary SLDs (if provided).
56
+ prior_bounds (List[Tuple[float, float]]): Flattened list of all parameter bounds in the order expected by the model: thicknesses,
57
+ roughnesses, real SLDs, imaginary SLDs (if present), followed by nuisance parameters.
58
+
59
+ Example:
60
+ >>> layer1 = Layer(thickness_bounds=[1, 100], roughness_bounds=[0, 10], sld_bounds=[-2, 2])
61
+ >>> backing = Backing(roughness_bounds=[0, 15], sld_bounds=[0, 3])
62
+ >>> structure = Structure(layers=[layer1, backing], r_scale_bounds=[0.9, 1.1])
63
+ >>> structure.prior_bounds
64
+ """
65
+ def __init__(self, layers, q_shift_bounds=None, r_scale_bounds=None, log10_background_bounds=None):
66
+ self.layers=layers
67
+ self.q_shift_bounds=q_shift_bounds
68
+ self.r_scale_bounds = r_scale_bounds
69
+ self.log10_background_bounds = log10_background_bounds
70
+ self.thicknesses_bounds = []
71
+ self.roughnesses_bounds = []
72
+ self.slds_bounds = []
73
+ self.imag_slds_bounds = []
74
+
75
+ for layer in layers:
76
+ if layer.thickness_bounds is not None:
77
+ self.thicknesses_bounds.append(layer.thickness_bounds)
78
+ self.roughnesses_bounds.append(layer.roughness_bounds)
79
+ self.slds_bounds.append(layer.sld_bounds)
80
+ if layer.imag_sld_bounds is not None:
81
+ self.imag_slds_bounds.append(layer.imag_sld_bounds)
82
+
83
+ self.prior_bounds = self.thicknesses_bounds + self.roughnesses_bounds + self.slds_bounds + self.imag_slds_bounds
84
+
85
+ if q_shift_bounds is not None:
86
+ self.prior_bounds += [q_shift_bounds]
87
+ if r_scale_bounds is not None:
88
+ self.prior_bounds += [r_scale_bounds]
89
+ if log10_background_bounds is not None:
90
+ self.prior_bounds += [log10_background_bounds]
91
+
92
+ def validate_parameters_and_ranges(self, inference_model):
93
+ """Validate that all layer bounds and nuisance parameters match the model's configuration.
94
+
95
+ This method checks that:
96
+ * The number of layers matches the model’s expected number.
97
+ * Each layer’s thickness, roughness, and SLD bounds are within the
98
+ model’s training ranges.
99
+ * The SLD bound width does not exceed the maximum training width.
100
+ * Any nuisance parameters expected by the model (e.g. q_shift, r_scale,
101
+ log10_background) are provided and within training bounds.
102
+
103
+ Args:
104
+ inference_model (InferenceModel): A loaded instance of :class:`InferenceModel` used to access the model’s metadata.
105
+
106
+ Raises:
107
+ ValueError: If the number of layers, parameter ranges, or nuisance configurations are inconsistent with the model.
108
+ """
109
+ if len(self.layers) - 1 != inference_model.trainer.loader.prior_sampler.max_num_layers:
110
+ raise ValueError(f'Number of layers mismatch: this model expects {inference_model.trainer.loader.prior_sampler.max_num_layers} layers (backing not included) but you provided {len(self.layers) - 1}')
111
+
112
+
113
+ thickness_ranges = inference_model.trainer.loader.prior_sampler.param_ranges['thicknesses']
114
+ roughness_ranges = inference_model.trainer.loader.prior_sampler.param_ranges['roughnesses']
115
+ sld_ranges = inference_model.trainer.loader.prior_sampler.param_ranges['slds']
116
+
117
+ def layer_name(i):
118
+ if i == inference_model.trainer.loader.prior_sampler.max_num_layers:
119
+ return 'the backing medium'
120
+ else:
121
+ return f'layer {i+1}'
122
+
123
+ for i, layer in enumerate(self.layers):
124
+ if layer.thickness_bounds is not None:
125
+ if layer.thickness_bounds[0] < thickness_ranges[0] or layer.thickness_bounds[1] > thickness_ranges[1]:
126
+ raise ValueError(f"The provided prior bounds for the thickness of layer {i+1} are outside the training range of the network: {thickness_ranges}")
127
+ if layer.roughness_bounds[0] < roughness_ranges[0] or layer.roughness_bounds[1] > roughness_ranges[1]:
128
+ raise ValueError(f"The provided prior bounds for the roughness of {layer_name(i)} are outside the training range of the network: {roughness_ranges}")
129
+ if layer.sld_bounds[0] < sld_ranges[0] or layer.sld_bounds[1] > sld_ranges[1]:
130
+ raise ValueError(f"The provided prior bounds for the (real) SLD of {layer_name(i)} are outside the training range of the network: {sld_ranges}")
131
+
132
+ max_sld_bounds_width = inference_model.trainer.loader.prior_sampler.bound_width_ranges['slds'][1]
133
+ if layer.sld_bounds[1] - layer.sld_bounds[0] > max_sld_bounds_width:
134
+ raise ValueError(f"The provided prior bounds for the (real) SLD of {layer_name(i)} have a width (max - min) exceeding the maximum width used for training: {max_sld_bounds_width}")
135
+
136
+ param_model = inference_model.trainer.loader.prior_sampler.param_model
137
+ if isinstance(param_model, NuisanceParamsWrapper):
138
+ nuisance_params_config = inference_model.trainer.loader.prior_sampler.shift_param_config
139
+
140
+ if self.q_shift_bounds is not None:
141
+ if 'q_shift' not in nuisance_params_config:
142
+ raise ValueError(f'Prior bounds for the q_shift parameter were provided but this parameter is not supported by this model.')
143
+ q_shift_range = inference_model.trainer.loader.prior_sampler.param_ranges['q_shift']
144
+ if self.q_shift_bounds[0] < q_shift_range[0] or self.q_shift_bounds[1] > q_shift_range[1]:
145
+ raise ValueError(f"The provided prior bounds for the q_shift are outside the training range of the network: {q_shift_range}")
146
+
147
+ if self.r_scale_bounds is not None:
148
+ if 'r_scale' not in nuisance_params_config:
149
+ raise ValueError(f'Prior bounds for the r_scale parameter were provided but this parameter is not supported by this model.')
150
+ r_scale_range = inference_model.trainer.loader.prior_sampler.param_ranges['r_scale']
151
+ if self.r_scale_bounds[0] < r_scale_range[0] or self.r_scale_bounds[1] > r_scale_range[1]:
152
+ raise ValueError(f"The provided prior bounds for the r_scale are outside the training range of the network: {r_scale_range}")
153
+
154
+ if self.log10_background_bounds is not None:
155
+ if 'log10_background' not in nuisance_params_config:
156
+ raise ValueError(f'Prior bounds for the log10_background parameter were provided but this parameter is not supported by this model.')
157
+ log10_background_range = inference_model.trainer.loader.prior_sampler.param_ranges['log10_background']
158
+ if self.log10_background_bounds[0] < log10_background_range[0] or self.log10_background_bounds[1] > log10_background_range[1]:
159
+ raise ValueError(f"The provided prior bounds for the r_scale are outside the training range of the network: {log10_background_range}")
160
+
161
+ if isinstance(param_model, NuisanceParamsWrapper):
162
+ if 'q_shift' in nuisance_params_config and self.q_shift_bounds is None:
163
+ raise ValueError(f'Prior bounds for the q_shift parameter are expected by this model but were not provided.')
164
+
165
+ if 'r_scale' in nuisance_params_config and self.r_scale_bounds is None:
166
+ raise ValueError(f'Prior bounds for the r_scale parameter are expected by this model but were not provided.')
167
+
168
+ if 'log10_background' in nuisance_params_config and self.log10_background_bounds is None:
169
+ raise ValueError(f'Prior bounds for the log10_background parameter are expected by this model but were not provided.')
170
+
171
+ print("All checks passed.")
172
+
173
+ def get_huggingface_filtering_query(self):
174
+ """Constructs a metadata query for selecting compatible pretrained models from Huggingface. Currently it only supports the older (research style)
175
+ layout of Huggingface repositories (such as 'valentinsingularity/reflectivity'), but not the newer layout (such as `reflectorch-ILL`).
176
+
177
+ Returns:
178
+ dict: A dictionary describing the structure and parameter bounds, suitable for filtering available model configurations
179
+ in a Huggingface repository using :class:`HuggingfaceQueryMatcher`.
180
+
181
+ Example:
182
+ >>> structure = Structure([...])
183
+ >>> query = structure.get_huggingface_filtering_query()
184
+ >>> matcher = HuggingfaceQueryMatcher(repo_id='valentinsingularity/reflectivity')
185
+ >>> configs = matcher.get_matching_configs(query)
186
+ """
187
+ query = {'dset.prior_sampler.kwargs.max_num_layers': len(self.layers) - 1}
188
+
189
+ query['dset.prior_sampler.kwargs.param_ranges.thicknesses'] = [min(sl[0] for sl in self.thicknesses_bounds), max(sl[1] for sl in self.thicknesses_bounds)]
190
+ query['dset.prior_sampler.kwargs.param_ranges.roughnesses'] = [min(sl[0] for sl in self.roughnesses_bounds), max(sl[1] for sl in self.roughnesses_bounds)]
191
+ query['dset.prior_sampler.kwargs.param_ranges.slds'] = [min(sl[0] for sl in self.slds_bounds), max(sl[1] for sl in self.slds_bounds)]
192
+
193
+ if len(self.imag_slds_bounds) > 0:
194
+ query['dset.prior_sampler.kwargs.model_name'] = 'model_with_absorption'
195
+ query['dset.prior_sampler.kwargs.param_ranges.islds'] = [min(sl[0] for sl in self.imag_slds_bounds), max(sl[1] for sl in self.imag_slds_bounds)]
196
+ else:
197
+ query['dset.prior_sampler.kwargs.model_name'] = 'standard_model'
198
+
199
+ if self.q_shift_bounds is not None:
200
+ query['dset.prior_sampler.kwargs.shift_param_config.q_shift'] = True
201
+ query['dset.prior_sampler.kwargs.param_ranges.q_shift'] = self.q_shift_bounds
202
+
203
+ if self.r_scale_bounds is not None:
204
+ query['dset.prior_sampler.kwargs.shift_param_config.r_scale'] = True
205
+ query['dset.prior_sampler.kwargs.param_ranges.r_scale'] = self.r_scale_bounds
206
+
207
+ if self.log10_background_bounds is not None:
208
+ query['dset.prior_sampler.kwargs.shift_param_config.log10_background'] = True
209
+ query['dset.prior_sampler.kwargs.param_ranges.log10_background'] = self.log10_background_bounds
210
+
211
+ return query
212
+
213
+
214
+ if __name__ == '__main__':
215
+ from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
216
+ from reflectorch import EasyInferenceModel
217
+
218
+ layer1 = Layer(thickness_bounds=[1, 1000], roughness_bounds=[0, 60], sld_bounds=[-2, 2])
219
+ layer2 = Layer(thickness_bounds=[1, 50], roughness_bounds=[0, 10], sld_bounds=[1, 4])
220
+ backing = Backing(roughness_bounds=[0, 15], sld_bounds=[0, 3])
221
+
222
+ structure = Structure(
223
+ layers=[layer1, layer2, backing],
224
+ r_scale_bounds=[0.9, 1.1],
225
+ log10_background_bounds=[-8, -5],
226
+ )
227
+
228
+ print(structure.prior_bounds)
229
+
230
+ query_matcher = HuggingfaceQueryMatcher(repo_id='valentinsingularity/reflectivity')
231
+ filtering_query = structure.get_huggingface_filtering_query()
232
+ print(filtering_query)
233
+
234
+ matching_configs = query_matcher.get_matching_configs(filtering_query)
235
+ print(f'Matching configs: {matching_configs}')
236
+
237
+
238
+ inference_model = EasyInferenceModel(config_name=matching_configs[0])
239
+ structure.validate_parameters_and_ranges(inference_model)
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+ import numpy as np
3
+
4
+ def load_mft_data(filepath):
5
+ """
6
+ Load q, reflectivity, reflectivity error, and q-resolution from an .mft file.
7
+
8
+ Parameters:
9
+ filepath (str or Path): Path to the .mft file
10
+
11
+ Returns:
12
+ q, refl, refl_err, q_res : np.ndarray
13
+ """
14
+ filepath = Path(filepath)
15
+
16
+ with filepath.open('r', encoding='utf-8') as f:
17
+ lines = f.readlines()
18
+
19
+ start_idx = next(
20
+ i for i, line in enumerate(lines)
21
+ if line.strip().startswith('q') and 'q_res' in line
22
+ ) + 1
23
+
24
+ data = []
25
+ for line in lines[start_idx:]:
26
+ parts = line.strip().split()
27
+ if len(parts) == 4:
28
+ try:
29
+ data.append([float(p.replace('E', 'e')) for p in parts])
30
+ except ValueError:
31
+ continue
32
+
33
+ if not data:
34
+ raise ValueError(f"No valid data found in {filepath}")
35
+
36
+ data_array = np.array(data)
37
+ return data_array.T
@@ -1,171 +1,171 @@
1
- from tqdm import trange
2
-
3
- from time import perf_counter
4
-
5
- import torch
6
- from torch import nn, Tensor
7
-
8
- from reflectorch.data_generation.reflectivity import kinematical_approximation
9
- from reflectorch.data_generation.priors.multilayer_structures import SimpleMultilayerSampler
10
-
11
-
12
- class MultilayerFit(object):
13
- def __init__(self,
14
- q: Tensor,
15
- exp_curve: Tensor,
16
- params: Tensor,
17
- convert_func,
18
- scale_curve_func,
19
- min_bounds: Tensor,
20
- max_bounds: Tensor,
21
- optim_cls=None,
22
- lr: float = 5e-2,
23
- ):
24
-
25
- self.q = q
26
- self.scale_curve_func = scale_curve_func
27
- self.scaled_exp_curve = self.scale_curve_func(torch.atleast_2d(exp_curve))
28
- self.params_to_fit = nn.Parameter(params.clone())
29
- self.convert_func = convert_func
30
-
31
- self.min_bounds = min_bounds
32
- self.max_bounds = max_bounds
33
-
34
- optim_cls = optim_cls or torch.optim.Adam
35
- self.optim = optim_cls([self.params_to_fit], lr)
36
-
37
- self.best_loss = float('inf')
38
- self.best_params, self.best_curve = None, None
39
-
40
- self.get_best_solution()
41
- self.losses = []
42
-
43
- @classmethod
44
- def from_prior_sampler(
45
- cls,
46
- q: Tensor,
47
- exp_curve: Tensor,
48
- prior_sampler: SimpleMultilayerSampler,
49
- batch_size: int = 2 ** 13,
50
- **kwargs
51
- ):
52
- _, scaled_params = prior_sampler.optimized_sample(batch_size)
53
- params = prior_sampler.restore_params2parametrized(scaled_params)
54
-
55
- return cls(
56
- q.float(), exp_curve.float(), params.float(),
57
- convert_func=get_convert_func(prior_sampler.multilayer_model),
58
- scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
59
- max_bounds=prior_sampler.max_bounds,
60
- **kwargs
61
- )
62
-
63
- @classmethod
64
- def from_prediction(
65
- cls,
66
- param: Tensor,
67
- prior_sampler,
68
- q: Tensor,
69
- exp_curve: Tensor,
70
- batch_size: int = 2 ** 13,
71
- rel_bounds: float = 0.3,
72
- **kwargs
73
- ):
74
-
75
- num_params = param.shape[-1]
76
-
77
- deltas = prior_sampler.restore_params2parametrized(
78
- (torch.rand(batch_size, num_params, device=param.device) * 2 - 1) * rel_bounds
79
- ) - prior_sampler.min_bounds
80
-
81
- deltas[0] = 0.
82
-
83
- params = torch.atleast_2d(param) + deltas
84
-
85
- return cls(
86
- q.float(), exp_curve.float(), params.float(),
87
- convert_func=get_convert_func(prior_sampler.multilayer_model),
88
- scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
89
- max_bounds=prior_sampler.max_bounds,
90
- **kwargs
91
- )
92
-
93
- def get_clipped_params(self):
94
- return torch.clamp(self.params_to_fit.detach().clone(), self.min_bounds, self.max_bounds)
95
-
96
- def calc_loss(self, reduce=True, clip: bool = False):
97
- if clip:
98
- params = self.get_clipped_params()
99
- else:
100
- params = self.params_to_fit
101
- curves = self.get_curves(params=params)
102
- losses = ((self.scale_curve_func(curves) - self.scaled_exp_curve) ** 2)
103
-
104
- if reduce:
105
- return losses.sum()
106
- else:
107
- return losses.sum(-1)
108
-
109
- def get_curves(self, params: Tensor = None):
110
- if params is None:
111
- params = self.params_to_fit
112
- curves = kinematical_approximation(self.q, **self.convert_func(params))
113
- return torch.atleast_2d(curves)
114
-
115
- def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
116
- pbar = trange(num_iterations, disable=disable_tqdm)
117
-
118
- for _ in pbar:
119
- self.optim.zero_grad()
120
- loss = self.calc_loss()
121
- loss.backward()
122
- self.optim.step()
123
- self.losses.append(loss.item())
124
- pbar.set_description(f'Loss = {loss.item():.2e}')
125
-
126
- def clear(self):
127
- self.losses.clear()
128
-
129
- def run_fixed_time(self, time_limit: float = 2.):
130
-
131
- start = perf_counter()
132
-
133
- while True:
134
- self.optim.zero_grad()
135
- loss = self.calc_loss()
136
- loss.backward()
137
- self.optim.step()
138
- self.losses.append(loss.item())
139
- time_spent = perf_counter() - start
140
-
141
- if time_spent > time_limit:
142
- break
143
-
144
- @torch.no_grad()
145
- def get_best_solution(self, clip: bool = True):
146
- losses = self.calc_loss(clip=clip, reduce=False)
147
- idx = torch.argmin(losses)
148
- best_loss = losses[idx].item()
149
-
150
- if best_loss < self.best_loss:
151
- best_curve = self.get_curves()[idx]
152
- best_params = self.params_to_fit.detach()[idx].clone()
153
- self.best_params, self.best_curve, self.best_loss = best_params, best_curve, best_loss
154
-
155
- return self.best_params
156
-
157
-
158
- def get_convert_func(multilayer_model):
159
- def func(params):
160
- params = multilayer_model.to_standard_params(params)
161
- return {
162
- 'thickness': params['thicknesses'], # very useful transformation indeed ...
163
- 'roughness': params['roughnesses'],
164
- 'sld': params['slds'],
165
- }
166
-
167
- return func
168
-
169
-
170
- def _save_log(curves, eps: float = 1e-10):
171
- return torch.log10(curves + eps)
1
+ from tqdm import trange
2
+
3
+ from time import perf_counter
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+
8
+ from reflectorch.data_generation.reflectivity import kinematical_approximation
9
+ from reflectorch.data_generation.priors.multilayer_structures import SimpleMultilayerSampler
10
+
11
+
12
+ class MultilayerFit(object):
13
+ def __init__(self,
14
+ q: Tensor,
15
+ exp_curve: Tensor,
16
+ params: Tensor,
17
+ convert_func,
18
+ scale_curve_func,
19
+ min_bounds: Tensor,
20
+ max_bounds: Tensor,
21
+ optim_cls=None,
22
+ lr: float = 5e-2,
23
+ ):
24
+
25
+ self.q = q
26
+ self.scale_curve_func = scale_curve_func
27
+ self.scaled_exp_curve = self.scale_curve_func(torch.atleast_2d(exp_curve))
28
+ self.params_to_fit = nn.Parameter(params.clone())
29
+ self.convert_func = convert_func
30
+
31
+ self.min_bounds = min_bounds
32
+ self.max_bounds = max_bounds
33
+
34
+ optim_cls = optim_cls or torch.optim.Adam
35
+ self.optim = optim_cls([self.params_to_fit], lr)
36
+
37
+ self.best_loss = float('inf')
38
+ self.best_params, self.best_curve = None, None
39
+
40
+ self.get_best_solution()
41
+ self.losses = []
42
+
43
+ @classmethod
44
+ def from_prior_sampler(
45
+ cls,
46
+ q: Tensor,
47
+ exp_curve: Tensor,
48
+ prior_sampler: SimpleMultilayerSampler,
49
+ batch_size: int = 2 ** 13,
50
+ **kwargs
51
+ ):
52
+ _, scaled_params = prior_sampler.optimized_sample(batch_size)
53
+ params = prior_sampler.restore_params2parametrized(scaled_params)
54
+
55
+ return cls(
56
+ q.float(), exp_curve.float(), params.float(),
57
+ convert_func=get_convert_func(prior_sampler.multilayer_model),
58
+ scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
59
+ max_bounds=prior_sampler.max_bounds,
60
+ **kwargs
61
+ )
62
+
63
+ @classmethod
64
+ def from_prediction(
65
+ cls,
66
+ param: Tensor,
67
+ prior_sampler,
68
+ q: Tensor,
69
+ exp_curve: Tensor,
70
+ batch_size: int = 2 ** 13,
71
+ rel_bounds: float = 0.3,
72
+ **kwargs
73
+ ):
74
+
75
+ num_params = param.shape[-1]
76
+
77
+ deltas = prior_sampler.restore_params2parametrized(
78
+ (torch.rand(batch_size, num_params, device=param.device) * 2 - 1) * rel_bounds
79
+ ) - prior_sampler.min_bounds
80
+
81
+ deltas[0] = 0.
82
+
83
+ params = torch.atleast_2d(param) + deltas
84
+
85
+ return cls(
86
+ q.float(), exp_curve.float(), params.float(),
87
+ convert_func=get_convert_func(prior_sampler.multilayer_model),
88
+ scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
89
+ max_bounds=prior_sampler.max_bounds,
90
+ **kwargs
91
+ )
92
+
93
+ def get_clipped_params(self):
94
+ return torch.clamp(self.params_to_fit.detach().clone(), self.min_bounds, self.max_bounds)
95
+
96
+ def calc_loss(self, reduce=True, clip: bool = False):
97
+ if clip:
98
+ params = self.get_clipped_params()
99
+ else:
100
+ params = self.params_to_fit
101
+ curves = self.get_curves(params=params)
102
+ losses = ((self.scale_curve_func(curves) - self.scaled_exp_curve) ** 2)
103
+
104
+ if reduce:
105
+ return losses.sum()
106
+ else:
107
+ return losses.sum(-1)
108
+
109
+ def get_curves(self, params: Tensor = None):
110
+ if params is None:
111
+ params = self.params_to_fit
112
+ curves = kinematical_approximation(self.q, **self.convert_func(params))
113
+ return torch.atleast_2d(curves)
114
+
115
+ def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
116
+ pbar = trange(num_iterations, disable=disable_tqdm)
117
+
118
+ for _ in pbar:
119
+ self.optim.zero_grad()
120
+ loss = self.calc_loss()
121
+ loss.backward()
122
+ self.optim.step()
123
+ self.losses.append(loss.item())
124
+ pbar.set_description(f'Loss = {loss.item():.2e}')
125
+
126
+ def clear(self):
127
+ self.losses.clear()
128
+
129
+ def run_fixed_time(self, time_limit: float = 2.):
130
+
131
+ start = perf_counter()
132
+
133
+ while True:
134
+ self.optim.zero_grad()
135
+ loss = self.calc_loss()
136
+ loss.backward()
137
+ self.optim.step()
138
+ self.losses.append(loss.item())
139
+ time_spent = perf_counter() - start
140
+
141
+ if time_spent > time_limit:
142
+ break
143
+
144
+ @torch.no_grad()
145
+ def get_best_solution(self, clip: bool = True):
146
+ losses = self.calc_loss(clip=clip, reduce=False)
147
+ idx = torch.argmin(losses)
148
+ best_loss = losses[idx].item()
149
+
150
+ if best_loss < self.best_loss:
151
+ best_curve = self.get_curves()[idx]
152
+ best_params = self.params_to_fit.detach()[idx].clone()
153
+ self.best_params, self.best_curve, self.best_loss = best_params, best_curve, best_loss
154
+
155
+ return self.best_params
156
+
157
+
158
+ def get_convert_func(multilayer_model):
159
+ def func(params):
160
+ params = multilayer_model.to_standard_params(params)
161
+ return {
162
+ 'thickness': params['thicknesses'], # very useful transformation indeed ...
163
+ 'roughness': params['roughnesses'],
164
+ 'sld': params['slds'],
165
+ }
166
+
167
+ return func
168
+
169
+
170
+ def _save_log(curves, eps: float = 1e-10):
171
+ return torch.log10(curves + eps)