reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,134 +1,134 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from reflectorch.models.activations import activation_by_name
6
-
7
- class SpectralConv1d(nn.Module):
8
- def __init__(self, in_channels, out_channels, modes):
9
- super().__init__()
10
-
11
- """
12
- 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
13
- """
14
-
15
- self.in_channels = in_channels
16
- self.out_channels = out_channels
17
- self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
18
-
19
- self.scale = (1 / (in_channels*out_channels))
20
- self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
21
-
22
- # Complex multiplication
23
- def compl_mul1d(self, input, weights):
24
- # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
25
- return torch.einsum("bix,iox->box", input, weights)
26
-
27
- def forward(self, x):
28
- batchsize = x.shape[0]
29
- #Compute Fourier coeffcients up to factor of e^(- something constant)
30
- x_ft = torch.fft.rfft(x)
31
-
32
- # Multiply relevant Fourier modes
33
- out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
34
- out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
35
-
36
- #Return to physical space
37
- x = torch.fft.irfft(out_ft, n=x.size(-1))
38
- return x
39
-
40
-
41
- class FnoEncoder(nn.Module):
42
- """An embedding network based on the Fourier Neural Operator (FNO) architecture
43
-
44
- .. image:: ../documentation/fig_reflectometry_embedding_networks.png
45
- :width: 400px
46
- :align: center
47
-
48
- Args:
49
- in_channels (int): number of input channels
50
- dim_embedding (int): dimension of the output embedding
51
- modes (int): number of Fourier modes
52
- width_fno (int): number of channels of the intermediate representations
53
- n_fno_blocks (int): number of FNO blocks
54
- activation (str): the activation function
55
- fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
56
- fsa_activation (str): the activation function of the fusion self attention block
57
- """
58
- def __init__(
59
- self,
60
- in_channels: int = 2,
61
- dim_embedding: int = 128,
62
- modes: int = 32,
63
- width_fno: int = 64,
64
- n_fno_blocks: int = 6,
65
- activation: str = 'gelu',
66
- fusion_self_attention: bool = False,
67
- fsa_activation: str = 'tanh',
68
- ):
69
- super().__init__()
70
-
71
-
72
- self.in_channels = in_channels
73
- self.dim_embedding = dim_embedding
74
-
75
- self.modes = modes
76
- self.width_fno = width_fno
77
- self.n_fno_blocks = n_fno_blocks
78
- self.activation = activation_by_name(activation)()
79
- self.fusion_self_attention = fusion_self_attention
80
-
81
-
82
- self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
83
- self.spectral_convs = nn.ModuleList([
84
- SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
85
- ])
86
- self.w_convs = nn.ModuleList([
87
- nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
88
- ])
89
- self.fc_out = nn.Linear(width_fno, dim_embedding)
90
-
91
- if fusion_self_attention:
92
- self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
93
-
94
- def forward(self, x):
95
- """"""
96
-
97
- x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
98
- x = self.fc0(x)
99
- x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
100
-
101
- for i in range(self.n_fno_blocks):
102
- x1 = self.spectral_convs[i](x)
103
- x2 = self.w_convs[i](x)
104
-
105
- x = x1 + x2
106
- x = self.activation(x)
107
-
108
- if self.fusion_self_attention:
109
- x = x.permute(0, 2, 1)
110
- x = self.fusion(x)
111
- else:
112
- x = x.mean(dim=-1)
113
-
114
- x = self.fc_out(x)
115
-
116
- return x
117
-
118
-
119
- class FusionSelfAttention(nn.Module):
120
- def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
121
- super().__init__()
122
- activation = activation_by_name(activation)()
123
- self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
124
- activation,
125
- nn.Linear(hidden_dim, 1, bias=False))
126
-
127
- def forward(self,
128
- c: torch.Tensor, # (batch_size x seq_len x embed_dim)
129
- mask: torch.Tensor = None, # (batch_size x seq_len)
130
- ):
131
- a = self.fuser(c)
132
- alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
133
- alpha = alpha/alpha.sum(dim=1, keepdim=True)
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from reflectorch.models.activations import activation_by_name
6
+
7
+ class SpectralConv1d(nn.Module):
8
+ def __init__(self, in_channels, out_channels, modes):
9
+ super().__init__()
10
+
11
+ """
12
+ 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
13
+ """
14
+
15
+ self.in_channels = in_channels
16
+ self.out_channels = out_channels
17
+ self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
18
+
19
+ self.scale = (1 / (in_channels*out_channels))
20
+ self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
21
+
22
+ # Complex multiplication
23
+ def compl_mul1d(self, input, weights):
24
+ # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
25
+ return torch.einsum("bix,iox->box", input, weights)
26
+
27
+ def forward(self, x):
28
+ batchsize = x.shape[0]
29
+ #Compute Fourier coeffcients up to factor of e^(- something constant)
30
+ x_ft = torch.fft.rfft(x)
31
+
32
+ # Multiply relevant Fourier modes
33
+ out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
34
+ out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
35
+
36
+ #Return to physical space
37
+ x = torch.fft.irfft(out_ft, n=x.size(-1))
38
+ return x
39
+
40
+
41
+ class FnoEncoder(nn.Module):
42
+ """An embedding network based on the Fourier Neural Operator (FNO) architecture
43
+
44
+ .. image:: ../documentation/fig_reflectometry_embedding_networks.png
45
+ :width: 400px
46
+ :align: center
47
+
48
+ Args:
49
+ in_channels (int): number of input channels
50
+ dim_embedding (int): dimension of the output embedding
51
+ modes (int): number of Fourier modes
52
+ width_fno (int): number of channels of the intermediate representations
53
+ n_fno_blocks (int): number of FNO blocks
54
+ activation (str): the activation function
55
+ fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
56
+ fsa_activation (str): the activation function of the fusion self attention block
57
+ """
58
+ def __init__(
59
+ self,
60
+ in_channels: int = 2,
61
+ dim_embedding: int = 128,
62
+ modes: int = 32,
63
+ width_fno: int = 64,
64
+ n_fno_blocks: int = 6,
65
+ activation: str = 'gelu',
66
+ fusion_self_attention: bool = False,
67
+ fsa_activation: str = 'tanh',
68
+ ):
69
+ super().__init__()
70
+
71
+
72
+ self.in_channels = in_channels
73
+ self.dim_embedding = dim_embedding
74
+
75
+ self.modes = modes
76
+ self.width_fno = width_fno
77
+ self.n_fno_blocks = n_fno_blocks
78
+ self.activation = activation_by_name(activation)()
79
+ self.fusion_self_attention = fusion_self_attention
80
+
81
+
82
+ self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
83
+ self.spectral_convs = nn.ModuleList([
84
+ SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
85
+ ])
86
+ self.w_convs = nn.ModuleList([
87
+ nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
88
+ ])
89
+ self.fc_out = nn.Linear(width_fno, dim_embedding)
90
+
91
+ if fusion_self_attention:
92
+ self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
93
+
94
+ def forward(self, x):
95
+ """"""
96
+
97
+ x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
98
+ x = self.fc0(x)
99
+ x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
100
+
101
+ for i in range(self.n_fno_blocks):
102
+ x1 = self.spectral_convs[i](x)
103
+ x2 = self.w_convs[i](x)
104
+
105
+ x = x1 + x2
106
+ x = self.activation(x)
107
+
108
+ if self.fusion_self_attention:
109
+ x = x.permute(0, 2, 1)
110
+ x = self.fusion(x)
111
+ else:
112
+ x = x.mean(dim=-1)
113
+
114
+ x = self.fc_out(x)
115
+
116
+ return x
117
+
118
+
119
+ class FusionSelfAttention(nn.Module):
120
+ def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
121
+ super().__init__()
122
+ activation = activation_by_name(activation)()
123
+ self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
124
+ activation,
125
+ nn.Linear(hidden_dim, 1, bias=False))
126
+
127
+ def forward(self,
128
+ c: torch.Tensor, # (batch_size x seq_len x embed_dim)
129
+ mask: torch.Tensor = None, # (batch_size x seq_len)
130
+ ):
131
+ a = self.fuser(c)
132
+ alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
133
+ alpha = alpha/alpha.sum(dim=1, keepdim=True)
134
134
  return (alpha*c).sum(dim=1) # (batch_size x embed_dim)