reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
reflectorch/ml/trainers.py
CHANGED
|
@@ -1,201 +1,201 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import torch
|
|
3
|
-
import torch.nn.functional as F
|
|
4
|
-
from torch import nn
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from typing import Optional
|
|
7
|
-
|
|
8
|
-
from reflectorch.data_generation import BATCH_DATA_TYPE
|
|
9
|
-
from reflectorch.ml.basic_trainer import Trainer
|
|
10
|
-
from reflectorch.ml.dataloaders import ReflectivityDataLoader
|
|
11
|
-
|
|
12
|
-
__all__ = [
|
|
13
|
-
'RealTimeSimTrainer',
|
|
14
|
-
'DenoisingAETrainer',
|
|
15
|
-
'PointEstimatorTrainer',
|
|
16
|
-
]
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@dataclass
|
|
20
|
-
class BasicBatchData:
|
|
21
|
-
scaled_curves: torch.Tensor
|
|
22
|
-
scaled_bounds: torch.Tensor
|
|
23
|
-
scaled_params: torch.Tensor = None
|
|
24
|
-
scaled_sigmas: Optional[torch.Tensor] = None
|
|
25
|
-
scaled_q_values: Optional[torch.Tensor] = None
|
|
26
|
-
scaled_denoised_curves: Optional[torch.Tensor] = None
|
|
27
|
-
key_padding_mask: Optional[torch.Tensor] = None
|
|
28
|
-
scaled_conditioning_params: Optional[torch.Tensor] = None
|
|
29
|
-
unscaled_q_values: Optional[torch.Tensor] = None
|
|
30
|
-
|
|
31
|
-
class RealTimeSimTrainer(Trainer):
|
|
32
|
-
"""Trainer with functionality to customize the sampled batch of data"""
|
|
33
|
-
loader: ReflectivityDataLoader
|
|
34
|
-
|
|
35
|
-
def get_batch_by_idx(self, batch_num: int):
|
|
36
|
-
"""Gets a batch of data with the default batch size"""
|
|
37
|
-
batch_data = self.loader.get_batch(self.batch_size)
|
|
38
|
-
return self._get_batch(batch_data)
|
|
39
|
-
|
|
40
|
-
def get_batch_by_size(self, batch_size: int):
|
|
41
|
-
"""Gets a batch of data with a custom batch size"""
|
|
42
|
-
batch_data = self.loader.get_batch(batch_size)
|
|
43
|
-
return self._get_batch(batch_data)
|
|
44
|
-
|
|
45
|
-
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
46
|
-
"""Modify the batch of data sampled from the data loader"""
|
|
47
|
-
raise NotImplementedError
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
51
|
-
"""Point estimator trainer for the inverse problem."""
|
|
52
|
-
|
|
53
|
-
def init(self):
|
|
54
|
-
if getattr(self, 'use_l1_loss', False):
|
|
55
|
-
self.criterion = nn.L1Loss(reduction='none')
|
|
56
|
-
else:
|
|
57
|
-
self.criterion = nn.MSELoss(reduction='none')
|
|
58
|
-
self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
|
|
59
|
-
self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
|
|
60
|
-
if self.use_curve_reconstruction_loss:
|
|
61
|
-
self.loader.calc_denoised_curves = True
|
|
62
|
-
|
|
63
|
-
self.train_with_q_input = getattr(self, 'train_with_q_input', False)
|
|
64
|
-
self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
|
|
65
|
-
self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
|
|
66
|
-
|
|
67
|
-
def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
|
|
68
|
-
def get_scaled_or_none(key, scaler=None):
|
|
69
|
-
value = batch_data.get(key)
|
|
70
|
-
if value is None:
|
|
71
|
-
return None
|
|
72
|
-
scale_func = scaler or (lambda x: x)
|
|
73
|
-
return scale_func(value).to(torch.float32)
|
|
74
|
-
|
|
75
|
-
scaled_params = batch_data['scaled_params'].to(torch.float32)
|
|
76
|
-
scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
77
|
-
scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
|
|
78
|
-
scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
|
|
79
|
-
key_padding_mask = batch_data.get('key_padding_mask', None)
|
|
80
|
-
|
|
81
|
-
scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
|
|
82
|
-
conditioning_params = []
|
|
83
|
-
if scaled_q_resolutions is not None:
|
|
84
|
-
conditioning_params.append(scaled_q_resolutions)
|
|
85
|
-
scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
|
|
86
|
-
|
|
87
|
-
num_params = scaled_params.shape[-1] // 3
|
|
88
|
-
assert num_params * 3 == scaled_params.shape[-1]
|
|
89
|
-
scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
90
|
-
|
|
91
|
-
return BasicBatchData(
|
|
92
|
-
scaled_params=scaled_params,
|
|
93
|
-
scaled_bounds=scaled_bounds,
|
|
94
|
-
scaled_curves=scaled_curves,
|
|
95
|
-
scaled_q_values=scaled_q_values,
|
|
96
|
-
scaled_denoised_curves=scaled_denoised_curves,
|
|
97
|
-
scaled_conditioning_params=scaled_conditioning_params,
|
|
98
|
-
unscaled_q_values=batch_data['q_values'],
|
|
99
|
-
key_padding_mask=key_padding_mask,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
def get_loss_dict(self, batch_data: BasicBatchData):
|
|
103
|
-
"""Returns the regression loss"""
|
|
104
|
-
scaled_params=batch_data.scaled_params
|
|
105
|
-
scaled_curves=batch_data.scaled_curves
|
|
106
|
-
scaled_bounds=batch_data.scaled_bounds
|
|
107
|
-
scaled_q_values=batch_data.scaled_q_values
|
|
108
|
-
key_padding_mask=batch_data.key_padding_mask
|
|
109
|
-
scaled_conditioning_params=batch_data.scaled_conditioning_params
|
|
110
|
-
unscaled_q_values=batch_data.unscaled_q_values
|
|
111
|
-
|
|
112
|
-
predicted_params = self.model(
|
|
113
|
-
curves = scaled_curves,
|
|
114
|
-
bounds = scaled_bounds,
|
|
115
|
-
q_values = scaled_q_values,
|
|
116
|
-
conditioning_params = scaled_conditioning_params,
|
|
117
|
-
key_padding_mask = key_padding_mask,
|
|
118
|
-
unscaled_q_values = unscaled_q_values,
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
if not self.rescale_loss_interval_width:
|
|
122
|
-
loss = self.criterion(predicted_params, scaled_params).mean()
|
|
123
|
-
else:
|
|
124
|
-
n_params = scaled_params.shape[-1]
|
|
125
|
-
b_min = scaled_bounds[..., :n_params]
|
|
126
|
-
b_max = scaled_bounds[..., n_params:]
|
|
127
|
-
interval_width = b_max - b_min
|
|
128
|
-
|
|
129
|
-
base_loss = self.criterion(predicted_params, scaled_params)
|
|
130
|
-
if isinstance(self.criterion, torch.nn.MSELoss):
|
|
131
|
-
width_factors = (interval_width / 2) ** 2
|
|
132
|
-
elif isinstance(self.criterion, torch.nn.L1Loss):
|
|
133
|
-
width_factors = interval_width / 2
|
|
134
|
-
|
|
135
|
-
loss = (width_factors * base_loss).mean()
|
|
136
|
-
|
|
137
|
-
return {'loss': loss}
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
# class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
141
|
-
# """Trainer for the regression inverse problem with incorporation of prior bounds"""
|
|
142
|
-
# add_sigmas_to_context: bool = False
|
|
143
|
-
|
|
144
|
-
# def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
145
|
-
# scaled_params = batch_data['scaled_params'].to(torch.float32)
|
|
146
|
-
# scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
147
|
-
# if self.train_with_q_input:
|
|
148
|
-
# q_values = batch_data['q_values'].to(torch.float32)
|
|
149
|
-
# scaled_q_values = self.loader.q_generator.scale_q(q_values)
|
|
150
|
-
# else:
|
|
151
|
-
# scaled_q_values = None
|
|
152
|
-
|
|
153
|
-
# num_params = scaled_params.shape[-1] // 3
|
|
154
|
-
# assert num_params * 3 == scaled_params.shape[-1]
|
|
155
|
-
# scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
156
|
-
|
|
157
|
-
# return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
|
|
158
|
-
|
|
159
|
-
# def get_loss_dict(self, batch_data):
|
|
160
|
-
# """computes the loss dictionary"""
|
|
161
|
-
|
|
162
|
-
# scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
|
|
163
|
-
|
|
164
|
-
# if self.train_with_q_input:
|
|
165
|
-
# predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
|
|
166
|
-
# else:
|
|
167
|
-
# predicted_params = self.model(scaled_curves, scaled_bounds)
|
|
168
|
-
|
|
169
|
-
# loss = self.mse(predicted_params, scaled_params)
|
|
170
|
-
# return {'loss': loss}
|
|
171
|
-
|
|
172
|
-
# def init(self):
|
|
173
|
-
# self.mse = nn.MSELoss()
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
class DenoisingAETrainer(RealTimeSimTrainer):
|
|
177
|
-
"""Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
|
|
178
|
-
def init(self):
|
|
179
|
-
self.loader.calc_denoised_curves = True
|
|
180
|
-
|
|
181
|
-
if getattr(self, 'use_l1_loss', False):
|
|
182
|
-
self.criterion = nn.L1Loss()
|
|
183
|
-
else:
|
|
184
|
-
self.criterion = nn.MSELoss()
|
|
185
|
-
|
|
186
|
-
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
187
|
-
"""returns scaled curves with and without noise"""
|
|
188
|
-
scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
|
|
189
|
-
scaled_curves = self.loader.curves_scaler.scale(curves)
|
|
190
|
-
|
|
191
|
-
scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
|
|
192
|
-
|
|
193
|
-
return scaled_noisy_curves, scaled_curves
|
|
194
|
-
|
|
195
|
-
def get_loss_dict(self, batch_data):
|
|
196
|
-
"""returns the reconstruction loss of the autoencoder"""
|
|
197
|
-
scaled_noisy_curves, scaled_curves = batch_data
|
|
198
|
-
restored_curves = self.model(scaled_noisy_curves)
|
|
199
|
-
loss = self.criterion(scaled_curves, restored_curves)
|
|
200
|
-
return {'loss': loss}
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torch import nn
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from reflectorch.data_generation import BATCH_DATA_TYPE
|
|
9
|
+
from reflectorch.ml.basic_trainer import Trainer
|
|
10
|
+
from reflectorch.ml.dataloaders import ReflectivityDataLoader
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
'RealTimeSimTrainer',
|
|
14
|
+
'DenoisingAETrainer',
|
|
15
|
+
'PointEstimatorTrainer',
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class BasicBatchData:
|
|
21
|
+
scaled_curves: torch.Tensor
|
|
22
|
+
scaled_bounds: torch.Tensor
|
|
23
|
+
scaled_params: torch.Tensor = None
|
|
24
|
+
scaled_sigmas: Optional[torch.Tensor] = None
|
|
25
|
+
scaled_q_values: Optional[torch.Tensor] = None
|
|
26
|
+
scaled_denoised_curves: Optional[torch.Tensor] = None
|
|
27
|
+
key_padding_mask: Optional[torch.Tensor] = None
|
|
28
|
+
scaled_conditioning_params: Optional[torch.Tensor] = None
|
|
29
|
+
unscaled_q_values: Optional[torch.Tensor] = None
|
|
30
|
+
|
|
31
|
+
class RealTimeSimTrainer(Trainer):
|
|
32
|
+
"""Trainer with functionality to customize the sampled batch of data"""
|
|
33
|
+
loader: ReflectivityDataLoader
|
|
34
|
+
|
|
35
|
+
def get_batch_by_idx(self, batch_num: int):
|
|
36
|
+
"""Gets a batch of data with the default batch size"""
|
|
37
|
+
batch_data = self.loader.get_batch(self.batch_size)
|
|
38
|
+
return self._get_batch(batch_data)
|
|
39
|
+
|
|
40
|
+
def get_batch_by_size(self, batch_size: int):
|
|
41
|
+
"""Gets a batch of data with a custom batch size"""
|
|
42
|
+
batch_data = self.loader.get_batch(batch_size)
|
|
43
|
+
return self._get_batch(batch_data)
|
|
44
|
+
|
|
45
|
+
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
46
|
+
"""Modify the batch of data sampled from the data loader"""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
51
|
+
"""Point estimator trainer for the inverse problem."""
|
|
52
|
+
|
|
53
|
+
def init(self):
|
|
54
|
+
if getattr(self, 'use_l1_loss', False):
|
|
55
|
+
self.criterion = nn.L1Loss(reduction='none')
|
|
56
|
+
else:
|
|
57
|
+
self.criterion = nn.MSELoss(reduction='none')
|
|
58
|
+
self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
|
|
59
|
+
self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
|
|
60
|
+
if self.use_curve_reconstruction_loss:
|
|
61
|
+
self.loader.calc_denoised_curves = True
|
|
62
|
+
|
|
63
|
+
self.train_with_q_input = getattr(self, 'train_with_q_input', False)
|
|
64
|
+
self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
|
|
65
|
+
self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
|
|
66
|
+
|
|
67
|
+
def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
|
|
68
|
+
def get_scaled_or_none(key, scaler=None):
|
|
69
|
+
value = batch_data.get(key)
|
|
70
|
+
if value is None:
|
|
71
|
+
return None
|
|
72
|
+
scale_func = scaler or (lambda x: x)
|
|
73
|
+
return scale_func(value).to(torch.float32)
|
|
74
|
+
|
|
75
|
+
scaled_params = batch_data['scaled_params'].to(torch.float32)
|
|
76
|
+
scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
77
|
+
scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
|
|
78
|
+
scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
|
|
79
|
+
key_padding_mask = batch_data.get('key_padding_mask', None)
|
|
80
|
+
|
|
81
|
+
scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
|
|
82
|
+
conditioning_params = []
|
|
83
|
+
if scaled_q_resolutions is not None:
|
|
84
|
+
conditioning_params.append(scaled_q_resolutions)
|
|
85
|
+
scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
|
|
86
|
+
|
|
87
|
+
num_params = scaled_params.shape[-1] // 3
|
|
88
|
+
assert num_params * 3 == scaled_params.shape[-1]
|
|
89
|
+
scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
90
|
+
|
|
91
|
+
return BasicBatchData(
|
|
92
|
+
scaled_params=scaled_params,
|
|
93
|
+
scaled_bounds=scaled_bounds,
|
|
94
|
+
scaled_curves=scaled_curves,
|
|
95
|
+
scaled_q_values=scaled_q_values,
|
|
96
|
+
scaled_denoised_curves=scaled_denoised_curves,
|
|
97
|
+
scaled_conditioning_params=scaled_conditioning_params,
|
|
98
|
+
unscaled_q_values=batch_data['q_values'],
|
|
99
|
+
key_padding_mask=key_padding_mask,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def get_loss_dict(self, batch_data: BasicBatchData):
|
|
103
|
+
"""Returns the regression loss"""
|
|
104
|
+
scaled_params=batch_data.scaled_params
|
|
105
|
+
scaled_curves=batch_data.scaled_curves
|
|
106
|
+
scaled_bounds=batch_data.scaled_bounds
|
|
107
|
+
scaled_q_values=batch_data.scaled_q_values
|
|
108
|
+
key_padding_mask=batch_data.key_padding_mask
|
|
109
|
+
scaled_conditioning_params=batch_data.scaled_conditioning_params
|
|
110
|
+
unscaled_q_values=batch_data.unscaled_q_values
|
|
111
|
+
|
|
112
|
+
predicted_params = self.model(
|
|
113
|
+
curves = scaled_curves,
|
|
114
|
+
bounds = scaled_bounds,
|
|
115
|
+
q_values = scaled_q_values,
|
|
116
|
+
conditioning_params = scaled_conditioning_params,
|
|
117
|
+
key_padding_mask = key_padding_mask,
|
|
118
|
+
unscaled_q_values = unscaled_q_values,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if not self.rescale_loss_interval_width:
|
|
122
|
+
loss = self.criterion(predicted_params, scaled_params).mean()
|
|
123
|
+
else:
|
|
124
|
+
n_params = scaled_params.shape[-1]
|
|
125
|
+
b_min = scaled_bounds[..., :n_params]
|
|
126
|
+
b_max = scaled_bounds[..., n_params:]
|
|
127
|
+
interval_width = b_max - b_min
|
|
128
|
+
|
|
129
|
+
base_loss = self.criterion(predicted_params, scaled_params)
|
|
130
|
+
if isinstance(self.criterion, torch.nn.MSELoss):
|
|
131
|
+
width_factors = (interval_width / 2) ** 2
|
|
132
|
+
elif isinstance(self.criterion, torch.nn.L1Loss):
|
|
133
|
+
width_factors = interval_width / 2
|
|
134
|
+
|
|
135
|
+
loss = (width_factors * base_loss).mean()
|
|
136
|
+
|
|
137
|
+
return {'loss': loss}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
141
|
+
# """Trainer for the regression inverse problem with incorporation of prior bounds"""
|
|
142
|
+
# add_sigmas_to_context: bool = False
|
|
143
|
+
|
|
144
|
+
# def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
145
|
+
# scaled_params = batch_data['scaled_params'].to(torch.float32)
|
|
146
|
+
# scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
147
|
+
# if self.train_with_q_input:
|
|
148
|
+
# q_values = batch_data['q_values'].to(torch.float32)
|
|
149
|
+
# scaled_q_values = self.loader.q_generator.scale_q(q_values)
|
|
150
|
+
# else:
|
|
151
|
+
# scaled_q_values = None
|
|
152
|
+
|
|
153
|
+
# num_params = scaled_params.shape[-1] // 3
|
|
154
|
+
# assert num_params * 3 == scaled_params.shape[-1]
|
|
155
|
+
# scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
156
|
+
|
|
157
|
+
# return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
|
|
158
|
+
|
|
159
|
+
# def get_loss_dict(self, batch_data):
|
|
160
|
+
# """computes the loss dictionary"""
|
|
161
|
+
|
|
162
|
+
# scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
|
|
163
|
+
|
|
164
|
+
# if self.train_with_q_input:
|
|
165
|
+
# predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
|
|
166
|
+
# else:
|
|
167
|
+
# predicted_params = self.model(scaled_curves, scaled_bounds)
|
|
168
|
+
|
|
169
|
+
# loss = self.mse(predicted_params, scaled_params)
|
|
170
|
+
# return {'loss': loss}
|
|
171
|
+
|
|
172
|
+
# def init(self):
|
|
173
|
+
# self.mse = nn.MSELoss()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class DenoisingAETrainer(RealTimeSimTrainer):
|
|
177
|
+
"""Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
|
|
178
|
+
def init(self):
|
|
179
|
+
self.loader.calc_denoised_curves = True
|
|
180
|
+
|
|
181
|
+
if getattr(self, 'use_l1_loss', False):
|
|
182
|
+
self.criterion = nn.L1Loss()
|
|
183
|
+
else:
|
|
184
|
+
self.criterion = nn.MSELoss()
|
|
185
|
+
|
|
186
|
+
def _get_batch(self, batch_data: BATCH_DATA_TYPE):
|
|
187
|
+
"""returns scaled curves with and without noise"""
|
|
188
|
+
scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
|
|
189
|
+
scaled_curves = self.loader.curves_scaler.scale(curves)
|
|
190
|
+
|
|
191
|
+
scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
|
|
192
|
+
|
|
193
|
+
return scaled_noisy_curves, scaled_curves
|
|
194
|
+
|
|
195
|
+
def get_loss_dict(self, batch_data):
|
|
196
|
+
"""returns the reconstruction loss of the autoencoder"""
|
|
197
|
+
scaled_noisy_curves, scaled_curves = batch_data
|
|
198
|
+
restored_curves = self.model(scaled_noisy_curves)
|
|
199
|
+
loss = self.criterion(scaled_curves, restored_curves)
|
|
200
|
+
return {'loss': loss}
|
|
201
201
|
|
reflectorch/ml/utils.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
def is_divisor(num: int, div: int):
|
|
2
|
-
return num and not num % div
|
|
1
|
+
def is_divisor(num: int, div: int):
|
|
2
|
+
return num and not num % div
|
reflectorch/models/__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
from reflectorch.models.encoders import *
|
|
2
|
-
from reflectorch.models.networks import *
|
|
3
|
-
|
|
4
|
-
__all__ = [
|
|
5
|
-
"ConvEncoder",
|
|
6
|
-
"ConvDecoder",
|
|
7
|
-
"ConvAutoencoder",
|
|
8
|
-
"FnoEncoder",
|
|
9
|
-
"IntegralConvEmbedding",
|
|
10
|
-
"SpectralConv1d",
|
|
11
|
-
"ConvResidualNet1D",
|
|
12
|
-
"ResidualMLP",
|
|
13
|
-
"NetworkWithPriors",
|
|
14
|
-
"NetworkWithPriorsConvEmb",
|
|
15
|
-
"NetworkWithPriorsFnoEmb",
|
|
1
|
+
from reflectorch.models.encoders import *
|
|
2
|
+
from reflectorch.models.networks import *
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"ConvEncoder",
|
|
6
|
+
"ConvDecoder",
|
|
7
|
+
"ConvAutoencoder",
|
|
8
|
+
"FnoEncoder",
|
|
9
|
+
"IntegralConvEmbedding",
|
|
10
|
+
"SpectralConv1d",
|
|
11
|
+
"ConvResidualNet1D",
|
|
12
|
+
"ResidualMLP",
|
|
13
|
+
"NetworkWithPriors",
|
|
14
|
+
"NetworkWithPriorsConvEmb",
|
|
15
|
+
"NetworkWithPriorsFnoEmb",
|
|
16
16
|
]
|
|
@@ -1,50 +1,50 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import nn
|
|
3
|
-
from torch.nn.functional import relu
|
|
4
|
-
|
|
5
|
-
class Rowdy(nn.Module):
|
|
6
|
-
"""adaptive activation function"""
|
|
7
|
-
def __init__(self, K=9):
|
|
8
|
-
super().__init__()
|
|
9
|
-
self.K = K
|
|
10
|
-
self.alpha = nn.Parameter(torch.cat((torch.ones(1), torch.zeros(K-1))))
|
|
11
|
-
self.alpha.requiresGrad = True
|
|
12
|
-
self.omega = nn.Parameter(torch.ones(K))
|
|
13
|
-
self.omega.requiresGrad = True
|
|
14
|
-
|
|
15
|
-
def forward(self, x):
|
|
16
|
-
rowdy = self.alpha[0]*relu(self.omega[0]*x)
|
|
17
|
-
for k in range(1, self.K):
|
|
18
|
-
rowdy += self.alpha[k]*torch.sin(self.omega[k]*k*x)
|
|
19
|
-
return rowdy
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
ACTIVATIONS = {
|
|
23
|
-
'relu': nn.ReLU,
|
|
24
|
-
'lrelu': nn.LeakyReLU,
|
|
25
|
-
'gelu': nn.GELU,
|
|
26
|
-
'selu': nn.SELU,
|
|
27
|
-
'elu': nn.ELU,
|
|
28
|
-
'sigmoid': nn.Sigmoid,
|
|
29
|
-
'tanh': nn.Tanh,
|
|
30
|
-
'silu': nn.SiLU,
|
|
31
|
-
'mish': nn.Mish,
|
|
32
|
-
'rowdy': Rowdy,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def activation_by_name(name):
|
|
37
|
-
"""returns an activation function module corresponding to its name
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
name (str): string denoting the activation function ('relu', 'lrelu', 'gelu', 'selu', 'elu', 'sigmoid', 'silu', 'mish', 'rowdy')
|
|
41
|
-
|
|
42
|
-
Returns:
|
|
43
|
-
nn.Module: Pytorch activation function module
|
|
44
|
-
"""
|
|
45
|
-
if not isinstance(name, str):
|
|
46
|
-
return name
|
|
47
|
-
try:
|
|
48
|
-
return ACTIVATIONS[name.lower()]
|
|
49
|
-
except KeyError:
|
|
50
|
-
raise KeyError(f'Unknown activation function {name}')
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn.functional import relu
|
|
4
|
+
|
|
5
|
+
class Rowdy(nn.Module):
|
|
6
|
+
"""adaptive activation function"""
|
|
7
|
+
def __init__(self, K=9):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.K = K
|
|
10
|
+
self.alpha = nn.Parameter(torch.cat((torch.ones(1), torch.zeros(K-1))))
|
|
11
|
+
self.alpha.requiresGrad = True
|
|
12
|
+
self.omega = nn.Parameter(torch.ones(K))
|
|
13
|
+
self.omega.requiresGrad = True
|
|
14
|
+
|
|
15
|
+
def forward(self, x):
|
|
16
|
+
rowdy = self.alpha[0]*relu(self.omega[0]*x)
|
|
17
|
+
for k in range(1, self.K):
|
|
18
|
+
rowdy += self.alpha[k]*torch.sin(self.omega[k]*k*x)
|
|
19
|
+
return rowdy
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
ACTIVATIONS = {
|
|
23
|
+
'relu': nn.ReLU,
|
|
24
|
+
'lrelu': nn.LeakyReLU,
|
|
25
|
+
'gelu': nn.GELU,
|
|
26
|
+
'selu': nn.SELU,
|
|
27
|
+
'elu': nn.ELU,
|
|
28
|
+
'sigmoid': nn.Sigmoid,
|
|
29
|
+
'tanh': nn.Tanh,
|
|
30
|
+
'silu': nn.SiLU,
|
|
31
|
+
'mish': nn.Mish,
|
|
32
|
+
'rowdy': Rowdy,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def activation_by_name(name):
|
|
37
|
+
"""returns an activation function module corresponding to its name
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
name (str): string denoting the activation function ('relu', 'lrelu', 'gelu', 'selu', 'elu', 'sigmoid', 'silu', 'mish', 'rowdy')
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
nn.Module: Pytorch activation function module
|
|
44
|
+
"""
|
|
45
|
+
if not isinstance(name, str):
|
|
46
|
+
return name
|
|
47
|
+
try:
|
|
48
|
+
return ACTIVATIONS[name.lower()]
|
|
49
|
+
except KeyError:
|
|
50
|
+
raise KeyError(f'Unknown activation function {name}')
|
|
@@ -1,19 +1,19 @@
|
|
|
1
|
-
from reflectorch.models.encoders.conv_encoder import (
|
|
2
|
-
ConvEncoder,
|
|
3
|
-
ConvDecoder,
|
|
4
|
-
ConvAutoencoder,
|
|
5
|
-
)
|
|
6
|
-
from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
|
|
7
|
-
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
8
|
-
from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
__all__ = [
|
|
12
|
-
"ConvEncoder",
|
|
13
|
-
"ConvDecoder",
|
|
14
|
-
"ConvAutoencoder",
|
|
15
|
-
"ConvResidualNet1D",
|
|
16
|
-
"FnoEncoder",
|
|
17
|
-
"SpectralConv1d",
|
|
18
|
-
"IntegralConvEmbedding",
|
|
19
|
-
]
|
|
1
|
+
from reflectorch.models.encoders.conv_encoder import (
|
|
2
|
+
ConvEncoder,
|
|
3
|
+
ConvDecoder,
|
|
4
|
+
ConvAutoencoder,
|
|
5
|
+
)
|
|
6
|
+
from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
|
|
7
|
+
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
8
|
+
from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ConvEncoder",
|
|
13
|
+
"ConvDecoder",
|
|
14
|
+
"ConvAutoencoder",
|
|
15
|
+
"ConvResidualNet1D",
|
|
16
|
+
"FnoEncoder",
|
|
17
|
+
"SpectralConv1d",
|
|
18
|
+
"IntegralConvEmbedding",
|
|
19
|
+
]
|