quadra 2.3.0a2__py3-none-any.whl → 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +0 -0
  2. quadra/__init__.py +1 -1
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +3 -2
  5. quadra/callbacks/lightning.py +3 -1
  6. quadra/callbacks/mlflow.py +0 -0
  7. quadra/callbacks/scheduler.py +0 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +0 -0
  10. quadra/configs/backbone/caformer_s36.yaml +0 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +0 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +0 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +0 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +0 -0
  15. quadra/configs/backbone/dino_vits8.yaml +0 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +0 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +0 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +0 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +0 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +0 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +0 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +0 -0
  23. quadra/configs/backbone/levit_128s.yaml +0 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +0 -0
  25. quadra/configs/backbone/resnet101.yaml +0 -0
  26. quadra/configs/backbone/resnet18.yaml +0 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +0 -0
  28. quadra/configs/backbone/resnet50.yaml +0 -0
  29. quadra/configs/backbone/smp.yaml +0 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +0 -0
  31. quadra/configs/backbone/unetr.yaml +0 -0
  32. quadra/configs/backbone/vit16_base.yaml +0 -0
  33. quadra/configs/backbone/vit16_small.yaml +0 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +0 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +0 -0
  36. quadra/configs/callbacks/all.yaml +0 -0
  37. quadra/configs/callbacks/default.yaml +0 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +0 -0
  39. quadra/configs/config.yaml +0 -0
  40. quadra/configs/core/default.yaml +0 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +0 -0
  42. quadra/configs/datamodule/base/classification.yaml +0 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +0 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +0 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +0 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +0 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +0 -0
  48. quadra/configs/datamodule/base/ssl.yaml +0 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +0 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +0 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +0 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +0 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +0 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +0 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +0 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +0 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +0 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +0 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +0 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +0 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +0 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +0 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +0 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +0 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +0 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +0 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +0 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +0 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +0 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +0 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +0 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +0 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +0 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +0 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +0 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +0 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +0 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +0 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +0 -0
  80. quadra/configs/experiment/custom/cls.yaml +0 -0
  81. quadra/configs/experiment/default.yaml +0 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +0 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +0 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +0 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +0 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +0 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +0 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +0 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +0 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +0 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +0 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +0 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +0 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +0 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +0 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +0 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +0 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +0 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +0 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +0 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +0 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +0 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +0 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +0 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +0 -0
  106. quadra/configs/export/default.yaml +0 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +0 -0
  108. quadra/configs/hydra/default.yaml +0 -0
  109. quadra/configs/inference/default.yaml +0 -0
  110. quadra/configs/logger/comet.yaml +0 -0
  111. quadra/configs/logger/csv.yaml +0 -0
  112. quadra/configs/logger/mlflow.yaml +0 -0
  113. quadra/configs/logger/tensorboard.yaml +0 -0
  114. quadra/configs/loss/asl.yaml +0 -0
  115. quadra/configs/loss/barlow.yaml +0 -0
  116. quadra/configs/loss/bce.yaml +0 -0
  117. quadra/configs/loss/byol.yaml +0 -0
  118. quadra/configs/loss/cross_entropy.yaml +0 -0
  119. quadra/configs/loss/dino.yaml +0 -0
  120. quadra/configs/loss/simclr.yaml +0 -0
  121. quadra/configs/loss/simsiam.yaml +0 -0
  122. quadra/configs/loss/smp_ce.yaml +0 -0
  123. quadra/configs/loss/smp_dice.yaml +0 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +0 -0
  125. quadra/configs/loss/smp_mcc.yaml +0 -0
  126. quadra/configs/loss/vicreg.yaml +0 -0
  127. quadra/configs/model/anomalib/cfa.yaml +0 -0
  128. quadra/configs/model/anomalib/cflow.yaml +0 -0
  129. quadra/configs/model/anomalib/csflow.yaml +0 -0
  130. quadra/configs/model/anomalib/dfm.yaml +0 -0
  131. quadra/configs/model/anomalib/draem.yaml +0 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +0 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +0 -0
  134. quadra/configs/model/anomalib/padim.yaml +0 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +0 -0
  136. quadra/configs/model/barlow.yaml +0 -0
  137. quadra/configs/model/byol.yaml +0 -0
  138. quadra/configs/model/classification.yaml +0 -0
  139. quadra/configs/model/dino.yaml +0 -0
  140. quadra/configs/model/logistic_regression.yaml +0 -0
  141. quadra/configs/model/multilabel_classification.yaml +0 -0
  142. quadra/configs/model/simclr.yaml +0 -0
  143. quadra/configs/model/simsiam.yaml +0 -0
  144. quadra/configs/model/smp.yaml +0 -0
  145. quadra/configs/model/smp_multiclass.yaml +0 -0
  146. quadra/configs/model/vicreg.yaml +0 -0
  147. quadra/configs/optimizer/adam.yaml +0 -0
  148. quadra/configs/optimizer/adamw.yaml +0 -0
  149. quadra/configs/optimizer/default.yaml +0 -0
  150. quadra/configs/optimizer/lars.yaml +0 -0
  151. quadra/configs/optimizer/sgd.yaml +0 -0
  152. quadra/configs/scheduler/default.yaml +0 -0
  153. quadra/configs/scheduler/rop.yaml +0 -0
  154. quadra/configs/scheduler/step.yaml +0 -0
  155. quadra/configs/scheduler/warmrestart.yaml +0 -0
  156. quadra/configs/scheduler/warmup.yaml +0 -0
  157. quadra/configs/task/anomalib/cfa.yaml +0 -0
  158. quadra/configs/task/anomalib/cflow.yaml +0 -0
  159. quadra/configs/task/anomalib/csflow.yaml +0 -0
  160. quadra/configs/task/anomalib/draem.yaml +0 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +0 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +0 -0
  163. quadra/configs/task/anomalib/inference.yaml +0 -0
  164. quadra/configs/task/anomalib/padim.yaml +0 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +0 -0
  166. quadra/configs/task/classification.yaml +0 -0
  167. quadra/configs/task/classification_evaluation.yaml +0 -0
  168. quadra/configs/task/default.yaml +0 -0
  169. quadra/configs/task/segmentation.yaml +0 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +0 -0
  171. quadra/configs/task/sklearn_classification.yaml +0 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +0 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +0 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +0 -0
  175. quadra/configs/task/ssl.yaml +0 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +0 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +0 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +0 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +0 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +0 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +0 -0
  182. quadra/configs/transforms/byol.yaml +0 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +0 -0
  184. quadra/configs/transforms/default.yaml +0 -0
  185. quadra/configs/transforms/default_numpy.yaml +0 -0
  186. quadra/configs/transforms/default_resize.yaml +0 -0
  187. quadra/configs/transforms/dino.yaml +0 -0
  188. quadra/configs/transforms/linear_eval.yaml +0 -0
  189. quadra/datamodules/__init__.py +0 -0
  190. quadra/datamodules/anomaly.py +0 -0
  191. quadra/datamodules/base.py +5 -5
  192. quadra/datamodules/classification.py +2 -2
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +0 -0
  195. quadra/datamodules/generic/mnist.py +0 -0
  196. quadra/datamodules/generic/mvtec.py +0 -0
  197. quadra/datamodules/generic/oxford_pet.py +0 -0
  198. quadra/datamodules/patch.py +0 -0
  199. quadra/datamodules/segmentation.py +6 -6
  200. quadra/datamodules/ssl.py +0 -0
  201. quadra/datasets/__init__.py +0 -0
  202. quadra/datasets/anomaly.py +2 -2
  203. quadra/datasets/classification.py +7 -7
  204. quadra/datasets/patch.py +1 -1
  205. quadra/datasets/segmentation.py +0 -0
  206. quadra/datasets/ssl.py +3 -3
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +0 -0
  209. quadra/losses/classification/asl.py +0 -0
  210. quadra/losses/classification/focal.py +0 -0
  211. quadra/losses/classification/prototypical.py +0 -0
  212. quadra/losses/ssl/__init__.py +0 -0
  213. quadra/losses/ssl/barlowtwins.py +0 -0
  214. quadra/losses/ssl/byol.py +0 -0
  215. quadra/losses/ssl/dino.py +0 -0
  216. quadra/losses/ssl/hyperspherical.py +0 -0
  217. quadra/losses/ssl/idmm.py +0 -0
  218. quadra/losses/ssl/simclr.py +0 -0
  219. quadra/losses/ssl/simsiam.py +0 -0
  220. quadra/losses/ssl/vicreg.py +0 -0
  221. quadra/main.py +0 -0
  222. quadra/metrics/__init__.py +0 -0
  223. quadra/metrics/segmentation.py +1 -1
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +1 -1
  226. quadra/models/classification/__init__.py +0 -0
  227. quadra/models/classification/backbones.py +0 -0
  228. quadra/models/classification/base.py +0 -0
  229. quadra/models/evaluation.py +1 -1
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +0 -0
  232. quadra/modules/base.py +3 -2
  233. quadra/modules/classification/__init__.py +0 -0
  234. quadra/modules/classification/base.py +0 -0
  235. quadra/modules/ssl/__init__.py +0 -0
  236. quadra/modules/ssl/barlowtwins.py +0 -0
  237. quadra/modules/ssl/byol.py +1 -0
  238. quadra/modules/ssl/common.py +0 -0
  239. quadra/modules/ssl/dino.py +0 -0
  240. quadra/modules/ssl/hyperspherical.py +0 -0
  241. quadra/modules/ssl/idmm.py +0 -0
  242. quadra/modules/ssl/simclr.py +0 -0
  243. quadra/modules/ssl/simsiam.py +0 -0
  244. quadra/modules/ssl/vicreg.py +0 -0
  245. quadra/optimizers/__init__.py +0 -0
  246. quadra/optimizers/lars.py +0 -0
  247. quadra/optimizers/sam.py +0 -0
  248. quadra/schedulers/__init__.py +0 -0
  249. quadra/schedulers/base.py +0 -0
  250. quadra/schedulers/warmup.py +0 -0
  251. quadra/tasks/__init__.py +0 -0
  252. quadra/tasks/anomaly.py +7 -4
  253. quadra/tasks/base.py +8 -4
  254. quadra/tasks/classification.py +6 -2
  255. quadra/tasks/patch.py +1 -1
  256. quadra/tasks/segmentation.py +7 -5
  257. quadra/tasks/ssl.py +2 -3
  258. quadra/trainers/README.md +0 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +0 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +0 -0
  263. quadra/utils/classification.py +8 -10
  264. quadra/utils/deprecation.py +0 -0
  265. quadra/utils/evaluation.py +12 -3
  266. quadra/utils/export.py +5 -5
  267. quadra/utils/imaging.py +0 -0
  268. quadra/utils/logger.py +0 -0
  269. quadra/utils/mlflow.py +2 -0
  270. quadra/utils/model_manager.py +0 -0
  271. quadra/utils/models.py +5 -7
  272. quadra/utils/patch/__init__.py +0 -0
  273. quadra/utils/patch/dataset.py +7 -6
  274. quadra/utils/patch/metrics.py +9 -6
  275. quadra/utils/patch/model.py +0 -0
  276. quadra/utils/patch/visualization.py +2 -2
  277. quadra/utils/resolver.py +0 -0
  278. quadra/utils/segmentation.py +0 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +0 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +0 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +0 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +0 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +1 -1
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +0 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +0 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +0 -0
  288. quadra/utils/tests/fixtures/models/classification.py +0 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +0 -0
  290. quadra/utils/tests/helpers.py +0 -0
  291. quadra/utils/tests/models.py +0 -0
  292. quadra/utils/utils.py +1 -1
  293. quadra/utils/validator.py +1 -3
  294. quadra/utils/visualization.py +8 -5
  295. quadra/utils/vit_explainability.py +1 -1
  296. {quadra-2.3.0a2.dist-info → quadra-2.3.1.dist-info}/LICENSE +0 -0
  297. {quadra-2.3.0a2.dist-info → quadra-2.3.1.dist-info}/METADATA +1 -1
  298. {quadra-2.3.0a2.dist-info → quadra-2.3.1.dist-info}/RECORD +39 -39
  299. {quadra-2.3.0a2.dist-info → quadra-2.3.1.dist-info}/WHEEL +1 -1
  300. {quadra-2.3.0a2.dist-info → quadra-2.3.1.dist-info}/entry_points.txt +0 -0
File without changes
File without changes
@@ -187,7 +187,7 @@ class SegmentationDataModule(BaseDataModule):
187
187
  samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
188
188
  if not self.train_split_file:
189
189
  samples_train, targets_train, masks_train = [], [], []
190
- for sample, target, mask in zip(all_samples, all_targets, all_masks):
190
+ for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
191
191
  if sample not in samples_test:
192
192
  samples_train.append(sample)
193
193
  targets_train.append(target)
@@ -197,7 +197,7 @@ class SegmentationDataModule(BaseDataModule):
197
197
  samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
198
198
  if not self.test_split_file:
199
199
  samples_test, targets_test, masks_test = [], [], []
200
- for sample, target, mask in zip(all_samples, all_targets, all_masks):
200
+ for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
201
201
  if sample not in samples_train:
202
202
  samples_test.append(sample)
203
203
  targets_test.append(target)
@@ -549,7 +549,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
549
549
  samples_and_masks_test,
550
550
  targets_test,
551
551
  ) = iterative_train_test_split(
552
- np.expand_dims(np.array(list(zip(all_samples, all_masks))), 1),
552
+ np.expand_dims(np.array(list(zip(all_samples, all_masks, strict=False))), 1),
553
553
  np.array(all_targets),
554
554
  test_size=self.test_size,
555
555
  )
@@ -561,7 +561,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
561
561
  samples_test, targets_test, masks_test = self._read_split(self.test_split_file)
562
562
  if not self.train_split_file:
563
563
  samples_train, targets_train, masks_train = [], [], []
564
- for sample, target, mask in zip(all_samples, all_targets, all_masks):
564
+ for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
565
565
  if sample not in samples_test:
566
566
  samples_train.append(sample)
567
567
  targets_train.append(target)
@@ -571,7 +571,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
571
571
  samples_train, targets_train, masks_train = self._read_split(self.train_split_file)
572
572
  if not self.test_split_file:
573
573
  samples_test, targets_test, masks_test = [], [], []
574
- for sample, target, mask in zip(all_samples, all_targets, all_masks):
574
+ for sample, target, mask in zip(all_samples, all_targets, all_masks, strict=False):
575
575
  if sample not in samples_train:
576
576
  samples_test.append(sample)
577
577
  targets_test.append(target)
@@ -583,7 +583,7 @@ class SegmentationMulticlassDataModule(BaseDataModule):
583
583
  raise ValueError("Validation split file is specified but no train or test split file is specified.")
584
584
  else:
585
585
  samples_and_masks_train, targets_train, samples_and_masks_val, targets_val = iterative_train_test_split(
586
- np.expand_dims(np.array(list(zip(samples_train, masks_train))), 1),
586
+ np.expand_dims(np.array(list(zip(samples_train, masks_train, strict=False))), 1),
587
587
  np.array(targets_train),
588
588
  test_size=self.val_size,
589
589
  )
quadra/datamodules/ssl.py CHANGED
File without changes
File without changes
@@ -220,7 +220,7 @@ class AnomalyDataset(Dataset):
220
220
  if not os.path.exists(valid_area_mask):
221
221
  raise RuntimeError(f"Valid area mask {valid_area_mask} does not exist.")
222
222
 
223
- self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0 # type: ignore[operator]
223
+ self.valid_area_mask = cv2.imread(valid_area_mask, 0) > 0
224
224
 
225
225
  def __len__(self) -> int:
226
226
  """Get length of the dataset."""
@@ -265,7 +265,7 @@ class AnomalyDataset(Dataset):
265
265
  if label_index == 0:
266
266
  mask = np.zeros(shape=original_image_shape[:2])
267
267
  elif os.path.isfile(mask_path):
268
- mask = cv2.imread(mask_path, flags=0) / 255.0 # type: ignore[operator]
268
+ mask = cv2.imread(mask_path, flags=0) / 255.0
269
269
  else:
270
270
  # We need ones in the mask to compute correctly at least image level f1 score
271
271
  mask = np.ones(shape=original_image_shape[:2])
@@ -50,9 +50,9 @@ class ImageClassificationListDataset(Dataset):
50
50
  allow_missing_label: bool | None = False,
51
51
  ):
52
52
  super().__init__()
53
- assert len(samples) == len(
54
- targets
55
- ), f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
53
+ assert len(samples) == len(targets), (
54
+ f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
55
+ )
56
56
  # Setting the ROI
57
57
  self.roi = roi
58
58
 
@@ -201,9 +201,9 @@ class MultilabelClassificationDataset(torch.utils.data.Dataset):
201
201
  rgb: bool = True,
202
202
  ):
203
203
  super().__init__()
204
- assert len(samples) == len(
205
- targets
206
- ), f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
204
+ assert len(samples) == len(targets), (
205
+ f"Samples ({len(samples)}) and targets ({len(targets)}) must have the same length"
206
+ )
207
207
 
208
208
  # Data
209
209
  self.x = samples
@@ -215,7 +215,7 @@ class MultilabelClassificationDataset(torch.utils.data.Dataset):
215
215
  class_to_idx = {c: i for i, c in enumerate(range(unique_targets))}
216
216
  self.class_to_idx = class_to_idx
217
217
  self.idx_to_class = {v: k for k, v in class_to_idx.items()}
218
- self.samples = list(zip(self.x, self.y))
218
+ self.samples = list(zip(self.x, self.y, strict=False))
219
219
  self.rgb = rgb
220
220
  self.transform = transform
221
221
 
quadra/datasets/patch.py CHANGED
@@ -58,7 +58,7 @@ class PatchSklearnClassificationTrainDataset(Dataset):
58
58
 
59
59
  cls, counts = np.unique(targets_array, return_counts=True)
60
60
  max_count = np.max(counts)
61
- for cl, count in zip(cls, counts):
61
+ for cl, count in zip(cls, counts, strict=False):
62
62
  idx_to_pick = list(np.where(targets_array == cl)[0])
63
63
 
64
64
  if count < max_count:
File without changes
quadra/datasets/ssl.py CHANGED
@@ -75,9 +75,9 @@ class TwoSetAugmentationDataset(Dataset):
75
75
  return the original image.
76
76
 
77
77
  Example:
78
- >>> images[0] = global_transform[0](original_image)
79
- >>> images[1] = global_transform[1](original_image)
80
- >>> images[2:] = local_transform(s)(original_image)
78
+ >>> `images[0] = global_transform[0](original_image)`
79
+ >>> `images[1] = global_transform[1](original_image)`
80
+ >>> `images[2:] = local_transform(s)(original_image)`
81
81
  """
82
82
 
83
83
  def __init__(
quadra/losses/__init__.py CHANGED
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
quadra/losses/ssl/byol.py CHANGED
File without changes
quadra/losses/ssl/dino.py CHANGED
File without changes
File without changes
quadra/losses/ssl/idmm.py CHANGED
File without changes
File without changes
File without changes
File without changes
quadra/main.py CHANGED
File without changes
File without changes
@@ -171,7 +171,7 @@ def segmentation_props(
171
171
  # Add dummy Dices so LSA is unique and i can compute FP and FN
172
172
  dice_mat = _pad_to_shape(dice_mat, (max_dim, max_dim), 1)
173
173
  lsa = linear_sum_assignment(dice_mat, maximize=False)
174
- for row, col in zip(lsa[0], lsa[1]):
174
+ for row, col in zip(lsa[0], lsa[1], strict=False):
175
175
  # More preds than GTs --> False Positive
176
176
  if row < n_labels_pred and col >= n_labels_mask:
177
177
  min_row = pred_bbox[row][0]
quadra/models/__init__.py CHANGED
File without changes
quadra/models/base.py CHANGED
@@ -76,7 +76,7 @@ class ModelSignatureWrapper(nn.Module):
76
76
 
77
77
  if isinstance(self.instance.forward, torch.ScriptMethod):
78
78
  # Handle torchscript backbones
79
- for i, argument in enumerate(self.instance.forward.schema.arguments):
79
+ for i, argument in enumerate(self.instance.forward.schema.arguments): # type: ignore[attr-defined]
80
80
  if i < (len(args) + 1): # +1 for self
81
81
  continue
82
82
 
File without changes
File without changes
File without changes
@@ -209,7 +209,7 @@ class ONNXEvaluationModel(BaseEvaluationModel):
209
209
 
210
210
  onnx_inputs: dict[str, np.ndarray | torch.Tensor] = {}
211
211
 
212
- for onnx_input, current_input in zip(self.model.get_inputs(), inputs):
212
+ for onnx_input, current_input in zip(self.model.get_inputs(), inputs, strict=False):
213
213
  if isinstance(current_input, torch.Tensor):
214
214
  onnx_inputs[onnx_input.name] = current_input
215
215
  use_pytorch = True
File without changes
File without changes
quadra/modules/base.py CHANGED
@@ -7,6 +7,7 @@ import pytorch_lightning as pl
7
7
  import sklearn
8
8
  import torch
9
9
  import torchmetrics
10
+ from pytorch_lightning.utilities.types import OptimizerLRScheduler
10
11
  from sklearn.linear_model import LogisticRegression
11
12
  from torch import nn
12
13
  from torch.optim import Optimizer
@@ -48,7 +49,7 @@ class BaseLightningModule(pl.LightningModule):
48
49
  """
49
50
  return self.model(x)
50
51
 
51
- def configure_optimizers(self) -> tuple[list[Any], list[dict[str, Any]]]:
52
+ def configure_optimizers(self) -> OptimizerLRScheduler:
52
53
  """Get default optimizer if not passed a value.
53
54
 
54
55
  Returns:
@@ -68,7 +69,7 @@ class BaseLightningModule(pl.LightningModule):
68
69
  "monitor": "val_loss",
69
70
  "strict": False,
70
71
  }
71
- return [self.optimizer], [lr_scheduler_conf]
72
+ return [self.optimizer], [lr_scheduler_conf] # type: ignore[return-value]
72
73
 
73
74
  # pylint: disable=unused-argument
74
75
  def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx: int = 0):
File without changes
File without changes
File without changes
File without changes
@@ -110,6 +110,7 @@ class BYOL(SSLModule):
110
110
  for student_ps, teacher_ps in zip(
111
111
  list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
112
112
  list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
113
+ strict=False,
113
114
  ):
114
115
  teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data
115
116
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
quadra/optimizers/lars.py CHANGED
File without changes
quadra/optimizers/sam.py CHANGED
File without changes
File without changes
quadra/schedulers/base.py CHANGED
File without changes
File without changes
quadra/tasks/__init__.py CHANGED
File without changes
quadra/tasks/anomaly.py CHANGED
@@ -161,7 +161,7 @@ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataMo
161
161
  all_output_flatten: dict[str, torch.Tensor | list] = {}
162
162
 
163
163
  for key in all_output[0]:
164
- if type(all_output[0][key]) == torch.Tensor:
164
+ if isinstance(all_output[0][key], torch.Tensor):
165
165
  tensor_gatherer = torch.cat([x[key] for x in all_output])
166
166
  all_output_flatten[key] = tensor_gatherer
167
167
  else:
@@ -205,13 +205,15 @@ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataMo
205
205
  class_to_idx.pop("false_defect")
206
206
 
207
207
  anomaly_scores = all_output_flatten["pred_scores"]
208
+
209
+ exportable_anomaly_scores: list[Any] | np.ndarray
208
210
  if isinstance(anomaly_scores, torch.Tensor):
209
211
  exportable_anomaly_scores = anomaly_scores.cpu().numpy()
210
212
  else:
211
213
  exportable_anomaly_scores = anomaly_scores
212
214
 
213
215
  # Zip the lists together to create rows for the CSV file
214
- rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores)
216
+ rows = zip(image_paths, pred_labels, gt_labels, exportable_anomaly_scores, strict=False)
215
217
  # Specify the CSV file name
216
218
  csv_file = "test_predictions.csv"
217
219
  # Write the data to the CSV file
@@ -483,7 +485,7 @@ class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
483
485
 
484
486
  if hasattr(self.datamodule, "valid_area_mask") and self.datamodule.valid_area_mask is not None:
485
487
  mask_area = cv2.imread(self.datamodule.valid_area_mask, 0)
486
- mask_area = (mask_area > 0).astype(np.uint8) # type: ignore[operator]
488
+ mask_area = (mask_area > 0).astype(np.uint8)
487
489
 
488
490
  if hasattr(self.datamodule, "crop_area") and self.datamodule.crop_area is not None:
489
491
  crop_area = self.datamodule.crop_area
@@ -499,12 +501,13 @@ class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
499
501
  self.metadata["image_labels"],
500
502
  anomaly_scores,
501
503
  anomaly_maps,
504
+ strict=False,
502
505
  ),
503
506
  total=len(self.metadata["image_paths"]),
504
507
  ):
505
508
  img = cv2.imread(img_path, 0)
506
509
  if mask_area is not None:
507
- img = img * mask_area # type: ignore[operator]
510
+ img = img * mask_area
508
511
 
509
512
  if crop_area is not None:
510
513
  img = img[crop_area[1] : crop_area[3], crop_area[0] : crop_area[2]]
quadra/tasks/base.py CHANGED
@@ -382,15 +382,19 @@ class Evaluation(Generic[DataModuleT], Task[DataModuleT]):
382
382
  # We assume that each input size has the same height and width
383
383
  if input_size[1] != self.config.transforms.input_height:
384
384
  log.warning(
385
- f"Input height of the model ({input_size[1]}) is different from the one specified "
386
- + f"in the config ({self.config.transforms.input_height}). Fixing the config."
385
+ "Input height of the model (%s) is different from the one specified "
386
+ + "in the config (%s). Fixing the config.",
387
+ input_size[1],
388
+ self.config.transforms.input_height,
387
389
  )
388
390
  self.config.transforms.input_height = input_size[1]
389
391
 
390
392
  if input_size[2] != self.config.transforms.input_width:
391
393
  log.warning(
392
- f"Input width of the model ({input_size[2]}) is different from the one specified "
393
- + f"in the config ({self.config.transforms.input_width}). Fixing the config."
394
+ "Input width of the model (%s) is different from the one specified "
395
+ + "in the config (%s). Fixing the config.",
396
+ input_size[2],
397
+ self.config.transforms.input_width,
394
398
  )
395
399
  self.config.transforms.input_width = input_size[2]
396
400
 
@@ -623,7 +623,9 @@ class SklearnClassification(Generic[SklearnClassificationDataModuleT], Task[Skle
623
623
  all_labels = all_labels[sorted_indices]
624
624
 
625
625
  # cycle over all train/test split
626
- for train_dataloader, test_dataloader in zip(self.train_dataloader_list, self.test_dataloader_list):
626
+ for train_dataloader, test_dataloader in zip(
627
+ self.train_dataloader_list, self.test_dataloader_list, strict=False
628
+ ):
627
629
  # Reinit classifier
628
630
  self.model = self.config.model
629
631
  self.trainer.change_classifier(self.model)
@@ -685,7 +687,7 @@ class SklearnClassification(Generic[SklearnClassificationDataModuleT], Task[Skle
685
687
  dl: PyTorch dataloader
686
688
  feature_extractor: PyTorch backbone
687
689
  """
688
- if isinstance(feature_extractor, (TorchEvaluationModel, TorchscriptEvaluationModel)):
690
+ if isinstance(feature_extractor, TorchEvaluationModel | TorchscriptEvaluationModel):
689
691
  # TODO: I'm not sure torchinfo supports torchscript models
690
692
  # If we are working with torch based evaluation models we need to extract the model
691
693
  feature_extractor = feature_extractor.model
@@ -1202,6 +1204,8 @@ class ClassificationEvaluation(Evaluation[ClassificationDataModuleT]):
1202
1204
  probabilities = [max(item) for sublist in probabilities for item in sublist]
1203
1205
  if self.datamodule.class_to_idx is not None:
1204
1206
  idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
1207
+ else:
1208
+ idx_to_class = None
1205
1209
 
1206
1210
  _, pd_cm, test_accuracy = get_results(
1207
1211
  test_labels=image_labels,
quadra/tasks/patch.py CHANGED
@@ -301,7 +301,7 @@ class PatchSklearnTestClassification(Evaluation[PatchSklearnClassificationDataMo
301
301
  "test_results": None,
302
302
  "test_labels": None,
303
303
  }
304
- self.class_to_skip: list[str] = []
304
+ self.class_to_skip: list[str] | None = []
305
305
  self.reconstruction_results: dict[str, Any]
306
306
  self.return_polygon: bool = True
307
307
 
@@ -92,8 +92,10 @@ class Segmentation(Generic[SegmentationDataModuleT], LightningTask[SegmentationD
92
92
  len(self.datamodule.idx_to_class) + 1
93
93
  ):
94
94
  log.warning(
95
- f"Number of classes in the model ({module_config.model.num_classes}) does not match the number of "
96
- + f"classes in the datamodule ({len(self.datamodule.idx_to_class)}). Updating the model..."
95
+ "Number of classes in the model (%s) does not match the number of "
96
+ + "classes in the datamodule (%d). Updating the model...",
97
+ module_config.model.num_classes,
98
+ len(self.datamodule.idx_to_class),
97
99
  )
98
100
  module_config.model.num_classes = len(self.datamodule.idx_to_class) + 1
99
101
 
@@ -341,7 +343,7 @@ class SegmentationAnalysisEvaluation(SegmentationEvaluation):
341
343
  if self.datamodule.test_dataset_available:
342
344
  stages.append("test")
343
345
  dataloaders.append(self.datamodule.test_dataloader())
344
- for stage, dataloader in zip(stages, dataloaders):
346
+ for stage, dataloader in zip(stages, dataloaders, strict=False):
345
347
  log.info("Running inference on %s set with batch size: %d", stage, dataloader.batch_size)
346
348
  image_list, mask_list, mask_pred_list, label_list = [], [], [], []
347
349
  for batch in dataloader:
@@ -369,10 +371,10 @@ class SegmentationAnalysisEvaluation(SegmentationEvaluation):
369
371
 
370
372
  for stage, output in self.test_output.items():
371
373
  image_mean = OmegaConf.to_container(self.config.transforms.mean)
372
- if not isinstance(image_mean, list) or any(not isinstance(x, (int, float)) for x in image_mean):
374
+ if not isinstance(image_mean, list) or any(not isinstance(x, int | float) for x in image_mean):
373
375
  raise ValueError("Image mean is not a list of float or integer values, please check your config")
374
376
  image_std = OmegaConf.to_container(self.config.transforms.std)
375
- if not isinstance(image_std, list) or any(not isinstance(x, (int, float)) for x in image_std):
377
+ if not isinstance(image_std, list) or any(not isinstance(x, int | float) for x in image_std):
376
378
  raise ValueError("Image std is not a list of float or integer values, please check your config")
377
379
  reports = create_mask_report(
378
380
  stage=stage,
quadra/tasks/ssl.py CHANGED
@@ -468,8 +468,7 @@ class EmbeddingVisualization(Task):
468
468
  self.report_folder = report_folder
469
469
  if self.model_path is None:
470
470
  raise ValueError(
471
- "Model path cannot be found!, please specify it in the config or pass it as an argument for"
472
- " evaluation"
471
+ "Model path cannot be found!, please specify it in the config or pass it as an argument for evaluation"
473
472
  )
474
473
  self.embeddings_path = os.path.join(self.model_path, self.report_folder)
475
474
  if not os.path.exists(self.embeddings_path):
@@ -547,7 +546,7 @@ class EmbeddingVisualization(Task):
547
546
  im = interpolate(im, self.embedding_image_size)
548
547
 
549
548
  images.append(im.cpu())
550
- metadata.extend(zip(targets, class_names, file_paths))
549
+ metadata.extend(zip(targets, class_names, file_paths, strict=False))
551
550
  counter += len(im)
552
551
  images = torch.cat(images, dim=0)
553
552
  embeddings = torch.cat(embeddings, dim=0)
quadra/trainers/README.md CHANGED
File without changes
File without changes
File without changes
quadra/utils/__init__.py CHANGED
File without changes
quadra/utils/anomaly.py CHANGED
File without changes
@@ -46,12 +46,10 @@ def get_file_condition(
46
46
  if any(fil in root for fil in exclude_filter):
47
47
  return False
48
48
 
49
- if include_filter is not None and (
50
- not any(fil in file_name for fil in include_filter) and not any(fil in root for fil in include_filter)
51
- ):
52
- return False
53
-
54
- return True
49
+ return not (
50
+ include_filter is not None
51
+ and (not any(fil in file_name for fil in include_filter) and not any(fil in root for fil in include_filter))
52
+ )
55
53
 
56
54
 
57
55
  def natural_key(string_):
@@ -130,7 +128,7 @@ def find_images_and_targets(
130
128
  sorted_labels = sorted(unique_labels, key=natural_key)
131
129
  class_to_idx = {str(c): idx for idx, c in enumerate(sorted_labels)}
132
130
 
133
- images_and_targets = [(f, l) for f, l in zip(filenames, labels) if l in class_to_idx]
131
+ images_and_targets = [(f, l) for f, l in zip(filenames, labels, strict=False) if l in class_to_idx]
134
132
 
135
133
  if sort:
136
134
  images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
@@ -210,7 +208,7 @@ def find_test_image(
210
208
  file_samples.append(sample_path)
211
209
 
212
210
  test_split = [os.path.join(folder, sample.strip()) for sample in file_samples]
213
- labels = [t for s, t in zip(filenames, labels) if s in file_samples]
211
+ labels = [t for s, t in zip(filenames, labels, strict=False) if s in file_samples]
214
212
  filenames = [s for s in filenames if s in file_samples]
215
213
  log.info("Selected %d images using test_split_file for the test", len(filenames))
216
214
  if len(filenames) != len(file_samples):
@@ -353,7 +351,7 @@ def get_split(
353
351
 
354
352
  cl, counts = np.unique(targets, return_counts=True)
355
353
 
356
- for num, _cl in zip(counts, cl):
354
+ for num, _cl in zip(counts, cl, strict=False):
357
355
  if num == 1:
358
356
  to_remove = np.where(np.array(targets) == _cl)[0][0]
359
357
  samples = np.delete(np.array(samples), to_remove)
@@ -378,7 +376,7 @@ def get_split(
378
376
  file_samples.append(sample_path)
379
377
 
380
378
  train_split = [os.path.join(image_dir, sample.strip()) for sample in file_samples]
381
- targets = np.array([t for s, t in zip(samples, targets) if s in file_samples])
379
+ targets = np.array([t for s, t in zip(samples, targets, strict=False) if s in file_samples])
382
380
  samples = np.array([s for s in samples if s in file_samples])
383
381
 
384
382
  if limit_training_data is not None:
File without changes
@@ -4,6 +4,7 @@ import os
4
4
  from ast import literal_eval
5
5
  from collections.abc import Callable
6
6
  from functools import wraps
7
+ from typing import Any
7
8
 
8
9
  import matplotlib.pyplot as plt
9
10
  import numpy as np
@@ -123,7 +124,7 @@ def calculate_mask_based_metrics(
123
124
  th_thresh_preds = (th_preds > threshold).float().cpu()
124
125
  thresh_preds = th_thresh_preds.squeeze(0).numpy()
125
126
  dice_scores = metric(th_thresh_preds, th_masks, reduction=None).numpy()
126
- result = {}
127
+ result: dict[str, Any] = {}
127
128
  if multilabel:
128
129
  if n_classes is None:
129
130
  raise ValueError("n_classes arg shouldn't be None when multilabel is True")
@@ -167,7 +168,7 @@ def calculate_mask_based_metrics(
167
168
  "Accuracy": [],
168
169
  }
169
170
  for idx, (image, pred, mask, thresh_pred, dice_score) in enumerate(
170
- zip(images, preds, masks, thresh_preds, dice_scores)
171
+ zip(images, preds, masks, thresh_preds, dice_scores, strict=False)
171
172
  ):
172
173
  if np.sum(mask) == 0:
173
174
  good_dice.append(dice_score)
@@ -261,6 +262,7 @@ def create_mask_report(
261
262
  th_labels = output["label"]
262
263
  n_classes = th_preds.shape[1]
263
264
  # TODO: Apply sigmoid is a wrong name now
265
+ # TODO: Apply sigmoid false is untested
264
266
  if apply_sigmoid:
265
267
  if n_classes == 1:
266
268
  th_preds = torch.nn.Sigmoid()(th_preds)
@@ -271,6 +273,13 @@ def create_mask_report(
271
273
  # Compute labels from the given masks since by default they are all 0
272
274
  th_labels = th_masks.max(dim=2)[0].max(dim=2)[0].squeeze(dim=1)
273
275
  show_orj_predictions = False
276
+ elif n_classes == 1:
277
+ th_thresh_preds = (th_preds > threshold).float()
278
+ else:
279
+ th_thresh_preds = torch.argmax(th_preds, dim=1).float().unsqueeze(1)
280
+ # Compute labels from the given masks since by default they are all 0
281
+ th_labels = th_masks.max(dim=2)[0].max(dim=2)[0].squeeze(dim=1)
282
+ show_orj_predictions = False
274
283
 
275
284
  mean = np.asarray(mean)
276
285
  std = np.asarray(std)
@@ -303,7 +312,7 @@ def create_mask_report(
303
312
  non_zero_score_idx = sorted_idx[~binary_labels]
304
313
  zero_score_idx = sorted_idx[binary_labels]
305
314
  file_paths = []
306
- for name, current_score_idx in zip(["good", "bad"], [zero_score_idx, non_zero_score_idx]):
315
+ for name, current_score_idx in zip(["good", "bad"], [zero_score_idx, non_zero_score_idx], strict=False):
307
316
  if len(current_score_idx) == 0:
308
317
  continue
309
318