reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,311 +1,311 @@
1
- from typing import Tuple
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- __all__ = [
7
- "MULTILAYER_MODELS",
8
- "MultilayerModel",
9
- ]
10
-
11
-
12
- class MultilayerModel(object):
13
- NAME: str = ''
14
- PARAMETER_NAMES: Tuple[str, ...]
15
-
16
- def __init__(self, max_num_layers: int):
17
- self.max_num_layers = max_num_layers
18
-
19
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
20
- raise NotImplementedError
21
-
22
- def from_standard_params(self, params: dict) -> Tensor:
23
- raise NotImplementedError
24
-
25
-
26
- class BasicMultilayerModel1(MultilayerModel):
27
- NAME = 'repeating_multilayer_v1'
28
-
29
- PARAMETER_NAMES = (
30
- "d_full_rel",
31
- "rel_sigmas",
32
- "d_block",
33
- "s_block_rel",
34
- "r_block",
35
- "dr",
36
- "d3_rel",
37
- "s3_rel",
38
- "r3",
39
- "d_sio2",
40
- "s_sio2",
41
- "s_si",
42
- "r_sio2",
43
- "r_si",
44
- )
45
-
46
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
47
- return multilayer_model1(parametrized_model, self.max_num_layers)
48
-
49
-
50
- class BasicMultilayerModel2(MultilayerModel):
51
- NAME = 'repeating_multilayer_v2'
52
-
53
- PARAMETER_NAMES = (
54
- "d_full_rel",
55
- "rel_sigmas",
56
- "dr_sigmoid_rel_pos",
57
- "dr_sigmoid_rel_width",
58
- "d_block",
59
- "s_block_rel",
60
- "r_block",
61
- "dr",
62
- "d3_rel",
63
- "s3_rel",
64
- "r3",
65
- "d_sio2",
66
- "s_sio2",
67
- "s_si",
68
- "r_sio2",
69
- "r_si",
70
- )
71
-
72
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
73
- return multilayer_model2(parametrized_model, self.max_num_layers)
74
-
75
-
76
- class BasicMultilayerModel3(MultilayerModel):
77
- NAME = 'repeating_multilayer_v3'
78
-
79
- PARAMETER_NAMES = (
80
- "d_full_rel",
81
- "rel_sigmas",
82
- "dr_sigmoid_rel_pos",
83
- "dr_sigmoid_rel_width",
84
- "d_block1_rel",
85
- "d_block",
86
- "s_block_rel",
87
- "r_block",
88
- "dr",
89
- "d3_rel",
90
- "s3_rel",
91
- "r3",
92
- "d_sio2",
93
- "s_sio2",
94
- "s_si",
95
- "r_sio2",
96
- "r_si",
97
- )
98
-
99
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
100
- return multilayer_model3(parametrized_model, self.max_num_layers)
101
-
102
-
103
- MULTILAYER_MODELS = {
104
- 'repeating_multilayer_v1': BasicMultilayerModel1,
105
- 'repeating_multilayer_v2': BasicMultilayerModel2,
106
- 'repeating_multilayer_v3': BasicMultilayerModel3,
107
- }
108
-
109
-
110
- def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
111
- n = d_full_rel_max
112
-
113
- (
114
- d_full_rel,
115
- rel_sigmas,
116
- d_block,
117
- s_block_rel,
118
- r_block,
119
- dr,
120
- d3_rel,
121
- s3_rel,
122
- r3,
123
- d_sio2,
124
- s_sio2,
125
- s_si,
126
- r_sio2,
127
- r_si,
128
- ) = parametrized_model.T
129
-
130
- batch_size = parametrized_model.shape[0]
131
-
132
- r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
133
-
134
- r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
135
-
136
- r_block = r_block[:, None].repeat(1, n)
137
- dr = dr[:, None].repeat(1, n)
138
-
139
- sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
140
-
141
- sld_blocks = r_modulations * sld_blocks
142
-
143
- d3 = d3_rel * d_block
144
-
145
- thicknesses = torch.cat(
146
- [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
147
- )
148
-
149
- s_block = s_block_rel * d_block
150
-
151
- roughnesses = torch.cat(
152
- [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
153
- )
154
-
155
- slds = torch.cat(
156
- [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
157
- )
158
-
159
- params = dict(
160
- thicknesses=thicknesses,
161
- roughnesses=roughnesses,
162
- slds=slds
163
- )
164
- return params
165
-
166
-
167
- def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
168
- n = d_full_rel_max
169
-
170
- (
171
- d_full_rel,
172
- rel_sigmas,
173
- dr_sigmoid_rel_pos,
174
- dr_sigmoid_rel_width,
175
- d_block,
176
- s_block_rel,
177
- r_block,
178
- dr,
179
- d3_rel,
180
- s3_rel,
181
- r3,
182
- d_sio2,
183
- s_sio2,
184
- s_si,
185
- r_sio2,
186
- r_si,
187
- ) = parametrized_model.T
188
-
189
- batch_size = parametrized_model.shape[0]
190
-
191
- r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
192
-
193
- r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
194
-
195
- r_block = r_block[:, None].repeat(1, n)
196
- dr = dr[:, None].repeat(1, n)
197
-
198
- dr_positions = r_positions[:, ::2]
199
-
200
- dr_modulations = torch.sigmoid(
201
- -(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
202
- )
203
-
204
- dr = dr * dr_modulations
205
-
206
- sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
207
-
208
- sld_blocks = r_modulations * sld_blocks
209
-
210
- d3 = d3_rel * d_block
211
-
212
- thicknesses = torch.cat(
213
- [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
214
- )
215
-
216
- s_block = s_block_rel * d_block
217
-
218
- roughnesses = torch.cat(
219
- [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
220
- )
221
-
222
- slds = torch.cat(
223
- [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
224
- )
225
-
226
- params = dict(
227
- thicknesses=thicknesses,
228
- roughnesses=roughnesses,
229
- slds=slds
230
- )
231
- return params
232
-
233
-
234
- def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
235
- n = d_full_rel_max
236
-
237
- (
238
- d_full_rel,
239
- rel_sigmas,
240
- dr_sigmoid_rel_pos,
241
- dr_sigmoid_rel_width,
242
- d_block1_rel,
243
- d_block,
244
- s_block_rel,
245
- r_block,
246
- dr,
247
- d3_rel,
248
- s3_rel,
249
- r3,
250
- d_sio2,
251
- s_sio2,
252
- s_si,
253
- r_sio2,
254
- r_si,
255
- ) = parametrized_model.T
256
-
257
- batch_size = parametrized_model.shape[0]
258
-
259
- r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
260
-
261
- r_modulations = torch.sigmoid(
262
- -(
263
- r_positions - 2 * d_full_rel[..., None]
264
- ) / rel_sigmas[..., None]
265
- )
266
-
267
- dr_positions = r_positions[:, ::2]
268
-
269
- dr_modulations = dr[..., None] * (1 - torch.sigmoid(
270
- -(
271
- dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
272
- ) / dr_sigmoid_rel_width[..., None]
273
- ))
274
-
275
- r_block = r_block[..., None].repeat(1, n)
276
- dr = dr[..., None].repeat(1, n)
277
-
278
- sld_blocks = torch.stack(
279
- [
280
- r_block + dr_modulations * (1 - d_block1_rel[..., None]),
281
- r_block + dr - dr_modulations * d_block1_rel[..., None]
282
- ], -1).flatten(1)
283
-
284
- sld_blocks = r_modulations * sld_blocks
285
-
286
- d3 = d3_rel * d_block
287
-
288
- d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
289
-
290
- thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
291
-
292
- thicknesses = torch.cat(
293
- [thickness_blocks, d3[:, None], d_sio2[:, None]], -1
294
- )
295
-
296
- s_block = s_block_rel * d_block
297
-
298
- roughnesses = torch.cat(
299
- [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
300
- )
301
-
302
- slds = torch.cat(
303
- [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
304
- )
305
-
306
- params = dict(
307
- thicknesses=thicknesses,
308
- roughnesses=roughnesses,
309
- slds=slds
310
- )
311
- return params
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ __all__ = [
7
+ "MULTILAYER_MODELS",
8
+ "MultilayerModel",
9
+ ]
10
+
11
+
12
+ class MultilayerModel(object):
13
+ NAME: str = ''
14
+ PARAMETER_NAMES: Tuple[str, ...]
15
+
16
+ def __init__(self, max_num_layers: int):
17
+ self.max_num_layers = max_num_layers
18
+
19
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
20
+ raise NotImplementedError
21
+
22
+ def from_standard_params(self, params: dict) -> Tensor:
23
+ raise NotImplementedError
24
+
25
+
26
+ class BasicMultilayerModel1(MultilayerModel):
27
+ NAME = 'repeating_multilayer_v1'
28
+
29
+ PARAMETER_NAMES = (
30
+ "d_full_rel",
31
+ "rel_sigmas",
32
+ "d_block",
33
+ "s_block_rel",
34
+ "r_block",
35
+ "dr",
36
+ "d3_rel",
37
+ "s3_rel",
38
+ "r3",
39
+ "d_sio2",
40
+ "s_sio2",
41
+ "s_si",
42
+ "r_sio2",
43
+ "r_si",
44
+ )
45
+
46
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
47
+ return multilayer_model1(parametrized_model, self.max_num_layers)
48
+
49
+
50
+ class BasicMultilayerModel2(MultilayerModel):
51
+ NAME = 'repeating_multilayer_v2'
52
+
53
+ PARAMETER_NAMES = (
54
+ "d_full_rel",
55
+ "rel_sigmas",
56
+ "dr_sigmoid_rel_pos",
57
+ "dr_sigmoid_rel_width",
58
+ "d_block",
59
+ "s_block_rel",
60
+ "r_block",
61
+ "dr",
62
+ "d3_rel",
63
+ "s3_rel",
64
+ "r3",
65
+ "d_sio2",
66
+ "s_sio2",
67
+ "s_si",
68
+ "r_sio2",
69
+ "r_si",
70
+ )
71
+
72
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
73
+ return multilayer_model2(parametrized_model, self.max_num_layers)
74
+
75
+
76
+ class BasicMultilayerModel3(MultilayerModel):
77
+ NAME = 'repeating_multilayer_v3'
78
+
79
+ PARAMETER_NAMES = (
80
+ "d_full_rel",
81
+ "rel_sigmas",
82
+ "dr_sigmoid_rel_pos",
83
+ "dr_sigmoid_rel_width",
84
+ "d_block1_rel",
85
+ "d_block",
86
+ "s_block_rel",
87
+ "r_block",
88
+ "dr",
89
+ "d3_rel",
90
+ "s3_rel",
91
+ "r3",
92
+ "d_sio2",
93
+ "s_sio2",
94
+ "s_si",
95
+ "r_sio2",
96
+ "r_si",
97
+ )
98
+
99
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
100
+ return multilayer_model3(parametrized_model, self.max_num_layers)
101
+
102
+
103
+ MULTILAYER_MODELS = {
104
+ 'repeating_multilayer_v1': BasicMultilayerModel1,
105
+ 'repeating_multilayer_v2': BasicMultilayerModel2,
106
+ 'repeating_multilayer_v3': BasicMultilayerModel3,
107
+ }
108
+
109
+
110
+ def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
111
+ n = d_full_rel_max
112
+
113
+ (
114
+ d_full_rel,
115
+ rel_sigmas,
116
+ d_block,
117
+ s_block_rel,
118
+ r_block,
119
+ dr,
120
+ d3_rel,
121
+ s3_rel,
122
+ r3,
123
+ d_sio2,
124
+ s_sio2,
125
+ s_si,
126
+ r_sio2,
127
+ r_si,
128
+ ) = parametrized_model.T
129
+
130
+ batch_size = parametrized_model.shape[0]
131
+
132
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
133
+
134
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
135
+
136
+ r_block = r_block[:, None].repeat(1, n)
137
+ dr = dr[:, None].repeat(1, n)
138
+
139
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
140
+
141
+ sld_blocks = r_modulations * sld_blocks
142
+
143
+ d3 = d3_rel * d_block
144
+
145
+ thicknesses = torch.cat(
146
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
147
+ )
148
+
149
+ s_block = s_block_rel * d_block
150
+
151
+ roughnesses = torch.cat(
152
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
153
+ )
154
+
155
+ slds = torch.cat(
156
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
157
+ )
158
+
159
+ params = dict(
160
+ thicknesses=thicknesses,
161
+ roughnesses=roughnesses,
162
+ slds=slds
163
+ )
164
+ return params
165
+
166
+
167
+ def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
168
+ n = d_full_rel_max
169
+
170
+ (
171
+ d_full_rel,
172
+ rel_sigmas,
173
+ dr_sigmoid_rel_pos,
174
+ dr_sigmoid_rel_width,
175
+ d_block,
176
+ s_block_rel,
177
+ r_block,
178
+ dr,
179
+ d3_rel,
180
+ s3_rel,
181
+ r3,
182
+ d_sio2,
183
+ s_sio2,
184
+ s_si,
185
+ r_sio2,
186
+ r_si,
187
+ ) = parametrized_model.T
188
+
189
+ batch_size = parametrized_model.shape[0]
190
+
191
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
192
+
193
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
194
+
195
+ r_block = r_block[:, None].repeat(1, n)
196
+ dr = dr[:, None].repeat(1, n)
197
+
198
+ dr_positions = r_positions[:, ::2]
199
+
200
+ dr_modulations = torch.sigmoid(
201
+ -(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
202
+ )
203
+
204
+ dr = dr * dr_modulations
205
+
206
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
207
+
208
+ sld_blocks = r_modulations * sld_blocks
209
+
210
+ d3 = d3_rel * d_block
211
+
212
+ thicknesses = torch.cat(
213
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
214
+ )
215
+
216
+ s_block = s_block_rel * d_block
217
+
218
+ roughnesses = torch.cat(
219
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
220
+ )
221
+
222
+ slds = torch.cat(
223
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
224
+ )
225
+
226
+ params = dict(
227
+ thicknesses=thicknesses,
228
+ roughnesses=roughnesses,
229
+ slds=slds
230
+ )
231
+ return params
232
+
233
+
234
+ def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
235
+ n = d_full_rel_max
236
+
237
+ (
238
+ d_full_rel,
239
+ rel_sigmas,
240
+ dr_sigmoid_rel_pos,
241
+ dr_sigmoid_rel_width,
242
+ d_block1_rel,
243
+ d_block,
244
+ s_block_rel,
245
+ r_block,
246
+ dr,
247
+ d3_rel,
248
+ s3_rel,
249
+ r3,
250
+ d_sio2,
251
+ s_sio2,
252
+ s_si,
253
+ r_sio2,
254
+ r_si,
255
+ ) = parametrized_model.T
256
+
257
+ batch_size = parametrized_model.shape[0]
258
+
259
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
260
+
261
+ r_modulations = torch.sigmoid(
262
+ -(
263
+ r_positions - 2 * d_full_rel[..., None]
264
+ ) / rel_sigmas[..., None]
265
+ )
266
+
267
+ dr_positions = r_positions[:, ::2]
268
+
269
+ dr_modulations = dr[..., None] * (1 - torch.sigmoid(
270
+ -(
271
+ dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
272
+ ) / dr_sigmoid_rel_width[..., None]
273
+ ))
274
+
275
+ r_block = r_block[..., None].repeat(1, n)
276
+ dr = dr[..., None].repeat(1, n)
277
+
278
+ sld_blocks = torch.stack(
279
+ [
280
+ r_block + dr_modulations * (1 - d_block1_rel[..., None]),
281
+ r_block + dr - dr_modulations * d_block1_rel[..., None]
282
+ ], -1).flatten(1)
283
+
284
+ sld_blocks = r_modulations * sld_blocks
285
+
286
+ d3 = d3_rel * d_block
287
+
288
+ d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
289
+
290
+ thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
291
+
292
+ thicknesses = torch.cat(
293
+ [thickness_blocks, d3[:, None], d_sio2[:, None]], -1
294
+ )
295
+
296
+ s_block = s_block_rel * d_block
297
+
298
+ roughnesses = torch.cat(
299
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
300
+ )
301
+
302
+ slds = torch.cat(
303
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
304
+ )
305
+
306
+ params = dict(
307
+ thicknesses=thicknesses,
308
+ roughnesses=roughnesses,
309
+ slds=slds
310
+ )
311
+ return params