reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,201 +1,201 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torch import nn
5
- from dataclasses import dataclass
6
- from typing import Optional
7
-
8
- from reflectorch.data_generation import BATCH_DATA_TYPE
9
- from reflectorch.ml.basic_trainer import Trainer
10
- from reflectorch.ml.dataloaders import ReflectivityDataLoader
11
-
12
- __all__ = [
13
- 'RealTimeSimTrainer',
14
- 'DenoisingAETrainer',
15
- 'PointEstimatorTrainer',
16
- ]
17
-
18
-
19
- @dataclass
20
- class BasicBatchData:
21
- scaled_curves: torch.Tensor
22
- scaled_bounds: torch.Tensor
23
- scaled_params: torch.Tensor = None
24
- scaled_sigmas: Optional[torch.Tensor] = None
25
- scaled_q_values: Optional[torch.Tensor] = None
26
- scaled_denoised_curves: Optional[torch.Tensor] = None
27
- key_padding_mask: Optional[torch.Tensor] = None
28
- scaled_conditioning_params: Optional[torch.Tensor] = None
29
- unscaled_q_values: Optional[torch.Tensor] = None
30
-
31
- class RealTimeSimTrainer(Trainer):
32
- """Trainer with functionality to customize the sampled batch of data"""
33
- loader: ReflectivityDataLoader
34
-
35
- def get_batch_by_idx(self, batch_num: int):
36
- """Gets a batch of data with the default batch size"""
37
- batch_data = self.loader.get_batch(self.batch_size)
38
- return self._get_batch(batch_data)
39
-
40
- def get_batch_by_size(self, batch_size: int):
41
- """Gets a batch of data with a custom batch size"""
42
- batch_data = self.loader.get_batch(batch_size)
43
- return self._get_batch(batch_data)
44
-
45
- def _get_batch(self, batch_data: BATCH_DATA_TYPE):
46
- """Modify the batch of data sampled from the data loader"""
47
- raise NotImplementedError
48
-
49
-
50
- class PointEstimatorTrainer(RealTimeSimTrainer):
51
- """Point estimator trainer for the inverse problem."""
52
-
53
- def init(self):
54
- if getattr(self, 'use_l1_loss', False):
55
- self.criterion = nn.L1Loss(reduction='none')
56
- else:
57
- self.criterion = nn.MSELoss(reduction='none')
58
- self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
59
- self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
60
- if self.use_curve_reconstruction_loss:
61
- self.loader.calc_denoised_curves = True
62
-
63
- self.train_with_q_input = getattr(self, 'train_with_q_input', False)
64
- self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
65
- self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
66
-
67
- def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
68
- def get_scaled_or_none(key, scaler=None):
69
- value = batch_data.get(key)
70
- if value is None:
71
- return None
72
- scale_func = scaler or (lambda x: x)
73
- return scale_func(value).to(torch.float32)
74
-
75
- scaled_params = batch_data['scaled_params'].to(torch.float32)
76
- scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
77
- scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
78
- scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
79
- key_padding_mask = batch_data.get('key_padding_mask', None)
80
-
81
- scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
82
- conditioning_params = []
83
- if scaled_q_resolutions is not None:
84
- conditioning_params.append(scaled_q_resolutions)
85
- scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
86
-
87
- num_params = scaled_params.shape[-1] // 3
88
- assert num_params * 3 == scaled_params.shape[-1]
89
- scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
90
-
91
- return BasicBatchData(
92
- scaled_params=scaled_params,
93
- scaled_bounds=scaled_bounds,
94
- scaled_curves=scaled_curves,
95
- scaled_q_values=scaled_q_values,
96
- scaled_denoised_curves=scaled_denoised_curves,
97
- scaled_conditioning_params=scaled_conditioning_params,
98
- unscaled_q_values=batch_data['q_values'],
99
- key_padding_mask=key_padding_mask,
100
- )
101
-
102
- def get_loss_dict(self, batch_data: BasicBatchData):
103
- """Returns the regression loss"""
104
- scaled_params=batch_data.scaled_params
105
- scaled_curves=batch_data.scaled_curves
106
- scaled_bounds=batch_data.scaled_bounds
107
- scaled_q_values=batch_data.scaled_q_values
108
- key_padding_mask=batch_data.key_padding_mask
109
- scaled_conditioning_params=batch_data.scaled_conditioning_params
110
- unscaled_q_values=batch_data.unscaled_q_values
111
-
112
- predicted_params = self.model(
113
- curves = scaled_curves,
114
- bounds = scaled_bounds,
115
- q_values = scaled_q_values,
116
- conditioning_params = scaled_conditioning_params,
117
- key_padding_mask = key_padding_mask,
118
- unscaled_q_values = unscaled_q_values,
119
- )
120
-
121
- if not self.rescale_loss_interval_width:
122
- loss = self.criterion(predicted_params, scaled_params).mean()
123
- else:
124
- n_params = scaled_params.shape[-1]
125
- b_min = scaled_bounds[..., :n_params]
126
- b_max = scaled_bounds[..., n_params:]
127
- interval_width = b_max - b_min
128
-
129
- base_loss = self.criterion(predicted_params, scaled_params)
130
- if isinstance(self.criterion, torch.nn.MSELoss):
131
- width_factors = (interval_width / 2) ** 2
132
- elif isinstance(self.criterion, torch.nn.L1Loss):
133
- width_factors = interval_width / 2
134
-
135
- loss = (width_factors * base_loss).mean()
136
-
137
- return {'loss': loss}
138
-
139
-
140
- # class PointEstimatorTrainer(RealTimeSimTrainer):
141
- # """Trainer for the regression inverse problem with incorporation of prior bounds"""
142
- # add_sigmas_to_context: bool = False
143
-
144
- # def _get_batch(self, batch_data: BATCH_DATA_TYPE):
145
- # scaled_params = batch_data['scaled_params'].to(torch.float32)
146
- # scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
147
- # if self.train_with_q_input:
148
- # q_values = batch_data['q_values'].to(torch.float32)
149
- # scaled_q_values = self.loader.q_generator.scale_q(q_values)
150
- # else:
151
- # scaled_q_values = None
152
-
153
- # num_params = scaled_params.shape[-1] // 3
154
- # assert num_params * 3 == scaled_params.shape[-1]
155
- # scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
156
-
157
- # return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
158
-
159
- # def get_loss_dict(self, batch_data):
160
- # """computes the loss dictionary"""
161
-
162
- # scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
163
-
164
- # if self.train_with_q_input:
165
- # predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
166
- # else:
167
- # predicted_params = self.model(scaled_curves, scaled_bounds)
168
-
169
- # loss = self.mse(predicted_params, scaled_params)
170
- # return {'loss': loss}
171
-
172
- # def init(self):
173
- # self.mse = nn.MSELoss()
174
-
175
-
176
- class DenoisingAETrainer(RealTimeSimTrainer):
177
- """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
178
- def init(self):
179
- self.loader.calc_denoised_curves = True
180
-
181
- if getattr(self, 'use_l1_loss', False):
182
- self.criterion = nn.L1Loss()
183
- else:
184
- self.criterion = nn.MSELoss()
185
-
186
- def _get_batch(self, batch_data: BATCH_DATA_TYPE):
187
- """returns scaled curves with and without noise"""
188
- scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
189
- scaled_curves = self.loader.curves_scaler.scale(curves)
190
-
191
- scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
192
-
193
- return scaled_noisy_curves, scaled_curves
194
-
195
- def get_loss_dict(self, batch_data):
196
- """returns the reconstruction loss of the autoencoder"""
197
- scaled_noisy_curves, scaled_curves = batch_data
198
- restored_curves = self.model(scaled_noisy_curves)
199
- loss = self.criterion(scaled_curves, restored_curves)
200
- return {'loss': loss}
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ from reflectorch.data_generation import BATCH_DATA_TYPE
9
+ from reflectorch.ml.basic_trainer import Trainer
10
+ from reflectorch.ml.dataloaders import ReflectivityDataLoader
11
+
12
+ __all__ = [
13
+ 'RealTimeSimTrainer',
14
+ 'DenoisingAETrainer',
15
+ 'PointEstimatorTrainer',
16
+ ]
17
+
18
+
19
+ @dataclass
20
+ class BasicBatchData:
21
+ scaled_curves: torch.Tensor
22
+ scaled_bounds: torch.Tensor
23
+ scaled_params: torch.Tensor = None
24
+ scaled_sigmas: Optional[torch.Tensor] = None
25
+ scaled_q_values: Optional[torch.Tensor] = None
26
+ scaled_denoised_curves: Optional[torch.Tensor] = None
27
+ key_padding_mask: Optional[torch.Tensor] = None
28
+ scaled_conditioning_params: Optional[torch.Tensor] = None
29
+ unscaled_q_values: Optional[torch.Tensor] = None
30
+
31
+ class RealTimeSimTrainer(Trainer):
32
+ """Trainer with functionality to customize the sampled batch of data"""
33
+ loader: ReflectivityDataLoader
34
+
35
+ def get_batch_by_idx(self, batch_num: int):
36
+ """Gets a batch of data with the default batch size"""
37
+ batch_data = self.loader.get_batch(self.batch_size)
38
+ return self._get_batch(batch_data)
39
+
40
+ def get_batch_by_size(self, batch_size: int):
41
+ """Gets a batch of data with a custom batch size"""
42
+ batch_data = self.loader.get_batch(batch_size)
43
+ return self._get_batch(batch_data)
44
+
45
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
46
+ """Modify the batch of data sampled from the data loader"""
47
+ raise NotImplementedError
48
+
49
+
50
+ class PointEstimatorTrainer(RealTimeSimTrainer):
51
+ """Point estimator trainer for the inverse problem."""
52
+
53
+ def init(self):
54
+ if getattr(self, 'use_l1_loss', False):
55
+ self.criterion = nn.L1Loss(reduction='none')
56
+ else:
57
+ self.criterion = nn.MSELoss(reduction='none')
58
+ self.use_curve_reconstruction_loss = getattr(self, 'use_curve_reconstruction_loss', False)
59
+ self.rescale_loss_interval_width = getattr(self, 'rescale_loss_interval_width', False)
60
+ if self.use_curve_reconstruction_loss:
61
+ self.loader.calc_denoised_curves = True
62
+
63
+ self.train_with_q_input = getattr(self, 'train_with_q_input', False)
64
+ self.train_with_sigmas = getattr(self, 'train_with_sigmas', False)
65
+ self.condition_on_q_resolutions = getattr(self, 'condition_on_q_resolutions', False)
66
+
67
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE) -> BasicBatchData:
68
+ def get_scaled_or_none(key, scaler=None):
69
+ value = batch_data.get(key)
70
+ if value is None:
71
+ return None
72
+ scale_func = scaler or (lambda x: x)
73
+ return scale_func(value).to(torch.float32)
74
+
75
+ scaled_params = batch_data['scaled_params'].to(torch.float32)
76
+ scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
77
+ scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
78
+ scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
79
+ key_padding_mask = batch_data.get('key_padding_mask', None)
80
+
81
+ scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
82
+ conditioning_params = []
83
+ if scaled_q_resolutions is not None:
84
+ conditioning_params.append(scaled_q_resolutions)
85
+ scaled_conditioning_params = torch.cat(conditioning_params, dim=-1) if len(conditioning_params) > 0 else None
86
+
87
+ num_params = scaled_params.shape[-1] // 3
88
+ assert num_params * 3 == scaled_params.shape[-1]
89
+ scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
90
+
91
+ return BasicBatchData(
92
+ scaled_params=scaled_params,
93
+ scaled_bounds=scaled_bounds,
94
+ scaled_curves=scaled_curves,
95
+ scaled_q_values=scaled_q_values,
96
+ scaled_denoised_curves=scaled_denoised_curves,
97
+ scaled_conditioning_params=scaled_conditioning_params,
98
+ unscaled_q_values=batch_data['q_values'],
99
+ key_padding_mask=key_padding_mask,
100
+ )
101
+
102
+ def get_loss_dict(self, batch_data: BasicBatchData):
103
+ """Returns the regression loss"""
104
+ scaled_params=batch_data.scaled_params
105
+ scaled_curves=batch_data.scaled_curves
106
+ scaled_bounds=batch_data.scaled_bounds
107
+ scaled_q_values=batch_data.scaled_q_values
108
+ key_padding_mask=batch_data.key_padding_mask
109
+ scaled_conditioning_params=batch_data.scaled_conditioning_params
110
+ unscaled_q_values=batch_data.unscaled_q_values
111
+
112
+ predicted_params = self.model(
113
+ curves = scaled_curves,
114
+ bounds = scaled_bounds,
115
+ q_values = scaled_q_values,
116
+ conditioning_params = scaled_conditioning_params,
117
+ key_padding_mask = key_padding_mask,
118
+ unscaled_q_values = unscaled_q_values,
119
+ )
120
+
121
+ if not self.rescale_loss_interval_width:
122
+ loss = self.criterion(predicted_params, scaled_params).mean()
123
+ else:
124
+ n_params = scaled_params.shape[-1]
125
+ b_min = scaled_bounds[..., :n_params]
126
+ b_max = scaled_bounds[..., n_params:]
127
+ interval_width = b_max - b_min
128
+
129
+ base_loss = self.criterion(predicted_params, scaled_params)
130
+ if isinstance(self.criterion, torch.nn.MSELoss):
131
+ width_factors = (interval_width / 2) ** 2
132
+ elif isinstance(self.criterion, torch.nn.L1Loss):
133
+ width_factors = interval_width / 2
134
+
135
+ loss = (width_factors * base_loss).mean()
136
+
137
+ return {'loss': loss}
138
+
139
+
140
+ # class PointEstimatorTrainer(RealTimeSimTrainer):
141
+ # """Trainer for the regression inverse problem with incorporation of prior bounds"""
142
+ # add_sigmas_to_context: bool = False
143
+
144
+ # def _get_batch(self, batch_data: BATCH_DATA_TYPE):
145
+ # scaled_params = batch_data['scaled_params'].to(torch.float32)
146
+ # scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
147
+ # if self.train_with_q_input:
148
+ # q_values = batch_data['q_values'].to(torch.float32)
149
+ # scaled_q_values = self.loader.q_generator.scale_q(q_values)
150
+ # else:
151
+ # scaled_q_values = None
152
+
153
+ # num_params = scaled_params.shape[-1] // 3
154
+ # assert num_params * 3 == scaled_params.shape[-1]
155
+ # scaled_params, scaled_bounds = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
156
+
157
+ # return scaled_params, scaled_bounds, scaled_curves, scaled_q_values
158
+
159
+ # def get_loss_dict(self, batch_data):
160
+ # """computes the loss dictionary"""
161
+
162
+ # scaled_params, scaled_bounds, scaled_curves, scaled_q_values = batch_data
163
+
164
+ # if self.train_with_q_input:
165
+ # predicted_params = self.model(scaled_curves, scaled_bounds, scaled_q_values)
166
+ # else:
167
+ # predicted_params = self.model(scaled_curves, scaled_bounds)
168
+
169
+ # loss = self.mse(predicted_params, scaled_params)
170
+ # return {'loss': loss}
171
+
172
+ # def init(self):
173
+ # self.mse = nn.MSELoss()
174
+
175
+
176
+ class DenoisingAETrainer(RealTimeSimTrainer):
177
+ """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
178
+ def init(self):
179
+ self.loader.calc_denoised_curves = True
180
+
181
+ if getattr(self, 'use_l1_loss', False):
182
+ self.criterion = nn.L1Loss()
183
+ else:
184
+ self.criterion = nn.MSELoss()
185
+
186
+ def _get_batch(self, batch_data: BATCH_DATA_TYPE):
187
+ """returns scaled curves with and without noise"""
188
+ scaled_noisy_curves, curves = batch_data['scaled_noisy_curves'], batch_data['curves']
189
+ scaled_curves = self.loader.curves_scaler.scale(curves)
190
+
191
+ scaled_noisy_curves, scaled_curves = scaled_noisy_curves.to(torch.float32), scaled_curves.to(torch.float32)
192
+
193
+ return scaled_noisy_curves, scaled_curves
194
+
195
+ def get_loss_dict(self, batch_data):
196
+ """returns the reconstruction loss of the autoencoder"""
197
+ scaled_noisy_curves, scaled_curves = batch_data
198
+ restored_curves = self.model(scaled_noisy_curves)
199
+ loss = self.criterion(scaled_curves, restored_curves)
200
+ return {'loss': loss}
201
201
 
reflectorch/ml/utils.py CHANGED
@@ -1,2 +1,2 @@
1
- def is_divisor(num: int, div: int):
2
- return num and not num % div
1
+ def is_divisor(num: int, div: int):
2
+ return num and not num % div
@@ -1,16 +1,16 @@
1
- from reflectorch.models.encoders import *
2
- from reflectorch.models.networks import *
3
-
4
- __all__ = [
5
- "ConvEncoder",
6
- "ConvDecoder",
7
- "ConvAutoencoder",
8
- "FnoEncoder",
9
- "IntegralConvEmbedding",
10
- "SpectralConv1d",
11
- "ConvResidualNet1D",
12
- "ResidualMLP",
13
- "NetworkWithPriors",
14
- "NetworkWithPriorsConvEmb",
15
- "NetworkWithPriorsFnoEmb",
1
+ from reflectorch.models.encoders import *
2
+ from reflectorch.models.networks import *
3
+
4
+ __all__ = [
5
+ "ConvEncoder",
6
+ "ConvDecoder",
7
+ "ConvAutoencoder",
8
+ "FnoEncoder",
9
+ "IntegralConvEmbedding",
10
+ "SpectralConv1d",
11
+ "ConvResidualNet1D",
12
+ "ResidualMLP",
13
+ "NetworkWithPriors",
14
+ "NetworkWithPriorsConvEmb",
15
+ "NetworkWithPriorsFnoEmb",
16
16
  ]
@@ -1,50 +1,50 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.functional import relu
4
-
5
- class Rowdy(nn.Module):
6
- """adaptive activation function"""
7
- def __init__(self, K=9):
8
- super().__init__()
9
- self.K = K
10
- self.alpha = nn.Parameter(torch.cat((torch.ones(1), torch.zeros(K-1))))
11
- self.alpha.requiresGrad = True
12
- self.omega = nn.Parameter(torch.ones(K))
13
- self.omega.requiresGrad = True
14
-
15
- def forward(self, x):
16
- rowdy = self.alpha[0]*relu(self.omega[0]*x)
17
- for k in range(1, self.K):
18
- rowdy += self.alpha[k]*torch.sin(self.omega[k]*k*x)
19
- return rowdy
20
-
21
-
22
- ACTIVATIONS = {
23
- 'relu': nn.ReLU,
24
- 'lrelu': nn.LeakyReLU,
25
- 'gelu': nn.GELU,
26
- 'selu': nn.SELU,
27
- 'elu': nn.ELU,
28
- 'sigmoid': nn.Sigmoid,
29
- 'tanh': nn.Tanh,
30
- 'silu': nn.SiLU,
31
- 'mish': nn.Mish,
32
- 'rowdy': Rowdy,
33
- }
34
-
35
-
36
- def activation_by_name(name):
37
- """returns an activation function module corresponding to its name
38
-
39
- Args:
40
- name (str): string denoting the activation function ('relu', 'lrelu', 'gelu', 'selu', 'elu', 'sigmoid', 'silu', 'mish', 'rowdy')
41
-
42
- Returns:
43
- nn.Module: Pytorch activation function module
44
- """
45
- if not isinstance(name, str):
46
- return name
47
- try:
48
- return ACTIVATIONS[name.lower()]
49
- except KeyError:
50
- raise KeyError(f'Unknown activation function {name}')
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.functional import relu
4
+
5
+ class Rowdy(nn.Module):
6
+ """adaptive activation function"""
7
+ def __init__(self, K=9):
8
+ super().__init__()
9
+ self.K = K
10
+ self.alpha = nn.Parameter(torch.cat((torch.ones(1), torch.zeros(K-1))))
11
+ self.alpha.requiresGrad = True
12
+ self.omega = nn.Parameter(torch.ones(K))
13
+ self.omega.requiresGrad = True
14
+
15
+ def forward(self, x):
16
+ rowdy = self.alpha[0]*relu(self.omega[0]*x)
17
+ for k in range(1, self.K):
18
+ rowdy += self.alpha[k]*torch.sin(self.omega[k]*k*x)
19
+ return rowdy
20
+
21
+
22
+ ACTIVATIONS = {
23
+ 'relu': nn.ReLU,
24
+ 'lrelu': nn.LeakyReLU,
25
+ 'gelu': nn.GELU,
26
+ 'selu': nn.SELU,
27
+ 'elu': nn.ELU,
28
+ 'sigmoid': nn.Sigmoid,
29
+ 'tanh': nn.Tanh,
30
+ 'silu': nn.SiLU,
31
+ 'mish': nn.Mish,
32
+ 'rowdy': Rowdy,
33
+ }
34
+
35
+
36
+ def activation_by_name(name):
37
+ """returns an activation function module corresponding to its name
38
+
39
+ Args:
40
+ name (str): string denoting the activation function ('relu', 'lrelu', 'gelu', 'selu', 'elu', 'sigmoid', 'silu', 'mish', 'rowdy')
41
+
42
+ Returns:
43
+ nn.Module: Pytorch activation function module
44
+ """
45
+ if not isinstance(name, str):
46
+ return name
47
+ try:
48
+ return ACTIVATIONS[name.lower()]
49
+ except KeyError:
50
+ raise KeyError(f'Unknown activation function {name}')
@@ -1,19 +1,19 @@
1
- from reflectorch.models.encoders.conv_encoder import (
2
- ConvEncoder,
3
- ConvDecoder,
4
- ConvAutoencoder,
5
- )
6
- from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
7
- from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
8
- from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
9
-
10
-
11
- __all__ = [
12
- "ConvEncoder",
13
- "ConvDecoder",
14
- "ConvAutoencoder",
15
- "ConvResidualNet1D",
16
- "FnoEncoder",
17
- "SpectralConv1d",
18
- "IntegralConvEmbedding",
19
- ]
1
+ from reflectorch.models.encoders.conv_encoder import (
2
+ ConvEncoder,
3
+ ConvDecoder,
4
+ ConvAutoencoder,
5
+ )
6
+ from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
7
+ from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
8
+ from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
9
+
10
+
11
+ __all__ = [
12
+ "ConvEncoder",
13
+ "ConvDecoder",
14
+ "ConvAutoencoder",
15
+ "ConvResidualNet1D",
16
+ "FnoEncoder",
17
+ "SpectralConv1d",
18
+ "IntegralConvEmbedding",
19
+ ]