reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,157 +1,157 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from torch.nn import init
5
-
6
- from reflectorch.models.activations import activation_by_name
7
-
8
- class ResidualMLP(nn.Module):
9
- """Multilayer perceptron with residual blocks (BN-Act-Linear-BN-Act-Linear)"""
10
-
11
- def __init__(
12
- self,
13
- dim_in: int,
14
- dim_out: int,
15
- dim_condition: int = 0,
16
- layer_width: int = 512,
17
- num_blocks: int = 4,
18
- repeats_per_block: int = 2,
19
- activation: str = 'relu',
20
- use_batch_norm: bool = True,
21
- use_layer_norm: bool = False,
22
- dropout_rate: float = 0.0,
23
- residual: bool = True,
24
- adaptive_activation: bool = False,
25
- conditioning: str = 'glu',
26
- concat_condition_first_layer: bool = True,
27
- film_with_tanh: bool = False,
28
- ):
29
- super().__init__()
30
-
31
- self.concat_condition_first_layer = concat_condition_first_layer
32
-
33
- dim_first_layer = dim_in + dim_condition if concat_condition_first_layer else dim_in
34
- self.first_layer = nn.Linear(dim_first_layer, layer_width)
35
-
36
- self.blocks = nn.ModuleList(
37
- [
38
- ResidualBlock(
39
- layer_width=layer_width,
40
- dim_condition=dim_condition,
41
- repeats_per_block=repeats_per_block,
42
- activation=activation,
43
- use_batch_norm=use_batch_norm,
44
- use_layer_norm=use_layer_norm,
45
- dropout_rate=dropout_rate,
46
- residual=residual,
47
- adaptive_activation=adaptive_activation,
48
- conditioning = conditioning,
49
- film_with_tanh = film_with_tanh,
50
- )
51
- for _ in range(num_blocks)
52
- ]
53
- )
54
-
55
- self.last_layer = nn.Linear(layer_width, dim_out)
56
-
57
- def forward(self, x, condition=None):
58
- if self.concat_condition_first_layer and condition is not None:
59
- x = self.first_layer(torch.cat([x, condition], dim=-1))
60
- else:
61
- x = self.first_layer(x)
62
-
63
- for block in self.blocks:
64
- x = block(x, condition=condition)
65
-
66
- x = self.last_layer(x)
67
-
68
- return x
69
-
70
-
71
- class ResidualBlock(nn.Module):
72
- """Residual block (BN-Act-Linear-BN-Act-Linear)"""
73
-
74
- def __init__(
75
- self,
76
- layer_width: int,
77
- dim_condition: int = 0,
78
- repeats_per_block: int = 2,
79
- activation: str = 'relu',
80
- use_batch_norm: bool = False,
81
- use_layer_norm: bool = False,
82
- dropout_rate: float = 0.0,
83
- residual: bool = True,
84
- adaptive_activation: bool = False,
85
- conditioning: str = 'glu',
86
- film_with_tanh: bool = False,
87
- ):
88
- super().__init__()
89
-
90
- self.residual = residual
91
- self.repeats_per_block = repeats_per_block
92
- self.use_batch_norm = use_batch_norm
93
- self.use_layer_norm = use_layer_norm
94
- self.dropout_rate = dropout_rate
95
- self.adaptive_activation = adaptive_activation
96
- self.conditioning = conditioning
97
- self.film_with_tanh = film_with_tanh
98
-
99
- if not adaptive_activation:
100
- self.activation = activation_by_name(activation)()
101
- else:
102
- self.activation_layers = nn.ModuleList(
103
- [activation_by_name(activation)() for _ in range(repeats_per_block)]
104
- )
105
-
106
- if use_batch_norm:
107
- self.batch_norm_layers = nn.ModuleList(
108
- [nn.BatchNorm1d(layer_width, eps=1e-3) for _ in range(repeats_per_block)]
109
- )
110
- elif use_layer_norm:
111
- self.layer_norm_layers = nn.ModuleList(
112
- [nn.LayerNorm(layer_width) for _ in range(repeats_per_block)]
113
- )
114
-
115
- if dim_condition:
116
- if conditioning == 'glu':
117
- self.condition_layer = nn.Linear(dim_condition, layer_width)
118
- elif conditioning == 'film':
119
- self.condition_layer = nn.Linear(dim_condition, 2*layer_width)
120
-
121
- self.linear_layers = nn.ModuleList(
122
- [nn.Linear(layer_width, layer_width) for _ in range(repeats_per_block)]
123
- )
124
-
125
- if self.dropout_rate > 0:
126
- self.dropout = nn.Dropout(p=dropout_rate)
127
-
128
- def forward(self, x, condition=None):
129
- x0 = x
130
-
131
- for i in range(self.repeats_per_block):
132
- if self.use_batch_norm:
133
- x = self.batch_norm_layers[i](x)
134
- elif self.use_layer_norm:
135
- x = self.layer_norm_layers[i](x)
136
-
137
- if not self.adaptive_activation:
138
- x = self.activation(x)
139
- else:
140
- x = self.activation_layers[i](x)
141
-
142
- if self.dropout_rate > 0 and i == self.repeats_per_block - 1:
143
- x = self.dropout(x)
144
-
145
- x = self.linear_layers[i](x)
146
-
147
- if condition is not None:
148
- if self.conditioning == 'glu':
149
- x = F.glu(torch.cat((x, self.condition_layer(condition)), dim=-1), dim=-1)
150
- elif self.conditioning == 'film':
151
- gamma, beta = torch.chunk(self.condition_layer(condition), chunks=2, dim=-1)
152
- if self.film_with_tanh:
153
- tanh = nn.Tanh()
154
- gamma, beta = tanh(gamma), tanh(beta)
155
- x = x * gamma + beta
156
-
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.nn import init
5
+
6
+ from reflectorch.models.activations import activation_by_name
7
+
8
+ class ResidualMLP(nn.Module):
9
+ """Multilayer perceptron with residual blocks (BN-Act-Linear-BN-Act-Linear)"""
10
+
11
+ def __init__(
12
+ self,
13
+ dim_in: int,
14
+ dim_out: int,
15
+ dim_condition: int = 0,
16
+ layer_width: int = 512,
17
+ num_blocks: int = 4,
18
+ repeats_per_block: int = 2,
19
+ activation: str = 'relu',
20
+ use_batch_norm: bool = True,
21
+ use_layer_norm: bool = False,
22
+ dropout_rate: float = 0.0,
23
+ residual: bool = True,
24
+ adaptive_activation: bool = False,
25
+ conditioning: str = 'glu',
26
+ concat_condition_first_layer: bool = True,
27
+ film_with_tanh: bool = False,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.concat_condition_first_layer = concat_condition_first_layer
32
+
33
+ dim_first_layer = dim_in + dim_condition if concat_condition_first_layer else dim_in
34
+ self.first_layer = nn.Linear(dim_first_layer, layer_width)
35
+
36
+ self.blocks = nn.ModuleList(
37
+ [
38
+ ResidualBlock(
39
+ layer_width=layer_width,
40
+ dim_condition=dim_condition,
41
+ repeats_per_block=repeats_per_block,
42
+ activation=activation,
43
+ use_batch_norm=use_batch_norm,
44
+ use_layer_norm=use_layer_norm,
45
+ dropout_rate=dropout_rate,
46
+ residual=residual,
47
+ adaptive_activation=adaptive_activation,
48
+ conditioning = conditioning,
49
+ film_with_tanh = film_with_tanh,
50
+ )
51
+ for _ in range(num_blocks)
52
+ ]
53
+ )
54
+
55
+ self.last_layer = nn.Linear(layer_width, dim_out)
56
+
57
+ def forward(self, x, condition=None):
58
+ if self.concat_condition_first_layer and condition is not None:
59
+ x = self.first_layer(torch.cat([x, condition], dim=-1))
60
+ else:
61
+ x = self.first_layer(x)
62
+
63
+ for block in self.blocks:
64
+ x = block(x, condition=condition)
65
+
66
+ x = self.last_layer(x)
67
+
68
+ return x
69
+
70
+
71
+ class ResidualBlock(nn.Module):
72
+ """Residual block (BN-Act-Linear-BN-Act-Linear)"""
73
+
74
+ def __init__(
75
+ self,
76
+ layer_width: int,
77
+ dim_condition: int = 0,
78
+ repeats_per_block: int = 2,
79
+ activation: str = 'relu',
80
+ use_batch_norm: bool = False,
81
+ use_layer_norm: bool = False,
82
+ dropout_rate: float = 0.0,
83
+ residual: bool = True,
84
+ adaptive_activation: bool = False,
85
+ conditioning: str = 'glu',
86
+ film_with_tanh: bool = False,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.residual = residual
91
+ self.repeats_per_block = repeats_per_block
92
+ self.use_batch_norm = use_batch_norm
93
+ self.use_layer_norm = use_layer_norm
94
+ self.dropout_rate = dropout_rate
95
+ self.adaptive_activation = adaptive_activation
96
+ self.conditioning = conditioning
97
+ self.film_with_tanh = film_with_tanh
98
+
99
+ if not adaptive_activation:
100
+ self.activation = activation_by_name(activation)()
101
+ else:
102
+ self.activation_layers = nn.ModuleList(
103
+ [activation_by_name(activation)() for _ in range(repeats_per_block)]
104
+ )
105
+
106
+ if use_batch_norm:
107
+ self.batch_norm_layers = nn.ModuleList(
108
+ [nn.BatchNorm1d(layer_width, eps=1e-3) for _ in range(repeats_per_block)]
109
+ )
110
+ elif use_layer_norm:
111
+ self.layer_norm_layers = nn.ModuleList(
112
+ [nn.LayerNorm(layer_width) for _ in range(repeats_per_block)]
113
+ )
114
+
115
+ if dim_condition:
116
+ if conditioning == 'glu':
117
+ self.condition_layer = nn.Linear(dim_condition, layer_width)
118
+ elif conditioning == 'film':
119
+ self.condition_layer = nn.Linear(dim_condition, 2*layer_width)
120
+
121
+ self.linear_layers = nn.ModuleList(
122
+ [nn.Linear(layer_width, layer_width) for _ in range(repeats_per_block)]
123
+ )
124
+
125
+ if self.dropout_rate > 0:
126
+ self.dropout = nn.Dropout(p=dropout_rate)
127
+
128
+ def forward(self, x, condition=None):
129
+ x0 = x
130
+
131
+ for i in range(self.repeats_per_block):
132
+ if self.use_batch_norm:
133
+ x = self.batch_norm_layers[i](x)
134
+ elif self.use_layer_norm:
135
+ x = self.layer_norm_layers[i](x)
136
+
137
+ if not self.adaptive_activation:
138
+ x = self.activation(x)
139
+ else:
140
+ x = self.activation_layers[i](x)
141
+
142
+ if self.dropout_rate > 0 and i == self.repeats_per_block - 1:
143
+ x = self.dropout(x)
144
+
145
+ x = self.linear_layers[i](x)
146
+
147
+ if condition is not None:
148
+ if self.conditioning == 'glu':
149
+ x = F.glu(torch.cat((x, self.condition_layer(condition)), dim=-1), dim=-1)
150
+ elif self.conditioning == 'film':
151
+ gamma, beta = torch.chunk(self.condition_layer(condition), chunks=2, dim=-1)
152
+ if self.film_with_tanh:
153
+ tanh = nn.Tanh()
154
+ gamma, beta = tanh(gamma), tanh(beta)
155
+ x = x * gamma + beta
156
+
157
157
  return x0 + x if self.residual else x
reflectorch/paths.py CHANGED
@@ -1,27 +1,29 @@
1
- from typing import Union
2
- from pathlib import Path
3
-
4
- __all__ = [
5
- 'ROOT_DIR',
6
- 'SAVED_MODELS_DIR',
7
- 'SAVED_LOSSES_DIR',
8
- 'RUN_SCRIPTS_DIR',
9
- 'CONFIG_DIR',
10
- 'TESTS_PATH',
11
- 'TEST_DATA_PATH',
12
- 'listdir',
13
- ]
14
-
15
- ROOT_DIR: Path = Path(__file__).parents[1]
16
- SAVED_MODELS_DIR: Path = ROOT_DIR / 'saved_models'
17
- SAVED_LOSSES_DIR: Path = ROOT_DIR / 'saved_losses'
18
- RUN_SCRIPTS_DIR: Path = ROOT_DIR / 'runs'
19
- CONFIG_DIR: Path = ROOT_DIR / 'configs'
20
- TESTS_PATH: Path = ROOT_DIR / 'tests'
21
- TEST_DATA_PATH = TESTS_PATH / 'data'
22
-
23
-
24
- def listdir(path: Union[Path, str], pattern: str = '*', recursive: bool = False, *, sort_key=None, reverse=False):
25
- path = Path(path)
26
- func = path.rglob if recursive else path.glob
27
- return sorted(list(func(pattern)), key=sort_key, reverse=reverse)
1
+ from typing import Union
2
+ from pathlib import Path
3
+
4
+ __all__ = [
5
+ 'ROOT_DIR',
6
+ 'EXP_DATA_DIR',
7
+ 'SAVED_MODELS_DIR',
8
+ 'SAVED_LOSSES_DIR',
9
+ 'RUN_SCRIPTS_DIR',
10
+ 'CONFIG_DIR',
11
+ 'TESTS_PATH',
12
+ 'TEST_DATA_PATH',
13
+ 'listdir',
14
+ ]
15
+
16
+ ROOT_DIR: Path = Path(__file__).parents[1]
17
+ EXP_DATA_DIR: Path = ROOT_DIR / 'exp_data'
18
+ SAVED_MODELS_DIR: Path = ROOT_DIR / 'saved_models'
19
+ SAVED_LOSSES_DIR: Path = ROOT_DIR / 'saved_losses'
20
+ RUN_SCRIPTS_DIR: Path = ROOT_DIR / 'runs'
21
+ CONFIG_DIR: Path = ROOT_DIR / 'configs'
22
+ TESTS_PATH: Path = ROOT_DIR / 'tests'
23
+ TEST_DATA_PATH = TESTS_PATH / 'data'
24
+
25
+
26
+ def listdir(path: Union[Path, str], pattern: str = '*', recursive: bool = False, *, sort_key=None, reverse=False):
27
+ path = Path(path)
28
+ func = path.rglob if recursive else path.glob
29
+ return sorted(list(func(pattern)), key=sort_key, reverse=reverse)
@@ -1,31 +1,31 @@
1
- from reflectorch.runs.train import (
2
- run_train,
3
- run_train_on_cluster,
4
- run_test_config,
5
- )
6
-
7
- from reflectorch.runs.utils import (
8
- train_from_config,
9
- get_trainer_from_config,
10
- get_paths_from_config,
11
- get_callbacks_from_config,
12
- get_trainer_by_name,
13
- get_callbacks_by_name,
14
- convert_pt_to_safetensors,
15
- )
16
-
17
- from reflectorch.runs.config import load_config
18
-
19
- __all__ = [
20
- 'run_train',
21
- 'run_train_on_cluster',
22
- 'train_from_config',
23
- 'run_test_config',
24
- 'get_trainer_from_config',
25
- 'get_paths_from_config',
26
- 'get_callbacks_from_config',
27
- 'get_trainer_by_name',
28
- 'get_callbacks_by_name',
29
- 'convert_pt_to_safetensors',
30
- 'load_config',
31
- ]
1
+ from reflectorch.runs.train import (
2
+ run_train,
3
+ run_train_on_cluster,
4
+ run_test_config,
5
+ )
6
+
7
+ from reflectorch.runs.utils import (
8
+ train_from_config,
9
+ get_trainer_from_config,
10
+ get_paths_from_config,
11
+ get_callbacks_from_config,
12
+ get_trainer_by_name,
13
+ get_callbacks_by_name,
14
+ convert_pt_to_safetensors,
15
+ )
16
+
17
+ from reflectorch.runs.config import load_config
18
+
19
+ __all__ = [
20
+ 'run_train',
21
+ 'run_train_on_cluster',
22
+ 'train_from_config',
23
+ 'run_test_config',
24
+ 'get_trainer_from_config',
25
+ 'get_paths_from_config',
26
+ 'get_callbacks_from_config',
27
+ 'get_trainer_by_name',
28
+ 'get_callbacks_by_name',
29
+ 'convert_pt_to_safetensors',
30
+ 'load_config',
31
+ ]
@@ -1,25 +1,25 @@
1
- import yaml
2
-
3
- from pathlib import Path
4
- from reflectorch.paths import CONFIG_DIR
5
-
6
-
7
- def load_config(config_name: str, config_dir: str = None) -> dict:
8
- """Loads a configuration dictionary from a YAML configuration file located in the configuration directory
9
-
10
- Args:
11
- config_name (str): name of the YAML configuration file
12
- config_dir (str): path of the configuration directory
13
-
14
- Returns:
15
- dict: the configuration dictionary
16
- """
17
- if not config_name.endswith('.yaml'):
18
- config_name = f'{config_name}.yaml'
19
- config_dir = Path(config_dir) if config_dir else CONFIG_DIR
20
- path = config_dir / config_name
21
- with open(path, 'r') as f:
22
- config = yaml.safe_load(f)
23
- config['config_path'] = str(path.absolute())
24
-
25
- return config
1
+ import yaml
2
+
3
+ from pathlib import Path
4
+ from reflectorch.paths import CONFIG_DIR
5
+
6
+
7
+ def load_config(config_name: str, config_dir: str = None) -> dict:
8
+ """Loads a configuration dictionary from a YAML configuration file located in the configuration directory
9
+
10
+ Args:
11
+ config_name (str): name of the YAML configuration file
12
+ config_dir (str): path of the configuration directory
13
+
14
+ Returns:
15
+ dict: the configuration dictionary
16
+ """
17
+ if not config_name.endswith('.yaml'):
18
+ config_name = f'{config_name}.yaml'
19
+ config_dir = Path(config_dir) if config_dir else CONFIG_DIR
20
+ path = config_dir / config_name
21
+ with open(path, 'r') as f:
22
+ config = yaml.safe_load(f)
23
+ config['config_path'] = str(path.absolute())
24
+
25
+ return config
@@ -1,93 +1,93 @@
1
- from typing import Tuple, Union
2
- from pathlib import Path
3
- import subprocess
4
-
5
- from reflectorch.paths import RUN_SCRIPTS_DIR
6
-
7
-
8
- def save_sbatch_and_run(
9
- name: str,
10
- args: str,
11
- time: str,
12
- partition: str = None,
13
- reservation: bool = False,
14
- chdir: str = '~/maxwell_output',
15
- run_dir: Path = None,
16
- confirm: bool = False,
17
- ) -> Union[Tuple[str, str], None]:
18
- run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
19
- sbatch_path = run_dir / f'{name}.sh'
20
-
21
- if sbatch_path.is_file():
22
- import warnings
23
- warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
24
- if confirm and not confirm_input('Continue?'):
25
- return
26
-
27
- file_content = _generate_sbatch_str(
28
- name,
29
- args,
30
- time=time,
31
- reservation=reservation,
32
- partition=partition,
33
- chdir=chdir,
34
- )
35
-
36
- if confirm:
37
- print(f'Generated file content: \n{file_content}\n')
38
- if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
39
- return
40
-
41
- with open(str(sbatch_path), 'w') as f:
42
- f.write(file_content)
43
-
44
- res = submit_job(str(sbatch_path))
45
- return res
46
-
47
-
48
- def _generate_sbatch_str(name: str,
49
- args: str,
50
- time: str,
51
- partition: str = None,
52
- reservation: bool = False,
53
- chdir: str = '~/maxwell_output',
54
- entry_point: str = 'python -m reflectorch.train',
55
- ) -> str:
56
- chdir = str(Path(chdir).expanduser().absolute())
57
- partition_keyword = 'reservation' if reservation else 'partition'
58
-
59
- return f'''#!/bin/bash
60
- #SBATCH --chdir {chdir}
61
- #SBATCH --{partition_keyword}={partition}
62
- #SBATCH --constraint=P100
63
- #SBATCH --nodes=1
64
- #SBATCH --job-name {name}
65
- #SBATCH --time={time}
66
- #SBATCH --output {name}.out
67
- #SBATCH --error {name}.err
68
-
69
- {entry_point} {args}
70
- '''
71
-
72
-
73
- def confirm_input(message: str) -> bool:
74
- yes = ('y', 'yes')
75
- no = ('n', 'no')
76
- res = ''
77
- valid_results = list(yes) + list(no)
78
- message = f'{message} Y/n: '
79
-
80
- while res not in valid_results:
81
- res = input(message).lower()
82
- return res in yes
83
-
84
-
85
- def submit_job(sbatch_path: str) -> Tuple[str, str]:
86
- process = subprocess.Popen(
87
- ['sbatch', str(sbatch_path)],
88
- stdout=subprocess.PIPE,
89
- stderr=subprocess.PIPE,
90
- )
91
-
92
- stdout, stderr = process.communicate()
93
- return stdout.decode(), stderr.decode()
1
+ from typing import Tuple, Union
2
+ from pathlib import Path
3
+ import subprocess
4
+
5
+ from reflectorch.paths import RUN_SCRIPTS_DIR
6
+
7
+
8
+ def save_sbatch_and_run(
9
+ name: str,
10
+ args: str,
11
+ time: str,
12
+ partition: str = None,
13
+ reservation: bool = False,
14
+ chdir: str = '~/maxwell_output',
15
+ run_dir: Path = None,
16
+ confirm: bool = False,
17
+ ) -> Union[Tuple[str, str], None]:
18
+ run_dir = Path(run_dir) if run_dir else RUN_SCRIPTS_DIR
19
+ sbatch_path = run_dir / f'{name}.sh'
20
+
21
+ if sbatch_path.is_file():
22
+ import warnings
23
+ warnings.warn(f'Sbatch file {str(sbatch_path)} already exists!')
24
+ if confirm and not confirm_input('Continue?'):
25
+ return
26
+
27
+ file_content = _generate_sbatch_str(
28
+ name,
29
+ args,
30
+ time=time,
31
+ reservation=reservation,
32
+ partition=partition,
33
+ chdir=chdir,
34
+ )
35
+
36
+ if confirm:
37
+ print(f'Generated file content: \n{file_content}\n')
38
+ if not confirm_input(f'Save to {str(sbatch_path)} and run?'):
39
+ return
40
+
41
+ with open(str(sbatch_path), 'w') as f:
42
+ f.write(file_content)
43
+
44
+ res = submit_job(str(sbatch_path))
45
+ return res
46
+
47
+
48
+ def _generate_sbatch_str(name: str,
49
+ args: str,
50
+ time: str,
51
+ partition: str = None,
52
+ reservation: bool = False,
53
+ chdir: str = '~/maxwell_output',
54
+ entry_point: str = 'python -m reflectorch.train',
55
+ ) -> str:
56
+ chdir = str(Path(chdir).expanduser().absolute())
57
+ partition_keyword = 'reservation' if reservation else 'partition'
58
+
59
+ return f'''#!/bin/bash
60
+ #SBATCH --chdir {chdir}
61
+ #SBATCH --{partition_keyword}={partition}
62
+ #SBATCH --constraint=P100
63
+ #SBATCH --nodes=1
64
+ #SBATCH --job-name {name}
65
+ #SBATCH --time={time}
66
+ #SBATCH --output {name}.out
67
+ #SBATCH --error {name}.err
68
+
69
+ {entry_point} {args}
70
+ '''
71
+
72
+
73
+ def confirm_input(message: str) -> bool:
74
+ yes = ('y', 'yes')
75
+ no = ('n', 'no')
76
+ res = ''
77
+ valid_results = list(yes) + list(no)
78
+ message = f'{message} Y/n: '
79
+
80
+ while res not in valid_results:
81
+ res = input(message).lower()
82
+ return res in yes
83
+
84
+
85
+ def submit_job(sbatch_path: str) -> Tuple[str, str]:
86
+ process = subprocess.Popen(
87
+ ['sbatch', str(sbatch_path)],
88
+ stdout=subprocess.PIPE,
89
+ stderr=subprocess.PIPE,
90
+ )
91
+
92
+ stdout, stderr = process.communicate()
93
+ return stdout.decode(), stderr.decode()