reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,195 +1,195 @@
1
- from typing import Union, Tuple
2
- from math import log
3
-
4
- import torch
5
- from torch import Tensor
6
-
7
- from reflectorch.data_generation.priors.base import PriorSampler
8
- from reflectorch.data_generation.priors.params import Params
9
- from reflectorch.data_generation.utils import get_param_labels
10
-
11
- __all__ = [
12
- 'SingleParamPrior',
13
- 'UniformParamPrior',
14
- 'GaussianParamPrior',
15
- 'TruncatedGaussianParamPrior',
16
- 'SimplePriorSampler',
17
-
18
- ]
19
-
20
-
21
- class SingleParamPrior(object):
22
- def sample(self, batch_num: int, device=None, dtype=None):
23
- raise NotImplementedError
24
-
25
- def log_prob(self, params: Tensor):
26
- raise NotImplementedError
27
-
28
- def to_conf(self):
29
- vars_dict = {k: v for k, v in vars(self).items() if not k.startswith('_')}
30
- return {
31
- 'cls': self.__class__.__name__,
32
- 'kwargs': vars_dict
33
- }
34
-
35
- def scale(self, params: Tensor) -> Tensor:
36
- raise NotImplementedError
37
-
38
- def restore(self, params: Tensor) -> Tensor:
39
- raise NotImplementedError
40
-
41
-
42
- class SimplePriorSampler(PriorSampler):
43
- def __init__(self,
44
- *params: Union[Tuple[float, float], SingleParamPrior],
45
- device=None, dtype=None, param_cls=None, scaling_prior: PriorSampler = None,
46
- ):
47
- if param_cls is not None:
48
- self.PARAM_CLS = param_cls
49
- elif scaling_prior is not None:
50
- self.PARAM_CLS = scaling_prior.PARAM_CLS
51
- self.param_priors: Tuple[SingleParamPrior, ...] = self._init_params(params)
52
- self._num_layers = self.PARAM_CLS.size2layers_num(len(params))
53
- self.device = device
54
- self.dtype = dtype
55
- self.scaling_prior = scaling_prior
56
-
57
- @property
58
- def max_num_layers(self) -> int:
59
- return self._num_layers
60
-
61
- def _init_params(self, params: Tuple[Union[Tuple[float, float], SingleParamPrior], ...]):
62
- assert len(params) == self.PARAM_CLS.layers_num2size(self.PARAM_CLS.size2layers_num(len(params)))
63
- params = tuple(
64
- param if isinstance(param, SingleParamPrior) else UniformParamPrior(*param)
65
- for param in params
66
- )
67
- return params
68
-
69
- def sample(self, batch_size: int) -> Params:
70
- t_params = torch.stack([
71
- param.sample(batch_size, device=self.device, dtype=self.dtype) for param in self.param_priors
72
- ], -1)
73
- params = self.PARAM_CLS.from_tensor(t_params)
74
- return params
75
-
76
- def scale_params(self, params: Params) -> Tensor:
77
- if self.scaling_prior:
78
- return self.scaling_prior.scale_params(params)
79
-
80
- t_params = params.as_tensor()
81
-
82
- scaled_params = torch.stack(
83
- [param_prior.scale(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
84
- )
85
-
86
- return scaled_params
87
-
88
- def restore_params(self, scaled_params: Tensor) -> Params:
89
- if self.scaling_prior:
90
- return self.scaling_prior.restore_params(scaled_params)
91
- t_params = torch.stack(
92
- [param_prior.restore(param) for param, param_prior in zip(scaled_params.T, self.param_priors)], -1
93
- )
94
- return self.PARAM_CLS.from_tensor(t_params)
95
-
96
- def log_prob(self, params: Params) -> Tensor:
97
- t_params = params.as_tensor()
98
- log_probs = torch.stack(
99
- [param_prior.log_prob(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
100
- )
101
- return log_probs.sum(1)
102
-
103
- def get_indices_within_domain(self, params: Params) -> Tensor:
104
- log_probs = self.log_prob(params)
105
- return torch.isfinite(log_probs)
106
-
107
- def get_indices_within_bounds(self, params: Params) -> Tensor:
108
- return self.get_indices_within_domain(params)
109
-
110
- def __repr__(self):
111
- layers_num = self.PARAM_CLS.size2layers_num(len(self.param_priors))
112
- labels = get_param_labels(layers_num)
113
- prior_str = '\n\t'.join(f'{label}: {param_prior}' for label, param_prior in zip(labels, self.param_priors))
114
- return f'SimplePriorSampler(\n\t{prior_str}\n)'
115
-
116
-
117
- class UniformParamPrior(SingleParamPrior):
118
- def __init__(self, min_value: float, max_value: float, device=None, dtype=None):
119
- assert min_value < max_value
120
- self.min_value, self.max_value, self.delta = min_value, max_value, max_value - min_value
121
- self._lob_prob_const = - log(self.delta)
122
- self.device = device
123
- self.dtype = dtype
124
-
125
- def sample(self, batch_num: int, device=None, dtype=None):
126
- params = torch.rand(
127
- batch_num, device=(device or self.device), dtype=(dtype or self.dtype)
128
- ) * self.delta + self.min_value
129
- return params
130
-
131
- def log_prob(self, params: Tensor):
132
- log_probs = torch.fill_(torch.ones_like(params), self._lob_prob_const)
133
- log_probs[(params < self.min_value) | (params > self.max_value)] = - float('inf')
134
- return log_probs
135
-
136
- def scale(self, params: Tensor) -> Tensor:
137
- return (params - self.min_value) / self.delta
138
-
139
- def restore(self, params: Tensor) -> Tensor:
140
- return params * self.delta + self.min_value
141
-
142
- def __repr__(self):
143
- return f'UniformParamPrior(min={self.min_value}, max={self.max_value})'
144
-
145
-
146
- class GaussianParamPrior(SingleParamPrior):
147
- _GAUSS_SCALE_CONST: float = 4.
148
-
149
- def __init__(self, mean: float, std: float, device=None, dtype=None):
150
- assert std > 0
151
- self.mean = mean
152
- self.std = std
153
- self.device = device
154
- self.dtype = dtype
155
-
156
- def sample(self, batch_num: int, device=None, dtype=None):
157
- params = torch.normal(
158
- self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
159
- )
160
- return params
161
-
162
- def log_prob(self, params: Tensor):
163
- # ignore constant
164
- log_probs = - ((params - self.mean) / self.std) ** 2 / 2
165
- return log_probs
166
-
167
- def scale(self, params: Tensor) -> Tensor:
168
- return (params - self.mean) / (self.std * self._GAUSS_SCALE_CONST)
169
-
170
- def restore(self, params: Tensor) -> Tensor:
171
- return params * self.std * self._GAUSS_SCALE_CONST + self.mean
172
-
173
- def __repr__(self):
174
- return f'GaussianParamPrior(mean={self.mean}, std={self.std})'
175
-
176
-
177
- class TruncatedGaussianParamPrior(GaussianParamPrior):
178
- def sample(self, batch_num: int, device=None, dtype=None) -> Tensor:
179
- params = torch.normal(
180
- self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
181
- )
182
- negative_params = params < 0.
183
- num_negative_params = negative_params.sum().item()
184
- if num_negative_params:
185
- params[negative_params] = self.sample(num_negative_params, device=device, dtype=dtype)
186
- return params
187
-
188
- def log_prob(self, params: Tensor) -> Tensor:
189
- # ignore constant
190
- log_probs: Tensor = - ((params - self.mean) / self.std) ** 2 / 2
191
- log_probs[params < 0.] = - float('inf')
192
- return log_probs
193
-
194
- def __repr__(self):
195
- return f'TruncatedGaussianParamPrior(mean={self.mean}, std={self.std})'
1
+ from typing import Union, Tuple
2
+ from math import log
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.priors.base import PriorSampler
8
+ from reflectorch.data_generation.priors.params import Params
9
+ from reflectorch.data_generation.utils import get_param_labels
10
+
11
+ __all__ = [
12
+ 'SingleParamPrior',
13
+ 'UniformParamPrior',
14
+ 'GaussianParamPrior',
15
+ 'TruncatedGaussianParamPrior',
16
+ 'SimplePriorSampler',
17
+
18
+ ]
19
+
20
+
21
+ class SingleParamPrior(object):
22
+ def sample(self, batch_num: int, device=None, dtype=None):
23
+ raise NotImplementedError
24
+
25
+ def log_prob(self, params: Tensor):
26
+ raise NotImplementedError
27
+
28
+ def to_conf(self):
29
+ vars_dict = {k: v for k, v in vars(self).items() if not k.startswith('_')}
30
+ return {
31
+ 'cls': self.__class__.__name__,
32
+ 'kwargs': vars_dict
33
+ }
34
+
35
+ def scale(self, params: Tensor) -> Tensor:
36
+ raise NotImplementedError
37
+
38
+ def restore(self, params: Tensor) -> Tensor:
39
+ raise NotImplementedError
40
+
41
+
42
+ class SimplePriorSampler(PriorSampler):
43
+ def __init__(self,
44
+ *params: Union[Tuple[float, float], SingleParamPrior],
45
+ device=None, dtype=None, param_cls=None, scaling_prior: PriorSampler = None,
46
+ ):
47
+ if param_cls is not None:
48
+ self.PARAM_CLS = param_cls
49
+ elif scaling_prior is not None:
50
+ self.PARAM_CLS = scaling_prior.PARAM_CLS
51
+ self.param_priors: Tuple[SingleParamPrior, ...] = self._init_params(params)
52
+ self._num_layers = self.PARAM_CLS.size2layers_num(len(params))
53
+ self.device = device
54
+ self.dtype = dtype
55
+ self.scaling_prior = scaling_prior
56
+
57
+ @property
58
+ def max_num_layers(self) -> int:
59
+ return self._num_layers
60
+
61
+ def _init_params(self, params: Tuple[Union[Tuple[float, float], SingleParamPrior], ...]):
62
+ assert len(params) == self.PARAM_CLS.layers_num2size(self.PARAM_CLS.size2layers_num(len(params)))
63
+ params = tuple(
64
+ param if isinstance(param, SingleParamPrior) else UniformParamPrior(*param)
65
+ for param in params
66
+ )
67
+ return params
68
+
69
+ def sample(self, batch_size: int) -> Params:
70
+ t_params = torch.stack([
71
+ param.sample(batch_size, device=self.device, dtype=self.dtype) for param in self.param_priors
72
+ ], -1)
73
+ params = self.PARAM_CLS.from_tensor(t_params)
74
+ return params
75
+
76
+ def scale_params(self, params: Params) -> Tensor:
77
+ if self.scaling_prior:
78
+ return self.scaling_prior.scale_params(params)
79
+
80
+ t_params = params.as_tensor()
81
+
82
+ scaled_params = torch.stack(
83
+ [param_prior.scale(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
84
+ )
85
+
86
+ return scaled_params
87
+
88
+ def restore_params(self, scaled_params: Tensor) -> Params:
89
+ if self.scaling_prior:
90
+ return self.scaling_prior.restore_params(scaled_params)
91
+ t_params = torch.stack(
92
+ [param_prior.restore(param) for param, param_prior in zip(scaled_params.T, self.param_priors)], -1
93
+ )
94
+ return self.PARAM_CLS.from_tensor(t_params)
95
+
96
+ def log_prob(self, params: Params) -> Tensor:
97
+ t_params = params.as_tensor()
98
+ log_probs = torch.stack(
99
+ [param_prior.log_prob(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
100
+ )
101
+ return log_probs.sum(1)
102
+
103
+ def get_indices_within_domain(self, params: Params) -> Tensor:
104
+ log_probs = self.log_prob(params)
105
+ return torch.isfinite(log_probs)
106
+
107
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
108
+ return self.get_indices_within_domain(params)
109
+
110
+ def __repr__(self):
111
+ layers_num = self.PARAM_CLS.size2layers_num(len(self.param_priors))
112
+ labels = get_param_labels(layers_num)
113
+ prior_str = '\n\t'.join(f'{label}: {param_prior}' for label, param_prior in zip(labels, self.param_priors))
114
+ return f'SimplePriorSampler(\n\t{prior_str}\n)'
115
+
116
+
117
+ class UniformParamPrior(SingleParamPrior):
118
+ def __init__(self, min_value: float, max_value: float, device=None, dtype=None):
119
+ assert min_value < max_value
120
+ self.min_value, self.max_value, self.delta = min_value, max_value, max_value - min_value
121
+ self._lob_prob_const = - log(self.delta)
122
+ self.device = device
123
+ self.dtype = dtype
124
+
125
+ def sample(self, batch_num: int, device=None, dtype=None):
126
+ params = torch.rand(
127
+ batch_num, device=(device or self.device), dtype=(dtype or self.dtype)
128
+ ) * self.delta + self.min_value
129
+ return params
130
+
131
+ def log_prob(self, params: Tensor):
132
+ log_probs = torch.fill_(torch.ones_like(params), self._lob_prob_const)
133
+ log_probs[(params < self.min_value) | (params > self.max_value)] = - float('inf')
134
+ return log_probs
135
+
136
+ def scale(self, params: Tensor) -> Tensor:
137
+ return (params - self.min_value) / self.delta
138
+
139
+ def restore(self, params: Tensor) -> Tensor:
140
+ return params * self.delta + self.min_value
141
+
142
+ def __repr__(self):
143
+ return f'UniformParamPrior(min={self.min_value}, max={self.max_value})'
144
+
145
+
146
+ class GaussianParamPrior(SingleParamPrior):
147
+ _GAUSS_SCALE_CONST: float = 4.
148
+
149
+ def __init__(self, mean: float, std: float, device=None, dtype=None):
150
+ assert std > 0
151
+ self.mean = mean
152
+ self.std = std
153
+ self.device = device
154
+ self.dtype = dtype
155
+
156
+ def sample(self, batch_num: int, device=None, dtype=None):
157
+ params = torch.normal(
158
+ self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
159
+ )
160
+ return params
161
+
162
+ def log_prob(self, params: Tensor):
163
+ # ignore constant
164
+ log_probs = - ((params - self.mean) / self.std) ** 2 / 2
165
+ return log_probs
166
+
167
+ def scale(self, params: Tensor) -> Tensor:
168
+ return (params - self.mean) / (self.std * self._GAUSS_SCALE_CONST)
169
+
170
+ def restore(self, params: Tensor) -> Tensor:
171
+ return params * self.std * self._GAUSS_SCALE_CONST + self.mean
172
+
173
+ def __repr__(self):
174
+ return f'GaussianParamPrior(mean={self.mean}, std={self.std})'
175
+
176
+
177
+ class TruncatedGaussianParamPrior(GaussianParamPrior):
178
+ def sample(self, batch_num: int, device=None, dtype=None) -> Tensor:
179
+ params = torch.normal(
180
+ self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
181
+ )
182
+ negative_params = params < 0.
183
+ num_negative_params = negative_params.sum().item()
184
+ if num_negative_params:
185
+ params[negative_params] = self.sample(num_negative_params, device=device, dtype=dtype)
186
+ return params
187
+
188
+ def log_prob(self, params: Tensor) -> Tensor:
189
+ # ignore constant
190
+ log_probs: Tensor = - ((params - self.mean) / self.std) ** 2 / 2
191
+ log_probs[params < 0.] = - float('inf')
192
+ return log_probs
193
+
194
+ def __repr__(self):
195
+ return f'TruncatedGaussianParamPrior(mean={self.mean}, std={self.std})'