autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,394 +0,0 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- try:
9
- import torch.distributed.nn
10
- from torch import distributed as dist
11
-
12
- has_distributed = True
13
- except ImportError:
14
- has_distributed = False
15
-
16
- try:
17
- import horovod.torch as hvd
18
- except ImportError:
19
- hvd = None
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- class RKDLoss(nn.Module):
25
- """
26
- Compute RKD Distance Loss.
27
- Paper Refer to: Relational Knowledge Disitllation, CVPR2019. https://arxiv.org/abs/1904.05068
28
- Code Refer to: https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/RKD.py
29
- and https://github.com/lenscloth/RKD/blob/master/metric/loss.py
30
- """
31
-
32
- def __init__(self, distance_loss_weight: Optional[float] = 25.0, angle_loss_weight: Optional[float] = 50.0):
33
- """
34
- Parameters
35
- ----------
36
- distance_loss_weight
37
- Weight of RKD distance loss
38
- angle_loss_weight
39
- Weight of RKD angle loss
40
- Returns
41
- -------
42
- """
43
- super(RKDLoss, self).__init__()
44
- self.distance_loss_weight = distance_loss_weight
45
- self.angle_loss_weight = angle_loss_weight
46
-
47
- def forward(self, feature_student: Optional[torch.Tensor], feature_teacher: Optional[torch.Tensor]):
48
- """
49
- Parameters
50
- ----------
51
- feature_student
52
- Output feature of student model, shape: (N, D)
53
- feature_teacher
54
- Output feature of teacher model, shape: (N, D)
55
- Returns
56
- -------
57
- The RKD Loss between teacher and student
58
- """
59
- # RKD loss
60
- if self.distance_loss_weight > 0:
61
- with torch.no_grad():
62
- t_dist = self.pdist(feature_teacher, squared=False)
63
- mean_td = t_dist[t_dist > 0].mean()
64
- t_dist = t_dist / mean_td
65
-
66
- s_dist = self.pdist(feature_student, squared=False)
67
- mean_d = s_dist[s_dist > 0].mean()
68
- s_dist = s_dist / mean_d
69
-
70
- loss_distance = F.smooth_l1_loss(s_dist, t_dist)
71
-
72
- # RKD Angle loss
73
- if self.angle_loss_weight > 0:
74
- with torch.no_grad():
75
- td = feature_teacher.unsqueeze(0) - feature_teacher.unsqueeze(1)
76
- norm_td = F.normalize(td, p=2, dim=2)
77
- t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
78
-
79
- sd = feature_student.unsqueeze(0) - feature_student.unsqueeze(1)
80
- norm_sd = F.normalize(sd, p=2, dim=2)
81
- s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
82
-
83
- loss_angle = F.smooth_l1_loss(s_angle, t_angle)
84
-
85
- loss = ((self.distance_loss_weight * loss_distance) if self.distance_loss_weight > 0 else 0) + (
86
- (self.angle_loss_weight * loss_angle) if self.angle_loss_weight > 0 else 0
87
- )
88
-
89
- return loss
90
-
91
- @staticmethod
92
- def pdist(embeddings: Optional[torch.Tensor], squared: Optional[bool] = False, eps: Optional[float] = 1e-12):
93
- """
94
- Compute pairwise Euclidean distances between embeddings in n-dimensional space.
95
-
96
- Parameters
97
- ----------
98
- embeddings
99
- The embeddings to compute pairwise distance between. Shape: (N,D)
100
- squared
101
- If the result is square of Euclidean distance.
102
- eps
103
- Min value of each entry.
104
-
105
- Returns
106
- -------
107
- Pairwise Euclidean distances. Shape: (N,N)
108
- """
109
- e_square = embeddings.pow(2).sum(dim=1)
110
- prod = embeddings @ embeddings.t()
111
- res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
112
-
113
- if not squared:
114
- res = res.sqrt()
115
-
116
- res = res.clone()
117
- res[range(len(embeddings)), range(len(embeddings))] = 0
118
-
119
- return res
120
-
121
-
122
- class SoftTargetCrossEntropy(nn.Module):
123
- """
124
- The soft target CrossEntropy from timm.
125
- https://github.com/rwightman/pytorch-image-models/blob/e4360e6125bb0bb4279785810c8eb33b40af3ebd/timm/loss/cross_entropy.py
126
- It works under the mixup.
127
- It can calculate the crossentropy of input and label with one-hot.
128
- """
129
-
130
- def __init__(self):
131
- super(SoftTargetCrossEntropy, self).__init__()
132
-
133
- def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
134
- loss = torch.sum(-target * F.log_softmax(input, dim=-1), dim=-1)
135
- return loss.mean()
136
-
137
-
138
- def gather_features(
139
- image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False
140
- ):
141
- """
142
- Gather features across GPUs.
143
-
144
- Parameters
145
- ----------
146
- image_features
147
- image features of the current process.
148
- text_features
149
- text features of the current process.
150
- local_loss
151
- If False, make sure the features on the current GPU have gradients.
152
- gather_with_grad
153
- Whether to gather all features with gradients enabled.
154
- rank
155
- Rank of the current process (it should be a number between 0 and world_size-1).
156
- world_size
157
- Number of processes participating in the job.
158
- use_horovod
159
- Whether to use horovod.
160
-
161
- Returns
162
- -------
163
- Gathered image and text features from all processes.
164
- """
165
- assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
166
- if use_horovod:
167
- assert hvd is not None, "Please install horovod"
168
- if gather_with_grad:
169
- all_image_features = hvd.allgather(image_features)
170
- all_text_features = hvd.allgather(text_features)
171
- else:
172
- with torch.no_grad():
173
- all_image_features = hvd.allgather(image_features)
174
- all_text_features = hvd.allgather(text_features)
175
- if not local_loss:
176
- # ensure grads for local rank when all_* features don't have a gradient
177
- gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
178
- gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
179
- gathered_image_features[rank] = image_features
180
- gathered_text_features[rank] = text_features
181
- all_image_features = torch.cat(gathered_image_features, dim=0)
182
- all_text_features = torch.cat(gathered_text_features, dim=0)
183
- else:
184
- # We gather tensors from all gpus
185
- if gather_with_grad:
186
- all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
187
- all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
188
- else:
189
- gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
190
- gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
191
- dist.all_gather(gathered_image_features, image_features)
192
- dist.all_gather(gathered_text_features, text_features)
193
- if not local_loss:
194
- # ensure grads for local rank when all_* features don't have a gradient
195
- gathered_image_features[rank] = image_features
196
- gathered_text_features[rank] = text_features
197
- all_image_features = torch.cat(gathered_image_features, dim=0)
198
- all_text_features = torch.cat(gathered_text_features, dim=0)
199
-
200
- return all_image_features, all_text_features
201
-
202
-
203
- class MultiNegativesSoftmaxLoss(nn.Module):
204
- """
205
- This loss expects as input a batch consisting of pairs (a_1, p_1), (a_2, p_2)…, (a_n, p_n) where
206
- we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.
207
- For each a_i, it uses all other p_j as negative samples, i.e., for a_i,
208
- we have 1 positive example (p_i) and n-1 negative examples (p_j).
209
- It then minimizes the negative log-likehood for softmax normalized scores.
210
- It can also support gather negatives across processes.
211
- """
212
-
213
- def __init__(
214
- self,
215
- local_loss=False,
216
- gather_with_grad=False,
217
- cache_labels=False,
218
- use_horovod=False,
219
- ):
220
- """
221
- Parameters
222
- ----------
223
- local_loss
224
- Whether to compute the loss only for the current process's samples.
225
- gather_with_grad
226
- Whether to gather all features with gradients enabled.
227
- cache_labels
228
- Whether to cache labels for loss in next iterations.
229
- use_horovod
230
- Whether to use horovod.
231
- """
232
- super().__init__()
233
- self.local_loss = local_loss
234
- self.gather_with_grad = gather_with_grad
235
- self.cache_labels = cache_labels
236
- self.use_horovod = use_horovod
237
-
238
- # cache state
239
- self.prev_num_logits = 0
240
- self.labels = {}
241
-
242
- def forward(self, features_a, features_b, logit_scale, rank=0, world_size=1):
243
- device = features_a.device
244
- if world_size > 1:
245
- all_features_a, all_features_b = gather_features(
246
- features_a, features_b, self.local_loss, self.gather_with_grad, rank, world_size, self.use_horovod
247
- )
248
-
249
- if self.local_loss:
250
- logits_per_a = logit_scale * features_a @ all_features_b.T
251
- logits_per_b = logit_scale * features_b @ all_features_a.T
252
- else:
253
- logits_per_a = logit_scale * all_features_a @ all_features_b.T
254
- logits_per_b = logits_per_a.T
255
- else:
256
- logits_per_a = logit_scale * features_a @ features_b.T
257
- logits_per_b = logit_scale * features_b @ features_a.T
258
-
259
- # calculated ground-truth and cache if enabled
260
- num_logits = logits_per_a.shape[0]
261
- if self.prev_num_logits != num_logits or device not in self.labels:
262
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
263
- if world_size > 1 and self.local_loss:
264
- labels = labels + num_logits * rank
265
- if self.cache_labels:
266
- self.labels[device] = labels
267
- self.prev_num_logits = num_logits
268
- else:
269
- labels = self.labels[device]
270
-
271
- total_loss = (F.cross_entropy(logits_per_a, labels) + F.cross_entropy(logits_per_b, labels)) / 2
272
- return total_loss
273
-
274
-
275
- class FocalLoss(nn.Module):
276
- """
277
- Focal loss based on https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
278
-
279
- References:
280
- [1] https://arxiv.org/abs/1708.02002
281
- """
282
-
283
- def __init__(
284
- self,
285
- alpha: Optional[torch.Tensor] = None,
286
- gamma: Optional[float] = 2.0,
287
- reduction: Optional[str] = "mean",
288
- eps: Optional[float] = 1e-6,
289
- ):
290
- """
291
-
292
- Parameters
293
- ----------
294
- alpha
295
- weighting factor for each class. Should be of shape (num_classes)
296
- gamma
297
- the focal parameter for calculating weights on easy/hard samples
298
- reduction
299
- the reduction to apply to the final loss output. Default: "mean". Options:
300
- "mean", "sum"
301
- eps
302
- epsilon for numerical stability
303
- """
304
- super(FocalLoss, self).__init__()
305
-
306
- self.gamma = gamma
307
- self.reduction = reduction
308
- self.eps = eps
309
- if alpha is not None:
310
- if isinstance(alpha, str): # handles Ray Tune HPO sampled hyperparameter
311
- try:
312
- numbers = alpha.strip("()").split(",")
313
- alpha = [float(num) for num in numbers]
314
- except:
315
- raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
316
- alpha = torch.tensor(alpha)
317
- self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")
318
-
319
- def forward(self, input: torch.Tensor, target: torch.Tensor):
320
- if not torch.is_tensor(input):
321
- raise TypeError("input type is not a torch.Tensor. Got {}".format(type(input)))
322
- if input.ndim > 2:
323
- # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
324
- num_class = input.shape[1]
325
- input = input.permute(0, *range(2, input.ndim), 1).reshape(-1, num_class)
326
- # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
327
- target = target.view(-1)
328
-
329
- pt = F.softmax(input, dim=-1)
330
-
331
- # -alpha_t * log(pt) term
332
- log_p = torch.log_softmax(input, dim=-1)
333
- ce = self.nll_loss(log_p, target)
334
-
335
- # (1 - pt)^gamma term
336
- all_rows = torch.arange(input.shape[0])
337
- pt = pt[all_rows, target]
338
- focal_term = (1 - pt) ** self.gamma
339
-
340
- loss = focal_term * ce
341
-
342
- if self.reduction == "mean":
343
- loss = loss.mean()
344
-
345
- elif self.reduction == "sum":
346
- loss = loss.sum()
347
-
348
- return loss
349
-
350
-
351
- class StructureLoss(nn.Module):
352
- """
353
- Structure Loss based on https://github.com/DengPingFan/PraNet/blob/master/MyTrain.py
354
- The loss represent the weighted IoU loss and binary cross entropy (BCE) loss for the global restriction and local (pixel-level) restriction.
355
-
356
- References:
357
- [1] https://arxiv.org/abs/2006.11392
358
- """
359
-
360
- def forward(self, input: torch.Tensor, target: torch.Tensor):
361
- if input.dim() == 3:
362
- input = input.unsqueeze(1)
363
- weit = 1 + 5 * torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15) - target)
364
- wbce = F.binary_cross_entropy_with_logits(input, target, reduce="none")
365
- wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
366
-
367
- input = torch.sigmoid(input)
368
- inter = ((input * target) * weit).sum(dim=(2, 3))
369
- union = ((input + target) * weit).sum(dim=(2, 3))
370
- wiou = 1 - (inter + 1) / (union - inter + 1)
371
- return (wbce + wiou).mean()
372
-
373
-
374
- class BBCEWithLogitLoss(nn.Module):
375
- """
376
- Balanced BCEWithLogitLoss based on https://github.com/NiFangBaAGe/Explicit-Visual-Prompt/blob/latest_branch/models/segformer.py
377
- """
378
-
379
- def __init__(self):
380
- super(BBCEWithLogitLoss, self).__init__()
381
-
382
- def forward(self, input: torch.Tensor, target: torch.Tensor):
383
- if input.dim() == 3:
384
- input = input.unsqueeze(1)
385
- eps = 1e-10
386
- count_pos = torch.sum(target) + eps
387
- count_neg = torch.sum(1.0 - target)
388
- ratio = count_neg / count_pos
389
- w_neg = count_pos / (count_pos + count_neg)
390
-
391
- bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio)
392
- loss = w_neg * bce1(input, target)
393
-
394
- return loss