reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,118 +1,118 @@
1
- from typing import Tuple
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- from reflectorch.data_generation.utils import (
7
- get_d_rhos,
8
- uniform_sampler,
9
- )
10
-
11
-
12
- def get_max_allowed_roughness(thicknesses: Tensor, mask: Tensor = None, coef: float = 0.5):
13
- """gets the maximum allowed interlayer roughnesses such that they do not exceed a fraction of the thickness of either layers meeting at that interface"""
14
- batch_size, layers_num = thicknesses.shape
15
- max_roughness = torch.ones(
16
- batch_size, layers_num + 1, device=thicknesses.device, dtype=thicknesses.dtype
17
- ) * float('inf')
18
-
19
- boundary = thicknesses * coef
20
- if mask is not None:
21
- boundary[get_thickness_mask_from_sld_mask(mask)] = float('inf')
22
-
23
- max_roughness[:, :-1] = boundary
24
- max_roughness[:, 1:] = torch.minimum(max_roughness[:, 1:], boundary)
25
- return max_roughness
26
-
27
-
28
- def get_allowed_contrast_indices(slds: Tensor, min_contrast: float, mask: Tensor = None) -> Tensor:
29
- d_rhos = get_d_rhos(slds)
30
- indices = d_rhos.abs() >= min_contrast
31
- if mask is not None:
32
- indices = indices | mask
33
- indices = torch.all(indices, -1)
34
- return indices
35
-
36
-
37
- def params_within_bounds(params_t: Tensor, min_t: Tensor, max_t: Tensor, mask: Tensor = None) -> Tensor:
38
- indices = (params_t >= min_t[None]) & (params_t <= max_t[None])
39
- if mask is not None:
40
- indices = indices | mask
41
- indices = torch.all(indices, -1)
42
- return indices
43
-
44
-
45
- def get_allowed_roughness_indices(thicknesses: Tensor, roughnesses: Tensor, mask: Tensor = None) -> Tensor:
46
- max_roughness = get_max_allowed_roughness(thicknesses, mask)
47
- indices = roughnesses <= max_roughness
48
- if mask is not None:
49
- indices = indices | mask
50
- indices = torch.all(indices, -1)
51
- return indices
52
-
53
-
54
- def get_thickness_mask_from_sld_mask(mask: Tensor):
55
- return mask[:, :-1]
56
-
57
-
58
- def generate_roughnesses(thicknesses: Tensor, roughness_range: Tuple[float, float], mask: Tensor = None):
59
- batch_size, layers_num = thicknesses.shape
60
- max_roughness = get_max_allowed_roughness(thicknesses, mask)
61
- max_roughness = torch.clamp_(max_roughness, max=roughness_range[1])
62
-
63
- roughnesses = uniform_sampler(
64
- roughness_range[0], max_roughness, batch_size, layers_num + 1,
65
- device=thicknesses.device, dtype=thicknesses.dtype
66
- )
67
-
68
- if mask is not None:
69
- roughnesses[mask] = 0.
70
-
71
- return roughnesses
72
-
73
-
74
- def generate_thicknesses(
75
- thickness_range: Tuple[float, float],
76
- batch_size: int,
77
- layers_num: int,
78
- device: torch.device,
79
- dtype: torch.dtype,
80
- mask: Tensor = None
81
- ):
82
- thicknesses = uniform_sampler(
83
- *thickness_range, batch_size, layers_num, device=device, dtype=dtype
84
- )
85
- if mask is not None:
86
- thicknesses[get_thickness_mask_from_sld_mask(mask)] = 0.
87
- return thicknesses
88
-
89
-
90
- def generate_slds_with_min_contrast(
91
- sld_range: Tuple[float, float],
92
- batch_size: int,
93
- layers_num: int,
94
- min_contrast: float,
95
- device: torch.device,
96
- dtype: torch.dtype,
97
- mask: Tensor = None,
98
- *,
99
- _depth: int = 0
100
- ):
101
- # rejection sampling
102
- slds = uniform_sampler(
103
- *sld_range, batch_size, layers_num + 1, device=device, dtype=dtype
104
- )
105
-
106
- if mask is not None:
107
- slds[mask] = 0.
108
-
109
- rejected_indices = ~get_allowed_contrast_indices(slds, min_contrast, mask)
110
- rejected_num = rejected_indices.sum(0).item()
111
-
112
- if rejected_num:
113
- if mask is not None:
114
- mask = mask[rejected_indices]
115
- slds[rejected_indices] = generate_slds_with_min_contrast(
116
- sld_range, rejected_num, layers_num, min_contrast, device, dtype, mask, _depth=_depth+1
117
- )
118
- return slds
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.utils import (
7
+ get_d_rhos,
8
+ uniform_sampler,
9
+ )
10
+
11
+
12
+ def get_max_allowed_roughness(thicknesses: Tensor, mask: Tensor = None, coef: float = 0.5):
13
+ """gets the maximum allowed interlayer roughnesses such that they do not exceed a fraction of the thickness of either layers meeting at that interface"""
14
+ batch_size, layers_num = thicknesses.shape
15
+ max_roughness = torch.ones(
16
+ batch_size, layers_num + 1, device=thicknesses.device, dtype=thicknesses.dtype
17
+ ) * float('inf')
18
+
19
+ boundary = thicknesses * coef
20
+ if mask is not None:
21
+ boundary[get_thickness_mask_from_sld_mask(mask)] = float('inf')
22
+
23
+ max_roughness[:, :-1] = boundary
24
+ max_roughness[:, 1:] = torch.minimum(max_roughness[:, 1:], boundary)
25
+ return max_roughness
26
+
27
+
28
+ def get_allowed_contrast_indices(slds: Tensor, min_contrast: float, mask: Tensor = None) -> Tensor:
29
+ d_rhos = get_d_rhos(slds)
30
+ indices = d_rhos.abs() >= min_contrast
31
+ if mask is not None:
32
+ indices = indices | mask
33
+ indices = torch.all(indices, -1)
34
+ return indices
35
+
36
+
37
+ def params_within_bounds(params_t: Tensor, min_t: Tensor, max_t: Tensor, mask: Tensor = None) -> Tensor:
38
+ indices = (params_t >= min_t[None]) & (params_t <= max_t[None])
39
+ if mask is not None:
40
+ indices = indices | mask
41
+ indices = torch.all(indices, -1)
42
+ return indices
43
+
44
+
45
+ def get_allowed_roughness_indices(thicknesses: Tensor, roughnesses: Tensor, mask: Tensor = None) -> Tensor:
46
+ max_roughness = get_max_allowed_roughness(thicknesses, mask)
47
+ indices = roughnesses <= max_roughness
48
+ if mask is not None:
49
+ indices = indices | mask
50
+ indices = torch.all(indices, -1)
51
+ return indices
52
+
53
+
54
+ def get_thickness_mask_from_sld_mask(mask: Tensor):
55
+ return mask[:, :-1]
56
+
57
+
58
+ def generate_roughnesses(thicknesses: Tensor, roughness_range: Tuple[float, float], mask: Tensor = None):
59
+ batch_size, layers_num = thicknesses.shape
60
+ max_roughness = get_max_allowed_roughness(thicknesses, mask)
61
+ max_roughness = torch.clamp_(max_roughness, max=roughness_range[1])
62
+
63
+ roughnesses = uniform_sampler(
64
+ roughness_range[0], max_roughness, batch_size, layers_num + 1,
65
+ device=thicknesses.device, dtype=thicknesses.dtype
66
+ )
67
+
68
+ if mask is not None:
69
+ roughnesses[mask] = 0.
70
+
71
+ return roughnesses
72
+
73
+
74
+ def generate_thicknesses(
75
+ thickness_range: Tuple[float, float],
76
+ batch_size: int,
77
+ layers_num: int,
78
+ device: torch.device,
79
+ dtype: torch.dtype,
80
+ mask: Tensor = None
81
+ ):
82
+ thicknesses = uniform_sampler(
83
+ *thickness_range, batch_size, layers_num, device=device, dtype=dtype
84
+ )
85
+ if mask is not None:
86
+ thicknesses[get_thickness_mask_from_sld_mask(mask)] = 0.
87
+ return thicknesses
88
+
89
+
90
+ def generate_slds_with_min_contrast(
91
+ sld_range: Tuple[float, float],
92
+ batch_size: int,
93
+ layers_num: int,
94
+ min_contrast: float,
95
+ device: torch.device,
96
+ dtype: torch.dtype,
97
+ mask: Tensor = None,
98
+ *,
99
+ _depth: int = 0
100
+ ):
101
+ # rejection sampling
102
+ slds = uniform_sampler(
103
+ *sld_range, batch_size, layers_num + 1, device=device, dtype=dtype
104
+ )
105
+
106
+ if mask is not None:
107
+ slds[mask] = 0.
108
+
109
+ rejected_indices = ~get_allowed_contrast_indices(slds, min_contrast, mask)
110
+ rejected_num = rejected_indices.sum(0).item()
111
+
112
+ if rejected_num:
113
+ if mask is not None:
114
+ mask = mask[rejected_indices]
115
+ slds[rejected_indices] = generate_slds_with_min_contrast(
116
+ sld_range, rejected_num, layers_num, min_contrast, device, dtype, mask, _depth=_depth+1
117
+ )
118
+ return slds
@@ -1,41 +1,41 @@
1
- from typing import Any
2
-
3
- __all__ = [
4
- "ProcessData",
5
- "ProcessPipeline",
6
- ]
7
-
8
-
9
- class ProcessData(object):
10
- def __add__(self, other):
11
- if isinstance(other, ProcessData):
12
- return ProcessPipeline(self, other)
13
-
14
- def apply(self, args: Any, context: dict = None):
15
- return args
16
-
17
- def __call__(self, args: Any, context: dict = None):
18
- return self.apply(args, context)
19
-
20
- def __repr__(self):
21
- return f'{self.__class__.__name__}()'
22
-
23
-
24
- class ProcessPipeline(ProcessData):
25
- def __init__(self, *processes):
26
- self._processes = list(processes)
27
-
28
- def apply(self, args: Any, context: dict = None):
29
- for process in self._processes:
30
- args = process(args, context)
31
- return args
32
-
33
- def __add__(self, other):
34
- if isinstance(other, ProcessPipeline):
35
- return ProcessPipeline(*self._processes, *other._processes)
36
- elif isinstance(other, ProcessData):
37
- return ProcessPipeline(*self._processes, other)
38
-
39
- def __repr__(self):
40
- processes = ", ".join(repr(p) for p in self._processes)
41
- return f'ProcessPipeline({processes})'
1
+ from typing import Any
2
+
3
+ __all__ = [
4
+ "ProcessData",
5
+ "ProcessPipeline",
6
+ ]
7
+
8
+
9
+ class ProcessData(object):
10
+ def __add__(self, other):
11
+ if isinstance(other, ProcessData):
12
+ return ProcessPipeline(self, other)
13
+
14
+ def apply(self, args: Any, context: dict = None):
15
+ return args
16
+
17
+ def __call__(self, args: Any, context: dict = None):
18
+ return self.apply(args, context)
19
+
20
+ def __repr__(self):
21
+ return f'{self.__class__.__name__}()'
22
+
23
+
24
+ class ProcessPipeline(ProcessData):
25
+ def __init__(self, *processes):
26
+ self._processes = list(processes)
27
+
28
+ def apply(self, args: Any, context: dict = None):
29
+ for process in self._processes:
30
+ args = process(args, context)
31
+ return args
32
+
33
+ def __add__(self, other):
34
+ if isinstance(other, ProcessPipeline):
35
+ return ProcessPipeline(*self._processes, *other._processes)
36
+ elif isinstance(other, ProcessData):
37
+ return ProcessPipeline(*self._processes, other)
38
+
39
+ def __repr__(self):
40
+ processes = ", ".join(repr(p) for p in self._processes)
41
+ return f'ProcessPipeline({processes})'