quadra 2.2.7__tar.gz → 2.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quadra-2.2.7 → quadra-2.3.0}/PKG-INFO +9 -9
- {quadra-2.2.7 → quadra-2.3.0}/pyproject.toml +14 -15
- {quadra-2.2.7 → quadra-2.3.0}/quadra/__init__.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/callbacks/anomalib.py +3 -2
- {quadra-2.2.7 → quadra-2.3.0}/quadra/callbacks/lightning.py +2 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/base.py +5 -5
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/classification.py +2 -2
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/segmentation.py +6 -6
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datasets/anomaly.py +2 -2
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datasets/classification.py +7 -7
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datasets/patch.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/metrics/segmentation.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/models/base.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/models/evaluation.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/base.py +3 -2
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/byol.py +1 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/anomaly.py +7 -4
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/base.py +8 -4
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/classification.py +6 -2
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/patch.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/segmentation.py +7 -5
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/ssl.py +2 -3
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/classification.py +8 -10
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/evaluation.py +12 -3
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/export.py +4 -4
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/mlflow.py +2 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/models.py +5 -7
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/patch/dataset.py +7 -6
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/patch/metrics.py +9 -6
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/patch/visualization.py +2 -2
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/utils.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/validator.py +1 -3
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/visualization.py +8 -5
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/vit_explainability.py +1 -1
- {quadra-2.2.7 → quadra-2.3.0}/LICENSE +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/README.md +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/callbacks/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/callbacks/mlflow.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/callbacks/scheduler.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/caformer_m36.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/caformer_s36.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/convnextv2_base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/convnextv2_femto.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/convnextv2_tiny.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/dino_vitb8.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/dino_vits8.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/dinov2_vitb14.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/dinov2_vits14.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/efficientnet_b0.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/efficientnet_b1.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/efficientnet_b2.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/efficientnet_b3.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/efficientnetv2_s.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/levit_128s.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/mnasnet0_5.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/resnet101.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/resnet18.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/resnet18_ssl.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/resnet50.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/smp.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/tiny_vit_21m_224.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/unetr.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/vit16_base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/vit16_small.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/vit16_tiny.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/callbacks/all.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/callbacks/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/callbacks/default_anomalib.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/config.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/core/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/anomaly.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/multilabel_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/segmentation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/segmentation_multiclass.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/sklearn_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/sklearn_classification_patch.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/base/ssl.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/generic/imagenette/classification/base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/cfa.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/cflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/csflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/draem.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/efficient_ad.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/fastflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/inference.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/padim.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/anomaly/patchcore.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/classification_evaluation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/multilabel_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/sklearn_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/segmentation/smp.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/ssl/barlow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/ssl/byol.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/ssl/dino.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/ssl/linear_eval.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/ssl/simclr.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/base/ssl/simsiam.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/custom/cls.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/imagenette/classification/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/export/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/hydra/anomaly_custom.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/hydra/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/inference/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/logger/comet.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/logger/csv.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/logger/mlflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/logger/tensorboard.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/asl.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/barlow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/bce.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/byol.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/cross_entropy.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/dino.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/simclr.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/simsiam.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/smp_ce.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/smp_dice.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/smp_dice_multiclass.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/smp_mcc.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/loss/vicreg.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/cfa.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/cflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/csflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/dfm.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/draem.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/efficient_ad.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/fastflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/padim.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/anomalib/patchcore.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/barlow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/byol.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/dino.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/logistic_regression.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/multilabel_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/simclr.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/simsiam.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/smp.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/smp_multiclass.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/model/vicreg.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/optimizer/adam.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/optimizer/adamw.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/optimizer/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/optimizer/lars.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/optimizer/sgd.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/scheduler/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/scheduler/rop.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/scheduler/step.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/scheduler/warmrestart.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/scheduler/warmup.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/cfa.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/cflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/csflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/draem.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/efficient_ad.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/fastflow.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/inference.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/padim.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/anomalib/patchcore.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/classification_evaluation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/segmentation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/segmentation_evaluation.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/sklearn_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/sklearn_classification_patch.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/sklearn_classification_patch_test.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/sklearn_classification_test.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/task/ssl.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/trainer/lightning_cpu.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/trainer/lightning_gpu.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/trainer/lightning_gpu_bf16.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/trainer/lightning_gpu_fp16.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/trainer/lightning_multigpu.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/trainer/sklearn_classification.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/byol.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/byol_no_random_resize.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/default.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/default_numpy.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/default_resize.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/dino.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/configs/transforms/linear_eval.yaml +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/anomaly.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/generic/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/generic/imagenette.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/generic/mnist.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/generic/mvtec.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/generic/oxford_pet.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/patch.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datamodules/ssl.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datasets/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datasets/segmentation.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/datasets/ssl.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/classification/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/classification/asl.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/classification/focal.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/classification/prototypical.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/barlowtwins.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/byol.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/dino.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/hyperspherical.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/idmm.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/simclr.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/simsiam.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/losses/ssl/vicreg.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/main.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/metrics/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/models/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/models/classification/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/models/classification/backbones.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/models/classification/base.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/backbone.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/classification/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/classification/base.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/barlowtwins.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/common.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/dino.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/hyperspherical.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/idmm.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/simclr.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/simsiam.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/modules/ssl/vicreg.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/optimizers/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/optimizers/lars.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/optimizers/sam.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/schedulers/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/schedulers/base.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/schedulers/warmup.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/tasks/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/trainers/README.md +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/trainers/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/trainers/classification.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/anomaly.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/deprecation.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/imaging.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/logger.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/model_manager.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/patch/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/patch/model.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/resolver.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/segmentation.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/dataset/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/dataset/anomaly.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/dataset/classification.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/dataset/imagenette.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/dataset/segmentation.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/models/__init__.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/models/anomaly.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/models/classification.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/fixtures/models/segmentation.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/helpers.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra/utils/tests/models.py +0 -0
- {quadra-2.2.7 → quadra-2.3.0}/quadra_hydra_plugin/hydra_plugins/quadra_searchpath_plugin.py +0 -0
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: quadra
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.3.0
|
|
4
4
|
Summary: Deep Learning experiment orchestration library
|
|
5
5
|
Home-page: https://orobix.github.io/quadra
|
|
6
6
|
License: Apache-2.0
|
|
7
7
|
Keywords: deep learning,experiment,lightning,hydra-core
|
|
8
8
|
Author: Federico Belotti
|
|
9
9
|
Author-email: federico.belotti@orobix.com
|
|
10
|
-
Requires-Python: >=3.
|
|
10
|
+
Requires-Python: >=3.10,<3.11
|
|
11
11
|
Classifier: Intended Audience :: Developers
|
|
12
12
|
Classifier: Intended Audience :: Education
|
|
13
13
|
Classifier: Intended Audience :: Science/Research
|
|
14
14
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
15
15
|
Classifier: Programming Language :: Python :: 3
|
|
16
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.10
|
|
18
17
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
18
|
Classifier: Topic :: Software Development
|
|
@@ -21,7 +20,7 @@ Classifier: Topic :: Software Development :: Libraries
|
|
|
21
20
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
21
|
Provides-Extra: onnx
|
|
23
22
|
Requires-Dist: albumentations (>=1.3,<1.4)
|
|
24
|
-
Requires-Dist: anomalib-orobix (==0.7.0.
|
|
23
|
+
Requires-Dist: anomalib-orobix (==0.7.0.dev150)
|
|
25
24
|
Requires-Dist: boto3 (>=1.26,<1.27)
|
|
26
25
|
Requires-Dist: grad-cam-orobix (==1.5.3.dev001)
|
|
27
26
|
Requires-Dist: h5py (>=3.8,<3.9)
|
|
@@ -32,17 +31,18 @@ Requires-Dist: label_studio_converter (>=0.0,<0.1)
|
|
|
32
31
|
Requires-Dist: matplotlib (>=3.6,<3.7)
|
|
33
32
|
Requires-Dist: minio (>=7.1,<7.2)
|
|
34
33
|
Requires-Dist: mlflow-skinny (>=2.3.1,<3.0.0)
|
|
34
|
+
Requires-Dist: numpy (<2)
|
|
35
35
|
Requires-Dist: nvitop (>=0.11,<0.12)
|
|
36
36
|
Requires-Dist: onnx (==1.15.0) ; extra == "onnx"
|
|
37
37
|
Requires-Dist: onnxconverter-common (>=1.14.0,<2.0.0) ; extra == "onnx"
|
|
38
|
-
Requires-Dist: onnxruntime_gpu (==1.
|
|
38
|
+
Requires-Dist: onnxruntime_gpu (==1.20.0) ; extra == "onnx"
|
|
39
39
|
Requires-Dist: onnxsim (==0.4.28) ; extra == "onnx"
|
|
40
40
|
Requires-Dist: opencv_python_headless (>=4.7.0,<4.8.0)
|
|
41
41
|
Requires-Dist: pandas (<2.0)
|
|
42
|
-
Requires-Dist: pillow (>=
|
|
42
|
+
Requires-Dist: pillow (>=10,<11)
|
|
43
43
|
Requires-Dist: pydantic (==1.10.10)
|
|
44
44
|
Requires-Dist: python_dotenv (>=0.21,<0.22)
|
|
45
|
-
Requires-Dist: pytorch_lightning (>=2.
|
|
45
|
+
Requires-Dist: pytorch_lightning (>=2.4,<2.5)
|
|
46
46
|
Requires-Dist: rich (>=13.2,<13.3)
|
|
47
47
|
Requires-Dist: scikit_learn (>=1.2,<1.3)
|
|
48
48
|
Requires-Dist: scikit_multilearn (>=0.2,<0.3)
|
|
@@ -50,11 +50,11 @@ Requires-Dist: seaborn (>=0.12,<0.13)
|
|
|
50
50
|
Requires-Dist: segmentation_models_pytorch-orobix (==0.3.3.dev1)
|
|
51
51
|
Requires-Dist: tensorboard (>=2.11,<2.12)
|
|
52
52
|
Requires-Dist: timm (==0.9.12)
|
|
53
|
-
Requires-Dist: torch (==2.1
|
|
53
|
+
Requires-Dist: torch (==2.4.1)
|
|
54
54
|
Requires-Dist: torchinfo (>=1.8,<1.9)
|
|
55
55
|
Requires-Dist: torchmetrics (>=0.10,<0.11)
|
|
56
56
|
Requires-Dist: torchsummary (>=1.5,<1.6)
|
|
57
|
-
Requires-Dist: torchvision (>=0.
|
|
57
|
+
Requires-Dist: torchvision (>=0.19,<0.20)
|
|
58
58
|
Requires-Dist: tripy (>=1.0,<1.1)
|
|
59
59
|
Requires-Dist: typing_extensions (==4.11.0) ; python_version < "3.10"
|
|
60
60
|
Requires-Dist: xxhash (>=3.2,<3.3)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "quadra"
|
|
3
|
-
version = "2.
|
|
3
|
+
version = "2.3.0"
|
|
4
4
|
description = "Deep Learning experiment orchestration library"
|
|
5
5
|
authors = [
|
|
6
6
|
"Federico Belotti <federico.belotti@orobix.com>",
|
|
@@ -38,12 +38,13 @@ build-backend = "poetry.core.masonry.api"
|
|
|
38
38
|
quadra = "quadra.main:main"
|
|
39
39
|
|
|
40
40
|
[tool.poetry.dependencies]
|
|
41
|
-
python = ">=3.
|
|
41
|
+
python = ">=3.10,<3.11"
|
|
42
42
|
|
|
43
|
-
torch = { version = "2.1
|
|
44
|
-
torchvision = { version = "~0.
|
|
43
|
+
torch = { version = "2.4.1", source = "torch_cu121" }
|
|
44
|
+
torchvision = { version = "~0.19", source = "torch_cu121" }
|
|
45
45
|
|
|
46
|
-
pytorch_lightning = "~2.
|
|
46
|
+
pytorch_lightning = "~2.4"
|
|
47
|
+
numpy = "<2"
|
|
47
48
|
torchsummary = "~1.5"
|
|
48
49
|
torchmetrics = "~0.10"
|
|
49
50
|
hydra_core = "~1.3"
|
|
@@ -53,7 +54,7 @@ mlflow-skinny = "^2.3.1"
|
|
|
53
54
|
boto3 = "~1.26"
|
|
54
55
|
minio = "~7.1"
|
|
55
56
|
tensorboard = "~2.11"
|
|
56
|
-
pillow = "
|
|
57
|
+
pillow = "^10"
|
|
57
58
|
pandas = "<2.0"
|
|
58
59
|
opencv_python_headless = "~4.7.0"
|
|
59
60
|
python_dotenv = "~0.21"
|
|
@@ -72,7 +73,7 @@ h5py = "~3.8"
|
|
|
72
73
|
timm = "0.9.12"
|
|
73
74
|
|
|
74
75
|
segmentation_models_pytorch-orobix = "0.3.3.dev1"
|
|
75
|
-
anomalib-orobix = "0.7.0.
|
|
76
|
+
anomalib-orobix = "0.7.0.dev150"
|
|
76
77
|
xxhash = "~3.2"
|
|
77
78
|
torchinfo = "~1.8"
|
|
78
79
|
typing_extensions = { version = "4.11.0", python = "<3.10" }
|
|
@@ -80,7 +81,7 @@ typing_extensions = { version = "4.11.0", python = "<3.10" }
|
|
|
80
81
|
# ONNX dependencies
|
|
81
82
|
onnx = { version = "1.15.0", optional = true }
|
|
82
83
|
onnxsim = { version = "0.4.28", optional = true }
|
|
83
|
-
onnxruntime_gpu = { version = "1.
|
|
84
|
+
onnxruntime_gpu = { version = "1.20.0", optional = true }
|
|
84
85
|
onnxconverter-common = { version = "^1.14.0", optional = true }
|
|
85
86
|
|
|
86
87
|
[[tool.poetry.source]]
|
|
@@ -88,10 +89,6 @@ name = "torch_cu121"
|
|
|
88
89
|
url = "https://download.pytorch.org/whl/cu121"
|
|
89
90
|
priority = "explicit"
|
|
90
91
|
|
|
91
|
-
[[tool.poetry.source]]
|
|
92
|
-
name = "onnx_cu12"
|
|
93
|
-
url = "https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/"
|
|
94
|
-
priority = "explicit"
|
|
95
92
|
|
|
96
93
|
[tool.poetry.group.dev]
|
|
97
94
|
optional = true
|
|
@@ -100,12 +97,14 @@ optional = true
|
|
|
100
97
|
hydra-plugins = { path = "quadra_hydra_plugin" }
|
|
101
98
|
# Dev dependencies
|
|
102
99
|
interrogate = "~1.5"
|
|
103
|
-
pre_commit = "
|
|
104
|
-
pylint = "
|
|
100
|
+
pre_commit = "^3.0"
|
|
101
|
+
pylint = "^3.3"
|
|
105
102
|
types_pyyaml = "~6.0.12"
|
|
106
103
|
mypy = "^1.9.0"
|
|
107
104
|
pandas_stubs = "~1.5.3"
|
|
108
105
|
twine = "~4.0"
|
|
106
|
+
ipython = ">8"
|
|
107
|
+
ipykernel = ">6"
|
|
109
108
|
|
|
110
109
|
|
|
111
110
|
# Test dependencies
|
|
@@ -229,7 +228,7 @@ exclude = ["quadra/utils/tests", "tests"]
|
|
|
229
228
|
|
|
230
229
|
[tool.ruff]
|
|
231
230
|
extend-include = ["*.ipynb"]
|
|
232
|
-
target-version = "
|
|
231
|
+
target-version = "py310"
|
|
233
232
|
# Orobix guidelines
|
|
234
233
|
line-length = 120
|
|
235
234
|
indent-width = 4
|
|
@@ -64,7 +64,7 @@ class Visualizer:
|
|
|
64
64
|
self.figure.subplots_adjust(right=0.9)
|
|
65
65
|
|
|
66
66
|
axes = self.axis if len(self.images) > 1 else [self.axis]
|
|
67
|
-
for axis, image_dict in zip(axes, self.images):
|
|
67
|
+
for axis, image_dict in zip(axes, self.images, strict=False):
|
|
68
68
|
axis.axes.xaxis.set_visible(False)
|
|
69
69
|
axis.axes.yaxis.set_visible(False)
|
|
70
70
|
axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255)
|
|
@@ -201,6 +201,7 @@ class VisualizerCallback(Callback):
|
|
|
201
201
|
outputs["label"],
|
|
202
202
|
outputs["pred_labels"],
|
|
203
203
|
outputs["pred_scores"],
|
|
204
|
+
strict=False,
|
|
204
205
|
)
|
|
205
206
|
):
|
|
206
207
|
denormalized_image = Denormalize()(image.cpu())
|
|
@@ -256,7 +257,7 @@ class VisualizerCallback(Callback):
|
|
|
256
257
|
visualizer.close()
|
|
257
258
|
|
|
258
259
|
if self.plot_raw_outputs:
|
|
259
|
-
for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"]):
|
|
260
|
+
for raw_output, raw_name in zip([heatmap, vis_img], ["heatmap", "segmentation"], strict=False):
|
|
260
261
|
current_raw_output = raw_output
|
|
261
262
|
if raw_name == "segmentation":
|
|
262
263
|
current_raw_output = (raw_output * 255).astype(np.uint8)
|
|
@@ -79,6 +79,8 @@ def _scale_batch_size(
|
|
|
79
79
|
new_size = _run_power_scaling(trainer, init_val, batch_arg_name, max_trials, params)
|
|
80
80
|
elif mode == "binsearch":
|
|
81
81
|
new_size = _run_binary_scaling(trainer, init_val, batch_arg_name, max_trials, params)
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Unknown mode {mode}")
|
|
82
84
|
|
|
83
85
|
garbage_collection_cuda()
|
|
84
86
|
|
|
@@ -7,7 +7,7 @@ import pickle as pkl
|
|
|
7
7
|
import typing
|
|
8
8
|
from collections.abc import Callable, Iterable, Sequence
|
|
9
9
|
from functools import wraps
|
|
10
|
-
from typing import Any, Literal,
|
|
10
|
+
from typing import Any, Literal, cast
|
|
11
11
|
|
|
12
12
|
import albumentations
|
|
13
13
|
import numpy as np
|
|
@@ -20,8 +20,8 @@ from tqdm import tqdm
|
|
|
20
20
|
from quadra.utils import utils
|
|
21
21
|
|
|
22
22
|
log = utils.get_logger(__name__)
|
|
23
|
-
TrainDataset =
|
|
24
|
-
ValDataset =
|
|
23
|
+
TrainDataset = torch.utils.data.Dataset | Sequence[torch.utils.data.Dataset]
|
|
24
|
+
ValDataset = torch.utils.data.Dataset | Sequence[torch.utils.data.Dataset]
|
|
25
25
|
TestDataset = torch.utils.data.Dataset
|
|
26
26
|
|
|
27
27
|
|
|
@@ -260,7 +260,7 @@ class BaseDataModule(LightningDataModule, metaclass=DecorateParentMethod):
|
|
|
260
260
|
return
|
|
261
261
|
|
|
262
262
|
# TODO: We need to find a way to annotate the columns of data.
|
|
263
|
-
paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data))
|
|
263
|
+
paths_and_hash_length = zip(self.data["samples"], [self.hash_size] * len(self.data), strict=False)
|
|
264
264
|
|
|
265
265
|
with mp.Pool(min(8, mp.cpu_count() - 1)) as pool:
|
|
266
266
|
self.data["hash"] = list(
|
|
@@ -355,7 +355,7 @@ class BaseDataModule(LightningDataModule, metaclass=DecorateParentMethod):
|
|
|
355
355
|
raise ValueError("`n_aug_to_take` is not set. Cannot load augmented samples.")
|
|
356
356
|
aug_samples = []
|
|
357
357
|
aug_labels = []
|
|
358
|
-
for sample, label in zip(samples, targets):
|
|
358
|
+
for sample, label in zip(samples, targets, strict=False):
|
|
359
359
|
aug_samples.append(sample)
|
|
360
360
|
aug_labels.append(label)
|
|
361
361
|
final_sample = sample
|
|
@@ -243,7 +243,7 @@ class ClassificationDataModule(BaseDataModule):
|
|
|
243
243
|
samples_test, targets_test = self._read_split(self.test_split_file)
|
|
244
244
|
if not self.train_split_file:
|
|
245
245
|
samples_train, targets_train = [], []
|
|
246
|
-
for sample, target in zip(all_samples, all_targets):
|
|
246
|
+
for sample, target in zip(all_samples, all_targets, strict=False):
|
|
247
247
|
if sample not in samples_test:
|
|
248
248
|
samples_train.append(sample)
|
|
249
249
|
targets_train.append(target)
|
|
@@ -251,7 +251,7 @@ class ClassificationDataModule(BaseDataModule):
|
|
|
251
251
|
samples_train, targets_train = self._read_split(self.train_split_file)
|
|
252
252
|
if not self.test_split_file:
|
|
253
253
|
samples_test, targets_test = [], []
|
|
254
|
-
for sample, target in zip(all_samples, all_targets):
|
|
254
|
+
for sample, target in zip(all_samples, all_targets, strict=False):
|
|
255
255
|
if sample not in samples_train:
|
|
256
256
|
samples_test.append(sample)
|
|
257
257
|
targets_test.append(target)
|
|
@@ -187,7 +187,7 @@ class SegmentationDataModule(BaseDataModule):
|
|
|
187
187
|
samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
|
|
188
188
|
if not self.train_split_file:
|
|
189
189
|
samples_train, targets_train, masks_train = [], [], []
|
|
190
|
-
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
190
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
|
|
191
191
|
if sample not in samples_test:
|
|
192
192
|
samples_train.append(sample)
|
|
193
193
|
targets_train.append(target)
|
|
@@ -197,7 +197,7 @@ class SegmentationDataModule(BaseDataModule):
|
|
|
197
197
|
samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
|
|
198
198
|
if not self.test_split_file:
|
|
199
199
|
samples_test, targets_test, masks_test = [], [], []
|
|
200
|
-
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
200
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
|
|
201
201
|
if sample not in samples_train:
|
|
202
202
|
samples_test.append(sample)
|
|
203
203
|
targets_test.append(target)
|
|
@@ -549,7 +549,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
|
|
|
549
549
|
samples_and_masks_test,
|
|
550
550
|
targets_test,
|
|
551
551
|
) = iterative_train_test_split(
|
|
552
|
-
np.expand_dims(np.array(list(zip(all_samples, all_masks))), 1),
|
|
552
|
+
np.expand_dims(np.array(list(zip(all_samples, all_masks, strict=False))), 1),
|
|
553
553
|
np.array(all_targets),
|
|
554
554
|
test_size=self.test_size,
|
|
555
555
|
)
|
|
@@ -561,7 +561,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
|
|
|
561
561
|
samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
|
|
562
562
|
if not self.train_split_file:
|
|
563
563
|
samples_train, targets_train, masks_train = [], [], []
|
|
564
|
-
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
564
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
|
|
565
565
|
if sample not in samples_test:
|
|
566
566
|
samples_train.append(sample)
|
|
567
567
|
targets_train.append(target)
|
|
@@ -571,7 +571,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
|
|
|
571
571
|
samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
|
|
572
572
|
if not self.test_split_file:
|
|
573
573
|
samples_test, targets_test, masks_test = [], [], []
|
|
574
|
-
for sample, target, mask in zip(all_samples, all_targets, all_masks):
|
|
574
|
+
for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
|
|
575
575
|
if sample not in samples_train:
|
|
576
576
|
samples_test.append(sample)
|
|
577
577
|
targets_test.append(target)
|
|
@@ -583,7 +583,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
|
|
|
583
583
|
raise ValueError("Validation split file is specified but no train or test split file is specified.")
|
|
584
584
|
else:
|
|
585
585
|
samples_and_masks_train, targets_train, samples_and_masks_val, targets_val = iterative_train_test_split(
|
|
586
|
-
np.expand_dims(np.array(list(zip(samples_train, masks_train))), 1),
|
|
586
|
+
np.expand_dims(np.array(list(zip(samples_train, masks_train, strict=False))), 1),
|
|
587
587
|
np.array(targets_train),
|
|
588
588
|
test_size=self.val_size,
|
|
589
589
|
)
|
|
@@ -220,7 +220,7 @@ class AnomalyDataset(Dataset):
|
|
|
220
220
|
if not os.path.exists(valid_area_mask):
|
|
221
221
|
raise RuntimeError(f"Valid area mask {valid_area_mask} does not exist.")
|
|
222
222
|
|
|
223
|
-
self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0
|
|
223
|
+
self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0
|
|
224
224
|
|
|
225
225
|
def __len__(self) -> int:
|
|
226
226
|
"""Get length of the dataset."""
|
|
@@ -265,7 +265,7 @@ class AnomalyDataset(Dataset):
|
|
|
265
265
|
if label_index == 0:
|
|
266
266
|
mask = np.zeros(shape=original_image_shape[:2])
|
|
267
267
|
elif os.path.isfile(mask_path):
|
|
268
|
-
mask = cv2.imread(mask_path, flags=0) / 255.0
|
|
268
|
+
mask = cv2.imread(mask_path, flags=0) / 255.0
|
|
269
269
|
else:
|
|
270
270
|
# We need ones in the mask to compute correctly at least image level f1 score
|
|
271
271
|
mask = np.ones(shape=original_image_shape[:2])
|
|
@@ -50,9 +50,9 @@ class ImageClassificationListDataset(Dataset):
|
|
|
50
50
|
allow_missing_label: bool | None = False,
|
|
51
51
|
):
|
|
52
52
|
super().__init__()
|
|
53
|
-
assert len(samples) == len(
|
|
54
|
-
targets
|
|
55
|
-
)
|
|
53
|
+
assert len(samples) == len(targets), (
|
|
54
|
+
f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
|
|
55
|
+
)
|
|
56
56
|
# Setting the ROI
|
|
57
57
|
self.roi = roi
|
|
58
58
|
|
|
@@ -201,9 +201,9 @@ class MultilabelClassificationDataset(torch.utils.data.Dataset):
|
|
|
201
201
|
rgb: bool = True,
|
|
202
202
|
):
|
|
203
203
|
super().__init__()
|
|
204
|
-
assert len(samples) == len(
|
|
205
|
-
targets
|
|
206
|
-
)
|
|
204
|
+
assert len(samples) == len(targets), (
|
|
205
|
+
f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
|
|
206
|
+
)
|
|
207
207
|
|
|
208
208
|
# Data
|
|
209
209
|
self.x = samples
|
|
@@ -215,7 +215,7 @@ class MultilabelClassificationDataset(torch.utils.data.Dataset):
|
|
|
215
215
|
class_to_idx = {c: i for i, c in enumerate(range(unique_targets))}
|
|
216
216
|
self.class_to_idx = class_to_idx
|
|
217
217
|
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
218
|
-
self.samples = list(zip(self.x, self.y))
|
|
218
|
+
self.samples = list(zip(self.x, self.y, strict=False))
|
|
219
219
|
self.rgb = rgb
|
|
220
220
|
self.transform = transform
|
|
221
221
|
|
|
@@ -58,7 +58,7 @@ class PatchSklearnClassificationTrainDataset(Dataset):
|
|
|
58
58
|
|
|
59
59
|
cls, counts = np.unique(targets_array, return_counts=True)
|
|
60
60
|
max_count = np.max(counts)
|
|
61
|
-
for cl, count in zip(cls, counts):
|
|
61
|
+
for cl, count in zip(cls, counts, strict=False):
|
|
62
62
|
idx_to_pick = list(np.where(targets_array == cl)[0])
|
|
63
63
|
|
|
64
64
|
if count < max_count:
|
|
@@ -171,7 +171,7 @@ def segmentation_props(
|
|
|
171
171
|
# Add dummy Dices so LSA is unique and i can compute FP and FN
|
|
172
172
|
dice_mat = _pad_to_shape(dice_mat, (max_dim, max_dim), 1)
|
|
173
173
|
lsa = linear_sum_assignment(dice_mat, maximize=False)
|
|
174
|
-
for row, col in zip(lsa[0], lsa[1]):
|
|
174
|
+
for row, col in zip(lsa[0], lsa[1], strict=False):
|
|
175
175
|
# More preds than GTs --> False Positive
|
|
176
176
|
if row < n_labels_pred and col >= n_labels_mask:
|
|
177
177
|
min_row = pred_bbox[row][0]
|
|
@@ -76,7 +76,7 @@ class ModelSignatureWrapper(nn.Module):
|
|
|
76
76
|
|
|
77
77
|
if isinstance(self.instance.forward, torch.ScriptMethod):
|
|
78
78
|
# Handle torchscript backbones
|
|
79
|
-
for i, argument in enumerate(self.instance.forward.schema.arguments):
|
|
79
|
+
for i, argument in enumerate(self.instance.forward.schema.arguments): # type: ignore[attr-defined]
|
|
80
80
|
if i < (len(args) + 1): # +1 for self
|
|
81
81
|
continue
|
|
82
82
|
|
|
@@ -209,7 +209,7 @@ class ONNXEvaluationModel(BaseEvaluationModel):
|
|
|
209
209
|
|
|
210
210
|
onnx_inputs: dict[str, np.ndarray | torch.Tensor] = {}
|
|
211
211
|
|
|
212
|
-
for onnx_input, current_input in zip(self.model.get_inputs(), inputs):
|
|
212
|
+
for onnx_input, current_input in zip(self.model.get_inputs(), inputs, strict=False):
|
|
213
213
|
if isinstance(current_input, torch.Tensor):
|
|
214
214
|
onnx_inputs[onnx_input.name] = current_input
|
|
215
215
|
use_pytorch = True
|
|
@@ -7,6 +7,7 @@ import pytorch_lightning as pl
|
|
|
7
7
|
import sklearn
|
|
8
8
|
import torch
|
|
9
9
|
import torchmetrics
|
|
10
|
+
from pytorch_lightning.utilities.types import OptimizerLRScheduler
|
|
10
11
|
from sklearn.linear_model import LogisticRegression
|
|
11
12
|
from torch import nn
|
|
12
13
|
from torch.optim import Optimizer
|
|
@@ -48,7 +49,7 @@ class BaseLightningModule(pl.LightningModule):
|
|
|
48
49
|
"""
|
|
49
50
|
return self.model(x)
|
|
50
51
|
|
|
51
|
-
def configure_optimizers(self) ->
|
|
52
|
+
def configure_optimizers(self) -> OptimizerLRScheduler:
|
|
52
53
|
"""Get default optimizer if not passed a value.
|
|
53
54
|
|
|
54
55
|
Returns:
|
|
@@ -68,7 +69,7 @@ class BaseLightningModule(pl.LightningModule):
|
|
|
68
69
|
"monitor": "val_loss",
|
|
69
70
|
"strict": False,
|
|
70
71
|
}
|
|
71
|
-
return [self.optimizer], [lr_scheduler_conf]
|
|
72
|
+
return [self.optimizer], [lr_scheduler_conf] # type: ignore[return-value]
|
|
72
73
|
|
|
73
74
|
# pylint: disable=unused-argument
|
|
74
75
|
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx: int = 0):
|
|
@@ -110,6 +110,7 @@ class BYOL(SSLModule):
|
|
|
110
110
|
for student_ps, teacher_ps in zip(
|
|
111
111
|
list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
|
|
112
112
|
list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
|
|
113
|
+
strict=False,
|
|
113
114
|
):
|
|
114
115
|
teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data
|
|
115
116
|
|
|
@@ -161,7 +161,7 @@ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataMo
|
|
|
161
161
|
all_output_flatten: dict[str, torch.Tensor | list] = {}
|
|
162
162
|
|
|
163
163
|
for key in all_output[0]:
|
|
164
|
-
if
|
|
164
|
+
if isinstance(all_output[0][key], torch.Tensor):
|
|
165
165
|
tensor_gatherer = torch.cat([x[key] for x in all_output])
|
|
166
166
|
all_output_flatten[key] = tensor_gatherer
|
|
167
167
|
else:
|
|
@@ -205,13 +205,15 @@ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataMo
|
|
|
205
205
|
class_to_idx.pop("false_defect")
|
|
206
206
|
|
|
207
207
|
anomaly_scores = all_output_flatten["pred_scores"]
|
|
208
|
+
|
|
209
|
+
exportable_anomaly_scores: list[Any] | np.ndarray
|
|
208
210
|
if isinstance(anomaly_scores, torch.Tensor):
|
|
209
211
|
exportable_anomaly_scores = anomaly_scores.cpu().numpy()
|
|
210
212
|
else:
|
|
211
213
|
exportable_anomaly_scores = anomaly_scores
|
|
212
214
|
|
|
213
215
|
# Zip the lists together to create rows for the CSV file
|
|
214
|
-
rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores)
|
|
216
|
+
rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores, strict=False)
|
|
215
217
|
# Specify the CSV file name
|
|
216
218
|
csv_file = "test_predictions.csv"
|
|
217
219
|
# Write the data to the CSV file
|
|
@@ -483,7 +485,7 @@ class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
|
|
|
483
485
|
|
|
484
486
|
if hasattr(self.datamodule, "valid_area_mask") and self.datamodule.valid_area_mask is not None:
|
|
485
487
|
mask_area = cv2.imread(self.datamodule.valid_area_mask, 0)
|
|
486
|
-
mask_area = (mask_area > 0).astype(np.uint8)
|
|
488
|
+
mask_area = (mask_area > 0).astype(np.uint8)
|
|
487
489
|
|
|
488
490
|
if hasattr(self.datamodule, "crop_area") and self.datamodule.crop_area is not None:
|
|
489
491
|
crop_area = self.datamodule.crop_area
|
|
@@ -499,12 +501,13 @@ class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
|
|
|
499
501
|
self.metadata["image_labels"],
|
|
500
502
|
anomaly_scores,
|
|
501
503
|
anomaly_maps,
|
|
504
|
+
strict=False,
|
|
502
505
|
),
|
|
503
506
|
total=len(self.metadata["image_paths"]),
|
|
504
507
|
):
|
|
505
508
|
img = cv2.imread(img_path, 0)
|
|
506
509
|
if mask_area is not None:
|
|
507
|
-
img = img * mask_area
|
|
510
|
+
img = img * mask_area
|
|
508
511
|
|
|
509
512
|
if crop_area is not None:
|
|
510
513
|
img = img[crop_area[1] : crop_area[3], crop_area[0] : crop_area[2]]
|
|
@@ -382,15 +382,19 @@ class Evaluation(Generic[DataModuleT], Task[DataModuleT]):
|
|
|
382
382
|
# We assume that each input size has the same height and width
|
|
383
383
|
if input_size[1] != self.config.transforms.input_height:
|
|
384
384
|
log.warning(
|
|
385
|
-
|
|
386
|
-
+
|
|
385
|
+
"Input height of the model (%s) is different from the one specified "
|
|
386
|
+
+ "in the config (%s). Fixing the config.",
|
|
387
|
+
input_size[1],
|
|
388
|
+
self.config.transforms.input_height,
|
|
387
389
|
)
|
|
388
390
|
self.config.transforms.input_height = input_size[1]
|
|
389
391
|
|
|
390
392
|
if input_size[2] != self.config.transforms.input_width:
|
|
391
393
|
log.warning(
|
|
392
|
-
|
|
393
|
-
+
|
|
394
|
+
"Input width of the model (%s) is different from the one specified "
|
|
395
|
+
+ "in the config (%s). Fixing the config.",
|
|
396
|
+
input_size[2],
|
|
397
|
+
self.config.transforms.input_width,
|
|
394
398
|
)
|
|
395
399
|
self.config.transforms.input_width = input_size[2]
|
|
396
400
|
|
|
@@ -623,7 +623,9 @@ class SklearnClassification(Generic[SklearnClassificationDataModuleT], Task[Skle
|
|
|
623
623
|
all_labels = all_labels[sorted_indices]
|
|
624
624
|
|
|
625
625
|
# cycle over all train/test split
|
|
626
|
-
for train_dataloader, test_dataloader in zip(
|
|
626
|
+
for train_dataloader, test_dataloader in zip(
|
|
627
|
+
self.train_dataloader_list, self.test_dataloader_list, strict=False
|
|
628
|
+
):
|
|
627
629
|
# Reinit classifier
|
|
628
630
|
self.model = self.config.model
|
|
629
631
|
self.trainer.change_classifier(self.model)
|
|
@@ -685,7 +687,7 @@ class SklearnClassification(Generic[SklearnClassificationDataModuleT], Task[Skle
|
|
|
685
687
|
dl: PyTorch dataloader
|
|
686
688
|
feature_extractor: PyTorch backbone
|
|
687
689
|
"""
|
|
688
|
-
if isinstance(feature_extractor,
|
|
690
|
+
if isinstance(feature_extractor, TorchEvaluationModel | TorchscriptEvaluationModel):
|
|
689
691
|
# TODO: I'm not sure torchinfo supports torchscript models
|
|
690
692
|
# If we are working with torch based evaluation models we need to extract the model
|
|
691
693
|
feature_extractor = feature_extractor.model
|
|
@@ -1202,6 +1204,8 @@ class ClassificationEvaluation(Evaluation[ClassificationDataModuleT]):
|
|
|
1202
1204
|
probabilities = [max(item) for sublist in probabilities for item in sublist]
|
|
1203
1205
|
if self.datamodule.class_to_idx is not None:
|
|
1204
1206
|
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
1207
|
+
else:
|
|
1208
|
+
idx_to_class = None
|
|
1205
1209
|
|
|
1206
1210
|
_, pd_cm, test_accuracy = get_results(
|
|
1207
1211
|
test_labels=image_labels,
|
|
@@ -301,7 +301,7 @@ class PatchSklearnTestClassification(Evaluation[PatchSklearnClassificationDataMo
|
|
|
301
301
|
"test_results": None,
|
|
302
302
|
"test_labels": None,
|
|
303
303
|
}
|
|
304
|
-
self.class_to_skip: list[str] = []
|
|
304
|
+
self.class_to_skip: list[str] | None = []
|
|
305
305
|
self.reconstruction_results: dict[str, Any]
|
|
306
306
|
self.return_polygon: bool = True
|
|
307
307
|
|
|
@@ -92,8 +92,10 @@ class Segmentation(Generic[SegmentationDataModuleT], LightningTask[SegmentationD
|
|
|
92
92
|
len(self.datamodule.idx_to_class) + 1
|
|
93
93
|
):
|
|
94
94
|
log.warning(
|
|
95
|
-
|
|
96
|
-
+
|
|
95
|
+
"Number of classes in the model (%s) does not match the number of "
|
|
96
|
+
+ "classes in the datamodule (%d). Updating the model...",
|
|
97
|
+
module_config.model.num_classes,
|
|
98
|
+
len(self.datamodule.idx_to_class),
|
|
97
99
|
)
|
|
98
100
|
module_config.model.num_classes = len(self.datamodule.idx_to_class) + 1
|
|
99
101
|
|
|
@@ -341,7 +343,7 @@ class SegmentationAnalysisEvaluation(SegmentationEvaluation):
|
|
|
341
343
|
if self.datamodule.test_dataset_available:
|
|
342
344
|
stages.append("test")
|
|
343
345
|
dataloaders.append(self.datamodule.test_dataloader())
|
|
344
|
-
for stage, dataloader in zip(stages, dataloaders):
|
|
346
|
+
for stage, dataloader in zip(stages, dataloaders, strict=False):
|
|
345
347
|
log.info("Running inference on %s set with batch size: %d", stage, dataloader.batch_size)
|
|
346
348
|
image_list, mask_list, mask_pred_list, label_list = [], [], [], []
|
|
347
349
|
for batch in dataloader:
|
|
@@ -369,10 +371,10 @@ class SegmentationAnalysisEvaluation(SegmentationEvaluation):
|
|
|
369
371
|
|
|
370
372
|
for stage, output in self.test_output.items():
|
|
371
373
|
image_mean = OmegaConf.to_container(self.config.transforms.mean)
|
|
372
|
-
if not isinstance(image_mean, list) or any(not isinstance(x,
|
|
374
|
+
if not isinstance(image_mean, list) or any(not isinstance(x, int | float) for x in image_mean):
|
|
373
375
|
raise ValueError("Image mean is not a list of float or integer values, please check your config")
|
|
374
376
|
image_std = OmegaConf.to_container(self.config.transforms.std)
|
|
375
|
-
if not isinstance(image_std, list) or any(not isinstance(x,
|
|
377
|
+
if not isinstance(image_std, list) or any(not isinstance(x, int | float) for x in image_std):
|
|
376
378
|
raise ValueError("Image std is not a list of float or integer values, please check your config")
|
|
377
379
|
reports = create_mask_report(
|
|
378
380
|
stage=stage,
|
|
@@ -468,8 +468,7 @@ class EmbeddingVisualization(Task):
|
|
|
468
468
|
self.report_folder = report_folder
|
|
469
469
|
if self.model_path is None:
|
|
470
470
|
raise ValueError(
|
|
471
|
-
"Model path cannot be found!, please specify it in the config or pass it as an argument for"
|
|
472
|
-
" evaluation"
|
|
471
|
+
"Model path cannot be found!, please specify it in the config or pass it as an argument for evaluation"
|
|
473
472
|
)
|
|
474
473
|
self.embeddings_path = os.path.join(self.model_path, self.report_folder)
|
|
475
474
|
if not os.path.exists(self.embeddings_path):
|
|
@@ -547,7 +546,7 @@ class EmbeddingVisualization(Task):
|
|
|
547
546
|
im = interpolate(im, self.embedding_image_size)
|
|
548
547
|
|
|
549
548
|
images.append(im.cpu())
|
|
550
|
-
metadata.extend(zip(targets, class_names, file_paths))
|
|
549
|
+
metadata.extend(zip(targets, class_names, file_paths, strict=False))
|
|
551
550
|
counter += len(im)
|
|
552
551
|
images = torch.cat(images, dim=0)
|
|
553
552
|
embeddings = torch.cat(embeddings, dim=0)
|