reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,81 +1,81 @@
1
- import torch
2
-
3
- import numpy as np
4
-
5
- from reflectorch.ml.basic_trainer import (
6
- TrainerCallback,
7
- Trainer,
8
- )
9
- from reflectorch.ml.utils import is_divisor
10
-
11
- __all__ = [
12
- 'SaveBestModel',
13
- 'LogLosses',
14
- ]
15
-
16
-
17
- class SaveBestModel(TrainerCallback):
18
- """Callback for periodically saving the best model weights
19
-
20
- Args:
21
- path (str): path for saving the model weights
22
- freq (int, optional): frequency in iterations at which the current average loss is evaluated. Defaults to 50.
23
- average (int, optional): number of recent iterations over which the average loss is computed. Defaults to 10.
24
- """
25
-
26
- def __init__(self, path: str, freq: int = 50, average: int = 10):
27
- self.path = path
28
- self.average = average
29
- self._best_loss = np.inf
30
- self.freq = freq
31
-
32
- def end_batch(self, trainer: Trainer, batch_num: int) -> None:
33
- """checks if the current average loss has improved from the previous save, if true the model is saved
34
-
35
- Args:
36
- trainer (Trainer): the trainer object
37
- batch_num (int): the current iteration / batch
38
- """
39
- if is_divisor(batch_num, self.freq):
40
-
41
- loss = np.mean(trainer.losses['total_loss'][-self.average:])
42
-
43
- if loss < self._best_loss:
44
- self._best_loss = loss
45
- self.save(trainer, batch_num)
46
-
47
- def save(self, trainer: Trainer, batch_num: int):
48
- """saves a dictionary containing the network weights, the learning rates, the losses and the current \
49
- best loss with its corresponding iteration to the disk
50
-
51
- Args:
52
- trainer (Trainer): the trainer object
53
- batch_num (int): the current iteration / batch
54
- """
55
- prev_save = trainer.callback_params.pop('saved_iteration', 0)
56
- trainer.callback_params['saved_iteration'] = batch_num
57
- save_dict = {
58
- 'model': trainer.model.state_dict(),
59
- 'lrs': trainer.lrs,
60
- 'losses': trainer.losses,
61
- 'prev_save': prev_save,
62
- 'batch_num': batch_num,
63
- 'best_loss': self._best_loss
64
- }
65
- torch.save(save_dict, self.path)
66
-
67
-
68
- class LogLosses(TrainerCallback):
69
- """Callback for logging the training losses"""
70
- def end_batch(self, trainer: Trainer, batch_num: int) -> None:
71
- """log loss at the current iteration
72
-
73
- Args:
74
- trainer (Trainer): the trainer object
75
- batch_num (int): the index of the current iteration / batch
76
- """
77
- for loss_name, loss_values in trainer.losses.items():
78
- try:
79
- trainer.log(f'train/{loss_name}', loss_values[-1])
80
- except IndexError:
1
+ import torch
2
+
3
+ import numpy as np
4
+
5
+ from reflectorch.ml.basic_trainer import (
6
+ TrainerCallback,
7
+ Trainer,
8
+ )
9
+ from reflectorch.ml.utils import is_divisor
10
+
11
+ __all__ = [
12
+ 'SaveBestModel',
13
+ 'LogLosses',
14
+ ]
15
+
16
+
17
+ class SaveBestModel(TrainerCallback):
18
+ """Callback for periodically saving the best model weights
19
+
20
+ Args:
21
+ path (str): path for saving the model weights
22
+ freq (int, optional): frequency in iterations at which the current average loss is evaluated. Defaults to 50.
23
+ average (int, optional): number of recent iterations over which the average loss is computed. Defaults to 10.
24
+ """
25
+
26
+ def __init__(self, path: str, freq: int = 50, average: int = 10):
27
+ self.path = path
28
+ self.average = average
29
+ self._best_loss = np.inf
30
+ self.freq = freq
31
+
32
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
33
+ """checks if the current average loss has improved from the previous save, if true the model is saved
34
+
35
+ Args:
36
+ trainer (Trainer): the trainer object
37
+ batch_num (int): the current iteration / batch
38
+ """
39
+ if is_divisor(batch_num, self.freq):
40
+
41
+ loss = np.mean(trainer.losses['total_loss'][-self.average:])
42
+
43
+ if loss < self._best_loss:
44
+ self._best_loss = loss
45
+ self.save(trainer, batch_num)
46
+
47
+ def save(self, trainer: Trainer, batch_num: int):
48
+ """saves a dictionary containing the network weights, the learning rates, the losses and the current \
49
+ best loss with its corresponding iteration to the disk
50
+
51
+ Args:
52
+ trainer (Trainer): the trainer object
53
+ batch_num (int): the current iteration / batch
54
+ """
55
+ prev_save = trainer.callback_params.pop('saved_iteration', 0)
56
+ trainer.callback_params['saved_iteration'] = batch_num
57
+ save_dict = {
58
+ 'model': trainer.model.state_dict(),
59
+ 'lrs': trainer.lrs,
60
+ 'losses': trainer.losses,
61
+ 'prev_save': prev_save,
62
+ 'batch_num': batch_num,
63
+ 'best_loss': self._best_loss
64
+ }
65
+ torch.save(save_dict, self.path)
66
+
67
+
68
+ class LogLosses(TrainerCallback):
69
+ """Callback for logging the training losses"""
70
+ def end_batch(self, trainer: Trainer, batch_num: int) -> None:
71
+ """log loss at the current iteration
72
+
73
+ Args:
74
+ trainer (Trainer): the trainer object
75
+ batch_num (int): the index of the current iteration / batch
76
+ """
77
+ for loss_name, loss_values in trainer.losses.items():
78
+ try:
79
+ trainer.log(f'train/{loss_name}', loss_values[-1])
80
+ except IndexError:
81
81
  continue
@@ -1,27 +1,27 @@
1
- from torch import Tensor
2
-
3
-
4
- from reflectorch.data_generation import BasicDataset
5
- from reflectorch.data_generation.reflectivity import kinematical_approximation
6
- from reflectorch.data_generation.priors import BasicParams
7
- from reflectorch.ml.basic_trainer import DataLoader
8
-
9
-
10
- __all__ = [
11
- "ReflectivityDataLoader",
12
- "MultilayerDataLoader",
13
- ]
14
-
15
-
16
- class ReflectivityDataLoader(BasicDataset, DataLoader):
17
- """Dataloader for reflectivity data, combining functionality from the ``BasicDataset`` (basic dataset class for reflectivity) and the ``DataLoader`` (which inherits from ``TrainerCallback``) classes"""
18
- pass
19
-
20
-
21
- class MultilayerDataLoader(ReflectivityDataLoader):
22
- """Dataloader for reflectivity curves simulated using the kinematical approximation"""
23
- def _sample_from_prior(self, batch_size: int):
24
- return self.prior_sampler.optimized_sample(batch_size)
25
-
26
- def _calc_curves(self, q_values: Tensor, params: BasicParams):
1
+ from torch import Tensor
2
+
3
+
4
+ from reflectorch.data_generation import BasicDataset
5
+ from reflectorch.data_generation.reflectivity import kinematical_approximation
6
+ from reflectorch.data_generation.priors import BasicParams
7
+ from reflectorch.ml.basic_trainer import DataLoader
8
+
9
+
10
+ __all__ = [
11
+ "ReflectivityDataLoader",
12
+ "MultilayerDataLoader",
13
+ ]
14
+
15
+
16
+ class ReflectivityDataLoader(BasicDataset, DataLoader):
17
+ """Dataloader for reflectivity data, combining functionality from the ``BasicDataset`` (basic dataset class for reflectivity) and the ``DataLoader`` (which inherits from ``TrainerCallback``) classes"""
18
+ pass
19
+
20
+
21
+ class MultilayerDataLoader(ReflectivityDataLoader):
22
+ """Dataloader for reflectivity curves simulated using the kinematical approximation"""
23
+ def _sample_from_prior(self, batch_size: int):
24
+ return self.prior_sampler.optimized_sample(batch_size)
25
+
26
+ def _calc_curves(self, q_values: Tensor, params: BasicParams):
27
27
  return kinematical_approximation(q_values, params.thicknesses, params.roughnesses, params.slds)
reflectorch/ml/loggers.py CHANGED
@@ -1,56 +1,56 @@
1
- from torch.utils.tensorboard import SummaryWriter
2
-
3
- __all__ = [
4
- 'Logger',
5
- 'Loggers',
6
- 'PrintLogger',
7
- 'TensorBoardLogger',
8
- ]
9
-
10
-
11
- class Logger(object):
12
- "Base class defining a common interface for logging"
13
- def log(self, name: str, data):
14
- pass
15
-
16
- def __setitem__(self, key, value):
17
- """Enable dictionary-style setting to log data."""
18
- self.log(key, value)
19
-
20
-
21
- class Loggers(Logger):
22
- """Class for using multiple loggers"""
23
- def __init__(self, *loggers):
24
- self._loggers = tuple(loggers)
25
-
26
- def log(self, name: str, data):
27
- for logger in self._loggers:
28
- logger.log(name, data)
29
-
30
-
31
- class PrintLogger(Logger):
32
- """Logger which prints to the console"""
33
- def log(self, name: str, data):
34
- print(name, ': ', data)
35
-
36
- class TensorBoardLogger(Logger):
37
- def __init__(self, log_dir: str):
38
- """
39
- Args:
40
- log_dir (str): Directory where TensorBoard logs will be written
41
- """
42
- super().__init__()
43
- self.writer = SummaryWriter(log_dir=log_dir)
44
- self.step = 1
45
-
46
- def log(self, name: str, data):
47
- """Log scalar data to TensorBoard
48
-
49
- Args:
50
- name (str): Name/tag for the data
51
- data: Scalar value to log
52
- """
53
- if hasattr(data, 'item'):
54
- data = data.item()
55
- self.writer.add_scalar(name, data, self.step)
1
+ from torch.utils.tensorboard import SummaryWriter
2
+
3
+ __all__ = [
4
+ 'Logger',
5
+ 'Loggers',
6
+ 'PrintLogger',
7
+ 'TensorBoardLogger',
8
+ ]
9
+
10
+
11
+ class Logger(object):
12
+ "Base class defining a common interface for logging"
13
+ def log(self, name: str, data):
14
+ pass
15
+
16
+ def __setitem__(self, key, value):
17
+ """Enable dictionary-style setting to log data."""
18
+ self.log(key, value)
19
+
20
+
21
+ class Loggers(Logger):
22
+ """Class for using multiple loggers"""
23
+ def __init__(self, *loggers):
24
+ self._loggers = tuple(loggers)
25
+
26
+ def log(self, name: str, data):
27
+ for logger in self._loggers:
28
+ logger.log(name, data)
29
+
30
+
31
+ class PrintLogger(Logger):
32
+ """Logger which prints to the console"""
33
+ def log(self, name: str, data):
34
+ print(name, ': ', data)
35
+
36
+ class TensorBoardLogger(Logger):
37
+ def __init__(self, log_dir: str):
38
+ """
39
+ Args:
40
+ log_dir (str): Directory where TensorBoard logs will be written
41
+ """
42
+ super().__init__()
43
+ self.writer = SummaryWriter(log_dir=log_dir)
44
+ self.step = 1
45
+
46
+ def log(self, name: str, data):
47
+ """Log scalar data to TensorBoard
48
+
49
+ Args:
50
+ name (str): Name/tag for the data
51
+ data: Scalar value to log
52
+ """
53
+ if hasattr(data, 'item'):
54
+ data = data.item()
55
+ self.writer.add_scalar(name, data, self.step)
56
56
  self.step += 1