reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,222 +1,223 @@
1
- from typing import List, Union
2
- from math import sqrt, pi, log10
3
-
4
- import torch
5
- from torch import Tensor
6
-
7
- __all__ = [
8
- "get_reversed_params",
9
- "get_density_profiles",
10
- "uniform_sampler",
11
- "logdist_sampler",
12
- "triangular_sampler",
13
- "get_param_labels",
14
- "get_d_rhos",
15
- "get_slds_from_d_rhos",
16
- ]
17
-
18
-
19
- def uniform_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
20
- if isinstance(low, Tensor):
21
- device, dtype = low.device, low.dtype
22
- return torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low
23
-
24
-
25
- def logdist_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
26
- if isinstance(low, Tensor):
27
- device, dtype = low.device, low.dtype
28
- low, high = map(torch.log10, (low, high))
29
- else:
30
- low, high = map(log10, (low, high))
31
- return 10 ** (torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low)
32
-
33
-
34
- def triangular_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
35
- if isinstance(low, Tensor):
36
- device, dtype = low.device, low.dtype
37
-
38
- x = torch.rand(*shape, device=device, dtype=dtype)
39
-
40
- return (high - low) * (1 - torch.sqrt(x)) + low
41
-
42
-
43
- def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
44
- reversed_slds = torch.cumsum(
45
- torch.flip(
46
- torch.diff(
47
- torch.cat([torch.zeros(slds.shape[0], 1).to(slds), slds], dim=-1),
48
- dim=-1
49
- ), (-1,)
50
- ),
51
- dim=-1
52
- )
53
- reversed_thicknesses = torch.flip(thicknesses, [-1])
54
- reversed_roughnesses = torch.flip(roughnesses, [-1])
55
- reversed_params = torch.cat([reversed_thicknesses, reversed_roughnesses, reversed_slds], -1)
56
-
57
- return reversed_params
58
-
59
-
60
- def get_density_profiles_sld(
61
- thicknesses: Tensor,
62
- roughnesses: Tensor,
63
- slds: Tensor,
64
- z_axis: Tensor = None,
65
- num: int = 1000
66
- ):
67
- """Generates SLD profiles (and their derivative) based on batches of thicknesses, roughnesses and layer SLDs.
68
-
69
- The axis has its zero at the top (ambient medium) interface and is positive inside the film.
70
-
71
- Args:
72
- thicknesses (Tensor): the layer thicknesses (top to bottom)
73
- roughnesses (Tensor): the interlayer roughnesses (top to bottom)
74
- slds (Tensor): the layer SLDs (top to bottom)
75
- z_axis (Tensor, optional): a custom depth (z) axis. Defaults to None.
76
- num (int, optional): number of discretization points for the profile. Defaults to 1000.
77
-
78
- Returns:
79
- tuple: the z axis, the computed density profile rho(z) and the derivative of the density profile drho/dz(z)
80
- """
81
- assert torch.all(roughnesses >= 0), 'Negative roughness happened'
82
- assert torch.all(thicknesses >= 0), 'Negative thickness happened'
83
-
84
- sample_num = thicknesses.shape[0]
85
-
86
- d_rhos = get_d_rhos(slds)
87
-
88
- zs = torch.cumsum(torch.cat([torch.zeros(sample_num, 1).to(thicknesses), thicknesses], dim=-1), dim=-1)
89
-
90
- if z_axis is None:
91
- z_axis = torch.linspace(- zs.max() * 0.1, zs.max() * 1.1, num, device=thicknesses.device)[None]
92
- elif len(z_axis.shape) == 1:
93
- z_axis = z_axis[None]
94
-
95
- sigmas = roughnesses * sqrt(2)
96
-
97
- profile = get_erf(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
98
-
99
- d_profile = get_gauss(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
100
-
101
- z_axis = z_axis[0]
102
-
103
- return z_axis, profile, d_profile
104
-
105
-
106
- def get_d_rhos(slds: Tensor) -> Tensor:
107
- d_rhos = torch.cat([slds[:, 0][:, None], torch.diff(slds, dim=-1)], -1)
108
- return d_rhos
109
-
110
-
111
- def get_slds_from_d_rhos(d_rhos: Tensor) -> Tensor:
112
- slds = torch.cumsum(d_rhos, dim=-1)
113
- return slds
114
-
115
-
116
- def get_erf(z, z0, sigma, amp):
117
- return (torch.erf((z - z0) / sigma) + 1) * amp / 2
118
-
119
-
120
- def get_gauss(z, z0, sigma, amp):
121
- return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
122
-
123
- def get_density_profiles(
124
- thicknesses: torch.Tensor,
125
- roughnesses: torch.Tensor,
126
- slds: torch.Tensor,
127
- ambient_sld: torch.Tensor = None,
128
- z_axis: torch.Tensor = None,
129
- num: int = 1000,
130
- padding_left: float = 0.2,
131
- padding_right: float = 1.1,
132
- ):
133
- """
134
- Args:
135
- thicknesses (Tensor): finite layer thicknesses.
136
- roughnesses (Tensor): interface roughnesses for all transitions (ambient→layer1 ... layerN→substrate).
137
- slds (Tensor): SLDs for the finite layers + substrate.
138
- ambient_sld (Tensor, optional): SLD for the top ambient. Defaults to 0.0 if None.
139
- z_axis (Tensor, optional): a custom depth axis. If None, a linear axis is generated.
140
- num (int): number of points in the generated z-axis (if z_axis is None).
141
- padding_left (float): factor to extend the negative (above the surface) portion of z-axis.
142
- padding_right (float): factor to extend the positive (into the sample) portion of z-axis.
143
-
144
- Returns:
145
- (z_axis, profile, d_profile)
146
- z_axis: 1D Tensor of shape (num, ) with the depth coordinates.
147
- profile: 2D Tensor of shape (batch_size, num) giving the SLD at each depth.
148
- d_profile: 2D Tensor of shape (batch_size, num) giving d(SLD)/dz at each depth.
149
- """
150
-
151
- bs, n = thicknesses.shape
152
- assert roughnesses.shape == (bs, n + 1), (
153
- f"Roughnesses must be (batch_size, num_layers+1). Found {roughnesses.shape} instead."
154
- )
155
- assert slds.shape == (bs, n + 1), (
156
- f"SLDs must be (batch_size, num_layers+1). Found {slds.shape} instead."
157
- )
158
- assert torch.all(thicknesses >= 0), "Negative thickness encountered."
159
- assert torch.all(roughnesses >= 0), "Negative roughness encountered."
160
-
161
- if ambient_sld is None:
162
- ambient_sld = torch.zeros((bs, 1), device=thicknesses.device)
163
- else:
164
- if ambient_sld.ndim == 1:
165
- ambient_sld = ambient_sld.unsqueeze(-1)
166
-
167
- slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
168
- d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
169
-
170
- interfaces = torch.cat([
171
- torch.zeros((bs, 1), device=thicknesses.device), # z=0 for ambient→layer1
172
- thicknesses
173
- ], dim=-1).cumsum(dim=-1) # now shape => (bs, n+1)
174
-
175
- total_thickness = interfaces[..., -1].max()
176
- if z_axis is None:
177
- z_axis = torch.linspace(
178
- -padding_left * total_thickness,
179
- padding_right * total_thickness,
180
- num,
181
- device=thicknesses.device
182
- ) # shape => (num,)
183
- if z_axis.ndim == 1:
184
- z_axis = z_axis.unsqueeze(0) # shape => (1, num)
185
-
186
- z_b = z_axis.repeat(bs, 1).unsqueeze(1) # (bs, 1, num)
187
- interfaces_b = interfaces.unsqueeze(-1) # (bs, n+1, 1)
188
- sigmas_b = (roughnesses * sqrt(2)).unsqueeze(-1) # (bs, n+1, 1)
189
- d_rhos_b = d_rhos.unsqueeze(-1) # (bs, n+1, 1)
190
-
191
- profile = get_erf(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
192
- if ambient_sld is not None:
193
- profile = profile + ambient_sld
194
-
195
- d_profile = get_gauss(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
196
-
197
- return z_axis.squeeze(0), profile, d_profile
198
-
199
- def get_param_labels(
200
- num_layers: int, *,
201
- thickness_name: str = 'Thickness',
202
- roughness_name: str = 'Roughness',
203
- sld_name: str = 'SLD',
204
- imag_sld_name: str = 'SLD imag',
205
- substrate_name: str = 'sub',
206
- parameterization_type: str = 'standard',
207
- number_top_to_bottom: bool = False,
208
- ) -> List[str]:
209
- def pos(i):
210
- return i + 1 if number_top_to_bottom else num_layers - i
211
-
212
- thickness_labels = [f'{thickness_name} L{pos(i)}' for i in range(num_layers)]
213
- roughness_labels = [f'{roughness_name} L{pos(i)}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
214
- sld_labels = [f'{sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
215
-
216
- all_labels = thickness_labels + roughness_labels + sld_labels
217
-
218
- if parameterization_type == 'absorption':
219
- imag_sld_labels = [f'{imag_sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
220
- all_labels = all_labels + imag_sld_labels
221
-
222
- return all_labels
1
+ from typing import List, Union
2
+ from math import sqrt, pi, log10
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ __all__ = [
8
+ "get_reversed_params",
9
+ "get_density_profiles",
10
+ "uniform_sampler",
11
+ "logdist_sampler",
12
+ "triangular_sampler",
13
+ "get_param_labels",
14
+ "get_d_rhos",
15
+ "get_slds_from_d_rhos",
16
+ ]
17
+
18
+
19
+ def uniform_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
20
+ if isinstance(low, Tensor):
21
+ device, dtype = low.device, low.dtype
22
+ return torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low
23
+
24
+
25
+ def logdist_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
26
+ if isinstance(low, Tensor):
27
+ device, dtype = low.device, low.dtype
28
+ low, high = map(torch.log10, (low, high))
29
+ else:
30
+ low, high = map(log10, (low, high))
31
+ return 10 ** (torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low)
32
+
33
+
34
+ def triangular_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
35
+ if isinstance(low, Tensor):
36
+ device, dtype = low.device, low.dtype
37
+
38
+ x = torch.rand(*shape, device=device, dtype=dtype)
39
+
40
+ return (high - low) * (1 - torch.sqrt(x)) + low
41
+
42
+
43
+ def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
44
+ reversed_slds = torch.cumsum(
45
+ torch.flip(
46
+ torch.diff(
47
+ torch.cat([torch.zeros(slds.shape[0], 1).to(slds), slds], dim=-1),
48
+ dim=-1
49
+ ), (-1,)
50
+ ),
51
+ dim=-1
52
+ )
53
+ reversed_thicknesses = torch.flip(thicknesses, [-1])
54
+ reversed_roughnesses = torch.flip(roughnesses, [-1])
55
+ reversed_params = torch.cat([reversed_thicknesses, reversed_roughnesses, reversed_slds], -1)
56
+
57
+ return reversed_params
58
+
59
+
60
+ def get_density_profiles_sld(
61
+ thicknesses: Tensor,
62
+ roughnesses: Tensor,
63
+ slds: Tensor,
64
+ z_axis: Tensor = None,
65
+ num: int = 1000
66
+ ):
67
+ """Generates SLD profiles (and their derivative) based on batches of thicknesses, roughnesses and layer SLDs.
68
+
69
+ The axis has its zero at the top (ambient medium) interface and is positive inside the film.
70
+
71
+ Args:
72
+ thicknesses (Tensor): the layer thicknesses (top to bottom)
73
+ roughnesses (Tensor): the interlayer roughnesses (top to bottom)
74
+ slds (Tensor): the layer SLDs (top to bottom)
75
+ z_axis (Tensor, optional): a custom depth (z) axis. Defaults to None.
76
+ num (int, optional): number of discretization points for the profile. Defaults to 1000.
77
+
78
+ Returns:
79
+ tuple: the z axis, the computed density profile rho(z) and the derivative of the density profile drho/dz(z)
80
+ """
81
+ assert torch.all(roughnesses >= 0), 'Negative roughness happened'
82
+ assert torch.all(thicknesses >= 0), 'Negative thickness happened'
83
+
84
+ sample_num = thicknesses.shape[0]
85
+
86
+ d_rhos = get_d_rhos(slds)
87
+
88
+ zs = torch.cumsum(torch.cat([torch.zeros(sample_num, 1).to(thicknesses), thicknesses], dim=-1), dim=-1)
89
+
90
+ if z_axis is None:
91
+ z_axis = torch.linspace(- zs.max() * 0.1, zs.max() * 1.1, num, device=thicknesses.device)[None]
92
+ elif len(z_axis.shape) == 1:
93
+ z_axis = z_axis[None]
94
+
95
+ sigmas = roughnesses * sqrt(2)
96
+
97
+ profile = get_erf(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
98
+
99
+ d_profile = get_gauss(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
100
+
101
+ z_axis = z_axis[0]
102
+
103
+ return z_axis, profile, d_profile
104
+
105
+
106
+ def get_d_rhos(slds: Tensor) -> Tensor:
107
+ d_rhos = torch.cat([slds[:, 0][:, None], torch.diff(slds, dim=-1)], -1)
108
+ return d_rhos
109
+
110
+
111
+ def get_slds_from_d_rhos(d_rhos: Tensor) -> Tensor:
112
+ slds = torch.cumsum(d_rhos, dim=-1)
113
+ return slds
114
+
115
+
116
+ def get_erf(z, z0, sigma, amp):
117
+ return (torch.erf((z - z0) / sigma) + 1) * amp / 2
118
+
119
+
120
+ def get_gauss(z, z0, sigma, amp):
121
+ return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
122
+
123
+ def get_density_profiles(
124
+ thicknesses: torch.Tensor,
125
+ roughnesses: torch.Tensor,
126
+ slds: torch.Tensor,
127
+ ambient_sld: torch.Tensor = None,
128
+ z_axis: torch.Tensor = None,
129
+ num: int = 1000,
130
+ padding_left: float = 0.2,
131
+ padding_right: float = 1.1,
132
+ ):
133
+ """
134
+ Args:
135
+ thicknesses (Tensor): finite layer thicknesses.
136
+ roughnesses (Tensor): interface roughnesses for all transitions (ambient→layer1 ... layerN→substrate).
137
+ slds (Tensor): SLDs for the finite layers + substrate.
138
+ ambient_sld (Tensor, optional): SLD for the top ambient. Defaults to 0.0 if None.
139
+ z_axis (Tensor, optional): a custom depth axis. If None, a linear axis is generated.
140
+ num (int): number of points in the generated z-axis (if z_axis is None).
141
+ padding_left (float): factor to extend the negative (above the surface) portion of z-axis.
142
+ padding_right (float): factor to extend the positive (into the sample) portion of z-axis.
143
+
144
+ Returns:
145
+ (z_axis, profile, d_profile)
146
+ z_axis: 1D Tensor of shape (num, ) with the depth coordinates.
147
+ profile: 2D Tensor of shape (batch_size, num) giving the SLD at each depth.
148
+ d_profile: 2D Tensor of shape (batch_size, num) giving d(SLD)/dz at each depth.
149
+ """
150
+
151
+ bs, n = thicknesses.shape
152
+ assert roughnesses.shape == (bs, n + 1), (
153
+ f"Roughnesses must be (batch_size, num_layers+1). Found {roughnesses.shape} instead."
154
+ )
155
+ assert slds.shape == (bs, n + 1), (
156
+ f"SLDs must be (batch_size, num_layers+1). Found {slds.shape} instead."
157
+ )
158
+ assert torch.all(thicknesses >= 0), "Negative thickness encountered."
159
+ assert torch.all(roughnesses >= 0), "Negative roughness encountered."
160
+
161
+ if ambient_sld is None:
162
+ ambient_sld = torch.zeros((bs, 1), device=thicknesses.device)
163
+ else:
164
+ if ambient_sld.ndim == 1:
165
+ ambient_sld = ambient_sld.unsqueeze(-1)
166
+ ambient_sld = ambient_sld.expand(bs, 1)
167
+
168
+ slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
169
+ d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
170
+
171
+ interfaces = torch.cat([
172
+ torch.zeros((bs, 1), device=thicknesses.device), # z=0 for ambient→layer1
173
+ thicknesses
174
+ ], dim=-1).cumsum(dim=-1) # now shape => (bs, n+1)
175
+
176
+ total_thickness = interfaces[..., -1].max()
177
+ if z_axis is None:
178
+ z_axis = torch.linspace(
179
+ -padding_left * total_thickness,
180
+ padding_right * total_thickness,
181
+ num,
182
+ device=thicknesses.device
183
+ ) # shape => (num,)
184
+ if z_axis.ndim == 1:
185
+ z_axis = z_axis.unsqueeze(0) # shape => (1, num)
186
+
187
+ z_b = z_axis.repeat(bs, 1).unsqueeze(1) # (bs, 1, num)
188
+ interfaces_b = interfaces.unsqueeze(-1) # (bs, n+1, 1)
189
+ sigmas_b = (roughnesses * sqrt(2)).unsqueeze(-1) # (bs, n+1, 1)
190
+ d_rhos_b = d_rhos.unsqueeze(-1) # (bs, n+1, 1)
191
+
192
+ profile = get_erf(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
193
+ if ambient_sld is not None:
194
+ profile = profile + ambient_sld
195
+
196
+ d_profile = get_gauss(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
197
+
198
+ return z_axis.squeeze(0), profile, d_profile
199
+
200
+ def get_param_labels(
201
+ num_layers: int, *,
202
+ thickness_name: str = 'Thickness',
203
+ roughness_name: str = 'Roughness',
204
+ sld_name: str = 'SLD',
205
+ imag_sld_name: str = 'SLD imag',
206
+ substrate_name: str = 'sub',
207
+ parameterization_type: str = 'standard',
208
+ number_top_to_bottom: bool = False,
209
+ ) -> List[str]:
210
+ def pos(i):
211
+ return i + 1 if number_top_to_bottom else num_layers - i
212
+
213
+ thickness_labels = [f'{thickness_name} L{pos(i)}' for i in range(num_layers)]
214
+ roughness_labels = [f'{roughness_name} L{pos(i)}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
215
+ sld_labels = [f'{sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
216
+
217
+ all_labels = thickness_labels + roughness_labels + sld_labels
218
+
219
+ if parameterization_type == 'absorption':
220
+ imag_sld_labels = [f'{imag_sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
221
+ all_labels = all_labels + imag_sld_labels
222
+
223
+ return all_labels
@@ -1,6 +1,11 @@
1
- from .callbacks import JPlotLoss
2
-
3
-
4
- __all__ = [
5
- 'JPlotLoss',
6
- ]
1
+ """
2
+ Reflectorch Jupyter Extensions
3
+ """
4
+ from reflectorch.extensions.jupyter.api import create_widget, ReflectorchPlotlyWidget
5
+ from reflectorch.extensions.jupyter.callbacks import JPlotLoss
6
+
7
+ __all__ = [
8
+ 'create_widget',
9
+ 'JPlotLoss',
10
+ 'ReflectorchPlotlyWidget',
11
+ ]
@@ -0,0 +1,85 @@
1
+ """
2
+ This module provides API for creating and using
3
+ Reflectorch widgets and plots in Jupyter notebooks.
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import Optional, Union, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from reflectorch.inference.inference_model import InferenceModel
11
+
12
+ from reflectorch.extensions.jupyter.widget import ReflectorchPlotlyWidget
13
+
14
+
15
+ def create_widget(
16
+ reflectivity_curve: np.ndarray,
17
+ q_values: np.ndarray,
18
+ model: Optional["InferenceModel"] = None,
19
+ sigmas: Optional[np.ndarray] = None,
20
+ q_resolution: Optional[Union[float, np.ndarray]] = None,
21
+ initial_prior_bounds: Optional[np.ndarray] = None,
22
+ ambient_sld: Optional[float] = None,
23
+ controls_width: int = 700,
24
+ plot_width: int = 400,
25
+ plot_height: int = 300,
26
+ ) -> ReflectorchPlotlyWidget:
27
+ """
28
+ Create and display a Reflectorch analysis widget
29
+
30
+ This is the main function for creating Reflectorch widgets.
31
+
32
+ Parameters:
33
+ ----------
34
+ reflectivity_curve: Experimental reflectivity data
35
+ q_values: Momentum transfer values
36
+ model: InferenceModel instance for making predictions (optional)
37
+ sigmas: Experimental uncertainties (optional)
38
+ q_resolution: Q-resolution, float or array (optional)
39
+ initial_prior_bounds: Initial bounds for priors, shape (n_params, 2) (optional)
40
+ ambient_sld: Ambient SLD value (optional)
41
+ controls_width: Width of the controls area in pixels. Default is 700px.
42
+ plot_width: Width of the plots in pixels. Default is 400px.
43
+ plot_height: Height of the plots in pixels. Default is 300px.
44
+
45
+ Returns:
46
+ -------
47
+ ReflectorchPlotlyWidget instance with the widget displayed
48
+
49
+ Example:
50
+ -------
51
+ ```python
52
+ # Load data
53
+ from reflectorch.paths import ROOT_DIR
54
+ data = np.loadtxt(ROOT_DIR / "exp_data/data_C60.txt")
55
+
56
+ # create widget (displayed automatically)
57
+ widget = create_widget(q_values=data[..., 0], reflectivity_curve=data[..., 1])
58
+ ```
59
+ """
60
+ # Create widget instance
61
+ widget = ReflectorchPlotlyWidget(
62
+ reflectivity_curve=reflectivity_curve,
63
+ q_values=q_values,
64
+ sigmas=sigmas,
65
+ q_resolution=q_resolution,
66
+ initial_prior_bounds=initial_prior_bounds,
67
+ ambient_sld=ambient_sld,
68
+ model=model,
69
+ )
70
+
71
+ # Display the widget interface
72
+ widget.display(
73
+ controls_width=controls_width,
74
+ plot_width=plot_width,
75
+ plot_height=plot_height
76
+ )
77
+
78
+ return widget
79
+
80
+
81
+ # Export the main widget class for direct usage
82
+ __all__ = [
83
+ 'create_widget',
84
+ 'ReflectorchPlotlyWidget'
85
+ ]
@@ -1,34 +1,34 @@
1
- from IPython.display import clear_output
2
-
3
- from ...ml import TrainerCallback, Trainer
4
-
5
- from ..matplotlib import plot_losses
6
-
7
-
8
- class JPlotLoss(TrainerCallback):
9
- """Callback for plotting the loss in a Jupyter notebook
10
- """
11
- def __init__(self, frequency: int, log: bool = True, clear: bool = True, **kwargs):
12
- """
13
-
14
- Args:
15
- frequency (int): plotting frequency
16
- log (bool, optional): if True, the plot is on a logarithmic scale. Defaults to True.
17
- clear (bool, optional):
18
- """
19
- self.frequency = frequency
20
- self.log = log
21
- self.kwargs = kwargs
22
- self.clear = clear
23
-
24
- def end_batch(self, trainer: Trainer, batch_num: int) -> None:
25
- if not batch_num % self.frequency:
26
- if self.clear:
27
- clear_output(wait=True)
28
-
29
- plot_losses(
30
- trainer.losses,
31
- log=self.log,
32
- best_epoch=trainer.callback_params.get('saved_iteration', None),
33
- **self.kwargs
34
- )
1
+ from IPython.display import clear_output
2
+
3
+ from ...ml import TrainerCallback, Trainer
4
+
5
+ from ..matplotlib import plot_losses
6
+
7
+
8
+ class JPlotLoss(TrainerCallback):
9
+ """Callback for plotting the loss in a Jupyter notebook
10
+ """
11
+ def __init__(self, frequency: int, log: bool = True, clear: bool = True, **kwargs):
12
+ """
13
+
14
+ Args:
15
+ frequency (int): plotting frequency
16
+ log (bool, optional): if True, the plot is on a logarithmic scale. Defaults to True.
17
+ clear (bool, optional):
18
+ """
19
+ self.frequency = frequency
20
+ self.log = log
21
+ self.kwargs = kwargs
22
+ self.clear = clear
23
+
24
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
25
+ if not batch_num % self.frequency:
26
+ if self.clear:
27
+ clear_output(wait=True)
28
+
29
+ plot_losses(
30
+ trainer.losses,
31
+ log=self.log,
32
+ best_epoch=trainer.callback_params.get('saved_iteration', None),
33
+ **self.kwargs
34
+ )