reflectorch 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +0 -6
- reflectorch/data_generation/__init__.py +0 -6
- reflectorch/data_generation/dataset.py +0 -6
- reflectorch/data_generation/likelihoods.py +0 -6
- reflectorch/data_generation/noise.py +0 -6
- reflectorch/data_generation/priors/__init__.py +0 -6
- reflectorch/data_generation/priors/base.py +0 -6
- reflectorch/data_generation/priors/exp_subprior_sampler.py +0 -6
- reflectorch/data_generation/priors/independent_priors.py +0 -6
- reflectorch/data_generation/priors/multilayer_structures.py +0 -6
- reflectorch/data_generation/priors/no_constraints.py +0 -6
- reflectorch/data_generation/priors/parametric_subpriors.py +0 -6
- reflectorch/data_generation/priors/params.py +0 -6
- reflectorch/data_generation/priors/subprior_sampler.py +0 -6
- reflectorch/data_generation/priors/utils.py +0 -6
- reflectorch/data_generation/process_data.py +0 -6
- reflectorch/data_generation/q_generator.py +0 -6
- reflectorch/data_generation/reflectivity/__init__.py +13 -9
- reflectorch/data_generation/reflectivity/abeles.py +6 -5
- reflectorch/data_generation/reflectivity/kinematical.py +14 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +13 -0
- reflectorch/data_generation/reflectivity/smearing.py +2 -2
- reflectorch/data_generation/scale_curves.py +0 -6
- reflectorch/data_generation/smearing.py +3 -2
- reflectorch/data_generation/utils.py +0 -6
- reflectorch/extensions/__init__.py +0 -6
- reflectorch/extensions/jupyter/__init__.py +0 -6
- reflectorch/extensions/jupyter/callbacks.py +0 -6
- reflectorch/extensions/matplotlib/__init__.py +0 -6
- reflectorch/extensions/matplotlib/losses.py +0 -6
- reflectorch/inference/__init__.py +2 -0
- reflectorch/inference/inference_model.py +9 -6
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/ml/__init__.py +0 -7
- reflectorch/ml/basic_trainer.py +0 -6
- reflectorch/ml/callbacks.py +0 -6
- reflectorch/ml/loggers.py +0 -7
- reflectorch/ml/schedulers.py +0 -6
- reflectorch/ml/trainers.py +1 -34
- reflectorch/ml/utils.py +0 -7
- reflectorch/models/__init__.py +0 -7
- reflectorch/models/encoders/__init__.py +0 -8
- reflectorch/models/encoders/conv_encoder.py +0 -6
- reflectorch/models/encoders/conv_res_net.py +1 -5
- reflectorch/models/networks/__init__.py +0 -6
- reflectorch/paths.py +0 -6
- reflectorch/runs/__init__.py +2 -6
- reflectorch/runs/config.py +0 -6
- reflectorch/runs/slurm_utils.py +0 -6
- reflectorch/runs/train.py +0 -6
- reflectorch/runs/utils.py +82 -14
- reflectorch/utils.py +0 -6
- {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/METADATA +15 -10
- reflectorch-1.2.0.dist-info/RECORD +83 -0
- {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/WHEEL +1 -1
- reflectorch/models/encoders/transformers.py +0 -56
- reflectorch-1.0.1.dist-info/RECORD +0 -83
- {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/LICENSE.txt +0 -0
- {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/top_level.txt +0 -0
reflectorch/__init__.py
CHANGED
|
@@ -1,9 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the GPL license found in the
|
|
5
|
-
# LICENSE file in the root directory of this source tree.
|
|
6
|
-
|
|
7
1
|
from reflectorch.data_generation import *
|
|
8
2
|
from reflectorch.ml import *
|
|
9
3
|
from reflectorch.models import *
|
|
@@ -1,9 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the GPL license found in the
|
|
5
|
-
# LICENSE file in the root directory of this source tree.
|
|
6
|
-
|
|
7
1
|
from reflectorch.data_generation.dataset import BasicDataset, BATCH_DATA_TYPE
|
|
8
2
|
from reflectorch.data_generation.priors import (
|
|
9
3
|
Params,
|
|
@@ -1,9 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the GPL license found in the
|
|
5
|
-
# LICENSE file in the root directory of this source tree.
|
|
6
|
-
|
|
7
1
|
from reflectorch.data_generation.priors.params import Params
|
|
8
2
|
from reflectorch.data_generation.priors.base import PriorSampler
|
|
9
3
|
from reflectorch.data_generation.priors.no_constraints import BasicPriorSampler
|
|
@@ -24,21 +24,25 @@ def reflectivity(
|
|
|
24
24
|
log: bool = False,
|
|
25
25
|
abeles_func=None,
|
|
26
26
|
):
|
|
27
|
-
"""Function which computes the reflectivity from thin film parameters
|
|
27
|
+
"""Function which computes the reflectivity curves from thin film parameters.
|
|
28
|
+
By default it uses the fast implementation of the Abeles matrix formalism.
|
|
28
29
|
|
|
29
30
|
Args:
|
|
30
|
-
q (Tensor):
|
|
31
|
-
thickness (Tensor): the layer thicknesses
|
|
32
|
-
roughness (Tensor): the interlayer roughnesses
|
|
33
|
-
sld (Tensor): the SLDs
|
|
34
|
-
|
|
31
|
+
q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
|
|
32
|
+
thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
|
|
33
|
+
roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
|
|
34
|
+
sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
|
|
35
|
+
It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
|
|
36
|
+
dq (Tensor, optional): tensor of resolutions used for curve smearing with shape [batch_size, 1].
|
|
37
|
+
Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
|
|
35
38
|
gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
|
|
36
|
-
constant_dq (bool, optional):
|
|
39
|
+
constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
|
|
40
|
+
otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to True.
|
|
37
41
|
log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
|
|
38
|
-
abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default implementation. Defaults to None.
|
|
42
|
+
abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation. Defaults to None.
|
|
39
43
|
|
|
40
44
|
Returns:
|
|
41
|
-
Tensor: the
|
|
45
|
+
Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
|
|
42
46
|
"""
|
|
43
47
|
abeles_func = abeles_func or abeles
|
|
44
48
|
q = torch.atleast_2d(q)
|
|
@@ -14,13 +14,14 @@ def abeles(
|
|
|
14
14
|
"""Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method
|
|
15
15
|
|
|
16
16
|
Args:
|
|
17
|
-
q (Tensor): q values
|
|
18
|
-
thickness (Tensor): layer thicknesses
|
|
19
|
-
roughness (Tensor): interlayer roughnesses
|
|
20
|
-
sld (Tensor): layer SLDs
|
|
17
|
+
q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
|
|
18
|
+
thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
|
|
19
|
+
roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
|
|
20
|
+
sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
|
|
21
|
+
It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
|
|
21
22
|
|
|
22
23
|
Returns:
|
|
23
|
-
Tensor: simulated reflectivity curves
|
|
24
|
+
Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
|
|
24
25
|
"""
|
|
25
26
|
c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
|
|
26
27
|
|
|
@@ -13,6 +13,20 @@ def kinematical_approximation(
|
|
|
13
13
|
apply_fresnel: bool = True,
|
|
14
14
|
log: bool = False,
|
|
15
15
|
):
|
|
16
|
+
"""Simulates reflectivity curves for SLD profiles with box model parameterization using the kinematical approximation
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
|
|
20
|
+
thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
|
|
21
|
+
roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
|
|
22
|
+
sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
|
|
23
|
+
It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
|
|
24
|
+
apply_fresnel (bool, optional): whether to use the Fresnel coefficient in the computation. Defaults to ``True``.
|
|
25
|
+
log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to ``False``.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
|
|
29
|
+
"""
|
|
16
30
|
c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
|
|
17
31
|
|
|
18
32
|
batch_size, num_layers = thickness.shape
|
|
@@ -12,6 +12,19 @@ def abeles_memory_eff(
|
|
|
12
12
|
roughness: Tensor,
|
|
13
13
|
sld: Tensor,
|
|
14
14
|
):
|
|
15
|
+
"""Simulates reflectivity curves for SLD profiles with box model parameterization using a memory-efficient implementation the Abeles matrix method.
|
|
16
|
+
It is computationally slower compared to the implementation in the 'abeles' function.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
|
|
20
|
+
thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
|
|
21
|
+
roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
|
|
22
|
+
sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
|
|
23
|
+
It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
|
|
27
|
+
"""
|
|
15
28
|
c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
|
|
16
29
|
|
|
17
30
|
batch_size, num_layers = thickness.shape
|
|
@@ -68,8 +68,8 @@ def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51
|
|
|
68
68
|
lowq = torch.clamp_min_(q.min(1).values, 1e-6)
|
|
69
69
|
highq = q.max(1).values
|
|
70
70
|
|
|
71
|
-
start = torch.log10(lowq) - 6 * resolutions / _FWHM
|
|
72
|
-
end = torch.log10(highq * (1 + 6 * resolutions / _FWHM))
|
|
71
|
+
start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
|
|
72
|
+
end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
|
|
73
73
|
|
|
74
74
|
interpnums = torch.abs(
|
|
75
75
|
(torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
|
|
@@ -9,8 +9,9 @@ class Smearing(object):
|
|
|
9
9
|
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
10
|
|
|
11
11
|
Args:
|
|
12
|
-
sigma_range (tuple, optional): the range for sampling the
|
|
13
|
-
constant_dq (bool, optional):
|
|
12
|
+
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (1e-4, 5e-3).
|
|
13
|
+
constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
|
|
14
|
+
otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
|
|
14
15
|
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
15
16
|
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
16
17
|
"""
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
|
|
2
|
+
from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
|
|
2
3
|
from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
|
|
3
4
|
from reflectorch.inference.preprocess_exp import (
|
|
4
5
|
StandardPreprocessing,
|
|
@@ -13,6 +14,7 @@ __all__ = [
|
|
|
13
14
|
"InferenceModel",
|
|
14
15
|
"EasyInferenceModel",
|
|
15
16
|
"MultilayerInferenceModel",
|
|
17
|
+
"HuggingfaceQueryMatcher",
|
|
16
18
|
"StandardPreprocessing",
|
|
17
19
|
"standard_preprocessing",
|
|
18
20
|
"ReflGradientFit",
|
|
@@ -35,15 +35,17 @@ class EasyInferenceModel(object):
|
|
|
35
35
|
config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
|
|
36
36
|
model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
|
|
37
37
|
root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
|
|
38
|
+
weights_format (str, optional): format (extension) of the weights file, either 'pt' or 'safetensors'. Defaults to 'safetensors'.
|
|
38
39
|
repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
|
|
39
40
|
trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
|
|
40
41
|
device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
|
|
41
42
|
"""
|
|
42
|
-
def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None,
|
|
43
|
-
trainer: PointEstimatorTrainer = None, device='cuda'):
|
|
43
|
+
def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, weights_format: str = 'safetensors',
|
|
44
|
+
repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer = None, device='cuda'):
|
|
44
45
|
self.config_name = config_name
|
|
45
46
|
self.model_name = model_name
|
|
46
47
|
self.root_dir = root_dir
|
|
48
|
+
self.weights_format = weights_format
|
|
47
49
|
self.repo_id = repo_id
|
|
48
50
|
self.trainer = trainer
|
|
49
51
|
self.device = device
|
|
@@ -58,7 +60,7 @@ class EasyInferenceModel(object):
|
|
|
58
60
|
|
|
59
61
|
Args:
|
|
60
62
|
config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
|
|
61
|
-
model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name +
|
|
63
|
+
model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' or '.safetensors' extension), only required if different than: `'model_' + config_name + extension`.
|
|
62
64
|
root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
|
|
63
65
|
"""
|
|
64
66
|
if self.config_name == config_name and self.trainer is not None:
|
|
@@ -72,9 +74,10 @@ class EasyInferenceModel(object):
|
|
|
72
74
|
self.config_name = config_name
|
|
73
75
|
|
|
74
76
|
self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
77
|
+
weights_extension = '.' + self.weights_format
|
|
78
|
+
self.model_name = model_name or 'model_' + config_name_no_extension + weights_extension
|
|
79
|
+
if not self.model_name.endswith(weights_extension):
|
|
80
|
+
self.model_name += weights_extension
|
|
78
81
|
self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
|
|
79
82
|
|
|
80
83
|
config_path = Path(self.config_dir) / self.config_name
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import yaml
|
|
4
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
5
|
+
|
|
6
|
+
class HuggingfaceQueryMatcher:
|
|
7
|
+
"""Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
repo_id (str): The Hugging Face repository ID.
|
|
11
|
+
config_dir (str): Directory within the repo where YAML files are stored.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
|
|
14
|
+
self.repo_id = repo_id
|
|
15
|
+
self.config_dir = config_dir
|
|
16
|
+
self.cache = {
|
|
17
|
+
'parsed_configs': None,
|
|
18
|
+
'temp_dir': None
|
|
19
|
+
}
|
|
20
|
+
self._renew_cache()
|
|
21
|
+
|
|
22
|
+
def _renew_cache(self):
|
|
23
|
+
temp_dir = tempfile.mkdtemp()
|
|
24
|
+
print(f"Temporary directory created at: {temp_dir}")
|
|
25
|
+
|
|
26
|
+
repo_files = list_repo_files(self.repo_id, repo_type='model')
|
|
27
|
+
config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
|
|
28
|
+
|
|
29
|
+
downloaded_files = []
|
|
30
|
+
for file in config_files:
|
|
31
|
+
file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
|
|
32
|
+
downloaded_files.append(file_path)
|
|
33
|
+
|
|
34
|
+
parsed_configs = {}
|
|
35
|
+
for file_path in downloaded_files:
|
|
36
|
+
with open(file_path, 'r') as file:
|
|
37
|
+
config_data = yaml.safe_load(file)
|
|
38
|
+
file_name = os.path.basename(file_path)
|
|
39
|
+
parsed_configs[file_name] = config_data
|
|
40
|
+
|
|
41
|
+
self.cache['parsed_configs'] = parsed_configs
|
|
42
|
+
self.cache['temp_dir'] = temp_dir
|
|
43
|
+
|
|
44
|
+
def get_matching_configs(self, query):
|
|
45
|
+
"""retrieves configuration files that match the user specified query.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
|
|
49
|
+
For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
|
|
50
|
+
is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
list: List of file names that match the query.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
filtered_configs = []
|
|
57
|
+
|
|
58
|
+
for file_name, config_data in self.cache['parsed_configs'].items():
|
|
59
|
+
if self.matches_query(config_data, query):
|
|
60
|
+
filtered_configs.append(file_name)
|
|
61
|
+
|
|
62
|
+
return filtered_configs
|
|
63
|
+
|
|
64
|
+
def matches_query(self, config_data, query):
|
|
65
|
+
for q_key, q_value in query.items():
|
|
66
|
+
keys = q_key.split('.')
|
|
67
|
+
value = self.deep_get(config_data, keys)
|
|
68
|
+
if 'param_ranges' in keys:
|
|
69
|
+
if q_value[0] < value[0] or q_value[1] > value[1]:
|
|
70
|
+
return False
|
|
71
|
+
else:
|
|
72
|
+
if value != q_value:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def deep_get(self, d, keys):
|
|
78
|
+
for key in keys:
|
|
79
|
+
if isinstance(d, dict):
|
|
80
|
+
d = d.get(key, None)
|
|
81
|
+
|
|
82
|
+
return d
|
reflectorch/ml/__init__.py
CHANGED
|
@@ -1,9 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the GPL license found in the
|
|
5
|
-
# LICENSE file in the root directory of this source tree.
|
|
6
|
-
|
|
7
1
|
from reflectorch.ml.basic_trainer import *
|
|
8
2
|
from reflectorch.ml.callbacks import *
|
|
9
3
|
from reflectorch.ml.trainers import *
|
|
@@ -32,6 +26,5 @@ __all__ = [
|
|
|
32
26
|
'MultilayerDataLoader',
|
|
33
27
|
'RealTimeSimTrainer',
|
|
34
28
|
'DenoisingAETrainer',
|
|
35
|
-
'VAETrainer',
|
|
36
29
|
'PointEstimatorTrainer',
|
|
37
30
|
]
|
reflectorch/ml/basic_trainer.py
CHANGED
|
@@ -1,9 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the GPL license found in the
|
|
5
|
-
# LICENSE file in the root directory of this source tree.
|
|
6
|
-
|
|
7
1
|
from typing import Optional, Tuple, Iterable, Any, Union, Type
|
|
8
2
|
from collections import defaultdict
|
|
9
3
|
|
reflectorch/ml/callbacks.py
CHANGED
reflectorch/ml/loggers.py
CHANGED
reflectorch/ml/schedulers.py
CHANGED
reflectorch/ml/trainers.py
CHANGED
|
@@ -1,9 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the GPL license found in the
|
|
5
|
-
# LICENSE file in the root directory of this source tree.
|
|
6
|
-
|
|
7
1
|
import numpy as np
|
|
8
2
|
import torch
|
|
9
3
|
import torch.nn.functional as F
|
|
@@ -16,7 +10,6 @@ from reflectorch.ml.dataloaders import ReflectivityDataLoader
|
|
|
16
10
|
__all__ = [
|
|
17
11
|
'RealTimeSimTrainer',
|
|
18
12
|
'DenoisingAETrainer',
|
|
19
|
-
'VAETrainer',
|
|
20
13
|
'PointEstimatorTrainer',
|
|
21
14
|
]
|
|
22
15
|
|
|
@@ -97,30 +90,4 @@ class DenoisingAETrainer(RealTimeSimTrainer):
|
|
|
97
90
|
restored_curves = self.model(scaled_noisy_curves)
|
|
98
91
|
loss = self.criterion(scaled_curves, restored_curves)
|
|
99
92
|
return {'loss': loss}
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class VAETrainer(DenoisingAETrainer):
|
|
103
|
-
"""Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
|
|
104
|
-
def init(self):
|
|
105
|
-
self.loader.calc_denoised_curves = True
|
|
106
|
-
self.freebits = 0.05
|
|
107
|
-
|
|
108
|
-
def calc_kl(self, z_mu, z_logvar):
|
|
109
|
-
return 0.5*(z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)
|
|
110
|
-
|
|
111
|
-
def gaussian_log_prob(self, z, mu, logvar):
|
|
112
|
-
return -0.5*(np.log(2*np.pi) + logvar + (z-mu)**2/torch.exp(logvar))
|
|
113
|
-
|
|
114
|
-
def get_loss_dict(self, batch_data):
|
|
115
|
-
"""returns the reconstruction loss of the autoencoder"""
|
|
116
|
-
scaled_noisy_curves, scaled_curves = batch_data
|
|
117
|
-
_, (z_mu, z_logvar, restored_curves_mu, restored_curves_logvar) = self.model(scaled_noisy_curves)
|
|
118
|
-
|
|
119
|
-
l_rec = -torch.mean(self.gaussian_log_prob(scaled_curves, restored_curves_mu, restored_curves_logvar), dim=-1)
|
|
120
|
-
l_kl = torch.mean(F.relu(self.calc_kl(z_mu, z_logvar) - self.freebits*np.log(2)) + self.freebits*np.log(2), dim=-1)
|
|
121
|
-
loss = torch.mean(l_rec + l_kl)/np.log(2)
|
|
122
|
-
|
|
123
|
-
l_rec = torch.mean(l_rec)
|
|
124
|
-
l_kl = torch.mean(l_kl)
|
|
125
|
-
|
|
126
|
-
return {'loss': loss}
|
|
93
|
+
|
reflectorch/ml/utils.py
CHANGED