reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,434 +1,434 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
|
|
3
|
-
import math
|
|
4
|
-
from typing import Optional
|
|
5
|
-
import torch
|
|
6
|
-
from torch import nn, cat, split, Tensor
|
|
7
|
-
|
|
8
|
-
from reflectorch.models.networks.residual_net import ResidualMLP
|
|
9
|
-
from reflectorch.models.encoders.conv_encoder import ConvEncoder
|
|
10
|
-
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
11
|
-
from reflectorch.models.encoders.fno import FnoEncoder
|
|
12
|
-
from reflectorch.models.activations import activation_by_name
|
|
13
|
-
|
|
14
|
-
class NetworkWithPriors(nn.Module):
|
|
15
|
-
"""MLP network with an embedding network
|
|
16
|
-
|
|
17
|
-
.. image:: ../documentation/FigureReflectometryNetwork.png
|
|
18
|
-
:width: 800px
|
|
19
|
-
:align: center
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
embedding_net_type (str): the type of embedding network, either 'conv', 'fno' or 'integral_conv'.
|
|
23
|
-
embedding_net_kwargs (dict): dictionary containing the keyword arguments for the embedding network.
|
|
24
|
-
dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
25
|
-
dim_conditioning_params (int, optional): the dimension of other parameters the network is conditioned on (e.g. for the smearing coefficient dq/q)
|
|
26
|
-
layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
|
|
27
|
-
num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
|
|
28
|
-
repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
|
|
29
|
-
mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
|
|
30
|
-
use_batch_norm (bool, optional): whether to use batch normalization in the MLP. Defaults to True.
|
|
31
|
-
use_layer_norm (bool, optional): whether to use layer normalization in the MLP (if use_batch_norm is False). Defaults to False.
|
|
32
|
-
dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
|
|
33
|
-
tanh_output (bool, optional): whether to apply a tanh function to the output. Defaults to False.
|
|
34
|
-
use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
|
|
35
|
-
pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
|
|
36
|
-
residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
|
|
37
|
-
adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
|
|
38
|
-
conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
|
|
39
|
-
"""
|
|
40
|
-
def __init__(self,
|
|
41
|
-
embedding_net_type: str, # 'conv', 'fno'
|
|
42
|
-
embedding_net_kwargs: dict,
|
|
43
|
-
pretrained_embedding_net: str = None,
|
|
44
|
-
dim_out: int = 8,
|
|
45
|
-
dim_conditioning_params: int = 0,
|
|
46
|
-
layer_width: int = 512,
|
|
47
|
-
num_blocks: int = 4,
|
|
48
|
-
repeats_per_block: int = 2,
|
|
49
|
-
mlp_activation: str = 'gelu',
|
|
50
|
-
use_batch_norm: bool = True,
|
|
51
|
-
use_layer_norm: bool = False,
|
|
52
|
-
dropout_rate: float = 0.0,
|
|
53
|
-
tanh_output: bool = False,
|
|
54
|
-
use_selu_init: bool = False,
|
|
55
|
-
residual: bool = True,
|
|
56
|
-
adaptive_activation: bool = False,
|
|
57
|
-
conditioning: str = 'concat',
|
|
58
|
-
concat_condition_first_layer: bool = True):
|
|
59
|
-
super().__init__()
|
|
60
|
-
|
|
61
|
-
self.conditioning = conditioning
|
|
62
|
-
self.dim_prior_bounds = 2 * dim_out
|
|
63
|
-
self.dim_conditioning_params = dim_conditioning_params
|
|
64
|
-
self.tanh_output = tanh_output
|
|
65
|
-
|
|
66
|
-
if embedding_net_type == 'conv':
|
|
67
|
-
self.embedding_net = ConvEncoder(**embedding_net_kwargs)
|
|
68
|
-
elif embedding_net_type == 'fno':
|
|
69
|
-
self.embedding_net = FnoEncoder(**embedding_net_kwargs)
|
|
70
|
-
elif embedding_net_type == 'integral_conv':
|
|
71
|
-
self.embedding_net = IntegralConvEmbedding(**embedding_net_kwargs)
|
|
72
|
-
elif embedding_net_type == 'no_embedding_net':
|
|
73
|
-
self.embedding_net = nn.Identity()
|
|
74
|
-
else:
|
|
75
|
-
raise ValueError(f"Unsupported embedding_net_type: {embedding_net_type}")
|
|
76
|
-
|
|
77
|
-
self.dim_embedding = embedding_net_kwargs['dim_embedding']
|
|
78
|
-
|
|
79
|
-
if conditioning == 'concat':
|
|
80
|
-
dim_mlp_in = self.dim_embedding + self.dim_prior_bounds + self.dim_conditioning_params
|
|
81
|
-
dim_condition = 0
|
|
82
|
-
elif conditioning == 'glu' or conditioning == 'film':
|
|
83
|
-
dim_mlp_in = self.dim_embedding
|
|
84
|
-
dim_condition = self.dim_prior_bounds + self.dim_conditioning_params
|
|
85
|
-
else:
|
|
86
|
-
raise NotImplementedError(f"Conditioning type '{conditioning}' is not supported.")
|
|
87
|
-
|
|
88
|
-
self.mlp = ResidualMLP(
|
|
89
|
-
dim_in=dim_mlp_in,
|
|
90
|
-
dim_out=dim_out,
|
|
91
|
-
dim_condition=dim_condition,
|
|
92
|
-
layer_width=layer_width,
|
|
93
|
-
num_blocks=num_blocks,
|
|
94
|
-
repeats_per_block=repeats_per_block,
|
|
95
|
-
activation=mlp_activation,
|
|
96
|
-
use_batch_norm=use_batch_norm,
|
|
97
|
-
use_layer_norm=use_layer_norm,
|
|
98
|
-
dropout_rate=dropout_rate,
|
|
99
|
-
residual=residual,
|
|
100
|
-
adaptive_activation=adaptive_activation,
|
|
101
|
-
conditioning=conditioning,
|
|
102
|
-
concat_condition_first_layer=concat_condition_first_layer,
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
if use_selu_init and embedding_net_kwargs.get('activation', None) == 'selu':
|
|
106
|
-
self.embedding_net.apply(selu_init)
|
|
107
|
-
if use_selu_init and mlp_activation == 'selu':
|
|
108
|
-
self.mlp.apply(selu_init)
|
|
109
|
-
|
|
110
|
-
if pretrained_embedding_net:
|
|
111
|
-
self.embedding_net.load_weights(pretrained_embedding_net)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def forward(self, curves, bounds, q_values=None, conditioning_params=None, key_padding_mask=None, unscaled_q_values=None):
|
|
115
|
-
"""
|
|
116
|
-
Args:
|
|
117
|
-
scaled_curves (torch.Tensor): Input tensor of shape [batch_size, n_points] or [batch_size, n_channels, n_points].
|
|
118
|
-
scaled_bounds (torch.Tensor): Tensor representing prior bounds, shape [batch_size, 2*n_params].
|
|
119
|
-
scaled_q_values (torch.Tensor, optional): Tensor of shape [batch_size, n_points].
|
|
120
|
-
scaled_conditioning_params (torch.Tensor, optional): Additional parameters for conditioning, shape [batch_size, ...].
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
if curves.dim() == 2:
|
|
124
|
-
curves = curves.unsqueeze(1)
|
|
125
|
-
|
|
126
|
-
additional_channels = []
|
|
127
|
-
if q_values is not None and not isinstance(self.embedding_net, IntegralConvEmbedding):
|
|
128
|
-
additional_channels.append(q_values.unsqueeze(1))
|
|
129
|
-
|
|
130
|
-
if additional_channels:
|
|
131
|
-
curves = torch.cat([curves] + additional_channels, dim=1) # [batch_size, n_channels, n_points]
|
|
132
|
-
|
|
133
|
-
if isinstance(self.embedding_net, IntegralConvEmbedding):
|
|
134
|
-
x = self.embedding_net(q=unscaled_q_values.float(), y=curves.permute(0, 2, 1), drop_mask=key_padding_mask)
|
|
135
|
-
else:
|
|
136
|
-
x = self.embedding_net(curves)
|
|
137
|
-
|
|
138
|
-
if self.conditioning == 'concat':
|
|
139
|
-
x = torch.cat([x, bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
|
|
140
|
-
x = self.mlp(x)
|
|
141
|
-
|
|
142
|
-
elif self.conditioning in ['glu', 'film']:
|
|
143
|
-
condition = torch.cat([bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
|
|
144
|
-
x = self.mlp(x, condition=condition)
|
|
145
|
-
|
|
146
|
-
else:
|
|
147
|
-
raise NotImplementedError(f"Conditioning type {self.conditioning} not recognized.")
|
|
148
|
-
|
|
149
|
-
if self.tanh_output:
|
|
150
|
-
x = torch.tanh(x)
|
|
151
|
-
|
|
152
|
-
return x
|
|
153
|
-
|
|
154
|
-
class NetworkWithPriorsConvEmb(NetworkWithPriors):
|
|
155
|
-
"""Wrapper for back-compatibility with previous versions of the package"""
|
|
156
|
-
def __init__(self, **kwargs):
|
|
157
|
-
embedding_net_kwargs = {
|
|
158
|
-
'in_channels': kwargs.pop('in_channels', 1),
|
|
159
|
-
'hidden_channels': kwargs.pop('hidden_channels', [32, 64, 128, 256, 512]),
|
|
160
|
-
'dim_embedding': kwargs.pop('dim_embedding', 128),
|
|
161
|
-
'dim_avpool': kwargs.pop('dim_avpool', 1),
|
|
162
|
-
'activation': kwargs.pop('embedding_net_activation', 'gelu'),
|
|
163
|
-
'use_batch_norm': kwargs.pop('use_batch_norm', False),
|
|
164
|
-
}
|
|
165
|
-
|
|
166
|
-
super().__init__(
|
|
167
|
-
embedding_net_type='conv',
|
|
168
|
-
embedding_net_kwargs=embedding_net_kwargs,
|
|
169
|
-
**kwargs
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
class NetworkWithPriorsFnoEmb(NetworkWithPriors):
|
|
173
|
-
"""Wrapper for back-compatibility with previous versions of the package"""
|
|
174
|
-
def __init__(self, **kwargs):
|
|
175
|
-
embedding_net_kwargs = {
|
|
176
|
-
'in_channels': kwargs.pop('in_channels', 2),
|
|
177
|
-
'dim_embedding': kwargs.pop('dim_embedding', 128),
|
|
178
|
-
'modes': kwargs.pop('modes', 16),
|
|
179
|
-
'width_fno': kwargs.pop('width_fno', 64),
|
|
180
|
-
'n_fno_blocks': kwargs.pop('n_fno_blocks', 6),
|
|
181
|
-
'activation': kwargs.pop('embedding_net_activation', 'gelu'),
|
|
182
|
-
'fusion_self_attention': kwargs.pop('fusion_self_attention', False),
|
|
183
|
-
}
|
|
184
|
-
|
|
185
|
-
super().__init__(
|
|
186
|
-
embedding_net_type='fno',
|
|
187
|
-
embedding_net_kwargs=embedding_net_kwargs,
|
|
188
|
-
**kwargs
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
# class NetworkWithPriorsConvEmb(nn.Module):
|
|
192
|
-
# """MLP network with 1D CNN embedding network
|
|
193
|
-
|
|
194
|
-
# .. image:: ../documentation/FigureReflectometryNetwork.png
|
|
195
|
-
# :width: 800px
|
|
196
|
-
# :align: center
|
|
197
|
-
|
|
198
|
-
# Args:
|
|
199
|
-
# in_channels (int, optional): the number of input channels of the 1D CNN. Defaults to 1.
|
|
200
|
-
# hidden_channels (tuple, optional): list with the number of channels for each layer of the 1D CNN. Defaults to (32, 64, 128, 256, 512).
|
|
201
|
-
# dim_embedding (int, optional): the dimension of the embedding produced by the 1D CNN. Defaults to 128.
|
|
202
|
-
# dim_avpool (int, optional): the type of activation function in the 1D CNN. Defaults to 1.
|
|
203
|
-
# embedding_net_activation (str, optional): the type of activation function in the 1D CNN. Defaults to 'gelu'.
|
|
204
|
-
# use_batch_norm (bool, optional): whether to use batch normalization (in both the 1D CNN and the MLP). Defaults to False.
|
|
205
|
-
# dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
206
|
-
# layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
|
|
207
|
-
# num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
|
|
208
|
-
# repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
|
|
209
|
-
# mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
|
|
210
|
-
# dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
|
|
211
|
-
# use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
|
|
212
|
-
# pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
|
|
213
|
-
# residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
|
|
214
|
-
# adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
|
|
215
|
-
# conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
|
|
216
|
-
# """
|
|
217
|
-
# def __init__(self,
|
|
218
|
-
# in_channels: int = 1,
|
|
219
|
-
# hidden_channels: tuple = (32, 64, 128, 256, 512),
|
|
220
|
-
# dim_embedding: int = 128,
|
|
221
|
-
# dim_avpool: int = 1,
|
|
222
|
-
# embedding_net_activation: str = 'gelu',
|
|
223
|
-
# use_batch_norm: bool = False,
|
|
224
|
-
# dim_out: int = 8,
|
|
225
|
-
# layer_width: int = 512,
|
|
226
|
-
# num_blocks: int = 4,
|
|
227
|
-
# repeats_per_block: int = 2,
|
|
228
|
-
# mlp_activation: str = 'gelu',
|
|
229
|
-
# dropout_rate: float = 0.0,
|
|
230
|
-
# use_selu_init: bool = False,
|
|
231
|
-
# pretrained_embedding_net: str = None,
|
|
232
|
-
# residual: bool = True,
|
|
233
|
-
# adaptive_activation: bool = False,
|
|
234
|
-
# conditioning: str = 'concat',
|
|
235
|
-
# ):
|
|
236
|
-
# super().__init__()
|
|
237
|
-
|
|
238
|
-
# self.in_channels = in_channels
|
|
239
|
-
# self.conditioning = conditioning
|
|
240
|
-
|
|
241
|
-
# self.embedding_net = ConvEncoder(
|
|
242
|
-
# in_channels=in_channels,
|
|
243
|
-
# hidden_channels=hidden_channels,
|
|
244
|
-
# dim_latent=dim_embedding,
|
|
245
|
-
# dim_avpool=dim_avpool,
|
|
246
|
-
# use_batch_norm=use_batch_norm,
|
|
247
|
-
# activation=embedding_net_activation
|
|
248
|
-
# )
|
|
249
|
-
|
|
250
|
-
# self.dim_prior_bounds = 2 * dim_out
|
|
251
|
-
|
|
252
|
-
# if conditioning == 'concat':
|
|
253
|
-
# dim_mlp_in = dim_embedding + self.dim_prior_bounds
|
|
254
|
-
# dim_condition = 0
|
|
255
|
-
# elif conditioning == 'glu' or conditioning == 'film':
|
|
256
|
-
# dim_mlp_in = dim_embedding
|
|
257
|
-
# dim_condition = self.dim_prior_bounds
|
|
258
|
-
# else:
|
|
259
|
-
# raise NotImplementedError
|
|
260
|
-
|
|
261
|
-
# self.mlp = ResidualMLP(
|
|
262
|
-
# dim_in=dim_mlp_in,
|
|
263
|
-
# dim_out=dim_out,
|
|
264
|
-
# dim_condition=dim_condition,
|
|
265
|
-
# layer_width=layer_width,
|
|
266
|
-
# num_blocks=num_blocks,
|
|
267
|
-
# repeats_per_block=repeats_per_block,
|
|
268
|
-
# activation=mlp_activation,
|
|
269
|
-
# use_batch_norm=use_batch_norm,
|
|
270
|
-
# dropout_rate=dropout_rate,
|
|
271
|
-
# residual=residual,
|
|
272
|
-
# adaptive_activation=adaptive_activation,
|
|
273
|
-
# conditioning=conditioning,
|
|
274
|
-
# )
|
|
275
|
-
|
|
276
|
-
# if use_selu_init and embedding_net_activation == 'selu':
|
|
277
|
-
# self.embedding_net.apply(selu_init)
|
|
278
|
-
|
|
279
|
-
# if use_selu_init and mlp_activation == 'selu':
|
|
280
|
-
# self.mlp.apply(selu_init)
|
|
281
|
-
|
|
282
|
-
# if pretrained_embedding_net:
|
|
283
|
-
# self.embedding_net.load_weights(pretrained_embedding_net)
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
# def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] = None):
|
|
287
|
-
# """
|
|
288
|
-
# Args:
|
|
289
|
-
# curves (Tensor): reflectivity curves
|
|
290
|
-
# bounds (Tensor): prior bounds
|
|
291
|
-
# q_values (Tensor, optional): q values. Defaults to None.
|
|
292
|
-
|
|
293
|
-
# Returns:
|
|
294
|
-
# Tensor: prediction
|
|
295
|
-
# """
|
|
296
|
-
# if q_values is not None:
|
|
297
|
-
# curves = torch.cat([curves[:, None, :], q_values[:, None, :]], dim=1)
|
|
298
|
-
|
|
299
|
-
# if self.conditioning == 'concat':
|
|
300
|
-
# x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
|
|
301
|
-
# x = self.mlp(x)
|
|
302
|
-
|
|
303
|
-
# elif self.conditioning == 'glu' or self.conditioning == 'film':
|
|
304
|
-
# x = self.mlp(self.embedding_net(curves), condition=bounds)
|
|
305
|
-
|
|
306
|
-
# return x
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
# class NetworkWithPriorsFnoEmb(nn.Module):
|
|
310
|
-
# """MLP network with FNO embedding network
|
|
311
|
-
|
|
312
|
-
# Args:
|
|
313
|
-
# in_channels (int, optional): the number of input channels to the FNO-based embedding network. Defaults to 2.
|
|
314
|
-
# dim_embedding (int, optional): the dimension of the embedding produced by the FNO. Defaults to 128.
|
|
315
|
-
# modes (int, optional): the number of Fourier modes that are utilized. Defaults to 16.
|
|
316
|
-
# width_fno (int, optional): the number of channels in the FNO blocks. Defaults to 64.
|
|
317
|
-
# embedding_net_activation (str, optional): the type of activation function in the embedding network. Defaults to 'gelu'.
|
|
318
|
-
# n_fno_blocks (int, optional): the number of FNO blocks. Defaults to 6.
|
|
319
|
-
# fusion_self_attention (bool, optional): if ``True`` a fusion layer is used after the FNO blocks to produce the final output. Defaults to False.
|
|
320
|
-
# dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
321
|
-
# layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
|
|
322
|
-
# num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
|
|
323
|
-
# repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
|
|
324
|
-
# use_batch_norm (bool, optional): whether to use batch normalization (only in the MLP). Defaults to False.
|
|
325
|
-
# mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
|
|
326
|
-
# dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
|
|
327
|
-
# use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
|
|
328
|
-
# residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
|
|
329
|
-
# adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
|
|
330
|
-
# conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
|
|
331
|
-
# """
|
|
332
|
-
# def __init__(self,
|
|
333
|
-
# in_channels: int = 2,
|
|
334
|
-
# dim_embedding: int = 128,
|
|
335
|
-
# modes: int = 16,
|
|
336
|
-
# width_fno: int = 64,
|
|
337
|
-
# embedding_net_activation: str = 'gelu',
|
|
338
|
-
# n_fno_blocks : int = 6,
|
|
339
|
-
# fusion_self_attention: bool = False,
|
|
340
|
-
# dim_out: int = 8,
|
|
341
|
-
# layer_width: int = 512,
|
|
342
|
-
# num_blocks: int = 4,
|
|
343
|
-
# repeats_per_block: int = 2,
|
|
344
|
-
# use_batch_norm: bool = False,
|
|
345
|
-
# mlp_activation: str = 'gelu',
|
|
346
|
-
# dropout_rate: float = 0.0,
|
|
347
|
-
# use_selu_init: bool = False,
|
|
348
|
-
# residual: bool = True,
|
|
349
|
-
# adaptive_activation: bool = False,
|
|
350
|
-
# conditioning: str = 'concat',
|
|
351
|
-
# ):
|
|
352
|
-
# super().__init__()
|
|
353
|
-
|
|
354
|
-
# self.conditioning = conditioning
|
|
355
|
-
|
|
356
|
-
# self.embedding_net = FnoEncoder(
|
|
357
|
-
# ch_in=in_channels,
|
|
358
|
-
# dim_embedding=dim_embedding,
|
|
359
|
-
# modes=modes,
|
|
360
|
-
# width_fno=width_fno,
|
|
361
|
-
# n_fno_blocks=n_fno_blocks,
|
|
362
|
-
# activation=embedding_net_activation,
|
|
363
|
-
# fusion_self_attention=fusion_self_attention
|
|
364
|
-
# )
|
|
365
|
-
|
|
366
|
-
# self.dim_prior_bounds = 2 * dim_out
|
|
367
|
-
|
|
368
|
-
# if conditioning == 'concat':
|
|
369
|
-
# dim_mlp_in = dim_embedding + self.dim_prior_bounds
|
|
370
|
-
# dim_condition = 0
|
|
371
|
-
# elif conditioning == 'glu' or conditioning == 'film':
|
|
372
|
-
# dim_mlp_in = dim_embedding
|
|
373
|
-
# dim_condition = self.dim_prior_bounds
|
|
374
|
-
# else:
|
|
375
|
-
# raise NotImplementedError
|
|
376
|
-
|
|
377
|
-
# self.mlp = ResidualMLP(
|
|
378
|
-
# dim_in=dim_mlp_in,
|
|
379
|
-
# dim_out=dim_out,
|
|
380
|
-
# dim_condition=dim_condition,
|
|
381
|
-
# layer_width=layer_width,
|
|
382
|
-
# num_blocks=num_blocks,
|
|
383
|
-
# repeats_per_block=repeats_per_block,
|
|
384
|
-
# activation=mlp_activation,
|
|
385
|
-
# use_batch_norm=use_batch_norm,
|
|
386
|
-
# dropout_rate=dropout_rate,
|
|
387
|
-
# residual=residual,
|
|
388
|
-
# adaptive_activation=adaptive_activation,
|
|
389
|
-
# conditioning=conditioning,
|
|
390
|
-
# )
|
|
391
|
-
|
|
392
|
-
# if use_selu_init and embedding_net_activation == 'selu':
|
|
393
|
-
# self.FnoEncoder.apply(selu_init)
|
|
394
|
-
|
|
395
|
-
# if use_selu_init and mlp_activation == 'selu':
|
|
396
|
-
# self.mlp.apply(selu_init)
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
# def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] =None):
|
|
400
|
-
# """
|
|
401
|
-
# Args:
|
|
402
|
-
# curves (Tensor): reflectivity curves
|
|
403
|
-
# bounds (Tensor): prior bounds
|
|
404
|
-
# q_values (Tensor, optional): q values. Defaults to None.
|
|
405
|
-
|
|
406
|
-
# Returns:
|
|
407
|
-
# Tensor: prediction
|
|
408
|
-
# """
|
|
409
|
-
# if curves.dim() < 3:
|
|
410
|
-
# curves = curves[:, None, :]
|
|
411
|
-
# if q_values is not None:
|
|
412
|
-
# curves = torch.cat([curves, q_values[:, None, :]], dim=1)
|
|
413
|
-
|
|
414
|
-
# if self.conditioning == 'concat':
|
|
415
|
-
# x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
|
|
416
|
-
# x = self.mlp(x)
|
|
417
|
-
|
|
418
|
-
# elif self.conditioning == 'glu' or self.conditioning == 'film':
|
|
419
|
-
# x = self.mlp(self.embedding_net(curves), condition=bounds)
|
|
420
|
-
|
|
421
|
-
# return x
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
def selu_init(m):
|
|
426
|
-
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
427
|
-
m.weight.data.normal_(0.0, 0.5 / math.sqrt(m.weight.numel()))
|
|
428
|
-
nn.init.constant_(m.bias, 0)
|
|
429
|
-
elif isinstance(m, nn.BatchNorm1d):
|
|
430
|
-
size = m.weight.size()
|
|
431
|
-
fan_in = size[0]
|
|
432
|
-
|
|
433
|
-
m.weight.data.normal_(0.0, 1.0 / math.sqrt(fan_in))
|
|
434
|
-
m.bias.data.fill_(0)
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn, cat, split, Tensor
|
|
7
|
+
|
|
8
|
+
from reflectorch.models.networks.residual_net import ResidualMLP
|
|
9
|
+
from reflectorch.models.encoders.conv_encoder import ConvEncoder
|
|
10
|
+
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
11
|
+
from reflectorch.models.encoders.fno import FnoEncoder
|
|
12
|
+
from reflectorch.models.activations import activation_by_name
|
|
13
|
+
|
|
14
|
+
class NetworkWithPriors(nn.Module):
|
|
15
|
+
"""MLP network with an embedding network
|
|
16
|
+
|
|
17
|
+
.. image:: ../documentation/FigureReflectometryNetwork.png
|
|
18
|
+
:width: 800px
|
|
19
|
+
:align: center
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
embedding_net_type (str): the type of embedding network, either 'conv', 'fno' or 'integral_conv'.
|
|
23
|
+
embedding_net_kwargs (dict): dictionary containing the keyword arguments for the embedding network.
|
|
24
|
+
dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
25
|
+
dim_conditioning_params (int, optional): the dimension of other parameters the network is conditioned on (e.g. for the smearing coefficient dq/q)
|
|
26
|
+
layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
|
|
27
|
+
num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
|
|
28
|
+
repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
|
|
29
|
+
mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
|
|
30
|
+
use_batch_norm (bool, optional): whether to use batch normalization in the MLP. Defaults to True.
|
|
31
|
+
use_layer_norm (bool, optional): whether to use layer normalization in the MLP (if use_batch_norm is False). Defaults to False.
|
|
32
|
+
dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
|
|
33
|
+
tanh_output (bool, optional): whether to apply a tanh function to the output. Defaults to False.
|
|
34
|
+
use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
|
|
35
|
+
pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
|
|
36
|
+
residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
|
|
37
|
+
adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
|
|
38
|
+
conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self,
|
|
41
|
+
embedding_net_type: str, # 'conv', 'fno'
|
|
42
|
+
embedding_net_kwargs: dict,
|
|
43
|
+
pretrained_embedding_net: str = None,
|
|
44
|
+
dim_out: int = 8,
|
|
45
|
+
dim_conditioning_params: int = 0,
|
|
46
|
+
layer_width: int = 512,
|
|
47
|
+
num_blocks: int = 4,
|
|
48
|
+
repeats_per_block: int = 2,
|
|
49
|
+
mlp_activation: str = 'gelu',
|
|
50
|
+
use_batch_norm: bool = True,
|
|
51
|
+
use_layer_norm: bool = False,
|
|
52
|
+
dropout_rate: float = 0.0,
|
|
53
|
+
tanh_output: bool = False,
|
|
54
|
+
use_selu_init: bool = False,
|
|
55
|
+
residual: bool = True,
|
|
56
|
+
adaptive_activation: bool = False,
|
|
57
|
+
conditioning: str = 'concat',
|
|
58
|
+
concat_condition_first_layer: bool = True):
|
|
59
|
+
super().__init__()
|
|
60
|
+
|
|
61
|
+
self.conditioning = conditioning
|
|
62
|
+
self.dim_prior_bounds = 2 * dim_out
|
|
63
|
+
self.dim_conditioning_params = dim_conditioning_params
|
|
64
|
+
self.tanh_output = tanh_output
|
|
65
|
+
|
|
66
|
+
if embedding_net_type == 'conv':
|
|
67
|
+
self.embedding_net = ConvEncoder(**embedding_net_kwargs)
|
|
68
|
+
elif embedding_net_type == 'fno':
|
|
69
|
+
self.embedding_net = FnoEncoder(**embedding_net_kwargs)
|
|
70
|
+
elif embedding_net_type == 'integral_conv':
|
|
71
|
+
self.embedding_net = IntegralConvEmbedding(**embedding_net_kwargs)
|
|
72
|
+
elif embedding_net_type == 'no_embedding_net':
|
|
73
|
+
self.embedding_net = nn.Identity()
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(f"Unsupported embedding_net_type: {embedding_net_type}")
|
|
76
|
+
|
|
77
|
+
self.dim_embedding = embedding_net_kwargs['dim_embedding']
|
|
78
|
+
|
|
79
|
+
if conditioning == 'concat':
|
|
80
|
+
dim_mlp_in = self.dim_embedding + self.dim_prior_bounds + self.dim_conditioning_params
|
|
81
|
+
dim_condition = 0
|
|
82
|
+
elif conditioning == 'glu' or conditioning == 'film':
|
|
83
|
+
dim_mlp_in = self.dim_embedding
|
|
84
|
+
dim_condition = self.dim_prior_bounds + self.dim_conditioning_params
|
|
85
|
+
else:
|
|
86
|
+
raise NotImplementedError(f"Conditioning type '{conditioning}' is not supported.")
|
|
87
|
+
|
|
88
|
+
self.mlp = ResidualMLP(
|
|
89
|
+
dim_in=dim_mlp_in,
|
|
90
|
+
dim_out=dim_out,
|
|
91
|
+
dim_condition=dim_condition,
|
|
92
|
+
layer_width=layer_width,
|
|
93
|
+
num_blocks=num_blocks,
|
|
94
|
+
repeats_per_block=repeats_per_block,
|
|
95
|
+
activation=mlp_activation,
|
|
96
|
+
use_batch_norm=use_batch_norm,
|
|
97
|
+
use_layer_norm=use_layer_norm,
|
|
98
|
+
dropout_rate=dropout_rate,
|
|
99
|
+
residual=residual,
|
|
100
|
+
adaptive_activation=adaptive_activation,
|
|
101
|
+
conditioning=conditioning,
|
|
102
|
+
concat_condition_first_layer=concat_condition_first_layer,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if use_selu_init and embedding_net_kwargs.get('activation', None) == 'selu':
|
|
106
|
+
self.embedding_net.apply(selu_init)
|
|
107
|
+
if use_selu_init and mlp_activation == 'selu':
|
|
108
|
+
self.mlp.apply(selu_init)
|
|
109
|
+
|
|
110
|
+
if pretrained_embedding_net:
|
|
111
|
+
self.embedding_net.load_weights(pretrained_embedding_net)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def forward(self, curves, bounds, q_values=None, conditioning_params=None, key_padding_mask=None, unscaled_q_values=None):
|
|
115
|
+
"""
|
|
116
|
+
Args:
|
|
117
|
+
scaled_curves (torch.Tensor): Input tensor of shape [batch_size, n_points] or [batch_size, n_channels, n_points].
|
|
118
|
+
scaled_bounds (torch.Tensor): Tensor representing prior bounds, shape [batch_size, 2*n_params].
|
|
119
|
+
scaled_q_values (torch.Tensor, optional): Tensor of shape [batch_size, n_points].
|
|
120
|
+
scaled_conditioning_params (torch.Tensor, optional): Additional parameters for conditioning, shape [batch_size, ...].
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
if curves.dim() == 2:
|
|
124
|
+
curves = curves.unsqueeze(1)
|
|
125
|
+
|
|
126
|
+
additional_channels = []
|
|
127
|
+
if q_values is not None and not isinstance(self.embedding_net, IntegralConvEmbedding):
|
|
128
|
+
additional_channels.append(q_values.unsqueeze(1))
|
|
129
|
+
|
|
130
|
+
if additional_channels:
|
|
131
|
+
curves = torch.cat([curves] + additional_channels, dim=1) # [batch_size, n_channels, n_points]
|
|
132
|
+
|
|
133
|
+
if isinstance(self.embedding_net, IntegralConvEmbedding):
|
|
134
|
+
x = self.embedding_net(q=unscaled_q_values.float(), y=curves.permute(0, 2, 1), drop_mask=key_padding_mask)
|
|
135
|
+
else:
|
|
136
|
+
x = self.embedding_net(curves)
|
|
137
|
+
|
|
138
|
+
if self.conditioning == 'concat':
|
|
139
|
+
x = torch.cat([x, bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
|
|
140
|
+
x = self.mlp(x)
|
|
141
|
+
|
|
142
|
+
elif self.conditioning in ['glu', 'film']:
|
|
143
|
+
condition = torch.cat([bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
|
|
144
|
+
x = self.mlp(x, condition=condition)
|
|
145
|
+
|
|
146
|
+
else:
|
|
147
|
+
raise NotImplementedError(f"Conditioning type {self.conditioning} not recognized.")
|
|
148
|
+
|
|
149
|
+
if self.tanh_output:
|
|
150
|
+
x = torch.tanh(x)
|
|
151
|
+
|
|
152
|
+
return x
|
|
153
|
+
|
|
154
|
+
class NetworkWithPriorsConvEmb(NetworkWithPriors):
|
|
155
|
+
"""Wrapper for back-compatibility with previous versions of the package"""
|
|
156
|
+
def __init__(self, **kwargs):
|
|
157
|
+
embedding_net_kwargs = {
|
|
158
|
+
'in_channels': kwargs.pop('in_channels', 1),
|
|
159
|
+
'hidden_channels': kwargs.pop('hidden_channels', [32, 64, 128, 256, 512]),
|
|
160
|
+
'dim_embedding': kwargs.pop('dim_embedding', 128),
|
|
161
|
+
'dim_avpool': kwargs.pop('dim_avpool', 1),
|
|
162
|
+
'activation': kwargs.pop('embedding_net_activation', 'gelu'),
|
|
163
|
+
'use_batch_norm': kwargs.pop('use_batch_norm', False),
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
super().__init__(
|
|
167
|
+
embedding_net_type='conv',
|
|
168
|
+
embedding_net_kwargs=embedding_net_kwargs,
|
|
169
|
+
**kwargs
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
class NetworkWithPriorsFnoEmb(NetworkWithPriors):
|
|
173
|
+
"""Wrapper for back-compatibility with previous versions of the package"""
|
|
174
|
+
def __init__(self, **kwargs):
|
|
175
|
+
embedding_net_kwargs = {
|
|
176
|
+
'in_channels': kwargs.pop('in_channels', 2),
|
|
177
|
+
'dim_embedding': kwargs.pop('dim_embedding', 128),
|
|
178
|
+
'modes': kwargs.pop('modes', 16),
|
|
179
|
+
'width_fno': kwargs.pop('width_fno', 64),
|
|
180
|
+
'n_fno_blocks': kwargs.pop('n_fno_blocks', 6),
|
|
181
|
+
'activation': kwargs.pop('embedding_net_activation', 'gelu'),
|
|
182
|
+
'fusion_self_attention': kwargs.pop('fusion_self_attention', False),
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
super().__init__(
|
|
186
|
+
embedding_net_type='fno',
|
|
187
|
+
embedding_net_kwargs=embedding_net_kwargs,
|
|
188
|
+
**kwargs
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# class NetworkWithPriorsConvEmb(nn.Module):
|
|
192
|
+
# """MLP network with 1D CNN embedding network
|
|
193
|
+
|
|
194
|
+
# .. image:: ../documentation/FigureReflectometryNetwork.png
|
|
195
|
+
# :width: 800px
|
|
196
|
+
# :align: center
|
|
197
|
+
|
|
198
|
+
# Args:
|
|
199
|
+
# in_channels (int, optional): the number of input channels of the 1D CNN. Defaults to 1.
|
|
200
|
+
# hidden_channels (tuple, optional): list with the number of channels for each layer of the 1D CNN. Defaults to (32, 64, 128, 256, 512).
|
|
201
|
+
# dim_embedding (int, optional): the dimension of the embedding produced by the 1D CNN. Defaults to 128.
|
|
202
|
+
# dim_avpool (int, optional): the type of activation function in the 1D CNN. Defaults to 1.
|
|
203
|
+
# embedding_net_activation (str, optional): the type of activation function in the 1D CNN. Defaults to 'gelu'.
|
|
204
|
+
# use_batch_norm (bool, optional): whether to use batch normalization (in both the 1D CNN and the MLP). Defaults to False.
|
|
205
|
+
# dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
206
|
+
# layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
|
|
207
|
+
# num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
|
|
208
|
+
# repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
|
|
209
|
+
# mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
|
|
210
|
+
# dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
|
|
211
|
+
# use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
|
|
212
|
+
# pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
|
|
213
|
+
# residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
|
|
214
|
+
# adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
|
|
215
|
+
# conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
|
|
216
|
+
# """
|
|
217
|
+
# def __init__(self,
|
|
218
|
+
# in_channels: int = 1,
|
|
219
|
+
# hidden_channels: tuple = (32, 64, 128, 256, 512),
|
|
220
|
+
# dim_embedding: int = 128,
|
|
221
|
+
# dim_avpool: int = 1,
|
|
222
|
+
# embedding_net_activation: str = 'gelu',
|
|
223
|
+
# use_batch_norm: bool = False,
|
|
224
|
+
# dim_out: int = 8,
|
|
225
|
+
# layer_width: int = 512,
|
|
226
|
+
# num_blocks: int = 4,
|
|
227
|
+
# repeats_per_block: int = 2,
|
|
228
|
+
# mlp_activation: str = 'gelu',
|
|
229
|
+
# dropout_rate: float = 0.0,
|
|
230
|
+
# use_selu_init: bool = False,
|
|
231
|
+
# pretrained_embedding_net: str = None,
|
|
232
|
+
# residual: bool = True,
|
|
233
|
+
# adaptive_activation: bool = False,
|
|
234
|
+
# conditioning: str = 'concat',
|
|
235
|
+
# ):
|
|
236
|
+
# super().__init__()
|
|
237
|
+
|
|
238
|
+
# self.in_channels = in_channels
|
|
239
|
+
# self.conditioning = conditioning
|
|
240
|
+
|
|
241
|
+
# self.embedding_net = ConvEncoder(
|
|
242
|
+
# in_channels=in_channels,
|
|
243
|
+
# hidden_channels=hidden_channels,
|
|
244
|
+
# dim_latent=dim_embedding,
|
|
245
|
+
# dim_avpool=dim_avpool,
|
|
246
|
+
# use_batch_norm=use_batch_norm,
|
|
247
|
+
# activation=embedding_net_activation
|
|
248
|
+
# )
|
|
249
|
+
|
|
250
|
+
# self.dim_prior_bounds = 2 * dim_out
|
|
251
|
+
|
|
252
|
+
# if conditioning == 'concat':
|
|
253
|
+
# dim_mlp_in = dim_embedding + self.dim_prior_bounds
|
|
254
|
+
# dim_condition = 0
|
|
255
|
+
# elif conditioning == 'glu' or conditioning == 'film':
|
|
256
|
+
# dim_mlp_in = dim_embedding
|
|
257
|
+
# dim_condition = self.dim_prior_bounds
|
|
258
|
+
# else:
|
|
259
|
+
# raise NotImplementedError
|
|
260
|
+
|
|
261
|
+
# self.mlp = ResidualMLP(
|
|
262
|
+
# dim_in=dim_mlp_in,
|
|
263
|
+
# dim_out=dim_out,
|
|
264
|
+
# dim_condition=dim_condition,
|
|
265
|
+
# layer_width=layer_width,
|
|
266
|
+
# num_blocks=num_blocks,
|
|
267
|
+
# repeats_per_block=repeats_per_block,
|
|
268
|
+
# activation=mlp_activation,
|
|
269
|
+
# use_batch_norm=use_batch_norm,
|
|
270
|
+
# dropout_rate=dropout_rate,
|
|
271
|
+
# residual=residual,
|
|
272
|
+
# adaptive_activation=adaptive_activation,
|
|
273
|
+
# conditioning=conditioning,
|
|
274
|
+
# )
|
|
275
|
+
|
|
276
|
+
# if use_selu_init and embedding_net_activation == 'selu':
|
|
277
|
+
# self.embedding_net.apply(selu_init)
|
|
278
|
+
|
|
279
|
+
# if use_selu_init and mlp_activation == 'selu':
|
|
280
|
+
# self.mlp.apply(selu_init)
|
|
281
|
+
|
|
282
|
+
# if pretrained_embedding_net:
|
|
283
|
+
# self.embedding_net.load_weights(pretrained_embedding_net)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] = None):
|
|
287
|
+
# """
|
|
288
|
+
# Args:
|
|
289
|
+
# curves (Tensor): reflectivity curves
|
|
290
|
+
# bounds (Tensor): prior bounds
|
|
291
|
+
# q_values (Tensor, optional): q values. Defaults to None.
|
|
292
|
+
|
|
293
|
+
# Returns:
|
|
294
|
+
# Tensor: prediction
|
|
295
|
+
# """
|
|
296
|
+
# if q_values is not None:
|
|
297
|
+
# curves = torch.cat([curves[:, None, :], q_values[:, None, :]], dim=1)
|
|
298
|
+
|
|
299
|
+
# if self.conditioning == 'concat':
|
|
300
|
+
# x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
|
|
301
|
+
# x = self.mlp(x)
|
|
302
|
+
|
|
303
|
+
# elif self.conditioning == 'glu' or self.conditioning == 'film':
|
|
304
|
+
# x = self.mlp(self.embedding_net(curves), condition=bounds)
|
|
305
|
+
|
|
306
|
+
# return x
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
# class NetworkWithPriorsFnoEmb(nn.Module):
|
|
310
|
+
# """MLP network with FNO embedding network
|
|
311
|
+
|
|
312
|
+
# Args:
|
|
313
|
+
# in_channels (int, optional): the number of input channels to the FNO-based embedding network. Defaults to 2.
|
|
314
|
+
# dim_embedding (int, optional): the dimension of the embedding produced by the FNO. Defaults to 128.
|
|
315
|
+
# modes (int, optional): the number of Fourier modes that are utilized. Defaults to 16.
|
|
316
|
+
# width_fno (int, optional): the number of channels in the FNO blocks. Defaults to 64.
|
|
317
|
+
# embedding_net_activation (str, optional): the type of activation function in the embedding network. Defaults to 'gelu'.
|
|
318
|
+
# n_fno_blocks (int, optional): the number of FNO blocks. Defaults to 6.
|
|
319
|
+
# fusion_self_attention (bool, optional): if ``True`` a fusion layer is used after the FNO blocks to produce the final output. Defaults to False.
|
|
320
|
+
# dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
321
|
+
# layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
|
|
322
|
+
# num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
|
|
323
|
+
# repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
|
|
324
|
+
# use_batch_norm (bool, optional): whether to use batch normalization (only in the MLP). Defaults to False.
|
|
325
|
+
# mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
|
|
326
|
+
# dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
|
|
327
|
+
# use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
|
|
328
|
+
# residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
|
|
329
|
+
# adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
|
|
330
|
+
# conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
|
|
331
|
+
# """
|
|
332
|
+
# def __init__(self,
|
|
333
|
+
# in_channels: int = 2,
|
|
334
|
+
# dim_embedding: int = 128,
|
|
335
|
+
# modes: int = 16,
|
|
336
|
+
# width_fno: int = 64,
|
|
337
|
+
# embedding_net_activation: str = 'gelu',
|
|
338
|
+
# n_fno_blocks : int = 6,
|
|
339
|
+
# fusion_self_attention: bool = False,
|
|
340
|
+
# dim_out: int = 8,
|
|
341
|
+
# layer_width: int = 512,
|
|
342
|
+
# num_blocks: int = 4,
|
|
343
|
+
# repeats_per_block: int = 2,
|
|
344
|
+
# use_batch_norm: bool = False,
|
|
345
|
+
# mlp_activation: str = 'gelu',
|
|
346
|
+
# dropout_rate: float = 0.0,
|
|
347
|
+
# use_selu_init: bool = False,
|
|
348
|
+
# residual: bool = True,
|
|
349
|
+
# adaptive_activation: bool = False,
|
|
350
|
+
# conditioning: str = 'concat',
|
|
351
|
+
# ):
|
|
352
|
+
# super().__init__()
|
|
353
|
+
|
|
354
|
+
# self.conditioning = conditioning
|
|
355
|
+
|
|
356
|
+
# self.embedding_net = FnoEncoder(
|
|
357
|
+
# ch_in=in_channels,
|
|
358
|
+
# dim_embedding=dim_embedding,
|
|
359
|
+
# modes=modes,
|
|
360
|
+
# width_fno=width_fno,
|
|
361
|
+
# n_fno_blocks=n_fno_blocks,
|
|
362
|
+
# activation=embedding_net_activation,
|
|
363
|
+
# fusion_self_attention=fusion_self_attention
|
|
364
|
+
# )
|
|
365
|
+
|
|
366
|
+
# self.dim_prior_bounds = 2 * dim_out
|
|
367
|
+
|
|
368
|
+
# if conditioning == 'concat':
|
|
369
|
+
# dim_mlp_in = dim_embedding + self.dim_prior_bounds
|
|
370
|
+
# dim_condition = 0
|
|
371
|
+
# elif conditioning == 'glu' or conditioning == 'film':
|
|
372
|
+
# dim_mlp_in = dim_embedding
|
|
373
|
+
# dim_condition = self.dim_prior_bounds
|
|
374
|
+
# else:
|
|
375
|
+
# raise NotImplementedError
|
|
376
|
+
|
|
377
|
+
# self.mlp = ResidualMLP(
|
|
378
|
+
# dim_in=dim_mlp_in,
|
|
379
|
+
# dim_out=dim_out,
|
|
380
|
+
# dim_condition=dim_condition,
|
|
381
|
+
# layer_width=layer_width,
|
|
382
|
+
# num_blocks=num_blocks,
|
|
383
|
+
# repeats_per_block=repeats_per_block,
|
|
384
|
+
# activation=mlp_activation,
|
|
385
|
+
# use_batch_norm=use_batch_norm,
|
|
386
|
+
# dropout_rate=dropout_rate,
|
|
387
|
+
# residual=residual,
|
|
388
|
+
# adaptive_activation=adaptive_activation,
|
|
389
|
+
# conditioning=conditioning,
|
|
390
|
+
# )
|
|
391
|
+
|
|
392
|
+
# if use_selu_init and embedding_net_activation == 'selu':
|
|
393
|
+
# self.FnoEncoder.apply(selu_init)
|
|
394
|
+
|
|
395
|
+
# if use_selu_init and mlp_activation == 'selu':
|
|
396
|
+
# self.mlp.apply(selu_init)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
# def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] =None):
|
|
400
|
+
# """
|
|
401
|
+
# Args:
|
|
402
|
+
# curves (Tensor): reflectivity curves
|
|
403
|
+
# bounds (Tensor): prior bounds
|
|
404
|
+
# q_values (Tensor, optional): q values. Defaults to None.
|
|
405
|
+
|
|
406
|
+
# Returns:
|
|
407
|
+
# Tensor: prediction
|
|
408
|
+
# """
|
|
409
|
+
# if curves.dim() < 3:
|
|
410
|
+
# curves = curves[:, None, :]
|
|
411
|
+
# if q_values is not None:
|
|
412
|
+
# curves = torch.cat([curves, q_values[:, None, :]], dim=1)
|
|
413
|
+
|
|
414
|
+
# if self.conditioning == 'concat':
|
|
415
|
+
# x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
|
|
416
|
+
# x = self.mlp(x)
|
|
417
|
+
|
|
418
|
+
# elif self.conditioning == 'glu' or self.conditioning == 'film':
|
|
419
|
+
# x = self.mlp(self.embedding_net(curves), condition=bounds)
|
|
420
|
+
|
|
421
|
+
# return x
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def selu_init(m):
|
|
426
|
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
427
|
+
m.weight.data.normal_(0.0, 0.5 / math.sqrt(m.weight.numel()))
|
|
428
|
+
nn.init.constant_(m.bias, 0)
|
|
429
|
+
elif isinstance(m, nn.BatchNorm1d):
|
|
430
|
+
size = m.weight.size()
|
|
431
|
+
fan_in = size[0]
|
|
432
|
+
|
|
433
|
+
m.weight.data.normal_(0.0, 1.0 / math.sqrt(fan_in))
|
|
434
|
+
m.bias.data.fill_(0)
|