reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from functools import reduce
|
|
3
|
+
from operator import or_
|
|
4
|
+
|
|
5
|
+
from reflectorch.inference.inference_model import EasyInferenceModel
|
|
6
|
+
from reflectorch import BasicParams
|
|
7
|
+
|
|
8
|
+
import refnx
|
|
9
|
+
from refnx.dataset import ReflectDataset, Data1D
|
|
10
|
+
from refnx.analysis import Transform, CurveFitter, Objective, Model, Parameter
|
|
11
|
+
from refnx.reflect import SLD, Slab, ReflectModel
|
|
12
|
+
|
|
13
|
+
def covert_reflectorch_prediction_to_refnx_structure(inference_model: EasyInferenceModel, pred_params_object: BasicParams, prior_bounds: np.array):
|
|
14
|
+
assert inference_model.trainer.loader.prior_sampler.param_model.__class__.__name__ == 'StandardModel'
|
|
15
|
+
|
|
16
|
+
n_layers = inference_model.trainer.loader.prior_sampler.max_num_layers
|
|
17
|
+
init_thicknesses = pred_params_object.thicknesses.squeeze().tolist()
|
|
18
|
+
init_roughnesses = pred_params_object.roughnesses.squeeze().tolist()
|
|
19
|
+
init_slds = pred_params_object.slds.squeeze().tolist()
|
|
20
|
+
|
|
21
|
+
sld_objects = []
|
|
22
|
+
|
|
23
|
+
for sld in init_slds:
|
|
24
|
+
sld_objects.append(SLD(value=sld))
|
|
25
|
+
|
|
26
|
+
layer_objects = [SLD(0)()]
|
|
27
|
+
for i in range(n_layers):
|
|
28
|
+
layer_objects.append(sld_objects[i](init_thicknesses[i], init_roughnesses[i]))
|
|
29
|
+
|
|
30
|
+
layer_objects.append(sld_objects[-1](0, init_roughnesses[-1]))
|
|
31
|
+
|
|
32
|
+
thickness_bounds = prior_bounds[:n_layers]
|
|
33
|
+
roughness_bounds = prior_bounds[n_layers:2*n_layers+1]
|
|
34
|
+
sld_bounds = prior_bounds[2*n_layers+1:]
|
|
35
|
+
|
|
36
|
+
for i, layer in enumerate(layer_objects):
|
|
37
|
+
if i == 0:
|
|
38
|
+
print("Ambient (air)")
|
|
39
|
+
print(80 * '-')
|
|
40
|
+
elif i < n_layers+1:
|
|
41
|
+
layer.thick.setp(bounds=thickness_bounds[i-1], vary=True)
|
|
42
|
+
layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
|
|
43
|
+
layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
|
|
44
|
+
|
|
45
|
+
print(f'Layer {i}')
|
|
46
|
+
print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
|
|
47
|
+
print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
|
|
48
|
+
print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
|
|
49
|
+
print(80 * '-')
|
|
50
|
+
else: #substrate
|
|
51
|
+
layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
|
|
52
|
+
layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
|
|
53
|
+
|
|
54
|
+
print(f'Substrate')
|
|
55
|
+
print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
|
|
56
|
+
print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
|
|
57
|
+
print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
|
|
58
|
+
|
|
59
|
+
refnx_structure = reduce(or_, layer_objects)
|
|
60
|
+
|
|
61
|
+
return refnx_structure
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
###Example usage:
|
|
65
|
+
# refnx_structure = covert_reflectorch_prediction_to_refnx_structure(inference_model, pred_params_object, prior_bounds)
|
|
66
|
+
|
|
67
|
+
# refnx_reflect_model = ReflectModel(refnx_structure, bkg=1e-10, dq=0.0)
|
|
68
|
+
# refnx_reflect_model.scale.setp(bounds=(0.8, 1.2), vary=True)
|
|
69
|
+
# refnx_reflect_model.q_offset.setp(bounds=(-0.01, 0.01), vary=True)
|
|
70
|
+
# refnx_reflect_model.bkg.setp(bounds=(1e-10, 1e-8), vary=True)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# data = Data1D(data=(q_model, exp_curve_interp))
|
|
74
|
+
|
|
75
|
+
# refnx_objective = Objective(refnx_reflect_model, data, transform=Transform("logY"))
|
|
76
|
+
# fitter = CurveFitter(refnx_objective)
|
|
77
|
+
# fitter.fit('least_squares')
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
|
|
2
|
+
from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
|
|
3
|
+
from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
|
|
4
|
+
from reflectorch.inference.preprocess_exp import (
|
|
5
|
+
StandardPreprocessing,
|
|
6
|
+
standard_preprocessing,
|
|
7
|
+
interp_reflectivity,
|
|
8
|
+
apply_attenuation_correction,
|
|
9
|
+
apply_footprint_correction,
|
|
10
|
+
)
|
|
11
|
+
from reflectorch.inference.torch_fitter import ReflGradientFit
|
|
12
|
+
from reflectorch.inference.input_interface import Layer, Backing, Structure
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"InferenceModel",
|
|
16
|
+
"EasyInferenceModel",
|
|
17
|
+
"MultilayerInferenceModel",
|
|
18
|
+
"HuggingfaceQueryMatcher",
|
|
19
|
+
"StandardPreprocessing",
|
|
20
|
+
"standard_preprocessing",
|
|
21
|
+
"ReflGradientFit",
|
|
22
|
+
"Layer",
|
|
23
|
+
"Backing",
|
|
24
|
+
"Structure",
|
|
25
|
+
"interp_reflectivity",
|
|
26
|
+
"apply_attenuation_correction",
|
|
27
|
+
"apply_footprint_correction",
|
|
28
|
+
]
|