quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ def one_hot(
11
+ labels: torch.Tensor,
12
+ num_classes: int,
13
+ device: torch.device | None = None,
14
+ dtype: torch.dtype | None = None,
15
+ eps: float = 1e-6,
16
+ ) -> torch.Tensor:
17
+ r"""Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.
18
+
19
+ Args:
20
+ labels: tensor with labels of shape :math:`(N, *)`, where N is batch size.
21
+ Each value is an integer representing correct classification.
22
+ num_classes: number of classes in labels.
23
+ device: the desired device of returned tensor.
24
+ dtype: the desired data type of returned tensor.
25
+ eps: a value added to the returned tensor.
26
+
27
+ Returns:
28
+ the labels in one hot tensor of shape :math:`(N, C, *)`,
29
+
30
+ Examples:
31
+ >>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
32
+ >>> one_hot(labels, num_classes=3)
33
+ tensor([[[[1.0000e+00, 1.0000e-06],
34
+ [1.0000e-06, 1.0000e+00]],
35
+ <BLANKLINE>
36
+ [[1.0000e-06, 1.0000e+00],
37
+ [1.0000e-06, 1.0000e-06]],
38
+ <BLANKLINE>
39
+ [[1.0000e-06, 1.0000e-06],
40
+ [1.0000e+00, 1.0000e-06]]]])
41
+
42
+ """
43
+ if not isinstance(labels, torch.Tensor):
44
+ raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")
45
+
46
+ if not labels.dtype == torch.int64:
47
+ raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")
48
+
49
+ if num_classes < 1:
50
+ raise ValueError(f"The number of classes must be bigger than one. Got: {num_classes}")
51
+
52
+ shape = labels.shape
53
+ one_hot_output = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
54
+
55
+ return one_hot_output.scatter_(1, labels.unsqueeze(1), 1.0) + eps
56
+
57
+
58
+ # based on: # https://github.com/zhezh/focalloss/blob/master/focalloss.py
59
+ def focal_loss(
60
+ input_tensor: torch.Tensor,
61
+ target: torch.Tensor,
62
+ alpha: float,
63
+ gamma: float = 2.0,
64
+ reduction: str = "none",
65
+ eps: float | None = None,
66
+ ) -> torch.Tensor:
67
+ r"""Criterion that computes Focal loss.
68
+
69
+ According to :cite:`lin2018focal`, the Focal loss is computed as follows:
70
+
71
+ .. math::
72
+
73
+ \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
74
+
75
+ Where:
76
+ - :math:`p_t` is the model's estimated probability for each class.
77
+
78
+ Args:
79
+ input_tensor: Logits tensor with shape :math:`(N, C, *)` where C = number of classes.
80
+ target: Labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
81
+ alpha: Weighting factor :math:`\alpha \in [0, 1]`.
82
+ gamma: Focusing parameter :math:`\gamma >= 0`.
83
+ reduction: Specifies the reduction to apply to the
84
+ output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
85
+ will be applied, ``'mean'``: the sum of the output will be divided by
86
+ the number of elements in the output, ``'sum'``: the output will be
87
+ summed.
88
+ eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.
89
+
90
+ Returns:
91
+ The computed loss.
92
+
93
+ Example:
94
+ >>> N = 5 # num_classes
95
+ >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
96
+ >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
97
+ >>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean')
98
+ >>> output.backward()
99
+ """
100
+ if eps is not None and not torch.jit.is_scripting():
101
+ warnings.warn(
102
+ "`focal_loss` has been reworked for improved numerical stability "
103
+ "and the `eps` argument is no longer necessary",
104
+ DeprecationWarning,
105
+ stacklevel=2,
106
+ )
107
+
108
+ if not isinstance(input_tensor, torch.Tensor):
109
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input_tensor)}")
110
+
111
+ if not len(input_tensor.shape) >= 2:
112
+ raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input_tensor.shape}")
113
+
114
+ if input_tensor.size(0) != target.size(0):
115
+ raise ValueError(
116
+ f"Expected input batch_size ({input_tensor.size(0)}) to match target batch_size ({target.size(0)})."
117
+ )
118
+
119
+ n = input_tensor.size(0)
120
+ out_size = (n,) + input_tensor.size()[2:]
121
+ if target.size()[1:] != input_tensor.size()[2:]:
122
+ raise ValueError(f"Expected target size {out_size}, got {target.size()}")
123
+
124
+ if not input_tensor.device == target.device:
125
+ raise ValueError(f"input and target must be in the same device. Got: {input_tensor.device} and {target.device}")
126
+
127
+ # compute softmax over the classes axis
128
+ input_soft: torch.Tensor = F.softmax(input_tensor, dim=1)
129
+ log_input_soft: torch.Tensor = F.log_softmax(input_tensor, dim=1)
130
+
131
+ # create the labels one hot tensor
132
+ target_one_hot: torch.Tensor = one_hot(
133
+ target, num_classes=input_tensor.shape[1], device=input_tensor.device, dtype=input_tensor.dtype
134
+ )
135
+
136
+ # compute the actual focal loss
137
+ weight = torch.pow(-input_soft + 1.0, gamma)
138
+
139
+ focal = -alpha * weight * log_input_soft
140
+ loss_tmp = torch.einsum("bc...,bc...->b...", (target_one_hot, focal))
141
+
142
+ if reduction == "none":
143
+ loss = loss_tmp
144
+ elif reduction == "mean":
145
+ loss = torch.mean(loss_tmp)
146
+ elif reduction == "sum":
147
+ loss = torch.sum(loss_tmp)
148
+ else:
149
+ raise NotImplementedError(f"Invalid reduction mode: {reduction}")
150
+ return loss
151
+
152
+
153
+ class FocalLoss(nn.Module):
154
+ r"""Criterion that computes Focal loss.
155
+
156
+ According to :cite:`lin2018focal`, the Focal loss is computed as follows:
157
+
158
+ .. math::
159
+
160
+ \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
161
+
162
+ Where:
163
+ - :math:`p_t` is the model's estimated probability for each class.
164
+
165
+ Args:
166
+ alpha: Weighting factor :math:`\alpha \in [0, 1]`.
167
+ gamma: Focusing parameter :math:`\gamma >= 0`.
168
+ reduction: Specifies the reduction to apply to the
169
+ output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
170
+ will be applied, ``'mean'``: the sum of the output will be divided by
171
+ the number of elements in the output, ``'sum'``: the output will be
172
+ summed.
173
+ eps: Deprecated: scalar to enforce numerical stability. This is no longer
174
+ used.
175
+
176
+ Shape:
177
+ - Input: :math:`(N, C, *)` where C = number of classes.
178
+ - Target: :math:`(N, *)` where each value is
179
+ :math:`0 ≤ targets[i] ≤ C−1`.
180
+
181
+ Example:
182
+ >>> N = 5 # num_classes
183
+ >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
184
+ >>> criterion = FocalLoss(**kwargs)
185
+ >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
186
+ >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
187
+ >>> output = criterion(input, target)
188
+ >>> output.backward()
189
+ """
190
+
191
+ def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none", eps: float | None = None) -> None:
192
+ super().__init__()
193
+ self.alpha: float = alpha
194
+ self.gamma: float = gamma
195
+ self.reduction: str = reduction
196
+ self.eps: float | None = eps
197
+
198
+ def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
199
+ """Forward call computation."""
200
+ return focal_loss(input_tensor, target, self.alpha, self.gamma, self.reduction, self.eps)
201
+
202
+
203
+ def binary_focal_loss_with_logits(
204
+ input_tensor: torch.Tensor,
205
+ target: torch.Tensor,
206
+ alpha: float = 0.25,
207
+ gamma: float = 2.0,
208
+ reduction: str = "none",
209
+ eps: float | None = None,
210
+ ) -> torch.Tensor:
211
+ r"""Function that computes Binary Focal loss.
212
+
213
+ .. math::
214
+
215
+ \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
216
+
217
+ where:
218
+ - :math:`p_t` is the model's estimated probability for each class.
219
+
220
+ Args:
221
+ input_tensor: input data tensor of arbitrary shape.
222
+ target: the target tensor with shape matching input.
223
+ alpha: Weighting factor for the rare class :math:`\alpha \in [0, 1]`.
224
+ gamma: Focusing parameter :math:`\gamma >= 0`.
225
+ reduction: Specifies the reduction to apply to the
226
+ output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
227
+ will be applied, ``'mean'``: the sum of the output will be divided by
228
+ the number of elements in the output, ``'sum'``: the output will be
229
+ summed.
230
+ eps: Deprecated: scalar for numerically stability when dividing. This is no longer used.
231
+
232
+ Returns:
233
+ the computed loss.
234
+
235
+ Examples:
236
+ >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
237
+ >>> logits = torch.tensor([[[6.325]],[[5.26]],[[87.49]]])
238
+ >>> labels = torch.tensor([[[1.]],[[1.]],[[0.]]])
239
+ >>> binary_focal_loss_with_logits(logits, labels, **kwargs)
240
+ tensor(21.8725)
241
+ """
242
+ if eps is not None and not torch.jit.is_scripting():
243
+ warnings.warn(
244
+ "`binary_focal_loss_with_logits` has been reworked for improved numerical stability "
245
+ "and the `eps` argument is no longer necessary",
246
+ DeprecationWarning,
247
+ stacklevel=2,
248
+ )
249
+
250
+ if not isinstance(input_tensor, torch.Tensor):
251
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input_tensor)}")
252
+
253
+ if not len(input_tensor.shape) >= 2:
254
+ raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input_tensor.shape}")
255
+
256
+ if input_tensor.size(0) != target.size(0):
257
+ raise ValueError(
258
+ f"Expected input batch_size ({input_tensor.size(0)}) to match target batch_size ({target.size(0)})."
259
+ )
260
+
261
+ probs_pos = torch.sigmoid(input_tensor)
262
+ probs_neg = torch.sigmoid(-input_tensor)
263
+ loss_tmp = -alpha * torch.pow(probs_neg, gamma) * target * F.logsigmoid(input_tensor) - (1 - alpha) * torch.pow(
264
+ probs_pos, gamma
265
+ ) * (1.0 - target) * F.logsigmoid(-input_tensor)
266
+
267
+ if reduction == "none":
268
+ loss = loss_tmp
269
+ elif reduction == "mean":
270
+ loss = torch.mean(loss_tmp)
271
+ elif reduction == "sum":
272
+ loss = torch.sum(loss_tmp)
273
+ else:
274
+ raise NotImplementedError(f"Invalid reduction mode: {reduction}")
275
+ return loss
276
+
277
+
278
+ class BinaryFocalLossWithLogits(nn.Module):
279
+ r"""Criterion that computes Focal loss.
280
+
281
+ According to :cite:`lin2018focal`, the Focal loss is computed as follows:
282
+
283
+ .. math::
284
+
285
+ \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
286
+
287
+ where:
288
+ - :math:`p_t` is the model's estimated probability for each class.
289
+
290
+ Args:
291
+ alpha: Weighting factor for the rare class :math:`\alpha \in [0, 1]`.
292
+ gamma: Focusing parameter :math:`\gamma >= 0`.
293
+ reduction: Specifies the reduction to apply to the
294
+ output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
295
+ will be applied, ``'mean'``: the sum of the output will be divided by
296
+ the number of elements in the output, ``'sum'``: the output will be
297
+ summed.
298
+
299
+ Shape:
300
+ - Input: :math:`(N, *)`.
301
+ - Target: :math:`(N, *)`.
302
+
303
+ Examples:
304
+ >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
305
+ >>> loss = BinaryFocalLossWithLogits(**kwargs)
306
+ >>> input = torch.randn(1, 3, 5, requires_grad=True)
307
+ >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(2)
308
+ >>> output = loss(input, target)
309
+ >>> output.backward()
310
+ """
311
+
312
+ def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none") -> None:
313
+ super().__init__()
314
+ self.alpha: float = alpha
315
+ self.gamma: float = gamma
316
+ self.reduction: str = reduction
317
+
318
+ def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
319
+ """Forward call computation."""
320
+ return binary_focal_loss_with_logits(input_tensor, target, self.alpha, self.gamma, self.reduction)
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def euclidean_dist(
8
+ query: torch.Tensor,
9
+ prototypes: torch.Tensor,
10
+ sen: bool = True,
11
+ eps_pos: float = 1.0,
12
+ eps_neg: float = -1e-7,
13
+ eps: float = 1e-7,
14
+ ) -> torch.Tensor:
15
+ """Compute euclidean distance between two tensors.
16
+ SEN dissimilarity from https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123680120.pdf
17
+ Args:
18
+ query: feature of the network
19
+ prototypes: prototypes of the center
20
+ sen: Sen dissimilarity flag
21
+ eps_pos: similarity arg
22
+ eps_neg: similarity arg
23
+ eps: similarity arg.
24
+
25
+ Returns:
26
+ Euclidian loss
27
+
28
+ """
29
+ # query: (n_classes * n_query) x d
30
+ # prototypes: n_classes x d
31
+ n = query.size(0)
32
+ m = prototypes.size(0)
33
+ d = query.size(1)
34
+ if d != prototypes.size(1):
35
+ raise ValueError("query and prototypes size[1] should be equal")
36
+
37
+ if sen:
38
+ # SEN dissimilarity from https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123680120.pdf
39
+ norm_query = torch.linalg.norm(query, ord=2, dim=1) # (n_classes * n_query) X 1
40
+ norm_prototypes = torch.linalg.norm(prototypes, ord=2, dim=1) # n_classes X 1
41
+
42
+ # We have to compute (||z|| - ||c||)^2 between all query points w.r.t.
43
+ # all support points
44
+
45
+ # Replicate each single query norm value m times
46
+ norm_query = norm_query.view(-1, 1).unsqueeze(1).expand(n, m, 1)
47
+ # Replicate all prototypes norm values n times
48
+ norm_prototypes = norm_prototypes.view(-1, 1).unsqueeze(0).expand(n, m, 1)
49
+ norm_diff = torch.pow(norm_query - norm_prototypes, 2).squeeze(2)
50
+ epsilon = torch.full((n, m), eps_neg).type_as(query)
51
+ if eps_pos != eps_neg:
52
+ # n_query = n // m
53
+ # for i in range(m):
54
+ # epsilon[i * n_query : (i + 1) * n_query, i] = 1.0
55
+
56
+ # Since query points with class i need to have a positive epsilon
57
+ # whenever they refer to support point with class i and since
58
+ # query and support points are ordered, we need to set:
59
+ # the 1st column of the 1st n_query rows to eps_pos
60
+ # the 2nd column of the 2nd n_query rows to eps_pos
61
+ # and so on
62
+ idxs = torch.eye(m, dtype=torch.bool).unsqueeze(1).expand(m, n // m, m).reshape(-1, m)
63
+ epsilon[idxs] = eps_pos
64
+ norm_diff = norm_diff * epsilon
65
+
66
+ # Replicate each single query point value m times
67
+ query = query.unsqueeze(1).expand(n, m, d)
68
+ # Replicate all prototype points values n times
69
+ prototypes = prototypes.unsqueeze(0).expand(n, m, d)
70
+
71
+ norm = torch.pow(query - prototypes, 2).sum(2)
72
+ if sen:
73
+ return torch.sqrt(norm + norm_diff + eps)
74
+
75
+ return norm
76
+
77
+
78
+ def prototypical_loss(
79
+ coords: torch.Tensor,
80
+ target: torch.Tensor,
81
+ n_support: int,
82
+ prototypes: torch.Tensor | None = None,
83
+ sen: bool = True,
84
+ eps_pos: float = 1.0,
85
+ eps_neg: float = -1e-7,
86
+ ):
87
+ """Prototypical loss implementation.
88
+
89
+ Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
90
+ Compute the barycentres by averaging the features of n_support
91
+ samples for each class in target, computes then the distances from each
92
+ samples' features to each one of the barycentres, computes the
93
+ log_probability for each n_query samples for each one of the current
94
+ classes, of appartaining to a class c, loss and accuracy are then computed and returned.
95
+
96
+ Args:
97
+ coords: The model output for a batch of samples
98
+ target: Ground truth for the above batch of samples
99
+ n_support: Number of samples to keep in account when computing
100
+ barycentres, for each one of the current classes
101
+ prototypes: if not None, is used for classification
102
+ sen: Sen dissimilarity flag
103
+ eps_pos: Sen positive similarity arg
104
+ eps_neg: Sen negative similarity arg
105
+ """
106
+ classes = torch.unique(target, sorted=True)
107
+ n_classes = len(classes)
108
+ n_query = len(torch.where(target == classes[0])[0]) - n_support
109
+
110
+ # Check equality between classes and target with broadcasting:
111
+ # class_idxs[i, j] = True iff classes[i] == target[j]
112
+ class_idxs = classes.unsqueeze(1) == target
113
+ if prototypes is None:
114
+ # Get the prototypes as the mean of the support points,
115
+ # ordered by class
116
+ prototypes = torch.stack([coords[idx_list][:n_support] for idx_list in class_idxs]).mean(1) # n_classes X d
117
+ # Get query samples as the points NOT in the support set,
118
+ # where, after .view(-1, d), one has that
119
+ # the 1st n_query points refer to class 1
120
+ # the 2nd n_query points refer to class 2
121
+ # and so on
122
+ query_samples = torch.stack([coords[idx_list][n_support:] for idx_list in class_idxs]).view(
123
+ -1, prototypes.shape[-1]
124
+ ) # (n_classes * n_query) X d
125
+ # Get distances, where dists[i, j] is the distance between
126
+ # query point i to support point j
127
+ dists = euclidean_dist(
128
+ query_samples, prototypes, sen=sen, eps_pos=eps_pos, eps_neg=eps_neg
129
+ ) # (n_classes * n_query) X n_classes
130
+ log_p_y = F.log_softmax(-dists, dim=1)
131
+ log_p_y = log_p_y.view(n_classes, n_query, -1) # n_classes X n_query X n_classes
132
+
133
+ target_inds = torch.arange(0, n_classes).view(n_classes, 1, 1)
134
+ # One solution is to use type_as(coords[0])
135
+ target_inds = target_inds.type_as(coords)
136
+ target_inds = target_inds.expand(n_classes, n_query, 1).long()
137
+
138
+ # Since we need to backpropagate the log softmax of query points
139
+ # of class i that refers to support of the same class for every i,
140
+ # and since query and support are ordered we select:
141
+ # from the 1st n_query X n_classes the 1st column
142
+ # from the 2nd n_query X n_classes the 2st column
143
+ # and so on
144
+ loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
145
+ _, y_hat = log_p_y.max(2)
146
+ acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
147
+
148
+ return loss_val, acc_val, prototypes
@@ -0,0 +1,17 @@
1
+ from .barlowtwins import BarlowTwinsLoss
2
+ from .byol import BYOLRegressionLoss
3
+ from .dino import DinoDistillationLoss
4
+ from .idmm import IDMMLoss
5
+ from .simclr import SimCLRLoss
6
+ from .simsiam import SimSIAMLoss
7
+ from .vicreg import VICRegLoss
8
+
9
+ __all__ = [
10
+ "BarlowTwinsLoss",
11
+ "BYOLRegressionLoss",
12
+ "IDMMLoss",
13
+ "SimCLRLoss",
14
+ "SimSIAMLoss",
15
+ "VICRegLoss",
16
+ "DinoDistillationLoss",
17
+ ]
@@ -0,0 +1,47 @@
1
+ import torch
2
+
3
+
4
+ def barlowtwins_loss(
5
+ z1: torch.Tensor,
6
+ z2: torch.Tensor,
7
+ lambd: float,
8
+ ) -> torch.Tensor:
9
+ """BarlowTwins loss described in https://arxiv.org/abs/2103.03230.
10
+
11
+ Args:
12
+ z1: First `augmented` normalized features (i.e. f(T(x))).
13
+ The normalization can be obtained with
14
+ z1_norm = (z1 - z1.mean(0)) / z1.std(0)
15
+ z2: Second `augmented` normalized features (i.e. f(T(x))).
16
+ The normalization can be obtained with
17
+ z2_norm = (z2 - z2.mean(0)) / z2.std(0)
18
+ lambd: lambda multiplier for redundancy term.
19
+
20
+ Returns:
21
+ BarlowTwins loss
22
+ """
23
+ z1 = (z1 - z1.mean(0)) / z1.std(0)
24
+ z1 = (z2 - z2.mean(0)) / z2.std(0)
25
+ cov = z1.T @ z2
26
+ cov.div_(z1.size(0))
27
+ n = cov.size(0)
28
+ invariance_term = torch.diagonal(cov).add_(-1).pow_(2).sum()
29
+ off_diag = cov.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
30
+ redundancy_term = off_diag.pow_(2).sum()
31
+ return invariance_term + lambd * redundancy_term
32
+
33
+
34
+ class BarlowTwinsLoss(torch.nn.Module):
35
+ """BarlowTwin loss.
36
+
37
+ Args:
38
+ lambd: lambda of the loss.
39
+ """
40
+
41
+ def __init__(self, lambd: float):
42
+ super().__init__()
43
+ self.lambd = lambd
44
+
45
+ def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
46
+ """Compute the BarlowTwins loss."""
47
+ return barlowtwins_loss(z1, z2, self.lambd)
@@ -0,0 +1,37 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ def byol_regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
7
+ """Byol regression loss
8
+ Args:
9
+ x: tensor
10
+ y: tensor.
11
+
12
+ Returns:
13
+ tensor
14
+ """
15
+ x = F.normalize(x, dim=-1)
16
+ y = F.normalize(y, dim=-1)
17
+ return 2 - 2 * (x * y).sum(dim=1).mean()
18
+
19
+
20
+ class BYOLRegressionLoss(nn.Module):
21
+ """BYOL regression loss module."""
22
+
23
+ def forward(
24
+ self,
25
+ x: torch.Tensor,
26
+ y: torch.Tensor,
27
+ ) -> torch.Tensor:
28
+ """Compute the BYOL regression loss.
29
+
30
+ Args:
31
+ x: First Tensor
32
+ y: Second Tensor
33
+
34
+ Returns:
35
+ BYOL regression loss
36
+ """
37
+ return byol_regression_loss(x, y)
@@ -0,0 +1,129 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ def dino_distillation_loss(
12
+ student_output: torch.Tensor,
13
+ teacher_output: torch.Tensor,
14
+ center_vector: torch.Tensor,
15
+ teacher_temp: float = 0.04,
16
+ student_temp: float = 0.1,
17
+ ) -> torch.Tensor:
18
+ """Compute the DINO distillation loss.
19
+
20
+ Args:
21
+ student_output: tensor of the student output
22
+ teacher_output: tensor of the teacher output
23
+ center_vector: center vector of distribution
24
+ teacher_temp: temperature teacher
25
+ student_temp: temperature student.
26
+
27
+ Returns:
28
+ The computed loss
29
+ """
30
+ student_temp = [s / student_temp for s in student_output]
31
+ teacher_temp = [(t - center_vector) / teacher_temp for t in teacher_output]
32
+
33
+ student_sm = [F.log_softmax(s, dim=-1) for s in student_temp]
34
+ teacher_sm = [F.softmax(t, dim=-1).detach() for t in teacher_temp]
35
+
36
+ total_loss = torch.tensor(0.0, device=student_output[0].device)
37
+ n_loss_terms = torch.tensor(0.0, device=student_output[0].device)
38
+
39
+ for t_ix, t in enumerate(teacher_sm):
40
+ for s_ix, s in enumerate(student_sm):
41
+ if t_ix == s_ix:
42
+ continue
43
+
44
+ loss = torch.sum(-t * s, dim=-1) # (n_samples,)
45
+ total_loss += loss.mean() # scalar
46
+ n_loss_terms += 1
47
+
48
+ total_loss /= n_loss_terms
49
+ return total_loss
50
+
51
+
52
+ class DinoDistillationLoss(nn.Module):
53
+ """Dino distillation loss module.
54
+
55
+ Args:
56
+ output_dim: output dim.
57
+ max_epochs: max epochs.
58
+ warmup_teacher_temp: warmup temperature.
59
+ teacher_temp: teacher temperature.
60
+ warmup_teacher_temp_epochs: warmup teacher epocs.
61
+ student_temp: student temperature.
62
+ center_momentum: center momentum.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ output_dim: int,
68
+ max_epochs: int,
69
+ warmup_teacher_temp: float = 0.04,
70
+ teacher_temp: float = 0.07,
71
+ warmup_teacher_temp_epochs: int = 30,
72
+ student_temp: float = 0.1,
73
+ center_momentum: float = 0.9,
74
+ ):
75
+ super().__init__()
76
+ self.student_temp = student_temp
77
+ self.center_momentum = center_momentum
78
+ self.center: torch.Tensor
79
+ # we apply a warm up for the teacher temperature because
80
+ # a too high temperature makes the training instable at the beginning
81
+
82
+ if warmup_teacher_temp_epochs >= max_epochs:
83
+ raise ValueError(
84
+ f"Number of warmup epochs ({warmup_teacher_temp_epochs}) must be smaller than max_epochs ({max_epochs})"
85
+ )
86
+
87
+ if warmup_teacher_temp_epochs < 30:
88
+ log.warning("Warmup teacher epochs is very small (< 30). This may cause instabilities in the training")
89
+
90
+ self.teacher_temp_schedule = np.concatenate(
91
+ (
92
+ np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
93
+ np.ones(max_epochs - warmup_teacher_temp_epochs) * teacher_temp,
94
+ )
95
+ )
96
+ self.register_buffer("center", torch.zeros(1, output_dim))
97
+
98
+ def forward(
99
+ self,
100
+ current_epoch: int,
101
+ student_output: torch.Tensor,
102
+ teacher_output: torch.Tensor,
103
+ ) -> torch.Tensor:
104
+ """Runs forward."""
105
+ teacher_temp = self.teacher_temp_schedule[current_epoch]
106
+ loss = dino_distillation_loss(
107
+ student_output,
108
+ teacher_output,
109
+ center_vector=self.center,
110
+ teacher_temp=teacher_temp,
111
+ student_temp=self.student_temp,
112
+ )
113
+
114
+ self.update_center(teacher_output)
115
+ return loss
116
+
117
+ @torch.no_grad()
118
+ def update_center(self, teacher_output: torch.Tensor) -> None:
119
+ """Update center of the distribution of the teacher
120
+ Args:
121
+ teacher_output: teacher output.
122
+
123
+ Returns:
124
+ None
125
+ """
126
+ # TODO: check if this is correct
127
+ # torch.cat expects a list of tensors but teacher_output is a tensor
128
+ batch_center = torch.cat(teacher_output).mean(dim=0, keepdim=True) # type: ignore[call-overload]
129
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)