autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,6 +1,7 @@
1
+ import logging
1
2
  import os
2
3
  import tempfile
3
- from typing import List, Optional
4
+ from typing import Dict, List, Optional
4
5
 
5
6
  import torch
6
7
  from torch import Tensor, nn
@@ -9,6 +10,8 @@ from ..constants import CATEGORICAL, FEATURES, LABEL, LOGITS, NUMERICAL
9
10
  from .custom_transformer import CLSToken, Custom_Transformer, _TokenInitialization
10
11
  from .utils import init_weights
11
12
 
13
+ logger = logging.getLogger(__name__)
14
+
12
15
 
13
16
  class CategoricalFeatureTokenizer(nn.Module):
14
17
  """
@@ -21,7 +24,7 @@ class CategoricalFeatureTokenizer(nn.Module):
21
24
  def __init__(
22
25
  self,
23
26
  num_categories: List[int],
24
- d_token: int,
27
+ token_dim: int,
25
28
  bias: Optional[bool] = True,
26
29
  initialization: Optional[str] = "normal",
27
30
  ) -> None:
@@ -30,7 +33,7 @@ class CategoricalFeatureTokenizer(nn.Module):
30
33
  ----------
31
34
  num_categories:
32
35
  A list of integers. Each one is the number of categories in one categorical column.
33
- d_token:
36
+ token_dim:
34
37
  The size of one token.
35
38
  bias:
36
39
  If `True`, for each feature, an additional trainable vector will be added to the
@@ -51,21 +54,21 @@ class CategoricalFeatureTokenizer(nn.Module):
51
54
  category_offsets = torch.tensor([0] + num_categories[:-1]).cumsum(0)
52
55
 
53
56
  self.register_buffer("category_offsets", category_offsets, persistent=False)
54
- self.embeddings = nn.Embedding(sum(num_categories), d_token)
55
- self.bias = nn.Parameter(Tensor(len(num_categories), d_token)) if bias else None
57
+ self.embeddings = nn.Embedding(sum(num_categories), token_dim)
58
+ self.bias = nn.Parameter(Tensor(len(num_categories), token_dim)) if bias else None
56
59
  initialization_ = _TokenInitialization.from_str(initialization)
57
60
 
58
61
  for parameter in [self.embeddings.weight, self.bias]:
59
62
  if parameter is not None:
60
- initialization_.apply(parameter, d_token)
63
+ initialization_.apply(parameter, token_dim)
61
64
 
62
65
  @property
63
- def n_tokens(self) -> int:
66
+ def num_tokens(self) -> int:
64
67
  """The number of tokens."""
65
68
  return len(self.num_categories)
66
69
 
67
70
  @property
68
- def d_token(self) -> int:
71
+ def token_dim(self) -> int:
69
72
  """The size of one token."""
70
73
  return self.embeddings.embedding_dim
71
74
 
@@ -190,7 +193,7 @@ class NumericalFeatureTokenizer(nn.Module):
190
193
  def __init__(
191
194
  self,
192
195
  in_features: int,
193
- d_token: int,
196
+ token_dim: int,
194
197
  bias: Optional[bool] = True,
195
198
  initialization: Optional[str] = "normal",
196
199
  ):
@@ -199,7 +202,7 @@ class NumericalFeatureTokenizer(nn.Module):
199
202
  ----------
200
203
  in_features:
201
204
  Dimension of input features i.e. the number of continuous (scalar) features
202
- d_token:
205
+ token_dim:
203
206
  The size of one token.
204
207
  bias:
205
208
  If `True`, for each feature, an additional trainable vector will be added to the
@@ -216,19 +219,19 @@ class NumericalFeatureTokenizer(nn.Module):
216
219
  super().__init__()
217
220
 
218
221
  initialization_ = _TokenInitialization.from_str(initialization)
219
- self.weight = nn.Parameter(Tensor(in_features, d_token))
220
- self.bias = nn.Parameter(Tensor(in_features, d_token)) if bias else None
222
+ self.weight = nn.Parameter(Tensor(in_features, token_dim))
223
+ self.bias = nn.Parameter(Tensor(in_features, token_dim)) if bias else None
221
224
  for parameter in [self.weight, self.bias]:
222
225
  if parameter is not None:
223
- initialization_.apply(parameter, d_token)
226
+ initialization_.apply(parameter, token_dim)
224
227
 
225
228
  @property
226
- def n_tokens(self) -> int:
229
+ def num_tokens(self) -> int:
227
230
  """The number of tokens."""
228
231
  return len(self.weight)
229
232
 
230
233
  @property
231
- def d_token(self) -> int:
234
+ def token_dim(self) -> int:
232
235
  """The size of one token."""
233
236
  return self.weight.shape[1]
234
237
 
@@ -266,7 +269,7 @@ class AutoDis(nn.Module):
266
269
  super().__init__()
267
270
  self.first_layer = NumericalFeatureTokenizer(
268
271
  in_features=in_features,
269
- d_token=n_meta_embeddings,
272
+ token_dim=n_meta_embeddings,
270
273
  bias=False,
271
274
  initialization="uniform",
272
275
  )
@@ -357,7 +360,7 @@ class NumEmbeddings(nn.Module):
357
360
  if embedding_arch[0] == "linear":
358
361
  layers.append(
359
362
  NumericalFeatureTokenizer(
360
- in_features=in_features, d_token=d_embedding, bias=True, initialization="normal"
363
+ in_features=in_features, token_dim=d_embedding, bias=True, initialization="normal"
361
364
  )
362
365
  )
363
366
  elif embedding_arch[0] == "positional":
@@ -406,13 +409,13 @@ class NumEmbeddings(nn.Module):
406
409
  self.layers = nn.Sequential(*layers)
407
410
 
408
411
  @property
409
- def n_tokens(self) -> int:
412
+ def num_tokens(self) -> int:
410
413
  """The number of tokens."""
411
414
  y = self.forward(torch.ones(1, self.in_features))
412
415
  return y.shape[1]
413
416
 
414
417
  @property
415
- def d_token(self) -> int:
418
+ def token_dim(self) -> int:
416
419
  """The size of one token."""
417
420
  y = self.forward(torch.ones(1, self.in_features))
418
421
  return y.shape[-1]
@@ -432,7 +435,8 @@ class FT_Transformer(nn.Module):
432
435
  self,
433
436
  prefix: str,
434
437
  num_numerical_columns: int,
435
- num_categories: List[int],
438
+ num_categories: Dict,
439
+ numerical_fill_values: Dict,
436
440
  embedding_arch: List[str],
437
441
  token_dim: int,
438
442
  hidden_size: Optional[int] = 192,
@@ -441,7 +445,7 @@ class FT_Transformer(nn.Module):
441
445
  token_bias: Optional[bool] = True,
442
446
  token_initialization: Optional[str] = "normal",
443
447
  num_blocks: Optional[int] = 0,
444
- attention_n_heads: Optional[int] = 8,
448
+ attention_num_heads: Optional[int] = 8,
445
449
  attention_initialization: Optional[str] = "kaiming",
446
450
  attention_normalization: Optional[str] = "layer_norm",
447
451
  attention_dropout: Optional[str] = 0.2,
@@ -485,7 +489,7 @@ class FT_Transformer(nn.Module):
485
489
  Must be one of `['uniform', 'normal']`.
486
490
  num_blocks
487
491
  Number of the `FT_Transformer` blocks, which should be non-negative.
488
- attention_n_heads
492
+ attention_num_heads
489
493
  Number of attention heads in each `FT_Transformer` block, which should be positive.
490
494
  attention_initialization
491
495
  Weights initialization scheme for Multi Headed Attention module.
@@ -527,11 +531,11 @@ class FT_Transformer(nn.Module):
527
531
  """
528
532
 
529
533
  super().__init__()
530
-
534
+ logger.debug(f"initializing {prefix} (FT_Transformer)")
531
535
  assert num_categories or num_numerical_columns > 0, "there must be categorical columns or numerical columns"
532
- assert token_dim > 0, "d_token must be positive"
533
- assert num_blocks >= 0, "n_blocks must be non-negative"
534
- assert attention_n_heads > 0, "attention_n_heads must be positive"
536
+ assert token_dim > 0, "token_dim must be positive"
537
+ assert num_blocks >= 0, "num_blocks must be non-negative"
538
+ assert attention_num_heads > 0, "attention_num_heads must be positive"
535
539
  assert token_initialization in ["uniform", "normal"], "initialization must be uniform or normal"
536
540
 
537
541
  self.prefix = prefix
@@ -544,14 +548,15 @@ class FT_Transformer(nn.Module):
544
548
  if num_categories:
545
549
  self.num_categories = num_categories
546
550
  self.categorical_feature_tokenizer = CategoricalFeatureTokenizer(
547
- num_categories=num_categories,
548
- d_token=token_dim,
551
+ num_categories=list(num_categories.values()),
552
+ token_dim=token_dim,
549
553
  bias=token_bias,
550
554
  initialization=token_initialization,
551
555
  )
552
556
  self.categorical_adapter = nn.Linear(token_dim, hidden_size)
553
557
 
554
558
  if num_numerical_columns > 0:
559
+ self.numerical_fill_values = numerical_fill_values
555
560
  self.numerical_feature_tokenizer = NumEmbeddings(
556
561
  in_features=num_numerical_columns,
557
562
  d_embedding=token_dim,
@@ -560,13 +565,13 @@ class FT_Transformer(nn.Module):
560
565
  self.numerical_adapter = nn.Linear(token_dim, hidden_size)
561
566
 
562
567
  self.transformer = Custom_Transformer(
563
- d_token=hidden_size,
564
- n_blocks=num_blocks,
565
- attention_n_heads=attention_n_heads,
568
+ token_dim=hidden_size,
569
+ num_blocks=num_blocks,
570
+ attention_num_heads=attention_num_heads,
566
571
  attention_dropout=attention_dropout,
567
572
  attention_initialization=attention_initialization,
568
573
  attention_normalization=attention_normalization,
569
- ffn_d_hidden=ffn_hidden_size,
574
+ ffn_hidden_size=ffn_hidden_size,
570
575
  ffn_dropout=ffn_dropout,
571
576
  ffn_activation=ffn_activation,
572
577
  ffn_normalization=ffn_normalization,
@@ -574,7 +579,7 @@ class FT_Transformer(nn.Module):
574
579
  prenormalization=prenormalization,
575
580
  first_prenormalization=first_prenormalization,
576
581
  last_layer_query_idx=None,
577
- n_tokens=None,
582
+ num_tokens=None,
578
583
  kv_compression_ratio=kv_compression_ratio,
579
584
  kv_compression_sharing=kv_compression_sharing,
580
585
  head_activation=head_activation,
@@ -594,7 +599,7 @@ class FT_Transformer(nn.Module):
594
599
  )
595
600
 
596
601
  self.cls_token = CLSToken(
597
- d_token=hidden_size,
602
+ token_dim=hidden_size,
598
603
  initialization="uniform",
599
604
  )
600
605
 
@@ -605,7 +610,7 @@ class FT_Transformer(nn.Module):
605
610
  self.categorical_adapter.apply(init_weights)
606
611
  self.head.apply(init_weights)
607
612
  # init transformer backbone from provided checkpoint
608
- from autogluon.multimodal.utils.download import download
613
+ from ..utils.download import download
609
614
 
610
615
  if pretrained and checkpoint_name:
611
616
  if os.path.exists(checkpoint_name):
@@ -4,8 +4,6 @@ from typing import Optional
4
4
 
5
5
  from torch import nn
6
6
 
7
- from ...constants import AUTOMM, LABEL
8
-
9
7
  logger = logging.getLogger(__name__)
10
8
 
11
9
 
@@ -18,12 +16,12 @@ class AbstractMultimodalFusionModel(ABC, nn.Module):
18
16
  self,
19
17
  prefix: str,
20
18
  models: list,
21
- loss_weight: Optional[float] = None,
19
+ aux_loss_weight: Optional[float] = None,
22
20
  ):
23
21
  super().__init__()
24
22
 
25
23
  self.prefix = prefix
26
- self.loss_weight = loss_weight
24
+ self.aux_loss_weight = aux_loss_weight
27
25
  self.model = nn.ModuleList(models)
28
26
 
29
27
  @property
@@ -4,7 +4,19 @@ from typing import List, Optional
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from ...constants import AUTOMM, FEATURES, LABEL, LOGITS, WEIGHT
7
+ from ...constants import (
8
+ AUG_LOGITS,
9
+ FEATURES,
10
+ LABEL,
11
+ LOGITS,
12
+ MULTIMODAL_FEATURES,
13
+ MULTIMODAL_FEATURES_POST_AUG,
14
+ MULTIMODAL_FEATURES_PRE_AUG,
15
+ ORI_LOGITS,
16
+ VAE_MEAN,
17
+ VAE_VAR,
18
+ WEIGHT,
19
+ )
8
20
  from ..mlp import MLP
9
21
  from ..utils import init_weights, run_model
10
22
  from .base import AbstractMultimodalFusionModel
@@ -27,9 +39,9 @@ class MultimodalFusionMLP(AbstractMultimodalFusionModel):
27
39
  num_classes: int,
28
40
  adapt_in_features: Optional[str] = None,
29
41
  activation: Optional[str] = "gelu",
30
- dropout_prob: Optional[float] = 0.5,
42
+ dropout: Optional[float] = 0.5,
31
43
  normalization: Optional[str] = "layer_norm",
32
- loss_weight: Optional[float] = None,
44
+ aux_loss_weight: Optional[float] = None,
33
45
  ):
34
46
  """
35
47
  Parameters
@@ -56,24 +68,26 @@ class MultimodalFusionMLP(AbstractMultimodalFusionModel):
56
68
  dimension 768.
57
69
  activation
58
70
  Name of activation function.
59
- dropout_prob
71
+ dropout
60
72
  Dropout probability.
61
73
  normalization
62
74
  Name of normalization function.
63
- loss_weight
75
+ aux_loss_weight
64
76
  The weight of individual models. For example, if we fuse the features of ViT, CLIP, and BERT,
65
- The loss will be computed as "loss = fusion_loss + loss_weight(vit_loss + clip_loss + bert_loss)".
77
+ The loss will be computed as "loss = fusion_loss + aux_loss_weight(vit_loss + clip_loss + bert_loss)".
66
78
  Basically, it supports adding an auxiliary loss for each individual model.
67
79
  """
68
80
  super().__init__(
69
81
  prefix=prefix,
70
82
  models=models,
71
- loss_weight=loss_weight,
83
+ aux_loss_weight=aux_loss_weight,
72
84
  )
73
- logger.debug("initializing MultimodalFusionMLP")
74
- if loss_weight is not None:
75
- assert loss_weight > 0
85
+ logger.debug(f"initializing {prefix} (MultimodalFusionMLP)")
86
+ if aux_loss_weight is not None:
87
+ assert aux_loss_weight >= 0
88
+ logger.debug(f"auxiliary loss weight: {aux_loss_weight}")
76
89
  self.num_classes = num_classes
90
+ self.augmenter = None
77
91
 
78
92
  raw_in_features = [per_model.out_features for per_model in models]
79
93
  if adapt_in_features is not None:
@@ -92,6 +106,7 @@ class MultimodalFusionMLP(AbstractMultimodalFusionModel):
92
106
  in_features = sum(raw_in_features)
93
107
 
94
108
  assert len(self.adapter) == len(self.model)
109
+ self.augmenter_in_features = in_features
95
110
 
96
111
  fusion_mlp = []
97
112
  for per_hidden_features in hidden_features:
@@ -102,7 +117,7 @@ class MultimodalFusionMLP(AbstractMultimodalFusionModel):
102
117
  out_features=per_hidden_features,
103
118
  num_layers=1,
104
119
  activation=activation,
105
- dropout_prob=dropout_prob,
120
+ dropout=dropout,
106
121
  normalization=normalization,
107
122
  )
108
123
  )
@@ -146,12 +161,16 @@ class MultimodalFusionMLP(AbstractMultimodalFusionModel):
146
161
 
147
162
  Returns
148
163
  -------
149
- If "loss_weight" is None, it returns dictionary containing the fusion model's logits and
164
+ If "aux_loss_weight" is None, it returns dictionary containing the fusion model's logits and
150
165
  features. Otherwise, it returns a list of dictionaries collecting all the models' output,
151
166
  including the fusion model's.
152
167
  """
153
168
  multimodal_features = []
154
169
  multimodal_logits = []
170
+ multimodal_features_pre_aug = None
171
+ multimodal_features_post_aug = None
172
+ vae_mean = None
173
+ vae_var = None
155
174
  offset = 0
156
175
  for per_model, per_adapter in zip(self.model, self.adapter):
157
176
  per_model_args = args[offset : offset + len(per_model.input_keys)]
@@ -163,23 +182,68 @@ class MultimodalFusionMLP(AbstractMultimodalFusionModel):
163
182
  multimodal_logits.append(per_output[per_model.prefix][LOGITS])
164
183
  offset += len(per_model.input_keys)
165
184
 
166
- features = self.fusion_mlp(torch.cat(multimodal_features, dim=1))
185
+ # make sure the returned multimodal features contain unimodal encoder features
186
+ multimodal_features_ret = multimodal_features
187
+ multimodal_features = torch.cat(multimodal_features, dim=1)
188
+ batch_size = multimodal_features.shape[0]
189
+ if self.training and self.augmenter is not None:
190
+ multimodal_features_pre_aug = multimodal_features.detach().clone() # [bs, dim]
191
+ multimodal_features_post_aug, vae_mean, vae_var = self.augmenter(multimodal_features_pre_aug)
192
+ multimodal_features_post_aug_clone = multimodal_features_post_aug.clone()
193
+ multimodal_features_post_aug_clone.register_hook(lambda grad: -grad * self.augmenter.adv_weight)
194
+ multimodal_features = torch.cat([multimodal_features, multimodal_features_post_aug_clone], dim=0)
195
+
196
+ features = self.fusion_mlp(multimodal_features)
167
197
  logits = self.head(features)
198
+ ori_logits = logits[:batch_size].detach() # detach the original logits when computing the consistency loss
199
+ aug_logits = logits[batch_size:]
200
+
201
+ return (
202
+ features,
203
+ logits,
204
+ multimodal_logits,
205
+ multimodal_features_ret,
206
+ multimodal_features_pre_aug,
207
+ multimodal_features_post_aug,
208
+ ori_logits,
209
+ aug_logits,
210
+ vae_mean,
211
+ vae_var,
212
+ )
168
213
 
169
- return features, logits, multimodal_logits
170
-
171
- def get_output_dict(self, features: torch.Tensor, logits: torch.Tensor, multimodal_logits: List[torch.Tensor]):
214
+ def get_output_dict(
215
+ self,
216
+ features: torch.Tensor,
217
+ logits: torch.Tensor,
218
+ multimodal_logits: List[torch.Tensor],
219
+ multimodal_features: List[torch.Tensor],
220
+ multimodal_features_pre_aug: torch.Tensor,
221
+ multimodal_features_post_aug: torch.Tensor,
222
+ ori_logits: torch.Tensor,
223
+ aug_logits: torch.Tensor,
224
+ vae_mean: torch.Tensor,
225
+ vae_var: torch.Tensor,
226
+ ):
172
227
  fusion_output = {
173
228
  self.prefix: {
174
229
  LOGITS: logits,
175
230
  FEATURES: features,
231
+ MULTIMODAL_FEATURES: multimodal_features,
232
+ MULTIMODAL_FEATURES_PRE_AUG: multimodal_features_pre_aug,
233
+ MULTIMODAL_FEATURES_POST_AUG: multimodal_features_post_aug,
234
+ ORI_LOGITS: ori_logits,
235
+ AUG_LOGITS: aug_logits,
236
+ VAE_MEAN: vae_mean,
237
+ VAE_VAR: vae_var,
176
238
  }
177
239
  }
178
- if self.loss_weight is not None:
240
+ # filter out None
241
+ fusion_output = {self.prefix: {k: v for k, v in fusion_output[self.prefix].items() if v is not None}}
242
+ if self.aux_loss_weight is not None:
179
243
  output = {}
180
244
  for per_model, per_logits in zip(self.model, multimodal_logits):
181
245
  per_output = {per_model.prefix: {}}
182
- per_output[per_model.prefix][WEIGHT] = torch.tensor(self.loss_weight).to(per_logits.dtype)
246
+ per_output[per_model.prefix][WEIGHT] = torch.tensor(self.aux_loss_weight).to(per_logits.dtype)
183
247
  per_output[per_model.prefix][LOGITS] = per_logits
184
248
  output.update(per_output)
185
249
  fusion_output[self.prefix].update({WEIGHT: torch.tensor(1.0).to(logits)})
@@ -5,7 +5,7 @@ import torch
5
5
  import torch.nn.functional as F
6
6
  from torch import nn
7
7
 
8
- from ...constants import AUTOMM, FEATURES, LABEL, LOGITS, NER_ANNOTATION, NER_TEXT, TOKEN_WORD_MAPPING, WORD_OFFSETS
8
+ from ...constants import FEATURES, LABEL, LOGITS, NER_ANNOTATION, NER_TEXT, TOKEN_WORD_MAPPING, WORD_OFFSETS
9
9
  from ..mlp import MLP
10
10
  from ..utils import run_model
11
11
  from .base import AbstractMultimodalFusionModel
@@ -4,7 +4,7 @@ from typing import Optional
4
4
  import torch
5
5
  from torch import nn
6
6
 
7
- from ...constants import AUTOMM, FEATURES, LABEL, LOGITS, WEIGHT
7
+ from ...constants import FEATURES, LABEL, LOGITS, WEIGHT
8
8
  from ..custom_transformer import CLSToken, Custom_Transformer
9
9
  from ..utils import init_weights, run_model
10
10
  from .base import AbstractMultimodalFusionModel
@@ -25,15 +25,15 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
25
25
  models: list,
26
26
  hidden_features: int,
27
27
  num_classes: int,
28
- n_blocks: Optional[int] = 0,
29
- attention_n_heads: Optional[int] = 8,
28
+ num_blocks: Optional[int] = 0,
29
+ attention_num_heads: Optional[int] = 8,
30
30
  attention_initialization: Optional[str] = "kaiming",
31
31
  attention_normalization: Optional[str] = "layer_norm",
32
32
  attention_dropout: Optional[str] = 0.2,
33
33
  residual_dropout: Optional[str] = 0.0,
34
34
  ffn_activation: Optional[str] = "reglu",
35
35
  ffn_normalization: Optional[str] = "layer_norm",
36
- ffn_d_hidden: Optional[str] = 192,
36
+ ffn_hidden_size: Optional[str] = 192,
37
37
  ffn_dropout: Optional[str] = 0.0,
38
38
  prenormalization: Optional[bool] = True,
39
39
  first_prenormalization: Optional[bool] = False,
@@ -42,7 +42,7 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
42
42
  head_activation: Optional[str] = "relu",
43
43
  head_normalization: Optional[str] = "layer_norm",
44
44
  adapt_in_features: Optional[str] = None,
45
- loss_weight: Optional[float] = None,
45
+ aux_loss_weight: Optional[float] = None,
46
46
  additive_attention: Optional[bool] = False,
47
47
  share_qv_weights: Optional[bool] = False,
48
48
  ):
@@ -59,9 +59,9 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
59
59
  feature dimensions.
60
60
  num_classes
61
61
  The number of classes.
62
- n_blocks
62
+ num_blocks
63
63
  Number of the `FT_Transformer` blocks, which should be non-negative.
64
- attention_n_heads
64
+ attention_num_heads
65
65
  Number of attention heads in each `FT_Transformer` block, which should be positive.
66
66
  attention_dropout
67
67
  Dropout ratio for the Multi Headed Attention module.
@@ -71,7 +71,7 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
71
71
  Normalization policy for attention layers. "layer_norm" is a good default.
72
72
  residual_dropout
73
73
  Dropout ratio for the linear layers in FT_Transformer block.
74
- ffn_d_hidden
74
+ ffn_hidden_size
75
75
  Number of the hidden nodes of the linear layers in the Feed-Forward Network module.
76
76
  ffn_dropout
77
77
  Dropout ratio of the hidden nodes of the linear layers in the Feed-Forward Network module.
@@ -99,9 +99,9 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
99
99
  Adapt all features to the maximum dimension. For example, if three models have
100
100
  feature dimensions are [512, 768, 64], it will linearly map all the features to
101
101
  dimension 768.
102
- loss_weight
102
+ aux_loss_weight
103
103
  The weight of individual models. For example, if we fuse the features of ViT, CLIP, and BERT,
104
- The loss will be computed as "loss = fusion_loss + loss_weight(vit_loss + clip_loss + bert_loss)".
104
+ The loss will be computed as "loss = fusion_loss + aux_loss_weight(vit_loss + clip_loss + bert_loss)".
105
105
  Basically, it supports adding an auxiliary loss for each individual model.
106
106
  additive_attention
107
107
  If 'true' the transformer will use additive attention with linear complexity to sequence length.
@@ -111,11 +111,11 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
111
111
  super().__init__(
112
112
  prefix=prefix,
113
113
  models=models,
114
- loss_weight=loss_weight,
114
+ aux_loss_weight=aux_loss_weight,
115
115
  )
116
- logger.debug("initializing MultimodalFusionTransformer")
117
- if loss_weight is not None:
118
- assert loss_weight > 0
116
+ logger.debug(f"initializing {prefix} (MultimodalFusionTransformer)")
117
+ if aux_loss_weight is not None:
118
+ assert aux_loss_weight >= 0
119
119
 
120
120
  raw_in_features = [per_model.out_features for per_model in models]
121
121
 
@@ -133,13 +133,13 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
133
133
  assert len(self.adapter) == len(self.model)
134
134
 
135
135
  self.fusion_transformer = Custom_Transformer(
136
- d_token=in_features,
137
- n_blocks=n_blocks,
138
- attention_n_heads=attention_n_heads,
136
+ token_dim=in_features,
137
+ num_blocks=num_blocks,
138
+ attention_num_heads=attention_num_heads,
139
139
  attention_dropout=attention_dropout,
140
140
  attention_initialization=attention_initialization,
141
141
  attention_normalization=attention_normalization,
142
- ffn_d_hidden=ffn_d_hidden,
142
+ ffn_hidden_size=ffn_hidden_size,
143
143
  ffn_dropout=ffn_dropout,
144
144
  ffn_activation=ffn_activation,
145
145
  ffn_normalization=ffn_normalization,
@@ -147,7 +147,7 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
147
147
  prenormalization=prenormalization,
148
148
  first_prenormalization=first_prenormalization,
149
149
  last_layer_query_idx=None,
150
- n_tokens=None,
150
+ num_tokens=None,
151
151
  kv_compression_ratio=kv_compression_ratio,
152
152
  kv_compression_sharing=kv_compression_sharing,
153
153
  head_activation=head_activation,
@@ -167,7 +167,7 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
167
167
  )
168
168
 
169
169
  self.cls_token = CLSToken(
170
- d_token=in_features,
170
+ token_dim=in_features,
171
171
  initialization="uniform",
172
172
  )
173
173
 
@@ -196,9 +196,9 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
196
196
  multimodal_feature = torch.unsqueeze(multimodal_feature, dim=1)
197
197
  multimodal_features.append(multimodal_feature)
198
198
 
199
- if self.loss_weight is not None:
199
+ if self.aux_loss_weight is not None:
200
200
  per_output[per_model.prefix].update(
201
- {WEIGHT: torch.tensor(self.loss_weight).to(multimodal_features[0])}
201
+ {WEIGHT: torch.tensor(self.aux_loss_weight).to(multimodal_features[0])}
202
202
  )
203
203
  output.update(per_output)
204
204
 
@@ -213,7 +213,7 @@ class MultimodalFusionTransformer(AbstractMultimodalFusionModel):
213
213
  FEATURES: features,
214
214
  }
215
215
  }
216
- if self.loss_weight is not None:
216
+ if self.aux_loss_weight is not None:
217
217
  fusion_output[self.prefix].update({WEIGHT: torch.tensor(1.0).to(logits)})
218
218
  output.update(fusion_output)
219
219
  return output
@@ -7,7 +7,6 @@ from transformers import logging as hf_logging
7
7
  from transformers.models.t5 import T5PreTrainedModel
8
8
 
9
9
  from ..constants import (
10
- AUTOMM,
11
10
  COLUMN,
12
11
  COLUMN_FEATURES,
13
12
  FEATURES,
@@ -24,6 +23,8 @@ from .utils import (
24
23
  get_column_features,
25
24
  get_hf_config_and_model,
26
25
  get_pretrained_tokenizer,
26
+ get_text_segment_num,
27
+ get_text_token_max_len,
27
28
  init_weights,
28
29
  )
29
30
 
@@ -48,6 +49,8 @@ class HFAutoModelForTextPrediction(nn.Module):
48
49
  low_cpu_mem_usage: Optional[bool] = False,
49
50
  pretrained: Optional[bool] = True,
50
51
  tokenizer_name: Optional[str] = "hf_auto",
52
+ max_text_len: Optional[int] = None,
53
+ text_segment_num: Optional[int] = 1,
51
54
  use_fast: Optional[bool] = True,
52
55
  ):
53
56
  """
@@ -82,13 +85,18 @@ class HFAutoModelForTextPrediction(nn.Module):
82
85
  Whether using the pretrained weights. If pretrained=True, download the pretrained model.
83
86
  tokenizer_name
84
87
  Name of the huggingface tokenizer type.
88
+ max_text_len
89
+ The maximum length of text tokens.
90
+ text_segment_num
91
+ The number of text segments.
85
92
  use_fast
86
93
  Use a fast Rust-based tokenizer if it is supported for a given model.
87
94
  If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
88
95
  See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast
89
96
  """
90
97
  super().__init__()
91
- logger.debug(f"initializing {checkpoint_name}")
98
+ logger.debug(f"initializing {prefix} (HFAutoModelForTextPrediction)")
99
+ logger.debug(f"model checkpoint: {checkpoint_name}")
92
100
  self.checkpoint_name = checkpoint_name
93
101
  self.num_classes = num_classes
94
102
 
@@ -101,6 +109,17 @@ class HFAutoModelForTextPrediction(nn.Module):
101
109
  checkpoint_name=self.checkpoint_name,
102
110
  use_fast=use_fast,
103
111
  )
112
+ self.max_text_len = get_text_token_max_len(
113
+ provided_max_len=max_text_len,
114
+ config=self.config,
115
+ tokenizer=self.tokenizer,
116
+ checkpoint_name=self.checkpoint_name,
117
+ )
118
+ self.text_segment_num = get_text_segment_num(
119
+ config=self.config,
120
+ provided_segment_num=text_segment_num,
121
+ checkpoint_name=self.checkpoint_name,
122
+ )
104
123
 
105
124
  if isinstance(self.model, T5PreTrainedModel):
106
125
  self.is_t5 = True