quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,37 @@
1
+ input_height: 224
2
+ input_width: 224
3
+ mean: [0.485, 0.456, 0.406]
4
+ std: [0.229, 0.224, 0.225]
5
+
6
+ normalize:
7
+ _target_: albumentations.Compose
8
+ transforms:
9
+ - _target_: albumentations.Normalize
10
+ mean: ${transforms.mean}
11
+ std: ${transforms.std}
12
+ always_apply: True
13
+ - _target_: albumentations.pytorch.ToTensorV2
14
+ always_apply: True
15
+
16
+ resize_center_crop:
17
+ _target_: albumentations.Compose
18
+ transforms:
19
+ - _target_: albumentations.Resize
20
+ height: 256
21
+ width: 256
22
+ interpolation: 2
23
+ always_apply: True
24
+ - _target_: albumentations.CenterCrop
25
+ height: ${transforms.input_height}
26
+ width: ${transforms.input_width}
27
+ always_apply: true
28
+
29
+ standard_transform:
30
+ _target_: albumentations.Compose
31
+ transforms:
32
+ - ${transforms.resize_center_crop}
33
+ - ${transforms.normalize}
34
+
35
+ train_transform: ${transforms.standard_transform}
36
+ val_transform: ${transforms.standard_transform}
37
+ test_transform: ${transforms.standard_transform}
@@ -0,0 +1,24 @@
1
+ input_height: 224
2
+ input_width: 224
3
+
4
+ resize_center_crop:
5
+ _target_: albumentations.Compose
6
+ transforms:
7
+ - _target_: albumentations.Resize
8
+ height: 256
9
+ width: 256
10
+ interpolation: 2
11
+ always_apply: True
12
+ - _target_: albumentations.CenterCrop
13
+ height: ${transforms.input_height}
14
+ width: ${transforms.input_width}
15
+ always_apply: true
16
+
17
+ standard_transform:
18
+ _target_: albumentations.Compose
19
+ transforms:
20
+ - ${transforms.resize_center_crop}
21
+
22
+ train_transform: ${transforms.standard_transform}
23
+ val_transform: ${transforms.standard_transform}
24
+ test_transform: ${transforms.standard_transform}
@@ -0,0 +1,22 @@
1
+ defaults:
2
+ - default
3
+ - _self_
4
+
5
+ input_height: 224
6
+ input_width: 224
7
+
8
+ standard_transform:
9
+ _target_: albumentations.Compose
10
+ transforms:
11
+ - _target_: albumentations.Resize
12
+ height: ${transforms.input_height}
13
+ width: ${transforms.input_width}
14
+ interpolation: 2
15
+ always_apply: True
16
+ - ${transforms.normalize}
17
+
18
+ train_transform: ${transforms.standard_transform}
19
+ val_transform: ${transforms.standard_transform}
20
+ test_transform: ${transforms.standard_transform}
21
+
22
+ name: default_resize
@@ -0,0 +1,63 @@
1
+ defaults:
2
+ - default
3
+ - _self_
4
+
5
+ flip_and_jitter:
6
+ _target_: albumentations.Compose
7
+ transforms:
8
+ - _target_: albumentations.HorizontalFlip
9
+ p: 0.5
10
+ - _target_: albumentations.ColorJitter
11
+ brightness: 0.4
12
+ contrast: 0.4
13
+ saturation: 0.4
14
+ hue: 0.1
15
+ - _target_: albumentations.ToGray
16
+ p: 0.2
17
+
18
+ global_transforms:
19
+ - _target_: albumentations.Compose
20
+ transforms:
21
+ - _target_: albumentations.RandomResizedCrop
22
+ height: ${transforms.input_height}
23
+ width: ${transforms.input_width}
24
+ scale: [0.4, 1.0]
25
+ interpolation: 2
26
+ - ${transforms.flip_and_jitter}
27
+ - _target_: albumentations.GaussianBlur
28
+ blur_limit: 5
29
+ sigma_limit: [0.1, 2]
30
+ p: 1.0
31
+ - ${transforms.normalize}
32
+
33
+ - _target_: albumentations.Compose
34
+ transforms:
35
+ - _target_: albumentations.RandomResizedCrop
36
+ height: ${transforms.input_height}
37
+ width: ${transforms.input_width}
38
+ scale: [0.4, 1.0]
39
+ interpolation: 2
40
+ - ${transforms.flip_and_jitter}
41
+ - _target_: albumentations.GaussianBlur
42
+ blur_limit: 5
43
+ sigma_limit: [0.1, 2]
44
+ p: 0.1
45
+ - _target_: albumentations.Solarize
46
+ threshold: 170
47
+ p: 0.2
48
+ - ${transforms.normalize}
49
+
50
+ local_transform:
51
+ _target_: albumentations.Compose
52
+ transforms:
53
+ - _target_: albumentations.RandomResizedCrop
54
+ height: ${transforms.input_height}
55
+ width: ${transforms.input_width}
56
+ scale: [0.05, 0.4]
57
+ interpolation: 2
58
+ - ${transforms.flip_and_jitter}
59
+ - _target_: albumentations.GaussianBlur
60
+ blur_limit: 5
61
+ sigma_limit: [0.1, 2]
62
+ p: 0.5
63
+ - ${transforms.normalize}
@@ -0,0 +1,18 @@
1
+ defaults:
2
+ - default
3
+ - _self_
4
+
5
+ train_transform:
6
+ _target_: albumentations.Compose
7
+ transforms:
8
+ - _target_: albumentations.RandomResizedCrop
9
+ height: ${transforms.input_height}
10
+ width: ${transforms.input_width}
11
+ scale: [0.08, 1.0]
12
+ interpolation: 2
13
+ always_apply: True
14
+ - _target_: albumentations.HorizontalFlip
15
+ p: 0.5
16
+ - ${transforms.normalize}
17
+ val_transform: ${transforms.standard_transform}
18
+ test_transform: ${transforms.standard_transform}
@@ -0,0 +1,20 @@
1
+ from .anomaly import AnomalyDataModule
2
+ from .classification import (
3
+ ClassificationDataModule,
4
+ MultilabelClassificationDataModule,
5
+ SklearnClassificationDataModule,
6
+ )
7
+ from .patch import PatchSklearnClassificationDataModule
8
+ from .segmentation import SegmentationDataModule, SegmentationMulticlassDataModule
9
+ from .ssl import SSLDataModule
10
+
11
+ __all__ = [
12
+ "AnomalyDataModule",
13
+ "ClassificationDataModule",
14
+ "SklearnClassificationDataModule",
15
+ "SegmentationDataModule",
16
+ "SegmentationMulticlassDataModule",
17
+ "PatchSklearnClassificationDataModule",
18
+ "MultilabelClassificationDataModule",
19
+ "SSLDataModule",
20
+ ]
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+
6
+ import albumentations
7
+ import pandas as pd
8
+ from torch.utils.data import DataLoader
9
+
10
+ from quadra.datamodules.base import BaseDataModule
11
+ from quadra.datasets import AnomalyDataset
12
+ from quadra.datasets.anomaly import make_anomaly_dataset
13
+ from quadra.utils import utils
14
+
15
+ log = utils.get_logger(__name__)
16
+
17
+
18
+ class AnomalyDataModule(BaseDataModule):
19
+ """Anomalib-like Lightning Data Module.
20
+
21
+ Args:
22
+ data_path: Path to the dataset
23
+ category: Name of the sub category to use.
24
+ image_size: Variable to which image is resized.
25
+ train_batch_size: Training batch size.
26
+ test_batch_size: Testing batch size.
27
+ train_transform: transformations for training. Defaults to None.
28
+ val_transform: transformations for validation. Defaults to None.
29
+ test_transform: transformations for testing. Defaults to None.
30
+ num_workers: Number of workers.
31
+ seed: seed used for the random subset splitting
32
+ task: Whether we are interested in segmenting the anomalies (segmentation) or not (classification)
33
+ mask_suffix: String to append to the base filename to get the mask name, by default for MVTec dataset masks
34
+ are saved as imagename_mask.png in this case the parameter should be filled with "_mask"
35
+ create_test_set_if_empty: If True, the test set is created from good images if it is empty.
36
+ phase: Either train or test.
37
+ name: Name of the data module.
38
+ valid_area_mask: Optional path to the mask to use to filter out the valid area of the image. If None, the whole
39
+ image is considered valid. The mask should match the image size even if the image is cropped.
40
+ crop_area: Optional tuple of 4 integers (x1, y1, x2, y2) to crop the image to the specified area. If None, the
41
+ whole image is considered valid.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ data_path: str,
47
+ category: str | None = None,
48
+ image_size: int | tuple[int, int] | None = None,
49
+ train_batch_size: int = 32,
50
+ test_batch_size: int = 32,
51
+ num_workers: int = 8,
52
+ train_transform: albumentations.Compose | None = None,
53
+ val_transform: albumentations.Compose | None = None,
54
+ test_transform: albumentations.Compose | None = None,
55
+ seed: int = 0,
56
+ task: str = "segmentation",
57
+ mask_suffix: str | None = None,
58
+ create_test_set_if_empty: bool = True,
59
+ phase: str = "train",
60
+ name: str = "anomaly_datamodule",
61
+ valid_area_mask: str | None = None,
62
+ crop_area: tuple[int, int, int, int] | None = None,
63
+ **kwargs,
64
+ ) -> None:
65
+ super().__init__(
66
+ data_path=data_path,
67
+ name=name,
68
+ seed=seed,
69
+ train_transform=train_transform,
70
+ val_transform=val_transform,
71
+ test_transform=test_transform,
72
+ num_workers=num_workers,
73
+ **kwargs,
74
+ )
75
+
76
+ self.root = data_path
77
+ self.category = category
78
+ self.data_path = os.path.join(self.root, self.category) if self.category is not None else self.root
79
+ self.image_size = image_size
80
+
81
+ self.train_batch_size = train_batch_size
82
+ self.test_batch_size = test_batch_size
83
+ self.task = task
84
+
85
+ self.train_dataset: AnomalyDataset
86
+ self.test_dataset: AnomalyDataset
87
+ self.val_dataset: AnomalyDataset
88
+ self.mask_suffix = mask_suffix
89
+ self.create_test_set_if_empty = create_test_set_if_empty
90
+ self.phase = phase
91
+ self.valid_area_mask = valid_area_mask
92
+ self.crop_area = crop_area
93
+
94
+ @property
95
+ def val_data(self) -> pd.DataFrame:
96
+ """Get validation data."""
97
+ _val_data = super().val_data
98
+ if len(_val_data) == 0:
99
+ return self.test_data
100
+ return _val_data
101
+
102
+ def _prepare_data(self) -> None:
103
+ """Prepare data for training and testing."""
104
+ self.data = make_anomaly_dataset(
105
+ path=pathlib.Path(self.data_path),
106
+ split=None,
107
+ seed=self.seed,
108
+ mask_suffix=self.mask_suffix,
109
+ create_test_set_if_empty=self.create_test_set_if_empty,
110
+ )
111
+
112
+ def setup(self, stage: str | None = None) -> None:
113
+ """Setup data module based on stages of training."""
114
+ if stage == "fit" and self.phase == "train":
115
+ self.train_dataset = AnomalyDataset(
116
+ transform=self.train_transform,
117
+ task=self.task,
118
+ samples=self.train_data,
119
+ valid_area_mask=self.valid_area_mask,
120
+ crop_area=self.crop_area,
121
+ )
122
+
123
+ if len(self.val_data) == 0:
124
+ log.info("Validation dataset is empty, using test set instead")
125
+
126
+ self.val_dataset = AnomalyDataset(
127
+ transform=self.test_transform,
128
+ task=self.task,
129
+ samples=self.val_data if len(self.val_data) > 0 else self.data,
130
+ valid_area_mask=self.valid_area_mask,
131
+ crop_area=self.crop_area,
132
+ )
133
+ if stage == "test" or self.phase == "test":
134
+ self.test_dataset = AnomalyDataset(
135
+ transform=self.test_transform,
136
+ task=self.task,
137
+ samples=self.test_data,
138
+ valid_area_mask=self.valid_area_mask,
139
+ crop_area=self.crop_area,
140
+ )
141
+
142
+ def train_dataloader(self) -> DataLoader:
143
+ """Get train dataloader."""
144
+ return DataLoader(
145
+ self.train_dataset,
146
+ shuffle=True,
147
+ batch_size=self.train_batch_size,
148
+ num_workers=self.num_workers,
149
+ pin_memory=True,
150
+ )
151
+
152
+ def val_dataloader(self) -> DataLoader:
153
+ """Get validation dataloader."""
154
+ return DataLoader(
155
+ dataset=self.val_dataset,
156
+ shuffle=False,
157
+ batch_size=self.test_batch_size,
158
+ num_workers=self.num_workers,
159
+ pin_memory=True,
160
+ )
161
+
162
+ def test_dataloader(self) -> DataLoader:
163
+ """Get test dataloader."""
164
+ return DataLoader(
165
+ self.test_dataset,
166
+ shuffle=False,
167
+ batch_size=self.test_batch_size,
168
+ num_workers=self.num_workers,
169
+ pin_memory=True,
170
+ )
171
+
172
+ def predict_dataloader(self) -> DataLoader:
173
+ """Returns a dataloader used for predictions."""
174
+ return DataLoader(
175
+ self.test_dataset,
176
+ shuffle=False,
177
+ batch_size=self.test_batch_size,
178
+ num_workers=self.num_workers,
179
+ pin_memory=True,
180
+ )