reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,193 +1,193 @@
1
- from typing import Tuple
2
-
3
- import torch
4
- from torch import Tensor
5
- import numpy as np
6
-
7
- from reflectorch.inference.inference_model import (
8
- InferenceModel,
9
- )
10
- from reflectorch.data_generation.reflectivity import kinematical_approximation_np, abeles_np
11
-
12
- from reflectorch.data_generation.priors import (
13
- MultilayerStructureParams,
14
- SimpleMultilayerSampler,
15
- )
16
- from reflectorch.inference.record_time import print_time
17
- from reflectorch.inference.scipy_fitter import standard_refl_fit
18
- from reflectorch.inference.multilayer_fitter import MultilayerFit
19
-
20
-
21
- class MultilayerInferenceModel(InferenceModel):
22
- def predict(self,
23
- intensity: np.ndarray,
24
- scattering_angle: np.ndarray,
25
- attenuation: np.ndarray,
26
- priors: np.ndarray = None,
27
- preprocessing_parameters: dict = None,
28
- polish: bool = True,
29
- use_raw_q: bool = False,
30
- **kwargs
31
- ) -> dict:
32
-
33
- with print_time("everything"):
34
- with print_time("preprocess"):
35
- preprocessed_dict = self.preprocess(
36
- intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
37
- )
38
-
39
- preprocessed_curve = preprocessed_dict["curve_interp"]
40
-
41
- raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
42
-
43
- with print_time("predict_from_preprocessed_curve"):
44
- preprocessed_dict.update(self.predict_from_preprocessed_curve(
45
- preprocessed_curve, priors,
46
- raw_curve=(raw_curve if use_raw_q else None),
47
- raw_q=raw_q,
48
- polish=polish,
49
- use_raw_q=use_raw_q,
50
- **kwargs
51
- ))
52
-
53
- return preprocessed_dict
54
-
55
- def predict_from_preprocessed_curve(self,
56
- curve: np.ndarray,
57
- priors: np.ndarray = None, *, # ignore the priors so far
58
- polish: bool = True,
59
- raw_curve: np.ndarray = None,
60
- raw_q: np.ndarray = None,
61
- clip_prediction: bool = True,
62
- use_raw_q: bool = False,
63
- use_sampler: bool = False,
64
- fitted_time_limit: float = 3.,
65
- sampler_rel_bounds: float = 0.3,
66
- polish_with_abeles: bool = False,
67
- **kwargs
68
- ) -> dict:
69
-
70
- scaled_curve = self._scale_curve(curve)
71
-
72
- predicted_params, parametrized = self._simple_prediction(scaled_curve)
73
-
74
- if use_sampler:
75
- parametrized: Tensor = self._sampler_solution(
76
- curve, parametrized,
77
- time_limit=fitted_time_limit,
78
- rel_bounds=sampler_rel_bounds,
79
- )
80
-
81
- init_raw_q = raw_q
82
-
83
- if raw_curve is None:
84
- raw_curve = curve
85
- raw_q = self.q.squeeze().cpu().numpy()
86
- raw_q_t = self.q
87
- else:
88
- raw_q_t = torch.from_numpy(raw_q).to(self.q)
89
-
90
- # if q_ratio != 1.:
91
- # predicted_params.scale_with_q(q_ratio)
92
- # raw_q = raw_q * q_ratio
93
- # raw_q_t = raw_q_t * q_ratio
94
-
95
- prediction_dict = {
96
- "params": parametrized.squeeze().cpu().numpy(),
97
- "param_names": list(self._prior_sampler.multilayer_model.PARAMETER_NAMES),
98
- "curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
99
- }
100
-
101
- # sld_x_axis, sld_profile, _ = get_density_profiles(
102
- # predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
103
- # )
104
- #
105
- # prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
106
- # prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
107
-
108
- if polish:
109
- prediction_dict.update(self._polish_prediction(
110
- raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=True
111
- ))
112
- if polish_with_abeles:
113
- prediction_dict.update(self._polish_prediction(
114
- raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=False
115
- ))
116
-
117
- return prediction_dict
118
-
119
- def _simple_prediction(self, scaled_curve) -> Tuple[MultilayerStructureParams, Tensor]:
120
- with torch.no_grad():
121
- self.trainer.model.eval()
122
- scaled_params = self.trainer.model(scaled_curve)
123
-
124
- predicted_params, parametrized = self._restore_predicted_params(scaled_params)
125
- return predicted_params, parametrized
126
-
127
- def _restore_predicted_params(self, scaled_params: Tensor) -> Tuple[MultilayerStructureParams, Tensor]:
128
- parametrized = self._prior_sampler.restore_params2parametrized(scaled_params)
129
- predicted_params: MultilayerStructureParams = self._prior_sampler.restore_params(scaled_params)
130
- return predicted_params, parametrized
131
-
132
- @print_time
133
- def _sampler_solution(
134
- self,
135
- curve: Tensor or np.ndarray,
136
- predicted_params: Tensor,
137
- batch_size: int = 2 ** 13,
138
- time_limit: float = 3.,
139
- rel_bounds: float = 0.3,
140
- ) -> Tensor:
141
-
142
- fit_obj = MultilayerFit.from_prediction(
143
- predicted_params, self._prior_sampler, self.q, torch.as_tensor(curve).to(self.q),
144
- batch_size=batch_size, rel_bounds=rel_bounds,
145
- )
146
-
147
- fit_obj.run_fixed_time(time_limit)
148
-
149
- best_params = fit_obj.get_best_solution()
150
-
151
- return best_params
152
-
153
- @property
154
- def _prior_sampler(self) -> SimpleMultilayerSampler:
155
- return self.trainer.loader.prior_sampler
156
-
157
- @print_time
158
- def _polish_prediction(self,
159
- q: np.ndarray,
160
- curve: np.ndarray,
161
- predicted_params: Tensor,
162
- q_values: np.ndarray,
163
- use_kinematical: bool = True,
164
- ) -> dict:
165
-
166
- params = predicted_params.squeeze().cpu().numpy()
167
- polished_params_dict = {}
168
-
169
- if use_kinematical:
170
- refl_generator = kinematical_approximation_np
171
- else:
172
- refl_generator = abeles_np
173
-
174
- try:
175
- polished_params_arr, curve_polished = standard_refl_fit(
176
- q, curve, params, restore_params_func=self._prior_sampler.restore_np_params,
177
- refl_generator=refl_generator,
178
- bounds=self._prior_sampler.get_np_bounds(),
179
- )
180
- params = self._prior_sampler.restore_np_params(polished_params_arr)
181
- if q_values is None:
182
- q_values = q
183
- curve_polished = abeles_np(q_values, **params)
184
-
185
- except Exception as err:
186
- self.log.exception(err)
187
- polished_params_arr = params
188
- curve_polished = np.zeros_like(q)
189
-
190
- polished_params_dict['params_polished'] = polished_params_arr
191
- polished_params_dict['curve_polished'] = curve_polished
192
-
193
- return polished_params_dict
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ import numpy as np
6
+
7
+ from reflectorch.inference.inference_model import (
8
+ InferenceModel,
9
+ )
10
+ from reflectorch.data_generation.reflectivity import kinematical_approximation_np, abeles_np
11
+
12
+ from reflectorch.data_generation.priors import (
13
+ MultilayerStructureParams,
14
+ SimpleMultilayerSampler,
15
+ )
16
+ from reflectorch.inference.record_time import print_time
17
+ from reflectorch.inference.scipy_fitter import standard_refl_fit
18
+ from reflectorch.inference.multilayer_fitter import MultilayerFit
19
+
20
+
21
+ class MultilayerInferenceModel(InferenceModel):
22
+ def predict(self,
23
+ intensity: np.ndarray,
24
+ scattering_angle: np.ndarray,
25
+ attenuation: np.ndarray,
26
+ priors: np.ndarray = None,
27
+ preprocessing_parameters: dict = None,
28
+ polish: bool = True,
29
+ use_raw_q: bool = False,
30
+ **kwargs
31
+ ) -> dict:
32
+
33
+ with print_time("everything"):
34
+ with print_time("preprocess"):
35
+ preprocessed_dict = self.preprocess(
36
+ intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
37
+ )
38
+
39
+ preprocessed_curve = preprocessed_dict["curve_interp"]
40
+
41
+ raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
42
+
43
+ with print_time("predict_from_preprocessed_curve"):
44
+ preprocessed_dict.update(self.predict_from_preprocessed_curve(
45
+ preprocessed_curve, priors,
46
+ raw_curve=(raw_curve if use_raw_q else None),
47
+ raw_q=raw_q,
48
+ polish=polish,
49
+ use_raw_q=use_raw_q,
50
+ **kwargs
51
+ ))
52
+
53
+ return preprocessed_dict
54
+
55
+ def predict_from_preprocessed_curve(self,
56
+ curve: np.ndarray,
57
+ priors: np.ndarray = None, *, # ignore the priors so far
58
+ polish: bool = True,
59
+ raw_curve: np.ndarray = None,
60
+ raw_q: np.ndarray = None,
61
+ clip_prediction: bool = True,
62
+ use_raw_q: bool = False,
63
+ use_sampler: bool = False,
64
+ fitted_time_limit: float = 3.,
65
+ sampler_rel_bounds: float = 0.3,
66
+ polish_with_abeles: bool = False,
67
+ **kwargs
68
+ ) -> dict:
69
+
70
+ scaled_curve = self._scale_curve(curve)
71
+
72
+ predicted_params, parametrized = self._simple_prediction(scaled_curve)
73
+
74
+ if use_sampler:
75
+ parametrized: Tensor = self._sampler_solution(
76
+ curve, parametrized,
77
+ time_limit=fitted_time_limit,
78
+ rel_bounds=sampler_rel_bounds,
79
+ )
80
+
81
+ init_raw_q = raw_q
82
+
83
+ if raw_curve is None:
84
+ raw_curve = curve
85
+ raw_q = self.q.squeeze().cpu().numpy()
86
+ raw_q_t = self.q
87
+ else:
88
+ raw_q_t = torch.from_numpy(raw_q).to(self.q)
89
+
90
+ # if q_ratio != 1.:
91
+ # predicted_params.scale_with_q(q_ratio)
92
+ # raw_q = raw_q * q_ratio
93
+ # raw_q_t = raw_q_t * q_ratio
94
+
95
+ prediction_dict = {
96
+ "params": parametrized.squeeze().cpu().numpy(),
97
+ "param_names": list(self._prior_sampler.multilayer_model.PARAMETER_NAMES),
98
+ "curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
99
+ }
100
+
101
+ # sld_x_axis, sld_profile, _ = get_density_profiles(
102
+ # predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
103
+ # )
104
+ #
105
+ # prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
106
+ # prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
107
+
108
+ if polish:
109
+ prediction_dict.update(self._polish_prediction(
110
+ raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=True
111
+ ))
112
+ if polish_with_abeles:
113
+ prediction_dict.update(self._polish_prediction(
114
+ raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=False
115
+ ))
116
+
117
+ return prediction_dict
118
+
119
+ def _simple_prediction(self, scaled_curve) -> Tuple[MultilayerStructureParams, Tensor]:
120
+ with torch.no_grad():
121
+ self.trainer.model.eval()
122
+ scaled_params = self.trainer.model(scaled_curve)
123
+
124
+ predicted_params, parametrized = self._restore_predicted_params(scaled_params)
125
+ return predicted_params, parametrized
126
+
127
+ def _restore_predicted_params(self, scaled_params: Tensor) -> Tuple[MultilayerStructureParams, Tensor]:
128
+ parametrized = self._prior_sampler.restore_params2parametrized(scaled_params)
129
+ predicted_params: MultilayerStructureParams = self._prior_sampler.restore_params(scaled_params)
130
+ return predicted_params, parametrized
131
+
132
+ @print_time
133
+ def _sampler_solution(
134
+ self,
135
+ curve: Tensor or np.ndarray,
136
+ predicted_params: Tensor,
137
+ batch_size: int = 2 ** 13,
138
+ time_limit: float = 3.,
139
+ rel_bounds: float = 0.3,
140
+ ) -> Tensor:
141
+
142
+ fit_obj = MultilayerFit.from_prediction(
143
+ predicted_params, self._prior_sampler, self.q, torch.as_tensor(curve).to(self.q),
144
+ batch_size=batch_size, rel_bounds=rel_bounds,
145
+ )
146
+
147
+ fit_obj.run_fixed_time(time_limit)
148
+
149
+ best_params = fit_obj.get_best_solution()
150
+
151
+ return best_params
152
+
153
+ @property
154
+ def _prior_sampler(self) -> SimpleMultilayerSampler:
155
+ return self.trainer.loader.prior_sampler
156
+
157
+ @print_time
158
+ def _polish_prediction(self,
159
+ q: np.ndarray,
160
+ curve: np.ndarray,
161
+ predicted_params: Tensor,
162
+ q_values: np.ndarray,
163
+ use_kinematical: bool = True,
164
+ ) -> dict:
165
+
166
+ params = predicted_params.squeeze().cpu().numpy()
167
+ polished_params_dict = {}
168
+
169
+ if use_kinematical:
170
+ refl_generator = kinematical_approximation_np
171
+ else:
172
+ refl_generator = abeles_np
173
+
174
+ try:
175
+ polished_params_arr, curve_polished = standard_refl_fit(
176
+ q, curve, params, restore_params_func=self._prior_sampler.restore_np_params,
177
+ refl_generator=refl_generator,
178
+ bounds=self._prior_sampler.get_np_bounds(),
179
+ )
180
+ params = self._prior_sampler.restore_np_params(polished_params_arr)
181
+ if q_values is None:
182
+ q_values = q
183
+ curve_polished = abeles_np(q_values, **params)
184
+
185
+ except Exception as err:
186
+ self.log.exception(err)
187
+ polished_params_arr = params
188
+ curve_polished = np.zeros_like(q)
189
+
190
+ polished_params_dict['params_polished'] = polished_params_arr
191
+ polished_params_dict['curve_polished'] = curve_polished
192
+
193
+ return polished_params_dict