reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,157 +1,157 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import nn
|
|
3
|
-
from torch.nn import functional as F
|
|
4
|
-
from torch.nn import init
|
|
5
|
-
|
|
6
|
-
from reflectorch.models.activations import activation_by_name
|
|
7
|
-
|
|
8
|
-
class ResidualMLP(nn.Module):
|
|
9
|
-
"""Multilayer perceptron with residual blocks (BN-Act-Linear-BN-Act-Linear)"""
|
|
10
|
-
|
|
11
|
-
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
dim_in: int,
|
|
14
|
-
dim_out: int,
|
|
15
|
-
dim_condition: int = 0,
|
|
16
|
-
layer_width: int = 512,
|
|
17
|
-
num_blocks: int = 4,
|
|
18
|
-
repeats_per_block: int = 2,
|
|
19
|
-
activation: str = 'relu',
|
|
20
|
-
use_batch_norm: bool = True,
|
|
21
|
-
use_layer_norm: bool = False,
|
|
22
|
-
dropout_rate: float = 0.0,
|
|
23
|
-
residual: bool = True,
|
|
24
|
-
adaptive_activation: bool = False,
|
|
25
|
-
conditioning: str = 'glu',
|
|
26
|
-
concat_condition_first_layer: bool = True,
|
|
27
|
-
film_with_tanh: bool = False,
|
|
28
|
-
):
|
|
29
|
-
super().__init__()
|
|
30
|
-
|
|
31
|
-
self.concat_condition_first_layer = concat_condition_first_layer
|
|
32
|
-
|
|
33
|
-
dim_first_layer = dim_in + dim_condition if concat_condition_first_layer else dim_in
|
|
34
|
-
self.first_layer = nn.Linear(dim_first_layer, layer_width)
|
|
35
|
-
|
|
36
|
-
self.blocks = nn.ModuleList(
|
|
37
|
-
[
|
|
38
|
-
ResidualBlock(
|
|
39
|
-
layer_width=layer_width,
|
|
40
|
-
dim_condition=dim_condition,
|
|
41
|
-
repeats_per_block=repeats_per_block,
|
|
42
|
-
activation=activation,
|
|
43
|
-
use_batch_norm=use_batch_norm,
|
|
44
|
-
use_layer_norm=use_layer_norm,
|
|
45
|
-
dropout_rate=dropout_rate,
|
|
46
|
-
residual=residual,
|
|
47
|
-
adaptive_activation=adaptive_activation,
|
|
48
|
-
conditioning = conditioning,
|
|
49
|
-
film_with_tanh = film_with_tanh,
|
|
50
|
-
)
|
|
51
|
-
for _ in range(num_blocks)
|
|
52
|
-
]
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
self.last_layer = nn.Linear(layer_width, dim_out)
|
|
56
|
-
|
|
57
|
-
def forward(self, x, condition=None):
|
|
58
|
-
if self.concat_condition_first_layer and condition is not None:
|
|
59
|
-
x = self.first_layer(torch.cat([x, condition], dim=-1))
|
|
60
|
-
else:
|
|
61
|
-
x = self.first_layer(x)
|
|
62
|
-
|
|
63
|
-
for block in self.blocks:
|
|
64
|
-
x = block(x, condition=condition)
|
|
65
|
-
|
|
66
|
-
x = self.last_layer(x)
|
|
67
|
-
|
|
68
|
-
return x
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
class ResidualBlock(nn.Module):
|
|
72
|
-
"""Residual block (BN-Act-Linear-BN-Act-Linear)"""
|
|
73
|
-
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
layer_width: int,
|
|
77
|
-
dim_condition: int = 0,
|
|
78
|
-
repeats_per_block: int = 2,
|
|
79
|
-
activation: str = 'relu',
|
|
80
|
-
use_batch_norm: bool = False,
|
|
81
|
-
use_layer_norm: bool = False,
|
|
82
|
-
dropout_rate: float = 0.0,
|
|
83
|
-
residual: bool = True,
|
|
84
|
-
adaptive_activation: bool = False,
|
|
85
|
-
conditioning: str = 'glu',
|
|
86
|
-
film_with_tanh: bool = False,
|
|
87
|
-
):
|
|
88
|
-
super().__init__()
|
|
89
|
-
|
|
90
|
-
self.residual = residual
|
|
91
|
-
self.repeats_per_block = repeats_per_block
|
|
92
|
-
self.use_batch_norm = use_batch_norm
|
|
93
|
-
self.use_layer_norm = use_layer_norm
|
|
94
|
-
self.dropout_rate = dropout_rate
|
|
95
|
-
self.adaptive_activation = adaptive_activation
|
|
96
|
-
self.conditioning = conditioning
|
|
97
|
-
self.film_with_tanh = film_with_tanh
|
|
98
|
-
|
|
99
|
-
if not adaptive_activation:
|
|
100
|
-
self.activation = activation_by_name(activation)()
|
|
101
|
-
else:
|
|
102
|
-
self.activation_layers = nn.ModuleList(
|
|
103
|
-
[activation_by_name(activation)() for _ in range(repeats_per_block)]
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if use_batch_norm:
|
|
107
|
-
self.batch_norm_layers = nn.ModuleList(
|
|
108
|
-
[nn.BatchNorm1d(layer_width, eps=1e-3) for _ in range(repeats_per_block)]
|
|
109
|
-
)
|
|
110
|
-
elif use_layer_norm:
|
|
111
|
-
self.layer_norm_layers = nn.ModuleList(
|
|
112
|
-
[nn.LayerNorm(layer_width) for _ in range(repeats_per_block)]
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
if dim_condition:
|
|
116
|
-
if conditioning == 'glu':
|
|
117
|
-
self.condition_layer = nn.Linear(dim_condition, layer_width)
|
|
118
|
-
elif conditioning == 'film':
|
|
119
|
-
self.condition_layer = nn.Linear(dim_condition, 2*layer_width)
|
|
120
|
-
|
|
121
|
-
self.linear_layers = nn.ModuleList(
|
|
122
|
-
[nn.Linear(layer_width, layer_width) for _ in range(repeats_per_block)]
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
if self.dropout_rate > 0:
|
|
126
|
-
self.dropout = nn.Dropout(p=dropout_rate)
|
|
127
|
-
|
|
128
|
-
def forward(self, x, condition=None):
|
|
129
|
-
x0 = x
|
|
130
|
-
|
|
131
|
-
for i in range(self.repeats_per_block):
|
|
132
|
-
if self.use_batch_norm:
|
|
133
|
-
x = self.batch_norm_layers[i](x)
|
|
134
|
-
elif self.use_layer_norm:
|
|
135
|
-
x = self.layer_norm_layers[i](x)
|
|
136
|
-
|
|
137
|
-
if not self.adaptive_activation:
|
|
138
|
-
x = self.activation(x)
|
|
139
|
-
else:
|
|
140
|
-
x = self.activation_layers[i](x)
|
|
141
|
-
|
|
142
|
-
if self.dropout_rate > 0 and i == self.repeats_per_block - 1:
|
|
143
|
-
x = self.dropout(x)
|
|
144
|
-
|
|
145
|
-
x = self.linear_layers[i](x)
|
|
146
|
-
|
|
147
|
-
if condition is not None:
|
|
148
|
-
if self.conditioning == 'glu':
|
|
149
|
-
x = F.glu(torch.cat((x, self.condition_layer(condition)), dim=-1), dim=-1)
|
|
150
|
-
elif self.conditioning == 'film':
|
|
151
|
-
gamma, beta = torch.chunk(self.condition_layer(condition), chunks=2, dim=-1)
|
|
152
|
-
if self.film_with_tanh:
|
|
153
|
-
tanh = nn.Tanh()
|
|
154
|
-
gamma, beta = tanh(gamma), tanh(beta)
|
|
155
|
-
x = x * gamma + beta
|
|
156
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
from torch.nn import init
|
|
5
|
+
|
|
6
|
+
from reflectorch.models.activations import activation_by_name
|
|
7
|
+
|
|
8
|
+
class ResidualMLP(nn.Module):
|
|
9
|
+
"""Multilayer perceptron with residual blocks (BN-Act-Linear-BN-Act-Linear)"""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
dim_in: int,
|
|
14
|
+
dim_out: int,
|
|
15
|
+
dim_condition: int = 0,
|
|
16
|
+
layer_width: int = 512,
|
|
17
|
+
num_blocks: int = 4,
|
|
18
|
+
repeats_per_block: int = 2,
|
|
19
|
+
activation: str = 'relu',
|
|
20
|
+
use_batch_norm: bool = True,
|
|
21
|
+
use_layer_norm: bool = False,
|
|
22
|
+
dropout_rate: float = 0.0,
|
|
23
|
+
residual: bool = True,
|
|
24
|
+
adaptive_activation: bool = False,
|
|
25
|
+
conditioning: str = 'glu',
|
|
26
|
+
concat_condition_first_layer: bool = True,
|
|
27
|
+
film_with_tanh: bool = False,
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
self.concat_condition_first_layer = concat_condition_first_layer
|
|
32
|
+
|
|
33
|
+
dim_first_layer = dim_in + dim_condition if concat_condition_first_layer else dim_in
|
|
34
|
+
self.first_layer = nn.Linear(dim_first_layer, layer_width)
|
|
35
|
+
|
|
36
|
+
self.blocks = nn.ModuleList(
|
|
37
|
+
[
|
|
38
|
+
ResidualBlock(
|
|
39
|
+
layer_width=layer_width,
|
|
40
|
+
dim_condition=dim_condition,
|
|
41
|
+
repeats_per_block=repeats_per_block,
|
|
42
|
+
activation=activation,
|
|
43
|
+
use_batch_norm=use_batch_norm,
|
|
44
|
+
use_layer_norm=use_layer_norm,
|
|
45
|
+
dropout_rate=dropout_rate,
|
|
46
|
+
residual=residual,
|
|
47
|
+
adaptive_activation=adaptive_activation,
|
|
48
|
+
conditioning = conditioning,
|
|
49
|
+
film_with_tanh = film_with_tanh,
|
|
50
|
+
)
|
|
51
|
+
for _ in range(num_blocks)
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
self.last_layer = nn.Linear(layer_width, dim_out)
|
|
56
|
+
|
|
57
|
+
def forward(self, x, condition=None):
|
|
58
|
+
if self.concat_condition_first_layer and condition is not None:
|
|
59
|
+
x = self.first_layer(torch.cat([x, condition], dim=-1))
|
|
60
|
+
else:
|
|
61
|
+
x = self.first_layer(x)
|
|
62
|
+
|
|
63
|
+
for block in self.blocks:
|
|
64
|
+
x = block(x, condition=condition)
|
|
65
|
+
|
|
66
|
+
x = self.last_layer(x)
|
|
67
|
+
|
|
68
|
+
return x
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ResidualBlock(nn.Module):
|
|
72
|
+
"""Residual block (BN-Act-Linear-BN-Act-Linear)"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
layer_width: int,
|
|
77
|
+
dim_condition: int = 0,
|
|
78
|
+
repeats_per_block: int = 2,
|
|
79
|
+
activation: str = 'relu',
|
|
80
|
+
use_batch_norm: bool = False,
|
|
81
|
+
use_layer_norm: bool = False,
|
|
82
|
+
dropout_rate: float = 0.0,
|
|
83
|
+
residual: bool = True,
|
|
84
|
+
adaptive_activation: bool = False,
|
|
85
|
+
conditioning: str = 'glu',
|
|
86
|
+
film_with_tanh: bool = False,
|
|
87
|
+
):
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
self.residual = residual
|
|
91
|
+
self.repeats_per_block = repeats_per_block
|
|
92
|
+
self.use_batch_norm = use_batch_norm
|
|
93
|
+
self.use_layer_norm = use_layer_norm
|
|
94
|
+
self.dropout_rate = dropout_rate
|
|
95
|
+
self.adaptive_activation = adaptive_activation
|
|
96
|
+
self.conditioning = conditioning
|
|
97
|
+
self.film_with_tanh = film_with_tanh
|
|
98
|
+
|
|
99
|
+
if not adaptive_activation:
|
|
100
|
+
self.activation = activation_by_name(activation)()
|
|
101
|
+
else:
|
|
102
|
+
self.activation_layers = nn.ModuleList(
|
|
103
|
+
[activation_by_name(activation)() for _ in range(repeats_per_block)]
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if use_batch_norm:
|
|
107
|
+
self.batch_norm_layers = nn.ModuleList(
|
|
108
|
+
[nn.BatchNorm1d(layer_width, eps=1e-3) for _ in range(repeats_per_block)]
|
|
109
|
+
)
|
|
110
|
+
elif use_layer_norm:
|
|
111
|
+
self.layer_norm_layers = nn.ModuleList(
|
|
112
|
+
[nn.LayerNorm(layer_width) for _ in range(repeats_per_block)]
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if dim_condition:
|
|
116
|
+
if conditioning == 'glu':
|
|
117
|
+
self.condition_layer = nn.Linear(dim_condition, layer_width)
|
|
118
|
+
elif conditioning == 'film':
|
|
119
|
+
self.condition_layer = nn.Linear(dim_condition, 2*layer_width)
|
|
120
|
+
|
|
121
|
+
self.linear_layers = nn.ModuleList(
|
|
122
|
+
[nn.Linear(layer_width, layer_width) for _ in range(repeats_per_block)]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if self.dropout_rate > 0:
|
|
126
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
127
|
+
|
|
128
|
+
def forward(self, x, condition=None):
|
|
129
|
+
x0 = x
|
|
130
|
+
|
|
131
|
+
for i in range(self.repeats_per_block):
|
|
132
|
+
if self.use_batch_norm:
|
|
133
|
+
x = self.batch_norm_layers[i](x)
|
|
134
|
+
elif self.use_layer_norm:
|
|
135
|
+
x = self.layer_norm_layers[i](x)
|
|
136
|
+
|
|
137
|
+
if not self.adaptive_activation:
|
|
138
|
+
x = self.activation(x)
|
|
139
|
+
else:
|
|
140
|
+
x = self.activation_layers[i](x)
|
|
141
|
+
|
|
142
|
+
if self.dropout_rate > 0 and i == self.repeats_per_block - 1:
|
|
143
|
+
x = self.dropout(x)
|
|
144
|
+
|
|
145
|
+
x = self.linear_layers[i](x)
|
|
146
|
+
|
|
147
|
+
if condition is not None:
|
|
148
|
+
if self.conditioning == 'glu':
|
|
149
|
+
x = F.glu(torch.cat((x, self.condition_layer(condition)), dim=-1), dim=-1)
|
|
150
|
+
elif self.conditioning == 'film':
|
|
151
|
+
gamma, beta = torch.chunk(self.condition_layer(condition), chunks=2, dim=-1)
|
|
152
|
+
if self.film_with_tanh:
|
|
153
|
+
tanh = nn.Tanh()
|
|
154
|
+
gamma, beta = tanh(gamma), tanh(beta)
|
|
155
|
+
x = x * gamma + beta
|
|
156
|
+
|
|
157
157
|
return x0 + x if self.residual else x
|
reflectorch/paths.py
CHANGED
|
@@ -1,27 +1,29 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
|
|
4
|
-
__all__ = [
|
|
5
|
-
'ROOT_DIR',
|
|
6
|
-
'
|
|
7
|
-
'
|
|
8
|
-
'
|
|
9
|
-
'
|
|
10
|
-
'
|
|
11
|
-
'
|
|
12
|
-
'
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
1
|
+
from typing import Union
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
'ROOT_DIR',
|
|
6
|
+
'EXP_DATA_DIR',
|
|
7
|
+
'SAVED_MODELS_DIR',
|
|
8
|
+
'SAVED_LOSSES_DIR',
|
|
9
|
+
'RUN_SCRIPTS_DIR',
|
|
10
|
+
'CONFIG_DIR',
|
|
11
|
+
'TESTS_PATH',
|
|
12
|
+
'TEST_DATA_PATH',
|
|
13
|
+
'listdir',
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
ROOT_DIR: Path = Path(__file__).parents[1]
|
|
17
|
+
EXP_DATA_DIR: Path = ROOT_DIR / 'exp_data'
|
|
18
|
+
SAVED_MODELS_DIR: Path = ROOT_DIR / 'saved_models'
|
|
19
|
+
SAVED_LOSSES_DIR: Path = ROOT_DIR / 'saved_losses'
|
|
20
|
+
RUN_SCRIPTS_DIR: Path = ROOT_DIR / 'runs'
|
|
21
|
+
CONFIG_DIR: Path = ROOT_DIR / 'configs'
|
|
22
|
+
TESTS_PATH: Path = ROOT_DIR / 'tests'
|
|
23
|
+
TEST_DATA_PATH = TESTS_PATH / 'data'
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def listdir(path: Union[Path, str], pattern: str = '*', recursive: bool = False, *, sort_key=None, reverse=False):
|
|
27
|
+
path = Path(path)
|
|
28
|
+
func = path.rglob if recursive else path.glob
|
|
29
|
+
return sorted(list(func(pattern)), key=sort_key, reverse=reverse)
|
reflectorch/runs/__init__.py
CHANGED
|
@@ -1,31 +1,31 @@
|
|
|
1
|
-
from reflectorch.runs.train import (
|
|
2
|
-
run_train,
|
|
3
|
-
run_train_on_cluster,
|
|
4
|
-
run_test_config,
|
|
5
|
-
)
|
|
6
|
-
|
|
7
|
-
from reflectorch.runs.utils import (
|
|
8
|
-
train_from_config,
|
|
9
|
-
get_trainer_from_config,
|
|
10
|
-
get_paths_from_config,
|
|
11
|
-
get_callbacks_from_config,
|
|
12
|
-
get_trainer_by_name,
|
|
13
|
-
get_callbacks_by_name,
|
|
14
|
-
convert_pt_to_safetensors,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
from reflectorch.runs.config import load_config
|
|
18
|
-
|
|
19
|
-
__all__ = [
|
|
20
|
-
'run_train',
|
|
21
|
-
'run_train_on_cluster',
|
|
22
|
-
'train_from_config',
|
|
23
|
-
'run_test_config',
|
|
24
|
-
'get_trainer_from_config',
|
|
25
|
-
'get_paths_from_config',
|
|
26
|
-
'get_callbacks_from_config',
|
|
27
|
-
'get_trainer_by_name',
|
|
28
|
-
'get_callbacks_by_name',
|
|
29
|
-
'convert_pt_to_safetensors',
|
|
30
|
-
'load_config',
|
|
31
|
-
]
|
|
1
|
+
from reflectorch.runs.train import (
|
|
2
|
+
run_train,
|
|
3
|
+
run_train_on_cluster,
|
|
4
|
+
run_test_config,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from reflectorch.runs.utils import (
|
|
8
|
+
train_from_config,
|
|
9
|
+
get_trainer_from_config,
|
|
10
|
+
get_paths_from_config,
|
|
11
|
+
get_callbacks_from_config,
|
|
12
|
+
get_trainer_by_name,
|
|
13
|
+
get_callbacks_by_name,
|
|
14
|
+
convert_pt_to_safetensors,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from reflectorch.runs.config import load_config
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
'run_train',
|
|
21
|
+
'run_train_on_cluster',
|
|
22
|
+
'train_from_config',
|
|
23
|
+
'run_test_config',
|
|
24
|
+
'get_trainer_from_config',
|
|
25
|
+
'get_paths_from_config',
|
|
26
|
+
'get_callbacks_from_config',
|
|
27
|
+
'get_trainer_by_name',
|
|
28
|
+
'get_callbacks_by_name',
|
|
29
|
+
'convert_pt_to_safetensors',
|
|
30
|
+
'load_config',
|
|
31
|
+
]
|
reflectorch/runs/config.py
CHANGED
|
@@ -1,25 +1,25 @@
|
|
|
1
|
-
import yaml
|
|
2
|
-
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from reflectorch.paths import CONFIG_DIR
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def load_config(config_name: str, config_dir: str = None) -> dict:
|
|
8
|
-
"""Loads a configuration dictionary from a YAML configuration file located in the configuration directory
|
|
9
|
-
|
|
10
|
-
Args:
|
|
11
|
-
config_name (str): name of the YAML configuration file
|
|
12
|
-
config_dir (str): path of the configuration directory
|
|
13
|
-
|
|
14
|
-
Returns:
|
|
15
|
-
dict: the configuration dictionary
|
|
16
|
-
"""
|
|
17
|
-
if not config_name.endswith('.yaml'):
|
|
18
|
-
config_name = f'{config_name}.yaml'
|
|
19
|
-
config_dir = Path(config_dir) if config_dir else CONFIG_DIR
|
|
20
|
-
path = config_dir / config_name
|
|
21
|
-
with open(path, 'r') as f:
|
|
22
|
-
config = yaml.safe_load(f)
|
|
23
|
-
config['config_path'] = str(path.absolute())
|
|
24
|
-
|
|
25
|
-
return config
|
|
1
|
+
import yaml
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from reflectorch.paths import CONFIG_DIR
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def load_config(config_name: str, config_dir: str = None) -> dict:
|
|
8
|
+
"""Loads a configuration dictionary from a YAML configuration file located in the configuration directory
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
config_name (str): name of the YAML configuration file
|
|
12
|
+
config_dir (str): path of the configuration directory
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
dict: the configuration dictionary
|
|
16
|
+
"""
|
|
17
|
+
if not config_name.endswith('.yaml'):
|
|
18
|
+
config_name = f'{config_name}.yaml'
|
|
19
|
+
config_dir = Path(config_dir) if config_dir else CONFIG_DIR
|
|
20
|
+
path = config_dir / config_name
|
|
21
|
+
with open(path, 'r') as f:
|
|
22
|
+
config = yaml.safe_load(f)
|
|
23
|
+
config['config_path'] = str(path.absolute())
|
|
24
|
+
|
|
25
|
+
return config
|
reflectorch/runs/slurm_utils.py
CHANGED
|
@@ -1,93 +1,93 @@
|
|
|
1
|
-
from typing import Tuple, Union
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
import subprocess
|
|
4
|
-
|
|
5
|
-
from reflectorch.paths import RUN_SCRIPTS_DIR
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def save_sbatch_and_run(
|
|
9
|
-
name: str,
|
|
10
|
-
args: str,
|
|
11
|
-
time: str,
|
|
12
|
-
partition: str = None,
|
|
13
|
-
reservation: bool = False,
|
|
14
|
-
chdir: str = '~/maxwell_output',
|
|
15
|
-
run_dir: Path = None,
|
|
16
|
-
confirm: bool = False,
|
|
17
|
-
) -> Union[Tuple[str, str], None]:
|
|
18
|
-
run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
|
|
19
|
-
sbatch_path = run_dir / f'{name}.sh'
|
|
20
|
-
|
|
21
|
-
if sbatch_path.is_file():
|
|
22
|
-
import warnings
|
|
23
|
-
warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
|
|
24
|
-
if confirm and not confirm_input('Continue?'):
|
|
25
|
-
return
|
|
26
|
-
|
|
27
|
-
file_content = _generate_sbatch_str(
|
|
28
|
-
name,
|
|
29
|
-
args,
|
|
30
|
-
time=time,
|
|
31
|
-
reservation=reservation,
|
|
32
|
-
partition=partition,
|
|
33
|
-
chdir=chdir,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
if confirm:
|
|
37
|
-
print(f'Generated file content: \n{file_content}\n')
|
|
38
|
-
if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
|
|
39
|
-
return
|
|
40
|
-
|
|
41
|
-
with open(str(sbatch_path), 'w') as f:
|
|
42
|
-
f.write(file_content)
|
|
43
|
-
|
|
44
|
-
res = submit_job(str(sbatch_path))
|
|
45
|
-
return res
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def _generate_sbatch_str(name: str,
|
|
49
|
-
args: str,
|
|
50
|
-
time: str,
|
|
51
|
-
partition: str = None,
|
|
52
|
-
reservation: bool = False,
|
|
53
|
-
chdir: str = '~/maxwell_output',
|
|
54
|
-
entry_point: str = 'python -m reflectorch.train',
|
|
55
|
-
) -> str:
|
|
56
|
-
chdir = str(Path(chdir).expanduser().absolute())
|
|
57
|
-
partition_keyword = 'reservation' if reservation else 'partition'
|
|
58
|
-
|
|
59
|
-
return f'''#!/bin/bash
|
|
60
|
-
#SBATCH --chdir {chdir}
|
|
61
|
-
#SBATCH --{partition_keyword}={partition}
|
|
62
|
-
#SBATCH --constraint=P100
|
|
63
|
-
#SBATCH --nodes=1
|
|
64
|
-
#SBATCH --job-name {name}
|
|
65
|
-
#SBATCH --time={time}
|
|
66
|
-
#SBATCH --output {name}.out
|
|
67
|
-
#SBATCH --error {name}.err
|
|
68
|
-
|
|
69
|
-
{entry_point} {args}
|
|
70
|
-
'''
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def confirm_input(message: str) -> bool:
|
|
74
|
-
yes = ('y', 'yes')
|
|
75
|
-
no = ('n', 'no')
|
|
76
|
-
res = ''
|
|
77
|
-
valid_results = list(yes) + list(no)
|
|
78
|
-
message = f'{message} Y/n: '
|
|
79
|
-
|
|
80
|
-
while res not in valid_results:
|
|
81
|
-
res = input(message).lower()
|
|
82
|
-
return res in yes
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def submit_job(sbatch_path: str) -> Tuple[str, str]:
|
|
86
|
-
process = subprocess.Popen(
|
|
87
|
-
['sbatch', str(sbatch_path)],
|
|
88
|
-
stdout=subprocess.PIPE,
|
|
89
|
-
stderr=subprocess.PIPE,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
stdout, stderr = process.communicate()
|
|
93
|
-
return stdout.decode(), stderr.decode()
|
|
1
|
+
from typing import Tuple, Union
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import subprocess
|
|
4
|
+
|
|
5
|
+
from reflectorch.paths import RUN_SCRIPTS_DIR
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def save_sbatch_and_run(
|
|
9
|
+
name: str,
|
|
10
|
+
args: str,
|
|
11
|
+
time: str,
|
|
12
|
+
partition: str = None,
|
|
13
|
+
reservation: bool = False,
|
|
14
|
+
chdir: str = '~/maxwell_output',
|
|
15
|
+
run_dir: Path = None,
|
|
16
|
+
confirm: bool = False,
|
|
17
|
+
) -> Union[Tuple[str, str], None]:
|
|
18
|
+
run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
|
|
19
|
+
sbatch_path = run_dir / f'{name}.sh'
|
|
20
|
+
|
|
21
|
+
if sbatch_path.is_file():
|
|
22
|
+
import warnings
|
|
23
|
+
warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
|
|
24
|
+
if confirm and not confirm_input('Continue?'):
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
file_content = _generate_sbatch_str(
|
|
28
|
+
name,
|
|
29
|
+
args,
|
|
30
|
+
time=time,
|
|
31
|
+
reservation=reservation,
|
|
32
|
+
partition=partition,
|
|
33
|
+
chdir=chdir,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if confirm:
|
|
37
|
+
print(f'Generated file content: \n{file_content}\n')
|
|
38
|
+
if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
with open(str(sbatch_path), 'w') as f:
|
|
42
|
+
f.write(file_content)
|
|
43
|
+
|
|
44
|
+
res = submit_job(str(sbatch_path))
|
|
45
|
+
return res
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _generate_sbatch_str(name: str,
|
|
49
|
+
args: str,
|
|
50
|
+
time: str,
|
|
51
|
+
partition: str = None,
|
|
52
|
+
reservation: bool = False,
|
|
53
|
+
chdir: str = '~/maxwell_output',
|
|
54
|
+
entry_point: str = 'python -m reflectorch.train',
|
|
55
|
+
) -> str:
|
|
56
|
+
chdir = str(Path(chdir).expanduser().absolute())
|
|
57
|
+
partition_keyword = 'reservation' if reservation else 'partition'
|
|
58
|
+
|
|
59
|
+
return f'''#!/bin/bash
|
|
60
|
+
#SBATCH --chdir {chdir}
|
|
61
|
+
#SBATCH --{partition_keyword}={partition}
|
|
62
|
+
#SBATCH --constraint=P100
|
|
63
|
+
#SBATCH --nodes=1
|
|
64
|
+
#SBATCH --job-name {name}
|
|
65
|
+
#SBATCH --time={time}
|
|
66
|
+
#SBATCH --output {name}.out
|
|
67
|
+
#SBATCH --error {name}.err
|
|
68
|
+
|
|
69
|
+
{entry_point} {args}
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def confirm_input(message: str) -> bool:
|
|
74
|
+
yes = ('y', 'yes')
|
|
75
|
+
no = ('n', 'no')
|
|
76
|
+
res = ''
|
|
77
|
+
valid_results = list(yes) + list(no)
|
|
78
|
+
message = f'{message} Y/n: '
|
|
79
|
+
|
|
80
|
+
while res not in valid_results:
|
|
81
|
+
res = input(message).lower()
|
|
82
|
+
return res in yes
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def submit_job(sbatch_path: str) -> Tuple[str, str]:
|
|
86
|
+
process = subprocess.Popen(
|
|
87
|
+
['sbatch', str(sbatch_path)],
|
|
88
|
+
stdout=subprocess.PIPE,
|
|
89
|
+
stderr=subprocess.PIPE,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
stdout, stderr = process.communicate()
|
|
93
|
+
return stdout.decode(), stderr.decode()
|