reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
reflectorch/runs/utils.py CHANGED
@@ -1,402 +1,405 @@
1
- from pathlib import Path
2
- from typing import Tuple
3
-
4
- import torch
5
- import safetensors.torch
6
- import os
7
-
8
- from reflectorch import *
9
- from reflectorch.runs.config import load_config
10
-
11
- __all__ = [
12
- "train_from_config",
13
- "get_trainer_from_config",
14
- "get_paths_from_config",
15
- "get_callbacks_from_config",
16
- "get_trainer_by_name",
17
- "get_callbacks_by_name",
18
- "convert_pt_to_safetensors",
19
- ]
20
-
21
-
22
- def init_from_conf(conf, **kwargs):
23
- """Initializes an object with class type, args and kwargs specified in the configuration
24
-
25
- Args:
26
- conf (dict): configuration dictionary
27
-
28
- Returns:
29
- Any: the initialized object
30
- """
31
- if not conf:
32
- return
33
- cls_name = conf['cls']
34
- if not cls_name:
35
- return
36
- cls = globals().get(cls_name)
37
-
38
- if not cls:
39
- raise ValueError(f'Unknown class {cls_name}')
40
-
41
- conf_args = conf.get('args', [])
42
- conf_kwargs = conf.get('kwargs', {})
43
- conf_kwargs.update(kwargs)
44
- return cls(*conf_args, **conf_kwargs)
45
-
46
-
47
- def train_from_config(config: dict):
48
- """Train a model from a configuration dictionary
49
-
50
- Args:
51
- config (dict): configuration dictionary
52
-
53
- Returns:
54
- Trainer: the trainer object
55
- """
56
-
57
- folder_paths = get_paths_from_config(config, mkdir=True)
58
-
59
- trainer = get_trainer_from_config(config, folder_paths)
60
-
61
- callbacks = get_callbacks_from_config(config, folder_paths)
62
-
63
- trainer.train(
64
- config['training']['num_iterations'],
65
- callbacks, disable_tqdm=False,
66
- update_tqdm_freq=config['training']['update_tqdm_freq'],
67
- grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
68
- )
69
-
70
- torch.save({
71
- 'paths': folder_paths,
72
- 'losses': trainer.losses,
73
- 'params': config,
74
- }, folder_paths['losses'])
75
-
76
- return trainer
77
-
78
-
79
- def get_paths_from_config(config: dict, mkdir: bool = False):
80
- """Get the directory paths from a configuration dictionary
81
-
82
- Args:
83
- config (dict): configuration dictionary
84
- mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
85
-
86
- Returns:
87
- dict: dictionary containing the folder paths
88
- """
89
- root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
90
- name = config['general']['name']
91
-
92
- assert root_dir.is_dir()
93
-
94
- saved_models_dir = root_dir / 'saved_models'
95
- saved_losses_dir = root_dir / 'saved_losses'
96
-
97
- if mkdir:
98
- saved_models_dir.mkdir(exist_ok=True)
99
- saved_losses_dir.mkdir(exist_ok=True)
100
-
101
- model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
102
-
103
- losses_path = saved_losses_dir / f'{name}_losses.pt'
104
-
105
- return {
106
- 'name': name,
107
- 'model': model_path,
108
- 'losses': losses_path,
109
- 'root': root_dir,
110
- 'saved_models': saved_models_dir,
111
- 'saved_losses': saved_losses_dir,
112
- }
113
-
114
-
115
- def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
116
- """Initializes the training callbacks from a configuration dictionary
117
-
118
- Returns:
119
- tuple: tuple of callbacks
120
- """
121
- callbacks = []
122
-
123
- folder_paths = folder_paths or get_paths_from_config(config)
124
-
125
- train_conf = config['training']
126
- callback_conf = dict(train_conf['callbacks'])
127
- save_conf = callback_conf.pop('save_best_model')
128
-
129
- if save_conf['enable']:
130
- save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
131
- callbacks.append(save_model)
132
-
133
- for conf in callback_conf.values():
134
- callback = init_from_conf(conf)
135
-
136
- if callback:
137
- callbacks.append(callback)
138
-
139
- if 'logger' in train_conf.keys():
140
- callbacks.append(LogLosses())
141
-
142
- return tuple(callbacks)
143
-
144
-
145
- def get_trainer_from_config(config: dict, folder_paths: dict = None):
146
- """Initializes a trainer from a configuration dictionary
147
-
148
- Args:
149
- config (dict): the configuration dictionary
150
- folder_paths (dict, optional): dictionary containing the folder paths
151
-
152
- Returns:
153
- Trainer: the trainer object
154
- """
155
- dset = init_dset(config['dset'])
156
-
157
- folder_paths = folder_paths or get_paths_from_config(config)
158
-
159
- model = init_network(config['model']['network'], folder_paths['saved_models'])
160
-
161
- train_conf = config['training']
162
-
163
- optim_cls = getattr(torch.optim, train_conf['optimizer'])
164
-
165
- logger_conf = train_conf.get('logger', None)
166
- logger = init_from_conf(logger_conf) if logger_conf and logger_conf.get('cls') else None
167
-
168
- clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
169
-
170
- trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
171
- trainer_kwargs = train_conf.get('trainer_kwargs', {})
172
-
173
-
174
- trainer = trainer_cls(
175
- model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
176
- logger=logger, optim_cls=optim_cls,
177
- **trainer_kwargs
178
- )
179
-
180
- if train_conf.get('train_with_q_input', False) and getattr(trainer, 'train_with_q_input', None) is not None: #only for back-compatibility with configs in older versions
181
- trainer.train_with_q_input = True
182
-
183
- return trainer
184
-
185
-
186
- def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
187
- """Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
188
- saved weights into the network
189
-
190
- Args:
191
- config_name (str): name of the configuration file
192
- config_dir (str): path of the configuration directory
193
- model_path (str, optional): path to the network weights. The default path is 'saved_models' located in the package directory
194
- load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
195
- inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
196
-
197
- Returns:
198
- Trainer: the trainer object
199
- """
200
- config = load_config(config_name, config_dir)
201
- #config['model']['network']['pretrained_name'] = None
202
-
203
- config['model']['network']['device'] = inference_device
204
- config['dset']['prior_sampler']['kwargs']['device'] = inference_device
205
- config['dset']['q_generator']['kwargs']['device'] = inference_device
206
-
207
- trainer = get_trainer_from_config(config)
208
-
209
- num_params = sum(p.numel() for p in trainer.model.parameters())
210
-
211
- print(
212
- f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
213
- )
214
-
215
- if not load_weights:
216
- return trainer
217
-
218
- if not model_path:
219
- model_name = f'model_{config_name}.pt'
220
- model_path = SAVED_MODELS_DIR / model_name
221
-
222
- if str(model_path).endswith('.pt'):
223
- try:
224
- state_dict = torch.load(model_path, map_location=inference_device, weights_only=False)
225
- except Exception as err:
226
- raise RuntimeError(f'Could not load model from {model_path}') from err
227
-
228
- if 'model' in state_dict:
229
- trainer.model.load_state_dict(state_dict['model'])
230
- else:
231
- trainer.model.load_state_dict(state_dict)
232
-
233
- elif str(model_path).endswith('.safetensors'):
234
- try:
235
- load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
236
- except Exception as err:
237
- raise RuntimeError(f'Could not load model from {model_path}') from err
238
-
239
- else:
240
- raise RuntimeError('Weigths file with unknown extension')
241
-
242
- return trainer
243
-
244
- def get_callbacks_by_name(config_name, config_dir=None):
245
- """Initializes the trainer callbacks based on a configuration file
246
-
247
- Args:
248
- config_name (str): name of the configuration file
249
- config_dir (str): path of the configuration directory
250
- """
251
- config = load_config(config_name, config_dir)
252
- callbacks = get_callbacks_from_config(config)
253
-
254
- return callbacks
255
-
256
-
257
- def init_network(config: dict, saved_models_dir: Path):
258
- """Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
259
- device = config.get('device', 'cuda')
260
- network = init_from_conf(config).to(device)
261
- network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
262
- return network
263
-
264
-
265
- def load_pretrained(model, model_name: str, saved_models_dir: Path):
266
- """Loads the saved weights into the network"""
267
- if not model_name:
268
- return model
269
-
270
- if '.' not in model_name:
271
- model_name = model_name + '.pt'
272
- model_path = saved_models_dir / model_name
273
-
274
- if not model_path.is_file():
275
- model_path = saved_models_dir / f'model_{model_name}'
276
-
277
- if not model_path.is_file():
278
- raise FileNotFoundError(f'File {str(model_path)} does not exist.')
279
-
280
- try:
281
- pretrained = torch.load(model_path)
282
- except Exception as err:
283
- raise RuntimeError(f'Could not load model from {str(model_path)}') from err
284
-
285
- if 'model' in pretrained:
286
- pretrained = pretrained['model']
287
- try:
288
- model.load_state_dict(pretrained)
289
- except Exception as err:
290
- raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
291
-
292
- return model
293
-
294
-
295
- def init_dset(config: dict):
296
- """Initializes the dataset / dataloader object"""
297
- dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
298
- prior_sampler = init_from_conf(config['prior_sampler'])
299
- intensity_noise = init_from_conf(config['intensity_noise'])
300
- q_generator = init_from_conf(config['q_generator'])
301
- curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
302
- smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
303
- q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
304
-
305
- dset = dset_cls(
306
- q_generator=q_generator,
307
- prior_sampler=prior_sampler,
308
- intensity_noise=intensity_noise,
309
- curves_scaler=curves_scaler,
310
- smearing=smearing,
311
- q_noise=q_noise,
312
- )
313
-
314
- return dset
315
-
316
- def split_complex_tensors(state_dict):
317
- new_state_dict = {}
318
- for key, tensor in state_dict.items():
319
- if tensor.is_complex():
320
- new_state_dict[f"{key}_real"] = tensor.real.clone()
321
- new_state_dict[f"{key}_imag"] = tensor.imag.clone()
322
- else:
323
- new_state_dict[key] = tensor
324
- return new_state_dict
325
-
326
- def recombine_complex_tensors(state_dict):
327
- new_state_dict = {}
328
- keys = list(state_dict.keys())
329
- visited = set()
330
-
331
- for key in keys:
332
- if key.endswith('_real') or key.endswith('_imag'):
333
- base_key = key[:-5]
334
- new_state_dict[base_key] = torch.complex(state_dict[base_key + '_real'], state_dict[base_key + '_imag'])
335
- visited.add(base_key + '_real')
336
- visited.add(base_key + '_imag')
337
- elif key not in visited:
338
- new_state_dict[key] = state_dict[key]
339
-
340
- return new_state_dict
341
-
342
- def convert_pt_to_safetensors(input_dir):
343
- """Creates '.safetensors' files for all the model state dictionaries inside '.pt' files in the specified directory.
344
-
345
- Args:
346
- input_dir (str): directory containing model weights
347
- """
348
- if not os.path.isdir(input_dir):
349
- raise ValueError(f"Input directory {input_dir} does not exist")
350
-
351
- for file_name in os.listdir(input_dir):
352
- if file_name.endswith('.pt'):
353
- pt_file_path = os.path.join(input_dir, file_name)
354
- safetensors_file_path = os.path.join(input_dir, file_name[:-3] + '.safetensors')
355
-
356
- if os.path.exists(safetensors_file_path):
357
- print(f"Skipping {pt_file_path}, corresponding .safetensors file already exists.")
358
- continue
359
-
360
- print(f"Converting {pt_file_path} to .safetensors format.")
361
- data_pt = torch.load(pt_file_path)
362
- model_state_dict = data_pt["model"]
363
- model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
364
-
365
- safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
366
-
367
- def convert_files_to_safetensors(files):
368
- """
369
- Converts specified .pt files to .safetensors format.
370
-
371
- Args:
372
- files (str or list of str): Path(s) to .pt files containing model state dictionaries.
373
- """
374
- if isinstance(files, str):
375
- files = [files]
376
-
377
- for pt_file_path in files:
378
- if not pt_file_path.endswith('.pt'):
379
- print(f"Skipping {pt_file_path}: not a .pt file.")
380
- continue
381
-
382
- if not os.path.exists(pt_file_path):
383
- print(f"File {pt_file_path} does not exist.")
384
- continue
385
-
386
- safetensors_file_path = pt_file_path[:-3] + '.safetensors'
387
-
388
- if os.path.exists(safetensors_file_path):
389
- print(f"Skipping {pt_file_path}: .safetensors version already exists.")
390
- continue
391
-
392
- print(f"Converting {pt_file_path} to .safetensors format.")
393
- data_pt = torch.load(pt_file_path, weights_only=False)
394
- model_state_dict = data_pt["model"]
395
- model_state_dict = split_complex_tensors(model_state_dict)
396
-
397
- safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
398
-
399
- def load_state_dict_safetensors(model, filename, device):
400
- state_dict = safetensors.torch.load_file(filename=filename, device=device)
401
- state_dict = recombine_complex_tensors(state_dict)
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import safetensors.torch
6
+ import os
7
+
8
+ from reflectorch import *
9
+ from reflectorch.runs.config import load_config
10
+
11
+ __all__ = [
12
+ "train_from_config",
13
+ "get_trainer_from_config",
14
+ "get_paths_from_config",
15
+ "get_callbacks_from_config",
16
+ "get_trainer_by_name",
17
+ "get_callbacks_by_name",
18
+ "convert_pt_to_safetensors",
19
+ ]
20
+
21
+
22
+ def init_from_conf(conf, **kwargs):
23
+ """Initializes an object with class type, args and kwargs specified in the configuration
24
+
25
+ Args:
26
+ conf (dict): configuration dictionary
27
+
28
+ Returns:
29
+ Any: the initialized object
30
+ """
31
+ if not conf:
32
+ return
33
+ cls_name = conf['cls']
34
+ if not cls_name:
35
+ return
36
+ cls = globals().get(cls_name)
37
+
38
+ if not cls:
39
+ raise ValueError(f'Unknown class {cls_name}')
40
+
41
+ conf_args = conf.get('args', [])
42
+ conf_kwargs = conf.get('kwargs', {})
43
+ conf_kwargs.update(kwargs)
44
+ return cls(*conf_args, **conf_kwargs)
45
+
46
+
47
+ def train_from_config(config: dict):
48
+ """Train a model from a configuration dictionary
49
+
50
+ Args:
51
+ config (dict): configuration dictionary
52
+
53
+ Returns:
54
+ Trainer: the trainer object
55
+ """
56
+
57
+ folder_paths = get_paths_from_config(config, mkdir=True)
58
+
59
+ trainer = get_trainer_from_config(config, folder_paths)
60
+
61
+ callbacks = get_callbacks_from_config(config, folder_paths)
62
+
63
+ trainer.train(
64
+ config['training']['num_iterations'],
65
+ callbacks, disable_tqdm=False,
66
+ update_tqdm_freq=config['training']['update_tqdm_freq'],
67
+ grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
68
+ )
69
+
70
+ torch.save({
71
+ 'paths': folder_paths,
72
+ 'losses': trainer.losses,
73
+ 'params': config,
74
+ }, folder_paths['losses'])
75
+
76
+ return trainer
77
+
78
+
79
+ def get_paths_from_config(config: dict, mkdir: bool = False):
80
+ """Get the directory paths from a configuration dictionary
81
+
82
+ Args:
83
+ config (dict): configuration dictionary
84
+ mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
85
+
86
+ Returns:
87
+ dict: dictionary containing the folder paths
88
+ """
89
+ root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
90
+ name = config['general']['name']
91
+
92
+ assert root_dir.is_dir()
93
+
94
+ saved_models_dir = root_dir / 'saved_models'
95
+ saved_losses_dir = root_dir / 'saved_losses'
96
+
97
+ if mkdir:
98
+ saved_models_dir.mkdir(exist_ok=True)
99
+ saved_losses_dir.mkdir(exist_ok=True)
100
+
101
+ model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
102
+
103
+ losses_path = saved_losses_dir / f'{name}_losses.pt'
104
+
105
+ return {
106
+ 'name': name,
107
+ 'model': model_path,
108
+ 'losses': losses_path,
109
+ 'root': root_dir,
110
+ 'saved_models': saved_models_dir,
111
+ 'saved_losses': saved_losses_dir,
112
+ }
113
+
114
+
115
+ def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
116
+ """Initializes the training callbacks from a configuration dictionary
117
+
118
+ Returns:
119
+ tuple: tuple of callbacks
120
+ """
121
+ callbacks = []
122
+
123
+ folder_paths = folder_paths or get_paths_from_config(config)
124
+
125
+ train_conf = config['training']
126
+ callback_conf = dict(train_conf['callbacks'])
127
+ save_conf = callback_conf.pop('save_best_model')
128
+
129
+ if save_conf['enable']:
130
+ save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
131
+ callbacks.append(save_model)
132
+
133
+ for conf in callback_conf.values():
134
+ callback = init_from_conf(conf)
135
+
136
+ if callback:
137
+ callbacks.append(callback)
138
+
139
+ if 'logger' in train_conf.keys():
140
+ callbacks.append(LogLosses())
141
+
142
+ return tuple(callbacks)
143
+
144
+
145
+ def get_trainer_from_config(config: dict, folder_paths: dict = None):
146
+ """Initializes a trainer from a configuration dictionary
147
+
148
+ Args:
149
+ config (dict): the configuration dictionary
150
+ folder_paths (dict, optional): dictionary containing the folder paths
151
+
152
+ Returns:
153
+ Trainer: the trainer object
154
+ """
155
+ dset = init_dset(config['dset'])
156
+
157
+ folder_paths = folder_paths or get_paths_from_config(config)
158
+
159
+ model = init_network(config['model']['network'], folder_paths['saved_models'])
160
+
161
+ train_conf = config['training']
162
+
163
+ optim_cls = getattr(torch.optim, train_conf['optimizer'])
164
+
165
+ logger_conf = train_conf.get('logger', None)
166
+ logger = init_from_conf(logger_conf) if logger_conf and logger_conf.get('cls') else None
167
+
168
+ clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
169
+
170
+ trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
171
+ trainer_kwargs = train_conf.get('trainer_kwargs', {})
172
+
173
+
174
+ trainer = trainer_cls(
175
+ model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
176
+ logger=logger, optim_cls=optim_cls,
177
+ **trainer_kwargs
178
+ )
179
+
180
+ if train_conf.get('train_with_q_input', False) and getattr(trainer, 'train_with_q_input', None) is not None: #only for back-compatibility with configs in older versions
181
+ trainer.train_with_q_input = True
182
+
183
+ return trainer
184
+
185
+
186
+ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
187
+ """Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
188
+ saved weights into the network
189
+
190
+ Args:
191
+ config_name (str): name of the configuration file
192
+ config_dir (str): path of the configuration directory
193
+ model_path (str, optional): path to the network weights. The default path is 'saved_models' located in the package directory
194
+ load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
195
+ inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
196
+
197
+ Returns:
198
+ Trainer: the trainer object
199
+ """
200
+ config = load_config(config_name, config_dir)
201
+ #config['model']['network']['pretrained_name'] = None
202
+
203
+ config['model']['network']['device'] = inference_device
204
+ config['dset']['prior_sampler']['kwargs']['device'] = inference_device
205
+ config['dset']['q_generator']['kwargs']['device'] = inference_device
206
+
207
+ trainer = get_trainer_from_config(config)
208
+
209
+ num_params = sum(p.numel() for p in trainer.model.parameters())
210
+
211
+ print(
212
+ f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
213
+ )
214
+
215
+ if not load_weights:
216
+ return trainer
217
+
218
+ if not model_path:
219
+ model_name = f'model_{config_name}.pt'
220
+ model_path = SAVED_MODELS_DIR / model_name
221
+
222
+ if str(model_path).endswith('.pt'):
223
+ try:
224
+ state_dict = torch.load(model_path, map_location=inference_device, weights_only=False)
225
+ except Exception as err:
226
+ raise RuntimeError(f'Could not load model from {model_path}') from err
227
+
228
+ if 'model' in state_dict:
229
+ trainer.model.load_state_dict(state_dict['model'])
230
+ else:
231
+ trainer.model.load_state_dict(state_dict)
232
+
233
+ elif str(model_path).endswith('.safetensors'):
234
+ try:
235
+ load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
236
+ except Exception as err:
237
+ raise RuntimeError(f'Could not load model from {model_path}') from err
238
+
239
+ else:
240
+ raise RuntimeError('Weigths file with unknown extension')
241
+
242
+ return trainer
243
+
244
+ def get_callbacks_by_name(config_name, config_dir=None):
245
+ """Initializes the trainer callbacks based on a configuration file
246
+
247
+ Args:
248
+ config_name (str): name of the configuration file
249
+ config_dir (str): path of the configuration directory
250
+ """
251
+ config = load_config(config_name, config_dir)
252
+ callbacks = get_callbacks_from_config(config)
253
+
254
+ return callbacks
255
+
256
+
257
+ def init_network(config: dict, saved_models_dir: Path):
258
+ """Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
259
+ device = config.get('device', 'cuda')
260
+ network = init_from_conf(config).to(device)
261
+ network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
262
+ return network
263
+
264
+
265
+ def load_pretrained(model, model_name: str, saved_models_dir: Path):
266
+ """Loads the saved weights into the network"""
267
+ if not model_name:
268
+ return model
269
+
270
+ if '.' not in model_name:
271
+ model_name = model_name + '.pt'
272
+ model_path = saved_models_dir / model_name
273
+
274
+ if not model_path.is_file():
275
+ model_path = saved_models_dir / f'model_{model_name}'
276
+
277
+ if not model_path.is_file():
278
+ raise FileNotFoundError(f'File {str(model_path)} does not exist.')
279
+
280
+ try:
281
+ pretrained = torch.load(model_path, weights_only=False)
282
+ except Exception as err:
283
+ raise RuntimeError(f'Could not load model from {str(model_path)}') from err
284
+
285
+ if 'model' in pretrained:
286
+ pretrained = pretrained['model']
287
+ try:
288
+ model.load_state_dict(pretrained)
289
+ except Exception as err:
290
+ raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
291
+
292
+ return model
293
+
294
+
295
+ def init_dset(config: dict):
296
+ """Initializes the dataset / dataloader object"""
297
+ dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
298
+ dset_kwargs = config.get('kwargs', {})
299
+
300
+ prior_sampler = init_from_conf(config['prior_sampler'])
301
+ intensity_noise = init_from_conf(config['intensity_noise'])
302
+ q_generator = init_from_conf(config['q_generator'])
303
+ curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
304
+ smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
305
+ q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
306
+
307
+ dset = dset_cls(
308
+ q_generator=q_generator,
309
+ prior_sampler=prior_sampler,
310
+ intensity_noise=intensity_noise,
311
+ curves_scaler=curves_scaler,
312
+ smearing=smearing,
313
+ q_noise=q_noise,
314
+ **dset_kwargs,
315
+ )
316
+
317
+ return dset
318
+
319
+ def split_complex_tensors(state_dict):
320
+ new_state_dict = {}
321
+ for key, tensor in state_dict.items():
322
+ if tensor.is_complex():
323
+ new_state_dict[f"{key}_real"] = tensor.real.clone()
324
+ new_state_dict[f"{key}_imag"] = tensor.imag.clone()
325
+ else:
326
+ new_state_dict[key] = tensor
327
+ return new_state_dict
328
+
329
+ def recombine_complex_tensors(state_dict):
330
+ new_state_dict = {}
331
+ keys = list(state_dict.keys())
332
+ visited = set()
333
+
334
+ for key in keys:
335
+ if key.endswith('_real') or key.endswith('_imag'):
336
+ base_key = key[:-5]
337
+ new_state_dict[base_key] = torch.complex(state_dict[base_key + '_real'], state_dict[base_key + '_imag'])
338
+ visited.add(base_key + '_real')
339
+ visited.add(base_key + '_imag')
340
+ elif key not in visited:
341
+ new_state_dict[key] = state_dict[key]
342
+
343
+ return new_state_dict
344
+
345
+ def convert_pt_to_safetensors(input_dir):
346
+ """Creates '.safetensors' files for all the model state dictionaries inside '.pt' files in the specified directory.
347
+
348
+ Args:
349
+ input_dir (str): directory containing model weights
350
+ """
351
+ if not os.path.isdir(input_dir):
352
+ raise ValueError(f"Input directory {input_dir} does not exist")
353
+
354
+ for file_name in os.listdir(input_dir):
355
+ if file_name.endswith('.pt'):
356
+ pt_file_path = os.path.join(input_dir, file_name)
357
+ safetensors_file_path = os.path.join(input_dir, file_name[:-3] + '.safetensors')
358
+
359
+ if os.path.exists(safetensors_file_path):
360
+ print(f"Skipping {pt_file_path}, corresponding .safetensors file already exists.")
361
+ continue
362
+
363
+ print(f"Converting {pt_file_path} to .safetensors format.")
364
+ data_pt = torch.load(pt_file_path, weights_only=False)
365
+ model_state_dict = data_pt["model"]
366
+ model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
367
+
368
+ safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
369
+
370
+ def convert_files_to_safetensors(files):
371
+ """
372
+ Converts specified .pt files to .safetensors format.
373
+
374
+ Args:
375
+ files (str or list of str): Path(s) to .pt files containing model state dictionaries.
376
+ """
377
+ if isinstance(files, str):
378
+ files = [files]
379
+
380
+ for pt_file_path in files:
381
+ if not pt_file_path.endswith('.pt'):
382
+ print(f"Skipping {pt_file_path}: not a .pt file.")
383
+ continue
384
+
385
+ if not os.path.exists(pt_file_path):
386
+ print(f"File {pt_file_path} does not exist.")
387
+ continue
388
+
389
+ safetensors_file_path = pt_file_path[:-3] + '.safetensors'
390
+
391
+ if os.path.exists(safetensors_file_path):
392
+ print(f"Skipping {pt_file_path}: .safetensors version already exists.")
393
+ continue
394
+
395
+ print(f"Converting {pt_file_path} to .safetensors format.")
396
+ data_pt = torch.load(pt_file_path, weights_only=False)
397
+ model_state_dict = data_pt["model"]
398
+ model_state_dict = split_complex_tensors(model_state_dict)
399
+
400
+ safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
401
+
402
+ def load_state_dict_safetensors(model, filename, device):
403
+ state_dict = safetensors.torch.load_file(filename=filename, device=device)
404
+ state_dict = recombine_complex_tensors(state_dict)
402
405
  model.load_state_dict(state_dict)