reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,32 +1,32 @@
1
- import matplotlib.pyplot as plt
2
-
3
-
4
- def plot_losses(
5
- losses: dict,
6
- log: bool = False,
7
- show: bool = True,
8
- title: str = 'Losses',
9
- x_label: str = 'Iterations',
10
- best_epoch: float = None,
11
- **kwargs
12
- ):
13
- func = plt.semilogy if log else plt.plot
14
-
15
- if len(losses) <= 2:
16
- losses = {'loss': losses['total_loss']}
17
-
18
- for k, data in losses.items():
19
- func(data, label=k, **kwargs)
20
-
21
- if best_epoch is not None:
22
- plt.axvline(best_epoch, ls='--', color='red')
23
-
24
- plt.xlabel(x_label)
25
-
26
- if len(losses) > 2:
27
- plt.legend()
28
-
29
- plt.title(title)
30
-
31
- if show:
32
- plt.show()
1
+ import matplotlib.pyplot as plt
2
+
3
+
4
+ def plot_losses(
5
+ losses: dict,
6
+ log: bool = False,
7
+ show: bool = True,
8
+ title: str = 'Losses',
9
+ x_label: str = 'Iterations',
10
+ best_epoch: float = None,
11
+ **kwargs
12
+ ):
13
+ func = plt.semilogy if log else plt.plot
14
+
15
+ if len(losses) <= 2:
16
+ losses = {'loss': losses['total_loss']}
17
+
18
+ for k, data in losses.items():
19
+ func(data, label=k, **kwargs)
20
+
21
+ if best_epoch is not None:
22
+ plt.axvline(best_epoch, ls='--', color='red')
23
+
24
+ plt.xlabel(x_label)
25
+
26
+ if len(losses) > 2:
27
+ plt.legend()
28
+
29
+ plt.title(title)
30
+
31
+ if show:
32
+ plt.show()
@@ -1,77 +1,77 @@
1
- import numpy as np
2
- from functools import reduce
3
- from operator import or_
4
-
5
- from reflectorch.inference.inference_model import EasyInferenceModel
6
- from reflectorch import BasicParams
7
-
8
- import refnx
9
- from refnx.dataset import ReflectDataset, Data1D
10
- from refnx.analysis import Transform, CurveFitter, Objective, Model, Parameter
11
- from refnx.reflect import SLD, Slab, ReflectModel
12
-
13
- def covert_reflectorch_prediction_to_refnx_structure(inference_model: EasyInferenceModel, pred_params_object: BasicParams, prior_bounds: np.array):
14
- assert inference_model.trainer.loader.prior_sampler.param_model.__class__.__name__ == 'StandardModel'
15
-
16
- n_layers = inference_model.trainer.loader.prior_sampler.max_num_layers
17
- init_thicknesses = pred_params_object.thicknesses.squeeze().tolist()
18
- init_roughnesses = pred_params_object.roughnesses.squeeze().tolist()
19
- init_slds = pred_params_object.slds.squeeze().tolist()
20
-
21
- sld_objects = []
22
-
23
- for sld in init_slds:
24
- sld_objects.append(SLD(value=sld))
25
-
26
- layer_objects = [SLD(0)()]
27
- for i in range(n_layers):
28
- layer_objects.append(sld_objects[i](init_thicknesses[i], init_roughnesses[i]))
29
-
30
- layer_objects.append(sld_objects[-1](0, init_roughnesses[-1]))
31
-
32
- thickness_bounds = prior_bounds[:n_layers]
33
- roughness_bounds = prior_bounds[n_layers:2*n_layers+1]
34
- sld_bounds = prior_bounds[2*n_layers+1:]
35
-
36
- for i, layer in enumerate(layer_objects):
37
- if i == 0:
38
- print("Ambient (air)")
39
- print(80 * '-')
40
- elif i < n_layers+1:
41
- layer.thick.setp(bounds=thickness_bounds[i-1], vary=True)
42
- layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
43
- layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
44
-
45
- print(f'Layer {i}')
46
- print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
47
- print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
48
- print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
49
- print(80 * '-')
50
- else: #substrate
51
- layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
52
- layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
53
-
54
- print(f'Substrate')
55
- print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
56
- print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
57
- print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
58
-
59
- refnx_structure = reduce(or_, layer_objects)
60
-
61
- return refnx_structure
62
-
63
-
64
- ###Example usage:
65
- # refnx_structure = covert_reflectorch_prediction_to_refnx_structure(inference_model, pred_params_object, prior_bounds)
66
-
67
- # refnx_reflect_model = ReflectModel(refnx_structure, bkg=1e-10, dq=0.0)
68
- # refnx_reflect_model.scale.setp(bounds=(0.8, 1.2), vary=True)
69
- # refnx_reflect_model.q_offset.setp(bounds=(-0.01, 0.01), vary=True)
70
- # refnx_reflect_model.bkg.setp(bounds=(1e-10, 1e-8), vary=True)
71
-
72
-
73
- # data = Data1D(data=(q_model, exp_curve_interp))
74
-
75
- # refnx_objective = Objective(refnx_reflect_model, data, transform=Transform("logY"))
76
- # fitter = CurveFitter(refnx_objective)
1
+ import numpy as np
2
+ from functools import reduce
3
+ from operator import or_
4
+
5
+ from reflectorch.inference.inference_model import EasyInferenceModel
6
+ from reflectorch import BasicParams
7
+
8
+ import refnx
9
+ from refnx.dataset import ReflectDataset, Data1D
10
+ from refnx.analysis import Transform, CurveFitter, Objective, Model, Parameter
11
+ from refnx.reflect import SLD, Slab, ReflectModel
12
+
13
+ def covert_reflectorch_prediction_to_refnx_structure(inference_model: EasyInferenceModel, pred_params_object: BasicParams, prior_bounds: np.array):
14
+ assert inference_model.trainer.loader.prior_sampler.param_model.__class__.__name__ == 'StandardModel'
15
+
16
+ n_layers = inference_model.trainer.loader.prior_sampler.max_num_layers
17
+ init_thicknesses = pred_params_object.thicknesses.squeeze().tolist()
18
+ init_roughnesses = pred_params_object.roughnesses.squeeze().tolist()
19
+ init_slds = pred_params_object.slds.squeeze().tolist()
20
+
21
+ sld_objects = []
22
+
23
+ for sld in init_slds:
24
+ sld_objects.append(SLD(value=sld))
25
+
26
+ layer_objects = [SLD(0)()]
27
+ for i in range(n_layers):
28
+ layer_objects.append(sld_objects[i](init_thicknesses[i], init_roughnesses[i]))
29
+
30
+ layer_objects.append(sld_objects[-1](0, init_roughnesses[-1]))
31
+
32
+ thickness_bounds = prior_bounds[:n_layers]
33
+ roughness_bounds = prior_bounds[n_layers:2*n_layers+1]
34
+ sld_bounds = prior_bounds[2*n_layers+1:]
35
+
36
+ for i, layer in enumerate(layer_objects):
37
+ if i == 0:
38
+ print("Ambient (air)")
39
+ print(80 * '-')
40
+ elif i < n_layers+1:
41
+ layer.thick.setp(bounds=thickness_bounds[i-1], vary=True)
42
+ layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
43
+ layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
44
+
45
+ print(f'Layer {i}')
46
+ print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
47
+ print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
48
+ print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
49
+ print(80 * '-')
50
+ else: #substrate
51
+ layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
52
+ layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
53
+
54
+ print(f'Substrate')
55
+ print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
56
+ print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
57
+ print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
58
+
59
+ refnx_structure = reduce(or_, layer_objects)
60
+
61
+ return refnx_structure
62
+
63
+
64
+ ###Example usage:
65
+ # refnx_structure = covert_reflectorch_prediction_to_refnx_structure(inference_model, pred_params_object, prior_bounds)
66
+
67
+ # refnx_reflect_model = ReflectModel(refnx_structure, bkg=1e-10, dq=0.0)
68
+ # refnx_reflect_model.scale.setp(bounds=(0.8, 1.2), vary=True)
69
+ # refnx_reflect_model.q_offset.setp(bounds=(-0.01, 0.01), vary=True)
70
+ # refnx_reflect_model.bkg.setp(bounds=(1e-10, 1e-8), vary=True)
71
+
72
+
73
+ # data = Data1D(data=(q_model, exp_curve_interp))
74
+
75
+ # refnx_objective = Objective(refnx_reflect_model, data, transform=Transform("logY"))
76
+ # fitter = CurveFitter(refnx_objective)
77
77
  # fitter.fit('least_squares')
@@ -1,24 +1,28 @@
1
- from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
2
- from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
3
- from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
4
- from reflectorch.inference.preprocess_exp import (
5
- StandardPreprocessing,
6
- standard_preprocessing,
7
- interp_reflectivity,
8
- apply_attenuation_correction,
9
- apply_footprint_correction,
10
- )
11
- from reflectorch.inference.torch_fitter import ReflGradientFit
12
-
13
- __all__ = [
14
- "InferenceModel",
15
- "EasyInferenceModel",
16
- "MultilayerInferenceModel",
17
- "HuggingfaceQueryMatcher",
18
- "StandardPreprocessing",
19
- "standard_preprocessing",
20
- "ReflGradientFit",
21
- "interp_reflectivity",
22
- "apply_attenuation_correction",
23
- "apply_footprint_correction",
24
- ]
1
+ from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
2
+ from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
3
+ from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
4
+ from reflectorch.inference.preprocess_exp import (
5
+ StandardPreprocessing,
6
+ standard_preprocessing,
7
+ interp_reflectivity,
8
+ apply_attenuation_correction,
9
+ apply_footprint_correction,
10
+ )
11
+ from reflectorch.inference.torch_fitter import ReflGradientFit
12
+ from reflectorch.inference.input_interface import Layer, Backing, Structure
13
+
14
+ __all__ = [
15
+ "InferenceModel",
16
+ "EasyInferenceModel",
17
+ "MultilayerInferenceModel",
18
+ "HuggingfaceQueryMatcher",
19
+ "StandardPreprocessing",
20
+ "standard_preprocessing",
21
+ "ReflGradientFit",
22
+ "Layer",
23
+ "Backing",
24
+ "Structure",
25
+ "interp_reflectivity",
26
+ "apply_attenuation_correction",
27
+ "apply_footprint_correction",
28
+ ]