reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,300 @@
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import torch
5
+
6
+ from reflectorch import *
7
+ from reflectorch.runs.config import load_config
8
+
9
+ __all__ = [
10
+ "train_from_config",
11
+ "get_trainer_from_config",
12
+ "get_paths_from_config",
13
+ "get_callbacks_from_config",
14
+ "get_trainer_by_name",
15
+ "get_callbacks_by_name",
16
+ ]
17
+
18
+
19
+ def init_from_conf(conf, **kwargs):
20
+ """Initializes an object with class type, args and kwargs specified in the configuration
21
+
22
+ Args:
23
+ conf (dict): configuration dictionary
24
+
25
+ Returns:
26
+ Any: the initialized object
27
+ """
28
+ if not conf:
29
+ return
30
+ cls_name = conf['cls']
31
+ if not cls_name:
32
+ return
33
+ cls = globals().get(cls_name)
34
+
35
+ if not cls:
36
+ raise ValueError(f'Unknown class {cls_name}')
37
+
38
+ conf_args = conf.get('args', [])
39
+ conf_kwargs = conf.get('kwargs', {})
40
+ conf_kwargs.update(kwargs)
41
+ return cls(*conf_args, **conf_kwargs)
42
+
43
+
44
+ def train_from_config(config: dict):
45
+ """Train a model from a configuration dictionary
46
+
47
+ Args:
48
+ config (dict): configuration dictionary
49
+
50
+ Returns:
51
+ Trainer: the trainer object
52
+ """
53
+
54
+ folder_paths = get_paths_from_config(config, mkdir=True)
55
+
56
+ trainer = get_trainer_from_config(config, folder_paths)
57
+
58
+ callbacks = get_callbacks_from_config(config, folder_paths)
59
+
60
+ trainer.train(
61
+ config['training']['num_iterations'],
62
+ callbacks, disable_tqdm=False,
63
+ update_tqdm_freq=config['training']['update_tqdm_freq'],
64
+ grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
65
+ )
66
+
67
+ torch.save({
68
+ 'paths': folder_paths,
69
+ 'losses': trainer.losses,
70
+ 'params': config,
71
+ }, folder_paths['losses'])
72
+
73
+ return trainer
74
+
75
+
76
+ def get_paths_from_config(config: dict, mkdir: bool = False):
77
+ """Get the directory paths from a configuration dictionary
78
+
79
+ Args:
80
+ config (dict): configuration dictionary
81
+ mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
82
+
83
+ Returns:
84
+ dict: dictionary containing the folder paths
85
+ """
86
+ root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
87
+ name = config['general']['name']
88
+
89
+ assert root_dir.is_dir()
90
+
91
+ saved_models_dir = root_dir / 'saved_models'
92
+ saved_losses_dir = root_dir / 'saved_losses'
93
+
94
+ if mkdir:
95
+ saved_models_dir.mkdir(exist_ok=True)
96
+ saved_losses_dir.mkdir(exist_ok=True)
97
+
98
+ model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
99
+
100
+ losses_path = saved_losses_dir / f'{name}_losses.pt'
101
+
102
+ return {
103
+ 'name': name,
104
+ 'model': model_path,
105
+ 'losses': losses_path,
106
+ 'root': root_dir,
107
+ 'saved_models': saved_models_dir,
108
+ 'saved_losses': saved_losses_dir,
109
+ }
110
+
111
+
112
+ def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
113
+ """Initializes the training callbacks from a configuration dictionary
114
+
115
+ Returns:
116
+ tuple: tuple of callbacks
117
+ """
118
+ callbacks = []
119
+
120
+ folder_paths = folder_paths or get_paths_from_config(config)
121
+
122
+ train_conf = config['training']
123
+ callback_conf = dict(train_conf['callbacks'])
124
+ save_conf = callback_conf.pop('save_best_model')
125
+
126
+ if save_conf['enable']:
127
+ save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
128
+ callbacks.append(save_model)
129
+
130
+ for conf in callback_conf.values():
131
+ callback = init_from_conf(conf)
132
+
133
+ if callback:
134
+ callbacks.append(callback)
135
+
136
+ if train_conf['logger']['use_neptune']:
137
+ callbacks.append(LogLosses())
138
+
139
+ return tuple(callbacks)
140
+
141
+
142
+ def get_trainer_from_config(config: dict, folder_paths: dict = None):
143
+ """Initializes a trainer from a configuration dictionary
144
+
145
+ Args:
146
+ config (dict): the configuration dictionary
147
+ folder_paths (dict, optional): dictionary containing the folder paths
148
+
149
+ Returns:
150
+ Trainer: the trainer object
151
+ """
152
+ dset = init_dset(config['dset'])
153
+
154
+ folder_paths = folder_paths or get_paths_from_config(config)
155
+
156
+ model = init_network(config['model']['network'], folder_paths['saved_models'])
157
+
158
+ train_conf = config['training']
159
+
160
+ optim_cls = getattr(torch.optim, train_conf['optimizer'])
161
+
162
+ logger = None
163
+
164
+ train_with_q_input = train_conf.get('train_with_q_input', False)
165
+ clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
166
+
167
+ trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
168
+ trainer_kwargs = train_conf.get('trainer_kwargs', {})
169
+
170
+
171
+ trainer = trainer_cls(
172
+ model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
173
+ logger=logger, optim_cls=optim_cls, train_with_q_input=train_with_q_input,
174
+ **trainer_kwargs
175
+ )
176
+
177
+ return trainer
178
+
179
+
180
+ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device = None):
181
+ """Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
182
+ saved weights into the network
183
+
184
+ Args:
185
+ config_name (str): name of the configuration file
186
+ config_dir (str): path of the configuration directory
187
+ model_path (str, optional): path to the network weights.
188
+ load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
189
+
190
+
191
+ Returns:
192
+ Trainer: the trainer object
193
+ """
194
+ config = load_config(config_name, config_dir)
195
+ config['model']['network']['pretrained_name'] = None
196
+ config['training']['logger']['use_neptune'] = False
197
+
198
+ if inference_device:
199
+ config['model']['network']['device'] = inference_device
200
+ config['dset']['prior_sampler']['kwargs']['device'] = inference_device
201
+ config['dset']['q_generator']['kwargs']['device'] = inference_device
202
+
203
+ trainer = get_trainer_from_config(config)
204
+
205
+ num_params = sum(p.numel() for p in trainer.model.parameters())
206
+
207
+ print(
208
+ f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
209
+ )
210
+
211
+ if not load_weights:
212
+ return trainer
213
+
214
+ if not model_path:
215
+ model_name = f'model_{config_name}.pt'
216
+ model_path = SAVED_MODELS_DIR / model_name
217
+
218
+ try:
219
+ state_dict = torch.load(model_path)
220
+ except Exception as err:
221
+ raise RuntimeError(f'Could not load model from {model_path}') from err
222
+
223
+ if 'model' in state_dict:
224
+ trainer.model.load_state_dict(state_dict['model'])
225
+ else:
226
+ trainer.model.load_state_dict(state_dict)
227
+
228
+ return trainer
229
+
230
+ def get_callbacks_by_name(config_name, config_dir=None):
231
+ """Initializes the trainer callbacks based on a configuration file
232
+
233
+ Args:
234
+ config_name (str): name of the configuration file
235
+ config_dir (str): path of the configuration directory
236
+ """
237
+ config = load_config(config_name, config_dir)
238
+ callbacks = get_callbacks_from_config(config)
239
+
240
+ return callbacks
241
+
242
+
243
+ def init_network(config: dict, saved_models_dir: Path):
244
+ """Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
245
+ device = config.get('device', 'cuda')
246
+ network = init_from_conf(config).to(device)
247
+ network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
248
+ return network
249
+
250
+
251
+ def load_pretrained(model, model_name: str, saved_models_dir: Path):
252
+ """Loads the saved weights into the network"""
253
+ if not model_name:
254
+ return model
255
+
256
+ if '.' not in model_name:
257
+ model_name = model_name + '.pt'
258
+ model_path = saved_models_dir / model_name
259
+
260
+ if not model_path.is_file():
261
+ model_path = saved_models_dir / f'model_{model_name}'
262
+
263
+ if not model_path.is_file():
264
+ raise FileNotFoundError(f'File {str(model_path)} does not exist.')
265
+
266
+ try:
267
+ pretrained = torch.load(model_path)
268
+ except Exception as err:
269
+ raise RuntimeError(f'Could not load model from {str(model_path)}') from err
270
+
271
+ if 'model' in pretrained:
272
+ pretrained = pretrained['model']
273
+ try:
274
+ model.load_state_dict(pretrained)
275
+ except Exception as err:
276
+ raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
277
+
278
+ return model
279
+
280
+
281
+ def init_dset(config: dict):
282
+ """Initializes the dataset / dataloader object"""
283
+ dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
284
+ prior_sampler = init_from_conf(config['prior_sampler'])
285
+ intensity_noise = init_from_conf(config['intensity_noise'])
286
+ q_generator = init_from_conf(config['q_generator'])
287
+ curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
288
+ smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
289
+ q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
290
+
291
+ dset = dset_cls(
292
+ q_generator=q_generator,
293
+ prior_sampler=prior_sampler,
294
+ intensity_noise=intensity_noise,
295
+ curves_scaler=curves_scaler,
296
+ smearing=smearing,
297
+ q_noise=q_noise,
298
+ )
299
+
300
+ return dset
@@ -0,0 +1,4 @@
1
+ from reflectorch.runs import run_test_config
2
+
3
+ if __name__ == '__main__':
4
+ run_test_config()
reflectorch/train.py ADDED
@@ -0,0 +1,4 @@
1
+ from reflectorch.runs import run_train
2
+
3
+ if __name__ == '__main__':
4
+ run_train()
@@ -0,0 +1,4 @@
1
+ from reflectorch.runs import run_train_on_cluster
2
+
3
+ if __name__ == '__main__':
4
+ run_train_on_cluster()
reflectorch/utils.py ADDED
@@ -0,0 +1,74 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ from numpy import ndarray
9
+
10
+ from torch import Tensor, tensor
11
+
12
+ __all__ = [
13
+ 'to_np',
14
+ 'to_t',
15
+ 'angle_to_q',
16
+ 'q_to_angle',
17
+ 'energy_to_wavelength',
18
+ 'wavelength_to_energy',
19
+ ]
20
+
21
+
22
+ def to_np(arr):
23
+ """Converts Pytorch tensor or Python list to Numpy array
24
+
25
+ Args:
26
+ arr (torch.Tensor or list): Input Pytorch tensor or Python list
27
+
28
+ Returns:
29
+ numpy.ndarray: Converted Numpy array
30
+ """
31
+
32
+ if isinstance(arr, Tensor):
33
+ return arr.detach().cpu().numpy()
34
+ return np.asarray(arr)
35
+
36
+
37
+ def to_t(arr, device=None, dtype=None):
38
+ """Converts Numpy array or Python list to Pytorch tensor
39
+
40
+ Args:
41
+ arr (numpy.ndarray or list): Input
42
+ device (torch.device or str, optional): device for the tensor ('cpu', 'cuda')
43
+ dtype (torch.dtype, optional): data type of the tensor (e.g. torch.float32)
44
+
45
+ Returns:
46
+ torch.Tensor: converted Pytorch tensor
47
+ """
48
+
49
+ if not isinstance(arr, Tensor):
50
+ return tensor(arr, device=device, dtype=dtype)
51
+ return arr
52
+
53
+
54
+ # taken from mlreflect package
55
+ # mlreflect/xrrloader/dataloader/transform.py
56
+
57
+ def angle_to_q(scattering_angle: ndarray or float, wavelength: float):
58
+ """Conversion from full scattering angle (degrees) to scattering vector (inverse angstroms)"""
59
+ return 4 * np.pi / wavelength * np.sin(scattering_angle / 2 * np.pi / 180)
60
+
61
+
62
+ def q_to_angle(q: ndarray or float, wavelength: float):
63
+ """Conversion from scattering vector (inverse angstroms) to full scattering angle (degrees)"""
64
+ return 2 * np.arcsin(q * wavelength / (4 * np.pi)) / np.pi * 180
65
+
66
+
67
+ def energy_to_wavelength(energy: float):
68
+ """Conversion from photon energy (eV) to photon wavelength (angstroms)"""
69
+ return 1.2398 / energy * 1e4
70
+
71
+
72
+ def wavelength_to_energy(wavelength: float):
73
+ """Conversion from photon wavelength (angstroms) to photon energy (eV)"""
74
+ return 1.2398 / wavelength * 1e4