reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,256 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ from typing import Optional
5
+ import torch
6
+ from torch import nn, cat, split, Tensor
7
+
8
+ from reflectorch.models.networks.residual_net import ResidualMLP
9
+ from reflectorch.models.encoders.conv_encoder import ConvEncoder
10
+ from reflectorch.models.encoders.fno import FnoEncoder
11
+ from reflectorch.models.activations import activation_by_name
12
+
13
+ class NetworkWithPriorsConvEmb(nn.Module):
14
+ """MLP network with 1D CNN embedding network
15
+
16
+ .. image:: ../documentation/FigureReflectometryNetwork.png
17
+ :width: 800px
18
+ :align: center
19
+
20
+ Args:
21
+ in_channels (int, optional): the number of input channels of the 1D CNN. Defaults to 1.
22
+ hidden_channels (tuple, optional): list with the number of channels for each layer of the 1D CNN. Defaults to (32, 64, 128, 256, 512).
23
+ dim_embedding (int, optional): the dimension of the embedding produced by the 1D CNN. Defaults to 128.
24
+ dim_avpool (int, optional): the type of activation function in the 1D CNN. Defaults to 1.
25
+ embedding_net_activation (str, optional): the type of activation function in the 1D CNN. Defaults to 'gelu'.
26
+ use_batch_norm (bool, optional): whether to use batch normalization (in both the 1D CNN and the MLP). Defaults to False.
27
+ dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
28
+ layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
29
+ num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
30
+ repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
31
+ mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
32
+ dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
33
+ use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
34
+ pretrained_embedding_net (str, optional): the path to the weights of a pretrained embedding network. Defaults to None.
35
+ residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
36
+ adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
37
+ conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
38
+ """
39
+ def __init__(self,
40
+ in_channels: int = 1,
41
+ hidden_channels: tuple = (32, 64, 128, 256, 512),
42
+ dim_embedding: int = 128,
43
+ dim_avpool: int = 1,
44
+ embedding_net_activation: str = 'gelu',
45
+ use_batch_norm: bool = False,
46
+ dim_out: int = 8,
47
+ layer_width: int = 512,
48
+ num_blocks: int = 4,
49
+ repeats_per_block: int = 2,
50
+ mlp_activation: str = 'gelu',
51
+ dropout_rate: float = 0.0,
52
+ use_selu_init: bool = False,
53
+ pretrained_embedding_net: str = None,
54
+ residual: bool = True,
55
+ adaptive_activation: bool = False,
56
+ conditioning: str = 'concat',
57
+ ):
58
+ super().__init__()
59
+
60
+ self.in_channels = in_channels
61
+ self.conditioning = conditioning
62
+
63
+ self.embedding_net = ConvEncoder(
64
+ in_channels=in_channels,
65
+ hidden_channels=hidden_channels,
66
+ dim_latent=dim_embedding,
67
+ dim_avpool=dim_avpool,
68
+ use_batch_norm=use_batch_norm,
69
+ activation=embedding_net_activation
70
+ )
71
+
72
+ self.dim_prior_bounds = 2 * dim_out
73
+
74
+ if conditioning == 'concat':
75
+ dim_mlp_in = dim_embedding + self.dim_prior_bounds
76
+ dim_condition = 0
77
+ elif conditioning == 'glu' or conditioning == 'film':
78
+ dim_mlp_in = dim_embedding
79
+ dim_condition = self.dim_prior_bounds
80
+ else:
81
+ raise NotImplementedError
82
+
83
+ self.mlp = ResidualMLP(
84
+ dim_in=dim_mlp_in,
85
+ dim_out=dim_out,
86
+ dim_condition=dim_condition,
87
+ layer_width=layer_width,
88
+ num_blocks=num_blocks,
89
+ repeats_per_block=repeats_per_block,
90
+ activation=mlp_activation,
91
+ use_batch_norm=use_batch_norm,
92
+ dropout_rate=dropout_rate,
93
+ residual=residual,
94
+ adaptive_activation=adaptive_activation,
95
+ conditioning=conditioning,
96
+ )
97
+
98
+ if use_selu_init and embedding_net_activation == 'selu':
99
+ self.embedding_net.apply(selu_init)
100
+
101
+ if use_selu_init and mlp_activation == 'selu':
102
+ self.mlp.apply(selu_init)
103
+
104
+ if pretrained_embedding_net:
105
+ self.embedding_net.load_weights(pretrained_embedding_net)
106
+
107
+
108
+ def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] = None):
109
+ """
110
+ Args:
111
+ curves (Tensor): reflectivity curves
112
+ bounds (Tensor): prior bounds
113
+ q_values (Tensor, optional): q values. Defaults to None.
114
+
115
+ Returns:
116
+ Tensor: prediction
117
+ """
118
+ if q_values is not None:
119
+ curves = torch.cat([curves[:, None, :], q_values[:, None, :]], dim=1)
120
+
121
+ if self.conditioning == 'concat':
122
+ x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
123
+ x = self.mlp(x)
124
+
125
+ elif self.conditioning == 'glu' or self.conditioning == 'film':
126
+ x = self.mlp(self.embedding_net(curves), condition=bounds)
127
+
128
+ return x
129
+
130
+
131
+ class NetworkWithPriorsFnoEmb(nn.Module):
132
+ """MLP network with FNO embedding network
133
+
134
+ Args:
135
+ in_channels (int, optional): the number of input channels to the FNO-based embedding network. Defaults to 2.
136
+ dim_embedding (int, optional): the dimension of the embedding produced by the FNO. Defaults to 128.
137
+ modes (int, optional): the number of Fourier modes that are utilized. Defaults to 16.
138
+ width_fno (int, optional): the number of channels in the FNO blocks. Defaults to 64.
139
+ embedding_net_activation (str, optional): the type of activation function in the embedding network. Defaults to 'gelu'.
140
+ n_fno_blocks (int, optional): the number of FNO blocks. Defaults to 6.
141
+ fusion_self_attention (bool, optional): if ``True`` a fusion layer is used after the FNO blocks to produce the final output. Defaults to False.
142
+ dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
143
+ layer_width (int, optional): the width of a linear layer in the MLP. Defaults to 512.
144
+ num_blocks (int, optional): the number of residual blocks in the MLP. Defaults to 4.
145
+ repeats_per_block (int, optional): the number of normalization/activation/linear repeats in a block. Defaults to 2.
146
+ use_batch_norm (bool, optional): whether to use batch normalization (only in the MLP). Defaults to False.
147
+ mlp_activation (str, optional): the type of activation function in the MLP. Defaults to 'gelu'.
148
+ dropout_rate (float, optional): dropout rate for each block. Defaults to 0.0.
149
+ use_selu_init (bool, optional): whether to use the special weights initialization for the 'selu' activation function. Defaults to False.
150
+ residual (bool, optional): whether the blocks have a residual skip connection. Defaults to True.
151
+ adaptive_activation (bool, optional): must be set to ``True`` if the activation function is adaptive. Defaults to False.
152
+ conditioning (str, optional): the manner in which the prior bounds are provided as input to the network. Defaults to 'concat'.
153
+ """
154
+ def __init__(self,
155
+ in_channels: int = 2,
156
+ dim_embedding: int = 128,
157
+ modes: int = 16,
158
+ width_fno: int = 64,
159
+ embedding_net_activation: str = 'gelu',
160
+ n_fno_blocks : int = 6,
161
+ fusion_self_attention: bool = False,
162
+ dim_out: int = 8,
163
+ layer_width: int = 512,
164
+ num_blocks: int = 4,
165
+ repeats_per_block: int = 2,
166
+ use_batch_norm: bool = False,
167
+ mlp_activation: str = 'gelu',
168
+ dropout_rate: float = 0.0,
169
+ use_selu_init: bool = False,
170
+ residual: bool = True,
171
+ adaptive_activation: bool = False,
172
+ conditioning: str = 'concat',
173
+ ):
174
+ super().__init__()
175
+
176
+ self.conditioning = conditioning
177
+
178
+ self.embedding_net = FnoEncoder(
179
+ ch_in=in_channels,
180
+ dim_embedding=dim_embedding,
181
+ modes=modes,
182
+ width_fno=width_fno,
183
+ n_fno_blocks=n_fno_blocks,
184
+ activation=embedding_net_activation,
185
+ fusion_self_attention=fusion_self_attention
186
+ )
187
+
188
+ self.dim_prior_bounds = 2 * dim_out
189
+
190
+ if conditioning == 'concat':
191
+ dim_mlp_in = dim_embedding + self.dim_prior_bounds
192
+ dim_condition = 0
193
+ elif conditioning == 'glu' or conditioning == 'film':
194
+ dim_mlp_in = dim_embedding
195
+ dim_condition = self.dim_prior_bounds
196
+ else:
197
+ raise NotImplementedError
198
+
199
+ self.mlp = ResidualMLP(
200
+ dim_in=dim_mlp_in,
201
+ dim_out=dim_out,
202
+ dim_condition=dim_condition,
203
+ layer_width=layer_width,
204
+ num_blocks=num_blocks,
205
+ repeats_per_block=repeats_per_block,
206
+ activation=mlp_activation,
207
+ use_batch_norm=use_batch_norm,
208
+ dropout_rate=dropout_rate,
209
+ residual=residual,
210
+ adaptive_activation=adaptive_activation,
211
+ conditioning=conditioning,
212
+ )
213
+
214
+ if use_selu_init and embedding_net_activation == 'selu':
215
+ self.FnoEncoder.apply(selu_init)
216
+
217
+ if use_selu_init and mlp_activation == 'selu':
218
+ self.mlp.apply(selu_init)
219
+
220
+
221
+ def forward(self, curves: Tensor, bounds: Tensor, q_values: Optional[Tensor] =None):
222
+ """
223
+ Args:
224
+ curves (Tensor): reflectivity curves
225
+ bounds (Tensor): prior bounds
226
+ q_values (Tensor, optional): q values. Defaults to None.
227
+
228
+ Returns:
229
+ Tensor: prediction
230
+ """
231
+ if curves.dim() < 3:
232
+ curves = curves[:, None, :]
233
+ if q_values is not None:
234
+ curves = torch.cat([curves, q_values[:, None, :]], dim=1)
235
+
236
+ if self.conditioning == 'concat':
237
+ x = torch.cat([self.embedding_net(curves), bounds], dim=-1)
238
+ x = self.mlp(x)
239
+
240
+ elif self.conditioning == 'glu' or self.conditioning == 'film':
241
+ x = self.mlp(self.embedding_net(curves), condition=bounds)
242
+
243
+ return x
244
+
245
+
246
+
247
+ def selu_init(m):
248
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
249
+ m.weight.data.normal_(0.0, 0.5 / math.sqrt(m.weight.numel()))
250
+ nn.init.constant_(m.bias, 0)
251
+ elif isinstance(m, nn.BatchNorm1d):
252
+ size = m.weight.size()
253
+ fan_in = size[0]
254
+
255
+ m.weight.data.normal_(0.0, 1.0 / math.sqrt(fan_in))
256
+ m.bias.data.fill_(0)
@@ -0,0 +1,131 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.nn import init
5
+
6
+ from reflectorch.models.activations import activation_by_name
7
+
8
+ class ResidualMLP(nn.Module):
9
+ """Multilayer perceptron with residual blocks (BN-Act-Linear-BN-Act-Linear)"""
10
+
11
+ def __init__(
12
+ self,
13
+ dim_in: int,
14
+ dim_out: int,
15
+ dim_condition: int = 0,
16
+ layer_width: int = 512,
17
+ num_blocks: int = 4,
18
+ repeats_per_block: int = 2,
19
+ activation: str = 'relu',
20
+ use_batch_norm: bool = True,
21
+ dropout_rate: float = 0.0,
22
+ residual: bool = True,
23
+ adaptive_activation: bool = False,
24
+ conditioning: str = 'glu',
25
+ ):
26
+ super().__init__()
27
+
28
+ dim_first_layer = dim_in + dim_condition
29
+ self.first_layer = nn.Linear(dim_first_layer, layer_width)
30
+ self.blocks = nn.ModuleList(
31
+ [
32
+ ResidualBlock(
33
+ layer_width=layer_width,
34
+ dim_condition=dim_condition,
35
+ repeats_per_block=repeats_per_block,
36
+ activation=activation,
37
+ use_batch_norm=use_batch_norm,
38
+ dropout_rate=dropout_rate,
39
+ residual=residual,
40
+ adaptive_activation=adaptive_activation,
41
+ conditioning = conditioning,
42
+ )
43
+ for _ in range(num_blocks)
44
+ ]
45
+ )
46
+ self.last_layer = nn.Linear(layer_width, dim_out)
47
+
48
+ def forward(self, x, condition=None):
49
+ if condition is None:
50
+ x = self.first_layer(x)
51
+ else:
52
+ x = self.first_layer(torch.cat([x, condition], dim=-1))
53
+
54
+ for block in self.blocks:
55
+ x = block(x, condition=condition)
56
+ x = self.last_layer(x)
57
+
58
+ return x
59
+
60
+
61
+ class ResidualBlock(nn.Module):
62
+ """Residual block (BN-Act-Linear-BN-Act-Linear)"""
63
+
64
+ def __init__(
65
+ self,
66
+ layer_width: int,
67
+ dim_condition: int = 0,
68
+ repeats_per_block: int = 2,
69
+ activation: str = 'relu',
70
+ use_batch_norm: bool = False,
71
+ dropout_rate: float = 0.0,
72
+ residual: bool = True,
73
+ adaptive_activation: bool = False,
74
+ conditioning: str = 'glu',
75
+ ):
76
+ super().__init__()
77
+
78
+ self.residual = residual
79
+ self.repeats_per_block = repeats_per_block
80
+ self.use_batch_norm = use_batch_norm
81
+ self.dropout_rate = dropout_rate
82
+ self.adaptive_activation = adaptive_activation
83
+ self.conditioning = conditioning
84
+
85
+ if not adaptive_activation:
86
+ self.activation = activation_by_name(activation)()
87
+ else:
88
+ self.activation_layers = nn.ModuleList(
89
+ [activation_by_name(activation)() for _ in range(repeats_per_block)]
90
+ )
91
+
92
+ if use_batch_norm:
93
+ self.batch_norm_layers = nn.ModuleList(
94
+ [nn.BatchNorm1d(layer_width, eps=1e-3) for _ in range(repeats_per_block)]
95
+ )
96
+
97
+ if dim_condition:
98
+ if conditioning == 'glu':
99
+ self.condition_layer = nn.Linear(dim_condition, layer_width)
100
+ elif conditioning == 'film':
101
+ self.condition_layer = nn.Linear(dim_condition, 2*layer_width)
102
+
103
+ self.linear_layers = nn.ModuleList(
104
+ [nn.Linear(layer_width, layer_width) for _ in range(repeats_per_block)]
105
+ )
106
+
107
+ if self.dropout_rate > 0:
108
+ self.dropout = nn.Dropout(p=dropout_rate)
109
+
110
+ def forward(self, x, condition=None):
111
+ x0 = x
112
+
113
+ for i in range(self.repeats_per_block):
114
+ if self.use_batch_norm:
115
+ x = self.batch_norm_layers[i](x)
116
+ if not self.adaptive_activation:
117
+ x = self.activation(x)
118
+ else:
119
+ x = self.activation_layers[i](x)
120
+ if self.dropout_rate > 0 and i == self.repeats_per_block - 1:
121
+ x = self.dropout(x)
122
+ x = self.linear_layers[i](x)
123
+
124
+ if condition is not None:
125
+ if self.conditioning == 'glu':
126
+ x = F.glu(torch.cat((x, self.condition_layer(condition)), dim=-1), dim=-1)
127
+ elif self.conditioning == 'film':
128
+ gamma, beta = torch.chunk(self.condition_layer(condition), chunks=2, dim=-1)
129
+ x = x * gamma + beta
130
+
131
+ return x0 + x if self.residual else x
reflectorch/paths.py ADDED
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union
8
+ from pathlib import Path
9
+
10
+ __all__ = [
11
+ 'ROOT_DIR',
12
+ 'SAVED_MODELS_DIR',
13
+ 'SAVED_LOSSES_DIR',
14
+ 'RUN_SCRIPTS_DIR',
15
+ 'CONFIG_DIR',
16
+ 'TESTS_PATH',
17
+ 'TEST_DATA_PATH',
18
+ 'listdir',
19
+ ]
20
+
21
+ ROOT_DIR: Path = Path(__file__).parents[1]
22
+ SAVED_MODELS_DIR: Path = ROOT_DIR / 'saved_models'
23
+ SAVED_LOSSES_DIR: Path = ROOT_DIR / 'saved_losses'
24
+ RUN_SCRIPTS_DIR: Path = ROOT_DIR / 'runs'
25
+ CONFIG_DIR: Path = ROOT_DIR / 'configs'
26
+ TESTS_PATH: Path = ROOT_DIR / 'tests'
27
+ TEST_DATA_PATH = TESTS_PATH / 'data'
28
+
29
+
30
+ def listdir(path: Union[Path, str], pattern: str = '*', recursive: bool = False, *, sort_key=None, reverse=False):
31
+ path = Path(path)
32
+ func = path.rglob if recursive else path.glob
33
+ return sorted(list(func(pattern)), key=sort_key, reverse=reverse)
@@ -0,0 +1,35 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.runs.train import (
8
+ run_train,
9
+ run_train_on_cluster,
10
+ run_test_config,
11
+ )
12
+
13
+ from reflectorch.runs.utils import (
14
+ train_from_config,
15
+ get_trainer_from_config,
16
+ get_paths_from_config,
17
+ get_callbacks_from_config,
18
+ get_trainer_by_name,
19
+ get_callbacks_by_name,
20
+ )
21
+
22
+ from reflectorch.runs.config import load_config
23
+
24
+ __all__ = [
25
+ 'run_train',
26
+ 'run_train_on_cluster',
27
+ 'train_from_config',
28
+ 'run_test_config',
29
+ 'get_trainer_from_config',
30
+ 'get_paths_from_config',
31
+ 'get_callbacks_from_config',
32
+ 'get_trainer_by_name',
33
+ 'get_callbacks_by_name',
34
+ 'load_config',
35
+ ]
@@ -0,0 +1,31 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import yaml
8
+
9
+ from pathlib import Path
10
+ from reflectorch.paths import CONFIG_DIR
11
+
12
+
13
+ def load_config(config_name: str, config_dir: str = None) -> dict:
14
+ """Loads a configuration dictionary from a YAML configuration file located in the configuration directory
15
+
16
+ Args:
17
+ config_name (str): name of the YAML configuration file
18
+ config_dir (str): path of the configuration directory
19
+
20
+ Returns:
21
+ dict: the configuration dictionary
22
+ """
23
+ if not config_name.endswith('.yaml'):
24
+ config_name = f'{config_name}.yaml'
25
+ config_dir = Path(config_dir) if config_dir else CONFIG_DIR
26
+ path = config_dir / config_name
27
+ with open(path, 'r') as f:
28
+ config = yaml.safe_load(f)
29
+ config['config_path'] = str(path.absolute())
30
+
31
+ return config
@@ -0,0 +1,99 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple, Union
8
+ from pathlib import Path
9
+ import subprocess
10
+
11
+ from reflectorch.paths import RUN_SCRIPTS_DIR
12
+
13
+
14
+ def save_sbatch_and_run(
15
+ name: str,
16
+ args: str,
17
+ time: str,
18
+ partition: str = None,
19
+ reservation: bool = False,
20
+ chdir: str = '~/maxwell_output',
21
+ run_dir: Path = None,
22
+ confirm: bool = False,
23
+ ) -> Union[Tuple[str, str], None]:
24
+ run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
25
+ sbatch_path = run_dir / f'{name}.sh'
26
+
27
+ if sbatch_path.is_file():
28
+ import warnings
29
+ warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
30
+ if confirm and not confirm_input('Continue?'):
31
+ return
32
+
33
+ file_content = _generate_sbatch_str(
34
+ name,
35
+ args,
36
+ time=time,
37
+ reservation=reservation,
38
+ partition=partition,
39
+ chdir=chdir,
40
+ )
41
+
42
+ if confirm:
43
+ print(f'Generated file content: \n{file_content}\n')
44
+ if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
45
+ return
46
+
47
+ with open(str(sbatch_path), 'w') as f:
48
+ f.write(file_content)
49
+
50
+ res = submit_job(str(sbatch_path))
51
+ return res
52
+
53
+
54
+ def _generate_sbatch_str(name: str,
55
+ args: str,
56
+ time: str,
57
+ partition: str = None,
58
+ reservation: bool = False,
59
+ chdir: str = '~/maxwell_output',
60
+ entry_point: str = 'python -m reflectorch.train',
61
+ ) -> str:
62
+ chdir = str(Path(chdir).expanduser().absolute())
63
+ partition_keyword = 'reservation' if reservation else 'partition'
64
+
65
+ return f'''#!/bin/bash
66
+ #SBATCH --chdir {chdir}
67
+ #SBATCH --{partition_keyword}={partition}
68
+ #SBATCH --constraint=P100
69
+ #SBATCH --nodes=1
70
+ #SBATCH --job-name {name}
71
+ #SBATCH --time={time}
72
+ #SBATCH --output {name}.out
73
+ #SBATCH --error {name}.err
74
+
75
+ {entry_point} {args}
76
+ '''
77
+
78
+
79
+ def confirm_input(message: str) -> bool:
80
+ yes = ('y', 'yes')
81
+ no = ('n', 'no')
82
+ res = ''
83
+ valid_results = list(yes) + list(no)
84
+ message = f'{message} Y/n: '
85
+
86
+ while res not in valid_results:
87
+ res = input(message).lower()
88
+ return res in yes
89
+
90
+
91
+ def submit_job(sbatch_path: str) -> Tuple[str, str]:
92
+ process = subprocess.Popen(
93
+ ['sbatch', str(sbatch_path)],
94
+ stdout=subprocess.PIPE,
95
+ stderr=subprocess.PIPE,
96
+ )
97
+
98
+ stdout, stderr = process.communicate()
99
+ return stdout.decode(), stderr.decode()
@@ -0,0 +1,85 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import click
8
+
9
+ from reflectorch.runs.slurm_utils import save_sbatch_and_run
10
+ from reflectorch.runs.utils import train_from_config
11
+ from reflectorch.runs.config import load_config
12
+
13
+ __all__ = [
14
+ 'run_train',
15
+ 'run_train_on_cluster',
16
+ 'run_test_config',
17
+ ]
18
+
19
+
20
+ @click.command()
21
+ @click.argument('config_name', type=str)
22
+ def run_train(config_name: str):
23
+ """Runs the training from the command line interface
24
+ Example: python -m reflectorch.train 'conf_name'
25
+
26
+ Args:
27
+ config_name (str): name of the YAML configuration file
28
+ """
29
+ config = load_config(config_name)
30
+ train_from_config(config)
31
+
32
+
33
+ @click.command()
34
+ @click.argument('config_name', type=str)
35
+ @click.argument('batch_size', type=int, default=512)
36
+ @click.argument('num_iterations', type=int, default=10)
37
+ def run_test_config(config_name: str, batch_size: int, num_iterations: int):
38
+ """Run for the purpose of testing the configuration file.
39
+ Example: python -m reflectorch.test_config 'conf_name.yaml' 512 10
40
+
41
+ Args:
42
+ config_name (str): name of the YAML configuration file
43
+ batch_size (int): overwrites the batch size in the configuration file
44
+ num_iterations (int): overwrites the number of iterations in the configuration file
45
+ """
46
+ config = load_config(config_name)
47
+ config = _change_to_test_config(config, batch_size=batch_size, num_iterations=num_iterations)
48
+ train_from_config(config)
49
+
50
+
51
+ @click.command()
52
+ @click.argument('config_name')
53
+ def run_train_on_cluster(config_name: str):
54
+ config = load_config(config_name)
55
+ name = config['general']['name']
56
+ slurm_conf = config['slurm']
57
+
58
+ res = save_sbatch_and_run(
59
+ name,
60
+ config_name,
61
+ time=slurm_conf['time'],
62
+ partition=slurm_conf['partition'],
63
+ reservation=slurm_conf.get('reservation', False),
64
+ chdir=slurm_conf.get('chdir', '~/maxwell_output'),
65
+ run_dir=slurm_conf.get('run_dir', None),
66
+ confirm=slurm_conf.get('confirm', True),
67
+ )
68
+ if not res:
69
+ print('Aborted.')
70
+ return
71
+ out, err = res
72
+
73
+ if err:
74
+ print('Error occurred: ', err)
75
+ else:
76
+ print('Success!', out)
77
+
78
+
79
+ def _change_to_test_config(config, batch_size: int, num_iterations: int):
80
+ config = dict(config)
81
+ config['training']['logger']['use_neptune'] = False
82
+ config['training']['num_iterations'] = num_iterations
83
+ config['training']['batch_size'] = batch_size
84
+ config['training']['update_tqdm_freq'] = 1
85
+ return config