quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
input_height: 224
|
|
2
|
+
input_width: 224
|
|
3
|
+
mean: [0.485, 0.456, 0.406]
|
|
4
|
+
std: [0.229, 0.224, 0.225]
|
|
5
|
+
|
|
6
|
+
normalize:
|
|
7
|
+
_target_: albumentations.Compose
|
|
8
|
+
transforms:
|
|
9
|
+
- _target_: albumentations.Normalize
|
|
10
|
+
mean: ${transforms.mean}
|
|
11
|
+
std: ${transforms.std}
|
|
12
|
+
always_apply: True
|
|
13
|
+
- _target_: albumentations.pytorch.ToTensorV2
|
|
14
|
+
always_apply: True
|
|
15
|
+
|
|
16
|
+
resize_center_crop:
|
|
17
|
+
_target_: albumentations.Compose
|
|
18
|
+
transforms:
|
|
19
|
+
- _target_: albumentations.Resize
|
|
20
|
+
height: 256
|
|
21
|
+
width: 256
|
|
22
|
+
interpolation: 2
|
|
23
|
+
always_apply: True
|
|
24
|
+
- _target_: albumentations.CenterCrop
|
|
25
|
+
height: ${transforms.input_height}
|
|
26
|
+
width: ${transforms.input_width}
|
|
27
|
+
always_apply: true
|
|
28
|
+
|
|
29
|
+
standard_transform:
|
|
30
|
+
_target_: albumentations.Compose
|
|
31
|
+
transforms:
|
|
32
|
+
- ${transforms.resize_center_crop}
|
|
33
|
+
- ${transforms.normalize}
|
|
34
|
+
|
|
35
|
+
train_transform: ${transforms.standard_transform}
|
|
36
|
+
val_transform: ${transforms.standard_transform}
|
|
37
|
+
test_transform: ${transforms.standard_transform}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
input_height: 224
|
|
2
|
+
input_width: 224
|
|
3
|
+
|
|
4
|
+
resize_center_crop:
|
|
5
|
+
_target_: albumentations.Compose
|
|
6
|
+
transforms:
|
|
7
|
+
- _target_: albumentations.Resize
|
|
8
|
+
height: 256
|
|
9
|
+
width: 256
|
|
10
|
+
interpolation: 2
|
|
11
|
+
always_apply: True
|
|
12
|
+
- _target_: albumentations.CenterCrop
|
|
13
|
+
height: ${transforms.input_height}
|
|
14
|
+
width: ${transforms.input_width}
|
|
15
|
+
always_apply: true
|
|
16
|
+
|
|
17
|
+
standard_transform:
|
|
18
|
+
_target_: albumentations.Compose
|
|
19
|
+
transforms:
|
|
20
|
+
- ${transforms.resize_center_crop}
|
|
21
|
+
|
|
22
|
+
train_transform: ${transforms.standard_transform}
|
|
23
|
+
val_transform: ${transforms.standard_transform}
|
|
24
|
+
test_transform: ${transforms.standard_transform}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- default
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
input_height: 224
|
|
6
|
+
input_width: 224
|
|
7
|
+
|
|
8
|
+
standard_transform:
|
|
9
|
+
_target_: albumentations.Compose
|
|
10
|
+
transforms:
|
|
11
|
+
- _target_: albumentations.Resize
|
|
12
|
+
height: ${transforms.input_height}
|
|
13
|
+
width: ${transforms.input_width}
|
|
14
|
+
interpolation: 2
|
|
15
|
+
always_apply: True
|
|
16
|
+
- ${transforms.normalize}
|
|
17
|
+
|
|
18
|
+
train_transform: ${transforms.standard_transform}
|
|
19
|
+
val_transform: ${transforms.standard_transform}
|
|
20
|
+
test_transform: ${transforms.standard_transform}
|
|
21
|
+
|
|
22
|
+
name: default_resize
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- default
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
flip_and_jitter:
|
|
6
|
+
_target_: albumentations.Compose
|
|
7
|
+
transforms:
|
|
8
|
+
- _target_: albumentations.HorizontalFlip
|
|
9
|
+
p: 0.5
|
|
10
|
+
- _target_: albumentations.ColorJitter
|
|
11
|
+
brightness: 0.4
|
|
12
|
+
contrast: 0.4
|
|
13
|
+
saturation: 0.4
|
|
14
|
+
hue: 0.1
|
|
15
|
+
- _target_: albumentations.ToGray
|
|
16
|
+
p: 0.2
|
|
17
|
+
|
|
18
|
+
global_transforms:
|
|
19
|
+
- _target_: albumentations.Compose
|
|
20
|
+
transforms:
|
|
21
|
+
- _target_: albumentations.RandomResizedCrop
|
|
22
|
+
height: ${transforms.input_height}
|
|
23
|
+
width: ${transforms.input_width}
|
|
24
|
+
scale: [0.4, 1.0]
|
|
25
|
+
interpolation: 2
|
|
26
|
+
- ${transforms.flip_and_jitter}
|
|
27
|
+
- _target_: albumentations.GaussianBlur
|
|
28
|
+
blur_limit: 5
|
|
29
|
+
sigma_limit: [0.1, 2]
|
|
30
|
+
p: 1.0
|
|
31
|
+
- ${transforms.normalize}
|
|
32
|
+
|
|
33
|
+
- _target_: albumentations.Compose
|
|
34
|
+
transforms:
|
|
35
|
+
- _target_: albumentations.RandomResizedCrop
|
|
36
|
+
height: ${transforms.input_height}
|
|
37
|
+
width: ${transforms.input_width}
|
|
38
|
+
scale: [0.4, 1.0]
|
|
39
|
+
interpolation: 2
|
|
40
|
+
- ${transforms.flip_and_jitter}
|
|
41
|
+
- _target_: albumentations.GaussianBlur
|
|
42
|
+
blur_limit: 5
|
|
43
|
+
sigma_limit: [0.1, 2]
|
|
44
|
+
p: 0.1
|
|
45
|
+
- _target_: albumentations.Solarize
|
|
46
|
+
threshold: 170
|
|
47
|
+
p: 0.2
|
|
48
|
+
- ${transforms.normalize}
|
|
49
|
+
|
|
50
|
+
local_transform:
|
|
51
|
+
_target_: albumentations.Compose
|
|
52
|
+
transforms:
|
|
53
|
+
- _target_: albumentations.RandomResizedCrop
|
|
54
|
+
height: ${transforms.input_height}
|
|
55
|
+
width: ${transforms.input_width}
|
|
56
|
+
scale: [0.05, 0.4]
|
|
57
|
+
interpolation: 2
|
|
58
|
+
- ${transforms.flip_and_jitter}
|
|
59
|
+
- _target_: albumentations.GaussianBlur
|
|
60
|
+
blur_limit: 5
|
|
61
|
+
sigma_limit: [0.1, 2]
|
|
62
|
+
p: 0.5
|
|
63
|
+
- ${transforms.normalize}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- default
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
train_transform:
|
|
6
|
+
_target_: albumentations.Compose
|
|
7
|
+
transforms:
|
|
8
|
+
- _target_: albumentations.RandomResizedCrop
|
|
9
|
+
height: ${transforms.input_height}
|
|
10
|
+
width: ${transforms.input_width}
|
|
11
|
+
scale: [0.08, 1.0]
|
|
12
|
+
interpolation: 2
|
|
13
|
+
always_apply: True
|
|
14
|
+
- _target_: albumentations.HorizontalFlip
|
|
15
|
+
p: 0.5
|
|
16
|
+
- ${transforms.normalize}
|
|
17
|
+
val_transform: ${transforms.standard_transform}
|
|
18
|
+
test_transform: ${transforms.standard_transform}
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .anomaly import AnomalyDataModule
|
|
2
|
+
from .classification import (
|
|
3
|
+
ClassificationDataModule,
|
|
4
|
+
MultilabelClassificationDataModule,
|
|
5
|
+
SklearnClassificationDataModule,
|
|
6
|
+
)
|
|
7
|
+
from .patch import PatchSklearnClassificationDataModule
|
|
8
|
+
from .segmentation import SegmentationDataModule, SegmentationMulticlassDataModule
|
|
9
|
+
from .ssl import SSLDataModule
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"AnomalyDataModule",
|
|
13
|
+
"ClassificationDataModule",
|
|
14
|
+
"SklearnClassificationDataModule",
|
|
15
|
+
"SegmentationDataModule",
|
|
16
|
+
"SegmentationMulticlassDataModule",
|
|
17
|
+
"PatchSklearnClassificationDataModule",
|
|
18
|
+
"MultilabelClassificationDataModule",
|
|
19
|
+
"SSLDataModule",
|
|
20
|
+
]
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import pathlib
|
|
5
|
+
|
|
6
|
+
import albumentations
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
|
|
10
|
+
from quadra.datamodules.base import BaseDataModule
|
|
11
|
+
from quadra.datasets import AnomalyDataset
|
|
12
|
+
from quadra.datasets.anomaly import make_anomaly_dataset
|
|
13
|
+
from quadra.utils import utils
|
|
14
|
+
|
|
15
|
+
log = utils.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AnomalyDataModule(BaseDataModule):
|
|
19
|
+
"""Anomalib-like Lightning Data Module.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
data_path: Path to the dataset
|
|
23
|
+
category: Name of the sub category to use.
|
|
24
|
+
image_size: Variable to which image is resized.
|
|
25
|
+
train_batch_size: Training batch size.
|
|
26
|
+
test_batch_size: Testing batch size.
|
|
27
|
+
train_transform: transformations for training. Defaults to None.
|
|
28
|
+
val_transform: transformations for validation. Defaults to None.
|
|
29
|
+
test_transform: transformations for testing. Defaults to None.
|
|
30
|
+
num_workers: Number of workers.
|
|
31
|
+
seed: seed used for the random subset splitting
|
|
32
|
+
task: Whether we are interested in segmenting the anomalies (segmentation) or not (classification)
|
|
33
|
+
mask_suffix: String to append to the base filename to get the mask name, by default for MVTec dataset masks
|
|
34
|
+
are saved as imagename_mask.png in this case the parameter should be filled with "_mask"
|
|
35
|
+
create_test_set_if_empty: If True, the test set is created from good images if it is empty.
|
|
36
|
+
phase: Either train or test.
|
|
37
|
+
name: Name of the data module.
|
|
38
|
+
valid_area_mask: Optional path to the mask to use to filter out the valid area of the image. If None, the whole
|
|
39
|
+
image is considered valid. The mask should match the image size even if the image is cropped.
|
|
40
|
+
crop_area: Optional tuple of 4 integers (x1, y1, x2, y2) to crop the image to the specified area. If None, the
|
|
41
|
+
whole image is considered valid.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
data_path: str,
|
|
47
|
+
category: str | None = None,
|
|
48
|
+
image_size: int | tuple[int, int] | None = None,
|
|
49
|
+
train_batch_size: int = 32,
|
|
50
|
+
test_batch_size: int = 32,
|
|
51
|
+
num_workers: int = 8,
|
|
52
|
+
train_transform: albumentations.Compose | None = None,
|
|
53
|
+
val_transform: albumentations.Compose | None = None,
|
|
54
|
+
test_transform: albumentations.Compose | None = None,
|
|
55
|
+
seed: int = 0,
|
|
56
|
+
task: str = "segmentation",
|
|
57
|
+
mask_suffix: str | None = None,
|
|
58
|
+
create_test_set_if_empty: bool = True,
|
|
59
|
+
phase: str = "train",
|
|
60
|
+
name: str = "anomaly_datamodule",
|
|
61
|
+
valid_area_mask: str | None = None,
|
|
62
|
+
crop_area: tuple[int, int, int, int] | None = None,
|
|
63
|
+
**kwargs,
|
|
64
|
+
) -> None:
|
|
65
|
+
super().__init__(
|
|
66
|
+
data_path=data_path,
|
|
67
|
+
name=name,
|
|
68
|
+
seed=seed,
|
|
69
|
+
train_transform=train_transform,
|
|
70
|
+
val_transform=val_transform,
|
|
71
|
+
test_transform=test_transform,
|
|
72
|
+
num_workers=num_workers,
|
|
73
|
+
**kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self.root = data_path
|
|
77
|
+
self.category = category
|
|
78
|
+
self.data_path = os.path.join(self.root, self.category) if self.category is not None else self.root
|
|
79
|
+
self.image_size = image_size
|
|
80
|
+
|
|
81
|
+
self.train_batch_size = train_batch_size
|
|
82
|
+
self.test_batch_size = test_batch_size
|
|
83
|
+
self.task = task
|
|
84
|
+
|
|
85
|
+
self.train_dataset: AnomalyDataset
|
|
86
|
+
self.test_dataset: AnomalyDataset
|
|
87
|
+
self.val_dataset: AnomalyDataset
|
|
88
|
+
self.mask_suffix = mask_suffix
|
|
89
|
+
self.create_test_set_if_empty = create_test_set_if_empty
|
|
90
|
+
self.phase = phase
|
|
91
|
+
self.valid_area_mask = valid_area_mask
|
|
92
|
+
self.crop_area = crop_area
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def val_data(self) -> pd.DataFrame:
|
|
96
|
+
"""Get validation data."""
|
|
97
|
+
_val_data = super().val_data
|
|
98
|
+
if len(_val_data) == 0:
|
|
99
|
+
return self.test_data
|
|
100
|
+
return _val_data
|
|
101
|
+
|
|
102
|
+
def _prepare_data(self) -> None:
|
|
103
|
+
"""Prepare data for training and testing."""
|
|
104
|
+
self.data = make_anomaly_dataset(
|
|
105
|
+
path=pathlib.Path(self.data_path),
|
|
106
|
+
split=None,
|
|
107
|
+
seed=self.seed,
|
|
108
|
+
mask_suffix=self.mask_suffix,
|
|
109
|
+
create_test_set_if_empty=self.create_test_set_if_empty,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def setup(self, stage: str | None = None) -> None:
|
|
113
|
+
"""Setup data module based on stages of training."""
|
|
114
|
+
if stage == "fit" and self.phase == "train":
|
|
115
|
+
self.train_dataset = AnomalyDataset(
|
|
116
|
+
transform=self.train_transform,
|
|
117
|
+
task=self.task,
|
|
118
|
+
samples=self.train_data,
|
|
119
|
+
valid_area_mask=self.valid_area_mask,
|
|
120
|
+
crop_area=self.crop_area,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if len(self.val_data) == 0:
|
|
124
|
+
log.info("Validation dataset is empty, using test set instead")
|
|
125
|
+
|
|
126
|
+
self.val_dataset = AnomalyDataset(
|
|
127
|
+
transform=self.test_transform,
|
|
128
|
+
task=self.task,
|
|
129
|
+
samples=self.val_data if len(self.val_data) > 0 else self.data,
|
|
130
|
+
valid_area_mask=self.valid_area_mask,
|
|
131
|
+
crop_area=self.crop_area,
|
|
132
|
+
)
|
|
133
|
+
if stage == "test" or self.phase == "test":
|
|
134
|
+
self.test_dataset = AnomalyDataset(
|
|
135
|
+
transform=self.test_transform,
|
|
136
|
+
task=self.task,
|
|
137
|
+
samples=self.test_data,
|
|
138
|
+
valid_area_mask=self.valid_area_mask,
|
|
139
|
+
crop_area=self.crop_area,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def train_dataloader(self) -> DataLoader:
|
|
143
|
+
"""Get train dataloader."""
|
|
144
|
+
return DataLoader(
|
|
145
|
+
self.train_dataset,
|
|
146
|
+
shuffle=True,
|
|
147
|
+
batch_size=self.train_batch_size,
|
|
148
|
+
num_workers=self.num_workers,
|
|
149
|
+
pin_memory=True,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def val_dataloader(self) -> DataLoader:
|
|
153
|
+
"""Get validation dataloader."""
|
|
154
|
+
return DataLoader(
|
|
155
|
+
dataset=self.val_dataset,
|
|
156
|
+
shuffle=False,
|
|
157
|
+
batch_size=self.test_batch_size,
|
|
158
|
+
num_workers=self.num_workers,
|
|
159
|
+
pin_memory=True,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def test_dataloader(self) -> DataLoader:
|
|
163
|
+
"""Get test dataloader."""
|
|
164
|
+
return DataLoader(
|
|
165
|
+
self.test_dataset,
|
|
166
|
+
shuffle=False,
|
|
167
|
+
batch_size=self.test_batch_size,
|
|
168
|
+
num_workers=self.num_workers,
|
|
169
|
+
pin_memory=True,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def predict_dataloader(self) -> DataLoader:
|
|
173
|
+
"""Returns a dataloader used for predictions."""
|
|
174
|
+
return DataLoader(
|
|
175
|
+
self.test_dataset,
|
|
176
|
+
shuffle=False,
|
|
177
|
+
batch_size=self.test_batch_size,
|
|
178
|
+
num_workers=self.num_workers,
|
|
179
|
+
pin_memory=True,
|
|
180
|
+
)
|