reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
from math import sqrt, pi, log10
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"get_reversed_params",
|
|
9
|
+
"get_density_profiles",
|
|
10
|
+
"uniform_sampler",
|
|
11
|
+
"logdist_sampler",
|
|
12
|
+
"triangular_sampler",
|
|
13
|
+
"get_param_labels",
|
|
14
|
+
"get_d_rhos",
|
|
15
|
+
"get_slds_from_d_rhos",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def uniform_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
|
|
20
|
+
if isinstance(low, Tensor):
|
|
21
|
+
device, dtype = low.device, low.dtype
|
|
22
|
+
return torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def logdist_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
|
|
26
|
+
if isinstance(low, Tensor):
|
|
27
|
+
device, dtype = low.device, low.dtype
|
|
28
|
+
low, high = map(torch.log10, (low, high))
|
|
29
|
+
else:
|
|
30
|
+
low, high = map(log10, (low, high))
|
|
31
|
+
return 10 ** (torch.rand(*shape, device=device, dtype=dtype) * (high - low) + low)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def triangular_sampler(low: Union[float, Tensor], high: Union[float, Tensor], *shape, device=None, dtype=None):
|
|
35
|
+
if isinstance(low, Tensor):
|
|
36
|
+
device, dtype = low.device, low.dtype
|
|
37
|
+
|
|
38
|
+
x = torch.rand(*shape, device=device, dtype=dtype)
|
|
39
|
+
|
|
40
|
+
return (high - low) * (1 - torch.sqrt(x)) + low
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
44
|
+
reversed_slds = torch.cumsum(
|
|
45
|
+
torch.flip(
|
|
46
|
+
torch.diff(
|
|
47
|
+
torch.cat([torch.zeros(slds.shape[0], 1).to(slds), slds], dim=-1),
|
|
48
|
+
dim=-1
|
|
49
|
+
), (-1,)
|
|
50
|
+
),
|
|
51
|
+
dim=-1
|
|
52
|
+
)
|
|
53
|
+
reversed_thicknesses = torch.flip(thicknesses, [-1])
|
|
54
|
+
reversed_roughnesses = torch.flip(roughnesses, [-1])
|
|
55
|
+
reversed_params = torch.cat([reversed_thicknesses, reversed_roughnesses, reversed_slds], -1)
|
|
56
|
+
|
|
57
|
+
return reversed_params
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_density_profiles_sld(
|
|
61
|
+
thicknesses: Tensor,
|
|
62
|
+
roughnesses: Tensor,
|
|
63
|
+
slds: Tensor,
|
|
64
|
+
z_axis: Tensor = None,
|
|
65
|
+
num: int = 1000
|
|
66
|
+
):
|
|
67
|
+
"""Generates SLD profiles (and their derivative) based on batches of thicknesses, roughnesses and layer SLDs.
|
|
68
|
+
|
|
69
|
+
The axis has its zero at the top (ambient medium) interface and is positive inside the film.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
thicknesses (Tensor): the layer thicknesses (top to bottom)
|
|
73
|
+
roughnesses (Tensor): the interlayer roughnesses (top to bottom)
|
|
74
|
+
slds (Tensor): the layer SLDs (top to bottom)
|
|
75
|
+
z_axis (Tensor, optional): a custom depth (z) axis. Defaults to None.
|
|
76
|
+
num (int, optional): number of discretization points for the profile. Defaults to 1000.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
tuple: the z axis, the computed density profile rho(z) and the derivative of the density profile drho/dz(z)
|
|
80
|
+
"""
|
|
81
|
+
assert torch.all(roughnesses >= 0), 'Negative roughness happened'
|
|
82
|
+
assert torch.all(thicknesses >= 0), 'Negative thickness happened'
|
|
83
|
+
|
|
84
|
+
sample_num = thicknesses.shape[0]
|
|
85
|
+
|
|
86
|
+
d_rhos = get_d_rhos(slds)
|
|
87
|
+
|
|
88
|
+
zs = torch.cumsum(torch.cat([torch.zeros(sample_num, 1).to(thicknesses), thicknesses], dim=-1), dim=-1)
|
|
89
|
+
|
|
90
|
+
if z_axis is None:
|
|
91
|
+
z_axis = torch.linspace(- zs.max() * 0.1, zs.max() * 1.1, num, device=thicknesses.device)[None]
|
|
92
|
+
elif len(z_axis.shape) == 1:
|
|
93
|
+
z_axis = z_axis[None]
|
|
94
|
+
|
|
95
|
+
sigmas = roughnesses * sqrt(2)
|
|
96
|
+
|
|
97
|
+
profile = get_erf(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
|
|
98
|
+
|
|
99
|
+
d_profile = get_gauss(z_axis[:, None], zs[..., None], sigmas[..., None], d_rhos[..., None]).sum(1)
|
|
100
|
+
|
|
101
|
+
z_axis = z_axis[0]
|
|
102
|
+
|
|
103
|
+
return z_axis, profile, d_profile
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_d_rhos(slds: Tensor) -> Tensor:
|
|
107
|
+
d_rhos = torch.cat([slds[:, 0][:, None], torch.diff(slds, dim=-1)], -1)
|
|
108
|
+
return d_rhos
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_slds_from_d_rhos(d_rhos: Tensor) -> Tensor:
|
|
112
|
+
slds = torch.cumsum(d_rhos, dim=-1)
|
|
113
|
+
return slds
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_erf(z, z0, sigma, amp):
|
|
117
|
+
return (torch.erf((z - z0) / sigma) + 1) * amp / 2
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_gauss(z, z0, sigma, amp):
|
|
121
|
+
return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
|
|
122
|
+
|
|
123
|
+
def get_density_profiles(
|
|
124
|
+
thicknesses: torch.Tensor,
|
|
125
|
+
roughnesses: torch.Tensor,
|
|
126
|
+
slds: torch.Tensor,
|
|
127
|
+
ambient_sld: torch.Tensor = None,
|
|
128
|
+
z_axis: torch.Tensor = None,
|
|
129
|
+
num: int = 1000,
|
|
130
|
+
padding_left: float = 0.2,
|
|
131
|
+
padding_right: float = 1.1,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Args:
|
|
135
|
+
thicknesses (Tensor): finite layer thicknesses.
|
|
136
|
+
roughnesses (Tensor): interface roughnesses for all transitions (ambient→layer1 ... layerN→substrate).
|
|
137
|
+
slds (Tensor): SLDs for the finite layers + substrate.
|
|
138
|
+
ambient_sld (Tensor, optional): SLD for the top ambient. Defaults to 0.0 if None.
|
|
139
|
+
z_axis (Tensor, optional): a custom depth axis. If None, a linear axis is generated.
|
|
140
|
+
num (int): number of points in the generated z-axis (if z_axis is None).
|
|
141
|
+
padding_left (float): factor to extend the negative (above the surface) portion of z-axis.
|
|
142
|
+
padding_right (float): factor to extend the positive (into the sample) portion of z-axis.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
(z_axis, profile, d_profile)
|
|
146
|
+
z_axis: 1D Tensor of shape (num, ) with the depth coordinates.
|
|
147
|
+
profile: 2D Tensor of shape (batch_size, num) giving the SLD at each depth.
|
|
148
|
+
d_profile: 2D Tensor of shape (batch_size, num) giving d(SLD)/dz at each depth.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
bs, n = thicknesses.shape
|
|
152
|
+
assert roughnesses.shape == (bs, n + 1), (
|
|
153
|
+
f"Roughnesses must be (batch_size, num_layers+1). Found {roughnesses.shape} instead."
|
|
154
|
+
)
|
|
155
|
+
assert slds.shape == (bs, n + 1), (
|
|
156
|
+
f"SLDs must be (batch_size, num_layers+1). Found {slds.shape} instead."
|
|
157
|
+
)
|
|
158
|
+
assert torch.all(thicknesses >= 0), "Negative thickness encountered."
|
|
159
|
+
assert torch.all(roughnesses >= 0), "Negative roughness encountered."
|
|
160
|
+
|
|
161
|
+
if ambient_sld is None:
|
|
162
|
+
ambient_sld = torch.zeros((bs, 1), device=thicknesses.device)
|
|
163
|
+
else:
|
|
164
|
+
if ambient_sld.ndim == 1:
|
|
165
|
+
ambient_sld = ambient_sld.unsqueeze(-1)
|
|
166
|
+
ambient_sld = ambient_sld.expand(bs, 1)
|
|
167
|
+
|
|
168
|
+
slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
|
|
169
|
+
d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
|
|
170
|
+
|
|
171
|
+
interfaces = torch.cat([
|
|
172
|
+
torch.zeros((bs, 1), device=thicknesses.device), # z=0 for ambient→layer1
|
|
173
|
+
thicknesses
|
|
174
|
+
], dim=-1).cumsum(dim=-1) # now shape => (bs, n+1)
|
|
175
|
+
|
|
176
|
+
total_thickness = interfaces[..., -1].max()
|
|
177
|
+
if z_axis is None:
|
|
178
|
+
z_axis = torch.linspace(
|
|
179
|
+
-padding_left * total_thickness,
|
|
180
|
+
padding_right * total_thickness,
|
|
181
|
+
num,
|
|
182
|
+
device=thicknesses.device
|
|
183
|
+
) # shape => (num,)
|
|
184
|
+
if z_axis.ndim == 1:
|
|
185
|
+
z_axis = z_axis.unsqueeze(0) # shape => (1, num)
|
|
186
|
+
|
|
187
|
+
z_b = z_axis.repeat(bs, 1).unsqueeze(1) # (bs, 1, num)
|
|
188
|
+
interfaces_b = interfaces.unsqueeze(-1) # (bs, n+1, 1)
|
|
189
|
+
sigmas_b = (roughnesses * sqrt(2)).unsqueeze(-1) # (bs, n+1, 1)
|
|
190
|
+
d_rhos_b = d_rhos.unsqueeze(-1) # (bs, n+1, 1)
|
|
191
|
+
|
|
192
|
+
profile = get_erf(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
|
|
193
|
+
if ambient_sld is not None:
|
|
194
|
+
profile = profile + ambient_sld
|
|
195
|
+
|
|
196
|
+
d_profile = get_gauss(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
|
|
197
|
+
|
|
198
|
+
return z_axis.squeeze(0), profile, d_profile
|
|
199
|
+
|
|
200
|
+
def get_param_labels(
|
|
201
|
+
num_layers: int, *,
|
|
202
|
+
thickness_name: str = 'Thickness',
|
|
203
|
+
roughness_name: str = 'Roughness',
|
|
204
|
+
sld_name: str = 'SLD',
|
|
205
|
+
imag_sld_name: str = 'SLD imag',
|
|
206
|
+
substrate_name: str = 'sub',
|
|
207
|
+
parameterization_type: str = 'standard',
|
|
208
|
+
number_top_to_bottom: bool = True,
|
|
209
|
+
) -> List[str]:
|
|
210
|
+
def pos(i):
|
|
211
|
+
return i + 1 if number_top_to_bottom else num_layers - i
|
|
212
|
+
|
|
213
|
+
thickness_labels = [f'{thickness_name} L{pos(i)}' for i in range(num_layers)]
|
|
214
|
+
roughness_labels = [f'{roughness_name} L{pos(i)}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
215
|
+
sld_labels = [f'{sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
|
|
216
|
+
|
|
217
|
+
all_labels = thickness_labels + roughness_labels + sld_labels
|
|
218
|
+
|
|
219
|
+
if parameterization_type == 'absorption':
|
|
220
|
+
imag_sld_labels = [f'{imag_sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
|
|
221
|
+
all_labels = all_labels + imag_sld_labels
|
|
222
|
+
|
|
223
|
+
return all_labels
|
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reflectorch Jupyter Extensions
|
|
3
|
+
"""
|
|
4
|
+
from reflectorch.extensions.jupyter.api import create_widget, ReflectorchPlotlyWidget
|
|
5
|
+
from reflectorch.extensions.jupyter.callbacks import JPlotLoss
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'create_widget',
|
|
9
|
+
'JPlotLoss',
|
|
10
|
+
'ReflectorchPlotlyWidget',
|
|
11
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides API for creating and using
|
|
3
|
+
Reflectorch widgets and plots in Jupyter notebooks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Optional, Union, TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from reflectorch.inference.inference_model import InferenceModel
|
|
11
|
+
|
|
12
|
+
from reflectorch.extensions.jupyter.widget import ReflectorchPlotlyWidget
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_widget(
|
|
16
|
+
reflectivity_curve: np.ndarray,
|
|
17
|
+
q_values: np.ndarray,
|
|
18
|
+
model: Optional["InferenceModel"] = None,
|
|
19
|
+
sigmas: Optional[np.ndarray] = None,
|
|
20
|
+
q_resolution: Optional[Union[float, np.ndarray]] = None,
|
|
21
|
+
initial_prior_bounds: Optional[np.ndarray] = None,
|
|
22
|
+
ambient_sld: Optional[float] = None,
|
|
23
|
+
controls_width: int = 700,
|
|
24
|
+
plot_width: int = 400,
|
|
25
|
+
plot_height: int = 300,
|
|
26
|
+
) -> ReflectorchPlotlyWidget:
|
|
27
|
+
"""
|
|
28
|
+
Create and display a Reflectorch analysis widget
|
|
29
|
+
|
|
30
|
+
This is the main function for creating Reflectorch widgets.
|
|
31
|
+
|
|
32
|
+
Parameters:
|
|
33
|
+
----------
|
|
34
|
+
reflectivity_curve: Experimental reflectivity data
|
|
35
|
+
q_values: Momentum transfer values
|
|
36
|
+
model: InferenceModel instance for making predictions (optional)
|
|
37
|
+
sigmas: Experimental uncertainties (optional)
|
|
38
|
+
q_resolution: Q-resolution, float or array (optional)
|
|
39
|
+
initial_prior_bounds: Initial bounds for priors, shape (n_params, 2) (optional)
|
|
40
|
+
ambient_sld: Ambient SLD value (optional)
|
|
41
|
+
controls_width: Width of the controls area in pixels. Default is 700px.
|
|
42
|
+
plot_width: Width of the plots in pixels. Default is 400px.
|
|
43
|
+
plot_height: Height of the plots in pixels. Default is 300px.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
-------
|
|
47
|
+
ReflectorchPlotlyWidget instance with the widget displayed
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
-------
|
|
51
|
+
```python
|
|
52
|
+
# Load data
|
|
53
|
+
from reflectorch.paths import ROOT_DIR
|
|
54
|
+
data = np.loadtxt(ROOT_DIR / "exp_data/data_C60.txt")
|
|
55
|
+
|
|
56
|
+
# create widget (displayed automatically)
|
|
57
|
+
widget = create_widget(q_values=data[..., 0], reflectivity_curve=data[..., 1])
|
|
58
|
+
```
|
|
59
|
+
"""
|
|
60
|
+
# Create widget instance
|
|
61
|
+
widget = ReflectorchPlotlyWidget(
|
|
62
|
+
reflectivity_curve=reflectivity_curve,
|
|
63
|
+
q_values=q_values,
|
|
64
|
+
sigmas=sigmas,
|
|
65
|
+
q_resolution=q_resolution,
|
|
66
|
+
initial_prior_bounds=initial_prior_bounds,
|
|
67
|
+
ambient_sld=ambient_sld,
|
|
68
|
+
model=model,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Display the widget interface
|
|
72
|
+
widget.display(
|
|
73
|
+
controls_width=controls_width,
|
|
74
|
+
plot_width=plot_width,
|
|
75
|
+
plot_height=plot_height
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return widget
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Export the main widget class for direct usage
|
|
82
|
+
__all__ = [
|
|
83
|
+
'create_widget',
|
|
84
|
+
'ReflectorchPlotlyWidget'
|
|
85
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from IPython.display import clear_output
|
|
2
|
+
|
|
3
|
+
from ...ml import TrainerCallback, Trainer
|
|
4
|
+
|
|
5
|
+
from ..matplotlib import plot_losses
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class JPlotLoss(TrainerCallback):
|
|
9
|
+
"""Callback for plotting the loss in a Jupyter notebook
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, frequency: int, log: bool = True, clear: bool = True, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
frequency (int): plotting frequency
|
|
16
|
+
log (bool, optional): if True, the plot is on a logarithmic scale. Defaults to True.
|
|
17
|
+
clear (bool, optional):
|
|
18
|
+
"""
|
|
19
|
+
self.frequency = frequency
|
|
20
|
+
self.log = log
|
|
21
|
+
self.kwargs = kwargs
|
|
22
|
+
self.clear = clear
|
|
23
|
+
|
|
24
|
+
def end_batch(self, trainer: Trainer, batch_num: int) -> None:
|
|
25
|
+
if not batch_num % self.frequency:
|
|
26
|
+
if self.clear:
|
|
27
|
+
clear_output(wait=True)
|
|
28
|
+
|
|
29
|
+
plot_losses(
|
|
30
|
+
trainer.losses,
|
|
31
|
+
log=self.log,
|
|
32
|
+
best_epoch=trainer.callback_params.get('saved_iteration', None),
|
|
33
|
+
**self.kwargs
|
|
34
|
+
)
|