careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,8 +13,8 @@ from bioimageio.spec.model.v0_5 import (
|
|
|
13
13
|
ChannelAxis,
|
|
14
14
|
EnvironmentFileDescr,
|
|
15
15
|
FileDescr,
|
|
16
|
+
FixedZeroMeanUnitVarianceAlongAxisKwargs,
|
|
16
17
|
FixedZeroMeanUnitVarianceDescr,
|
|
17
|
-
FixedZeroMeanUnitVarianceKwargs,
|
|
18
18
|
Identifier,
|
|
19
19
|
InputTensorDescr,
|
|
20
20
|
ModelDescr,
|
|
@@ -134,44 +134,52 @@ def _create_inputs_ouputs(
|
|
|
134
134
|
output_axes = _create_axes(output_array, data_config, channel_names, False)
|
|
135
135
|
|
|
136
136
|
# mean and std
|
|
137
|
-
assert data_config.
|
|
138
|
-
assert data_config.
|
|
139
|
-
|
|
140
|
-
|
|
137
|
+
assert data_config.image_means is not None, "Mean cannot be None."
|
|
138
|
+
assert data_config.image_means is not None, "Std cannot be None."
|
|
139
|
+
means = data_config.image_means
|
|
140
|
+
stds = data_config.image_stds
|
|
141
141
|
|
|
142
142
|
# and the mean and std required to invert the normalization
|
|
143
143
|
# CAREamics denormalization: x = y * (std + eps) + mean
|
|
144
144
|
# BMZ normalization : x = (y - mean') / (std' + eps)
|
|
145
145
|
# to apply the BMZ normalization as a denormalization step, we need:
|
|
146
146
|
eps = 1e-6
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
test_tensor=FileDescr(source=output_path),
|
|
165
|
-
postprocessing=[
|
|
166
|
-
FixedZeroMeanUnitVarianceDescr(
|
|
167
|
-
kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization
|
|
168
|
-
mean=inv_mean, std=inv_std
|
|
147
|
+
inv_means = []
|
|
148
|
+
inv_stds = []
|
|
149
|
+
if means and stds:
|
|
150
|
+
for mean, std in zip(means, stds):
|
|
151
|
+
inv_means.append(-mean / (std + eps))
|
|
152
|
+
inv_stds.append(1 / (std + eps) - eps)
|
|
153
|
+
|
|
154
|
+
# create input/output descriptions
|
|
155
|
+
input_descr = InputTensorDescr(
|
|
156
|
+
id=TensorId("input"),
|
|
157
|
+
axes=input_axes,
|
|
158
|
+
test_tensor=FileDescr(source=input_path),
|
|
159
|
+
preprocessing=[
|
|
160
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
161
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
|
|
162
|
+
mean=means, std=stds, axis="channel"
|
|
163
|
+
)
|
|
169
164
|
)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
output_descr = OutputTensorDescr(
|
|
168
|
+
id=TensorId("prediction"),
|
|
169
|
+
axes=output_axes,
|
|
170
|
+
test_tensor=FileDescr(source=output_path),
|
|
171
|
+
postprocessing=[
|
|
172
|
+
FixedZeroMeanUnitVarianceDescr(
|
|
173
|
+
kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
|
|
174
|
+
mean=inv_means, std=inv_stds, axis="channel"
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
],
|
|
178
|
+
)
|
|
173
179
|
|
|
174
|
-
|
|
180
|
+
return input_descr, output_descr
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("Mean and std cannot be None.")
|
|
175
183
|
|
|
176
184
|
|
|
177
185
|
def create_model_description(
|
|
@@ -280,7 +288,7 @@ def create_model_description(
|
|
|
280
288
|
"bioimageio": {
|
|
281
289
|
"test_kwargs": {
|
|
282
290
|
"pytorch_state_dict": {
|
|
283
|
-
"decimals":
|
|
291
|
+
"decimals": 0, # ...so we relax the constraints on the decimals
|
|
284
292
|
}
|
|
285
293
|
}
|
|
286
294
|
}
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -104,9 +104,9 @@ def export_to_bmz(
|
|
|
104
104
|
authors : List[dict]
|
|
105
105
|
Authors of the model.
|
|
106
106
|
input_array : np.ndarray
|
|
107
|
-
Input array.
|
|
107
|
+
Input array, should not have been normalized.
|
|
108
108
|
output_array : np.ndarray
|
|
109
|
-
Output array.
|
|
109
|
+
Output array, should have been denormalized.
|
|
110
110
|
channel_names : Optional[List[str]], optional
|
|
111
111
|
Channel names, by default None.
|
|
112
112
|
data_description : Optional[str], optional
|
|
@@ -178,7 +178,7 @@ def export_to_bmz(
|
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
# test model description
|
|
181
|
-
summary: ValidationSummary = test_model(model_description, decimal=
|
|
181
|
+
summary: ValidationSummary = test_model(model_description, decimal=1)
|
|
182
182
|
if summary.status == "failed":
|
|
183
183
|
raise ValueError(f"Model description test failed: {summary}")
|
|
184
184
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Tuple, Union
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import torch
|
|
7
7
|
|
|
8
8
|
from careamics.config import Configuration
|
|
9
9
|
from careamics.lightning_module import CAREamicsModule
|
|
@@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
|
|
|
64
64
|
If the checkpoint file does not contain hyper parameters (configuration).
|
|
65
65
|
"""
|
|
66
66
|
# load checkpoint
|
|
67
|
-
|
|
67
|
+
# here we might run into issues between devices
|
|
68
|
+
# see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
|
|
69
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
70
|
+
checkpoint: dict = torch.load(path, map_location=device)
|
|
68
71
|
|
|
69
72
|
# attempt to load configuration
|
|
70
73
|
try:
|
careamics/models/activation.py
CHANGED
careamics/models/layers.py
CHANGED
|
@@ -162,6 +162,18 @@ def _unpack_kernel_size(
|
|
|
162
162
|
"""Unpack kernel_size to a tuple of ints.
|
|
163
163
|
|
|
164
164
|
Inspired by Kornia implementation. TODO: link
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
kernel_size : Union[Tuple[int, ...], int]
|
|
169
|
+
Kernel size.
|
|
170
|
+
dim : int
|
|
171
|
+
Number of dimensions.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
Tuple[int, ...]
|
|
176
|
+
Kernel size tuple.
|
|
165
177
|
"""
|
|
166
178
|
if isinstance(kernel_size, int):
|
|
167
179
|
kernel_dims = tuple([kernel_size for _ in range(dim)])
|
|
@@ -173,7 +185,20 @@ def _unpack_kernel_size(
|
|
|
173
185
|
def _compute_zero_padding(
|
|
174
186
|
kernel_size: Union[Tuple[int, ...], int], dim: int
|
|
175
187
|
) -> Tuple[int, ...]:
|
|
176
|
-
"""Utility function that computes zero padding tuple.
|
|
188
|
+
"""Utility function that computes zero padding tuple.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
kernel_size : Union[Tuple[int, ...], int]
|
|
193
|
+
Kernel size.
|
|
194
|
+
dim : int
|
|
195
|
+
Number of dimensions.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
Tuple[int, ...]
|
|
200
|
+
Zero padding tuple.
|
|
201
|
+
"""
|
|
177
202
|
kernel_dims = _unpack_kernel_size(kernel_size, dim)
|
|
178
203
|
return tuple([(kd - 1) // 2 for kd in kernel_dims])
|
|
179
204
|
|
|
@@ -191,14 +216,19 @@ def get_pascal_kernel_1d(
|
|
|
191
216
|
|
|
192
217
|
Parameters
|
|
193
218
|
----------
|
|
194
|
-
kernel_size:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
219
|
+
kernel_size : int
|
|
220
|
+
Kernel size.
|
|
221
|
+
norm : bool
|
|
222
|
+
Normalize the kernel, by default False.
|
|
223
|
+
device : Optional[torch.device]
|
|
224
|
+
Device of the tensor, by default None.
|
|
225
|
+
dtype : Optional[torch.dtype]
|
|
226
|
+
Data type of the tensor, by default None.
|
|
198
227
|
|
|
199
228
|
Returns
|
|
200
229
|
-------
|
|
201
|
-
|
|
230
|
+
torch.Tensor
|
|
231
|
+
Pascal kernel.
|
|
202
232
|
|
|
203
233
|
Examples
|
|
204
234
|
--------
|
|
@@ -245,19 +275,28 @@ def _get_pascal_kernel_nd(
|
|
|
245
275
|
) -> torch.Tensor:
|
|
246
276
|
"""Generate pascal filter kernel by kernel size.
|
|
247
277
|
|
|
278
|
+
If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
|
|
279
|
+
otherwise the kernel will be shaped as kernel_size
|
|
280
|
+
|
|
248
281
|
Inspired by Kornia implementation.
|
|
249
282
|
|
|
250
283
|
Parameters
|
|
251
284
|
----------
|
|
252
|
-
kernel_size:
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
285
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
286
|
+
Kernel size for the pascal kernel.
|
|
287
|
+
norm : bool
|
|
288
|
+
Normalize the kernel, by default True.
|
|
289
|
+
dim : int
|
|
290
|
+
Number of dimensions, by default 2.
|
|
291
|
+
device : Optional[torch.device]
|
|
292
|
+
Device of the tensor, by default None.
|
|
293
|
+
dtype : Optional[torch.dtype]
|
|
294
|
+
Data type of the tensor, by default None.
|
|
256
295
|
|
|
257
296
|
Returns
|
|
258
297
|
-------
|
|
259
|
-
|
|
260
|
-
|
|
298
|
+
torch.Tensor
|
|
299
|
+
Pascal kernel.
|
|
261
300
|
|
|
262
301
|
Examples
|
|
263
302
|
--------
|
|
@@ -303,6 +342,24 @@ def _max_blur_pool_by_kernel2d(
|
|
|
303
342
|
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
|
|
304
343
|
|
|
305
344
|
Inspired by Kornia implementation.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
x : torch.Tensor
|
|
349
|
+
Input tensor.
|
|
350
|
+
kernel : torch.Tensor
|
|
351
|
+
Kernel tensor.
|
|
352
|
+
stride : int
|
|
353
|
+
Stride.
|
|
354
|
+
max_pool_size : int
|
|
355
|
+
Maximum pool size.
|
|
356
|
+
ceil_mode : bool
|
|
357
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
torch.Tensor
|
|
362
|
+
Output tensor.
|
|
306
363
|
"""
|
|
307
364
|
# compute local maxima
|
|
308
365
|
x = F.max_pool2d(
|
|
@@ -323,6 +380,24 @@ def _max_blur_pool_by_kernel3d(
|
|
|
323
380
|
"""Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
|
|
324
381
|
|
|
325
382
|
Inspired by Kornia implementation.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
x : torch.Tensor
|
|
387
|
+
Input tensor.
|
|
388
|
+
kernel : torch.Tensor
|
|
389
|
+
Kernel tensor.
|
|
390
|
+
stride : int
|
|
391
|
+
Stride.
|
|
392
|
+
max_pool_size : int
|
|
393
|
+
Maximum pool size.
|
|
394
|
+
ceil_mode : bool
|
|
395
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
torch.Tensor
|
|
400
|
+
Output tensor.
|
|
326
401
|
"""
|
|
327
402
|
# compute local maxima
|
|
328
403
|
x = F.max_pool3d(
|
|
@@ -343,21 +418,16 @@ class MaxBlurPool(nn.Module):
|
|
|
343
418
|
|
|
344
419
|
Parameters
|
|
345
420
|
----------
|
|
346
|
-
dim: int
|
|
347
|
-
Toggles between 2D and 3D
|
|
348
|
-
kernel_size: Union[Tuple[int, int], int]
|
|
421
|
+
dim : int
|
|
422
|
+
Toggles between 2D and 3D.
|
|
423
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
349
424
|
Kernel size for max pooling.
|
|
350
|
-
stride: int
|
|
425
|
+
stride : int
|
|
351
426
|
Stride for pooling.
|
|
352
|
-
max_pool_size: int
|
|
427
|
+
max_pool_size : int
|
|
353
428
|
Max kernel size for max pooling.
|
|
354
|
-
ceil_mode: bool
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
Returns
|
|
358
|
-
-------
|
|
359
|
-
torch.Tensor
|
|
360
|
-
The pooled and blurred tensor.
|
|
429
|
+
ceil_mode : bool
|
|
430
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
361
431
|
"""
|
|
362
432
|
|
|
363
433
|
def __init__(
|
|
@@ -368,6 +438,21 @@ class MaxBlurPool(nn.Module):
|
|
|
368
438
|
max_pool_size: int = 2,
|
|
369
439
|
ceil_mode: bool = False,
|
|
370
440
|
) -> None:
|
|
441
|
+
"""Constructor.
|
|
442
|
+
|
|
443
|
+
Parameters
|
|
444
|
+
----------
|
|
445
|
+
dim : int
|
|
446
|
+
Dimension of the convolution.
|
|
447
|
+
kernel_size : Union[Tuple[int, int], int]
|
|
448
|
+
Kernel size for max pooling.
|
|
449
|
+
stride : int, optional
|
|
450
|
+
Stride, by default 2.
|
|
451
|
+
max_pool_size : int, optional
|
|
452
|
+
Maximum pool size, by default 2.
|
|
453
|
+
ceil_mode : bool, optional
|
|
454
|
+
Ceil mode, by default False. Set to True to match output size of conv2d.
|
|
455
|
+
"""
|
|
371
456
|
super().__init__()
|
|
372
457
|
self.dim = dim
|
|
373
458
|
self.kernel_size = kernel_size
|
|
@@ -377,7 +462,18 @@ class MaxBlurPool(nn.Module):
|
|
|
377
462
|
self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
|
|
378
463
|
|
|
379
464
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
380
|
-
"""Forward pass of the function.
|
|
465
|
+
"""Forward pass of the function.
|
|
466
|
+
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
x : torch.Tensor
|
|
470
|
+
Input tensor.
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
torch.Tensor
|
|
475
|
+
Output tensor.
|
|
476
|
+
"""
|
|
381
477
|
self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
|
|
382
478
|
if self.dim == 2:
|
|
383
479
|
return _max_blur_pool_by_kernel2d(
|
|
File without changes
|