quadra 2.4.0a0__tar.gz → 2.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {quadra-2.4.0a0 → quadra-2.5.1}/PKG-INFO +3 -2
  2. {quadra-2.4.0a0 → quadra-2.5.1}/pyproject.toml +1 -1
  3. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/__init__.py +1 -1
  4. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/core/default.yaml +1 -0
  5. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/anomaly.py +29 -3
  6. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/classification.py +12 -1
  7. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/anomaly.py +35 -0
  8. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/export.py +86 -11
  9. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/utils.py +70 -35
  10. {quadra-2.4.0a0 → quadra-2.5.1}/LICENSE +0 -0
  11. {quadra-2.4.0a0 → quadra-2.5.1}/README.md +0 -0
  12. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/callbacks/__init__.py +0 -0
  13. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/callbacks/anomalib.py +0 -0
  14. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/callbacks/lightning.py +0 -0
  15. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/callbacks/mlflow.py +0 -0
  16. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/callbacks/scheduler.py +0 -0
  17. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/__init__.py +0 -0
  18. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/caformer_m36.yaml +0 -0
  19. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/caformer_s36.yaml +0 -0
  20. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/convnextv2_base.yaml +0 -0
  21. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/convnextv2_femto.yaml +0 -0
  22. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/convnextv2_tiny.yaml +0 -0
  23. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/dino_vitb8.yaml +0 -0
  24. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/dino_vits8.yaml +0 -0
  25. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/dinov2_vitb14.yaml +0 -0
  26. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/dinov2_vits14.yaml +0 -0
  27. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/efficientnet_b0.yaml +0 -0
  28. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/efficientnet_b1.yaml +0 -0
  29. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/efficientnet_b2.yaml +0 -0
  30. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/efficientnet_b3.yaml +0 -0
  31. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/efficientnetv2_s.yaml +0 -0
  32. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/levit_128s.yaml +0 -0
  33. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/mnasnet0_5.yaml +0 -0
  34. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/resnet101.yaml +0 -0
  35. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/resnet18.yaml +0 -0
  36. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/resnet18_ssl.yaml +0 -0
  37. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/resnet50.yaml +0 -0
  38. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/smp.yaml +0 -0
  39. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/tiny_vit_21m_224.yaml +0 -0
  40. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/unetr.yaml +0 -0
  41. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/vit16_base.yaml +0 -0
  42. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/vit16_small.yaml +0 -0
  43. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/vit16_tiny.yaml +0 -0
  44. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +0 -0
  45. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/callbacks/all.yaml +0 -0
  46. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/callbacks/default.yaml +0 -0
  47. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/callbacks/default_anomalib.yaml +0 -0
  48. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/config.yaml +0 -0
  49. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/anomaly.yaml +0 -0
  50. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/classification.yaml +0 -0
  51. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/multilabel_classification.yaml +0 -0
  52. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/segmentation.yaml +0 -0
  53. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/segmentation_multiclass.yaml +0 -0
  54. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/sklearn_classification.yaml +0 -0
  55. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/sklearn_classification_patch.yaml +0 -0
  56. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/base/ssl.yaml +0 -0
  57. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/generic/imagenette/classification/base.yaml +0 -0
  58. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +0 -0
  59. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +0 -0
  60. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +0 -0
  61. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +0 -0
  62. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/cfa.yaml +0 -0
  63. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/cflow.yaml +0 -0
  64. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/csflow.yaml +0 -0
  65. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/draem.yaml +0 -0
  66. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/efficient_ad.yaml +0 -0
  67. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/fastflow.yaml +0 -0
  68. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/inference.yaml +0 -0
  69. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/padim.yaml +0 -0
  70. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/anomaly/patchcore.yaml +0 -0
  71. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/classification.yaml +0 -0
  72. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/classification_evaluation.yaml +0 -0
  73. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/multilabel_classification.yaml +0 -0
  74. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/sklearn_classification.yaml +0 -0
  75. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +0 -0
  76. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +0 -0
  77. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +0 -0
  78. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/segmentation/smp.yaml +0 -0
  79. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +0 -0
  80. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +0 -0
  81. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +0 -0
  82. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/ssl/barlow.yaml +0 -0
  83. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/ssl/byol.yaml +0 -0
  84. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/ssl/dino.yaml +0 -0
  85. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/ssl/linear_eval.yaml +0 -0
  86. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/ssl/simclr.yaml +0 -0
  87. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/base/ssl/simsiam.yaml +0 -0
  88. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/custom/cls.yaml +0 -0
  89. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/default.yaml +0 -0
  90. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/imagenette/classification/default.yaml +0 -0
  91. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +0 -0
  92. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +0 -0
  93. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +0 -0
  94. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +0 -0
  95. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +0 -0
  96. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +0 -0
  97. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +0 -0
  98. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +0 -0
  99. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +0 -0
  100. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +0 -0
  101. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +0 -0
  102. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +0 -0
  103. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +0 -0
  104. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +0 -0
  105. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +0 -0
  106. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +0 -0
  107. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +0 -0
  108. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +0 -0
  109. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +0 -0
  110. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +0 -0
  111. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +0 -0
  112. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +0 -0
  113. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +0 -0
  114. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/export/default.yaml +0 -0
  115. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/hydra/anomaly_custom.yaml +0 -0
  116. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/hydra/default.yaml +0 -0
  117. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/inference/default.yaml +0 -0
  118. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/logger/comet.yaml +0 -0
  119. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/logger/csv.yaml +0 -0
  120. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/logger/mlflow.yaml +0 -0
  121. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/logger/tensorboard.yaml +0 -0
  122. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/asl.yaml +0 -0
  123. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/barlow.yaml +0 -0
  124. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/bce.yaml +0 -0
  125. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/byol.yaml +0 -0
  126. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/cross_entropy.yaml +0 -0
  127. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/dino.yaml +0 -0
  128. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/simclr.yaml +0 -0
  129. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/simsiam.yaml +0 -0
  130. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/smp_ce.yaml +0 -0
  131. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/smp_dice.yaml +0 -0
  132. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/smp_dice_multiclass.yaml +0 -0
  133. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/smp_mcc.yaml +0 -0
  134. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/loss/vicreg.yaml +0 -0
  135. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/cfa.yaml +0 -0
  136. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/cflow.yaml +0 -0
  137. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/csflow.yaml +0 -0
  138. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/dfm.yaml +0 -0
  139. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/draem.yaml +0 -0
  140. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/efficient_ad.yaml +0 -0
  141. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/fastflow.yaml +0 -0
  142. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/padim.yaml +0 -0
  143. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/anomalib/patchcore.yaml +0 -0
  144. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/barlow.yaml +0 -0
  145. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/byol.yaml +0 -0
  146. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/classification.yaml +0 -0
  147. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/dino.yaml +0 -0
  148. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/logistic_regression.yaml +0 -0
  149. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/multilabel_classification.yaml +0 -0
  150. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/simclr.yaml +0 -0
  151. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/simsiam.yaml +0 -0
  152. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/smp.yaml +0 -0
  153. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/smp_multiclass.yaml +0 -0
  154. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/model/vicreg.yaml +0 -0
  155. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/optimizer/adam.yaml +0 -0
  156. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/optimizer/adamw.yaml +0 -0
  157. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/optimizer/default.yaml +0 -0
  158. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/optimizer/lars.yaml +0 -0
  159. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/optimizer/sgd.yaml +0 -0
  160. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/scheduler/default.yaml +0 -0
  161. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/scheduler/rop.yaml +0 -0
  162. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/scheduler/step.yaml +0 -0
  163. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/scheduler/warmrestart.yaml +0 -0
  164. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/scheduler/warmup.yaml +0 -0
  165. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/cfa.yaml +0 -0
  166. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/cflow.yaml +0 -0
  167. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/csflow.yaml +0 -0
  168. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/draem.yaml +0 -0
  169. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/efficient_ad.yaml +0 -0
  170. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/fastflow.yaml +0 -0
  171. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/inference.yaml +0 -0
  172. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/padim.yaml +0 -0
  173. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/anomalib/patchcore.yaml +0 -0
  174. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/classification.yaml +0 -0
  175. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/classification_evaluation.yaml +0 -0
  176. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/default.yaml +0 -0
  177. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/segmentation.yaml +0 -0
  178. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/segmentation_evaluation.yaml +0 -0
  179. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/sklearn_classification.yaml +0 -0
  180. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/sklearn_classification_patch.yaml +0 -0
  181. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/sklearn_classification_patch_test.yaml +0 -0
  182. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/sklearn_classification_test.yaml +0 -0
  183. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/task/ssl.yaml +0 -0
  184. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/trainer/lightning_cpu.yaml +0 -0
  185. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/trainer/lightning_gpu.yaml +0 -0
  186. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/trainer/lightning_gpu_bf16.yaml +0 -0
  187. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/trainer/lightning_gpu_fp16.yaml +0 -0
  188. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/trainer/lightning_multigpu.yaml +0 -0
  189. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/trainer/sklearn_classification.yaml +0 -0
  190. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/byol.yaml +0 -0
  191. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/byol_no_random_resize.yaml +0 -0
  192. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/default.yaml +0 -0
  193. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/default_numpy.yaml +0 -0
  194. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/default_resize.yaml +0 -0
  195. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/dino.yaml +0 -0
  196. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/configs/transforms/linear_eval.yaml +0 -0
  197. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/__init__.py +0 -0
  198. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/anomaly.py +0 -0
  199. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/base.py +0 -0
  200. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/classification.py +0 -0
  201. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/generic/__init__.py +0 -0
  202. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/generic/imagenette.py +0 -0
  203. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/generic/mnist.py +0 -0
  204. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/generic/mvtec.py +0 -0
  205. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/generic/oxford_pet.py +0 -0
  206. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/patch.py +0 -0
  207. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/segmentation.py +0 -0
  208. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datamodules/ssl.py +0 -0
  209. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datasets/__init__.py +0 -0
  210. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datasets/anomaly.py +0 -0
  211. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datasets/classification.py +0 -0
  212. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datasets/patch.py +0 -0
  213. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datasets/segmentation.py +0 -0
  214. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/datasets/ssl.py +0 -0
  215. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/__init__.py +0 -0
  216. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/classification/__init__.py +0 -0
  217. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/classification/asl.py +0 -0
  218. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/classification/focal.py +0 -0
  219. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/classification/prototypical.py +0 -0
  220. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/__init__.py +0 -0
  221. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/barlowtwins.py +0 -0
  222. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/byol.py +0 -0
  223. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/dino.py +0 -0
  224. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/hyperspherical.py +0 -0
  225. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/idmm.py +0 -0
  226. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/simclr.py +0 -0
  227. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/simsiam.py +0 -0
  228. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/losses/ssl/vicreg.py +0 -0
  229. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/main.py +0 -0
  230. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/metrics/__init__.py +0 -0
  231. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/metrics/segmentation.py +0 -0
  232. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/models/__init__.py +0 -0
  233. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/models/base.py +0 -0
  234. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/models/classification/__init__.py +0 -0
  235. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/models/classification/backbones.py +0 -0
  236. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/models/classification/base.py +0 -0
  237. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/models/evaluation.py +0 -0
  238. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/__init__.py +0 -0
  239. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/backbone.py +0 -0
  240. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/base.py +0 -0
  241. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/classification/__init__.py +0 -0
  242. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/classification/base.py +0 -0
  243. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/__init__.py +0 -0
  244. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/barlowtwins.py +0 -0
  245. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/byol.py +0 -0
  246. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/common.py +0 -0
  247. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/dino.py +0 -0
  248. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/hyperspherical.py +0 -0
  249. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/idmm.py +0 -0
  250. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/simclr.py +0 -0
  251. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/simsiam.py +0 -0
  252. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/modules/ssl/vicreg.py +0 -0
  253. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/optimizers/__init__.py +0 -0
  254. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/optimizers/lars.py +0 -0
  255. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/optimizers/sam.py +0 -0
  256. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/schedulers/__init__.py +0 -0
  257. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/schedulers/base.py +0 -0
  258. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/schedulers/warmup.py +0 -0
  259. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/__init__.py +0 -0
  260. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/base.py +0 -0
  261. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/patch.py +0 -0
  262. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/segmentation.py +0 -0
  263. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/tasks/ssl.py +0 -0
  264. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/trainers/README.md +0 -0
  265. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/trainers/__init__.py +0 -0
  266. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/trainers/classification.py +0 -0
  267. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/__init__.py +0 -0
  268. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/classification.py +0 -0
  269. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/deprecation.py +0 -0
  270. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/evaluation.py +0 -0
  271. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/imaging.py +0 -0
  272. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/logger.py +0 -0
  273. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/mlflow.py +0 -0
  274. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/model_manager.py +0 -0
  275. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/models.py +0 -0
  276. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/patch/__init__.py +0 -0
  277. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/patch/dataset.py +0 -0
  278. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/patch/metrics.py +0 -0
  279. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/patch/model.py +0 -0
  280. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/patch/visualization.py +0 -0
  281. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/resolver.py +0 -0
  282. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/segmentation.py +0 -0
  283. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/__init__.py +0 -0
  284. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/__init__.py +0 -0
  285. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/dataset/__init__.py +0 -0
  286. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/dataset/anomaly.py +0 -0
  287. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/dataset/classification.py +0 -0
  288. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/dataset/imagenette.py +0 -0
  289. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/dataset/segmentation.py +0 -0
  290. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/models/__init__.py +0 -0
  291. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/models/anomaly.py +0 -0
  292. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/models/classification.py +0 -0
  293. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/fixtures/models/segmentation.py +0 -0
  294. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/helpers.py +0 -0
  295. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/tests/models.py +0 -0
  296. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/validator.py +0 -0
  297. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/visualization.py +0 -0
  298. {quadra-2.4.0a0 → quadra-2.5.1}/quadra/utils/vit_explainability.py +0 -0
  299. {quadra-2.4.0a0 → quadra-2.5.1}/quadra_hydra_plugin/hydra_plugins/quadra_searchpath_plugin.py +0 -0
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: quadra
3
- Version: 2.4.0a0
3
+ Version: 2.5.1
4
4
  Summary: Deep Learning experiment orchestration library
5
5
  License: Apache-2.0
6
+ License-File: LICENSE
6
7
  Keywords: deep learning,experiment,lightning,hydra-core
7
8
  Author: Federico Belotti
8
9
  Author-email: federico.belotti@orobix.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "quadra"
3
- version = "2.4.0a0"
3
+ version = "2.5.1"
4
4
  description = "Deep Learning experiment orchestration library"
5
5
  authors = [
6
6
  "Federico Belotti <federico.belotti@orobix.com>",
@@ -1,4 +1,4 @@
1
- __version__ = "2.4.0a0"
1
+ __version__ = "2.5.1"
2
2
 
3
3
 
4
4
  def get_version():
@@ -9,3 +9,4 @@ experiment_path: null
9
9
  upload_artifacts: False
10
10
  upload_models: ${export.types} # Default behavior in quadra <= 1.5.6
11
11
  log_level: info
12
+ mlflow_zip_models: False
@@ -244,7 +244,7 @@ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataMo
244
244
  ):
245
245
  threshold = torch.tensor(100.0)
246
246
  else:
247
- threshold = self.module.image_metrics.F1Score.threshold
247
+ threshold = self.module.image_metrics.F1Score.threshold # type: ignore[union-attr,assignment]
248
248
 
249
249
  # The output of the prediction is a normalized score so the cumulative histogram is displayed with the
250
250
  # normalized scores
@@ -328,6 +328,22 @@ class AnomalibDetection(Generic[AnomalyDataModuleT], LightningTask[AnomalyDataMo
328
328
  else:
329
329
  utils.upload_file_tensorboard(a, tensorboard_logger)
330
330
 
331
+ def execute(self):
332
+ """Execute the experiment and all the steps."""
333
+ self.prepare()
334
+ self.train()
335
+ # When training in fp16 mixed precision, export function casts model weights from fp32 to fp16,
336
+ # for this reason, predictions logits could slightly change and predictions could be inconsistent between
337
+ # test and generated report.
338
+ # Performing export before test allows to have consistent results in test metrics and generated report.
339
+ if self.config.export is not None and len(self.config.export.types) > 0:
340
+ self.export()
341
+ if self.run_test:
342
+ self.test()
343
+ if self.report:
344
+ self.generate_report()
345
+ self.finalize()
346
+
331
347
 
332
348
  class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
333
349
  """Evaluation task for Anomalib.
@@ -445,12 +461,22 @@ class AnomalibEvaluation(Evaluation[AnomalyDataModule]):
445
461
  training_threshold = self.model_data[f"{self.training_threshold_type}_threshold"]
446
462
  optimal_threshold = self.metadata["threshold"]
447
463
 
448
- normalized_optimal_threshold = cast(float, normalize_anomaly_score(optimal_threshold, training_threshold))
449
-
450
464
  os.makedirs(os.path.join(self.report_path, "predictions"), exist_ok=True)
451
465
  os.makedirs(os.path.join(self.report_path, "heatmaps"), exist_ok=True)
452
466
 
453
467
  anomaly_scores = self.metadata["anomaly_scores"].cpu().numpy()
468
+
469
+ # The reason I have to expand dims and cast the optimal threshold to anomaly_scores dtype is because
470
+ # of internal roundings performed differently by numpy and python
471
+ # Particularly the normalized_optimal_threshold computed directly using float values might be higher than the
472
+ # actual value obtained by the anomaly_scores
473
+ normalized_optimal_threshold = cast(
474
+ np.ndarray,
475
+ normalize_anomaly_score(
476
+ np.expand_dims(np.array(optimal_threshold, dtype=anomaly_scores.dtype), -1), training_threshold
477
+ ),
478
+ ).item()
479
+
454
480
  anomaly_scores = normalize_anomaly_score(anomaly_scores, training_threshold)
455
481
 
456
482
  if not isinstance(anomaly_scores, np.ndarray):
@@ -307,6 +307,14 @@ class Classification(Generic[ClassificationDataModuleT], LightningTask[Classific
307
307
  # TODO: What happens if we have 64 precision?
308
308
  half_precision = "16" in self.trainer.precision
309
309
 
310
+ example_input: torch.Tensor | None = None
311
+
312
+ if hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "val_dataset"):
313
+ # Retrieve a better input to evaluate fp16 performance or efficientnetb0 does not sometimes export properly
314
+ example_input = self.trainer.datamodule.val_dataset[0][0]
315
+
316
+ # Selected rtol and atol are quite high, this is mostly done for efficientnetb0 that seems to be
317
+ # quite unstable in fp16
310
318
  self.model_json, export_paths = export_model(
311
319
  config=self.config,
312
320
  model=module.model,
@@ -314,6 +322,9 @@ class Classification(Generic[ClassificationDataModuleT], LightningTask[Classific
314
322
  half_precision=half_precision,
315
323
  input_shapes=input_shapes,
316
324
  idx_to_class=idx_to_class,
325
+ example_inputs=example_input,
326
+ rtol=0.05,
327
+ atol=0.01,
317
328
  )
318
329
 
319
330
  if len(export_paths) == 0:
@@ -1136,7 +1147,7 @@ class ClassificationEvaluation(Evaluation[ClassificationDataModuleT]):
1136
1147
  return
1137
1148
 
1138
1149
  if isinstance(self.deployment_model.model.features_extractor, timm.models.resnet.ResNet):
1139
- target_layers = [cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor.layer4[-1]]
1150
+ target_layers = [cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor.layer4[-1]] # type: ignore[index]
1140
1151
  self.cam = GradCAM(
1141
1152
  model=self.deployment_model.model,
1142
1153
  target_layers=target_layers,
@@ -5,6 +5,8 @@
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ from typing import cast
9
+
8
10
  try:
9
11
  from typing import Any, TypeAlias
10
12
  except ImportError:
@@ -43,6 +45,39 @@ def normalize_anomaly_score(raw_score: MapOrValue, threshold: float) -> MapOrVal
43
45
  else:
44
46
  normalized_score = 200.0 - ((raw_score / threshold) * 100.0)
45
47
 
48
+ # Ensures that the normalized scores are consistent with the raw scores
49
+ # For all the items whose prediction changes after normalization, force the normalized score to be
50
+ # consistent with the prediction made on the raw score by clipping the score:
51
+ # - to 100.0 if the prediction was "anomaly" on the raw score and "good" on the normalized score
52
+ # - to 99.99 if the prediction was "good" on the raw score and "anomaly" on the normalized score
53
+ score = raw_score
54
+ if isinstance(score, torch.Tensor):
55
+ score = score.cpu().numpy()
56
+ # Anomalib classify as anomaly if anomaly_score gte threshold
57
+ is_anomaly_mask = score >= threshold
58
+ is_not_anomaly_mask = np.bitwise_not(is_anomaly_mask)
59
+ if isinstance(normalized_score, torch.Tensor):
60
+ if normalized_score.dim() == 0:
61
+ normalized_score = (
62
+ normalized_score.clamp(min=100.0) if is_anomaly_mask else normalized_score.clamp(max=99.99)
63
+ )
64
+ else:
65
+ normalized_score[is_anomaly_mask] = normalized_score[is_anomaly_mask].clamp(min=100.0)
66
+ normalized_score[is_not_anomaly_mask] = normalized_score[is_not_anomaly_mask].clamp(max=99.99)
67
+ elif isinstance(normalized_score, np.ndarray) or np.isscalar(normalized_score):
68
+ if np.isscalar(normalized_score) or normalized_score.ndim == 0: # type: ignore[union-attr]
69
+ normalized_score = (
70
+ np.clip(normalized_score, a_min=100.0, a_max=None)
71
+ if is_anomaly_mask
72
+ else np.clip(normalized_score, a_min=None, a_max=99.99)
73
+ )
74
+ else:
75
+ normalized_score = cast(np.ndarray, normalized_score)
76
+ normalized_score[is_anomaly_mask] = np.clip(normalized_score[is_anomaly_mask], a_min=100.0, a_max=None)
77
+ normalized_score[is_not_anomaly_mask] = np.clip(
78
+ normalized_score[is_not_anomaly_mask], a_min=None, a_max=99.99
79
+ )
80
+
46
81
  if isinstance(normalized_score, torch.Tensor):
47
82
  return torch.clamp(normalized_score, 0.0, 1000.0)
48
83
 
@@ -119,6 +119,7 @@ def export_torchscript_model(
119
119
  input_shapes: list[Any] | None = None,
120
120
  half_precision: bool = False,
121
121
  model_name: str = "model.pt",
122
+ example_inputs: list[torch.Tensor] | tuple[torch.Tensor, ...] | torch.Tensor | None = None,
122
123
  ) -> tuple[str, Any] | None:
123
124
  """Export a PyTorch model with TorchScript.
124
125
 
@@ -128,6 +129,8 @@ def export_torchscript_model(
128
129
  output_path: Path to save the model
129
130
  half_precision: If True, the model will be exported with half precision
130
131
  model_name: Name of the exported model
132
+ example_inputs: If provided use this to evaluate the model instead of generating random inputs, it's expected to
133
+ be a list of tensors or a single tensor without batch dimension
131
134
 
132
135
  Returns:
133
136
  If the model is exported successfully, the path to the model and the input shape are returned.
@@ -144,7 +147,32 @@ def export_torchscript_model(
144
147
  else:
145
148
  model.cpu()
146
149
 
147
- model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
150
+ batch_size = 1
151
+ model_inputs: tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None
152
+ if example_inputs is not None:
153
+ if isinstance(example_inputs, Sequence):
154
+ model_input_tensors = []
155
+ model_input_shapes = []
156
+
157
+ for example_input in example_inputs:
158
+ new_inp = example_input.to(
159
+ device="cuda:0" if half_precision else "cpu",
160
+ dtype=torch.float16 if half_precision else torch.float32,
161
+ )
162
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
163
+ model_input_tensors.append(new_inp)
164
+ model_input_shapes.append(new_inp[0].shape)
165
+
166
+ model_inputs = (model_input_tensors, [model_input_shapes])
167
+ else:
168
+ new_inp = example_inputs.to(
169
+ device="cuda:0" if half_precision else "cpu",
170
+ dtype=torch.float16 if half_precision else torch.float32,
171
+ )
172
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
173
+ model_inputs = (new_inp, [new_inp[0].shape])
174
+ else:
175
+ model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
148
176
 
149
177
  if model_inputs is None:
150
178
  return None
@@ -182,6 +210,9 @@ def export_onnx_model(
182
210
  input_shapes: list[Any] | None = None,
183
211
  half_precision: bool = False,
184
212
  model_name: str = "model.onnx",
213
+ example_inputs: list[torch.Tensor] | tuple[torch.Tensor, ...] | torch.Tensor | None = None,
214
+ rtol: float = 0.01,
215
+ atol: float = 5e-3,
185
216
  ) -> tuple[str, Any] | None:
186
217
  """Export a PyTorch model with ONNX.
187
218
 
@@ -192,6 +223,10 @@ def export_onnx_model(
192
223
  onnx_config: ONNX export configuration
193
224
  half_precision: If True, the model will be exported with half precision
194
225
  model_name: Name of the exported model
226
+ example_inputs: If provided use this to evaluate the model instead of generating random inputs, it's expected to
227
+ be a list of tensors or a single tensor without batch dimension
228
+ rtol: Relative tolerance for the ONNX safe export in fp16
229
+ atol: Absolute tolerance for the ONNX safe export in fp16
195
230
  """
196
231
  if not ONNX_AVAILABLE:
197
232
  log.warning("ONNX is not installed, can not export model in this format.")
@@ -210,9 +245,32 @@ def export_onnx_model(
210
245
  else:
211
246
  batch_size = 1
212
247
 
213
- model_inputs = extract_torch_model_inputs(
214
- model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
215
- )
248
+ model_inputs: tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None
249
+ if example_inputs is not None:
250
+ if isinstance(example_inputs, Sequence):
251
+ model_input_tensors = []
252
+ model_input_shapes = []
253
+
254
+ for example_input in example_inputs:
255
+ new_inp = example_input.to(
256
+ device="cuda:0" if half_precision else "cpu",
257
+ dtype=torch.float16 if half_precision else torch.float32,
258
+ )
259
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
260
+ model_input_tensors.append(new_inp)
261
+ model_input_shapes.append(new_inp[0].shape)
262
+
263
+ model_inputs = (model_input_tensors, [model_input_shapes])
264
+ else:
265
+ new_inp = example_inputs.to(
266
+ device="cuda:0" if half_precision else "cpu",
267
+ dtype=torch.float16 if half_precision else torch.float32,
268
+ )
269
+ new_inp = new_inp.unsqueeze(0).repeat(batch_size, *(1 for x in new_inp.shape))
270
+ model_inputs = ([new_inp], [new_inp[0].shape])
271
+ else:
272
+ model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
273
+
216
274
  if model_inputs is None:
217
275
  return None
218
276
 
@@ -266,6 +324,8 @@ def export_onnx_model(
266
324
 
267
325
  if isinstance(inp, list):
268
326
  inp = tuple(inp) # onnx doesn't like lists representing tuples of inputs
327
+ elif isinstance(inp, torch.Tensor):
328
+ inp = (inp,)
269
329
 
270
330
  if isinstance(inp, dict):
271
331
  raise ValueError("ONNX export does not support model with dict inputs")
@@ -290,6 +350,8 @@ def export_onnx_model(
290
350
  onnx_config=onnx_config,
291
351
  input_shapes=input_shapes,
292
352
  input_names=input_names,
353
+ rtol=rtol,
354
+ atol=atol,
293
355
  )
294
356
 
295
357
  if not is_export_ok:
@@ -324,6 +386,8 @@ def _safe_export_half_precision_onnx(
324
386
  onnx_config: DictConfig,
325
387
  input_shapes: list[Any],
326
388
  input_names: list[str],
389
+ rtol: float = 0.01,
390
+ atol: float = 5e-3,
327
391
  ) -> bool:
328
392
  """Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export
329
393
  the model with a more stable export and overwrite the original model.
@@ -335,6 +399,8 @@ def _safe_export_half_precision_onnx(
335
399
  onnx_config: ONNX export configuration
336
400
  input_shapes: Input shapes for the model
337
401
  input_names: Input names for the model
402
+ rtol: Relative tolerance to evaluate the model
403
+ atol: Absolute tolerance to evaluate the model
338
404
 
339
405
  Returns:
340
406
  True if the model is stable or it was possible to export a more stable model, False otherwise.
@@ -364,16 +430,15 @@ def _safe_export_half_precision_onnx(
364
430
  export_output = export_onnx_model(
365
431
  model=model,
366
432
  output_path=os.path.dirname(export_model_path),
367
- onnx_config=onnx_config,
433
+ # Force to not simplify fp32 model
434
+ onnx_config=DictConfig({**onnx_config, "simplify": False}),
368
435
  input_shapes=input_shapes,
369
436
  half_precision=False,
370
437
  model_name=os.path.basename(export_model_path),
371
438
  )
372
- if export_output is not None:
373
- export_model_path, _ = export_output
374
- else:
375
- log.warning("Failed to export model")
376
- return False
439
+ if export_output is None:
440
+ # This should not happen
441
+ raise RuntimeError("Failed to export model")
377
442
 
378
443
  model_fp32 = onnx.load(export_model_path)
379
444
  test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))}
@@ -381,7 +446,7 @@ def _safe_export_half_precision_onnx(
381
446
  with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
382
447
  # This function prints a lot of information that is not useful for the user
383
448
  model_fp16 = auto_convert_mixed_precision(
384
- model_fp32, test_data, rtol=0.01, atol=5e-3, keep_io_types=False
449
+ model_fp32, test_data, rtol=rtol, atol=atol, keep_io_types=False
385
450
  )
386
451
  onnx.save(model_fp16, export_model_path)
387
452
 
@@ -431,6 +496,9 @@ def export_model(
431
496
  input_shapes: list[Any] | None = None,
432
497
  idx_to_class: dict[int, str] | None = None,
433
498
  pytorch_model_type: Literal["backbone", "model"] = "model",
499
+ example_inputs: list[Any] | tuple[Any, ...] | torch.Tensor | None = None,
500
+ rtol: float = 0.01,
501
+ atol: float = 5e-3,
434
502
  ) -> tuple[dict[str, Any], dict[str, str]]:
435
503
  """Generate deployment models for the task.
436
504
 
@@ -443,6 +511,9 @@ def export_model(
443
511
  idx_to_class: Mapping from class index to class name
444
512
  pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
445
513
  config.backbone config, otherwise we will save the config.model
514
+ example_inputs: If provided use this to evaluate the model instead of generating random inputs
515
+ rtol: Relative tolerance for the ONNX safe export in fp16
516
+ atol: Absolute tolerance for the ONNX safe export in fp16
446
517
 
447
518
  Returns:
448
519
  If the model is exported successfully, return a dictionary containing information about the exported model and
@@ -468,6 +539,7 @@ def export_model(
468
539
  input_shapes=input_shapes,
469
540
  output_path=export_folder,
470
541
  half_precision=half_precision,
542
+ example_inputs=example_inputs,
471
543
  )
472
544
 
473
545
  if out is None:
@@ -495,6 +567,9 @@ def export_model(
495
567
  onnx_config=config.export.onnx,
496
568
  input_shapes=input_shapes,
497
569
  half_precision=half_precision,
570
+ example_inputs=example_inputs,
571
+ rtol=rtol,
572
+ atol=atol,
498
573
  )
499
574
 
500
575
  if out is None:
@@ -8,10 +8,12 @@ import glob
8
8
  import json
9
9
  import logging
10
10
  import os
11
+ import shutil
11
12
  import subprocess
12
13
  import sys
13
14
  import warnings
14
15
  from collections.abc import Iterable, Iterator, Sequence
16
+ from tempfile import TemporaryDirectory
15
17
  from typing import Any, cast
16
18
 
17
19
  import cv2
@@ -299,45 +301,78 @@ def finish(
299
301
  quadra_export.generate_torch_inputs(input_size, device=device, half_precision=half_precision),
300
302
  )
301
303
  types_to_upload = config.core.get("upload_models")
302
- for model_path in deployed_models:
303
- model_type = model_type_from_path(model_path)
304
- if model_type is None:
305
- logging.warning("%s model type not supported", model_path)
306
- continue
307
- if model_type is not None and model_type in types_to_upload:
308
- if model_type == "pytorch":
309
- logging.warning("Pytorch format still not supported for mlflow upload")
304
+ mlflow_zip_models = config.core.get("mlflow_zip_models", False)
305
+ model_uploaded = False
306
+ with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
307
+ for model_path in deployed_models:
308
+ model_type = model_type_from_path(model_path)
309
+ model_name = os.path.basename(model_path)
310
+
311
+ if model_type is None:
312
+ logging.warning("%s model type not supported", model_path)
310
313
  continue
311
-
312
- model = quadra_export.import_deployment_model(
313
- model_path,
314
- device=device,
315
- inference_config=config.inference,
316
- )
317
-
318
- if model_type in ["torchscript", "pytorch"]:
319
- signature = infer_signature_model(model.model, inputs)
320
- with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
321
- mlflow.pytorch.log_model(
322
- model.model,
323
- artifact_path=model_path,
324
- signature=signature,
325
- )
326
- elif model_type in ["onnx", "simplified_onnx"] and ONNX_AVAILABLE:
327
- signature = infer_signature_model(model, inputs)
328
- with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
329
- if model.model_path is None:
330
- logging.warning(
331
- "Cannot log onnx model on mlflow, \
332
- BaseEvaluationModel 'model_path' attribute is None"
314
+ if model_type is not None and model_type in types_to_upload:
315
+ if model_type == "pytorch" and not mlflow_zip_models:
316
+ logging.warning("Pytorch format still not supported for mlflow upload")
317
+ continue
318
+
319
+ if mlflow_zip_models:
320
+ with TemporaryDirectory() as temp_dir:
321
+ if model_type == "pytorch" and os.path.isfile(
322
+ os.path.join(export_folder, "model_config.yaml")
323
+ ):
324
+ shutil.copy(model_path, temp_dir)
325
+ shutil.copy(os.path.join(export_folder, "model_config.yaml"), temp_dir)
326
+ shutil.make_archive("assets", "zip", root_dir=temp_dir)
327
+ else:
328
+ shutil.make_archive(
329
+ "assets",
330
+ "zip",
331
+ root_dir=os.path.dirname(model_path),
332
+ base_dir=model_name,
333
+ )
334
+ shutil.move("assets.zip", temp_dir)
335
+ mlflow.pyfunc.log_model(
336
+ artifact_path=model_path,
337
+ loader_module="not.used",
338
+ data_path=os.path.join(temp_dir, "assets.zip"),
339
+ pip_requirements=[""],
333
340
  )
334
- else:
335
- model_proto = onnx.load(model.model_path)
336
- mlflow.onnx.log_model(
337
- model_proto,
341
+ model_uploaded = True
342
+ else:
343
+ model = quadra_export.import_deployment_model(
344
+ model_path,
345
+ device=device,
346
+ inference_config=config.inference,
347
+ )
348
+
349
+ if model_type in ["torchscript", "pytorch"]:
350
+ signature = infer_signature_model(model.model, inputs)
351
+ mlflow.pytorch.log_model(
352
+ model.model,
338
353
  artifact_path=model_path,
339
354
  signature=signature,
340
355
  )
356
+ model_uploaded = True
357
+
358
+ elif model_type in ["onnx", "simplified_onnx"] and ONNX_AVAILABLE:
359
+ if model.model_path is None:
360
+ logging.warning(
361
+ "Cannot log onnx model on mlflow, \
362
+ BaseEvaluationModel 'model_path' attribute is None"
363
+ )
364
+ else:
365
+ signature = infer_signature_model(model, inputs)
366
+ model_proto = onnx.load(model.model_path)
367
+ mlflow.onnx.log_model(
368
+ model_proto,
369
+ artifact_path=model_path,
370
+ signature=signature,
371
+ )
372
+ model_uploaded = True
373
+
374
+ if model_uploaded:
375
+ mlflow.log_artifact(os.path.join(export_folder, "model.json"), export_folder)
341
376
 
342
377
  if tensorboard_logger is not None:
343
378
  config_paths = []
@@ -376,7 +411,7 @@ def model_type_from_path(model_path: str) -> str | None:
376
411
  - "pytorch" if the model has a '.pth' extension (PyTorch).
377
412
  - "simplified_onnx" if the model file ends with 'simplified.onnx' (Simplified ONNX).
378
413
  - "onnx" if the model has a '.onnx' extension (ONNX).
379
- - "json" id the model has a '.json' extension (JSON).
414
+ - "json" if the model has a '.json' extension (JSON).
380
415
  - None if model extension is not supported.
381
416
 
382
417
  Example:
File without changes
File without changes