reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
reflectorch/runs/train.py
CHANGED
|
@@ -1,78 +1,78 @@
|
|
|
1
|
-
import click
|
|
2
|
-
|
|
3
|
-
from reflectorch.runs.slurm_utils import save_sbatch_and_run
|
|
4
|
-
from reflectorch.runs.utils import train_from_config
|
|
5
|
-
from reflectorch.runs.config import load_config
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
'run_train',
|
|
9
|
-
'run_train_on_cluster',
|
|
10
|
-
'run_test_config',
|
|
11
|
-
]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
@click.command()
|
|
15
|
-
@click.argument('config_name', type=str)
|
|
16
|
-
def run_train(config_name: str):
|
|
17
|
-
"""Runs the training from the command line interface
|
|
18
|
-
Example: python -m reflectorch.train 'conf_name'
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
config_name (str): name of the YAML configuration file
|
|
22
|
-
"""
|
|
23
|
-
config = load_config(config_name)
|
|
24
|
-
train_from_config(config)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@click.command()
|
|
28
|
-
@click.argument('config_name', type=str)
|
|
29
|
-
@click.argument('batch_size', type=int, default=512)
|
|
30
|
-
@click.argument('num_iterations', type=int, default=10)
|
|
31
|
-
def run_test_config(config_name: str, batch_size: int, num_iterations: int):
|
|
32
|
-
"""Run for the purpose of testing the configuration file.
|
|
33
|
-
Example: python -m reflectorch.test_config 'conf_name.yaml' 512 10
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
config_name (str): name of the YAML configuration file
|
|
37
|
-
batch_size (int): overwrites the batch size in the configuration file
|
|
38
|
-
num_iterations (int): overwrites the number of iterations in the configuration file
|
|
39
|
-
"""
|
|
40
|
-
config = load_config(config_name)
|
|
41
|
-
config = _change_to_test_config(config, batch_size=batch_size, num_iterations=num_iterations)
|
|
42
|
-
train_from_config(config)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@click.command()
|
|
46
|
-
@click.argument('config_name')
|
|
47
|
-
def run_train_on_cluster(config_name: str):
|
|
48
|
-
config = load_config(config_name)
|
|
49
|
-
name = config['general']['name']
|
|
50
|
-
slurm_conf = config['slurm']
|
|
51
|
-
|
|
52
|
-
res = save_sbatch_and_run(
|
|
53
|
-
name,
|
|
54
|
-
config_name,
|
|
55
|
-
time=slurm_conf['time'],
|
|
56
|
-
partition=slurm_conf['partition'],
|
|
57
|
-
reservation=slurm_conf.get('reservation', False),
|
|
58
|
-
chdir=slurm_conf.get('chdir', '~/maxwell_output'),
|
|
59
|
-
run_dir=slurm_conf.get('run_dir', None),
|
|
60
|
-
confirm=slurm_conf.get('confirm', True),
|
|
61
|
-
)
|
|
62
|
-
if not res:
|
|
63
|
-
print('Aborted.')
|
|
64
|
-
return
|
|
65
|
-
out, err = res
|
|
66
|
-
|
|
67
|
-
if err:
|
|
68
|
-
print('Error occurred: ', err)
|
|
69
|
-
else:
|
|
70
|
-
print('Success!', out)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def _change_to_test_config(config, batch_size: int, num_iterations: int):
|
|
74
|
-
config = dict(config)
|
|
75
|
-
config['training']['num_iterations'] = num_iterations
|
|
76
|
-
config['training']['batch_size'] = batch_size
|
|
77
|
-
config['training']['update_tqdm_freq'] = 1
|
|
78
|
-
return config
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from reflectorch.runs.slurm_utils import save_sbatch_and_run
|
|
4
|
+
from reflectorch.runs.utils import train_from_config
|
|
5
|
+
from reflectorch.runs.config import load_config
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
'run_train',
|
|
9
|
+
'run_train_on_cluster',
|
|
10
|
+
'run_test_config',
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@click.command()
|
|
15
|
+
@click.argument('config_name', type=str)
|
|
16
|
+
def run_train(config_name: str):
|
|
17
|
+
"""Runs the training from the command line interface
|
|
18
|
+
Example: python -m reflectorch.train 'conf_name'
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config_name (str): name of the YAML configuration file
|
|
22
|
+
"""
|
|
23
|
+
config = load_config(config_name)
|
|
24
|
+
train_from_config(config)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@click.command()
|
|
28
|
+
@click.argument('config_name', type=str)
|
|
29
|
+
@click.argument('batch_size', type=int, default=512)
|
|
30
|
+
@click.argument('num_iterations', type=int, default=10)
|
|
31
|
+
def run_test_config(config_name: str, batch_size: int, num_iterations: int):
|
|
32
|
+
"""Run for the purpose of testing the configuration file.
|
|
33
|
+
Example: python -m reflectorch.test_config 'conf_name.yaml' 512 10
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config_name (str): name of the YAML configuration file
|
|
37
|
+
batch_size (int): overwrites the batch size in the configuration file
|
|
38
|
+
num_iterations (int): overwrites the number of iterations in the configuration file
|
|
39
|
+
"""
|
|
40
|
+
config = load_config(config_name)
|
|
41
|
+
config = _change_to_test_config(config, batch_size=batch_size, num_iterations=num_iterations)
|
|
42
|
+
train_from_config(config)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@click.command()
|
|
46
|
+
@click.argument('config_name')
|
|
47
|
+
def run_train_on_cluster(config_name: str):
|
|
48
|
+
config = load_config(config_name)
|
|
49
|
+
name = config['general']['name']
|
|
50
|
+
slurm_conf = config['slurm']
|
|
51
|
+
|
|
52
|
+
res = save_sbatch_and_run(
|
|
53
|
+
name,
|
|
54
|
+
config_name,
|
|
55
|
+
time=slurm_conf['time'],
|
|
56
|
+
partition=slurm_conf['partition'],
|
|
57
|
+
reservation=slurm_conf.get('reservation', False),
|
|
58
|
+
chdir=slurm_conf.get('chdir', '~/maxwell_output'),
|
|
59
|
+
run_dir=slurm_conf.get('run_dir', None),
|
|
60
|
+
confirm=slurm_conf.get('confirm', True),
|
|
61
|
+
)
|
|
62
|
+
if not res:
|
|
63
|
+
print('Aborted.')
|
|
64
|
+
return
|
|
65
|
+
out, err = res
|
|
66
|
+
|
|
67
|
+
if err:
|
|
68
|
+
print('Error occurred: ', err)
|
|
69
|
+
else:
|
|
70
|
+
print('Success!', out)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _change_to_test_config(config, batch_size: int, num_iterations: int):
|
|
74
|
+
config = dict(config)
|
|
75
|
+
config['training']['num_iterations'] = num_iterations
|
|
76
|
+
config['training']['batch_size'] = batch_size
|
|
77
|
+
config['training']['update_tqdm_freq'] = 1
|
|
78
|
+
return config
|