reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Tuple, Union
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import subprocess
|
|
4
|
+
|
|
5
|
+
from reflectorch.paths import RUN_SCRIPTS_DIR
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def save_sbatch_and_run(
|
|
9
|
+
name: str,
|
|
10
|
+
args: str,
|
|
11
|
+
time: str,
|
|
12
|
+
partition: str = None,
|
|
13
|
+
reservation: bool = False,
|
|
14
|
+
chdir: str = '~/maxwell_output',
|
|
15
|
+
run_dir: Path = None,
|
|
16
|
+
confirm: bool = False,
|
|
17
|
+
) -> Union[Tuple[str, str], None]:
|
|
18
|
+
run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
|
|
19
|
+
sbatch_path = run_dir / f'{name}.sh'
|
|
20
|
+
|
|
21
|
+
if sbatch_path.is_file():
|
|
22
|
+
import warnings
|
|
23
|
+
warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
|
|
24
|
+
if confirm and not confirm_input('Continue?'):
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
file_content = _generate_sbatch_str(
|
|
28
|
+
name,
|
|
29
|
+
args,
|
|
30
|
+
time=time,
|
|
31
|
+
reservation=reservation,
|
|
32
|
+
partition=partition,
|
|
33
|
+
chdir=chdir,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if confirm:
|
|
37
|
+
print(f'Generated file content: \n{file_content}\n')
|
|
38
|
+
if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
with open(str(sbatch_path), 'w') as f:
|
|
42
|
+
f.write(file_content)
|
|
43
|
+
|
|
44
|
+
res = submit_job(str(sbatch_path))
|
|
45
|
+
return res
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _generate_sbatch_str(name: str,
|
|
49
|
+
args: str,
|
|
50
|
+
time: str,
|
|
51
|
+
partition: str = None,
|
|
52
|
+
reservation: bool = False,
|
|
53
|
+
chdir: str = '~/maxwell_output',
|
|
54
|
+
entry_point: str = 'python -m reflectorch.train',
|
|
55
|
+
) -> str:
|
|
56
|
+
chdir = str(Path(chdir).expanduser().absolute())
|
|
57
|
+
partition_keyword = 'reservation' if reservation else 'partition'
|
|
58
|
+
|
|
59
|
+
return f'''#!/bin/bash
|
|
60
|
+
#SBATCH --chdir {chdir}
|
|
61
|
+
#SBATCH --{partition_keyword}={partition}
|
|
62
|
+
#SBATCH --constraint=P100
|
|
63
|
+
#SBATCH --nodes=1
|
|
64
|
+
#SBATCH --job-name {name}
|
|
65
|
+
#SBATCH --time={time}
|
|
66
|
+
#SBATCH --output {name}.out
|
|
67
|
+
#SBATCH --error {name}.err
|
|
68
|
+
|
|
69
|
+
{entry_point} {args}
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def confirm_input(message: str) -> bool:
|
|
74
|
+
yes = ('y', 'yes')
|
|
75
|
+
no = ('n', 'no')
|
|
76
|
+
res = ''
|
|
77
|
+
valid_results = list(yes) + list(no)
|
|
78
|
+
message = f'{message} Y/n: '
|
|
79
|
+
|
|
80
|
+
while res not in valid_results:
|
|
81
|
+
res = input(message).lower()
|
|
82
|
+
return res in yes
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def submit_job(sbatch_path: str) -> Tuple[str, str]:
|
|
86
|
+
process = subprocess.Popen(
|
|
87
|
+
['sbatch', str(sbatch_path)],
|
|
88
|
+
stdout=subprocess.PIPE,
|
|
89
|
+
stderr=subprocess.PIPE,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
stdout, stderr = process.communicate()
|
|
93
|
+
return stdout.decode(), stderr.decode()
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from reflectorch.runs.slurm_utils import save_sbatch_and_run
|
|
4
|
+
from reflectorch.runs.utils import train_from_config
|
|
5
|
+
from reflectorch.runs.config import load_config
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'run_train',
|
|
9
|
+
'run_train_on_cluster',
|
|
10
|
+
'run_test_config',
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@click.command()
|
|
15
|
+
@click.argument('config_name', type=str)
|
|
16
|
+
def run_train(config_name: str):
|
|
17
|
+
"""Runs the training from the command line interface
|
|
18
|
+
Example: python -m reflectorch.train 'conf_name'
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config_name (str): name of the YAML configuration file
|
|
22
|
+
"""
|
|
23
|
+
config = load_config(config_name)
|
|
24
|
+
train_from_config(config)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@click.command()
|
|
28
|
+
@click.argument('config_name', type=str)
|
|
29
|
+
@click.argument('batch_size', type=int, default=512)
|
|
30
|
+
@click.argument('num_iterations', type=int, default=10)
|
|
31
|
+
def run_test_config(config_name: str, batch_size: int, num_iterations: int):
|
|
32
|
+
"""Run for the purpose of testing the configuration file.
|
|
33
|
+
Example: python -m reflectorch.test_config 'conf_name.yaml' 512 10
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config_name (str): name of the YAML configuration file
|
|
37
|
+
batch_size (int): overwrites the batch size in the configuration file
|
|
38
|
+
num_iterations (int): overwrites the number of iterations in the configuration file
|
|
39
|
+
"""
|
|
40
|
+
config = load_config(config_name)
|
|
41
|
+
config = _change_to_test_config(config, batch_size=batch_size, num_iterations=num_iterations)
|
|
42
|
+
train_from_config(config)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@click.command()
|
|
46
|
+
@click.argument('config_name')
|
|
47
|
+
def run_train_on_cluster(config_name: str):
|
|
48
|
+
config = load_config(config_name)
|
|
49
|
+
name = config['general']['name']
|
|
50
|
+
slurm_conf = config['slurm']
|
|
51
|
+
|
|
52
|
+
res = save_sbatch_and_run(
|
|
53
|
+
name,
|
|
54
|
+
config_name,
|
|
55
|
+
time=slurm_conf['time'],
|
|
56
|
+
partition=slurm_conf['partition'],
|
|
57
|
+
reservation=slurm_conf.get('reservation', False),
|
|
58
|
+
chdir=slurm_conf.get('chdir', '~/maxwell_output'),
|
|
59
|
+
run_dir=slurm_conf.get('run_dir', None),
|
|
60
|
+
confirm=slurm_conf.get('confirm', True),
|
|
61
|
+
)
|
|
62
|
+
if not res:
|
|
63
|
+
print('Aborted.')
|
|
64
|
+
return
|
|
65
|
+
out, err = res
|
|
66
|
+
|
|
67
|
+
if err:
|
|
68
|
+
print('Error occurred: ', err)
|
|
69
|
+
else:
|
|
70
|
+
print('Success!', out)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _change_to_test_config(config, batch_size: int, num_iterations: int):
|
|
74
|
+
config = dict(config)
|
|
75
|
+
config['training']['num_iterations'] = num_iterations
|
|
76
|
+
config['training']['batch_size'] = batch_size
|
|
77
|
+
config['training']['update_tqdm_freq'] = 1
|
|
78
|
+
return config
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import safetensors.torch
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from reflectorch import *
|
|
9
|
+
from reflectorch.runs.config import load_config
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"train_from_config",
|
|
13
|
+
"get_trainer_from_config",
|
|
14
|
+
"get_paths_from_config",
|
|
15
|
+
"get_callbacks_from_config",
|
|
16
|
+
"get_trainer_by_name",
|
|
17
|
+
"get_callbacks_by_name",
|
|
18
|
+
"convert_pt_to_safetensors",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def init_from_conf(conf, **kwargs):
|
|
23
|
+
"""Initializes an object with class type, args and kwargs specified in the configuration
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
conf (dict): configuration dictionary
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Any: the initialized object
|
|
30
|
+
"""
|
|
31
|
+
if not conf:
|
|
32
|
+
return
|
|
33
|
+
cls_name = conf['cls']
|
|
34
|
+
if not cls_name:
|
|
35
|
+
return
|
|
36
|
+
cls = globals().get(cls_name)
|
|
37
|
+
|
|
38
|
+
if not cls:
|
|
39
|
+
raise ValueError(f'Unknown class {cls_name}')
|
|
40
|
+
|
|
41
|
+
conf_args = conf.get('args', [])
|
|
42
|
+
conf_kwargs = conf.get('kwargs', {})
|
|
43
|
+
conf_kwargs.update(kwargs)
|
|
44
|
+
return cls(*conf_args, **conf_kwargs)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def train_from_config(config: dict):
|
|
48
|
+
"""Train a model from a configuration dictionary
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config (dict): configuration dictionary
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Trainer: the trainer object
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
folder_paths = get_paths_from_config(config, mkdir=True)
|
|
58
|
+
|
|
59
|
+
trainer = get_trainer_from_config(config, folder_paths)
|
|
60
|
+
|
|
61
|
+
callbacks = get_callbacks_from_config(config, folder_paths)
|
|
62
|
+
|
|
63
|
+
trainer.train(
|
|
64
|
+
config['training']['num_iterations'],
|
|
65
|
+
callbacks, disable_tqdm=False,
|
|
66
|
+
update_tqdm_freq=config['training']['update_tqdm_freq'],
|
|
67
|
+
grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
torch.save({
|
|
71
|
+
'paths': folder_paths,
|
|
72
|
+
'losses': trainer.losses,
|
|
73
|
+
'params': config,
|
|
74
|
+
}, folder_paths['losses'])
|
|
75
|
+
|
|
76
|
+
return trainer
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_paths_from_config(config: dict, mkdir: bool = False):
|
|
80
|
+
"""Get the directory paths from a configuration dictionary
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
config (dict): configuration dictionary
|
|
84
|
+
mkdir (bool, optional): option to create a new directory for the saved model weights and losses.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
dict: dictionary containing the folder paths
|
|
88
|
+
"""
|
|
89
|
+
root_dir = Path(config['general']['root_dir'] or ROOT_DIR)
|
|
90
|
+
name = config['general']['name']
|
|
91
|
+
|
|
92
|
+
assert root_dir.is_dir()
|
|
93
|
+
|
|
94
|
+
saved_models_dir = root_dir / 'saved_models'
|
|
95
|
+
saved_losses_dir = root_dir / 'saved_losses'
|
|
96
|
+
|
|
97
|
+
if mkdir:
|
|
98
|
+
saved_models_dir.mkdir(exist_ok=True)
|
|
99
|
+
saved_losses_dir.mkdir(exist_ok=True)
|
|
100
|
+
|
|
101
|
+
model_path = str((saved_models_dir / f'model_{name}.pt').absolute())
|
|
102
|
+
|
|
103
|
+
losses_path = saved_losses_dir / f'{name}_losses.pt'
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
'name': name,
|
|
107
|
+
'model': model_path,
|
|
108
|
+
'losses': losses_path,
|
|
109
|
+
'root': root_dir,
|
|
110
|
+
'saved_models': saved_models_dir,
|
|
111
|
+
'saved_losses': saved_losses_dir,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]:
|
|
116
|
+
"""Initializes the training callbacks from a configuration dictionary
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
tuple: tuple of callbacks
|
|
120
|
+
"""
|
|
121
|
+
callbacks = []
|
|
122
|
+
|
|
123
|
+
folder_paths = folder_paths or get_paths_from_config(config)
|
|
124
|
+
|
|
125
|
+
train_conf = config['training']
|
|
126
|
+
callback_conf = dict(train_conf['callbacks'])
|
|
127
|
+
save_conf = callback_conf.pop('save_best_model')
|
|
128
|
+
|
|
129
|
+
if save_conf['enable']:
|
|
130
|
+
save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq'])
|
|
131
|
+
callbacks.append(save_model)
|
|
132
|
+
|
|
133
|
+
for conf in callback_conf.values():
|
|
134
|
+
callback = init_from_conf(conf)
|
|
135
|
+
|
|
136
|
+
if callback:
|
|
137
|
+
callbacks.append(callback)
|
|
138
|
+
|
|
139
|
+
if 'logger' in train_conf.keys():
|
|
140
|
+
callbacks.append(LogLosses())
|
|
141
|
+
|
|
142
|
+
return tuple(callbacks)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_trainer_from_config(config: dict, folder_paths: dict = None):
|
|
146
|
+
"""Initializes a trainer from a configuration dictionary
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
config (dict): the configuration dictionary
|
|
150
|
+
folder_paths (dict, optional): dictionary containing the folder paths
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Trainer: the trainer object
|
|
154
|
+
"""
|
|
155
|
+
dset = init_dset(config['dset'])
|
|
156
|
+
|
|
157
|
+
folder_paths = folder_paths or get_paths_from_config(config)
|
|
158
|
+
|
|
159
|
+
model = init_network(config['model']['network'], folder_paths['saved_models'])
|
|
160
|
+
|
|
161
|
+
train_conf = config['training']
|
|
162
|
+
|
|
163
|
+
optim_cls = getattr(torch.optim, train_conf['optimizer'])
|
|
164
|
+
|
|
165
|
+
logger_conf = train_conf.get('logger', None)
|
|
166
|
+
logger = init_from_conf(logger_conf) if logger_conf and logger_conf.get('cls') else None
|
|
167
|
+
|
|
168
|
+
clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None)
|
|
169
|
+
|
|
170
|
+
trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer
|
|
171
|
+
trainer_kwargs = train_conf.get('trainer_kwargs', {})
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
trainer = trainer_cls(
|
|
175
|
+
model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max,
|
|
176
|
+
logger=logger, optim_cls=optim_cls,
|
|
177
|
+
**trainer_kwargs
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if train_conf.get('train_with_q_input', False) and getattr(trainer, 'train_with_q_input', None) is not None: #only for back-compatibility with configs in older versions
|
|
181
|
+
trainer.train_with_q_input = True
|
|
182
|
+
|
|
183
|
+
return trainer
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
|
|
187
|
+
"""Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
|
|
188
|
+
saved weights into the network
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
config_name (str): name of the configuration file
|
|
192
|
+
config_dir (str): path of the configuration directory
|
|
193
|
+
model_path (str, optional): path to the network weights. The default path is 'saved_models' located in the package directory
|
|
194
|
+
load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
|
|
195
|
+
inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Trainer: the trainer object
|
|
199
|
+
"""
|
|
200
|
+
config = load_config(config_name, config_dir)
|
|
201
|
+
#config['model']['network']['pretrained_name'] = None
|
|
202
|
+
|
|
203
|
+
config['model']['network']['device'] = inference_device
|
|
204
|
+
config['dset']['prior_sampler']['kwargs']['device'] = inference_device
|
|
205
|
+
config['dset']['q_generator']['kwargs']['device'] = inference_device
|
|
206
|
+
|
|
207
|
+
trainer = get_trainer_from_config(config)
|
|
208
|
+
|
|
209
|
+
num_params = sum(p.numel() for p in trainer.model.parameters())
|
|
210
|
+
|
|
211
|
+
print(
|
|
212
|
+
f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M',
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if not load_weights:
|
|
216
|
+
return trainer
|
|
217
|
+
|
|
218
|
+
if not model_path:
|
|
219
|
+
model_name = f'model_{config_name}.pt'
|
|
220
|
+
model_path = SAVED_MODELS_DIR / model_name
|
|
221
|
+
|
|
222
|
+
if str(model_path).endswith('.pt'):
|
|
223
|
+
try:
|
|
224
|
+
state_dict = torch.load(model_path, map_location=inference_device, weights_only=False)
|
|
225
|
+
except Exception as err:
|
|
226
|
+
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
227
|
+
|
|
228
|
+
if 'model' in state_dict:
|
|
229
|
+
trainer.model.load_state_dict(state_dict['model'])
|
|
230
|
+
else:
|
|
231
|
+
trainer.model.load_state_dict(state_dict)
|
|
232
|
+
|
|
233
|
+
elif str(model_path).endswith('.safetensors'):
|
|
234
|
+
try:
|
|
235
|
+
load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
|
|
236
|
+
except Exception as err:
|
|
237
|
+
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
238
|
+
|
|
239
|
+
else:
|
|
240
|
+
raise RuntimeError('Weigths file with unknown extension')
|
|
241
|
+
|
|
242
|
+
return trainer
|
|
243
|
+
|
|
244
|
+
def get_callbacks_by_name(config_name, config_dir=None):
|
|
245
|
+
"""Initializes the trainer callbacks based on a configuration file
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
config_name (str): name of the configuration file
|
|
249
|
+
config_dir (str): path of the configuration directory
|
|
250
|
+
"""
|
|
251
|
+
config = load_config(config_name, config_dir)
|
|
252
|
+
callbacks = get_callbacks_from_config(config)
|
|
253
|
+
|
|
254
|
+
return callbacks
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def init_network(config: dict, saved_models_dir: Path):
|
|
258
|
+
"""Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model"""
|
|
259
|
+
device = config.get('device', 'cuda')
|
|
260
|
+
network = init_from_conf(config).to(device)
|
|
261
|
+
network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir)
|
|
262
|
+
return network
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def load_pretrained(model, model_name: str, saved_models_dir: Path):
|
|
266
|
+
"""Loads the saved weights into the network"""
|
|
267
|
+
if not model_name:
|
|
268
|
+
return model
|
|
269
|
+
|
|
270
|
+
if '.' not in model_name:
|
|
271
|
+
model_name = model_name + '.pt'
|
|
272
|
+
model_path = saved_models_dir / model_name
|
|
273
|
+
|
|
274
|
+
if not model_path.is_file():
|
|
275
|
+
model_path = saved_models_dir / f'model_{model_name}'
|
|
276
|
+
|
|
277
|
+
if not model_path.is_file():
|
|
278
|
+
raise FileNotFoundError(f'File {str(model_path)} does not exist.')
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
pretrained = torch.load(model_path, weights_only=False)
|
|
282
|
+
except Exception as err:
|
|
283
|
+
raise RuntimeError(f'Could not load model from {str(model_path)}') from err
|
|
284
|
+
|
|
285
|
+
if 'model' in pretrained:
|
|
286
|
+
pretrained = pretrained['model']
|
|
287
|
+
try:
|
|
288
|
+
model.load_state_dict(pretrained)
|
|
289
|
+
except Exception as err:
|
|
290
|
+
raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err
|
|
291
|
+
|
|
292
|
+
return model
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def init_dset(config: dict):
|
|
296
|
+
"""Initializes the dataset / dataloader object"""
|
|
297
|
+
dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
|
|
298
|
+
dset_kwargs = config.get('kwargs', {})
|
|
299
|
+
|
|
300
|
+
prior_sampler = init_from_conf(config['prior_sampler'])
|
|
301
|
+
intensity_noise = init_from_conf(config['intensity_noise'])
|
|
302
|
+
q_generator = init_from_conf(config['q_generator'])
|
|
303
|
+
curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None
|
|
304
|
+
smearing = init_from_conf(config['smearing']) if 'smearing' in config else None
|
|
305
|
+
q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None
|
|
306
|
+
|
|
307
|
+
dset = dset_cls(
|
|
308
|
+
q_generator=q_generator,
|
|
309
|
+
prior_sampler=prior_sampler,
|
|
310
|
+
intensity_noise=intensity_noise,
|
|
311
|
+
curves_scaler=curves_scaler,
|
|
312
|
+
smearing=smearing,
|
|
313
|
+
q_noise=q_noise,
|
|
314
|
+
**dset_kwargs,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return dset
|
|
318
|
+
|
|
319
|
+
def split_complex_tensors(state_dict):
|
|
320
|
+
new_state_dict = {}
|
|
321
|
+
for key, tensor in state_dict.items():
|
|
322
|
+
if tensor.is_complex():
|
|
323
|
+
new_state_dict[f"{key}_real"] = tensor.real.clone()
|
|
324
|
+
new_state_dict[f"{key}_imag"] = tensor.imag.clone()
|
|
325
|
+
else:
|
|
326
|
+
new_state_dict[key] = tensor
|
|
327
|
+
return new_state_dict
|
|
328
|
+
|
|
329
|
+
def recombine_complex_tensors(state_dict):
|
|
330
|
+
new_state_dict = {}
|
|
331
|
+
keys = list(state_dict.keys())
|
|
332
|
+
visited = set()
|
|
333
|
+
|
|
334
|
+
for key in keys:
|
|
335
|
+
if key.endswith('_real') or key.endswith('_imag'):
|
|
336
|
+
base_key = key[:-5]
|
|
337
|
+
new_state_dict[base_key] = torch.complex(state_dict[base_key + '_real'], state_dict[base_key + '_imag'])
|
|
338
|
+
visited.add(base_key + '_real')
|
|
339
|
+
visited.add(base_key + '_imag')
|
|
340
|
+
elif key not in visited:
|
|
341
|
+
new_state_dict[key] = state_dict[key]
|
|
342
|
+
|
|
343
|
+
return new_state_dict
|
|
344
|
+
|
|
345
|
+
def convert_pt_to_safetensors(input_dir):
|
|
346
|
+
"""Creates '.safetensors' files for all the model state dictionaries inside '.pt' files in the specified directory.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
input_dir (str): directory containing model weights
|
|
350
|
+
"""
|
|
351
|
+
if not os.path.isdir(input_dir):
|
|
352
|
+
raise ValueError(f"Input directory {input_dir} does not exist")
|
|
353
|
+
|
|
354
|
+
for file_name in os.listdir(input_dir):
|
|
355
|
+
if file_name.endswith('.pt'):
|
|
356
|
+
pt_file_path = os.path.join(input_dir, file_name)
|
|
357
|
+
safetensors_file_path = os.path.join(input_dir, file_name[:-3] + '.safetensors')
|
|
358
|
+
|
|
359
|
+
if os.path.exists(safetensors_file_path):
|
|
360
|
+
print(f"Skipping {pt_file_path}, corresponding .safetensors file already exists.")
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
print(f"Converting {pt_file_path} to .safetensors format.")
|
|
364
|
+
data_pt = torch.load(pt_file_path, weights_only=False)
|
|
365
|
+
model_state_dict = data_pt["model"]
|
|
366
|
+
model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
|
|
367
|
+
|
|
368
|
+
safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
|
|
369
|
+
|
|
370
|
+
def convert_files_to_safetensors(files):
|
|
371
|
+
"""
|
|
372
|
+
Converts specified .pt files to .safetensors format.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
files (str or list of str): Path(s) to .pt files containing model state dictionaries.
|
|
376
|
+
"""
|
|
377
|
+
if isinstance(files, str):
|
|
378
|
+
files = [files]
|
|
379
|
+
|
|
380
|
+
for pt_file_path in files:
|
|
381
|
+
if not pt_file_path.endswith('.pt'):
|
|
382
|
+
print(f"Skipping {pt_file_path}: not a .pt file.")
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
if not os.path.exists(pt_file_path):
|
|
386
|
+
print(f"File {pt_file_path} does not exist.")
|
|
387
|
+
continue
|
|
388
|
+
|
|
389
|
+
safetensors_file_path = pt_file_path[:-3] + '.safetensors'
|
|
390
|
+
|
|
391
|
+
if os.path.exists(safetensors_file_path):
|
|
392
|
+
print(f"Skipping {pt_file_path}: .safetensors version already exists.")
|
|
393
|
+
continue
|
|
394
|
+
|
|
395
|
+
print(f"Converting {pt_file_path} to .safetensors format.")
|
|
396
|
+
data_pt = torch.load(pt_file_path, weights_only=False)
|
|
397
|
+
model_state_dict = data_pt["model"]
|
|
398
|
+
model_state_dict = split_complex_tensors(model_state_dict)
|
|
399
|
+
|
|
400
|
+
safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
|
|
401
|
+
|
|
402
|
+
def load_state_dict_safetensors(model, filename, device):
|
|
403
|
+
state_dict = safetensors.torch.load_file(filename=filename, device=device)
|
|
404
|
+
state_dict = recombine_complex_tensors(state_dict)
|
|
405
|
+
model.load_state_dict(state_dict)
|
reflectorch/train.py
ADDED
reflectorch/utils.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numpy import ndarray
|
|
3
|
+
|
|
4
|
+
from torch import Tensor, tensor
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'to_np',
|
|
8
|
+
'to_t',
|
|
9
|
+
'angle_to_q',
|
|
10
|
+
'q_to_angle',
|
|
11
|
+
'energy_to_wavelength',
|
|
12
|
+
'wavelength_to_energy',
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def to_np(arr):
|
|
17
|
+
"""Converts Pytorch tensor or Python list to Numpy array
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
arr (torch.Tensor or list): Input Pytorch tensor or Python list
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
numpy.ndarray: Converted Numpy array
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
if isinstance(arr, Tensor):
|
|
27
|
+
return arr.detach().cpu().numpy()
|
|
28
|
+
return np.asarray(arr)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def to_t(arr, device=None, dtype=None):
|
|
32
|
+
"""Converts Numpy array or Python list to Pytorch tensor
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
arr (numpy.ndarray or list): Input
|
|
36
|
+
device (torch.device or str, optional): device for the tensor ('cpu', 'cuda')
|
|
37
|
+
dtype (torch.dtype, optional): data type of the tensor (e.g. torch.float32)
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
torch.Tensor: converted Pytorch tensor
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
if not isinstance(arr, Tensor):
|
|
44
|
+
return tensor(arr, device=device, dtype=dtype)
|
|
45
|
+
return arr
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# taken from mlreflect package
|
|
49
|
+
# mlreflect/xrrloader/dataloader/transform.py
|
|
50
|
+
|
|
51
|
+
def angle_to_q(scattering_angle: ndarray or float, wavelength: float):
|
|
52
|
+
"""Conversion from full scattering angle (degrees) to scattering vector (inverse angstroms)"""
|
|
53
|
+
return 4 * np.pi / wavelength * np.sin(scattering_angle / 2 * np.pi / 180)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def q_to_angle(q: ndarray or float, wavelength: float):
|
|
57
|
+
"""Conversion from scattering vector (inverse angstroms) to full scattering angle (degrees)"""
|
|
58
|
+
return 2 * np.arcsin(q * wavelength / (4 * np.pi)) / np.pi * 180
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def energy_to_wavelength(energy: float):
|
|
62
|
+
"""Conversion from photon energy (eV) to photon wavelength (angstroms)"""
|
|
63
|
+
return 1.2398 / energy * 1e4
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def wavelength_to_energy(wavelength: float):
|
|
67
|
+
"""Conversion from photon wavelength (angstroms) to photon energy (eV)"""
|
|
68
|
+
return 1.2398 / wavelength * 1e4
|
|
69
|
+
|
|
70
|
+
def get_filtering_mask(Q, R, dR, threshold=0.3, consecutive=3,
|
|
71
|
+
remove_singles=True, remove_consecutives=True,
|
|
72
|
+
q_start_trunc=0.1):
|
|
73
|
+
Q, R, dR = Q.copy(), R.copy(), dR.copy()
|
|
74
|
+
rel_error = np.abs(dR / R)
|
|
75
|
+
|
|
76
|
+
# Mask for singles
|
|
77
|
+
mask_singles = (rel_error >= threshold) if remove_singles else np.zeros_like(Q, dtype=bool)
|
|
78
|
+
|
|
79
|
+
# Mask for truncation
|
|
80
|
+
mask_consecutive = np.zeros_like(Q, dtype=bool)
|
|
81
|
+
if remove_consecutives:
|
|
82
|
+
count = 0
|
|
83
|
+
cutoff_idx = None
|
|
84
|
+
for i in range(len(Q)):
|
|
85
|
+
if Q[i] < q_start_trunc:
|
|
86
|
+
continue
|
|
87
|
+
if rel_error[i] >= threshold:
|
|
88
|
+
count += 1
|
|
89
|
+
if count >= consecutive:
|
|
90
|
+
cutoff_idx = i - consecutive + 1
|
|
91
|
+
break
|
|
92
|
+
else:
|
|
93
|
+
count = 0
|
|
94
|
+
if cutoff_idx is not None:
|
|
95
|
+
mask_consecutive[cutoff_idx:] = True
|
|
96
|
+
|
|
97
|
+
final_mask = mask_singles | mask_consecutive
|
|
98
|
+
return ~final_mask
|