careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,524 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import numpy as np
4
- import torch
5
-
6
- from ..utils.logging import get_logger
7
-
8
- logger = get_logger(__name__)
9
-
10
-
11
- # TODO here "Model" clashes a bit with the naming convention of the Pydantic Models
12
- class NoiseModel(ABC):
13
- """Base class for noise models."""
14
-
15
- @abstractmethod
16
- def instantiate(self):
17
- """Instantiate the noise model.
18
-
19
- Method that should produce ready to use noise model.
20
- """
21
- pass
22
-
23
- @abstractmethod
24
- def likelihood(self, observations, signals):
25
- """Function that returns the likelihood of observations given the signals."""
26
- pass
27
-
28
-
29
- class HistogramNoiseModel(NoiseModel):
30
- """Creates a NoiseModel object.
31
-
32
- Parameters
33
- ----------
34
- histogram: numpy array
35
- A histogram as create by the 'createHistogram(...)' method.
36
- device:
37
- The device your NoiseModel lives on, e.g. your GPU.
38
- """
39
-
40
- def __init__(self, **kwargs):
41
- pass
42
-
43
- def instantiate(self, bins, min_value, max_value, observation, signal):
44
- """Creates a nD histogram from 'observation' and 'signal'.
45
-
46
- Parameters
47
- ----------
48
- bins: int
49
- The number of bins in all dimensions. The total number of bins is
50
- 'bins' ** number_of_dimensions.
51
- min_value: float
52
- the lower bound of the lowest bin.
53
- max_value: float
54
- the highest bound of the highest bin.
55
- observation: np.array
56
- A stack of noisy images. The number has to be divisible by the number of
57
- images in signal. N subsequent images in observation belong to one image
58
- in the signal.
59
- signal: np.array
60
- A stack of clean images.
61
-
62
- Returns
63
- -------
64
- histogram: numpy array
65
- A 3D array:
66
- 'histogram[0,...]' holds the normalized nD counts.
67
- Each row sums to 1, describing p(x_i|s_i).
68
- 'histogram[1,...]' holds the lower boundaries of each bin in y.
69
- 'histogram[2,...]' holds the upper boundaries of each bin in y.
70
- The values for x can be obtained by transposing 'histogram[1,...]'
71
- and 'histogram[2,...]'.
72
- """
73
- img_factor = int(observation.shape[0] / signal.shape[0])
74
- histogram = np.zeros((3, bins, bins))
75
- value_range = [min_value, max_value]
76
-
77
- for i in range(observation.shape[0]):
78
- observation_i = observation[i].copy().ravel()
79
-
80
- signal_i = (signal[i // img_factor].copy()).ravel()
81
-
82
- histogram_i = np.histogramdd(
83
- (signal_i, observation_i), bins=bins, range=[value_range, value_range]
84
- )
85
- # Adding a constant for numerical stability
86
- histogram[0] = histogram[0] + histogram_i[0] + 1e-30
87
-
88
- for i in range(bins):
89
- # Exclude empty rows from normalization
90
- if np.sum(histogram[0, i, :]) > 1e-20:
91
- # Normalize each non-empty row
92
- histogram[0, i, :] /= np.sum(histogram[0, i, :])
93
-
94
- for i in range(bins):
95
- # The lower boundaries of each bin in y are stored in dimension 1
96
- histogram[1, :, i] = histogram_i[1][:-1]
97
- # The upper boundaries of each bin in y are stored in dimension 2
98
- histogram[2, :, i] = histogram_i[1][1:]
99
- # The accordent numbers for x are just transposed.
100
-
101
- return histogram
102
-
103
- def likelihood(self, observed, signal):
104
- """Calculate the likelihood using a histogram based noise model.
105
-
106
- For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability
107
- in the direction of s_i, we linearly interpolate in this direction.
108
-
109
- Parameters
110
- ----------
111
- observed: torch.Tensor
112
- tensor holding your observed intesities x_i.
113
-
114
- signal: torch.Tensor
115
- tensor holding hypotheses for the clean signal at every pixel s_i^k.
116
-
117
- Returns
118
- -------
119
- Torch.tensor containing the observation likelihoods according to the
120
- noise model.
121
- """
122
- observed_float = self.get_index_observed_float(observed)
123
- observed_long = observed_float.floor().long()
124
- signal_float = self.get_index_signal_float(signal)
125
- signal_long = signal_float.floor().long()
126
- fact = signal_float - signal_long.float()
127
-
128
- # Finally we are looking ud the values and interpolate
129
- return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[
130
- torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long
131
- ] * (fact)
132
-
133
- def get_index_observed_float(self, x: float):
134
- """_summary_.
135
-
136
- Parameters
137
- ----------
138
- x : _type_
139
- _description_
140
-
141
- Returns
142
- -------
143
- _type_
144
- _description_
145
- """
146
- return torch.clamp(
147
- self.bins * (x - self.minv) / (self.maxv - self.minv),
148
- min=0.0,
149
- max=self.bins - 1 - 1e-3,
150
- )
151
-
152
- def get_index_signal_float(self, x):
153
- """_summary_.
154
-
155
- Parameters
156
- ----------
157
- x : _type_
158
- _description_
159
-
160
- Returns
161
- -------
162
- _type_
163
- _description_
164
- """
165
- return torch.clamp(
166
- self.bins * (x - self.minv) / (self.maxv - self.minv),
167
- min=0.0,
168
- max=self.bins - 1 - 1e-3,
169
- )
170
-
171
-
172
- # TODO refactor this into Pydantic model
173
- class GaussianMixtureNoiseModel(NoiseModel):
174
- """Describes a noise model parameterized as a mixture of gaussians.
175
-
176
- If you would like to initialize a new object from scratch, then set `params` = None
177
- and specify the other parameters as keyword arguments. If you are instead loading
178
- a model, use only `params`.
179
-
180
- Parameters
181
- ----------
182
- **kwargs: keyworded, variable-length argument dictionary.
183
- Arguments include:
184
- min_signal : float
185
- Minimum signal intensity expected in the image.
186
- max_signal : float
187
- Maximum signal intensity expected in the image.
188
- weight : array
189
- A [3*n_gaussian, n_coeff] sized array containing the values of the weights
190
- describing the noise model.
191
- Each gaussian contributes three parameters (mean, standard deviation and weight),
192
- hence the number of rows in `weight` are 3*n_gaussian.
193
- If `weight = None`, the weight array is initialized using the `min_signal` and
194
- `max_signal` parameters.
195
- n_gaussian: int
196
- Number of gaussians.
197
- n_coeff: int
198
- Number of coefficients to describe the functional relationship between gaussian
199
- parameters and the signal.
200
- 2 implies a linear relationship, 3 implies a quadratic relationship and so on.
201
- device: device
202
- GPU device.
203
- min_sigma: int
204
- All values of sigma (`standard deviation`) below min_sigma are clamped to become
205
- equal to min_sigma.
206
- params: dictionary
207
- Use `params` if one wishes to load a model with trained weights.
208
- While initializing a new object of the class `GaussianMixtureNoiseModel` from
209
- scratch, set this to `None`.
210
- """
211
-
212
- def __init__(self, **kwargs):
213
- if kwargs.get("params") is None:
214
- weight = kwargs.get("weight")
215
- n_gaussian = kwargs.get("n_gaussian")
216
- n_coeff = kwargs.get("n_coeff")
217
- min_signal = kwargs.get("min_signal")
218
- max_signal = kwargs.get("max_signal")
219
- self.device = kwargs.get("device")
220
- self.path = kwargs.get("path")
221
- self.min_sigma = kwargs.get("min_sigma")
222
- if weight is None:
223
- weight = np.random.randn(n_gaussian * 3, n_coeff)
224
- weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal)
225
- weight = (
226
- torch.from_numpy(weight.astype(np.float32)).float().to(self.device)
227
- )
228
- weight.requires_grad = True
229
- self.n_gaussian = weight.shape[0] // 3
230
- self.n_coeff = weight.shape[1]
231
- self.weight = weight
232
- self.min_signal = torch.Tensor([min_signal]).to(self.device)
233
- self.max_signal = torch.Tensor([max_signal]).to(self.device)
234
- self.tol = torch.Tensor([1e-10]).to(self.device)
235
- else:
236
- params = kwargs.get("params")
237
- self.device = kwargs.get("device")
238
-
239
- self.min_signal = torch.Tensor(params["min_signal"]).to(self.device)
240
- self.max_signal = torch.Tensor(params["max_signal"]).to(self.device)
241
-
242
- self.weight = torch.Tensor(params["trained_weight"]).to(self.device)
243
- self.min_sigma = np.ndarray.item(params["min_sigma"])
244
- self.n_gaussian = self.weight.shape[0] // 3
245
- self.n_coeff = self.weight.shape[1]
246
- self.tol = torch.Tensor([1e-10]).to(self.device)
247
- self.min_signal = torch.Tensor([self.min_signal]).to(self.device)
248
- self.max_signal = torch.Tensor([self.max_signal]).to(self.device)
249
-
250
- def fast_shuffle(self, series, num):
251
- """.
252
-
253
- Parameters
254
- ----------
255
- series : _type_
256
- _description_
257
- num : _type_
258
- _description_
259
-
260
- Returns
261
- -------
262
- _type_
263
- _description_
264
- """
265
- length = series.shape[0]
266
- for _i in range(num):
267
- series = series[np.random.permutation(length), :]
268
- return series
269
-
270
- def polynomial_regressor(self, weightParams, signals):
271
- """Combines weight_parameters and signals to perform regression.
272
-
273
- Parameters
274
- ----------
275
- weightParams : torch.cuda.FloatTensor
276
- Corresponds to specific rows of the `self.weight'
277
-
278
- signals : torch.cuda.FloatTensor
279
- Signals
280
-
281
- Returns
282
- -------
283
- value : torch.cuda.FloatTensor
284
- Corresponds to either of mean, standard deviation or weight, evaluated at
285
- `signals`
286
- """
287
- value = 0
288
- for i in range(weightParams.shape[0]):
289
- value += weightParams[i] * (
290
- ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
291
- )
292
- return value
293
-
294
- def normal_density(self, x, m_=0.0, std_=None):
295
- """Evaluates the normal probability density.
296
-
297
- Parameters
298
- ----------
299
- x: torch.cuda.FloatTensor
300
- Observations
301
- m_: torch.cuda.FloatTensor
302
- Mean
303
- std_: torch.cuda.FloatTensor
304
- Standard-deviation
305
-
306
- Returns
307
- -------
308
- tmp: torch.cuda.FloatTensor
309
- Normal probability density of `x` given `m_` and `std_`
310
-
311
- """
312
- tmp = -((x - m_) ** 2)
313
- tmp = tmp / (2.0 * std_ * std_)
314
- tmp = torch.exp(tmp)
315
- tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_)
316
- return tmp
317
-
318
- def likelihood(self, observations, signals):
319
- """Evaluates the likelihood of observations.
320
-
321
- Given the signals and the corresponding gaussian parameters evaluates the
322
- likelihood of observations.
323
-
324
- Parameters
325
- ----------
326
- observations : torch.cuda.FloatTensor
327
- Noisy observations
328
- signals : torch.cuda.FloatTensor
329
- Underlying signals
330
-
331
- Returns
332
- -------
333
- value :p + self.tol
334
- Likelihood of observations given the signals and the GMM noise model
335
-
336
- """
337
- gaussianParameters = self.getGaussianParameters(signals)
338
- p = 0
339
- for gaussian in range(self.n_gaussian):
340
- p += (
341
- self.normalDens(
342
- observations,
343
- gaussianParameters[gaussian],
344
- gaussianParameters[self.n_gaussian + gaussian],
345
- )
346
- * gaussianParameters[2 * self.n_gaussian + gaussian]
347
- )
348
- return p + self.tol
349
-
350
- def get_gaussian_parameters(self, signals):
351
- """Returns the noise model for given signals.
352
-
353
- Parameters
354
- ----------
355
- signals : torch.cuda.FloatTensor
356
- Underlying signals
357
-
358
- Returns
359
- -------
360
- noiseModel: list of torch.cuda.FloatTensor
361
- Contains a list of `mu`, `sigma` and `alpha` for the `signals`
362
-
363
- """
364
- noiseModel = []
365
- mu = []
366
- sigma = []
367
- alpha = []
368
- kernels = self.weight.shape[0] // 3
369
- for num in range(kernels):
370
- mu.append(self.polynomialRegressor(self.weight[num, :], signals))
371
-
372
- sigmaTemp = self.polynomialRegressor(
373
- torch.exp(self.weight[kernels + num, :]), signals
374
- )
375
- sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
376
- sigma.append(torch.sqrt(sigmaTemp))
377
- alpha.append(
378
- torch.exp(
379
- self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
380
- + self.tol
381
- )
382
- )
383
-
384
- sum_alpha = 0
385
- for al in range(kernels):
386
- sum_alpha = alpha[al] + sum_alpha
387
- for ker in range(kernels):
388
- alpha[ker] = alpha[ker] / sum_alpha
389
-
390
- sum_means = 0
391
- for ker in range(kernels):
392
- sum_means = alpha[ker] * mu[ker] + sum_means
393
-
394
- for ker in range(kernels):
395
- mu[ker] = mu[ker] - sum_means + signals
396
-
397
- for i in range(kernels):
398
- noiseModel.append(mu[i])
399
- for j in range(kernels):
400
- noiseModel.append(sigma[j])
401
- for k in range(kernels):
402
- noiseModel.append(alpha[k])
403
-
404
- return noiseModel
405
-
406
- def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip):
407
- """Returns the Signal-Observation pixel intensities as a two-column array.
408
-
409
- Parameters
410
- ----------
411
- signal : numpy array
412
- Clean Signal Data
413
- observation: numpy array
414
- Noisy observation Data
415
- lowerClip: float
416
- Lower percentile bound for clipping.
417
- upperClip: float
418
- Upper percentile bound for clipping.
419
-
420
- Returns
421
- -------
422
- noiseModel: list of torch floats
423
- Contains a list of `mu`, `sigma` and `alpha` for the `signals`
424
-
425
- """
426
- lb = np.percentile(signal, lowerClip)
427
- ub = np.percentile(signal, upperClip)
428
- stepsize = observation[0].size
429
- n_observations = observation.shape[0]
430
- n_signals = signal.shape[0]
431
- sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
432
-
433
- for i in range(n_observations):
434
- j = i // (n_observations // n_signals)
435
- sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
436
- sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
437
- sig_obs_pairs = sig_obs_pairs[
438
- (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
439
- ]
440
- return self.fast_shuffle(sig_obs_pairs, 2)
441
-
442
- def train(
443
- self,
444
- signal,
445
- observation,
446
- learning_rate=1e-1,
447
- batchSize=250000,
448
- n_epochs=2000,
449
- name="GMMNoiseModel.npz",
450
- lowerClip=0,
451
- upperClip=100,
452
- ):
453
- """Training to learn the noise model from signal - observation pairs.
454
-
455
- Parameters
456
- ----------
457
- signal: numpy array
458
- Clean Signal Data
459
- observation: numpy array
460
- Noisy Observation Data
461
- learning_rate: float
462
- Learning rate. Default = 1e-1.
463
- batchSize: int
464
- Nini-batch size. Default = 250000.
465
- n_epochs: int
466
- Number of epochs. Default = 2000.
467
- name: string
468
- Model name. Default is `GMMNoiseModel`. This model after being trained is
469
- saved at the location `path`.
470
-
471
- lowerClip : int
472
- Lower percentile for clipping. Default is 0.
473
- upperClip : int
474
- Upper percentile for clipping. Default is 100.
475
-
476
-
477
- """
478
- sig_obs_pairs = self.getSignalObservationPairs(
479
- signal, observation, lowerClip, upperClip
480
- )
481
- counter = 0
482
- optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
483
- for t in range(n_epochs):
484
- jointLoss = 0
485
- if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]:
486
- counter = 0
487
- sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1)
488
-
489
- batch_vectors = sig_obs_pairs[
490
- counter * batchSize : (counter + 1) * batchSize, :
491
- ]
492
- observations = batch_vectors[:, 1].astype(np.float32)
493
- signals = batch_vectors[:, 0].astype(np.float32)
494
- observations = (
495
- torch.from_numpy(observations.astype(np.float32))
496
- .float()
497
- .to(self.device)
498
- )
499
- signals = torch.from_numpy(signals).float().to(self.device)
500
- p = self.likelihood(observations, signals)
501
- loss = torch.mean(-torch.log(p))
502
- jointLoss = jointLoss + loss
503
-
504
- if t % 100 == 0:
505
- print(t, jointLoss.item())
506
-
507
- if t % (int(n_epochs * 0.5)) == 0:
508
- trained_weight = self.weight.cpu().detach().numpy()
509
- min_signal = self.min_signal.cpu().detach().numpy()
510
- max_signal = self.max_signal.cpu().detach().numpy()
511
- np.savez(
512
- self.path + name,
513
- trained_weight=trained_weight,
514
- min_signal=min_signal,
515
- max_signal=max_signal,
516
- min_sigma=self.min_sigma,
517
- )
518
-
519
- optimizer.zero_grad()
520
- jointLoss.backward()
521
- optimizer.step()
522
- counter += 1
523
-
524
- logger.info(f"The trained parameters {name} is saved at location: " + self.path)
@@ -1,93 +0,0 @@
1
- from typing import Any, Dict, Tuple
2
-
3
- import numpy as np
4
- from albumentations import DualTransform
5
-
6
-
7
- class NDFlip(DualTransform):
8
- """Flip ND arrays on a single axis.
9
-
10
- This transform ignores singleton axes and randomly flips one of the other
11
- axes, to the exception of the first and last axes (sample and channels).
12
-
13
- This transform expects (Z)YXC dimensions.
14
- """
15
-
16
- def __init__(self, p: float = 0.5, is_3D: bool = False, flip_z: bool = True):
17
- """Constructor.
18
-
19
- Parameters
20
- ----------
21
- p : float, optional
22
- Probability to apply the transform, by default 0.5
23
- is_3D : bool, optional
24
- Whether the data is 3D, by default False
25
- flip_z : bool, optional
26
- Whether to flip Z dimension, by default True
27
- """
28
- super().__init__(p=p)
29
-
30
- self.is_3D = is_3D
31
- self.flip_z = flip_z
32
-
33
- # "flippable" axes
34
- if is_3D:
35
- self.axis_indices = [0, 1, 2] if flip_z else [1, 2]
36
- else:
37
- self.axis_indices = [0, 1]
38
-
39
- def get_params(self, **kwargs: Any) -> Dict[str, int]:
40
- """Get the transform parameters.
41
-
42
- Returns
43
- -------
44
- Dict[str, int]
45
- Transform parameters.
46
- """
47
- return {"flip_axis": np.random.choice(self.axis_indices)}
48
-
49
- def apply(self, patch: np.ndarray, flip_axis: int, **kwargs: Any) -> np.ndarray:
50
- """Apply the transform to the image.
51
-
52
- Parameters
53
- ----------
54
- patch : np.ndarray
55
- Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
56
- flip_axis : int
57
- Axis along which to flip the patch.
58
- """
59
- if len(patch.shape) == 3 and self.is_3D:
60
- raise ValueError(
61
- "Incompatible patch shape and dimensionality. ZYXC patch shape "
62
- "expected, but got YXC shape."
63
- )
64
-
65
- return np.ascontiguousarray(np.flip(patch, axis=flip_axis))
66
-
67
- def apply_to_mask(
68
- self, mask: np.ndarray, flip_axis: int, **kwargs: Any
69
- ) -> np.ndarray:
70
- """Apply the transform to the mask.
71
-
72
- Parameters
73
- ----------
74
- mask : np.ndarray
75
- Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c).
76
- """
77
- if len(mask.shape) == 3 and self.is_3D:
78
- raise ValueError(
79
- "Incompatible mask shape and dimensionality. ZYXC patch shape "
80
- "expected, but got YXC shape."
81
- )
82
-
83
- return np.ascontiguousarray(np.flip(mask, axis=flip_axis))
84
-
85
- def get_transform_init_args_names(self, **kwargs: Any) -> Tuple[str, ...]:
86
- """Get the transform arguments names.
87
-
88
- Returns
89
- -------
90
- Tuple[str, ...]
91
- Transform arguments names.
92
- """
93
- return ("is_3D", "flip_z")