reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (39) hide show
  1. reflectorch/data_generation/__init__.py +2 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +90 -15
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +31 -11
  8. reflectorch/data_generation/reflectivity/__init__.py +56 -14
  9. reflectorch/data_generation/reflectivity/abeles.py +31 -16
  10. reflectorch/data_generation/reflectivity/kinematical.py +5 -6
  11. reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
  12. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  13. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  14. reflectorch/data_generation/smearing.py +42 -11
  15. reflectorch/data_generation/utils.py +92 -18
  16. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  17. reflectorch/inference/inference_model.py +220 -105
  18. reflectorch/inference/plotting.py +98 -0
  19. reflectorch/inference/scipy_fitter.py +84 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +122 -23
  26. reflectorch/models/__init__.py +1 -1
  27. reflectorch/models/encoders/__init__.py +0 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/networks/__init__.py +2 -0
  31. reflectorch/models/networks/mlp_networks.py +324 -152
  32. reflectorch/models/networks/residual_net.py +31 -5
  33. reflectorch/runs/train.py +0 -1
  34. reflectorch/runs/utils.py +43 -9
  35. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
  36. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
  37. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
  38. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
  39. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
@@ -10,75 +10,77 @@ from reflectorch.models.encoders.conv_encoder import ConvEncoder
10
10
  from reflectorch.models.encoders.fno import FnoEncoder
11
11
  from reflectorch.models.activations import activation_by_name
12
12
 
13
- class NetworkWithPriorsConvEmb(nn.Module):
14
- """MLP network with 1D CNN embedding network
13
+ class NetworkWithPriors(nn.Module):
14
+ """MLP network with an embedding network
15
15
 
16
16
  .. image:: ../documentation/FigureReflectometryNetwork.png
17
17
  :width: 800px
18
18
  :align: center
19
19
 
20
20
  Args:
21
- in_channels (int, optional): the number of input channels of the 1D CNN. Defaults to 1.
22
- hidden_channels (tuple, optional): list with the number of channels for each layer of the 1D CNN. Defaults to (32, 64, 128, 256, 512).
23
- dim_embedding (int, optional): the dimension of the embedding produced by the 1D CNN. Defaults to 128.
24
- dim_avpool (int, optional): the type of activation function in the 1D CNN. Defaults to 1.
25
- embedding_net_activation (str, optional): the type of activation function in the 1D CNN. Defaults to 'gelu'.
26
- use_batch_norm (bool, optional): whether to use batch normalization (in both the 1D CNN and the MLP). Defaults to False.
21
+ embedding_net_type (str): the type of embedding network, either 'conv' or 'fno'.
22
+ embedding_net_kwargs (dict): dictionary containing the keyword arguments for the embedding network.
27
23
  dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
24
+ dim_conditioning_params (int, optional): the dimension of other parameters the network is conditioned on (e.g. for the smearing coefficient dq/q)
28
25
  layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
29
26
  num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
30
27
  repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
31
28
  mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
29
+ use_batch_norm (bool, optional): whether to use batch normalization in the MLP. Defaults to True.
30
+ use_layer_norm (bool, optional): whether to use layer normalization in the MLP (if use_batch_norm is False). Defaults to False.
32
31
  dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
32
+ tanh_output (bool, optional): whether to apply a tanh function to the output. Defaults to False.
33
33
  use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
34
34
  pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
35
35
  residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
36
36
  adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
37
37
  conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
38
- """
38
+ """
39
39
  def __init__(self,
40
- in_channels: int = 1,
41
- hidden_channels: tuple = (32, 64, 128, 256, 512),
42
- dim_embedding: int = 128,
43
- dim_avpool: int = 1,
44
- embedding_net_activation: str = 'gelu',
45
- use_batch_norm: bool = False,
40
+ embedding_net_type: str, # 'conv', 'fno'
41
+ embedding_net_kwargs: dict,
42
+ pretrained_embedding_net: str = None,
46
43
  dim_out: int = 8,
44
+ dim_conditioning_params: int = 0,
47
45
  layer_width: int = 512,
48
46
  num_blocks: int = 4,
49
47
  repeats_per_block: int = 2,
50
48
  mlp_activation: str = 'gelu',
49
+ use_batch_norm: bool = True,
50
+ use_layer_norm: bool = False,
51
51
  dropout_rate: float = 0.0,
52
+ tanh_output: bool = False,
52
53
  use_selu_init: bool = False,
53
- pretrained_embedding_net: str = None,
54
54
  residual: bool = True,
55
55
  adaptive_activation: bool = False,
56
56
  conditioning: str = 'concat',
57
- ):
57
+ concat_condition_first_layer: bool = True):
58
58
  super().__init__()
59
59
 
60
- self.in_channels = in_channels
61
60
  self.conditioning = conditioning
62
-
63
- self.embedding_net = ConvEncoder(
64
- in_channels=in_channels,
65
- hidden_channels=hidden_channels,
66
- dim_latent=dim_embedding,
67
- dim_avpool=dim_avpool,
68
- use_batch_norm=use_batch_norm,
69
- activation=embedding_net_activation
70
- )
71
-
72
61
  self.dim_prior_bounds = 2 * dim_out
62
+ self.dim_conditioning_params = dim_conditioning_params
63
+ self.tanh_output = tanh_output
64
+
65
+ if embedding_net_type == 'conv':
66
+ self.embedding_net = ConvEncoder(**embedding_net_kwargs)
67
+ elif embedding_net_type == 'fno':
68
+ self.embedding_net = FnoEncoder(**embedding_net_kwargs)
69
+ elif embedding_net_type == 'no_embedding_net':
70
+ self.embedding_net = nn.Identity()
71
+ else:
72
+ raise ValueError(f"Unsupported embedding_net_type: {embedding_net_type}")
73
+
74
+ self.dim_embedding = embedding_net_kwargs['dim_embedding']
73
75
 
74
76
  if conditioning == 'concat':
75
- dim_mlp_in = dim_embedding + self.dim_prior_bounds
77
+ dim_mlp_in = self.dim_embedding + self.dim_prior_bounds + self.dim_conditioning_params
76
78
  dim_condition = 0
77
79
  elif conditioning == 'glu' or conditioning == 'film':
78
- dim_mlp_in = dim_embedding
79
- dim_condition = self.dim_prior_bounds
80
+ dim_mlp_in = self.dim_embedding
81
+ dim_condition = self.dim_prior_bounds + self.dim_conditioning_params
80
82
  else:
81
- raise NotImplementedError
83
+ raise NotImplementedError(f"Conditioning type '{conditioning}' is not supported.")
82
84
 
83
85
  self.mlp = ResidualMLP(
84
86
  dim_in=dim_mlp_in,
@@ -89,15 +91,16 @@ class NetworkWithPriorsConvEmb(nn.Module):
89
91
  repeats_per_block=repeats_per_block,
90
92
  activation=mlp_activation,
91
93
  use_batch_norm=use_batch_norm,
94
+ use_layer_norm=use_layer_norm,
92
95
  dropout_rate=dropout_rate,
93
96
  residual=residual,
94
97
  adaptive_activation=adaptive_activation,
95
98
  conditioning=conditioning,
99
+ concat_condition_first_layer=concat_condition_first_layer,
96
100
  )
97
101
 
98
- if use_selu_init and embedding_net_activation == 'selu':
102
+ if use_selu_init and embedding_net_kwargs.get('activation', None) == 'selu':
99
103
  self.embedding_net.apply(selu_init)
100
-
101
104
  if use_selu_init and mlp_activation == 'selu':
102
105
  self.mlp.apply(selu_init)
103
106
 
@@ -105,142 +108,311 @@ class NetworkWithPriorsConvEmb(nn.Module):
105
108
  self.embedding_net.load_weights(pretrained_embedding_net)
106
109
 
107
110
 
108
- def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] = None):
111
+ def forward(self, curves, bounds, q_values=None, conditioning_params=None):
109
112
  """
110
113
  Args:
111
- curves (Tensor): reflectivity curves
112
- bounds (Tensor): prior bounds
113
- q_values (Tensor, optional): q values. Defaults to None.
114
-
115
- Returns:
116
- Tensor: prediction
114
+ scaled_curves (torch.Tensor): Input tensor of shape [batch_size, n_points] or [batch_size, n_channels, n_points].
115
+ scaled_bounds (torch.Tensor): Tensor representing prior bounds, shape [batch_size, 2*n_params].
116
+ scaled_q_values (torch.Tensor, optional): Tensor of shape [batch_size, n_points].
117
+ scaled_conditioning_params (torch.Tensor, optional): Additional parameters for conditioning, shape [batch_size, ...].
117
118
  """
119
+
120
+ if curves.dim() == 2:
121
+ curves = curves.unsqueeze(1)
122
+
123
+ additional_channels = []
118
124
  if q_values is not None:
119
- curves = torch.cat([curves[:, None, :], q_values[:, None, :]], dim=1)
125
+ additional_channels.append(q_values.unsqueeze(1))
126
+
127
+ if additional_channels:
128
+ curves = torch.cat([curves] + additional_channels, dim=1) # [batch_size, n_channels, n_points]
129
+
130
+ x = self.embedding_net(curves)
120
131
 
121
132
  if self.conditioning == 'concat':
122
- x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
133
+ x = torch.cat([x, bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
123
134
  x = self.mlp(x)
124
135
 
125
- elif self.conditioning == 'glu' or self.conditioning == 'film':
126
- x = self.mlp(self.embedding_net(curves), condition=bounds)
136
+ elif self.conditioning in ['glu', 'film']:
137
+ condition = torch.cat([bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
138
+ x = self.mlp(x, condition=condition)
127
139
 
128
- return x
129
-
130
-
131
- class NetworkWithPriorsFnoEmb(nn.Module):
132
- """MLP network with FNO embedding network
140
+ else:
141
+ raise NotImplementedError(f"Conditioning type {self.conditioning} not recognized.")
133
142
 
134
- Args:
135
- in_channels (int, optional): the number of input channels to the FNO-based embedding network. Defaults to 2.
136
- dim_embedding (int, optional): the dimension of the embedding produced by the FNO. Defaults to 128.
137
- modes (int, optional): the number of Fourier modes that are utilized. Defaults to 16.
138
- width_fno (int, optional): the number of channels in the FNO blocks. Defaults to 64.
139
- embedding_net_activation (str, optional): the type of activation function in the embedding network. Defaults to 'gelu'.
140
- n_fno_blocks (int, optional): the number of FNO blocks. Defaults to 6.
141
- fusion_self_attention (bool, optional): if ``True`` a fusion layer is used after the FNO blocks to produce the final output. Defaults to False.
142
- dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
143
- layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
144
- num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
145
- repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
146
- use_batch_norm (bool, optional): whether to use batch normalization (only in the MLP). Defaults to False.
147
- mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
148
- dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
149
- use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
150
- residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
151
- adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
152
- conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
153
- """
154
- def __init__(self,
155
- in_channels: int = 2,
156
- dim_embedding: int = 128,
157
- modes: int = 16,
158
- width_fno: int = 64,
159
- embedding_net_activation: str = 'gelu',
160
- n_fno_blocks : int = 6,
161
- fusion_self_attention: bool = False,
162
- dim_out: int = 8,
163
- layer_width: int = 512,
164
- num_blocks: int = 4,
165
- repeats_per_block: int = 2,
166
- use_batch_norm: bool = False,
167
- mlp_activation: str = 'gelu',
168
- dropout_rate: float = 0.0,
169
- use_selu_init: bool = False,
170
- residual: bool = True,
171
- adaptive_activation: bool = False,
172
- conditioning: str = 'concat',
173
- ):
174
- super().__init__()
143
+ if self.tanh_output:
144
+ x = torch.tanh(x)
175
145
 
176
- self.conditioning = conditioning
146
+ return x
177
147
 
178
- self.embedding_net = FnoEncoder(
179
- ch_in=in_channels,
180
- dim_embedding=dim_embedding,
181
- modes=modes,
182
- width_fno=width_fno,
183
- n_fno_blocks=n_fno_blocks,
184
- activation=embedding_net_activation,
185
- fusion_self_attention=fusion_self_attention
148
+ class NetworkWithPriorsConvEmb(NetworkWithPriors):
149
+ """Wrapper for back-compatibility with previous versions of the package"""
150
+ def __init__(self, **kwargs):
151
+ embedding_net_kwargs = {
152
+ 'in_channels': kwargs.pop('in_channels', 1),
153
+ 'hidden_channels': kwargs.pop('hidden_channels', [32, 64, 128, 256, 512]),
154
+ 'dim_embedding': kwargs.pop('dim_embedding', 128),
155
+ 'dim_avpool': kwargs.pop('dim_avpool', 1),
156
+ 'activation': kwargs.pop('embedding_net_activation', 'gelu'),
157
+ 'use_batch_norm': kwargs.pop('use_batch_norm', False),
158
+ }
159
+
160
+ super().__init__(
161
+ embedding_net_type='conv',
162
+ embedding_net_kwargs=embedding_net_kwargs,
163
+ **kwargs
186
164
  )
187
165
 
188
- self.dim_prior_bounds = 2 * dim_out
189
-
190
- if conditioning == 'concat':
191
- dim_mlp_in = dim_embedding + self.dim_prior_bounds
192
- dim_condition = 0
193
- elif conditioning == 'glu' or conditioning == 'film':
194
- dim_mlp_in = dim_embedding
195
- dim_condition = self.dim_prior_bounds
196
- else:
197
- raise NotImplementedError
198
-
199
- self.mlp = ResidualMLP(
200
- dim_in=dim_mlp_in,
201
- dim_out=dim_out,
202
- dim_condition=dim_condition,
203
- layer_width=layer_width,
204
- num_blocks=num_blocks,
205
- repeats_per_block=repeats_per_block,
206
- activation=mlp_activation,
207
- use_batch_norm=use_batch_norm,
208
- dropout_rate=dropout_rate,
209
- residual=residual,
210
- adaptive_activation=adaptive_activation,
211
- conditioning=conditioning,
166
+ class NetworkWithPriorsFnoEmb(NetworkWithPriors):
167
+ """Wrapper for back-compatibility with previous versions of the package"""
168
+ def __init__(self, **kwargs):
169
+ embedding_net_kwargs = {
170
+ 'in_channels': kwargs.pop('in_channels', 2),
171
+ 'dim_embedding': kwargs.pop('dim_embedding', 128),
172
+ 'modes': kwargs.pop('modes', 16),
173
+ 'width_fno': kwargs.pop('width_fno', 64),
174
+ 'n_fno_blocks': kwargs.pop('n_fno_blocks', 6),
175
+ 'activation': kwargs.pop('embedding_net_activation', 'gelu'),
176
+ 'fusion_self_attention': kwargs.pop('fusion_self_attention', False),
177
+ }
178
+
179
+ super().__init__(
180
+ embedding_net_type='fno',
181
+ embedding_net_kwargs=embedding_net_kwargs,
182
+ **kwargs
212
183
  )
213
184
 
214
- if use_selu_init and embedding_net_activation == 'selu':
215
- self.FnoEncoder.apply(selu_init)
216
-
217
- if use_selu_init and mlp_activation == 'selu':
218
- self.mlp.apply(selu_init)
185
+ # class NetworkWithPriorsConvEmb(nn.Module):
186
+ # """MLP network with 1D CNN embedding network
187
+
188
+ # .. image:: ../documentation/FigureReflectometryNetwork.png
189
+ # :width: 800px
190
+ # :align: center
191
+
192
+ # Args:
193
+ # in_channels (int, optional): the number of input channels of the 1D CNN. Defaults to 1.
194
+ # hidden_channels (tuple, optional): list with the number of channels for each layer of the 1D CNN. Defaults to (32, 64, 128, 256, 512).
195
+ # dim_embedding (int, optional): the dimension of the embedding produced by the 1D CNN. Defaults to 128.
196
+ # dim_avpool (int, optional): the type of activation function in the 1D CNN. Defaults to 1.
197
+ # embedding_net_activation (str, optional): the type of activation function in the 1D CNN. Defaults to 'gelu'.
198
+ # use_batch_norm (bool, optional): whether to use batch normalization (in both the 1D CNN and the MLP). Defaults to False.
199
+ # dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
200
+ # layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
201
+ # num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
202
+ # repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
203
+ # mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
204
+ # dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
205
+ # use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
206
+ # pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
207
+ # residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
208
+ # adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
209
+ # conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
210
+ # """
211
+ # def __init__(self,
212
+ # in_channels: int = 1,
213
+ # hidden_channels: tuple = (32, 64, 128, 256, 512),
214
+ # dim_embedding: int = 128,
215
+ # dim_avpool: int = 1,
216
+ # embedding_net_activation: str = 'gelu',
217
+ # use_batch_norm: bool = False,
218
+ # dim_out: int = 8,
219
+ # layer_width: int = 512,
220
+ # num_blocks: int = 4,
221
+ # repeats_per_block: int = 2,
222
+ # mlp_activation: str = 'gelu',
223
+ # dropout_rate: float = 0.0,
224
+ # use_selu_init: bool = False,
225
+ # pretrained_embedding_net: str = None,
226
+ # residual: bool = True,
227
+ # adaptive_activation: bool = False,
228
+ # conditioning: str = 'concat',
229
+ # ):
230
+ # super().__init__()
231
+
232
+ # self.in_channels = in_channels
233
+ # self.conditioning = conditioning
234
+
235
+ # self.embedding_net = ConvEncoder(
236
+ # in_channels=in_channels,
237
+ # hidden_channels=hidden_channels,
238
+ # dim_latent=dim_embedding,
239
+ # dim_avpool=dim_avpool,
240
+ # use_batch_norm=use_batch_norm,
241
+ # activation=embedding_net_activation
242
+ # )
243
+
244
+ # self.dim_prior_bounds = 2 * dim_out
245
+
246
+ # if conditioning == 'concat':
247
+ # dim_mlp_in = dim_embedding + self.dim_prior_bounds
248
+ # dim_condition = 0
249
+ # elif conditioning == 'glu' or conditioning == 'film':
250
+ # dim_mlp_in = dim_embedding
251
+ # dim_condition = self.dim_prior_bounds
252
+ # else:
253
+ # raise NotImplementedError
254
+
255
+ # self.mlp = ResidualMLP(
256
+ # dim_in=dim_mlp_in,
257
+ # dim_out=dim_out,
258
+ # dim_condition=dim_condition,
259
+ # layer_width=layer_width,
260
+ # num_blocks=num_blocks,
261
+ # repeats_per_block=repeats_per_block,
262
+ # activation=mlp_activation,
263
+ # use_batch_norm=use_batch_norm,
264
+ # dropout_rate=dropout_rate,
265
+ # residual=residual,
266
+ # adaptive_activation=adaptive_activation,
267
+ # conditioning=conditioning,
268
+ # )
269
+
270
+ # if use_selu_init and embedding_net_activation == 'selu':
271
+ # self.embedding_net.apply(selu_init)
272
+
273
+ # if use_selu_init and mlp_activation == 'selu':
274
+ # self.mlp.apply(selu_init)
275
+
276
+ # if pretrained_embedding_net:
277
+ # self.embedding_net.load_weights(pretrained_embedding_net)
278
+
279
+
280
+ # def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] = None):
281
+ # """
282
+ # Args:
283
+ # curves (Tensor): reflectivity curves
284
+ # bounds (Tensor): prior bounds
285
+ # q_values (Tensor, optional): q values. Defaults to None.
286
+
287
+ # Returns:
288
+ # Tensor: prediction
289
+ # """
290
+ # if q_values is not None:
291
+ # curves = torch.cat([curves[:, None, :], q_values[:, None, :]], dim=1)
292
+
293
+ # if self.conditioning == 'concat':
294
+ # x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
295
+ # x = self.mlp(x)
296
+
297
+ # elif self.conditioning == 'glu' or self.conditioning == 'film':
298
+ # x = self.mlp(self.embedding_net(curves), condition=bounds)
299
+
300
+ # return x
301
+
302
+
303
+ # class NetworkWithPriorsFnoEmb(nn.Module):
304
+ # """MLP network with FNO embedding network
305
+
306
+ # Args:
307
+ # in_channels (int, optional): the number of input channels to the FNO-based embedding network. Defaults to 2.
308
+ # dim_embedding (int, optional): the dimension of the embedding produced by the FNO. Defaults to 128.
309
+ # modes (int, optional): the number of Fourier modes that are utilized. Defaults to 16.
310
+ # width_fno (int, optional): the number of channels in the FNO blocks. Defaults to 64.
311
+ # embedding_net_activation (str, optional): the type of activation function in the embedding network. Defaults to 'gelu'.
312
+ # n_fno_blocks (int, optional): the number of FNO blocks. Defaults to 6.
313
+ # fusion_self_attention (bool, optional): if ``True`` a fusion layer is used after the FNO blocks to produce the final output. Defaults to False.
314
+ # dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
315
+ # layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
316
+ # num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
317
+ # repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
318
+ # use_batch_norm (bool, optional): whether to use batch normalization (only in the MLP). Defaults to False.
319
+ # mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
320
+ # dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
321
+ # use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
322
+ # residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
323
+ # adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
324
+ # conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
325
+ # """
326
+ # def __init__(self,
327
+ # in_channels: int = 2,
328
+ # dim_embedding: int = 128,
329
+ # modes: int = 16,
330
+ # width_fno: int = 64,
331
+ # embedding_net_activation: str = 'gelu',
332
+ # n_fno_blocks : int = 6,
333
+ # fusion_self_attention: bool = False,
334
+ # dim_out: int = 8,
335
+ # layer_width: int = 512,
336
+ # num_blocks: int = 4,
337
+ # repeats_per_block: int = 2,
338
+ # use_batch_norm: bool = False,
339
+ # mlp_activation: str = 'gelu',
340
+ # dropout_rate: float = 0.0,
341
+ # use_selu_init: bool = False,
342
+ # residual: bool = True,
343
+ # adaptive_activation: bool = False,
344
+ # conditioning: str = 'concat',
345
+ # ):
346
+ # super().__init__()
347
+
348
+ # self.conditioning = conditioning
349
+
350
+ # self.embedding_net = FnoEncoder(
351
+ # ch_in=in_channels,
352
+ # dim_embedding=dim_embedding,
353
+ # modes=modes,
354
+ # width_fno=width_fno,
355
+ # n_fno_blocks=n_fno_blocks,
356
+ # activation=embedding_net_activation,
357
+ # fusion_self_attention=fusion_self_attention
358
+ # )
359
+
360
+ # self.dim_prior_bounds = 2 * dim_out
361
+
362
+ # if conditioning == 'concat':
363
+ # dim_mlp_in = dim_embedding + self.dim_prior_bounds
364
+ # dim_condition = 0
365
+ # elif conditioning == 'glu' or conditioning == 'film':
366
+ # dim_mlp_in = dim_embedding
367
+ # dim_condition = self.dim_prior_bounds
368
+ # else:
369
+ # raise NotImplementedError
370
+
371
+ # self.mlp = ResidualMLP(
372
+ # dim_in=dim_mlp_in,
373
+ # dim_out=dim_out,
374
+ # dim_condition=dim_condition,
375
+ # layer_width=layer_width,
376
+ # num_blocks=num_blocks,
377
+ # repeats_per_block=repeats_per_block,
378
+ # activation=mlp_activation,
379
+ # use_batch_norm=use_batch_norm,
380
+ # dropout_rate=dropout_rate,
381
+ # residual=residual,
382
+ # adaptive_activation=adaptive_activation,
383
+ # conditioning=conditioning,
384
+ # )
385
+
386
+ # if use_selu_init and embedding_net_activation == 'selu':
387
+ # self.FnoEncoder.apply(selu_init)
388
+
389
+ # if use_selu_init and mlp_activation == 'selu':
390
+ # self.mlp.apply(selu_init)
219
391
 
220
392
 
221
- def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] =None):
222
- """
223
- Args:
224
- curves (Tensor): reflectivity curves
225
- bounds (Tensor): prior bounds
226
- q_values (Tensor, optional): q values. Defaults to None.
227
-
228
- Returns:
229
- Tensor: prediction
230
- """
231
- if curves.dim() < 3:
232
- curves = curves[:, None, :]
233
- if q_values is not None:
234
- curves = torch.cat([curves, q_values[:, None, :]], dim=1)
235
-
236
- if self.conditioning == 'concat':
237
- x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
238
- x = self.mlp(x)
239
-
240
- elif self.conditioning == 'glu' or self.conditioning == 'film':
241
- x = self.mlp(self.embedding_net(curves), condition=bounds)
242
-
243
- return x
393
+ # def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] =None):
394
+ # """
395
+ # Args:
396
+ # curves (Tensor): reflectivity curves
397
+ # bounds (Tensor): prior bounds
398
+ # q_values (Tensor, optional): q values. Defaults to None.
399
+
400
+ # Returns:
401
+ # Tensor: prediction
402
+ # """
403
+ # if curves.dim() < 3:
404
+ # curves = curves[:, None, :]
405
+ # if q_values is not None:
406
+ # curves = torch.cat([curves, q_values[:, None, :]], dim=1)
407
+
408
+ # if self.conditioning == 'concat':
409
+ # x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
410
+ # x = self.mlp(x)
411
+
412
+ # elif self.conditioning == 'glu' or self.conditioning == 'film':
413
+ # x = self.mlp(self.embedding_net(curves), condition=bounds)
414
+
415
+ # return x
244
416
 
245
417
 
246
418