reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
reflectorch/runs/utils.py
CHANGED
|
@@ -1,402 +1,405 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
from typing import Tuple
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
import safetensors.torch
|
|
6
|
-
import os
|
|
7
|
-
|
|
8
|
-
from reflectorch import *
|
|
9
|
-
from reflectorch.runs.config import load_config
|
|
10
|
-
|
|
11
|
-
__all__ = [
|
|
12
|
-
"train_from_config",
|
|
13
|
-
"get_trainer_from_config",
|
|
14
|
-
"get_paths_from_config",
|
|
15
|
-
"get_callbacks_from_config",
|
|
16
|
-
"get_trainer_by_name",
|
|
17
|
-
"get_callbacks_by_name",
|
|
18
|
-
"convert_pt_to_safetensors",
|
|
19
|
-
]
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def init_from_conf(conf, **kwargs):
|
|
23
|
-
"""Initializes an object with class type, args and kwargs specified in the configuration
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
conf (dict): configuration dictionary
|
|
27
|
-
|
|
28
|
-
Returns:
|
|
29
|
-
Any: the initialized object
|
|
30
|
-
"""
|
|
31
|
-
if not conf:
|
|
32
|
-
return
|
|
33
|
-
cls_name = conf['cls']
|
|
34
|
-
if not cls_name:
|
|
35
|
-
return
|
|
36
|
-
cls = globals().get(cls_name)
|
|
37
|
-
|
|
38
|
-
if not cls:
|
|
39
|
-
raise ValueError(f'Unknown class {cls_name}')
|
|
40
|
-
|
|
41
|
-
conf_args = conf.get('args', [])
|
|
42
|
-
conf_kwargs = conf.get('kwargs', {})
|
|
43
|
-
conf_kwargs.update(kwargs)
|
|
44
|
-
return cls(*conf_args, **conf_kwargs)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def train_from_config(config: dict):
|
|
48
|
-
"""Train a model from a configuration dictionary
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
config (dict): configuration dictionary
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
Trainer: the trainer object
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
folder_paths = get_paths_from_config(config, mkdir=True)
|
|
58
|
-
|
|
59
|
-
trainer = get_trainer_from_config(config, folder_paths)
|
|
60
|
-
|
|
61
|
-
callbacks = get_callbacks_from_config(config, folder_paths)
|
|
62
|
-
|
|
63
|
-
trainer.train(
|
|
64
|
-
config['training']['num_iterations'],
|
|
65
|
-
callbacks, disable_tqdm=False,
|
|
66
|
-
update_tqdm_freq=config['training']['update_tqdm_freq'],
|
|
67
|
-
grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
torch.save({
|
|
71
|
-
'paths': folder_paths,
|
|
72
|
-
'losses': trainer.losses,
|
|
73
|
-
'params': config,
|
|
74
|
-
}, folder_paths['losses'])
|
|
75
|
-
|
|
76
|
-
return trainer
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def get_paths_from_config(config: dict, mkdir: bool = False):
|
|
80
|
-
"""Get the directory paths from a configuration dictionary
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
config (dict): configuration dictionary
|
|
84
|
-
mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
dict: dictionary containing the folder paths
|
|
88
|
-
"""
|
|
89
|
-
root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
|
|
90
|
-
name = config['general']['name']
|
|
91
|
-
|
|
92
|
-
assert root_dir.is_dir()
|
|
93
|
-
|
|
94
|
-
saved_models_dir = root_dir / 'saved_models'
|
|
95
|
-
saved_losses_dir = root_dir / 'saved_losses'
|
|
96
|
-
|
|
97
|
-
if mkdir:
|
|
98
|
-
saved_models_dir.mkdir(exist_ok=True)
|
|
99
|
-
saved_losses_dir.mkdir(exist_ok=True)
|
|
100
|
-
|
|
101
|
-
model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
|
|
102
|
-
|
|
103
|
-
losses_path = saved_losses_dir / f'{name}_losses.pt'
|
|
104
|
-
|
|
105
|
-
return {
|
|
106
|
-
'name': name,
|
|
107
|
-
'model': model_path,
|
|
108
|
-
'losses': losses_path,
|
|
109
|
-
'root': root_dir,
|
|
110
|
-
'saved_models': saved_models_dir,
|
|
111
|
-
'saved_losses': saved_losses_dir,
|
|
112
|
-
}
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
|
|
116
|
-
"""Initializes the training callbacks from a configuration dictionary
|
|
117
|
-
|
|
118
|
-
Returns:
|
|
119
|
-
tuple: tuple of callbacks
|
|
120
|
-
"""
|
|
121
|
-
callbacks = []
|
|
122
|
-
|
|
123
|
-
folder_paths = folder_paths or get_paths_from_config(config)
|
|
124
|
-
|
|
125
|
-
train_conf = config['training']
|
|
126
|
-
callback_conf = dict(train_conf['callbacks'])
|
|
127
|
-
save_conf = callback_conf.pop('save_best_model')
|
|
128
|
-
|
|
129
|
-
if save_conf['enable']:
|
|
130
|
-
save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
|
|
131
|
-
callbacks.append(save_model)
|
|
132
|
-
|
|
133
|
-
for conf in callback_conf.values():
|
|
134
|
-
callback = init_from_conf(conf)
|
|
135
|
-
|
|
136
|
-
if callback:
|
|
137
|
-
callbacks.append(callback)
|
|
138
|
-
|
|
139
|
-
if 'logger' in train_conf.keys():
|
|
140
|
-
callbacks.append(LogLosses())
|
|
141
|
-
|
|
142
|
-
return tuple(callbacks)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def get_trainer_from_config(config: dict, folder_paths: dict = None):
|
|
146
|
-
"""Initializes a trainer from a configuration dictionary
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
config (dict): the configuration dictionary
|
|
150
|
-
folder_paths (dict, optional): dictionary containing the folder paths
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
Trainer: the trainer object
|
|
154
|
-
"""
|
|
155
|
-
dset = init_dset(config['dset'])
|
|
156
|
-
|
|
157
|
-
folder_paths = folder_paths or get_paths_from_config(config)
|
|
158
|
-
|
|
159
|
-
model = init_network(config['model']['network'], folder_paths['saved_models'])
|
|
160
|
-
|
|
161
|
-
train_conf = config['training']
|
|
162
|
-
|
|
163
|
-
optim_cls = getattr(torch.optim, train_conf['optimizer'])
|
|
164
|
-
|
|
165
|
-
logger_conf = train_conf.get('logger', None)
|
|
166
|
-
logger = init_from_conf(logger_conf) if logger_conf and logger_conf.get('cls') else None
|
|
167
|
-
|
|
168
|
-
clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
|
|
169
|
-
|
|
170
|
-
trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
|
|
171
|
-
trainer_kwargs = train_conf.get('trainer_kwargs', {})
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
trainer = trainer_cls(
|
|
175
|
-
model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
|
|
176
|
-
logger=logger, optim_cls=optim_cls,
|
|
177
|
-
**trainer_kwargs
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
if train_conf.get('train_with_q_input', False) and getattr(trainer, 'train_with_q_input', None) is not None: #only for back-compatibility with configs in older versions
|
|
181
|
-
trainer.train_with_q_input = True
|
|
182
|
-
|
|
183
|
-
return trainer
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
|
|
187
|
-
"""Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
|
|
188
|
-
saved weights into the network
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
config_name (str): name of the configuration file
|
|
192
|
-
config_dir (str): path of the configuration directory
|
|
193
|
-
model_path (str, optional): path to the network weights. The default path is 'saved_models' located in the package directory
|
|
194
|
-
load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
|
|
195
|
-
inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
|
|
196
|
-
|
|
197
|
-
Returns:
|
|
198
|
-
Trainer: the trainer object
|
|
199
|
-
"""
|
|
200
|
-
config = load_config(config_name, config_dir)
|
|
201
|
-
#config['model']['network']['pretrained_name'] = None
|
|
202
|
-
|
|
203
|
-
config['model']['network']['device'] = inference_device
|
|
204
|
-
config['dset']['prior_sampler']['kwargs']['device'] = inference_device
|
|
205
|
-
config['dset']['q_generator']['kwargs']['device'] = inference_device
|
|
206
|
-
|
|
207
|
-
trainer = get_trainer_from_config(config)
|
|
208
|
-
|
|
209
|
-
num_params = sum(p.numel() for p in trainer.model.parameters())
|
|
210
|
-
|
|
211
|
-
print(
|
|
212
|
-
f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
if not load_weights:
|
|
216
|
-
return trainer
|
|
217
|
-
|
|
218
|
-
if not model_path:
|
|
219
|
-
model_name = f'model_{config_name}.pt'
|
|
220
|
-
model_path = SAVED_MODELS_DIR / model_name
|
|
221
|
-
|
|
222
|
-
if str(model_path).endswith('.pt'):
|
|
223
|
-
try:
|
|
224
|
-
state_dict = torch.load(model_path, map_location=inference_device, weights_only=False)
|
|
225
|
-
except Exception as err:
|
|
226
|
-
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
227
|
-
|
|
228
|
-
if 'model' in state_dict:
|
|
229
|
-
trainer.model.load_state_dict(state_dict['model'])
|
|
230
|
-
else:
|
|
231
|
-
trainer.model.load_state_dict(state_dict)
|
|
232
|
-
|
|
233
|
-
elif str(model_path).endswith('.safetensors'):
|
|
234
|
-
try:
|
|
235
|
-
load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
|
|
236
|
-
except Exception as err:
|
|
237
|
-
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
238
|
-
|
|
239
|
-
else:
|
|
240
|
-
raise RuntimeError('Weigths file with unknown extension')
|
|
241
|
-
|
|
242
|
-
return trainer
|
|
243
|
-
|
|
244
|
-
def get_callbacks_by_name(config_name, config_dir=None):
|
|
245
|
-
"""Initializes the trainer callbacks based on a configuration file
|
|
246
|
-
|
|
247
|
-
Args:
|
|
248
|
-
config_name (str): name of the configuration file
|
|
249
|
-
config_dir (str): path of the configuration directory
|
|
250
|
-
"""
|
|
251
|
-
config = load_config(config_name, config_dir)
|
|
252
|
-
callbacks = get_callbacks_from_config(config)
|
|
253
|
-
|
|
254
|
-
return callbacks
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
def init_network(config: dict, saved_models_dir: Path):
|
|
258
|
-
"""Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
|
|
259
|
-
device = config.get('device', 'cuda')
|
|
260
|
-
network = init_from_conf(config).to(device)
|
|
261
|
-
network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
|
|
262
|
-
return network
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def load_pretrained(model, model_name: str, saved_models_dir: Path):
|
|
266
|
-
"""Loads the saved weights into the network"""
|
|
267
|
-
if not model_name:
|
|
268
|
-
return model
|
|
269
|
-
|
|
270
|
-
if '.' not in model_name:
|
|
271
|
-
model_name = model_name + '.pt'
|
|
272
|
-
model_path = saved_models_dir / model_name
|
|
273
|
-
|
|
274
|
-
if not model_path.is_file():
|
|
275
|
-
model_path = saved_models_dir / f'model_{model_name}'
|
|
276
|
-
|
|
277
|
-
if not model_path.is_file():
|
|
278
|
-
raise FileNotFoundError(f'File {str(model_path)} does not exist.')
|
|
279
|
-
|
|
280
|
-
try:
|
|
281
|
-
pretrained = torch.load(model_path)
|
|
282
|
-
except Exception as err:
|
|
283
|
-
raise RuntimeError(f'Could not load model from {str(model_path)}') from err
|
|
284
|
-
|
|
285
|
-
if 'model' in pretrained:
|
|
286
|
-
pretrained = pretrained['model']
|
|
287
|
-
try:
|
|
288
|
-
model.load_state_dict(pretrained)
|
|
289
|
-
except Exception as err:
|
|
290
|
-
raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
|
|
291
|
-
|
|
292
|
-
return model
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
def init_dset(config: dict):
|
|
296
|
-
"""Initializes the dataset / dataloader object"""
|
|
297
|
-
dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
new_state_dict[key] = tensor
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
new_state_dict
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
files
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import safetensors.torch
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from reflectorch import *
|
|
9
|
+
from reflectorch.runs.config import load_config
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"train_from_config",
|
|
13
|
+
"get_trainer_from_config",
|
|
14
|
+
"get_paths_from_config",
|
|
15
|
+
"get_callbacks_from_config",
|
|
16
|
+
"get_trainer_by_name",
|
|
17
|
+
"get_callbacks_by_name",
|
|
18
|
+
"convert_pt_to_safetensors",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def init_from_conf(conf, **kwargs):
|
|
23
|
+
"""Initializes an object with class type, args and kwargs specified in the configuration
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
conf (dict): configuration dictionary
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Any: the initialized object
|
|
30
|
+
"""
|
|
31
|
+
if not conf:
|
|
32
|
+
return
|
|
33
|
+
cls_name = conf['cls']
|
|
34
|
+
if not cls_name:
|
|
35
|
+
return
|
|
36
|
+
cls = globals().get(cls_name)
|
|
37
|
+
|
|
38
|
+
if not cls:
|
|
39
|
+
raise ValueError(f'Unknown class {cls_name}')
|
|
40
|
+
|
|
41
|
+
conf_args = conf.get('args', [])
|
|
42
|
+
conf_kwargs = conf.get('kwargs', {})
|
|
43
|
+
conf_kwargs.update(kwargs)
|
|
44
|
+
return cls(*conf_args, **conf_kwargs)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def train_from_config(config: dict):
|
|
48
|
+
"""Train a model from a configuration dictionary
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config (dict): configuration dictionary
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Trainer: the trainer object
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
folder_paths = get_paths_from_config(config, mkdir=True)
|
|
58
|
+
|
|
59
|
+
trainer = get_trainer_from_config(config, folder_paths)
|
|
60
|
+
|
|
61
|
+
callbacks = get_callbacks_from_config(config, folder_paths)
|
|
62
|
+
|
|
63
|
+
trainer.train(
|
|
64
|
+
config['training']['num_iterations'],
|
|
65
|
+
callbacks, disable_tqdm=False,
|
|
66
|
+
update_tqdm_freq=config['training']['update_tqdm_freq'],
|
|
67
|
+
grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
torch.save({
|
|
71
|
+
'paths': folder_paths,
|
|
72
|
+
'losses': trainer.losses,
|
|
73
|
+
'params': config,
|
|
74
|
+
}, folder_paths['losses'])
|
|
75
|
+
|
|
76
|
+
return trainer
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_paths_from_config(config: dict, mkdir: bool = False):
|
|
80
|
+
"""Get the directory paths from a configuration dictionary
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
config (dict): configuration dictionary
|
|
84
|
+
mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
dict: dictionary containing the folder paths
|
|
88
|
+
"""
|
|
89
|
+
root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
|
|
90
|
+
name = config['general']['name']
|
|
91
|
+
|
|
92
|
+
assert root_dir.is_dir()
|
|
93
|
+
|
|
94
|
+
saved_models_dir = root_dir / 'saved_models'
|
|
95
|
+
saved_losses_dir = root_dir / 'saved_losses'
|
|
96
|
+
|
|
97
|
+
if mkdir:
|
|
98
|
+
saved_models_dir.mkdir(exist_ok=True)
|
|
99
|
+
saved_losses_dir.mkdir(exist_ok=True)
|
|
100
|
+
|
|
101
|
+
model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
|
|
102
|
+
|
|
103
|
+
losses_path = saved_losses_dir / f'{name}_losses.pt'
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
'name': name,
|
|
107
|
+
'model': model_path,
|
|
108
|
+
'losses': losses_path,
|
|
109
|
+
'root': root_dir,
|
|
110
|
+
'saved_models': saved_models_dir,
|
|
111
|
+
'saved_losses': saved_losses_dir,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
|
|
116
|
+
"""Initializes the training callbacks from a configuration dictionary
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
tuple: tuple of callbacks
|
|
120
|
+
"""
|
|
121
|
+
callbacks = []
|
|
122
|
+
|
|
123
|
+
folder_paths = folder_paths or get_paths_from_config(config)
|
|
124
|
+
|
|
125
|
+
train_conf = config['training']
|
|
126
|
+
callback_conf = dict(train_conf['callbacks'])
|
|
127
|
+
save_conf = callback_conf.pop('save_best_model')
|
|
128
|
+
|
|
129
|
+
if save_conf['enable']:
|
|
130
|
+
save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
|
|
131
|
+
callbacks.append(save_model)
|
|
132
|
+
|
|
133
|
+
for conf in callback_conf.values():
|
|
134
|
+
callback = init_from_conf(conf)
|
|
135
|
+
|
|
136
|
+
if callback:
|
|
137
|
+
callbacks.append(callback)
|
|
138
|
+
|
|
139
|
+
if 'logger' in train_conf.keys():
|
|
140
|
+
callbacks.append(LogLosses())
|
|
141
|
+
|
|
142
|
+
return tuple(callbacks)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_trainer_from_config(config: dict, folder_paths: dict = None):
|
|
146
|
+
"""Initializes a trainer from a configuration dictionary
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
config (dict): the configuration dictionary
|
|
150
|
+
folder_paths (dict, optional): dictionary containing the folder paths
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Trainer: the trainer object
|
|
154
|
+
"""
|
|
155
|
+
dset = init_dset(config['dset'])
|
|
156
|
+
|
|
157
|
+
folder_paths = folder_paths or get_paths_from_config(config)
|
|
158
|
+
|
|
159
|
+
model = init_network(config['model']['network'], folder_paths['saved_models'])
|
|
160
|
+
|
|
161
|
+
train_conf = config['training']
|
|
162
|
+
|
|
163
|
+
optim_cls = getattr(torch.optim, train_conf['optimizer'])
|
|
164
|
+
|
|
165
|
+
logger_conf = train_conf.get('logger', None)
|
|
166
|
+
logger = init_from_conf(logger_conf) if logger_conf and logger_conf.get('cls') else None
|
|
167
|
+
|
|
168
|
+
clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
|
|
169
|
+
|
|
170
|
+
trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
|
|
171
|
+
trainer_kwargs = train_conf.get('trainer_kwargs', {})
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
trainer = trainer_cls(
|
|
175
|
+
model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
|
|
176
|
+
logger=logger, optim_cls=optim_cls,
|
|
177
|
+
**trainer_kwargs
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if train_conf.get('train_with_q_input', False) and getattr(trainer, 'train_with_q_input', None) is not None: #only for back-compatibility with configs in older versions
|
|
181
|
+
trainer.train_with_q_input = True
|
|
182
|
+
|
|
183
|
+
return trainer
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
|
|
187
|
+
"""Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
|
|
188
|
+
saved weights into the network
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
config_name (str): name of the configuration file
|
|
192
|
+
config_dir (str): path of the configuration directory
|
|
193
|
+
model_path (str, optional): path to the network weights. The default path is 'saved_models' located in the package directory
|
|
194
|
+
load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
|
|
195
|
+
inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Trainer: the trainer object
|
|
199
|
+
"""
|
|
200
|
+
config = load_config(config_name, config_dir)
|
|
201
|
+
#config['model']['network']['pretrained_name'] = None
|
|
202
|
+
|
|
203
|
+
config['model']['network']['device'] = inference_device
|
|
204
|
+
config['dset']['prior_sampler']['kwargs']['device'] = inference_device
|
|
205
|
+
config['dset']['q_generator']['kwargs']['device'] = inference_device
|
|
206
|
+
|
|
207
|
+
trainer = get_trainer_from_config(config)
|
|
208
|
+
|
|
209
|
+
num_params = sum(p.numel() for p in trainer.model.parameters())
|
|
210
|
+
|
|
211
|
+
print(
|
|
212
|
+
f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if not load_weights:
|
|
216
|
+
return trainer
|
|
217
|
+
|
|
218
|
+
if not model_path:
|
|
219
|
+
model_name = f'model_{config_name}.pt'
|
|
220
|
+
model_path = SAVED_MODELS_DIR / model_name
|
|
221
|
+
|
|
222
|
+
if str(model_path).endswith('.pt'):
|
|
223
|
+
try:
|
|
224
|
+
state_dict = torch.load(model_path, map_location=inference_device, weights_only=False)
|
|
225
|
+
except Exception as err:
|
|
226
|
+
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
227
|
+
|
|
228
|
+
if 'model' in state_dict:
|
|
229
|
+
trainer.model.load_state_dict(state_dict['model'])
|
|
230
|
+
else:
|
|
231
|
+
trainer.model.load_state_dict(state_dict)
|
|
232
|
+
|
|
233
|
+
elif str(model_path).endswith('.safetensors'):
|
|
234
|
+
try:
|
|
235
|
+
load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
|
|
236
|
+
except Exception as err:
|
|
237
|
+
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
238
|
+
|
|
239
|
+
else:
|
|
240
|
+
raise RuntimeError('Weigths file with unknown extension')
|
|
241
|
+
|
|
242
|
+
return trainer
|
|
243
|
+
|
|
244
|
+
def get_callbacks_by_name(config_name, config_dir=None):
|
|
245
|
+
"""Initializes the trainer callbacks based on a configuration file
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
config_name (str): name of the configuration file
|
|
249
|
+
config_dir (str): path of the configuration directory
|
|
250
|
+
"""
|
|
251
|
+
config = load_config(config_name, config_dir)
|
|
252
|
+
callbacks = get_callbacks_from_config(config)
|
|
253
|
+
|
|
254
|
+
return callbacks
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def init_network(config: dict, saved_models_dir: Path):
|
|
258
|
+
"""Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
|
|
259
|
+
device = config.get('device', 'cuda')
|
|
260
|
+
network = init_from_conf(config).to(device)
|
|
261
|
+
network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
|
|
262
|
+
return network
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def load_pretrained(model, model_name: str, saved_models_dir: Path):
|
|
266
|
+
"""Loads the saved weights into the network"""
|
|
267
|
+
if not model_name:
|
|
268
|
+
return model
|
|
269
|
+
|
|
270
|
+
if '.' not in model_name:
|
|
271
|
+
model_name = model_name + '.pt'
|
|
272
|
+
model_path = saved_models_dir / model_name
|
|
273
|
+
|
|
274
|
+
if not model_path.is_file():
|
|
275
|
+
model_path = saved_models_dir / f'model_{model_name}'
|
|
276
|
+
|
|
277
|
+
if not model_path.is_file():
|
|
278
|
+
raise FileNotFoundError(f'File {str(model_path)} does not exist.')
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
pretrained = torch.load(model_path, weights_only=False)
|
|
282
|
+
except Exception as err:
|
|
283
|
+
raise RuntimeError(f'Could not load model from {str(model_path)}') from err
|
|
284
|
+
|
|
285
|
+
if 'model' in pretrained:
|
|
286
|
+
pretrained = pretrained['model']
|
|
287
|
+
try:
|
|
288
|
+
model.load_state_dict(pretrained)
|
|
289
|
+
except Exception as err:
|
|
290
|
+
raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
|
|
291
|
+
|
|
292
|
+
return model
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def init_dset(config: dict):
|
|
296
|
+
"""Initializes the dataset / dataloader object"""
|
|
297
|
+
dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
|
|
298
|
+
dset_kwargs = config.get('kwargs', {})
|
|
299
|
+
|
|
300
|
+
prior_sampler = init_from_conf(config['prior_sampler'])
|
|
301
|
+
intensity_noise = init_from_conf(config['intensity_noise'])
|
|
302
|
+
q_generator = init_from_conf(config['q_generator'])
|
|
303
|
+
curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
|
|
304
|
+
smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
|
|
305
|
+
q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
|
|
306
|
+
|
|
307
|
+
dset = dset_cls(
|
|
308
|
+
q_generator=q_generator,
|
|
309
|
+
prior_sampler=prior_sampler,
|
|
310
|
+
intensity_noise=intensity_noise,
|
|
311
|
+
curves_scaler=curves_scaler,
|
|
312
|
+
smearing=smearing,
|
|
313
|
+
q_noise=q_noise,
|
|
314
|
+
**dset_kwargs,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return dset
|
|
318
|
+
|
|
319
|
+
def split_complex_tensors(state_dict):
|
|
320
|
+
new_state_dict = {}
|
|
321
|
+
for key, tensor in state_dict.items():
|
|
322
|
+
if tensor.is_complex():
|
|
323
|
+
new_state_dict[f"{key}_real"] = tensor.real.clone()
|
|
324
|
+
new_state_dict[f"{key}_imag"] = tensor.imag.clone()
|
|
325
|
+
else:
|
|
326
|
+
new_state_dict[key] = tensor
|
|
327
|
+
return new_state_dict
|
|
328
|
+
|
|
329
|
+
def recombine_complex_tensors(state_dict):
|
|
330
|
+
new_state_dict = {}
|
|
331
|
+
keys = list(state_dict.keys())
|
|
332
|
+
visited = set()
|
|
333
|
+
|
|
334
|
+
for key in keys:
|
|
335
|
+
if key.endswith('_real') or key.endswith('_imag'):
|
|
336
|
+
base_key = key[:-5]
|
|
337
|
+
new_state_dict[base_key] = torch.complex(state_dict[base_key + '_real'], state_dict[base_key + '_imag'])
|
|
338
|
+
visited.add(base_key + '_real')
|
|
339
|
+
visited.add(base_key + '_imag')
|
|
340
|
+
elif key not in visited:
|
|
341
|
+
new_state_dict[key] = state_dict[key]
|
|
342
|
+
|
|
343
|
+
return new_state_dict
|
|
344
|
+
|
|
345
|
+
def convert_pt_to_safetensors(input_dir):
|
|
346
|
+
"""Creates '.safetensors' files for all the model state dictionaries inside '.pt' files in the specified directory.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
input_dir (str): directory containing model weights
|
|
350
|
+
"""
|
|
351
|
+
if not os.path.isdir(input_dir):
|
|
352
|
+
raise ValueError(f"Input directory {input_dir} does not exist")
|
|
353
|
+
|
|
354
|
+
for file_name in os.listdir(input_dir):
|
|
355
|
+
if file_name.endswith('.pt'):
|
|
356
|
+
pt_file_path = os.path.join(input_dir, file_name)
|
|
357
|
+
safetensors_file_path = os.path.join(input_dir, file_name[:-3] + '.safetensors')
|
|
358
|
+
|
|
359
|
+
if os.path.exists(safetensors_file_path):
|
|
360
|
+
print(f"Skipping {pt_file_path}, corresponding .safetensors file already exists.")
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
print(f"Converting {pt_file_path} to .safetensors format.")
|
|
364
|
+
data_pt = torch.load(pt_file_path, weights_only=False)
|
|
365
|
+
model_state_dict = data_pt["model"]
|
|
366
|
+
model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
|
|
367
|
+
|
|
368
|
+
safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
|
|
369
|
+
|
|
370
|
+
def convert_files_to_safetensors(files):
|
|
371
|
+
"""
|
|
372
|
+
Converts specified .pt files to .safetensors format.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
files (str or list of str): Path(s) to .pt files containing model state dictionaries.
|
|
376
|
+
"""
|
|
377
|
+
if isinstance(files, str):
|
|
378
|
+
files = [files]
|
|
379
|
+
|
|
380
|
+
for pt_file_path in files:
|
|
381
|
+
if not pt_file_path.endswith('.pt'):
|
|
382
|
+
print(f"Skipping {pt_file_path}: not a .pt file.")
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
if not os.path.exists(pt_file_path):
|
|
386
|
+
print(f"File {pt_file_path} does not exist.")
|
|
387
|
+
continue
|
|
388
|
+
|
|
389
|
+
safetensors_file_path = pt_file_path[:-3] + '.safetensors'
|
|
390
|
+
|
|
391
|
+
if os.path.exists(safetensors_file_path):
|
|
392
|
+
print(f"Skipping {pt_file_path}: .safetensors version already exists.")
|
|
393
|
+
continue
|
|
394
|
+
|
|
395
|
+
print(f"Converting {pt_file_path} to .safetensors format.")
|
|
396
|
+
data_pt = torch.load(pt_file_path, weights_only=False)
|
|
397
|
+
model_state_dict = data_pt["model"]
|
|
398
|
+
model_state_dict = split_complex_tensors(model_state_dict)
|
|
399
|
+
|
|
400
|
+
safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
|
|
401
|
+
|
|
402
|
+
def load_state_dict_safetensors(model, filename, device):
|
|
403
|
+
state_dict = safetensors.torch.load_file(filename=filename, device=device)
|
|
404
|
+
state_dict = recombine_complex_tensors(state_dict)
|
|
402
405
|
model.load_state_dict(state_dict)
|