reflectorch 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +23 -0
- reflectorch/data_generation/__init__.py +130 -0
- reflectorch/data_generation/dataset.py +196 -0
- reflectorch/data_generation/likelihoods.py +86 -0
- reflectorch/data_generation/noise.py +371 -0
- reflectorch/data_generation/priors/__init__.py +66 -0
- reflectorch/data_generation/priors/base.py +61 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
- reflectorch/data_generation/priors/independent_priors.py +201 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +110 -0
- reflectorch/data_generation/priors/no_constraints.py +212 -0
- reflectorch/data_generation/priors/parametric_models.py +767 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
- reflectorch/data_generation/priors/params.py +258 -0
- reflectorch/data_generation/priors/sampler_strategies.py +306 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +377 -0
- reflectorch/data_generation/priors/utils.py +124 -0
- reflectorch/data_generation/process_data.py +47 -0
- reflectorch/data_generation/q_generator.py +232 -0
- reflectorch/data_generation/reflectivity/__init__.py +56 -0
- reflectorch/data_generation/reflectivity/abeles.py +81 -0
- reflectorch/data_generation/reflectivity/kinematical.py +58 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +123 -0
- reflectorch/data_generation/scale_curves.py +118 -0
- reflectorch/data_generation/smearing.py +67 -0
- reflectorch/data_generation/utils.py +154 -0
- reflectorch/extensions/__init__.py +6 -0
- reflectorch/extensions/jupyter/__init__.py +12 -0
- reflectorch/extensions/jupyter/callbacks.py +40 -0
- reflectorch/extensions/matplotlib/__init__.py +11 -0
- reflectorch/extensions/matplotlib/losses.py +38 -0
- reflectorch/inference/__init__.py +22 -0
- reflectorch/inference/inference_model.py +734 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +16 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +171 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +37 -0
- reflectorch/ml/basic_trainer.py +286 -0
- reflectorch/ml/callbacks.py +86 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +38 -0
- reflectorch/ml/schedulers.py +246 -0
- reflectorch/ml/trainers.py +126 -0
- reflectorch/ml/utils.py +9 -0
- reflectorch/models/__init__.py +22 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +27 -0
- reflectorch/models/encoders/conv_encoder.py +211 -0
- reflectorch/models/encoders/conv_res_net.py +119 -0
- reflectorch/models/encoders/fno.py +127 -0
- reflectorch/models/encoders/transformers.py +56 -0
- reflectorch/models/networks/__init__.py +18 -0
- reflectorch/models/networks/mlp_networks.py +256 -0
- reflectorch/models/networks/residual_net.py +131 -0
- reflectorch/paths.py +33 -0
- reflectorch/runs/__init__.py +35 -0
- reflectorch/runs/config.py +31 -0
- reflectorch/runs/slurm_utils.py +99 -0
- reflectorch/runs/train.py +85 -0
- reflectorch/runs/utils.py +300 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +74 -0
- reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
- reflectorch-1.0.0.dist-info/METADATA +115 -0
- reflectorch-1.0.0.dist-info/RECORD +83 -0
- reflectorch-1.0.0.dist-info/WHEEL +5 -0
- reflectorch-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from time import perf_counter
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EvaluateTime(list):
|
|
7
|
+
@contextmanager
|
|
8
|
+
def __call__(self, name: str, *args, **kwargs):
|
|
9
|
+
start = perf_counter()
|
|
10
|
+
yield
|
|
11
|
+
self.action(perf_counter() - start, name, *args, **kwargs)
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def action(delta_time, name, *args, **kwargs):
|
|
15
|
+
print(f"Time for {name} = {delta_time:.2f} sec")
|
|
16
|
+
|
|
17
|
+
def __repr__(self):
|
|
18
|
+
return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def print_time(name: str or callable):
|
|
22
|
+
if isinstance(name, str):
|
|
23
|
+
return _print_time_context(name)
|
|
24
|
+
else:
|
|
25
|
+
return _print_time_wrap(name)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _print_time_wrap(func, name: str = None):
|
|
29
|
+
name = name or func.__name__
|
|
30
|
+
|
|
31
|
+
@wraps(func)
|
|
32
|
+
def wrapped_func(*args, **kwargs):
|
|
33
|
+
with _print_time_context(name):
|
|
34
|
+
return func(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
return wrapped_func
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@contextmanager
|
|
40
|
+
def _print_time_context(name: str):
|
|
41
|
+
start = perf_counter()
|
|
42
|
+
yield
|
|
43
|
+
print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.priors.utils import uniform_sampler
|
|
5
|
+
from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
|
|
6
|
+
from reflectorch.data_generation.priors.params import Params
|
|
7
|
+
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def simple_sampler_solution(
|
|
11
|
+
likelihood: LogLikelihood,
|
|
12
|
+
predicted_params: UniformSubPriorParams,
|
|
13
|
+
total_min_bounds: Tensor,
|
|
14
|
+
total_max_bounds: Tensor,
|
|
15
|
+
num: int = 2 ** 15,
|
|
16
|
+
coef: float = 0.1,
|
|
17
|
+
) -> UniformSubPriorParams:
|
|
18
|
+
sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
|
|
19
|
+
sampled_params = Params.from_tensor(sampled_params_t)
|
|
20
|
+
return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def sample_around_params(predicted_params: UniformSubPriorParams,
|
|
24
|
+
total_min_bounds: Tensor,
|
|
25
|
+
total_max_bounds: Tensor,
|
|
26
|
+
num: int = 2 ** 15,
|
|
27
|
+
coef: float = 0.1,
|
|
28
|
+
) -> Tensor:
|
|
29
|
+
params_t = predicted_params.as_tensor(add_bounds=False)
|
|
30
|
+
|
|
31
|
+
delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
|
|
32
|
+
min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
|
|
33
|
+
max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
|
|
34
|
+
|
|
35
|
+
sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
|
|
36
|
+
sampled_params_t[0] = params_t[0]
|
|
37
|
+
|
|
38
|
+
return sampled_params_t
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_best_mse_param(
|
|
42
|
+
params: Params,
|
|
43
|
+
likelihood: LogLikelihood,
|
|
44
|
+
min_bounds: Tensor = None,
|
|
45
|
+
max_bounds: Tensor = None,
|
|
46
|
+
):
|
|
47
|
+
sampled_curves = params.reflectivity(likelihood.q)
|
|
48
|
+
log_probs = likelihood.calc_log_likelihood(sampled_curves)
|
|
49
|
+
best_idx = torch.argmax(log_probs)
|
|
50
|
+
best_param = params[best_idx:best_idx + 1]
|
|
51
|
+
|
|
52
|
+
if min_bounds is not None:
|
|
53
|
+
best_param = UniformSubPriorParams.from_tensor(
|
|
54
|
+
torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
|
|
55
|
+
)
|
|
56
|
+
return best_param
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.optimize import minimize, curve_fit
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.reflectivity import abeles_np
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"standard_refl_fit",
|
|
10
|
+
"fit_refl_curve",
|
|
11
|
+
"restore_masked_params",
|
|
12
|
+
"get_fit_with_growth",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def standard_restore_params(fitted_params) -> dict:
|
|
17
|
+
num_layers = (fitted_params.size - 2) // 3
|
|
18
|
+
|
|
19
|
+
return dict(
|
|
20
|
+
thickness=fitted_params[:num_layers],
|
|
21
|
+
roughness=fitted_params[num_layers:2 * num_layers + 1],
|
|
22
|
+
sld=fitted_params[2 * num_layers + 1:],
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def mse_loss(curve1, curve2):
|
|
27
|
+
return np.sum((curve1 - curve2) ** 2)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def standard_refl_fit(
|
|
31
|
+
q: np.ndarray, curve: np.ndarray,
|
|
32
|
+
init_params: np.ndarray,
|
|
33
|
+
bounds: np.ndarray = None,
|
|
34
|
+
refl_generator=abeles_np,
|
|
35
|
+
restore_params_func=standard_restore_params,
|
|
36
|
+
scale_curve_func=np.log10,
|
|
37
|
+
**kwargs
|
|
38
|
+
):
|
|
39
|
+
if bounds is not None:
|
|
40
|
+
kwargs['bounds'] = bounds
|
|
41
|
+
init_params = np.clip(init_params, *bounds)
|
|
42
|
+
|
|
43
|
+
res = curve_fit(
|
|
44
|
+
get_scaled_curve_func(
|
|
45
|
+
refl_generator=refl_generator,
|
|
46
|
+
restore_params_func=restore_params_func,
|
|
47
|
+
scale_curve_func=scale_curve_func,
|
|
48
|
+
),
|
|
49
|
+
q, scale_curve_func(curve),
|
|
50
|
+
p0=init_params, **kwargs
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
curve = refl_generator(q, **restore_params_func(res[0]))
|
|
54
|
+
return res[0], curve
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_fit_with_growth(
|
|
58
|
+
q: np.ndarray, curve: np.ndarray,
|
|
59
|
+
init_params: np.ndarray,
|
|
60
|
+
bounds: np.ndarray = None,
|
|
61
|
+
init_d_change: float = 0.,
|
|
62
|
+
max_d_change: float = 30.,
|
|
63
|
+
scale_curve_func=np.log10,
|
|
64
|
+
**kwargs
|
|
65
|
+
):
|
|
66
|
+
init_params = np.array(list(init_params) + [init_d_change])
|
|
67
|
+
if bounds is not None:
|
|
68
|
+
bounds = np.concatenate([bounds, np.array([0, max_d_change])[..., None]], -1)
|
|
69
|
+
|
|
70
|
+
params, curve = standard_refl_fit(
|
|
71
|
+
q, curve, init_params, bounds, refl_generator=growth_reflectivity,
|
|
72
|
+
restore_params_func=get_restore_params_with_growth_func(q_size=q.size, d_idx=0),
|
|
73
|
+
scale_curve_func=scale_curve_func, **kwargs
|
|
74
|
+
)
|
|
75
|
+
params[0] += params[-1] / 2
|
|
76
|
+
return params, curve
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def fit_refl_curve(q: np.ndarray, curve: np.ndarray,
|
|
80
|
+
init_params: np.ndarray,
|
|
81
|
+
bounds: np.ndarray = None,
|
|
82
|
+
refl_generator=abeles_np,
|
|
83
|
+
restore_params_func=standard_restore_params,
|
|
84
|
+
scale_curve_func=np.log10,
|
|
85
|
+
**kwargs
|
|
86
|
+
) -> np.ndarray:
|
|
87
|
+
fitting_func = get_fitting_func(
|
|
88
|
+
q=q, curve=curve,
|
|
89
|
+
refl_generator=refl_generator,
|
|
90
|
+
restore_params_func=restore_params_func,
|
|
91
|
+
scale_curve_func=scale_curve_func,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
res = minimize(fitting_func, init_params, bounds=bounds, **kwargs)
|
|
95
|
+
|
|
96
|
+
if not res.success:
|
|
97
|
+
warnings.warn(f"Minimization did not converge.")
|
|
98
|
+
return res.x
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_scaled_curve_func(
|
|
102
|
+
refl_generator=abeles_np,
|
|
103
|
+
restore_params_func=standard_restore_params,
|
|
104
|
+
scale_curve_func=np.log10,
|
|
105
|
+
):
|
|
106
|
+
def scaled_curve_func(q, *fitted_params):
|
|
107
|
+
fitted_params = restore_params_func(np.asarray(fitted_params))
|
|
108
|
+
fitted_curve = refl_generator(q, **fitted_params)
|
|
109
|
+
scaled_curve = scale_curve_func(fitted_curve)
|
|
110
|
+
return scaled_curve
|
|
111
|
+
|
|
112
|
+
return scaled_curve_func
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_fitting_func(
|
|
116
|
+
q: np.ndarray,
|
|
117
|
+
curve: np.ndarray,
|
|
118
|
+
refl_generator=abeles_np,
|
|
119
|
+
restore_params_func=standard_restore_params,
|
|
120
|
+
scale_curve_func=np.log10,
|
|
121
|
+
loss_func=mse_loss,
|
|
122
|
+
):
|
|
123
|
+
scaled_curve = scale_curve_func(curve)
|
|
124
|
+
|
|
125
|
+
def fitting_func(fitted_params):
|
|
126
|
+
fitted_params = restore_params_func(fitted_params)
|
|
127
|
+
fitted_curve = refl_generator(q, **fitted_params)
|
|
128
|
+
loss = loss_func(scale_curve_func(fitted_curve), scaled_curve)
|
|
129
|
+
return loss
|
|
130
|
+
|
|
131
|
+
return fitting_func
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def restore_masked_params(fixed_params, fixed_mask):
|
|
135
|
+
def restore_params(fitted_params) -> dict:
|
|
136
|
+
params = np.empty_like(fixed_mask).astype(fitted_params.dtype)
|
|
137
|
+
params[fixed_mask] = fixed_params
|
|
138
|
+
params[~fixed_mask] = fitted_params
|
|
139
|
+
return standard_restore_params(params)
|
|
140
|
+
|
|
141
|
+
return restore_params
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def base_params2growth(base_params: dict, d_shift: np.ndarray, d_idx: int = 0) -> dict:
|
|
145
|
+
d_init = base_params['thickness'][None]
|
|
146
|
+
q_size = d_shift.size
|
|
147
|
+
d = d_init.repeat(q_size, 0)
|
|
148
|
+
d[:, d_idx] = d[:, d_idx] + d_shift
|
|
149
|
+
|
|
150
|
+
roughness = np.broadcast_to(base_params['roughness'][None], (q_size, base_params['roughness'].size))
|
|
151
|
+
sld = np.broadcast_to(base_params['sld'][None], (q_size, base_params['sld'].size))
|
|
152
|
+
|
|
153
|
+
return {
|
|
154
|
+
'thickness': d,
|
|
155
|
+
'roughness': roughness,
|
|
156
|
+
'sld': sld,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_restore_params_with_growth_func(q_size: int, d_idx: int = 0):
|
|
161
|
+
def restore_params_with_growth(fitted_params) -> dict:
|
|
162
|
+
fitted_params, delta_d = fitted_params[:-1], fitted_params[-1]
|
|
163
|
+
base_params = standard_restore_params(fitted_params)
|
|
164
|
+
d_shift = np.linspace(0, delta_d, q_size)
|
|
165
|
+
return base_params2growth(base_params, d_shift, d_idx)
|
|
166
|
+
|
|
167
|
+
return restore_params_with_growth
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def growth_reflectivity(q: np.ndarray, **kwargs):
|
|
171
|
+
return abeles_np(q[..., None], **kwargs).flatten()
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from tqdm import trange
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation import LogLikelihood, reflectivity, PriorSampler
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ReflGradientFit(object):
|
|
10
|
+
"""Directly optimizes the thin film parameters using a Pytorch optimizer
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
q (Tensor): the q positions
|
|
14
|
+
exp_curve (Tensor): the experimental reflectivity curve
|
|
15
|
+
prior_sampler (PriorSampler): the prior sampler
|
|
16
|
+
params (Tensor): the initial thin film parameters
|
|
17
|
+
fit_indices (Tensor): the indices of the thin film parameters which are to be fitted
|
|
18
|
+
sigmas (Tensor, optional): error bars of the reflectivity curve, if not provided they are derived from ``rel_err`` and ``abs_err``. Defaults to None.
|
|
19
|
+
optim_cls (Type[torch.optim.Optimizer], optional): the Pytorch optimizer class. Defaults to None.
|
|
20
|
+
lr (float, optional): the learning rate. Defaults to 1e-2.
|
|
21
|
+
rel_err (float, optional): the relative error in the reflectivity curve. Defaults to 0.1.
|
|
22
|
+
abs_err (float, optional): the absolute error in the reflectivity curve. Defaults to 1e-7.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self,
|
|
25
|
+
q: Tensor,
|
|
26
|
+
exp_curve: Tensor,
|
|
27
|
+
prior_sampler: PriorSampler,
|
|
28
|
+
params: Tensor,
|
|
29
|
+
fit_indices: Tensor,
|
|
30
|
+
sigmas: Tensor = None,
|
|
31
|
+
optim_cls=None,
|
|
32
|
+
lr: float = 1e-2,
|
|
33
|
+
rel_err: float = 0.1,
|
|
34
|
+
abs_err: float = 1e-7,
|
|
35
|
+
):
|
|
36
|
+
self.q = q
|
|
37
|
+
|
|
38
|
+
if sigmas is None:
|
|
39
|
+
sigmas = exp_curve * rel_err + abs_err
|
|
40
|
+
|
|
41
|
+
self.likelihood = LogLikelihood(q, exp_curve, prior_sampler, sigmas)
|
|
42
|
+
|
|
43
|
+
self.num_layers = params.shape[-1] // 3
|
|
44
|
+
self.fit_indices = fit_indices
|
|
45
|
+
self.init_params = params.clone()
|
|
46
|
+
self.params_to_fit = nn.Parameter(self.init_params[fit_indices].clone())
|
|
47
|
+
|
|
48
|
+
optim_cls = optim_cls or torch.optim.Adam
|
|
49
|
+
self.optim = optim_cls([self.params_to_fit], lr)
|
|
50
|
+
|
|
51
|
+
self.losses = []
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def params(self):
|
|
55
|
+
params = self.init_params.clone()
|
|
56
|
+
params[self.fit_indices] = self.params_to_fit
|
|
57
|
+
return params
|
|
58
|
+
|
|
59
|
+
def calc_log_likelihood(self):
|
|
60
|
+
return self.likelihood.calc_log_likelihood(self.refl())
|
|
61
|
+
|
|
62
|
+
def calc_log_prob_loss(self):
|
|
63
|
+
return - self.calc_log_likelihood().mean()
|
|
64
|
+
|
|
65
|
+
def refl(self):
|
|
66
|
+
d, sigma, rho = torch.split(self.params, [self.num_layers, self.num_layers + 1, self.num_layers + 1], -1)
|
|
67
|
+
return reflectivity(self.q, d, sigma, rho)
|
|
68
|
+
|
|
69
|
+
def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
|
|
70
|
+
"""Runs the optimization process
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
num_iterations (int, optional): number of iterations the optimization is run for. Defaults to 500.
|
|
74
|
+
disable_tqdm (bool, optional): whether to disable the prograss bar. Defaults to False.
|
|
75
|
+
"""
|
|
76
|
+
pbar = trange(num_iterations, disable=disable_tqdm)
|
|
77
|
+
|
|
78
|
+
for _ in pbar:
|
|
79
|
+
self.optim.zero_grad()
|
|
80
|
+
loss = self.calc_log_prob_loss()
|
|
81
|
+
loss.backward()
|
|
82
|
+
self.optim.step()
|
|
83
|
+
self.losses.append(loss.item())
|
|
84
|
+
pbar.set_description(f'Loss = {loss.item():.2e}')
|
|
85
|
+
|
|
86
|
+
def clear(self):
|
|
87
|
+
self.losses.clear()
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from reflectorch.ml.basic_trainer import *
|
|
8
|
+
from reflectorch.ml.callbacks import *
|
|
9
|
+
from reflectorch.ml.trainers import *
|
|
10
|
+
from reflectorch.ml.loggers import *
|
|
11
|
+
from reflectorch.ml.schedulers import *
|
|
12
|
+
from reflectorch.ml.dataloaders import *
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
'Trainer',
|
|
16
|
+
'TrainerCallback',
|
|
17
|
+
'DataLoader',
|
|
18
|
+
'PeriodicTrainerCallback',
|
|
19
|
+
'SaveBestModel',
|
|
20
|
+
'LogLosses',
|
|
21
|
+
'Logger',
|
|
22
|
+
'Loggers',
|
|
23
|
+
'PrintLogger',
|
|
24
|
+
'ScheduleBatchSize',
|
|
25
|
+
'ScheduleLR',
|
|
26
|
+
'StepLR',
|
|
27
|
+
'CyclicLR',
|
|
28
|
+
'LogCyclicLR',
|
|
29
|
+
'ReduceLROnPlateau',
|
|
30
|
+
'OneCycleLR',
|
|
31
|
+
'ReflectivityDataLoader',
|
|
32
|
+
'MultilayerDataLoader',
|
|
33
|
+
'RealTimeSimTrainer',
|
|
34
|
+
'DenoisingAETrainer',
|
|
35
|
+
'VAETrainer',
|
|
36
|
+
'PointEstimatorTrainer',
|
|
37
|
+
]
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Optional, Tuple, Iterable, Any, Union, Type
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
|
|
10
|
+
from tqdm.notebook import trange
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torch.nn import Module
|
|
15
|
+
|
|
16
|
+
from reflectorch.ml.loggers import Logger, Loggers
|
|
17
|
+
|
|
18
|
+
from .utils import is_divisor
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
'Trainer',
|
|
22
|
+
'TrainerCallback',
|
|
23
|
+
'DataLoader',
|
|
24
|
+
'PeriodicTrainerCallback',
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Trainer(object):
|
|
29
|
+
"""Trainer class
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model (nn.Module): neural network
|
|
33
|
+
loader (DataLoader): data loader
|
|
34
|
+
lr (float): learning rate
|
|
35
|
+
batch_size (int): batch size
|
|
36
|
+
clip_grad_norm (int, optional): maximum norm for gradient clipping if it is not ``None``. Defaults to None.
|
|
37
|
+
logger (Union[Logger, Tuple[Logger, ...], Loggers], optional): logger. Defaults to None.
|
|
38
|
+
optim_cls (Type[torch.optim.Optimizer], optional): Pytorch optimizer. Defaults to torch.optim.Adam.
|
|
39
|
+
optim_kwargs (dict, optional): optimizer arguments. Defaults to None.
|
|
40
|
+
train_with_q_input (bool, optional): if ``True`` the q values are also used as input. Defaults to False.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
TOTAL_LOSS_KEY: str = 'total_loss'
|
|
44
|
+
|
|
45
|
+
def __init__(self,
|
|
46
|
+
model: Module,
|
|
47
|
+
loader: 'DataLoader',
|
|
48
|
+
lr: float,
|
|
49
|
+
batch_size: int,
|
|
50
|
+
clip_grad_norm_max: Optional[int] = None,
|
|
51
|
+
train_with_q_input: bool = False,
|
|
52
|
+
logger: Union[Logger, Tuple[Logger, ...], Loggers] = None,
|
|
53
|
+
optim_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
|
54
|
+
optim_kwargs: dict = None,
|
|
55
|
+
**kwargs
|
|
56
|
+
):
|
|
57
|
+
|
|
58
|
+
self.model = model
|
|
59
|
+
self.loader = loader
|
|
60
|
+
self.batch_size = batch_size
|
|
61
|
+
self.clip_grad_norm_max = clip_grad_norm_max
|
|
62
|
+
self.train_with_q_input = train_with_q_input
|
|
63
|
+
|
|
64
|
+
self.optim = self.configure_optimizer(optim_cls, lr=lr, **(optim_kwargs or {}))
|
|
65
|
+
self.lrs = []
|
|
66
|
+
self.losses = defaultdict(list)
|
|
67
|
+
|
|
68
|
+
self.logger = _init_logger(logger)
|
|
69
|
+
self.callback_params = {}
|
|
70
|
+
|
|
71
|
+
for k, v in kwargs.items():
|
|
72
|
+
setattr(self, k, v)
|
|
73
|
+
|
|
74
|
+
self.init()
|
|
75
|
+
|
|
76
|
+
def init(self):
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
def log(self, name: str, data):
|
|
80
|
+
"""log data"""
|
|
81
|
+
self.logger.log(name, data)
|
|
82
|
+
|
|
83
|
+
def train(self,
|
|
84
|
+
num_batches: int,
|
|
85
|
+
callbacks: Union[Tuple['TrainerCallback', ...], 'TrainerCallback'] = (),
|
|
86
|
+
disable_tqdm: bool = False,
|
|
87
|
+
update_tqdm_freq: int = 10,
|
|
88
|
+
grad_accumulation_steps: int = 1,
|
|
89
|
+
):
|
|
90
|
+
"""starts the training process
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
num_batches (int): total number of training iterations
|
|
94
|
+
callbacks (Union[Tuple['TrainerCallback'], 'TrainerCallback']): the trainer callbacks. Defaults to ().
|
|
95
|
+
disable_tqdm (bool, optional): if ``True``, the progress bar is disabled. Defaults to False.
|
|
96
|
+
update_tqdm_freq (int, optional): frequency for updating the progress bar. Defaults to 10.
|
|
97
|
+
grad_accumulation_steps (int, optional): number of gradient accumulation steps. Defaults to 1.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
if isinstance(callbacks, TrainerCallback):
|
|
101
|
+
callbacks = (callbacks,)
|
|
102
|
+
|
|
103
|
+
callbacks = _StackedTrainerCallbacks(list(callbacks) + [self.loader])
|
|
104
|
+
|
|
105
|
+
pbar = trange(num_batches, disable=disable_tqdm)
|
|
106
|
+
|
|
107
|
+
callbacks.start_training(self)
|
|
108
|
+
|
|
109
|
+
for batch_num in pbar:
|
|
110
|
+
self.model.train()
|
|
111
|
+
|
|
112
|
+
self.optim.zero_grad()
|
|
113
|
+
total_loss, avr_loss_dict = 0, defaultdict(list)
|
|
114
|
+
|
|
115
|
+
for _ in range(grad_accumulation_steps):
|
|
116
|
+
|
|
117
|
+
batch_data = self.get_batch_by_idx(batch_num)
|
|
118
|
+
loss_dict = self.get_loss_dict(batch_data)
|
|
119
|
+
loss = loss_dict['loss'] / grad_accumulation_steps
|
|
120
|
+
total_loss += loss.item()
|
|
121
|
+
_update_loss_dict(avr_loss_dict, loss_dict)
|
|
122
|
+
|
|
123
|
+
if not torch.isfinite(loss).item():
|
|
124
|
+
raise ValueError('Loss is not finite!')
|
|
125
|
+
|
|
126
|
+
loss.backward()
|
|
127
|
+
|
|
128
|
+
if self.clip_grad_norm_max is not None:
|
|
129
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm_max)
|
|
130
|
+
self.optim.step()
|
|
131
|
+
|
|
132
|
+
avr_loss_dict = {k: np.mean(v) for k, v in avr_loss_dict.items()}
|
|
133
|
+
self._update_losses(avr_loss_dict, total_loss)
|
|
134
|
+
|
|
135
|
+
if not disable_tqdm:
|
|
136
|
+
self._update_tqdm(pbar, batch_num, update_tqdm_freq)
|
|
137
|
+
|
|
138
|
+
break_epoch = callbacks.end_batch(self, batch_num)
|
|
139
|
+
|
|
140
|
+
if break_epoch:
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
callbacks.end_training(self)
|
|
144
|
+
|
|
145
|
+
def _update_tqdm(self, pbar, batch_num: int, update_tqdm_freq: int):
|
|
146
|
+
if is_divisor(batch_num, update_tqdm_freq):
|
|
147
|
+
last_loss = np.mean(self.losses[self.TOTAL_LOSS_KEY][-10:])
|
|
148
|
+
pbar.set_description(f'Loss = {last_loss:.2e}')
|
|
149
|
+
|
|
150
|
+
def get_batch_by_idx(self, batch_num: int) -> Any:
|
|
151
|
+
raise NotImplementedError
|
|
152
|
+
|
|
153
|
+
def get_loss_dict(self, batch_data) -> dict:
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
def _update_losses(self, loss_dict: dict, loss: float) -> None:
|
|
157
|
+
_update_loss_dict(self.losses, loss_dict)
|
|
158
|
+
self.losses[self.TOTAL_LOSS_KEY].append(loss)
|
|
159
|
+
self.lrs.append(self.lr())
|
|
160
|
+
|
|
161
|
+
def configure_optimizer(self, optim_cls, lr: float, **kwargs) -> torch.optim.Optimizer:
|
|
162
|
+
"""configure the optimizer based on the optimizer class, the learning rate and the optimizer keyword arguments
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
optim_cls: the class of the optimizer
|
|
166
|
+
lr (float): the learning rate
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
torch.optim.Optimizer:
|
|
170
|
+
"""
|
|
171
|
+
optim = optim_cls(self.model.parameters(), lr, **kwargs)
|
|
172
|
+
return optim
|
|
173
|
+
|
|
174
|
+
def lr(self, param_group: int = 0) -> float:
|
|
175
|
+
"""get the learning rate"""
|
|
176
|
+
return self.optim.param_groups[param_group]['lr']
|
|
177
|
+
|
|
178
|
+
def set_lr(self, lr: float, param_group: int = 0) -> None:
|
|
179
|
+
"""set the learning rate"""
|
|
180
|
+
self.optim.param_groups[param_group]['lr'] = lr
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TrainerCallback(object):
|
|
184
|
+
"""Base class for trainer callbacks
|
|
185
|
+
"""
|
|
186
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
187
|
+
"""add functionality the start of training
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
trainer (Trainer): the trainer object
|
|
191
|
+
"""
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
def end_training(self, trainer: Trainer) -> None:
|
|
195
|
+
"""add functionality at the end of training
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
trainer (Trainer): the trainer object
|
|
199
|
+
"""
|
|
200
|
+
pass
|
|
201
|
+
|
|
202
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
203
|
+
"""add functionality at the end of the iteration / batch
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
trainer (Trainer): the trainer object
|
|
207
|
+
batch_num (int): the index of the current iteration / batch
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Union[bool, None]:
|
|
211
|
+
"""
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
def __repr__(self):
|
|
215
|
+
return f'{self.__class__.__name__}()'
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class DataLoader(TrainerCallback):
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class PeriodicTrainerCallback(TrainerCallback):
|
|
223
|
+
"""Base class for trainer callbacks which perform an action periodically after a number of iterations
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
step (int, optional): Number of iterations after which the action is repeated. Defaults to 1.
|
|
227
|
+
last_epoch (int, optional): the last training iteration for which the action is performed. Defaults to -1.
|
|
228
|
+
"""
|
|
229
|
+
def __init__(self, step: int = 1, last_epoch: int = -1):
|
|
230
|
+
self.step = step
|
|
231
|
+
self.last_epoch = last_epoch
|
|
232
|
+
|
|
233
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
234
|
+
"""add functionality at the end of the iteration / batch
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
trainer (Trainer): the trainer object
|
|
238
|
+
batch_num (int): the index of the current iteration / batch
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Union[bool, None]:
|
|
242
|
+
"""
|
|
243
|
+
if (
|
|
244
|
+
is_divisor(batch_num, self.step) and
|
|
245
|
+
(self.last_epoch == -1 or batch_num < self.last_epoch)
|
|
246
|
+
):
|
|
247
|
+
return self._end_batch(trainer, batch_num)
|
|
248
|
+
|
|
249
|
+
def _end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class _StackedTrainerCallbacks(TrainerCallback):
|
|
254
|
+
def __init__(self, callbacks: Iterable[TrainerCallback]):
|
|
255
|
+
self.callbacks = tuple(callbacks)
|
|
256
|
+
|
|
257
|
+
def start_training(self, trainer: Trainer) -> None:
|
|
258
|
+
for c in self.callbacks:
|
|
259
|
+
c.start_training(trainer)
|
|
260
|
+
|
|
261
|
+
def end_training(self, trainer: Trainer) -> None:
|
|
262
|
+
for c in self.callbacks:
|
|
263
|
+
c.end_training(trainer)
|
|
264
|
+
|
|
265
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> Union[bool, None]:
|
|
266
|
+
break_epoch = False
|
|
267
|
+
for c in self.callbacks:
|
|
268
|
+
break_epoch += bool(c.end_batch(trainer, batch_num))
|
|
269
|
+
return break_epoch
|
|
270
|
+
|
|
271
|
+
def __repr__(self):
|
|
272
|
+
callbacks = ", ".join(repr(c) for c in self.callbacks)
|
|
273
|
+
return f'StackedTrainerCallbacks({callbacks})'
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _init_logger(logger: Union[Logger, Tuple[Logger, ...], Loggers] = None):
|
|
277
|
+
if not logger:
|
|
278
|
+
return Logger()
|
|
279
|
+
if isinstance(logger, Logger):
|
|
280
|
+
return logger
|
|
281
|
+
return Loggers(*logger)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _update_loss_dict(loss_dict: dict, new_values: dict):
|
|
285
|
+
for k, v in new_values.items():
|
|
286
|
+
loss_dict[k].append(v.item())
|