reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,93 @@
1
+ from typing import Tuple, Union
2
+ from pathlib import Path
3
+ import subprocess
4
+
5
+ from reflectorch.paths import RUN_SCRIPTS_DIR
6
+
7
+
8
+ def save_sbatch_and_run(
9
+ name: str,
10
+ args: str,
11
+ time: str,
12
+ partition: str = None,
13
+ reservation: bool = False,
14
+ chdir: str = '~/maxwell_output',
15
+ run_dir: Path = None,
16
+ confirm: bool = False,
17
+ ) -> Union[Tuple[str, str], None]:
18
+ run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
19
+ sbatch_path = run_dir / f'{name}.sh'
20
+
21
+ if sbatch_path.is_file():
22
+ import warnings
23
+ warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
24
+ if confirm and not confirm_input('Continue?'):
25
+ return
26
+
27
+ file_content = _generate_sbatch_str(
28
+ name,
29
+ args,
30
+ time=time,
31
+ reservation=reservation,
32
+ partition=partition,
33
+ chdir=chdir,
34
+ )
35
+
36
+ if confirm:
37
+ print(f'Generated file content: \n{file_content}\n')
38
+ if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
39
+ return
40
+
41
+ with open(str(sbatch_path), 'w') as f:
42
+ f.write(file_content)
43
+
44
+ res = submit_job(str(sbatch_path))
45
+ return res
46
+
47
+
48
+ def _generate_sbatch_str(name: str,
49
+ args: str,
50
+ time: str,
51
+ partition: str = None,
52
+ reservation: bool = False,
53
+ chdir: str = '~/maxwell_output',
54
+ entry_point: str = 'python -m reflectorch.train',
55
+ ) -> str:
56
+ chdir = str(Path(chdir).expanduser().absolute())
57
+ partition_keyword = 'reservation' if reservation else 'partition'
58
+
59
+ return f'''#!/bin/bash
60
+ #SBATCH --chdir {chdir}
61
+ #SBATCH --{partition_keyword}={partition}
62
+ #SBATCH --constraint=P100
63
+ #SBATCH --nodes=1
64
+ #SBATCH --job-name {name}
65
+ #SBATCH --time={time}
66
+ #SBATCH --output {name}.out
67
+ #SBATCH --error {name}.err
68
+
69
+ {entry_point} {args}
70
+ '''
71
+
72
+
73
+ def confirm_input(message: str) -> bool:
74
+ yes = ('y', 'yes')
75
+ no = ('n', 'no')
76
+ res = ''
77
+ valid_results = list(yes) + list(no)
78
+ message = f'{message} Y/n: '
79
+
80
+ while res not in valid_results:
81
+ res = input(message).lower()
82
+ return res in yes
83
+
84
+
85
+ def submit_job(sbatch_path: str) -> Tuple[str, str]:
86
+ process = subprocess.Popen(
87
+ ['sbatch', str(sbatch_path)],
88
+ stdout=subprocess.PIPE,
89
+ stderr=subprocess.PIPE,
90
+ )
91
+
92
+ stdout, stderr = process.communicate()
93
+ return stdout.decode(), stderr.decode()
@@ -0,0 +1,78 @@
1
+ import click
2
+
3
+ from reflectorch.runs.slurm_utils import save_sbatch_and_run
4
+ from reflectorch.runs.utils import train_from_config
5
+ from reflectorch.runs.config import load_config
6
+
7
+ __all__ = [
8
+ 'run_train',
9
+ 'run_train_on_cluster',
10
+ 'run_test_config',
11
+ ]
12
+
13
+
14
+ @click.command()
15
+ @click.argument('config_name', type=str)
16
+ def run_train(config_name: str):
17
+ """Runs the training from the command line interface
18
+ Example: python -m reflectorch.train 'conf_name'
19
+
20
+ Args:
21
+ config_name (str): name of the YAML configuration file
22
+ """
23
+ config = load_config(config_name)
24
+ train_from_config(config)
25
+
26
+
27
+ @click.command()
28
+ @click.argument('config_name', type=str)
29
+ @click.argument('batch_size', type=int, default=512)
30
+ @click.argument('num_iterations', type=int, default=10)
31
+ def run_test_config(config_name: str, batch_size: int, num_iterations: int):
32
+ """Run for the purpose of testing the configuration file.
33
+ Example: python -m reflectorch.test_config 'conf_name.yaml' 512 10
34
+
35
+ Args:
36
+ config_name (str): name of the YAML configuration file
37
+ batch_size (int): overwrites the batch size in the configuration file
38
+ num_iterations (int): overwrites the number of iterations in the configuration file
39
+ """
40
+ config = load_config(config_name)
41
+ config = _change_to_test_config(config, batch_size=batch_size, num_iterations=num_iterations)
42
+ train_from_config(config)
43
+
44
+
45
+ @click.command()
46
+ @click.argument('config_name')
47
+ def run_train_on_cluster(config_name: str):
48
+ config = load_config(config_name)
49
+ name = config['general']['name']
50
+ slurm_conf = config['slurm']
51
+
52
+ res = save_sbatch_and_run(
53
+ name,
54
+ config_name,
55
+ time=slurm_conf['time'],
56
+ partition=slurm_conf['partition'],
57
+ reservation=slurm_conf.get('reservation', False),
58
+ chdir=slurm_conf.get('chdir', '~/maxwell_output'),
59
+ run_dir=slurm_conf.get('run_dir', None),
60
+ confirm=slurm_conf.get('confirm', True),
61
+ )
62
+ if not res:
63
+ print('Aborted.')
64
+ return
65
+ out, err = res
66
+
67
+ if err:
68
+ print('Error occurred: ', err)
69
+ else:
70
+ print('Success!', out)
71
+
72
+
73
+ def _change_to_test_config(config, batch_size: int, num_iterations: int):
74
+ config = dict(config)
75
+ config['training']['num_iterations'] = num_iterations
76
+ config['training']['batch_size'] = batch_size
77
+ config['training']['update_tqdm_freq'] = 1
78
+ return config
@@ -0,0 +1,405 @@
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import safetensors.torch
6
+ import os
7
+
8
+ from reflectorch import *
9
+ from reflectorch.runs.config import load_config
10
+
11
+ __all__ = [
12
+ "train_from_config",
13
+ "get_trainer_from_config",
14
+ "get_paths_from_config",
15
+ "get_callbacks_from_config",
16
+ "get_trainer_by_name",
17
+ "get_callbacks_by_name",
18
+ "convert_pt_to_safetensors",
19
+ ]
20
+
21
+
22
+ def init_from_conf(conf, **kwargs):
23
+ """Initializes an object with class type, args and kwargs specified in the configuration
24
+
25
+ Args:
26
+ conf (dict): configuration dictionary
27
+
28
+ Returns:
29
+ Any: the initialized object
30
+ """
31
+ if not conf:
32
+ return
33
+ cls_name = conf['cls']
34
+ if not cls_name:
35
+ return
36
+ cls = globals().get(cls_name)
37
+
38
+ if not cls:
39
+ raise ValueError(f'Unknown class {cls_name}')
40
+
41
+ conf_args = conf.get('args', [])
42
+ conf_kwargs = conf.get('kwargs', {})
43
+ conf_kwargs.update(kwargs)
44
+ return cls(*conf_args, **conf_kwargs)
45
+
46
+
47
+ def train_from_config(config: dict):
48
+ """Train a model from a configuration dictionary
49
+
50
+ Args:
51
+ config (dict): configuration dictionary
52
+
53
+ Returns:
54
+ Trainer: the trainer object
55
+ """
56
+
57
+ folder_paths = get_paths_from_config(config, mkdir=True)
58
+
59
+ trainer = get_trainer_from_config(config, folder_paths)
60
+
61
+ callbacks = get_callbacks_from_config(config, folder_paths)
62
+
63
+ trainer.train(
64
+ config['training']['num_iterations'],
65
+ callbacks, disable_tqdm=False,
66
+ update_tqdm_freq=config['training']['update_tqdm_freq'],
67
+ grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
68
+ )
69
+
70
+ torch.save({
71
+ 'paths': folder_paths,
72
+ 'losses': trainer.losses,
73
+ 'params': config,
74
+ }, folder_paths['losses'])
75
+
76
+ return trainer
77
+
78
+
79
+ def get_paths_from_config(config: dict, mkdir: bool = False):
80
+ """Get the directory paths from a configuration dictionary
81
+
82
+ Args:
83
+ config (dict): configuration dictionary
84
+ mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
85
+
86
+ Returns:
87
+ dict: dictionary containing the folder paths
88
+ """
89
+ root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
90
+ name = config['general']['name']
91
+
92
+ assert root_dir.is_dir()
93
+
94
+ saved_models_dir = root_dir / 'saved_models'
95
+ saved_losses_dir = root_dir / 'saved_losses'
96
+
97
+ if mkdir:
98
+ saved_models_dir.mkdir(exist_ok=True)
99
+ saved_losses_dir.mkdir(exist_ok=True)
100
+
101
+ model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
102
+
103
+ losses_path = saved_losses_dir / f'{name}_losses.pt'
104
+
105
+ return {
106
+ 'name': name,
107
+ 'model': model_path,
108
+ 'losses': losses_path,
109
+ 'root': root_dir,
110
+ 'saved_models': saved_models_dir,
111
+ 'saved_losses': saved_losses_dir,
112
+ }
113
+
114
+
115
+ def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
116
+ """Initializes the training callbacks from a configuration dictionary
117
+
118
+ Returns:
119
+ tuple: tuple of callbacks
120
+ """
121
+ callbacks = []
122
+
123
+ folder_paths = folder_paths or get_paths_from_config(config)
124
+
125
+ train_conf = config['training']
126
+ callback_conf = dict(train_conf['callbacks'])
127
+ save_conf = callback_conf.pop('save_best_model')
128
+
129
+ if save_conf['enable']:
130
+ save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
131
+ callbacks.append(save_model)
132
+
133
+ for conf in callback_conf.values():
134
+ callback = init_from_conf(conf)
135
+
136
+ if callback:
137
+ callbacks.append(callback)
138
+
139
+ if 'logger' in train_conf.keys():
140
+ callbacks.append(LogLosses())
141
+
142
+ return tuple(callbacks)
143
+
144
+
145
+ def get_trainer_from_config(config: dict, folder_paths: dict = None):
146
+ """Initializes a trainer from a configuration dictionary
147
+
148
+ Args:
149
+ config (dict): the configuration dictionary
150
+ folder_paths (dict, optional): dictionary containing the folder paths
151
+
152
+ Returns:
153
+ Trainer: the trainer object
154
+ """
155
+ dset = init_dset(config['dset'])
156
+
157
+ folder_paths = folder_paths or get_paths_from_config(config)
158
+
159
+ model = init_network(config['model']['network'], folder_paths['saved_models'])
160
+
161
+ train_conf = config['training']
162
+
163
+ optim_cls = getattr(torch.optim, train_conf['optimizer'])
164
+
165
+ logger_conf = train_conf.get('logger', None)
166
+ logger = init_from_conf(logger_conf) if logger_conf and logger_conf.get('cls') else None
167
+
168
+ clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
169
+
170
+ trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
171
+ trainer_kwargs = train_conf.get('trainer_kwargs', {})
172
+
173
+
174
+ trainer = trainer_cls(
175
+ model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
176
+ logger=logger, optim_cls=optim_cls,
177
+ **trainer_kwargs
178
+ )
179
+
180
+ if train_conf.get('train_with_q_input', False) and getattr(trainer, 'train_with_q_input', None) is not None: #only for back-compatibility with configs in older versions
181
+ trainer.train_with_q_input = True
182
+
183
+ return trainer
184
+
185
+
186
+ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
187
+ """Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
188
+ saved weights into the network
189
+
190
+ Args:
191
+ config_name (str): name of the configuration file
192
+ config_dir (str): path of the configuration directory
193
+ model_path (str, optional): path to the network weights. The default path is 'saved_models' located in the package directory
194
+ load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
195
+ inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
196
+
197
+ Returns:
198
+ Trainer: the trainer object
199
+ """
200
+ config = load_config(config_name, config_dir)
201
+ #config['model']['network']['pretrained_name'] = None
202
+
203
+ config['model']['network']['device'] = inference_device
204
+ config['dset']['prior_sampler']['kwargs']['device'] = inference_device
205
+ config['dset']['q_generator']['kwargs']['device'] = inference_device
206
+
207
+ trainer = get_trainer_from_config(config)
208
+
209
+ num_params = sum(p.numel() for p in trainer.model.parameters())
210
+
211
+ print(
212
+ f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
213
+ )
214
+
215
+ if not load_weights:
216
+ return trainer
217
+
218
+ if not model_path:
219
+ model_name = f'model_{config_name}.pt'
220
+ model_path = SAVED_MODELS_DIR / model_name
221
+
222
+ if str(model_path).endswith('.pt'):
223
+ try:
224
+ state_dict = torch.load(model_path, map_location=inference_device, weights_only=False)
225
+ except Exception as err:
226
+ raise RuntimeError(f'Could not load model from {model_path}') from err
227
+
228
+ if 'model' in state_dict:
229
+ trainer.model.load_state_dict(state_dict['model'])
230
+ else:
231
+ trainer.model.load_state_dict(state_dict)
232
+
233
+ elif str(model_path).endswith('.safetensors'):
234
+ try:
235
+ load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
236
+ except Exception as err:
237
+ raise RuntimeError(f'Could not load model from {model_path}') from err
238
+
239
+ else:
240
+ raise RuntimeError('Weigths file with unknown extension')
241
+
242
+ return trainer
243
+
244
+ def get_callbacks_by_name(config_name, config_dir=None):
245
+ """Initializes the trainer callbacks based on a configuration file
246
+
247
+ Args:
248
+ config_name (str): name of the configuration file
249
+ config_dir (str): path of the configuration directory
250
+ """
251
+ config = load_config(config_name, config_dir)
252
+ callbacks = get_callbacks_from_config(config)
253
+
254
+ return callbacks
255
+
256
+
257
+ def init_network(config: dict, saved_models_dir: Path):
258
+ """Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
259
+ device = config.get('device', 'cuda')
260
+ network = init_from_conf(config).to(device)
261
+ network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
262
+ return network
263
+
264
+
265
+ def load_pretrained(model, model_name: str, saved_models_dir: Path):
266
+ """Loads the saved weights into the network"""
267
+ if not model_name:
268
+ return model
269
+
270
+ if '.' not in model_name:
271
+ model_name = model_name + '.pt'
272
+ model_path = saved_models_dir / model_name
273
+
274
+ if not model_path.is_file():
275
+ model_path = saved_models_dir / f'model_{model_name}'
276
+
277
+ if not model_path.is_file():
278
+ raise FileNotFoundError(f'File {str(model_path)} does not exist.')
279
+
280
+ try:
281
+ pretrained = torch.load(model_path, weights_only=False)
282
+ except Exception as err:
283
+ raise RuntimeError(f'Could not load model from {str(model_path)}') from err
284
+
285
+ if 'model' in pretrained:
286
+ pretrained = pretrained['model']
287
+ try:
288
+ model.load_state_dict(pretrained)
289
+ except Exception as err:
290
+ raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
291
+
292
+ return model
293
+
294
+
295
+ def init_dset(config: dict):
296
+ """Initializes the dataset / dataloader object"""
297
+ dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
298
+ dset_kwargs = config.get('kwargs', {})
299
+
300
+ prior_sampler = init_from_conf(config['prior_sampler'])
301
+ intensity_noise = init_from_conf(config['intensity_noise'])
302
+ q_generator = init_from_conf(config['q_generator'])
303
+ curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
304
+ smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
305
+ q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
306
+
307
+ dset = dset_cls(
308
+ q_generator=q_generator,
309
+ prior_sampler=prior_sampler,
310
+ intensity_noise=intensity_noise,
311
+ curves_scaler=curves_scaler,
312
+ smearing=smearing,
313
+ q_noise=q_noise,
314
+ **dset_kwargs,
315
+ )
316
+
317
+ return dset
318
+
319
+ def split_complex_tensors(state_dict):
320
+ new_state_dict = {}
321
+ for key, tensor in state_dict.items():
322
+ if tensor.is_complex():
323
+ new_state_dict[f"{key}_real"] = tensor.real.clone()
324
+ new_state_dict[f"{key}_imag"] = tensor.imag.clone()
325
+ else:
326
+ new_state_dict[key] = tensor
327
+ return new_state_dict
328
+
329
+ def recombine_complex_tensors(state_dict):
330
+ new_state_dict = {}
331
+ keys = list(state_dict.keys())
332
+ visited = set()
333
+
334
+ for key in keys:
335
+ if key.endswith('_real') or key.endswith('_imag'):
336
+ base_key = key[:-5]
337
+ new_state_dict[base_key] = torch.complex(state_dict[base_key + '_real'], state_dict[base_key + '_imag'])
338
+ visited.add(base_key + '_real')
339
+ visited.add(base_key + '_imag')
340
+ elif key not in visited:
341
+ new_state_dict[key] = state_dict[key]
342
+
343
+ return new_state_dict
344
+
345
+ def convert_pt_to_safetensors(input_dir):
346
+ """Creates '.safetensors' files for all the model state dictionaries inside '.pt' files in the specified directory.
347
+
348
+ Args:
349
+ input_dir (str): directory containing model weights
350
+ """
351
+ if not os.path.isdir(input_dir):
352
+ raise ValueError(f"Input directory {input_dir} does not exist")
353
+
354
+ for file_name in os.listdir(input_dir):
355
+ if file_name.endswith('.pt'):
356
+ pt_file_path = os.path.join(input_dir, file_name)
357
+ safetensors_file_path = os.path.join(input_dir, file_name[:-3] + '.safetensors')
358
+
359
+ if os.path.exists(safetensors_file_path):
360
+ print(f"Skipping {pt_file_path}, corresponding .safetensors file already exists.")
361
+ continue
362
+
363
+ print(f"Converting {pt_file_path} to .safetensors format.")
364
+ data_pt = torch.load(pt_file_path, weights_only=False)
365
+ model_state_dict = data_pt["model"]
366
+ model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
367
+
368
+ safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
369
+
370
+ def convert_files_to_safetensors(files):
371
+ """
372
+ Converts specified .pt files to .safetensors format.
373
+
374
+ Args:
375
+ files (str or list of str): Path(s) to .pt files containing model state dictionaries.
376
+ """
377
+ if isinstance(files, str):
378
+ files = [files]
379
+
380
+ for pt_file_path in files:
381
+ if not pt_file_path.endswith('.pt'):
382
+ print(f"Skipping {pt_file_path}: not a .pt file.")
383
+ continue
384
+
385
+ if not os.path.exists(pt_file_path):
386
+ print(f"File {pt_file_path} does not exist.")
387
+ continue
388
+
389
+ safetensors_file_path = pt_file_path[:-3] + '.safetensors'
390
+
391
+ if os.path.exists(safetensors_file_path):
392
+ print(f"Skipping {pt_file_path}: .safetensors version already exists.")
393
+ continue
394
+
395
+ print(f"Converting {pt_file_path} to .safetensors format.")
396
+ data_pt = torch.load(pt_file_path, weights_only=False)
397
+ model_state_dict = data_pt["model"]
398
+ model_state_dict = split_complex_tensors(model_state_dict)
399
+
400
+ safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
401
+
402
+ def load_state_dict_safetensors(model, filename, device):
403
+ state_dict = safetensors.torch.load_file(filename=filename, device=device)
404
+ state_dict = recombine_complex_tensors(state_dict)
405
+ model.load_state_dict(state_dict)
@@ -0,0 +1,4 @@
1
+ from reflectorch.runs import run_test_config
2
+
3
+ if __name__ == '__main__':
4
+ run_test_config()
reflectorch/train.py ADDED
@@ -0,0 +1,4 @@
1
+ from reflectorch.runs import run_train
2
+
3
+ if __name__ == '__main__':
4
+ run_train()
@@ -0,0 +1,4 @@
1
+ from reflectorch.runs import run_train_on_cluster
2
+
3
+ if __name__ == '__main__':
4
+ run_train_on_cluster()
reflectorch/utils.py ADDED
@@ -0,0 +1,98 @@
1
+ import numpy as np
2
+ from numpy import ndarray
3
+
4
+ from torch import Tensor, tensor
5
+
6
+ __all__ = [
7
+ 'to_np',
8
+ 'to_t',
9
+ 'angle_to_q',
10
+ 'q_to_angle',
11
+ 'energy_to_wavelength',
12
+ 'wavelength_to_energy',
13
+ ]
14
+
15
+
16
+ def to_np(arr):
17
+ """Converts Pytorch tensor or Python list to Numpy array
18
+
19
+ Args:
20
+ arr (torch.Tensor or list): Input Pytorch tensor or Python list
21
+
22
+ Returns:
23
+ numpy.ndarray: Converted Numpy array
24
+ """
25
+
26
+ if isinstance(arr, Tensor):
27
+ return arr.detach().cpu().numpy()
28
+ return np.asarray(arr)
29
+
30
+
31
+ def to_t(arr, device=None, dtype=None):
32
+ """Converts Numpy array or Python list to Pytorch tensor
33
+
34
+ Args:
35
+ arr (numpy.ndarray or list): Input
36
+ device (torch.device or str, optional): device for the tensor ('cpu', 'cuda')
37
+ dtype (torch.dtype, optional): data type of the tensor (e.g. torch.float32)
38
+
39
+ Returns:
40
+ torch.Tensor: converted Pytorch tensor
41
+ """
42
+
43
+ if not isinstance(arr, Tensor):
44
+ return tensor(arr, device=device, dtype=dtype)
45
+ return arr
46
+
47
+
48
+ # taken from mlreflect package
49
+ # mlreflect/xrrloader/dataloader/transform.py
50
+
51
+ def angle_to_q(scattering_angle: ndarray or float, wavelength: float):
52
+ """Conversion from full scattering angle (degrees) to scattering vector (inverse angstroms)"""
53
+ return 4 * np.pi / wavelength * np.sin(scattering_angle / 2 * np.pi / 180)
54
+
55
+
56
+ def q_to_angle(q: ndarray or float, wavelength: float):
57
+ """Conversion from scattering vector (inverse angstroms) to full scattering angle (degrees)"""
58
+ return 2 * np.arcsin(q * wavelength / (4 * np.pi)) / np.pi * 180
59
+
60
+
61
+ def energy_to_wavelength(energy: float):
62
+ """Conversion from photon energy (eV) to photon wavelength (angstroms)"""
63
+ return 1.2398 / energy * 1e4
64
+
65
+
66
+ def wavelength_to_energy(wavelength: float):
67
+ """Conversion from photon wavelength (angstroms) to photon energy (eV)"""
68
+ return 1.2398 / wavelength * 1e4
69
+
70
+ def get_filtering_mask(Q, R, dR, threshold=0.3, consecutive=3,
71
+ remove_singles=True, remove_consecutives=True,
72
+ q_start_trunc=0.1):
73
+ Q, R, dR = Q.copy(), R.copy(), dR.copy()
74
+ rel_error = np.abs(dR / R)
75
+
76
+ # Mask for singles
77
+ mask_singles = (rel_error >= threshold) if remove_singles else np.zeros_like(Q, dtype=bool)
78
+
79
+ # Mask for truncation
80
+ mask_consecutive = np.zeros_like(Q, dtype=bool)
81
+ if remove_consecutives:
82
+ count = 0
83
+ cutoff_idx = None
84
+ for i in range(len(Q)):
85
+ if Q[i] < q_start_trunc:
86
+ continue
87
+ if rel_error[i] >= threshold:
88
+ count += 1
89
+ if count >= consecutive:
90
+ cutoff_idx = i - consecutive + 1
91
+ break
92
+ else:
93
+ count = 0
94
+ if cutoff_idx is not None:
95
+ mask_consecutive[cutoff_idx:] = True
96
+
97
+ final_mask = mask_singles | mask_consecutive
98
+ return ~final_mask