reflectorch 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (57) hide show
  1. reflectorch/__init__.py +0 -6
  2. reflectorch/data_generation/__init__.py +0 -6
  3. reflectorch/data_generation/dataset.py +0 -6
  4. reflectorch/data_generation/likelihoods.py +0 -6
  5. reflectorch/data_generation/noise.py +0 -6
  6. reflectorch/data_generation/priors/__init__.py +0 -6
  7. reflectorch/data_generation/priors/base.py +0 -6
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +0 -6
  9. reflectorch/data_generation/priors/independent_priors.py +0 -6
  10. reflectorch/data_generation/priors/multilayer_structures.py +0 -6
  11. reflectorch/data_generation/priors/no_constraints.py +0 -6
  12. reflectorch/data_generation/priors/parametric_subpriors.py +0 -6
  13. reflectorch/data_generation/priors/params.py +0 -6
  14. reflectorch/data_generation/priors/subprior_sampler.py +0 -6
  15. reflectorch/data_generation/priors/utils.py +0 -6
  16. reflectorch/data_generation/process_data.py +0 -6
  17. reflectorch/data_generation/q_generator.py +0 -6
  18. reflectorch/data_generation/reflectivity/__init__.py +13 -9
  19. reflectorch/data_generation/reflectivity/abeles.py +6 -5
  20. reflectorch/data_generation/reflectivity/kinematical.py +14 -0
  21. reflectorch/data_generation/reflectivity/memory_eff.py +13 -0
  22. reflectorch/data_generation/reflectivity/smearing.py +2 -2
  23. reflectorch/data_generation/scale_curves.py +0 -6
  24. reflectorch/data_generation/smearing.py +3 -2
  25. reflectorch/data_generation/utils.py +0 -6
  26. reflectorch/extensions/__init__.py +0 -6
  27. reflectorch/extensions/jupyter/__init__.py +0 -6
  28. reflectorch/extensions/jupyter/callbacks.py +0 -6
  29. reflectorch/extensions/matplotlib/__init__.py +0 -6
  30. reflectorch/extensions/matplotlib/losses.py +0 -6
  31. reflectorch/inference/inference_model.py +9 -6
  32. reflectorch/ml/__init__.py +0 -7
  33. reflectorch/ml/basic_trainer.py +0 -6
  34. reflectorch/ml/callbacks.py +0 -6
  35. reflectorch/ml/loggers.py +0 -7
  36. reflectorch/ml/schedulers.py +0 -6
  37. reflectorch/ml/trainers.py +1 -34
  38. reflectorch/ml/utils.py +0 -7
  39. reflectorch/models/__init__.py +0 -7
  40. reflectorch/models/encoders/__init__.py +0 -8
  41. reflectorch/models/encoders/conv_encoder.py +0 -6
  42. reflectorch/models/encoders/conv_res_net.py +1 -5
  43. reflectorch/models/networks/__init__.py +0 -6
  44. reflectorch/paths.py +0 -6
  45. reflectorch/runs/__init__.py +2 -6
  46. reflectorch/runs/config.py +0 -6
  47. reflectorch/runs/slurm_utils.py +0 -6
  48. reflectorch/runs/train.py +0 -6
  49. reflectorch/runs/utils.py +77 -8
  50. reflectorch/utils.py +0 -6
  51. {reflectorch-1.1.0.dist-info → reflectorch-1.2.0.dist-info}/METADATA +2 -1
  52. reflectorch-1.2.0.dist-info/RECORD +83 -0
  53. {reflectorch-1.1.0.dist-info → reflectorch-1.2.0.dist-info}/WHEEL +1 -1
  54. reflectorch/models/encoders/transformers.py +0 -56
  55. reflectorch-1.1.0.dist-info/RECORD +0 -84
  56. {reflectorch-1.1.0.dist-info → reflectorch-1.2.0.dist-info}/LICENSE.txt +0 -0
  57. {reflectorch-1.1.0.dist-info → reflectorch-1.2.0.dist-info}/top_level.txt +0 -0
reflectorch/__init__.py CHANGED
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.data_generation import *
8
2
  from reflectorch.ml import *
9
3
  from reflectorch.models import *
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.data_generation.dataset import BasicDataset, BATCH_DATA_TYPE
8
2
  from reflectorch.data_generation.priors import (
9
3
  Params,
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Dict, Union
8
2
  import warnings
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union, Tuple
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union, Tuple
8
2
  from math import log10
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.data_generation.priors.params import Params
8
2
  from reflectorch.data_generation.priors.base import PriorSampler
9
3
  from reflectorch.data_generation.priors.no_constraints import BasicPriorSampler
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from torch import Tensor
8
2
 
9
3
  from reflectorch.data_generation.priors.params import Params
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Union, List
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union, Tuple
8
2
  from math import log
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Dict
8
2
 
9
3
  import numpy as np
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import logging
8
2
  from functools import lru_cache
9
3
  from typing import Tuple
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Dict, Type, List
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import List, Tuple
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from functools import lru_cache
8
2
  from typing import Tuple
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Any
8
2
 
9
3
  __all__ = [
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Union
8
2
 
9
3
  import numpy as np
@@ -24,21 +24,25 @@ def reflectivity(
24
24
  log: bool = False,
25
25
  abeles_func=None,
26
26
  ):
27
- """Function which computes the reflectivity from thin film parameters using the Abeles matrix formalism
27
+ """Function which computes the reflectivity curves from thin film parameters.
28
+ By default it uses the fast implementation of the Abeles matrix formalism.
28
29
 
29
30
  Args:
30
- q (Tensor): the momentum transfer (q) values
31
- thickness (Tensor): the layer thicknesses
32
- roughness (Tensor): the interlayer roughnesses
33
- sld (Tensor): the SLDs of the layers
34
- dq (Tensor, optional): the resolution for curve smearing. Defaults to None.
31
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
32
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
33
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
34
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
35
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
36
+ dq (Tensor, optional): tensor of resolutions used for curve smearing with shape [batch_size, 1].
37
+ Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
35
38
  gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
36
- constant_dq (bool, optional): whether the smearing is constant. Defaults to True.
39
+ constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
40
+ otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to True.
37
41
  log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
38
- abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default implementation. Defaults to None.
42
+ abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation. Defaults to None.
39
43
 
40
44
  Returns:
41
- Tensor: the computed reflectivity curves
45
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
42
46
  """
43
47
  abeles_func = abeles_func or abeles
44
48
  q = torch.atleast_2d(q)
@@ -14,13 +14,14 @@ def abeles(
14
14
  """Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method
15
15
 
16
16
  Args:
17
- q (Tensor): q values
18
- thickness (Tensor): layer thicknesses
19
- roughness (Tensor): interlayer roughnesses
20
- sld (Tensor): layer SLDs
17
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
18
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
19
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
20
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
21
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
21
22
 
22
23
  Returns:
23
- Tensor: simulated reflectivity curves
24
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
24
25
  """
25
26
  c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
26
27
 
@@ -13,6 +13,20 @@ def kinematical_approximation(
13
13
  apply_fresnel: bool = True,
14
14
  log: bool = False,
15
15
  ):
16
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using the kinematical approximation
17
+
18
+ Args:
19
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
+ apply_fresnel (bool, optional): whether to use the Fresnel coefficient in the computation. Defaults to ``True``.
25
+ log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to ``False``.
26
+
27
+ Returns:
28
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
29
+ """
16
30
  c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
17
31
 
18
32
  batch_size, num_layers = thickness.shape
@@ -12,6 +12,19 @@ def abeles_memory_eff(
12
12
  roughness: Tensor,
13
13
  sld: Tensor,
14
14
  ):
15
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using a memory-efficient implementation the Abeles matrix method.
16
+ It is computationally slower compared to the implementation in the 'abeles' function.
17
+
18
+ Args:
19
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
+
25
+ Returns:
26
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
27
+ """
15
28
  c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
16
29
 
17
30
  batch_size, num_layers = thickness.shape
@@ -68,8 +68,8 @@ def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51
68
68
  lowq = torch.clamp_min_(q.min(1).values, 1e-6)
69
69
  highq = q.max(1).values
70
70
 
71
- start = torch.log10(lowq) - 6 * resolutions / _FWHM
72
- end = torch.log10(highq * (1 + 6 * resolutions / _FWHM))
71
+ start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
72
+ end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
73
73
 
74
74
  interpnums = torch.abs(
75
75
  (torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from pathlib import Path
8
2
 
9
3
  import torch
@@ -9,8 +9,9 @@ class Smearing(object):
9
9
  The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
10
 
11
11
  Args:
12
- sigma_range (tuple, optional): the range for sampling the standard deviation of the gaussians. Defaults to (1e-4, 5e-3).
13
- constant_dq (bool, optional): whether the smearing is constant for each q point. Defaults to True.
12
+ sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (1e-4, 5e-3).
13
+ constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
+ otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
14
15
  gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
15
16
  share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
16
17
  """
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import List, Union
8
2
  from math import sqrt, pi, log10
9
3
 
@@ -1,6 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from .callbacks import JPlotLoss
8
2
 
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from IPython.display import clear_output
8
2
 
9
3
  from ...ml import TrainerCallback, Trainer
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.extensions.matplotlib.losses import plot_losses
8
2
 
9
3
  __all__ = [
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import matplotlib.pyplot as plt
8
2
 
9
3
 
@@ -35,15 +35,17 @@ class EasyInferenceModel(object):
35
35
  config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
36
36
  model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
37
37
  root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
38
+ weights_format (str, optional): format (extension) of the weights file, either 'pt' or 'safetensors'. Defaults to 'safetensors'.
38
39
  repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
39
40
  trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
40
41
  device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
41
42
  """
42
- def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, repo_id: str = 'valentinsingularity/reflectivity',
43
- trainer: PointEstimatorTrainer = None, device='cuda'):
43
+ def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, weights_format: str = 'safetensors',
44
+ repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer = None, device='cuda'):
44
45
  self.config_name = config_name
45
46
  self.model_name = model_name
46
47
  self.root_dir = root_dir
48
+ self.weights_format = weights_format
47
49
  self.repo_id = repo_id
48
50
  self.trainer = trainer
49
51
  self.device = device
@@ -58,7 +60,7 @@ class EasyInferenceModel(object):
58
60
 
59
61
  Args:
60
62
  config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
61
- model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`.
63
+ model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' or '.safetensors' extension), only required if different than: `'model_' + config_name + extension`.
62
64
  root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
63
65
  """
64
66
  if self.config_name == config_name and self.trainer is not None:
@@ -72,9 +74,10 @@ class EasyInferenceModel(object):
72
74
  self.config_name = config_name
73
75
 
74
76
  self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
75
- self.model_name = model_name or 'model_' + config_name_no_extension + '.pt'
76
- if not self.model_name.endswith('.pt'):
77
- self.model_name += '.pt'
77
+ weights_extension = '.' + self.weights_format
78
+ self.model_name = model_name or 'model_' + config_name_no_extension + weights_extension
79
+ if not self.model_name.endswith(weights_extension):
80
+ self.model_name += weights_extension
78
81
  self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
79
82
 
80
83
  config_path = Path(self.config_dir) / self.config_name
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.ml.basic_trainer import *
8
2
  from reflectorch.ml.callbacks import *
9
3
  from reflectorch.ml.trainers import *
@@ -32,6 +26,5 @@ __all__ = [
32
26
  'MultilayerDataLoader',
33
27
  'RealTimeSimTrainer',
34
28
  'DenoisingAETrainer',
35
- 'VAETrainer',
36
29
  'PointEstimatorTrainer',
37
30
  ]
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Optional, Tuple, Iterable, Any, Union, Type
8
2
  from collections import defaultdict
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import torch
8
2
 
9
3
  import numpy as np
reflectorch/ml/loggers.py CHANGED
@@ -1,10 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
1
  __all__ = [
9
2
  'Logger',
10
3
  'Loggers',
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from torch.optim import lr_scheduler
8
2
 
9
3
  import numpy as np
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import numpy as np
8
2
  import torch
9
3
  import torch.nn.functional as F
@@ -16,7 +10,6 @@ from reflectorch.ml.dataloaders import ReflectivityDataLoader
16
10
  __all__ = [
17
11
  'RealTimeSimTrainer',
18
12
  'DenoisingAETrainer',
19
- 'VAETrainer',
20
13
  'PointEstimatorTrainer',
21
14
  ]
22
15
 
@@ -97,30 +90,4 @@ class DenoisingAETrainer(RealTimeSimTrainer):
97
90
  restored_curves = self.model(scaled_noisy_curves)
98
91
  loss = self.criterion(scaled_curves, restored_curves)
99
92
  return {'loss': loss}
100
-
101
-
102
- class VAETrainer(DenoisingAETrainer):
103
- """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
104
- def init(self):
105
- self.loader.calc_denoised_curves = True
106
- self.freebits = 0.05
107
-
108
- def calc_kl(self, z_mu, z_logvar):
109
- return 0.5*(z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)
110
-
111
- def gaussian_log_prob(self, z, mu, logvar):
112
- return -0.5*(np.log(2*np.pi) + logvar + (z-mu)**2/torch.exp(logvar))
113
-
114
- def get_loss_dict(self, batch_data):
115
- """returns the reconstruction loss of the autoencoder"""
116
- scaled_noisy_curves, scaled_curves = batch_data
117
- _, (z_mu, z_logvar, restored_curves_mu, restored_curves_logvar) = self.model(scaled_noisy_curves)
118
-
119
- l_rec = -torch.mean(self.gaussian_log_prob(scaled_curves, restored_curves_mu, restored_curves_logvar), dim=-1)
120
- l_kl = torch.mean(F.relu(self.calc_kl(z_mu, z_logvar) - self.freebits*np.log(2)) + self.freebits*np.log(2), dim=-1)
121
- loss = torch.mean(l_rec + l_kl)/np.log(2)
122
-
123
- l_rec = torch.mean(l_rec)
124
- l_kl = torch.mean(l_kl)
125
-
126
- return {'loss': loss}
93
+
reflectorch/ml/utils.py CHANGED
@@ -1,9 +1,2 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
1
  def is_divisor(num: int, div: int):
9
2
  return num and not num % div
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.models.encoders import *
8
2
  from reflectorch.models.networks import *
9
3
 
@@ -12,7 +6,6 @@ __all__ = [
12
6
  "ConvDecoder",
13
7
  "ConvAutoencoder",
14
8
  "ConvVAE",
15
- "TransformerEncoder",
16
9
  "FnoEncoder",
17
10
  "SpectralConv1d",
18
11
  "ConvResidualNet1D",
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.models.encoders.conv_encoder import (
8
2
  ConvEncoder,
9
3
  ConvDecoder,
@@ -11,12 +5,10 @@ from reflectorch.models.encoders.conv_encoder import (
11
5
  ConvVAE,
12
6
  )
13
7
  from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
14
- from reflectorch.models.encoders.transformers import TransformerEncoder
15
8
  from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
16
9
 
17
10
 
18
11
  __all__ = [
19
- "TransformerEncoder",
20
12
  "ConvEncoder",
21
13
  "ConvDecoder",
22
14
  "ConvAutoencoder",
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import logging
8
2
  from pathlib import Path
9
3
 
@@ -1,8 +1,4 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
1
+
6
2
 
7
3
  from torch import nn
8
4
  from torch.nn import functional as F
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.models.networks.mlp_networks import (
8
2
  NetworkWithPriorsConvEmb,
9
3
  NetworkWithPriorsFnoEmb,
reflectorch/paths.py CHANGED
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union
8
2
  from pathlib import Path
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.runs.train import (
8
2
  run_train,
9
3
  run_train_on_cluster,
@@ -17,6 +11,7 @@ from reflectorch.runs.utils import (
17
11
  get_callbacks_from_config,
18
12
  get_trainer_by_name,
19
13
  get_callbacks_by_name,
14
+ convert_pt_to_safetensors,
20
15
  )
21
16
 
22
17
  from reflectorch.runs.config import load_config
@@ -31,5 +26,6 @@ __all__ = [
31
26
  'get_callbacks_from_config',
32
27
  'get_trainer_by_name',
33
28
  'get_callbacks_by_name',
29
+ 'convert_pt_to_safetensors',
34
30
  'load_config',
35
31
  ]
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import yaml
8
2
 
9
3
  from pathlib import Path
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Union
8
2
  from pathlib import Path
9
3
  import subprocess
reflectorch/runs/train.py CHANGED
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import click
8
2
 
9
3
  from reflectorch.runs.slurm_utils import save_sbatch_and_run
reflectorch/runs/utils.py CHANGED
@@ -2,6 +2,8 @@ from pathlib import Path
2
2
  from typing import Tuple
3
3
 
4
4
  import torch
5
+ import safetensors.torch
6
+ import os
5
7
 
6
8
  from reflectorch import *
7
9
  from reflectorch.runs.config import load_config
@@ -13,6 +15,7 @@ __all__ = [
13
15
  "get_callbacks_from_config",
14
16
  "get_trainer_by_name",
15
17
  "get_callbacks_by_name",
18
+ "convert_pt_to_safetensors",
16
19
  ]
17
20
 
18
21
 
@@ -214,15 +217,25 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
214
217
  model_name = f'model_{config_name}.pt'
215
218
  model_path = SAVED_MODELS_DIR / model_name
216
219
 
217
- try:
218
- state_dict = torch.load(model_path, map_location=inference_device)
219
- except Exception as err:
220
- raise RuntimeError(f'Could not load model from {model_path}') from err
221
-
222
- if 'model' in state_dict:
223
- trainer.model.load_state_dict(state_dict['model'])
220
+ if str(model_path).endswith('.pt'):
221
+ try:
222
+ state_dict = torch.load(model_path, map_location=inference_device)
223
+ except Exception as err:
224
+ raise RuntimeError(f'Could not load model from {model_path}') from err
225
+
226
+ if 'model' in state_dict:
227
+ trainer.model.load_state_dict(state_dict['model'])
228
+ else:
229
+ trainer.model.load_state_dict(state_dict)
230
+
231
+ elif str(model_path).endswith('.safetensors'):
232
+ try:
233
+ load_state_dict_safetensors(model=trainer.model, filename=model_path, device=inference_device)
234
+ except Exception as err:
235
+ raise RuntimeError(f'Could not load model from {model_path}') from err
236
+
224
237
  else:
225
- trainer.model.load_state_dict(state_dict)
238
+ raise RuntimeError('Weigths file with unknown extension')
226
239
 
227
240
  return trainer
228
241
 
@@ -297,3 +310,59 @@ def init_dset(config: dict):
297
310
  )
298
311
 
299
312
  return dset
313
+
314
+ def split_complex_tensors(state_dict):
315
+ new_state_dict = {}
316
+ for key, tensor in state_dict.items():
317
+ if tensor.is_complex():
318
+ new_state_dict[f"{key}_real"] = tensor.real.clone()
319
+ new_state_dict[f"{key}_imag"] = tensor.imag.clone()
320
+ else:
321
+ new_state_dict[key] = tensor
322
+ return new_state_dict
323
+
324
+ def recombine_complex_tensors(state_dict):
325
+ new_state_dict = {}
326
+ keys = list(state_dict.keys())
327
+ visited = set()
328
+
329
+ for key in keys:
330
+ if key.endswith('_real') or key.endswith('_imag'):
331
+ base_key = key[:-5]
332
+ new_state_dict[base_key] = torch.complex(state_dict[base_key + '_real'], state_dict[base_key + '_imag'])
333
+ visited.add(base_key + '_real')
334
+ visited.add(base_key + '_imag')
335
+ elif key not in visited:
336
+ new_state_dict[key] = state_dict[key]
337
+
338
+ return new_state_dict
339
+
340
+ def convert_pt_to_safetensors(input_dir):
341
+ """Creates '.safetensors' files for all the model state dictionaries inside '.pt' files in the specified directory.
342
+
343
+ Args:
344
+ input_dir (str): directory containing model weights
345
+ """
346
+ if not os.path.isdir(input_dir):
347
+ raise ValueError(f"Input directory {input_dir} does not exist")
348
+
349
+ for file_name in os.listdir(input_dir):
350
+ if file_name.endswith('.pt'):
351
+ pt_file_path = os.path.join(input_dir, file_name)
352
+ safetensors_file_path = os.path.join(input_dir, file_name[:-3] + '.safetensors')
353
+
354
+ if os.path.exists(safetensors_file_path):
355
+ print(f"Skipping {pt_file_path}, corresponding .safetensors file already exists.")
356
+ continue
357
+
358
+ print(f"Converting {pt_file_path} to .safetensors format.")
359
+ data_pt = torch.load(pt_file_path)
360
+ model_state_dict = data_pt["model"]
361
+ model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
362
+
363
+ safetensors.torch.save_file(tensors=model_state_dict, filename=safetensors_file_path)
364
+
365
+ def load_state_dict_safetensors(model, filename, device):
366
+ state_dict = safetensors.torch.load_file(filename=filename, device=device)
367
+ state_dict = recombine_complex_tensors(state_dict)
368
+ model.load_state_dict(state_dict)
reflectorch/utils.py CHANGED
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import numpy as np
8
2
  from numpy import ndarray
9
3
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: reflectorch
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: A Pytorch-based package for the analysis of reflectometry data
5
5
  Author-email: Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>
6
6
  Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
@@ -27,6 +27,7 @@ Requires-Dist: click
27
27
  Requires-Dist: matplotlib
28
28
  Requires-Dist: ipywidgets
29
29
  Requires-Dist: huggingface-hub
30
+ Requires-Dist: safetensors
30
31
  Provides-Extra: build
31
32
  Requires-Dist: build ; extra == 'build'
32
33
  Requires-Dist: twine ; extra == 'build'
@@ -0,0 +1,83 @@
1
+ reflectorch/__init__.py,sha256=N98TX0LJtHBT8Q8MbUedjVlHSQJXii0EMmwuhSKHvOQ,736
2
+ reflectorch/paths.py,sha256=Z_VRVknkqRn03ULShc8YCp0czqqMJ9w2CRuUi8e2OV0,814
3
+ reflectorch/test_config.py,sha256=1T7pMJ-WYLEu-4WtYMQxcJrqXvgdpvJ1yi2qINd0kNA,99
4
+ reflectorch/train.py,sha256=-c8ac1fpjrCiEwnAaXg_rcBNl1stO1V5p5afx_78xHs,87
5
+ reflectorch/train_on_cluster.py,sha256=aG3_g5_rzL8iL1tvtdY9ueJTo1f2Pn8lGJgudrrRknU,109
6
+ reflectorch/utils.py,sha256=ehG5dU02_WIFPIVGPsSjHl3Ji10MQ8-jsCwRkwUO7D4,1993
7
+ reflectorch/data_generation/__init__.py,sha256=TcF7kf16GCsQpp6cIXj-s4vb_gjrtv7tgXwLSeI8Xy8,3353
8
+ reflectorch/data_generation/dataset.py,sha256=wEHMIzA1XQGQkzkN8WQRNqG53lyzxIVc80nMJKs0mgI,7273
9
+ reflectorch/data_generation/likelihoods.py,sha256=gnqmsEfsZnCC6WuPcIe9rFl4SeiLv_LyNlTVM-5YTE8,2862
10
+ reflectorch/data_generation/noise.py,sha256=c4ytClId3t3T5FO8r0NqAVz-x_zYWAJ2VkpxcvssrVY,16159
11
+ reflectorch/data_generation/process_data.py,sha256=iBvWB4_X5TTIZTacsq9baBZPyGCRyI1G-J3l1BVMuD4,1180
12
+ reflectorch/data_generation/q_generator.py,sha256=Etlp_6Tj2it2nK8PQSWgQZBLh8_459P985GrZs9U1lE,8529
13
+ reflectorch/data_generation/scale_curves.py,sha256=hNZoGkA9FMdLUarp0fwj5nCDYvpy9NbSiO85_rCaNr8,4084
14
+ reflectorch/data_generation/smearing.py,sha256=MfHQ2sa-wPu5stO835m8i1-3NnJYvxS3qpzNs1jXWAE,3109
15
+ reflectorch/data_generation/utils.py,sha256=3Di6I5Ihy2rm8NMeCEvIVqH0IXFEX2bcxkHSc6y9WgU,5607
16
+ reflectorch/data_generation/priors/__init__.py,sha256=ZUaQUgNR44MQGYcPVJaSzzoE270710RqdKAPcaBNHWo,1941
17
+ reflectorch/data_generation/priors/base.py,sha256=JNa2A3F4IWaEwV7gyTQjVdCRevUI43UUpTOm1U3XA8k,1705
18
+ reflectorch/data_generation/priors/exp_subprior_sampler.py,sha256=hjHey_32HGgZ4ojOAhmmx8jcQPCyvcPRgxMlfthSdOo,11551
19
+ reflectorch/data_generation/priors/independent_priors.py,sha256=ZdFCW4Ea6cK9f0Pnk2F-F6JLOgZSMsc77ZWrbQl1gkE,7241
20
+ reflectorch/data_generation/priors/multilayer_models.py,sha256=lAf-HJPbIDnbD1ecDTvx03TfA4jLN5tTLbwaBCiYgWM,7763
21
+ reflectorch/data_generation/priors/multilayer_structures.py,sha256=b1LTkzMK_R2hgbAmif96zRRzZRV2V-dD7eSTaNalMU8,3693
22
+ reflectorch/data_generation/priors/no_constraints.py,sha256=qioTThJ17NixCaMIIf4dusPtvcK_uI2jOoysyXKnkZ4,7292
23
+ reflectorch/data_generation/priors/parametric_models.py,sha256=vaAGcm9Ky-coidNliE4R1YvoqCcbNwEwJVCwwmkKI4E,25139
24
+ reflectorch/data_generation/priors/parametric_subpriors.py,sha256=mN4MELBOQurNponknB_1n46fiNncXkI5PnnPL3hJee4,14352
25
+ reflectorch/data_generation/priors/params.py,sha256=JpH-LRxTztc0QP-3QwjLdu0MsVc1rSixtcscYGs_2Ew,8238
26
+ reflectorch/data_generation/priors/sampler_strategies.py,sha256=U-v5dXpFLJq939aQKvNl7n1Lih-y97K8WJCXBFRQiA0,13059
27
+ reflectorch/data_generation/priors/scaler_mixin.py,sha256=gI64v2KOZugSJWaLKASR34fn6qVFl-aoeVA1BR5yXNg,2648
28
+ reflectorch/data_generation/priors/subprior_sampler.py,sha256=TE8DOQhxVr69VmWSjwHyElpgVjOhZKBIPChFWNVYzRc,14769
29
+ reflectorch/data_generation/priors/utils.py,sha256=bmIZYHq95gG68kETfNH6ygR39oitUEJ0eCO_Fb8maH0,3836
30
+ reflectorch/data_generation/reflectivity/__init__.py,sha256=x8XivTg1ygz1Ky_N1eVvDfz2KS-9j9U2px0kQM3Bgf4,3145
31
+ reflectorch/data_generation/reflectivity/abeles.py,sha256=3NXGXaBiOyAiFfy3Za7WATNLhTpKrNzieYdCsmljvQ8,2625
32
+ reflectorch/data_generation/reflectivity/kinematical.py,sha256=QLX3erfSmEwP2n_x8gQlfaGr3pNAxcJ95Efx4Uu2CKk,2626
33
+ reflectorch/data_generation/reflectivity/memory_eff.py,sha256=iIufbdEJGv9nc-Sr51gFQHEYVdgmbrAD-F3ydNE30nU,4024
34
+ reflectorch/data_generation/reflectivity/numpy_implementations.py,sha256=QBzn4yVnOdlkHeeR-ZFPS115GnLdO9lMTGO2d3YhG9I,3177
35
+ reflectorch/data_generation/reflectivity/smearing.py,sha256=pc95Lig9NIWtHDCKnKbLp7G_kmL7_3YB6ZpDphgC2D8,4001
36
+ reflectorch/extensions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
+ reflectorch/extensions/jupyter/__init__.py,sha256=bs6RQQ8LRcKAJu2I9K8BZEMQSll3kLdirRvxLbgsY9w,72
38
+ reflectorch/extensions/jupyter/callbacks.py,sha256=K4Z8Mlm_Pa80YDsNOdMZLIJZyVrg09kBVzVdFHKgQFk,1074
39
+ reflectorch/extensions/matplotlib/__init__.py,sha256=ARmc228xbZUj4fWNq5fVxZ8PN444fjDNxNP4ytVCyfo,104
40
+ reflectorch/extensions/matplotlib/losses.py,sha256=H2vrjNqzmQ0BP4377uN4uhlH88aS5OAfA7pwRM2WHqA,686
41
+ reflectorch/inference/__init__.py,sha256=i0KNn83XN33mLrV7bpHdLd0SXxuGCKfbQcIoa247Uts,834
42
+ reflectorch/inference/inference_model.py,sha256=aVAZkTfVe2cweitcVVyVuqDsb6FkLbh4Wun4_K_MkTM,36599
43
+ reflectorch/inference/multilayer_fitter.py,sha256=0CxDpLOEp1terR4N39yFlxhvA8qAbHf_01NbmvYadck,5510
44
+ reflectorch/inference/multilayer_inference_model.py,sha256=hH_-dJGdMOox8GHXdM_nODXDlNgh_v449xW5FmklRdo,7575
45
+ reflectorch/inference/query_matcher.py,sha256=Dk49dW0XreeCjufzYBTKchfTdVbG6759ryV6I-wQL60,3387
46
+ reflectorch/inference/record_time.py,sha256=3er-aoR8Sd_Kc4qNwUmRqkEz4FYhVxdi1ARnBohybzM,1140
47
+ reflectorch/inference/sampler_solution.py,sha256=DeJM3EXEb6S5EiASj3mmdNI-Y07Cr5UzzA5oq-vEB-Q,2288
48
+ reflectorch/inference/scipy_fitter.py,sha256=339M33OdmfgOpifJGLYk4KVcnnNJrY6_aH7Lz6Vtt24,5404
49
+ reflectorch/inference/torch_fitter.py,sha256=j1NzkzLCmQ4H6LfIi82LsSBmIdunnWzm3kbGx-hqvDs,3391
50
+ reflectorch/inference/preprocess_exp/__init__.py,sha256=bR6H-xgBD96z4P9P1T2ECnWvalrimdMTfTNArIWPLy0,383
51
+ reflectorch/inference/preprocess_exp/attenuation.py,sha256=UKDfUjCKKMgAuEs5Ttyo0KEQmpvHZI52UgVflh7T81A,1518
52
+ reflectorch/inference/preprocess_exp/cut_with_q_ratio.py,sha256=CbtwIw7iJNkoVxqTHKzONBgGFwOsCUyFoTIQ8XMLTfY,1149
53
+ reflectorch/inference/preprocess_exp/footprint.py,sha256=xc409M5X-QW0Ce_6dEZdj8NkOY1vd0LaGpPQFxiOOR0,2625
54
+ reflectorch/inference/preprocess_exp/interpolation.py,sha256=o2v-mlfRYzeaaikeQVeY7EA7j-K42dMfT12oN3mk51k,694
55
+ reflectorch/inference/preprocess_exp/normalize.py,sha256=09v7nZdtw6SnW_67xFrPnqzOA5AtFBGarjK4Pfn4VoE,695
56
+ reflectorch/inference/preprocess_exp/preprocess.py,sha256=pyyq8fSvcm1bWAShzGHYnKOc55Rofh4FIw1AC7Smq-U,5111
57
+ reflectorch/ml/__init__.py,sha256=TJERkE_itNOH3GwtC-9Xv0rZ70qEkMukFLD1qXsMdOQ,730
58
+ reflectorch/ml/basic_trainer.py,sha256=2z_3Iy_9dEgOha8da69RjiMQ_89C2SoJM9omh0WO-ek,9647
59
+ reflectorch/ml/callbacks.py,sha256=blZNFubtRQkcx5sQrTNkARiPj75T8VdCn6CYJNN7hKg,2743
60
+ reflectorch/ml/dataloaders.py,sha256=IvKmsH5gX_b-00KRFeL-x3keEfBcvYkQFWGWd8Caj-I,1073
61
+ reflectorch/ml/loggers.py,sha256=oYZskGMbbWfW3sOXOC-4F-DIL-cLhYrRFeggrmPtPGM,743
62
+ reflectorch/ml/schedulers.py,sha256=AO-dS1bZgn7p-gGJJhocmL6Vc8XLMwynZxvko4BCEsw,10048
63
+ reflectorch/ml/trainers.py,sha256=DXO0_ue3iVQKi8vDttEQZbNLcn3xwkRBIFrw0XTomn4,3643
64
+ reflectorch/ml/utils.py,sha256=_OKkc6o5od7vnfBNZKWsGApJxA62TrAlLok92s0nG4k,71
65
+ reflectorch/models/__init__.py,sha256=vsGXi1BoLVyki_OF_Ezv9GXCpuJ22liGwolVZ0rxyyM,335
66
+ reflectorch/models/activations.py,sha256=LDiIxCnLFb8r_TRBZSt6vdOZmexCWAGa5DfE_SotUL8,1431
67
+ reflectorch/models/encoders/__init__.py,sha256=X9cHeWjJGVXNFwii9ZJasNbqrV4kt9QyJYPdKfxgp04,443
68
+ reflectorch/models/encoders/conv_encoder.py,sha256=Xa9Yo2lDctFWlnz8vIIdzwakcE_cgBR_J92UL59XhbA,7653
69
+ reflectorch/models/encoders/conv_res_net.py,sha256=VP1CCKdNqsva2R4l-IV6d63mQk6bI4Aresfd6b6_YKU,3343
70
+ reflectorch/models/encoders/fno.py,sha256=s_S7hnpLE7iGfyvnQ-QvTh0rKO5KFiy5tUYau4sJbvI,4693
71
+ reflectorch/models/networks/__init__.py,sha256=zHUvrlb0KVOpRrwVjwjR1g8sVWX6VH80FLGqu4jiux8,291
72
+ reflectorch/models/networks/mlp_networks.py,sha256=C7py6qCBVaYYt0FMEf8gbT4lndArKpUYYgTN1001-T8,11614
73
+ reflectorch/models/networks/residual_net.py,sha256=msDJaDw7qL9ebEW1Avw6Qw0lgni68AMgF4kXiJKzeaQ,4637
74
+ reflectorch/runs/__init__.py,sha256=xjuZGjqZuEovlpe9Jj5d8Nn5ii-5jvAYdeHT5oYaYVI,723
75
+ reflectorch/runs/config.py,sha256=8YtUOXr_DvNvgpu59CNAr3KrQijx1AGDE95gYrsuwsM,804
76
+ reflectorch/runs/slurm_utils.py,sha256=Zyj4_K5YpiWNJhgpFLWYHsSaaI-mgVEWsN15Gd_BhI0,2600
77
+ reflectorch/runs/train.py,sha256=e-Jj0fwYlUB2NLDxCzy0cLsSrJbNqd5pN5T1L7-Eiig,2560
78
+ reflectorch/runs/utils.py,sha256=bp3Nwd6pfX5yFlTVX28zDtSTo2gg2JZh9HCpZpvnJWk,12637
79
+ reflectorch-1.2.0.dist-info/LICENSE.txt,sha256=2kX9kLKiIRiQRqUXwk3J-Ba3fqmztNu8ORskLBlAuKM,1098
80
+ reflectorch-1.2.0.dist-info/METADATA,sha256=dulVsQJvbWfaEPJ71tgp_CVS1AwUnba2Y-Wph3c1pRk,7805
81
+ reflectorch-1.2.0.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
82
+ reflectorch-1.2.0.dist-info/top_level.txt,sha256=2EyIWrt4SeZ3hNadLXvEVpPFhyoZ4An7YflP4y_E3Fc,12
83
+ reflectorch-1.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.0.4)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,56 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
- import torch
9
- from torch import nn
10
-
11
-
12
- class TransformerEncoder(nn.Module):
13
- def __init__(
14
- self,
15
- dim: int = 64,
16
- nhead: int = 8,
17
- num_encoder_layers: int = 4,
18
- num_decoder_layers: int = 2,
19
- dim_feedforward: int = 512,
20
- dropout: float = 0.01,
21
- activation: str = 'gelu',
22
- in_dim: int = 2,
23
- out_dim: int = None,
24
- ):
25
-
26
- super().__init__()
27
-
28
- self.in_projector = nn.Linear(in_dim, dim)
29
-
30
- self.dim = dim
31
-
32
- self.transformer = nn.Transformer(
33
- dim, nhead=nhead,
34
- num_encoder_layers=num_encoder_layers,
35
- num_decoder_layers=num_decoder_layers,
36
- dim_feedforward=dim_feedforward,
37
- dropout=dropout,
38
- activation=activation
39
- )
40
-
41
- if out_dim:
42
- self.out_projector = nn.Linear(dim, out_dim)
43
- else:
44
- self.out_projector = None
45
-
46
- def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, **kwargs):
47
- src = self.in_projector(src.transpose(1, 2)).transpose(0, 1)
48
-
49
- res = self.transformer(
50
- src, tgt, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, **kwargs
51
- )
52
-
53
- if self.out_projector:
54
- res = self.out_projector(res).squeeze(-1)
55
-
56
- return res.squeeze(0)
@@ -1,84 +0,0 @@
1
- reflectorch/__init__.py,sha256=Xs-ZhOJ1D_yluLD98aj6Pr2ssy-tDFwtCnIIukDFOW4,895
2
- reflectorch/paths.py,sha256=8xAtonoxLxfwnZF4rriUPidPaaz9pNvOEibFXiXu9yc,973
3
- reflectorch/test_config.py,sha256=1T7pMJ-WYLEu-4WtYMQxcJrqXvgdpvJ1yi2qINd0kNA,99
4
- reflectorch/train.py,sha256=-c8ac1fpjrCiEwnAaXg_rcBNl1stO1V5p5afx_78xHs,87
5
- reflectorch/train_on_cluster.py,sha256=aG3_g5_rzL8iL1tvtdY9ueJTo1f2Pn8lGJgudrrRknU,109
6
- reflectorch/utils.py,sha256=TIRilK7px-AAw3MWNmhBPC8dzq7RI5YbkNbEcqDvrNs,2152
7
- reflectorch/data_generation/__init__.py,sha256=7lbWlxj4tMuhgTzz_lFgrRkYp2AVJm262ZLoJx_Riyw,3512
8
- reflectorch/data_generation/dataset.py,sha256=O8R_RaRBXFYfvsI3co67fDeclEG1pMO50RXZTIPUf_4,7432
9
- reflectorch/data_generation/likelihoods.py,sha256=mXplS5nwoH4sAeHqp3ciEf--GMA3ulPsnxasix0HMn0,3021
10
- reflectorch/data_generation/noise.py,sha256=cZDTYPlteIclEhnRcV5DNKG64cvz-Q4VHiWGGdZkeDk,16318
11
- reflectorch/data_generation/process_data.py,sha256=0kpVWqkpDkCDEybIu2uMYdr-ytfT0sVWKfY-BbXVo9c,1339
12
- reflectorch/data_generation/q_generator.py,sha256=8_TLXamHF-5Lsr1g3WgAd6gcNuQgXPZ9hSJNazY247A,8688
13
- reflectorch/data_generation/scale_curves.py,sha256=vMiyq9Y3S5nFaf5RA9g_Id96lrcr4D_O0s_Amz_3fWM,4243
14
- reflectorch/data_generation/smearing.py,sha256=UqcWBATSyvds-P_Soq0hUfBipe3VJstBE88m_v0z-rc,2929
15
- reflectorch/data_generation/utils.py,sha256=aBICDneVaZhlEqCyMJAzuDXbCe0kLUyJ_9VHxUTBrao,5766
16
- reflectorch/data_generation/priors/__init__.py,sha256=5tN-0V8GduS8N1riZ6sDUXp_Wr7WM6840tlka4SsFqU,2100
17
- reflectorch/data_generation/priors/base.py,sha256=y7e6AWxbMekw5MHtW_h3_VhmzcGS4T6hnhqso4p2MDA,1864
18
- reflectorch/data_generation/priors/exp_subprior_sampler.py,sha256=WPjc7aGLr_qJiiNHA1yfKPd1EHJKcmOIixkbhhy7QKg,11710
19
- reflectorch/data_generation/priors/independent_priors.py,sha256=T4_uvX74iHScPGa5u_-fvdwurUu97AuRssfDL3tYxKY,7400
20
- reflectorch/data_generation/priors/multilayer_models.py,sha256=lAf-HJPbIDnbD1ecDTvx03TfA4jLN5tTLbwaBCiYgWM,7763
21
- reflectorch/data_generation/priors/multilayer_structures.py,sha256=nGVoMstkn--0kdKlT5o3VsGZ0dHUdMoeSXeI9t9V61Q,3852
22
- reflectorch/data_generation/priors/no_constraints.py,sha256=jLzsKXyQFHW2dtKViCpJpEbNjk2njjWpOaUxD8Hc3wE,7451
23
- reflectorch/data_generation/priors/parametric_models.py,sha256=vaAGcm9Ky-coidNliE4R1YvoqCcbNwEwJVCwwmkKI4E,25139
24
- reflectorch/data_generation/priors/parametric_subpriors.py,sha256=hSuSlZO1KPVAftgYzdsF6CtSDGV709v_DFjcPNKzc0g,14511
25
- reflectorch/data_generation/priors/params.py,sha256=fND4ZNlplNertLHvIimfW0KKc2uWtPTUpwygCRskIu4,8397
26
- reflectorch/data_generation/priors/sampler_strategies.py,sha256=U-v5dXpFLJq939aQKvNl7n1Lih-y97K8WJCXBFRQiA0,13059
27
- reflectorch/data_generation/priors/scaler_mixin.py,sha256=gI64v2KOZugSJWaLKASR34fn6qVFl-aoeVA1BR5yXNg,2648
28
- reflectorch/data_generation/priors/subprior_sampler.py,sha256=6kaLIvJ_dSNfPCAbOzW1vyYKU3zQ6Qdnuu8nMRFkEoQ,14928
29
- reflectorch/data_generation/priors/utils.py,sha256=-iCLSkb_MrDIvmfEZKxpWMJDwj9PtZGP8LwbbM9JYos,3995
30
- reflectorch/data_generation/reflectivity/__init__.py,sha256=whto-vWKjWt3ZziqajgJ5M2RLZWXavSZpJLvyYcuom4,2280
31
- reflectorch/data_generation/reflectivity/abeles.py,sha256=4G42f1XPcu6hvYbvV3HCfivHx7oRYq_a6AzHyGK_UeM,2091
32
- reflectorch/data_generation/reflectivity/kinematical.py,sha256=tNmh1aYHA9-eFxnY-hJZBbctkpC89tht4o9rxbDBMdU,1475
33
- reflectorch/data_generation/reflectivity/memory_eff.py,sha256=OzPZJoFupRdlz8Qchndz0A5aYhO33yU_yYMdnVxgsNg,2997
34
- reflectorch/data_generation/reflectivity/numpy_implementations.py,sha256=QBzn4yVnOdlkHeeR-ZFPS115GnLdO9lMTGO2d3YhG9I,3177
35
- reflectorch/data_generation/reflectivity/smearing.py,sha256=XVtFfQFT-Ouyl8wa_IKc0p8ZZg8PCNtZIqe_RFy6E6E,3983
36
- reflectorch/extensions/__init__.py,sha256=XeuLafCqNwwmfWTcJgbuzCzpiypBG7ZatbIZrT9TvBA,159
37
- reflectorch/extensions/jupyter/__init__.py,sha256=inEXUpeVWeAhkW5nkW_dASBzsAlv4htvj7GIS7svIGk,231
38
- reflectorch/extensions/jupyter/callbacks.py,sha256=piDR4ax6JFSOPyqfkk-nxrhyWYdMrxgC8ocoaJbbbu8,1233
39
- reflectorch/extensions/matplotlib/__init__.py,sha256=8II5pU8015VrMjFI8szCKBP1zjz0dFAzBn7smNQzGuA,263
40
- reflectorch/extensions/matplotlib/losses.py,sha256=TqcyrFrls1N6RXotFyXDF64Xz6nJGg7n5XMSXFdeRtQ,845
41
- reflectorch/inference/__init__.py,sha256=i0KNn83XN33mLrV7bpHdLd0SXxuGCKfbQcIoa247Uts,834
42
- reflectorch/inference/inference_model.py,sha256=QvnQDRRcZByHDJh84T5W8O3X_aJLZI6AmlskE6BaBlU,36265
43
- reflectorch/inference/multilayer_fitter.py,sha256=0CxDpLOEp1terR4N39yFlxhvA8qAbHf_01NbmvYadck,5510
44
- reflectorch/inference/multilayer_inference_model.py,sha256=hH_-dJGdMOox8GHXdM_nODXDlNgh_v449xW5FmklRdo,7575
45
- reflectorch/inference/query_matcher.py,sha256=Dk49dW0XreeCjufzYBTKchfTdVbG6759ryV6I-wQL60,3387
46
- reflectorch/inference/record_time.py,sha256=3er-aoR8Sd_Kc4qNwUmRqkEz4FYhVxdi1ARnBohybzM,1140
47
- reflectorch/inference/sampler_solution.py,sha256=DeJM3EXEb6S5EiASj3mmdNI-Y07Cr5UzzA5oq-vEB-Q,2288
48
- reflectorch/inference/scipy_fitter.py,sha256=339M33OdmfgOpifJGLYk4KVcnnNJrY6_aH7Lz6Vtt24,5404
49
- reflectorch/inference/torch_fitter.py,sha256=j1NzkzLCmQ4H6LfIi82LsSBmIdunnWzm3kbGx-hqvDs,3391
50
- reflectorch/inference/preprocess_exp/__init__.py,sha256=bR6H-xgBD96z4P9P1T2ECnWvalrimdMTfTNArIWPLy0,383
51
- reflectorch/inference/preprocess_exp/attenuation.py,sha256=UKDfUjCKKMgAuEs5Ttyo0KEQmpvHZI52UgVflh7T81A,1518
52
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py,sha256=CbtwIw7iJNkoVxqTHKzONBgGFwOsCUyFoTIQ8XMLTfY,1149
53
- reflectorch/inference/preprocess_exp/footprint.py,sha256=xc409M5X-QW0Ce_6dEZdj8NkOY1vd0LaGpPQFxiOOR0,2625
54
- reflectorch/inference/preprocess_exp/interpolation.py,sha256=o2v-mlfRYzeaaikeQVeY7EA7j-K42dMfT12oN3mk51k,694
55
- reflectorch/inference/preprocess_exp/normalize.py,sha256=09v7nZdtw6SnW_67xFrPnqzOA5AtFBGarjK4Pfn4VoE,695
56
- reflectorch/inference/preprocess_exp/preprocess.py,sha256=pyyq8fSvcm1bWAShzGHYnKOc55Rofh4FIw1AC7Smq-U,5111
57
- reflectorch/ml/__init__.py,sha256=wdItiY13KD6PlCrHnHVcdpQOgTB5iUSj_qn4BZFM_uU,908
58
- reflectorch/ml/basic_trainer.py,sha256=Kr-oVAlmZjkL9MuJDxHAKA_1tTqUvX-3Q2BETWWlsmE,9806
59
- reflectorch/ml/callbacks.py,sha256=YxA_VUlfsE9Uh9MotPe2tXq6rbCyoG52LfI3e_YQy3w,2902
60
- reflectorch/ml/dataloaders.py,sha256=IvKmsH5gX_b-00KRFeL-x3keEfBcvYkQFWGWd8Caj-I,1073
61
- reflectorch/ml/loggers.py,sha256=XC7KwqHDTSr_2iWyBatOQO6EuFtK1bvwUVBcoA-D7fg,904
62
- reflectorch/ml/schedulers.py,sha256=xIloPpmCSnB35YniyzcDZoXHJFMT_rz0CWh2xiXnDak,10207
63
- reflectorch/ml/trainers.py,sha256=36R_oU33UHoebd7F1eNVlQ1GdhJXeGMgWsg-RrId2Mg,5014
64
- reflectorch/ml/utils.py,sha256=VfgWVjnXTrvw8eIMhFJXEaf7gkmp3rTUHrZvy42b_2k,232
65
- reflectorch/models/__init__.py,sha256=4k6JTr4XOhxtchCIlkcYNW51CmdIPsVAGfAwuhhTgYI,521
66
- reflectorch/models/activations.py,sha256=LDiIxCnLFb8r_TRBZSt6vdOZmexCWAGa5DfE_SotUL8,1431
67
- reflectorch/models/encoders/__init__.py,sha256=9PT31292CtfXlm1jucd7-2h69M_2vQNYQeaFX0lM2EM,702
68
- reflectorch/models/encoders/conv_encoder.py,sha256=Ns5df_baTh-7lu-xRaO_jnnar1apsXGKNDfbaFIHv0U,7812
69
- reflectorch/models/encoders/conv_res_net.py,sha256=_TYbF9GMThOtYuGmiyzIkClbq8wwRA251IFzUlxMwdU,3497
70
- reflectorch/models/encoders/fno.py,sha256=s_S7hnpLE7iGfyvnQ-QvTh0rKO5KFiy5tUYau4sJbvI,4693
71
- reflectorch/models/encoders/transformers.py,sha256=hfgGr2HiTj7DvaQnm_5RU_osPxVZn-L0r5OGqF8ZJZ4,1610
72
- reflectorch/models/networks/__init__.py,sha256=_NBjIl4QNLAuzBb2IaOIGG37iWwGzVQwuQhbcP9lxpI,450
73
- reflectorch/models/networks/mlp_networks.py,sha256=C7py6qCBVaYYt0FMEf8gbT4lndArKpUYYgTN1001-T8,11614
74
- reflectorch/models/networks/residual_net.py,sha256=msDJaDw7qL9ebEW1Avw6Qw0lgni68AMgF4kXiJKzeaQ,4637
75
- reflectorch/runs/__init__.py,sha256=2BcdMJul5yd726p8w4iqlKhygAAxiu1zu0MKDe96bWk,816
76
- reflectorch/runs/config.py,sha256=6aEub3NV0jmoREdegV7S3Nz-5o1xPZnmPpNgYfMpdys,963
77
- reflectorch/runs/slurm_utils.py,sha256=T5vsWrcduq_N9mS9XAXjAbx7PHcYiiiwjdS0iiXh_TI,2759
78
- reflectorch/runs/train.py,sha256=NaHMUYApjOCeajyS5UMQkeCVyxVtroohXK5ceHNLOkM,2719
79
- reflectorch/runs/utils.py,sha256=j_gJYrw4fIZvKJWXPdt1mOR0d_Ht6pg0rDjE2iOTLc8,9737
80
- reflectorch-1.1.0.dist-info/LICENSE.txt,sha256=2kX9kLKiIRiQRqUXwk3J-Ba3fqmztNu8ORskLBlAuKM,1098
81
- reflectorch-1.1.0.dist-info/METADATA,sha256=Jj8WKCgTrNn8_TT7GZS6gbLX2BxLauF3J4FJbo-18ZM,7777
82
- reflectorch-1.1.0.dist-info/WHEEL,sha256=rWxmBtp7hEUqVLOnTaDOPpR-cZpCDkzhhcBce-Zyd5k,91
83
- reflectorch-1.1.0.dist-info/top_level.txt,sha256=2EyIWrt4SeZ3hNadLXvEVpPFhyoZ4An7YflP4y_E3Fc,12
84
- reflectorch-1.1.0.dist-info/RECORD,,