reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,112 +1,112 @@
1
- from pathlib import Path
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- from reflectorch.data_generation.priors import PriorSampler
7
- from reflectorch.paths import SAVED_MODELS_DIR
8
-
9
-
10
- class CurvesScaler(object):
11
- """Base class for curve scalers"""
12
- def scale(self, curves: Tensor):
13
- raise NotImplementedError
14
-
15
- def restore(self, curves: Tensor):
16
- raise NotImplementedError
17
-
18
-
19
- class LogAffineCurvesScaler(CurvesScaler):
20
- """ Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
21
- :math:`\log_{10}(R + eps) \cdot weight + bias`.
22
-
23
- Args:
24
- weight (float): multiplication factor in the transformation
25
- bias (float): addition term in the transformation
26
- eps (float): sets the minimum intensity value of the reflectivity curves which is considered
27
- """
28
- def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
29
- self.weight = weight
30
- self.bias = bias
31
- self.eps = eps
32
-
33
- def scale(self, curves: Tensor):
34
- """scales the reflectivity curves to a ML-friendly range
35
-
36
- Args:
37
- curves (Tensor): original reflectivity curves
38
-
39
- Returns:
40
- Tensor: reflectivity curves scaled to a ML-friendly range
41
- """
42
- return torch.log10(curves + self.eps) * self.weight + self.bias
43
-
44
- def restore(self, curves: Tensor):
45
- """restores the physical reflectivity curves
46
-
47
- Args:
48
- curves (Tensor): scaled reflectivity curves
49
-
50
- Returns:
51
- Tensor: reflectivity curves restored to the physical range
52
- """
53
- return 10 ** ((curves - self.bias) / self.weight) - self.eps
54
-
55
-
56
- class MeanNormalizationCurvesScaler(CurvesScaler):
57
- """Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
58
-
59
- Args:
60
- path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
61
- curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
62
- device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
63
- """
64
-
65
- def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
66
- if curves_mean is None:
67
- curves_mean = torch.load(self.get_path(path))
68
- self.curves_mean = curves_mean.to(device)
69
-
70
- def scale(self, curves: Tensor):
71
- """scales the reflectivity curves to a ML-friendly range
72
-
73
- Args:
74
- curves (Tensor): original reflectivity curves
75
-
76
- Returns:
77
- Tensor: reflectivity curves scaled to a ML-friendly range
78
- """
79
- self.curves_mean = self.curves_mean.to(curves)
80
- return curves / self.curves_mean - 1
81
-
82
- def restore(self, curves: Tensor):
83
- """restores the physical reflectivity curves
84
-
85
- Args:
86
- curves (Tensor): scaled reflectivity curves
87
-
88
- Returns:
89
- Tensor: reflectivity curves restored to the physical range
90
- """
91
- self.curves_mean = self.curves_mean.to(curves)
92
- return (curves + 1) * self.curves_mean
93
-
94
- @staticmethod
95
- def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
96
- """computes the mean of a batch of reflectivity curves and saves it
97
-
98
- Args:
99
- prior_sampler (PriorSampler): the prior sampler
100
- q (Tensor): the q values
101
- path (str): the path for saving the mean of the curves
102
- num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
103
- """
104
- params = prior_sampler.sample(num)
105
- curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
106
- torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
107
-
108
- @staticmethod
109
- def get_path(path: str) -> Path:
110
- if not path.endswith('.pt'):
111
- path = path + '.pt'
112
- return SAVED_MODELS_DIR / path
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.priors import PriorSampler
7
+ from reflectorch.paths import SAVED_MODELS_DIR
8
+
9
+
10
+ class CurvesScaler(object):
11
+ """Base class for curve scalers"""
12
+ def scale(self, curves: Tensor):
13
+ raise NotImplementedError
14
+
15
+ def restore(self, curves: Tensor):
16
+ raise NotImplementedError
17
+
18
+
19
+ class LogAffineCurvesScaler(CurvesScaler):
20
+ """ Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
21
+ :math:`\log_{10}(R + eps) \cdot weight + bias`.
22
+
23
+ Args:
24
+ weight (float): multiplication factor in the transformation
25
+ bias (float): addition term in the transformation
26
+ eps (float): sets the minimum intensity value of the reflectivity curves which is considered
27
+ """
28
+ def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
29
+ self.weight = weight
30
+ self.bias = bias
31
+ self.eps = eps
32
+
33
+ def scale(self, curves: Tensor):
34
+ """scales the reflectivity curves to a ML-friendly range
35
+
36
+ Args:
37
+ curves (Tensor): original reflectivity curves
38
+
39
+ Returns:
40
+ Tensor: reflectivity curves scaled to a ML-friendly range
41
+ """
42
+ return torch.log10(curves + self.eps) * self.weight + self.bias
43
+
44
+ def restore(self, curves: Tensor):
45
+ """restores the physical reflectivity curves
46
+
47
+ Args:
48
+ curves (Tensor): scaled reflectivity curves
49
+
50
+ Returns:
51
+ Tensor: reflectivity curves restored to the physical range
52
+ """
53
+ return 10 ** ((curves - self.bias) / self.weight) - self.eps
54
+
55
+
56
+ class MeanNormalizationCurvesScaler(CurvesScaler):
57
+ """Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
58
+
59
+ Args:
60
+ path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
61
+ curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
62
+ device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
63
+ """
64
+
65
+ def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
66
+ if curves_mean is None:
67
+ curves_mean = torch.load(self.get_path(path))
68
+ self.curves_mean = curves_mean.to(device)
69
+
70
+ def scale(self, curves: Tensor):
71
+ """scales the reflectivity curves to a ML-friendly range
72
+
73
+ Args:
74
+ curves (Tensor): original reflectivity curves
75
+
76
+ Returns:
77
+ Tensor: reflectivity curves scaled to a ML-friendly range
78
+ """
79
+ self.curves_mean = self.curves_mean.to(curves)
80
+ return curves / self.curves_mean - 1
81
+
82
+ def restore(self, curves: Tensor):
83
+ """restores the physical reflectivity curves
84
+
85
+ Args:
86
+ curves (Tensor): scaled reflectivity curves
87
+
88
+ Returns:
89
+ Tensor: reflectivity curves restored to the physical range
90
+ """
91
+ self.curves_mean = self.curves_mean.to(curves)
92
+ return (curves + 1) * self.curves_mean
93
+
94
+ @staticmethod
95
+ def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
96
+ """computes the mean of a batch of reflectivity curves and saves it
97
+
98
+ Args:
99
+ prior_sampler (PriorSampler): the prior sampler
100
+ q (Tensor): the q values
101
+ path (str): the path for saving the mean of the curves
102
+ num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
103
+ """
104
+ params = prior_sampler.sample(num)
105
+ curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
106
+ torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
107
+
108
+ @staticmethod
109
+ def get_path(path: str) -> Path:
110
+ if not path.endswith('.pt'):
111
+ path = path + '.pt'
112
+ return SAVED_MODELS_DIR / path
@@ -1,99 +1,99 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
5
-
6
-
7
- class Smearing(object):
8
- """Class which applies resolution smearing to the reflectivity curves.
9
- The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
-
11
- Args:
12
- sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
13
- constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
- otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
15
- gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
16
- share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
17
- """
18
- def __init__(self,
19
- sigma_range: tuple = (0.01, 0.1),
20
- constant_dq: bool = False,
21
- gauss_num: int = 31,
22
- share_smeared: float = 0.2,
23
- ):
24
- self.sigma_min, self.sigma_max = sigma_range
25
- self.sigma_delta = self.sigma_max - self.sigma_min
26
- self.constant_dq = constant_dq
27
- self.gauss_num = gauss_num
28
- self.share_smeared = share_smeared
29
-
30
- def __repr__(self):
31
- return f'Smearing(({self.sigma_min}, {self.sigma_max})'
32
-
33
- def generate_resolutions(self, batch_size: int, device=None, dtype=None):
34
- num_smeared = int(batch_size * self.share_smeared)
35
- if not num_smeared:
36
- return None, None
37
- dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
38
- indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
39
- indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
40
- return dq, indices
41
-
42
- def scale_resolutions(self, resolutions: Tensor) -> Tensor:
43
- """Scales the q-resolution values to [-1,1] range using the internal sigma range"""
44
- sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
45
- return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
46
-
47
- def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
48
- refl_kwargs = refl_kwargs or {}
49
-
50
- dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
51
- q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
52
-
53
- if dq is None:
54
- return params.reflectivity(q_values, **refl_kwargs), q_resolutions
55
-
56
- refl_kwargs_not_smeared = {}
57
- refl_kwargs_smeared = {}
58
- for key, value in refl_kwargs.items():
59
- if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
60
- refl_kwargs_not_smeared[key] = value[~indices]
61
- refl_kwargs_smeared[key] = value[indices]
62
- else:
63
- refl_kwargs_not_smeared[key] = value
64
- refl_kwargs_smeared[key] = value
65
-
66
- # Compute unsmeared reflectivity
67
- if (~indices).sum().item():
68
- if q_values.dim() == 2 and q_values.shape[0] > 1:
69
- q = q_values[~indices]
70
- else:
71
- q = q_values
72
-
73
- reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
74
- else:
75
- reflectivity_not_smeared = None
76
-
77
- # Compute smeared reflectivity
78
- if indices.sum().item():
79
- if q_values.dim() == 2 and q_values.shape[0] > 1:
80
- q = q_values[indices]
81
- else:
82
- q = q_values
83
-
84
- reflectivity_smeared = params[indices].reflectivity(
85
- q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
86
- )
87
- else:
88
- reflectivity_smeared = None
89
-
90
- curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
91
-
92
- if (~indices).sum().item():
93
- curves[~indices] = reflectivity_not_smeared
94
-
95
- curves[indices] = reflectivity_smeared
96
-
97
- q_resolutions[indices] = dq
98
-
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
5
+
6
+
7
+ class Smearing(object):
8
+ """Class which applies resolution smearing to the reflectivity curves.
9
+ The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
+
11
+ Args:
12
+ sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
13
+ constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
+ otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
15
+ gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
16
+ share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
17
+ """
18
+ def __init__(self,
19
+ sigma_range: tuple = (0.01, 0.1),
20
+ constant_dq: bool = False,
21
+ gauss_num: int = 31,
22
+ share_smeared: float = 0.2,
23
+ ):
24
+ self.sigma_min, self.sigma_max = sigma_range
25
+ self.sigma_delta = self.sigma_max - self.sigma_min
26
+ self.constant_dq = constant_dq
27
+ self.gauss_num = gauss_num
28
+ self.share_smeared = share_smeared
29
+
30
+ def __repr__(self):
31
+ return f'Smearing(({self.sigma_min}, {self.sigma_max})'
32
+
33
+ def generate_resolutions(self, batch_size: int, device=None, dtype=None):
34
+ num_smeared = int(batch_size * self.share_smeared)
35
+ if not num_smeared:
36
+ return None, None
37
+ dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
38
+ indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
39
+ indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
40
+ return dq, indices
41
+
42
+ def scale_resolutions(self, resolutions: Tensor) -> Tensor:
43
+ """Scales the q-resolution values to [-1,1] range using the internal sigma range"""
44
+ sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
45
+ return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
46
+
47
+ def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
48
+ refl_kwargs = refl_kwargs or {}
49
+
50
+ dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
51
+ q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
52
+
53
+ if dq is None:
54
+ return params.reflectivity(q_values, **refl_kwargs), q_resolutions
55
+
56
+ refl_kwargs_not_smeared = {}
57
+ refl_kwargs_smeared = {}
58
+ for key, value in refl_kwargs.items():
59
+ if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
60
+ refl_kwargs_not_smeared[key] = value[~indices]
61
+ refl_kwargs_smeared[key] = value[indices]
62
+ else:
63
+ refl_kwargs_not_smeared[key] = value
64
+ refl_kwargs_smeared[key] = value
65
+
66
+ # Compute unsmeared reflectivity
67
+ if (~indices).sum().item():
68
+ if q_values.dim() == 2 and q_values.shape[0] > 1:
69
+ q = q_values[~indices]
70
+ else:
71
+ q = q_values
72
+
73
+ reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
74
+ else:
75
+ reflectivity_not_smeared = None
76
+
77
+ # Compute smeared reflectivity
78
+ if indices.sum().item():
79
+ if q_values.dim() == 2 and q_values.shape[0] > 1:
80
+ q = q_values[indices]
81
+ else:
82
+ q = q_values
83
+
84
+ reflectivity_smeared = params[indices].reflectivity(
85
+ q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
86
+ )
87
+ else:
88
+ reflectivity_smeared = None
89
+
90
+ curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
91
+
92
+ if (~indices).sum().item():
93
+ curves[~indices] = reflectivity_not_smeared
94
+
95
+ curves[indices] = reflectivity_smeared
96
+
97
+ q_resolutions[indices] = dq
98
+
99
99
  return curves, q_resolutions