reflectorch 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +23 -0
- reflectorch/data_generation/__init__.py +130 -0
- reflectorch/data_generation/dataset.py +196 -0
- reflectorch/data_generation/likelihoods.py +86 -0
- reflectorch/data_generation/noise.py +371 -0
- reflectorch/data_generation/priors/__init__.py +66 -0
- reflectorch/data_generation/priors/base.py +61 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
- reflectorch/data_generation/priors/independent_priors.py +201 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +110 -0
- reflectorch/data_generation/priors/no_constraints.py +212 -0
- reflectorch/data_generation/priors/parametric_models.py +767 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
- reflectorch/data_generation/priors/params.py +258 -0
- reflectorch/data_generation/priors/sampler_strategies.py +306 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +377 -0
- reflectorch/data_generation/priors/utils.py +124 -0
- reflectorch/data_generation/process_data.py +47 -0
- reflectorch/data_generation/q_generator.py +232 -0
- reflectorch/data_generation/reflectivity/__init__.py +56 -0
- reflectorch/data_generation/reflectivity/abeles.py +81 -0
- reflectorch/data_generation/reflectivity/kinematical.py +58 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +123 -0
- reflectorch/data_generation/scale_curves.py +118 -0
- reflectorch/data_generation/smearing.py +67 -0
- reflectorch/data_generation/utils.py +154 -0
- reflectorch/extensions/__init__.py +6 -0
- reflectorch/extensions/jupyter/__init__.py +12 -0
- reflectorch/extensions/jupyter/callbacks.py +40 -0
- reflectorch/extensions/matplotlib/__init__.py +11 -0
- reflectorch/extensions/matplotlib/losses.py +38 -0
- reflectorch/inference/__init__.py +22 -0
- reflectorch/inference/inference_model.py +734 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +16 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +171 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +37 -0
- reflectorch/ml/basic_trainer.py +286 -0
- reflectorch/ml/callbacks.py +86 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +38 -0
- reflectorch/ml/schedulers.py +246 -0
- reflectorch/ml/trainers.py +126 -0
- reflectorch/ml/utils.py +9 -0
- reflectorch/models/__init__.py +22 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +27 -0
- reflectorch/models/encoders/conv_encoder.py +211 -0
- reflectorch/models/encoders/conv_res_net.py +119 -0
- reflectorch/models/encoders/fno.py +127 -0
- reflectorch/models/encoders/transformers.py +56 -0
- reflectorch/models/networks/__init__.py +18 -0
- reflectorch/models/networks/mlp_networks.py +256 -0
- reflectorch/models/networks/residual_net.py +131 -0
- reflectorch/paths.py +33 -0
- reflectorch/runs/__init__.py +35 -0
- reflectorch/runs/config.py +31 -0
- reflectorch/runs/slurm_utils.py +99 -0
- reflectorch/runs/train.py +85 -0
- reflectorch/runs/utils.py +300 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +74 -0
- reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
- reflectorch-1.0.0.dist-info/METADATA +115 -0
- reflectorch-1.0.0.dist-info/RECORD +83 -0
- reflectorch-1.0.0.dist-info/WHEEL +5 -0
- reflectorch-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from reflectorch import *
|
|
7
|
+
from reflectorch.runs.config import load_config
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"train_from_config",
|
|
11
|
+
"get_trainer_from_config",
|
|
12
|
+
"get_paths_from_config",
|
|
13
|
+
"get_callbacks_from_config",
|
|
14
|
+
"get_trainer_by_name",
|
|
15
|
+
"get_callbacks_by_name",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def init_from_conf(conf, **kwargs):
|
|
20
|
+
"""Initializes an object with class type, args and kwargs specified in the configuration
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
conf (dict): configuration dictionary
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Any: the initialized object
|
|
27
|
+
"""
|
|
28
|
+
if not conf:
|
|
29
|
+
return
|
|
30
|
+
cls_name = conf['cls']
|
|
31
|
+
if not cls_name:
|
|
32
|
+
return
|
|
33
|
+
cls = globals().get(cls_name)
|
|
34
|
+
|
|
35
|
+
if not cls:
|
|
36
|
+
raise ValueError(f'Unknown class {cls_name}')
|
|
37
|
+
|
|
38
|
+
conf_args = conf.get('args', [])
|
|
39
|
+
conf_kwargs = conf.get('kwargs', {})
|
|
40
|
+
conf_kwargs.update(kwargs)
|
|
41
|
+
return cls(*conf_args, **conf_kwargs)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def train_from_config(config: dict):
|
|
45
|
+
"""Train a model from a configuration dictionary
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
config (dict): configuration dictionary
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Trainer: the trainer object
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
folder_paths = get_paths_from_config(config, mkdir=True)
|
|
55
|
+
|
|
56
|
+
trainer = get_trainer_from_config(config, folder_paths)
|
|
57
|
+
|
|
58
|
+
callbacks = get_callbacks_from_config(config, folder_paths)
|
|
59
|
+
|
|
60
|
+
trainer.train(
|
|
61
|
+
config['training']['num_iterations'],
|
|
62
|
+
callbacks, disable_tqdm=False,
|
|
63
|
+
update_tqdm_freq=config['training']['update_tqdm_freq'],
|
|
64
|
+
grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
torch.save({
|
|
68
|
+
'paths': folder_paths,
|
|
69
|
+
'losses': trainer.losses,
|
|
70
|
+
'params': config,
|
|
71
|
+
}, folder_paths['losses'])
|
|
72
|
+
|
|
73
|
+
return trainer
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_paths_from_config(config: dict, mkdir: bool = False):
|
|
77
|
+
"""Get the directory paths from a configuration dictionary
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
config (dict): configuration dictionary
|
|
81
|
+
mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
dict: dictionary containing the folder paths
|
|
85
|
+
"""
|
|
86
|
+
root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
|
|
87
|
+
name = config['general']['name']
|
|
88
|
+
|
|
89
|
+
assert root_dir.is_dir()
|
|
90
|
+
|
|
91
|
+
saved_models_dir = root_dir / 'saved_models'
|
|
92
|
+
saved_losses_dir = root_dir / 'saved_losses'
|
|
93
|
+
|
|
94
|
+
if mkdir:
|
|
95
|
+
saved_models_dir.mkdir(exist_ok=True)
|
|
96
|
+
saved_losses_dir.mkdir(exist_ok=True)
|
|
97
|
+
|
|
98
|
+
model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
|
|
99
|
+
|
|
100
|
+
losses_path = saved_losses_dir / f'{name}_losses.pt'
|
|
101
|
+
|
|
102
|
+
return {
|
|
103
|
+
'name': name,
|
|
104
|
+
'model': model_path,
|
|
105
|
+
'losses': losses_path,
|
|
106
|
+
'root': root_dir,
|
|
107
|
+
'saved_models': saved_models_dir,
|
|
108
|
+
'saved_losses': saved_losses_dir,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
|
|
113
|
+
"""Initializes the training callbacks from a configuration dictionary
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
tuple: tuple of callbacks
|
|
117
|
+
"""
|
|
118
|
+
callbacks = []
|
|
119
|
+
|
|
120
|
+
folder_paths = folder_paths or get_paths_from_config(config)
|
|
121
|
+
|
|
122
|
+
train_conf = config['training']
|
|
123
|
+
callback_conf = dict(train_conf['callbacks'])
|
|
124
|
+
save_conf = callback_conf.pop('save_best_model')
|
|
125
|
+
|
|
126
|
+
if save_conf['enable']:
|
|
127
|
+
save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
|
|
128
|
+
callbacks.append(save_model)
|
|
129
|
+
|
|
130
|
+
for conf in callback_conf.values():
|
|
131
|
+
callback = init_from_conf(conf)
|
|
132
|
+
|
|
133
|
+
if callback:
|
|
134
|
+
callbacks.append(callback)
|
|
135
|
+
|
|
136
|
+
if train_conf['logger']['use_neptune']:
|
|
137
|
+
callbacks.append(LogLosses())
|
|
138
|
+
|
|
139
|
+
return tuple(callbacks)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def get_trainer_from_config(config: dict, folder_paths: dict = None):
|
|
143
|
+
"""Initializes a trainer from a configuration dictionary
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
config (dict): the configuration dictionary
|
|
147
|
+
folder_paths (dict, optional): dictionary containing the folder paths
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Trainer: the trainer object
|
|
151
|
+
"""
|
|
152
|
+
dset = init_dset(config['dset'])
|
|
153
|
+
|
|
154
|
+
folder_paths = folder_paths or get_paths_from_config(config)
|
|
155
|
+
|
|
156
|
+
model = init_network(config['model']['network'], folder_paths['saved_models'])
|
|
157
|
+
|
|
158
|
+
train_conf = config['training']
|
|
159
|
+
|
|
160
|
+
optim_cls = getattr(torch.optim, train_conf['optimizer'])
|
|
161
|
+
|
|
162
|
+
logger = None
|
|
163
|
+
|
|
164
|
+
train_with_q_input = train_conf.get('train_with_q_input', False)
|
|
165
|
+
clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
|
|
166
|
+
|
|
167
|
+
trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
|
|
168
|
+
trainer_kwargs = train_conf.get('trainer_kwargs', {})
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
trainer = trainer_cls(
|
|
172
|
+
model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
|
|
173
|
+
logger=logger, optim_cls=optim_cls, train_with_q_input=train_with_q_input,
|
|
174
|
+
**trainer_kwargs
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return trainer
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device = None):
|
|
181
|
+
"""Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
|
|
182
|
+
saved weights into the network
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
config_name (str): name of the configuration file
|
|
186
|
+
config_dir (str): path of the configuration directory
|
|
187
|
+
model_path (str, optional): path to the network weights.
|
|
188
|
+
load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Trainer: the trainer object
|
|
193
|
+
"""
|
|
194
|
+
config = load_config(config_name, config_dir)
|
|
195
|
+
config['model']['network']['pretrained_name'] = None
|
|
196
|
+
config['training']['logger']['use_neptune'] = False
|
|
197
|
+
|
|
198
|
+
if inference_device:
|
|
199
|
+
config['model']['network']['device'] = inference_device
|
|
200
|
+
config['dset']['prior_sampler']['kwargs']['device'] = inference_device
|
|
201
|
+
config['dset']['q_generator']['kwargs']['device'] = inference_device
|
|
202
|
+
|
|
203
|
+
trainer = get_trainer_from_config(config)
|
|
204
|
+
|
|
205
|
+
num_params = sum(p.numel() for p in trainer.model.parameters())
|
|
206
|
+
|
|
207
|
+
print(
|
|
208
|
+
f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if not load_weights:
|
|
212
|
+
return trainer
|
|
213
|
+
|
|
214
|
+
if not model_path:
|
|
215
|
+
model_name = f'model_{config_name}.pt'
|
|
216
|
+
model_path = SAVED_MODELS_DIR / model_name
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
state_dict = torch.load(model_path)
|
|
220
|
+
except Exception as err:
|
|
221
|
+
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
222
|
+
|
|
223
|
+
if 'model' in state_dict:
|
|
224
|
+
trainer.model.load_state_dict(state_dict['model'])
|
|
225
|
+
else:
|
|
226
|
+
trainer.model.load_state_dict(state_dict)
|
|
227
|
+
|
|
228
|
+
return trainer
|
|
229
|
+
|
|
230
|
+
def get_callbacks_by_name(config_name, config_dir=None):
|
|
231
|
+
"""Initializes the trainer callbacks based on a configuration file
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
config_name (str): name of the configuration file
|
|
235
|
+
config_dir (str): path of the configuration directory
|
|
236
|
+
"""
|
|
237
|
+
config = load_config(config_name, config_dir)
|
|
238
|
+
callbacks = get_callbacks_from_config(config)
|
|
239
|
+
|
|
240
|
+
return callbacks
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def init_network(config: dict, saved_models_dir: Path):
|
|
244
|
+
"""Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
|
|
245
|
+
device = config.get('device', 'cuda')
|
|
246
|
+
network = init_from_conf(config).to(device)
|
|
247
|
+
network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
|
|
248
|
+
return network
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def load_pretrained(model, model_name: str, saved_models_dir: Path):
|
|
252
|
+
"""Loads the saved weights into the network"""
|
|
253
|
+
if not model_name:
|
|
254
|
+
return model
|
|
255
|
+
|
|
256
|
+
if '.' not in model_name:
|
|
257
|
+
model_name = model_name + '.pt'
|
|
258
|
+
model_path = saved_models_dir / model_name
|
|
259
|
+
|
|
260
|
+
if not model_path.is_file():
|
|
261
|
+
model_path = saved_models_dir / f'model_{model_name}'
|
|
262
|
+
|
|
263
|
+
if not model_path.is_file():
|
|
264
|
+
raise FileNotFoundError(f'File {str(model_path)} does not exist.')
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
pretrained = torch.load(model_path)
|
|
268
|
+
except Exception as err:
|
|
269
|
+
raise RuntimeError(f'Could not load model from {str(model_path)}') from err
|
|
270
|
+
|
|
271
|
+
if 'model' in pretrained:
|
|
272
|
+
pretrained = pretrained['model']
|
|
273
|
+
try:
|
|
274
|
+
model.load_state_dict(pretrained)
|
|
275
|
+
except Exception as err:
|
|
276
|
+
raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
|
|
277
|
+
|
|
278
|
+
return model
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def init_dset(config: dict):
|
|
282
|
+
"""Initializes the dataset / dataloader object"""
|
|
283
|
+
dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
|
|
284
|
+
prior_sampler = init_from_conf(config['prior_sampler'])
|
|
285
|
+
intensity_noise = init_from_conf(config['intensity_noise'])
|
|
286
|
+
q_generator = init_from_conf(config['q_generator'])
|
|
287
|
+
curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
|
|
288
|
+
smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
|
|
289
|
+
q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
|
|
290
|
+
|
|
291
|
+
dset = dset_cls(
|
|
292
|
+
q_generator=q_generator,
|
|
293
|
+
prior_sampler=prior_sampler,
|
|
294
|
+
intensity_noise=intensity_noise,
|
|
295
|
+
curves_scaler=curves_scaler,
|
|
296
|
+
smearing=smearing,
|
|
297
|
+
q_noise=q_noise,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return dset
|
reflectorch/train.py
ADDED
reflectorch/utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
|
|
10
|
+
from torch import Tensor, tensor
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
'to_np',
|
|
14
|
+
'to_t',
|
|
15
|
+
'angle_to_q',
|
|
16
|
+
'q_to_angle',
|
|
17
|
+
'energy_to_wavelength',
|
|
18
|
+
'wavelength_to_energy',
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def to_np(arr):
|
|
23
|
+
"""Converts Pytorch tensor or Python list to Numpy array
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
arr (torch.Tensor or list): Input Pytorch tensor or Python list
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
numpy.ndarray: Converted Numpy array
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
if isinstance(arr, Tensor):
|
|
33
|
+
return arr.detach().cpu().numpy()
|
|
34
|
+
return np.asarray(arr)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_t(arr, device=None, dtype=None):
|
|
38
|
+
"""Converts Numpy array or Python list to Pytorch tensor
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
arr (numpy.ndarray or list): Input
|
|
42
|
+
device (torch.device or str, optional): device for the tensor ('cpu', 'cuda')
|
|
43
|
+
dtype (torch.dtype, optional): data type of the tensor (e.g. torch.float32)
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
torch.Tensor: converted Pytorch tensor
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
if not isinstance(arr, Tensor):
|
|
50
|
+
return tensor(arr, device=device, dtype=dtype)
|
|
51
|
+
return arr
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# taken from mlreflect package
|
|
55
|
+
# mlreflect/xrrloader/dataloader/transform.py
|
|
56
|
+
|
|
57
|
+
def angle_to_q(scattering_angle: ndarray or float, wavelength: float):
|
|
58
|
+
"""Conversion from full scattering angle (degrees) to scattering vector (inverse angstroms)"""
|
|
59
|
+
return 4 * np.pi / wavelength * np.sin(scattering_angle / 2 * np.pi / 180)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def q_to_angle(q: ndarray or float, wavelength: float):
|
|
63
|
+
"""Conversion from scattering vector (inverse angstroms) to full scattering angle (degrees)"""
|
|
64
|
+
return 2 * np.arcsin(q * wavelength / (4 * np.pi)) / np.pi * 180
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def energy_to_wavelength(energy: float):
|
|
68
|
+
"""Conversion from photon energy (eV) to photon wavelength (angstroms)"""
|
|
69
|
+
return 1.2398 / energy * 1e4
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def wavelength_to_energy(wavelength: float):
|
|
73
|
+
"""Conversion from photon wavelength (angstroms) to photon energy (eV)"""
|
|
74
|
+
return 1.2398 / wavelength * 1e4
|