reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,852 +1,848 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
import
|
|
7
|
-
import
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
from reflectorch.
|
|
15
|
-
from reflectorch.
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from reflectorch.
|
|
19
|
-
from reflectorch.
|
|
20
|
-
|
|
21
|
-
from reflectorch.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
from
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
self.
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
51
|
-
self.
|
|
52
|
-
|
|
53
|
-
self.
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
if
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
config_name_no_extension = config_name
|
|
74
|
-
self.config_name =
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
self.
|
|
80
|
-
|
|
81
|
-
self.
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
print("
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
print(
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
print("
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
"""
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
if
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
if
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
)
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
eps = torch.finfo(r.dtype).eps
|
|
849
|
-
ind = torch.searchsorted(q[None].expand_as(qs).contiguous(), qs.contiguous())
|
|
850
|
-
ind = torch.clamp(ind - 1, 0, q.shape[0] - 2)
|
|
851
|
-
slopes = (r[1:] - r[:-1]) / (eps + (q[1:] - q[:-1]))
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from typing import List, Tuple, Union
|
|
7
|
+
from huggingface_hub import hf_hub_download
|
|
8
|
+
|
|
9
|
+
from reflectorch.data_generation.priors import BasicParams
|
|
10
|
+
from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
|
|
11
|
+
from reflectorch.data_generation.q_generator import ConstantQ, VariableQ, MaskedVariableQ
|
|
12
|
+
from reflectorch.data_generation.utils import get_density_profiles
|
|
13
|
+
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
14
|
+
from reflectorch.paths import CONFIG_DIR, SAVED_MODELS_DIR
|
|
15
|
+
from reflectorch.runs.utils import (
|
|
16
|
+
get_trainer_by_name
|
|
17
|
+
)
|
|
18
|
+
from reflectorch.ml.trainers import PointEstimatorTrainer
|
|
19
|
+
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
20
|
+
|
|
21
|
+
from reflectorch.inference.scipy_fitter import refl_fit, get_fit_with_growth
|
|
22
|
+
from reflectorch.inference.sampler_solution import get_best_mse_param
|
|
23
|
+
from reflectorch.utils import get_filtering_mask, to_t
|
|
24
|
+
|
|
25
|
+
from huggingface_hub.utils import disable_progress_bars
|
|
26
|
+
|
|
27
|
+
# that causes some Rust related errors when downloading models from Huggingface
|
|
28
|
+
disable_progress_bars()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InferenceModel(object):
|
|
32
|
+
"""Facilitates the inference process using pretrained models
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
|
|
36
|
+
model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
|
|
37
|
+
root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
|
|
38
|
+
weights_format (str, optional): format (extension) of the weights file, either 'pt' or 'safetensors'. Defaults to 'safetensors'.
|
|
39
|
+
repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
|
|
40
|
+
trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
|
|
41
|
+
device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
|
|
42
|
+
"""
|
|
43
|
+
def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, weights_format: str = 'safetensors',
|
|
44
|
+
repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer = None, device='cuda'):
|
|
45
|
+
self.config_name = config_name
|
|
46
|
+
self.model_name = model_name
|
|
47
|
+
self.root_dir = root_dir
|
|
48
|
+
self.weights_format = weights_format
|
|
49
|
+
self.repo_id = repo_id
|
|
50
|
+
self.trainer = trainer
|
|
51
|
+
self.device = device
|
|
52
|
+
|
|
53
|
+
if trainer is None and self.config_name is not None:
|
|
54
|
+
self.load_model(self.config_name, self.model_name, self.root_dir)
|
|
55
|
+
|
|
56
|
+
self.prediction_result = None
|
|
57
|
+
|
|
58
|
+
def load_model(self, config_name: str, model_name: str, root_dir: str) -> None:
|
|
59
|
+
"""Loads a model for inference
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
|
|
63
|
+
model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' or '.safetensors' extension), only required if different than: `'model_' + config_name + extension`.
|
|
64
|
+
root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
|
|
65
|
+
"""
|
|
66
|
+
if self.config_name == config_name and self.trainer is not None:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
if not config_name.endswith('.yaml'):
|
|
70
|
+
config_name_no_extension = config_name
|
|
71
|
+
self.config_name = config_name_no_extension + '.yaml'
|
|
72
|
+
else:
|
|
73
|
+
config_name_no_extension = config_name[:-5]
|
|
74
|
+
self.config_name = config_name
|
|
75
|
+
|
|
76
|
+
self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
|
|
77
|
+
weights_extension = '.' + self.weights_format
|
|
78
|
+
self.model_name = model_name or 'model_' + config_name_no_extension + weights_extension
|
|
79
|
+
if not self.model_name.endswith(weights_extension):
|
|
80
|
+
self.model_name += weights_extension
|
|
81
|
+
self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
|
|
82
|
+
|
|
83
|
+
def _download_with_fallback(filename: str, local_target_dir: Path, legacy_subfolder: str):
|
|
84
|
+
"""Try to download from repo root (new layout). If not found, retry with legacy `subfolder=legacy_subfolder`. Place result under local_target_dir using `local_dir`.
|
|
85
|
+
"""
|
|
86
|
+
try: # new layout: files at repo root (same level as README.md)
|
|
87
|
+
hf_hub_download(repo_id=self.repo_id + '/' + config_name, filename=filename, local_dir=str(local_target_dir))
|
|
88
|
+
except Exception : # legacy layout fallback: e.g. subfolder='configs' or 'saved_models'
|
|
89
|
+
hf_hub_download(repo_id=self.repo_id, filename=filename, subfolder=legacy_subfolder, local_dir=str(local_target_dir.parent))
|
|
90
|
+
|
|
91
|
+
config_path = Path(self.config_dir) / self.config_name
|
|
92
|
+
if config_path.exists():
|
|
93
|
+
print(f"Configuration file `{config_path}` found locally.")
|
|
94
|
+
else:
|
|
95
|
+
print(f"Configuration file `{config_path}` not found locally.")
|
|
96
|
+
if self.repo_id is None:
|
|
97
|
+
raise ValueError("repo_id must be provided to download files from Huggingface.")
|
|
98
|
+
print("Downloading from Huggingface...")
|
|
99
|
+
_download_with_fallback(self.config_name, self.config_dir, legacy_subfolder='configs')
|
|
100
|
+
|
|
101
|
+
model_path = Path(self.model_dir) / self.model_name
|
|
102
|
+
if model_path.exists():
|
|
103
|
+
print(f"Weights file `{model_path}` found locally.")
|
|
104
|
+
else:
|
|
105
|
+
print(f"Weights file `{model_path}` not found locally.")
|
|
106
|
+
if self.repo_id is None:
|
|
107
|
+
raise ValueError("repo_id must be provided to download files from Huggingface.")
|
|
108
|
+
print("Downloading from Huggingface...")
|
|
109
|
+
_download_with_fallback(self.model_name, self.model_dir, legacy_subfolder='saved_models')
|
|
110
|
+
|
|
111
|
+
self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
|
|
112
|
+
self.trainer.model.eval()
|
|
113
|
+
|
|
114
|
+
param_model = self.trainer.loader.prior_sampler.param_model
|
|
115
|
+
param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
|
|
116
|
+
print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
|
|
117
|
+
print("Parameter types and total ranges:")
|
|
118
|
+
for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
|
|
119
|
+
print(f"- {param}: {range_}")
|
|
120
|
+
print("Allowed widths of the prior bound intervals (max-min):")
|
|
121
|
+
for param, range_ in self.trainer.loader.prior_sampler.bound_width_ranges.items():
|
|
122
|
+
print(f"- {param}: {range_}")
|
|
123
|
+
|
|
124
|
+
if isinstance(self.trainer.loader.q_generator, ConstantQ):
|
|
125
|
+
q_min = self.trainer.loader.q_generator.q[0].item()
|
|
126
|
+
q_max = self.trainer.loader.q_generator.q[-1].item()
|
|
127
|
+
n_q = self.trainer.loader.q_generator.q.shape[0]
|
|
128
|
+
print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
|
|
129
|
+
elif isinstance(self.trainer.loader.q_generator, VariableQ):
|
|
130
|
+
q_min_range = self.trainer.loader.q_generator.q_min_range
|
|
131
|
+
q_max_range = self.trainer.loader.q_generator.q_max_range
|
|
132
|
+
n_q_range = self.trainer.loader.q_generator.n_q_range
|
|
133
|
+
if n_q_range[0] == n_q_range[1]:
|
|
134
|
+
n_q_fixed = n_q_range[0]
|
|
135
|
+
print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
|
|
136
|
+
f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
137
|
+
else:
|
|
138
|
+
print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
|
|
139
|
+
f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
140
|
+
|
|
141
|
+
if self.trainer.loader.smearing is not None:
|
|
142
|
+
q_res_min = self.trainer.loader.smearing.sigma_min
|
|
143
|
+
q_res_max = self.trainer.loader.smearing.sigma_max
|
|
144
|
+
if self.trainer.loader.smearing.constant_dq == False:
|
|
145
|
+
print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
|
|
146
|
+
elif self.trainer.loader.smearing.constant_dq == True:
|
|
147
|
+
print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
|
|
148
|
+
|
|
149
|
+
additional_inputs = ["prior bounds"]
|
|
150
|
+
if self.trainer.train_with_q_input:
|
|
151
|
+
additional_inputs.append("q values")
|
|
152
|
+
if self.trainer.condition_on_q_resolutions:
|
|
153
|
+
additional_inputs.append("the resolution dq/q")
|
|
154
|
+
if additional_inputs:
|
|
155
|
+
inputs_str = ", ".join(additional_inputs)
|
|
156
|
+
print(f"The following quantities are additional inputs to the network: {inputs_str}.")
|
|
157
|
+
|
|
158
|
+
def preprocess_and_predict(self,
|
|
159
|
+
reflectivity_curve: np.ndarray,
|
|
160
|
+
q_values: np.ndarray = None,
|
|
161
|
+
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
162
|
+
sigmas: np.ndarray = None,
|
|
163
|
+
q_resolution: Union[float, np.ndarray] = None,
|
|
164
|
+
ambient_sld: float = None,
|
|
165
|
+
clip_prediction: bool = True,
|
|
166
|
+
polish_prediction: bool = False,
|
|
167
|
+
polishing_method: str = 'trf',
|
|
168
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
169
|
+
use_sigmas_for_polishing: bool = False,
|
|
170
|
+
polishing_max_steps: int = None,
|
|
171
|
+
fit_growth: bool = False,
|
|
172
|
+
max_d_change: float = 5.,
|
|
173
|
+
calc_pred_curve: bool = True,
|
|
174
|
+
calc_pred_sld_profile: bool = False,
|
|
175
|
+
calc_polished_sld_profile: bool = False,
|
|
176
|
+
sld_profile_padding_left: float = 0.2,
|
|
177
|
+
sld_profile_padding_right: float = 1.1,
|
|
178
|
+
kwargs_param_labels: dict = {},
|
|
179
|
+
|
|
180
|
+
truncate_index_left: int = None,
|
|
181
|
+
truncate_index_right: int = None,
|
|
182
|
+
enable_error_bars_filtering: bool = True,
|
|
183
|
+
filter_threshold=0.3,
|
|
184
|
+
filter_remove_singles=True,
|
|
185
|
+
filter_remove_consecutives=True,
|
|
186
|
+
filter_consecutive=3,
|
|
187
|
+
filter_q_start_trunc=0.1,
|
|
188
|
+
):
|
|
189
|
+
"""Preprocess experimental data (clean, truncate, filter, interpolate) and run prediction. This wrapper prepares inputs according to the model's Q generator calls `predict(...)` on the interpolated/padded data, and (optionally) performs a polishing step on the original data (pre-interpolation)
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
reflectivity_curve (Union[np.ndarray, Tensor]): 1D array of experimental reflectivity values.
|
|
193
|
+
q_values (Union[np.ndarray, Tensor]): 1D array of momentum transfer values for the reflectivity curve (in units of inverse angstroms).
|
|
194
|
+
prior_bounds (Union[np.ndarray, List[Tuple]]): Prior bounds for all parameters, shape ``(num_params, 2)`` as ``[(min, max), …]``.
|
|
195
|
+
sigmas (Union[np.ndarray, Tensor], optional): 1D array of experimental uncertainties (same length as `reflectivity_curve`). Used for error-bar filtering (if enabled) and for polishing (if requested).
|
|
196
|
+
q_resolution (Union[float, np.ndarray], optional): The q resolution for neutron reflectometry models. Can be either a float (dq/q) for linear resolution smearing (e.g. 0.05 meaning 5% reolution smearing) or an array of dq values for pointwise resolution smearing.
|
|
197
|
+
ambient_sld (float, optional): The SLD of the fronting (i.e. ambient) medium for structure with fronting medium different than air.
|
|
198
|
+
clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to True.
|
|
199
|
+
polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Defaults to False.
|
|
200
|
+
polishing_method (str): {'trf', 'dogbox', 'lm'} SciPy least-squares method used for polishing.
|
|
201
|
+
use_sigmas_for_polishing (bool): If ``True``, weigh residuals by `sigmas` during polishing.
|
|
202
|
+
polishing_max_steps (int, optional): Maximum number of function evaluations for the SciPy optimizer.
|
|
203
|
+
fit_growth (bool, optional): (Deprecated) If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
|
|
204
|
+
max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
|
|
205
|
+
calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
|
|
206
|
+
calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
|
|
207
|
+
calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
|
|
208
|
+
sld_profile_padding_left (float, optional): Controls the amount of padding applied to the left side of the computed SLD profiles.
|
|
209
|
+
sld_profile_padding_right (float, optional): Controls the amount of padding applied to the right side of the computed SLD profiles.
|
|
210
|
+
truncate_index_left (int, optional): The data provided as input to the neural network will be truncated between the indices [truncate_index_left, truncate_index_right].
|
|
211
|
+
truncate_index_right (int, optional): The data provided as input to the neural network will be truncated between the indices [truncate_index_left, truncate_index_right].
|
|
212
|
+
enable_error_bars_filtering (bool, optional). If ``True``, the data points with high error bars (above a threshold) will be removed before constructing the input to the neural network (they are still used in the polishing step). Default to True.
|
|
213
|
+
filter_threshold (float, optional). The relative threshold (dR/R) for error bar filtering. Defaults to 0.3.
|
|
214
|
+
filter_remove_singles (float, optional). If ``True``, all isolated points exceeding the filtering threshold will be eliminated. Default to True.
|
|
215
|
+
filter_remove_consecutives (float, optional). If ``True``, in the situation when a number of ``filter_consecutive`` consecutive points exceeding the filtering threshold are detected at a position higher than ``filter_q_start_trunc``, all the subsequent points in the curve are eliminated.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
dict: dictionary containing the predictions
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
## Preprocess the data for inference (remove negative intensities, truncation, filer out points with high error bars)
|
|
222
|
+
(q_values, reflectivity_curve, sigmas, q_resolution,
|
|
223
|
+
q_values_original, reflectivity_curve_original, sigmas_original, q_resolution_original) = self._preprocess_input_data(
|
|
224
|
+
reflectivity_curve=reflectivity_curve,
|
|
225
|
+
q_values=q_values,
|
|
226
|
+
sigmas=sigmas,
|
|
227
|
+
q_resolution=q_resolution,
|
|
228
|
+
truncate_index_left=truncate_index_left,
|
|
229
|
+
truncate_index_right=truncate_index_right,
|
|
230
|
+
enable_error_bars_filtering=enable_error_bars_filtering,
|
|
231
|
+
filter_threshold=filter_threshold,
|
|
232
|
+
filter_remove_singles=filter_remove_singles,
|
|
233
|
+
filter_remove_consecutives=filter_remove_consecutives,
|
|
234
|
+
filter_consecutive=filter_consecutive,
|
|
235
|
+
filter_q_start_trunc=filter_q_start_trunc,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
### Interpolate the experimental data if needed by the embedding network
|
|
239
|
+
interp_data = self.interpolate_data_to_model_q(
|
|
240
|
+
q_exp=q_values,
|
|
241
|
+
refl_exp=reflectivity_curve,
|
|
242
|
+
sigmas_exp=sigmas,
|
|
243
|
+
q_res_exp=q_resolution,
|
|
244
|
+
as_dict=True
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
q_model = interp_data["q_model"]
|
|
248
|
+
reflectivity_curve_interp = interp_data["reflectivity"]
|
|
249
|
+
sigmas_interp = interp_data.get("sigmas")
|
|
250
|
+
q_resolution_interp = interp_data.get("q_resolution")
|
|
251
|
+
key_padding_mask = interp_data.get("key_padding_mask")
|
|
252
|
+
|
|
253
|
+
### Make the prediction
|
|
254
|
+
prediction_dict = self.predict(
|
|
255
|
+
reflectivity_curve=reflectivity_curve_interp,
|
|
256
|
+
q_values=q_model,
|
|
257
|
+
sigmas=sigmas_interp,
|
|
258
|
+
q_resolution=q_resolution_interp,
|
|
259
|
+
key_padding_mask=key_padding_mask,
|
|
260
|
+
prior_bounds=prior_bounds,
|
|
261
|
+
ambient_sld=ambient_sld,
|
|
262
|
+
clip_prediction=clip_prediction,
|
|
263
|
+
polish_prediction=False, ###do the polishing outside the predict method on the full data
|
|
264
|
+
supress_sld_amb_back_shift=True, ###do not shift back the slds by the ambient yet
|
|
265
|
+
calc_pred_curve=calc_pred_curve,
|
|
266
|
+
calc_pred_sld_profile=calc_pred_sld_profile,
|
|
267
|
+
sld_profile_padding_left=sld_profile_padding_left,
|
|
268
|
+
sld_profile_padding_right=sld_profile_padding_right,
|
|
269
|
+
kwargs_param_labels=kwargs_param_labels,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
### Save interpolated data
|
|
273
|
+
prediction_dict['q_model'] = q_model
|
|
274
|
+
prediction_dict['reflectivity_curve_interp'] = reflectivity_curve_interp
|
|
275
|
+
if q_resolution_interp is not None:
|
|
276
|
+
prediction_dict['q_resolution_interp'] = q_resolution_interp
|
|
277
|
+
if sigmas_interp is not None:
|
|
278
|
+
prediction_dict['sigmas_interp'] = sigmas_interp
|
|
279
|
+
if key_padding_mask is not None:
|
|
280
|
+
prediction_dict['key_padding_mask'] = key_padding_mask
|
|
281
|
+
|
|
282
|
+
### Shift the slds for nonzero ambient
|
|
283
|
+
prior_bounds = np.array(prior_bounds)
|
|
284
|
+
if ambient_sld:
|
|
285
|
+
sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
|
|
286
|
+
|
|
287
|
+
### Perform polishing on the original data
|
|
288
|
+
if polish_prediction:
|
|
289
|
+
polishing_kwargs = polishing_kwargs_reflectivity or {}
|
|
290
|
+
polishing_kwargs.setdefault('dq', q_resolution_original)
|
|
291
|
+
|
|
292
|
+
polished_dict = self._polish_prediction(
|
|
293
|
+
q=q_values_original,
|
|
294
|
+
curve=reflectivity_curve_original,
|
|
295
|
+
predicted_params=prediction_dict['predicted_params_object'],
|
|
296
|
+
priors=prior_bounds,
|
|
297
|
+
ambient_sld_tensor=torch.atleast_2d(torch.as_tensor(ambient_sld)) if ambient_sld is not None else None,
|
|
298
|
+
calc_polished_sld_profile=calc_polished_sld_profile,
|
|
299
|
+
sld_x_axis=torch.from_numpy(prediction_dict['predicted_sld_xaxis']),
|
|
300
|
+
polishing_kwargs_reflectivity = polishing_kwargs,
|
|
301
|
+
error_bars=sigmas_original if use_sigmas_for_polishing else None,
|
|
302
|
+
polishing_method=polishing_method,
|
|
303
|
+
polishing_max_steps=polishing_max_steps,
|
|
304
|
+
fit_growth=fit_growth,
|
|
305
|
+
max_d_change=max_d_change,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
prediction_dict.update(polished_dict)
|
|
309
|
+
if fit_growth and "polished_params_array" in prediction_dict:
|
|
310
|
+
prediction_dict["param_names"].append("max_d_change")
|
|
311
|
+
|
|
312
|
+
### Shift back the slds for nonzero ambient
|
|
313
|
+
if ambient_sld:
|
|
314
|
+
self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
|
|
315
|
+
|
|
316
|
+
return prediction_dict
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def predict(self,
|
|
320
|
+
reflectivity_curve: Union[np.ndarray, Tensor],
|
|
321
|
+
q_values: Union[np.ndarray, Tensor] = None,
|
|
322
|
+
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
323
|
+
sigmas: Union[np.ndarray, Tensor] = None,
|
|
324
|
+
key_padding_mask: Union[np.ndarray, Tensor] = None,
|
|
325
|
+
q_resolution: Union[float, np.ndarray] = None,
|
|
326
|
+
ambient_sld: float = None,
|
|
327
|
+
clip_prediction: bool = True,
|
|
328
|
+
polish_prediction: bool = False,
|
|
329
|
+
polishing_method: str = 'trf',
|
|
330
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
331
|
+
polishing_max_steps: int = None,
|
|
332
|
+
fit_growth: bool = False,
|
|
333
|
+
max_d_change: float = 5.,
|
|
334
|
+
use_q_shift: bool = False,
|
|
335
|
+
calc_pred_curve: bool = True,
|
|
336
|
+
calc_pred_sld_profile: bool = False,
|
|
337
|
+
calc_polished_sld_profile: bool = False,
|
|
338
|
+
sld_profile_padding_left: float = 0.2,
|
|
339
|
+
sld_profile_padding_right: float = 1.1,
|
|
340
|
+
supress_sld_amb_back_shift: bool = False,
|
|
341
|
+
kwargs_param_labels: dict = {},
|
|
342
|
+
):
|
|
343
|
+
"""Predict the thin film parameters
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
|
|
347
|
+
q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
|
|
348
|
+
prior_bounds (Union[np.ndarray, List[Tuple]]): The prior bounds for the predicted parameters.
|
|
349
|
+
sigmas (Union[np.ndarray, Tensor], optional): The error bars of the reflectivity curve, if available. They are used for filtering out points with high error bars if ``enable_error_bars_filtering`` is ``True``, as well as for the polishing step if ``use_sigmas_for_polishing`` is ``True``.
|
|
350
|
+
key_padding_mask (Union[np.ndarray, Tensor], optional): The key padding mask required for some embedding networks.
|
|
351
|
+
q_resolution (Union[float, np.ndarray], optional): The q resolution for neutron reflectometry models. Can be either a float dq/q for linear resolution smearing (e.g. 0.05 meaning 5% reolution smearing) or an array of dq values for pointwise resolution smearing.
|
|
352
|
+
ambient_sld (float, optional): The SLD of the fronting (i.e. ambient) medium for structure with fronting medium different than air.
|
|
353
|
+
clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to True.
|
|
354
|
+
polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Defaults to False.
|
|
355
|
+
polishing_method (str): Type of scipy method used for polishing.
|
|
356
|
+
polishing_max_steps (int, optional): Sets the maximum number of steps for the polishing algorithm.
|
|
357
|
+
fit_growth (bool, optional): (Deprecated) If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
|
|
358
|
+
max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
|
|
359
|
+
use_q_shift: (Deprecated) If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
|
|
360
|
+
calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
|
|
361
|
+
calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
|
|
362
|
+
calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
|
|
363
|
+
sld_profile_padding_left (float, optional): Controls the amount of padding applied to the left side of the computed SLD profiles.
|
|
364
|
+
sld_profile_padding_right (float, optional): Controls the amount of padding applied to the right side of the computed SLD profiles.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
dict: dictionary containing the predictions
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
scaled_curve = self._scale_curve(reflectivity_curve)
|
|
371
|
+
if prior_bounds is None:
|
|
372
|
+
raise ValueError(f'Prior bounds were not provided')
|
|
373
|
+
prior_bounds = np.array(prior_bounds)
|
|
374
|
+
|
|
375
|
+
if ambient_sld:
|
|
376
|
+
sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
|
|
377
|
+
|
|
378
|
+
scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
|
|
379
|
+
|
|
380
|
+
if isinstance(self.trainer.loader.q_generator, ConstantQ):
|
|
381
|
+
q_values = self.trainer.loader.q_generator.q
|
|
382
|
+
else:
|
|
383
|
+
if q_values is None:
|
|
384
|
+
raise ValueError(f'The q values were not provided')
|
|
385
|
+
q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
|
|
386
|
+
|
|
387
|
+
scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
|
|
388
|
+
|
|
389
|
+
if q_resolution is None and self.trainer.loader.smearing is not None:
|
|
390
|
+
raise ValueError(f'The q resolution must be provided for NR models')
|
|
391
|
+
|
|
392
|
+
if q_resolution is not None:
|
|
393
|
+
q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
|
|
394
|
+
if isinstance(q_resolution, float):
|
|
395
|
+
unscaled_q_resolutions = q_resolution_tensor
|
|
396
|
+
else:
|
|
397
|
+
unscaled_q_resolutions = (q_resolution_tensor / q_values).nanmean(dim=-1, keepdim=True) ##when q_values is padded with 0s, there will be nan at the padded positions
|
|
398
|
+
scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
|
|
399
|
+
scaled_conditioning_params = scaled_q_resolutions
|
|
400
|
+
if polishing_kwargs_reflectivity is None:
|
|
401
|
+
polishing_kwargs_reflectivity = {'dq': q_resolution}
|
|
402
|
+
else:
|
|
403
|
+
q_resolution_tensor = None
|
|
404
|
+
scaled_conditioning_params = None
|
|
405
|
+
|
|
406
|
+
if key_padding_mask is not None:
|
|
407
|
+
key_padding_mask = torch.as_tensor(key_padding_mask, device=self.device)
|
|
408
|
+
key_padding_mask = key_padding_mask.unsqueeze(0) if key_padding_mask.dim() == 1 else key_padding_mask
|
|
409
|
+
|
|
410
|
+
if use_q_shift and not self.trainer.train_with_q_input:
|
|
411
|
+
predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
|
|
412
|
+
else:
|
|
413
|
+
with torch.no_grad():
|
|
414
|
+
self.trainer.model.eval()
|
|
415
|
+
|
|
416
|
+
scaled_predicted_params = self.trainer.model(
|
|
417
|
+
curves=scaled_curve,
|
|
418
|
+
bounds=scaled_prior_bounds,
|
|
419
|
+
q_values=scaled_q_values,
|
|
420
|
+
conditioning_params = scaled_conditioning_params,
|
|
421
|
+
key_padding_mask = key_padding_mask,
|
|
422
|
+
unscaled_q_values = q_values,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
426
|
+
|
|
427
|
+
if clip_prediction:
|
|
428
|
+
predicted_params = self.trainer.loader.prior_sampler.clamp_params(predicted_params)
|
|
429
|
+
|
|
430
|
+
prediction_dict = {
|
|
431
|
+
"predicted_params_object": predicted_params,
|
|
432
|
+
"predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
|
|
433
|
+
"param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs_param_labels)
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
key_padding_mask = None if key_padding_mask is None else key_padding_mask.squeeze().cpu().numpy()
|
|
437
|
+
|
|
438
|
+
if calc_pred_curve:
|
|
439
|
+
predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
|
|
440
|
+
prediction_dict[ "predicted_curve"] = predicted_curve if key_padding_mask is None else predicted_curve[key_padding_mask]
|
|
441
|
+
|
|
442
|
+
ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld, device=self.device)) if ambient_sld is not None else None
|
|
443
|
+
if calc_pred_sld_profile:
|
|
444
|
+
predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
|
|
445
|
+
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
|
|
446
|
+
num=1024, padding_left=sld_profile_padding_left, padding_right=sld_profile_padding_right,
|
|
447
|
+
)
|
|
448
|
+
prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
|
|
449
|
+
prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
|
|
450
|
+
else:
|
|
451
|
+
predicted_sld_xaxis = None
|
|
452
|
+
|
|
453
|
+
refl_curve_polish = reflectivity_curve if key_padding_mask is None else reflectivity_curve[key_padding_mask]
|
|
454
|
+
q_polish = q_values.squeeze().cpu().numpy() if key_padding_mask is None else q_values.squeeze().cpu().numpy()[key_padding_mask]
|
|
455
|
+
prediction_dict['q_plot_pred'] = q_polish
|
|
456
|
+
|
|
457
|
+
if polish_prediction:
|
|
458
|
+
if ambient_sld_tensor:
|
|
459
|
+
ambient_sld_tensor = ambient_sld_tensor.cpu()
|
|
460
|
+
|
|
461
|
+
polished_dict = self._polish_prediction(
|
|
462
|
+
q = q_polish,
|
|
463
|
+
curve = refl_curve_polish,
|
|
464
|
+
predicted_params = predicted_params,
|
|
465
|
+
priors = np.array(prior_bounds),
|
|
466
|
+
error_bars = sigmas,
|
|
467
|
+
fit_growth = fit_growth,
|
|
468
|
+
max_d_change = max_d_change,
|
|
469
|
+
calc_polished_curve = calc_pred_curve,
|
|
470
|
+
calc_polished_sld_profile = calc_polished_sld_profile,
|
|
471
|
+
ambient_sld_tensor=ambient_sld_tensor,
|
|
472
|
+
sld_x_axis = predicted_sld_xaxis,
|
|
473
|
+
polishing_method=polishing_method,
|
|
474
|
+
polishing_max_steps=polishing_max_steps,
|
|
475
|
+
polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
|
|
476
|
+
)
|
|
477
|
+
prediction_dict.update(polished_dict)
|
|
478
|
+
|
|
479
|
+
if fit_growth and "polished_params_array" in prediction_dict:
|
|
480
|
+
prediction_dict["param_names"].append("max_d_change")
|
|
481
|
+
|
|
482
|
+
if ambient_sld and not supress_sld_amb_back_shift: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object; supress_sld_amb_back_shift is required for the 'preprocess_and_predict' method
|
|
483
|
+
self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
|
|
484
|
+
|
|
485
|
+
return prediction_dict
|
|
486
|
+
|
|
487
|
+
def _polish_prediction(self,
|
|
488
|
+
q: np.ndarray,
|
|
489
|
+
curve: np.ndarray,
|
|
490
|
+
predicted_params: BasicParams,
|
|
491
|
+
priors: np.ndarray,
|
|
492
|
+
sld_x_axis,
|
|
493
|
+
ambient_sld_tensor: Tensor = None,
|
|
494
|
+
fit_growth: bool = False,
|
|
495
|
+
max_d_change: float = 5.,
|
|
496
|
+
calc_polished_curve: bool = True,
|
|
497
|
+
calc_polished_sld_profile: bool = False,
|
|
498
|
+
error_bars: np.ndarray = None,
|
|
499
|
+
polishing_method: str = 'trf',
|
|
500
|
+
polishing_max_steps: int = None,
|
|
501
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
502
|
+
) -> dict:
|
|
503
|
+
params = predicted_params.parameters.squeeze().cpu().numpy()
|
|
504
|
+
|
|
505
|
+
polished_params_dict = {}
|
|
506
|
+
polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
if fit_growth:
|
|
510
|
+
polished_params_arr, curve_polished = get_fit_with_growth(
|
|
511
|
+
q = q,
|
|
512
|
+
curve = curve,
|
|
513
|
+
init_params = params,
|
|
514
|
+
bounds = priors.T,
|
|
515
|
+
max_d_change = max_d_change,
|
|
516
|
+
)
|
|
517
|
+
polished_params = BasicParams(
|
|
518
|
+
torch.from_numpy(polished_params_arr[:-1][None]),
|
|
519
|
+
torch.from_numpy(priors.T[0][None]),
|
|
520
|
+
torch.from_numpy(priors.T[1][None]),
|
|
521
|
+
self.trainer.loader.prior_sampler.max_num_layers,
|
|
522
|
+
self.trainer.loader.prior_sampler.param_model
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
polished_params_arr, polished_params_err, curve_polished = refl_fit(
|
|
526
|
+
q = q,
|
|
527
|
+
curve = curve,
|
|
528
|
+
init_params = params,
|
|
529
|
+
bounds=priors.T,
|
|
530
|
+
prior_sampler=self.trainer.loader.prior_sampler,
|
|
531
|
+
error_bars=error_bars,
|
|
532
|
+
method=polishing_method,
|
|
533
|
+
polishing_max_steps=polishing_max_steps,
|
|
534
|
+
reflectivity_kwargs=polishing_kwargs_reflectivity,
|
|
535
|
+
)
|
|
536
|
+
polished_params = BasicParams(
|
|
537
|
+
torch.from_numpy(polished_params_arr[None]),
|
|
538
|
+
torch.from_numpy(priors.T[0][None]),
|
|
539
|
+
torch.from_numpy(priors.T[1][None]),
|
|
540
|
+
self.trainer.loader.prior_sampler.max_num_layers,
|
|
541
|
+
self.trainer.loader.prior_sampler.param_model
|
|
542
|
+
)
|
|
543
|
+
except Exception as err:
|
|
544
|
+
polished_params = predicted_params
|
|
545
|
+
polished_params_arr = get_prediction_array(polished_params)
|
|
546
|
+
curve_polished = np.zeros_like(q)
|
|
547
|
+
polished_params_err = None
|
|
548
|
+
|
|
549
|
+
polished_params_dict['polished_params_array'] = polished_params_arr
|
|
550
|
+
|
|
551
|
+
polished_params_dict['polished_params_error_array'] = (
|
|
552
|
+
np.array(polished_params_err)
|
|
553
|
+
if polished_params_err is not None
|
|
554
|
+
else np.full_like(polished_params, np.nan, dtype=np.float64)
|
|
555
|
+
)
|
|
556
|
+
if calc_polished_curve:
|
|
557
|
+
polished_params_dict['polished_curve'] = curve_polished
|
|
558
|
+
|
|
559
|
+
if ambient_sld_tensor is not None:
|
|
560
|
+
ambient_sld_tensor = ambient_sld_tensor.to(polished_params.slds.device)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
if calc_polished_sld_profile:
|
|
564
|
+
_, sld_profile_polished, _ = get_density_profiles(
|
|
565
|
+
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
|
|
566
|
+
z_axis=sld_x_axis.to(polished_params.slds.device),
|
|
567
|
+
)
|
|
568
|
+
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
|
|
569
|
+
|
|
570
|
+
return polished_params_dict
|
|
571
|
+
|
|
572
|
+
def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
|
|
573
|
+
if not isinstance(curve, Tensor):
|
|
574
|
+
curve = torch.from_numpy(curve).float()
|
|
575
|
+
curve = curve.unsqueeze(0).to(self.device)
|
|
576
|
+
scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
|
|
577
|
+
return scaled_curve
|
|
578
|
+
|
|
579
|
+
def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
|
|
580
|
+
try:
|
|
581
|
+
prior_bounds = torch.tensor(prior_bounds)
|
|
582
|
+
prior_bounds = prior_bounds.to(self.device).T
|
|
583
|
+
min_bounds, max_bounds = prior_bounds[:, None]
|
|
584
|
+
|
|
585
|
+
scaled_bounds = torch.cat([
|
|
586
|
+
self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
|
|
587
|
+
self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
|
|
588
|
+
], -1)
|
|
589
|
+
|
|
590
|
+
return scaled_bounds.float()
|
|
591
|
+
|
|
592
|
+
except RuntimeError as e:
|
|
593
|
+
expected_param_dim = self.trainer.loader.prior_sampler.param_dim
|
|
594
|
+
actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
|
|
595
|
+
|
|
596
|
+
msg = (
|
|
597
|
+
f"\n **Parameter dimension mismatch during inference!**\n"
|
|
598
|
+
f"- Model expects **{expected_param_dim}** parameters.\n"
|
|
599
|
+
f"- You provided **{actual_param_dim}** prior bounds.\n\n"
|
|
600
|
+
f"💡This often occurs when:\n"
|
|
601
|
+
f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
|
|
602
|
+
f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
|
|
603
|
+
f"- The number of layers or parameterization type differs from the one used during training.\n\n"
|
|
604
|
+
f" Check the configuration or the summary of expected parameters."
|
|
605
|
+
)
|
|
606
|
+
raise ValueError(msg) from e
|
|
607
|
+
|
|
608
|
+
def _shift_slds_by_ambient(self, prior_bounds: np.ndarray, ambient_sld: float):
|
|
609
|
+
n_layers = self.trainer.loader.prior_sampler.max_num_layers
|
|
610
|
+
sld_indices = slice(2*n_layers+1, 3*n_layers+2)
|
|
611
|
+
prior_bounds[sld_indices, ...] -= ambient_sld
|
|
612
|
+
|
|
613
|
+
training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
|
|
614
|
+
training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
|
|
615
|
+
lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
|
|
616
|
+
upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
|
|
617
|
+
assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
|
|
618
|
+
|
|
619
|
+
return sld_indices
|
|
620
|
+
|
|
621
|
+
def _restore_slds_after_ambient_shift(self, prediction_dict, sld_indices, ambient_sld):
|
|
622
|
+
prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
|
|
623
|
+
if "polished_params_array" in prediction_dict:
|
|
624
|
+
prediction_dict["polished_params_array"][sld_indices] += ambient_sld
|
|
625
|
+
|
|
626
|
+
def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
|
|
627
|
+
return LogLikelihood(
|
|
628
|
+
q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
def get_param_labels(self, **kwargs):
|
|
632
|
+
return self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs)
|
|
633
|
+
|
|
634
|
+
@staticmethod
|
|
635
|
+
def _preprocess_input_data(
|
|
636
|
+
reflectivity_curve,
|
|
637
|
+
q_values,
|
|
638
|
+
sigmas=None,
|
|
639
|
+
q_resolution=None,
|
|
640
|
+
truncate_index_left=None,
|
|
641
|
+
truncate_index_right=None,
|
|
642
|
+
enable_error_bars_filtering=True,
|
|
643
|
+
filter_threshold=0.3,
|
|
644
|
+
filter_remove_singles=True,
|
|
645
|
+
filter_remove_consecutives=True,
|
|
646
|
+
filter_consecutive=3,
|
|
647
|
+
filter_q_start_trunc=0.1):
|
|
648
|
+
|
|
649
|
+
# Save originals for polishing
|
|
650
|
+
reflectivity_curve_original = reflectivity_curve.copy()
|
|
651
|
+
q_values_original = q_values.copy() if q_values is not None else None
|
|
652
|
+
q_resolution_original = q_resolution.copy() if isinstance(q_resolution, np.ndarray) else q_resolution
|
|
653
|
+
sigmas_original = sigmas.copy() if sigmas is not None else None
|
|
654
|
+
|
|
655
|
+
# Remove points with non-positive intensities
|
|
656
|
+
nonnegative_mask = reflectivity_curve > 0.0
|
|
657
|
+
reflectivity_curve = reflectivity_curve[nonnegative_mask]
|
|
658
|
+
q_values = q_values[nonnegative_mask]
|
|
659
|
+
if sigmas is not None:
|
|
660
|
+
sigmas = sigmas[nonnegative_mask]
|
|
661
|
+
if isinstance(q_resolution, np.ndarray):
|
|
662
|
+
q_resolution = q_resolution[nonnegative_mask]
|
|
663
|
+
|
|
664
|
+
# Truncate arrays
|
|
665
|
+
if truncate_index_left is not None or truncate_index_right is not None:
|
|
666
|
+
slice_obj = slice(truncate_index_left, truncate_index_right)
|
|
667
|
+
reflectivity_curve = reflectivity_curve[slice_obj]
|
|
668
|
+
q_values = q_values[slice_obj]
|
|
669
|
+
if sigmas is not None:
|
|
670
|
+
sigmas = sigmas[slice_obj]
|
|
671
|
+
if isinstance(q_resolution, np.ndarray):
|
|
672
|
+
q_resolution = q_resolution[slice_obj]
|
|
673
|
+
|
|
674
|
+
# Filter high-error points
|
|
675
|
+
if enable_error_bars_filtering and sigmas is not None:
|
|
676
|
+
valid_mask = get_filtering_mask(
|
|
677
|
+
q_values,
|
|
678
|
+
reflectivity_curve,
|
|
679
|
+
sigmas,
|
|
680
|
+
threshold=filter_threshold,
|
|
681
|
+
consecutive=filter_consecutive,
|
|
682
|
+
remove_singles=filter_remove_singles,
|
|
683
|
+
remove_consecutives=filter_remove_consecutives,
|
|
684
|
+
q_start_trunc=filter_q_start_trunc
|
|
685
|
+
)
|
|
686
|
+
reflectivity_curve = reflectivity_curve[valid_mask]
|
|
687
|
+
q_values = q_values[valid_mask]
|
|
688
|
+
sigmas = sigmas[valid_mask]
|
|
689
|
+
if isinstance(q_resolution, np.ndarray):
|
|
690
|
+
q_resolution = q_resolution[valid_mask]
|
|
691
|
+
|
|
692
|
+
return (q_values, reflectivity_curve, sigmas, q_resolution,
|
|
693
|
+
q_values_original, reflectivity_curve_original,
|
|
694
|
+
sigmas_original, q_resolution_original)
|
|
695
|
+
|
|
696
|
+
def interpolate_data_to_model_q(
|
|
697
|
+
self,
|
|
698
|
+
q_exp,
|
|
699
|
+
refl_exp,
|
|
700
|
+
sigmas_exp=None,
|
|
701
|
+
q_res_exp=None,
|
|
702
|
+
as_dict=False
|
|
703
|
+
):
|
|
704
|
+
q_generator = self.trainer.loader.q_generator
|
|
705
|
+
|
|
706
|
+
def _pad(arr, pad_to, value=0.0):
|
|
707
|
+
if arr is None:
|
|
708
|
+
return None
|
|
709
|
+
return np.pad(arr, (0, pad_to - len(arr)), constant_values=value)
|
|
710
|
+
|
|
711
|
+
def _interp_or_keep(q_model, q_exp, arr):
|
|
712
|
+
"""Interpolate arrays, keep floats or None unchanged."""
|
|
713
|
+
if arr is None:
|
|
714
|
+
return None
|
|
715
|
+
return np.interp(q_model, q_exp, arr) if isinstance(arr, np.ndarray) else arr
|
|
716
|
+
|
|
717
|
+
def _pad_or_keep(arr, max_n):
|
|
718
|
+
"""Pad arrays, keep floats or None unchanged."""
|
|
719
|
+
if arr is None:
|
|
720
|
+
return None
|
|
721
|
+
return _pad(arr, max_n, 0.0) if isinstance(arr, np.ndarray) else arr
|
|
722
|
+
|
|
723
|
+
def _prepare_return(q, refl, sigmas=None, q_res=None, mask=None, as_dict=False):
|
|
724
|
+
if as_dict:
|
|
725
|
+
result = {"q_model": q, "reflectivity": refl}
|
|
726
|
+
if sigmas is not None: result["sigmas"] = sigmas
|
|
727
|
+
if q_res is not None: result["q_resolution"] = q_res
|
|
728
|
+
if mask is not None: result["key_padding_mask"] = mask
|
|
729
|
+
return result
|
|
730
|
+
result = [q, refl]
|
|
731
|
+
if sigmas is not None: result.append(sigmas)
|
|
732
|
+
if q_res is not None: result.append(q_res)
|
|
733
|
+
if mask is not None: result.append(mask)
|
|
734
|
+
return tuple(result)
|
|
735
|
+
|
|
736
|
+
# ConstantQ
|
|
737
|
+
if isinstance(q_generator, ConstantQ):
|
|
738
|
+
q_model = q_generator.q.cpu().numpy()
|
|
739
|
+
refl_out = interp_reflectivity(q_model, q_exp, refl_exp)
|
|
740
|
+
sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
|
|
741
|
+
q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
|
|
742
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
|
|
743
|
+
|
|
744
|
+
# VariableQ
|
|
745
|
+
elif isinstance(q_generator, VariableQ):
|
|
746
|
+
if q_generator.n_q_range[0] == q_generator.n_q_range[1]:
|
|
747
|
+
n_q_model = q_generator.n_q_range[0]
|
|
748
|
+
q_min = max(q_exp.min(), q_generator.q_min_range[0])
|
|
749
|
+
q_max = min(q_exp.max(), q_generator.q_max_range[1])
|
|
750
|
+
if self.trainer.loader.q_generator.mode == 'logspace':
|
|
751
|
+
q_model = torch.logspace(start=torch.log10(torch.tensor(q_min, device=self.device)),
|
|
752
|
+
end=torch.log10(torch.tensor(q_max, device=self.device)),
|
|
753
|
+
steps=n_q_model, device=self.device).to('cpu')
|
|
754
|
+
logspace = True
|
|
755
|
+
else:
|
|
756
|
+
q_model = np.linspace(q_min, q_max, n_q_model)
|
|
757
|
+
logspace = False
|
|
758
|
+
else:
|
|
759
|
+
return _prepare_return(q_exp, refl_exp, sigmas_exp, q_res_exp, None, as_dict)
|
|
760
|
+
|
|
761
|
+
refl_out = interp_reflectivity(q_model, q_exp, refl_exp, logspace=logspace)
|
|
762
|
+
sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
|
|
763
|
+
q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
|
|
764
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
|
|
765
|
+
|
|
766
|
+
# MaskedVariableQ
|
|
767
|
+
elif isinstance(q_generator, MaskedVariableQ):
|
|
768
|
+
min_n, max_n = q_generator.n_q_range
|
|
769
|
+
n_exp = len(q_exp)
|
|
770
|
+
|
|
771
|
+
if min_n <= n_exp <= max_n:
|
|
772
|
+
# Pad only
|
|
773
|
+
q_model = _pad(q_exp, max_n, 0.0)
|
|
774
|
+
refl_out = _pad(refl_exp, max_n, 0.0)
|
|
775
|
+
sigmas_out = _pad_or_keep(sigmas_exp, max_n)
|
|
776
|
+
q_res_out = _pad_or_keep(q_res_exp, max_n)
|
|
777
|
+
key_padding_mask = np.zeros(max_n, dtype=bool)
|
|
778
|
+
key_padding_mask[:n_exp] = True
|
|
779
|
+
|
|
780
|
+
else:
|
|
781
|
+
# Interpolate + pad
|
|
782
|
+
n_interp = min(max(n_exp, min_n), max_n)
|
|
783
|
+
q_min = max(q_exp.min(), q_generator.q_min_range[0])
|
|
784
|
+
q_max = min(q_exp.max(), q_generator.q_max_range[1])
|
|
785
|
+
q_interp = np.linspace(q_min, q_max, n_interp)
|
|
786
|
+
|
|
787
|
+
refl_interp = interp_reflectivity(q_interp, q_exp, refl_exp)
|
|
788
|
+
sigmas_interp = _interp_or_keep(q_interp, q_exp, sigmas_exp)
|
|
789
|
+
q_res_interp = _interp_or_keep(q_interp, q_exp, q_res_exp)
|
|
790
|
+
|
|
791
|
+
q_model = _pad(q_interp, max_n, 0.0)
|
|
792
|
+
refl_out = _pad(refl_interp, max_n, 0.0)
|
|
793
|
+
sigmas_out = _pad_or_keep(sigmas_interp, max_n)
|
|
794
|
+
q_res_out = _pad_or_keep(q_res_interp, max_n)
|
|
795
|
+
key_padding_mask = np.zeros(max_n, dtype=bool)
|
|
796
|
+
key_padding_mask[:n_interp] = True
|
|
797
|
+
|
|
798
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, key_padding_mask, as_dict)
|
|
799
|
+
|
|
800
|
+
else:
|
|
801
|
+
raise TypeError(f"Unsupported QGenerator type: {type(q_generator)}")
|
|
802
|
+
|
|
803
|
+
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
804
|
+
assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
|
|
805
|
+
q = self.trainer.loader.q_generator.q.squeeze().float()
|
|
806
|
+
dq_max = (q[1] - q[0]) * dq_coef
|
|
807
|
+
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
808
|
+
|
|
809
|
+
curve = to_t(curve).to(scaled_bounds)
|
|
810
|
+
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
811
|
+
|
|
812
|
+
assert shifted_curves.shape == (num, q.shape[0])
|
|
813
|
+
|
|
814
|
+
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
815
|
+
scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
|
|
816
|
+
|
|
817
|
+
with torch.no_grad():
|
|
818
|
+
self.trainer.model.eval()
|
|
819
|
+
scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
|
|
820
|
+
restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
821
|
+
|
|
822
|
+
best_param = get_best_mse_param(
|
|
823
|
+
restored_params,
|
|
824
|
+
self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
|
|
825
|
+
)
|
|
826
|
+
return best_param
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
EasyInferenceModel = InferenceModel
|
|
831
|
+
|
|
832
|
+
def get_prediction_array(params: BasicParams) -> np.ndarray:
|
|
833
|
+
predict_arr = torch.cat([
|
|
834
|
+
params.thicknesses.squeeze(),
|
|
835
|
+
params.roughnesses.squeeze(),
|
|
836
|
+
params.slds.squeeze(),
|
|
837
|
+
]).cpu().numpy()
|
|
838
|
+
|
|
839
|
+
return predict_arr
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def _qshift_interp(q, r, q_shifts):
|
|
843
|
+
qs = q[None] + q_shifts[:, None]
|
|
844
|
+
eps = torch.finfo(r.dtype).eps
|
|
845
|
+
ind = torch.searchsorted(q[None].expand_as(qs).contiguous(), qs.contiguous())
|
|
846
|
+
ind = torch.clamp(ind - 1, 0, q.shape[0] - 2)
|
|
847
|
+
slopes = (r[1:] - r[:-1]) / (eps + (q[1:] - q[:-1]))
|
|
852
848
|
return r[ind] + slopes[ind] * (qs - q[ind])
|